Skip to content
Snippets Groups Projects
Commit e08b5ea8 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Move the loopy generators too

parent 8de34429
No related branches found
No related tags found
No related merge requests found
from __future__ import absolute_import
# TODO I am not sure about whether to collect these here or not.
from dune.perftool.generation.cache import no_caching, generator_factory, retrieve_cache_items, delete_cache_items, delete_cache
from dune.perftool.generation.cpp import *
from dune.perftool.generation.cache import (generator_factory, # TODO get rid of this one, it is internal
retrieve_cache_items,
delete_cache_items,
delete_cache,
)
from dune.perftool.generation.cpp import (base_class,
class_member,
constructor_parameter,
include_file,
initializer_list,
preamble,
symbol,
)
from dune.perftool.generation.loopy import (c_instruction,
domain,
expr_instruction,
globalarg,
iname,
temporary_variable,
valuearg,
)
\ No newline at end of file
""" The loopy specific generators """
\ No newline at end of file
""" The loopy specific generators """
from __future__ import absolute_import
from dune.perftool.generation import generator_factory
import loopy
import numpy
iname = generator_factory(item_tags=("loopy", "iname"))
expr_instruction = generator_factory(item_tags=("loopy", "instruction", "exprinstruction"), no_deco=True)
temporary_variable = generator_factory(item_tags=("loopy", "temporary"), on_store=lambda n: loopy.TemporaryVariable(n, dtype=numpy.float64), no_deco=True)
c_instruction = generator_factory(item_tags=("loopy", "instruction", "cinstruction"), no_deco=True)
valuearg = generator_factory(item_tags=("loopy", "argument", "valuearg"), on_store=lambda n: loopy.ValueArg(n), no_deco=True)
@generator_factory(item_tags=("loopy", "argument", "globalarg"))
def globalarg(name, shape=loopy.auto):
if isinstance(shape, str):
shape = (shape,)
return loopy.GlobalArg(name, numpy.float64, shape)
@generator_factory(item_tags=("loopy", "domain"))
def domain(iname, shape):
valuearg(shape)
return "{{ [{0}] : 0<={0}<{1} }}".format(iname, shape)
......@@ -4,61 +4,47 @@ This is the module that contains the main transformation from ufl to loopy
"""
from __future__ import absolute_import
from dune.perftool import Restriction
from dune.perftool.ufl.modified_terminals import ModifiedTerminalTracker
from dune.perftool.pymbolic.uflmapper import UFL2PymbolicMapper
from ufl.algorithms import MultiFunction
# TODO Spread the pymbolic import statements to where they are used.
from pymbolic.primitives import Variable, Subscript, Sum, Product
import loopy
import numpy
import ufl
# Define the generators that are used here
from dune.perftool.generation import generator_factory
loopy_iname = generator_factory(item_tags=("loopy", "kernel", "iname"))
loopy_expr_instruction = generator_factory(item_tags=("loopy", "kernel", "instruction", "exprinstruction"), no_deco=True)
loopy_temporary_variable = generator_factory(item_tags=("loopy", "kernel", "temporary"), on_store=lambda n: loopy.TemporaryVariable(n, dtype=numpy.float64), no_deco=True)
loopy_c_instruction = generator_factory(item_tags=("loopy", "kernel", "instruction", "cinstruction"), no_deco=True)
loopy_valuearg = generator_factory(item_tags=("loopy", "kernel", "argument", "valuearg"), on_store=lambda n: loopy.ValueArg(n), no_deco=True)
@generator_factory(item_tags=("loopy", "kernel", "argument", "globalarg"))
def loopy_globalarg(name, shape=loopy.auto):
if isinstance(shape, str):
shape = (shape,)
return loopy.GlobalArg(name, numpy.float64, shape)
from dune.perftool.generation import (c_instruction,
domain,
expr_instruction,
globalarg,
iname,
temporary_variable,
valuearg,
)
from ufl.algorithms import MultiFunction
@generator_factory(item_tags=("loopy", "kernel", "domain"))
def loopy_domain(iname, shape):
loopy_valuearg(shape)
return "{{ [{0}] : 0<={0}<{1} }}".format(iname, shape)
import loopy
@loopy_iname
@iname
def dimension_iname(index):
from dune.perftool.pdelab import name_index
from dune.perftool.pdelab.geometry import name_dimension
iname = name_index(index)
dimname = name_dimension()
loopy_domain(iname, dimname)
domain(iname, dimname)
return iname
@loopy_iname
@iname
def argument_iname(arg):
# TODO extract the {iname}_n thing by a preamble
from dune.perftool.ufl.modified_terminals import modified_argument_number
iname = "arg{}".format(chr(ord("i") + modified_argument_number()(arg)))
loopy_domain(iname, iname + "_n")
domain(iname, iname + "_n")
return iname
@loopy_iname
@iname
def quadrature_iname():
loopy_domain("q", "q_n")
domain("q", "q_n")
return "q"
......@@ -99,8 +85,9 @@ class TrialFunctionExtractor(MultiFunction):
if o in self.tf:
# This is a modified trial function!
from dune.perftool.pdelab.argument import name_trialfunction
from pymbolic.primitives import Variable
name = name_trialfunction(o)
loopy_globalarg(name)
globalarg(name)
return Variable(name)
else:
return self.u2l(o)
......@@ -134,8 +121,9 @@ def transform_accumulation_term(term):
# Define a temporary variable for this expression
expr_tv_name = "expr_" + str(get_count()).zfill(4)
expr_tv = loopy_temporary_variable(expr_tv_name)
loopy_expr_instruction(loopy.ExpressionInstruction(assignee=Variable(expr_tv_name), expression=pymbolic_expr))
expr_tv = temporary_variable(expr_tv_name)
from pymbolic.primitives import Variable
expr_instruction(loopy.ExpressionInstruction(assignee=Variable(expr_tv_name), expression=pymbolic_expr))
# The data that is used to collect the arguments for the accumulate function
accumargs = []
......@@ -148,7 +136,7 @@ def transform_accumulation_term(term):
accumargs.append(argument_iname(arg))
name = name_argument(arg)
argument_code.append(name)
loopy_globalarg(name)
globalarg(name)
from dune.perftool.pdelab.argument import name_residual
residual = name_residual()
......@@ -157,12 +145,12 @@ def transform_accumulation_term(term):
inames = retrieve_cache_items("iname")
from dune.perftool.pdelab.quadrature import name_factor
loopy_c_instruction(loopy.CInstruction(inames,
"{}.accumulate({}, {}*{}*{})".format(residual,
", ".join(accumargs),
expr_tv_name,
"*".join(argument_code),
name_factor()
)
)
)
c_instruction(loopy.CInstruction(inames,
"{}.accumulate({}, {}*{}*{})".format(residual,
", ".join(accumargs),
expr_tv_name,
"*".join(argument_code),
name_factor()
)
)
)
......@@ -8,7 +8,7 @@ from loopy import CInstruction
def quadrature_preamble(assignees=[]):
# TODO: How to enforce the order of quadrature preambles? Counted?
return generator_factory(item_tags=("pdelab", "instruction", "cinstruction", "quadrature"), on_store=lambda code: CInstruction(quadrature_iname(), code, assignees=assignees))
return generator_factory(item_tags=("instruction", "cinstruction"), on_store=lambda code: CInstruction(quadrature_iname(), code, assignees=assignees))
# Now define some commonly used generators that do not fall into a specific category
......
from dune.perftool.loopy.transformer import quadrature_iname, loopy_temporary_variable
from dune.perftool.generation import symbol
from dune.perftool.loopy.transformer import quadrature_iname
from dune.perftool.generation import symbol, temporary_variable
from dune.perftool.pdelab import quadrature_preamble
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment