diff --git a/python/dune/perftool/pdelab/localoperator.py b/python/dune/perftool/pdelab/localoperator.py index 467a81156f0f76084c566196b5fac1f7beaf95f1..7457dd834c081cd810f8701882a1cf2784fb3b58 100644 --- a/python/dune/perftool/pdelab/localoperator.py +++ b/python/dune/perftool/pdelab/localoperator.py @@ -501,10 +501,6 @@ def generate_kernel(integrals): def extract_kernel_from_cache(tag): - # Preprocess some instruction! - from dune.perftool.sumfact.sumfact import expand_sumfact_kernels - expand_sumfact_kernels(tag) - # Now extract regular loopy kernel components from dune.perftool.loopy.target import DuneTarget domains = [i for i in retrieve_cache_items("{} and domain".format(tag))] diff --git a/python/dune/perftool/sumfact/sumfact.py b/python/dune/perftool/sumfact/sumfact.py index 7f32ed52aeaf20c46669a9b182e31271d93344ba..059488f99d873be0e354e9817e795d33292a0e8d 100644 --- a/python/dune/perftool/sumfact/sumfact.py +++ b/python/dune/perftool/sumfact/sumfact.py @@ -59,178 +59,6 @@ import pymbolic.primitives as prim from pytools import product -class IndexFiddleMapper(IdentityMapper): - def __init__(self, var, index, pos): - assert isinstance(var, str) - self.var = var - self.index = index - if isinstance(index, str): - self.index = prim.Variable(index) - self.pos = pos - - def map_variable(self, expr): - if expr.name == self.var: - return prim.Subscript(expr, (self.var,)) - else: - return IdentityMapper.map_variable(self, expr) - - def map_subscript(self, expr): - if expr.aggregate.name == self.var: - ind = expr.index - if not isinstance(ind, tuple): - ind = (ind,) - if self.pos is None: - ind = ind + (self.index,) - else: - raise NotImplementedError - return prim.Subscript(expr.aggregate, ind) - else: - return IdentityMapper.map_subscript(self, expr) - - -def fiddle_in_index(expr, var, index, pos=None): - return IndexFiddleMapper(var, index, pos)(expr) - - -def default_resolution(insns): - for insn in insns: - if isinstance(insn, (lp.Assignment, lp.CallInstruction)): - replace = {} - deps = [] - for sumf in find_sumfact(insn.expression): - # Maybe set up the input - if sumf.setup_method: - shape = product(mat.cols for mat in sumf.a_matrices) - shape = (shape,) - inp = get_buffer_temporary(sumf.buffer, shape=shape) - silenced_warning('read_no_write({})'.format(inp)) - - func, args = sumf.setup_method - insn_dep = frozenset({func(inp, *args)}) - else: - insn_dep = sumf.insn_dep - - # Call the sum factorization algorithm - var, dep = sum_factorization_kernel(sumf.a_matrices, sumf.buffer, insn_dep, sumf.additional_inames) - replace[sumf] = prim.Variable(var) - deps.append(dep) - - if replace: - built_instruction(insn.copy(expression=substitute(insn.expression, replace), - depends_on=frozenset(*deps) - ) - ) - if isinstance(insn, lp.Assignment): - from dune.perftool.generation.loopy import expr_instruction_impl - expr_instruction_impl.remove_by_value(insn) - if isinstance(insn, lp.CallInstruction): - from dune.perftool.generation.loopy import call_instruction_impl - call_instruction_impl.remove_by_value(insn) - - -def apply_sumfact_grad_vectorization(insns, stage): - sumfact_kernels = [] - for insn in insns: - if isinstance(insn, (lp.Assignment, lp.CallInstruction)): - sumf = find_sumfact(insn.expression) - if sumf: - sumf, = sumf - if sumf.stage == stage: - sumfact_kernels.append(sumf) - - # Now apply some heuristics when to vectorize... - if len(set(sumfact_kernels)) < 3: - return - else: - # Vectorize!!! - sumfact_kernels = sorted(sumfact_kernels, key=lambda s: s.preferred_interleaving_position) - # Pad to 4 - if len(sumfact_kernels) < 4: - sumfact_kernels.append(sumfact_kernels[0]) - - # Determine a name for the buffer in use - buffer = "_".join(sf.buffer for sf in sumfact_kernels) - insn_dep = frozenset() - - # Maybe initialize the input buffer - if sumfact_kernels[0].setup_method: - assert stage == 1 - shape = product(mat.cols for mat in sumfact_kernels[0].a_matrices) - shape = (shape, 4) - initialize_buffer(buffer, base_storage_size=4 * shape[0]) - inp = get_buffer_temporary(buffer, shape=shape) - silenced_warning('read_no_write({})'.format(inp)) - - # Call the individual setup_methods - for i, sumf in enumerate(sumfact_kernels): - func, args = sumf.setup_method - insn_dep = insn_dep.union({func(inp, *args, additional_indices=(i,))}) - else: - assert stage == 3 - # No setup method defined. We need to make sure the input is correctly setup - shape = tuple(mat.cols for mat in sumfact_kernels[0].a_matrices) + (4,) - initialize_buffer(buffer, base_storage_size=4 * shape[0]) - dim_tags = ",".join("f" * (len(shape) - 1)) + ",c" - inp = get_buffer_temporary(buffer, shape=shape, dim_tags=dim_tags) - - for i, sumf in enumerate(sumfact_kernels): - assert sumf.input_temporary - for insn in insns: - if isinstance(insn, lp.Assignment): - if get_pymbolic_basename(insn.assignee) == sumf.input_temporary: - built_instruction(insn.copy(assignee=prim.Subscript(prim.Variable(inp), insn.assignee.index + (i,)), - id=insn.id + "__{}".format(i))) - insn_dep = insn_dep.union(insn.depends_on) - from dune.perftool.generation.loopy import expr_instruction_impl - expr_instruction_impl.remove_by_value(insn) - - # Determine the joined AMatrix - large_a_matrices = [] - for i in range(len(sumfact_kernels[0].a_matrices)): - assert len(set(tuple(sf.a_matrices[i].rows for sf in sumfact_kernels))) == 1 - assert len(set(tuple(sf.a_matrices[i].cols for sf in sumfact_kernels))) == 1 - large = LargeAMatrix(rows=sumfact_kernels[0].a_matrices[i].rows, - cols=sumfact_kernels[0].a_matrices[i].cols, - transpose=tuple(sf.a_matrices[i].transpose for sf in sumfact_kernels), - derivative=tuple(sf.a_matrices[i].derivative for sf in sumfact_kernels), - ) - large_a_matrices.append(large) - - # Join the instruction dependencies - insn_dep = insn_dep.union(*tuple(sf.insn_dep for sf in sumfact_kernels)) - - var, dep = sum_factorization_kernel(large_a_matrices, buffer, insn_dep, sumfact_kernels[0].additional_inames, add_vec_tag=True) - - for insn in insns: - if isinstance(insn, (lp.Assignment, lp.CallInstruction)): - sumfacts = find_sumfact(insn.expression) - if sumfacts: - sumf, = sumfacts - if sumf.stage == stage: - replace = {} - replace[sumf] = prim.Variable(var) - newexpr = substitute(insn.expression, replace) - newexpr = fiddle_in_index(newexpr, var, sumf.preferred_interleaving_position) - assert newexpr - built_instruction(insn.copy(expression=newexpr, - depends_on=dep, - )) - - if isinstance(insn, lp.Assignment): - from dune.perftool.generation.loopy import expr_instruction_impl - expr_instruction_impl.remove_by_value(insn) - if isinstance(insn, lp.CallInstruction): - from dune.perftool.generation.loopy import call_instruction_impl - call_instruction_impl.remove_by_value(insn) - - -def expand_sumfact_kernels(tag): - if get_option("vectorize_grads"): - apply_sumfact_grad_vectorization([i for i in retrieve_cache_items("{} and instruction".format(tag))], 1) - apply_sumfact_grad_vectorization([i for i in retrieve_cache_items("{} and instruction".format(tag))], 3) - default_resolution(retrieve_cache_items("{} and instruction".format(tag))) - - @iname def _sumfact_iname(bound, _type, count): name = "sf_{}_{}".format(_type, str(count))