from __future__ import annotations

import dataclasses
from typing import Any, Optional, TYPE_CHECKING


if TYPE_CHECKING:
    from torch.fx import Node


class CantChunk(RuntimeError):
    pass


@dataclasses.dataclass
class ChunkingMeta:
    # The value of the current node having this ChunkingMeta should be scaled by the specified scalar
    # tensor in the `scale_by` field. Need this since we pretend tangent to be 1 first and
    # scale the affected tensor later. Need propagate such information
    # to downstream.
    scale_by: Optional[Node] = None

    # The dimension of the current tensor that get chunked.
    # Can be None if the current tensor is not chunked. E.g., when
    # the current tensor is a scalar tensor generated by summing a chunked tensor.
    #
    # To recover a tensor with non-None chunk_dim, we need concat each
    # chunk at the 'chunk_dim' dimension.
    chunk_dim: Optional[int] = None

    # The original tensor is the sum of each tensor in the chunking subgraph.
    # `chunk_dim` and `need_sum` are exclusive! A node can not have both a non
    # None chunk_dim and a need_sum with True value.
    #
    # Note for some special cases like the tangent placeholder node, both
    # chunk_dim can be None and need_sum can be false, but scale_by
    # in that case is the tangent node itself.
    need_sum: bool = False

    def __post_init__(self) -> None:
        assert self.chunk_dim is None or not self.need_sum, (
            f"Can not have both a non-None chunk_dim and a true need_sum: {self.chunk_dim}"
        )

    def copy(self, **kwargs: Any) -> ChunkingMeta:
        meta = ChunkingMeta(**self.__dict__)
        for k, v in kwargs.items():
            setattr(meta, k, v)
        return meta

    @staticmethod
    def equal(
        lhs_meta: ChunkingMeta, rhs_meta: ChunkingMeta, skip_scale_by: bool = False
    ) -> bool:
        if skip_scale_by:
            lhs_meta = lhs_meta.copy(scale_by=None)
            rhs_meta = rhs_meta.copy(scale_by=None)
        return lhs_meta == rhs_meta

    def chunked_by_dim(self, dim: int) -> bool:
        return self.chunk_dim == dim

    @staticmethod
    def is_nop(meta: ChunkingMeta | None) -> bool:
        return meta is None or meta == ChunkingMeta()
