diff --git a/applications/convection_diffusion/conv_diff_dg.ufl b/applications/convection_diffusion/conv_diff_dg.ufl index d3c95a5226e26981cf78e5d8b78a686150a9310d..00267f15b9331b4705138cfb1a4698ee357a5767 100644 --- a/applications/convection_diffusion/conv_diff_dg.ufl +++ b/applications/convection_diffusion/conv_diff_dg.ufl @@ -16,20 +16,20 @@ n = FacetNormal(cell)('+') alpha = 3.0 h_ext = CellVolume(cell) / FacetArea(cell) -gamma_ext = (alpha * degree * (degree + dim - 1)) / h_ext +cse_gamma_ext = (alpha * degree * (degree + dim - 1)) / h_ext h_int = Min(CellVolume(cell)('+'), CellVolume(cell)('-')) / FacetArea(cell) -gamma_int = (alpha * degree * (degree + dim - 1)) / h_int +cse_gamma_int = (alpha * degree * (degree + dim - 1)) / h_int theta = -1.0 r = (inner(A*grad(u), grad(v)) + (c*u-f)*v)*dx \ + inner(n, A*avg(grad(u)))*jump(v)*dS \ - + gamma_int*jump(u)*jump(v)*dS \ + + cse_gamma_int*jump(u)*jump(v)*dS \ - theta*jump(u)*inner(A*avg(grad(v)), n)*dS \ - inner(n, A*grad(u))*v*ds \ - + gamma_ext*u*v*ds \ + + cse_gamma_ext*u*v*ds \ + theta*u*inner(A*grad(v), n)*ds \ - theta*g*inner(A*grad(v), n)*ds \ - - gamma_ext*g*v*ds + - cse_gamma_ext*g*v*ds forms = [r] diff --git a/applications/poisson_dg/poisson_dg.ufl b/applications/poisson_dg/poisson_dg.ufl index 7b5c3e548f81688a0735256411fd54e9d49fa9f4..2c0f2c34436f1c6b76817d6abad344fae3183d62 100644 --- a/applications/poisson_dg/poisson_dg.ufl +++ b/applications/poisson_dg/poisson_dg.ufl @@ -12,21 +12,21 @@ n = FacetNormal(cell)('+') alpha = 1.0 h_ext = CellVolume(cell) / FacetArea(cell) -gamma_ext = (alpha * degree * (degree + dim - 1)) / h_ext +cse_gamma_ext = (alpha * degree * (degree + dim - 1)) / h_ext h_int = Min(CellVolume(cell)('+'), CellVolume(cell)('-')) / FacetArea(cell) -gamma_int = (alpha * degree * (degree + dim - 1)) / h_int +cse_gamma_int = (alpha * degree * (degree + dim - 1)) / h_int theta = -1.0 r = inner(grad(u), grad(v))*dx \ + inner(n, avg(grad(u)))*jump(v)*dS \ - + gamma_int*jump(u)*jump(v)*dS \ + + cse_gamma_int*jump(u)*jump(v)*dS \ - theta*jump(u)*inner(avg(grad(v)), n)*dS \ - inner(n, grad(u))*v*ds \ - + gamma_ext*u*v*ds \ + + cse_gamma_ext*u*v*ds \ + theta*u*inner(grad(v), n)*ds \ - f*v*dx \ - theta*g*inner(grad(v), n)*ds \ - - gamma_ext*g*v*ds + - cse_gamma_ext*g*v*ds forms = [r] diff --git a/python/dune/perftool/compile.py b/python/dune/perftool/compile.py index 8e537016b7acd691dca9f7425580fdf6f6798eb8..d5427fa6fe15f98d10ab0e3e491abc0ae7697432 100644 --- a/python/dune/perftool/compile.py +++ b/python/dune/perftool/compile.py @@ -12,7 +12,6 @@ from ufl.algorithms.formfiles import interpret_ufl_namespace from dune.perftool.generation import (delete_cache_items, global_context, - subst_rule, ) from dune.perftool.interactive import start_interactive_session from dune.perftool.options import get_option diff --git a/python/dune/perftool/generation/__init__.py b/python/dune/perftool/generation/__init__.py index c7f16dfc85c1d6e8fbf040d640a69baf7e082e12..9d71f86c0913fd766112ec9ccbc8c91e8d9b55f9 100644 --- a/python/dune/perftool/generation/__init__.py +++ b/python/dune/perftool/generation/__init__.py @@ -33,6 +33,7 @@ from dune.perftool.generation.cpp import (base_class, from dune.perftool.generation.loopy import (barrier, built_instruction, constantarg, + construct_subst_rule, domain, function_mangler, get_temporary_name, diff --git a/python/dune/perftool/generation/loopy.py b/python/dune/perftool/generation/loopy.py index be835b7a2dd33cc21355ef5a872197c0e442810e..f01bdf03f85c412f860f33ed1df202f81c5dda52 100644 --- a/python/dune/perftool/generation/loopy.py +++ b/python/dune/perftool/generation/loopy.py @@ -201,15 +201,29 @@ def loopy_class_member(name, classtag=None, potentially_vectorized=False, **kwar return name -@generator_factory(item_tags=("substrule",), +@generator_factory(item_tags=("substrule_name",), + context_tags="kernel", + cache_key_generator=lambda e, n: e) +def _substrule_name(expr, name): + return name + + +@generator_factory(item_tags=("substrule_impl",), context_tags="kernel", - cache_key_generator=lambda e, r: e, + cache_key_generator=lambda n, e, **ex: e, ) -def subst_rule(expr, rule): - return rule +def subst_rule(name, expr, exists=False): + _substrule_name(expr, name) + return exists + +def set_subst_rule(name, expr): + subst_rule(name, expr, exists=True) -def set_subst_rule(name, expr, visitor): - rule = lp.SubstitutionRule(name, (), visitor(expr)) - subst_rule._memoize_cache = {k: v for k, v in subst_rule._memoize_cache.items() if v is not None} - return subst_rule(expr, rule) + +@generator_factory(item_tags=("substrule",), + context_tags="kernel") +def construct_subst_rule(expr, visitor): + name = _substrule_name(expr, None) + assert name + return lp.SubstitutionRule(name, (), visitor(expr)) diff --git a/python/dune/perftool/loopy/transformations/collect_rotate.py b/python/dune/perftool/loopy/transformations/collect_rotate.py index d99e530dbc4107691663c6bfd865bab274965caa..c0e222c2819015f272199c5bbb4440ba11af770c 100644 --- a/python/dune/perftool/loopy/transformations/collect_rotate.py +++ b/python/dune/perftool/loopy/transformations/collect_rotate.py @@ -194,7 +194,7 @@ def collect_vector_data_rotate(knl): replacemap_vec[expr] = prim.Subscript(prim.Variable(get_vector_view_name(quantity)), (vector_indices.get(vec_size) + last_index, prim.Variable(new_iname)), ) - else: + elif all(get_pymbolic_tag(expr) == 'sumfac' for expr in quantities[quantity]): # Add a vector view to this quantity expr, = quantities[quantity] knl = add_vector_view(knl, quantity, flatview=True) diff --git a/python/dune/perftool/pdelab/localoperator.py b/python/dune/perftool/pdelab/localoperator.py index adda9c7fc64c169cab0e237220f36d5c76302447..1f6c899bb6d47d8e37d622a51d3987821eb844af 100644 --- a/python/dune/perftool/pdelab/localoperator.py +++ b/python/dune/perftool/pdelab/localoperator.py @@ -419,16 +419,16 @@ def visit_integrals(integrals): data = get_global_context_value("data") for name, expr in data.object_by_name.items(): if name.startswith("cse"): - set_subst_rule(name, expr, visitor) + set_subst_rule(name, expr) # Ensure CSE on detjac * quadrature weight domain = term.argument.argexpr.ufl_domain() if measure == "cell": - set_subst_rule("integration_factor_cell1", uc.QuadratureWeight(domain)*uc.Abs(uc.JacobianDeterminant(domain)), visitor) - set_subst_rule("integration_factor_cell2", uc.Abs(uc.JacobianDeterminant(domain))*uc.QuadratureWeight(domain), visitor) + set_subst_rule("integration_factor_cell1", uc.QuadratureWeight(domain)*uc.Abs(uc.JacobianDeterminant(domain))) + set_subst_rule("integration_factor_cell2", uc.Abs(uc.JacobianDeterminant(domain))*uc.QuadratureWeight(domain)) else: - set_subst_rule("integration_factor_facet1", uc.FacetJacobianDeterminant(domain)*uc.QuadratureWeight(domain), visitor) - set_subst_rule("integration_factor_facet2", uc.QuadratureWeight(domain)*uc.FacetJacobianDeterminant(domain), visitor) + set_subst_rule("integration_factor_facet1", uc.FacetJacobianDeterminant(domain)*uc.QuadratureWeight(domain)) + set_subst_rule("integration_factor_facet2", uc.QuadratureWeight(domain)*uc.FacetJacobianDeterminant(domain)) get_backend(interface="accum_insn")(visitor, term, measure, subdomain_id) diff --git a/python/dune/perftool/sumfact/amatrix.py b/python/dune/perftool/sumfact/amatrix.py index e3efb78fb93b0336dc89e9869dc4a83f9e47508b..16f8b5289e074665f540ad67593af7ff20598bde 100644 --- a/python/dune/perftool/sumfact/amatrix.py +++ b/python/dune/perftool/sumfact/amatrix.py @@ -54,7 +54,7 @@ class AMatrix(ImmutableRecord): return False def output_to_pymbolic(self, name): - return prim.Variable(name) + return lp.TaggedVariable(name, "sumfac") class LargeAMatrix(ImmutableRecord): diff --git a/python/dune/perftool/ufl/visitor.py b/python/dune/perftool/ufl/visitor.py index f107d94be1ce4b5eeb9e6ef822bb97b3fc69970f..503bfd83959e3d3efb7ae597e11b352c4b5f5952 100644 --- a/python/dune/perftool/ufl/visitor.py +++ b/python/dune/perftool/ufl/visitor.py @@ -3,7 +3,8 @@ This module defines the main visitor algorithm transforming ufl expressions to pymbolic and loopy. """ from dune.perftool.error import PerftoolUFLError -from dune.perftool.generation import (domain, +from dune.perftool.generation import (construct_subst_rule, + domain, get_global_context_value, subst_rule, ) @@ -38,10 +39,11 @@ import pymbolic.primitives as prim class UFL2LoopyVisitor(ModifiedTerminalTracker): - def __init__(self, interface, measure, dimension_indices): + def __init__(self, interface, measure, dimension_indices, donot_check_substrules=None): self.interface = interface self.measure = measure self.dimension_indices = dimension_indices + self.donot_check_substrules = donot_check_substrules # Call base class constructors super(UFL2LoopyVisitor, self).__init__() @@ -55,8 +57,14 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker): return self.call(o) def call(self, o): - rule = subst_rule(o, None) - if rule: + if o != self.donot_check_substrules and subst_rule(None, o): + rule = construct_subst_rule(o, + type(self)(self.interface, + self.measure, + self.dimension_indices, + donot_check_substrules=o, + ) + ) return prim.Call(prim.Variable(rule.name), ()) else: return MultiFunction.__call__(self, o)