blob: a5690b3070409d83d9e3c5b106c3839785b4a6a3 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
//! This library demonstrates a minimal usage of Rust's C data interface to pass
//! arrays from and to Python.
#![warn(missing_docs)]
use std::sync::Arc;
use arrow::array::new_empty_array;
use arrow::record_batch::{RecordBatchIterator, RecordBatchReader};
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::wrap_pyfunction;
use arrow::array::{Array, ArrayData, ArrayRef, Int64Array, make_array};
use arrow::compute::kernels;
use arrow::datatypes::{DataType, Field, Schema};
use arrow::error::ArrowError;
use arrow::ffi_stream::ArrowArrayStreamReader;
use arrow::pyarrow::{FromPyArrow, PyArrowException, PyArrowType, Table, ToPyArrow};
use arrow::record_batch::RecordBatch;
fn to_py_err(err: ArrowError) -> PyErr {
PyArrowException::new_err(err.to_string())
}
/// Returns `array + array` of an int64 array.
#[pyfunction]
fn double<'py>(array: &Bound<'py, PyAny>) -> PyResult<Bound<'py, PyAny>> {
let py = array.py();
// import
let array = make_array(ArrayData::from_pyarrow_bound(array)?);
// perform some operation
let array = array
.as_any()
.downcast_ref::<Int64Array>()
.ok_or_else(|| ArrowError::ParseError("Expects an int64".to_string()))
.map_err(to_py_err)?;
let array = kernels::numeric::add(array, array).map_err(to_py_err)?;
// export
array.to_data().to_pyarrow(py)
}
/// calls a lambda function that receives and returns an array
/// whose result must be the array multiplied by two
#[pyfunction]
fn double_py(lambda: &Bound<PyAny>) -> PyResult<bool> {
// create
let array = Arc::new(Int64Array::from(vec![Some(1), None, Some(3)]));
let expected = Arc::new(Int64Array::from(vec![Some(2), None, Some(6)])) as ArrayRef;
// to py
let pyarray = array.to_data().to_pyarrow(lambda.py())?;
let pyarray = lambda.call1((pyarray,))?;
let array = make_array(ArrayData::from_pyarrow_bound(&pyarray)?);
Ok(array == expected)
}
#[pyfunction]
fn make_empty_array<'py>(
datatype: PyArrowType<DataType>,
py: Python<'py>,
) -> PyResult<Bound<'py, PyAny>> {
let array = new_empty_array(&datatype.0);
array.to_data().to_pyarrow(py)
}
/// Returns the substring
#[pyfunction]
fn substring(array: PyArrowType<ArrayData>, start: i64) -> PyResult<PyArrowType<ArrayData>> {
// import
let array = make_array(array.0);
// substring
let array = kernels::substring::substring(array.as_ref(), start, None).map_err(to_py_err)?;
Ok(array.to_data().into())
}
/// Returns the concatenate
#[pyfunction]
fn concatenate<'py>(array: PyArrowType<ArrayData>, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
let array = make_array(array.0);
// concat
let array = kernels::concat::concat(&[array.as_ref(), array.as_ref()]).map_err(to_py_err)?;
array.to_data().to_pyarrow(py)
}
#[pyfunction]
fn round_trip_type(obj: PyArrowType<DataType>) -> PyResult<PyArrowType<DataType>> {
Ok(obj)
}
#[pyfunction]
fn round_trip_field(obj: PyArrowType<Field>) -> PyResult<PyArrowType<Field>> {
Ok(obj)
}
#[pyfunction]
fn round_trip_schema(obj: PyArrowType<Schema>) -> PyResult<PyArrowType<Schema>> {
Ok(obj)
}
#[pyfunction]
fn round_trip_array(obj: PyArrowType<ArrayData>) -> PyResult<PyArrowType<ArrayData>> {
Ok(obj)
}
#[pyfunction]
fn round_trip_record_batch(obj: PyArrowType<RecordBatch>) -> PyResult<PyArrowType<RecordBatch>> {
Ok(obj)
}
#[pyfunction]
fn round_trip_record_batch_reader(
obj: PyArrowType<ArrowArrayStreamReader>,
) -> PyResult<PyArrowType<ArrowArrayStreamReader>> {
Ok(obj)
}
#[pyfunction]
fn round_trip_table(obj: PyArrowType<Table>) -> PyResult<PyArrowType<Table>> {
Ok(obj)
}
/// Builds a Table from a list of RecordBatches and a Schema.
#[pyfunction]
pub fn build_table(
record_batches: Vec<PyArrowType<RecordBatch>>,
schema: PyArrowType<Schema>,
) -> PyResult<PyArrowType<Table>> {
Ok(PyArrowType(
Table::try_new(
record_batches.into_iter().map(|rb| rb.0).collect(),
Arc::new(schema.0),
)
.map_err(to_py_err)?,
))
}
#[pyfunction]
fn reader_return_errors(obj: PyArrowType<ArrowArrayStreamReader>) -> PyResult<()> {
// This makes sure we can correctly consume a RBR and return the error,
// ensuring the error can live beyond the lifetime of the RBR.
let batches = obj.0.collect::<Result<Vec<RecordBatch>, ArrowError>>();
match batches {
Ok(_) => Ok(()),
Err(err) => Err(PyValueError::new_err(err.to_string())),
}
}
#[pyfunction]
fn boxed_reader_roundtrip(
obj: PyArrowType<ArrowArrayStreamReader>,
) -> PyArrowType<Box<dyn RecordBatchReader + Send>> {
let schema = obj.0.schema();
let batches = obj
.0
.collect::<Result<Vec<RecordBatch>, ArrowError>>()
.unwrap();
let reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema);
let reader: Box<dyn RecordBatchReader + Send> = Box::new(reader);
PyArrowType(reader)
}
#[pymodule]
fn arrow_pyarrow_integration_testing(_py: Python, m: &Bound<PyModule>) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(double))?;
m.add_wrapped(wrap_pyfunction!(double_py))?;
m.add_wrapped(wrap_pyfunction!(make_empty_array))?;
m.add_wrapped(wrap_pyfunction!(substring))?;
m.add_wrapped(wrap_pyfunction!(concatenate))?;
m.add_wrapped(wrap_pyfunction!(round_trip_type))?;
m.add_wrapped(wrap_pyfunction!(round_trip_field))?;
m.add_wrapped(wrap_pyfunction!(round_trip_schema))?;
m.add_wrapped(wrap_pyfunction!(round_trip_array))?;
m.add_wrapped(wrap_pyfunction!(round_trip_record_batch))?;
m.add_wrapped(wrap_pyfunction!(round_trip_record_batch_reader))?;
m.add_wrapped(wrap_pyfunction!(round_trip_table))?;
m.add_wrapped(wrap_pyfunction!(build_table))?;
m.add_wrapped(wrap_pyfunction!(reader_return_errors))?;
m.add_wrapped(wrap_pyfunction!(boxed_reader_roundtrip))?;
Ok(())
}