Skip to content
Snippets Groups Projects
Commit 15a35f91 authored by Marcel Koch's avatar Marcel Koch
Browse files

reduce code duplication

parent bee8a5f4
No related branches found
No related tags found
No related merge requests found
......@@ -32,18 +32,16 @@ class FusedMultiplyAddSubBase(prim.Expression):
def stringifier(self):
return lp.symbolic.StringifyMapper
mapper_method = intern("map_fused_multiply_add_sub")
class FusedMultiplyAdd(FusedMultiplyAddSubBase):
""" Represents an FMA operation """
mapper_method = intern("map_fused_multiply_add")
class FusedMultiplySub(FusedMultiplyAddSubBase):
""" Represents an FMS operation """
mapper_method = intern("map_fused_multiply_sub")
#
# Mapper methods to monkey patch into the visitor base classes!
......@@ -70,21 +68,20 @@ def needs_resolution(self, expr):
raise CodegenError("SumfactKernel node is a placeholder and needs to be removed!")
def identity_map_fused_multiply_add(self, expr, *args):
return FusedMultiplyAdd(self.rec(expr.mul_op1, *args),
self.rec(expr.mul_op2, *args),
self.rec(expr.add_op, *args),
)
def identity_map_fused_multiply_add_sub(self, expr, *args):
if isinstance(expr, FusedMultiplyAdd):
return FusedMultiplyAdd(self.rec(expr.mul_op1, *args),
self.rec(expr.mul_op2, *args),
self.rec(expr.add_op, *args),
)
else:
return FusedMultiplySub(self.rec(expr.mul_op1, *args),
self.rec(expr.mul_op2, *args),
self.rec(expr.add_op, *args),
)
def identity_map_fused_multiply_sub(self, expr, *args):
return FusedMultiplySub(self.rec(expr.mul_op1, *args),
self.rec(expr.mul_op2, *args),
self.rec(expr.add_op, *args),
)
def walk_map_fused_multiply_add(self, expr, *args):
def walk_map_fused_multiply_add_sub(self, expr, *args):
if not self.visit(expr):
return
......@@ -93,32 +90,30 @@ def walk_map_fused_multiply_add(self, expr, *args):
self.rec(expr.add_op, *args)
def stringify_map_fused_multiply_add(self, expr, enclosing_prec):
from pymbolic.mapper.stringifier import PREC_NONE
return "fma(%s*%s+%s)" % (self.rec(expr.mul_op1, PREC_NONE),
self.rec(expr.mul_op2, PREC_NONE),
self.rec(expr.add_op, PREC_NONE))
def stringify_map_fused_multiply_sub(self, expr, enclosing_prec):
def stringify_map_fused_multiply_add_sub(self, expr, enclosing_prec):
from pymbolic.mapper.stringifier import PREC_NONE
return "fms(%s*%s-%s)" % (self.rec(expr.mul_op1, PREC_NONE),
self.rec(expr.mul_op2, PREC_NONE),
self.rec(expr.add_op, PREC_NONE))
if isinstance(expr, FusedMultiplyAdd):
return "fma(%s*%s+%s)" % (self.rec(expr.mul_op1, PREC_NONE),
self.rec(expr.mul_op2, PREC_NONE),
self.rec(expr.add_op, PREC_NONE))
else:
return "fms(%s*%s-%s)" % (self.rec(expr.mul_op1, PREC_NONE),
self.rec(expr.mul_op2, PREC_NONE),
self.rec(expr.add_op, PREC_NONE))
def dependency_map_fused_multiply_add(self, expr):
def dependency_map_fused_multiply_add_sub(self, expr):
return self.combine((self.rec(expr.mul_op1),
self.rec(expr.mul_op2),
self.rec(expr.add_op)
))
def type_inference_fused_multiply_add(self, expr):
def type_inference_fused_multiply_add_sub(self, expr):
return self.rec(expr.mul_op1)
def vectorizability_map_fused_multiply_add(self, expr):
def vectorizability_map_fused_multiply_add_sub(self, expr):
return all((self.rec(expr.mul_op1), self.rec(expr.mul_op2), self.rec(expr.add_op)))
......@@ -145,21 +140,13 @@ lp.target.c.codegen.expression.ExpressionToCExpressionMapper.map_vectorized_sumf
lp.type_inference.TypeInferenceMapper.map_vectorized_sumfact_kernel = needs_resolution
# FusedMultiplyAdd node
lp.symbolic.IdentityMapper.map_fused_multiply_add = identity_map_fused_multiply_add
lp.symbolic.SubstitutionMapper.map_fused_multiply_add = lp.symbolic.SubstitutionMapper.map_variable
lp.symbolic.WalkMapper.map_fused_multiply_add = walk_map_fused_multiply_add
lp.symbolic.StringifyMapper.map_fused_multiply_add = stringify_map_fused_multiply_add
lp.symbolic.DependencyMapper.map_fused_multiply_add = dependency_map_fused_multiply_add
lp.type_inference.TypeInferenceMapper.map_fused_multiply_add = type_inference_fused_multiply_add
lp.expression.VectorizabilityChecker.map_fused_multiply_add = vectorizability_map_fused_multiply_add
lp.symbolic.IdentityMapper.map_fused_multiply_sub = identity_map_fused_multiply_sub
lp.symbolic.SubstitutionMapper.map_fused_multiply_sub = lp.symbolic.SubstitutionMapper.map_variable
lp.symbolic.WalkMapper.map_fused_multiply_sub = walk_map_fused_multiply_add
lp.symbolic.StringifyMapper.map_fused_multiply_sub = stringify_map_fused_multiply_sub
lp.symbolic.DependencyMapper.map_fused_multiply_sub = dependency_map_fused_multiply_add
lp.type_inference.TypeInferenceMapper.map_fused_multiply_sub = type_inference_fused_multiply_add
lp.expression.VectorizabilityChecker.map_fused_multiply_sub = vectorizability_map_fused_multiply_add
lp.symbolic.IdentityMapper.map_fused_multiply_add_sub = identity_map_fused_multiply_add_sub
lp.symbolic.SubstitutionMapper.map_fused_multiply_add_sub = lp.symbolic.SubstitutionMapper.map_variable
lp.symbolic.WalkMapper.map_fused_multiply_add_sub = walk_map_fused_multiply_add_sub
lp.symbolic.StringifyMapper.map_fused_multiply_add_sub = stringify_map_fused_multiply_add_sub
lp.symbolic.DependencyMapper.map_fused_multiply_add_sub = dependency_map_fused_multiply_add_sub
lp.type_inference.TypeInferenceMapper.map_fused_multiply_add_sub = type_inference_fused_multiply_add_sub
lp.expression.VectorizabilityChecker.map_fused_multiply_add_sub = vectorizability_map_fused_multiply_add_sub
#
# Some helper functions!
......
......@@ -102,31 +102,23 @@ class DuneExpressionToCExpressionMapper(ExpressionToCExpressionMapper):
ret = Literal("{}({})".format(type_floatingpoint(), ret.s))
return ret
def map_fused_multiply_add(self, expr, type_context):
def map_fused_multiply_add_sub(self, expr, type_context):
from dune.codegen.loopy.symbolic import FusedMultiplyAdd
if self.codegen_state.vectorization_info:
include_file("dune/codegen/common/muladd_workarounds.hh", filetag="operatorfile")
func = "mul_add" if isinstance(expr, FusedMultiplyAdd) else "mul_sub"
# If this is vectorized we call the VCL function mul_add
return prim.Call(prim.Variable("mul_add"),
return prim.Call(prim.Variable(func),
(self.rec(expr.mul_op1, type_context),
self.rec(expr.mul_op2, type_context),
self.rec(expr.add_op, type_context)))
else:
# Default implementation that discards the node in favor of the resp.
# additions and multiplications.
return self.rec(expr.mul_op1 * expr.mul_op2 + expr.add_op, type_context)
def map_fused_multiply_sub(self, expr, type_context):
if self.codegen_state.vectorization_info:
include_file("dune/codegen/common/muladd_workarounds.hh", filetag="operatorfile")
# If this is vectorized we call the VCL function mul_add
return prim.Call(prim.Variable("mul_sub"),
(self.rec(expr.mul_op1, type_context),
self.rec(expr.mul_op2, type_context),
self.rec(expr.add_op, type_context)))
else:
# Default implementation that discards the node in favor of the resp.
# additions and multiplications.
return self.rec(expr.mul_op1 * expr.mul_op2 - expr.add_op, type_context)
if isinstance(expr, FusedMultiplyAdd):
return self.rec(expr.mul_op1 * expr.mul_op2 + expr.add_op, type_context)
else:
return self.rec(expr.mul_op1 * expr.mul_op2 - expr.add_op, type_context)
def map_if(self, expr, type_context):
if self.codegen_state.vectorization_info:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment