From 27e8905bb6fbffe60c0eb14af199eca749424553 Mon Sep 17 00:00:00 2001
From: Marcel Koch <marcel.koch@uni-muenster.de>
Date: Thu, 13 Jul 2017 15:23:01 +0200
Subject: [PATCH] Adds localBasis function object to operator

---
 python/dune/perftool/blockstructured/basis.py | 43 ++++++++++++-------
 python/dune/perftool/pdelab/basis.py          | 22 ++++++++--
 .../ufl/transformations/blockstructured.py    | 39 -----------------
 3 files changed, 45 insertions(+), 59 deletions(-)
 delete mode 100644 python/dune/perftool/ufl/transformations/blockstructured.py

diff --git a/python/dune/perftool/blockstructured/basis.py b/python/dune/perftool/blockstructured/basis.py
index 2f545523..4b32d5d9 100644
--- a/python/dune/perftool/blockstructured/basis.py
+++ b/python/dune/perftool/blockstructured/basis.py
@@ -2,16 +2,17 @@ from dune.perftool.generation import (backend,
                                       kernel_cached,
                                       get_backend,
                                       instruction,
-                                      temporary_variable)
+                                      temporary_variable,
+                                      globalarg,
+                                      class_member,
+                                      initializer_list)
 from dune.perftool.tools import get_pymbolic_basename
-from dune.perftool.pdelab.driver import FEM_name_mangling
-from dune.perftool.pdelab.restriction import restricted_name
 from dune.perftool.pdelab.basis import (declare_cache_temporary,
-                                        name_localbasis_cache)
+                                        name_localbasis_cache,
+                                        type_localbasis
+                                        )
 from dune.perftool.pdelab.geometry import world_dimension
 from dune.perftool.pdelab.quadrature import pymbolic_quadrature_position_in_cell
-from dune.perftool.blockstructured.spaces import lfs_inames
-import pymbolic.primitives as prim
 
 
 @backend(interface="evaluate_basis", name="blockstructured")
@@ -20,11 +21,9 @@ def evaluate_basis(leaf_element, name, restriction):
     temporary_variable(name, shape=(4,), decl_method=declare_cache_temporary(leaf_element, restriction, 'Function'))
     cache = name_localbasis_cache(leaf_element)
     qp = pymbolic_quadrature_position_in_cell(restriction)
+    localbasis = name_localbasis(leaf_element)
     instruction(inames=get_backend("quad_inames")(),
-                code='{} = {}.evaluateFunction({}, lfs.finiteElement().localBasis());'.format(name,
-                                                                                             cache,
-                                                                                             str(qp),
-                                                                                             ),
+                code='{} = {}.evaluateFunction({}, {});'.format(name, cache, str(qp), localbasis),
                 assignees=frozenset({name}),
                 read_variables=frozenset({get_pymbolic_basename(qp)}),
                 )
@@ -33,14 +32,26 @@ def evaluate_basis(leaf_element, name, restriction):
 @backend(interface="evaluate_grad", name="blockstructured")
 @kernel_cached
 def evaluate_reference_gradient(leaf_element, name, restriction):
-    temporary_variable(name, shape=(4,1,world_dimension()), decl_method=declare_cache_temporary(leaf_element, restriction, 'Jacobian'))
+    temporary_variable(name, shape=(4, 1, world_dimension()), decl_method=declare_cache_temporary(leaf_element, restriction, 'Jacobian'))
     cache = name_localbasis_cache(leaf_element)
     qp = pymbolic_quadrature_position_in_cell(restriction)
+    localbasis = name_localbasis(leaf_element)
     instruction(inames=get_backend("quad_inames")(),
-                code='{} = {}.evaluateJacobian({}, lfs.finiteElement().localBasis());'.format(name,
-                                                                                             cache,
-                                                                                             str(qp),
-                                                                                             ),
+                code='{} = {}.evaluateJacobian({}, {});'.format(name, cache, str(qp), localbasis),
                 assignees=frozenset({name}),
                 read_variables=frozenset({get_pymbolic_basename(qp)}),
-                )
\ No newline at end of file
+                )
+
+
+@class_member(classtag="operator")
+def define_localbasis(leaf_element, name):
+    localBasis_type = type_localbasis(leaf_element)
+    initializer_list(name, (), classtag="operator")
+    return "const {} {};".format(localBasis_type, name)
+
+
+def name_localbasis(leaf_element):
+    name = "microElementBasis"
+    globalarg(name)
+    define_localbasis(leaf_element, name)
+    return name
diff --git a/python/dune/perftool/pdelab/basis.py b/python/dune/perftool/pdelab/basis.py
index 2195e21f..fd51a2e6 100644
--- a/python/dune/perftool/pdelab/basis.py
+++ b/python/dune/perftool/pdelab/basis.py
@@ -32,18 +32,32 @@ import pymbolic.primitives as prim
 from loopy import Reduction
 
 
-def type_localbasis(element):
+@class_member(classtag="operator")
+def typedef_localbasis(element, name):
     df = "typename {}::Traits::GridView::ctype".format(type_gfs(element))
     r = basetype_range()
     dim = world_dimension()
     # if isPk(element):
     #     include_file("dune/localfunctions/lagrange/pk/pklocalbasis.hh", filetag="operatorfile")
     #     return "Dune::PkLocalBasis<{}, {}, {}, {}>".format(df, r, dim, element._degree)
+    #TODO add dg support
     if isQk(element):
         include_file("dune/localfunctions/lagrange/qk/qklocalbasis.hh", filetag="operatorfile")
-        return "Dune::QkLocalBasis<{}, {}, {}, {}>".format(df, r, dim, element._degree)
-    #TODO add dg support
-    raise NotImplementedError("Element type not known in code generation")
+        basis_type = "QkLocalBasis<{}, {}, {}, {}>".format(df, r, element._degree, dim)
+    else:
+        raise NotImplementedError("Element type not known in code generation")
+    return "using {} = Dune::{};".format(name, basis_type)
+
+
+def type_localbasis(element):
+    if isPk(element):
+        name = "P{}_LocalBasis".format(element._degree)
+    elif isQk(element):
+        name = "Q{}_LocalBasis".format(element._degree)
+    else:
+        raise NotImplementedError("Element type not known in code generation")
+    typedef_localbasis(element, name)
+    return name
 
 
 def type_localbasis_cache(element):
diff --git a/python/dune/perftool/ufl/transformations/blockstructured.py b/python/dune/perftool/ufl/transformations/blockstructured.py
deleted file mode 100644
index d25c95e1..00000000
--- a/python/dune/perftool/ufl/transformations/blockstructured.py
+++ /dev/null
@@ -1,39 +0,0 @@
-from dune.perftool.options import get_option
-from dune.perftool.ufl.transformations import ufl_transformation
-from dune.perftool.ufl.transformations.replace import ReplaceExpression
-from ufl.algorithms import MultiFunction
-from ufl import as_ufl
-from ufl.classes import JacobianInverse, JacobianDeterminant, Product, Division, Indexed
-
-
-class ReplaceReferenceTransformation(MultiFunction):
-    def __init__(self, k):
-        MultiFunction.__init__(self)
-        self.k = k
-        self.visited_jit = False
-
-    def expr(self, o):
-        return self.reuse_if_untouched(o, *tuple(self(op) for op in o.ufl_operands))
-
-#TODO abs uses c abs -> only works for ints!!!
-    def abs(self,o):
-        if isinstance(o.ufl_operands[0], JacobianDeterminant):
-            return Division(o, as_ufl(self.k**2))
-        else:
-            return self.reuse_if_untouched(o, *tuple(self(op) for op in o.ufl_operands))
-
-    def jacobian_determinant(self,o):
-        return Division(o, as_ufl(self.k**2))
-
-    def indexed(self, o):
-        expr = o.ufl_operands[0]
-        multiindex = o.ufl_operands[1]
-        if isinstance(expr, JacobianInverse):
-            return Product(as_ufl(self.k), Indexed(expr, multiindex))
-        else:
-            return self.reuse_if_untouched(o, *tuple(self(op) for op in o.ufl_operands))
-
-
-@ufl_transformation(name="blockstructured")
-def blockstructured(expr):
-    return ReplaceReferenceTransformation(get_option("number_of_blocks"))(expr)
\ No newline at end of file
-- 
GitLab