# This code is part of Qiskit.
#
# (C) Copyright IBM 2024.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

"""
==================================================================================
NoiseLearner result classes (:mod:`qiskit_ibm_runtime.utils.noise_learner_result`)
==================================================================================

.. autosummary::
   :toctree: ../stubs/
   :nosignatures:

   PauliLindbladError
   LayerError
"""

from __future__ import annotations

from typing import Any, Iterator, Optional, Sequence, Union, TYPE_CHECKING
from numpy.typing import NDArray
import numpy as np

from qiskit.providers.backend import BackendV2
from qiskit.circuit import QuantumCircuit
from qiskit.quantum_info import PauliList, Pauli

from ..utils.embeddings import Embedding

if TYPE_CHECKING:
    from plotly.graph_objects import Figure as PlotlyFigure


class PauliLindbladError:
    r"""A Pauli error channel generated by a Pauli Lindblad dissipators.

    This operator represents an N-qubit quantum error channel
    :math:`E(\rho) = e^{\sum_j r_j D_{P_j}}(\rho)` generated by Pauli Lindblad dissipators
    :math:`D_P(\rho) = P \rho P - \rho`, where :math:`P_j` are N-qubit :class:`~.Pauli`
    operators.

    The list of Pauli generator terms are stored as a :class:`~.PauliList` and can be
    accessed via the :attr:`~generators` attribute. The array of dissipator rates
    :math:`r_j` can be accessed via the :attr:`~rates` attribute.

    The equivalent Pauli error channel can be constructed as a composition
    of single-Pauli channel terms

    .. math::

        E = e^{\sum_j r_j D_{P_j}} = \prod_j e^{r_j D_{P_j}}
        = prod_j \left( (1 - p_j) S_I + p_j S_{P_j} \right)

    where :math:`p_j = \frac12 - \frac12 e^{-2 r_j}` [1].

    Args:
        generators: A list of the Pauli Lindblad generators for the error channel.
        rates: A list of the rates for the Pauli-Lindblad ``generators``.

    Raises:
        ValueError: If ``generators`` and ``rates`` have different lengths.

    References:
        1. E. van den Berg, Z. Minev, A. Kandala, K. Temme, *Probabilistic error
           cancellation with sparse Pauli–Lindblad models on noisy quantum processors*,
           Nature Physics volume 19, pages1116–1121 (2023).
           `arXiv:2201.09866 [quant-ph] <https://arxiv.org/abs/2201.09866>`_
    """

    def __init__(self, generators: PauliList, rates: Sequence[float]) -> None:
        self._generators = generators
        self._rates = np.asarray(rates, dtype=float)

        if (len(generators),) != self._rates.shape:
            raise ValueError(
                f"``generators`` has length {len(generators)} "
                f"but ``rates`` has shape {self._rates.shape}."
            )

    @property
    def generators(self) -> PauliList:
        r"""
        The Pauli Lindblad generators of this :class:`~.PauliLindbladError`.
        """
        return self._generators

    @property
    def rates(self) -> NDArray[np.float64]:
        r"""
        The Lindblad generator rates of this quantum error.
        """
        return self._rates

    @property
    def num_qubits(self) -> int:
        r"""
        The number of qubits in this :class:`~.PauliLindbladError`.
        """
        return self.generators.num_qubits

    def restrict_num_bodies(self, num_qubits: int) -> PauliLindbladError:
        r"""
        The :class:`~.PauliLindbladError` containing only those terms acting on exactly
        ``num_qubits`` qubits.

        Args:
            num_qubits: The number of qubits that the returned error acts on.

        Returns:
            The error containing only those terms acting on exactly ``num_qubits`` qubits.

        Raises:
            ValueError: If ``num_qubits`` is smaller than ``0``.
        """
        if num_qubits < 0:
            raise ValueError("``num_qubits`` must be ``0`` or larger.")
        mask = np.sum(self.generators.x | self.generators.z, axis=1) == num_qubits
        return PauliLindbladError(self.generators[mask], self.rates[mask])

    def _json(self) -> dict:
        """Return a dictionary containing all the information to re-initialize this object."""
        return {"generators": self.generators, "rates": self.rates}

    def __repr__(self) -> str:
        return f"PauliLindbladError(generators={self.generators}, rates={self.rates.tolist()})"


class LayerError:
    """The error channel (in Pauli-Lindblad format) of a single layer of instructions.

    Args:
        circuit: A circuit whose noise has been learnt.
        qubits: The labels of the qubits in the ``circuit``.
        error: The Pauli Lindblad error channel affecting the ``circuit``, or ``None`` if the error
            channel is either unknown or explicitly disabled.

    Raises:
        ValueError: If ``circuit``, ``qubits``, and ``error`` have mismatching number of qubits.
    """

    def __init__(
        self,
        circuit: QuantumCircuit,
        qubits: Sequence[int],
        error: Optional[PauliLindbladError] = None,
    ) -> None:

        self._circuit = circuit
        self._qubits = list(qubits)
        self._error = error

        err = ValueError("Mistmatching numbers of qubits.")
        if len(self.qubits) != self.circuit.num_qubits:
            raise err
        if self.error is not None and len(self.qubits) != self.error.num_qubits:
            raise err

    @property
    def circuit(self) -> QuantumCircuit:
        r"""
        The circuit in this :class:`.~LayerError`.
        """
        return self._circuit

    @property
    def qubits(self) -> list[int]:
        r"""
        The qubits in this :class:`.~LayerError`.
        """
        return self._qubits

    @property
    def error(self) -> Union[PauliLindbladError, None]:
        r"""
        The error channel in this :class:`.~LayerError`, or ``None`` if the error channel is either
        unknown or explicitly disabled.
        """
        return self._error

    @property
    def num_qubits(self) -> int:
        r"""
        The number of qubits in this :class:`~.LayerError`.
        """
        return len(self.qubits)

    def draw_map(
        self,
        embedding: Union[Embedding, BackendV2],
        colorscale: str = "Bluered",
        color_no_data: str = "lightgray",
        color_out_of_scale: str = "lightgreen",
        num_edge_segments: int = 16,
        edge_width: float = 4,
        height: int = 500,
        highest_rate: Optional[float] = None,
        background_color: str = "white",
        radius: float = 0.25,
        width: int = 800,
    ) -> PlotlyFigure:
        r"""
        Draw a map view of a this layer error.

        Args:
            embedding: An :class:`~.Embedding` object containing the coordinates and coupling map
                to draw the layer error on, or a backend to generate an :class:`~.Embedding` for.
            colorscale: The colorscale used to show the rates of this layer error.
            color_no_data: The color used for qubits and edges for which no data is available.
            color_out_of_scale: The color used for rates with value greater than ``highest_rate``.
            num_edge_segments: The number of equal-sized segments that edges are made of.
            edge_width: The line width of the edges in pixels.
            height: The height of the returned figure.
            highest_rate: The highest rate, used to normalize all other rates before choosing their
                colors. If ``None``, it defaults to the highest value found in the ``layer_error``.
            background_color: The background color.
            radius: The radius of the pie charts representing the qubits.
            width: The width of the returned figure.

        .. code:: python

            from qiskit import QuantumCircuit
            from qiskit.quantum_info import PauliList
            from qiskit_ibm_runtime.utils.embeddings import Embedding
            from qiskit_ibm_runtime.utils.noise_learner_result import LayerError, PauliLindbladError

            # A five-qubit 1-D embedding with nearest neighbouring connectivity
            coordinates1 = [(0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5)]
            coupling_map1 = [(0, 1), (1, 2), (2, 3), (3, 4), (4, 5)]
            embedding1 = Embedding(coordinates1, coupling_map1)

            # A six-qubit horseshoe-shaped embedding with nearest neighbouring connectivity
            coordinates2 = [(0, 0), (0, 1), (0, 2), (1, 0), (1, 1), (1, 2)]
            coupling_map2 = [(0, 1), (1, 2), (0, 3), (3, 4), (4, 5)]
            embedding2 = Embedding(coordinates2, coupling_map2)

            # A LayerError object
            circuit = QuantumCircuit(4)
            qubits = [1, 2, 3, 4]
            generators = PauliList(["IIIX", "IIXI", "IXII", "YIII", "ZIII", "XXII", "ZZII"])
            rates = [0.01, 0.01, 0.01, 0.005, 0.02, 0.01, 0.01]
            error = PauliLindbladError(generators, rates)
            layer_error = LayerError(circuit, qubits, error)

            # Draw the layer error on embedding1
            layer_error.draw_map(embedding1)

            # Draw the layer error on embedding2
            layer_error.draw_map(embedding2)
        """
        # pylint: disable=import-outside-toplevel, cyclic-import

        from ..visualization import draw_layer_error_map

        return draw_layer_error_map(
            layer_error=self,
            embedding=embedding,
            colorscale=colorscale,
            color_no_data=color_no_data,
            color_out_of_scale=color_out_of_scale,
            num_edge_segments=num_edge_segments,
            edge_width=edge_width,
            height=height,
            highest_rate=highest_rate,
            background_color=background_color,
            radius=radius,
            width=width,
        )

    def draw_swarm(
        self,
        num_bodies: Optional[int] = None,
        max_rate: Optional[float] = None,
        min_rate: Optional[float] = None,
        connected: Optional[Union[list[Pauli], list[str]]] = None,
        colors: Optional[list[str]] = None,
        num_bins: Optional[int] = None,
        opacities: Union[float, list[float]] = 0.4,
        names: Optional[list[str]] = None,
        x_coo: Optional[list[float]] = None,
        marker_size: Optional[float] = None,
        height: int = 500,
        width: int = 800,
    ) -> PlotlyFigure:
        r"""
        Draw a swarm plot of the rates in this layer error.

        This function plots the rates along a vertical axes, offsetting the rates along the ``x``
        axis so that they do not overlap with each other.

        .. note::
            To draw multiple layer errors at once, consider calling
            :meth:`~qiskit_ibm_runtime.visualization.draw_layer_errors_swarm` directly.

        Args:
            num_bodies: The weight of the generators to include in the plot, or ``None`` if all the
                generators should be included.
            max_rate: The largest rate to include in the plot, or ``None`` if no upper limit should be
                set.
            min_rate: The smallest rate to include in the plot, or ``None`` if no lower limit should be
                set.
            connected: A list of generators whose markers are to be connected by lines.
            colors: A list of colors for the markers in the plot, or ``None`` if these colors are to be
                chosen automatically.
            num_bins: The number of bins to place the rates into when calculating the ``x``-axis
                offsets.
            opacities: A list of opacities for the markers.
            names: The names of the various layers as displayed in the legend. If ``None``, default
                names are assigned based on the layers' position inside the ``layer_errors`` list.
            x_coo: The ``x``-axis coordinates of the vertical axes that the markers are drawn around, or
                ``None`` if these axes should be placed at regular intervals.
            marker_size: The size of the marker in the plot.
            height: The height of the returned figure.
            width: The width of the returned figure.
        """
        # pylint: disable=import-outside-toplevel, cyclic-import

        from ..visualization import draw_layer_errors_swarm

        return draw_layer_errors_swarm(
            layer_errors=[self],
            num_bodies=num_bodies,
            max_rate=max_rate,
            min_rate=min_rate,
            connected=connected,
            colors=colors,
            num_bins=num_bins,
            opacities=opacities,
            names=names,
            x_coo=x_coo,
            marker_size=marker_size,
            height=height,
            width=width,
        )

    def _json(self) -> dict:
        """Return a dictionary containing all the information to re-initialize this object."""
        return {"circuit": self.circuit, "qubits": self.qubits, "error": self.error}

    def __repr__(self) -> str:
        ret = f"circuit={repr(self.circuit)}, qubits={self.qubits}, error={self.error})"
        return f"LayerError({ret})"


class NoiseLearnerResult:
    """A container for the results of a noise learner experiment."""

    def __init__(self, data: Sequence[LayerError], metadata: dict[str, Any] | None = None):
        """
        Args:
            data: The data of a noise learner experiment.
            metadata: Metadata that is common to all pub results; metadata specific to particular
                pubs should be placed in their metadata fields. Keys are expected to be strings.
        """
        self._data = list(data)
        self._metadata = {} if metadata is None else metadata.copy()

    @property
    def data(self) -> list[LayerError]:
        """The data of this noise learner result."""
        return self._data

    @property
    def metadata(self) -> dict[str, Any]:
        """The metadata of this noise learner result."""
        return self._metadata

    def __getitem__(self, index: int) -> LayerError:
        return self.data[index]

    def __len__(self) -> int:
        return len(self.data)

    def __repr__(self) -> str:
        return f"NoiseLearnerResult(data={self.data}, metadata={self.metadata})"

    def __iter__(self) -> Iterator[LayerError]:
        return iter(self.data)
