Skip to content
Snippets Groups Projects
Commit 69158898 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Implement vectorization for poisson

parent 89f0fe86
No related branches found
No related tags found
No related merge requests found
......@@ -66,14 +66,10 @@ def collect_vector_data_rotate(knl, insns, inames):
basename = get_pymbolic_basename(expr)
quantities.setdefault(basename, frozenset())
quantities[basename] = quantities[basename].union(frozenset([expr]))
assert all(len(q) == 1 for q in quantities.values())
# Add vector size buffers for all these quantities
replacemap_arr = {}
replacemap_vec = {}
for quantity in quantities:
expr, = quantities[quantity]
# Check whether there is an instruction that writes this quantity within
# the given inames. If so, we need a buffer array.
iname_match = lp.match.And(tuple(lp.match.Iname(i) for i in inames))
......@@ -83,38 +79,57 @@ def collect_vector_data_rotate(knl, insns, inames):
all_writers.extend([i.id for i in write_insns])
if write_insns:
# Determine the shape of the quantity
shape = knl.temporary_variables[quantity].shape
arrname = quantity + '_buffered_arr'
knl = add_temporary_with_vector_view(knl,
arrname,
dtype=np.float64,
shape=(vec_size,),
dim_tags="c",
shape=shape + (vec_size,),
dim_tags=",".join("c" for i in range(len(shape) + 1)),
base_storage=quantity + '_base_storage',
scope=lp.temp_var_scope.PRIVATE,
)
replacemap_arr[quantity] = prim.Subscript(prim.Variable(arrname), (prim.Variable('rotate_index'),))
replacemap_vec[expr] = prim.Subscript(prim.Variable(get_vector_view_name(arrname)), (0, prim.Variable(new_iname),))
def get_quantity_subscripts(e, zero=False):
if isinstance(e, prim.Subscript):
index = e.index
if isinstance(index, tuple):
return index
else:
return (index,)
else:
if zero:
return (0,)
else:
return ()
for expr in quantities[quantity]:
replacemap_vec[expr] = prim.Subscript(prim.Variable(get_vector_view_name(arrname)), get_quantity_subscripts(expr, zero=True) + (prim.Variable(new_iname),))
for insn in write_insns:
if isinstance(insn, lp.Assignment):
new_insns.append(insn.copy(assignee=replacemap_arr[get_pymbolic_basename(insn.assignee)],
assignee = prim.Subscript(prim.Variable(arrname), get_quantity_subscripts(insn.assignee) + (prim.Variable('rotate_index'),))
new_insns.append(insn.copy(assignee=assignee,
depends_on_is_final=True,
)
)
elif isinstance(insn, lp.CInstruction):
# Rip apart the code and change the assignee
assignee, expression = insn.code.split("=")
assignee = assignee.strip()
assert assignee in replacemap_arr
code = "{} ={}".format(str(replacemap_arr[assignee]), expression)
# TODO This is a bit whacky: It only works for scalar assignees
# OTOH this code is on its way out anyway, because of CInstruction
assignee = prim.Subscript(prim.Variable(arrname), (prim.Variable('rotate_index'),))
code = "{} ={}".format(str(assignee), expression)
new_insns.append(insn.copy(code=code,
depends_on_is_final=True,
))
else:
raise NotImplementedError
else:
elif quantity in knl.temporary_variables:
# Add a vector view to this quantity
knl = add_vector_view(knl, quantity)
replacemap_vec[expr] = prim.Subscript(prim.Variable(get_vector_view_name(quantity)),
......
......@@ -26,6 +26,8 @@ from dune.perftool.ufl.modified_terminals import Restriction
from pymbolic.primitives import Variable
from pytools import Record
import loopy as lp
def name_form(formdata, data):
# Check wether the formdata has a name in UFL
......@@ -508,6 +510,19 @@ def generate_kernel(integrals):
from dune.perftool.loopy import heuristic_duplication
kernel = heuristic_duplication(kernel)
# Maybe apply vectorization strategies
if get_option("vectorize"):
if get_option("sumfact"):
# Vectorization of the quadrature loop
insns = [i.id for i in lp.find_instructions(kernel, lp.match.Tagged("quadvec"))]
from dune.perftool.sumfact.quadrature import quadrature_inames
inames = quadrature_inames()
from dune.perftool.loopy.transformations.collect_rotate import collect_vector_data_rotate
kernel = collect_vector_data_rotate(kernel, insns, inames)
else:
raise NotImplementedError("Only vectorizing sumfactoized code right now!")
# Now add the preambles to the kernel
preambles = [(i, p) for i, p in enumerate(retrieve_cache_items("preamble"))]
kernel = kernel.copy(preambles=preambles)
......
......@@ -41,32 +41,13 @@ def name_sumfact_base_buffer():
@cached
def sumfact_evaluate_coefficient_gradient(element, name, restriction, component):
# First we determine the rank of the tensor we are talking about
# Get a temporary for the gradient
from ufl.functionview import select_subelement
sub_element = select_subelement(element, component)
rank = len(sub_element.value_shape()) + 1
# We do then set some variables accordingly
shape = sub_element.value_shape() + (element.cell().geometric_dimension(),)
shape_impl = ('arr',) * rank
from dune.perftool.pdelab.geometry import dimension_iname
idims = tuple(dimension_iname(count=i) for i in range(rank))
leaf_element = sub_element
from ufl import VectorElement, TensorElement
if isinstance(sub_element, (VectorElement, TensorElement)):
leaf_element = sub_element.sub_elements()[0]
# and proceed to call the necessary generator functions
temporary_variable(name, shape=shape, shape_impl=shape_impl)
from dune.perftool.pdelab.spaces import name_lfs
lfs = name_lfs(element, restriction, component)
from dune.perftool.pdelab.basis import pymbolic_reference_gradient
basis = pymbolic_reference_gradient(leaf_element, restriction, 0, context='trialgrad')
from dune.perftool.tools import get_pymbolic_indices
index, _ = get_pymbolic_indices(basis)
if isinstance(sub_element, (VectorElement, TensorElement)):
lfs = lfs_child(lfs, idims[:-1], shape=shape_as_pymbolic(shape[:-1]), symmetry=element.symmetry())
# Calculate values with sumfactorization
theta = name_theta()
......@@ -111,7 +92,7 @@ def sumfact_evaluate_coefficient_gradient(element, name, restriction, component)
expression = Subscript(Variable(buf), tuple(Variable(i) for i in quadrature_inames()))
instruction(assignee=assignee,
expression=expression,
forced_iname_deps=frozenset(get_backend("quad_inames")()).union(frozenset(idims)),
forced_iname_deps=frozenset(get_backend("quad_inames")()),
forced_iname_deps_is_final=True,
)
......@@ -204,9 +185,6 @@ def pymbolic_basis(element, restriction, number):
@backend(interface="evaluate_grad")
@cached
def evaluate_reference_gradient(element, name, restriction):
# from dune.perftool.pdelab.basis import name_leaf_lfs
# lfs = name_leaf_lfs(element, restriction)
# from dune.perftool.pdelab.spaces import name_lfs_bound
from dune.perftool.pdelab.geometry import name_dimension
temporary_variable(
name,
......
......@@ -157,6 +157,7 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
expression=expression,
forced_iname_deps=frozenset(quadrature_inames() + visitor.inames),
forced_iname_deps_is_final=True,
tags=frozenset({"quadvec"}),
)
contribution_ids.append(contrib_dep)
......@@ -205,11 +206,6 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
# Mark the transformation that moves the quadrature loop inside the trialfunction loops for application
transform(nest_quadrature_loops, visitor.inames)
# Maybe try to vectorize!
if get_option("vectorize"):
from dune.perftool.loopy.transformations.collect_rotate import collect_vector_data_rotate
transform(collect_vector_data_rotate, contribution_ids, quadrature_inames())
def sum_factorization_kernel(a_matrices, buf, insn_dep=frozenset({}), additional_inames=frozenset({})):
"""
......
__name = sumfact_poisson_order1_{__exec_suffix}
__exec_suffix = numdiff, symdiff | expand num
__exec_suffix = {diff_suffix}_{vec_suffix}
diff_suffix = numdiff, symdiff | expand num
vec_suffix = vec, nonvec | expand vec
cells = 8 8
extension = 1. 1.
......@@ -14,3 +17,4 @@ numerical_jacobian = 1, 0 | expand num
exact_solution_expression = g
compare_l2errorsquared = 1e-4
sumfact = 1
vectorize = 1, 0 | expand vec
__name = sumfact_poisson_order2_{__exec_suffix}
__exec_suffix = numdiff, symdiff | expand num
__exec_suffix = {diff_suffix}_{vec_suffix}
diff_suffix = numdiff, symdiff | expand num
vec_suffix = vec, nonvec | expand vec
cells = 8 8
extension = 1. 1.
......@@ -14,3 +17,4 @@ numerical_jacobian = 1, 0 | expand num
exact_solution_expression = g
compare_l2errorsquared = 1e-8
sumfact = 1
vectorize = 1, 0 | expand vec
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment