From e7a7fe6db6656a5f37678311de3653d0e95287da Mon Sep 17 00:00:00 2001 From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de> Date: Mon, 28 Nov 2016 16:11:19 +0100 Subject: [PATCH] Delay addition of sum factorization kernel until all tree visiting is done --- python/dune/perftool/__init__.py | 3 +- python/dune/perftool/generation/__init__.py | 1 + python/dune/perftool/generation/loopy.py | 1 + python/dune/perftool/loopy/symbolic.py | 85 ++++++++++++++++++++ python/dune/perftool/pdelab/localoperator.py | 12 ++- python/dune/perftool/sumfact/amatrix.py | 3 + python/dune/perftool/sumfact/basis.py | 22 ++--- python/dune/perftool/sumfact/sumfact.py | 68 +++++++++++++--- 8 files changed, 169 insertions(+), 26 deletions(-) create mode 100644 python/dune/perftool/loopy/symbolic.py diff --git a/python/dune/perftool/__init__.py b/python/dune/perftool/__init__.py index b907a800..f656fe92 100644 --- a/python/dune/perftool/__init__.py +++ b/python/dune/perftool/__init__.py @@ -1,4 +1,5 @@ -from dune.perftool.options import get_option +# Trigger imports that involve monkey patching! +import dune.perftool.loopy.symbolic # Trigger some imports that are needed to have all backend implementations visible # to the selection mechanisms diff --git a/python/dune/perftool/generation/__init__.py b/python/dune/perftool/generation/__init__.py index 89a98fd7..93d3981f 100644 --- a/python/dune/perftool/generation/__init__.py +++ b/python/dune/perftool/generation/__init__.py @@ -29,6 +29,7 @@ from dune.perftool.generation.cpp import (base_class, ) from dune.perftool.generation.loopy import (barrier, + built_instruction, constantarg, domain, function_mangler, diff --git a/python/dune/perftool/generation/loopy.py b/python/dune/perftool/generation/loopy.py index 5f20ec16..a5e073cd 100644 --- a/python/dune/perftool/generation/loopy.py +++ b/python/dune/perftool/generation/loopy.py @@ -15,6 +15,7 @@ iname = generator_factory(item_tags=("iname",), context_tags="kernel") function_mangler = generator_factory(item_tags=("mangler",), context_tags="kernel") silenced_warning = generator_factory(item_tags=("silenced_warning",), no_deco=True, context_tags="kernel") kernel_cached = generator_factory(item_tags=("default_cached",), context_tags="kernel") +built_instruction = generator_factory(item_tags=("instruction",), context_tags="kernel", no_deco=True) class DuneGlobalArg(lp.GlobalArg): diff --git a/python/dune/perftool/loopy/symbolic.py b/python/dune/perftool/loopy/symbolic.py new file mode 100644 index 00000000..0c9d1f96 --- /dev/null +++ b/python/dune/perftool/loopy/symbolic.py @@ -0,0 +1,85 @@ +""" Monkey patches for loopy.symbolic + +Use this module to insert pymbolic nodes and the likes. +""" +from dune.perftool.error import PerftoolError +from pymbolic.mapper.substitutor import make_subst_func + +import loopy as lp +import pymbolic.primitives as prim + + +# +# Pymbolic nodes to insert into the symbolic language understood by loopy +# + + +class SumfactKernel(prim.Variable): + def __init__(self, a_matrices, buffer, insn_dep=frozenset({}), additional_inames=frozenset({})): + self.a_matrices = a_matrices + self.buffer = buffer + self.insn_dep = insn_dep + self.additional_inames = additional_inames + + prim.Variable.__init__(self, "SUMFACT") + + def __getinitargs__(self): + return (self.a_matrices, self.buffer, self.insn_dep, self.additional_inames) + + def stringifier(self): + return lp.symbolic.StringifyMapper + + init_arg_names = ("a_matrices", "buffer", "insn_dep", "additional_inames") + + mapper_method = "map_sumfact_kernel" + + +# +# Mapper methods to monkey patch into the visitor base classes! +# + + +def identity_map_sumfact_kernel(self, expr, *args): + return expr + + +def walk_map_sumfact_kernel(self, expr, *args): + self.visit(expr) + + +def stringify_map_sumfact_kernel(self, expr, *args): + return "SUMFACT" + + +def dependency_map_sumfact_kernel(self, expr): + return set() + + +def needs_resolution(self, expr): + raise PerftoolError("SumfactKernel node is a placeholder and needs to be removed!") + + +# +# Do the actual monkey patching!!! +# + + +lp.symbolic.IdentityMapper.map_sumfact_kernel = identity_map_sumfact_kernel +lp.symbolic.SubstitutionMapper.map_sumfact_kernel = lp.symbolic.SubstitutionMapper.map_variable +lp.symbolic.WalkMapper.map_sumfact_kernel = walk_map_sumfact_kernel +lp.symbolic.StringifyMapper.map_sumfact_kernel = stringify_map_sumfact_kernel +lp.symbolic.DependencyMapper.map_sumfact_kernel = dependency_map_sumfact_kernel +lp.target.c.codegen.expression.ExpressionToCExpressionMapper.map_sumfact_kernel = needs_resolution +lp.type_inference.TypeInferenceMapper.map_sumfact_kernel = needs_resolution + + +# +# Some helper functions! +# + + +def substitute(expr, replacemap): + """ A replacement for pymbolic.mapper.subsitutor.substitute which is aware of all + monkey patches etc. + """ + return lp.symbolic.SubstitutionMapper(make_subst_func(replacemap))(expr) diff --git a/python/dune/perftool/pdelab/localoperator.py b/python/dune/perftool/pdelab/localoperator.py index 30454153..fb36e575 100644 --- a/python/dune/perftool/pdelab/localoperator.py +++ b/python/dune/perftool/pdelab/localoperator.py @@ -16,6 +16,7 @@ from dune.perftool.generation import (backend, include_file, initializer_list, post_include, + retrieve_cache_functions, retrieve_cache_items, template_parameter, ) @@ -485,9 +486,14 @@ def generate_kernel(integrals): def extract_kernel_from_cache(tag): - # Extract the information, which is needed to create a loopy kernel. - # First extracting it, might be useful to alter it before kernel generation. - from dune.perftool.generation import retrieve_cache_functions, retrieve_cache_items + # Preprocess some instruction! + from dune.perftool.sumfact.sumfact import expand_sumfact_kernels, filter_sumfact_instructions + instructions = [i for i in retrieve_cache_items("{} and instruction".format(tag))] + for insn in instructions: + expand_sumfact_kernels(insn) + filter_sumfact_instructions() + + # Now extract regular loopy kernel components from dune.perftool.loopy.target import DuneTarget domains = [i for i in retrieve_cache_items("{} and domain".format(tag))] diff --git a/python/dune/perftool/sumfact/amatrix.py b/python/dune/perftool/sumfact/amatrix.py index 0296a06a..1c786393 100644 --- a/python/dune/perftool/sumfact/amatrix.py +++ b/python/dune/perftool/sumfact/amatrix.py @@ -42,6 +42,9 @@ class AMatrix(Record): cols=cols, ) + def __hash__(self): + return hash((self.a_matrix, self.rows, self.cols)) + def quadrature_points_per_direction(): # TODO use quadrature order from dune.perftool.pdelab.quadrature diff --git a/python/dune/perftool/sumfact/basis.py b/python/dune/perftool/sumfact/basis.py index a54eb2e8..10021fd5 100644 --- a/python/dune/perftool/sumfact/basis.py +++ b/python/dune/perftool/sumfact/basis.py @@ -20,7 +20,7 @@ from dune.perftool.sumfact.amatrix import (AMatrix, quadrature_points_per_direction, ) from dune.perftool.sumfact.sumfact import (setup_theta, - sum_factorization_kernel, + SumfactKernel, sumfact_iname, ) from dune.perftool.sumfact.quadrature import quadrature_inames @@ -79,10 +79,10 @@ def sumfact_evaluate_coefficient_gradient(element, name, restriction, component) # evaluation of the gradients of basis functions at quadrature # points (stage 1) insn_dep = setup_theta(element, restriction, component, a_matrices, buffer_name) - var, _ = sum_factorization_kernel(a_matrices, - buffer_name, - insn_dep=frozenset({insn_dep}), - ) + var = SumfactKernel(a_matrices, + buffer_name, + insn_dep=frozenset({insn_dep}), + ) buffers.append(var) @@ -91,7 +91,7 @@ def sumfact_evaluate_coefficient_gradient(element, name, restriction, component) from pymbolic.primitives import Subscript, Variable from dune.perftool.generation import get_backend assignee = Subscript(Variable(name), i) - expression = Subscript(Variable(buf), tuple(Variable(i) for i in quadrature_inames())) + expression = Subscript(buf, tuple(Variable(i) for i in quadrature_inames())) instruction(assignee=assignee, expression=expression, forced_iname_deps=frozenset(get_backend("quad_inames")()), @@ -133,12 +133,12 @@ def pymbolic_trialfunction(element, restriction, component): # Add a sum factorization kernel that implements the evaluation of # the basis functions at quadrature points (stage 1) insn_dep = setup_theta(element, restriction, component, a_matrices, buffer_name) - var, _ = sum_factorization_kernel(a_matrices, - buffer_name, - insn_dep=frozenset({insn_dep}), - ) + var = SumfactKernel(a_matrices, + buffer_name, + insn_dep=frozenset({insn_dep}), + ) - return prim.Subscript(prim.Variable(var), + return prim.Subscript(var, tuple(prim.Variable(i) for i in quadrature_inames()) ) diff --git a/python/dune/perftool/sumfact/sumfact.py b/python/dune/perftool/sumfact/sumfact.py index cfef8155..691aa1db 100644 --- a/python/dune/perftool/sumfact/sumfact.py +++ b/python/dune/perftool/sumfact/sumfact.py @@ -1,7 +1,6 @@ import copy -from pymbolic.mapper.substitutor import substitute - +from dune.perftool.loopy.symbolic import substitute from dune.perftool.pdelab.argument import (name_accumulation_variable, name_coefficientcontainer, pymbolic_coefficient, @@ -9,6 +8,7 @@ from dune.perftool.pdelab.argument import (name_accumulation_variable, ) from dune.perftool.generation import (backend, barrier, + built_instruction, domain, function_mangler, get_counter, @@ -37,6 +37,8 @@ from dune.perftool.sumfact.amatrix import (AMatrix, name_theta, name_theta_transposed, ) +from dune.perftool.loopy.symbolic import SumfactKernel +from dune.perftool.error import PerftoolError from pymbolic.primitives import (Call, Product, Subscript, @@ -46,9 +48,55 @@ from dune.perftool.sumfact.quadrature import quadrature_inames from loopy import Reduction, GlobalArg from loopy.symbolic import FunctionIdentifier +import loopy as lp +import pymbolic.primitives as prim from pytools import product +class HasSumfactMapper(lp.symbolic.CombineMapper): + def combine(self, *args): + return frozenset().union(*tuple(*args)) + + def map_constant(self, expr): + return frozenset() + + def map_algebraic_leaf(self, expr): + return frozenset() + + def map_loopy_function_identifier(self, expr): + return frozenset() + + def map_sumfact_kernel(self, expr): + return frozenset({expr}) + + +def find_sumfact(expr): + return HasSumfactMapper()(expr) + + +def expand_sumfact_kernels(insn): + if isinstance(insn, (lp.Assignment, lp.CallInstruction)): + replace = {} + deps = [] + for sumf in find_sumfact(insn.expression): + var, dep = sum_factorization_kernel(sumf.a_matrices, sumf.buffer, sumf.insn_dep, sumf.additional_inames) + replace[sumf] = prim.Variable(var) + deps.append(dep) + + if replace: + built_instruction(insn.copy(expression=substitute(insn.expression, replace), + depends_on=frozenset(*deps) + ) + ) + + +def filter_sumfact_instructions(): + """ Remove all instructions that contain a SumfactKernel node """ + from dune.perftool.generation.loopy import expr_instruction_impl, call_instruction_impl + expr_instruction_impl._memoize_cache = {k: v for k, v in expr_instruction_impl._memoize_cache.items() if not find_sumfact(v.expression)} + call_instruction_impl._memoize_cache = {k: v for k, v in call_instruction_impl._memoize_cache.items() if not find_sumfact(v.expression)} + + @iname def _sumfact_iname(bound, _type, count): name = "sf_{}_{}".format(_type, str(count)) @@ -146,10 +194,9 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id): # Replace gradient iname with correct index for assignement replace_dict = {} - expression = copy.deepcopy(pymbolic_expr) for iname in additional_inames: replace_dict[Variable(iname)] = i - expression = substitute(expression, replace_dict) + expression = substitute(pymbolic_expr, replace_dict) # Issue an instruction in the quadrature loop that fills the buffer # with the evaluation of the contribution at all quadrature points @@ -164,11 +211,11 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id): # Add a sum factorization kernel that implements the multiplication # with the test function (stage 3) - result, insn_dep = sum_factorization_kernel(a_matrices, - buf, - insn_dep=frozenset({contrib_dep}), - additional_inames=frozenset(visitor.inames), - ) + result = SumfactKernel(a_matrices, + buf, + insn_dep=frozenset({contrib_dep}), + additional_inames=frozenset(visitor.inames), + ) inames = tuple(sumfact_iname(mat.rows, 'accum') for mat in a_matrices) @@ -193,7 +240,7 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id): expr = Call(PDELabAccumulationFunction(accum, rank), (ansatz_lfs.get_args() + test_lfs.get_args() + - (Subscript(Variable(result), tuple(Variable(i) for i in inames)),) + (Subscript(result, tuple(Variable(i) for i in inames)),) ) ) @@ -201,7 +248,6 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id): expression=expr, forced_iname_deps=frozenset(inames + visitor.inames), forced_iname_deps_is_final=True, - depends_on=insn_dep, ) # Mark the transformation that moves the quadrature loop inside the trialfunction loops for application -- GitLab