From 77c56f647c05bdf0d1718f9995761fe6cbb62643 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9=20He=C3=9F?= <rene.hess@iwr.uni-heidelberg.de> Date: Thu, 7 Sep 2017 11:22:43 +0200 Subject: [PATCH] Do not generate code for stage 1 sumfact kernels that don't get used Save all stage 1 sum factorization kernels that are used in accumulation expression in the cache during the dry run. Discard all inactive sum factorization kernels in decide_vetorization_strategy. --- python/dune/perftool/sumfact/accumulation.py | 23 +++++++++++++ python/dune/perftool/sumfact/basis.py | 10 ++++++ python/dune/perftool/sumfact/realization.py | 1 - python/dune/perftool/sumfact/vectorization.py | 33 ++++++++++++++----- 4 files changed, 57 insertions(+), 10 deletions(-) diff --git a/python/dune/perftool/sumfact/accumulation.py b/python/dune/perftool/sumfact/accumulation.py index 50f4d169..c9c32b48 100644 --- a/python/dune/perftool/sumfact/accumulation.py +++ b/python/dune/perftool/sumfact/accumulation.py @@ -6,6 +6,7 @@ from dune.perftool.pdelab.argument import (name_accumulation_variable, from dune.perftool.generation import (backend, domain, dump_accumulate_timer, + generator_factory, get_counted_variable, get_counter, iname, @@ -40,10 +41,29 @@ from pytools import ImmutableRecord import loopy as lp import numpy as np import pymbolic.primitives as prim +from pymbolic.mapper import WalkMapper import ufl.classes as uc from ufl import FiniteElement, MixedElement, TensorProductElement +basis_sf_kernels = generator_factory(item_tags=("basis_sf_kernels",), context_tags='kernel', no_deco=True) + + +class SumfactCollectMapper(WalkMapper): + def map_tagged_variable(self, expr, *args, **kwargs): + self.visit(expr, *args, **kwargs) + self.post_visit(expr, *args, **kwargs) + + def map_variable(self, expr, *args, **kwargs): + self.visit(expr, *args, **kwargs) + self.post_visit(expr, *args, **kwargs) + + def map_sumfact_kernel(self, expr, *args, **kwargs): + basis_sf_kernels(expr) + self.visit(expr, *args, **kwargs) + self.post_visit(expr, *args, **kwargs) + + @iname def _sumfact_iname(bound, _type, count): name = "sf_{}_{}".format(_type, str(count)) @@ -199,6 +219,9 @@ def generate_accumulation_instruction(expr, visitor): test_info = visitor.test_info trial_info = visitor.trial_info + # Cache all stage 1 sum factorization kernels used in this expression + SumfactCollectMapper()(expr) + # Number of basis functions per direction leaf_element = test_info.element from ufl import MixedElement diff --git a/python/dune/perftool/sumfact/basis.py b/python/dune/perftool/sumfact/basis.py index 8e745c7d..00ef2966 100644 --- a/python/dune/perftool/sumfact/basis.py +++ b/python/dune/perftool/sumfact/basis.py @@ -146,6 +146,11 @@ def pymbolic_coefficient_gradient(element, restriction, index, coeff_func, visit from dune.perftool.sumfact.vectorization import attach_vectorization_info vsf = attach_vectorization_info(sf) + # If this sum factorization kernel was not used in the dry run we + # just return 0 + if vsf == 0: + return 0, None + from dune.perftool.sumfact.realization import realize_sum_factorization_kernel var, insn_dep = realize_sum_factorization_kernel(vsf) @@ -182,6 +187,11 @@ def pymbolic_coefficient(element, restriction, index, coeff_func, visitor_indice from dune.perftool.sumfact.vectorization import attach_vectorization_info vsf = attach_vectorization_info(sf) + # If this sum factorization kernel was not used in the dry run we + # just return 0 + if vsf == 0: + return 0, None + # Add a sum factorization kernel that implements the evaluation of # the basis functions at quadrature points (stage 1) from dune.perftool.sumfact.realization import realize_sum_factorization_kernel diff --git a/python/dune/perftool/sumfact/realization.py b/python/dune/perftool/sumfact/realization.py index 3f1afe9c..5c9e4a5b 100644 --- a/python/dune/perftool/sumfact/realization.py +++ b/python/dune/perftool/sumfact/realization.py @@ -24,7 +24,6 @@ from dune.perftool.sumfact.permutation import (sumfact_permutation_strategy, permute_backward, permute_forward, ) -from dune.perftool.sumfact.vectorization import attach_vectorization_info from dune.perftool.sumfact.accumulation import sumfact_iname from dune.perftool.loopy.vcl import ExplicitVCLCast diff --git a/python/dune/perftool/sumfact/vectorization.py b/python/dune/perftool/sumfact/vectorization.py index 08e07782..3a044794 100644 --- a/python/dune/perftool/sumfact/vectorization.py +++ b/python/dune/perftool/sumfact/vectorization.py @@ -140,37 +140,52 @@ def decide_vectorization_strategy(): """ logger = logging.getLogger(__name__) + # Retrieve all sum factorization kernels for stage 1 and 3 from dune.perftool.generation import retrieve_cache_items - sumfacts = [i for i in retrieve_cache_items("kernel_default and sumfactnodes")] + all_sumfacts = [i for i in retrieve_cache_items("kernel_default and sumfactnodes")] + + # Stage 1 sumfactorizations that were actually used + basis_sumfacts = [i for i in retrieve_cache_items('kernel_default and basis_sf_kernels')] + + # This means we can have sum factorizations that will not get used + inactive_sumfacts = [i for i in all_sumfacts if i.stage == 1 and i not in basis_sumfacts] + + # All sum factorization kernels that get used + active_sumfacts = [i for i in all_sumfacts if i.stage == 3 or i in basis_sumfacts] + + # We map inacitve sum factorizatino kernels to 0 sfdict = {} + for sf in inactive_sumfacts: + sfdict[sf] = 0 - logger.debug("decide_vectorization_strategy: Found {} sum factorization nodes".format(len(sumfacts))) + logger.debug("decide_vectorization_strategy: Found {} active sum factorization nodes" + .format(len(active_sumfacts))) if get_option("vectorize_grads"): # Currently we base our idea here on the fact that we only group sum # factorization kernels with the same input. - inputkeys = set(sf.input_key for sf in sumfacts) + inputkeys = set(sf.input_key for sf in active_sumfacts) for inputkey in inputkeys: width = get_vcl_type_size(np.float64) - sumfact_filter = [sf for sf in sumfacts if sf.input_key == inputkey] + sumfact_filter = [sf for sf in active_sumfacts if sf.input_key == inputkey] for old, new in horizontal_vectorization_strategy(sumfact_filter, width).items(): sfdict[old] = new elif get_option("vectorize_slice"): - for sumfact in sumfacts: + for sumfact in active_sumfacts: width = get_vcl_type_size(np.float64) for old, new in vertical_vectorization_strategy(sumfact, width).items(): sfdict[old] = new elif get_option("vectorize_diagonal"): - inputkeys = set(sf.input_key for sf in sumfacts) + inputkeys = set(sf.input_key for sf in active_sumfacts) for inputkey in inputkeys: width = get_vcl_type_size(np.float64) - sumfact_filter = [sf for sf in sumfacts if sf.input_key == inputkey] + sumfact_filter = [sf for sf in active_sumfacts if sf.input_key == inputkey] for old, new in diagonal_vectorization_strategy(sumfact_filter, width).items(): sfdict[old] = new else: - for old, new in no_vectorization(sumfacts).items(): + for old, new in no_vectorization(active_sumfacts).items(): sfdict[old] = new # Register the results - for sf in sumfacts: + for sf in all_sumfacts: _cache_vectorization_info(sf, sfdict.get(sf, no_vec(sf))) -- GitLab