From 4b2db15f2cd52ad2cd3e11f48926375e31153ca5 Mon Sep 17 00:00:00 2001 From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de> Date: Fri, 31 Mar 2017 14:38:02 +0200 Subject: [PATCH] Refactor output shape --- python/dune/perftool/loopy/symbolic.py | 23 +++++++++++++++++++++ python/dune/perftool/sumfact/basis.py | 8 ++----- python/dune/perftool/sumfact/realization.py | 20 +++--------------- 3 files changed, 28 insertions(+), 23 deletions(-) diff --git a/python/dune/perftool/loopy/symbolic.py b/python/dune/perftool/loopy/symbolic.py index 345819e5..402c8940 100644 --- a/python/dune/perftool/loopy/symbolic.py +++ b/python/dune/perftool/loopy/symbolic.py @@ -109,6 +109,29 @@ class SumfactKernel(ImmutableRecord, prim.Variable): shape = shape + (4,) return shape + @property + def output_shape(self): + """ The shape of the output temporary, ready to be fed into loopy """ + # In stage 1, the output may be of reduced dimensionality + if self.stage == 1: + shape = tuple(mat.rows for mat in self.a_matrices if mat.face is None) + else: + shape = tuple(mat.rows for mat in self.a_matrices) + if self.vectorized: + shape = shape + (4,) + return shape + + @property + def output_dimtags(self): + """ The dim_tags of the output temporary, ready to be fed into loopy """ + tags = ["f"] * len(self.output_shape) + if self.vectorized: + if self.stage == 1: + tags[-1] = 'c' + else: + tags[-1] = 'vec' + return ",".join(tags) + class FusedMultiplyAdd(prim.Expression): """ Represents an FMA operation """ diff --git a/python/dune/perftool/sumfact/basis.py b/python/dune/perftool/sumfact/basis.py index 1f15491e..cf148a7f 100644 --- a/python/dune/perftool/sumfact/basis.py +++ b/python/dune/perftool/sumfact/basis.py @@ -97,9 +97,7 @@ def pymbolic_coefficient_gradient(element, restriction, component, coeff_func, v # evaluation of the gradients of basis functions at quadrature # points (stage 1) from dune.perftool.sumfact.realization import realize_sum_factorization_kernel - var, insn_dep = realize_sum_factorization_kernel(sf, - outshape=tuple(mat.rows for mat in sf.a_matrices if mat.face is None), - ) + var, insn_dep = realize_sum_factorization_kernel(sf) buffers.append(var) @@ -144,9 +142,7 @@ def pymbolic_coefficient(element, restriction, component, coeff_func, visitor): # Add a sum factorization kernel that implements the evaluation of # the basis functions at quadrature points (stage 1) from dune.perftool.sumfact.realization import realize_sum_factorization_kernel - var, _ = realize_sum_factorization_kernel(sf, - outshape=tuple(mat.rows for mat in sf.a_matrices if mat.face is None), - ) + var, _ = realize_sum_factorization_kernel(sf) if sf.index: index = (sf.index,) diff --git a/python/dune/perftool/sumfact/realization.py b/python/dune/perftool/sumfact/realization.py index 0ed1641e..8da4339f 100644 --- a/python/dune/perftool/sumfact/realization.py +++ b/python/dune/perftool/sumfact/realization.py @@ -69,7 +69,7 @@ def _realize_input(sf, insn_dep): @generator_factory(item_tags=("sumfactkernel",), context_tags=("kernel",), cache_key_generator=lambda s, **kw: s.cache_key) -def _realize_sum_factorization_kernel(sf, insn_dep=frozenset(), outshape=None, direct_output=None): +def _realize_sum_factorization_kernel(sf, insn_dep=frozenset(), direct_output=None): # Unify the insn_dep parameter to be a frozenset if isinstance(insn_dep, str): insn_dep = frozenset({insn_dep}) @@ -279,23 +279,9 @@ def _realize_sum_factorization_kernel(sf, insn_dep=frozenset(), outshape=None, d insn_dep = instruction(code="HP_TIMER_START({});".format(qp_timer_name), depends_on=insn_dep) - if outshape is None: - assert sf.stage == 3 - outshape = tuple(mat.rows for mat in a_matrices) - - dim_tags = ",".join(['f'] * len(outshape)) - - if sf.vectorized: - outshape = outshape + vec_shape - # This is a 'bit' hacky: In stage 3 we need to return something with vectag, in stage 1 not. - if sf.stage == 1: - dim_tags = dim_tags + ",c" - else: - dim_tags = dim_tags + ",vec" - out = get_buffer_temporary(sf.buffer, - shape=outshape, - dim_tags=dim_tags, + shape=sf.output_shape, + dim_tags=sf.output_dimtags, ) silenced_warning('read_no_write({})'.format(out)) -- GitLab