diff --git a/python/dune/perftool/loopy/symbolic.py b/python/dune/perftool/loopy/symbolic.py index ad3c74076c814c454b7081d018b1d5ff16efad62..58768edf82f4c7e1d6358ae8090ddc9ca5cd2a1c 100644 --- a/python/dune/perftool/loopy/symbolic.py +++ b/python/dune/perftool/loopy/symbolic.py @@ -44,6 +44,25 @@ class SumfactKernel(prim.Variable): mapper_method = "map_sumfact_kernel" +class FusedMultiplyAdd(prim.Expression): + """ Represents an FMA operation """ + + init_arg_names = ("mul_op1", "mul_op2", "add_op") + + def __init__(self, mul_op1, mul_op2, add_op): + self.mul_op1 = mul_op1 + self.mul_op2 = mul_op2 + self.add_op = add_op + + def __getinitargs__(self): + return (self.mul_op1, self.mul_op2, self.add_op) + + def stringifier(self): + return StringifyMapper + + mapper_method = intern("map_fused_multiply_add") + + # # Mapper methods to monkey patch into the visitor base classes! # @@ -69,11 +88,45 @@ def needs_resolution(self, expr): raise PerftoolError("SumfactKernel node is a placeholder and needs to be removed!") +def identity_map_fused_multiply_add(self, expr, *args): + return FusedMultiplyAdd(self.rec(expr.mul_op1, *args), + self.rec(expr.mul_op2, *args), + self.rec(expr.add_op, *args), + ) + + +def walk_map_fused_multiply_add(self, expr, *args): + if not self.visit(expr): + return + + self.rec(expr.mul_op1, *args) + self.rec(expr.mul_op2, *args) + self.rec(expr.add_op, *args) + + +def stringify_map_fused_multiply_add(self, expr, enclosing_prec): + from pymbolic.mapper.stringifier import PREC_NONE + return "fma(%s*%s+%s)" % (self.rec(expr.mul_op1, PREC_NONE), + self.rec(expr.mul_op2, PREC_NONE), + self.rec(expr.add_op, PREC_NONE)) + + +def dependency_map_fused_multiply_add(self, expr): + return self.combine((self.rec(expr.mul_op1), + self.rec(expr.mul_op2), + self.rec(expr.add_op) + )) + + +def type_inference_fused_multiply_add(self, expr): + return self.rec(expr.mul_op1) + + # # Do the actual monkey patching!!! # - +# SumfactKernel node lp.symbolic.IdentityMapper.map_sumfact_kernel = identity_map_sumfact_kernel lp.symbolic.SubstitutionMapper.map_sumfact_kernel = lp.symbolic.SubstitutionMapper.map_variable lp.symbolic.WalkMapper.map_sumfact_kernel = walk_map_sumfact_kernel @@ -82,6 +135,13 @@ lp.symbolic.DependencyMapper.map_sumfact_kernel = dependency_map_sumfact_kernel lp.target.c.codegen.expression.ExpressionToCExpressionMapper.map_sumfact_kernel = needs_resolution lp.type_inference.TypeInferenceMapper.map_sumfact_kernel = needs_resolution +# FusedMultiplyAdd node +lp.symbolic.IdentityMapper.map_fused_multiply_add = identity_map_fused_multiply_add +lp.symbolic.WalkMapper.map_fused_multiply_add = walk_map_fused_multiply_add +lp.symbolic.StringifyMapper.map_fused_multiply_add = stringify_map_fused_multiply_add +lp.symbolic.DependencyMapper.map_fused_multiply_add = dependency_map_fused_multiply_add +lp.type_inference.TypeInferenceMapper.map_fused_multiply_add = type_inference_fused_multiply_add + # # Some helper functions! diff --git a/python/dune/perftool/loopy/target.py b/python/dune/perftool/loopy/target.py index 3bb99b44f53321086da442f94bcaf754cb1980ac..18ebddbfa7fed5903281e69530d687668900d517 100644 --- a/python/dune/perftool/loopy/target.py +++ b/python/dune/perftool/loopy/target.py @@ -82,6 +82,11 @@ class DuneExpressionToCExpressionMapper(ExpressionToCExpressionMapper): ret = Literal("{}({})".format(_type, ret.s)) return ret + def map_fused_multiply_add(self, expr, type_context): + # Default implementation that discards the node in favor of the resp. + # additions and multiplications. + return self.rec(expr.mul_op1 * expr.mul_op2 + expr.add_op, type_context) + class DuneCExpressionToCodeMapper(CExpressionToCodeMapper): def map_remainder(self, expr, enclosing_prec):