Skip to content
Snippets Groups Projects
Commit f0c015cc authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Getting loval basis caches into code generation

as a consequence introduce template parameters
parent 6f037c1e
No related branches found
No related tags found
No related merge requests found
......@@ -102,9 +102,6 @@ class Class(Generable):
assert isinstance(bc, BaseClass)
for mem in members:
assert isinstance(mem, ClassMember)
from cgen import Declarator
for tp in tparam_decls:
assert isinstance(tp, Declarator)
for con in constructors:
assert isinstance(con, Constructor)
......@@ -114,8 +111,9 @@ class Class(Generable):
decl = Value('class', self.name)
if self.tparam_decls:
from cgen import Template
decl = Template(self.tparam_decls, decl)
yield 'template<'
yield ', '.join('typename {}'.format(t) for t in self.tparam_decls)
yield '>\n'
# Yield the definition
for line in decl.generate(with_semicolon=False):
......
......@@ -7,12 +7,14 @@ from dune.perftool.generation.cache import (cached,
)
from dune.perftool.generation.cpp import (base_class,
class_basename,
class_member,
constructor_parameter,
include_file,
initializer_list,
preamble,
symbol,
template_parameter,
)
from dune.perftool.generation.loopy import (domain,
......
......@@ -51,3 +51,15 @@ def constructor_parameter(_type, name, classtag=None, constructortag=None):
gen = generator_factory(item_tags=(classtag, constructortag, "constructor_param"), counted=True, no_deco=True)
return gen(Value(_type, name))
def template_parameter(classtag=None):
assert classtag
return generator_factory(item_tags=(classtag, "template_param"), counted=True)
def class_basename(classtag=None):
assert classtag
return generator_factory(item_tags=(classtag, "basename"))
......@@ -29,11 +29,34 @@ def domain(iname, shape):
return "{{ [{0}] : 0<={0}<{1} }}".format(iname, shape)
def _temporary_type(shape_impl, shape, first=True):
if len(shape_impl) == 0:
return 'double'
if shape_impl[0] == 'arr':
if not first or len(set(shape_impl)) != 1:
raise ValueError("We do not allow mixing of C++ containers and plain C arrays, for reasons of mental sanity")
return 'double'
if shape_impl[0] == 'vec':
return "std::vector<{}>".format(_temporary_type(shape_impl[1:], shape[1:], first=False))
if shape_impl[0] == 'fv':
return "Dune::FieldVector<{}, {}>".format(_temporary_type(shape_impl[1:], shape[1:], first=False), shape[0])
if shape_impl[0] == 'fm':
pass
@preamble
def default_temporary_declaration(name, **kwargs):
shape = kwargs.get('shape', ())
array_bounds = ''.join('[{}]'.format(s) for s in shape)
return 'double {}{};'.format(name, array_bounds)
def default_declaration(name, shape, shape_impl):
# Determine the C++ type to use for this temporary.
t = _temporary_type(shape_impl, shape)
if len(shape_impl) == 0:
# This is a scalar, just return it!
return '{} {}(0.0);'.format(t, name);
if shape_impl[0] == 'arr':
return '{} {}{};'.format(t, name, ''.join('[{}]'.format(s) for s in shape))
if shape_impl[0] == 'vec':
return '{} {}({});'.format(t, name, shape[0])
if shape_impl[0] == 'fv':
return '{} {}(0.0);'.format(t, name)
@generator_factory(item_tags=("loopy", "kernel", "temporary"), cache_key_generator=lambda n, **kw: n)
......@@ -41,8 +64,11 @@ def temporary_variable(name, **kwargs):
if 'dtype' not in kwargs:
kwargs['dtype'] = numpy.float64
decl_method = kwargs.pop('decl_method', default_temporary_declaration)
decl_method(name, **kwargs)
decl_method = kwargs.pop('decl_method', default_declaration)
shape = kwargs.get('shape', ())
shape_impl = kwargs.pop('shape_impl', ('arr',)*len(shape))
decl_method(name, shape, shape_impl)
return loopy.TemporaryVariable(name, **kwargs)
......
""" Generators for basis evaluations """
from dune.perftool.generation import (cached,
class_member,
domain,
generator_factory,
iname,
include_file,
instruction,
preamble,
symbol,
......@@ -15,6 +17,29 @@ from dune.perftool.pdelab.quadrature import (name_quadrature_point,
from dune.perftool.pdelab.geometry import (name_dimension,
name_jacobian_inverse_transposed,
)
from dune.perftool.pdelab.localoperator import (lop_template_ansatz_gfs,
lop_template_test_gfs,
)
from dune.perftool.pdelab.driver import FEM_name_mangling
@symbol
def type_localbasis_cache(element):
return "Dune::PDELab::LocalBasisCache<typename {}::Traits::FiniteElementMap::Traits::FiniteElementType::Traits::LocalBasisType>".format(type_gfs(element))
@class_member("operator")
def define_localbasis_cache(element, name):
include_file("dune/pdelab/finiteelement/localbasiscache.hh", filetag="operatorfile")
t = type_localbasis_cache(element)
return "{} {};".format(t, name)
@symbol
def name_localbasis_cache(element):
name = "cache_{}".format(FEM_name_mangling(element))
define_localbasis_cache(element, name)
return name
@preamble
......@@ -58,19 +83,46 @@ def name_lfs(element, prefix=None):
return prefix
@generator_factory(cache_key_generator=lambda e, **kw: e)
def type_gfs(element, basetype=None, index_stack=None):
# Omitting basetype and index_stack is only valid upon a second call,
# which will result in a cache hit.
assert basetype
assert index_stack is not None
# Additionally, element is expected to be a ufl finite element
from ufl import FiniteElementBase
assert isinstance(element, FiniteElementBase)
# Recurse into the given element to define all other local function spaces!
from ufl import MixedElement
if isinstance(element, MixedElement):
for i, subelem in enumerate(element.sub_elements()):
type_gfs(subelem, basetype=basetype, index_stack=index_stack + (i,))
if len(index_stack) == 0:
return basetype
else:
include_file("dune/typetree/childextraction.hh", filetag="operatorfile")
return 'Dune::TypeTree::Child<{},{}>'.format(basetype, ','.join(str(i) for i in index_stack))
def traverse_lfs_tree(arg):
from dune.perftool.ufl.modified_terminals import ModifiedArgumentDescriptor
assert isinstance(arg, ModifiedArgumentDescriptor)
# First we need to determine the basename as given in the signature of
# this kernel method!
basename = None
lfs_basename = None
gfs_basename = None
from ufl.classes import Argument, Coefficient
if isinstance(arg.argexpr, Argument):
if arg.argexpr.count() == 0:
basename = 'lfsv'
lfs_basename = 'lfsv'
gfs_basename = 'GFSV'
if arg.argexpr.count() == 1:
basename = 'lfsu'
lfs_basename = 'lfsu'
gfs_basename = 'GFSU'
# TODO add restrictions here.
if isinstance(arg.argexpr, Coefficient):
......@@ -78,14 +130,16 @@ def traverse_lfs_tree(arg):
# is the coefficient of reserved index 0.
assert arg.argexpr.count() == 0
basename = 'lfsu'
lfs_basename = 'lfsu'
gfs_basename = 'GFSU'
assert basename
assert lfs_basename and gfs_basename
# Now start recursively extracting local function spaces and fill the cache with
# all those values. That way we can later get a correct local function space with
# just the ufl finite element.
name_lfs(arg.argexpr.element(), prefix=basename)
name_lfs(arg.argexpr.element(), prefix=lfs_basename)
type_gfs(arg.argexpr.element(), basetype=gfs_basename, index_stack=())
@iname
......@@ -124,15 +178,17 @@ def lfs_iname(element, argcount=0, context=''):
@cached
def evaluate_basis(element, name):
temporary_variable(name, shape=(name_lfs_bound(element),))
temporary_variable(name, shape=(name_lfs_bound(element),), shape_impl=('vec',))
cache = name_localbasis_cache(element)
lfs = name_lfs(element)
qp = name_quadrature_point()
instruction(inames=(quadrature_iname(),
),
code='{}.finiteElement().localBasis().evaluateFunction({}, {});'.format(lfs,
qp,
name,
),
code='auto& {} = {}.evaluateFunction({}, {}.finiteElement().localBasis());'.format(name,
cache,
qp,
lfs,
),
assignees=frozenset({name}),
)
......@@ -145,16 +201,17 @@ def name_basis(element):
@cached
def evaluate_reference_gradient(element, name):
# TODO this is of course not yet correct
temporary_variable(name, shape=(name_lfs_bound(element), name_dimension()))
temporary_variable(name, shape=(name_lfs_bound(element), name_dimension()), shape_impl=('vec', 'fv'))
cache = name_localbasis_cache(element)
lfs = name_lfs(element)
qp = name_quadrature_point()
instruction(inames=(quadrature_iname(),
),
code='{}.finiteElement().localBasis().evaluateJacobian({}, {});'.format(lfs,
qp,
name,
),
code='auto& {} = {}.evaluateJacobian({}, {}.finiteElement().localBasis());'.format(name,
cache,
qp,
lfs,
),
assignees=frozenset({name}),
)
......@@ -168,7 +225,7 @@ def name_reference_gradient(element):
@cached
def evaluate_basis_gradient(element, name):
# TODO this is of course not yet correct
temporary_variable(name, shape=(name_lfs_bound(element), name_dimension()))
temporary_variable(name, shape=(name_lfs_bound(element), name_dimension()), shape_impl=('vec', 'fv'))
jac = name_jacobian_inverse_transposed()
index = lfs_iname(element, context='transformgrads')
reference_gradients = name_reference_gradient(element)
......@@ -215,7 +272,7 @@ def evaluate_trialfunction(element, name):
@cached
def evaluate_trialfunction_gradient(element, name):
# TODO this is of course not yet correct
temporary_variable(name, shape=(name_dimension(),))
temporary_variable(name, shape=(name_dimension(),), shape_impl=('fv',))
lfs = name_lfs(element)
index = lfs_iname(element, context='trialgrad')
basis = name_basis_gradient(element)
......
......@@ -438,15 +438,17 @@ def typedef_localoperator(name):
# No Parameter class here, yet
# params = type_parameters()
# return "typedef LocalOperator<{}> {};".format(params, name)
ugfs = type_gfs(_form.coefficients()[0].element())
vgfs = type_gfs(_form.arguments()[0].element())
from dune.perftool.options import get_option
include_file(get_option('operator_file'), filetag="driver")
return "// Here in the future: typedef for the local operator with parameter class as template parameter"
return "using {} = LocalOperator<{}, {}>;".format(name, ugfs, vgfs)
@symbol
def type_localoperator():
typedef_localoperator("LocalOperator")
return "LocalOperator"
typedef_localoperator("LOP")
return "LOP"
@preamble
......@@ -493,7 +495,9 @@ def define_gridoperator(name):
vcc = name_constraintscontainer(_form.arguments()[0].element())
lop = name_localoperator()
mb = name_matrixbackend()
return "{} {}({}, {}, {}, {}, {}, {});".format(gotype, name, ugfs, ucc, vgfs, vcc, lop, mb)
return ["{} {}({}, {}, {}, {}, {}, {});".format(gotype, name, ugfs, ucc, vgfs, vcc, lop, mb),
"std::cout << \"gfs with \" << {}.size() << \" dofs generated \"<< std::endl;".format(ugfs),
"std::cout << \"cc with \" << {}.size() << \" dofs generated \"<< std::endl;".format(ucc)]
@symbol
......
......@@ -35,7 +35,7 @@ def name_dimension():
@preamble
def define_jacobian_inverse_transposed_temporary(name, **kwargs):
def define_jacobian_inverse_transposed_temporary(name, shape, shape_impl):
geo = name_geometry()
return "auto {} = {}.jacobianInverseTransposed({{{{ 0.0 }}}});".format(name,
geo,
......
from __future__ import absolute_import
from dune.perftool.options import get_option
from dune.perftool.generation import include_file, base_class, symbol, initializer_list, class_member, constructor_parameter
from dune.perftool.cgen.clazz import AccessModifier, BaseClass, ClassMember
from dune.perftool.generation import (base_class,
class_basename,
class_member,
constructor_parameter,
include_file,
initializer_list,
symbol,
template_parameter,
)
from dune.perftool.cgen.clazz import (AccessModifier,
BaseClass,
ClassMember,
)
from pytools import memoize
@template_parameter("operator")
def lop_template_ansatz_gfs():
return "GFSU"
@template_parameter("operator")
def lop_template_test_gfs():
return "GFSV"
@symbol
def name_initree_constructor_param():
return "iniParams"
......@@ -38,12 +59,27 @@ def name_initree_member():
return "_iniParams"
@symbol
def localoperator_type():
# TODO use something from the form here to make it unique
@class_basename("operator")
def localoperator_basename():
return "LocalOperator"
def class_type_from_cache(classtag):
from dune.perftool.generation import retrieve_cache_items
# get the basename
basename = [i for i in retrieve_cache_items(condition="{} and basename".format(classtag))]
assert len(basename) == 1
basename = basename[0]
# get the template parameters
tparams = [i for i in retrieve_cache_items(condition="{} and template_param".format(classtag))]
tparam_str = ''
if len(tparams) > 0:
tparam_str = '<{}>'.format(', '.join(t for t in tparams))
return basename, basename + tparam_str
@memoize
def measure_specific_details(measure):
# The return dictionary that this memoized method will grant direct access to.
......@@ -52,8 +88,7 @@ def measure_specific_details(measure):
def numerical_jacobian(which):
if get_option("numerical_jacobian"):
# Add a base class
from dune.perftool.pdelab.driver import type_localoperator
loptype = type_localoperator()
_, loptype = class_type_from_cache("operator")
base_class("Dune::PDELab::NumericalJacobian{}<{}>".format(which, loptype), classtag="operator")
# Add the initializer list for that base class
......@@ -149,16 +184,20 @@ class AssemblyMethod(ClassMember):
def cgen_class_from_cache(tag, members=[]):
from dune.perftool.generation import retrieve_cache_items
# Generate the name by concatenating basename and template parameters
basename, fullname = class_type_from_cache("operator")
base_classes = [bc for bc in retrieve_cache_items('{} and baseclass'.format(tag))]
constructor_params = [bc for bc in retrieve_cache_items('{} and constructor_param'.format(tag))]
il = [i for i in retrieve_cache_items('{} and initializer'.format(tag))]
pm = [m for m in retrieve_cache_items('{} and member'.format(tag))]
tparams = [i for i in retrieve_cache_items('{} and template_param'.format(tag))]
from dune.perftool.cgen.clazz import Constructor
constructor = Constructor(arg_decls=constructor_params, clsname=localoperator_type(), initializer_list=il)
constructor = Constructor(arg_decls=constructor_params, clsname=basename, initializer_list=il)
from dune.perftool.cgen import Class
return Class(localoperator_type(), base_classes=base_classes, members=members + pm, constructors=[constructor])
return Class(basename, base_classes=base_classes, members=members + pm, constructors=[constructor], tparam_decls=tparams)
def generate_localoperator_kernels(form):
......@@ -175,7 +214,11 @@ def generate_localoperator_kernels(form):
include_file('dune/pdelab/localoperator/idefault.hh', filetag="operatorfile")
include_file('dune/pdelab/localoperator/flags.hh', filetag="operatorfile")
include_file('dune/pdelab/localoperator/pattern.hh', filetag="operatorfile")
include_file('dune/geometry/quadraturerules.hh', filetag="operatorfile")
# Trigger this one once early on to avoid wrong stuff happening
localoperator_basename()
lop_template_ansatz_gfs()
lop_template_test_gfs()
base_class('Dune::PDELab::LocalOperatorDefaultFlags', classtag="operator")
......
from dune.perftool.generation import (cached,
domain,
iname,
include_file,
instruction,
symbol,
temporary_variable,
......@@ -64,6 +65,7 @@ def name_order():
@cached
def quadrature_loop_statement():
include_file('dune/pdelab/common/quadraturerules.hh', filetag='operatorfile')
qp = name_quadrature_point()
from dune.perftool.pdelab.geometry import name_geometry
geo = name_geometry()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment