blob: 96c2a8f62daf78af4125a9a731aa9908155ef267 [file] [log] [blame]
// Copyright 2022 The Blaze Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::error::PlanSerDeError;
use datafusion::arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit};
use datafusion::logical_expr::AggregateFunction;
use datafusion::logical_plan::Operator;
use datafusion::physical_plan::join_utils::JoinSide;
use datafusion::prelude::JoinType;
use datafusion::scalar::ScalarValue;
// include the generated protobuf source as a submodule
#[allow(clippy::all)]
pub mod protobuf {
include!(concat!(env!("OUT_DIR"), "/plan.protobuf.rs"));
}
pub mod error;
pub mod from_proto;
pub(crate) fn proto_error<S: Into<String>>(message: S) -> PlanSerDeError {
PlanSerDeError::General(message.into())
}
#[macro_export]
macro_rules! convert_required {
($PB:expr) => {{
if let Some(field) = $PB.as_ref() {
field.try_into()
} else {
Err(proto_error("Missing required field in protobuf"))
}
}};
}
#[macro_export]
macro_rules! into_required {
($PB:expr) => {{
if let Some(field) = $PB.as_ref() {
Ok(field.into())
} else {
Err(proto_error("Missing required field in protobuf"))
}
}};
}
#[macro_export]
macro_rules! convert_box_required {
($PB:expr) => {{
if let Some(field) = $PB.as_ref() {
field.as_ref().try_into()
} else {
Err(proto_error("Missing required field in protobuf"))
}
}};
}
pub(crate) fn from_proto_binary_op(op: &str) -> Result<Operator, PlanSerDeError> {
match op {
"And" => Ok(Operator::And),
"Or" => Ok(Operator::Or),
"Eq" => Ok(Operator::Eq),
"NotEq" => Ok(Operator::NotEq),
"LtEq" => Ok(Operator::LtEq),
"Lt" => Ok(Operator::Lt),
"Gt" => Ok(Operator::Gt),
"GtEq" => Ok(Operator::GtEq),
"Plus" => Ok(Operator::Plus),
"Minus" => Ok(Operator::Minus),
"Multiply" => Ok(Operator::Multiply),
"Divide" => Ok(Operator::Divide),
"Modulo" => Ok(Operator::Modulo),
"Like" => Ok(Operator::Like),
"NotLike" => Ok(Operator::NotLike),
other => Err(proto_error(format!(
"Unsupported binary operator '{:?}'",
other
))),
}
}
impl From<protobuf::JoinType> for JoinType {
fn from(t: protobuf::JoinType) -> Self {
match t {
protobuf::JoinType::Inner => JoinType::Inner,
protobuf::JoinType::Left => JoinType::Left,
protobuf::JoinType::Right => JoinType::Right,
protobuf::JoinType::Full => JoinType::Full,
protobuf::JoinType::Semi => JoinType::Semi,
protobuf::JoinType::Anti => JoinType::Anti,
}
}
}
impl From<protobuf::JoinSide> for JoinSide {
fn from(t: protobuf::JoinSide) -> Self {
match t {
protobuf::JoinSide::LeftSide => JoinSide::Left,
protobuf::JoinSide::RightSide => JoinSide::Right,
}
}
}
impl From<protobuf::AggregateFunction> for AggregateFunction {
fn from(agg_fun: protobuf::AggregateFunction) -> AggregateFunction {
match agg_fun {
protobuf::AggregateFunction::Min => AggregateFunction::Min,
protobuf::AggregateFunction::Max => AggregateFunction::Max,
protobuf::AggregateFunction::Sum => AggregateFunction::Sum,
protobuf::AggregateFunction::Avg => AggregateFunction::Avg,
protobuf::AggregateFunction::Count => AggregateFunction::Count,
protobuf::AggregateFunction::ApproxDistinct => {
AggregateFunction::ApproxDistinct
}
protobuf::AggregateFunction::ArrayAgg => AggregateFunction::ArrayAgg,
protobuf::AggregateFunction::Variance => AggregateFunction::Variance,
protobuf::AggregateFunction::VariancePop => AggregateFunction::VariancePop,
protobuf::AggregateFunction::Covariance => AggregateFunction::Covariance,
protobuf::AggregateFunction::CovariancePop => {
AggregateFunction::CovariancePop
}
protobuf::AggregateFunction::Stddev => AggregateFunction::Stddev,
protobuf::AggregateFunction::StddevPop => AggregateFunction::StddevPop,
protobuf::AggregateFunction::Correlation => AggregateFunction::Correlation,
}
}
}
impl protobuf::TimeUnit {
pub fn from_arrow_time_unit(val: &TimeUnit) -> Self {
match val {
TimeUnit::Second => protobuf::TimeUnit::Second,
TimeUnit::Millisecond => protobuf::TimeUnit::TimeMillisecond,
TimeUnit::Microsecond => protobuf::TimeUnit::Microsecond,
TimeUnit::Nanosecond => protobuf::TimeUnit::Nanosecond,
}
}
pub fn from_i32_to_arrow(time_unit_i32: i32) -> Result<TimeUnit, PlanSerDeError> {
let pb_time_unit = protobuf::TimeUnit::from_i32(time_unit_i32);
match pb_time_unit {
Some(time_unit) => Ok(match time_unit {
protobuf::TimeUnit::Second => TimeUnit::Second,
protobuf::TimeUnit::TimeMillisecond => TimeUnit::Millisecond,
protobuf::TimeUnit::Microsecond => TimeUnit::Microsecond,
protobuf::TimeUnit::Nanosecond => TimeUnit::Nanosecond,
}),
None => Err(proto_error(
"Error converting i32 to TimeUnit: Passed invalid variant",
)),
}
}
}
impl protobuf::IntervalUnit {
pub fn from_arrow_interval_unit(interval_unit: &IntervalUnit) -> Self {
match interval_unit {
IntervalUnit::YearMonth => protobuf::IntervalUnit::YearMonth,
IntervalUnit::DayTime => protobuf::IntervalUnit::DayTime,
IntervalUnit::MonthDayNano => protobuf::IntervalUnit::MonthDayNano,
}
}
pub fn from_i32_to_arrow(
interval_unit_i32: i32,
) -> Result<IntervalUnit, PlanSerDeError> {
let pb_interval_unit = protobuf::IntervalUnit::from_i32(interval_unit_i32);
match pb_interval_unit {
Some(interval_unit) => Ok(match interval_unit {
protobuf::IntervalUnit::YearMonth => IntervalUnit::YearMonth,
protobuf::IntervalUnit::DayTime => IntervalUnit::DayTime,
protobuf::IntervalUnit::MonthDayNano => IntervalUnit::MonthDayNano,
}),
None => Err(proto_error(
"Error converting i32 to DateUnit: Passed invalid variant",
)),
}
}
}
impl TryInto<datafusion::arrow::datatypes::DataType>
for &protobuf::arrow_type::ArrowTypeEnum
{
type Error = PlanSerDeError;
fn try_into(self) -> Result<datafusion::arrow::datatypes::DataType, Self::Error> {
use protobuf::arrow_type;
Ok(match self {
arrow_type::ArrowTypeEnum::None(_) => DataType::Null,
arrow_type::ArrowTypeEnum::Bool(_) => DataType::Boolean,
arrow_type::ArrowTypeEnum::Uint8(_) => DataType::UInt8,
arrow_type::ArrowTypeEnum::Int8(_) => DataType::Int8,
arrow_type::ArrowTypeEnum::Uint16(_) => DataType::UInt16,
arrow_type::ArrowTypeEnum::Int16(_) => DataType::Int16,
arrow_type::ArrowTypeEnum::Uint32(_) => DataType::UInt32,
arrow_type::ArrowTypeEnum::Int32(_) => DataType::Int32,
arrow_type::ArrowTypeEnum::Uint64(_) => DataType::UInt64,
arrow_type::ArrowTypeEnum::Int64(_) => DataType::Int64,
arrow_type::ArrowTypeEnum::Float16(_) => DataType::Float16,
arrow_type::ArrowTypeEnum::Float32(_) => DataType::Float32,
arrow_type::ArrowTypeEnum::Float64(_) => DataType::Float64,
arrow_type::ArrowTypeEnum::Utf8(_) => DataType::Utf8,
arrow_type::ArrowTypeEnum::LargeUtf8(_) => DataType::LargeUtf8,
arrow_type::ArrowTypeEnum::Binary(_) => DataType::Binary,
arrow_type::ArrowTypeEnum::FixedSizeBinary(size) => {
DataType::FixedSizeBinary(*size)
}
arrow_type::ArrowTypeEnum::LargeBinary(_) => DataType::LargeBinary,
arrow_type::ArrowTypeEnum::Date32(_) => DataType::Date32,
arrow_type::ArrowTypeEnum::Date64(_) => DataType::Date64,
arrow_type::ArrowTypeEnum::Duration(time_unit) => {
DataType::Duration(protobuf::TimeUnit::from_i32_to_arrow(*time_unit)?)
}
arrow_type::ArrowTypeEnum::Timestamp(protobuf::Timestamp {
time_unit,
timezone,
}) => DataType::Timestamp(
protobuf::TimeUnit::from_i32_to_arrow(*time_unit)?,
match timezone.len() {
0 => None,
_ => Some(timezone.to_owned()),
},
),
arrow_type::ArrowTypeEnum::Time32(time_unit) => {
DataType::Time32(protobuf::TimeUnit::from_i32_to_arrow(*time_unit)?)
}
arrow_type::ArrowTypeEnum::Time64(time_unit) => {
DataType::Time64(protobuf::TimeUnit::from_i32_to_arrow(*time_unit)?)
}
arrow_type::ArrowTypeEnum::Interval(interval_unit) => DataType::Interval(
protobuf::IntervalUnit::from_i32_to_arrow(*interval_unit)?,
),
arrow_type::ArrowTypeEnum::Decimal(protobuf::Decimal {
whole,
fractional,
}) => DataType::Decimal(*whole as usize, *fractional as usize),
arrow_type::ArrowTypeEnum::List(list) => {
let list_type: &protobuf::Field = list
.as_ref()
.field_type
.as_ref()
.ok_or_else(|| proto_error("Protobuf deserialization error: List message missing required field 'field_type'"))?
.as_ref();
DataType::List(Box::new(list_type.try_into()?))
}
arrow_type::ArrowTypeEnum::LargeList(list) => {
let list_type: &protobuf::Field = list
.as_ref()
.field_type
.as_ref()
.ok_or_else(|| proto_error("Protobuf deserialization error: List message missing required field 'field_type'"))?
.as_ref();
DataType::LargeList(Box::new(list_type.try_into()?))
}
arrow_type::ArrowTypeEnum::FixedSizeList(list) => {
let list_type: &protobuf::Field = list
.as_ref()
.field_type
.as_ref()
.ok_or_else(|| proto_error("Protobuf deserialization error: List message missing required field 'field_type'"))?
.as_ref();
let list_size = list.list_size;
DataType::FixedSizeList(Box::new(list_type.try_into()?), list_size)
}
arrow_type::ArrowTypeEnum::Struct(strct) => DataType::Struct(
strct
.sub_field_types
.iter()
.map(|field| field.try_into())
.collect::<Result<Vec<_>, _>>()?,
),
arrow_type::ArrowTypeEnum::Union(_union) => {
// let union_mode = protobuf::UnionMode::from_i32(union.union_mode)
// .ok_or_else(|| {
// proto_error(
// "Protobuf deserialization error: Unknown union mode type",
// )
// })?;
// let union_mode = match union_mode {
// protobuf::UnionMode::Dense => UnionMode::Dense,
// protobuf::UnionMode::Sparse => UnionMode::Sparse,
// };
// let union_types = union
// .union_types
// .iter()
// .map(|field| field.try_into())
// .collect::<Result<Vec<_>, _>>()?;
// DataType::Union(union_types, _, union_mode)
unimplemented!()
}
arrow_type::ArrowTypeEnum::Dictionary(dict) => {
let pb_key_datatype = dict
.as_ref()
.key
.as_ref()
.ok_or_else(|| proto_error("Protobuf deserialization error: Dictionary message missing required field 'key'"))?;
let pb_value_datatype = dict
.as_ref()
.value
.as_ref()
.ok_or_else(|| proto_error("Protobuf deserialization error: Dictionary message missing required field 'key'"))?;
let key_datatype: DataType = pb_key_datatype.as_ref().try_into()?;
let value_datatype: DataType = pb_value_datatype.as_ref().try_into()?;
DataType::Dictionary(Box::new(key_datatype), Box::new(value_datatype))
}
})
}
}
#[allow(clippy::from_over_into)]
impl Into<datafusion::arrow::datatypes::DataType> for protobuf::PrimitiveScalarType {
fn into(self) -> datafusion::arrow::datatypes::DataType {
match self {
protobuf::PrimitiveScalarType::Bool => DataType::Boolean,
protobuf::PrimitiveScalarType::Uint8 => DataType::UInt8,
protobuf::PrimitiveScalarType::Int8 => DataType::Int8,
protobuf::PrimitiveScalarType::Uint16 => DataType::UInt16,
protobuf::PrimitiveScalarType::Int16 => DataType::Int16,
protobuf::PrimitiveScalarType::Uint32 => DataType::UInt32,
protobuf::PrimitiveScalarType::Int32 => DataType::Int32,
protobuf::PrimitiveScalarType::Uint64 => DataType::UInt64,
protobuf::PrimitiveScalarType::Int64 => DataType::Int64,
protobuf::PrimitiveScalarType::Float32 => DataType::Float32,
protobuf::PrimitiveScalarType::Float64 => DataType::Float64,
protobuf::PrimitiveScalarType::Utf8 => DataType::Utf8,
protobuf::PrimitiveScalarType::LargeUtf8 => DataType::LargeUtf8,
protobuf::PrimitiveScalarType::Date32 => DataType::Date32,
protobuf::PrimitiveScalarType::TimeMicrosecond => {
DataType::Time64(TimeUnit::Microsecond)
}
protobuf::PrimitiveScalarType::TimeNanosecond => {
DataType::Time64(TimeUnit::Nanosecond)
}
protobuf::PrimitiveScalarType::Null => DataType::Null,
protobuf::PrimitiveScalarType::Decimal128 => DataType::Decimal(0, 0),
protobuf::PrimitiveScalarType::Date64 => DataType::Date64,
protobuf::PrimitiveScalarType::TimeSecond => {
DataType::Timestamp(TimeUnit::Second, None)
}
protobuf::PrimitiveScalarType::TimeMillisecond => {
DataType::Timestamp(TimeUnit::Millisecond, None)
}
protobuf::PrimitiveScalarType::IntervalYearmonth => {
DataType::Interval(IntervalUnit::YearMonth)
}
protobuf::PrimitiveScalarType::IntervalDaytime => {
DataType::Interval(IntervalUnit::DayTime)
}
}
}
}
impl TryInto<DataType> for &protobuf::ArrowType {
type Error = PlanSerDeError;
fn try_into(self) -> Result<DataType, Self::Error> {
let pb_arrow_type = self.arrow_type_enum.as_ref().ok_or_else(|| {
proto_error(
"Protobuf deserialization error: ArrowType missing required field 'data_type'",
)
})?;
pb_arrow_type.try_into()
}
}
impl TryInto<DataType> for &Box<protobuf::List> {
type Error = PlanSerDeError;
fn try_into(self) -> Result<DataType, Self::Error> {
let list_ref = self.as_ref();
match &list_ref.field_type {
Some(pb_field) => {
let pb_field_ref = pb_field.as_ref();
let arrow_field: Field = pb_field_ref.try_into()?;
Ok(DataType::List(Box::new(arrow_field)))
}
None => Err(proto_error(
"List message missing required field 'field_type'",
)),
}
}
}
impl TryInto<Field> for &protobuf::Field {
type Error = PlanSerDeError;
fn try_into(self) -> Result<Field, Self::Error> {
let pb_datatype = self.arrow_type.as_ref().ok_or_else(|| {
proto_error(
"Protobuf deserialization error: Field message missing required field 'arrow_type'",
)
})?;
Ok(Field::new(
self.name.as_str(),
pb_datatype.as_ref().try_into()?,
self.nullable,
))
}
}
impl TryInto<Schema> for &protobuf::Schema {
type Error = PlanSerDeError;
fn try_into(self) -> Result<Schema, PlanSerDeError> {
let fields = self
.columns
.iter()
.map(|c| {
let pb_arrow_type_res = c
.arrow_type
.as_ref()
.ok_or_else(|| proto_error("Protobuf deserialization error: Field message was missing required field 'arrow_type'"));
let pb_arrow_type: &protobuf::ArrowType = match pb_arrow_type_res {
Ok(res) => res,
Err(e) => return Err(e),
};
Ok(Field::new(&c.name, pb_arrow_type.try_into()?, c.nullable))
})
.collect::<Result<Vec<_>, _>>()?;
Ok(Schema::new(fields))
}
}
impl TryInto<datafusion::scalar::ScalarValue> for &protobuf::ScalarValue {
type Error = PlanSerDeError;
fn try_into(self) -> Result<datafusion::scalar::ScalarValue, Self::Error> {
let value = self.value.as_ref().ok_or_else(|| {
proto_error("Protobuf deserialization error: missing required field 'value'")
})?;
Ok(match value {
protobuf::scalar_value::Value::BoolValue(v) => ScalarValue::Boolean(Some(*v)),
protobuf::scalar_value::Value::Utf8Value(v) => {
ScalarValue::Utf8(Some(v.to_owned()))
}
protobuf::scalar_value::Value::LargeUtf8Value(v) => {
ScalarValue::LargeUtf8(Some(v.to_owned()))
}
protobuf::scalar_value::Value::Int8Value(v) => {
ScalarValue::Int8(Some(*v as i8))
}
protobuf::scalar_value::Value::Int16Value(v) => {
ScalarValue::Int16(Some(*v as i16))
}
protobuf::scalar_value::Value::Int32Value(v) => ScalarValue::Int32(Some(*v)),
protobuf::scalar_value::Value::Int64Value(v) => ScalarValue::Int64(Some(*v)),
protobuf::scalar_value::Value::Uint8Value(v) => {
ScalarValue::UInt8(Some(*v as u8))
}
protobuf::scalar_value::Value::Uint16Value(v) => {
ScalarValue::UInt16(Some(*v as u16))
}
protobuf::scalar_value::Value::Uint32Value(v) => {
ScalarValue::UInt32(Some(*v))
}
protobuf::scalar_value::Value::Uint64Value(v) => {
ScalarValue::UInt64(Some(*v))
}
protobuf::scalar_value::Value::Float32Value(v) => {
ScalarValue::Float32(Some(*v))
}
protobuf::scalar_value::Value::Float64Value(v) => {
ScalarValue::Float64(Some(*v))
}
protobuf::scalar_value::Value::Date32Value(v) => {
ScalarValue::Date32(Some(*v))
}
protobuf::scalar_value::Value::TimeMicrosecondValue(v) => {
ScalarValue::TimestampMicrosecond(Some(*v), None)
}
protobuf::scalar_value::Value::TimeNanosecondValue(v) => {
ScalarValue::TimestampNanosecond(Some(*v), None)
}
protobuf::scalar_value::Value::DecimalValue(v) => {
let decimal = v.decimal.as_ref().unwrap();
ScalarValue::Decimal128(
Some(v.long_value as i128),
decimal.whole as usize,
decimal.fractional as usize,
)
}
protobuf::scalar_value::Value::ListValue(scalar_list) => {
let protobuf::ScalarListValue {
values,
datatype: opt_scalar_type,
} = &scalar_list;
let pb_scalar_type = opt_scalar_type
.as_ref()
.ok_or_else(|| proto_error("Protobuf deserialization err: ScalaListValue missing required field 'datatype'"))?;
let typechecked_values: Vec<ScalarValue> = values
.iter()
.map(|val| val.try_into())
.collect::<Result<Vec<_>, _>>()?;
let scalar_type: DataType = pb_scalar_type.try_into()?;
let scalar_type = Box::new(scalar_type);
ScalarValue::List(Some(typechecked_values), scalar_type)
}
protobuf::scalar_value::Value::NullListValue(v) => {
let pb_datatype = v
.datatype
.as_ref()
.ok_or_else(|| proto_error("Protobuf deserialization error: NullListValue message missing required field 'datatyp'"))?;
let pb_datatype = Box::new(pb_datatype.try_into()?);
ScalarValue::List(None, pb_datatype)
}
protobuf::scalar_value::Value::NullValue(v) => {
let null_type_enum = protobuf::PrimitiveScalarType::from_i32(*v)
.ok_or_else(|| proto_error("Protobuf deserialization error found invalid enum variant for DatafusionScalar"))?;
null_type_enum.try_into()?
}
})
}
}
impl TryInto<DataType> for &protobuf::ScalarType {
type Error = PlanSerDeError;
fn try_into(self) -> Result<DataType, Self::Error> {
let pb_scalartype = self.datatype.as_ref().ok_or_else(|| {
proto_error("ScalarType message missing required field 'datatype'")
})?;
pb_scalartype.try_into()
}
}
impl TryInto<DataType> for &protobuf::scalar_type::Datatype {
type Error = PlanSerDeError;
fn try_into(self) -> Result<DataType, Self::Error> {
use protobuf::scalar_type::Datatype;
Ok(match self {
Datatype::Scalar(scalar_type) => {
let pb_scalar_enum = protobuf::PrimitiveScalarType::from_i32(*scalar_type).ok_or_else(|| {
proto_error(format!(
"Protobuf deserialization error, scalar_type::Datatype missing was provided invalid enum variant: {}",
*scalar_type
))
})?;
pb_scalar_enum.into()
}
Datatype::List(protobuf::ScalarListType {
deepest_type,
field_names,
}) => {
if field_names.is_empty() {
return Err(proto_error(
"Protobuf deserialization error: found no field names in ScalarListType message which requires at least one",
));
}
let pb_scalar_type = protobuf::PrimitiveScalarType::from_i32(
*deepest_type,
)
.ok_or_else(|| {
proto_error(format!(
"Protobuf deserialization error: invalid i32 for scalar enum: {}",
*deepest_type
))
})?;
//Because length is checked above it is safe to unwrap .last()
let mut scalar_type = DataType::List(Box::new(Field::new(
field_names.last().unwrap().as_str(),
pb_scalar_type.into(),
true,
)));
//Iterate over field names in reverse order except for the last item in the vector
for name in field_names.iter().rev().skip(1) {
let new_datatype = DataType::List(Box::new(Field::new(
name.as_str(),
scalar_type,
true,
)));
scalar_type = new_datatype;
}
scalar_type
}
})
}
}
impl TryInto<datafusion::scalar::ScalarValue> for &protobuf::scalar_value::Value {
type Error = PlanSerDeError;
fn try_into(self) -> Result<datafusion::scalar::ScalarValue, Self::Error> {
use protobuf::PrimitiveScalarType;
let scalar = match self {
protobuf::scalar_value::Value::BoolValue(v) => ScalarValue::Boolean(Some(*v)),
protobuf::scalar_value::Value::Utf8Value(v) => {
ScalarValue::Utf8(Some(v.to_owned()))
}
protobuf::scalar_value::Value::LargeUtf8Value(v) => {
ScalarValue::LargeUtf8(Some(v.to_owned()))
}
protobuf::scalar_value::Value::Int8Value(v) => {
ScalarValue::Int8(Some(*v as i8))
}
protobuf::scalar_value::Value::Int16Value(v) => {
ScalarValue::Int16(Some(*v as i16))
}
protobuf::scalar_value::Value::Int32Value(v) => ScalarValue::Int32(Some(*v)),
protobuf::scalar_value::Value::Int64Value(v) => ScalarValue::Int64(Some(*v)),
protobuf::scalar_value::Value::Uint8Value(v) => {
ScalarValue::UInt8(Some(*v as u8))
}
protobuf::scalar_value::Value::Uint16Value(v) => {
ScalarValue::UInt16(Some(*v as u16))
}
protobuf::scalar_value::Value::Uint32Value(v) => {
ScalarValue::UInt32(Some(*v))
}
protobuf::scalar_value::Value::Uint64Value(v) => {
ScalarValue::UInt64(Some(*v))
}
protobuf::scalar_value::Value::Float32Value(v) => {
ScalarValue::Float32(Some(*v))
}
protobuf::scalar_value::Value::Float64Value(v) => {
ScalarValue::Float64(Some(*v))
}
protobuf::scalar_value::Value::Date32Value(v) => {
ScalarValue::Date32(Some(*v))
}
protobuf::scalar_value::Value::TimeMicrosecondValue(v) => {
ScalarValue::TimestampMicrosecond(Some(*v), None)
}
protobuf::scalar_value::Value::TimeNanosecondValue(v) => {
ScalarValue::TimestampNanosecond(Some(*v), None)
}
protobuf::scalar_value::Value::ListValue(v) => v.try_into()?,
protobuf::scalar_value::Value::NullListValue(v) => {
ScalarValue::List(None, Box::new(v.try_into()?))
}
protobuf::scalar_value::Value::NullValue(null_enum) => {
PrimitiveScalarType::from_i32(*null_enum)
.ok_or_else(|| proto_error("Invalid scalar type"))?
.try_into()?
}
protobuf::scalar_value::Value::DecimalValue(v) => {
let decimal = v.decimal.as_ref().unwrap();
ScalarValue::Decimal128(
Some(v.long_value as i128),
decimal.whole as usize,
decimal.fractional as usize,
)
}
};
Ok(scalar)
}
}
impl TryInto<datafusion::scalar::ScalarValue> for &protobuf::ScalarListValue {
type Error = PlanSerDeError;
fn try_into(self) -> Result<datafusion::scalar::ScalarValue, Self::Error> {
use protobuf::scalar_type::Datatype;
use protobuf::PrimitiveScalarType;
let protobuf::ScalarListValue { datatype, values } = self;
let pb_scalar_type = datatype
.as_ref()
.ok_or_else(|| proto_error("Protobuf deserialization error: ScalarListValue messsage missing required field 'datatype'"))?;
let scalar_type = pb_scalar_type
.datatype
.as_ref()
.ok_or_else(|| proto_error("Protobuf deserialization error: ScalarListValue.Datatype messsage missing required field 'datatype'"))?;
let scalar_values = match scalar_type {
Datatype::Scalar(scalar_type_i32) => {
let leaf_scalar_type =
protobuf::PrimitiveScalarType::from_i32(*scalar_type_i32)
.ok_or_else(|| {
proto_error("Error converting i32 to basic scalar type")
})?;
let typechecked_values: Vec<datafusion::scalar::ScalarValue> = values
.iter()
.map(|protobuf::ScalarValue { value: opt_value }| {
let value = opt_value.as_ref().ok_or_else(|| {
proto_error(
"Protobuf deserialization error: missing required field 'value'",
)
})?;
typechecked_scalar_value_conversion(value, leaf_scalar_type)
})
.collect::<Result<Vec<_>, _>>()?;
datafusion::scalar::ScalarValue::List(
Some(typechecked_values),
Box::new(leaf_scalar_type.into()),
)
}
Datatype::List(list_type) => {
let protobuf::ScalarListType {
deepest_type,
field_names,
} = &list_type;
let leaf_type =
PrimitiveScalarType::from_i32(*deepest_type).ok_or_else(|| {
proto_error("Error converting i32 to basic scalar type")
})?;
let depth = field_names.len();
let typechecked_values: Vec<datafusion::scalar::ScalarValue> = if depth
== 0
{
return Err(proto_error(
"Protobuf deserialization error, ScalarListType had no field names, requires at least one",
));
} else if depth == 1 {
values
.iter()
.map(|protobuf::ScalarValue { value: opt_value }| {
let value = opt_value
.as_ref()
.ok_or_else(|| proto_error("Protobuf deserialization error: missing required field 'value'"))?;
typechecked_scalar_value_conversion(value, leaf_type)
})
.collect::<Result<Vec<_>, _>>()?
} else {
values
.iter()
.map(|protobuf::ScalarValue { value: opt_value }| {
let value = opt_value
.as_ref()
.ok_or_else(|| proto_error("Protobuf deserialization error: missing required field 'value'"))?;
value.try_into()
})
.collect::<Result<Vec<_>, _>>()?
};
datafusion::scalar::ScalarValue::List(
match typechecked_values.len() {
0 => None,
_ => Some(typechecked_values),
},
Box::new(list_type.try_into()?),
)
}
};
Ok(scalar_values)
}
}
impl TryInto<DataType> for &protobuf::ScalarListType {
type Error = PlanSerDeError;
fn try_into(self) -> Result<DataType, Self::Error> {
use protobuf::PrimitiveScalarType;
let protobuf::ScalarListType {
deepest_type,
field_names,
} = self;
let depth = field_names.len();
if depth == 0 {
return Err(proto_error(
"Protobuf deserialization error: Found a ScalarListType message with no field names, at least one is required",
));
}
let mut curr_type = DataType::List(Box::new(Field::new(
//Since checked vector is not empty above this is safe to unwrap
field_names.last().unwrap(),
PrimitiveScalarType::from_i32(*deepest_type)
.ok_or_else(|| {
proto_error("Could not convert to datafusion scalar type")
})?
.into(),
true,
)));
//Iterates over field names in reverse order except for the last item in the vector
for name in field_names.iter().rev().skip(1) {
let temp_curr_type =
DataType::List(Box::new(Field::new(name, curr_type, true)));
curr_type = temp_curr_type;
}
Ok(curr_type)
}
}
//Does not typecheck lists
fn typechecked_scalar_value_conversion(
tested_type: &protobuf::scalar_value::Value,
required_type: protobuf::PrimitiveScalarType,
) -> Result<datafusion::scalar::ScalarValue, PlanSerDeError> {
use protobuf::scalar_value::Value;
use protobuf::PrimitiveScalarType;
Ok(match (tested_type, &required_type) {
(Value::BoolValue(v), PrimitiveScalarType::Bool) => {
ScalarValue::Boolean(Some(*v))
}
(Value::Int8Value(v), PrimitiveScalarType::Int8) => {
ScalarValue::Int8(Some(*v as i8))
}
(Value::Int16Value(v), PrimitiveScalarType::Int16) => {
ScalarValue::Int16(Some(*v as i16))
}
(Value::Int32Value(v), PrimitiveScalarType::Int32) => {
ScalarValue::Int32(Some(*v))
}
(Value::Int64Value(v), PrimitiveScalarType::Int64) => {
ScalarValue::Int64(Some(*v))
}
(Value::Uint8Value(v), PrimitiveScalarType::Uint8) => {
ScalarValue::UInt8(Some(*v as u8))
}
(Value::Uint16Value(v), PrimitiveScalarType::Uint16) => {
ScalarValue::UInt16(Some(*v as u16))
}
(Value::Uint32Value(v), PrimitiveScalarType::Uint32) => {
ScalarValue::UInt32(Some(*v))
}
(Value::Uint64Value(v), PrimitiveScalarType::Uint64) => {
ScalarValue::UInt64(Some(*v))
}
(Value::Float32Value(v), PrimitiveScalarType::Float32) => {
ScalarValue::Float32(Some(*v))
}
(Value::Float64Value(v), PrimitiveScalarType::Float64) => {
ScalarValue::Float64(Some(*v))
}
(Value::Date32Value(v), PrimitiveScalarType::Date32) => {
ScalarValue::Date32(Some(*v))
}
(Value::TimeMicrosecondValue(v), PrimitiveScalarType::TimeMicrosecond) => {
ScalarValue::TimestampMicrosecond(Some(*v), None)
}
(Value::TimeNanosecondValue(v), PrimitiveScalarType::TimeMicrosecond) => {
ScalarValue::TimestampNanosecond(Some(*v), None)
}
(Value::Utf8Value(v), PrimitiveScalarType::Utf8) => {
ScalarValue::Utf8(Some(v.to_owned()))
}
(Value::LargeUtf8Value(v), PrimitiveScalarType::LargeUtf8) => {
ScalarValue::LargeUtf8(Some(v.to_owned()))
}
(Value::NullValue(i32_enum), required_scalar_type) => {
if *i32_enum == *required_scalar_type as i32 {
let pb_scalar_type = PrimitiveScalarType::from_i32(*i32_enum).ok_or_else(|| {
PlanSerDeError::General(format!(
"Invalid i32_enum={} when converting with PrimitiveScalarType::from_i32()",
*i32_enum
))
})?;
let scalar_value: ScalarValue = match pb_scalar_type {
PrimitiveScalarType::Bool => ScalarValue::Boolean(None),
PrimitiveScalarType::Uint8 => ScalarValue::UInt8(None),
PrimitiveScalarType::Int8 => ScalarValue::Int8(None),
PrimitiveScalarType::Uint16 => ScalarValue::UInt16(None),
PrimitiveScalarType::Int16 => ScalarValue::Int16(None),
PrimitiveScalarType::Uint32 => ScalarValue::UInt32(None),
PrimitiveScalarType::Int32 => ScalarValue::Int32(None),
PrimitiveScalarType::Uint64 => ScalarValue::UInt64(None),
PrimitiveScalarType::Int64 => ScalarValue::Int64(None),
PrimitiveScalarType::Float32 => ScalarValue::Float32(None),
PrimitiveScalarType::Float64 => ScalarValue::Float64(None),
PrimitiveScalarType::Utf8 => ScalarValue::Utf8(None),
PrimitiveScalarType::LargeUtf8 => ScalarValue::LargeUtf8(None),
PrimitiveScalarType::Date32 => ScalarValue::Date32(None),
PrimitiveScalarType::TimeMicrosecond => {
ScalarValue::TimestampMicrosecond(None, None)
}
PrimitiveScalarType::TimeNanosecond => {
ScalarValue::TimestampNanosecond(None, None)
}
PrimitiveScalarType::Null => {
return Err(proto_error(
"Untyped scalar null is not a valid scalar value",
))
}
PrimitiveScalarType::Decimal128 => {
ScalarValue::Decimal128(None, 0, 0)
}
PrimitiveScalarType::Date64 => ScalarValue::Date64(None),
PrimitiveScalarType::TimeSecond => {
ScalarValue::TimestampSecond(None, None)
}
PrimitiveScalarType::TimeMillisecond => {
ScalarValue::TimestampMillisecond(None, None)
}
PrimitiveScalarType::IntervalYearmonth => {
ScalarValue::IntervalYearMonth(None)
}
PrimitiveScalarType::IntervalDaytime => {
ScalarValue::IntervalDayTime(None)
}
};
scalar_value
} else {
return Err(proto_error("Could not convert to the proper type"));
}
}
_ => return Err(proto_error("Could not convert to the proper type")),
})
}
impl TryInto<datafusion::scalar::ScalarValue> for protobuf::PrimitiveScalarType {
type Error = PlanSerDeError;
fn try_into(self) -> Result<datafusion::scalar::ScalarValue, Self::Error> {
Ok(match self {
protobuf::PrimitiveScalarType::Null => {
return Err(proto_error("Untyped null is an invalid scalar value"))
}
protobuf::PrimitiveScalarType::Bool => ScalarValue::Boolean(None),
protobuf::PrimitiveScalarType::Uint8 => ScalarValue::UInt8(None),
protobuf::PrimitiveScalarType::Int8 => ScalarValue::Int8(None),
protobuf::PrimitiveScalarType::Uint16 => ScalarValue::UInt16(None),
protobuf::PrimitiveScalarType::Int16 => ScalarValue::Int16(None),
protobuf::PrimitiveScalarType::Uint32 => ScalarValue::UInt32(None),
protobuf::PrimitiveScalarType::Int32 => ScalarValue::Int32(None),
protobuf::PrimitiveScalarType::Uint64 => ScalarValue::UInt64(None),
protobuf::PrimitiveScalarType::Int64 => ScalarValue::Int64(None),
protobuf::PrimitiveScalarType::Float32 => ScalarValue::Float32(None),
protobuf::PrimitiveScalarType::Float64 => ScalarValue::Float64(None),
protobuf::PrimitiveScalarType::Utf8 => ScalarValue::Utf8(None),
protobuf::PrimitiveScalarType::LargeUtf8 => ScalarValue::LargeUtf8(None),
protobuf::PrimitiveScalarType::Date32 => ScalarValue::Date32(None),
protobuf::PrimitiveScalarType::TimeMicrosecond => {
ScalarValue::TimestampMicrosecond(None, None)
}
protobuf::PrimitiveScalarType::TimeNanosecond => {
ScalarValue::TimestampNanosecond(None, None)
}
protobuf::PrimitiveScalarType::Decimal128 => {
ScalarValue::Decimal128(None, 0, 0)
}
protobuf::PrimitiveScalarType::Date64 => ScalarValue::Date64(None),
protobuf::PrimitiveScalarType::TimeSecond => {
ScalarValue::TimestampSecond(None, None)
}
protobuf::PrimitiveScalarType::TimeMillisecond => {
ScalarValue::TimestampMillisecond(None, None)
}
protobuf::PrimitiveScalarType::IntervalYearmonth => {
ScalarValue::IntervalYearMonth(None)
}
protobuf::PrimitiveScalarType::IntervalDaytime => {
ScalarValue::IntervalDayTime(None)
}
})
}
}