Skip to content
Snippets Groups Projects
Commit fb1ce62e authored by Marcel Koch's avatar Marcel Koch
Browse files

adds weight to accumulation

parent e8a4edec
No related branches found
No related tags found
No related merge requests found
from dune.perftool.generation import temporary_variable, instruction from dune.perftool.generation import temporary_variable, instruction
from dune.perftool.loopy.target import dtype_floatingpoint
from dune.perftool.options import get_option from dune.perftool.options import get_option
from dune.perftool.pdelab.localoperator import determine_accumulation_space from dune.perftool.pdelab.localoperator import determine_accumulation_space
from dune.perftool.pdelab.argument import name_accumulation_variable from dune.perftool.pdelab.argument import name_accumulation_variable
from dune.perftool.pdelab.localoperator import boundary_predicates from dune.perftool.pdelab.localoperator import boundary_predicates
from dune.perftool.generation.loopy import function_mangler
import loopy as lp
import pymbolic.primitives as prim import pymbolic.primitives as prim
...@@ -20,6 +23,13 @@ def name_accumulation_alias(container, accumspace): ...@@ -20,6 +23,13 @@ def name_accumulation_alias(container, accumspace):
return name return name
@function_mangler
def residual_weight_mangler(knl, func, arg_dtypes):
if isinstance(func, str) and func.endswith('.weight'):
from pudb import set_trace; set_trace()
return lp.CallMangleInfo(func, (lp.types.NumpyType(dtype_floatingpoint()),), ())
def generate_accumulation_instruction(expr, visitor): def generate_accumulation_instruction(expr, visitor):
# Collect the lfs and lfs indices for the accumulate call # Collect the lfs and lfs indices for the accumulate call
test_lfs = determine_accumulation_space(visitor.test_info, 0) test_lfs = determine_accumulation_space(visitor.test_info, 0)
...@@ -39,8 +49,10 @@ def generate_accumulation_instruction(expr, visitor): ...@@ -39,8 +49,10 @@ def generate_accumulation_instruction(expr, visitor):
assignee = prim.Subscript(prim.Variable(accumvar_alias), tuple(prim.Variable(i) for i in lfs_inames)) assignee = prim.Subscript(prim.Variable(accumvar_alias), tuple(prim.Variable(i) for i in lfs_inames))
expr_with_weight = prim.Product((expr, prim.Call(prim.Variable(accumvar+'.weight'),())))
instruction(assignee=assignee, instruction(assignee=assignee,
expression=prim.Sum((expr, assignee)), expression=prim.Sum((expr_with_weight, assignee)),
forced_iname_deps=frozenset(lfs_inames).union(frozenset(quad_inames)), forced_iname_deps=frozenset(lfs_inames).union(frozenset(quad_inames)),
forced_iname_deps_is_final=True, forced_iname_deps_is_final=True,
predicates=predicates predicates=predicates
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment