From fb1ce62e094c8ae69f5a4f1e95addb95f865b7d8 Mon Sep 17 00:00:00 2001
From: Marcel Koch <marcel.koch@uni-muenster.de>
Date: Tue, 30 Jan 2018 15:35:09 +0100
Subject: [PATCH] adds weight to accumulation

---
 .../dune/perftool/blockstructured/accumulation.py  | 14 +++++++++++++-
 1 file changed, 13 insertions(+), 1 deletion(-)

diff --git a/python/dune/perftool/blockstructured/accumulation.py b/python/dune/perftool/blockstructured/accumulation.py
index 05b2c8b6..a6024cf9 100644
--- a/python/dune/perftool/blockstructured/accumulation.py
+++ b/python/dune/perftool/blockstructured/accumulation.py
@@ -1,8 +1,11 @@
 from dune.perftool.generation import temporary_variable, instruction
+from dune.perftool.loopy.target import dtype_floatingpoint
 from dune.perftool.options import get_option
 from dune.perftool.pdelab.localoperator import determine_accumulation_space
 from dune.perftool.pdelab.argument import name_accumulation_variable
 from dune.perftool.pdelab.localoperator import boundary_predicates
+from dune.perftool.generation.loopy import function_mangler
+import loopy as lp
 import pymbolic.primitives as prim
 
 
@@ -20,6 +23,13 @@ def name_accumulation_alias(container, accumspace):
     return name
 
 
+@function_mangler
+def residual_weight_mangler(knl, func, arg_dtypes):
+    if isinstance(func, str) and func.endswith('.weight'):
+        from pudb import set_trace; set_trace()
+        return lp.CallMangleInfo(func, (lp.types.NumpyType(dtype_floatingpoint()),), ())
+
+
 def generate_accumulation_instruction(expr, visitor):
     # Collect the lfs and lfs indices for the accumulate call
     test_lfs = determine_accumulation_space(visitor.test_info, 0)
@@ -39,8 +49,10 @@ def generate_accumulation_instruction(expr, visitor):
 
     assignee = prim.Subscript(prim.Variable(accumvar_alias), tuple(prim.Variable(i) for i in lfs_inames))
 
+    expr_with_weight = prim.Product((expr, prim.Call(prim.Variable(accumvar+'.weight'),())))
+
     instruction(assignee=assignee,
-                expression=prim.Sum((expr, assignee)),
+                expression=prim.Sum((expr_with_weight, assignee)),
                 forced_iname_deps=frozenset(lfs_inames).union(frozenset(quad_inames)),
                 forced_iname_deps_is_final=True,
                 predicates=predicates
-- 
GitLab