""" Sum factorization vectorization """

import logging

from dune.perftool.loopy.vcl import get_vcl_type_size
from dune.perftool.loopy.symbolic import SumfactKernel, VectorizedSumfactKernel
from dune.perftool.generation import (generator_factory,
                                      get_counted_variable,
                                      get_global_context_value,
                                      )
from dune.perftool.pdelab.restriction import (Restriction,
                                              restricted_name,
                                              )
from dune.perftool.sumfact.tabulation import BasisTabulationMatrixArray
from dune.perftool.error import PerftoolError
from dune.perftool.options import get_option

import loopy as lp
import numpy as np


@generator_factory(item_tags=("vecinfo", "dryrundata"), cache_key_generator=lambda o, n: o)
def _cache_vectorization_info(old, new):
    if new is None:
        raise PerftoolError("Vectorization info for sum factorization kernel was not gathered correctly!")
    return new


_collect_sumfact_nodes = generator_factory(item_tags=("sumfactnodes", "dryrundata"), context_tags="kernel", no_deco=True)


def get_all_sumfact_nodes():
    from dune.perftool.generation import retrieve_cache_items
    return [i for i in retrieve_cache_items("kernel_default and sumfactnodes")]


def attach_vectorization_info(sf):
    assert isinstance(sf, SumfactKernel)
    if get_global_context_value("dry_run"):
        return _collect_sumfact_nodes(sf)
    else:
        return _cache_vectorization_info(sf, None)


def no_vec(sf):
    return sf.copy(buffer=get_counted_variable("buffer"))


def no_vectorization(sumfacts):
    return {sf: no_vec(sf) for sf in sumfacts}


def vertical_vectorization_strategy(sumfact, depth):
    # Assert that this is not already sliced
    assert all(mat.slice_size is None for mat in sumfact.matrix_sequence)
    result = {}

    # Determine which of the matrices in the kernel should be sliced
    def determine_slice_direction(sf):
        for i, mat in enumerate(sf.matrix_sequence):
            if mat.quadrature_size % depth == 0:
                return i
            elif mat.quadrature_size != 1:
                raise PerftoolError("Vertical vectorization is not possible!")

    if isinstance(sumfact, SumfactKernel):
        kernels = [sumfact]
    else:
        assert isinstance(sumfact, VectorizedSumfactKernel)
        kernels = sumfact.kernels

    newkernels = []
    for sf in kernels:
        sliced = determine_slice_direction(sf)
        oldtab = sf.matrix_sequence[sliced]
        for i in range(depth):
            seq = list(sf.matrix_sequence)
            seq[sliced] = oldtab.copy(slice_size=depth,
                                      slice_index=i)
            newkernels.append(sf.copy(matrix_sequence=tuple(seq)))

    if isinstance(sumfact, SumfactKernel):
        buffer = get_counted_variable("vertical_buffer")
        result[sumfact] = VectorizedSumfactKernel(kernels=tuple(newkernels),
                                                  buffer=buffer,
                                                  vertical_width=depth,
                                                  )
    else:
        assert isinstance(sumfact, VectorizedSumfactKernel)
        for sf in kernels:
            result[sf] = sumfact.copy(kernels=tuple(newkernels),
                                      vertical_width=depth,
                                      )

    return result


def horizontal_vectorization_strategy(sumfacts, width, allow_padding=1):
    result = {}
    todo = set(sumfacts)
    while todo:
        kernels = []
        for sf in sorted(todo, key=lambda s: s.position_priority):
            if len(kernels) < width:
                kernels.append(sf)
                todo.discard(sf)

        buffer = get_counted_variable("joined_buffer")
        if len(kernels) in range(width - allow_padding, width + 1):
            for sf in kernels:
                result[sf] = VectorizedSumfactKernel(kernels=tuple(kernels),
                                                     horizontal_width=width,
                                                     buffer=buffer,
                                                     )

    return result


def diagonal_vectorization_strategy(sumfacts, width):
    if width == 4:
        horizontal, vertical = 2, 2
        padding = 0
    elif width == 8:
        horizontal, vertical = 4, 2
        padding = 1
    else:
        raise NotImplementedError

    result = {}

    horizontal_kernels = horizontal_vectorization_strategy(sumfacts, horizontal, allow_padding=padding)

    for sf in horizontal_kernels:
        if horizontal_kernels[sf].horizontal_width > 1:
            vert = vertical_vectorization_strategy(horizontal_kernels[sf], width // horizontal_kernels[sf].horizontal_width)
            for k in vert:
                result[k] = vert[k]

    return result


def decide_vectorization_strategy():
    """ Decide how to vectorize!
    Note that the vectorization of the quadrature loop is independent of this,
    as it is implemented through a post-processing (== loopy transformation) step.
    """
    logger = logging.getLogger(__name__)

    # Retrieve all sum factorization kernels for stage 1 and 3
    from dune.perftool.generation import retrieve_cache_items
    all_sumfacts = [i for i in retrieve_cache_items("kernel_default and sumfactnodes")]

    # Stage 1 sumfactorizations that were actually used
    basis_sumfacts = [i for i in retrieve_cache_items('kernel_default and basis_sf_kernels')]

    # This means we can have sum factorizations that will not get used
    inactive_sumfacts = [i for i in all_sumfacts if i.stage == 1 and i not in basis_sumfacts]

    # All sum factorization kernels that get used
    active_sumfacts = [i for i in all_sumfacts if i.stage == 3 or i in basis_sumfacts]

    # We map inacitve sum factorizatino kernels to 0
    sfdict = {}
    for sf in inactive_sumfacts:
        sfdict[sf] = 0

    logger.debug("decide_vectorization_strategy: Found {} active sum factorization nodes"
                 .format(len(active_sumfacts)))

    if get_option("vectorize_grads"):
        # Currently we base our idea here on the fact that we only group sum
        # factorization kernels with the same input.
        inputkeys = set(sf.input_key for sf in active_sumfacts)
        for inputkey in inputkeys:
            width = get_vcl_type_size(np.float64)
            sumfact_filter = [sf for sf in active_sumfacts if sf.input_key == inputkey]
            for old, new in horizontal_vectorization_strategy(sumfact_filter, width).items():
                sfdict[old] = new
    elif get_option("vectorize_slice"):
        for sumfact in active_sumfacts:
            width = get_vcl_type_size(np.float64)
            for old, new in vertical_vectorization_strategy(sumfact, width).items():
                sfdict[old] = new
    elif get_option("vectorize_diagonal"):
        inputkeys = set(sf.input_key for sf in active_sumfacts)
        for inputkey in inputkeys:
            width = get_vcl_type_size(np.float64)
            sumfact_filter = [sf for sf in active_sumfacts if sf.input_key == inputkey]
            for old, new in diagonal_vectorization_strategy(sumfact_filter, width).items():
                sfdict[old] = new
    else:
        for old, new in no_vectorization(active_sumfacts).items():
            sfdict[old] = new

    # Register the results
    for sf in all_sumfacts:
        _cache_vectorization_info(sf, sfdict.get(sf, no_vec(sf)))