Skip to content
Snippets Groups Projects
Commit 7f61c38a authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Adapt the custom declaration mechanism to get better type understanding

parent ee14ba98
No related branches found
No related tags found
No related merge requests found
...@@ -161,7 +161,7 @@ class DuneASTBuilder(CASTBuilder): ...@@ -161,7 +161,7 @@ class DuneASTBuilder(CASTBuilder):
return CASTBuilder.get_temporary_decl(self, codegen_state, schedule_index, temp_var, decl_info) return CASTBuilder.get_temporary_decl(self, codegen_state, schedule_index, temp_var, decl_info)
if temp_var.custom_declaration: if temp_var.custom_declaration:
decl = temp_var.decl_method(temp_var.name, temp_var.shape, temp_var.shape_impl) decl = temp_var.decl_method(temp_var.name, codegen_state.kernel, decl_info)
if decl: if decl:
return cgen.Line(decl) return cgen.Line(decl)
......
...@@ -28,7 +28,10 @@ def _temporary_type(shape_impl, shape, first=True): ...@@ -28,7 +28,10 @@ def _temporary_type(shape_impl, shape, first=True):
return "Dune::FieldMatrix<{}, {}, {}>".format(_type, shape[0], shape[1]) return "Dune::FieldMatrix<{}, {}, {}>".format(_type, shape[0], shape[1])
def default_declaration(name, shape=(), shape_impl=()): def default_declaration(name, kernel, decl_info):
shape = kernel.temporary_variables[name].shape
shape_impl = kernel.temporary_variables[name].shape_impl
# Determine the C++ type to use for this temporary. # Determine the C++ type to use for this temporary.
t = _temporary_type(shape_impl, shape) t = _temporary_type(shape_impl, shape)
if len(shape_impl) == 0: if len(shape_impl) == 0:
...@@ -45,10 +48,10 @@ def default_declaration(name, shape=(), shape_impl=()): ...@@ -45,10 +48,10 @@ def default_declaration(name, shape=(), shape_impl=()):
return '{} {}(0.0);'.format(t, name) return '{} {}(0.0);'.format(t, name)
def custom_base_storage_temporary_declaration(storage, dtype): def custom_base_storage_temporary_declaration(storage):
def _decl(name, *a): def _decl(name, kernel, decl_info):
from dune.perftool.loopy.target import numpy_to_cpp_dtype dtype = kernel.temporary_variables[name].dtype
_type = numpy_to_cpp_dtype(lp.types.NumpyType(dtype).dtype.name) _type = kernel.target.dtype_to_typename(decl_info.dtype)
return "{0} *{1} = ({0} *){2};".format(_type, name, storage) return "{0} *{1} = ({0} *){2};".format(_type, name, storage)
return _decl return _decl
...@@ -71,7 +74,7 @@ class DuneTemporaryVariable(TemporaryVariable): ...@@ -71,7 +74,7 @@ class DuneTemporaryVariable(TemporaryVariable):
if custom_base_storage and self.decl_method is None: if custom_base_storage and self.decl_method is None:
assert shape_impl is None assert shape_impl is None
self.decl_method = custom_base_storage_temporary_declaration(custom_base_storage, kwargs["dtype"]) self.decl_method = custom_base_storage_temporary_declaration(custom_base_storage)
self.custom_declaration = self.decl_method is not None self.custom_declaration = self.decl_method is not None
......
...@@ -86,7 +86,7 @@ def declare_cache_temporary(element, restriction, which): ...@@ -86,7 +86,7 @@ def declare_cache_temporary(element, restriction, which):
t_cache = type_localbasis_cache(element) t_cache = type_localbasis_cache(element)
lfs = name_leaf_lfs(element, restriction) lfs = name_leaf_lfs(element, restriction)
def decl(name, shape, shape_impl): def decl(name, kernel, decl_info):
return "typename {}::{}ReturnType {};".format(t_cache, return "typename {}::{}ReturnType {};".format(t_cache,
which, which,
name, name,
......
...@@ -19,7 +19,7 @@ def bind_gridfunction_to_element(gf, restriction): ...@@ -19,7 +19,7 @@ def bind_gridfunction_to_element(gf, restriction):
def declare_grid_function_range(gridfunction): def declare_grid_function_range(gridfunction):
def _decl(name, *args): def _decl(name, kernel, decl_info):
return "typename decltype({})::Range {};".format(gridfunction, name) return "typename decltype({})::Range {};".format(gridfunction, name)
return _decl return _decl
......
...@@ -261,7 +261,7 @@ def evaluate_unit_outer_normal(name): ...@@ -261,7 +261,7 @@ def evaluate_unit_outer_normal(name):
@preamble @preamble
def declare_normal(name, shape, shape_impl): def declare_normal(name, kernel, decl_info):
ig = name_intersection_geometry_wrapper() ig = name_intersection_geometry_wrapper()
return "auto {} = {}.centerUnitOuterNormal();".format(name, ig) return "auto {} = {}.centerUnitOuterNormal();".format(name, ig)
...@@ -300,7 +300,7 @@ def type_jacobian_inverse_transposed(restriction): ...@@ -300,7 +300,7 @@ def type_jacobian_inverse_transposed(restriction):
@kernel_cached @kernel_cached
def define_jacobian_inverse_transposed_temporary(restriction): def define_jacobian_inverse_transposed_temporary(restriction):
@preamble @preamble
def _define_jacobian_inverse_transposed_temporary(name, shape, shape_impl): def _define_jacobian_inverse_transposed_temporary(name, kernel, decl_info):
t = type_jacobian_inverse_transposed(restriction) t = type_jacobian_inverse_transposed(restriction)
return "{} {};".format(t, return "{} {};".format(t,
name, name,
......
...@@ -40,11 +40,11 @@ class GeoCornersInput(SumfactKernelInputBase, ImmutableRecord): ...@@ -40,11 +40,11 @@ class GeoCornersInput(SumfactKernelInputBase, ImmutableRecord):
ImmutableRecord.__init__(self, dir=dir) ImmutableRecord.__init__(self, dir=dir)
def realize(self, sf, index, insn_dep): def realize(self, sf, index, insn_dep):
from dune.perftool.sumfact.realization import name_buffer_storage, buffer_decl, get_sumfact_dtype from dune.perftool.sumfact.realization import name_buffer_storage
storage = name_buffer_storage(sf.buffer, 0)
name = name="input_{}".format(sf.buffer) name = name="input_{}".format(sf.buffer)
temporary_variable(name, temporary_variable(name,
shape=(2 ** local_dimension(), sf.vector_width), shape=(2 ** local_dimension(), sf.vector_width),
custom_base_storage=name_buffer_storage(sf.buffer, 0),
decl_method=buffer_decl(storage, get_sumfact_dtype(sf)), decl_method=buffer_decl(storage, get_sumfact_dtype(sf)),
managed=True, managed=True,
) )
......
...@@ -90,7 +90,7 @@ def _realize_sum_factorization_kernel(sf): ...@@ -90,7 +90,7 @@ def _realize_sum_factorization_kernel(sf):
temporary_variable("{}_dummy".format(buf), temporary_variable("{}_dummy".format(buf),
shape=(10000,), shape=(10000,),
custom_base_storage=buf, custom_base_storage=buf,
decl_method=lambda *a: None, decl_method=lambda n, k, di: None,
) )
# Realize the input if it is not direct # Realize the input if it is not direct
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment