Skip to content
Snippets Groups Projects
Commit c2c6100a authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Major overhaul of loopy instruction generation mechanisms

parent e005da8a
No related branches found
No related tags found
No related merge requests found
......@@ -14,11 +14,10 @@ from dune.perftool.generation.cpp import (base_class,
symbol,
)
from dune.perftool.generation.loopy import (c_instruction,
domain,
expr_instruction,
from dune.perftool.generation.loopy import (domain,
globalarg,
iname,
instruction,
pymbolic_expr,
temporary_variable,
valuearg,
......
......@@ -25,6 +25,10 @@ def _freeze(data):
if isinstance(data, ufl.classes.Expr):
return data
from pymbolic.primitives import Expression
if isinstance(data, Expression):
return data
# Check if the given data is already hashable
if isinstance(data, Hashable):
if isinstance(data, Iterable):
......@@ -103,16 +107,19 @@ class _RegisteredFunction(object):
self.cache_key_generator = cache_key_generator
self.itemtype = _construct_cache_item_type("CacheItemType", **kwargs)
def __call__(self, *args):
def __call__(self, *args, **kwargs):
# Get the cache key from the given arguments
cache_key = (self, _freeze(self.cache_key_generator(*args)))
cache_args = self.cache_key_generator(*args, **kwargs)
# Make sure that all keyword arguments have vanished from the cache_args
assert (lambda *a, **k: len(k) == 0)(cache_args)
cache_key = (self, _freeze(self.cache_key_generator(*args, **kwargs)))
# check whether we have a cache hit
if cache_key in _cache:
# and return the result depending on the cache item type
return _cache[cache_key].content
else:
# evaluate the original function and wrap it in a cache item
citem = self.itemtype(self.func(*args))
citem = self.itemtype(self.func(*args, **kwargs))
_cache[cache_key] = citem
return citem.content
......
......@@ -7,9 +7,7 @@ import loopy
import numpy
iname = generator_factory(item_tags=("loopy", "kernel", "iname"))
expr_instruction = generator_factory(item_tags=("loopy", "kernel", "instruction", "exprinstruction"), no_deco=True)
temporary_variable = generator_factory(item_tags=("loopy", "kernel", "temporary"), on_store=lambda n: loopy.TemporaryVariable(n, dtype=numpy.float64), no_deco=True)
c_instruction = generator_factory(item_tags=("loopy", "kernel", "instruction", "cinstruction"), no_deco=True)
valuearg = generator_factory(item_tags=("loopy", "kernel", "argument", "valuearg"), on_store=lambda n: loopy.ValueArg(n), no_deco=True)
pymbolic_expr = generator_factory(item_tags=("loopy", "kernel", "pymbolic"))
constantarg = generator_factory(item_tags=("loopy", "kernel", "argument", "constantarg"), on_store=lambda n:loopy.ConstantArg(n))
......@@ -26,3 +24,55 @@ def domain(iname, shape):
if isinstance(shape, str):
valuearg(shape)
return "{{ [{0}] : 0<={0}<{1} }}".format(iname, shape)
# Now define generators for instructions. To ease dependency handling of instructions
# these generators are a bit more involved... We apply the following procedure:
# There is one generator that returns the unique id and forwards to a generator that
# actually adds the instruction. Hashing is done based on the code snippet.
@generator_factory(item_tags=("loopy", "kernel", "instruction", "cinstruction"),
cache_key_generator=lambda *a, **kw: kw['code'],
)
def c_instruction_impl(**kw):
kw['insn_deps'] = kw.pop('deps', None)
kw.setdefault('assignees', [])
inames = kw.pop('inames')
return loopy.CInstruction(inames, **kw)
@generator_factory(item_tags=("loopy", "kernel", "instruction", "exprinstruction"),
cache_key_generator=lambda *a, **kw: kw['expression'],
)
def expr_instruction_impl(**kw):
return loopy.ExpressionInstruction(id=kw['id'], assignee=kw['assignee'], expression=kw['expression'])
class _IDCounter:
count = 0
def _insn_cache_key(inames, code=None, expr=None, deps=[]):
if code:
return code
if expr:
return expr
@generator_factory(item_tags=("insn_id"), no_deco=True, cache_key_generator=_insn_cache_key)
def instruction(code=None, expression=None, **kwargs):
assert code or expression
assert not (code and expression)
# Get an ID for this instruction
id = 'insn' + str(_IDCounter.count).zfill(4)
_IDCounter.count = _IDCounter.count + 1
# Now create the actual instruction
if code:
c_instruction_impl(id=id, code=code, **kwargs)
if expression:
expr_instruction_impl(id=id, expression=expression, **kwargs)
# return the ID, as it is the only useful information to the user
return id
......@@ -9,11 +9,10 @@ from dune.perftool import Restriction
from dune.perftool.ufl.modified_terminals import ModifiedTerminalTracker
from dune.perftool.pymbolic.uflmapper import UFL2PymbolicMapper
from dune.perftool.generation import (c_instruction,
domain,
expr_instruction,
from dune.perftool.generation import (domain,
globalarg,
iname,
instruction,
temporary_variable,
valuearg,
)
......@@ -48,6 +47,12 @@ def quadrature_iname():
return "q"
@iname
def index_sum_iname(i):
from dune.perftool.pdelab import name_index
return name_index(i)
class UFL2LoopyVisitor(ModifiedTerminalTracker, UFL2PymbolicMapper):
def __init__(self):
super(UFL2LoopyVisitor, self).__init__()
......@@ -64,10 +69,9 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker, UFL2PymbolicMapper):
# Define an iname for each of the indices in the multiindex
for i in o.ufl_operands[1].indices():
shape = determine_shape(o.ufl_operands[0], i)
index_sum_iname(i)
from dune.perftool.pdelab import name_index
name = name_index(i)
iname(name)
domain(name, shape)
domain(name_index(i), shape)
# Now continue processing the expression
return self.call(o.ufl_operands[0])
......@@ -126,7 +130,7 @@ def transform_accumulation_term(term):
expr_tv_name = "expr_" + str(get_count()).zfill(4)
expr_tv = temporary_variable(expr_tv_name)
from pymbolic.primitives import Variable
expr_instruction(loopy.ExpressionInstruction(assignee=Variable(expr_tv_name), expression=pymbolic_expr))
instruction(assignee=Variable(expr_tv_name), expression=pymbolic_expr)
# The data that is used to collect the arguments for the accumulate function
accumargs = []
......@@ -144,11 +148,10 @@ def transform_accumulation_term(term):
inames = retrieve_cache_items("iname")
from dune.perftool.pdelab.quadrature import name_factor
c_instruction(loopy.CInstruction(inames,
"{}.accumulate({}, {}*{})".format(residual,
", ".join(accumargs),
expr_tv_name,
name_factor()
)
)
)
instruction(inames=inames,
code="{}.accumulate({}, {}*{})".format(residual,
", ".join(accumargs),
expr_tv_name,
name_factor()
)
)
\ No newline at end of file
""" The pdelab specific parts of the code generation process """
# Define the generators that are used throughout all pdelab specific code generations.
from dune.perftool.generation import symbol, generator_factory
from dune.perftool.generation import symbol, instruction
from dune.perftool.loopy.transformer import quadrature_iname
from loopy import CInstruction
def quadrature_preamble(assignees=[]):
# TODO: How to enforce the order of quadrature preambles? Counted?
return generator_factory(item_tags=("instruction", "cinstruction"), on_store=lambda code: CInstruction(quadrature_iname(), code, assignees=assignees))
def quadrature_preamble(code, **kw):
return instruction(inames=quadrature_iname(), code=code, **kw)
# Now define some commonly used generators that do not fall into a specific category
......
......@@ -8,10 +8,10 @@ def quadrature_rule():
return "rule"
@quadrature_preamble()
def define_quadrature_factor(fac):
rule = quadrature_rule()
return "auto {} = {}->weight();".format(fac, rule)
code = "auto {} = {}->weight();".format(fac, rule)
return quadrature_preamble(code, assignees=fac)
@symbol
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment