Skip to content
Snippets Groups Projects
Commit 030e2232 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Handle max/min through functions

parent e7470548
No related branches found
No related tags found
No related merge requests found
......@@ -6,9 +6,17 @@ from dune.perftool.generation import (function_mangler,
from loopy import CallMangleInfo
import numpy as np
def add_std(name, dtype):
if dtype.dtype.kind == "f":
return "std::{}".format(name)
@function_mangler
def dune_math_manglers(kernel, name, arg_dtypes):
dt = arg_dtypes[0]
if name == "exp":
include_file("dune/perftool/common/vectorclass.hh", filetag="operatorfile")
return CallMangleInfo("exp",
......@@ -21,3 +29,15 @@ def dune_math_manglers(kernel, name, arg_dtypes):
arg_dtypes,
arg_dtypes,
)
if name == "max":
return CallMangleInfo(add_std("min", dt),
(dt,),
arg_dtypes,
)
if name == "min":
return CallMangleInfo(add_std("min", dt),
(dt,),
arg_dtypes,
)
......@@ -108,24 +108,6 @@ class DuneCExpressionToCodeMapper(CExpressionToCodeMapper):
else:
return CExpressionToCodeMapper.map_remainder(expr, enclosing_prec)
def map_min(self, expr, enclosing_prec):
""" Max/Min is not implemented as a function!
TODO: Revisit this w.r.t. ADL problems
"""
what = type(expr).__name__.lower()
children = list(set(expr.children))
result = self.rec(children.pop(), PREC_NONE)
while children:
result = "%s(%s, %s)" % (what,
self.rec(children.pop(), PREC_NONE),
result)
return result
map_max = map_min
class DuneASTBuilder(CASTBuilder):
def function_manglers(self):
......
......@@ -266,10 +266,25 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker):
raise NotImplementedError("Power function not really implemented")
def max_value(self, o):
return prim.Max(tuple(self.call(op) for op in o.ufl_operands))
# NB: There is a pymbolic node Max/Min, which we are intentionally
# avoiding here, as the C++ nature is so much more like a function!
children = set(self.call(op) for op in o.ufl_operands)
ret = children.pop()
while children:
ret = prim.Call(prim.Variable("max"), (ret, children.pop()))
return ret
def min_value(self, o):
return prim.Min(tuple(self.call(op) for op in o.ufl_operands))
# NB: There is a pymbolic node Max/Min, which we are intentionally
# avoiding here, as the C++ nature is so much more like a function!
children = set(self.call(op) for op in o.ufl_operands)
ret = children.pop()
while children:
ret = prim.Call(prim.Variable("min"), (ret, children.pop()))
return ret
#
# Handler for conditionals, use pymbolic base implementation
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment