Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,4 @@ e2e-tests = "uv run pytest tests/ -v -s -m 'account or basin or stream'"
e2e-account-tests = "uv run pytest tests/ -v -s -m account"
e2e-basin-tests = "uv run pytest tests/ -v -s -m basin"
e2e-stream-tests = "uv run pytest tests/ -v -s -m stream"
correctness-tests = "uv run pytest tests/ -v -s -m correctness"
1 change: 1 addition & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ markers =
stream: tests for stream operations
metrics: tests for metrics operations
access_tokens: tests for access token operations
correctness: correctness tests
41 changes: 30 additions & 11 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest
import pytest_asyncio

from s2_sdk import S2, Compression, Endpoints, S2Basin, S2Stream
from s2_sdk import S2, Compression, Endpoints, Retry, S2Basin, S2Stream

pytest_plugins = ["pytest_asyncio"]

Expand Down Expand Up @@ -50,22 +50,35 @@ def endpoints() -> Endpoints | None:
return None


@pytest.fixture(scope="session")
def retry() -> Retry | None:
return None


@pytest_asyncio.fixture(scope="session")
async def s2(
access_token: str, compression: Compression, endpoints: Endpoints | None
access_token: str,
compression: Compression,
endpoints: Endpoints | None,
retry: Retry | None,
) -> AsyncGenerator[S2, None]:
async with S2(access_token, endpoints=endpoints, compression=compression) as s2:
async with S2(
access_token,
endpoints=endpoints,
compression=compression,
retry=retry,
) as s2:
yield s2


@pytest.fixture
def basin_name() -> str:
return _basin_name()
def basin_name(basin_prefix: str) -> str:
return _basin_name(basin_prefix)


@pytest.fixture
def basin_names() -> list[str]:
return [_basin_name() for _ in range(3)]
def basin_names(basin_prefix: str) -> list[str]:
return [_basin_name(basin_prefix) for _ in range(3)]


@pytest.fixture
Expand Down Expand Up @@ -94,8 +107,10 @@ async def basin(s2: S2, basin_name: str) -> AsyncGenerator[S2Basin, None]:


@pytest_asyncio.fixture(scope="class")
async def shared_basin(s2: S2) -> AsyncGenerator[S2Basin, None]:
basin_name = _basin_name()
async def shared_basin(
s2: S2, basin_prefix: str
) -> AsyncGenerator[S2Basin, None]:
basin_name = _basin_name(basin_prefix)
await s2.create_basin(name=basin_name)

try:
Expand All @@ -117,8 +132,12 @@ async def stream(
await basin.delete_stream(stream_name)


def _basin_name() -> str:
return f"{BASIN_PREFIX}-{uuid.uuid4().hex[:8]}"
def _basin_name(prefix: str) -> str:
suffix = uuid.uuid4().hex[:8]
prefix = prefix.strip("-")[: 48 - len(suffix) - 1].strip("-")
if not prefix:
return suffix
return f"{prefix}-{suffix}"


def _stream_name() -> str:
Expand Down
74 changes: 74 additions & 0 deletions tests/test_correctness.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import asyncio
import sys

import pytest

from s2_sdk import Batching, Record, Retry, S2Stream, SeqNum

TOTAL_RECORDS = 64


@pytest.fixture(scope="session")
def retry() -> Retry:
return Retry(max_attempts=sys.maxsize)


@pytest.fixture(scope="session")
def basin_prefix() -> str:
return "python-correctness"


@pytest.mark.correctness
@pytest.mark.asyncio
async def test_concurrent_producer_and_consumer_remain_gapless(stream: S2Stream):
async def read_records() -> tuple[int, int]:
highest_contiguous_index = -1
last_seq_num: int | None = None
observed_records = 0

async for batch in stream.read_session(start=SeqNum(0)):
for record in batch.records:
seq_num = record.seq_num
if last_seq_num is None:
assert seq_num == 0
else:
assert seq_num == last_seq_num + 1
last_seq_num = seq_num

body = record.body.decode()
index = int(body)
assert 0 <= index < TOTAL_RECORDS
assert index <= highest_contiguous_index + 1

if index == highest_contiguous_index + 1:
highest_contiguous_index = index
observed_records += 1

if highest_contiguous_index + 1 >= TOTAL_RECORDS:
assert highest_contiguous_index == TOTAL_RECORDS - 1
assert observed_records >= TOTAL_RECORDS
return highest_contiguous_index, observed_records

raise AssertionError("read session ended before all records were observed")

async def append_records() -> None:
async with stream.producer(batching=Batching(max_records=4)) as producer:
tickets = []
for i in range(TOTAL_RECORDS):
ticket = await producer.submit(Record(body=str(i).encode()))
tickets.append(ticket)

for ticket in tickets:
ack = await ticket
assert ack.seq_num >= 0

read_task = asyncio.create_task(read_records())
append_task = asyncio.create_task(append_records())
try:
read_result, _ = await asyncio.gather(read_task, append_task)
finally:
for task in (read_task, append_task):
if not task.done():
task.cancel()

assert read_result[0] == TOTAL_RECORDS - 1
Loading