From f7705ce6e794eb14f302d9c533ee2a7a5d65b11b Mon Sep 17 00:00:00 2001 From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de> Date: Fri, 3 Feb 2017 16:41:40 +0100 Subject: [PATCH] Refactor substitution rules to not show up in *all* kernels But only when they are actually used! --- python/dune/perftool/compile.py | 1 - python/dune/perftool/generation/__init__.py | 1 + python/dune/perftool/generation/loopy.py | 30 ++++++++++++++------ python/dune/perftool/pdelab/localoperator.py | 10 +++---- python/dune/perftool/ufl/visitor.py | 10 ++++--- 5 files changed, 34 insertions(+), 18 deletions(-) diff --git a/python/dune/perftool/compile.py b/python/dune/perftool/compile.py index 8e537016..d5427fa6 100644 --- a/python/dune/perftool/compile.py +++ b/python/dune/perftool/compile.py @@ -12,7 +12,6 @@ from ufl.algorithms.formfiles import interpret_ufl_namespace from dune.perftool.generation import (delete_cache_items, global_context, - subst_rule, ) from dune.perftool.interactive import start_interactive_session from dune.perftool.options import get_option diff --git a/python/dune/perftool/generation/__init__.py b/python/dune/perftool/generation/__init__.py index c7f16dfc..9d71f86c 100644 --- a/python/dune/perftool/generation/__init__.py +++ b/python/dune/perftool/generation/__init__.py @@ -33,6 +33,7 @@ from dune.perftool.generation.cpp import (base_class, from dune.perftool.generation.loopy import (barrier, built_instruction, constantarg, + construct_subst_rule, domain, function_mangler, get_temporary_name, diff --git a/python/dune/perftool/generation/loopy.py b/python/dune/perftool/generation/loopy.py index be835b7a..8bd656eb 100644 --- a/python/dune/perftool/generation/loopy.py +++ b/python/dune/perftool/generation/loopy.py @@ -201,15 +201,29 @@ def loopy_class_member(name, classtag=None, potentially_vectorized=False, **kwar return name -@generator_factory(item_tags=("substrule",), +@generator_factory(item_tags=("substrule_name",), + context_tags="kernel", + cache_key_generator=lambda e, n: e) +def _substrule_name(expr, name): + return name + + +@generator_factory(item_tags=("substrule_impl",), context_tags="kernel", - cache_key_generator=lambda e, r: e, + cache_key_generator=lambda n, e, **ex: e, ) -def subst_rule(expr, rule): - return rule +def subst_rule(name, expr, exists=False): + _substrule_name(expr, name) + return exists + +def set_subst_rule(name, expr): + subst_rule(name, expr, exists=True) -def set_subst_rule(name, expr, visitor): - rule = lp.SubstitutionRule(name, (), visitor(expr)) - subst_rule._memoize_cache = {k: v for k, v in subst_rule._memoize_cache.items() if v is not None} - return subst_rule(expr, rule) + +@generator_factory(item_tags=("substrule",), + context_tags="kernel") +def construct_subst_rule(expr, visitor): + name = _substrule_name(expr, None) + assert name + return lp.SubstitutionRule(name, (), visitor(expr, donot_check_substrules=expr)) diff --git a/python/dune/perftool/pdelab/localoperator.py b/python/dune/perftool/pdelab/localoperator.py index adda9c7f..1f6c899b 100644 --- a/python/dune/perftool/pdelab/localoperator.py +++ b/python/dune/perftool/pdelab/localoperator.py @@ -419,16 +419,16 @@ def visit_integrals(integrals): data = get_global_context_value("data") for name, expr in data.object_by_name.items(): if name.startswith("cse"): - set_subst_rule(name, expr, visitor) + set_subst_rule(name, expr) # Ensure CSE on detjac * quadrature weight domain = term.argument.argexpr.ufl_domain() if measure == "cell": - set_subst_rule("integration_factor_cell1", uc.QuadratureWeight(domain)*uc.Abs(uc.JacobianDeterminant(domain)), visitor) - set_subst_rule("integration_factor_cell2", uc.Abs(uc.JacobianDeterminant(domain))*uc.QuadratureWeight(domain), visitor) + set_subst_rule("integration_factor_cell1", uc.QuadratureWeight(domain)*uc.Abs(uc.JacobianDeterminant(domain))) + set_subst_rule("integration_factor_cell2", uc.Abs(uc.JacobianDeterminant(domain))*uc.QuadratureWeight(domain)) else: - set_subst_rule("integration_factor_facet1", uc.FacetJacobianDeterminant(domain)*uc.QuadratureWeight(domain), visitor) - set_subst_rule("integration_factor_facet2", uc.QuadratureWeight(domain)*uc.FacetJacobianDeterminant(domain), visitor) + set_subst_rule("integration_factor_facet1", uc.FacetJacobianDeterminant(domain)*uc.QuadratureWeight(domain)) + set_subst_rule("integration_factor_facet2", uc.QuadratureWeight(domain)*uc.FacetJacobianDeterminant(domain)) get_backend(interface="accum_insn")(visitor, term, measure, subdomain_id) diff --git a/python/dune/perftool/ufl/visitor.py b/python/dune/perftool/ufl/visitor.py index f107d94b..c5e33a3d 100644 --- a/python/dune/perftool/ufl/visitor.py +++ b/python/dune/perftool/ufl/visitor.py @@ -3,7 +3,8 @@ This module defines the main visitor algorithm transforming ufl expressions to pymbolic and loopy. """ from dune.perftool.error import PerftoolUFLError -from dune.perftool.generation import (domain, +from dune.perftool.generation import (construct_subst_rule, + domain, get_global_context_value, subst_rule, ) @@ -46,17 +47,18 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker): # Call base class constructors super(UFL2LoopyVisitor, self).__init__() - def __call__(self, o): + def __call__(self, o, donot_check_substrules=None): # Reset some state variables that are reinitialized for each accumulation term self.indices = None self._indices_backup = [] self.inames = () + self.donot_check_substrules = donot_check_substrules return self.call(o) def call(self, o): - rule = subst_rule(o, None) - if rule: + if o != self.donot_check_substrules and subst_rule(None, o): + rule = construct_subst_rule(o, self) return prim.Call(prim.Variable(rule.name), ()) else: return MultiFunction.__call__(self, o) -- GitLab