From f0c015cc35112302272e092309b2722dc14dd781 Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Tue, 19 Apr 2016 17:57:07 +0200
Subject: [PATCH] Getting loval basis caches into code generation

as a consequence introduce template parameters
---
 python/dune/perftool/cgen/clazz.py           |  8 +-
 python/dune/perftool/generation/__init__.py  |  2 +
 python/dune/perftool/generation/cpp.py       | 12 +++
 python/dune/perftool/generation/loopy.py     | 38 ++++++--
 python/dune/perftool/pdelab/basis.py         | 95 ++++++++++++++++----
 python/dune/perftool/pdelab/driver.py        | 12 ++-
 python/dune/perftool/pdelab/geometry.py      |  2 +-
 python/dune/perftool/pdelab/localoperator.py | 63 ++++++++++---
 python/dune/perftool/pdelab/quadrature.py    |  2 +
 9 files changed, 189 insertions(+), 45 deletions(-)

diff --git a/python/dune/perftool/cgen/clazz.py b/python/dune/perftool/cgen/clazz.py
index 2586614a..8d3c9985 100644
--- a/python/dune/perftool/cgen/clazz.py
+++ b/python/dune/perftool/cgen/clazz.py
@@ -102,9 +102,6 @@ class Class(Generable):
             assert isinstance(bc, BaseClass)
         for mem in members:
             assert isinstance(mem, ClassMember)
-        from cgen import Declarator
-        for tp in tparam_decls:
-            assert isinstance(tp, Declarator)
         for con in constructors:
             assert isinstance(con, Constructor)
 
@@ -114,8 +111,9 @@ class Class(Generable):
         decl = Value('class', self.name)
 
         if self.tparam_decls:
-            from cgen import Template
-            decl = Template(self.tparam_decls, decl)
+            yield 'template<'
+            yield ', '.join('typename {}'.format(t) for t in self.tparam_decls)
+            yield '>\n'
 
         # Yield the definition
         for line in decl.generate(with_semicolon=False):
diff --git a/python/dune/perftool/generation/__init__.py b/python/dune/perftool/generation/__init__.py
index 79480b15..a51143e6 100644
--- a/python/dune/perftool/generation/__init__.py
+++ b/python/dune/perftool/generation/__init__.py
@@ -7,12 +7,14 @@ from dune.perftool.generation.cache import (cached,
                                             )
 
 from dune.perftool.generation.cpp import (base_class,
+                                          class_basename,
                                           class_member,
                                           constructor_parameter,
                                           include_file,
                                           initializer_list,
                                           preamble,
                                           symbol,
+                                          template_parameter,
                                           )
 
 from dune.perftool.generation.loopy import (domain,
diff --git a/python/dune/perftool/generation/cpp.py b/python/dune/perftool/generation/cpp.py
index 4a0757cb..9bed7f13 100644
--- a/python/dune/perftool/generation/cpp.py
+++ b/python/dune/perftool/generation/cpp.py
@@ -51,3 +51,15 @@ def constructor_parameter(_type, name, classtag=None, constructortag=None):
 
     gen = generator_factory(item_tags=(classtag, constructortag, "constructor_param"), counted=True, no_deco=True)
     return gen(Value(_type, name))
+
+
+def template_parameter(classtag=None):
+    assert classtag
+
+    return generator_factory(item_tags=(classtag, "template_param"), counted=True)
+
+
+def class_basename(classtag=None):
+    assert classtag
+
+    return generator_factory(item_tags=(classtag, "basename"))
diff --git a/python/dune/perftool/generation/loopy.py b/python/dune/perftool/generation/loopy.py
index c8d6ee97..37b68b06 100644
--- a/python/dune/perftool/generation/loopy.py
+++ b/python/dune/perftool/generation/loopy.py
@@ -29,11 +29,34 @@ def domain(iname, shape):
     return "{{ [{0}] : 0<={0}<{1} }}".format(iname, shape)
 
 
+def _temporary_type(shape_impl, shape, first=True):
+    if len(shape_impl) == 0:
+        return 'double'
+    if shape_impl[0] == 'arr':
+        if not first or len(set(shape_impl)) != 1:
+            raise ValueError("We do not allow mixing of C++ containers and plain C arrays, for reasons of mental sanity")
+        return 'double'
+    if shape_impl[0] == 'vec':
+        return "std::vector<{}>".format(_temporary_type(shape_impl[1:], shape[1:], first=False))
+    if shape_impl[0] == 'fv':
+        return "Dune::FieldVector<{}, {}>".format(_temporary_type(shape_impl[1:], shape[1:], first=False), shape[0])
+    if shape_impl[0] == 'fm':
+        pass
+
+
 @preamble
-def default_temporary_declaration(name, **kwargs):
-    shape = kwargs.get('shape', ())
-    array_bounds = ''.join('[{}]'.format(s) for s in shape)
-    return 'double {}{};'.format(name, array_bounds)
+def default_declaration(name, shape, shape_impl):
+    # Determine the C++ type to use for this temporary.
+    t = _temporary_type(shape_impl, shape)
+    if len(shape_impl) == 0:
+        # This is a scalar, just return it!
+        return '{} {}(0.0);'.format(t, name);
+    if shape_impl[0] == 'arr':
+        return '{} {}{};'.format(t, name, ''.join('[{}]'.format(s) for s in shape))
+    if shape_impl[0] == 'vec':
+        return '{} {}({});'.format(t, name, shape[0])
+    if shape_impl[0] == 'fv':
+        return '{} {}(0.0);'.format(t, name)
 
 
 @generator_factory(item_tags=("loopy", "kernel", "temporary"), cache_key_generator=lambda n, **kw: n)
@@ -41,8 +64,11 @@ def temporary_variable(name, **kwargs):
     if 'dtype' not in kwargs:
         kwargs['dtype'] = numpy.float64
 
-    decl_method = kwargs.pop('decl_method', default_temporary_declaration)
-    decl_method(name, **kwargs)
+    decl_method = kwargs.pop('decl_method', default_declaration)
+    shape = kwargs.get('shape', ())
+    shape_impl = kwargs.pop('shape_impl', ('arr',)*len(shape))
+
+    decl_method(name, shape, shape_impl)
 
     return loopy.TemporaryVariable(name, **kwargs)
 
diff --git a/python/dune/perftool/pdelab/basis.py b/python/dune/perftool/pdelab/basis.py
index 4cdff4d7..00042f14 100644
--- a/python/dune/perftool/pdelab/basis.py
+++ b/python/dune/perftool/pdelab/basis.py
@@ -1,9 +1,11 @@
 """ Generators for basis evaluations """
 
 from dune.perftool.generation import (cached,
+                                      class_member,
                                       domain,
                                       generator_factory,
                                       iname,
+                                      include_file,
                                       instruction,
                                       preamble,
                                       symbol,
@@ -15,6 +17,29 @@ from dune.perftool.pdelab.quadrature import (name_quadrature_point,
 from dune.perftool.pdelab.geometry import (name_dimension,
                                            name_jacobian_inverse_transposed,
                                            )
+from dune.perftool.pdelab.localoperator import (lop_template_ansatz_gfs,
+                                                lop_template_test_gfs,
+                                                )
+from dune.perftool.pdelab.driver import FEM_name_mangling
+
+
+@symbol
+def type_localbasis_cache(element):
+    return "Dune::PDELab::LocalBasisCache<typename {}::Traits::FiniteElementMap::Traits::FiniteElementType::Traits::LocalBasisType>".format(type_gfs(element))
+
+
+@class_member("operator")
+def define_localbasis_cache(element, name):
+    include_file("dune/pdelab/finiteelement/localbasiscache.hh", filetag="operatorfile")
+    t = type_localbasis_cache(element)
+    return "{} {};".format(t, name)
+
+
+@symbol
+def name_localbasis_cache(element):
+    name = "cache_{}".format(FEM_name_mangling(element))
+    define_localbasis_cache(element, name)
+    return name
 
 
 @preamble
@@ -58,19 +83,46 @@ def name_lfs(element, prefix=None):
     return prefix
 
 
+@generator_factory(cache_key_generator=lambda e, **kw: e)
+def type_gfs(element, basetype=None, index_stack=None):
+    # Omitting basetype and index_stack is only valid upon a second call,
+    # which will result in a cache hit.
+    assert basetype
+    assert index_stack is not None
+
+    # Additionally, element is expected to be a ufl finite element
+    from ufl import FiniteElementBase
+    assert isinstance(element, FiniteElementBase)
+
+    # Recurse into the given element to define all other local function spaces!
+    from ufl import MixedElement
+    if isinstance(element, MixedElement):
+        for i, subelem in enumerate(element.sub_elements()):
+            type_gfs(subelem, basetype=basetype, index_stack=index_stack + (i,))
+
+    if len(index_stack) == 0:
+        return basetype
+    else:
+        include_file("dune/typetree/childextraction.hh", filetag="operatorfile")
+        return 'Dune::TypeTree::Child<{},{}>'.format(basetype, ','.join(str(i) for i in index_stack))
+
+
 def traverse_lfs_tree(arg):
     from dune.perftool.ufl.modified_terminals import ModifiedArgumentDescriptor
     assert isinstance(arg, ModifiedArgumentDescriptor)
 
     # First we need to determine the basename as given in the signature of
     # this kernel method!
-    basename = None
+    lfs_basename = None
+    gfs_basename = None
     from ufl.classes import Argument, Coefficient
     if isinstance(arg.argexpr, Argument):
         if arg.argexpr.count() == 0:
-            basename = 'lfsv'
+            lfs_basename = 'lfsv'
+            gfs_basename = 'GFSV'
         if arg.argexpr.count() == 1:
-            basename = 'lfsu'
+            lfs_basename = 'lfsu'
+            gfs_basename = 'GFSU'
         # TODO add restrictions here.
 
     if isinstance(arg.argexpr, Coefficient):
@@ -78,14 +130,16 @@ def traverse_lfs_tree(arg):
         # is the coefficient of reserved index 0.
         assert arg.argexpr.count() == 0
 
-        basename = 'lfsu'
+        lfs_basename = 'lfsu'
+        gfs_basename = 'GFSU'
 
-    assert basename
+    assert lfs_basename and gfs_basename
 
     # Now start recursively extracting local function spaces and fill the cache with
     # all those values. That way we can later get a correct local function space with
     # just the ufl finite element.
-    name_lfs(arg.argexpr.element(), prefix=basename)
+    name_lfs(arg.argexpr.element(), prefix=lfs_basename)
+    type_gfs(arg.argexpr.element(), basetype=gfs_basename, index_stack=())
 
 
 @iname
@@ -124,15 +178,17 @@ def lfs_iname(element, argcount=0, context=''):
 
 @cached
 def evaluate_basis(element, name):
-    temporary_variable(name, shape=(name_lfs_bound(element),))
+    temporary_variable(name, shape=(name_lfs_bound(element),), shape_impl=('vec',))
+    cache = name_localbasis_cache(element)
     lfs = name_lfs(element)
     qp = name_quadrature_point()
     instruction(inames=(quadrature_iname(),
                         ),
-                code='{}.finiteElement().localBasis().evaluateFunction({}, {});'.format(lfs,
-                                                                                        qp,
-                                                                                        name,
-                                                                                        ),
+                code='auto& {} = {}.evaluateFunction({}, {}.finiteElement().localBasis());'.format(name,
+                                                                                                   cache,
+                                                                                                   qp,
+                                                                                                   lfs,
+                                                                                                   ),
                 assignees=frozenset({name}),
                 )
 
@@ -145,16 +201,17 @@ def name_basis(element):
 
 @cached
 def evaluate_reference_gradient(element, name):
-    # TODO this is of course not yet correct
-    temporary_variable(name, shape=(name_lfs_bound(element), name_dimension()))
+    temporary_variable(name, shape=(name_lfs_bound(element), name_dimension()), shape_impl=('vec', 'fv'))
+    cache = name_localbasis_cache(element)
     lfs = name_lfs(element)
     qp = name_quadrature_point()
     instruction(inames=(quadrature_iname(),
                         ),
-                code='{}.finiteElement().localBasis().evaluateJacobian({}, {});'.format(lfs,
-                                                                                        qp,
-                                                                                        name,
-                                                                                        ),
+                code='auto& {} = {}.evaluateJacobian({}, {}.finiteElement().localBasis());'.format(name,
+                                                                                                   cache,
+                                                                                                   qp,
+                                                                                                   lfs,
+                                                                                                   ),
                 assignees=frozenset({name}),
                 )
 
@@ -168,7 +225,7 @@ def name_reference_gradient(element):
 @cached
 def evaluate_basis_gradient(element, name):
     # TODO this is of course not yet correct
-    temporary_variable(name, shape=(name_lfs_bound(element), name_dimension()))
+    temporary_variable(name, shape=(name_lfs_bound(element), name_dimension()), shape_impl=('vec', 'fv'))
     jac = name_jacobian_inverse_transposed()
     index = lfs_iname(element, context='transformgrads')
     reference_gradients = name_reference_gradient(element)
@@ -215,7 +272,7 @@ def evaluate_trialfunction(element, name):
 @cached
 def evaluate_trialfunction_gradient(element, name):
     # TODO this is of course not yet correct
-    temporary_variable(name, shape=(name_dimension(),))
+    temporary_variable(name, shape=(name_dimension(),), shape_impl=('fv',))
     lfs = name_lfs(element)
     index = lfs_iname(element, context='trialgrad')
     basis = name_basis_gradient(element)
diff --git a/python/dune/perftool/pdelab/driver.py b/python/dune/perftool/pdelab/driver.py
index 0dea6a07..4a286b3f 100644
--- a/python/dune/perftool/pdelab/driver.py
+++ b/python/dune/perftool/pdelab/driver.py
@@ -438,15 +438,17 @@ def typedef_localoperator(name):
     # No Parameter class here, yet
     # params = type_parameters()
     # return "typedef LocalOperator<{}> {};".format(params, name)
+    ugfs = type_gfs(_form.coefficients()[0].element())
+    vgfs = type_gfs(_form.arguments()[0].element())
     from dune.perftool.options import get_option
     include_file(get_option('operator_file'), filetag="driver")
-    return "// Here in the future: typedef for the local operator with parameter class as template parameter"
+    return "using {} = LocalOperator<{}, {}>;".format(name, ugfs, vgfs)
 
 
 @symbol
 def type_localoperator():
-    typedef_localoperator("LocalOperator")
-    return "LocalOperator"
+    typedef_localoperator("LOP")
+    return "LOP"
 
 
 @preamble
@@ -493,7 +495,9 @@ def define_gridoperator(name):
     vcc = name_constraintscontainer(_form.arguments()[0].element())
     lop = name_localoperator()
     mb = name_matrixbackend()
-    return "{} {}({}, {}, {}, {}, {}, {});".format(gotype, name, ugfs, ucc, vgfs, vcc, lop, mb)
+    return ["{} {}({}, {}, {}, {}, {}, {});".format(gotype, name, ugfs, ucc, vgfs, vcc, lop, mb),
+            "std::cout << \"gfs with \" << {}.size() << \" dofs generated  \"<< std::endl;".format(ugfs),
+            "std::cout << \"cc with \" << {}.size() << \" dofs generated  \"<< std::endl;".format(ucc)]
 
 
 @symbol
diff --git a/python/dune/perftool/pdelab/geometry.py b/python/dune/perftool/pdelab/geometry.py
index 5a12c91b..355a4d6a 100644
--- a/python/dune/perftool/pdelab/geometry.py
+++ b/python/dune/perftool/pdelab/geometry.py
@@ -35,7 +35,7 @@ def name_dimension():
 
 
 @preamble
-def define_jacobian_inverse_transposed_temporary(name, **kwargs):
+def define_jacobian_inverse_transposed_temporary(name, shape, shape_impl):
     geo = name_geometry()
     return "auto {} = {}.jacobianInverseTransposed({{{{ 0.0 }}}});".format(name,
                                                                 geo,
diff --git a/python/dune/perftool/pdelab/localoperator.py b/python/dune/perftool/pdelab/localoperator.py
index 294a7b0e..4671eba2 100644
--- a/python/dune/perftool/pdelab/localoperator.py
+++ b/python/dune/perftool/pdelab/localoperator.py
@@ -1,12 +1,33 @@
 from __future__ import absolute_import
 
 from dune.perftool.options import get_option
-from dune.perftool.generation import include_file, base_class, symbol, initializer_list, class_member, constructor_parameter
-from dune.perftool.cgen.clazz import AccessModifier, BaseClass, ClassMember
+from dune.perftool.generation import (base_class,
+                                      class_basename,
+                                      class_member,
+                                      constructor_parameter,
+                                      include_file,
+                                      initializer_list,
+                                      symbol,
+                                      template_parameter,
+                                      )
+from dune.perftool.cgen.clazz import (AccessModifier,
+                                      BaseClass,
+                                      ClassMember,
+                                      )
 
 from pytools import memoize
 
 
+@template_parameter("operator")
+def lop_template_ansatz_gfs():
+    return "GFSU"
+
+
+@template_parameter("operator")
+def lop_template_test_gfs():
+    return "GFSV"
+
+
 @symbol
 def name_initree_constructor_param():
     return "iniParams"
@@ -38,12 +59,27 @@ def name_initree_member():
     return "_iniParams"
 
 
-@symbol
-def localoperator_type():
-    # TODO use something from the form here to make it unique
+@class_basename("operator")
+def localoperator_basename():
     return "LocalOperator"
 
 
+def class_type_from_cache(classtag):
+    from dune.perftool.generation import retrieve_cache_items
+
+    # get the basename
+    basename = [i for i in retrieve_cache_items(condition="{} and basename".format(classtag))]
+    assert len(basename) == 1
+    basename = basename[0]
+
+    # get the template parameters
+    tparams = [i for i in retrieve_cache_items(condition="{} and template_param".format(classtag))]
+    tparam_str = ''
+    if len(tparams) > 0:
+        tparam_str = '<{}>'.format(', '.join(t for t in tparams))
+
+    return basename, basename + tparam_str
+
 @memoize
 def measure_specific_details(measure):
     # The return dictionary that this memoized method will grant direct access to.
@@ -52,8 +88,7 @@ def measure_specific_details(measure):
     def numerical_jacobian(which):
         if get_option("numerical_jacobian"):
             # Add a base class
-            from dune.perftool.pdelab.driver import type_localoperator
-            loptype = type_localoperator()
+            _, loptype = class_type_from_cache("operator")
             base_class("Dune::PDELab::NumericalJacobian{}<{}>".format(which, loptype), classtag="operator")
 
             # Add the initializer list for that base class
@@ -149,16 +184,20 @@ class AssemblyMethod(ClassMember):
 def cgen_class_from_cache(tag, members=[]):
     from dune.perftool.generation import retrieve_cache_items
 
+    # Generate the name by concatenating basename and template parameters
+    basename, fullname = class_type_from_cache("operator")
+
     base_classes = [bc for bc in retrieve_cache_items('{} and baseclass'.format(tag))]
     constructor_params = [bc for bc in retrieve_cache_items('{} and constructor_param'.format(tag))]
     il = [i for i in retrieve_cache_items('{} and initializer'.format(tag))]
     pm = [m for m in retrieve_cache_items('{} and member'.format(tag))]
+    tparams = [i for i in retrieve_cache_items('{} and template_param'.format(tag))]
 
     from dune.perftool.cgen.clazz import Constructor
-    constructor = Constructor(arg_decls=constructor_params, clsname=localoperator_type(), initializer_list=il)
+    constructor = Constructor(arg_decls=constructor_params, clsname=basename, initializer_list=il)
 
     from dune.perftool.cgen import Class
-    return Class(localoperator_type(), base_classes=base_classes, members=members + pm, constructors=[constructor])
+    return Class(basename, base_classes=base_classes, members=members + pm, constructors=[constructor], tparam_decls=tparams)
 
 
 def generate_localoperator_kernels(form):
@@ -175,7 +214,11 @@ def generate_localoperator_kernels(form):
     include_file('dune/pdelab/localoperator/idefault.hh', filetag="operatorfile")
     include_file('dune/pdelab/localoperator/flags.hh', filetag="operatorfile")
     include_file('dune/pdelab/localoperator/pattern.hh', filetag="operatorfile")
-    include_file('dune/geometry/quadraturerules.hh', filetag="operatorfile")
+
+    # Trigger this one once early on to avoid wrong stuff happening
+    localoperator_basename()
+    lop_template_ansatz_gfs()
+    lop_template_test_gfs()
 
     base_class('Dune::PDELab::LocalOperatorDefaultFlags', classtag="operator")
 
diff --git a/python/dune/perftool/pdelab/quadrature.py b/python/dune/perftool/pdelab/quadrature.py
index eb5b724e..b4555eef 100644
--- a/python/dune/perftool/pdelab/quadrature.py
+++ b/python/dune/perftool/pdelab/quadrature.py
@@ -1,6 +1,7 @@
 from dune.perftool.generation import (cached,
                                       domain,
                                       iname,
+                                      include_file,
                                       instruction,
                                       symbol,
                                       temporary_variable,
@@ -64,6 +65,7 @@ def name_order():
 
 @cached
 def quadrature_loop_statement():
+    include_file('dune/pdelab/common/quadraturerules.hh', filetag='operatorfile')
     qp = name_quadrature_point()
     from dune.perftool.pdelab.geometry import name_geometry
     geo = name_geometry()
-- 
GitLab