From 15a35f9177d55d8a4fd8dd2bed2004617f73a451 Mon Sep 17 00:00:00 2001
From: Marcel Koch <marcel.koch@uni-muenster.de>
Date: Thu, 31 Jan 2019 14:37:19 +0100
Subject: [PATCH] reduce code duplication

---
 python/dune/codegen/loopy/symbolic.py | 79 +++++++++++----------------
 python/dune/codegen/loopy/target.py   | 24 +++-----
 2 files changed, 41 insertions(+), 62 deletions(-)

diff --git a/python/dune/codegen/loopy/symbolic.py b/python/dune/codegen/loopy/symbolic.py
index 92772d6c..16ee7ca1 100644
--- a/python/dune/codegen/loopy/symbolic.py
+++ b/python/dune/codegen/loopy/symbolic.py
@@ -32,18 +32,16 @@ class FusedMultiplyAddSubBase(prim.Expression):
     def stringifier(self):
         return lp.symbolic.StringifyMapper
 
+    mapper_method = intern("map_fused_multiply_add_sub")
+
 
 class FusedMultiplyAdd(FusedMultiplyAddSubBase):
     """ Represents an FMA operation """
 
-    mapper_method = intern("map_fused_multiply_add")
-
 
 class FusedMultiplySub(FusedMultiplyAddSubBase):
     """ Represents an FMS operation """
 
-    mapper_method = intern("map_fused_multiply_sub")
-
 
 #
 # Mapper methods to monkey patch into the visitor base classes!
@@ -70,21 +68,20 @@ def needs_resolution(self, expr):
     raise CodegenError("SumfactKernel node is a placeholder and needs to be removed!")
 
 
-def identity_map_fused_multiply_add(self, expr, *args):
-    return FusedMultiplyAdd(self.rec(expr.mul_op1, *args),
-                            self.rec(expr.mul_op2, *args),
-                            self.rec(expr.add_op, *args),
-                            )
-
+def identity_map_fused_multiply_add_sub(self, expr, *args):
+    if isinstance(expr, FusedMultiplyAdd):
+        return FusedMultiplyAdd(self.rec(expr.mul_op1, *args),
+                                self.rec(expr.mul_op2, *args),
+                                self.rec(expr.add_op, *args),
+                                )
+    else:
+        return FusedMultiplySub(self.rec(expr.mul_op1, *args),
+                                self.rec(expr.mul_op2, *args),
+                                self.rec(expr.add_op, *args),
+                                )
 
-def identity_map_fused_multiply_sub(self, expr, *args):
-    return FusedMultiplySub(self.rec(expr.mul_op1, *args),
-                            self.rec(expr.mul_op2, *args),
-                            self.rec(expr.add_op, *args),
-                            )
 
-
-def walk_map_fused_multiply_add(self, expr, *args):
+def walk_map_fused_multiply_add_sub(self, expr, *args):
     if not self.visit(expr):
         return
 
@@ -93,32 +90,30 @@ def walk_map_fused_multiply_add(self, expr, *args):
     self.rec(expr.add_op, *args)
 
 
-def stringify_map_fused_multiply_add(self, expr, enclosing_prec):
-    from pymbolic.mapper.stringifier import PREC_NONE
-    return "fma(%s*%s+%s)" % (self.rec(expr.mul_op1, PREC_NONE),
-                              self.rec(expr.mul_op2, PREC_NONE),
-                              self.rec(expr.add_op, PREC_NONE))
-
-
-def stringify_map_fused_multiply_sub(self, expr, enclosing_prec):
+def stringify_map_fused_multiply_add_sub(self, expr, enclosing_prec):
     from pymbolic.mapper.stringifier import PREC_NONE
-    return "fms(%s*%s-%s)" % (self.rec(expr.mul_op1, PREC_NONE),
-                              self.rec(expr.mul_op2, PREC_NONE),
-                              self.rec(expr.add_op, PREC_NONE))
+    if isinstance(expr, FusedMultiplyAdd):
+        return "fma(%s*%s+%s)" % (self.rec(expr.mul_op1, PREC_NONE),
+                                  self.rec(expr.mul_op2, PREC_NONE),
+                                  self.rec(expr.add_op, PREC_NONE))
+    else:
+        return "fms(%s*%s-%s)" % (self.rec(expr.mul_op1, PREC_NONE),
+                                  self.rec(expr.mul_op2, PREC_NONE),
+                                  self.rec(expr.add_op, PREC_NONE))
 
 
-def dependency_map_fused_multiply_add(self, expr):
+def dependency_map_fused_multiply_add_sub(self, expr):
     return self.combine((self.rec(expr.mul_op1),
                          self.rec(expr.mul_op2),
                          self.rec(expr.add_op)
                          ))
 
 
-def type_inference_fused_multiply_add(self, expr):
+def type_inference_fused_multiply_add_sub(self, expr):
     return self.rec(expr.mul_op1)
 
 
-def vectorizability_map_fused_multiply_add(self, expr):
+def vectorizability_map_fused_multiply_add_sub(self, expr):
     return all((self.rec(expr.mul_op1), self.rec(expr.mul_op2), self.rec(expr.add_op)))
 
 
@@ -145,21 +140,13 @@ lp.target.c.codegen.expression.ExpressionToCExpressionMapper.map_vectorized_sumf
 lp.type_inference.TypeInferenceMapper.map_vectorized_sumfact_kernel = needs_resolution
 
 # FusedMultiplyAdd node
-lp.symbolic.IdentityMapper.map_fused_multiply_add = identity_map_fused_multiply_add
-lp.symbolic.SubstitutionMapper.map_fused_multiply_add = lp.symbolic.SubstitutionMapper.map_variable
-lp.symbolic.WalkMapper.map_fused_multiply_add = walk_map_fused_multiply_add
-lp.symbolic.StringifyMapper.map_fused_multiply_add = stringify_map_fused_multiply_add
-lp.symbolic.DependencyMapper.map_fused_multiply_add = dependency_map_fused_multiply_add
-lp.type_inference.TypeInferenceMapper.map_fused_multiply_add = type_inference_fused_multiply_add
-lp.expression.VectorizabilityChecker.map_fused_multiply_add = vectorizability_map_fused_multiply_add
-
-lp.symbolic.IdentityMapper.map_fused_multiply_sub = identity_map_fused_multiply_sub
-lp.symbolic.SubstitutionMapper.map_fused_multiply_sub = lp.symbolic.SubstitutionMapper.map_variable
-lp.symbolic.WalkMapper.map_fused_multiply_sub = walk_map_fused_multiply_add
-lp.symbolic.StringifyMapper.map_fused_multiply_sub = stringify_map_fused_multiply_sub
-lp.symbolic.DependencyMapper.map_fused_multiply_sub = dependency_map_fused_multiply_add
-lp.type_inference.TypeInferenceMapper.map_fused_multiply_sub = type_inference_fused_multiply_add
-lp.expression.VectorizabilityChecker.map_fused_multiply_sub = vectorizability_map_fused_multiply_add
+lp.symbolic.IdentityMapper.map_fused_multiply_add_sub = identity_map_fused_multiply_add_sub
+lp.symbolic.SubstitutionMapper.map_fused_multiply_add_sub = lp.symbolic.SubstitutionMapper.map_variable
+lp.symbolic.WalkMapper.map_fused_multiply_add_sub = walk_map_fused_multiply_add_sub
+lp.symbolic.StringifyMapper.map_fused_multiply_add_sub = stringify_map_fused_multiply_add_sub
+lp.symbolic.DependencyMapper.map_fused_multiply_add_sub = dependency_map_fused_multiply_add_sub
+lp.type_inference.TypeInferenceMapper.map_fused_multiply_add_sub = type_inference_fused_multiply_add_sub
+lp.expression.VectorizabilityChecker.map_fused_multiply_add_sub = vectorizability_map_fused_multiply_add_sub
 
 #
 # Some helper functions!
diff --git a/python/dune/codegen/loopy/target.py b/python/dune/codegen/loopy/target.py
index 0d98e24a..abfbf19d 100644
--- a/python/dune/codegen/loopy/target.py
+++ b/python/dune/codegen/loopy/target.py
@@ -102,31 +102,23 @@ class DuneExpressionToCExpressionMapper(ExpressionToCExpressionMapper):
                 ret = Literal("{}({})".format(type_floatingpoint(), ret.s))
         return ret
 
-    def map_fused_multiply_add(self, expr, type_context):
+    def map_fused_multiply_add_sub(self, expr, type_context):
+        from dune.codegen.loopy.symbolic import FusedMultiplyAdd
         if self.codegen_state.vectorization_info:
             include_file("dune/codegen/common/muladd_workarounds.hh", filetag="operatorfile")
+            func = "mul_add" if isinstance(expr, FusedMultiplyAdd) else "mul_sub"
             # If this is vectorized we call the VCL function mul_add
-            return prim.Call(prim.Variable("mul_add"),
+            return prim.Call(prim.Variable(func),
                              (self.rec(expr.mul_op1, type_context),
                               self.rec(expr.mul_op2, type_context),
                               self.rec(expr.add_op, type_context)))
         else:
             # Default implementation that discards the node in favor of the resp.
             # additions and multiplications.
-            return self.rec(expr.mul_op1 * expr.mul_op2 + expr.add_op, type_context)
-
-    def map_fused_multiply_sub(self, expr, type_context):
-        if self.codegen_state.vectorization_info:
-            include_file("dune/codegen/common/muladd_workarounds.hh", filetag="operatorfile")
-            # If this is vectorized we call the VCL function mul_add
-            return prim.Call(prim.Variable("mul_sub"),
-                             (self.rec(expr.mul_op1, type_context),
-                              self.rec(expr.mul_op2, type_context),
-                              self.rec(expr.add_op, type_context)))
-        else:
-            # Default implementation that discards the node in favor of the resp.
-            # additions and multiplications.
-            return self.rec(expr.mul_op1 * expr.mul_op2 - expr.add_op, type_context)
+            if isinstance(expr, FusedMultiplyAdd):
+                return self.rec(expr.mul_op1 * expr.mul_op2 + expr.add_op, type_context)
+            else:
+                return self.rec(expr.mul_op1 * expr.mul_op2 - expr.add_op, type_context)
 
     def map_if(self, expr, type_context):
         if self.codegen_state.vectorization_info:
-- 
GitLab