# Copyright 2017 The TensorFlow Agents Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Wrappers for OpenAI Gym environments."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import atexit
import multiprocessing
import sys
import traceback

import gym
import gym.spaces
import numpy as np
try:
  import tensorflow.compat.v1 as tf
except Exception:
  import tensorflow as tf


class AutoReset(object):
  """Automatically reset environment when the episode is done."""

  def __init__(self, env):
    self._env = env
    self._done = True

  def __getattr__(self, name):
    return getattr(self._env, name)

  def step(self, action):
    if self._done:
      observ, reward, done, info = self._env.reset(), 0.0, False, {}
    else:
      observ, reward, done, info = self._env.step(action)
    self._done = done
    return observ, reward, done, info

  def reset(self):
    self._done = False
    return self._env.reset()


class ActionRepeat(object):
  """Repeat the agent action multiple steps."""

  def __init__(self, env, amount):
    self._env = env
    self._amount = amount

  def __getattr__(self, name):
    return getattr(self._env, name)

  def step(self, action):
    done = False
    total_reward = 0
    current_step = 0
    while current_step < self._amount and not done:
      observ, reward, done, info = self._env.step(action)
      total_reward += reward
      current_step += 1
    return observ, total_reward, done, info


class RandomStart(object):
  """Perform random number of random actions at the start of the episode."""

  def __init__(self, env, max_steps):
    self._env = env
    self._max_steps = max_steps

  def __getattr__(self, name):
    return getattr(self._env, name)

  def reset(self):
    observ = self._env.reset()
    random_steps = np.random.randint(0, self._max_steps)
    for _ in range(random_steps):
      action = self._env.action_space.sample()
      observ, unused_reward, done, unused_info = self._env.step(action)
      if done:
        tf.logging.warning('Episode ended during random start.')
        return self.reset()
    return observ


class FrameHistory(object):
  """Augment the observation with past observations."""

  def __init__(self, env, past_indices, flatten):
    """Augment the observation with past observations.

    Implemented as a Numpy ring buffer holding the necessary past observations.

    Args:
      env: OpenAI Gym environment to wrap.
      past_indices: List of non-negative integers indicating the time offsets
        from the current time step of observations to include.
      flatten: Concatenate the past observations rather than stacking them.

    Raises:
      KeyError: The current observation is not included in the indices.
    """
    if 0 not in past_indices:
      raise KeyError('Past indices should include 0 for the current frame.')
    self._env = env
    self._past_indices = past_indices
    self._step = 0
    self._buffer = None
    self._capacity = max(past_indices)
    self._flatten = flatten

  def __getattr__(self, name):
    return getattr(self._env, name)

  @property
  def observation_space(self):
    low = self._env.observation_space.low
    high = self._env.observation_space.high
    low = np.repeat(low[None, ...], len(self._past_indices), 0)
    high = np.repeat(high[None, ...], len(self._past_indices), 0)
    if self._flatten:
      low = np.reshape(low, (-1,) + low.shape[2:])
      high = np.reshape(high, (-1,) + high.shape[2:])
    return gym.spaces.Box(low, high)

  def step(self, action):
    observ, reward, done, info = self._env.step(action)
    self._step += 1
    self._buffer[self._step % self._capacity] = observ
    observ = self._select_frames()
    return observ, reward, done, info

  def reset(self):
    observ = self._env.reset()
    self._buffer = np.repeat(observ[None, ...], self._capacity, 0)
    self._step = 0
    return self._select_frames()

  def _select_frames(self):
    indices = [(self._step - index) % self._capacity for index in self._past_indices]
    observ = self._buffer[indices]
    if self._flatten:
      observ = np.reshape(observ, (-1,) + observ.shape[2:])
    return observ


class FrameDelta(object):
  """Convert the observation to a difference from the previous observation."""

  def __init__(self, env):
    self._env = env
    self._last = None

  def __getattr__(self, name):
    return getattr(self._env, name)

  @property
  def observation_space(self):
    low = self._env.observation_space.low
    high = self._env.observation_space.high
    low, high = low - high, high - low
    return gym.spaces.Box(low, high)

  def step(self, action):
    observ, reward, done, info = self._env.step(action)
    delta = observ - self._last
    self._last = observ
    return delta, reward, done, info

  def reset(self):
    observ = self._env.reset()
    self._last = observ
    return observ


class RangeNormalize(object):
  """Normalize the specialized observation and action ranges to [-1, 1]."""

  def __init__(self, env, observ=None, action=None):
    self._env = env
    self._should_normalize_observ = (observ is not False and
                                     self._is_finite(self._env.observation_space))
    if observ is True and not self._should_normalize_observ:
      raise ValueError('Cannot normalize infinite observation range.')
    if observ is None and not self._should_normalize_observ:
      tf.logging.info('Not normalizing infinite observation range.')
    self._should_normalize_action = (action is not False and
                                     self._is_finite(self._env.action_space))
    if action is True and not self._should_normalize_action:
      raise ValueError('Cannot normalize infinite action range.')
    if action is None and not self._should_normalize_action:
      tf.logging.info('Not normalizing infinite action range.')

  def __getattr__(self, name):
    return getattr(self._env, name)

  @property
  def observation_space(self):
    space = self._env.observation_space
    if not self._should_normalize_observ:
      return space
    return gym.spaces.Box(-np.ones(space.shape), np.ones(space.shape))

  @property
  def action_space(self):
    space = self._env.action_space
    if not self._should_normalize_action:
      return space
    return gym.spaces.Box(-np.ones(space.shape), np.ones(space.shape))

  def step(self, action):
    if self._should_normalize_action:
      action = self._denormalize_action(action)
    observ, reward, done, info = self._env.step(action)
    if self._should_normalize_observ:
      observ = self._normalize_observ(observ)
    return observ, reward, done, info

  def reset(self):
    observ = self._env.reset()
    if self._should_normalize_observ:
      observ = self._normalize_observ(observ)
    return observ

  def _denormalize_action(self, action):
    min_ = self._env.action_space.low
    max_ = self._env.action_space.high
    action = (action + 1) / 2 * (max_ - min_) + min_
    return action

  def _normalize_observ(self, observ):
    min_ = self._env.observation_space.low
    max_ = self._env.observation_space.high
    observ = 2 * (observ - min_) / (max_ - min_) - 1
    return observ

  def _is_finite(self, space):
    return np.isfinite(space.low).all() and np.isfinite(space.high).all()


class ClipAction(object):
  """Clip out of range actions to the action space of the environment."""

  def __init__(self, env):
    self._env = env

  def __getattr__(self, name):
    return getattr(self._env, name)

  @property
  def action_space(self):
    shape = self._env.action_space.shape
    return gym.spaces.Box(-np.inf * np.ones(shape), np.inf * np.ones(shape))

  def step(self, action):
    action_space = self._env.action_space
    action = np.clip(action, action_space.low, action_space.high)
    return self._env.step(action)


class LimitDuration(object):
  """End episodes after specified number of steps."""

  def __init__(self, env, duration):
    self._env = env
    self._duration = duration
    self._step = None

  def __getattr__(self, name):
    return getattr(self._env, name)

  def step(self, action):
    if self._step is None:
      raise RuntimeError('Must reset environment.')
    observ, reward, done, info = self._env.step(action)
    self._step += 1
    if self._step >= self._duration:
      done = True
      self._step = None
    return observ, reward, done, info

  def reset(self):
    self._step = 0
    return self._env.reset()


class ExternalProcess(object):
  """Step environment in a separate process for lock free paralellism."""

  # Message types for communication via the pipe.
  _ACCESS = 1
  _CALL = 2
  _RESULT = 3
  _EXCEPTION = 4
  _CLOSE = 5

  def __init__(self, constructor):
    """Step environment in a separate process for lock free paralellism.

    The environment will be created in the external process by calling the
    specified callable. This can be an environment class, or a function
    creating the environment and potentially wrapping it. The returned
    environment should not access global variables.

    Args:
      constructor: Callable that creates and returns an OpenAI gym environment.

    Attributes:
      observation_space: The cached observation space of the environment.
      action_space: The cached action space of the environment.
    """
    self._conn, conn = multiprocessing.Pipe()
    self._process = multiprocessing.Process(target=self._worker, args=(constructor, conn))
    atexit.register(self.close)
    self._process.start()
    self._observ_space = None
    self._action_space = None

  @property
  def observation_space(self):
    if not self._observ_space:
      self._observ_space = self.__getattr__('observation_space')
    return self._observ_space

  @property
  def action_space(self):
    if not self._action_space:
      self._action_space = self.__getattr__('action_space')
    return self._action_space

  def __getattr__(self, name):
    """Request an attribute from the environment.

    Note that this involves communication with the external process, so it can
    be slow.

    Args:
      name: Attribute to access.

    Returns:
      Value of the attribute.
    """
    self._conn.send((self._ACCESS, name))
    return self._receive()

  def call(self, name, *args, **kwargs):
    """Asynchronously call a method of the external environment.

    Args:
      name: Name of the method to call.
      *args: Positional arguments to forward to the method.
      **kwargs: Keyword arguments to forward to the method.

    Returns:
      Promise object that blocks and provides the return value when called.
    """
    payload = name, args, kwargs
    self._conn.send((self._CALL, payload))
    return self._receive

  def close(self):
    """Send a close message to the external process and join it."""
    try:
      self._conn.send((self._CLOSE, None))
      self._conn.close()
    except IOError:
      # The connection was already closed.
      pass
    self._process.join()

  def step(self, action, blocking=True):
    """Step the environment.

    Args:
      action: The action to apply to the environment.
      blocking: Whether to wait for the result.

    Returns:
      Transition tuple when blocking, otherwise callable that returns the
      transition tuple.
    """
    promise = self.call('step', action)
    if blocking:
      return promise()
    else:
      return promise

  def reset(self, blocking=True):
    """Reset the environment.

    Args:
      blocking: Whether to wait for the result.

    Returns:
      New observation when blocking, otherwise callable that returns the new
      observation.
    """
    promise = self.call('reset')
    if blocking:
      return promise()
    else:
      return promise

  def _receive(self):
    """Wait for a message from the worker process and return its payload.

    Raises:
      Exception: An exception was raised inside the worker process.
      KeyError: The reveived message is of an unknown type.

    Returns:
      Payload object of the message.
    """
    message, payload = self._conn.recv()
    # Re-raise exceptions in the main process.
    if message == self._EXCEPTION:
      stacktrace = payload
      raise Exception(stacktrace)
    if message == self._RESULT:
      return payload
    raise KeyError('Received message of unexpected type {}'.format(message))

  def _worker(self, constructor, conn):
    """The process waits for actions and sends back environment results.

    Args:
      constructor: Constructor for the OpenAI Gym environment.
      conn: Connection for communication to the main process.
    """
    try:
      env = constructor()
      while True:
        try:
          # Only block for short times to have keyboard exceptions be raised.
          if not conn.poll(0.1):
            continue
          message, payload = conn.recv()
        except (EOFError, KeyboardInterrupt):
          break
        if message == self._ACCESS:
          name = payload
          result = getattr(env, name)
          conn.send((self._RESULT, result))
          continue
        if message == self._CALL:
          name, args, kwargs = payload
          result = getattr(env, name)(*args, **kwargs)
          conn.send((self._RESULT, result))
          continue
        if message == self._CLOSE:
          assert payload is None
          break
        raise KeyError('Received message of unknown type {}'.format(message))
    except Exception:  # pylint: disable=broad-except
      stacktrace = ''.join(traceback.format_exception(*sys.exc_info()))
      tf.logging.error('Error in environment process: {}'.format(stacktrace))
      conn.send((self._EXCEPTION, stacktrace))
    conn.close()


class ConvertTo32Bit(object):
  """Convert data types of an OpenAI Gym environment to 32 bit."""

  def __init__(self, env):
    """Convert data types of an OpenAI Gym environment to 32 bit.

    Args:
      env: OpenAI Gym environment.
    """
    self._env = env

  def __getattr__(self, name):
    """Forward unimplemented attributes to the original environment.

    Args:
      name: Attribute that was accessed.

    Returns:
      Value behind the attribute name in the wrapped environment.
    """
    return getattr(self._env, name)

  def step(self, action):
    """Forward action to the wrapped environment.

    Args:
      action: Action to apply to the environment.

    Raises:
      ValueError: Invalid action.

    Returns:
      Converted observation, converted reward, done flag, and info object.
    """
    observ, reward, done, info = self._env.step(action)
    observ = self._convert_observ(observ)
    reward = self._convert_reward(reward)
    return observ, reward, done, info

  def reset(self):
    """Reset the environment and convert the resulting observation.

    Returns:
      Converted observation.
    """
    observ = self._env.reset()
    observ = self._convert_observ(observ)
    return observ

  def _convert_observ(self, observ):
    """Convert the observation to 32 bits.

    Args:
      observ: Numpy observation.

    Raises:
      ValueError: Observation contains infinite values.

    Returns:
      Numpy observation with 32-bit data type.
    """
    if not np.isfinite(observ).all():
      raise ValueError('Infinite observation encountered.')
    if observ.dtype == np.float64:
      return observ.astype(np.float32)
    if observ.dtype == np.int64:
      return observ.astype(np.int32)
    return observ

  def _convert_reward(self, reward):
    """Convert the reward to 32 bits.

    Args:
      reward: Numpy reward.

    Raises:
      ValueError: Rewards contain infinite values.

    Returns:
      Numpy reward with 32-bit data type.
    """
    if not np.isfinite(reward).all():
      raise ValueError('Infinite reward encountered.')
    return np.array(reward, dtype=np.float32)
