from typing import Any, Optional
from collections import Counter
from dataclasses import dataclass

from workflows import Workflow, Context
from workflows.events import StartEvent, Event, EventType


@dataclass
class WorkflowTestResult:
    """
    Container for workflow test results

    Attributes:
        collected (list[Event]): List of collected events
        event_type (dict[EventType, int]): Dictionary that maps each event type with its number of occurencies within the collected events
        result (Any): Final output of the workflow run
    """

    collected: list[Event]
    event_types: dict[EventType, int]
    result: Any
    ctx: Context


class WorkflowTestRunner:
    """
    Utility class that can be used to test workflows end-to-end.

    Attributes:
        _workflow (Workflow): The workflow to be tested
    """

    def __init__(
        self,
        workflow: "Workflow",
    ):
        self._workflow = workflow

    async def run(
        self,
        start_event: StartEvent = StartEvent(),
        ctx: Optional["Context"] = None,
        expose_internal: bool = True,
        exclude_events: Optional[list[EventType]] = None,
    ) -> WorkflowTestResult:
        """
        Run a workflow end-to-end and collect the events that are streamed during its execution.

        Args:
            start_event (StartEvent): The input event for the workflow
            expose_internal (bool): Whether or not to expose internal events. Defaults to True if not set.
            exclude_events. (list[EventType]): A list of event types to exclude from the collected events. Defaults to None if not set.

        Returns:
            WorkflowTestResult

        Example:
            ```
            wf = GreetingWorkflow()
            runner = WorkflowTestRunner(wf)
            test_result = runner.run(start_even=StartEvent(message="hello"), expose_internal = True, exclude_events = [StepStateChanged])
            assert test_result.collected == 22
            assert test_result.event_types.get(StepStateChanged, 0) == 8
            assert str(test_result.result) == "hello Adam!"
            ```
        """
        handler = self._workflow.run(start_event=start_event, ctx=ctx)
        collected_events: list[Event] = []
        async for event in handler.stream_events(expose_internal=expose_internal):
            if exclude_events and type(event) in exclude_events:
                continue
            collected_events.append(event)
        result = await handler
        event_freqs: dict[EventType, int] = dict(
            Counter([type(ev) for ev in collected_events])
        )
        assert handler.ctx is not None
        return WorkflowTestResult(
            collected=collected_events,
            result=result,
            event_types=event_freqs,
            ctx=handler.ctx,
        )
