from ctypes import c_void_p
from typing import overload, Protocol

from torch import Tensor

# Defined in torch/csrc/inductor/aoti_runner/pybind.cpp

# Tensor to AtenTensorHandle
def unsafe_alloc_void_ptrs_from_tensors(tensors: list[Tensor]) -> list[c_void_p]: ...
def unsafe_alloc_void_ptr_from_tensor(tensor: Tensor) -> c_void_p: ...

# AtenTensorHandle to Tensor
def alloc_tensors_by_stealing_from_void_ptrs(
    handles: list[c_void_p],
) -> list[Tensor]: ...
def alloc_tensor_by_stealing_from_void_ptr(
    handle: c_void_p,
) -> Tensor: ...

class AOTIModelContainerRunner(Protocol):
    def run(
        self, inputs: list[Tensor], stream_handle: c_void_p = ...
    ) -> list[Tensor]: ...
    def get_call_spec(self) -> list[str]: ...
    def get_constant_names_to_original_fqns(self) -> dict[str, str]: ...
    def get_constant_names_to_dtypes(self) -> dict[str, int]: ...
    def extract_constants_map(self, use_inactive: bool) -> dict[str, Tensor]: ...
    def update_constant_buffer(
        self,
        tensor_map: dict[str, Tensor],
        use_inactive: bool,
        validate_full_updates: bool,
        user_managed: bool = ...,
    ) -> None: ...
    def swap_constant_buffer(self) -> None: ...
    def free_inactive_constant_buffer(self) -> None: ...

class AOTIModelContainerRunnerCpu:
    def __init__(self, model_so_path: str, num_models: int) -> None: ...
    def run(
        self, inputs: list[Tensor], stream_handle: c_void_p = ...
    ) -> list[Tensor]: ...
    def get_call_spec(self) -> list[str]: ...
    def get_constant_names_to_original_fqns(self) -> dict[str, str]: ...
    def get_constant_names_to_dtypes(self) -> dict[str, int]: ...
    def extract_constants_map(self, use_inactive: bool) -> dict[str, Tensor]: ...
    def update_constant_buffer(
        self,
        tensor_map: dict[str, Tensor],
        use_inactive: bool,
        validate_full_updates: bool,
        user_managed: bool = ...,
    ) -> None: ...
    def swap_constant_buffer(self) -> None: ...
    def free_inactive_constant_buffer(self) -> None: ...

class AOTIModelContainerRunnerCuda:
    @overload
    def __init__(self, model_so_path: str, num_models: int) -> None: ...
    @overload
    def __init__(
        self, model_so_path: str, num_models: int, device_str: str
    ) -> None: ...
    @overload
    def __init__(
        self, model_so_path: str, num_models: int, device_str: str, cubin_dir: str
    ) -> None: ...
    def run(
        self, inputs: list[Tensor], stream_handle: c_void_p = ...
    ) -> list[Tensor]: ...
    def get_call_spec(self) -> list[str]: ...
    def get_constant_names_to_original_fqns(self) -> dict[str, str]: ...
    def get_constant_names_to_dtypes(self) -> dict[str, int]: ...
    def extract_constants_map(self, use_inactive: bool) -> dict[str, Tensor]: ...
    def update_constant_buffer(
        self,
        tensor_map: dict[str, Tensor],
        use_inactive: bool,
        validate_full_updates: bool,
        user_managed: bool = ...,
    ) -> None: ...
    def swap_constant_buffer(self) -> None: ...
    def free_inactive_constant_buffer(self) -> None: ...

class AOTIModelContainerRunnerXpu:
    @overload
    def __init__(self, model_so_path: str, num_models: int) -> None: ...
    @overload
    def __init__(
        self, model_so_path: str, num_models: int, device_str: str
    ) -> None: ...
    @overload
    def __init__(
        self, model_so_path: str, num_models: int, device_str: str, kernel_bin_dir: str
    ) -> None: ...
    def run(
        self, inputs: list[Tensor], stream_handle: c_void_p = ...
    ) -> list[Tensor]: ...
    def get_call_spec(self) -> list[str]: ...
    def get_constant_names_to_original_fqns(self) -> dict[str, str]: ...
    def get_constant_names_to_dtypes(self) -> dict[str, int]: ...
    def extract_constants_map(self, use_inactive: bool) -> dict[str, Tensor]: ...
    def update_constant_buffer(
        self,
        tensor_map: dict[str, Tensor],
        use_inactive: bool,
        validate_full_updates: bool,
        user_managed: bool = ...,
    ) -> None: ...
    def swap_constant_buffer(self) -> None: ...
    def free_inactive_constant_buffer(self) -> None: ...

class AOTIModelContainerRunnerMps:
    def __init__(self, model_so_path: str, num_models: int) -> None: ...
    def run(
        self, inputs: list[Tensor], stream_handle: c_void_p = ...
    ) -> list[Tensor]: ...
    def get_call_spec(self) -> list[str]: ...
    def get_constant_names_to_original_fqns(self) -> dict[str, str]: ...
    def get_constant_names_to_dtypes(self) -> dict[str, int]: ...
    def extract_constants_map(self, use_inactive: bool) -> dict[str, Tensor]: ...
    def update_constant_buffer(
        self,
        tensor_map: dict[str, Tensor],
        use_inactive: bool,
        validate_full_updates: bool,
        user_managed: bool = ...,
    ) -> None: ...
    def swap_constant_buffer(self) -> None: ...
    def free_inactive_constant_buffer(self) -> None: ...

# Defined in torch/csrc/inductor/aoti_package/pybind.cpp
class AOTIModelPackageLoader:
    def __init__(
        self,
        model_package_path: str,
        model_name: str,
        run_single_threaded: bool,
        num_runners: int,
        device_index: int,
    ) -> None: ...
    def get_metadata(self) -> dict[str, str]: ...
    def run(
        self, inputs: list[Tensor], stream_handle: c_void_p = ...
    ) -> list[Tensor]: ...
    def boxed_run(
        self, inputs: list[Tensor], stream_handle: c_void_p = ...
    ) -> list[Tensor]: ...
    def get_call_spec(self) -> list[str]: ...
    def get_constant_fqns(self) -> list[str]: ...
    def load_constants(
        self,
        constants_map: dict[str, Tensor],
        use_inactive: bool,
        check_full_update: bool,
        user_managed: bool = ...,
    ) -> None: ...
    def update_constant_buffer(
        self,
        tensor_map: dict[str, Tensor],
        use_inactive: bool,
        validate_full_updates: bool,
        user_managed: bool = ...,
    ) -> None: ...
