# This code is part of Qiskit.
#
# (C) Copyright IBM 2018-2024.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

"""
Class for representing a Pauli noise channel generated by a Pauli Lindblad dissipator.
"""

from __future__ import annotations
from collections.abc import Sequence
import numpy as np

from qiskit.quantum_info import Pauli, PauliList, SparsePauliOp, ScalarOp, SuperOp
from qiskit.quantum_info.operators.mixins import TolerancesMixin
from .base_quantum_error import BaseQuantumError
from .quantum_error import QuantumError
from .pauli_error import PauliError, sort_paulis
from ..noiseerror import NoiseError


class PauliLindbladError(BaseQuantumError, TolerancesMixin):
    r"""A Pauli channel generated by a Pauli Lindblad dissipator.

    This operator represents an N-qubit quantum error channel
    :math:`E(ρ) = e^{\sum_j r_j D_{P_j}}(ρ)` generated by Pauli Lindblad dissipators
    :math:`D_P(ρ) = P ρ P - ρ`, where :math:`P_j` are N-qubit :class:`~.Pauli` operators.

    The list of Pauli generator terms are stored as a :class:`~.PauliList` and can be
    accessed via the :attr:`generators` attribute. The array of dissipator rates
    :math:`r_j` can be accessed via the :attr:`rates` attribute.

    A Pauli lindblad error is equivalent to a :class:`.PauliError` and can be converted
    using the :meth:`to_pauli_error` method. Though note that for a sparse generated
    ``PauliLindbladError`` there may in general be exponentially many terms in the
    converted :class:`.PauliError` operator. Because of this, this operator can be
    used to more efficiently represent N-qubit Pauli channels for simulation if they
    have few generator terms.

    The equivalent Pauli error is constructed as a composition of single-Pauli channel terms

    .. math::

        e^{\sum_j r_j D_{P_j}} = \prod_j e^{r_j D_{P_j}}
        = prod_j \left( (1 - p_j) S_I + p_j S_{P_j} \right)

    where :math:`p_j = \frac12 - \frac12 e^{-2 r_j}`.

    .. note::

        This operator can also represent a non-physical (non-CPTP) channel if any of the
        rates are negative. Non-physical operators cannot be converted to a
        :class:`~.QuantumError` or used in an :class:`~.AerSimulator` simulation. You can
        check if an operator is physical using the :meth:`is_cptp` method.
    """

    def __init__(
        self,
        generators: Sequence[Pauli],
        rates: Sequence[float],
    ):
        """Initialize a Pauli-Lindblad dissipator model.

        Args:
            generators: A list of Pauli's corresponding to the Lindblad dissipator generator terms.
            rates: The Pauli Lindblad dissipator generator rates.
        """
        self._generators = PauliList(generators)
        self._rates = np.asarray(rates, dtype=float)
        if self._rates.shape != (len(self._generators),):
            raise NoiseError("Input generators and rates are different lengths.")
        super().__init__(num_qubits=self._generators.num_qubits)

    def __repr__(self):
        return f"{type(self).__name__}({self.generators.to_labels()}, {self.rates.tolist()})"

    def __eq__(self, other):
        # Use BaseOperator eq to check type and shape
        if not super().__eq__(other):
            return False
        lhs = self.simplify()
        rhs = other.simplify()
        if lhs.size != rhs.size:
            return False
        lpaulis, lrates = sort_paulis(lhs.generators, lhs.rates)
        rpaulis, rrates = sort_paulis(rhs.generators, rhs.rates)
        return np.allclose(lrates, rrates) and lpaulis == rpaulis

    @property
    def size(self):
        """Return the number of error circuit."""
        return len(self.generators)

    @property
    def generators(self) -> PauliList:
        """Return the Pauli Lindblad dissipator generator terms"""
        return self._generators

    @property
    def rates(self) -> np.ndarray:
        """Return the Lindblad dissipator generator rates"""
        return self._rates

    @property
    def settings(self):
        """Settings for IBM RuntimeEncoder JSON encoding"""
        return {
            "generators": self.generators,
            "rates": self.rates,
        }

    def ideal(self) -> bool:
        """Return True if this error object is composed only of identity operations.
        Note that the identity check is best effort and up to global phase."""
        return np.allclose(self.rates, 0) or not (
            np.any(self.generators.z) or np.any(self.generators.x)
        )

    def is_cptp(self, atol: float | None = None, rtol: float | None = None) -> bool:
        """Return True if completely-positive trace-preserving (CPTP)."""
        return self.is_cp(atol=atol, rtol=rtol)

    def is_tp(self, atol: float | None = None, rtol: float | None = None) -> bool:
        """Test if a channel is trace-preserving (TP)"""
        # pylint: disable = unused-argument
        # This error is TP by construction regardless of rates.
        return True

    def is_cp(self, atol: float | None = None, rtol: float | None = None) -> bool:
        """Test if Choi-matrix is completely-positive (CP)"""
        if atol is None:
            atol = self.atol
        if rtol is None:
            rtol = self.rtol
        neg_rates = self.rates[self.rates < 0]
        return np.allclose(neg_rates, 0, atol=atol, rtol=rtol)

    def tensor(self, other: PauliLindbladError) -> PauliLindbladError:
        if not isinstance(other, PauliLindbladError):
            raise NoiseError("other must be a PauliLindbladError")
        zeros_l = np.zeros(self.num_qubits, dtype=bool)
        zeros_r = np.zeros(other.num_qubits, dtype=bool)
        gens_left = self.generators.tensor(Pauli((zeros_r, zeros_r)))
        gens_right = other.generators.expand(Pauli((zeros_l, zeros_l)))
        generators = gens_left + gens_right
        rates = np.concatenate([self.rates, other.rates])
        return PauliLindbladError(generators, rates)

    def expand(self, other) -> PauliLindbladError:
        if not isinstance(other, PauliLindbladError):
            raise NoiseError("other must be a PauliLindbladError")
        return other.tensor(self)

    def compose(self, other, qargs=None, front=False) -> PauliLindbladError:
        if qargs is None:
            qargs = getattr(other, "qargs", None)
        if not isinstance(other, PauliLindbladError):
            raise NoiseError("other must be a PauliLindbladError")
        # Validate composition dimensions and qargs match
        self._op_shape.compose(other._op_shape, qargs, front)
        # pylint: disable = unused-argument
        padded = ScalarOp(self.num_qubits * (2,)).compose(
            SparsePauliOp(other.generators, copy=False, ignore_pauli_phase=True), qargs
        )
        generators = self.generators + padded.paulis
        rates = np.concatenate([self.rates, other.rates])
        return PauliLindbladError(generators, rates)

    def power(self, n: float) -> PauliLindbladError:
        return PauliLindbladError(self.generators, n * self.rates)

    def inverse(self) -> PauliLindbladError:
        """Return the inverse (non-CPTP) channel"""
        return PauliLindbladError(self.generators, -self.rates)

    def simplify(
        self, atol: float | None = None, rtol: float | None = None
    ) -> "PauliLindbladError":
        """Simplify PauliList by combining duplicates and removing zeros.

        Args:
            atol (float): Optional. Absolute tolerance for checking if
                          coefficients are zero (Default: 1e-8).
            rtol (float): Optional. relative tolerance for checking if
                          coefficients are zero (Default: 1e-5).

        Returns:
            SparsePauliOp: the simplified SparsePauliOp operator.
        """
        if atol is None:
            atol = self.atol
        if rtol is None:
            rtol = self.rtol
        # Remove identity terms
        non_iden = np.any(np.logical_or(self.generators.z, self.generators.x), axis=1)
        simplified = SparsePauliOp(
            self.generators[non_iden],
            self.rates[non_iden],
            copy=False,
            ignore_pauli_phase=True,
        ).simplify(atol=atol, rtol=rtol)
        return PauliLindbladError(simplified.paulis, simplified.coeffs.real)

    def subsystem_errors(self) -> list[tuple[PauliLindbladError, tuple[int, ...]]]:
        """Return a list errors for the subsystem decomposed error.

        .. note::

            This uses a greedy algorithm to find the largest non-identity subsystem
            Pauli, checks if its non identity terms are covered by any previously
            selected Pauli's, and if not adds it to list of covering subsystems.
            This is repeated until no generators remain.

            In terms of the number of Pauli terms this has runtime
            `O(num_terms * num_coverings)`,
            which in the worst case is `O(num_terms ** 2)`.

        Returns:
            A list of pairs of component PauliLindbladErrors and subsystem indices for the
            that decompose the current errors.
        """
        # Find non-identity paulis we wish to cover
        paulis = self.generators
        non_iden = np.logical_or(paulis.z, paulis.x)

        # Mask to exclude generator terms that are trivial (identities)
        non_trivial = np.any(non_iden, axis=1)

        # Paulis that aren't covered yet
        uncovered = np.arange(non_iden.shape[0])[non_trivial]

        # Indices that cover each Pauli
        # Note that trivial terms will be left at -1
        covered = -1 * np.ones(non_iden.shape[0], dtype=int)
        coverings = []

        # In general finding optimal coverings is NP-hard (I think?)
        # so we do a heuristic of just greedily finding the largest
        # pauli that isn't covered, and checking if it is covered
        # by any previous coverings, if not adding it to coverings
        # and repeat until we run out of paulis
        while uncovered.size:
            imax = np.argmax(np.sum(non_iden[uncovered], axis=1))
            rmax = uncovered[imax]
            add_covering = True
            for rcov in coverings:
                if np.all(non_iden[rcov][non_iden[rmax]]):
                    add_covering = False
                    covered[rmax] = rcov
                    break
            if add_covering:
                covered[rmax] = rmax
                coverings.append(rmax)
            uncovered = uncovered[uncovered != rmax]

        # Extract subsystem errors and qinds of non-identity terms
        sub_errors = []
        for cover in coverings:
            pinds = covered == cover
            qinds = np.any(non_iden[pinds], axis=0)
            sub_z = paulis.z[pinds][:, qinds]
            sub_x = paulis.x[pinds][:, qinds]
            sub_gens = PauliList.from_symplectic(sub_z, sub_x)
            sub_err = PauliLindbladError(sub_gens, self.rates[pinds])
            sub_qubits = tuple(np.where(qinds)[0])
            sub_errors.append((sub_err, sub_qubits))

        return sub_errors

    def to_pauli_error(self, simplify: bool = True) -> PauliError:
        """Convert to a PauliError operator.

        .. note::

            If this objects represents an non-CPTP inverse channel with negative
            rates the returned "probabilities" will be a quasi-probability
            distribution containing negative values.

        Args:
            simplify: If True call :meth:`~.PauliError.simplify` each single
                Pauli channel composition to reduce the number of duplicate terms.

        Returns:
            The :class:`~.PauliError` of the current Pauli channel.
        """
        chan_z = np.zeros((1, self.num_qubits), dtype=bool)
        chan_x = np.zeros_like(chan_z)
        chan_p = np.ones(1, dtype=float)
        for term_z, term_x, term_r in zip(self.generators.z, self.generators.x, self.rates):
            term_p = 0.5 - 0.5 * np.exp(-2 * term_r)
            chan_z = np.concatenate([chan_z, np.logical_xor(chan_z, term_z)], axis=0)
            chan_x = np.concatenate([chan_x, chan_x ^ term_x])
            chan_p = np.concatenate([(1 - term_p) * chan_p, term_p * chan_p])
            if simplify:
                error_op = PauliError(PauliList.from_symplectic(chan_z, chan_x), chan_p).simplify()
                chan_z, chan_x, chan_p = (
                    error_op.paulis.z,
                    error_op.paulis.x,
                    error_op.probabilities,
                )
        return PauliError(PauliList.from_symplectic(chan_z, chan_x), chan_p)

    def to_quantum_error(self) -> QuantumError:
        """Convert to a general QuantumError object."""
        if not self.is_cptp():
            raise NoiseError("Cannot convert non-CPTP PauliLindbladError to a QuantumError")
        return self.to_pauli_error().to_quantum_error()

    def to_quantumchannel(self) -> SuperOp:
        """Convert to a dense N-qubit QuantumChannel"""
        return self.to_pauli_error().to_quantumchannel()

    def to_dict(self):
        """Return the current error as a dictionary."""
        # Assemble noise circuits for Aer simulator
        qubits = list(range(self.num_qubits))
        instructions = [
            [{"name": "pauli", "params": [pauli.to_label()], "qubits": qubits}]
            for pauli in self.generators
        ]
        # Construct error dict
        error = {
            "type": "plerror",
            "id": self.id,
            "operations": [],
            "instructions": instructions,
            "rates": self.rates.tolist(),
        }
        return error

    @staticmethod
    def from_dict(error: dict) -> PauliLindbladError:
        """Implement current error from a dictionary."""
        # check if dictionary
        if not isinstance(error, dict):
            raise NoiseError("error is not a dictionary")
        # check expected keys "type, id, operations, instructions, probabilities"
        if (
            ("type" not in error)
            or ("id" not in error)
            or ("operations" not in error)
            or ("instructions" not in error)
            or ("rates" not in error)
        ):
            raise NoiseError("error dictionary not containing expected keys")
        instructions = error["instructions"]
        rates = error["rates"]
        if len(instructions) != len(rates):
            raise NoiseError("rates not matching with instructions")
        # parse instructions and turn to noise_ops
        generators = []
        for inst in instructions:
            if len(inst) != 1 or inst[0]["name"] != "pauli":
                raise NoiseError("Invalid PauliLindbladError dict")
            generators.append(inst[0]["params"][0])
        return PauliLindbladError(generators, rates)
