From 63886624fc03010c34223c0053382f06c59d07d7 Mon Sep 17 00:00:00 2001 From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de> Date: Thu, 20 Apr 2017 15:57:29 +0200 Subject: [PATCH] Fix FastDG --- python/dune/perftool/sumfact/basis.py | 7 +++++++ python/dune/perftool/sumfact/realization.py | 11 +++-------- python/dune/perftool/sumfact/symbolic.py | 7 +++++-- 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/python/dune/perftool/sumfact/basis.py b/python/dune/perftool/sumfact/basis.py index de9c30ae..04bf64ba 100644 --- a/python/dune/perftool/sumfact/basis.py +++ b/python/dune/perftool/sumfact/basis.py @@ -110,6 +110,13 @@ class LFSSumfactKernelInput(SumfactKernelInputBase, ImmutableRecord): tags=frozenset({"sumfact_stage{}".format(sf.stage)}), ) + @property + def direct_input(self): + if get_option("fastdg"): + return self.coeff_func(self.restriction) + else: + return None + def _basis_functions_per_direction(element, component): """Number of basis functions per direction of a given component of an element""" diff --git a/python/dune/perftool/sumfact/realization.py b/python/dune/perftool/sumfact/realization.py index f939feae..8922a0b0 100644 --- a/python/dune/perftool/sumfact/realization.py +++ b/python/dune/perftool/sumfact/realization.py @@ -57,10 +57,10 @@ def _realize_sum_factorization_kernel(sf): ), })) - # Set up the input for stage 1 - if sf.stage == 1 and not get_option("fastdg"): - assert sf.input + direct_input = sf.input.direct_input + # Set up the input for stage 1 + if direct_input is None: if sf.vectorized: for i, inputsf in enumerate(sf.kernels): inputsf.input.realize(sf, i, inputsf.insn_dep.union(insn_dep)) @@ -69,11 +69,6 @@ def _realize_sum_factorization_kernel(sf): insn_dep = insn_dep.union(frozenset({lp.match.Writes("input_{}".format(sf.buffer))})) - # Construct the direct_input for the FastDG case - direct_input = None - if get_option('fastdg') and sf.stage == 1: - direct_input = sf.input.coeff_func(sf.input.restriction) - direct_output = None if get_option('fastdg') and sf.stage == 3: direct_output = sf.accumvar + ".data()" diff --git a/python/dune/perftool/sumfact/symbolic.py b/python/dune/perftool/sumfact/symbolic.py index a8a42ed1..9a3011cb 100644 --- a/python/dune/perftool/sumfact/symbolic.py +++ b/python/dune/perftool/sumfact/symbolic.py @@ -16,8 +16,11 @@ import inspect class SumfactKernelInputBase(object): @property - def flat_shape(self): - return False + def direct_input(self): + return None + + def realize(self, sf, i, dep): + pass class SumfactKernelBase(object): -- GitLab