"""Noise generators to simulate noise in sensor / actuator classes."""

import abc
import gin
import numpy as np


class NoiseGenerator(metaclass=abc.ABCMeta):
  """Base class for noise generators."""

  @abc.abstractmethod
  def _get_noise(self, shape, dtype=None):
    """Gets noise as a numpy array in the specified shape and dtype.

    Tensorflow requires the shape and dtype of noise to be correctly specified,
    so the generator needs to know this to produce data of the correct type.

    Args:
      shape: Shape of the returned array.
      dtype: Datatype of returned array (None for default).
    """

  @abc.abstractmethod
  def add_noise(self, data):
    """Adds noise generated by _get_noise to the given data with clipping.

    Args:
      data: Numpy array of data to be modified with noise.
    """


@gin.configurable
class BiasNoise(NoiseGenerator):
  """Adds bias to the data, possibly with clipping."""

  def __init__(self,
               bias=0.0,
               clipping_lower_bound=-np.inf,
               clipping_upper_bound=np.inf):
    """Create a bias noise generator.

    Args:
      bias: Absolute magnitude of bias applied to input.
      clipping_lower_bound: lower bound of add_noise (use -np.inf to ignore).
      clipping_upper_bound: Upper bound of add_noise (use np.inf to ignore).
    """
    self._bias = bias
    self._clipping_lower_bound = clipping_lower_bound
    self._clipping_upper_bound = clipping_upper_bound

  def _get_noise(self, shape, dtype=None):
    """Create bias noise of the given direction and datatype."""
    return np.full(shape, self._bias, dtype)

  def add_noise(self, data):
    """Add bias noise to the given data, clipping to the given range."""
    noise = self._get_noise(data.shape, data.dtype)
    return np.clip(data + noise, self._clipping_lower_bound,
                   self._clipping_upper_bound)


@gin.configurable
class NormalNoise(BiasNoise):
  """Adds Gaussian noise to the data, possibly with clipping."""

  def __init__(self, scale, **kwargs):
    """Create a normal noise generator.

    Args:
      scale: Absolute magnitude of standard deviation of Gaussian noise. Note
        numpy will throw an error if scale < 0.
      **kwargs: Arguments passed to BiasNoise (e.g. bias and clipping).
    """
    super(NormalNoise, self).__init__(**kwargs)
    self._scale = scale

  def _get_noise(self, shape, dtype=None):
    """Create normal noise of the given direction and datatype."""
    return np.random.normal(self._bias, self._scale, shape).astype(dtype)


@gin.configurable
class UniformNoise(NoiseGenerator):
  """Generates uniform noise in the given range."""

  def __init__(self,
               low,
               high,
               clipping_lower_bound=-np.inf,
               clipping_upper_bound=np.inf):
    """Creates a uniform noise generator.

    Args:
      low: the lower bound of the noise.
      high: the higher bound of the noise.
      clipping_lower_bound: lower bound of add_noise (use -np.inf to ignore).
      clipping_upper_bound: Upper bound of add_noise (use np.inf to ignore).
    """
    super().__init__()
    self._low = low
    self._high = high
    self._clipping_lower_bound = clipping_lower_bound
    self._clipping_upper_bound = clipping_upper_bound

  def _get_noise(self, shape, dtype=None):
    """Generates a noise using the given shape and data type."""
    return np.random.uniform(self._low, self._high, shape).astype(dtype)

  def add_noise(self, data):
    """Adds noise to the given data, clipping to the given bound."""
    noise = self._get_noise(data.shape, data.dtype)
    return np.clip(data + noise, self._clipping_lower_bound,
                   self._clipping_upper_bound)


@gin.configurable
class RangeNoise(NormalNoise):
  """Add normally distributed noise in m, applied to hit fractions in (0, 1).

  This enables us to specify range noise in terms of meters of Gaussian noise
  between a maximum and minimum range, but the add_noise is applied as above
  to values expected to be in a hit fraction range of (0, 1) as needed for the
  SimLidarSensor API. Separate methods return noise or noisify data in meters.
  """

  def __init__(self, range_noise_m, max_range_m, min_range_m=0.0, **kwargs):
    """Create a normal noise generator suitable for use in a range scanner.

    Args:
      range_noise_m: Absolute magnitude of standard deviation of Gaussian noise,
        applied to range observation readngs, measured in meters.
      max_range_m: Maximum range in meters of the data, used for clipping.
      min_range_m: Minimum range in meters of the data, used for clipping.
      **kwargs: Other arguments passed to NormalNoise (principally bias).
    """
    # Validate range values.
    if range_noise_m < 0.0:
      raise ValueError("Range noise should not be negative: %r" % range_noise_m)
    if min_range_m >= max_range_m:
      raise ValueError("min_range_m %s must be less than max_range_m %s" %
                       (min_range_m, max_range_m))

    self._range_noise_m = range_noise_m
    self._max_range_m = max_range_m
    self._min_range_m = min_range_m
    self._total_range = max_range_m - min_range_m
    super(RangeNoise, self).__init__(
        scale=range_noise_m / self._total_range,
        clipping_lower_bound=0.0,
        clipping_upper_bound=1.0,
        **kwargs)

  def _get_noise_m(self, shape, dtype=None):
    """Create normal noise of the given direction and datatype, in meters."""
    return self.range_to_m(self._get_noise(shape=shape, dtype=dtype))

  def add_noise_m(self, data):
    """Add normal noise to the given data, scaled in meters."""
    return self.range_to_m(self.add_noise(self.m_to_range(data)))

  def m_to_range(self, data):
    """Scale data in meters to a range of (0, 1)."""
    return (data - self._min_range_m) / self._total_range

  def range_to_m(self, data):
    """Scale data in range of (0, 1) to meters."""
    return data * self._total_range + self._min_range_m


@gin.configurable
class TwistNoise(object):
  """Add normally distributed noise to twist actions.

    Note this is a simplified noise model in action space designed for parity
    with DriveWorld's  r/s/e/drive_models/twist_drive.py;rcl=307540784;l=161.
    This assumes that the TwistNoise will be applied to a twist action which is
    then clipped, as currently done in wheeled_robot_base.py:

    robotics/reinforcement_learning/minitaur/robots/wheeled_robot_base.py;l=533
    # We assume that the velocity clipping would be absorbed in this API.
    if self._action_filter:
      action = self._action_filter.filter(action)

    where action is a linear_velocity, angular_velocity pair, which is clipped
    to limits subsequently by the _compute_kinematic_base_velocity method.
  """

  def __init__(self,
               linear_velocity_noise_stdev_mps: float,
               linear_velocity_noise_max_stdevs: float,
               angular_velocity_noise_stdev_rps: float,
               angular_velocity_noise_max_stdevs: float,
               noise_scaling_cutoff_mps: float = 0.0):
    """Create a normal noise generator suitable for use in a range scanner.

    Supports the API specified in the DriveWorld TwistDrive class:
    robotics/simulation/environments/drive_models/twist_drive.py;l=54

    Args:
      linear_velocity_noise_stdev_mps: One standard deviation of normal noise
        for linear velocity in meters per second.
      linear_velocity_noise_max_stdevs: Max stdevs for linear velocity noise.
        This ensures that the noise values do not spike too crazy.
      angular_velocity_noise_stdev_rps: One standard deviation of normal noise
        for angular velocity in radians per second.
      angular_velocity_noise_max_stdevs: Max stdevs for angular velocity noise.
      noise_scaling_cutoff_mps: If linear velocity is less than this cutoff,
        linear and angular noise are scaled so that zero velocity produces zero
        noise. This enables a robot at rest to remain at rest, while still
        applying reasonable noise values to finite velocities. Angular velocity
        does not contribute to this computation to keep the model simple.
    """
    # Validate range values.
    if linear_velocity_noise_stdev_mps < 0.0:
      raise ValueError("Linear action noise should not be negative: %r" %
                       linear_velocity_noise_stdev_mps)
    if linear_velocity_noise_max_stdevs < 0.0:
      raise ValueError("Maximum linear noise should not be negative: %r" %
                       linear_velocity_noise_max_stdevs)
    if angular_velocity_noise_stdev_rps < 0.0:
      raise ValueError("Angular action noise should not be negative: %r" %
                       angular_velocity_noise_stdev_rps)
    if angular_velocity_noise_max_stdevs < 0.0:
      raise ValueError("Maximum action noise should not be negative: %r" %
                       angular_velocity_noise_max_stdevs)
    if noise_scaling_cutoff_mps < 0.0:
      raise ValueError("Noise scaling cutoff should not be negative: %r" %
                       noise_scaling_cutoff_mps)

    # Save the values to create our noise later.
    self._noise_shape = [
        linear_velocity_noise_stdev_mps, angular_velocity_noise_stdev_rps
    ]
    # The noise clipping is performed using one standard deviation as the unit.
    self._noise_lower_bound = np.array([
        -linear_velocity_noise_max_stdevs * linear_velocity_noise_stdev_mps,
        -angular_velocity_noise_max_stdevs * angular_velocity_noise_stdev_rps
    ])
    self._noise_upper_bound = -self._noise_lower_bound
    self._noise_scaling_cutoff_mps = noise_scaling_cutoff_mps

  def filter(self, action):
    """Filter the linear and angular velocity by adding noise."""
    linear_velocity, angular_velocity = action
    linear_noise, angular_noise = np.clip(
        np.random.normal(0, self._noise_shape, 2), self._noise_lower_bound,
        self._noise_upper_bound)
    if self._noise_scaling_cutoff_mps:
      clipped_velocity = min(abs(linear_velocity),
                             self._noise_scaling_cutoff_mps)
      scaling_factor = clipped_velocity / self._noise_scaling_cutoff_mps
      linear_noise *= scaling_factor
      angular_noise *= scaling_factor

    return (linear_velocity + linear_noise, angular_velocity + angular_noise)
