Skip to content
Snippets Groups Projects
Commit fca6404b authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Cleanup

parent c29aa746
No related branches found
No related tags found
No related merge requests found
...@@ -120,7 +120,7 @@ class _RegisteredFunction(object): ...@@ -120,7 +120,7 @@ class _RegisteredFunction(object):
return self._get_content(cache_key) return self._get_content(cache_key)
def remove_by_value(self, val): def remove_by_value(self, val):
self._memoize_cache = {k:v for k, v in self._memoize_cache.items() if v != val} self._memoize_cache = {k: v for k, v in self._memoize_cache.items() if v != val}
def generator_factory(**factory_kwargs): def generator_factory(**factory_kwargs):
......
...@@ -69,7 +69,6 @@ def pymbolic_trialfunction_gradient(element, restriction, component, visitor): ...@@ -69,7 +69,6 @@ def pymbolic_trialfunction_gradient(element, restriction, component, visitor):
dim = formdata.geometric_dimension dim = formdata.geometric_dimension
buffers = [] buffers = []
insn_dep = None insn_dep = None
ret = False
for i in range(dim): for i in range(dim):
a_matrices = [theta_matrix] * dim a_matrices = [theta_matrix] * dim
a_matrices[i] = dtheta_matrix a_matrices[i] = dtheta_matrix
...@@ -98,28 +97,19 @@ def pymbolic_trialfunction_gradient(element, restriction, component, visitor): ...@@ -98,28 +97,19 @@ def pymbolic_trialfunction_gradient(element, restriction, component, visitor):
# Add a sum factorization kernel that implements the # Add a sum factorization kernel that implements the
# evaluation of the gradients of basis functions at quadrature # evaluation of the gradients of basis functions at quadrature
# points (stage 1) # points (stage 1)
if index: var, insn_dep = sum_factorization_kernel(a_matrices,
assert len(visitor.indices) == 1 buffer,
var, insn_dep = sum_factorization_kernel(a_matrices, 1,
buffer, preferred_position=i,
1, insn_dep=insn_dep,
preferred_position=i, )
insn_dep=insn_dep, buffers.append(var)
)
ret = True
else:
var, insn_dep = sum_factorization_kernel(a_matrices,
buffer,
1,
preferred_position=i,
insn_dep=insn_dep,
)
buffers.append(var)
# Check whether we want to return early with something that has the indexing # Check whether we want to return early with something that has the indexing
# already handled! This happens with vectorization when the index coincides # already handled! This happens with vectorization when the index coincides
# with the position in the vector register. # with the position in the vector register.
if ret: if index:
assert len(visitor.indices) == 1
indices = visitor.indices indices = visitor.indices
visitor.indices = None visitor.indices = None
return maybe_wrap_subscript(var, tuple(prim.Variable(i) for i in quadrature_inames()) + indices) return maybe_wrap_subscript(var, tuple(prim.Variable(i) for i in quadrature_inames()) + indices)
...@@ -179,8 +169,13 @@ def pymbolic_trialfunction(element, restriction, component, visitor): ...@@ -179,8 +169,13 @@ def pymbolic_trialfunction(element, restriction, component, visitor):
insn_dep=frozenset({Writes(input)}), insn_dep=frozenset({Writes(input)}),
) )
if index:
index = (index,)
else:
index = ()
return prim.Subscript(var, return prim.Subscript(var,
tuple(prim.Variable(i) for i in quadrature_inames()) tuple(prim.Variable(i) for i in quadrature_inames() + index)
) )
......
...@@ -237,7 +237,7 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id): ...@@ -237,7 +237,7 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
transform(nest_quadrature_loops, visitor.inames) transform(nest_quadrature_loops, visitor.inames)
@generator_factory(item_tags=("sumfactkernel",), context_tags=("kernel",), cache_key_generator=lambda a, b, s, **kw: (a,b,s)) @generator_factory(item_tags=("sumfactkernel",), context_tags=("kernel",), cache_key_generator=lambda a, b, s, **kw: (a, b, s))
def sum_factorization_kernel(a_matrices, buf, stage, insn_dep=frozenset({}), additional_inames=frozenset({}), preferred_position=None): def sum_factorization_kernel(a_matrices, buf, stage, insn_dep=frozenset({}), additional_inames=frozenset({}), preferred_position=None):
""" """
Calculate a sum factorization matrix product. Calculate a sum factorization matrix product.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment