blob: 78a1e0f81763955e970f4ff328cbaf7e893655f3 [file] [log] [blame]
// Copyright 2022 The Blaze Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{
any::Any,
fmt::{Display, Formatter},
hash::{Hash, Hasher},
sync::Arc,
};
use arrow::{datatypes::*, record_batch::RecordBatch};
use datafusion::{
common::Result, logical_expr::ColumnarValue, physical_expr::PhysicalExpr, scalar::ScalarValue,
};
use crate::down_cast_any_ref;
/// cast expression compatible with spark
#[derive(Debug, Hash)]
pub struct TryCastExpr {
pub expr: Arc<dyn PhysicalExpr>,
pub cast_type: DataType,
}
impl PartialEq<dyn Any> for TryCastExpr {
fn eq(&self, other: &dyn Any) -> bool {
down_cast_any_ref(other)
.downcast_ref::<Self>()
.map(|x| self.expr.eq(&x.expr) && self.cast_type == x.cast_type)
.unwrap_or(false)
}
}
impl TryCastExpr {
pub fn new(expr: Arc<dyn PhysicalExpr>, cast_type: DataType) -> Self {
Self { expr, cast_type }
}
}
impl Display for TryCastExpr {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "cast({} AS {:?})", self.expr, self.cast_type)
}
}
impl PhysicalExpr for TryCastExpr {
fn as_any(&self) -> &dyn Any {
self
}
fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
Ok(self.cast_type.clone())
}
fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
Ok(true)
}
fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
Ok(match self.expr.evaluate(batch)? {
ColumnarValue::Array(array) => {
ColumnarValue::Array(datafusion_ext_commons::cast::cast(&array, &self.cast_type)?)
}
ColumnarValue::Scalar(scalar) => {
let array = scalar.to_array()?;
ColumnarValue::Scalar(ScalarValue::try_from_array(
&datafusion_ext_commons::cast::cast(&array, &self.cast_type)?,
0,
)?)
}
})
}
fn children(&self) -> Vec<Arc<dyn PhysicalExpr>> {
vec![self.expr.clone()]
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn PhysicalExpr>>,
) -> Result<Arc<dyn PhysicalExpr>> {
Ok(Arc::new(Self::new(
children[0].clone(),
self.cast_type.clone(),
)))
}
fn dyn_hash(&self, state: &mut dyn Hasher) {
let mut s = state;
self.hash(&mut s);
}
}
#[cfg(test)]
mod test {
use std::sync::Arc;
use arrow::{
array::{ArrayRef, Float32Array, Int32Array, StringArray},
datatypes::{DataType, Field, Schema},
record_batch::RecordBatch,
};
use datafusion::physical_expr::{expressions as phys_expr, PhysicalExpr};
use crate::cast::TryCastExpr;
#[test]
fn test_ok_1() {
// input: Array
// cast Float32 into Int32
let float_arr: ArrayRef = Arc::new(Float32Array::from(vec![
Some(7.6),
Some(9.0),
Some(3.4),
Some(-0.0),
Some(-99.9),
None,
]));
let schema = Arc::new(Schema::new(vec![Field::new(
"col",
DataType::Float32,
true,
)]));
let batch =
RecordBatch::try_new(schema, vec![float_arr]).expect("Error creating RecordBatch");
let cast_type = DataType::Int32;
let expr = Arc::new(TryCastExpr::new(
phys_expr::col("col", &batch.schema()).unwrap(),
cast_type,
));
let ret = expr
.evaluate(&batch)
.expect("Error evaluating expr")
.into_array(batch.num_rows())
.unwrap();
let expected: ArrayRef = Arc::new(Int32Array::from(vec![
Some(7),
Some(9),
Some(3),
Some(0),
Some(-99),
None,
]));
assert_eq!(&ret, &expected);
}
#[test]
fn test_ok_2() {
// input: Array
// cast Utf8 into Float32
let string_arr: ArrayRef = Arc::new(StringArray::from(vec![
Some("123"),
Some("321.9"),
Some("-098"),
Some("sda"),
None, // null
]));
let schema = Arc::new(Schema::new(vec![Field::new("col", DataType::Utf8, true)]));
let batch =
RecordBatch::try_new(schema, vec![string_arr]).expect("Error creating RecordBatch");
let cast_type = DataType::Float32;
let expr = Arc::new(TryCastExpr::new(
phys_expr::col("col", &batch.schema()).unwrap(),
cast_type,
));
let ret = expr
.evaluate(&batch)
.expect("Error evaluating expr")
.into_array(batch.num_rows())
.unwrap();
let expected: ArrayRef = Arc::new(Float32Array::from(vec![
Some(123.0),
Some(321.9),
Some(-98.0),
None,
None,
]));
assert_eq!(&ret, &expected);
}
#[test]
fn test_ok_3() {
// input: Scalar
// cast Utf8 into Float32
let string_arr: ArrayRef = Arc::new(StringArray::from(vec![
Some("123"),
Some("321.9"),
Some("-098"),
Some("sda"),
None, // null
]));
let schema = Arc::new(Schema::new(vec![Field::new("col", DataType::Utf8, true)]));
let batch =
RecordBatch::try_new(schema, vec![string_arr]).expect("Error creating RecordBatch");
let cast_type = DataType::Float32;
let expr = Arc::new(TryCastExpr::new(phys_expr::lit("123.4"), cast_type));
let ret = expr
.evaluate(&batch)
.expect("Error evaluating expr")
.into_array(batch.num_rows())
.unwrap();
let expected: ArrayRef = Arc::new(Float32Array::from(vec![
Some(123.4),
Some(123.4),
Some(123.4),
Some(123.4),
Some(123.4),
]));
assert_eq!(&ret, &expected);
}
}