Skip to content
Snippets Groups Projects
Commit cb45cbe9 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Introduce proper function mangling!

parent 284f3fa9
No related branches found
No related tags found
No related merge requests found
......@@ -24,7 +24,6 @@ from dune.perftool.generation.loopy import (constantarg,
globalarg,
iname,
instruction,
function_mangler,
pymbolic_expr,
temporary_variable,
valuearg,
......
......@@ -11,7 +11,6 @@ import numpy
iname = generator_factory(item_tags=("iname",))
pymbolic_expr = generator_factory(item_tags=("kernel", "pymbolic"))
function_mangler = generator_factory(item_tags=("kernel", "mangler"))
@generator_factory(item_tags=("argument", "globalarg"),
......
from loopy import CallMangleInfo
from loopy.symbolic import FunctionIdentifier
from loopy.types import NumpyType
import numpy
class CoefficientAccess(FunctionIdentifier):
def __init__(self, restriction):
self.restriction = restriction
def __getinitargs__(self):
return (self.restriction,)
@property
def name(self):
from dune.perftool.pdelab.argument import name_coefficientcontainer
return name_coefficientcontainer(self.restriction)
def coefficient_mangler(target, func, dtypes):
if isinstance(func, CoefficientAccess):
return CallMangleInfo(func.name, (NumpyType(numpy.float64),), (NumpyType(str), NumpyType(numpy.int32)))
class PDELabAccumulationFunction(FunctionIdentifier):
def __init__(self, accumobj):
self.accumobj = accumobj
def __getinitargs__(self):
return (self.accumobj,)
def accumulation_mangler(target, func, dtypes):
if isinstance(func, PDELabAccumulationFunction):
return CallMangleInfo('{}.accumulate'.format(func.accumobj), (), ())
......@@ -23,13 +23,6 @@ class MyMapper(ExpressionToCMapper):
ret = ret + '[{}]'.format(str(i))
return ret
def map_variable(self, expr, enclosing_prec, type_context):
from dune.perftool.pymbolic import VerbatimVariable
if isinstance(expr, VerbatimVariable):
return expr.name
else:
return super(MyMapper, self).map_variable(expr, enclosing_prec, type_context)
class DuneASTBuilder(CASTBuilder):
def get_expression_to_code_mapper(self, codegen_state):
......
""" Generator functions related to trial and test functions and the accumulation loop"""
from dune.perftool.generation import domain, iname, pymbolic_expr, symbol, globalarg, function_mangler, constantarg, get_global_context_value
from dune.perftool.generation import domain, iname, pymbolic_expr, symbol, globalarg, valuearg, get_global_context_value
from dune.perftool.ufl.modified_terminals import ModifiedArgumentDescriptor
from dune.perftool.pdelab import (name_index,
restricted_name,
......@@ -15,6 +15,8 @@ from dune.perftool.pdelab.basis import (evaluate_trialfunction,
from dune.perftool import Restriction
from pymbolic.primitives import Call, Subscript, Variable
import loopy
@symbol
def name_testfunction_gradient(element, restriction):
......@@ -68,28 +70,15 @@ def type_trialfunctionspace():
@symbol
def name_coefficientcontainer(restriction):
name = restricted_name("x", restriction)
from dune.perftool.pdelab.basis import name_lfs_bound, lfs_iname
return name
@function_mangler
def create_function_mangler(container):
def _mangler(kernel, name, dtypes):
if name == container:
import loopy
return loopy.types.NumpyType("int"), container
return _mangler
@pymbolic_expr
def pymbolic_coefficient(lfs, index, restriction):
container = name_coefficientcontainer(restriction)
create_function_mangler(container)
import loopy
constantarg(lfs, dtype=loopy.types.NumpyType("str"))
from dune.perftool.pymbolic import VerbatimVariable
return Call(VerbatimVariable(container), (VerbatimVariable(lfs), Variable(index),))
# TODO introduce a proper type for local function spaces!
valuearg(lfs, dtype=loopy.types.NumpyType("str"))
from dune.perftool.loopy.functions import CoefficientAccess
return Call(CoefficientAccess(restriction), (Variable(lfs), Variable(index),))
@symbol
......
......@@ -180,7 +180,9 @@ def generate_kernel(integrals):
instructions = [i for i in retrieve_cache_items("instruction")]
temporaries = {i.name: i for i in retrieve_cache_items("temporary")}
arguments = [i for i in retrieve_cache_items("argument")]
manglers = [i for i in retrieve_cache_items("mangler")]
# Get the function manglers
from dune.perftool.loopy.functions import accumulation_mangler, coefficient_mangler
# Create the kernel
from loopy import make_kernel, preprocess_kernel
......@@ -188,7 +190,7 @@ def generate_kernel(integrals):
instructions,
arguments,
temporary_variables=temporaries,
function_manglers=manglers,
function_manglers=[accumulation_mangler, coefficient_mangler],
target=DuneTarget()
)
......
from pymbolic.primitives import Variable
class VerbatimVariable(Variable):
pass
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment