Skip to content
Snippets Groups Projects
Commit 15a35f91 authored by Marcel Koch's avatar Marcel Koch
Browse files

reduce code duplication

parent bee8a5f4
No related branches found
No related tags found
No related merge requests found
...@@ -32,18 +32,16 @@ class FusedMultiplyAddSubBase(prim.Expression): ...@@ -32,18 +32,16 @@ class FusedMultiplyAddSubBase(prim.Expression):
def stringifier(self): def stringifier(self):
return lp.symbolic.StringifyMapper return lp.symbolic.StringifyMapper
mapper_method = intern("map_fused_multiply_add_sub")
class FusedMultiplyAdd(FusedMultiplyAddSubBase): class FusedMultiplyAdd(FusedMultiplyAddSubBase):
""" Represents an FMA operation """ """ Represents an FMA operation """
mapper_method = intern("map_fused_multiply_add")
class FusedMultiplySub(FusedMultiplyAddSubBase): class FusedMultiplySub(FusedMultiplyAddSubBase):
""" Represents an FMS operation """ """ Represents an FMS operation """
mapper_method = intern("map_fused_multiply_sub")
# #
# Mapper methods to monkey patch into the visitor base classes! # Mapper methods to monkey patch into the visitor base classes!
...@@ -70,21 +68,20 @@ def needs_resolution(self, expr): ...@@ -70,21 +68,20 @@ def needs_resolution(self, expr):
raise CodegenError("SumfactKernel node is a placeholder and needs to be removed!") raise CodegenError("SumfactKernel node is a placeholder and needs to be removed!")
def identity_map_fused_multiply_add(self, expr, *args): def identity_map_fused_multiply_add_sub(self, expr, *args):
return FusedMultiplyAdd(self.rec(expr.mul_op1, *args), if isinstance(expr, FusedMultiplyAdd):
self.rec(expr.mul_op2, *args), return FusedMultiplyAdd(self.rec(expr.mul_op1, *args),
self.rec(expr.add_op, *args), self.rec(expr.mul_op2, *args),
) self.rec(expr.add_op, *args),
)
else:
return FusedMultiplySub(self.rec(expr.mul_op1, *args),
self.rec(expr.mul_op2, *args),
self.rec(expr.add_op, *args),
)
def identity_map_fused_multiply_sub(self, expr, *args):
return FusedMultiplySub(self.rec(expr.mul_op1, *args),
self.rec(expr.mul_op2, *args),
self.rec(expr.add_op, *args),
)
def walk_map_fused_multiply_add_sub(self, expr, *args):
def walk_map_fused_multiply_add(self, expr, *args):
if not self.visit(expr): if not self.visit(expr):
return return
...@@ -93,32 +90,30 @@ def walk_map_fused_multiply_add(self, expr, *args): ...@@ -93,32 +90,30 @@ def walk_map_fused_multiply_add(self, expr, *args):
self.rec(expr.add_op, *args) self.rec(expr.add_op, *args)
def stringify_map_fused_multiply_add(self, expr, enclosing_prec): def stringify_map_fused_multiply_add_sub(self, expr, enclosing_prec):
from pymbolic.mapper.stringifier import PREC_NONE
return "fma(%s*%s+%s)" % (self.rec(expr.mul_op1, PREC_NONE),
self.rec(expr.mul_op2, PREC_NONE),
self.rec(expr.add_op, PREC_NONE))
def stringify_map_fused_multiply_sub(self, expr, enclosing_prec):
from pymbolic.mapper.stringifier import PREC_NONE from pymbolic.mapper.stringifier import PREC_NONE
return "fms(%s*%s-%s)" % (self.rec(expr.mul_op1, PREC_NONE), if isinstance(expr, FusedMultiplyAdd):
self.rec(expr.mul_op2, PREC_NONE), return "fma(%s*%s+%s)" % (self.rec(expr.mul_op1, PREC_NONE),
self.rec(expr.add_op, PREC_NONE)) self.rec(expr.mul_op2, PREC_NONE),
self.rec(expr.add_op, PREC_NONE))
else:
return "fms(%s*%s-%s)" % (self.rec(expr.mul_op1, PREC_NONE),
self.rec(expr.mul_op2, PREC_NONE),
self.rec(expr.add_op, PREC_NONE))
def dependency_map_fused_multiply_add(self, expr): def dependency_map_fused_multiply_add_sub(self, expr):
return self.combine((self.rec(expr.mul_op1), return self.combine((self.rec(expr.mul_op1),
self.rec(expr.mul_op2), self.rec(expr.mul_op2),
self.rec(expr.add_op) self.rec(expr.add_op)
)) ))
def type_inference_fused_multiply_add(self, expr): def type_inference_fused_multiply_add_sub(self, expr):
return self.rec(expr.mul_op1) return self.rec(expr.mul_op1)
def vectorizability_map_fused_multiply_add(self, expr): def vectorizability_map_fused_multiply_add_sub(self, expr):
return all((self.rec(expr.mul_op1), self.rec(expr.mul_op2), self.rec(expr.add_op))) return all((self.rec(expr.mul_op1), self.rec(expr.mul_op2), self.rec(expr.add_op)))
...@@ -145,21 +140,13 @@ lp.target.c.codegen.expression.ExpressionToCExpressionMapper.map_vectorized_sumf ...@@ -145,21 +140,13 @@ lp.target.c.codegen.expression.ExpressionToCExpressionMapper.map_vectorized_sumf
lp.type_inference.TypeInferenceMapper.map_vectorized_sumfact_kernel = needs_resolution lp.type_inference.TypeInferenceMapper.map_vectorized_sumfact_kernel = needs_resolution
# FusedMultiplyAdd node # FusedMultiplyAdd node
lp.symbolic.IdentityMapper.map_fused_multiply_add = identity_map_fused_multiply_add lp.symbolic.IdentityMapper.map_fused_multiply_add_sub = identity_map_fused_multiply_add_sub
lp.symbolic.SubstitutionMapper.map_fused_multiply_add = lp.symbolic.SubstitutionMapper.map_variable lp.symbolic.SubstitutionMapper.map_fused_multiply_add_sub = lp.symbolic.SubstitutionMapper.map_variable
lp.symbolic.WalkMapper.map_fused_multiply_add = walk_map_fused_multiply_add lp.symbolic.WalkMapper.map_fused_multiply_add_sub = walk_map_fused_multiply_add_sub
lp.symbolic.StringifyMapper.map_fused_multiply_add = stringify_map_fused_multiply_add lp.symbolic.StringifyMapper.map_fused_multiply_add_sub = stringify_map_fused_multiply_add_sub
lp.symbolic.DependencyMapper.map_fused_multiply_add = dependency_map_fused_multiply_add lp.symbolic.DependencyMapper.map_fused_multiply_add_sub = dependency_map_fused_multiply_add_sub
lp.type_inference.TypeInferenceMapper.map_fused_multiply_add = type_inference_fused_multiply_add lp.type_inference.TypeInferenceMapper.map_fused_multiply_add_sub = type_inference_fused_multiply_add_sub
lp.expression.VectorizabilityChecker.map_fused_multiply_add = vectorizability_map_fused_multiply_add lp.expression.VectorizabilityChecker.map_fused_multiply_add_sub = vectorizability_map_fused_multiply_add_sub
lp.symbolic.IdentityMapper.map_fused_multiply_sub = identity_map_fused_multiply_sub
lp.symbolic.SubstitutionMapper.map_fused_multiply_sub = lp.symbolic.SubstitutionMapper.map_variable
lp.symbolic.WalkMapper.map_fused_multiply_sub = walk_map_fused_multiply_add
lp.symbolic.StringifyMapper.map_fused_multiply_sub = stringify_map_fused_multiply_sub
lp.symbolic.DependencyMapper.map_fused_multiply_sub = dependency_map_fused_multiply_add
lp.type_inference.TypeInferenceMapper.map_fused_multiply_sub = type_inference_fused_multiply_add
lp.expression.VectorizabilityChecker.map_fused_multiply_sub = vectorizability_map_fused_multiply_add
# #
# Some helper functions! # Some helper functions!
......
...@@ -102,31 +102,23 @@ class DuneExpressionToCExpressionMapper(ExpressionToCExpressionMapper): ...@@ -102,31 +102,23 @@ class DuneExpressionToCExpressionMapper(ExpressionToCExpressionMapper):
ret = Literal("{}({})".format(type_floatingpoint(), ret.s)) ret = Literal("{}({})".format(type_floatingpoint(), ret.s))
return ret return ret
def map_fused_multiply_add(self, expr, type_context): def map_fused_multiply_add_sub(self, expr, type_context):
from dune.codegen.loopy.symbolic import FusedMultiplyAdd
if self.codegen_state.vectorization_info: if self.codegen_state.vectorization_info:
include_file("dune/codegen/common/muladd_workarounds.hh", filetag="operatorfile") include_file("dune/codegen/common/muladd_workarounds.hh", filetag="operatorfile")
func = "mul_add" if isinstance(expr, FusedMultiplyAdd) else "mul_sub"
# If this is vectorized we call the VCL function mul_add # If this is vectorized we call the VCL function mul_add
return prim.Call(prim.Variable("mul_add"), return prim.Call(prim.Variable(func),
(self.rec(expr.mul_op1, type_context), (self.rec(expr.mul_op1, type_context),
self.rec(expr.mul_op2, type_context), self.rec(expr.mul_op2, type_context),
self.rec(expr.add_op, type_context))) self.rec(expr.add_op, type_context)))
else: else:
# Default implementation that discards the node in favor of the resp. # Default implementation that discards the node in favor of the resp.
# additions and multiplications. # additions and multiplications.
return self.rec(expr.mul_op1 * expr.mul_op2 + expr.add_op, type_context) if isinstance(expr, FusedMultiplyAdd):
return self.rec(expr.mul_op1 * expr.mul_op2 + expr.add_op, type_context)
def map_fused_multiply_sub(self, expr, type_context): else:
if self.codegen_state.vectorization_info: return self.rec(expr.mul_op1 * expr.mul_op2 - expr.add_op, type_context)
include_file("dune/codegen/common/muladd_workarounds.hh", filetag="operatorfile")
# If this is vectorized we call the VCL function mul_add
return prim.Call(prim.Variable("mul_sub"),
(self.rec(expr.mul_op1, type_context),
self.rec(expr.mul_op2, type_context),
self.rec(expr.add_op, type_context)))
else:
# Default implementation that discards the node in favor of the resp.
# additions and multiplications.
return self.rec(expr.mul_op1 * expr.mul_op2 - expr.add_op, type_context)
def map_if(self, expr, type_context): def map_if(self, expr, type_context):
if self.codegen_state.vectorization_info: if self.codegen_state.vectorization_info:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment