Skip to content
Snippets Groups Projects
Commit 164dc781 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Cache a few evaluations that take excessively long

parent e0825c0e
No related branches found
No related tags found
No related merge requests found
......@@ -39,6 +39,7 @@ from dune.codegen.pdelab.restriction import restricted_name
from dune.codegen.pdelab.driver import (isPk,
isQk,
isDG)
from dune.codegen.ufl.modified_terminals import Restriction
from pymbolic.primitives import Product, Subscript, Variable
import pymbolic.primitives as prim
......@@ -81,7 +82,11 @@ class BasisMixinBase(object):
@basis_mixin("generic")
class GenericBasisMixin(BasisMixinBase):
def initialize_function_spaces(self, expr):
return initialize_function_spaces(expr, self)
restriction = self.restriction
if self.measure == 'exterior_facet':
restriction = Restriction.POSITIVE
return initialize_function_spaces(expr, restriction, self.indices)
def lfs_inames(self, element, restriction, number, context=""):
return (lfs_iname(element, restriction, number, context),)
......
......@@ -89,6 +89,10 @@ class GenericPDELabGeometryMixin(GeometryMixinBase):
if restriction == Restriction.NONE:
return local
return self._to_cell(local, restriction)
@kernel_cached
def _to_cell(self, local, restriction):
basename = get_pymbolic_basename(local)
name = "{}_in_{}side".format(basename, "in" if restriction is Restriction.POSITIVE else "out")
temporary_variable(name, shape=(world_dimension(),), shape_impl=("fv",))
......
......@@ -26,6 +26,7 @@ from dune.codegen.generation import (accumulation_mixin,
iname,
include_file,
initializer_list,
kernel_cached,
post_include,
retrieve_cache_functions,
retrieve_cache_items,
......@@ -346,7 +347,11 @@ class AccumulationMixinBase(object):
@accumulation_mixin("generic")
class GenericAccumulationMixin(AccumulationMixinBase):
def get_accumulation_info(self, expr):
return get_accumulation_info(expr, self)
restriction = self.restriction
if self.measure == 'exterior_facet':
restriction = Restriction.POSITIVE
return get_accumulation_info(expr, restriction, self.indices, self)
def list_accumulation_infos(self, expr):
return list_accumulation_infos(expr, self)
......@@ -404,19 +409,16 @@ def list_accumulation_infos(expr, visitor):
return itertools.product(testgen, trialgen)
def get_accumulation_info(expr, visitor):
@kernel_cached
def get_accumulation_info(expr, restriction, indices, visitor):
element = expr.ufl_element()
leaf_element = element
element_index = 0
from ufl import MixedElement
if isinstance(expr.ufl_element(), MixedElement):
element_index = visitor.indices[0]
element_index = indices[0]
leaf_element = element.extract_component(element_index)[1]
restriction = visitor.restriction
if visitor.measure == 'exterior_facet':
restriction = Restriction.POSITIVE
inames = visitor.lfs_inames(leaf_element,
restriction,
expr.number()
......
......@@ -7,6 +7,7 @@ from dune.codegen.generation import (class_member,
iname,
include_file,
instruction,
kernel_cached,
preamble,
quadrature_mixin,
temporary_variable,
......@@ -71,6 +72,7 @@ class GenericQuadratureMixin(QuadratureMixinBase):
def quadrature_inames(self):
return (quadrature_iname(),)
@kernel_cached
def quadrature_position(self, index=None):
from dune.codegen.pdelab.geometry import local_dimension
dim = local_dimension()
......
......@@ -5,6 +5,7 @@ from dune.codegen.generation import (class_member,
function_mangler,
generator_factory,
include_file,
kernel_cached,
preamble,
valuearg,
)
......@@ -122,15 +123,12 @@ name_lfs = partial(_function_space_traversal, defaultname=available_lfs_names, r
type_gfs = partial(_function_space_traversal, defaultname=available_gfs_names, recfunc=_type_gfs)
def initialize_function_spaces(expr, visitor):
restriction = visitor.restriction
if visitor.measure == 'exterior_facet':
restriction = Restriction.POSITIVE
@kernel_cached
def initialize_function_spaces(expr, restriction, indices):
index = None
from ufl import MixedElement
if isinstance(expr.ufl_element(), MixedElement):
index = visitor.indices[0]
index = indices[0]
from ufl.classes import Argument, Coefficient
if isinstance(expr, Argument) and expr.number() == 0:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment