Skip to content
Snippets Groups Projects
Commit 2999d749 authored by Marcel Koch's avatar Marcel Koch
Browse files

adds FMS node

parent 04e6f465
No related branches found
No related tags found
No related merge requests found
......@@ -16,8 +16,8 @@ from six.moves import intern
#
class FusedMultiplyAdd(prim.Expression):
""" Represents an FMA operation """
class FusedMultiplyAddSubBase(prim.Expression):
""" Base for FMA and FMS operation """
init_arg_names = ("mul_op1", "mul_op2", "add_op")
......@@ -32,9 +32,19 @@ class FusedMultiplyAdd(prim.Expression):
def stringifier(self):
return lp.symbolic.StringifyMapper
class FusedMultiplyAdd(FusedMultiplyAddSubBase):
""" Represents an FMA operation """
mapper_method = intern("map_fused_multiply_add")
class FusedMultiplySub(FusedMultiplyAddSubBase):
""" Represents an FMS operation """
mapper_method = intern("map_fused_multiply_sub")
#
# Mapper methods to monkey patch into the visitor base classes!
#
......@@ -67,6 +77,13 @@ def identity_map_fused_multiply_add(self, expr, *args):
)
def identity_map_fused_multiply_sub(self, expr, *args):
return FusedMultiplySub(self.rec(expr.mul_op1, *args),
self.rec(expr.mul_op2, *args),
self.rec(expr.add_op, *args),
)
def walk_map_fused_multiply_add(self, expr, *args):
if not self.visit(expr):
return
......@@ -83,6 +100,13 @@ def stringify_map_fused_multiply_add(self, expr, enclosing_prec):
self.rec(expr.add_op, PREC_NONE))
def stringify_map_fused_multiply_sub(self, expr, enclosing_prec):
from pymbolic.mapper.stringifier import PREC_NONE
return "fms(%s*%s-%s)" % (self.rec(expr.mul_op1, PREC_NONE),
self.rec(expr.mul_op2, PREC_NONE),
self.rec(expr.add_op, PREC_NONE))
def dependency_map_fused_multiply_add(self, expr):
return self.combine((self.rec(expr.mul_op1),
self.rec(expr.mul_op2),
......@@ -129,6 +153,14 @@ lp.symbolic.DependencyMapper.map_fused_multiply_add = dependency_map_fused_multi
lp.type_inference.TypeInferenceMapper.map_fused_multiply_add = type_inference_fused_multiply_add
lp.expression.VectorizabilityChecker.map_fused_multiply_add = vectorizability_map_fused_multiply_add
lp.symbolic.IdentityMapper.map_fused_multiply_sub = identity_map_fused_multiply_sub
lp.symbolic.SubstitutionMapper.map_fused_multiply_sub = lp.symbolic.SubstitutionMapper.map_variable
lp.symbolic.WalkMapper.map_fused_multiply_sub = walk_map_fused_multiply_add
lp.symbolic.StringifyMapper.map_fused_multiply_sub = stringify_map_fused_multiply_sub
lp.symbolic.DependencyMapper.map_fused_multiply_sub = dependency_map_fused_multiply_add
lp.type_inference.TypeInferenceMapper.map_fused_multiply_sub = type_inference_fused_multiply_add
lp.expression.VectorizabilityChecker.map_fused_multiply_sub = vectorizability_map_fused_multiply_add
#
# Some helper functions!
#
......
......@@ -115,6 +115,19 @@ class DuneExpressionToCExpressionMapper(ExpressionToCExpressionMapper):
# additions and multiplications.
return self.rec(expr.mul_op1 * expr.mul_op2 + expr.add_op, type_context)
def map_fused_multiply_sub(self, expr, type_context):
if self.codegen_state.vectorization_info:
include_file("dune/codegen/common/muladd_workarounds.hh", filetag="operatorfile")
# If this is vectorized we call the VCL function mul_add
return prim.Call(prim.Variable("mul_sub"),
(self.rec(expr.mul_op1, type_context),
self.rec(expr.mul_op2, type_context),
self.rec(expr.add_op, type_context)))
else:
# Default implementation that discards the node in favor of the resp.
# additions and multiplications.
return self.rec(expr.mul_op1 * expr.mul_op2 - expr.add_op, type_context)
def map_if(self, expr, type_context):
if self.codegen_state.vectorization_info:
return prim.Call(prim.Variable("select"),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment