From 5e072d6c6142f0ea7e553120f5d200f9081bd050 Mon Sep 17 00:00:00 2001 From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de> Date: Thu, 7 Apr 2016 14:56:41 +0200 Subject: [PATCH] Various foxes to the generation machinery --- python/dune/perftool/generation/cache.py | 10 ++++++++-- python/dune/perftool/generation/cpp.py | 2 +- python/dune/perftool/generation/loopy.py | 3 ++- python/dune/perftool/loopy/transformer.py | 8 ++++---- python/dune/perftool/pdelab/__init__.py | 7 ++++++- python/dune/perftool/pdelab/argument.py | 18 +++++++++--------- python/dune/perftool/pdelab/localoperator.py | 7 ++++--- python/dune/perftool/pdelab/quadrature.py | 2 +- .../extract_accumulation_terms.py | 4 ++-- 9 files changed, 37 insertions(+), 24 deletions(-) diff --git a/python/dune/perftool/generation/cache.py b/python/dune/perftool/generation/cache.py index 0780cd14..82217280 100644 --- a/python/dune/perftool/generation/cache.py +++ b/python/dune/perftool/generation/cache.py @@ -65,11 +65,17 @@ class _CacheItemMeta(type): def add_count(x): rettype._count = rettype._count + 1 - return (rettype._count, original_on_store(x)) + if isinstance(x, tuple): + return (rettype._count, original_on_store(*x)) + else: + return (rettype._count, original_on_store(x)) on_store = add_count def _init(s, x): - s.content = on_store(x) + if isinstance(x, tuple) and not counted: + s.content = on_store(*x) + else: + s.content = on_store(x) s.tags = item_tags s.counted = counted diff --git a/python/dune/perftool/generation/cpp.py b/python/dune/perftool/generation/cpp.py index 652f7f71..32efb653 100644 --- a/python/dune/perftool/generation/cpp.py +++ b/python/dune/perftool/generation/cpp.py @@ -43,7 +43,7 @@ def class_member(classtag=None, access=AccessModifier.PRIVATE): from cgen import Value from dune.perftool.cgen.clazz import ClassMember - return generator_factory(item_tags=(classtag, "member"), on_store=lambda t, n: ClassMember(Value(_type, name), access=access), counted=True, cache_key_generator=lambda t, n: n) + return generator_factory(item_tags=(classtag, "member"), on_store=lambda t, n: ClassMember(Value(t, n), access=access), counted=True) def constructor_parameter(_type, name, classtag=None, constructortag=None): diff --git a/python/dune/perftool/generation/loopy.py b/python/dune/perftool/generation/loopy.py index d4098dec..55268c4f 100644 --- a/python/dune/perftool/generation/loopy.py +++ b/python/dune/perftool/generation/loopy.py @@ -22,5 +22,6 @@ def globalarg(name, shape=loopy.auto): @generator_factory(item_tags=("loopy", "kernel", "domain")) def domain(iname, shape): - valuearg(shape) + if isinstance(shape, str): + valuearg(shape) return "{{ [{0}] : 0<={0}<{1} }}".format(iname, shape) diff --git a/python/dune/perftool/loopy/transformer.py b/python/dune/perftool/loopy/transformer.py index e5f26008..3ad304f1 100644 --- a/python/dune/perftool/loopy/transformer.py +++ b/python/dune/perftool/loopy/transformer.py @@ -37,9 +37,9 @@ def dimension_iname(index): def argument_iname(arg): # TODO extract the {iname}_n thing by a preamble from dune.perftool.ufl.modified_terminals import modified_argument_number - iname = "arg{}".format(chr(ord("i") + modified_argument_number()(arg))) - domain(iname, iname + "_n") - return iname + ainame = "arg{}".format(chr(ord("i") + arg.argexpr.number())) + domain(ainame, ainame + "_n") + return ainame @iname @@ -76,7 +76,7 @@ def get_pymbolic_expr(expr): trial_ma = extract_modified_arguments(expr, trialfunction=True) # OLD CODE had: globalarg(name) - rmap = {ma: Variable(name_trialfunction(ma)) for ma in trial_ma} + rmap = {ma.expr: Variable(name_trialfunction(ma)) for ma in trial_ma} ufl2l_mf = UFL2LoopyVisitor() re_mf = ReplaceExpression(replacemap=rmap, otherwise=ufl2l_mf) ufl2l_mf.call = re_mf.__call__ diff --git a/python/dune/perftool/pdelab/__init__.py b/python/dune/perftool/pdelab/__init__.py index 981c8b71..5743fbab 100644 --- a/python/dune/perftool/pdelab/__init__.py +++ b/python/dune/perftool/pdelab/__init__.py @@ -14,4 +14,9 @@ def quadrature_preamble(assignees=[]): # Now define some commonly used generators that do not fall into a specific category @symbol def name_index(index): - return str(index._indices[0]) + from ufl.classes import MultiIndex, Index + if isinstance(index, Index): + return str(index) + if isinstance(index, MultiIndex): + assert len(index) == 1 + return str(index._indices[0]) diff --git a/python/dune/perftool/pdelab/argument.py b/python/dune/perftool/pdelab/argument.py index bc7e6078..318c0040 100644 --- a/python/dune/perftool/pdelab/argument.py +++ b/python/dune/perftool/pdelab/argument.py @@ -6,7 +6,7 @@ from dune.perftool.ufl.modified_terminals import ModifiedArgumentDescriptor @symbol def name_testfunction(ma): - if len(ma.expr.element().sub_elements()) > 0: + if len(ma.argexpr.element().sub_elements()) > 0: pass return "{}a{}".format("grad_" if ma.grad else "", ma.argexpr.number()) @@ -29,19 +29,19 @@ def name_trialfunctionspace(*a): def name_argumentspace(ma): - if ma.expr.number() == 0: - return name_testfunctionspace(modarg) - if ma.expr.number() == 1: - return name_trialfunctionspace(modarg) + if ma.argexpr.number() == 0: + return name_testfunctionspace(ma) + if ma.argexpr.number() == 1: + return name_trialfunctionspace(ma) # We should never encounter an argument other than 0 or 1 assert False def name_argument(ma): - if ma.expr.number() == 0: - return name_testfunction(modarg) - if ma.expr.number() == 1: - return name_trialfunction(modarg) + if ma.argexpr.number() == 0: + return name_testfunction(ma) + if ma.argexpr.number() == 1: + return name_trialfunction(ma) # We should never encounter an argument other than 0 or 1 assert False diff --git a/python/dune/perftool/pdelab/localoperator.py b/python/dune/perftool/pdelab/localoperator.py index e844820e..bb2cd648 100644 --- a/python/dune/perftool/pdelab/localoperator.py +++ b/python/dune/perftool/pdelab/localoperator.py @@ -11,9 +11,9 @@ from pytools import memoize def define_initree(name): include_file('dune/common/parametertree.hh', filetag="operatorfile") constructor_parameter("const Dune::ParameterTree&", "iniParams", classtag="operator", constructortag="iniconstructor") - initializer_list("_iniParams", ["iniParams"]) + initializer_list("_iniParams", ["iniParams"], classtag="operator") - return "const Dune::ParameterTree&", "_iniParams" + return ("const Dune::ParameterTree&", "_iniParams") @symbol @@ -43,7 +43,8 @@ def measure_specific_details(measure): # Add the initializer list for that base class ini = name_initree_member() initializer_list("Dune::PDELab::NumericalJacobian{}<{}>".format(which, loptype), - ["{}.get(\"numerical_epsilon.{}\", 1e-9)".format(ini, which.lower())]) + ["{}.get(\"numerical_epsilon.{}\", 1e-9)".format(ini, which.lower())], + classtag="operator") if measure == "cell": base_class('Dune::PDELab::FullVolumePattern', classtag="operator") diff --git a/python/dune/perftool/pdelab/quadrature.py b/python/dune/perftool/pdelab/quadrature.py index a5a99681..962cb735 100644 --- a/python/dune/perftool/pdelab/quadrature.py +++ b/python/dune/perftool/pdelab/quadrature.py @@ -16,6 +16,6 @@ def define_quadrature_factor(fac): @symbol def name_factor(): - loopy_temporary_variable("fac") + temporary_variable("fac") define_quadrature_factor("fac") return "fac" diff --git a/python/dune/perftool/ufl/transformations/extract_accumulation_terms.py b/python/dune/perftool/ufl/transformations/extract_accumulation_terms.py index 2f902aa2..d14383ab 100644 --- a/python/dune/perftool/ufl/transformations/extract_accumulation_terms.py +++ b/python/dune/perftool/ufl/transformations/extract_accumulation_terms.py @@ -35,7 +35,7 @@ def split_into_accumulation_terms(expr): if len(filter(lambda ma: ma.argexpr.count() == 1, mod_args)) == 0: for arg in mod_args: # Do the replacement on the expression - accum_expr = replace_expression(expr, replacemap=_ReplacementDict(arg)) + accum_expr = replace_expression(expr, replacemap=_ReplacementDict(arg.expr)) # Store the found accumulation expression accumulation_terms.append((accum_expr, (arg,))) @@ -44,7 +44,7 @@ def split_into_accumulation_terms(expr): for arg1, arg2 in itertools.product(filter(lambda ma: ma.argexpr.count() == 0, mod_args), filter(lambda ma: ma.argexpr.count() == 1, mod_args) ): - accum_expr = replace_expression(expr, replacemap=_ReplacementDict(arg1, arg2)) + accum_expr = replace_expression(expr, replacemap=_ReplacementDict(arg1.expr, arg2.expr)) accumulation_terms.append((accum_expr, (arg1, arg2))) -- GitLab