diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000000000000000000000000000000000000..927590e28c828e71285fc6101c0faeda0aac3ab4 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +pep8ignore = E501 diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index 896d7d7b8e1516ce30fd75ab0ce491479999f68a..583ea39f623721a3233e9b5dcc47c0408ba08787 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -21,3 +21,6 @@ dune_install_python_package(PATH ufl MAJOR_VERSION 2) # Install out python package dune_install_python_package(PATH . MAJOR_VERSION 2) +add_python_test_command(COMMAND python -m pytest --pep8 VIRTUALENV dune-env-2 + WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/python/dune/perftool + REQUIRED_PACKAGES pytest pytest-pep8) diff --git a/python/dune/perftool/__init__.py b/python/dune/perftool/__init__.py index fa06bca3d581b54ad724163642d813f35181d9c7..175c0adc5140f13fd0ce1583d1b13ed3bc0ea083 100644 --- a/python/dune/perftool/__init__.py +++ b/python/dune/perftool/__init__.py @@ -1,4 +1,4 @@ class Restriction: NONE = 0 POSITIVE = 1 - NEGATIVE = 2 \ No newline at end of file + NEGATIVE = 2 diff --git a/python/dune/perftool/cgen/__init__.py b/python/dune/perftool/cgen/__init__.py index de5710209bc084f8dba3f7e6eb0208530d9a11ff..128c4dedffb347b256ed64b518fdde9009ce9344 100644 --- a/python/dune/perftool/cgen/__init__.py +++ b/python/dune/perftool/cgen/__init__.py @@ -15,8 +15,8 @@ class Namespace(PrivateNamespace): def __init__(self, *args, **kwargs): name = kwargs.pop("name") PrivateNamespace.__init__(self, *args, **kwargs) - + self.name = name def get_namespace_name(self): - return self.name \ No newline at end of file + return self.name diff --git a/python/dune/perftool/cgen/clazz.py b/python/dune/perftool/cgen/clazz.py index 0fed3054c241d37ca8ab03017582f04f6887f503..0fa4b5b0aba76343d7bf744cf7b19dd319bdf772 100644 --- a/python/dune/perftool/cgen/clazz.py +++ b/python/dune/perftool/cgen/clazz.py @@ -1,10 +1,12 @@ from cgen import Generable, Block + class AccessModifier: PRIVATE = 1 PUBLIC = 2 PROTECTED = 3 + def access_modifier_string(am): if am == AccessModifier.PRIVATE: return "private" @@ -14,6 +16,7 @@ def access_modifier_string(am): return "protected" raise ValueError("Unknown access modifier in class generation") + class BaseClass(Generable): def __init__(self, name, inheritance=AccessModifier.PUBLIC, construction=[]): self.name = name @@ -27,6 +30,7 @@ class BaseClass(Generable): def generate(self): yield self.name + class ClassMember(Generable): def __init__(self, member, access=AccessModifier.PUBLIC): self.member = member @@ -40,6 +44,7 @@ class ClassMember(Generable): for line in self.member.generate(): yield line + '\n' + class Constructor(Generable): def __init__(self, block=Block([]), arg_decls=[], clsname=None, initializer_list=[], access=AccessModifier.PUBLIC): self.clsname = clsname diff --git a/python/dune/perftool/compile.py b/python/dune/perftool/compile.py index 8a0193c1f46f305915d8b5ab4fbac6b255052c26..46d63854dce2fad1f717d99138b0fd300d3ee311 100644 --- a/python/dune/perftool/compile.py +++ b/python/dune/perftool/compile.py @@ -5,6 +5,7 @@ Should also contain the entrypoint methods. """ from __future__ import absolute_import + def read_ufl(uflfile): from ufl.algorithms.formfiles import read_ufl_file, interpret_ufl_namespace from ufl.algorithms import compute_form_data @@ -35,13 +36,13 @@ def read_ufl(uflfile): # apply some transformations unconditionally! from dune.perftool.ufl.transformations import transform_form - from dune.perftool.ufl.transformations.splitarguments import split_arguments + # from dune.perftool.ufl.transformations.splitarguments import split_arguments from dune.perftool.ufl.transformations.indexpushdown import pushdown_indexed from dune.perftool.ufl.transformations.reindexing import reindexing form = transform_form(form, pushdown_indexed) form = transform_form(form, reindexing) - form = transform_form(form, split_arguments) +# form = transform_form(form, split_arguments) return form diff --git a/python/dune/perftool/generation.py b/python/dune/perftool/generation.py index 03b791e31f6eef5d53bea9e408d5b73283f576f1..f80933d4da345b81b619f50f176092a1da9b44d6 100644 --- a/python/dune/perftool/generation.py +++ b/python/dune/perftool/generation.py @@ -5,6 +5,7 @@ a complex requirement structure. This includes: """ from __future__ import absolute_import +from pytools import memoize # have one cache the module level. It is easier than handing around an instance of it. _cache = {} @@ -32,15 +33,17 @@ def _freeze(data): # convert standard mutable containers if isinstance(data, MutableMapping): - return tuple((_freeze(k), _freeze(v)) for k,v in data.iteritems()) + return tuple((_freeze(k), _freeze(v)) for k, v in data.iteritems()) if isinstance(data, Iterable): return tuple(_freeze(i) for i in data) # we don't know how to handle this object, so we give up raise TypeError('Cannot freeze non-hashable object {} of type {}'.format(data, type(data))) + class _NoCachingCounter(object): counter = 0 + def get(self): _NoCachingCounter.counter = _NoCachingCounter.counter + 1 return _NoCachingCounter.counter @@ -59,6 +62,7 @@ class _CacheItemMeta(type): if counted: original_on_store = on_store setattr(rettype, '_count', 0) + def add_count(x): rettype._count = rettype._count + 1 return (rettype._count, original_on_store(x)) @@ -74,7 +78,6 @@ class _CacheItemMeta(type): return rettype -from pytools import memoize @memoize(use_kwargs=True) def _construct_cache_item_type(name, **kwargs): """ Wrap the generation of cache item types from the meta class. @@ -86,11 +89,10 @@ def _construct_cache_item_type(name, **kwargs): class _RegisteredFunction(object): """ The data structure for a function that accesses UFL2LoopyDataCache """ - import sys def __init__(self, func, - cache_key_generator=lambda *a : a, + cache_key_generator=lambda *a: a, **kwargs - ): + ): self.func = func self.cache_key_generator = cache_key_generator self.itemtype = _construct_cache_item_type("CacheItemType", **kwargs) @@ -209,7 +211,7 @@ def retrieve_cache_items(tags, union=True, make_generable=False): if not item.counted: yield as_generable(item.content) - for item in sorted([i for i in choice if i.counted], key = lambda i: i.content[0]): + for item in sorted([i for i in choice if i.counted], key=lambda i: i.content[0]): from collections import Iterable if isinstance(item.content[1], Iterable) and not isinstance(item.content[1], str): for l in item.content[1]: @@ -228,11 +230,11 @@ def delete_cache_items(tags, union=True): # TODO this implementation is horribly inefficient, but does the job removing = retrieve_cache_items(tags, union) global _cache - _cache = {k: v for k,v in _cache.items() if v not in removing} + _cache = {k: v for k, v in _cache.items() if v not in removing} def delete_cache(tags=[], union=True): # TODO this implementation is horribly inefficient, but does the job keeping = retrieve_cache_items(tags, union) global _cache - _cache = {k: v for k,v in _cache.items() if v in keeping} + _cache = {k: v for k, v in _cache.items() if v in keeping} diff --git a/python/dune/perftool/interactive.py b/python/dune/perftool/interactive.py index 259ac0ff872d8e5b6144e726df8e57016beca385..0baed286214259cd60cd6f9970d607f45110370c 100644 --- a/python/dune/perftool/interactive.py +++ b/python/dune/perftool/interactive.py @@ -103,7 +103,7 @@ def optimize_kernel(which, kernels): kernel = {'a': partial(show_kernel, which), 'b': partial(choose_transformation, which), 'c': partial(show_code, which) - }[choice](kernel) + }[choice](kernel) except KeyError: pass diff --git a/python/dune/perftool/loopy/target.py b/python/dune/perftool/loopy/target.py index ce420b128efaf8e308a1278ead45f9782ec9935a..e093d671fade9817bb2306183184be1d59fed1e8 100644 --- a/python/dune/perftool/loopy/target.py +++ b/python/dune/perftool/loopy/target.py @@ -4,6 +4,7 @@ import six from loopy.target import TargetBase from loopy.target.c.codegen.expression import LoopyCCodeMapper + class AllToDouble(dict): """ This imitates a dict that maps everything to double and logs the requested keys """ def __getitem__(self, key): @@ -12,12 +13,14 @@ class AllToDouble(dict): _registry = {'float32': 'float', - 'int32' : 'int', - 'float64' : 'double'} + 'int32': 'int', + 'float64': 'double'} + class MyMapper(LoopyCCodeMapper): var_subst_map = {} + class DuneTarget(TargetBase): def get_or_register_dtype(self, names, dtype=None): diff --git a/python/dune/perftool/loopy/transformer.py b/python/dune/perftool/loopy/transformer.py index d5d79c9ab934b01b8feb93f4be21169982d0dc32..c1d1ccef7ed0093287758f0c1024971574c21119 100644 --- a/python/dune/perftool/loopy/transformer.py +++ b/python/dune/perftool/loopy/transformer.py @@ -4,15 +4,17 @@ This is the module that contains the main transformation from ufl to loopy """ from __future__ import absolute_import +from dune.perftool import Restriction +from dune.perftool.ufl.modified_terminals import ModifiedTerminalTracker +from dune.perftool.pymbolic.uflmapper import UFL2PymbolicMapper + from ufl.algorithms import MultiFunction -# Spread the pymbolic import statements to where they are used. +# TODO Spread the pymbolic import statements to where they are used. from pymbolic.primitives import Variable, Subscript, Sum, Product import loopy import numpy import ufl -from dune.perftool import Restriction - # Define the generators that are used here from dune.perftool.generation import generator_factory loopy_iname = generator_factory(item_tags=("loopy", "kernel", "iname")) @@ -21,17 +23,20 @@ loopy_temporary_variable = generator_factory(item_tags=("loopy", "kernel", "temp loopy_c_instruction = generator_factory(item_tags=("loopy", "kernel", "instruction", "cinstruction"), no_deco=True) loopy_valuearg = generator_factory(item_tags=("loopy", "kernel", "argument", "valuearg"), on_store=lambda n: loopy.ValueArg(n), no_deco=True) + @generator_factory(item_tags=("loopy", "kernel", "argument", "globalarg")) def loopy_globalarg(name, shape=loopy.auto): if isinstance(shape, str): shape = (shape,) return loopy.GlobalArg(name, numpy.float64, shape) + @generator_factory(item_tags=("loopy", "kernel", "domain")) def loopy_domain(iname, shape): loopy_valuearg(shape) return "{{ [{0}] : 0<={0}<{1} }}".format(iname, shape) + @loopy_iname def dimension_iname(index): from dune.perftool.pdelab import name_index @@ -41,21 +46,21 @@ def dimension_iname(index): loopy_domain(iname, dimname) return iname + @loopy_iname def argument_iname(arg): # TODO extract the {iname}_n thing by a preamble - from dune.perftool.ufl.modified_terminals import ModifiedArgumentNumber - iname = "arg{}".format(chr(ord("i") + ModifiedArgumentNumber()(arg))) + from dune.perftool.ufl.modified_terminals import modified_argument_number + iname = "arg{}".format(chr(ord("i") + modified_argument_number()(arg))) loopy_domain(iname, iname + "_n") return iname + @loopy_iname def quadrature_iname(): loopy_domain("q", "q_n") return "q" -from dune.perftool.ufl.modified_terminals import ModifiedTerminalTracker -from dune.perftool.pymbolic.uflmapper import UFL2PymbolicMapper class UFL2LoopyVisitor(ModifiedTerminalTracker, UFL2PymbolicMapper): def __init__(self, embracing_mf): @@ -84,8 +89,8 @@ class TrialFunctionExtractor(MultiFunction): call = MultiFunction.__call__ def __call__(self, o): - from dune.perftool.ufl.modified_terminals import ModifiedArgumentExtractor - self.tf = ModifiedArgumentExtractor()(o, trialfunction=True) + from dune.perftool.ufl.modified_terminals import extract_modified_arguments + self.tf = extract_modified_arguments(o, trialfunction=True) self.u2l = UFL2LoopyVisitor(self) return self.call(o) @@ -100,14 +105,17 @@ class TrialFunctionExtractor(MultiFunction): else: return self.u2l(o) + class _Counter: counter = 0 + def get_count(): c = _Counter.counter _Counter.counter = c + 1 return c + def transform_accumulation_term(term): # Get the accumulation expression and the modified arguments expr, args = term @@ -125,7 +133,7 @@ def transform_accumulation_term(term): pymbolic_expr = simplify_pymbolic_expression(pymbolic_expr) # Define a temporary variable for this expression - expr_tv_name = "expr_"+str(get_count()).zfill(4) + expr_tv_name = "expr_" + str(get_count()).zfill(4) expr_tv = loopy_temporary_variable(expr_tv_name) loopy_expr_instruction(loopy.ExpressionInstruction(assignee=Variable(expr_tv_name), expression=pymbolic_expr)) diff --git a/python/dune/perftool/options.py b/python/dune/perftool/options.py index 4816ce8ecf5b2be5e4f6fd139b1fa79f4efa61d8..82c53c1da31e890ad39fc321bea6e495b73ecab5 100644 --- a/python/dune/perftool/options.py +++ b/python/dune/perftool/options.py @@ -2,6 +2,7 @@ from pytools import memoize + @memoize def get_form_compiler_arguments(): # define an argument parser. @@ -36,9 +37,10 @@ def get_form_compiler_arguments(): # Return the argument dict. This result is memoized to turn all get_option calls into simple dict lookups. return args + def get_option(key, default=None): try: __IPYTHON__ return default except NameError: - return get_form_compiler_arguments().get(key, default) \ No newline at end of file + return get_form_compiler_arguments().get(key, default) diff --git a/python/dune/perftool/pdelab/__init__.py b/python/dune/perftool/pdelab/__init__.py index 2d98c18e408bf09684ed404bb1a5b7d9f81bf604..fec0eba570d0a487dfbdc22fd19be002a6ed6216 100644 --- a/python/dune/perftool/pdelab/__init__.py +++ b/python/dune/perftool/pdelab/__init__.py @@ -2,12 +2,13 @@ # Define the generators that are used throughout all pdelab specific code generations. from dune.perftool.generation import generator_factory +from dune.perftool.loopy.transformer import quadrature_iname +from loopy import CInstruction + dune_symbol = generator_factory(item_tags=("pdelab", "kernel", "symbol")) dune_preamble = generator_factory(item_tags=("pdelab", "kernel", "preamble"), counted=True) dune_include = generator_factory(on_store=lambda i: "#include<{}>".format(i), item_tags=("pdelab", "include"), no_deco=True) -from dune.perftool.loopy.transformer import quadrature_iname -from loopy import CInstruction def quadrature_preamble(assignees=[]): # TODO: How to enforce the order of quadrature preambles? Counted? diff --git a/python/dune/perftool/pdelab/argument.py b/python/dune/perftool/pdelab/argument.py index 7b2ba8f4190facb910ff89783a27e3dfbc3cdeb7..d402b00aa5a0912189de67cfaabba49fa7a8ae3a 100644 --- a/python/dune/perftool/pdelab/argument.py +++ b/python/dune/perftool/pdelab/argument.py @@ -3,6 +3,7 @@ from dune.perftool.pdelab import dune_symbol, dune_preamble from dune.perftool.ufl.modified_terminals import ModifiedArgumentDescriptor + @dune_symbol def name_testfunction(modarg): ma = ModifiedArgumentDescriptor(modarg) @@ -10,21 +11,25 @@ def name_testfunction(modarg): pass return "{}a{}".format("grad_" if ma.grad else "", ma.expr.number()) + @dune_symbol def name_trialfunction(modarg): ma = ModifiedArgumentDescriptor(modarg) return "{}c{}".format("grad_" if ma.grad else "", ma.expr.count()) + @dune_symbol def name_testfunctionspace(*a): - #TODO + # TODO return "lfsv" + @dune_symbol def name_trialfunctionspace(*a): - #TODO + # TODO return "lfsu" + def name_argumentspace(modarg): ma = ModifiedArgumentDescriptor(modarg) if ma.expr.number() == 0: @@ -44,6 +49,7 @@ def name_argument(modarg): # We should never encounter an argument other than 0 or 1 assert False + @dune_symbol def name_residual(): return "r" diff --git a/python/dune/perftool/pdelab/driver.py b/python/dune/perftool/pdelab/driver.py index 526ad1f5a86e4a7652f12af459a02474ba4e0117..d3fa4b8c83c57492ab9aae3c04f50e9cb4e27d3b 100644 --- a/python/dune/perftool/pdelab/driver.py +++ b/python/dune/perftool/pdelab/driver.py @@ -19,6 +19,7 @@ driver_preamble = generator_factory(item_tags=("driver", "preamble"), counted=Tr # through the entire generator chain. _form = None + # Have a function access this global data structure def set_form(f): global _form @@ -28,18 +29,23 @@ def set_form(f): def isLagrange(fem): return fem._short_name is 'CG' + def isSimplical(fem): return any(fem._cell._cellname in x for x in ["triangle", "tetrahedron"]) + def isQuadrilateral(fem): return any(fem._cell._cellname in x for x in ["vertex", "interval", "quadrilateral", "hexalateral"]) + def isPk(fem): return isLagrange(fem) and isSimplical(fem) + def isQk(fem): return isLagrange(fem) and isQuadrilateral(fem) + def FEM_name_mangling(fem): from ufl import MixedElement, VectorElement, FiniteElement if isinstance(fem, MixedElement): @@ -58,11 +64,13 @@ def FEM_name_mangling(fem): return "Q" + str(fem._degree) raise NotImplementedError("FEM NAME MANGLING") + @dune_symbol def name_inifile(): # TODO pass some other option here. return "argv[1]" + @driver_preamble def parse_initree(varname): driver_include("dune/common/parametertree.hh") @@ -70,21 +78,25 @@ def parse_initree(varname): filename = name_inifile() return ["Dune::ParameterTree initree;", "Dune::ParameterTreeParser::readINITree({}, {});".format(filename, varname)] + @dune_symbol def name_initree(): parse_initree("initree") - #TODO we can get some other ini file here. + # TODO we can get some other ini file here. return "initree" + @driver_preamble def define_dimension(name): return "static const int {} = {};".format(name, _form.cell().geometric_dimension()) + @dune_symbol def name_dimension(): define_dimension("dim") return "dim" + @driver_preamble def typedef_grid(name): dim = name_dimension() @@ -99,11 +111,13 @@ def typedef_grid(name): raise ValueError("Cant match your geometry with a DUNE grid. Please report bug.") return "typedef {} {};".format(gridt, name) + @dune_symbol def type_grid(): typedef_grid("Grid") return "Grid" + @driver_preamble def define_grid(name): driver_include("dune/testtools/gridconstruction.hh") @@ -112,53 +126,63 @@ def define_grid(name): return ["IniGridFactory<{}> factory({});".format(_type, ini), "std::shared_ptr<{}> grid = factory.getGrid();".format(_type)] + @dune_symbol def name_grid(): define_grid("grid") return "grid" + @driver_preamble def typedef_leafview(name): grid = type_grid() return "typedef {}::LeafGridView {};".format(grid, name) + @dune_symbol def type_leafview(): typedef_leafview("GV") return "GV" + @driver_preamble def define_leafview(name): _type = type_leafview() grid = name_grid() return "{} {} = {}->leafGridView();".format(_type, name, grid) + @dune_symbol def name_leafview(): define_leafview("gv") return "gv" + @driver_preamble def typedef_vtkwriter(name): driver_include("dune/grid/io/file/vtk/subsamplingvtkwriter.hh") gv = type_leafview() return "typedef Dune::SubsamplingVTKWriter<{}> {};".format(gv, name) + @dune_symbol def type_vtkwriter(): typedef_vtkwriter("VTKWriter") return "VTKWriter" + @driver_preamble def define_subsamplinglevel(name): ini = name_initree() return "int {} = {}.get<int>(\"vtk.subsamplinglevel\", 0);".format(name, ini) + @dune_symbol def name_subsamplinglevel(): define_subsamplinglevel("sublevel") return "sublevel" + @driver_preamble def define_vtkwriter(name): _type = type_vtkwriter() @@ -166,30 +190,36 @@ def define_vtkwriter(name): subsamp = name_subsamplinglevel() return "{} {}({}, {});".format(_type, name, gv, subsamp) + @dune_symbol def name_vtkwriter(): define_vtkwriter("vtkwriter") return "vtkwriter" + @driver_preamble def typedef_domainfield(name): gridt = type_grid() return "typedef {}::ctype {};".format(gridt, name) + @dune_symbol def type_domainfield(): typedef_domainfield("DF") return "DF" + @driver_preamble def typedef_range(name): return "typedef double {};".format(name) + @dune_symbol def type_range(): typedef_range("R") return "R" + @driver_preamble def typedef_fem(expr, name): gv = type_leafview() @@ -203,53 +233,63 @@ def typedef_fem(expr, name): return "typedef Dune::PDELab::QkLocalFiniteElementMap<{}, {}, {}, {}> {};".format(gv, df, r, expr._degree, name) raise NotImplementedError("FEM not implemented in dune-perftool") + @dune_symbol def type_fem(expr): name = "{}_FEM".format(FEM_name_mangling(expr).upper()) typedef_fem(expr, name) return name + @driver_preamble def define_fem(expr, name): femtype = type_fem(expr) gv = name_leafview() return "{} {}({});".format(femtype, name, gv) + @dune_symbol def name_fem(expr): name = "{}_fem".format(FEM_name_mangling(expr).lower()) define_fem(expr, name) return name + @driver_preamble def typedef_vectorbackend(name): driver_include("dune/pdelab/backend/istlvectorbackend.hh") return "typedef Dune::PDELab::ISTLVectorBackend<Dune::PDELab::ISTLParameters::no_blocking, 1> {};".format(name) + @dune_symbol def type_vectorbackend(): typedef_vectorbackend("VectorBackend") return "VectorBackend" + @dune_symbol def type_orderingtag(): return "Dune::PDELab::LexicographicOrderingTag" + @driver_preamble def typedef_constraintsassembler(name): driver_include("dune/pdelab/constraints/conforming.hh") return "typedef Dune::PDELab::ConformingDirichletConstraints {};".format(name) + @dune_symbol def type_constraintsassembler(): typedef_constraintsassembler("ConstraintsAssembler") return "ConstraintsAssembler" + @driver_preamble def typedef_constraintscontainer(expr, name): gfs = type_gfs(expr) r = type_range() - return "typedef {}::ConstraintsContainer<{}>::Type {};".format(gfs,r, name) + return "typedef {}::ConstraintsContainer<{}>::Type {};".format(gfs, r, name) + @dune_symbol def type_constraintscontainer(expr): @@ -257,17 +297,20 @@ def type_constraintscontainer(expr): typedef_constraintscontainer(expr, name) return name + @driver_preamble def define_constraintscontainer(expr, name): cctype = type_constraintscontainer(expr) return ["{} {};".format(cctype, name), "{}.clear();".format(name)] + @dune_symbol def name_constraintscontainer(expr): name = "{}_cc".format(FEM_name_mangling(expr)).lower() define_constraintscontainer(expr, name) return name + @driver_preamble def typedef_gfs(expr, name): vb = type_vectorbackend() @@ -293,12 +336,14 @@ def typedef_gfs(expr, name): if isinstance(expr, RestrictedElement): raise NotImplementedError("Dune does not support restricted elements!") + @dune_symbol def type_gfs(expr): name = "{}_GFS".format(FEM_name_mangling(expr).upper()) typedef_gfs(expr, name) return name + @driver_preamble def define_gfs(expr, name): gfstype = type_gfs(expr) @@ -319,12 +364,14 @@ def define_gfs(expr, name): if isinstance(expr, RestrictedElement): raise NotImplementedError("Dune does not support restricted elements!") + @dune_symbol def name_gfs(expr): name = "{}_gfs".format(FEM_name_mangling(expr)).lower() define_gfs(expr, name) return name + @driver_preamble def define_dofestimate(name): # Provide a worstcase estimate for the number of entries per row based on the given gridfunction space and cell geometry @@ -337,65 +384,77 @@ def define_dofestimate(name): return ["int generic_dof_estimate = {} * {}.maxLocalSize();".format(geo_factor, gfs), "int dof_estimate = {}.get<int>(\"istl.number_of_nnz\", generic_dof_estimate);".format(ini)] + @dune_symbol def name_dofestimate(): define_dofestimate("dofestimate") return "dofestimate" + @driver_preamble def typedef_matrixbackend(name): driver_include("dune/pdelab/backend/istl/bcrsmatrixbackend.hh") return "typedef Dune::PDELab::istl::BCRSMatrixBackend<> {};".format(name) + @dune_symbol def type_matrixbackend(): typedef_matrixbackend("MatrixBackend") return "MatrixBackend" + @driver_preamble def define_matrixbackend(name): mbtype = type_matrixbackend() dof = name_dofestimate() return "{} {}({});".format(mbtype, name, dof) + @dune_symbol def name_matrixbackend(): define_matrixbackend("mb") return "mb" + @driver_preamble def typedef_parameters(name): return "typedef LocalOperatorParameters {};".format(name) + @dune_symbol def type_parameters(): typedef_parameters("Params") return "Params" + @driver_preamble def define_parameters(name): partype = type_parameters() return "{} {}();".format(partype, name) + @dune_symbol def name_parameters(): define_parameters("params") return "params" + @driver_preamble def typedef_localoperator(name): # No Parameter class here, yet - #params = type_parameters() - #return "typedef LocalOperator<{}> {};".format(params, name) + # params = type_parameters() + # return "typedef LocalOperator<{}> {};".format(params, name) from dune.perftool.options import get_option driver_include(get_option('operator_file')) return "// Here in the future: typedef for the local operator with parameter class as template parameter" + @dune_symbol def type_localoperator(): typedef_localoperator("LocalOperator") return "LocalOperator" + @driver_preamble def define_localoperator(name): loptype = type_localoperator() @@ -403,11 +462,13 @@ def define_localoperator(name): params = name_parameters() return "{} {}({}, {});".format(loptype, name, ini, params) + @dune_symbol def name_localoperator(): define_localoperator("lop") return "lop" + @driver_preamble def typedef_gridoperator(name): ugfs = type_gfs(_form.coefficients()[0].element()) @@ -421,11 +482,13 @@ def typedef_gridoperator(name): driver_include("dune/pdelab/gridoperator/gridoperator.hh") return "typedef Dune::PDELab::GridOperator<{}, {}, {}, {}, {}, {}, {}, {}, {}> {};".format(ugfs, vgfs, lop, mb, df, r, r, ucc, vcc, name) + @dune_symbol def type_gridoperator(): typedef_gridoperator("GO") return "GO" + @driver_preamble def define_gridoperator(name): gotype = type_gridoperator() @@ -437,62 +500,74 @@ def define_gridoperator(name): mb = name_matrixbackend() return "{} {}({}, {}, {}, {}, {}, {});".format(gotype, name, ugfs, ucc, vgfs, vcc, lop, mb) + @dune_symbol def name_gridoperator(): define_gridoperator("go") return "go" + @driver_preamble def typedef_vector(name): gotype = type_gridoperator() return "typedef {}::Traits::Domain {};".format(gotype, name) + @dune_symbol def type_vector(): typedef_vector("V") return "V" + @driver_preamble def define_vector(name): vtype = type_vector() gfs = name_gfs(_form.coefficients()[0].element()) return ["{} {}({});".format(vtype, name, gfs), "{} = 0.0;".format(name)] + @dune_symbol def name_vector(): define_vector("x") return "x" + @driver_preamble def typedef_linearsolver(name): driver_include("dune/pdelab/backend/istlsolverbackend.hh") return "typedef Dune::PDELab::ISTLBackend_SEQ_UMFPack {};".format(name) + @dune_symbol def type_linearsolver(): typedef_linearsolver("LinearSolver") return "LinearSolver" + @driver_preamble def define_linearsolver(name): lstype = type_linearsolver() return "{} {}(false);".format(lstype, name) + @dune_symbol def name_linearsolver(): define_linearsolver("ls") return "ls" + @driver_preamble def define_reduction(name): ini = name_initree() return "double {} = {}.get<double>(\"reduction\", 1e-12);".format(name, ini) + @dune_symbol def name_reduction(): define_reduction("reduction") return "reduction" + @dune_symbol def typedef_stationarylinearproblemsolver(name): driver_include("dune/pdelab/stationary/linearproblem.hh") @@ -501,11 +576,13 @@ def typedef_stationarylinearproblemsolver(name): xtype = type_vector() return "typedef Dune::PDELab::StationaryLinearProblemSolver<{}, {}, {}> {}".format(gotype, lstype, xtype, name) + @dune_symbol def type_stationarylinearproblemsolver(): typedef_stationarylinearproblemsolver("SLP") return "SLP" + @driver_preamble def define_stationarylinearproblemsolver(name): slptype = type_stationarylinearproblemsolver() @@ -515,11 +592,13 @@ def define_stationarylinearproblemsolver(name): red = name_reduction() return "{} {}({}, {}, {}, {});".format(slptype, name, go, ls, x, red) + @dune_symbol def name_stationarylinearproblemsolver(): define_stationarylinearproblemsolver("slp") return "slp" + @driver_preamble def dune_solve(): from ufl.algorithms.predicates import is_multilinear @@ -531,17 +610,20 @@ def dune_solve(): else: raise NotImplementedError + @driver_preamble def define_vtkfile(name): ini = name_initree() driver_include("string") return "std::string {} = {}.get<std::string>(\"vtk.filename\", \"output\");".format(name, ini) + @dune_symbol def name_vtkfile(): define_vtkfile("vtkfile") return "vtkfile" + @driver_preamble def vtkoutput(): driver_include("dune/pdelab/gridfunctionspace/vtk.hh") diff --git a/python/dune/perftool/pdelab/geometry.py b/python/dune/perftool/pdelab/geometry.py index 6d0db436c7cae4aed6d9d021f5097c8ee8f3aa36..020f5d8797199a885bf798d8f4201ecde52c2c50 100644 --- a/python/dune/perftool/pdelab/geometry.py +++ b/python/dune/perftool/pdelab/geometry.py @@ -1,11 +1,13 @@ from dune.perftool.pdelab import dune_symbol + @dune_symbol def name_dimension(): - #TODO preamble define_dimension + # TODO preamble define_dimension return "dim" + @dune_symbol def name_facetarea(): - #TODO preambles - return "farea" \ No newline at end of file + # TODO preambles + return "farea" diff --git a/python/dune/perftool/pdelab/localoperator.py b/python/dune/perftool/pdelab/localoperator.py index 4a7d8d0910428d0e68bf3f7b0b752f8b462b86c3..444d9381ab287886f49dbdf8e8f0448114d64e76 100644 --- a/python/dune/perftool/pdelab/localoperator.py +++ b/python/dune/perftool/pdelab/localoperator.py @@ -3,6 +3,7 @@ from __future__ import absolute_import from dune.perftool.options import get_option from dune.perftool.generation import generator_factory from dune.perftool.pdelab import dune_symbol +from dune.perftool.cgen.clazz import BaseClass, ClassMember from cgen import Include @@ -10,7 +11,6 @@ from pytools import memoize # Define the generators used in-here operator_include = generator_factory(item_tags=("include", "operator"), on_store=lambda i: Include(i), no_deco=True) -from dune.perftool.cgen.clazz import BaseClass public_base_class = generator_factory(item_tags=("baseclass", "operator"), on_store=lambda n: BaseClass(n), counted=True, no_deco=True) @@ -18,23 +18,27 @@ public_base_class = generator_factory(item_tags=("baseclass", "operator"), on_st def initializer_list(obj, params): return "{}({})".format(obj, ", ".join(params)) -@generator_factory(item_tags=("operator", "member"), counted=True, cache_key_generator=lambda t,n : n) + +@generator_factory(item_tags=("operator", "member"), counted=True, cache_key_generator=lambda t, n: n) def define_private_member(_type, name): from cgen import Value from dune.perftool.cgen.clazz import ClassMember, AccessModifier - return ClassMember(Value(_type, name), access=AccessModifier.PRIVATE) + return ClassMember(Value(_type, name), access=AccessModifier.PRIVATE) + @generator_factory(item_tags=("operator", "constructor_param"), counted=True) def constructor_parameter(_type, name): from cgen import Value return Value(_type, name) + @dune_symbol def name_initree_constructor(): operator_include('dune/common/parametertree.hh') constructor_parameter("const Dune::ParameterTree&", "iniParams") return "iniParams" + @dune_symbol def name_initree_member(): operator_include('dune/common/parametertree.hh') @@ -43,11 +47,13 @@ def name_initree_member(): initializer_list("_iniParams", [in_constructor]) return "_iniParams" + @dune_symbol def localoperator_type(): - #TODO use something from the form here to make it unique + # TODO use something from the form here to make it unique return "LocalOperator" + @memoize def measure_specific_details(measure): # The return dictionary that this memoized method will grant direct access to. @@ -116,13 +122,13 @@ def generate_kernel(integrand=None, measure=None): from dune.perftool.loopy.target import DuneTarget domains = [i for i in retrieve_cache_items("domain")] instructions = [i for i in retrieve_cache_items("instruction")] - temporaries = {i.name:i for i in retrieve_cache_items("temporary")} - #preambles = [i for i in retrieve_cache_items("preamble")] + temporaries = {i.name: i for i in retrieve_cache_items("temporary")} + # preambles = [i for i in retrieve_cache_items("preamble")] arguments = [i for i in retrieve_cache_items("argument")] # Create the kernel from loopy import make_kernel, preprocess_kernel - #kernel = make_kernel(domains, instructions, arguments, temporary_variables=temporaries, preambles=preambles, target=DuneTarget()) + # kernel = make_kernel(domains, instructions, arguments, temporary_variables=temporaries, preambles=preambles, target=DuneTarget()) kernel = make_kernel(domains, instructions, arguments, temporary_variables=temporaries, target=DuneTarget()) kernel = preprocess_kernel(kernel) @@ -134,12 +140,11 @@ def generate_kernel(integrand=None, measure=None): return kernel -from dune.perftool.cgen.clazz import ClassMember class AssemblyMethod(ClassMember): def __init__(self, signature, kernel): from loopy import generate_code from cgen import LiteralLines - content = LiteralLines('\n'+'\n'.join(signature) + '\n' + generate_code(kernel)[0]) + content = LiteralLines('\n' + '\n'.join(signature) + '\n' + generate_code(kernel)[0]) ClassMember.__init__(self, content) @@ -202,6 +207,7 @@ def generate_localoperator_kernels(form): # Return the set of generated kernels return operator_kernels + def generate_localoperator_file(kernels): operator_methods = [] diff --git a/python/dune/perftool/pdelab/quadrature.py b/python/dune/perftool/pdelab/quadrature.py index e306e3943e6581d7cfd87c0a5110657a10c7a2a7..514e94eedabacf5de90272078771e718c83f6688 100644 --- a/python/dune/perftool/pdelab/quadrature.py +++ b/python/dune/perftool/pdelab/quadrature.py @@ -2,15 +2,18 @@ from dune.perftool.generation import generator_factory from dune.perftool.loopy.transformer import quadrature_iname, loopy_temporary_variable from dune.perftool.pdelab import dune_symbol, quadrature_preamble, dune_preamble + @dune_symbol def quadrature_rule(): return "rule" + @quadrature_preamble(assignees="fac") def define_quadrature_factor(fac): rule = quadrature_rule() return "auto {} = {}->weight();".format(fac, rule) + @dune_symbol def name_factor(): loopy_temporary_variable("fac") diff --git a/python/dune/perftool/pymbolic/simplify.py b/python/dune/perftool/pymbolic/simplify.py index 3382cc245540971ceb00315244e2bf5da90edcf3..f0312b040443f0a24207940e2bf7bdd7f96ce843 100644 --- a/python/dune/perftool/pymbolic/simplify.py +++ b/python/dune/perftool/pymbolic/simplify.py @@ -3,7 +3,8 @@ from __future__ import absolute_import from pymbolic.sympy_interface import SympyToPymbolicMapper, PymbolicToSympyMapper from sympy import simplify + def simplify_pymbolic_expression(e): sympyexpr = PymbolicToSympyMapper()(e) simplified = simplify(sympyexpr) - return SympyToPymbolicMapper()(simplified) \ No newline at end of file + return SympyToPymbolicMapper()(simplified) diff --git a/python/dune/perftool/pymbolic/uflmapper.py b/python/dune/perftool/pymbolic/uflmapper.py index 535c2af747574ecb675a4b665e193c85ba84bde8..18b03cc1c82b95ed4d53f07d9b1a3420eb2e3d47 100644 --- a/python/dune/perftool/pymbolic/uflmapper.py +++ b/python/dune/perftool/pymbolic/uflmapper.py @@ -13,6 +13,7 @@ from pymbolic.primitives import Product, Quotient, Subscript, Sum, Variable # The constructed pymbolic expressions use n-ary operators instead of ufls binary operators from dune.perftool.ufl.flatoperators import get_operands + class UFL2PymbolicMapper(MultiFunction): def __init__(self): super(UFL2PymbolicMapper, self).__init__() diff --git a/python/dune/perftool/ufl/execution.py b/python/dune/perftool/ufl/execution.py index 026f5b11d66de541f9f3e1a8a69416cdf28a5110..9e91a2680e9f6b16158520ef6bbfe8e4ea977294 100644 --- a/python/dune/perftool/ufl/execution.py +++ b/python/dune/perftool/ufl/execution.py @@ -7,6 +7,7 @@ import ufl from ufl import * from ufl.split_functions import split + class TrialFunction(ufl.Coefficient): """ A coefficient that always takes the reserved index 0 """ def __init__(self, element, count=None): @@ -14,6 +15,7 @@ class TrialFunction(ufl.Coefficient): raise ValueError("The trial function must be the coefficient of index 0 in uflpdelab") ufl.Coefficient.__init__(self, element, count=0) + class Coefficient(ufl.Coefficient): """ A coefficient that honors the reserved index 0. """ def __init__(self, element, count=None): @@ -23,8 +25,10 @@ class Coefficient(ufl.Coefficient): count = 1 ufl.Coefficient.__init__(self, element, count) + def Coefficients(element): return split(Coefficient(element)) + def TrialFunctions(element): return split(TrialFunction(element)) diff --git a/python/dune/perftool/ufl/flatoperators.py b/python/dune/perftool/ufl/flatoperators.py index 911d9bde2908c1551c150bfbbb3a8fb02a900e70..d4614896598404e4ae6a5ceb91bfb3c80ca36bfc 100644 --- a/python/dune/perftool/ufl/flatoperators.py +++ b/python/dune/perftool/ufl/flatoperators.py @@ -26,5 +26,5 @@ def construct_binary_operator(operands, operator): if len(operands) == 2: return operator(operands[0], operands[1]) else: - mid = len(operands)//2 + 1 + mid = len(operands) // 2 + 1 return operator(construct_binary_operator(operands[:mid], operator), construct_binary_operator(operands[mid:len(operands)], operator)) diff --git a/python/dune/perftool/ufl/modified_terminals.py b/python/dune/perftool/ufl/modified_terminals.py index 9c7578231294cf5d1bb5ab49e395b20120e02645..9c2ce6e8175af44147de8174550aa24078add369 100644 --- a/python/dune/perftool/ufl/modified_terminals.py +++ b/python/dune/perftool/ufl/modified_terminals.py @@ -44,7 +44,7 @@ class ModifiedTerminalTracker(MultiFunction): return ret -class ModifiedArgumentExtractor(MultiFunction): +class _ModifiedArgumentExtractor(MultiFunction): """ A multifunction that extracts and returns the set of modified arguments """ def __call__(self, o, argnumber=None, trialfunction=False): @@ -96,7 +96,11 @@ class ModifiedArgumentExtractor(MultiFunction): call = MultiFunction.__call__ -class ModifiedArgumentNumber(MultiFunction): +def extract_modified_arguments(expr, **kwargs): + return _ModifiedArgumentExtractor()(expr, **kwargs) + + +class _ModifiedArgumentNumber(MultiFunction): """ return the number() of a modified argument """ def expr(self, o): return self(o.ufl_operands[0]) @@ -105,6 +109,11 @@ class ModifiedArgumentNumber(MultiFunction): return o.number() +def modified_argument_number(expr): + """ Given an expression, return the number() of the argument in it """ + return _ModifiedArgumentNumber()(expr) + + class ModifiedArgumentDescriptor(MultiFunction): def __init__(self, e): MultiFunction.__init__(self) diff --git a/python/dune/perftool/ufl/rank.py b/python/dune/perftool/ufl/rank.py index 48f638de35557b975d1eb303a16cf42984d1dcdd..973b2c071fd422e41675d33972c03f2e1f53eaae 100644 --- a/python/dune/perftool/ufl/rank.py +++ b/python/dune/perftool/ufl/rank.py @@ -1,6 +1,7 @@ from __future__ import absolute_import from ufl.algorithms import MultiFunction + class _UFLRank(MultiFunction): def __call__(self, expr): return len(MultiFunction.__call__(self, expr)) @@ -11,5 +12,6 @@ class _UFLRank(MultiFunction): def argument(self, o): return (o.number(),) + def ufl_rank(o): - return _UFLRank()(o) \ No newline at end of file + return _UFLRank()(o) diff --git a/python/dune/perftool/ufl/transformations/__init__.py b/python/dune/perftool/ufl/transformations/__init__.py index 006e44cd0de8c31ef9d7fbd99fec8a14f8e62007..81b3ff6e88312b1b01ab95db29bb0673556b77f5 100644 --- a/python/dune/perftool/ufl/transformations/__init__.py +++ b/python/dune/perftool/ufl/transformations/__init__.py @@ -1,5 +1,6 @@ """ Define the general infrastructure for debuggable UFL transformations""" + class UFLTransformationWrapper(object): def __init__(self, func, **kwargs): # Store the decorated function @@ -22,11 +23,11 @@ class UFLTransformationWrapper(object): if get_option("print_transformations", True): import os dir = get_option("print_transformations_dir", os.getcwd()) - + for i, exprtowrite in enumerate(expr): - filename = "trafo_{}_{}_{}{}.dot".format(self.name, str(self.counter).zfill(4), "in" if before else "out", "_{}".format(i) if len(expr)>1 else "") + filename = "trafo_{}_{}_{}{}.dot".format(self.name, str(self.counter).zfill(4), "in" if before else "out", "_{}".format(i) if len(expr) > 1 else "") filename = os.path.join(dir, filename) - with open(filename,'w') as out: + with open(filename, 'w') as out: from ufl.formatting.ufl2dot import ufl2dot out.write(str(ufl2dot(exprtowrite)[0])) @@ -50,7 +51,8 @@ class UFLTransformationWrapper(object): try: assert isinstance(ret_for_print, list) and all(isinstance(e, Expr) for e in ret_for_print) except AssertionError: - from IPython import embed; embed() + from IPython import embed + embed() # Maybe output the returned expression self.write_trafo(ret_for_print, False) @@ -65,10 +67,12 @@ def ufl_transformation(_positional_arg=None, **kwargs): assert not _positional_arg return lambda f: UFLTransformationWrapper(f, **kwargs) + @ufl_transformation(name="print", printBefore=False) def print_expression(e): return e + def transform_integral(integral, trafo): from ufl import Integral assert isinstance(integral, Integral) @@ -76,6 +80,7 @@ def transform_integral(integral, trafo): return integral.reconstruct(integrand=trafo(integral.integrand())) + def transform_form(form, trafo): from ufl import Form assert isinstance(form, Form) diff --git a/python/dune/perftool/ufl/transformations/argument_elimination.py b/python/dune/perftool/ufl/transformations/argument_elimination.py index 4bb6c80e693a8f757d8ed3939baee52cef30c634..4d380009855e11312449d7845beb95c94f11e96f 100644 --- a/python/dune/perftool/ufl/transformations/argument_elimination.py +++ b/python/dune/perftool/ufl/transformations/argument_elimination.py @@ -9,63 +9,63 @@ e1=e2*v*w => (e2, (v, w)) Note that in PDELab, only test functions are arguments! Trial functions are coefficients instead. """ -from __future__ import absolute_import -from ufl.algorithms import MultiFunction - - -class EliminateArguments(MultiFunction): - """ This MultiFunction processes the expression bottom up and replaces - all modified argument by None and eliminates all None-Terms afterwards. - """ - call = MultiFunction.__call__ - - def __call__(self, o): - from dune.perftool.ufl.modified_terminals import ModifiedArgumentExtractor - - self.arguments = ModifiedArgumentExtractor()(o) - e = self.call(o) - - # Catch the case that the entire expression vanished! - if e is None: - from ufl.classes import IntValue - e = IntValue(1) - - return (e, self.arguments) - - def expr(self, o): - if o in self.arguments: - return None - else: - # Evaluate the multi function applied to the operands - newop = tuple(self.call(op) for op in o.ufl_operands) - # Find out whether an operand vanished. If so, the class needs special treatment. - if None in newop: - raise NotImplementedError("Operand vanished: {} needs special treatment in EliminateArguments".format(type(o))) - return self.reuse_if_untouched(o, *newop) - - def sum(self, o): - assert len(o.ufl_operands) == 2 - - op0 = self.call(o.ufl_operands[0]) - op1 = self.call(o.ufl_operands[1]) - - if op0 and op1: - return self.reuse_if_untouched(o, op0, op1) - # One term vanished, so there is no sum anymore - else: - if op0 or op1: - # Return the term that did not vanish - return op0 if op0 else op1 - else: - # This entire sum vanished!!! - return None - - # The handler for product is equal to the sum handler - product = sum - - def index_sum(self, o): - op = self.call(o.ufl_operands[0]) - if op: - return self.reuse_if_untouched(o, op, o.ufl_operands[1]) - else: - return None +# from __future__ import absolute_import +# from ufl.algorithms import MultiFunction +# +# +# class EliminateArguments(MultiFunction): +# """ This MultiFunction processes the expression bottom up and replaces +# all modified argument by None and eliminates all None-Terms afterwards. +# """ +# call = MultiFunction.__call__ +# +# def __call__(self, o): +# from dune.perftool.ufl.modified_terminals import ModifiedArgumentExtractor +# +# self.arguments = ModifiedArgumentExtractor()(o) +# e = self.call(o) +# +# # Catch the case that the entire expression vanished! +# if e is None: +# from ufl.classes import IntValue +# e = IntValue(1) +# +# return (e, self.arguments) +# +# def expr(self, o): +# if o in self.arguments: +# return None +# else: +# # Evaluate the multi function applied to the operands +# newop = tuple(self.call(op) for op in o.ufl_operands) +# # Find out whether an operand vanished. If so, the class needs special treatment. +# if None in newop: +# raise NotImplementedError("Operand vanished: {} needs special treatment in EliminateArguments".format(type(o))) +# return self.reuse_if_untouched(o, *newop) +# +# def sum(self, o): +# assert len(o.ufl_operands) == 2 +# +# op0 = self.call(o.ufl_operands[0]) +# op1 = self.call(o.ufl_operands[1]) +# +# if op0 and op1: +# return self.reuse_if_untouched(o, op0, op1) +# # One term vanished, so there is no sum anymore +# else: +# if op0 or op1: +# # Return the term that did not vanish +# return op0 if op0 else op1 +# else: +# # This entire sum vanished!!! +# return None +# +# # The handler for product is equal to the sum handler +# product = sum +# +# def index_sum(self, o): +# op = self.call(o.ufl_operands[0]) +# if op: +# return self.reuse_if_untouched(o, op, o.ufl_operands[1]) +# else: +# return None diff --git a/python/dune/perftool/ufl/transformations/extract_accumulation_terms.py b/python/dune/perftool/ufl/transformations/extract_accumulation_terms.py index 7f64f65b4e241c3bd2f85f25b558a522bc22e0b2..4b186da5ebee208c0cfee7fbdb49dca7bbd92c2d 100644 --- a/python/dune/perftool/ufl/transformations/extract_accumulation_terms.py +++ b/python/dune/perftool/ufl/transformations/extract_accumulation_terms.py @@ -4,67 +4,91 @@ and transforms it into a sum of terms that all contain the correct number of test function terms (which is equal to the form rank). """ from __future__ import absolute_import -from ufl.algorithms import MultiFunction - - -class SplitIntoAccumulationTerms(MultiFunction): - """ return a list of tuples of expressions and modified arguments! """ - - call = MultiFunction.__call__ - - def __call__(self, expr): - from dune.perftool.ufl.rank import ufl_rank - self.rank = ufl_rank(expr) - - from dune.perftool.ufl.modified_terminals import ModifiedArgumentExtractor - self.mae = ModifiedArgumentExtractor() - from dune.perftool.ufl.transformations.argument_elimination import EliminateArguments - self.ea = EliminateArguments() - - # Collect the found terms - self.terms = {} +from dune.perftool.ufl.modified_terminals import extract_modified_arguments +from dune.perftool.ufl.transformations import ufl_transformation +from ufl.algorithms import MultiFunction - # Now, fill the terms dict by applying this multifunction! - self.call(expr) - from dune.perftool.ufl.flatoperators import construct_binary_operator - from ufl.classes import Sum - return [(construct_binary_operator(t, Sum), a) for a, t in self.terms.items()] +@ufl_transformation(name="accterms2", extraction_lambda=lambda l: [i[0] for i in l]) +def split_into_accumulation_terms(expr): + mod_args = extract_modified_arguments(expr) - def expr(self, o): - # This needs to be a valid accumulation term! - assert all(len(self.mae(o, i)) == 1 for i in range(self.rank)) + accumulation_terms = [] + for arg in mod_args: + from dune.perftool.ufl.transformations.replace import replace_expression + from ufl.classes import Zero, IntValue + # Define a replacement map that maps the given arg to 1 and the rest to 0 + rmap = {ma: Zero() for ma in mod_args} + rmap[arg] = IntValue(1) - expr, arg = self.ea(o) - if arg not in self.terms: - self.terms[arg] = [] - self.terms[arg].append(expr) + # Do the replacement on the expression + accum_expr = replace_expression(expr, rmap) - def sum(self, o): - # Check whether this sums contains too many accumulation terms. - if not all(len(self.mae(o, i)) == 1 for i in range(self.rank)): - # This sum is part of a top level sum that separates accumulation terms! - for op in o.ufl_operands: - self.call(op) - else: - # This is a normal sum, we might treat it as any other expression - self.expr(o) + # Store the foudn accumulation expression + accumulation_terms.append((accum_expr, arg)) - def index_sum(self, o): - # Check whether this sums contains too many accumulation terms. - if not all(len(self.mae(o, i)) == 1 for i in range(self.rank)): - # This sum is part of a top level sum that separates accumulation terms! - self.call(o.ufl_operands[0]) - else: - # This is a normal sum, we might treat it as any other expression - # TODO we need to eliminate topsum indexsum regardless of the thing being valid - self.expr(o.ufl_operands[0]) - # old code - #self.expr(o) + return accumulation_terms -from dune.perftool.ufl.transformations import ufl_transformation -@ufl_transformation(name="accterms", extraction_lambda=lambda l:[i[0] for i in l]) -def split_into_accumulation_terms(expr): - return SplitIntoAccumulationTerms()(expr) +# class SplitIntoAccumulationTerms(MultiFunction): +# """ return a list of tuples of expressions and modified arguments! """ +# +# call = MultiFunction.__call__ +# +# def __call__(self, expr): +# from dune.perftool.ufl.rank import ufl_rank +# self.rank = ufl_rank(expr) +# +# from dune.perftool.ufl.modified_terminals import ModifiedArgumentExtractor +# self.mae = ModifiedArgumentExtractor() +# +# from dune.perftool.ufl.transformations.argument_elimination import EliminateArguments +# self.ea = EliminateArguments() +# +# # Collect the found terms +# self.terms = {} +# +# # Now, fill the terms dict by applying this multifunction! +# self.call(expr) +# +# from dune.perftool.ufl.flatoperators import construct_binary_operator +# from ufl.classes import Sum +# return [(construct_binary_operator(t, Sum), a) for a, t in self.terms.items()] +# +# def expr(self, o): +# # This needs to be a valid accumulation term! +# assert all(len(self.mae(o, i)) == 1 for i in range(self.rank)) +# +# expr, arg = self.ea(o) +# if arg not in self.terms: +# self.terms[arg] = [] +# self.terms[arg].append(expr) +# +# def sum(self, o): +# # Check whether this sums contains too many accumulation terms. +# if not all(len(self.mae(o, i)) == 1 for i in range(self.rank)): +# # This sum is part of a top level sum that separates accumulation terms! +# for op in o.ufl_operands: +# self.call(op) +# else: +# # This is a normal sum, we might treat it as any other expression +# self.expr(o) +# +# def index_sum(self, o): +# # Check whether this sums contains too many accumulation terms. +# if not all(len(self.mae(o, i)) == 1 for i in range(self.rank)): +# # This sum is part of a top level sum that separates accumulation terms! +# self.call(o.ufl_operands[0]) +# else: +# # This is a normal sum, we might treat it as any other expression +# # TODO we need to eliminate topsum indexsum regardless of the thing being valid +# self.expr(o.ufl_operands[0]) +# # old code +# #self.expr(o) +# +# +# from dune.perftool.ufl.transformations import ufl_transformation +# @ufl_transformation(name="accterms", extraction_lambda=lambda l:[i[0] for i in l]) +# def split_into_accumulation_terms2(expr): +# return SplitIntoAccumulationTerms()(expr) diff --git a/python/dune/perftool/ufl/transformations/indexpushdown.py b/python/dune/perftool/ufl/transformations/indexpushdown.py index 340f87ec2963a7c94dc1470cfbc4acefae2c9d39..1f1e0b64307b72b916423089ca0ca2d6871c6957 100644 --- a/python/dune/perftool/ufl/transformations/indexpushdown.py +++ b/python/dune/perftool/ufl/transformations/indexpushdown.py @@ -17,8 +17,8 @@ class IndexPushDown(MultiFunction): else: # This is a normal indexed, we treat it as any other. return self.expr(o) - + @ufl_transformation(name="index_pushdown") def pushdown_indexed(e): - return IndexPushDown()(e) \ No newline at end of file + return IndexPushDown()(e) diff --git a/python/dune/perftool/ufl/transformations/reindexing.py b/python/dune/perftool/ufl/transformations/reindexing.py index c663becc364541c225f87e6df6811540e784403a..0b405005d469d904d74d6f57df66e4c3adf6a321 100644 --- a/python/dune/perftool/ufl/transformations/reindexing.py +++ b/python/dune/perftool/ufl/transformations/reindexing.py @@ -2,52 +2,53 @@ from __future__ import absolute_import from ufl.algorithms import MultiFunction from dune.perftool.ufl.transformations import ufl_transformation + # TODO: This follows a pattern that may be worth abstracting! # It replaces the ifinstance orgy mappers when having operand dependant behaviour. class IndexedMapper(MultiFunction): - + call = MultiFunction.__call__ - + def __init__(self, rim): MultiFunction.__init__(self) # The corresponding ReindexingMapper self.rim = rim - + def __call__(self, o): from ufl.classes import Indexed assert isinstance(o, Indexed) - + self.indexed_expr = o - + return self.call(o.ufl_operands[0]) - + def expr(self, o): # If we are here, this one does not need special treatment, we use the other one # TODO: apply renaming! return self.reuse_if_untouched(self.indexed_expr, self.rim.call(o), self.rim.call(self.indexed_expr.ufl_operands[1])) - + def component_tensor(self, o): tensor, idx = o.ufl_operands - + # Extract an index mapping from itertools import izip for inner, outer in izip(idx, self.indexed_expr.ufl_operands[1]): self.rim.replacement_map[inner] = self.rim.replacement_map.setdefault(outer, outer) - + # And return an expression with the component tensor removed return self.rim.call(tensor) - + class ReindexingMapper(MultiFunction): - + call = MultiFunction.__call__ - + def __init__(self): MultiFunction.__init__(self) self.replacement_map = {} self.multi_index_cache = {} self.im = IndexedMapper(self) - + def expr(self, o): return self.reuse_if_untouched(o, *tuple(self.call(op) for op in o.ufl_operands)) @@ -63,4 +64,3 @@ class ReindexingMapper(MultiFunction): @ufl_transformation(name="reindexing") def reindexing(e): return ReindexingMapper()(e) - diff --git a/python/dune/perftool/ufl/transformations/replace.py b/python/dune/perftool/ufl/transformations/replace.py new file mode 100644 index 0000000000000000000000000000000000000000..987d89ba98cc084affee36ad8c35c97dbfaeea42 --- /dev/null +++ b/python/dune/perftool/ufl/transformations/replace.py @@ -0,0 +1,21 @@ +""" A transformation that replaces expression with others from a given dictionary """ + +from dune.perftool.ufl.transformations import ufl_transformation +from ufl.algorithms import MultiFunction + + +class ReplaceExpression(MultiFunction): + def __init__(self, replacemap={}): + MultiFunction.__init__(self) + self.replacemap = replacemap + + def expr(self, o): + if o in self.replacemap: + return self.replacemap[o] + else: + return self.reuse_if_untouched(o, *tuple(self(op) for op in o.ufl_operands)) + + +@ufl_transformation(name="replace") +def replace_expression(expr, replacemap={}): + return ReplaceExpression(replacemap)(expr) diff --git a/python/dune/perftool/ufl/transformations/splitarguments.py b/python/dune/perftool/ufl/transformations/splitarguments.py index 9304f2ab33f8a56330b574839ac4b2027826ec4f..bbc8954577cd326aa8f4a969ea31aad2e3fe6f80 100644 --- a/python/dune/perftool/ufl/transformations/splitarguments.py +++ b/python/dune/perftool/ufl/transformations/splitarguments.py @@ -1,46 +1,46 @@ -from __future__ import absolute_import -from ufl.algorithms import MultiFunction -from dune.perftool.ufl.transformations import ufl_transformation - - -class SplitArguments(MultiFunction): - - call = MultiFunction.__call__ - - def __init__(self): - MultiFunction.__init__(self) - - def __call__(self, o): - from dune.perftool.ufl.modified_terminals import ModifiedArgumentExtractor - self.ae = ModifiedArgumentExtractor() - from dune.perftool.ufl.rank import ufl_rank - self.rank = ufl_rank(o) - # call the actual recursive function - return self.call(o) - - def expr(self, o): - return self.reuse_if_untouched(o, *tuple(self.call(op) for op in o.ufl_operands)) - - def product(self, o): - from dune.perftool.ufl.flatoperators import get_operands, construct_binary_operator - from itertools import product as iterproduct - from ufl.classes import Sum, Product - - # Check whether this product is fine! - if len(self.ae(o)) == self.rank: - return self.reuse_if_untouched(o, *(self.call(op) for op in o.ufl_operands)) - - # It is not, lets switch sums and products! - # First we apply recursively to all - product_operands = [get_operands(self.call(op)) if isinstance(op, Sum) else (self.call(op),) for op in get_operands(o)] - # Multiply all terms by taking the cartesian product of terms - distributive = [f for f in iterproduct(*product_operands)] - # Prepare all sum terms by introducing products - sum_terms = [construct_binary_operator(s, Product) for s in distributive] - # Return the big sum. - return construct_binary_operator(sum_terms, Sum) - - -@ufl_transformation(name="split") -def split_arguments(expr): - return SplitArguments()(expr) +# from __future__ import absolute_import +# from ufl.algorithms import MultiFunction +# from dune.perftool.ufl.transformations import ufl_transformation + +# +# class SplitArguments(MultiFunction): +# +# call = MultiFunction.__call__ +# +# def __init__(self): +# MultiFunction.__init__(self) +# +# def __call__(self, o): +# from dune.perftool.ufl.modified_terminals import ModifiedArgumentExtractor +# self.ae = ModifiedArgumentExtractor() +# from dune.perftool.ufl.rank import ufl_rank +# self.rank = ufl_rank(o) +# # call the actual recursive function +# return self.call(o) +# +# def expr(self, o): +# return self.reuse_if_untouched(o, *tuple(self.call(op) for op in o.ufl_operands)) +# +# def product(self, o): +# from dune.perftool.ufl.flatoperators import get_operands, construct_binary_operator +# from itertools import product as iterproduct +# from ufl.classes import Sum, Product +# +# # Check whether this product is fine! +# if len(self.ae(o)) == self.rank: +# return self.reuse_if_untouched(o, *(self.call(op) for op in o.ufl_operands)) +# +# # It is not, lets switch sums and products! +# # First we apply recursively to all +# product_operands = [get_operands(self.call(op)) if isinstance(op, Sum) else (self.call(op),) for op in get_operands(o)] +# # Multiply all terms by taking the cartesian product of terms +# distributive = [f for f in iterproduct(*product_operands)] +# # Prepare all sum terms by introducing products +# sum_terms = [construct_binary_operator(s, Product) for s in distributive] +# # Return the big sum. +# return construct_binary_operator(sum_terms, Sum) +# +# +# @ufl_transformation(name="split") +# def split_arguments(expr): +# return SplitArguments()(expr)