from __future__ import annotations

from contextlib import closing
from io import StringIO
from os import path

import numpy as np

import gymnasium as gym
from gymnasium import Env, spaces, utils
from gymnasium.envs.toy_text.utils import categorical_sample
from gymnasium.error import DependencyNotInstalled
from gymnasium.utils import seeding


LEFT = 0
DOWN = 1
RIGHT = 2
UP = 3

MAPS = {
    "4x4": ["SFFF", "FHFH", "FFFH", "HFFG"],
    "8x8": [
        "SFFFFFFF",
        "FFFFFFFF",
        "FFFHFFFF",
        "FFFFFHFF",
        "FFFHFFFF",
        "FHHFFFHF",
        "FHFFHFHF",
        "FFFHFFFG",
    ],
}


# DFS to check that it's a valid path.
def is_valid(board: list[list[str]], max_size: int) -> bool:
    frontier, discovered = [], set()
    frontier.append((0, 0))
    while frontier:
        r, c = frontier.pop()
        if not (r, c) in discovered:
            discovered.add((r, c))
            directions = [(1, 0), (0, 1), (-1, 0), (0, -1)]
            for x, y in directions:
                r_new = r + x
                c_new = c + y
                if r_new < 0 or r_new >= max_size or c_new < 0 or c_new >= max_size:
                    continue
                if board[r_new][c_new] == "G":
                    return True
                if board[r_new][c_new] != "H":
                    frontier.append((r_new, c_new))
    return False


def generate_random_map(
    size: int = 8, p: float = 0.8, seed: int | None = None
) -> list[str]:
    """Generates a random valid map (one that has a path from start to goal)

    Args:
        size: size of each side of the grid
        p: probability that a tile is frozen
        seed: optional seed to ensure the generation of reproducible maps

    Returns:
        A random valid map
    """
    valid = False
    board = []  # initialize to make pyright happy

    np_random, _ = seeding.np_random(seed)

    while not valid:
        p = min(1, p)
        board = np_random.choice(["F", "H"], (size, size), p=[p, 1 - p])
        board[0][0] = "S"
        board[-1][-1] = "G"
        valid = is_valid(board, size)
    return ["".join(x) for x in board]


class FrozenLakeEnv(Env):
    """
     Frozen lake involves crossing a frozen lake from start to goal without falling into any holes
     by walking over the frozen lake.
     The player may not always move in the intended direction due to the slippery nature of the frozen lake.

     ## Description
     The game starts with the player at location `[0,0]` of the frozen lake grid world with the
     goal located at far extent of the world e.g. `[3,3]` for the 4x4 environment.

     Holes in the ice are distributed in set locations when using a pre-determined map
     or in random locations when a random map is generated.
     Randomly generated worlds will always have a path to the goal.

     The player makes moves until they reach the goal or fall in a hole.

     The lake is slippery (unless disabled) so the player may move perpendicular
     to the intended direction sometimes (see `is_slippery` in Argument section).

     Elf and stool from [https://franuka.itch.io/rpg-snow-tileset](https://franuka.itch.io/rpg-snow-tileset).
     All other assets by Mel Tillery [http://www.cyaneus.com/](http://www.cyaneus.com/).

     ## Action Space
     The action shape is `(1,)` in the range `{0, 3}` indicating
     which direction to move the player.

     - 0: Move left
     - 1: Move down
     - 2: Move right
     - 3: Move up

     ## Observation Space
     The observation is a value representing the player's current position as
     `current_row * ncols + current_col` (where both the row and col start at 0).
     Therefore, the observation is returned as an integer.

     For example, the goal position in the 4x4 map can be calculated as follows: 3 * 4 + 3 = 15.
     The number of possible observations is dependent on the size of the map.

     ## Starting State
     The episode starts with the player in state `[0]` (location [0, 0]).

     ## Rewards

     Default reward schedule:
     - Reach goal: +1
     - Reach hole: 0
     - Reach frozen: 0

     See `reward_schedule` for reward customization in the Argument section.

     ## Episode End
     The episode ends if the following happens:

     - Termination:
         1. The player moves into a hole.
         2. The player reaches the goal at `max(nrow) * max(ncol) - 1` (location `[max(nrow)-1, max(ncol)-1]`).

     - Truncation (using the time_limit wrapper):
         1. The length of the episode is 100 for FrozenLake4x4, 200 for FrozenLake8x8.

     ## Information

     `step()` and `reset()` return a dict with the following keys:
     - `p`: transition probability for the state which will be impacted by the `is_slippery` parameter.

     ## Arguments

     FrozenLake has five parameters:
     ```python
     import gymnasium as gym
     gym.make(
         'FrozenLake-v1',
         desc=None,
         map_name="4x4",
         is_slippery=True,
         success_rate=1.0/3.0,
         reward_schedule=(1, 0, 0)
     )
     ```

     * `desc=None`: Used to specify maps non-preloaded maps.
         If `desc=None` then `map_name` will be used. If both `desc` and `map_name` are
         `None` a random 8x8 map with 80% of locations frozen will be generated.

         To Specify a custom map - `desc=["SFFF", "FHFH", "FFFH", "HFFG"]`
         The tile letters denote:
         - "S" for Start tile
         - "G" for Goal tile
         - "F" for frozen tile
         - "H" for a tile with a hole

         A random generated map can be specified by calling the function `generate_random_map`.
         ```
         from gymnasium.envs.toy_text.frozen_lake import generate_random_map

         gym.make('FrozenLake-v1', desc=generate_random_map(size=8))
         ```

     * `map_name="4x4"` - Helps load two predefined map names (`4x4` and `8x8`)
         ```
         "4x4":[
             "SFFF",
             "FHFH",
             "FFFH",
             "HFFG"
         ]

         "8x8": [
             "SFFFFFFF",
             "FFFFFFFF",
             "FFFHFFFF",
             "FFFFFHFF",
             "FFFHFFFF",
             "FHHFFFHF",
             "FHFFHFHF",
             "FFFHFFFG",
         ]
         ```

    * `is_slippery=True`: If true the player will move in intended direction with probability specified by the
         `success_rate` else will move in either perpendicular direction with equal probability in both directions.

         For example, if action is left, `is_slippery` is True, and `success_rate` is 1/3, then:
         - P(move left)=1/3
         - P(move up)=1/3
         - P(move down)=1/3

         If action is up, `is_slippery` is True, and `success_rate` is 3/4, then:
         - P(move up)=3/4
         - P(move left)=1/8
         - P(move right)=1/8

    * `success_rate=1.0/3.0`: Used to specify the probability of moving in the intended direction when is_slippery=True

    * `reward_schedule=(1, 0, 0)`: Used to specify reward amounts for reaching certain tiles.
         The indices correspond to: Reach Goal, Reach Hole, Reach Frozen (includes Start), Respectively

     ## Version History
     * v1: Bug fixes to rewards (v1.3, added reward customization)
     * v0: Initial version release

    """

    metadata = {
        "render_modes": ["human", "ansi", "rgb_array"],
        "render_fps": 4,
    }

    def __init__(
        self,
        render_mode: str | None = None,
        desc: list[str] = None,
        map_name: str = "4x4",
        is_slippery: bool = True,
        success_rate: float = 1.0 / 3.0,
        reward_schedule: tuple[int, int, int] = (1, 0, 0),
    ):
        if desc is None and map_name is None:
            desc = generate_random_map()
        elif desc is None:
            desc = MAPS[map_name]
        self.desc = desc = np.asarray(desc, dtype="c")
        self.nrow, self.ncol = nrow, ncol = desc.shape
        self.reward_range = (min(reward_schedule), max(reward_schedule))

        nA = 4
        nS = nrow * ncol

        self.initial_state_distrib = np.array(desc == b"S").astype("float64").ravel()
        self.initial_state_distrib /= self.initial_state_distrib.sum()

        self.P = {s: {a: [] for a in range(nA)} for s in range(nS)}

        fail_rate = (1.0 - success_rate) / 2.0

        def to_s(row, col):
            return row * ncol + col

        def inc(row, col, a):
            if a == LEFT:
                col = max(col - 1, 0)
            elif a == DOWN:
                row = min(row + 1, nrow - 1)
            elif a == RIGHT:
                col = min(col + 1, ncol - 1)
            elif a == UP:
                row = max(row - 1, 0)
            return (row, col)

        def update_probability_matrix(row, col, action):
            new_row, new_col = inc(row, col, action)
            new_state = to_s(new_row, new_col)
            new_letter = desc[new_row, new_col]
            terminated = bytes(new_letter) in b"GH"
            reward = reward_schedule[
                b"GHF".index(new_letter if new_letter in b"GHF" else b"F")
            ]
            return new_state, reward, terminated

        for row in range(nrow):
            for col in range(ncol):
                s = to_s(row, col)
                for a in range(4):
                    li = self.P[s][a]
                    letter = desc[row, col]
                    if letter in b"GH":
                        li.append((1.0, s, 0, True))
                    else:
                        if is_slippery:
                            for b in [(a - 1) % 4, a, (a + 1) % 4]:
                                li.append(
                                    (
                                        success_rate if b == a else fail_rate,
                                        *update_probability_matrix(row, col, b),
                                    )
                                )
                        else:
                            li.append((1.0, *update_probability_matrix(row, col, a)))

        self.observation_space = spaces.Discrete(nS)
        self.action_space = spaces.Discrete(nA)

        self.render_mode = render_mode

        # pygame utils
        self.window_size = (min(64 * ncol, 512), min(64 * nrow, 512))
        self.cell_size = (
            self.window_size[0] // self.ncol,
            self.window_size[1] // self.nrow,
        )
        self.window_surface = None
        self.clock = None
        self.hole_img = None
        self.cracked_hole_img = None
        self.ice_img = None
        self.elf_images = None
        self.goal_img = None
        self.start_img = None

    def step(self, a):
        transitions = self.P[self.s][a]
        i = categorical_sample([t[0] for t in transitions], self.np_random)
        p, s, r, t = transitions[i]
        self.s = s
        self.lastaction = a

        if self.render_mode == "human":
            self.render()
        # truncation=False as the time limit is handled by the `TimeLimit` wrapper added during `make`
        return int(s), r, t, False, {"prob": p}

    def reset(
        self,
        *,
        seed: int | None = None,
        options: dict | None = None,
    ):
        super().reset(seed=seed)
        self.s = categorical_sample(self.initial_state_distrib, self.np_random)
        self.lastaction = None

        if self.render_mode == "human":
            self.render()
        return int(self.s), {"prob": 1}

    def render(self):
        if self.render_mode is None:
            assert self.spec is not None
            gym.logger.warn(
                "You are calling render method without specifying any render mode. "
                "You can specify the render_mode at initialization, "
                f'e.g. gym.make("{self.spec.id}", render_mode="rgb_array")'
            )
            return

        if self.render_mode == "ansi":
            return self._render_text()
        else:  # self.render_mode in {"human", "rgb_array"}:
            return self._render_gui(self.render_mode)

    def _render_gui(self, mode):
        try:
            import pygame
        except ImportError as e:
            raise DependencyNotInstalled(
                'pygame is not installed, run `pip install "gymnasium[toy-text]"`'
            ) from e

        if self.window_surface is None:
            pygame.init()

            if mode == "human":
                pygame.display.init()
                pygame.display.set_caption("Frozen Lake")
                self.window_surface = pygame.display.set_mode(self.window_size)
            elif mode == "rgb_array":
                self.window_surface = pygame.Surface(self.window_size)

        assert (
            self.window_surface is not None
        ), "Something went wrong with pygame. This should never happen."

        if self.clock is None:
            self.clock = pygame.time.Clock()
        if self.hole_img is None:
            file_name = path.join(path.dirname(__file__), "img/hole.png")
            self.hole_img = pygame.transform.scale(
                pygame.image.load(file_name), self.cell_size
            )
        if self.cracked_hole_img is None:
            file_name = path.join(path.dirname(__file__), "img/cracked_hole.png")
            self.cracked_hole_img = pygame.transform.scale(
                pygame.image.load(file_name), self.cell_size
            )
        if self.ice_img is None:
            file_name = path.join(path.dirname(__file__), "img/ice.png")
            self.ice_img = pygame.transform.scale(
                pygame.image.load(file_name), self.cell_size
            )
        if self.goal_img is None:
            file_name = path.join(path.dirname(__file__), "img/goal.png")
            self.goal_img = pygame.transform.scale(
                pygame.image.load(file_name), self.cell_size
            )
        if self.start_img is None:
            file_name = path.join(path.dirname(__file__), "img/stool.png")
            self.start_img = pygame.transform.scale(
                pygame.image.load(file_name), self.cell_size
            )
        if self.elf_images is None:
            elfs = [
                path.join(path.dirname(__file__), "img/elf_left.png"),
                path.join(path.dirname(__file__), "img/elf_down.png"),
                path.join(path.dirname(__file__), "img/elf_right.png"),
                path.join(path.dirname(__file__), "img/elf_up.png"),
            ]
            self.elf_images = [
                pygame.transform.scale(pygame.image.load(f_name), self.cell_size)
                for f_name in elfs
            ]

        desc = self.desc.tolist()
        assert isinstance(desc, list), f"desc should be a list or an array, got {desc}"
        for y in range(self.nrow):
            for x in range(self.ncol):
                pos = (x * self.cell_size[0], y * self.cell_size[1])
                rect = (*pos, *self.cell_size)

                self.window_surface.blit(self.ice_img, pos)
                if desc[y][x] == b"H":
                    self.window_surface.blit(self.hole_img, pos)
                elif desc[y][x] == b"G":
                    self.window_surface.blit(self.goal_img, pos)
                elif desc[y][x] == b"S":
                    self.window_surface.blit(self.start_img, pos)

                pygame.draw.rect(self.window_surface, (180, 200, 230), rect, 1)

        # paint the elf
        bot_row, bot_col = self.s // self.ncol, self.s % self.ncol
        cell_rect = (bot_col * self.cell_size[0], bot_row * self.cell_size[1])
        last_action = self.lastaction if self.lastaction is not None else 1
        elf_img = self.elf_images[last_action]

        if desc[bot_row][bot_col] == b"H":
            self.window_surface.blit(self.cracked_hole_img, cell_rect)
        else:
            self.window_surface.blit(elf_img, cell_rect)

        if mode == "human":
            pygame.event.pump()
            pygame.display.update()
            self.clock.tick(self.metadata["render_fps"])
        elif mode == "rgb_array":
            return np.transpose(
                np.array(pygame.surfarray.pixels3d(self.window_surface)), axes=(1, 0, 2)
            )

    @staticmethod
    def _center_small_rect(big_rect, small_dims):
        offset_w = (big_rect[2] - small_dims[0]) / 2
        offset_h = (big_rect[3] - small_dims[1]) / 2
        return (
            big_rect[0] + offset_w,
            big_rect[1] + offset_h,
        )

    def _render_text(self):
        desc = self.desc.tolist()
        outfile = StringIO()

        row, col = self.s // self.ncol, self.s % self.ncol
        desc = [[c.decode("utf-8") for c in line] for line in desc]
        desc[row][col] = utils.colorize(desc[row][col], "red", highlight=True)
        if self.lastaction is not None:
            outfile.write(f"  ({['Left', 'Down', 'Right', 'Up'][self.lastaction]})\n")
        else:
            outfile.write("\n")
        outfile.write("\n".join("".join(line) for line in desc) + "\n")

        with closing(outfile):
            return outfile.getvalue()

    def close(self):
        if self.window_surface is not None:
            import pygame

            pygame.display.quit()
            pygame.quit()


# Elf and stool from https://franuka.itch.io/rpg-snow-tileset
# All other assets by Mel Tillery http://www.cyaneus.com/
