diff --git a/modern_di/container.py b/modern_di/container.py index 0269945..8ecf828 100644 --- a/modern_di/container.py +++ b/modern_di/container.py @@ -3,7 +3,7 @@ import typing_extensions -from modern_di import errors, types +from modern_di import exceptions, types from modern_di.group import Group from modern_di.providers.abstract import AbstractProvider from modern_di.providers.container_provider import container_provider @@ -62,39 +62,31 @@ def build_child_container( self, context: dict[type[typing.Any], typing.Any] | None = None, scope: Scope | None = None ) -> "typing_extensions.Self": if scope and scope <= self.scope: - raise RuntimeError( - errors.CONTAINER_SCOPE_IS_LOWER_ERROR.format( - parent_scope=self.scope.name, - child_scope=scope.name, - allowed_scopes=[x.name for x in Scope if x > self.scope], - ) + raise exceptions.InvalidChildScopeError( + parent_scope=self.scope, + child_scope=scope, + allowed_scopes=[x.name for x in Scope if x > self.scope], ) if not scope: try: scope = self.scope.__class__(self.scope.value + 1) except ValueError as exc: - raise RuntimeError( - errors.CONTAINER_MAX_SCOPE_REACHED_ERROR.format(parent_scope=self.scope.name) - ) from exc + raise exceptions.MaxScopeReachedError(parent_scope=self.scope) from exc return self.__class__(scope=scope, parent_container=self, context=context) def find_container(self, scope: Scope) -> "typing_extensions.Self": if scope not in self.scope_map: if scope > self.scope: - raise RuntimeError( - errors.CONTAINER_NOT_INITIALIZED_SCOPE_ERROR.format( - provider_scope=scope.name, container_scope=self.scope.name - ) - ) - raise RuntimeError(errors.CONTAINER_SCOPE_IS_SKIPPED_ERROR.format(provider_scope=scope.name)) + raise exceptions.ScopeNotInitializedError(provider_scope=scope, container_scope=self.scope) + raise exceptions.ScopeSkippedError(provider_scope=scope) return self.scope_map[scope] def resolve(self, dependency_type: type[types.T]) -> types.T: provider = self.providers_registry.find_provider(dependency_type) if not provider: - raise RuntimeError(errors.CONTAINER_MISSING_PROVIDER_ERROR.format(provider_type=dependency_type)) + raise exceptions.ProviderNotRegisteredError(provider_type=dependency_type) return self.resolve_provider(provider) @@ -123,7 +115,7 @@ def _visit(provider: AbstractProvider[typing.Any]) -> None: cycle_start = next(i for i, p in enumerate(path) if p.provider_id == pid) cycle_names = [p.bound_type.__name__ if p.bound_type else repr(p) for p in path[cycle_start:]] cycle_names.append(cycle_names[0]) - raise RuntimeError(errors.CYCLE_DEPENDENCY_ERROR.format(cycle_path=" -> ".join(cycle_names))) + raise exceptions.CircularDependencyError(cycle_path=cycle_names) visiting.add(pid) path.append(provider) diff --git a/modern_di/exceptions.py b/modern_di/exceptions.py new file mode 100644 index 0000000..5270525 --- /dev/null +++ b/modern_di/exceptions.py @@ -0,0 +1,104 @@ +import typing + +from modern_di import errors +from modern_di.scope import Scope + + +class ModernDIError(RuntimeError): + """Base class for all modern-di errors. Inherits from RuntimeError for backwards compatibility.""" + + +class ContainerError(ModernDIError): + """Base class for container and scope errors.""" + + +class InvalidChildScopeError(ContainerError): + def __init__(self, *, parent_scope: Scope, child_scope: Scope, allowed_scopes: list[str]) -> None: + self.parent_scope = parent_scope + self.child_scope = child_scope + self.allowed_scopes = allowed_scopes + super().__init__( + errors.CONTAINER_SCOPE_IS_LOWER_ERROR.format( + parent_scope=parent_scope.name, + child_scope=child_scope.name, + allowed_scopes=allowed_scopes, + ) + ) + + +class MaxScopeReachedError(ContainerError): + def __init__(self, *, parent_scope: Scope) -> None: + self.parent_scope = parent_scope + super().__init__(errors.CONTAINER_MAX_SCOPE_REACHED_ERROR.format(parent_scope=parent_scope.name)) + + +class ScopeNotInitializedError(ContainerError): + def __init__(self, *, provider_scope: Scope, container_scope: Scope) -> None: + self.provider_scope = provider_scope + self.container_scope = container_scope + super().__init__( + errors.CONTAINER_NOT_INITIALIZED_SCOPE_ERROR.format( + provider_scope=provider_scope.name, + container_scope=container_scope.name, + ) + ) + + +class ScopeSkippedError(ContainerError): + def __init__(self, *, provider_scope: Scope) -> None: + self.provider_scope = provider_scope + super().__init__(errors.CONTAINER_SCOPE_IS_SKIPPED_ERROR.format(provider_scope=provider_scope.name)) + + +class ResolutionError(ModernDIError): + """Base class for errors raised while resolving a provider.""" + + +class ProviderNotRegisteredError(ResolutionError): + def __init__(self, *, provider_type: type) -> None: + self.provider_type = provider_type + super().__init__(errors.CONTAINER_MISSING_PROVIDER_ERROR.format(provider_type=provider_type)) + + +class ArgumentResolutionError(ResolutionError): + def __init__(self, *, arg_name: str, arg_type: typing.Any, bound_type: typing.Any) -> None: # noqa: ANN401 + self.arg_name = arg_name + self.arg_type = arg_type + self.bound_type = bound_type + super().__init__( + errors.FACTORY_ARGUMENT_RESOLUTION_ERROR.format( + arg_name=arg_name, + arg_type=arg_type, + bound_type=bound_type, + ) + ) + + +class CircularDependencyError(ResolutionError): + def __init__(self, *, cycle_path: list[str]) -> None: + self.cycle_path = cycle_path + super().__init__(errors.CYCLE_DEPENDENCY_ERROR.format(cycle_path=" -> ".join(cycle_path))) + + +class RegistrationError(ModernDIError): + """Base class for errors raised while registering providers.""" + + +class DuplicateProviderTypeError(RegistrationError): + def __init__(self, *, provider_type: type) -> None: + self.provider_type = provider_type + super().__init__(errors.PROVIDER_DUPLICATE_TYPE_ERROR.format(provider_type=provider_type)) + + +class FinalizerError(ModernDIError): + def __init__(self, *, finalizer_errors: list[BaseException], is_async: bool) -> None: + self.finalizer_errors = finalizer_errors + self.is_async = is_async + kind = "async" if is_async else "sync" + super().__init__(f"Errors during {kind} cleanup: {finalizer_errors}") + + +class GroupInstantiationError(ModernDIError): + def __init__(self, *, group_name: str) -> None: + self.group_name = group_name + super().__init__(f"{group_name} cannot be instantiated") diff --git a/modern_di/group.py b/modern_di/group.py index 404bc6d..40b20e8 100644 --- a/modern_di/group.py +++ b/modern_di/group.py @@ -1,5 +1,6 @@ import typing +from modern_di import exceptions from modern_di.providers.abstract import AbstractProvider @@ -15,8 +16,7 @@ class Group: providers: list[AbstractProvider[typing.Any]] def __new__(cls, *_: typing.Any, **__: typing.Any) -> "typing_extensions.Self": # noqa: ANN401 - msg = f"{cls.__name__} cannot be instantiated" - raise RuntimeError(msg) + raise exceptions.GroupInstantiationError(group_name=cls.__name__) @classmethod def get_providers(cls) -> list[AbstractProvider[typing.Any]]: diff --git a/modern_di/providers/factory.py b/modern_di/providers/factory.py index 41005fc..e1e36ea 100644 --- a/modern_di/providers/factory.py +++ b/modern_di/providers/factory.py @@ -3,7 +3,7 @@ import typing import warnings -from modern_di import errors, types +from modern_di import exceptions, types from modern_di.providers import ContextProvider from modern_di.providers.abstract import AbstractProvider from modern_di.scope import Scope @@ -82,18 +82,14 @@ def _compile_kwargs(self, container: "Container") -> dict[str, typing.Any]: if provider: result[k] = provider if is_kwarg_not_found and isinstance(provider, ContextProvider) and provider.resolve(container) is None: - raise RuntimeError( - errors.FACTORY_ARGUMENT_RESOLUTION_ERROR.format( - arg_name=k, arg_type=v.arg_type, bound_type=self.bound_type or self._creator - ) + raise exceptions.ArgumentResolutionError( + arg_name=k, arg_type=v.arg_type, bound_type=self.bound_type or self._creator ) continue if v.default == types.UNSET and is_kwarg_not_found: - raise RuntimeError( - errors.FACTORY_ARGUMENT_RESOLUTION_ERROR.format( - arg_name=k, arg_type=v.arg_type, bound_type=self.bound_type or self._creator - ) + raise exceptions.ArgumentResolutionError( + arg_name=k, arg_type=v.arg_type, bound_type=self.bound_type or self._creator ) if self._kwargs: diff --git a/modern_di/registries/cache_registry.py b/modern_di/registries/cache_registry.py index 1a0c40e..d7ef8cc 100644 --- a/modern_di/registries/cache_registry.py +++ b/modern_di/registries/cache_registry.py @@ -2,7 +2,7 @@ import typing import warnings -from modern_di import types +from modern_di import exceptions, types from modern_di.providers import CacheSettings, Factory @@ -52,23 +52,21 @@ def fetch_cache_item(self, provider: Factory[types.T_co]) -> CacheItem: return self._items.setdefault(provider.provider_id, CacheItem(settings=provider.cache_settings)) async def close_async(self) -> None: - errors: list[BaseException] = [] + finalizer_errors: list[BaseException] = [] for cache_item in self._items.values(): try: await cache_item.close_async() except Exception as e: # noqa: BLE001, PERF203 - errors.append(e) - if errors: - msg = f"Errors during async cleanup: {errors}" - raise RuntimeError(msg) + finalizer_errors.append(e) + if finalizer_errors: + raise exceptions.FinalizerError(finalizer_errors=finalizer_errors, is_async=True) def close_sync(self) -> None: - errors: list[BaseException] = [] + finalizer_errors: list[BaseException] = [] for cache_item in self._items.values(): try: cache_item.close_sync() except Exception as e: # noqa: BLE001, PERF203 - errors.append(e) - if errors: - msg = f"Errors during sync cleanup: {errors}" - raise RuntimeError(msg) + finalizer_errors.append(e) + if finalizer_errors: + raise exceptions.FinalizerError(finalizer_errors=finalizer_errors, is_async=False) diff --git a/modern_di/registries/providers_registry.py b/modern_di/registries/providers_registry.py index 3ee30a5..9ace295 100644 --- a/modern_di/registries/providers_registry.py +++ b/modern_di/registries/providers_registry.py @@ -1,6 +1,6 @@ import typing -from modern_di import errors, types +from modern_di import exceptions, types from modern_di.providers.abstract import AbstractProvider @@ -21,7 +21,7 @@ def find_provider(self, dependency_type: type[types.T]) -> AbstractProvider[type def register(self, provider_type: type, provider: AbstractProvider[typing.Any]) -> None: if provider_type in self._providers: - raise RuntimeError(errors.PROVIDER_DUPLICATE_TYPE_ERROR.format(provider_type=provider_type)) + raise exceptions.DuplicateProviderTypeError(provider_type=provider_type) self._providers[provider_type] = provider diff --git a/tests/providers/test_context_provider.py b/tests/providers/test_context_provider.py index 59e15b6..bc665c4 100644 --- a/tests/providers/test_context_provider.py +++ b/tests/providers/test_context_provider.py @@ -4,6 +4,7 @@ import pytest from modern_di import Container, Group, Scope, providers +from modern_di.exceptions import ArgumentResolutionError request_context_provider = providers.ContextProvider(scope=Scope.REQUEST, context_type=datetime.datetime) @@ -44,8 +45,12 @@ def test_context_provider_not_found() -> None: def test_context_provider_not_found_but_required() -> None: app_container = Container(groups=[MyGroup]) - with pytest.raises(RuntimeError, match=r"Argument arg1 of type cannot be resolved"): + with pytest.raises( + ArgumentResolutionError, match=r"Argument arg1 of type cannot be resolved" + ) as exc: app_container.resolve(SomeFactory) + assert exc.value.arg_name == "arg1" + assert exc.value.arg_type is datetime.datetime def test_context_provider_in_request_scope() -> None: diff --git a/tests/providers/test_factory.py b/tests/providers/test_factory.py index ed767aa..0e63602 100644 --- a/tests/providers/test_factory.py +++ b/tests/providers/test_factory.py @@ -5,6 +5,7 @@ import pytest from modern_di import Container, Group, Scope, providers +from modern_di.exceptions import ArgumentResolutionError, ScopeNotInitializedError @dataclasses.dataclass(kw_only=True, slots=True, frozen=True) @@ -62,8 +63,10 @@ def test_app_factory_skip_creator_parsing() -> None: def test_app_factory_unresolvable() -> None: app_container = Container(groups=[MyGroup]) - with pytest.raises(RuntimeError, match="Argument dep1 of type cannot be resolved"): + with pytest.raises(ArgumentResolutionError, match="Argument dep1 of type cannot be resolved") as exc: app_container.validate_provider(MyGroup.app_factory_unresolvable) + assert exc.value.arg_name == "dep1" + assert exc.value.arg_type is str def test_func_with_union_factory() -> None: @@ -74,7 +77,7 @@ def test_func_with_union_factory() -> None: def test_func_with_broken_annotation() -> None: app_container = Container(groups=[MyGroup]) - with pytest.raises(RuntimeError, match="Argument dep1 of type None cannot be resolved"): + with pytest.raises(ArgumentResolutionError, match="Argument dep1 of type None cannot be resolved"): app_container.validate_provider(MyGroup.func_with_broken_annotation) @@ -156,8 +159,13 @@ def test_factory_overridden_request_scope() -> None: def test_factory_scope_is_not_initialized() -> None: app_container = Container(groups=[MyGroup]) - with pytest.raises(RuntimeError, match=r"Provider of scope REQUEST cannot be resolved in container of scope APP."): + with pytest.raises( + ScopeNotInitializedError, + match=r"Provider of scope REQUEST cannot be resolved in container of scope APP.", + ) as exc: app_container.resolve_provider(MyGroup.request_factory) + assert exc.value.provider_scope == Scope.REQUEST + assert exc.value.container_scope == Scope.APP def test_factory_self_reference() -> None: diff --git a/tests/providers/test_singleton.py b/tests/providers/test_singleton.py index 908375c..814a00e 100644 --- a/tests/providers/test_singleton.py +++ b/tests/providers/test_singleton.py @@ -6,6 +6,7 @@ import pytest from modern_di import Container, Group, Scope, providers +from modern_di.exceptions import FinalizerError @dataclasses.dataclass(kw_only=True, slots=True, frozen=True) @@ -113,8 +114,10 @@ class BrokenGroup(Group): app_container.resolve_provider(BrokenGroup.first) app_container.resolve_provider(BrokenGroup.second) - with pytest.raises(RuntimeError, match="Errors during sync cleanup"): + with pytest.raises(FinalizerError, match="Errors during sync cleanup") as exc: app_container.close_sync() + assert exc.value.is_async is False + assert len(exc.value.finalizer_errors) == 1 assert cleaned_up == ["done"] @@ -146,8 +149,10 @@ class BrokenAsyncGroup(Group): app_container.resolve_provider(BrokenAsyncGroup.first) app_container.resolve_provider(BrokenAsyncGroup.second) - with pytest.raises(RuntimeError, match="Errors during async cleanup"): + with pytest.raises(FinalizerError, match="Errors during async cleanup") as exc: await app_container.close_async() + assert exc.value.is_async is True + assert len(exc.value.finalizer_errors) == 1 assert cleaned_up == ["done"] diff --git a/tests/registries/test_providers_registry.py b/tests/registries/test_providers_registry.py index a47a77b..c37beda 100644 --- a/tests/registries/test_providers_registry.py +++ b/tests/registries/test_providers_registry.py @@ -1,6 +1,7 @@ import pytest from modern_di import providers +from modern_di.exceptions import DuplicateProviderTypeError from modern_di.registries.providers_registry import ProvidersRegistry @@ -15,7 +16,6 @@ def test_providers_registry_add_provider_duplicates() -> None: providers_registry = ProvidersRegistry() providers_registry.add_providers(str_factory) - with ( - pytest.raises(RuntimeError, match="Provider is duplicated by type "), - ): + with pytest.raises(DuplicateProviderTypeError, match="Provider is duplicated by type ") as exc: providers_registry.add_providers(str_factory) + assert exc.value.provider_type is str diff --git a/tests/test_container.py b/tests/test_container.py index 1af0b1f..5b56fa7 100644 --- a/tests/test_container.py +++ b/tests/test_container.py @@ -4,6 +4,13 @@ import pytest from modern_di import Container, Group, Scope, providers +from modern_di.exceptions import ( + CircularDependencyError, + InvalidChildScopeError, + MaxScopeReachedError, + ProviderNotRegisteredError, + ScopeSkippedError, +) def test_container_prevent_copy() -> None: @@ -16,8 +23,9 @@ def test_container_prevent_copy() -> None: def test_container_scope_skipped() -> None: app_factory = providers.Factory(creator=lambda: "test") container = Container(scope=Scope.REQUEST) - with pytest.raises(RuntimeError, match=r"Provider of scope APP is skipped in the chain of containers."): + with pytest.raises(ScopeSkippedError, match=r"Provider of scope APP is skipped in the chain of containers.") as exc: container.resolve_provider(app_factory) + assert exc.value.provider_scope == Scope.APP def test_container_build_child() -> None: @@ -29,20 +37,27 @@ def test_container_build_child() -> None: def test_container_scope_limit_reached() -> None: step_container = Container(scope=Scope.STEP) - with pytest.raises(RuntimeError, match=r"Max scope of STEP is reached."): + with pytest.raises(MaxScopeReachedError, match=r"Max scope of STEP is reached.") as exc: step_container.build_child_container() + assert exc.value.parent_scope == Scope.STEP def test_container_build_child_wrong_scope() -> None: app_container = Container() - with pytest.raises(RuntimeError, match="Scope of child container cannot be"): + with pytest.raises(InvalidChildScopeError, match="Scope of child container cannot be") as exc: app_container.build_child_container(scope=Scope.APP) + assert exc.value.parent_scope == Scope.APP + assert exc.value.child_scope == Scope.APP def test_container_resolve_missing_provider() -> None: app_container = Container() - with pytest.raises(RuntimeError, match=r"Provider of type is not registered in providers registry."): + with pytest.raises( + ProviderNotRegisteredError, + match=r"Provider of type is not registered in providers registry.", + ) as exc: assert app_container.resolve(str) is None + assert exc.value.provider_type is str def test_container_sync_context_manager() -> None: @@ -82,14 +97,17 @@ class CycleGroup(Group): def test_validate_on_creation() -> None: - with pytest.raises(RuntimeError, match="Circular dependency detected"): + with pytest.raises(CircularDependencyError, match="Circular dependency detected"): Container(groups=[CycleGroup], validate=True) def test_validate_detects_cycle() -> None: container = Container(groups=[CycleGroup]) - with pytest.raises(RuntimeError, match="Circular dependency detected: CycleA -> CycleB -> CycleA"): + with pytest.raises( + CircularDependencyError, match="Circular dependency detected: CycleA -> CycleB -> CycleA" + ) as exc: container.validate() + assert exc.value.cycle_path == ["CycleA", "CycleB", "CycleA"] def test_validate_passes_for_valid_graph() -> None: diff --git a/tests/test_group.py b/tests/test_group.py index 3499e6e..06275fa 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -1,10 +1,12 @@ import pytest from modern_di import Group +from modern_di.exceptions import GroupInstantiationError def test_group_cannot_be_instantiated() -> None: class Dependencies(Group): ... - with pytest.raises(RuntimeError, match="Dependencies cannot be instantiated"): + with pytest.raises(GroupInstantiationError, match="Dependencies cannot be instantiated") as exc: Dependencies() + assert exc.value.group_name == "Dependencies"