Skip to content
Snippets Groups Projects
Commit ed0d0683 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Broadcast theta matrices where possible

parent 70589afb
No related branches found
No related tags found
No related merge requests found
......@@ -20,10 +20,12 @@ from dune.perftool.generation import (class_member,
valuearg
)
from dune.perftool.loopy.buffer import get_buffer_temporary
from dune.perftool.loopy.vcl import get_vcl_typename
from dune.perftool.pdelab.localoperator import (name_domain_field,
lop_template_range_field,
)
from dune.perftool.pdelab.quadrature import quadrature_order
from dune.perftool.tools import maybe_wrap_subscript
from loopy import CallMangleInfo
from loopy.symbolic import FunctionIdentifier
from loopy.types import NumpyType
......@@ -32,7 +34,7 @@ from pytools import ImmutableRecord, product
import pymbolic.primitives as prim
import loopy as lp
import numpy
import numpy as np
def ceildiv(a, b):
......@@ -84,7 +86,6 @@ class BasisTabulationMatrix(BasisTabulationMatrixBase, ImmutableRecord):
return self.basis_size
def pymbolic(self, indices):
assert len(indices) == 2
name = "{}{}Theta{}{}_qp{}_dof_{}".format("face{}_".format(self.face) if self.face is not None else "",
"d" if self.derivative else "",
"T" if self.transpose else "",
......@@ -160,6 +161,13 @@ class BasisTabulationMatrixArray(BasisTabulationMatrixBase):
def pymbolic(self, indices):
assert len(indices) == 3
# Check whether we can realize this by broadcasting the values of a simple tabulation
if len(set(self.tabs)) == 1:
vcltype = get_vcl_typename(np.float64, vector_width=len(self.tabs))
theta = self.tabs[0].pymbolic(indices[:-1])
return prim.Call(prim.Variable(vcltype), (theta,))
abbrevs = tuple("{}x{}".format("d" if t.derivative else "",
"s{}".format(t.slice_index) if t.slice_size is not None else "")
for t in self.tabs)
......@@ -183,7 +191,7 @@ class BasisTabulationMatrixArray(BasisTabulationMatrixBase):
member = loopy_class_member(name,
classtag="operator",
dtype=numpy.float64,
dtype=np.float64,
dim_tags="f,f,vec",
shape=(self.rows, self.cols, self.width),
potentially_vectorized=True,
......@@ -218,7 +226,7 @@ def basis_functions_per_direction():
def define_oned_quadrature_weights(name):
loopy_class_member(name,
dtype=numpy.float64,
dtype=np.float64,
classtag="operator",
shape=(quadrature_points_per_direction(),),
)
......@@ -233,7 +241,7 @@ def name_oned_quadrature_weights():
def define_oned_quadrature_points(name):
loopy_class_member(name,
dtype=numpy.float64,
dtype=np.float64,
classtag="operator",
shape=(quadrature_points_per_direction(),),
)
......@@ -315,7 +323,7 @@ class PolynomialLookup(FunctionIdentifier):
@function_mangler
def polynomial_lookup_mangler(target, func, dtypes):
if isinstance(func, PolynomialLookup):
return CallMangleInfo(func.name, (NumpyType(numpy.float64),), (NumpyType(numpy.int32), NumpyType(numpy.float64)))
return CallMangleInfo(func.name, (NumpyType(np.float64),), (NumpyType(np.int32), NumpyType(np.float64)))
def define_theta(name, tabmat, additional_indices=(), width=None):
......@@ -332,7 +340,7 @@ def define_theta(name, tabmat, additional_indices=(), width=None):
shape = shape + (width,)
loopy_class_member(name,
dtype=numpy.float64,
dtype=np.float64,
shape=shape,
classtag="operator",
dim_tags=dim_tags,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment