ARROW-9731: [C++][Python][R][Dataset] Implement Scanner::Head
This ports the head() method from R to C++ and exposes it in Python.
Closes #10047 from lidavidm/arrow-9731-2
Authored-by: David Li <li.davidm96@gmail.com>
Signed-off-by: Krisztián Szűcs <szucs.krisztian@gmail.com>
diff --git a/cpp/src/arrow/dataset/scanner.cc b/cpp/src/arrow/dataset/scanner.cc
index 5095c2e..f7bd3c0 100644
--- a/cpp/src/arrow/dataset/scanner.cc
+++ b/cpp/src/arrow/dataset/scanner.cc
@@ -547,5 +547,21 @@
return out.table();
}
+Result<std::shared_ptr<Table>> Scanner::Head(int64_t num_rows) {
+ if (num_rows == 0) {
+ return Table::FromRecordBatches(options()->projected_schema, {});
+ }
+ ARROW_ASSIGN_OR_RAISE(auto batch_iterator, ScanBatches());
+ RecordBatchVector batches;
+ while (true) {
+ ARROW_ASSIGN_OR_RAISE(auto batch, batch_iterator.Next());
+ if (IsIterationEnd(batch)) break;
+ batches.push_back(batch.record_batch->Slice(0, num_rows));
+ num_rows -= batch.record_batch->num_rows();
+ if (num_rows <= 0) break;
+ }
+ return Table::FromRecordBatches(options()->projected_schema, batches);
+}
+
} // namespace dataset
} // namespace arrow
diff --git a/cpp/src/arrow/dataset/scanner.h b/cpp/src/arrow/dataset/scanner.h
index 9720346..956fbbb 100644
--- a/cpp/src/arrow/dataset/scanner.h
+++ b/cpp/src/arrow/dataset/scanner.h
@@ -289,6 +289,8 @@
///
/// Will only consume as many batches as needed from ScanBatches().
virtual Result<std::shared_ptr<Table>> TakeRows(const Array& indices);
+ /// \brief Get the first N rows.
+ virtual Result<std::shared_ptr<Table>> Head(int64_t num_rows);
/// \brief Get the options for this scan.
const std::shared_ptr<ScanOptions>& options() const { return scan_options_; }
diff --git a/cpp/src/arrow/dataset/scanner_test.cc b/cpp/src/arrow/dataset/scanner_test.cc
index 3a2d37f..b4e374a 100644
--- a/cpp/src/arrow/dataset/scanner_test.cc
+++ b/cpp/src/arrow/dataset/scanner_test.cc
@@ -301,6 +301,41 @@
}
}
+TEST_P(TestScanner, Head) {
+ SetSchema({field("i32", int32()), field("f64", float64())});
+ auto batch = ConstantArrayGenerator::Zeroes(kBatchSize, schema_);
+
+ auto scanner = MakeScanner(batch);
+ std::shared_ptr<Table> expected, actual;
+
+ ASSERT_OK_AND_ASSIGN(expected, Table::FromRecordBatches(schema_, {}));
+ ASSERT_OK_AND_ASSIGN(actual, scanner->Head(0));
+ AssertTablesEqual(*expected, *actual);
+
+ ASSERT_OK_AND_ASSIGN(expected, Table::FromRecordBatches(schema_, {batch}));
+ ASSERT_OK_AND_ASSIGN(actual, scanner->Head(kBatchSize));
+ AssertTablesEqual(*expected, *actual);
+
+ ASSERT_OK_AND_ASSIGN(expected, Table::FromRecordBatches(schema_, {batch->Slice(0, 1)}));
+ ASSERT_OK_AND_ASSIGN(actual, scanner->Head(1));
+ AssertTablesEqual(*expected, *actual);
+
+ ASSERT_OK_AND_ASSIGN(expected,
+ Table::FromRecordBatches(schema_, {batch, batch->Slice(0, 1)}));
+ ASSERT_OK_AND_ASSIGN(actual, scanner->Head(kBatchSize + 1));
+ AssertTablesEqual(*expected, *actual);
+
+ ASSERT_OK_AND_ASSIGN(expected, scanner->ToTable());
+ ASSERT_OK_AND_ASSIGN(actual,
+ scanner->Head(kBatchSize * kNumberBatches * kNumberChildDatasets));
+ AssertTablesEqual(*expected, *actual);
+
+ ASSERT_OK_AND_ASSIGN(expected, scanner->ToTable());
+ ASSERT_OK_AND_ASSIGN(
+ actual, scanner->Head(kBatchSize * kNumberBatches * kNumberChildDatasets + 100));
+ AssertTablesEqual(*expected, *actual);
+}
+
INSTANTIATE_TEST_SUITE_P(TestScannerThreading, TestScanner, ::testing::Bool());
class TestScannerBuilder : public ::testing::Test {
diff --git a/python/pyarrow/_dataset.pyx b/python/pyarrow/_dataset.pyx
index 46f78d4..6199428 100644
--- a/python/pyarrow/_dataset.pyx
+++ b/python/pyarrow/_dataset.pyx
@@ -457,6 +457,17 @@
"""
return self._scanner(**kwargs).to_table()
+ def head(self, int num_rows, **kwargs):
+ """Load the first N rows of the dataset.
+
+ See scan method parameters documentation.
+
+ Returns
+ -------
+ table : Table instance
+ """
+ return self._scanner(**kwargs).head(num_rows)
+
@property
def schema(self):
"""The common schema of the full Dataset"""
@@ -989,6 +1000,17 @@
"""
return self._scanner(schema=schema, **kwargs).to_table()
+ def head(self, int num_rows, **kwargs):
+ """Load the first N rows of the fragment.
+
+ See scan method parameters documentation.
+
+ Returns
+ -------
+ table : Table instance
+ """
+ return self._scanner(**kwargs).head(num_rows)
+
cdef class FileFragment(Fragment):
"""A Fragment representing a data file."""
@@ -2883,6 +2905,18 @@
result = self.scanner.TakeRows(deref(c_indices))
return pyarrow_wrap_table(GetResultValue(result))
+ def head(self, int num_rows):
+ """Load the first N rows of the dataset.
+
+ Returns
+ -------
+ table : Table instance
+ """
+ cdef CResult[shared_ptr[CTable]] result
+ with nogil:
+ result = self.scanner.Head(num_rows)
+ return pyarrow_wrap_table(GetResultValue(result))
+
def _get_partition_keys(Expression partition_expression):
"""
diff --git a/python/pyarrow/includes/libarrow_dataset.pxd b/python/pyarrow/includes/libarrow_dataset.pxd
index 4da2978..16f6c5c 100644
--- a/python/pyarrow/includes/libarrow_dataset.pxd
+++ b/python/pyarrow/includes/libarrow_dataset.pxd
@@ -100,6 +100,7 @@
CResult[CTaggedRecordBatchIterator] ScanBatches()
CResult[shared_ptr[CTable]] ToTable()
CResult[shared_ptr[CTable]] TakeRows(const CArray& indices)
+ CResult[shared_ptr[CTable]] Head(int64_t num_rows)
CResult[CFragmentIterator] GetFragments()
const shared_ptr[CScanOptions]& options()
diff --git a/python/pyarrow/tests/test_dataset.py b/python/pyarrow/tests/test_dataset.py
index 26c14e1..6ca6b09 100644
--- a/python/pyarrow/tests/test_dataset.py
+++ b/python/pyarrow/tests/test_dataset.py
@@ -371,6 +371,28 @@
scanner.take(pa.array([table.num_rows]))
+def test_head(dataset):
+ result = dataset.head(0)
+ assert result == pa.Table.from_batches([], schema=dataset.schema)
+
+ result = dataset.head(1, columns=['i64']).to_pydict()
+ assert result == {'i64': [0]}
+
+ result = dataset.head(2, columns=['i64'],
+ filter=ds.field('i64') > 1).to_pydict()
+ assert result == {'i64': [2, 3]}
+
+ result = dataset.head(1024, columns=['i64']).to_pydict()
+ assert result == {'i64': list(range(5)) * 2}
+
+ fragment = next(dataset.get_fragments())
+ result = fragment.head(1, columns=['i64']).to_pydict()
+ assert result == {'i64': [0]}
+
+ result = fragment.head(1024, columns=['i64']).to_pydict()
+ assert result == {'i64': list(range(5))}
+
+
def test_abstract_classes():
classes = [
ds.FileFormat,
diff --git a/r/src/dataset.cpp b/r/src/dataset.cpp
index af321d7..c7ef39b 100644
--- a/r/src/dataset.cpp
+++ b/r/src/dataset.cpp
@@ -438,16 +438,7 @@
std::shared_ptr<arrow::Table> dataset___Scanner__head(
const std::shared_ptr<ds::Scanner>& scanner, int n) {
// TODO: make this a full Slice with offset > 0
- auto it = ValueOrStop(scanner->ScanBatches());
- std::vector<std::shared_ptr<arrow::RecordBatch>> batches;
- while (true) {
- auto current_batch = ValueOrStop(it.Next());
- if (arrow::IsIterationEnd(current_batch)) break;
- batches.push_back(current_batch.record_batch->Slice(0, n));
- n -= current_batch.record_batch->num_rows();
- if (n < 0) break;
- }
- return ValueOrStop(arrow::Table::FromRecordBatches(std::move(batches)));
+ return ValueOrStop(scanner->Head(n));
}
// TODO (ARROW-11782) Remove calls to Scan()