From cb3390bd52c5884149bee1c8e7c6fa66a0d24b3f Mon Sep 17 00:00:00 2001 From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de> Date: Thu, 7 Apr 2016 15:58:12 +0200 Subject: [PATCH] Do not remove mod args from the expression anymore UFL did not like my way of replacing some args with 1 as it destroyed the shape information. It now remains in the expression, but triggers code generation separately. --- python/dune/perftool/loopy/transformer.py | 53 +++++++++---------- python/dune/perftool/pdelab/quadrature.py | 2 +- .../dune/perftool/ufl/modified_terminals.py | 5 +- .../perftool/ufl/transformations/__init__.py | 6 +-- .../extract_accumulation_terms.py | 19 +++---- 5 files changed, 36 insertions(+), 49 deletions(-) diff --git a/python/dune/perftool/loopy/transformer.py b/python/dune/perftool/loopy/transformer.py index 71681e95..15abdee6 100644 --- a/python/dune/perftool/loopy/transformer.py +++ b/python/dune/perftool/loopy/transformer.py @@ -79,25 +79,6 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker, UFL2PymbolicMapper): return Variable(name_facetarea()) -def get_pymbolic_expr(expr): - """ Transform the given UFL expression into a pymbolic expression - and have all sorts of side effects on the generation cache. """ - # We do need to manually handle modified terminals related to trial functions - from dune.perftool.ufl.modified_terminals import extract_modified_arguments - from dune.perftool.ufl.transformations.replace import ReplaceExpression - from dune.perftool.pdelab.argument import name_trialfunction - from pymbolic.primitives import Variable - - trial_ma = extract_modified_arguments(expr, trialfunction=True) - # OLD CODE had: globalarg(name) - rmap = {ma.expr: Variable(name_trialfunction(ma)) for ma in trial_ma} - ufl2l_mf = UFL2LoopyVisitor() - re_mf = ReplaceExpression(replacemap=rmap, otherwise=ufl2l_mf) - ufl2l_mf.call = re_mf.__call__ - - return re_mf(expr) - - class _Counter: counter = 0 @@ -109,15 +90,32 @@ def get_count(): def transform_accumulation_term(term): - # Get the accumulation expression and the modified arguments - expr, args = term + from dune.perftool.ufl.transformations.replace import ReplaceExpression + from pymbolic.primitives import Variable # We always have a quadrature loop quadrature_iname() # Get the pymbolic expression needed for this accumulation term. # This includes filling the cache with all sorts of necessary preambles! - pymbolic_expr = get_pymbolic_expr(expr) + from dune.perftool.ufl.modified_terminals import extract_modified_arguments + test_ma = extract_modified_arguments(term, trialfunction=False, testfunction=True) + trial_ma = extract_modified_arguments(term, trialfunction=True, testfunction=False) + + rmap = {} + for ma in test_ma: + from dune.perftool.pdelab.argument import name_testfunction + rmap[ma.expr] = Variable(name_testfunction(ma)) + for ma in trial_ma: + from dune.perftool.pdelab.argument import name_trialfunction + rmap[ma.expr] = Variable(name_trialfunction(ma)) + + # Get the transformer! + ufl2l_mf = UFL2LoopyVisitor() + re_mf = ReplaceExpression(replacemap=rmap, otherwise=ufl2l_mf) + ufl2l_mf.call = re_mf.__call__ + + pymbolic_expr = re_mf(term) # Now simplify the expression # TODO: Add a switch to disable/configure this. @@ -132,16 +130,14 @@ def transform_accumulation_term(term): # The data that is used to collect the arguments for the accumulate function accumargs = [] - argument_code = [] # Generate the code for the modified arguments: - for arg in args: + for arg in test_ma: from dune.perftool.pdelab.argument import name_argumentspace, name_argument accumargs.append(name_argumentspace(arg)) accumargs.append(argument_iname(arg)) - name = name_argument(arg) - argument_code.append(name) - globalarg(name) + # TODO is this global + #globalarg(argument_iname(arg)+"_n") from dune.perftool.pdelab.argument import name_residual residual = name_residual() @@ -151,10 +147,9 @@ def transform_accumulation_term(term): from dune.perftool.pdelab.quadrature import name_factor c_instruction(loopy.CInstruction(inames, - "{}.accumulate({}, {}*{}*{})".format(residual, + "{}.accumulate({}, {}*{})".format(residual, ", ".join(accumargs), expr_tv_name, - "*".join(argument_code), name_factor() ) ) diff --git a/python/dune/perftool/pdelab/quadrature.py b/python/dune/perftool/pdelab/quadrature.py index 962cb735..c8346972 100644 --- a/python/dune/perftool/pdelab/quadrature.py +++ b/python/dune/perftool/pdelab/quadrature.py @@ -8,7 +8,7 @@ def quadrature_rule(): return "rule" -@quadrature_preamble(assignees="fac") +@quadrature_preamble() def define_quadrature_factor(fac): rule = quadrature_rule() return "auto {} = {}->weight();".format(fac, rule) diff --git a/python/dune/perftool/ufl/modified_terminals.py b/python/dune/perftool/ufl/modified_terminals.py index 9bc23612..8bf5173a 100644 --- a/python/dune/perftool/ufl/modified_terminals.py +++ b/python/dune/perftool/ufl/modified_terminals.py @@ -90,9 +90,10 @@ class ModifiedArgumentDescriptor(MultiFunction): class _ModifiedArgumentExtractor(MultiFunction): """ A multifunction that extracts and returns the set of modified arguments """ - def __call__(self, o, argnumber=None, trialfunction=False): + def __call__(self, o, argnumber=None, testfunction=True, trialfunction=False): self.argnumber = argnumber self.trialfunction = trialfunction + self.testfunction = testfunction self.modified_arguments = set() ret = self.call(o) if ret: @@ -127,7 +128,7 @@ class _ModifiedArgumentExtractor(MultiFunction): return o def argument(self, o): - if not self.trialfunction: + if self.testfunction: if self.argnumber is None or o.number() == self.argnumber: return o diff --git a/python/dune/perftool/ufl/transformations/__init__.py b/python/dune/perftool/ufl/transformations/__init__.py index 81b3ff6e..19577bb3 100644 --- a/python/dune/perftool/ufl/transformations/__init__.py +++ b/python/dune/perftool/ufl/transformations/__init__.py @@ -48,11 +48,7 @@ class UFLTransformationWrapper(object): # We do also assume that the transformation returns an ufl expression or a list there of ret_for_print = self.extractExpressionListFromResult(ret) - try: - assert isinstance(ret_for_print, list) and all(isinstance(e, Expr) for e in ret_for_print) - except AssertionError: - from IPython import embed - embed() + assert isinstance(ret_for_print, list) and all(isinstance(e, Expr) for e in ret_for_print) # Maybe output the returned expression self.write_trafo(ret_for_print, False) diff --git a/python/dune/perftool/ufl/transformations/extract_accumulation_terms.py b/python/dune/perftool/ufl/transformations/extract_accumulation_terms.py index d14383ab..3dcb8b42 100644 --- a/python/dune/perftool/ufl/transformations/extract_accumulation_terms.py +++ b/python/dune/perftool/ufl/transformations/extract_accumulation_terms.py @@ -10,22 +10,22 @@ from dune.perftool.ufl.transformations import ufl_transformation from dune.perftool.ufl.transformations.replace import replace_expression from ufl.algorithms import MultiFunction -from ufl.classes import Zero, IntValue +from ufl.classes import Zero import itertools class _ReplacementDict(dict): - def __init__(self, *mod_args): + def __init__(self, *args): dict.__init__(self) - for ma in mod_args: - self[ma] = IntValue(1) + for a in args: + self[a] = a def __getitem__(self, key): return self.get(key, Zero()) -@ufl_transformation(name="accterms2", extraction_lambda=lambda l: [i[0] for i in l]) +@ufl_transformation(name="accterms2", extraction_lambda=lambda l: l) def split_into_accumulation_terms(expr): mod_args = extract_modified_arguments(expr) @@ -35,18 +35,13 @@ def split_into_accumulation_terms(expr): if len(filter(lambda ma: ma.argexpr.count() == 1, mod_args)) == 0: for arg in mod_args: # Do the replacement on the expression - accum_expr = replace_expression(expr, replacemap=_ReplacementDict(arg.expr)) - - # Store the found accumulation expression - accumulation_terms.append((accum_expr, (arg,))) + accumulation_terms.append(replace_expression(expr, replacemap=_ReplacementDict(arg.expr))) # and now the case of a rank 2 form: else: for arg1, arg2 in itertools.product(filter(lambda ma: ma.argexpr.count() == 0, mod_args), filter(lambda ma: ma.argexpr.count() == 1, mod_args) ): - accum_expr = replace_expression(expr, replacemap=_ReplacementDict(arg1.expr, arg2.expr)) - - accumulation_terms.append((accum_expr, (arg1, arg2))) + accumulation_terms.append(replace_expression(expr, replacemap=_ReplacementDict(arg1.expr, arg2.expr))) # and return the result return accumulation_terms -- GitLab