diff --git a/python/dune/codegen/sumfact/accumulation.py b/python/dune/codegen/sumfact/accumulation.py index f327fceb0ba5e31d14dadaa8d977e7147c6655ed..e0a11c144678771bf74ecafddfb0b7e34355648e 100644 --- a/python/dune/codegen/sumfact/accumulation.py +++ b/python/dune/codegen/sumfact/accumulation.py @@ -44,7 +44,7 @@ from dune.codegen.sumfact.switch import (get_facedir, ) from dune.codegen.sumfact.symbolic import SumfactKernel, SumfactKernelInterfaceBase from dune.codegen.ufl.modified_terminals import extract_modified_arguments -from dune.codegen.tools import get_pymbolic_basename, get_leaf +from dune.codegen.tools import get_pymbolic_basename, get_leaf, ImmutableCuttingRecord from dune.codegen.error import CodegenError from pytools import ImmutableRecord, product @@ -90,7 +90,7 @@ def accum_iname(element, bound, i): return sumfact_iname(bound, "accum{}".format(suffix)) -class AccumulationOutput(SumfactKernelInterfaceBase, ImmutableRecord): +class AccumulationOutput(SumfactKernelInterfaceBase, ImmutableCuttingRecord): def __init__(self, matrix_sequence=None, accumvar=None, @@ -112,26 +112,23 @@ class AccumulationOutput(SumfactKernelInterfaceBase, ImmutableRecord): # recalculating it in the property. dim = world_dimension() quadrature_permutation = sumfact_quadrature_permutation_strategy(dim, restriction[0]) - - # Calculate cost optimal permutation matrix_sequence = permute_forward(matrix_sequence, quadrature_permutation) - cost_permutation = sumfact_cost_permutation_strategy(matrix_sequence, self.stage) # TODO: Isnt accumvar superfluous in the presence of all the other infos? # Note: Do not put matrix_sequence into the Record. That screws up the vectorization strategy! - ImmutableRecord.__init__(self, - accumvar=accumvar, - restriction=restriction, - test_element=test_element, - test_element_index=test_element_index, - trial_element=trial_element, - trial_element_index=trial_element_index, - _quadrature_permutation=quadrature_permutation, - _cost_permutation=cost_permutation, - ) + ImmutableCuttingRecord.__init__(self, + accumvar=accumvar, + restriction=restriction, + test_element=test_element, + test_element_index=test_element_index, + trial_element=trial_element, + trial_element_index=trial_element_index, + _quadrature_permutation=quadrature_permutation, + _permuted_matrix_sequence=matrix_sequence, + ) def __repr__(self): - return ImmutableRecord.__repr__(self) + return ImmutableCuttingRecord.__repr__(self) def get_keyword_arguments(self): """Get dictionary of keyword arguments needed to initialize this class @@ -141,7 +138,7 @@ class AccumulationOutput(SumfactKernelInterfaceBase, ImmutableRecord): this dict to create an interface. """ dict = self.get_copy_kwargs() - del dict['_cost_permutation'] + del dict['_permuted_matrix_sequence'] del dict['_quadrature_permutation'] dict['matrix_sequence'] = None return dict @@ -152,7 +149,7 @@ class AccumulationOutput(SumfactKernelInterfaceBase, ImmutableRecord): @property def cost_permutation(self): - return self._cost_permutation + return sumfact_cost_permutation_strategy(self._permuted_matrix_sequence, self.stage) @property def stage(self): diff --git a/python/dune/codegen/sumfact/basis.py b/python/dune/codegen/sumfact/basis.py index 9fac39426df4339a5804237418e5edfcd6e092c6..b9466e260776fbef31cab9252fbb0c203569a6f5 100644 --- a/python/dune/codegen/sumfact/basis.py +++ b/python/dune/codegen/sumfact/basis.py @@ -44,13 +44,13 @@ from dune.codegen.options import get_form_option from dune.codegen.pdelab.driver import FEM_name_mangling from dune.codegen.pdelab.restriction import restricted_name from dune.codegen.pdelab.spaces import name_lfs, name_lfs_bound, name_leaf_lfs -from dune.codegen.tools import maybe_wrap_subscript +from dune.codegen.tools import maybe_wrap_subscript, ImmutableCuttingRecord from dune.codegen.pdelab.basis import shape_as_pymbolic from dune.codegen.sumfact.accumulation import sumfact_iname from ufl import MixedElement, VectorElement, TensorElement, TensorProductElement -from pytools import product, ImmutableRecord +from pytools import product from loopy.match import Writes @@ -235,7 +235,7 @@ class SumfactBasisMixin(GenericBasisMixin): return prim.Subscript(var, vsf.quadrature_index(sf, self)) -class LFSSumfactKernelInput(SumfactKernelInterfaceBase, ImmutableRecord): +class LFSSumfactKernelInput(SumfactKernelInterfaceBase, ImmutableCuttingRecord): def __init__(self, matrix_sequence=None, coeff_func=None, @@ -252,22 +252,20 @@ class LFSSumfactKernelInput(SumfactKernelInterfaceBase, ImmutableRecord): # recalculating it in the property. dim = world_dimension() quadrature_permutation = sumfact_quadrature_permutation_strategy(dim, restriction) - matrix_sequence = permute_forward(matrix_sequence, quadrature_permutation) - cost_permutation = sumfact_cost_permutation_strategy(matrix_sequence, self.stage) # Note: Do not put matrix_sequence into the Record. That screws up the vectorization strategy! - ImmutableRecord.__init__(self, - coeff_func=coeff_func, - element=element, - element_index=element_index, - restriction=restriction, - _quadrature_permutation=quadrature_permutation, - _cost_permutation=cost_permutation, - ) + ImmutableCuttingRecord.__init__(self, + coeff_func=coeff_func, + element=element, + element_index=element_index, + restriction=restriction, + _quadrature_permutation=quadrature_permutation, + _permuted_matrix_sequence=matrix_sequence, + ) def __repr__(self): - return ImmutableRecord.__repr__(self) + return ImmutableCuttingRecord.__repr__(self) def __str__(self): return repr(self) @@ -280,7 +278,7 @@ class LFSSumfactKernelInput(SumfactKernelInterfaceBase, ImmutableRecord): this dict to create an interface. """ dict = self.get_copy_kwargs() - del dict['_cost_permutation'] + del dict['_permuted_matrix_sequence'] del dict['_quadrature_permutation'] dict['matrix_sequence'] = None return dict @@ -291,7 +289,7 @@ class LFSSumfactKernelInput(SumfactKernelInterfaceBase, ImmutableRecord): @property def cost_permutation(self): - return self._cost_permutation + return sumfact_cost_permutation_strategy(self._permuted_matrix_sequence, self.stage) @property def stage(self): diff --git a/python/dune/codegen/sumfact/geometry.py b/python/dune/codegen/sumfact/geometry.py index 08e9d43e8fe685b3b24b60bd10be418b5f078ad7..ddb68dd006d5ff3bf1f3b5db09111f7c03b7f1d3 100644 --- a/python/dune/codegen/sumfact/geometry.py +++ b/python/dune/codegen/sumfact/geometry.py @@ -41,12 +41,10 @@ from dune.codegen.sumfact.permutation import (permute_backward, from dune.codegen.sumfact.quadrature import additional_inames from dune.codegen.sumfact.switch import get_facedir, get_facemod from dune.codegen.sumfact.symbolic import SumfactKernelInterfaceBase, SumfactKernel -from dune.codegen.tools import get_pymbolic_basename +from dune.codegen.tools import get_pymbolic_basename, ImmutableCuttingRecord from dune.codegen.options import get_form_option, option_switch from dune.codegen.ufl.modified_terminals import Restriction -from pytools import ImmutableRecord - from loopy.match import Writes import pymbolic.primitives as prim @@ -337,7 +335,7 @@ def global_corner_iname(restriction): return name -class GeoCornersInput(SumfactKernelInterfaceBase, ImmutableRecord): +class GeoCornersInput(SumfactKernelInterfaceBase, ImmutableCuttingRecord): def __init__(self, matrix_sequence=None, direction=None, @@ -361,20 +359,18 @@ class GeoCornersInput(SumfactKernelInterfaceBase, ImmutableRecord): # recalculating it in the property. dim = world_dimension() quadrature_permutation = sumfact_quadrature_permutation_strategy(dim, restriction) - matrix_sequence = permute_forward(matrix_sequence, quadrature_permutation) - cost_permutation = sumfact_cost_permutation_strategy(matrix_sequence, self.stage) # Note: Do not put matrix_sequence into the Record. That screws up the vectorization strategy! - ImmutableRecord.__init__(self, - direction=direction, - restriction=restriction, - _quadrature_permutation=quadrature_permutation, - _cost_permutation=cost_permutation, - ) + ImmutableCuttingRecord.__init__(self, + direction=direction, + restriction=restriction, + _quadrature_permutation=quadrature_permutation, + _permuted_matrix_sequence=matrix_sequence, + ) def __repr__(self): - return ImmutableRecord.__repr__(self) + return ImmutableCuttingRecord.__repr__(self) def __str__(self): return repr(self) @@ -387,7 +383,7 @@ class GeoCornersInput(SumfactKernelInterfaceBase, ImmutableRecord): this dict to create an interface. """ dict = self.get_copy_kwargs() - del dict['_cost_permutation'] + del dict['_permutated_matrix_sequence'] del dict['_quadrature_permutation'] dict['matrix_sequence'] = None return dict @@ -398,7 +394,7 @@ class GeoCornersInput(SumfactKernelInterfaceBase, ImmutableRecord): @property def cost_permutation(self): - return self._cost_permutation + return sumfact_cost_permutation_strategy(self._permuted_matrix_sequence, self.stage) @property def stage(self): diff --git a/python/dune/codegen/sumfact/symbolic.py b/python/dune/codegen/sumfact/symbolic.py index da26c7b9e111fdc7d9da6253297e9a9814abe91b..656aa236b67c145b5fa35a6b459f7a0131df01e0 100644 --- a/python/dune/codegen/sumfact/symbolic.py +++ b/python/dune/codegen/sumfact/symbolic.py @@ -1000,16 +1000,7 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable) def horizontal_index(self, sf): for i, k in enumerate(self.kernels): - # We need to identify to which part of the vectorized kernel sf - # corresponds. Since splitting might change the cost_permutation we - # exclude it in the comparison below. We also make sure to check - # that derivatives are the same. - from copy import deepcopy - sf_interface = deepcopy(sf.interface) - sf_interface._cost_permutation = None - k_interface = deepcopy(k.interface) - k_interface._cost_permutation = None - if repr(sf_interface) == repr(k_interface): + if sf.interface == k.interface: if tuple(mat.derivative for mat in sf.matrix_sequence_quadrature_permuted) == tuple(mat.derivative for mat in k.matrix_sequence_quadrature_permuted): return i diff --git a/python/dune/codegen/tools.py b/python/dune/codegen/tools.py index d5c0a18ebe8f7e32316aa459c959ab1eb9ba75f8..a27ee13384d62697bc984a33819b09e9982c4a39 100644 --- a/python/dune/codegen/tools.py +++ b/python/dune/codegen/tools.py @@ -4,6 +4,20 @@ from __future__ import absolute_import import loopy as lp import pymbolic.primitives as prim import frozendict +import pytools + + +class ImmutableCuttingRecord(pytools.ImmutableRecord): + """ + A record implementation that drops fields starting with an underscore + from hash and equality computation + """ + def __hash__(self): + return hash((type(self),) + tuple(getattr(self, field) for field in self.__class__.fields if not field.startswith("_"))) + + def __eq__(self, other): + return type(self) == type(other) and all(getattr(self, field) == getattr(other, field) for field in self.__class__.fields if not field.startswith("_")) + def get_pymbolic_basename(expr):