from typing import Callable, Dict, Optional, Union, cast

import httpx

import litellm
from litellm import verbose_logger
from litellm.caching import InMemoryCache
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.watsonx import WatsonXAPIParams


class WatsonXAIError(BaseLLMException):
    def __init__(
        self,
        status_code: int,
        message: str,
        headers: Optional[Union[Dict, httpx.Headers]] = None,
    ):
        super().__init__(status_code=status_code, message=message, headers=headers)


iam_token_cache = InMemoryCache()


def get_watsonx_iam_url():
    return (
        get_secret_str("WATSONX_IAM_URL") or "https://iam.cloud.ibm.com/identity/token"
    )


def generate_iam_token(api_key=None, **params) -> str:
    result: Optional[str] = iam_token_cache.get_cache(api_key)  # type: ignore

    if result is None:
        headers = {}
        headers["Content-Type"] = "application/x-www-form-urlencoded"
        if api_key is None:
            api_key = get_secret_str("WX_API_KEY") or get_secret_str("WATSONX_API_KEY")
        if api_key is None:
            raise ValueError("API key is required")
        headers["Accept"] = "application/json"
        data = {
            "grant_type": "urn:ibm:params:oauth:grant-type:apikey",
            "apikey": api_key,
        }
        iam_token_url = get_watsonx_iam_url()
        verbose_logger.debug(
            "calling ibm `/identity/token` to retrieve IAM token.\nURL=%s\nheaders=%s\ndata=%s",
            iam_token_url,
            headers,
            data,
        )
        response = httpx.post(iam_token_url, data=data, headers=headers)
        response.raise_for_status()
        json_data = response.json()

        result = json_data["access_token"]
        iam_token_cache.set_cache(
            key=api_key,
            value=result,
            ttl=json_data["expires_in"] - 10,  # leave some buffer
        )

    return cast(str, result)


def _get_api_params(
    params: dict,
    print_verbose: Optional[Callable] = None,
    generate_token: Optional[bool] = True,
) -> WatsonXAPIParams:
    """
    Find watsonx.ai credentials in the params or environment variables and return the headers for authentication.
    """
    # Load auth variables from params
    url = params.pop("url", params.pop("api_base", params.pop("base_url", None)))
    api_key = params.pop("apikey", None)
    token = params.pop("token", None)
    project_id = params.pop(
        "project_id", params.pop("watsonx_project", None)
    )  # watsonx.ai project_id - allow 'watsonx_project' to be consistent with how vertex project implementation works -> reduce provider-specific params
    space_id = params.pop("space_id", None)  # watsonx.ai deployment space_id
    region_name = params.pop("region_name", params.pop("region", None))
    if region_name is None:
        region_name = params.pop(
            "watsonx_region_name", params.pop("watsonx_region", None)
        )  # consistent with how vertex ai + aws regions are accepted
    wx_credentials = params.pop(
        "wx_credentials",
        params.pop(
            "watsonx_credentials", None
        ),  # follow {provider}_credentials, same as vertex ai
    )
    api_version = params.pop("api_version", litellm.WATSONX_DEFAULT_API_VERSION)
    # Load auth variables from environment variables
    if url is None:
        url = (
            get_secret_str("WATSONX_API_BASE")  # consistent with 'AZURE_API_BASE'
            or get_secret_str("WATSONX_URL")
            or get_secret_str("WX_URL")
            or get_secret_str("WML_URL")
        )
    if api_key is None:
        api_key = (
            get_secret_str("WATSONX_APIKEY")
            or get_secret_str("WATSONX_API_KEY")
            or get_secret_str("WX_API_KEY")
        )
    if token is None:
        token = get_secret_str("WATSONX_TOKEN") or get_secret_str("WX_TOKEN")
    if project_id is None:
        project_id = (
            get_secret_str("WATSONX_PROJECT_ID")
            or get_secret_str("WX_PROJECT_ID")
            or get_secret_str("PROJECT_ID")
        )
    if region_name is None:
        region_name = (
            get_secret_str("WATSONX_REGION")
            or get_secret_str("WX_REGION")
            or get_secret_str("REGION")
        )
    if space_id is None:
        space_id = (
            get_secret_str("WATSONX_DEPLOYMENT_SPACE_ID")
            or get_secret_str("WATSONX_SPACE_ID")
            or get_secret_str("WX_SPACE_ID")
            or get_secret_str("SPACE_ID")
        )

    # credentials parsing
    if wx_credentials is not None:
        url = wx_credentials.get("url", url)
        api_key = wx_credentials.get("apikey", wx_credentials.get("api_key", api_key))
        token = wx_credentials.get(
            "token",
            wx_credentials.get(
                "watsonx_token", token
            ),  # follow format of {provider}_token, same as azure - e.g. 'azure_ad_token=..'
        )

    # verify that all required credentials are present
    if url is None:
        raise WatsonXAIError(
            status_code=401,
            message="Error: Watsonx URL not set. Set WX_URL in environment variables or pass in as a parameter.",
        )

    if token is None and api_key is not None and generate_token:
        # generate the auth token
        if print_verbose is not None:
            print_verbose("Generating IAM token for Watsonx.ai")
        token = generate_iam_token(api_key)
    elif token is None and api_key is None:
        raise WatsonXAIError(
            status_code=401,
            message="Error: API key or token not found. Set WX_API_KEY or WX_TOKEN in environment variables or pass in as a parameter.",
        )
    if project_id is None:
        raise WatsonXAIError(
            status_code=401,
            message="Error: Watsonx project_id not set. Set WX_PROJECT_ID in environment variables or pass in as a parameter.",
        )

    return WatsonXAPIParams(
        url=url,
        api_key=api_key,
        token=cast(str, token),
        project_id=project_id,
        space_id=space_id,
        region_name=region_name,
        api_version=api_version,
    )
