Skip to content
Snippets Groups Projects
Commit e7a7fe6d authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Delay addition of sum factorization kernel until all tree visiting is done

parent 3c090707
No related branches found
No related tags found
No related merge requests found
from dune.perftool.options import get_option # Trigger imports that involve monkey patching!
import dune.perftool.loopy.symbolic
# Trigger some imports that are needed to have all backend implementations visible # Trigger some imports that are needed to have all backend implementations visible
# to the selection mechanisms # to the selection mechanisms
......
...@@ -29,6 +29,7 @@ from dune.perftool.generation.cpp import (base_class, ...@@ -29,6 +29,7 @@ from dune.perftool.generation.cpp import (base_class,
) )
from dune.perftool.generation.loopy import (barrier, from dune.perftool.generation.loopy import (barrier,
built_instruction,
constantarg, constantarg,
domain, domain,
function_mangler, function_mangler,
......
...@@ -15,6 +15,7 @@ iname = generator_factory(item_tags=("iname",), context_tags="kernel") ...@@ -15,6 +15,7 @@ iname = generator_factory(item_tags=("iname",), context_tags="kernel")
function_mangler = generator_factory(item_tags=("mangler",), context_tags="kernel") function_mangler = generator_factory(item_tags=("mangler",), context_tags="kernel")
silenced_warning = generator_factory(item_tags=("silenced_warning",), no_deco=True, context_tags="kernel") silenced_warning = generator_factory(item_tags=("silenced_warning",), no_deco=True, context_tags="kernel")
kernel_cached = generator_factory(item_tags=("default_cached",), context_tags="kernel") kernel_cached = generator_factory(item_tags=("default_cached",), context_tags="kernel")
built_instruction = generator_factory(item_tags=("instruction",), context_tags="kernel", no_deco=True)
class DuneGlobalArg(lp.GlobalArg): class DuneGlobalArg(lp.GlobalArg):
......
""" Monkey patches for loopy.symbolic
Use this module to insert pymbolic nodes and the likes.
"""
from dune.perftool.error import PerftoolError
from pymbolic.mapper.substitutor import make_subst_func
import loopy as lp
import pymbolic.primitives as prim
#
# Pymbolic nodes to insert into the symbolic language understood by loopy
#
class SumfactKernel(prim.Variable):
def __init__(self, a_matrices, buffer, insn_dep=frozenset({}), additional_inames=frozenset({})):
self.a_matrices = a_matrices
self.buffer = buffer
self.insn_dep = insn_dep
self.additional_inames = additional_inames
prim.Variable.__init__(self, "SUMFACT")
def __getinitargs__(self):
return (self.a_matrices, self.buffer, self.insn_dep, self.additional_inames)
def stringifier(self):
return lp.symbolic.StringifyMapper
init_arg_names = ("a_matrices", "buffer", "insn_dep", "additional_inames")
mapper_method = "map_sumfact_kernel"
#
# Mapper methods to monkey patch into the visitor base classes!
#
def identity_map_sumfact_kernel(self, expr, *args):
return expr
def walk_map_sumfact_kernel(self, expr, *args):
self.visit(expr)
def stringify_map_sumfact_kernel(self, expr, *args):
return "SUMFACT"
def dependency_map_sumfact_kernel(self, expr):
return set()
def needs_resolution(self, expr):
raise PerftoolError("SumfactKernel node is a placeholder and needs to be removed!")
#
# Do the actual monkey patching!!!
#
lp.symbolic.IdentityMapper.map_sumfact_kernel = identity_map_sumfact_kernel
lp.symbolic.SubstitutionMapper.map_sumfact_kernel = lp.symbolic.SubstitutionMapper.map_variable
lp.symbolic.WalkMapper.map_sumfact_kernel = walk_map_sumfact_kernel
lp.symbolic.StringifyMapper.map_sumfact_kernel = stringify_map_sumfact_kernel
lp.symbolic.DependencyMapper.map_sumfact_kernel = dependency_map_sumfact_kernel
lp.target.c.codegen.expression.ExpressionToCExpressionMapper.map_sumfact_kernel = needs_resolution
lp.type_inference.TypeInferenceMapper.map_sumfact_kernel = needs_resolution
#
# Some helper functions!
#
def substitute(expr, replacemap):
""" A replacement for pymbolic.mapper.subsitutor.substitute which is aware of all
monkey patches etc.
"""
return lp.symbolic.SubstitutionMapper(make_subst_func(replacemap))(expr)
...@@ -16,6 +16,7 @@ from dune.perftool.generation import (backend, ...@@ -16,6 +16,7 @@ from dune.perftool.generation import (backend,
include_file, include_file,
initializer_list, initializer_list,
post_include, post_include,
retrieve_cache_functions,
retrieve_cache_items, retrieve_cache_items,
template_parameter, template_parameter,
) )
...@@ -485,9 +486,14 @@ def generate_kernel(integrals): ...@@ -485,9 +486,14 @@ def generate_kernel(integrals):
def extract_kernel_from_cache(tag): def extract_kernel_from_cache(tag):
# Extract the information, which is needed to create a loopy kernel. # Preprocess some instruction!
# First extracting it, might be useful to alter it before kernel generation. from dune.perftool.sumfact.sumfact import expand_sumfact_kernels, filter_sumfact_instructions
from dune.perftool.generation import retrieve_cache_functions, retrieve_cache_items instructions = [i for i in retrieve_cache_items("{} and instruction".format(tag))]
for insn in instructions:
expand_sumfact_kernels(insn)
filter_sumfact_instructions()
# Now extract regular loopy kernel components
from dune.perftool.loopy.target import DuneTarget from dune.perftool.loopy.target import DuneTarget
domains = [i for i in retrieve_cache_items("{} and domain".format(tag))] domains = [i for i in retrieve_cache_items("{} and domain".format(tag))]
......
...@@ -42,6 +42,9 @@ class AMatrix(Record): ...@@ -42,6 +42,9 @@ class AMatrix(Record):
cols=cols, cols=cols,
) )
def __hash__(self):
return hash((self.a_matrix, self.rows, self.cols))
def quadrature_points_per_direction(): def quadrature_points_per_direction():
# TODO use quadrature order from dune.perftool.pdelab.quadrature # TODO use quadrature order from dune.perftool.pdelab.quadrature
......
...@@ -20,7 +20,7 @@ from dune.perftool.sumfact.amatrix import (AMatrix, ...@@ -20,7 +20,7 @@ from dune.perftool.sumfact.amatrix import (AMatrix,
quadrature_points_per_direction, quadrature_points_per_direction,
) )
from dune.perftool.sumfact.sumfact import (setup_theta, from dune.perftool.sumfact.sumfact import (setup_theta,
sum_factorization_kernel, SumfactKernel,
sumfact_iname, sumfact_iname,
) )
from dune.perftool.sumfact.quadrature import quadrature_inames from dune.perftool.sumfact.quadrature import quadrature_inames
...@@ -79,10 +79,10 @@ def sumfact_evaluate_coefficient_gradient(element, name, restriction, component) ...@@ -79,10 +79,10 @@ def sumfact_evaluate_coefficient_gradient(element, name, restriction, component)
# evaluation of the gradients of basis functions at quadrature # evaluation of the gradients of basis functions at quadrature
# points (stage 1) # points (stage 1)
insn_dep = setup_theta(element, restriction, component, a_matrices, buffer_name) insn_dep = setup_theta(element, restriction, component, a_matrices, buffer_name)
var, _ = sum_factorization_kernel(a_matrices, var = SumfactKernel(a_matrices,
buffer_name, buffer_name,
insn_dep=frozenset({insn_dep}), insn_dep=frozenset({insn_dep}),
) )
buffers.append(var) buffers.append(var)
...@@ -91,7 +91,7 @@ def sumfact_evaluate_coefficient_gradient(element, name, restriction, component) ...@@ -91,7 +91,7 @@ def sumfact_evaluate_coefficient_gradient(element, name, restriction, component)
from pymbolic.primitives import Subscript, Variable from pymbolic.primitives import Subscript, Variable
from dune.perftool.generation import get_backend from dune.perftool.generation import get_backend
assignee = Subscript(Variable(name), i) assignee = Subscript(Variable(name), i)
expression = Subscript(Variable(buf), tuple(Variable(i) for i in quadrature_inames())) expression = Subscript(buf, tuple(Variable(i) for i in quadrature_inames()))
instruction(assignee=assignee, instruction(assignee=assignee,
expression=expression, expression=expression,
forced_iname_deps=frozenset(get_backend("quad_inames")()), forced_iname_deps=frozenset(get_backend("quad_inames")()),
...@@ -133,12 +133,12 @@ def pymbolic_trialfunction(element, restriction, component): ...@@ -133,12 +133,12 @@ def pymbolic_trialfunction(element, restriction, component):
# Add a sum factorization kernel that implements the evaluation of # Add a sum factorization kernel that implements the evaluation of
# the basis functions at quadrature points (stage 1) # the basis functions at quadrature points (stage 1)
insn_dep = setup_theta(element, restriction, component, a_matrices, buffer_name) insn_dep = setup_theta(element, restriction, component, a_matrices, buffer_name)
var, _ = sum_factorization_kernel(a_matrices, var = SumfactKernel(a_matrices,
buffer_name, buffer_name,
insn_dep=frozenset({insn_dep}), insn_dep=frozenset({insn_dep}),
) )
return prim.Subscript(prim.Variable(var), return prim.Subscript(var,
tuple(prim.Variable(i) for i in quadrature_inames()) tuple(prim.Variable(i) for i in quadrature_inames())
) )
......
import copy import copy
from pymbolic.mapper.substitutor import substitute from dune.perftool.loopy.symbolic import substitute
from dune.perftool.pdelab.argument import (name_accumulation_variable, from dune.perftool.pdelab.argument import (name_accumulation_variable,
name_coefficientcontainer, name_coefficientcontainer,
pymbolic_coefficient, pymbolic_coefficient,
...@@ -9,6 +8,7 @@ from dune.perftool.pdelab.argument import (name_accumulation_variable, ...@@ -9,6 +8,7 @@ from dune.perftool.pdelab.argument import (name_accumulation_variable,
) )
from dune.perftool.generation import (backend, from dune.perftool.generation import (backend,
barrier, barrier,
built_instruction,
domain, domain,
function_mangler, function_mangler,
get_counter, get_counter,
...@@ -37,6 +37,8 @@ from dune.perftool.sumfact.amatrix import (AMatrix, ...@@ -37,6 +37,8 @@ from dune.perftool.sumfact.amatrix import (AMatrix,
name_theta, name_theta,
name_theta_transposed, name_theta_transposed,
) )
from dune.perftool.loopy.symbolic import SumfactKernel
from dune.perftool.error import PerftoolError
from pymbolic.primitives import (Call, from pymbolic.primitives import (Call,
Product, Product,
Subscript, Subscript,
...@@ -46,9 +48,55 @@ from dune.perftool.sumfact.quadrature import quadrature_inames ...@@ -46,9 +48,55 @@ from dune.perftool.sumfact.quadrature import quadrature_inames
from loopy import Reduction, GlobalArg from loopy import Reduction, GlobalArg
from loopy.symbolic import FunctionIdentifier from loopy.symbolic import FunctionIdentifier
import loopy as lp
import pymbolic.primitives as prim
from pytools import product from pytools import product
class HasSumfactMapper(lp.symbolic.CombineMapper):
def combine(self, *args):
return frozenset().union(*tuple(*args))
def map_constant(self, expr):
return frozenset()
def map_algebraic_leaf(self, expr):
return frozenset()
def map_loopy_function_identifier(self, expr):
return frozenset()
def map_sumfact_kernel(self, expr):
return frozenset({expr})
def find_sumfact(expr):
return HasSumfactMapper()(expr)
def expand_sumfact_kernels(insn):
if isinstance(insn, (lp.Assignment, lp.CallInstruction)):
replace = {}
deps = []
for sumf in find_sumfact(insn.expression):
var, dep = sum_factorization_kernel(sumf.a_matrices, sumf.buffer, sumf.insn_dep, sumf.additional_inames)
replace[sumf] = prim.Variable(var)
deps.append(dep)
if replace:
built_instruction(insn.copy(expression=substitute(insn.expression, replace),
depends_on=frozenset(*deps)
)
)
def filter_sumfact_instructions():
""" Remove all instructions that contain a SumfactKernel node """
from dune.perftool.generation.loopy import expr_instruction_impl, call_instruction_impl
expr_instruction_impl._memoize_cache = {k: v for k, v in expr_instruction_impl._memoize_cache.items() if not find_sumfact(v.expression)}
call_instruction_impl._memoize_cache = {k: v for k, v in call_instruction_impl._memoize_cache.items() if not find_sumfact(v.expression)}
@iname @iname
def _sumfact_iname(bound, _type, count): def _sumfact_iname(bound, _type, count):
name = "sf_{}_{}".format(_type, str(count)) name = "sf_{}_{}".format(_type, str(count))
...@@ -146,10 +194,9 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id): ...@@ -146,10 +194,9 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
# Replace gradient iname with correct index for assignement # Replace gradient iname with correct index for assignement
replace_dict = {} replace_dict = {}
expression = copy.deepcopy(pymbolic_expr)
for iname in additional_inames: for iname in additional_inames:
replace_dict[Variable(iname)] = i replace_dict[Variable(iname)] = i
expression = substitute(expression, replace_dict) expression = substitute(pymbolic_expr, replace_dict)
# Issue an instruction in the quadrature loop that fills the buffer # Issue an instruction in the quadrature loop that fills the buffer
# with the evaluation of the contribution at all quadrature points # with the evaluation of the contribution at all quadrature points
...@@ -164,11 +211,11 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id): ...@@ -164,11 +211,11 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
# Add a sum factorization kernel that implements the multiplication # Add a sum factorization kernel that implements the multiplication
# with the test function (stage 3) # with the test function (stage 3)
result, insn_dep = sum_factorization_kernel(a_matrices, result = SumfactKernel(a_matrices,
buf, buf,
insn_dep=frozenset({contrib_dep}), insn_dep=frozenset({contrib_dep}),
additional_inames=frozenset(visitor.inames), additional_inames=frozenset(visitor.inames),
) )
inames = tuple(sumfact_iname(mat.rows, 'accum') for mat in a_matrices) inames = tuple(sumfact_iname(mat.rows, 'accum') for mat in a_matrices)
...@@ -193,7 +240,7 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id): ...@@ -193,7 +240,7 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
expr = Call(PDELabAccumulationFunction(accum, rank), expr = Call(PDELabAccumulationFunction(accum, rank),
(ansatz_lfs.get_args() + (ansatz_lfs.get_args() +
test_lfs.get_args() + test_lfs.get_args() +
(Subscript(Variable(result), tuple(Variable(i) for i in inames)),) (Subscript(result, tuple(Variable(i) for i in inames)),)
) )
) )
...@@ -201,7 +248,6 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id): ...@@ -201,7 +248,6 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
expression=expr, expression=expr,
forced_iname_deps=frozenset(inames + visitor.inames), forced_iname_deps=frozenset(inames + visitor.inames),
forced_iname_deps_is_final=True, forced_iname_deps_is_final=True,
depends_on=insn_dep,
) )
# Mark the transformation that moves the quadrature loop inside the trialfunction loops for application # Mark the transformation that moves the quadrature loop inside the trialfunction loops for application
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment