Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 10 additions & 18 deletions modern_di/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import typing_extensions

from modern_di import errors, types
from modern_di import exceptions, types
from modern_di.group import Group
from modern_di.providers.abstract import AbstractProvider
from modern_di.providers.container_provider import container_provider
Expand Down Expand Up @@ -62,39 +62,31 @@ def build_child_container(
self, context: dict[type[typing.Any], typing.Any] | None = None, scope: Scope | None = None
) -> "typing_extensions.Self":
if scope and scope <= self.scope:
raise RuntimeError(
errors.CONTAINER_SCOPE_IS_LOWER_ERROR.format(
parent_scope=self.scope.name,
child_scope=scope.name,
allowed_scopes=[x.name for x in Scope if x > self.scope],
)
raise exceptions.InvalidChildScopeError(
parent_scope=self.scope,
child_scope=scope,
allowed_scopes=[x.name for x in Scope if x > self.scope],
)

if not scope:
try:
scope = self.scope.__class__(self.scope.value + 1)
except ValueError as exc:
raise RuntimeError(
errors.CONTAINER_MAX_SCOPE_REACHED_ERROR.format(parent_scope=self.scope.name)
) from exc
raise exceptions.MaxScopeReachedError(parent_scope=self.scope) from exc

return self.__class__(scope=scope, parent_container=self, context=context)

def find_container(self, scope: Scope) -> "typing_extensions.Self":
if scope not in self.scope_map:
if scope > self.scope:
raise RuntimeError(
errors.CONTAINER_NOT_INITIALIZED_SCOPE_ERROR.format(
provider_scope=scope.name, container_scope=self.scope.name
)
)
raise RuntimeError(errors.CONTAINER_SCOPE_IS_SKIPPED_ERROR.format(provider_scope=scope.name))
raise exceptions.ScopeNotInitializedError(provider_scope=scope, container_scope=self.scope)
raise exceptions.ScopeSkippedError(provider_scope=scope)
return self.scope_map[scope]

def resolve(self, dependency_type: type[types.T]) -> types.T:
provider = self.providers_registry.find_provider(dependency_type)
if not provider:
raise RuntimeError(errors.CONTAINER_MISSING_PROVIDER_ERROR.format(provider_type=dependency_type))
raise exceptions.ProviderNotRegisteredError(provider_type=dependency_type)

return self.resolve_provider(provider)

Expand Down Expand Up @@ -123,7 +115,7 @@ def _visit(provider: AbstractProvider[typing.Any]) -> None:
cycle_start = next(i for i, p in enumerate(path) if p.provider_id == pid)
cycle_names = [p.bound_type.__name__ if p.bound_type else repr(p) for p in path[cycle_start:]]
cycle_names.append(cycle_names[0])
raise RuntimeError(errors.CYCLE_DEPENDENCY_ERROR.format(cycle_path=" -> ".join(cycle_names)))
raise exceptions.CircularDependencyError(cycle_path=cycle_names)

visiting.add(pid)
path.append(provider)
Expand Down
104 changes: 104 additions & 0 deletions modern_di/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import typing

from modern_di import errors
from modern_di.scope import Scope


class ModernDIError(RuntimeError):
"""Base class for all modern-di errors. Inherits from RuntimeError for backwards compatibility."""


class ContainerError(ModernDIError):
"""Base class for container and scope errors."""


class InvalidChildScopeError(ContainerError):
def __init__(self, *, parent_scope: Scope, child_scope: Scope, allowed_scopes: list[str]) -> None:
self.parent_scope = parent_scope
self.child_scope = child_scope
self.allowed_scopes = allowed_scopes
super().__init__(
errors.CONTAINER_SCOPE_IS_LOWER_ERROR.format(
parent_scope=parent_scope.name,
child_scope=child_scope.name,
allowed_scopes=allowed_scopes,
)
)


class MaxScopeReachedError(ContainerError):
def __init__(self, *, parent_scope: Scope) -> None:
self.parent_scope = parent_scope
super().__init__(errors.CONTAINER_MAX_SCOPE_REACHED_ERROR.format(parent_scope=parent_scope.name))


class ScopeNotInitializedError(ContainerError):
def __init__(self, *, provider_scope: Scope, container_scope: Scope) -> None:
self.provider_scope = provider_scope
self.container_scope = container_scope
super().__init__(
errors.CONTAINER_NOT_INITIALIZED_SCOPE_ERROR.format(
provider_scope=provider_scope.name,
container_scope=container_scope.name,
)
)


class ScopeSkippedError(ContainerError):
def __init__(self, *, provider_scope: Scope) -> None:
self.provider_scope = provider_scope
super().__init__(errors.CONTAINER_SCOPE_IS_SKIPPED_ERROR.format(provider_scope=provider_scope.name))


class ResolutionError(ModernDIError):
"""Base class for errors raised while resolving a provider."""


class ProviderNotRegisteredError(ResolutionError):
def __init__(self, *, provider_type: type) -> None:
self.provider_type = provider_type
super().__init__(errors.CONTAINER_MISSING_PROVIDER_ERROR.format(provider_type=provider_type))


class ArgumentResolutionError(ResolutionError):
def __init__(self, *, arg_name: str, arg_type: typing.Any, bound_type: typing.Any) -> None: # noqa: ANN401
self.arg_name = arg_name
self.arg_type = arg_type
self.bound_type = bound_type
super().__init__(
errors.FACTORY_ARGUMENT_RESOLUTION_ERROR.format(
arg_name=arg_name,
arg_type=arg_type,
bound_type=bound_type,
)
)


class CircularDependencyError(ResolutionError):
def __init__(self, *, cycle_path: list[str]) -> None:
self.cycle_path = cycle_path
super().__init__(errors.CYCLE_DEPENDENCY_ERROR.format(cycle_path=" -> ".join(cycle_path)))


class RegistrationError(ModernDIError):
"""Base class for errors raised while registering providers."""


class DuplicateProviderTypeError(RegistrationError):
def __init__(self, *, provider_type: type) -> None:
self.provider_type = provider_type
super().__init__(errors.PROVIDER_DUPLICATE_TYPE_ERROR.format(provider_type=provider_type))


class FinalizerError(ModernDIError):
def __init__(self, *, finalizer_errors: list[BaseException], is_async: bool) -> None:
self.finalizer_errors = finalizer_errors
self.is_async = is_async
kind = "async" if is_async else "sync"
super().__init__(f"Errors during {kind} cleanup: {finalizer_errors}")


class GroupInstantiationError(ModernDIError):
def __init__(self, *, group_name: str) -> None:
self.group_name = group_name
super().__init__(f"{group_name} cannot be instantiated")
4 changes: 2 additions & 2 deletions modern_di/group.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import typing

from modern_di import exceptions
from modern_di.providers.abstract import AbstractProvider


Expand All @@ -15,8 +16,7 @@ class Group:
providers: list[AbstractProvider[typing.Any]]

def __new__(cls, *_: typing.Any, **__: typing.Any) -> "typing_extensions.Self": # noqa: ANN401
msg = f"{cls.__name__} cannot be instantiated"
raise RuntimeError(msg)
raise exceptions.GroupInstantiationError(group_name=cls.__name__)

@classmethod
def get_providers(cls) -> list[AbstractProvider[typing.Any]]:
Expand Down
14 changes: 5 additions & 9 deletions modern_di/providers/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import typing
import warnings

from modern_di import errors, types
from modern_di import exceptions, types
from modern_di.providers import ContextProvider
from modern_di.providers.abstract import AbstractProvider
from modern_di.scope import Scope
Expand Down Expand Up @@ -82,18 +82,14 @@ def _compile_kwargs(self, container: "Container") -> dict[str, typing.Any]:
if provider:
result[k] = provider
if is_kwarg_not_found and isinstance(provider, ContextProvider) and provider.resolve(container) is None:
raise RuntimeError(
errors.FACTORY_ARGUMENT_RESOLUTION_ERROR.format(
arg_name=k, arg_type=v.arg_type, bound_type=self.bound_type or self._creator
)
raise exceptions.ArgumentResolutionError(
arg_name=k, arg_type=v.arg_type, bound_type=self.bound_type or self._creator
)
continue

if v.default == types.UNSET and is_kwarg_not_found:
raise RuntimeError(
errors.FACTORY_ARGUMENT_RESOLUTION_ERROR.format(
arg_name=k, arg_type=v.arg_type, bound_type=self.bound_type or self._creator
)
raise exceptions.ArgumentResolutionError(
arg_name=k, arg_type=v.arg_type, bound_type=self.bound_type or self._creator
)

if self._kwargs:
Expand Down
20 changes: 9 additions & 11 deletions modern_di/registries/cache_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import typing
import warnings

from modern_di import types
from modern_di import exceptions, types
from modern_di.providers import CacheSettings, Factory


Expand Down Expand Up @@ -52,23 +52,21 @@ def fetch_cache_item(self, provider: Factory[types.T_co]) -> CacheItem:
return self._items.setdefault(provider.provider_id, CacheItem(settings=provider.cache_settings))

async def close_async(self) -> None:
errors: list[BaseException] = []
finalizer_errors: list[BaseException] = []
for cache_item in self._items.values():
try:
await cache_item.close_async()
except Exception as e: # noqa: BLE001, PERF203
errors.append(e)
if errors:
msg = f"Errors during async cleanup: {errors}"
raise RuntimeError(msg)
finalizer_errors.append(e)
if finalizer_errors:
raise exceptions.FinalizerError(finalizer_errors=finalizer_errors, is_async=True)

def close_sync(self) -> None:
errors: list[BaseException] = []
finalizer_errors: list[BaseException] = []
for cache_item in self._items.values():
try:
cache_item.close_sync()
except Exception as e: # noqa: BLE001, PERF203
errors.append(e)
if errors:
msg = f"Errors during sync cleanup: {errors}"
raise RuntimeError(msg)
finalizer_errors.append(e)
if finalizer_errors:
raise exceptions.FinalizerError(finalizer_errors=finalizer_errors, is_async=False)
4 changes: 2 additions & 2 deletions modern_di/registries/providers_registry.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import typing

from modern_di import errors, types
from modern_di import exceptions, types
from modern_di.providers.abstract import AbstractProvider


Expand All @@ -21,7 +21,7 @@ def find_provider(self, dependency_type: type[types.T]) -> AbstractProvider[type

def register(self, provider_type: type, provider: AbstractProvider[typing.Any]) -> None:
if provider_type in self._providers:
raise RuntimeError(errors.PROVIDER_DUPLICATE_TYPE_ERROR.format(provider_type=provider_type))
raise exceptions.DuplicateProviderTypeError(provider_type=provider_type)

self._providers[provider_type] = provider

Expand Down
7 changes: 6 additions & 1 deletion tests/providers/test_context_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest

from modern_di import Container, Group, Scope, providers
from modern_di.exceptions import ArgumentResolutionError


request_context_provider = providers.ContextProvider(scope=Scope.REQUEST, context_type=datetime.datetime)
Expand Down Expand Up @@ -44,8 +45,12 @@ def test_context_provider_not_found() -> None:

def test_context_provider_not_found_but_required() -> None:
app_container = Container(groups=[MyGroup])
with pytest.raises(RuntimeError, match=r"Argument arg1 of type <class 'datetime.datetime'> cannot be resolved"):
with pytest.raises(
ArgumentResolutionError, match=r"Argument arg1 of type <class 'datetime.datetime'> cannot be resolved"
) as exc:
app_container.resolve(SomeFactory)
assert exc.value.arg_name == "arg1"
assert exc.value.arg_type is datetime.datetime


def test_context_provider_in_request_scope() -> None:
Expand Down
14 changes: 11 additions & 3 deletions tests/providers/test_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest

from modern_di import Container, Group, Scope, providers
from modern_di.exceptions import ArgumentResolutionError, ScopeNotInitializedError


@dataclasses.dataclass(kw_only=True, slots=True, frozen=True)
Expand Down Expand Up @@ -62,8 +63,10 @@ def test_app_factory_skip_creator_parsing() -> None:

def test_app_factory_unresolvable() -> None:
app_container = Container(groups=[MyGroup])
with pytest.raises(RuntimeError, match="Argument dep1 of type <class 'str'> cannot be resolved"):
with pytest.raises(ArgumentResolutionError, match="Argument dep1 of type <class 'str'> cannot be resolved") as exc:
app_container.validate_provider(MyGroup.app_factory_unresolvable)
assert exc.value.arg_name == "dep1"
assert exc.value.arg_type is str


def test_func_with_union_factory() -> None:
Expand All @@ -74,7 +77,7 @@ def test_func_with_union_factory() -> None:

def test_func_with_broken_annotation() -> None:
app_container = Container(groups=[MyGroup])
with pytest.raises(RuntimeError, match="Argument dep1 of type None cannot be resolved"):
with pytest.raises(ArgumentResolutionError, match="Argument dep1 of type None cannot be resolved"):
app_container.validate_provider(MyGroup.func_with_broken_annotation)


Expand Down Expand Up @@ -156,8 +159,13 @@ def test_factory_overridden_request_scope() -> None:

def test_factory_scope_is_not_initialized() -> None:
app_container = Container(groups=[MyGroup])
with pytest.raises(RuntimeError, match=r"Provider of scope REQUEST cannot be resolved in container of scope APP."):
with pytest.raises(
ScopeNotInitializedError,
match=r"Provider of scope REQUEST cannot be resolved in container of scope APP.",
) as exc:
app_container.resolve_provider(MyGroup.request_factory)
assert exc.value.provider_scope == Scope.REQUEST
assert exc.value.container_scope == Scope.APP


def test_factory_self_reference() -> None:
Expand Down
9 changes: 7 additions & 2 deletions tests/providers/test_singleton.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest

from modern_di import Container, Group, Scope, providers
from modern_di.exceptions import FinalizerError


@dataclasses.dataclass(kw_only=True, slots=True, frozen=True)
Expand Down Expand Up @@ -113,8 +114,10 @@ class BrokenGroup(Group):
app_container.resolve_provider(BrokenGroup.first)
app_container.resolve_provider(BrokenGroup.second)

with pytest.raises(RuntimeError, match="Errors during sync cleanup"):
with pytest.raises(FinalizerError, match="Errors during sync cleanup") as exc:
app_container.close_sync()
assert exc.value.is_async is False
assert len(exc.value.finalizer_errors) == 1

assert cleaned_up == ["done"]

Expand Down Expand Up @@ -146,8 +149,10 @@ class BrokenAsyncGroup(Group):
app_container.resolve_provider(BrokenAsyncGroup.first)
app_container.resolve_provider(BrokenAsyncGroup.second)

with pytest.raises(RuntimeError, match="Errors during async cleanup"):
with pytest.raises(FinalizerError, match="Errors during async cleanup") as exc:
await app_container.close_async()
assert exc.value.is_async is True
assert len(exc.value.finalizer_errors) == 1

assert cleaned_up == ["done"]

Expand Down
6 changes: 3 additions & 3 deletions tests/registries/test_providers_registry.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest

from modern_di import providers
from modern_di.exceptions import DuplicateProviderTypeError
from modern_di.registries.providers_registry import ProvidersRegistry


Expand All @@ -15,7 +16,6 @@ def test_providers_registry_add_provider_duplicates() -> None:
providers_registry = ProvidersRegistry()
providers_registry.add_providers(str_factory)

with (
pytest.raises(RuntimeError, match="Provider is duplicated by type <class 'str'>"),
):
with pytest.raises(DuplicateProviderTypeError, match="Provider is duplicated by type <class 'str'>") as exc:
providers_registry.add_providers(str_factory)
assert exc.value.provider_type is str
Loading
Loading