ARROW-12686: [C++][Python][FlightRPC] Convert Flight reader into a regular reader
This provides compatibility with APIs that expect regular RecordBatchReaders, e.g. exporting via the C Data Interface. The Flight interface itself cannot implement RecordBatchReader because getting the schema is not an infallible operation.
Closes #10267 from lidavidm/arrow-12686
Authored-by: David Li <li.davidm96@gmail.com>
Signed-off-by: David Li <li.davidm96@gmail.com>
diff --git a/cpp/src/arrow/flight/flight_test.cc b/cpp/src/arrow/flight/flight_test.cc
index 099d416..35993f1 100644
--- a/cpp/src/arrow/flight/flight_test.cc
+++ b/cpp/src/arrow/flight/flight_test.cc
@@ -427,12 +427,20 @@
std::unique_ptr<FlightStreamReader> stream;
ASSERT_OK(client_->DoGet(ticket, &stream));
+ std::unique_ptr<FlightStreamReader> stream2;
+ ASSERT_OK(client_->DoGet(ticket, &stream2));
+ ASSERT_OK_AND_ASSIGN(auto reader, MakeRecordBatchReader(std::move(stream2)));
+
FlightStreamChunk chunk;
+ std::shared_ptr<RecordBatch> batch;
for (int i = 0; i < num_batches; ++i) {
ASSERT_OK(stream->Next(&chunk));
+ ASSERT_OK(reader->ReadNext(&batch));
ASSERT_NE(nullptr, chunk.data);
+ ASSERT_NE(nullptr, batch);
#if !defined(__MINGW32__)
ASSERT_BATCHES_EQUAL(*expected_batches[i], *chunk.data);
+ ASSERT_BATCHES_EQUAL(*expected_batches[i], *batch);
#else
// In MINGW32, the following code does not have the reproducibility at the LSB
// even when this is called twice with the same seed.
@@ -444,12 +452,15 @@
// [&dist, &rng] { return static_cast<ValueType>(dist(rng)); });
// /* data[1] = 0x40852cdfe23d3976 or 0x40852cdfe23d3975 */
ASSERT_BATCHES_APPROX_EQUAL(*expected_batches[i], *chunk.data);
+ ASSERT_BATCHES_APPROX_EQUAL(*expected_batches[i], *batch);
#endif
}
// Stream exhausted
ASSERT_OK(stream->Next(&chunk));
+ ASSERT_OK(reader->ReadNext(&batch));
ASSERT_EQ(nullptr, chunk.data);
+ ASSERT_EQ(nullptr, batch);
}
protected:
diff --git a/cpp/src/arrow/flight/types.cc b/cpp/src/arrow/flight/types.cc
index 84973f0..8139b21 100644
--- a/cpp/src/arrow/flight/types.cc
+++ b/cpp/src/arrow/flight/types.cc
@@ -284,6 +284,42 @@
return Begin(schema, ipc::IpcWriteOptions::Defaults());
}
+namespace {
+class MetadataRecordBatchReaderAdapter : public RecordBatchReader {
+ public:
+ explicit MetadataRecordBatchReaderAdapter(
+ std::shared_ptr<Schema> schema, std::shared_ptr<MetadataRecordBatchReader> delegate)
+ : schema_(std::move(schema)), delegate_(std::move(delegate)) {}
+ std::shared_ptr<Schema> schema() const override { return schema_; }
+ Status ReadNext(std::shared_ptr<RecordBatch>* batch) override {
+ FlightStreamChunk next;
+ while (true) {
+ RETURN_NOT_OK(delegate_->Next(&next));
+ if (!next.data && !next.app_metadata) {
+ // EOS
+ *batch = nullptr;
+ return Status::OK();
+ } else if (next.data) {
+ *batch = std::move(next.data);
+ return Status::OK();
+ }
+ // Got metadata, but no data (which is valid) - read the next message
+ }
+ }
+
+ private:
+ std::shared_ptr<Schema> schema_;
+ std::shared_ptr<MetadataRecordBatchReader> delegate_;
+};
+}; // namespace
+
+arrow::Result<std::shared_ptr<RecordBatchReader>> MakeRecordBatchReader(
+ std::shared_ptr<MetadataRecordBatchReader> reader) {
+ ARROW_ASSIGN_OR_RAISE(auto schema, reader->GetSchema());
+ return std::make_shared<MetadataRecordBatchReaderAdapter>(std::move(schema),
+ std::move(reader));
+}
+
SimpleFlightListing::SimpleFlightListing(const std::vector<FlightInfo>& flights)
: position_(0), flights_(flights) {}
diff --git a/cpp/src/arrow/flight/types.h b/cpp/src/arrow/flight/types.h
index 7538e4b..cd37318 100644
--- a/cpp/src/arrow/flight/types.h
+++ b/cpp/src/arrow/flight/types.h
@@ -474,6 +474,11 @@
virtual Status ReadAll(std::shared_ptr<Table>* table);
};
+/// \brief Convert a MetadataRecordBatchReader to a regular RecordBatchReader.
+ARROW_FLIGHT_EXPORT
+arrow::Result<std::shared_ptr<RecordBatchReader>> MakeRecordBatchReader(
+ std::shared_ptr<MetadataRecordBatchReader> reader);
+
/// \brief An interface to write IPC payloads with metadata.
class ARROW_FLIGHT_EXPORT MetadataRecordBatchWriter : public ipc::RecordBatchWriter {
public:
diff --git a/python/pyarrow/_flight.pyx b/python/pyarrow/_flight.pyx
index 7a8dcdb..e5d80df 100644
--- a/python/pyarrow/_flight.pyx
+++ b/python/pyarrow/_flight.pyx
@@ -874,6 +874,16 @@
return chunk
+ def to_reader(self):
+ """Convert this reader into a regular RecordBatchReader.
+
+ This may fail if the schema cannot be read from the remote end.
+ """
+ cdef RecordBatchReader reader
+ reader = RecordBatchReader.__new__(RecordBatchReader)
+ reader.reader = GetResultValue(MakeRecordBatchReader(self.reader))
+ return reader
+
cdef class MetadataRecordBatchReader(_MetadataRecordBatchReader):
"""The virtual base class for readers for Flight streams."""
diff --git a/python/pyarrow/includes/libarrow_flight.pxd b/python/pyarrow/includes/libarrow_flight.pxd
index 161a804..737babb 100644
--- a/python/pyarrow/includes/libarrow_flight.pxd
+++ b/python/pyarrow/includes/libarrow_flight.pxd
@@ -151,6 +151,10 @@
CStatus Next(CFlightStreamChunk* out)
CStatus ReadAll(shared_ptr[CTable]* table)
+ CResult[shared_ptr[CRecordBatchReader]] MakeRecordBatchReader\
+ " arrow::flight::MakeRecordBatchReader"(
+ shared_ptr[CMetadataRecordBatchReader])
+
cdef cppclass CMetadataRecordBatchWriter \
" arrow::flight::MetadataRecordBatchWriter"(CRecordBatchWriter):
CStatus Begin(shared_ptr[CSchema] schema,
diff --git a/python/pyarrow/tests/test_flight.py b/python/pyarrow/tests/test_flight.py
index 45ba5c2..585fdb2 100644
--- a/python/pyarrow/tests/test_flight.py
+++ b/python/pyarrow/tests/test_flight.py
@@ -857,6 +857,10 @@
data = client.do_get(flight.Ticket(b'ints')).read_all()
assert data.equals(table)
+ # Also test via RecordBatchReader interface
+ data = client.do_get(flight.Ticket(b'ints')).to_reader().read_all()
+ assert data.equals(table)
+
with pytest.raises(flight.FlightServerError,
match="expected IpcWriteOptions, got <class 'int'>"):
with ConstantFlightServer(options=42) as server: