# This code is part of Qiskit.
#
# (C) Copyright IBM 2018-2024.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

"""
Class for representing a Pauli noise channel generated by a Pauli Lindblad dissipator.
"""

from __future__ import annotations
from collections.abc import Sequence
import numpy as np

from qiskit.quantum_info import Pauli, PauliList, SparsePauliOp, SuperOp
from qiskit.quantum_info.operators.mixins import TolerancesMixin
from .base_quantum_error import BaseQuantumError
from .quantum_error import QuantumError
from ..noiseerror import NoiseError


class PauliError(BaseQuantumError, TolerancesMixin):
    r"""A Pauli channel quantum error.

    This represents an N-qubit quantum error channel :math:`E(ρ) = \sum_j p_j P_j ρ P_j`
    where :math:`P_j` are N-qubit :class:`~.Pauli` operators.

    The list of Pauli terms are stored as a :class:`~.PauliList` and can be accessed
    via the :attr:`paulis` attribute. The array of probabilities :math:`p_j` can be
    accessed via the :attr:`probabilities` attribute.

    .. note::

        This operator can also represent a non-physical (non-CPTP) channel where some
        probabilities are negative or don't sum to 1. Non-physical operators
        cannot be converted to a :class:`~.QuantumError` or used in an
        :class:`~.AerSimulator` simulation. You can check if an operator is physical
        using the :meth:`is_cptp` method.
    """

    def __init__(
        self,
        paulis: Sequence[Pauli],
        probabilities: Sequence[float],
    ):
        """Initialize a Pauli error channel.

        Args:
            paulis: A sequence of Pauli channel terms.
            probabilities: A sequence of the probability for each Pauli channel term.

        Raises:
            NoiseError: If inputs are invalid.
        """
        self._paulis = PauliList(paulis)
        self._probabilities = np.asarray(probabilities, dtype=float)
        if self._probabilities.shape != (len(self._paulis),):
            raise NoiseError("Input Paulis and probabilities are different lengths.")
        super().__init__(num_qubits=self._paulis.num_qubits)

    def __repr__(self):
        return f"{type(self).__name__}({self.paulis.to_labels()}, {self.probabilities.tolist()})"

    def __eq__(self, other):
        # Use BaseOperator eq to check type and shape
        if not super().__eq__(other):
            return False
        lhs = self.simplify()
        rhs = other.simplify()
        if lhs.size != rhs.size:
            return False
        lpaulis, lprobs = sort_paulis(lhs.paulis, lhs.probabilities)
        rpaulis, rprobs = sort_paulis(rhs.paulis, rhs.probabilities)
        return np.allclose(lprobs, rprobs) and lpaulis == rpaulis

    @property
    def size(self):
        """Return the number of error circuit."""
        return len(self.paulis)

    @property
    def paulis(self) -> PauliList:
        """Return the Pauli channel error terms"""
        return self._paulis

    @property
    def probabilities(self) -> np.ndarray:
        """Return the Pauli channel probabilities"""
        return self._probabilities

    @property
    def settings(self):
        """Settings for IBM RuntimeEncoder JSON encoding"""
        return {
            "paulis": self.paulis,
            "probabilities": self.probabilities,
        }

    def ideal(self) -> bool:
        """Return True if this error object is composed only of identity operations.
        Note that the identity check is best effort and up to global phase."""
        if not self.is_cptp():
            return False
        non_zero = self.paulis[~np.isclose(self.probabilities, 0)]
        return not (np.any(non_zero.z) or np.any(non_zero.x))

    def is_cptp(self, atol: float | None = None, rtol: float | None = None) -> bool:
        """Return True if completely-positive trace-preserving (CPTP)."""
        return self.is_cp(atol=atol, rtol=rtol) and self.is_tp(atol=atol, rtol=rtol)

    def is_tp(self, atol: float | None = None, rtol: float | None = None) -> bool:
        """Test if a channel is trace-preserving (TP)"""
        if atol is None:
            atol = self.atol
        if rtol is None:
            rtol = self.rtol
        return np.isclose(np.sum(self.probabilities), 1, atol=atol, rtol=rtol)

    def is_cp(self, atol: float | None = None, rtol: float | None = None) -> bool:
        """Test if Choi-matrix is completely-positive (CP)"""
        if atol is None:
            atol = self.atol
        if rtol is None:
            rtol = self.rtol
        neg_probs = self.probabilities[self.probabilities < 0]
        return np.allclose(neg_probs, 0, atol=atol, rtol=rtol)

    def tensor(self, other: PauliError) -> PauliError:
        if not isinstance(other, PauliError):
            raise NoiseError("other must be a PauliError")
        left = SparsePauliOp(self.paulis, self.probabilities, copy=False, ignore_pauli_phase=True)
        right = SparsePauliOp(
            other.paulis, other.probabilities, copy=False, ignore_pauli_phase=True
        )
        tens = left.tensor(right)
        return PauliError(tens.paulis, tens.coeffs.real)

    def expand(self, other: PauliError) -> PauliError:
        if not isinstance(other, PauliError):
            raise NoiseError("other must be a PauliError")
        return other.tensor(self)

    def compose(self, other, qargs=None, front=False) -> PauliError:
        if qargs is None:
            qargs = getattr(other, "qargs", None)
        if not isinstance(other, PauliError):
            raise NoiseError("other must be a PauliError")

        # This is similar to SparsePauliOp.compose but doesn't need to track
        # phases since it is equivalent to the abeliean Pauli group compose

        # Validate composition dimensions and qargs match
        self._op_shape.compose(other._op_shape, qargs, front)

        if qargs is not None:
            x1, z1 = self.paulis.x[:, qargs], self.paulis.z[:, qargs]
        else:
            x1, z1 = self.paulis.x, self.paulis.z
        x2, z2 = other.paulis.x, other.paulis.z
        num_qubits = other.num_qubits

        x3 = np.logical_xor(x1[:, np.newaxis], x2).reshape((-1, num_qubits))
        z3 = np.logical_xor(z1[:, np.newaxis], z2).reshape((-1, num_qubits))

        if qargs is None:
            paulis = PauliList.from_symplectic(z3, x3)
        else:
            x4 = np.repeat(self.paulis.x, other.size, axis=0)
            z4 = np.repeat(self.paulis.z, other.size, axis=0)
            x4[:, qargs] = x3
            z4[:, qargs] = z3
            paulis = PauliList.from_symplectic(z4, x4)

        probabilities = np.multiply.outer(self.probabilities, other.probabilities).ravel()
        return PauliError(paulis, probabilities)

    def simplify(self, atol: float | None = None, rtol: float | None = None) -> PauliError:
        """Simplify PauliList by combining duplicates and removing zeros.

        Args:
            atol (float): Optional. Absolute tolerance for checking if
                          coefficients are zero (Default: 1e-8).
            rtol (float): Optional. relative tolerance for checking if
                          coefficients are zero (Default: 1e-5).

        Returns:
            SparsePauliOp: the simplified SparsePauliOp operator.
        """
        if atol is None:
            atol = self.atol
        if rtol is None:
            rtol = self.rtol
        simplified = SparsePauliOp(self.paulis, self.probabilities).simplify(atol=atol, rtol=rtol)
        return PauliError(simplified.paulis, simplified.coeffs.real)

    def to_quantum_error(self) -> "QuantumError":
        """Convert to a general QuantumError object."""
        if not self.is_cptp():
            raise NoiseError("Cannot convert non-CPTP PauliError to a QuantumError")
        return QuantumError(list(zip(self.paulis, self.probabilities)))

    def to_quantumchannel(self) -> SuperOp:
        """Convert to a dense N-qubit QuantumChannel"""
        # Sum terms as superoperator
        # We could do this more efficiently as a PTM or Chi, but would need
        # to map Pauli terms to integer index.
        chan = SuperOp(np.zeros(2 * [4**self.num_qubits]))
        for pauli, coeff in zip(self.paulis, self.probabilities):
            chan += coeff * SuperOp(pauli)
        return chan

    def to_dict(self) -> dict:
        """Return the current error as a dictionary."""
        # Assemble noise circuits for Aer simulator
        qubits = list(range(self.num_qubits))
        instructions = [
            [{"name": "pauli", "params": [pauli.to_label()], "qubits": qubits}]
            for pauli in self.paulis
        ]
        # Construct error dict
        error = {
            "type": "qerror",
            "id": self.id,
            "operations": [],
            "instructions": instructions,
            "probabilities": self.probabilities.tolist(),
        }
        return error

    @staticmethod
    def from_dict(error: dict) -> PauliError:
        """Implement current error from a dictionary."""
        # check if dictionary
        if not isinstance(error, dict):
            raise NoiseError("error is not a dictionary")
        # check expected keys "type, id, operations, instructions, probabilities"
        if (
            ("type" not in error)
            or ("id" not in error)
            or ("operations" not in error)
            or ("instructions" not in error)
            or ("probabilities" not in error)
        ):
            raise NoiseError("error dictionary not containing expected keys")
        instructions = error["instructions"]
        probabilities = error["probabilities"]
        if len(instructions) != len(probabilities):
            raise NoiseError("probabilities not matching with instructions")
        # parse instructions and turn to noise_ops
        paulis = []
        for inst in instructions:
            if len(inst) != 1 or inst[0]["name"] != "pauli":
                raise NoiseError("Invalid PauliError dict")
            paulis.append(inst[0]["params"][0])

        return PauliError(paulis, probabilities)


def sort_paulis(paulis: PauliList, coeffs: Sequence | None = None) -> tuple[PauliList, Sequence]:
    """Sort terms in a way that can be used for equality checks between simplified error ops"""
    if coeffs is not None and len(coeffs) != len(paulis):
        raise ValueError("paulis and coefffs must have the same length.")

    # Get packed bigs tableau of Paulis
    # Use numpy sorted and enumerate to implement an argsort of
    # rows based on python tuple sorting
    tableau = np.hstack([paulis.x, paulis.z])
    packed = np.packbits(tableau, axis=1)
    if coeffs is None:
        unsorted = ((*row.tolist(), i) for i, row in enumerate(packed))
    else:
        unsorted = ((*row.tolist(), coeff, i) for i, (row, coeff) in enumerate(zip(packed, coeffs)))
    index = [tup[-1] for tup in sorted(unsorted)]

    if coeffs is None:
        return paulis[index]
    return paulis[index], coeffs[index]
