Skip to content
Snippets Groups Projects
Commit e8cda866 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Fix facetheta matrix generation

parent 1c476518
No related branches found
No related tags found
No related merge requests found
......@@ -37,17 +37,18 @@ import numpy
class AMatrix(ImmutableRecord):
def __init__(self, rows, cols, transpose=False, derivative=False):
def __init__(self, rows, cols, transpose=False, derivative=False, face=None):
ImmutableRecord.__init__(self,
rows=rows,
cols=cols,
transpose=transpose,
derivative=derivative,
face=face,
)
@property
def name(self):
return name_theta(self.transpose, self.derivative)
return name_theta(self.transpose, self.derivative, face=self.face)
@property
def vectorized(self):
......@@ -216,7 +217,7 @@ def polynomial_lookup_mangler(target, func, dtypes):
return CallMangleInfo(func.name, (NumpyType(numpy.float64),), (NumpyType(numpy.int32), NumpyType(numpy.float64)))
def define_theta(name, shape, transpose, derivative, additional_indices=()):
def define_theta(name, shape, transpose, derivative, face=None, additional_indices=()):
sort_quadrature_points_weights()
polynomials = name_polynomials()
qp = name_oned_quadrature_points()
......@@ -238,28 +239,38 @@ def define_theta(name, shape, transpose, derivative, additional_indices=()):
i = theta_iname("i", shape[0])
j = theta_iname("j", shape[1])
inames = i, j
if transpose:
args = (prim.Variable(i), prim.Subscript(prim.Variable(qp), (prim.Variable(j),)))
else:
args = (prim.Variable(j), prim.Subscript(prim.Variable(qp), (prim.Variable(i),)))
inames = j, i
args = [prim.Variable(inames[1]), prim.Subscript(prim.Variable(qp), (prim.Variable(inames[0]),))]
if face is not None:
args[1] = face
instruction(assignee=prim.Subscript(prim.Variable(name), (prim.Variable(i), prim.Variable(j)) + additional_indices),
expression=prim.Call(PolynomialLookup(polynomials, derivative), args),
expression=prim.Call(PolynomialLookup(polynomials, derivative), tuple(args)),
kernel="operator",
)
def name_theta(transpose=False, derivative=False):
name = "{}Theta{}".format("d" if derivative else "", "T" if transpose else "")
def name_theta(transpose=False, derivative=False, face=None):
name = "{}{}Theta{}".format("face{}_".format(face) if face is not None else "",
"d" if derivative else "",
"T" if transpose else "",
)
shape = [quadrature_points_per_direction(), basis_functions_per_direction()]
if face is not None:
shape[0] = 1
if transpose:
shape = (basis_functions_per_direction(), quadrature_points_per_direction())
else:
shape = (quadrature_points_per_direction(), basis_functions_per_direction())
define_theta(name, shape, transpose, derivative)
shape = shape[1], shape[0]
shape = tuple(shape)
define_theta(name, shape, transpose, derivative, face=face)
return name
def construct_amatrix_sequence(transpose=False, derivative=None, face=None):
def construct_amatrix_sequence(transpose=False, derivative=None, facedir=None, facemod=None):
dim = world_dimension()
result = [None] * dim
......@@ -267,12 +278,14 @@ def construct_amatrix_sequence(transpose=False, derivative=None, face=None):
rows = quadrature_points_per_direction()
cols = basis_functions_per_direction()
if face == i:
onface = None
if facedir == i:
rows = 1
onface = facemod
if transpose:
rows, cols = cols, rows
result[i] = AMatrix(rows, cols, transpose=transpose, derivative=derivative == i)
result[i] = AMatrix(rows, cols, transpose=transpose, derivative=derivative == i, face=onface)
return tuple(result)
......@@ -66,7 +66,10 @@ def pymbolic_trialfunction_gradient(element, restriction, component, visitor):
insn_dep = None
for i in range(dim):
# Construct the matrix sequence for this sum factorization
a_matrices = construct_amatrix_sequence(derivative=i, face=get_facedir(restriction))
a_matrices = construct_amatrix_sequence(derivative=i,
facedir=get_facedir(restriction),
facemod=get_facemod(restriction),
)
# Get the vectorization info. If this happens during the dry run, we get dummies
from dune.perftool.sumfact.vectorization import get_vectorization_info
......@@ -130,7 +133,8 @@ def pymbolic_trialfunction(element, restriction, component, visitor):
dim = world_dimension()
# Construct the matrix sequence for this sum factorization
a_matrices = construct_amatrix_sequence(face=get_facedir(restriction))
a_matrices = construct_amatrix_sequence(facedir=get_facedir(restriction),
facemod=get_facemod(restriction),)
# Get the vectorization info. If this happens during the dry run, we get dummies
from dune.perftool.sumfact.vectorization import get_vectorization_info
......
......@@ -43,7 +43,9 @@ from dune.perftool.sumfact.amatrix import (AMatrix,
basis_functions_per_direction,
construct_amatrix_sequence,
)
from dune.perftool.sumfact.switch import get_facedir
from dune.perftool.sumfact.switch import (get_facedir,
get_facemod,
)
from dune.perftool.loopy.symbolic import SumfactKernel
from dune.perftool.tools import get_pymbolic_basename
from dune.perftool.error import PerftoolError
......@@ -106,7 +108,6 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
return
dim = world_dimension()
facedir = get_facedir(accterm.argument.restriction)
# Collect buffers we need
buffers = []
......@@ -128,7 +129,8 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
# Construct the matrix sequence for this sum factorization
a_matrices = construct_amatrix_sequence(transpose=True,
derivative=i if accterm.argument.index else None,
face=facedir,
facedir=get_facedir(accterm.argument.restriction),
facemod=get_facemod(accterm.argument.restriction),
)
# Get the vectorization info. If this happens during the dry run, we get dummies
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment