Skip to content
Snippets Groups Projects
Commit 4a55e374 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Fix fastdg vectorization (incl. h/l)

parent c85ca159
No related branches found
No related tags found
No related merge requests found
......@@ -186,10 +186,11 @@ class AccumulationOutput(SumfactKernelOutputBase, ImmutableRecord):
else:
rowsize = sum(tuple(s for s in _local_sizes(self.trial_element)))
manual_strides = tuple("stride:{}".format(rowsize * product(shape[:i])) for i in range(len(shape)))
valuearg("jacobian_offset")
offset = "jacobian_offset{}".format(which)
valuearg(offset)
globalarg(direct_output,
shape=shape,
offset=prim.Variable("jacobian_offset") + rowsize * _dof_offset(self.test_element, self.test_element_index) + _dof_offset(self.trial_element, self.trial_element_index),
offset=prim.Variable(offset) + rowsize * _dof_offset(self.test_element, self.test_element_index) + _dof_offset(self.trial_element, self.trial_element_index),
dim_tags=manual_strides,
)
lhs = prim.Subscript(prim.Variable(direct_output), inames)
......@@ -204,10 +205,10 @@ class AccumulationOutput(SumfactKernelOutputBase, ImmutableRecord):
def fastdg_args(self):
if get_form_option("fastdg"):
ret = ("{}.data()".format(self.accumvar),)
if get_form_option("fastdg") and sf.within_inames:
element = get_leaf(sf.output.trial_element, sf.output.trial_element_index)
if get_form_option("fastdg") and self.within_inames:
element = get_leaf(self.trial_element, self.trial_element_index)
shape = tuple(element.degree() + 1 for e in range(element.cell().geometric_dimension()))
jacobian_index = flatten_index(tuple(prim.Variable(i) for i in sf.within_inames), shape, order="f")
jacobian_index = flatten_index(tuple(prim.Variable(i) for i in self.within_inames), shape, order="f")
ret = ret + (str(jacobian_index),)
return ret
else:
......
......@@ -37,7 +37,7 @@ from dune.perftool.sumfact.vectorization import attach_vectorization_info
from dune.perftool.sumfact.accumulation import sumfact_iname
from dune.perftool.loopy.target import dtype_floatingpoint
from dune.perftool.loopy.vcl import ExplicitVCLCast
from dune.perftool.tools import get_leaf
from dune.perftool.tools import get_leaf, remove_duplicates
from pytools import product
from ufl import MixedElement
......@@ -58,14 +58,19 @@ def _name_kernel_implementation_function(sf, qp):
if isinstance(sf, SumfactKernel):
fastdg = "{}comp{}".format(FEM_name_mangling(sf.input.element), sf.input.element_index)
if isinstance(sf, VectorizedSumfactKernel):
1/0
fastdg = "_".join("{}comp{}".format(FEM_name_mangling(i.element), i.element_index) for i in remove_duplicates(sf.input.inputs))
if sf.stage == 3:
if isinstance(sf, SumfactKernel):
fastdg = "{}comp{}".format(FEM_name_mangling(sf.output.test_element), sf.output.test_element_index)
if sf.output.trial_element:
if sf.within_inames:
fastdg = "{}x{}comp{}".format(fastdg, FEM_name_mangling(sf.output.trial_element), sf.output.trial_element_index)
if isinstance(sf, VectorizedSumfactKernel):
1/0
fastdg = "_".join("{}comp{}".format(FEM_name_mangling(i.test_element), i.test_element_index) for i in remove_duplicates(sf.output.outputs))
if sf.within_inames:
fastdg = "{}x{}".format(fastdg,
"_".join("{}comp{}".format(FEM_name_mangling(i.trial_element), i.trial_element_index) for i in remove_duplicates(sf.output.outputs))
)
name = "{}_fastdg{}_{}".format(name, sf.stage, fastdg)
necessary_kernel_implementations((sf, qp))
return name
......@@ -347,9 +352,10 @@ def realize_sumfact_kernel_function(sf):
args = ["const char* buffer0", "const char* buffer1"]
if get_form_option('fastdg'):
const = "const " if sf.stage == 1 else ""
args = args + ["{}double* fastdg{}".format(const, i) for i in range(len(get_input_output_tuple(sf)))]
if sf.within_inames:
args = args + ["unsigned int jacobian_offset"]
for i in range(len(get_input_output_tuple(sf))):
args.append("{}double* fastdg{}".format(const, i))
if sf.within_inames:
args.append("unsigned int jacobian_offset{}".format(i))
signature = "void {}({}) const".format(name, ", ".join(args))
kernel = extract_kernel_from_cache("kernel_default", name, [signature], add_timings=False)
......
......@@ -10,7 +10,7 @@ from dune.perftool.sumfact.quadrature import quadrature_inames
from dune.perftool.sumfact.tabulation import BasisTabulationMatrixBase, BasisTabulationMatrixArray
from dune.perftool.loopy.target import dtype_floatingpoint
from dune.perftool.loopy.vcl import ExplicitVCLCast, VCLLowerUpperLoad
from dune.perftool.tools import get_leaf, maybe_wrap_subscript
from dune.perftool.tools import get_leaf, maybe_wrap_subscript, remove_duplicates
from pytools import ImmutableRecord, product
......@@ -81,6 +81,10 @@ class VectorSumfactKernelInput(SumfactKernelInputBase):
# need to load scalars into the SIMD vector.
raise NotImplementedError("SIMD loads from scalars not implemented!")
@property
def fastdg_args(self):
return sum((i.fastdg_args for i in remove_duplicates(self.inputs)), ())
class SumfactKernelOutputBase(object):
@property
......@@ -143,15 +147,20 @@ class VectorSumfactKernelOutput(SumfactKernelOutputBase):
substname = "haddsubst_{}".format("_".join([i.name for i in inames]))
subst_rule(substname, (), result)
result = prim.Call(prim.Variable(substname), ())
transform(lp.precompute, substname, precompute_outer_inames=args["forced_iname_deps"])
transform(lp.precompute, substname)
deps = frozenset()
for o in outputs:
hadd_result = self._add_hadd(o, result)
deps = deps.union(o.realize_direct(hadd_result, inames, shape, which=self.outputs.index(o), **args))
which = tuple(remove_duplicates(self.outputs)).index(o)
deps = deps.union(o.realize_direct(hadd_result, inames, shape, which=which, **args))
return deps
@property
def fastdg_args(self):
return sum((i.fastdg_args for i in remove_duplicates(self.outputs)), ())
class SumfactKernelBase(object):
pass
......@@ -548,7 +557,12 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
#
@property
def function_key(self):
fastdg = self.inout_key if get_form_option("fastdg") else ()
fastdg = ()
if get_form_option("fastdg"):
if self.stage == 1:
fastdg = sum(((i.element, i.element_index) for i in remove_duplicates(self.input.inputs)), ())
if self.stage == 3:
fastdg = sum(((o.test_element, o.test_element_index, o.trial_element, o.trial_element_index) for o in remove_duplicates(self.output.outputs)), ())
return tuple(str(m) for m in self.matrix_sequence) + fastdg
@property
......@@ -775,14 +789,10 @@ def get_input_output_tuple(sf):
if isinstance(sf, SumfactKernel):
return (sf.input,)
if isinstance(sf, VectorizedSumfactKernel):
# This is a short recipe for removing duplicates from an iterable
# while preserving the order!
seen = set()
return tuple(x for x in self.input if not (x in seen or seen.add(x)))
return tuple(remove_duplicates(sf.input.inputs))
if sf.stage == 3:
if isinstance(sf, SumfactKernel):
return (sf.output,)
if isinstance(sf, VectorizedSumfactKernel):
seen = set()
return tuple(x for x in self.output if not (x in seen or seen.add(x)))
return tuple(remove_duplicates(sf.output.outputs))
assert(False)
......@@ -93,3 +93,12 @@ def get_leaf(element, index):
leaf_element = element.extract_component(index)[1]
return leaf_element
def remove_duplicates(iterable):
""" Remove duplicates from an iterable while preserving the order """
seen = set()
for i in iterable:
if i not in seen:
yield i
seen.add(i)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment