Skip to content
Snippets Groups Projects
Commit 7a8f940d authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Update the validation code

Introduce a new topsum splitter
parent b4f197b5
No related branches found
No related tags found
No related merge requests found
......@@ -3,6 +3,7 @@ The methods to run the parts of the form compiler
Should also contain the entrypoint methods.
"""
from __future__ import absolute_import
def read_ufl(uflfile):
from ufl.algorithms.formfiles import read_ufl_file, interpret_ufl_namespace
......@@ -29,7 +30,7 @@ def read_ufl(uflfile):
assert len(data.forms) == 1
# We make some assumptions on the UFL expression!
from dune.perftool.validity import check_validity
from dune.perftool.ufl.validity import check_validity
check_validity(data.forms[0])
return formdata
......
......@@ -4,6 +4,8 @@ a complex requirement structure. This includes:
* a caching mechanism to avoid duplicated preambles where harmful
"""
from __future__ import absolute_import
# have one cache the module level. It is easier than handing around an instance of it.
_cache = {}
......
......@@ -26,5 +26,5 @@ def get_form_compiler_arguments():
# Return the argument dict. This result is memoized to turn all get_option calls into simple dict lookups.
return args
def get_option(key):
return get_form_compiler_arguments().get(key, None)
\ No newline at end of file
def get_option(key, default=None):
return get_form_compiler_arguments().get(key, default)
\ No newline at end of file
# Generate a loop kernel from that cell integral. This will map the UFL IR
# to the loopy IR.
"""
This is the module that contains the main transformation from ufl to loopy
(with pdelab as the hardcoded generation target)
"""
from __future__ import absolute_import
from ufl.algorithms import MultiFunction
# Spread the pymbolic import statements to where they are used.
from pymbolic.primitives import Variable, Subscript, Sum, Product
......
""" A multifunction for separation of the top sum
TODO: describe the purpose of topsums.
"""
from __future__ import absolute_import
from ufl.algorithms import MultiFunction
class TopSumSplit(MultiFunction):
""" Split an expression into summands of the top-level sum """
def __call__(self, o):
self.terms = []
MultiFunction.__call__(self, o)
return self.terms
def expr(self, o):
self.terms.append(o)
def sum(self, o):
for op in o.operands():
MultiFunction.__call__(self, op)
def index_sum(self, o):
print "ISUM"
MultiFunction.__call__(self, o.operands()[0])
from __future__ import absolute_import
from ufl.algorithms import MultiFunction
from dune.perftool.ufl.modified_terminals import ModifiedTerminalTracker
......@@ -13,11 +15,7 @@ class UFLRank(MultiFunction):
return (o.number(),)
class ArgumentCounter(ModifiedTerminalTracker):
def __call__(self, o, mock):
if not len(MultiFunction.__call__(self, o)) == UFLRank()(o):
raise ValueError("Interal Error: The transformed form violated the pdelab generation assumptions. Please send your ufl file to dominic.kempf@iwr.uni-heidelberg.de")
class ArgumentExtractor(ModifiedTerminalTracker):
def expr(self, o):
return set(a for op in o.operands() for a in MultiFunction.__call__(self, op))
......@@ -35,13 +33,21 @@ def check_validity(uflexpr):
"""
from ufl import Form
from ufl.classes import Expr
from dune.perftool.transformer import TopSumSeparation
from dune.perftool.ufl.topsum import TopSumSplit
tss = TopSumSplit()
ae = ArgumentExtractor()
if isinstance(uflexpr, Form):
rank = len(uflexpr.arguments())
for integral in uflexpr.integrals():
return TopSumSeparation(visitor=ArgumentCounter())(integral.integrand())
for term in tss(integral.integrand()):
assert len(ae(term)) == rank
return
if isinstance(uflexpr, Expr):
return TopSumSeparation(visitor=ArgumentCounter())(uflexpr)
rank = UFLRank()(uflexpr)
assert len(ae(uflexpr)) == rank
return
raise TypeError("Unknown object type in check_validity: {}".format(type(uflexpr)))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment