diff --git a/python/dune/perftool/sumfact/amatrix.py b/python/dune/perftool/sumfact/amatrix.py index d92fcc91d840261bd9c38650140f9fe48ae2e73e..98d3b55a9e4fb5978a85f424f2d05e248594e914 100644 --- a/python/dune/perftool/sumfact/amatrix.py +++ b/python/dune/perftool/sumfact/amatrix.py @@ -7,6 +7,7 @@ from dune.perftool.pdelab.argument import name_coefficientcontainer from dune.perftool.generation import (class_member, constructor_block, domain, + function_mangler, get_global_context_value, globalarg, iname, @@ -21,6 +22,10 @@ from dune.perftool.pdelab.localoperator import (name_domain_field, ) from dune.perftool.pdelab.quadrature import estimate_quadrature_order +from loopy import CallMangleInfo +from loopy.symbolic import FunctionIdentifier +from loopy.types import NumpyType + from pytools import Record import numpy @@ -35,6 +40,25 @@ class AMatrix(Record): ) +class ColMajorAccess(FunctionIdentifier): + def __init__(self, amatrix): + assert isinstance(amatrix, AMatrix) + self.amatrix = amatrix + + def __getinitargs__(self): + return (self.amatrix,) + + @property + def name(self): + return '{}.colmajoraccess'.format(self.amatrix.a_matrix) + + +@function_mangler +def colmajoraccess_mangler(target, func, dtypes): + if isinstance(func, ColMajorAccess): + return CallMangleInfo(func.name, (NumpyType(numpy.float64),), (NumpyType(numpy.int32), NumpyType(numpy.int32))) + + @class_member("operator") def define_alignment(name): alignment = get_option("sumfact_alignment") diff --git a/python/dune/perftool/sumfact/sumfact.py b/python/dune/perftool/sumfact/sumfact.py index 6b4b42e6402c8a7265820150ca729a6da7bd8eb6..74bd6fb9fa7223b54a4bb6115c536f097ef7750f 100644 --- a/python/dune/perftool/sumfact/sumfact.py +++ b/python/dune/perftool/sumfact/sumfact.py @@ -10,7 +10,8 @@ from dune.perftool.loopy.buffer import (get_buffer_temporary, switch_base_storage, ) from dune.perftool.pdelab.spaces import name_lfs -from pymbolic.primitives import (Product, +from pymbolic.primitives import (Call, + Product, Subscript, Variable, ) @@ -117,7 +118,8 @@ def sum_factorization_kernel(a_matrices, inp, buffer, insn_dep): k = sumfact_iname(a_matrix.n, "red") # Construct the matrix-matrix-multiplication expression a_ik*in_kj - prod = Product((Subscript(Variable(a_matrix.a_matrix), (Variable(i), Variable(k))), + from dune.perftool.sumfact.amatrix import ColMajorAccess + prod = Product((Call(ColMajorAccess(a_matrix), (Variable(i), Variable(k))), Subscript(Variable(inp), (Variable(k), Variable(j))) ))