# SPDX-License-Identifier: MIT
# Copyright (c) 2025 LlamaIndex Inc.

from __future__ import annotations

import asyncio
import logging
from collections import Counter, defaultdict
from typing import (
    TYPE_CHECKING,
    Any,
    AsyncGenerator,
    Awaitable,
    Callable,
    Coroutine,
    Generic,
    Type,
    TypeVar,
    cast,
)

from llama_index_instrumentation.dispatcher import (
    active_instrument_tags,
    instrument_tags,
)


from workflows.utils import _nanoid as nanoid
from workflows.errors import WorkflowRuntimeError
from workflows.events import (
    Event,
    StartEvent,
)
from workflows.runtime.control_loop import control_loop, rebuild_state_from_ticks
from workflows.runtime.types.internal_state import BrokerState
from workflows.runtime.types.plugin import Plugin, WorkflowRuntime, as_snapshottable
from workflows.runtime.types.results import (
    AddCollectedEvent,
    AddWaiter,
    DeleteCollectedEvent,
    DeleteWaiter,
    StepWorkerContext,
    StepWorkerStateContextVar,
    WaitingForEvent,
)
from workflows.runtime.types.step_function import (
    StepWorkerFunction,
    as_step_worker_function,
)
from workflows.runtime.types.ticks import TickAddEvent, TickCancelRun, WorkflowTick
from workflows.runtime.workflow_registry import workflow_registry

from ..context.state_store import MODEL_T

from workflows.handler import WorkflowHandler

if TYPE_CHECKING:
    from workflows import Workflow
    from workflows.context.context import Context


T = TypeVar("T", bound=Event)
EventBuffer = dict[str, list[Event]]

logger = logging.getLogger()


# Only warn once about unserializable keys
class UnserializableKeyWarning(Warning):
    pass


class WorkflowBroker(Generic[MODEL_T]):
    """
    The workflow broker manages starting up and connecting a workflow handler, a runtime, and triggering the
    execution of the workflow. From there it manages communication between the workflow and the outside world.
    """

    _context: Context[MODEL_T]
    _runtime: WorkflowRuntime
    _plugin: Plugin
    _is_running: bool
    _handler: WorkflowHandler | None
    _workflow: Workflow
    # transient tasks to run async ops in background, exposing sync interfaces
    _workers: list[asyncio.Task]
    _init_state: BrokerState | None

    def __init__(
        self,
        workflow: Workflow,
        context: Context[MODEL_T],
        runtime: WorkflowRuntime,
        plugin: Plugin,
    ) -> None:
        self._context = context
        self._runtime = runtime
        self._plugin = plugin
        self._is_running = False
        self._handler = None
        self._workflow = workflow
        self._workers = []
        self._init_state = None

    def _execute_task(self, coro: Coroutine[Any, Any, Any]) -> asyncio.Task[Any]:
        task = asyncio.create_task(coro)
        self._workers.append(task)
        task.add_done_callback(lambda _: self._workers.remove(task))
        return task

    # context API only
    def start(
        self,
        workflow: Workflow,
        previous: BrokerState | None = None,
        start_event: StartEvent | None = None,
        before_start: Callable[[], Awaitable[None]] | None = None,
        after_complete: Callable[[], Awaitable[None]] | None = None,
    ) -> WorkflowHandler:
        """Start the workflow run. Can only be called once."""
        if self._handler is not None:
            raise WorkflowRuntimeError(
                "this WorkflowBroker already run or running. Cannot start again."
            )
        self._init_state = previous

        async def _run_workflow(run_id: str, tags: dict[str, Any]) -> None:
            with instrument_tags({"run_id": run_id, **tags}):
                # defer execution to make sure the task can be captured and passed
                # to the handler as async exception, protecting against exceptions from before_start
                self._is_running = True
                await asyncio.sleep(0)
                if before_start is not None:
                    await before_start()
                try:
                    init_state = previous or BrokerState.from_workflow(workflow)

                    try:
                        exception_raised = None

                        step_workers: dict[str, StepWorkerFunction] = {}
                        for name, step_func in workflow._get_steps().items():
                            # Avoid capturing a bound method (which retains the instance).
                            # If it's a bound method, extract the unbound function from the class.
                            unbound = getattr(step_func, "__func__", step_func)
                            step_workers[name] = as_step_worker_function(unbound)

                        registered = workflow_registry.get_registered_workflow(
                            workflow, self._plugin, control_loop, step_workers
                        )

                        # Register run context prior to invoking control loop
                        workflow_registry.register_run(
                            run_id=run_id,
                            workflow=workflow,
                            plugin=self._runtime,
                            context=self._context,  # type: ignore
                            steps=registered.steps,
                        )

                        try:
                            workflow_result = await registered.workflow_function(
                                start_event,
                                init_state,
                                run_id,
                            )
                        finally:
                            # ensure run context is cleaned up even on failure
                            workflow_registry.delete_run(run_id)
                        result._set_stop_event(workflow_result)
                    except Exception as e:
                        exception_raised = e

                    if exception_raised:
                        # cancel the stream
                        if not result.done():
                            result.set_exception(exception_raised)
                finally:
                    if after_complete is not None:
                        await after_complete()
                    self._is_running = False

        # Start the machinery in a new Context or use the provided one
        run_id = nanoid()

        # If a previous context is provided, pass its serialized form

        run_task = self._execute_task(
            _run_workflow(run_id, tags=active_instrument_tags.get())
        )
        result = WorkflowHandler(
            ctx=self._context,  # type: ignore
            run_id=run_id,
            run_task=run_task,
        )
        self._handler = result
        return result

    # outer handler API to cancel the workflow run
    def cancel_run(self) -> None:
        self._execute_task(self._runtime.send_event(TickCancelRun()))

    @property
    def is_running(self) -> bool:
        return self._is_running

    @property
    def _state(self) -> BrokerState:
        ticks = self._tick_log
        state = self._init_state or BrokerState.from_workflow(self._workflow)
        new_state = rebuild_state_from_ticks(state, ticks)
        return new_state

    @property
    def _tick_log(self) -> list[WorkflowTick]:
        snapshottable = as_snapshottable(self._runtime)
        if snapshottable is None:
            raise WorkflowRuntimeError("Plugin is not snapshottable")
        return snapshottable.replay()

    # mostly a debug API. May be removed in the future.
    async def running_steps(self) -> list[str]:
        return [
            step
            for step in self._state.workers.keys()
            if self._state.workers[step].in_progress
        ]

    # step api only
    def collect_events(
        self, ev: Event, expected: list[Type[Event]], buffer_id: str | None = None
    ) -> list[Event] | None:
        step_ctx = self._get_step_ctx(fn="collect_events")

        buffer_id = buffer_id or "default"

        collected_events = step_ctx.state.collected_events.get(buffer_id, [])

        remaining_event_types = Counter(expected) - Counter(
            [type(e) for e in collected_events]
        )

        if remaining_event_types != Counter([type(ev)]):
            if type(ev) in remaining_event_types:
                step_ctx.returns.return_values.append(
                    AddCollectedEvent(event_id=buffer_id, event=ev)
                )
            return None

        total = []
        by_type = defaultdict(list)
        for e in collected_events + [ev]:
            by_type[type(e)].append(e)
        # order by expected type
        for e_type in expected:
            total.append(by_type[e_type].pop(0))
        # if we got here, it means the collection is fulfilled. Clear the collected events when the step is complete
        step_ctx.returns.return_values.append(DeleteCollectedEvent(event_id=buffer_id))
        return total

    # may be called from both step API and outer handler API
    def send_event(self, message: Event, step: str | None = None) -> None:
        if step is not None:
            if step not in self._workflow._get_steps():
                raise WorkflowRuntimeError(f"Step {step} does not exist")

            # Validate that the step accepts this event type
            step_func = self._workflow._get_steps()[step]
            step_config = step_func._step_config
            if type(message) not in step_config.accepted_events:
                raise WorkflowRuntimeError(
                    f"Step {step} does not accept event of type {type(message)}"
                )

        self._execute_task(
            self._runtime.send_event(TickAddEvent(event=message, step_name=step))
        )

    def _get_step_ctx(self, fn: str) -> StepWorkerContext:
        try:
            return StepWorkerStateContextVar.get()
        except LookupError:
            raise WorkflowRuntimeError(
                f"{fn} may only be called from within a step function"
            )

    # step api only
    async def wait_for_event(
        self,
        event_type: Type[T],
        waiter_event: Event | None = None,
        waiter_id: str | None = None,
        requirements: dict[str, Any] | None = None,
        timeout: float | None = 2000,
    ) -> T:
        step_ctx = self._get_step_ctx(fn="wait_for_event")

        collected_waiters = step_ctx.state.collected_waiters
        requirements = requirements or {}

        # Generate a unique key for the waiter
        event_str = self._get_full_path(event_type)
        requirements_str = str(requirements)
        waiter_id = waiter_id or f"waiter_{event_str}_{requirements_str}"

        waiter = next((w for w in collected_waiters if w.waiter_id == waiter_id), None)
        if waiter is None or waiter.resolved_event is None:
            raise WaitingForEvent(
                AddWaiter(
                    waiter_id=waiter_id,
                    requirements=requirements,
                    timeout=timeout,
                    event_type=event_type,
                    waiter_event=waiter_event,
                )
            )
        else:
            step_ctx.returns.return_values.append(DeleteWaiter(waiter_id=waiter_id))
            return cast(T, waiter.resolved_event)

    def _get_full_path(self, ev_type: Type[Event]) -> str:
        return f"{ev_type.__module__}.{ev_type.__name__}"

    def stream_published_events(self) -> AsyncGenerator[Event, None]:
        """The internal queue used for streaming events to callers."""
        return self._runtime.stream_published_events()

    # step API only
    def write_event_to_stream(self, ev: Event | None) -> None:
        if ev is not None:
            self._execute_task(self._runtime.write_to_event_stream(ev))

    async def shutdown(self) -> None:
        """Cancels the running workflow loop

        Cancels all outstanding workers, waits for them to finish, and marks the
        broker as not running. Queues and state remain available so callers can
        inspect or drain leftover events.
        """
        await self._runtime.send_event(TickCancelRun())
        for worker in self._workers:
            worker.cancel()
        self._workers.clear()
        await self._runtime.close()
