diff --git a/sqlmodel/_compat.py b/sqlmodel/_compat.py index a220b193f1..fc258e6d7c 100644 --- a/sqlmodel/_compat.py +++ b/sqlmodel/_compat.py @@ -171,10 +171,18 @@ def is_field_noneable(field: "FieldInfo") -> bool: return False +def unwrap_newtype(tp: Any) -> Any: + # returns wrapped type of newtype recursivly + while hasattr(tp, "__supertype__"): + tp = tp.__supertype__ + return tp + + def get_sa_type_from_type_annotation(annotation: Any) -> Any: # Resolve Optional fields if annotation is None: raise ValueError("Missing field type") + annotation = unwrap_newtype(annotation) origin = get_origin(annotation) if origin is None: return annotation diff --git a/tests/test_field_newtype.py b/tests/test_field_newtype.py new file mode 100644 index 0000000000..5721c3a929 --- /dev/null +++ b/tests/test_field_newtype.py @@ -0,0 +1,37 @@ +from typing import Annotated, NewType +from uuid import UUID, uuid4 + +from sqlmodel import Field, SQLModel + + +def test_field_is_newtype() -> None: + NewId = NewType("NewId", UUID) + + class Item(SQLModel, table=True): + id: NewId = Field(default_factory=uuid4, primary_key=True) + + item = Item() + assert isinstance(item.id, UUID) + + +def test_field_is_recursive_newtype() -> None: + NewId1 = NewType("NewId1", int) + NewId2 = NewType("NewId2", NewId1) + NewId3 = NewType("NewId3", NewId2) + + class Item(SQLModel, table=True): + id: NewId3 = Field(primary_key=True) + + item = Item(id=NewId3(NewId2(NewId1(3)))) + assert isinstance(item.id, int) + assert item.id == 3, item.id + + +def test_field_is_newtype_and_annotated() -> None: + NewId = NewType("NewId", UUID) + + class Item(SQLModel, table=True): + id: Annotated[NewId, Field(primary_key=True)] = NewId(uuid4()) + + item = Item() + assert isinstance(item.id, UUID)