blob: 56cd4a6bf61f3e25775007c3c5c65a6901edcb4a [file] [log] [blame]
// Copyright 2022 The Blaze Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use arrow::datatypes::{DataType, Field, Fields, IntervalUnit, Schema, TimeUnit};
use datafusion::{common::JoinSide, logical_expr::Operator, scalar::ScalarValue};
use datafusion_ext_plans::{agg::AggFunction, joins::join_utils::JoinType};
use crate::error::PlanSerDeError;
// include the generated protobuf source as a submodule
#[allow(clippy::all)]
pub mod protobuf {
include!(concat!(env!("OUT_DIR"), "/plan.protobuf.rs"));
}
pub mod error;
pub mod from_proto;
pub(crate) fn proto_error<S: Into<String>>(message: S) -> PlanSerDeError {
PlanSerDeError::General(message.into())
}
#[macro_export]
macro_rules! convert_required {
($PB:expr) => {{
if let Some(field) = $PB.as_ref() {
field.try_into()
} else {
Err(proto_error("Missing required field in protobuf"))
}
}};
}
#[macro_export]
macro_rules! into_required {
($PB:expr) => {{
if let Some(field) = $PB.as_ref() {
Ok(field.into())
} else {
Err(proto_error("Missing required field in protobuf"))
}
}};
}
#[macro_export]
macro_rules! convert_box_required {
($PB:expr) => {{
if let Some(field) = $PB.as_ref() {
field.as_ref().try_into()
} else {
Err(proto_error("Missing required field in protobuf"))
}
}};
}
pub fn from_proto_binary_op(op: &str) -> Result<Operator, PlanSerDeError> {
match op {
"And" => Ok(Operator::And),
"Or" => Ok(Operator::Or),
"Eq" => Ok(Operator::Eq),
"NotEq" => Ok(Operator::NotEq),
"LtEq" => Ok(Operator::LtEq),
"Lt" => Ok(Operator::Lt),
"Gt" => Ok(Operator::Gt),
"GtEq" => Ok(Operator::GtEq),
"Plus" => Ok(Operator::Plus),
"Minus" => Ok(Operator::Minus),
"Multiply" => Ok(Operator::Multiply),
"Divide" => Ok(Operator::Divide),
"Modulo" => Ok(Operator::Modulo),
"IsDistinctFrom" => Ok(Operator::IsDistinctFrom),
"IsNotDistinctFrom" => Ok(Operator::IsNotDistinctFrom),
"BitwiseAnd" => Ok(Operator::BitwiseAnd),
"BitwiseOr" => Ok(Operator::BitwiseOr),
"BitwiseXor" => Ok(Operator::BitwiseXor),
"BitwiseShiftLeft" => Ok(Operator::BitwiseShiftLeft),
"BitwiseShiftRight" => Ok(Operator::BitwiseShiftRight),
"RegexIMatch" => Ok(Operator::RegexIMatch),
"RegexMatch" => Ok(Operator::RegexMatch),
"RegexNotIMatch" => Ok(Operator::RegexNotIMatch),
"RegexNotMatch" => Ok(Operator::RegexNotMatch),
"StringConcat" => Ok(Operator::StringConcat),
other => Err(proto_error(format!(
"Unsupported binary operator '{:?}'",
other
))),
}
}
impl From<protobuf::JoinType> for JoinType {
fn from(t: protobuf::JoinType) -> Self {
match t {
protobuf::JoinType::Inner => JoinType::Inner,
protobuf::JoinType::Left => JoinType::Left,
protobuf::JoinType::Right => JoinType::Right,
protobuf::JoinType::Full => JoinType::Full,
protobuf::JoinType::Semi => JoinType::LeftSemi,
protobuf::JoinType::Anti => JoinType::LeftAnti,
protobuf::JoinType::Existence => JoinType::Existence,
}
}
}
impl From<protobuf::JoinSide> for JoinSide {
fn from(t: protobuf::JoinSide) -> Self {
match t {
protobuf::JoinSide::LeftSide => JoinSide::Left,
protobuf::JoinSide::RightSide => JoinSide::Right,
}
}
}
impl From<protobuf::AggFunction> for AggFunction {
fn from(agg_fun: protobuf::AggFunction) -> AggFunction {
match agg_fun {
protobuf::AggFunction::Min => AggFunction::Min,
protobuf::AggFunction::Max => AggFunction::Max,
protobuf::AggFunction::Sum => AggFunction::Sum,
protobuf::AggFunction::Avg => AggFunction::Avg,
protobuf::AggFunction::Count => AggFunction::Count,
protobuf::AggFunction::CollectList => AggFunction::CollectList,
protobuf::AggFunction::CollectSet => AggFunction::CollectSet,
protobuf::AggFunction::First => AggFunction::First,
protobuf::AggFunction::FirstIgnoresNull => AggFunction::FirstIgnoresNull,
protobuf::AggFunction::BrickhouseCollect => AggFunction::BrickhouseCollect,
protobuf::AggFunction::BrickhouseCombineUnique => AggFunction::BrickhouseCombineUnique,
}
}
}
impl protobuf::TimeUnit {
pub fn from_arrow_time_unit(val: &TimeUnit) -> Self {
match val {
TimeUnit::Second => protobuf::TimeUnit::Second,
TimeUnit::Millisecond => protobuf::TimeUnit::Millisecond,
TimeUnit::Microsecond => protobuf::TimeUnit::Microsecond,
TimeUnit::Nanosecond => protobuf::TimeUnit::Nanosecond,
}
}
pub fn try_from_to_arrow(time_unit_i32: i32) -> Result<TimeUnit, PlanSerDeError> {
let pb_time_unit = protobuf::TimeUnit::try_from(time_unit_i32).expect("invalid TimeUnit");
Ok(match pb_time_unit {
protobuf::TimeUnit::Second => TimeUnit::Second,
protobuf::TimeUnit::Millisecond => TimeUnit::Millisecond,
protobuf::TimeUnit::Microsecond => TimeUnit::Microsecond,
protobuf::TimeUnit::Nanosecond => TimeUnit::Nanosecond,
})
}
}
impl protobuf::IntervalUnit {
pub fn from_arrow_interval_unit(interval_unit: &IntervalUnit) -> Self {
match interval_unit {
IntervalUnit::YearMonth => protobuf::IntervalUnit::YearMonth,
IntervalUnit::DayTime => protobuf::IntervalUnit::DayTime,
IntervalUnit::MonthDayNano => protobuf::IntervalUnit::MonthDayNano,
}
}
pub fn try_from_to_arrow(interval_unit_i32: i32) -> Result<IntervalUnit, PlanSerDeError> {
let pb_interval_unit =
protobuf::IntervalUnit::try_from(interval_unit_i32).expect("invalid IntervalUnit");
Ok(match pb_interval_unit {
protobuf::IntervalUnit::YearMonth => IntervalUnit::YearMonth,
protobuf::IntervalUnit::DayTime => IntervalUnit::DayTime,
protobuf::IntervalUnit::MonthDayNano => IntervalUnit::MonthDayNano,
})
}
}
impl TryInto<arrow::datatypes::DataType> for &protobuf::arrow_type::ArrowTypeEnum {
type Error = PlanSerDeError;
fn try_into(self) -> Result<arrow::datatypes::DataType, Self::Error> {
use protobuf::arrow_type;
Ok(match self {
arrow_type::ArrowTypeEnum::None(_) => DataType::Null,
arrow_type::ArrowTypeEnum::Bool(_) => DataType::Boolean,
arrow_type::ArrowTypeEnum::Uint8(_) => DataType::UInt8,
arrow_type::ArrowTypeEnum::Int8(_) => DataType::Int8,
arrow_type::ArrowTypeEnum::Uint16(_) => DataType::UInt16,
arrow_type::ArrowTypeEnum::Int16(_) => DataType::Int16,
arrow_type::ArrowTypeEnum::Uint32(_) => DataType::UInt32,
arrow_type::ArrowTypeEnum::Int32(_) => DataType::Int32,
arrow_type::ArrowTypeEnum::Uint64(_) => DataType::UInt64,
arrow_type::ArrowTypeEnum::Int64(_) => DataType::Int64,
arrow_type::ArrowTypeEnum::Float16(_) => DataType::Float16,
arrow_type::ArrowTypeEnum::Float32(_) => DataType::Float32,
arrow_type::ArrowTypeEnum::Float64(_) => DataType::Float64,
arrow_type::ArrowTypeEnum::Utf8(_) => DataType::Utf8,
arrow_type::ArrowTypeEnum::LargeUtf8(_) => DataType::LargeUtf8,
arrow_type::ArrowTypeEnum::Binary(_) => DataType::Binary,
arrow_type::ArrowTypeEnum::FixedSizeBinary(size) => DataType::FixedSizeBinary(*size),
arrow_type::ArrowTypeEnum::LargeBinary(_) => DataType::LargeBinary,
arrow_type::ArrowTypeEnum::Date32(_) => DataType::Date32,
arrow_type::ArrowTypeEnum::Date64(_) => DataType::Date64,
arrow_type::ArrowTypeEnum::Duration(time_unit) => {
DataType::Duration(protobuf::TimeUnit::try_from_to_arrow(*time_unit)?)
}
arrow_type::ArrowTypeEnum::Timestamp(protobuf::Timestamp {
time_unit,
timezone,
}) => DataType::Timestamp(
protobuf::TimeUnit::try_from_to_arrow(*time_unit)?,
match timezone.len() {
0 => None,
_ => Some(timezone.to_owned().into()),
},
),
arrow_type::ArrowTypeEnum::Time32(time_unit) => {
DataType::Time32(protobuf::TimeUnit::try_from_to_arrow(*time_unit)?)
}
arrow_type::ArrowTypeEnum::Time64(time_unit) => {
DataType::Time64(protobuf::TimeUnit::try_from_to_arrow(*time_unit)?)
}
arrow_type::ArrowTypeEnum::Interval(interval_unit) => {
DataType::Interval(protobuf::IntervalUnit::try_from_to_arrow(*interval_unit)?)
}
arrow_type::ArrowTypeEnum::Decimal(protobuf::Decimal { whole, fractional }) => {
DataType::Decimal128(*whole as u8, *fractional as i8)
}
arrow_type::ArrowTypeEnum::List(list) => {
let list_type: &protobuf::Field = list
.as_ref()
.field_type
.as_ref()
.ok_or_else(|| proto_error("Protobuf deserialization error: List message missing required field 'field_type'"))?
.as_ref();
DataType::List(Arc::new(list_type.try_into()?))
}
arrow_type::ArrowTypeEnum::LargeList(list) => {
let list_type: &protobuf::Field = list
.as_ref()
.field_type
.as_ref()
.ok_or_else(|| proto_error("Protobuf deserialization error: List message missing required field 'field_type'"))?
.as_ref();
DataType::LargeList(Arc::new(list_type.try_into()?))
}
arrow_type::ArrowTypeEnum::FixedSizeList(list) => {
let list_type: &protobuf::Field = list
.as_ref()
.field_type
.as_ref()
.ok_or_else(|| proto_error("Protobuf deserialization error: List message missing required field 'field_type'"))?
.as_ref();
let list_size = list.list_size;
DataType::FixedSizeList(Arc::new(list_type.try_into()?), list_size)
}
arrow_type::ArrowTypeEnum::Struct(strct) => DataType::Struct(
strct
.sub_field_types
.iter()
.map(|field| Ok(Arc::new(field.try_into()?)))
.collect::<Result<Fields, Self::Error>>()?,
),
arrow_type::ArrowTypeEnum::Union(_union) => {
// let union_mode = protobuf::UnionMode::try_from(union.union_mode)
// .ok_or_else(|| {
// proto_error(
// "Protobuf deserialization error: Unknown union mode type",
// )
// })?;
// let union_mode = match union_mode {
// protobuf::UnionMode::Dense => UnionMode::Dense,
// protobuf::UnionMode::Sparse => UnionMode::Sparse,
// };
// let union_types = union
// .union_types
// .iter()
// .map(|field| field.try_into())
// .collect::<Result<Vec<_>, _>>()?;
// DataType::Union(union_types, _, union_mode)
unimplemented!()
}
arrow_type::ArrowTypeEnum::Map(map) => {
let key_type: &protobuf::Field = map
.as_ref()
.key_type
.as_ref()
.ok_or_else(|| proto_error("Protobuf deserialization error: Map message missing required field 'key_type'"))?
.as_ref();
let value_type: &protobuf::Field = map
.as_ref()
.value_type
.as_ref()
.ok_or_else(|| proto_error("Protobuf deserialization error: Map message missing required field 'value_type'"))?
.as_ref();
let vec_field = vec![
Arc::new(key_type.try_into()?),
Arc::new(value_type.try_into()?),
];
let fields = Arc::new(Field::new(
"entries",
DataType::Struct(vec_field.into()),
false,
));
DataType::Map(fields, false)
}
arrow_type::ArrowTypeEnum::Dictionary(dict) => {
let pb_key_datatype = dict
.as_ref()
.key
.as_ref()
.ok_or_else(|| proto_error("Protobuf deserialization error: Dictionary message missing required field 'key'"))?;
let pb_value_datatype = dict
.as_ref()
.value
.as_ref()
.ok_or_else(|| proto_error("Protobuf deserialization error: Dictionary message missing required field 'key'"))?;
let key_datatype: DataType = pb_key_datatype.as_ref().try_into()?;
let value_datatype: DataType = pb_value_datatype.as_ref().try_into()?;
DataType::Dictionary(Box::new(key_datatype), Box::new(value_datatype))
}
})
}
}
#[allow(clippy::from_over_into)]
impl Into<arrow::datatypes::DataType> for protobuf::PrimitiveScalarType {
fn into(self) -> arrow::datatypes::DataType {
match self {
protobuf::PrimitiveScalarType::Bool => DataType::Boolean,
protobuf::PrimitiveScalarType::Uint8 => DataType::UInt8,
protobuf::PrimitiveScalarType::Int8 => DataType::Int8,
protobuf::PrimitiveScalarType::Uint16 => DataType::UInt16,
protobuf::PrimitiveScalarType::Int16 => DataType::Int16,
protobuf::PrimitiveScalarType::Uint32 => DataType::UInt32,
protobuf::PrimitiveScalarType::Int32 => DataType::Int32,
protobuf::PrimitiveScalarType::Uint64 => DataType::UInt64,
protobuf::PrimitiveScalarType::Int64 => DataType::Int64,
protobuf::PrimitiveScalarType::Float32 => DataType::Float32,
protobuf::PrimitiveScalarType::Float64 => DataType::Float64,
protobuf::PrimitiveScalarType::Utf8 => DataType::Utf8,
protobuf::PrimitiveScalarType::LargeUtf8 => DataType::LargeUtf8,
protobuf::PrimitiveScalarType::Date32 => DataType::Date32,
protobuf::PrimitiveScalarType::Null => DataType::Null,
protobuf::PrimitiveScalarType::Decimal128 => DataType::Decimal128(1, 0),
protobuf::PrimitiveScalarType::Date64 => DataType::Date64,
protobuf::PrimitiveScalarType::TimestampSecond => {
DataType::Timestamp(TimeUnit::Second, None)
}
protobuf::PrimitiveScalarType::TimestampMillisecond => {
DataType::Timestamp(TimeUnit::Millisecond, None)
}
protobuf::PrimitiveScalarType::TimestampMicrosecond => {
DataType::Timestamp(TimeUnit::Microsecond, None)
}
protobuf::PrimitiveScalarType::TimestampNanosecond => {
DataType::Timestamp(TimeUnit::Nanosecond, None)
}
protobuf::PrimitiveScalarType::IntervalYearmonth => {
DataType::Interval(IntervalUnit::YearMonth)
}
protobuf::PrimitiveScalarType::IntervalDaytime => {
DataType::Interval(IntervalUnit::DayTime)
}
}
}
}
impl TryInto<DataType> for &protobuf::ArrowType {
type Error = PlanSerDeError;
fn try_into(self) -> Result<DataType, Self::Error> {
let pb_arrow_type = self.arrow_type_enum.as_ref().ok_or_else(|| {
proto_error(
"Protobuf deserialization error: ArrowType missing required field 'data_type'",
)
})?;
pb_arrow_type.try_into()
}
}
impl TryInto<DataType> for &Box<protobuf::List> {
type Error = PlanSerDeError;
fn try_into(self) -> Result<DataType, Self::Error> {
let list_ref = self.as_ref();
match &list_ref.field_type {
Some(pb_field) => {
let pb_field_ref = pb_field.as_ref();
let arrow_field: Field = pb_field_ref.try_into()?;
Ok(DataType::List(Arc::new(arrow_field)))
}
None => Err(proto_error(
"List message missing required field 'field_type'",
)),
}
}
}
impl TryInto<Field> for &protobuf::Field {
type Error = PlanSerDeError;
fn try_into(self) -> Result<Field, Self::Error> {
let pb_datatype = self.arrow_type.as_ref().ok_or_else(|| {
proto_error(
"Protobuf deserialization error: Field message missing required field 'arrow_type'",
)
})?;
Ok(Field::new(
self.name.as_str(),
pb_datatype.as_ref().try_into()?,
self.nullable,
))
}
}
impl TryInto<Schema> for &protobuf::Schema {
type Error = PlanSerDeError;
fn try_into(self) -> Result<Schema, PlanSerDeError> {
let fields = self
.columns
.iter()
.map(|c| {
let pb_arrow_type_res = c
.arrow_type
.as_ref()
.ok_or_else(|| proto_error("Protobuf deserialization error: Field message was missing required field 'arrow_type'"));
let pb_arrow_type: &protobuf::ArrowType = match pb_arrow_type_res {
Ok(res) => res,
Err(e) => return Err(e),
};
Ok(Field::new(&c.name, pb_arrow_type.try_into()?, c.nullable))
})
.collect::<Result<Vec<_>, _>>()?;
Ok(Schema::new(fields))
}
}
impl TryInto<datafusion::scalar::ScalarValue> for &protobuf::ScalarValue {
type Error = PlanSerDeError;
fn try_into(self) -> Result<datafusion::scalar::ScalarValue, Self::Error> {
let value = self.value.as_ref().ok_or_else(|| {
proto_error("Protobuf deserialization error: missing required field 'value'")
})?;
Ok(match value {
protobuf::scalar_value::Value::BoolValue(v) => ScalarValue::Boolean(Some(*v)),
protobuf::scalar_value::Value::Utf8Value(v) => ScalarValue::Utf8(Some(v.to_owned())),
protobuf::scalar_value::Value::LargeUtf8Value(v) => {
ScalarValue::LargeUtf8(Some(v.to_owned()))
}
protobuf::scalar_value::Value::Int8Value(v) => ScalarValue::Int8(Some(*v as i8)),
protobuf::scalar_value::Value::Int16Value(v) => ScalarValue::Int16(Some(*v as i16)),
protobuf::scalar_value::Value::Int32Value(v) => ScalarValue::Int32(Some(*v)),
protobuf::scalar_value::Value::Int64Value(v) => ScalarValue::Int64(Some(*v)),
protobuf::scalar_value::Value::Uint8Value(v) => ScalarValue::UInt8(Some(*v as u8)),
protobuf::scalar_value::Value::Uint16Value(v) => ScalarValue::UInt16(Some(*v as u16)),
protobuf::scalar_value::Value::Uint32Value(v) => ScalarValue::UInt32(Some(*v)),
protobuf::scalar_value::Value::Uint64Value(v) => ScalarValue::UInt64(Some(*v)),
protobuf::scalar_value::Value::Float32Value(v) => ScalarValue::Float32(Some(*v)),
protobuf::scalar_value::Value::Float64Value(v) => ScalarValue::Float64(Some(*v)),
protobuf::scalar_value::Value::Date32Value(v) => ScalarValue::Date32(Some(*v)),
protobuf::scalar_value::Value::TimestampSecondValue(v) => {
ScalarValue::TimestampSecond(Some(*v), None)
}
protobuf::scalar_value::Value::TimestampMillisecondValue(v) => {
ScalarValue::TimestampMillisecond(Some(*v), None)
}
protobuf::scalar_value::Value::TimestampMicrosecondValue(v) => {
ScalarValue::TimestampMicrosecond(Some(*v), None)
}
protobuf::scalar_value::Value::TimestampNanosecondValue(v) => {
ScalarValue::TimestampNanosecond(Some(*v), None)
}
protobuf::scalar_value::Value::DecimalValue(v) => {
let decimal = v.decimal.as_ref().unwrap();
ScalarValue::Decimal128(
Some(v.long_value as i128),
decimal.whole as u8,
decimal.fractional as i8,
)
}
protobuf::scalar_value::Value::ListValue(scalar_list) => {
let protobuf::ScalarListValue {
values,
datatype: opt_scalar_type,
} = &scalar_list;
let pb_scalar_type = opt_scalar_type
.as_ref()
.ok_or_else(|| proto_error("Protobuf deserialization err: ScalaListValue missing required field 'datatype'"))?;
let typechecked_values: Vec<ScalarValue> = values
.iter()
.map(|val| val.try_into())
.collect::<Result<Vec<_>, _>>()?;
let scalar_type: DataType = pb_scalar_type.try_into()?;
ScalarValue::List(ScalarValue::new_list(&typechecked_values, &scalar_type))
}
protobuf::scalar_value::Value::NullValue(v) => {
match v.datatype.as_ref().expect("missing scalar data type") {
protobuf::scalar_type::Datatype::Scalar(scalar) => {
let null_type_enum = protobuf::PrimitiveScalarType::try_from(*scalar)
.expect("invalid PrimitiveScalarType");
null_type_enum.try_into()?
}
protobuf::scalar_type::Datatype::List(list) => {
let pb_scalar_type = list
.element_type
.as_ref()
.expect("missing list element type");
let scalar_type: DataType = pb_scalar_type.as_ref().try_into()?;
ScalarValue::try_from(DataType::new_list(scalar_type, true))?
}
}
}
})
}
}
impl TryInto<DataType> for &protobuf::ScalarType {
type Error = PlanSerDeError;
fn try_into(self) -> Result<DataType, Self::Error> {
let pb_scalartype = self.datatype.as_ref().expect("missing data type");
pb_scalartype.try_into()
}
}
impl TryInto<DataType> for &protobuf::scalar_type::Datatype {
type Error = PlanSerDeError;
fn try_into(self) -> Result<DataType, Self::Error> {
use protobuf::scalar_type::Datatype;
Ok(match self {
Datatype::Scalar(scalar_type) => {
let pb_scalar_enum = protobuf::PrimitiveScalarType::try_from(*scalar_type)
.expect("invalid PrimitiveScalarType");
pb_scalar_enum.into()
}
Datatype::List(list_type) => {
let element_scalar_type: DataType = list_type
.element_type
.as_ref()
.expect("missing element type")
.as_ref()
.try_into()?;
DataType::new_list(element_scalar_type, true)
}
})
}
}
impl TryInto<datafusion::scalar::ScalarValue> for &protobuf::scalar_value::Value {
type Error = PlanSerDeError;
fn try_into(self) -> Result<datafusion::scalar::ScalarValue, Self::Error> {
use protobuf::PrimitiveScalarType;
let scalar = match self {
protobuf::scalar_value::Value::BoolValue(v) => ScalarValue::Boolean(Some(*v)),
protobuf::scalar_value::Value::Utf8Value(v) => ScalarValue::Utf8(Some(v.to_owned())),
protobuf::scalar_value::Value::LargeUtf8Value(v) => {
ScalarValue::LargeUtf8(Some(v.to_owned()))
}
protobuf::scalar_value::Value::Int8Value(v) => ScalarValue::Int8(Some(*v as i8)),
protobuf::scalar_value::Value::Int16Value(v) => ScalarValue::Int16(Some(*v as i16)),
protobuf::scalar_value::Value::Int32Value(v) => ScalarValue::Int32(Some(*v)),
protobuf::scalar_value::Value::Int64Value(v) => ScalarValue::Int64(Some(*v)),
protobuf::scalar_value::Value::Uint8Value(v) => ScalarValue::UInt8(Some(*v as u8)),
protobuf::scalar_value::Value::Uint16Value(v) => ScalarValue::UInt16(Some(*v as u16)),
protobuf::scalar_value::Value::Uint32Value(v) => ScalarValue::UInt32(Some(*v)),
protobuf::scalar_value::Value::Uint64Value(v) => ScalarValue::UInt64(Some(*v)),
protobuf::scalar_value::Value::Float32Value(v) => ScalarValue::Float32(Some(*v)),
protobuf::scalar_value::Value::Float64Value(v) => ScalarValue::Float64(Some(*v)),
protobuf::scalar_value::Value::Date32Value(v) => ScalarValue::Date32(Some(*v)),
protobuf::scalar_value::Value::TimestampSecondValue(v) => {
ScalarValue::TimestampSecond(Some(*v), None)
}
protobuf::scalar_value::Value::TimestampMillisecondValue(v) => {
ScalarValue::TimestampMillisecond(Some(*v), None)
}
protobuf::scalar_value::Value::TimestampMicrosecondValue(v) => {
ScalarValue::TimestampMicrosecond(Some(*v), None)
}
protobuf::scalar_value::Value::TimestampNanosecondValue(v) => {
ScalarValue::TimestampNanosecond(Some(*v), None)
}
protobuf::scalar_value::Value::ListValue(v) => v.try_into()?,
protobuf::scalar_value::Value::NullValue(v) => {
match v.datatype.as_ref().expect("missing null value type") {
protobuf::scalar_type::Datatype::Scalar(scalar) => {
PrimitiveScalarType::try_from(*scalar)
.expect("invalid PrimitiveScalarType")
.try_into()?
}
protobuf::scalar_type::Datatype::List(list) => {
let element_scalar_type: DataType = list
.element_type
.as_ref()
.expect("missing list element type")
.as_ref()
.try_into()?;
ScalarValue::try_from(DataType::new_list(element_scalar_type, true))?
}
}
}
protobuf::scalar_value::Value::DecimalValue(v) => {
let decimal = v.decimal.as_ref().unwrap();
ScalarValue::Decimal128(
Some(v.long_value as i128),
decimal.whole as u8,
decimal.fractional as i8,
)
}
};
Ok(scalar)
}
}
impl TryInto<ScalarValue> for &protobuf::ScalarListValue {
type Error = PlanSerDeError;
fn try_into(self) -> Result<ScalarValue, Self::Error> {
let element_scalar_type: DataType = self
.datatype
.as_ref()
.expect("missing list data type")
.try_into()?;
let values: Vec<ScalarValue> = self
.values
.iter()
.map(|value| Ok(value.try_into()?))
.collect::<Result<_, Self::Error>>()?;
Ok(ScalarValue::List(ScalarValue::new_list(
&values,
&element_scalar_type,
)))
}
}
impl TryInto<DataType> for &protobuf::ScalarListType {
type Error = PlanSerDeError;
fn try_into(self) -> Result<DataType, Self::Error> {
let element_scalar_type: DataType = self
.element_type
.as_ref()
.expect("missing list element type")
.as_ref()
.try_into()?;
Ok(DataType::new_list(element_scalar_type, true))
}
}
impl TryInto<datafusion::scalar::ScalarValue> for protobuf::PrimitiveScalarType {
type Error = PlanSerDeError;
fn try_into(self) -> Result<datafusion::scalar::ScalarValue, Self::Error> {
Ok(match self {
// protobuf::PrimitiveScalarType::Null => {
// return Err(proto_error("Untyped null is an invalid scalar value"))
// }
protobuf::PrimitiveScalarType::Null => ScalarValue::Null,
protobuf::PrimitiveScalarType::Bool => ScalarValue::Boolean(None),
protobuf::PrimitiveScalarType::Uint8 => ScalarValue::UInt8(None),
protobuf::PrimitiveScalarType::Int8 => ScalarValue::Int8(None),
protobuf::PrimitiveScalarType::Uint16 => ScalarValue::UInt16(None),
protobuf::PrimitiveScalarType::Int16 => ScalarValue::Int16(None),
protobuf::PrimitiveScalarType::Uint32 => ScalarValue::UInt32(None),
protobuf::PrimitiveScalarType::Int32 => ScalarValue::Int32(None),
protobuf::PrimitiveScalarType::Uint64 => ScalarValue::UInt64(None),
protobuf::PrimitiveScalarType::Int64 => ScalarValue::Int64(None),
protobuf::PrimitiveScalarType::Float32 => ScalarValue::Float32(None),
protobuf::PrimitiveScalarType::Float64 => ScalarValue::Float64(None),
protobuf::PrimitiveScalarType::Utf8 => ScalarValue::Utf8(None),
protobuf::PrimitiveScalarType::LargeUtf8 => ScalarValue::LargeUtf8(None),
protobuf::PrimitiveScalarType::Date32 => ScalarValue::Date32(None),
protobuf::PrimitiveScalarType::Decimal128 => ScalarValue::Decimal128(None, 1, 0),
protobuf::PrimitiveScalarType::Date64 => ScalarValue::Date64(None),
protobuf::PrimitiveScalarType::TimestampSecond => {
ScalarValue::TimestampSecond(None, None)
}
protobuf::PrimitiveScalarType::TimestampMillisecond => {
ScalarValue::TimestampMillisecond(None, None)
}
protobuf::PrimitiveScalarType::TimestampMicrosecond => {
ScalarValue::TimestampMicrosecond(None, None)
}
protobuf::PrimitiveScalarType::TimestampNanosecond => {
ScalarValue::TimestampNanosecond(None, None)
}
protobuf::PrimitiveScalarType::IntervalYearmonth => {
ScalarValue::IntervalYearMonth(None)
}
protobuf::PrimitiveScalarType::IntervalDaytime => ScalarValue::IntervalDayTime(None),
})
}
}