diff --git a/python/dune/perftool/loopy/functions.py b/python/dune/perftool/loopy/functions.py index ec06459e4167968a61f46bcac1b2f8d2b07b8947..7f6c519ed1e8849872ab6e01a8793a881543ec44 100644 --- a/python/dune/perftool/loopy/functions.py +++ b/python/dune/perftool/loopy/functions.py @@ -4,6 +4,22 @@ from loopy.types import NumpyType import numpy +class LFSChild(FunctionIdentifier): + def __init__(self, lfs): + self.lfs = lfs + + def __getinitargs__(self): + return (self.lfs,) + + @property + def name(self): + return '{}.child'.format(self.lfs) + + +def lfs_child_mangler(target, func, dtypes): + if isinstance(func, LFSChild): + return CallMangleInfo(func.name, (NumpyType(str),), (NumpyType(numpy.int32),)) + class CoefficientAccess(FunctionIdentifier): def __init__(self, restriction): diff --git a/python/dune/perftool/loopy/transformer.py b/python/dune/perftool/loopy/transformer.py index 843e73e4df092a18cefb958017c62a99e07c8683..90049a40def1f1dd5e69861f63072e5a573d17ac 100644 --- a/python/dune/perftool/loopy/transformer.py +++ b/python/dune/perftool/loopy/transformer.py @@ -115,7 +115,12 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker, UFL2PymbolicMapper, GeometryMapp lfsi = lfs_iname(subel, ma.restriction, count=count) - accumargs[2 * icount] = Variable(lfs) + # If the LFS is not yet a pymbolic expression, make it one + from pymbolic.primitives import Expression + if not isinstance(lfs, Expression): + lfs = Variable(lfs) + + accumargs[2 * icount] = lfs accumargs[2 * icount + 1] = Variable(lfsi) arg_restr[icount] = ma.restriction diff --git a/python/dune/perftool/pdelab/argument.py b/python/dune/perftool/pdelab/argument.py index c96199fb2d8738918d9541359b78fbc6b22e9d82..6ff8444d27df79f4822629de12d15310422448e6 100644 --- a/python/dune/perftool/pdelab/argument.py +++ b/python/dune/perftool/pdelab/argument.py @@ -76,9 +76,16 @@ def name_coefficientcontainer(restriction): @pymbolic_expr def pymbolic_coefficient(lfs, index, restriction): # TODO introduce a proper type for local function spaces! - valuearg(lfs, dtype=loopy.types.NumpyType("str")) + if isinstance(lfs, str): + valuearg(lfs, dtype=loopy.types.NumpyType("str")) + + # If the LFS is not yet a pymbolic expression, make it one + from pymbolic.primitives import Expression + if not isinstance(lfs, Expression): + lfs = Variable(lfs) + from dune.perftool.loopy.functions import CoefficientAccess - return Call(CoefficientAccess(restriction), (Variable(lfs), Variable(index),)) + return Call(CoefficientAccess(restriction), (lfs, Variable(index),)) @symbol diff --git a/python/dune/perftool/pdelab/basis.py b/python/dune/perftool/pdelab/basis.py index 0031f468087c22de132ba186404188299e26a24f..9a45e0d297efe75c63e5e24c451a183b47d5cb41 100644 --- a/python/dune/perftool/pdelab/basis.py +++ b/python/dune/perftool/pdelab/basis.py @@ -76,7 +76,10 @@ def define_lfs(name, father, child): def lfs_child(lfs, child): - return "{}.child({})".format(lfs, child) + from pymbolic.primitives import Call + from dune.perftool.loopy.functions import LFSChild + return Call(LFSChild(lfs), (Variable(child),)) +# return "{}.child({})".format(lfs, child) @generator_factory(cache_key_generator=lambda e, r, **kw: (e, r)) diff --git a/python/dune/perftool/pdelab/localoperator.py b/python/dune/perftool/pdelab/localoperator.py index ef2165e7060a6dc487bec341dc2b14dec60c8718..310dc616285883198b5ccabe12616a96a95cd673 100644 --- a/python/dune/perftool/pdelab/localoperator.py +++ b/python/dune/perftool/pdelab/localoperator.py @@ -184,7 +184,7 @@ def generate_kernel(integrals): arguments = [i for i in retrieve_cache_items("argument")] # Get the function manglers - from dune.perftool.loopy.functions import accumulation_mangler, coefficient_mangler + from dune.perftool.loopy.functions import accumulation_mangler, coefficient_mangler, lfs_child_mangler # Create the kernel from loopy import make_kernel, preprocess_kernel @@ -192,7 +192,7 @@ def generate_kernel(integrals): instructions + subst_rules, arguments, temporary_variables=temporaries, - function_manglers=[accumulation_mangler, coefficient_mangler], + function_manglers=[accumulation_mangler, coefficient_mangler, lfs_child_mangler], target=DuneTarget() )