"""Utilities of visualising an environment."""

from __future__ import annotations

from collections import deque
from collections.abc import Callable
from typing import TYPE_CHECKING

import numpy as np

import gymnasium as gym
from gymnasium import Env, logger
from gymnasium.core import ActType, ObsType
from gymnasium.error import DependencyNotInstalled


if TYPE_CHECKING:
    from matplotlib.axes import Axes


try:
    import pygame
    from pygame import Surface
    from pygame.event import Event
except ImportError as e:
    raise gym.error.DependencyNotInstalled(
        'pygame is not installed, run `pip install "gymnasium[classic_control]"`'
    ) from e

try:
    import matplotlib

    matplotlib.use("TkAgg")
    import matplotlib.pyplot as plt
except ImportError:
    logger.warn('matplotlib is not installed, run `pip install "gymnasium[other]"`')
    matplotlib, plt = None, None


class MissingKeysToAction(Exception):
    """Raised when the environment does not have a default ``keys_to_action`` mapping."""


class PlayableGame:
    """Wraps an environment allowing keyboard inputs to interact with the environment."""

    def __init__(
        self,
        env: Env,
        keys_to_action: dict[tuple[int, ...], int] | None = None,
        zoom: float | None = None,
    ):
        """Wraps an environment with a dictionary of keyboard buttons to action and if to zoom in on the environment.

        Args:
            env: The environment to play
            keys_to_action: The dictionary of keyboard tuples and action value
            zoom: If to zoom in on the environment render
        """
        if env.render_mode not in {"rgb_array", "rgb_array_list"}:
            raise ValueError(
                "PlayableGame wrapper works only with rgb_array and rgb_array_list render modes, "
                f"but your environment render_mode = {env.render_mode}."
            )

        self.env = env
        self.relevant_keys = self._get_relevant_keys(keys_to_action)
        # self.video_size is the size of the video that is being displayed.
        # The window size may be larger, in that case we will add black bars
        self.video_size = self._get_video_size(zoom)
        self.screen = pygame.display.set_mode(self.video_size, pygame.RESIZABLE)
        self.pressed_keys = []
        self.running = True

    def _get_relevant_keys(
        self, keys_to_action: dict[tuple[int], int] | None = None
    ) -> set:
        if keys_to_action is None:
            if self.env.has_wrapper_attr("get_keys_to_action"):
                keys_to_action = self.env.get_wrapper_attr("get_keys_to_action")()
            else:
                assert self.env.spec is not None
                raise MissingKeysToAction(
                    f"{self.env.spec.id} does not have explicit key to action mapping, "
                    "please specify one manually, `play(env, keys_to_action=...)`"
                )
        assert isinstance(keys_to_action, dict)
        relevant_keys = set(sum((list(k) for k in keys_to_action.keys()), []))
        return relevant_keys

    def _get_video_size(self, zoom: float | None = None) -> tuple[int, int]:
        rendered = self.env.render()
        if isinstance(rendered, list):
            rendered = rendered[-1]
        assert rendered is not None and isinstance(rendered, np.ndarray)
        video_size = (rendered.shape[1], rendered.shape[0])

        if zoom is not None:
            video_size = (int(video_size[0] * zoom), int(video_size[1] * zoom))

        return video_size

    def process_event(self, event: Event):
        """Processes a PyGame event.

        In particular, this function is used to keep track of which buttons are currently pressed
        and to exit the :func:`play` function when the PyGame window is closed.

        Args:
            event: The event to process
        """
        if event.type == pygame.KEYDOWN:
            if event.key in self.relevant_keys:
                self.pressed_keys.append(event.key)
            elif event.key == pygame.K_ESCAPE:
                self.running = False
        elif event.type == pygame.KEYUP:
            if event.key in self.relevant_keys:
                self.pressed_keys.remove(event.key)
        elif event.type == pygame.QUIT:
            self.running = False
        elif event.type == pygame.WINDOWRESIZED:
            # Compute the maximum video size that fits into the new window
            scale_width = event.x / self.video_size[0]
            scale_height = event.y / self.video_size[1]
            scale = min(scale_height, scale_width)
            self.video_size = (scale * self.video_size[0], scale * self.video_size[1])


def display_arr(
    screen: Surface, arr: np.ndarray, video_size: tuple[int, int], transpose: bool
):
    """Displays a numpy array on screen.

    Args:
        screen: The screen to show the array on
        arr: The array to show
        video_size: The video size of the screen
        transpose: If to transpose the array on the screen
    """
    assert isinstance(arr, np.ndarray) and arr.dtype == np.uint8
    pyg_img = pygame.surfarray.make_surface(arr.swapaxes(0, 1) if transpose else arr)
    pyg_img = pygame.transform.scale(pyg_img, video_size)
    # We might have to add black bars if surface_size is larger than video_size
    surface_size = screen.get_size()
    width_offset = (surface_size[0] - video_size[0]) / 2
    height_offset = (surface_size[1] - video_size[1]) / 2
    screen.fill((0, 0, 0))
    screen.blit(pyg_img, (width_offset, height_offset))


def play(
    env: Env,
    transpose: bool | None = True,
    fps: int | None = None,
    zoom: float | None = None,
    callback: Callable | None = None,
    keys_to_action: dict[tuple[str | int, ...] | str | int, ActType] | None = None,
    seed: int | None = None,
    noop: ActType = 0,
    wait_on_player: bool = False,
):
    """Allows the user to play the environment using a keyboard.

    If playing in a turn-based environment, set wait_on_player to True.

    Args:
        env: Environment to use for playing.
        transpose: If this is ``True``, the output of observation is transposed. Defaults to ``True``.
        fps: Maximum number of steps of the environment executed every second. If ``None`` (the default),
            ``env.metadata["render_fps""]`` (or 30, if the environment does not specify "render_fps") is used.
        zoom: Zoom the observation in, ``zoom`` amount, should be positive float
        callback: If a callback is provided, it will be executed after every step. It takes the following input:

            * obs_t: observation before performing action
            * obs_tp1: observation after performing action
            * action: action that was executed
            * rew: reward that was received
            * terminated: whether the environment is terminated or not
            * truncated: whether the environment is truncated or not
            * info: debug info
        keys_to_action:  Mapping from keys pressed to action performed.
            Different formats are supported: Key combinations can either be expressed as a tuple of unicode code
            points of the keys, as a tuple of characters, or as a string where each character of the string represents
            one key.
            For example if pressing 'w' and space at the same time is supposed
            to trigger action number 2 then ``key_to_action`` dict could look like this:

                >>> key_to_action = {
                ...    # ...
                ...    (ord('w'), ord(' ')): 2
                ...    # ...
                ... }

            or like this:

                >>> key_to_action = {
                ...    # ...
                ...    ("w", " "): 2
                ...    # ...
                ... }

            or like this:

                >>> key_to_action = {
                ...    # ...
                ...    "w ": 2
                ...    # ...
                ... }

            If ``None``, default ``key_to_action`` mapping for that environment is used, if provided.
        seed: Random seed used when resetting the environment. If None, no seed is used.
        noop: The action used when no key input has been entered, or the entered key combination is unknown.
        wait_on_player: Play should wait for a user action

    Example:
        >>> import gymnasium as gym
        >>> import numpy as np
        >>> from gymnasium.utils.play import play
        >>> play(gym.make("CarRacing-v3", render_mode="rgb_array"),  # doctest: +SKIP
        ...     keys_to_action={
        ...         "w": np.array([0, 0.7, 0], dtype=np.float32),
        ...         "a": np.array([-1, 0, 0], dtype=np.float32),
        ...         "s": np.array([0, 0, 1], dtype=np.float32),
        ...         "d": np.array([1, 0, 0], dtype=np.float32),
        ...         "wa": np.array([-1, 0.7, 0], dtype=np.float32),
        ...         "dw": np.array([1, 0.7, 0], dtype=np.float32),
        ...         "ds": np.array([1, 0, 1], dtype=np.float32),
        ...         "as": np.array([-1, 0, 1], dtype=np.float32),
        ...     },
        ...     noop=np.array([0, 0, 0], dtype=np.float32)
        ... )

        Above code works also if the environment is wrapped, so it's particularly useful in
        verifying that the frame-level preprocessing does not render the game
        unplayable.

        If you wish to plot real time statistics as you play, you can use
        :class:`PlayPlot`. Here's a sample code for plotting the reward
        for last 150 steps.

        >>> from gymnasium.utils.play import PlayPlot, play
        >>> def callback(obs_t, obs_tp1, action, rew, terminated, truncated, info):
        ...        return [rew,]
        >>> plotter = PlayPlot(callback, 150, ["reward"])             # doctest: +SKIP
        >>> play(gym.make("CartPole-v1"), callback=plotter.callback)  # doctest: +SKIP
    """
    env.reset(seed=seed)

    if keys_to_action is None:
        if env.has_wrapper_attr("get_keys_to_action"):
            keys_to_action = env.get_wrapper_attr("get_keys_to_action")()
        else:
            assert env.spec is not None
            raise MissingKeysToAction(
                f"{env.spec.id} does not have explicit key to action mapping, "
                "please specify one manually"
            )

    assert keys_to_action is not None

    # validate the `keys_to_action` set provided
    assert isinstance(keys_to_action, dict)
    for key, action in keys_to_action.items():
        if isinstance(key, tuple):
            assert len(key) > 0
            assert all(isinstance(k, (str, int)) for k in key)
        else:
            assert isinstance(key, (str, int))

        assert action in env.action_space

    key_code_to_action = {}
    for key_combination, action in keys_to_action.items():
        if isinstance(key_combination, int):
            key_combination = (key_combination,)
        key_code = tuple(
            sorted(ord(key) if isinstance(key, str) else key for key in key_combination)
        )
        key_code_to_action[key_code] = action

    game = PlayableGame(env, key_code_to_action, zoom)

    if fps is None:
        fps = env.metadata.get("render_fps", 30)

    done, obs = True, None
    clock = pygame.time.Clock()

    while game.running:
        if done:
            done = False
            obs = env.reset(seed=seed)
        elif wait_on_player is False or len(game.pressed_keys) > 0:
            action = key_code_to_action.get(tuple(sorted(game.pressed_keys)), noop)
            prev_obs = obs
            obs, rew, terminated, truncated, info = env.step(action)
            done = terminated or truncated
            if callback is not None:
                callback(prev_obs, obs, action, rew, terminated, truncated, info)
        if obs is not None:
            rendered = env.render()
            if isinstance(rendered, list):
                rendered = rendered[-1]
            assert rendered is not None and isinstance(rendered, np.ndarray)
            display_arr(
                game.screen, rendered, transpose=transpose, video_size=game.video_size
            )

        # process pygame events
        for event in pygame.event.get():
            game.process_event(event)

        pygame.display.flip()
        clock.tick(fps)
    pygame.quit()


class PlayPlot:
    """Provides a callback to create live plots of arbitrary metrics when using :func:`play`.

    This class is instantiated with a function that accepts information about a single environment transition:
        - obs_t: observation before performing action
        - obs_tp1: observation after performing action
        - action: action that was executed
        - rew: reward that was received
        - terminated: whether the environment is terminated or not
        - truncated: whether the environment is truncated or not
        - info: debug info

    It should return a list of metrics that are computed from this data.
    For instance, the function may look like this::

        >>> def compute_metrics(obs_t, obs_tp, action, reward, terminated, truncated, info):
        ...     return [reward, info["cumulative_reward"], np.linalg.norm(action)]

    :class:`PlayPlot` provides the method :meth:`callback` which will pass its arguments along to that function
    and uses the returned values to update live plots of the metrics.

    Typically, this :meth:`callback` will be used in conjunction with :func:`play` to see how the metrics evolve as you play::

        >>> plotter = PlayPlot(compute_metrics, horizon_timesteps=200,                               # doctest: +SKIP
        ...                    plot_names=["Immediate Rew.", "Cumulative Rew.", "Action Magnitude"])
        >>> play(your_env, callback=plotter.callback)                                                # doctest: +SKIP
    """

    def __init__(
        self, callback: Callable, horizon_timesteps: int, plot_names: list[str]
    ):
        """Constructor of :class:`PlayPlot`.

        The function ``callback`` that is passed to this constructor should return
        a list of metrics that is of length ``len(plot_names)``.

        Args:
            callback: Function that computes metrics from environment transitions
            horizon_timesteps: The time horizon used for the live plots
            plot_names: List of plot titles

        Raises:
            DependencyNotInstalled: If matplotlib is not installed
        """
        self.data_callback = callback
        self.horizon_timesteps = horizon_timesteps
        self.plot_names = plot_names

        if plt is None:
            raise DependencyNotInstalled(
                'matplotlib is not installed, run `pip install "gymnasium[other]"`'
            )

        num_plots = len(self.plot_names)
        self.fig, self.ax = plt.subplots(num_plots)
        if num_plots == 1:
            self.ax = [self.ax]
        for axis, name in zip(self.ax, plot_names):
            axis.set_title(name)
        self.t = 0
        self.cur_plot: list[Axes | None] = [None for _ in range(num_plots)]
        self.data = [deque(maxlen=horizon_timesteps) for _ in range(num_plots)]

    def callback(
        self,
        obs_t: ObsType,
        obs_tp1: ObsType,
        action: ActType,
        rew: float,
        terminated: bool,
        truncated: bool,
        info: dict,
    ):
        """The callback that calls the provided data callback and adds the data to the plots.

        Args:
            obs_t: The observation at time step t
            obs_tp1: The observation at time step t+1
            action: The action
            rew: The reward
            terminated: If the environment is terminated
            truncated: If the environment is truncated
            info: The information from the environment
        """
        points = self.data_callback(
            obs_t, obs_tp1, action, rew, terminated, truncated, info
        )
        for point, data_series in zip(points, self.data):
            data_series.append(point)
        self.t += 1

        xmin, xmax = max(0, self.t - self.horizon_timesteps), self.t

        for i, plot in enumerate(self.cur_plot):
            if plot is not None:
                plot.remove()
            self.cur_plot[i] = self.ax[i].scatter(
                range(xmin, xmax), list(self.data[i]), c="blue"
            )
            self.ax[i].set_xlim(xmin, xmax)

        if plt is None:
            raise DependencyNotInstalled(
                'matplotlib is not installed, run `pip install "gymnasium[other]"`'
            )
        plt.pause(0.000001)
