# This code is part of Qiskit.
#
# (C) Copyright IBM 2025.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

"""Utilities for working with circuit schedule timing information returned 
from the Qiskit Runtime service."""

from __future__ import annotations

from itertools import cycle
from typing import Tuple, List, Set, Dict, TYPE_CHECKING
import numpy as np
from ..visualization.utils import plotly_module

if TYPE_CHECKING:
    from plotly.graph_objects import Figure as PlotlyFigure, Scatter


go = plotly_module(".graph_objects")
colors = plotly_module(".colors").qualitative.Plotly


READOUT_CHANNEL_PREFIX = "AWGR"  # Arbitrary Wave Generator Readout (AWGR) is used for
# readout channels communication for measuring qubits,
# as opposed to drive channels which are for driving the
# qubits.
BARRIER = "barrier"


class CircuitSchedule:
    """The class encapsulates the data of a Qiskit circuit schedule and contains the functionality
    to visualize it.
    """

    def __init__(
        self,
        circuit_schedule: str,
    ):
        """Initialize a CircuitSchedule object with the schedule data generated by the compiler.
        The data is loaded as a single string of consecutive hardware instruction separated by
        new line arguments, parsed into a Numpy array to allow efficient filtration and
        preparation of the data for plotting.

        Args:
            circuit_schedule: A schedule data as a string of hardware instructions as returned
            by the compiler.

        Attributes:
            channels: A list of channels to be plotted (the rows in the plot).
            type_to_idx: A mapping from data type names to indices of the corresponding data in the
            circuit_scheduling Numpy array
            circuit_scheduling: A Numpy array container for holding and manipulating the data for
            plotting.
            instruction_set: A set of all the different instructions (gates + communication
            instructions) within the circuit scheduling.
            max_time: The duration of the scheduled circuit in cycles (a cycle is a global unit
            which may take a different amount of seconds on different backends).
            color_map: A color map for mapping instructions to colors.
            annotations: A list that contains annotations for traces.
            legend: A legend for the plot (a set of instructions).
            traces: A list of the plotly scatter traces to plot.
        """
        self.channels: List = None
        self.type_to_idx: Dict[str, int] = None
        self.circuit_scheduling = None

        raw_data = self._load(circuit_schedule)
        self._parse(raw_data)

        self.instruction_set: Set[str] = set()
        self.max_time: int = None
        self.color_map: Dict[str, str] = {}
        self.annotations: List[Dict] = []
        self.legend: Set[str] = set()
        self.traces: List[Scatter] = []

    @classmethod
    def _load(cls, circuit_schedule: str) -> List[str]:
        """Load the data from a file or a data object.

        Args:
            circuit_schedule: A schedule data as a string of hardware instructions as returned
            by the compiler.

        Returns:
            Circuit schedule data.
        """
        if isinstance(circuit_schedule, str):
            data = circuit_schedule.split("\n")
        else:
            raise TypeError("CircuitSchedule expects a str.")

        return data

    def _parse(self, raw_data: List[str]) -> None:
        """Parse the raw circuit schedule data into a numpy array.

        Args:
            raw_data: A list of instruction schedules as strings.

        Return:
            None.
        """
        circuit_scheduling = []
        for line in raw_data:
            if line == "":
                continue
            words = line.split(",")
            if "shift_phase" in words[0]:
                continue
            if len(words) != 6:
                raise ValueError(
                    "Cannot interpret timeline data that doesn't have the format \\\
                        <Branch, Instruction, Channel, T0, Duration, Pulse>"
                )
            circuit_scheduling.append(
                [
                    words[0],  # Branch
                    words[1],  # Instruction
                    words[2],  # Channel
                    words[3],  # Start
                    str(int(words[3]) + int(words[4])),  # Finish
                    words[5],  # Pulse
                    words[1].split("_")[0],  # GateName
                ]
            )
        data_names = ["Branch", "Instruction", "Channel", "Start", "Finish", "Pulse", "GateName"]
        self.type_to_idx = {data_name: idx for idx, data_name in enumerate(data_names)}
        self.circuit_scheduling = np.array(circuit_scheduling)

    def preprocess(
        self,
        filter_awgr: bool = False,
        filter_barriers: bool = False,
        included_channels: list = None,
    ) -> None:
        """Preprocess and filter the parsed circuit schedule data for visualization.

        Args:
            filter_awgr: If ``True``, remove all readout channels from scheduling data.
            filter_barriers: If ``True``, remove all barriers from scheduling data.
            included_channels: If not ``None``, remove all channels from scheduling data
               that are not in the ``included_channels`` list.
        """
        # filter channels
        if included_channels is not None and isinstance(included_channels, list):
            mask = np.isin(
                self.circuit_scheduling[:, self.type_to_idx["Channel"]], included_channels
            )
            self.circuit_scheduling = self.circuit_scheduling[mask]

        # filter AWGR channels
        if filter_awgr:
            mask = ~np.char.startswith(
                self.circuit_scheduling[:, self.type_to_idx["Channel"]], READOUT_CHANNEL_PREFIX
            )
            self.circuit_scheduling = self.circuit_scheduling[mask]

        # filter barriers
        if filter_barriers:
            mask = self.circuit_scheduling[:, self.type_to_idx["Instruction"]] != BARRIER
            self.circuit_scheduling = self.circuit_scheduling[mask]

        self.circuit_scheduling = self.circuit_scheduling[
            np.argsort(self.circuit_scheduling[:, self.type_to_idx["Channel"]])
        ]
        self.channels = np.unique(self.circuit_scheduling[:, self.type_to_idx["Channel"]])
        self.channels.sort()
        self.channels = list(self.channels)
        self.max_time = int(max(self.circuit_scheduling[:, self.type_to_idx["Finish"]]))
        self.instruction_set = np.unique(self.circuit_scheduling[:, self.type_to_idx["GateName"]])
        self.color_map = dict(zip(self.instruction_set, cycle(colors)))

    def get_trace_finite_duration_y_shift(self, branch: str) -> Tuple[float, float, float]:
        """Return y-axis trace shift for a finite duration instruction schedule and its annotation.
        The shifts are to distinguish static and dynamic (control-flow) parts of the circuit.

        Args:
            branch: The branch type to get the shift for, 'main' / 'then' / 'else'.

        Raises:
            ValueError for unsupported branch name.

        Returns:
            A y-axis shifts for trace and annotation.
        """
        if branch == "main":
            return (-0.4, 0.4, 0)
        elif branch == "then":
            return (0, 0.4, 0.25)
        elif branch == "else":
            return (-0.4, 0, -0.25)
        else:
            raise ValueError(f"Unexpected branch provided: {branch}")

    def get_trace_zero_duration_y_shift(self, branch: str) -> float:
        """Return y-axis trace shift for a zero duration instruction schedule.
        The shifts are to distinguish static and dynamic (control-flow) parts of the circuit.

        Args:
            branch: The branch type to get the shift for, 'main' / 'then' / 'else'.

        Raises:
            ValueError for unsupported branch name.

        Returns:
            A y-axis shifts for trace.
        """
        if branch == "main":
            return 0
        elif branch == "then":
            return 0.2
        elif branch == "else":
            return -0.2
        else:
            raise ValueError(f"Unexpected branch provided: {branch}")

    def trace_finite_duration_instruction(self, instruction_schedule: np.array) -> None:
        """Create a trace and annotation for a single finite duration instruction schedule.

        Args:
            instruction_schedule: A single instruction schedule as a numpy array.
        """
        (branch, instruction, channel, t_i, t_f, pulse, gate_name) = instruction_schedule
        t_i, t_f = int(t_i), int(t_f)

        # compute trace y-position
        y0, y1, annotation_y = self.get_trace_finite_duration_y_shift(branch)
        channel_y_loc = self.channels.index(channel)
        y_low = channel_y_loc + y0
        y_high = channel_y_loc + y1

        # extend barriers vertically beyond operations
        if gate_name == BARRIER:
            y0 -= 0.05
            y1 += 0.05

        # Get gate trace
        # the gate is drawn as a 6-point closed rectangular to
        # allow convenient annotation display when hover
        trace = go.Scatter(
            x=[
                t_i,
                (t_i + t_f) / 2,
                t_f,
                t_f,
                (t_i + t_f) / 2,
                t_i,
                t_i,
                None,
            ],
            y=[
                y_low,
                y_low,
                y_low,
                y_high,
                y_high,
                y_high,
                y_low,
                None,
            ],
            mode="markers",
            hoverinfo="x+text",
            name=gate_name,
            text="<br>".join(
                [
                    "Instruction: " + instruction,
                    "Pulse: " + pulse,
                    "Start: " + str(t_i),
                    "Finish: " + str(t_f),
                    "Duration: " + str(t_f - t_i),
                ]
            ),
            legendgroup=gate_name,
            line={"color": "black"},
            fill="toself",
            fillcolor=self.color_map[gate_name],
            showlegend=gate_name not in self.legend,
        )
        self.traces.append(trace)

        # Get trace annotation
        # hide text if drawing a barrier
        text = "" if gate_name == BARRIER else f"{gate_name}_{pulse}"
        annotation = {
            "x": (t_i + t_f) / 2,
            "y": channel_y_loc + annotation_y,
            "showarrow": False,
            "font": {"color": "black", "size": 10},
            "text": text,
            "textangle": 0,
        }
        self.annotations.append(annotation)

    def trace_zero_duration_instruction(self, instruction_schedule: np.array) -> None:
        """Create a trace and annotation for a single zero duration instruction schedule.

        Args:
            instruction_schedule: A single instruction schedule as a numpy array.
        """
        (branch, instruction, channel, t_i, t_f, pulse, gate_name) = instruction_schedule
        t_i, t_f = int(t_i), int(t_f)

        y_shift = self.get_trace_zero_duration_y_shift(branch)
        channel_y_loc = self.channels.index(channel)
        y_mid = channel_y_loc + y_shift
        y_low = y_mid - 0.2
        y_high = y_mid + 0.2

        # Get trace
        # drawing zero duration traces as diamonds
        trace = go.Scatter(
            x=[t_i, t_i + 1, t_i, t_i - 1, t_i, None],
            y=[y_low, y_mid, y_high, y_mid, y_low, None],
            mode="markers",
            hoverinfo="x+text",
            name=gate_name,
            text="<br>".join(
                [
                    f"Instruction: {instruction}",
                    f"Pulse: {pulse}",
                    f"Start: {t_i}",
                    f"Finish: {t_f}",
                    f"Duration: {t_f - t_i}",
                ]
            ),
            legendgroup=gate_name,
            line={"color": "black"},
            fill="toself",
            fillcolor=self.color_map[gate_name],
            showlegend=gate_name not in self.legend,
        )
        self.traces.append(trace)

        # Get trace annotation
        annotation = {
            "x": ((t_i + t_f) / 2),
            "y": y_mid,
            "showarrow": True,
            "font": {"color": "black", "size": 10},
            "text": f"{gate_name}_{pulse}",
            "textangle": 0,
        }
        self.annotations.append(annotation)

    def populate_figure(self, fig: PlotlyFigure) -> PlotlyFigure:
        """Iterate through the processed circuit instruction schedules, generate
        their traces and annotations, and add those to the figure.

        Args:
            fig: A plotly figure to populate with traces.

        Return:
            The populated figure.
        """
        # Process instructions
        shift_phase_instructions = []
        for instruction_schedule in self.circuit_scheduling:
            (_, _, _, _, _, pulse, gate_name) = instruction_schedule
            if "shift_phase" not in pulse:
                # Trace instructions of finite duration
                self.trace_finite_duration_instruction(instruction_schedule)
                self.legend.add(gate_name)
            else:
                # cache instructions of zero duration
                # for later tracing so it won't be covered
                shift_phase_instructions.append(instruction_schedule)

        # Trace instructions of zero duration
        for instruction_schedule in shift_phase_instructions:
            (_, _, _, _, _, _, gate_name) = instruction_schedule
            self.trace_zero_duration_instruction(instruction_schedule)
            self.legend.add(gate_name)

        fig.add_traces(self.traces)
        return fig
