diff --git a/python/dune/codegen/pdelab/tensors.py b/python/dune/codegen/pdelab/tensors.py index 6c6c55911675d4e0871697c7010b4ef986d62bdb..70a0115d287d0613b8dc8d4a24a3dada4b551c15 100644 --- a/python/dune/codegen/pdelab/tensors.py +++ b/python/dune/codegen/pdelab/tensors.py @@ -123,12 +123,6 @@ def name_matrix_inverse(name, shape, visitor): return name_inv -def matrix_inverse(name, shape, visitor): - name_inv = name_matrix_inverse(name, shape, visitor) - - return prim.Variable(name_inv) - - def define_assembled_tensor(name, expr, visitor): temporary_variable(name, shape=expr.ufl_shape, @@ -157,6 +151,18 @@ def pymbolic_matrix_inverse(o, visitor): indices = visitor.indices visitor.indices = None name = name_assembled_tensor(expr, visitor) + + if expr.shape[0] <= 3: + name = name_matrix_inverse(name, expr.ufl_shape, visitor) + else: + instruction(code="{}.invert();".format(name), + within_inames=frozenset(visitor.quadrature_inames()), + depends_on=frozenset({lp.match.Writes(name), + lp.match.Tagged("sumfact_stage1"), + }), + tags=frozenset({"quad"}), + ) + visitor.indices = indices - return matrix_inverse(name, expr.ufl_shape, visitor) + return prim.Variable(name)