Skip to content
Snippets Groups Projects
Commit 18609949 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

FMA working!

parent 95ed172c
No related branches found
No related tags found
No related merge requests found
...@@ -58,7 +58,7 @@ class FusedMultiplyAdd(prim.Expression): ...@@ -58,7 +58,7 @@ class FusedMultiplyAdd(prim.Expression):
return (self.mul_op1, self.mul_op2, self.add_op) return (self.mul_op1, self.mul_op2, self.add_op)
def stringifier(self): def stringifier(self):
return StringifyMapper return lp.symbolic.StringifyMapper
mapper_method = intern("map_fused_multiply_add") mapper_method = intern("map_fused_multiply_add")
...@@ -122,6 +122,10 @@ def type_inference_fused_multiply_add(self, expr): ...@@ -122,6 +122,10 @@ def type_inference_fused_multiply_add(self, expr):
return self.rec(expr.mul_op1) return self.rec(expr.mul_op1)
def vectorizability_map_fused_multiply_add(self, expr):
return all((self.rec(expr.mul_op1), self.rec(expr.mul_op2), self.rec(expr.add_op)))
# #
# Do the actual monkey patching!!! # Do the actual monkey patching!!!
# #
...@@ -141,7 +145,7 @@ lp.symbolic.WalkMapper.map_fused_multiply_add = walk_map_fused_multiply_add ...@@ -141,7 +145,7 @@ lp.symbolic.WalkMapper.map_fused_multiply_add = walk_map_fused_multiply_add
lp.symbolic.StringifyMapper.map_fused_multiply_add = stringify_map_fused_multiply_add lp.symbolic.StringifyMapper.map_fused_multiply_add = stringify_map_fused_multiply_add
lp.symbolic.DependencyMapper.map_fused_multiply_add = dependency_map_fused_multiply_add lp.symbolic.DependencyMapper.map_fused_multiply_add = dependency_map_fused_multiply_add
lp.type_inference.TypeInferenceMapper.map_fused_multiply_add = type_inference_fused_multiply_add lp.type_inference.TypeInferenceMapper.map_fused_multiply_add = type_inference_fused_multiply_add
lp.expression.VectorizabilityChecker.map_fused_multiply_add = vectorizability_map_fused_multiply_add
# #
# Some helper functions! # Some helper functions!
......
...@@ -83,9 +83,14 @@ class DuneExpressionToCExpressionMapper(ExpressionToCExpressionMapper): ...@@ -83,9 +83,14 @@ class DuneExpressionToCExpressionMapper(ExpressionToCExpressionMapper):
return ret return ret
def map_fused_multiply_add(self, expr, type_context): def map_fused_multiply_add(self, expr, type_context):
# Default implementation that discards the node in favor of the resp. if self.codegen_state.vectorization_info:
# additions and multiplications. # If this is vectorized we call the VCL function mul_add
return self.rec(expr.mul_op1 * expr.mul_op2 + expr.add_op, type_context) return prim.Call(prim.Variable("mul_add"),
(self.rec(expr.mul_op1), self.rec(expr.mul_op2), self.rec(expr.add_op)))
else:
# Default implementation that discards the node in favor of the resp.
# additions and multiplications.
return self.rec(expr.mul_op1 * expr.mul_op2 + expr.add_op, type_context)
class DuneCExpressionToCodeMapper(CExpressionToCodeMapper): class DuneCExpressionToCodeMapper(CExpressionToCodeMapper):
......
...@@ -59,3 +59,13 @@ def get_vcl_type(nptype, register_size=None, vec_size=None): ...@@ -59,3 +59,13 @@ def get_vcl_type(nptype, register_size=None, vec_size=None):
vec_size = register_size // (np.dtype(nptype).itemsize * 8) vec_size = register_size // (np.dtype(nptype).itemsize * 8)
return VCLTypeRegistry.types[np.dtype(nptype), vec_size] return VCLTypeRegistry.types[np.dtype(nptype), vec_size]
@function_mangler
def vcl_mul_add(knl, func, arg_dtypes):
if func == "mul_add":
# This is not 100% within the loopy philosophy, as we are
# passing the vector registers as references and have them
# changed. Loopy assumes this function to be read-only.
vcl = lp.types.NumpyType(get_vcl_type(np.float64, register_size=256))
return lp.CallMangleInfo("mul_add", (vcl), (vcl, vcl, vcl))
...@@ -500,6 +500,9 @@ def extract_kernel_from_cache(tag, wrap_in_cgen=True): ...@@ -500,6 +500,9 @@ def extract_kernel_from_cache(tag, wrap_in_cgen=True):
# *REALLY* ignore boostability. This is - so far - necessary due to a mystery bug. # *REALLY* ignore boostability. This is - so far - necessary due to a mystery bug.
kernel = kernel.copy(instructions=[i.copy(boostable=False, boostable_into=frozenset()) for i in kernel.instructions]) kernel = kernel.copy(instructions=[i.copy(boostable=False, boostable_into=frozenset()) for i in kernel.instructions])
from dune.perftool.loopy.transformations.matchfma import match_fused_multiply_add
kernel = match_fused_multiply_add(kernel)
if wrap_in_cgen: if wrap_in_cgen:
# Wrap the kernel in something which can generate code # Wrap the kernel in something which can generate code
from dune.perftool.pdelab.signatures import assembly_routine_signature from dune.perftool.pdelab.signatures import assembly_routine_signature
......
...@@ -136,4 +136,4 @@ class HasSumfactMapper(lp.symbolic.CombineMapper): ...@@ -136,4 +136,4 @@ class HasSumfactMapper(lp.symbolic.CombineMapper):
def find_sumfact(expr): def find_sumfact(expr):
return HasSumfactMapper()(expr) return HasSumfactMapper()(expr)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment