Skip to content
Snippets Groups Projects
Commit c3fd8493 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Rip out horrible thing to determine bounds

parent d275dd12
No related branches found
No related tags found
No related merge requests found
...@@ -184,18 +184,17 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker, UFL2PymbolicMapper, GeometryMapp ...@@ -184,18 +184,17 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker, UFL2PymbolicMapper, GeometryMapp
def index_sum(self, o): def index_sum(self, o):
from loopy import Reduction from loopy import Reduction
from dune.perftool.ufl.shape import determine_shape
oldinames = self.inames oldinames = self.inames
self.inames = [] self.inames = []
red_inames = () red_inames = ()
# Define an iname for each of the indices in the multiindex # Define an iname for each of the indices in the multiindex
for i in o.ufl_operands[1].indices(): for i, ind in enumerate(o.ufl_operands[1].indices()):
red_inames = red_inames + (index_sum_iname(i),) red_inames = red_inames + (index_sum_iname(ind),)
shape = determine_shape(o.ufl_operands[0], i) shape = o.ufl_operands[0].ufl_index_dimensions[i]
from dune.perftool.pdelab import name_index from dune.perftool.pdelab import name_index
domain(name_index(i), shape) domain(name_index(ind), shape)
# Recurse to get the summation expression # Recurse to get the summation expression
term = self.call(o.ufl_operands[0]) term = self.call(o.ufl_operands[0])
......
""" An algorithm to determine the shape aka the loop domain for a given index """
from ufl.algorithms import MultiFunction
class ShapeDetermination(MultiFunction):
def __init__(self, index):
MultiFunction.__init__(self)
self.index = index
def expr(self, o):
ret = tuple(self(op) for op in o.ufl_operands)
asset = set(ret)
if len(asset) == 0:
return None
if len(asset) == 1:
# All determined shapes are equal. We just use the result.
return ret[0]
if len(asset) == 2:
try:
# Two shapes determined, one might be "None" so we use the other.
asset.remove(None)
return ret[0]
except KeyError:
pass
raise AssertionError("I had trouble determining the shape of an expression!")
def indexed(self, o):
try:
position = o.ufl_operands[1].indices().index(self.index)
return o.ufl_operands[0].ufl_shape[position]
except ValueError:
return self(o.ufl_operands[0])
def determine_shape(expr, i):
return ShapeDetermination(i)(expr)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment