from __future__ import annotations

from contextlib import contextmanager
from typing import Any, Callable, ClassVar, Union

from tyro._singleton import DEFAULT_SENTINEL_SINGLETONS

from .. import _resolver
from ._primitive_spec import (
    PrimitiveConstructorSpec,
    PrimitiveTypeInfo,
    UnsupportedTypeAnnotationError,
    apply_default_primitive_rules,
)
from ._struct_spec import (
    InvalidDefaultInstanceError,
    StructConstructorSpec,
    StructTypeInfo,
    apply_default_struct_rules,
)

current_registry: ConstructorRegistry | None = None

PrimitiveSpecRule = Callable[
    [PrimitiveTypeInfo],
    Union[PrimitiveConstructorSpec, UnsupportedTypeAnnotationError, None],
]
StructSpecRule = Callable[[StructTypeInfo], Union[StructConstructorSpec, None]]

_check_default_instances_flag: bool = False


def check_default_instances() -> bool:
    """Check whether we should be strict about checking that default types and
    instances match.

    This is usually `False`; tyro attempts to be somewhat lenient when
    inconsistent types are encounted. Strictness, however, is useful for
    matching annotated subcommands to default values.
    """
    return _check_default_instances_flag


@contextmanager
def check_default_instances_context():
    """Context for whether we should be strict about checking that default
    types and instances match.

    This is usually `False`; tyro attempts to be somewhat lenient when
    inconsistent types are encounted. Strictness, however, is useful for
    matching annotated subcommands to default values.
    """
    global _check_default_instances_flag
    old_value = _check_default_instances_flag
    _check_default_instances_flag = True
    try:
        yield
    finally:
        _check_default_instances_flag = old_value


class ConstructorRegistry:
    """Registry for rules that define how types are constructed from
    command-line arguments.

    The behavior of CLIs generated by tyro are based on two types of rules.

    *Primitive rules* should be a callable with the signature:

    .. code-block: python

        (type_info: PrimitiveTypeInfo) -> PrimitiveConstructorSpec | UnsupportedTypeAnnotationError | None

    where `None` is returned if the rule doesn't apply, and
    `UnsupportedTypeAnnotationError` is returned if the rule applies but an
    error was encountered. Each primitive rule defines behavior for a type that
    can be instantiated from a single command-line argument.


    *Struct rules* should be a callable with the signature:

    .. code-block: python

        (type_info: StructTypeInfo) -> StructConstructorSpec | None

    where `None` is returned if the rule doesn't apply. Each struct rule
    defines behavior for a type that can be instantiated from multiple
    command-line arguments.


    To activate a registry, pass it directly to :func:`tyro.cli`:

    .. code-block: python

        registry = ConstructorRegistry()
        tyro.cli(..., registry=registry)

    For backward compatibility, you can also use the context manager pattern, though
    the direct parameter approach is recommended:

    .. code-block: python

        registry = ConstructorRegistry()
        with registry:
            tyro.cli(...)

    """

    _active_registries: ClassVar[list[ConstructorRegistry]] = []

    def __init__(self) -> None:
        self._primitive_rules: list[PrimitiveSpecRule] = []
        self._struct_rules: list[StructSpecRule] = []

    def primitive_rule(self, rule: PrimitiveSpecRule) -> PrimitiveSpecRule:
        """Define a rule for constructing a primitive type from a string. The
        most recently added rule will be applied first.

        Custom primitive rules will take precedence over both default primitive
        rules and struct rules
        """

        self._primitive_rules.append(rule)
        return rule

    def struct_rule(self, rule: StructSpecRule) -> StructSpecRule:
        """Define a rule for constructing a primitive type from a string. The
        most recently added rule will be applied first."""

        self._struct_rules.append(rule)
        return rule

    @classmethod
    def _is_primitive_type(
        cls, type: Any, markers: set[Any], nondefault_only: bool = False
    ) -> bool:
        """Check if a type is a primitive type."""
        return isinstance(
            cls.get_primitive_spec(
                PrimitiveTypeInfo.make(type, markers), nondefault_only=nondefault_only
            ),
            PrimitiveConstructorSpec,
        )

    @classmethod
    def get_primitive_spec(
        cls, type_info: PrimitiveTypeInfo, nondefault_only: bool = False
    ) -> PrimitiveConstructorSpec | UnsupportedTypeAnnotationError:
        """Get a constructor specification for a given type."""

        cls._ensure_defaults_initialized()

        if type_info._primitive_spec is not None:
            return type_info._primitive_spec

        for registry in (
            cls._active_registries[1:] if nondefault_only else cls._active_registries
        )[::-1]:
            for spec_factory in registry._primitive_rules[::-1]:
                maybe_spec = spec_factory(type_info)
                if maybe_spec is not None:
                    return maybe_spec

        return UnsupportedTypeAnnotationError(
            f"Unsupported type annotation: {type_info.type}"
        )

    @classmethod
    def get_struct_spec(cls, type_info: StructTypeInfo) -> StructConstructorSpec | None:
        """Get a constructor specification for a given type. Returns `None` if
        unsuccessful."""

        cls._ensure_defaults_initialized()

        if (
            check_default_instances()
            and type_info.default not in DEFAULT_SENTINEL_SINGLETONS
            and not _resolver.is_instance(type_info.type, type_info.default)
        ):
            raise InvalidDefaultInstanceError(
                f"Invalid default instance for type {type_info.type}: {type_info.default}"
            )

        with type_info._typevar_context:
            for registry in cls._active_registries[::-1]:
                for spec_factory in registry._struct_rules[::-1]:
                    maybe_spec = spec_factory(type_info)
                    if maybe_spec is not None:
                        return maybe_spec

        return None

    def __enter__(self) -> None:
        self.__class__._ensure_defaults_initialized()

        cls = self.__class__
        cls._active_registries.append(self)

    def __exit__(self, *args: Any) -> None:
        cls = self.__class__
        assert cls._active_registries.pop() is self

    @classmethod
    def _ensure_defaults_initialized(cls) -> None:
        """Initialize default registry if needed."""
        if len(cls._active_registries) == 0:
            registry = ConstructorRegistry()

            # Apply the default rules.
            apply_default_primitive_rules(registry)
            apply_default_struct_rules(registry)

            cls._active_registries.append(registry)
