Skip to content
Snippets Groups Projects
Commit 1d8c8aa1 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Implement a custom base storage concept

parent 91e6c5c0
No related branches found
No related tags found
No related merge requests found
......@@ -154,14 +154,16 @@ class DuneASTBuilder(CASTBuilder):
def get_c_expression_to_code_mapper(self):
return DuneCExpressionToCodeMapper()
def get_temporary_decl(self, knl, schedule_index, temp_var, decl_info):
def get_temporary_decl(self, codegen_state, schedule_index, temp_var, decl_info):
# If this is not a DuneTemporaryVariable, it was introduced by loopy
# and it should be totally under loopys control: Call the base class implementation!
if not (isinstance(temp_var, DuneTemporaryVariable) and temp_var.custom_declaration):
return CASTBuilder.get_temporary_decl(self, knl, schedule_index, temp_var, decl_info)
return CASTBuilder.get_temporary_decl(self, codegen_state, schedule_index, temp_var, decl_info)
if temp_var.decl_method:
return cgen.Line(temp_var.decl_method(temp_var.name, temp_var.shape, temp_var.shape_impl))
if temp_var.custom_declaration:
decl = temp_var.decl_method(temp_var.name, temp_var.shape, temp_var.shape_impl)
if decl:
return cgen.Line(decl)
def add_vector_access(self, access_expr, index):
# There is no generic way of implementing a vector access with VCL, as
......@@ -176,10 +178,33 @@ class DuneASTBuilder(CASTBuilder):
return cgen.Line("BARRIER;")
def get_temporary_decls(self, codegen_state, schedule_index):
temps = codegen_state.kernel.temporary_variables.values()
# Declare all the custom base storages
ret = []
for bs in set(t.custom_base_storage for t in temps if isinstance(t, DuneTemporaryVariable)) - set({None}):
if bs in [a.name for a in codegen_state.kernel.args]:
continue
# Find the alignment bytes
alignment = []
size = []
for t in temps:
if t.custom_base_storage == bs:
# TODO: Extract correct size
alignment.append(8)
from pytools import product
size.append(product(t.shape))
alignment = max(alignment)
size = max(size)
decl = "char {}[{}] __attribute__ ((aligned({})));".format(bs, size * alignment, alignment)
ret.append(cgen.Line(decl))
if self.target.declare_temporaries:
return CASTBuilder.get_temporary_decls(self, codegen_state, schedule_index)
return ret + CASTBuilder.get_temporary_decls(self, codegen_state, schedule_index)
else:
return []
return ret
class DuneTarget(TargetBase):
......
......@@ -5,6 +5,7 @@ from dune.perftool.error import PerftoolLoopyError
from loopy import TemporaryVariable
import loopy as lp
import numpy
......@@ -44,11 +45,20 @@ def default_declaration(name, shape=(), shape_impl=()):
return '{} {}(0.0);'.format(t, name)
def custom_base_storage_temporary_declaration(storage, dtype):
def _decl(name, *a):
from dune.perftool.loopy.target import numpy_to_cpp_dtype
_type = numpy_to_cpp_dtype(lp.types.NumpyType(dtype).dtype.name)
return "{0} *{1} = ({0} *){2};".format(_type, name, storage)
return _decl
class DuneTemporaryVariable(TemporaryVariable):
allowed_extra_kwargs = TemporaryVariable.allowed_extra_kwargs + ["managed", "shape_impl", "decl_method"]
allowed_extra_kwargs = TemporaryVariable.allowed_extra_kwargs + ["managed", "shape_impl", "decl_method", "custom_base_storage"]
def __init__(self, name, managed=False, shape_impl=None, decl_method=None, **kwargs):
def __init__(self, name, managed=False, shape_impl=None, decl_method=None, custom_base_storage=None, **kwargs):
self.managed = managed
self.decl_method = decl_method
self.shape_impl = shape_impl
......@@ -59,6 +69,15 @@ class DuneTemporaryVariable(TemporaryVariable):
from dune.perftool.loopy.target import dtype_floatingpoint
kwargs.setdefault('dtype', dtype_floatingpoint())
if custom_base_storage and self.decl_method is None:
assert shape_impl is None
self.decl_method = custom_base_storage_temporary_declaration(custom_base_storage, kwargs["dtype"])
self.custom_declaration = self.decl_method is not None
TemporaryVariable.__init__(self, name, managed=self.managed, shape_impl=self.shape_impl, decl_method=self.decl_method, **kwargs)
TemporaryVariable.__init__(self, name,
managed=self.managed,
shape_impl=self.shape_impl,
decl_method=self.decl_method,
custom_base_storage=custom_base_storage,
**kwargs)
......@@ -426,13 +426,12 @@ def generate_accumulation_instruction(expr, visitor):
vectag = frozenset({"gradvec"}) if vsf.vectorized else frozenset()
from dune.perftool.sumfact.realization import name_buffer_storage, buffer_decl, get_sumfact_dtype
storage = name_buffer_storage(buffer, 0)
from dune.perftool.sumfact.realization import name_buffer_storage
temp = "input_{}".format(buffer)
temporary_variable(temp,
shape=vsf.quadrature_shape,
dim_tags=vsf.quadrature_dimtags,
decl_method=buffer_decl(storage, get_sumfact_dtype(sf)),
custom_base_storage=name_buffer_storage(buffer, 0),
managed=True,
)
......
......@@ -82,12 +82,11 @@ class LFSSumfactKernelInput(SumfactKernelInputBase, ImmutableRecord):
coeff = pc(container, lfs, basisiname)
# Get the input temporary!
from dune.perftool.sumfact.realization import name_buffer_storage, buffer_decl, get_sumfact_dtype
storage = name_buffer_storage(sf.buffer, 0)
from dune.perftool.sumfact.realization import name_buffer_storage
name = "input_{}".format(sf.buffer)
temporary_variable(name,
shape=(product(mat.basis_size for mat in sf.matrix_sequence), sf.vector_width),
decl_method=buffer_decl(storage, get_sumfact_dtype(sf)),
custom_base_storage=name_buffer_storage(sf.buffer, 0),
managed=True,
)
......
......@@ -67,11 +67,6 @@ def alias_data_array(name, data):
return "auto {} = {}.data();".format(name, data)
@preamble
def declare_buffer_storage(name, size, alignment):
return "char {}[{}] __attribute__ ((aligned({})));".format(name, size * alignment, alignment)
def name_buffer_storage(buff, which):
name = "{}_{}".format(buff, which)
return name
......@@ -88,9 +83,15 @@ def _realize_sum_factorization_kernel(sf):
alignment = 8
buffers = tuple(name_buffer_storage(sf.buffer, i) for i in range(2))
# Make sure that the storage is allocated
# Make sure that the storage is allocated and has a certain minimum size
# This is necessary to allocate buffers that will be passed to sumfact kernel
# functions. Loopy has no knowledge of what happens with those...
for buf in buffers:
declare_buffer_storage(buf, size, alignment)
temporary_variable("{}_dummy".format(buf),
shape=(10000,),
custom_base_storage=buf,
decl_method=lambda *a: None,
)
# Realize the input if it is not direct
if not sf.input.direct_input_is_possible:
......@@ -108,7 +109,7 @@ def _realize_sum_factorization_kernel(sf):
temporary_variable(out,
shape=sf.output_shape,
dim_tags=sf.output_dimtags,
decl_method=buffer_decl(buffers[sf.length % 2], get_sumfact_dtype(sf)),
custom_base_storage=buffers[sf.length % 2],
managed=True,
)
silenced_warning("read_no_write({})".format(out))
......@@ -125,24 +126,17 @@ def buffer_decl(buffer, dtype):
return _buffer_decl
def get_sumfact_dtype(sf):
if sf.vectorized:
pass
else:
from dune.perftool.loopy.target import dtype_floatingpoint
from loopy.types import NumpyType
return NumpyType(dtype_floatingpoint()).dtype.name
class BufferSwitcher(object):
def __init__(self, buffers=("buffer0", "buffer1")):
self.buffers = buffers
self.current = 0
def get_temporary(self, name=None, **kwargs):
bs = self.buffers[self.current]
globalarg(bs)
temporary_variable(name,
managed=True,
decl_method=buffer_decl(self.buffers[self.current], kwargs["dtype"]),
custom_base_storage=self.buffers[self.current],
**kwargs
)
......@@ -229,7 +223,6 @@ def realize_sumfact_kernel_function(sf):
inp = buffer.get_temporary("buff_step{}_in".format(l),
shape=inp_shape + vec_shape,
dim_tags=ftags,
dtype=get_sumfact_dtype(sf),
)
# The input temporary will only be read from, so we need to silence the loopy warning
......@@ -255,7 +248,6 @@ def realize_sumfact_kernel_function(sf):
out = buffer.get_temporary("buff_step{}_out".format(l),
shape=output_shape + vec_shape,
dim_tags=ftags,
dtype=get_sumfact_dtype(sf),
)
# Write the matrix-matrix multiplication expression
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment