Skip to content
Snippets Groups Projects
Commit eb1d9bbe authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Add accumulation mixins

parent 7f2f128a
No related branches found
No related tags found
No related merge requests found
......@@ -59,7 +59,8 @@ from dune.codegen.generation.context import (cache_restoring,
get_global_context_value,
)
from dune.codegen.generation.mixins import (basis_mixin,
from dune.codegen.generation.mixins import (accumulation_mixin,
basis_mixin,
construct_from_mixins,
geometry_mixin,
quadrature_mixin,
......
......@@ -24,4 +24,5 @@ def construct_from_mixins(base=object, mixins=[], mixintype="geometry", name="Ge
# A list of specific mixins that we keep around explicitly
geometry_mixin = partial(mixin_base, "geometry")
quadrature_mixin = partial(mixin_base, "quadrature")
basis_mixin = partial(mixin_base, "basis")
\ No newline at end of file
basis_mixin = partial(mixin_base, "basis")
accumulation_mixin = partial(mixin_base, "accumulation")
......@@ -110,9 +110,10 @@ class CodegenFormOptionsArray(ImmutableRecord):
control_variable = CodegenOption(default=None, helpstr="Name of control variable in UFL file")
block_preconditioner_diagonal = CodegenOption(default=False, helpstr="Whether this operator should implement the diagonal part of a block preconditioner")
block_preconditioner_offdiagonal = CodegenOption(default=False, helpstr="Whether this operator should implement the off-diagonal part of a block preconditioner")
geometry_mixins = CodegenOption(default="generic", helpstr="A comma separated list of mixin identifiers to use for geometries. Currently implemented mixins: generic, axiparallel, equidistant")
quadrature_mixins = CodegenOption(default="generic", helpstr="A comma separated list of mixin identifiers to use for quadrature. Currently implemented: generic")
basis_mixins = CodegenOption(default="generic", helpstr="A comma separated list of mixin identifiers to use for basis function evaluation. Currently implemented: generic")
geometry_mixins = CodegenOption(default="generic", helpstr="A comma separated list of mixin identifiers to use for geometries. Currently implemented mixins: generic, axiparallel, equidistant, sumfact_multilinear, sumfact_axiparallel, sumfact_equidistant")
quadrature_mixins = CodegenOption(default="generic", helpstr="A comma separated list of mixin identifiers to use for quadrature. Currently implemented: generic, sumfact")
basis_mixins = CodegenOption(default="generic", helpstr="A comma separated list of mixin identifiers to use for basis function evaluation. Currently implemented: generic, sumfact")
accumulation_mixins = CodegenOption(default="generic", helpstr="A comma separated list of mixin identifiers to use for accumulation. Currently implemented: generic")
enable_volume = CodegenOption(default=True, helpstr="Whether to assemble volume integrals")
enable_skeleton = CodegenOption(default=True, helpstr="Whether to assemble skeleton integrals")
enable_boundary = CodegenOption(default=True, helpstr="Whether to assemble boundary integrals")
......@@ -183,10 +184,8 @@ def process_form_options(opt, form):
opt = opt.copy(unroll_dimension_loops=True,
quadrature_mixins="sumfact",
basis_mixins="sumfact",
constant_transformation_matrix=True,
diagonal_transformation_matrix=True,
accumulation_mixins="sumfact",
)
#TODO Remove the trafo matrix ones!
if opt.numerical_jacobian:
opt = opt.copy(generate_jacobians=False, generate_jacobian_apply=False)
......
......@@ -9,7 +9,8 @@ from dune.codegen.options import (get_form_option,
get_option,
option_switch,
set_form_option)
from dune.codegen.generation import (backend,
from dune.codegen.generation import (accumulation_mixin,
backend,
base_class,
class_basename,
class_member,
......@@ -329,6 +330,30 @@ def boundary_predicates(measure, subdomain_id):
return frozenset(predicates)
@accumulation_mixin("base")
class AccumulationMixinBase(object):
def get_accumulation_info(self, expr):
raise NotImplementedError
def list_accumulation_infos(self, expr):
raise NotImplementedError
def generate_accumulation_instruction(self, expr):
raise NotImplementedError
@accumulation_mixin("generic")
class GenericAccumulationMixin(AccumulationMixinBase):
def get_accumulation_info(self, expr):
return get_accumulation_info(expr, self)
def list_accumulation_infos(self, expr):
return list_accumulation_infos(expr, self)
def generate_accumulation_instruction(self, expr):
return generate_accumulation_instruction(expr, self)
class PDELabAccumulationInfo(ImmutableRecord):
def __init__(self,
element=None,
......@@ -463,6 +488,10 @@ def get_visitor(measure, subdomain_id):
mixins = get_form_option("basis_mixins").split(",")
VisitorType = construct_from_mixins(base=VisitorType, mixins=mixins, mixintype="basis", name="UFLVisitor")
# Mix accumulation mixins in
mixins = get_form_option("accumulation_mixins").split(",")
VisitorType = construct_from_mixins(base=VisitorType, mixins=mixins, mixintype="accumulation", name="UFLVisitor")
return VisitorType(interface, measure, subdomain_id)
......
import dune.codegen.sumfact.geometry
from dune.codegen.generation import get_backend
from dune.codegen.options import option_switch
from dune.codegen.pdelab.argument import (name_applycontainer,
name_coefficientcontainer,
)
import dune.codegen.sumfact.accumulation
import dune.codegen.sumfact.switch
......@@ -13,14 +6,4 @@ from dune.codegen.pdelab import PDELabInterface
class SumFactInterface(PDELabInterface):
def get_accumulation_info(self, expr, visitor):
from dune.codegen.sumfact.accumulation import get_accumulation_info
return get_accumulation_info(expr, visitor)
def list_accumulation_infos(self, expr, visitor):
from dune.codegen.sumfact.accumulation import list_accumulation_infos
return list_accumulation_infos(expr, visitor)
def generate_accumulation_instruction(self, expr, visitor):
from dune.codegen.sumfact.accumulation import generate_accumulation_instruction
return generate_accumulation_instruction(expr, visitor)
pass
\ No newline at end of file
......@@ -3,7 +3,8 @@ import itertools
from dune.codegen.pdelab.argument import (name_accumulation_variable,
PDELabAccumulationFunction,
)
from dune.codegen.generation import (backend,
from dune.codegen.generation import (accumulation_mixin,
backend,
domain,
dump_accumulate_timer,
generator_factory,
......@@ -25,7 +26,7 @@ from dune.codegen.options import (get_form_option,
from dune.codegen.loopy.flatten import flatten_index
from dune.codegen.loopy.target import type_floatingpoint
from dune.codegen.pdelab.driver import FEM_name_mangling
from dune.codegen.pdelab.localoperator import determine_accumulation_space
from dune.codegen.pdelab.localoperator import determine_accumulation_space, AccumulationMixinBase
from dune.codegen.pdelab.restriction import restricted_name
from dune.codegen.pdelab.signatures import assembler_routine_name
from dune.codegen.pdelab.geometry import world_dimension
......@@ -267,6 +268,18 @@ def _dof_offset(element, component):
return sum(sizes[0:component])
@accumulation_mixin("sumfact")
class SumfactAccumulationMixin(AccumulationMixinBase):
def get_accumulation_info(self, expr):
return get_accumulation_info(expr, self)
def list_accumulation_infos(self, expr):
return list_accumulation_infos(expr, self)
def generate_accumulation_instruction(self, expr):
return generate_accumulation_instruction(expr, self)
class SumfactAccumulationInfo(ImmutableRecord):
def __init__(self,
element=None,
......
......@@ -64,7 +64,7 @@ def get_kernel_name(facedir_s=None, facemod_s=None, facedir_n=None, facemod_n=No
def decide_if_kernel_is_necessary(facedir_s, facemod_s, facedir_n, facemod_n):
# If we are not using YaspGrid, all variants need to be realized
if not get_form_option("diagonal_transformation_matrix"):
if get_form_option("geometry_mixins") == "sumfact_multilinear":
# Reduce the variability according to grid info file
if get_option("grid_info") is not None:
filename = get_option("grid_info")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment