diff --git a/python/dune/perftool/loopy/target.py b/python/dune/perftool/loopy/target.py index f0e50e8f8718d4e31e5c6e5e8509fe98901e3c4d..48e441cb819a07ac1acce36a01edd8174e9d7d0a 100644 --- a/python/dune/perftool/loopy/target.py +++ b/python/dune/perftool/loopy/target.py @@ -94,6 +94,16 @@ class DuneExpressionToCExpressionMapper(ExpressionToCExpressionMapper): # additions and multiplications. return self.rec(expr.mul_op1 * expr.mul_op2 + expr.add_op, type_context) + def map_if(self, expr, type_context): + if self.codegen_state.vectorization_info: + return prim.Call(prim.Variable("select"), + (self.rec(expr.condition, type_context), + self.rec(expr.then, type_context), + self.rec(expr.else_, type_context), + )) + else: + return ExpressionToCExpressionMapper.map_if(self, expr, type_context) + class DuneCExpressionToCodeMapper(CExpressionToCodeMapper): def map_remainder(self, expr, enclosing_prec): diff --git a/python/dune/perftool/loopy/vcl.py b/python/dune/perftool/loopy/vcl.py index d994dbb32ab4b6a6a1673fb04ce0ca7b62756af4..271e547ddb4f7a49a27e0d68de6157c5c9d536f7 100644 --- a/python/dune/perftool/loopy/vcl.py +++ b/python/dune/perftool/loopy/vcl.py @@ -62,7 +62,11 @@ def get_vcl_type(nptype, register_size=None, vec_size=None): @function_mangler -def vcl_mul_add(knl, func, arg_dtypes): +def vcl_function_mangler(knl, func, arg_dtypes): if func == "mul_add": vcl = lp.types.NumpyType(get_vcl_type(np.float64, register_size=256)) - return lp.CallMangleInfo("mul_add", (vcl), (vcl, vcl, vcl)) + return lp.CallMangleInfo("mul_add", (vcl,), (vcl, vcl, vcl)) + + if func == "select": + vcl = lp.types.NumpyType(get_vcl_type(np.float64, register_size=256)) + return lp.CallMangleInfo("select", (vcl,), (vcl, vcl, vcl)) diff --git a/python/dune/perftool/ufl/transformations/replace.py b/python/dune/perftool/ufl/transformations/replace.py index 73f2555f41d47ff8bd236a80a640bee504067b82..79e48e8339b895a81128561511e2241f974fa1be 100644 --- a/python/dune/perftool/ufl/transformations/replace.py +++ b/python/dune/perftool/ufl/transformations/replace.py @@ -1,4 +1,8 @@ -""" A transformation that replaces expression with others from a given dictionary """ +""" +A transformation that replaces expression with others from a given dictionary. +The replace algorithm from ufl.algorithms.replace is limited to terminals, we +want arbitrary expressions. +""" from dune.perftool.ufl.transformations import ufl_transformation from ufl.algorithms import MultiFunction @@ -19,6 +23,20 @@ class ReplaceExpression(MultiFunction): else: return self.reuse_if_untouched(o, *tuple(self(op) for op in o.ufl_operands)) + def conditional(self, o): + """ + We need to handle this separately because we want to collapse + vanishing conditionals + """ + if o in self.replacemap: + return self.replacemap[o] + else: + ops = tuple(self(op) for op in o.ufl_operands) + if ops[1] == ops[2]: + return ops[1] + else: + return self.reuse_if_untouched(o, *tuple(self(op) for op in o.ufl_operands)) + @ufl_transformation(name="replace") def replace_expression(expr, **kwargs):