From 236970d18a72f6a25c7051ea3e61bc517568c13d Mon Sep 17 00:00:00 2001 From: quettabit <27509167+quettabit@users.noreply.github.com> Date: Mon, 11 May 2026 21:11:10 -0600 Subject: [PATCH] wip --- pyproject.toml | 1 + pytest.ini | 1 + tests/conftest.py | 41 ++++++++++++++++------ tests/test_correctness.py | 74 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 106 insertions(+), 11 deletions(-) create mode 100644 tests/test_correctness.py diff --git a/pyproject.toml b/pyproject.toml index 8b162ea..3c7a814 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,3 +75,4 @@ e2e-tests = "uv run pytest tests/ -v -s -m 'account or basin or stream'" e2e-account-tests = "uv run pytest tests/ -v -s -m account" e2e-basin-tests = "uv run pytest tests/ -v -s -m basin" e2e-stream-tests = "uv run pytest tests/ -v -s -m stream" +correctness-tests = "uv run pytest tests/ -v -s -m correctness" diff --git a/pytest.ini b/pytest.ini index 7fa0aad..e656437 100644 --- a/pytest.ini +++ b/pytest.ini @@ -18,3 +18,4 @@ markers = stream: tests for stream operations metrics: tests for metrics operations access_tokens: tests for access token operations + correctness: correctness tests diff --git a/tests/conftest.py b/tests/conftest.py index 5081684..17abcfe 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,7 +5,7 @@ import pytest import pytest_asyncio -from s2_sdk import S2, Compression, Endpoints, S2Basin, S2Stream +from s2_sdk import S2, Compression, Endpoints, Retry, S2Basin, S2Stream pytest_plugins = ["pytest_asyncio"] @@ -50,22 +50,35 @@ def endpoints() -> Endpoints | None: return None +@pytest.fixture(scope="session") +def retry() -> Retry | None: + return None + + @pytest_asyncio.fixture(scope="session") async def s2( - access_token: str, compression: Compression, endpoints: Endpoints | None + access_token: str, + compression: Compression, + endpoints: Endpoints | None, + retry: Retry | None, ) -> AsyncGenerator[S2, None]: - async with S2(access_token, endpoints=endpoints, compression=compression) as s2: + async with S2( + access_token, + endpoints=endpoints, + compression=compression, + retry=retry, + ) as s2: yield s2 @pytest.fixture -def basin_name() -> str: - return _basin_name() +def basin_name(basin_prefix: str) -> str: + return _basin_name(basin_prefix) @pytest.fixture -def basin_names() -> list[str]: - return [_basin_name() for _ in range(3)] +def basin_names(basin_prefix: str) -> list[str]: + return [_basin_name(basin_prefix) for _ in range(3)] @pytest.fixture @@ -94,8 +107,10 @@ async def basin(s2: S2, basin_name: str) -> AsyncGenerator[S2Basin, None]: @pytest_asyncio.fixture(scope="class") -async def shared_basin(s2: S2) -> AsyncGenerator[S2Basin, None]: - basin_name = _basin_name() +async def shared_basin( + s2: S2, basin_prefix: str +) -> AsyncGenerator[S2Basin, None]: + basin_name = _basin_name(basin_prefix) await s2.create_basin(name=basin_name) try: @@ -117,8 +132,12 @@ async def stream( await basin.delete_stream(stream_name) -def _basin_name() -> str: - return f"{BASIN_PREFIX}-{uuid.uuid4().hex[:8]}" +def _basin_name(prefix: str) -> str: + suffix = uuid.uuid4().hex[:8] + prefix = prefix.strip("-")[: 48 - len(suffix) - 1].strip("-") + if not prefix: + return suffix + return f"{prefix}-{suffix}" def _stream_name() -> str: diff --git a/tests/test_correctness.py b/tests/test_correctness.py new file mode 100644 index 0000000..204d6ab --- /dev/null +++ b/tests/test_correctness.py @@ -0,0 +1,74 @@ +import asyncio +import sys + +import pytest + +from s2_sdk import Batching, Record, Retry, S2Stream, SeqNum + +TOTAL_RECORDS = 64 + + +@pytest.fixture(scope="session") +def retry() -> Retry: + return Retry(max_attempts=sys.maxsize) + + +@pytest.fixture(scope="session") +def basin_prefix() -> str: + return "python-correctness" + + +@pytest.mark.correctness +@pytest.mark.asyncio +async def test_concurrent_producer_and_consumer_remain_gapless(stream: S2Stream): + async def read_records() -> tuple[int, int]: + highest_contiguous_index = -1 + last_seq_num: int | None = None + observed_records = 0 + + async for batch in stream.read_session(start=SeqNum(0)): + for record in batch.records: + seq_num = record.seq_num + if last_seq_num is None: + assert seq_num == 0 + else: + assert seq_num == last_seq_num + 1 + last_seq_num = seq_num + + body = record.body.decode() + index = int(body) + assert 0 <= index < TOTAL_RECORDS + assert index <= highest_contiguous_index + 1 + + if index == highest_contiguous_index + 1: + highest_contiguous_index = index + observed_records += 1 + + if highest_contiguous_index + 1 >= TOTAL_RECORDS: + assert highest_contiguous_index == TOTAL_RECORDS - 1 + assert observed_records >= TOTAL_RECORDS + return highest_contiguous_index, observed_records + + raise AssertionError("read session ended before all records were observed") + + async def append_records() -> None: + async with stream.producer(batching=Batching(max_records=4)) as producer: + tickets = [] + for i in range(TOTAL_RECORDS): + ticket = await producer.submit(Record(body=str(i).encode())) + tickets.append(ticket) + + for ticket in tickets: + ack = await ticket + assert ack.seq_num >= 0 + + read_task = asyncio.create_task(read_records()) + append_task = asyncio.create_task(append_records()) + try: + read_result, _ = await asyncio.gather(read_task, append_task) + finally: + for task in (read_task, append_task): + if not task.done(): + task.cancel() + + assert read_result[0] == TOTAL_RECORDS - 1