Skip to content
Snippets Groups Projects
Commit eb28ceb1 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

New take on dimension indices, WIP

parent 09cc5d4c
No related branches found
No related tags found
No related merge requests found
......@@ -35,17 +35,12 @@ from dune.perftool.pdelab.quadrature import quadrature_iname
from pymbolic.primitives import Subscript, Variable
@iname
def index_sum_iname(i):
from dune.perftool.pdelab import name_index
return name_index(i)
class UFL2LoopyVisitor(ModifiedTerminalTracker, UFL2PymbolicMapper, GeometryMapper):
def __init__(self, measure, subdomain_id):
def __init__(self, measure, subdomain_id, dimension_index_aliases):
# Some variables describing the integral measure of this integral
self.measure = measure
self.subdomain_id = subdomain_id
self.dimension_index_aliases = dimension_index_aliases
# Call base class constructors
super(UFL2LoopyVisitor, self).__init__()
......@@ -55,7 +50,6 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker, UFL2PymbolicMapper, GeometryMapp
self.argshape = 0
self.redinames = ()
self.inames = []
self.dimension_index_aliases = []
self.substitution_rules = []
# Initialize the local function spaces that we might need for this term
......@@ -265,10 +259,6 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker, UFL2PymbolicMapper, GeometryMapp
use_indices = self.last_index[self.argshape:]
for i in range(self.argshape):
self.dimension_index_aliases.append(i)
# self.index_placeholder_removal_mapper.index_replacement_map[self.last_index[i].expr] = Variable(dimension_iname(context='arg'))
self.argshape = 0
if isinstance(aggr, Subscript):
return Subscript(aggr.aggregate, aggr.index + use_indices)
......@@ -287,7 +277,7 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker, UFL2PymbolicMapper, GeometryMapp
# Get the iname for the reduction index
ind = o.ufl_operands[1][0]
self.redinames = self.redinames + (index_sum_iname(ind),)
self.redinames = self.redinames + (ind,)
shape = o.ufl_operands[0].ufl_index_dimensions[0]
from dune.perftool.pdelab import name_index
domain(name_index(ind), shape)
......@@ -301,11 +291,9 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker, UFL2PymbolicMapper, GeometryMapp
# Recurse to get the summation expression
term = self.call(o.ufl_operands[0])
self.redinames = tuple(i for i in self.redinames if i not in self.dimension_index_aliases)
if len(self.redinames) > 0:
ret = Reduction("sum", self.redinames, term)
ret = Reduction("sum", tuple(name_index(ind) for ind in self.redinames), term)
name = get_temporary_name()
# Generate a substitution rule for this one.
from loopy import SubstitutionRule
......@@ -333,6 +321,7 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker, UFL2PymbolicMapper, GeometryMapp
from dune.perftool.pdelab import name_index
if index in self.dimension_index_aliases:
from dune.perftool.pdelab.geometry import dimension_iname
self.inames.append(dimension_iname(context='arg'))
return Variable(dimension_iname(context='arg'))
else:
return Variable(name_index(index))
......
......@@ -161,13 +161,16 @@ def generate_kernel(integrals):
subdomain_id = integral.subdomain_id()
subdomain_data = integral.subdomain_data()
from dune.perftool.ufl.dimensionindex import collect_dimension_index_aliases
dimension_index_aliases = collect_dimension_index_aliases(integrand)
# Now split the given integrand into accumulation expressions
from dune.perftool.ufl.transformations.extract_accumulation_terms import split_into_accumulation_terms
accterms = split_into_accumulation_terms(integrand)
# Get a transformer instance for this kernel
from dune.perftool.loopy.transformer import UFL2LoopyVisitor
visitor = UFL2LoopyVisitor(measure, subdomain_id)
visitor = UFL2LoopyVisitor(measure, subdomain_id, dimension_index_aliases)
# Iterate over the terms and generate a kernel
for term in accterms:
......
""" Extract all the aliases of dimension indices """
from ufl.algorithms import MultiFunction
class _CollectDimensionIndexAliases(MultiFunction):
call = MultiFunction.__call__
def __call__(self, o):
self.shape = 0
return self.call(o)
def expr(self, o):
return frozenset({}).union(*tuple(self.call(op) for op in o.ufl_operands))
def terminal(self, o):
return frozenset({})
def function_view(self, o):
self.shape = len(o.ufl_operands[1])
return frozenset({})
def indexed(self, o):
ret = self.call(o.ufl_operands[0])
if self.shape:
ret = ret.union(frozenset({o.ufl_operands[1][:self.shape][0]}))
self.shape = 0
return ret
def collect_dimension_index_aliases(expr):
return _CollectDimensionIndexAliases()(expr)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment