Skip to content
Snippets Groups Projects
Commit f7191522 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Make sumfact data structure loopy-controlled

parent 4b6609e6
No related branches found
No related tags found
No related merge requests found
...@@ -72,7 +72,7 @@ class Class(Generable): ...@@ -72,7 +72,7 @@ class Class(Generable):
for bc in base_classes: for bc in base_classes:
assert isinstance(bc, BaseClass) assert isinstance(bc, BaseClass)
for mem in members: for mem in members:
assert isinstance(mem, ClassMember) assert isinstance(mem, Generable)
def generate(self): def generate(self):
# define the class header # define the class header
......
...@@ -36,6 +36,7 @@ from dune.perftool.generation.loopy import (barrier, ...@@ -36,6 +36,7 @@ from dune.perftool.generation.loopy import (barrier,
globalarg, globalarg,
iname, iname,
instruction, instruction,
loopy_class_member,
kernel_cached, kernel_cached,
noop_instruction, noop_instruction,
silenced_warning, silenced_warning,
......
...@@ -164,3 +164,17 @@ def barrier(**kwargs): ...@@ -164,3 +164,17 @@ def barrier(**kwargs):
name = 'barrier_{}'.format(get_counter('barrier')) name = 'barrier_{}'.format(get_counter('barrier'))
_barrier(id=name, **kwargs) _barrier(id=name, **kwargs)
return name return name
def loopy_class_member(name, classtag=None, **kwargs):
""" A class member is based on loopy! It is an
* temporary variable of the constructor kernel
* A globalarg of the requesting kernel (to make things pass)
"""
assert classtag
temporary_variable(name, kernel=classtag, **kwargs)
silenced_warning("read_no_write({})".format(name), kernel=classtag)
kwargs.pop("decl_method", None)
# TODO I guess some filtering has to be applied here.
globalarg(name, **kwargs)
\ No newline at end of file
...@@ -115,11 +115,18 @@ class DuneASTBuilder(CASTBuilder): ...@@ -115,11 +115,18 @@ class DuneASTBuilder(CASTBuilder):
post_include("#define BARRIER asm volatile(\"\": : :\"memory\")", filetag="operatorfile") post_include("#define BARRIER asm volatile(\"\": : :\"memory\")", filetag="operatorfile")
return cgen.Line("BARRIER;") return cgen.Line("BARRIER;")
def get_temporary_decls(self, codegen_state, schedule_index):
if self.target.declare_temporaries:
return CASTBuilder.get_temporary_decls(self, codegen_state, schedule_index)
else:
return []
class DuneTarget(TargetBase): class DuneTarget(TargetBase):
def __init__(self): def __init__(self, declare_temporaries=True):
# Set fortran_abi to allow reusing CASTBuilder for the moment # Set fortran_abi to allow reusing CASTBuilder for the moment
self.fortran_abi = False self.fortran_abi = False
self.declare_temporaries = declare_temporaries
def split_kernel_at_global_barriers(self): def split_kernel_at_global_barriers(self):
return False return False
......
...@@ -644,11 +644,29 @@ def cgen_class_from_cache(tag, members=[]): ...@@ -644,11 +644,29 @@ def cgen_class_from_cache(tag, members=[]):
# Construct the constructor # Construct the constructor
constructor_knl = extract_kernel_from_cache(tag) constructor_knl = extract_kernel_from_cache(tag)
from dune.perftool.loopy.target import DuneTarget
constructor_knl = constructor_knl.copy(target=DuneTarget(declare_temporaries=False))
signature = "{}({})".format(basename, ", ".join(next(iter(p.generate(with_semicolon=False))) for p in constructor_params)) signature = "{}({})".format(basename, ", ".join(next(iter(p.generate(with_semicolon=False))) for p in constructor_params))
constructor = LoopyKernelMethod([signature], constructor_knl, add_timings=False, initializer_list=il) constructor = LoopyKernelMethod([signature], constructor_knl, add_timings=False, initializer_list=il)
# Take any temporary declarations from the kernel and make them class members
target = DuneTarget()
from loopy.codegen import CodeGenerationState
codegen_state = CodeGenerationState(kernel=constructor_knl,
implemented_data_info=None,
implemented_domain=None,
implemented_predicates=frozenset(),
seen_dtypes=frozenset(),
seen_functions=frozenset(),
seen_atomic_dtypes=frozenset(),
var_subst_map={},
allow_complex=False,
is_generating_device_code=True,
)
decls = target.get_device_ast_builder().get_temporary_decls(codegen_state, 0)
from dune.perftool.cgen import Class from dune.perftool.cgen import Class
return Class(basename, base_classes=base_classes, members=[constructor] + members + pm, tparam_decls=tparams) return Class(basename, base_classes=base_classes, members=[constructor] + members + pm + decls, tparam_decls=tparams)
def generate_localoperator_kernels(formdata, data): def generate_localoperator_kernels(formdata, data):
......
...@@ -13,6 +13,7 @@ from dune.perftool.generation import (class_member, ...@@ -13,6 +13,7 @@ from dune.perftool.generation import (class_member,
include_file, include_file,
initializer_list, initializer_list,
instruction, instruction,
loopy_class_member,
preamble, preamble,
silenced_warning, silenced_warning,
temporary_variable, temporary_variable,
...@@ -93,30 +94,30 @@ def basis_functions_per_direction(): ...@@ -93,30 +94,30 @@ def basis_functions_per_direction():
return polynomial_degree() + 1 return polynomial_degree() + 1
@class_member(classtag="operator")
def define_oned_quadrature_weights(name): def define_oned_quadrature_weights(name):
range_field = lop_template_range_field() loopy_class_member(name,
number_qp = quadrature_points_per_direction() dtype=numpy.float64,
return "{} {}[{}];".format(range_field, name, number_qp) classtag="operator",
shape=(quadrature_points_per_direction(),),
)
def name_oned_quadrature_weights(): def name_oned_quadrature_weights():
name = "qw" name = "qw"
globalarg(name, shape=(quadrature_points_per_direction(),), dtype=NumpyType(numpy.float64))
define_oned_quadrature_weights(name) define_oned_quadrature_weights(name)
return name return name
@class_member(classtag="operator")
def define_oned_quadrature_points(name): def define_oned_quadrature_points(name):
range_field = lop_template_range_field() loopy_class_member(name,
number_qp = quadrature_points_per_direction() dtype=numpy.float64,
return "{} {}[{}];".format(range_field, name, number_qp) classtag="operator",
shape=(quadrature_points_per_direction(),),
)
def name_oned_quadrature_points(): def name_oned_quadrature_points():
name = "qp" name = "qp"
globalarg(name, shape=(quadrature_points_per_direction(),), dtype=NumpyType(numpy.float64))
define_oned_quadrature_points(name) define_oned_quadrature_points(name)
return name return name
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment