Skip to content
Snippets Groups Projects
Commit ad360799 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Implement ListTensors with dimension > 1.

parent 23ff178c
No related branches found
No related tags found
No related merge requests found
......@@ -10,10 +10,14 @@ import pymbolic.primitives as prim
import numpy as np
def define_list_tensor(name, expr, visitor):
def define_list_tensor(name, expr, visitor, stack=()):
for i, child in enumerate(expr.ufl_operands):
instruction(assignee=prim.Subscript(prim.Variable(name), (i,)),
expression=visitor.call(child))
from ufl.classes import ListTensor
if isinstance(child, ListTensor):
define_list_tensor(name, child, visitor, stack=stack + (i,))
else:
instruction(assignee=prim.Subscript(prim.Variable(name), stack + (i,)),
expression=visitor.call(child))
@kernel_cached
......@@ -22,6 +26,7 @@ def pymbolic_list_tensor(expr, visitor):
temporary_variable(name,
shape=expr.ufl_shape,
dtype=np.float64,
managed=True,
)
define_list_tensor(name, expr, visitor)
return prim.Variable(name)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment