From 15a35f9177d55d8a4fd8dd2bed2004617f73a451 Mon Sep 17 00:00:00 2001 From: Marcel Koch <marcel.koch@uni-muenster.de> Date: Thu, 31 Jan 2019 14:37:19 +0100 Subject: [PATCH] reduce code duplication --- python/dune/codegen/loopy/symbolic.py | 79 +++++++++++---------------- python/dune/codegen/loopy/target.py | 24 +++----- 2 files changed, 41 insertions(+), 62 deletions(-) diff --git a/python/dune/codegen/loopy/symbolic.py b/python/dune/codegen/loopy/symbolic.py index 92772d6c..16ee7ca1 100644 --- a/python/dune/codegen/loopy/symbolic.py +++ b/python/dune/codegen/loopy/symbolic.py @@ -32,18 +32,16 @@ class FusedMultiplyAddSubBase(prim.Expression): def stringifier(self): return lp.symbolic.StringifyMapper + mapper_method = intern("map_fused_multiply_add_sub") + class FusedMultiplyAdd(FusedMultiplyAddSubBase): """ Represents an FMA operation """ - mapper_method = intern("map_fused_multiply_add") - class FusedMultiplySub(FusedMultiplyAddSubBase): """ Represents an FMS operation """ - mapper_method = intern("map_fused_multiply_sub") - # # Mapper methods to monkey patch into the visitor base classes! @@ -70,21 +68,20 @@ def needs_resolution(self, expr): raise CodegenError("SumfactKernel node is a placeholder and needs to be removed!") -def identity_map_fused_multiply_add(self, expr, *args): - return FusedMultiplyAdd(self.rec(expr.mul_op1, *args), - self.rec(expr.mul_op2, *args), - self.rec(expr.add_op, *args), - ) - +def identity_map_fused_multiply_add_sub(self, expr, *args): + if isinstance(expr, FusedMultiplyAdd): + return FusedMultiplyAdd(self.rec(expr.mul_op1, *args), + self.rec(expr.mul_op2, *args), + self.rec(expr.add_op, *args), + ) + else: + return FusedMultiplySub(self.rec(expr.mul_op1, *args), + self.rec(expr.mul_op2, *args), + self.rec(expr.add_op, *args), + ) -def identity_map_fused_multiply_sub(self, expr, *args): - return FusedMultiplySub(self.rec(expr.mul_op1, *args), - self.rec(expr.mul_op2, *args), - self.rec(expr.add_op, *args), - ) - -def walk_map_fused_multiply_add(self, expr, *args): +def walk_map_fused_multiply_add_sub(self, expr, *args): if not self.visit(expr): return @@ -93,32 +90,30 @@ def walk_map_fused_multiply_add(self, expr, *args): self.rec(expr.add_op, *args) -def stringify_map_fused_multiply_add(self, expr, enclosing_prec): - from pymbolic.mapper.stringifier import PREC_NONE - return "fma(%s*%s+%s)" % (self.rec(expr.mul_op1, PREC_NONE), - self.rec(expr.mul_op2, PREC_NONE), - self.rec(expr.add_op, PREC_NONE)) - - -def stringify_map_fused_multiply_sub(self, expr, enclosing_prec): +def stringify_map_fused_multiply_add_sub(self, expr, enclosing_prec): from pymbolic.mapper.stringifier import PREC_NONE - return "fms(%s*%s-%s)" % (self.rec(expr.mul_op1, PREC_NONE), - self.rec(expr.mul_op2, PREC_NONE), - self.rec(expr.add_op, PREC_NONE)) + if isinstance(expr, FusedMultiplyAdd): + return "fma(%s*%s+%s)" % (self.rec(expr.mul_op1, PREC_NONE), + self.rec(expr.mul_op2, PREC_NONE), + self.rec(expr.add_op, PREC_NONE)) + else: + return "fms(%s*%s-%s)" % (self.rec(expr.mul_op1, PREC_NONE), + self.rec(expr.mul_op2, PREC_NONE), + self.rec(expr.add_op, PREC_NONE)) -def dependency_map_fused_multiply_add(self, expr): +def dependency_map_fused_multiply_add_sub(self, expr): return self.combine((self.rec(expr.mul_op1), self.rec(expr.mul_op2), self.rec(expr.add_op) )) -def type_inference_fused_multiply_add(self, expr): +def type_inference_fused_multiply_add_sub(self, expr): return self.rec(expr.mul_op1) -def vectorizability_map_fused_multiply_add(self, expr): +def vectorizability_map_fused_multiply_add_sub(self, expr): return all((self.rec(expr.mul_op1), self.rec(expr.mul_op2), self.rec(expr.add_op))) @@ -145,21 +140,13 @@ lp.target.c.codegen.expression.ExpressionToCExpressionMapper.map_vectorized_sumf lp.type_inference.TypeInferenceMapper.map_vectorized_sumfact_kernel = needs_resolution # FusedMultiplyAdd node -lp.symbolic.IdentityMapper.map_fused_multiply_add = identity_map_fused_multiply_add -lp.symbolic.SubstitutionMapper.map_fused_multiply_add = lp.symbolic.SubstitutionMapper.map_variable -lp.symbolic.WalkMapper.map_fused_multiply_add = walk_map_fused_multiply_add -lp.symbolic.StringifyMapper.map_fused_multiply_add = stringify_map_fused_multiply_add -lp.symbolic.DependencyMapper.map_fused_multiply_add = dependency_map_fused_multiply_add -lp.type_inference.TypeInferenceMapper.map_fused_multiply_add = type_inference_fused_multiply_add -lp.expression.VectorizabilityChecker.map_fused_multiply_add = vectorizability_map_fused_multiply_add - -lp.symbolic.IdentityMapper.map_fused_multiply_sub = identity_map_fused_multiply_sub -lp.symbolic.SubstitutionMapper.map_fused_multiply_sub = lp.symbolic.SubstitutionMapper.map_variable -lp.symbolic.WalkMapper.map_fused_multiply_sub = walk_map_fused_multiply_add -lp.symbolic.StringifyMapper.map_fused_multiply_sub = stringify_map_fused_multiply_sub -lp.symbolic.DependencyMapper.map_fused_multiply_sub = dependency_map_fused_multiply_add -lp.type_inference.TypeInferenceMapper.map_fused_multiply_sub = type_inference_fused_multiply_add -lp.expression.VectorizabilityChecker.map_fused_multiply_sub = vectorizability_map_fused_multiply_add +lp.symbolic.IdentityMapper.map_fused_multiply_add_sub = identity_map_fused_multiply_add_sub +lp.symbolic.SubstitutionMapper.map_fused_multiply_add_sub = lp.symbolic.SubstitutionMapper.map_variable +lp.symbolic.WalkMapper.map_fused_multiply_add_sub = walk_map_fused_multiply_add_sub +lp.symbolic.StringifyMapper.map_fused_multiply_add_sub = stringify_map_fused_multiply_add_sub +lp.symbolic.DependencyMapper.map_fused_multiply_add_sub = dependency_map_fused_multiply_add_sub +lp.type_inference.TypeInferenceMapper.map_fused_multiply_add_sub = type_inference_fused_multiply_add_sub +lp.expression.VectorizabilityChecker.map_fused_multiply_add_sub = vectorizability_map_fused_multiply_add_sub # # Some helper functions! diff --git a/python/dune/codegen/loopy/target.py b/python/dune/codegen/loopy/target.py index 0d98e24a..abfbf19d 100644 --- a/python/dune/codegen/loopy/target.py +++ b/python/dune/codegen/loopy/target.py @@ -102,31 +102,23 @@ class DuneExpressionToCExpressionMapper(ExpressionToCExpressionMapper): ret = Literal("{}({})".format(type_floatingpoint(), ret.s)) return ret - def map_fused_multiply_add(self, expr, type_context): + def map_fused_multiply_add_sub(self, expr, type_context): + from dune.codegen.loopy.symbolic import FusedMultiplyAdd if self.codegen_state.vectorization_info: include_file("dune/codegen/common/muladd_workarounds.hh", filetag="operatorfile") + func = "mul_add" if isinstance(expr, FusedMultiplyAdd) else "mul_sub" # If this is vectorized we call the VCL function mul_add - return prim.Call(prim.Variable("mul_add"), + return prim.Call(prim.Variable(func), (self.rec(expr.mul_op1, type_context), self.rec(expr.mul_op2, type_context), self.rec(expr.add_op, type_context))) else: # Default implementation that discards the node in favor of the resp. # additions and multiplications. - return self.rec(expr.mul_op1 * expr.mul_op2 + expr.add_op, type_context) - - def map_fused_multiply_sub(self, expr, type_context): - if self.codegen_state.vectorization_info: - include_file("dune/codegen/common/muladd_workarounds.hh", filetag="operatorfile") - # If this is vectorized we call the VCL function mul_add - return prim.Call(prim.Variable("mul_sub"), - (self.rec(expr.mul_op1, type_context), - self.rec(expr.mul_op2, type_context), - self.rec(expr.add_op, type_context))) - else: - # Default implementation that discards the node in favor of the resp. - # additions and multiplications. - return self.rec(expr.mul_op1 * expr.mul_op2 - expr.add_op, type_context) + if isinstance(expr, FusedMultiplyAdd): + return self.rec(expr.mul_op1 * expr.mul_op2 + expr.add_op, type_context) + else: + return self.rec(expr.mul_op1 * expr.mul_op2 - expr.add_op, type_context) def map_if(self, expr, type_context): if self.codegen_state.vectorization_info: -- GitLab