Skip to content
Snippets Groups Projects
Commit 95ed172c authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Introduce explicit FMA node in dune-perftool

(instead of having the prototype in loopy)

Requires lots of monkey patches...
parent 96daed33
No related branches found
No related tags found
No related merge requests found
...@@ -44,6 +44,25 @@ class SumfactKernel(prim.Variable): ...@@ -44,6 +44,25 @@ class SumfactKernel(prim.Variable):
mapper_method = "map_sumfact_kernel" mapper_method = "map_sumfact_kernel"
class FusedMultiplyAdd(prim.Expression):
""" Represents an FMA operation """
init_arg_names = ("mul_op1", "mul_op2", "add_op")
def __init__(self, mul_op1, mul_op2, add_op):
self.mul_op1 = mul_op1
self.mul_op2 = mul_op2
self.add_op = add_op
def __getinitargs__(self):
return (self.mul_op1, self.mul_op2, self.add_op)
def stringifier(self):
return StringifyMapper
mapper_method = intern("map_fused_multiply_add")
# #
# Mapper methods to monkey patch into the visitor base classes! # Mapper methods to monkey patch into the visitor base classes!
# #
...@@ -69,11 +88,45 @@ def needs_resolution(self, expr): ...@@ -69,11 +88,45 @@ def needs_resolution(self, expr):
raise PerftoolError("SumfactKernel node is a placeholder and needs to be removed!") raise PerftoolError("SumfactKernel node is a placeholder and needs to be removed!")
def identity_map_fused_multiply_add(self, expr, *args):
return FusedMultiplyAdd(self.rec(expr.mul_op1, *args),
self.rec(expr.mul_op2, *args),
self.rec(expr.add_op, *args),
)
def walk_map_fused_multiply_add(self, expr, *args):
if not self.visit(expr):
return
self.rec(expr.mul_op1, *args)
self.rec(expr.mul_op2, *args)
self.rec(expr.add_op, *args)
def stringify_map_fused_multiply_add(self, expr, enclosing_prec):
from pymbolic.mapper.stringifier import PREC_NONE
return "fma(%s*%s+%s)" % (self.rec(expr.mul_op1, PREC_NONE),
self.rec(expr.mul_op2, PREC_NONE),
self.rec(expr.add_op, PREC_NONE))
def dependency_map_fused_multiply_add(self, expr):
return self.combine((self.rec(expr.mul_op1),
self.rec(expr.mul_op2),
self.rec(expr.add_op)
))
def type_inference_fused_multiply_add(self, expr):
return self.rec(expr.mul_op1)
# #
# Do the actual monkey patching!!! # Do the actual monkey patching!!!
# #
# SumfactKernel node
lp.symbolic.IdentityMapper.map_sumfact_kernel = identity_map_sumfact_kernel lp.symbolic.IdentityMapper.map_sumfact_kernel = identity_map_sumfact_kernel
lp.symbolic.SubstitutionMapper.map_sumfact_kernel = lp.symbolic.SubstitutionMapper.map_variable lp.symbolic.SubstitutionMapper.map_sumfact_kernel = lp.symbolic.SubstitutionMapper.map_variable
lp.symbolic.WalkMapper.map_sumfact_kernel = walk_map_sumfact_kernel lp.symbolic.WalkMapper.map_sumfact_kernel = walk_map_sumfact_kernel
...@@ -82,6 +135,13 @@ lp.symbolic.DependencyMapper.map_sumfact_kernel = dependency_map_sumfact_kernel ...@@ -82,6 +135,13 @@ lp.symbolic.DependencyMapper.map_sumfact_kernel = dependency_map_sumfact_kernel
lp.target.c.codegen.expression.ExpressionToCExpressionMapper.map_sumfact_kernel = needs_resolution lp.target.c.codegen.expression.ExpressionToCExpressionMapper.map_sumfact_kernel = needs_resolution
lp.type_inference.TypeInferenceMapper.map_sumfact_kernel = needs_resolution lp.type_inference.TypeInferenceMapper.map_sumfact_kernel = needs_resolution
# FusedMultiplyAdd node
lp.symbolic.IdentityMapper.map_fused_multiply_add = identity_map_fused_multiply_add
lp.symbolic.WalkMapper.map_fused_multiply_add = walk_map_fused_multiply_add
lp.symbolic.StringifyMapper.map_fused_multiply_add = stringify_map_fused_multiply_add
lp.symbolic.DependencyMapper.map_fused_multiply_add = dependency_map_fused_multiply_add
lp.type_inference.TypeInferenceMapper.map_fused_multiply_add = type_inference_fused_multiply_add
# #
# Some helper functions! # Some helper functions!
......
...@@ -82,6 +82,11 @@ class DuneExpressionToCExpressionMapper(ExpressionToCExpressionMapper): ...@@ -82,6 +82,11 @@ class DuneExpressionToCExpressionMapper(ExpressionToCExpressionMapper):
ret = Literal("{}({})".format(_type, ret.s)) ret = Literal("{}({})".format(_type, ret.s))
return ret return ret
def map_fused_multiply_add(self, expr, type_context):
# Default implementation that discards the node in favor of the resp.
# additions and multiplications.
return self.rec(expr.mul_op1 * expr.mul_op2 + expr.add_op, type_context)
class DuneCExpressionToCodeMapper(CExpressionToCodeMapper): class DuneCExpressionToCodeMapper(CExpressionToCodeMapper):
def map_remainder(self, expr, enclosing_prec): def map_remainder(self, expr, enclosing_prec):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment