Skip to content
Snippets Groups Projects
Commit f98117ab authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Also do stage 3

parent 3758798e
No related branches found
No related tags found
No related merge requests found
...@@ -32,17 +32,16 @@ import numpy ...@@ -32,17 +32,16 @@ import numpy
class AMatrix(Record): class AMatrix(Record):
def __init__(self, a_matrix, m, n): def __init__(self, a_matrix, rows, cols):
Record.__init__(self, Record.__init__(self,
a_matrix=a_matrix, a_matrix=a_matrix,
m=m, rows=rows,
n=n, cols=cols,
) )
class ColMajorAccess(FunctionIdentifier): class ColMajorAccess(FunctionIdentifier):
def __init__(self, amatrix): def __init__(self, amatrix):
assert isinstance(amatrix, AMatrix)
self.amatrix = amatrix self.amatrix = amatrix
def __getinitargs__(self): def __getinitargs__(self):
...@@ -50,7 +49,7 @@ class ColMajorAccess(FunctionIdentifier): ...@@ -50,7 +49,7 @@ class ColMajorAccess(FunctionIdentifier):
@property @property
def name(self): def name(self):
return '{}.colmajoraccess'.format(self.amatrix.a_matrix) return '{}.colmajoraccess'.format(self.amatrix)
@function_mangler @function_mangler
...@@ -192,17 +191,19 @@ def sort_quadrature_points_weights(): ...@@ -192,17 +191,19 @@ def sort_quadrature_points_weights():
@constructor_block("operator") @constructor_block("operator")
def construct_theta(name): def construct_theta(name, transpose):
# Make sure that the quadrature points are sorted # Make sure that the quadrature points are sorted
sort_quadrature_points_weights() sort_quadrature_points_weights()
m = name_number_of_quadrature_points_per_direction() if transpose:
n = name_number_of_basis_functions_per_direction() shape = (name_number_of_basis_functions_per_direction(), name_number_of_quadrature_points_per_direction())
else:
shape = (name_number_of_quadrature_points_per_direction(), name_number_of_basis_functions_per_direction())
polynomials = name_polynomials() polynomials = name_polynomials()
qp = name_oned_quadrature_points() qp = name_oned_quadrature_points()
return ["for (std::size_t i=0; i<{}; i++){{".format(m), return ["for (std::size_t i=0; i<{}; i++){{".format(shape[0]),
" for (std::size_t j=0; j<{}; j++){{".format(n), " for (std::size_t j=0; j<{}; j++){{".format(shape[1]),
" {}.colmajoraccess(i,j) = {}.p(j,{}[i]);".format(name, polynomials, qp), " {}.colmajoraccess(i,j) = {}.p(j,{}[i]);".format(name, polynomials, qp),
" }", " }",
"}"] "}"]
...@@ -223,17 +224,25 @@ def type_theta(): ...@@ -223,17 +224,25 @@ def type_theta():
@class_member("operator") @class_member("operator")
def define_theta(name): def define_theta(name, transpose):
theta_type = type_theta() theta_type = type_theta()
number_qp = quadrature_points_per_direction() if transpose:
number_basis = basis_functions_per_direction() shape = (basis_functions_per_direction(), quadrature_points_per_direction())
globalarg(name, shape=(number_basis, number_qp), dtype=numpy.float32, dim_tags="f,f") else:
initializer_list(name, [str(number_qp), str(number_basis)], classtag="operator") shape = (quadrature_points_per_direction(), basis_functions_per_direction())
construct_theta(name) globalarg(name, shape=shape, dtype=numpy.float32, dim_tags="f,f")
initializer_list(name, [str(axis) for axis in shape], classtag="operator")
construct_theta(name, transpose)
return "{} {};".format(theta_type, name) return "{} {};".format(theta_type, name)
def name_theta(): def name_theta():
name = "Theta" name = "Theta"
define_theta(name) define_theta(name, False)
return name
def name_theta_transposed():
name = "ThetaT"
define_theta(name, True)
return name return name
...@@ -10,6 +10,13 @@ from dune.perftool.loopy.buffer import (get_buffer_temporary, ...@@ -10,6 +10,13 @@ from dune.perftool.loopy.buffer import (get_buffer_temporary,
switch_base_storage, switch_base_storage,
) )
from dune.perftool.pdelab.spaces import name_lfs from dune.perftool.pdelab.spaces import name_lfs
from dune.perftool.sumfact.amatrix import (AMatrix,
quadrature_points_per_direction,
basis_functions_per_direction,
name_theta,
name_theta_transposed,
)
from dune.perftool.loopy.stages import stage_insn
from pymbolic.primitives import (Call, from pymbolic.primitives import (Call,
Product, Product,
Subscript, Subscript,
...@@ -34,47 +41,51 @@ def sumfact_iname(bound, _type): ...@@ -34,47 +41,51 @@ def sumfact_iname(bound, _type):
return name return name
def setup_theta(element, container, restriction, component, a_matrices):
number_basis = product(mat.cols for mat in a_matrices)
shape = (number_basis,)
inp = get_buffer_temporary("buffer",
shape=shape)
silenced_warning('read_no_write({})'.format(inp))
# Write initial coefficients into buffer
basisiname = sumfact_iname(number_basis, "basis")
lfs = name_lfs(element, restriction, component)
coeff = pymbolic_coefficient(container, lfs, basisiname)
assignee = Subscript(Variable(inp), (Variable(basisiname),))
return instruction(assignee=assignee,
expression=coeff,
depends_on=frozenset({stage_insn(0)}),
)
# TODO this code is WIP and mainly used for experiments. # TODO this code is WIP and mainly used for experiments.
def start_sumfactorization(element, container, restriction, component): def start_sumfactorization(element, container, restriction, component):
from dune.perftool.sumfact.amatrix import (AMatrix,
quadrature_points_per_direction,
basis_functions_per_direction,
name_theta,
)
theta = name_theta() theta = name_theta()
m = quadrature_points_per_direction() rows = quadrature_points_per_direction()
n = basis_functions_per_direction() cols = basis_functions_per_direction()
a_matrix = AMatrix(theta, m, n) a_matrix = AMatrix(theta, rows, cols)
a_matrices = (a_matrix, a_matrix) a_matrices = (a_matrix, a_matrix)
# Do stage 1
initialize_buffer("buffer", initialize_buffer("buffer",
base_storage_size=product(max(mat.n, mat.m) for mat in a_matrices), base_storage_size=product(max(mat.rows, mat.cols) for mat in a_matrices),
num=2 num=2
) )
insn_dep = setup_theta(element, container, restriction, component, a_matrices)
number_basis = product(mat.n for mat in a_matrices) sum_factorization_kernel(a_matrices, "buffer", 0, frozenset({insn_dep}))
shape = (n,)
inp = get_buffer_temporary("buffer",
shape=shape)
silenced_warning('read_no_write({})'.format(inp))
# Write initial coefficients into buffer # Do stage 3 (for f=u => mass matrix)
basisiname = sumfact_iname(number_basis, "basis") theta_transposed = name_theta_transposed()
lfs = name_lfs(element, restriction, component) a_matrix_transposed = AMatrix(theta_transposed, cols, rows)
coeff = pymbolic_coefficient(container, lfs, basisiname) a_matrices_transposed = (a_matrix_transposed, a_matrix_transposed)
assignee = Subscript(Variable(inp), (Variable(basisiname),))
from dune.perftool.loopy.stages import stage_insn
insn_dep = instruction(assignee = assignee,
expression = coeff,
depends_on = frozenset({stage_insn(0)}),
)
return sum_factorization_kernel(a_matrices, inp, "buffer", insn_dep) return sum_factorization_kernel(a_matrices_transposed, "buffer", 2)
def sum_factorization_kernel(a_matrices, inp, buffer, insn_dep): def sum_factorization_kernel(a_matrices, buffer, stage, insn_dep=frozenset({})):
""" """
Calculate a sum factorization matrix product. Calculate a sum factorization matrix product.
...@@ -87,15 +98,18 @@ def sum_factorization_kernel(a_matrices, inp, buffer, insn_dep): ...@@ -87,15 +98,18 @@ def sum_factorization_kernel(a_matrices, inp, buffer, insn_dep):
a_matrices: An iterable of AMatrix instances a_matrices: An iterable of AMatrix instances
The list of tensors to be applied to the input. The list of tensors to be applied to the input.
Order of application is from 0 up. Order of application is from 0 up.
inp: A temporary that contains the input matrix
buffer: A string identifying the flip flop buffer in use buffer: A string identifying the flip flop buffer in use
for intermediate results. for intermediate results. The memory is expected to be
pre-initialized with the input.
insn_dep: an instruction ID that the first issued instruction
should depend upon. All following ones will depend on each
other.
""" """
for l, a_matrix in enumerate(a_matrices): for l, a_matrix in enumerate(a_matrices):
# Get a temporary that interprets the base storage of the input # Get a temporary that interprets the base storage of the input
# as a column-major matrix. In later iteration of the amatrix loop # as a column-major matrix. In later iteration of the amatrix loop
# this reinterprets the output of the previous iteration. # this reinterprets the output of the previous iteration.
inp_shape = (a_matrix.n, product(mat.m for mat in a_matrices[:l]) * product(mat.n for mat in a_matrices[l + 1:])) inp_shape = (a_matrix.cols, product(mat.rows for mat in a_matrices[:l]) * product(mat.cols for mat in a_matrices[l + 1:]))
inp = get_buffer_temporary(buffer, inp = get_buffer_temporary(buffer,
shape=inp_shape, shape=inp_shape,
dim_tags="f,f") dim_tags="f,f")
...@@ -107,7 +121,7 @@ def sum_factorization_kernel(a_matrices, inp, buffer, insn_dep): ...@@ -107,7 +121,7 @@ def sum_factorization_kernel(a_matrices, inp, buffer, insn_dep):
# Get a temporary that interprets the base storage of the output # Get a temporary that interprets the base storage of the output
# as row-major matrix # as row-major matrix
out_shape = (a_matrix.m, product(mat.m for mat in a_matrices[:l]) * product(mat.n for mat in a_matrices[l + 1:])) out_shape = (a_matrix.rows, product(mat.rows for mat in a_matrices[:l]) * product(mat.cols for mat in a_matrices[l + 1:]))
out = get_buffer_temporary(buffer, out = get_buffer_temporary(buffer,
shape=out_shape, shape=out_shape,
dim_tags="c,c") dim_tags="c,c")
...@@ -115,21 +129,22 @@ def sum_factorization_kernel(a_matrices, inp, buffer, insn_dep): ...@@ -115,21 +129,22 @@ def sum_factorization_kernel(a_matrices, inp, buffer, insn_dep):
# Get the inames needed for one matrix-matrix multiplication # Get the inames needed for one matrix-matrix multiplication
i = sumfact_iname(out_shape[0], "row") i = sumfact_iname(out_shape[0], "row")
j = sumfact_iname(out_shape[1], "col") j = sumfact_iname(out_shape[1], "col")
k = sumfact_iname(a_matrix.n, "red") k = sumfact_iname(a_matrix.cols, "red")
# Construct the matrix-matrix-multiplication expression a_ik*in_kj # Construct the matrix-matrix-multiplication expression a_ik*in_kj
from dune.perftool.sumfact.amatrix import ColMajorAccess from dune.perftool.sumfact.amatrix import ColMajorAccess
prod = Product((Call(ColMajorAccess(a_matrix), (Variable(i), Variable(k))), prod = Product((Call(ColMajorAccess(a_matrix.a_matrix), (Variable(i), Variable(k))),
Subscript(Variable(inp), (Variable(k), Variable(j))) Subscript(Variable(inp), (Variable(k), Variable(j)))
)) ))
# Issue the reduction instruction that implements the multiplication # Issue the reduction instruction that implements the multiplication
# at the same time store the instruction ID for the next instruction to depend on # at the same time store the instruction ID for the next instruction to depend on
insn_dep = instruction(assignee=Subscript(Variable(out), (Variable(i), Variable(j))), insn_dep = frozenset({instruction(assignee=Subscript(Variable(out), (Variable(i), Variable(j))),
expression=Reduction("sum", k, prod), expression=Reduction("sum", k, prod),
forced_iname_deps=frozenset({i, j}), forced_iname_deps=frozenset({i, j}),
forced_iname_deps_is_final=True, forced_iname_deps_is_final=True,
depends_on=frozenset({insn_dep}), depends_on=insn_dep.union(frozenset({stage_insn(stage)})),
) )
})
return out return out
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment