Skip to content
Snippets Groups Projects
Commit 5c970638 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Output invalid expressions!

parent 7a8f940d
No related branches found
No related tags found
No related merge requests found
......@@ -22,5 +22,4 @@ class TopSumSplit(MultiFunction):
MultiFunction.__call__(self, op)
def index_sum(self, o):
print "ISUM"
MultiFunction.__call__(self, o.operands()[0])
......@@ -7,19 +7,26 @@ class TransformationWrapper(object):
# Extract the name of the transformation from the given kwargs
assert "name" in kwargs
self.name = kwargs.pop("name")
self.printBefore = kwargs.pop("printBefore", True)
def write_trafo(self, expr, before):
# Skip this if we explicitly disabled it
if before and not self.printBefore:
return
# Write out a dot file
from dune.perftool.options import get_option
if get_option("print_transformations", False):
# TODO This should be siabled by default!
if get_option("print_transformations", True):
import os
dir = get_option("print_transformations_dir", os.getcwd())
filename = "trafo_{}_{}_{}.dot".format(self.name, str(self.counter).zfill(4), "in" if before else "out")
filename = os.join(dir, filename)
filename = os.path.join(dir, filename)
with open(filename,'w') as out:
from ufl.formatting import ufl2dot
out.write(ufl2dot(expr))
from ufl.formatting.ufl2dot import ufl2dot
out.write(str(ufl2dot(expr)[0]))
if not before:
self._counter = self._counter + 1
self.counter = self.counter + 1
def __call__(self, expr, *args, **kwargs):
# We assume that the first argument to any transformation is the expression
......@@ -47,3 +54,7 @@ def ufl_transformation(**kwargs):
""" A decorator for ufl transformations. It allows us to output the
result if needed. """
return lambda f: TransformationWrapper(f, **kwargs)
@ufl_transformation(name="print", printBefore=False)
def print_expression(e):
return e
\ No newline at end of file
......@@ -38,16 +38,22 @@ def check_validity(uflexpr):
tss = TopSumSplit()
ae = ArgumentExtractor()
def check(term, rank):
if not len(ae(term)) == rank:
from dune.perftool.ufl.transformations import print_expression
print_expression(term)
raise ValueError("Form not valid for pdelab code generation. Dumped dot file to send to dominic.kempf@iwr.uni-heidelberg.de")
if isinstance(uflexpr, Form):
rank = len(uflexpr.arguments())
for integral in uflexpr.integrals():
for term in tss(integral.integrand()):
assert len(ae(term)) == rank
check(term, rank)
return
if isinstance(uflexpr, Expr):
rank = UFLRank()(uflexpr)
assert len(ae(uflexpr)) == rank
check(uflexpr, rank)
return
raise TypeError("Unknown object type in check_validity: {}".format(type(uflexpr)))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment