Officially maintained Arrow2 branch
Co-authored-by: Jorge C. Leitao <jorgecarleitao@gmail.com>
Co-authored-by: Yijie Shen <henry.yijieshen@gmail.com>
Co-authored-by: Guillaume Balaine <igosuki@gmail.com>
Co-authored-by: Guillaume Balaine <igosuki.github@gmail.com>
diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml
index 1b34d44..322aacb 100644
--- a/.github/workflows/rust.yml
+++ b/.github/workflows/rust.yml
@@ -383,8 +383,7 @@
run: |
cargo miri setup
cargo clean
- # Ignore MIRI errors until we can get a clean run
- cargo miri test || true
+ cargo miri test
# Check answers are correct when hash values collide
hash-collisions:
diff --git a/Cargo.toml b/Cargo.toml
index f7e9c03..0aab116 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -36,3 +36,7 @@
[profile.release]
lto = true
codegen-units = 1
+
+[patch.crates-io]
+arrow2 = { git = "https://github.com/jorgecarleitao/arrow2.git", rev = "v0.10.0" }
+parquet2 = { git = "https://github.com/jorgecarleitao/parquet2.git", rev = "v0.10.1" }
diff --git a/ballista-examples/src/bin/ballista-dataframe.rs b/ballista-examples/src/bin/ballista-dataframe.rs
index 8399324..345b698 100644
--- a/ballista-examples/src/bin/ballista-dataframe.rs
+++ b/ballista-examples/src/bin/ballista-dataframe.rs
@@ -27,7 +27,7 @@
.build()?;
let ctx = BallistaContext::remote("localhost", 50050, &config);
- let testdata = datafusion::arrow::util::test_util::parquet_test_data();
+ let testdata = datafusion::test_util::parquet_test_data();
let filename = &format!("{}/alltypes_plain.parquet", testdata);
diff --git a/ballista-examples/src/bin/ballista-sql.rs b/ballista-examples/src/bin/ballista-sql.rs
index 3e0df21..25fc333 100644
--- a/ballista-examples/src/bin/ballista-sql.rs
+++ b/ballista-examples/src/bin/ballista-sql.rs
@@ -27,7 +27,7 @@
.build()?;
let ctx = BallistaContext::remote("localhost", 50050, &config);
- let testdata = datafusion::arrow::util::test_util::arrow_test_data();
+ let testdata = datafusion::test_util::arrow_test_data();
// register csv file with the execution context
ctx.register_csv(
diff --git a/ballista/rust/client/README.md b/ballista/rust/client/README.md
index c27b838..f3bbcee 100644
--- a/ballista/rust/client/README.md
+++ b/ballista/rust/client/README.md
@@ -95,7 +95,7 @@
```rust,no_run
use ballista::prelude::*;
-use datafusion::arrow::util::pretty;
+use datafusion::arrow::io::print;
use datafusion::prelude::CsvReadOptions;
#[tokio::main]
@@ -125,7 +125,7 @@
// collect the results and print them to stdout
let results = df.collect().await?;
- pretty::print_batches(&results)?;
+ print::print(&results);
Ok(())
}
```
diff --git a/ballista/rust/client/src/columnar_batch.rs b/ballista/rust/client/src/columnar_batch.rs
index 3431f56..5177261 100644
--- a/ballista/rust/client/src/columnar_batch.rs
+++ b/ballista/rust/client/src/columnar_batch.rs
@@ -21,9 +21,11 @@
use datafusion::arrow::{
array::ArrayRef,
+ compute::aggregate::estimated_bytes_size,
datatypes::{DataType, Schema},
- record_batch::RecordBatch,
};
+use datafusion::field_util::{FieldExt, SchemaExt};
+use datafusion::record_batch::RecordBatch;
use datafusion::scalar::ScalarValue;
pub type MaybeColumnarBatch = Result<Option<ColumnarBatch>>;
@@ -43,14 +45,14 @@
.enumerate()
.map(|(i, array)| {
(
- batch.schema().field(i).name().clone(),
+ batch.schema().field(i).name().to_string(),
ColumnarValue::Columnar(array.clone()),
)
})
.collect();
Self {
- schema: batch.schema(),
+ schema: batch.schema().clone(),
columns,
}
}
@@ -60,7 +62,7 @@
.fields()
.iter()
.enumerate()
- .map(|(i, f)| (f.name().clone(), values[i].clone()))
+ .map(|(i, f)| (f.name().to_string(), values[i].clone()))
.collect();
Self {
@@ -156,7 +158,7 @@
pub fn memory_size(&self) -> usize {
match self {
- ColumnarValue::Columnar(array) => array.get_array_memory_size(),
+ ColumnarValue::Columnar(array) => estimated_bytes_size(array.as_ref()),
_ => 0,
}
}
diff --git a/ballista/rust/core/Cargo.toml b/ballista/rust/core/Cargo.toml
index 0c374b3..83cf199 100644
--- a/ballista/rust/core/Cargo.toml
+++ b/ballista/rust/core/Cargo.toml
@@ -46,7 +46,9 @@
clap = { version = "3", features = ["derive", "cargo"] }
parse_arg = "0.1.3"
-arrow-flight = { version = "10.0" }
+arrow-format = { version = "0.4", features = ["flight-data", "flight-service"] }
+arrow = { package = "arrow2", version="0.10", features = ["io_ipc", "io_flight"] }
+
datafusion = { path = "../../../datafusion", version = "7.0.0" }
datafusion-proto = { path = "../../../datafusion-proto", version = "7.0.0" }
diff --git a/ballista/rust/core/src/client.rs b/ballista/rust/core/src/client.rs
index 5441888..ed8886f 100644
--- a/ballista/rust/core/src/client.rs
+++ b/ballista/rust/core/src/client.rs
@@ -17,10 +17,12 @@
//! Client API for sending requests to executors.
+use arrow::io::flight::deserialize_schemas;
+use arrow::io::ipc::IpcSchema;
+use std::collections::HashMap;
use std::sync::Arc;
-
use std::{
- convert::{TryFrom, TryInto},
+ convert::TryInto,
task::{Context, Poll},
};
@@ -28,16 +30,16 @@
use crate::serde::protobuf::{self};
use crate::serde::scheduler::Action;
-use arrow_flight::utils::flight_data_to_arrow_batch;
-use arrow_flight::Ticket;
-use arrow_flight::{flight_service_client::FlightServiceClient, FlightData};
+use arrow_format::flight::data::{FlightData, Ticket};
+use arrow_format::flight::service::flight_service_client::FlightServiceClient;
use datafusion::arrow::{
- datatypes::{Schema, SchemaRef},
+ datatypes::SchemaRef,
error::{ArrowError, Result as ArrowResult},
- record_batch::RecordBatch,
};
-
-use datafusion::physical_plan::{RecordBatchStream, SendableRecordBatchStream};
+use datafusion::field_util::SchemaExt;
+use datafusion::physical_plan::RecordBatchStream;
+use datafusion::physical_plan::SendableRecordBatchStream;
+use datafusion::record_batch::RecordBatch;
use futures::{Stream, StreamExt};
use log::debug;
use prost::Message;
@@ -116,10 +118,12 @@
{
Some(flight_data) => {
// convert FlightData to a stream
- let schema = Arc::new(Schema::try_from(&flight_data)?);
+ let (schema, ipc_schema) =
+ deserialize_schemas(flight_data.data_body.as_slice()).unwrap();
+ let schema = Arc::new(schema);
// all the remaining stream messages should be dictionary and record batches
- Ok(Box::pin(FlightDataStream::new(stream, schema)))
+ Ok(Box::pin(FlightDataStream::new(stream, schema, ipc_schema)))
}
None => Err(ballista_error(
"Did not receive schema batch from flight server",
@@ -131,11 +135,20 @@
struct FlightDataStream {
stream: Streaming<FlightData>,
schema: SchemaRef,
+ ipc_schema: IpcSchema,
}
impl FlightDataStream {
- pub fn new(stream: Streaming<FlightData>, schema: SchemaRef) -> Self {
- Self { stream, schema }
+ pub fn new(
+ stream: Streaming<FlightData>,
+ schema: SchemaRef,
+ ipc_schema: IpcSchema,
+ ) -> Self {
+ Self {
+ stream,
+ schema,
+ ipc_schema,
+ }
}
}
@@ -151,12 +164,16 @@
let converted_chunk = flight_data_chunk_result
.map_err(|e| ArrowError::from_external_error(Box::new(e)))
.and_then(|flight_data_chunk| {
- flight_data_to_arrow_batch(
+ let hm = HashMap::new();
+
+ arrow::io::flight::deserialize_batch(
&flight_data_chunk,
- self.schema.clone(),
- &[],
+ self.schema.fields(),
+ &self.ipc_schema,
+ &hm,
)
- });
+ })
+ .map(|c| RecordBatch::new_with_chunk(&self.schema, c));
Some(converted_chunk)
}
None => None,
diff --git a/ballista/rust/core/src/config.rs b/ballista/rust/core/src/config.rs
index 8cdaf1f..7cf4e0e 100644
--- a/ballista/rust/core/src/config.rs
+++ b/ballista/rust/core/src/config.rs
@@ -135,7 +135,7 @@
.map_err(|e| format!("{:?}", e))?;
}
_ => {
- return Err(format!("not support data type: {}", data_type));
+ return Err(format!("not support data type: {:?}", data_type));
}
}
diff --git a/ballista/rust/core/src/execution_plans/shuffle_reader.rs b/ballista/rust/core/src/execution_plans/shuffle_reader.rs
index 3bebcd1..9762d64 100644
--- a/ballista/rust/core/src/execution_plans/shuffle_reader.rs
+++ b/ballista/rust/core/src/execution_plans/shuffle_reader.rs
@@ -24,7 +24,6 @@
use crate::utils::WrappedStream;
use async_trait::async_trait;
use datafusion::arrow::datatypes::SchemaRef;
-
use datafusion::execution::runtime_env::RuntimeEnv;
use datafusion::physical_plan::expressions::PhysicalSortExpr;
use datafusion::physical_plan::metrics::{
diff --git a/ballista/rust/core/src/execution_plans/shuffle_writer.rs b/ballista/rust/core/src/execution_plans/shuffle_writer.rs
index b80fc84..55925ed 100644
--- a/ballista/rust/core/src/execution_plans/shuffle_writer.rs
+++ b/ballista/rust/core/src/execution_plans/shuffle_writer.rs
@@ -20,8 +20,6 @@
//! partition is re-partitioned and streamed to disk in Arrow IPC format. Future stages of the query
//! will use the ShuffleReaderExec to read these results.
-use datafusion::physical_plan::expressions::PhysicalSortExpr;
-
use std::any::Any;
use std::iter::Iterator;
use std::path::PathBuf;
@@ -33,16 +31,12 @@
use crate::serde::protobuf::ShuffleWritePartition;
use crate::serde::scheduler::PartitionStats;
use async_trait::async_trait;
-use datafusion::arrow::array::{
- Array, ArrayBuilder, ArrayRef, StringBuilder, StructBuilder, UInt32Builder,
- UInt64Builder,
-};
+use datafusion::arrow::array::*;
use datafusion::arrow::compute::take;
use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef};
-
-use datafusion::arrow::record_batch::RecordBatch;
use datafusion::error::{DataFusionError, Result};
use datafusion::execution::runtime_env::RuntimeEnv;
+use datafusion::field_util::SchemaExt;
use datafusion::physical_plan::common::IPCWriter;
use datafusion::physical_plan::hash_utils::create_hashes;
use datafusion::physical_plan::memory::MemoryStream;
@@ -53,8 +47,10 @@
use datafusion::physical_plan::{
DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, Statistics,
};
+use datafusion::record_batch::RecordBatch;
use futures::StreamExt;
+use datafusion::physical_plan::expressions::PhysicalSortExpr;
use log::{debug, info};
/// ShuffleWriterExec represents a section of a query plan that has consistent partitioning and
@@ -230,21 +226,24 @@
for (output_partition, partition_indices) in
indices.into_iter().enumerate()
{
- let indices = partition_indices.into();
-
// Produce batches based on indices
let columns = input_batch
.columns()
.iter()
.map(|c| {
- take(c.as_ref(), &indices, None).map_err(|e| {
- DataFusionError::Execution(e.to_string())
- })
+ take::take(
+ c.as_ref(),
+ &PrimitiveArray::<u64>::from_slice(
+ &partition_indices,
+ ),
+ )
+ .map_err(|e| DataFusionError::Execution(e.to_string()))
+ .map(ArrayRef::from)
})
.collect::<Result<Vec<Arc<dyn Array>>>>()?;
let output_batch =
- RecordBatch::try_new(input_batch.schema(), columns)?;
+ RecordBatch::try_new(input_batch.schema().clone(), columns)?;
// write non-empty batch out
@@ -364,36 +363,34 @@
// build metadata result batch
let num_writers = part_loc.len();
- let mut partition_builder = UInt32Builder::new(num_writers);
- let mut path_builder = StringBuilder::new(num_writers);
- let mut num_rows_builder = UInt64Builder::new(num_writers);
- let mut num_batches_builder = UInt64Builder::new(num_writers);
- let mut num_bytes_builder = UInt64Builder::new(num_writers);
+ let mut partition_builder = UInt32Vec::with_capacity(num_writers);
+ let mut path_builder = MutableUtf8Array::<i32>::with_capacity(num_writers);
+ let mut num_rows_builder = UInt64Vec::with_capacity(num_writers);
+ let mut num_batches_builder = UInt64Vec::with_capacity(num_writers);
+ let mut num_bytes_builder = UInt64Vec::with_capacity(num_writers);
for loc in &part_loc {
- path_builder.append_value(loc.path.clone())?;
- partition_builder.append_value(loc.partition_id as u32)?;
- num_rows_builder.append_value(loc.num_rows)?;
- num_batches_builder.append_value(loc.num_batches)?;
- num_bytes_builder.append_value(loc.num_bytes)?;
+ path_builder.push(Some(loc.path.clone()));
+ partition_builder.push(Some(loc.partition_id as u32));
+ num_rows_builder.push(Some(loc.num_rows));
+ num_batches_builder.push(Some(loc.num_batches));
+ num_bytes_builder.push(Some(loc.num_bytes));
}
// build arrays
- let partition_num: ArrayRef = Arc::new(partition_builder.finish());
- let path: ArrayRef = Arc::new(path_builder.finish());
- let field_builders: Vec<Box<dyn ArrayBuilder>> = vec![
- Box::new(num_rows_builder),
- Box::new(num_batches_builder),
- Box::new(num_bytes_builder),
+ let partition_num: ArrayRef = partition_builder.into_arc();
+ let path: ArrayRef = path_builder.into_arc();
+ let field_builders: Vec<Arc<dyn Array>> = vec![
+ num_rows_builder.into_arc(),
+ num_batches_builder.into_arc(),
+ num_bytes_builder.into_arc(),
];
- let mut stats_builder = StructBuilder::new(
- PartitionStats::default().arrow_struct_fields(),
+ let stats_builder = StructArray::from_data(
+ DataType::Struct(PartitionStats::default().arrow_struct_fields()),
field_builders,
+ None,
);
- for _ in 0..num_writers {
- stats_builder.append(true)?;
- }
- let stats = Arc::new(stats_builder.finish());
+ let stats = Arc::new(stats_builder);
// build result batch containing metadata
let schema = result_schema();
@@ -443,9 +440,11 @@
#[cfg(test)]
mod tests {
use super::*;
- use datafusion::arrow::array::{StringArray, StructArray, UInt32Array, UInt64Array};
+ use datafusion::arrow::array::{StructArray, UInt32Array, UInt64Array, Utf8Array};
+ use datafusion::field_util::StructArrayExt;
use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec;
use datafusion::physical_plan::expressions::Column;
+ use std::iter::FromIterator;
use datafusion::physical_plan::memory::MemoryExec;
use tempfile::TempDir;
@@ -473,7 +472,7 @@
assert_eq!(2, batch.num_rows());
let path = batch.columns()[1]
.as_any()
- .downcast_ref::<StringArray>()
+ .downcast_ref::<Utf8Array<i32>>()
.unwrap();
let file0 = path.value(0);
@@ -551,8 +550,8 @@
let batch = RecordBatch::try_new(
schema.clone(),
vec![
- Arc::new(UInt32Array::from(vec![Some(1), Some(2)])),
- Arc::new(StringArray::from(vec![Some("hello"), Some("world")])),
+ Arc::new(UInt32Array::from_iter(vec![Some(1), Some(2)])),
+ Arc::new(Utf8Array::<i32>::from(vec![Some("hello"), Some("world")])),
],
)?;
let partition = vec![batch.clone(), batch];
diff --git a/ballista/rust/core/src/lib.rs b/ballista/rust/core/src/lib.rs
index c452a45..cfab525 100644
--- a/ballista/rust/core/src/lib.rs
+++ b/ballista/rust/core/src/lib.rs
@@ -18,6 +18,9 @@
#![doc = include_str!("../README.md")]
pub const BALLISTA_VERSION: &str = env!("CARGO_PKG_VERSION");
+#[macro_use]
+extern crate async_trait;
+
pub fn print_version() {
println!("Ballista version: {}", BALLISTA_VERSION)
}
diff --git a/ballista/rust/core/src/memory_stream.rs b/ballista/rust/core/src/memory_stream.rs
new file mode 100644
index 0000000..8b13789
--- /dev/null
+++ b/ballista/rust/core/src/memory_stream.rs
@@ -0,0 +1 @@
+
diff --git a/ballista/rust/core/src/serde/logical_plan/mod.rs b/ballista/rust/core/src/serde/logical_plan/mod.rs
index 4970cd6..9d9113b 100644
--- a/ballista/rust/core/src/serde/logical_plan/mod.rs
+++ b/ballista/rust/core/src/serde/logical_plan/mod.rs
@@ -38,6 +38,7 @@
};
use datafusion::prelude::ExecutionContext;
+use datafusion::field_util::{FieldExt, SchemaExt};
use prost::bytes::BufMut;
use prost::Message;
use protobuf::listing_table_scan_node::FileFormatType;
@@ -858,6 +859,7 @@
FileMetaStream, ListEntryStream, ObjectReader, ObjectStore, SizedFile,
};
use datafusion::error::DataFusionError;
+ use datafusion::field_util::SchemaExt;
use datafusion::{
arrow::datatypes::{DataType, Field, Schema},
datasource::object_store::local::LocalFileSystem,
diff --git a/ballista/rust/core/src/serde/physical_plan/mod.rs b/ballista/rust/core/src/serde/physical_plan/mod.rs
index 83607ae..96b2810 100644
--- a/ballista/rust/core/src/serde/physical_plan/mod.rs
+++ b/ballista/rust/core/src/serde/physical_plan/mod.rs
@@ -31,10 +31,11 @@
PhysicalExtensionCodec,
};
use crate::{convert_box_required, convert_required, into_physical_plan, into_required};
-use datafusion::arrow::compute::SortOptions;
+use datafusion::arrow::compute::sort::SortOptions;
use datafusion::arrow::datatypes::SchemaRef;
use datafusion::datasource::object_store::local::LocalFileSystem;
use datafusion::datasource::PartitionedFile;
+use datafusion::field_util::FieldExt;
use datafusion::logical_plan::window_frames::WindowFrame;
use datafusion::physical_plan::aggregates::create_aggregate_expr;
use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec;
@@ -640,7 +641,7 @@
.aggr_expr()
.iter()
.map(|expr| match expr.field() {
- Ok(field) => Ok(field.name().clone()),
+ Ok(field) => Ok(field.name().to_string()),
Err(e) => Err(BallistaError::DataFusionError(e)),
})
.collect::<Result<_, BallistaError>>()?;
@@ -939,11 +940,12 @@
use crate::serde::{AsExecutionPlan, BallistaCodec};
use datafusion::datasource::object_store::local::LocalFileSystem;
use datafusion::datasource::PartitionedFile;
+ use datafusion::field_util::SchemaExt;
use datafusion::physical_plan::sorts::sort::SortExec;
use datafusion::prelude::ExecutionContext;
use datafusion::{
arrow::{
- compute::kernels::sort::SortOptions,
+ compute::sort::SortOptions,
datatypes::{DataType, Field, Schema},
},
logical_plan::{JoinType, Operator},
diff --git a/ballista/rust/core/src/serde/scheduler/mod.rs b/ballista/rust/core/src/serde/scheduler/mod.rs
index c304382..9e91340 100644
--- a/ballista/rust/core/src/serde/scheduler/mod.rs
+++ b/ballista/rust/core/src/serde/scheduler/mod.rs
@@ -17,11 +17,8 @@
use std::{collections::HashMap, fmt, sync::Arc};
-use datafusion::arrow::array::{
- ArrayBuilder, StructArray, StructBuilder, UInt64Array, UInt64Builder,
-};
+use datafusion::arrow::array::*;
use datafusion::arrow::datatypes::{DataType, Field};
-
use datafusion::physical_plan::ExecutionPlan;
use datafusion::physical_plan::Partitioning;
use serde::Serialize;
@@ -293,52 +290,29 @@
]
}
- pub fn to_arrow_arrayref(self) -> Result<Arc<StructArray>, BallistaError> {
- let mut field_builders = Vec::new();
+ pub fn to_arrow_arrayref(&self) -> Result<Arc<StructArray>, BallistaError> {
+ let num_rows = Arc::new(UInt64Array::from(&[self.num_rows])) as ArrayRef;
+ let num_batches = Arc::new(UInt64Array::from(&[self.num_batches])) as ArrayRef;
+ let num_bytes = Arc::new(UInt64Array::from(&[self.num_bytes])) as ArrayRef;
+ let values = vec![num_rows, num_batches, num_bytes];
- let mut num_rows_builder = UInt64Builder::new(1);
- match self.num_rows {
- Some(n) => num_rows_builder.append_value(n)?,
- None => num_rows_builder.append_null()?,
- }
- field_builders.push(Box::new(num_rows_builder) as Box<dyn ArrayBuilder>);
-
- let mut num_batches_builder = UInt64Builder::new(1);
- match self.num_batches {
- Some(n) => num_batches_builder.append_value(n)?,
- None => num_batches_builder.append_null()?,
- }
- field_builders.push(Box::new(num_batches_builder) as Box<dyn ArrayBuilder>);
-
- let mut num_bytes_builder = UInt64Builder::new(1);
- match self.num_bytes {
- Some(n) => num_bytes_builder.append_value(n)?,
- None => num_bytes_builder.append_null()?,
- }
- field_builders.push(Box::new(num_bytes_builder) as Box<dyn ArrayBuilder>);
-
- let mut struct_builder =
- StructBuilder::new(self.arrow_struct_fields(), field_builders);
- struct_builder.append(true)?;
- Ok(Arc::new(struct_builder.finish()))
+ Ok(Arc::new(StructArray::from_data(
+ DataType::Struct(self.arrow_struct_fields()),
+ values,
+ None,
+ )))
}
pub fn from_arrow_struct_array(struct_array: &StructArray) -> PartitionStats {
- let num_rows = struct_array
- .column_by_name("num_rows")
- .expect("from_arrow_struct_array expected a field num_rows")
+ let num_rows = struct_array.values()[0]
.as_any()
.downcast_ref::<UInt64Array>()
.expect("from_arrow_struct_array expected num_rows to be a UInt64Array");
- let num_batches = struct_array
- .column_by_name("num_batches")
- .expect("from_arrow_struct_array expected a field num_batches")
+ let num_batches = struct_array.values()[1]
.as_any()
.downcast_ref::<UInt64Array>()
.expect("from_arrow_struct_array expected num_batches to be a UInt64Array");
- let num_bytes = struct_array
- .column_by_name("num_bytes")
- .expect("from_arrow_struct_array expected a field num_bytes")
+ let num_bytes = struct_array.values()[2]
.as_any()
.downcast_ref::<UInt64Array>()
.expect("from_arrow_struct_array expected num_bytes to be a UInt64Array");
diff --git a/ballista/rust/core/src/utils.rs b/ballista/rust/core/src/utils.rs
index 560d459..3bee937 100644
--- a/ballista/rust/core/src/utils.rs
+++ b/ballista/rust/core/src/utils.rs
@@ -30,18 +30,19 @@
use crate::config::BallistaConfig;
use crate::serde::{AsLogicalPlan, DefaultLogicalExtensionCodec, LogicalExtensionCodec};
+use arrow::chunk::Chunk;
use async_trait::async_trait;
use datafusion::arrow::datatypes::Schema;
+use datafusion::arrow::datatypes::SchemaRef;
use datafusion::arrow::error::Result as ArrowResult;
-use datafusion::arrow::{
- datatypes::SchemaRef, ipc::writer::FileWriter, record_batch::RecordBatch,
-};
+use datafusion::arrow::io::ipc::write::FileWriter;
+use datafusion::arrow::io::ipc::write::WriteOptions;
use datafusion::error::DataFusionError;
use datafusion::execution::context::{
ExecutionConfig, ExecutionContext, ExecutionContextState, QueryPlanner,
};
+use datafusion::field_util::SchemaExt;
use datafusion::logical_plan::LogicalPlan;
-
use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec;
use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec;
use datafusion::physical_plan::common::batch_byte_size;
@@ -54,6 +55,7 @@
use datafusion::physical_plan::projection::ProjectionExec;
use datafusion::physical_plan::sorts::sort::SortExec;
use datafusion::physical_plan::{metrics, ExecutionPlan, RecordBatchStream};
+use datafusion::record_batch::RecordBatch;
use futures::{Stream, StreamExt};
/// Stream data to disk in Arrow IPC format
@@ -63,7 +65,7 @@
path: &str,
disk_write_metric: &metrics::Time,
) -> Result<PartitionStats> {
- let file = File::create(&path).map_err(|e| {
+ let mut file = File::create(&path).map_err(|e| {
BallistaError::General(format!(
"Failed to create partition file at {}: {:?}",
path, e
@@ -73,7 +75,12 @@
let mut num_rows = 0;
let mut num_batches = 0;
let mut num_bytes = 0;
- let mut writer = FileWriter::try_new(file, stream.schema().as_ref())?;
+ let mut writer = FileWriter::try_new(
+ &mut file,
+ stream.schema().as_ref(),
+ None,
+ WriteOptions::default(),
+ )?;
while let Some(result) = stream.next().await {
let batch = result?;
@@ -84,7 +91,8 @@
num_bytes += batch_size_bytes;
let timer = disk_write_metric.timer();
- writer.write(&batch)?;
+ let chunk = Chunk::new(batch.columns().to_vec());
+ writer.write(&chunk, None)?;
timer.done();
}
let timer = disk_write_metric.timer();
diff --git a/ballista/rust/executor/Cargo.toml b/ballista/rust/executor/Cargo.toml
index c45e57b..241af8e 100644
--- a/ballista/rust/executor/Cargo.toml
+++ b/ballista/rust/executor/Cargo.toml
@@ -29,8 +29,8 @@
snmalloc = ["snmalloc-rs"]
[dependencies]
-arrow = { version = "10.0" }
-arrow-flight = { version = "10.0" }
+arrow-format = { version = "0.4", features = ["flight-data", "flight-service"] }
+arrow = { package = "arrow2", version="0.10", features = ["io_ipc"] }
anyhow = "1"
async-trait = "0.1.41"
ballista-core = { path = "../core", version = "0.6.0" }
diff --git a/ballista/rust/executor/src/collect.rs b/ballista/rust/executor/src/collect.rs
index 37a7f7b..72fa1ac 100644
--- a/ballista/rust/executor/src/collect.rs
+++ b/ballista/rust/executor/src/collect.rs
@@ -23,15 +23,14 @@
use std::{any::Any, pin::Pin};
use async_trait::async_trait;
-use datafusion::arrow::{
- datatypes::SchemaRef, error::Result as ArrowResult, record_batch::RecordBatch,
-};
+use datafusion::arrow::{datatypes::SchemaRef, error::Result as ArrowResult};
use datafusion::error::DataFusionError;
use datafusion::execution::runtime_env::RuntimeEnv;
use datafusion::physical_plan::expressions::PhysicalSortExpr;
use datafusion::physical_plan::{
DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, Statistics,
};
+use datafusion::record_batch::RecordBatch;
use datafusion::{error::Result, physical_plan::RecordBatchStream};
use futures::stream::SelectAll;
use futures::Stream;
diff --git a/ballista/rust/executor/src/flight_service.rs b/ballista/rust/executor/src/flight_service.rs
index cf5ab17..a936768 100644
--- a/ballista/rust/executor/src/flight_service.rs
+++ b/ballista/rust/executor/src/flight_service.rs
@@ -17,28 +17,28 @@
//! Implementation of the Apache Arrow Flight protocol that wraps an executor.
+use arrow::array::ArrayRef;
+use arrow::chunk::Chunk;
use std::fs::File;
use std::pin::Pin;
use std::sync::Arc;
use crate::executor::Executor;
-use arrow_flight::SchemaAsIpc;
use ballista_core::error::BallistaError;
use ballista_core::serde::decode_protobuf;
use ballista_core::serde::scheduler::Action as BallistaAction;
-use arrow_flight::{
- flight_service_server::FlightService, Action, ActionType, Criteria, Empty,
- FlightData, FlightDescriptor, FlightInfo, HandshakeRequest, HandshakeResponse,
- PutResult, SchemaResult, Ticket,
+use arrow::io::ipc::read::read_file_metadata;
+use arrow_format::flight::data::{
+ Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo,
+ HandshakeRequest, HandshakeResponse, PutResult, SchemaResult, Ticket,
};
+use arrow_format::flight::service::flight_service_server::FlightService;
use datafusion::arrow::{
- error::ArrowError, ipc::reader::FileReader, ipc::writer::IpcWriteOptions,
- record_batch::RecordBatch,
+ error::ArrowError, io::ipc::read::FileReader, io::ipc::write::WriteOptions,
};
use futures::{Stream, StreamExt};
use log::{info, warn};
-use std::io::{Read, Seek};
use tokio::sync::mpsc::channel;
use tokio::{
sync::mpsc::{Receiver, Sender},
@@ -68,7 +68,7 @@
#[tonic::async_trait]
impl FlightService for BallistaFlightService {
- type DoActionStream = BoxedFlightStream<arrow_flight::Result>;
+ type DoActionStream = BoxedFlightStream<arrow_format::flight::data::Result>;
type DoExchangeStream = BoxedFlightStream<FlightData>;
type DoGetStream = BoxedFlightStream<FlightData>;
type DoPutStream = BoxedFlightStream<PutResult>;
@@ -88,22 +88,12 @@
match &action {
BallistaAction::FetchPartition { path, .. } => {
info!("FetchPartition reading {}", &path);
- let file = File::open(&path)
- .map_err(|e| {
- BallistaError::General(format!(
- "Failed to open partition file at {}: {:?}",
- path, e
- ))
- })
- .map_err(|e| from_ballista_err(&e))?;
- let reader = FileReader::try_new(file).map_err(|e| from_arrow_err(&e))?;
-
let (tx, rx): (FlightDataSender, FlightDataReceiver) = channel(2);
-
+ let path = path.clone();
// Arrow IPC reader does not implement Sync + Send so we need to use a channel
// to communicate
task::spawn(async move {
- if let Err(e) = stream_flight_data(reader, tx).await {
+ if let Err(e) = stream_flight_data(path, tx).await {
warn!("Error streaming results: {:?}", e);
}
});
@@ -186,11 +176,11 @@
/// Convert a single RecordBatch into an iterator of FlightData (containing
/// dictionaries and batches)
fn create_flight_iter(
- batch: &RecordBatch,
- options: &IpcWriteOptions,
+ chunk: &Chunk<ArrayRef>,
+ options: &WriteOptions,
) -> Box<dyn Iterator<Item = Result<FlightData, Status>>> {
let (flight_dictionaries, flight_batch) =
- arrow_flight::utils::flight_data_from_arrow_batch(batch, options);
+ arrow::io::flight::serialize_batch(chunk, &[], options);
Box::new(
flight_dictionaries
.into_iter()
@@ -199,21 +189,26 @@
)
}
-async fn stream_flight_data<T>(
- reader: FileReader<T>,
- tx: FlightDataSender,
-) -> Result<(), Status>
-where
- T: Read + Seek,
-{
- let options = arrow::ipc::writer::IpcWriteOptions::default();
- let schema_flight_data = SchemaAsIpc::new(reader.schema().as_ref(), &options).into();
+async fn stream_flight_data(path: String, tx: FlightDataSender) -> Result<(), Status> {
+ let mut file = File::open(&path)
+ .map_err(|e| {
+ BallistaError::General(format!(
+ "Failed to open partition file at {}: {:?}",
+ path, e
+ ))
+ })
+ .map_err(|e| from_ballista_err(&e))?;
+ let file_meta = read_file_metadata(&mut file).map_err(|e| from_arrow_err(&e))?;
+ let reader = FileReader::new(&mut file, file_meta, None);
+
+ let options = WriteOptions::default();
+ let schema_flight_data = arrow::io::flight::serialize_schema(reader.schema(), None);
send_response(&tx, Ok(schema_flight_data)).await?;
let mut row_count = 0;
for batch in reader {
if let Ok(x) = &batch {
- row_count += x.num_rows();
+ row_count += x.len();
}
let batch_flight_data: Vec<_> = batch
.map(|b| create_flight_iter(&b, &options).collect())
diff --git a/ballista/rust/executor/src/main.rs b/ballista/rust/executor/src/main.rs
index 6b270a2..37c7d2d 100644
--- a/ballista/rust/executor/src/main.rs
+++ b/ballista/rust/executor/src/main.rs
@@ -22,7 +22,7 @@
use std::time::Duration as Core_Duration;
use anyhow::{Context, Result};
-use arrow_flight::flight_service_server::FlightServiceServer;
+use arrow_format::flight::service::flight_service_server::FlightServiceServer;
use ballista_executor::{execution_loop, executor_server};
use log::{error, info};
use tempfile::TempDir;
diff --git a/ballista/rust/executor/src/standalone.rs b/ballista/rust/executor/src/standalone.rs
index 0bc2503..dcbc2b2 100644
--- a/ballista/rust/executor/src/standalone.rs
+++ b/ballista/rust/executor/src/standalone.rs
@@ -17,8 +17,7 @@
use std::sync::Arc;
-use arrow_flight::flight_service_server::FlightServiceServer;
-
+use arrow_format::flight::service::flight_service_server::FlightServiceServer;
use ballista_core::serde::scheduler::ExecutorSpecification;
use ballista_core::serde::{AsExecutionPlan, AsLogicalPlan, BallistaCodec};
use ballista_core::{
diff --git a/ballista/rust/scheduler/src/test_utils.rs b/ballista/rust/scheduler/src/test_utils.rs
index b9d7ee4..18c4710 100644
--- a/ballista/rust/scheduler/src/test_utils.rs
+++ b/ballista/rust/scheduler/src/test_utils.rs
@@ -19,6 +19,7 @@
use datafusion::arrow::datatypes::{DataType, Field, Schema};
use datafusion::execution::context::{ExecutionConfig, ExecutionContext};
+use datafusion::field_util::SchemaExt;
use datafusion::prelude::CsvReadOptions;
pub const TPCH_TABLES: &[&str] = &[
diff --git a/benchmarks/.gitignore b/benchmarks/.gitignore
index 6320cd2..1269488 100644
--- a/benchmarks/.gitignore
+++ b/benchmarks/.gitignore
@@ -1 +1 @@
-data
\ No newline at end of file
+data
diff --git a/benchmarks/Cargo.toml b/benchmarks/Cargo.toml
index 5f457ca..1b4c089 100644
--- a/benchmarks/Cargo.toml
+++ b/benchmarks/Cargo.toml
@@ -32,6 +32,7 @@
snmalloc = ["snmalloc-rs"]
[dependencies]
+arrow = { package = "arrow2", version="0.10", features = ["io_csv", "io_json", "io_parquet", "io_parquet_compression", "io_ipc", "io_print", "ahash", "compute_merge_sort", "compute", "regex"] }
datafusion = { path = "../datafusion" }
ballista = { path = "../ballista/rust/client" }
structopt = { version = "0.3", default-features = false }
diff --git a/benchmarks/src/bin/nyctaxi.rs b/benchmarks/src/bin/nyctaxi.rs
index 49679f4..0da5f89 100644
--- a/benchmarks/src/bin/nyctaxi.rs
+++ b/benchmarks/src/bin/nyctaxi.rs
@@ -17,17 +17,20 @@
//! Apache Arrow Rust Benchmarks
+use arrow::array::ArrayRef;
+use arrow::chunk::Chunk;
use std::collections::HashMap;
use std::path::PathBuf;
use std::process;
use std::time::Instant;
use datafusion::arrow::datatypes::{DataType, Field, Schema};
-use datafusion::arrow::util::pretty;
+use datafusion::arrow::io::print;
use datafusion::error::Result;
use datafusion::execution::context::{ExecutionConfig, ExecutionContext};
+use datafusion::field_util::SchemaExt;
use datafusion::physical_plan::collect;
use datafusion::prelude::CsvReadOptions;
use structopt::StructOpt;
@@ -125,7 +128,12 @@
let physical_plan = ctx.create_physical_plan(&plan).await?;
let result = collect(physical_plan, runtime).await?;
if debug {
- pretty::print_batches(&result)?;
+ let fields = result
+ .first()
+ .map(|b| b.schema().field_names())
+ .unwrap_or(vec![]);
+ let chunks: Vec<Chunk<ArrayRef>> = result.iter().map(|rb| rb.into()).collect();
+ println!("{}", print::write(&chunks, &fields));
}
Ok(())
}
diff --git a/benchmarks/src/bin/tpch.rs b/benchmarks/src/bin/tpch.rs
index 1cc6687..7671c78 100644
--- a/benchmarks/src/bin/tpch.rs
+++ b/benchmarks/src/bin/tpch.rs
@@ -17,6 +17,8 @@
//! Benchmark derived from TPC-H. This is not an official TPC-H benchmark.
+use arrow::array::ArrayRef;
+use arrow::chunk::Chunk;
use futures::future::join_all;
use rand::prelude::*;
use std::ops::Div;
@@ -29,14 +31,15 @@
time::{Instant, SystemTime},
};
-use ballista::context::BallistaContext;
-use ballista::prelude::{BallistaConfig, BALLISTA_DEFAULT_SHUFFLE_PARTITIONS};
+use datafusion::arrow::io::print;
+use datafusion::datasource::{
+ listing::{ListingOptions, ListingTable},
+ object_store::local::LocalFileSystem,
+};
use datafusion::datasource::{MemTable, TableProvider};
use datafusion::error::{DataFusionError, Result};
use datafusion::logical_plan::LogicalPlan;
-use datafusion::parquet::basic::Compression;
-use datafusion::parquet::file::properties::WriterProperties;
use datafusion::physical_plan::display::DisplayableExecutionPlan;
use datafusion::physical_plan::{collect, displayable};
use datafusion::prelude::*;
@@ -46,26 +49,25 @@
DATAFUSION_VERSION,
};
use datafusion::{
- arrow::record_batch::RecordBatch, datasource::file_format::parquet::ParquetFormat,
-};
-use datafusion::{
- arrow::util::pretty,
- datasource::{
- listing::{ListingOptions, ListingTable, ListingTableConfig},
- object_store::local::LocalFileSystem,
- },
+ datasource::file_format::parquet::ParquetFormat, record_batch::RecordBatch,
};
+use arrow::io::parquet::write::{Compression, Version, WriteOptions};
+use ballista::prelude::{
+ BallistaConfig, BallistaContext, BALLISTA_DEFAULT_SHUFFLE_PARTITIONS,
+};
use datafusion::datasource::file_format::csv::DEFAULT_CSV_EXTENSION;
use datafusion::datasource::file_format::parquet::DEFAULT_PARQUET_EXTENSION;
+use datafusion::datasource::listing::ListingTableConfig;
+use datafusion::field_util::SchemaExt;
use serde::Serialize;
use structopt::StructOpt;
-#[cfg(feature = "snmalloc")]
+#[cfg(all(feature = "snmalloc", not(feature = "mimalloc")))]
#[global_allocator]
static ALLOC: snmalloc_rs::SnMalloc = snmalloc_rs::SnMalloc;
-#[cfg(feature = "mimalloc")]
+#[cfg(all(feature = "mimalloc", not(feature = "snmalloc")))]
#[global_allocator]
static ALLOC: mimalloc::MiMalloc = mimalloc::MiMalloc;
@@ -379,7 +381,7 @@
);
benchmark_run.add_result(elapsed, row_count);
if opt.debug {
- pretty::print_batches(&batches)?;
+ println!("{}", datafusion::arrow_print::write(&batches));
}
}
@@ -493,7 +495,7 @@
&client_id, &i, query_id, elapsed
);
if opt.debug {
- pretty::print_batches(&batches).unwrap();
+ println!("{}", datafusion::arrow_print::write(&batches));
}
}
});
@@ -615,7 +617,12 @@
"=== Physical plan with metrics ===\n{}\n",
DisplayableExecutionPlan::with_metrics(physical_plan.as_ref()).indent()
);
- pretty::print_batches(&result)?;
+ let fields = result
+ .first()
+ .map(|b| b.schema().field_names())
+ .unwrap_or(vec![]);
+ let chunks: Vec<Chunk<ArrayRef>> = result.iter().map(|rb| rb.into()).collect();
+ println!("{}", print::write(&chunks, &fields));
}
Ok(result)
}
@@ -659,13 +666,13 @@
"csv" => ctx.write_csv(csv, output_path).await?,
"parquet" => {
let compression = match opt.compression.as_str() {
- "none" => Compression::UNCOMPRESSED,
- "snappy" => Compression::SNAPPY,
- "brotli" => Compression::BROTLI,
- "gzip" => Compression::GZIP,
- "lz4" => Compression::LZ4,
- "lz0" => Compression::LZO,
- "zstd" => Compression::ZSTD,
+ "none" => Compression::Uncompressed,
+ "snappy" => Compression::Snappy,
+ "brotli" => Compression::Brotli,
+ "gzip" => Compression::Gzip,
+ "lz4" => Compression::Lz4,
+ "lz0" => Compression::Lzo,
+ "zstd" => Compression::Zstd,
other => {
return Err(DataFusionError::NotImplemented(format!(
"Invalid compression format: {}",
@@ -673,10 +680,13 @@
)))
}
};
- let props = WriterProperties::builder()
- .set_compression(compression)
- .build();
- ctx.write_parquet(csv, output_path, Some(props)).await?
+
+ let options = WriteOptions {
+ compression,
+ write_statistics: false,
+ version: Version::V1,
+ };
+ ctx.write_parquet(csv, output_path, options).await?
}
other => {
return Err(DataFusionError::NotImplemented(format!(
@@ -893,8 +903,9 @@
use std::env;
use std::sync::Arc;
+ use arrow::array::get_display;
use datafusion::arrow::array::*;
- use datafusion::arrow::util::display::array_value_to_string;
+ use datafusion::field_util::FieldExt;
use datafusion::logical_plan::Expr;
use datafusion::logical_plan::Expr::Cast;
@@ -1069,7 +1080,7 @@
}
/// Specialised String representation
- fn col_str(column: &ArrayRef, row_index: usize) -> String {
+ fn col_str(column: &dyn Array, row_index: usize) -> String {
if column.is_null(row_index) {
return "NULL".to_string();
}
@@ -1084,12 +1095,13 @@
let mut r = Vec::with_capacity(*n as usize);
for i in 0..*n {
- r.push(col_str(&array, i as usize));
+ r.push(col_str(array.as_ref(), i as usize));
}
return format!("[{}]", r.join(","));
}
-
- array_value_to_string(column, row_index).unwrap()
+ let mut string = String::new();
+ get_display(column, "null")(&mut string, row_index).unwrap();
+ string
}
/// Converts the results into a 2d array of strings, `result[row][column]`
@@ -1101,7 +1113,7 @@
let row_vec = batch
.columns()
.iter()
- .map(|column| col_str(column, row_index))
+ .map(|column| col_str(column.as_ref(), row_index))
.collect();
result.push(row_vec);
}
@@ -1263,7 +1275,7 @@
// convert the schema to the same but with all columns set to nullable=true.
// this allows direct schema comparison ignoring nullable.
- fn nullable_schema(schema: Arc<Schema>) -> Schema {
+ fn nullable_schema(schema: &Schema) -> Schema {
Schema::new(
schema
.fields()
diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml
index 2302827..e546c5b 100644
--- a/datafusion-cli/Cargo.toml
+++ b/datafusion-cli/Cargo.toml
@@ -32,7 +32,7 @@
rustyline = "9.0"
tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync", "parking_lot"] }
datafusion = { path = "../datafusion", version = "7.0.0" }
-arrow = { version = "10.0" }
+arrow = { package = "arrow2", version="0.10", features = ["io_print"] }
ballista = { path = "../ballista/rust/client", version = "0.6.0", optional=true }
env_logger = "0.9"
mimalloc = { version = "*", default-features = false }
diff --git a/datafusion-cli/src/command.rs b/datafusion-cli/src/command.rs
index 0fd43a3..f6bedc2 100644
--- a/datafusion-cli/src/command.rs
+++ b/datafusion-cli/src/command.rs
@@ -22,14 +22,17 @@
use crate::print_format::PrintFormat;
use crate::print_options::PrintOptions;
use clap::ArgEnum;
-use datafusion::arrow::array::{ArrayRef, StringArray};
+use datafusion::arrow::array::{ArrayRef, Utf8Array};
use datafusion::arrow::datatypes::{DataType, Field, Schema};
-use datafusion::arrow::record_batch::RecordBatch;
use datafusion::error::{DataFusionError, Result};
+use datafusion::field_util::SchemaExt;
+use datafusion::record_batch::RecordBatch;
use std::str::FromStr;
use std::sync::Arc;
use std::time::Instant;
+type StringArray = Utf8Array<i32>;
+
/// Command
#[derive(Debug)]
pub enum Command {
@@ -147,7 +150,7 @@
schema,
[names, description]
.into_iter()
- .map(|i| Arc::new(StringArray::from(i)) as ArrayRef)
+ .map(|i| Arc::new(StringArray::from_slice(i)) as ArrayRef)
.collect::<Vec<_>>(),
)
.expect("This should not fail")
diff --git a/datafusion-cli/src/functions.rs b/datafusion-cli/src/functions.rs
index 98b698a..224f990 100644
--- a/datafusion-cli/src/functions.rs
+++ b/datafusion-cli/src/functions.rs
@@ -16,15 +16,18 @@
// under the License.
//! Functions that are query-able and searchable via the `\h` command
-use arrow::array::StringArray;
+use arrow::array::{ArrayRef, Utf8Array};
+use arrow::chunk::Chunk;
use arrow::datatypes::{DataType, Field, Schema};
-use arrow::record_batch::RecordBatch;
-use arrow::util::pretty::pretty_format_batches;
+use datafusion::arrow::io::print;
use datafusion::error::Result;
+use datafusion::field_util::SchemaExt;
use std::fmt;
use std::str::FromStr;
use std::sync::Arc;
+type StringArray = Utf8Array<i32>;
+
#[derive(Debug)]
pub enum Function {
Select,
@@ -185,14 +188,14 @@
pub fn display_all_functions() -> Result<()> {
println!("Available help:");
- let array = StringArray::from(
+ let array: ArrayRef = Arc::new(StringArray::from_slice(
ALL_FUNCTIONS
.iter()
.map(|f| format!("{}", f))
.collect::<Vec<String>>(),
- );
+ ));
let schema = Schema::new(vec![Field::new("Function", DataType::Utf8, false)]);
- let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array)])?;
- println!("{}", pretty_format_batches(&[batch]).unwrap());
+ let batch = Chunk::try_new(vec![array])?;
+ println!("{}", print::write(&[batch], &schema.field_names()));
Ok(())
}
diff --git a/datafusion-cli/src/print_format.rs b/datafusion-cli/src/print_format.rs
index 05a1ef7..076c068 100644
--- a/datafusion-cli/src/print_format.rs
+++ b/datafusion-cli/src/print_format.rs
@@ -16,11 +16,13 @@
// under the License.
//! Print format variants
-use arrow::csv::writer::WriterBuilder;
-use arrow::json::{ArrayWriter, LineDelimitedWriter};
-use datafusion::arrow::record_batch::RecordBatch;
-use datafusion::arrow::util::pretty;
+use arrow::io::csv::write::SerializeOptions;
+use arrow::io::ndjson::write::FallibleStreamingIterator;
+use datafusion::arrow::io::csv::write;
use datafusion::error::{DataFusionError, Result};
+use datafusion::field_util::SchemaExt;
+use datafusion::record_batch::RecordBatch;
+use std::io::Write;
use std::str::FromStr;
/// Allow records to be printed in different formats
@@ -41,27 +43,69 @@
}
}
-macro_rules! batches_to_json {
- ($WRITER: ident, $batches: expr) => {{
- let mut bytes = vec![];
- {
- let mut writer = $WRITER::new(&mut bytes);
- writer.write_batches($batches)?;
- writer.finish()?;
+fn print_batches_to_json(batches: &[RecordBatch]) -> Result<String> {
+ use arrow::io::json::write as json_write;
+
+ if batches.is_empty() {
+ return Ok("{}".to_string());
+ }
+
+ let mut bytes = vec![];
+ for batch in batches {
+ let blocks = json_write::Serializer::new(
+ batch.columns().into_iter().map(|r| Ok(r)),
+ vec![],
+ );
+ json_write::write(&mut bytes, blocks)?;
+ }
+
+ let formatted = String::from_utf8(bytes)
+ .map_err(|e| DataFusionError::Execution(e.to_string()))?;
+ Ok(formatted)
+}
+
+fn print_batches_to_ndjson(batches: &[RecordBatch]) -> Result<String> {
+ use arrow::io::ndjson::write as json_write;
+
+ if batches.is_empty() {
+ return Ok("{}".to_string());
+ }
+ let mut bytes = vec![];
+ for batch in batches {
+ let mut blocks = json_write::Serializer::new(
+ batch.columns().into_iter().map(|r| Ok(r)),
+ vec![],
+ );
+ while let Some(block) = blocks.next()? {
+ bytes.write_all(block)?;
}
- String::from_utf8(bytes).map_err(|e| DataFusionError::Execution(e.to_string()))?
- }};
+ }
+ let formatted = String::from_utf8(bytes)
+ .map_err(|e| DataFusionError::Execution(e.to_string()))?;
+ Ok(formatted)
}
fn print_batches_with_sep(batches: &[RecordBatch], delimiter: u8) -> Result<String> {
let mut bytes = vec![];
{
- let builder = WriterBuilder::new()
- .has_headers(true)
- .with_delimiter(delimiter);
- let mut writer = builder.build(&mut bytes);
+ let mut is_first = true;
for batch in batches {
- writer.write(batch)?;
+ if is_first {
+ write::write_header(
+ &mut bytes,
+ &batches[0].schema().field_names(),
+ &SerializeOptions {
+ delimiter,
+ ..SerializeOptions::default()
+ },
+ )?;
+ is_first = false;
+ }
+ write::write_chunk(
+ &mut bytes,
+ &batch.into(),
+ &write::SerializeOptions::default(),
+ )?;
}
}
let formatted = String::from_utf8(bytes)
@@ -75,10 +119,12 @@
match self {
Self::Csv => println!("{}", print_batches_with_sep(batches, b',')?),
Self::Tsv => println!("{}", print_batches_with_sep(batches, b'\t')?),
- Self::Table => pretty::print_batches(batches)?,
- Self::Json => println!("{}", batches_to_json!(ArrayWriter, batches)),
+ Self::Table => println!("{}", datafusion::arrow_print::write(batches)),
+ Self::Json => {
+ println!("{}", print_batches_to_json(batches)?)
+ }
Self::NdJson => {
- println!("{}", batches_to_json!(LineDelimitedWriter, batches))
+ println!("{}", print_batches_to_ndjson(batches)?)
}
}
Ok(())
@@ -88,9 +134,8 @@
#[cfg(test)]
mod tests {
use super::*;
- use arrow::array::Int32Array;
- use arrow::datatypes::{DataType, Field, Schema};
- use datafusion::from_slice::FromSlice;
+ use datafusion::arrow::array::Int32Array;
+ use datafusion::arrow::datatypes::{DataType, Field, Schema};
use std::sync::Arc;
#[test]
@@ -122,11 +167,11 @@
#[test]
fn test_print_batches_to_json_empty() -> Result<()> {
let batches = vec![];
- let r = batches_to_json!(ArrayWriter, &batches);
- assert_eq!("", r);
+ let r = print_batches_to_json(&batches)?;
+ assert_eq!("{}", r);
- let r = batches_to_json!(LineDelimitedWriter, &batches);
- assert_eq!("", r);
+ let r = print_batches_to_ndjson(&batches)?;
+ assert_eq!("{}", r);
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
@@ -145,10 +190,10 @@
.unwrap();
let batches = vec![batch];
- let r = batches_to_json!(ArrayWriter, &batches);
+ let r = print_batches_to_json(&batches)?;
assert_eq!("[{\"a\":1,\"b\":4,\"c\":7},{\"a\":2,\"b\":5,\"c\":8},{\"a\":3,\"b\":6,\"c\":9}]", r);
- let r = batches_to_json!(LineDelimitedWriter, &batches);
+ let r = print_batches_to_ndjson(&batches)?;
assert_eq!("{\"a\":1,\"b\":4,\"c\":7}\n{\"a\":2,\"b\":5,\"c\":8}\n{\"a\":3,\"b\":6,\"c\":9}\n", r);
Ok(())
}
diff --git a/datafusion-cli/src/print_options.rs b/datafusion-cli/src/print_options.rs
index 5e37926..bebd498 100644
--- a/datafusion-cli/src/print_options.rs
+++ b/datafusion-cli/src/print_options.rs
@@ -16,8 +16,8 @@
// under the License.
use crate::print_format::PrintFormat;
-use datafusion::arrow::record_batch::RecordBatch;
use datafusion::error::Result;
+use datafusion::record_batch::RecordBatch;
use std::time::Instant;
#[derive(Debug, Clone)]
diff --git a/datafusion-common/Cargo.toml b/datafusion-common/Cargo.toml
index 111bb26..069fa7e 100644
--- a/datafusion-common/Cargo.toml
+++ b/datafusion-common/Cargo.toml
@@ -33,14 +33,12 @@
path = "src/lib.rs"
[features]
-avro = ["avro-rs"]
pyarrow = ["pyo3"]
jit = ["cranelift-module"]
[dependencies]
-arrow = { version = "10.0", features = ["prettyprint"] }
-parquet = { version = "10.0", features = ["arrow"], optional = true }
-avro-rs = { version = "0.13", features = ["snappy"], optional = true }
+arrow = { package = "arrow2", version = "0.10", default-features = false }
+parquet = { package = "parquet2", version = "0.10", default_features = false, features = ["stream"], optional = true }
pyo3 = { version = "0.16", optional = true }
sqlparser = "0.15"
ordered-float = "2.10"
diff --git a/datafusion-common/src/dfschema.rs b/datafusion-common/src/dfschema.rs
index 6a3dcb0..5b5d2de 100644
--- a/datafusion-common/src/dfschema.rs
+++ b/datafusion-common/src/dfschema.rs
@@ -22,12 +22,29 @@
use std::convert::TryFrom;
use std::sync::Arc;
-use crate::error::{DataFusionError, Result};
use crate::Column;
+use crate::{DataFusionError, Result};
+use crate::field_util::{FieldExt, SchemaExt};
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use std::fmt::{Display, Formatter};
+pub type DFMetadata = HashMap<String, String>;
+
+pub fn convert_metadata<
+ 'a,
+ M1: Clone + IntoIterator<Item = (String, String)>,
+ M2: FromIterator<(String, String)>,
+>(
+ metadata: &M1,
+) -> M2 {
+ metadata
+ .clone()
+ .into_iter()
+ .map(|(k, v)| (k.clone(), v.clone()))
+ .collect()
+}
+
/// A reference-counted reference to a `DFSchema`.
pub type DFSchemaRef = Arc<DFSchema>;
@@ -37,7 +54,7 @@
/// Fields
fields: Vec<DFField>,
/// Additional metadata in form of key value pairs
- metadata: HashMap<String, String>,
+ metadata: DFMetadata,
}
impl DFSchema {
@@ -45,14 +62,14 @@
pub fn empty() -> Self {
Self {
fields: vec![],
- metadata: HashMap::new(),
+ metadata: DFMetadata::new(),
}
}
#[deprecated(since = "7.0.0", note = "please use `new_with_metadata` instead")]
/// Create a new `DFSchema`
pub fn new(fields: Vec<DFField>) -> Result<Self> {
- Self::new_with_metadata(fields, HashMap::new())
+ Self::new_with_metadata(fields, DFMetadata::new())
}
/// Create a new `DFSchema`
@@ -84,8 +101,8 @@
// deterministic
let mut qualified_names = qualified_names
.iter()
- .map(|(l, r)| (l.to_owned(), r.to_owned()))
- .collect::<Vec<(&String, &String)>>();
+ .map(|(l, r)| (l.as_str(), r.to_owned()))
+ .collect::<Vec<(&str, &str)>>();
qualified_names.sort_by(|a, b| {
let a = format!("{}.{}", a.0, a.1);
let b = format!("{}.{}", b.0, b.1);
@@ -111,7 +128,7 @@
.iter()
.map(|f| DFField::from_qualified(qualifier, f.clone()))
.collect(),
- schema.metadata().clone(),
+ convert_metadata(schema.metadata()),
)
}
@@ -331,17 +348,13 @@
.into_iter()
.map(|f| {
if f.qualifier().is_some() {
- Field::new(
- f.name().as_str(),
- f.data_type().to_owned(),
- f.is_nullable(),
- )
+ Field::new(f.name(), f.data_type().to_owned(), f.is_nullable())
} else {
f.field
}
})
.collect(),
- df_schema.metadata,
+ convert_metadata(&df_schema.metadata),
)
}
}
@@ -351,7 +364,7 @@
fn from(df_schema: &DFSchema) -> Self {
Schema::new_with_metadata(
df_schema.fields.iter().map(|f| f.field.clone()).collect(),
- df_schema.metadata.clone(),
+ convert_metadata(&df_schema.metadata),
)
}
}
@@ -366,7 +379,7 @@
.iter()
.map(|f| DFField::from(f.clone()))
.collect(),
- schema.metadata().clone(),
+ convert_metadata(schema.metadata()),
)
}
}
@@ -414,7 +427,7 @@
impl ToDFSchema for Vec<DFField> {
fn to_dfschema(self) -> Result<DFSchema> {
- DFSchema::new_with_metadata(self, HashMap::new())
+ DFSchema::new_with_metadata(self, DFMetadata::new())
}
}
@@ -507,7 +520,7 @@
}
/// Returns an immutable reference to the `DFField`'s unqualified name
- pub fn name(&self) -> &String {
+ pub fn name(&self) -> &str {
self.field.name()
}
@@ -602,9 +615,10 @@
fn from_qualified_schema_into_arrow_schema() -> Result<()> {
let schema = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?;
let arrow_schema: Schema = schema.into();
- let expected = "Field { name: \"c0\", data_type: Boolean, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: None }, \
- Field { name: \"c1\", data_type: Boolean, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: None }";
- assert_eq!(expected, arrow_schema.to_string());
+ let expected =
+ "[Field { name: \"c0\", data_type: Boolean, is_nullable: true, metadata: {} }, \
+ Field { name: \"c1\", data_type: Boolean, is_nullable: true, metadata: {} }]";
+ assert_eq!(expected, format!("{:?}", arrow_schema.fields));
Ok(())
}
@@ -718,7 +732,7 @@
let metadata = test_metadata();
let arrow_schema = Schema::new_with_metadata(
vec![Field::new("c0", DataType::Int64, true)],
- metadata.clone(),
+ convert_metadata(&metadata),
);
let arrow_schema_ref = Arc::new(arrow_schema.clone());
diff --git a/datafusion-common/src/error.rs b/datafusion-common/src/error.rs
index 4a82ac3..5aa63c1 100644
--- a/datafusion-common/src/error.rs
+++ b/datafusion-common/src/error.rs
@@ -28,7 +28,7 @@
#[cfg(feature = "jit")]
use cranelift_module::ModuleError;
#[cfg(feature = "parquet")]
-use parquet::errors::ParquetError;
+use parquet::error::ParquetError;
use sqlparser::parser::ParserError;
/// Result type for operations that could result in an [DataFusionError]
@@ -94,8 +94,8 @@
fn from(e: DataFusionError) -> Self {
match e {
DataFusionError::ArrowError(e) => e,
- DataFusionError::External(e) => ArrowError::ExternalError(e),
- other => ArrowError::ExternalError(Box::new(other)),
+ DataFusionError::External(e) => ArrowError::External(String::new(), e),
+ other => ArrowError::External(String::new(), Box::new(other)),
}
}
}
@@ -212,7 +212,9 @@
#[allow(clippy::try_err)]
fn return_datafusion_error() -> crate::error::Result<()> {
// Expect the '?' to work
- let _bar = Err(ArrowError::SchemaError("bar".to_string()))?;
+ let _bar = Err(ArrowError::InvalidArgumentError(
+ "bad schema bar".to_string(),
+ ))?;
Ok(())
}
}
diff --git a/datafusion-common/src/field_util.rs b/datafusion-common/src/field_util.rs
new file mode 100644
index 0000000..639e484
--- /dev/null
+++ b/datafusion-common/src/field_util.rs
@@ -0,0 +1,490 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Utility functions for complex field access
+
+use arrow::array::{ArrayRef, StructArray};
+use arrow::datatypes::{DataType, Field, Metadata, Schema};
+use arrow::error::ArrowError;
+use std::borrow::Borrow;
+use std::collections::BTreeMap;
+
+use crate::ScalarValue;
+use crate::{DataFusionError, Result};
+
+/// Returns the field access indexed by `key` from a [`DataType::List`] or [`DataType::Struct`]
+/// # Error
+/// Errors if
+/// * the `data_type` is not a Struct or,
+/// * there is no field key is not of the required index type
+pub fn get_indexed_field(data_type: &DataType, key: &ScalarValue) -> Result<Field> {
+ match (data_type, key) {
+ (DataType::List(lt), ScalarValue::Int64(Some(i))) => {
+ if *i < 0 {
+ Err(DataFusionError::Plan(format!(
+ "List based indexed access requires a positive int, was {0}",
+ i
+ )))
+ } else {
+ Ok(Field::new(&i.to_string(), lt.data_type().clone(), false))
+ }
+ }
+ (DataType::Struct(fields), ScalarValue::Utf8(Some(s))) => {
+ if s.is_empty() {
+ Err(DataFusionError::Plan(
+ "Struct based indexed access requires a non empty string".to_string(),
+ ))
+ } else {
+ let field = fields.iter().find(|f| f.name() == s);
+ match field {
+ None => Err(DataFusionError::Plan(format!(
+ "Field {} not found in struct",
+ s
+ ))),
+ Some(f) => Ok(f.clone()),
+ }
+ }
+ }
+ (DataType::Struct(_), _) => Err(DataFusionError::Plan(
+ "Only utf8 strings are valid as an indexed field in a struct".to_string(),
+ )),
+ (DataType::List(_), _) => Err(DataFusionError::Plan(
+ "Only ints are valid as an indexed field in a list".to_string(),
+ )),
+ _ => Err(DataFusionError::Plan(
+ "The expression to get an indexed field is only valid for `List` types"
+ .to_string(),
+ )),
+ }
+}
+
+/// Imitate arrow-rs StructArray behavior by extending arrow2 StructArray
+pub trait StructArrayExt {
+ /// Return field names in this struct array
+ fn column_names(&self) -> Vec<&str>;
+ /// Return child array whose field name equals to column_name
+ fn column_by_name(&self, column_name: &str) -> Option<&ArrayRef>;
+ /// Return the number of fields in this struct array
+ fn num_columns(&self) -> usize;
+ /// Return the column at the position
+ fn column(&self, pos: usize) -> ArrayRef;
+}
+
+impl StructArrayExt for StructArray {
+ fn column_names(&self) -> Vec<&str> {
+ self.fields().iter().map(|f| f.name.as_str()).collect()
+ }
+
+ fn column_by_name(&self, column_name: &str) -> Option<&ArrayRef> {
+ self.fields()
+ .iter()
+ .position(|c| c.name() == column_name)
+ .map(|pos| self.values()[pos].borrow())
+ }
+
+ fn num_columns(&self) -> usize {
+ self.fields().len()
+ }
+
+ fn column(&self, pos: usize) -> ArrayRef {
+ self.values()[pos].clone()
+ }
+}
+
+/// Converts a list of field / array pairs to a struct array
+pub fn struct_array_from(pairs: Vec<(Field, ArrayRef)>) -> StructArray {
+ let fields: Vec<Field> = pairs.iter().map(|v| v.0.clone()).collect();
+ let values = pairs.iter().map(|v| v.1.clone()).collect();
+ StructArray::from_data(DataType::Struct(fields), values, None)
+}
+
+/// Imitate arrow-rs Schema behavior by extending arrow2 Schema
+pub trait SchemaExt {
+ /// Creates a new [`Schema`] from a sequence of [`Field`] values.
+ ///
+ /// # Example
+ ///
+ /// ```
+ /// use arrow::datatypes::{Field, DataType, Schema};
+ /// use datafusion_common::field_util::SchemaExt;
+ /// let field_a = Field::new("a", DataType::Int64, false);
+ /// let field_b = Field::new("b", DataType::Boolean, false);
+ ///
+ /// let schema = Schema::new(vec![field_a, field_b]);
+ /// ```
+ fn new(fields: Vec<Field>) -> Self;
+
+ /// Creates a new [`Schema`] from a sequence of [`Field`] values and [`arrow::datatypes::Metadata`]
+ ///
+ /// # Example
+ ///
+ /// ```
+ /// use std::collections::BTreeMap;
+ /// use arrow::datatypes::{Field, DataType, Schema};
+ /// use datafusion_common::field_util::SchemaExt;
+ ///
+ /// let field_a = Field::new("a", DataType::Int64, false);
+ /// let field_b = Field::new("b", DataType::Boolean, false);
+ ///
+ /// let schema_metadata: BTreeMap<String, String> =
+ /// vec![("baz".to_string(), "barf".to_string())]
+ /// .into_iter()
+ /// .collect();
+ /// let schema = Schema::new_with_metadata(vec![field_a, field_b], schema_metadata);
+ /// ```
+ fn new_with_metadata(fields: Vec<Field>, metadata: Metadata) -> Self;
+
+ /// Creates an empty [`Schema`].
+ fn empty() -> Self;
+
+ /// Look up a column by name and return a immutable reference to the column along with
+ /// its index.
+ fn column_with_name(&self, name: &str) -> Option<(usize, &Field)>;
+
+ /// Returns the first [`Field`] named `name`.
+ fn field_with_name(&self, name: &str) -> Result<&Field>;
+
+ /// Find the index of the column with the given name.
+ fn index_of(&self, name: &str) -> Result<usize>;
+
+ /// Returns the [`Field`] at position `i`.
+ /// # Panics
+ /// Panics iff `i` is larger than the number of fields in this [`Schema`].
+ fn field(&self, index: usize) -> &Field;
+
+ /// Returns all [`Field`]s in this schema.
+ fn fields(&self) -> &[Field];
+
+ /// Returns an immutable reference to the Map of custom metadata key-value pairs.
+ fn metadata(&self) -> &BTreeMap<String, String>;
+
+ /// Merge schema into self if it is compatible. Struct fields will be merged recursively.
+ ///
+ /// Example:
+ ///
+ /// ```
+ /// use arrow::datatypes::*;
+ /// use datafusion_common::field_util::SchemaExt;
+ ///
+ /// let merged = Schema::try_merge(vec![
+ /// Schema::new(vec![
+ /// Field::new("c1", DataType::Int64, false),
+ /// Field::new("c2", DataType::Utf8, false),
+ /// ]),
+ /// Schema::new(vec![
+ /// Field::new("c1", DataType::Int64, true),
+ /// Field::new("c2", DataType::Utf8, false),
+ /// Field::new("c3", DataType::Utf8, false),
+ /// ]),
+ /// ]).unwrap();
+ ///
+ /// assert_eq!(
+ /// merged,
+ /// Schema::new(vec![
+ /// Field::new("c1", DataType::Int64, true),
+ /// Field::new("c2", DataType::Utf8, false),
+ /// Field::new("c3", DataType::Utf8, false),
+ /// ]),
+ /// );
+ /// ```
+ fn try_merge(schemas: impl IntoIterator<Item = Self>) -> Result<Self>
+ where
+ Self: Sized;
+
+ /// Return the field names
+ fn field_names(&self) -> Vec<String>;
+
+ /// Returns a new schema with only the specified columns in the new schema
+ /// This carries metadata from the parent schema over as well
+ fn project(&self, indices: &[usize]) -> Result<Schema>;
+}
+
+impl SchemaExt for Schema {
+ fn new(fields: Vec<Field>) -> Self {
+ Self::from(fields)
+ }
+
+ fn new_with_metadata(fields: Vec<Field>, metadata: Metadata) -> Self {
+ Self::new(fields).with_metadata(metadata)
+ }
+
+ fn empty() -> Self {
+ Self::from(vec![])
+ }
+
+ fn column_with_name(&self, name: &str) -> Option<(usize, &Field)> {
+ self.fields.iter().enumerate().find(|(_, f)| f.name == name)
+ }
+
+ fn field_with_name(&self, name: &str) -> Result<&Field> {
+ Ok(&self.fields[self.index_of(name)?])
+ }
+
+ fn index_of(&self, name: &str) -> Result<usize> {
+ self.column_with_name(name).map(|(i, _f)| i).ok_or_else(|| {
+ DataFusionError::ArrowError(ArrowError::InvalidArgumentError(format!(
+ "Unable to get field named \"{}\". Valid fields: {:?}",
+ name,
+ self.field_names()
+ )))
+ })
+ }
+
+ fn field(&self, index: usize) -> &Field {
+ &self.fields[index]
+ }
+
+ #[inline]
+ fn fields(&self) -> &[Field] {
+ &self.fields
+ }
+
+ #[inline]
+ fn metadata(&self) -> &BTreeMap<String, String> {
+ &self.metadata
+ }
+
+ fn try_merge(schemas: impl IntoIterator<Item = Self>) -> Result<Self> {
+ schemas
+ .into_iter()
+ .try_fold(Self::empty(), |mut merged, schema| {
+ let Schema { metadata, fields } = schema;
+ for (key, value) in metadata.into_iter() {
+ // merge metadata
+ if let Some(old_val) = merged.metadata.get(&key) {
+ if old_val != &value {
+ return Err(DataFusionError::ArrowError(
+ ArrowError::InvalidArgumentError(
+ "Fail to merge schema due to conflicting metadata."
+ .to_string(),
+ ),
+ ));
+ }
+ }
+ merged.metadata.insert(key, value);
+ }
+ // merge fields
+ for field in fields.into_iter() {
+ let mut new_field = true;
+ for merged_field in &mut merged.fields {
+ if field.name() != merged_field.name() {
+ continue;
+ }
+ new_field = false;
+ merged_field.try_merge(&field)?
+ }
+ // found a new field, add to field list
+ if new_field {
+ merged.fields.push(field);
+ }
+ }
+ Ok(merged)
+ })
+ }
+
+ fn field_names(&self) -> Vec<String> {
+ self.fields.iter().map(|f| f.name.to_string()).collect()
+ }
+
+ fn project(&self, indices: &[usize]) -> Result<Schema> {
+ let new_fields = indices
+ .iter()
+ .map(|i| {
+ self.fields.get(*i).cloned().ok_or_else(|| {
+ DataFusionError::ArrowError(ArrowError::InvalidArgumentError(
+ format!(
+ "project index {} out of bounds, max field {}",
+ i,
+ self.fields().len()
+ ),
+ ))
+ })
+ })
+ .collect::<Result<Vec<_>>>()?;
+ Ok(Self::new_with_metadata(new_fields, self.metadata.clone()))
+ }
+}
+
+/// Imitate arrow-rs Field behavior by extending arrow2 Field
+pub trait FieldExt {
+ /// The field name
+ fn name(&self) -> &str;
+
+ /// Whether the field is nullable
+ fn is_nullable(&self) -> bool;
+
+ /// Returns the field metadata
+ fn metadata(&self) -> &BTreeMap<String, String>;
+
+ /// Merge field into self if it is compatible. Struct will be merged recursively.
+ /// NOTE: `self` may be updated to unexpected state in case of merge failure.
+ ///
+ /// Example:
+ ///
+ /// ```
+ /// use arrow2::datatypes::*;
+ ///
+ /// let mut field = Field::new("c1", DataType::Int64, false);
+ /// assert!(field.try_merge(&Field::new("c1", DataType::Int64, true)).is_ok());
+ /// assert!(field.is_nullable());
+ /// ```
+ fn try_merge(&mut self, from: &Field) -> Result<()>;
+
+ /// Sets the `Field`'s optional custom metadata.
+ /// The metadata is set as `None` for empty map.
+ fn set_metadata(&mut self, metadata: Option<BTreeMap<String, String>>);
+}
+
+impl FieldExt for Field {
+ #[inline]
+ fn name(&self) -> &str {
+ &self.name
+ }
+
+ #[inline]
+ fn is_nullable(&self) -> bool {
+ self.is_nullable
+ }
+
+ #[inline]
+ fn metadata(&self) -> &BTreeMap<String, String> {
+ &self.metadata
+ }
+
+ fn try_merge(&mut self, from: &Field) -> Result<()> {
+ // merge metadata
+ for (key, from_value) in from.metadata() {
+ if let Some(self_value) = self.metadata.get(key) {
+ if self_value != from_value {
+ return Err(DataFusionError::ArrowError(ArrowError::InvalidArgumentError(format!(
+ "Fail to merge field due to conflicting metadata data value for key {}",
+ key
+ ))));
+ }
+ } else {
+ self.metadata.insert(key.clone(), from_value.clone());
+ }
+ }
+
+ match &mut self.data_type {
+ DataType::Struct(nested_fields) => match &from.data_type {
+ DataType::Struct(from_nested_fields) => {
+ for from_field in from_nested_fields {
+ let mut is_new_field = true;
+ for self_field in nested_fields.iter_mut() {
+ if self_field.name != from_field.name {
+ continue;
+ }
+ is_new_field = false;
+ self_field.try_merge(from_field)?;
+ }
+ if is_new_field {
+ nested_fields.push(from_field.clone());
+ }
+ }
+ }
+ _ => {
+ return Err(DataFusionError::ArrowError(
+ ArrowError::InvalidArgumentError(
+ "Fail to merge schema Field due to conflicting datatype"
+ .to_string(),
+ ),
+ ));
+ }
+ },
+ DataType::Union(nested_fields, _, _) => match &from.data_type {
+ DataType::Union(from_nested_fields, _, _) => {
+ for from_field in from_nested_fields {
+ let mut is_new_field = true;
+ for self_field in nested_fields.iter_mut() {
+ if from_field == self_field {
+ is_new_field = false;
+ break;
+ }
+ }
+ if is_new_field {
+ nested_fields.push(from_field.clone());
+ }
+ }
+ }
+ _ => {
+ return Err(DataFusionError::ArrowError(
+ ArrowError::InvalidArgumentError(
+ "Fail to merge schema Field due to conflicting datatype"
+ .to_string(),
+ ),
+ ));
+ }
+ },
+ DataType::Null
+ | DataType::Boolean
+ | DataType::Int8
+ | DataType::Int16
+ | DataType::Int32
+ | DataType::Int64
+ | DataType::UInt8
+ | DataType::UInt16
+ | DataType::UInt32
+ | DataType::UInt64
+ | DataType::Float16
+ | DataType::Float32
+ | DataType::Float64
+ | DataType::Timestamp(_, _)
+ | DataType::Date32
+ | DataType::Date64
+ | DataType::Time32(_)
+ | DataType::Time64(_)
+ | DataType::Duration(_)
+ | DataType::Binary
+ | DataType::LargeBinary
+ | DataType::Interval(_)
+ | DataType::LargeList(_)
+ | DataType::List(_)
+ | DataType::Dictionary(_, _, _)
+ | DataType::FixedSizeList(_, _)
+ | DataType::FixedSizeBinary(_)
+ | DataType::Utf8
+ | DataType::LargeUtf8
+ | DataType::Extension(_, _, _)
+ | DataType::Map(_, _)
+ | DataType::Decimal(_, _) => {
+ if self.data_type != from.data_type {
+ return Err(DataFusionError::ArrowError(
+ ArrowError::InvalidArgumentError(
+ "Fail to merge schema Field due to conflicting datatype"
+ .to_string(),
+ ),
+ ));
+ }
+ }
+ }
+ if from.is_nullable {
+ self.is_nullable = from.is_nullable;
+ }
+
+ Ok(())
+ }
+
+ #[inline]
+ fn set_metadata(&mut self, metadata: Option<BTreeMap<String, String>>) {
+ if let Some(v) = metadata {
+ if !v.is_empty() {
+ self.metadata = v;
+ }
+ }
+ }
+}
diff --git a/datafusion-common/src/lib.rs b/datafusion-common/src/lib.rs
index d39020f..6b8c075 100644
--- a/datafusion-common/src/lib.rs
+++ b/datafusion-common/src/lib.rs
@@ -18,11 +18,15 @@
mod column;
mod dfschema;
mod error;
+pub mod field_util;
#[cfg(feature = "pyarrow")]
mod pyarrow;
+pub mod record_batch;
mod scalar;
pub use column::Column;
-pub use dfschema::{DFField, DFSchema, DFSchemaRef, ExprSchema, ToDFSchema};
+pub use dfschema::{
+ convert_metadata, DFField, DFMetadata, DFSchema, DFSchemaRef, ExprSchema, ToDFSchema,
+};
pub use error::{DataFusionError, Result};
-pub use scalar::{ScalarType, ScalarValue};
+pub use scalar::{ScalarValue, DECIMAL_MAX_PRECISION, DECIMAL_MAX_SCALE};
diff --git a/datafusion-common/src/record_batch.rs b/datafusion-common/src/record_batch.rs
new file mode 100644
index 0000000..a1fa310
--- /dev/null
+++ b/datafusion-common/src/record_batch.rs
@@ -0,0 +1,452 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Contains [`RecordBatch`].
+use std::sync::Arc;
+
+use crate::field_util::SchemaExt;
+use arrow::array::*;
+use arrow::chunk::Chunk;
+use arrow::compute::filter::{build_filter, filter};
+use arrow::datatypes::*;
+use arrow::error::{ArrowError, Result};
+
+/// A two-dimensional dataset with a number of
+/// columns ([`Array`]) and rows and defined [`Schema`](crate::datatypes::Schema).
+/// # Implementation
+/// Cloning is `O(C)` where `C` is the number of columns.
+#[derive(Clone, Debug, PartialEq)]
+pub struct RecordBatch {
+ schema: Arc<Schema>,
+ columns: Vec<Arc<dyn Array>>,
+}
+
+impl RecordBatch {
+ /// Creates a [`RecordBatch`] from a schema and columns.
+ /// # Errors
+ /// This function errors iff
+ /// * `columns` is empty
+ /// * the schema and column data types do not match
+ /// * `columns` have a different length
+ /// # Example
+ ///
+ /// ```
+ /// # use std::sync::Arc;
+ /// # use arrow::array::PrimitiveArray;
+ /// # use arrow::datatypes::{Schema, Field, DataType};
+ /// # use datafusion_common::record_batch::RecordBatch;
+ /// # use datafusion_common::field_util::SchemaExt;
+ /// # fn main() -> arrow2::error::Result<()> {
+ /// let id_array = PrimitiveArray::from_slice([1i32, 2, 3, 4, 5]);
+ /// let schema = Arc::new(Schema::new(vec![
+ /// Field::new("id", DataType::Int32, false)
+ /// ]));
+ ///
+ /// let batch = RecordBatch::try_new(
+ /// schema,
+ /// vec![Arc::new(id_array)]
+ /// )?;
+ /// # Ok(())
+ /// # }
+ /// ```
+ pub fn try_new(schema: Arc<Schema>, columns: Vec<Arc<dyn Array>>) -> Result<Self> {
+ let options = RecordBatchOptions::default();
+ Self::validate_new_batch(&schema, columns.as_slice(), &options)?;
+ Ok(RecordBatch { schema, columns })
+ }
+
+ /// Creates a [`RecordBatch`] from a schema and columns, with additional options,
+ /// such as whether to strictly validate field names.
+ ///
+ /// See [`Self::try_new()`] for the expected conditions.
+ pub fn try_new_with_options(
+ schema: Arc<Schema>,
+ columns: Vec<Arc<dyn Array>>,
+ options: &RecordBatchOptions,
+ ) -> Result<Self> {
+ Self::validate_new_batch(&schema, &columns, options)?;
+ Ok(RecordBatch { schema, columns })
+ }
+
+ /// Creates a new empty [`RecordBatch`].
+ pub fn new_empty(schema: Arc<Schema>) -> Self {
+ let columns = schema
+ .fields()
+ .iter()
+ .map(|field| new_empty_array(field.data_type().clone()).into())
+ .collect();
+ RecordBatch { schema, columns }
+ }
+
+ /// Creates a new [`RecordBatch`] from a [`arrow::chunk::Chunk`]
+ pub fn new_with_chunk(schema: &Arc<Schema>, chunk: Chunk<ArrayRef>) -> Self {
+ Self {
+ schema: schema.clone(),
+ columns: chunk.into_arrays(),
+ }
+ }
+
+ /// Validate the schema and columns using [`RecordBatchOptions`]. Returns an error
+ /// if any validation check fails.
+ fn validate_new_batch(
+ schema: &Schema,
+ columns: &[Arc<dyn Array>],
+ options: &RecordBatchOptions,
+ ) -> Result<()> {
+ // check that there are some columns
+ if columns.is_empty() {
+ return Err(ArrowError::InvalidArgumentError(
+ "at least one column must be defined to create a record batch"
+ .to_string(),
+ ));
+ }
+ // check that number of fields in schema match column length
+ if schema.fields().len() != columns.len() {
+ return Err(ArrowError::InvalidArgumentError(format!(
+ "number of columns({}) must match number of fields({}) in schema",
+ columns.len(),
+ schema.fields().len(),
+ )));
+ }
+ // check that all columns have the same row count, and match the schema
+ let len = columns[0].len();
+
+ // This is a bit repetitive, but it is better to check the condition outside the loop
+ if options.match_field_names {
+ for (i, column) in columns.iter().enumerate() {
+ if column.len() != len {
+ return Err(ArrowError::InvalidArgumentError(
+ "all columns in a record batch must have the same length"
+ .to_string(),
+ ));
+ }
+ if column.data_type() != schema.field(i).data_type() {
+ return Err(ArrowError::InvalidArgumentError(format!(
+ "column types must match schema types, expected {:?} but found {:?} at column index {}",
+ schema.field(i).data_type(),
+ column.data_type(),
+ i)));
+ }
+ }
+ } else {
+ for (i, column) in columns.iter().enumerate() {
+ if column.len() != len {
+ return Err(ArrowError::InvalidArgumentError(
+ "all columns in a record batch must have the same length"
+ .to_string(),
+ ));
+ }
+ if !column.data_type().eq(schema.field(i).data_type()) {
+ return Err(ArrowError::InvalidArgumentError(format!(
+ "column types must match schema types, expected {:?} but found {:?} at column index {}",
+ schema.field(i).data_type(),
+ column.data_type(),
+ i)));
+ }
+ }
+ }
+
+ Ok(())
+ }
+
+ /// Returns the [`Schema`](crate::datatypes::Schema) of the record batch.
+ pub fn schema(&self) -> &Arc<Schema> {
+ &self.schema
+ }
+
+ /// Returns the number of columns in the record batch.
+ ///
+ /// # Example
+ ///
+ /// ```
+ /// # use std::sync::Arc;
+ /// # use arrow::array::PrimitiveArray;
+ /// # use arrow::datatypes::{Schema, Field, DataType};
+ /// # use datafusion_common::record_batch::RecordBatch;
+ /// # use datafusion_common::field_util::SchemaExt;
+ /// # fn main() -> arrow2::error::Result<()> {
+ /// let id_array = PrimitiveArray::from_slice([1i32, 2, 3, 4, 5]);
+ /// let schema = Arc::new(Schema::new(vec![
+ /// Field::new("id", DataType::Int32, false)
+ /// ]));
+ ///
+ /// let batch = RecordBatch::try_new(schema, vec![Arc::new(id_array)])?;
+ ///
+ /// assert_eq!(batch.num_columns(), 1);
+ /// # Ok(())
+ /// # }
+ /// ```
+ pub fn num_columns(&self) -> usize {
+ self.columns.len()
+ }
+
+ /// Returns the number of rows in each column.
+ ///
+ /// # Panics
+ ///
+ /// Panics if the `RecordBatch` contains no columns.
+ ///
+ /// # Example
+ ///
+ /// ```
+ /// # use std::sync::Arc;
+ /// # use arrow::array::PrimitiveArray;
+ /// # use arrow::datatypes::{Schema, Field, DataType};
+ /// # use datafusion_common::record_batch::RecordBatch;
+ /// # use datafusion_common::field_util::SchemaExt;
+ /// # fn main() -> arrow2::error::Result<()> {
+ /// let id_array = PrimitiveArray::from_slice([1i32, 2, 3, 4, 5]);
+ /// let schema = Arc::new(Schema::new(vec![
+ /// Field::new("id", DataType::Int32, false)
+ /// ]));
+ ///
+ /// let batch = RecordBatch::try_new(schema, vec![Arc::new(id_array)])?;
+ ///
+ /// assert_eq!(batch.num_rows(), 5);
+ /// # Ok(())
+ /// # }
+ /// ```
+ pub fn num_rows(&self) -> usize {
+ self.columns[0].len()
+ }
+
+ /// Get a reference to a column's array by index.
+ ///
+ /// # Panics
+ ///
+ /// Panics if `index` is outside of `0..num_columns`.
+ pub fn column(&self, index: usize) -> &Arc<dyn Array> {
+ &self.columns[index]
+ }
+
+ /// Get a reference to all columns in the record batch.
+ pub fn columns(&self) -> &[Arc<dyn Array>] {
+ &self.columns[..]
+ }
+
+ /// Create a `RecordBatch` from an iterable list of pairs of the
+ /// form `(field_name, array)`, with the same requirements on
+ /// fields and arrays as [`RecordBatch::try_new`]. This method is
+ /// often used to create a single `RecordBatch` from arrays,
+ /// e.g. for testing.
+ ///
+ /// The resulting schema is marked as nullable for each column if
+ /// the array for that column is has any nulls. To explicitly
+ /// specify nullibility, use [`RecordBatch::try_from_iter_with_nullable`]
+ ///
+ /// Example:
+ /// ```
+ /// use std::sync::Arc;
+ /// use arrow::array::*;
+ /// use arrow::datatypes::DataType;
+ /// use datafusion_common::record_batch::RecordBatch;
+ ///
+ /// let a: Arc<dyn Array> = Arc::new(Int32Array::from_slice(&[1, 2]));
+ /// let b: Arc<dyn Array> = Arc::new(Utf8Array::<i32>::from_slice(&["a", "b"]));
+ ///
+ /// let record_batch = RecordBatch::try_from_iter(vec![
+ /// ("a", a),
+ /// ("b", b),
+ /// ]);
+ /// ```
+ pub fn try_from_iter<I, F>(value: I) -> Result<Self>
+ where
+ I: IntoIterator<Item = (F, Arc<dyn Array>)>,
+ F: AsRef<str>,
+ {
+ // TODO: implement `TryFrom` trait, once
+ // https://github.com/rust-lang/rust/issues/50133 is no longer an
+ // issue
+ let iter = value.into_iter().map(|(field_name, array)| {
+ let nullable = array.null_count() > 0;
+ (field_name, array, nullable)
+ });
+
+ Self::try_from_iter_with_nullable(iter)
+ }
+
+ /// Create a `RecordBatch` from an iterable list of tuples of the
+ /// form `(field_name, array, nullable)`, with the same requirements on
+ /// fields and arrays as [`RecordBatch::try_new`]. This method is often
+ /// used to create a single `RecordBatch` from arrays, e.g. for
+ /// testing.
+ ///
+ /// Example:
+ /// ```
+ /// use std::sync::Arc;
+ /// use arrow::array::*;
+ /// use arrow::datatypes::DataType;
+ /// use datafusion_common::record_batch::RecordBatch;
+ ///
+ /// let a: Arc<dyn Array> = Arc::new(Int32Array::from_slice(&[1, 2]));
+ /// let b: Arc<dyn Array> = Arc::new(Utf8Array::<i32>::from_slice(&["a", "b"]));
+ ///
+ /// // Note neither `a` nor `b` has any actual nulls, but we mark
+ /// // b an nullable
+ /// let record_batch = RecordBatch::try_from_iter_with_nullable(vec![
+ /// ("a", a, false),
+ /// ("b", b, true),
+ /// ]);
+ /// ```
+ pub fn try_from_iter_with_nullable<I, F>(value: I) -> Result<Self>
+ where
+ I: IntoIterator<Item = (F, Arc<dyn Array>, bool)>,
+ F: AsRef<str>,
+ {
+ // TODO: implement `TryFrom` trait, once
+ // https://github.com/rust-lang/rust/issues/50133 is no longer an
+ // issue
+ let (fields, columns) = value
+ .into_iter()
+ .map(|(field_name, array, nullable)| {
+ let field_name = field_name.as_ref();
+ let field = Field::new(field_name, array.data_type().clone(), nullable);
+ (field, array)
+ })
+ .unzip();
+
+ let schema = Arc::new(Schema::new(fields));
+ RecordBatch::try_new(schema, columns)
+ }
+
+ /// Deconstructs itself into its internal components
+ pub fn into_inner(self) -> (Vec<Arc<dyn Array>>, Arc<Schema>) {
+ let Self { columns, schema } = self;
+ (columns, schema)
+ }
+
+ /// Projects the schema onto the specified columns
+ pub fn project(&self, indices: &[usize]) -> Result<RecordBatch> {
+ let projected_schema = self.schema.project(indices)?;
+ let batch_fields = indices
+ .iter()
+ .map(|f| {
+ self.columns.get(*f).cloned().ok_or_else(|| {
+ ArrowError::InvalidArgumentError(format!(
+ "project index {} out of bounds, max field {}",
+ f,
+ self.columns.len()
+ ))
+ })
+ })
+ .collect::<Result<Vec<_>>>()?;
+
+ RecordBatch::try_new(SchemaRef::new(projected_schema), batch_fields)
+ }
+
+ /// Return a new RecordBatch where each column is sliced
+ /// according to `offset` and `length`
+ ///
+ /// # Panics
+ ///
+ /// Panics if `offset` with `length` is greater than column length.
+ pub fn slice(&self, offset: usize, length: usize) -> RecordBatch {
+ if self.schema.fields().is_empty() {
+ assert!((offset + length) == 0);
+ return RecordBatch::new_empty(self.schema.clone());
+ }
+ assert!((offset + length) <= self.num_rows());
+
+ let columns = self
+ .columns()
+ .iter()
+ .map(|column| Arc::from(column.slice(offset, length)))
+ .collect();
+
+ Self {
+ schema: self.schema.clone(),
+ columns,
+ }
+ }
+}
+
+/// Options that control the behaviour used when creating a [`RecordBatch`].
+#[derive(Debug)]
+pub struct RecordBatchOptions {
+ /// Match field names of structs and lists. If set to `true`, the names must match.
+ pub match_field_names: bool,
+}
+
+impl Default for RecordBatchOptions {
+ fn default() -> Self {
+ Self {
+ match_field_names: true,
+ }
+ }
+}
+
+impl From<StructArray> for RecordBatch {
+ /// # Panics iff the null count of the array is not null.
+ fn from(array: StructArray) -> Self {
+ assert!(array.null_count() == 0);
+ let (fields, values, _) = array.into_data();
+ RecordBatch {
+ schema: Arc::new(Schema::new(fields)),
+ columns: values,
+ }
+ }
+}
+
+impl From<RecordBatch> for StructArray {
+ fn from(batch: RecordBatch) -> Self {
+ let (fields, values) = batch
+ .schema
+ .fields
+ .iter()
+ .zip(batch.columns.iter())
+ .map(|t| (t.0.clone(), t.1.clone()))
+ .unzip();
+ StructArray::from_data(DataType::Struct(fields), values, None)
+ }
+}
+
+impl From<RecordBatch> for Chunk<ArrayRef> {
+ fn from(rb: RecordBatch) -> Self {
+ Chunk::new(rb.columns)
+ }
+}
+
+impl From<&RecordBatch> for Chunk<ArrayRef> {
+ fn from(rb: &RecordBatch) -> Self {
+ Chunk::new(rb.columns.clone())
+ }
+}
+
+/// Returns a new [RecordBatch] with arrays containing only values matching the filter.
+/// WARNING: the nulls of `filter` are ignored and the value on its slot is considered.
+/// Therefore, it is considered undefined behavior to pass `filter` with null values.
+pub fn filter_record_batch(
+ record_batch: &RecordBatch,
+ filter_values: &BooleanArray,
+) -> Result<RecordBatch> {
+ let num_colums = record_batch.columns().len();
+
+ let filtered_arrays = match num_colums {
+ 1 => {
+ vec![filter(record_batch.columns()[0].as_ref(), filter_values)?.into()]
+ }
+ _ => {
+ let filter = build_filter(filter_values)?;
+ record_batch
+ .columns()
+ .iter()
+ .map(|a| filter(a.as_ref()).into())
+ .collect()
+ }
+ };
+ RecordBatch::try_new(record_batch.schema().clone(), filtered_arrays)
+}
diff --git a/datafusion-common/src/scalar.rs b/datafusion-common/src/scalar.rs
index 4a1dde1..45be8a8 100644
--- a/datafusion-common/src/scalar.rs
+++ b/datafusion-common/src/scalar.rs
@@ -17,17 +17,16 @@
//! This module provides ScalarValue, an enum that can be used for storage of single elements
-use crate::error::{DataFusionError, Result};
+use crate::field_util::{FieldExt, StructArrayExt};
+use crate::{DataFusionError, Result};
+use arrow::bitmap::Bitmap;
+use arrow::buffer::Buffer;
+use arrow::compute::concatenate;
use arrow::{
array::*,
- compute::kernels::cast::cast,
- datatypes::{
- ArrowDictionaryKeyType, ArrowNativeType, DataType, Field, Float32Type,
- Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, IntervalUnit, TimeUnit,
- TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType,
- TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
- DECIMAL_MAX_PRECISION,
- },
+ datatypes::{DataType, Field, IntegerType, IntervalUnit, TimeUnit},
+ scalar::{PrimitiveScalar, Scalar},
+ types::{days_ms, NativeType},
};
use ordered_float::OrderedFloat;
use std::cmp::Ordering;
@@ -35,6 +34,17 @@
use std::str::FromStr;
use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc};
+type StringArray = Utf8Array<i32>;
+type LargeStringArray = Utf8Array<i64>;
+type SmallBinaryArray = BinaryArray<i32>;
+type LargeBinaryArray = BinaryArray<i64>;
+type MutableStringArray = MutableUtf8Array<i32>;
+type MutableLargeStringArray = MutableUtf8Array<i64>;
+
+/// The max precision and scale for decimal128
+pub const DECIMAL_MAX_PRECISION: usize = 38;
+pub const DECIMAL_MAX_SCALE: usize = 38;
+
/// Represents a dynamically typed, nullable single value.
/// This is the single-valued counter-part of arrow’s `Array`.
#[derive(Clone)]
@@ -89,7 +99,7 @@
/// Interval with YearMonth unit
IntervalYearMonth(Option<i32>),
/// Interval with DayTime unit
- IntervalDayTime(Option<i64>),
+ IntervalDayTime(Option<days_ms>),
/// Interval with MonthDayNano unit
IntervalMonthDayNano(Option<i128>),
/// struct of nested ScalarValue (boxed to reduce size_of(ScalarValue))
@@ -258,7 +268,10 @@
(TimestampNanosecond(_, _), _) => None,
(IntervalYearMonth(v1), IntervalYearMonth(v2)) => v1.partial_cmp(v2),
(IntervalYearMonth(_), _) => None,
- (IntervalDayTime(v1), IntervalDayTime(v2)) => v1.partial_cmp(v2),
+ (IntervalDayTime(v1), IntervalDayTime(v2)) => v1
+ .map(|d| d.to_le_bytes())
+ .partial_cmp(&v2.map(|d| d.to_le_bytes())),
+ (_, IntervalDayTime(_)) => None,
(IntervalDayTime(_), _) => None,
(IntervalMonthDayNano(v1), IntervalMonthDayNano(v2)) => v1.partial_cmp(v2),
(IntervalMonthDayNano(_), _) => None,
@@ -333,7 +346,7 @@
// as a reference to the dictionary values array. Returns None for the
// index if the array is NULL at index
#[inline]
-fn get_dict_value<K: ArrowDictionaryKeyType>(
+fn get_dict_value<K: DictionaryKey>(
array: &ArrayRef,
index: usize,
) -> Result<(&ArrayRef, Option<usize>)> {
@@ -355,8 +368,8 @@
}
macro_rules! typed_cast_tz {
- ($array:expr, $index:expr, $ARRAYTYPE:ident, $SCALAR:ident, $TZ:expr) => {{
- let array = $array.as_any().downcast_ref::<$ARRAYTYPE>().unwrap();
+ ($array:expr, $index:expr, $SCALAR:ident, $TZ:expr) => {{
+ let array = $array.as_any().downcast_ref::<Int64Array>().unwrap();
ScalarValue::$SCALAR(
match array.is_null($index) {
true => None,
@@ -379,68 +392,59 @@
macro_rules! build_list {
($VALUE_BUILDER_TY:ident, $SCALAR_TY:ident, $VALUES:expr, $SIZE:expr) => {{
+ let dt = DataType::List(Box::new(Field::new("item", DataType::$SCALAR_TY, true)));
match $VALUES {
// the return on the macro is necessary, to short-circuit and return ArrayRef
None => {
- return new_null_array(
- &DataType::List(Box::new(Field::new(
- "item",
- DataType::$SCALAR_TY,
- true,
- ))),
- $SIZE,
- )
+ return Arc::from(new_null_array(dt, $SIZE));
}
Some(values) => {
- build_values_list!($VALUE_BUILDER_TY, $SCALAR_TY, values.as_ref(), $SIZE)
+ let mut array = MutableListArray::<i32, $VALUE_BUILDER_TY>::new_from(
+ <$VALUE_BUILDER_TY>::default(),
+ dt,
+ $SIZE,
+ );
+ build_values_list!(array, $SCALAR_TY, values.as_ref(), $SIZE)
}
}
}};
}
macro_rules! build_timestamp_list {
- ($TIME_UNIT:expr, $TIME_ZONE:expr, $VALUES:expr, $SIZE:expr) => {{
+ ($TIME_UNIT:expr, $VALUES:expr, $SIZE:expr, $TZ:expr) => {{
+ let child_dt = DataType::Timestamp($TIME_UNIT, $TZ.clone());
match $VALUES {
// the return on the macro is necessary, to short-circuit and return ArrayRef
None => {
- return new_null_array(
- &DataType::List(Box::new(Field::new(
- "item",
- DataType::Timestamp($TIME_UNIT, $TIME_ZONE),
- true,
- ))),
+ let null_array: ArrayRef = new_null_array(
+ DataType::List(Box::new(Field::new("item", child_dt, true))),
$SIZE,
)
+ .into();
+ null_array
}
Some(values) => {
let values = values.as_ref();
+ let empty_arr = <Int64Vec>::default().to(child_dt.clone());
+ let mut array = MutableListArray::<i32, Int64Vec>::new_from(
+ empty_arr,
+ DataType::List(Box::new(Field::new("item", child_dt, true))),
+ $SIZE,
+ );
+
match $TIME_UNIT {
TimeUnit::Second => {
- build_values_list_tz!(
- TimestampSecondBuilder,
- TimestampSecond,
- values,
- $SIZE
- )
+ build_values_list_tz!(array, TimestampSecond, values, $SIZE)
}
- TimeUnit::Microsecond => build_values_list_tz!(
- TimestampMillisecondBuilder,
- TimestampMillisecond,
- values,
- $SIZE
- ),
- TimeUnit::Millisecond => build_values_list_tz!(
- TimestampMicrosecondBuilder,
- TimestampMicrosecond,
- values,
- $SIZE
- ),
- TimeUnit::Nanosecond => build_values_list_tz!(
- TimestampNanosecondBuilder,
- TimestampNanosecond,
- values,
- $SIZE
- ),
+ TimeUnit::Microsecond => {
+ build_values_list_tz!(array, TimestampMillisecond, values, $SIZE)
+ }
+ TimeUnit::Millisecond => {
+ build_values_list_tz!(array, TimestampMicrosecond, values, $SIZE)
+ }
+ TimeUnit::Nanosecond => {
+ build_values_list_tz!(array, TimestampNanosecond, values, $SIZE)
+ }
}
}
}
@@ -448,74 +452,52 @@
}
macro_rules! build_values_list {
- ($VALUE_BUILDER_TY:ident, $SCALAR_TY:ident, $VALUES:expr, $SIZE:expr) => {{
- let mut builder = ListBuilder::new($VALUE_BUILDER_TY::new($VALUES.len()));
-
+ ($MUTABLE_ARR:ident, $SCALAR_TY:ident, $VALUES:expr, $SIZE:expr) => {{
for _ in 0..$SIZE {
+ let mut vec = vec![];
for scalar_value in $VALUES {
match scalar_value {
- ScalarValue::$SCALAR_TY(Some(v)) => {
- builder.values().append_value(v.clone()).unwrap()
- }
- ScalarValue::$SCALAR_TY(None) => {
- builder.values().append_null().unwrap();
+ ScalarValue::$SCALAR_TY(v) => {
+ vec.push(v.clone());
}
_ => panic!("Incompatible ScalarValue for list"),
};
}
- builder.append(true).unwrap();
+ $MUTABLE_ARR.try_push(Some(vec)).unwrap();
}
- builder.finish()
+ let array: ListArray<i32> = $MUTABLE_ARR.into();
+ Arc::new(array)
+ }};
+}
+
+macro_rules! dyn_to_array {
+ ($self:expr, $value:expr, $size:expr, $ty:ty) => {{
+ Arc::new(PrimitiveArray::<$ty>::from_data(
+ $self.get_datatype(),
+ Buffer::<$ty>::from_iter(repeat(*$value).take($size)),
+ None,
+ ))
}};
}
macro_rules! build_values_list_tz {
- ($VALUE_BUILDER_TY:ident, $SCALAR_TY:ident, $VALUES:expr, $SIZE:expr) => {{
- let mut builder = ListBuilder::new($VALUE_BUILDER_TY::new($VALUES.len()));
-
+ ($MUTABLE_ARR:ident, $SCALAR_TY:ident, $VALUES:expr, $SIZE:expr) => {{
for _ in 0..$SIZE {
+ let mut vec = vec![];
for scalar_value in $VALUES {
match scalar_value {
- ScalarValue::$SCALAR_TY(Some(v), _) => {
- builder.values().append_value(v.clone()).unwrap()
- }
- ScalarValue::$SCALAR_TY(None, _) => {
- builder.values().append_null().unwrap();
+ ScalarValue::$SCALAR_TY(v, _) => {
+ vec.push(v.clone());
}
_ => panic!("Incompatible ScalarValue for list"),
};
}
- builder.append(true).unwrap();
+ $MUTABLE_ARR.try_push(Some(vec)).unwrap();
}
- builder.finish()
- }};
-}
-
-macro_rules! build_array_from_option {
- ($DATA_TYPE:ident, $ARRAY_TYPE:ident, $EXPR:expr, $SIZE:expr) => {{
- match $EXPR {
- Some(value) => Arc::new($ARRAY_TYPE::from_value(*value, $SIZE)),
- None => new_null_array(&DataType::$DATA_TYPE, $SIZE),
- }
- }};
- ($DATA_TYPE:ident, $ENUM:expr, $ARRAY_TYPE:ident, $EXPR:expr, $SIZE:expr) => {{
- match $EXPR {
- Some(value) => Arc::new($ARRAY_TYPE::from_value(*value, $SIZE)),
- None => new_null_array(&DataType::$DATA_TYPE($ENUM), $SIZE),
- }
- }};
- ($DATA_TYPE:ident, $ENUM:expr, $ENUM2:expr, $ARRAY_TYPE:ident, $EXPR:expr, $SIZE:expr) => {{
- match $EXPR {
- Some(value) => {
- let array: ArrayRef = Arc::new($ARRAY_TYPE::from_value(*value, $SIZE));
- // Need to call cast to cast to final data type with timezone/extra param
- cast(&array, &DataType::$DATA_TYPE($ENUM, $ENUM2))
- .expect("cannot do temporal cast")
- }
- None => new_null_array(&DataType::$DATA_TYPE($ENUM, $ENUM2), $SIZE),
- }
+ let array: ListArray<i32> = $MUTABLE_ARR.into();
+ Arc::new(array)
}};
}
@@ -661,8 +643,8 @@
///
/// Example
/// ```
- /// use datafusion_common::ScalarValue;
- /// use arrow::array::{ArrayRef, BooleanArray};
+ /// use datafusion::scalar::ScalarValue;
+ /// use arrow::array::{BooleanArray, Array};
///
/// let scalars = vec![
/// ScalarValue::Boolean(Some(true)),
@@ -674,8 +656,8 @@
/// let array = ScalarValue::iter_to_array(scalars.into_iter())
/// .unwrap();
///
- /// let expected: ArrayRef = std::sync::Arc::new(
- /// BooleanArray::from(vec![
+ /// let expected: Box<dyn Array> = Box::new(
+ /// BooleanArray::from_slice(vec![
/// Some(true),
/// None,
/// Some(false)
@@ -702,218 +684,203 @@
/// Creates an array of $ARRAY_TY by unpacking values of
/// SCALAR_TY for primitive types
macro_rules! build_array_primitive {
- ($ARRAY_TY:ident, $SCALAR_TY:ident) => {{
- {
- let array = scalars
- .map(|sv| {
- if let ScalarValue::$SCALAR_TY(v) = sv {
- Ok(v)
- } else {
- Err(DataFusionError::Internal(format!(
- "Inconsistent types in ScalarValue::iter_to_array. \
+ ($TY:ty, $SCALAR_TY:ident, $DT:ident) => {{
+ {
+ Arc::new(scalars
+ .map(|sv| {
+ if let ScalarValue::$SCALAR_TY(v) = sv {
+ Ok(v)
+ } else {
+ Err(DataFusionError::Internal(format!(
+ "Inconsistent types in ScalarValue::iter_to_array. \
Expected {:?}, got {:?}",
- data_type, sv
- )))
- }
- })
- .collect::<Result<$ARRAY_TY>>()?;
-
- Arc::new(array)
+ data_type, sv
+ )))
+ }
+ }).collect::<Result<PrimitiveArray<$TY>>>()?.to($DT)
+ ) as Arc<dyn Array>
+ }
+ }};
}
- }};
- }
macro_rules! build_array_primitive_tz {
- ($ARRAY_TY:ident, $SCALAR_TY:ident) => {{
- {
- let array = scalars
- .map(|sv| {
- if let ScalarValue::$SCALAR_TY(v, _) = sv {
- Ok(v)
- } else {
- Err(DataFusionError::Internal(format!(
- "Inconsistent types in ScalarValue::iter_to_array. \
+ ($SCALAR_TY:ident) => {{
+ {
+ let array = scalars
+ .map(|sv| {
+ if let ScalarValue::$SCALAR_TY(v, _) = sv {
+ Ok(v)
+ } else {
+ Err(DataFusionError::Internal(format!(
+ "Inconsistent types in ScalarValue::iter_to_array. \
Expected {:?}, got {:?}",
- data_type, sv
- )))
- }
- })
- .collect::<Result<$ARRAY_TY>>()?;
+ data_type, sv
+ )))
+ }
+ })
+ .collect::<Result<Int64Array>>()?;
- Arc::new(array)
+ Arc::new(array)
+ }
+ }};
}
- }};
- }
/// Creates an array of $ARRAY_TY by unpacking values of
/// SCALAR_TY for "string-like" types.
macro_rules! build_array_string {
- ($ARRAY_TY:ident, $SCALAR_TY:ident) => {{
- {
- let array = scalars
- .map(|sv| {
- if let ScalarValue::$SCALAR_TY(v) = sv {
- Ok(v)
- } else {
- Err(DataFusionError::Internal(format!(
- "Inconsistent types in ScalarValue::iter_to_array. \
+ ($ARRAY_TY:ident, $SCALAR_TY:ident) => {{
+ {
+ let array = scalars
+ .map(|sv| {
+ if let ScalarValue::$SCALAR_TY(v) = sv {
+ Ok(v)
+ } else {
+ Err(DataFusionError::Internal(format!(
+ "Inconsistent types in ScalarValue::iter_to_array. \
Expected {:?}, got {:?}",
- data_type, sv
- )))
- }
- })
- .collect::<Result<$ARRAY_TY>>()?;
- Arc::new(array)
- }
- }};
- }
-
- macro_rules! build_array_list_primitive {
- ($ARRAY_TY:ident, $SCALAR_TY:ident, $NATIVE_TYPE:ident) => {{
- Arc::new(ListArray::from_iter_primitive::<$ARRAY_TY, _, _>(
- scalars.into_iter().map(|x| match x {
- ScalarValue::List(xs, _) => xs.map(|x| {
- x.iter()
- .map(|x| match x {
- ScalarValue::$SCALAR_TY(i) => *i,
- sv => panic!(
- "Inconsistent types in ScalarValue::iter_to_array. \
- Expected {:?}, got {:?}",
- data_type, sv
- ),
- })
- .collect::<Vec<Option<$NATIVE_TYPE>>>()
- }),
- sv => panic!(
- "Inconsistent types in ScalarValue::iter_to_array. \
- Expected {:?}, got {:?}",
- data_type, sv
- ),
- }),
- ))
- }};
- }
-
- macro_rules! build_array_list_string {
- ($BUILDER:ident, $SCALAR_TY:ident) => {{
- let mut builder = ListBuilder::new($BUILDER::new(0));
-
- for scalar in scalars.into_iter() {
- match scalar {
- ScalarValue::List(Some(xs), _) => {
- let xs = *xs;
- for s in xs {
- match s {
- ScalarValue::$SCALAR_TY(Some(val)) => {
- builder.values().append_value(val)?;
- }
- ScalarValue::$SCALAR_TY(None) => {
- builder.values().append_null()?;
- }
- sv => {
- return Err(DataFusionError::Internal(format!(
- "Inconsistent types in ScalarValue::iter_to_array. \
- Expected Utf8, got {:?}",
- sv
- )))
- }
+ data_type, sv
+ )))
+ }
+ })
+ .collect::<Result<$ARRAY_TY>>()?;
+ Arc::new(array)
}
- }
- builder.append(true)?;
- }
- ScalarValue::List(None, _) => {
- builder.append(false)?;
- }
- sv => {
- return Err(DataFusionError::Internal(format!(
- "Inconsistent types in ScalarValue::iter_to_array. \
- Expected List, got {:?}",
- sv
- )))
- }
- }
+ }};
}
- Arc::new(builder.finish())
- }};
- }
+ macro_rules! build_array_list {
+ ($MUTABLE_TY:ty, $SCALAR_TY:ident) => {{
+ let mut array = MutableListArray::<i32, $MUTABLE_TY>::new();
+ for scalar in scalars.into_iter() {
+ match scalar {
+ ScalarValue::List(Some(xs), _) => {
+ let xs = *xs;
+ let mut vec = vec![];
+ for s in xs {
+ match s {
+ ScalarValue::$SCALAR_TY(o) => { vec.push(o) }
+ sv => return Err(DataFusionError::Internal(format!(
+ "Inconsistent types in ScalarValue::iter_to_array. \
+ Expected Utf8, got {:?}",
+ sv
+ ))),
+ }
+ }
+ array.try_push(Some(vec))?;
+ }
+ ScalarValue::List(None, _) => {
+ array.push_null();
+ }
+ sv => {
+ return Err(DataFusionError::Internal(format!(
+ "Inconsistent types in ScalarValue::iter_to_array. \
+ Expected List, got {:?}",
+ sv
+ )))
+ }
+ }
+ }
- let array: ArrayRef = match &data_type {
+ let array: ListArray<i32> = array.into();
+ Arc::new(array)
+ }}
+ }
+
+ use DataType::*;
+ let array: Arc<dyn Array> = match &data_type {
DataType::Decimal(precision, scale) => {
let decimal_array =
ScalarValue::iter_to_decimal_array(scalars, precision, scale)?;
Arc::new(decimal_array)
}
- DataType::Boolean => build_array_primitive!(BooleanArray, Boolean),
- DataType::Float32 => build_array_primitive!(Float32Array, Float32),
- DataType::Float64 => build_array_primitive!(Float64Array, Float64),
- DataType::Int8 => build_array_primitive!(Int8Array, Int8),
- DataType::Int16 => build_array_primitive!(Int16Array, Int16),
- DataType::Int32 => build_array_primitive!(Int32Array, Int32),
- DataType::Int64 => build_array_primitive!(Int64Array, Int64),
- DataType::UInt8 => build_array_primitive!(UInt8Array, UInt8),
- DataType::UInt16 => build_array_primitive!(UInt16Array, UInt16),
- DataType::UInt32 => build_array_primitive!(UInt32Array, UInt32),
- DataType::UInt64 => build_array_primitive!(UInt64Array, UInt64),
- DataType::Utf8 => build_array_string!(StringArray, Utf8),
- DataType::LargeUtf8 => build_array_string!(LargeStringArray, LargeUtf8),
- DataType::Binary => build_array_string!(BinaryArray, Binary),
- DataType::LargeBinary => build_array_string!(LargeBinaryArray, LargeBinary),
- DataType::Date32 => build_array_primitive!(Date32Array, Date32),
- DataType::Date64 => build_array_primitive!(Date64Array, Date64),
- DataType::Timestamp(TimeUnit::Second, _) => {
- build_array_primitive_tz!(TimestampSecondArray, TimestampSecond)
+ DataType::Boolean => Arc::new(
+ scalars
+ .map(|sv| {
+ if let ScalarValue::Boolean(v) = sv {
+ Ok(v)
+ } else {
+ Err(DataFusionError::Internal(format!(
+ "Inconsistent types in ScalarValue::iter_to_array. \
+ Expected {:?}, got {:?}",
+ data_type, sv
+ )))
+ }
+ })
+ .collect::<Result<BooleanArray>>()?,
+ ),
+ Float32 => {
+ build_array_primitive!(f32, Float32, Float32)
}
- DataType::Timestamp(TimeUnit::Millisecond, _) => {
- build_array_primitive_tz!(TimestampMillisecondArray, TimestampMillisecond)
+ Float64 => {
+ build_array_primitive!(f64, Float64, Float64)
}
- DataType::Timestamp(TimeUnit::Microsecond, _) => {
- build_array_primitive_tz!(TimestampMicrosecondArray, TimestampMicrosecond)
+ Int8 => build_array_primitive!(i8, Int8, Int8),
+ Int16 => build_array_primitive!(i16, Int16, Int16),
+ Int32 => build_array_primitive!(i32, Int32, Int32),
+ Int64 => build_array_primitive!(i64, Int64, Int64),
+ UInt8 => build_array_primitive!(u8, UInt8, UInt8),
+ UInt16 => build_array_primitive!(u16, UInt16, UInt16),
+ UInt32 => build_array_primitive!(u32, UInt32, UInt32),
+ UInt64 => build_array_primitive!(u64, UInt64, UInt64),
+ Utf8 => build_array_string!(StringArray, Utf8),
+ LargeUtf8 => build_array_string!(LargeStringArray, LargeUtf8),
+ Binary => build_array_string!(SmallBinaryArray, Binary),
+ LargeBinary => build_array_string!(LargeBinaryArray, LargeBinary),
+ Date32 => build_array_primitive!(i32, Date32, Date32),
+ Date64 => build_array_primitive!(i64, Date64, Date64),
+ Timestamp(TimeUnit::Second, _) => {
+ build_array_primitive_tz!(TimestampSecond)
}
- DataType::Timestamp(TimeUnit::Nanosecond, _) => {
- build_array_primitive_tz!(TimestampNanosecondArray, TimestampNanosecond)
+ Timestamp(TimeUnit::Millisecond, _) => {
+ build_array_primitive_tz!(TimestampMillisecond)
}
- DataType::Interval(IntervalUnit::DayTime) => {
- build_array_primitive!(IntervalDayTimeArray, IntervalDayTime)
+ Timestamp(TimeUnit::Microsecond, _) => {
+ build_array_primitive_tz!(TimestampMicrosecond)
}
- DataType::Interval(IntervalUnit::YearMonth) => {
- build_array_primitive!(IntervalYearMonthArray, IntervalYearMonth)
+ Timestamp(TimeUnit::Nanosecond, _) => {
+ build_array_primitive_tz!(TimestampNanosecond)
+ }
+ Interval(IntervalUnit::DayTime) => {
+ build_array_primitive!(days_ms, IntervalDayTime, data_type)
+ }
+ Interval(IntervalUnit::YearMonth) => {
+ build_array_primitive!(i32, IntervalYearMonth, data_type)
}
DataType::List(fields) if fields.data_type() == &DataType::Int8 => {
- build_array_list_primitive!(Int8Type, Int8, i8)
+ build_array_list!(Int8Vec, Int8)
}
DataType::List(fields) if fields.data_type() == &DataType::Int16 => {
- build_array_list_primitive!(Int16Type, Int16, i16)
+ build_array_list!(Int16Vec, Int16)
}
DataType::List(fields) if fields.data_type() == &DataType::Int32 => {
- build_array_list_primitive!(Int32Type, Int32, i32)
+ build_array_list!(Int32Vec, Int32)
}
DataType::List(fields) if fields.data_type() == &DataType::Int64 => {
- build_array_list_primitive!(Int64Type, Int64, i64)
+ build_array_list!(Int64Vec, Int64)
}
DataType::List(fields) if fields.data_type() == &DataType::UInt8 => {
- build_array_list_primitive!(UInt8Type, UInt8, u8)
+ build_array_list!(UInt8Vec, UInt8)
}
DataType::List(fields) if fields.data_type() == &DataType::UInt16 => {
- build_array_list_primitive!(UInt16Type, UInt16, u16)
+ build_array_list!(UInt16Vec, UInt16)
}
DataType::List(fields) if fields.data_type() == &DataType::UInt32 => {
- build_array_list_primitive!(UInt32Type, UInt32, u32)
+ build_array_list!(UInt32Vec, UInt32)
}
DataType::List(fields) if fields.data_type() == &DataType::UInt64 => {
- build_array_list_primitive!(UInt64Type, UInt64, u64)
+ build_array_list!(UInt64Vec, UInt64)
}
DataType::List(fields) if fields.data_type() == &DataType::Float32 => {
- build_array_list_primitive!(Float32Type, Float32, f32)
+ build_array_list!(Float32Vec, Float32)
}
DataType::List(fields) if fields.data_type() == &DataType::Float64 => {
- build_array_list_primitive!(Float64Type, Float64, f64)
+ build_array_list!(Float64Vec, Float64)
}
DataType::List(fields) if fields.data_type() == &DataType::Utf8 => {
- build_array_list_string!(StringBuilder, Utf8)
+ build_array_list!(MutableStringArray, Utf8)
}
DataType::List(fields) if fields.data_type() == &DataType::LargeUtf8 => {
- build_array_list_string!(LargeStringBuilder, LargeUtf8)
+ build_array_list!(MutableLargeStringArray, LargeUtf8)
}
DataType::List(_) => {
// Fallback case handling homogeneous lists with any ScalarValue element type
@@ -954,15 +921,12 @@
}
// Call iter_to_array recursively to convert the scalars for each column into Arrow arrays
- let field_values = fields
+ let field_values = columns
.iter()
- .zip(columns)
- .map(|(field, column)| -> Result<(Field, ArrayRef)> {
- Ok((field.clone(), Self::iter_to_array(column)?))
- })
+ .map(|c| Self::iter_to_array(c.clone()).map(Arc::from))
.collect::<Result<Vec<_>>>()?;
- Arc::new(StructArray::from(field_values))
+ Arc::new(StructArray::from_data(data_type, field_values, None))
}
_ => {
return Err(DataFusionError::Internal(format!(
@@ -980,29 +944,31 @@
scalars: impl IntoIterator<Item = ScalarValue>,
precision: &usize,
scale: &usize,
- ) -> Result<DecimalArray> {
+ ) -> Result<Int128Array> {
+ // collect the value as Option<i128>
let array = scalars
.into_iter()
.map(|element: ScalarValue| match element {
ScalarValue::Decimal128(v1, _, _) => v1,
_ => unreachable!(),
})
- .collect::<DecimalArray>()
- .with_precision_and_scale(*precision, *scale)?;
- Ok(array)
+ .collect::<Vec<Option<i128>>>();
+
+ // build the decimal array using the Decimal Builder
+ Ok(Int128Vec::from(array)
+ .to(DataType::Decimal(*precision, *scale))
+ .into())
}
fn iter_to_array_list(
scalars: impl IntoIterator<Item = ScalarValue>,
data_type: &DataType,
- ) -> Result<GenericListArray<i32>> {
- let mut offsets = Int32Array::builder(0);
- if let Err(err) = offsets.append_value(0) {
- return Err(DataFusionError::ArrowError(err));
- }
+ ) -> Result<ListArray<i32>> {
+ let mut offsets: Vec<i32> = vec![0];
let mut elements: Vec<ArrayRef> = Vec::new();
- let mut valid = BooleanBufferBuilder::new(0);
+ let mut valid: Vec<bool> = vec![];
+
let mut flat_len = 0i32;
for scalar in scalars {
if let ScalarValue::List(values, _) = scalar {
@@ -1012,23 +978,19 @@
// Add new offset index
flat_len += element_array.len() as i32;
- if let Err(err) = offsets.append_value(flat_len) {
- return Err(DataFusionError::ArrowError(err));
- }
+ offsets.push(flat_len);
elements.push(element_array);
// Element is valid
- valid.append(true);
+ valid.push(true);
}
None => {
// Repeat previous offset index
- if let Err(err) = offsets.append_value(flat_len) {
- return Err(DataFusionError::ArrowError(err));
- }
+ offsets.push(flat_len);
// Element is null
- valid.append(false);
+ valid.push(false);
}
}
} else {
@@ -1042,212 +1004,167 @@
// Concatenate element arrays to create single flat array
let element_arrays: Vec<&dyn Array> =
elements.iter().map(|a| a.as_ref()).collect();
- let flat_array = match arrow::compute::concat(&element_arrays) {
+ let flat_array = match concatenate::concatenate(&element_arrays) {
Ok(flat_array) => flat_array,
Err(err) => return Err(DataFusionError::ArrowError(err)),
};
- // Build ListArray using ArrayData so we can specify a flat inner array, and offset indices
- let offsets_array = offsets.finish();
- let array_data = ArrayDataBuilder::new(data_type.clone())
- .len(offsets_array.len() - 1)
- .null_bit_buffer(valid.finish())
- .add_buffer(offsets_array.data().buffers()[0].clone())
- .add_child_data(flat_array.data().clone());
+ let list_array = ListArray::<i32>::from_data(
+ data_type.clone(),
+ Buffer::from(offsets),
+ flat_array.into(),
+ Some(Bitmap::from(valid)),
+ );
- let list_array = ListArray::from(array_data.build()?);
Ok(list_array)
}
- fn build_decimal_array(
- value: &Option<i128>,
- precision: &usize,
- scale: &usize,
- size: usize,
- ) -> DecimalArray {
- std::iter::repeat(value)
- .take(size)
- .collect::<DecimalArray>()
- .with_precision_and_scale(*precision, *scale)
- .unwrap()
- }
-
/// Converts a scalar value into an array of `size` rows.
pub fn to_array_of_size(&self, size: usize) -> ArrayRef {
match self {
ScalarValue::Decimal128(e, precision, scale) => {
- Arc::new(ScalarValue::build_decimal_array(e, precision, scale, size))
+ Int128Vec::from_iter(repeat(e).take(size))
+ .to(DataType::Decimal(*precision, *scale))
+ .into_arc()
}
ScalarValue::Boolean(e) => {
- Arc::new(BooleanArray::from(vec![*e; size])) as ArrayRef
+ Arc::new(BooleanArray::from_iter(vec![*e; size])) as ArrayRef
}
- ScalarValue::Float64(e) => {
- build_array_from_option!(Float64, Float64Array, e, size)
- }
- ScalarValue::Float32(e) => {
- build_array_from_option!(Float32, Float32Array, e, size)
- }
- ScalarValue::Int8(e) => build_array_from_option!(Int8, Int8Array, e, size),
- ScalarValue::Int16(e) => build_array_from_option!(Int16, Int16Array, e, size),
- ScalarValue::Int32(e) => build_array_from_option!(Int32, Int32Array, e, size),
- ScalarValue::Int64(e) => build_array_from_option!(Int64, Int64Array, e, size),
- ScalarValue::UInt8(e) => build_array_from_option!(UInt8, UInt8Array, e, size),
- ScalarValue::UInt16(e) => {
- build_array_from_option!(UInt16, UInt16Array, e, size)
- }
- ScalarValue::UInt32(e) => {
- build_array_from_option!(UInt32, UInt32Array, e, size)
- }
- ScalarValue::UInt64(e) => {
- build_array_from_option!(UInt64, UInt64Array, e, size)
- }
- ScalarValue::TimestampSecond(e, tz_opt) => build_array_from_option!(
- Timestamp,
- TimeUnit::Second,
- tz_opt.clone(),
- TimestampSecondArray,
- e,
- size
- ),
- ScalarValue::TimestampMillisecond(e, tz_opt) => build_array_from_option!(
- Timestamp,
- TimeUnit::Millisecond,
- tz_opt.clone(),
- TimestampMillisecondArray,
- e,
- size
- ),
-
- ScalarValue::TimestampMicrosecond(e, tz_opt) => build_array_from_option!(
- Timestamp,
- TimeUnit::Microsecond,
- tz_opt.clone(),
- TimestampMicrosecondArray,
- e,
- size
- ),
- ScalarValue::TimestampNanosecond(e, tz_opt) => build_array_from_option!(
- Timestamp,
- TimeUnit::Nanosecond,
- tz_opt.clone(),
- TimestampNanosecondArray,
- e,
- size
- ),
- ScalarValue::Utf8(e) => match e {
+ ScalarValue::Float64(e) => match e {
Some(value) => {
- Arc::new(StringArray::from_iter_values(repeat(value).take(size)))
+ dyn_to_array!(self, value, size, f64)
}
- None => new_null_array(&DataType::Utf8, size),
+ None => new_null_array(self.get_datatype(), size).into(),
+ },
+ ScalarValue::Float32(e) => match e {
+ Some(value) => dyn_to_array!(self, value, size, f32),
+ None => new_null_array(self.get_datatype(), size).into(),
+ },
+ ScalarValue::Int8(e) => match e {
+ Some(value) => dyn_to_array!(self, value, size, i8),
+ None => new_null_array(self.get_datatype(), size).into(),
+ },
+ ScalarValue::Int16(e) => match e {
+ Some(value) => dyn_to_array!(self, value, size, i16),
+ None => new_null_array(self.get_datatype(), size).into(),
+ },
+ ScalarValue::Int32(e)
+ | ScalarValue::Date32(e)
+ | ScalarValue::IntervalYearMonth(e) => match e {
+ Some(value) => dyn_to_array!(self, value, size, i32),
+ None => new_null_array(self.get_datatype(), size).into(),
+ },
+ ScalarValue::IntervalMonthDayNano(e) => match e {
+ Some(value) => dyn_to_array!(self, value, size, i128),
+ None => new_null_array(self.get_datatype(), size).into(),
+ },
+ ScalarValue::Int64(e) | ScalarValue::Date64(e) => match e {
+ Some(value) => dyn_to_array!(self, value, size, i64),
+ None => new_null_array(self.get_datatype(), size).into(),
+ },
+ ScalarValue::UInt8(e) => match e {
+ Some(value) => dyn_to_array!(self, value, size, u8),
+ None => new_null_array(self.get_datatype(), size).into(),
+ },
+ ScalarValue::UInt16(e) => match e {
+ Some(value) => dyn_to_array!(self, value, size, u16),
+ None => new_null_array(self.get_datatype(), size).into(),
+ },
+ ScalarValue::UInt32(e) => match e {
+ Some(value) => dyn_to_array!(self, value, size, u32),
+ None => new_null_array(self.get_datatype(), size).into(),
+ },
+ ScalarValue::UInt64(e) => match e {
+ Some(value) => dyn_to_array!(self, value, size, u64),
+ None => new_null_array(self.get_datatype(), size).into(),
+ },
+ ScalarValue::TimestampSecond(e, _) => match e {
+ Some(value) => dyn_to_array!(self, value, size, i64),
+ None => new_null_array(self.get_datatype(), size).into(),
+ },
+ ScalarValue::TimestampMillisecond(e, _) => match e {
+ Some(value) => dyn_to_array!(self, value, size, i64),
+ None => new_null_array(self.get_datatype(), size).into(),
+ },
+
+ ScalarValue::TimestampMicrosecond(e, _) => match e {
+ Some(value) => dyn_to_array!(self, value, size, i64),
+ None => new_null_array(self.get_datatype(), size).into(),
+ },
+ ScalarValue::TimestampNanosecond(e, _) => match e {
+ Some(value) => dyn_to_array!(self, value, size, i64),
+ None => new_null_array(self.get_datatype(), size).into(),
+ },
+ ScalarValue::Utf8(e) => match e {
+ Some(value) => Arc::new(Utf8Array::<i32>::from_trusted_len_values_iter(
+ repeat(&value).take(size),
+ )),
+ None => new_null_array(self.get_datatype(), size).into(),
},
ScalarValue::LargeUtf8(e) => match e {
- Some(value) => {
- Arc::new(LargeStringArray::from_iter_values(repeat(value).take(size)))
- }
- None => new_null_array(&DataType::LargeUtf8, size),
+ Some(value) => Arc::new(Utf8Array::<i64>::from_trusted_len_values_iter(
+ repeat(&value).take(size),
+ )),
+ None => new_null_array(self.get_datatype(), size).into(),
},
ScalarValue::Binary(e) => match e {
Some(value) => Arc::new(
repeat(Some(value.as_slice()))
.take(size)
- .collect::<BinaryArray>(),
+ .collect::<BinaryArray<i32>>(),
),
- None => {
- Arc::new(repeat(None::<&str>).take(size).collect::<BinaryArray>())
- }
+ None => new_null_array(self.get_datatype(), size).into(),
},
ScalarValue::LargeBinary(e) => match e {
Some(value) => Arc::new(
repeat(Some(value.as_slice()))
.take(size)
- .collect::<LargeBinaryArray>(),
+ .collect::<BinaryArray<i64>>(),
),
- None => Arc::new(
- repeat(None::<&str>)
- .take(size)
- .collect::<LargeBinaryArray>(),
- ),
+ None => new_null_array(self.get_datatype(), size).into(),
},
- ScalarValue::List(values, data_type) => Arc::new(match data_type.as_ref() {
- DataType::Boolean => build_list!(BooleanBuilder, Boolean, values, size),
- DataType::Int8 => build_list!(Int8Builder, Int8, values, size),
- DataType::Int16 => build_list!(Int16Builder, Int16, values, size),
- DataType::Int32 => build_list!(Int32Builder, Int32, values, size),
- DataType::Int64 => build_list!(Int64Builder, Int64, values, size),
- DataType::UInt8 => build_list!(UInt8Builder, UInt8, values, size),
- DataType::UInt16 => build_list!(UInt16Builder, UInt16, values, size),
- DataType::UInt32 => build_list!(UInt32Builder, UInt32, values, size),
- DataType::UInt64 => build_list!(UInt64Builder, UInt64, values, size),
- DataType::Utf8 => build_list!(StringBuilder, Utf8, values, size),
- DataType::Float32 => build_list!(Float32Builder, Float32, values, size),
- DataType::Float64 => build_list!(Float64Builder, Float64, values, size),
+ ScalarValue::List(values, data_type) => match data_type.as_ref() {
+ DataType::Boolean => {
+ build_list!(MutableBooleanArray, Boolean, values, size)
+ }
+ DataType::Int8 => build_list!(Int8Vec, Int8, values, size),
+ DataType::Int16 => build_list!(Int16Vec, Int16, values, size),
+ DataType::Int32 => build_list!(Int32Vec, Int32, values, size),
+ DataType::Int64 => build_list!(Int64Vec, Int64, values, size),
+ DataType::UInt8 => build_list!(UInt8Vec, UInt8, values, size),
+ DataType::UInt16 => build_list!(UInt16Vec, UInt16, values, size),
+ DataType::UInt32 => build_list!(UInt32Vec, UInt32, values, size),
+ DataType::UInt64 => build_list!(UInt64Vec, UInt64, values, size),
+ DataType::Float32 => build_list!(Float32Vec, Float32, values, size),
+ DataType::Float64 => build_list!(Float64Vec, Float64, values, size),
DataType::Timestamp(unit, tz) => {
- build_timestamp_list!(unit.clone(), tz.clone(), values, size)
+ build_timestamp_list!(*unit, values, size, tz.clone())
}
- &DataType::LargeUtf8 => {
- build_list!(LargeStringBuilder, LargeUtf8, values, size)
+ DataType::Utf8 => build_list!(MutableStringArray, Utf8, values, size),
+ DataType::LargeUtf8 => {
+ build_list!(MutableLargeStringArray, LargeUtf8, values, size)
}
- _ => ScalarValue::iter_to_array_list(
- repeat(self.clone()).take(size),
- &DataType::List(Box::new(Field::new(
- "item",
- data_type.as_ref().clone(),
- true,
- ))),
- )
- .unwrap(),
- }),
- ScalarValue::Date32(e) => {
- build_array_from_option!(Date32, Date32Array, e, size)
- }
- ScalarValue::Date64(e) => {
- build_array_from_option!(Date64, Date64Array, e, size)
- }
- ScalarValue::IntervalDayTime(e) => build_array_from_option!(
- Interval,
- IntervalUnit::DayTime,
- IntervalDayTimeArray,
- e,
- size
- ),
- ScalarValue::IntervalYearMonth(e) => build_array_from_option!(
- Interval,
- IntervalUnit::YearMonth,
- IntervalYearMonthArray,
- e,
- size
- ),
- ScalarValue::IntervalMonthDayNano(e) => build_array_from_option!(
- Interval,
- IntervalUnit::MonthDayNano,
- IntervalMonthDayNanoArray,
- e,
- size
- ),
- ScalarValue::Struct(values, fields) => match values {
+ dt => panic!("Unexpected DataType for list {:?}", dt),
+ },
+ ScalarValue::IntervalDayTime(e) => match e {
+ Some(value) => {
+ Arc::new(PrimitiveArray::<days_ms>::from_trusted_len_values_iter(
+ std::iter::repeat(*value).take(size),
+ ))
+ }
+ None => new_null_array(self.get_datatype(), size).into(),
+ },
+ ScalarValue::Struct(values, _) => match values {
Some(values) => {
- let field_values: Vec<_> = fields
- .iter()
- .zip(values.iter())
- .map(|(field, value)| {
- (field.clone(), value.to_array_of_size(size))
- })
- .collect();
-
- Arc::new(StructArray::from(field_values))
+ let field_values =
+ values.iter().map(|v| v.to_array_of_size(size)).collect();
+ Arc::new(StructArray::from_data(
+ self.get_datatype(),
+ field_values,
+ None,
+ ))
}
- None => {
- let field_values: Vec<_> = fields
- .iter()
- .map(|field| {
- let none_field = Self::try_from(field.data_type())
- .expect("Failed to construct null ScalarValue from Struct field type");
- (field.clone(), none_field.to_array_of_size(size))
- })
- .collect();
-
- Arc::new(StructArray::from(field_values))
- }
+ None => Arc::new(StructArray::new_null(self.get_datatype(), size)),
},
}
}
@@ -1258,7 +1175,7 @@
precision: &usize,
scale: &usize,
) -> ScalarValue {
- let array = array.as_any().downcast_ref::<DecimalArray>().unwrap();
+ let array = array.as_any().downcast_ref::<Int128Array>().unwrap();
if array.is_null(index) {
ScalarValue::Decimal128(None, *precision, *scale)
} else {
@@ -1288,15 +1205,17 @@
DataType::Int32 => typed_cast!(array, index, Int32Array, Int32),
DataType::Int16 => typed_cast!(array, index, Int16Array, Int16),
DataType::Int8 => typed_cast!(array, index, Int8Array, Int8),
- DataType::Binary => typed_cast!(array, index, BinaryArray, Binary),
+ DataType::Binary => typed_cast!(array, index, SmallBinaryArray, Binary),
DataType::LargeBinary => {
typed_cast!(array, index, LargeBinaryArray, LargeBinary)
}
DataType::Utf8 => typed_cast!(array, index, StringArray, Utf8),
DataType::LargeUtf8 => typed_cast!(array, index, LargeStringArray, LargeUtf8),
DataType::List(nested_type) => {
- let list_array =
- array.as_any().downcast_ref::<ListArray>().ok_or_else(|| {
+ let list_array = array
+ .as_any()
+ .downcast_ref::<ListArray<i32>>()
+ .ok_or_else(|| {
DataFusionError::Internal(
"Failed to downcast ListArray".to_string(),
)
@@ -1304,7 +1223,7 @@
let value = match list_array.is_null(index) {
true => None,
false => {
- let nested_array = list_array.value(index);
+ let nested_array = ArrayRef::from(list_array.value(index));
let scalar_vec = (0..nested_array.len())
.map(|i| ScalarValue::try_from_array(&nested_array, i))
.collect::<Result<Vec<_>>>()?;
@@ -1316,63 +1235,33 @@
ScalarValue::List(value, data_type)
}
DataType::Date32 => {
- typed_cast!(array, index, Date32Array, Date32)
+ typed_cast!(array, index, Int32Array, Date32)
}
DataType::Date64 => {
- typed_cast!(array, index, Date64Array, Date64)
+ typed_cast!(array, index, Int64Array, Date64)
}
DataType::Timestamp(TimeUnit::Second, tz_opt) => {
- typed_cast_tz!(
- array,
- index,
- TimestampSecondArray,
- TimestampSecond,
- tz_opt
- )
+ typed_cast_tz!(array, index, TimestampSecond, tz_opt)
}
DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => {
- typed_cast_tz!(
- array,
- index,
- TimestampMillisecondArray,
- TimestampMillisecond,
- tz_opt
- )
+ typed_cast_tz!(array, index, TimestampMillisecond, tz_opt)
}
DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => {
- typed_cast_tz!(
- array,
- index,
- TimestampMicrosecondArray,
- TimestampMicrosecond,
- tz_opt
- )
+ typed_cast_tz!(array, index, TimestampMicrosecond, tz_opt)
}
DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => {
- typed_cast_tz!(
- array,
- index,
- TimestampNanosecondArray,
- TimestampNanosecond,
- tz_opt
- )
+ typed_cast_tz!(array, index, TimestampNanosecond, tz_opt)
}
- DataType::Dictionary(index_type, _) => {
- let (values, values_index) = match **index_type {
- DataType::Int8 => get_dict_value::<Int8Type>(array, index)?,
- DataType::Int16 => get_dict_value::<Int16Type>(array, index)?,
- DataType::Int32 => get_dict_value::<Int32Type>(array, index)?,
- DataType::Int64 => get_dict_value::<Int64Type>(array, index)?,
- DataType::UInt8 => get_dict_value::<UInt8Type>(array, index)?,
- DataType::UInt16 => get_dict_value::<UInt16Type>(array, index)?,
- DataType::UInt32 => get_dict_value::<UInt32Type>(array, index)?,
- DataType::UInt64 => get_dict_value::<UInt64Type>(array, index)?,
- _ => {
- return Err(DataFusionError::Internal(format!(
- "Index type not supported while creating scalar from dictionary: {}",
- array.data_type(),
- )));
- }
+ DataType::Dictionary(index_type, _, _) => {
+ let (values, values_index) = match index_type {
+ IntegerType::Int8 => get_dict_value::<i8>(array, index)?,
+ IntegerType::Int16 => get_dict_value::<i16>(array, index)?,
+ IntegerType::Int32 => get_dict_value::<i32>(array, index)?,
+ IntegerType::Int64 => get_dict_value::<i64>(array, index)?,
+ IntegerType::UInt8 => get_dict_value::<u8>(array, index)?,
+ IntegerType::UInt16 => get_dict_value::<u16>(array, index)?,
+ IntegerType::UInt32 => get_dict_value::<u32>(array, index)?,
+ IntegerType::UInt64 => get_dict_value::<u64>(array, index)?,
};
match values_index {
@@ -1393,7 +1282,7 @@
})?;
let mut field_values: Vec<ScalarValue> = Vec::new();
for col_index in 0..array.num_columns() {
- let col_array = array.column(col_index);
+ let col_array = &array.values()[col_index];
let col_scalar = ScalarValue::try_from_array(col_array, index)?;
field_values.push(col_scalar);
}
@@ -1415,9 +1304,14 @@
precision: usize,
scale: usize,
) -> bool {
- let array = array.as_any().downcast_ref::<DecimalArray>().unwrap();
- if array.precision() != precision || array.scale() != scale {
- return false;
+ let array = array.as_any().downcast_ref::<Int128Array>().unwrap();
+ match array.data_type() {
+ DataType::Decimal(pre, sca) => {
+ if *pre != precision || *sca != scale {
+ return false;
+ }
+ }
+ _ => return false,
}
match value {
None => array.is_null(index),
@@ -1443,7 +1337,7 @@
/// comparisons where comparing a single row at a time is necessary.
#[inline]
pub fn eq_array(&self, array: &ArrayRef, index: usize) -> bool {
- if let DataType::Dictionary(key_type, _) = array.data_type() {
+ if let DataType::Dictionary(key_type, _, _) = array.data_type() {
return self.eq_array_dictionary(array, index, key_type);
}
@@ -1479,38 +1373,38 @@
eq_array_primitive!(array, index, LargeStringArray, val)
}
ScalarValue::Binary(val) => {
- eq_array_primitive!(array, index, BinaryArray, val)
+ eq_array_primitive!(array, index, SmallBinaryArray, val)
}
ScalarValue::LargeBinary(val) => {
eq_array_primitive!(array, index, LargeBinaryArray, val)
}
ScalarValue::List(_, _) => unimplemented!(),
ScalarValue::Date32(val) => {
- eq_array_primitive!(array, index, Date32Array, val)
+ eq_array_primitive!(array, index, Int32Array, val)
}
ScalarValue::Date64(val) => {
- eq_array_primitive!(array, index, Date64Array, val)
+ eq_array_primitive!(array, index, Int64Array, val)
}
ScalarValue::TimestampSecond(val, _) => {
- eq_array_primitive!(array, index, TimestampSecondArray, val)
+ eq_array_primitive!(array, index, Int64Array, val)
}
ScalarValue::TimestampMillisecond(val, _) => {
- eq_array_primitive!(array, index, TimestampMillisecondArray, val)
+ eq_array_primitive!(array, index, Int64Array, val)
}
ScalarValue::TimestampMicrosecond(val, _) => {
- eq_array_primitive!(array, index, TimestampMicrosecondArray, val)
+ eq_array_primitive!(array, index, Int64Array, val)
}
ScalarValue::TimestampNanosecond(val, _) => {
- eq_array_primitive!(array, index, TimestampNanosecondArray, val)
+ eq_array_primitive!(array, index, Int64Array, val)
}
ScalarValue::IntervalYearMonth(val) => {
- eq_array_primitive!(array, index, IntervalYearMonthArray, val)
+ eq_array_primitive!(array, index, Int32Array, val)
}
ScalarValue::IntervalDayTime(val) => {
- eq_array_primitive!(array, index, IntervalDayTimeArray, val)
+ eq_array_primitive!(array, index, DaysMsArray, val)
}
ScalarValue::IntervalMonthDayNano(val) => {
- eq_array_primitive!(array, index, IntervalMonthDayNanoArray, val)
+ eq_array_primitive!(array, index, Int128Array, val)
}
ScalarValue::Struct(_, _) => unimplemented!(),
}
@@ -1522,18 +1416,17 @@
&self,
array: &ArrayRef,
index: usize,
- key_type: &DataType,
+ key_type: &IntegerType,
) -> bool {
let (values, values_index) = match key_type {
- DataType::Int8 => get_dict_value::<Int8Type>(array, index).unwrap(),
- DataType::Int16 => get_dict_value::<Int16Type>(array, index).unwrap(),
- DataType::Int32 => get_dict_value::<Int32Type>(array, index).unwrap(),
- DataType::Int64 => get_dict_value::<Int64Type>(array, index).unwrap(),
- DataType::UInt8 => get_dict_value::<UInt8Type>(array, index).unwrap(),
- DataType::UInt16 => get_dict_value::<UInt16Type>(array, index).unwrap(),
- DataType::UInt32 => get_dict_value::<UInt32Type>(array, index).unwrap(),
- DataType::UInt64 => get_dict_value::<UInt64Type>(array, index).unwrap(),
- _ => unreachable!("Invalid dictionary keys type: {:?}", key_type),
+ IntegerType::Int8 => get_dict_value::<i8>(array, index).unwrap(),
+ IntegerType::Int16 => get_dict_value::<i16>(array, index).unwrap(),
+ IntegerType::Int32 => get_dict_value::<i32>(array, index).unwrap(),
+ IntegerType::Int64 => get_dict_value::<i64>(array, index).unwrap(),
+ IntegerType::UInt8 => get_dict_value::<u8>(array, index).unwrap(),
+ IntegerType::UInt16 => get_dict_value::<u16>(array, index).unwrap(),
+ IntegerType::UInt32 => get_dict_value::<u32>(array, index).unwrap(),
+ IntegerType::UInt64 => get_dict_value::<u64>(array, index).unwrap(),
};
match values_index {
@@ -1689,6 +1582,123 @@
impl_try_from!(Float64, f64);
impl_try_from!(Boolean, bool);
+impl TryInto<Box<dyn Scalar>> for &ScalarValue {
+ type Error = DataFusionError;
+
+ fn try_into(self) -> Result<Box<dyn Scalar>> {
+ use arrow::scalar::*;
+ match self {
+ ScalarValue::Boolean(b) => Ok(Box::new(BooleanScalar::new(*b))),
+ ScalarValue::Float32(f) => {
+ Ok(Box::new(PrimitiveScalar::<f32>::new(DataType::Float32, *f)))
+ }
+ ScalarValue::Float64(f) => {
+ Ok(Box::new(PrimitiveScalar::<f64>::new(DataType::Float64, *f)))
+ }
+ ScalarValue::Int8(i) => {
+ Ok(Box::new(PrimitiveScalar::<i8>::new(DataType::Int8, *i)))
+ }
+ ScalarValue::Int16(i) => {
+ Ok(Box::new(PrimitiveScalar::<i16>::new(DataType::Int16, *i)))
+ }
+ ScalarValue::Int32(i) => {
+ Ok(Box::new(PrimitiveScalar::<i32>::new(DataType::Int32, *i)))
+ }
+ ScalarValue::Int64(i) => {
+ Ok(Box::new(PrimitiveScalar::<i64>::new(DataType::Int64, *i)))
+ }
+ ScalarValue::UInt8(u) => {
+ Ok(Box::new(PrimitiveScalar::<u8>::new(DataType::UInt8, *u)))
+ }
+ ScalarValue::UInt16(u) => {
+ Ok(Box::new(PrimitiveScalar::<u16>::new(DataType::UInt16, *u)))
+ }
+ ScalarValue::UInt32(u) => {
+ Ok(Box::new(PrimitiveScalar::<u32>::new(DataType::UInt32, *u)))
+ }
+ ScalarValue::UInt64(u) => {
+ Ok(Box::new(PrimitiveScalar::<u64>::new(DataType::UInt64, *u)))
+ }
+ ScalarValue::Utf8(s) => Ok(Box::new(Utf8Scalar::<i32>::new(s.clone()))),
+ ScalarValue::LargeUtf8(s) => Ok(Box::new(Utf8Scalar::<i64>::new(s.clone()))),
+ ScalarValue::Binary(b) => Ok(Box::new(BinaryScalar::<i32>::new(b.clone()))),
+ ScalarValue::LargeBinary(b) => {
+ Ok(Box::new(BinaryScalar::<i64>::new(b.clone())))
+ }
+ ScalarValue::Date32(i) => {
+ Ok(Box::new(PrimitiveScalar::<i32>::new(DataType::Date32, *i)))
+ }
+ ScalarValue::Date64(i) => {
+ Ok(Box::new(PrimitiveScalar::<i64>::new(DataType::Date64, *i)))
+ }
+ ScalarValue::TimestampSecond(i, tz) => {
+ Ok(Box::new(PrimitiveScalar::<i64>::new(
+ DataType::Timestamp(TimeUnit::Second, tz.clone()),
+ *i,
+ )))
+ }
+ ScalarValue::TimestampMillisecond(i, tz) => {
+ Ok(Box::new(PrimitiveScalar::<i64>::new(
+ DataType::Timestamp(TimeUnit::Millisecond, tz.clone()),
+ *i,
+ )))
+ }
+ ScalarValue::TimestampMicrosecond(i, tz) => {
+ Ok(Box::new(PrimitiveScalar::<i64>::new(
+ DataType::Timestamp(TimeUnit::Microsecond, tz.clone()),
+ *i,
+ )))
+ }
+ ScalarValue::TimestampNanosecond(i, tz) => {
+ Ok(Box::new(PrimitiveScalar::<i64>::new(
+ DataType::Timestamp(TimeUnit::Nanosecond, tz.clone()),
+ *i,
+ )))
+ }
+ ScalarValue::IntervalYearMonth(i) => {
+ Ok(Box::new(PrimitiveScalar::<i32>::new(
+ DataType::Interval(IntervalUnit::YearMonth),
+ *i,
+ )))
+ }
+
+ // List and IntervalDayTime comparison not possible in arrow2
+ _ => Err(DataFusionError::Internal(
+ "Conversion not possible in arrow2".to_owned(),
+ )),
+ }
+ }
+}
+
+impl<T: NativeType> TryFrom<PrimitiveScalar<T>> for ScalarValue {
+ type Error = DataFusionError;
+
+ fn try_from(s: PrimitiveScalar<T>) -> Result<ScalarValue> {
+ match s.data_type() {
+ DataType::Timestamp(TimeUnit::Second, tz) => {
+ let s = s.as_any().downcast_ref::<PrimitiveScalar<i64>>().unwrap();
+ Ok(ScalarValue::TimestampSecond(s.value(), tz.clone()))
+ }
+ DataType::Timestamp(TimeUnit::Microsecond, tz) => {
+ let s = s.as_any().downcast_ref::<PrimitiveScalar<i64>>().unwrap();
+ Ok(ScalarValue::TimestampMicrosecond(s.value(), tz.clone()))
+ }
+ DataType::Timestamp(TimeUnit::Millisecond, tz) => {
+ let s = s.as_any().downcast_ref::<PrimitiveScalar<i64>>().unwrap();
+ Ok(ScalarValue::TimestampMillisecond(s.value(), tz.clone()))
+ }
+ DataType::Timestamp(TimeUnit::Nanosecond, tz) => {
+ let s = s.as_any().downcast_ref::<PrimitiveScalar<i64>>().unwrap();
+ Ok(ScalarValue::TimestampNanosecond(s.value(), tz.clone()))
+ }
+ _ => Err(DataFusionError::Internal(
+ format!(
+ "Conversion from arrow Scalar to Datafusion ScalarValue not implemented for: {:?}", s))
+ ),
+ }
+ }
+}
+
impl TryFrom<&DataType> for ScalarValue {
type Error = DataFusionError;
@@ -1725,7 +1735,7 @@
DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => {
ScalarValue::TimestampNanosecond(None, tz_opt.clone())
}
- DataType::Dictionary(_index_type, value_type) => {
+ DataType::Dictionary(_index_type, value_type, _) => {
value_type.as_ref().try_into()?
}
DataType::List(ref nested_type) => {
@@ -1896,39 +1906,3 @@
}
}
}
-
-/// Trait used to map a NativeTime to a ScalarType.
-pub trait ScalarType<T: ArrowNativeType> {
- /// returns a scalar from an optional T
- fn scalar(r: Option<T>) -> ScalarValue;
-}
-
-impl ScalarType<f32> for Float32Type {
- fn scalar(r: Option<f32>) -> ScalarValue {
- ScalarValue::Float32(r)
- }
-}
-
-impl ScalarType<i64> for TimestampSecondType {
- fn scalar(r: Option<i64>) -> ScalarValue {
- ScalarValue::TimestampSecond(r, None)
- }
-}
-
-impl ScalarType<i64> for TimestampMillisecondType {
- fn scalar(r: Option<i64>) -> ScalarValue {
- ScalarValue::TimestampMillisecond(r, None)
- }
-}
-
-impl ScalarType<i64> for TimestampMicrosecondType {
- fn scalar(r: Option<i64>) -> ScalarValue {
- ScalarValue::TimestampMicrosecond(r, None)
- }
-}
-
-impl ScalarType<i64> for TimestampNanosecondType {
- fn scalar(r: Option<i64>) -> ScalarValue {
- ScalarValue::TimestampNanosecond(r, None)
- }
-}
diff --git a/datafusion-examples/Cargo.toml b/datafusion-examples/Cargo.toml
index e61e044..599dc09 100644
--- a/datafusion-examples/Cargo.toml
+++ b/datafusion-examples/Cargo.toml
@@ -34,7 +34,8 @@
required-features = ["datafusion/avro"]
[dev-dependencies]
-arrow-flight = { version = "10.0" }
+arrow-format = { version = "0.4", features = ["flight-service", "flight-data"] }
+arrow = { package = "arrow2", version="0.10", features = ["io_ipc", "io_flight"] }
datafusion = { path = "../datafusion" }
prost = "0.9"
tonic = "0.6"
diff --git a/datafusion-examples/examples/avro_sql.rs b/datafusion-examples/examples/avro_sql.rs
index f08c12b..b819f2b 100644
--- a/datafusion-examples/examples/avro_sql.rs
+++ b/datafusion-examples/examples/avro_sql.rs
@@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.
-use datafusion::arrow::util::pretty;
+use datafusion::arrow_print;
use datafusion::error::Result;
use datafusion::prelude::*;
@@ -27,7 +27,7 @@
// create local execution context
let mut ctx = ExecutionContext::new();
- let testdata = datafusion::arrow::util::test_util::arrow_test_data();
+ let testdata = datafusion::test_util::arrow_test_data();
// register avro file with the execution context
let avro_file = &format!("{}/avro/alltypes_plain.avro", testdata);
@@ -45,7 +45,7 @@
let results = df.collect().await?;
// print the results
- pretty::print_batches(&results)?;
+ println!("{}", arrow_print::write(&results));
Ok(())
}
diff --git a/datafusion-examples/examples/custom_datasource.rs b/datafusion-examples/examples/custom_datasource.rs
index aad153a..6dadb05 100644
--- a/datafusion-examples/examples/custom_datasource.rs
+++ b/datafusion-examples/examples/custom_datasource.rs
@@ -15,14 +15,14 @@
// specific language governing permissions and limitations
// under the License.
+use arrow::array::{MutableArray, UInt64Vec, UInt8Vec};
use async_trait::async_trait;
-use datafusion::arrow::array::{Array, UInt64Builder, UInt8Builder};
use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef};
-use datafusion::arrow::record_batch::RecordBatch;
use datafusion::datasource::TableProvider;
use datafusion::error::{DataFusionError, Result};
use datafusion::execution::dataframe_impl::DataFrameImpl;
use datafusion::execution::runtime_env::RuntimeEnv;
+use datafusion::field_util::SchemaExt;
use datafusion::logical_plan::{Expr, LogicalPlanBuilder};
use datafusion::physical_plan::expressions::PhysicalSortExpr;
use datafusion::physical_plan::memory::MemoryStream;
@@ -30,6 +30,7 @@
project_schema, ExecutionPlan, SendableRecordBatchStream, Statistics,
};
use datafusion::prelude::*;
+use datafusion::record_batch::RecordBatch;
use std::any::Any;
use std::collections::{BTreeMap, HashMap};
use std::fmt::{Debug, Formatter};
@@ -242,21 +243,18 @@
db.data.values().cloned().collect()
};
- let mut id_array = UInt8Builder::new(users.len());
- let mut account_array = UInt64Builder::new(users.len());
+ let mut id_array = UInt8Vec::with_capacity(users.len());
+ let mut account_array = UInt64Vec::with_capacity(users.len());
for user in users {
- id_array.append_value(user.id)?;
- account_array.append_value(user.bank_account)?;
+ id_array.push(Some(user.id));
+ account_array.push(Some(user.bank_account));
}
return Ok(Box::pin(MemoryStream::try_new(
vec![RecordBatch::try_new(
self.projected_schema.clone(),
- vec![
- Arc::new(id_array.finish()),
- Arc::new(account_array.finish()),
- ],
+ vec![id_array.as_arc(), account_array.as_arc()],
)?],
self.schema(),
None,
diff --git a/datafusion-examples/examples/dataframe.rs b/datafusion-examples/examples/dataframe.rs
index 6fd3461..1d5b496 100644
--- a/datafusion-examples/examples/dataframe.rs
+++ b/datafusion-examples/examples/dataframe.rs
@@ -25,7 +25,7 @@
// create local execution context
let mut ctx = ExecutionContext::new();
- let testdata = datafusion::arrow::util::test_util::parquet_test_data();
+ let testdata = datafusion::test_util::parquet_test_data();
let filename = &format!("{}/alltypes_plain.parquet", testdata);
diff --git a/datafusion-examples/examples/dataframe_in_memory.rs b/datafusion-examples/examples/dataframe_in_memory.rs
index e17c69e..b00bfda 100644
--- a/datafusion-examples/examples/dataframe_in_memory.rs
+++ b/datafusion-examples/examples/dataframe_in_memory.rs
@@ -17,12 +17,13 @@
use std::sync::Arc;
-use datafusion::arrow::array::{Int32Array, StringArray};
+use datafusion::arrow::array::{Int32Array, Utf8Array};
use datafusion::arrow::datatypes::{DataType, Field, Schema};
-use datafusion::arrow::record_batch::RecordBatch;
+use datafusion::record_batch::RecordBatch;
+
use datafusion::datasource::MemTable;
use datafusion::error::Result;
-use datafusion::from_slice::FromSlice;
+use datafusion::field_util::SchemaExt;
use datafusion::prelude::*;
/// This example demonstrates how to use the DataFrame API against in-memory data.
@@ -38,8 +39,8 @@
let batch = RecordBatch::try_new(
schema.clone(),
vec![
- Arc::new(StringArray::from_slice(&["a", "b", "c", "d"])),
- Arc::new(Int32Array::from_slice(&[1, 10, 10, 100])),
+ Arc::new(Utf8Array::<i32>::from_slice(&["a", "b", "c", "d"])),
+ Arc::new(Int32Array::from_values(vec![1, 10, 10, 100])),
],
)?;
diff --git a/datafusion-examples/examples/flight_client.rs b/datafusion-examples/examples/flight_client.rs
index 6fc8014..5b8304c 100644
--- a/datafusion-examples/examples/flight_client.rs
+++ b/datafusion-examples/examples/flight_client.rs
@@ -15,23 +15,22 @@
// specific language governing permissions and limitations
// under the License.
-use std::convert::TryFrom;
use std::sync::Arc;
-use datafusion::arrow::datatypes::Schema;
-
-use arrow_flight::flight_descriptor;
-use arrow_flight::flight_service_client::FlightServiceClient;
-use arrow_flight::utils::flight_data_to_arrow_batch;
-use arrow_flight::{FlightDescriptor, Ticket};
-use datafusion::arrow::util::pretty;
+use arrow::io::flight::deserialize_schemas;
+use arrow_format::flight::data::{flight_descriptor, FlightDescriptor, Ticket};
+use arrow_format::flight::service::flight_service_client::FlightServiceClient;
+use datafusion::arrow_print;
+use datafusion::field_util::SchemaExt;
+use datafusion::record_batch::RecordBatch;
+use std::collections::HashMap;
/// This example shows how to wrap DataFusion with `FlightService` to support looking up schema information for
/// Parquet files and executing SQL queries against them on a remote server.
/// This example is run along-side the example `flight_server`.
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
- let testdata = datafusion::arrow::util::test_util::parquet_test_data();
+ let testdata = datafusion::test_util::parquet_test_data();
// Create Flight client
let mut client = FlightServiceClient::connect("http://localhost:50051").await?;
@@ -44,7 +43,8 @@
});
let schema_result = client.get_schema(request).await?.into_inner();
- let schema = Schema::try_from(&schema_result)?;
+ let (schema, _) = deserialize_schemas(schema_result.schema.as_slice()).unwrap();
+ let schema = Arc::new(schema);
println!("Schema: {:?}", schema);
// Call do_get to execute a SQL query and receive results
@@ -57,23 +57,26 @@
// the schema should be the first message returned, else client should error
let flight_data = stream.message().await?.unwrap();
// convert FlightData to a stream
- let schema = Arc::new(Schema::try_from(&flight_data)?);
+ let (schema, ipc_schema) =
+ deserialize_schemas(flight_data.data_body.as_slice()).unwrap();
+ let schema = Arc::new(schema);
println!("Schema: {:?}", schema);
// all the remaining stream messages should be dictionary and record batches
let mut results = vec![];
- let dictionaries_by_field = vec![None; schema.fields().len()];
+ let dictionaries_by_field = HashMap::new();
while let Some(flight_data) = stream.message().await? {
- let record_batch = flight_data_to_arrow_batch(
+ let chunk = arrow::io::flight::deserialize_batch(
&flight_data,
- schema.clone(),
+ schema.fields(),
+ &ipc_schema,
&dictionaries_by_field,
)?;
- results.push(record_batch);
+ results.push(RecordBatch::new_with_chunk(&schema, chunk));
}
// print the results
- pretty::print_batches(&results)?;
+ println!("{}", arrow_print::write(&results));
Ok(())
}
diff --git a/datafusion-examples/examples/flight_server.rs b/datafusion-examples/examples/flight_server.rs
index c26dcce..b616cfb 100644
--- a/datafusion-examples/examples/flight_server.rs
+++ b/datafusion-examples/examples/flight_server.rs
@@ -15,10 +15,10 @@
// specific language governing permissions and limitations
// under the License.
+use arrow::chunk::Chunk;
use std::pin::Pin;
use std::sync::Arc;
-use arrow_flight::SchemaAsIpc;
use datafusion::datasource::file_format::parquet::ParquetFormat;
use datafusion::datasource::listing::ListingOptions;
use datafusion::datasource::object_store::local::LocalFileSystem;
@@ -28,11 +28,14 @@
use datafusion::prelude::*;
-use arrow_flight::{
- flight_service_server::FlightService, flight_service_server::FlightServiceServer,
+use arrow::io::ipc::write::WriteOptions;
+use arrow_format::flight::data::{
Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo,
HandshakeRequest, HandshakeResponse, PutResult, SchemaResult, Ticket,
};
+use arrow_format::flight::service::flight_service_server::{
+ FlightService, FlightServiceServer,
+};
#[derive(Clone)]
pub struct FlightServiceImpl {}
@@ -50,7 +53,7 @@
Pin<Box<dyn Stream<Item = Result<PutResult, Status>> + Send + Sync + 'static>>;
type DoActionStream = Pin<
Box<
- dyn Stream<Item = Result<arrow_flight::Result, Status>>
+ dyn Stream<Item = Result<arrow_format::flight::data::Result, Status>>
+ Send
+ Sync
+ 'static,
@@ -74,8 +77,8 @@
.await
.unwrap();
- let options = datafusion::arrow::ipc::writer::IpcWriteOptions::default();
- let schema_result = SchemaAsIpc::new(&schema, &options).into();
+ let schema_result =
+ arrow::io::flight::serialize_schema_to_result(schema.as_ref(), None);
Ok(Response::new(schema_result))
}
@@ -92,7 +95,7 @@
// create local execution context
let mut ctx = ExecutionContext::new();
- let testdata = datafusion::arrow::util::test_util::parquet_test_data();
+ let testdata = datafusion::test_util::parquet_test_data();
// register parquet file with the execution context
ctx.register_parquet(
@@ -112,20 +115,21 @@
}
// add an initial FlightData message that sends schema
- let options = datafusion::arrow::ipc::writer::IpcWriteOptions::default();
- let schema_flight_data =
- SchemaAsIpc::new(&df.schema().clone().into(), &options).into();
+ let options = WriteOptions::default();
+ let schema_flight_data = arrow::io::flight::serialize_schema(
+ &df.schema().clone().into(),
+ None,
+ );
let mut flights: Vec<Result<FlightData, Status>> =
vec![Ok(schema_flight_data)];
let mut batches: Vec<Result<FlightData, Status>> = results
- .iter()
+ .into_iter()
.flat_map(|batch| {
+ let chunk = Chunk::new(batch.columns().to_vec());
let (flight_dictionaries, flight_batch) =
- arrow_flight::utils::flight_data_from_arrow_batch(
- batch, &options,
- );
+ arrow::io::flight::serialize_batch(&chunk, &[], &options);
flight_dictionaries
.into_iter()
.chain(std::iter::once(flight_batch))
diff --git a/datafusion-examples/examples/memtable.rs b/datafusion-examples/examples/memtable.rs
index e113d98..4c63520 100644
--- a/datafusion-examples/examples/memtable.rs
+++ b/datafusion-examples/examples/memtable.rs
@@ -17,10 +17,11 @@
use datafusion::arrow::array::{UInt64Array, UInt8Array};
use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef};
-use datafusion::arrow::record_batch::RecordBatch;
use datafusion::datasource::MemTable;
use datafusion::error::Result;
+use datafusion::field_util::SchemaExt;
use datafusion::prelude::ExecutionContext;
+use datafusion::record_batch::RecordBatch;
use std::sync::Arc;
use std::time::Duration;
use tokio::time::timeout;
@@ -56,8 +57,8 @@
}
fn create_record_batch() -> Result<RecordBatch> {
- let id_array = UInt8Array::from(vec![1]);
- let account_array = UInt64Array::from(vec![9000]);
+ let id_array = UInt8Array::from_slice(vec![1]);
+ let account_array = UInt64Array::from_slice(vec![9000]);
Result::Ok(
RecordBatch::try_new(
diff --git a/datafusion-examples/examples/parquet_sql.rs b/datafusion-examples/examples/parquet_sql.rs
index e74ed39..7f7a976 100644
--- a/datafusion-examples/examples/parquet_sql.rs
+++ b/datafusion-examples/examples/parquet_sql.rs
@@ -25,7 +25,7 @@
// create local execution context
let mut ctx = ExecutionContext::new();
- let testdata = datafusion::arrow::util::test_util::parquet_test_data();
+ let testdata = datafusion::test_util::parquet_test_data();
// register parquet file with the execution context
ctx.register_parquet(
diff --git a/datafusion-examples/examples/parquet_sql_multiple_files.rs b/datafusion-examples/examples/parquet_sql_multiple_files.rs
index 7485bc7..a8c9b64 100644
--- a/datafusion-examples/examples/parquet_sql_multiple_files.rs
+++ b/datafusion-examples/examples/parquet_sql_multiple_files.rs
@@ -30,7 +30,7 @@
// create local execution context
let mut ctx = ExecutionContext::new();
- let testdata = datafusion::arrow::util::test_util::parquet_test_data();
+ let testdata = datafusion::test_util::parquet_test_data();
// Configure listing options
let file_format = ParquetFormat::default().with_enable_pruning(true);
diff --git a/datafusion-examples/examples/simple_udaf.rs b/datafusion-examples/examples/simple_udaf.rs
index 3acace2..15c85bc 100644
--- a/datafusion-examples/examples/simple_udaf.rs
+++ b/datafusion-examples/examples/simple_udaf.rs
@@ -17,12 +17,11 @@
/// In this example we will declare a single-type, single return type UDAF that computes the geometric mean.
/// The geometric mean is described here: https://en.wikipedia.org/wiki/Geometric_mean
-use datafusion::arrow::{
- array::ArrayRef, array::Float32Array, array::Float64Array, datatypes::DataType,
- record_batch::RecordBatch,
-};
+use datafusion::arrow::{array::Float32Array, array::Float64Array, datatypes::DataType};
+use datafusion::record_batch::RecordBatch;
-use datafusion::from_slice::FromSlice;
+use arrow::array::ArrayRef;
+use datafusion::field_util::SchemaExt;
use datafusion::physical_plan::functions::Volatility;
use datafusion::{error::Result, logical_plan::create_udaf, physical_plan::Accumulator};
use datafusion::{prelude::*, scalar::ScalarValue};
diff --git a/datafusion-examples/examples/simple_udf.rs b/datafusion-examples/examples/simple_udf.rs
index 33242c7..e30bd39 100644
--- a/datafusion-examples/examples/simple_udf.rs
+++ b/datafusion-examples/examples/simple_udf.rs
@@ -15,17 +15,16 @@
// specific language governing permissions and limitations
// under the License.
+use datafusion::field_util::SchemaExt;
+use datafusion::prelude::*;
+use datafusion::record_batch::RecordBatch;
use datafusion::{
arrow::{
array::{ArrayRef, Float32Array, Float64Array},
datatypes::DataType,
- record_batch::RecordBatch,
},
physical_plan::functions::Volatility,
};
-
-use datafusion::from_slice::FromSlice;
-use datafusion::prelude::*;
use datafusion::{error::Result, physical_plan::functions::make_scalar_function};
use std::sync::Arc;
@@ -43,8 +42,8 @@
let batch = RecordBatch::try_new(
schema.clone(),
vec![
- Arc::new(Float32Array::from_slice(&[2.1, 3.1, 4.1, 5.1])),
- Arc::new(Float64Array::from_slice(&[1.0, 2.0, 3.0, 4.0])),
+ Arc::new(Float32Array::from_values(vec![2.1, 3.1, 4.1, 5.1])),
+ Arc::new(Float64Array::from_values(vec![1.0, 2.0, 3.0, 4.0])),
],
)?;
@@ -92,7 +91,7 @@
match (base, exponent) {
// in arrow, any value can be null.
// Here we decide to make our UDF to return null when either base or exponent is null.
- (Some(base), Some(exponent)) => Some(base.powf(exponent)),
+ (Some(base), Some(exponent)) => Some(base.powf(*exponent)),
_ => None,
}
})
diff --git a/datafusion-expr/Cargo.toml b/datafusion-expr/Cargo.toml
index 4609d1b..a021197 100644
--- a/datafusion-expr/Cargo.toml
+++ b/datafusion-expr/Cargo.toml
@@ -36,6 +36,6 @@
[dependencies]
datafusion-common = { path = "../datafusion-common", version = "7.0.0" }
-arrow = { version = "10.0", features = ["prettyprint"] }
+arrow = { package = "arrow2", version = "0.10", default-features = false }
sqlparser = "0.15"
ahash = { version = "0.7", default-features = false }
diff --git a/datafusion-expr/src/columnar_value.rs b/datafusion-expr/src/columnar_value.rs
index 4867c0e..f78964a 100644
--- a/datafusion-expr/src/columnar_value.rs
+++ b/datafusion-expr/src/columnar_value.rs
@@ -17,12 +17,14 @@
//! Columnar value module contains a set of types that represent a columnar value.
+use std::sync::Arc;
+
use arrow::array::ArrayRef;
use arrow::array::NullArray;
use arrow::datatypes::DataType;
-use arrow::record_batch::RecordBatch;
+
+use datafusion_common::record_batch::RecordBatch;
use datafusion_common::ScalarValue;
-use std::sync::Arc;
/// Represents the result from an expression
#[derive(Clone)]
@@ -57,6 +59,9 @@
impl From<&RecordBatch> for NullColumnarValue {
fn from(batch: &RecordBatch) -> Self {
let num_rows = batch.num_rows();
- ColumnarValue::Array(Arc::new(NullArray::new(num_rows)))
+ ColumnarValue::Array(Arc::new(NullArray::new_null(
+ DataType::Struct(batch.schema().fields.to_vec()),
+ num_rows,
+ )))
}
}
diff --git a/datafusion-physical-expr/Cargo.toml b/datafusion-physical-expr/Cargo.toml
index 90a560e..fc3f225 100644
--- a/datafusion-physical-expr/Cargo.toml
+++ b/datafusion-physical-expr/Cargo.toml
@@ -41,7 +41,7 @@
[dependencies]
datafusion-common = { path = "../datafusion-common", version = "7.0.0" }
datafusion-expr = { path = "../datafusion-expr", version = "7.0.0" }
-arrow = { version = "10.0", features = ["prettyprint"] }
+arrow = { package = "arrow2", version = "0.10" }
paste = "^1.0"
ahash = { version = "0.7", default-features = false }
ordered-float = "2.10"
diff --git a/datafusion-physical-expr/src/array_expressions.rs b/datafusion-physical-expr/src/array_expressions.rs
index ca396d0..19d7535 100644
--- a/datafusion-physical-expr/src/array_expressions.rs
+++ b/datafusion-physical-expr/src/array_expressions.rs
@@ -21,66 +21,92 @@
use arrow::datatypes::DataType;
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::ColumnarValue;
-use std::sync::Arc;
-macro_rules! downcast_vec {
- ($ARGS:expr, $ARRAY_TYPE:ident) => {{
- $ARGS
- .iter()
- .map(|e| match e.as_any().downcast_ref::<$ARRAY_TYPE>() {
- Some(array) => Ok(array),
- _ => Err(DataFusionError::Internal("failed to downcast".to_string())),
- })
- }};
-}
+fn array_array(arrays: &[&dyn Array]) -> Result<ArrayRef> {
+ assert!(!arrays.is_empty());
+ let first = arrays[0];
+ assert!(arrays.iter().all(|x| x.len() == first.len()));
+ assert!(arrays.iter().all(|x| x.data_type() == first.data_type()));
-macro_rules! array {
- ($ARGS:expr, $ARRAY_TYPE:ident, $BUILDER_TYPE:ident) => {{
- // downcast all arguments to their common format
- let args =
- downcast_vec!($ARGS, $ARRAY_TYPE).collect::<Result<Vec<&$ARRAY_TYPE>>>()?;
+ let size = arrays.len();
- let mut builder = FixedSizeListBuilder::<$BUILDER_TYPE>::new(
- <$BUILDER_TYPE>::new(args[0].len()),
- args.len() as i32,
- );
- // for each entry in the array
- for index in 0..args[0].len() {
- for arg in &args {
- if arg.is_null(index) {
- builder.values().append_null()?;
- } else {
- builder.values().append_value(arg.value(index))?;
- }
- }
- builder.append(true)?;
- }
- Ok(Arc::new(builder.finish()))
- }};
-}
-
-fn array_array(args: &[&dyn Array]) -> Result<ArrayRef> {
- // do not accept 0 arguments.
- if args.is_empty() {
- return Err(DataFusionError::Internal(
- "array requires at least one argument".to_string(),
- ));
+ macro_rules! array {
+ ($PRIMITIVE: ty, $ARRAY: ty, $DATA_TYPE: path) => {{
+ let array = MutablePrimitiveArray::<$PRIMITIVE>::with_capacity_from(
+ first.len() * size,
+ $DATA_TYPE,
+ );
+ let mut array = MutableFixedSizeListArray::new(array, size);
+ array.try_extend(
+ // for each entry in the array
+ (0..first.len()).map(|idx| {
+ Some(arrays.iter().map(move |arg| {
+ let arg = arg.as_any().downcast_ref::<$ARRAY>().unwrap();
+ if arg.is_null(idx) {
+ None
+ } else {
+ Some(arg.value(idx))
+ }
+ }))
+ }),
+ )?;
+ Ok(array.as_arc())
+ }};
}
- match args[0].data_type() {
- DataType::Utf8 => array!(args, StringArray, StringBuilder),
- DataType::LargeUtf8 => array!(args, LargeStringArray, LargeStringBuilder),
- DataType::Boolean => array!(args, BooleanArray, BooleanBuilder),
- DataType::Float32 => array!(args, Float32Array, Float32Builder),
- DataType::Float64 => array!(args, Float64Array, Float64Builder),
- DataType::Int8 => array!(args, Int8Array, Int8Builder),
- DataType::Int16 => array!(args, Int16Array, Int16Builder),
- DataType::Int32 => array!(args, Int32Array, Int32Builder),
- DataType::Int64 => array!(args, Int64Array, Int64Builder),
- DataType::UInt8 => array!(args, UInt8Array, UInt8Builder),
- DataType::UInt16 => array!(args, UInt16Array, UInt16Builder),
- DataType::UInt32 => array!(args, UInt32Array, UInt32Builder),
- DataType::UInt64 => array!(args, UInt64Array, UInt64Builder),
+ macro_rules! array_string {
+ ($OFFSET: ty) => {{
+ let array = MutableUtf8Array::<$OFFSET>::with_capacity(first.len() * size);
+ let mut array = MutableFixedSizeListArray::new(array, size);
+ array.try_extend(
+ // for each entry in the array
+ (0..first.len()).map(|idx| {
+ Some(arrays.iter().map(move |arg| {
+ let arg =
+ arg.as_any().downcast_ref::<Utf8Array<$OFFSET>>().unwrap();
+ if arg.is_null(idx) {
+ None
+ } else {
+ Some(arg.value(idx))
+ }
+ }))
+ }),
+ )?;
+ Ok(array.as_arc())
+ }};
+ }
+
+ match first.data_type() {
+ DataType::Boolean => {
+ let array = MutableBooleanArray::with_capacity(first.len() * size);
+ let mut array = MutableFixedSizeListArray::new(array, size);
+ array.try_extend(
+ // for each entry in the array
+ (0..first.len()).map(|idx| {
+ Some(arrays.iter().map(move |arg| {
+ let arg = arg.as_any().downcast_ref::<BooleanArray>().unwrap();
+ if arg.is_null(idx) {
+ None
+ } else {
+ Some(arg.value(idx))
+ }
+ }))
+ }),
+ )?;
+ Ok(array.as_arc())
+ }
+ DataType::UInt8 => array!(u8, PrimitiveArray<u8>, DataType::UInt8),
+ DataType::UInt16 => array!(u16, PrimitiveArray<u16>, DataType::UInt16),
+ DataType::UInt32 => array!(u32, PrimitiveArray<u32>, DataType::UInt32),
+ DataType::UInt64 => array!(u64, PrimitiveArray<u64>, DataType::UInt64),
+ DataType::Int8 => array!(i8, PrimitiveArray<i8>, DataType::Int8),
+ DataType::Int16 => array!(i16, PrimitiveArray<i16>, DataType::Int16),
+ DataType::Int32 => array!(i32, PrimitiveArray<i32>, DataType::Int32),
+ DataType::Int64 => array!(i64, PrimitiveArray<i64>, DataType::Int64),
+ DataType::Float32 => array!(f32, PrimitiveArray<f32>, DataType::Float32),
+ DataType::Float64 => array!(f64, PrimitiveArray<f64>, DataType::Float64),
+ DataType::Utf8 => array_string!(i32),
+ DataType::LargeUtf8 => array_string!(i64),
data_type => Err(DataFusionError::NotImplemented(format!(
"Array is not implemented for type '{:?}'.",
data_type
@@ -109,6 +135,8 @@
/// Currently supported types by the array function.
/// The order of these types correspond to the order on which coercion applies
/// This should thus be from least informative to most informative
+// `array` supports all types, but we do not have a signature to correctly
+// coerce them.
pub static SUPPORTED_ARRAY_TYPES: &[DataType] = &[
DataType::Boolean,
DataType::UInt8,
diff --git a/datafusion-physical-expr/src/arrow_temporal_util.rs b/datafusion-physical-expr/src/arrow_temporal_util.rs
new file mode 100644
index 0000000..fdc8418
--- /dev/null
+++ b/datafusion-physical-expr/src/arrow_temporal_util.rs
@@ -0,0 +1,302 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use arrow::error::{ArrowError, Result};
+use chrono::{prelude::*, LocalResult};
+
+/// Accepts a string in RFC3339 / ISO8601 standard format and some
+/// variants and converts it to a nanosecond precision timestamp.
+///
+/// Implements the `to_timestamp` function to convert a string to a
+/// timestamp, following the model of spark SQL’s to_`timestamp`.
+///
+/// In addition to RFC3339 / ISO8601 standard timestamps, it also
+/// accepts strings that use a space ` ` to separate the date and time
+/// as well as strings that have no explicit timezone offset.
+///
+/// Examples of accepted inputs:
+/// * `1997-01-31T09:26:56.123Z` # RCF3339
+/// * `1997-01-31T09:26:56.123-05:00` # RCF3339
+/// * `1997-01-31 09:26:56.123-05:00` # close to RCF3339 but with a space rather than T
+/// * `1997-01-31T09:26:56.123` # close to RCF3339 but no timezone offset specified
+/// * `1997-01-31 09:26:56.123` # close to RCF3339 but uses a space and no timezone offset
+/// * `1997-01-31 09:26:56` # close to RCF3339, no fractional seconds
+//
+/// Internally, this function uses the `chrono` library for the
+/// datetime parsing
+///
+/// We hope to extend this function in the future with a second
+/// parameter to specifying the format string.
+///
+/// ## Timestamp Precision
+///
+/// Function uses the maximum precision timestamps supported by
+/// Arrow (nanoseconds stored as a 64-bit integer) timestamps. This
+/// means the range of dates that timestamps can represent is ~1677 AD
+/// to 2262 AM
+///
+///
+/// ## Timezone / Offset Handling
+///
+/// Numerical values of timestamps are stored compared to offset UTC.
+///
+/// This function intertprets strings without an explicit time zone as
+/// timestamps with offsets of the local time on the machine
+///
+/// For example, `1997-01-31 09:26:56.123Z` is interpreted as UTC, as
+/// it has an explicit timezone specifier (“Z” for Zulu/UTC)
+///
+/// `1997-01-31T09:26:56.123` is interpreted as a local timestamp in
+/// the timezone of the machine. For example, if
+/// the system timezone is set to Americas/New_York (UTC-5) the
+/// timestamp will be interpreted as though it were
+/// `1997-01-31T09:26:56.123-05:00`
+///
+/// TODO: remove this hack and redesign DataFusion's time related API, with regard to timezone.
+#[inline]
+pub(crate) fn string_to_timestamp_nanos(s: &str) -> Result<i64> {
+ // Fast path: RFC3339 timestamp (with a T)
+ // Example: 2020-09-08T13:42:29.190855Z
+ if let Ok(ts) = DateTime::parse_from_rfc3339(s) {
+ return Ok(ts.timestamp_nanos());
+ }
+
+ // Implement quasi-RFC3339 support by trying to parse the
+ // timestamp with various other format specifiers to to support
+ // separating the date and time with a space ' ' rather than 'T' to be
+ // (more) compatible with Apache Spark SQL
+
+ // timezone offset, using ' ' as a separator
+ // Example: 2020-09-08 13:42:29.190855-05:00
+ if let Ok(ts) = DateTime::parse_from_str(s, "%Y-%m-%d %H:%M:%S%.f%:z") {
+ return Ok(ts.timestamp_nanos());
+ }
+
+ // with an explicit Z, using ' ' as a separator
+ // Example: 2020-09-08 13:42:29Z
+ if let Ok(ts) = Utc.datetime_from_str(s, "%Y-%m-%d %H:%M:%S%.fZ") {
+ return Ok(ts.timestamp_nanos());
+ }
+
+ // Support timestamps without an explicit timezone offset, again
+ // to be compatible with what Apache Spark SQL does.
+
+ // without a timezone specifier as a local time, using T as a separator
+ // Example: 2020-09-08T13:42:29.190855
+ if let Ok(ts) = NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%S.%f") {
+ return naive_datetime_to_timestamp(s, ts);
+ }
+
+ // without a timezone specifier as a local time, using T as a
+ // separator, no fractional seconds
+ // Example: 2020-09-08T13:42:29
+ if let Ok(ts) = NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%S") {
+ return naive_datetime_to_timestamp(s, ts);
+ }
+
+ // without a timezone specifier as a local time, using ' ' as a separator
+ // Example: 2020-09-08 13:42:29.190855
+ if let Ok(ts) = NaiveDateTime::parse_from_str(s, "%Y-%m-%d %H:%M:%S.%f") {
+ return naive_datetime_to_timestamp(s, ts);
+ }
+
+ // without a timezone specifier as a local time, using ' ' as a
+ // separator, no fractional seconds
+ // Example: 2020-09-08 13:42:29
+ if let Ok(ts) = NaiveDateTime::parse_from_str(s, "%Y-%m-%d %H:%M:%S") {
+ return naive_datetime_to_timestamp(s, ts);
+ }
+
+ // Note we don't pass along the error message from the underlying
+ // chrono parsing because we tried several different format
+ // strings and we don't know which the user was trying to
+ // match. Ths any of the specific error messages is likely to be
+ // be more confusing than helpful
+ Err(ArrowError::OutOfSpec(format!(
+ "Error parsing '{}' as timestamp",
+ s
+ )))
+}
+
+/// Converts the naive datetime (which has no specific timezone) to a
+/// nanosecond epoch timestamp relative to UTC.
+fn naive_datetime_to_timestamp(s: &str, datetime: NaiveDateTime) -> Result<i64> {
+ let l = Local {};
+
+ match l.from_local_datetime(&datetime) {
+ LocalResult::None => Err(ArrowError::OutOfSpec(format!(
+ "Error parsing '{}' as timestamp: local time representation is invalid",
+ s
+ ))),
+ LocalResult::Single(local_datetime) => {
+ Ok(local_datetime.with_timezone(&Utc).timestamp_nanos())
+ }
+ // Ambiguous times can happen if the timestamp is exactly when
+ // a daylight savings time transition occurs, for example, and
+ // so the datetime could validly be said to be in two
+ // potential offsets. However, since we are about to convert
+ // to UTC anyways, we can pick one arbitrarily
+ LocalResult::Ambiguous(local_datetime, _) => {
+ Ok(local_datetime.with_timezone(&Utc).timestamp_nanos())
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn string_to_timestamp_timezone() -> Result<()> {
+ // Explicit timezone
+ assert_eq!(
+ 1599572549190855000,
+ parse_timestamp("2020-09-08T13:42:29.190855+00:00")?
+ );
+ assert_eq!(
+ 1599572549190855000,
+ parse_timestamp("2020-09-08T13:42:29.190855Z")?
+ );
+ assert_eq!(
+ 1599572549000000000,
+ parse_timestamp("2020-09-08T13:42:29Z")?
+ ); // no fractional part
+ assert_eq!(
+ 1599590549190855000,
+ parse_timestamp("2020-09-08T13:42:29.190855-05:00")?
+ );
+ Ok(())
+ }
+
+ #[test]
+ fn string_to_timestamp_timezone_space() -> Result<()> {
+ // Ensure space rather than T between time and date is accepted
+ assert_eq!(
+ 1599572549190855000,
+ parse_timestamp("2020-09-08 13:42:29.190855+00:00")?
+ );
+ assert_eq!(
+ 1599572549190855000,
+ parse_timestamp("2020-09-08 13:42:29.190855Z")?
+ );
+ assert_eq!(
+ 1599572549000000000,
+ parse_timestamp("2020-09-08 13:42:29Z")?
+ ); // no fractional part
+ assert_eq!(
+ 1599590549190855000,
+ parse_timestamp("2020-09-08 13:42:29.190855-05:00")?
+ );
+ Ok(())
+ }
+
+ /// Interprets a naive_datetime (with no explicit timzone offset)
+ /// using the local timezone and returns the timestamp in UTC (0
+ /// offset)
+ fn naive_datetime_to_timestamp(naive_datetime: &NaiveDateTime) -> i64 {
+ // Note: Use chrono APIs that are different than
+ // naive_datetime_to_timestamp to compute the utc offset to
+ // try and double check the logic
+ let utc_offset_secs = match Local.offset_from_local_datetime(naive_datetime) {
+ LocalResult::Single(local_offset) => {
+ local_offset.fix().local_minus_utc() as i64
+ }
+ _ => panic!("Unexpected failure converting to local datetime"),
+ };
+ let utc_offset_nanos = utc_offset_secs * 1_000_000_000;
+ naive_datetime.timestamp_nanos() - utc_offset_nanos
+ }
+
+ #[test]
+ #[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function: mktime
+ fn string_to_timestamp_no_timezone() -> Result<()> {
+ // This test is designed to succeed in regardless of the local
+ // timezone the test machine is running. Thus it is still
+ // somewhat suceptable to bugs in the use of chrono
+ let naive_datetime = NaiveDateTime::new(
+ NaiveDate::from_ymd(2020, 9, 8),
+ NaiveTime::from_hms_nano(13, 42, 29, 190855),
+ );
+
+ // Ensure both T and ' ' variants work
+ assert_eq!(
+ naive_datetime_to_timestamp(&naive_datetime),
+ parse_timestamp("2020-09-08T13:42:29.190855")?
+ );
+
+ assert_eq!(
+ naive_datetime_to_timestamp(&naive_datetime),
+ parse_timestamp("2020-09-08 13:42:29.190855")?
+ );
+
+ // Also ensure that parsing timestamps with no fractional
+ // second part works as well
+ let naive_datetime_whole_secs = NaiveDateTime::new(
+ NaiveDate::from_ymd(2020, 9, 8),
+ NaiveTime::from_hms(13, 42, 29),
+ );
+
+ // Ensure both T and ' ' variants work
+ assert_eq!(
+ naive_datetime_to_timestamp(&naive_datetime_whole_secs),
+ parse_timestamp("2020-09-08T13:42:29")?
+ );
+
+ assert_eq!(
+ naive_datetime_to_timestamp(&naive_datetime_whole_secs),
+ parse_timestamp("2020-09-08 13:42:29")?
+ );
+
+ Ok(())
+ }
+
+ #[test]
+ fn string_to_timestamp_invalid() {
+ // Test parsing invalid formats
+
+ // It would be nice to make these messages better
+ expect_timestamp_parse_error("", "Error parsing '' as timestamp");
+ expect_timestamp_parse_error("SS", "Error parsing 'SS' as timestamp");
+ expect_timestamp_parse_error(
+ "Wed, 18 Feb 2015 23:16:09 GMT",
+ "Error parsing 'Wed, 18 Feb 2015 23:16:09 GMT' as timestamp",
+ );
+ }
+
+ // Parse a timestamp to timestamp int with a useful human readable error message
+ fn parse_timestamp(s: &str) -> Result<i64> {
+ let result = string_to_timestamp_nanos(s);
+ if let Err(e) = &result {
+ eprintln!("Error parsing timestamp '{}': {:?}", s, e);
+ }
+ result
+ }
+
+ fn expect_timestamp_parse_error(s: &str, expected_err: &str) {
+ match string_to_timestamp_nanos(s) {
+ Ok(v) => panic!(
+ "Expected error '{}' while parsing '{}', but parsed {} instead",
+ expected_err, s, v
+ ),
+ Err(e) => {
+ assert!(e.to_string().contains(expected_err),
+ "Can not find expected error '{}' while parsing '{}'. Actual error '{}'",
+ expected_err, s, e);
+ }
+ }
+ }
+}
diff --git a/datafusion-physical-expr/src/coercion_rule/aggregate_rule.rs b/datafusion-physical-expr/src/coercion_rule/aggregate_rule.rs
index 279fe7d..b486b97 100644
--- a/datafusion-physical-expr/src/coercion_rule/aggregate_rule.rs
+++ b/datafusion-physical-expr/src/coercion_rule/aggregate_rule.rs
@@ -223,7 +223,7 @@
// min and max support the dictionary data type
// unpack the dictionary to get the value
match &input_types[0] {
- DataType::Dictionary(_, dict_value_type) => {
+ DataType::Dictionary(_, dict_value_type, _) => {
// TODO add checker, if the value type is complex data type
Ok(vec![dict_value_type.deref().clone()])
}
diff --git a/datafusion-physical-expr/src/coercion_rule/binary_rule.rs b/datafusion-physical-expr/src/coercion_rule/binary_rule.rs
index ac23f2b..c1941ff 100644
--- a/datafusion-physical-expr/src/coercion_rule/binary_rule.rs
+++ b/datafusion-physical-expr/src/coercion_rule/binary_rule.rs
@@ -17,9 +17,9 @@
//! Coercion rules for matching argument types for binary operators
-use arrow::datatypes::{DataType, DECIMAL_MAX_PRECISION, DECIMAL_MAX_SCALE};
-use datafusion_common::DataFusionError;
+use arrow::datatypes::DataType;
use datafusion_common::Result;
+use datafusion_common::{DataFusionError, DECIMAL_MAX_PRECISION, DECIMAL_MAX_SCALE};
use datafusion_expr::Operator;
/// Coercion rules for all binary operators. Returns the output type
@@ -356,13 +356,13 @@
fn dictionary_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
match (lhs_type, rhs_type) {
(
- DataType::Dictionary(_lhs_index_type, lhs_value_type),
- DataType::Dictionary(_rhs_index_type, rhs_value_type),
+ DataType::Dictionary(_lhs_index_type, lhs_value_type, _),
+ DataType::Dictionary(_rhs_index_type, rhs_value_type, _),
) => dictionary_value_coercion(lhs_value_type, rhs_value_type),
- (DataType::Dictionary(_index_type, value_type), _) => {
+ (DataType::Dictionary(_index_type, value_type, _), _) => {
dictionary_value_coercion(value_type, rhs_type)
}
- (_, DataType::Dictionary(_index_type, value_type)) => {
+ (_, DataType::Dictionary(_index_type, value_type, _)) => {
dictionary_value_coercion(lhs_type, value_type)
}
_ => None,
@@ -429,7 +429,7 @@
(TimeUnit::Nanosecond, TimeUnit::Microsecond) => TimeUnit::Microsecond,
(l, r) => {
assert_eq!(l, r);
- l.clone()
+ *l
}
};
@@ -440,7 +440,7 @@
}
pub(crate) fn is_dictionary(t: &DataType) -> bool {
- matches!(t, DataType::Dictionary(_, _))
+ matches!(t, DataType::Dictionary(_, _, _))
}
/// Coercion rule for numerical types: The type that both lhs and rhs
@@ -494,7 +494,7 @@
#[cfg(test)]
mod tests {
use super::*;
- use arrow::datatypes::DataType;
+ use arrow::datatypes::{DataType, IntegerType};
use datafusion_common::DataFusionError;
use datafusion_common::Result;
use datafusion_expr::Operator;
@@ -628,20 +628,20 @@
use DataType::*;
// TODO: In the future, this would ideally return Dictionary types and avoid unpacking
- let lhs_type = Dictionary(Box::new(Int8), Box::new(Int32));
- let rhs_type = Dictionary(Box::new(Int8), Box::new(Int16));
+ let lhs_type = Dictionary(IntegerType::Int8, Box::new(Int32), false);
+ let rhs_type = Dictionary(IntegerType::Int8, Box::new(Int16), false);
assert_eq!(dictionary_coercion(&lhs_type, &rhs_type), Some(Int32));
- let lhs_type = Dictionary(Box::new(Int8), Box::new(Utf8));
- let rhs_type = Dictionary(Box::new(Int8), Box::new(Int16));
+ let lhs_type = Dictionary(IntegerType::Int8, Box::new(Utf8), false);
+ let rhs_type = Dictionary(IntegerType::Int8, Box::new(Int16), false);
assert_eq!(dictionary_coercion(&lhs_type, &rhs_type), None);
- let lhs_type = Dictionary(Box::new(Int8), Box::new(Utf8));
+ let lhs_type = Dictionary(IntegerType::Int8, Box::new(Utf8), false);
let rhs_type = Utf8;
assert_eq!(dictionary_coercion(&lhs_type, &rhs_type), Some(Utf8));
let lhs_type = Utf8;
- let rhs_type = Dictionary(Box::new(Int8), Box::new(Utf8));
+ let rhs_type = Dictionary(IntegerType::Int8, Box::new(Utf8), false);
assert_eq!(dictionary_coercion(&lhs_type, &rhs_type), Some(Utf8));
}
}
diff --git a/datafusion-physical-expr/src/crypto_expressions.rs b/datafusion-physical-expr/src/crypto_expressions.rs
index 95bedd4..f786e6e 100644
--- a/datafusion-physical-expr/src/crypto_expressions.rs
+++ b/datafusion-physical-expr/src/crypto_expressions.rs
@@ -17,11 +17,10 @@
//! Crypto expressions
+use arrow::array::Utf8Array;
+use arrow::types::Offset;
use arrow::{
- array::{
- Array, ArrayRef, BinaryArray, GenericStringArray, StringArray,
- StringOffsetSizeTrait,
- },
+ array::{Array, BinaryArray},
datatypes::DataType,
};
use blake2::{Blake2b512, Blake2s256, Digest};
@@ -81,7 +80,7 @@
macro_rules! digest_to_array {
($METHOD:ident, $INPUT:expr) => {{
- let binary_array: BinaryArray = $INPUT
+ let binary_array: BinaryArray<i32> = $INPUT
.iter()
.map(|x| {
x.map(|x| {
@@ -127,18 +126,19 @@
/// digest a string array to their hash values
fn digest_array<T>(self, value: &dyn Array) -> Result<ColumnarValue>
where
- T: StringOffsetSizeTrait,
+ T: Offset,
{
- let input_value = value
- .as_any()
- .downcast_ref::<GenericStringArray<T>>()
- .ok_or_else(|| {
- DataFusionError::Internal(format!(
- "could not cast value to {}",
- type_name::<GenericStringArray<T>>()
- ))
- })?;
- let array: ArrayRef = match self {
+ let input_value =
+ value
+ .as_any()
+ .downcast_ref::<Utf8Array<T>>()
+ .ok_or_else(|| {
+ DataFusionError::Internal(format!(
+ "could not cast value to {}",
+ type_name::<Utf8Array<T>>()
+ ))
+ })?;
+ let array: Arc<dyn Array> = match self {
Self::Md5 => digest_to_array!(Md5, input_value),
Self::Sha224 => digest_to_array!(Sha224, input_value),
Self::Sha256 => digest_to_array!(Sha256, input_value),
@@ -147,7 +147,7 @@
Self::Blake2b => digest_to_array!(Blake2b512, input_value),
Self::Blake2s => digest_to_array!(Blake2s256, input_value),
Self::Blake3 => {
- let binary_array: BinaryArray = input_value
+ let binary_array: BinaryArray<i32> = input_value
.iter()
.map(|opt| {
opt.map(|x| {
@@ -251,13 +251,13 @@
let binary_array = array
.as_ref()
.as_any()
- .downcast_ref::<BinaryArray>()
+ .downcast_ref::<BinaryArray<i32>>()
.ok_or_else(|| {
DataFusionError::Internal(
"Impossibly got non-binary array data from digest".into(),
)
})?;
- let string_array: StringArray = binary_array
+ let string_array: Utf8Array<i32> = binary_array
.iter()
.map(|opt| opt.map(hex_encode::<_>))
.collect();
diff --git a/datafusion-physical-expr/src/datetime_expressions.rs b/datafusion-physical-expr/src/datetime_expressions.rs
index 9a8351d..1f53ac8 100644
--- a/datafusion-physical-expr/src/datetime_expressions.rs
+++ b/datafusion-physical-expr/src/datetime_expressions.rs
@@ -17,27 +17,21 @@
//! DateTime expressions
+use crate::arrow_temporal_util::string_to_timestamp_nanos;
+use arrow::compute::temporal;
+use arrow::scalar::PrimitiveScalar;
+use arrow::temporal_conversions::timestamp_ns_to_datetime;
+use arrow::types::NativeType;
use arrow::{
- array::{Array, ArrayRef, GenericStringArray, PrimitiveArray, StringOffsetSizeTrait},
- compute::kernels::cast_utils::string_to_timestamp_nanos,
- datatypes::{
- ArrowPrimitiveType, DataType, TimestampMicrosecondType, TimestampMillisecondType,
- TimestampNanosecondType, TimestampSecondType,
- },
+ array::*,
+ compute::cast,
+ datatypes::{DataType, TimeUnit},
};
-use arrow::{
- array::{
- Date32Array, Date64Array, TimestampMicrosecondArray, TimestampMillisecondArray,
- TimestampNanosecondArray, TimestampSecondArray,
- },
- compute::kernels::temporal,
- datatypes::TimeUnit,
- temporal_conversions::timestamp_ns_to_datetime,
-};
-use chrono::prelude::*;
+use chrono::prelude::{DateTime, Utc};
use chrono::Duration;
-use datafusion_common::{DataFusionError, Result};
-use datafusion_common::{ScalarType, ScalarValue};
+use chrono::Timelike;
+use chrono::{Datelike, NaiveDateTime};
+use datafusion_common::{DataFusionError, Result, ScalarValue};
use datafusion_expr::ColumnarValue;
use std::borrow::Borrow;
use std::sync::Arc;
@@ -48,7 +42,7 @@
/// # Errors
/// This function errors iff:
/// * the number of arguments is not 1 or
-/// * the first argument is not castable to a `GenericStringArray` or
+/// * the first argument is not castable to a `Utf8Array` or
/// * the function `op` errors
pub(crate) fn unary_string_to_primitive_function<'a, T, O, F>(
args: &[&'a dyn Array],
@@ -56,9 +50,9 @@
name: &str,
) -> Result<PrimitiveArray<O>>
where
- O: ArrowPrimitiveType,
- T: StringOffsetSizeTrait,
- F: Fn(&'a str) -> Result<O::Native>,
+ O: NativeType,
+ T: Offset,
+ F: Fn(&'a str) -> Result<O>,
{
if args.len() != 1 {
return Err(DataFusionError::Internal(format!(
@@ -70,7 +64,7 @@
let array = args[0]
.as_any()
- .downcast_ref::<GenericStringArray<T>>()
+ .downcast_ref::<Utf8Array<T>>()
.ok_or_else(|| {
DataFusionError::Internal("failed to downcast to string".to_string())
})?;
@@ -85,23 +79,26 @@
// given an function that maps a `&str` to a arrow native type,
// returns a `ColumnarValue` where the function is applied to either a `ArrayRef` or `ScalarValue`
// depending on the `args`'s variant.
-fn handle<'a, O, F, S>(
+fn handle<'a, O, F>(
args: &'a [ColumnarValue],
op: F,
name: &str,
+ data_type: DataType,
) -> Result<ColumnarValue>
where
- O: ArrowPrimitiveType,
- S: ScalarType<O::Native>,
- F: Fn(&'a str) -> Result<O::Native>,
+ O: NativeType,
+ ScalarValue: From<Option<O>>,
+ F: Fn(&'a str) -> Result<O>,
{
match &args[0] {
ColumnarValue::Array(a) => match a.data_type() {
DataType::Utf8 => Ok(ColumnarValue::Array(Arc::new(
- unary_string_to_primitive_function::<i32, O, _>(&[a.as_ref()], op, name)?,
+ unary_string_to_primitive_function::<i32, O, _>(&[a.as_ref()], op, name)?
+ .to(data_type),
))),
DataType::LargeUtf8 => Ok(ColumnarValue::Array(Arc::new(
- unary_string_to_primitive_function::<i64, O, _>(&[a.as_ref()], op, name)?,
+ unary_string_to_primitive_function::<i64, O, _>(&[a.as_ref()], op, name)?
+ .to(data_type),
))),
other => Err(DataFusionError::Internal(format!(
"Unsupported data type {:?} for function {}",
@@ -109,14 +106,15 @@
))),
},
ColumnarValue::Scalar(scalar) => match scalar {
- ScalarValue::Utf8(a) => {
- let result = a.as_ref().map(|x| (op)(x)).transpose()?;
- Ok(ColumnarValue::Scalar(S::scalar(result)))
- }
- ScalarValue::LargeUtf8(a) => {
- let result = a.as_ref().map(|x| (op)(x)).transpose()?;
- Ok(ColumnarValue::Scalar(S::scalar(result)))
- }
+ ScalarValue::Utf8(a) | ScalarValue::LargeUtf8(a) => Ok(match a {
+ Some(s) => {
+ let s = PrimitiveScalar::<O>::new(data_type, Some((op)(s)?));
+ ColumnarValue::Scalar(s.try_into()?)
+ }
+ None => ColumnarValue::Scalar(
+ PrimitiveScalar::<O>::new(data_type, None).try_into()?,
+ ),
+ }),
other => Err(DataFusionError::Internal(format!(
"Unsupported data type {:?} for function {}",
other, name
@@ -125,44 +123,48 @@
}
}
-/// Calls string_to_timestamp_nanos and converts the error type
+/// Calls cast::string_to_timestamp_nanos and converts the error type
fn string_to_timestamp_nanos_shim(s: &str) -> Result<i64> {
string_to_timestamp_nanos(s).map_err(|e| e.into())
}
/// to_timestamp SQL function
pub fn to_timestamp(args: &[ColumnarValue]) -> Result<ColumnarValue> {
- handle::<TimestampNanosecondType, _, TimestampNanosecondType>(
+ handle::<i64, _>(
args,
string_to_timestamp_nanos_shim,
"to_timestamp",
+ DataType::Timestamp(TimeUnit::Nanosecond, None),
)
}
/// to_timestamp_millis SQL function
pub fn to_timestamp_millis(args: &[ColumnarValue]) -> Result<ColumnarValue> {
- handle::<TimestampMillisecondType, _, TimestampMillisecondType>(
+ handle::<i64, _>(
args,
|s| string_to_timestamp_nanos_shim(s).map(|n| n / 1_000_000),
"to_timestamp_millis",
+ DataType::Timestamp(TimeUnit::Millisecond, None),
)
}
/// to_timestamp_micros SQL function
pub fn to_timestamp_micros(args: &[ColumnarValue]) -> Result<ColumnarValue> {
- handle::<TimestampMicrosecondType, _, TimestampMicrosecondType>(
+ handle::<i64, _>(
args,
|s| string_to_timestamp_nanos_shim(s).map(|n| n / 1_000),
"to_timestamp_micros",
+ DataType::Timestamp(TimeUnit::Microsecond, None),
)
}
/// to_timestamp_seconds SQL function
pub fn to_timestamp_seconds(args: &[ColumnarValue]) -> Result<ColumnarValue> {
- handle::<TimestampSecondType, _, TimestampSecondType>(
+ handle::<i64, _>(
args,
|s| string_to_timestamp_nanos_shim(s).map(|n| n / 1_000_000_000),
"to_timestamp_seconds",
+ DataType::Timestamp(TimeUnit::Second, None),
)
}
@@ -246,24 +248,22 @@
));
};
- let f = |x: Option<i64>| x.map(|x| date_trunc_single(granularity, x)).transpose();
+ let f = |x: Option<&i64>| x.map(|x| date_trunc_single(granularity, *x)).transpose();
Ok(match array {
ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(v, tz_opt)) => {
ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(
- (f)(*v)?,
+ (f)(v.as_ref())?,
tz_opt.clone(),
))
}
ColumnarValue::Array(array) => {
- let array = array
- .as_any()
- .downcast_ref::<TimestampNanosecondArray>()
- .unwrap();
+ let array = array.as_any().downcast_ref::<Int64Array>().unwrap();
let array = array
.iter()
.map(f)
- .collect::<Result<TimestampNanosecondArray>>()?;
+ .collect::<Result<PrimitiveArray<i64>>>()?
+ .to(DataType::Timestamp(TimeUnit::Nanosecond, None));
ColumnarValue::Array(Arc::new(array))
}
@@ -275,52 +275,11 @@
})
}
-macro_rules! extract_date_part {
+macro_rules! cast_array_u32_i32 {
($ARRAY: expr, $FN:expr) => {
- match $ARRAY.data_type() {
- DataType::Date32 => {
- let array = $ARRAY.as_any().downcast_ref::<Date32Array>().unwrap();
- Ok($FN(array)?)
- }
- DataType::Date64 => {
- let array = $ARRAY.as_any().downcast_ref::<Date64Array>().unwrap();
- Ok($FN(array)?)
- }
- DataType::Timestamp(time_unit, None) => match time_unit {
- TimeUnit::Second => {
- let array = $ARRAY
- .as_any()
- .downcast_ref::<TimestampSecondArray>()
- .unwrap();
- Ok($FN(array)?)
- }
- TimeUnit::Millisecond => {
- let array = $ARRAY
- .as_any()
- .downcast_ref::<TimestampMillisecondArray>()
- .unwrap();
- Ok($FN(array)?)
- }
- TimeUnit::Microsecond => {
- let array = $ARRAY
- .as_any()
- .downcast_ref::<TimestampMicrosecondArray>()
- .unwrap();
- Ok($FN(array)?)
- }
- TimeUnit::Nanosecond => {
- let array = $ARRAY
- .as_any()
- .downcast_ref::<TimestampNanosecondArray>()
- .unwrap();
- Ok($FN(array)?)
- }
- },
- datatype => Err(DataFusionError::Internal(format!(
- "Extract does not support datatype {:?}",
- datatype
- ))),
- }
+ $FN($ARRAY.as_ref())
+ .map(|x| cast::primitive_to_primitive::<u32, i32>(&x, &DataType::Int32))
+ .map_err(|e| e.into())
};
}
@@ -349,13 +308,13 @@
};
let arr = match date_part.to_lowercase().as_str() {
- "year" => extract_date_part!(array, temporal::year),
- "month" => extract_date_part!(array, temporal::month),
- "week" => extract_date_part!(array, temporal::week),
- "day" => extract_date_part!(array, temporal::day),
- "hour" => extract_date_part!(array, temporal::hour),
- "minute" => extract_date_part!(array, temporal::minute),
- "second" => extract_date_part!(array, temporal::second),
+ "year" => temporal::year(array.as_ref()).map_err(|e| e.into()),
+ "month" => cast_array_u32_i32!(array, temporal::month),
+ "week" => cast_array_u32_i32!(array, temporal::iso_week),
+ "day" => cast_array_u32_i32!(array, temporal::day),
+ "hour" => cast_array_u32_i32!(array, temporal::hour),
+ "minute" => cast_array_u32_i32!(array, temporal::minute),
+ "second" => cast_array_u32_i32!(array, temporal::second),
_ => Err(DataFusionError::Execution(format!(
"Date part '{}' not supported",
date_part
@@ -376,7 +335,8 @@
mod tests {
use std::sync::Arc;
- use arrow::array::{ArrayRef, Int64Array, StringBuilder};
+ use arrow::array::*;
+ use arrow::datatypes::*;
use super::*;
@@ -384,18 +344,15 @@
fn to_timestamp_arrays_and_nulls() -> Result<()> {
// ensure that arrow array implementation is wired up and handles nulls correctly
- let mut string_builder = StringBuilder::new(2);
- let mut ts_builder = TimestampNanosecondArray::builder(2);
-
- string_builder.append_value("2020-09-08T13:42:29.190855Z")?;
- ts_builder.append_value(1599572549190855000)?;
-
- string_builder.append_null()?;
- ts_builder.append_null()?;
- let expected_timestamps = &ts_builder.finish() as &dyn Array;
-
let string_array =
- ColumnarValue::Array(Arc::new(string_builder.finish()) as ArrayRef);
+ Utf8Array::<i32>::from(&[Some("2020-09-08T13:42:29.190855Z"), None]);
+
+ let ts_array = Int64Array::from(&[Some(1599572549190855000), None])
+ .to(DataType::Timestamp(TimeUnit::Nanosecond, None));
+
+ let expected_timestamps = &ts_array as &dyn Array;
+
+ let string_array = ColumnarValue::Array(Arc::new(string_array) as ArrayRef);
let parsed_timestamps = to_timestamp(&[string_array])
.expect("that to_timestamp parsed values without error");
if let ColumnarValue::Array(parsed_array) = parsed_timestamps {
@@ -507,9 +464,8 @@
// pass the wrong type of input array to to_timestamp and test
// that we get an error.
- let mut builder = Int64Array::builder(1);
- builder.append_value(1)?;
- let int64array = ColumnarValue::Array(Arc::new(builder.finish()));
+ let array = Int64Array::from_slice(&[1]);
+ let int64array = ColumnarValue::Array(Arc::new(array));
let expected_err =
"Internal error: Unsupported data type Int64 for function to_timestamp";
diff --git a/datafusion-physical-expr/src/expressions/approx_distinct.rs b/datafusion-physical-expr/src/expressions/approx_distinct.rs
index 610f381..725a075 100644
--- a/datafusion-physical-expr/src/expressions/approx_distinct.rs
+++ b/datafusion-physical-expr/src/expressions/approx_distinct.rs
@@ -19,14 +19,9 @@
use super::format_state_name;
use crate::{hyperloglog::HyperLogLog, AggregateExpr, PhysicalExpr};
-use arrow::array::{
- ArrayRef, BinaryArray, BinaryOffsetSizeTrait, GenericBinaryArray, GenericStringArray,
- PrimitiveArray, StringOffsetSizeTrait,
-};
-use arrow::datatypes::{
- ArrowPrimitiveType, DataType, Field, Int16Type, Int32Type, Int64Type, Int8Type,
- UInt16Type, UInt32Type, UInt64Type, UInt8Type,
-};
+use arrow::array::{ArrayRef, BinaryArray, Offset, PrimitiveArray, Utf8Array};
+use arrow::datatypes::{DataType, Field};
+use arrow::types::NativeType;
use datafusion_common::ScalarValue;
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::Accumulator;
@@ -88,21 +83,21 @@
// TODO u8, i8, u16, i16 shall really be done using bitmap, not HLL
// TODO support for boolean (trivial case)
// https://github.com/apache/arrow-datafusion/issues/1109
- DataType::UInt8 => Box::new(NumericHLLAccumulator::<UInt8Type>::new()),
- DataType::UInt16 => Box::new(NumericHLLAccumulator::<UInt16Type>::new()),
- DataType::UInt32 => Box::new(NumericHLLAccumulator::<UInt32Type>::new()),
- DataType::UInt64 => Box::new(NumericHLLAccumulator::<UInt64Type>::new()),
- DataType::Int8 => Box::new(NumericHLLAccumulator::<Int8Type>::new()),
- DataType::Int16 => Box::new(NumericHLLAccumulator::<Int16Type>::new()),
- DataType::Int32 => Box::new(NumericHLLAccumulator::<Int32Type>::new()),
- DataType::Int64 => Box::new(NumericHLLAccumulator::<Int64Type>::new()),
+ DataType::UInt8 => Box::new(NumericHLLAccumulator::<u8>::new()),
+ DataType::UInt16 => Box::new(NumericHLLAccumulator::<u16>::new()),
+ DataType::UInt32 => Box::new(NumericHLLAccumulator::<u32>::new()),
+ DataType::UInt64 => Box::new(NumericHLLAccumulator::<u64>::new()),
+ DataType::Int8 => Box::new(NumericHLLAccumulator::<i8>::new()),
+ DataType::Int16 => Box::new(NumericHLLAccumulator::<i16>::new()),
+ DataType::Int32 => Box::new(NumericHLLAccumulator::<i32>::new()),
+ DataType::Int64 => Box::new(NumericHLLAccumulator::<i64>::new()),
DataType::Utf8 => Box::new(StringHLLAccumulator::<i32>::new()),
DataType::LargeUtf8 => Box::new(StringHLLAccumulator::<i64>::new()),
DataType::Binary => Box::new(BinaryHLLAccumulator::<i32>::new()),
DataType::LargeBinary => Box::new(BinaryHLLAccumulator::<i64>::new()),
other => {
return Err(DataFusionError::NotImplemented(format!(
- "Support for 'approx_distinct' for data type {} is not implemented",
+ "Support for 'approx_distinct' for data type {:?} is not implemented",
other
)))
}
@@ -118,7 +113,7 @@
#[derive(Debug)]
struct BinaryHLLAccumulator<T>
where
- T: BinaryOffsetSizeTrait,
+ T: Offset,
{
hll: HyperLogLog<Vec<u8>>,
phantom_data: PhantomData<T>,
@@ -126,7 +121,7 @@
impl<T> BinaryHLLAccumulator<T>
where
- T: BinaryOffsetSizeTrait,
+ T: Offset,
{
/// new approx_distinct accumulator
pub fn new() -> Self {
@@ -140,7 +135,7 @@
#[derive(Debug)]
struct StringHLLAccumulator<T>
where
- T: StringOffsetSizeTrait,
+ T: Offset,
{
hll: HyperLogLog<String>,
phantom_data: PhantomData<T>,
@@ -148,7 +143,7 @@
impl<T> StringHLLAccumulator<T>
where
- T: StringOffsetSizeTrait,
+ T: Offset,
{
/// new approx_distinct accumulator
pub fn new() -> Self {
@@ -162,16 +157,14 @@
#[derive(Debug)]
struct NumericHLLAccumulator<T>
where
- T: ArrowPrimitiveType,
- T::Native: Hash,
+ T: NativeType + Hash,
{
- hll: HyperLogLog<T::Native>,
+ hll: HyperLogLog<T>,
}
impl<T> NumericHLLAccumulator<T>
where
- T: ArrowPrimitiveType,
- T::Native: Hash,
+ T: NativeType + Hash,
{
/// new approx_distinct accumulator
pub fn new() -> Self {
@@ -218,7 +211,10 @@
() => {
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
assert_eq!(1, states.len(), "expect only 1 element in the states");
- let binary_array = states[0].as_any().downcast_ref::<BinaryArray>().unwrap();
+ let binary_array = states[0]
+ .as_any()
+ .downcast_ref::<BinaryArray<i32>>()
+ .unwrap();
for v in binary_array.iter() {
let v = v.ok_or_else(|| {
DataFusionError::Internal(
@@ -258,11 +254,10 @@
impl<T> Accumulator for BinaryHLLAccumulator<T>
where
- T: BinaryOffsetSizeTrait,
+ T: Offset,
{
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
- let array: &GenericBinaryArray<T> =
- downcast_value!(values, GenericBinaryArray, T);
+ let array: &BinaryArray<T> = downcast_value!(values, BinaryArray, T);
// flatten because we would skip nulls
self.hll
.extend(array.into_iter().flatten().map(|v| v.to_vec()));
@@ -274,11 +269,10 @@
impl<T> Accumulator for StringHLLAccumulator<T>
where
- T: StringOffsetSizeTrait,
+ T: Offset,
{
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
- let array: &GenericStringArray<T> =
- downcast_value!(values, GenericStringArray, T);
+ let array: &Utf8Array<T> = downcast_value!(values, Utf8Array, T);
// flatten because we would skip nulls
self.hll
.extend(array.into_iter().flatten().map(|i| i.to_string()));
@@ -290,8 +284,7 @@
impl<T> Accumulator for NumericHLLAccumulator<T>
where
- T: ArrowPrimitiveType + std::fmt::Debug,
- T::Native: Hash,
+ T: NativeType + Hash,
{
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
let array: &PrimitiveArray<T> = downcast_value!(values, PrimitiveArray, T);
diff --git a/datafusion-physical-expr/src/expressions/approx_percentile_cont.rs b/datafusion-physical-expr/src/expressions/approx_percentile_cont.rs
index 77d82cf..59ddae1 100644
--- a/datafusion-physical-expr/src/expressions/approx_percentile_cont.rs
+++ b/datafusion-physical-expr/src/expressions/approx_percentile_cont.rs
@@ -171,7 +171,7 @@
}
other => {
return Err(DataFusionError::NotImplemented(format!(
- "Support for 'APPROX_PERCENTILE_CONT' for data type {} is not implemented",
+ "Support for 'APPROX_PERCENTILE_CONT' for data type {:?} is not implemented",
other
)))
}
diff --git a/datafusion-physical-expr/src/expressions/array_agg.rs b/datafusion-physical-expr/src/expressions/array_agg.rs
index e187930..3d57b83 100644
--- a/datafusion-physical-expr/src/expressions/array_agg.rs
+++ b/datafusion-physical-expr/src/expressions/array_agg.rs
@@ -158,18 +158,18 @@
#[cfg(test)]
mod tests {
use super::*;
- use crate::expressions::col;
use crate::expressions::tests::aggregate;
use crate::generic_test_op;
use arrow::array::ArrayRef;
use arrow::array::Int32Array;
use arrow::datatypes::*;
- use arrow::record_batch::RecordBatch;
+ use datafusion_common::field_util::SchemaExt;
+ use datafusion_common::record_batch::RecordBatch;
use datafusion_common::Result;
#[test]
fn array_agg_i32() -> Result<()> {
- let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
+ let a: ArrayRef = Arc::new(Int32Array::from_slice(vec![1, 2, 3, 4, 5]));
let list = ScalarValue::List(
Some(Box::new(vec![
@@ -254,7 +254,7 @@
)))),
);
- let array = ScalarValue::iter_to_array(vec![l1, l2, l3]).unwrap();
+ let array: ArrayRef = ScalarValue::iter_to_array(vec![l1, l2, l3]).unwrap();
generic_test_op!(
array,
diff --git a/datafusion-physical-expr/src/expressions/average.rs b/datafusion-physical-expr/src/expressions/average.rs
index 8888ee9..3a87f51 100644
--- a/datafusion-physical-expr/src/expressions/average.rs
+++ b/datafusion-physical-expr/src/expressions/average.rs
@@ -23,13 +23,13 @@
use crate::{AggregateExpr, PhysicalExpr};
use arrow::compute;
-use arrow::datatypes::{DataType, DECIMAL_MAX_PRECISION, DECIMAL_MAX_SCALE};
+use arrow::datatypes::DataType;
use arrow::{
array::{ArrayRef, UInt64Array},
datatypes::Field,
};
-use datafusion_common::ScalarValue;
use datafusion_common::{DataFusionError, Result};
+use datafusion_common::{ScalarValue, DECIMAL_MAX_PRECISION, DECIMAL_MAX_SCALE};
use datafusion_expr::Accumulator;
use super::{format_state_name, sum};
@@ -173,7 +173,7 @@
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
let values = &values[0];
- self.count += (values.len() - values.data().null_count()) as u64;
+ self.count += (values.len() - values.null_count()) as u64;
self.sum = sum::sum(&self.sum, &sum::sum_batch(values)?)?;
Ok(())
}
@@ -181,7 +181,7 @@
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
let counts = states[0].as_any().downcast_ref::<UInt64Array>().unwrap();
// counts are summed
- self.count += compute::sum(counts).unwrap_or(0);
+ self.count += compute::aggregate::sum_primitive(counts).unwrap_or(0);
// sums are summed
self.sum = sum::sum(&self.sum, &sum::sum_batch(&states[1])?)?;
@@ -214,10 +214,10 @@
#[cfg(test)]
mod tests {
use super::*;
- use crate::expressions::col;
use crate::generic_test_op;
- use arrow::record_batch::RecordBatch;
use arrow::{array::*, datatypes::*};
+ use datafusion_common::field_util::SchemaExt;
+ use datafusion_common::record_batch::RecordBatch;
use datafusion_common::Result;
#[test]
@@ -235,12 +235,12 @@
#[test]
fn avg_decimal() -> Result<()> {
// test agg
- let array: ArrayRef = Arc::new(
- (1..7)
- .map(Some)
- .collect::<DecimalArray>()
- .with_precision_and_scale(10, 0)?,
- );
+ let mut decimal_builder =
+ Int128Vec::with_capacity(6).to(DataType::Decimal(10, 0));
+ for i in 1..7 {
+ decimal_builder.push(Some(i as i128));
+ }
+ let array = decimal_builder.as_arc();
generic_test_op!(
array,
@@ -253,12 +253,16 @@
#[test]
fn avg_decimal_with_nulls() -> Result<()> {
- let array: ArrayRef = Arc::new(
- (1..6)
- .map(|i| if i == 2 { None } else { Some(i) })
- .collect::<DecimalArray>()
- .with_precision_and_scale(10, 0)?,
- );
+ let mut decimal_builder =
+ Int128Vec::with_capacity(5).to(DataType::Decimal(10, 0));
+ for i in 1..6 {
+ if i == 2 {
+ decimal_builder.push_null();
+ } else {
+ decimal_builder.push(Some(i));
+ }
+ }
+ let array: ArrayRef = decimal_builder.as_arc();
generic_test_op!(
array,
DataType::Decimal(10, 0),
@@ -271,12 +275,12 @@
#[test]
fn avg_decimal_all_nulls() -> Result<()> {
// test agg
- let array: ArrayRef = Arc::new(
- std::iter::repeat(None)
- .take(6)
- .collect::<DecimalArray>()
- .with_precision_and_scale(10, 0)?,
- );
+ let mut decimal_builder =
+ Int128Vec::with_capacity(5).to(DataType::Decimal(10, 0));
+ for _i in 1..6 {
+ decimal_builder.push_null();
+ }
+ let array: ArrayRef = decimal_builder.as_arc();
generic_test_op!(
array,
DataType::Decimal(10, 0),
@@ -288,7 +292,7 @@
#[test]
fn avg_i32() -> Result<()> {
- let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
+ let a: ArrayRef = Arc::new(Int32Array::from_slice(vec![1, 2, 3, 4, 5]));
generic_test_op!(
a,
DataType::Int32,
@@ -300,8 +304,8 @@
#[test]
fn avg_i32_with_nulls() -> Result<()> {
- let a: ArrayRef = Arc::new(Int32Array::from(vec![
- Some(1),
+ let a: ArrayRef = Arc::new(Int32Array::from_iter(vec![
+ Some(1_i32),
None,
Some(3),
Some(4),
@@ -318,7 +322,7 @@
#[test]
fn avg_i32_all_nulls() -> Result<()> {
- let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None]));
+ let a: ArrayRef = Arc::new(Int32Array::from_iter(vec![None, None]));
generic_test_op!(
a,
DataType::Int32,
@@ -330,8 +334,9 @@
#[test]
fn avg_u32() -> Result<()> {
- let a: ArrayRef =
- Arc::new(UInt32Array::from(vec![1_u32, 2_u32, 3_u32, 4_u32, 5_u32]));
+ let a: ArrayRef = Arc::new(UInt32Array::from_slice(vec![
+ 1_u32, 2_u32, 3_u32, 4_u32, 5_u32,
+ ]));
generic_test_op!(
a,
DataType::UInt32,
@@ -343,8 +348,9 @@
#[test]
fn avg_f32() -> Result<()> {
- let a: ArrayRef =
- Arc::new(Float32Array::from(vec![1_f32, 2_f32, 3_f32, 4_f32, 5_f32]));
+ let a: ArrayRef = Arc::new(Float32Array::from_slice(vec![
+ 1_f32, 2_f32, 3_f32, 4_f32, 5_f32,
+ ]));
generic_test_op!(
a,
DataType::Float32,
@@ -356,8 +362,9 @@
#[test]
fn avg_f64() -> Result<()> {
- let a: ArrayRef =
- Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64]));
+ let a: ArrayRef = Arc::new(Float64Array::from_slice(vec![
+ 1_f64, 2_f64, 3_f64, 4_f64, 5_f64,
+ ]));
generic_test_op!(
a,
DataType::Float64,
diff --git a/datafusion-physical-expr/src/expressions/binary.rs b/datafusion-physical-expr/src/expressions/binary.rs
index 6b40c8f..ab04790 100644
--- a/datafusion-physical-expr/src/expressions/binary.rs
+++ b/datafusion-physical-expr/src/expressions/binary.rs
@@ -15,380 +15,125 @@
// specific language governing permissions and limitations
// under the License.
-use std::convert::TryInto;
-use std::{any::Any, sync::Arc};
+use std::{any::Any, convert::TryInto, sync::Arc};
-use arrow::array::TimestampMillisecondArray;
use arrow::array::*;
-use arrow::compute::kernels::arithmetic::{
- add, add_scalar, divide, divide_scalar, modulus, modulus_scalar, multiply,
- multiply_scalar, subtract, subtract_scalar,
-};
-use arrow::compute::kernels::boolean::{and_kleene, not, or_kleene};
-use arrow::compute::kernels::comparison::{
- eq_bool_scalar, gt_bool_scalar, gt_eq_bool_scalar, lt_bool_scalar, lt_eq_bool_scalar,
- neq_bool_scalar,
-};
-use arrow::compute::kernels::comparison::{
- eq_dyn_bool_scalar, gt_dyn_bool_scalar, gt_eq_dyn_bool_scalar, lt_dyn_bool_scalar,
- lt_eq_dyn_bool_scalar, neq_dyn_bool_scalar,
-};
-use arrow::compute::kernels::comparison::{
- eq_dyn_scalar, gt_dyn_scalar, gt_eq_dyn_scalar, lt_dyn_scalar, lt_eq_dyn_scalar,
- neq_dyn_scalar,
-};
-use arrow::compute::kernels::comparison::{
- eq_dyn_utf8_scalar, gt_dyn_utf8_scalar, gt_eq_dyn_utf8_scalar, lt_dyn_utf8_scalar,
- lt_eq_dyn_utf8_scalar, neq_dyn_utf8_scalar,
-};
-use arrow::compute::kernels::comparison::{
- eq_scalar, gt_eq_scalar, gt_scalar, lt_eq_scalar, lt_scalar, neq_scalar,
-};
-use arrow::compute::kernels::comparison::{
- eq_utf8_scalar, gt_eq_utf8_scalar, gt_utf8_scalar, like_utf8_scalar,
- lt_eq_utf8_scalar, lt_utf8_scalar, neq_utf8_scalar, nlike_utf8_scalar,
- regexp_is_match_utf8_scalar,
-};
-use arrow::compute::kernels::comparison::{like_utf8, nlike_utf8, regexp_is_match_utf8};
-use arrow::datatypes::{ArrowNumericType, DataType, Schema, TimeUnit};
-use arrow::error::ArrowError::DivideByZero;
-use arrow::record_batch::RecordBatch;
+use arrow::compute;
+use arrow::datatypes::DataType::Decimal;
+use arrow::datatypes::{DataType, Schema};
+use arrow::scalar::Scalar;
+use arrow::types::NativeType;
use crate::coercion_rule::binary_rule::coerce_types;
use crate::expressions::try_cast;
use crate::PhysicalExpr;
+use datafusion_common::record_batch::RecordBatch;
use datafusion_common::ScalarValue;
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::ColumnarValue;
use datafusion_expr::Operator;
-// TODO move to arrow_rs
-// https://github.com/apache/arrow-rs/issues/1312
-fn as_decimal_array(arr: &dyn Array) -> &DecimalArray {
- arr.as_any()
- .downcast_ref::<DecimalArray>()
- .expect("Unable to downcast to typed array to DecimalArray")
-}
+// fn as_decimal_array(arr: &dyn Array) -> &Int128Array {
+// arr.as_any()
+// .downcast_ref::<Int128Array>()
+// .expect("Unable to downcast to typed array to DecimalArray")
+// }
-/// create a `dyn_op` wrapper function for the specified operation
-/// that call the underlying dyn_op arrow kernel if the type is
-/// supported, and translates ArrowError to DataFusionError
-macro_rules! make_dyn_comp_op {
- ($OP:tt) => {
- paste::paste! {
- /// wrapper over arrow compute kernel that maps Error types and
- /// patches missing support in arrow
- fn [<$OP _dyn>] (left: &dyn Array, right: &dyn Array) -> Result<ArrayRef> {
- match (left.data_type(), right.data_type()) {
- // Call `op_decimal` (e.g. `eq_decimal) until
- // arrow has native support
- // https://github.com/apache/arrow-rs/issues/1200
- (DataType::Decimal(_, _), DataType::Decimal(_, _)) => {
- [<$OP _decimal>](as_decimal_array(left), as_decimal_array(right))
- },
- // By default call the arrow kernel
- _ => {
- arrow::compute::kernels::comparison::[<$OP _dyn>](left, right)
- .map_err(|e| e.into())
- }
- }
- .map(|a| Arc::new(a) as ArrayRef)
- }
- }
- };
-}
-
-// create eq_dyn, gt_dyn, wrappers etc
-make_dyn_comp_op!(eq);
-make_dyn_comp_op!(gt);
-make_dyn_comp_op!(gt_eq);
-make_dyn_comp_op!(lt);
-make_dyn_comp_op!(lt_eq);
-make_dyn_comp_op!(neq);
+// /// create a `dyn_op` wrapper function for the specified operation
+// /// that call the underlying dyn_op arrow kernel if the type is
+// /// supported, and translates ArrowError to DataFusionError
+// macro_rules! make_dyn_comp_op {
+// ($OP:tt) => {
+// paste::paste! {
+// /// wrapper over arrow compute kernel that maps Error types and
+// /// patches missing support in arrow
+// fn [<$OP _dyn>] (left: &dyn Array, right: &dyn Array) -> Result<ArrayRef> {
+// match (left.data_type(), right.data_type()) {
+// // Call `op_decimal` (e.g. `eq_decimal) until
+// // arrow has native support
+// // https://github.com/apache/arrow-rs/issues/1200
+// (DataType::Decimal(_, _), DataType::Decimal(_, _)) => {
+// [<$OP _decimal>](as_decimal_array(left), as_decimal_array(right))
+// },
+// // By default call the arrow kernel
+// _ => {
+// arrow::compute::comparison::[<$OP _dyn>](left, right)
+// .map_err(|e| e.into())
+// }
+// }
+// .map(|a| Arc::new(a) as ArrayRef)
+// }
+// }
+// };
+// }
+//
+// // create eq_dyn, gt_dyn, wrappers etc
+// make_dyn_comp_op!(eq);
+// make_dyn_comp_op!(gt);
+// make_dyn_comp_op!(gt_eq);
+// make_dyn_comp_op!(lt);
+// make_dyn_comp_op!(lt_eq);
+// make_dyn_comp_op!(neq);
// Simple (low performance) kernels until optimized kernels are added to arrow
// See https://github.com/apache/arrow-rs/issues/960
-fn is_distinct_from_bool(
- left: &BooleanArray,
- right: &BooleanArray,
-) -> Result<BooleanArray> {
+fn is_distinct_from_bool(left: &dyn Array, right: &dyn Array) -> BooleanArray {
// Different from `neq_bool` because `null is distinct from null` is false and not null
- Ok(left
- .iter()
+ let left = left
+ .as_any()
+ .downcast_ref::<BooleanArray>()
+ .expect("distinct_from op failed to downcast to boolean array");
+ let right = right
+ .as_any()
+ .downcast_ref::<BooleanArray>()
+ .expect("distinct_from op failed to downcast to boolean array");
+ left.iter()
.zip(right.iter())
.map(|(left, right)| Some(left != right))
- .collect())
+ .collect()
}
-fn is_not_distinct_from_bool(
- left: &BooleanArray,
- right: &BooleanArray,
-) -> Result<BooleanArray> {
- Ok(left
- .iter()
+fn is_not_distinct_from_bool(left: &dyn Array, right: &dyn Array) -> BooleanArray {
+ let left = left
+ .as_any()
+ .downcast_ref::<BooleanArray>()
+ .expect("not_distinct_from op failed to downcast to boolean array");
+ let right = right
+ .as_any()
+ .downcast_ref::<BooleanArray>()
+ .expect("not_distinct_from op failed to downcast to boolean array");
+ left.iter()
.zip(right.iter())
.map(|(left, right)| Some(left == right))
- .collect())
+ .collect()
}
-// TODO move decimal kernels to to arrow-rs
-// https://github.com/apache/arrow-rs/issues/1200
-
-// TODO use iter added for for decimal array in
-// https://github.com/apache/arrow-rs/issues/1083
-pub(super) fn eq_decimal_scalar(
- left: &DecimalArray,
- right: i128,
-) -> Result<BooleanArray> {
- let mut bool_builder = BooleanBuilder::new(left.len());
- for i in 0..left.len() {
- if left.is_null(i) {
- bool_builder.append_null()?;
+/// The binary_bitwise_array_op macro only evaluates for integer types
+/// like int64, int32.
+/// It is used to do bitwise operation on an array with a scalar.
+macro_rules! binary_bitwise_array_scalar {
+ ($LEFT:expr, $RIGHT:expr, $OP:tt, $ARRAY_TYPE:ident, $TYPE:ty) => {{
+ let len = $LEFT.len();
+ let array = $LEFT.as_any().downcast_ref::<$ARRAY_TYPE>().unwrap();
+ let scalar = $RIGHT;
+ if scalar.is_null() {
+ Ok(new_null_array(array.data_type().clone(), len).into())
} else {
- bool_builder.append_value(left.value(i) == right)?;
+ let right: $TYPE = scalar.try_into().unwrap();
+ let result = (0..len)
+ .into_iter()
+ .map(|i| {
+ if array.is_null(i) {
+ None
+ } else {
+ Some(array.value(i) $OP right)
+ }
+ })
+ .collect::<$ARRAY_TYPE>();
+ Ok(Arc::new(result) as ArrayRef)
}
- }
- Ok(bool_builder.finish())
-}
-
-pub(super) fn eq_decimal(
- left: &DecimalArray,
- right: &DecimalArray,
-) -> Result<BooleanArray> {
- let mut bool_builder = BooleanBuilder::new(left.len());
- for i in 0..left.len() {
- if left.is_null(i) || right.is_null(i) {
- bool_builder.append_null()?;
- } else {
- bool_builder.append_value(left.value(i) == right.value(i))?;
- }
- }
- Ok(bool_builder.finish())
-}
-
-fn neq_decimal_scalar(left: &DecimalArray, right: i128) -> Result<BooleanArray> {
- let mut bool_builder = BooleanBuilder::new(left.len());
- for i in 0..left.len() {
- if left.is_null(i) {
- bool_builder.append_null()?;
- } else {
- bool_builder.append_value(left.value(i) != right)?;
- }
- }
- Ok(bool_builder.finish())
-}
-
-fn neq_decimal(left: &DecimalArray, right: &DecimalArray) -> Result<BooleanArray> {
- let mut bool_builder = BooleanBuilder::new(left.len());
- for i in 0..left.len() {
- if left.is_null(i) || right.is_null(i) {
- bool_builder.append_null()?;
- } else {
- bool_builder.append_value(left.value(i) != right.value(i))?;
- }
- }
- Ok(bool_builder.finish())
-}
-
-fn lt_decimal_scalar(left: &DecimalArray, right: i128) -> Result<BooleanArray> {
- let mut bool_builder = BooleanBuilder::new(left.len());
- for i in 0..left.len() {
- if left.is_null(i) {
- bool_builder.append_null()?;
- } else {
- bool_builder.append_value(left.value(i) < right)?;
- }
- }
- Ok(bool_builder.finish())
-}
-
-fn lt_decimal(left: &DecimalArray, right: &DecimalArray) -> Result<BooleanArray> {
- let mut bool_builder = BooleanBuilder::new(left.len());
- for i in 0..left.len() {
- if left.is_null(i) || right.is_null(i) {
- bool_builder.append_null()?;
- } else {
- bool_builder.append_value(left.value(i) < right.value(i))?;
- }
- }
- Ok(bool_builder.finish())
-}
-
-fn lt_eq_decimal_scalar(left: &DecimalArray, right: i128) -> Result<BooleanArray> {
- let mut bool_builder = BooleanBuilder::new(left.len());
- for i in 0..left.len() {
- if left.is_null(i) {
- bool_builder.append_null()?;
- } else {
- bool_builder.append_value(left.value(i) <= right)?;
- }
- }
- Ok(bool_builder.finish())
-}
-
-fn lt_eq_decimal(left: &DecimalArray, right: &DecimalArray) -> Result<BooleanArray> {
- let mut bool_builder = BooleanBuilder::new(left.len());
- for i in 0..left.len() {
- if left.is_null(i) || right.is_null(i) {
- bool_builder.append_null()?;
- } else {
- bool_builder.append_value(left.value(i) <= right.value(i))?;
- }
- }
- Ok(bool_builder.finish())
-}
-
-fn gt_decimal_scalar(left: &DecimalArray, right: i128) -> Result<BooleanArray> {
- let mut bool_builder = BooleanBuilder::new(left.len());
- for i in 0..left.len() {
- if left.is_null(i) {
- bool_builder.append_null()?;
- } else {
- bool_builder.append_value(left.value(i) > right)?;
- }
- }
- Ok(bool_builder.finish())
-}
-
-fn gt_decimal(left: &DecimalArray, right: &DecimalArray) -> Result<BooleanArray> {
- let mut bool_builder = BooleanBuilder::new(left.len());
- for i in 0..left.len() {
- if left.is_null(i) || right.is_null(i) {
- bool_builder.append_null()?;
- } else {
- bool_builder.append_value(left.value(i) > right.value(i))?;
- }
- }
- Ok(bool_builder.finish())
-}
-
-fn gt_eq_decimal_scalar(left: &DecimalArray, right: i128) -> Result<BooleanArray> {
- let mut bool_builder = BooleanBuilder::new(left.len());
- for i in 0..left.len() {
- if left.is_null(i) {
- bool_builder.append_null()?;
- } else {
- bool_builder.append_value(left.value(i) >= right)?;
- }
- }
- Ok(bool_builder.finish())
-}
-
-fn gt_eq_decimal(left: &DecimalArray, right: &DecimalArray) -> Result<BooleanArray> {
- let mut bool_builder = BooleanBuilder::new(left.len());
- for i in 0..left.len() {
- if left.is_null(i) || right.is_null(i) {
- bool_builder.append_null()?;
- } else {
- bool_builder.append_value(left.value(i) >= right.value(i))?;
- }
- }
- Ok(bool_builder.finish())
-}
-
-fn is_distinct_from_decimal(
- left: &DecimalArray,
- right: &DecimalArray,
-) -> Result<BooleanArray> {
- let mut bool_builder = BooleanBuilder::new(left.len());
- for i in 0..left.len() {
- match (left.is_null(i), right.is_null(i)) {
- (true, true) => bool_builder.append_value(false)?,
- (true, false) | (false, true) => bool_builder.append_value(true)?,
- (_, _) => bool_builder.append_value(left.value(i) != right.value(i))?,
- }
- }
- Ok(bool_builder.finish())
-}
-
-fn is_not_distinct_from_decimal(
- left: &DecimalArray,
- right: &DecimalArray,
-) -> Result<BooleanArray> {
- let mut bool_builder = BooleanBuilder::new(left.len());
- for i in 0..left.len() {
- match (left.is_null(i), right.is_null(i)) {
- (true, true) => bool_builder.append_value(true)?,
- (true, false) | (false, true) => bool_builder.append_value(false)?,
- (_, _) => bool_builder.append_value(left.value(i) == right.value(i))?,
- }
- }
- Ok(bool_builder.finish())
-}
-
-fn add_decimal(left: &DecimalArray, right: &DecimalArray) -> Result<DecimalArray> {
- let mut decimal_builder =
- DecimalBuilder::new(left.len(), left.precision(), left.scale());
- for i in 0..left.len() {
- if left.is_null(i) || right.is_null(i) {
- decimal_builder.append_null()?;
- } else {
- decimal_builder.append_value(left.value(i) + right.value(i))?;
- }
- }
- Ok(decimal_builder.finish())
-}
-
-fn subtract_decimal(left: &DecimalArray, right: &DecimalArray) -> Result<DecimalArray> {
- let mut decimal_builder =
- DecimalBuilder::new(left.len(), left.precision(), left.scale());
- for i in 0..left.len() {
- if left.is_null(i) || right.is_null(i) {
- decimal_builder.append_null()?;
- } else {
- decimal_builder.append_value(left.value(i) - right.value(i))?;
- }
- }
- Ok(decimal_builder.finish())
-}
-
-fn multiply_decimal(left: &DecimalArray, right: &DecimalArray) -> Result<DecimalArray> {
- let mut decimal_builder =
- DecimalBuilder::new(left.len(), left.precision(), left.scale());
- let divide = 10_i128.pow(left.scale() as u32);
- for i in 0..left.len() {
- if left.is_null(i) || right.is_null(i) {
- decimal_builder.append_null()?;
- } else {
- decimal_builder.append_value(left.value(i) * right.value(i) / divide)?;
- }
- }
- Ok(decimal_builder.finish())
-}
-
-fn divide_decimal(left: &DecimalArray, right: &DecimalArray) -> Result<DecimalArray> {
- let mut decimal_builder =
- DecimalBuilder::new(left.len(), left.precision(), left.scale());
- let mul = 10_f64.powi(left.scale() as i32);
- for i in 0..left.len() {
- if left.is_null(i) || right.is_null(i) {
- decimal_builder.append_null()?;
- } else if right.value(i) == 0 {
- return Err(DataFusionError::ArrowError(DivideByZero));
- } else {
- let l_value = left.value(i) as f64;
- let r_value = right.value(i) as f64;
- let result = ((l_value / r_value) * mul) as i128;
- decimal_builder.append_value(result)?;
- }
- }
- Ok(decimal_builder.finish())
-}
-
-fn modulus_decimal(left: &DecimalArray, right: &DecimalArray) -> Result<DecimalArray> {
- let mut decimal_builder =
- DecimalBuilder::new(left.len(), left.precision(), left.scale());
- for i in 0..left.len() {
- if left.is_null(i) || right.is_null(i) {
- decimal_builder.append_null()?;
- } else if right.value(i) == 0 {
- return Err(DataFusionError::ArrowError(DivideByZero));
- } else {
- decimal_builder.append_value(left.value(i) % right.value(i))?;
- }
- }
- Ok(decimal_builder.finish())
+ }};
}
/// The binary_bitwise_array_op macro only evaluates for integer types
@@ -413,34 +158,7 @@
}};
}
-/// The binary_bitwise_array_op macro only evaluates for integer types
-/// like int64, int32.
-/// It is used to do bitwise operation on an array with a scalar.
-macro_rules! binary_bitwise_array_scalar {
- ($LEFT:expr, $RIGHT:expr, $OP:tt, $ARRAY_TYPE:ident, $TYPE:ty) => {{
- let len = $LEFT.len();
- let array = $LEFT.as_any().downcast_ref::<$ARRAY_TYPE>().unwrap();
- let scalar = $RIGHT;
- if scalar.is_null() {
- Ok(new_null_array(array.data_type(), len))
- } else {
- let right: $TYPE = scalar.try_into().unwrap();
- let result = (0..len)
- .into_iter()
- .map(|i| {
- if array.is_null(i) {
- None
- } else {
- Some(array.value(i) $OP right)
- }
- })
- .collect::<$ARRAY_TYPE>();
- Ok(Arc::new(result) as ArrayRef)
- }
- }};
-}
-
-fn bitwise_and(left: ArrayRef, right: ArrayRef) -> Result<ArrayRef> {
+fn bitwise_and(left: &dyn Array, right: &dyn Array) -> Result<ArrayRef> {
match &left.data_type() {
DataType::Int8 => {
binary_bitwise_array_op!(left, right, &, Int8Array, i8)
@@ -462,7 +180,7 @@
}
}
-fn bitwise_or(left: ArrayRef, right: ArrayRef) -> Result<ArrayRef> {
+fn bitwise_or(left: &dyn Array, right: &dyn Array) -> Result<ArrayRef> {
match &left.data_type() {
DataType::Int8 => {
binary_bitwise_array_op!(left, right, |, Int8Array, i8)
@@ -573,465 +291,349 @@
}
}
-macro_rules! compute_decimal_op_scalar {
- ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{
- let ll = $LEFT.as_any().downcast_ref::<$DT>().unwrap();
- Ok(Arc::new(paste::expr! {[<$OP _decimal_scalar>]}(
- ll,
- $RIGHT.try_into()?,
- )?))
- }};
-}
-
-macro_rules! compute_decimal_op {
- ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{
- let ll = $LEFT.as_any().downcast_ref::<$DT>().unwrap();
- let rr = $RIGHT.as_any().downcast_ref::<$DT>().unwrap();
- Ok(Arc::new(paste::expr! {[<$OP _decimal>]}(ll, rr)?))
- }};
-}
-
-/// Invoke a compute kernel on a pair of binary data arrays
-macro_rules! compute_utf8_op {
- ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{
- let ll = $LEFT
- .as_any()
- .downcast_ref::<$DT>()
- .expect("compute_op failed to downcast array");
- let rr = $RIGHT
- .as_any()
- .downcast_ref::<$DT>()
- .expect("compute_op failed to downcast array");
- Ok(Arc::new(paste::expr! {[<$OP _utf8>]}(&ll, &rr)?))
- }};
-}
-
-/// Invoke a compute kernel on a data array and a scalar value
-macro_rules! compute_utf8_op_scalar {
- ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{
- let ll = $LEFT
- .as_any()
- .downcast_ref::<$DT>()
- .expect("compute_op failed to downcast array");
- if let ScalarValue::Utf8(Some(string_value)) = $RIGHT {
- Ok(Arc::new(paste::expr! {[<$OP _utf8_scalar>]}(
- &ll,
- &string_value,
- )?))
- } else {
- Err(DataFusionError::Internal(format!(
- "compute_utf8_op_scalar for '{}' failed to cast literal value {}",
- stringify!($OP),
- $RIGHT
- )))
- }
- }};
-}
-
-/// Invoke a compute kernel on a data array and a scalar value
-macro_rules! compute_utf8_op_dyn_scalar {
- ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
- if let Some(string_value) = $RIGHT {
- Ok(Arc::new(paste::expr! {[<$OP _dyn_utf8_scalar>]}(
- $LEFT,
- &string_value,
- )?))
- } else {
- Err(DataFusionError::Internal(format!(
- "compute_utf8_op_scalar for '{}' failed with literal 'none' value",
- stringify!($OP),
- )))
- }
- }};
-}
-
-/// Invoke a compute kernel on a boolean data array and a scalar value
-macro_rules! compute_bool_op_scalar {
- ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{
- use std::convert::TryInto;
- let ll = $LEFT
- .as_any()
- .downcast_ref::<$DT>()
- .expect("compute_op failed to downcast array");
- // generate the scalar function name, such as lt_scalar, from the $OP parameter
- // (which could have a value of lt) and the suffix _scalar
- Ok(Arc::new(paste::expr! {[<$OP _bool_scalar>]}(
- &ll,
- $RIGHT.try_into()?,
- )?))
- }};
-}
-
-/// Invoke a compute kernel on a boolean data array and a scalar value
-macro_rules! compute_bool_op_dyn_scalar {
- ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
- // generate the scalar function name, such as lt_dyn_bool_scalar, from the $OP parameter
- // (which could have a value of lt) and the suffix _scalar
- if let Some(b) = $RIGHT {
- Ok(Arc::new(paste::expr! {[<$OP _dyn_bool_scalar>]}(
- $LEFT,
- b,
- )?))
- } else {
- Err(DataFusionError::Internal(format!(
- "compute_utf8_op_scalar for '{}' failed with literal 'none' value",
- stringify!($OP),
- )))
- }
- }};
-}
-
-/// Invoke a bool compute kernel on array(s)
-macro_rules! compute_bool_op {
- // invoke binary operator
- ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{
- let ll = $LEFT
- .as_any()
- .downcast_ref::<$DT>()
- .expect("compute_op failed to downcast left side array");
- let rr = $RIGHT
- .as_any()
- .downcast_ref::<$DT>()
- .expect("compute_op failed to downcast right side array");
- Ok(Arc::new(paste::expr! {[<$OP _bool>]}(&ll, &rr)?))
- }};
- // invoke unary operator
- ($OPERAND:expr, $OP:ident, $DT:ident) => {{
- let operand = $OPERAND
- .as_any()
- .downcast_ref::<$DT>()
- .expect("compute_op failed to downcast operant array");
- Ok(Arc::new(paste::expr! {[<$OP _bool>]}(&operand)?))
- }};
-}
-
-/// Invoke a compute kernel on a data array and a scalar value
-/// LEFT is array, RIGHT is scalar value
-macro_rules! compute_op_scalar {
- ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{
- let ll = $LEFT
- .as_any()
- .downcast_ref::<$DT>()
- .expect("compute_op failed to downcast array");
- // generate the scalar function name, such as lt_scalar, from the $OP parameter
- // (which could have a value of lt) and the suffix _scalar
- Ok(Arc::new(paste::expr! {[<$OP _scalar>]}(
- &ll,
- $RIGHT.try_into()?,
- )?))
- }};
-}
-
-/// Invoke a dyn compute kernel on a data array and a scalar value
-/// LEFT is Primitive or Dictionart array of numeric values, RIGHT is scalar value
-macro_rules! compute_op_dyn_scalar {
- ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
- // generate the scalar function name, such as lt_dyn_scalar, from the $OP parameter
- // (which could have a value of lt_dyn) and the suffix _scalar
- if let Some(value) = $RIGHT {
- Ok(Arc::new(paste::expr! {[<$OP _dyn_scalar>]}(
- $LEFT,
- value,
- )?))
- } else {
- Err(DataFusionError::Internal(format!(
- "compute_utf8_op_scalar for '{}' failed with literal 'none' value",
- stringify!($OP),
- )))
- }
- }};
-}
-
-/// Invoke a compute kernel on array(s)
-macro_rules! compute_op {
- // invoke binary operator
- ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{
- let ll = $LEFT
- .as_any()
- .downcast_ref::<$DT>()
- .expect("compute_op failed to downcast array");
- let rr = $RIGHT
- .as_any()
- .downcast_ref::<$DT>()
- .expect("compute_op failed to downcast array");
- Ok(Arc::new($OP(&ll, &rr)?))
- }};
- // invoke unary operator
- ($OPERAND:expr, $OP:ident, $DT:ident) => {{
- let operand = $OPERAND
- .as_any()
- .downcast_ref::<$DT>()
- .expect("compute_op failed to downcast array");
- Ok(Arc::new($OP(&operand)?))
- }};
-}
-
-macro_rules! binary_string_array_op_scalar {
- ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
- let result: Result<Arc<dyn Array>> = match $LEFT.data_type() {
- DataType::Utf8 => compute_utf8_op_scalar!($LEFT, $RIGHT, $OP, StringArray),
- other => Err(DataFusionError::Internal(format!(
- "Data type {:?} not supported for scalar operation '{}' on string array",
- other, stringify!($OP)
- ))),
- };
- Some(result)
- }};
-}
-
-macro_rules! binary_string_array_op {
- ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
- match $LEFT.data_type() {
- DataType::Utf8 => compute_utf8_op!($LEFT, $RIGHT, $OP, StringArray),
- other => Err(DataFusionError::Internal(format!(
- "Data type {:?} not supported for binary operation '{}' on string arrays",
- other, stringify!($OP)
- ))),
- }
- }};
-}
-
-/// Invoke a compute kernel on a pair of arrays
-/// The binary_primitive_array_op macro only evaluates for primitive types
-/// like integers and floats.
-macro_rules! binary_primitive_array_op {
- ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
- match $LEFT.data_type() {
- // TODO support decimal type
- // which is not the primitive type
- DataType::Decimal(_,_) => compute_decimal_op!($LEFT, $RIGHT, $OP, DecimalArray),
- DataType::Int8 => compute_op!($LEFT, $RIGHT, $OP, Int8Array),
- DataType::Int16 => compute_op!($LEFT, $RIGHT, $OP, Int16Array),
- DataType::Int32 => compute_op!($LEFT, $RIGHT, $OP, Int32Array),
- DataType::Int64 => compute_op!($LEFT, $RIGHT, $OP, Int64Array),
- DataType::UInt8 => compute_op!($LEFT, $RIGHT, $OP, UInt8Array),
- DataType::UInt16 => compute_op!($LEFT, $RIGHT, $OP, UInt16Array),
- DataType::UInt32 => compute_op!($LEFT, $RIGHT, $OP, UInt32Array),
- DataType::UInt64 => compute_op!($LEFT, $RIGHT, $OP, UInt64Array),
- DataType::Float32 => compute_op!($LEFT, $RIGHT, $OP, Float32Array),
- DataType::Float64 => compute_op!($LEFT, $RIGHT, $OP, Float64Array),
- other => Err(DataFusionError::Internal(format!(
- "Data type {:?} not supported for binary operation '{}' on primitive arrays",
- other, stringify!($OP)
- ))),
- }
- }};
-}
-
-/// Invoke a compute kernel on an array and a scalar
-/// The binary_primitive_array_op_scalar macro only evaluates for primitive
-/// types like integers and floats.
-macro_rules! binary_primitive_array_op_scalar {
- ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
- let result: Result<Arc<dyn Array>> = match $LEFT.data_type() {
- DataType::Int8 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int8Array),
- DataType::Int16 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int16Array),
- DataType::Int32 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int32Array),
- DataType::Int64 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int64Array),
- DataType::UInt8 => compute_op_scalar!($LEFT, $RIGHT, $OP, UInt8Array),
- DataType::UInt16 => compute_op_scalar!($LEFT, $RIGHT, $OP, UInt16Array),
- DataType::UInt32 => compute_op_scalar!($LEFT, $RIGHT, $OP, UInt32Array),
- DataType::UInt64 => compute_op_scalar!($LEFT, $RIGHT, $OP, UInt64Array),
- DataType::Float32 => compute_op_scalar!($LEFT, $RIGHT, $OP, Float32Array),
- DataType::Float64 => compute_op_scalar!($LEFT, $RIGHT, $OP, Float64Array),
- other => Err(DataFusionError::Internal(format!(
- "Data type {:?} not supported for scalar operation '{}' on primitive array",
- other, stringify!($OP)
- ))),
- };
- Some(result)
- }};
-}
-
-/// The binary_array_op_scalar macro includes types that extend beyond the primitive,
-/// such as Utf8 strings.
-#[macro_export]
-macro_rules! binary_array_op_scalar {
- ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
- let result: Result<Arc<dyn Array>> = match $LEFT.data_type() {
- DataType::Decimal(_,_) => compute_decimal_op_scalar!($LEFT, $RIGHT, $OP, DecimalArray),
- DataType::Int8 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int8Array),
- DataType::Int16 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int16Array),
- DataType::Int32 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int32Array),
- DataType::Int64 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int64Array),
- DataType::UInt8 => compute_op_scalar!($LEFT, $RIGHT, $OP, UInt8Array),
- DataType::UInt16 => compute_op_scalar!($LEFT, $RIGHT, $OP, UInt16Array),
- DataType::UInt32 => compute_op_scalar!($LEFT, $RIGHT, $OP, UInt32Array),
- DataType::UInt64 => compute_op_scalar!($LEFT, $RIGHT, $OP, UInt64Array),
- DataType::Float32 => compute_op_scalar!($LEFT, $RIGHT, $OP, Float32Array),
- DataType::Float64 => compute_op_scalar!($LEFT, $RIGHT, $OP, Float64Array),
- DataType::Utf8 => compute_utf8_op_scalar!($LEFT, $RIGHT, $OP, StringArray),
- DataType::Timestamp(TimeUnit::Nanosecond, _) => {
- compute_op_scalar!($LEFT, $RIGHT, $OP, TimestampNanosecondArray)
- }
- DataType::Timestamp(TimeUnit::Microsecond, _) => {
- compute_op_scalar!($LEFT, $RIGHT, $OP, TimestampMicrosecondArray)
- }
- DataType::Timestamp(TimeUnit::Millisecond, _) => {
- compute_op_scalar!($LEFT, $RIGHT, $OP, TimestampMillisecondArray)
- }
- DataType::Timestamp(TimeUnit::Second, _) => {
- compute_op_scalar!($LEFT, $RIGHT, $OP, TimestampSecondArray)
- }
- DataType::Date32 => {
- compute_op_scalar!($LEFT, $RIGHT, $OP, Date32Array)
- }
- DataType::Date64 => {
- compute_op_scalar!($LEFT, $RIGHT, $OP, Date64Array)
- }
- DataType::Boolean => compute_bool_op_scalar!($LEFT, $RIGHT, $OP, BooleanArray),
- other => Err(DataFusionError::Internal(format!(
- "Data type {:?} not supported for scalar operation '{}' on dyn array",
- other, stringify!($OP)
- ))),
- };
- Some(result)
- }};
-}
-
-/// The binary_array_op macro includes types that extend beyond the primitive,
-/// such as Utf8 strings.
-#[macro_export]
-macro_rules! binary_array_op {
- ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
- match $LEFT.data_type() {
- DataType::Decimal(_,_) => compute_decimal_op!($LEFT, $RIGHT, $OP, DecimalArray),
- DataType::Int8 => compute_op!($LEFT, $RIGHT, $OP, Int8Array),
- DataType::Int16 => compute_op!($LEFT, $RIGHT, $OP, Int16Array),
- DataType::Int32 => compute_op!($LEFT, $RIGHT, $OP, Int32Array),
- DataType::Int64 => compute_op!($LEFT, $RIGHT, $OP, Int64Array),
- DataType::UInt8 => compute_op!($LEFT, $RIGHT, $OP, UInt8Array),
- DataType::UInt16 => compute_op!($LEFT, $RIGHT, $OP, UInt16Array),
- DataType::UInt32 => compute_op!($LEFT, $RIGHT, $OP, UInt32Array),
- DataType::UInt64 => compute_op!($LEFT, $RIGHT, $OP, UInt64Array),
- DataType::Float32 => compute_op!($LEFT, $RIGHT, $OP, Float32Array),
- DataType::Float64 => compute_op!($LEFT, $RIGHT, $OP, Float64Array),
- DataType::Utf8 => compute_utf8_op!($LEFT, $RIGHT, $OP, StringArray),
- DataType::Timestamp(TimeUnit::Nanosecond, _) => {
- compute_op!($LEFT, $RIGHT, $OP, TimestampNanosecondArray)
- }
- DataType::Timestamp(TimeUnit::Microsecond, _) => {
- compute_op!($LEFT, $RIGHT, $OP, TimestampMicrosecondArray)
- }
- DataType::Timestamp(TimeUnit::Millisecond, _) => {
- compute_op!($LEFT, $RIGHT, $OP, TimestampMillisecondArray)
- }
- DataType::Timestamp(TimeUnit::Second, _) => {
- compute_op!($LEFT, $RIGHT, $OP, TimestampSecondArray)
- }
- DataType::Date32 => {
- compute_op!($LEFT, $RIGHT, $OP, Date32Array)
- }
- DataType::Date64 => {
- compute_op!($LEFT, $RIGHT, $OP, Date64Array)
- }
- DataType::Boolean => compute_bool_op!($LEFT, $RIGHT, $OP, BooleanArray),
- other => Err(DataFusionError::Internal(format!(
- "Data type {:?} not supported for binary operation '{}' on dyn arrays",
- other, stringify!($OP)
- ))),
- }
- }};
-}
-
/// Invoke a boolean kernel on a pair of arrays
macro_rules! boolean_op {
- ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
+ ($LEFT:expr, $RIGHT:expr, $OP:expr) => {{
let ll = $LEFT
.as_any()
- .downcast_ref::<BooleanArray>()
+ .downcast_ref()
.expect("boolean_op failed to downcast array");
let rr = $RIGHT
.as_any()
- .downcast_ref::<BooleanArray>()
+ .downcast_ref()
.expect("boolean_op failed to downcast array");
Ok(Arc::new($OP(&ll, &rr)?))
}};
}
-macro_rules! binary_string_array_flag_op {
- ($LEFT:expr, $RIGHT:expr, $OP:ident, $NOT:expr, $FLAG:expr) => {{
- match $LEFT.data_type() {
- DataType::Utf8 => {
- compute_utf8_flag_op!($LEFT, $RIGHT, $OP, StringArray, $NOT, $FLAG)
+#[inline]
+fn evaluate_regex<O: Offset>(lhs: &dyn Array, rhs: &dyn Array) -> Result<BooleanArray> {
+ Ok(compute::regex_match::regex_match::<O>(
+ lhs.as_any().downcast_ref().unwrap(),
+ rhs.as_any().downcast_ref().unwrap(),
+ )?)
+}
+
+#[inline]
+fn evaluate_regex_case_insensitive<O: Offset>(
+ lhs: &dyn Array,
+ rhs: &dyn Array,
+) -> Result<BooleanArray> {
+ let patterns_arr = rhs.as_any().downcast_ref::<Utf8Array<O>>().unwrap();
+ // TODO: avoid this pattern array iteration by building the new regex pattern in the match
+ // loop. We need to roll our own regex compute kernel instead of using the ones from arrow for
+ // postgresql compatibility.
+ let patterns = patterns_arr
+ .iter()
+ .map(|pattern| pattern.map(|s| format!("(?i){}", s)))
+ .collect::<Vec<_>>();
+ Ok(compute::regex_match::regex_match::<O>(
+ lhs.as_any().downcast_ref().unwrap(),
+ &Utf8Array::<O>::from(patterns),
+ )?)
+}
+
+fn evaluate(lhs: &dyn Array, op: &Operator, rhs: &dyn Array) -> Result<Arc<dyn Array>> {
+ use Operator::*;
+ if matches!(op, Plus) {
+ let arr: ArrayRef = match (lhs.data_type(), rhs.data_type()) {
+ (Decimal(p1, s1), Decimal(p2, s2)) => {
+ let left_array =
+ lhs.as_any().downcast_ref::<PrimitiveArray<i128>>().unwrap();
+ let right_array =
+ rhs.as_any().downcast_ref::<PrimitiveArray<i128>>().unwrap();
+ Arc::new(if *p1 == *p2 && *s1 == *s2 {
+ compute::arithmetics::decimal::add(left_array, right_array)
+ } else {
+ compute::arithmetics::decimal::adaptive_add(left_array, right_array)?
+ })
}
- DataType::LargeUtf8 => {
- compute_utf8_flag_op!($LEFT, $RIGHT, $OP, LargeStringArray, $NOT, $FLAG)
+ _ => compute::arithmetics::add(lhs, rhs).into(),
+ };
+ Ok(arr)
+ } else if matches!(op, Minus | Divide | Multiply | Modulo) {
+ let arr = match op {
+ Operator::Minus => compute::arithmetics::sub(lhs, rhs),
+ Operator::Divide => compute::arithmetics::div(lhs, rhs),
+ Operator::Multiply => compute::arithmetics::mul(lhs, rhs),
+ Operator::Modulo => compute::arithmetics::rem(lhs, rhs),
+ // TODO: show proper error message
+ _ => unreachable!(),
+ };
+ Ok(Arc::<dyn Array>::from(arr))
+ } else if matches!(op, Eq | NotEq | Lt | LtEq | Gt | GtEq) {
+ let arr = match op {
+ Operator::Eq => compute::comparison::eq(lhs, rhs),
+ Operator::NotEq => compute::comparison::neq(lhs, rhs),
+ Operator::Lt => compute::comparison::lt(lhs, rhs),
+ Operator::LtEq => compute::comparison::lt_eq(lhs, rhs),
+ Operator::Gt => compute::comparison::gt(lhs, rhs),
+ Operator::GtEq => compute::comparison::gt_eq(lhs, rhs),
+ // TODO: show proper error message
+ _ => unreachable!(),
+ };
+ Ok(Arc::new(arr) as Arc<dyn Array>)
+ } else if matches!(op, IsDistinctFrom) {
+ is_distinct_from(lhs, rhs)
+ } else if matches!(op, IsNotDistinctFrom) {
+ is_not_distinct_from(lhs, rhs)
+ } else if matches!(op, Or) {
+ boolean_op!(lhs, rhs, compute::boolean_kleene::or)
+ } else if matches!(op, And) {
+ boolean_op!(lhs, rhs, compute::boolean_kleene::and)
+ } else if matches!(op, BitwiseOr) {
+ bitwise_or(lhs, rhs)
+ } else if matches!(op, BitwiseAnd) {
+ bitwise_and(lhs, rhs)
+ } else {
+ match (lhs.data_type(), op, rhs.data_type()) {
+ (DataType::Utf8, Like, DataType::Utf8) => {
+ Ok(compute::like::like_utf8::<i32>(
+ lhs.as_any().downcast_ref().unwrap(),
+ rhs.as_any().downcast_ref().unwrap(),
+ )
+ .map(Arc::new)?)
}
- other => Err(DataFusionError::Internal(format!(
- "Data type {:?} not supported for binary_string_array_flag_op operation '{}' on string array",
- other, stringify!($OP)
+ (DataType::LargeUtf8, Like, DataType::LargeUtf8) => {
+ Ok(compute::like::like_utf8::<i64>(
+ lhs.as_any().downcast_ref().unwrap(),
+ rhs.as_any().downcast_ref().unwrap(),
+ )
+ .map(Arc::new)?)
+ }
+ (DataType::Utf8, NotLike, DataType::Utf8) => {
+ Ok(compute::like::nlike_utf8::<i32>(
+ lhs.as_any().downcast_ref().unwrap(),
+ rhs.as_any().downcast_ref().unwrap(),
+ )
+ .map(Arc::new)?)
+ }
+ (DataType::LargeUtf8, NotLike, DataType::LargeUtf8) => {
+ Ok(compute::like::nlike_utf8::<i64>(
+ lhs.as_any().downcast_ref().unwrap(),
+ rhs.as_any().downcast_ref().unwrap(),
+ )
+ .map(Arc::new)?)
+ }
+ (DataType::Utf8, RegexMatch, DataType::Utf8) => {
+ Ok(Arc::new(evaluate_regex::<i32>(lhs, rhs)?))
+ }
+ (DataType::Utf8, RegexIMatch, DataType::Utf8) => {
+ Ok(Arc::new(evaluate_regex_case_insensitive::<i32>(lhs, rhs)?))
+ }
+ (DataType::Utf8, RegexNotMatch, DataType::Utf8) => {
+ let re = evaluate_regex::<i32>(lhs, rhs)?;
+ Ok(Arc::new(compute::boolean::not(&re)))
+ }
+ (DataType::Utf8, RegexNotIMatch, DataType::Utf8) => {
+ let re = evaluate_regex_case_insensitive::<i32>(lhs, rhs)?;
+ Ok(Arc::new(compute::boolean::not(&re)))
+ }
+ (DataType::LargeUtf8, RegexMatch, DataType::LargeUtf8) => {
+ Ok(Arc::new(evaluate_regex::<i64>(lhs, rhs)?))
+ }
+ (DataType::LargeUtf8, RegexIMatch, DataType::LargeUtf8) => {
+ Ok(Arc::new(evaluate_regex_case_insensitive::<i64>(lhs, rhs)?))
+ }
+ (DataType::LargeUtf8, RegexNotMatch, DataType::LargeUtf8) => {
+ let re = evaluate_regex::<i64>(lhs, rhs)?;
+ Ok(Arc::new(compute::boolean::not(&re)))
+ }
+ (DataType::LargeUtf8, RegexNotIMatch, DataType::LargeUtf8) => {
+ let re = evaluate_regex_case_insensitive::<i64>(lhs, rhs)?;
+ Ok(Arc::new(compute::boolean::not(&re)))
+ }
+ (lhs, op, rhs) => Err(DataFusionError::Internal(format!(
+ "Cannot evaluate binary expression {:?} with types {:?} and {:?}",
+ op, lhs, rhs
))),
}
+ }
+}
+
+macro_rules! dyn_compute_scalar {
+ ($lhs:expr, $op:ident, $rhs:expr, $ty:ty) => {{
+ Arc::new(compute::arithmetics::basic::$op::<$ty>(
+ $lhs.as_any().downcast_ref().unwrap(),
+ &$rhs.clone().try_into().unwrap(),
+ ))
}};
}
-/// Invoke a compute kernel on a pair of binary data arrays with flags
-macro_rules! compute_utf8_flag_op {
- ($LEFT:expr, $RIGHT:expr, $OP:ident, $ARRAYTYPE:ident, $NOT:expr, $FLAG:expr) => {{
- let ll = $LEFT
- .as_any()
- .downcast_ref::<$ARRAYTYPE>()
- .expect("compute_utf8_flag_op failed to downcast array");
- let rr = $RIGHT
- .as_any()
- .downcast_ref::<$ARRAYTYPE>()
- .expect("compute_utf8_flag_op failed to downcast array");
-
- let flag = if $FLAG {
- Some($ARRAYTYPE::from(vec!["i"; ll.len()]))
- } else {
- None
- };
- let mut array = paste::expr! {[<$OP _utf8>]}(&ll, &rr, flag.as_ref())?;
- if $NOT {
- array = not(&array).unwrap();
+#[inline]
+fn evaluate_regex_scalar<O: Offset>(
+ values: &dyn Array,
+ regex: &ScalarValue,
+) -> Result<BooleanArray> {
+ let values = values.as_any().downcast_ref().unwrap();
+ let regex = match regex {
+ ScalarValue::Utf8(Some(s)) | ScalarValue::LargeUtf8(Some(s)) => s.as_str(),
+ _ => {
+ return Err(DataFusionError::Plan(format!(
+ "Regex pattern is not a valid string, got: {:?}",
+ regex,
+ )));
}
- Ok(Arc::new(array))
- }};
+ };
+ Ok(compute::regex_match::regex_match_scalar::<O>(
+ values, regex,
+ )?)
}
-macro_rules! binary_string_array_flag_op_scalar {
- ($LEFT:expr, $RIGHT:expr, $OP:ident, $NOT:expr, $FLAG:expr) => {{
- let result: Result<Arc<dyn Array>> = match $LEFT.data_type() {
- DataType::Utf8 => {
- compute_utf8_flag_op_scalar!($LEFT, $RIGHT, $OP, StringArray, $NOT, $FLAG)
+#[inline]
+fn evaluate_regex_scalar_case_insensitive<O: Offset>(
+ values: &dyn Array,
+ regex: &ScalarValue,
+) -> Result<BooleanArray> {
+ let values = values.as_any().downcast_ref().unwrap();
+ let regex = match regex {
+ ScalarValue::Utf8(Some(s)) | ScalarValue::LargeUtf8(Some(s)) => s.as_str(),
+ _ => {
+ return Err(DataFusionError::Plan(format!(
+ "Regex pattern is not a valid string, got: {:?}",
+ regex,
+ )));
+ }
+ };
+ Ok(compute::regex_match::regex_match_scalar::<O>(
+ values,
+ &format!("(?i){}", regex),
+ )?)
+}
+
+macro_rules! with_match_primitive_type {(
+ $key_type:expr, | $_:tt $T:ident | $($body:tt)*
+) => ({
+ macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )}
+ match $key_type {
+ DataType::Int8 => Some(__with_ty__! { i8 }),
+ DataType::Int16 => Some(__with_ty__! { i16 }),
+ DataType::Int32 => Some(__with_ty__! { i32 }),
+ DataType::Int64 => Some(__with_ty__! { i64 }),
+ DataType::UInt8 => Some(__with_ty__! { u8 }),
+ DataType::UInt16 => Some(__with_ty__! { u16 }),
+ DataType::UInt32 => Some(__with_ty__! { u32 }),
+ DataType::UInt64 => Some(__with_ty__! { u64 }),
+ DataType::Float32 => Some(__with_ty__! { f32 }),
+ DataType::Float64 => Some(__with_ty__! { f64 }),
+ _ => None,
+ }
+})}
+
+fn evaluate_scalar(
+ lhs: &dyn Array,
+ op: &Operator,
+ rhs: &ScalarValue,
+) -> Result<Option<Arc<dyn Array>>> {
+ use Operator::*;
+ if matches!(op, Plus | Minus | Divide | Multiply | Modulo) {
+ Ok(match op {
+ Plus => {
+ with_match_primitive_type!(lhs.data_type(), |$T| {
+ dyn_compute_scalar!(lhs, add_scalar, rhs, $T)
+ })
}
- DataType::LargeUtf8 => {
- compute_utf8_flag_op_scalar!($LEFT, $RIGHT, $OP, LargeStringArray, $NOT, $FLAG)
+ Minus => {
+ with_match_primitive_type!(lhs.data_type(), |$T| {
+ dyn_compute_scalar!(lhs, sub_scalar, rhs, $T)
+ })
}
- other => Err(DataFusionError::Internal(format!(
- "Data type {:?} not supported for binary_string_array_flag_op_scalar operation '{}' on string array",
- other, stringify!($OP)
+ Divide => {
+ with_match_primitive_type!(lhs.data_type(), |$T| {
+ dyn_compute_scalar!(lhs, div_scalar, rhs, $T)
+ })
+ }
+ Multiply => {
+ with_match_primitive_type!(lhs.data_type(), |$T| {
+ dyn_compute_scalar!(lhs, mul_scalar, rhs, $T)
+ })
+ }
+ Modulo => {
+ with_match_primitive_type!(lhs.data_type(), |$T| {
+ dyn_compute_scalar!(lhs, rem_scalar, rhs, $T)
+ })
+ }
+ _ => None, // fall back to default comparison below
+ })
+ } else if matches!(op, Eq | NotEq | Lt | LtEq | Gt | GtEq) {
+ let rhs: Result<Box<dyn Scalar>> = rhs.try_into();
+ match rhs {
+ Ok(rhs) => {
+ let arr = match op {
+ Operator::Eq => compute::comparison::eq_scalar(lhs, &*rhs),
+ Operator::NotEq => compute::comparison::neq_scalar(lhs, &*rhs),
+ Operator::Lt => compute::comparison::lt_scalar(lhs, &*rhs),
+ Operator::LtEq => compute::comparison::lt_eq_scalar(lhs, &*rhs),
+ Operator::Gt => compute::comparison::gt_scalar(lhs, &*rhs),
+ Operator::GtEq => compute::comparison::gt_eq_scalar(lhs, &*rhs),
+ _ => unreachable!(),
+ };
+ Ok(Some(Arc::new(arr) as Arc<dyn Array>))
+ }
+ Err(_) => {
+ // fall back to default comparison below
+ Ok(None)
+ }
+ }
+ } else if matches!(op, Or | And) {
+ // TODO: optimize scalar Or | And
+ Ok(None)
+ } else if matches!(op, BitwiseOr) {
+ bitwise_or_scalar(lhs, rhs.clone()).transpose()
+ } else if matches!(op, BitwiseAnd) {
+ bitwise_and_scalar(lhs, rhs.clone()).transpose()
+ } else {
+ match (lhs.data_type(), op) {
+ (DataType::Utf8, RegexMatch) => {
+ Ok(Some(Arc::new(evaluate_regex_scalar::<i32>(lhs, rhs)?)))
+ }
+ (DataType::Utf8, RegexIMatch) => Ok(Some(Arc::new(
+ evaluate_regex_scalar_case_insensitive::<i32>(lhs, rhs)?,
))),
- };
- Some(result)
- }};
+ (DataType::Utf8, RegexNotMatch) => Ok(Some(Arc::new(compute::boolean::not(
+ &evaluate_regex_scalar::<i32>(lhs, rhs)?,
+ )))),
+ (DataType::Utf8, RegexNotIMatch) => {
+ Ok(Some(Arc::new(compute::boolean::not(
+ &evaluate_regex_scalar_case_insensitive::<i32>(lhs, rhs)?,
+ ))))
+ }
+ (DataType::LargeUtf8, RegexMatch) => {
+ Ok(Some(Arc::new(evaluate_regex_scalar::<i64>(lhs, rhs)?)))
+ }
+ (DataType::LargeUtf8, RegexIMatch) => Ok(Some(Arc::new(
+ evaluate_regex_scalar_case_insensitive::<i64>(lhs, rhs)?,
+ ))),
+ (DataType::LargeUtf8, RegexNotMatch) => Ok(Some(Arc::new(
+ compute::boolean::not(&evaluate_regex_scalar::<i64>(lhs, rhs)?),
+ ))),
+ (DataType::LargeUtf8, RegexNotIMatch) => {
+ Ok(Some(Arc::new(compute::boolean::not(
+ &evaluate_regex_scalar_case_insensitive::<i64>(lhs, rhs)?,
+ ))))
+ }
+ _ => Ok(None),
+ }
+ }
}
-/// Invoke a compute kernel on a data array and a scalar value with flag
-macro_rules! compute_utf8_flag_op_scalar {
- ($LEFT:expr, $RIGHT:expr, $OP:ident, $ARRAYTYPE:ident, $NOT:expr, $FLAG:expr) => {{
- let ll = $LEFT
- .as_any()
- .downcast_ref::<$ARRAYTYPE>()
- .expect("compute_utf8_flag_op_scalar failed to downcast array");
-
- if let ScalarValue::Utf8(Some(string_value)) = $RIGHT {
- let flag = if $FLAG { Some("i") } else { None };
- let mut array =
- paste::expr! {[<$OP _utf8_scalar>]}(&ll, &string_value, flag)?;
- if $NOT {
- array = not(&array).unwrap();
- }
- Ok(Arc::new(array))
- } else {
- Err(DataFusionError::Internal(format!(
- "compute_utf8_flag_op_scalar failed to cast literal value {} for operation '{}'",
- $RIGHT, stringify!($OP)
- )))
- }
- }};
+fn evaluate_inverse_scalar(
+ lhs: &ScalarValue,
+ op: &Operator,
+ rhs: &dyn Array,
+) -> Result<Option<Arc<dyn Array>>> {
+ use Operator::*;
+ match op {
+ Lt => evaluate_scalar(rhs, &Gt, lhs),
+ Gt => evaluate_scalar(rhs, &Lt, lhs),
+ GtEq => evaluate_scalar(rhs, &LtEq, lhs),
+ LtEq => evaluate_scalar(rhs, &GtEq, lhs),
+ Eq => evaluate_scalar(rhs, &Eq, lhs),
+ NotEq => evaluate_scalar(rhs, &NotEq, lhs),
+ Plus => evaluate_scalar(rhs, &Plus, lhs),
+ Multiply => evaluate_scalar(rhs, &Multiply, lhs),
+ _ => Ok(None),
+ }
}
/// Returns the return type of a binary operator or an error when the binary operator cannot
@@ -1110,18 +712,16 @@
// Attempt to use special kernels if one input is scalar and the other is an array
let scalar_result = match (&left_value, &right_value) {
(ColumnarValue::Array(array), ColumnarValue::Scalar(scalar)) => {
- // if left is array and right is literal - use scalar operations
- self.evaluate_array_scalar(array, scalar)?
+ evaluate_scalar(array.as_ref(), &self.op, scalar)
}
(ColumnarValue::Scalar(scalar), ColumnarValue::Array(array)) => {
- // if right is literal and left is array - reverse operator and parameters
- self.evaluate_scalar_array(scalar, array)?
+ evaluate_inverse_scalar(scalar, &self.op, array.as_ref())
}
- (_, _) => None, // default to array implementation
- };
+ (_, _) => Ok(None),
+ }?;
if let Some(result) = scalar_result {
- return result.map(|a| ColumnarValue::Array(a));
+ return Ok(ColumnarValue::Array(result));
}
// if both arrays or both literals - extract arrays and continue execution
@@ -1129,263 +729,169 @@
left_value.into_array(batch.num_rows()),
right_value.into_array(batch.num_rows()),
);
- self.evaluate_with_resolved_args(left, &left_data_type, right, &right_data_type)
- .map(|a| ColumnarValue::Array(a))
+
+ let result = evaluate(left.as_ref(), &self.op, right.as_ref());
+ result.map(|a| ColumnarValue::Array(a))
}
}
-/// The binary_array_op_dyn_scalar macro includes types that extend beyond the primitive,
-/// such as Utf8 strings.
-#[macro_export]
-macro_rules! binary_array_op_dyn_scalar {
- ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
- let result: Result<Arc<dyn Array>> = match $RIGHT {
- ScalarValue::Boolean(b) => compute_bool_op_dyn_scalar!($LEFT, b, $OP),
- ScalarValue::Decimal128(..) => compute_decimal_op_scalar!($LEFT, $RIGHT, $OP, DecimalArray),
- ScalarValue::Utf8(v) => compute_utf8_op_dyn_scalar!($LEFT, v, $OP),
- ScalarValue::LargeUtf8(v) => compute_utf8_op_dyn_scalar!($LEFT, v, $OP),
- ScalarValue::Int8(v) => compute_op_dyn_scalar!($LEFT, v, $OP),
- ScalarValue::Int16(v) => compute_op_dyn_scalar!($LEFT, v, $OP),
- ScalarValue::Int32(v) => compute_op_dyn_scalar!($LEFT, v, $OP),
- ScalarValue::Int64(v) => compute_op_dyn_scalar!($LEFT, v, $OP),
- ScalarValue::UInt8(v) => compute_op_dyn_scalar!($LEFT, v, $OP),
- ScalarValue::UInt16(v) => compute_op_dyn_scalar!($LEFT, v, $OP),
- ScalarValue::UInt32(v) => compute_op_dyn_scalar!($LEFT, v, $OP),
- ScalarValue::UInt64(v) => compute_op_dyn_scalar!($LEFT, v, $OP),
- ScalarValue::Float32(_) => compute_op_scalar!($LEFT, $RIGHT, $OP, Float32Array),
- ScalarValue::Float64(_) => compute_op_scalar!($LEFT, $RIGHT, $OP, Float64Array),
- ScalarValue::Date32(_) => compute_op_scalar!($LEFT, $RIGHT, $OP, Date32Array),
- ScalarValue::Date64(_) => compute_op_scalar!($LEFT, $RIGHT, $OP, Date64Array),
- ScalarValue::TimestampSecond(..) => compute_op_scalar!($LEFT, $RIGHT, $OP, TimestampSecondArray),
- ScalarValue::TimestampMillisecond(..) => compute_op_scalar!($LEFT, $RIGHT, $OP, TimestampMillisecondArray),
- ScalarValue::TimestampMicrosecond(..) => compute_op_scalar!($LEFT, $RIGHT, $OP, TimestampMicrosecondArray),
- ScalarValue::TimestampNanosecond(..) => compute_op_scalar!($LEFT, $RIGHT, $OP, TimestampNanosecondArray),
- other => Err(DataFusionError::Internal(format!("Data type {:?} not supported for scalar operation '{}' on dyn array", other, stringify!($OP))))
- };
- Some(result)
- }}
+fn is_distinct_from_primitive<T: NativeType>(
+ left: &dyn Array,
+ right: &dyn Array,
+) -> BooleanArray {
+ let left = left
+ .as_any()
+ .downcast_ref::<PrimitiveArray<T>>()
+ .expect("distinct_from op failed to downcast to primitive array");
+ let right = right
+ .as_any()
+ .downcast_ref::<PrimitiveArray<T>>()
+ .expect("distinct_from op failed to downcast to primitive array");
+ left.iter()
+ .zip(right.iter())
+ .map(|(x, y)| Some(x != y))
+ .collect()
}
-impl BinaryExpr {
- /// Evaluate the expression of the left input is an array and
- /// right is literal - use scalar operations
- fn evaluate_array_scalar(
- &self,
- array: &dyn Array,
- scalar: &ScalarValue,
- ) -> Result<Option<Result<ArrayRef>>> {
- let scalar_result = match &self.op {
- Operator::Lt => {
- binary_array_op_dyn_scalar!(array, scalar.clone(), lt)
- }
- Operator::LtEq => {
- binary_array_op_dyn_scalar!(array, scalar.clone(), lt_eq)
- }
- Operator::Gt => {
- binary_array_op_dyn_scalar!(array, scalar.clone(), gt)
- }
- Operator::GtEq => {
- binary_array_op_dyn_scalar!(array, scalar.clone(), gt_eq)
- }
- Operator::Eq => {
- binary_array_op_dyn_scalar!(array, scalar.clone(), eq)
- }
- Operator::NotEq => {
- binary_array_op_dyn_scalar!(array, scalar.clone(), neq)
- }
- Operator::Like => {
- binary_string_array_op_scalar!(array, scalar.clone(), like)
- }
- Operator::NotLike => {
- binary_string_array_op_scalar!(array, scalar.clone(), nlike)
- }
- Operator::Plus => {
- binary_primitive_array_op_scalar!(array, scalar.clone(), add)
- }
- Operator::Minus => {
- binary_primitive_array_op_scalar!(array, scalar.clone(), subtract)
- }
- Operator::Multiply => {
- binary_primitive_array_op_scalar!(array, scalar.clone(), multiply)
- }
- Operator::Divide => {
- binary_primitive_array_op_scalar!(array, scalar.clone(), divide)
- }
- Operator::Modulo => {
- binary_primitive_array_op_scalar!(array, scalar.clone(), modulus)
- }
- Operator::RegexMatch => binary_string_array_flag_op_scalar!(
- array,
- scalar.clone(),
- regexp_is_match,
- false,
- false
- ),
- Operator::RegexIMatch => binary_string_array_flag_op_scalar!(
- array,
- scalar.clone(),
- regexp_is_match,
- false,
- true
- ),
- Operator::RegexNotMatch => binary_string_array_flag_op_scalar!(
- array,
- scalar.clone(),
- regexp_is_match,
- true,
- false
- ),
- Operator::RegexNotIMatch => binary_string_array_flag_op_scalar!(
- array,
- scalar.clone(),
- regexp_is_match,
- true,
- true
- ),
- Operator::BitwiseAnd => bitwise_and_scalar(array, scalar.clone()),
- Operator::BitwiseOr => bitwise_or_scalar(array, scalar.clone()),
- // if scalar operation is not supported - fallback to array implementation
- _ => None,
- };
+fn is_not_distinct_from_primitive<T: NativeType>(
+ left: &dyn Array,
+ right: &dyn Array,
+) -> BooleanArray {
+ let left = left
+ .as_any()
+ .downcast_ref::<PrimitiveArray<T>>()
+ .expect("not_distinct_from op failed to downcast to primitive array");
+ let right = right
+ .as_any()
+ .downcast_ref::<PrimitiveArray<T>>()
+ .expect("not_distinct_from op failed to downcast to primitive array");
+ left.iter()
+ .zip(right.iter())
+ .map(|(x, y)| Some(x == y))
+ .collect()
+}
- Ok(scalar_result)
- }
+fn is_distinct_from_utf8<O: Offset>(left: &dyn Array, right: &dyn Array) -> BooleanArray {
+ let left = left
+ .as_any()
+ .downcast_ref::<Utf8Array<O>>()
+ .expect("distinct_from op failed to downcast to utf8 array");
+ let right = right
+ .as_any()
+ .downcast_ref::<Utf8Array<O>>()
+ .expect("distinct_from op failed to downcast to utf8 array");
+ left.iter()
+ .zip(right.iter())
+ .map(|(x, y)| Some(x != y))
+ .collect()
+}
- /// Evaluate the expression if the left input is a literal and the
- /// right is an array - reverse operator and parameters
- fn evaluate_scalar_array(
- &self,
- scalar: &ScalarValue,
- array: &ArrayRef,
- ) -> Result<Option<Result<ArrayRef>>> {
- let scalar_result = match &self.op {
- Operator::Lt => binary_array_op_scalar!(array, scalar.clone(), gt),
- Operator::LtEq => binary_array_op_scalar!(array, scalar.clone(), gt_eq),
- Operator::Gt => binary_array_op_scalar!(array, scalar.clone(), lt),
- Operator::GtEq => binary_array_op_scalar!(array, scalar.clone(), lt_eq),
- Operator::Eq => binary_array_op_scalar!(array, scalar.clone(), eq),
- Operator::NotEq => {
- binary_array_op_scalar!(array, scalar.clone(), neq)
- }
- // if scalar operation is not supported - fallback to array implementation
- _ => None,
- };
- Ok(scalar_result)
- }
+fn is_not_distinct_from_utf8<O: Offset>(
+ left: &dyn Array,
+ right: &dyn Array,
+) -> BooleanArray {
+ let left = left
+ .as_any()
+ .downcast_ref::<Utf8Array<O>>()
+ .expect("not_distinct_from op failed to downcast to utf8 array");
+ let right = right
+ .as_any()
+ .downcast_ref::<Utf8Array<O>>()
+ .expect("not_distinct_from op failed to downcast to utf8 array");
+ left.iter()
+ .zip(right.iter())
+ .map(|(x, y)| Some(x == y))
+ .collect()
+}
- fn evaluate_with_resolved_args(
- &self,
- left: Arc<dyn Array>,
- left_data_type: &DataType,
- right: Arc<dyn Array>,
- right_data_type: &DataType,
- ) -> Result<ArrayRef> {
- match &self.op {
- Operator::Like => binary_string_array_op!(left, right, like),
- Operator::NotLike => binary_string_array_op!(left, right, nlike),
- Operator::Lt => lt_dyn(&left, &right),
- Operator::LtEq => lt_eq_dyn(&left, &right),
- Operator::Gt => gt_dyn(&left, &right),
- Operator::GtEq => gt_eq_dyn(&left, &right),
- Operator::Eq => eq_dyn(&left, &right),
- Operator::NotEq => neq_dyn(&left, &right),
- Operator::IsDistinctFrom => binary_array_op!(left, right, is_distinct_from),
- Operator::IsNotDistinctFrom => {
- binary_array_op!(left, right, is_not_distinct_from)
- }
- Operator::Plus => binary_primitive_array_op!(left, right, add),
- Operator::Minus => binary_primitive_array_op!(left, right, subtract),
- Operator::Multiply => binary_primitive_array_op!(left, right, multiply),
- Operator::Divide => binary_primitive_array_op!(left, right, divide),
- Operator::Modulo => binary_primitive_array_op!(left, right, modulus),
- Operator::And => {
- if left_data_type == &DataType::Boolean {
- boolean_op!(left, right, and_kleene)
- } else {
- return Err(DataFusionError::Internal(format!(
- "Cannot evaluate binary expression {:?} with types {:?} and {:?}",
- self.op,
- left.data_type(),
- right.data_type()
- )));
- }
- }
- Operator::Or => {
- if left_data_type == &DataType::Boolean {
- boolean_op!(left, right, or_kleene)
- } else {
- return Err(DataFusionError::Internal(format!(
- "Cannot evaluate binary expression {:?} with types {:?} and {:?}",
- self.op, left_data_type, right_data_type
- )));
- }
- }
- Operator::RegexMatch => {
- binary_string_array_flag_op!(left, right, regexp_is_match, false, false)
- }
- Operator::RegexIMatch => {
- binary_string_array_flag_op!(left, right, regexp_is_match, false, true)
- }
- Operator::RegexNotMatch => {
- binary_string_array_flag_op!(left, right, regexp_is_match, true, false)
- }
- Operator::RegexNotIMatch => {
- binary_string_array_flag_op!(left, right, regexp_is_match, true, true)
- }
- Operator::BitwiseAnd => bitwise_and(left, right),
- Operator::BitwiseOr => bitwise_or(left, right),
+fn is_distinct_from(left: &dyn Array, right: &dyn Array) -> Result<Arc<dyn Array>> {
+ match (left.data_type(), right.data_type()) {
+ (DataType::Int8, DataType::Int8) => {
+ Ok(Arc::new(is_distinct_from_primitive::<i8>(left, right)))
}
+ (DataType::Int32, DataType::Int32) => {
+ Ok(Arc::new(is_distinct_from_primitive::<i32>(left, right)))
+ }
+ (DataType::Int64, DataType::Int64) => {
+ Ok(Arc::new(is_distinct_from_primitive::<i64>(left, right)))
+ }
+ (DataType::UInt8, DataType::UInt8) => {
+ Ok(Arc::new(is_distinct_from_primitive::<u8>(left, right)))
+ }
+ (DataType::UInt16, DataType::UInt16) => {
+ Ok(Arc::new(is_distinct_from_primitive::<u16>(left, right)))
+ }
+ (DataType::UInt32, DataType::UInt32) => {
+ Ok(Arc::new(is_distinct_from_primitive::<u32>(left, right)))
+ }
+ (DataType::UInt64, DataType::UInt64) => {
+ Ok(Arc::new(is_distinct_from_primitive::<u64>(left, right)))
+ }
+ (DataType::Float32, DataType::Float32) => {
+ Ok(Arc::new(is_distinct_from_primitive::<f32>(left, right)))
+ }
+ (DataType::Float64, DataType::Float64) => {
+ Ok(Arc::new(is_distinct_from_primitive::<f64>(left, right)))
+ }
+ (DataType::Boolean, DataType::Boolean) => {
+ Ok(Arc::new(is_distinct_from_bool(left, right)))
+ }
+ (DataType::Utf8, DataType::Utf8) => {
+ Ok(Arc::new(is_distinct_from_utf8::<i32>(left, right)))
+ }
+ (DataType::LargeUtf8, DataType::LargeUtf8) => {
+ Ok(Arc::new(is_distinct_from_utf8::<i64>(left, right)))
+ }
+ (lhs, rhs) => Err(DataFusionError::Internal(format!(
+ "Cannot evaluate is_distinct_from expression with types {:?} and {:?}",
+ lhs, rhs
+ ))),
}
}
-fn is_distinct_from<T>(
- left: &PrimitiveArray<T>,
- right: &PrimitiveArray<T>,
-) -> Result<BooleanArray>
-where
- T: ArrowNumericType,
-{
- Ok(left
- .iter()
- .zip(right.iter())
- .map(|(x, y)| Some(x != y))
- .collect())
-}
-
-fn is_distinct_from_utf8<OffsetSize: StringOffsetSizeTrait>(
- left: &GenericStringArray<OffsetSize>,
- right: &GenericStringArray<OffsetSize>,
-) -> Result<BooleanArray> {
- Ok(left
- .iter()
- .zip(right.iter())
- .map(|(x, y)| Some(x != y))
- .collect())
-}
-
-fn is_not_distinct_from<T>(
- left: &PrimitiveArray<T>,
- right: &PrimitiveArray<T>,
-) -> Result<BooleanArray>
-where
- T: ArrowNumericType,
-{
- Ok(left
- .iter()
- .zip(right.iter())
- .map(|(x, y)| Some(x == y))
- .collect())
-}
-
-fn is_not_distinct_from_utf8<OffsetSize: StringOffsetSizeTrait>(
- left: &GenericStringArray<OffsetSize>,
- right: &GenericStringArray<OffsetSize>,
-) -> Result<BooleanArray> {
- Ok(left
- .iter()
- .zip(right.iter())
- .map(|(x, y)| Some(x == y))
- .collect())
+fn is_not_distinct_from(left: &dyn Array, right: &dyn Array) -> Result<Arc<dyn Array>> {
+ match (left.data_type(), right.data_type()) {
+ (DataType::Int8, DataType::Int8) => {
+ Ok(Arc::new(is_not_distinct_from_primitive::<i8>(left, right)))
+ }
+ (DataType::Int32, DataType::Int32) => {
+ Ok(Arc::new(is_not_distinct_from_primitive::<i32>(left, right)))
+ }
+ (DataType::Int64, DataType::Int64) => {
+ Ok(Arc::new(is_not_distinct_from_primitive::<i64>(left, right)))
+ }
+ (DataType::UInt8, DataType::UInt8) => {
+ Ok(Arc::new(is_not_distinct_from_primitive::<u8>(left, right)))
+ }
+ (DataType::UInt16, DataType::UInt16) => {
+ Ok(Arc::new(is_not_distinct_from_primitive::<u16>(left, right)))
+ }
+ (DataType::UInt32, DataType::UInt32) => {
+ Ok(Arc::new(is_not_distinct_from_primitive::<u32>(left, right)))
+ }
+ (DataType::UInt64, DataType::UInt64) => {
+ Ok(Arc::new(is_not_distinct_from_primitive::<u64>(left, right)))
+ }
+ (DataType::Float32, DataType::Float32) => {
+ Ok(Arc::new(is_not_distinct_from_primitive::<f32>(left, right)))
+ }
+ (DataType::Float64, DataType::Float64) => {
+ Ok(Arc::new(is_not_distinct_from_primitive::<f64>(left, right)))
+ }
+ (DataType::Boolean, DataType::Boolean) => {
+ Ok(Arc::new(is_not_distinct_from_bool(left, right)))
+ }
+ (DataType::Utf8, DataType::Utf8) => {
+ Ok(Arc::new(is_not_distinct_from_utf8::<i32>(left, right)))
+ }
+ (DataType::LargeUtf8, DataType::LargeUtf8) => {
+ Ok(Arc::new(is_not_distinct_from_utf8::<i64>(left, right)))
+ }
+ (lhs, rhs) => Err(DataFusionError::Internal(format!(
+ "Cannot evaluate is_not_distinct_from expression with types {:?} and {:?}",
+ lhs, rhs
+ ))),
+ }
}
/// return two physical expressions that are optionally coerced to a
@@ -1422,11 +928,301 @@
#[cfg(test)]
mod tests {
+ use arrow::datatypes::*;
+ use arrow::{array::*, types::NativeType};
+
use super::*;
+
use crate::expressions::{col, lit};
- use arrow::datatypes::{ArrowNumericType, Field, Int32Type, SchemaRef};
- use arrow::util::display::array_value_to_string;
- use datafusion_common::Result;
+ use crate::test_util::create_decimal_array;
+ use arrow::datatypes::{Field, SchemaRef};
+ use arrow::error::ArrowError;
+ use datafusion_common::field_util::SchemaExt;
+
+ // TODO add iter for decimal array
+ // TODO move this to arrow-rs
+ // https://github.com/apache/arrow-rs/issues/1083
+ pub(super) fn eq_decimal_scalar(
+ left: &Int128Array,
+ right: i128,
+ ) -> Result<BooleanArray> {
+ let mut bool_builder = MutableBooleanArray::with_capacity(left.len());
+ for i in 0..left.len() {
+ if left.is_null(i) {
+ bool_builder.push(None);
+ } else {
+ bool_builder.try_push(Some(left.value(i) == right))?;
+ }
+ }
+ Ok(bool_builder.into())
+ }
+
+ pub(super) fn eq_decimal(
+ left: &Int128Array,
+ right: &Int128Array,
+ ) -> Result<BooleanArray> {
+ let mut bool_builder = MutableBooleanArray::with_capacity(left.len());
+ for i in 0..left.len() {
+ if left.is_null(i) || right.is_null(i) {
+ bool_builder.push(None);
+ } else {
+ bool_builder.try_push(Some(left.value(i) == right.value(i)))?;
+ }
+ }
+ Ok(bool_builder.into())
+ }
+
+ fn neq_decimal_scalar(left: &Int128Array, right: i128) -> Result<BooleanArray> {
+ let mut bool_builder = MutableBooleanArray::with_capacity(left.len());
+ for i in 0..left.len() {
+ if left.is_null(i) {
+ bool_builder.push(None);
+ } else {
+ bool_builder.try_push(Some(left.value(i) != right))?;
+ }
+ }
+ Ok(bool_builder.into())
+ }
+
+ fn neq_decimal(left: &Int128Array, right: &Int128Array) -> Result<BooleanArray> {
+ let mut bool_builder = MutableBooleanArray::with_capacity(left.len());
+ for i in 0..left.len() {
+ if left.is_null(i) || right.is_null(i) {
+ bool_builder.push(None);
+ } else {
+ bool_builder.try_push(Some(left.value(i) != right.value(i)))?;
+ }
+ }
+ Ok(bool_builder.into())
+ }
+
+ fn lt_decimal_scalar(left: &Int128Array, right: i128) -> Result<BooleanArray> {
+ let mut bool_builder = MutableBooleanArray::with_capacity(left.len());
+ for i in 0..left.len() {
+ if left.is_null(i) {
+ bool_builder.push(None);
+ } else {
+ bool_builder.try_push(Some(left.value(i) < right))?;
+ }
+ }
+ Ok(bool_builder.into())
+ }
+
+ fn lt_decimal(left: &Int128Array, right: &Int128Array) -> Result<BooleanArray> {
+ let mut bool_builder = MutableBooleanArray::with_capacity(left.len());
+ for i in 0..left.len() {
+ if left.is_null(i) || right.is_null(i) {
+ bool_builder.push(None);
+ } else {
+ bool_builder.try_push(Some(left.value(i) < right.value(i)))?;
+ }
+ }
+ Ok(bool_builder.into())
+ }
+
+ fn lt_eq_decimal_scalar(left: &Int128Array, right: i128) -> Result<BooleanArray> {
+ let mut bool_builder = MutableBooleanArray::with_capacity(left.len());
+ for i in 0..left.len() {
+ if left.is_null(i) {
+ bool_builder.push(None);
+ } else {
+ bool_builder.try_push(Some(left.value(i) <= right))?;
+ }
+ }
+ Ok(bool_builder.into())
+ }
+
+ fn lt_eq_decimal(left: &Int128Array, right: &Int128Array) -> Result<BooleanArray> {
+ let mut bool_builder = MutableBooleanArray::with_capacity(left.len());
+ for i in 0..left.len() {
+ if left.is_null(i) || right.is_null(i) {
+ bool_builder.push(None);
+ } else {
+ bool_builder.try_push(Some(left.value(i) <= right.value(i)))?;
+ }
+ }
+ Ok(bool_builder.into())
+ }
+
+ fn gt_decimal_scalar(left: &Int128Array, right: i128) -> Result<BooleanArray> {
+ let mut bool_builder = MutableBooleanArray::with_capacity(left.len());
+ for i in 0..left.len() {
+ if left.is_null(i) {
+ bool_builder.push(None);
+ } else {
+ bool_builder.try_push(Some(left.value(i) > right))?;
+ }
+ }
+ Ok(bool_builder.into())
+ }
+
+ fn gt_decimal(left: &Int128Array, right: &Int128Array) -> Result<BooleanArray> {
+ let mut bool_builder = MutableBooleanArray::with_capacity(left.len());
+ for i in 0..left.len() {
+ if left.is_null(i) || right.is_null(i) {
+ bool_builder.push(None);
+ } else {
+ bool_builder.try_push(Some(left.value(i) > right.value(i)))?;
+ }
+ }
+ Ok(bool_builder.into())
+ }
+
+ fn gt_eq_decimal_scalar(left: &Int128Array, right: i128) -> Result<BooleanArray> {
+ let mut bool_builder = MutableBooleanArray::with_capacity(left.len());
+ for i in 0..left.len() {
+ if left.is_null(i) {
+ bool_builder.push(None);
+ } else {
+ bool_builder.try_push(Some(left.value(i) >= right))?;
+ }
+ }
+ Ok(bool_builder.into())
+ }
+
+ fn gt_eq_decimal(left: &Int128Array, right: &Int128Array) -> Result<BooleanArray> {
+ let mut bool_builder = MutableBooleanArray::with_capacity(left.len());
+ for i in 0..left.len() {
+ if left.is_null(i) || right.is_null(i) {
+ bool_builder.push(None);
+ } else {
+ bool_builder.try_push(Some(left.value(i) >= right.value(i)))?;
+ }
+ }
+ Ok(bool_builder.into())
+ }
+
+ fn is_distinct_from_decimal(
+ left: &Int128Array,
+ right: &Int128Array,
+ ) -> Result<BooleanArray> {
+ let mut bool_builder = MutableBooleanArray::with_capacity(left.len());
+ for i in 0..left.len() {
+ match (left.is_null(i), right.is_null(i)) {
+ (true, true) => bool_builder.try_push(Some(false))?,
+ (true, false) | (false, true) => bool_builder.try_push(Some(true))?,
+ (_, _) => bool_builder.try_push(Some(left.value(i) != right.value(i)))?,
+ }
+ }
+ Ok(bool_builder.into())
+ }
+
+ fn is_not_distinct_from_decimal(
+ left: &Int128Array,
+ right: &Int128Array,
+ ) -> Result<BooleanArray> {
+ let mut bool_builder = MutableBooleanArray::with_capacity(left.len());
+ for i in 0..left.len() {
+ match (left.is_null(i), right.is_null(i)) {
+ (true, true) => bool_builder.try_push(Some(true))?,
+ (true, false) | (false, true) => bool_builder.try_push(Some(false))?,
+ (_, _) => bool_builder.try_push(Some(left.value(i) == right.value(i)))?,
+ }
+ }
+ Ok(bool_builder.into())
+ }
+
+ fn add_decimal(left: &Int128Array, right: &Int128Array) -> Result<Int128Array> {
+ let mut decimal_builder = Int128Vec::from_data(
+ left.data_type().clone(),
+ Vec::<i128>::with_capacity(left.len()),
+ None,
+ );
+ for i in 0..left.len() {
+ if left.is_null(i) || right.is_null(i) {
+ decimal_builder.push(None);
+ } else {
+ decimal_builder.try_push(Some(left.value(i) + right.value(i)))?;
+ }
+ }
+ Ok(decimal_builder.into())
+ }
+
+ fn subtract_decimal(left: &Int128Array, right: &Int128Array) -> Result<Int128Array> {
+ let mut decimal_builder = Int128Vec::from_data(
+ left.data_type().clone(),
+ Vec::<i128>::with_capacity(left.len()),
+ None,
+ );
+ for i in 0..left.len() {
+ if left.is_null(i) || right.is_null(i) {
+ decimal_builder.push(None);
+ } else {
+ decimal_builder.try_push(Some(left.value(i) - right.value(i)))?;
+ }
+ }
+ Ok(decimal_builder.into())
+ }
+
+ fn multiply_decimal(
+ left: &Int128Array,
+ right: &Int128Array,
+ scale: u32,
+ ) -> Result<Int128Array> {
+ let mut decimal_builder = Int128Vec::from_data(
+ left.data_type().clone(),
+ Vec::<i128>::with_capacity(left.len()),
+ None,
+ );
+ let divide = 10_i128.pow(scale);
+ for i in 0..left.len() {
+ if left.is_null(i) || right.is_null(i) {
+ decimal_builder.push(None);
+ } else {
+ decimal_builder
+ .try_push(Some(left.value(i) * right.value(i) / divide))?;
+ }
+ }
+ Ok(decimal_builder.into())
+ }
+
+ fn divide_decimal(
+ left: &Int128Array,
+ right: &Int128Array,
+ scale: i32,
+ ) -> Result<Int128Array> {
+ let mut decimal_builder = Int128Vec::from_data(
+ left.data_type().clone(),
+ Vec::<i128>::with_capacity(left.len()),
+ None,
+ );
+ let mul = 10_f64.powi(scale);
+ for i in 0..left.len() {
+ if left.is_null(i) || right.is_null(i) {
+ decimal_builder.push(None);
+ } else if right.value(i) == 0 {
+ return Err(DataFusionError::ArrowError(
+ ArrowError::InvalidArgumentError("Cannot divide by zero".to_string()),
+ ));
+ } else {
+ let l_value = left.value(i) as f64;
+ let r_value = right.value(i) as f64;
+ let result = ((l_value / r_value) * mul) as i128;
+ decimal_builder.try_push(Some(result))?;
+ }
+ }
+ Ok(decimal_builder.into())
+ }
+
+ fn modulus_decimal(left: &Int128Array, right: &Int128Array) -> Result<Int128Array> {
+ let mut decimal_builder = Int128Vec::from_data(
+ left.data_type().clone(),
+ Vec::<i128>::with_capacity(left.len()),
+ None,
+ );
+ for i in 0..left.len() {
+ if left.is_null(i) || right.is_null(i) {
+ decimal_builder.push(None);
+ } else if right.value(i) == 0 {
+ return Err(DataFusionError::ArrowError(
+ ArrowError::InvalidArgumentError("Cannot divide by zero".to_string()),
+ ));
+ } else {
+ decimal_builder.try_push(Some(left.value(i) % right.value(i)))?;
+ }
+ }
+ Ok(decimal_builder.into())
+ }
// Create a binary expression without coercion. Used here when we do not want to coerce the expressions
// to valid types. Usage can result in an execution (after plan) error.
@@ -1445,8 +1241,8 @@
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
]);
- let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
- let b = Int32Array::from(vec![1, 2, 4, 8, 16]);
+ let a = Int32Array::from_slice(vec![1, 2, 3, 4, 5]);
+ let b = Int32Array::from_slice(vec![1, 2, 4, 8, 16]);
// expression: "a < b"
let lt = binary_simple(
@@ -1479,8 +1275,8 @@
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
]);
- let a = Int32Array::from(vec![2, 4, 6, 8, 10]);
- let b = Int32Array::from(vec![2, 5, 4, 8, 8]);
+ let a = Int32Array::from_slice(vec![2, 4, 6, 8, 10]);
+ let b = Int32Array::from_slice(vec![2, 5, 4, 8, 8]);
// expression: "a < b OR a == b"
let expr = binary_simple(
@@ -1527,273 +1323,130 @@
// 4. verify that the resulting expression is of type C
// 5. verify that the results of evaluation are $VEC
macro_rules! test_coercion {
- ($A_ARRAY:ident, $A_TYPE:expr, $A_VEC:expr, $B_ARRAY:ident, $B_TYPE:expr, $B_VEC:expr, $OP:expr, $C_ARRAY:ident, $C_TYPE:expr, $VEC:expr) => {{
+ ($A_ARRAY:ident, $B_ARRAY:ident, $OP:expr, $C_ARRAY:ident) => {{
let schema = Schema::new(vec![
- Field::new("a", $A_TYPE, false),
- Field::new("b", $B_TYPE, false),
+ Field::new("a", $A_ARRAY.data_type().clone(), false),
+ Field::new("b", $B_ARRAY.data_type().clone(), false),
]);
- let a = $A_ARRAY::from($A_VEC);
- let b = $B_ARRAY::from($B_VEC);
-
// verify that we can construct the expression
let expression =
binary(col("a", &schema)?, $OP, col("b", &schema)?, &schema)?;
let batch = RecordBatch::try_new(
Arc::new(schema.clone()),
- vec![Arc::new(a), Arc::new(b)],
+ vec![Arc::new($A_ARRAY), Arc::new($B_ARRAY)],
)?;
// verify that the expression's type is correct
- assert_eq!(expression.data_type(&schema)?, $C_TYPE);
+ assert_eq!(&expression.data_type(&schema)?, $C_ARRAY.data_type());
// compute
let result = expression.evaluate(&batch)?.into_array(batch.num_rows());
- // verify that the array's data_type is correct
- assert_eq!(*result.data_type(), $C_TYPE);
-
- // verify that the data itself is downcastable
- let result = result
- .as_any()
- .downcast_ref::<$C_ARRAY>()
- .expect("failed to downcast");
- // verify that the result itself is correct
- for (i, x) in $VEC.iter().enumerate() {
- assert_eq!(result.value(i), *x);
- }
+ // verify that the array is equal
+ assert_eq!($C_ARRAY, result.as_ref());
}};
}
#[test]
fn test_type_coersion() -> Result<()> {
- test_coercion!(
- Int32Array,
- DataType::Int32,
- vec![1i32, 2i32],
- UInt32Array,
- DataType::UInt32,
- vec![1u32, 2u32],
- Operator::Plus,
- Int32Array,
- DataType::Int32,
- vec![2i32, 4i32]
- );
- test_coercion!(
- Int32Array,
- DataType::Int32,
- vec![1i32],
- UInt16Array,
- DataType::UInt16,
- vec![1u16],
- Operator::Plus,
- Int32Array,
- DataType::Int32,
- vec![2i32]
- );
- test_coercion!(
- Float32Array,
- DataType::Float32,
- vec![1f32],
- UInt16Array,
- DataType::UInt16,
- vec![1u16],
- Operator::Plus,
- Float32Array,
- DataType::Float32,
- vec![2f32]
- );
- test_coercion!(
- Float32Array,
- DataType::Float32,
- vec![2f32],
- UInt16Array,
- DataType::UInt16,
- vec![1u16],
- Operator::Multiply,
- Float32Array,
- DataType::Float32,
- vec![2f32]
- );
- test_coercion!(
- StringArray,
- DataType::Utf8,
- vec!["hello world", "world"],
- StringArray,
- DataType::Utf8,
- vec!["%hello%", "%hello%"],
- Operator::Like,
- BooleanArray,
- DataType::Boolean,
- vec![true, false]
- );
- test_coercion!(
- StringArray,
- DataType::Utf8,
- vec!["1994-12-13", "1995-01-26"],
- Date32Array,
- DataType::Date32,
- vec![9112, 9156],
- Operator::Eq,
- BooleanArray,
- DataType::Boolean,
- vec![true, true]
- );
- test_coercion!(
- StringArray,
- DataType::Utf8,
- vec!["1994-12-13", "1995-01-26"],
- Date32Array,
- DataType::Date32,
- vec![9113, 9154],
- Operator::Lt,
- BooleanArray,
- DataType::Boolean,
- vec![true, false]
- );
- test_coercion!(
- StringArray,
- DataType::Utf8,
- vec!["1994-12-13T12:34:56", "1995-01-26T01:23:45"],
- Date64Array,
- DataType::Date64,
- vec![787322096000, 791083425000],
- Operator::Eq,
- BooleanArray,
- DataType::Boolean,
- vec![true, true]
- );
- test_coercion!(
- StringArray,
- DataType::Utf8,
- vec!["1994-12-13T12:34:56", "1995-01-26T01:23:45"],
- Date64Array,
- DataType::Date64,
- vec![787322096001, 791083424999],
- Operator::Lt,
- BooleanArray,
- DataType::Boolean,
- vec![true, false]
- );
- test_coercion!(
- StringArray,
- DataType::Utf8,
- vec!["abc"; 5],
- StringArray,
- DataType::Utf8,
- vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
- Operator::RegexMatch,
- BooleanArray,
- DataType::Boolean,
- vec![true, false, true, false, false]
- );
- test_coercion!(
- StringArray,
- DataType::Utf8,
- vec!["abc"; 5],
- StringArray,
- DataType::Utf8,
- vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
- Operator::RegexIMatch,
- BooleanArray,
- DataType::Boolean,
- vec![true, true, true, true, false]
- );
- test_coercion!(
- StringArray,
- DataType::Utf8,
- vec!["abc"; 5],
- StringArray,
- DataType::Utf8,
- vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
- Operator::RegexNotMatch,
- BooleanArray,
- DataType::Boolean,
- vec![false, true, false, true, true]
- );
- test_coercion!(
- StringArray,
- DataType::Utf8,
- vec!["abc"; 5],
- StringArray,
- DataType::Utf8,
- vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
- Operator::RegexNotIMatch,
- BooleanArray,
- DataType::Boolean,
- vec![false, false, false, false, true]
- );
- test_coercion!(
- LargeStringArray,
- DataType::LargeUtf8,
- vec!["abc"; 5],
- LargeStringArray,
- DataType::LargeUtf8,
- vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
- Operator::RegexMatch,
- BooleanArray,
- DataType::Boolean,
- vec![true, false, true, false, false]
- );
- test_coercion!(
- LargeStringArray,
- DataType::LargeUtf8,
- vec!["abc"; 5],
- LargeStringArray,
- DataType::LargeUtf8,
- vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
- Operator::RegexIMatch,
- BooleanArray,
- DataType::Boolean,
- vec![true, true, true, true, false]
- );
- test_coercion!(
- LargeStringArray,
- DataType::LargeUtf8,
- vec!["abc"; 5],
- LargeStringArray,
- DataType::LargeUtf8,
- vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
- Operator::RegexNotMatch,
- BooleanArray,
- DataType::Boolean,
- vec![false, true, false, true, true]
- );
- test_coercion!(
- LargeStringArray,
- DataType::LargeUtf8,
- vec!["abc"; 5],
- LargeStringArray,
- DataType::LargeUtf8,
- vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
- Operator::RegexNotIMatch,
- BooleanArray,
- DataType::Boolean,
- vec![false, false, false, false, true]
- );
- test_coercion!(
- Int16Array,
- DataType::Int16,
- vec![1i16, 2i16, 3i16],
- Int64Array,
- DataType::Int64,
- vec![10i64, 4i64, 5i64],
- Operator::BitwiseAnd,
- Int64Array,
- DataType::Int64,
- vec![0i64, 0i64, 1i64]
- );
- test_coercion!(
- Int16Array,
- DataType::Int16,
- vec![1i16, 2i16, 3i16],
- Int64Array,
- DataType::Int64,
- vec![10i64, 4i64, 5i64],
- Operator::BitwiseOr,
- Int64Array,
- DataType::Int64,
- vec![11i64, 6i64, 7i64]
- );
+ let a = Int32Array::from_slice(&[1, 2]);
+ let b = UInt32Array::from_slice(&[1, 2]);
+ let c = Int32Array::from_slice(&[2, 4]);
+ test_coercion!(a, b, Operator::Plus, c);
+
+ let a = Int32Array::from_slice(&[1]);
+ let b = UInt32Array::from_slice(&[1]);
+ let c = Int32Array::from_slice(&[2]);
+ test_coercion!(a, b, Operator::Plus, c);
+
+ let a = Int32Array::from_slice(&[1]);
+ let b = UInt16Array::from_slice(&[1]);
+ let c = Int32Array::from_slice(&[2]);
+ test_coercion!(a, b, Operator::Plus, c);
+
+ let a = Float32Array::from_slice(&[1.0]);
+ let b = UInt16Array::from_slice(&[1]);
+ let c = Float32Array::from_slice(&[2.0]);
+ test_coercion!(a, b, Operator::Plus, c);
+
+ let a = Float32Array::from_slice(&[1.0]);
+ let b = UInt16Array::from_slice(&[1]);
+ let c = Float32Array::from_slice(&[1.0]);
+ test_coercion!(a, b, Operator::Multiply, c);
+
+ let a = Utf8Array::<i32>::from_slice(&["hello world"]);
+ let b = Utf8Array::<i32>::from_slice(&["%hello%"]);
+ let c = BooleanArray::from_slice(&[true]);
+ test_coercion!(a, b, Operator::Like, c);
+
+ let a = Utf8Array::<i32>::from_slice(&["1994-12-13"]);
+ let b = Int32Array::from_slice(&[9112]).to(DataType::Date32);
+ let c = BooleanArray::from_slice(&[true]);
+ test_coercion!(a, b, Operator::Eq, c);
+
+ let a = Utf8Array::<i32>::from_slice(&["1994-12-13", "1995-01-26"]);
+ let b = Int32Array::from_slice(&[9113, 9154]).to(DataType::Date32);
+ let c = BooleanArray::from_slice(&[true, false]);
+ test_coercion!(a, b, Operator::Lt, c);
+
+ let a =
+ Utf8Array::<i32>::from_slice(&["1994-12-13T12:34:56", "1995-01-26T01:23:45"]);
+ let b =
+ Int64Array::from_slice(&[787322096000, 791083425000]).to(DataType::Date64);
+ let c = BooleanArray::from_slice(&[true, true]);
+ test_coercion!(a, b, Operator::Eq, c);
+
+ let a =
+ Utf8Array::<i32>::from_slice(&["1994-12-13T12:34:56", "1995-01-26T01:23:45"]);
+ let b =
+ Int64Array::from_slice(&[787322096001, 791083424999]).to(DataType::Date64);
+ let c = BooleanArray::from_slice(&[true, false]);
+ test_coercion!(a, b, Operator::Lt, c);
+
+ let a = Utf8Array::<i32>::from_slice(["abc"; 5]);
+ let b = Utf8Array::<i32>::from_slice(["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"]);
+ let c = BooleanArray::from_slice(&[true, false, true, false, false]);
+ test_coercion!(a, b, Operator::RegexMatch, c);
+
+ let a = Utf8Array::<i32>::from_slice(["abc"; 5]);
+ let b = Utf8Array::<i32>::from_slice(["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"]);
+ let c = BooleanArray::from_slice(&[true, true, true, true, false]);
+ test_coercion!(a, b, Operator::RegexIMatch, c);
+
+ let a = Utf8Array::<i32>::from_slice(["abc"; 5]);
+ let b = Utf8Array::<i32>::from_slice(["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"]);
+ let c = BooleanArray::from_slice(&[false, true, false, true, true]);
+ test_coercion!(a, b, Operator::RegexNotMatch, c);
+
+ let a = Utf8Array::<i32>::from_slice(["abc"; 5]);
+ let b = Utf8Array::<i32>::from_slice(["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"]);
+ let c = BooleanArray::from_slice(&[false, false, false, false, true]);
+ test_coercion!(a, b, Operator::RegexNotIMatch, c);
+
+ let a = Utf8Array::<i64>::from_slice(["abc"; 5]);
+ let b = Utf8Array::<i64>::from_slice(["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"]);
+ let c = BooleanArray::from_slice(&[true, false, true, false, false]);
+ test_coercion!(a, b, Operator::RegexMatch, c);
+
+ let a = Utf8Array::<i64>::from_slice(["abc"; 5]);
+ let b = Utf8Array::<i64>::from_slice(["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"]);
+ let c = BooleanArray::from_slice(&[true, true, true, true, false]);
+ test_coercion!(a, b, Operator::RegexIMatch, c);
+
+ let a = Utf8Array::<i64>::from_slice(["abc"; 5]);
+ let b = Utf8Array::<i64>::from_slice(["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"]);
+ let c = BooleanArray::from_slice(&[false, true, false, true, true]);
+ test_coercion!(a, b, Operator::RegexNotMatch, c);
+
+ let a = Utf8Array::<i64>::from_slice(["abc"; 5]);
+ let b = Utf8Array::<i64>::from_slice(["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"]);
+ let c = BooleanArray::from_slice(&[false, false, false, false, true]);
+ test_coercion!(a, b, Operator::RegexNotIMatch, c);
+
+ let a = Int16Array::from_slice(&[1i16, 2i16, 3i16]);
+ let b = Int64Array::from_slice(&[10i64, 4i64, 5i64]);
+ let c = Int64Array::from_slice(&[0i64, 0i64, 1i64]);
+ test_coercion!(a, b, Operator::BitwiseAnd, c);
Ok(())
}
@@ -1805,35 +1458,25 @@
#[test]
fn test_dictionary_type_to_array_coersion() -> Result<()> {
// Test string a string dictionary
- let dict_type =
- DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8));
- let string_type = DataType::Utf8;
- // build dictionary
- let keys_builder = PrimitiveBuilder::<Int32Type>::new(10);
- let values_builder = arrow::array::StringBuilder::new(10);
- let mut dict_builder = StringDictionaryBuilder::new(keys_builder, values_builder);
+ let data = vec![Some("one"), None, Some("three"), Some("four")];
- dict_builder.append("one")?;
- dict_builder.append_null()?;
- dict_builder.append("three")?;
- dict_builder.append("four")?;
- let dict_array = dict_builder.finish();
+ let mut dict_array = MutableDictionaryArray::<i32, MutableUtf8Array<i32>>::new();
+ dict_array.try_extend(data)?;
+ let dict_array = dict_array.into_arc();
let str_array =
- StringArray::from(vec![Some("not one"), Some("two"), None, Some("four")]);
+ Utf8Array::<i32>::from(&[Some("not one"), Some("two"), None, Some("four")]);
let schema = Arc::new(Schema::new(vec![
- Field::new("dict", dict_type, true),
- Field::new("str", string_type, true),
+ Field::new("dict", dict_array.data_type().clone(), true),
+ Field::new("str", str_array.data_type().clone(), true),
]));
- let batch = RecordBatch::try_new(
- schema.clone(),
- vec![Arc::new(dict_array), Arc::new(str_array)],
- )?;
+ let batch =
+ RecordBatch::try_new(schema.clone(), vec![dict_array, Arc::new(str_array)])?;
- let expected = "false\n\n\ntrue";
+ let expected = BooleanArray::from(&[Some(false), None, None, Some(true)]);
// Test 1: dict = str
@@ -1851,7 +1494,7 @@
assert_eq!(result.data_type(), &DataType::Boolean);
// verify that the result itself is correct
- assert_eq!(expected, array_to_string(&result)?);
+ assert_eq!(expected, result.as_ref());
// Test 2: now test the other direction
// str = dict
@@ -1870,34 +1513,25 @@
assert_eq!(result.data_type(), &DataType::Boolean);
// verify that the result itself is correct
- assert_eq!(expected, array_to_string(&result)?);
+ assert_eq!(expected, result.as_ref());
Ok(())
}
- // Convert the array to a newline delimited string of pretty printed values
- fn array_to_string(array: &ArrayRef) -> Result<String> {
- let s = (0..array.len())
- .map(|i| array_value_to_string(array, i))
- .collect::<std::result::Result<Vec<_>, arrow::error::ArrowError>>()?
- .join("\n");
- Ok(s)
- }
-
#[test]
fn plus_op() -> Result<()> {
let schema = Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
]);
- let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
- let b = Int32Array::from(vec![1, 2, 4, 8, 16]);
+ let a = Int32Array::from_slice(vec![1, 2, 3, 4, 5]);
+ let b = Int32Array::from_slice(vec![1, 2, 4, 8, 16]);
- apply_arithmetic::<Int32Type>(
+ apply_arithmetic::<i32>(
Arc::new(schema),
vec![Arc::new(a), Arc::new(b)],
Operator::Plus,
- Int32Array::from(vec![2, 4, 7, 12, 21]),
+ Int32Array::from_slice(vec![2, 4, 7, 12, 21]),
)?;
Ok(())
@@ -1909,22 +1543,22 @@
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
]));
- let a = Arc::new(Int32Array::from(vec![1, 2, 4, 8, 16]));
- let b = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
+ let a = Arc::new(Int32Array::from_slice(vec![1, 2, 4, 8, 16]));
+ let b = Arc::new(Int32Array::from_slice(vec![1, 2, 3, 4, 5]));
- apply_arithmetic::<Int32Type>(
+ apply_arithmetic::<i32>(
schema.clone(),
vec![a.clone(), b.clone()],
Operator::Minus,
- Int32Array::from(vec![0, 0, 1, 4, 11]),
+ Int32Array::from_slice(vec![0, 0, 1, 4, 11]),
)?;
// should handle have negative values in result (for signed)
- apply_arithmetic::<Int32Type>(
+ apply_arithmetic::<i32>(
schema,
vec![b, a],
Operator::Minus,
- Int32Array::from(vec![0, 0, -1, -4, -11]),
+ Int32Array::from_slice(vec![0, 0, -1, -4, -11]),
)?;
Ok(())
@@ -1936,14 +1570,14 @@
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
]));
- let a = Arc::new(Int32Array::from(vec![4, 8, 16, 32, 64]));
- let b = Arc::new(Int32Array::from(vec![2, 4, 8, 16, 32]));
+ let a = Arc::new(Int32Array::from_slice(vec![4, 8, 16, 32, 64]));
+ let b = Arc::new(Int32Array::from_slice(vec![2, 4, 8, 16, 32]));
- apply_arithmetic::<Int32Type>(
+ apply_arithmetic::<i32>(
schema,
vec![a, b],
Operator::Multiply,
- Int32Array::from(vec![8, 32, 128, 512, 2048]),
+ Int32Array::from_slice(vec![8, 32, 128, 512, 2048]),
)?;
Ok(())
@@ -1955,70 +1589,70 @@
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
]));
- let a = Arc::new(Int32Array::from(vec![8, 32, 128, 512, 2048]));
- let b = Arc::new(Int32Array::from(vec![2, 4, 8, 16, 32]));
+ let a = Arc::new(Int32Array::from_slice(vec![8, 32, 128, 512, 2048]));
+ let b = Arc::new(Int32Array::from_slice(vec![2, 4, 8, 16, 32]));
- apply_arithmetic::<Int32Type>(
+ apply_arithmetic::<i32>(
schema,
vec![a, b],
Operator::Divide,
- Int32Array::from(vec![4, 8, 16, 32, 64]),
+ Int32Array::from_slice(vec![4, 8, 16, 32, 64]),
)?;
Ok(())
}
+ fn apply_arithmetic<T: NativeType>(
+ schema: Arc<Schema>,
+ data: Vec<Arc<dyn Array>>,
+ op: Operator,
+ expected: PrimitiveArray<T>,
+ ) -> Result<()> {
+ let arithmetic_op =
+ binary_simple(col("a", &schema)?, op, col("b", &schema)?, &schema);
+ let batch = RecordBatch::try_new(schema, data)?;
+ let result = arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows());
+
+ assert_eq!(expected, result.as_ref());
+ Ok(())
+ }
+
+ fn apply_logic_op(
+ schema: &Arc<Schema>,
+ left: &ArrayRef,
+ right: &ArrayRef,
+ op: Operator,
+ expected: ArrayRef,
+ ) -> Result<()> {
+ let arithmetic_op =
+ binary_simple(col("a", schema)?, op, col("b", schema)?, schema);
+ let data: Vec<ArrayRef> = vec![left.clone(), right.clone()];
+ let batch = RecordBatch::try_new(schema.clone(), data)?;
+ let result = arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows());
+
+ assert_eq!(expected, result);
+ Ok(())
+ }
+
#[test]
fn modulus_op() -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
]));
- let a = Arc::new(Int32Array::from(vec![8, 32, 128, 512, 2048]));
- let b = Arc::new(Int32Array::from(vec![2, 4, 7, 14, 32]));
+ let a = Arc::new(Int32Array::from_slice(&[8, 32, 128, 512, 2048]));
+ let b = Arc::new(Int32Array::from_slice(&[2, 4, 7, 14, 32]));
- apply_arithmetic::<Int32Type>(
+ apply_arithmetic::<i32>(
schema,
vec![a, b],
Operator::Modulo,
- Int32Array::from(vec![0, 0, 2, 8, 0]),
+ Int32Array::from_slice(&[0, 0, 2, 8, 0]),
)?;
Ok(())
}
- fn apply_arithmetic<T: ArrowNumericType>(
- schema: SchemaRef,
- data: Vec<ArrayRef>,
- op: Operator,
- expected: PrimitiveArray<T>,
- ) -> Result<()> {
- let arithmetic_op =
- binary_simple(col("a", &schema)?, op, col("b", &schema)?, &schema);
- let batch = RecordBatch::try_new(schema, data)?;
- let result = arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows());
-
- assert_eq!(result.as_ref(), &expected);
- Ok(())
- }
-
- fn apply_logic_op(
- schema: &SchemaRef,
- left: &ArrayRef,
- right: &ArrayRef,
- op: Operator,
- expected: BooleanArray,
- ) -> Result<()> {
- let arithmetic_op =
- binary_simple(col("a", schema)?, op, col("b", schema)?, schema);
- let data: Vec<ArrayRef> = vec![left.clone(), right.clone()];
- let batch = RecordBatch::try_new(schema.clone(), data)?;
- let result = arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows());
-
- assert_eq!(result.as_ref(), &expected);
- Ok(())
- }
-
// Test `scalar <op> arr` produces expected
fn apply_logic_op_scalar_arr(
schema: &SchemaRef,
@@ -2032,7 +1666,7 @@
let arithmetic_op = binary_simple(scalar, op, col("a", schema)?, schema);
let batch = RecordBatch::try_new(Arc::clone(schema), vec![Arc::clone(arr)])?;
let result = arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows());
- assert_eq!(result.as_ref(), expected);
+ assert_eq!(result.as_ref(), expected as &dyn Array);
Ok(())
}
@@ -2050,7 +1684,7 @@
let arithmetic_op = binary_simple(col("a", schema)?, op, scalar, schema);
let batch = RecordBatch::try_new(Arc::clone(schema), vec![Arc::clone(arr)])?;
let result = arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows());
- assert_eq!(result.as_ref(), expected);
+ assert_eq!(result.as_ref(), expected as &dyn Array);
Ok(())
}
@@ -2061,7 +1695,7 @@
Field::new("a", DataType::Boolean, true),
Field::new("b", DataType::Boolean, true),
]);
- let a = Arc::new(BooleanArray::from(vec![
+ let a = Arc::new(BooleanArray::from_iter(vec![
Some(true),
Some(false),
None,
@@ -2072,7 +1706,7 @@
Some(false),
None,
])) as ArrayRef;
- let b = Arc::new(BooleanArray::from(vec![
+ let b = Arc::new(BooleanArray::from_iter(vec![
Some(true),
Some(true),
Some(true),
@@ -2084,7 +1718,7 @@
None,
])) as ArrayRef;
- let expected = BooleanArray::from(vec![
+ let expected = BooleanArray::from_iter(vec![
Some(true),
Some(false),
None,
@@ -2095,7 +1729,7 @@
Some(false),
None,
]);
- apply_logic_op(&Arc::new(schema), &a, &b, Operator::And, expected)?;
+ apply_logic_op(&Arc::new(schema), &a, &b, Operator::And, Arc::new(expected))?;
Ok(())
}
@@ -2106,7 +1740,7 @@
Field::new("a", DataType::Boolean, true),
Field::new("b", DataType::Boolean, true),
]);
- let a = Arc::new(BooleanArray::from(vec![
+ let a = Arc::new(BooleanArray::from_iter(vec![
Some(true),
Some(false),
None,
@@ -2117,7 +1751,7 @@
Some(false),
None,
])) as ArrayRef;
- let b = Arc::new(BooleanArray::from(vec![
+ let b = Arc::new(BooleanArray::from_iter(vec![
Some(true),
Some(true),
Some(true),
@@ -2129,7 +1763,7 @@
None,
])) as ArrayRef;
- let expected = BooleanArray::from(vec![
+ let expected = BooleanArray::from_iter(vec![
Some(true),
Some(true),
Some(true),
@@ -2140,7 +1774,7 @@
None,
None,
]);
- apply_logic_op(&Arc::new(schema), &a, &b, Operator::Or, expected)?;
+ apply_logic_op(&Arc::new(schema), &a, &b, Operator::Or, Arc::new(expected))?;
Ok(())
}
@@ -2193,7 +1827,7 @@
#[test]
fn eq_op_bool() {
let (schema, a, b) = bool_test_arrays();
- let expected = vec![
+ let expected = BooleanArray::from_iter(vec![
Some(true),
None,
Some(false),
@@ -2203,10 +1837,8 @@
Some(false),
None,
Some(true),
- ]
- .iter()
- .collect();
- apply_logic_op(&schema, &a, &b, Operator::Eq, expected).unwrap();
+ ]);
+ apply_logic_op(&schema, &a, &b, Operator::Eq, Arc::new(expected)).unwrap();
}
#[test]
@@ -2252,7 +1884,7 @@
#[test]
fn neq_op_bool() {
let (schema, a, b) = bool_test_arrays();
- let expected = [
+ let expected = BooleanArray::from_iter([
Some(false),
None,
Some(true),
@@ -2262,10 +1894,8 @@
Some(true),
None,
Some(false),
- ]
- .iter()
- .collect();
- apply_logic_op(&schema, &a, &b, Operator::NotEq, expected).unwrap();
+ ]);
+ apply_logic_op(&schema, &a, &b, Operator::NotEq, Arc::new(expected)).unwrap();
}
#[test]
@@ -2311,7 +1941,7 @@
#[test]
fn lt_op_bool() {
let (schema, a, b) = bool_test_arrays();
- let expected = [
+ let expected = BooleanArray::from_iter([
Some(false),
None,
Some(false),
@@ -2321,10 +1951,8 @@
Some(true),
None,
Some(false),
- ]
- .iter()
- .collect();
- apply_logic_op(&schema, &a, &b, Operator::Lt, expected).unwrap();
+ ]);
+ apply_logic_op(&schema, &a, &b, Operator::Lt, Arc::new(expected)).unwrap();
}
#[test]
@@ -2374,7 +2002,7 @@
#[test]
fn lt_eq_op_bool() {
let (schema, a, b) = bool_test_arrays();
- let expected = [
+ let expected = BooleanArray::from_iter([
Some(true),
None,
Some(false),
@@ -2384,10 +2012,8 @@
Some(true),
None,
Some(true),
- ]
- .iter()
- .collect();
- apply_logic_op(&schema, &a, &b, Operator::LtEq, expected).unwrap();
+ ]);
+ apply_logic_op(&schema, &a, &b, Operator::LtEq, Arc::new(expected)).unwrap();
}
#[test]
@@ -2437,7 +2063,7 @@
#[test]
fn gt_op_bool() {
let (schema, a, b) = bool_test_arrays();
- let expected = [
+ let expected = BooleanArray::from_iter([
Some(false),
None,
Some(true),
@@ -2447,16 +2073,14 @@
Some(false),
None,
Some(false),
- ]
- .iter()
- .collect();
- apply_logic_op(&schema, &a, &b, Operator::Gt, expected).unwrap();
+ ]);
+ apply_logic_op(&schema, &a, &b, Operator::Gt, Arc::new(expected)).unwrap();
}
#[test]
fn gt_op_bool_scalar() {
let (schema, a) = scalar_bool_test_array();
- let expected = [Some(false), None, Some(true)].iter().collect();
+ let expected = BooleanArray::from_iter([Some(false), None, Some(true)]);
apply_logic_op_scalar_arr(
&schema,
&ScalarValue::from(true),
@@ -2466,7 +2090,7 @@
)
.unwrap();
- let expected = [Some(false), None, Some(false)].iter().collect();
+ let expected = BooleanArray::from_iter([Some(false), None, Some(false)]);
apply_logic_op_arr_scalar(
&schema,
&a,
@@ -2476,7 +2100,7 @@
)
.unwrap();
- let expected = [Some(false), None, Some(false)].iter().collect();
+ let expected = BooleanArray::from_iter([Some(false), None, Some(false)]);
apply_logic_op_scalar_arr(
&schema,
&ScalarValue::from(false),
@@ -2486,7 +2110,7 @@
)
.unwrap();
- let expected = [Some(true), None, Some(false)].iter().collect();
+ let expected = BooleanArray::from_iter([Some(true), None, Some(false)]);
apply_logic_op_arr_scalar(
&schema,
&a,
@@ -2500,7 +2124,7 @@
#[test]
fn gt_eq_op_bool() {
let (schema, a, b) = bool_test_arrays();
- let expected = [
+ let expected = BooleanArray::from_iter([
Some(true),
None,
Some(true),
@@ -2510,16 +2134,14 @@
Some(false),
None,
Some(true),
- ]
- .iter()
- .collect();
- apply_logic_op(&schema, &a, &b, Operator::GtEq, expected).unwrap();
+ ]);
+ apply_logic_op(&schema, &a, &b, Operator::GtEq, Arc::new(expected)).unwrap();
}
#[test]
fn gt_eq_op_bool_scalar() {
let (schema, a) = scalar_bool_test_array();
- let expected = [Some(true), None, Some(true)].iter().collect();
+ let expected = BooleanArray::from_iter([Some(true), None, Some(true)]);
apply_logic_op_scalar_arr(
&schema,
&ScalarValue::from(true),
@@ -2529,7 +2151,7 @@
)
.unwrap();
- let expected = [Some(true), None, Some(false)].iter().collect();
+ let expected = BooleanArray::from_iter([Some(true), None, Some(false)]);
apply_logic_op_arr_scalar(
&schema,
&a,
@@ -2539,7 +2161,7 @@
)
.unwrap();
- let expected = [Some(false), None, Some(true)].iter().collect();
+ let expected = BooleanArray::from_iter([Some(false), None, Some(true)]);
apply_logic_op_scalar_arr(
&schema,
&ScalarValue::from(false),
@@ -2549,7 +2171,7 @@
)
.unwrap();
- let expected = [Some(true), None, Some(true)].iter().collect();
+ let expected = BooleanArray::from_iter([Some(true), None, Some(true)]);
apply_logic_op_arr_scalar(
&schema,
&a,
@@ -2563,7 +2185,7 @@
#[test]
fn is_distinct_from_op_bool() {
let (schema, a, b) = bool_test_arrays();
- let expected = [
+ let expected = BooleanArray::from_iter([
Some(false),
Some(true),
Some(true),
@@ -2573,16 +2195,21 @@
Some(true),
Some(true),
Some(false),
- ]
- .iter()
- .collect();
- apply_logic_op(&schema, &a, &b, Operator::IsDistinctFrom, expected).unwrap();
+ ]);
+ apply_logic_op(
+ &schema,
+ &a,
+ &b,
+ Operator::IsDistinctFrom,
+ Arc::new(expected),
+ )
+ .unwrap();
}
#[test]
fn is_not_distinct_from_op_bool() {
let (schema, a, b) = bool_test_arrays();
- let expected = [
+ let expected = BooleanArray::from_iter([
Some(true),
Some(false),
Some(false),
@@ -2592,10 +2219,15 @@
Some(false),
Some(false),
Some(true),
- ]
- .iter()
- .collect();
- apply_logic_op(&schema, &a, &b, Operator::IsNotDistinctFrom, expected).unwrap();
+ ]);
+ apply_logic_op(
+ &schema,
+ &a,
+ &b,
+ Operator::IsNotDistinctFrom,
+ Arc::new(expected),
+ )
+ .unwrap();
}
#[test]
@@ -2616,7 +2248,7 @@
let expr = (0..tree_depth)
.into_iter()
.map(|_| col("a", schema.as_ref()).unwrap())
- .reduce(|l, r| binary_simple(l, Operator::Plus, r, &schema))
+ .reduce(|l, r| binary_simple(l, Operator::Plus, r, schema))
.unwrap();
let result = expr
@@ -2628,26 +2260,7 @@
.into_iter()
.map(|i| i.map(|i| i * tree_depth))
.collect();
- assert_eq!(result.as_ref(), &expected);
- }
-
- fn create_decimal_array(
- array: &[Option<i128>],
- precision: usize,
- scale: usize,
- ) -> Result<DecimalArray> {
- let mut decimal_builder = DecimalBuilder::new(array.len(), precision, scale);
- for value in array {
- match value {
- None => {
- decimal_builder.append_null()?;
- }
- Some(v) => {
- decimal_builder.append_value(*v)?;
- }
- }
- }
- Ok(decimal_builder.finish())
+ assert_eq!(result.as_ref(), &expected as &dyn Array);
}
#[test]
@@ -2666,37 +2279,37 @@
// eq: array = i128
let result = eq_decimal_scalar(&decimal_array, value_i128)?;
assert_eq!(
- BooleanArray::from(vec![Some(true), None, Some(false), Some(false)]),
+ BooleanArray::from_iter(vec![Some(true), None, Some(false), Some(false)]),
result
);
// neq: array != i128
let result = neq_decimal_scalar(&decimal_array, value_i128)?;
assert_eq!(
- BooleanArray::from(vec![Some(false), None, Some(true), Some(true)]),
+ BooleanArray::from_iter(vec![Some(false), None, Some(true), Some(true)]),
result
);
// lt: array < i128
let result = lt_decimal_scalar(&decimal_array, value_i128)?;
assert_eq!(
- BooleanArray::from(vec![Some(false), None, Some(true), Some(false)]),
+ BooleanArray::from_iter(vec![Some(false), None, Some(true), Some(false)]),
result
);
// lt_eq: array <= i128
let result = lt_eq_decimal_scalar(&decimal_array, value_i128)?;
assert_eq!(
- BooleanArray::from(vec![Some(true), None, Some(true), Some(false)]),
+ BooleanArray::from_iter(vec![Some(true), None, Some(true), Some(false)]),
result
);
// gt: array > i128
let result = gt_decimal_scalar(&decimal_array, value_i128)?;
assert_eq!(
- BooleanArray::from(vec![Some(false), None, Some(false), Some(true)]),
+ BooleanArray::from_iter(vec![Some(false), None, Some(false), Some(true)]),
result
);
// gt_eq: array >= i128
let result = gt_eq_decimal_scalar(&decimal_array, value_i128)?;
assert_eq!(
- BooleanArray::from(vec![Some(true), None, Some(false), Some(true)]),
+ BooleanArray::from_iter(vec![Some(true), None, Some(false), Some(true)]),
result
);
@@ -2714,50 +2327,60 @@
// eq: left == right
let result = eq_decimal(&left_decimal_array, &right_decimal_array)?;
assert_eq!(
- BooleanArray::from(vec![Some(false), None, Some(false), Some(true)]),
+ BooleanArray::from_iter(vec![Some(false), None, Some(false), Some(true)]),
result
);
// neq: left != right
let result = neq_decimal(&left_decimal_array, &right_decimal_array)?;
assert_eq!(
- BooleanArray::from(vec![Some(true), None, Some(true), Some(false)]),
+ BooleanArray::from_iter(vec![Some(true), None, Some(true), Some(false)]),
result
);
// lt: left < right
let result = lt_decimal(&left_decimal_array, &right_decimal_array)?;
assert_eq!(
- BooleanArray::from(vec![Some(false), None, Some(true), Some(false)]),
+ BooleanArray::from_iter(vec![Some(false), None, Some(true), Some(false)]),
result
);
// lt_eq: left <= right
let result = lt_eq_decimal(&left_decimal_array, &right_decimal_array)?;
assert_eq!(
- BooleanArray::from(vec![Some(false), None, Some(true), Some(true)]),
+ BooleanArray::from_iter(vec![Some(false), None, Some(true), Some(true)]),
result
);
// gt: left > right
let result = gt_decimal(&left_decimal_array, &right_decimal_array)?;
assert_eq!(
- BooleanArray::from(vec![Some(true), None, Some(false), Some(false)]),
+ BooleanArray::from_iter(vec![Some(true), None, Some(false), Some(false)]),
result
);
// gt_eq: left >= right
let result = gt_eq_decimal(&left_decimal_array, &right_decimal_array)?;
assert_eq!(
- BooleanArray::from(vec![Some(true), None, Some(false), Some(true)]),
+ BooleanArray::from_iter(vec![Some(true), None, Some(false), Some(true)]),
result
);
// is_distinct: left distinct right
let result = is_distinct_from_decimal(&left_decimal_array, &right_decimal_array)?;
assert_eq!(
- BooleanArray::from(vec![Some(true), Some(true), Some(true), Some(false)]),
+ BooleanArray::from_iter(vec![
+ Some(true),
+ Some(true),
+ Some(true),
+ Some(false)
+ ]),
result
);
// is_distinct: left distinct right
let result =
is_not_distinct_from_decimal(&left_decimal_array, &right_decimal_array)?;
assert_eq!(
- BooleanArray::from(vec![Some(false), Some(false), Some(false), Some(true)]),
+ BooleanArray::from_iter(vec![
+ Some(false),
+ Some(false),
+ Some(false),
+ Some(true)
+ ]),
result
);
Ok(())
@@ -2771,39 +2394,42 @@
apply_logic_op_scalar_arr(
&schema,
&decimal_scalar,
- &(Arc::new(Int64Array::from(vec![Some(124), None])) as ArrayRef),
+ &(Arc::new(Int64Array::from_iter(vec![Some(124), None])) as ArrayRef),
Operator::Eq,
- &BooleanArray::from(vec![Some(false), None]),
+ &BooleanArray::from_iter(vec![Some(false), None]),
)
.unwrap();
// array != scalar
apply_logic_op_arr_scalar(
&schema,
- &(Arc::new(Int64Array::from(vec![Some(123), None, Some(1)])) as ArrayRef),
+ &(Arc::new(Int64Array::from_iter(vec![Some(123), None, Some(1)]))
+ as ArrayRef),
&decimal_scalar,
Operator::NotEq,
- &BooleanArray::from(vec![Some(true), None, Some(true)]),
+ &BooleanArray::from_iter(vec![Some(true), None, Some(true)]),
)
.unwrap();
// array < scalar
apply_logic_op_arr_scalar(
&schema,
- &(Arc::new(Int64Array::from(vec![Some(123), None, Some(124)])) as ArrayRef),
+ &(Arc::new(Int64Array::from_iter(vec![Some(123), None, Some(124)]))
+ as ArrayRef),
&decimal_scalar,
Operator::Lt,
- &BooleanArray::from(vec![Some(true), None, Some(false)]),
+ &BooleanArray::from_iter(vec![Some(true), None, Some(false)]),
)
.unwrap();
// array > scalar
apply_logic_op_arr_scalar(
&schema,
- &(Arc::new(Int64Array::from(vec![Some(123), None, Some(124)])) as ArrayRef),
+ &(Arc::new(Int64Array::from_iter(vec![Some(123), None, Some(124)]))
+ as ArrayRef),
&decimal_scalar,
Operator::Gt,
- &BooleanArray::from(vec![Some(false), None, Some(true)]),
+ &BooleanArray::from_iter(vec![Some(false), None, Some(true)]),
)
.unwrap();
@@ -2812,18 +2438,21 @@
// array == scalar
apply_logic_op_arr_scalar(
&schema,
- &(Arc::new(Float64Array::from(vec![Some(123.456), None, Some(123.457)]))
- as ArrayRef),
+ &(Arc::new(Float64Array::from_iter(vec![
+ Some(123.456),
+ None,
+ Some(123.457),
+ ])) as ArrayRef),
&decimal_scalar,
Operator::Eq,
- &BooleanArray::from(vec![Some(true), None, Some(false)]),
+ &BooleanArray::from_iter(vec![Some(true), None, Some(false)]),
)
.unwrap();
// array <= scalar
apply_logic_op_arr_scalar(
&schema,
- &(Arc::new(Float64Array::from(vec![
+ &(Arc::new(Float64Array::from_iter(vec![
Some(123.456),
None,
Some(123.457),
@@ -2831,13 +2460,13 @@
])) as ArrayRef),
&decimal_scalar,
Operator::LtEq,
- &BooleanArray::from(vec![Some(true), None, Some(false), Some(true)]),
+ &BooleanArray::from_iter(vec![Some(true), None, Some(false), Some(true)]),
)
.unwrap();
// array >= scalar
apply_logic_op_arr_scalar(
&schema,
- &(Arc::new(Float64Array::from(vec![
+ &(Arc::new(Float64Array::from_iter(vec![
Some(123.456),
None,
Some(123.457),
@@ -2845,7 +2474,7 @@
])) as ArrayRef),
&decimal_scalar,
Operator::GtEq,
- &BooleanArray::from(vec![Some(true), None, Some(true), Some(false)]),
+ &BooleanArray::from_iter(vec![Some(true), None, Some(true), Some(false)]),
)
.unwrap();
@@ -2868,7 +2497,7 @@
0,
)?) as ArrayRef;
- let int64_array = Arc::new(Int64Array::from(vec![
+ let int64_array = Arc::new(Int64Array::from_iter(vec![
Some(value),
Some(value - 1),
Some(value),
@@ -2881,7 +2510,12 @@
&int64_array,
&decimal_array,
Operator::Eq,
- BooleanArray::from(vec![Some(true), None, Some(false), Some(true)]),
+ Arc::new(BooleanArray::from_iter(vec![
+ Some(true),
+ None,
+ Some(false),
+ Some(true),
+ ])),
)
.unwrap();
// neq: int64array != decimal array
@@ -2890,7 +2524,12 @@
&int64_array,
&decimal_array,
Operator::NotEq,
- BooleanArray::from(vec![Some(false), None, Some(true), Some(false)]),
+ Arc::new(BooleanArray::from_iter(vec![
+ Some(false),
+ None,
+ Some(true),
+ Some(false),
+ ])),
)
.unwrap();
@@ -2910,7 +2549,7 @@
10,
2,
)?) as ArrayRef;
- let float64_array = Arc::new(Float64Array::from(vec![
+ let float64_array = Arc::new(Float64Array::from_iter(vec![
Some(1.23),
Some(1.22),
Some(1.23),
@@ -2922,7 +2561,12 @@
&float64_array,
&decimal_array,
Operator::Lt,
- BooleanArray::from(vec![Some(false), None, Some(false), Some(false)]),
+ Arc::new(BooleanArray::from_iter(vec![
+ Some(false),
+ None,
+ Some(false),
+ Some(false),
+ ])),
)
.unwrap();
// lt_eq: float64array <= decimal array
@@ -2931,7 +2575,12 @@
&float64_array,
&decimal_array,
Operator::LtEq,
- BooleanArray::from(vec![Some(true), None, Some(false), Some(true)]),
+ Arc::new(BooleanArray::from_iter(vec![
+ Some(true),
+ None,
+ Some(false),
+ Some(true),
+ ])),
)
.unwrap();
// gt: float64array > decimal array
@@ -2940,7 +2589,12 @@
&float64_array,
&decimal_array,
Operator::Gt,
- BooleanArray::from(vec![Some(false), None, Some(true), Some(false)]),
+ Arc::new(BooleanArray::from_iter(vec![
+ Some(false),
+ None,
+ Some(true),
+ Some(false),
+ ])),
)
.unwrap();
apply_logic_op(
@@ -2948,7 +2602,12 @@
&float64_array,
&decimal_array,
Operator::GtEq,
- BooleanArray::from(vec![Some(true), None, Some(true), Some(true)]),
+ Arc::new(BooleanArray::from_iter(vec![
+ Some(true),
+ None,
+ Some(true),
+ Some(true),
+ ])),
)
.unwrap();
// is distinct: float64array is distinct decimal array
@@ -2960,7 +2619,12 @@
&float64_array,
&decimal_array,
Operator::IsDistinctFrom,
- BooleanArray::from(vec![Some(false), Some(true), Some(true), Some(false)]),
+ Arc::new(BooleanArray::from_iter(vec![
+ Some(false),
+ Some(true),
+ Some(true),
+ Some(false),
+ ])),
)
.unwrap();
// is not distinct
@@ -2969,7 +2633,12 @@
&float64_array,
&decimal_array,
Operator::IsNotDistinctFrom,
- BooleanArray::from(vec![Some(true), Some(false), Some(false), Some(true)]),
+ Arc::new(BooleanArray::from_iter(vec![
+ Some(true),
+ Some(false),
+ Some(false),
+ Some(true),
+ ])),
)
.unwrap();
@@ -3009,7 +2678,7 @@
let expect = create_decimal_array(&[Some(0), None, Some(-1), Some(1)], 25, 3)?;
assert_eq!(expect, result);
// multiply
- let result = multiply_decimal(&left_decimal_array, &right_decimal_array)?;
+ let result = multiply_decimal(&left_decimal_array, &right_decimal_array, 3)?;
let expect = create_decimal_array(&[Some(15), None, Some(15), Some(15)], 25, 3)?;
assert_eq!(expect, result);
// divide
@@ -3020,7 +2689,7 @@
)?;
let right_decimal_array =
create_decimal_array(&[Some(10), Some(100), Some(55), Some(-123)], 25, 3)?;
- let result = divide_decimal(&left_decimal_array, &right_decimal_array)?;
+ let result = divide_decimal(&left_decimal_array, &right_decimal_array, 3)?;
let expect = create_decimal_array(
&[Some(123456700), None, Some(22446672), Some(-10037130)],
25,
@@ -3069,7 +2738,7 @@
10,
2,
)?) as ArrayRef;
- let int32_array = Arc::new(Int32Array::from(vec![
+ let int32_array = Arc::new(Int32Array::from_iter(vec![
Some(123),
Some(122),
Some(123),
@@ -3171,30 +2840,32 @@
#[test]
fn bitwise_array_test() -> Result<()> {
- let left = Arc::new(Int32Array::from(vec![Some(12), None, Some(11)])) as ArrayRef;
+ let left =
+ Arc::new(Int32Array::from_iter(vec![Some(12), None, Some(11)])) as ArrayRef;
let right =
- Arc::new(Int32Array::from(vec![Some(1), Some(3), Some(7)])) as ArrayRef;
- let mut result = bitwise_and(left.clone(), right.clone())?;
- let expected = Int32Array::from(vec![Some(0), None, Some(3)]);
- assert_eq!(result.as_ref(), &expected);
+ Arc::new(Int32Array::from_iter(vec![Some(1), Some(3), Some(7)])) as ArrayRef;
+ let result = bitwise_and(left.as_ref(), right.as_ref())?;
+ let expected = Int32Vec::from(vec![Some(0), None, Some(3)]).as_arc();
+ assert_eq!(result.as_ref(), expected.as_ref());
- result = bitwise_or(left.clone(), right.clone())?;
- let expected = Int32Array::from(vec![Some(13), None, Some(15)]);
- assert_eq!(result.as_ref(), &expected);
+ let result = bitwise_or(left.as_ref(), right.as_ref())?;
+ let expected = Int32Vec::from(vec![Some(13), None, Some(15)]).as_arc();
+ assert_eq!(result.as_ref(), expected.as_ref());
Ok(())
}
#[test]
fn bitwise_scalar_test() -> Result<()> {
- let left = Arc::new(Int32Array::from(vec![Some(12), None, Some(11)])) as ArrayRef;
+ let left =
+ Arc::new(Int32Array::from_iter(vec![Some(12), None, Some(11)])) as ArrayRef;
let right = ScalarValue::from(3i32);
- let mut result = bitwise_and_scalar(&left, right.clone()).unwrap()?;
- let expected = Int32Array::from(vec![Some(0), None, Some(3)]);
- assert_eq!(result.as_ref(), &expected);
+ let result = bitwise_and_scalar(left.as_ref(), right.clone()).unwrap()?;
+ let expected = Int32Vec::from(vec![Some(0), None, Some(3)]).as_arc();
+ assert_eq!(result.as_ref(), expected.as_ref());
- result = bitwise_or_scalar(&left, right).unwrap()?;
- let expected = Int32Array::from(vec![Some(15), None, Some(11)]);
- assert_eq!(result.as_ref(), &expected);
+ let result = bitwise_and_scalar(left.as_ref(), right).unwrap()?;
+ let expected = Int32Vec::from(vec![Some(15), None, Some(11)]).as_arc();
+ assert_eq!(result.as_ref(), expected.as_ref());
Ok(())
}
}
diff --git a/datafusion-physical-expr/src/expressions/case.rs b/datafusion-physical-expr/src/expressions/case.rs
index 3bcb78a..f42191a 100644
--- a/datafusion-physical-expr/src/expressions/case.rs
+++ b/datafusion-physical-expr/src/expressions/case.rs
@@ -17,15 +17,17 @@
use std::{any::Any, sync::Arc};
-use crate::expressions::try_cast;
-use crate::PhysicalExpr;
-use arrow::array::{self, *};
-use arrow::compute::{eq, eq_utf8};
+use arrow::array::*;
+use arrow::compute::{comparison, if_then_else};
use arrow::datatypes::{DataType, Schema};
-use arrow::record_batch::RecordBatch;
+
+use datafusion_common::record_batch::RecordBatch;
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::ColumnarValue;
+use crate::expressions::try_cast;
+use crate::PhysicalExpr;
+
type WhenThen = (Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>);
/// The CASE expression is similar to a series of nested if/else and there are two forms that
@@ -107,208 +109,6 @@
}
}
-macro_rules! if_then_else {
- ($BUILDER_TYPE:ty, $ARRAY_TYPE:ty, $BOOLS:expr, $TRUE:expr, $FALSE:expr) => {{
- let true_values = $TRUE
- .as_ref()
- .as_any()
- .downcast_ref::<$ARRAY_TYPE>()
- .expect("true_values downcast failed");
-
- let false_values = $FALSE
- .as_ref()
- .as_any()
- .downcast_ref::<$ARRAY_TYPE>()
- .expect("false_values downcast failed");
-
- let mut builder = <$BUILDER_TYPE>::new($BOOLS.len());
- for i in 0..$BOOLS.len() {
- if $BOOLS.is_null(i) {
- if false_values.is_null(i) {
- builder.append_null()?;
- } else {
- builder.append_value(false_values.value(i))?;
- }
- } else if $BOOLS.value(i) {
- if true_values.is_null(i) {
- builder.append_null()?;
- } else {
- builder.append_value(true_values.value(i))?;
- }
- } else {
- if false_values.is_null(i) {
- builder.append_null()?;
- } else {
- builder.append_value(false_values.value(i))?;
- }
- }
- }
- Ok(Arc::new(builder.finish()))
- }};
-}
-
-fn if_then_else(
- bools: &BooleanArray,
- true_values: ArrayRef,
- false_values: ArrayRef,
- data_type: &DataType,
-) -> Result<ArrayRef> {
- match data_type {
- DataType::UInt8 => if_then_else!(
- array::UInt8Builder,
- array::UInt8Array,
- bools,
- true_values,
- false_values
- ),
- DataType::UInt16 => if_then_else!(
- array::UInt16Builder,
- array::UInt16Array,
- bools,
- true_values,
- false_values
- ),
- DataType::UInt32 => if_then_else!(
- array::UInt32Builder,
- array::UInt32Array,
- bools,
- true_values,
- false_values
- ),
- DataType::UInt64 => if_then_else!(
- array::UInt64Builder,
- array::UInt64Array,
- bools,
- true_values,
- false_values
- ),
- DataType::Int8 => if_then_else!(
- array::Int8Builder,
- array::Int8Array,
- bools,
- true_values,
- false_values
- ),
- DataType::Int16 => if_then_else!(
- array::Int16Builder,
- array::Int16Array,
- bools,
- true_values,
- false_values
- ),
- DataType::Int32 => if_then_else!(
- array::Int32Builder,
- array::Int32Array,
- bools,
- true_values,
- false_values
- ),
- DataType::Int64 => if_then_else!(
- array::Int64Builder,
- array::Int64Array,
- bools,
- true_values,
- false_values
- ),
- DataType::Float32 => if_then_else!(
- array::Float32Builder,
- array::Float32Array,
- bools,
- true_values,
- false_values
- ),
- DataType::Float64 => if_then_else!(
- array::Float64Builder,
- array::Float64Array,
- bools,
- true_values,
- false_values
- ),
- DataType::Utf8 => if_then_else!(
- array::StringBuilder,
- array::StringArray,
- bools,
- true_values,
- false_values
- ),
- DataType::Boolean => if_then_else!(
- array::BooleanBuilder,
- array::BooleanArray,
- bools,
- true_values,
- false_values
- ),
- other => Err(DataFusionError::Execution(format!(
- "CASE does not support '{:?}'",
- other
- ))),
- }
-}
-
-macro_rules! array_equals {
- ($TY:ty, $L:expr, $R:expr, $eq_fn:expr) => {{
- let when_value = $L
- .as_ref()
- .as_any()
- .downcast_ref::<$TY>()
- .expect("array_equals downcast failed");
-
- let base_value = $R
- .as_ref()
- .as_any()
- .downcast_ref::<$TY>()
- .expect("array_equals downcast failed");
-
- $eq_fn(when_value, base_value).map_err(DataFusionError::from)
- }};
-}
-
-fn array_equals(
- data_type: &DataType,
- when_value: ArrayRef,
- base_value: ArrayRef,
-) -> Result<BooleanArray> {
- match data_type {
- DataType::UInt8 => {
- array_equals!(array::UInt8Array, when_value, base_value, eq)
- }
- DataType::UInt16 => {
- array_equals!(array::UInt16Array, when_value, base_value, eq)
- }
- DataType::UInt32 => {
- array_equals!(array::UInt32Array, when_value, base_value, eq)
- }
- DataType::UInt64 => {
- array_equals!(array::UInt64Array, when_value, base_value, eq)
- }
- DataType::Int8 => {
- array_equals!(array::Int8Array, when_value, base_value, eq)
- }
- DataType::Int16 => {
- array_equals!(array::Int16Array, when_value, base_value, eq)
- }
- DataType::Int32 => {
- array_equals!(array::Int32Array, when_value, base_value, eq)
- }
- DataType::Int64 => {
- array_equals!(array::Int64Array, when_value, base_value, eq)
- }
- DataType::Float32 => {
- array_equals!(array::Float32Array, when_value, base_value, eq)
- }
- DataType::Float64 => {
- array_equals!(array::Float64Array, when_value, base_value, eq)
- }
- DataType::Utf8 => {
- array_equals!(array::StringArray, when_value, base_value, eq_utf8)
- }
- other => Err(DataFusionError::Execution(format!(
- "CASE does not support '{:?}'",
- other
- ))),
- }
-}
-
impl CaseExpr {
/// This function evaluates the form of CASE that matches an expression to fixed values.
///
@@ -318,20 +118,19 @@
/// [ELSE result]
/// END
fn case_when_with_expr(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
- let return_type = self.when_then_expr[0].1.data_type(&batch.schema())?;
+ let return_type = self.when_then_expr[0].1.data_type(batch.schema())?;
let expr = self.expr.as_ref().unwrap();
let base_value = expr.evaluate(batch)?;
- let base_type = expr.data_type(&batch.schema())?;
let base_value = base_value.into_array(batch.num_rows());
// start with the else condition, or nulls
- let mut current_value: Option<ArrayRef> = if let Some(e) = &self.else_expr {
+ let mut current_value = if let Some(e) = &self.else_expr {
// keep `else_expr`'s data type and return type consistent
- let expr = try_cast(e.clone(), &*batch.schema(), return_type.clone())
+ let expr = try_cast(e.clone(), &*batch.schema(), return_type)
.unwrap_or_else(|_| e.clone());
- Some(expr.evaluate(batch)?.into_array(batch.num_rows()))
+ expr.evaluate(batch)?.into_array(batch.num_rows())
} else {
- Some(new_null_array(&return_type, batch.num_rows()))
+ new_null_array(return_type, batch.num_rows()).into()
};
// walk backwards through the when/then expressions
@@ -345,17 +144,27 @@
let then_value = then_value.into_array(batch.num_rows());
// build boolean array representing which rows match the "when" value
- let when_match = array_equals(&base_type, when_value, base_value.clone())?;
+ let when_match = comparison::eq(when_value.as_ref(), base_value.as_ref());
+ let when_match = if let Some(validity) = when_match.validity() {
+ // null values are never matched and should thus be "else".
+ BooleanArray::from_data(
+ DataType::Boolean,
+ when_match.values() & validity,
+ None,
+ )
+ } else {
+ when_match
+ };
- current_value = Some(if_then_else(
+ current_value = if_then_else::if_then_else(
&when_match,
- then_value,
- current_value.unwrap(),
- &return_type,
- )?);
+ then_value.as_ref(),
+ current_value.as_ref(),
+ )?
+ .into();
}
- Ok(ColumnarValue::Array(current_value.unwrap()))
+ Ok(ColumnarValue::Array(current_value))
}
/// This function evaluates the form of CASE where each WHEN expression is a boolean
@@ -366,15 +175,15 @@
/// [ELSE result]
/// END
fn case_when_no_expr(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
- let return_type = self.when_then_expr[0].1.data_type(&batch.schema())?;
+ let return_type = self.when_then_expr[0].1.data_type(batch.schema())?;
// start with the else condition, or nulls
- let mut current_value: Option<ArrayRef> = if let Some(e) = &self.else_expr {
- let expr = try_cast(e.clone(), &*batch.schema(), return_type.clone())
+ let mut current_value = if let Some(e) = &self.else_expr {
+ let expr = try_cast(e.clone(), &*batch.schema(), return_type)
.unwrap_or_else(|_| e.clone());
- Some(expr.evaluate(batch)?.into_array(batch.num_rows()))
+ expr.evaluate(batch)?.into_array(batch.num_rows())
} else {
- Some(new_null_array(&return_type, batch.num_rows()))
+ new_null_array(return_type, batch.num_rows()).into()
};
// walk backwards through the when/then expressions
@@ -387,20 +196,31 @@
.as_ref()
.as_any()
.downcast_ref::<BooleanArray>()
- .expect("WHEN expression did not return a BooleanArray");
+ .expect("WHEN expression did not return a BooleanArray")
+ .clone();
+ let when_value = if let Some(validity) = when_value.validity() {
+ // null values are never matched and should thus be "else".
+ BooleanArray::from_data(
+ DataType::Boolean,
+ when_value.values() & validity,
+ None,
+ )
+ } else {
+ when_value
+ };
let then_value = self.when_then_expr[i].1.evaluate(batch)?;
let then_value = then_value.into_array(batch.num_rows());
- current_value = Some(if_then_else(
- when_value,
- then_value,
- current_value.unwrap(),
- &return_type,
- )?);
+ current_value = if_then_else::if_then_else(
+ &when_value,
+ then_value.as_ref(),
+ current_value.as_ref(),
+ )?
+ .into();
}
- Ok(ColumnarValue::Array(current_value.unwrap()))
+ Ok(ColumnarValue::Array(current_value))
}
}
@@ -455,11 +275,10 @@
#[cfg(test)]
mod tests {
use super::*;
- use crate::expressions::binary;
- use crate::expressions::col;
- use crate::expressions::lit;
- use arrow::array::StringArray;
+ use crate::expressions::{binary, col, lit};
+ use arrow::array::Utf8Array;
use arrow::datatypes::*;
+ use datafusion_common::field_util::SchemaExt;
use datafusion_common::ScalarValue;
use datafusion_expr::Operator;
@@ -475,7 +294,7 @@
let then2 = lit(ScalarValue::Int32(Some(456)));
let expr = case(
- Some(col("a", &schema)?),
+ Some(col("a", schema)?),
&[(when1, then1), (when2, then2)],
None,
)?;
@@ -485,7 +304,7 @@
.downcast_ref::<Int32Array>()
.expect("failed to downcast to Int32Array");
- let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]);
+ let expected = &Int32Array::from_iter(vec![Some(123), None, None, Some(456)]);
assert_eq!(expected, result);
@@ -505,7 +324,7 @@
let else_value = lit(ScalarValue::Int32(Some(999)));
let expr = case(
- Some(col("a", &schema)?),
+ Some(col("a", schema)?),
&[(when1, then1), (when2, then2)],
Some(else_value),
)?;
@@ -516,7 +335,7 @@
.expect("failed to downcast to Int32Array");
let expected =
- &Int32Array::from(vec![Some(123), Some(999), Some(999), Some(456)]);
+ &Int32Array::from_iter(vec![Some(123), Some(999), Some(999), Some(456)]);
assert_eq!(expected, result);
@@ -530,17 +349,17 @@
// CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN 456 END
let when1 = binary(
- col("a", &schema)?,
+ col("a", schema)?,
Operator::Eq,
lit(ScalarValue::Utf8(Some("foo".to_string()))),
- &batch.schema(),
+ batch.schema(),
)?;
let then1 = lit(ScalarValue::Int32(Some(123)));
let when2 = binary(
- col("a", &schema)?,
+ col("a", schema)?,
Operator::Eq,
lit(ScalarValue::Utf8(Some("bar".to_string()))),
- &batch.schema(),
+ batch.schema(),
)?;
let then2 = lit(ScalarValue::Int32(Some(456)));
@@ -551,7 +370,7 @@
.downcast_ref::<Int32Array>()
.expect("failed to downcast to Int32Array");
- let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]);
+ let expected = &Int32Array::from_iter(vec![Some(123), None, None, Some(456)]);
assert_eq!(expected, result);
@@ -565,17 +384,17 @@
// CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN 456 ELSE 999 END
let when1 = binary(
- col("a", &schema)?,
+ col("a", schema)?,
Operator::Eq,
lit(ScalarValue::Utf8(Some("foo".to_string()))),
- &batch.schema(),
+ batch.schema(),
)?;
let then1 = lit(ScalarValue::Int32(Some(123)));
let when2 = binary(
- col("a", &schema)?,
+ col("a", schema)?,
Operator::Eq,
lit(ScalarValue::Utf8(Some("bar".to_string()))),
- &batch.schema(),
+ batch.schema(),
)?;
let then2 = lit(ScalarValue::Int32(Some(456)));
let else_value = lit(ScalarValue::Int32(Some(999)));
@@ -588,7 +407,7 @@
.expect("failed to downcast to Int32Array");
let expected =
- &Int32Array::from(vec![Some(123), Some(999), Some(999), Some(456)]);
+ &Int32Array::from_iter(vec![Some(123), Some(999), Some(999), Some(456)]);
assert_eq!(expected, result);
@@ -602,10 +421,10 @@
// CASE WHEN a = 'foo' THEN 123.3 ELSE 999 END
let when = binary(
- col("a", &schema)?,
+ col("a", schema)?,
Operator::Eq,
lit(ScalarValue::Utf8(Some("foo".to_string()))),
- &batch.schema(),
+ batch.schema(),
)?;
let then = lit(ScalarValue::Float64(Some(123.3)));
let else_value = lit(ScalarValue::Int32(Some(999)));
@@ -617,8 +436,12 @@
.downcast_ref::<Float64Array>()
.expect("failed to downcast to Float64Array");
- let expected =
- &Float64Array::from(vec![Some(123.3), Some(999.0), Some(999.0), Some(999.0)]);
+ let expected = &Float64Array::from_iter(vec![
+ Some(123.3),
+ Some(999.0),
+ Some(999.0),
+ Some(999.0),
+ ]);
assert_eq!(expected, result);
@@ -626,7 +449,7 @@
}
fn case_test_batch() -> Result<RecordBatch> {
let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
- let a = StringArray::from(vec![Some("foo"), Some("baz"), None, Some("bar")]);
+ let a = Utf8Array::<i32>::from(vec![Some("foo"), Some("baz"), None, Some("bar")]);
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
Ok(batch)
}
diff --git a/datafusion-physical-expr/src/expressions/cast.rs b/datafusion-physical-expr/src/expressions/cast.rs
index 9144acc..d284180 100644
--- a/datafusion-physical-expr/src/expressions/cast.rs
+++ b/datafusion-physical-expr/src/expressions/cast.rs
@@ -19,19 +19,24 @@
use std::fmt;
use std::sync::Arc;
-use crate::PhysicalExpr;
-use arrow::compute;
-use arrow::compute::kernels;
-use arrow::compute::CastOptions;
+use arrow::array::{Array, Int32Array};
+use arrow::compute::cast;
+use arrow::compute::cast::CastOptions;
+use arrow::compute::take;
use arrow::datatypes::{DataType, Schema};
-use arrow::record_batch::RecordBatch;
-use compute::can_cast_types;
+
+use datafusion_common::record_batch::RecordBatch;
use datafusion_common::ScalarValue;
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::ColumnarValue;
+use crate::PhysicalExpr;
+
/// provide DataFusion default cast options
-pub const DEFAULT_DATAFUSION_CAST_OPTIONS: CastOptions = CastOptions { safe: false };
+pub const DEFAULT_DATAFUSION_CAST_OPTIONS: CastOptions = CastOptions {
+ wrapped: false,
+ partial: false,
+};
/// CAST expression casts an expression to a specific data type and returns a runtime error on invalid cast
#[derive(Debug)]
@@ -91,25 +96,52 @@
fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
let value = self.expr.evaluate(batch)?;
- cast_column(&value, &self.cast_type, &self.cast_options)
+ cast_column(&value, &self.cast_type, self.cast_options)
}
}
+pub fn cast_with_error(
+ array: &dyn Array,
+ cast_type: &DataType,
+ options: CastOptions,
+) -> Result<Box<dyn Array>> {
+ let result = cast::cast(array, cast_type, options)?;
+ if result.null_count() != array.null_count() {
+ let casted_valids = result.validity().unwrap();
+ let failed_casts = match array.validity() {
+ Some(valids) => valids ^ casted_valids,
+ None => !casted_valids,
+ };
+ let invalid_indices = failed_casts
+ .iter()
+ .enumerate()
+ .filter(|(_, failed)| *failed)
+ .map(|(idx, _)| Some(idx as i32))
+ .collect::<Vec<Option<i32>>>();
+ let invalid_values = take::take(array, &Int32Array::from(&invalid_indices))?;
+ return Err(DataFusionError::Execution(format!(
+ "Could not cast {:?} to value of type {:?}",
+ invalid_values, cast_type
+ )));
+ }
+ Ok(result)
+}
+
/// Internal cast function for casting ColumnarValue -> ColumnarValue for cast_type
pub fn cast_column(
value: &ColumnarValue,
cast_type: &DataType,
- cast_options: &CastOptions,
+ cast_options: CastOptions,
) -> Result<ColumnarValue> {
match value {
- ColumnarValue::Array(array) => Ok(ColumnarValue::Array(
- kernels::cast::cast_with_options(array, cast_type, cast_options)?,
- )),
+ ColumnarValue::Array(array) => Ok(ColumnarValue::Array(Arc::from(
+ cast_with_error(array.as_ref(), cast_type, cast_options)?,
+ ))),
ColumnarValue::Scalar(scalar) => {
let scalar_array = scalar.to_array();
let cast_array =
- kernels::cast::cast_with_options(&scalar_array, cast_type, cast_options)?;
- let cast_scalar = ScalarValue::try_from_array(&cast_array, 0)?;
+ cast_with_error(scalar_array.as_ref(), cast_type, cast_options)?;
+ let cast_scalar = ScalarValue::try_from_array(&Arc::from(cast_array), 0)?;
Ok(ColumnarValue::Scalar(cast_scalar))
}
}
@@ -128,7 +160,7 @@
let expr_type = expr.data_type(input_schema)?;
if expr_type == cast_type {
Ok(expr.clone())
- } else if can_cast_types(&expr_type, &cast_type) {
+ } else if cast::can_cast_types(&expr_type, &cast_type) {
Ok(Arc::new(CastExpr::new(expr, cast_type, cast_options)))
} else {
Err(DataFusionError::Internal(format!(
@@ -158,17 +190,19 @@
#[cfg(test)]
mod tests {
use super::*;
+
use crate::expressions::col;
- use arrow::{
- array::{
- Array, DecimalArray, Float32Array, Float64Array, Int16Array, Int32Array,
- Int64Array, Int8Array, StringArray, Time64NanosecondArray,
- TimestampNanosecondArray, UInt32Array,
- },
- datatypes::*,
+ use crate::test_util::{create_decimal_array, create_decimal_array_from_slice};
+ use arrow::array::{
+ Array, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array,
+ UInt32Array,
};
+ use arrow::{array::*, datatypes::*};
+ use datafusion_common::field_util::SchemaExt;
use datafusion_common::Result;
+ type StringArray = Utf8Array<i32>;
+
// runs an end-to-end test of physical type cast
// 1. construct a record batch with a column "a" of type A
// 2. construct a physical expression of CAST(a AS B)
@@ -226,7 +260,7 @@
macro_rules! generic_test_cast {
($A_ARRAY:ident, $A_TYPE:expr, $A_VEC:expr, $TYPEARRAY:ident, $TYPE:expr, $VEC:expr, $CAST_OPTIONS:expr) => {{
let schema = Schema::new(vec![Field::new("a", $A_TYPE, false)]);
- let a = $A_ARRAY::from($A_VEC);
+ let a = $A_ARRAY::from_slice($A_VEC);
let batch =
RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
@@ -270,18 +304,12 @@
#[test]
fn test_cast_decimal_to_decimal() -> Result<()> {
- let array = vec![1234, 2222, 3, 4000, 5000];
-
- let decimal_array = array
- .iter()
- .map(|v| Some(*v))
- .collect::<DecimalArray>()
- .with_precision_and_scale(10, 3)?;
-
+ let array: Vec<i128> = vec![1234, 2222, 3, 4000, 5000];
+ let decimal_array = create_decimal_array_from_slice(&array, 10, 3)?;
generic_decimal_to_other_test_cast!(
decimal_array,
DataType::Decimal(10, 3),
- DecimalArray,
+ Int128Array,
DataType::Decimal(20, 6),
vec![
Some(1_234_000_i128),
@@ -294,16 +322,11 @@
DEFAULT_DATAFUSION_CAST_OPTIONS
);
- let decimal_array = array
- .iter()
- .map(|v| Some(*v))
- .collect::<DecimalArray>()
- .with_precision_and_scale(10, 3)?;
-
+ let decimal_array = create_decimal_array_from_slice(&array, 10, 3)?;
generic_decimal_to_other_test_cast!(
decimal_array,
DataType::Decimal(10, 3),
- DecimalArray,
+ Int128Array,
DataType::Decimal(10, 2),
vec![
Some(123_i128),
@@ -323,10 +346,7 @@
fn test_cast_decimal_to_numeric() -> Result<()> {
let array = vec![Some(1), Some(2), Some(3), Some(4), Some(5), None];
// decimal to i8
- let decimal_array = array
- .iter()
- .collect::<DecimalArray>()
- .with_precision_and_scale(10, 0)?;
+ let decimal_array = create_decimal_array(&array, 10, 0)?;
generic_decimal_to_other_test_cast!(
decimal_array,
DataType::Decimal(10, 0),
@@ -344,10 +364,7 @@
);
// decimal to i16
- let decimal_array = array
- .iter()
- .collect::<DecimalArray>()
- .with_precision_and_scale(10, 0)?;
+ let decimal_array = create_decimal_array(&array, 10, 0)?;
generic_decimal_to_other_test_cast!(
decimal_array,
DataType::Decimal(10, 0),
@@ -365,10 +382,7 @@
);
// decimal to i32
- let decimal_array = array
- .iter()
- .collect::<DecimalArray>()
- .with_precision_and_scale(10, 0)?;
+ let decimal_array = create_decimal_array(&array, 10, 0)?;
generic_decimal_to_other_test_cast!(
decimal_array,
DataType::Decimal(10, 0),
@@ -386,10 +400,7 @@
);
// decimal to i64
- let decimal_array = array
- .iter()
- .collect::<DecimalArray>()
- .with_precision_and_scale(10, 0)?;
+ let decimal_array = create_decimal_array(&array, 10, 0)?;
generic_decimal_to_other_test_cast!(
decimal_array,
DataType::Decimal(10, 0),
@@ -407,18 +418,8 @@
);
// decimal to float32
- let array = vec![
- Some(1234),
- Some(2222),
- Some(3),
- Some(4000),
- Some(5000),
- None,
- ];
- let decimal_array = array
- .iter()
- .collect::<DecimalArray>()
- .with_precision_and_scale(10, 3)?;
+ let array: Vec<i128> = vec![1234, 2222, 3, 4000, 5000];
+ let decimal_array = create_decimal_array_from_slice(&array, 10, 0)?;
generic_decimal_to_other_test_cast!(
decimal_array,
DataType::Decimal(10, 3),
@@ -436,10 +437,7 @@
);
// decimal to float64
- let decimal_array = array
- .into_iter()
- .collect::<DecimalArray>()
- .with_precision_and_scale(20, 6)?;
+ let decimal_array = create_decimal_array_from_slice(&array, 20, 6)?;
generic_decimal_to_other_test_cast!(
decimal_array,
DataType::Decimal(20, 6),
@@ -465,7 +463,7 @@
Int8Array,
DataType::Int8,
vec![1, 2, 3, 4, 5],
- DecimalArray,
+ Int128Array,
DataType::Decimal(3, 0),
vec![
Some(1_i128),
@@ -482,7 +480,7 @@
Int16Array,
DataType::Int16,
vec![1, 2, 3, 4, 5],
- DecimalArray,
+ Int128Array,
DataType::Decimal(5, 0),
vec![
Some(1_i128),
@@ -499,7 +497,7 @@
Int32Array,
DataType::Int32,
vec![1, 2, 3, 4, 5],
- DecimalArray,
+ Int128Array,
DataType::Decimal(10, 0),
vec![
Some(1_i128),
@@ -516,7 +514,7 @@
Int64Array,
DataType::Int64,
vec![1, 2, 3, 4, 5],
- DecimalArray,
+ Int128Array,
DataType::Decimal(20, 0),
vec![
Some(1_i128),
@@ -533,7 +531,7 @@
Int64Array,
DataType::Int64,
vec![1, 2, 3, 4, 5],
- DecimalArray,
+ Int128Array,
DataType::Decimal(20, 2),
vec![
Some(100_i128),
@@ -550,7 +548,7 @@
Float32Array,
DataType::Float32,
vec![1.5, 2.5, 3.0, 1.123_456_8, 5.50],
- DecimalArray,
+ Int128Array,
DataType::Decimal(10, 2),
vec![
Some(150_i128),
@@ -567,7 +565,7 @@
Float64Array,
DataType::Float64,
vec![1.5, 2.5, 3.0, 1.123_456_8, 5.50],
- DecimalArray,
+ Int128Array,
DataType::Decimal(20, 4),
vec![
Some(15000_i128),
@@ -586,7 +584,7 @@
generic_test_cast!(
Int32Array,
DataType::Int32,
- vec![1, 2, 3, 4, 5],
+ &[1, 2, 3, 4, 5],
UInt32Array,
DataType::UInt32,
vec![
@@ -606,7 +604,7 @@
generic_test_cast!(
Int32Array,
DataType::Int32,
- vec![1, 2, 3, 4, 5],
+ &[1, 2, 3, 4, 5],
StringArray,
DataType::Utf8,
vec![Some("1"), Some("2"), Some("3"), Some("4"), Some("5")],
@@ -618,16 +616,13 @@
#[allow(clippy::redundant_clone)]
#[test]
fn test_cast_i64_t64() -> Result<()> {
- let original = vec![1, 2, 3, 4, 5];
- let expected: Vec<Option<i64>> = original
- .iter()
- .map(|i| Some(Time64NanosecondArray::from(vec![*i]).value(0)))
- .collect();
+ let original = &[1, 2, 3, 4, 5];
+ let expected: Vec<Option<i64>> = original.iter().map(|i| Some(*i)).collect();
generic_test_cast!(
Int64Array,
DataType::Int64,
- original.clone(),
- TimestampNanosecondArray,
+ original,
+ Int64Array,
DataType::Timestamp(TimeUnit::Nanosecond, None),
expected,
DEFAULT_DATAFUSION_CAST_OPTIONS
@@ -638,17 +633,21 @@
#[test]
fn invalid_cast() {
// Ensure a useful error happens at plan time if invalid casts are used
- let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
+ let schema = Schema::new(vec![Field::new("a", DataType::Null, false)]);
- let result = cast(col("a", &schema).unwrap(), &schema, DataType::LargeBinary);
- result.expect_err("expected Invalid CAST");
+ let result = cast_column(
+ col("a", &schema).unwrap().as_any().downcast_ref().unwrap(),
+ &DataType::LargeBinary,
+ DEFAULT_DATAFUSION_CAST_OPTIONS,
+ );
+ assert!(result.is_err(), "expected Invalid CAST");
}
#[test]
fn invalid_cast_with_options_error() -> Result<()> {
// Ensure a useful error happens at plan time if invalid casts are used
let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]);
- let a = StringArray::from(vec!["9.1"]);
+ let a = StringArray::from_slice(vec!["9.1"]);
let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
let expression = cast_with_options(
col("a", &schema)?,
diff --git a/datafusion-physical-expr/src/expressions/column.rs b/datafusion-physical-expr/src/expressions/column.rs
index 3def89f..4e4c361 100644
--- a/datafusion-physical-expr/src/expressions/column.rs
+++ b/datafusion-physical-expr/src/expressions/column.rs
@@ -19,12 +19,10 @@
use std::sync::Arc;
-use arrow::{
- datatypes::{DataType, Schema},
- record_batch::RecordBatch,
-};
-
use crate::PhysicalExpr;
+use arrow::datatypes::{DataType, Schema};
+use datafusion_common::field_util::{FieldExt, SchemaExt};
+use datafusion_common::record_batch::RecordBatch;
use datafusion_common::Result;
use datafusion_expr::ColumnarValue;
diff --git a/datafusion-physical-expr/src/expressions/correlation.rs b/datafusion-physical-expr/src/expressions/correlation.rs
index 3f7b28a..d27d0b6 100644
--- a/datafusion-physical-expr/src/expressions/correlation.rs
+++ b/datafusion-physical-expr/src/expressions/correlation.rs
@@ -230,14 +230,15 @@
use super::*;
use crate::expressions::col;
use crate::generic_test_op2;
- use arrow::record_batch::RecordBatch;
use arrow::{array::*, datatypes::*};
+ use datafusion_common::field_util::SchemaExt;
+ use datafusion_common::record_batch::RecordBatch;
use datafusion_common::Result;
#[test]
fn correlation_f64_1() -> Result<()> {
- let a: ArrayRef = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64]));
- let b: ArrayRef = Arc::new(Float64Array::from(vec![4_f64, 5_f64, 7_f64]));
+ let a: ArrayRef = Arc::new(Float64Array::from_slice(vec![1_f64, 2_f64, 3_f64]));
+ let b: ArrayRef = Arc::new(Float64Array::from_slice(vec![4_f64, 5_f64, 7_f64]));
generic_test_op2!(
a,
@@ -252,8 +253,8 @@
#[test]
fn correlation_f64_2() -> Result<()> {
- let a: ArrayRef = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64]));
- let b: ArrayRef = Arc::new(Float64Array::from(vec![4_f64, -5_f64, 6_f64]));
+ let a: ArrayRef = Arc::new(Float64Array::from_slice(vec![1_f64, 2_f64, 3_f64]));
+ let b: ArrayRef = Arc::new(Float64Array::from_slice(vec![4_f64, -5_f64, 6_f64]));
generic_test_op2!(
a,
@@ -268,8 +269,8 @@
#[test]
fn correlation_f64_4() -> Result<()> {
- let a: ArrayRef = Arc::new(Float64Array::from(vec![1.1_f64, 2_f64, 3_f64]));
- let b: ArrayRef = Arc::new(Float64Array::from(vec![4.1_f64, 5_f64, 6_f64]));
+ let a: ArrayRef = Arc::new(Float64Array::from_slice(vec![1.1_f64, 2_f64, 3_f64]));
+ let b: ArrayRef = Arc::new(Float64Array::from_slice(vec![4.1_f64, 5_f64, 6_f64]));
generic_test_op2!(
a,
@@ -284,10 +285,10 @@
#[test]
fn correlation_f64_6() -> Result<()> {
- let a = Arc::new(Float64Array::from(vec![
+ let a = Arc::new(Float64Array::from_slice(&[
1_f64, 2_f64, 3_f64, 1.1_f64, 2.2_f64, 3.3_f64,
]));
- let b = Arc::new(Float64Array::from(vec![
+ let b = Arc::new(Float64Array::from_slice(&[
4_f64, 5_f64, 6_f64, 4.4_f64, 5.5_f64, 6.6_f64,
]));
@@ -304,8 +305,8 @@
#[test]
fn correlation_i32() -> Result<()> {
- let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3]));
- let b: ArrayRef = Arc::new(Int32Array::from(vec![4, 5, 6]));
+ let a: ArrayRef = Arc::new(Int32Array::from_slice(vec![1, 2, 3]));
+ let b: ArrayRef = Arc::new(Int32Array::from_slice(vec![4, 5, 6]));
generic_test_op2!(
a,
@@ -320,8 +321,8 @@
#[test]
fn correlation_u32() -> Result<()> {
- let a: ArrayRef = Arc::new(UInt32Array::from(vec![1_u32, 2_u32, 3_u32]));
- let b: ArrayRef = Arc::new(UInt32Array::from(vec![4_u32, 5_u32, 6_u32]));
+ let a: ArrayRef = Arc::new(UInt32Array::from_slice(vec![1_u32, 2_u32, 3_u32]));
+ let b: ArrayRef = Arc::new(UInt32Array::from_slice(vec![4_u32, 5_u32, 6_u32]));
generic_test_op2!(
a,
b,
@@ -335,8 +336,8 @@
#[test]
fn correlation_f32() -> Result<()> {
- let a: ArrayRef = Arc::new(Float32Array::from(vec![1_f32, 2_f32, 3_f32]));
- let b: ArrayRef = Arc::new(Float32Array::from(vec![4_f32, 5_f32, 6_f32]));
+ let a: ArrayRef = Arc::new(Float32Array::from_slice(vec![1_f32, 2_f32, 3_f32]));
+ let b: ArrayRef = Arc::new(Float32Array::from_slice(vec![4_f32, 5_f32, 6_f32]));
generic_test_op2!(
a,
b,
@@ -362,9 +363,9 @@
#[test]
fn correlation_i32_with_nulls_1() -> Result<()> {
let a: ArrayRef =
- Arc::new(Int32Array::from(vec![Some(1), None, Some(3), Some(3)]));
+ Arc::new(Int32Array::from_iter(vec![Some(1), None, Some(3), Some(3)]));
let b: ArrayRef =
- Arc::new(Int32Array::from(vec![Some(4), None, Some(6), Some(3)]));
+ Arc::new(Int32Array::from_iter(vec![Some(4), None, Some(6), Some(3)]));
generic_test_op2!(
a,
@@ -379,8 +380,9 @@
#[test]
fn correlation_i32_with_nulls_2() -> Result<()> {
- let a: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), None, Some(3)]));
- let b: ArrayRef = Arc::new(Int32Array::from(vec![Some(4), Some(5), Some(6)]));
+ let a: ArrayRef = Arc::new(Int32Array::from_iter(vec![Some(1), None, Some(3)]));
+ let b: ArrayRef =
+ Arc::new(Int32Array::from_iter(vec![Some(4), Some(5), Some(6)]));
let schema = Schema::new(vec![
Field::new("a", DataType::Int32, false),
@@ -402,8 +404,8 @@
#[test]
fn correlation_i32_all_nulls() -> Result<()> {
- let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None]));
- let b: ArrayRef = Arc::new(Int32Array::from(vec![None, None]));
+ let a: ArrayRef = Arc::new(Int32Array::from_iter(vec![None, None]));
+ let b: ArrayRef = Arc::new(Int32Array::from_iter(vec![None, None]));
let schema = Schema::new(vec![
Field::new("a", DataType::Int32, false),
@@ -425,10 +427,10 @@
#[test]
fn correlation_f64_merge_1() -> Result<()> {
- let a = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64]));
- let b = Arc::new(Float64Array::from(vec![4_f64, 5_f64, 6_f64]));
- let c = Arc::new(Float64Array::from(vec![1.1_f64, 2.2_f64, 3.3_f64]));
- let d = Arc::new(Float64Array::from(vec![4.4_f64, 5.5_f64, 9.9_f64]));
+ let a = Arc::new(Float64Array::from_slice(vec![1_f64, 2_f64, 3_f64]));
+ let b = Arc::new(Float64Array::from_slice(vec![4_f64, 5_f64, 6_f64]));
+ let c = Arc::new(Float64Array::from_slice(vec![1.1_f64, 2.2_f64, 3.3_f64]));
+ let d = Arc::new(Float64Array::from_slice(vec![4.4_f64, 5.5_f64, 9.9_f64]));
let schema = Schema::new(vec![
Field::new("a", DataType::Float64, false),
@@ -460,10 +462,10 @@
#[test]
fn correlation_f64_merge_2() -> Result<()> {
- let a = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64]));
- let b = Arc::new(Float64Array::from(vec![4_f64, 5_f64, 6_f64]));
- let c = Arc::new(Float64Array::from(vec![None]));
- let d = Arc::new(Float64Array::from(vec![None]));
+ let a = Arc::new(Float64Array::from_slice(vec![1_f64, 2_f64, 3_f64]));
+ let b = Arc::new(Float64Array::from_slice(vec![4_f64, 5_f64, 6_f64]));
+ let c = Arc::new(Float64Array::from_iter(vec![None]));
+ let d = Arc::new(Float64Array::from_iter(vec![None]));
let schema = Schema::new(vec![
Field::new("a", DataType::Float64, false),
diff --git a/datafusion-physical-expr/src/expressions/count.rs b/datafusion-physical-expr/src/expressions/count.rs
index ccc5a8e..4ed0802 100644
--- a/datafusion-physical-expr/src/expressions/count.rs
+++ b/datafusion-physical-expr/src/expressions/count.rs
@@ -109,13 +109,13 @@
impl Accumulator for CountAccumulator {
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
let array = &values[0];
- self.count += (array.len() - array.data().null_count()) as u64;
+ self.count += (array.len() - array.null_count()) as u64;
Ok(())
}
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
let counts = states[0].as_any().downcast_ref::<UInt64Array>().unwrap();
- let delta = &compute::sum(counts);
+ let delta = &compute::aggregate::sum_primitive(counts);
if let Some(d) = delta {
self.count += *d;
}
@@ -134,16 +134,16 @@
#[cfg(test)]
mod tests {
use super::*;
- use crate::expressions::col;
use crate::expressions::tests::aggregate;
use crate::generic_test_op;
- use arrow::record_batch::RecordBatch;
use arrow::{array::*, datatypes::*};
+ use datafusion_common::field_util::SchemaExt;
+ use datafusion_common::record_batch::RecordBatch;
use datafusion_common::Result;
#[test]
fn count_elements() -> Result<()> {
- let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
+ let a: ArrayRef = Arc::new(Int32Array::from_slice(vec![1, 2, 3, 4, 5]));
generic_test_op!(
a,
DataType::Int32,
@@ -155,7 +155,7 @@
#[test]
fn count_with_nulls() -> Result<()> {
- let a: ArrayRef = Arc::new(Int32Array::from(vec![
+ let a: ArrayRef = Arc::new(Int32Array::from_iter(vec![
Some(1),
Some(2),
None,
@@ -174,7 +174,7 @@
#[test]
fn count_all_nulls() -> Result<()> {
- let a: ArrayRef = Arc::new(BooleanArray::from(vec![
+ let a: ArrayRef = Arc::new(BooleanArray::from_iter(vec![
None, None, None, None, None, None, None, None,
]));
generic_test_op!(
@@ -188,8 +188,7 @@
#[test]
fn count_empty() -> Result<()> {
- let a: Vec<bool> = vec![];
- let a: ArrayRef = Arc::new(BooleanArray::from(a));
+ let a: ArrayRef = Arc::new(BooleanArray::new_empty(DataType::Boolean));
generic_test_op!(
a,
DataType::Boolean,
@@ -201,8 +200,9 @@
#[test]
fn count_utf8() -> Result<()> {
- let a: ArrayRef =
- Arc::new(StringArray::from(vec!["a", "bb", "ccc", "dddd", "ad"]));
+ let a: ArrayRef = Arc::new(Utf8Array::<i32>::from_slice(&[
+ "a", "bb", "ccc", "dddd", "ad",
+ ]));
generic_test_op!(
a,
DataType::Utf8,
@@ -214,8 +214,9 @@
#[test]
fn count_large_utf8() -> Result<()> {
- let a: ArrayRef =
- Arc::new(LargeStringArray::from(vec!["a", "bb", "ccc", "dddd", "ad"]));
+ let a: ArrayRef = Arc::new(Utf8Array::<i64>::from_slice(&[
+ "a", "bb", "ccc", "dddd", "ad",
+ ]));
generic_test_op!(
a,
DataType::LargeUtf8,
diff --git a/datafusion-physical-expr/src/expressions/covariance.rs b/datafusion-physical-expr/src/expressions/covariance.rs
index 539a869..ae60b2a 100644
--- a/datafusion-physical-expr/src/expressions/covariance.rs
+++ b/datafusion-physical-expr/src/expressions/covariance.rs
@@ -20,11 +20,11 @@
use std::any::Any;
use std::sync::Arc;
+use crate::expressions::cast::{cast_with_error, DEFAULT_DATAFUSION_CAST_OPTIONS};
use crate::{AggregateExpr, PhysicalExpr};
use arrow::array::Float64Array;
use arrow::{
array::{ArrayRef, UInt64Array},
- compute::cast,
datatypes::DataType,
datatypes::Field,
};
@@ -282,8 +282,16 @@
}
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
- let values1 = &cast(&values[0], &DataType::Float64)?;
- let values2 = &cast(&values[1], &DataType::Float64)?;
+ let values1 = &cast_with_error(
+ values[0].as_ref(),
+ &DataType::Float64,
+ DEFAULT_DATAFUSION_CAST_OPTIONS,
+ )?;
+ let values2 = &cast_with_error(
+ values[1].as_ref(),
+ &DataType::Float64,
+ DEFAULT_DATAFUSION_CAST_OPTIONS,
+ )?;
let mut arr1 = values1
.as_any()
@@ -389,14 +397,15 @@
use super::*;
use crate::expressions::col;
use crate::generic_test_op2;
- use arrow::record_batch::RecordBatch;
use arrow::{array::*, datatypes::*};
+ use datafusion_common::field_util::SchemaExt;
+ use datafusion_common::record_batch::RecordBatch;
use datafusion_common::Result;
#[test]
fn covariance_f64_1() -> Result<()> {
- let a: ArrayRef = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64]));
- let b: ArrayRef = Arc::new(Float64Array::from(vec![4_f64, 5_f64, 6_f64]));
+ let a: ArrayRef = Arc::new(Float64Array::from_slice(vec![1_f64, 2_f64, 3_f64]));
+ let b: ArrayRef = Arc::new(Float64Array::from_slice(vec![4_f64, 5_f64, 6_f64]));
generic_test_op2!(
a,
@@ -411,8 +420,8 @@
#[test]
fn covariance_f64_2() -> Result<()> {
- let a: ArrayRef = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64]));
- let b: ArrayRef = Arc::new(Float64Array::from(vec![4_f64, 5_f64, 6_f64]));
+ let a: ArrayRef = Arc::new(Float64Array::from_slice(vec![1_f64, 2_f64, 3_f64]));
+ let b: ArrayRef = Arc::new(Float64Array::from_slice(vec![4_f64, 5_f64, 6_f64]));
generic_test_op2!(
a,
@@ -427,8 +436,8 @@
#[test]
fn covariance_f64_4() -> Result<()> {
- let a: ArrayRef = Arc::new(Float64Array::from(vec![1.1_f64, 2_f64, 3_f64]));
- let b: ArrayRef = Arc::new(Float64Array::from(vec![4.1_f64, 5_f64, 6_f64]));
+ let a: ArrayRef = Arc::new(Float64Array::from_slice(vec![1.1_f64, 2_f64, 3_f64]));
+ let b: ArrayRef = Arc::new(Float64Array::from_slice(vec![4.1_f64, 5_f64, 6_f64]));
generic_test_op2!(
a,
@@ -443,8 +452,8 @@
#[test]
fn covariance_f64_5() -> Result<()> {
- let a: ArrayRef = Arc::new(Float64Array::from(vec![1.1_f64, 2_f64, 3_f64]));
- let b: ArrayRef = Arc::new(Float64Array::from(vec![4.1_f64, 5_f64, 6_f64]));
+ let a: ArrayRef = Arc::new(Float64Array::from_slice(vec![1.1_f64, 2_f64, 3_f64]));
+ let b: ArrayRef = Arc::new(Float64Array::from_slice(vec![4.1_f64, 5_f64, 6_f64]));
generic_test_op2!(
a,
@@ -459,10 +468,10 @@
#[test]
fn covariance_f64_6() -> Result<()> {
- let a = Arc::new(Float64Array::from(vec![
+ let a = Arc::new(Float64Array::from_slice(&[
1_f64, 2_f64, 3_f64, 1.1_f64, 2.2_f64, 3.3_f64,
]));
- let b = Arc::new(Float64Array::from(vec![
+ let b = Arc::new(Float64Array::from_slice(&[
4_f64, 5_f64, 6_f64, 4.4_f64, 5.5_f64, 6.6_f64,
]));
@@ -479,8 +488,8 @@
#[test]
fn covariance_i32() -> Result<()> {
- let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3]));
- let b: ArrayRef = Arc::new(Int32Array::from(vec![4, 5, 6]));
+ let a: ArrayRef = Arc::new(Int32Array::from_slice(vec![1, 2, 3]));
+ let b: ArrayRef = Arc::new(Int32Array::from_slice(vec![4, 5, 6]));
generic_test_op2!(
a,
@@ -495,8 +504,8 @@
#[test]
fn covariance_u32() -> Result<()> {
- let a: ArrayRef = Arc::new(UInt32Array::from(vec![1_u32, 2_u32, 3_u32]));
- let b: ArrayRef = Arc::new(UInt32Array::from(vec![4_u32, 5_u32, 6_u32]));
+ let a: ArrayRef = Arc::new(UInt32Array::from_slice(vec![1_u32, 2_u32, 3_u32]));
+ let b: ArrayRef = Arc::new(UInt32Array::from_slice(vec![4_u32, 5_u32, 6_u32]));
generic_test_op2!(
a,
b,
@@ -510,8 +519,8 @@
#[test]
fn covariance_f32() -> Result<()> {
- let a: ArrayRef = Arc::new(Float32Array::from(vec![1_f32, 2_f32, 3_f32]));
- let b: ArrayRef = Arc::new(Float32Array::from(vec![4_f32, 5_f32, 6_f32]));
+ let a: ArrayRef = Arc::new(Float32Array::from_slice(vec![1_f32, 2_f32, 3_f32]));
+ let b: ArrayRef = Arc::new(Float32Array::from_slice(vec![4_f32, 5_f32, 6_f32]));
generic_test_op2!(
a,
b,
@@ -536,8 +545,8 @@
#[test]
fn covariance_i32_with_nulls_1() -> Result<()> {
- let a: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), None, Some(3)]));
- let b: ArrayRef = Arc::new(Int32Array::from(vec![Some(4), None, Some(6)]));
+ let a: ArrayRef = Arc::new(Int32Array::from_iter(vec![Some(1), None, Some(3)]));
+ let b: ArrayRef = Arc::new(Int32Array::from_iter(vec![Some(4), None, Some(6)]));
generic_test_op2!(
a,
@@ -552,8 +561,9 @@
#[test]
fn covariance_i32_with_nulls_2() -> Result<()> {
- let a: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), None, Some(3)]));
- let b: ArrayRef = Arc::new(Int32Array::from(vec![Some(4), Some(5), Some(6)]));
+ let a: ArrayRef = Arc::new(Int32Array::from_iter(vec![Some(1), None, Some(3)]));
+ let b: ArrayRef =
+ Arc::new(Int32Array::from_iter(vec![Some(4), Some(5), Some(6)]));
let schema = Schema::new(vec![
Field::new("a", DataType::Int32, false),
@@ -575,8 +585,8 @@
#[test]
fn covariance_i32_all_nulls() -> Result<()> {
- let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None]));
- let b: ArrayRef = Arc::new(Int32Array::from(vec![None, None]));
+ let a: ArrayRef = Arc::new(Int32Array::from_iter(vec![None, None]));
+ let b: ArrayRef = Arc::new(Int32Array::from_iter(vec![None, None]));
let schema = Schema::new(vec![
Field::new("a", DataType::Int32, false),
@@ -598,10 +608,10 @@
#[test]
fn covariance_f64_merge_1() -> Result<()> {
- let a = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64]));
- let b = Arc::new(Float64Array::from(vec![4_f64, 5_f64, 6_f64]));
- let c = Arc::new(Float64Array::from(vec![1.1_f64, 2.2_f64, 3.3_f64]));
- let d = Arc::new(Float64Array::from(vec![4.4_f64, 5.5_f64, 6.6_f64]));
+ let a = Arc::new(Float64Array::from_slice(vec![1_f64, 2_f64, 3_f64]));
+ let b = Arc::new(Float64Array::from_slice(vec![4_f64, 5_f64, 6_f64]));
+ let c = Arc::new(Float64Array::from_slice(vec![1.1_f64, 2.2_f64, 3.3_f64]));
+ let d = Arc::new(Float64Array::from_slice(vec![4.4_f64, 5.5_f64, 6.6_f64]));
let schema = Schema::new(vec![
Field::new("a", DataType::Float64, false),
@@ -633,10 +643,10 @@
#[test]
fn covariance_f64_merge_2() -> Result<()> {
- let a = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64]));
- let b = Arc::new(Float64Array::from(vec![4_f64, 5_f64, 6_f64]));
- let c = Arc::new(Float64Array::from(vec![None]));
- let d = Arc::new(Float64Array::from(vec![None]));
+ let a = Arc::new(Float64Array::from_slice(vec![1_f64, 2_f64, 3_f64]));
+ let b = Arc::new(Float64Array::from_slice(vec![4_f64, 5_f64, 6_f64]));
+ let c = Arc::new(Float64Array::from_iter(vec![None]));
+ let d = Arc::new(Float64Array::from_iter(vec![None]));
let schema = Schema::new(vec![
Field::new("a", DataType::Float64, false),
diff --git a/datafusion-physical-expr/src/expressions/cume_dist.rs b/datafusion-physical-expr/src/expressions/cume_dist.rs
index 9cd28a3..028679f 100644
--- a/datafusion-physical-expr/src/expressions/cume_dist.rs
+++ b/datafusion-physical-expr/src/expressions/cume_dist.rs
@@ -24,7 +24,7 @@
use arrow::array::ArrayRef;
use arrow::array::Float64Array;
use arrow::datatypes::{DataType, Field};
-use arrow::record_batch::RecordBatch;
+use datafusion_common::record_batch::RecordBatch;
use datafusion_common::Result;
use std::any::Any;
use std::iter;
@@ -89,18 +89,18 @@
ranks_in_partition: &[Range<usize>],
) -> Result<ArrayRef> {
let scaler = (partition.end - partition.start) as f64;
- let result = Float64Array::from_iter_values(
- ranks_in_partition
- .iter()
- .scan(0_u64, |acc, range| {
- let len = range.end - range.start;
- *acc += len as u64;
- let value: f64 = (*acc as f64) / scaler;
- let result = iter::repeat(value).take(len);
- Some(result)
- })
- .flatten(),
- );
+ let result = ranks_in_partition
+ .iter()
+ .scan(0_u64, |acc, range| {
+ let len = range.end - range.start;
+ *acc += len as u64;
+ let value: f64 = (*acc as f64) / scaler;
+ let result = iter::repeat(value).take(len);
+ Some(result)
+ })
+ .flatten()
+ .collect::<Vec<_>>();
+ let result = Float64Array::from_values(result);
Ok(Arc::new(result))
}
}
@@ -109,6 +109,7 @@
mod tests {
use super::*;
use arrow::{array::*, datatypes::*};
+ use datafusion_common::field_util::SchemaExt;
fn test_i32_result(
expr: &CumeDist,
@@ -117,7 +118,7 @@
ranks: Vec<Range<usize>>,
expected: Vec<f64>,
) -> Result<()> {
- let arr: ArrayRef = Arc::new(Int32Array::from(data));
+ let arr: ArrayRef = Arc::new(Int32Array::from_slice(data));
let values = vec![arr];
let schema = Schema::new(vec![Field::new("arr", DataType::Int32, false)]);
let batch = RecordBatch::try_new(Arc::new(schema), values.clone())?;
@@ -127,7 +128,7 @@
assert_eq!(1, result.len());
let result = result[0].as_any().downcast_ref::<Float64Array>().unwrap();
let result = result.values();
- assert_eq!(expected, result);
+ assert_eq!(expected, result.as_slice());
Ok(())
}
diff --git a/datafusion-physical-expr/src/expressions/distinct_expressions.rs b/datafusion-physical-expr/src/expressions/distinct_expressions.rs
index c249ca8..b20b4f5 100644
--- a/datafusion-physical-expr/src/expressions/distinct_expressions.rs
+++ b/datafusion-physical-expr/src/expressions/distinct_expressions.rs
@@ -17,16 +17,19 @@
//! Implementations for DISTINCT expressions, e.g. `COUNT(DISTINCT c)`
-use arrow::datatypes::{DataType, Field};
use std::any::Any;
use std::fmt::Debug;
-use std::hash::Hash;
use std::sync::Arc;
use ahash::RandomState;
-use arrow::array::{Array, ArrayRef};
+use arrow::array::ArrayRef;
use std::collections::HashSet;
+use arrow::{
+ array::*,
+ datatypes::{DataType, Field},
+};
+
use crate::{AggregateExpr, PhysicalExpr};
use datafusion_common::ScalarValue;
use datafusion_common::{DataFusionError, Result};
@@ -75,7 +78,7 @@
fn state_type(data_type: DataType) -> DataType {
match data_type {
// when aggregating dictionary values, use the underlying value type
- DataType::Dictionary(_key_type, value_type) => *value_type,
+ DataType::Dictionary(_key_type, value_type, _) => *value_type,
t => t,
}
}
@@ -97,11 +100,7 @@
.map(|state_data_type| {
Field::new(
&format_state_name(&self.name, "count distinct"),
- DataType::List(Box::new(Field::new(
- "item",
- state_data_type.clone(),
- true,
- ))),
+ ListArray::<i32>::default_datatype(state_data_type.clone()),
false,
)
})
@@ -363,43 +362,12 @@
#[cfg(test)]
mod tests {
use super::*;
+
use crate::expressions::col;
use crate::expressions::tests::aggregate;
- use arrow::array::{
- ArrayRef, BooleanArray, Float32Array, Float64Array, Int16Array, Int32Array,
- Int64Array, Int8Array, ListArray, UInt16Array, UInt32Array, UInt64Array,
- UInt8Array,
- };
- use arrow::array::{Int32Builder, ListBuilder, UInt64Builder};
use arrow::datatypes::{DataType, Schema};
- use arrow::record_batch::RecordBatch;
-
- macro_rules! build_list {
- ($LISTS:expr, $BUILDER_TYPE:ident) => {{
- let mut builder = ListBuilder::new($BUILDER_TYPE::new(0));
- for list in $LISTS.iter() {
- match list {
- Some(values) => {
- for value in values.iter() {
- match value {
- Some(v) => builder.values().append_value((*v).into())?,
- None => builder.values().append_null()?,
- }
- }
-
- builder.append(true)?;
- }
- None => {
- builder.append(false)?;
- }
- }
- }
-
- let array = Arc::new(builder.finish()) as ArrayRef;
-
- Ok(array) as Result<ArrayRef>
- }};
- }
+ use datafusion_common::field_util::SchemaExt;
+ use datafusion_common::record_batch::RecordBatch;
macro_rules! state_to_vec {
($LIST:expr, $DATA_TYPE:ident, $PRIM_TY:ty) => {{
@@ -494,7 +462,7 @@
let agg = DistinctCount::new(
arrays
.iter()
- .map(|a| a.as_any().downcast_ref::<ListArray>().unwrap())
+ .map(|a| a.as_any().downcast_ref::<ListArray<i32>>().unwrap())
.map(|a| a.values().data_type().clone())
.collect::<Vec<_>>(),
vec![],
@@ -677,14 +645,15 @@
Ok((state_vec, count))
};
- let zero_count_values = BooleanArray::from(Vec::<bool>::new());
+ let zero_count_values = BooleanArray::from_slice(&[]);
- let one_count_values = BooleanArray::from(vec![false, false]);
+ let one_count_values = BooleanArray::from_slice(vec![false, false]);
let one_count_values_with_null =
- BooleanArray::from(vec![Some(true), Some(true), None, None]);
+ BooleanArray::from_iter(vec![Some(true), Some(true), None, None]);
- let two_count_values = BooleanArray::from(vec![true, false, true, false, true]);
- let two_count_values_with_null = BooleanArray::from(vec![
+ let two_count_values =
+ BooleanArray::from_slice(vec![true, false, true, false, true]);
+ let two_count_values_with_null = BooleanArray::from_iter(vec![
Some(true),
Some(false),
None,
@@ -730,7 +699,7 @@
#[test]
fn count_distinct_update_batch_empty() -> Result<()> {
- let arrays = vec![Arc::new(Int32Array::from(vec![0_i32; 0])) as ArrayRef];
+ let arrays = vec![Arc::new(Int32Array::new_empty(DataType::Int32)) as ArrayRef];
let (states, result) = run_update_batch(&arrays)?;
@@ -743,8 +712,8 @@
#[test]
fn count_distinct_update_batch_multiple_columns() -> Result<()> {
- let array_int8: ArrayRef = Arc::new(Int8Array::from(vec![1, 1, 2]));
- let array_int16: ArrayRef = Arc::new(Int16Array::from(vec![3, 3, 4]));
+ let array_int8: ArrayRef = Arc::new(Int8Array::from_slice(vec![1, 1, 2]));
+ let array_int16: ArrayRef = Arc::new(Int16Array::from_slice(vec![3, 3, 4]));
let arrays = vec![array_int8, array_int16];
let (states, result) = run_update_batch(&arrays)?;
@@ -833,23 +802,24 @@
#[test]
fn count_distinct_merge_batch() -> Result<()> {
- let state_in1 = build_list!(
- vec![
- Some(vec![Some(-1_i32), Some(-1_i32), Some(-2_i32), Some(-2_i32)]),
- Some(vec![Some(-2_i32), Some(-3_i32)]),
- ],
- Int32Builder
- )?;
+ let state_in1 = vec![
+ Some(vec![Some(-1_i32), Some(-1_i32), Some(-2_i32), Some(-2_i32)]),
+ Some(vec![Some(-2_i32), Some(-3_i32)]),
+ ];
+ let mut array = MutableListArray::<i32, MutablePrimitiveArray<i32>>::new();
+ array.try_extend(state_in1)?;
+ let state_in1: ListArray<i32> = array.into();
- let state_in2 = build_list!(
- vec![
- Some(vec![Some(5_u64), Some(6_u64), Some(5_u64), Some(7_u64)]),
- Some(vec![Some(5_u64), Some(7_u64)]),
- ],
- UInt64Builder
- )?;
+ let state_in2 = vec![
+ Some(vec![Some(5_u64), Some(6_u64), Some(5_u64), Some(7_u64)]),
+ Some(vec![Some(5_u64), Some(7_u64)]),
+ ];
+ let mut array = MutableListArray::<i32, MutablePrimitiveArray<u64>>::new();
+ array.try_extend(state_in2)?;
+ let state_in2: ListArray<i32> = array.into();
- let (states, result) = run_merge_batch(&[state_in1, state_in2])?;
+ let (states, result) =
+ run_merge_batch(&[Arc::new(state_in1), Arc::new(state_in2)])?;
let state_out_vec1 = state_to_vec!(&states[0], Int32, i32).unwrap();
let state_out_vec2 = state_to_vec!(&states[1], UInt64, u64).unwrap();
@@ -908,7 +878,7 @@
#[test]
fn distinct_array_agg_i32() -> Result<()> {
- let col: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 4, 5, 2]));
+ let col: ArrayRef = Arc::new(Int32Array::from_slice(&[1, 2, 7, 4, 5, 2]));
let out = ScalarValue::List(
Some(Box::new(vec![
diff --git a/datafusion-physical-expr/src/expressions/get_indexed_field.rs b/datafusion-physical-expr/src/expressions/get_indexed_field.rs
index 26a5cf2..84894d4 100644
--- a/datafusion-physical-expr/src/expressions/get_indexed_field.rs
+++ b/datafusion-physical-expr/src/expressions/get_indexed_field.rs
@@ -17,22 +17,22 @@
//! get field of a `ListArray`
-use crate::{field_util::get_indexed_field as get_data_type_field, PhysicalExpr};
-use arrow::array::Array;
-use arrow::array::{ListArray, StructArray};
-use arrow::compute::concat;
-use arrow::{
- datatypes::{DataType, Schema},
- record_batch::RecordBatch,
-};
-use datafusion_common::DataFusionError;
-use datafusion_common::Result;
-use datafusion_common::ScalarValue;
-use datafusion_expr::ColumnarValue;
use std::convert::TryInto;
-use std::fmt::Debug;
use std::{any::Any, sync::Arc};
+use arrow::datatypes::{DataType, Schema};
+use datafusion_common::record_batch::RecordBatch;
+
+use crate::{field_util::get_indexed_field as get_data_type_field, PhysicalExpr};
+use arrow::array::{Array, ListArray, StructArray};
+use arrow::compute::concatenate::concatenate;
+use datafusion_common::field_util::FieldExt;
+use datafusion_common::{
+ field_util::StructArrayExt, DataFusionError, Result, ScalarValue,
+};
+use datafusion_expr::ColumnarValue;
+use std::fmt::Debug;
+
/// expression to get a field of a struct array.
#[derive(Debug)]
pub struct GetIndexedFieldExpr {
@@ -83,18 +83,18 @@
}
(DataType::List(_), ScalarValue::Int64(Some(i))) => {
let as_list_array =
- array.as_any().downcast_ref::<ListArray>().unwrap();
+ array.as_any().downcast_ref::<ListArray<i32>>().unwrap();
if as_list_array.is_empty() {
let scalar_null: ScalarValue = array.data_type().try_into()?;
return Ok(ColumnarValue::Scalar(scalar_null))
}
let sliced_array: Vec<Arc<dyn Array>> = as_list_array
.iter()
- .filter_map(|o| o.map(|list| list.slice(*i as usize, 1)))
+ .filter_map(|o| o.map(|list| list.slice(*i as usize, 1).into()))
.collect();
let vec = sliced_array.iter().map(|a| a.as_ref()).collect::<Vec<&dyn Array>>();
- let iter = concat(vec.as_slice()).unwrap();
- Ok(ColumnarValue::Array(iter))
+ let iter = concatenate(vec.as_slice()).unwrap();
+ Ok(ColumnarValue::Array(iter.into()))
}
(DataType::Struct(_), ScalarValue::Utf8(Some(k))) => {
let as_struct_array = array.as_any().downcast_ref::<StructArray>().unwrap();
@@ -103,7 +103,7 @@
Some(col) => Ok(ColumnarValue::Array(col.clone()))
}
}
- (dt, key) => Err(DataFusionError::NotImplemented(format!("get indexed field is only possible on lists with int64 indexes. Tried {} with {} index", dt, key))),
+ (dt, key) => Err(DataFusionError::NotImplemented(format!("get indexed field is only possible on lists with int64 indexes. Tried {:?} with {} index", dt, key))),
},
ColumnarValue::Scalar(_) => Err(DataFusionError::NotImplemented(
"field access is not yet implemented for scalar values".to_string(),
@@ -115,30 +115,21 @@
#[cfg(test)]
mod tests {
use super::*;
+
use crate::expressions::{col, lit};
- use arrow::array::GenericListArray;
use arrow::array::{
- Int64Array, Int64Builder, ListBuilder, StringBuilder, StructArray, StructBuilder,
+ Int64Array, MutableListArray, MutableUtf8Array, StructArray, Utf8Array,
};
- use arrow::{array::StringArray, datatypes::Field};
- use datafusion_common::Result;
+ use arrow::array::{TryExtend, TryPush};
+ use arrow::datatypes::Field;
+ use datafusion_common::field_util::SchemaExt;
- fn build_utf8_lists(list_of_lists: Vec<Vec<Option<&str>>>) -> GenericListArray<i32> {
- let builder = StringBuilder::new(list_of_lists.len());
- let mut lb = ListBuilder::new(builder);
+ fn build_utf8_lists(list_of_lists: Vec<Vec<Option<&str>>>) -> ListArray<i32> {
+ let mut array = MutableListArray::<i32, MutableUtf8Array<i32>>::new();
for values in list_of_lists {
- let builder = lb.values();
- for value in values {
- match value {
- None => builder.append_null(),
- Some(v) => builder.append_value(v),
- }
- .unwrap()
- }
- lb.append(true).unwrap();
+ array.try_push(Some(values)).unwrap();
}
-
- lb.finish()
+ array.into()
}
fn get_indexed_field_test(
@@ -155,9 +146,9 @@
let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
let result = result
.as_any()
- .downcast_ref::<StringArray>()
- .expect("failed to downcast to StringArray");
- let expected = &StringArray::from(expected);
+ .downcast_ref::<Utf8Array<i32>>()
+ .expect("failed to downcast to Utf8Array<i32>");
+ let expected = &Utf8Array::<i32>::from(expected);
assert_eq!(expected, result);
Ok(())
}
@@ -192,10 +183,13 @@
#[test]
fn get_indexed_field_empty_list() -> Result<()> {
let schema = list_schema("l");
- let builder = StringBuilder::new(0);
- let mut lb = ListBuilder::new(builder);
let expr = col("l", &schema).unwrap();
- let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(lb.finish())])?;
+ let batch = RecordBatch::try_new(
+ Arc::new(schema.clone()),
+ vec![Arc::new(ListArray::<i32>::new_empty(
+ schema.field(0).data_type.clone(),
+ ))],
+ )?;
let key = ScalarValue::Int64(Some(0));
let expr = Arc::new(GetIndexedFieldExpr::new(expr, key));
let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
@@ -209,9 +203,9 @@
key: ScalarValue,
expected: &str,
) -> Result<()> {
- let builder = StringBuilder::new(3);
- let mut lb = ListBuilder::new(builder);
- let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(lb.finish())])?;
+ let mut array = MutableListArray::<i32, MutableUtf8Array<i32>>::new();
+ array.try_extend(vec![Some(vec![Some("a")]), None, None])?;
+ let batch = RecordBatch::try_new(Arc::new(schema), vec![array.into_arc()])?;
let expr = Arc::new(GetIndexedFieldExpr::new(expr, key));
let r = expr.evaluate(&batch).map(|_| ());
assert!(r.is_err());
@@ -230,41 +224,27 @@
fn get_indexed_field_invalid_list_index() -> Result<()> {
let schema = list_schema("l");
let expr = col("l", &schema).unwrap();
- get_indexed_field_test_failure(schema, expr, ScalarValue::Int8(Some(0)), "This feature is not implemented: get indexed field is only possible on lists with int64 indexes. Tried List(Field { name: \"item\", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: None }) with 0 index")
+ get_indexed_field_test_failure(schema, expr, ScalarValue::Int8(Some(0)), "This feature is not implemented: get indexed field is only possible on lists with int64 indexes. Tried List(Field { name: \"item\", data_type: Utf8, is_nullable: true, metadata: {} }) with 0 index")
}
fn build_struct(
fields: Vec<Field>,
list_of_tuples: Vec<(Option<i64>, Vec<Option<&str>>)>,
) -> StructArray {
- let foo_builder = Int64Array::builder(list_of_tuples.len());
- let str_builder = StringBuilder::new(list_of_tuples.len());
- let bar_builder = ListBuilder::new(str_builder);
- let mut builder = StructBuilder::new(
- fields,
- vec![Box::new(foo_builder), Box::new(bar_builder)],
- );
+ let mut foo_values = Vec::new();
+ let mut bar_array = MutableListArray::<i32, MutableUtf8Array<i32>>::new();
+
for (int_value, list_value) in list_of_tuples {
- let fb = builder.field_builder::<Int64Builder>(0).unwrap();
- match int_value {
- None => fb.append_null(),
- Some(v) => fb.append_value(v),
- }
- .unwrap();
- builder.append(true).unwrap();
- let lb = builder
- .field_builder::<ListBuilder<StringBuilder>>(1)
- .unwrap();
- for str_value in list_value {
- match str_value {
- None => lb.values().append_null(),
- Some(v) => lb.values().append_value(v),
- }
- .unwrap();
- }
- lb.append(true).unwrap();
+ foo_values.push(int_value);
+ bar_array.try_push(Some(list_value)).unwrap();
}
- builder.finish()
+
+ let foo = Arc::new(Int64Array::from(foo_values));
+ StructArray::from_data(
+ DataType::Struct(fields),
+ vec![foo, bar_array.into_arc()],
+ None,
+ )
}
fn get_indexed_field_mixed_test(
@@ -312,7 +292,7 @@
let result = get_list_expr.evaluate(&batch)?.into_array(batch.num_rows());
let result = result
.as_any()
- .downcast_ref::<ListArray>()
+ .downcast_ref::<ListArray<i32>>()
.unwrap_or_else(|| panic!("failed to downcast to ListArray : {:?}", result));
let expected =
&build_utf8_lists(list_of_tuples.into_iter().map(|t| t.1).collect());
@@ -328,11 +308,11 @@
.into_array(batch.num_rows());
let result = result
.as_any()
- .downcast_ref::<StringArray>()
+ .downcast_ref::<Utf8Array<i32>>()
.unwrap_or_else(|| {
- panic!("failed to downcast to StringArray : {:?}", result)
+ panic!("failed to downcast to Utf8Array<i32>: {:?}", result)
});
- let expected = &StringArray::from(expected);
+ let expected = &Utf8Array::<i32>::from(expected);
assert_eq!(expected, result);
}
Ok(())
diff --git a/datafusion-physical-expr/src/expressions/in_list.rs b/datafusion-physical-expr/src/expressions/in_list.rs
index 2aee0d8..a378028 100644
--- a/datafusion-physical-expr/src/expressions/in_list.rs
+++ b/datafusion-physical-expr/src/expressions/in_list.rs
@@ -20,46 +20,44 @@
use std::any::Any;
use std::sync::Arc;
-use arrow::array::GenericStringArray;
use arrow::array::{
ArrayRef, BooleanArray, Float32Array, Float64Array, Int16Array, Int32Array,
- Int64Array, Int8Array, StringOffsetSizeTrait, UInt16Array, UInt32Array, UInt64Array,
- UInt8Array,
+ Int64Array, Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
};
-use arrow::datatypes::ArrowPrimitiveType;
-use arrow::{
- datatypes::{DataType, Schema},
- record_batch::RecordBatch,
-};
+use arrow::datatypes::{DataType, Schema};
use crate::PhysicalExpr;
-use arrow::array::*;
-use arrow::buffer::{Buffer, MutableBuffer};
-use datafusion_common::ScalarValue;
-use datafusion_common::{DataFusionError, Result};
+use arrow::types::NativeType;
+use arrow::{array::*, bitmap::Bitmap};
+use datafusion_common::record_batch::RecordBatch;
+use datafusion_common::{DataFusionError, Result, ScalarValue};
use datafusion_expr::ColumnarValue;
macro_rules! compare_op_scalar {
($left: expr, $right:expr, $op:expr) => {{
- let null_bit_buffer = $left.data().null_buffer().cloned();
+ let validity = $left.validity();
+ let values =
+ Bitmap::from_trusted_len_iter($left.values_iter().map(|x| $op(x, $right)));
+ Ok(BooleanArray::from_data(
+ DataType::Boolean,
+ values,
+ validity.cloned(),
+ ))
+ }};
+}
- let comparison =
- (0..$left.len()).map(|i| unsafe { $op($left.value_unchecked(i), $right) });
- // same as $left.len()
- let buffer = unsafe { MutableBuffer::from_trusted_len_iter_bool(comparison) };
-
- let data = unsafe {
- ArrayData::new_unchecked(
- DataType::Boolean,
- $left.len(),
- None,
- null_bit_buffer,
- 0,
- vec![Buffer::from(buffer)],
- vec![],
- )
- };
- Ok(BooleanArray::from(data))
+// TODO: primitive array currently doesn't have `values_iter()`, it may
+// worth adding one there, and this specialized case could be removed.
+macro_rules! compare_primitive_op_scalar {
+ ($left: expr, $right:expr, $op:expr) => {{
+ let validity = $left.validity();
+ let values =
+ Bitmap::from_trusted_len_iter($left.values().iter().map(|x| $op(x, $right)));
+ Ok(BooleanArray::from_data(
+ DataType::Boolean,
+ values,
+ validity.cloned(),
+ ))
}};
}
@@ -182,39 +180,31 @@
}
// whether each value on the left (can be null) is contained in the non-null list
-fn in_list_primitive<T: ArrowPrimitiveType>(
+fn in_list_primitive<T: NativeType>(
array: &PrimitiveArray<T>,
- values: &[<T as ArrowPrimitiveType>::Native],
+ values: &[T],
) -> Result<BooleanArray> {
- compare_op_scalar!(
- array,
- values,
- |x, v: &[<T as ArrowPrimitiveType>::Native]| v.contains(&x)
- )
+ compare_primitive_op_scalar!(array, values, |x, v: &[T]| v.contains(x))
}
// whether each value on the left (can be null) is contained in the non-null list
-fn not_in_list_primitive<T: ArrowPrimitiveType>(
+fn not_in_list_primitive<T: NativeType>(
array: &PrimitiveArray<T>,
- values: &[<T as ArrowPrimitiveType>::Native],
+ values: &[T],
) -> Result<BooleanArray> {
- compare_op_scalar!(
- array,
- values,
- |x, v: &[<T as ArrowPrimitiveType>::Native]| !v.contains(&x)
- )
+ compare_primitive_op_scalar!(array, values, |x, v: &[T]| !v.contains(x))
}
// whether each value on the left (can be null) is contained in the non-null list
-fn in_list_utf8<OffsetSize: StringOffsetSizeTrait>(
- array: &GenericStringArray<OffsetSize>,
+fn in_list_utf8<OffsetSize: Offset>(
+ array: &Utf8Array<OffsetSize>,
values: &[&str],
) -> Result<BooleanArray> {
compare_op_scalar!(array, values, |x, v: &[&str]| v.contains(&x))
}
-fn not_in_list_utf8<OffsetSize: StringOffsetSizeTrait>(
- array: &GenericStringArray<OffsetSize>,
+fn not_in_list_utf8<OffsetSize: Offset>(
+ array: &Utf8Array<OffsetSize>,
values: &[&str],
) -> Result<BooleanArray> {
compare_op_scalar!(array, values, |x, v: &[&str]| !v.contains(&x))
@@ -251,16 +241,13 @@
/// Compare for specific utf8 types
#[allow(clippy::unnecessary_wraps)]
- fn compare_utf8<T: StringOffsetSizeTrait>(
+ fn compare_utf8<T: Offset>(
&self,
array: ArrayRef,
list_values: Vec<ColumnarValue>,
negated: bool,
) -> Result<ColumnarValue> {
- let array = array
- .as_any()
- .downcast_ref::<GenericStringArray<T>>()
- .unwrap();
+ let array = array.as_any().downcast_ref::<Utf8Array<T>>().unwrap();
let contains_null = list_values
.iter()
@@ -470,7 +457,10 @@
#[cfg(test)]
mod tests {
- use arrow::{array::StringArray, datatypes::Field};
+ use arrow::{array::Utf8Array, datatypes::Field};
+ use datafusion_common::field_util::SchemaExt;
+
+ type StringArray = Utf8Array<i32>;
use super::*;
use crate::expressions::{col, lit};
@@ -493,7 +483,7 @@
#[test]
fn in_list_utf8() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
- let a = StringArray::from(vec![Some("a"), Some("d"), None]);
+ let a = StringArray::from_iter(vec![Some("a"), Some("d"), None]);
let col_a = col("a", &schema)?;
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
@@ -557,7 +547,7 @@
#[test]
fn in_list_int64() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Int64, true)]);
- let a = Int64Array::from(vec![Some(0), Some(2), None]);
+ let a = Int64Array::from_iter(vec![Some(0), Some(2), None]);
let col_a = col("a", &schema)?;
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
@@ -621,7 +611,7 @@
#[test]
fn in_list_float64() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]);
- let a = Float64Array::from(vec![Some(0.0), Some(0.2), None]);
+ let a = Float64Array::from_iter(vec![Some(0.0), Some(0.2), None]);
let col_a = col("a", &schema)?;
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
@@ -685,7 +675,7 @@
#[test]
fn in_list_bool() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Boolean, true)]);
- let a = BooleanArray::from(vec![Some(true), None]);
+ let a = BooleanArray::from_iter(vec![Some(true), None]);
let col_a = col("a", &schema)?;
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
diff --git a/datafusion-physical-expr/src/expressions/is_not_null.rs b/datafusion-physical-expr/src/expressions/is_not_null.rs
index 6b614f3..2934053 100644
--- a/datafusion-physical-expr/src/expressions/is_not_null.rs
+++ b/datafusion-physical-expr/src/expressions/is_not_null.rs
@@ -21,10 +21,8 @@
use crate::PhysicalExpr;
use arrow::compute;
-use arrow::{
- datatypes::{DataType, Schema},
- record_batch::RecordBatch,
-};
+use arrow::datatypes::{DataType, Schema};
+use datafusion_common::record_batch::RecordBatch;
use datafusion_common::Result;
use datafusion_common::ScalarValue;
use datafusion_expr::ColumnarValue;
@@ -72,7 +70,7 @@
let arg = self.arg.evaluate(batch)?;
match arg {
ColumnarValue::Array(array) => Ok(ColumnarValue::Array(Arc::new(
- compute::is_not_null(array.as_ref())?,
+ compute::boolean::is_not_null(array.as_ref()),
))),
ColumnarValue::Scalar(scalar) => Ok(ColumnarValue::Scalar(
ScalarValue::Boolean(Some(!scalar.is_null())),
@@ -91,16 +89,19 @@
use super::*;
use crate::expressions::col;
use arrow::{
- array::{BooleanArray, StringArray},
+ array::{BooleanArray, Utf8Array},
datatypes::*,
- record_batch::RecordBatch,
};
+ use datafusion_common::field_util::SchemaExt;
+ use datafusion_common::record_batch::RecordBatch;
use std::sync::Arc;
+ type StringArray = Utf8Array<i32>;
+
#[test]
fn is_not_null_op() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
- let a = StringArray::from(vec![Some("foo"), None]);
+ let a = StringArray::from_iter(vec![Some("foo"), None]);
let expr = is_not_null(col("a", &schema)?).unwrap();
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
@@ -111,7 +112,7 @@
.downcast_ref::<BooleanArray>()
.expect("failed to downcast to BooleanArray");
- let expected = &BooleanArray::from(vec![true, false]);
+ let expected = &BooleanArray::from_slice(vec![true, false]);
assert_eq!(expected, result);
diff --git a/datafusion-physical-expr/src/expressions/is_null.rs b/datafusion-physical-expr/src/expressions/is_null.rs
index e5dbfbd..63a7a4f 100644
--- a/datafusion-physical-expr/src/expressions/is_null.rs
+++ b/datafusion-physical-expr/src/expressions/is_null.rs
@@ -20,10 +20,8 @@
use std::{any::Any, sync::Arc};
use arrow::compute;
-use arrow::{
- datatypes::{DataType, Schema},
- record_batch::RecordBatch,
-};
+use arrow::datatypes::{DataType, Schema};
+use datafusion_common::record_batch::RecordBatch;
use crate::PhysicalExpr;
use datafusion_common::Result;
@@ -73,7 +71,7 @@
let arg = self.arg.evaluate(batch)?;
match arg {
ColumnarValue::Array(array) => Ok(ColumnarValue::Array(Arc::new(
- compute::is_null(array.as_ref())?,
+ compute::boolean::is_null(array.as_ref()),
))),
ColumnarValue::Scalar(scalar) => Ok(ColumnarValue::Scalar(
ScalarValue::Boolean(Some(scalar.is_null())),
@@ -92,16 +90,18 @@
use super::*;
use crate::expressions::col;
use arrow::{
- array::{BooleanArray, StringArray},
+ array::{BooleanArray, Utf8Array},
datatypes::*,
- record_batch::RecordBatch,
};
+ use datafusion_common::field_util::SchemaExt;
use std::sync::Arc;
+ type StringArray = Utf8Array<i32>;
+
#[test]
fn is_null_op() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
- let a = StringArray::from(vec![Some("foo"), None]);
+ let a = StringArray::from_iter(vec![Some("foo"), None]);
// expression: "a is null"
let expr = is_null(col("a", &schema)?).unwrap();
@@ -113,7 +113,7 @@
.downcast_ref::<BooleanArray>()
.expect("failed to downcast to BooleanArray");
- let expected = &BooleanArray::from(vec![false, true]);
+ let expected = &BooleanArray::from_slice(vec![false, true]);
assert_eq!(expected, result);
diff --git a/datafusion-physical-expr/src/expressions/lead_lag.rs b/datafusion-physical-expr/src/expressions/lead_lag.rs
index 4e286d5..90828ea 100644
--- a/datafusion-physical-expr/src/expressions/lead_lag.rs
+++ b/datafusion-physical-expr/src/expressions/lead_lag.rs
@@ -18,16 +18,17 @@
//! Defines physical expression for `lead` and `lag` that can evaluated
//! at runtime during query execution
+use crate::expressions::cast::cast_with_error;
use crate::window::partition_evaluator::PartitionEvaluator;
use crate::window::BuiltInWindowFunctionExpr;
use crate::PhysicalExpr;
use arrow::array::ArrayRef;
-use arrow::compute::cast;
+use arrow::compute::{cast, concatenate};
use arrow::datatypes::{DataType, Field};
-use arrow::record_batch::RecordBatch;
-use datafusion_common::ScalarValue;
-use datafusion_common::{DataFusionError, Result};
+use datafusion_common::record_batch::RecordBatch;
+use datafusion_common::{DataFusionError, Result, ScalarValue};
use std::any::Any;
+use std::borrow::Borrow;
use std::ops::Neg;
use std::ops::Range;
use std::sync::Arc;
@@ -128,9 +129,10 @@
let array = value
.as_ref()
.map(|scalar| scalar.to_array_of_size(size))
- .unwrap_or_else(|| new_null_array(data_type, size));
+ .unwrap_or_else(|| ArrayRef::from(new_null_array(data_type.clone(), size)));
if array.data_type() != data_type {
- cast(&array, data_type).map_err(DataFusionError::ArrowError)
+ cast_with_error(array.borrow(), data_type, cast::CastOptions::default())
+ .map(ArrayRef::from)
} else {
Ok(array)
}
@@ -142,11 +144,9 @@
offset: i64,
value: &Option<ScalarValue>,
) -> Result<ArrayRef> {
- use arrow::compute::concat;
-
let value_len = array.len() as i64;
if offset == 0 {
- Ok(arrow::array::make_array(array.data_ref().clone()))
+ Ok(array.clone())
} else if offset == i64::MIN || offset.abs() >= value_len {
create_empty_array(value, array.data_type(), array.len())
} else {
@@ -159,11 +159,13 @@
let default_values = create_empty_array(value, slice.data_type(), nulls)?;
// Concatenate both arrays, add nulls after if shift > 0 else before
if offset > 0 {
- concat(&[default_values.as_ref(), slice.as_ref()])
+ concatenate::concatenate(&[default_values.as_ref(), slice.as_ref()])
.map_err(DataFusionError::ArrowError)
+ .map(ArrayRef::from)
} else {
- concat(&[slice.as_ref(), default_values.as_ref()])
+ concatenate::concatenate(&[slice.as_ref(), default_values.as_ref()])
.map_err(DataFusionError::ArrowError)
+ .map(ArrayRef::from)
}
}
}
@@ -172,20 +174,27 @@
fn evaluate_partition(&self, partition: Range<usize>) -> Result<ArrayRef> {
let value = &self.values[0];
let value = value.slice(partition.start, partition.end - partition.start);
- shift_with_default_value(&value, self.shift_offset, &self.default_value)
+ shift_with_default_value(
+ ArrayRef::from(value).borrow(),
+ self.shift_offset,
+ &self.default_value,
+ )
}
}
#[cfg(test)]
mod tests {
use super::*;
+
use crate::expressions::Column;
- use arrow::record_batch::RecordBatch;
use arrow::{array::*, datatypes::*};
+ use datafusion_common::field_util::SchemaExt;
+ use datafusion_common::record_batch::RecordBatch;
use datafusion_common::Result;
fn test_i32_result(expr: WindowShift, expected: Int32Array) -> Result<()> {
- let arr: ArrayRef = Arc::new(Int32Array::from(vec![1, -2, 3, -4, 5, -6, 7, 8]));
+ let arr: ArrayRef =
+ Arc::new(Int32Array::from_slice(vec![1, -2, 3, -4, 5, -6, 7, 8]));
let values = vec![arr];
let schema = Schema::new(vec![Field::new("arr", DataType::Int32, false)]);
let batch = RecordBatch::try_new(Arc::new(schema), values.clone())?;
diff --git a/datafusion-physical-expr/src/expressions/literal.rs b/datafusion-physical-expr/src/expressions/literal.rs
index 6fff67e..e053072 100644
--- a/datafusion-physical-expr/src/expressions/literal.rs
+++ b/datafusion-physical-expr/src/expressions/literal.rs
@@ -20,10 +20,8 @@
use std::any::Any;
use std::sync::Arc;
-use arrow::{
- datatypes::{DataType, Schema},
- record_batch::RecordBatch,
-};
+use arrow::datatypes::{DataType, Schema};
+use datafusion_common::record_batch::RecordBatch;
use crate::PhysicalExpr;
use datafusion_common::Result;
@@ -81,15 +79,17 @@
#[cfg(test)]
mod tests {
use super::*;
- use arrow::array::Int32Array;
+
+ use arrow::array::*;
use arrow::datatypes::*;
+ use datafusion_common::field_util::SchemaExt;
use datafusion_common::Result;
#[test]
fn literal_i32() -> Result<()> {
// create an arbitrary record bacth
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
- let a = Int32Array::from(vec![Some(1), None, Some(3), Some(4), Some(5)]);
+ let a = Int32Array::from_iter(vec![Some(1), None, Some(3), Some(4), Some(5)]);
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
// create and evaluate a literal expression
diff --git a/datafusion-physical-expr/src/expressions/min_max.rs b/datafusion-physical-expr/src/expressions/min_max.rs
index a599d65..8be2588 100644
--- a/datafusion-physical-expr/src/expressions/min_max.rs
+++ b/datafusion-physical-expr/src/expressions/min_max.rs
@@ -21,32 +21,25 @@
use std::convert::TryFrom;
use std::sync::Arc;
+use arrow::array::*;
+use arrow::compute::aggregate::*;
+use arrow::datatypes::*;
+
use crate::{AggregateExpr, PhysicalExpr};
-use arrow::compute;
-use arrow::datatypes::{DataType, TimeUnit};
-use arrow::{
- array::{
- ArrayRef, Date32Array, Date64Array, Float32Array, Float64Array, Int16Array,
- Int32Array, Int64Array, Int8Array, LargeStringArray, StringArray,
- TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray,
- TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
- },
- datatypes::Field,
-};
-use datafusion_common::ScalarValue;
-use datafusion_common::{DataFusionError, Result};
+use datafusion_common::{DataFusionError, Result, ScalarValue};
use datafusion_expr::Accumulator;
+type StringArray = Utf8Array<i32>;
+type LargeStringArray = Utf8Array<i64>;
+
use super::format_state_name;
-use arrow::array::Array;
-use arrow::array::DecimalArray;
// Min/max aggregation can take Dictionary encode input but always produces unpacked
// (aka non Dictionary) output. We need to adjust the output data type to reflect this.
// The reason min/max aggregate produces unpacked output because there is only one
// min/max value per group; there is no needs to keep them Dictionary encode
fn min_max_aggregate_data_type(input_type: DataType) -> DataType {
- if let DataType::Dictionary(_, value_type) = input_type {
+ if let DataType::Dictionary(_, value_type, _) = input_type {
*value_type
} else {
input_type
@@ -117,7 +110,7 @@
macro_rules! typed_min_max_batch_string {
($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident) => {{
let array = $VALUES.as_any().downcast_ref::<$ARRAYTYPE>().unwrap();
- let value = compute::$OP(array);
+ let value = $OP(array);
let value = value.and_then(|e| Some(e.to_string()));
ScalarValue::$SCALAR(value)
}};
@@ -127,13 +120,13 @@
macro_rules! typed_min_max_batch {
($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident) => {{
let array = $VALUES.as_any().downcast_ref::<$ARRAYTYPE>().unwrap();
- let value = compute::$OP(array);
+ let value = $OP(array);
ScalarValue::$SCALAR(value)
}};
($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident, $TZ:expr) => {{
let array = $VALUES.as_any().downcast_ref::<$ARRAYTYPE>().unwrap();
- let value = compute::$OP(array);
+ let value = $OP(array);
ScalarValue::$SCALAR(value, $TZ.clone())
}};
}
@@ -147,7 +140,7 @@
if null_count == $VALUES.len() {
ScalarValue::Decimal128(None, *$PRECISION, *$SCALE)
} else {
- let array = $VALUES.as_any().downcast_ref::<DecimalArray>().unwrap();
+ let array = $VALUES.as_any().downcast_ref::<Int128Array>().unwrap();
if null_count == 0 {
// there is no null value
let mut result = array.value(0);
@@ -178,17 +171,10 @@
macro_rules! min_max_batch {
($VALUES:expr, $OP:ident) => {{
match $VALUES.data_type() {
- DataType::Decimal(precision, scale) => {
- typed_min_max_batch_decimal128!($VALUES, precision, scale, $OP)
- }
// all types that have a natural order
- DataType::Float64 => {
- typed_min_max_batch!($VALUES, Float64Array, Float64, $OP)
+ DataType::Int64 => {
+ typed_min_max_batch!($VALUES, Int64Array, Int64, $OP)
}
- DataType::Float32 => {
- typed_min_max_batch!($VALUES, Float32Array, Float32, $OP)
- }
- DataType::Int64 => typed_min_max_batch!($VALUES, Int64Array, Int64, $OP),
DataType::Int32 => typed_min_max_batch!($VALUES, Int32Array, Int32, $OP),
DataType::Int16 => typed_min_max_batch!($VALUES, Int16Array, Int16, $OP),
DataType::Int8 => typed_min_max_batch!($VALUES, Int8Array, Int8, $OP),
@@ -197,37 +183,31 @@
DataType::UInt16 => typed_min_max_batch!($VALUES, UInt16Array, UInt16, $OP),
DataType::UInt8 => typed_min_max_batch!($VALUES, UInt8Array, UInt8, $OP),
DataType::Timestamp(TimeUnit::Second, tz_opt) => {
- typed_min_max_batch!(
- $VALUES,
- TimestampSecondArray,
- TimestampSecond,
- $OP,
- tz_opt
- )
+ typed_min_max_batch!($VALUES, Int64Array, TimestampSecond, $OP, tz_opt)
}
DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => typed_min_max_batch!(
$VALUES,
- TimestampMillisecondArray,
+ Int64Array,
TimestampMillisecond,
$OP,
tz_opt
),
DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => typed_min_max_batch!(
$VALUES,
- TimestampMicrosecondArray,
+ Int64Array,
TimestampMicrosecond,
$OP,
tz_opt
),
DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => typed_min_max_batch!(
$VALUES,
- TimestampNanosecondArray,
+ Int64Array,
TimestampNanosecond,
$OP,
tz_opt
),
- DataType::Date32 => typed_min_max_batch!($VALUES, Date32Array, Date32, $OP),
- DataType::Date64 => typed_min_max_batch!($VALUES, Date64Array, Date64, $OP),
+ DataType::Date32 => typed_min_max_batch!($VALUES, Int32Array, Date32, $OP),
+ DataType::Date64 => typed_min_max_batch!($VALUES, Int64Array, Date64, $OP),
other => {
// This should have been handled before
return Err(DataFusionError::Internal(format!(
@@ -248,7 +228,16 @@
DataType::LargeUtf8 => {
typed_min_max_batch_string!(values, LargeStringArray, LargeUtf8, min_string)
}
- _ => min_max_batch!(values, min),
+ DataType::Float64 => {
+ typed_min_max_batch!(values, Float64Array, Float64, min_primitive)
+ }
+ DataType::Float32 => {
+ typed_min_max_batch!(values, Float32Array, Float32, min_primitive)
+ }
+ DataType::Decimal(precision, scale) => {
+ typed_min_max_batch_decimal128!(values, precision, scale, min)
+ }
+ _ => min_max_batch!(values, min_primitive),
})
}
@@ -261,7 +250,16 @@
DataType::LargeUtf8 => {
typed_min_max_batch_string!(values, LargeStringArray, LargeUtf8, max_string)
}
- _ => min_max_batch!(values, max),
+ DataType::Float64 => {
+ typed_min_max_batch!(values, Float64Array, Float64, max_primitive)
+ }
+ DataType::Float32 => {
+ typed_min_max_batch!(values, Float32Array, Float32, max_primitive)
+ }
+ DataType::Decimal(precision, scale) => {
+ typed_min_max_batch_decimal128!(values, precision, scale, max)
+ }
+ _ => min_max_batch!(values, max_primitive),
})
}
macro_rules! typed_min_max_decimal {
@@ -553,14 +551,12 @@
#[cfg(test)]
mod tests {
use super::*;
- use crate::expressions::col;
use crate::expressions::tests::aggregate;
use crate::generic_test_op;
- use arrow::datatypes::*;
- use arrow::record_batch::RecordBatch;
+ use datafusion_common::field_util::SchemaExt;
+ use datafusion_common::record_batch::RecordBatch;
use datafusion_common::Result;
use datafusion_common::ScalarValue;
- use datafusion_common::ScalarValue::Decimal128;
#[test]
fn min_decimal() -> Result<()> {
@@ -572,31 +568,25 @@
// min batch
let array: ArrayRef = Arc::new(
- (1..6)
- .map(Some)
- .collect::<DecimalArray>()
- .with_precision_and_scale(10, 0)?,
+ Int128Array::from_slice(&[1, 2, 3, 4, 5]).to(DataType::Decimal(10, 0)),
);
-
let result = min_batch(&array)?;
assert_eq!(result, ScalarValue::Decimal128(Some(1), 10, 0));
// min batch without values
- let array: ArrayRef = Arc::new(
- std::iter::repeat(None)
- .take(0)
- .collect::<DecimalArray>()
- .with_precision_and_scale(10, 0)?,
- );
+ let array: ArrayRef =
+ Arc::new(Int128Array::new_null(DataType::Decimal(10, 0), 0));
+ let result = min_batch(&array)?;
+ assert_eq!(ScalarValue::Decimal128(None, 10, 0), result);
+
+ let array: ArrayRef = Arc::new(Int128Array::new_empty(DataType::Decimal(10, 0)));
let result = min_batch(&array)?;
assert_eq!(ScalarValue::Decimal128(None, 10, 0), result);
// min batch with agg
let array: ArrayRef = Arc::new(
- (1..6)
- .map(Some)
- .collect::<DecimalArray>()
- .with_precision_and_scale(10, 0)?,
+ Int128Array::from_iter((1..6).map(Some).collect::<Vec<Option<i128>>>())
+ .to(DataType::Decimal(10, 0)),
);
generic_test_op!(
array,
@@ -610,12 +600,8 @@
#[test]
fn min_decimal_all_nulls() -> Result<()> {
// min batch all nulls
- let array: ArrayRef = Arc::new(
- std::iter::repeat(None)
- .take(6)
- .collect::<DecimalArray>()
- .with_precision_and_scale(10, 0)?,
- );
+ let array: ArrayRef =
+ Arc::new(Int128Array::new_null(DataType::Decimal(10, 0), 6));
generic_test_op!(
array,
DataType::Decimal(10, 0),
@@ -629,12 +615,13 @@
fn min_decimal_with_nulls() -> Result<()> {
// min batch with nulls
let array: ArrayRef = Arc::new(
- (1..6)
- .map(|i| if i == 2 { None } else { Some(i) })
- .collect::<DecimalArray>()
- .with_precision_and_scale(10, 0)?,
+ Int128Array::from_iter(
+ (1..6)
+ .map(|i| if i == 2 { None } else { Some(i) })
+ .collect::<Vec<Option<i128>>>(),
+ )
+ .to(DataType::Decimal(10, 0)),
);
-
generic_test_op!(
array,
DataType::Decimal(10, 0),
@@ -656,36 +643,28 @@
let result = max(&left, &right);
let expect = DataFusionError::Internal(format!(
"MIN/MAX is not expected to receive scalars of incompatible types {:?}",
- (Decimal128(Some(123), 10, 2), Decimal128(Some(124), 10, 3))
+ (DataType::Decimal(10, 2), DataType::Decimal(10, 3))
));
assert_eq!(expect.to_string(), result.unwrap_err().to_string());
// max batch
let array: ArrayRef = Arc::new(
- (1..6)
- .map(Some)
- .collect::<DecimalArray>()
- .with_precision_and_scale(10, 5)?,
+ Int128Array::from_slice((1..6).collect::<Vec<i128>>())
+ .to(DataType::Decimal(10, 5)),
);
let result = max_batch(&array)?;
assert_eq!(result, ScalarValue::Decimal128(Some(5), 10, 5));
// max batch without values
- let array: ArrayRef = Arc::new(
- std::iter::repeat(None)
- .take(0)
- .collect::<DecimalArray>()
- .with_precision_and_scale(10, 0)?,
- );
+ let array: ArrayRef =
+ Arc::new(Int128Array::new_null(DataType::Decimal(10, 0), 0));
let result = max_batch(&array)?;
assert_eq!(ScalarValue::Decimal128(None, 10, 0), result);
// max batch with agg
let array: ArrayRef = Arc::new(
- (1..6)
- .map(Some)
- .collect::<DecimalArray>()
- .with_precision_and_scale(10, 0)?,
+ Int128Array::from_iter((1..6).map(Some).collect::<Vec<Option<i128>>>())
+ .to(DataType::Decimal(10, 0)),
);
generic_test_op!(
array,
@@ -699,10 +678,12 @@
#[test]
fn max_decimal_with_nulls() -> Result<()> {
let array: ArrayRef = Arc::new(
- (1..6)
- .map(|i| if i == 2 { None } else { Some(i) })
- .collect::<DecimalArray>()
- .with_precision_and_scale(10, 0)?,
+ Int128Array::from_iter(
+ (1..6)
+ .map(|i| if i == 2 { None } else { Some(i) })
+ .collect::<Vec<Option<i128>>>(),
+ )
+ .to(DataType::Decimal(10, 0)),
);
generic_test_op!(
array,
@@ -715,12 +696,8 @@
#[test]
fn max_decimal_all_nulls() -> Result<()> {
- let array: ArrayRef = Arc::new(
- std::iter::repeat(None)
- .take(6)
- .collect::<DecimalArray>()
- .with_precision_and_scale(10, 0)?,
- );
+ let array: ArrayRef =
+ Arc::new(Int128Array::new_null(DataType::Decimal(10, 0), 6));
generic_test_op!(
array,
DataType::Decimal(10, 0),
@@ -732,7 +709,7 @@
#[test]
fn max_i32() -> Result<()> {
- let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
+ let a: ArrayRef = Arc::new(Int32Array::from_slice(vec![1, 2, 3, 4, 5]));
generic_test_op!(
a,
DataType::Int32,
@@ -744,7 +721,7 @@
#[test]
fn min_i32() -> Result<()> {
- let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
+ let a: ArrayRef = Arc::new(Int32Array::from_slice(vec![1, 2, 3, 4, 5]));
generic_test_op!(
a,
DataType::Int32,
@@ -756,7 +733,7 @@
#[test]
fn max_utf8() -> Result<()> {
- let a: ArrayRef = Arc::new(StringArray::from(vec!["d", "a", "c", "b"]));
+ let a: ArrayRef = Arc::new(StringArray::from_slice(vec!["d", "a", "c", "b"]));
generic_test_op!(
a,
DataType::Utf8,
@@ -768,7 +745,8 @@
#[test]
fn max_large_utf8() -> Result<()> {
- let a: ArrayRef = Arc::new(LargeStringArray::from(vec!["d", "a", "c", "b"]));
+ let a: ArrayRef =
+ Arc::new(LargeStringArray::from_slice(vec!["d", "a", "c", "b"]));
generic_test_op!(
a,
DataType::LargeUtf8,
@@ -780,7 +758,7 @@
#[test]
fn min_utf8() -> Result<()> {
- let a: ArrayRef = Arc::new(StringArray::from(vec!["d", "a", "c", "b"]));
+ let a: ArrayRef = Arc::new(StringArray::from_slice(vec!["d", "a", "c", "b"]));
generic_test_op!(
a,
DataType::Utf8,
@@ -792,7 +770,8 @@
#[test]
fn min_large_utf8() -> Result<()> {
- let a: ArrayRef = Arc::new(LargeStringArray::from(vec!["d", "a", "c", "b"]));
+ let a: ArrayRef =
+ Arc::new(LargeStringArray::from_slice(vec!["d", "a", "c", "b"]));
generic_test_op!(
a,
DataType::LargeUtf8,
@@ -804,7 +783,7 @@
#[test]
fn max_i32_with_nulls() -> Result<()> {
- let a: ArrayRef = Arc::new(Int32Array::from(vec![
+ let a: ArrayRef = Arc::new(Int32Array::from(&[
Some(1),
None,
Some(3),
@@ -822,7 +801,7 @@
#[test]
fn min_i32_with_nulls() -> Result<()> {
- let a: ArrayRef = Arc::new(Int32Array::from(vec![
+ let a: ArrayRef = Arc::new(Int32Array::from(&[
Some(1),
None,
Some(3),
@@ -840,7 +819,7 @@
#[test]
fn max_i32_all_nulls() -> Result<()> {
- let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None]));
+ let a: ArrayRef = Arc::new(Int32Array::from(&[None, None]));
generic_test_op!(
a,
DataType::Int32,
@@ -852,7 +831,7 @@
#[test]
fn min_i32_all_nulls() -> Result<()> {
- let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None]));
+ let a: ArrayRef = Arc::new(Int32Array::from(&[None, None]));
generic_test_op!(
a,
DataType::Int32,
@@ -864,8 +843,9 @@
#[test]
fn max_u32() -> Result<()> {
- let a: ArrayRef =
- Arc::new(UInt32Array::from(vec![1_u32, 2_u32, 3_u32, 4_u32, 5_u32]));
+ let a: ArrayRef = Arc::new(UInt32Array::from_slice(vec![
+ 1_u32, 2_u32, 3_u32, 4_u32, 5_u32,
+ ]));
generic_test_op!(
a,
DataType::UInt32,
@@ -877,8 +857,9 @@
#[test]
fn min_u32() -> Result<()> {
- let a: ArrayRef =
- Arc::new(UInt32Array::from(vec![1_u32, 2_u32, 3_u32, 4_u32, 5_u32]));
+ let a: ArrayRef = Arc::new(UInt32Array::from_slice(vec![
+ 1_u32, 2_u32, 3_u32, 4_u32, 5_u32,
+ ]));
generic_test_op!(
a,
DataType::UInt32,
@@ -890,8 +871,9 @@
#[test]
fn max_f32() -> Result<()> {
- let a: ArrayRef =
- Arc::new(Float32Array::from(vec![1_f32, 2_f32, 3_f32, 4_f32, 5_f32]));
+ let a: ArrayRef = Arc::new(Float32Array::from_slice(vec![
+ 1_f32, 2_f32, 3_f32, 4_f32, 5_f32,
+ ]));
generic_test_op!(
a,
DataType::Float32,
@@ -903,8 +885,9 @@
#[test]
fn min_f32() -> Result<()> {
- let a: ArrayRef =
- Arc::new(Float32Array::from(vec![1_f32, 2_f32, 3_f32, 4_f32, 5_f32]));
+ let a: ArrayRef = Arc::new(Float32Array::from_slice(vec![
+ 1_f32, 2_f32, 3_f32, 4_f32, 5_f32,
+ ]));
generic_test_op!(
a,
DataType::Float32,
@@ -916,8 +899,9 @@
#[test]
fn max_f64() -> Result<()> {
- let a: ArrayRef =
- Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64]));
+ let a: ArrayRef = Arc::new(Float64Array::from_slice(vec![
+ 1_f64, 2_f64, 3_f64, 4_f64, 5_f64,
+ ]));
generic_test_op!(
a,
DataType::Float64,
@@ -929,8 +913,9 @@
#[test]
fn min_f64() -> Result<()> {
- let a: ArrayRef =
- Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64]));
+ let a: ArrayRef = Arc::new(Float64Array::from_slice(vec![
+ 1_f64, 2_f64, 3_f64, 4_f64, 5_f64,
+ ]));
generic_test_op!(
a,
DataType::Float64,
@@ -942,7 +927,8 @@
#[test]
fn min_date32() -> Result<()> {
- let a: ArrayRef = Arc::new(Date32Array::from(vec![1, 2, 3, 4, 5]));
+ let a: ArrayRef =
+ Arc::new(Int32Array::from_slice(&[1, 2, 3, 4, 5]).to(DataType::Date32));
generic_test_op!(
a,
DataType::Date32,
@@ -954,7 +940,8 @@
#[test]
fn min_date64() -> Result<()> {
- let a: ArrayRef = Arc::new(Date64Array::from(vec![1, 2, 3, 4, 5]));
+ let a: ArrayRef =
+ Arc::new(Int64Array::from_slice(&[1, 2, 3, 4, 5]).to(DataType::Date64));
generic_test_op!(
a,
DataType::Date64,
@@ -966,7 +953,8 @@
#[test]
fn max_date32() -> Result<()> {
- let a: ArrayRef = Arc::new(Date32Array::from(vec![1, 2, 3, 4, 5]));
+ let a: ArrayRef =
+ Arc::new(Int32Array::from_slice(&[1, 2, 3, 4, 5]).to(DataType::Date32));
generic_test_op!(
a,
DataType::Date32,
@@ -978,7 +966,8 @@
#[test]
fn max_date64() -> Result<()> {
- let a: ArrayRef = Arc::new(Date64Array::from(vec![1, 2, 3, 4, 5]));
+ let a: ArrayRef =
+ Arc::new(Int64Array::from_slice(&[1, 2, 3, 4, 5]).to(DataType::Date64));
generic_test_op!(
a,
DataType::Date64,
diff --git a/datafusion-physical-expr/src/expressions/mod.rs b/datafusion-physical-expr/src/expressions/mod.rs
index dd0b011..adbee32 100644
--- a/datafusion-physical-expr/src/expressions/mod.rs
+++ b/datafusion-physical-expr/src/expressions/mod.rs
@@ -24,7 +24,7 @@
#[macro_use]
mod binary;
mod case;
-mod cast;
+pub(crate) mod cast;
mod column;
mod count;
mod cume_dist;
@@ -113,7 +113,7 @@
#[cfg(test)]
mod tests {
use crate::AggregateExpr;
- use arrow::record_batch::RecordBatch;
+ use datafusion_common::record_batch::RecordBatch;
use datafusion_common::Result;
use datafusion_common::ScalarValue;
use std::sync::Arc;
@@ -127,7 +127,7 @@
let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![$ARRAY])?;
let agg = Arc::new(<$OP>::new(
- col("a", &schema)?,
+ $crate::expressions::col("a", &schema)?,
"bla".to_string(),
$EXPECTED_DATATYPE,
));
diff --git a/datafusion-physical-expr/src/expressions/negative.rs b/datafusion-physical-expr/src/expressions/negative.rs
index 4974bdb..2c80653 100644
--- a/datafusion-physical-expr/src/expressions/negative.rs
+++ b/datafusion-physical-expr/src/expressions/negative.rs
@@ -20,13 +20,12 @@
use std::any::Any;
use std::sync::Arc;
-use arrow::array::ArrayRef;
-use arrow::compute::kernels::arithmetic::negate;
use arrow::{
- array::{Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array},
+ array::*,
+ compute::arithmetics::basic::negate,
datatypes::{DataType, Schema},
- record_batch::RecordBatch,
};
+use datafusion_common::record_batch::RecordBatch;
use crate::coercion_rule::binary_rule::is_signed_numeric;
use crate::PhysicalExpr;
@@ -36,12 +35,12 @@
/// Invoke a compute kernel on array(s)
macro_rules! compute_op {
// invoke unary operator
- ($OPERAND:expr, $OP:ident, $DT:ident) => {{
+ ($OPERAND:expr, $DT:ident) => {{
let operand = $OPERAND
.as_any()
.downcast_ref::<$DT>()
.expect("compute_op failed to downcast array");
- Ok(Arc::new($OP(&operand)?))
+ Ok(Arc::new(negate(operand)))
}};
}
@@ -89,12 +88,12 @@
match arg {
ColumnarValue::Array(array) => {
let result: Result<ArrayRef> = match array.data_type() {
- DataType::Int8 => compute_op!(array, negate, Int8Array),
- DataType::Int16 => compute_op!(array, negate, Int16Array),
- DataType::Int32 => compute_op!(array, negate, Int32Array),
- DataType::Int64 => compute_op!(array, negate, Int64Array),
- DataType::Float32 => compute_op!(array, negate, Float32Array),
- DataType::Float64 => compute_op!(array, negate, Float64Array),
+ DataType::Int8 => compute_op!(array, Int8Array),
+ DataType::Int16 => compute_op!(array, Int16Array),
+ DataType::Int32 => compute_op!(array, Int32Array),
+ DataType::Int64 => compute_op!(array, Int64Array),
+ DataType::Float32 => compute_op!(array, Float32Array),
+ DataType::Float64 => compute_op!(array, Float64Array),
_ => Err(DataFusionError::Internal(format!(
"(- '{:?}') can't be evaluated because the expression's type is {:?}, not signed numeric",
self,
diff --git a/datafusion-physical-expr/src/expressions/not.rs b/datafusion-physical-expr/src/expressions/not.rs
index fd0fbd1..57ec37b 100644
--- a/datafusion-physical-expr/src/expressions/not.rs
+++ b/datafusion-physical-expr/src/expressions/not.rs
@@ -24,9 +24,10 @@
use crate::PhysicalExpr;
use arrow::array::BooleanArray;
use arrow::datatypes::{DataType, Schema};
-use arrow::record_batch::RecordBatch;
use datafusion_common::ScalarValue;
use datafusion_common::{DataFusionError, Result};
+
+use datafusion_common::record_batch::RecordBatch;
use datafusion_expr::ColumnarValue;
/// Not expression
@@ -82,7 +83,7 @@
)
})?;
Ok(ColumnarValue::Array(Arc::new(
- arrow::compute::kernels::boolean::not(array)?,
+ arrow::compute::boolean::not(array),
)))
}
ColumnarValue::Scalar(scalar) => {
@@ -118,8 +119,10 @@
#[cfg(test)]
mod tests {
use super::*;
+
use crate::expressions::col;
use arrow::datatypes::*;
+ use datafusion_common::field_util::SchemaExt;
use datafusion_common::Result;
#[test]
@@ -130,8 +133,8 @@
assert_eq!(expr.data_type(&schema)?, DataType::Boolean);
assert!(expr.nullable(&schema)?);
- let input = BooleanArray::from(vec![Some(true), None, Some(false)]);
- let expected = &BooleanArray::from(vec![Some(false), None, Some(true)]);
+ let input = BooleanArray::from_iter(vec![Some(true), None, Some(false)]);
+ let expected = &BooleanArray::from_iter(vec![Some(false), None, Some(true)]);
let batch =
RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(input)])?;
diff --git a/datafusion-physical-expr/src/expressions/nth_value.rs b/datafusion-physical-expr/src/expressions/nth_value.rs
index e0a6b2b..84e01fc 100644
--- a/datafusion-physical-expr/src/expressions/nth_value.rs
+++ b/datafusion-physical-expr/src/expressions/nth_value.rs
@@ -22,9 +22,9 @@
use crate::window::BuiltInWindowFunctionExpr;
use crate::PhysicalExpr;
use arrow::array::{new_null_array, ArrayRef};
-use arrow::compute::kernels::window::shift;
+use arrow::compute::window::shift;
use arrow::datatypes::{DataType, Field};
-use arrow::record_batch::RecordBatch;
+use datafusion_common::record_batch::RecordBatch;
use datafusion_common::ScalarValue;
use datafusion_common::{DataFusionError, Result};
use std::any::Any;
@@ -175,12 +175,15 @@
.collect::<Result<Vec<_>>>()?
.into_iter()
.flatten();
- ScalarValue::iter_to_array(values)
+ ScalarValue::iter_to_array(values).map(ArrayRef::from)
}
NthValueKind::Nth(n) => {
let index = (n as usize) - 1;
if index >= num_rows {
- Ok(new_null_array(arr.data_type(), num_rows))
+ Ok(ArrayRef::from(new_null_array(
+ arr.data_type().clone(),
+ num_rows,
+ )))
} else {
let value =
ScalarValue::try_from_array(arr, partition.start + index)?;
@@ -188,7 +191,9 @@
// because the default window frame is between unbounded preceding and current
// row, hence the shift because for values with indices < index they should be
// null. This changes when window frames other than default is implemented
- shift(arr.as_ref(), index as i64).map_err(DataFusionError::ArrowError)
+ shift(arr.as_ref(), index as i64)
+ .map_err(DataFusionError::ArrowError)
+ .map(ArrayRef::from)
}
}
}
@@ -198,13 +203,17 @@
#[cfg(test)]
mod tests {
use super::*;
+
use crate::expressions::Column;
- use arrow::record_batch::RecordBatch;
+ use datafusion_common::field_util::SchemaExt;
+
use arrow::{array::*, datatypes::*};
+ use datafusion_common::record_batch::RecordBatch;
use datafusion_common::Result;
fn test_i32_result(expr: NthValue, expected: Int32Array) -> Result<()> {
- let arr: ArrayRef = Arc::new(Int32Array::from(vec![1, -2, 3, -4, 5, -6, 7, 8]));
+ let arr: ArrayRef =
+ Arc::new(Int32Array::from_slice(vec![1, -2, 3, -4, 5, -6, 7, 8]));
let values = vec![arr];
let schema = Schema::new(vec![Field::new("arr", DataType::Int32, false)]);
let batch = RecordBatch::try_new(Arc::new(schema), values.clone())?;
@@ -224,7 +233,7 @@
Arc::new(Column::new("arr", 0)),
DataType::Int32,
);
- test_i32_result(first_value, Int32Array::from_iter_values(vec![1; 8]))?;
+ test_i32_result(first_value, Int32Array::from_values(vec![1; 8]))?;
Ok(())
}
@@ -235,7 +244,7 @@
Arc::new(Column::new("arr", 0)),
DataType::Int32,
);
- test_i32_result(last_value, Int32Array::from_iter_values(vec![8; 8]))?;
+ test_i32_result(last_value, Int32Array::from_values(vec![8; 8]))?;
Ok(())
}
@@ -247,7 +256,7 @@
DataType::Int32,
1,
)?;
- test_i32_result(nth_value, Int32Array::from_iter_values(vec![1; 8]))?;
+ test_i32_result(nth_value, Int32Array::from_values(vec![1; 8]))?;
Ok(())
}
@@ -261,7 +270,7 @@
)?;
test_i32_result(
nth_value,
- Int32Array::from(vec![
+ Int32Array::from(&[
None,
Some(-2),
Some(-2),
diff --git a/datafusion-physical-expr/src/expressions/nullif.rs b/datafusion-physical-expr/src/expressions/nullif.rs
index a078e22..45040ab 100644
--- a/datafusion-physical-expr/src/expressions/nullif.rs
+++ b/datafusion-physical-expr/src/expressions/nullif.rs
@@ -15,57 +15,11 @@
// specific language governing permissions and limitations
// under the License.
-use std::sync::Arc;
-
-use crate::expressions::binary::{eq_decimal, eq_decimal_scalar};
-use arrow::array::Array;
-use arrow::array::*;
-use arrow::compute::kernels::boolean::nullif;
-use arrow::compute::kernels::comparison::{
- eq, eq_bool, eq_bool_scalar, eq_scalar, eq_utf8, eq_utf8_scalar,
-};
-use arrow::datatypes::{DataType, TimeUnit};
-use datafusion_common::ScalarValue;
+use arrow::compute::nullif;
+use arrow::datatypes::DataType;
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::ColumnarValue;
-/// Invoke a compute kernel on a primitive array and a Boolean Array
-macro_rules! compute_bool_array_op {
- ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{
- let ll = $LEFT
- .as_any()
- .downcast_ref::<$DT>()
- .expect("compute_op failed to downcast array");
- let rr = $RIGHT
- .as_any()
- .downcast_ref::<BooleanArray>()
- .expect("compute_op failed to downcast array");
- Ok(Arc::new($OP(&ll, &rr)?) as ArrayRef)
- }};
-}
-
-/// Binary op between primitive and boolean arrays
-macro_rules! primitive_bool_array_op {
- ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
- match $LEFT.data_type() {
- DataType::Int8 => compute_bool_array_op!($LEFT, $RIGHT, $OP, Int8Array),
- DataType::Int16 => compute_bool_array_op!($LEFT, $RIGHT, $OP, Int16Array),
- DataType::Int32 => compute_bool_array_op!($LEFT, $RIGHT, $OP, Int32Array),
- DataType::Int64 => compute_bool_array_op!($LEFT, $RIGHT, $OP, Int64Array),
- DataType::UInt8 => compute_bool_array_op!($LEFT, $RIGHT, $OP, UInt8Array),
- DataType::UInt16 => compute_bool_array_op!($LEFT, $RIGHT, $OP, UInt16Array),
- DataType::UInt32 => compute_bool_array_op!($LEFT, $RIGHT, $OP, UInt32Array),
- DataType::UInt64 => compute_bool_array_op!($LEFT, $RIGHT, $OP, UInt64Array),
- DataType::Float32 => compute_bool_array_op!($LEFT, $RIGHT, $OP, Float32Array),
- DataType::Float64 => compute_bool_array_op!($LEFT, $RIGHT, $OP, Float64Array),
- other => Err(DataFusionError::Internal(format!(
- "Unsupported data type {:?} for NULLIF/primitive/boolean operator",
- other
- ))),
- }
- }};
-}
-
/// Implements NULLIF(expr1, expr2)
/// Args: 0 - left expr is any array
/// 1 - if the left is equal to this expr2, then the result is NULL, otherwise left value is passed.
@@ -82,20 +36,14 @@
match (lhs, rhs) {
(ColumnarValue::Array(lhs), ColumnarValue::Scalar(rhs)) => {
- let cond_array = binary_array_op_scalar!(lhs, rhs.clone(), eq).unwrap()?;
-
- let array = primitive_bool_array_op!(lhs, *cond_array, nullif)?;
-
- Ok(ColumnarValue::Array(array))
+ Ok(ColumnarValue::Array(
+ nullif::nullif(lhs.as_ref(), rhs.to_array_of_size(lhs.len()).as_ref())
+ .into(),
+ ))
}
- (ColumnarValue::Array(lhs), ColumnarValue::Array(rhs)) => {
- // Get args0 == args1 evaluated and produce a boolean array
- let cond_array = binary_array_op!(lhs, rhs, eq)?;
-
- // Now, invoke nullif on the result
- let array = primitive_bool_array_op!(lhs, *cond_array, nullif)?;
- Ok(ColumnarValue::Array(array))
- }
+ (ColumnarValue::Array(lhs), ColumnarValue::Array(rhs)) => Ok(
+ ColumnarValue::Array(nullif::nullif(lhs.as_ref(), rhs.as_ref()).into()),
+ ),
_ => Err(DataFusionError::NotImplemented(
"nullif does not support a literal as first argument".to_string(),
)),
@@ -121,12 +69,15 @@
#[cfg(test)]
mod tests {
+ use arrow::array::Int32Array;
+ use std::sync::Arc;
+
use super::*;
- use datafusion_common::Result;
+ use datafusion_common::{Result, ScalarValue};
#[test]
fn nullif_int32() -> Result<()> {
- let a = Int32Array::from(vec![
+ let a = Int32Array::from_iter(vec![
Some(1),
Some(2),
None,
@@ -144,7 +95,7 @@
let result = nullif_func(&[a, lit_array])?;
let result = result.into_array(0);
- let expected = Arc::new(Int32Array::from(vec![
+ let expected = Int32Array::from_iter(vec![
Some(1),
None,
None,
@@ -154,15 +105,15 @@
None,
Some(4),
Some(5),
- ])) as ArrayRef;
- assert_eq!(expected.as_ref(), result.as_ref());
+ ]);
+ assert_eq!(expected, result.as_ref());
Ok(())
}
#[test]
// Ensure that arrays with no nulls can also invoke NULLIF() correctly
fn nullif_int32_nonulls() -> Result<()> {
- let a = Int32Array::from(vec![1, 3, 10, 7, 8, 1, 2, 4, 5]);
+ let a = Int32Array::from_slice(vec![1, 3, 10, 7, 8, 1, 2, 4, 5]);
let a = ColumnarValue::Array(Arc::new(a));
let lit_array = ColumnarValue::Scalar(ScalarValue::Int32(Some(1i32)));
@@ -170,7 +121,7 @@
let result = nullif_func(&[a, lit_array])?;
let result = result.into_array(0);
- let expected = Arc::new(Int32Array::from(vec![
+ let expected = Int32Array::from_iter(vec![
None,
Some(3),
Some(10),
@@ -180,8 +131,8 @@
Some(2),
Some(4),
Some(5),
- ])) as ArrayRef;
- assert_eq!(expected.as_ref(), result.as_ref());
+ ]);
+ assert_eq!(expected, result.as_ref());
Ok(())
}
}
diff --git a/datafusion-physical-expr/src/expressions/rank.rs b/datafusion-physical-expr/src/expressions/rank.rs
index 18bcf26..dc31f4a 100644
--- a/datafusion-physical-expr/src/expressions/rank.rs
+++ b/datafusion-physical-expr/src/expressions/rank.rs
@@ -24,7 +24,7 @@
use arrow::array::ArrayRef;
use arrow::array::{Float64Array, UInt64Array};
use arrow::datatypes::{DataType, Field};
-use arrow::record_batch::RecordBatch;
+use datafusion_common::record_batch::RecordBatch;
use datafusion_common::Result;
use std::any::Any;
use std::iter;
@@ -39,6 +39,7 @@
}
#[derive(Debug, Copy, Clone)]
+#[allow(clippy::enum_variant_names)]
pub(crate) enum RankType {
Basic,
Dense,
@@ -122,7 +123,7 @@
) -> Result<ArrayRef> {
// see https://www.postgresql.org/docs/current/functions-window.html
let result: ArrayRef = match self.rank_type {
- RankType::Dense => Arc::new(UInt64Array::from_iter_values(
+ RankType::Dense => Arc::new(UInt64Array::from_values(
ranks_in_partition
.iter()
.zip(1u64..)
@@ -134,7 +135,7 @@
RankType::Percent => {
// Returns the relative rank of the current row, that is (rank - 1) / (total partition rows - 1). The value thus ranges from 0 to 1 inclusive.
let denominator = (partition.end - partition.start) as f64;
- Arc::new(Float64Array::from_iter_values(
+ Arc::new(Float64Array::from_values(
ranks_in_partition
.iter()
.scan(0_u64, |acc, range| {
@@ -147,7 +148,7 @@
.flatten(),
))
}
- RankType::Basic => Arc::new(UInt64Array::from_iter_values(
+ RankType::Basic => Arc::new(UInt64Array::from_values(
ranks_in_partition
.iter()
.scan(1_u64, |acc, range| {
@@ -167,6 +168,7 @@
mod tests {
use super::*;
use arrow::{array::*, datatypes::*};
+ use datafusion_common::field_util::SchemaExt;
fn test_with_rank(expr: &Rank, expected: Vec<u64>) -> Result<()> {
test_i32_result(
@@ -188,7 +190,7 @@
ranks: Vec<Range<usize>>,
expected: Vec<f64>,
) -> Result<()> {
- let arr: ArrayRef = Arc::new(Int32Array::from(data));
+ let arr: ArrayRef = Arc::new(Int32Array::from_slice(data.as_slice()));
let values = vec![arr];
let schema = Schema::new(vec![Field::new("arr", DataType::Int32, false)]);
let batch = RecordBatch::try_new(Arc::new(schema), values.clone())?;
@@ -197,7 +199,7 @@
.evaluate_with_rank(vec![range], ranks)?;
assert_eq!(1, result.len());
let result = result[0].as_any().downcast_ref::<Float64Array>().unwrap();
- let result = result.values();
+ let result = result.values().as_slice();
assert_eq!(expected, result);
Ok(())
}
@@ -208,7 +210,7 @@
ranks: Vec<Range<usize>>,
expected: Vec<u64>,
) -> Result<()> {
- let arr: ArrayRef = Arc::new(Int32Array::from(data));
+ let arr: ArrayRef = Arc::new(Int32Array::from_values(data));
let values = vec![arr];
let schema = Schema::new(vec![Field::new("arr", DataType::Int32, false)]);
let batch = RecordBatch::try_new(Arc::new(schema), values.clone())?;
@@ -217,8 +219,8 @@
.evaluate_with_rank(vec![0..8], ranks)?;
assert_eq!(1, result.len());
let result = result[0].as_any().downcast_ref::<UInt64Array>().unwrap();
- let result = result.values();
- assert_eq!(expected, result);
+ let expected = UInt64Array::from_values(expected);
+ assert_eq!(expected, *result);
Ok(())
}
diff --git a/datafusion-physical-expr/src/expressions/row_number.rs b/datafusion-physical-expr/src/expressions/row_number.rs
index 8a720d2..90ff378 100644
--- a/datafusion-physical-expr/src/expressions/row_number.rs
+++ b/datafusion-physical-expr/src/expressions/row_number.rs
@@ -22,7 +22,7 @@
use crate::PhysicalExpr;
use arrow::array::{ArrayRef, UInt64Array};
use arrow::datatypes::{DataType, Field};
-use arrow::record_batch::RecordBatch;
+use datafusion_common::record_batch::RecordBatch;
use datafusion_common::Result;
use std::any::Any;
use std::ops::Range;
@@ -75,22 +75,22 @@
impl PartitionEvaluator for NumRowsEvaluator {
fn evaluate_partition(&self, partition: Range<usize>) -> Result<ArrayRef> {
let num_rows = partition.end - partition.start;
- Ok(Arc::new(UInt64Array::from_iter_values(
- 1..(num_rows as u64) + 1,
- )))
+ Ok(Arc::new(UInt64Array::from_values(1..(num_rows as u64) + 1)))
}
}
#[cfg(test)]
mod tests {
use super::*;
- use arrow::record_batch::RecordBatch;
+
use arrow::{array::*, datatypes::*};
+ use datafusion_common::field_util::SchemaExt;
+ use datafusion_common::record_batch::RecordBatch;
use datafusion_common::Result;
#[test]
fn row_number_all_null() -> Result<()> {
- let arr: ArrayRef = Arc::new(BooleanArray::from(vec![
+ let arr: ArrayRef = Arc::new(BooleanArray::from_iter(vec![
None, None, None, None, None, None, None, None,
]));
let schema = Schema::new(vec![Field::new("arr", DataType::Boolean, false)]);
@@ -99,14 +99,14 @@
let result = row_number.create_evaluator(&batch)?.evaluate(vec![0..8])?;
assert_eq!(1, result.len());
let result = result[0].as_any().downcast_ref::<UInt64Array>().unwrap();
- let result = result.values();
+ let result = result.values().as_slice();
assert_eq!(vec![1, 2, 3, 4, 5, 6, 7, 8], result);
Ok(())
}
#[test]
fn row_number_all_values() -> Result<()> {
- let arr: ArrayRef = Arc::new(BooleanArray::from(vec![
+ let arr: ArrayRef = Arc::new(BooleanArray::from_slice(vec![
true, false, true, false, false, true, false, true,
]));
let schema = Schema::new(vec![Field::new("arr", DataType::Boolean, false)]);
@@ -115,7 +115,7 @@
let result = row_number.create_evaluator(&batch)?.evaluate(vec![0..8])?;
assert_eq!(1, result.len());
let result = result[0].as_any().downcast_ref::<UInt64Array>().unwrap();
- let result = result.values();
+ let result = result.values().as_slice();
assert_eq!(vec![1, 2, 3, 4, 5, 6, 7, 8], result);
Ok(())
}
diff --git a/datafusion-physical-expr/src/expressions/stddev.rs b/datafusion-physical-expr/src/expressions/stddev.rs
index 8a5d4e8..5cca776 100644
--- a/datafusion-physical-expr/src/expressions/stddev.rs
+++ b/datafusion-physical-expr/src/expressions/stddev.rs
@@ -253,13 +253,14 @@
use super::*;
use crate::expressions::col;
use crate::generic_test_op;
- use arrow::record_batch::RecordBatch;
use arrow::{array::*, datatypes::*};
+ use datafusion_common::field_util::SchemaExt;
+ use datafusion_common::record_batch::RecordBatch;
use datafusion_common::Result;
#[test]
fn stddev_f64_1() -> Result<()> {
- let a: ArrayRef = Arc::new(Float64Array::from(vec![1_f64, 2_f64]));
+ let a: ArrayRef = Arc::new(Float64Array::from_slice(vec![1_f64, 2_f64]));
generic_test_op!(
a,
DataType::Float64,
@@ -271,7 +272,7 @@
#[test]
fn stddev_f64_2() -> Result<()> {
- let a: ArrayRef = Arc::new(Float64Array::from(vec![1.1_f64, 2_f64, 3_f64]));
+ let a: ArrayRef = Arc::new(Float64Array::from_slice(vec![1.1_f64, 2_f64, 3_f64]));
generic_test_op!(
a,
DataType::Float64,
@@ -283,8 +284,9 @@
#[test]
fn stddev_f64_3() -> Result<()> {
- let a: ArrayRef =
- Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64]));
+ let a: ArrayRef = Arc::new(Float64Array::from_slice(vec![
+ 1_f64, 2_f64, 3_f64, 4_f64, 5_f64,
+ ]));
generic_test_op!(
a,
DataType::Float64,
@@ -296,7 +298,7 @@
#[test]
fn stddev_f64_4() -> Result<()> {
- let a: ArrayRef = Arc::new(Float64Array::from(vec![1.1_f64, 2_f64, 3_f64]));
+ let a: ArrayRef = Arc::new(Float64Array::from_slice(vec![1.1_f64, 2_f64, 3_f64]));
generic_test_op!(
a,
DataType::Float64,
@@ -308,7 +310,7 @@
#[test]
fn stddev_i32() -> Result<()> {
- let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
+ let a: ArrayRef = Arc::new(Int32Array::from_slice(vec![1, 2, 3, 4, 5]));
generic_test_op!(
a,
DataType::Int32,
@@ -320,8 +322,9 @@
#[test]
fn stddev_u32() -> Result<()> {
- let a: ArrayRef =
- Arc::new(UInt32Array::from(vec![1_u32, 2_u32, 3_u32, 4_u32, 5_u32]));
+ let a: ArrayRef = Arc::new(UInt32Array::from_slice(vec![
+ 1_u32, 2_u32, 3_u32, 4_u32, 5_u32,
+ ]));
generic_test_op!(
a,
DataType::UInt32,
@@ -333,8 +336,9 @@
#[test]
fn stddev_f32() -> Result<()> {
- let a: ArrayRef =
- Arc::new(Float32Array::from(vec![1_f32, 2_f32, 3_f32, 4_f32, 5_f32]));
+ let a: ArrayRef = Arc::new(Float32Array::from_slice(vec![
+ 1_f32, 2_f32, 3_f32, 4_f32, 5_f32,
+ ]));
generic_test_op!(
a,
DataType::Float32,
@@ -357,7 +361,7 @@
#[test]
fn test_stddev_1_input() -> Result<()> {
- let a: ArrayRef = Arc::new(Float64Array::from(vec![1_f64]));
+ let a: ArrayRef = Arc::new(Float64Array::from_slice(vec![1_f64]));
let schema = Schema::new(vec![Field::new("a", DataType::Float64, false)]);
let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?;
@@ -374,7 +378,7 @@
#[test]
fn stddev_i32_with_nulls() -> Result<()> {
- let a: ArrayRef = Arc::new(Int32Array::from(vec![
+ let a: ArrayRef = Arc::new(Int32Array::from_iter(vec![
Some(1),
None,
Some(3),
@@ -392,7 +396,7 @@
#[test]
fn stddev_i32_all_nulls() -> Result<()> {
- let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None]));
+ let a: ArrayRef = Int32Vec::from(vec![None, None]).as_arc();
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?;
@@ -409,8 +413,8 @@
#[test]
fn stddev_f64_merge_1() -> Result<()> {
- let a = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64]));
- let b = Arc::new(Float64Array::from(vec![4_f64, 5_f64]));
+ let a = Arc::new(Float64Array::from_slice(vec![1_f64, 2_f64, 3_f64]));
+ let b = Arc::new(Float64Array::from_slice(vec![4_f64, 5_f64]));
let schema = Schema::new(vec![Field::new("a", DataType::Float64, false)]);
@@ -437,8 +441,10 @@
#[test]
fn stddev_f64_merge_2() -> Result<()> {
- let a = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64]));
- let b = Arc::new(Float64Array::from(vec![None]));
+ let a = Arc::new(Float64Array::from_slice(vec![
+ 1_f64, 2_f64, 3_f64, 4_f64, 5_f64,
+ ]));
+ let b = Arc::new(Float64Array::from_iter(vec![None]));
let schema = Schema::new(vec![Field::new("a", DataType::Float64, false)]);
diff --git a/datafusion-physical-expr/src/expressions/sum.rs b/datafusion-physical-expr/src/expressions/sum.rs
index 9945620..f4b0557 100644
--- a/datafusion-physical-expr/src/expressions/sum.rs
+++ b/datafusion-physical-expr/src/expressions/sum.rs
@@ -23,20 +23,15 @@
use crate::{AggregateExpr, PhysicalExpr};
use arrow::compute;
-use arrow::datatypes::{DataType, DECIMAL_MAX_PRECISION};
use arrow::{
- array::{
- ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array,
- Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
- },
- datatypes::Field,
+ array::*,
+ datatypes::{DataType, Field},
};
-use datafusion_common::{DataFusionError, Result, ScalarValue};
+use datafusion_common::{DataFusionError, Result, ScalarValue, DECIMAL_MAX_PRECISION};
use datafusion_expr::Accumulator;
use super::format_state_name;
use arrow::array::Array;
-use arrow::array::DecimalArray;
/// SUM aggregate expression
#[derive(Debug)]
@@ -158,7 +153,7 @@
macro_rules! typed_sum_delta_batch {
($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident) => {{
let array = $VALUES.as_any().downcast_ref::<$ARRAYTYPE>().unwrap();
- let delta = compute::sum(array);
+ let delta = compute::aggregate::sum_primitive(array);
ScalarValue::$SCALAR(delta)
}};
}
@@ -170,7 +165,7 @@
precision: &usize,
scale: &usize,
) -> Result<ScalarValue> {
- let array = values.as_any().downcast_ref::<DecimalArray>().unwrap();
+ let array = values.as_any().downcast_ref::<Int128Array>().unwrap();
if array.null_count() == array.len() {
return Ok(ScalarValue::Decimal128(None, *precision, *scale));
@@ -374,11 +369,10 @@
#[cfg(test)]
mod tests {
use super::*;
- use crate::expressions::col;
use crate::generic_test_op;
use arrow::datatypes::*;
- use arrow::record_batch::RecordBatch;
- use datafusion_common::Result;
+ use datafusion_common::field_util::SchemaExt;
+ use datafusion_common::record_batch::RecordBatch;
#[test]
fn test_sum_return_data_type() -> Result<()> {
@@ -417,22 +411,22 @@
);
// test sum batch
- let array: ArrayRef = Arc::new(
- (1..6)
- .map(Some)
- .collect::<DecimalArray>()
- .with_precision_and_scale(10, 0)?,
- );
+ let mut decimal_builder =
+ Int128Vec::with_capacity(5).to(DataType::Decimal(10, 0));
+ for i in 1..6 {
+ decimal_builder.push(Some(i as i128));
+ }
+ let array: ArrayRef = decimal_builder.as_arc();
let result = sum_batch(&array)?;
assert_eq!(ScalarValue::Decimal128(Some(15), 10, 0), result);
// test agg
- let array: ArrayRef = Arc::new(
- (1..6)
- .map(Some)
- .collect::<DecimalArray>()
- .with_precision_and_scale(10, 0)?,
- );
+ let mut decimal_builder =
+ Int128Vec::with_capacity(5).to(DataType::Decimal(10, 0));
+ for i in 1..6 {
+ decimal_builder.push(Some(i as i128));
+ }
+ let array: ArrayRef = decimal_builder.as_arc();
generic_test_op!(
array,
@@ -452,22 +446,30 @@
assert_eq!(ScalarValue::Decimal128(Some(123), 10, 2), result);
// test with batch
- let array: ArrayRef = Arc::new(
- (1..6)
- .map(|i| if i == 2 { None } else { Some(i) })
- .collect::<DecimalArray>()
- .with_precision_and_scale(10, 0)?,
- );
+ let mut decimal_builder =
+ Int128Vec::with_capacity(5).to(DataType::Decimal(10, 0));
+ for i in 1..6 {
+ if i == 2 {
+ decimal_builder.push_null();
+ } else {
+ decimal_builder.push(Some(i));
+ }
+ }
+ let array: ArrayRef = decimal_builder.as_arc();
let result = sum_batch(&array)?;
assert_eq!(ScalarValue::Decimal128(Some(13), 10, 0), result);
// test agg
- let array: ArrayRef = Arc::new(
- (1..6)
- .map(|i| if i == 2 { None } else { Some(i) })
- .collect::<DecimalArray>()
- .with_precision_and_scale(35, 0)?,
- );
+ let mut decimal_builder =
+ Int128Vec::with_capacity(5).to(DataType::Decimal(35, 0));
+ for i in 1..6 {
+ if i == 2 {
+ decimal_builder.push_null();
+ } else {
+ decimal_builder.push(Some(i));
+ }
+ }
+ let array: ArrayRef = decimal_builder.as_arc();
generic_test_op!(
array,
DataType::Decimal(35, 0),
@@ -486,16 +488,22 @@
assert_eq!(ScalarValue::Decimal128(None, 10, 2), result);
// test with batch
- let array: ArrayRef = Arc::new(
- std::iter::repeat(None)
- .take(6)
- .collect::<DecimalArray>()
- .with_precision_and_scale(10, 0)?,
- );
+ let mut decimal_builder =
+ Int128Vec::with_capacity(6).to(DataType::Decimal(10, 0));
+ for _i in 1..7 {
+ decimal_builder.push_null();
+ }
+ let array: ArrayRef = decimal_builder.as_arc();
let result = sum_batch(&array)?;
assert_eq!(ScalarValue::Decimal128(None, 10, 0), result);
// test agg
+ let mut decimal_builder =
+ Int128Vec::with_capacity(5).to(DataType::Decimal(10, 0));
+ for _i in 1..6 {
+ decimal_builder.push_null();
+ }
+ let array: ArrayRef = decimal_builder.as_arc();
generic_test_op!(
array,
DataType::Decimal(10, 0),
@@ -507,7 +515,7 @@
#[test]
fn sum_i32() -> Result<()> {
- let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
+ let a: ArrayRef = Arc::new(Int32Array::from_slice(vec![1, 2, 3, 4, 5]));
generic_test_op!(
a,
DataType::Int32,
@@ -519,7 +527,7 @@
#[test]
fn sum_i32_with_nulls() -> Result<()> {
- let a: ArrayRef = Arc::new(Int32Array::from(vec![
+ let a: ArrayRef = Arc::new(Int32Array::from_iter(&[
Some(1),
None,
Some(3),
@@ -537,7 +545,7 @@
#[test]
fn sum_i32_all_nulls() -> Result<()> {
- let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None]));
+ let a: ArrayRef = Arc::new(Int32Array::from_iter(vec![None, None]));
generic_test_op!(
a,
DataType::Int32,
@@ -549,8 +557,9 @@
#[test]
fn sum_u32() -> Result<()> {
- let a: ArrayRef =
- Arc::new(UInt32Array::from(vec![1_u32, 2_u32, 3_u32, 4_u32, 5_u32]));
+ let a: ArrayRef = Arc::new(UInt32Array::from_slice(vec![
+ 1_u32, 2_u32, 3_u32, 4_u32, 5_u32,
+ ]));
generic_test_op!(
a,
DataType::UInt32,
@@ -562,8 +571,9 @@
#[test]
fn sum_f32() -> Result<()> {
- let a: ArrayRef =
- Arc::new(Float32Array::from(vec![1_f32, 2_f32, 3_f32, 4_f32, 5_f32]));
+ let a: ArrayRef = Arc::new(Float32Array::from_slice(vec![
+ 1_f32, 2_f32, 3_f32, 4_f32, 5_f32,
+ ]));
generic_test_op!(
a,
DataType::Float32,
@@ -575,8 +585,9 @@
#[test]
fn sum_f64() -> Result<()> {
- let a: ArrayRef =
- Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64]));
+ let a: ArrayRef = Arc::new(Float64Array::from_slice(vec![
+ 1_f64, 2_f64, 3_f64, 4_f64, 5_f64,
+ ]));
generic_test_op!(
a,
DataType::Float64,
diff --git a/datafusion-physical-expr/src/expressions/try_cast.rs b/datafusion-physical-expr/src/expressions/try_cast.rs
index 6b0d3e1..2727ead 100644
--- a/datafusion-physical-expr/src/expressions/try_cast.rs
+++ b/datafusion-physical-expr/src/expressions/try_cast.rs
@@ -19,12 +19,12 @@
use std::fmt;
use std::sync::Arc;
+use crate::expressions::cast::cast_with_error;
use crate::PhysicalExpr;
use arrow::compute;
-use arrow::compute::kernels;
use arrow::datatypes::{DataType, Schema};
-use arrow::record_batch::RecordBatch;
-use compute::can_cast_types;
+use compute::cast;
+use datafusion_common::record_batch::RecordBatch;
use datafusion_common::ScalarValue;
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::ColumnarValue;
@@ -78,13 +78,22 @@
fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
let value = self.expr.evaluate(batch)?;
match value {
- ColumnarValue::Array(array) => Ok(ColumnarValue::Array(kernels::cast::cast(
- &array,
- &self.cast_type,
- )?)),
+ ColumnarValue::Array(array) => Ok(ColumnarValue::Array(
+ cast_with_error(
+ array.as_ref(),
+ &self.cast_type,
+ cast::CastOptions::default(),
+ )?
+ .into(),
+ )),
ColumnarValue::Scalar(scalar) => {
let scalar_array = scalar.to_array();
- let cast_array = kernels::cast::cast(&scalar_array, &self.cast_type)?;
+ let cast_array = cast_with_error(
+ scalar_array.as_ref(),
+ &self.cast_type,
+ cast::CastOptions::default(),
+ )?
+ .into();
let cast_scalar = ScalarValue::try_from_array(&cast_array, 0)?;
Ok(ColumnarValue::Scalar(cast_scalar))
}
@@ -104,7 +113,7 @@
let expr_type = expr.data_type(input_schema)?;
if expr_type == cast_type {
Ok(expr.clone())
- } else if can_cast_types(&expr_type, &cast_type) {
+ } else if cast::can_cast_types(&expr_type, &cast_type) {
Ok(Arc::new(TryCastExpr::new(expr, cast_type)))
} else {
Err(DataFusionError::Internal(format!(
@@ -118,18 +127,13 @@
mod tests {
use super::*;
use crate::expressions::col;
- use arrow::array::{
- DecimalArray, DecimalBuilder, StringArray, Time64NanosecondArray,
- };
- use arrow::{
- array::{
- Array, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array,
- Int8Array, TimestampNanosecondArray, UInt32Array,
- },
- datatypes::*,
- };
+ use crate::test_util::create_decimal_array_from_slice;
+ use arrow::{array::*, datatypes::*};
+ use datafusion_common::field_util::SchemaExt;
use datafusion_common::Result;
+ type StringArray = Utf8Array<i32>;
+
// runs an end-to-end test of physical type cast
// 1. construct a record batch with a column "a" of type A
// 2. construct a physical expression of CAST(a AS B)
@@ -186,7 +190,7 @@
macro_rules! generic_test_cast {
($A_ARRAY:ident, $A_TYPE:expr, $A_VEC:expr, $TYPEARRAY:ident, $TYPE:expr, $VEC:expr) => {{
let schema = Schema::new(vec![Field::new("a", $A_TYPE, false)]);
- let a = $A_ARRAY::from($A_VEC);
+ let a = $A_ARRAY::from_slice(&$A_VEC);
let batch =
RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
@@ -231,11 +235,11 @@
fn test_try_cast_decimal_to_decimal() -> Result<()> {
// try cast one decimal data type to another decimal data type
let array: Vec<i128> = vec![1234, 2222, 3, 4000, 5000];
- let decimal_array = create_decimal_array(&array, 10, 3)?;
+ let decimal_array = create_decimal_array_from_slice(&array, 10, 3)?;
generic_decimal_to_other_test_cast!(
decimal_array,
DataType::Decimal(10, 3),
- DecimalArray,
+ Int128Array,
DataType::Decimal(20, 6),
vec![
Some(1_234_000_i128),
@@ -247,11 +251,11 @@
]
);
- let decimal_array = create_decimal_array(&array, 10, 3)?;
+ let decimal_array = create_decimal_array_from_slice(&array, 10, 3)?;
generic_decimal_to_other_test_cast!(
decimal_array,
DataType::Decimal(10, 3),
- DecimalArray,
+ Int128Array,
DataType::Decimal(10, 2),
vec![
Some(123_i128),
@@ -268,14 +272,14 @@
#[test]
fn test_try_cast_decimal_to_numeric() -> Result<()> {
- // TODO we should add function to create DecimalArray with value and metadata
+ // TODO we should add function to create Int128Array with value and metadata
// https://github.com/apache/arrow-rs/issues/1009
let array: Vec<i128> = vec![1, 2, 3, 4, 5];
- let decimal_array = create_decimal_array(&array, 10, 0)?;
+ let decimal_array = create_decimal_array_from_slice(&array, 10, 3)?;
// decimal to i8
generic_decimal_to_other_test_cast!(
decimal_array,
- DataType::Decimal(10, 0),
+ DataType::Decimal(10, 3),
Int8Array,
DataType::Int8,
vec![
@@ -289,7 +293,7 @@
);
// decimal to i16
- let decimal_array = create_decimal_array(&array, 10, 0)?;
+ let decimal_array = create_decimal_array_from_slice(&array, 10, 3)?;
generic_decimal_to_other_test_cast!(
decimal_array,
DataType::Decimal(10, 0),
@@ -306,7 +310,7 @@
);
// decimal to i32
- let decimal_array = create_decimal_array(&array, 10, 0)?;
+ let decimal_array = create_decimal_array_from_slice(&array, 10, 3)?;
generic_decimal_to_other_test_cast!(
decimal_array,
DataType::Decimal(10, 0),
@@ -323,7 +327,7 @@
);
// decimal to i64
- let decimal_array = create_decimal_array(&array, 10, 0)?;
+ let decimal_array = create_decimal_array_from_slice(&array, 10, 3)?;
generic_decimal_to_other_test_cast!(
decimal_array,
DataType::Decimal(10, 0),
@@ -341,7 +345,7 @@
// decimal to float32
let array: Vec<i128> = vec![1234, 2222, 3, 4000, 5000];
- let decimal_array = create_decimal_array(&array, 10, 3)?;
+ let decimal_array = create_decimal_array_from_slice(&array, 10, 3)?;
generic_decimal_to_other_test_cast!(
decimal_array,
DataType::Decimal(10, 3),
@@ -357,7 +361,7 @@
]
);
// decimal to float64
- let decimal_array = create_decimal_array(&array, 20, 6)?;
+ let decimal_array = create_decimal_array_from_slice(&array, 10, 3)?;
generic_decimal_to_other_test_cast!(
decimal_array,
DataType::Decimal(20, 6),
@@ -383,7 +387,7 @@
Int8Array,
DataType::Int8,
vec![1, 2, 3, 4, 5],
- DecimalArray,
+ Int128Array,
DataType::Decimal(3, 0),
vec![
Some(1_i128),
@@ -399,7 +403,7 @@
Int16Array,
DataType::Int16,
vec![1, 2, 3, 4, 5],
- DecimalArray,
+ Int128Array,
DataType::Decimal(5, 0),
vec![
Some(1_i128),
@@ -415,7 +419,7 @@
Int32Array,
DataType::Int32,
vec![1, 2, 3, 4, 5],
- DecimalArray,
+ Int128Array,
DataType::Decimal(10, 0),
vec![
Some(1_i128),
@@ -431,7 +435,7 @@
Int64Array,
DataType::Int64,
vec![1, 2, 3, 4, 5],
- DecimalArray,
+ Int128Array,
DataType::Decimal(20, 0),
vec![
Some(1_i128),
@@ -447,7 +451,7 @@
Int64Array,
DataType::Int64,
vec![1, 2, 3, 4, 5],
- DecimalArray,
+ Int128Array,
DataType::Decimal(20, 2),
vec![
Some(100_i128),
@@ -463,7 +467,7 @@
Float32Array,
DataType::Float32,
vec![1.5, 2.5, 3.0, 1.123_456_8, 5.50],
- DecimalArray,
+ Int128Array,
DataType::Decimal(10, 2),
vec![
Some(150_i128),
@@ -479,7 +483,7 @@
Float64Array,
DataType::Float64,
vec![1.5, 2.5, 3.0, 1.123_456_8, 5.50],
- DecimalArray,
+ Int128Array,
DataType::Decimal(20, 4),
vec![
Some(15000_i128),
@@ -497,7 +501,7 @@
generic_test_cast!(
Int32Array,
DataType::Int32,
- vec![1, 2, 3, 4, 5],
+ [1, 2, 3, 4, 5],
UInt32Array,
DataType::UInt32,
vec![
@@ -516,7 +520,7 @@
generic_test_cast!(
Int32Array,
DataType::Int32,
- vec![1, 2, 3, 4, 5],
+ [1, 2, 3, 4, 5],
StringArray,
DataType::Utf8,
vec![Some("1"), Some("2"), Some("3"), Some("4"), Some("5")]
@@ -541,15 +545,12 @@
#[test]
fn test_cast_i64_t64() -> Result<()> {
let original = vec![1, 2, 3, 4, 5];
- let expected: Vec<Option<i64>> = original
- .iter()
- .map(|i| Some(Time64NanosecondArray::from(vec![*i]).value(0)))
- .collect();
+ let expected: Vec<Option<i64>> = original.iter().map(|i| Some(*i)).collect();
generic_test_cast!(
Int64Array,
DataType::Int64,
original.clone(),
- TimestampNanosecondArray,
+ Int64Array,
DataType::Timestamp(TimeUnit::Nanosecond, None),
expected
);
@@ -559,23 +560,9 @@
#[test]
fn invalid_cast() {
// Ensure a useful error happens at plan time if invalid casts are used
- let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
+ let schema = Schema::new(vec![Field::new("a", DataType::Null, false)]);
let result = try_cast(col("a", &schema).unwrap(), &schema, DataType::LargeBinary);
result.expect_err("expected Invalid CAST");
}
-
- // create decimal array with the specified precision and scale
- fn create_decimal_array(
- array: &[i128],
- precision: usize,
- scale: usize,
- ) -> Result<DecimalArray> {
- let mut decimal_builder = DecimalBuilder::new(array.len(), precision, scale);
- for value in array {
- decimal_builder.append_value(*value)?
- }
- decimal_builder.append_null()?;
- Ok(decimal_builder.finish())
- }
}
diff --git a/datafusion-physical-expr/src/expressions/variance.rs b/datafusion-physical-expr/src/expressions/variance.rs
index 70f25ce..6c3859e 100644
--- a/datafusion-physical-expr/src/expressions/variance.rs
+++ b/datafusion-physical-expr/src/expressions/variance.rs
@@ -20,11 +20,11 @@
use std::any::Any;
use std::sync::Arc;
+use crate::expressions::cast::{cast_with_error, DEFAULT_DATAFUSION_CAST_OPTIONS};
use crate::{AggregateExpr, PhysicalExpr};
use arrow::array::Float64Array;
use arrow::{
array::{ArrayRef, UInt64Array},
- compute::cast,
datatypes::DataType,
datatypes::Field,
};
@@ -255,7 +255,11 @@
}
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
- let values = &cast(&values[0], &DataType::Float64)?;
+ let values = &cast_with_error(
+ values[0].as_ref(),
+ &DataType::Float64,
+ DEFAULT_DATAFUSION_CAST_OPTIONS,
+ )?;
let arr = values
.as_any()
.downcast_ref::<Float64Array>()
@@ -334,13 +338,14 @@
use super::*;
use crate::expressions::col;
use crate::generic_test_op;
- use arrow::record_batch::RecordBatch;
use arrow::{array::*, datatypes::*};
+ use datafusion_common::field_util::SchemaExt;
+ use datafusion_common::record_batch::RecordBatch;
use datafusion_common::Result;
#[test]
fn variance_f64_1() -> Result<()> {
- let a: ArrayRef = Arc::new(Float64Array::from(vec![1_f64, 2_f64]));
+ let a: ArrayRef = Arc::new(Float64Array::from_slice(vec![1_f64, 2_f64]));
generic_test_op!(
a,
DataType::Float64,
@@ -352,8 +357,9 @@
#[test]
fn variance_f64_2() -> Result<()> {
- let a: ArrayRef =
- Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64]));
+ let a: ArrayRef = Arc::new(Float64Array::from_slice(vec![
+ 1_f64, 2_f64, 3_f64, 4_f64, 5_f64,
+ ]));
generic_test_op!(
a,
DataType::Float64,
@@ -365,8 +371,9 @@
#[test]
fn variance_f64_3() -> Result<()> {
- let a: ArrayRef =
- Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64]));
+ let a: ArrayRef = Arc::new(Float64Array::from_slice(vec![
+ 1_f64, 2_f64, 3_f64, 4_f64, 5_f64,
+ ]));
generic_test_op!(
a,
DataType::Float64,
@@ -378,7 +385,7 @@
#[test]
fn variance_f64_4() -> Result<()> {
- let a: ArrayRef = Arc::new(Float64Array::from(vec![1.1_f64, 2_f64, 3_f64]));
+ let a: ArrayRef = Arc::new(Float64Array::from_slice(vec![1.1_f64, 2_f64, 3_f64]));
generic_test_op!(
a,
DataType::Float64,
@@ -390,7 +397,7 @@
#[test]
fn variance_i32() -> Result<()> {
- let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
+ let a: ArrayRef = Arc::new(Int32Array::from_slice(vec![1, 2, 3, 4, 5]));
generic_test_op!(
a,
DataType::Int32,
@@ -402,8 +409,9 @@
#[test]
fn variance_u32() -> Result<()> {
- let a: ArrayRef =
- Arc::new(UInt32Array::from(vec![1_u32, 2_u32, 3_u32, 4_u32, 5_u32]));
+ let a: ArrayRef = Arc::new(UInt32Array::from_slice(vec![
+ 1_u32, 2_u32, 3_u32, 4_u32, 5_u32,
+ ]));
generic_test_op!(
a,
DataType::UInt32,
@@ -415,8 +423,9 @@
#[test]
fn variance_f32() -> Result<()> {
- let a: ArrayRef =
- Arc::new(Float32Array::from(vec![1_f32, 2_f32, 3_f32, 4_f32, 5_f32]));
+ let a: ArrayRef = Arc::new(Float32Array::from_slice(vec![
+ 1_f32, 2_f32, 3_f32, 4_f32, 5_f32,
+ ]));
generic_test_op!(
a,
DataType::Float32,
@@ -439,7 +448,7 @@
#[test]
fn test_variance_1_input() -> Result<()> {
- let a: ArrayRef = Arc::new(Float64Array::from(vec![1_f64]));
+ let a: ArrayRef = Arc::new(Float64Array::from_slice(vec![1_f64]));
let schema = Schema::new(vec![Field::new("a", DataType::Float64, false)]);
let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?;
@@ -456,13 +465,8 @@
#[test]
fn variance_i32_with_nulls() -> Result<()> {
- let a: ArrayRef = Arc::new(Int32Array::from(vec![
- Some(1),
- None,
- Some(3),
- Some(4),
- Some(5),
- ]));
+ let a: ArrayRef =
+ Int32Vec::from(vec![Some(1), None, Some(3), Some(4), Some(5)]).as_arc();
generic_test_op!(
a,
DataType::Int32,
@@ -474,7 +478,7 @@
#[test]
fn variance_i32_all_nulls() -> Result<()> {
- let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None]));
+ let a: ArrayRef = Int32Vec::from(vec![None, None]).as_arc();
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?;
@@ -491,8 +495,8 @@
#[test]
fn variance_f64_merge_1() -> Result<()> {
- let a = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64]));
- let b = Arc::new(Float64Array::from(vec![4_f64, 5_f64]));
+ let a = Arc::new(Float64Array::from_slice(vec![1_f64, 2_f64, 3_f64]));
+ let b = Arc::new(Float64Array::from_slice(vec![4_f64, 5_f64]));
let schema = Schema::new(vec![Field::new("a", DataType::Float64, false)]);
@@ -519,8 +523,10 @@
#[test]
fn variance_f64_merge_2() -> Result<()> {
- let a = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64]));
- let b = Arc::new(Float64Array::from(vec![None]));
+ let a = Arc::new(Float64Array::from_slice(vec![
+ 1_f64, 2_f64, 3_f64, 4_f64, 5_f64,
+ ]));
+ let b = Arc::new(Float64Array::from_iter(vec![None]));
let schema = Schema::new(vec![Field::new("a", DataType::Float64, false)]);
diff --git a/datafusion-physical-expr/src/field_util.rs b/datafusion-physical-expr/src/field_util.rs
index 2c9411e..f7a5e4b 100644
--- a/datafusion-physical-expr/src/field_util.rs
+++ b/datafusion-physical-expr/src/field_util.rs
@@ -18,6 +18,7 @@
//! Utility functions for complex field access
use arrow::datatypes::{DataType, Field};
+use datafusion_common::field_util::FieldExt;
use datafusion_common::ScalarValue;
use datafusion_common::{DataFusionError, Result};
diff --git a/datafusion-physical-expr/src/functions.rs b/datafusion-physical-expr/src/functions.rs
index 1350d49..0cc0975 100644
--- a/datafusion-physical-expr/src/functions.rs
+++ b/datafusion-physical-expr/src/functions.rs
@@ -31,7 +31,8 @@
use crate::PhysicalExpr;
use arrow::datatypes::{DataType, Schema};
-use arrow::record_batch::RecordBatch;
+
+use datafusion_common::record_batch::RecordBatch;
use datafusion_common::Result;
use datafusion_expr::BuiltinScalarFunction;
use datafusion_expr::ColumnarValue;
diff --git a/datafusion-physical-expr/src/lib.rs b/datafusion-physical-expr/src/lib.rs
index 8a2fe25..71bbcfc 100644
--- a/datafusion-physical-expr/src/lib.rs
+++ b/datafusion-physical-expr/src/lib.rs
@@ -17,6 +17,7 @@
mod aggregate_expr;
pub mod array_expressions;
+mod arrow_temporal_util;
pub mod coercion_rule;
#[cfg(feature = "crypto_expressions")]
pub mod crypto_expressions;
@@ -32,6 +33,8 @@
mod sort_expr;
pub mod string_expressions;
mod tdigest;
+#[cfg(test)]
+mod test_util;
#[cfg(feature = "unicode_expressions")]
pub mod unicode_expressions;
pub mod window;
@@ -39,4 +42,4 @@
pub use aggregate_expr::AggregateExpr;
pub use functions::ScalarFunctionExpr;
pub use physical_expr::PhysicalExpr;
-pub use sort_expr::PhysicalSortExpr;
+pub use sort_expr::{PhysicalSortExpr, SortColumn};
diff --git a/datafusion-physical-expr/src/math_expressions.rs b/datafusion-physical-expr/src/math_expressions.rs
index b16a596..b437efa 100644
--- a/datafusion-physical-expr/src/math_expressions.rs
+++ b/datafusion-physical-expr/src/math_expressions.rs
@@ -17,22 +17,23 @@
//! Math expressions
-use arrow::array::{Float32Array, Float64Array};
-use arrow::datatypes::DataType;
-use datafusion_common::ScalarValue;
-use datafusion_common::{DataFusionError, Result};
-use datafusion_expr::ColumnarValue;
use rand::{thread_rng, Rng};
use std::iter;
use std::sync::Arc;
+use arrow::array::{Float32Array, Float64Array};
+use arrow::compute::arity::unary;
+use arrow::datatypes::DataType;
+use datafusion_common::ScalarValue;
+use datafusion_common::{DataFusionError, Result};
+use datafusion_expr::ColumnarValue;
+
macro_rules! downcast_compute_op {
- ($ARRAY:expr, $NAME:expr, $FUNC:ident, $TYPE:ident) => {{
+ ($ARRAY:expr, $NAME:expr, $FUNC:ident, $TYPE:ident, $DT: path) => {{
let n = $ARRAY.as_any().downcast_ref::<$TYPE>();
match n {
Some(array) => {
- let res: $TYPE =
- arrow::compute::kernels::arity::unary(array, |x| x.$FUNC());
+ let res: $TYPE = unary(array, |x| x.$FUNC(), $DT);
Ok(Arc::new(res))
}
_ => Err(DataFusionError::Internal(format!(
@@ -48,11 +49,23 @@
match ($VALUE) {
ColumnarValue::Array(array) => match array.data_type() {
DataType::Float32 => {
- let result = downcast_compute_op!(array, $NAME, $FUNC, Float32Array);
+ let result = downcast_compute_op!(
+ array,
+ $NAME,
+ $FUNC,
+ Float32Array,
+ DataType::Float32
+ );
Ok(ColumnarValue::Array(result?))
}
DataType::Float64 => {
- let result = downcast_compute_op!(array, $NAME, $FUNC, Float64Array);
+ let result = downcast_compute_op!(
+ array,
+ $NAME,
+ $FUNC,
+ Float64Array,
+ DataType::Float64
+ );
Ok(ColumnarValue::Array(result?))
}
other => Err(DataFusionError::Internal(format!(
@@ -116,7 +129,7 @@
};
let mut rng = thread_rng();
let values = iter::repeat_with(|| rng.gen_range(0.0..1.0)).take(len);
- let array = Float64Array::from_iter_values(values);
+ let array = Float64Array::from_trusted_len_values_iter(values);
Ok(ColumnarValue::Array(Arc::new(array)))
}
@@ -124,11 +137,17 @@
mod tests {
use super::*;
- use arrow::array::{Float64Array, NullArray};
+ use arrow::{
+ array::{Float64Array, NullArray},
+ datatypes::DataType,
+ };
#[test]
fn test_random_expression() {
- let args = vec![ColumnarValue::Array(Arc::new(NullArray::new(1)))];
+ let args = vec![ColumnarValue::Array(Arc::new(NullArray::from_data(
+ DataType::Null,
+ 1,
+ )))];
let array = random(&args).expect("fail").into_array(1);
let floats = array.as_any().downcast_ref::<Float64Array>().expect("fail");
diff --git a/datafusion-physical-expr/src/physical_expr.rs b/datafusion-physical-expr/src/physical_expr.rs
index 25885b1..0954fe5 100644
--- a/datafusion-physical-expr/src/physical_expr.rs
+++ b/datafusion-physical-expr/src/physical_expr.rs
@@ -17,13 +17,12 @@
use arrow::datatypes::{DataType, Schema};
-use arrow::record_batch::RecordBatch;
-
use datafusion_common::Result;
use datafusion_expr::ColumnarValue;
use std::fmt::{Debug, Display};
+use datafusion_common::record_batch::RecordBatch;
use std::any::Any;
/// Expression that can be evaluated against a RecordBatch
diff --git a/datafusion-physical-expr/src/regex_expressions.rs b/datafusion-physical-expr/src/regex_expressions.rs
index 69de68e..fd8b7a3 100644
--- a/datafusion-physical-expr/src/regex_expressions.rs
+++ b/datafusion-physical-expr/src/regex_expressions.rs
@@ -21,42 +21,44 @@
//! Regex expressions
-use arrow::array::{ArrayRef, GenericStringArray, StringOffsetSizeTrait};
-use arrow::compute;
-use datafusion_common::{DataFusionError, Result};
use hashbrown::HashMap;
use lazy_static::lazy_static;
use regex::Regex;
use std::any::type_name;
use std::sync::Arc;
+use arrow::array::*;
+use arrow::error::ArrowError;
+
+use datafusion_common::{DataFusionError, Result};
+
macro_rules! downcast_string_arg {
($ARG:expr, $NAME:expr, $T:ident) => {{
$ARG.as_any()
- .downcast_ref::<GenericStringArray<T>>()
+ .downcast_ref::<Utf8Array<T>>()
.ok_or_else(|| {
DataFusionError::Internal(format!(
"could not cast {} to {}",
$NAME,
- type_name::<GenericStringArray<T>>()
+ type_name::<Utf8Array<T>>()
))
})?
}};
}
/// extract a specific group from a string column, using a regular expression
-pub fn regexp_match<T: StringOffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
+pub fn regexp_match<T: Offset>(args: &[ArrayRef]) -> Result<ArrayRef> {
match args.len() {
2 => {
let values = downcast_string_arg!(args[0], "string", T);
let regex = downcast_string_arg!(args[1], "pattern", T);
- compute::regexp_match(values, regex, None).map_err(DataFusionError::ArrowError)
+ Ok(regexp_matches(values, regex, None).map(|x| Arc::new(x) as Arc<dyn Array>)?)
}
3 => {
let values = downcast_string_arg!(args[0], "string", T);
let regex = downcast_string_arg!(args[1], "pattern", T);
let flags = Some(downcast_string_arg!(args[2], "flags", T));
- compute::regexp_match(values, regex, flags).map_err(DataFusionError::ArrowError)
+ Ok(regexp_matches(values, regex, flags).map(|x| Arc::new(x) as Arc<dyn Array>)?)
}
other => Err(DataFusionError::Internal(format!(
"regexp_match was called with {} arguments. It requires at least 2 and at most 3.",
@@ -79,7 +81,7 @@
/// Replaces substring(s) matching a POSIX regular expression.
///
/// example: `regexp_replace('Thomas', '.[mN]a.', 'M') = 'ThM'`
-pub fn regexp_replace<T: StringOffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
+pub fn regexp_replace<T: Offset>(args: &[ArrayRef]) -> Result<ArrayRef> {
// creating Regex is expensive so