Skip to content
Snippets Groups Projects
Commit d9998690 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Move get_facedir/get_facemod into accumulation mixin

parent 94d31143
No related branches found
No related tags found
No related merge requests found
...@@ -495,7 +495,8 @@ def visit_integral(integral): ...@@ -495,7 +495,8 @@ def visit_integral(integral):
# Start the visiting process! # Start the visiting process!
visitor = get_visitor(measure, subdomain_id) visitor = get_visitor(measure, subdomain_id)
visitor.accumulate(integrand) with global_context(visitor=visitor):
visitor.accumulate(integrand)
run_hook(name="after_visit", run_hook(name="after_visit",
args=(visitor,), args=(visitor,),
......
...@@ -39,9 +39,6 @@ from dune.codegen.sumfact.permutation import (permute_backward, ...@@ -39,9 +39,6 @@ from dune.codegen.sumfact.permutation import (permute_backward,
from dune.codegen.sumfact.tabulation import (basis_functions_per_direction, from dune.codegen.sumfact.tabulation import (basis_functions_per_direction,
construct_basis_matrix_sequence, construct_basis_matrix_sequence,
) )
from dune.codegen.sumfact.switch import (get_facedir,
get_facemod,
)
from dune.codegen.sumfact.symbolic import SumfactKernel, SumfactKernelInterfaceBase from dune.codegen.sumfact.symbolic import SumfactKernel, SumfactKernelInterfaceBase
from dune.codegen.ufl.modified_terminals import extract_modified_arguments from dune.codegen.ufl.modified_terminals import extract_modified_arguments
from dune.codegen.tools import get_pymbolic_basename, get_leaf, ImmutableCuttingRecord from dune.codegen.tools import get_pymbolic_basename, get_leaf, ImmutableCuttingRecord
...@@ -355,6 +352,22 @@ class SumfactAccumulationMixin(AccumulationMixinBase): ...@@ -355,6 +352,22 @@ class SumfactAccumulationMixin(AccumulationMixinBase):
def generate_accumulation_instruction(self, expr): def generate_accumulation_instruction(self, expr):
return generate_accumulation_instruction(expr, self) return generate_accumulation_instruction(expr, self)
def get_facedir(self, restriction):
from dune.codegen.pdelab.restriction import Restriction
if restriction == Restriction.POSITIVE or get_global_context_value("integral_type") == "exterior_facet":
return get_global_context_value("facedir_s")
if restriction == Restriction.NEGATIVE:
return get_global_context_value("facedir_n")
return None
def get_facemod(self, restriction):
from dune.codegen.pdelab.restriction import Restriction
if restriction == Restriction.POSITIVE or get_global_context_value("integral_type") == "exterior_facet":
return get_global_context_value("facemod_s")
if restriction == Restriction.NEGATIVE:
return get_global_context_value("facemod_n")
return None
class SumfactAccumulationInfo(ImmutableRecord): class SumfactAccumulationInfo(ImmutableRecord):
def __init__(self, def __init__(self,
...@@ -520,8 +533,8 @@ def generate_accumulation_instruction(expr, visitor): ...@@ -520,8 +533,8 @@ def generate_accumulation_instruction(expr, visitor):
matrix_sequence = construct_basis_matrix_sequence( matrix_sequence = construct_basis_matrix_sequence(
transpose=True, transpose=True,
derivative=test_info.grad_index, derivative=test_info.grad_index,
facedir=get_facedir(test_info.restriction), facedir=visitor.get_facedir(test_info.restriction),
facemod=get_facemod(test_info.restriction), facemod=visitor.get_facemod(test_info.restriction),
basis_size=basis_size) basis_size=basis_size)
jacobian_inames = trial_info.inames jacobian_inames = trial_info.inames
......
...@@ -31,9 +31,6 @@ from dune.codegen.sumfact.permutation import (permute_backward, ...@@ -31,9 +31,6 @@ from dune.codegen.sumfact.permutation import (permute_backward,
sumfact_cost_permutation_strategy, sumfact_cost_permutation_strategy,
sumfact_quadrature_permutation_strategy, sumfact_quadrature_permutation_strategy,
) )
from dune.codegen.sumfact.switch import (get_facedir,
get_facemod,
)
from dune.codegen.pdelab.argument import name_coefficientcontainer, name_applycontainer from dune.codegen.pdelab.argument import name_coefficientcontainer, name_applycontainer
from dune.codegen.pdelab.basis import GenericBasisMixin from dune.codegen.pdelab.basis import GenericBasisMixin
from dune.codegen.pdelab.geometry import (local_dimension, from dune.codegen.pdelab.geometry import (local_dimension,
...@@ -86,7 +83,7 @@ class SumfactBasisMixin(GenericBasisMixin): ...@@ -86,7 +83,7 @@ class SumfactBasisMixin(GenericBasisMixin):
temporary_variable(name, shape=()) temporary_variable(name, shape=())
quad_inames = self.quadrature_inames() quad_inames = self.quadrature_inames()
inames = self.lfs_inames(element, restriction) inames = self.lfs_inames(element, restriction)
facedir = get_facedir(restriction) facedir = self.get_facedir(restriction)
# Collect the pairs of lfs/quad inames that are in use # Collect the pairs of lfs/quad inames that are in use
# On facets, the normal direction of the facet is excluded # On facets, the normal direction of the facet is excluded
...@@ -106,7 +103,7 @@ class SumfactBasisMixin(GenericBasisMixin): ...@@ -106,7 +103,7 @@ class SumfactBasisMixin(GenericBasisMixin):
# Add the missing direction on facedirs by evaluating at either 0 or 1 # Add the missing direction on facedirs by evaluating at either 0 or 1
if facedir is not None: if facedir is not None:
facemod = get_facemod(restriction) facemod = self.get_facemod(restriction)
prod = prod + (prim.Call(PolynomialLookup(name_polynomials(element.degree()), False), prod = prod + (prim.Call(PolynomialLookup(name_polynomials(element.degree()), False),
(prim.Variable(inames[facedir]), facemod)),) (prim.Variable(inames[facedir]), facemod)),)
...@@ -141,7 +138,7 @@ class SumfactBasisMixin(GenericBasisMixin): ...@@ -141,7 +138,7 @@ class SumfactBasisMixin(GenericBasisMixin):
temporary_variable(name, shape=()) temporary_variable(name, shape=())
quad_inames = self.quadrature_inames() quad_inames = self.quadrature_inames()
inames = self.lfs_inames(element, restriction) inames = self.lfs_inames(element, restriction)
facedir = get_facedir(restriction) facedir = self.get_facedir(restriction)
# Map the direction to a quadrature iname # Map the direction to a quadrature iname
quadinamemapping = {} quadinamemapping = {}
...@@ -161,7 +158,7 @@ class SumfactBasisMixin(GenericBasisMixin): ...@@ -161,7 +158,7 @@ class SumfactBasisMixin(GenericBasisMixin):
prod.append(tab.pymbolic((prim.Variable(quadinamemapping[i]), prim.Variable(inames[i])))) prod.append(tab.pymbolic((prim.Variable(quadinamemapping[i]), prim.Variable(inames[i]))))
if facedir is not None: if facedir is not None:
facemod = get_facemod(restriction) facemod = self.get_facemod(restriction)
prod.append(prim.Call(PolynomialLookup(name_polynomials(element.degree()), index == facedir), prod.append(prim.Call(PolynomialLookup(name_polynomials(element.degree()), index == facedir),
(prim.Variable(inames[facedir]), facemod)),) (prim.Variable(inames[facedir]), facemod)),)
...@@ -197,8 +194,8 @@ class SumfactBasisMixin(GenericBasisMixin): ...@@ -197,8 +194,8 @@ class SumfactBasisMixin(GenericBasisMixin):
# Construct the matrix sequence for this sum factorization # Construct the matrix sequence for this sum factorization
matrix_sequence = construct_basis_matrix_sequence(derivative=derivative, matrix_sequence = construct_basis_matrix_sequence(derivative=derivative,
facedir=get_facedir(restriction), facedir=self.get_facedir(restriction),
facemod=get_facemod(restriction), facemod=self.get_facemod(restriction),
basis_size=basis_size) basis_size=basis_size)
inp = LFSSumfactKernelInput(matrix_sequence=matrix_sequence, inp = LFSSumfactKernelInput(matrix_sequence=matrix_sequence,
......
...@@ -39,7 +39,6 @@ from dune.codegen.sumfact.permutation import (permute_backward, ...@@ -39,7 +39,6 @@ from dune.codegen.sumfact.permutation import (permute_backward,
sumfact_quadrature_permutation_strategy, sumfact_quadrature_permutation_strategy,
) )
from dune.codegen.sumfact.quadrature import additional_inames from dune.codegen.sumfact.quadrature import additional_inames
from dune.codegen.sumfact.switch import get_facedir, get_facemod
from dune.codegen.sumfact.symbolic import SumfactKernelInterfaceBase, SumfactKernel from dune.codegen.sumfact.symbolic import SumfactKernelInterfaceBase, SumfactKernel
from dune.codegen.tools import get_pymbolic_basename, ImmutableCuttingRecord from dune.codegen.tools import get_pymbolic_basename, ImmutableCuttingRecord
from dune.codegen.options import get_form_option, option_switch from dune.codegen.options import get_form_option, option_switch
...@@ -162,8 +161,8 @@ class SumfactMultiLinearGeometryMixin(GenericPDELabGeometryMixin): ...@@ -162,8 +161,8 @@ class SumfactMultiLinearGeometryMixin(GenericPDELabGeometryMixin):
def outer_normal(self): def outer_normal(self):
""" This is the *unnormalized* outer normal """ """ This is the *unnormalized* outer normal """
name = "outer_normal" name = "outer_normal"
facedir_s = get_facedir(Restriction.POSITIVE) facedir_s = self.get_facedir(Restriction.POSITIVE)
facemod_s = get_facemod(Restriction.POSITIVE) facemod_s = self.get_facemod(Restriction.POSITIVE)
temporary_variable(name, shape=(world_dimension(),)) temporary_variable(name, shape=(world_dimension(),))
for i in range(world_dimension()): for i in range(world_dimension()):
...@@ -210,8 +209,8 @@ class SumfactMultiLinearGeometryMixin(GenericPDELabGeometryMixin): ...@@ -210,8 +209,8 @@ class SumfactMultiLinearGeometryMixin(GenericPDELabGeometryMixin):
restriction = enforce_boundary_restriction(self) restriction = enforce_boundary_restriction(self)
# Generate sum factorization kernel and add vectorization info # Generate sum factorization kernel and add vectorization info
matrix_sequence = construct_basis_matrix_sequence(facedir=get_facedir(restriction), matrix_sequence = construct_basis_matrix_sequence(facedir=self.get_facedir(restriction),
facemod=get_facemod(restriction), facemod=self.get_facemod(restriction),
basis_size=(2,) * world_dimension()) basis_size=(2,) * world_dimension())
inp = GeoCornersInput(matrix_sequence=matrix_sequence, inp = GeoCornersInput(matrix_sequence=matrix_sequence,
direction=self.indices[0], direction=self.indices[0],
...@@ -253,8 +252,8 @@ class SumfactAxiParallelGeometryMixin(AxiparallelGeometryMixin): ...@@ -253,8 +252,8 @@ class SumfactAxiParallelGeometryMixin(AxiparallelGeometryMixin):
assert isinstance(i, int) assert isinstance(i, int)
# Use facemod_s and facedir_s # Use facemod_s and facedir_s
if i == get_facedir(Restriction.POSITIVE): if i == self.get_facedir(Restriction.POSITIVE):
if get_facemod(Restriction.POSITIVE): if self.get_facemod(Restriction.POSITIVE):
return 1 return 1
else: else:
return -1 return -1
...@@ -270,7 +269,7 @@ class SumfactEqudistantGeometryMixin(EquidistantGeometryMixin, SumfactAxiParalle ...@@ -270,7 +269,7 @@ class SumfactEqudistantGeometryMixin(EquidistantGeometryMixin, SumfactAxiParalle
def facet_jacobian_determinant(self, o): def facet_jacobian_determinant(self, o):
name = "fdetjac" name = "fdetjac"
self.define_facet_jacobian_determinant(name) self.define_facet_jacobian_determinant(name)
facedir = get_facedir(Restriction.POSITIVE) facedir = self.get_facedir(Restriction.POSITIVE)
globalarg(name, shape=(world_dimension(),)) globalarg(name, shape=(world_dimension(),))
return prim.Subscript(prim.Variable(name), (facedir,)) return prim.Subscript(prim.Variable(name), (facedir,))
...@@ -304,8 +303,7 @@ class SumfactEqudistantGeometryMixin(EquidistantGeometryMixin, SumfactAxiParalle ...@@ -304,8 +303,7 @@ class SumfactEqudistantGeometryMixin(EquidistantGeometryMixin, SumfactAxiParalle
restriction = Restriction.NONE restriction = Restriction.NONE
if self.measure == "interior_facet": if self.measure == "interior_facet":
restriction = Restriction.POSITIVE restriction = Restriction.POSITIVE
from dune.codegen.sumfact.switch import get_facedir face = self.get_facedir(restriction)
face = get_facedir(restriction)
lowcorner = name_lowerleft_corner() lowcorner = name_lowerleft_corner()
meshwidth = name_meshwidth() meshwidth = name_meshwidth()
...@@ -527,8 +525,8 @@ def _name_jacobian(i, j, restriction, visitor): ...@@ -527,8 +525,8 @@ def _name_jacobian(i, j, restriction, visitor):
""" """
# Create matrix sequence with derivative in j direction # Create matrix sequence with derivative in j direction
matrix_sequence = construct_basis_matrix_sequence(derivative=j, matrix_sequence = construct_basis_matrix_sequence(derivative=j,
facedir=get_facedir(restriction), facedir=visitor.get_facedir(restriction),
facemod=get_facemod(restriction), facemod=visitor.get_facemod(restriction),
basis_size=(2,) * world_dimension()) basis_size=(2,) * world_dimension())
# Sum factorization input for the i'th component of the geometry mapping # Sum factorization input for the i'th component of the geometry mapping
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
import itertools import itertools
from dune.codegen.options import get_option from dune.codegen.options import get_option
from dune.codegen.sumfact.switch import get_facedir, get_facemod
from dune.codegen.ufl.modified_terminals import Restriction from dune.codegen.ufl.modified_terminals import Restriction
...@@ -125,8 +124,9 @@ def sumfact_quadrature_permutation_strategy(dim, restriction): ...@@ -125,8 +124,9 @@ def sumfact_quadrature_permutation_strategy(dim, restriction):
# all others can be derived by rotating the cube and matching edge # all others can be derived by rotating the cube and matching edge
# directions. # directions.
def _order_on_self(restriction): def _order_on_self(restriction):
facedir = get_facedir(restriction) visitor = get_global_context_value("visitor")
facemod = get_facemod(restriction) facedir = visitor.get_facedir(restriction)
facemod = visitor.get_facemod(restriction)
quadrature_order = { quadrature_order = {
(0, 0): (0, 1, 2), (0, 0): (0, 1, 2),
......
...@@ -10,7 +10,6 @@ from dune.codegen.generation import (domain, ...@@ -10,7 +10,6 @@ from dune.codegen.generation import (domain,
quadrature_mixin, quadrature_mixin,
temporary_variable, temporary_variable,
) )
from dune.codegen.sumfact.switch import get_facedir
from dune.codegen.sumfact.tabulation import (quadrature_points_per_direction, from dune.codegen.sumfact.tabulation import (quadrature_points_per_direction,
local_quadrature_points_per_direction, local_quadrature_points_per_direction,
name_oned_quadrature_points, name_oned_quadrature_points,
...@@ -22,7 +21,6 @@ from dune.codegen.pdelab.geometry import (local_dimension, ...@@ -22,7 +21,6 @@ from dune.codegen.pdelab.geometry import (local_dimension,
) )
from dune.codegen.pdelab.quadrature import GenericQuadratureMixin from dune.codegen.pdelab.quadrature import GenericQuadratureMixin
from dune.codegen.options import get_form_option from dune.codegen.options import get_form_option
from dune.codegen.sumfact.switch import get_facedir
from dune.codegen.loopy.target import dtype_floatingpoint from dune.codegen.loopy.target import dtype_floatingpoint
from loopy import CallMangleInfo from loopy import CallMangleInfo
......
...@@ -169,25 +169,3 @@ def generate_interior_facet_switch(): ...@@ -169,25 +169,3 @@ def generate_interior_facet_switch():
block.append("}") block.append("}")
return ClassMember(signature + block) return ClassMember(signature + block)
def get_facedir(restriction):
from dune.codegen.pdelab.restriction import Restriction
if restriction == Restriction.POSITIVE or get_global_context_value("integral_type") == "exterior_facet":
return get_global_context_value("facedir_s")
if restriction == Restriction.NEGATIVE:
return get_global_context_value("facedir_n")
if restriction == Restriction.NONE:
return None
assert False
def get_facemod(restriction):
from dune.codegen.pdelab.restriction import Restriction
if restriction == Restriction.POSITIVE or get_global_context_value("integral_type") == "exterior_facet":
return get_global_context_value("facemod_s")
if restriction == Restriction.NEGATIVE:
return get_global_context_value("facemod_n")
if restriction == Restriction.NONE:
return None
assert False
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment