From 6b21e04da600ecf2f23418c837bad5380817c484 Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.r.kempf@gmail.com>
Date: Thu, 10 Sep 2015 17:45:48 +0200
Subject: [PATCH] Update the validation checker and introduce ufl subpackage

---
 python/dune/perftool/compile.py               | 20 +++++-
 python/dune/perftool/transformer.py           | 10 +--
 python/dune/perftool/ufl/__init__.py          |  0
 .../dune/perftool/ufl/modified_terminals.py   | 36 ++++++++++
 python/dune/perftool/ufl/validity.py          | 47 ++++++++++++++
 python/dune/perftool/validity.py              | 65 -------------------
 6 files changed, 108 insertions(+), 70 deletions(-)
 create mode 100644 python/dune/perftool/ufl/__init__.py
 create mode 100644 python/dune/perftool/ufl/modified_terminals.py
 create mode 100644 python/dune/perftool/ufl/validity.py
 delete mode 100644 python/dune/perftool/validity.py

diff --git a/python/dune/perftool/compile.py b/python/dune/perftool/compile.py
index e480fd50..e3124f69 100644
--- a/python/dune/perftool/compile.py
+++ b/python/dune/perftool/compile.py
@@ -11,9 +11,27 @@ def read_ufl(uflfile):
     uflcode = read_ufl_file(uflfile)
     namespace = {}
     uflcode = "from dune.perftool.runtime_ufl import *\n" + uflcode
-    exec uflcode in namespace
+    try:
+        exec uflcode in namespace
+    except:
+        import os
+        basename = os.path.splitext(os.path.basename(uflfile[0]))[0]
+        basename = "{}_debug".format(basename)
+        pyname = "{}.py".format(basename)
+        pycode = "#!/usr/bin/env python\nfrom dune.perftool.runtime_ufl import *\nset_level(DEBUG)\n" + uflcode
+        with file(pyname, "w") as f:
+            f.write(pycode)
+        raise SyntaxError("Not a valid ufl file, dumped a debug script: {}".format(pyname))
     data = interpret_ufl_namespace(namespace)
     formdata = compute_form_data(data.forms[0])
+
+    # We do not expect more than one form
+    assert len(data.forms) == 1
+
+    # We make some assumptions on the UFL expression!
+    from dune.perftool.validity import check_validity
+    check_validity(data.forms[0])
+
     return formdata
 
 
diff --git a/python/dune/perftool/transformer.py b/python/dune/perftool/transformer.py
index 05da1205..903351ca 100644
--- a/python/dune/perftool/transformer.py
+++ b/python/dune/perftool/transformer.py
@@ -165,13 +165,14 @@ class UFLVisitor(MultiFunction):
 
 class TopSumSeparation(MultiFunction):
     """ A multifunction that separates the toplevel sum """
-    def __init__(self):
+    def __init__(self, visitor=UFLVisitor()):
         MultiFunction.__init__(self)
-        self.visitor = UFLVisitor()
+        self.visitor = visitor
         self.inames = []
 
     def expr(self, o):
         self.visitor(o, self.inames)
+        self.inames = []
 
     def sum(self, o):
         for op in o.operands():
@@ -180,9 +181,10 @@ class TopSumSeparation(MultiFunction):
     def index_sum(self, o):
         # Generate an iname: We do this here and restrict ourselves to shape dim.
         # TODO revisit when implementing general index sums
-        self.inames.append(dimension_iname(o.operands()[1]))
+        if isinstance(self.visitor, UFLVisitor):
+            self.inames.append(dimension_iname(o.operands()[1]))
         self(o.operands()[0])
 
 
 def transform_expression(expr):
-    return TopSumSeparation()(expr)
+    TopSumSeparation()(expr)
diff --git a/python/dune/perftool/ufl/__init__.py b/python/dune/perftool/ufl/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/python/dune/perftool/ufl/modified_terminals.py b/python/dune/perftool/ufl/modified_terminals.py
new file mode 100644
index 00000000..b8c3900c
--- /dev/null
+++ b/python/dune/perftool/ufl/modified_terminals.py
@@ -0,0 +1,36 @@
+""" A module mimicking some functionality of uflacs' modified terminals """
+
+from ufl.algorithms import MultiFunction
+from dune.perftool.restriction import Restriction
+
+
+class ModifiedTerminalTracker(MultiFunction):
+    def __init__(self):
+        MultiFunction.__init__(self)
+        self.grad = False
+        self.reference_grad = False
+        self.restriction = Restriction.NONE
+
+    def positive_restricted(self, o):
+        self.restriction = Restriction.POSITIVE
+        ret = self.expr(o)
+        self.restriction = Restriction.NONE
+        return ret
+
+    def negative_restricted(self, o):
+        self.restriction = Restriction.NEGATIVE
+        ret = self.expr(o)
+        self.restriction = Restriction.NONE
+        return ret
+
+    def grad(self, o):
+        self.grad = True
+        ret = self.expr(o)
+        self.grad = False
+        return ret
+
+    def reference_grad(self, o):
+        self.reference_grad = True
+        ret = self.expr(o)
+        self.reference_grad = False
+        return ret
diff --git a/python/dune/perftool/ufl/validity.py b/python/dune/perftool/ufl/validity.py
new file mode 100644
index 00000000..8690adb0
--- /dev/null
+++ b/python/dune/perftool/ufl/validity.py
@@ -0,0 +1,47 @@
+from ufl.algorithms import MultiFunction
+from dune.perftool.ufl.modified_terminals import ModifiedTerminalTracker
+
+
+class UFLRank(MultiFunction):
+    def __call__(self, expr):
+        return len(MultiFunction.__call__(self, expr))
+
+    def expr(self, o):
+        return set(a for op in o.operands() for a in MultiFunction.__call__(self, op))
+
+    def argument(self, o):
+        return (o.number(),)
+
+
+class ArgumentCounter(ModifiedTerminalTracker):
+    def __call__(self, o, mock):
+        if not len(MultiFunction.__call__(self, o)) == UFLRank()(o):
+            raise ValueError("Interal Error: The transformed form violated the pdelab generation assumptions. Please send your ufl file to dominic.kempf@iwr.uni-heidelberg.de")
+
+    def expr(self, o):
+        return set(a for op in o.operands() for a in MultiFunction.__call__(self, op))
+
+    def argument(self, o):
+        return ((o.number(), self.grad, self.reference_grad, self.restriction),)
+
+
+def check_validity(uflexpr):
+    """ check an UFL expression (usually an integrand) for
+    compatibility with the dune-pdelab code generation tool chain.
+
+    The assumptions made are the following:
+    - The expression is the sum of subtrees, where the number of test
+      function terms in such subtree is equal the rank of the entire expression.
+    """
+    from ufl import Form
+    from ufl.classes import Expr
+    from dune.perftool.transformer import TopSumSeparation
+
+    if isinstance(uflexpr, Form):
+        for integral in uflexpr.integrals():
+            return TopSumSeparation(visitor=ArgumentCounter())(integral.integrand())
+
+    if isinstance(uflexpr, Expr):
+        return TopSumSeparation(visitor=ArgumentCounter())(uflexpr)
+
+    raise TypeError("Unknown object type in check_validity: {}".format(type(uflexpr)))
diff --git a/python/dune/perftool/validity.py b/python/dune/perftool/validity.py
deleted file mode 100644
index b88b6881..00000000
--- a/python/dune/perftool/validity.py
+++ /dev/null
@@ -1,65 +0,0 @@
-from ufl.classes import Sum
-from ufl.algorithms import MultiFunction
-
-
-class UFLRank(MultiFunction):
-    def __call__(self, expr):
-        return len(MultiFunction.__call__(self, expr))
-
-    def expr(self, o):
-        return set(a for op in o.operands() for a in MultiFunction.__call__(self, op))
-
-    def argument(self, o):
-        return (o.number(),)
-
-
-class UFLValidityChecker(MultiFunction):
-    def __init__(self):
-        MultiFunction.__init__(self)
-        self.rankCounter = UFLRank()
-
-    def __call__(self, expr, topsum=True):
-        self.rank = self.rankCounter(expr)
-        self.sane = None
-        ret = MultiFunction.__call__(self, expr, topsum)
-        if self.sane is None:
-            # In this case, we did not encounter a sum and need to compare with the entire expression
-            self.sane = len(ret) == self.rank
-        return self.sane
-
-    # define the default behaviour: just call the multifunction
-    # recursively and combine the sets of found arguments
-    def expr(self, o, topsum):
-        return set(a for op in o.operands() for a in MultiFunction.__call__(self, op, False))
-
-    def argument(self, o, topsum):
-        return (o.number(),)
-
-    def sum(self, o, topsum):
-        if topsum:
-            for op in o.operands():
-                if not isinstance(op, Sum):
-                    self.sane = len(MultiFunction.__call__(self, op, True)) == self.rank
-        else:
-            # If this is sum is not part of the topsum, we treat it as all other expressions
-            return self.expr(o, topsum)
-
-    def index_sum(self, o, topsum):
-        assert len(o.operands()) == 2
-        if topsum:
-            op = o.operands()[0]
-            if not isinstance(op, Sum):
-                self.sane = len(MultiFunction.__call__(self, op, True)) == self.rank
-        else:
-            return self.expr(o, topsum)
-
-
-def check_validity(uflexpr):
-    """ check an UFL expression (usually an integrand) for
-    compatibility with the dune-pdelab code generation tool chain.
-    
-    The assumptions made are the following:
-    - The expression is the sum of subtrees, where the number of test
-      function terms in such subtree is equal the rank of the entire expression.
-    """
-    return UFLValidityChecker()(uflexpr)
\ No newline at end of file
-- 
GitLab