Skip to content
Snippets Groups Projects
Commit 09718cae authored by Dominic Kempf's avatar Dominic Kempf
Browse files

generator constructor blocks as loopy kernel

parent f6b4eae7
No related branches found
No related tags found
No related merge requests found
......@@ -61,57 +61,18 @@ class ClassMember(Generable):
yield line + '\n'
class Constructor(Generable):
def __init__(self, block=Block([]), arg_decls=[], clsname=None, initializer_list=[], access=AccessModifier.PUBLIC):
self.clsname = clsname
self.arg_decls = arg_decls
self.access = access
self.block = block
self.il = initializer_list
def generate(self):
assert self.clsname
yield '\n'
yield "{}:\n".format(access_modifier_string(self.access))
yield self.clsname + "("
if self.arg_decls:
for content in self.arg_decls[0].generate(with_semicolon=False):
yield content
for ad in self.arg_decls[1:]:
yield ", "
for content in ad.generate(with_semicolon=False):
yield content
yield ")\n"
# add the initializer list
if self.il:
yield " : {}".format(self.il[0])
for i in self.il[1:]:
yield ",\n"
yield " {}".format(i)
yield '\n'
for line in self.block.generate():
yield line
class Class(Generable):
""" Generator for a templated class """
def __init__(self, name, base_classes=[], members=[], tparam_decls=[], constructors=[]):
def __init__(self, name, base_classes=[], members=[], tparam_decls=[]):
self.name = name
self.base_classes = base_classes
self.members = members
self.tparam_decls = tparam_decls
self.constructors = constructors
for bc in base_classes:
assert isinstance(bc, BaseClass)
for mem in members:
assert isinstance(mem, ClassMember)
for con in constructors:
assert isinstance(con, Constructor)
def generate(self):
# define the class header
......@@ -139,7 +100,7 @@ class Class(Generable):
yield '\n'
# Now yield the entire block
block = Block(contents=self.constructors + self.members)
block = Block(contents=self.members)
# Yield the block
for line in block.generate():
......
......@@ -12,10 +12,9 @@ import cgen
preamble = generator_factory(item_tags=("preamble",), counted=True, context_tags="kernel")
pre_include = generator_factory(item_tags=("file", "pre_include"), context_tags=("filetag",), no_deco=True)
post_include = generator_factory(item_tags=("file", "post_include"), context_tags=("filetag",), no_deco=True)
class_member = generator_factory(item_tags=("clazz", "member"), context_tags=("classtag",), on_store=lambda m: ClassMember(m), counted=True)
template_parameter = generator_factory(item_tags=("clazz", "template_param"), context_tags=("classtag",), counted=True)
class_basename = generator_factory(item_tags=("clazz", "basename"), context_tags=("classtag",))
constructor_block = generator_factory(item_tags=("clazz", "constructor_block"), context_tags=("classtag",), counted=True)
class_member = generator_factory(item_tags=("member",), context_tags=("classtag",), on_store=lambda m: ClassMember(m), counted=True)
template_parameter = generator_factory(item_tags=("template_param",), context_tags=("classtag",), counted=True)
class_basename = generator_factory(item_tags=("basename",), context_tags=("classtag",))
@generator_factory(item_tags=("file", "include"), context_tags=("filetag",))
......
......@@ -3,7 +3,7 @@ from functools import partial
from dune.perftool.generation import global_context
from dune.perftool.loopy.transformations import get_loopy_transformations
from dune.perftool.pdelab.localoperator import assembly_routine_signature, AssemblyMethod
from dune.perftool.pdelab.localoperator import assembly_routine_signature, LoopyKernelMethod
import os
......@@ -80,7 +80,7 @@ def show_code(which, kernel):
with global_context(integral_type=which[0], form_type=which[1]):
signature = assembly_routine_signature()
print("".join(AssemblyMethod(signature, kernel).generate()))
print("".join(LoopyKernelMethod(signature, kernel).generate()))
print("Press Return to return to the previous menu")
input()
......
......@@ -476,7 +476,13 @@ def generate_kernel(integrals):
get_backend(interface="accum_insn")(visitor, term, measure, subdomain_id)
tag = get_global_context_value("kernel")
return extract_kernel_from_cache(tag)
knl = extract_kernel_from_cache(tag)
# All items with the kernel tags can be destroyed once a kernel has been generated
from dune.perftool.generation import delete_cache_items
delete_cache_items(tag)
return knl
def extract_kernel_from_cache(tag):
......@@ -485,6 +491,10 @@ def extract_kernel_from_cache(tag):
from dune.perftool.generation import retrieve_cache_functions, retrieve_cache_items
from dune.perftool.loopy.target import DuneTarget
domains = [i for i in retrieve_cache_items("{} and domain".format(tag))]
if not domains:
domains = ["{[stupid] : 0<=stupid<1}"]
instructions = [i for i in retrieve_cache_items("{} and instruction".format(tag))]
temporaries = {i.name: i for i in retrieve_cache_items("{} and temporary".format(tag))}
arguments = [i for i in retrieve_cache_items("{} and argument".format(tag))]
......@@ -536,10 +546,6 @@ def extract_kernel_from_cache(tag):
# Do the loopy preprocessing!
kernel = preprocess_kernel(kernel)
# All items with the kernel tags can be destroyed once a kernel has been generated
from dune.perftool.generation import delete_cache_items
delete_cache_items(tag)
return kernel
......@@ -587,11 +593,19 @@ class TimerMethod(ClassMember):
ClassMember.__init__(self, content)
class AssemblyMethod(ClassMember):
def __init__(self, signature, kernel, filename):
class LoopyKernelMethod(ClassMember):
def __init__(self, signature, kernel, add_timings=True, initializer_list=[]):
from loopy import generate_body
from cgen import LiteralLines, Block
content = signature
# Add initializer list if this is a constructor
if initializer_list:
content[-1] = content[-1] + " :"
for init in initializer_list[:-1]:
content.append(" "*4 + init + ",")
content.append(" "*4 + initializer_list[-1])
content.append('{')
if kernel is not None:
# Add kernel preamble
......@@ -599,7 +613,7 @@ class AssemblyMethod(ClassMember):
content.append(' ' + p)
# Start timer
if get_option('timer'):
if add_timings and get_option('timer'):
timer_name = assembler_routine_name() + '_kernel'
post_include('HP_DECLARE_TIMER({});'.format(timer_name), filetag='operatorfile')
content.append(' ' + 'HP_TIMER_START({});'.format(timer_name))
......@@ -609,7 +623,7 @@ class AssemblyMethod(ClassMember):
content.extend(l for l in generate_body(kernel).split('\n')[1:-1])
# Stop timer
if get_option('timer'):
if add_timings and get_option('timer'):
content.append(' ' + 'HP_TIMER_STOP({});'.format(timer_name))
content.append('}')
......@@ -624,17 +638,17 @@ def cgen_class_from_cache(tag, members=[]):
base_classes = [bc for bc in retrieve_cache_items('{} and baseclass'.format(tag))]
constructor_params = [bc for bc in retrieve_cache_items('{} and constructor_param'.format(tag))]
from cgen import Block
constructor_block = Block(contents=[i for i in retrieve_cache_items("{} and constructor_block".format(tag), make_generable=True)])
il = [i for i in retrieve_cache_items('{} and initializer'.format(tag))]
pm = [m for m in retrieve_cache_items('{} and member'.format(tag))]
tparams = [i for i in retrieve_cache_items('{} and template_param'.format(tag))]
from dune.perftool.cgen.clazz import Constructor
constructor = Constructor(block=constructor_block, arg_decls=constructor_params, clsname=basename, initializer_list=il)
# Construct the constructor
constructor_knl = extract_kernel_from_cache(tag)
signature = "{}({})".format(basename, ", ".join(next(iter(p.generate(with_semicolon=False))) for p in constructor_params))
constructor = LoopyKernelMethod([signature], constructor_knl, add_timings=False, initializer_list=il)
from dune.perftool.cgen import Class
return Class(basename, base_classes=base_classes, members=members + pm, constructors=[constructor], tparam_decls=tparams)
return Class(basename, base_classes=base_classes, members=[constructor] + members + pm, tparam_decls=tparams)
def generate_localoperator_kernels(formdata, data):
......@@ -794,7 +808,7 @@ def generate_localoperator_file(formdata, kernels, filename):
it, ft = method
with global_context(integral_type=it, form_type=ft):
signature = assembly_routine_signature(formdata)
operator_methods.append(AssemblyMethod(signature, kernel, filename))
operator_methods.append(LoopyKernelMethod(signature, kernel))
if get_option('timer'):
include_file('dune/perftool/common/timer.hh', filetag='operatorfile')
......
......@@ -13,6 +13,8 @@ from dune.perftool.generation import (class_member,
iname,
include_file,
initializer_list,
instruction,
preamble,
silenced_warning,
temporary_variable,
valuearg
......@@ -181,7 +183,7 @@ def name_polynomials():
return name
@constructor_block(classtag="operator")
@preamble(kernel="operator")
def sort_quadrature_points_weights():
range_field = lop_template_range_field()
domain_field = name_domain_field()
......@@ -192,7 +194,13 @@ def sort_quadrature_points_weights():
return "onedQuadraturePointsWeights<{}, {}, {}>({}, {});".format(range_field, domain_field, number_qp, qp, qw)
@constructor_block(classtag="operator")
@iname(kernel="operator")
def theta_iname(name, bound):
name = "{}_{}".format(name, bound)
domain(name, bound)
return name
def construct_theta(name, transpose, derivative):
# Make sure that the quadrature points are sorted
sort_quadrature_points_weights()
......@@ -204,15 +212,18 @@ def construct_theta(name, transpose, derivative):
polynomials = name_polynomials()
qp = name_oned_quadrature_points()
i = theta_iname("i", shape[0])
j = theta_iname("j", shape[1])
# access = "j,i" if transpose else "i,j"
basispol = "dp" if derivative else "p"
polynomial_access = "i,{}[j]".format(qp) if transpose else "j,{}[i]".format(qp)
polynomial_access = "{},{}[{}]".format(i, qp, j) if transpose else "{},{}[{}]".format(j, qp, i)
return ["for (std::size_t i=0; i<{}; i++){{".format(shape[0]),
" for (std::size_t j=0; j<{}; j++){{".format(shape[1]),
" {}.colmajoraccess(i,j) = {}.{}({});".format(name, polynomials, basispol, polynomial_access),
" }",
"}"]
return instruction(code="{}.colmajoraccess({},{}) = {}.{}({});".format(name, i, j, polynomials, basispol, polynomial_access),
kernel="operator",
within_inames=frozenset({i, j}),
within_inames_is_final=True,
)
@class_member(classtag="operator")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment