diff --git a/python/dune/codegen/pdelab/tensors.py b/python/dune/codegen/pdelab/tensors.py index 59bb314da1cca79a90f4284cb869a059186f78eb..68f6d393129600f0fda995ea1a4f239ff844c173 100644 --- a/python/dune/codegen/pdelab/tensors.py +++ b/python/dune/codegen/pdelab/tensors.py @@ -36,17 +36,43 @@ def name_assembled_tensor(o, visitor): @kernel_cached +def code_generation_time_inversion(expr, visitor): + mat = np.ndarray(expr.ufl_shape) + for indices in it.product(*tuple(range(i) for i in expr.ufl_shape)): + visitor.indices = indices + val = visitor.call(expr.ufl_operands[0]) + if not isinstance(val, (float, int)): + visitor.indices = None + return None + + mat[indices] = val + + visitor.indices = None + return np.linalg.inv(mat) + + def pymbolic_matrix_inverse(o, visitor): + # Try to evaluate the matrix at code generation time. + # If this works (it does e.g. for Maxwell on structured grids) + # we can invert the matrix at code generation time!!! indices = visitor.indices visitor.indices = None + + mat = code_generation_time_inversion(o, visitor) + if mat is not None: + return mat[indices] + + # If code generation time inversion failed, we assemble it in C++ + # and invert it there. name = name_assembled_tensor(o.ufl_operands[0], visitor) instruction(code="{}.invert();".format(name), - within_inames=frozenset(visitor.interface.quadrature_inames()), + within_inames=frozenset(visitor.quadrature_inames()), depends_on=frozenset({lp.match.Writes(name), lp.match.Tagged("sumfact_stage1"), }), tags=frozenset({"quad"}), + assignees=frozenset({name}), ) visitor.indices = indices diff --git a/python/dune/codegen/ufl/visitor.py b/python/dune/codegen/ufl/visitor.py index ab7c334f323bbc51221d926d357466d6ebc8cf88..e774e03d90e8bc61a1108767867ab046bfa56c81 100644 --- a/python/dune/codegen/ufl/visitor.py +++ b/python/dune/codegen/ufl/visitor.py @@ -37,6 +37,7 @@ from ufl.classes import (Coefficient, JacobianDeterminant, ) +from pytools import product as ptproduct import pymbolic.primitives as prim import numpy as np @@ -278,7 +279,10 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker): # def product(self, o): - return prim.flattened_product(tuple(self.call(op) for op in o.ufl_operands)) + ops = tuple(self.call(op) for op in o.ufl_operands) + if all(isinstance(op, (int, float)) for op in ops): + return ptproduct(ops) + return prim.flattened_product(ops) def float_value(self, o): return o.value() @@ -290,7 +294,10 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker): return prim.quotient(self.call(o.ufl_operands[0]), self.call(o.ufl_operands[1])) def sum(self, o): - return prim.flattened_sum(tuple(self.call(op) for op in o.ufl_operands)) + ops = tuple(self.call(op) for op in o.ufl_operands) + if all(isinstance(op, (int, float)) for op in ops): + return sum(ops) + return prim.flattened_sum(ops) def zero(self, o): # UFL has Zeroes with shape. We ignore those indices.