From 0a7a4cc5c286b4a265dcc04fd3811416297117d6 Mon Sep 17 00:00:00 2001 From: quettabit <27509167+quettabit@users.noreply.github.com> Date: Mon, 11 May 2026 15:29:31 -0600 Subject: [PATCH 1/2] initial commit --- src/s2_sdk/_batching.py | 4 ++++ tests/test_batching.py | 17 +++++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/src/s2_sdk/_batching.py b/src/s2_sdk/_batching.py index 3c742fe..17d9611 100644 --- a/src/s2_sdk/_batching.py +++ b/src/s2_sdk/_batching.py @@ -88,6 +88,10 @@ async def append_record_batches( acc.add(record) yield acc.take() + except Exception: + if not acc.is_empty(): + yield acc.take() + raise finally: if pending_next is not None: pending_next.cancel() diff --git a/tests/test_batching.py b/tests/test_batching.py index 0a09b14..73f73d1 100644 --- a/tests/test_batching.py +++ b/tests/test_batching.py @@ -84,6 +84,23 @@ async def delayed_records(): assert len(batches[1]) == 2 # r2, r3 +@pytest.mark.asyncio +async def test_batch_accumulator_flushes_when_source_iter_raises(): + async def records(): + yield Record(body=b"r1") + yield Record(body=b"r2") + raise RuntimeError("err") + + batches = [] + with pytest.raises(RuntimeError, match="err"): + async for batch in append_record_batches( + records(), batching=Batching(linger=timedelta(0)) + ): + batches.append(batch) + + assert batches == [[Record(body=b"r1"), Record(body=b"r2")]] + + @pytest.mark.asyncio async def test_append_inputs_skips_empty_batches(): inputs = [] From 56cf2e1464e6c3152ffef4d64107cc437665cafe Mon Sep 17 00:00:00 2001 From: quettabit <27509167+quettabit@users.noreply.github.com> Date: Mon, 11 May 2026 15:38:19 -0600 Subject: [PATCH 2/2] cover both zero and non-zero linger in test --- tests/test_batching.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/test_batching.py b/tests/test_batching.py index 73f73d1..900e76f 100644 --- a/tests/test_batching.py +++ b/tests/test_batching.py @@ -85,7 +85,12 @@ async def delayed_records(): @pytest.mark.asyncio -async def test_batch_accumulator_flushes_when_source_iter_raises(): +@pytest.mark.parametrize( + "linger", + [timedelta(0), timedelta(milliseconds=10)], + ids=["zero_linger", "non_zero_linger"], +) +async def test_batch_accumulator_flushes_when_source_iter_raises(linger): async def records(): yield Record(body=b"r1") yield Record(body=b"r2") @@ -94,7 +99,7 @@ async def records(): batches = [] with pytest.raises(RuntimeError, match="err"): async for batch in append_record_batches( - records(), batching=Batching(linger=timedelta(0)) + records(), batching=Batching(linger=linger) ): batches.append(batch)