Skip to content
Snippets Groups Projects
Commit f60941e5 authored by René Heß's avatar René Heß
Browse files

Merge branch 'feature/manual-cse-strategy' into 'master'

Manual CSE

Closes #50

See merge request !106
parents bfa83c94 f7fba4c0
No related branches found
No related tags found
No related merge requests found
......@@ -10,7 +10,10 @@ import loopy
from ufl.algorithms import compute_form_data, read_ufl_file
from ufl.algorithms.formfiles import interpret_ufl_namespace
from dune.perftool.generation import delete_cache_items, global_context
from dune.perftool.generation import (delete_cache_items,
global_context,
subst_rule,
)
from dune.perftool.interactive import start_interactive_session
from dune.perftool.options import get_option
from dune.perftool.pdelab.driver import generate_driver
......@@ -82,6 +85,9 @@ def read_ufl(uflfile):
# Enrich data by some additional objects we deem worth keeping
if get_option("exact_solution_expression"):
data.object_by_name[get_option("exact_solution_expression")] = namespace[get_option("exact_solution_expression")]
for name, expr in namespace.items():
if name.startswith("cse"):
data.object_by_name[name] = namespace[name]
formdatas = []
forms = data.forms
......
......@@ -43,6 +43,8 @@ from dune.perftool.generation.loopy import (barrier,
kernel_cached,
noop_instruction,
silenced_warning,
set_subst_rule,
subst_rule,
temporary_variable,
transform,
valuearg,
......
......@@ -107,7 +107,7 @@ class _RegisteredFunction(object):
# get an unknown keyword...
try:
val = self.on_store(self.func(*args, **kwargs))
except:
except TypeError:
val = self.on_store(self.func(*args, **without_context))
# Maybe wrap it with a counter!
......@@ -119,9 +119,6 @@ class _RegisteredFunction(object):
# Return the result for immediate usage
return self._get_content(cache_key)
def remove_by_value(self, val):
self._memoize_cache = {k: v for k, v in self._memoize_cache.items() if v != val}
def generator_factory(**factory_kwargs):
""" A function decorator factory
......
......@@ -199,3 +199,17 @@ def loopy_class_member(name, classtag=None, potentially_vectorized=False, **kwar
globalarg(name, **kwargs)
return name
@generator_factory(item_tags=("substrule",),
context_tags="kernel",
cache_key_generator=lambda e, r: e,
)
def subst_rule(expr, rule):
return rule
def set_subst_rule(name, expr, visitor):
rule = lp.SubstitutionRule(name, (), visitor(expr))
subst_rule._memoize_cache = {k: v for k, v in subst_rule._memoize_cache.items() if v is not None}
return subst_rule(expr, rule)
......@@ -19,6 +19,7 @@ from dune.perftool.generation import (backend,
post_include,
retrieve_cache_functions,
retrieve_cache_items,
set_subst_rule,
template_parameter,
)
from dune.perftool.cgen.clazz import (AccessModifier,
......@@ -33,6 +34,7 @@ from pymbolic.primitives import Variable
import pymbolic.primitives as prim
from pytools import Record
import ufl.classes as uc
import loopy as lp
import cgen
......@@ -412,6 +414,22 @@ def visit_integrals(integrals):
interface = PDELabInterface()
from dune.perftool.ufl.visitor import UFL2LoopyVisitor
visitor = UFL2LoopyVisitor(interface, measure, indexmap)
# Inspect the UFL file for manual common subexpression elimination stuff!
data = get_global_context_value("data")
for name, expr in data.object_by_name.items():
if name.startswith("cse"):
set_subst_rule(name, expr, visitor)
# Ensure CSE on detjac * quadrature weight
domain = term.argument.argexpr.ufl_domain()
if measure == "cell":
set_subst_rule("integration_factor_cell1", uc.QuadratureWeight(domain)*uc.Abs(uc.JacobianDeterminant(domain)), visitor)
set_subst_rule("integration_factor_cell2", uc.Abs(uc.JacobianDeterminant(domain))*uc.QuadratureWeight(domain), visitor)
else:
set_subst_rule("integration_factor_facet1", uc.FacetJacobianDeterminant(domain)*uc.QuadratureWeight(domain), visitor)
set_subst_rule("integration_factor_facet2", uc.QuadratureWeight(domain)*uc.FacetJacobianDeterminant(domain), visitor)
get_backend(interface="accum_insn")(visitor, term, measure, subdomain_id)
......@@ -451,6 +469,7 @@ def extract_kernel_from_cache(tag, wrap_in_cgen=True):
domains = ["{[stupid] : 0<=stupid<1}"]
instructions = [i for i in retrieve_cache_items("{} and instruction".format(tag))]
substrules = [i for i in retrieve_cache_items("{} and substrule".format(tag)) if i is not None]
temporaries = {i.name: i for i in retrieve_cache_items("{} and temporary".format(tag))}
arguments = [i for i in retrieve_cache_items("{} and argument".format(tag))]
silenced = [l for l in retrieve_cache_items("{} and silenced_warning".format(tag))]
......@@ -470,7 +489,7 @@ def extract_kernel_from_cache(tag, wrap_in_cgen=True):
# Create the kernel
from loopy import make_kernel, preprocess_kernel
kernel = make_kernel(domains,
instructions,
instructions + substrules,
arguments,
temporary_variables=temporaries,
target=DuneTarget(),
......@@ -486,6 +505,18 @@ def extract_kernel_from_cache(tag, wrap_in_cgen=True):
for trafo in transformations:
kernel = trafo[0](kernel, *trafo[1])
# Precompute all the substrules
for sr in kernel.substitutions:
tmpname = "precompute_{}".format(sr)
kernel = lp.precompute(kernel,
sr,
temporary_name=tmpname,
)
# Vectorization strategies are actually very likely to eliminate the
# precomputation temporary. To avoid the temporary elimination warning
# we need to explicitly disable it.
kernel = kernel.copy(silenced_warnings=kernel.silenced_warnings + ["temp_to_write({})".format(tmpname)])
from dune.perftool.loopy import heuristic_duplication
kernel = heuristic_duplication(kernel)
......
......@@ -5,6 +5,7 @@ to pymbolic and loopy.
from dune.perftool.error import PerftoolUFLError
from dune.perftool.generation import (domain,
get_global_context_value,
subst_rule,
)
from dune.perftool.ufl.flatoperators import get_operands
from dune.perftool.ufl.modified_terminals import (ModifiedTerminalTracker,
......@@ -45,9 +46,6 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker):
# Call base class constructors
super(UFL2LoopyVisitor, self).__init__()
# Allow recursion through self.call(..)
call = MultiFunction.__call__
def __call__(self, o):
# Reset some state variables that are reinitialized for each accumulation term
self.indices = None
......@@ -56,6 +54,13 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker):
return self.call(o)
def call(self, o):
rule = subst_rule(o, None)
if rule:
return prim.Call(prim.Variable(rule.name), ())
else:
return MultiFunction.__call__(self, o)
#
# Form argument/coefficients handlers:
# This is where the actual domain specific work happens
......@@ -222,7 +227,7 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker):
#
def product(self, o):
return Product(tuple(self.call(op) for op in get_operands(o)))
return Product(tuple(self.call(op) for op in o.ufl_operands))
def float_value(self, o):
return o.value()
......@@ -234,7 +239,7 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker):
return Quotient(self.call(o.ufl_operands[0]), self.call(o.ufl_operands[1]))
def sum(self, o):
return Sum(tuple(self.call(op) for op in get_operands(o)))
return Sum(tuple(self.call(op) for op in o.ufl_operands))
def zero(self, o):
return 0
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment