import warnings
from typing import Callable, List, Optional, Sequence, Tuple, cast

from thinc.api import Model, Ops, registry
from thinc.initializers import glorot_uniform_init
from thinc.types import Floats1d, Floats2d, Ints1d, Ragged
from thinc.util import partial

from ..attrs import ORTH
from ..errors import Errors, Warnings
from ..tokens import Doc
from ..vectors import Mode, Vectors
from ..vocab import Vocab


def StaticVectors(
    nO: Optional[int] = None,
    nM: Optional[int] = None,
    *,
    dropout: Optional[float] = None,
    init_W: Callable = glorot_uniform_init,
    key_attr: str = "ORTH"
) -> Model[List[Doc], Ragged]:
    """Embed Doc objects with their vocab's vectors table, applying a learned
    linear projection to control the dimensionality. If a dropout rate is
    specified, the dropout is applied per dimension over the whole batch.
    """
    if key_attr != "ORTH":
        warnings.warn(Warnings.W125, DeprecationWarning)
    return Model(
        "static_vectors",
        forward,
        init=partial(init, init_W),
        params={"W": None},
        attrs={"key_attr": key_attr, "dropout_rate": dropout},
        dims={"nO": nO, "nM": nM},
    )


def forward(
    model: Model[List[Doc], Ragged], docs: List[Doc], is_train: bool
) -> Tuple[Ragged, Callable]:
    token_count = sum(len(doc) for doc in docs)
    if not token_count:
        return _handle_empty(model.ops, model.get_dim("nO"))
    vocab: Vocab = docs[0].vocab
    key_attr: int = getattr(vocab.vectors, "attr", ORTH)
    keys = model.ops.flatten([cast(Ints1d, doc.to_array(key_attr)) for doc in docs])
    W = cast(Floats2d, model.ops.as_contig(model.get_param("W")))
    if isinstance(vocab.vectors, Vectors) and vocab.vectors.mode == Mode.default:
        V = model.ops.asarray(vocab.vectors.data)
        rows = vocab.vectors.find(keys=keys)
        V = model.ops.as_contig(V[rows])
    elif isinstance(vocab.vectors, Vectors) and vocab.vectors.mode == Mode.floret:
        V = vocab.vectors.get_batch(keys)
        V = model.ops.as_contig(V)
    elif hasattr(vocab.vectors, "get_batch"):
        V = vocab.vectors.get_batch(keys)
        V = model.ops.as_contig(V)
    else:
        raise RuntimeError(Errors.E896)
    try:
        vectors_data = model.ops.gemm(V, W, trans2=True)
    except ValueError:
        raise RuntimeError(Errors.E896)
    if isinstance(vocab.vectors, Vectors) and vocab.vectors.mode == Mode.default:
        # Convert negative indices to 0-vectors
        # TODO: more options for UNK tokens
        vectors_data[rows < 0] = 0
    output = Ragged(vectors_data, model.ops.asarray1i([len(doc) for doc in docs]))
    mask = None
    if is_train:
        mask = _get_drop_mask(model.ops, W.shape[0], model.attrs.get("dropout_rate"))
        if mask is not None:
            output.data *= mask

    def backprop(d_output: Ragged) -> List[Doc]:
        if mask is not None:
            d_output.data *= mask
        model.inc_grad(
            "W",
            model.ops.gemm(
                cast(Floats2d, d_output.data),
                cast(Floats2d, model.ops.as_contig(V)),
                trans1=True,
            ),
        )
        return []

    return output, backprop


def init(
    init_W: Callable,
    model: Model[List[Doc], Ragged],
    X: Optional[List[Doc]] = None,
    Y: Optional[Ragged] = None,
) -> Model[List[Doc], Ragged]:
    nM = model.get_dim("nM") if model.has_dim("nM") else None
    nO = model.get_dim("nO") if model.has_dim("nO") else None
    if X is not None and len(X):
        nM = X[0].vocab.vectors.shape[1]
    if Y is not None:
        nO = Y.data.shape[1]

    if nM is None:
        raise ValueError(Errors.E905)
    if nO is None:
        raise ValueError(Errors.E904)
    model.set_dim("nM", nM)
    model.set_dim("nO", nO)
    model.set_param("W", init_W(model.ops, (nO, nM)))
    return model


def _handle_empty(ops: Ops, nO: int):
    return Ragged(ops.alloc2f(0, nO), ops.alloc1i(0)), lambda d_ragged: []


def _get_drop_mask(ops: Ops, nO: int, rate: Optional[float]) -> Optional[Floats1d]:
    if rate is not None:
        mask = ops.get_dropout_mask((nO,), rate)
        return mask  # type: ignore
    return None
