diff --git a/python/dune/codegen/loopy/symbolic.py b/python/dune/codegen/loopy/symbolic.py index c5eef223d46ea37d85ad9ada5acfcfe90f5e5587..92772d6cd610910b45b9d239c59af80564470f7d 100644 --- a/python/dune/codegen/loopy/symbolic.py +++ b/python/dune/codegen/loopy/symbolic.py @@ -16,8 +16,8 @@ from six.moves import intern # -class FusedMultiplyAdd(prim.Expression): - """ Represents an FMA operation """ +class FusedMultiplyAddSubBase(prim.Expression): + """ Base for FMA and FMS operation """ init_arg_names = ("mul_op1", "mul_op2", "add_op") @@ -32,9 +32,19 @@ class FusedMultiplyAdd(prim.Expression): def stringifier(self): return lp.symbolic.StringifyMapper + +class FusedMultiplyAdd(FusedMultiplyAddSubBase): + """ Represents an FMA operation """ + mapper_method = intern("map_fused_multiply_add") +class FusedMultiplySub(FusedMultiplyAddSubBase): + """ Represents an FMS operation """ + + mapper_method = intern("map_fused_multiply_sub") + + # # Mapper methods to monkey patch into the visitor base classes! # @@ -67,6 +77,13 @@ def identity_map_fused_multiply_add(self, expr, *args): ) +def identity_map_fused_multiply_sub(self, expr, *args): + return FusedMultiplySub(self.rec(expr.mul_op1, *args), + self.rec(expr.mul_op2, *args), + self.rec(expr.add_op, *args), + ) + + def walk_map_fused_multiply_add(self, expr, *args): if not self.visit(expr): return @@ -83,6 +100,13 @@ def stringify_map_fused_multiply_add(self, expr, enclosing_prec): self.rec(expr.add_op, PREC_NONE)) +def stringify_map_fused_multiply_sub(self, expr, enclosing_prec): + from pymbolic.mapper.stringifier import PREC_NONE + return "fms(%s*%s-%s)" % (self.rec(expr.mul_op1, PREC_NONE), + self.rec(expr.mul_op2, PREC_NONE), + self.rec(expr.add_op, PREC_NONE)) + + def dependency_map_fused_multiply_add(self, expr): return self.combine((self.rec(expr.mul_op1), self.rec(expr.mul_op2), @@ -129,6 +153,14 @@ lp.symbolic.DependencyMapper.map_fused_multiply_add = dependency_map_fused_multi lp.type_inference.TypeInferenceMapper.map_fused_multiply_add = type_inference_fused_multiply_add lp.expression.VectorizabilityChecker.map_fused_multiply_add = vectorizability_map_fused_multiply_add +lp.symbolic.IdentityMapper.map_fused_multiply_sub = identity_map_fused_multiply_sub +lp.symbolic.SubstitutionMapper.map_fused_multiply_sub = lp.symbolic.SubstitutionMapper.map_variable +lp.symbolic.WalkMapper.map_fused_multiply_sub = walk_map_fused_multiply_add +lp.symbolic.StringifyMapper.map_fused_multiply_sub = stringify_map_fused_multiply_sub +lp.symbolic.DependencyMapper.map_fused_multiply_sub = dependency_map_fused_multiply_add +lp.type_inference.TypeInferenceMapper.map_fused_multiply_sub = type_inference_fused_multiply_add +lp.expression.VectorizabilityChecker.map_fused_multiply_sub = vectorizability_map_fused_multiply_add + # # Some helper functions! # diff --git a/python/dune/codegen/loopy/target.py b/python/dune/codegen/loopy/target.py index b8dd4d2552fb38b753be02fd6279074b4276d825..0d98e24aecc80dd17f222f0ac9d5f6e889209f12 100644 --- a/python/dune/codegen/loopy/target.py +++ b/python/dune/codegen/loopy/target.py @@ -115,6 +115,19 @@ class DuneExpressionToCExpressionMapper(ExpressionToCExpressionMapper): # additions and multiplications. return self.rec(expr.mul_op1 * expr.mul_op2 + expr.add_op, type_context) + def map_fused_multiply_sub(self, expr, type_context): + if self.codegen_state.vectorization_info: + include_file("dune/codegen/common/muladd_workarounds.hh", filetag="operatorfile") + # If this is vectorized we call the VCL function mul_add + return prim.Call(prim.Variable("mul_sub"), + (self.rec(expr.mul_op1, type_context), + self.rec(expr.mul_op2, type_context), + self.rec(expr.add_op, type_context))) + else: + # Default implementation that discards the node in favor of the resp. + # additions and multiplications. + return self.rec(expr.mul_op1 * expr.mul_op2 - expr.add_op, type_context) + def map_if(self, expr, type_context): if self.codegen_state.vectorization_info: return prim.Call(prim.Variable("select"),