Skip to content
Snippets Groups Projects
Commit e005da8a authored by Dominic Kempf's avatar Dominic Kempf
Browse files

A bit work on arguments

laplace is again proceeding to code generation
parent cb3390bd
No related branches found
No related tags found
No related merge requests found
......@@ -19,6 +19,7 @@ from dune.perftool.generation.loopy import (c_instruction,
expr_instruction,
globalarg,
iname,
pymbolic_expr,
temporary_variable,
valuearg,
)
......@@ -11,7 +11,8 @@ expr_instruction = generator_factory(item_tags=("loopy", "kernel", "instruction"
temporary_variable = generator_factory(item_tags=("loopy", "kernel", "temporary"), on_store=lambda n: loopy.TemporaryVariable(n, dtype=numpy.float64), no_deco=True)
c_instruction = generator_factory(item_tags=("loopy", "kernel", "instruction", "cinstruction"), no_deco=True)
valuearg = generator_factory(item_tags=("loopy", "kernel", "argument", "valuearg"), on_store=lambda n: loopy.ValueArg(n), no_deco=True)
pymbolic_expr = generator_factory(item_tags=("loopy", "kernel", "pymbolic"))
constantarg = generator_factory(item_tags=("loopy", "kernel", "argument", "constantarg"), on_store=lambda n:loopy.ConstantArg(n))
@generator_factory(item_tags=("loopy", "kernel", "argument", "globalarg"))
def globalarg(name, shape=loopy.auto):
......
......@@ -104,11 +104,11 @@ def transform_accumulation_term(term):
rmap = {}
for ma in test_ma:
from dune.perftool.pdelab.argument import name_testfunction
rmap[ma.expr] = Variable(name_testfunction(ma))
from dune.perftool.pdelab.argument import pymbolic_testfunction
rmap[ma.expr] = pymbolic_testfunction(ma)
for ma in trial_ma:
from dune.perftool.pdelab.argument import name_trialfunction
rmap[ma.expr] = Variable(name_trialfunction(ma))
from dune.perftool.pdelab.argument import pymbolic_trialfunction
rmap[ma.expr] = pymbolic_trialfunction(ma)
# Get the transformer!
ufl2l_mf = UFL2LoopyVisitor()
......@@ -119,8 +119,8 @@ def transform_accumulation_term(term):
# Now simplify the expression
# TODO: Add a switch to disable/configure this.
from dune.perftool.pymbolic.simplify import simplify_pymbolic_expression
pymbolic_expr = simplify_pymbolic_expression(pymbolic_expr)
# from dune.perftool.pymbolic.simplify import simplify_pymbolic_expression
# pymbolic_expr = simplify_pymbolic_expression(pymbolic_expr)
# Define a temporary variable for this expression
expr_tv_name = "expr_" + str(get_count()).zfill(4)
......@@ -133,11 +133,9 @@ def transform_accumulation_term(term):
# Generate the code for the modified arguments:
for arg in test_ma:
from dune.perftool.pdelab.argument import name_argumentspace, name_argument
from dune.perftool.pdelab.argument import name_argumentspace, pymbolic_argument
accumargs.append(name_argumentspace(arg))
accumargs.append(argument_iname(arg))
# TODO is this global
#globalarg(argument_iname(arg)+"_n")
from dune.perftool.pdelab.argument import name_residual
residual = name_residual()
......
""" Generator functions related to trial and test functions and the accumulation loop"""
from dune.perftool.generation import symbol
from dune.perftool.generation import pymbolic_expr, symbol, globalarg
from dune.perftool.ufl.modified_terminals import ModifiedArgumentDescriptor
......@@ -11,11 +11,37 @@ def name_testfunction(ma):
return "{}a{}".format("grad_" if ma.grad else "", ma.argexpr.number())
@pymbolic_expr
def pymbolic_testfunction(ma):
assert bool(ma.index) == ma.grad
from pymbolic.primitives import Subscript, Variable
v = Variable(name_testfunction(ma))
globalarg(name_testfunction(ma))
if ma.grad:
from dune.perftool.pdelab import name_index
return Subscript(v, Variable(name_index(ma.index)))
else:
return v
@symbol
def name_trialfunction(ma):
return "{}c{}".format("grad_" if ma.grad else "", ma.argexpr.count())
@pymbolic_expr
def pymbolic_trialfunction(ma):
assert bool(ma.index) == ma.grad
from pymbolic.primitives import Subscript, Variable
v = Variable(name_trialfunction(ma))
globalarg(name_trialfunction(ma))
if ma.grad:
from dune.perftool.pdelab import name_index
return Subscript(v, Variable(name_index(ma.index)))
else:
return v
@symbol
def name_testfunctionspace(*a):
# TODO
......@@ -46,6 +72,13 @@ def name_argument(ma):
assert False
def pymbolic_argument(ma):
if ma.argexpr.number() == 0:
return pymbolic_testfunction(ma)
if ma.argexpr.number() == 1:
return pymbolic_trialfunction(ma)
assert False
@symbol
def name_residual():
return "r"
......@@ -77,7 +77,7 @@ class ModifiedArgumentDescriptor(MultiFunction):
self(o.ufl_operands[0])
def indexed(self, o):
indexed = o.ufl_operands[1]
self.index = o.ufl_operands[1]
self(o.ufl_operands[0])
def argument(self, o):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment