PARQUET-1166: Add GetRecordBatchReader in parquet/arrow/reader
Author: Xianjin YE <advancedxy@gmail.com>
Closes #445 from advancedxy/PARQUET-1166 and squashes the following commits:
03461ed [Xianjin YE] Add comments and auto-formatted code.
0abbb78 [Xianjin YE] declare next_row_group_ as size_t
f947668 [Xianjin YE] [WIP] PARQUET-1166: Add GetRecordBatchReader in parquet/arrow/reader
diff --git a/src/parquet/arrow/arrow-reader-writer-test.cc b/src/parquet/arrow/arrow-reader-writer-test.cc
index f06f4a8..8051ff1 100644
--- a/src/parquet/arrow/arrow-reader-writer-test.cc
+++ b/src/parquet/arrow/arrow-reader-writer-test.cc
@@ -1504,6 +1504,38 @@
ASSERT_TRUE(table->Equals(*concatenated));
}
+TEST(TestArrowReadWrite, GetRecordBatchReader) {
+ const int num_columns = 20;
+ const int num_rows = 1000;
+
+ std::shared_ptr<Table> table;
+ MakeDoubleTable(num_columns, num_rows, 1, &table);
+
+ std::shared_ptr<Buffer> buffer;
+ WriteTableToBuffer(table, 1, num_rows / 2, default_arrow_writer_properties(), &buffer);
+
+ std::unique_ptr<FileReader> reader;
+ ASSERT_OK_NO_THROW(OpenFile(std::make_shared<BufferReader>(buffer),
+ ::arrow::default_memory_pool(),
+ ::parquet::default_reader_properties(), nullptr, &reader));
+
+ std::shared_ptr<::arrow::RecordBatchReader> rb_reader;
+ ASSERT_OK_NO_THROW(reader->GetRecordBatchReader({0, 1}, &rb_reader));
+
+ std::shared_ptr<::arrow::RecordBatch> batch;
+
+ ASSERT_OK(rb_reader->ReadNext(&batch));
+ ASSERT_EQ(500, batch->num_rows());
+ ASSERT_EQ(20, batch->num_columns());
+
+ ASSERT_OK(rb_reader->ReadNext(&batch));
+ ASSERT_EQ(500, batch->num_rows());
+ ASSERT_EQ(20, batch->num_columns());
+
+ ASSERT_OK(rb_reader->ReadNext(&batch));
+ ASSERT_EQ(nullptr, batch);
+}
+
TEST(TestArrowReadWrite, ScanContents) {
const int num_columns = 20;
const int num_rows = 1000;
diff --git a/src/parquet/arrow/reader.cc b/src/parquet/arrow/reader.cc
index 78c3225..dd58d7a 100644
--- a/src/parquet/arrow/reader.cc
+++ b/src/parquet/arrow/reader.cc
@@ -57,6 +57,7 @@
// Help reduce verbosity
using ParquetReader = parquet::ParquetFileReader;
using arrow::ParallelFor;
+using arrow::RecordBatchReader;
using parquet::internal::RecordReader;
@@ -152,6 +153,59 @@
bool done_;
};
+class RowGroupRecordBatchReader : public ::arrow::RecordBatchReader {
+ public:
+ explicit RowGroupRecordBatchReader(const std::vector<int>& row_group_indices,
+ const std::vector<int>& column_indices,
+ std::shared_ptr<::arrow::Schema> schema,
+ FileReader* reader)
+ : row_group_indices_(row_group_indices),
+ column_indices_(column_indices),
+ schema_(schema),
+ file_reader_(reader),
+ next_row_group_(0) {}
+
+ ~RowGroupRecordBatchReader() {}
+
+ std::shared_ptr<::arrow::Schema> schema() const override { return schema_; }
+
+ Status ReadNext(std::shared_ptr<::arrow::RecordBatch>* out) override {
+ if (table_ != nullptr) { // one row group has been loaded
+ std::shared_ptr<::arrow::RecordBatch> tmp;
+ RETURN_NOT_OK(table_batch_reader_->ReadNext(&tmp));
+ if (tmp != nullptr) { // some column chunks are left in table
+ *out = tmp;
+ return Status::OK();
+ } else { // the entire table is consumed
+ table_batch_reader_.reset();
+ table_.reset();
+ }
+ }
+
+ // all row groups has been consumed
+ if (next_row_group_ == row_group_indices_.size()) {
+ *out = nullptr;
+ return Status::OK();
+ }
+
+ RETURN_NOT_OK(file_reader_->ReadRowGroup(row_group_indices_[next_row_group_],
+ column_indices_, &table_));
+
+ next_row_group_++;
+ table_batch_reader_.reset(new ::arrow::TableBatchReader(*table_.get()));
+ return table_batch_reader_->ReadNext(out);
+ }
+
+ private:
+ std::vector<int> row_group_indices_;
+ std::vector<int> column_indices_;
+ std::shared_ptr<::arrow::Schema> schema_;
+ FileReader* file_reader_;
+ size_t next_row_group_;
+ std::shared_ptr<::arrow::Table> table_;
+ std::unique_ptr<::arrow::TableBatchReader> table_batch_reader_;
+};
+
// ----------------------------------------------------------------------
// File reader implementation
@@ -188,6 +242,8 @@
int num_row_groups() const { return reader_->metadata()->num_row_groups(); }
+ int num_columns() const { return reader_->metadata()->num_columns(); }
+
void set_num_threads(int num_threads) { num_threads_ = num_threads; }
ParquetFileReader* reader() { return reader_.get(); }
@@ -520,6 +576,11 @@
return impl_->GetColumn(i, out);
}
+Status FileReader::GetSchema(const std::vector<int>& indices,
+ std::shared_ptr<::arrow::Schema>* out) {
+ return impl_->GetSchema(indices, out);
+}
+
Status FileReader::ReadColumn(int i, std::shared_ptr<Array>* out) {
try {
return impl_->ReadColumn(i, out);
@@ -536,6 +597,40 @@
}
}
+Status FileReader::GetRecordBatchReader(const std::vector<int>& row_group_indices,
+ std::shared_ptr<RecordBatchReader>* out) {
+ std::vector<int> indices(impl_->num_columns());
+
+ for (size_t j = 0; j < indices.size(); ++j) {
+ indices[j] = static_cast<int>(j);
+ }
+
+ return GetRecordBatchReader(row_group_indices, indices, out);
+}
+
+Status FileReader::GetRecordBatchReader(const std::vector<int>& row_group_indices,
+ const std::vector<int>& column_indices,
+ std::shared_ptr<RecordBatchReader>* out) {
+ // column indicies check
+ std::shared_ptr<::arrow::Schema> schema;
+ RETURN_NOT_OK(GetSchema(column_indices, &schema));
+
+ // row group indices check
+ int max_num = num_row_groups();
+ for (auto row_group_index : row_group_indices) {
+ if (row_group_index < 0 || row_group_index >= max_num) {
+ std::ostringstream ss;
+ ss << "Some index in row_group_indices is " << row_group_index
+ << ", which is either < 0 or >= num_row_groups(" << max_num << ")";
+ return Status::Invalid(ss.str());
+ }
+ }
+
+ *out = std::make_shared<RowGroupRecordBatchReader>(row_group_indices, column_indices,
+ schema, this);
+ return Status::OK();
+}
+
Status FileReader::ReadTable(std::shared_ptr<Table>* out) {
try {
return impl_->ReadTable(out);
diff --git a/src/parquet/arrow/reader.h b/src/parquet/arrow/reader.h
index 95b2186..4d68c61 100644
--- a/src/parquet/arrow/reader.h
+++ b/src/parquet/arrow/reader.h
@@ -30,7 +30,7 @@
class Array;
class MemoryPool;
-class RowBatch;
+class RecordBatchReader;
class Status;
class Table;
} // namespace arrow
@@ -112,6 +112,11 @@
// Returns error status if the column of interest is not flat.
::arrow::Status GetColumn(int i, std::unique_ptr<ColumnReader>* out);
+ /// \brief Return arrow schema by apply selection of column indices.
+ /// \returns error status if passed wrong indices.
+ ::arrow::Status GetSchema(const std::vector<int>& indices,
+ std::shared_ptr<::arrow::Schema>* out);
+
// Read column as a whole into an Array.
::arrow::Status ReadColumn(int i, std::shared_ptr<::arrow::Array>* out);
@@ -149,6 +154,21 @@
::arrow::Status ReadSchemaField(int i, const std::vector<int>& indices,
std::shared_ptr<::arrow::Array>* out);
+ /// \brief Return a RecordBatchReader of row groups selected from row_group_indices, the
+ /// ordering in row_group_indices matters.
+ /// \returns error Status if row_group_indices contains invalid index
+ ::arrow::Status GetRecordBatchReader(const std::vector<int>& row_group_indices,
+ std::shared_ptr<::arrow::RecordBatchReader>* out);
+
+ /// \brief Return a RecordBatchReader of row groups selected from row_group_indices,
+ /// whose columns are selected by column_indices. The ordering in row_group_indices
+ /// and column_indices matter.
+ /// \returns error Status if either row_group_indices or column_indices contains invalid
+ /// index
+ ::arrow::Status GetRecordBatchReader(const std::vector<int>& row_group_indices,
+ const std::vector<int>& column_indices,
+ std::shared_ptr<::arrow::RecordBatchReader>* out);
+
// Read a table of columns into a Table
::arrow::Status ReadTable(std::shared_ptr<::arrow::Table>* out);
diff --git a/src/parquet/arrow/writer.h b/src/parquet/arrow/writer.h
index a432850..06008d2 100644
--- a/src/parquet/arrow/writer.h
+++ b/src/parquet/arrow/writer.h
@@ -31,7 +31,6 @@
class Array;
class MemoryPool;
class PrimitiveArray;
-class RowBatch;
class Schema;
class Status;
class StringArray;