# mypy: allow-untyped-defs
import logging
import math
from dataclasses import dataclass
from functools import lru_cache
from typing import Optional

import torch
import torch.distributed._functional_collectives as funcol
import torch.distributed.tensor._dtensor_spec as dtensor_spec
from torch._C._distributed_c10d import _resolve_process_group
from torch._logging import warning_once
from torch.distributed._local_tensor import (
    local_tensor_mode,
    maybe_run_for_local_tensor,
)
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh
from torch.distributed.distributed_c10d import (
    _get_group_size_by_name,
    broadcast,
    get_group_rank,
    get_rank,
    ProcessGroup,
    scatter,
    Work,
)


logger = logging.getLogger(__name__)


@torch.library.register_fake("_dtensor::shard_dim_alltoall")
def _shard_dim_alltoall_meta(input, gather_dim, shard_dim, group_name):
    group_size = _get_group_size_by_name(group_name)
    stacked_list = [torch.empty_like(input) for _ in range(group_size)]
    group = _resolve_process_group(group_name)
    group_rank = get_group_rank(group, get_rank())

    return (
        torch.cat(stacked_list, dim=gather_dim)
        .chunk(group_size, dim=shard_dim)[group_rank]
        .contiguous()
    )


def shard_dim_alltoall(input, gather_dim, shard_dim, mesh, mesh_dim):
    if mesh.device_type == "cpu" and local_tensor_mode() is None:
        # Gloo does not support alltoall, so falling back to allgather + chunk
        warning_once(
            logger,
            "CPU process group does not support alltoall yet, falling back with allgather + chunk!",
        )
        out = funcol.all_gather_tensor(input, gather_dim, (mesh, mesh_dim))
        if isinstance(out, funcol.AsyncCollectiveTensor):
            # stick to the same behavior for the alltoall case, remove this once we enable alltoall async
            out = out.wait()
        out = torch.chunk(out, mesh.size(mesh_dim), dim=shard_dim)[
            mesh.get_local_rank(mesh_dim)
        ]
        return out.contiguous()

    group_name = funcol._resolve_group_name((mesh, mesh_dim))
    # TODO: enable async op for shard_dim_alltoall
    return torch.ops._dtensor.shard_dim_alltoall(
        input, gather_dim, shard_dim, group_name
    )


def mesh_scatter(
    output: torch.Tensor,
    scatter_list: list[torch.Tensor],
    mesh: DeviceMesh,
    mesh_dim: int = 0,
    async_op: bool = False,
    *,
    group_src: int = 0,
) -> Work | None:
    """
    scatter a list of tensors to a device mesh dimension. We by default
    use the first rank of the mesh dimension as the source of truth, i.e
    for a 2d mesh [[0, 1], [2, 3]], if we scatter on mesh_dim = 1, we will
    scatter the tensor list on rank 0 to rank 0/1, and tensor list on rank
    2 to rank 2/3.

    Args:
        output (torch.Tensor): the tensor to receive the scattered list.
        scatter_list (List[torch.Tensor]): the tensor list to be scattered.
        mesh_dim (int, optional): indicate which mesh dimension we want
            to scatter on, we by default choose the first rank on the
            mesh dimension as source of truth.

    Keyword args:
        group_src (int, optional): the group rank of the source data for the
        logical/global tensor, on the specific mesh dimension. By default, we
        use ``group_rank=0`` on each DeviceMesh dimension as the source data
        to preserve the single-device semantic. If passing ``None`` explicitly,
        this method simply uses its local data with no communication.

    Returns:
        A :class:`Work` object
    """
    # TODO: Ideally we should use the meta tensor way
    # (to register a meta kernel for the collective op)
    # so that it would avoid the communication. Need to
    # remove the check below once that is done.
    if output.is_meta:
        return None
    dim_group = mesh.get_group(mesh_dim)
    assert isinstance(dim_group, ProcessGroup)

    if group_src == get_rank(dim_group):
        fut = scatter(
            output,
            scatter_list=scatter_list,
            group=dim_group,
            async_op=async_op,
            group_src=group_src,
        )
    else:
        fut = scatter(
            output,
            scatter_list=None,
            group=dim_group,
            async_op=async_op,
            group_src=group_src,
        )

    return fut


def mesh_broadcast(
    tensor: torch.Tensor,
    mesh: DeviceMesh,
    mesh_dim: int = 0,
    async_op: bool = False,
    *,
    group_src: int = 0,
) -> Work | None:
    """
    broadcast the tensor to a device mesh dimension. We by default
    use the first rank of the mesh dimension as the source of truth, i.e
    for a 2d mesh [[0, 1], [2, 3]], if we broadcast on mesh_dim = 1, we will
    broadcast the tensor on rank 0 to rank 0/1, and tensor on rank 2
    to rank 2/3.

    Args:
        tensor (torch.Tensor): tensor to broadcast.
        mesh_dim (int, optional): indicate which mesh dimension we want
            to scatter on, we by default choose the first rank on the
            mesh dimension as source of truth.

    Keyword args:
        group_src (int, optional): the group rank of the source data for the
        logical/global tensor, on the specific mesh dimension. By default, we
        use ``group_rank=0`` on each DeviceMesh dimension as the source data
        to preserve the single-device semantic. If passing ``None`` explicitly,
        this method simply uses its local data with no communication.

    Returns:
        A :class:`Work` object
    """
    # TODO: Ideally we should use the meta tensor way
    # (to register a meta kernel for the collective op)
    # so that it would avoid the communication. Need to
    # remove the check below once that is done.
    if tensor.is_meta:
        return None
    dim_group = mesh.get_group(mesh_dim)
    assert isinstance(dim_group, ProcessGroup)

    return broadcast(tensor, group=dim_group, async_op=async_op, group_src=group_src)


@maybe_run_for_local_tensor
def pad_tensor(tensor: torch.Tensor, pad_dim: int, pad_size: int) -> torch.Tensor:
    if pad_size == 0:
        return tensor
    pad = [0, 0] * (tensor.ndim - pad_dim)
    pad[-1] = pad_size
    return torch.nn.functional.pad(tensor, pad)


@maybe_run_for_local_tensor
def unpad_tensor(tensor: torch.Tensor, pad_dim: int, pad_size: int) -> torch.Tensor:
    if pad_size == 0:
        return tensor
    return tensor.narrow(
        pad_dim,
        start=0,
        length=tensor.size(pad_dim) - pad_size,
    )


def fill_empty_tensor_to_shards(
    shards: list[torch.Tensor], shard_dim: int, num_empty_tensors: int
) -> list[torch.Tensor]:
    if num_empty_tensors == 0:
        return shards
    tensor_size = list(shards[0].size())
    tensor_size[shard_dim] = 0
    tensor = shards[0].new_zeros(tensor_size)
    shards.extend(tensor for _ in range(num_empty_tensors))
    return shards


def check_tensor_meta(
    local_tensor, check_shape_stride=False
) -> Optional["dtensor_spec.TensorMeta"]:
    local_metadata = {
        "dtype": local_tensor.dtype,
        "requires_grad": local_tensor.requires_grad,
    }

    if check_shape_stride:
        local_metadata.update(
            {"shape": local_tensor.shape, "stride": local_tensor.stride()}
        )

    gathered_metadata = [None for _ in range(torch.distributed.get_world_size())]
    torch.distributed.all_gather_object(gathered_metadata, local_metadata)

    # Check if metadata is consistent across ranks
    if not all(meta == local_metadata for meta in gathered_metadata):
        raise ValueError(
            "Inconsistent tensor metadata (including shape and stride) across ranks."
        )
    return None


def spec_to_bytes(spec: "dtensor_spec.DTensorSpec") -> int:
    assert spec.tensor_meta is not None, "spec should have tensor meta defined!"
    return spec.tensor_meta.dtype.itemsize * math.prod(spec.shape)


@dataclass
class MeshTopoInfo:
    """
    Mesh information for collective cost estimation
    """

    mesh: DeviceMesh
    mesh_dim_devices: list[int]
    mesh_dim_bandwidth: list[float]
    mesh_dim_latency: list[float]

    @staticmethod
    @lru_cache(None)
    def build_from_mesh(mesh: DeviceMesh) -> "MeshTopoInfo":
        # Generate mesh topology info for intra-host/inter-host communication pattern
        # Note that we made bunch of assumptions for simplicity:
        # 1. we assume the mesh is homogeneous, and it's gpu/nccl model
        # 2. we assume gpu arch is Ampere or Hopper
        # 3. we assume collectives are all ring base algo for now
        num_devices_per_host = _mesh_resources.num_devices_per_host(mesh.device_type)
        # the base bw number (intra-node), GB/s
        base_bw = 87.7
        mesh_dim_bandwidth = [base_bw] * mesh.ndim
        # the latency in terms of us (intra-node, nv-link)
        mesh_dim_latency = [0.6] * mesh.ndim
        mesh_dim_devices = [1] * mesh.ndim

        total_num_devices = 1
        for mesh_dim in reversed(range(mesh.ndim)):
            num_devices = mesh.size(mesh_dim)
            mesh_dim_devices[mesh_dim] = num_devices
            total_num_devices *= num_devices
            if total_num_devices > num_devices_per_host:
                # magic number for inter-host communication bandwidth/latency factor
                # This number assumes latest GPU arch, i.e. Ampere or Hopper
                # TODO: see if we need to tweak this or offer a way for user
                # to specify the bandwidths/latency
                mesh_dim_bandwidth[mesh_dim] *= 0.22
                # set to ethernet latency for inter-host
                mesh_dim_latency[mesh_dim] = 2.7

        return MeshTopoInfo(
            mesh, mesh_dim_devices, mesh_dim_bandwidth, mesh_dim_latency
        )


def allgather_cost(bytes_gb: float, mesh_topo: MeshTopoInfo, mesh_dim: int) -> float:
    num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[mesh_dim]
    mesh_dim_bandwidth = mesh_topo.mesh_dim_bandwidth[mesh_dim]
    num_hops = num_devices_on_mesh_dim - 1
    # base latency + comm latency
    latency = 6.6 + num_hops * mesh_topo.mesh_dim_latency[mesh_dim]  # us
    bw = (bytes_gb * num_hops / num_devices_on_mesh_dim) / mesh_dim_bandwidth  # s
    return latency + bw * 1e6  # rescale to us


def allreduce_cost(bytes_gb: float, mesh_topo: MeshTopoInfo, mesh_dim: int) -> float:
    num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[mesh_dim]
    mesh_dim_bandwidth = mesh_topo.mesh_dim_bandwidth[mesh_dim]
    # allreduce have almost 2x comm bytes compare to allgather/reduce_scatter
    num_hops = 2 * (num_devices_on_mesh_dim - 1)

    latency = 6.6 + num_hops * mesh_topo.mesh_dim_latency[mesh_dim]
    bw = (bytes_gb * num_hops / num_devices_on_mesh_dim) / mesh_dim_bandwidth
    return latency + bw * 1e6


def reduce_scatter_cost(
    bytes_gb: float,
    mesh_topo: MeshTopoInfo,
    mesh_dim: int,
) -> float:
    num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[mesh_dim]
    mesh_dim_bandwidth = mesh_topo.mesh_dim_bandwidth[mesh_dim]
    num_hops = num_devices_on_mesh_dim - 1
    # base latency + comm latency
    latency = 6.6 + num_hops * mesh_topo.mesh_dim_latency[mesh_dim]
    bw = (bytes_gb * num_hops / num_devices_on_mesh_dim) / mesh_dim_bandwidth
    return latency + bw * 1e6


def _compute_placement_transition_cost(
    current_placement: "dtensor_spec.Placement",
    target_placement: "dtensor_spec.Placement",
    mesh_topo: MeshTopoInfo,
    mesh_dim: int,
    comm_bytes_gb: float,
) -> tuple[float, float]:
    """
    Compute the cost of transitioning from one placement to another on a single mesh dimension.

    Args:
        current_placement: The current placement on the mesh dimension.
        target_placement: The target placement on the mesh dimension.
        mesh_topo: Mesh topology information for cost estimation.
        mesh_dim: The mesh dimension where the transition happens.
        comm_bytes_gb: The communication bytes in GB for this step.

    Returns:
        A tuple of (cost, updated_comm_bytes_gb):
            - cost: The communication cost for this transition (float("inf") if invalid).
            - updated_comm_bytes_gb: The updated communication bytes after this step.
    """
    if current_placement == target_placement:
        return 0.0, comm_bytes_gb

    num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[mesh_dim]

    if current_placement.is_shard() and target_placement.is_replicate():
        # allgather gives larger comm bytes
        comm_bytes_gb *= num_devices_on_mesh_dim
        return allgather_cost(comm_bytes_gb, mesh_topo, mesh_dim), comm_bytes_gb
    elif current_placement.is_shard() and target_placement.is_shard():
        # should be alltoall comm, since we haven't implement it yet, add 1.0 as penalty
        # to favor allgather instead
        # TODO: add alltoall_cost
        return allgather_cost(comm_bytes_gb, mesh_topo, mesh_dim) + 1.0, comm_bytes_gb
    elif current_placement.is_partial() and target_placement.is_replicate():
        return allreduce_cost(comm_bytes_gb, mesh_topo, mesh_dim), comm_bytes_gb
    elif current_placement.is_partial() and target_placement.is_shard():
        cost = reduce_scatter_cost(comm_bytes_gb, mesh_topo, mesh_dim)
        # after reduce_scatter the comm bytes for further collectives halved.
        comm_bytes_gb /= num_devices_on_mesh_dim
        return cost, comm_bytes_gb
    elif current_placement.is_shard() and target_placement.is_partial():
        # ban shard -> partial as it does not make sense to perform
        # this redistribute
        return float("inf"), comm_bytes_gb
    elif current_placement.is_partial() and target_placement.is_partial():
        # we already handled the == case at the top, and we ban converting between partial types.
        return float("inf"), comm_bytes_gb
    elif current_placement.is_replicate() and target_placement.is_shard():
        comm_bytes_gb /= num_devices_on_mesh_dim
        return 0.0, comm_bytes_gb

    return 0.0, comm_bytes_gb


def one_step_redistribute_cost(
    current_spec: "dtensor_spec.DTensorSpec",
    target_spec: "dtensor_spec.DTensorSpec",
) -> float:
    """
    Calculate the cost of a single redistribution step between two DTensorSpecs.

    This function computes the communication cost for a one-step redistribution
    where the current and target specs differ by exactly one placement on one
    mesh dimension.

    Args:
        current_spec: The current DTensorSpec.
        target_spec: The target DTensorSpec.

    Returns:
        The communication cost for this step (float("inf") if invalid).
    """
    if current_spec.mesh != target_spec.mesh:
        return float("inf")

    if current_spec.placements == target_spec.placements:
        return 0.0

    # Find the mesh dimension that differs
    mesh_dim = -1
    current_placement = None
    target_placement = None
    for i, (cur, tgt) in enumerate(
        zip(current_spec.placements, target_spec.placements)
    ):
        if cur != tgt:
            if mesh_dim != -1:
                # More than one dimension differs - not a single step
                raise ValueError(
                    "one_step_redistribute_cost expects specs that differ by exactly one placement"
                )
            mesh_dim = i
            current_placement = cur
            target_placement = tgt

    if mesh_dim == -1:
        return 0.0

    assert current_placement is not None and target_placement is not None

    mesh_topo = MeshTopoInfo.build_from_mesh(current_spec.mesh)
    comm_bytes_gb = (
        spec_to_bytes(current_spec) / current_spec.num_shards / 1024 / 1024 / 1024
    )

    cost, _ = _compute_placement_transition_cost(
        current_placement, target_placement, mesh_topo, mesh_dim, comm_bytes_gb
    )
    return cost


def redistribute_cost(
    current_spec: "dtensor_spec.DTensorSpec",
    target_spec: "dtensor_spec.DTensorSpec",
) -> float:
    """
    This function returns the cost of redistribute from current to target DTensorSpec.

    NOTE:
    1. Only consider communication cost here, since computation costs for redistribute
       are quite trivial (i.e. we only need to narrow or simple division)
    2. Only consider redistribute cost on same mesh, cross mesh communication cost is
       not quite needed for operator strategy estimation/selection.
    """
    if current_spec.mesh != target_spec.mesh:
        # make infinite cost if meshes are not same
        # TODO: see if we want to support this once there's cross mesh communication
        return float("inf")
    if current_spec.is_replicated():
        # short-cut: comm cost is 0 if current spec is already full replication
        return 0.0

    # TODO(zpcore): test placements with _StridedShard if we replace shard_order
    # with _StridedShard.
    if (
        current_spec.placements == target_spec.placements
        and current_spec.shard_order == target_spec.shard_order
    ):
        return 0.0

    mesh_topo = MeshTopoInfo.build_from_mesh(current_spec.mesh)
    cost = 0.0
    comm_bytes_gb = (
        spec_to_bytes(current_spec) / current_spec.num_shards / 1024 / 1024 / 1024
    )
    # Transformation that considered for redistribute cost:
    # 1. allgather 2. alltoall
    # 3. allreduce 4. reduce_scatter
    from torch.distributed._functional_collectives import _are_we_tracing
    from torch.distributed.tensor._redistribute import (
        _gen_transform_infos,
        _gen_transform_infos_non_cached,
    )

    # TODO(zpcore): Support _StridedShard redistribution. Remove the temporary
    # fix, which is to prevent StridedShard erroring out.
    if current_spec.shard_order is None or target_spec.shard_order is None:
        return float("inf")

    # No redistribution needed when placements are already identical.
    # This also prevents potential failures in _gen_transform_infos for certain configurations
    # (e.g., sub-meshes) where finding a transform path between identical states may error out.
    # TODO(zpcore): test placements with _StridedShard if we replace shard_order
    # with _StridedShard.
    if (
        current_spec.placements == target_spec.placements
        and current_spec.shard_order == target_spec.shard_order
    ):
        return cost

    if _are_we_tracing():
        transform_infos = _gen_transform_infos_non_cached(current_spec, target_spec)
    else:
        transform_infos = _gen_transform_infos(current_spec, target_spec)
    for transform_info in transform_infos:
        assert current_spec.tensor_meta is not None, (
            "spec should have tensor meta defined!"
        )
        current = transform_info.src_dst_placements[0]
        target = transform_info.src_dst_placements[1]
        mesh_dim = transform_info.mesh_dim
        step_cost, comm_bytes_gb = _compute_placement_transition_cost(
            current, target, mesh_topo, mesh_dim, comm_bytes_gb
        )
        if step_cost == float("inf"):
            return float("inf")
        cost += step_cost

    return cost
