Skip to content
Snippets Groups Projects
Commit 1bcfe841 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Move input shape on the input object

parent 65c7ce35
No related branches found
No related tags found
No related merge requests found
......@@ -32,7 +32,7 @@ from dune.perftool.pdelab.argument import name_coefficientcontainer
from dune.perftool.pdelab.geometry import (local_dimension,
world_dimension,
)
from dune.perftool.loopy.buffer import initialize_buffer
from dune.perftool.loopy.buffer import initialize_buffer, get_buffer_temporary
from dune.perftool.sumfact.symbolic import SumfactKernel, SumfactKernelInputBase
from dune.perftool.options import get_option
from dune.perftool.pdelab.driver import FEM_name_mangling
......@@ -96,7 +96,13 @@ class LFSSumfactKernelInput(SumfactKernelInputBase, ImmutableRecord):
from dune.perftool.pdelab.argument import pymbolic_coefficient as pc
coeff = pc(container, lfs, basisiname)
assignee = prim.Subscript(prim.Variable("input_{}".format(sf.buffer)),
# Get the input temporary!
name = get_buffer_temporary(sf.buffer,
shape=(product(mat.basis_size for mat in sf.matrix_sequence), sf.vector_width),
name="input_{}".format(sf.buffer)
)
assignee = prim.Subscript(prim.Variable(name),
(prim.Variable(basisiname),) + (index,))
instruction(assignee=assignee,
expression=coeff,
......
......@@ -61,12 +61,6 @@ def _realize_sum_factorization_kernel(sf):
if sf.stage == 1 and not get_option("fastdg"):
assert sf.input
# Get the input temporary!
input_setup = get_buffer_temporary(sf.buffer,
shape=sf.flat_input_shape,
name="input_{}".format(sf.buffer)
)
if sf.vectorized:
for i, inputsf in enumerate(sf.kernels):
inputsf.input.realize(sf, i, inputsf.insn_dep.union(insn_dep))
......
......@@ -15,7 +15,9 @@ import inspect
class SumfactKernelInputBase(object):
pass
@property
def flat_shape(self):
return False
class SumfactKernelBase(object):
......@@ -192,12 +194,6 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
"""
return 0
@property
def flat_input_shape(self):
""" The 'flat' input tensor shape """
assert self.stage == 1
return (product(mat.basis_size for mat in self.matrix_sequence), 1)
@property
def quadrature_shape(self):
""" The shape of a temporary for the quadrature points
......@@ -283,6 +279,10 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
def vertical_width(self):
return 1
@property
def vector_width(self):
return 1
# Extract the argument list and store it on the class. This needs to be done
# outside of the class because the SumfactKernel class object needs to be fully
# initialized in order to extract the information from __init__.
......@@ -470,10 +470,6 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
return self.horizontal_index(sf) + prim.Remainder(sliced, self.vertical_width)
@property
def flat_input_shape(self):
return (product(mat.basis_size for mat in self.matrix_sequence), self.vector_width)
@property
def quadrature_shape(self):
return tuple(mat.quadrature_size for mat in self.matrix_sequence) + (self.vector_width,)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment