From 632c592538821cd3c77586f8919ececada53f89f Mon Sep 17 00:00:00 2001
From: Marcel Koch <marcel.koch@uni-muenster.de>
Date: Fri, 14 Jul 2017 18:12:42 +0200
Subject: [PATCH] Sets right type for DG LocalBasis in the default case

---
 python/dune/perftool/blockstructured/basis.py | 34 ++++++++++++++++-
 python/dune/perftool/pdelab/basis.py          | 37 +++++--------------
 2 files changed, 42 insertions(+), 29 deletions(-)

diff --git a/python/dune/perftool/blockstructured/basis.py b/python/dune/perftool/blockstructured/basis.py
index 4b32d5d9..ced039c9 100644
--- a/python/dune/perftool/blockstructured/basis.py
+++ b/python/dune/perftool/blockstructured/basis.py
@@ -5,14 +5,46 @@ from dune.perftool.generation import (backend,
                                       temporary_variable,
                                       globalarg,
                                       class_member,
-                                      initializer_list)
+                                      initializer_list,
+                                      include_file,)
 from dune.perftool.tools import get_pymbolic_basename
 from dune.perftool.pdelab.basis import (declare_cache_temporary,
                                         name_localbasis_cache,
                                         type_localbasis
                                         )
+from dune.perftool.pdelab.driver import (basetype_range,
+                                         isPk,
+                                         isQk)
 from dune.perftool.pdelab.geometry import world_dimension
 from dune.perftool.pdelab.quadrature import pymbolic_quadrature_position_in_cell
+from dune.perftool.pdelab.spaces import type_gfs
+
+
+@backend(interface="typedef_localbasis", name="blockstructured")
+@class_member(classtag="operator")
+def typedef_localbasis(element, name):
+    df = "typename {}::Traits::GridView::ctype".format(type_gfs(element))
+    r = basetype_range()
+    dim = world_dimension()
+    if isPk(element):
+        if dim == 1:
+            include_file("dune/localfunctions/lagrange/pk1d/pk1dlocalbasis.hh", filetag="operatorfile")
+            basis_type = "Pk1DLocalBasis<{}, {}, {}>".format(df, r, element._degree)
+        elif dim == 2:
+            include_file("dune/localfunctions/lagrange/pk2d/pk2dlocalbasis.hh", filetag="operatorfile")
+            basis_type = "Pk2DLocalBasis<{}, {}, {}>".format(df, r, element._degree)
+        elif dim == 3:
+            include_file("dune/localfunctions/lagrange/pk3d/pk3dlocalbasis.hh", filetag="operatorfile")
+            basis_type = "Pk3DLocalBasis<{}, {}, {}>".format(df, r, element._degree)
+        else:
+            raise NotImplementedError("P{} in {}D is not implemented".format(element._degree, dim))
+    elif isQk(element):
+        include_file("dune/localfunctions/lagrange/qk/qklocalbasis.hh", filetag="operatorfile")
+        basis_type = "QkLocalBasis<{}, {}, {}, {}>".format(df, r, element._degree, dim)
+    else:
+        raise NotImplementedError("Element type not known in code generation")
+    return "using {} = Dune::{};".format(name, basis_type)
+
 
 
 @backend(interface="evaluate_basis", name="blockstructured")
diff --git a/python/dune/perftool/pdelab/basis.py b/python/dune/perftool/pdelab/basis.py
index b15497c0..5fec4cc6 100644
--- a/python/dune/perftool/pdelab/basis.py
+++ b/python/dune/perftool/pdelab/basis.py
@@ -23,39 +23,18 @@ from dune.perftool.tools import (get_pymbolic_basename,
                                  )
 from dune.perftool.pdelab.driver import FEM_name_mangling
 from dune.perftool.pdelab.restriction import restricted_name
-from dune.perftool.pdelab.driver import (type_domainfield,
-                                         basetype_range,
-                                         isPk,
-                                         isQk,)
+from dune.perftool.pdelab.driver import (isPk,
+                                         isQk,
+                                         isDG)
 from pymbolic.primitives import Product, Subscript, Variable
-import pymbolic.primitives as prim
 from loopy import Reduction
 
 
+@backend(interface="typedef_localbasis")
 @class_member(classtag="operator")
 def typedef_localbasis(element, name):
-    df = "typename {}::Traits::GridView::ctype".format(type_gfs(element))
-    r = basetype_range()
-    dim = world_dimension()
-    if isPk(element):
-        if dim == 1:
-            include_file("dune/localfunctions/lagrange/pk1d/pk1dlocalbasis.hh", filetag="operatorfile")
-            basis_type = "Pk1DLocalBasis<{}, {}, {}>".format(df, r, element._degree)
-        elif dim == 2:
-            include_file("dune/localfunctions/lagrange/pk2d/pk2dlocalbasis.hh", filetag="operatorfile")
-            basis_type = "Pk2DLocalBasis<{}, {}, {}>".format(df, r, element._degree)
-        elif dim == 3:
-            include_file("dune/localfunctions/lagrange/pk3d/pk3dlocalbasis.hh", filetag="operatorfile")
-            basis_type = "Pk3DLocalBasis<{}, {}, {}>".format(df, r, element._degree)
-        else:
-            raise NotImplementedError("P{} in {}D is not implemented".format(element._degree, dim))
-    elif isQk(element):
-        include_file("dune/localfunctions/lagrange/qk/qklocalbasis.hh", filetag="operatorfile")
-        basis_type = "QkLocalBasis<{}, {}, {}, {}>".format(df, r, element._degree, dim)
-    #TODO add dg support
-    else:
-        raise NotImplementedError("Element type not known in code generation")
-    return "using {} = Dune::{};".format(name, basis_type)
+    basis_type = "{}::Traits::FiniteElementMap::Traits::FiniteElementType::Traits::LocalBasisType".format(type_gfs(element))
+    return "using {} = typename {};".format(name, basis_type)
 
 
 def type_localbasis(element):
@@ -63,9 +42,11 @@ def type_localbasis(element):
         name = "P{}_LocalBasis".format(element._degree)
     elif isQk(element):
         name = "Q{}_LocalBasis".format(element._degree)
+    elif isDG(element):
+        name = "DG{}_LocalBasis".format(element._degree)
     else:
         raise NotImplementedError("Element type not known in code generation")
-    typedef_localbasis(element, name)
+    get_backend("typedef_localbasis", selector=option_switch("blockstructured"))(element, name)
     return name
 
 
-- 
GitLab