Skip to content
Snippets Groups Projects
Commit 09718cae authored by Dominic Kempf's avatar Dominic Kempf
Browse files

generator constructor blocks as loopy kernel

parent f6b4eae7
No related branches found
No related tags found
No related merge requests found
...@@ -61,57 +61,18 @@ class ClassMember(Generable): ...@@ -61,57 +61,18 @@ class ClassMember(Generable):
yield line + '\n' yield line + '\n'
class Constructor(Generable):
def __init__(self, block=Block([]), arg_decls=[], clsname=None, initializer_list=[], access=AccessModifier.PUBLIC):
self.clsname = clsname
self.arg_decls = arg_decls
self.access = access
self.block = block
self.il = initializer_list
def generate(self):
assert self.clsname
yield '\n'
yield "{}:\n".format(access_modifier_string(self.access))
yield self.clsname + "("
if self.arg_decls:
for content in self.arg_decls[0].generate(with_semicolon=False):
yield content
for ad in self.arg_decls[1:]:
yield ", "
for content in ad.generate(with_semicolon=False):
yield content
yield ")\n"
# add the initializer list
if self.il:
yield " : {}".format(self.il[0])
for i in self.il[1:]:
yield ",\n"
yield " {}".format(i)
yield '\n'
for line in self.block.generate():
yield line
class Class(Generable): class Class(Generable):
""" Generator for a templated class """ """ Generator for a templated class """
def __init__(self, name, base_classes=[], members=[], tparam_decls=[], constructors=[]): def __init__(self, name, base_classes=[], members=[], tparam_decls=[]):
self.name = name self.name = name
self.base_classes = base_classes self.base_classes = base_classes
self.members = members self.members = members
self.tparam_decls = tparam_decls self.tparam_decls = tparam_decls
self.constructors = constructors
for bc in base_classes: for bc in base_classes:
assert isinstance(bc, BaseClass) assert isinstance(bc, BaseClass)
for mem in members: for mem in members:
assert isinstance(mem, ClassMember) assert isinstance(mem, ClassMember)
for con in constructors:
assert isinstance(con, Constructor)
def generate(self): def generate(self):
# define the class header # define the class header
...@@ -139,7 +100,7 @@ class Class(Generable): ...@@ -139,7 +100,7 @@ class Class(Generable):
yield '\n' yield '\n'
# Now yield the entire block # Now yield the entire block
block = Block(contents=self.constructors + self.members) block = Block(contents=self.members)
# Yield the block # Yield the block
for line in block.generate(): for line in block.generate():
......
...@@ -12,10 +12,9 @@ import cgen ...@@ -12,10 +12,9 @@ import cgen
preamble = generator_factory(item_tags=("preamble",), counted=True, context_tags="kernel") preamble = generator_factory(item_tags=("preamble",), counted=True, context_tags="kernel")
pre_include = generator_factory(item_tags=("file", "pre_include"), context_tags=("filetag",), no_deco=True) pre_include = generator_factory(item_tags=("file", "pre_include"), context_tags=("filetag",), no_deco=True)
post_include = generator_factory(item_tags=("file", "post_include"), context_tags=("filetag",), no_deco=True) post_include = generator_factory(item_tags=("file", "post_include"), context_tags=("filetag",), no_deco=True)
class_member = generator_factory(item_tags=("clazz", "member"), context_tags=("classtag",), on_store=lambda m: ClassMember(m), counted=True) class_member = generator_factory(item_tags=("member",), context_tags=("classtag",), on_store=lambda m: ClassMember(m), counted=True)
template_parameter = generator_factory(item_tags=("clazz", "template_param"), context_tags=("classtag",), counted=True) template_parameter = generator_factory(item_tags=("template_param",), context_tags=("classtag",), counted=True)
class_basename = generator_factory(item_tags=("clazz", "basename"), context_tags=("classtag",)) class_basename = generator_factory(item_tags=("basename",), context_tags=("classtag",))
constructor_block = generator_factory(item_tags=("clazz", "constructor_block"), context_tags=("classtag",), counted=True)
@generator_factory(item_tags=("file", "include"), context_tags=("filetag",)) @generator_factory(item_tags=("file", "include"), context_tags=("filetag",))
......
...@@ -3,7 +3,7 @@ from functools import partial ...@@ -3,7 +3,7 @@ from functools import partial
from dune.perftool.generation import global_context from dune.perftool.generation import global_context
from dune.perftool.loopy.transformations import get_loopy_transformations from dune.perftool.loopy.transformations import get_loopy_transformations
from dune.perftool.pdelab.localoperator import assembly_routine_signature, AssemblyMethod from dune.perftool.pdelab.localoperator import assembly_routine_signature, LoopyKernelMethod
import os import os
...@@ -80,7 +80,7 @@ def show_code(which, kernel): ...@@ -80,7 +80,7 @@ def show_code(which, kernel):
with global_context(integral_type=which[0], form_type=which[1]): with global_context(integral_type=which[0], form_type=which[1]):
signature = assembly_routine_signature() signature = assembly_routine_signature()
print("".join(AssemblyMethod(signature, kernel).generate())) print("".join(LoopyKernelMethod(signature, kernel).generate()))
print("Press Return to return to the previous menu") print("Press Return to return to the previous menu")
input() input()
......
...@@ -476,7 +476,13 @@ def generate_kernel(integrals): ...@@ -476,7 +476,13 @@ def generate_kernel(integrals):
get_backend(interface="accum_insn")(visitor, term, measure, subdomain_id) get_backend(interface="accum_insn")(visitor, term, measure, subdomain_id)
tag = get_global_context_value("kernel") tag = get_global_context_value("kernel")
return extract_kernel_from_cache(tag) knl = extract_kernel_from_cache(tag)
# All items with the kernel tags can be destroyed once a kernel has been generated
from dune.perftool.generation import delete_cache_items
delete_cache_items(tag)
return knl
def extract_kernel_from_cache(tag): def extract_kernel_from_cache(tag):
...@@ -485,6 +491,10 @@ def extract_kernel_from_cache(tag): ...@@ -485,6 +491,10 @@ def extract_kernel_from_cache(tag):
from dune.perftool.generation import retrieve_cache_functions, retrieve_cache_items from dune.perftool.generation import retrieve_cache_functions, retrieve_cache_items
from dune.perftool.loopy.target import DuneTarget from dune.perftool.loopy.target import DuneTarget
domains = [i for i in retrieve_cache_items("{} and domain".format(tag))] domains = [i for i in retrieve_cache_items("{} and domain".format(tag))]
if not domains:
domains = ["{[stupid] : 0<=stupid<1}"]
instructions = [i for i in retrieve_cache_items("{} and instruction".format(tag))] instructions = [i for i in retrieve_cache_items("{} and instruction".format(tag))]
temporaries = {i.name: i for i in retrieve_cache_items("{} and temporary".format(tag))} temporaries = {i.name: i for i in retrieve_cache_items("{} and temporary".format(tag))}
arguments = [i for i in retrieve_cache_items("{} and argument".format(tag))] arguments = [i for i in retrieve_cache_items("{} and argument".format(tag))]
...@@ -536,10 +546,6 @@ def extract_kernel_from_cache(tag): ...@@ -536,10 +546,6 @@ def extract_kernel_from_cache(tag):
# Do the loopy preprocessing! # Do the loopy preprocessing!
kernel = preprocess_kernel(kernel) kernel = preprocess_kernel(kernel)
# All items with the kernel tags can be destroyed once a kernel has been generated
from dune.perftool.generation import delete_cache_items
delete_cache_items(tag)
return kernel return kernel
...@@ -587,11 +593,19 @@ class TimerMethod(ClassMember): ...@@ -587,11 +593,19 @@ class TimerMethod(ClassMember):
ClassMember.__init__(self, content) ClassMember.__init__(self, content)
class AssemblyMethod(ClassMember): class LoopyKernelMethod(ClassMember):
def __init__(self, signature, kernel, filename): def __init__(self, signature, kernel, add_timings=True, initializer_list=[]):
from loopy import generate_body from loopy import generate_body
from cgen import LiteralLines, Block from cgen import LiteralLines, Block
content = signature content = signature
# Add initializer list if this is a constructor
if initializer_list:
content[-1] = content[-1] + " :"
for init in initializer_list[:-1]:
content.append(" "*4 + init + ",")
content.append(" "*4 + initializer_list[-1])
content.append('{') content.append('{')
if kernel is not None: if kernel is not None:
# Add kernel preamble # Add kernel preamble
...@@ -599,7 +613,7 @@ class AssemblyMethod(ClassMember): ...@@ -599,7 +613,7 @@ class AssemblyMethod(ClassMember):
content.append(' ' + p) content.append(' ' + p)
# Start timer # Start timer
if get_option('timer'): if add_timings and get_option('timer'):
timer_name = assembler_routine_name() + '_kernel' timer_name = assembler_routine_name() + '_kernel'
post_include('HP_DECLARE_TIMER({});'.format(timer_name), filetag='operatorfile') post_include('HP_DECLARE_TIMER({});'.format(timer_name), filetag='operatorfile')
content.append(' ' + 'HP_TIMER_START({});'.format(timer_name)) content.append(' ' + 'HP_TIMER_START({});'.format(timer_name))
...@@ -609,7 +623,7 @@ class AssemblyMethod(ClassMember): ...@@ -609,7 +623,7 @@ class AssemblyMethod(ClassMember):
content.extend(l for l in generate_body(kernel).split('\n')[1:-1]) content.extend(l for l in generate_body(kernel).split('\n')[1:-1])
# Stop timer # Stop timer
if get_option('timer'): if add_timings and get_option('timer'):
content.append(' ' + 'HP_TIMER_STOP({});'.format(timer_name)) content.append(' ' + 'HP_TIMER_STOP({});'.format(timer_name))
content.append('}') content.append('}')
...@@ -624,17 +638,17 @@ def cgen_class_from_cache(tag, members=[]): ...@@ -624,17 +638,17 @@ def cgen_class_from_cache(tag, members=[]):
base_classes = [bc for bc in retrieve_cache_items('{} and baseclass'.format(tag))] base_classes = [bc for bc in retrieve_cache_items('{} and baseclass'.format(tag))]
constructor_params = [bc for bc in retrieve_cache_items('{} and constructor_param'.format(tag))] constructor_params = [bc for bc in retrieve_cache_items('{} and constructor_param'.format(tag))]
from cgen import Block
constructor_block = Block(contents=[i for i in retrieve_cache_items("{} and constructor_block".format(tag), make_generable=True)])
il = [i for i in retrieve_cache_items('{} and initializer'.format(tag))] il = [i for i in retrieve_cache_items('{} and initializer'.format(tag))]
pm = [m for m in retrieve_cache_items('{} and member'.format(tag))] pm = [m for m in retrieve_cache_items('{} and member'.format(tag))]
tparams = [i for i in retrieve_cache_items('{} and template_param'.format(tag))] tparams = [i for i in retrieve_cache_items('{} and template_param'.format(tag))]
from dune.perftool.cgen.clazz import Constructor # Construct the constructor
constructor = Constructor(block=constructor_block, arg_decls=constructor_params, clsname=basename, initializer_list=il) constructor_knl = extract_kernel_from_cache(tag)
signature = "{}({})".format(basename, ", ".join(next(iter(p.generate(with_semicolon=False))) for p in constructor_params))
constructor = LoopyKernelMethod([signature], constructor_knl, add_timings=False, initializer_list=il)
from dune.perftool.cgen import Class from dune.perftool.cgen import Class
return Class(basename, base_classes=base_classes, members=members + pm, constructors=[constructor], tparam_decls=tparams) return Class(basename, base_classes=base_classes, members=[constructor] + members + pm, tparam_decls=tparams)
def generate_localoperator_kernels(formdata, data): def generate_localoperator_kernels(formdata, data):
...@@ -794,7 +808,7 @@ def generate_localoperator_file(formdata, kernels, filename): ...@@ -794,7 +808,7 @@ def generate_localoperator_file(formdata, kernels, filename):
it, ft = method it, ft = method
with global_context(integral_type=it, form_type=ft): with global_context(integral_type=it, form_type=ft):
signature = assembly_routine_signature(formdata) signature = assembly_routine_signature(formdata)
operator_methods.append(AssemblyMethod(signature, kernel, filename)) operator_methods.append(LoopyKernelMethod(signature, kernel))
if get_option('timer'): if get_option('timer'):
include_file('dune/perftool/common/timer.hh', filetag='operatorfile') include_file('dune/perftool/common/timer.hh', filetag='operatorfile')
......
...@@ -13,6 +13,8 @@ from dune.perftool.generation import (class_member, ...@@ -13,6 +13,8 @@ from dune.perftool.generation import (class_member,
iname, iname,
include_file, include_file,
initializer_list, initializer_list,
instruction,
preamble,
silenced_warning, silenced_warning,
temporary_variable, temporary_variable,
valuearg valuearg
...@@ -181,7 +183,7 @@ def name_polynomials(): ...@@ -181,7 +183,7 @@ def name_polynomials():
return name return name
@constructor_block(classtag="operator") @preamble(kernel="operator")
def sort_quadrature_points_weights(): def sort_quadrature_points_weights():
range_field = lop_template_range_field() range_field = lop_template_range_field()
domain_field = name_domain_field() domain_field = name_domain_field()
...@@ -192,7 +194,13 @@ def sort_quadrature_points_weights(): ...@@ -192,7 +194,13 @@ def sort_quadrature_points_weights():
return "onedQuadraturePointsWeights<{}, {}, {}>({}, {});".format(range_field, domain_field, number_qp, qp, qw) return "onedQuadraturePointsWeights<{}, {}, {}>({}, {});".format(range_field, domain_field, number_qp, qp, qw)
@constructor_block(classtag="operator") @iname(kernel="operator")
def theta_iname(name, bound):
name = "{}_{}".format(name, bound)
domain(name, bound)
return name
def construct_theta(name, transpose, derivative): def construct_theta(name, transpose, derivative):
# Make sure that the quadrature points are sorted # Make sure that the quadrature points are sorted
sort_quadrature_points_weights() sort_quadrature_points_weights()
...@@ -204,15 +212,18 @@ def construct_theta(name, transpose, derivative): ...@@ -204,15 +212,18 @@ def construct_theta(name, transpose, derivative):
polynomials = name_polynomials() polynomials = name_polynomials()
qp = name_oned_quadrature_points() qp = name_oned_quadrature_points()
i = theta_iname("i", shape[0])
j = theta_iname("j", shape[1])
# access = "j,i" if transpose else "i,j" # access = "j,i" if transpose else "i,j"
basispol = "dp" if derivative else "p" basispol = "dp" if derivative else "p"
polynomial_access = "i,{}[j]".format(qp) if transpose else "j,{}[i]".format(qp) polynomial_access = "{},{}[{}]".format(i, qp, j) if transpose else "{},{}[{}]".format(j, qp, i)
return ["for (std::size_t i=0; i<{}; i++){{".format(shape[0]), return instruction(code="{}.colmajoraccess({},{}) = {}.{}({});".format(name, i, j, polynomials, basispol, polynomial_access),
" for (std::size_t j=0; j<{}; j++){{".format(shape[1]), kernel="operator",
" {}.colmajoraccess(i,j) = {}.{}({});".format(name, polynomials, basispol, polynomial_access), within_inames=frozenset({i, j}),
" }", within_inames_is_final=True,
"}"] )
@class_member(classtag="operator") @class_member(classtag="operator")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment