From d9eee811a1c1f7986dd292dae25b56dffc56b9f9 Mon Sep 17 00:00:00 2001
From: Marcel Koch <marcel.koch@uni-muenster.de>
Date: Wed, 30 Jan 2019 16:14:58 +0100
Subject: [PATCH] copy loopy computation of matrix inverse to tensors.py

---
 python/dune/codegen/pdelab/tensors.py | 146 +++++++++++++++++++++++---
 1 file changed, 134 insertions(+), 12 deletions(-)

diff --git a/python/dune/codegen/pdelab/tensors.py b/python/dune/codegen/pdelab/tensors.py
index a924a39a..6bd0eafd 100644
--- a/python/dune/codegen/pdelab/tensors.py
+++ b/python/dune/codegen/pdelab/tensors.py
@@ -8,12 +8,140 @@ from dune.codegen.generation import (get_counted_variable,
                                      temporary_variable,
                                      )
 
+from loopy.match import Writes
+
 import pymbolic.primitives as prim
 import numpy as np
 import loopy as lp
 import itertools as it
 
 
+def define_determinant(name, matrix, shape, visitor):
+    temporary_variable(name)
+
+    assert len(shape) == 2 and shape[0] == shape[1]
+    dim = shape[0]
+
+    matrix_entry = [[prim.Subscript(prim.Variable(matrix), (i, j)) for j in range(dim)] for i in range(dim)]
+    if dim == 2:
+        expr_determinant = prim.Sum((prim.Product((matrix_entry[0][0], matrix_entry[1][1])),
+                                     -1 * prim.Product((matrix_entry[1][0], matrix_entry[0][1]))))
+    elif dim == 3:
+        expr_determinant = prim.Sum((prim.Product((matrix_entry[0][0], matrix_entry[1][1], matrix_entry[2][2])),
+                                     prim.Product((matrix_entry[0][1], matrix_entry[1][2], matrix_entry[2][0])),
+                                     prim.Product((matrix_entry[0][2], matrix_entry[1][0], matrix_entry[2][1])),
+                                     -1 * prim.Product((matrix_entry[0][2], matrix_entry[1][1], matrix_entry[2][0])),
+                                     -1 * prim.Product((matrix_entry[0][0], matrix_entry[1][2], matrix_entry[2][1])),
+                                     -1 * prim.Product((matrix_entry[0][1], matrix_entry[1][0], matrix_entry[2][2]))
+                                     ))
+    else:
+        raise NotImplementedError()
+    instruction(expression=expr_determinant,
+                assignee=prim.Variable(name),
+                within_inames=frozenset(visitor.quadrature_inames()),
+                depends_on=frozenset({Writes(matrix)})
+                )
+
+
+def define_determinant_inverse(name, matrix, shape, visitor):
+    det = name_determinant(matrix, shape, visitor)
+
+    temporary_variable(name)
+
+    instruction(expression=prim.Quotient(1, prim.Variable(det)),
+                assignee=prim.Variable(name),
+                within_inames=frozenset(visitor.quadrature_inames()),
+                depends_on=frozenset({Writes(matrix)})
+                )
+
+
+def define_matrix_inverse(name, name_inv, shape, visitor):
+    temporary_variable(name_inv, shape=shape, managed=True)
+
+    det_inv = name_determinant_inverse(name, shape, visitor)
+
+    assert len(shape) == 2 and shape[0] == shape[1]
+    dim = shape[0]
+
+    matrix_entry = [[prim.Subscript(prim.Variable(name), (i, j)) for j in range(dim)] for i in range(dim)]
+    assignee = [[prim.Subscript(prim.Variable(name_inv), (i, j)) for j in range(dim)] for i in range(dim)]
+    exprs = [[None for _ in range(dim)] for _ in range(dim)]
+
+    if dim == 2:
+        for i in range(2):
+            for j in range(2):
+                sign = 1. if i == j else -1.
+                exprs[i][j] = prim.Product((sign, prim.Variable(det_inv), matrix_entry[1 - i][1 - j]))
+    elif dim == 3:
+        exprs[0][0] = prim.Product((1., prim.Variable(det_inv),
+                                    prim.Sum((prim.Product((matrix_entry[1][1], matrix_entry[2][2])),
+                                              -1 * prim.Product((matrix_entry[1][2], matrix_entry[2][1]))))))
+        exprs[1][0] = prim.Product((-1., prim.Variable(det_inv),
+                                    prim.Sum((prim.Product((matrix_entry[0][1], matrix_entry[2][2])),
+                                              -1 * prim.Product((matrix_entry[0][2], matrix_entry[2][1]))))))
+        exprs[2][0] = prim.Product((1., prim.Variable(det_inv),
+                                    prim.Sum((prim.Product((matrix_entry[0][1], matrix_entry[1][2])),
+                                              -1 * prim.Product((matrix_entry[0][2], matrix_entry[1][1]))))))
+
+        exprs[0][1] = prim.Product((-1., prim.Variable(det_inv),
+                                    prim.Sum((prim.Product((matrix_entry[1][0], matrix_entry[2][2])),
+                                              -1 * prim.Product((matrix_entry[1][2], matrix_entry[2][0]))))))
+        exprs[1][1] = prim.Product((1., prim.Variable(det_inv),
+                                    prim.Sum((prim.Product((matrix_entry[0][0], matrix_entry[2][2])),
+                                              -1 * prim.Product((matrix_entry[0][2], matrix_entry[2][0]))))))
+        exprs[2][1] = prim.Product((-1., prim.Variable(det_inv),
+                                    prim.Sum((prim.Product((matrix_entry[0][0], matrix_entry[1][2])),
+                                              -1 * prim.Product((matrix_entry[0][2], matrix_entry[1][0]))))))
+
+        exprs[0][2] = prim.Product((1., prim.Variable(det_inv),
+                                    prim.Sum((prim.Product((matrix_entry[1][0], matrix_entry[2][1])),
+                                              -1 * prim.Product((matrix_entry[1][1], matrix_entry[2][0]))))))
+        exprs[1][2] = prim.Product((-1., prim.Variable(det_inv),
+                                    prim.Sum((prim.Product((matrix_entry[0][0], matrix_entry[2][1])),
+                                              -1 * prim.Product((matrix_entry[0][1], matrix_entry[2][0]))))))
+        exprs[2][2] = prim.Product((1., prim.Variable(det_inv),
+                                    prim.Sum((prim.Product((matrix_entry[0][0], matrix_entry[1][1])),
+                                              -1 * prim.Product((matrix_entry[0][1], matrix_entry[1][0]))))))
+    else:
+        raise NotImplementedError
+    for j in range(dim):
+        for i in range(dim):
+            instruction(expression=exprs[i][j],
+                        assignee=assignee[i][j],
+                        within_inames=frozenset(visitor.quadrature_inames()),
+                        depends_on=frozenset({Writes(name)}))
+
+
+def name_determinant(matrix, shape, visitor):
+    name = matrix + "_det"
+
+    define_determinant(name, matrix, shape, visitor)
+
+    return name
+
+
+def name_determinant_inverse(matrix, shape, visitor):
+    name = matrix + "_det_inv"
+
+    define_determinant_inverse(name, matrix, shape, visitor)
+
+    return name
+
+
+def name_matrix_inverse(name, shape, visitor):
+    name_inv =  name + "_inv"
+
+    define_matrix_inverse(name, name_inv, shape, visitor)
+
+    return name_inv
+
+
+def matrix_inverse(name, shape, visitor):
+    name_inv = name_matrix_inverse(name, shape, visitor)
+
+    return prim.Variable(name_inv)
+
+
 def define_assembled_tensor(name, expr, visitor):
     temporary_variable(name,
                        shape=expr.ufl_shape,
@@ -22,7 +150,7 @@ def define_assembled_tensor(name, expr, visitor):
         visitor.indices = indices
         instruction(assignee=prim.Subscript(prim.Variable(name), indices),
                     expression=visitor.call(expr),
-                    forced_iname_deps=frozenset(visitor.interface.quadrature_inames()),
+                    forced_iname_deps=frozenset(visitor.quadrature_inames()),
                     depends_on=frozenset({lp.match.Tagged("sumfact_stage1")}),
                     tags=frozenset({"quad"}),
                     )
@@ -37,17 +165,11 @@ def name_assembled_tensor(o, visitor):
 
 @kernel_cached
 def pymbolic_matrix_inverse(o, visitor):
+    expr = o.ufl_operands[0]
+
     indices = visitor.indices
     visitor.indices = None
-    name = name_assembled_tensor(o.ufl_operands[0], visitor)
-
-    instruction(code="{}.invert();".format(name),
-                within_inames=frozenset(visitor.interface.quadrature_inames()),
-                depends_on=frozenset({lp.match.Writes(name),
-                                      lp.match.Tagged("sumfact_stage1"),
-                                      }),
-                tags=frozenset({"quad"}),
-                )
-
+    name = name_assembled_tensor(expr, visitor)
     visitor.indices = indices
-    return prim.Variable(name)
+
+    return matrix_inverse(name, expr.ufl_shape, visitor)
-- 
GitLab