Skip to content
Snippets Groups Projects
Commit d8f1c90a authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Fix operator counting in sumfact kernel function signatures

Avoid explicit doubles.
parent 6cf20872
No related branches found
No related tags found
No related merge requests found
...@@ -23,6 +23,7 @@ from dune.perftool.options import (get_form_option, ...@@ -23,6 +23,7 @@ from dune.perftool.options import (get_form_option,
get_option, get_option,
) )
from dune.perftool.loopy.flatten import flatten_index from dune.perftool.loopy.flatten import flatten_index
from dune.perftool.loopy.target import type_floatingpoint
from dune.perftool.sumfact.quadrature import nest_quadrature_loops from dune.perftool.sumfact.quadrature import nest_quadrature_loops
from dune.perftool.pdelab.driver import FEM_name_mangling from dune.perftool.pdelab.driver import FEM_name_mangling
from dune.perftool.pdelab.localoperator import determine_accumulation_space from dune.perftool.pdelab.localoperator import determine_accumulation_space
...@@ -236,7 +237,7 @@ class AccumulationOutput(SumfactKernelInterfaceBase, ImmutableRecord): ...@@ -236,7 +237,7 @@ class AccumulationOutput(SumfactKernelInterfaceBase, ImmutableRecord):
@property @property
def signature_args(self): def signature_args(self):
if get_form_option('fastdg'): if get_form_option('fastdg'):
ret = ("double* fastdg0",) ret = ("{}* fastdg0".format(type_floatingpoint()),)
if self.within_inames: if self.within_inames:
ret = ret + ("unsigned int jacobian_offset0",) ret = ret + ("unsigned int jacobian_offset0",)
return ret return ret
......
...@@ -17,6 +17,7 @@ from dune.perftool.generation import (backend, ...@@ -17,6 +17,7 @@ from dune.perftool.generation import (backend,
kernel_cached, kernel_cached,
temporary_variable, temporary_variable,
) )
from dune.perftool.loopy.target import type_floatingpoint
from dune.perftool.sumfact.tabulation import (basis_functions_per_direction, from dune.perftool.sumfact.tabulation import (basis_functions_per_direction,
construct_basis_matrix_sequence, construct_basis_matrix_sequence,
BasisTabulationMatrix, BasisTabulationMatrix,
...@@ -134,7 +135,7 @@ class LFSSumfactKernelInput(SumfactKernelInterfaceBase, ImmutableRecord): ...@@ -134,7 +135,7 @@ class LFSSumfactKernelInput(SumfactKernelInterfaceBase, ImmutableRecord):
@property @property
def signature_args(self): def signature_args(self):
if get_form_option("fastdg"): if get_form_option("fastdg"):
return ("const double* fastdg0",) return ("const {}* fastdg0".format(type_floatingpoint()),)
else: else:
return () return ()
......
...@@ -8,7 +8,7 @@ from dune.perftool.generation import (get_counted_variable, ...@@ -8,7 +8,7 @@ from dune.perftool.generation import (get_counted_variable,
from dune.perftool.pdelab.geometry import local_dimension, world_dimension from dune.perftool.pdelab.geometry import local_dimension, world_dimension
from dune.perftool.sumfact.quadrature import quadrature_inames from dune.perftool.sumfact.quadrature import quadrature_inames
from dune.perftool.sumfact.tabulation import BasisTabulationMatrixBase, BasisTabulationMatrixArray from dune.perftool.sumfact.tabulation import BasisTabulationMatrixBase, BasisTabulationMatrixArray
from dune.perftool.loopy.target import dtype_floatingpoint from dune.perftool.loopy.target import dtype_floatingpoint, type_floatingpoint
from dune.perftool.loopy.vcl import ExplicitVCLCast, VCLLowerUpperLoad from dune.perftool.loopy.vcl import ExplicitVCLCast, VCLLowerUpperLoad
from dune.perftool.tools import get_leaf, maybe_wrap_subscript, remove_duplicates from dune.perftool.tools import get_leaf, maybe_wrap_subscript, remove_duplicates
...@@ -115,7 +115,7 @@ class VectorSumfactKernelInput(SumfactKernelInterfaceBase): ...@@ -115,7 +115,7 @@ class VectorSumfactKernelInput(SumfactKernelInterfaceBase):
@property @property
def signature_args(self): def signature_args(self):
if get_form_option("fastdg"): if get_form_option("fastdg"):
return tuple("const double* fastdg{}".format(i)for i, _ in enumerate(remove_duplicates(self.interfaces))) return tuple("const {}* fastdg{}".format(type_floatingpoint(), i) for i, _ in enumerate(remove_duplicates(self.interfaces)))
else: else:
return () return ()
...@@ -199,7 +199,7 @@ class VectorSumfactKernelOutput(SumfactKernelInterfaceBase): ...@@ -199,7 +199,7 @@ class VectorSumfactKernelOutput(SumfactKernelInterfaceBase):
def signature_args(self): def signature_args(self):
if get_form_option("fastdg"): if get_form_option("fastdg"):
def _get_pair(i): def _get_pair(i):
ret = ("double* fastdg{}".format(i),) ret = ("{}* fastdg{}".format(type_floatingpoint(), i),)
if self.within_inames: if self.within_inames:
ret = ret + ("unsigned int jacobian_offset{}".format(i),) ret = ret + ("unsigned int jacobian_offset{}".format(i),)
return ret return ret
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment