blob: c661369dbb3f189f4f0534b2c0546da792c47f15 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
use crate::error::PlanSerDeError;
use crate::protobuf::scalar_type;
use datafusion::arrow::datatypes::{
DataType, Field, IntervalUnit, Schema, SchemaRef, TimeUnit, UnionMode,
};
use datafusion::logical_expr::{BuiltInWindowFunction, BuiltinScalarFunction};
use datafusion::logical_plan::{JoinConstraint, Operator};
use datafusion::physical_plan::aggregates::AggregateFunction;
use datafusion::prelude::JoinType;
use datafusion::scalar::ScalarValue;
// include the generated protobuf source as a submodule
#[allow(clippy::all)]
pub mod protobuf {
include!(concat!(env!("OUT_DIR"), "/plan.protobuf.rs"));
}
pub mod error;
pub mod from_proto;
pub(crate) fn proto_error<S: Into<String>>(message: S) -> PlanSerDeError {
PlanSerDeError::General(message.into())
}
#[macro_export]
macro_rules! convert_required {
($PB:expr) => {{
if let Some(field) = $PB.as_ref() {
field.try_into()
} else {
Err(proto_error("Missing required field in protobuf"))
}
}};
}
#[macro_export]
macro_rules! into_required {
($PB:expr) => {{
if let Some(field) = $PB.as_ref() {
Ok(field.into())
} else {
Err(proto_error("Missing required field in protobuf"))
}
}};
}
#[macro_export]
macro_rules! convert_box_required {
($PB:expr) => {{
if let Some(field) = $PB.as_ref() {
field.as_ref().try_into()
} else {
Err(proto_error("Missing required field in protobuf"))
}
}};
}
pub(crate) fn from_proto_binary_op(op: &str) -> Result<Operator, PlanSerDeError> {
match op {
"And" => Ok(Operator::And),
"Or" => Ok(Operator::Or),
"Eq" => Ok(Operator::Eq),
"NotEq" => Ok(Operator::NotEq),
"LtEq" => Ok(Operator::LtEq),
"Lt" => Ok(Operator::Lt),
"Gt" => Ok(Operator::Gt),
"GtEq" => Ok(Operator::GtEq),
"Plus" => Ok(Operator::Plus),
"Minus" => Ok(Operator::Minus),
"Multiply" => Ok(Operator::Multiply),
"Divide" => Ok(Operator::Divide),
"Modulo" => Ok(Operator::Modulo),
"Like" => Ok(Operator::Like),
"NotLike" => Ok(Operator::NotLike),
other => Err(proto_error(format!(
"Unsupported binary operator '{:?}'",
other
))),
}
}
impl From<protobuf::AggregateFunction> for AggregateFunction {
fn from(agg_fun: protobuf::AggregateFunction) -> AggregateFunction {
match agg_fun {
protobuf::AggregateFunction::Min => AggregateFunction::Min,
protobuf::AggregateFunction::Max => AggregateFunction::Max,
protobuf::AggregateFunction::Sum => AggregateFunction::Sum,
protobuf::AggregateFunction::Avg => AggregateFunction::Avg,
protobuf::AggregateFunction::Count => AggregateFunction::Count,
protobuf::AggregateFunction::ApproxDistinct => {
AggregateFunction::ApproxDistinct
}
protobuf::AggregateFunction::ArrayAgg => AggregateFunction::ArrayAgg,
protobuf::AggregateFunction::Variance => AggregateFunction::Variance,
protobuf::AggregateFunction::VariancePop => AggregateFunction::VariancePop,
protobuf::AggregateFunction::Covariance => AggregateFunction::Covariance,
protobuf::AggregateFunction::CovariancePop => {
AggregateFunction::CovariancePop
}
protobuf::AggregateFunction::Stddev => AggregateFunction::Stddev,
protobuf::AggregateFunction::StddevPop => AggregateFunction::StddevPop,
protobuf::AggregateFunction::Correlation => AggregateFunction::Correlation,
}
}
}
impl From<protobuf::BuiltInWindowFunction> for BuiltInWindowFunction {
fn from(built_in_function: protobuf::BuiltInWindowFunction) -> Self {
match built_in_function {
protobuf::BuiltInWindowFunction::RowNumber => {
BuiltInWindowFunction::RowNumber
}
protobuf::BuiltInWindowFunction::Rank => BuiltInWindowFunction::Rank,
protobuf::BuiltInWindowFunction::PercentRank => {
BuiltInWindowFunction::PercentRank
}
protobuf::BuiltInWindowFunction::DenseRank => {
BuiltInWindowFunction::DenseRank
}
protobuf::BuiltInWindowFunction::Lag => BuiltInWindowFunction::Lag,
protobuf::BuiltInWindowFunction::Lead => BuiltInWindowFunction::Lead,
protobuf::BuiltInWindowFunction::FirstValue => {
BuiltInWindowFunction::FirstValue
}
protobuf::BuiltInWindowFunction::CumeDist => BuiltInWindowFunction::CumeDist,
protobuf::BuiltInWindowFunction::Ntile => BuiltInWindowFunction::Ntile,
protobuf::BuiltInWindowFunction::NthValue => BuiltInWindowFunction::NthValue,
protobuf::BuiltInWindowFunction::LastValue => {
BuiltInWindowFunction::LastValue
}
}
}
}
impl protobuf::TimeUnit {
pub fn from_arrow_time_unit(val: &TimeUnit) -> Self {
match val {
TimeUnit::Second => protobuf::TimeUnit::Second,
TimeUnit::Millisecond => protobuf::TimeUnit::TimeMillisecond,
TimeUnit::Microsecond => protobuf::TimeUnit::Microsecond,
TimeUnit::Nanosecond => protobuf::TimeUnit::Nanosecond,
}
}
pub fn from_i32_to_arrow(time_unit_i32: i32) -> Result<TimeUnit, PlanSerDeError> {
let pb_time_unit = protobuf::TimeUnit::from_i32(time_unit_i32);
match pb_time_unit {
Some(time_unit) => Ok(match time_unit {
protobuf::TimeUnit::Second => TimeUnit::Second,
protobuf::TimeUnit::TimeMillisecond => TimeUnit::Millisecond,
protobuf::TimeUnit::Microsecond => TimeUnit::Microsecond,
protobuf::TimeUnit::Nanosecond => TimeUnit::Nanosecond,
}),
None => Err(proto_error(
"Error converting i32 to TimeUnit: Passed invalid variant",
)),
}
}
}
impl protobuf::IntervalUnit {
pub fn from_arrow_interval_unit(interval_unit: &IntervalUnit) -> Self {
match interval_unit {
IntervalUnit::YearMonth => protobuf::IntervalUnit::YearMonth,
IntervalUnit::DayTime => protobuf::IntervalUnit::DayTime,
IntervalUnit::MonthDayNano => protobuf::IntervalUnit::MonthDayNano,
}
}
pub fn from_i32_to_arrow(
interval_unit_i32: i32,
) -> Result<IntervalUnit, PlanSerDeError> {
let pb_interval_unit = protobuf::IntervalUnit::from_i32(interval_unit_i32);
match pb_interval_unit {
Some(interval_unit) => Ok(match interval_unit {
protobuf::IntervalUnit::YearMonth => IntervalUnit::YearMonth,
protobuf::IntervalUnit::DayTime => IntervalUnit::DayTime,
protobuf::IntervalUnit::MonthDayNano => IntervalUnit::MonthDayNano,
}),
None => Err(proto_error(
"Error converting i32 to DateUnit: Passed invalid variant",
)),
}
}
}
impl TryInto<datafusion::arrow::datatypes::DataType>
for &protobuf::arrow_type::ArrowTypeEnum
{
type Error = PlanSerDeError;
fn try_into(self) -> Result<datafusion::arrow::datatypes::DataType, Self::Error> {
use protobuf::arrow_type;
Ok(match self {
arrow_type::ArrowTypeEnum::None(_) => DataType::Null,
arrow_type::ArrowTypeEnum::Bool(_) => DataType::Boolean,
arrow_type::ArrowTypeEnum::Uint8(_) => DataType::UInt8,
arrow_type::ArrowTypeEnum::Int8(_) => DataType::Int8,
arrow_type::ArrowTypeEnum::Uint16(_) => DataType::UInt16,
arrow_type::ArrowTypeEnum::Int16(_) => DataType::Int16,
arrow_type::ArrowTypeEnum::Uint32(_) => DataType::UInt32,
arrow_type::ArrowTypeEnum::Int32(_) => DataType::Int32,
arrow_type::ArrowTypeEnum::Uint64(_) => DataType::UInt64,
arrow_type::ArrowTypeEnum::Int64(_) => DataType::Int64,
arrow_type::ArrowTypeEnum::Float16(_) => DataType::Float16,
arrow_type::ArrowTypeEnum::Float32(_) => DataType::Float32,
arrow_type::ArrowTypeEnum::Float64(_) => DataType::Float64,
arrow_type::ArrowTypeEnum::Utf8(_) => DataType::Utf8,
arrow_type::ArrowTypeEnum::LargeUtf8(_) => DataType::LargeUtf8,
arrow_type::ArrowTypeEnum::Binary(_) => DataType::Binary,
arrow_type::ArrowTypeEnum::FixedSizeBinary(size) => {
DataType::FixedSizeBinary(*size)
}
arrow_type::ArrowTypeEnum::LargeBinary(_) => DataType::LargeBinary,
arrow_type::ArrowTypeEnum::Date32(_) => DataType::Date32,
arrow_type::ArrowTypeEnum::Date64(_) => DataType::Date64,
arrow_type::ArrowTypeEnum::Duration(time_unit) => {
DataType::Duration(protobuf::TimeUnit::from_i32_to_arrow(*time_unit)?)
}
arrow_type::ArrowTypeEnum::Timestamp(protobuf::Timestamp {
time_unit,
timezone,
}) => DataType::Timestamp(
protobuf::TimeUnit::from_i32_to_arrow(*time_unit)?,
match timezone.len() {
0 => None,
_ => Some(timezone.to_owned()),
},
),
arrow_type::ArrowTypeEnum::Time32(time_unit) => {
DataType::Time32(protobuf::TimeUnit::from_i32_to_arrow(*time_unit)?)
}
arrow_type::ArrowTypeEnum::Time64(time_unit) => {
DataType::Time64(protobuf::TimeUnit::from_i32_to_arrow(*time_unit)?)
}
arrow_type::ArrowTypeEnum::Interval(interval_unit) => DataType::Interval(
protobuf::IntervalUnit::from_i32_to_arrow(*interval_unit)?,
),
arrow_type::ArrowTypeEnum::Decimal(protobuf::Decimal {
whole,
fractional,
}) => DataType::Decimal(*whole as usize, *fractional as usize),
arrow_type::ArrowTypeEnum::List(list) => {
let list_type: &protobuf::Field = list
.as_ref()
.field_type
.as_ref()
.ok_or_else(|| proto_error("Protobuf deserialization error: List message missing required field 'field_type'"))?
.as_ref();
DataType::List(Box::new(list_type.try_into()?))
}
arrow_type::ArrowTypeEnum::LargeList(list) => {
let list_type: &protobuf::Field = list
.as_ref()
.field_type
.as_ref()
.ok_or_else(|| proto_error("Protobuf deserialization error: List message missing required field 'field_type'"))?
.as_ref();
DataType::LargeList(Box::new(list_type.try_into()?))
}
arrow_type::ArrowTypeEnum::FixedSizeList(list) => {
let list_type: &protobuf::Field = list
.as_ref()
.field_type
.as_ref()
.ok_or_else(|| proto_error("Protobuf deserialization error: List message missing required field 'field_type'"))?
.as_ref();
let list_size = list.list_size;
DataType::FixedSizeList(Box::new(list_type.try_into()?), list_size)
}
arrow_type::ArrowTypeEnum::Struct(strct) => DataType::Struct(
strct
.sub_field_types
.iter()
.map(|field| field.try_into())
.collect::<Result<Vec<_>, _>>()?,
),
arrow_type::ArrowTypeEnum::Union(union) => {
let union_mode = protobuf::UnionMode::from_i32(union.union_mode)
.ok_or_else(|| {
proto_error(
"Protobuf deserialization error: Unknown union mode type",
)
})?;
let union_mode = match union_mode {
protobuf::UnionMode::Dense => UnionMode::Dense,
protobuf::UnionMode::Sparse => UnionMode::Sparse,
};
let union_types = union
.union_types
.iter()
.map(|field| field.try_into())
.collect::<Result<Vec<_>, _>>()?;
DataType::Union(union_types, union_mode)
}
arrow_type::ArrowTypeEnum::Dictionary(dict) => {
let pb_key_datatype = dict
.as_ref()
.key
.as_ref()
.ok_or_else(|| proto_error("Protobuf deserialization error: Dictionary message missing required field 'key'"))?;
let pb_value_datatype = dict
.as_ref()
.value
.as_ref()
.ok_or_else(|| proto_error("Protobuf deserialization error: Dictionary message missing required field 'key'"))?;
let key_datatype: DataType = pb_key_datatype.as_ref().try_into()?;
let value_datatype: DataType = pb_value_datatype.as_ref().try_into()?;
DataType::Dictionary(Box::new(key_datatype), Box::new(value_datatype))
}
})
}
}
#[allow(clippy::from_over_into)]
impl Into<datafusion::arrow::datatypes::DataType> for protobuf::PrimitiveScalarType {
fn into(self) -> datafusion::arrow::datatypes::DataType {
match self {
protobuf::PrimitiveScalarType::Bool => DataType::Boolean,
protobuf::PrimitiveScalarType::Uint8 => DataType::UInt8,
protobuf::PrimitiveScalarType::Int8 => DataType::Int8,
protobuf::PrimitiveScalarType::Uint16 => DataType::UInt16,
protobuf::PrimitiveScalarType::Int16 => DataType::Int16,
protobuf::PrimitiveScalarType::Uint32 => DataType::UInt32,
protobuf::PrimitiveScalarType::Int32 => DataType::Int32,
protobuf::PrimitiveScalarType::Uint64 => DataType::UInt64,
protobuf::PrimitiveScalarType::Int64 => DataType::Int64,
protobuf::PrimitiveScalarType::Float32 => DataType::Float32,
protobuf::PrimitiveScalarType::Float64 => DataType::Float64,
protobuf::PrimitiveScalarType::Utf8 => DataType::Utf8,
protobuf::PrimitiveScalarType::LargeUtf8 => DataType::LargeUtf8,
protobuf::PrimitiveScalarType::Date32 => DataType::Date32,
protobuf::PrimitiveScalarType::TimeMicrosecond => {
DataType::Time64(TimeUnit::Microsecond)
}
protobuf::PrimitiveScalarType::TimeNanosecond => {
DataType::Time64(TimeUnit::Nanosecond)
}
protobuf::PrimitiveScalarType::Null => DataType::Null,
}
}
}
impl From<protobuf::JoinType> for JoinType {
fn from(t: protobuf::JoinType) -> Self {
match t {
protobuf::JoinType::Inner => JoinType::Inner,
protobuf::JoinType::Left => JoinType::Left,
protobuf::JoinType::Right => JoinType::Right,
protobuf::JoinType::Full => JoinType::Full,
protobuf::JoinType::Semi => JoinType::Semi,
protobuf::JoinType::Anti => JoinType::Anti,
}
}
}
impl From<JoinType> for protobuf::JoinType {
fn from(t: JoinType) -> Self {
match t {
JoinType::Inner => protobuf::JoinType::Inner,
JoinType::Left => protobuf::JoinType::Left,
JoinType::Right => protobuf::JoinType::Right,
JoinType::Full => protobuf::JoinType::Full,
JoinType::Semi => protobuf::JoinType::Semi,
JoinType::Anti => protobuf::JoinType::Anti,
}
}
}
impl From<protobuf::JoinConstraint> for JoinConstraint {
fn from(t: protobuf::JoinConstraint) -> Self {
match t {
protobuf::JoinConstraint::On => JoinConstraint::On,
protobuf::JoinConstraint::Using => JoinConstraint::Using,
}
}
}
impl From<JoinConstraint> for protobuf::JoinConstraint {
fn from(t: JoinConstraint) -> Self {
match t {
JoinConstraint::On => protobuf::JoinConstraint::On,
JoinConstraint::Using => protobuf::JoinConstraint::Using,
}
}
}
impl TryFrom<&DataType> for protobuf::ScalarType {
type Error = PlanSerDeError;
fn try_from(value: &DataType) -> Result<Self, Self::Error> {
let datatype = protobuf::scalar_type::Datatype::try_from(value)?;
Ok(protobuf::ScalarType {
datatype: Some(datatype),
})
}
}
impl TryFrom<&DataType> for protobuf::scalar_type::Datatype {
type Error = PlanSerDeError;
fn try_from(val: &DataType) -> Result<Self, Self::Error> {
use protobuf::PrimitiveScalarType;
let scalar_value = match val {
DataType::Boolean => scalar_type::Datatype::Scalar(PrimitiveScalarType::Bool as i32),
DataType::Int8 => scalar_type::Datatype::Scalar(PrimitiveScalarType::Int8 as i32),
DataType::Int16 => scalar_type::Datatype::Scalar(PrimitiveScalarType::Int16 as i32),
DataType::Int32 => scalar_type::Datatype::Scalar(PrimitiveScalarType::Int32 as i32),
DataType::Int64 => scalar_type::Datatype::Scalar(PrimitiveScalarType::Int64 as i32),
DataType::UInt8 => scalar_type::Datatype::Scalar(PrimitiveScalarType::Uint8 as i32),
DataType::UInt16 => scalar_type::Datatype::Scalar(PrimitiveScalarType::Uint16 as i32),
DataType::UInt32 => scalar_type::Datatype::Scalar(PrimitiveScalarType::Uint32 as i32),
DataType::UInt64 => scalar_type::Datatype::Scalar(PrimitiveScalarType::Uint64 as i32),
DataType::Float32 => scalar_type::Datatype::Scalar(PrimitiveScalarType::Float32 as i32),
DataType::Float64 => scalar_type::Datatype::Scalar(PrimitiveScalarType::Float64 as i32),
DataType::Date32 => scalar_type::Datatype::Scalar(PrimitiveScalarType::Date32 as i32),
DataType::Time64(time_unit) => match time_unit {
TimeUnit::Microsecond => scalar_type::Datatype::Scalar(PrimitiveScalarType::TimeMicrosecond as i32),
TimeUnit::Nanosecond => scalar_type::Datatype::Scalar(PrimitiveScalarType::TimeNanosecond as i32),
_ => {
return Err(proto_error(format!(
"Found invalid time unit for scalar value, only TimeUnit::Microsecond and TimeUnit::Nanosecond are valid time units: {:?}",
time_unit
)))
}
},
DataType::Utf8 => scalar_type::Datatype::Scalar(PrimitiveScalarType::Utf8 as i32),
DataType::LargeUtf8 => scalar_type::Datatype::Scalar(PrimitiveScalarType::LargeUtf8 as i32),
DataType::List(field_type) => {
let mut field_names: Vec<String> = Vec::new();
let mut curr_field = field_type.as_ref();
field_names.push(curr_field.name().to_owned());
//For each nested field check nested datatype, since datafusion scalars only support recursive lists with a leaf scalar type
// any other compound types are errors.
while let DataType::List(nested_field_type) = curr_field.data_type() {
curr_field = nested_field_type.as_ref();
field_names.push(curr_field.name().to_owned());
if !is_valid_scalar_type_no_list_check(curr_field.data_type()) {
return Err(proto_error(format!("{:?} is an invalid scalar type", curr_field)));
}
}
let deepest_datatype = curr_field.data_type();
if !is_valid_scalar_type_no_list_check(deepest_datatype) {
return Err(proto_error(format!("The list nested type {:?} is an invalid scalar type", curr_field)));
}
let pb_deepest_type: PrimitiveScalarType = match deepest_datatype {
DataType::Boolean => PrimitiveScalarType::Bool,
DataType::Int8 => PrimitiveScalarType::Int8,
DataType::Int16 => PrimitiveScalarType::Int16,
DataType::Int32 => PrimitiveScalarType::Int32,
DataType::Int64 => PrimitiveScalarType::Int64,
DataType::UInt8 => PrimitiveScalarType::Uint8,
DataType::UInt16 => PrimitiveScalarType::Uint16,
DataType::UInt32 => PrimitiveScalarType::Uint32,
DataType::UInt64 => PrimitiveScalarType::Uint64,
DataType::Float32 => PrimitiveScalarType::Float32,
DataType::Float64 => PrimitiveScalarType::Float64,
DataType::Date32 => PrimitiveScalarType::Date32,
DataType::Time64(time_unit) => match time_unit {
TimeUnit::Microsecond => PrimitiveScalarType::TimeMicrosecond,
TimeUnit::Nanosecond => PrimitiveScalarType::TimeNanosecond,
_ => {
return Err(proto_error(format!(
"Found invalid time unit for scalar value, only TimeUnit::Microsecond and TimeUnit::Nanosecond are valid time units: {:?}",
time_unit
)))
}
},
DataType::Utf8 => PrimitiveScalarType::Utf8,
DataType::LargeUtf8 => PrimitiveScalarType::LargeUtf8,
_ => {
return Err(proto_error(format!(
"Error converting to Datatype to scalar type, {:?} is invalid as a datafusion scalar.",
val
)))
}
};
protobuf::scalar_type::Datatype::List(protobuf::ScalarListType {
field_names,
deepest_type: pb_deepest_type as i32,
})
}
DataType::Null
| DataType::Float16
| DataType::Timestamp(_, _)
| DataType::Date64
| DataType::Time32(_)
| DataType::Duration(_)
| DataType::Interval(_)
| DataType::Binary
| DataType::FixedSizeBinary(_)
| DataType::LargeBinary
| DataType::FixedSizeList(_, _)
| DataType::LargeList(_)
| DataType::Struct(_)
| DataType::Union(_, _)
| DataType::Dictionary(_, _)
| DataType::Map(_, _)
| DataType::Decimal(_, _) => {
return Err(proto_error(format!(
"Error converting to Datatype to scalar type, {:?} is invalid as a datafusion scalar.",
val
)))
}
};
Ok(scalar_value)
}
}
//Does not check if list subtypes are valid
fn is_valid_scalar_type_no_list_check(datatype: &DataType) -> bool {
match datatype {
DataType::Boolean
| DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::UInt8
| DataType::UInt16
| DataType::UInt32
| DataType::UInt64
| DataType::Float32
| DataType::Float64
| DataType::LargeUtf8
| DataType::Utf8
| DataType::Date32 => true,
DataType::Time64(time_unit) => {
matches!(time_unit, TimeUnit::Microsecond | TimeUnit::Nanosecond)
}
DataType::List(_) => true,
_ => false,
}
}
impl TryFrom<&datafusion::scalar::ScalarValue> for protobuf::ScalarValue {
type Error = PlanSerDeError;
fn try_from(
val: &datafusion::scalar::ScalarValue,
) -> Result<protobuf::ScalarValue, Self::Error> {
use datafusion::scalar;
use protobuf::scalar_value::Value;
use protobuf::PrimitiveScalarType;
let scalar_val = match val {
scalar::ScalarValue::Boolean(val) => {
create_proto_scalar(val, PrimitiveScalarType::Bool, |s| Value::BoolValue(*s))
}
scalar::ScalarValue::Float32(val) => {
create_proto_scalar(val, PrimitiveScalarType::Float32, |s| {
Value::Float32Value(*s)
})
}
scalar::ScalarValue::Float64(val) => {
create_proto_scalar(val, PrimitiveScalarType::Float64, |s| {
Value::Float64Value(*s)
})
}
scalar::ScalarValue::Int8(val) => {
create_proto_scalar(val, PrimitiveScalarType::Int8, |s| {
Value::Int8Value(*s as i32)
})
}
scalar::ScalarValue::Int16(val) => {
create_proto_scalar(val, PrimitiveScalarType::Int16, |s| {
Value::Int16Value(*s as i32)
})
}
scalar::ScalarValue::Int32(val) => {
create_proto_scalar(val, PrimitiveScalarType::Int32, |s| Value::Int32Value(*s))
}
scalar::ScalarValue::Int64(val) => {
create_proto_scalar(val, PrimitiveScalarType::Int64, |s| Value::Int64Value(*s))
}
scalar::ScalarValue::UInt8(val) => {
create_proto_scalar(val, PrimitiveScalarType::Uint8, |s| {
Value::Uint8Value(*s as u32)
})
}
scalar::ScalarValue::UInt16(val) => {
create_proto_scalar(val, PrimitiveScalarType::Uint16, |s| {
Value::Uint16Value(*s as u32)
})
}
scalar::ScalarValue::UInt32(val) => {
create_proto_scalar(val, PrimitiveScalarType::Uint32, |s| Value::Uint32Value(*s))
}
scalar::ScalarValue::UInt64(val) => {
create_proto_scalar(val, PrimitiveScalarType::Uint64, |s| Value::Uint64Value(*s))
}
scalar::ScalarValue::Utf8(val) => {
create_proto_scalar(val, PrimitiveScalarType::Utf8, |s| {
Value::Utf8Value(s.to_owned())
})
}
scalar::ScalarValue::LargeUtf8(val) => {
create_proto_scalar(val, PrimitiveScalarType::LargeUtf8, |s| {
Value::LargeUtf8Value(s.to_owned())
})
}
scalar::ScalarValue::List(value, datatype) => {
match value {
Some(values) => {
if values.is_empty() {
protobuf::ScalarValue {
value: Some(protobuf::scalar_value::Value::ListValue(
protobuf::ScalarListValue {
datatype: Some(datatype.as_ref().try_into()?),
values: Vec::new(),
},
)),
}
} else {
let scalar_type = match datatype.as_ref() {
DataType::List(field) => field.as_ref().data_type(),
_ => todo!("Proper error handling"),
};
let type_checked_values: Vec<protobuf::ScalarValue> = values
.iter()
.map(|scalar| match (scalar, scalar_type) {
(scalar::ScalarValue::List(_, list_type), DataType::List(field)) => {
if let DataType::List(list_field) = list_type.as_ref() {
let scalar_datatype = field.data_type();
let list_datatype = list_field.data_type();
if std::mem::discriminant(list_datatype) != std::mem::discriminant(scalar_datatype) {
return Err(proto_error(format!(
"Protobuf serialization error: Lists with inconsistent typing {:?} and {:?} found within list",
list_datatype, scalar_datatype
)));
}
scalar.try_into()
} else {
Err(proto_error(format!(
"Protobuf serialization error, {:?} was inconsistent with designated type {:?}",
scalar, datatype
)))
}
}
(scalar::ScalarValue::Boolean(_), DataType::Boolean) => scalar.try_into(),
(scalar::ScalarValue::Float32(_), DataType::Float32) => scalar.try_into(),
(scalar::ScalarValue::Float64(_), DataType::Float64) => scalar.try_into(),
(scalar::ScalarValue::Int8(_), DataType::Int8) => scalar.try_into(),
(scalar::ScalarValue::Int16(_), DataType::Int16) => scalar.try_into(),
(scalar::ScalarValue::Int32(_), DataType::Int32) => scalar.try_into(),
(scalar::ScalarValue::Int64(_), DataType::Int64) => scalar.try_into(),
(scalar::ScalarValue::UInt8(_), DataType::UInt8) => scalar.try_into(),
(scalar::ScalarValue::UInt16(_), DataType::UInt16) => scalar.try_into(),
(scalar::ScalarValue::UInt32(_), DataType::UInt32) => scalar.try_into(),
(scalar::ScalarValue::UInt64(_), DataType::UInt64) => scalar.try_into(),
(scalar::ScalarValue::Utf8(_), DataType::Utf8) => scalar.try_into(),
(scalar::ScalarValue::LargeUtf8(_), DataType::LargeUtf8) => scalar.try_into(),
_ => Err(proto_error(format!(
"Protobuf serialization error, {:?} was inconsistent with designated type {:?}",
scalar, datatype
))),
})
.collect::<Result<Vec<_>, _>>()?;
protobuf::ScalarValue {
value: Some(protobuf::scalar_value::Value::ListValue(
protobuf::ScalarListValue {
datatype: Some(datatype.as_ref().try_into()?),
values: type_checked_values,
},
)),
}
}
}
None => protobuf::ScalarValue {
value: Some(protobuf::scalar_value::Value::NullListValue(
datatype.as_ref().try_into()?,
)),
},
}
}
datafusion::scalar::ScalarValue::Date32(val) => {
create_proto_scalar(val, PrimitiveScalarType::Date32, |s| Value::Date32Value(*s))
}
datafusion::scalar::ScalarValue::TimestampMicrosecond(val, _) => {
create_proto_scalar(val, PrimitiveScalarType::TimeMicrosecond, |s| {
Value::TimeMicrosecondValue(*s)
})
}
datafusion::scalar::ScalarValue::TimestampNanosecond(val, _) => {
create_proto_scalar(val, PrimitiveScalarType::TimeNanosecond, |s| {
Value::TimeNanosecondValue(*s)
})
}
_ => {
return Err(proto_error(format!(
"Error converting to Datatype to scalar type, {:?} is invalid as a datafusion scalar.",
val
)))
}
};
Ok(scalar_val)
}
}
impl From<&Field> for protobuf::Field {
fn from(field: &Field) -> Self {
protobuf::Field {
name: field.name().to_owned(),
arrow_type: Some(Box::new(field.data_type().into())),
nullable: field.is_nullable(),
children: Vec::new(),
}
}
}
impl From<&DataType> for protobuf::ArrowType {
fn from(val: &DataType) -> protobuf::ArrowType {
protobuf::ArrowType {
arrow_type_enum: Some(val.into()),
}
}
}
impl From<&DataType> for protobuf::arrow_type::ArrowTypeEnum {
fn from(val: &DataType) -> protobuf::arrow_type::ArrowTypeEnum {
use protobuf::arrow_type::ArrowTypeEnum;
use protobuf::EmptyMessage;
match val {
DataType::Null => ArrowTypeEnum::None(EmptyMessage {}),
DataType::Boolean => ArrowTypeEnum::Bool(EmptyMessage {}),
DataType::Int8 => ArrowTypeEnum::Int8(EmptyMessage {}),
DataType::Int16 => ArrowTypeEnum::Int16(EmptyMessage {}),
DataType::Int32 => ArrowTypeEnum::Int32(EmptyMessage {}),
DataType::Int64 => ArrowTypeEnum::Int64(EmptyMessage {}),
DataType::UInt8 => ArrowTypeEnum::Uint8(EmptyMessage {}),
DataType::UInt16 => ArrowTypeEnum::Uint16(EmptyMessage {}),
DataType::UInt32 => ArrowTypeEnum::Uint32(EmptyMessage {}),
DataType::UInt64 => ArrowTypeEnum::Uint64(EmptyMessage {}),
DataType::Float16 => ArrowTypeEnum::Float16(EmptyMessage {}),
DataType::Float32 => ArrowTypeEnum::Float32(EmptyMessage {}),
DataType::Float64 => ArrowTypeEnum::Float64(EmptyMessage {}),
DataType::Timestamp(time_unit, timezone) => {
ArrowTypeEnum::Timestamp(protobuf::Timestamp {
time_unit: protobuf::TimeUnit::from_arrow_time_unit(time_unit) as i32,
timezone: timezone.to_owned().unwrap_or_default(),
})
}
DataType::Date32 => ArrowTypeEnum::Date32(EmptyMessage {}),
DataType::Date64 => ArrowTypeEnum::Date64(EmptyMessage {}),
DataType::Time32(time_unit) => ArrowTypeEnum::Time32(
protobuf::TimeUnit::from_arrow_time_unit(time_unit) as i32,
),
DataType::Time64(time_unit) => ArrowTypeEnum::Time64(
protobuf::TimeUnit::from_arrow_time_unit(time_unit) as i32,
),
DataType::Duration(time_unit) => ArrowTypeEnum::Duration(
protobuf::TimeUnit::from_arrow_time_unit(time_unit) as i32,
),
DataType::Interval(interval_unit) => ArrowTypeEnum::Interval(
protobuf::IntervalUnit::from_arrow_interval_unit(interval_unit) as i32,
),
DataType::Binary => ArrowTypeEnum::Binary(EmptyMessage {}),
DataType::FixedSizeBinary(size) => ArrowTypeEnum::FixedSizeBinary(*size),
DataType::LargeBinary => ArrowTypeEnum::LargeBinary(EmptyMessage {}),
DataType::Utf8 => ArrowTypeEnum::Utf8(EmptyMessage {}),
DataType::LargeUtf8 => ArrowTypeEnum::LargeUtf8(EmptyMessage {}),
DataType::List(item_type) => ArrowTypeEnum::List(Box::new(protobuf::List {
field_type: Some(Box::new(item_type.as_ref().into())),
})),
DataType::FixedSizeList(item_type, size) => {
ArrowTypeEnum::FixedSizeList(Box::new(protobuf::FixedSizeList {
field_type: Some(Box::new(item_type.as_ref().into())),
list_size: *size,
}))
}
DataType::LargeList(item_type) => {
ArrowTypeEnum::LargeList(Box::new(protobuf::List {
field_type: Some(Box::new(item_type.as_ref().into())),
}))
}
DataType::Struct(struct_fields) => ArrowTypeEnum::Struct(protobuf::Struct {
sub_field_types: struct_fields
.iter()
.map(|field| field.into())
.collect::<Vec<_>>(),
}),
DataType::Union(union_types, union_mode) => {
let union_mode = match union_mode {
UnionMode::Sparse => protobuf::UnionMode::Sparse,
UnionMode::Dense => protobuf::UnionMode::Dense,
};
ArrowTypeEnum::Union(protobuf::Union {
union_types: union_types
.iter()
.map(|field| field.into())
.collect::<Vec<_>>(),
union_mode: union_mode.into(),
})
}
DataType::Dictionary(key_type, value_type) => {
ArrowTypeEnum::Dictionary(Box::new(protobuf::Dictionary {
key: Some(Box::new(key_type.as_ref().into())),
value: Some(Box::new(value_type.as_ref().into())),
}))
}
DataType::Decimal(whole, fractional) => {
ArrowTypeEnum::Decimal(protobuf::Decimal {
whole: *whole as u64,
fractional: *fractional as u64,
})
}
DataType::Map(_, _) => {
unimplemented!("Ballista does not yet support Map data type")
}
}
}
}
impl TryInto<DataType> for &protobuf::ArrowType {
type Error = PlanSerDeError;
fn try_into(self) -> Result<DataType, Self::Error> {
let pb_arrow_type = self.arrow_type_enum.as_ref().ok_or_else(|| {
proto_error(
"Protobuf deserialization error: ArrowType missing required field 'data_type'",
)
})?;
pb_arrow_type.try_into()
}
}
impl TryInto<DataType> for &Box<protobuf::List> {
type Error = PlanSerDeError;
fn try_into(self) -> Result<DataType, Self::Error> {
let list_ref = self.as_ref();
match &list_ref.field_type {
Some(pb_field) => {
let pb_field_ref = pb_field.as_ref();
let arrow_field: Field = pb_field_ref.try_into()?;
Ok(DataType::List(Box::new(arrow_field)))
}
None => Err(proto_error(
"List message missing required field 'field_type'",
)),
}
}
}
fn create_proto_scalar<I, T: FnOnce(&I) -> protobuf::scalar_value::Value>(
v: &Option<I>,
null_arrow_type: protobuf::PrimitiveScalarType,
constructor: T,
) -> protobuf::ScalarValue {
protobuf::ScalarValue {
value: Some(v.as_ref().map(constructor).unwrap_or(
protobuf::scalar_value::Value::NullValue(null_arrow_type as i32),
)),
}
}
#[allow(clippy::from_over_into)]
impl Into<protobuf::Schema> for &Schema {
fn into(self) -> protobuf::Schema {
protobuf::Schema {
columns: self
.fields()
.iter()
.map(protobuf::Field::from)
.collect::<Vec<_>>(),
}
}
}
#[allow(clippy::from_over_into)]
impl Into<protobuf::Schema> for SchemaRef {
fn into(self) -> protobuf::Schema {
protobuf::Schema {
columns: self
.fields()
.iter()
.map(protobuf::Field::from)
.collect::<Vec<_>>(),
}
}
}
impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction {
type Error = PlanSerDeError;
fn try_from(scalar: &BuiltinScalarFunction) -> Result<Self, Self::Error> {
let scalar_function = match scalar {
BuiltinScalarFunction::Sqrt => Self::Sqrt,
BuiltinScalarFunction::Sin => Self::Sin,
BuiltinScalarFunction::Cos => Self::Cos,
BuiltinScalarFunction::Tan => Self::Tan,
BuiltinScalarFunction::Asin => Self::Asin,
BuiltinScalarFunction::Acos => Self::Acos,
BuiltinScalarFunction::Atan => Self::Atan,
BuiltinScalarFunction::Exp => Self::Exp,
BuiltinScalarFunction::Log => Self::Log,
BuiltinScalarFunction::Ln => Self::Ln,
BuiltinScalarFunction::Log10 => Self::Log10,
BuiltinScalarFunction::Floor => Self::Floor,
BuiltinScalarFunction::Ceil => Self::Ceil,
BuiltinScalarFunction::Round => Self::Round,
BuiltinScalarFunction::Trunc => Self::Trunc,
BuiltinScalarFunction::Abs => Self::Abs,
BuiltinScalarFunction::OctetLength => Self::OctetLength,
BuiltinScalarFunction::Concat => Self::Concat,
BuiltinScalarFunction::Lower => Self::Lower,
BuiltinScalarFunction::Upper => Self::Upper,
BuiltinScalarFunction::Trim => Self::Trim,
BuiltinScalarFunction::Ltrim => Self::Ltrim,
BuiltinScalarFunction::Rtrim => Self::Rtrim,
BuiltinScalarFunction::ToTimestamp => Self::ToTimestamp,
BuiltinScalarFunction::Array => Self::Array,
BuiltinScalarFunction::NullIf => Self::NullIf,
BuiltinScalarFunction::DatePart => Self::DatePart,
BuiltinScalarFunction::DateTrunc => Self::DateTrunc,
BuiltinScalarFunction::MD5 => Self::Md5,
BuiltinScalarFunction::SHA224 => Self::Sha224,
BuiltinScalarFunction::SHA256 => Self::Sha256,
BuiltinScalarFunction::SHA384 => Self::Sha384,
BuiltinScalarFunction::SHA512 => Self::Sha512,
BuiltinScalarFunction::Digest => Self::Digest,
BuiltinScalarFunction::ToTimestampMillis => Self::ToTimestampMillis,
BuiltinScalarFunction::Log2 => Self::Log2,
BuiltinScalarFunction::Signum => Self::Signum,
BuiltinScalarFunction::Ascii => Self::Ascii,
BuiltinScalarFunction::BitLength => Self::BitLength,
BuiltinScalarFunction::Btrim => Self::Btrim,
BuiltinScalarFunction::CharacterLength => Self::CharacterLength,
BuiltinScalarFunction::Chr => Self::Chr,
BuiltinScalarFunction::ConcatWithSeparator => Self::ConcatWithSeparator,
BuiltinScalarFunction::InitCap => Self::InitCap,
BuiltinScalarFunction::Left => Self::Left,
BuiltinScalarFunction::Lpad => Self::Lpad,
BuiltinScalarFunction::Random => Self::Random,
BuiltinScalarFunction::RegexpReplace => Self::RegexpReplace,
BuiltinScalarFunction::Repeat => Self::Repeat,
BuiltinScalarFunction::Replace => Self::Replace,
BuiltinScalarFunction::Reverse => Self::Reverse,
BuiltinScalarFunction::Right => Self::Right,
BuiltinScalarFunction::Rpad => Self::Rpad,
BuiltinScalarFunction::SplitPart => Self::SplitPart,
BuiltinScalarFunction::StartsWith => Self::StartsWith,
BuiltinScalarFunction::Strpos => Self::Strpos,
BuiltinScalarFunction::Substr => Self::Substr,
BuiltinScalarFunction::ToHex => Self::ToHex,
BuiltinScalarFunction::ToTimestampMicros => Self::ToTimestampMicros,
BuiltinScalarFunction::ToTimestampSeconds => Self::ToTimestampSeconds,
BuiltinScalarFunction::Now => Self::Now,
BuiltinScalarFunction::Translate => Self::Translate,
BuiltinScalarFunction::RegexpMatch => Self::RegexpMatch,
BuiltinScalarFunction::Coalesce => Self::Coalesce,
BuiltinScalarFunction::Power | BuiltinScalarFunction::Struct => todo!(),
};
Ok(scalar_function)
}
}
impl TryInto<Field> for &protobuf::Field {
type Error = PlanSerDeError;
fn try_into(self) -> Result<Field, Self::Error> {
let pb_datatype = self.arrow_type.as_ref().ok_or_else(|| {
proto_error(
"Protobuf deserialization error: Field message missing required field 'arrow_type'",
)
})?;
Ok(Field::new(
self.name.as_str(),
pb_datatype.as_ref().try_into()?,
self.nullable,
))
}
}
impl TryInto<Schema> for &protobuf::Schema {
type Error = PlanSerDeError;
fn try_into(self) -> Result<Schema, PlanSerDeError> {
let fields = self
.columns
.iter()
.map(|c| {
let pb_arrow_type_res = c
.arrow_type
.as_ref()
.ok_or_else(|| proto_error("Protobuf deserialization error: Field message was missing required field 'arrow_type'"));
let pb_arrow_type: &protobuf::ArrowType = match pb_arrow_type_res {
Ok(res) => res,
Err(e) => return Err(e),
};
Ok(Field::new(&c.name, pb_arrow_type.try_into()?, c.nullable))
})
.collect::<Result<Vec<_>, _>>()?;
Ok(Schema::new(fields))
}
}
impl TryInto<datafusion::scalar::ScalarValue> for &protobuf::ScalarValue {
type Error = PlanSerDeError;
fn try_into(self) -> Result<datafusion::scalar::ScalarValue, Self::Error> {
let value = self.value.as_ref().ok_or_else(|| {
proto_error("Protobuf deserialization error: missing required field 'value'")
})?;
Ok(match value {
protobuf::scalar_value::Value::BoolValue(v) => ScalarValue::Boolean(Some(*v)),
protobuf::scalar_value::Value::Utf8Value(v) => {
ScalarValue::Utf8(Some(v.to_owned()))
}
protobuf::scalar_value::Value::LargeUtf8Value(v) => {
ScalarValue::LargeUtf8(Some(v.to_owned()))
}
protobuf::scalar_value::Value::Int8Value(v) => {
ScalarValue::Int8(Some(*v as i8))
}
protobuf::scalar_value::Value::Int16Value(v) => {
ScalarValue::Int16(Some(*v as i16))
}
protobuf::scalar_value::Value::Int32Value(v) => ScalarValue::Int32(Some(*v)),
protobuf::scalar_value::Value::Int64Value(v) => ScalarValue::Int64(Some(*v)),
protobuf::scalar_value::Value::Uint8Value(v) => {
ScalarValue::UInt8(Some(*v as u8))
}
protobuf::scalar_value::Value::Uint16Value(v) => {
ScalarValue::UInt16(Some(*v as u16))
}
protobuf::scalar_value::Value::Uint32Value(v) => {
ScalarValue::UInt32(Some(*v))
}
protobuf::scalar_value::Value::Uint64Value(v) => {
ScalarValue::UInt64(Some(*v))
}
protobuf::scalar_value::Value::Float32Value(v) => {
ScalarValue::Float32(Some(*v))
}
protobuf::scalar_value::Value::Float64Value(v) => {
ScalarValue::Float64(Some(*v))
}
protobuf::scalar_value::Value::Date32Value(v) => {
ScalarValue::Date32(Some(*v))
}
protobuf::scalar_value::Value::TimeMicrosecondValue(v) => {
ScalarValue::TimestampMicrosecond(Some(*v), None)
}
protobuf::scalar_value::Value::TimeNanosecondValue(v) => {
ScalarValue::TimestampNanosecond(Some(*v), None)
}
protobuf::scalar_value::Value::DecimalValue(v) => {
let decimal = v.decimal.as_ref().unwrap();
ScalarValue::Decimal128(
Some(v.long_value as i128),
decimal.whole as usize,
decimal.fractional as usize,
)
}
protobuf::scalar_value::Value::ListValue(scalar_list) => {
let protobuf::ScalarListValue {
values,
datatype: opt_scalar_type,
} = &scalar_list;
let pb_scalar_type = opt_scalar_type
.as_ref()
.ok_or_else(|| proto_error("Protobuf deserialization err: ScalaListValue missing required field 'datatype'"))?;
let typechecked_values: Vec<ScalarValue> = values
.iter()
.map(|val| val.try_into())
.collect::<Result<Vec<_>, _>>()?;
let scalar_type: DataType = pb_scalar_type.try_into()?;
let scalar_type = Box::new(scalar_type);
ScalarValue::List(Some(Box::new(typechecked_values)), scalar_type)
}
protobuf::scalar_value::Value::NullListValue(v) => {
let pb_datatype = v
.datatype
.as_ref()
.ok_or_else(|| proto_error("Protobuf deserialization error: NullListValue message missing required field 'datatyp'"))?;
let pb_datatype = Box::new(pb_datatype.try_into()?);
ScalarValue::List(None, pb_datatype)
}
protobuf::scalar_value::Value::NullValue(v) => {
let null_type_enum = protobuf::PrimitiveScalarType::from_i32(*v)
.ok_or_else(|| proto_error("Protobuf deserialization error found invalid enum variant for DatafusionScalar"))?;
null_type_enum.try_into()?
}
})
}
}
impl TryInto<DataType> for &protobuf::ScalarType {
type Error = PlanSerDeError;
fn try_into(self) -> Result<DataType, Self::Error> {
let pb_scalartype = self.datatype.as_ref().ok_or_else(|| {
proto_error("ScalarType message missing required field 'datatype'")
})?;
pb_scalartype.try_into()
}
}
impl TryInto<DataType> for &protobuf::scalar_type::Datatype {
type Error = PlanSerDeError;
fn try_into(self) -> Result<DataType, Self::Error> {
use protobuf::scalar_type::Datatype;
Ok(match self {
Datatype::Scalar(scalar_type) => {
let pb_scalar_enum = protobuf::PrimitiveScalarType::from_i32(*scalar_type).ok_or_else(|| {
proto_error(format!(
"Protobuf deserialization error, scalar_type::Datatype missing was provided invalid enum variant: {}",
*scalar_type
))
})?;
pb_scalar_enum.into()
}
Datatype::List(protobuf::ScalarListType {
deepest_type,
field_names,
}) => {
if field_names.is_empty() {
return Err(proto_error(
"Protobuf deserialization error: found no field names in ScalarListType message which requires at least one",
));
}
let pb_scalar_type = protobuf::PrimitiveScalarType::from_i32(
*deepest_type,
)
.ok_or_else(|| {
proto_error(format!(
"Protobuf deserialization error: invalid i32 for scalar enum: {}",
*deepest_type
))
})?;
//Because length is checked above it is safe to unwrap .last()
let mut scalar_type = DataType::List(Box::new(Field::new(
field_names.last().unwrap().as_str(),
pb_scalar_type.into(),
true,
)));
//Iterate over field names in reverse order except for the last item in the vector
for name in field_names.iter().rev().skip(1) {
let new_datatype = DataType::List(Box::new(Field::new(
name.as_str(),
scalar_type,
true,
)));
scalar_type = new_datatype;
}
scalar_type
}
})
}
}
impl TryInto<datafusion::scalar::ScalarValue> for &protobuf::scalar_value::Value {
type Error = PlanSerDeError;
fn try_into(self) -> Result<datafusion::scalar::ScalarValue, Self::Error> {
use protobuf::PrimitiveScalarType;
let scalar = match self {
protobuf::scalar_value::Value::BoolValue(v) => ScalarValue::Boolean(Some(*v)),
protobuf::scalar_value::Value::Utf8Value(v) => {
ScalarValue::Utf8(Some(v.to_owned()))
}
protobuf::scalar_value::Value::LargeUtf8Value(v) => {
ScalarValue::LargeUtf8(Some(v.to_owned()))
}
protobuf::scalar_value::Value::Int8Value(v) => {
ScalarValue::Int8(Some(*v as i8))
}
protobuf::scalar_value::Value::Int16Value(v) => {
ScalarValue::Int16(Some(*v as i16))
}
protobuf::scalar_value::Value::Int32Value(v) => ScalarValue::Int32(Some(*v)),
protobuf::scalar_value::Value::Int64Value(v) => ScalarValue::Int64(Some(*v)),
protobuf::scalar_value::Value::Uint8Value(v) => {
ScalarValue::UInt8(Some(*v as u8))
}
protobuf::scalar_value::Value::Uint16Value(v) => {
ScalarValue::UInt16(Some(*v as u16))
}
protobuf::scalar_value::Value::Uint32Value(v) => {
ScalarValue::UInt32(Some(*v))
}
protobuf::scalar_value::Value::Uint64Value(v) => {
ScalarValue::UInt64(Some(*v))
}
protobuf::scalar_value::Value::Float32Value(v) => {
ScalarValue::Float32(Some(*v))
}
protobuf::scalar_value::Value::Float64Value(v) => {
ScalarValue::Float64(Some(*v))
}
protobuf::scalar_value::Value::Date32Value(v) => {
ScalarValue::Date32(Some(*v))
}
protobuf::scalar_value::Value::TimeMicrosecondValue(v) => {
ScalarValue::TimestampMicrosecond(Some(*v), None)
}
protobuf::scalar_value::Value::TimeNanosecondValue(v) => {
ScalarValue::TimestampNanosecond(Some(*v), None)
}
protobuf::scalar_value::Value::ListValue(v) => v.try_into()?,
protobuf::scalar_value::Value::NullListValue(v) => {
ScalarValue::List(None, Box::new(v.try_into()?))
}
protobuf::scalar_value::Value::NullValue(null_enum) => {
PrimitiveScalarType::from_i32(*null_enum)
.ok_or_else(|| proto_error("Invalid scalar type"))?
.try_into()?
}
protobuf::scalar_value::Value::DecimalValue(v) => {
let decimal = v.decimal.as_ref().unwrap();
ScalarValue::Decimal128(
Some(v.long_value as i128),
decimal.whole as usize,
decimal.fractional as usize,
)
}
};
Ok(scalar)
}
}
impl TryInto<datafusion::scalar::ScalarValue> for &protobuf::ScalarListValue {
type Error = PlanSerDeError;
fn try_into(self) -> Result<datafusion::scalar::ScalarValue, Self::Error> {
use protobuf::scalar_type::Datatype;
use protobuf::PrimitiveScalarType;
let protobuf::ScalarListValue { datatype, values } = self;
let pb_scalar_type = datatype
.as_ref()
.ok_or_else(|| proto_error("Protobuf deserialization error: ScalarListValue messsage missing required field 'datatype'"))?;
let scalar_type = pb_scalar_type
.datatype
.as_ref()
.ok_or_else(|| proto_error("Protobuf deserialization error: ScalarListValue.Datatype messsage missing required field 'datatype'"))?;
let scalar_values = match scalar_type {
Datatype::Scalar(scalar_type_i32) => {
let leaf_scalar_type =
protobuf::PrimitiveScalarType::from_i32(*scalar_type_i32)
.ok_or_else(|| {
proto_error("Error converting i32 to basic scalar type")
})?;
let typechecked_values: Vec<datafusion::scalar::ScalarValue> = values
.iter()
.map(|protobuf::ScalarValue { value: opt_value }| {
let value = opt_value.as_ref().ok_or_else(|| {
proto_error(
"Protobuf deserialization error: missing required field 'value'",
)
})?;
typechecked_scalar_value_conversion(value, leaf_scalar_type)
})
.collect::<Result<Vec<_>, _>>()?;
datafusion::scalar::ScalarValue::List(
Some(Box::new(typechecked_values)),
Box::new(leaf_scalar_type.into()),
)
}
Datatype::List(list_type) => {
let protobuf::ScalarListType {
deepest_type,
field_names,
} = &list_type;
let leaf_type =
PrimitiveScalarType::from_i32(*deepest_type).ok_or_else(|| {
proto_error("Error converting i32 to basic scalar type")
})?;
let depth = field_names.len();
let typechecked_values: Vec<datafusion::scalar::ScalarValue> = if depth
== 0
{
return Err(proto_error(
"Protobuf deserialization error, ScalarListType had no field names, requires at least one",
));
} else if depth == 1 {
values
.iter()
.map(|protobuf::ScalarValue { value: opt_value }| {
let value = opt_value
.as_ref()
.ok_or_else(|| proto_error("Protobuf deserialization error: missing required field 'value'"))?;
typechecked_scalar_value_conversion(value, leaf_type)
})
.collect::<Result<Vec<_>, _>>()?
} else {
values
.iter()
.map(|protobuf::ScalarValue { value: opt_value }| {
let value = opt_value
.as_ref()
.ok_or_else(|| proto_error("Protobuf deserialization error: missing required field 'value'"))?;
value.try_into()
})
.collect::<Result<Vec<_>, _>>()?
};
datafusion::scalar::ScalarValue::List(
match typechecked_values.len() {
0 => None,
_ => Some(Box::new(typechecked_values)),
},
Box::new(list_type.try_into()?),
)
}
};
Ok(scalar_values)
}
}
impl TryInto<DataType> for &protobuf::ScalarListType {
type Error = PlanSerDeError;
fn try_into(self) -> Result<DataType, Self::Error> {
use protobuf::PrimitiveScalarType;
let protobuf::ScalarListType {
deepest_type,
field_names,
} = self;
let depth = field_names.len();
if depth == 0 {
return Err(proto_error(
"Protobuf deserialization error: Found a ScalarListType message with no field names, at least one is required",
));
}
let mut curr_type = DataType::List(Box::new(Field::new(
//Since checked vector is not empty above this is safe to unwrap
field_names.last().unwrap(),
PrimitiveScalarType::from_i32(*deepest_type)
.ok_or_else(|| {
proto_error("Could not convert to datafusion scalar type")
})?
.into(),
true,
)));
//Iterates over field names in reverse order except for the last item in the vector
for name in field_names.iter().rev().skip(1) {
let temp_curr_type =
DataType::List(Box::new(Field::new(name, curr_type, true)));
curr_type = temp_curr_type;
}
Ok(curr_type)
}
}
//Does not typecheck lists
fn typechecked_scalar_value_conversion(
tested_type: &protobuf::scalar_value::Value,
required_type: protobuf::PrimitiveScalarType,
) -> Result<datafusion::scalar::ScalarValue, PlanSerDeError> {
use protobuf::scalar_value::Value;
use protobuf::PrimitiveScalarType;
Ok(match (tested_type, &required_type) {
(Value::BoolValue(v), PrimitiveScalarType::Bool) => {
ScalarValue::Boolean(Some(*v))
}
(Value::Int8Value(v), PrimitiveScalarType::Int8) => {
ScalarValue::Int8(Some(*v as i8))
}
(Value::Int16Value(v), PrimitiveScalarType::Int16) => {
ScalarValue::Int16(Some(*v as i16))
}
(Value::Int32Value(v), PrimitiveScalarType::Int32) => {
ScalarValue::Int32(Some(*v))
}
(Value::Int64Value(v), PrimitiveScalarType::Int64) => {
ScalarValue::Int64(Some(*v))
}
(Value::Uint8Value(v), PrimitiveScalarType::Uint8) => {
ScalarValue::UInt8(Some(*v as u8))
}
(Value::Uint16Value(v), PrimitiveScalarType::Uint16) => {
ScalarValue::UInt16(Some(*v as u16))
}
(Value::Uint32Value(v), PrimitiveScalarType::Uint32) => {
ScalarValue::UInt32(Some(*v))
}
(Value::Uint64Value(v), PrimitiveScalarType::Uint64) => {
ScalarValue::UInt64(Some(*v))
}
(Value::Float32Value(v), PrimitiveScalarType::Float32) => {
ScalarValue::Float32(Some(*v))
}
(Value::Float64Value(v), PrimitiveScalarType::Float64) => {
ScalarValue::Float64(Some(*v))
}
(Value::Date32Value(v), PrimitiveScalarType::Date32) => {
ScalarValue::Date32(Some(*v))
}
(Value::TimeMicrosecondValue(v), PrimitiveScalarType::TimeMicrosecond) => {
ScalarValue::TimestampMicrosecond(Some(*v), None)
}
(Value::TimeNanosecondValue(v), PrimitiveScalarType::TimeMicrosecond) => {
ScalarValue::TimestampNanosecond(Some(*v), None)
}
(Value::Utf8Value(v), PrimitiveScalarType::Utf8) => {
ScalarValue::Utf8(Some(v.to_owned()))
}
(Value::LargeUtf8Value(v), PrimitiveScalarType::LargeUtf8) => {
ScalarValue::LargeUtf8(Some(v.to_owned()))
}
(Value::NullValue(i32_enum), required_scalar_type) => {
if *i32_enum == *required_scalar_type as i32 {
let pb_scalar_type = PrimitiveScalarType::from_i32(*i32_enum).ok_or_else(|| {
PlanSerDeError::General(format!(
"Invalid i32_enum={} when converting with PrimitiveScalarType::from_i32()",
*i32_enum
))
})?;
let scalar_value: ScalarValue = match pb_scalar_type {
PrimitiveScalarType::Bool => ScalarValue::Boolean(None),
PrimitiveScalarType::Uint8 => ScalarValue::UInt8(None),
PrimitiveScalarType::Int8 => ScalarValue::Int8(None),
PrimitiveScalarType::Uint16 => ScalarValue::UInt16(None),
PrimitiveScalarType::Int16 => ScalarValue::Int16(None),
PrimitiveScalarType::Uint32 => ScalarValue::UInt32(None),
PrimitiveScalarType::Int32 => ScalarValue::Int32(None),
PrimitiveScalarType::Uint64 => ScalarValue::UInt64(None),
PrimitiveScalarType::Int64 => ScalarValue::Int64(None),
PrimitiveScalarType::Float32 => ScalarValue::Float32(None),
PrimitiveScalarType::Float64 => ScalarValue::Float64(None),
PrimitiveScalarType::Utf8 => ScalarValue::Utf8(None),
PrimitiveScalarType::LargeUtf8 => ScalarValue::LargeUtf8(None),
PrimitiveScalarType::Date32 => ScalarValue::Date32(None),
PrimitiveScalarType::TimeMicrosecond => {
ScalarValue::TimestampMicrosecond(None, None)
}
PrimitiveScalarType::TimeNanosecond => {
ScalarValue::TimestampNanosecond(None, None)
}
PrimitiveScalarType::Null => {
return Err(proto_error(
"Untyped scalar null is not a valid scalar value",
))
}
};
scalar_value
} else {
return Err(proto_error("Could not convert to the proper type"));
}
}
_ => return Err(proto_error("Could not convert to the proper type")),
})
}
impl TryInto<datafusion::scalar::ScalarValue> for protobuf::PrimitiveScalarType {
type Error = PlanSerDeError;
fn try_into(self) -> Result<datafusion::scalar::ScalarValue, Self::Error> {
Ok(match self {
protobuf::PrimitiveScalarType::Null => {
return Err(proto_error("Untyped null is an invalid scalar value"))
}
protobuf::PrimitiveScalarType::Bool => ScalarValue::Boolean(None),
protobuf::PrimitiveScalarType::Uint8 => ScalarValue::UInt8(None),
protobuf::PrimitiveScalarType::Int8 => ScalarValue::Int8(None),
protobuf::PrimitiveScalarType::Uint16 => ScalarValue::UInt16(None),
protobuf::PrimitiveScalarType::Int16 => ScalarValue::Int16(None),
protobuf::PrimitiveScalarType::Uint32 => ScalarValue::UInt32(None),
protobuf::PrimitiveScalarType::Int32 => ScalarValue::Int32(None),
protobuf::PrimitiveScalarType::Uint64 => ScalarValue::UInt64(None),
protobuf::PrimitiveScalarType::Int64 => ScalarValue::Int64(None),
protobuf::PrimitiveScalarType::Float32 => ScalarValue::Float32(None),
protobuf::PrimitiveScalarType::Float64 => ScalarValue::Float64(None),
protobuf::PrimitiveScalarType::Utf8 => ScalarValue::Utf8(None),
protobuf::PrimitiveScalarType::LargeUtf8 => ScalarValue::LargeUtf8(None),
protobuf::PrimitiveScalarType::Date32 => ScalarValue::Date32(None),
protobuf::PrimitiveScalarType::TimeMicrosecond => {
ScalarValue::TimestampMicrosecond(None, None)
}
protobuf::PrimitiveScalarType::TimeNanosecond => {
ScalarValue::TimestampNanosecond(None, None)
}
})
}
}
fn str_to_byte(s: &str) -> Result<u8, PlanSerDeError> {
if s.len() != 1 {
return Err(PlanSerDeError::General("Invalid CSV delimiter".to_owned()));
}
Ok(s.as_bytes()[0])
}