From ef543709e75c3aa3e2a414b1e1f57bf9c357d88a Mon Sep 17 00:00:00 2001 From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de> Date: Wed, 24 Feb 2016 13:31:48 +0100 Subject: [PATCH] More fixes on new splitting implementation and remove old implementation --- .gitignore | 1 + .../dune/perftool/ufl/modified_terminals.py | 84 +++++++------- .../transformations/argument_elimination.py | 71 ------------ .../extract_accumulation_terms.py | 104 ++++++------------ .../ufl/transformations/splitarguments.py | 46 -------- 5 files changed, 76 insertions(+), 230 deletions(-) delete mode 100644 python/dune/perftool/ufl/transformations/argument_elimination.py delete mode 100644 python/dune/perftool/ufl/transformations/splitarguments.py diff --git a/.gitignore b/.gitignore index b25c15b8..55c65aee 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ *~ +.cache/* diff --git a/python/dune/perftool/ufl/modified_terminals.py b/python/dune/perftool/ufl/modified_terminals.py index 9c2ce6e8..9bc23612 100644 --- a/python/dune/perftool/ufl/modified_terminals.py +++ b/python/dune/perftool/ufl/modified_terminals.py @@ -44,6 +44,49 @@ class ModifiedTerminalTracker(MultiFunction): return ret +class ModifiedArgumentDescriptor(MultiFunction): + def __init__(self, e): + MultiFunction.__init__(self) + + self.grad = False + self.reference_grad = False + self.index = None + self.restriction = Restriction.NONE + self.expr = e + + self.__call__(e) + self.__call__ = None + + def __eq__(self, other): + return self.expr == other.expr + + def grad(self, o): + self.grad = True + self(o.ufl_operands[0]) + + def reference_grad(self, o): + self.reference_grad = True + self(o.ufl_operands[0]) + + def positive_restricted(self, o): + self.restriction = Restriction.POSITIVE + self(o.ufl_operands[0]) + + def negative_restricted(self, o): + self.restriction = Restriction.NEGATIVE + self(o.ufl_operands[0]) + + def indexed(self, o): + indexed = o.ufl_operands[1] + self(o.ufl_operands[0]) + + def argument(self, o): + self.argexpr = o + + def coefficient(self, o): + self.argexpr = o + + class _ModifiedArgumentExtractor(MultiFunction): """ A multifunction that extracts and returns the set of modified arguments """ @@ -55,7 +98,7 @@ class _ModifiedArgumentExtractor(MultiFunction): if ret: # This indicates that this entire expression was a modified thing... self.modified_arguments.add(ret) - return tuple(self.modified_arguments) + return tuple(ModifiedArgumentDescriptor(ma) for ma in self.modified_arguments) def expr(self, o): for op in o.ufl_operands: @@ -112,42 +155,3 @@ class _ModifiedArgumentNumber(MultiFunction): def modified_argument_number(expr): """ Given an expression, return the number() of the argument in it """ return _ModifiedArgumentNumber()(expr) - - -class ModifiedArgumentDescriptor(MultiFunction): - def __init__(self, e): - MultiFunction.__init__(self) - - self.grad = False - self.reference_grad = False - self.index = None - self.restriction = Restriction.NONE - - self.__call__(e) - self.__call__ = None - - def grad(self, o): - self.grad = True - self(o.ufl_operands[0]) - - def reference_grad(self, o): - self.reference_grad = True - self(o.ufl_operands[0]) - - def positive_restricted(self, o): - self.restriction = Restriction.POSITIVE - self(o.ufl_operands[0]) - - def negative_restricted(self, o): - self.restriction = Restriction.NEGATIVE - self(o.ufl_operands[0]) - - def indexed(self, o): - indexed = o.ufl_operands[1] - self(o.ufl_operands[0]) - - def argument(self, o): - self.expr = o - - def coefficient(self, o): - self.expr = o diff --git a/python/dune/perftool/ufl/transformations/argument_elimination.py b/python/dune/perftool/ufl/transformations/argument_elimination.py deleted file mode 100644 index 4d380009..00000000 --- a/python/dune/perftool/ufl/transformations/argument_elimination.py +++ /dev/null @@ -1,71 +0,0 @@ -""" -Define an UFL MultiFunction that removes the modified arguments from -a given UFL expression(, which represents one accumumlation term). - -Examples: -e1=e2*v => (e2, (v, )) -e1=e2*v*w => (e2, (v, w)) - -Note that in PDELab, only test functions are arguments! -Trial functions are coefficients instead. -""" -# from __future__ import absolute_import -# from ufl.algorithms import MultiFunction -# -# -# class EliminateArguments(MultiFunction): -# """ This MultiFunction processes the expression bottom up and replaces -# all modified argument by None and eliminates all None-Terms afterwards. -# """ -# call = MultiFunction.__call__ -# -# def __call__(self, o): -# from dune.perftool.ufl.modified_terminals import ModifiedArgumentExtractor -# -# self.arguments = ModifiedArgumentExtractor()(o) -# e = self.call(o) -# -# # Catch the case that the entire expression vanished! -# if e is None: -# from ufl.classes import IntValue -# e = IntValue(1) -# -# return (e, self.arguments) -# -# def expr(self, o): -# if o in self.arguments: -# return None -# else: -# # Evaluate the multi function applied to the operands -# newop = tuple(self.call(op) for op in o.ufl_operands) -# # Find out whether an operand vanished. If so, the class needs special treatment. -# if None in newop: -# raise NotImplementedError("Operand vanished: {} needs special treatment in EliminateArguments".format(type(o))) -# return self.reuse_if_untouched(o, *newop) -# -# def sum(self, o): -# assert len(o.ufl_operands) == 2 -# -# op0 = self.call(o.ufl_operands[0]) -# op1 = self.call(o.ufl_operands[1]) -# -# if op0 and op1: -# return self.reuse_if_untouched(o, op0, op1) -# # One term vanished, so there is no sum anymore -# else: -# if op0 or op1: -# # Return the term that did not vanish -# return op0 if op0 else op1 -# else: -# # This entire sum vanished!!! -# return None -# -# # The handler for product is equal to the sum handler -# product = sum -# -# def index_sum(self, o): -# op = self.call(o.ufl_operands[0]) -# if op: -# return self.reuse_if_untouched(o, op, o.ufl_operands[1]) -# else: -# return None diff --git a/python/dune/perftool/ufl/transformations/extract_accumulation_terms.py b/python/dune/perftool/ufl/transformations/extract_accumulation_terms.py index 4b186da5..290df135 100644 --- a/python/dune/perftool/ufl/transformations/extract_accumulation_terms.py +++ b/python/dune/perftool/ufl/transformations/extract_accumulation_terms.py @@ -7,7 +7,22 @@ from __future__ import absolute_import from dune.perftool.ufl.modified_terminals import extract_modified_arguments from dune.perftool.ufl.transformations import ufl_transformation +from dune.perftool.ufl.transformations.replace import replace_expression + from ufl.algorithms import MultiFunction +from ufl.classes import Zero, IntValue + +import itertools + + +class _ReplacementDict(dict): + def __init__(self, *mod_args): + dict.__init__(self) + for ma in mod_args: + self[ma] = IntValue(1) + + def __getitem__(self, key): + return self.get(key, Zero()) @ufl_transformation(name="accterms2", extraction_lambda=lambda l: [i[0] for i in l]) @@ -15,80 +30,23 @@ def split_into_accumulation_terms(expr): mod_args = extract_modified_arguments(expr) accumulation_terms = [] - for arg in mod_args: - from dune.perftool.ufl.transformations.replace import replace_expression - from ufl.classes import Zero, IntValue - # Define a replacement map that maps the given arg to 1 and the rest to 0 - rmap = {ma: Zero() for ma in mod_args} - rmap[arg] = IntValue(1) - - # Do the replacement on the expression - accum_expr = replace_expression(expr, rmap) - # Store the foudn accumulation expression - accumulation_terms.append((accum_expr, arg)) + # Treat the case of a rank 1 form: + if len(filter(lambda ma: ma.argexpr.count() == 1, mod_args)) == 0: + for arg in mod_args: + # Do the replacement on the expression + accum_expr = replace_expression(expr, _ReplacementDict(arg)) - return accumulation_terms + # Store the found accumulation expression + accumulation_terms.append((accum_expr, (arg,))) + # and now the case of a rank 2 form: + else: + for arg1, arg2 in itertools.product(filter(lambda ma: ma.argexpr.count() == 0, mod_args), + filter(lambda ma: ma.argexpr.count() == 1, mod_args) + ): + accum_expr = replace_expression(expr, _ReplacementDict(arg1, arg2)) + accumulation_terms.append((accum_expr, (arg1, arg2))) -# class SplitIntoAccumulationTerms(MultiFunction): -# """ return a list of tuples of expressions and modified arguments! """ -# -# call = MultiFunction.__call__ -# -# def __call__(self, expr): -# from dune.perftool.ufl.rank import ufl_rank -# self.rank = ufl_rank(expr) -# -# from dune.perftool.ufl.modified_terminals import ModifiedArgumentExtractor -# self.mae = ModifiedArgumentExtractor() -# -# from dune.perftool.ufl.transformations.argument_elimination import EliminateArguments -# self.ea = EliminateArguments() -# -# # Collect the found terms -# self.terms = {} -# -# # Now, fill the terms dict by applying this multifunction! -# self.call(expr) -# -# from dune.perftool.ufl.flatoperators import construct_binary_operator -# from ufl.classes import Sum -# return [(construct_binary_operator(t, Sum), a) for a, t in self.terms.items()] -# -# def expr(self, o): -# # This needs to be a valid accumulation term! -# assert all(len(self.mae(o, i)) == 1 for i in range(self.rank)) -# -# expr, arg = self.ea(o) -# if arg not in self.terms: -# self.terms[arg] = [] -# self.terms[arg].append(expr) -# -# def sum(self, o): -# # Check whether this sums contains too many accumulation terms. -# if not all(len(self.mae(o, i)) == 1 for i in range(self.rank)): -# # This sum is part of a top level sum that separates accumulation terms! -# for op in o.ufl_operands: -# self.call(op) -# else: -# # This is a normal sum, we might treat it as any other expression -# self.expr(o) -# -# def index_sum(self, o): -# # Check whether this sums contains too many accumulation terms. -# if not all(len(self.mae(o, i)) == 1 for i in range(self.rank)): -# # This sum is part of a top level sum that separates accumulation terms! -# self.call(o.ufl_operands[0]) -# else: -# # This is a normal sum, we might treat it as any other expression -# # TODO we need to eliminate topsum indexsum regardless of the thing being valid -# self.expr(o.ufl_operands[0]) -# # old code -# #self.expr(o) -# -# -# from dune.perftool.ufl.transformations import ufl_transformation -# @ufl_transformation(name="accterms", extraction_lambda=lambda l:[i[0] for i in l]) -# def split_into_accumulation_terms2(expr): -# return SplitIntoAccumulationTerms()(expr) + # and return the result + return accumulation_terms diff --git a/python/dune/perftool/ufl/transformations/splitarguments.py b/python/dune/perftool/ufl/transformations/splitarguments.py deleted file mode 100644 index bbc89545..00000000 --- a/python/dune/perftool/ufl/transformations/splitarguments.py +++ /dev/null @@ -1,46 +0,0 @@ -# from __future__ import absolute_import -# from ufl.algorithms import MultiFunction -# from dune.perftool.ufl.transformations import ufl_transformation - -# -# class SplitArguments(MultiFunction): -# -# call = MultiFunction.__call__ -# -# def __init__(self): -# MultiFunction.__init__(self) -# -# def __call__(self, o): -# from dune.perftool.ufl.modified_terminals import ModifiedArgumentExtractor -# self.ae = ModifiedArgumentExtractor() -# from dune.perftool.ufl.rank import ufl_rank -# self.rank = ufl_rank(o) -# # call the actual recursive function -# return self.call(o) -# -# def expr(self, o): -# return self.reuse_if_untouched(o, *tuple(self.call(op) for op in o.ufl_operands)) -# -# def product(self, o): -# from dune.perftool.ufl.flatoperators import get_operands, construct_binary_operator -# from itertools import product as iterproduct -# from ufl.classes import Sum, Product -# -# # Check whether this product is fine! -# if len(self.ae(o)) == self.rank: -# return self.reuse_if_untouched(o, *(self.call(op) for op in o.ufl_operands)) -# -# # It is not, lets switch sums and products! -# # First we apply recursively to all -# product_operands = [get_operands(self.call(op)) if isinstance(op, Sum) else (self.call(op),) for op in get_operands(o)] -# # Multiply all terms by taking the cartesian product of terms -# distributive = [f for f in iterproduct(*product_operands)] -# # Prepare all sum terms by introducing products -# sum_terms = [construct_binary_operator(s, Product) for s in distributive] -# # Return the big sum. -# return construct_binary_operator(sum_terms, Sum) -# -# -# @ufl_transformation(name="split") -# def split_arguments(expr): -# return SplitArguments()(expr) -- GitLab