diff --git a/python/dune/perftool/sumfact/basis.py b/python/dune/perftool/sumfact/basis.py index de9c30aea7dd6a08a35dc702aa3c91442209596f..04bf64bad3349da40f153261cf308f0ba6a2edb7 100644 --- a/python/dune/perftool/sumfact/basis.py +++ b/python/dune/perftool/sumfact/basis.py @@ -110,6 +110,13 @@ class LFSSumfactKernelInput(SumfactKernelInputBase, ImmutableRecord): tags=frozenset({"sumfact_stage{}".format(sf.stage)}), ) + @property + def direct_input(self): + if get_option("fastdg"): + return self.coeff_func(self.restriction) + else: + return None + def _basis_functions_per_direction(element, component): """Number of basis functions per direction of a given component of an element""" diff --git a/python/dune/perftool/sumfact/realization.py b/python/dune/perftool/sumfact/realization.py index f939feaeb6fc8da05828e88ef25fece71ecf9408..8922a0b06c8bd5c0e7bc21afe0585b08bf804ab3 100644 --- a/python/dune/perftool/sumfact/realization.py +++ b/python/dune/perftool/sumfact/realization.py @@ -57,10 +57,10 @@ def _realize_sum_factorization_kernel(sf): ), })) - # Set up the input for stage 1 - if sf.stage == 1 and not get_option("fastdg"): - assert sf.input + direct_input = sf.input.direct_input + # Set up the input for stage 1 + if direct_input is None: if sf.vectorized: for i, inputsf in enumerate(sf.kernels): inputsf.input.realize(sf, i, inputsf.insn_dep.union(insn_dep)) @@ -69,11 +69,6 @@ def _realize_sum_factorization_kernel(sf): insn_dep = insn_dep.union(frozenset({lp.match.Writes("input_{}".format(sf.buffer))})) - # Construct the direct_input for the FastDG case - direct_input = None - if get_option('fastdg') and sf.stage == 1: - direct_input = sf.input.coeff_func(sf.input.restriction) - direct_output = None if get_option('fastdg') and sf.stage == 3: direct_output = sf.accumvar + ".data()" diff --git a/python/dune/perftool/sumfact/symbolic.py b/python/dune/perftool/sumfact/symbolic.py index a8a42ed170ea97eead8741183d4145f0d5b7ba3e..9a3011cba1b3223e5af7b081a89cecd9d6326549 100644 --- a/python/dune/perftool/sumfact/symbolic.py +++ b/python/dune/perftool/sumfact/symbolic.py @@ -16,8 +16,11 @@ import inspect class SumfactKernelInputBase(object): @property - def flat_shape(self): - return False + def direct_input(self): + return None + + def realize(self, sf, i, dep): + pass class SumfactKernelBase(object):