From dd0363d9780f7eb3d7c917c27ec358cd8adfe660 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9=20He=C3=9F?= <rene.hess@iwr.uni-heidelberg.de> Date: Mon, 22 Oct 2018 16:27:53 +0200 Subject: [PATCH] Move quadrature_permutation to interface SumfactKernelInterfaceBase --- python/dune/codegen/sumfact/accumulation.py | 15 +++++++ python/dune/codegen/sumfact/basis.py | 15 +++++++ python/dune/codegen/sumfact/geometry.py | 16 ++++++- python/dune/codegen/sumfact/realization.py | 14 +++--- python/dune/codegen/sumfact/symbolic.py | 48 +++++++++++---------- 5 files changed, 77 insertions(+), 31 deletions(-) diff --git a/python/dune/codegen/sumfact/accumulation.py b/python/dune/codegen/sumfact/accumulation.py index 6af42b2f..7059b3a5 100644 --- a/python/dune/codegen/sumfact/accumulation.py +++ b/python/dune/codegen/sumfact/accumulation.py @@ -94,6 +94,16 @@ class AccumulationOutput(SumfactKernelInterfaceBase, ImmutableRecord): trial_element=None, trial_element_index=None, ): + + # Note: The function sumfact_quadrature_permutation_strategy does not + # work anymore after the visiting process since get_facedir and + # get_facemod are not well defined. But we need the + # quadrature_permutation to generate the name of the sumfact + # kernel. This means we need to store the value here instead of + # recalculating it in the property. + dim = world_dimension() + quadrature_permutation = sumfact_quadrature_permutation_strategy(dim, restriction[0]) + # TODO: Isnt accumvar superfluous in the presence of all the other infos? ImmutableRecord.__init__(self, accumvar=accumvar, @@ -102,11 +112,16 @@ class AccumulationOutput(SumfactKernelInterfaceBase, ImmutableRecord): test_element_index=test_element_index, trial_element=trial_element, trial_element_index=trial_element_index, + _quadrature_permutation=quadrature_permutation, ) def __repr__(self): return ImmutableRecord.__repr__(self) + @property + def quadrature_permutation(self): + return self._quadrature_permutation + @property def stage(self): return 3 diff --git a/python/dune/codegen/sumfact/basis.py b/python/dune/codegen/sumfact/basis.py index 0a1ce5b8..a7d2e01e 100644 --- a/python/dune/codegen/sumfact/basis.py +++ b/python/dune/codegen/sumfact/basis.py @@ -57,11 +57,22 @@ class LFSSumfactKernelInput(SumfactKernelInterfaceBase, ImmutableRecord): element_index=0, restriction=0, ): + + # Note: The function sumfact_quadrature_permutation_strategy does not + # work anymore after the visiting process since get_facedir and + # get_facemod are not well defined. But we need the + # quadrature_permutation to generate the name of the sumfact + # kernel. This means we need to store the value here instead of + # recalculating it in the property. + dim = world_dimension() + quadrature_permutation = sumfact_quadrature_permutation_strategy(dim, restriction) + ImmutableRecord.__init__(self, coeff_func=coeff_func, element=element, element_index=element_index, restriction=restriction, + _quadrature_permutation=quadrature_permutation, ) def __repr__(self): @@ -70,6 +81,10 @@ class LFSSumfactKernelInput(SumfactKernelInterfaceBase, ImmutableRecord): def __str__(self): return repr(self) + @property + def quadrature_permutation(self): + return self._quadrature_permutation + @property def stage(self): return 1 diff --git a/python/dune/codegen/sumfact/geometry.py b/python/dune/codegen/sumfact/geometry.py index 3d7df550..17aa5963 100644 --- a/python/dune/codegen/sumfact/geometry.py +++ b/python/dune/codegen/sumfact/geometry.py @@ -67,7 +67,17 @@ class GeoCornersInput(SumfactKernelInterfaceBase, ImmutableRecord): argument 'direction' specifies the component (x-component: 0, y-component: 1, z-component: 2). """ - ImmutableRecord.__init__(self, direction=direction, restriction=restriction) + + # Note: The function sumfact_quadrature_permutation_strategy does not + # work anymore after the visiting process since get_facedir and + # get_facemod are not well defined. But we need the + # quadrature_permutation to generate the name of the sumfact + # kernel. This means we need to store the value here instead of + # recalculating it in the property. + dim = world_dimension() + quadrature_permutation = sumfact_quadrature_permutation_strategy(dim, restriction) + + ImmutableRecord.__init__(self, direction=direction, restriction=restriction, _quadrature_permutation=quadrature_permutation) def __repr__(self): return ImmutableRecord.__repr__(self) @@ -75,6 +85,10 @@ class GeoCornersInput(SumfactKernelInterfaceBase, ImmutableRecord): def __str__(self): return repr(self) + @property + def quadrature_permutation(self): + return self._quadrature_permutation + @property def stage(self): return 1 diff --git a/python/dune/codegen/sumfact/realization.py b/python/dune/codegen/sumfact/realization.py index ce4393ef..1d951eca 100644 --- a/python/dune/codegen/sumfact/realization.py +++ b/python/dune/codegen/sumfact/realization.py @@ -192,8 +192,8 @@ def realize_sumfact_kernel_function(sf): input_inames = permute_backward(input_inames, sf.cost_permutation) # And one more for permuted quadrature points, see comment below - inp_shape = permute_backward(inp_shape, sf.quadrature_permutation) - input_inames = permute_backward(input_inames, sf.quadrature_permutation) + inp_shape = permute_backward(inp_shape, sf.interface.quadrature_permutation) + input_inames = permute_backward(input_inames, sf.interface.quadrature_permutation) input_summand = sf.interface.realize_direct(inp_shape, input_inames) else: @@ -207,10 +207,10 @@ def realize_sumfact_kernel_function(sf): if sf.stage == 1: # In the unstructured case the sf.matrix_sequence_quadrature_permuted could # already be permuted according to - # sf.quadrature_permutation. We also need to reverse this + # sf.interface.quadrature_permutation. We also need to reverse this # permutation to get the input from 0 to d-1. - inp_shape = permute_backward(inp_shape, sf.quadrature_permutation) - input_inames = permute_backward(input_inames, sf.quadrature_permutation) + inp_shape = permute_backward(inp_shape, sf.interface.quadrature_permutation) + input_inames = permute_backward(input_inames, sf.interface.quadrature_permutation) # Get a temporary that interprets the base storage of the input # as a column-major matrix. In later iteration of the matrix loop @@ -242,7 +242,7 @@ def realize_sumfact_kernel_function(sf): if l == len(matrix_sequence) - 1: output_shape = permute_backward(output_shape, sf.cost_permutation) if sf.stage == 3: - output_shape = permute_backward(output_shape, sf.quadrature_permutation) + output_shape = permute_backward(output_shape, sf.interface.quadrature_permutation) out = buffer.get_temporary("buff_step{}_out".format(l), shape=output_shape + vec_shape, @@ -263,7 +263,7 @@ def realize_sumfact_kernel_function(sf): if l == len(matrix_sequence) - 1: output_inames = permute_backward(output_inames, sf.cost_permutation) if sf.stage == 3: - output_inames = permute_backward(output_inames, sf.quadrature_permutation) + output_inames = permute_backward(output_inames, sf.interface.quadrature_permutation) # Collect the key word arguments for the loopy instruction insn_args = {"depends_on": insn_dep} diff --git a/python/dune/codegen/sumfact/symbolic.py b/python/dune/codegen/sumfact/symbolic.py index 91e48eb2..c23d77b6 100644 --- a/python/dune/codegen/sumfact/symbolic.py +++ b/python/dune/codegen/sumfact/symbolic.py @@ -37,6 +37,10 @@ class SumfactKernelInterfaceBase(object): def realize_direct(self, *a, **kw): raise NotImplementedError + @property + def quadrature_permutation(self): + return () + @property def within_inames(self): return () @@ -73,6 +77,13 @@ class VectorSumfactKernelInput(SumfactKernelInterfaceBase): def __repr__(self): return "_".join(repr(i) for i in self.interfaces) + @property + def quadrature_permutation(self): + # TODO: For now we assure that all kerneles have the same quadrature_permutation + for i in self.interfaces: + assert i.quadrature_permutation == self.interfaces[0].quadrature_permutation + return self.interfaces[0].quadrature_permutation + @property def stage(self): return 1 @@ -140,6 +151,13 @@ class VectorSumfactKernelOutput(SumfactKernelInterfaceBase): def __repr__(self): return "_".join(repr(o) for o in self.interfaces) + @property + def quadrature_permutation(self): + # TODO: For now we assure that all kerneles have the same quadrature_permutation + for i in self.interfaces: + assert i.quadrature_permutation == self.interfaces[0].quadrature_permutation + return self.interfaces[0].quadrature_permutation + @property def stage(self): return 3 @@ -237,7 +255,6 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable): insn_dep=frozenset(), interface=SumfactKernelInterfaceBase(), predicates=frozenset(), - quadrature_permutation=None, ): """Create a sum factorization kernel @@ -311,27 +328,12 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable): for a in SumfactKernel.init_arg_names: defaultdict[a] = eval(a) - dim = len(matrix_sequence) - # Not sure if this whole permuting would make sense if we would do sum # factorized evaluation of intersections where len(matrix_sequence) # would not be equal to world dim. + dim = len(matrix_sequence) assert dim == world_dimension() - # Get restriction for this sum factorization kernel. Note: For - # accumulation output we have a restriction for the test (index 0) and - # ansatz (index 1) space. We need the restriction corresponding to the - # test space since we are in stage 3 - restriction = interface.restriction - if isinstance(restriction, tuple): - assert interface.stage is 3 - assert len(restriction) is 2 - restriction = restriction[0] - - # Store correct quadrature_permutation - quadrature_permuation = sumfact_quadrature_permutation_strategy(dim, restriction) - defaultdict['quadrature_permutation'] = quadrature_permuation - # Call the base class constructors ImmutableRecord.__init__(self, **defaultdict) prim.Variable.__init__(self, "SUMFACT") @@ -371,8 +373,8 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable): # different permuation of quadrature points on self and neighbor. Mangle # the permutation of the quadrature points into the name to generate # sperate functions. - if self.quadrature_permutation != tuple(range(len(self.matrix_sequence))): - name_quad_perm = "_qpperm_{}".format("".join(str(a) for a in self.quadrature_permutation)) + if self.interface.quadrature_permutation != tuple(range(len(self.matrix_sequence))): + name_quad_perm = "_qpperm_{}".format("".join(str(a) for a in self.interface.quadrature_permutation)) name = name + name_quad_perm return name @@ -383,7 +385,7 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable): # TODO: For now we do not vectorize SumfactKernels with different # quadrature_permutation. This should be handled like upper/lower # vectorization - return self.quadrature_permutation + tuple(m.basis_size for m in self.matrix_sequence_quadrature_permuted) + (self.stage, self.buffer, self.interface.within_inames) + return tuple(m.basis_size for m in self.matrix_sequence_quadrature_permuted) + (self.stage, self.buffer, self.interface.within_inames) @property def cache_key(self): @@ -456,7 +458,7 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable): to ensure that quadrature points are visited in the same order on self and neighbor. """ - perm = self.quadrature_permutation + perm = self.interface.quadrature_permutation matrix_sequence_quadrature_permuted = permute_forward(self.matrix_sequence, perm) return matrix_sequence_quadrature_permuted @@ -640,7 +642,7 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable) # Assert that quadrature permutation is the same for all kernels for k in kernels: - assert k.quadrature_permutation == kernels[0].quadrature_permutation + assert k.interface.quadrature_permutation == kernels[0].interface.quadrature_permutation # We currently assume that all subkernels are consecutive, 0-based within the vector assert None not in kernels @@ -722,7 +724,7 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable) @property def quadrature_permutation(self): - return self.kernels[0].quadrature_permutation + return self.kernels[0].interface.quadrature_permutation @property def within_inames(self): -- GitLab