Skip to content
Snippets Groups Projects
Commit fca6404b authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Cleanup

parent c29aa746
No related branches found
No related tags found
No related merge requests found
......@@ -120,7 +120,7 @@ class _RegisteredFunction(object):
return self._get_content(cache_key)
def remove_by_value(self, val):
self._memoize_cache = {k:v for k, v in self._memoize_cache.items() if v != val}
self._memoize_cache = {k: v for k, v in self._memoize_cache.items() if v != val}
def generator_factory(**factory_kwargs):
......
......@@ -69,7 +69,6 @@ def pymbolic_trialfunction_gradient(element, restriction, component, visitor):
dim = formdata.geometric_dimension
buffers = []
insn_dep = None
ret = False
for i in range(dim):
a_matrices = [theta_matrix] * dim
a_matrices[i] = dtheta_matrix
......@@ -98,28 +97,19 @@ def pymbolic_trialfunction_gradient(element, restriction, component, visitor):
# Add a sum factorization kernel that implements the
# evaluation of the gradients of basis functions at quadrature
# points (stage 1)
if index:
assert len(visitor.indices) == 1
var, insn_dep = sum_factorization_kernel(a_matrices,
buffer,
1,
preferred_position=i,
insn_dep=insn_dep,
)
ret = True
else:
var, insn_dep = sum_factorization_kernel(a_matrices,
buffer,
1,
preferred_position=i,
insn_dep=insn_dep,
)
buffers.append(var)
var, insn_dep = sum_factorization_kernel(a_matrices,
buffer,
1,
preferred_position=i,
insn_dep=insn_dep,
)
buffers.append(var)
# Check whether we want to return early with something that has the indexing
# already handled! This happens with vectorization when the index coincides
# with the position in the vector register.
if ret:
if index:
assert len(visitor.indices) == 1
indices = visitor.indices
visitor.indices = None
return maybe_wrap_subscript(var, tuple(prim.Variable(i) for i in quadrature_inames()) + indices)
......@@ -179,8 +169,13 @@ def pymbolic_trialfunction(element, restriction, component, visitor):
insn_dep=frozenset({Writes(input)}),
)
if index:
index = (index,)
else:
index = ()
return prim.Subscript(var,
tuple(prim.Variable(i) for i in quadrature_inames())
tuple(prim.Variable(i) for i in quadrature_inames() + index)
)
......
......@@ -237,7 +237,7 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
transform(nest_quadrature_loops, visitor.inames)
@generator_factory(item_tags=("sumfactkernel",), context_tags=("kernel",), cache_key_generator=lambda a, b, s, **kw: (a,b,s))
@generator_factory(item_tags=("sumfactkernel",), context_tags=("kernel",), cache_key_generator=lambda a, b, s, **kw: (a, b, s))
def sum_factorization_kernel(a_matrices, buf, stage, insn_dep=frozenset({}), additional_inames=frozenset({}), preferred_position=None):
"""
Calculate a sum factorization matrix product.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment