diff --git a/python/dune/perftool/loopy/mangler.py b/python/dune/perftool/loopy/mangler.py index 80a3d69a0ca92aaa4266410f24c1ebaf4306012d..a411512251df091e9d61403579a0893faf4d456d 100644 --- a/python/dune/perftool/loopy/mangler.py +++ b/python/dune/perftool/loopy/mangler.py @@ -6,9 +6,17 @@ from dune.perftool.generation import (function_mangler, from loopy import CallMangleInfo +import numpy as np + + +def add_std(name, dtype): + if dtype.dtype.kind == "f": + return "std::{}".format(name) + @function_mangler def dune_math_manglers(kernel, name, arg_dtypes): + dt = arg_dtypes[0] if name == "exp": include_file("dune/perftool/common/vectorclass.hh", filetag="operatorfile") return CallMangleInfo("exp", @@ -21,3 +29,15 @@ def dune_math_manglers(kernel, name, arg_dtypes): arg_dtypes, arg_dtypes, ) + + if name == "max": + return CallMangleInfo(add_std("min", dt), + (dt,), + arg_dtypes, + ) + + if name == "min": + return CallMangleInfo(add_std("min", dt), + (dt,), + arg_dtypes, + ) diff --git a/python/dune/perftool/loopy/target.py b/python/dune/perftool/loopy/target.py index 6ef27648aa9d659ab63838d84798f19eec3a91f3..f08f3fadc445bb774c6f750e521e448daddbc03f 100644 --- a/python/dune/perftool/loopy/target.py +++ b/python/dune/perftool/loopy/target.py @@ -108,24 +108,6 @@ class DuneCExpressionToCodeMapper(CExpressionToCodeMapper): else: return CExpressionToCodeMapper.map_remainder(expr, enclosing_prec) - def map_min(self, expr, enclosing_prec): - """ Max/Min is not implemented as a function! - TODO: Revisit this w.r.t. ADL problems - """ - what = type(expr).__name__.lower() - - children = list(set(expr.children)) - - result = self.rec(children.pop(), PREC_NONE) - while children: - result = "%s(%s, %s)" % (what, - self.rec(children.pop(), PREC_NONE), - result) - - return result - - map_max = map_min - class DuneASTBuilder(CASTBuilder): def function_manglers(self): diff --git a/python/dune/perftool/ufl/visitor.py b/python/dune/perftool/ufl/visitor.py index de92f3a49f23346324a280c0e5b1565563adaf3a..f107d94be1ce4b5eeb9e6ef822bb97b3fc69970f 100644 --- a/python/dune/perftool/ufl/visitor.py +++ b/python/dune/perftool/ufl/visitor.py @@ -266,10 +266,25 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker): raise NotImplementedError("Power function not really implemented") def max_value(self, o): - return prim.Max(tuple(self.call(op) for op in o.ufl_operands)) + # NB: There is a pymbolic node Max/Min, which we are intentionally + # avoiding here, as the C++ nature is so much more like a function! + children = set(self.call(op) for op in o.ufl_operands) + ret = children.pop() + while children: + ret = prim.Call(prim.Variable("max"), (ret, children.pop())) + + return ret def min_value(self, o): - return prim.Min(tuple(self.call(op) for op in o.ufl_operands)) + # NB: There is a pymbolic node Max/Min, which we are intentionally + # avoiding here, as the C++ nature is so much more like a function! + children = set(self.call(op) for op in o.ufl_operands) + ret = children.pop() + while children: + ret = prim.Call(prim.Variable("min"), (ret, children.pop())) + + return ret + # # Handler for conditionals, use pymbolic base implementation