diff --git a/python/dune/perftool/loopy/symbolic.py b/python/dune/perftool/loopy/symbolic.py index 58768edf82f4c7e1d6358ae8090ddc9ca5cd2a1c..131bed1c11c98bf51a12659fa507d3edc3eba778 100644 --- a/python/dune/perftool/loopy/symbolic.py +++ b/python/dune/perftool/loopy/symbolic.py @@ -58,7 +58,7 @@ class FusedMultiplyAdd(prim.Expression): return (self.mul_op1, self.mul_op2, self.add_op) def stringifier(self): - return StringifyMapper + return lp.symbolic.StringifyMapper mapper_method = intern("map_fused_multiply_add") @@ -122,6 +122,10 @@ def type_inference_fused_multiply_add(self, expr): return self.rec(expr.mul_op1) +def vectorizability_map_fused_multiply_add(self, expr): + return all((self.rec(expr.mul_op1), self.rec(expr.mul_op2), self.rec(expr.add_op))) + + # # Do the actual monkey patching!!! # @@ -141,7 +145,7 @@ lp.symbolic.WalkMapper.map_fused_multiply_add = walk_map_fused_multiply_add lp.symbolic.StringifyMapper.map_fused_multiply_add = stringify_map_fused_multiply_add lp.symbolic.DependencyMapper.map_fused_multiply_add = dependency_map_fused_multiply_add lp.type_inference.TypeInferenceMapper.map_fused_multiply_add = type_inference_fused_multiply_add - +lp.expression.VectorizabilityChecker.map_fused_multiply_add = vectorizability_map_fused_multiply_add # # Some helper functions! diff --git a/python/dune/perftool/loopy/target.py b/python/dune/perftool/loopy/target.py index 18ebddbfa7fed5903281e69530d687668900d517..0d46b7e2d9c955fcb345b4f448158c732fc67c77 100644 --- a/python/dune/perftool/loopy/target.py +++ b/python/dune/perftool/loopy/target.py @@ -83,9 +83,14 @@ class DuneExpressionToCExpressionMapper(ExpressionToCExpressionMapper): return ret def map_fused_multiply_add(self, expr, type_context): - # Default implementation that discards the node in favor of the resp. - # additions and multiplications. - return self.rec(expr.mul_op1 * expr.mul_op2 + expr.add_op, type_context) + if self.codegen_state.vectorization_info: + # If this is vectorized we call the VCL function mul_add + return prim.Call(prim.Variable("mul_add"), + (self.rec(expr.mul_op1), self.rec(expr.mul_op2), self.rec(expr.add_op))) + else: + # Default implementation that discards the node in favor of the resp. + # additions and multiplications. + return self.rec(expr.mul_op1 * expr.mul_op2 + expr.add_op, type_context) class DuneCExpressionToCodeMapper(CExpressionToCodeMapper): diff --git a/python/dune/perftool/loopy/vcl.py b/python/dune/perftool/loopy/vcl.py index 78dea3cfdf05c34d24b9ba872334a61d9b16d5be..2d8bf6cc7ed9161587b9f7d7442ebd254dcc719b 100644 --- a/python/dune/perftool/loopy/vcl.py +++ b/python/dune/perftool/loopy/vcl.py @@ -59,3 +59,13 @@ def get_vcl_type(nptype, register_size=None, vec_size=None): vec_size = register_size // (np.dtype(nptype).itemsize * 8) return VCLTypeRegistry.types[np.dtype(nptype), vec_size] + + +@function_mangler +def vcl_mul_add(knl, func, arg_dtypes): + if func == "mul_add": + # This is not 100% within the loopy philosophy, as we are + # passing the vector registers as references and have them + # changed. Loopy assumes this function to be read-only. + vcl = lp.types.NumpyType(get_vcl_type(np.float64, register_size=256)) + return lp.CallMangleInfo("mul_add", (vcl), (vcl, vcl, vcl)) diff --git a/python/dune/perftool/pdelab/localoperator.py b/python/dune/perftool/pdelab/localoperator.py index 4707d96f4d4bb3d83e0796b5532038ee537f29dd..e214b5239eaa628832eaf8a3c30585d2db59bff2 100644 --- a/python/dune/perftool/pdelab/localoperator.py +++ b/python/dune/perftool/pdelab/localoperator.py @@ -500,6 +500,9 @@ def extract_kernel_from_cache(tag, wrap_in_cgen=True): # *REALLY* ignore boostability. This is - so far - necessary due to a mystery bug. kernel = kernel.copy(instructions=[i.copy(boostable=False, boostable_into=frozenset()) for i in kernel.instructions]) + from dune.perftool.loopy.transformations.matchfma import match_fused_multiply_add + kernel = match_fused_multiply_add(kernel) + if wrap_in_cgen: # Wrap the kernel in something which can generate code from dune.perftool.pdelab.signatures import assembly_routine_signature diff --git a/python/dune/perftool/sumfact/vectorization.py b/python/dune/perftool/sumfact/vectorization.py index 6af1c14ddf5f11e5416da8f26a85a39556f7f457..100f3b2149d0630acaccaa6c6347bc9f2bae4849 100644 --- a/python/dune/perftool/sumfact/vectorization.py +++ b/python/dune/perftool/sumfact/vectorization.py @@ -136,4 +136,4 @@ class HasSumfactMapper(lp.symbolic.CombineMapper): def find_sumfact(expr): - return HasSumfactMapper()(expr) \ No newline at end of file + return HasSumfactMapper()(expr)