Implement a `Vec<RecordBatch>` wrapper for `pyarrow.Table` convenience (#8790)
# Rationale for this change
When dealing with Parquet files that have an exceedingly large amount of
Binary or UTF8 data in one row group, there can be issues when returning
a single RecordBatch because of index overflows
(https://github.com/apache/arrow-rs/issues/7973).
In `pyarrow` this is usually solved by representing data as a
`pyarrow.Table` object whose columns are `ChunkedArray`s, which
basically are just lists of Arrow Arrays, or alternatively, the
`pyarrow.Table` is just a representation of a list of `RecordBatch`es.
I'd like to build a function in PyO3 that returns a `pyarrow.Table`,
very similar to [pyarrow's read_row_group
method](https://arrow.apache.org/docs/python/generated/pyarrow.parquet.ParquetFile.html#pyarrow.parquet.ParquetFile.read_row_group).
With that, we could have feature parity with `pyarrow` in circumstances
of potential index overflows without resorting to type changes (such as
reading the data as `LargeString` or `StringView` columns).
Currently, AFAIS, there is no way in `arrow-pyarrow` to export a
`pyarrow.Table` directly. Especially convenience methods from
`Vec<RecordBatch>` seem to be missing. This PR tries to implement a
convenience wrapper that allows directly exporting `pyarrow.Table`.
# What changes are included in this PR?
A new struct `Table` in the crate `arrow-pyarrow` is added which can be
constructed from `Vec<RecordBatch>` or from `ArrowArrayStreamReader`.
It implements `FromPyArrow` and `IntoPyArrow`.
`FromPyArrow` will support anything that either implements the
ArrowStreamReader protocol or is a RecordBatchReader, or has a
`to_reader()` method which does that. `pyarrow.Table` does both of these
things.
`IntoPyArrow` will result int a `pyarrow.Table` on the Python side,
constructed through `pyarrow.Table.from_batches(...)`.
# Are these changes tested?
Yes, in `arrow-pyarrow-integration-tests`.
# Are there any user-facing changes?
A new `Table` convience wrapper is added!
diff --git a/arrow-pyarrow-integration-testing/src/lib.rs b/arrow-pyarrow-integration-testing/src/lib.rs
index 7d5d63c..a5690b3 100644
--- a/arrow-pyarrow-integration-testing/src/lib.rs
+++ b/arrow-pyarrow-integration-testing/src/lib.rs
@@ -32,7 +32,7 @@
use arrow::datatypes::{DataType, Field, Schema};
use arrow::error::ArrowError;
use arrow::ffi_stream::ArrowArrayStreamReader;
-use arrow::pyarrow::{FromPyArrow, PyArrowException, PyArrowType, ToPyArrow};
+use arrow::pyarrow::{FromPyArrow, PyArrowException, PyArrowType, Table, ToPyArrow};
use arrow::record_batch::RecordBatch;
fn to_py_err(err: ArrowError) -> PyErr {
@@ -141,6 +141,26 @@
}
#[pyfunction]
+fn round_trip_table(obj: PyArrowType<Table>) -> PyResult<PyArrowType<Table>> {
+ Ok(obj)
+}
+
+/// Builds a Table from a list of RecordBatches and a Schema.
+#[pyfunction]
+pub fn build_table(
+ record_batches: Vec<PyArrowType<RecordBatch>>,
+ schema: PyArrowType<Schema>,
+) -> PyResult<PyArrowType<Table>> {
+ Ok(PyArrowType(
+ Table::try_new(
+ record_batches.into_iter().map(|rb| rb.0).collect(),
+ Arc::new(schema.0),
+ )
+ .map_err(to_py_err)?,
+ ))
+}
+
+#[pyfunction]
fn reader_return_errors(obj: PyArrowType<ArrowArrayStreamReader>) -> PyResult<()> {
// This makes sure we can correctly consume a RBR and return the error,
// ensuring the error can live beyond the lifetime of the RBR.
@@ -178,6 +198,8 @@
m.add_wrapped(wrap_pyfunction!(round_trip_array))?;
m.add_wrapped(wrap_pyfunction!(round_trip_record_batch))?;
m.add_wrapped(wrap_pyfunction!(round_trip_record_batch_reader))?;
+ m.add_wrapped(wrap_pyfunction!(round_trip_table))?;
+ m.add_wrapped(wrap_pyfunction!(build_table))?;
m.add_wrapped(wrap_pyfunction!(reader_return_errors))?;
m.add_wrapped(wrap_pyfunction!(boxed_reader_roundtrip))?;
Ok(())
diff --git a/arrow-pyarrow-integration-testing/tests/test_sql.py b/arrow-pyarrow-integration-testing/tests/test_sql.py
index f5d5315..79220fb 100644
--- a/arrow-pyarrow-integration-testing/tests/test_sql.py
+++ b/arrow-pyarrow-integration-testing/tests/test_sql.py
@@ -20,6 +20,7 @@
import datetime
import decimal
import string
+from typing import Union, Tuple, Protocol
import pytest
import pyarrow as pa
@@ -130,28 +131,50 @@
# This defines that Arrow consumers should allow any object that has specific "dunder"
# methods, `__arrow_c_*_`. These wrapper classes ensure that arrow-rs is able to handle
# _any_ class, without pyarrow-specific handling.
-class SchemaWrapper:
- def __init__(self, schema):
+
+
+class ArrowSchemaExportable(Protocol):
+ def __arrow_c_schema__(self) -> object: ...
+
+
+class ArrowArrayExportable(Protocol):
+ def __arrow_c_array__(
+ self,
+ requested_schema: Union[object, None] = None
+ ) -> Tuple[object, object]:
+ ...
+
+
+class ArrowStreamExportable(Protocol):
+ def __arrow_c_stream__(
+ self,
+ requested_schema: Union[object, None] = None
+ ) -> object:
+ ...
+
+
+class SchemaWrapper(ArrowSchemaExportable):
+ def __init__(self, schema: ArrowSchemaExportable) -> None:
self.schema = schema
- def __arrow_c_schema__(self):
+ def __arrow_c_schema__(self) -> object:
return self.schema.__arrow_c_schema__()
-class ArrayWrapper:
- def __init__(self, array):
+class ArrayWrapper(ArrowArrayExportable):
+ def __init__(self, array: ArrowArrayExportable) -> None:
self.array = array
- def __arrow_c_array__(self):
- return self.array.__arrow_c_array__()
+ def __arrow_c_array__(self, requested_schema: Union[object, None] = None) -> Tuple[object, object]:
+ return self.array.__arrow_c_array__(requested_schema=requested_schema)
-class StreamWrapper:
- def __init__(self, stream):
+class StreamWrapper(ArrowStreamExportable):
+ def __init__(self, stream: ArrowStreamExportable) -> None:
self.stream = stream
- def __arrow_c_stream__(self):
- return self.stream.__arrow_c_stream__()
+ def __arrow_c_stream__(self, requested_schema: Union[object, None] = None) -> object:
+ return self.stream.__arrow_c_stream__(requested_schema=requested_schema)
@pytest.mark.parametrize("pyarrow_type", _supported_pyarrow_types, ids=str)
@@ -632,6 +655,67 @@
assert len(table.to_batches()) == len(new_table.to_batches())
+def test_table_empty():
+ """
+ Python -> Rust -> Python
+ """
+ schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1': b'value1'})
+ table = pa.Table.from_batches([], schema=schema)
+ new_table = rust.build_table([], schema=schema)
+
+ assert table.schema == new_table.schema
+ assert table == new_table
+ assert len(table.to_batches()) == len(new_table.to_batches())
+
+
+def test_table_roundtrip():
+ """
+ Python -> Rust -> Python
+ """
+ schema = pa.schema([('ints', pa.list_(pa.int32()))])
+ batches = [
+ pa.record_batch([[[1], [2, 42]]], schema),
+ pa.record_batch([[None, [], [5, 6]]], schema),
+ ]
+ table = pa.Table.from_batches(batches, schema=schema)
+ new_table = rust.round_trip_table(table)
+
+ assert table.schema == new_table.schema
+ assert table == new_table
+ assert len(table.to_batches()) == len(new_table.to_batches())
+
+
+def test_table_from_batches():
+ """
+ Python -> Rust -> Python
+ """
+ schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1': b'value1'})
+ batches = [
+ pa.record_batch([[[1], [2, 42]]], schema),
+ pa.record_batch([[None, [], [5, 6]]], schema),
+ ]
+ table = pa.Table.from_batches(batches)
+ new_table = rust.build_table(batches, schema)
+
+ assert table.schema == new_table.schema
+ assert table == new_table
+ assert len(table.to_batches()) == len(new_table.to_batches())
+
+
+def test_table_error_inconsistent_schema():
+ """
+ Python -> Rust -> Python
+ """
+ schema_1 = pa.schema([('ints', pa.list_(pa.int32()))])
+ schema_2 = pa.schema([('floats', pa.list_(pa.float32()))])
+ batches = [
+ pa.record_batch([[[1], [2, 42]]], schema_1),
+ pa.record_batch([[None, [], [5.6, 6.4]]], schema_2),
+ ]
+ with pytest.raises(pa.ArrowException, match="Schema error: All record batches must have the same schema."):
+ rust.build_table(batches, schema_1)
+
+
def test_reject_other_classes():
# Arbitrary type that is not a PyArrow type
not_pyarrow = ["hello"]
diff --git a/arrow-pyarrow/src/lib.rs b/arrow-pyarrow/src/lib.rs
index d4bbb20..1f8941e 100644
--- a/arrow-pyarrow/src/lib.rs
+++ b/arrow-pyarrow/src/lib.rs
@@ -44,17 +44,20 @@
//! | `pyarrow.Array` | [ArrayData] |
//! | `pyarrow.RecordBatch` | [RecordBatch] |
//! | `pyarrow.RecordBatchReader` | [ArrowArrayStreamReader] / `Box<dyn RecordBatchReader + Send>` (1) |
+//! | `pyarrow.Table` | [Table] (2) |
//!
//! (1) `pyarrow.RecordBatchReader` can be imported as [ArrowArrayStreamReader]. Either
//! [ArrowArrayStreamReader] or `Box<dyn RecordBatchReader + Send>` can be exported
//! as `pyarrow.RecordBatchReader`. (`Box<dyn RecordBatchReader + Send>` is typically
//! easier to create.)
//!
-//! PyArrow has the notion of chunked arrays and tables, but arrow-rs doesn't
-//! have these same concepts. A chunked table is instead represented with
-//! `Vec<RecordBatch>`. A `pyarrow.Table` can be imported to Rust by calling
-//! [pyarrow.Table.to_reader()](https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table.to_reader)
-//! and then importing the reader as a [ArrowArrayStreamReader].
+//! (2) Although arrow-rs offers [Table], a convenience wrapper for [pyarrow.Table](https://arrow.apache.org/docs/python/generated/pyarrow.Table)
+//! that internally holds `Vec<RecordBatch>`, it is meant primarily for use cases where you already
+//! have `Vec<RecordBatch>` on the Rust side and want to export that in bulk as a `pyarrow.Table`.
+//! In general, it is recommended to use streaming approaches instead of dealing with data in bulk.
+//! For example, a `pyarrow.Table` (or any other object that implements the ArrayStream PyCapsule
+//! interface) can be imported to Rust through `PyArrowType<ArrowArrayStreamReader>` instead of
+//! forcing eager reading into `Vec<RecordBatch>`.
use std::convert::{From, TryFrom};
use std::ptr::{addr_of, addr_of_mut};
@@ -68,13 +71,13 @@
make_array,
};
use arrow_data::ArrayData;
-use arrow_schema::{ArrowError, DataType, Field, Schema};
+use arrow_schema::{ArrowError, DataType, Field, Schema, SchemaRef};
use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::ffi::Py_uintptr_t;
-use pyo3::import_exception;
use pyo3::prelude::*;
use pyo3::pybacked::PyBackedStr;
-use pyo3::types::{PyCapsule, PyList, PyTuple};
+use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple};
+use pyo3::{import_exception, intern};
import_exception!(pyarrow, ArrowException);
/// Represents an exception raised by PyArrow.
@@ -484,6 +487,100 @@
}
}
+/// This is a convenience wrapper around `Vec<RecordBatch>` that tries to simplify conversion from
+/// and to `pyarrow.Table`.
+///
+/// This could be used in circumstances where you either want to consume a `pyarrow.Table` directly
+/// (although technically, since `pyarrow.Table` implements the ArrayStreamReader PyCapsule
+/// interface, one could also consume a `PyArrowType<ArrowArrayStreamReader>` instead) or, more
+/// importantly, where one wants to export a `pyarrow.Table` from a `Vec<RecordBatch>` from the Rust
+/// side.
+///
+/// ```ignore
+/// #[pyfunction]
+/// fn return_table(...) -> PyResult<PyArrowType<Table>> {
+/// let batches: Vec<RecordBatch>;
+/// let schema: SchemaRef;
+/// PyArrowType(Table::try_new(batches, schema).map_err(|err| err.into_py_err(py))?)
+/// }
+/// ```
+#[derive(Clone)]
+pub struct Table {
+ record_batches: Vec<RecordBatch>,
+ schema: SchemaRef,
+}
+
+impl Table {
+ pub fn try_new(
+ record_batches: Vec<RecordBatch>,
+ schema: SchemaRef,
+ ) -> Result<Self, ArrowError> {
+ for record_batch in &record_batches {
+ if schema != record_batch.schema() {
+ return Err(ArrowError::SchemaError(format!(
+ "All record batches must have the same schema. \
+ Expected schema: {:?}, got schema: {:?}",
+ schema,
+ record_batch.schema()
+ )));
+ }
+ }
+ Ok(Self {
+ record_batches,
+ schema,
+ })
+ }
+
+ pub fn record_batches(&self) -> &[RecordBatch] {
+ &self.record_batches
+ }
+
+ pub fn schema(&self) -> SchemaRef {
+ self.schema.clone()
+ }
+
+ pub fn into_inner(self) -> (Vec<RecordBatch>, SchemaRef) {
+ (self.record_batches, self.schema)
+ }
+}
+
+impl TryFrom<Box<dyn RecordBatchReader>> for Table {
+ type Error = ArrowError;
+
+ fn try_from(value: Box<dyn RecordBatchReader>) -> Result<Self, ArrowError> {
+ let schema = value.schema();
+ let batches = value.collect::<Result<Vec<_>, _>>()?;
+ Self::try_new(batches, schema)
+ }
+}
+
+/// Convert a `pyarrow.Table` (or any other ArrowArrayStream compliant object) into [`Table`]
+impl FromPyArrow for Table {
+ fn from_pyarrow_bound(ob: &Bound<PyAny>) -> PyResult<Self> {
+ let reader: Box<dyn RecordBatchReader> =
+ Box::new(ArrowArrayStreamReader::from_pyarrow_bound(ob)?);
+ Self::try_from(reader).map_err(|err| PyErr::new::<PyValueError, _>(err.to_string()))
+ }
+}
+
+/// Convert a [`Table`] into `pyarrow.Table`.
+impl IntoPyArrow for Table {
+ fn into_pyarrow(self, py: Python) -> PyResult<Bound<PyAny>> {
+ let module = py.import(intern!(py, "pyarrow"))?;
+ let class = module.getattr(intern!(py, "Table"))?;
+
+ let py_batches = PyList::new(py, self.record_batches.into_iter().map(PyArrowType))?;
+ let py_schema = PyArrowType(Arc::unwrap_or_clone(self.schema));
+
+ let kwargs = PyDict::new(py);
+ kwargs.set_item("schema", py_schema)?;
+
+ let reader = class.call_method("from_batches", (py_batches,), Some(&kwargs))?;
+
+ Ok(reader)
+ }
+}
+
/// A newtype wrapper for types implementing [`FromPyArrow`] or [`IntoPyArrow`].
///
/// When wrapped around a type `T: FromPyArrow`, it