diff --git a/python/dune/perftool/sumfact/basis.py b/python/dune/perftool/sumfact/basis.py index 37e1163086efecb056cd90df3f53ef9b8c9e177c..de9c30aea7dd6a08a35dc702aa3c91442209596f 100644 --- a/python/dune/perftool/sumfact/basis.py +++ b/python/dune/perftool/sumfact/basis.py @@ -32,7 +32,7 @@ from dune.perftool.pdelab.argument import name_coefficientcontainer from dune.perftool.pdelab.geometry import (local_dimension, world_dimension, ) -from dune.perftool.loopy.buffer import initialize_buffer +from dune.perftool.loopy.buffer import initialize_buffer, get_buffer_temporary from dune.perftool.sumfact.symbolic import SumfactKernel, SumfactKernelInputBase from dune.perftool.options import get_option from dune.perftool.pdelab.driver import FEM_name_mangling @@ -96,7 +96,13 @@ class LFSSumfactKernelInput(SumfactKernelInputBase, ImmutableRecord): from dune.perftool.pdelab.argument import pymbolic_coefficient as pc coeff = pc(container, lfs, basisiname) - assignee = prim.Subscript(prim.Variable("input_{}".format(sf.buffer)), + # Get the input temporary! + name = get_buffer_temporary(sf.buffer, + shape=(product(mat.basis_size for mat in sf.matrix_sequence), sf.vector_width), + name="input_{}".format(sf.buffer) + ) + + assignee = prim.Subscript(prim.Variable(name), (prim.Variable(basisiname),) + (index,)) instruction(assignee=assignee, expression=coeff, diff --git a/python/dune/perftool/sumfact/realization.py b/python/dune/perftool/sumfact/realization.py index 1bd5ea3fab0c579b091e9c70b820587dd5741456..f939feaeb6fc8da05828e88ef25fece71ecf9408 100644 --- a/python/dune/perftool/sumfact/realization.py +++ b/python/dune/perftool/sumfact/realization.py @@ -61,12 +61,6 @@ def _realize_sum_factorization_kernel(sf): if sf.stage == 1 and not get_option("fastdg"): assert sf.input - # Get the input temporary! - input_setup = get_buffer_temporary(sf.buffer, - shape=sf.flat_input_shape, - name="input_{}".format(sf.buffer) - ) - if sf.vectorized: for i, inputsf in enumerate(sf.kernels): inputsf.input.realize(sf, i, inputsf.insn_dep.union(insn_dep)) diff --git a/python/dune/perftool/sumfact/symbolic.py b/python/dune/perftool/sumfact/symbolic.py index 88f884c80a54ff82aeeb184b636753becfae4dce..88e82c4aff265f0221608a9afaf18ef5a59abc62 100644 --- a/python/dune/perftool/sumfact/symbolic.py +++ b/python/dune/perftool/sumfact/symbolic.py @@ -15,7 +15,9 @@ import inspect class SumfactKernelInputBase(object): - pass + @property + def flat_shape(self): + return False class SumfactKernelBase(object): @@ -192,12 +194,6 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable): """ return 0 - @property - def flat_input_shape(self): - """ The 'flat' input tensor shape """ - assert self.stage == 1 - return (product(mat.basis_size for mat in self.matrix_sequence), 1) - @property def quadrature_shape(self): """ The shape of a temporary for the quadrature points @@ -283,6 +279,10 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable): def vertical_width(self): return 1 + @property + def vector_width(self): + return 1 + # Extract the argument list and store it on the class. This needs to be done # outside of the class because the SumfactKernel class object needs to be fully # initialized in order to extract the information from __init__. @@ -470,10 +470,6 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable) return self.horizontal_index(sf) + prim.Remainder(sliced, self.vertical_width) - @property - def flat_input_shape(self): - return (product(mat.basis_size for mat in self.matrix_sequence), self.vector_width) - @property def quadrature_shape(self): return tuple(mat.quadrature_size for mat in self.matrix_sequence) + (self.vector_width,)