From 11ecb83965516fb7fe7425e4436ca9c9da3b120a Mon Sep 17 00:00:00 2001 From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de> Date: Mon, 7 Nov 2016 12:59:00 +0100 Subject: [PATCH] Implement correct accumulation (also for jacobians!) --- python/dune/perftool/generation/loopy.py | 4 +- python/dune/perftool/loopy/flatten.py | 52 ++++++++++++++++++++++ python/dune/perftool/loopy/target.py | 6 --- python/dune/perftool/sumfact/quadrature.py | 11 ++++- python/dune/perftool/sumfact/sumfact.py | 45 ++++++++++++------- 5 files changed, 92 insertions(+), 26 deletions(-) create mode 100644 python/dune/perftool/loopy/flatten.py diff --git a/python/dune/perftool/generation/loopy.py b/python/dune/perftool/generation/loopy.py index 4f237d61..451149cf 100644 --- a/python/dune/perftool/generation/loopy.py +++ b/python/dune/perftool/generation/loopy.py @@ -18,11 +18,11 @@ silenced_warning = generator_factory(item_tags=("silenced_warning",), no_deco=Tr @generator_factory(item_tags=("argument", "globalarg"), cache_key_generator=lambda n, **kw: n) -def globalarg(name, shape=loopy.auto, argtype=loopy.GlobalArg, **kw): +def globalarg(name, shape=loopy.auto, **kw): if isinstance(shape, str): shape = (shape,) dtype = kw.pop("dtype", numpy.float64) - return argtype(name, dtype=dtype, shape=shape, **kw) + return loopy.GlobalArg(name, dtype=dtype, shape=shape, **kw) @generator_factory(item_tags=("argument", "constantarg"), diff --git a/python/dune/perftool/loopy/flatten.py b/python/dune/perftool/loopy/flatten.py new file mode 100644 index 00000000..5e5415d8 --- /dev/null +++ b/python/dune/perftool/loopy/flatten.py @@ -0,0 +1,52 @@ +from loopy.kernel.array import (convert_computed_to_fixed_dim_tags, + get_access_info, + parse_array_dim_tags, + ) + + +class _DummyArrayObject(object): + def __init__(self, dim_tags): + self.name = 'isthiseverused' + self.offset = None + self.dim_tags = dim_tags + + def num_target_axes(self): + return 1 + + def vector_size(self, target): + # This should call something on the target instead + return 1 + + +def flatten_index(index, shape, order="c"): + """ + A function that flattens a multiindex given the shape + of the multi dimensional array, a tuple of indices and + the specification of the axis order ("c" for row major, + "f" for column major) + + Loopy of course does this automatically in a lot of places. + This code is only meant to be used if a flat index needs + to be manually created. + """ + assert order in ("c", "f") + assert len(index) == len(shape) + + # Get a tuple of dim tags with the specified order + dim_tags = parse_array_dim_tags(",".join(order for i in index)) + + # Transform them to fixed stride tags + dim_tags = convert_computed_to_fixed_dim_tags("blubber", # Name unused + len(index), # number of user axes + 1, # number of implementation axes + shape, + dim_tags, + ) + accinfo = get_access_info(None, # the target fed into above _DummyArrayObject.vector_size + _DummyArrayObject(dim_tags), # the array duck + index, + lambda x: x, # eval_expr, semantics unclear + None, # vectorization info + ) + + return accinfo.subscripts[0] diff --git a/python/dune/perftool/loopy/target.py b/python/dune/perftool/loopy/target.py index 2785737f..35097ac7 100644 --- a/python/dune/perftool/loopy/target.py +++ b/python/dune/perftool/loopy/target.py @@ -1,5 +1,4 @@ from dune.perftool.loopy.temporary import DuneTemporaryVariable -from dune.perftool.sumfact.sumfact import AccumulationArg from dune.perftool.pdelab.spaces import LFSLocalIndex from loopy.target import (TargetBase, @@ -33,11 +32,6 @@ class MyMapper(ExpressionToCExpressionMapper): for i in expr.index: ret = Subscript(ret, i) return ret - elif isinstance(arr, AccumulationArg): - pseudo_subscript = Subscript(expr.aggregate, expr.index) - flattened = ExpressionToCExpressionMapper.map_subscript(self, pseudo_subscript, type_context) - transformed = ExpressionToCExpressionMapper.map_call(self, Call(LFSLocalIndex(arr.lfs), (flattened.index,)), type_context) - return Subscript(Variable(flattened.aggregate.name + '.base()'), transformed) else: return ExpressionToCExpressionMapper.map_subscript(self, expr, type_context) diff --git a/python/dune/perftool/sumfact/quadrature.py b/python/dune/perftool/sumfact/quadrature.py index ccb0c375..b59aa21f 100644 --- a/python/dune/perftool/sumfact/quadrature.py +++ b/python/dune/perftool/sumfact/quadrature.py @@ -59,6 +59,15 @@ def base_weight_function_mangler(target, func, dtypes): return CallMangleInfo(func.name, (NumpyType(numpy.float64),), ()) +def pymbolic_base_weight(): + """ This is the base weight that should be multiplied to the quadrature + weight. With the fast DG assembler this will handle the weighting of the + time discretization scheme. + TODO: Introduce backend switch that uses above BaseWeight function + """ + return 1.0 + + @iname def sumfact_quad_iname(d, context): name = "quad_{}_{}".format(context, d) @@ -92,7 +101,7 @@ def recursive_quadrature_weight(dir=0): formdata = get_global_context_value('formdata') dim = formdata.geometric_dimension if dir == dim: - return Call(BaseWeight(name_accumulation_variable()), ()) + return pymbolic_base_weight() else: name = 'weight_{}'.format(dir) define_recursive_quadrature_weight(name, dir) diff --git a/python/dune/perftool/sumfact/sumfact.py b/python/dune/perftool/sumfact/sumfact.py index 1e69bc32..fe489784 100644 --- a/python/dune/perftool/sumfact/sumfact.py +++ b/python/dune/perftool/sumfact/sumfact.py @@ -1,6 +1,7 @@ from dune.perftool.pdelab.argument import (name_accumulation_variable, name_coefficientcontainer, pymbolic_coefficient, + PDELabAccumulationFunction, ) from dune.perftool.generation import (backend, domain, @@ -14,11 +15,13 @@ from dune.perftool.generation import (backend, temporary_variable, transform, ) +from dune.perftool.loopy.flatten import flatten_index from dune.perftool.loopy.buffer import (get_buffer_temporary, initialize_buffer, switch_base_storage, ) from dune.perftool.sumfact.quadrature import nest_quadrature_loops +from dune.perftool.pdelab.localoperator import determine_accumulation_space from dune.perftool.pdelab.spaces import name_lfs from dune.perftool.sumfact.amatrix import (AMatrix, quadrature_points_per_direction, @@ -38,10 +41,6 @@ from loopy.symbolic import FunctionIdentifier from pytools import product -class AccumulationArg(GlobalArg): - allowed_extra_kwargs = GlobalArg.allowed_extra_kwargs + ["lfs"] - - @iname def _sumfact_iname(bound, _type, count): name = "sf_{}_{}".format(_type, str(count)) @@ -111,21 +110,33 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id): additional_inames=frozenset(visitor.inames), ) - # Now write all this into the correct residual - lfs = name_lfs(accterm.argument.argexpr.ufl_element(), - accterm.argument.restriction, - accterm.argument.component, - ) inames = tuple(sumfact_iname(mat.rows, 'accum') for mat in a_matrices) + + # Collect the lfs and lfs indices for the accumulate call + test_lfs = determine_accumulation_space(accterm.argument.expr, 0, measure) + test_lfs.index = flatten_index(tuple(Variable(i) for i in inames), + (basis_functions_per_direction(),) * dim + ) + + # In the jacobian case, also determine the space for the ansatz space + ansatz_lfs = determine_accumulation_space(accterm.term, 1, measure) + rank = 2 if visitor.inames else 1 + if rank == 2: + ansatz_lfs.index = flatten_index(tuple(Variable(i) for i in visitor.inames), + (basis_functions_per_direction(),) * dim + ) + + # Construct the expression representing "{r,jac}.accumulate(..)" accum = name_accumulation_variable() - globalarg(accum, - shape=(basis_functions_per_direction(),) * dim, - argtype=AccumulationArg, - lfs=lfs, - ) - - instruction(expression=Subscript(Variable(result), tuple(Variable(i) for i in inames)), - assignee=Subscript(Variable(accum), tuple(Variable(i) for i in inames)), + expr = Call(PDELabAccumulationFunction(accum, rank), + (ansatz_lfs.get_args() + + test_lfs.get_args() + + (Subscript(Variable(result), tuple(Variable(i) for i in inames)),) + ) + ) + + instruction(assignees=(), + expression=expr, forced_iname_deps=frozenset(inames + visitor.inames), forced_iname_deps_is_final=True, depends_on=insn_dep, -- GitLab