diff --git a/src/s2_sdk/_batching.py b/src/s2_sdk/_batching.py index 3c742fe..17d9611 100644 --- a/src/s2_sdk/_batching.py +++ b/src/s2_sdk/_batching.py @@ -88,6 +88,10 @@ async def append_record_batches( acc.add(record) yield acc.take() + except Exception: + if not acc.is_empty(): + yield acc.take() + raise finally: if pending_next is not None: pending_next.cancel() diff --git a/tests/test_batching.py b/tests/test_batching.py index 0a09b14..900e76f 100644 --- a/tests/test_batching.py +++ b/tests/test_batching.py @@ -84,6 +84,28 @@ async def delayed_records(): assert len(batches[1]) == 2 # r2, r3 +@pytest.mark.asyncio +@pytest.mark.parametrize( + "linger", + [timedelta(0), timedelta(milliseconds=10)], + ids=["zero_linger", "non_zero_linger"], +) +async def test_batch_accumulator_flushes_when_source_iter_raises(linger): + async def records(): + yield Record(body=b"r1") + yield Record(body=b"r2") + raise RuntimeError("err") + + batches = [] + with pytest.raises(RuntimeError, match="err"): + async for batch in append_record_batches( + records(), batching=Batching(linger=linger) + ): + batches.append(batch) + + assert batches == [[Record(body=b"r1"), Record(body=b"r2")]] + + @pytest.mark.asyncio async def test_append_inputs_skips_empty_batches(): inputs = []