Skip to content
Snippets Groups Projects
Commit 0eb8cbba authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Checkpoint!

parent 6637071e
No related branches found
No related tags found
No related merge requests found
...@@ -42,6 +42,7 @@ class FlipFlopBuffer(object): ...@@ -42,6 +42,7 @@ class FlipFlopBuffer(object):
@kernel_cached @kernel_cached
def initialize_buffer(identifier): def initialize_buffer(identifier):
assert isinstance(identifier, str)
return FlipFlopBuffer(identifier) return FlipFlopBuffer(identifier)
......
...@@ -28,6 +28,9 @@ class SumfactKernel(ImmutableRecord, prim.Variable): ...@@ -28,6 +28,9 @@ class SumfactKernel(ImmutableRecord, prim.Variable):
padding=frozenset(), padding=frozenset(),
index=None, index=None,
insn_dep=frozenset(), insn_dep=frozenset(),
coeff_func=None,
element=None,
component=None,
): ):
# Check the input and apply defaults where necessary # Check the input and apply defaults where necessary
assert isinstance(a_matrices, tuple) assert isinstance(a_matrices, tuple)
...@@ -39,8 +42,8 @@ class SumfactKernel(ImmutableRecord, prim.Variable): ...@@ -39,8 +42,8 @@ class SumfactKernel(ImmutableRecord, prim.Variable):
if preferred_position is not None: if preferred_position is not None:
assert isinstance(preferred_position, int) assert isinstance(preferred_position, int)
if not isinstance(restriction, tuple): if stage == 3:
restriction = (restriction, 0) assert isinstance(restriction, tuple)
assert isinstance(within_inames, tuple) assert isinstance(within_inames, tuple)
...@@ -57,6 +60,9 @@ class SumfactKernel(ImmutableRecord, prim.Variable): ...@@ -57,6 +60,9 @@ class SumfactKernel(ImmutableRecord, prim.Variable):
padding=padding, padding=padding,
index=index, index=index,
insn_dep=insn_dep, insn_dep=insn_dep,
coeff_func=coeff_func,
element=element,
component=component,
) )
prim.Variable.__init__(self, "SUMFACT") prim.Variable.__init__(self, "SUMFACT")
...@@ -65,12 +71,12 @@ class SumfactKernel(ImmutableRecord, prim.Variable): ...@@ -65,12 +71,12 @@ class SumfactKernel(ImmutableRecord, prim.Variable):
# The methods/fields needed to get a well-formed pymbolic node # The methods/fields needed to get a well-formed pymbolic node
# #
def __getinitargs__(self): def __getinitargs__(self):
return (self.a_matrices, self.buffer, self.stage, self.preferred_position, self.restriction, self.within_inames, self.input, self.padding, self.index, self.insn_dep) return (self.a_matrices, self.buffer, self.stage, self.preferred_position, self.restriction, self.within_inames, self.input, self.padding, self.index, self.insn_dep, self.coeff_func, self.element, self.component)
def stringifier(self): def stringifier(self):
return lp.symbolic.StringifyMapper return lp.symbolic.StringifyMapper
init_arg_names = ("a_matrices", "buffer", "stage", "preferred_position", "restriction", "within_inames", "input", "padding", "index", "insn_dep") init_arg_names = ("a_matrices", "buffer", "stage", "preferred_position", "restriction", "within_inames", "input", "padding", "index", "insn_dep", "coeff_func", "element", "component")
mapper_method = "map_sumfact_kernel" mapper_method = "map_sumfact_kernel"
...@@ -94,6 +100,15 @@ class SumfactKernel(ImmutableRecord, prim.Variable): ...@@ -94,6 +100,15 @@ class SumfactKernel(ImmutableRecord, prim.Variable):
""" """
return hash((self.a_matrices, self.restriction, self.stage, self.buffer)) return hash((self.a_matrices, self.restriction, self.stage, self.buffer))
@property
def flat_input_shape(self):
""" The 'flat' input tensor shape """
from pytools import product
shape = (product(mat.cols for mat in self.a_matrices),)
if self.vectorized:
shape = shape + (4,)
return shape
class FusedMultiplyAdd(prim.Expression): class FusedMultiplyAdd(prim.Expression):
""" Represents an FMA operation """ """ Represents an FMA operation """
......
...@@ -24,7 +24,6 @@ from dune.perftool.sumfact.sumfact import (get_facedir, ...@@ -24,7 +24,6 @@ from dune.perftool.sumfact.sumfact import (get_facedir,
setup_theta, setup_theta,
SumfactKernel, SumfactKernel,
sumfact_iname, sumfact_iname,
sum_factorization_kernel,
) )
from dune.perftool.sumfact.quadrature import quadrature_inames from dune.perftool.sumfact.quadrature import quadrature_inames
from dune.perftool.sumfact.switch import (get_facedir, from dune.perftool.sumfact.switch import (get_facedir,
...@@ -78,6 +77,9 @@ def pymbolic_coefficient_gradient(element, restriction, component, coeff_func, v ...@@ -78,6 +77,9 @@ def pymbolic_coefficient_gradient(element, restriction, component, coeff_func, v
sf = SumfactKernel(a_matrices=a_matrices, sf = SumfactKernel(a_matrices=a_matrices,
restriction=restriction, restriction=restriction,
preferred_position=i, preferred_position=i,
coeff_func=coeff_func,
element=element,
component=component,
) )
from dune.perftool.sumfact.vectorization import attach_vectorization_info from dune.perftool.sumfact.vectorization import attach_vectorization_info
...@@ -91,19 +93,19 @@ def pymbolic_coefficient_gradient(element, restriction, component, coeff_func, v ...@@ -91,19 +93,19 @@ def pymbolic_coefficient_gradient(element, restriction, component, coeff_func, v
index = sf.index index = sf.index
padding = sf.padding padding = sf.padding
if buf is None: # if buf is None:
buf = get_counted_variable("buffer") # buf = get_counted_variable("buffer")
if inp is None: # if inp is None:
inp = get_counted_variable("input") # inp = get_counted_variable("input")
#
# Initialize the buffer for the sum fact kernel # # Initialize the buffer for the sum fact kernel
shape = (product(mat.cols for mat in a_matrices),) # shape = (product(mat.cols for mat in a_matrices),)
if index is not None: # if index is not None:
shape = shape + (4,) # shape = shape + (4,)
inp = initialize_buffer(buf).get_temporary(shape=shape, # inp = initialize_buffer(buf).get_temporary(shape=shape,
name=inp, # name=inp,
) # )
insn_dep = frozenset({Writes(inp)}) # insn_dep = frozenset({Writes(inp)})
if get_option('fastdg'): if get_option('fastdg'):
# Name of direct input, shape and globalarg is set in sum_factorization_kernel # Name of direct input, shape and globalarg is set in sum_factorization_kernel
...@@ -111,29 +113,16 @@ def pymbolic_coefficient_gradient(element, restriction, component, coeff_func, v ...@@ -111,29 +113,16 @@ def pymbolic_coefficient_gradient(element, restriction, component, coeff_func, v
else: else:
direct_input = None direct_input = None
# Setup the input! # Setup the input!
setup_theta(inp, element, restriction, component, index, coeff_func) #setup_theta(inp, element, restriction, component, index, coeff_func)
# Add a sum factorization kernel that implements the # Add a sum factorization kernel that implements the
# evaluation of the gradients of basis functions at quadrature # evaluation of the gradients of basis functions at quadrature
# points (stage 1) # points (stage 1)
if not get_global_context_value("dry_run", False): from dune.perftool.sumfact.realization import realize_sum_factorization_kernel
from dune.perftool.sumfact.realization import realize_sum_factorization_kernel var, insn_dep = realize_sum_factorization_kernel(sf,
var, insn_dep = realize_sum_factorization_kernel(sf,
insn_dep=insn_dep,
outshape=tuple(mat.rows for mat in a_matrices if mat.face is None), outshape=tuple(mat.rows for mat in a_matrices if mat.face is None),
direct_input=direct_input, direct_input=direct_input,
) )
# var, insn_dep = sum_factorization_kernel(a_matrices,
# buf,
# 1,
# preferred_position=i,
# insn_dep=insn_dep,
# restriction=restriction,
# outshape=tuple(mat.rows for mat in a_matrices if mat.face is None),
# direct_input=direct_input,
# )
else:
var = sf
buffers.append(var) buffers.append(var)
...@@ -171,6 +160,9 @@ def pymbolic_coefficient(element, restriction, component, coeff_func, visitor): ...@@ -171,6 +160,9 @@ def pymbolic_coefficient(element, restriction, component, coeff_func, visitor):
sf = SumfactKernel(a_matrices=a_matrices, sf = SumfactKernel(a_matrices=a_matrices,
restriction=restriction, restriction=restriction,
coeff_func=coeff_func,
element=element,
component=component,
) )
# TODO: Move this away! # TODO: Move this away!
...@@ -184,19 +176,19 @@ def pymbolic_coefficient(element, restriction, component, coeff_func, visitor): ...@@ -184,19 +176,19 @@ def pymbolic_coefficient(element, restriction, component, coeff_func, visitor):
inp = sf.input inp = sf.input
index = sf.index index = sf.index
padding = sf.padding padding = sf.padding
#
if buf is None: # if buf is None:
buf = get_counted_variable("buffer") # buf = get_counted_variable("buffer")
if inp is None: # if inp is None:
inp = get_counted_variable("input") # inp = get_counted_variable("input")
#
# Flip flop buffers for sumfactorization # # Flip flop buffers for sumfactorization
shape = (product(mat.cols for mat in a_matrices),) # shape = (product(mat.cols for mat in a_matrices),)
if index is not None: # if index is not None:
shape = shape + (4,) # shape = shape + (4,)
initialize_buffer(buf).get_temporary(shape=shape, # initialize_buffer(buf).get_temporary(shape=shape,
name=inp, # name=inp,
) # )
if get_option('fastdg'): if get_option('fastdg'):
# Name of direct input, shape and globalarg is set in sum_factorization_kernel # Name of direct input, shape and globalarg is set in sum_factorization_kernel
...@@ -204,14 +196,14 @@ def pymbolic_coefficient(element, restriction, component, coeff_func, visitor): ...@@ -204,14 +196,14 @@ def pymbolic_coefficient(element, restriction, component, coeff_func, visitor):
else: else:
direct_input = None direct_input = None
# Setup the input! # Setup the input!
setup_theta(inp, element, restriction, component, index, coeff_func) # setup_theta(inp, element, restriction, component, index, coeff_func)
# Add a sum factorization kernel that implements the evaluation of # Add a sum factorization kernel that implements the evaluation of
# the basis functions at quadrature points (stage 1) # the basis functions at quadrature points (stage 1)
if not get_global_context_value("dry_run", False): if not get_global_context_value("dry_run", False):
from dune.perftool.sumfact.realization import realize_sum_factorization_kernel from dune.perftool.sumfact.realization import realize_sum_factorization_kernel
var, _ = realize_sum_factorization_kernel(sf, var, _ = realize_sum_factorization_kernel(sf,
insn_dep=frozenset({Writes(inp)}), # insn_dep=frozenset({Writes(inp)}),
outshape=tuple(mat.rows for mat in a_matrices if mat.face is None), outshape=tuple(mat.rows for mat in a_matrices if mat.face is None),
direct_input=direct_input, direct_input=direct_input,
) )
......
...@@ -16,7 +16,9 @@ from dune.perftool.generation import (barrier, ...@@ -16,7 +16,9 @@ from dune.perftool.generation import (barrier,
from dune.perftool.loopy.buffer import (get_buffer_temporary, from dune.perftool.loopy.buffer import (get_buffer_temporary,
switch_base_storage, switch_base_storage,
) )
from dune.perftool.pdelab.argument import pymbolic_coefficient
from dune.perftool.pdelab.geometry import world_dimension from dune.perftool.pdelab.geometry import world_dimension
from dune.perftool.pdelab.spaces import name_lfs, name_lfs_bound
from dune.perftool.options import get_option from dune.perftool.options import get_option
from dune.perftool.pdelab.signatures import assembler_routine_name from dune.perftool.pdelab.signatures import assembler_routine_name
from dune.perftool.sumfact.permutation import (_sf_permutation_strategy, from dune.perftool.sumfact.permutation import (_sf_permutation_strategy,
...@@ -40,20 +42,40 @@ def realize_sum_factorization_kernel(sf, insn_dep=frozenset(), outshape=None, di ...@@ -40,20 +42,40 @@ def realize_sum_factorization_kernel(sf, insn_dep=frozenset(), outshape=None, di
insn_dep = frozenset({insn_dep}) insn_dep = frozenset({insn_dep})
assert isinstance(insn_dep, frozenset) assert isinstance(insn_dep, frozenset)
# Get the vectorization info. During dry run, this is a now op
# sf = attach_vectorization_info(sf)
if get_global_context_value("dry_run", False): if get_global_context_value("dry_run", False):
# During the dry run, we just return the kernel as passed into this # During the dry run, we just return the kernel as passed into this
# function. After the dry run, it can be used to attach information # function. After the dry run, it can be used to attach information
# about vectorization. # about vectorization.
return sf, insn_dep return sf, insn_dep
# else:
# # This is the second run: Retrieve the vectorization information
# # attached in dune.perftool.sumfact.vectorization
# sf = attach_vectorization_info(sf)
# Get the instruction dependencies of the sumfact kernel. This variable will be # Get the instruction dependencies of the sumfact kernel. This variable will be
# updated throughout this function. # updated throughout this function.
insn_dep = insn_dep.union(sf.insn_dep) insn_dep = insn_dep.union(sf.insn_dep)
# Define some helper constructs that make it easier to write generic code later
vecindex = () if sf.index is None else (sf.index,)
# Set up the input for stage 1
if sf.stage == 1 and not get_option("fastdg"):
assert sf.coeff_func
# Get the input temporary!
input_setup = get_buffer_temporary(sf.buffer,
shape=sf.flat_input_shape,
)
# Write initial coefficients into buffer
lfs = name_lfs(sf.element, sf.restriction, sf.component)
basisiname = sumfact_iname(name_lfs_bound(lfs), "basis")
container = sf.coeff_func(sf.restriction)
coeff = pymbolic_coefficient(container, lfs, basisiname)
assignee = prim.Subscript(prim.Variable(input_setup), (prim.Variable(basisiname),) + vecindex)
insn_dep = instruction(assignee=assignee,
expression=coeff,
)
# Prepare some dim_tags/shapes for later use # Prepare some dim_tags/shapes for later use
ftags = ",".join(["f"] * sf.length) ftags = ",".join(["f"] * sf.length)
novec_ftags = ftags novec_ftags = ftags
......
...@@ -184,7 +184,6 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id): ...@@ -184,7 +184,6 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
index = () index = ()
vectag = frozenset() vectag = frozenset()
base_storage_size = product(max(mat.rows, mat.cols) for mat in a_matrices)
temp = initialize_buffer(buf).get_temporary(shape=shape, temp = initialize_buffer(buf).get_temporary(shape=shape,
dim_tags=dim_tags, dim_tags=dim_tags,
name=inp, name=inp,
...@@ -356,19 +355,6 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id): ...@@ -356,19 +355,6 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
insn_dep = emit_sumfact_kernel(None, restriction, insn_dep) insn_dep = emit_sumfact_kernel(None, restriction, insn_dep)
@generator_factory(item_tags=("sumfactkernel",), context_tags=("kernel",), cache_key_generator=lambda a, b, s, **kw: (a, b, s, kw.get("restriction", 0)))
def sum_factorization_kernel(a_matrices,
buf,
stage,
insn_dep=frozenset({}),
additional_inames=frozenset({}),
preferred_position=None,
outshape=None,
restriction=0,
direct_input=None,
direct_output=None,
visitor=None,
):
"""Create a sum factorization kernel """Create a sum factorization kernel
Sum factorization can be written as Sum factorization can be written as
...@@ -430,214 +416,4 @@ def sum_factorization_kernel(a_matrices, ...@@ -430,214 +416,4 @@ def sum_factorization_kernel(a_matrices,
restriction: Restriction for faces values. restriction: Restriction for faces values.
direct_input: Global data structure containing input for direct_input: Global data structure containing input for
sumfactorization (e.g. when using FastDGGridOperator). sumfactorization (e.g. when using FastDGGridOperator).
""" """
# Return a pymbolic SumfactKernel node in the dry run. This will \ No newline at end of file
# be used to decide on an appropriate vectorization strategy
# before we do the real thing.
if get_global_context_value("dry_run", False):
return SumfactKernel(a_matrices, buf, stage, preferred_position, restriction), frozenset()
ftags = ",".join(["f"] * len(a_matrices))
novec_ftags = ftags
ctags = ",".join(["c"] * len(a_matrices))
vec_shape = ()
if next(iter(a_matrices)).vectorized:
ftags = ftags + ",vec"
ctags = ctags + ",vec"
vec_shape = (4,)
# Measure times and count operations in c++ code
if get_option("instrumentation_level") >= 4:
timer_name = assembler_routine_name() + '_kernel' + '_stage{}'.format(stage)
post_include('HP_DECLARE_TIMER({});'.format(timer_name), filetag='operatorfile')
dump_accumulate_timer(timer_name)
insn_dep = frozenset({instruction(code="HP_TIMER_START({});".format(timer_name),
depends_on=insn_dep,
within_inames=additional_inames)})
# Put a barrier before the sumfactorization kernel
insn_dep = frozenset({barrier(depends_on=insn_dep,
within_inames=additional_inames,
)})
# Decide in which order we want to process directions in the
# sumfactorization. A clever ordering can lead to a reduced
# complexity. This will e.g. happen at faces where we only have
# one quadratue point m_l=1 if l is the normal direction of the
# face.
#
# Rule of thumb: small m's early and large n's late.
perm = _sf_permutation_strategy(a_matrices, stage)
# Permute a_matrices
a_matrices = _permute_forward(a_matrices, perm)
# Product of all matrices
for l, a_matrix in enumerate(a_matrices):
# Compute the correct shapes of in- and output matrices of this matrix-matrix multiplication
# and get inames that realize the product.
inp_shape = (a_matrix.cols,) + tuple(mat.cols for mat in a_matrices[l + 1:]) + tuple(mat.rows for mat in a_matrices[:l])
out_shape = (a_matrix.rows,) + tuple(mat.cols for mat in a_matrices[l + 1:]) + tuple(mat.rows for mat in a_matrices[:l])
out_inames = tuple(sumfact_iname(length, "out_inames_" + str(k)) for k, length in enumerate(out_shape))
vec_iname = ()
if a_matrix.vectorized:
iname = sumfact_iname(4, "vec")
vec_iname = (prim.Variable(iname),)
transform(lp.tag_inames, [(iname, "vec")])
# A trivial reduction is implemented as a product, otherwise we run into
# a code generation corner case producing way too complicated code. This
# could be fixed upstream, but the loopy code realizing reductions is not
# trivial and the priority is kind of low.
if a_matrix.cols != 1:
k = sumfact_iname(a_matrix.cols, "red")
k_expr = prim.Variable(k)
else:
k_expr = 0
# Setup the input of the sum factorization kernel. In the
# first matrix multiplication this can be taken from
# * an input temporary (default)
# * a global data structure (if FastDGGridOperator is in use)
# * a value from a global data structure, broadcasted to a vector type (vectorized + FastDGGridOperator)
input_inames = (k_expr,) + tuple(prim.Variable(j) for j in out_inames[1:])
if l == 0 and direct_input is not None:
# See comment below
input_inames = _permute_backward(input_inames, perm)
inp_shape = _permute_backward(inp_shape, perm)
globalarg(direct_input, dtype=np.float64, shape=inp_shape, dim_tags=novec_ftags)
if a_matrix.vectorized:
input_summand = prim.Call(prim.Variable("Vec4d"),
(prim.Subscript(prim.Variable(direct_input),
input_inames),))
else:
input_summand = prim.Subscript(prim.Variable(direct_input),
input_inames + vec_iname)
else:
# If we did permute the order of a matrices above we also
# permuted the order of out_inames. Unfortunately the
# order of our input is from 0 to d-1. This means we need
# to permute _back_ to get the right coefficients.
if l == 0:
inp_shape = _permute_backward(inp_shape, perm)
input_inames = _permute_backward(input_inames, perm)
# Get a temporary that interprets the base storage of the input
# as a column-major matrix. In later iteration of the amatrix loop
# this reinterprets the output of the previous iteration.
inp = get_buffer_temporary(buf,
shape=inp_shape + vec_shape,
dim_tags=ftags)
# The input temporary will only be read from, so we need to silence the loopy warning
silenced_warning('read_no_write({})'.format(inp))
input_summand = prim.Subscript(prim.Variable(inp),
input_inames + vec_iname)
switch_base_storage(buf)
# Get a temporary that interprets the base storage of the output.
#
# Note: In this step the reordering of the fastest directions
# is happening. The new direction (out_inames[0]) and the
# corresponding shape (out_shape[0]) goes to the end (slowest
# direction) and everything stays column major (ftags->fortran
# style).
#
# If we are in the last step we reverse the permutation.
output_shape = tuple(out_shape[1:]) + (out_shape[0],)
if l == len(a_matrices) - 1:
output_shape = _permute_backward(output_shape, perm)
out = get_buffer_temporary(buf,
shape=output_shape + vec_shape,
dim_tags=ftags)
# Write the matrix-matrix multiplication expression
matprod = Product((prim.Subscript(prim.Variable(a_matrix.name),
(prim.Variable(out_inames[0]), k_expr) + vec_iname),
input_summand))
# ... which may be a reduction, if k>0
if a_matrix.cols != 1:
matprod = lp.Reduction("sum", k, matprod)
# Here we also move the new direction (out_inames[0]) to the
# end and reverse permutation
output_inames = tuple(prim.Variable(i) for i in out_inames[1:]) + (prim.Variable(out_inames[0]),)
if l == len(a_matrices) - 1:
output_inames = _permute_backward(output_inames, perm)
# In case of direct output we directly accumulate the result
# of the Sumfactorization into some global data structure.
if l == len(a_matrices) - 1 and direct_output is not None:
ft = get_global_context_value("form_type")
if ft == 'residual' or ft == 'jacobian_apply':
globalarg(direct_output, dtype=np.float64, shape=output_shape, dim_tags=novec_ftags)
assignee = prim.Subscript(prim.Variable(direct_output), output_inames)
else:
assert ft == 'jacobian'
globalarg(direct_output,
dtype=np.float64,
shape=output_shape + output_shape,
dim_tags=novec_ftags + "," + novec_ftags)
# TODO the next line should get its inames from
# elsewhere. This is *NOT* robust (but works right
# now)
_ansatz_inames = tuple(Variable(visitor.inames[i]) for i in range(world_dimension()))
assignee = prim.Subscript(prim.Variable(direct_output), _ansatz_inames + output_inames)
# In case of vectorization we need to apply a horizontal add
if a_matrix.vectorized:
matprod = prim.Call(prim.Variable("horizontal_add"),
(matprod,))
# We need to accumulate
matprod = prim.Sum((assignee, matprod))
else:
assignee = prim.Subscript(prim.Variable(out), output_inames + vec_iname)
# Issue the reduction instruction that implements the multiplication
# at the same time store the instruction ID for the next instruction to depend on
insn_dep = frozenset({instruction(assignee=assignee,
expression=matprod,
forced_iname_deps=frozenset([iname for iname in out_inames]).union(additional_inames),
forced_iname_deps_is_final=True,
depends_on=insn_dep,
)
})
# Measure times and count operations in c++ code
if get_option("instrumentation_level") >= 4:
insn_dep = frozenset({instruction(code="HP_TIMER_STOP({});".format(timer_name),
depends_on=insn_dep,
within_inames=additional_inames)})
if stage == 1:
qp_timer_name = assembler_routine_name() + '_kernel' + '_quadratureloop'
post_include('HP_DECLARE_TIMER({});'.format(timer_name), filetag='operatorfile')
dump_accumulate_timer(timer_name)
insn_dep = instruction(code="HP_TIMER_START({});".format(qp_timer_name),
depends_on=insn_dep)
if outshape is None:
assert stage == 3
outshape = tuple(mat.rows for mat in a_matrices)
dim_tags = ",".join(['f'] * len(outshape))
if next(iter(a_matrices)).vectorized:
outshape = outshape + vec_shape
# This is a 'bit' hacky: In stage 3 we need to return something with vectag, in stage 1 not.
if stage == 1:
dim_tags = dim_tags + ",c"
else:
dim_tags = dim_tags + ",vec"
out = get_buffer_temporary(buf,
shape=outshape,
dim_tags=dim_tags,
)
silenced_warning('read_no_write({})'.format(out))
return next(iter(a_matrices)).output_to_pymbolic(out), insn_dep
...@@ -21,10 +21,12 @@ def _cache_vectorization_info(old, new): ...@@ -21,10 +21,12 @@ def _cache_vectorization_info(old, new):
return new return new
_collect_sumfact_nodes = generator_factory(item_tags=("sumfactnodes", "dryrundata"), no_deco=True)
def attach_vectorization_info(sf): def attach_vectorization_info(sf):
assert isinstance(sf, SumfactKernel) assert isinstance(sf, SumfactKernel)
if get_global_context_value("dry_run"): if get_global_context_value("dry_run"):
return sf return _collect_sumfact_nodes(sf)
else: else:
return _cache_vectorization_info(sf, None) return _cache_vectorization_info(sf, None)
...@@ -110,11 +112,15 @@ def decide_vectorization_strategy(): ...@@ -110,11 +112,15 @@ def decide_vectorization_strategy():
if not get_option("vectorize_grads"): if not get_option("vectorize_grads"):
no_vectorization(sumfacts) no_vectorization(sumfacts)
else: else:
for stage in (1, 3): res = (Restriction.NONE, Restriction.POSITIVE, Restriction.NEGATIVE)
res = (Restriction.NONE, Restriction.POSITIVE, Restriction.NEGATIVE) # Stage 1 kernels
import itertools as it for restriction in res:
for restriction in it.product(res, res): decide_stage_vectorization_strategy(sumfacts, 1, restriction)
decide_stage_vectorization_strategy(sumfacts, stage, restriction)
# Stage 3 kernels
import itertools as it
for restriction in it.product(res, res):
decide_stage_vectorization_strategy(sumfacts, 3, restriction)
class HasSumfactMapper(lp.symbolic.CombineMapper): class HasSumfactMapper(lp.symbolic.CombineMapper):
...@@ -133,6 +139,9 @@ class HasSumfactMapper(lp.symbolic.CombineMapper): ...@@ -133,6 +139,9 @@ class HasSumfactMapper(lp.symbolic.CombineMapper):
def map_sumfact_kernel(self, expr): def map_sumfact_kernel(self, expr):
return frozenset({expr}) return frozenset({expr})
def map_tagged_variable(self, expr):
return frozenset()
def find_sumfact(expr): def find_sumfact(expr):
return HasSumfactMapper()(expr) return HasSumfactMapper()(expr)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment