From da26c5d8863d9b3cebf556feeb464ae6961a6df6 Mon Sep 17 00:00:00 2001 From: Artur Shiriev Date: Sat, 2 May 2026 12:39:44 +0300 Subject: [PATCH] Replace RuntimeError raises with custom exception hierarchy Introduces modern_di/exceptions.py with ModernDIError(RuntimeError) and a hierarchy keyed to the existing message templates: ContainerError, ResolutionError, RegistrationError, FinalizerError, GroupInstantiationError. Each subclass carries structured fields so callers can branch on type and introspect details instead of matching on RuntimeError message strings. Inheriting from RuntimeError preserves backwards compatibility for anyone catching the broad type today. Tests updated to assert the typed exceptions and verify the structured fields. Closes #177 Co-Authored-By: Claude Opus 4.7 --- modern_di/container.py | 28 ++---- modern_di/exceptions.py | 104 ++++++++++++++++++++ modern_di/group.py | 4 +- modern_di/providers/factory.py | 14 +-- modern_di/registries/cache_registry.py | 20 ++-- modern_di/registries/providers_registry.py | 4 +- tests/providers/test_context_provider.py | 7 +- tests/providers/test_factory.py | 14 ++- tests/providers/test_singleton.py | 9 +- tests/registries/test_providers_registry.py | 6 +- tests/test_container.py | 30 ++++-- tests/test_group.py | 4 +- 12 files changed, 186 insertions(+), 58 deletions(-) create mode 100644 modern_di/exceptions.py diff --git a/modern_di/container.py b/modern_di/container.py index 0269945..8ecf828 100644 --- a/modern_di/container.py +++ b/modern_di/container.py @@ -3,7 +3,7 @@ import typing_extensions -from modern_di import errors, types +from modern_di import exceptions, types from modern_di.group import Group from modern_di.providers.abstract import AbstractProvider from modern_di.providers.container_provider import container_provider @@ -62,39 +62,31 @@ def build_child_container( self, context: dict[type[typing.Any], typing.Any] | None = None, scope: Scope | None = None ) -> "typing_extensions.Self": if scope and scope <= self.scope: - raise RuntimeError( - errors.CONTAINER_SCOPE_IS_LOWER_ERROR.format( - parent_scope=self.scope.name, - child_scope=scope.name, - allowed_scopes=[x.name for x in Scope if x > self.scope], - ) + raise exceptions.InvalidChildScopeError( + parent_scope=self.scope, + child_scope=scope, + allowed_scopes=[x.name for x in Scope if x > self.scope], ) if not scope: try: scope = self.scope.__class__(self.scope.value + 1) except ValueError as exc: - raise RuntimeError( - errors.CONTAINER_MAX_SCOPE_REACHED_ERROR.format(parent_scope=self.scope.name) - ) from exc + raise exceptions.MaxScopeReachedError(parent_scope=self.scope) from exc return self.__class__(scope=scope, parent_container=self, context=context) def find_container(self, scope: Scope) -> "typing_extensions.Self": if scope not in self.scope_map: if scope > self.scope: - raise RuntimeError( - errors.CONTAINER_NOT_INITIALIZED_SCOPE_ERROR.format( - provider_scope=scope.name, container_scope=self.scope.name - ) - ) - raise RuntimeError(errors.CONTAINER_SCOPE_IS_SKIPPED_ERROR.format(provider_scope=scope.name)) + raise exceptions.ScopeNotInitializedError(provider_scope=scope, container_scope=self.scope) + raise exceptions.ScopeSkippedError(provider_scope=scope) return self.scope_map[scope] def resolve(self, dependency_type: type[types.T]) -> types.T: provider = self.providers_registry.find_provider(dependency_type) if not provider: - raise RuntimeError(errors.CONTAINER_MISSING_PROVIDER_ERROR.format(provider_type=dependency_type)) + raise exceptions.ProviderNotRegisteredError(provider_type=dependency_type) return self.resolve_provider(provider) @@ -123,7 +115,7 @@ def _visit(provider: AbstractProvider[typing.Any]) -> None: cycle_start = next(i for i, p in enumerate(path) if p.provider_id == pid) cycle_names = [p.bound_type.__name__ if p.bound_type else repr(p) for p in path[cycle_start:]] cycle_names.append(cycle_names[0]) - raise RuntimeError(errors.CYCLE_DEPENDENCY_ERROR.format(cycle_path=" -> ".join(cycle_names))) + raise exceptions.CircularDependencyError(cycle_path=cycle_names) visiting.add(pid) path.append(provider) diff --git a/modern_di/exceptions.py b/modern_di/exceptions.py new file mode 100644 index 0000000..5270525 --- /dev/null +++ b/modern_di/exceptions.py @@ -0,0 +1,104 @@ +import typing + +from modern_di import errors +from modern_di.scope import Scope + + +class ModernDIError(RuntimeError): + """Base class for all modern-di errors. Inherits from RuntimeError for backwards compatibility.""" + + +class ContainerError(ModernDIError): + """Base class for container and scope errors.""" + + +class InvalidChildScopeError(ContainerError): + def __init__(self, *, parent_scope: Scope, child_scope: Scope, allowed_scopes: list[str]) -> None: + self.parent_scope = parent_scope + self.child_scope = child_scope + self.allowed_scopes = allowed_scopes + super().__init__( + errors.CONTAINER_SCOPE_IS_LOWER_ERROR.format( + parent_scope=parent_scope.name, + child_scope=child_scope.name, + allowed_scopes=allowed_scopes, + ) + ) + + +class MaxScopeReachedError(ContainerError): + def __init__(self, *, parent_scope: Scope) -> None: + self.parent_scope = parent_scope + super().__init__(errors.CONTAINER_MAX_SCOPE_REACHED_ERROR.format(parent_scope=parent_scope.name)) + + +class ScopeNotInitializedError(ContainerError): + def __init__(self, *, provider_scope: Scope, container_scope: Scope) -> None: + self.provider_scope = provider_scope + self.container_scope = container_scope + super().__init__( + errors.CONTAINER_NOT_INITIALIZED_SCOPE_ERROR.format( + provider_scope=provider_scope.name, + container_scope=container_scope.name, + ) + ) + + +class ScopeSkippedError(ContainerError): + def __init__(self, *, provider_scope: Scope) -> None: + self.provider_scope = provider_scope + super().__init__(errors.CONTAINER_SCOPE_IS_SKIPPED_ERROR.format(provider_scope=provider_scope.name)) + + +class ResolutionError(ModernDIError): + """Base class for errors raised while resolving a provider.""" + + +class ProviderNotRegisteredError(ResolutionError): + def __init__(self, *, provider_type: type) -> None: + self.provider_type = provider_type + super().__init__(errors.CONTAINER_MISSING_PROVIDER_ERROR.format(provider_type=provider_type)) + + +class ArgumentResolutionError(ResolutionError): + def __init__(self, *, arg_name: str, arg_type: typing.Any, bound_type: typing.Any) -> None: # noqa: ANN401 + self.arg_name = arg_name + self.arg_type = arg_type + self.bound_type = bound_type + super().__init__( + errors.FACTORY_ARGUMENT_RESOLUTION_ERROR.format( + arg_name=arg_name, + arg_type=arg_type, + bound_type=bound_type, + ) + ) + + +class CircularDependencyError(ResolutionError): + def __init__(self, *, cycle_path: list[str]) -> None: + self.cycle_path = cycle_path + super().__init__(errors.CYCLE_DEPENDENCY_ERROR.format(cycle_path=" -> ".join(cycle_path))) + + +class RegistrationError(ModernDIError): + """Base class for errors raised while registering providers.""" + + +class DuplicateProviderTypeError(RegistrationError): + def __init__(self, *, provider_type: type) -> None: + self.provider_type = provider_type + super().__init__(errors.PROVIDER_DUPLICATE_TYPE_ERROR.format(provider_type=provider_type)) + + +class FinalizerError(ModernDIError): + def __init__(self, *, finalizer_errors: list[BaseException], is_async: bool) -> None: + self.finalizer_errors = finalizer_errors + self.is_async = is_async + kind = "async" if is_async else "sync" + super().__init__(f"Errors during {kind} cleanup: {finalizer_errors}") + + +class GroupInstantiationError(ModernDIError): + def __init__(self, *, group_name: str) -> None: + self.group_name = group_name + super().__init__(f"{group_name} cannot be instantiated") diff --git a/modern_di/group.py b/modern_di/group.py index 404bc6d..40b20e8 100644 --- a/modern_di/group.py +++ b/modern_di/group.py @@ -1,5 +1,6 @@ import typing +from modern_di import exceptions from modern_di.providers.abstract import AbstractProvider @@ -15,8 +16,7 @@ class Group: providers: list[AbstractProvider[typing.Any]] def __new__(cls, *_: typing.Any, **__: typing.Any) -> "typing_extensions.Self": # noqa: ANN401 - msg = f"{cls.__name__} cannot be instantiated" - raise RuntimeError(msg) + raise exceptions.GroupInstantiationError(group_name=cls.__name__) @classmethod def get_providers(cls) -> list[AbstractProvider[typing.Any]]: diff --git a/modern_di/providers/factory.py b/modern_di/providers/factory.py index 41005fc..e1e36ea 100644 --- a/modern_di/providers/factory.py +++ b/modern_di/providers/factory.py @@ -3,7 +3,7 @@ import typing import warnings -from modern_di import errors, types +from modern_di import exceptions, types from modern_di.providers import ContextProvider from modern_di.providers.abstract import AbstractProvider from modern_di.scope import Scope @@ -82,18 +82,14 @@ def _compile_kwargs(self, container: "Container") -> dict[str, typing.Any]: if provider: result[k] = provider if is_kwarg_not_found and isinstance(provider, ContextProvider) and provider.resolve(container) is None: - raise RuntimeError( - errors.FACTORY_ARGUMENT_RESOLUTION_ERROR.format( - arg_name=k, arg_type=v.arg_type, bound_type=self.bound_type or self._creator - ) + raise exceptions.ArgumentResolutionError( + arg_name=k, arg_type=v.arg_type, bound_type=self.bound_type or self._creator ) continue if v.default == types.UNSET and is_kwarg_not_found: - raise RuntimeError( - errors.FACTORY_ARGUMENT_RESOLUTION_ERROR.format( - arg_name=k, arg_type=v.arg_type, bound_type=self.bound_type or self._creator - ) + raise exceptions.ArgumentResolutionError( + arg_name=k, arg_type=v.arg_type, bound_type=self.bound_type or self._creator ) if self._kwargs: diff --git a/modern_di/registries/cache_registry.py b/modern_di/registries/cache_registry.py index 1a0c40e..d7ef8cc 100644 --- a/modern_di/registries/cache_registry.py +++ b/modern_di/registries/cache_registry.py @@ -2,7 +2,7 @@ import typing import warnings -from modern_di import types +from modern_di import exceptions, types from modern_di.providers import CacheSettings, Factory @@ -52,23 +52,21 @@ def fetch_cache_item(self, provider: Factory[types.T_co]) -> CacheItem: return self._items.setdefault(provider.provider_id, CacheItem(settings=provider.cache_settings)) async def close_async(self) -> None: - errors: list[BaseException] = [] + finalizer_errors: list[BaseException] = [] for cache_item in self._items.values(): try: await cache_item.close_async() except Exception as e: # noqa: BLE001, PERF203 - errors.append(e) - if errors: - msg = f"Errors during async cleanup: {errors}" - raise RuntimeError(msg) + finalizer_errors.append(e) + if finalizer_errors: + raise exceptions.FinalizerError(finalizer_errors=finalizer_errors, is_async=True) def close_sync(self) -> None: - errors: list[BaseException] = [] + finalizer_errors: list[BaseException] = [] for cache_item in self._items.values(): try: cache_item.close_sync() except Exception as e: # noqa: BLE001, PERF203 - errors.append(e) - if errors: - msg = f"Errors during sync cleanup: {errors}" - raise RuntimeError(msg) + finalizer_errors.append(e) + if finalizer_errors: + raise exceptions.FinalizerError(finalizer_errors=finalizer_errors, is_async=False) diff --git a/modern_di/registries/providers_registry.py b/modern_di/registries/providers_registry.py index 3ee30a5..9ace295 100644 --- a/modern_di/registries/providers_registry.py +++ b/modern_di/registries/providers_registry.py @@ -1,6 +1,6 @@ import typing -from modern_di import errors, types +from modern_di import exceptions, types from modern_di.providers.abstract import AbstractProvider @@ -21,7 +21,7 @@ def find_provider(self, dependency_type: type[types.T]) -> AbstractProvider[type def register(self, provider_type: type, provider: AbstractProvider[typing.Any]) -> None: if provider_type in self._providers: - raise RuntimeError(errors.PROVIDER_DUPLICATE_TYPE_ERROR.format(provider_type=provider_type)) + raise exceptions.DuplicateProviderTypeError(provider_type=provider_type) self._providers[provider_type] = provider diff --git a/tests/providers/test_context_provider.py b/tests/providers/test_context_provider.py index 59e15b6..bc665c4 100644 --- a/tests/providers/test_context_provider.py +++ b/tests/providers/test_context_provider.py @@ -4,6 +4,7 @@ import pytest from modern_di import Container, Group, Scope, providers +from modern_di.exceptions import ArgumentResolutionError request_context_provider = providers.ContextProvider(scope=Scope.REQUEST, context_type=datetime.datetime) @@ -44,8 +45,12 @@ def test_context_provider_not_found() -> None: def test_context_provider_not_found_but_required() -> None: app_container = Container(groups=[MyGroup]) - with pytest.raises(RuntimeError, match=r"Argument arg1 of type cannot be resolved"): + with pytest.raises( + ArgumentResolutionError, match=r"Argument arg1 of type cannot be resolved" + ) as exc: app_container.resolve(SomeFactory) + assert exc.value.arg_name == "arg1" + assert exc.value.arg_type is datetime.datetime def test_context_provider_in_request_scope() -> None: diff --git a/tests/providers/test_factory.py b/tests/providers/test_factory.py index ed767aa..0e63602 100644 --- a/tests/providers/test_factory.py +++ b/tests/providers/test_factory.py @@ -5,6 +5,7 @@ import pytest from modern_di import Container, Group, Scope, providers +from modern_di.exceptions import ArgumentResolutionError, ScopeNotInitializedError @dataclasses.dataclass(kw_only=True, slots=True, frozen=True) @@ -62,8 +63,10 @@ def test_app_factory_skip_creator_parsing() -> None: def test_app_factory_unresolvable() -> None: app_container = Container(groups=[MyGroup]) - with pytest.raises(RuntimeError, match="Argument dep1 of type cannot be resolved"): + with pytest.raises(ArgumentResolutionError, match="Argument dep1 of type cannot be resolved") as exc: app_container.validate_provider(MyGroup.app_factory_unresolvable) + assert exc.value.arg_name == "dep1" + assert exc.value.arg_type is str def test_func_with_union_factory() -> None: @@ -74,7 +77,7 @@ def test_func_with_union_factory() -> None: def test_func_with_broken_annotation() -> None: app_container = Container(groups=[MyGroup]) - with pytest.raises(RuntimeError, match="Argument dep1 of type None cannot be resolved"): + with pytest.raises(ArgumentResolutionError, match="Argument dep1 of type None cannot be resolved"): app_container.validate_provider(MyGroup.func_with_broken_annotation) @@ -156,8 +159,13 @@ def test_factory_overridden_request_scope() -> None: def test_factory_scope_is_not_initialized() -> None: app_container = Container(groups=[MyGroup]) - with pytest.raises(RuntimeError, match=r"Provider of scope REQUEST cannot be resolved in container of scope APP."): + with pytest.raises( + ScopeNotInitializedError, + match=r"Provider of scope REQUEST cannot be resolved in container of scope APP.", + ) as exc: app_container.resolve_provider(MyGroup.request_factory) + assert exc.value.provider_scope == Scope.REQUEST + assert exc.value.container_scope == Scope.APP def test_factory_self_reference() -> None: diff --git a/tests/providers/test_singleton.py b/tests/providers/test_singleton.py index 908375c..814a00e 100644 --- a/tests/providers/test_singleton.py +++ b/tests/providers/test_singleton.py @@ -6,6 +6,7 @@ import pytest from modern_di import Container, Group, Scope, providers +from modern_di.exceptions import FinalizerError @dataclasses.dataclass(kw_only=True, slots=True, frozen=True) @@ -113,8 +114,10 @@ class BrokenGroup(Group): app_container.resolve_provider(BrokenGroup.first) app_container.resolve_provider(BrokenGroup.second) - with pytest.raises(RuntimeError, match="Errors during sync cleanup"): + with pytest.raises(FinalizerError, match="Errors during sync cleanup") as exc: app_container.close_sync() + assert exc.value.is_async is False + assert len(exc.value.finalizer_errors) == 1 assert cleaned_up == ["done"] @@ -146,8 +149,10 @@ class BrokenAsyncGroup(Group): app_container.resolve_provider(BrokenAsyncGroup.first) app_container.resolve_provider(BrokenAsyncGroup.second) - with pytest.raises(RuntimeError, match="Errors during async cleanup"): + with pytest.raises(FinalizerError, match="Errors during async cleanup") as exc: await app_container.close_async() + assert exc.value.is_async is True + assert len(exc.value.finalizer_errors) == 1 assert cleaned_up == ["done"] diff --git a/tests/registries/test_providers_registry.py b/tests/registries/test_providers_registry.py index a47a77b..c37beda 100644 --- a/tests/registries/test_providers_registry.py +++ b/tests/registries/test_providers_registry.py @@ -1,6 +1,7 @@ import pytest from modern_di import providers +from modern_di.exceptions import DuplicateProviderTypeError from modern_di.registries.providers_registry import ProvidersRegistry @@ -15,7 +16,6 @@ def test_providers_registry_add_provider_duplicates() -> None: providers_registry = ProvidersRegistry() providers_registry.add_providers(str_factory) - with ( - pytest.raises(RuntimeError, match="Provider is duplicated by type "), - ): + with pytest.raises(DuplicateProviderTypeError, match="Provider is duplicated by type ") as exc: providers_registry.add_providers(str_factory) + assert exc.value.provider_type is str diff --git a/tests/test_container.py b/tests/test_container.py index 1af0b1f..5b56fa7 100644 --- a/tests/test_container.py +++ b/tests/test_container.py @@ -4,6 +4,13 @@ import pytest from modern_di import Container, Group, Scope, providers +from modern_di.exceptions import ( + CircularDependencyError, + InvalidChildScopeError, + MaxScopeReachedError, + ProviderNotRegisteredError, + ScopeSkippedError, +) def test_container_prevent_copy() -> None: @@ -16,8 +23,9 @@ def test_container_prevent_copy() -> None: def test_container_scope_skipped() -> None: app_factory = providers.Factory(creator=lambda: "test") container = Container(scope=Scope.REQUEST) - with pytest.raises(RuntimeError, match=r"Provider of scope APP is skipped in the chain of containers."): + with pytest.raises(ScopeSkippedError, match=r"Provider of scope APP is skipped in the chain of containers.") as exc: container.resolve_provider(app_factory) + assert exc.value.provider_scope == Scope.APP def test_container_build_child() -> None: @@ -29,20 +37,27 @@ def test_container_build_child() -> None: def test_container_scope_limit_reached() -> None: step_container = Container(scope=Scope.STEP) - with pytest.raises(RuntimeError, match=r"Max scope of STEP is reached."): + with pytest.raises(MaxScopeReachedError, match=r"Max scope of STEP is reached.") as exc: step_container.build_child_container() + assert exc.value.parent_scope == Scope.STEP def test_container_build_child_wrong_scope() -> None: app_container = Container() - with pytest.raises(RuntimeError, match="Scope of child container cannot be"): + with pytest.raises(InvalidChildScopeError, match="Scope of child container cannot be") as exc: app_container.build_child_container(scope=Scope.APP) + assert exc.value.parent_scope == Scope.APP + assert exc.value.child_scope == Scope.APP def test_container_resolve_missing_provider() -> None: app_container = Container() - with pytest.raises(RuntimeError, match=r"Provider of type is not registered in providers registry."): + with pytest.raises( + ProviderNotRegisteredError, + match=r"Provider of type is not registered in providers registry.", + ) as exc: assert app_container.resolve(str) is None + assert exc.value.provider_type is str def test_container_sync_context_manager() -> None: @@ -82,14 +97,17 @@ class CycleGroup(Group): def test_validate_on_creation() -> None: - with pytest.raises(RuntimeError, match="Circular dependency detected"): + with pytest.raises(CircularDependencyError, match="Circular dependency detected"): Container(groups=[CycleGroup], validate=True) def test_validate_detects_cycle() -> None: container = Container(groups=[CycleGroup]) - with pytest.raises(RuntimeError, match="Circular dependency detected: CycleA -> CycleB -> CycleA"): + with pytest.raises( + CircularDependencyError, match="Circular dependency detected: CycleA -> CycleB -> CycleA" + ) as exc: container.validate() + assert exc.value.cycle_path == ["CycleA", "CycleB", "CycleA"] def test_validate_passes_for_valid_graph() -> None: diff --git a/tests/test_group.py b/tests/test_group.py index 3499e6e..06275fa 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -1,10 +1,12 @@ import pytest from modern_di import Group +from modern_di.exceptions import GroupInstantiationError def test_group_cannot_be_instantiated() -> None: class Dependencies(Group): ... - with pytest.raises(RuntimeError, match="Dependencies cannot be instantiated"): + with pytest.raises(GroupInstantiationError, match="Dependencies cannot be instantiated") as exc: Dependencies() + assert exc.value.group_name == "Dependencies"