Skip to content
Snippets Groups Projects
Commit 7959fa42 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Reimplement cache context as a step towards multiple kernels

parent 273fa4b3
No related branches found
No related tags found
No related merge requests found
...@@ -37,6 +37,7 @@ from dune.perftool.generation.loopy import (barrier, ...@@ -37,6 +37,7 @@ from dune.perftool.generation.loopy import (barrier,
globalarg, globalarg,
iname, iname,
instruction, instruction,
kernel_cached,
noop_instruction, noop_instruction,
silenced_warning, silenced_warning,
temporary_variable, temporary_variable,
......
""" This module provides the memoization infrastructure for code """ This module provides the memoization infrastructure for code
generating functions. generating functions.
""" """
from dune.perftool.generation.context import get_global_context_value
from dune.perftool.generation.counter import get_counter from dune.perftool.generation.counter import get_counter
# Store a global list of generator functions # Store a global list of generator functions
...@@ -51,13 +52,15 @@ class _RegisteredFunction(object): ...@@ -51,13 +52,15 @@ class _RegisteredFunction(object):
cache_key_generator=lambda *a, **kw: a, cache_key_generator=lambda *a, **kw: a,
counted=False, counted=False,
on_store=lambda x: x, on_store=lambda x: x,
item_tags=() item_tags=(),
context_tags=(),
): ):
self.func = func self.func = func
self.cache_key_generator = cache_key_generator self.cache_key_generator = cache_key_generator
self.counted = counted self.counted = counted
self.on_store = on_store self.on_store = on_store
self.item_tags = item_tags self.item_tags = item_tags
self.context_tags = context_tags
# Initialize the memoization cache # Initialize the memoization cache
self._memoize_cache = {} self._memoize_cache = {}
...@@ -81,9 +84,14 @@ class _RegisteredFunction(object): ...@@ -81,9 +84,14 @@ class _RegisteredFunction(object):
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
# Get the cache key from the given arguments # Get the cache key from the given arguments
cache_key = self.cache_key_generator(*args, **kwargs) cache_key = self.cache_key_generator(*args, **kwargs)
# Make sure that all keyword arguments have vanished from the cache_args # Make sure that all keyword arguments have vanished from the cache_args
assert (lambda *a, **k: len(k) == 0)(cache_key) assert (lambda *a, **k: len(k) == 0)(cache_key)
# Add any context tags to the cache key
context_key = tuple(get_global_context_value(t, None) for t in self.context_tags)
cache_key = (cache_key, context_key)
# check whether we have a cache hit # check whether we have a cache hit
if cache_key not in self._memoize_cache: if cache_key not in self._memoize_cache:
# evaluate the original function # evaluate the original function
...@@ -128,10 +136,17 @@ def generator_factory(**factory_kwargs): ...@@ -128,10 +136,17 @@ def generator_factory(**factory_kwargs):
items automatically turns to a tuple with the counttag being the first entry. items automatically turns to a tuple with the counttag being the first entry.
no_deco : bool no_deco : bool
Instead of a decorator, return a function that uses identity as a body. Instead of a decorator, return a function that uses identity as a body.
context_tags: tuple, str
A single tag or tuple thereof, that will be added to the cache key. This
feature can be used to maintain multiple sets of memoized function evaluations,
for example if you generate multiple loopy kernels at the same time. The
given strings are used to look up in the global context manager for a tag.
""" """
# Tuplize the item_tags parameter # Tuplize the item_tags parameter
if "item_tags" in factory_kwargs and isinstance(factory_kwargs["item_tags"], str): if "item_tags" in factory_kwargs and isinstance(factory_kwargs["item_tags"], str):
factory_kwargs["item_tags"] = (factory_kwargs["item_tags"],) factory_kwargs["item_tags"] = (factory_kwargs["item_tags"],)
if "context_tags" in factory_kwargs and isinstance(factory_kwargs["context_tags"], str):
factory_kwargs["context_tags"] = (factory_kwargs["context_tags"],)
no_deco = factory_kwargs.pop("no_deco", False) no_deco = factory_kwargs.pop("no_deco", False)
......
...@@ -6,15 +6,15 @@ from dune.perftool.generation import (get_counter, ...@@ -6,15 +6,15 @@ from dune.perftool.generation import (get_counter,
no_caching, no_caching,
preamble, preamble,
) )
from dune.perftool.generation.context import get_global_context_value
from dune.perftool.error import PerftoolLoopyError from dune.perftool.error import PerftoolLoopyError
import loopy import loopy as lp
import numpy import numpy as np
iname = generator_factory(item_tags=("iname",)) iname = generator_factory(item_tags=("iname",), context_tags="kernel")
function_mangler = generator_factory(item_tags=("mangler",)) function_mangler = generator_factory(item_tags=("mangler",), context_tags="kernel")
silenced_warning = generator_factory(item_tags=("silenced_warning",), no_deco=True) silenced_warning = generator_factory(item_tags=("silenced_warning",), no_deco=True, context_tags="kernel")
kernel_cached = generator_factory(item_tags=("default_cached",), context_tags="kernel")
class DuneGlobalArg(loopy.GlobalArg): class DuneGlobalArg(loopy.GlobalArg):
...@@ -22,6 +22,7 @@ class DuneGlobalArg(loopy.GlobalArg): ...@@ -22,6 +22,7 @@ class DuneGlobalArg(loopy.GlobalArg):
@generator_factory(item_tags=("argument", "globalarg"), @generator_factory(item_tags=("argument", "globalarg"),
context_tags="kernel",
cache_key_generator=lambda n, **kw: n) cache_key_generator=lambda n, **kw: n)
def globalarg(name, shape=loopy.auto, managed=True, **kw): def globalarg(name, shape=loopy.auto, managed=True, **kw):
if isinstance(shape, str): if isinstance(shape, str):
...@@ -31,21 +32,23 @@ def globalarg(name, shape=loopy.auto, managed=True, **kw): ...@@ -31,21 +32,23 @@ def globalarg(name, shape=loopy.auto, managed=True, **kw):
@generator_factory(item_tags=("argument", "constantarg"), @generator_factory(item_tags=("argument", "constantarg"),
context_tags="kernel",
cache_key_generator=lambda n, **kw: n) cache_key_generator=lambda n, **kw: n)
def constantarg(name, shape=loopy.auto, **kw): def constantarg(name, shape=None, **kw):
if isinstance(shape, str): if isinstance(shape, str):
shape = (shape,) shape = (shape,)
dtype = kw.pop("dtype", numpy.float64) dtype = kw.pop("dtype", np.float64)
return loopy.GlobalArg(name, dtype=dtype, shape=shape, **kw) return lp.GlobalArg(name, dtype=dtype, shape=shape, **kw)
@generator_factory(item_tags=("argument", "valuearg"), @generator_factory(item_tags=("argument", "valuearg"),
context_tags="kernel",
cache_key_generator=lambda n, **kw: n) cache_key_generator=lambda n, **kw: n)
def valuearg(name, **kw): def valuearg(name, **kw):
return loopy.ValueArg(name, **kw) return lp.ValueArg(name, **kw)
@generator_factory(item_tags=("domain",)) @generator_factory(item_tags=("domain",), context_tags="kernel")
def domain(iname, shape): def domain(iname, shape):
if isinstance(shape, str): if isinstance(shape, str):
valuearg(shape) valuearg(shape)
...@@ -56,7 +59,9 @@ def get_temporary_name(): ...@@ -56,7 +59,9 @@ def get_temporary_name():
return 'expr_{}'.format(str(get_counter('__temporary').zfill(4))) return 'expr_{}'.format(str(get_counter('__temporary').zfill(4)))
@generator_factory(item_tags=("temporary",), cache_key_generator=lambda n, **kw: n) @generator_factory(item_tags=("temporary",),
context_tags="kernel",
cache_key_generator=lambda n, **kw: n)
def temporary_variable(name, **kwargs): def temporary_variable(name, **kwargs):
from dune.perftool.loopy.temporary import DuneTemporaryVariable from dune.perftool.loopy.temporary import DuneTemporaryVariable
return DuneTemporaryVariable(name, scope=loopy.temp_var_scope.PRIVATE, **kwargs) return DuneTemporaryVariable(name, scope=loopy.temp_var_scope.PRIVATE, **kwargs)
...@@ -69,6 +74,7 @@ def temporary_variable(name, **kwargs): ...@@ -69,6 +74,7 @@ def temporary_variable(name, **kwargs):
@generator_factory(item_tags=("instruction", "cinstruction"), @generator_factory(item_tags=("instruction", "cinstruction"),
context_tags="kernel",
cache_key_generator=lambda *a, **kw: kw['code'], cache_key_generator=lambda *a, **kw: kw['code'],
) )
def c_instruction_impl(**kw): def c_instruction_impl(**kw):
...@@ -77,24 +83,26 @@ def c_instruction_impl(**kw): ...@@ -77,24 +83,26 @@ def c_instruction_impl(**kw):
kw['assignees'] = frozenset(Variable(i) for i in kw['assignees']) kw['assignees'] = frozenset(Variable(i) for i in kw['assignees'])
inames = kw.pop('inames', kw.get('forced_iname_deps', [])) inames = kw.pop('inames', kw.get('forced_iname_deps', []))
return loopy.CInstruction(inames, **kw) return lp.CInstruction(inames, **kw)
@generator_factory(item_tags=("instruction", "exprinstruction"), @generator_factory(item_tags=("instruction", "exprinstruction"),
context_tags="kernel",
cache_key_generator=lambda *a, **kw: kw['expression'], cache_key_generator=lambda *a, **kw: kw['expression'],
) )
def expr_instruction_impl(**kw): def expr_instruction_impl(**kw):
if 'assignees' in kw: if 'assignees' in kw:
from pymbolic.primitives import Variable from pymbolic.primitives import Variable
kw['assignees'] = frozenset(Variable(i) for i in kw['assignees']) kw['assignees'] = frozenset(Variable(i) for i in kw['assignees'])
return loopy.ExpressionInstruction(**kw) return lp.ExpressionInstruction(**kw)
@generator_factory(item_tags=("instruction", "callinstruction"), @generator_factory(item_tags=("instruction", "callinstruction"),
context_tags="kernel",
cache_key_generator=lambda *a, **kw: kw['expression'], cache_key_generator=lambda *a, **kw: kw['expression'],
) )
def call_instruction_impl(**kw): def call_instruction_impl(**kw):
return loopy.CallInstruction(**kw) return lp.CallInstruction(**kw)
def _insn_cache_key(code=None, expression=None, **kwargs): def _insn_cache_key(code=None, expression=None, **kwargs):
...@@ -105,7 +113,9 @@ def _insn_cache_key(code=None, expression=None, **kwargs): ...@@ -105,7 +113,9 @@ def _insn_cache_key(code=None, expression=None, **kwargs):
raise ValueError("Please specify either code or expression for instruction!") raise ValueError("Please specify either code or expression for instruction!")
@generator_factory(item_tags=("insn_id",), cache_key_generator=_insn_cache_key) @generator_factory(item_tags=("insn_id",),
context_tags="kernel",
cache_key_generator=_insn_cache_key)
def instruction(code=None, expression=None, **kwargs): def instruction(code=None, expression=None, **kwargs):
assert (code is not None) or (expression is not None) assert (code is not None) or (expression is not None)
assert not ((code is not None) and (expression is not None)) assert not ((code is not None) and (expression is not None))
...@@ -127,12 +137,15 @@ def instruction(code=None, expression=None, **kwargs): ...@@ -127,12 +137,15 @@ def instruction(code=None, expression=None, **kwargs):
return id return id
@generator_factory(item_tags=("instruction",), cache_key_generator=lambda **kw: kw['id']) @generator_factory(item_tags=("instruction",),
context_tags="kernel",
cache_key_generator=lambda **kw: kw['id'])
def noop_instruction(**kwargs): def noop_instruction(**kwargs):
return loopy.NoOpInstruction(**kwargs) return lp.NoOpInstruction(**kwargs)
@generator_factory(item_tags=("transformation",), @generator_factory(item_tags=("transformation",),
context_tags="kernel",
cache_key_generator=no_caching, cache_key_generator=no_caching,
) )
def transform(trafo, *args): def transform(trafo, *args):
......
...@@ -6,13 +6,13 @@ Namely: ...@@ -6,13 +6,13 @@ Namely:
""" """
from dune.perftool.options import get_option from dune.perftool.options import get_option
from dune.perftool.generation import (cached, from dune.perftool.generation import (domain,
domain,
function_mangler, function_mangler,
iname, iname,
globalarg, globalarg,
valuearg, valuearg,
get_global_context_value get_global_context_value,
kernel_cached,
) )
from dune.perftool.pdelab.index import name_index from dune.perftool.pdelab.index import name_index
from dune.perftool.pdelab.basis import (evaluate_coefficient, from dune.perftool.pdelab.basis import (evaluate_coefficient,
...@@ -131,7 +131,7 @@ def name_applycontainer(restriction): ...@@ -131,7 +131,7 @@ def name_applycontainer(restriction):
return name return name
@cached @kernel_cached
def pymbolic_coefficient(container, lfs, index): def pymbolic_coefficient(container, lfs, index):
# TODO introduce a proper type for local function spaces! # TODO introduce a proper type for local function spaces!
if isinstance(lfs, str): if isinstance(lfs, str):
......
""" Generators for basis evaluations """ """ Generators for basis evaluations """
from dune.perftool.generation import (backend, from dune.perftool.generation import (backend,
cached,
class_member, class_member,
generator_factory, generator_factory,
get_backend, get_backend,
include_file, include_file,
instruction, instruction,
kernel_cached,
preamble, preamble,
temporary_variable, temporary_variable,
) )
...@@ -65,7 +65,7 @@ def declare_cache_temporary(element, restriction, which): ...@@ -65,7 +65,7 @@ def declare_cache_temporary(element, restriction, which):
@backend(interface="evaluate_basis") @backend(interface="evaluate_basis")
@cached @kernel_cached
def evaluate_basis(leaf_element, name, restriction): def evaluate_basis(leaf_element, name, restriction):
lfs = name_leaf_lfs(leaf_element, restriction) lfs = name_leaf_lfs(leaf_element, restriction)
temporary_variable(name, shape=(name_lfs_bound(lfs),), decl_method=declare_cache_temporary(leaf_element, restriction, 'Function')) temporary_variable(name, shape=(name_lfs_bound(lfs),), decl_method=declare_cache_temporary(leaf_element, restriction, 'Function'))
...@@ -93,7 +93,7 @@ def pymbolic_basis(leaf_element, restriction, number, context=''): ...@@ -93,7 +93,7 @@ def pymbolic_basis(leaf_element, restriction, number, context=''):
@backend(interface="evaluate_grad") @backend(interface="evaluate_grad")
@cached @kernel_cached
def evaluate_reference_gradient(leaf_element, name, restriction): def evaluate_reference_gradient(leaf_element, name, restriction):
lfs = name_leaf_lfs(leaf_element, restriction) lfs = name_leaf_lfs(leaf_element, restriction)
temporary_variable(name, shape=(name_lfs_bound(lfs), 1, name_dimension()), decl_method=declare_cache_temporary(leaf_element, restriction, 'Jacobian')) temporary_variable(name, shape=(name_lfs_bound(lfs), 1, name_dimension()), decl_method=declare_cache_temporary(leaf_element, restriction, 'Jacobian'))
...@@ -129,7 +129,7 @@ def shape_as_pymbolic(shape): ...@@ -129,7 +129,7 @@ def shape_as_pymbolic(shape):
return tuple(_shape_as_pymbolic(s) for s in shape) return tuple(_shape_as_pymbolic(s) for s in shape)
@cached @kernel_cached
def evaluate_coefficient(element, name, container, restriction, component): def evaluate_coefficient(element, name, container, restriction, component):
from ufl.functionview import select_subelement from ufl.functionview import select_subelement
sub_element = select_subelement(element, component) sub_element = select_subelement(element, component)
...@@ -165,7 +165,7 @@ def evaluate_coefficient(element, name, container, restriction, component): ...@@ -165,7 +165,7 @@ def evaluate_coefficient(element, name, container, restriction, component):
) )
@cached @kernel_cached
def evaluate_coefficient_gradient(element, name, container, restriction, component): def evaluate_coefficient_gradient(element, name, container, restriction, component):
# First we determine the rank of the tensor we are talking about # First we determine the rank of the tensor we are talking about
from ufl.functionview import select_subelement from ufl.functionview import select_subelement
......
from dune.perftool.ufl.modified_terminals import Restriction from dune.perftool.ufl.modified_terminals import Restriction
from dune.perftool.pdelab.restriction import restricted_name from dune.perftool.pdelab.restriction import restricted_name
from dune.perftool.generation import (backend, from dune.perftool.generation import (backend,
cached,
domain, domain,
get_backend, get_backend,
get_global_context_value, get_global_context_value,
globalarg, globalarg,
iname, iname,
include_file, include_file,
kernel_cached,
preamble, preamble,
temporary_variable, temporary_variable,
valuearg, valuearg,
...@@ -276,7 +276,7 @@ def type_jacobian_inverse_transposed(restriction): ...@@ -276,7 +276,7 @@ def type_jacobian_inverse_transposed(restriction):
return "typename {}::JacobianInverseTransposed".format(geo) return "typename {}::JacobianInverseTransposed".format(geo)
@cached @kernel_cached
def define_jacobian_inverse_transposed_temporary(restriction): def define_jacobian_inverse_transposed_temporary(restriction):
@preamble @preamble
def _define_jacobian_inverse_transposed_temporary(name, shape, shape_impl): def _define_jacobian_inverse_transposed_temporary(name, shape, shape_impl):
......
from dune.perftool.generation import cached from dune.perftool.generation import kernel_cached
from ufl.classes import MultiIndex, Index from ufl.classes import MultiIndex, Index
# Now define some commonly used generators that do not fall into a specific category # Now define some commonly used generators that do not fall into a specific category
@cached @kernel_cached
def name_index(index): def name_index(index):
if isinstance(index, Index): if isinstance(index, Index):
# This failed for index > 9 because ufl placed curly brackets around # This failed for index > 9 because ufl placed curly brackets around
......
""" Generators for parameter functions """ """ Generators for parameter functions """
from dune.perftool.generation import (cached, from dune.perftool.generation import (class_basename,
class_basename,
class_member, class_member,
constructor_parameter, constructor_parameter,
generator_factory, generator_factory,
get_backend, get_backend,
initializer_list, initializer_list,
kernel_cached,
preamble, preamble,
temporary_variable temporary_variable
) )
...@@ -206,7 +206,7 @@ def construct_nested_fieldvector(t, shape): ...@@ -206,7 +206,7 @@ def construct_nested_fieldvector(t, shape):
return 'Dune::FieldVector<{}, {}>'.format(construct_nested_fieldvector(t, shape[1:]), shape[0]) return 'Dune::FieldVector<{}, {}>'.format(construct_nested_fieldvector(t, shape[1:]), shape[0])
@cached @kernel_cached
def cell_parameter_function(name, expr, restriction, cellwise_constant, t='double'): def cell_parameter_function(name, expr, restriction, cellwise_constant, t='double'):
shape = expr.ufl_element().value_shape() shape = expr.ufl_element().value_shape()
shape_impl = ('fv',) * len(shape) shape_impl = ('fv',) * len(shape)
...@@ -218,7 +218,7 @@ def cell_parameter_function(name, expr, restriction, cellwise_constant, t='doubl ...@@ -218,7 +218,7 @@ def cell_parameter_function(name, expr, restriction, cellwise_constant, t='doubl
evaluate_cell_parameter_function(name, restriction) evaluate_cell_parameter_function(name, restriction)
@cached @kernel_cached
def intersection_parameter_function(name, expr, cellwise_constant, t='double'): def intersection_parameter_function(name, expr, cellwise_constant, t='double'):
shape = expr.ufl_element().value_shape() shape = expr.ufl_element().value_shape()
shape_impl = ('fv',) * len(shape) shape_impl = ('fv',) * len(shape)
......
import numpy import numpy
from dune.perftool.generation import (backend, from dune.perftool.generation import (backend,
cached,
class_member, class_member,
domain, domain,
get_backend, get_backend,
......
...@@ -4,12 +4,12 @@ NB: Basis evaluation is only needed for the trial function argument in jacobians ...@@ -4,12 +4,12 @@ NB: Basis evaluation is only needed for the trial function argument in jacobians
multiplication with the test function is part of the sum factorization kernel. multiplication with the test function is part of the sum factorization kernel.
""" """
from dune.perftool.generation import (backend, from dune.perftool.generation import (backend,
cached,
domain, domain,
get_counter, get_counter,
get_global_context_value, get_global_context_value,
iname, iname,
instruction, instruction,
kernel_cached,
temporary_variable, temporary_variable,
) )
from dune.perftool.sumfact.amatrix import (AMatrix, from dune.perftool.sumfact.amatrix import (AMatrix,
...@@ -39,7 +39,7 @@ def name_sumfact_base_buffer(): ...@@ -39,7 +39,7 @@ def name_sumfact_base_buffer():
return name return name
@cached @kernel_cached
def sumfact_evaluate_coefficient_gradient(element, name, restriction, component): def sumfact_evaluate_coefficient_gradient(element, name, restriction, component):
# Get a temporary for the gradient # Get a temporary for the gradient
from ufl.functionview import select_subelement from ufl.functionview import select_subelement
...@@ -97,7 +97,7 @@ def sumfact_evaluate_coefficient_gradient(element, name, restriction, component) ...@@ -97,7 +97,7 @@ def sumfact_evaluate_coefficient_gradient(element, name, restriction, component)
) )
@cached @kernel_cached
def pymbolic_trialfunction_gradient(element, restriction, component): def pymbolic_trialfunction_gradient(element, restriction, component):
rawname = "gradu" + "_".join(str(c) for c in component) rawname = "gradu" + "_".join(str(c) for c in component)
name = restricted_name(rawname, restriction) name = restricted_name(rawname, restriction)
...@@ -108,7 +108,7 @@ def pymbolic_trialfunction_gradient(element, restriction, component): ...@@ -108,7 +108,7 @@ def pymbolic_trialfunction_gradient(element, restriction, component):
return Variable(name) return Variable(name)
@cached @kernel_cached
def pymbolic_trialfunction(element, restriction, component): def pymbolic_trialfunction(element, restriction, component):
theta = name_theta() theta = name_theta()
rows = quadrature_points_per_direction() rows = quadrature_points_per_direction()
...@@ -151,7 +151,7 @@ def lfs_inames(element, restriction, number=1, context=''): ...@@ -151,7 +151,7 @@ def lfs_inames(element, restriction, number=1, context=''):
@backend(interface="evaluate_basis") @backend(interface="evaluate_basis")
@cached @kernel_cached
def evaluate_basis(element, name, restriction): def evaluate_basis(element, name, restriction):
temporary_variable(name, shape=()) temporary_variable(name, shape=())
theta = name_theta() theta = name_theta()
...@@ -183,7 +183,7 @@ def pymbolic_basis(element, restriction, number): ...@@ -183,7 +183,7 @@ def pymbolic_basis(element, restriction, number):
@backend(interface="evaluate_grad") @backend(interface="evaluate_grad")
@cached @kernel_cached
def evaluate_reference_gradient(element, name, restriction): def evaluate_reference_gradient(element, name, restriction):
from dune.perftool.pdelab.geometry import name_dimension from dune.perftool.pdelab.geometry import name_dimension
temporary_variable( temporary_variable(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment