diff --git a/python/dune/perftool/ufl/topsum.py b/python/dune/perftool/ufl/topsum.py index 18b266cf644cf1c9709adc0faecf4698c97ac64e..5f79d6cc88ad4310813a5b78560686ea91497484 100644 --- a/python/dune/perftool/ufl/topsum.py +++ b/python/dune/perftool/ufl/topsum.py @@ -22,5 +22,4 @@ class TopSumSplit(MultiFunction): MultiFunction.__call__(self, op) def index_sum(self, o): - print "ISUM" MultiFunction.__call__(self, o.operands()[0]) diff --git a/python/dune/perftool/ufl/transformations/__init__.py b/python/dune/perftool/ufl/transformations/__init__.py index 2ff2cd122d9d7f4ea4f8bfa13892682f22db249b..e723f548563feb4a0c41e128daf3d362d7f2c4b2 100644 --- a/python/dune/perftool/ufl/transformations/__init__.py +++ b/python/dune/perftool/ufl/transformations/__init__.py @@ -7,19 +7,26 @@ class TransformationWrapper(object): # Extract the name of the transformation from the given kwargs assert "name" in kwargs self.name = kwargs.pop("name") + self.printBefore = kwargs.pop("printBefore", True) def write_trafo(self, expr, before): + # Skip this if we explicitly disabled it + if before and not self.printBefore: + return + + # Write out a dot file from dune.perftool.options import get_option - if get_option("print_transformations", False): + # TODO This should be siabled by default! + if get_option("print_transformations", True): import os dir = get_option("print_transformations_dir", os.getcwd()) filename = "trafo_{}_{}_{}.dot".format(self.name, str(self.counter).zfill(4), "in" if before else "out") - filename = os.join(dir, filename) + filename = os.path.join(dir, filename) with open(filename,'w') as out: - from ufl.formatting import ufl2dot - out.write(ufl2dot(expr)) + from ufl.formatting.ufl2dot import ufl2dot + out.write(str(ufl2dot(expr)[0])) if not before: - self._counter = self._counter + 1 + self.counter = self.counter + 1 def __call__(self, expr, *args, **kwargs): # We assume that the first argument to any transformation is the expression @@ -47,3 +54,7 @@ def ufl_transformation(**kwargs): """ A decorator for ufl transformations. It allows us to output the result if needed. """ return lambda f: TransformationWrapper(f, **kwargs) + +@ufl_transformation(name="print", printBefore=False) +def print_expression(e): + return e \ No newline at end of file diff --git a/python/dune/perftool/ufl/validity.py b/python/dune/perftool/ufl/validity.py index 9b91c2620b58bb33f97327556f95ab932506b06c..c8c77787de90513bc077d97877b2e212242f8a16 100644 --- a/python/dune/perftool/ufl/validity.py +++ b/python/dune/perftool/ufl/validity.py @@ -38,16 +38,22 @@ def check_validity(uflexpr): tss = TopSumSplit() ae = ArgumentExtractor() + def check(term, rank): + if not len(ae(term)) == rank: + from dune.perftool.ufl.transformations import print_expression + print_expression(term) + raise ValueError("Form not valid for pdelab code generation. Dumped dot file to send to dominic.kempf@iwr.uni-heidelberg.de") + if isinstance(uflexpr, Form): rank = len(uflexpr.arguments()) for integral in uflexpr.integrals(): for term in tss(integral.integrand()): - assert len(ae(term)) == rank + check(term, rank) return if isinstance(uflexpr, Expr): rank = UFLRank()(uflexpr) - assert len(ae(uflexpr)) == rank + check(uflexpr, rank) return raise TypeError("Unknown object type in check_validity: {}".format(type(uflexpr)))