# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
import math
from collections.abc import Sequence
from dataclasses import dataclass
from enum import Enum
from typing import cast, Union

import torch
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor._dtensor_spec import DTensorSpec
from torch.distributed.tensor._op_schema import (
    OpSchema,
    OpSpec,
    OpStrategy,
    PlacementList,
    RuntimeSchemaInfo,
    TupleStrategy,
)
from torch.distributed.tensor._ops.utils import (
    as_list,
    expand_to_full_mesh_op_strategy,
    generate_redistribute_costs,
    is_tensor_evenly_shardable,
    is_tensor_evenly_shardable_on_dim,
    normalize_dim,
    normalize_dims,
    register_op_strategy,
)
from torch.distributed.tensor._utils import normalize_to_torch_size
from torch.distributed.tensor.placement_types import (
    _StridedShard,
    Partial,
    Placement,
    Replicate,
    Shard,
)


aten = torch.ops.aten


class Reduction(Enum):
    NONE = 0
    MEAN = 1
    SUM = 2


@dataclass(frozen=True)
class NormReduction:
    norm_type: int | float | str


ReductionOpType = Union[NormReduction, str]


@dataclass(frozen=True)
class _NormPartial(Partial):
    """
    This placement is used for partial vector norm.

    For p-norms (where p not inf or -inf), the p-norm over n elements computes
        (sum_i x_i^p)^(1/p)
    where the sum is from i=1 to n. The reduction op is the p-norm itself.
    For example, consider 2 ranks, a (4,) tensor sharded on dim-0, and 2-norm:
        Rank 0: [t1, t2] | Rank 1: [t3, t4]
    After computing 2-norm per gradient (partial placement):
        Rank 0: [sqrt(t1^2 + t2^2)] | Rank 1: [sqrt(t3^2 + t4^2)]
    Converting from partial to replicate wants to ultimately get:
        Rank 0/1: [sqrt(t1^2 + t2^2 + t3^2 + t4^2)]
    This can be achieved by computing 2-norm on each rank's result. This holds
    similarly for inf and -inf norm. For 0-norm, the reduction op is sum.
    """

    norm_type: int | float | str = 2

    def __init__(self, norm_type: int | float | str = 2):
        reduce_op = None
        if norm_type in (float("inf"), "inf"):
            reduce_op = "max"
        elif norm_type in (float("-inf"), "-inf"):
            reduce_op = "min"
        elif isinstance(norm_type, (int, float)):
            reduce_op = "sum"
        else:
            raise NotImplementedError(f"Unsupported norm type: {norm_type}")

        super().__init__(reduce_op)
        object.__setattr__(self, "norm_type", norm_type)

    def _partition_value(
        self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
    ) -> torch.Tensor:
        """
        For example, consider 4 ranks, a (3,) replicated tensor, and 2-norm:
            Ranks 0 and 1: sqrt(t1^2 + t2^2 + t3^3)
        To convert from replicated to partial, we want f(x) such that
            sqrt(t1^2 + t2^2 + t3^3) = sqrt(4f(t1)^2 + 4f(t2)^2 + 4f(t3)^2)
                                     = sqrt(4) sqrt(f(t1)^2 + f(t2)^2 + f(t3)^2).
        One such f(x) is f(x) = x / sqrt(4). This generalizes to d ranks and
        p-norm as f(x) = x / d^(1/p).
        """
        if self.reduce_op in ("max", "min"):
            return tensor
        elif self.reduce_op == "sum":
            if self.norm_type == 0:
                raise NotImplementedError(f"Unsupported norm type:: {self.norm_type}")
            elif self.norm_type == 1:
                return tensor / mesh.size(mesh_dim)
            if not isinstance(self.norm_type, (int, float)):
                raise AssertionError(
                    f"Expected int or float, got {type(self.norm_type)}"
                )
            return tensor / math.pow(mesh.size(mesh_dim), 1 / self.norm_type)
        raise NotImplementedError(self.reduce_op)

    def _reduce_shard_value(
        self,
        tensor: torch.Tensor,
        mesh: DeviceMesh,
        mesh_dim: int,
        shard_spec: Placement,
    ) -> torch.Tensor:
        if not isinstance(shard_spec, Shard):
            raise AssertionError(f"Expected Shard, got {type(shard_spec)}")
        tensor = self._pre_reduce_transform(tensor)
        reduced_tensor = super()._reduce_shard_value(tensor, mesh, mesh_dim, shard_spec)
        return self._post_reduce_transform(reduced_tensor)

    def _reduce_value(
        self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
    ) -> torch.Tensor:
        tensor = self._pre_reduce_transform(tensor)
        reduced_tensor = super()._reduce_value(tensor, mesh, mesh_dim)
        return self._post_reduce_transform(reduced_tensor)

    def _pre_reduce_transform(self, tensor: torch.Tensor) -> torch.Tensor:
        if self.reduce_op == "sum":
            if not isinstance(self.norm_type, (int, float)):
                raise AssertionError(
                    f"Expected int or float, got {type(self.norm_type)}"
                )
            if self.norm_type != 0 and self.norm_type != 1:
                return tensor**self.norm_type
        return tensor

    def _post_reduce_transform(self, tensor: torch.Tensor) -> torch.Tensor:
        if self.reduce_op == "sum":
            if not isinstance(self.norm_type, (int, float)):
                raise AssertionError(
                    f"Expected int or float, got {type(self.norm_type)}"
                )
            if self.norm_type != 0 and self.norm_type != 1:
                return tensor ** (1.0 / self.norm_type)
        return tensor

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, _NormPartial):
            return False
        return self.norm_type == other.norm_type

    def __hash__(self) -> int:
        return 1 + hash(self.norm_type)

    def __repr__(self) -> str:
        """
        machine readable representation of the _NormPartial placement
        """
        return f"_NormPartial(reduce_op={self.reduce_op}, norm_type={self.norm_type})"

    def __str__(self) -> str:
        """human readable representation of the _NormPartial placement"""
        return f"_NormP({self.reduce_op}, {self.norm_type})"


def _infer_reduction_dims(dims_arg: object, ndim: int) -> list[int] | None:
    if dims_arg is None:
        return None
    dims = cast(list[int], as_list(dims_arg))
    dims = cast(list[int], normalize_dims(dims, ndim))
    empty_dims = [[0], [-1], []]
    if ndim == 0 and dims_arg in empty_dims:
        return None
    return dims


def _infer_reduce_dims_map(
    reduction_dims: list[int], input_ndim: int, keep_dim=False
) -> list[int]:
    reduction_dims_map = []
    new_dim_count = 0
    for input_dim in range(input_ndim):
        if input_dim in reduction_dims and not keep_dim:
            # if input dim in reduction dims, mark it as -1
            reduction_dims_map.append(-1)
        else:
            # otherwise mark it as the new dim
            reduction_dims_map.append(new_dim_count)
            new_dim_count += 1

    return reduction_dims_map


def _replicate_dims_start_at(
    placements: Sequence[Placement], start_dim: int = 0
) -> tuple[Placement, ...]:
    new_placements: list[Placement] = []
    for p in placements:
        if p.is_partial() or (isinstance(p, Shard) and p.dim >= start_dim):
            new_placements.append(Replicate())  # make it replicate
        else:
            new_placements.append(p)  # keep the placement
    return tuple(new_placements)


# return new_placements which align with placements but skip the skipped_dim
def _skip_dim(
    placements: tuple[Placement, ...], skipped_dim: int
) -> tuple[Placement, ...]:
    new_placements: list[Placement] = []
    for p in placements:
        if isinstance(p, Shard) and p.dim >= skipped_dim:
            new_placements.append(Shard(p.dim - 1))
        else:
            new_placements.append(p)
    return tuple(new_placements)


def replicate_reduction_dims(
    placements: tuple[Placement, ...], reduction_dims: list[int]
) -> tuple[Placement, ...]:
    # replicate the reduction dims if not reduction_linear
    new_placements: list[Placement] = []

    for p in placements:
        if p.is_partial():
            new_placements.append(Replicate())
        elif isinstance(p, Shard) and p.dim in reduction_dims:
            new_placements.append(Replicate())
        else:
            new_placements.append(p)

    return tuple(new_placements)


def map_placements_after_reduction(
    placements: tuple[Placement, ...],
    reduction_dims: list[int],
    reduction_dims_map: list[int],
    reduction_op: ReductionOpType,
) -> tuple[Placement, ...]:
    """
    Map each placement based on the output shape after reduction.
    """
    new_placements: list[Placement] = []
    for placement in placements:
        if isinstance(placement, (Replicate, Partial)):
            new_placements.append(placement)
        else:
            if not isinstance(placement, Shard | _StridedShard):
                raise AssertionError(
                    f"Expected Shard/_StridedShard, got {type(placement)}"
                )
            shard_dim = placement.dim
            new_shard_dim = reduction_dims_map[shard_dim]
            if new_shard_dim == -1 or shard_dim in reduction_dims:
                # if new_shard_dim collapsed or its in the reduction dims
                # (i.e. for the case where keepdims=True), we generate partial
                new_placements.append(get_placement_from_reduction_op(reduction_op))
            else:
                if isinstance(placement, Shard):
                    new_placements.append(Shard(new_shard_dim))
                else:
                    new_placements.append(
                        _StridedShard(
                            new_shard_dim, split_factor=placement.split_factor
                        )
                    )
    return tuple(new_placements)


def get_placement_from_reduction_op(reduction_op: ReductionOpType) -> Placement:
    if isinstance(reduction_op, NormReduction):
        return _NormPartial(norm_type=reduction_op.norm_type)
    return Partial(reduction_op)


def common_reduction_strategy(
    input_strategy: OpStrategy,
    reduce_dims: list[int],
    keep_dim: bool = False,
    reduction_linear: bool = True,
    reduction_op: ReductionOpType = "sum",
) -> OpStrategy:
    """
    reduction_linear means that the reduction `f` follows this rule:
        f([f(a), f(b)]) = f([a, b])

    reduction linear should be super set of linearity.
    """
    # by default follow reduction input strategy
    reduction_strategy = OpStrategy([])

    for op_spec in input_strategy.strategies:
        if reduction_op == "avg":
            output_spec = op_spec.output_spec
            local_shape = list(output_spec.tensor_meta.shape)  # type:ignore[union-attr]
            for dim in reduce_dims:
                if not is_tensor_evenly_shardable_on_dim(local_shape, output_spec, dim):
                    # reduce(avg) is not linear for unevenly sharded tensors
                    reduction_linear = False
                    break

        for p in op_spec.output_spec.placements:
            # when the partial reduction op matches the global reduction op,
            # we can delay redistribution (i.e max, max)
            if isinstance(p, Partial) and p.reduce_op != reduction_op:
                reduction_linear = False
                break

        if not reduction_linear:
            # input placements for this strategy should clear out pending sum and sharding
            # on the reduction dimension
            input_placements = replicate_reduction_dims(
                op_spec.output_spec.placements, reduce_dims
            )
        else:
            input_placements = op_spec.output_spec.placements

        input_spec = DTensorSpec(
            mesh=input_strategy.mesh,
            placements=input_placements,
            tensor_meta=op_spec.output_spec.tensor_meta,
        )

        reduce_dims_map = _infer_reduce_dims_map(reduce_dims, input_spec.ndim, keep_dim)
        out_placements = map_placements_after_reduction(
            input_spec.placements, reduce_dims, reduce_dims_map, reduction_op
        )
        redistribute_cost = [generate_redistribute_costs(input_strategy, input_spec)]
        reduction_strategy.strategies.append(
            OpSpec(
                output_specs=DTensorSpec(
                    mesh=input_strategy.mesh,
                    placements=out_placements,
                ),
                input_specs=(input_spec,),
                redistribute_cost=redistribute_cost,
            )
        )

    return reduction_strategy


LINEAR_REDUCTION_OP_MAP = {
    aten.all.default: "product",
    aten.all.dim: "product",
    aten.sum.default: "sum",
    aten.sum.dim_IntList: "sum",
    aten.any.default: "sum",
    aten.any.dim: "sum",
    aten.any.out: "sum",
    # These are only valid when there is no padding
    aten.prod.default: "product",
    aten.prod.dim_int: "product",
    aten.prod.int_out: "product",
    # avg is only linear when there is no padding
    aten.mean.default: "avg",
    aten.mean.dim: "avg",
    aten.mean.out: "avg",
    aten.max.default: "max",
    aten.max.dim: "max",
    aten.max.out: "max",
    aten.min.default: "min",
    aten.min.dim: "min",
    aten.min.out: "min",
    aten.amax.default: "max",
    aten.amax.out: "max",
    aten.amin.default: "min",
    aten.amin.out: "min",
    # argmax and argmin is using custom hanndler leveraging linear reduction of max and min
    aten.argmax.default: "max",
    aten.argmin.default: "min",
}


@register_op_strategy(
    list(LINEAR_REDUCTION_OP_MAP.keys()), schema_info=RuntimeSchemaInfo(1)
)
def linear_reduction_strategy(op_schema: OpSchema) -> OpStrategy:
    args_schema = op_schema.args_schema
    input_strategy = args_schema[0]
    if not isinstance(input_strategy, OpStrategy):
        raise AssertionError(f"Expected OpStrategy, got {type(input_strategy)}")

    dims = None
    if len(op_schema.args_schema) > 1:
        dims = _infer_reduction_dims(args_schema[1], input_strategy.ndim)

    reduce_dims = list(range(input_strategy.ndim)) if dims is None else dims

    keep_dim = len(op_schema.args_schema) > 2 and bool(op_schema.args_schema[2])
    reduction_op = LINEAR_REDUCTION_OP_MAP[op_schema.op]
    return common_reduction_strategy(
        input_strategy,
        reduce_dims,
        keep_dim=keep_dim,
        reduction_linear=True,
        reduction_op=reduction_op,
    )


@register_op_strategy(aten.cumsum.default, schema_info=RuntimeSchemaInfo(1))
def cumsum_strategy(op_schema: OpSchema) -> OpStrategy:
    args_schema = op_schema.args_schema
    input_strategy = args_schema[0]
    if not isinstance(input_strategy, OpStrategy):
        raise AssertionError(f"Expected OpStrategy, got {type(input_strategy)}")
    dim = args_schema[1]
    if not isinstance(dim, int):
        raise AssertionError(f"Expected int, got {type(dim)}")

    return common_reduction_strategy(
        input_strategy, [dim], keep_dim=True, reduction_linear=False
    )


@register_op_strategy(
    [
        aten.std.correction,
        aten.std.correction_out,
        aten.var.correction,
        aten.var.correction_out,
    ],
    schema_info=RuntimeSchemaInfo(1, ["keepdim"]),
)
def std_var_reduction_strategy(op_schema: OpSchema) -> OpStrategy:
    args_schema = op_schema.args_schema
    input_strategy = args_schema[0]
    if not isinstance(input_strategy, OpStrategy):
        raise AssertionError(f"Expected OpStrategy, got {type(input_strategy)}")
    dims = None
    if len(op_schema.args_schema) > 1:
        dims = _infer_reduction_dims(args_schema[1], input_strategy.ndim)

    reduce_dims = list(range(input_strategy.ndim)) if dims is None else dims

    keep_dim = cast(bool, op_schema.kwargs_schema.get("keepdim", False))
    return common_reduction_strategy(
        input_strategy, reduce_dims, keep_dim=keep_dim, reduction_linear=False
    )


@register_op_strategy(
    [aten.linalg_vector_norm.default], schema_info=RuntimeSchemaInfo(1)
)
def vector_norm_strategy(op_schema: OpSchema) -> OpStrategy:
    args_schema = op_schema.args_schema
    input_strategy = args_schema[0]
    if not isinstance(input_strategy, OpStrategy):
        raise AssertionError(f"Expected OpStrategy, got {type(input_strategy)}")

    norm_type = args_schema[1] if len(args_schema) > 1 else 2
    if not isinstance(norm_type, (int, float, str)):
        raise AssertionError(f"Expected int, float, or str, got {type(norm_type)}")
    dim = args_schema[2] if len(args_schema) > 2 else None
    keepdim = args_schema[3] if len(args_schema) > 3 else False
    dims = _infer_reduction_dims(dim, input_strategy.ndim)
    reduce_dims = list(range(input_strategy.ndim)) if dims is None else dims
    return common_reduction_strategy(
        input_strategy,
        reduce_dims,
        keep_dim=cast(bool, keepdim),
        reduction_linear=True,
        reduction_op=NormReduction(norm_type),
    )


@register_op_strategy(
    [aten._foreach_norm.Scalar], schema_info=RuntimeSchemaInfo(1, needs_pytree=True)
)
def foreach_norm_strategy(op_schema: OpSchema) -> TupleStrategy:
    args_schema = op_schema.args_schema
    input_tuple_strategy = args_schema[0]
    if not isinstance(input_tuple_strategy, TupleStrategy):
        raise AssertionError(
            f"Expected TupleStrategy, got {type(input_tuple_strategy)}"
        )
    norm_type = args_schema[1] if len(args_schema) > 1 else 2
    if not isinstance(norm_type, (int, float, str)):
        raise AssertionError(f"Expected int, float, or str, got {type(norm_type)}")
    output_tuple_strategy_children: list[OpStrategy] = []
    for op_strategy in input_tuple_strategy.children:
        if not isinstance(op_strategy, OpStrategy):
            raise AssertionError(f"Expected OpStrategy, got {type(op_strategy)}")
        reduce_dims = list(range(op_strategy.ndim))
        output_strategy = common_reduction_strategy(
            op_strategy,
            reduce_dims,
            reduction_linear=True,
            reduction_op=NormReduction(norm_type),
        )
        output_tuple_strategy_children.append(output_strategy)
    return TupleStrategy(output_tuple_strategy_children)


@register_op_strategy(
    [aten._foreach_max.default], schema_info=RuntimeSchemaInfo(1, needs_pytree=True)
)
def foreach_max_strategy(op_schema: OpSchema) -> TupleStrategy:
    """
    Strategy for _foreach_max, which reduces each tensor in a list to its maximum value.
    """
    args_schema = op_schema.args_schema
    input_tuple_strategy = args_schema[0]
    if not isinstance(input_tuple_strategy, TupleStrategy):
        raise AssertionError(
            f"Expected TupleStrategy, got {type(input_tuple_strategy)}"
        )
    output_tuple_strategy_children: list[OpStrategy] = []
    for op_strategy in input_tuple_strategy.children:
        if not isinstance(op_strategy, OpStrategy):
            raise AssertionError(f"Expected OpStrategy, got {type(op_strategy)}")
        # Reduce all dimensions to get a scalar
        reduce_dims = list(range(op_strategy.ndim))
        output_strategy = common_reduction_strategy(
            op_strategy,
            reduce_dims,
            reduction_linear=True,
            reduction_op="max",
        )
        output_tuple_strategy_children.append(output_strategy)
    return TupleStrategy(output_tuple_strategy_children)


@register_op_strategy(
    [
        aten._linalg_svd.default,
        aten.linalg_qr.default,
        # TODO: The diagonal ops can have an improved sharding strategy for
        # shard placements that does not require redistributing to replicate.
        aten.diagonal_copy.default,
        aten.diag_embed.default,
        aten.diag.default,
        aten.diagonal.default,
        aten.tril.default,
        aten.triu.default,
        aten._linalg_eigh.default,
        aten.upsample_bicubic2d.default,
        aten.upsample_bilinear2d.default,
        aten.upsample_linear1d.default,
        aten.upsample_nearest2d.default,
        aten.upsample_trilinear3d.default,
        # TODO: support the full F.interpolate set of options.
    ],
    schema_info=RuntimeSchemaInfo(1),
)
def linalg_replicate_strategy(op_schema: OpSchema) -> OpStrategy:
    """
    Since we do not have a simple way to compute some linear algebra operations
    like SVD or QR decomposition, always fall back to replicate.
    """
    args_schema = op_schema.args_schema
    input_strategy = args_schema[0]
    if not isinstance(input_strategy, OpStrategy):
        raise AssertionError(f"Expected OpStrategy, got {type(input_strategy)}")
    mesh = input_strategy.mesh

    output_strategies: list[OpSpec] = []
    for placement_strategy in input_strategy.strategies:
        replicate_placements = tuple(Replicate() for _ in range(mesh.ndim))
        replicate_spec = DTensorSpec(
            mesh=mesh,
            placements=replicate_placements,
            tensor_meta=placement_strategy.output_spec.tensor_meta,
        )
        redistribute_cost = [
            generate_redistribute_costs(input_strategy, replicate_spec)
        ]
        replicate_strategy = OpSpec(
            output_specs=replicate_spec,
            input_specs=(replicate_spec,),
            redistribute_cost=redistribute_cost,
        )
        output_strategies.append(replicate_strategy)
    return OpStrategy(output_strategies)


@register_op_strategy(
    [aten._log_softmax.default, aten._softmax.default, aten._safe_softmax.default],
    schema_info=RuntimeSchemaInfo(1),
)
def softmax_strategy(op_schema: OpSchema) -> OpStrategy:
    input_strategy, softmax_dim, *_ = op_schema.args_schema
    input_strategy = cast(OpStrategy, input_strategy)

    softmax_dim = cast(int, softmax_dim)
    softmax_dim = normalize_dim(softmax_dim, input_strategy.ndim)

    output_strategy = OpStrategy([])
    for input_placement_strategy in input_strategy.strategies:
        redistribute_costs = []
        input_src_spec = input_placement_strategy.output_spec

        # make sure input is replicated along the softmax dim
        input_target_spec = DTensorSpec(
            mesh=input_strategy.mesh,
            placements=replicate_reduction_dims(
                input_src_spec.placements, [softmax_dim]
            ),
            tensor_meta=input_src_spec.tensor_meta,
        )
        redistribute_costs.append(
            generate_redistribute_costs(input_strategy, input_target_spec)
        )
        output_target_spec = input_target_spec
        output_strategy.strategies.append(
            OpSpec(
                output_specs=output_target_spec,
                input_specs=[input_target_spec],
                redistribute_cost=redistribute_costs,
            )
        )

    return output_strategy


@register_op_strategy(
    [
        aten._log_softmax_backward_data.default,
        aten._softmax_backward_data.default,
    ],
    schema_info=RuntimeSchemaInfo(2),
)
def softmax_backward_strategy(op_schema: OpSchema) -> OpStrategy:
    grad_out_strategy, out_strategy, softmax_dim, _ = op_schema.args_schema
    grad_out_strategy = cast(OpStrategy, grad_out_strategy)
    out_strategy = cast(OpStrategy, out_strategy)
    softmax_dim = cast(int, softmax_dim)
    softmax_dim = normalize_dim(softmax_dim, grad_out_strategy.ndim)

    grad_in_strategy = OpStrategy([])
    for grad_out_placement_strat, out_placement_strat in zip(
        grad_out_strategy.strategies, out_strategy.strategies
    ):
        # follow the sharding of the grad_out or out depending on which has more shards
        grad_out_src_spec = grad_out_placement_strat.output_spec
        out_src_spec = out_placement_strat.output_spec
        src_spec = (
            grad_out_src_spec
            if grad_out_src_spec.num_shards >= out_src_spec.num_shards
            else out_src_spec
        )

        # make sure inputs are replicated along the softmax dim
        tgt_spec = DTensorSpec(
            mesh=grad_out_strategy.mesh,
            placements=replicate_reduction_dims(src_spec.placements, [softmax_dim]),
        )
        new_grad_out_spec = DTensorSpec(
            mesh=tgt_spec.mesh,
            placements=tgt_spec.placements,
            tensor_meta=grad_out_src_spec.tensor_meta,
        )
        new_out_spec = DTensorSpec(
            mesh=tgt_spec.mesh,
            placements=tgt_spec.placements,
            tensor_meta=out_src_spec.tensor_meta,
        )
        redist_grad_out_cost = generate_redistribute_costs(grad_out_strategy, tgt_spec)
        redist_out_cost = generate_redistribute_costs(out_strategy, tgt_spec)
        grad_in_strategy.strategies.append(
            OpSpec(
                output_specs=tgt_spec,
                input_specs=(new_grad_out_spec, new_out_spec),
                redistribute_cost=[redist_grad_out_cost, redist_out_cost],
            )
        )

    return grad_in_strategy


@register_op_strategy(
    [aten.nll_loss_forward.default, aten.nll_loss2d_forward.default],
    schema_info=RuntimeSchemaInfo(3),
)
def nll_loss_forward_strategy(op_schema: OpSchema) -> OpStrategy:
    mesh = op_schema.get_mesh_from_args()

    if not len(op_schema.args_schema) == 5:
        raise AssertionError(f"Expected 5 args, got {len(op_schema.args_schema)}")

    (
        input_strategy,
        target_strategy,
        weight_strategy,
        reduction,
        _,
    ) = op_schema.args_schema
    input_strategy = cast(OpStrategy, input_strategy)
    target_strategy = cast(OpStrategy, target_strategy)
    reduction = cast(int, reduction)

    input_shape = input_strategy.shape
    channel_dim = 1 if len(input_shape) >= 2 else 0

    output_strategy = OpStrategy([])
    for idx, input_placement_strategy in enumerate(input_strategy.strategies):
        op_args_target_specs = []
        redistribute_costs = []

        # make sure input is replicated along the channel dim
        input_src_spec = input_placement_strategy.output_spec
        input_expected_spec = DTensorSpec(
            mesh=mesh,
            placements=replicate_reduction_dims(
                input_src_spec.placements, [channel_dim]
            ),
            tensor_meta=input_src_spec.tensor_meta,
        )
        op_args_target_specs.append(input_expected_spec)
        redistribute_costs.append(
            generate_redistribute_costs(input_strategy, input_expected_spec)
        )

        # target doesn't have channel dim, and it follows input on other dims
        target_src_spec = target_strategy.strategies[idx].output_spec
        target_expected_spec = DTensorSpec(
            mesh=mesh,
            placements=_skip_dim(input_expected_spec.placements, channel_dim),
            tensor_meta=target_src_spec.tensor_meta,
        )
        op_args_target_specs.append(target_expected_spec)
        redistribute_costs.append(
            generate_redistribute_costs(target_strategy, target_expected_spec)
        )

        # weight tensor, if given, has to be a Tensor of size input_shape[channel_dim]
        # make sure it is replicated
        if weight_strategy is not None:
            if not isinstance(weight_strategy, OpStrategy):
                raise AssertionError(
                    f"Expected OpStrategy, got {type(weight_strategy)}"
                )
            weight_src_spec = weight_strategy.strategies[idx].output_spec
            weight_expected_spec = DTensorSpec(
                mesh=mesh,
                placements=_replicate_dims_start_at(weight_src_spec.placements),
                tensor_meta=weight_src_spec.tensor_meta,
            )
            op_args_target_specs.append(weight_expected_spec)
            redistribute_costs.append(
                generate_redistribute_costs(weight_strategy, weight_expected_spec)
            )

        if reduction == Reduction.NONE.value:
            output_expected_spec = target_expected_spec
            total_weight_expected_spec = DTensorSpec(
                mesh=mesh, placements=tuple([Replicate()] * mesh.ndim)
            )
        else:
            if reduction == Reduction.MEAN.value:
                reduction_op = "avg"
                if not is_tensor_evenly_shardable(
                    target_expected_spec.shape, target_expected_spec
                ):
                    raise ValueError(
                        "The intermediate results of nll_loss cannot be evenly sharded, \
                        resulting in biased mean result."
                    )
            else:  # reduction == Reduction.SUM.value:
                reduction_op = "sum"
            reduce_dims = list(range(target_expected_spec.ndim))
            reduce_dims_map = _infer_reduce_dims_map(
                reduce_dims, target_expected_spec.ndim, keep_dim=False
            )
            out_placements = map_placements_after_reduction(
                target_expected_spec.placements,
                reduce_dims,
                reduce_dims_map,
                reduction_op,
            )
            output_expected_spec = DTensorSpec(
                mesh=mesh,
                placements=out_placements,
            )

            # whether reduction is sum or mean, the total weight has to be summed up if not replicated
            total_weight_placements = map_placements_after_reduction(
                target_expected_spec.placements,
                reduce_dims,
                reduce_dims_map,
                "sum",
            )
            total_weight_expected_spec = DTensorSpec(
                mesh=mesh,
                placements=total_weight_placements,
            )

        output_strategy.strategies.append(
            OpSpec(
                output_specs=(output_expected_spec, total_weight_expected_spec),
                input_specs=op_args_target_specs,
                redistribute_cost=redistribute_costs,
            )
        )

    return output_strategy


@register_op_strategy(
    [aten.nll_loss_backward.default, aten.nll_loss2d_backward.default],
    schema_info=RuntimeSchemaInfo(4),
)
def nll_loss_backward_strategy(op_schema: OpSchema) -> OpStrategy:
    # backward op does not need to validate the mesh since forward op has already done it
    mesh = op_schema.get_mesh_from_args(validate=False)

    if not len(op_schema.args_schema) == 7:
        raise AssertionError(f"Expected 7 args, got {len(op_schema.args_schema)}")
    (
        grad_out_strategy,
        input_strategy,
        target_strategy,
        weight_strategy,
        reduction,
        _,
        total_weight_strategy,
    ) = op_schema.args_schema
    grad_out_strategy = cast(OpStrategy, grad_out_strategy)
    input_strategy = cast(OpStrategy, input_strategy)
    target_strategy = cast(OpStrategy, target_strategy)
    reduction = cast(int, reduction)
    total_weight_strategy = cast(OpStrategy, total_weight_strategy)

    input_shape = input_strategy.shape
    channel_dim = 1 if len(input_shape) >= 2 else 0

    grad_in_strategy = OpStrategy([])
    for idx, input_placement_strategy in enumerate(input_strategy.strategies):
        op_args_target_specs = []
        redistribute_costs = []

        # make sure input is replicated along the channel dim
        input_src_spec = input_placement_strategy.output_spec
        input_expected_spec = DTensorSpec(
            mesh=mesh,
            placements=replicate_reduction_dims(
                input_src_spec.placements, [channel_dim]
            ),
            tensor_meta=input_src_spec.tensor_meta,
        )
        op_args_target_specs.append(input_expected_spec)
        redistribute_costs.append(
            generate_redistribute_costs(input_strategy, input_expected_spec)
        )

        # target doesn't have channel dim, and it follows input on other dims
        target_src_spec = target_strategy.strategies[idx].output_spec
        target_expected_spec = DTensorSpec(
            mesh=mesh,
            placements=_skip_dim(input_expected_spec.placements, channel_dim),
            tensor_meta=target_src_spec.tensor_meta,
        )
        op_args_target_specs.append(target_expected_spec)
        redistribute_costs.append(
            generate_redistribute_costs(target_strategy, target_expected_spec)
        )

        # grad_out follows target if there is no reduction;
        # otherwise, it should be a replicated scalar.
        grad_out_src_spec = grad_out_strategy.strategies[idx].output_spec
        if reduction == Reduction.NONE.value:
            grad_out_expected_spec = target_expected_spec
        else:
            grad_out_expected_spec = DTensorSpec(
                mesh=mesh,
                placements=_replicate_dims_start_at(grad_out_src_spec.placements),
                tensor_meta=grad_out_src_spec.tensor_meta,
            )
        op_args_target_specs.insert(0, grad_out_expected_spec)
        redistribute_costs.insert(
            0, generate_redistribute_costs(grad_out_strategy, grad_out_expected_spec)
        )

        # weight tensor, if given, has to be a Tensor of size input_shape[channel_dim]
        # make sure it is replicated
        if weight_strategy is not None:
            if not isinstance(weight_strategy, OpStrategy):
                raise AssertionError(
                    f"Expected OpStrategy, got {type(weight_strategy)}"
                )
            weight_src_spec = weight_strategy.strategies[idx].output_spec
            weight_expected_spec = DTensorSpec(
                mesh=mesh,
                placements=_replicate_dims_start_at(weight_src_spec.placements),
                tensor_meta=weight_src_spec.tensor_meta,
            )
            op_args_target_specs.append(weight_expected_spec)
            redistribute_costs.append(
                generate_redistribute_costs(weight_strategy, weight_expected_spec)
            )

        # total_weight should always be replicated
        total_weight_src_spec = total_weight_strategy.strategies[idx].output_spec
        total_weight_expected_spec = DTensorSpec(
            mesh=mesh,
            placements=_replicate_dims_start_at(total_weight_src_spec.placements),
            tensor_meta=total_weight_src_spec.tensor_meta,
        )
        op_args_target_specs.append(total_weight_expected_spec)
        redistribute_costs.append(
            generate_redistribute_costs(
                total_weight_strategy, total_weight_expected_spec
            )
        )

        grad_in_expected_spec = input_expected_spec
        grad_in_strategy.strategies.append(
            OpSpec(
                output_specs=grad_in_expected_spec,
                input_specs=op_args_target_specs,
                redistribute_cost=redistribute_costs,
            )
        )

    return grad_in_strategy


def _common_norm_forward_strategy(
    op_schema: OpSchema,
    rms_norm: bool = False,
) -> OpStrategy:
    """Common forward strategy logic for layer_norm and rms_norm."""
    mesh = op_schema.get_mesh_from_args()

    if not rms_norm:
        # layer_norm args: input, normalized_shape, weight, bias, eps
        # for None weight and bias, their corresponding objects will
        # be None as well. layer_norm_strategy returns one OpStrategy
        # for the triple return values (out, mean, rstd).
        if not len(op_schema.args_schema) == 5:
            raise AssertionError(f"Expected 5 args, got {len(op_schema.args_schema)}")
        (
            input_strategy,
            normalized_shape,
            weight_strategy,
            bias_strategy,
            _,
        ) = op_schema.args_schema
    else:
        # rms_norm args: input, normalized_shape, weight, eps
        if not len(op_schema.args_schema) == 4:
            raise AssertionError(f"Expected 4 args, got {len(op_schema.args_schema)}")
        (
            input_strategy,
            normalized_shape,
            weight_strategy,
            _,
        ) = op_schema.args_schema
        bias_strategy = None

    # the current norm implementation requires that all
    # input DTensor's sharding must be in form of OpStrategy
    if not isinstance(input_strategy, OpStrategy):
        raise AssertionError(f"Expected OpStrategy, got {type(input_strategy)}")
    if not isinstance(normalized_shape, (int, Sequence, torch.Size)):
        raise AssertionError(
            f"Expected int, Sequence, or torch.Size, got {type(normalized_shape)}"
        )
    normalized_size = normalize_to_torch_size(normalized_shape)

    input_ndim = input_strategy.ndim
    axis = input_ndim - len(normalized_size)

    # we use OpStrategy because the output values (out, mean, rstd)
    # should have the same placements
    output_strategy = OpStrategy([])
    for idx, input_placement_strategy in enumerate(input_strategy.strategies):
        op_args_target_specs = []
        redistribute_costs = []
        input_src_spec = input_placement_strategy.output_spec

        # for the input tensor, we replicate it on the inner dims if necessary
        # TODO: we can avoid forcing the redistribution once we figure out
        # how to decompose layer norm
        input_target_spec = DTensorSpec(
            mesh=mesh,
            placements=_replicate_dims_start_at(input_src_spec.placements, axis),
            tensor_meta=input_src_spec.tensor_meta,
        )
        op_args_target_specs.append(input_target_spec)
        redistribute_costs.append(
            generate_redistribute_costs(input_strategy, input_target_spec)
        )

        if weight_strategy is not None:
            if not isinstance(weight_strategy, OpStrategy):
                raise AssertionError(
                    f"Expected OpStrategy, got {type(weight_strategy)}"
                )
            weight_src_spec = weight_strategy.strategies[idx].output_spec

            # for the weight tensor, we replicate it on all dims if necessary
            # TODO: we can avoid forcing the redistribution once we figure out
            # how to decompose layer norm
            weight_target_spec = DTensorSpec(
                mesh=mesh,
                placements=_replicate_dims_start_at(weight_src_spec.placements),
                tensor_meta=weight_src_spec.tensor_meta,
            )
            op_args_target_specs.append(weight_target_spec)
            redistribute_costs.append(
                generate_redistribute_costs(weight_strategy, weight_target_spec)
            )

        if bias_strategy is not None:
            if not isinstance(bias_strategy, OpStrategy):
                raise AssertionError(f"Expected OpStrategy, got {type(bias_strategy)}")
            bias_src_spec = bias_strategy.strategies[idx].output_spec

            # for the bias tensor, we replicate it on all dims if necessary
            # TODO: we can avoid forcing the redistribution once we figure out
            # how to decompose layer norm
            bias_target_spec = DTensorSpec(
                mesh=mesh,
                placements=_replicate_dims_start_at(bias_src_spec.placements),
                tensor_meta=bias_src_spec.tensor_meta,
            )
            op_args_target_specs.append(bias_target_spec)
            redistribute_costs.append(
                generate_redistribute_costs(bias_strategy, bias_target_spec)
            )

        # the output spec is the same as input spec
        output_target_spec = input_target_spec
        output_strategy.strategies.append(
            OpSpec(
                output_specs=output_target_spec,
                input_specs=op_args_target_specs,
                redistribute_cost=redistribute_costs,
            )
        )

    return output_strategy


@register_op_strategy(
    [aten.native_layer_norm.default],
    schema_info=RuntimeSchemaInfo(1),
)
def layer_norm_strategy(op_schema: OpSchema) -> OpStrategy:
    return _common_norm_forward_strategy(op_schema)


@register_op_strategy(
    [aten._fused_rms_norm.default],
    schema_info=RuntimeSchemaInfo(1),
)
def fused_rms_norm_strategy(op_schema: OpSchema) -> OpStrategy:
    return _common_norm_forward_strategy(op_schema, rms_norm=True)


def _common_norm_backward_strategy(
    op_schema: OpSchema,
    rms_norm: bool = False,
) -> OpStrategy:
    """Common backward strategy logic for layer_norm and rms_norm."""
    # backward op does not need to validate the mesh since forward op has already done it
    mesh = op_schema.get_mesh_from_args(validate=False)

    if not rms_norm:
        # layer_norm args: grad_out, input, normalized_shape, mean, rstd,
        # weight, bias, output_mask. For None weight and bias, their
        # corresponding objects will be None as well.
        if not len(op_schema.args_schema) == 8:
            raise AssertionError(f"Expected 8 args, got {len(op_schema.args_schema)}")
        (
            grad_out_strategy,
            input_strategy,
            normalized_shape,
            mean_strategy,
            rstd_strategy,
            weight_strategy,
            bias_strategy,
            output_mask,
        ) = op_schema.args_schema
    else:
        # rms_norm args: grad_out, input, normalized_shape, rstd,
        if not len(op_schema.args_schema) == 6:
            raise AssertionError(f"Expected 6 args, got {len(op_schema.args_schema)}")
        (
            grad_out_strategy,
            input_strategy,
            normalized_shape,
            rstd_strategy,
            weight_strategy,
            output_mask,
        ) = op_schema.args_schema
        mean_strategy = None
        bias_strategy = None

    if not isinstance(grad_out_strategy, OpStrategy):
        raise AssertionError(f"Expected OpStrategy, got {type(grad_out_strategy)}")
    if not isinstance(input_strategy, OpStrategy):
        raise AssertionError(f"Expected OpStrategy, got {type(input_strategy)}")
    if not isinstance(rstd_strategy, OpStrategy):
        raise AssertionError(f"Expected OpStrategy, got {type(rstd_strategy)}")
    if mean_strategy is not None:
        if not isinstance(mean_strategy, OpStrategy):
            raise AssertionError(f"Expected OpStrategy, got {type(mean_strategy)}")

    if not isinstance(normalized_shape, (int, Sequence, torch.Size)):
        raise AssertionError(
            f"Expected int, Sequence, or torch.Size, got {type(normalized_shape)}"
        )
    normalized_size = normalize_to_torch_size(normalized_shape)
    input_ndim = input_strategy.ndim
    axis = input_ndim - len(normalized_size)
    outer_dims = list(range(axis))

    if not rms_norm:
        if not (isinstance(output_mask, list) and len(output_mask) == 3):
            raise AssertionError(
                f"Expected output_mask to be list of length 3, got {type(output_mask)} "
                f"of length {len(output_mask) if isinstance(output_mask, list) else 'N/A'}"
            )
    else:
        if not (isinstance(output_mask, list) and len(output_mask) == 2):
            raise AssertionError(
                f"Expected output_mask to be list of length 2, got {type(output_mask)} "
                f"of length {len(output_mask) if isinstance(output_mask, list) else 'N/A'}"
            )

    # output tuple: (d_input, d_weight[, d_bias])
    out_tuple_strategy = OpStrategy([])
    for idx, input_placement_strategy in enumerate(input_strategy.strategies):
        # args for OpSpec
        output_specs_list: list[DTensorSpec | None] = []
        input_specs_list: list[DTensorSpec] = []
        redistribute_costs = []

        input_src_spec = input_placement_strategy.output_spec
        # arg: grad_out
        # TODO: change the strategy to the following rule.
        # d_input is basically a product of element-wise mul of
        # grad_out, rstd, and normalized input, among which rstd
        # and normalized input (x_hat) should have the same sharding
        # placements, and grad_out's sharding is determined by the
        # pointwise result of x_hat and weight/bias.
        # TODO: now grad_out spec follows input spec. we may need
        # to change it to apply a pointwise rule over grad_out,
        # input, and weight.
        grad_out_target_spec = DTensorSpec(
            mesh=mesh,
            placements=_replicate_dims_start_at(input_src_spec.placements, axis),
            tensor_meta=input_src_spec.tensor_meta,
        )
        input_specs_list.append(grad_out_target_spec)
        redistribute_costs.append(
            generate_redistribute_costs(grad_out_strategy, grad_out_target_spec)
        )
        output_specs_list.append(grad_out_target_spec if output_mask[0] else None)

        # arg: input
        input_target_spec = DTensorSpec(
            mesh=mesh,
            placements=_replicate_dims_start_at(input_src_spec.placements, axis),
            tensor_meta=input_src_spec.tensor_meta,
        )
        input_specs_list.append(input_target_spec)
        redistribute_costs.append(
            generate_redistribute_costs(input_strategy, input_target_spec)
        )

        # arg: mean
        if not rms_norm:
            if mean_strategy is None:
                raise AssertionError("Expected mean_strategy to not be None")
            mean_src_spec = mean_strategy.strategies[idx].output_spec
            input_specs_list.append(mean_src_spec)
            redistribute_costs.append([0.0 for _ in mean_strategy.strategies])

        # arg: rstd
        rstd_src_spec = rstd_strategy.strategies[idx].output_spec
        input_specs_list.append(rstd_src_spec)
        redistribute_costs.append([0.0 for _ in rstd_strategy.strategies])

        def _add_target_input_spec(strategy) -> DTensorSpec:
            # shared logic for setting the weight and bias target input specs
            if not isinstance(strategy, OpStrategy):
                raise AssertionError(f"Expected OpStrategy, got {type(strategy)}")
            src_spec = strategy.strategies[idx].output_spec
            # no need to redistribute since they should be replicated in forward pass
            input_specs_list.append(src_spec)
            redistribute_costs.append([0.0 for _ in strategy.strategies])
            return src_spec

        # arg: weight
        # d_weight = sum(grad_out * (input - mean) / rstd, outer_dim, keepdim=False)
        # For RMS norm, mean is 0, so it's just: sum(grad_out * input / rstd, outer_dim, keepdim=False)
        if weight_strategy is not None:
            weight_src_spec = _add_target_input_spec(weight_strategy)
            # TODO: now d_weight spec follows input spec w/ a reduction.
            # we may need to change to a pointwise rule over grad_out and
            # input, then apply a reduction.
            inp_placements = _replicate_dims_start_at(input_src_spec.placements, axis)
            reduce_dims_map = _infer_reduce_dims_map(
                outer_dims, input_src_spec.ndim, False
            )
            out_placements = map_placements_after_reduction(
                inp_placements, outer_dims, reduce_dims_map, "sum"
            )
            weight_out_spec = DTensorSpec(
                mesh=mesh,
                placements=out_placements,
                tensor_meta=weight_src_spec.tensor_meta,
            )
            output_specs_list.append(weight_out_spec if output_mask[1] else None)
        else:
            if not rms_norm:
                error_msg = "output_mask[1] should not be `True` while weight argument is `None` in native_layer_norm_backward."
            else:
                error_msg = "output_mask[1] should not be `True` while weight argument is `None` in _fused_rms_norm_backward."
            if output_mask[1] is not False:
                raise AssertionError(error_msg)
            output_specs_list.append(None)

        # arg: bias
        # d_bias = sum(grad_out, outer_dim, keepdim=False)
        if not rms_norm:
            if bias_strategy is not None:
                bias_src_spec = _add_target_input_spec(bias_strategy)
                # d_bias spec follows a reduction over grad_out
                inp_placements = _replicate_dims_start_at(
                    grad_out_target_spec.placements, axis
                )
                reduce_dims_map = _infer_reduce_dims_map(
                    outer_dims, grad_out_target_spec.ndim, False
                )
                out_placements = map_placements_after_reduction(
                    inp_placements, outer_dims, reduce_dims_map, "sum"
                )
                bias_out_spec = DTensorSpec(
                    mesh=mesh,
                    placements=out_placements,
                    tensor_meta=bias_src_spec.tensor_meta,
                )
                output_specs_list.append(bias_out_spec if output_mask[2] else None)
            else:
                if output_mask[2] is not False:
                    raise AssertionError(
                        "output_mask[2] should not be `True` while bias argument is `None` in native_layer_norm_backward."
                    )
                output_specs_list.append(None)

        out_tuple_strategy.strategies.append(
            OpSpec(
                output_specs=tuple(output_specs_list),
                input_specs=input_specs_list,
                redistribute_cost=redistribute_costs,
            )
        )

    return out_tuple_strategy


@register_op_strategy(
    [aten.native_layer_norm_backward.default],
    schema_info=RuntimeSchemaInfo(2),
)
def layer_norm_bwd_strategy(op_schema: OpSchema) -> OpStrategy:
    return _common_norm_backward_strategy(op_schema)


@register_op_strategy(
    [aten._fused_rms_norm_backward.default],
    schema_info=RuntimeSchemaInfo(2),
)
def fused_rms_norm_bwd_strategy(op_schema: OpSchema) -> OpStrategy:
    return _common_norm_backward_strategy(op_schema, rms_norm=True)


def sort_strategy(op_schema: OpSchema, sort_dim: int) -> OpStrategy:
    input_strategy = cast(OpStrategy, op_schema.args_schema[0])
    sort_dim = normalize_dim(sort_dim, input_strategy.ndim)
    single_mesh_dim_strategies = []
    all_replicate: PlacementList = [Replicate()] * 3
    single_mesh_dim_strategies.append(all_replicate)
    for dim in range(input_strategy.ndim):
        if dim != sort_dim:
            dim_shardings: PlacementList = [Shard(dim)] * 3
            single_mesh_dim_strategies.append(dim_shardings)
    return expand_to_full_mesh_op_strategy(
        input_strategy.mesh, op_schema, single_mesh_dim_strategies, input_index=2
    )


@register_op_strategy(
    [aten.topk.default],
    schema_info=RuntimeSchemaInfo(2),
)
def topk_strategy(op_schema: OpSchema) -> OpStrategy:
    topk_dim = (
        cast(int, op_schema.args_schema[2]) if len(op_schema.args_schema) > 2 else -1
    )
    return sort_strategy(op_schema, topk_dim)


@register_op_strategy(
    aten.sort.default,
    schema_info=RuntimeSchemaInfo(
        1,
    ),
)
def sort_default_strategy(op_schema: OpSchema) -> OpStrategy:
    # mostly copy paste from topk_strategy
    input_strategy = op_schema.args_schema[0]
    if not isinstance(input_strategy, OpStrategy):
        raise AssertionError(f"Expected OpStrategy, got {type(input_strategy)}")
    sort_dim = -1
    if len(op_schema.args_schema) > 1:
        sort_dim = cast(int, op_schema.args_schema[1])
    return sort_strategy(op_schema, sort_dim)


@register_op_strategy(
    aten.sort.stable,
    schema_info=RuntimeSchemaInfo(
        1,
        static_kwargkey=["dim", "descending", "stable"],
    ),
)
def sort_stable_strategy(op_schema: OpSchema) -> OpStrategy:
    # mostly copy paste from topk_strategy
    input_strategy = op_schema.args_schema[0]
    if not isinstance(input_strategy, OpStrategy):
        raise AssertionError(f"Expected OpStrategy, got {type(input_strategy)}")
    sort_dim = -1
    if "dim" in op_schema.kwargs_schema:
        sort_dim = cast(int, op_schema.kwargs_schema["dim"])
    return sort_strategy(op_schema, sort_dim)


@register_op_strategy(
    [aten.histc.default],
    # strategy choice depends on the value of 'min' and 'max' kwargs, which are position 2 and 3
    schema_info=RuntimeSchemaInfo(2),
)
def histc_strategy(op_schema: OpSchema) -> OpStrategy:
    input_strategy = cast(OpStrategy, op_schema.args_schema[0])
    single_mesh_dim_strategies: list[PlacementList] = []
    single_mesh_dim_strategies.append([Replicate(), Replicate()])

    # histc can support sharded input and partial output on any input dim, provided the min and max
    # values are user-specified.  If not user-specified, the true min and max of the data in each local
    # tensor will be used to compute bin boundaries, which will not be the same across ranks, leading to
    # an incorrect final result
    if len(op_schema.args_schema) == 4:
        for dim in range(input_strategy.ndim):
            dim_shardings: PlacementList = [Partial(), Shard(dim)]
            single_mesh_dim_strategies.append(dim_shardings)

    return expand_to_full_mesh_op_strategy(
        input_strategy.mesh, op_schema, single_mesh_dim_strategies
    )


@register_op_strategy(
    [aten.logsumexp.default],
    schema_info=RuntimeSchemaInfo(
        # static_argnum is the position where non-Tensor args beings.
        static_argnum=1,
        # static_kwargkey is the name of kwargs to hash (which determines
        # whether sharding prop can be cached).
        static_kwargkey=["keepdim"],
    ),
)
def logsumexp_strategy(op_schema: OpSchema) -> OpStrategy:
    """Implements the sharding propagation strategy for logsumexp."""

    # args_schema contains all but the DTensor args (e.g., dim, keepdim).
    args_schema = op_schema.args_schema
    if not len(args_schema) > 1:
        raise AssertionError(
            f"Expected more than 1 arg (input and dim are required), got {len(args_schema)}"
        )

    input_strategy = args_schema[0]
    if not isinstance(input_strategy, OpStrategy):
        raise AssertionError(f"Expected OpStrategy, got {type(input_strategy)}")

    dims_arg = args_schema[1]
    reduce_dims = _infer_reduction_dims(dims_arg, input_strategy.ndim)
    if reduce_dims is None:
        raise AssertionError("Expected reduce_dims to not be None")

    keep_dim = cast(bool, op_schema.kwargs_schema.get("keepdim", False))
    return common_reduction_strategy(
        input_strategy,
        reduce_dims,
        keep_dim=keep_dim,
        reduction_linear=False,
    )
