Skip to content
Snippets Groups Projects
Commit 7f61c38a authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Adapt the custom declaration mechanism to get better type understanding

parent ee14ba98
No related branches found
No related tags found
No related merge requests found
......@@ -161,7 +161,7 @@ class DuneASTBuilder(CASTBuilder):
return CASTBuilder.get_temporary_decl(self, codegen_state, schedule_index, temp_var, decl_info)
if temp_var.custom_declaration:
decl = temp_var.decl_method(temp_var.name, temp_var.shape, temp_var.shape_impl)
decl = temp_var.decl_method(temp_var.name, codegen_state.kernel, decl_info)
if decl:
return cgen.Line(decl)
......
......@@ -28,7 +28,10 @@ def _temporary_type(shape_impl, shape, first=True):
return "Dune::FieldMatrix<{}, {}, {}>".format(_type, shape[0], shape[1])
def default_declaration(name, shape=(), shape_impl=()):
def default_declaration(name, kernel, decl_info):
shape = kernel.temporary_variables[name].shape
shape_impl = kernel.temporary_variables[name].shape_impl
# Determine the C++ type to use for this temporary.
t = _temporary_type(shape_impl, shape)
if len(shape_impl) == 0:
......@@ -45,10 +48,10 @@ def default_declaration(name, shape=(), shape_impl=()):
return '{} {}(0.0);'.format(t, name)
def custom_base_storage_temporary_declaration(storage, dtype):
def _decl(name, *a):
from dune.perftool.loopy.target import numpy_to_cpp_dtype
_type = numpy_to_cpp_dtype(lp.types.NumpyType(dtype).dtype.name)
def custom_base_storage_temporary_declaration(storage):
def _decl(name, kernel, decl_info):
dtype = kernel.temporary_variables[name].dtype
_type = kernel.target.dtype_to_typename(decl_info.dtype)
return "{0} *{1} = ({0} *){2};".format(_type, name, storage)
return _decl
......@@ -71,7 +74,7 @@ class DuneTemporaryVariable(TemporaryVariable):
if custom_base_storage and self.decl_method is None:
assert shape_impl is None
self.decl_method = custom_base_storage_temporary_declaration(custom_base_storage, kwargs["dtype"])
self.decl_method = custom_base_storage_temporary_declaration(custom_base_storage)
self.custom_declaration = self.decl_method is not None
......
......@@ -86,7 +86,7 @@ def declare_cache_temporary(element, restriction, which):
t_cache = type_localbasis_cache(element)
lfs = name_leaf_lfs(element, restriction)
def decl(name, shape, shape_impl):
def decl(name, kernel, decl_info):
return "typename {}::{}ReturnType {};".format(t_cache,
which,
name,
......
......@@ -19,7 +19,7 @@ def bind_gridfunction_to_element(gf, restriction):
def declare_grid_function_range(gridfunction):
def _decl(name, *args):
def _decl(name, kernel, decl_info):
return "typename decltype({})::Range {};".format(gridfunction, name)
return _decl
......
......@@ -261,7 +261,7 @@ def evaluate_unit_outer_normal(name):
@preamble
def declare_normal(name, shape, shape_impl):
def declare_normal(name, kernel, decl_info):
ig = name_intersection_geometry_wrapper()
return "auto {} = {}.centerUnitOuterNormal();".format(name, ig)
......@@ -300,7 +300,7 @@ def type_jacobian_inverse_transposed(restriction):
@kernel_cached
def define_jacobian_inverse_transposed_temporary(restriction):
@preamble
def _define_jacobian_inverse_transposed_temporary(name, shape, shape_impl):
def _define_jacobian_inverse_transposed_temporary(name, kernel, decl_info):
t = type_jacobian_inverse_transposed(restriction)
return "{} {};".format(t,
name,
......
......@@ -40,11 +40,11 @@ class GeoCornersInput(SumfactKernelInputBase, ImmutableRecord):
ImmutableRecord.__init__(self, dir=dir)
def realize(self, sf, index, insn_dep):
from dune.perftool.sumfact.realization import name_buffer_storage, buffer_decl, get_sumfact_dtype
storage = name_buffer_storage(sf.buffer, 0)
from dune.perftool.sumfact.realization import name_buffer_storage
name = name="input_{}".format(sf.buffer)
temporary_variable(name,
shape=(2 ** local_dimension(), sf.vector_width),
custom_base_storage=name_buffer_storage(sf.buffer, 0),
decl_method=buffer_decl(storage, get_sumfact_dtype(sf)),
managed=True,
)
......
......@@ -90,7 +90,7 @@ def _realize_sum_factorization_kernel(sf):
temporary_variable("{}_dummy".format(buf),
shape=(10000,),
custom_base_storage=buf,
decl_method=lambda *a: None,
decl_method=lambda n, k, di: None,
)
# Realize the input if it is not direct
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment