Skip to content
Snippets Groups Projects
Commit 0001fb03 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Move assembly of input for sf kernels to a dedicated object

parent 7d85651b
No related branches found
No related tags found
No related merge requests found
......@@ -29,12 +29,13 @@ from dune.perftool.sumfact.tabulation import (basis_functions_per_direction,
from dune.perftool.sumfact.switch import (get_facedir,
get_facemod,
)
from dune.perftool.sumfact.symbolic import SumfactKernel
from dune.perftool.sumfact.symbolic import SumfactKernel, SumfactKernelInputBase
from dune.perftool.ufl.modified_terminals import extract_modified_arguments
from dune.perftool.tools import get_pymbolic_basename
from dune.perftool.error import PerftoolError
from dune.perftool.sumfact.quadrature import quadrature_inames
from pytools import ImmutableRecord
import loopy as lp
import numpy as np
......@@ -59,6 +60,10 @@ def accum_iname(restriction, bound, i):
return sumfact_iname(bound, "accum")
class AlreadyAssembledInput(SumfactKernelInputBase, ImmutableRecord):
pass
@backend(interface="accum_insn", name="sumfact")
def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
# When doing sumfactorization we want to split the test function
......@@ -126,7 +131,7 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
preferred_position=indices[-1] if accterm.new_indices else None,
accumvar=accum,
within_inames=jacobian_inames,
coeff_func_index=coeff_func_index,
input=AlreadyAssembledInput(index=coeff_func_index),
)
from dune.perftool.sumfact.vectorization import attach_vectorization_info
......
......@@ -33,19 +33,77 @@ from dune.perftool.pdelab.geometry import (local_dimension,
world_dimension,
)
from dune.perftool.loopy.buffer import initialize_buffer
from dune.perftool.sumfact.symbolic import SumfactKernel
from dune.perftool.sumfact.symbolic import SumfactKernel, SumfactKernelInputBase
from dune.perftool.options import get_option
from dune.perftool.pdelab.driver import FEM_name_mangling
from dune.perftool.pdelab.restriction import restricted_name
from dune.perftool.pdelab.spaces import name_lfs, name_lfs_bound, lfs_child, name_leaf_lfs
from dune.perftool.tools import maybe_wrap_subscript
from dune.perftool.pdelab.basis import shape_as_pymbolic
from dune.perftool.sumfact.accumulation import sumfact_iname
from pytools import product
from ufl.functionview import select_subelement
from ufl import VectorElement, TensorElement
from pytools import product, ImmutableRecord
from loopy.match import Writes
import pymbolic.primitives as prim
class LFSSumfactKernelInput(SumfactKernelInputBase, ImmutableRecord):
def __init__(self,
coeff_func=None,
coeff_func_index=None,
element=None,
component=None,
restriction=0,
):
ImmutableRecord.__init__(self,
coeff_func=coeff_func,
coeff_func_index=coeff_func_index,
element=element,
component=component,
restriction=restriction,
)
def realize(self, sf, index, insn_dep):
lfs = name_lfs(self.element, self.restriction, self.component)
sub_element = select_subelement(self.element, self.component)
shape = sub_element.value_shape() + (self.element.cell().geometric_dimension(),)
if isinstance(sub_element, (VectorElement, TensorElement)):
# Could be 0 but shouldn't be None
assert self.coeff_func_index is not None
lfs_pym = lfs_child(lfs,
(self.coeff_func_index,),
shape=shape_as_pymbolic(shape[:-1]),
symmetry=self.element.symmetry())
leaf_element = sub_element
if isinstance(sub_element, (VectorElement, TensorElement)):
leaf_element = sub_element.sub_elements()[0]
lfs = name_leaf_lfs(leaf_element, self.restriction)
basisiname = sumfact_iname(name_lfs_bound(lfs), "basis")
container = self.coeff_func(self.restriction)
if isinstance(sub_element, (VectorElement, TensorElement)):
from dune.perftool.pdelab.argument import pymbolic_coefficient as pc
coeff = pc(container, lfs_pym, basisiname)
else:
from dune.perftool.pdelab.argument import pymbolic_coefficient as pc
coeff = pc(container, lfs, basisiname)
assignee = prim.Subscript(prim.Variable("input_{}".format(sf.buffer)),
(prim.Variable(basisiname),) + (index,))
instruction(assignee=assignee,
expression=coeff,
depends_on=sf.insn_dep.union(insn_dep),
tags=frozenset({"sumfact_stage{}".format(sf.stage)}),
)
def name_sumfact_base_buffer():
count = get_counter('sumfact_base_buffer')
name = "buffer_{}".format(str(count))
......@@ -102,14 +160,17 @@ def pymbolic_coefficient_gradient(element, restriction, component, coeff_func, v
if len(indices) == 2:
coeff_func_index = indices[0]
inp = LFSSumfactKernelInput(coeff_func=coeff_func,
coeff_func_index=coeff_func_index,
element=element,
component=component,
restriction=restriction,
)
# The sum factorization kernel object gathering all relevant information
sf = SumfactKernel(matrix_sequence=matrix_sequence,
restriction=restriction,
preferred_position=indices[-1],
coeff_func=coeff_func,
coeff_func_index=coeff_func_index,
element=element,
component=component,
input=inp,
)
from dune.perftool.sumfact.vectorization import attach_vectorization_info
......@@ -156,11 +217,14 @@ def pymbolic_coefficient(element, restriction, component, coeff_func, visitor_in
facemod=get_facemod(restriction),
basis_size=basis_size)
inp = LFSSumfactKernelInput(coeff_func=coeff_func,
element=element,
component=component,
restriction=restriction,
)
sf = SumfactKernel(matrix_sequence=matrix_sequence,
restriction=restriction,
coeff_func=coeff_func,
element=element,
component=component,
input=inp,
)
from dune.perftool.sumfact.vectorization import attach_vectorization_info
......
......@@ -2,9 +2,6 @@
The code that triggers the creation of the necessary code constructs
to realize a sum factorization kernel
"""
from ufl.functionview import select_subelement
from ufl import VectorElement, TensorElement
from dune.perftool.generation import (barrier,
dump_accumulate_timer,
generator_factory,
......@@ -21,7 +18,6 @@ from dune.perftool.loopy.buffer import (get_buffer_temporary,
from dune.perftool.pdelab.argument import pymbolic_coefficient
from dune.perftool.pdelab.basis import shape_as_pymbolic
from dune.perftool.pdelab.geometry import world_dimension
from dune.perftool.pdelab.spaces import name_lfs, name_lfs_bound, lfs_child, name_leaf_lfs
from dune.perftool.options import get_option
from dune.perftool.pdelab.signatures import assembler_routine_name
from dune.perftool.sumfact.permutation import (sumfact_permutation_strategy,
......@@ -63,7 +59,7 @@ def _realize_sum_factorization_kernel(sf):
# Set up the input for stage 1
if sf.stage == 1 and not get_option("fastdg"):
assert sf.coeff_func
assert sf.input
# Get the input temporary!
input_setup = get_buffer_temporary(sf.buffer,
......@@ -71,53 +67,18 @@ def _realize_sum_factorization_kernel(sf):
name="input_{}".format(sf.buffer)
)
def _write_input(inputsf, index=0):
# Write initial coefficients into buffer
lfs = name_lfs(inputsf.element, inputsf.restriction, inputsf.component)
sub_element = select_subelement(inputsf.element, inputsf.component)
shape = sub_element.value_shape() + (inputsf.element.cell().geometric_dimension(),)
if isinstance(sub_element, (VectorElement, TensorElement)):
# Could be 0 but shouldn't be None
assert inputsf.coeff_func_index is not None
lfs_pym = lfs_child(lfs,
(inputsf.coeff_func_index,),
shape=shape_as_pymbolic(shape[:-1]),
symmetry=inputsf.element.symmetry())
leaf_element = sub_element
if isinstance(sub_element, (VectorElement, TensorElement)):
leaf_element = sub_element.sub_elements()[0]
lfs = name_leaf_lfs(leaf_element, inputsf.restriction)
basisiname = sumfact_iname(name_lfs_bound(lfs), "basis")
container = inputsf.coeff_func(inputsf.restriction)
if isinstance(sub_element, (VectorElement, TensorElement)):
coeff = pymbolic_coefficient(container, lfs_pym, basisiname)
else:
coeff = pymbolic_coefficient(container, lfs, basisiname)
assignee = prim.Subscript(prim.Variable(input_setup), (prim.Variable(basisiname),) + (index,))
instruction(assignee=assignee,
expression=coeff,
depends_on=inputsf.insn_dep.union(insn_dep),
tags=frozenset({"sumfact_stage{}".format(sf.stage)}),
)
if sf.vectorized:
for i, inputsf in enumerate(sf.kernels):
_write_input(inputsf, i)
inputsf.input.realize(inputsf, i, insn_dep)
else:
_write_input(sf)
sf.input.realize(sf, 0, insn_dep)
insn_dep = insn_dep.union(frozenset({lp.match.Writes("input_{}".format(sf.buffer))}))
# Construct the direct_input for the FastDG case
direct_input = None
if get_option('fastdg') and sf.stage == 1:
direct_input = sf.coeff_func(sf.restriction)
direct_input = sf.input.coeff_func(sf.input.restriction)
direct_output = None
if get_option('fastdg') and sf.stage == 3:
......
......@@ -14,6 +14,10 @@ import frozendict
import inspect
class SumfactKernelInputBase(object):
pass
class SumfactKernelBase(object):
pass
......@@ -24,13 +28,10 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
buffer=None,
stage=1,
preferred_position=None,
restriction=0,
restriction=None,
within_inames=(),
insn_dep=frozenset(),
coeff_func=None,
coeff_func_index=None,
element=None,
component=None,
input=None,
accumvar=None,
):
"""Create a sum factorization kernel
......@@ -101,6 +102,7 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
element: The UFL element
component: The treepath to the correct component of above element
accumvar: The accumulation variable to accumulate into
input: An SumfactKernelInputBase instance describing the input of the kernel
"""
# Assert the inputs!
assert isinstance(matrix_sequence, tuple)
......@@ -112,6 +114,10 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
if preferred_position is not None:
assert isinstance(preferred_position, int)
if stage == 1:
assert isinstance(input, SumfactKernelInputBase)
restriction = input.restriction
if stage == 3:
assert isinstance(restriction, tuple)
......@@ -161,7 +167,7 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
work on the same input coefficient (and are suitable for simultaneous
treatment because of that)
"""
return (self.restriction, self.stage, self.coeff_func, self.coeff_func_index, self.element, self.component, self.accumvar)
return (self.input, self.restriction, self.accumvar)
#
# Some convenience methods to extract information about the sum factorization kernel
......@@ -380,24 +386,9 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
return self.kernels[0].within_inames
@property
def coeff_func(self):
assert len(set(k.coeff_func for k in self.kernels)) == 1
return self.kernels[0].coeff_func
@property
def coeff_func_index(self):
assert len(set(k.coeff_func_index for k in self.kernels)) == 1
return self.kernels[0].coeff_func_index
@property
def element(self):
assert len(set(k.element for k in self.kernels)) == 1
return self.kernels[0].element
@property
def component(self):
assert len(set(k.component for k in self.kernels)) == 1
return self.kernels[0].component
def input(self):
assert len(set(k.input for k in self.kernels)) == 1
return self.kernels[0].input
@property
def accumvar(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment