blob: a181f98b6eb6c3ae706977a5475577b4ea719e76 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
//! Serde code to convert Arrow schemas and DataFusion logical plans to Ballista protocol
//! buffer format, allowing DataFusion logical plans to be serialized and transmitted between
//! processes.
use std::{
boxed,
convert::{TryFrom, TryInto},
};
use crate::datasource::DFTableAdapter;
use crate::serde::{protobuf, BallistaError};
use arrow::datatypes::{DataType, Schema};
use datafusion::datasource::CsvFile;
use datafusion::logical_plan::{Expr, JoinType, LogicalPlan};
use datafusion::physical_plan::aggregates::AggregateFunction;
use datafusion::{datasource::parquet::ParquetTable, logical_plan::exprlist_to_fields};
use protobuf::{
arrow_type, logical_expr_node::ExprType, scalar_type, DateUnit, Field,
PrimitiveScalarType, ScalarListValue, ScalarType,
};
use super::super::proto_error;
use datafusion::physical_plan::functions::BuiltinScalarFunction;
impl protobuf::IntervalUnit {
pub fn from_arrow_interval_unit(
interval_unit: &arrow::datatypes::IntervalUnit,
) -> Self {
match interval_unit {
arrow::datatypes::IntervalUnit::YearMonth => {
protobuf::IntervalUnit::YearMonth
}
arrow::datatypes::IntervalUnit::DayTime => protobuf::IntervalUnit::DayTime,
}
}
pub fn from_i32_to_arrow(
interval_unit_i32: i32,
) -> Result<arrow::datatypes::IntervalUnit, BallistaError> {
let pb_interval_unit = protobuf::IntervalUnit::from_i32(interval_unit_i32);
use arrow::datatypes::IntervalUnit;
match pb_interval_unit {
Some(interval_unit) => Ok(match interval_unit {
protobuf::IntervalUnit::YearMonth => IntervalUnit::YearMonth,
protobuf::IntervalUnit::DayTime => IntervalUnit::DayTime,
}),
None => Err(proto_error(
"Error converting i32 to DateUnit: Passed invalid variant",
)),
}
}
}
/* Arrow changed dates to no longer have date unit
impl protobuf::DateUnit {
pub fn from_arrow_date_unit(val: &arrow::datatypes::DateUnit) -> Self {
match val {
arrow::datatypes::DateUnit::Day => protobuf::DateUnit::Day,
arrow::datatypes::DateUnit::Millisecond => protobuf::DateUnit::DateMillisecond,
}
}
pub fn from_i32_to_arrow(date_unit_i32: i32) -> Result<arrow::datatypes::DateUnit, BallistaError> {
let pb_date_unit = protobuf::DateUnit::from_i32(date_unit_i32);
use arrow::datatypes::DateUnit;
match pb_date_unit {
Some(date_unit) => Ok(match date_unit {
protobuf::DateUnit::Day => DateUnit::Day,
protobuf::DateUnit::DateMillisecond => DateUnit::Millisecond,
}),
None => Err(proto_error("Error converting i32 to DateUnit: Passed invalid variant")),
}
}
}*/
impl protobuf::TimeUnit {
pub fn from_arrow_time_unit(val: &arrow::datatypes::TimeUnit) -> Self {
match val {
arrow::datatypes::TimeUnit::Second => protobuf::TimeUnit::Second,
arrow::datatypes::TimeUnit::Millisecond => {
protobuf::TimeUnit::TimeMillisecond
}
arrow::datatypes::TimeUnit::Microsecond => protobuf::TimeUnit::Microsecond,
arrow::datatypes::TimeUnit::Nanosecond => protobuf::TimeUnit::Nanosecond,
}
}
pub fn from_i32_to_arrow(
time_unit_i32: i32,
) -> Result<arrow::datatypes::TimeUnit, BallistaError> {
let pb_time_unit = protobuf::TimeUnit::from_i32(time_unit_i32);
use arrow::datatypes::TimeUnit;
match pb_time_unit {
Some(time_unit) => Ok(match time_unit {
protobuf::TimeUnit::Second => TimeUnit::Second,
protobuf::TimeUnit::TimeMillisecond => TimeUnit::Millisecond,
protobuf::TimeUnit::Microsecond => TimeUnit::Microsecond,
protobuf::TimeUnit::Nanosecond => TimeUnit::Nanosecond,
}),
None => Err(proto_error(
"Error converting i32 to TimeUnit: Passed invalid variant",
)),
}
}
}
impl From<&arrow::datatypes::Field> for protobuf::Field {
fn from(field: &arrow::datatypes::Field) -> Self {
protobuf::Field {
name: field.name().to_owned(),
arrow_type: Some(Box::new(field.data_type().into())),
nullable: field.is_nullable(),
children: Vec::new(),
}
}
}
impl From<&arrow::datatypes::DataType> for protobuf::ArrowType {
fn from(val: &arrow::datatypes::DataType) -> protobuf::ArrowType {
protobuf::ArrowType {
arrow_type_enum: Some(val.into()),
}
}
}
impl TryInto<arrow::datatypes::DataType> for &protobuf::ArrowType {
type Error = BallistaError;
fn try_into(self) -> Result<arrow::datatypes::DataType, Self::Error> {
let pb_arrow_type = self.arrow_type_enum.as_ref().ok_or_else(|| {
proto_error(
"Protobuf deserialization error: ArrowType missing required field 'data_type'",
)
})?;
use arrow::datatypes::DataType;
Ok(match pb_arrow_type {
protobuf::arrow_type::ArrowTypeEnum::None(_) => DataType::Null,
protobuf::arrow_type::ArrowTypeEnum::Bool(_) => DataType::Boolean,
protobuf::arrow_type::ArrowTypeEnum::Uint8(_) => DataType::UInt8,
protobuf::arrow_type::ArrowTypeEnum::Int8(_) => DataType::Int8,
protobuf::arrow_type::ArrowTypeEnum::Uint16(_) => DataType::UInt16,
protobuf::arrow_type::ArrowTypeEnum::Int16(_) => DataType::Int16,
protobuf::arrow_type::ArrowTypeEnum::Uint32(_) => DataType::UInt32,
protobuf::arrow_type::ArrowTypeEnum::Int32(_) => DataType::Int32,
protobuf::arrow_type::ArrowTypeEnum::Uint64(_) => DataType::UInt64,
protobuf::arrow_type::ArrowTypeEnum::Int64(_) => DataType::Int64,
protobuf::arrow_type::ArrowTypeEnum::Float16(_) => DataType::Float16,
protobuf::arrow_type::ArrowTypeEnum::Float32(_) => DataType::Float32,
protobuf::arrow_type::ArrowTypeEnum::Float64(_) => DataType::Float64,
protobuf::arrow_type::ArrowTypeEnum::Utf8(_) => DataType::Utf8,
protobuf::arrow_type::ArrowTypeEnum::LargeUtf8(_) => DataType::LargeUtf8,
protobuf::arrow_type::ArrowTypeEnum::Binary(_) => DataType::Binary,
protobuf::arrow_type::ArrowTypeEnum::FixedSizeBinary(size) => {
DataType::FixedSizeBinary(*size)
}
protobuf::arrow_type::ArrowTypeEnum::LargeBinary(_) => DataType::LargeBinary,
protobuf::arrow_type::ArrowTypeEnum::Date32(_) => DataType::Date32,
protobuf::arrow_type::ArrowTypeEnum::Date64(_) => DataType::Date64,
protobuf::arrow_type::ArrowTypeEnum::Duration(time_unit_i32) => {
DataType::Duration(protobuf::TimeUnit::from_i32_to_arrow(*time_unit_i32)?)
}
protobuf::arrow_type::ArrowTypeEnum::Timestamp(timestamp) => {
DataType::Timestamp(
protobuf::TimeUnit::from_i32_to_arrow(timestamp.time_unit)?,
match timestamp.timezone.is_empty() {
true => None,
false => Some(timestamp.timezone.to_owned()),
},
)
}
protobuf::arrow_type::ArrowTypeEnum::Time32(time_unit_i32) => {
DataType::Time32(protobuf::TimeUnit::from_i32_to_arrow(*time_unit_i32)?)
}
protobuf::arrow_type::ArrowTypeEnum::Time64(time_unit_i32) => {
DataType::Time64(protobuf::TimeUnit::from_i32_to_arrow(*time_unit_i32)?)
}
protobuf::arrow_type::ArrowTypeEnum::Interval(interval_unit_i32) => {
DataType::Interval(protobuf::IntervalUnit::from_i32_to_arrow(
*interval_unit_i32,
)?)
}
protobuf::arrow_type::ArrowTypeEnum::Decimal(protobuf::Decimal {
whole,
fractional,
}) => DataType::Decimal(*whole as usize, *fractional as usize),
protobuf::arrow_type::ArrowTypeEnum::List(boxed_list) => {
let field_ref = boxed_list
.field_type
.as_ref()
.ok_or_else(|| proto_error("Protobuf deserialization error: List message was missing required field 'field_type'"))?
.as_ref();
arrow::datatypes::DataType::List(Box::new(field_ref.try_into()?))
}
protobuf::arrow_type::ArrowTypeEnum::LargeList(boxed_list) => {
let field_ref = boxed_list
.field_type
.as_ref()
.ok_or_else(|| proto_error("Protobuf deserialization error: List message was missing required field 'field_type'"))?
.as_ref();
arrow::datatypes::DataType::LargeList(Box::new(field_ref.try_into()?))
}
protobuf::arrow_type::ArrowTypeEnum::FixedSizeList(boxed_list) => {
let fsl_ref = boxed_list.as_ref();
let pb_fieldtype = fsl_ref
.field_type
.as_ref()
.ok_or_else(|| proto_error("Protobuf deserialization error: FixedSizeList message was missing required field 'field_type'"))?;
arrow::datatypes::DataType::FixedSizeList(
Box::new(pb_fieldtype.as_ref().try_into()?),
fsl_ref.list_size,
)
}
protobuf::arrow_type::ArrowTypeEnum::Struct(struct_type) => {
let fields = struct_type
.sub_field_types
.iter()
.map(|field| field.try_into())
.collect::<Result<Vec<_>, _>>()?;
arrow::datatypes::DataType::Struct(fields)
}
protobuf::arrow_type::ArrowTypeEnum::Union(union) => {
let union_types = union
.union_types
.iter()
.map(|field| field.try_into())
.collect::<Result<Vec<_>, _>>()?;
arrow::datatypes::DataType::Union(union_types)
}
protobuf::arrow_type::ArrowTypeEnum::Dictionary(boxed_dict) => {
let dict_ref = boxed_dict.as_ref();
let pb_key = dict_ref
.key
.as_ref()
.ok_or_else(|| proto_error("Protobuf deserialization error: Dictionary message was missing required field 'key'"))?;
let pb_value = dict_ref
.value
.as_ref()
.ok_or_else(|| proto_error("Protobuf deserialization error: Dictionary message was missing required field 'value'"))?;
arrow::datatypes::DataType::Dictionary(
Box::new(pb_key.as_ref().try_into()?),
Box::new(pb_value.as_ref().try_into()?),
)
}
})
}
}
impl TryInto<arrow::datatypes::DataType> for &Box<protobuf::List> {
type Error = BallistaError;
fn try_into(self) -> Result<arrow::datatypes::DataType, Self::Error> {
let list_ref = self.as_ref();
match &list_ref.field_type {
Some(pb_field) => {
let pb_field_ref = pb_field.as_ref();
let arrow_field: arrow::datatypes::Field = pb_field_ref.try_into()?;
Ok(arrow::datatypes::DataType::List(Box::new(arrow_field)))
}
None => Err(proto_error(
"List message missing required field 'field_type'",
)),
}
}
}
impl From<&arrow::datatypes::DataType> for protobuf::arrow_type::ArrowTypeEnum {
fn from(val: &arrow::datatypes::DataType) -> protobuf::arrow_type::ArrowTypeEnum {
use protobuf::arrow_type::ArrowTypeEnum;
use protobuf::ArrowType;
use protobuf::EmptyMessage;
match val {
DataType::Null => ArrowTypeEnum::None(EmptyMessage {}),
DataType::Boolean => ArrowTypeEnum::Bool(EmptyMessage {}),
DataType::Int8 => ArrowTypeEnum::Int8(EmptyMessage {}),
DataType::Int16 => ArrowTypeEnum::Int16(EmptyMessage {}),
DataType::Int32 => ArrowTypeEnum::Int32(EmptyMessage {}),
DataType::Int64 => ArrowTypeEnum::Int64(EmptyMessage {}),
DataType::UInt8 => ArrowTypeEnum::Uint8(EmptyMessage {}),
DataType::UInt16 => ArrowTypeEnum::Uint16(EmptyMessage {}),
DataType::UInt32 => ArrowTypeEnum::Uint32(EmptyMessage {}),
DataType::UInt64 => ArrowTypeEnum::Uint64(EmptyMessage {}),
DataType::Float16 => ArrowTypeEnum::Float16(EmptyMessage {}),
DataType::Float32 => ArrowTypeEnum::Float32(EmptyMessage {}),
DataType::Float64 => ArrowTypeEnum::Float64(EmptyMessage {}),
DataType::Timestamp(time_unit, timezone) => {
ArrowTypeEnum::Timestamp(protobuf::Timestamp {
time_unit: protobuf::TimeUnit::from_arrow_time_unit(time_unit) as i32,
timezone: timezone.to_owned().unwrap_or_else(String::new),
})
}
DataType::Date32 => ArrowTypeEnum::Date32(EmptyMessage {}),
DataType::Date64 => ArrowTypeEnum::Date64(EmptyMessage {}),
DataType::Time32(time_unit) => ArrowTypeEnum::Time32(
protobuf::TimeUnit::from_arrow_time_unit(time_unit) as i32,
),
DataType::Time64(time_unit) => ArrowTypeEnum::Time64(
protobuf::TimeUnit::from_arrow_time_unit(time_unit) as i32,
),
DataType::Duration(time_unit) => ArrowTypeEnum::Duration(
protobuf::TimeUnit::from_arrow_time_unit(time_unit) as i32,
),
DataType::Interval(interval_unit) => ArrowTypeEnum::Interval(
protobuf::IntervalUnit::from_arrow_interval_unit(interval_unit) as i32,
),
DataType::Binary => ArrowTypeEnum::Binary(EmptyMessage {}),
DataType::FixedSizeBinary(size) => ArrowTypeEnum::FixedSizeBinary(*size),
DataType::LargeBinary => ArrowTypeEnum::LargeBinary(EmptyMessage {}),
DataType::Utf8 => ArrowTypeEnum::Utf8(EmptyMessage {}),
DataType::LargeUtf8 => ArrowTypeEnum::LargeUtf8(EmptyMessage {}),
DataType::List(item_type) => ArrowTypeEnum::List(Box::new(protobuf::List {
field_type: Some(Box::new(item_type.as_ref().into())),
})),
DataType::FixedSizeList(item_type, size) => {
ArrowTypeEnum::FixedSizeList(Box::new(protobuf::FixedSizeList {
field_type: Some(Box::new(item_type.as_ref().into())),
list_size: *size,
}))
}
DataType::LargeList(item_type) => {
ArrowTypeEnum::LargeList(Box::new(protobuf::List {
field_type: Some(Box::new(item_type.as_ref().into())),
}))
}
DataType::Struct(struct_fields) => ArrowTypeEnum::Struct(protobuf::Struct {
sub_field_types: struct_fields
.iter()
.map(|field| field.into())
.collect::<Vec<_>>(),
}),
DataType::Union(union_types) => ArrowTypeEnum::Union(protobuf::Union {
union_types: union_types
.iter()
.map(|field| field.into())
.collect::<Vec<_>>(),
}),
DataType::Dictionary(key_type, value_type) => {
ArrowTypeEnum::Dictionary(Box::new(protobuf::Dictionary {
key: Some(Box::new(key_type.as_ref().into())),
value: Some(Box::new(value_type.as_ref().into())),
}))
}
DataType::Decimal(whole, fractional) => {
ArrowTypeEnum::Decimal(protobuf::Decimal {
whole: *whole as u64,
fractional: *fractional as u64,
})
}
}
}
}
//Does not check if list subtypes are valid
fn is_valid_scalar_type_no_list_check(datatype: &arrow::datatypes::DataType) -> bool {
match datatype {
DataType::Boolean
| DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::UInt8
| DataType::UInt16
| DataType::UInt32
| DataType::UInt64
| DataType::Float32
| DataType::Float64
| DataType::LargeUtf8
| DataType::Utf8
| DataType::Date32 => true,
DataType::Time64(time_unit) => matches!(
time_unit,
arrow::datatypes::TimeUnit::Microsecond
| arrow::datatypes::TimeUnit::Nanosecond
),
DataType::List(_) => true,
_ => false,
}
}
impl TryFrom<&arrow::datatypes::DataType> for protobuf::scalar_type::Datatype {
type Error = BallistaError;
fn try_from(val: &arrow::datatypes::DataType) -> Result<Self, Self::Error> {
use protobuf::scalar_type;
use protobuf::Field;
use protobuf::{List, PrimitiveScalarType};
let scalar_value = match val {
DataType::Boolean => scalar_type::Datatype::Scalar(PrimitiveScalarType::Bool as i32),
DataType::Int8 => scalar_type::Datatype::Scalar(PrimitiveScalarType::Int8 as i32),
DataType::Int16 => scalar_type::Datatype::Scalar(PrimitiveScalarType::Int16 as i32),
DataType::Int32 => scalar_type::Datatype::Scalar(PrimitiveScalarType::Int32 as i32),
DataType::Int64 => scalar_type::Datatype::Scalar(PrimitiveScalarType::Int64 as i32),
DataType::UInt8 => scalar_type::Datatype::Scalar(PrimitiveScalarType::Uint8 as i32),
DataType::UInt16 => scalar_type::Datatype::Scalar(PrimitiveScalarType::Uint16 as i32),
DataType::UInt32 => scalar_type::Datatype::Scalar(PrimitiveScalarType::Uint32 as i32),
DataType::UInt64 => scalar_type::Datatype::Scalar(PrimitiveScalarType::Uint64 as i32),
DataType::Float32 => scalar_type::Datatype::Scalar(PrimitiveScalarType::Float32 as i32),
DataType::Float64 => scalar_type::Datatype::Scalar(PrimitiveScalarType::Float64 as i32),
DataType::Date32 => scalar_type::Datatype::Scalar(PrimitiveScalarType::Date32 as i32),
DataType::Time64(time_unit) => match time_unit {
arrow::datatypes::TimeUnit::Microsecond => scalar_type::Datatype::Scalar(PrimitiveScalarType::TimeMicrosecond as i32),
arrow::datatypes::TimeUnit::Nanosecond => scalar_type::Datatype::Scalar(PrimitiveScalarType::TimeNanosecond as i32),
_ => {
return Err(proto_error(format!(
"Found invalid time unit for scalar value, only TimeUnit::Microsecond and TimeUnit::Nanosecond are valid time units: {:?}",
time_unit
)))
}
},
DataType::Utf8 => scalar_type::Datatype::Scalar(PrimitiveScalarType::Utf8 as i32),
DataType::LargeUtf8 => scalar_type::Datatype::Scalar(PrimitiveScalarType::LargeUtf8 as i32),
DataType::List(field_type) => {
let mut field_names: Vec<String> = Vec::new();
let mut curr_field: &arrow::datatypes::Field = field_type.as_ref();
field_names.push(curr_field.name().to_owned());
//For each nested field check nested datatype, since datafusion scalars only support recursive lists with a leaf scalar type
// any other compound types are errors.
while let DataType::List(nested_field_type) = curr_field.data_type() {
curr_field = nested_field_type.as_ref();
field_names.push(curr_field.name().to_owned());
if !is_valid_scalar_type_no_list_check(curr_field.data_type()) {
return Err(proto_error(format!("{:?} is an invalid scalar type", curr_field)));
}
}
let deepest_datatype = curr_field.data_type();
if !is_valid_scalar_type_no_list_check(deepest_datatype) {
return Err(proto_error(format!("The list nested type {:?} is an invalid scalar type", curr_field)));
}
let pb_deepest_type: PrimitiveScalarType = match deepest_datatype {
DataType::Boolean => PrimitiveScalarType::Bool,
DataType::Int8 => PrimitiveScalarType::Int8,
DataType::Int16 => PrimitiveScalarType::Int16,
DataType::Int32 => PrimitiveScalarType::Int32,
DataType::Int64 => PrimitiveScalarType::Int64,
DataType::UInt8 => PrimitiveScalarType::Uint8,
DataType::UInt16 => PrimitiveScalarType::Uint16,
DataType::UInt32 => PrimitiveScalarType::Uint32,
DataType::UInt64 => PrimitiveScalarType::Uint64,
DataType::Float32 => PrimitiveScalarType::Float32,
DataType::Float64 => PrimitiveScalarType::Float64,
DataType::Date32 => PrimitiveScalarType::Date32,
DataType::Time64(time_unit) => match time_unit {
arrow::datatypes::TimeUnit::Microsecond => PrimitiveScalarType::TimeMicrosecond,
arrow::datatypes::TimeUnit::Nanosecond => PrimitiveScalarType::TimeNanosecond,
_ => {
return Err(proto_error(format!(
"Found invalid time unit for scalar value, only TimeUnit::Microsecond and TimeUnit::Nanosecond are valid time units: {:?}",
time_unit
)))
}
},
DataType::Utf8 => PrimitiveScalarType::Utf8,
DataType::LargeUtf8 => PrimitiveScalarType::LargeUtf8,
_ => {
return Err(proto_error(format!(
"Error converting to Datatype to scalar type, {:?} is invalid as a datafusion scalar.",
val
)))
}
};
protobuf::scalar_type::Datatype::List(protobuf::ScalarListType {
field_names,
deepest_type: pb_deepest_type as i32,
})
}
DataType::Null
| DataType::Float16
| DataType::Timestamp(_, _)
| DataType::Date64
| DataType::Time32(_)
| DataType::Duration(_)
| DataType::Interval(_)
| DataType::Binary
| DataType::FixedSizeBinary(_)
| DataType::LargeBinary
| DataType::FixedSizeList(_, _)
| DataType::LargeList(_)
| DataType::Struct(_)
| DataType::Union(_)
| DataType::Dictionary(_, _)
| DataType::Decimal(_, _) => {
return Err(proto_error(format!(
"Error converting to Datatype to scalar type, {:?} is invalid as a datafusion scalar.",
val
)))
}
};
Ok(scalar_value)
}
}
impl TryFrom<&datafusion::scalar::ScalarValue> for protobuf::ScalarValue {
type Error = BallistaError;
fn try_from(
val: &datafusion::scalar::ScalarValue,
) -> Result<protobuf::ScalarValue, Self::Error> {
use datafusion::scalar;
use protobuf::scalar_value::Value;
use protobuf::PrimitiveScalarType;
let scalar_val = match val {
scalar::ScalarValue::Boolean(val) => {
create_proto_scalar(val, PrimitiveScalarType::Bool, |s| Value::BoolValue(*s))
}
scalar::ScalarValue::Float32(val) => {
create_proto_scalar(val, PrimitiveScalarType::Float32, |s| {
Value::Float32Value(*s)
})
}
scalar::ScalarValue::Float64(val) => {
create_proto_scalar(val, PrimitiveScalarType::Float64, |s| {
Value::Float64Value(*s)
})
}
scalar::ScalarValue::Int8(val) => {
create_proto_scalar(val, PrimitiveScalarType::Int8, |s| {
Value::Int8Value(*s as i32)
})
}
scalar::ScalarValue::Int16(val) => {
create_proto_scalar(val, PrimitiveScalarType::Int16, |s| {
Value::Int16Value(*s as i32)
})
}
scalar::ScalarValue::Int32(val) => {
create_proto_scalar(val, PrimitiveScalarType::Int32, |s| Value::Int32Value(*s))
}
scalar::ScalarValue::Int64(val) => {
create_proto_scalar(val, PrimitiveScalarType::Int64, |s| Value::Int64Value(*s))
}
scalar::ScalarValue::UInt8(val) => {
create_proto_scalar(val, PrimitiveScalarType::Uint8, |s| {
Value::Uint8Value(*s as u32)
})
}
scalar::ScalarValue::UInt16(val) => {
create_proto_scalar(val, PrimitiveScalarType::Uint16, |s| {
Value::Uint16Value(*s as u32)
})
}
scalar::ScalarValue::UInt32(val) => {
create_proto_scalar(val, PrimitiveScalarType::Uint32, |s| Value::Uint32Value(*s))
}
scalar::ScalarValue::UInt64(val) => {
create_proto_scalar(val, PrimitiveScalarType::Uint64, |s| Value::Uint64Value(*s))
}
scalar::ScalarValue::Utf8(val) => {
create_proto_scalar(val, PrimitiveScalarType::Utf8, |s| {
Value::Utf8Value(s.to_owned())
})
}
scalar::ScalarValue::LargeUtf8(val) => {
create_proto_scalar(val, PrimitiveScalarType::LargeUtf8, |s| {
Value::LargeUtf8Value(s.to_owned())
})
}
scalar::ScalarValue::List(value, datatype) => {
println!("Current datatype of list: {:?}", datatype);
match value {
Some(values) => {
if values.is_empty() {
protobuf::ScalarValue {
value: Some(protobuf::scalar_value::Value::ListValue(
protobuf::ScalarListValue {
datatype: Some(datatype.try_into()?),
values: Vec::new(),
},
)),
}
} else {
let scalar_type = match datatype {
DataType::List(field) => field.as_ref().data_type(),
_ => todo!("Proper error handling"),
};
println!("Current scalar type for list: {:?}", scalar_type);
let type_checked_values: Vec<protobuf::ScalarValue> = values
.iter()
.map(|scalar| match (scalar, scalar_type) {
(scalar::ScalarValue::List(_, arrow::datatypes::DataType::List(list_field)), arrow::datatypes::DataType::List(field)) => {
let scalar_datatype = field.data_type();
let list_datatype = list_field.data_type();
if std::mem::discriminant(list_datatype) != std::mem::discriminant(scalar_datatype) {
return Err(proto_error(format!(
"Protobuf serialization error: Lists with inconsistent typing {:?} and {:?} found within list",
list_datatype, scalar_datatype
)));
}
scalar.try_into()
}
(scalar::ScalarValue::Boolean(_), arrow::datatypes::DataType::Boolean) => scalar.try_into(),
(scalar::ScalarValue::Float32(_), arrow::datatypes::DataType::Float32) => scalar.try_into(),
(scalar::ScalarValue::Float64(_), arrow::datatypes::DataType::Float64) => scalar.try_into(),
(scalar::ScalarValue::Int8(_), arrow::datatypes::DataType::Int8) => scalar.try_into(),
(scalar::ScalarValue::Int16(_), arrow::datatypes::DataType::Int16) => scalar.try_into(),
(scalar::ScalarValue::Int32(_), arrow::datatypes::DataType::Int32) => scalar.try_into(),
(scalar::ScalarValue::Int64(_), arrow::datatypes::DataType::Int64) => scalar.try_into(),
(scalar::ScalarValue::UInt8(_), arrow::datatypes::DataType::UInt8) => scalar.try_into(),
(scalar::ScalarValue::UInt16(_), arrow::datatypes::DataType::UInt16) => scalar.try_into(),
(scalar::ScalarValue::UInt32(_), arrow::datatypes::DataType::UInt32) => scalar.try_into(),
(scalar::ScalarValue::UInt64(_), arrow::datatypes::DataType::UInt64) => scalar.try_into(),
(scalar::ScalarValue::Utf8(_), arrow::datatypes::DataType::Utf8) => scalar.try_into(),
(scalar::ScalarValue::LargeUtf8(_), arrow::datatypes::DataType::LargeUtf8) => scalar.try_into(),
_ => Err(proto_error(format!(
"Protobuf serialization error, {:?} was inconsistent with designated type {:?}",
scalar, datatype
))),
})
.collect::<Result<Vec<_>, _>>()?;
protobuf::ScalarValue {
value: Some(protobuf::scalar_value::Value::ListValue(
protobuf::ScalarListValue {
datatype: Some(datatype.try_into()?),
values: type_checked_values,
},
)),
}
}
}
None => protobuf::ScalarValue {
value: Some(protobuf::scalar_value::Value::NullListValue(
datatype.try_into()?,
)),
},
}
}
datafusion::scalar::ScalarValue::Date32(val) => {
create_proto_scalar(val, PrimitiveScalarType::Date32, |s| Value::Date32Value(*s))
}
datafusion::scalar::ScalarValue::TimestampMicrosecond(val) => {
create_proto_scalar(val, PrimitiveScalarType::TimeMicrosecond, |s| {
Value::TimeMicrosecondValue(*s)
})
}
datafusion::scalar::ScalarValue::TimestampNanosecond(val) => {
create_proto_scalar(val, PrimitiveScalarType::TimeNanosecond, |s| {
Value::TimeNanosecondValue(*s)
})
}
_ => {
return Err(proto_error(format!(
"Error converting to Datatype to scalar type, {:?} is invalid as a datafusion scalar.",
val
)))
}
};
Ok(scalar_val)
}
}
impl TryInto<protobuf::LogicalPlanNode> for &LogicalPlan {
type Error = BallistaError;
fn try_into(self) -> Result<protobuf::LogicalPlanNode, Self::Error> {
use protobuf::logical_plan_node::LogicalPlanType;
match self {
LogicalPlan::TableScan {
table_name,
source,
filters,
projection,
..
} => {
let schema = source.schema();
// unwrap the DFTableAdapter to get to the real TableProvider
let source = if let Some(adapter) =
source.as_any().downcast_ref::<DFTableAdapter>()
{
match &adapter.logical_plan {
LogicalPlan::TableScan { source, .. } => Ok(source.as_any()),
_ => Err(BallistaError::General(
"Invalid LogicalPlan::TableScan".to_owned(),
)),
}
} else {
Ok(source.as_any())
}?;
let projection = match projection {
None => None,
Some(columns) => {
let column_names = columns
.iter()
.map(|i| schema.field(*i).name().to_owned())
.collect();
Some(protobuf::ProjectionColumns {
columns: column_names,
})
}
};
let schema: protobuf::Schema = schema.as_ref().into();
let filters: Vec<protobuf::LogicalExprNode> = filters
.iter()
.map(|filter| filter.try_into())
.collect::<Result<Vec<_>, _>>()?;
if let Some(parquet) = source.downcast_ref::<ParquetTable>() {
Ok(protobuf::LogicalPlanNode {
logical_plan_type: Some(LogicalPlanType::ParquetScan(
protobuf::ParquetTableScanNode {
table_name: table_name.to_owned(),
path: parquet.path().to_owned(),
projection,
schema: Some(schema),
filters,
},
)),
})
} else if let Some(csv) = source.downcast_ref::<CsvFile>() {
let delimiter = [csv.delimiter()];
let delimiter = std::str::from_utf8(&delimiter).map_err(|_| {
BallistaError::General("Invalid CSV delimiter".to_owned())
})?;
Ok(protobuf::LogicalPlanNode {
logical_plan_type: Some(LogicalPlanType::CsvScan(
protobuf::CsvTableScanNode {
table_name: table_name.to_owned(),
path: csv.path().to_owned(),
projection,
schema: Some(schema),
has_header: csv.has_header(),
delimiter: delimiter.to_string(),
file_extension: csv.file_extension().to_string(),
filters,
},
)),
})
} else {
Err(BallistaError::General(format!(
"logical plan to_proto unsupported table provider {:?}",
source
)))
}
}
LogicalPlan::Projection { expr, input, .. } => {
Ok(protobuf::LogicalPlanNode {
logical_plan_type: Some(LogicalPlanType::Projection(Box::new(
protobuf::ProjectionNode {
input: Some(Box::new(input.as_ref().try_into()?)),
expr: expr
.iter()
.map(|expr| expr.try_into())
.collect::<Result<Vec<_>, BallistaError>>()?,
},
))),
})
}
LogicalPlan::Filter { predicate, input } => {
let input: protobuf::LogicalPlanNode = input.as_ref().try_into()?;
Ok(protobuf::LogicalPlanNode {
logical_plan_type: Some(LogicalPlanType::Selection(Box::new(
protobuf::SelectionNode {
input: Some(Box::new(input)),
expr: Some(predicate.try_into()?),
},
))),
})
}
LogicalPlan::Aggregate {
input,
group_expr,
aggr_expr,
..
} => {
let input: protobuf::LogicalPlanNode = input.as_ref().try_into()?;
Ok(protobuf::LogicalPlanNode {
logical_plan_type: Some(LogicalPlanType::Aggregate(Box::new(
protobuf::AggregateNode {
input: Some(Box::new(input)),
group_expr: group_expr
.iter()
.map(|expr| expr.try_into())
.collect::<Result<Vec<_>, BallistaError>>()?,
aggr_expr: aggr_expr
.iter()
.map(|expr| expr.try_into())
.collect::<Result<Vec<_>, BallistaError>>()?,
},
))),
})
}
LogicalPlan::Join {
left,
right,
on,
join_type,
..
} => {
let left: protobuf::LogicalPlanNode = left.as_ref().try_into()?;
let right: protobuf::LogicalPlanNode = right.as_ref().try_into()?;
let join_type = match join_type {
JoinType::Inner => protobuf::JoinType::Inner,
JoinType::Left => protobuf::JoinType::Left,
JoinType::Right => protobuf::JoinType::Right,
};
let left_join_column = on.iter().map(|on| on.0.to_owned()).collect();
let right_join_column = on.iter().map(|on| on.1.to_owned()).collect();
Ok(protobuf::LogicalPlanNode {
logical_plan_type: Some(LogicalPlanType::Join(Box::new(
protobuf::JoinNode {
left: Some(Box::new(left)),
right: Some(Box::new(right)),
join_type: join_type.into(),
left_join_column,
right_join_column,
},
))),
})
}
LogicalPlan::Limit { input, n } => {
let input: protobuf::LogicalPlanNode = input.as_ref().try_into()?;
Ok(protobuf::LogicalPlanNode {
logical_plan_type: Some(LogicalPlanType::Limit(Box::new(
protobuf::LimitNode {
input: Some(Box::new(input)),
limit: *n as u32,
},
))),
})
}
LogicalPlan::Sort { input, expr } => {
let input: protobuf::LogicalPlanNode = input.as_ref().try_into()?;
let selection_expr: Vec<protobuf::LogicalExprNode> = expr
.iter()
.map(|expr| expr.try_into())
.collect::<Result<Vec<_>, BallistaError>>()?;
Ok(protobuf::LogicalPlanNode {
logical_plan_type: Some(LogicalPlanType::Sort(Box::new(
protobuf::SortNode {
input: Some(Box::new(input)),
expr: selection_expr,
},
))),
})
}
LogicalPlan::Repartition {
input,
partitioning_scheme,
} => {
use datafusion::logical_plan::Partitioning;
let input: protobuf::LogicalPlanNode = input.as_ref().try_into()?;
//Assumed common usize field was batch size
//Used u64 to avoid any nastyness involving large values, most data clusters are probably uniformly 64 bits any ways
use protobuf::repartition_node::PartitionMethod;
let pb_partition_method = match partitioning_scheme {
Partitioning::Hash(exprs, partition_count) => {
PartitionMethod::Hash(protobuf::HashRepartition {
hash_expr: exprs
.iter()
.map(|expr| expr.try_into())
.collect::<Result<Vec<_>, BallistaError>>()?,
partition_count: *partition_count as u64,
})
}
Partitioning::RoundRobinBatch(batch_size) => {
PartitionMethod::RoundRobin(*batch_size as u64)
}
};
Ok(protobuf::LogicalPlanNode {
logical_plan_type: Some(LogicalPlanType::Repartition(Box::new(
protobuf::RepartitionNode {
input: Some(Box::new(input)),
partition_method: Some(pb_partition_method),
},
))),
})
}
LogicalPlan::EmptyRelation {
produce_one_row, ..
} => Ok(protobuf::LogicalPlanNode {
logical_plan_type: Some(LogicalPlanType::EmptyRelation(
protobuf::EmptyRelationNode {
produce_one_row: *produce_one_row,
},
)),
}),
LogicalPlan::CreateExternalTable {
name,
location,
file_type,
has_header,
schema: df_schema,
} => {
use datafusion::sql::parser::FileType;
let schema: Schema = df_schema.as_ref().clone().into();
let pb_schema: protobuf::Schema = (&schema).try_into().map_err(|e| {
BallistaError::General(format!(
"Could not convert schema into protobuf: {:?}",
e
))
})?;
let pb_file_type: protobuf::FileType = match file_type {
FileType::NdJson => protobuf::FileType::NdJson,
FileType::Parquet => protobuf::FileType::Parquet,
FileType::CSV => protobuf::FileType::Csv,
};
Ok(protobuf::LogicalPlanNode {
logical_plan_type: Some(LogicalPlanType::CreateExternalTable(
protobuf::CreateExternalTableNode {
name: name.clone(),
location: location.clone(),
file_type: pb_file_type as i32,
has_header: *has_header,
schema: Some(pb_schema),
},
)),
})
}
LogicalPlan::Explain { verbose, plan, .. } => {
let input: protobuf::LogicalPlanNode = plan.as_ref().try_into()?;
Ok(protobuf::LogicalPlanNode {
logical_plan_type: Some(LogicalPlanType::Explain(Box::new(
protobuf::ExplainNode {
input: Some(Box::new(input)),
verbose: *verbose,
},
))),
})
}
LogicalPlan::Extension { .. } => unimplemented!(),
LogicalPlan::Union { .. } => unimplemented!(),
}
}
}
fn create_proto_scalar<I, T: FnOnce(&I) -> protobuf::scalar_value::Value>(
v: &Option<I>,
null_arrow_type: protobuf::PrimitiveScalarType,
constructor: T,
) -> protobuf::ScalarValue {
protobuf::ScalarValue {
value: Some(v.as_ref().map(constructor).unwrap_or(
protobuf::scalar_value::Value::NullValue(null_arrow_type as i32),
)),
}
}
impl TryInto<protobuf::LogicalExprNode> for &Expr {
type Error = BallistaError;
fn try_into(self) -> Result<protobuf::LogicalExprNode, Self::Error> {
use datafusion::scalar::ScalarValue;
use protobuf::scalar_value::Value;
match self {
Expr::Column(name) => {
let expr = protobuf::LogicalExprNode {
expr_type: Some(ExprType::ColumnName(name.clone())),
};
Ok(expr)
}
Expr::Alias(expr, alias) => {
let alias = Box::new(protobuf::AliasNode {
expr: Some(Box::new(expr.as_ref().try_into()?)),
alias: alias.to_owned(),
});
let expr = protobuf::LogicalExprNode {
expr_type: Some(ExprType::Alias(alias)),
};
Ok(expr)
}
Expr::Literal(value) => {
let pb_value: protobuf::ScalarValue = value.try_into()?;
Ok(protobuf::LogicalExprNode {
expr_type: Some(ExprType::Literal(pb_value)),
})
}
Expr::BinaryExpr { left, op, right } => {
let binary_expr = Box::new(protobuf::BinaryExprNode {
l: Some(Box::new(left.as_ref().try_into()?)),
r: Some(Box::new(right.as_ref().try_into()?)),
op: format!("{:?}", op),
});
Ok(protobuf::LogicalExprNode {
expr_type: Some(ExprType::BinaryExpr(binary_expr)),
})
}
Expr::AggregateFunction {
ref fun, ref args, ..
} => {
let aggr_function = match fun {
AggregateFunction::Min => protobuf::AggregateFunction::Min,
AggregateFunction::Max => protobuf::AggregateFunction::Max,
AggregateFunction::Sum => protobuf::AggregateFunction::Sum,
AggregateFunction::Avg => protobuf::AggregateFunction::Avg,
AggregateFunction::Count => protobuf::AggregateFunction::Count,
};
let arg = &args[0];
let aggregate_expr = Box::new(protobuf::AggregateExprNode {
aggr_function: aggr_function.into(),
expr: Some(Box::new(arg.try_into()?)),
});
Ok(protobuf::LogicalExprNode {
expr_type: Some(ExprType::AggregateExpr(aggregate_expr)),
})
}
Expr::ScalarVariable(_) => unimplemented!(),
Expr::ScalarFunction { ref fun, ref args } => {
let fun: protobuf::ScalarFunction = fun.try_into()?;
let expr: Vec<protobuf::LogicalExprNode> = args
.iter()
.map(|e| Ok(e.try_into()?))
.collect::<Result<Vec<protobuf::LogicalExprNode>, BallistaError>>()?;
Ok(protobuf::LogicalExprNode {
expr_type: Some(
protobuf::logical_expr_node::ExprType::ScalarFunction(
protobuf::ScalarFunctionNode {
fun: fun.into(),
expr,
},
),
),
})
}
Expr::ScalarUDF { .. } => unimplemented!(),
Expr::AggregateUDF { .. } => unimplemented!(),
Expr::Not(expr) => {
let expr = Box::new(protobuf::Not {
expr: Some(Box::new(expr.as_ref().try_into()?)),
});
Ok(protobuf::LogicalExprNode {
expr_type: Some(ExprType::NotExpr(expr)),
})
}
Expr::IsNull(expr) => {
let expr = Box::new(protobuf::IsNull {
expr: Some(Box::new(expr.as_ref().try_into()?)),
});
Ok(protobuf::LogicalExprNode {
expr_type: Some(ExprType::IsNullExpr(expr)),
})
}
Expr::IsNotNull(expr) => {
let expr = Box::new(protobuf::IsNotNull {
expr: Some(Box::new(expr.as_ref().try_into()?)),
});
Ok(protobuf::LogicalExprNode {
expr_type: Some(ExprType::IsNotNullExpr(expr)),
})
}
Expr::Between {
expr,
negated,
low,
high,
} => {
let expr = Box::new(protobuf::BetweenNode {
expr: Some(Box::new(expr.as_ref().try_into()?)),
negated: *negated,
low: Some(Box::new(low.as_ref().try_into()?)),
high: Some(Box::new(high.as_ref().try_into()?)),
});
Ok(protobuf::LogicalExprNode {
expr_type: Some(ExprType::Between(expr)),
})
}
Expr::Case {
expr,
when_then_expr,
else_expr,
} => {
let when_then_expr = when_then_expr
.iter()
.map(|(w, t)| {
Ok(protobuf::WhenThen {
when_expr: Some(w.as_ref().try_into()?),
then_expr: Some(t.as_ref().try_into()?),
})
})
.collect::<Result<Vec<protobuf::WhenThen>, BallistaError>>()?;
let expr = Box::new(protobuf::CaseNode {
expr: match expr {
Some(e) => Some(Box::new(e.as_ref().try_into()?)),
None => None,
},
when_then_expr,
else_expr: match else_expr {
Some(e) => Some(Box::new(e.as_ref().try_into()?)),
None => None,
},
});
Ok(protobuf::LogicalExprNode {
expr_type: Some(ExprType::Case(expr)),
})
}
Expr::Cast { expr, data_type } => {
let expr = Box::new(protobuf::CastNode {
expr: Some(Box::new(expr.as_ref().try_into()?)),
arrow_type: Some(data_type.into()),
});
Ok(protobuf::LogicalExprNode {
expr_type: Some(ExprType::Cast(expr)),
})
}
Expr::Sort {
expr,
asc,
nulls_first,
} => {
let expr = Box::new(protobuf::SortExprNode {
expr: Some(Box::new(expr.as_ref().try_into()?)),
asc: *asc,
nulls_first: *nulls_first,
});
Ok(protobuf::LogicalExprNode {
expr_type: Some(ExprType::Sort(expr)),
})
}
Expr::Negative(expr) => {
let expr = Box::new(protobuf::NegativeNode {
expr: Some(Box::new(expr.as_ref().try_into()?)),
});
Ok(protobuf::LogicalExprNode {
expr_type: Some(protobuf::logical_expr_node::ExprType::Negative(
expr,
)),
})
}
Expr::InList {
expr,
list,
negated,
} => {
let expr = Box::new(protobuf::InListNode {
expr: Some(Box::new(expr.as_ref().try_into()?)),
list: list.iter().map(|expr| expr.try_into()).collect::<Result<
Vec<_>,
BallistaError,
>>(
)?,
negated: *negated,
});
Ok(protobuf::LogicalExprNode {
expr_type: Some(protobuf::logical_expr_node::ExprType::InList(expr)),
})
}
Expr::Wildcard => Ok(protobuf::LogicalExprNode {
expr_type: Some(protobuf::logical_expr_node::ExprType::Wildcard(true)),
}),
Expr::TryCast { .. } => unimplemented!(),
}
}
}
impl Into<protobuf::Schema> for &Schema {
fn into(self) -> protobuf::Schema {
protobuf::Schema {
columns: self
.fields()
.iter()
.map(protobuf::Field::from)
.collect::<Vec<_>>(),
}
}
}
impl TryFrom<&arrow::datatypes::DataType> for protobuf::ScalarType {
type Error = BallistaError;
fn try_from(value: &arrow::datatypes::DataType) -> Result<Self, Self::Error> {
let datatype = protobuf::scalar_type::Datatype::try_from(value)?;
Ok(protobuf::ScalarType {
datatype: Some(datatype),
})
}
}
impl TryInto<protobuf::ScalarFunction> for &BuiltinScalarFunction {
type Error = BallistaError;
fn try_into(self) -> Result<protobuf::ScalarFunction, Self::Error> {
match self {
BuiltinScalarFunction::Sqrt => Ok(protobuf::ScalarFunction::Sqrt),
BuiltinScalarFunction::Sin => Ok(protobuf::ScalarFunction::Sin),
BuiltinScalarFunction::Cos => Ok(protobuf::ScalarFunction::Cos),
BuiltinScalarFunction::Tan => Ok(protobuf::ScalarFunction::Tan),
BuiltinScalarFunction::Asin => Ok(protobuf::ScalarFunction::Asin),
BuiltinScalarFunction::Acos => Ok(protobuf::ScalarFunction::Acos),
BuiltinScalarFunction::Atan => Ok(protobuf::ScalarFunction::Atan),
BuiltinScalarFunction::Exp => Ok(protobuf::ScalarFunction::Exp),
BuiltinScalarFunction::Log => Ok(protobuf::ScalarFunction::Log),
BuiltinScalarFunction::Log10 => Ok(protobuf::ScalarFunction::Log10),
BuiltinScalarFunction::Floor => Ok(protobuf::ScalarFunction::Floor),
BuiltinScalarFunction::Ceil => Ok(protobuf::ScalarFunction::Ceil),
BuiltinScalarFunction::Round => Ok(protobuf::ScalarFunction::Round),
BuiltinScalarFunction::Trunc => Ok(protobuf::ScalarFunction::Trunc),
BuiltinScalarFunction::Abs => Ok(protobuf::ScalarFunction::Abs),
BuiltinScalarFunction::OctetLength => {
Ok(protobuf::ScalarFunction::Octetlength)
}
BuiltinScalarFunction::Concat => Ok(protobuf::ScalarFunction::Concat),
BuiltinScalarFunction::Lower => Ok(protobuf::ScalarFunction::Lower),
BuiltinScalarFunction::Upper => Ok(protobuf::ScalarFunction::Upper),
BuiltinScalarFunction::Trim => Ok(protobuf::ScalarFunction::Trim),
BuiltinScalarFunction::Ltrim => Ok(protobuf::ScalarFunction::Ltrim),
BuiltinScalarFunction::Rtrim => Ok(protobuf::ScalarFunction::Rtrim),
BuiltinScalarFunction::ToTimestamp => {
Ok(protobuf::ScalarFunction::Totimestamp)
}
BuiltinScalarFunction::Array => Ok(protobuf::ScalarFunction::Array),
BuiltinScalarFunction::NullIf => Ok(protobuf::ScalarFunction::Nullif),
BuiltinScalarFunction::DateTrunc => Ok(protobuf::ScalarFunction::Datetrunc),
BuiltinScalarFunction::MD5 => Ok(protobuf::ScalarFunction::Md5),
BuiltinScalarFunction::SHA224 => Ok(protobuf::ScalarFunction::Sha224),
BuiltinScalarFunction::SHA256 => Ok(protobuf::ScalarFunction::Sha256),
BuiltinScalarFunction::SHA384 => Ok(protobuf::ScalarFunction::Sha384),
BuiltinScalarFunction::SHA512 => Ok(protobuf::ScalarFunction::Sha512),
_ => Err(BallistaError::General(format!(
"logical_plan::to_proto() unsupported scalar function {:?}",
self
))),
}
}
}