Skip to content
Snippets Groups Projects
Commit 3923568e authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Introduce a general infrastructure to detect common subexpressions

parent bfa83c94
No related branches found
No related tags found
No related merge requests found
......@@ -10,7 +10,10 @@ import loopy
from ufl.algorithms import compute_form_data, read_ufl_file
from ufl.algorithms.formfiles import interpret_ufl_namespace
from dune.perftool.generation import delete_cache_items, global_context
from dune.perftool.generation import (delete_cache_items,
global_context,
subst_rule,
)
from dune.perftool.interactive import start_interactive_session
from dune.perftool.options import get_option
from dune.perftool.pdelab.driver import generate_driver
......@@ -82,6 +85,9 @@ def read_ufl(uflfile):
# Enrich data by some additional objects we deem worth keeping
if get_option("exact_solution_expression"):
data.object_by_name[get_option("exact_solution_expression")] = namespace[get_option("exact_solution_expression")]
for name, expr in namespace.items():
if name.startswith("cse"):
data.object_by_name[name] = namespace[name]
formdatas = []
forms = data.forms
......
......@@ -43,6 +43,7 @@ from dune.perftool.generation.loopy import (barrier,
kernel_cached,
noop_instruction,
silenced_warning,
subst_rule,
temporary_variable,
transform,
valuearg,
......
......@@ -199,3 +199,13 @@ def loopy_class_member(name, classtag=None, potentially_vectorized=False, **kwar
globalarg(name, **kwargs)
return name
@generator_factory(item_tags=("substrule",),
context_tags="kernel",
cache_key_generator=lambda n, e, v: e,
)
def subst_rule(name, expr, visitor):
if name is None:
raise ValueError("Requested a Substitution Rule that was not predefined!")
return lp.SubstitutionRule(name, (), visitor(expr))
......@@ -19,6 +19,7 @@ from dune.perftool.generation import (backend,
post_include,
retrieve_cache_functions,
retrieve_cache_items,
subst_rule,
template_parameter,
)
from dune.perftool.cgen.clazz import (AccessModifier,
......@@ -412,6 +413,13 @@ def visit_integrals(integrals):
interface = PDELabInterface()
from dune.perftool.ufl.visitor import UFL2LoopyVisitor
visitor = UFL2LoopyVisitor(interface, measure, indexmap)
# Inspect the UFL file for manual common subexpression elimination stuff!
data = get_global_context_value("data")
for name, expr in data.object_by_name.items():
if name.startswith("cse"):
subst_rule(name, expr, visitor)
get_backend(interface="accum_insn")(visitor, term, measure, subdomain_id)
......
......@@ -5,6 +5,7 @@ to pymbolic and loopy.
from dune.perftool.error import PerftoolUFLError
from dune.perftool.generation import (domain,
get_global_context_value,
subst_rule,
)
from dune.perftool.ufl.flatoperators import get_operands
from dune.perftool.ufl.modified_terminals import (ModifiedTerminalTracker,
......@@ -45,9 +46,6 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker):
# Call base class constructors
super(UFL2LoopyVisitor, self).__init__()
# Allow recursion through self.call(..)
call = MultiFunction.__call__
def __call__(self, o):
# Reset some state variables that are reinitialized for each accumulation term
self.indices = None
......@@ -56,6 +54,12 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker):
return self.call(o)
def call(self, o):
try:
return subst_rule(None, o, self)
except ValueError:
return MultiFunction.__call__(self, o)
#
# Form argument/coefficients handlers:
# This is where the actual domain specific work happens
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment