Skip to content
Snippets Groups Projects
Commit 2999d749 authored by Marcel Koch's avatar Marcel Koch
Browse files

adds FMS node

parent 04e6f465
No related branches found
No related tags found
No related merge requests found
...@@ -16,8 +16,8 @@ from six.moves import intern ...@@ -16,8 +16,8 @@ from six.moves import intern
# #
class FusedMultiplyAdd(prim.Expression): class FusedMultiplyAddSubBase(prim.Expression):
""" Represents an FMA operation """ """ Base for FMA and FMS operation """
init_arg_names = ("mul_op1", "mul_op2", "add_op") init_arg_names = ("mul_op1", "mul_op2", "add_op")
...@@ -32,9 +32,19 @@ class FusedMultiplyAdd(prim.Expression): ...@@ -32,9 +32,19 @@ class FusedMultiplyAdd(prim.Expression):
def stringifier(self): def stringifier(self):
return lp.symbolic.StringifyMapper return lp.symbolic.StringifyMapper
class FusedMultiplyAdd(FusedMultiplyAddSubBase):
""" Represents an FMA operation """
mapper_method = intern("map_fused_multiply_add") mapper_method = intern("map_fused_multiply_add")
class FusedMultiplySub(FusedMultiplyAddSubBase):
""" Represents an FMS operation """
mapper_method = intern("map_fused_multiply_sub")
# #
# Mapper methods to monkey patch into the visitor base classes! # Mapper methods to monkey patch into the visitor base classes!
# #
...@@ -67,6 +77,13 @@ def identity_map_fused_multiply_add(self, expr, *args): ...@@ -67,6 +77,13 @@ def identity_map_fused_multiply_add(self, expr, *args):
) )
def identity_map_fused_multiply_sub(self, expr, *args):
return FusedMultiplySub(self.rec(expr.mul_op1, *args),
self.rec(expr.mul_op2, *args),
self.rec(expr.add_op, *args),
)
def walk_map_fused_multiply_add(self, expr, *args): def walk_map_fused_multiply_add(self, expr, *args):
if not self.visit(expr): if not self.visit(expr):
return return
...@@ -83,6 +100,13 @@ def stringify_map_fused_multiply_add(self, expr, enclosing_prec): ...@@ -83,6 +100,13 @@ def stringify_map_fused_multiply_add(self, expr, enclosing_prec):
self.rec(expr.add_op, PREC_NONE)) self.rec(expr.add_op, PREC_NONE))
def stringify_map_fused_multiply_sub(self, expr, enclosing_prec):
from pymbolic.mapper.stringifier import PREC_NONE
return "fms(%s*%s-%s)" % (self.rec(expr.mul_op1, PREC_NONE),
self.rec(expr.mul_op2, PREC_NONE),
self.rec(expr.add_op, PREC_NONE))
def dependency_map_fused_multiply_add(self, expr): def dependency_map_fused_multiply_add(self, expr):
return self.combine((self.rec(expr.mul_op1), return self.combine((self.rec(expr.mul_op1),
self.rec(expr.mul_op2), self.rec(expr.mul_op2),
...@@ -129,6 +153,14 @@ lp.symbolic.DependencyMapper.map_fused_multiply_add = dependency_map_fused_multi ...@@ -129,6 +153,14 @@ lp.symbolic.DependencyMapper.map_fused_multiply_add = dependency_map_fused_multi
lp.type_inference.TypeInferenceMapper.map_fused_multiply_add = type_inference_fused_multiply_add lp.type_inference.TypeInferenceMapper.map_fused_multiply_add = type_inference_fused_multiply_add
lp.expression.VectorizabilityChecker.map_fused_multiply_add = vectorizability_map_fused_multiply_add lp.expression.VectorizabilityChecker.map_fused_multiply_add = vectorizability_map_fused_multiply_add
lp.symbolic.IdentityMapper.map_fused_multiply_sub = identity_map_fused_multiply_sub
lp.symbolic.SubstitutionMapper.map_fused_multiply_sub = lp.symbolic.SubstitutionMapper.map_variable
lp.symbolic.WalkMapper.map_fused_multiply_sub = walk_map_fused_multiply_add
lp.symbolic.StringifyMapper.map_fused_multiply_sub = stringify_map_fused_multiply_sub
lp.symbolic.DependencyMapper.map_fused_multiply_sub = dependency_map_fused_multiply_add
lp.type_inference.TypeInferenceMapper.map_fused_multiply_sub = type_inference_fused_multiply_add
lp.expression.VectorizabilityChecker.map_fused_multiply_sub = vectorizability_map_fused_multiply_add
# #
# Some helper functions! # Some helper functions!
# #
......
...@@ -115,6 +115,19 @@ class DuneExpressionToCExpressionMapper(ExpressionToCExpressionMapper): ...@@ -115,6 +115,19 @@ class DuneExpressionToCExpressionMapper(ExpressionToCExpressionMapper):
# additions and multiplications. # additions and multiplications.
return self.rec(expr.mul_op1 * expr.mul_op2 + expr.add_op, type_context) return self.rec(expr.mul_op1 * expr.mul_op2 + expr.add_op, type_context)
def map_fused_multiply_sub(self, expr, type_context):
if self.codegen_state.vectorization_info:
include_file("dune/codegen/common/muladd_workarounds.hh", filetag="operatorfile")
# If this is vectorized we call the VCL function mul_add
return prim.Call(prim.Variable("mul_sub"),
(self.rec(expr.mul_op1, type_context),
self.rec(expr.mul_op2, type_context),
self.rec(expr.add_op, type_context)))
else:
# Default implementation that discards the node in favor of the resp.
# additions and multiplications.
return self.rec(expr.mul_op1 * expr.mul_op2 - expr.add_op, type_context)
def map_if(self, expr, type_context): def map_if(self, expr, type_context):
if self.codegen_state.vectorization_info: if self.codegen_state.vectorization_info:
return prim.Call(prim.Variable("select"), return prim.Call(prim.Variable("select"),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment