# -*- coding: utf-8 -*-
import itertools
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple

if TYPE_CHECKING:
    from pyccolo.tracer import _InternalBaseTracer


class TraceStack:
    def __init__(self, manager: "_InternalBaseTracer"):
        self._manager = manager
        self._stack: List[Tuple[Any, ...]] = []
        self._stack_item_initializers: Dict[str, Callable[[], Any]] = {}
        self._stack_items_with_manual_initialization: Set[str] = set()
        self._registering_stack_state_context = False
        self._field_mapping: Dict[str, int] = {}

    def _stack_item_names(self):
        return itertools.chain(
            self._stack_item_initializers.keys(),
            self._stack_items_with_manual_initialization,
        )

    def get_field(
        self, field: str, depth: Optional[int] = None, height: Optional[int] = None
    ) -> Any:
        height = -(1 if depth is None else depth) if height is None else height
        return self._stack[height][self._field_mapping[field]]

    @staticmethod
    def _make_initer_from_val(init_val: Any) -> Callable[[], Any]:
        return lambda: init_val

    @contextmanager
    def register_stack_state(self):
        self._registering_stack_state_context = True
        original_state = set(self._manager.__dict__.keys())
        yield
        self._registering_stack_state_context = False
        stack_item_names = set(self._manager.__dict__.keys() - original_state)
        for stack_item_name in (
            stack_item_names - self._stack_items_with_manual_initialization
        ):
            stack_item = self._manager.__dict__[stack_item_name]
            if isinstance(stack_item, TraceStack):
                self._stack_item_initializers[stack_item_name] = stack_item._clone
            elif stack_item is None:
                self._stack_item_initializers[stack_item_name] = lambda: None
            elif isinstance(stack_item, (int, bool, str, float)):
                init_val = type(stack_item)(stack_item)
                self._stack_item_initializers[stack_item_name] = (
                    self._make_initer_from_val(init_val)
                )
            else:
                self._stack_item_initializers[stack_item_name] = type(stack_item)
        for i, stack_item_name in enumerate(self._stack_item_names()):
            self._field_mapping[stack_item_name] = i

    @contextmanager
    def needing_manual_initialization(self):
        assert self._registering_stack_state_context
        original_state = set(self._manager.__dict__.keys())
        yield
        self._stack_items_with_manual_initialization = set(
            self._manager.__dict__.keys() - original_state
        )

    @contextmanager
    def push(self):
        """
        Checks at the end of the context that everything requiring manual init was manually inited.
        """
        self._stack.append(
            tuple(
                self._manager.__dict__[stack_item]
                for stack_item in self._stack_item_names()
            )
        )
        for stack_item, initializer in self._stack_item_initializers.items():
            self._manager.__dict__[stack_item] = initializer()
        for stack_item in self._stack_items_with_manual_initialization:
            del self._manager.__dict__[stack_item]
        yield
        uninitialized_items = []
        for stack_item in self._stack_items_with_manual_initialization:
            if stack_item not in self._manager.__dict__:
                uninitialized_items.append(stack_item)
        if len(uninitialized_items) > 0:
            raise ValueError(
                "Stack item(s) %s requiring manual initialization were not initialized"
                % uninitialized_items
            )

    def _clone(self):
        new_tracing_stack = TraceStack(self._manager)
        new_tracing_stack.__dict__ = dict(self.__dict__)
        return new_tracing_stack

    def pop(self) -> "TraceStack":
        for stack_item_name, stack_item in zip(
            self._stack_item_names(), self._stack.pop()
        ):
            self._manager.__dict__[stack_item_name] = stack_item
        return self

    def clear(self):
        self._stack = self._stack[:1]
        if len(self._stack) > 0:
            self.pop()

    def __len__(self):
        return len(self._stack)
