diff --git a/python/dune/perftool/blockstructured/vectorization.py b/python/dune/perftool/blockstructured/vectorization.py index d514198f5c29c4cea95997adcaf44f3be456494b..2e663ffe5acaf0fd06238830d5871083ec3adb08 100644 --- a/python/dune/perftool/blockstructured/vectorization.py +++ b/python/dune/perftool/blockstructured/vectorization.py @@ -4,6 +4,7 @@ import pymbolic.primitives as prim from loopy.match import Tagged, Id +from dune.perftool.generation import get_global_context_value from dune.perftool.loopy.target import dtype_floatingpoint from dune.perftool.loopy.temporary import DuneTemporaryVariable from dune.perftool.loopy.symbolic import substitute @@ -352,9 +353,21 @@ def add_iname_array(knl, vec_iname): return knl +def replace_vcl_functions(knl, func_names): + replacemap = dict() + for name in func_names: + replacemap[name] = prim.Variable('vcl_' + name) + + new_insns = [] + for insn in knl.instructions: + new_insns.append(insn.with_transformed_expressions(lambda expr: substitute(expr, replacemap))) + + return knl.copy(instructions=new_insns) + + def vectorize_micro_elements(knl): vec_iname = "subel_x" - if vec_iname in knl.all_inames(): + if vec_iname in knl.all_inames() and get_global_context_value('integral_type') == 'cell': vcl_size = get_vcl_type_size(np.float64) assert get_form_option('number_of_blocks') % vcl_size == 0 @@ -370,4 +383,6 @@ def vectorize_micro_elements(knl): knl = add_vcl_temporaries(knl) knl = add_vcl_accum_insns(knl, vec_iname + '_inner', vec_iname + '_outer') knl = add_vcl_access(knl, vec_iname + '_inner') + + knl = replace_vcl_functions(knl, ['abs']) return knl diff --git a/python/dune/perftool/loopy/vcl.py b/python/dune/perftool/loopy/vcl.py index 345dec931596c07c8641a03f5c0c5035f86d06b6..c0e1d57f277fee67b8779f43890e12c4284cc2d4 100644 --- a/python/dune/perftool/loopy/vcl.py +++ b/python/dune/perftool/loopy/vcl.py @@ -161,3 +161,11 @@ def vcl_store_and_load_mangler(knl, func, arg_dtypes): if isinstance(func, VCLStore): return lp.CallMangleInfo(func.name, (), (lp.types.NumpyType(np.int32),)) + + +@function_mangler +def vcl_math_mangler(knl, func, arg_dtypes): + if func == 'vcl_abs': + dtype = arg_dtypes[0] + vcl = lp.types.NumpyType(get_vcl_type(dtype)) + return lp.CallMangleInfo('abs', (vcl,), (vcl,)) \ No newline at end of file