Skip to content
Snippets Groups Projects
Commit cb3390bd authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Do not remove mod args from the expression anymore

UFL did not like  my way of replacing some args with 1 as it
destroyed the shape information. It now remains in the expression,
but triggers code generation separately.
parent c2bb0616
No related branches found
No related tags found
No related merge requests found
......@@ -79,25 +79,6 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker, UFL2PymbolicMapper):
return Variable(name_facetarea())
def get_pymbolic_expr(expr):
""" Transform the given UFL expression into a pymbolic expression
and have all sorts of side effects on the generation cache. """
# We do need to manually handle modified terminals related to trial functions
from dune.perftool.ufl.modified_terminals import extract_modified_arguments
from dune.perftool.ufl.transformations.replace import ReplaceExpression
from dune.perftool.pdelab.argument import name_trialfunction
from pymbolic.primitives import Variable
trial_ma = extract_modified_arguments(expr, trialfunction=True)
# OLD CODE had: globalarg(name)
rmap = {ma.expr: Variable(name_trialfunction(ma)) for ma in trial_ma}
ufl2l_mf = UFL2LoopyVisitor()
re_mf = ReplaceExpression(replacemap=rmap, otherwise=ufl2l_mf)
ufl2l_mf.call = re_mf.__call__
return re_mf(expr)
class _Counter:
counter = 0
......@@ -109,15 +90,32 @@ def get_count():
def transform_accumulation_term(term):
# Get the accumulation expression and the modified arguments
expr, args = term
from dune.perftool.ufl.transformations.replace import ReplaceExpression
from pymbolic.primitives import Variable
# We always have a quadrature loop
quadrature_iname()
# Get the pymbolic expression needed for this accumulation term.
# This includes filling the cache with all sorts of necessary preambles!
pymbolic_expr = get_pymbolic_expr(expr)
from dune.perftool.ufl.modified_terminals import extract_modified_arguments
test_ma = extract_modified_arguments(term, trialfunction=False, testfunction=True)
trial_ma = extract_modified_arguments(term, trialfunction=True, testfunction=False)
rmap = {}
for ma in test_ma:
from dune.perftool.pdelab.argument import name_testfunction
rmap[ma.expr] = Variable(name_testfunction(ma))
for ma in trial_ma:
from dune.perftool.pdelab.argument import name_trialfunction
rmap[ma.expr] = Variable(name_trialfunction(ma))
# Get the transformer!
ufl2l_mf = UFL2LoopyVisitor()
re_mf = ReplaceExpression(replacemap=rmap, otherwise=ufl2l_mf)
ufl2l_mf.call = re_mf.__call__
pymbolic_expr = re_mf(term)
# Now simplify the expression
# TODO: Add a switch to disable/configure this.
......@@ -132,16 +130,14 @@ def transform_accumulation_term(term):
# The data that is used to collect the arguments for the accumulate function
accumargs = []
argument_code = []
# Generate the code for the modified arguments:
for arg in args:
for arg in test_ma:
from dune.perftool.pdelab.argument import name_argumentspace, name_argument
accumargs.append(name_argumentspace(arg))
accumargs.append(argument_iname(arg))
name = name_argument(arg)
argument_code.append(name)
globalarg(name)
# TODO is this global
#globalarg(argument_iname(arg)+"_n")
from dune.perftool.pdelab.argument import name_residual
residual = name_residual()
......@@ -151,10 +147,9 @@ def transform_accumulation_term(term):
from dune.perftool.pdelab.quadrature import name_factor
c_instruction(loopy.CInstruction(inames,
"{}.accumulate({}, {}*{}*{})".format(residual,
"{}.accumulate({}, {}*{})".format(residual,
", ".join(accumargs),
expr_tv_name,
"*".join(argument_code),
name_factor()
)
)
......
......@@ -8,7 +8,7 @@ def quadrature_rule():
return "rule"
@quadrature_preamble(assignees="fac")
@quadrature_preamble()
def define_quadrature_factor(fac):
rule = quadrature_rule()
return "auto {} = {}->weight();".format(fac, rule)
......
......@@ -90,9 +90,10 @@ class ModifiedArgumentDescriptor(MultiFunction):
class _ModifiedArgumentExtractor(MultiFunction):
""" A multifunction that extracts and returns the set of modified arguments """
def __call__(self, o, argnumber=None, trialfunction=False):
def __call__(self, o, argnumber=None, testfunction=True, trialfunction=False):
self.argnumber = argnumber
self.trialfunction = trialfunction
self.testfunction = testfunction
self.modified_arguments = set()
ret = self.call(o)
if ret:
......@@ -127,7 +128,7 @@ class _ModifiedArgumentExtractor(MultiFunction):
return o
def argument(self, o):
if not self.trialfunction:
if self.testfunction:
if self.argnumber is None or o.number() == self.argnumber:
return o
......
......@@ -48,11 +48,7 @@ class UFLTransformationWrapper(object):
# We do also assume that the transformation returns an ufl expression or a list there of
ret_for_print = self.extractExpressionListFromResult(ret)
try:
assert isinstance(ret_for_print, list) and all(isinstance(e, Expr) for e in ret_for_print)
except AssertionError:
from IPython import embed
embed()
assert isinstance(ret_for_print, list) and all(isinstance(e, Expr) for e in ret_for_print)
# Maybe output the returned expression
self.write_trafo(ret_for_print, False)
......
......@@ -10,22 +10,22 @@ from dune.perftool.ufl.transformations import ufl_transformation
from dune.perftool.ufl.transformations.replace import replace_expression
from ufl.algorithms import MultiFunction
from ufl.classes import Zero, IntValue
from ufl.classes import Zero
import itertools
class _ReplacementDict(dict):
def __init__(self, *mod_args):
def __init__(self, *args):
dict.__init__(self)
for ma in mod_args:
self[ma] = IntValue(1)
for a in args:
self[a] = a
def __getitem__(self, key):
return self.get(key, Zero())
@ufl_transformation(name="accterms2", extraction_lambda=lambda l: [i[0] for i in l])
@ufl_transformation(name="accterms2", extraction_lambda=lambda l: l)
def split_into_accumulation_terms(expr):
mod_args = extract_modified_arguments(expr)
......@@ -35,18 +35,13 @@ def split_into_accumulation_terms(expr):
if len(filter(lambda ma: ma.argexpr.count() == 1, mod_args)) == 0:
for arg in mod_args:
# Do the replacement on the expression
accum_expr = replace_expression(expr, replacemap=_ReplacementDict(arg.expr))
# Store the found accumulation expression
accumulation_terms.append((accum_expr, (arg,)))
accumulation_terms.append(replace_expression(expr, replacemap=_ReplacementDict(arg.expr)))
# and now the case of a rank 2 form:
else:
for arg1, arg2 in itertools.product(filter(lambda ma: ma.argexpr.count() == 0, mod_args),
filter(lambda ma: ma.argexpr.count() == 1, mod_args)
):
accum_expr = replace_expression(expr, replacemap=_ReplacementDict(arg1.expr, arg2.expr))
accumulation_terms.append((accum_expr, (arg1, arg2)))
accumulation_terms.append(replace_expression(expr, replacemap=_ReplacementDict(arg1.expr, arg2.expr)))
# and return the result
return accumulation_terms
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment