Skip to content
Snippets Groups Projects
Commit 8b60ecf4 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Bugfix on sumfact stage1: Input layout

parent 392c621a
No related branches found
No related tags found
No related merge requests found
......@@ -493,6 +493,7 @@ def sum_factorization_kernel(a_matrices,
return SumfactKernel(a_matrices, buf, stage, preferred_position, restriction), frozenset()
ftags = ",".join(["f"]*len(a_matrices))
novec_ftags = ftags
ctags = ",".join(["c"]*len(a_matrices))
vec_shape = ()
if next(iter(a_matrices)).vectorized:
......@@ -556,11 +557,11 @@ def sum_factorization_kernel(a_matrices,
# * a value from a global data structure, broadcasted to a vector type (vectorized + FastDGGridOperator)
input_inames = (k_expr,) + tuple(prim.Variable(j) for j in out_inames[1:])
if l == 0 and direct_input is not None:
# See comment bellow
# See comment below
input_inames = _permute_backward(input_inames, perm)
inp_shape = _permute_backward(inp_shape, perm)
globalarg(direct_input, dtype=np.float64, shape=inp_shape)
globalarg(direct_input, dtype=np.float64, shape=inp_shape, dim_tags=novec_ftags)
if a_matrix.vectorized:
input_summand = prim.Call(prim.Variable("Vec4d"),
(prim.Subscript(prim.Variable(direct_input),
......@@ -627,7 +628,6 @@ def sum_factorization_kernel(a_matrices,
# of the Sumfactorization into some global data structure.
if l == len(a_matrices)-1 and direct_output is not None:
ft = get_global_context_value("form_type")
novec_ftags = ",".join(["f"]*len(a_matrices))
if ft == 'residual' or ft == 'jacobian_apply':
globalarg(direct_output, dtype=np.float64, shape=output_shape, dim_tags=novec_ftags)
assignee = prim.Subscript(prim.Variable(direct_output), output_inames)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment