blob: 61bc3be951d3319748afc3114a2a53f92dea577c [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
use crate::basic::{ConvertedType, Repetition};
use crate::errors::ParquetError::General;
use crate::errors::Result;
use crate::schema::types::{Type, TypePtr};
/// A utility trait to help user to traverse against parquet type.
pub trait TypeVisitor<R, C> {
/// Called when a primitive type hit.
fn visit_primitive(&mut self, primitive_type: TypePtr, context: C) -> Result<R>;
/// Default implementation when visiting a list.
///
/// It checks list type definition and calls `visit_list_with_item` with extracted
/// item type.
///
/// To fully understand this algorithm, please refer to
/// [parquet doc](https://github.com/apache/parquet-format/blob/master/LogicalTypes.md).
fn visit_list(&mut self, list_type: TypePtr, context: C) -> Result<R> {
match list_type.as_ref() {
Type::PrimitiveType { .. } => panic!(
"{:?} is a list type and can't be processed as primitive.",
list_type
),
Type::GroupType {
basic_info: _,
fields,
} if fields.len() == 1 => {
let list_item = fields.first().unwrap();
match list_item.as_ref() {
Type::PrimitiveType { .. } => {
if list_item.get_basic_info().repetition() == Repetition::REPEATED
{
self.visit_list_with_item(
list_type.clone(),
list_item.clone(),
context,
)
} else {
Err(General(
"Primitive element type of list must be repeated."
.to_string(),
))
}
}
Type::GroupType {
basic_info: _,
fields,
} => {
if fields.len() == 1
&& list_item.name() != "array"
&& list_item.name() != format!("{}_tuple", list_type.name())
{
self.visit_list_with_item(
list_type.clone(),
fields.first().unwrap().clone(),
context,
)
} else {
self.visit_list_with_item(
list_type.clone(),
list_item.clone(),
context,
)
}
}
}
}
_ => Err(General(
"Group element type of list can only contain one field.".to_string(),
)),
}
}
/// Called when a struct type hit.
fn visit_struct(&mut self, struct_type: TypePtr, context: C) -> Result<R>;
/// Called when a map type hit.
fn visit_map(&mut self, map_type: TypePtr, context: C) -> Result<R>;
/// A utility method which detects input type and calls corresponding method.
fn dispatch(&mut self, cur_type: TypePtr, context: C) -> Result<R> {
if cur_type.is_primitive() {
self.visit_primitive(cur_type, context)
} else {
match cur_type.get_basic_info().converted_type() {
ConvertedType::LIST => self.visit_list(cur_type, context),
ConvertedType::MAP | ConvertedType::MAP_KEY_VALUE => {
self.visit_map(cur_type, context)
}
_ => self.visit_struct(cur_type, context),
}
}
}
/// Called by `visit_list`.
fn visit_list_with_item(
&mut self,
list_type: TypePtr,
item_type: TypePtr,
context: C,
) -> Result<R>;
}
#[cfg(test)]
mod tests {
use super::TypeVisitor;
use crate::basic::Type as PhysicalType;
use crate::errors::Result;
use crate::schema::parser::parse_message_type;
use crate::schema::types::TypePtr;
use std::sync::Arc;
struct TestVisitorContext {}
struct TestVisitor {
primitive_visited: bool,
struct_visited: bool,
list_visited: bool,
root_type: TypePtr,
}
impl TypeVisitor<bool, TestVisitorContext> for TestVisitor {
fn visit_primitive(
&mut self,
primitive_type: TypePtr,
_context: TestVisitorContext,
) -> Result<bool> {
assert_eq!(
self.get_field_by_name(primitive_type.name()).as_ref(),
primitive_type.as_ref()
);
self.primitive_visited = true;
Ok(true)
}
fn visit_struct(
&mut self,
struct_type: TypePtr,
_context: TestVisitorContext,
) -> Result<bool> {
assert_eq!(
self.get_field_by_name(struct_type.name()).as_ref(),
struct_type.as_ref()
);
self.struct_visited = true;
Ok(true)
}
fn visit_map(
&mut self,
_map_type: TypePtr,
_context: TestVisitorContext,
) -> Result<bool> {
unimplemented!()
}
fn visit_list_with_item(
&mut self,
list_type: TypePtr,
item_type: TypePtr,
_context: TestVisitorContext,
) -> Result<bool> {
assert_eq!(
self.get_field_by_name(list_type.name()).as_ref(),
list_type.as_ref()
);
assert_eq!("element", item_type.name());
assert_eq!(PhysicalType::INT32, item_type.get_physical_type());
self.list_visited = true;
Ok(true)
}
}
impl TestVisitor {
fn new(root: TypePtr) -> Self {
Self {
primitive_visited: false,
struct_visited: false,
list_visited: false,
root_type: root,
}
}
fn get_field_by_name(&self, name: &str) -> TypePtr {
self.root_type
.get_fields()
.iter()
.find(|t| t.name() == name)
.cloned()
.unwrap()
}
}
#[test]
fn test_visitor() {
let message_type = "
message spark_schema {
REQUIRED INT32 a;
OPTIONAL group inner_schema {
REQUIRED INT32 b;
REQUIRED DOUBLE c;
}
OPTIONAL group e (LIST) {
REPEATED group list {
REQUIRED INT32 element;
}
}
";
let parquet_type = Arc::new(parse_message_type(&message_type).unwrap());
let mut visitor = TestVisitor::new(parquet_type.clone());
for f in parquet_type.get_fields() {
let c = TestVisitorContext {};
assert!(visitor.dispatch(f.clone(), c).unwrap());
}
assert!(visitor.struct_visited);
assert!(visitor.primitive_visited);
assert!(visitor.list_visited);
}
}