Skip to content
Snippets Groups Projects
Commit f7705ce6 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Refactor substitution rules to not show up in *all* kernels

But only when they are actually used!
parent 9a1cdd22
No related branches found
No related tags found
No related merge requests found
...@@ -12,7 +12,6 @@ from ufl.algorithms.formfiles import interpret_ufl_namespace ...@@ -12,7 +12,6 @@ from ufl.algorithms.formfiles import interpret_ufl_namespace
from dune.perftool.generation import (delete_cache_items, from dune.perftool.generation import (delete_cache_items,
global_context, global_context,
subst_rule,
) )
from dune.perftool.interactive import start_interactive_session from dune.perftool.interactive import start_interactive_session
from dune.perftool.options import get_option from dune.perftool.options import get_option
......
...@@ -33,6 +33,7 @@ from dune.perftool.generation.cpp import (base_class, ...@@ -33,6 +33,7 @@ from dune.perftool.generation.cpp import (base_class,
from dune.perftool.generation.loopy import (barrier, from dune.perftool.generation.loopy import (barrier,
built_instruction, built_instruction,
constantarg, constantarg,
construct_subst_rule,
domain, domain,
function_mangler, function_mangler,
get_temporary_name, get_temporary_name,
......
...@@ -201,15 +201,29 @@ def loopy_class_member(name, classtag=None, potentially_vectorized=False, **kwar ...@@ -201,15 +201,29 @@ def loopy_class_member(name, classtag=None, potentially_vectorized=False, **kwar
return name return name
@generator_factory(item_tags=("substrule",), @generator_factory(item_tags=("substrule_name",),
context_tags="kernel",
cache_key_generator=lambda e, n: e)
def _substrule_name(expr, name):
return name
@generator_factory(item_tags=("substrule_impl",),
context_tags="kernel", context_tags="kernel",
cache_key_generator=lambda e, r: e, cache_key_generator=lambda n, e, **ex: e,
) )
def subst_rule(expr, rule): def subst_rule(name, expr, exists=False):
return rule _substrule_name(expr, name)
return exists
def set_subst_rule(name, expr):
subst_rule(name, expr, exists=True)
def set_subst_rule(name, expr, visitor):
rule = lp.SubstitutionRule(name, (), visitor(expr)) @generator_factory(item_tags=("substrule",),
subst_rule._memoize_cache = {k: v for k, v in subst_rule._memoize_cache.items() if v is not None} context_tags="kernel")
return subst_rule(expr, rule) def construct_subst_rule(expr, visitor):
name = _substrule_name(expr, None)
assert name
return lp.SubstitutionRule(name, (), visitor(expr, donot_check_substrules=expr))
...@@ -419,16 +419,16 @@ def visit_integrals(integrals): ...@@ -419,16 +419,16 @@ def visit_integrals(integrals):
data = get_global_context_value("data") data = get_global_context_value("data")
for name, expr in data.object_by_name.items(): for name, expr in data.object_by_name.items():
if name.startswith("cse"): if name.startswith("cse"):
set_subst_rule(name, expr, visitor) set_subst_rule(name, expr)
# Ensure CSE on detjac * quadrature weight # Ensure CSE on detjac * quadrature weight
domain = term.argument.argexpr.ufl_domain() domain = term.argument.argexpr.ufl_domain()
if measure == "cell": if measure == "cell":
set_subst_rule("integration_factor_cell1", uc.QuadratureWeight(domain)*uc.Abs(uc.JacobianDeterminant(domain)), visitor) set_subst_rule("integration_factor_cell1", uc.QuadratureWeight(domain)*uc.Abs(uc.JacobianDeterminant(domain)))
set_subst_rule("integration_factor_cell2", uc.Abs(uc.JacobianDeterminant(domain))*uc.QuadratureWeight(domain), visitor) set_subst_rule("integration_factor_cell2", uc.Abs(uc.JacobianDeterminant(domain))*uc.QuadratureWeight(domain))
else: else:
set_subst_rule("integration_factor_facet1", uc.FacetJacobianDeterminant(domain)*uc.QuadratureWeight(domain), visitor) set_subst_rule("integration_factor_facet1", uc.FacetJacobianDeterminant(domain)*uc.QuadratureWeight(domain))
set_subst_rule("integration_factor_facet2", uc.QuadratureWeight(domain)*uc.FacetJacobianDeterminant(domain), visitor) set_subst_rule("integration_factor_facet2", uc.QuadratureWeight(domain)*uc.FacetJacobianDeterminant(domain))
get_backend(interface="accum_insn")(visitor, term, measure, subdomain_id) get_backend(interface="accum_insn")(visitor, term, measure, subdomain_id)
......
...@@ -3,7 +3,8 @@ This module defines the main visitor algorithm transforming ufl expressions ...@@ -3,7 +3,8 @@ This module defines the main visitor algorithm transforming ufl expressions
to pymbolic and loopy. to pymbolic and loopy.
""" """
from dune.perftool.error import PerftoolUFLError from dune.perftool.error import PerftoolUFLError
from dune.perftool.generation import (domain, from dune.perftool.generation import (construct_subst_rule,
domain,
get_global_context_value, get_global_context_value,
subst_rule, subst_rule,
) )
...@@ -46,17 +47,18 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker): ...@@ -46,17 +47,18 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker):
# Call base class constructors # Call base class constructors
super(UFL2LoopyVisitor, self).__init__() super(UFL2LoopyVisitor, self).__init__()
def __call__(self, o): def __call__(self, o, donot_check_substrules=None):
# Reset some state variables that are reinitialized for each accumulation term # Reset some state variables that are reinitialized for each accumulation term
self.indices = None self.indices = None
self._indices_backup = [] self._indices_backup = []
self.inames = () self.inames = ()
self.donot_check_substrules = donot_check_substrules
return self.call(o) return self.call(o)
def call(self, o): def call(self, o):
rule = subst_rule(o, None) if o != self.donot_check_substrules and subst_rule(None, o):
if rule: rule = construct_subst_rule(o, self)
return prim.Call(prim.Variable(rule.name), ()) return prim.Call(prim.Variable(rule.name), ())
else: else:
return MultiFunction.__call__(self, o) return MultiFunction.__call__(self, o)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment