From 0bc540722b2e4cf0172598afd4abe4148243792a Mon Sep 17 00:00:00 2001 From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de> Date: Tue, 24 Jan 2017 13:08:45 +0100 Subject: [PATCH] First implementation of manual CSE --- python/dune/perftool/generation/__init__.py | 1 + python/dune/perftool/generation/cache.py | 5 +---- python/dune/perftool/generation/loopy.py | 15 ++++++++++----- python/dune/perftool/pdelab/localoperator.py | 19 ++++++++++++------- python/dune/perftool/ufl/visitor.py | 11 ++++++----- 5 files changed, 30 insertions(+), 21 deletions(-) diff --git a/python/dune/perftool/generation/__init__.py b/python/dune/perftool/generation/__init__.py index 0af1aff7..c7f16dfc 100644 --- a/python/dune/perftool/generation/__init__.py +++ b/python/dune/perftool/generation/__init__.py @@ -43,6 +43,7 @@ from dune.perftool.generation.loopy import (barrier, kernel_cached, noop_instruction, silenced_warning, + set_subst_rule, subst_rule, temporary_variable, transform, diff --git a/python/dune/perftool/generation/cache.py b/python/dune/perftool/generation/cache.py index f0823bb5..29e128e0 100644 --- a/python/dune/perftool/generation/cache.py +++ b/python/dune/perftool/generation/cache.py @@ -107,7 +107,7 @@ class _RegisteredFunction(object): # get an unknown keyword... try: val = self.on_store(self.func(*args, **kwargs)) - except: + except TypeError: val = self.on_store(self.func(*args, **without_context)) # Maybe wrap it with a counter! @@ -119,9 +119,6 @@ class _RegisteredFunction(object): # Return the result for immediate usage return self._get_content(cache_key) - def remove_by_value(self, val): - self._memoize_cache = {k: v for k, v in self._memoize_cache.items() if v != val} - def generator_factory(**factory_kwargs): """ A function decorator factory diff --git a/python/dune/perftool/generation/loopy.py b/python/dune/perftool/generation/loopy.py index ac9f91bd..b618f0bc 100644 --- a/python/dune/perftool/generation/loopy.py +++ b/python/dune/perftool/generation/loopy.py @@ -203,9 +203,14 @@ def loopy_class_member(name, classtag=None, potentially_vectorized=False, **kwar @generator_factory(item_tags=("substrule",), context_tags="kernel", - cache_key_generator=lambda n, e, v: e, + cache_key_generator=lambda e, r: e, ) -def subst_rule(name, expr, visitor): - if name is None: - raise ValueError("Requested a Substitution Rule that was not predefined!") - return lp.SubstitutionRule(name, (), visitor(expr)) +def subst_rule(expr, rule): + return rule + + +def set_subst_rule(name, expr, visitor): + rule = lp.SubstitutionRule(name, (), visitor(expr)) + subst_rule._memoize_cache = {k: v for k, v in subst_rule._memoize_cache.items() if v is not None} + print [type(v) for v in subst_rule._memoize_cache.values()] + return subst_rule(expr, rule) diff --git a/python/dune/perftool/pdelab/localoperator.py b/python/dune/perftool/pdelab/localoperator.py index 5d53d25d..ecd2f787 100644 --- a/python/dune/perftool/pdelab/localoperator.py +++ b/python/dune/perftool/pdelab/localoperator.py @@ -19,7 +19,7 @@ from dune.perftool.generation import (backend, post_include, retrieve_cache_functions, retrieve_cache_items, - subst_rule, + set_subst_rule, template_parameter, ) from dune.perftool.cgen.clazz import (AccessModifier, @@ -419,14 +419,14 @@ def visit_integrals(integrals): data = get_global_context_value("data") for name, expr in data.object_by_name.items(): if name.startswith("cse"): - subst_rule(name, expr, visitor) + set_subst_rule(name, expr, visitor) # Ensure CSE on detjac * quadrature weight domain = term.argument.argexpr.ufl_domain() if term.argument.restriction: - subst_rule("integration_factor", uc.FacetJacobianDeterminant(domain)*uc.QuadratureWeight(domain), visitor) + set_subst_rule("integration_factor", uc.FacetJacobianDeterminant(domain)*uc.QuadratureWeight(domain), visitor) else: - subst_rule("integration_factor", uc.JacobianDeterminant(domain)*uc.QuadratureWeight(domain), visitor) + set_subst_rule("integration_factor", uc.QuadratureWeight(domain)*uc.Abs(uc.JacobianDeterminant(domain)), visitor) get_backend(interface="accum_insn")(visitor, term, measure, subdomain_id) @@ -467,6 +467,7 @@ def extract_kernel_from_cache(tag, wrap_in_cgen=True): domains = ["{[stupid] : 0<=stupid<1}"] instructions = [i for i in retrieve_cache_items("{} and instruction".format(tag))] + substrules = [i for i in retrieve_cache_items("{} and substrule".format(tag)) if i is not None] temporaries = {i.name: i for i in retrieve_cache_items("{} and temporary".format(tag))} arguments = [i for i in retrieve_cache_items("{} and argument".format(tag))] silenced = [l for l in retrieve_cache_items("{} and silenced_warning".format(tag))] @@ -486,7 +487,7 @@ def extract_kernel_from_cache(tag, wrap_in_cgen=True): # Create the kernel from loopy import make_kernel, preprocess_kernel kernel = make_kernel(domains, - instructions, + instructions + substrules, arguments, temporary_variables=temporaries, target=DuneTarget(), @@ -494,14 +495,18 @@ def extract_kernel_from_cache(tag, wrap_in_cgen=True): silenced_warnings=silenced, name=name, ) - from loopy import make_reduction_inames_unique - kernel = make_reduction_inames_unique(kernel) + #kernel = make_reduction_inames_unique(kernel) # Apply the transformations that were gathered during tree traversals for trafo in transformations: kernel = trafo[0](kernel, *trafo[1]) + # Precompute all the substrules + for sr in substrules: + print sr.name + kernel = lp.precompute(kernel, sr.name) + from dune.perftool.loopy import heuristic_duplication kernel = heuristic_duplication(kernel) diff --git a/python/dune/perftool/ufl/visitor.py b/python/dune/perftool/ufl/visitor.py index 76d169c9..1ba817c8 100644 --- a/python/dune/perftool/ufl/visitor.py +++ b/python/dune/perftool/ufl/visitor.py @@ -55,9 +55,10 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker): return self.call(o) def call(self, o): - try: - return subst_rule(None, o, self) - except ValueError: + rule = subst_rule(o, None) + if rule: + return prim.Call(prim.Variable(rule.name), ()) + else: return MultiFunction.__call__(self, o) # @@ -226,7 +227,7 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker): # def product(self, o): - return Product(tuple(self.call(op) for op in get_operands(o))) + return Product(tuple(self.call(op) for op in o.ufl_operands)) def float_value(self, o): return o.value() @@ -238,7 +239,7 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker): return Quotient(self.call(o.ufl_operands[0]), self.call(o.ufl_operands[1])) def sum(self, o): - return Sum(tuple(self.call(op) for op in get_operands(o))) + return Sum(tuple(self.call(op) for op in o.ufl_operands)) def zero(self, o): return 0 -- GitLab