import itertools
import logging
from collections.abc import Callable
from typing import Any

import sympy

import torch
import torch._inductor.config as config
from torch._inductor import ir
from torch._inductor.codegen.common import KernelTemplate
from torch._inductor.ir import (
    Buffer,
    FixedLayout,
    get_free_symbols,
    get_symbolic_inputs,
    gm_original_output_strides,
    ir_node_to_tensor,
    Layout,
)
from torch._inductor.runtime.benchmarking import benchmarker
from torch._inductor.utils import do_bench_using_profiling
from torch._inductor.virtualized import V
from torch.utils._ordered_set import OrderedSet


log = logging.getLogger(__name__)


def inline_subgraph_to_ir_nodes(
    gm: torch.fx.GraphModule, inputs: list[Any], name: str
) -> Any:
    """Inline a subgraph by converting its FX operations to individual IR nodes.

    This converts a subgraph to multiple ComputedBuffer nodes (fusable),
    enabling epilogue fusion with subsequent operations.

    Returns:
        TensorBox containing the final operation result as individual IR nodes
    """
    from torch._inductor.lowering import process_subgraph_nodes

    # Temporarily switch V.graph.module to subgraph during processing; restore to prevent IR nodes added to wrong graph
    original_module = V.graph.module
    try:
        V.graph.module = gm
        return process_subgraph_nodes(gm, inputs)
    finally:
        V.graph.module = original_module


class SubgraphChoiceCaller(ir.ChoiceCaller):
    """
    Represents a Subgraph Autotuning choice, and the subgraph can be any arbitrary
    GraphModule. Compiles the Subgraph down to a module for benchmarking.
    """

    def __init__(
        self,
        name: str,
        input_nodes: list[Buffer],
        layout: Layout,
        description: str,
        make_fx_graph: Callable[..., Any],
        input_gen_fns: dict[int, Callable[[Any], torch.Tensor]] | None = None,
    ) -> None:
        super().__init__(name, input_nodes, layout, description)

        self.example_inputs = []
        with V.fake_mode:
            for i, inp in enumerate(self.input_nodes):
                # Here there will be no unbacked symbols, as SubgraphBuffer does not support them
                assert len(get_free_symbols(inp.get_size(), unbacked_only=True)) == 0
                assert len(get_free_symbols(inp.get_stride(), unbacked_only=True)) == 0

                inp.data.freeze_layout()  # type: ignore[attr-defined]

                # Use input_gen_fn if provided for this index (e.g., for range-based autotuning)
                # This ensures the graph is traced with correct target shapes
                if input_gen_fns is not None and i in input_gen_fns:
                    self.example_inputs.append(input_gen_fns[i](inp))
                else:
                    self.example_inputs.append(ir_node_to_tensor(inp))

        self.gm = make_fx_graph(*self.example_inputs)
        gm_original_output_strides(self.gm)

        self.sym_inputs = get_symbolic_inputs(self.input_nodes)
        self.sym_input_values = self._compute_sym_input_values()

        # Cached decomposition info for range-based dispatch (set via cache_decomposition)
        self.decomposition: Callable[..., Any] | None = None
        self.decomposition_kwargs: dict[str, Any] = {}
        # Cache compiled module to avoid recompiling on every benchmark call
        self._compiled_module: Any = None

    def _compute_sym_input_values(self) -> list[int]:
        """Extract concrete dimension values for sym_inputs from example_inputs.

        The compiled module expects symbolic dimension values as runtime arguments.
        This maps each symbolic variable to its concrete value from the example tensors.
        Used for range based autotuning.
        """
        sym_input_names = OrderedSet(
            [s.name for s in self.sym_inputs if hasattr(s, "name")]
        )

        # Build mapping: symbolic dimension name -> concrete value
        sym_name_to_value: dict[str, int] = {}
        for inp_node, example_inp in zip(self.input_nodes, self.example_inputs):
            if isinstance(example_inp, torch.Tensor):
                for sym_dim, actual_dim in zip(inp_node.get_size(), example_inp.shape):
                    if isinstance(sym_dim, sympy.Symbol):
                        sym_name_to_value[sym_dim.name] = int(actual_dim)
                    elif str(sym_dim) in sym_input_names:
                        sym_name_to_value[str(sym_dim)] = int(actual_dim)

        result = []
        for sym_var in self.sym_inputs:
            if isinstance(sym_var, sympy.Symbol) and sym_var.name in sym_name_to_value:
                result.append(sym_name_to_value[sym_var.name])
            else:
                hint = V.graph.sizevars.shape_env.size_hint(sym_var)
                result.append(int(hint) if hint is not None else 1)
        return result

    def cache_decomposition(
        self, decomposition: Callable[..., Any], kwargs: dict[str, Any]
    ) -> None:
        """Cache decomposition function and kwargs for range-based dispatch lookup."""
        self.decomposition = decomposition
        self.decomposition_kwargs = kwargs

    def __str__(self) -> str:
        return f"SubgraphCaller({self.name})"

    def _compile_for_benchmarking(self) -> Any:
        """Compile the subgraph for benchmarking, returns the compiled module."""
        from torch._inductor.graph import GraphLowering

        safe_name = self.name.replace("::", "_").replace(".", "_")

        bm_graph_lowering = GraphLowering(
            gm=self.gm,
            example_inputs=self.example_inputs,
            shape_env=V.graph._shape_env,
            cpp_wrapper=V.graph.cpp_wrapper,
            aot_mode=V.graph.aot_mode,
            extern_node_serializer=V.graph.extern_node_serializer,
            is_inference=V.graph.is_inference,
            is_backward=V.graph.is_backward,
            name=f"benchmark_{safe_name}",
        )

        for sym_inp in self.sym_inputs:
            bm_graph_lowering.graph_inputs[sym_inp.name] = sym_inp
            bm_graph_lowering.graph_input_names.append(sym_inp.name)

        with V.set_graph_handler(bm_graph_lowering):
            # Don't bother autotuning on Triton here
            with config.patch(
                max_autotune=False,
                max_autotune_gemm=False,
                max_autotune_gemm_backends="ATEN",
            ):
                bm_graph_lowering.run(*self.example_inputs)
                return bm_graph_lowering.compile_to_module()

    def benchmark(self, *args: list[Any], out: torch.Tensor) -> float:
        """Regular benchmarking: compile and use benchmarker with warmup/rep."""
        if self._compiled_module is None:
            self._compiled_module = self._compile_for_benchmarking()

        bm_func = self._compiled_module.call
        sym_inputs = self.sym_input_values
        if config.profile_bandwidth_with_do_bench_using_profiling:
            return do_bench_using_profiling(lambda: bm_func([*sym_inputs, *args]))
        return benchmarker.benchmark(
            # Shallow clone args since bm_func may clear args
            lambda: bm_func([*sym_inputs, *args]),
            device=benchmarker.infer_device(*sym_inputs, *args),
        )

    def benchmark_collective(self, *args: list[Any], out: torch.Tensor) -> None:
        """Run once for collective benchmarking (barrier sync handled by caller)."""
        if self._compiled_module is None:
            self._compiled_module = self._compile_for_benchmarking()

        self._compiled_module.call([*self.sym_input_values, *args])

    def hash_key(self) -> str:
        return "-".join(
            [
                self.name.rsplit("_", 1)[0],
                *[str(inp.get_size()) for inp in self.input_nodes],
                *[str(inp.get_stride()) for inp in self.input_nodes],
                str(self.gm.graph),
            ]
        )

    def output_node(self) -> ir.TensorBox:
        return ir.TensorBox.create(
            ir.SubgraphBuffer(
                layout=self.layout,
                input_nodes=self.input_nodes,
                gm=self.gm,
                example_inputs=self.example_inputs,
                subgraph_name=self.name,
            )
        )

    def info_dict(self) -> dict[str, Any]:
        """Information returned here is logged to the autotune log file when that is enabled."""
        return {
            "backend": "subgraph",
            "kernel_name": self.name,
        }

    def autoheuristic_id(self) -> str:
        return f"subgraph_{self.name}"


class SubgraphTemplate(KernelTemplate):
    """
    A template for subgraph evaluation to be used in autotuning.

    This class allows creating customized subgraphs that can be appended
    as choices during the autotuning process, enabling the selection of
    optimal implementations for complex operations.
    """

    index_counter = itertools.count()

    def __init__(
        self,
        name: str,
    ):
        """
        Initialize a subgraph template.

        Args:
            name: The name of this template
            graph: The FX graph
        """
        super().__init__(name=name)

    def generate(  # type: ignore[override]
        self,
        name: str,
        input_nodes: list[Buffer],
        layout: Layout,
        make_fx_graph: Callable[..., Any],
        description: str = "",
        input_gen_fns: dict[int, Callable[[Any], torch.Tensor]] | None = None,
        **kwargs: Any,
    ) -> SubgraphChoiceCaller:
        """
        Generate a SubgraphChoiceCaller instance for autotuning.

        Args:
            name: The name for this subgraph choice
            input_nodes: List of input nodes to the subgraph
            layout: Memory layout information for the output
            make_fx_graph: Callable that creates the FX graph for this subgraph
            description: Optional description of this choice
            input_gen_fns: Optional dict mapping input indices to tensor generators
            **kwargs: Additional keyword arguments

        Returns:
            SubgraphChoiceCaller: A callable object that can be used for autotuning
        """

        return SubgraphChoiceCaller(
            name=f"{name}_{next(SubgraphTemplate.index_counter)}",
            input_nodes=input_nodes,
            layout=layout,
            description=description,
            make_fx_graph=make_fx_graph,
            input_gen_fns=input_gen_fns,
        )

    def generate_custom_op_choices(
        self,
        name: str,
        decompositions: list[Callable[..., Any]],
        input_nodes: list[Buffer],
        non_tensor_args: list[dict[str, Any]],
        default_impl: Callable[..., Any] | None = None,
        input_gen_fns: dict[int, Callable[[Any], torch.Tensor]] | None = None,
    ) -> list[SubgraphChoiceCaller]:
        """
        Generate multiple SubgraphChoiceCaller instances for custom op autotuning.

        This method extends SubgraphTemplate to support custom op decompositions,
        allowing multiple implementations to compete in autotuning.

        Args:
            name: Base name for the choices
            decompositions: List of decomposition functions to compete in autotuning
            input_nodes: List of tensor inputs. All tensor arguments must be passed here.
            non_tensor_args: List of non-tensor kwargs only, one dict per corresponding decomposition.
            default_impl: Default implementation for layout inference
            input_gen_fns: Optional dict mapping input indices to tensor generators

        Returns:
            List of SubgraphChoiceCaller instances for autotuning
        """
        if not decompositions:
            return []

        assert len(decompositions) == len(non_tensor_args), (
            f"decompositions and non_tensor_args must have same length, "
            f"got {len(decompositions)} decompositions and {len(non_tensor_args)} kwargs"
        )

        # Infer layouts and ensure layout consistency for fair autotuning comparison
        layouts = [
            self._infer_custom_op_layout(
                input_nodes, decomp, kwargs, default_impl, input_gen_fns
            )
            for decomp, kwargs in zip(decompositions, non_tensor_args)
        ]

        # Validate all decompositions produce equivalent layouts for fair comparison
        self._validate_layout_equivalence(name, decompositions, layouts)
        layout = layouts[0]  # All layouts are now validated to be equivalent

        choices: list[SubgraphChoiceCaller] = []
        for decomp, decomp_kwargs in zip(decompositions, non_tensor_args):
            # Create make_fx_graph function for this decomposition
            import functools

            def make_fx_graph(
                *args: Any,
                decomp: Callable[..., Any] = decomp,
                decomp_kwargs: dict[str, Any] = decomp_kwargs,
            ) -> Any:
                # decomp_kwargs contains all merged parameters: CustomOpConfig params + runtime kwargs
                from torch.fx.experimental.proxy_tensor import make_fx

                from ..decomposition import select_decomp_table

                decomposition_table = select_decomp_table()

                return make_fx(
                    functools.partial(decomp, **decomp_kwargs),
                    decomposition_table=decomposition_table,
                )(*args)

            # Generate descriptive name for this variant
            variant_name = self._generate_variant_name(decomp, decomp_kwargs)

            choice = self.generate(
                name=f"{name}_{variant_name}",
                input_nodes=input_nodes,
                layout=layout,
                make_fx_graph=make_fx_graph,
                description=f"CustomOp {decomp.__name__}",
                input_gen_fns=input_gen_fns,
            )
            # Cache decomposition info for range-based dispatch
            choice.cache_decomposition(decomp, decomp_kwargs)
            choices.append(choice)

        return choices

    def _generate_variant_name(
        self, decomp: Callable[..., Any], kwargs: dict[str, Any]
    ) -> str:
        """Generate a descriptive name for a decomposition variant with its parameters."""
        base_name = decomp.__name__
        if not kwargs:
            return base_name
        param_suffix = "_".join(f"{k}_{v}" for k, v in sorted(kwargs.items()))
        return f"{base_name}_{param_suffix}"

    def _validate_non_tensor_kwargs(self, kwargs: dict[str, Any]) -> None:
        """Validate that kwargs contains only non-tensor arguments."""
        for key, value in kwargs.items():
            assert not isinstance(value, (torch.Tensor, Buffer)), (
                f"kwargs['{key}'] contains tensor {type(value)}. "
                f"Tensor arguments should be in input_nodes, not kwargs. "
                f"Only scalar/non-tensor parameters should be in kwargs."
            )

    def _validate_layout_equivalence(
        self,
        op_name: str,
        decompositions: list[Callable[..., Any]],
        layouts: list[Layout],
    ) -> None:
        """Ensure all layouts have consistent stride, device, dtype, and sizes for fair autotuning."""
        if not layouts:
            return

        reference = layouts[0]
        for i, layout in enumerate(layouts[1:], start=1):
            if (layout.device, layout.dtype, layout.size, layout.stride) != (
                reference.device,
                reference.dtype,
                reference.size,
                reference.stride,
            ):
                raise AssertionError(
                    f"Layout mismatch in custom op '{op_name}': "
                    f"decomposition '{decompositions[i].__name__}' produces "
                    f"({layout.device}, {layout.dtype}, {layout.size}, {layout.stride}) "
                    f"but '{decompositions[0].__name__}' produces "
                    f"({reference.device}, {reference.dtype}, {reference.size}, {reference.stride})"
                )

    def _infer_custom_op_layout(
        self,
        input_nodes: list[Buffer],
        function_decomposition: Callable[..., Any],
        kwargs: dict[str, Any],
        default_impl: Callable[..., Any] | None = None,
        input_gen_fns: dict[int, Callable[[Any], torch.Tensor]] | None = None,
    ) -> Layout:
        """Infer output layout for custom ops using the default implementation when available.
        Note that the Subgraph assumes custom ops return exactly one tensor output.
        TODO: Add support for multiple output custom ops.
        """
        import functools

        from torch._inductor.virtualized import V

        # Assert kwargs contain only non-tensor arguments
        self._validate_non_tensor_kwargs(kwargs)

        with V.fake_mode:
            example_inputs = []
            for i, inp in enumerate(input_nodes):
                if input_gen_fns and i in input_gen_fns:
                    fake_tensor = input_gen_fns[i](inp)
                else:
                    raw_shape = inp.get_size()
                    concrete_shape = V.graph.sizevars.size_hints(
                        raw_shape, fallback=config.unbacked_symint_fallback
                    )
                    raw_stride = inp.get_stride()
                    concrete_stride = V.graph.sizevars.size_hints(
                        raw_stride, fallback=config.unbacked_symint_fallback
                    )
                    fake_tensor = torch.empty_strided(
                        concrete_shape,
                        concrete_stride,
                        dtype=inp.get_dtype(),
                        device=inp.get_device(),
                    )
                example_inputs.append(fake_tensor)

            fn = functools.partial(function_decomposition, **kwargs)
            output = fn(*example_inputs)

            # Assert single output
            assert isinstance(output, torch.Tensor), (
                f"Expected single tensor output, got {type(output)}. "
                f"Multi-output custom ops not yet supported in autotuning."
            )

            return FixedLayout(
                device=output.device,
                dtype=output.dtype,
                size=output.shape,
                stride=output.stride(),
            )
