# Copyright (C) 2024 Apple Inc. All Rights Reserved.
LINEAR_CROSS_ENTROPY_DOC = """Computes cross-entropy loss using the logits generated by performing
    the matrix multiplication between the embeddings (e) and classifier (c).

    This method saves GPU memory by not materializing the logits into GPU
    main memory.


    Specifically, this computes

    ```python

    loss = F.cross_entropy((e @ c.T).float(), targets)
    ```

    without allocating the intermediary (e @ c.T).float() matrix.

    :param e: Embedding of the inputs used to compute the logits. Shape (..., D)
    :param c: Classifier matrix. Shape (NumClasses, D)
    :param targets: The target class for each input. Values must be in [0, NumClasses). Shape (...)
    :param ignore_index: If an input as a target of this value, it is ignored in the loss computation.
    :param softcap: The value for logit softcapping.
    :param reduction: The reduction to perform over the loss. Supports "mean", "sum", and "none".
    :param shift: If true, the embedding and targets are assumed to require a shift along the
        temporal axis to perform next token prediction. Specifically, setting this to true
        will efficiently compute

        ```python
        shift_e = e[..., :-1, :].flatten(0, -2)
        shift_targets = targets[..., 1:].flatten()

        loss = F.cross_entropy((shift_e @ c.T), targets)
        ```
"""


def add_doc_start(*docstr: str):
    def add_doc(fn):
        fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")

        return fn

    return add_doc
