From fb1ce62e094c8ae69f5a4f1e95addb95f865b7d8 Mon Sep 17 00:00:00 2001 From: Marcel Koch <marcel.koch@uni-muenster.de> Date: Tue, 30 Jan 2018 15:35:09 +0100 Subject: [PATCH] adds weight to accumulation --- .../dune/perftool/blockstructured/accumulation.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/python/dune/perftool/blockstructured/accumulation.py b/python/dune/perftool/blockstructured/accumulation.py index 05b2c8b6..a6024cf9 100644 --- a/python/dune/perftool/blockstructured/accumulation.py +++ b/python/dune/perftool/blockstructured/accumulation.py @@ -1,8 +1,11 @@ from dune.perftool.generation import temporary_variable, instruction +from dune.perftool.loopy.target import dtype_floatingpoint from dune.perftool.options import get_option from dune.perftool.pdelab.localoperator import determine_accumulation_space from dune.perftool.pdelab.argument import name_accumulation_variable from dune.perftool.pdelab.localoperator import boundary_predicates +from dune.perftool.generation.loopy import function_mangler +import loopy as lp import pymbolic.primitives as prim @@ -20,6 +23,13 @@ def name_accumulation_alias(container, accumspace): return name +@function_mangler +def residual_weight_mangler(knl, func, arg_dtypes): + if isinstance(func, str) and func.endswith('.weight'): + from pudb import set_trace; set_trace() + return lp.CallMangleInfo(func, (lp.types.NumpyType(dtype_floatingpoint()),), ()) + + def generate_accumulation_instruction(expr, visitor): # Collect the lfs and lfs indices for the accumulate call test_lfs = determine_accumulation_space(visitor.test_info, 0) @@ -39,8 +49,10 @@ def generate_accumulation_instruction(expr, visitor): assignee = prim.Subscript(prim.Variable(accumvar_alias), tuple(prim.Variable(i) for i in lfs_inames)) + expr_with_weight = prim.Product((expr, prim.Call(prim.Variable(accumvar+'.weight'),()))) + instruction(assignee=assignee, - expression=prim.Sum((expr, assignee)), + expression=prim.Sum((expr_with_weight, assignee)), forced_iname_deps=frozenset(lfs_inames).union(frozenset(quad_inames)), forced_iname_deps_is_final=True, predicates=predicates -- GitLab