Skip to content
Snippets Groups Projects
Commit 3fc53181 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Implement conditional handling for vector types

parent 5ea3fa03
No related branches found
No related tags found
No related merge requests found
......@@ -94,6 +94,16 @@ class DuneExpressionToCExpressionMapper(ExpressionToCExpressionMapper):
# additions and multiplications.
return self.rec(expr.mul_op1 * expr.mul_op2 + expr.add_op, type_context)
def map_if(self, expr, type_context):
if self.codegen_state.vectorization_info:
return prim.Call(prim.Variable("select"),
(self.rec(expr.condition, type_context),
self.rec(expr.then, type_context),
self.rec(expr.else_, type_context),
))
else:
return ExpressionToCExpressionMapper.map_if(self, expr, type_context)
class DuneCExpressionToCodeMapper(CExpressionToCodeMapper):
def map_remainder(self, expr, enclosing_prec):
......
......@@ -62,7 +62,11 @@ def get_vcl_type(nptype, register_size=None, vec_size=None):
@function_mangler
def vcl_mul_add(knl, func, arg_dtypes):
def vcl_function_mangler(knl, func, arg_dtypes):
if func == "mul_add":
vcl = lp.types.NumpyType(get_vcl_type(np.float64, register_size=256))
return lp.CallMangleInfo("mul_add", (vcl), (vcl, vcl, vcl))
return lp.CallMangleInfo("mul_add", (vcl,), (vcl, vcl, vcl))
if func == "select":
vcl = lp.types.NumpyType(get_vcl_type(np.float64, register_size=256))
return lp.CallMangleInfo("select", (vcl,), (vcl, vcl, vcl))
""" A transformation that replaces expression with others from a given dictionary """
"""
A transformation that replaces expression with others from a given dictionary.
The replace algorithm from ufl.algorithms.replace is limited to terminals, we
want arbitrary expressions.
"""
from dune.perftool.ufl.transformations import ufl_transformation
from ufl.algorithms import MultiFunction
......@@ -19,6 +23,20 @@ class ReplaceExpression(MultiFunction):
else:
return self.reuse_if_untouched(o, *tuple(self(op) for op in o.ufl_operands))
def conditional(self, o):
"""
We need to handle this separately because we want to collapse
vanishing conditionals
"""
if o in self.replacemap:
return self.replacemap[o]
else:
ops = tuple(self(op) for op in o.ufl_operands)
if ops[1] == ops[2]:
return ops[1]
else:
return self.reuse_if_untouched(o, *tuple(self(op) for op in o.ufl_operands))
@ufl_transformation(name="replace")
def replace_expression(expr, **kwargs):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment