From 3fc5318116fb4cefef3699c3678142f26758916b Mon Sep 17 00:00:00 2001 From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de> Date: Tue, 7 Feb 2017 17:14:28 +0100 Subject: [PATCH] Implement conditional handling for vector types --- python/dune/perftool/loopy/target.py | 10 ++++++++++ python/dune/perftool/loopy/vcl.py | 8 ++++++-- .../perftool/ufl/transformations/replace.py | 20 ++++++++++++++++++- 3 files changed, 35 insertions(+), 3 deletions(-) diff --git a/python/dune/perftool/loopy/target.py b/python/dune/perftool/loopy/target.py index f0e50e8f..48e441cb 100644 --- a/python/dune/perftool/loopy/target.py +++ b/python/dune/perftool/loopy/target.py @@ -94,6 +94,16 @@ class DuneExpressionToCExpressionMapper(ExpressionToCExpressionMapper): # additions and multiplications. return self.rec(expr.mul_op1 * expr.mul_op2 + expr.add_op, type_context) + def map_if(self, expr, type_context): + if self.codegen_state.vectorization_info: + return prim.Call(prim.Variable("select"), + (self.rec(expr.condition, type_context), + self.rec(expr.then, type_context), + self.rec(expr.else_, type_context), + )) + else: + return ExpressionToCExpressionMapper.map_if(self, expr, type_context) + class DuneCExpressionToCodeMapper(CExpressionToCodeMapper): def map_remainder(self, expr, enclosing_prec): diff --git a/python/dune/perftool/loopy/vcl.py b/python/dune/perftool/loopy/vcl.py index d994dbb3..271e547d 100644 --- a/python/dune/perftool/loopy/vcl.py +++ b/python/dune/perftool/loopy/vcl.py @@ -62,7 +62,11 @@ def get_vcl_type(nptype, register_size=None, vec_size=None): @function_mangler -def vcl_mul_add(knl, func, arg_dtypes): +def vcl_function_mangler(knl, func, arg_dtypes): if func == "mul_add": vcl = lp.types.NumpyType(get_vcl_type(np.float64, register_size=256)) - return lp.CallMangleInfo("mul_add", (vcl), (vcl, vcl, vcl)) + return lp.CallMangleInfo("mul_add", (vcl,), (vcl, vcl, vcl)) + + if func == "select": + vcl = lp.types.NumpyType(get_vcl_type(np.float64, register_size=256)) + return lp.CallMangleInfo("select", (vcl,), (vcl, vcl, vcl)) diff --git a/python/dune/perftool/ufl/transformations/replace.py b/python/dune/perftool/ufl/transformations/replace.py index 73f2555f..79e48e83 100644 --- a/python/dune/perftool/ufl/transformations/replace.py +++ b/python/dune/perftool/ufl/transformations/replace.py @@ -1,4 +1,8 @@ -""" A transformation that replaces expression with others from a given dictionary """ +""" +A transformation that replaces expression with others from a given dictionary. +The replace algorithm from ufl.algorithms.replace is limited to terminals, we +want arbitrary expressions. +""" from dune.perftool.ufl.transformations import ufl_transformation from ufl.algorithms import MultiFunction @@ -19,6 +23,20 @@ class ReplaceExpression(MultiFunction): else: return self.reuse_if_untouched(o, *tuple(self(op) for op in o.ufl_operands)) + def conditional(self, o): + """ + We need to handle this separately because we want to collapse + vanishing conditionals + """ + if o in self.replacemap: + return self.replacemap[o] + else: + ops = tuple(self(op) for op in o.ufl_operands) + if ops[1] == ops[2]: + return ops[1] + else: + return self.reuse_if_untouched(o, *tuple(self(op) for op in o.ufl_operands)) + @ufl_transformation(name="replace") def replace_expression(expr, **kwargs): -- GitLab