diff --git a/python/dune/perftool/compile.py b/python/dune/perftool/compile.py index 1306d7fc8d2fd1e7310e159bd45d40f1fb28639f..f7855277d79e9ff5b94d2effde9a98a7df597776 100644 --- a/python/dune/perftool/compile.py +++ b/python/dune/perftool/compile.py @@ -3,6 +3,7 @@ The methods to run the parts of the form compiler Should also contain the entrypoint methods. """ +from __future__ import absolute_import def read_ufl(uflfile): from ufl.algorithms.formfiles import read_ufl_file, interpret_ufl_namespace @@ -29,7 +30,7 @@ def read_ufl(uflfile): assert len(data.forms) == 1 # We make some assumptions on the UFL expression! - from dune.perftool.validity import check_validity + from dune.perftool.ufl.validity import check_validity check_validity(data.forms[0]) return formdata diff --git a/python/dune/perftool/generation.py b/python/dune/perftool/generation.py index 0e1ae5ff8b4600a8050cf2f7d416a9da81acf56e..25d2425e3e1d886d70b6aa6c360fcf18c58a7941 100644 --- a/python/dune/perftool/generation.py +++ b/python/dune/perftool/generation.py @@ -4,6 +4,8 @@ a complex requirement structure. This includes: * a caching mechanism to avoid duplicated preambles where harmful """ +from __future__ import absolute_import + # have one cache the module level. It is easier than handing around an instance of it. _cache = {} diff --git a/python/dune/perftool/options.py b/python/dune/perftool/options.py index 956213b756c54c3c9458d93ad507f5c11bb1ba13..1e0d1e89e6538e59e801455357fde30041d05a0c 100644 --- a/python/dune/perftool/options.py +++ b/python/dune/perftool/options.py @@ -26,5 +26,5 @@ def get_form_compiler_arguments(): # Return the argument dict. This result is memoized to turn all get_option calls into simple dict lookups. return args -def get_option(key): - return get_form_compiler_arguments().get(key, None) \ No newline at end of file +def get_option(key, default=None): + return get_form_compiler_arguments().get(key, default) \ No newline at end of file diff --git a/python/dune/perftool/transformer.py b/python/dune/perftool/transformer.py index 3213e572f48ab48ae77e5408be613241494df004..55b042d70cad63c2b2bb19282b0916dac8de87b3 100644 --- a/python/dune/perftool/transformer.py +++ b/python/dune/perftool/transformer.py @@ -1,5 +1,9 @@ -# Generate a loop kernel from that cell integral. This will map the UFL IR -# to the loopy IR. +""" +This is the module that contains the main transformation from ufl to loopy +(with pdelab as the hardcoded generation target) +""" + +from __future__ import absolute_import from ufl.algorithms import MultiFunction # Spread the pymbolic import statements to where they are used. from pymbolic.primitives import Variable, Subscript, Sum, Product diff --git a/python/dune/perftool/ufl/topsum.py b/python/dune/perftool/ufl/topsum.py new file mode 100644 index 0000000000000000000000000000000000000000..18b266cf644cf1c9709adc0faecf4698c97ac64e --- /dev/null +++ b/python/dune/perftool/ufl/topsum.py @@ -0,0 +1,26 @@ +""" A multifunction for separation of the top sum + +TODO: describe the purpose of topsums. +""" + +from __future__ import absolute_import +from ufl.algorithms import MultiFunction + + +class TopSumSplit(MultiFunction): + """ Split an expression into summands of the top-level sum """ + def __call__(self, o): + self.terms = [] + MultiFunction.__call__(self, o) + return self.terms + + def expr(self, o): + self.terms.append(o) + + def sum(self, o): + for op in o.operands(): + MultiFunction.__call__(self, op) + + def index_sum(self, o): + print "ISUM" + MultiFunction.__call__(self, o.operands()[0]) diff --git a/python/dune/perftool/ufl/validity.py b/python/dune/perftool/ufl/validity.py index 8690adb037ab3527fa87335786fc5bb7cb37720d..9b91c2620b58bb33f97327556f95ab932506b06c 100644 --- a/python/dune/perftool/ufl/validity.py +++ b/python/dune/perftool/ufl/validity.py @@ -1,3 +1,5 @@ +from __future__ import absolute_import + from ufl.algorithms import MultiFunction from dune.perftool.ufl.modified_terminals import ModifiedTerminalTracker @@ -13,11 +15,7 @@ class UFLRank(MultiFunction): return (o.number(),) -class ArgumentCounter(ModifiedTerminalTracker): - def __call__(self, o, mock): - if not len(MultiFunction.__call__(self, o)) == UFLRank()(o): - raise ValueError("Interal Error: The transformed form violated the pdelab generation assumptions. Please send your ufl file to dominic.kempf@iwr.uni-heidelberg.de") - +class ArgumentExtractor(ModifiedTerminalTracker): def expr(self, o): return set(a for op in o.operands() for a in MultiFunction.__call__(self, op)) @@ -35,13 +33,21 @@ def check_validity(uflexpr): """ from ufl import Form from ufl.classes import Expr - from dune.perftool.transformer import TopSumSeparation + from dune.perftool.ufl.topsum import TopSumSplit + + tss = TopSumSplit() + ae = ArgumentExtractor() if isinstance(uflexpr, Form): + rank = len(uflexpr.arguments()) for integral in uflexpr.integrals(): - return TopSumSeparation(visitor=ArgumentCounter())(integral.integrand()) + for term in tss(integral.integrand()): + assert len(ae(term)) == rank + return if isinstance(uflexpr, Expr): - return TopSumSeparation(visitor=ArgumentCounter())(uflexpr) + rank = UFLRank()(uflexpr) + assert len(ae(uflexpr)) == rank + return raise TypeError("Unknown object type in check_validity: {}".format(type(uflexpr)))