Skip to content
Snippets Groups Projects
Commit 2178f7cd authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Avoid storing any quadrature dependent data in interface objects

parent 263e3a5d
No related branches found
No related tags found
No related merge requests found
...@@ -44,7 +44,7 @@ from dune.codegen.sumfact.switch import (get_facedir, ...@@ -44,7 +44,7 @@ from dune.codegen.sumfact.switch import (get_facedir,
) )
from dune.codegen.sumfact.symbolic import SumfactKernel, SumfactKernelInterfaceBase from dune.codegen.sumfact.symbolic import SumfactKernel, SumfactKernelInterfaceBase
from dune.codegen.ufl.modified_terminals import extract_modified_arguments from dune.codegen.ufl.modified_terminals import extract_modified_arguments
from dune.codegen.tools import get_pymbolic_basename, get_leaf from dune.codegen.tools import get_pymbolic_basename, get_leaf, ImmutableCuttingRecord
from dune.codegen.error import CodegenError from dune.codegen.error import CodegenError
from pytools import ImmutableRecord, product from pytools import ImmutableRecord, product
...@@ -90,7 +90,7 @@ def accum_iname(element, bound, i): ...@@ -90,7 +90,7 @@ def accum_iname(element, bound, i):
return sumfact_iname(bound, "accum{}".format(suffix)) return sumfact_iname(bound, "accum{}".format(suffix))
class AccumulationOutput(SumfactKernelInterfaceBase, ImmutableRecord): class AccumulationOutput(SumfactKernelInterfaceBase, ImmutableCuttingRecord):
def __init__(self, def __init__(self,
matrix_sequence=None, matrix_sequence=None,
accumvar=None, accumvar=None,
...@@ -112,26 +112,23 @@ class AccumulationOutput(SumfactKernelInterfaceBase, ImmutableRecord): ...@@ -112,26 +112,23 @@ class AccumulationOutput(SumfactKernelInterfaceBase, ImmutableRecord):
# recalculating it in the property. # recalculating it in the property.
dim = world_dimension() dim = world_dimension()
quadrature_permutation = sumfact_quadrature_permutation_strategy(dim, restriction[0]) quadrature_permutation = sumfact_quadrature_permutation_strategy(dim, restriction[0])
# Calculate cost optimal permutation
matrix_sequence = permute_forward(matrix_sequence, quadrature_permutation) matrix_sequence = permute_forward(matrix_sequence, quadrature_permutation)
cost_permutation = sumfact_cost_permutation_strategy(matrix_sequence, self.stage)
# TODO: Isnt accumvar superfluous in the presence of all the other infos? # TODO: Isnt accumvar superfluous in the presence of all the other infos?
# Note: Do not put matrix_sequence into the Record. That screws up the vectorization strategy! # Note: Do not put matrix_sequence into the Record. That screws up the vectorization strategy!
ImmutableRecord.__init__(self, ImmutableCuttingRecord.__init__(self,
accumvar=accumvar, accumvar=accumvar,
restriction=restriction, restriction=restriction,
test_element=test_element, test_element=test_element,
test_element_index=test_element_index, test_element_index=test_element_index,
trial_element=trial_element, trial_element=trial_element,
trial_element_index=trial_element_index, trial_element_index=trial_element_index,
_quadrature_permutation=quadrature_permutation, _quadrature_permutation=quadrature_permutation,
_cost_permutation=cost_permutation, _permuted_matrix_sequence=matrix_sequence,
) )
def __repr__(self): def __repr__(self):
return ImmutableRecord.__repr__(self) return ImmutableCuttingRecord.__repr__(self)
def get_keyword_arguments(self): def get_keyword_arguments(self):
"""Get dictionary of keyword arguments needed to initialize this class """Get dictionary of keyword arguments needed to initialize this class
...@@ -141,7 +138,7 @@ class AccumulationOutput(SumfactKernelInterfaceBase, ImmutableRecord): ...@@ -141,7 +138,7 @@ class AccumulationOutput(SumfactKernelInterfaceBase, ImmutableRecord):
this dict to create an interface. this dict to create an interface.
""" """
dict = self.get_copy_kwargs() dict = self.get_copy_kwargs()
del dict['_cost_permutation'] del dict['_permuted_matrix_sequence']
del dict['_quadrature_permutation'] del dict['_quadrature_permutation']
dict['matrix_sequence'] = None dict['matrix_sequence'] = None
return dict return dict
...@@ -152,7 +149,7 @@ class AccumulationOutput(SumfactKernelInterfaceBase, ImmutableRecord): ...@@ -152,7 +149,7 @@ class AccumulationOutput(SumfactKernelInterfaceBase, ImmutableRecord):
@property @property
def cost_permutation(self): def cost_permutation(self):
return self._cost_permutation return sumfact_cost_permutation_strategy(self._permuted_matrix_sequence, self.stage)
@property @property
def stage(self): def stage(self):
......
...@@ -44,13 +44,13 @@ from dune.codegen.options import get_form_option ...@@ -44,13 +44,13 @@ from dune.codegen.options import get_form_option
from dune.codegen.pdelab.driver import FEM_name_mangling from dune.codegen.pdelab.driver import FEM_name_mangling
from dune.codegen.pdelab.restriction import restricted_name from dune.codegen.pdelab.restriction import restricted_name
from dune.codegen.pdelab.spaces import name_lfs, name_lfs_bound, name_leaf_lfs from dune.codegen.pdelab.spaces import name_lfs, name_lfs_bound, name_leaf_lfs
from dune.codegen.tools import maybe_wrap_subscript from dune.codegen.tools import maybe_wrap_subscript, ImmutableCuttingRecord
from dune.codegen.pdelab.basis import shape_as_pymbolic from dune.codegen.pdelab.basis import shape_as_pymbolic
from dune.codegen.sumfact.accumulation import sumfact_iname from dune.codegen.sumfact.accumulation import sumfact_iname
from ufl import MixedElement, VectorElement, TensorElement, TensorProductElement from ufl import MixedElement, VectorElement, TensorElement, TensorProductElement
from pytools import product, ImmutableRecord from pytools import product
from loopy.match import Writes from loopy.match import Writes
...@@ -235,7 +235,7 @@ class SumfactBasisMixin(GenericBasisMixin): ...@@ -235,7 +235,7 @@ class SumfactBasisMixin(GenericBasisMixin):
return prim.Subscript(var, vsf.quadrature_index(sf, self)) return prim.Subscript(var, vsf.quadrature_index(sf, self))
class LFSSumfactKernelInput(SumfactKernelInterfaceBase, ImmutableRecord): class LFSSumfactKernelInput(SumfactKernelInterfaceBase, ImmutableCuttingRecord):
def __init__(self, def __init__(self,
matrix_sequence=None, matrix_sequence=None,
coeff_func=None, coeff_func=None,
...@@ -252,22 +252,20 @@ class LFSSumfactKernelInput(SumfactKernelInterfaceBase, ImmutableRecord): ...@@ -252,22 +252,20 @@ class LFSSumfactKernelInput(SumfactKernelInterfaceBase, ImmutableRecord):
# recalculating it in the property. # recalculating it in the property.
dim = world_dimension() dim = world_dimension()
quadrature_permutation = sumfact_quadrature_permutation_strategy(dim, restriction) quadrature_permutation = sumfact_quadrature_permutation_strategy(dim, restriction)
matrix_sequence = permute_forward(matrix_sequence, quadrature_permutation) matrix_sequence = permute_forward(matrix_sequence, quadrature_permutation)
cost_permutation = sumfact_cost_permutation_strategy(matrix_sequence, self.stage)
# Note: Do not put matrix_sequence into the Record. That screws up the vectorization strategy! # Note: Do not put matrix_sequence into the Record. That screws up the vectorization strategy!
ImmutableRecord.__init__(self, ImmutableCuttingRecord.__init__(self,
coeff_func=coeff_func, coeff_func=coeff_func,
element=element, element=element,
element_index=element_index, element_index=element_index,
restriction=restriction, restriction=restriction,
_quadrature_permutation=quadrature_permutation, _quadrature_permutation=quadrature_permutation,
_cost_permutation=cost_permutation, _permuted_matrix_sequence=matrix_sequence,
) )
def __repr__(self): def __repr__(self):
return ImmutableRecord.__repr__(self) return ImmutableCuttingRecord.__repr__(self)
def __str__(self): def __str__(self):
return repr(self) return repr(self)
...@@ -280,7 +278,7 @@ class LFSSumfactKernelInput(SumfactKernelInterfaceBase, ImmutableRecord): ...@@ -280,7 +278,7 @@ class LFSSumfactKernelInput(SumfactKernelInterfaceBase, ImmutableRecord):
this dict to create an interface. this dict to create an interface.
""" """
dict = self.get_copy_kwargs() dict = self.get_copy_kwargs()
del dict['_cost_permutation'] del dict['_permuted_matrix_sequence']
del dict['_quadrature_permutation'] del dict['_quadrature_permutation']
dict['matrix_sequence'] = None dict['matrix_sequence'] = None
return dict return dict
...@@ -291,7 +289,7 @@ class LFSSumfactKernelInput(SumfactKernelInterfaceBase, ImmutableRecord): ...@@ -291,7 +289,7 @@ class LFSSumfactKernelInput(SumfactKernelInterfaceBase, ImmutableRecord):
@property @property
def cost_permutation(self): def cost_permutation(self):
return self._cost_permutation return sumfact_cost_permutation_strategy(self._permuted_matrix_sequence, self.stage)
@property @property
def stage(self): def stage(self):
......
...@@ -41,12 +41,10 @@ from dune.codegen.sumfact.permutation import (permute_backward, ...@@ -41,12 +41,10 @@ from dune.codegen.sumfact.permutation import (permute_backward,
from dune.codegen.sumfact.quadrature import additional_inames from dune.codegen.sumfact.quadrature import additional_inames
from dune.codegen.sumfact.switch import get_facedir, get_facemod from dune.codegen.sumfact.switch import get_facedir, get_facemod
from dune.codegen.sumfact.symbolic import SumfactKernelInterfaceBase, SumfactKernel from dune.codegen.sumfact.symbolic import SumfactKernelInterfaceBase, SumfactKernel
from dune.codegen.tools import get_pymbolic_basename from dune.codegen.tools import get_pymbolic_basename, ImmutableCuttingRecord
from dune.codegen.options import get_form_option, option_switch from dune.codegen.options import get_form_option, option_switch
from dune.codegen.ufl.modified_terminals import Restriction from dune.codegen.ufl.modified_terminals import Restriction
from pytools import ImmutableRecord
from loopy.match import Writes from loopy.match import Writes
import pymbolic.primitives as prim import pymbolic.primitives as prim
...@@ -337,7 +335,7 @@ def global_corner_iname(restriction): ...@@ -337,7 +335,7 @@ def global_corner_iname(restriction):
return name return name
class GeoCornersInput(SumfactKernelInterfaceBase, ImmutableRecord): class GeoCornersInput(SumfactKernelInterfaceBase, ImmutableCuttingRecord):
def __init__(self, def __init__(self,
matrix_sequence=None, matrix_sequence=None,
direction=None, direction=None,
...@@ -361,20 +359,18 @@ class GeoCornersInput(SumfactKernelInterfaceBase, ImmutableRecord): ...@@ -361,20 +359,18 @@ class GeoCornersInput(SumfactKernelInterfaceBase, ImmutableRecord):
# recalculating it in the property. # recalculating it in the property.
dim = world_dimension() dim = world_dimension()
quadrature_permutation = sumfact_quadrature_permutation_strategy(dim, restriction) quadrature_permutation = sumfact_quadrature_permutation_strategy(dim, restriction)
matrix_sequence = permute_forward(matrix_sequence, quadrature_permutation) matrix_sequence = permute_forward(matrix_sequence, quadrature_permutation)
cost_permutation = sumfact_cost_permutation_strategy(matrix_sequence, self.stage)
# Note: Do not put matrix_sequence into the Record. That screws up the vectorization strategy! # Note: Do not put matrix_sequence into the Record. That screws up the vectorization strategy!
ImmutableRecord.__init__(self, ImmutableCuttingRecord.__init__(self,
direction=direction, direction=direction,
restriction=restriction, restriction=restriction,
_quadrature_permutation=quadrature_permutation, _quadrature_permutation=quadrature_permutation,
_cost_permutation=cost_permutation, _permuted_matrix_sequence=matrix_sequence,
) )
def __repr__(self): def __repr__(self):
return ImmutableRecord.__repr__(self) return ImmutableCuttingRecord.__repr__(self)
def __str__(self): def __str__(self):
return repr(self) return repr(self)
...@@ -387,7 +383,7 @@ class GeoCornersInput(SumfactKernelInterfaceBase, ImmutableRecord): ...@@ -387,7 +383,7 @@ class GeoCornersInput(SumfactKernelInterfaceBase, ImmutableRecord):
this dict to create an interface. this dict to create an interface.
""" """
dict = self.get_copy_kwargs() dict = self.get_copy_kwargs()
del dict['_cost_permutation'] del dict['_permutated_matrix_sequence']
del dict['_quadrature_permutation'] del dict['_quadrature_permutation']
dict['matrix_sequence'] = None dict['matrix_sequence'] = None
return dict return dict
...@@ -398,7 +394,7 @@ class GeoCornersInput(SumfactKernelInterfaceBase, ImmutableRecord): ...@@ -398,7 +394,7 @@ class GeoCornersInput(SumfactKernelInterfaceBase, ImmutableRecord):
@property @property
def cost_permutation(self): def cost_permutation(self):
return self._cost_permutation return sumfact_cost_permutation_strategy(self._permuted_matrix_sequence, self.stage)
@property @property
def stage(self): def stage(self):
......
...@@ -1000,16 +1000,7 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable) ...@@ -1000,16 +1000,7 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
def horizontal_index(self, sf): def horizontal_index(self, sf):
for i, k in enumerate(self.kernels): for i, k in enumerate(self.kernels):
# We need to identify to which part of the vectorized kernel sf if sf.interface == k.interface:
# corresponds. Since splitting might change the cost_permutation we
# exclude it in the comparison below. We also make sure to check
# that derivatives are the same.
from copy import deepcopy
sf_interface = deepcopy(sf.interface)
sf_interface._cost_permutation = None
k_interface = deepcopy(k.interface)
k_interface._cost_permutation = None
if repr(sf_interface) == repr(k_interface):
if tuple(mat.derivative for mat in sf.matrix_sequence_quadrature_permuted) == tuple(mat.derivative for mat in k.matrix_sequence_quadrature_permuted): if tuple(mat.derivative for mat in sf.matrix_sequence_quadrature_permuted) == tuple(mat.derivative for mat in k.matrix_sequence_quadrature_permuted):
return i return i
......
...@@ -4,6 +4,20 @@ from __future__ import absolute_import ...@@ -4,6 +4,20 @@ from __future__ import absolute_import
import loopy as lp import loopy as lp
import pymbolic.primitives as prim import pymbolic.primitives as prim
import frozendict import frozendict
import pytools
class ImmutableCuttingRecord(pytools.ImmutableRecord):
"""
A record implementation that drops fields starting with an underscore
from hash and equality computation
"""
def __hash__(self):
return hash((type(self),) + tuple(getattr(self, field) for field in self.__class__.fields if not field.startswith("_")))
def __eq__(self, other):
return type(self) == type(other) and all(getattr(self, field) == getattr(other, field) for field in self.__class__.fields if not field.startswith("_"))
def get_pymbolic_basename(expr): def get_pymbolic_basename(expr):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment