Skip to content
Snippets Groups Projects
Commit 9e098b54 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Refactor to be more loopyish

including CallInstructions for accumulation etc.
parent c22773d9
No related branches found
No related tags found
No related merge requests found
...@@ -131,6 +131,13 @@ def expr_instruction_impl(**kw): ...@@ -131,6 +131,13 @@ def expr_instruction_impl(**kw):
return loopy.ExpressionInstruction(**kw) return loopy.ExpressionInstruction(**kw)
@generator_factory(item_tags=("instruction", "callinstruction"),
cache_key_generator=lambda *a, **kw: kw['expression'],
)
def call_instruction_impl(**kw):
return loopy.CallInstruction(**kw)
class _IDCounter: class _IDCounter:
count = 0 count = 0
...@@ -157,7 +164,10 @@ def instruction(code=None, expression=None, **kwargs): ...@@ -157,7 +164,10 @@ def instruction(code=None, expression=None, **kwargs):
if code: if code:
c_instruction_impl(id=id, code=code, **kwargs) c_instruction_impl(id=id, code=code, **kwargs)
if expression: if expression:
expr_instruction_impl(id=id, expression=expression, **kwargs) if 'assignees' in kwargs and len(kwargs['assignees']) == 0:
call_instruction_impl(id=id, expression=expression, **kwargs)
else:
expr_instruction_impl(id=id, expression=expression, **kwargs)
# return the ID, as it is the only useful information to the user # return the ID, as it is the only useful information to the user
return id return id
...@@ -24,13 +24,33 @@ def coefficient_mangler(target, func, dtypes): ...@@ -24,13 +24,33 @@ def coefficient_mangler(target, func, dtypes):
class PDELabAccumulationFunction(FunctionIdentifier): class PDELabAccumulationFunction(FunctionIdentifier):
def __init__(self, accumobj): def __init__(self, accumobj, rank):
self.accumobj = accumobj self.accumobj = accumobj
self.rank = rank
assert rank in (1, 2)
def __getinitargs__(self): def __getinitargs__(self):
return (self.accumobj,) return (self.accumobj, self.rank)
def accumulation_mangler(target, func, dtypes): def accumulation_mangler(target, func, dtypes):
if isinstance(func, PDELabAccumulationFunction): if isinstance(func, PDELabAccumulationFunction):
return CallMangleInfo('{}.accumulate'.format(func.accumobj), (), ()) if func.rank == 1:
return CallMangleInfo('{}.accumulate'.format(func.accumobj),
(),
(NumpyType(str),
NumpyType(numpy.int32),
NumpyType(numpy.float64),
)
)
if func.rank == 2:
return CallMangleInfo('{}.accumulate'.format(func.accumobj),
(),
(NumpyType(str),
NumpyType(numpy.int32),
NumpyType(str),
NumpyType(numpy.int32),
NumpyType(numpy.float64),
)
)
...@@ -5,7 +5,6 @@ from loopy.target import (TargetBase, ...@@ -5,7 +5,6 @@ from loopy.target import (TargetBase,
from loopy.target.c import CASTBuilder from loopy.target.c import CASTBuilder
from loopy.target.c.codegen.expression import ExpressionToCMapper from loopy.target.c.codegen.expression import ExpressionToCMapper
_registry = {'float32': 'float', _registry = {'float32': 'float',
'int32': 'int', 'int32': 'int',
'float64': 'double', 'float64': 'double',
......
...@@ -41,18 +41,6 @@ def index_sum_iname(i): ...@@ -41,18 +41,6 @@ def index_sum_iname(i):
return name_index(i) return name_index(i)
_outerloop = None
def set_outerloop(v):
global _outerloop
_outerloop = v
def get_outerloop():
return _outerloop
class UFL2LoopyVisitor(ModifiedTerminalTracker, UFL2PymbolicMapper, GeometryMapper): class UFL2LoopyVisitor(ModifiedTerminalTracker, UFL2PymbolicMapper, GeometryMapper):
def __init__(self, measure, subdomain_id): def __init__(self, measure, subdomain_id):
# Some variables describing the integral measure of this integral # Some variables describing the integral measure of this integral
...@@ -62,71 +50,13 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker, UFL2PymbolicMapper, GeometryMapp ...@@ -62,71 +50,13 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker, UFL2PymbolicMapper, GeometryMapp
# Call base class constructors # Call base class constructors
super(UFL2LoopyVisitor, self).__init__() super(UFL2LoopyVisitor, self).__init__()
# Some state variables that need to be persistent over multiple calls
self.index_placeholder_removal_mapper = IndexPlaceholderRemoval()
def _assign(self, o):
# In some corner cases we do not even need a temporary variable
if isinstance(o, int) or isinstance(o, float):
return o
# Assign a name to the temporary variable we want our result in
temp = get_temporary_name()
temp_shape = ()
# Determine which inames this assignment depends on and whether it should
# be merged into the main accumulation loop. Right now we apply the following
# procedure: All instructions that depend on all argument loop indices are
# merged into the main loop nest. Those instructions that depend on some
# argument loop indices but not all are merged into the kernel by fixing the
# loop ordering of the main loop (or are pulled outside if this already happened).
assignee = Variable(temp)
iname_deps = self.inames
merge_into_main_loopnest = True
if self.rank == 2 and len(set(iname_deps)) == 1:
if get_outerloop() is None:
set_outerloop(iname_deps[0].number)
if iname_deps[0].number != get_outerloop():
merge_into_main_loopnest = False
# Change the assignee!
if not merge_into_main_loopnest:
assignee_index_placeholder = LFSIndexPlaceholderExtraction()(o).pop()
assignee_index = self.index_placeholder_removal_mapper(assignee_index_placeholder, duplicate_inames=True)
assignee = Subscript(assignee, (assignee_index,))
temp_shape = (name_lfs_bound(name_leaf_lfs(assignee_index_placeholder.element, assignee_index_placeholder.restriction)),)
# Now introduce duplicate inames for the argument loop if necessary
replaced_iname_deps = [self.index_placeholder_removal_mapper(i, duplicate_inames=not merge_into_main_loopnest, wrap_in_variable=False) for i in iname_deps]
replaced_expr = self.index_placeholder_removal_mapper(o, duplicate_inames=not merge_into_main_loopnest)
# Now we assign this expression to a new temporary variable
insn_id = instruction(assignee=assignee,
expression=replaced_expr,
forced_iname_deps=frozenset({i for i in replaced_iname_deps}).union(frozenset({quadrature_iname()})),
forced_iname_deps_is_final=True,
)
# Actually, if we have a cache hit, we need to change our temporary
from dune.perftool.generation import retrieve_cache_items
temp = filter(lambda i: i.id == insn_id, retrieve_cache_items("instruction"))[0].assignee_name
retvar = Variable(temp)
if not merge_into_main_loopnest:
retvar_index = self.index_placeholder_removal_mapper(assignee_index_placeholder)
retvar = Subscript(retvar, (retvar_index,))
# Now that we know its exact name, declare the temporary
temporary_variable(temp, shape=temp_shape)
return retvar
def __call__(self, o): def __call__(self, o):
# Reset some state variables that are reinitialized for each accumulation term # Reset some state variables that are reinitialized for each accumulation term
self.argshape = 0 self.argshape = 0
self.redinames = () self.redinames = ()
self.inames = [] self.inames = []
self.dimension_index_aliases = []
self.substitution_rules = []
# Initialize the local function spaces that we might need for this term # Initialize the local function spaces that we might need for this term
# We therefore gather a list of modified trial functions too. # We therefore gather a list of modified trial functions too.
...@@ -152,13 +82,6 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker, UFL2PymbolicMapper, GeometryMapp ...@@ -152,13 +82,6 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker, UFL2PymbolicMapper, GeometryMapp
if pymbolic_expr == 0: if pymbolic_expr == 0:
return return
# We assign the result to a temporary variable to ease things a bit
if not isinstance(pymbolic_expr, Variable):
pymbolic_expr = self._assign(pymbolic_expr)
# Transform the IndexPlaceholders into real inames
self.inames = [self.index_placeholder_removal_mapper(i, wrap_in_variable=False) for i in self.inames]
# Collect the arguments for the accumulate function # Collect the arguments for the accumulate function
accumargs = [None] * (2 * len(test_ma)) accumargs = [None] * (2 * len(test_ma))
residual_shape = [None] * len(test_ma) residual_shape = [None] * len(test_ma)
...@@ -177,6 +100,10 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker, UFL2PymbolicMapper, GeometryMapp ...@@ -177,6 +100,10 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker, UFL2PymbolicMapper, GeometryMapp
# And generate a local function space for it! # And generate a local function space for it!
lfs = name_lfs(ma.argexpr.ufl_element(), ma.restriction, ma.component) lfs = name_lfs(ma.argexpr.ufl_element(), ma.restriction, ma.component)
from dune.perftool.generation import valuearg
from loopy.types import NumpyType
valuearg(lfs, dtype=NumpyType("str"))
if len(subel.value_shape()) != 0: if len(subel.value_shape()) != 0:
from dune.perftool.pdelab.geometry import dimension_iname from dune.perftool.pdelab.geometry import dimension_iname
from dune.perftool.pdelab.basis import lfs_child from dune.perftool.pdelab.basis import lfs_child
...@@ -188,8 +115,8 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker, UFL2PymbolicMapper, GeometryMapp ...@@ -188,8 +115,8 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker, UFL2PymbolicMapper, GeometryMapp
lfsi = lfs_iname(subel, ma.restriction, count=count) lfsi = lfs_iname(subel, ma.restriction, count=count)
accumargs[2 * icount] = lfs accumargs[2 * icount] = Variable(lfs)
accumargs[2 * icount + 1] = lfsi accumargs[2 * icount + 1] = Variable(lfsi)
arg_restr[icount] = ma.restriction arg_restr[icount] = ma.restriction
...@@ -199,13 +126,6 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker, UFL2PymbolicMapper, GeometryMapp ...@@ -199,13 +126,6 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker, UFL2PymbolicMapper, GeometryMapp
from dune.perftool.pdelab.quadrature import name_factor from dune.perftool.pdelab.quadrature import name_factor
factor = name_factor() factor = name_factor()
# Generate the code snippet for this accumulation instruction
code = "{}.accumulate({}, {}*{});".format(accumvar,
", ".join(accumargs),
pymbolic_expr.name,
factor,
)
predicates = frozenset({}) predicates = frozenset({})
# Maybe wrap this instruction into a condiditional. This mostly happens with mixed boundary conditions # Maybe wrap this instruction into a condiditional. This mostly happens with mixed boundary conditions
...@@ -241,9 +161,13 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker, UFL2PymbolicMapper, GeometryMapp ...@@ -241,9 +161,13 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker, UFL2PymbolicMapper, GeometryMapp
predicates = predicates.union(['{} == {}'.format(name, self.subdomain_id)]) predicates = predicates.union(['{} == {}'.format(name, self.subdomain_id)])
# Finally, issue the instruction from dune.perftool.loopy.functions import PDELabAccumulationFunction
instruction(code=code, from pymbolic.primitives import Call, Product
read_variables=frozenset({factor, pymbolic_expr.name}), expr = Product((pymbolic_expr, Variable(factor)))
expr = Call(PDELabAccumulationFunction(accumvar, len(test_ma)), tuple(a for a in accumargs) + (expr,))
instruction(assignees=frozenset({}),
expression=expr,
forced_iname_deps=frozenset(self.inames).union(frozenset({quadrature_iname()})), forced_iname_deps=frozenset(self.inames).union(frozenset({quadrature_iname()})),
forced_iname_deps_is_final=True, forced_iname_deps_is_final=True,
predicates=predicates predicates=predicates
...@@ -279,14 +203,15 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker, UFL2PymbolicMapper, GeometryMapp ...@@ -279,14 +203,15 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker, UFL2PymbolicMapper, GeometryMapp
leaf_element = element.sub_elements()[0] leaf_element = element.sub_elements()[0]
# Have the issued instruction depend on the iname for this localfunction space # Have the issued instruction depend on the iname for this localfunction space
self.inames.append(LFSIndexPlaceholder(leaf_element, restriction, o.number())) iname = lfs_iname(leaf_element, restriction, o.number())
self.inames.append(iname)
if self.grad: if self.grad:
from dune.perftool.pdelab.argument import name_testfunction_gradient from dune.perftool.pdelab.argument import name_testfunction_gradient
return Subscript(Variable(name_testfunction_gradient(leaf_element, restriction)), (LFSIndexPlaceholder(leaf_element, restriction, o.number()),)) return Subscript(Variable(name_testfunction_gradient(leaf_element, restriction)), (Variable(iname),))
else: else:
from dune.perftool.pdelab.argument import name_testfunction from dune.perftool.pdelab.argument import name_testfunction
return Subscript(Variable(name_testfunction(leaf_element, restriction)), (LFSIndexPlaceholder(leaf_element, restriction, o.number()),)) return Subscript(Variable(name_testfunction(leaf_element, restriction)), (Variable(iname),))
def coefficient(self, o): def coefficient(self, o):
# If this is a trialfunction, we do something entirely different # If this is a trialfunction, we do something entirely different
...@@ -336,8 +261,8 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker, UFL2PymbolicMapper, GeometryMapp ...@@ -336,8 +261,8 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker, UFL2PymbolicMapper, GeometryMapp
use_indices = self.last_index[self.argshape:] use_indices = self.last_index[self.argshape:]
for i in range(self.argshape): for i in range(self.argshape):
from dune.perftool.pdelab.geometry import dimension_iname self.dimension_index_aliases.append(i)
self.index_placeholder_removal_mapper.index_replacement_map[self.last_index[i].expr] = Variable(dimension_iname(context='arg')) # self.index_placeholder_removal_mapper.index_replacement_map[self.last_index[i].expr] = Variable(dimension_iname(context='arg'))
self.argshape = 0 self.argshape = 0
if isinstance(aggr, Subscript): if isinstance(aggr, Subscript):
...@@ -369,37 +294,46 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker, UFL2PymbolicMapper, GeometryMapp ...@@ -369,37 +294,46 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker, UFL2PymbolicMapper, GeometryMapp
else: else:
from loopy import Reduction from loopy import Reduction
oldinames = self.inames
self.inames = []
# Recurse to get the summation expression # Recurse to get the summation expression
term = self.call(o.ufl_operands[0]) term = self.call(o.ufl_operands[0])
from dune.perftool.pymbolic.inameset import get_index_inames self.redinames = tuple(i for i in self.redinames if i not in self.dimension_index_aliases)
used_inames = frozenset({self.index_placeholder_removal_mapper.index_replacement_map.get(i, i).name for i in get_index_inames(term, as_variables=True)})
self.inames = [i for i in used_inames.intersection(frozenset({i for i in oldinames}))] + self.inames
# Now filter all those reduction inames that are marked for removal
implicit_inames = [i.name for i in self.index_placeholder_removal_mapper.index_replacement_map]
self.redinames = tuple(i for i in self.redinames if i not in implicit_inames)
if len(self.redinames) > 0: if len(self.redinames) > 0:
ret = self._assign(Reduction("sum", self.redinames, term)) ret = Reduction("sum", self.redinames, term)
name = get_temporary_name()
# Generate a substitution rule for this one.
from loopy import SubstitutionRule
self.substitution_rules.append(SubstitutionRule(name,
(),
ret
)
)
ret = Variable(name)
else: else:
ret = term ret = term
self.inames = self.inames + oldinames
# Reset the reduction inames for future indexsums # Reset the reduction inames for future indexsums
self.redinames = () self.redinames = ()
return ret return ret
def _index_or_fixed_index(self, index):
from ufl.classes import FixedIndex
if isinstance(index, FixedIndex):
return index._value
else:
from pymbolic.primitives import Variable
from dune.perftool.pdelab import name_index
if index in self.dimension_index_aliases:
from dune.perftool.pdelab.geometry import dimension_iname
return Variable(dimension_iname(context='arg'))
else:
return Variable(name_index(index))
def multi_index(self, o): def multi_index(self, o):
from dune.perftool.pdelab import pymbolic_index return tuple(self._index_or_fixed_index(i) for i in o)
return tuple(IndexPlaceholder(pymbolic_index(op)) for op in o.indices())
def index(self, o): def index(self, o):
# One might as well take the uflname as string here, but I apply this function return self._index_or_fixed_index(o)
from dune.perftool.pdelab import name_index
return IndexPlaceholder(Variable(name_index(o)))
""" The pdelab specific parts of the code generation process """ """ The pdelab specific parts of the code generation process """
from dune.perftool.generation import (preamble, from dune.perftool.generation import (preamble,
pymbolic_expr,
symbol, symbol,
) )
...@@ -21,7 +20,6 @@ def name_index(index): ...@@ -21,7 +20,6 @@ def name_index(index):
raise NotImplementedError raise NotImplementedError
@pymbolic_expr
def pymbolic_index(index): def pymbolic_index(index):
from ufl.classes import FixedIndex from ufl.classes import FixedIndex
if isinstance(index, FixedIndex): if isinstance(index, FixedIndex):
......
...@@ -154,6 +154,7 @@ def assembly_routine_signature(): ...@@ -154,6 +154,7 @@ def assembly_routine_signature():
def generate_kernel(integrals): def generate_kernel(integrals):
subst_rules = []
for integral in integrals: for integral in integrals:
integrand = integral.integrand() integrand = integral.integrand()
measure = integral.integral_type() measure = integral.integral_type()
...@@ -171,6 +172,7 @@ def generate_kernel(integrals): ...@@ -171,6 +172,7 @@ def generate_kernel(integrals):
# Iterate over the terms and generate a kernel # Iterate over the terms and generate a kernel
for term in accterms: for term in accterms:
visitor(term) visitor(term)
subst_rules.extend(visitor.substitution_rules)
# Extract the information, which is needed to create a loopy kernel. # Extract the information, which is needed to create a loopy kernel.
# First extracting it, might be useful to alter it before kernel generation. # First extracting it, might be useful to alter it before kernel generation.
...@@ -187,7 +189,7 @@ def generate_kernel(integrals): ...@@ -187,7 +189,7 @@ def generate_kernel(integrals):
# Create the kernel # Create the kernel
from loopy import make_kernel, preprocess_kernel from loopy import make_kernel, preprocess_kernel
kernel = make_kernel(domains, kernel = make_kernel(domains,
instructions, instructions + subst_rules,
arguments, arguments,
temporary_variables=temporaries, temporary_variables=temporaries,
function_manglers=[accumulation_mangler, coefficient_mangler], function_manglers=[accumulation_mangler, coefficient_mangler],
...@@ -213,7 +215,6 @@ def generate_kernel(integrals): ...@@ -213,7 +215,6 @@ def generate_kernel(integrals):
from dune.perftool.generation import delete_cache_items from dune.perftool.generation import delete_cache_items
delete_cache_items("(not file) and (not clazz)") delete_cache_items("(not file) and (not clazz)")
# Return the actual code (might instead return kernels...)
return kernel return kernel
...@@ -225,7 +226,7 @@ class AssemblyMethod(ClassMember): ...@@ -225,7 +226,7 @@ class AssemblyMethod(ClassMember):
content.append('{') content.append('{')
if kernel is not None: if kernel is not None:
for i, p in kernel.preambles: for i, p in kernel.preambles:
content.append(p) content.append(' ' + p)
content.extend(l for l in generate_body(kernel).split('\n')[1:-1]) content.extend(l for l in generate_body(kernel).split('\n')[1:-1])
content.append('}') content.append('}')
ClassMember.__init__(self, content) ClassMember.__init__(self, content)
...@@ -286,10 +287,6 @@ def generate_localoperator_kernels(formdata, namedata): ...@@ -286,10 +287,6 @@ def generate_localoperator_kernels(formdata, namedata):
# Generate the necessary residual methods # Generate the necessary residual methods
for measure in set(i.integral_type() for i in form.integrals()): for measure in set(i.integral_type() for i in form.integrals()):
with global_context(integral_type=measure): with global_context(integral_type=measure):
# Reset the outer loop
from dune.perftool.loopy.transformer import set_outerloop
set_outerloop(None)
enum_pattern() enum_pattern()
pattern_baseclass() pattern_baseclass()
enum_alpha() enum_alpha()
...@@ -321,10 +318,6 @@ def generate_localoperator_kernels(formdata, namedata): ...@@ -321,10 +318,6 @@ def generate_localoperator_kernels(formdata, namedata):
with global_context(form_type="jacobian"): with global_context(form_type="jacobian"):
for measure in set(i.integral_type() for i in jacform.integrals()): for measure in set(i.integral_type() for i in jacform.integrals()):
# Reset the outer loop
from dune.perftool.loopy.transformer import set_outerloop
set_outerloop(None)
with global_context(integral_type=measure): with global_context(integral_type=measure):
kernel = generate_kernel(jacform.integrals_by_type(measure)) kernel = generate_kernel(jacform.integrals_by_type(measure))
operator_kernels[(measure, 'jacobian')] = kernel operator_kernels[(measure, 'jacobian')] = kernel
......
from pymbolic.mapper import CombineMapper
from pymbolic.primitives import Variable
class INameMapper(CombineMapper):
def _map_index(self, i):
if isinstance(i, Variable):
return frozenset([str(i)])
from dune.perftool.pymbolic.placeholder import IndexPlaceholder, LFSIndexPlaceholder
if isinstance(i, IndexPlaceholder):
return frozenset([str(i.expr.name)])
if isinstance(i, LFSIndexPlaceholder):
return frozenset({})
def map_subscript(self, e):
if isinstance(e.index, tuple):
return self.combine(tuple(self._map_index(i) for i in e.index))
else:
return self._map_index(e.index)
def map_constant(self, e):
return frozenset({})
def map_algebraic_leaf(self, e):
return frozenset({})
def combine(self, values):
return frozenset().union(*values)
def get_index_inames(e, as_variables=False):
iset = INameMapper()(e)
if as_variables:
return frozenset(Variable(i) for i in iset)
else:
return iset
from pymbolic.mapper import Collector, IdentityMapper
from pymbolic.primitives import Variable
class IndexPlaceholderBase(object):
pass
class IndexPlaceholder(IndexPlaceholderBase):
def __init__(self, expr):
self.expr = expr
def __hash__(self):
return hash(self.expr)
def __eq__(self, o):
return (type(o) == IndexPlaceholder) and (self.expr == o.expr)
class LFSIndexPlaceholder(IndexPlaceholderBase):
def __init__(self, element, restriction, number, context=''):
self.element = element
self.restriction = restriction
self.number = number
self.context = context
def __hash__(self):
return hash((self.element, self.restriction, self.number, self.context))
def __eq__(self, o):
return (self.element == o.element) and (self.restriction == o.restriction) and (self.number == o.number) and (self.context == o.context)
class IndexPlaceholderRemoval(IdentityMapper):
def __init__(self):
# Initialize base class
super(IndexPlaceholderRemoval, self).__init__()
# Initialize state variables that are persistent over multiple calls
self.index_replacement_map = {}
def __call__(self, o, wrap_in_variable=True, duplicate_inames=False):
self.duplicate_inames = duplicate_inames
self.wrap_in_variable = wrap_in_variable
return self.rec(o)
def map_foreign(self, o):
# How do we map constants here? map_constant was not correct
if isinstance(o, int) or isinstance(o, float) or isinstance(o, str):
return o
# There might be tuples of indices where only one is a placeholder,
# so we recurse manually into the tuple.
if isinstance(o, tuple):
return tuple(self.rec(op) for op in o)
# We only handle IndexPlaceholder instances from now on
assert isinstance(o, IndexPlaceholderBase)
if isinstance(o, LFSIndexPlaceholder):
context = o.context
from dune.perftool.loopy.transformer import get_outerloop
if (self.duplicate_inames) and (o.number != get_outerloop()):
context = 'dup'
from dune.perftool.pdelab.basis import lfs_iname
i = lfs_iname(o.element, o.restriction, count=o.number, context=context)
if isinstance(o, IndexPlaceholder):
i = self.index_replacement_map.get(o.expr, o.expr)
if self.wrap_in_variable and not isinstance(i, Variable):
return Variable(i)
else:
return i
def map_reduction(self, o):
from loopy import Reduction
o.expr = self.rec(o.expr)
return o
class LFSIndexPlaceholderExtraction(Collector):
def map_foreign(self, o):
if isinstance(o, int) or isinstance(o, float):
return set()
assert isinstance(o, tuple)
return set(i for i in o if isinstance(i, LFSIndexPlaceholder))
def map_reduction(self, o):
return self.rec(o.expr)
""" Implement a design pattern for a multifunction that delegates to another
multifunction and gets back control if the delegate does not define a handler.
This avoids writing isinstance-if-orgies in handlers
"""
def delegate(Delegate, *args, **kwargs):
assert(isinstance(Delegate, type))
assert(Delegate.expr == Delegate.undefined)
class MyDelegate(Delegate):
def __init__(self, *a, **ka):
Delegate.__init__(self, *a, **ka)
def expr(s, *a, **ka):
s._back(*a, **ka)
def _handler(s, *a, **ka):
delegate_instance = MyDelegate(*args, **kwargs)
delegate_instance._back = s
return delegate_instance(*a, **ka)
return _handler
...@@ -146,17 +146,3 @@ class _ModifiedArgumentExtractor(MultiFunction): ...@@ -146,17 +146,3 @@ class _ModifiedArgumentExtractor(MultiFunction):
def extract_modified_arguments(expr, **kwargs): def extract_modified_arguments(expr, **kwargs):
return _ModifiedArgumentExtractor()(expr, **kwargs) return _ModifiedArgumentExtractor()(expr, **kwargs)
class _ModifiedArgumentNumber(MultiFunction):
""" return the number() of a modified argument """
def expr(self, o):
return self(o.ufl_operands[0])
def argument(self, o):
return o.number()
def modified_argument_number(expr):
""" Given an expression, return the number() of the argument in it """
return _ModifiedArgumentNumber()(expr)
from __future__ import absolute_import
from ufl.algorithms import MultiFunction
class _UFLRank(MultiFunction):
def __call__(self, expr):
return len(MultiFunction.__call__(self, expr))
def expr(self, o):
return set(a for op in o.ufl_operands for a in MultiFunction.__call__(self, op))
def argument(self, o):
return (o.number(),)
def ufl_rank(o):
return _UFLRank()(o)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment