# SPDX-License-Identifier: MIT
# Copyright (c) 2025 LlamaIndex Inc.

from __future__ import annotations

import asyncio
import functools
import uuid
import warnings
from typing import (
    TYPE_CHECKING,
    Any,
    AsyncGenerator,
    Generic,
    Type,
    TypeVar,
    cast,
)

from pydantic import BaseModel, ValidationError

from workflows.context.context_types import SerializedContext
from workflows.decorators import StepConfig
from workflows.errors import (
    ContextSerdeError,
    WorkflowRuntimeError,
)
from workflows.events import (
    Event,
    StartEvent,
    StopEvent,
)
from workflows.runtime.types.internal_state import BrokerState
from workflows.runtime.broker import WorkflowBroker
from workflows.plugins.basic import basic_runtime
from workflows.runtime.types.plugin import Plugin, WorkflowRuntime
from workflows.types import RunResultT
from workflows.handler import WorkflowHandler

from .serializers import BaseSerializer, JsonSerializer
from .state_store import MODEL_T, DictState, InMemoryStateStore


if TYPE_CHECKING:  # pragma: no cover
    from workflows import Workflow


T = TypeVar("T", bound=Event)
EventBuffer = dict[str, list[Event]]


# Only warn once about unserializable keys
class UnserializableKeyWarning(Warning):
    pass


warnings.simplefilter("once", UnserializableKeyWarning)


class Context(Generic[MODEL_T]):
    """
    Global, per-run context for a `Workflow`. Provides an interface into the
    underlying broker run, for both external (workflow run oberservers) and
    internal consumption by workflow steps.

    The `Context` coordinates event delivery between steps, tracks in-flight work,
    exposes a global state store, and provides utilities for streaming and
    synchronization. It is created by a `Workflow` at run time and can be
    persisted and restored.

    Args:
        workflow (Workflow): The owning workflow instance. Used to infer
            step configuration and instrumentation.
        previous_context: A previous context snapshot to resume from.
        serializer: A serializer to use for serializing and deserializing the current and previous context snapshots.

    Attributes:
        is_running (bool): Whether the workflow is currently running.
        store (InMemoryStateStore[MODEL_T]): Type-safe, async state store shared
            across steps. See also
            [InMemoryStateStore][workflows.context.state_store.InMemoryStateStore].

    Examples:
        Basic usage inside a step:

        ```python
        from workflows import step
        from workflows.events import StartEvent, StopEvent

        @step
        async def start(self, ctx: Context, ev: StartEvent) -> StopEvent:
            await ctx.store.set("query", ev.topic)
            ctx.write_event_to_stream(ev)  # surface progress to UI
            return StopEvent(result="ok")
        ```

        Persisting the state of a workflow across runs:

        ```python
        from workflows import Context

        # Create a context and run the workflow with the same context
        ctx = Context(my_workflow)
        result_1 = await my_workflow.run(..., ctx=ctx)
        result_2 = await my_workflow.run(..., ctx=ctx)

        # Serialize the context and restore it
        ctx_dict = ctx.to_dict()
        restored_ctx = Context.from_dict(my_workflow, ctx_dict)
        result_3 = await my_workflow.run(..., ctx=restored_ctx)
        ```


    See Also:
        - [Workflow][workflows.Workflow]
        - [Event][workflows.events.Event]
        - [InMemoryStateStore][workflows.context.state_store.InMemoryStateStore]
    """

    # These keys are set by pre-built workflows and
    # are known to be unserializable in some cases.
    known_unserializable_keys = ("memory",)

    # Backing state store; serialized as `state`
    _state_store: InMemoryStateStore[MODEL_T]
    _broker_run: WorkflowBroker[MODEL_T] | None
    _plugin: Plugin
    _workflow: Workflow

    def __init__(
        self,
        workflow: Workflow,
        previous_context: dict[str, Any] | None = None,
        serializer: BaseSerializer | None = None,
        plugin: Plugin = basic_runtime,
    ) -> None:
        self._serializer = serializer or JsonSerializer()
        self._broker_run = None
        self._plugin = plugin
        self._workflow = workflow

        # parse the serialized context
        serializer = serializer or JsonSerializer()
        if previous_context is not None:
            try:
                # Auto-detect and convert V0 to V1 if needed
                previous_context_parsed = SerializedContext.from_dict_auto(
                    previous_context
                )
                # validate it fully parses synchronously to avoid delayed validation errors
                BrokerState.from_serialized(
                    previous_context_parsed, workflow, serializer
                )
            except ValidationError as e:
                raise ContextSerdeError(
                    f"Context dict specified in an invalid format: {e}"
                ) from e
        else:
            previous_context_parsed = SerializedContext()

        self._init_snapshot = previous_context_parsed

        # initialization of the state store is a bit complex, due to inferring and validating its type from the
        # provided workflow context args

        state_types: set[Type[BaseModel]] = set()
        for _, step_func in workflow._get_steps().items():
            step_config: StepConfig = step_func._step_config
            if (
                step_config.context_state_type is not None
                and step_config.context_state_type != DictState
                and issubclass(step_config.context_state_type, BaseModel)
            ):
                state_type = step_config.context_state_type
                state_types.add(state_type)

        if len(state_types) > 1:
            raise ValueError(
                "Multiple state types are not supported. Make sure that each Context[...] has the same generic state type. Found: "
                + ", ".join([state_type.__name__ for state_type in state_types])
            )
        state_type = state_types.pop() if state_types else DictState
        if previous_context_parsed.state:
            # perhaps offer a way to clear on invalid
            store_state = InMemoryStateStore.from_dict(
                previous_context_parsed.state, serializer
            )
            if store_state.state_type != state_type:
                raise ValueError(
                    f"State type mismatch. Workflow context expected {state_type.__name__}, got {store_state.state_type.__name__}"
                )
            self._state_store = cast(InMemoryStateStore[MODEL_T], store_state)
        else:
            try:
                state_instance = cast(MODEL_T, state_type())
                self._state_store = InMemoryStateStore(state_instance)
            except Exception as e:
                raise WorkflowRuntimeError(
                    f"Failed to initialize state of type {state_type}. Does your state define defaults for all fields? Original error:\n{e}"
                ) from e

    @property
    def is_running(self) -> bool:
        """Whether the workflow is currently running."""
        if self._broker_run is None:
            return self._init_snapshot.is_running
        else:
            return self._broker_run.is_running

    def _init_broker(
        self, workflow: Workflow, plugin: WorkflowRuntime | None = None
    ) -> WorkflowBroker[MODEL_T]:
        if self._broker_run is not None:
            raise WorkflowRuntimeError("Broker already initialized")
        # Initialize a runtime plugin (asyncio-based by default)
        runtime: WorkflowRuntime = plugin or self._plugin.new_runtime(str(uuid.uuid4()))
        # Initialize the new broker implementation (broker2)
        self._broker_run = WorkflowBroker(
            workflow=workflow,
            context=self,
            runtime=runtime,
            plugin=self._plugin,
        )
        return self._broker_run

    def _workflow_run(
        self,
        workflow: Workflow,
        start_event: StartEvent | None = None,
        semaphore: asyncio.Semaphore | None = None,
    ) -> WorkflowHandler:
        """
        called by package internally from the workflow to run it
        """
        prev_broker: WorkflowBroker[MODEL_T] | None = None
        if self._broker_run is not None:
            prev_broker = self._broker_run
            self._broker_run = None

        self._broker_run = self._init_broker(workflow)

        async def before_start() -> None:
            if prev_broker is not None:
                try:
                    await prev_broker.shutdown()
                except Exception:
                    pass
            if semaphore is not None:
                await semaphore.acquire()

        async def after_complete() -> None:
            if semaphore is not None:
                semaphore.release()

        state = BrokerState.from_serialized(
            self._init_snapshot, workflow, self._serializer
        )
        return self._broker_run.start(
            workflow=workflow,
            previous=state,
            start_event=start_event,
            before_start=before_start,
            after_complete=after_complete,
        )

    def _workflow_cancel_run(self) -> None:
        """
        Called internally from the handler to cancel a context's run
        """
        self._running_broker.cancel_run()

    @property
    def _running_broker(self) -> WorkflowBroker[MODEL_T]:
        if self._broker_run is None:
            raise WorkflowRuntimeError(
                "Workflow run is not yet running. Make sure to only call this method after the context has been passed to a workflow.run call."
            )
        return self._broker_run

    @property
    def store(self) -> InMemoryStateStore[MODEL_T]:
        """Typed, process-local state store shared across steps.

        If no state was initialized yet, a default
        [DictState][workflows.context.state_store.DictState] store is created.

        Returns:
            InMemoryStateStore[MODEL_T]: The state store instance.
        """
        return self._state_store

    def to_dict(self, serializer: BaseSerializer | None = None) -> dict[str, Any]:
        """Serialize the context to a JSON-serializable dict.

        Persists the global state store, event queues, buffers, accepted events,
        broker log, and running flag. This payload can be fed to
        [from_dict][workflows.context.context.Context.from_dict] to resume a run
        or carry state across runs.

        Args:
            serializer (BaseSerializer | None): Value serializer used for state
                and event payloads. Defaults to
                [JsonSerializer][workflows.context.serializers.JsonSerializer].

        Returns:
            dict[str, Any]: A dict suitable for JSON encoding and later
            restoration via `from_dict`.

        See Also:
            - [InMemoryStateStore.to_dict][workflows.context.state_store.InMemoryStateStore.to_dict]

        Examples:
            ```python
            ctx_dict = ctx.to_dict()
            my_db.set("key", json.dumps(ctx_dict))

            ctx_dict = my_db.get("key")
            restored_ctx = Context.from_dict(my_workflow, json.loads(ctx_dict))
            result = await my_workflow.run(..., ctx=restored_ctx)
            ```
        """
        serializer = serializer or self._serializer

        # Serialize state using the state manager's method
        state_data = {}
        if self._state_store is not None:
            state_data = self._state_store.to_dict(serializer)

        # Get the broker state - either from the running broker or from the init snapshot
        if self._broker_run is not None:
            broker_state = self._broker_run._state
        else:
            # Deserialize the init snapshot to get a BrokerState, then re-serialize it
            # This ensures we always output the current format
            broker_state = BrokerState.from_serialized(
                self._init_snapshot, self._workflow, serializer
            )

        context = broker_state.to_serialized(serializer)
        context.state = state_data
        # mode="python" to support pickling over json if one so chooses. This should perhaps be moved into the serializers
        return context.model_dump(mode="python")

    @classmethod
    def from_dict(
        cls,
        workflow: "Workflow",
        data: dict[str, Any],
        serializer: BaseSerializer | None = None,
    ) -> "Context[MODEL_T]":
        """Reconstruct a `Context` from a serialized payload.

        Args:
            workflow (Workflow): The workflow instance that will own this
                context.
            data (dict[str, Any]): Payload produced by
                [to_dict][workflows.context.context.Context.to_dict].
            serializer (BaseSerializer | None): Serializer used to decode state
                and events. Defaults to JSON.

        Returns:
            Context[MODEL_T]: A context instance initialized with the persisted
                state and queues.

        Raises:
            ContextSerdeError: If the payload is missing required fields or is
                in an incompatible format.

        Examples:
            ```python
            ctx_dict = ctx.to_dict()
            my_db.set("key", json.dumps(ctx_dict))

            ctx_dict = my_db.get("key")
            restored_ctx = Context.from_dict(my_workflow, json.loads(ctx_dict))
            result = await my_workflow.run(..., ctx=restored_ctx)
            ```
        """
        try:
            return cls(workflow, previous_context=data, serializer=serializer)
        except KeyError as e:
            msg = "Error creating a Context instance: the provided payload has a wrong or old format."
            raise ContextSerdeError(msg) from e

    async def running_steps(self) -> list[str]:
        """Return the list of currently running step names.

        Returns:
            list[str]: Names of steps that have at least one active worker.
        """
        return await self._running_broker.running_steps()

    def collect_events(
        self, ev: Event, expected: list[Type[Event]], buffer_id: str | None = None
    ) -> list[Event] | None:
        """
        Buffer events until all expected types are available, then return them.

        This utility is helpful when a step can receive multiple event types
        and needs to proceed only when it has a full set. The returned list is
        ordered according to `expected`.

        Args:
            ev (Event): The incoming event to add to the buffer.
            expected (list[Type[Event]]): Event types to collect, in order.
            buffer_id (str | None): Optional stable key to isolate buffers across
                steps or workers. Defaults to an internal key derived from the
                task name or expected types.

        Returns:
            list[Event] | None: The events in the requested order when complete,
            otherwise `None`.

        Examples:
            ```python
            @step
            async def synthesize(
                self, ctx: Context, ev: QueryEvent | RetrieveEvent
            ) -> StopEvent | None:
                events = ctx.collect_events(ev, [QueryEvent, RetrieveEvent])
                if events is None:
                    return None
                query_ev, retrieve_ev = events
                # ... proceed with both inputs present ...
            ```

        See Also:
            - [Event][workflows.events.Event]
        """
        return self._running_broker.collect_events(ev, expected, buffer_id)

    def send_event(self, message: Event, step: str | None = None) -> None:
        """Dispatch an event to one or all workflow steps.

        If `step` is omitted, the event is broadcast to all step queues and
        non-matching steps will ignore it. When `step` is provided, the target
        step must accept the event type or a
        [WorkflowRuntimeError][workflows.errors.WorkflowRuntimeError] is raised.

        Args:
            message (Event): The event to enqueue.
            step (str | None): Optional step name to target.

        Raises:
            WorkflowRuntimeError: If the target step does not exist or does not
                accept the event type.

        Examples:
            It's common to use this method to fan-out events:

            ```python
            @step
            async def my_step(self, ctx: Context, ev: StartEvent) -> WorkerEvent | GatherEvent:
                for i in range(10):
                    ctx.send_event(WorkerEvent(msg=i))
                return GatherEvent()
            ```

            You also see this method used from the caller side to send events into the workflow:

            ```python
            handler = my_workflow.run(...)
            async for ev in handler.stream_events():
                if isinstance(ev, SomeEvent):
                    handler.ctx.send_event(SomeOtherEvent(msg="Hello!"))

            result = await handler
            ```
        """
        return self._running_broker.send_event(message, step)

    async def wait_for_event(
        self,
        event_type: Type[T],
        waiter_event: Event | None = None,
        waiter_id: str | None = None,
        requirements: dict[str, Any] | None = None,
        timeout: float | None = 2000,
    ) -> T:
        """Wait for the next matching event of type `event_type`.

        Optionally emits a `waiter_event` to the event stream once per `waiter_id` to
        inform callers that the workflow is waiting for external input.
        This helps to prevent duplicate waiter events from being sent to the event stream.

        Args:
            event_type (type[T]): Concrete event class to wait for.
            waiter_event (Event | None): Optional event to write to the stream
                once when the wait begins.
            waiter_id (str | None): Stable identifier to avoid emitting multiple
                waiter events for the same logical wait.
            requirements (dict[str, Any] | None): Key/value filters that must be
                satisfied by the event via `event.get(key) == value`.
            timeout (float | None): Max seconds to wait. `None` means no
                timeout. Defaults to 2000 seconds.

        Returns:
            T: The received event instance of the requested type.

        Raises:
            asyncio.TimeoutError: If the timeout elapses.

        Examples:
            ```python
            @step
            async def my_step(self, ctx: Context, ev: StartEvent) -> StopEvent:
                response = await ctx.wait_for_event(
                    HumanResponseEvent,
                    waiter_event=InputRequiredEvent(msg="What's your name?"),
                    waiter_id="user_name",
                    timeout=60,
                )
                return StopEvent(result=response.response)
            ```
        """
        return await self._running_broker.wait_for_event(
            event_type, waiter_event, waiter_id, requirements, timeout
        )

    def write_event_to_stream(self, ev: Event | None) -> None:
        """Enqueue an event for streaming to [WorkflowHandler]](workflows.handler.WorkflowHandler).

        Args:
            ev (Event | None): The event to stream. `None` can be used as a
                sentinel in some streaming modes.

        Examples:
            ```python
            @step
            async def my_step(self, ctx: Context, ev: StartEvent) -> StopEvent:
                ctx.write_event_to_stream(ev)
                return StopEvent(result="ok")
            ```
        """
        self._running_broker.write_event_to_stream(ev)

    def get_result(self) -> RunResultT:
        """Return the final result of the workflow run.

        Deprecated:
            This method is deprecated and will be removed in a future release.
            Prefer awaiting the handler returned by `Workflow.run`, e.g.:
            `result = await workflow.run(..., ctx=ctx)`.

        Examples:
            ```python
            # Preferred
            result = await my_workflow.run(..., ctx=ctx)

            # Deprecated
            result_agent = ctx.get_result()
            ```

        Returns:
            RunResultT: The value provided via a `StopEvent`.
        """
        _warn_get_result()
        if self._running_broker._handler is None:
            raise WorkflowRuntimeError("Workflow handler is not set")
        return self._running_broker._handler.result()

    def stream_events(self) -> AsyncGenerator[Event, None]:
        """The internal queue used for streaming events to callers."""
        return self._running_broker.stream_published_events()

    @property
    def streaming_queue(self) -> asyncio.Queue:
        """Deprecated queue-based event stream.

        Returns an asyncio.Queue that is populated by iterating this context's
        stream_events(). A deprecation warning is emitted once per process.
        """
        _warn_streaming_queue()
        q: asyncio.Queue[Event] = asyncio.Queue()

        async def _pump() -> None:
            async for ev in self.stream_events():
                await q.put(ev)
                if isinstance(ev, StopEvent):
                    break

        try:
            asyncio.create_task(_pump())
        except RuntimeError:
            loop = asyncio.get_event_loop()
            loop.create_task(_pump())
        return q


@functools.lru_cache(maxsize=1)
def _warn_get_result() -> None:
    warnings.warn(
        (
            "Context.get_result() is deprecated and will be removed in a future "
            "release. Prefer awaiting the WorkflowHandler returned by "
            "Workflow.run: `result = await workflow.run(..., ctx=ctx)`."
        ),
        DeprecationWarning,
        stacklevel=2,
    )


@functools.lru_cache(maxsize=1)
def _warn_streaming_queue() -> None:
    warnings.warn(
        (
            "Context.streaming_queue is deprecated and will be removed in a future "
            "release. Prefer iterating Context.stream_events(): "
            "`async for ev in ctx.stream_events(): ...`"
        ),
        DeprecationWarning,
        stacklevel=2,
    )
