[bug] fix reading with `to_arrow_batch_reader` and `limit` (#1042)
* fix project_batches with limit
* add test
* lint + readability
diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py
index 5218845..b8471ee 100644
--- a/pyiceberg/io/pyarrow.py
+++ b/pyiceberg/io/pyarrow.py
@@ -1409,6 +1409,9 @@
total_row_count = 0
for task in tasks:
+ # stop early if limit is satisfied
+ if limit is not None and total_row_count >= limit:
+ break
batches = _task_to_record_batches(
fs,
task,
@@ -1421,9 +1424,10 @@
)
for batch in batches:
if limit is not None:
- if total_row_count + len(batch) >= limit:
- yield batch.slice(0, limit - total_row_count)
+ if total_row_count >= limit:
break
+ elif total_row_count + len(batch) >= limit:
+ batch = batch.slice(0, limit - total_row_count)
yield batch
total_row_count += len(batch)
diff --git a/tests/integration/test_reads.py b/tests/integration/test_reads.py
index 8271267..cee8839 100644
--- a/tests/integration/test_reads.py
+++ b/tests/integration/test_reads.py
@@ -240,6 +240,54 @@
full_result = table_test_limit.scan(selected_fields=("idx",), limit=999).to_arrow()
assert len(full_result) == 10
+ # test `to_arrow_batch_reader`
+ limited_result = table_test_limit.scan(selected_fields=("idx",), limit=1).to_arrow_batch_reader().read_all()
+ assert len(limited_result) == 1
+
+ empty_result = table_test_limit.scan(selected_fields=("idx",), limit=0).to_arrow_batch_reader().read_all()
+ assert len(empty_result) == 0
+
+ full_result = table_test_limit.scan(selected_fields=("idx",), limit=999).to_arrow_batch_reader().read_all()
+ assert len(full_result) == 10
+
+
+@pytest.mark.integration
+@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
+def test_pyarrow_limit_with_multiple_files(catalog: Catalog) -> None:
+ table_name = "default.test_pyarrow_limit_with_multiple_files"
+ try:
+ catalog.drop_table(table_name)
+ except NoSuchTableError:
+ pass
+ reference_table = catalog.load_table("default.test_limit")
+ data = reference_table.scan().to_arrow()
+ table_test_limit = catalog.create_table(table_name, schema=reference_table.schema())
+
+ n_files = 2
+ for _ in range(n_files):
+ table_test_limit.append(data)
+ assert len(table_test_limit.inspect.files()) == n_files
+
+ # test with multiple files
+ limited_result = table_test_limit.scan(selected_fields=("idx",), limit=1).to_arrow()
+ assert len(limited_result) == 1
+
+ empty_result = table_test_limit.scan(selected_fields=("idx",), limit=0).to_arrow()
+ assert len(empty_result) == 0
+
+ full_result = table_test_limit.scan(selected_fields=("idx",), limit=999).to_arrow()
+ assert len(full_result) == 10 * n_files
+
+ # test `to_arrow_batch_reader`
+ limited_result = table_test_limit.scan(selected_fields=("idx",), limit=1).to_arrow_batch_reader().read_all()
+ assert len(limited_result) == 1
+
+ empty_result = table_test_limit.scan(selected_fields=("idx",), limit=0).to_arrow_batch_reader().read_all()
+ assert len(empty_result) == 0
+
+ full_result = table_test_limit.scan(selected_fields=("idx",), limit=999).to_arrow_batch_reader().read_all()
+ assert len(full_result) == 10 * n_files
+
@pytest.mark.integration
@pytest.mark.filterwarnings("ignore")