From 6b21e04da600ecf2f23418c837bad5380817c484 Mon Sep 17 00:00:00 2001 From: Dominic Kempf <dominic.r.kempf@gmail.com> Date: Thu, 10 Sep 2015 17:45:48 +0200 Subject: [PATCH] Update the validation checker and introduce ufl subpackage --- python/dune/perftool/compile.py | 20 +++++- python/dune/perftool/transformer.py | 10 +-- python/dune/perftool/ufl/__init__.py | 0 .../dune/perftool/ufl/modified_terminals.py | 36 ++++++++++ python/dune/perftool/ufl/validity.py | 47 ++++++++++++++ python/dune/perftool/validity.py | 65 ------------------- 6 files changed, 108 insertions(+), 70 deletions(-) create mode 100644 python/dune/perftool/ufl/__init__.py create mode 100644 python/dune/perftool/ufl/modified_terminals.py create mode 100644 python/dune/perftool/ufl/validity.py delete mode 100644 python/dune/perftool/validity.py diff --git a/python/dune/perftool/compile.py b/python/dune/perftool/compile.py index e480fd50..e3124f69 100644 --- a/python/dune/perftool/compile.py +++ b/python/dune/perftool/compile.py @@ -11,9 +11,27 @@ def read_ufl(uflfile): uflcode = read_ufl_file(uflfile) namespace = {} uflcode = "from dune.perftool.runtime_ufl import *\n" + uflcode - exec uflcode in namespace + try: + exec uflcode in namespace + except: + import os + basename = os.path.splitext(os.path.basename(uflfile[0]))[0] + basename = "{}_debug".format(basename) + pyname = "{}.py".format(basename) + pycode = "#!/usr/bin/env python\nfrom dune.perftool.runtime_ufl import *\nset_level(DEBUG)\n" + uflcode + with file(pyname, "w") as f: + f.write(pycode) + raise SyntaxError("Not a valid ufl file, dumped a debug script: {}".format(pyname)) data = interpret_ufl_namespace(namespace) formdata = compute_form_data(data.forms[0]) + + # We do not expect more than one form + assert len(data.forms) == 1 + + # We make some assumptions on the UFL expression! + from dune.perftool.validity import check_validity + check_validity(data.forms[0]) + return formdata diff --git a/python/dune/perftool/transformer.py b/python/dune/perftool/transformer.py index 05da1205..903351ca 100644 --- a/python/dune/perftool/transformer.py +++ b/python/dune/perftool/transformer.py @@ -165,13 +165,14 @@ class UFLVisitor(MultiFunction): class TopSumSeparation(MultiFunction): """ A multifunction that separates the toplevel sum """ - def __init__(self): + def __init__(self, visitor=UFLVisitor()): MultiFunction.__init__(self) - self.visitor = UFLVisitor() + self.visitor = visitor self.inames = [] def expr(self, o): self.visitor(o, self.inames) + self.inames = [] def sum(self, o): for op in o.operands(): @@ -180,9 +181,10 @@ class TopSumSeparation(MultiFunction): def index_sum(self, o): # Generate an iname: We do this here and restrict ourselves to shape dim. # TODO revisit when implementing general index sums - self.inames.append(dimension_iname(o.operands()[1])) + if isinstance(self.visitor, UFLVisitor): + self.inames.append(dimension_iname(o.operands()[1])) self(o.operands()[0]) def transform_expression(expr): - return TopSumSeparation()(expr) + TopSumSeparation()(expr) diff --git a/python/dune/perftool/ufl/__init__.py b/python/dune/perftool/ufl/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/dune/perftool/ufl/modified_terminals.py b/python/dune/perftool/ufl/modified_terminals.py new file mode 100644 index 00000000..b8c3900c --- /dev/null +++ b/python/dune/perftool/ufl/modified_terminals.py @@ -0,0 +1,36 @@ +""" A module mimicking some functionality of uflacs' modified terminals """ + +from ufl.algorithms import MultiFunction +from dune.perftool.restriction import Restriction + + +class ModifiedTerminalTracker(MultiFunction): + def __init__(self): + MultiFunction.__init__(self) + self.grad = False + self.reference_grad = False + self.restriction = Restriction.NONE + + def positive_restricted(self, o): + self.restriction = Restriction.POSITIVE + ret = self.expr(o) + self.restriction = Restriction.NONE + return ret + + def negative_restricted(self, o): + self.restriction = Restriction.NEGATIVE + ret = self.expr(o) + self.restriction = Restriction.NONE + return ret + + def grad(self, o): + self.grad = True + ret = self.expr(o) + self.grad = False + return ret + + def reference_grad(self, o): + self.reference_grad = True + ret = self.expr(o) + self.reference_grad = False + return ret diff --git a/python/dune/perftool/ufl/validity.py b/python/dune/perftool/ufl/validity.py new file mode 100644 index 00000000..8690adb0 --- /dev/null +++ b/python/dune/perftool/ufl/validity.py @@ -0,0 +1,47 @@ +from ufl.algorithms import MultiFunction +from dune.perftool.ufl.modified_terminals import ModifiedTerminalTracker + + +class UFLRank(MultiFunction): + def __call__(self, expr): + return len(MultiFunction.__call__(self, expr)) + + def expr(self, o): + return set(a for op in o.operands() for a in MultiFunction.__call__(self, op)) + + def argument(self, o): + return (o.number(),) + + +class ArgumentCounter(ModifiedTerminalTracker): + def __call__(self, o, mock): + if not len(MultiFunction.__call__(self, o)) == UFLRank()(o): + raise ValueError("Interal Error: The transformed form violated the pdelab generation assumptions. Please send your ufl file to dominic.kempf@iwr.uni-heidelberg.de") + + def expr(self, o): + return set(a for op in o.operands() for a in MultiFunction.__call__(self, op)) + + def argument(self, o): + return ((o.number(), self.grad, self.reference_grad, self.restriction),) + + +def check_validity(uflexpr): + """ check an UFL expression (usually an integrand) for + compatibility with the dune-pdelab code generation tool chain. + + The assumptions made are the following: + - The expression is the sum of subtrees, where the number of test + function terms in such subtree is equal the rank of the entire expression. + """ + from ufl import Form + from ufl.classes import Expr + from dune.perftool.transformer import TopSumSeparation + + if isinstance(uflexpr, Form): + for integral in uflexpr.integrals(): + return TopSumSeparation(visitor=ArgumentCounter())(integral.integrand()) + + if isinstance(uflexpr, Expr): + return TopSumSeparation(visitor=ArgumentCounter())(uflexpr) + + raise TypeError("Unknown object type in check_validity: {}".format(type(uflexpr))) diff --git a/python/dune/perftool/validity.py b/python/dune/perftool/validity.py deleted file mode 100644 index b88b6881..00000000 --- a/python/dune/perftool/validity.py +++ /dev/null @@ -1,65 +0,0 @@ -from ufl.classes import Sum -from ufl.algorithms import MultiFunction - - -class UFLRank(MultiFunction): - def __call__(self, expr): - return len(MultiFunction.__call__(self, expr)) - - def expr(self, o): - return set(a for op in o.operands() for a in MultiFunction.__call__(self, op)) - - def argument(self, o): - return (o.number(),) - - -class UFLValidityChecker(MultiFunction): - def __init__(self): - MultiFunction.__init__(self) - self.rankCounter = UFLRank() - - def __call__(self, expr, topsum=True): - self.rank = self.rankCounter(expr) - self.sane = None - ret = MultiFunction.__call__(self, expr, topsum) - if self.sane is None: - # In this case, we did not encounter a sum and need to compare with the entire expression - self.sane = len(ret) == self.rank - return self.sane - - # define the default behaviour: just call the multifunction - # recursively and combine the sets of found arguments - def expr(self, o, topsum): - return set(a for op in o.operands() for a in MultiFunction.__call__(self, op, False)) - - def argument(self, o, topsum): - return (o.number(),) - - def sum(self, o, topsum): - if topsum: - for op in o.operands(): - if not isinstance(op, Sum): - self.sane = len(MultiFunction.__call__(self, op, True)) == self.rank - else: - # If this is sum is not part of the topsum, we treat it as all other expressions - return self.expr(o, topsum) - - def index_sum(self, o, topsum): - assert len(o.operands()) == 2 - if topsum: - op = o.operands()[0] - if not isinstance(op, Sum): - self.sane = len(MultiFunction.__call__(self, op, True)) == self.rank - else: - return self.expr(o, topsum) - - -def check_validity(uflexpr): - """ check an UFL expression (usually an integrand) for - compatibility with the dune-pdelab code generation tool chain. - - The assumptions made are the following: - - The expression is the sum of subtrees, where the number of test - function terms in such subtree is equal the rank of the entire expression. - """ - return UFLValidityChecker()(uflexpr) \ No newline at end of file -- GitLab