blob: 441001099639a30beff80ddd48f7de22abf93a60 [file] [log] [blame]
// Copyright 2022 The Blaze Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::down_cast_any_ref;
use arrow::array::*;
use arrow::compute::{eq_dyn_binary_scalar, eq_dyn_bool_scalar, eq_dyn_scalar, eq_dyn_utf8_scalar};
use arrow::datatypes::Field;
use arrow::{
datatypes::{DataType, Schema},
record_batch::RecordBatch,
};
use datafusion::common::DataFusionError;
use datafusion::common::Result;
use datafusion::common::ScalarValue;
use datafusion::logical_expr::ColumnarValue;
use datafusion::physical_expr::PhysicalExpr;
use std::fmt::Debug;
use std::hash::{Hash, Hasher};
use std::{any::Any, sync::Arc};
/// expression to get value of a key in map array.
#[derive(Debug, Hash)]
pub struct GetMapValueExpr {
arg: Arc<dyn PhysicalExpr>,
key: ScalarValue,
}
impl GetMapValueExpr {
pub fn new(arg: Arc<dyn PhysicalExpr>, key: ScalarValue) -> Self {
Self { arg, key }
}
pub fn key(&self) -> &ScalarValue {
&self.key
}
pub fn arg(&self) -> &Arc<dyn PhysicalExpr> {
&self.arg
}
}
impl PartialEq<dyn Any> for GetMapValueExpr {
fn eq(&self, other: &dyn Any) -> bool {
down_cast_any_ref(other)
.downcast_ref::<Self>()
.map(|x| self.arg.eq(&x.arg) && self.key == x.key)
.unwrap_or(false)
}
}
impl std::fmt::Display for GetMapValueExpr {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "({}).[{}]", self.arg, self.key)
}
}
impl PhysicalExpr for GetMapValueExpr {
fn as_any(&self) -> &dyn Any {
self
}
fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
let data_type = self.arg.data_type(input_schema)?;
get_data_type_field(&data_type).map(|f| f.data_type().clone())
}
fn nullable(&self, input_schema: &Schema) -> Result<bool> {
let data_type = self.arg.data_type(input_schema)?;
get_data_type_field(&data_type).map(|f| f.is_nullable())
}
fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
let array = self.arg.evaluate(batch)?.into_array(1);
match (array.data_type(), &self.key) {
(DataType::Map(_, _), _) if self.key.is_null() => {
Err(DataFusionError::NotImplemented("map key not support Null Type".to_string()))
}
(DataType::Map(_, _), _) => {
let as_map_array = array.as_any().downcast_ref::<MapArray>().unwrap();
if !as_map_array.key_type().equals_datatype(&self.key.get_datatype()) {
return Err(DataFusionError::Execution("MapArray key type must equal to GetMapValue key type".to_string()))
}
macro_rules! get_boolean_value {
($keyarrowty:ident, $scalar:expr) => {{
type A = paste::paste! {[< $keyarrowty Array >]};
let key_array = as_map_array.keys().as_any().downcast_ref::<A>().unwrap();
let ans_boolean = eq_dyn_bool_scalar(key_array, $scalar)?;
let ans_index = ans_boolean.iter().enumerate()
.filter(|(_, ans)| if let Some(res) = ans { res.clone() } else { false })
.map(|(idx, _)|idx as i32)
.collect::<Vec<_>>();
let mut indices = vec![];
if ans_index.len() == 0 {
for _i in 0..as_map_array.len() {
indices.push(None);
}
} else {
let mut cur_offset = 0;
for &idx in as_map_array.value_offsets().into_iter().skip(1) {
if cur_offset >= ans_index.len() {
indices.push(None);
} else if idx <= ans_index[cur_offset] {
indices.push(None);
} else {
indices.push(Some(ans_index[cur_offset] as u32));
cur_offset += 1;
}
}
}
let indice_array = UInt32Array::from(indices);
let ans_array = arrow::compute::take(as_map_array.values(), &indice_array, None)?;
Ok(ColumnarValue::Array(ans_array))
}};
}
macro_rules! get_prim_value {
($keyarrowty:ident, $scalar:expr) => {{
type A = paste::paste! {[< $keyarrowty Array >]};
let key_array = as_map_array.keys().as_any().downcast_ref::<A>().unwrap();
let ans_boolean = eq_dyn_scalar(key_array, $scalar)?;
let ans_index = ans_boolean.iter().enumerate()
.filter(|(_, ans)| if let Some(res) = ans { res.clone() } else { false })
.map(|(idx, _)|idx as i32)
.collect::<Vec<_>>();
let mut indices = vec![];
if ans_index.len() == 0 {
for _i in 0..as_map_array.len() {
indices.push(None);
}
} else {
let mut cur_offset = 0;
for &idx in as_map_array.value_offsets().into_iter().skip(1) {
if cur_offset >= ans_index.len() {
indices.push(None);
} else if idx <= ans_index[cur_offset] {
indices.push(None);
} else {
indices.push(Some(ans_index[cur_offset] as u32));
cur_offset += 1;
}
}
}
let indice_array = UInt32Array::from(indices);
let ans_array = arrow::compute::take(as_map_array.values(), &indice_array, None)?;
Ok(ColumnarValue::Array(ans_array))
}};
}
macro_rules! get_str_value {
($keyarrowty:ident, $scalar:expr) => {{
type A = paste::paste! {[< $keyarrowty Array >]};
let key_array = as_map_array.keys().as_any().downcast_ref::<A>().unwrap();
let ans_boolean = eq_dyn_utf8_scalar(key_array, $scalar)?;
let ans_index = ans_boolean.iter().enumerate()
.filter(|(_, ans)| if let Some(res) = ans { res.clone() } else { false })
.map(|(idx, _)|idx as i32)
.collect::<Vec<_>>();
let mut indices = vec![];
if ans_index.len() == 0 {
for _i in 0..as_map_array.len() {
indices.push(None);
}
} else {
let mut cur_offset = 0;
for &idx in as_map_array.value_offsets().into_iter().skip(1) {
if cur_offset >= ans_index.len() {
indices.push(None);
} else if idx <= ans_index[cur_offset] {
indices.push(None);
} else {
indices.push(Some(ans_index[cur_offset] as u32));
cur_offset += 1;
}
}
}
let indice_array = UInt32Array::from(indices);
let ans_array = arrow::compute::take(as_map_array.values(), &indice_array, None)?;
Ok(ColumnarValue::Array(ans_array))
}};
}
macro_rules! get_binary_value {
($keyarrowty:ident, $scalar:expr) => {{
type A = paste::paste! {[< $keyarrowty Array >]};
let key_array = as_map_array.keys().as_any().downcast_ref::<A>().unwrap();
let ans_boolean = eq_dyn_binary_scalar(key_array, $scalar)?;
let ans_index = ans_boolean.iter().enumerate()
.filter(|(_, ans)| if let Some(res) = ans { res.clone() } else { false })
.map(|(idx, _)|idx as i32)
.collect::<Vec<_>>();
let mut indices = vec![];
if ans_index.len() == 0 {
for _i in 0..as_map_array.len() {
indices.push(None);
}
} else {
let mut cur_offset = 0;
for &idx in as_map_array.value_offsets().into_iter().skip(1) {
if cur_offset >= ans_index.len() {
indices.push(None);
} else if idx <= ans_index[cur_offset] {
indices.push(None);
} else {
indices.push(Some(ans_index[cur_offset] as u32));
cur_offset += 1;
}
}
}
let indice_array = UInt32Array::from(indices);
let ans_array = arrow::compute::take(as_map_array.values(), &indice_array, None)?;
Ok(ColumnarValue::Array(ans_array))
}};
}
match &self.key {
ScalarValue::Boolean(Some(i)) => get_boolean_value!(Boolean, *i),
ScalarValue::Float32(Some(i)) => get_prim_value!(Float32, *i),
ScalarValue::Float64(Some(i)) => get_prim_value!(Float64, *i),
ScalarValue::Int8(Some(i)) => get_prim_value!(Int8, *i),
ScalarValue::Int16(Some(i)) => get_prim_value!(Int16, *i),
ScalarValue::Int32(Some(i)) => get_prim_value!(Int32, *i),
ScalarValue::Int64(Some(i)) => get_prim_value!(Int64, *i),
ScalarValue::UInt8(Some(i)) => get_prim_value!(UInt8, *i),
ScalarValue::UInt16(Some(i)) => get_prim_value!(UInt16, *i),
ScalarValue::UInt32(Some(i)) => get_prim_value!(UInt32, *i),
ScalarValue::UInt64(Some(i)) => get_prim_value!(UInt64, *i),
ScalarValue::Utf8(Some(i)) => get_str_value!(String, i.as_str()),
ScalarValue::LargeUtf8(Some(i)) => get_str_value!(LargeString, i.as_str()),
ScalarValue::Binary(Some(i)) => get_binary_value!(Binary, i.as_slice()),
ScalarValue::LargeBinary(Some(i)) => get_binary_value!(LargeBinary, i.as_slice()),
t => {
Err(DataFusionError::Execution(
format!("get map value (Map) not support {} as key type", t)))
},
}
}
(dt, key) => {
Err(DataFusionError::Execution(format!("get map value (Map) is only possible on map with no-null key. Tried {:?} with {:?} key", dt, key)))
},
}
}
fn children(&self) -> Vec<Arc<dyn PhysicalExpr>> {
vec![self.arg.clone()]
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn PhysicalExpr>>,
) -> Result<Arc<dyn PhysicalExpr>> {
Ok(Arc::new(Self::new(children[0].clone(), self.key.clone())))
}
fn dyn_hash(&self, state: &mut dyn Hasher) {
let mut s = state;
self.hash(&mut s);
}
}
fn get_data_type_field(data_type: &DataType) -> Result<Field> {
match data_type {
DataType::Map(field, _) => {
if let DataType::Struct(fields) = field.data_type() {
Ok(fields[1].as_ref().clone()) // values field
} else {
Err(DataFusionError::NotImplemented(
"Map field only support Struct".to_string(),
))
}
}
_ => Err(DataFusionError::Plan(
"The expression to get map value is only valid for `Map` types".to_string(),
)),
}
}
#[cfg(test)]
mod test {
use super::GetMapValueExpr;
use arrow::array::*;
use arrow::buffer::Buffer;
use arrow::datatypes::{DataType, Field, ToByteSlice};
use arrow::record_batch::RecordBatch;
use datafusion::assert_batches_eq;
use datafusion::common::ScalarValue;
use datafusion::physical_plan::expressions::Column;
use datafusion::physical_plan::PhysicalExpr;
use std::sync::Arc;
#[test]
fn test_map_1() -> Result<(), Box<dyn std::error::Error>> {
//Construct key and values
let key_data = ArrayData::builder(DataType::Int32)
.len(8)
.add_buffer(Buffer::from(&[0, 1, 2, 3, 4, 5, 6, 7].to_byte_slice()))
.build()
.unwrap();
let value_data = ArrayData::builder(DataType::UInt32)
.len(8)
.add_buffer(Buffer::from(
&[0u32, 10, 20, 0, 40, 0, 60, 70].to_byte_slice(),
))
.null_bit_buffer(Some(Buffer::from(&[0b11010110])))
.build()
.unwrap();
let entry_offsets = Buffer::from(&[0, 3, 6, 8].to_byte_slice());
let keys_field = Arc::new(Field::new("keys", DataType::Int32, false));
let values_field = Arc::new(Field::new("values", DataType::UInt32, true));
let entry_struct = StructArray::from(vec![
(keys_field.clone(), make_array(key_data)),
(values_field.clone(), make_array(value_data.clone())),
]);
// Construct a map array from the above two
let map_data_type = DataType::Map(
Arc::new(Field::new(
"entries",
entry_struct.data_type().clone(),
true,
)),
false,
);
let map_data = ArrayData::builder(map_data_type)
.len(3)
.add_buffer(entry_offsets)
.add_child_data(entry_struct.into_data())
.build()
.unwrap();
let map_array: ArrayRef = Arc::new(MapArray::from(map_data));
let input_batch = RecordBatch::try_from_iter_with_nullable(vec![("col", map_array, true)])?;
let get_indexed = Arc::new(GetMapValueExpr::new(
Arc::new(Column::new("col", 0)),
ScalarValue::from(7_i32),
));
let output_array = get_indexed.evaluate(&input_batch)?.into_array(0);
let output_batch =
RecordBatch::try_from_iter_with_nullable(vec![("col", output_array, true)])?;
let expected =
vec!["+-----+", "| col |", "+-----+", "| |", "| |", "| 70 |", "+-----+"];
assert_batches_eq!(expected, &[output_batch]);
Ok(())
}
#[test]
fn test_map_2() -> Result<(), Box<dyn std::error::Error>> {
let keys = vec!["a", "b", "c", "d", "e", "f", "g", "h"];
let values_data = UInt32Array::from(vec![0u32, 10, 20, 30, 40, 50, 60, 70]);
// Construct a buffer for value offsets, for the nested array:
// [[a, b, c], [d, e, f], [g, h]]
let entry_offsets = [0, 3, 6, 8];
let map_array: ArrayRef = Arc::new(
MapArray::new_from_strings(keys.clone().into_iter(), &values_data, &entry_offsets)
.unwrap(),
);
let input_batch = RecordBatch::try_from_iter_with_nullable(vec![("col", map_array, true)])?;
let get_indexed = Arc::new(GetMapValueExpr::new(
Arc::new(Column::new("col", 0)),
ScalarValue::from("e"),
));
let output_array = get_indexed.evaluate(&input_batch)?.into_array(0);
let output_batch =
RecordBatch::try_from_iter_with_nullable(vec![("col", output_array, true)])?;
let expected =
vec!["+-----+", "| col |", "+-----+", "| |", "| 40 |", "| |", "+-----+"];
assert_batches_eq!(expected, &[output_batch]);
Ok(())
}
}