Skip to content
Snippets Groups Projects
Commit 327658d5 authored by René Heß's avatar René Heß
Browse files

Add permutation methods to Interface classes

Note: They are not yet used but in the long term the permutation should be
handled here since it is about input/output setup.
parent db2e2977
No related branches found
No related tags found
No related merge requests found
...@@ -30,7 +30,10 @@ from dune.codegen.pdelab.restriction import restricted_name ...@@ -30,7 +30,10 @@ from dune.codegen.pdelab.restriction import restricted_name
from dune.codegen.pdelab.signatures import assembler_routine_name from dune.codegen.pdelab.signatures import assembler_routine_name
from dune.codegen.pdelab.geometry import world_dimension from dune.codegen.pdelab.geometry import world_dimension
from dune.codegen.pdelab.spaces import name_lfs from dune.codegen.pdelab.spaces import name_lfs
from dune.codegen.sumfact.permutation import sumfact_quadrature_permutation_strategy from dune.codegen.sumfact.permutation import (permute_forward,
sumfact_cost_permutation_strategy,
sumfact_quadrature_permutation_strategy,
)
from dune.codegen.sumfact.tabulation import (basis_functions_per_direction, from dune.codegen.sumfact.tabulation import (basis_functions_per_direction,
construct_basis_matrix_sequence, construct_basis_matrix_sequence,
) )
...@@ -88,6 +91,7 @@ def accum_iname(element, bound, i): ...@@ -88,6 +91,7 @@ def accum_iname(element, bound, i):
class AccumulationOutput(SumfactKernelInterfaceBase, ImmutableRecord): class AccumulationOutput(SumfactKernelInterfaceBase, ImmutableRecord):
def __init__(self, def __init__(self,
matrix_sequence,
accumvar=None, accumvar=None,
restriction=None, restriction=None,
test_element=None, test_element=None,
...@@ -105,6 +109,10 @@ class AccumulationOutput(SumfactKernelInterfaceBase, ImmutableRecord): ...@@ -105,6 +109,10 @@ class AccumulationOutput(SumfactKernelInterfaceBase, ImmutableRecord):
dim = world_dimension() dim = world_dimension()
quadrature_permutation = sumfact_quadrature_permutation_strategy(dim, restriction[0]) quadrature_permutation = sumfact_quadrature_permutation_strategy(dim, restriction[0])
# Calculate cost optimal permutation
matrix_sequence = permute_forward(matrix_sequence, quadrature_permutation)
cost_permutation = sumfact_cost_permutation_strategy(matrix_sequence, self.stage)
# TODO: Isnt accumvar superfluous in the presence of all the other infos? # TODO: Isnt accumvar superfluous in the presence of all the other infos?
ImmutableRecord.__init__(self, ImmutableRecord.__init__(self,
accumvar=accumvar, accumvar=accumvar,
...@@ -114,6 +122,7 @@ class AccumulationOutput(SumfactKernelInterfaceBase, ImmutableRecord): ...@@ -114,6 +122,7 @@ class AccumulationOutput(SumfactKernelInterfaceBase, ImmutableRecord):
trial_element=trial_element, trial_element=trial_element,
trial_element_index=trial_element_index, trial_element_index=trial_element_index,
_quadrature_permutation=quadrature_permutation, _quadrature_permutation=quadrature_permutation,
_cost_permutation=cost_permutation,
) )
def __repr__(self): def __repr__(self):
...@@ -123,6 +132,10 @@ class AccumulationOutput(SumfactKernelInterfaceBase, ImmutableRecord): ...@@ -123,6 +132,10 @@ class AccumulationOutput(SumfactKernelInterfaceBase, ImmutableRecord):
def quadrature_permutation(self): def quadrature_permutation(self):
return self._quadrature_permutation return self._quadrature_permutation
@property
def cost_permutation(self):
return self._cost_permutation
@property @property
def stage(self): def stage(self):
return 3 return 3
...@@ -457,7 +470,8 @@ def generate_accumulation_instruction(expr, visitor): ...@@ -457,7 +470,8 @@ def generate_accumulation_instruction(expr, visitor):
if priority is None: if priority is None:
priority = 3 priority = 3
output = AccumulationOutput(accumvar=accumvar, output = AccumulationOutput(matrix_sequence,
accumvar=accumvar,
restriction=(test_info.restriction, trial_info.restriction), restriction=(test_info.restriction, trial_info.restriction),
test_element=test_info.element, test_element=test_info.element,
test_element_index=test_info.element_index, test_element_index=test_info.element_index,
......
...@@ -24,7 +24,10 @@ from dune.codegen.sumfact.tabulation import (basis_functions_per_direction, ...@@ -24,7 +24,10 @@ from dune.codegen.sumfact.tabulation import (basis_functions_per_direction,
name_polynomials, name_polynomials,
polynomial_degree, polynomial_degree,
) )
from dune.codegen.sumfact.permutation import sumfact_quadrature_permutation_strategy from dune.codegen.sumfact.permutation import (permute_forward,
sumfact_cost_permutation_strategy,
sumfact_quadrature_permutation_strategy,
)
from dune.codegen.sumfact.quadrature import quadrature_inames from dune.codegen.sumfact.quadrature import quadrature_inames
from dune.codegen.sumfact.switch import (get_facedir, from dune.codegen.sumfact.switch import (get_facedir,
get_facemod, get_facemod,
...@@ -53,6 +56,7 @@ import pymbolic.primitives as prim ...@@ -53,6 +56,7 @@ import pymbolic.primitives as prim
class LFSSumfactKernelInput(SumfactKernelInterfaceBase, ImmutableRecord): class LFSSumfactKernelInput(SumfactKernelInterfaceBase, ImmutableRecord):
def __init__(self, def __init__(self,
matrix_sequence,
coeff_func=None, coeff_func=None,
element=None, element=None,
element_index=0, element_index=0,
...@@ -68,12 +72,16 @@ class LFSSumfactKernelInput(SumfactKernelInterfaceBase, ImmutableRecord): ...@@ -68,12 +72,16 @@ class LFSSumfactKernelInput(SumfactKernelInterfaceBase, ImmutableRecord):
dim = world_dimension() dim = world_dimension()
quadrature_permutation = sumfact_quadrature_permutation_strategy(dim, restriction) quadrature_permutation = sumfact_quadrature_permutation_strategy(dim, restriction)
matrix_sequence = permute_forward(matrix_sequence, quadrature_permutation)
cost_permutation = sumfact_cost_permutation_strategy(matrix_sequence, self.stage)
ImmutableRecord.__init__(self, ImmutableRecord.__init__(self,
coeff_func=coeff_func, coeff_func=coeff_func,
element=element, element=element,
element_index=element_index, element_index=element_index,
restriction=restriction, restriction=restriction,
_quadrature_permutation=quadrature_permutation, _quadrature_permutation=quadrature_permutation,
_cost_permutation=cost_permutation,
) )
def __repr__(self): def __repr__(self):
...@@ -86,6 +94,10 @@ class LFSSumfactKernelInput(SumfactKernelInterfaceBase, ImmutableRecord): ...@@ -86,6 +94,10 @@ class LFSSumfactKernelInput(SumfactKernelInterfaceBase, ImmutableRecord):
def quadrature_permutation(self): def quadrature_permutation(self):
return self._quadrature_permutation return self._quadrature_permutation
@property
def cost_permutation(self):
return self._cost_permutation
@property @property
def stage(self): def stage(self):
return 1 return 1
...@@ -196,7 +208,8 @@ def pymbolic_coefficient_gradient(element, restriction, index, coeff_func, visit ...@@ -196,7 +208,8 @@ def pymbolic_coefficient_gradient(element, restriction, index, coeff_func, visit
basis_size=basis_size, basis_size=basis_size,
) )
inp = LFSSumfactKernelInput(coeff_func=coeff_func, inp = LFSSumfactKernelInput(matrix_sequence,
coeff_func=coeff_func,
element=element, element=element,
element_index=index, element_index=index,
restriction=restriction, restriction=restriction,
...@@ -239,7 +252,8 @@ def pymbolic_coefficient(element, restriction, index, coeff_func, visitor): ...@@ -239,7 +252,8 @@ def pymbolic_coefficient(element, restriction, index, coeff_func, visitor):
facemod=get_facemod(restriction), facemod=get_facemod(restriction),
basis_size=basis_size) basis_size=basis_size)
inp = LFSSumfactKernelInput(coeff_func=coeff_func, inp = LFSSumfactKernelInput(matrix_sequence,
coeff_func=coeff_func,
element=element, element=element,
element_index=index, element_index=index,
restriction=restriction, restriction=restriction,
......
...@@ -28,7 +28,10 @@ from dune.codegen.pdelab.localoperator import (name_ansatz_gfs_constructor_param ...@@ -28,7 +28,10 @@ from dune.codegen.pdelab.localoperator import (name_ansatz_gfs_constructor_param
from dune.codegen.pdelab.restriction import restricted_name from dune.codegen.pdelab.restriction import restricted_name
from dune.codegen.sumfact.accumulation import basis_sf_kernels from dune.codegen.sumfact.accumulation import basis_sf_kernels
from dune.codegen.sumfact.basis import construct_basis_matrix_sequence from dune.codegen.sumfact.basis import construct_basis_matrix_sequence
from dune.codegen.sumfact.permutation import sumfact_quadrature_permutation_strategy from dune.codegen.sumfact.permutation import (permute_forward,
sumfact_cost_permutation_strategy,
sumfact_quadrature_permutation_strategy,
)
from dune.codegen.sumfact.quadrature import (additional_inames, from dune.codegen.sumfact.quadrature import (additional_inames,
default_quadrature_inames) default_quadrature_inames)
from dune.codegen.sumfact.realization import (name_buffer_storage, from dune.codegen.sumfact.realization import (name_buffer_storage,
...@@ -57,7 +60,7 @@ def global_corner_iname(restriction): ...@@ -57,7 +60,7 @@ def global_corner_iname(restriction):
class GeoCornersInput(SumfactKernelInterfaceBase, ImmutableRecord): class GeoCornersInput(SumfactKernelInterfaceBase, ImmutableRecord):
def __init__(self, direction, restriction): def __init__(self, matrix_sequence, direction, restriction):
"""Base class for sum-factorized evaluation of geometry mappings """Base class for sum-factorized evaluation of geometry mappings
At the moment we only do this for cells and not faces. For At the moment we only do this for cells and not faces. For
...@@ -78,7 +81,15 @@ class GeoCornersInput(SumfactKernelInterfaceBase, ImmutableRecord): ...@@ -78,7 +81,15 @@ class GeoCornersInput(SumfactKernelInterfaceBase, ImmutableRecord):
dim = world_dimension() dim = world_dimension()
quadrature_permutation = sumfact_quadrature_permutation_strategy(dim, restriction) quadrature_permutation = sumfact_quadrature_permutation_strategy(dim, restriction)
ImmutableRecord.__init__(self, direction=direction, restriction=restriction, _quadrature_permutation=quadrature_permutation) matrix_sequence = permute_forward(matrix_sequence, quadrature_permutation)
cost_permutation = sumfact_cost_permutation_strategy(matrix_sequence, self.stage)
ImmutableRecord.__init__(self,
direction=direction,
restriction=restriction,
_quadrature_permutation=quadrature_permutation,
_cost_permutation=cost_permutation,
)
def __repr__(self): def __repr__(self):
return ImmutableRecord.__repr__(self) return ImmutableRecord.__repr__(self)
...@@ -90,6 +101,10 @@ class GeoCornersInput(SumfactKernelInterfaceBase, ImmutableRecord): ...@@ -90,6 +101,10 @@ class GeoCornersInput(SumfactKernelInterfaceBase, ImmutableRecord):
def quadrature_permutation(self): def quadrature_permutation(self):
return self._quadrature_permutation return self._quadrature_permutation
@property
def cost_permutation(self):
return self._cost_permutation
@property @property
def stage(self): def stage(self):
return 1 return 1
...@@ -160,7 +175,7 @@ def pymbolic_spatial_coordinate_multilinear(do_predicates, visitor): ...@@ -160,7 +175,7 @@ def pymbolic_spatial_coordinate_multilinear(do_predicates, visitor):
matrix_sequence = construct_basis_matrix_sequence(facedir=get_facedir(restriction), matrix_sequence = construct_basis_matrix_sequence(facedir=get_facedir(restriction),
facemod=get_facemod(restriction), facemod=get_facemod(restriction),
basis_size=(2,) * world_dimension()) basis_size=(2,) * world_dimension())
inp = GeoCornersInput(visitor.indices[0], restriction) inp = GeoCornersInput(matrix_sequence, visitor.indices[0], restriction)
sf = SumfactKernel(matrix_sequence=matrix_sequence, sf = SumfactKernel(matrix_sequence=matrix_sequence,
interface=inp, interface=inp,
) )
...@@ -537,7 +552,7 @@ def _name_jacobian(i, j, restriction, visitor): ...@@ -537,7 +552,7 @@ def _name_jacobian(i, j, restriction, visitor):
basis_size=(2,) * world_dimension()) basis_size=(2,) * world_dimension())
# Sum factorization input for the i'th component of the geometry mapping # Sum factorization input for the i'th component of the geometry mapping
inp = GeoCornersInput(i, restriction) inp = GeoCornersInput(matrix_sequence, i, restriction)
sf = SumfactKernel(matrix_sequence=matrix_sequence, sf = SumfactKernel(matrix_sequence=matrix_sequence,
interface=inp, interface=inp,
......
...@@ -45,16 +45,12 @@ def flop_cost(matrix_sequence): ...@@ -45,16 +45,12 @@ def flop_cost(matrix_sequence):
return 2 * cost return 2 * cost
def sumfact_cost_permutation_strategy(sf): def sumfact_cost_permutation_strategy(matrix_sequence, stage):
"""Choose permutation of the matrix sequence based on computational cost """Choose permutation of the matrix sequence based on computational cost
Note: If there are multiple permutations with the same cost a Note: If there are multiple permutations with the same cost a
heuristic is used to pick one. heuristic is used to pick one.
""" """
# Extract information from the SumfactKernel object
matrix_sequence = sf.matrix_sequence_quadrature_permuted
stage = sf.stage
# Combine permutation and matrix_sequence # Combine permutation and matrix_sequence
perm = [i for i, _ in enumerate(matrix_sequence)] perm = [i for i, _ in enumerate(matrix_sequence)]
perm_matrix_sequence = zip(perm, matrix_sequence) perm_matrix_sequence = zip(perm, matrix_sequence)
......
...@@ -7,6 +7,7 @@ from dune.codegen.generation import (get_counted_variable, ...@@ -7,6 +7,7 @@ from dune.codegen.generation import (get_counted_variable,
) )
from dune.codegen.pdelab.geometry import local_dimension, world_dimension from dune.codegen.pdelab.geometry import local_dimension, world_dimension
from dune.codegen.sumfact.permutation import (flop_cost, from dune.codegen.sumfact.permutation import (flop_cost,
permute_backward,
permute_forward, permute_forward,
sumfact_cost_permutation_strategy, sumfact_cost_permutation_strategy,
sumfact_quadrature_permutation_strategy, sumfact_quadrature_permutation_strategy,
...@@ -41,6 +42,34 @@ class SumfactKernelInterfaceBase(object): ...@@ -41,6 +42,34 @@ class SumfactKernelInterfaceBase(object):
def quadrature_permutation(self): def quadrature_permutation(self):
return () return ()
@property
def cost_permutation(self):
return ()
@property
def combined_permutation(self):
return permute_forward(self.quadrature_permutation, self.cost_permutation)
def permute_backward_cost(self, shape, inames):
shape = permute_backward(shape, self.cost_permutation)
inames = permute_backward(inames, self.cost_permutation)
return shape, inames
def permute_backward_quadrature(self, shape, inames):
shape = permute_backward(shape, self.quadrature_permutation)
inames = permute_backward(inames, self.quadrature_permutation)
return shape, inames
def permute_forward_cost(self, shape, inames):
shape = permute_forward(shape, self.cost_permutation)
inames = permute_forward(inames, self.cost_permutation)
return shape_inames
def permute_forward_quadrature(self, shape, inames):
shape = permute_forward(shape, self.quadrature_permutation)
inames = permute_forward(inames, self.quadrature_permutation)
return shape_inames
@property @property
def within_inames(self): def within_inames(self):
return () return ()
...@@ -474,13 +503,13 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable): ...@@ -474,13 +503,13 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
Rule of thumb: small m's early and large n's late. Rule of thumb: small m's early and large n's late.
""" """
perm = sumfact_cost_permutation_strategy(self) perm = sumfact_cost_permutation_strategy(self.matrix_sequence_quadrature_permuted, self.stage)
matrix_sequence_cost_permuted = permute_forward(self.matrix_sequence_quadrature_permuted, perm) matrix_sequence_cost_permuted = permute_forward(self.matrix_sequence_quadrature_permuted, perm)
return matrix_sequence_cost_permuted return matrix_sequence_cost_permuted
@property @property
def cost_permutation(self): def cost_permutation(self):
return sumfact_cost_permutation_strategy(self) return sumfact_cost_permutation_strategy(self.matrix_sequence_quadrature_permuted, self.stage)
@property @property
def quadrature_shape(self): def quadrature_shape(self):
...@@ -713,13 +742,13 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable) ...@@ -713,13 +742,13 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
@property @property
def matrix_sequence_cost_permuted(self): def matrix_sequence_cost_permuted(self):
perm = sumfact_cost_permutation_strategy(self) perm = sumfact_cost_permutation_strategy(self.matrix_sequence_quadrature_permuted, self.stage)
matrix_sequence_cost_permuted = permute_forward(self.matrix_sequence_quadrature_permuted, perm) matrix_sequence_cost_permuted = permute_forward(self.matrix_sequence_quadrature_permuted, perm)
return matrix_sequence_cost_permuted return matrix_sequence_cost_permuted
@property @property
def cost_permutation(self): def cost_permutation(self):
return sumfact_cost_permutation_strategy(self) return sumfact_cost_permutation_strategy(self.matrix_sequence_quadrature_permuted, self.stage)
@property @property
def stage(self): def stage(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment