Skip to content
Snippets Groups Projects
Commit 1bf40e76 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Merge branch 'feature/refactor-manual-cse' into 'master'

Feature/refactor manual cse

See merge request !114
parents 9a1cdd22 e305f154
No related branches found
No related tags found
No related merge requests found
...@@ -16,20 +16,20 @@ n = FacetNormal(cell)('+') ...@@ -16,20 +16,20 @@ n = FacetNormal(cell)('+')
alpha = 3.0 alpha = 3.0
h_ext = CellVolume(cell) / FacetArea(cell) h_ext = CellVolume(cell) / FacetArea(cell)
gamma_ext = (alpha * degree * (degree + dim - 1)) / h_ext cse_gamma_ext = (alpha * degree * (degree + dim - 1)) / h_ext
h_int = Min(CellVolume(cell)('+'), CellVolume(cell)('-')) / FacetArea(cell) h_int = Min(CellVolume(cell)('+'), CellVolume(cell)('-')) / FacetArea(cell)
gamma_int = (alpha * degree * (degree + dim - 1)) / h_int cse_gamma_int = (alpha * degree * (degree + dim - 1)) / h_int
theta = -1.0 theta = -1.0
r = (inner(A*grad(u), grad(v)) + (c*u-f)*v)*dx \ r = (inner(A*grad(u), grad(v)) + (c*u-f)*v)*dx \
+ inner(n, A*avg(grad(u)))*jump(v)*dS \ + inner(n, A*avg(grad(u)))*jump(v)*dS \
+ gamma_int*jump(u)*jump(v)*dS \ + cse_gamma_int*jump(u)*jump(v)*dS \
- theta*jump(u)*inner(A*avg(grad(v)), n)*dS \ - theta*jump(u)*inner(A*avg(grad(v)), n)*dS \
- inner(n, A*grad(u))*v*ds \ - inner(n, A*grad(u))*v*ds \
+ gamma_ext*u*v*ds \ + cse_gamma_ext*u*v*ds \
+ theta*u*inner(A*grad(v), n)*ds \ + theta*u*inner(A*grad(v), n)*ds \
- theta*g*inner(A*grad(v), n)*ds \ - theta*g*inner(A*grad(v), n)*ds \
- gamma_ext*g*v*ds - cse_gamma_ext*g*v*ds
forms = [r] forms = [r]
...@@ -12,21 +12,21 @@ n = FacetNormal(cell)('+') ...@@ -12,21 +12,21 @@ n = FacetNormal(cell)('+')
alpha = 1.0 alpha = 1.0
h_ext = CellVolume(cell) / FacetArea(cell) h_ext = CellVolume(cell) / FacetArea(cell)
gamma_ext = (alpha * degree * (degree + dim - 1)) / h_ext cse_gamma_ext = (alpha * degree * (degree + dim - 1)) / h_ext
h_int = Min(CellVolume(cell)('+'), CellVolume(cell)('-')) / FacetArea(cell) h_int = Min(CellVolume(cell)('+'), CellVolume(cell)('-')) / FacetArea(cell)
gamma_int = (alpha * degree * (degree + dim - 1)) / h_int cse_gamma_int = (alpha * degree * (degree + dim - 1)) / h_int
theta = -1.0 theta = -1.0
r = inner(grad(u), grad(v))*dx \ r = inner(grad(u), grad(v))*dx \
+ inner(n, avg(grad(u)))*jump(v)*dS \ + inner(n, avg(grad(u)))*jump(v)*dS \
+ gamma_int*jump(u)*jump(v)*dS \ + cse_gamma_int*jump(u)*jump(v)*dS \
- theta*jump(u)*inner(avg(grad(v)), n)*dS \ - theta*jump(u)*inner(avg(grad(v)), n)*dS \
- inner(n, grad(u))*v*ds \ - inner(n, grad(u))*v*ds \
+ gamma_ext*u*v*ds \ + cse_gamma_ext*u*v*ds \
+ theta*u*inner(grad(v), n)*ds \ + theta*u*inner(grad(v), n)*ds \
- f*v*dx \ - f*v*dx \
- theta*g*inner(grad(v), n)*ds \ - theta*g*inner(grad(v), n)*ds \
- gamma_ext*g*v*ds - cse_gamma_ext*g*v*ds
forms = [r] forms = [r]
...@@ -12,7 +12,6 @@ from ufl.algorithms.formfiles import interpret_ufl_namespace ...@@ -12,7 +12,6 @@ from ufl.algorithms.formfiles import interpret_ufl_namespace
from dune.perftool.generation import (delete_cache_items, from dune.perftool.generation import (delete_cache_items,
global_context, global_context,
subst_rule,
) )
from dune.perftool.interactive import start_interactive_session from dune.perftool.interactive import start_interactive_session
from dune.perftool.options import get_option from dune.perftool.options import get_option
......
...@@ -33,6 +33,7 @@ from dune.perftool.generation.cpp import (base_class, ...@@ -33,6 +33,7 @@ from dune.perftool.generation.cpp import (base_class,
from dune.perftool.generation.loopy import (barrier, from dune.perftool.generation.loopy import (barrier,
built_instruction, built_instruction,
constantarg, constantarg,
construct_subst_rule,
domain, domain,
function_mangler, function_mangler,
get_temporary_name, get_temporary_name,
......
...@@ -201,15 +201,29 @@ def loopy_class_member(name, classtag=None, potentially_vectorized=False, **kwar ...@@ -201,15 +201,29 @@ def loopy_class_member(name, classtag=None, potentially_vectorized=False, **kwar
return name return name
@generator_factory(item_tags=("substrule",), @generator_factory(item_tags=("substrule_name",),
context_tags="kernel",
cache_key_generator=lambda e, n: e)
def _substrule_name(expr, name):
return name
@generator_factory(item_tags=("substrule_impl",),
context_tags="kernel", context_tags="kernel",
cache_key_generator=lambda e, r: e, cache_key_generator=lambda n, e, **ex: e,
) )
def subst_rule(expr, rule): def subst_rule(name, expr, exists=False):
return rule _substrule_name(expr, name)
return exists
def set_subst_rule(name, expr):
subst_rule(name, expr, exists=True)
def set_subst_rule(name, expr, visitor):
rule = lp.SubstitutionRule(name, (), visitor(expr)) @generator_factory(item_tags=("substrule",),
subst_rule._memoize_cache = {k: v for k, v in subst_rule._memoize_cache.items() if v is not None} context_tags="kernel")
return subst_rule(expr, rule) def construct_subst_rule(expr, visitor):
name = _substrule_name(expr, None)
assert name
return lp.SubstitutionRule(name, (), visitor(expr))
...@@ -194,7 +194,7 @@ def collect_vector_data_rotate(knl): ...@@ -194,7 +194,7 @@ def collect_vector_data_rotate(knl):
replacemap_vec[expr] = prim.Subscript(prim.Variable(get_vector_view_name(quantity)), replacemap_vec[expr] = prim.Subscript(prim.Variable(get_vector_view_name(quantity)),
(vector_indices.get(vec_size) + last_index, prim.Variable(new_iname)), (vector_indices.get(vec_size) + last_index, prim.Variable(new_iname)),
) )
else: elif all(get_pymbolic_tag(expr) == 'sumfac' for expr in quantities[quantity]):
# Add a vector view to this quantity # Add a vector view to this quantity
expr, = quantities[quantity] expr, = quantities[quantity]
knl = add_vector_view(knl, quantity, flatview=True) knl = add_vector_view(knl, quantity, flatview=True)
......
...@@ -419,16 +419,16 @@ def visit_integrals(integrals): ...@@ -419,16 +419,16 @@ def visit_integrals(integrals):
data = get_global_context_value("data") data = get_global_context_value("data")
for name, expr in data.object_by_name.items(): for name, expr in data.object_by_name.items():
if name.startswith("cse"): if name.startswith("cse"):
set_subst_rule(name, expr, visitor) set_subst_rule(name, expr)
# Ensure CSE on detjac * quadrature weight # Ensure CSE on detjac * quadrature weight
domain = term.argument.argexpr.ufl_domain() domain = term.argument.argexpr.ufl_domain()
if measure == "cell": if measure == "cell":
set_subst_rule("integration_factor_cell1", uc.QuadratureWeight(domain)*uc.Abs(uc.JacobianDeterminant(domain)), visitor) set_subst_rule("integration_factor_cell1", uc.QuadratureWeight(domain)*uc.Abs(uc.JacobianDeterminant(domain)))
set_subst_rule("integration_factor_cell2", uc.Abs(uc.JacobianDeterminant(domain))*uc.QuadratureWeight(domain), visitor) set_subst_rule("integration_factor_cell2", uc.Abs(uc.JacobianDeterminant(domain))*uc.QuadratureWeight(domain))
else: else:
set_subst_rule("integration_factor_facet1", uc.FacetJacobianDeterminant(domain)*uc.QuadratureWeight(domain), visitor) set_subst_rule("integration_factor_facet1", uc.FacetJacobianDeterminant(domain)*uc.QuadratureWeight(domain))
set_subst_rule("integration_factor_facet2", uc.QuadratureWeight(domain)*uc.FacetJacobianDeterminant(domain), visitor) set_subst_rule("integration_factor_facet2", uc.QuadratureWeight(domain)*uc.FacetJacobianDeterminant(domain))
get_backend(interface="accum_insn")(visitor, term, measure, subdomain_id) get_backend(interface="accum_insn")(visitor, term, measure, subdomain_id)
......
...@@ -54,7 +54,7 @@ class AMatrix(ImmutableRecord): ...@@ -54,7 +54,7 @@ class AMatrix(ImmutableRecord):
return False return False
def output_to_pymbolic(self, name): def output_to_pymbolic(self, name):
return prim.Variable(name) return lp.TaggedVariable(name, "sumfac")
class LargeAMatrix(ImmutableRecord): class LargeAMatrix(ImmutableRecord):
......
...@@ -3,7 +3,8 @@ This module defines the main visitor algorithm transforming ufl expressions ...@@ -3,7 +3,8 @@ This module defines the main visitor algorithm transforming ufl expressions
to pymbolic and loopy. to pymbolic and loopy.
""" """
from dune.perftool.error import PerftoolUFLError from dune.perftool.error import PerftoolUFLError
from dune.perftool.generation import (domain, from dune.perftool.generation import (construct_subst_rule,
domain,
get_global_context_value, get_global_context_value,
subst_rule, subst_rule,
) )
...@@ -38,10 +39,11 @@ import pymbolic.primitives as prim ...@@ -38,10 +39,11 @@ import pymbolic.primitives as prim
class UFL2LoopyVisitor(ModifiedTerminalTracker): class UFL2LoopyVisitor(ModifiedTerminalTracker):
def __init__(self, interface, measure, dimension_indices): def __init__(self, interface, measure, dimension_indices, donot_check_substrules=None):
self.interface = interface self.interface = interface
self.measure = measure self.measure = measure
self.dimension_indices = dimension_indices self.dimension_indices = dimension_indices
self.donot_check_substrules = donot_check_substrules
# Call base class constructors # Call base class constructors
super(UFL2LoopyVisitor, self).__init__() super(UFL2LoopyVisitor, self).__init__()
...@@ -55,8 +57,14 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker): ...@@ -55,8 +57,14 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker):
return self.call(o) return self.call(o)
def call(self, o): def call(self, o):
rule = subst_rule(o, None) if o != self.donot_check_substrules and subst_rule(None, o):
if rule: rule = construct_subst_rule(o,
type(self)(self.interface,
self.measure,
self.dimension_indices,
donot_check_substrules=o,
)
)
return prim.Call(prim.Variable(rule.name), ()) return prim.Call(prim.Variable(rule.name), ())
else: else:
return MultiFunction.__call__(self, o) return MultiFunction.__call__(self, o)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment