blob: edbc1b75f433c446f31c351f54578e7258316379 [file]
################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
import unittest
from typing import List, Optional
from pypaimon.read.reader.iface.record_iterator import RecordIterator
from pypaimon.read.reader.iface.record_reader import RecordReader
from pypaimon.read.reader.limited_record_reader import LimitedRecordReader
class _ListIterator(RecordIterator):
def __init__(self, items: List):
self._items = items
self._idx = 0
def next(self):
if self._idx >= len(self._items):
return None
v = self._items[self._idx]
self._idx += 1
return v
class _StaticReader(RecordReader):
"""Hands back batches one at a time, tracks close calls and the
number of times ``read_batch`` was invoked (so tests can prove the
limiter actually short-circuited instead of draining the inner)."""
def __init__(self, batches: List[List]):
self._batches = batches
self._idx = 0
self.closed = False
self.read_batch_calls = 0
def read_batch(self) -> Optional[RecordIterator]:
self.read_batch_calls += 1
if self._idx >= len(self._batches):
return None
batch = self._batches[self._idx]
self._idx += 1
return _ListIterator(batch)
def close(self):
self.closed = True
def _drain(reader: RecordReader) -> List:
out = []
while True:
batch = reader.read_batch()
if batch is None:
break
while True:
v = batch.next()
if v is None:
break
out.append(v)
return out
class LimitedRecordReaderTest(unittest.TestCase):
def test_limit_within_first_batch(self):
reader = LimitedRecordReader(
_StaticReader([[1, 2, 3, 4, 5]]), limit=3)
self.assertEqual(_drain(reader), [1, 2, 3])
def test_limit_spans_multiple_batches(self):
reader = LimitedRecordReader(
_StaticReader([[1, 2], [3, 4], [5, 6]]), limit=5)
self.assertEqual(_drain(reader), [1, 2, 3, 4, 5])
def test_limit_larger_than_total_returns_everything(self):
reader = LimitedRecordReader(
_StaticReader([[1, 2, 3]]), limit=999)
self.assertEqual(_drain(reader), [1, 2, 3])
def test_limit_zero_returns_nothing(self):
reader = LimitedRecordReader(
_StaticReader([[1, 2, 3]]), limit=0)
self.assertEqual(_drain(reader), [])
# read_batch should short-circuit immediately rather than peek.
self.assertIsNone(reader.read_batch())
def test_negative_limit_rejected(self):
with self.assertRaises(ValueError):
LimitedRecordReader(_StaticReader([]), limit=-1)
def test_close_propagates(self):
inner = _StaticReader([[1, 2]])
reader = LimitedRecordReader(inner, limit=10)
reader.close()
self.assertTrue(inner.closed)
def test_iterator_stops_mid_batch(self):
# Limit cuts halfway through a batch; the next() call past the limit
# must return None even though the inner batch still has items.
reader = LimitedRecordReader(
_StaticReader([[1, 2, 3, 4, 5]]), limit=2)
batch = reader.read_batch()
self.assertEqual(batch.next(), 1)
self.assertEqual(batch.next(), 2)
self.assertIsNone(batch.next())
# Subsequent read_batch is also None.
self.assertIsNone(reader.read_batch())
def test_count_visible_for_observability(self):
reader = LimitedRecordReader(
_StaticReader([[1, 2, 3, 4]]), limit=10)
_drain(reader)
self.assertEqual(reader.count, 4)
def test_does_not_drain_inner_when_limit_met_within_first_batch(self):
"""Direct proof of the short-circuit: once the limiter has handed
out ``limit`` rows the next ``read_batch`` short-circuits at the
entry guard and never pulls a second batch from the inner."""
inner = _StaticReader([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]])
reader = LimitedRecordReader(inner, limit=3)
self.assertEqual(_drain(reader), [1, 2, 3])
# Only the first batch was fetched; the second is never asked for.
self.assertEqual(inner.read_batch_calls, 1)
if __name__ == '__main__':
unittest.main()