Skip to content
Snippets Groups Projects
Commit d9eee811 authored by Marcel Koch's avatar Marcel Koch
Browse files

copy loopy computation of matrix inverse to tensors.py

parent 0a07cf37
No related branches found
No related tags found
No related merge requests found
......@@ -8,12 +8,140 @@ from dune.codegen.generation import (get_counted_variable,
temporary_variable,
)
from loopy.match import Writes
import pymbolic.primitives as prim
import numpy as np
import loopy as lp
import itertools as it
def define_determinant(name, matrix, shape, visitor):
temporary_variable(name)
assert len(shape) == 2 and shape[0] == shape[1]
dim = shape[0]
matrix_entry = [[prim.Subscript(prim.Variable(matrix), (i, j)) for j in range(dim)] for i in range(dim)]
if dim == 2:
expr_determinant = prim.Sum((prim.Product((matrix_entry[0][0], matrix_entry[1][1])),
-1 * prim.Product((matrix_entry[1][0], matrix_entry[0][1]))))
elif dim == 3:
expr_determinant = prim.Sum((prim.Product((matrix_entry[0][0], matrix_entry[1][1], matrix_entry[2][2])),
prim.Product((matrix_entry[0][1], matrix_entry[1][2], matrix_entry[2][0])),
prim.Product((matrix_entry[0][2], matrix_entry[1][0], matrix_entry[2][1])),
-1 * prim.Product((matrix_entry[0][2], matrix_entry[1][1], matrix_entry[2][0])),
-1 * prim.Product((matrix_entry[0][0], matrix_entry[1][2], matrix_entry[2][1])),
-1 * prim.Product((matrix_entry[0][1], matrix_entry[1][0], matrix_entry[2][2]))
))
else:
raise NotImplementedError()
instruction(expression=expr_determinant,
assignee=prim.Variable(name),
within_inames=frozenset(visitor.quadrature_inames()),
depends_on=frozenset({Writes(matrix)})
)
def define_determinant_inverse(name, matrix, shape, visitor):
det = name_determinant(matrix, shape, visitor)
temporary_variable(name)
instruction(expression=prim.Quotient(1, prim.Variable(det)),
assignee=prim.Variable(name),
within_inames=frozenset(visitor.quadrature_inames()),
depends_on=frozenset({Writes(matrix)})
)
def define_matrix_inverse(name, name_inv, shape, visitor):
temporary_variable(name_inv, shape=shape, managed=True)
det_inv = name_determinant_inverse(name, shape, visitor)
assert len(shape) == 2 and shape[0] == shape[1]
dim = shape[0]
matrix_entry = [[prim.Subscript(prim.Variable(name), (i, j)) for j in range(dim)] for i in range(dim)]
assignee = [[prim.Subscript(prim.Variable(name_inv), (i, j)) for j in range(dim)] for i in range(dim)]
exprs = [[None for _ in range(dim)] for _ in range(dim)]
if dim == 2:
for i in range(2):
for j in range(2):
sign = 1. if i == j else -1.
exprs[i][j] = prim.Product((sign, prim.Variable(det_inv), matrix_entry[1 - i][1 - j]))
elif dim == 3:
exprs[0][0] = prim.Product((1., prim.Variable(det_inv),
prim.Sum((prim.Product((matrix_entry[1][1], matrix_entry[2][2])),
-1 * prim.Product((matrix_entry[1][2], matrix_entry[2][1]))))))
exprs[1][0] = prim.Product((-1., prim.Variable(det_inv),
prim.Sum((prim.Product((matrix_entry[0][1], matrix_entry[2][2])),
-1 * prim.Product((matrix_entry[0][2], matrix_entry[2][1]))))))
exprs[2][0] = prim.Product((1., prim.Variable(det_inv),
prim.Sum((prim.Product((matrix_entry[0][1], matrix_entry[1][2])),
-1 * prim.Product((matrix_entry[0][2], matrix_entry[1][1]))))))
exprs[0][1] = prim.Product((-1., prim.Variable(det_inv),
prim.Sum((prim.Product((matrix_entry[1][0], matrix_entry[2][2])),
-1 * prim.Product((matrix_entry[1][2], matrix_entry[2][0]))))))
exprs[1][1] = prim.Product((1., prim.Variable(det_inv),
prim.Sum((prim.Product((matrix_entry[0][0], matrix_entry[2][2])),
-1 * prim.Product((matrix_entry[0][2], matrix_entry[2][0]))))))
exprs[2][1] = prim.Product((-1., prim.Variable(det_inv),
prim.Sum((prim.Product((matrix_entry[0][0], matrix_entry[1][2])),
-1 * prim.Product((matrix_entry[0][2], matrix_entry[1][0]))))))
exprs[0][2] = prim.Product((1., prim.Variable(det_inv),
prim.Sum((prim.Product((matrix_entry[1][0], matrix_entry[2][1])),
-1 * prim.Product((matrix_entry[1][1], matrix_entry[2][0]))))))
exprs[1][2] = prim.Product((-1., prim.Variable(det_inv),
prim.Sum((prim.Product((matrix_entry[0][0], matrix_entry[2][1])),
-1 * prim.Product((matrix_entry[0][1], matrix_entry[2][0]))))))
exprs[2][2] = prim.Product((1., prim.Variable(det_inv),
prim.Sum((prim.Product((matrix_entry[0][0], matrix_entry[1][1])),
-1 * prim.Product((matrix_entry[0][1], matrix_entry[1][0]))))))
else:
raise NotImplementedError
for j in range(dim):
for i in range(dim):
instruction(expression=exprs[i][j],
assignee=assignee[i][j],
within_inames=frozenset(visitor.quadrature_inames()),
depends_on=frozenset({Writes(name)}))
def name_determinant(matrix, shape, visitor):
name = matrix + "_det"
define_determinant(name, matrix, shape, visitor)
return name
def name_determinant_inverse(matrix, shape, visitor):
name = matrix + "_det_inv"
define_determinant_inverse(name, matrix, shape, visitor)
return name
def name_matrix_inverse(name, shape, visitor):
name_inv = name + "_inv"
define_matrix_inverse(name, name_inv, shape, visitor)
return name_inv
def matrix_inverse(name, shape, visitor):
name_inv = name_matrix_inverse(name, shape, visitor)
return prim.Variable(name_inv)
def define_assembled_tensor(name, expr, visitor):
temporary_variable(name,
shape=expr.ufl_shape,
......@@ -22,7 +150,7 @@ def define_assembled_tensor(name, expr, visitor):
visitor.indices = indices
instruction(assignee=prim.Subscript(prim.Variable(name), indices),
expression=visitor.call(expr),
forced_iname_deps=frozenset(visitor.interface.quadrature_inames()),
forced_iname_deps=frozenset(visitor.quadrature_inames()),
depends_on=frozenset({lp.match.Tagged("sumfact_stage1")}),
tags=frozenset({"quad"}),
)
......@@ -37,17 +165,11 @@ def name_assembled_tensor(o, visitor):
@kernel_cached
def pymbolic_matrix_inverse(o, visitor):
expr = o.ufl_operands[0]
indices = visitor.indices
visitor.indices = None
name = name_assembled_tensor(o.ufl_operands[0], visitor)
instruction(code="{}.invert();".format(name),
within_inames=frozenset(visitor.interface.quadrature_inames()),
depends_on=frozenset({lp.match.Writes(name),
lp.match.Tagged("sumfact_stage1"),
}),
tags=frozenset({"quad"}),
)
name = name_assembled_tensor(expr, visitor)
visitor.indices = indices
return prim.Variable(name)
return matrix_inverse(name, expr.ufl_shape, visitor)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment