| ################################################################################ |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| ################################################################################ |
| |
| """End-to-end coverage for ``with_limit`` after row-level pushdown. |
| |
| Locks the contract: ``with_limit(N)`` returns at most ``N`` rows, and |
| the reader actually stops at that boundary instead of reading every |
| split / merge output to completion and trimming at the consumer. |
| """ |
| |
| import os |
| import shutil |
| import tempfile |
| import unittest |
| |
| import pyarrow as pa |
| |
| from pypaimon import CatalogFactory, Schema |
| |
| |
| class LimitPushdownTest(unittest.TestCase): |
| |
| @classmethod |
| def setUpClass(cls): |
| cls.tempdir = tempfile.mkdtemp() |
| cls.warehouse = os.path.join(cls.tempdir, 'warehouse') |
| cls.catalog_options = {'warehouse': cls.warehouse} |
| cls.catalog = CatalogFactory.create(cls.catalog_options) |
| cls.catalog.create_database('default', False) |
| |
| @classmethod |
| def tearDownClass(cls): |
| shutil.rmtree(cls.tempdir, ignore_errors=True) |
| |
| @staticmethod |
| def _ao_schema() -> pa.Schema: |
| return pa.schema([ |
| pa.field('id', pa.int64(), nullable=False), |
| ('val', pa.int64()), |
| pa.field('dt', pa.string(), nullable=False), |
| ]) |
| |
| @staticmethod |
| def _pk_schema() -> pa.Schema: |
| return pa.schema([ |
| pa.field('id', pa.int64(), nullable=False), |
| ('val', pa.int64()), |
| ]) |
| |
| def _create_ao_table(self, name: str): |
| identifier = 'default.' + name |
| schema = Schema.from_pyarrow_schema( |
| self._ao_schema(), |
| partition_keys=['dt'], |
| options={'file.format': 'parquet'}, |
| ) |
| self.catalog.create_table(identifier, schema, False) |
| return self.catalog.get_table(identifier) |
| |
| def _create_pk_table(self, name: str, *, num_buckets: int = 1): |
| identifier = 'default.' + name |
| schema = Schema.from_pyarrow_schema( |
| self._pk_schema(), |
| primary_keys=['id'], |
| options={'bucket': str(num_buckets), 'file.format': 'parquet'}, |
| ) |
| self.catalog.create_table(identifier, schema, False) |
| return self.catalog.get_table(identifier) |
| |
| def _write_ao_partitions(self, table, partitions): |
| for dt, rows in partitions: |
| wb = table.new_batch_write_builder() |
| w = wb.new_write() |
| data = pa.Table.from_pylist( |
| [{'id': r, 'val': r * 10, 'dt': dt} for r in rows], |
| schema=self._ao_schema()) |
| w.write_arrow(data) |
| wb.new_commit().commit(w.prepare_commit()) |
| w.close() |
| |
| def _write_pk_snapshots(self, table, snapshots): |
| for rows in snapshots: |
| wb = table.new_batch_write_builder() |
| w = wb.new_write() |
| data = pa.Table.from_pylist( |
| [{'id': i, 'val': v} for i, v in rows], schema=self._pk_schema()) |
| w.write_arrow(data) |
| wb.new_commit().commit(w.prepare_commit()) |
| w.close() |
| |
| # ---- append-only ----------------------------------------------------- |
| |
| def test_append_only_limit_stops_within_first_split(self): |
| """With limit=3 on a partitioned append-only table, the result is |
| exactly 3 rows — even though each partition split has 5 rows.""" |
| table = self._create_ao_table('limit_ao_within_split') |
| self._write_ao_partitions(table, [ |
| ('p1', list(range(5))), # 5 rows |
| ('p2', list(range(5, 10))), # 5 rows |
| ]) |
| rb = table.new_read_builder().with_limit(3) |
| result = rb.new_read().to_arrow(rb.new_scan().plan().splits()) |
| self.assertEqual(result.num_rows, 3) |
| |
| def test_append_only_limit_spans_multiple_splits(self): |
| """Limit larger than first split: read carries over to the next |
| split until the budget is met.""" |
| table = self._create_ao_table('limit_ao_span_splits') |
| self._write_ao_partitions(table, [ |
| ('p1', [1, 2]), |
| ('p2', [3, 4]), |
| ('p3', [5, 6]), |
| ]) |
| rb = table.new_read_builder().with_limit(5) |
| result = rb.new_read().to_arrow(rb.new_scan().plan().splits()) |
| self.assertEqual(result.num_rows, 5) |
| |
| def test_append_only_limit_zero_returns_empty(self): |
| table = self._create_ao_table('limit_ao_zero') |
| self._write_ao_partitions(table, [('p1', [1, 2, 3])]) |
| rb = table.new_read_builder().with_limit(0) |
| splits = rb.new_scan().plan().splits() |
| result = rb.new_read().to_arrow(splits) |
| self.assertEqual(result.num_rows, 0) |
| |
| def test_append_only_limit_larger_than_total(self): |
| """Limit greater than the total returns the total, not the limit.""" |
| table = self._create_ao_table('limit_ao_oversize') |
| self._write_ao_partitions(table, [('p1', [1, 2, 3])]) |
| rb = table.new_read_builder().with_limit(100) |
| result = rb.new_read().to_arrow(rb.new_scan().plan().splits()) |
| self.assertEqual(result.num_rows, 3) |
| |
| # ---- PK merge-on-read ------------------------------------------------ |
| |
| def test_pk_merge_limit_stops_within_first_split(self): |
| """PK + multiple snapshots forces the merge-read path. The reader |
| must stop at limit rows instead of running every section to |
| completion and trimming at the consumer.""" |
| table = self._create_pk_table('limit_pk_within_split') |
| # Two snapshots over the same key range → merge path; total |
| # post-merge unique rows = 20. |
| self._write_pk_snapshots(table, [ |
| [(i, i) for i in range(20)], |
| [(i, i + 1000) for i in range(0, 20, 2)], |
| ]) |
| for limit in (1, 5, 10, 19): |
| rb = table.new_read_builder().with_limit(limit) |
| result = rb.new_read().to_arrow(rb.new_scan().plan().splits()) |
| self.assertEqual( |
| result.num_rows, limit, |
| "with_limit(%d) must short-circuit at the row level" % limit) |
| |
| def test_pk_merge_limit_equals_total(self): |
| """Limit equal to total post-merge row count: returns everything.""" |
| table = self._create_pk_table('limit_pk_equals_total') |
| self._write_pk_snapshots(table, [ |
| [(i, i) for i in range(10)], |
| [(i, i + 100) for i in range(5)], |
| ]) |
| rb = table.new_read_builder().with_limit(10) |
| result = rb.new_read().to_arrow(rb.new_scan().plan().splits()) |
| self.assertEqual(result.num_rows, 10) |
| |
| def test_pk_merge_limit_with_predicate(self): |
| """``with_limit`` plus ``with_filter``: the filter prunes first and |
| the limit caps what survives. ``val >= 1000`` matches the latest |
| write of the even ``id`` rows; limit then takes the prefix.""" |
| table = self._create_pk_table('limit_pk_with_filter') |
| self._write_pk_snapshots(table, [ |
| [(i, i) for i in range(20)], |
| [(i, i + 1000) for i in range(0, 20, 2)], # update evens |
| ]) |
| rb = table.new_read_builder() |
| pred = rb.new_predicate_builder().greater_or_equal('val', 1000) |
| rb = rb.with_filter(pred).with_limit(3) |
| result = rb.new_read().to_arrow(rb.new_scan().plan().splits()) |
| self.assertEqual(result.num_rows, 3) |
| for v in result.column('val').to_pylist(): |
| self.assertGreaterEqual(v, 1000) |
| |
| # ---- to_iterator path ------------------------------------------------ |
| |
| def test_to_iterator_limit_short_circuits(self): |
| table = self._create_ao_table('limit_iter') |
| self._write_ao_partitions(table, [ |
| ('p1', list(range(50))), |
| ('p2', list(range(50, 100))), |
| ]) |
| rb = table.new_read_builder().with_limit(7) |
| it = rb.new_read().to_iterator(rb.new_scan().plan().splits()) |
| rows = list(it) |
| self.assertEqual(len(rows), 7) |
| |
| # ---- SplitRead-level limit pushdown verification --------------------- |
| |
| def test_append_only_split_read_creates_limited_batch_reader(self): |
| """Verify that RawFileSplitRead.create_reader() returns a |
| LimitedRecordBatchReader (inherits RecordBatchReader) when limit |
| is set, so the arrow-batch read path is taken.""" |
| from pypaimon.read.reader.iface.record_batch_reader import RecordBatchReader |
| from pypaimon.read.reader.limited_record_reader import LimitedRecordBatchReader |
| |
| table = self._create_ao_table('limit_ao_split_read') |
| self._write_ao_partitions(table, [('p1', list(range(10)))]) |
| rb = table.new_read_builder().with_limit(3) |
| table_read = rb.new_read() |
| splits = rb.new_scan().plan().splits() |
| self.assertGreater(len(splits), 0) |
| for split in splits: |
| split_read = table_read._create_split_read(split) |
| self.assertEqual(split_read.limit, 3) |
| reader = split_read.create_reader() |
| self.assertIsInstance(reader, LimitedRecordBatchReader, |
| "RawFileSplitRead.create_reader() should wrap with LimitedRecordBatchReader") |
| self.assertIsInstance(reader, RecordBatchReader, |
| "LimitedRecordBatchReader should be a RecordBatchReader") |
| reader.close() |
| |
| def test_append_only_split_read_limit_truncates_within_split(self): |
| """Directly read from a single split's reader with limit and verify |
| the reader itself stops at the limit boundary, not relying on |
| TableRead-level truncation.""" |
| table = self._create_ao_table('limit_ao_split_truncate') |
| self._write_ao_partitions(table, [('p1', list(range(20)))]) |
| rb = table.new_read_builder().with_limit(5) |
| table_read = rb.new_read() |
| splits = rb.new_scan().plan().splits() |
| self.assertEqual(len(splits), 1) |
| split_read = table_read._create_split_read(splits[0]) |
| reader = split_read.create_reader() |
| # Drain the reader directly, bypassing TableRead-level control |
| total_rows = 0 |
| while True: |
| batch = reader.read_arrow_batch() |
| if batch is None: |
| break |
| total_rows += batch.num_rows |
| reader.close() |
| self.assertEqual(total_rows, 5, |
| "SplitRead-level reader should stop at limit=5, got %d" % total_rows) |
| |
| |
| if __name__ == '__main__': |
| unittest.main() |