From 8d2b63838311597d409ebadc1db2b48374204ce4 Mon Sep 17 00:00:00 2001 From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de> Date: Thu, 15 Feb 2018 13:50:48 +0100 Subject: [PATCH] Move accumulation code onto output object --- python/dune/perftool/sumfact/accumulation.py | 117 ++++++++++--------- python/dune/perftool/sumfact/symbolic.py | 28 ++++- python/dune/perftool/tools.py | 11 ++ 3 files changed, 101 insertions(+), 55 deletions(-) diff --git a/python/dune/perftool/sumfact/accumulation.py b/python/dune/perftool/sumfact/accumulation.py index a01fb8ea..0b189892 100644 --- a/python/dune/perftool/sumfact/accumulation.py +++ b/python/dune/perftool/sumfact/accumulation.py @@ -15,6 +15,7 @@ from dune.perftool.generation import (backend, kernel_cached, temporary_variable, transform, + valuearg ) from dune.perftool.options import (get_form_option, get_option, @@ -26,6 +27,7 @@ from dune.perftool.pdelab.localoperator import determine_accumulation_space from dune.perftool.pdelab.restriction import restricted_name from dune.perftool.pdelab.signatures import assembler_routine_name from dune.perftool.pdelab.geometry import world_dimension +from dune.perftool.pdelab.spaces import name_lfs from dune.perftool.sumfact.tabulation import (basis_functions_per_direction, construct_basis_matrix_sequence, ) @@ -34,7 +36,7 @@ from dune.perftool.sumfact.switch import (get_facedir, ) from dune.perftool.sumfact.symbolic import SumfactKernel, SumfactKernelOutputBase from dune.perftool.ufl.modified_terminals import extract_modified_arguments -from dune.perftool.tools import get_pymbolic_basename +from dune.perftool.tools import get_pymbolic_basename, get_leaf from dune.perftool.error import PerftoolError from dune.perftool.sumfact.quadrature import quadrature_inames @@ -90,7 +92,7 @@ class AccumulationOutput(SumfactKernelOutputBase, ImmutableRecord): trial_element=None, trial_element_index=None, ): - #TODO: Isnt accumvar superfluous in the presence of all the other infos? + # TODO: Isnt accumvar superfluous in the presence of all the other infos? ImmutableRecord.__init__(self, accumvar=accumvar, restriction=None, @@ -106,10 +108,65 @@ class AccumulationOutput(SumfactKernelOutputBase, ImmutableRecord): return () else: from dune.perftool.sumfact.basis import lfs_inames - element = self.trial_element - if isinstance(element, MixedElement): - element = element.extract_component(self.trial_element_index)[1] - return lfs_inames(element, self.restriction) + return lfs_inames(get_leaf(self.trial_element, self.trial_element_index), self.restriction) + + + def realize(self, sf, result, insn_dep, inames=None, additional_inames=()): + trial_leaf_element = get_leaf(self.trial_element, self.trial_element_index) if self.trial_element is not None else None + + basis_size = tuple(mat.basis_size for mat in sf.matrix_sequence) + + if inames is None: + inames = tuple(accum_iname(trial_leaf_element, mat.rows, i) + for i, mat in enumerate(sf.matrix_sequence)) + + # Determine the expression to accumulate with. This depends on the vectorization strategy! + from dune.perftool.tools import maybe_wrap_subscript + result = maybe_wrap_subscript(result, tuple(prim.Variable(i) for i in inames)) + + # Collect the lfs and lfs indices for the accumulate call + restriction = (0, 0) if self.restriction is None else self.restriction + test_lfs = name_lfs(self.test_element, restriction[0], self.test_element_index) + valuearg(test_lfs, dtype=lp.types.NumpyType("str")) + test_lfs_index = flatten_index(tuple(prim.Variable(i) for i in inames), + basis_size, + order="f" + ) + + accum_args = [prim.Variable(test_lfs), test_lfs_index] + + # In the jacobian case, also determine the space for the ansatz space + if sf.within_inames: + # TODO the next line should get its inames from + # elsewhere. This is *NOT* robust (but works right now) + ansatz_lfs = name_lfs(self.trial_element, restriction[1], self.trial_element_index) + valuearg(ansatz_lfs, dtype=lp.types.NumpyType("str")) + from dune.perftool.sumfact.basis import _basis_functions_per_direction + ansatz_lfs_index = flatten_index(tuple(prim.Variable(sf.within_inames[i]) + for i in range(world_dimension())), + _basis_functions_per_direction(trial_leaf_element), + order="f" + ) + + accum_args.append(prim.Variable(ansatz_lfs)) + accum_args.append(ansatz_lfs_index) + + accum_args.append(result) + + if not get_form_option("fastdg"): + rank = 2 if self.within_inames else 1 + expr = prim.Call(PDELabAccumulationFunction(self.accumvar, rank), + tuple(accum_args) + ) + instruction(assignees=(), + expression=expr, + forced_iname_deps=frozenset(inames + additional_inames + self.within_inames), + forced_iname_deps_is_final=True, + depends_on=insn_dep, + predicates=sf.predicates + ) + + return frozenset() class SumfactAccumulationInfo(ImmutableRecord): @@ -358,56 +415,10 @@ def generate_accumulation_instruction(expr, visitor): depends_on=insn_dep, within_inames=frozenset(jacobian_inames))}) - inames = tuple(accum_iname(trial_leaf_element, mat.rows, i) - for i, mat in enumerate(vsf.matrix_sequence)) - - # Collect the lfs and lfs indices for the accumulate call - test_lfs.index = flatten_index(tuple(prim.Variable(i) for i in inames), - basis_size, - order="f" - ) - - # In the jacobian case, also determine the space for the ansatz space - if jacobian_inames: - # TODO the next line should get its inames from - # elsewhere. This is *NOT* robust (but works right now) - from dune.perftool.sumfact.basis import _basis_functions_per_direction - ansatz_lfs.index = flatten_index(tuple(prim.Variable(jacobian_inames[i]) - for i in range(world_dimension())), - _basis_functions_per_direction(trial_leaf_element), - order="f" - ) - # Add a sum factorization kernel that implements the multiplication # with the test function (stage 3) from dune.perftool.sumfact.realization import realize_sum_factorization_kernel result, insn_dep = realize_sum_factorization_kernel(vsf.copy(insn_dep=vsf.insn_dep.union(insn_dep))) - # Determine the expression to accumulate with. This depends on the vectorization strategy! - result = prim.Subscript(result, tuple(prim.Variable(i) for i in inames)) - vecinames = () - - if vsf.vectorized: - iname = accum_iname(trial_leaf_element, vsf.vector_width, "vec") - vecinames = (iname,) - transform(lp.tag_inames, [(iname, "vec")]) - from dune.perftool.tools import maybe_wrap_subscript - result = prim.Call(prim.Variable("horizontal_add"), - (maybe_wrap_subscript(result, prim.Variable(iname)),), - ) - if not get_form_option("fastdg"): - rank = 2 if jacobian_inames else 1 - expr = prim.Call(PDELabAccumulationFunction(accumvar, rank), - (test_lfs.get_args() + - ansatz_lfs.get_args() + - (result,) - ) - ) - instruction(assignees=(), - expression=expr, - forced_iname_deps=frozenset(inames + vecinames + jacobian_inames), - forced_iname_deps_is_final=True, - depends_on=insn_dep, - predicates=predicates - ) + vsf.output.realize(vsf, result, insn_dep) diff --git a/python/dune/perftool/sumfact/symbolic.py b/python/dune/perftool/sumfact/symbolic.py index 322806fa..2b9c2e21 100644 --- a/python/dune/perftool/sumfact/symbolic.py +++ b/python/dune/perftool/sumfact/symbolic.py @@ -1,12 +1,15 @@ """ A pymbolic node representing a sum factorization kernel """ from dune.perftool.options import get_option -from dune.perftool.generation import get_counted_variable +from dune.perftool.generation import (get_counted_variable, + transform, + ) from dune.perftool.pdelab.geometry import local_dimension, world_dimension from dune.perftool.sumfact.quadrature import quadrature_inames from dune.perftool.sumfact.tabulation import BasisTabulationMatrixBase, BasisTabulationMatrixArray from dune.perftool.loopy.target import dtype_floatingpoint from dune.perftool.loopy.vcl import ExplicitVCLCast, VCLLowerUpperLoad +from dune.perftool.tools import get_leaf from pytools import ImmutableRecord, product @@ -81,7 +84,7 @@ class SumfactKernelOutputBase(object): def within_inames(self): return () - def realize(self, sf, dep): + def realize(self, sf, result, insn_dep): return dep def realize_direct(self): @@ -92,6 +95,27 @@ class VectorSumfactKernelOutput(SumfactKernelOutputBase): def __init__(self, outputs): self.outputs = outputs + def realize(self, sf, result, insn_dep): + outputs = set(self.outputs) + assert(len(outputs) == 1) + + o, = outputs + + from dune.perftool.sumfact.accumulation import accum_iname + element = get_leaf(o.trial_element, o.trial_element_index) if o.trial_element is not None else None + inames = tuple(accum_iname(element, mat.rows, i) + for i, mat in enumerate(sf.matrix_sequence)) + + veciname = accum_iname(element, sf.vector_width, "vec") + transform(lp.tag_inames, [(veciname, "vec")]) + + from dune.perftool.tools import maybe_wrap_subscript + result = prim.Call(prim.Variable("horizontal_add"), + (maybe_wrap_subscript(result, tuple(prim.Variable(iname) for iname in inames + (veciname,))),), + ) + + return o.realize(sf, result, insn_dep, inames=inames, additional_inames=(veciname,)) + class SumfactKernelBase(object): pass diff --git a/python/dune/perftool/tools.py b/python/dune/perftool/tools.py index e302f28c..b29ebe22 100644 --- a/python/dune/perftool/tools.py +++ b/python/dune/perftool/tools.py @@ -82,3 +82,14 @@ def list_diff(l1, l2): if item not in l2: l.append(item) return l + + +def get_leaf(element, index): + """ return a leaf element if the given element is a MixedElement """ + leaf_element = element + from ufl import MixedElement + if isinstance(element, MixedElement): + assert isinstance(index, int) + leaf_element = element.extract_component(index)[1] + + return leaf_element -- GitLab