From d8f1c90a2fb2bd9f606e7485c0d6d05c88fb8644 Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Tue, 10 Apr 2018 13:44:41 +0200
Subject: [PATCH] Fix operator counting in sumfact kernel function signatures

Avoid explicit doubles.
---
 python/dune/perftool/sumfact/accumulation.py | 3 ++-
 python/dune/perftool/sumfact/basis.py        | 3 ++-
 python/dune/perftool/sumfact/symbolic.py     | 6 +++---
 3 files changed, 7 insertions(+), 5 deletions(-)

diff --git a/python/dune/perftool/sumfact/accumulation.py b/python/dune/perftool/sumfact/accumulation.py
index 47b0a982..4e7c2a69 100644
--- a/python/dune/perftool/sumfact/accumulation.py
+++ b/python/dune/perftool/sumfact/accumulation.py
@@ -23,6 +23,7 @@ from dune.perftool.options import (get_form_option,
                                    get_option,
                                    )
 from dune.perftool.loopy.flatten import flatten_index
+from dune.perftool.loopy.target import type_floatingpoint
 from dune.perftool.sumfact.quadrature import nest_quadrature_loops
 from dune.perftool.pdelab.driver import FEM_name_mangling
 from dune.perftool.pdelab.localoperator import determine_accumulation_space
@@ -236,7 +237,7 @@ class AccumulationOutput(SumfactKernelInterfaceBase, ImmutableRecord):
     @property
     def signature_args(self):
         if get_form_option('fastdg'):
-            ret = ("double* fastdg0",)
+            ret = ("{}* fastdg0".format(type_floatingpoint()),)
             if self.within_inames:
                 ret = ret + ("unsigned int jacobian_offset0",)
             return ret
diff --git a/python/dune/perftool/sumfact/basis.py b/python/dune/perftool/sumfact/basis.py
index 3c48f419..39bba49e 100644
--- a/python/dune/perftool/sumfact/basis.py
+++ b/python/dune/perftool/sumfact/basis.py
@@ -17,6 +17,7 @@ from dune.perftool.generation import (backend,
                                       kernel_cached,
                                       temporary_variable,
                                       )
+from dune.perftool.loopy.target import type_floatingpoint
 from dune.perftool.sumfact.tabulation import (basis_functions_per_direction,
                                               construct_basis_matrix_sequence,
                                               BasisTabulationMatrix,
@@ -134,7 +135,7 @@ class LFSSumfactKernelInput(SumfactKernelInterfaceBase, ImmutableRecord):
     @property
     def signature_args(self):
         if get_form_option("fastdg"):
-            return ("const double* fastdg0",)
+            return ("const {}* fastdg0".format(type_floatingpoint()),)
         else:
             return ()
 
diff --git a/python/dune/perftool/sumfact/symbolic.py b/python/dune/perftool/sumfact/symbolic.py
index 517ef5ab..fb283a05 100644
--- a/python/dune/perftool/sumfact/symbolic.py
+++ b/python/dune/perftool/sumfact/symbolic.py
@@ -8,7 +8,7 @@ from dune.perftool.generation import (get_counted_variable,
 from dune.perftool.pdelab.geometry import local_dimension, world_dimension
 from dune.perftool.sumfact.quadrature import quadrature_inames
 from dune.perftool.sumfact.tabulation import BasisTabulationMatrixBase, BasisTabulationMatrixArray
-from dune.perftool.loopy.target import dtype_floatingpoint
+from dune.perftool.loopy.target import dtype_floatingpoint, type_floatingpoint
 from dune.perftool.loopy.vcl import ExplicitVCLCast, VCLLowerUpperLoad
 from dune.perftool.tools import get_leaf, maybe_wrap_subscript, remove_duplicates
 
@@ -115,7 +115,7 @@ class VectorSumfactKernelInput(SumfactKernelInterfaceBase):
     @property
     def signature_args(self):
         if get_form_option("fastdg"):
-            return tuple("const double* fastdg{}".format(i)for i, _ in enumerate(remove_duplicates(self.interfaces)))
+            return tuple("const {}* fastdg{}".format(type_floatingpoint(), i) for i, _ in enumerate(remove_duplicates(self.interfaces)))
         else:
             return ()
 
@@ -199,7 +199,7 @@ class VectorSumfactKernelOutput(SumfactKernelInterfaceBase):
     def signature_args(self):
         if get_form_option("fastdg"):
             def _get_pair(i):
-                ret = ("double* fastdg{}".format(i),)
+                ret = ("{}* fastdg{}".format(type_floatingpoint(), i),)
                 if self.within_inames:
                     ret = ret + ("unsigned int jacobian_offset{}".format(i),)
                 return ret
-- 
GitLab