Skip to content
Snippets Groups Projects
Commit ef543709 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

More fixes on new splitting implementation

and remove old implementation
parent 99d5562e
No related branches found
No related tags found
No related merge requests found
*~
.cache/*
......@@ -44,6 +44,49 @@ class ModifiedTerminalTracker(MultiFunction):
return ret
class ModifiedArgumentDescriptor(MultiFunction):
def __init__(self, e):
MultiFunction.__init__(self)
self.grad = False
self.reference_grad = False
self.index = None
self.restriction = Restriction.NONE
self.expr = e
self.__call__(e)
self.__call__ = None
def __eq__(self, other):
return self.expr == other.expr
def grad(self, o):
self.grad = True
self(o.ufl_operands[0])
def reference_grad(self, o):
self.reference_grad = True
self(o.ufl_operands[0])
def positive_restricted(self, o):
self.restriction = Restriction.POSITIVE
self(o.ufl_operands[0])
def negative_restricted(self, o):
self.restriction = Restriction.NEGATIVE
self(o.ufl_operands[0])
def indexed(self, o):
indexed = o.ufl_operands[1]
self(o.ufl_operands[0])
def argument(self, o):
self.argexpr = o
def coefficient(self, o):
self.argexpr = o
class _ModifiedArgumentExtractor(MultiFunction):
""" A multifunction that extracts and returns the set of modified arguments """
......@@ -55,7 +98,7 @@ class _ModifiedArgumentExtractor(MultiFunction):
if ret:
# This indicates that this entire expression was a modified thing...
self.modified_arguments.add(ret)
return tuple(self.modified_arguments)
return tuple(ModifiedArgumentDescriptor(ma) for ma in self.modified_arguments)
def expr(self, o):
for op in o.ufl_operands:
......@@ -112,42 +155,3 @@ class _ModifiedArgumentNumber(MultiFunction):
def modified_argument_number(expr):
""" Given an expression, return the number() of the argument in it """
return _ModifiedArgumentNumber()(expr)
class ModifiedArgumentDescriptor(MultiFunction):
def __init__(self, e):
MultiFunction.__init__(self)
self.grad = False
self.reference_grad = False
self.index = None
self.restriction = Restriction.NONE
self.__call__(e)
self.__call__ = None
def grad(self, o):
self.grad = True
self(o.ufl_operands[0])
def reference_grad(self, o):
self.reference_grad = True
self(o.ufl_operands[0])
def positive_restricted(self, o):
self.restriction = Restriction.POSITIVE
self(o.ufl_operands[0])
def negative_restricted(self, o):
self.restriction = Restriction.NEGATIVE
self(o.ufl_operands[0])
def indexed(self, o):
indexed = o.ufl_operands[1]
self(o.ufl_operands[0])
def argument(self, o):
self.expr = o
def coefficient(self, o):
self.expr = o
"""
Define an UFL MultiFunction that removes the modified arguments from
a given UFL expression(, which represents one accumumlation term).
Examples:
e1=e2*v => (e2, (v, ))
e1=e2*v*w => (e2, (v, w))
Note that in PDELab, only test functions are arguments!
Trial functions are coefficients instead.
"""
# from __future__ import absolute_import
# from ufl.algorithms import MultiFunction
#
#
# class EliminateArguments(MultiFunction):
# """ This MultiFunction processes the expression bottom up and replaces
# all modified argument by None and eliminates all None-Terms afterwards.
# """
# call = MultiFunction.__call__
#
# def __call__(self, o):
# from dune.perftool.ufl.modified_terminals import ModifiedArgumentExtractor
#
# self.arguments = ModifiedArgumentExtractor()(o)
# e = self.call(o)
#
# # Catch the case that the entire expression vanished!
# if e is None:
# from ufl.classes import IntValue
# e = IntValue(1)
#
# return (e, self.arguments)
#
# def expr(self, o):
# if o in self.arguments:
# return None
# else:
# # Evaluate the multi function applied to the operands
# newop = tuple(self.call(op) for op in o.ufl_operands)
# # Find out whether an operand vanished. If so, the class needs special treatment.
# if None in newop:
# raise NotImplementedError("Operand vanished: {} needs special treatment in EliminateArguments".format(type(o)))
# return self.reuse_if_untouched(o, *newop)
#
# def sum(self, o):
# assert len(o.ufl_operands) == 2
#
# op0 = self.call(o.ufl_operands[0])
# op1 = self.call(o.ufl_operands[1])
#
# if op0 and op1:
# return self.reuse_if_untouched(o, op0, op1)
# # One term vanished, so there is no sum anymore
# else:
# if op0 or op1:
# # Return the term that did not vanish
# return op0 if op0 else op1
# else:
# # This entire sum vanished!!!
# return None
#
# # The handler for product is equal to the sum handler
# product = sum
#
# def index_sum(self, o):
# op = self.call(o.ufl_operands[0])
# if op:
# return self.reuse_if_untouched(o, op, o.ufl_operands[1])
# else:
# return None
......@@ -7,7 +7,22 @@ from __future__ import absolute_import
from dune.perftool.ufl.modified_terminals import extract_modified_arguments
from dune.perftool.ufl.transformations import ufl_transformation
from dune.perftool.ufl.transformations.replace import replace_expression
from ufl.algorithms import MultiFunction
from ufl.classes import Zero, IntValue
import itertools
class _ReplacementDict(dict):
def __init__(self, *mod_args):
dict.__init__(self)
for ma in mod_args:
self[ma] = IntValue(1)
def __getitem__(self, key):
return self.get(key, Zero())
@ufl_transformation(name="accterms2", extraction_lambda=lambda l: [i[0] for i in l])
......@@ -15,80 +30,23 @@ def split_into_accumulation_terms(expr):
mod_args = extract_modified_arguments(expr)
accumulation_terms = []
for arg in mod_args:
from dune.perftool.ufl.transformations.replace import replace_expression
from ufl.classes import Zero, IntValue
# Define a replacement map that maps the given arg to 1 and the rest to 0
rmap = {ma: Zero() for ma in mod_args}
rmap[arg] = IntValue(1)
# Do the replacement on the expression
accum_expr = replace_expression(expr, rmap)
# Store the foudn accumulation expression
accumulation_terms.append((accum_expr, arg))
# Treat the case of a rank 1 form:
if len(filter(lambda ma: ma.argexpr.count() == 1, mod_args)) == 0:
for arg in mod_args:
# Do the replacement on the expression
accum_expr = replace_expression(expr, _ReplacementDict(arg))
return accumulation_terms
# Store the found accumulation expression
accumulation_terms.append((accum_expr, (arg,)))
# and now the case of a rank 2 form:
else:
for arg1, arg2 in itertools.product(filter(lambda ma: ma.argexpr.count() == 0, mod_args),
filter(lambda ma: ma.argexpr.count() == 1, mod_args)
):
accum_expr = replace_expression(expr, _ReplacementDict(arg1, arg2))
accumulation_terms.append((accum_expr, (arg1, arg2)))
# class SplitIntoAccumulationTerms(MultiFunction):
# """ return a list of tuples of expressions and modified arguments! """
#
# call = MultiFunction.__call__
#
# def __call__(self, expr):
# from dune.perftool.ufl.rank import ufl_rank
# self.rank = ufl_rank(expr)
#
# from dune.perftool.ufl.modified_terminals import ModifiedArgumentExtractor
# self.mae = ModifiedArgumentExtractor()
#
# from dune.perftool.ufl.transformations.argument_elimination import EliminateArguments
# self.ea = EliminateArguments()
#
# # Collect the found terms
# self.terms = {}
#
# # Now, fill the terms dict by applying this multifunction!
# self.call(expr)
#
# from dune.perftool.ufl.flatoperators import construct_binary_operator
# from ufl.classes import Sum
# return [(construct_binary_operator(t, Sum), a) for a, t in self.terms.items()]
#
# def expr(self, o):
# # This needs to be a valid accumulation term!
# assert all(len(self.mae(o, i)) == 1 for i in range(self.rank))
#
# expr, arg = self.ea(o)
# if arg not in self.terms:
# self.terms[arg] = []
# self.terms[arg].append(expr)
#
# def sum(self, o):
# # Check whether this sums contains too many accumulation terms.
# if not all(len(self.mae(o, i)) == 1 for i in range(self.rank)):
# # This sum is part of a top level sum that separates accumulation terms!
# for op in o.ufl_operands:
# self.call(op)
# else:
# # This is a normal sum, we might treat it as any other expression
# self.expr(o)
#
# def index_sum(self, o):
# # Check whether this sums contains too many accumulation terms.
# if not all(len(self.mae(o, i)) == 1 for i in range(self.rank)):
# # This sum is part of a top level sum that separates accumulation terms!
# self.call(o.ufl_operands[0])
# else:
# # This is a normal sum, we might treat it as any other expression
# # TODO we need to eliminate topsum indexsum regardless of the thing being valid
# self.expr(o.ufl_operands[0])
# # old code
# #self.expr(o)
#
#
# from dune.perftool.ufl.transformations import ufl_transformation
# @ufl_transformation(name="accterms", extraction_lambda=lambda l:[i[0] for i in l])
# def split_into_accumulation_terms2(expr):
# return SplitIntoAccumulationTerms()(expr)
# and return the result
return accumulation_terms
# from __future__ import absolute_import
# from ufl.algorithms import MultiFunction
# from dune.perftool.ufl.transformations import ufl_transformation
#
# class SplitArguments(MultiFunction):
#
# call = MultiFunction.__call__
#
# def __init__(self):
# MultiFunction.__init__(self)
#
# def __call__(self, o):
# from dune.perftool.ufl.modified_terminals import ModifiedArgumentExtractor
# self.ae = ModifiedArgumentExtractor()
# from dune.perftool.ufl.rank import ufl_rank
# self.rank = ufl_rank(o)
# # call the actual recursive function
# return self.call(o)
#
# def expr(self, o):
# return self.reuse_if_untouched(o, *tuple(self.call(op) for op in o.ufl_operands))
#
# def product(self, o):
# from dune.perftool.ufl.flatoperators import get_operands, construct_binary_operator
# from itertools import product as iterproduct
# from ufl.classes import Sum, Product
#
# # Check whether this product is fine!
# if len(self.ae(o)) == self.rank:
# return self.reuse_if_untouched(o, *(self.call(op) for op in o.ufl_operands))
#
# # It is not, lets switch sums and products!
# # First we apply recursively to all
# product_operands = [get_operands(self.call(op)) if isinstance(op, Sum) else (self.call(op),) for op in get_operands(o)]
# # Multiply all terms by taking the cartesian product of terms
# distributive = [f for f in iterproduct(*product_operands)]
# # Prepare all sum terms by introducing products
# sum_terms = [construct_binary_operator(s, Product) for s in distributive]
# # Return the big sum.
# return construct_binary_operator(sum_terms, Sum)
#
#
# @ufl_transformation(name="split")
# def split_arguments(expr):
# return SplitArguments()(expr)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment