Skip to content
Snippets Groups Projects
Commit e93da956 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

More fixes

parent 7ddc18e1
No related branches found
No related tags found
No related merge requests found
......@@ -55,7 +55,7 @@ def get_temporary_name():
@generator_factory(item_tags=("temporary",), cache_key_generator=lambda n, **kw: n)
def temporary_variable(name, **kwargs):
from dune.perftool.loopy.temporary import DuneTemporaryVariable
return DuneTemporaryVariable(name, scope=loopy.temp_var_scope.LOCAL, **kwargs)
return DuneTemporaryVariable(name, scope=loopy.temp_var_scope.PRIVATE, **kwargs)
# Now define generators for instructions. To ease dependency handling of instructions
......
......@@ -14,6 +14,8 @@ from loopy.target.c import CASTBuilder
from loopy.target.c.codegen.expression import ExpressionToCExpressionMapper, CExpressionToCodeMapper
from loopy.types import NumpyType
from pymbolic.mapper.stringifier import PREC_NONE
import pymbolic.primitives as prim
import cgen
......@@ -62,6 +64,24 @@ class DuneCExpressionToCodeMapper(CExpressionToCodeMapper):
else:
return CExpressionToCodeMapper.map_remainder(expr, enclosing_prec)
def map_min(self, expr, enclosing_prec):
""" Max/Min is not implemented as a function!
TODO: Revisit this w.r.t. ADL problems
"""
what = type(expr).__name__.lower()
children = list(expr.children)
result = self.rec(children.pop(), PREC_NONE)
while children:
result = "std::%s(%s, %s)" % (what,
self.rec(children.pop(), PREC_NONE),
result)
return result
map_max = map_min
class DuneASTBuilder(CASTBuilder):
def function_manglers(self):
......
......@@ -23,14 +23,17 @@ def add_vector_view(knl, tmpname):
assert tmpname in temporaries
temp = temporaries[tmpname]
vecname = get_vector_view_name(tmpname)
bsname = tmpname + "_base"
if vecname in knl.temporary_variables:
return knl
# Add base storage to the original temporary!
if not temp.base_storage:
temp = temp.copy(base_storage=tmpname + "_base")
temp = temp.copy(base_storage=bsname)
temporaries[tmpname] = temp
else:
bsname = temp.base_storage
# Determine the shape by dividing total size by vector size
vecsize = get_vcl_type_size(temp.dtype)
......@@ -41,7 +44,7 @@ def add_vector_view(knl, tmpname):
temporaries[vecname] = lp.TemporaryVariable(vecname,
dim_tags="c,vec",
shape=(size, vecsize),
base_storage=tmpname + "_base",
base_storage=bsname,
dtype=np.float64,
scope=lp.temp_var_scope.PRIVATE,
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment