Skip to content
Snippets Groups Projects
Commit d9998690 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Move get_facedir/get_facemod into accumulation mixin

parent 94d31143
No related branches found
No related tags found
No related merge requests found
......@@ -495,7 +495,8 @@ def visit_integral(integral):
# Start the visiting process!
visitor = get_visitor(measure, subdomain_id)
visitor.accumulate(integrand)
with global_context(visitor=visitor):
visitor.accumulate(integrand)
run_hook(name="after_visit",
args=(visitor,),
......
......@@ -39,9 +39,6 @@ from dune.codegen.sumfact.permutation import (permute_backward,
from dune.codegen.sumfact.tabulation import (basis_functions_per_direction,
construct_basis_matrix_sequence,
)
from dune.codegen.sumfact.switch import (get_facedir,
get_facemod,
)
from dune.codegen.sumfact.symbolic import SumfactKernel, SumfactKernelInterfaceBase
from dune.codegen.ufl.modified_terminals import extract_modified_arguments
from dune.codegen.tools import get_pymbolic_basename, get_leaf, ImmutableCuttingRecord
......@@ -355,6 +352,22 @@ class SumfactAccumulationMixin(AccumulationMixinBase):
def generate_accumulation_instruction(self, expr):
return generate_accumulation_instruction(expr, self)
def get_facedir(self, restriction):
from dune.codegen.pdelab.restriction import Restriction
if restriction == Restriction.POSITIVE or get_global_context_value("integral_type") == "exterior_facet":
return get_global_context_value("facedir_s")
if restriction == Restriction.NEGATIVE:
return get_global_context_value("facedir_n")
return None
def get_facemod(self, restriction):
from dune.codegen.pdelab.restriction import Restriction
if restriction == Restriction.POSITIVE or get_global_context_value("integral_type") == "exterior_facet":
return get_global_context_value("facemod_s")
if restriction == Restriction.NEGATIVE:
return get_global_context_value("facemod_n")
return None
class SumfactAccumulationInfo(ImmutableRecord):
def __init__(self,
......@@ -520,8 +533,8 @@ def generate_accumulation_instruction(expr, visitor):
matrix_sequence = construct_basis_matrix_sequence(
transpose=True,
derivative=test_info.grad_index,
facedir=get_facedir(test_info.restriction),
facemod=get_facemod(test_info.restriction),
facedir=visitor.get_facedir(test_info.restriction),
facemod=visitor.get_facemod(test_info.restriction),
basis_size=basis_size)
jacobian_inames = trial_info.inames
......
......@@ -31,9 +31,6 @@ from dune.codegen.sumfact.permutation import (permute_backward,
sumfact_cost_permutation_strategy,
sumfact_quadrature_permutation_strategy,
)
from dune.codegen.sumfact.switch import (get_facedir,
get_facemod,
)
from dune.codegen.pdelab.argument import name_coefficientcontainer, name_applycontainer
from dune.codegen.pdelab.basis import GenericBasisMixin
from dune.codegen.pdelab.geometry import (local_dimension,
......@@ -86,7 +83,7 @@ class SumfactBasisMixin(GenericBasisMixin):
temporary_variable(name, shape=())
quad_inames = self.quadrature_inames()
inames = self.lfs_inames(element, restriction)
facedir = get_facedir(restriction)
facedir = self.get_facedir(restriction)
# Collect the pairs of lfs/quad inames that are in use
# On facets, the normal direction of the facet is excluded
......@@ -106,7 +103,7 @@ class SumfactBasisMixin(GenericBasisMixin):
# Add the missing direction on facedirs by evaluating at either 0 or 1
if facedir is not None:
facemod = get_facemod(restriction)
facemod = self.get_facemod(restriction)
prod = prod + (prim.Call(PolynomialLookup(name_polynomials(element.degree()), False),
(prim.Variable(inames[facedir]), facemod)),)
......@@ -141,7 +138,7 @@ class SumfactBasisMixin(GenericBasisMixin):
temporary_variable(name, shape=())
quad_inames = self.quadrature_inames()
inames = self.lfs_inames(element, restriction)
facedir = get_facedir(restriction)
facedir = self.get_facedir(restriction)
# Map the direction to a quadrature iname
quadinamemapping = {}
......@@ -161,7 +158,7 @@ class SumfactBasisMixin(GenericBasisMixin):
prod.append(tab.pymbolic((prim.Variable(quadinamemapping[i]), prim.Variable(inames[i]))))
if facedir is not None:
facemod = get_facemod(restriction)
facemod = self.get_facemod(restriction)
prod.append(prim.Call(PolynomialLookup(name_polynomials(element.degree()), index == facedir),
(prim.Variable(inames[facedir]), facemod)),)
......@@ -197,8 +194,8 @@ class SumfactBasisMixin(GenericBasisMixin):
# Construct the matrix sequence for this sum factorization
matrix_sequence = construct_basis_matrix_sequence(derivative=derivative,
facedir=get_facedir(restriction),
facemod=get_facemod(restriction),
facedir=self.get_facedir(restriction),
facemod=self.get_facemod(restriction),
basis_size=basis_size)
inp = LFSSumfactKernelInput(matrix_sequence=matrix_sequence,
......
......@@ -39,7 +39,6 @@ from dune.codegen.sumfact.permutation import (permute_backward,
sumfact_quadrature_permutation_strategy,
)
from dune.codegen.sumfact.quadrature import additional_inames
from dune.codegen.sumfact.switch import get_facedir, get_facemod
from dune.codegen.sumfact.symbolic import SumfactKernelInterfaceBase, SumfactKernel
from dune.codegen.tools import get_pymbolic_basename, ImmutableCuttingRecord
from dune.codegen.options import get_form_option, option_switch
......@@ -162,8 +161,8 @@ class SumfactMultiLinearGeometryMixin(GenericPDELabGeometryMixin):
def outer_normal(self):
""" This is the *unnormalized* outer normal """
name = "outer_normal"
facedir_s = get_facedir(Restriction.POSITIVE)
facemod_s = get_facemod(Restriction.POSITIVE)
facedir_s = self.get_facedir(Restriction.POSITIVE)
facemod_s = self.get_facemod(Restriction.POSITIVE)
temporary_variable(name, shape=(world_dimension(),))
for i in range(world_dimension()):
......@@ -210,8 +209,8 @@ class SumfactMultiLinearGeometryMixin(GenericPDELabGeometryMixin):
restriction = enforce_boundary_restriction(self)
# Generate sum factorization kernel and add vectorization info
matrix_sequence = construct_basis_matrix_sequence(facedir=get_facedir(restriction),
facemod=get_facemod(restriction),
matrix_sequence = construct_basis_matrix_sequence(facedir=self.get_facedir(restriction),
facemod=self.get_facemod(restriction),
basis_size=(2,) * world_dimension())
inp = GeoCornersInput(matrix_sequence=matrix_sequence,
direction=self.indices[0],
......@@ -253,8 +252,8 @@ class SumfactAxiParallelGeometryMixin(AxiparallelGeometryMixin):
assert isinstance(i, int)
# Use facemod_s and facedir_s
if i == get_facedir(Restriction.POSITIVE):
if get_facemod(Restriction.POSITIVE):
if i == self.get_facedir(Restriction.POSITIVE):
if self.get_facemod(Restriction.POSITIVE):
return 1
else:
return -1
......@@ -270,7 +269,7 @@ class SumfactEqudistantGeometryMixin(EquidistantGeometryMixin, SumfactAxiParalle
def facet_jacobian_determinant(self, o):
name = "fdetjac"
self.define_facet_jacobian_determinant(name)
facedir = get_facedir(Restriction.POSITIVE)
facedir = self.get_facedir(Restriction.POSITIVE)
globalarg(name, shape=(world_dimension(),))
return prim.Subscript(prim.Variable(name), (facedir,))
......@@ -304,8 +303,7 @@ class SumfactEqudistantGeometryMixin(EquidistantGeometryMixin, SumfactAxiParalle
restriction = Restriction.NONE
if self.measure == "interior_facet":
restriction = Restriction.POSITIVE
from dune.codegen.sumfact.switch import get_facedir
face = get_facedir(restriction)
face = self.get_facedir(restriction)
lowcorner = name_lowerleft_corner()
meshwidth = name_meshwidth()
......@@ -527,8 +525,8 @@ def _name_jacobian(i, j, restriction, visitor):
"""
# Create matrix sequence with derivative in j direction
matrix_sequence = construct_basis_matrix_sequence(derivative=j,
facedir=get_facedir(restriction),
facemod=get_facemod(restriction),
facedir=visitor.get_facedir(restriction),
facemod=visitor.get_facemod(restriction),
basis_size=(2,) * world_dimension())
# Sum factorization input for the i'th component of the geometry mapping
......
......@@ -3,7 +3,6 @@
import itertools
from dune.codegen.options import get_option
from dune.codegen.sumfact.switch import get_facedir, get_facemod
from dune.codegen.ufl.modified_terminals import Restriction
......@@ -125,8 +124,9 @@ def sumfact_quadrature_permutation_strategy(dim, restriction):
# all others can be derived by rotating the cube and matching edge
# directions.
def _order_on_self(restriction):
facedir = get_facedir(restriction)
facemod = get_facemod(restriction)
visitor = get_global_context_value("visitor")
facedir = visitor.get_facedir(restriction)
facemod = visitor.get_facemod(restriction)
quadrature_order = {
(0, 0): (0, 1, 2),
......
......@@ -10,7 +10,6 @@ from dune.codegen.generation import (domain,
quadrature_mixin,
temporary_variable,
)
from dune.codegen.sumfact.switch import get_facedir
from dune.codegen.sumfact.tabulation import (quadrature_points_per_direction,
local_quadrature_points_per_direction,
name_oned_quadrature_points,
......@@ -22,7 +21,6 @@ from dune.codegen.pdelab.geometry import (local_dimension,
)
from dune.codegen.pdelab.quadrature import GenericQuadratureMixin
from dune.codegen.options import get_form_option
from dune.codegen.sumfact.switch import get_facedir
from dune.codegen.loopy.target import dtype_floatingpoint
from loopy import CallMangleInfo
......
......@@ -169,25 +169,3 @@ def generate_interior_facet_switch():
block.append("}")
return ClassMember(signature + block)
def get_facedir(restriction):
from dune.codegen.pdelab.restriction import Restriction
if restriction == Restriction.POSITIVE or get_global_context_value("integral_type") == "exterior_facet":
return get_global_context_value("facedir_s")
if restriction == Restriction.NEGATIVE:
return get_global_context_value("facedir_n")
if restriction == Restriction.NONE:
return None
assert False
def get_facemod(restriction):
from dune.codegen.pdelab.restriction import Restriction
if restriction == Restriction.POSITIVE or get_global_context_value("integral_type") == "exterior_facet":
return get_global_context_value("facemod_s")
if restriction == Restriction.NEGATIVE:
return get_global_context_value("facemod_n")
if restriction == Restriction.NONE:
return None
assert False
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment