| // Licensed to the Apache Software Foundation (ASF) under one |
| // or more contributor license agreements. See the NOTICE file |
| // distributed with this work for additional information |
| // regarding copyright ownership. The ASF licenses this file |
| // to you under the Apache License, Version 2.0 (the |
| // "License"); you may not use this file except in compliance |
| // with the License. You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, |
| // software distributed under the License is distributed on an |
| // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| // KIND, either express or implied. See the License for the |
| // specific language governing permissions and limitations |
| // under the License. |
| |
| use std::sync::Arc; |
| |
| use arrow::{ |
| array::{as_dictionary_array, as_largestring_array, as_string_array}, |
| datatypes::Int32Type, |
| }; |
| use arrow_array::StringArray; |
| use arrow_schema::DataType; |
| use datafusion::logical_expr::ColumnarValue; |
| use datafusion_common::{ |
| cast::{as_binary_array, as_fixed_size_binary_array, as_int64_array}, |
| exec_err, DataFusionError, |
| }; |
| use std::fmt::Write; |
| |
| fn hex_int64(num: i64) -> String { |
| format!("{:X}", num) |
| } |
| |
| #[inline(always)] |
| fn hex_encode<T: AsRef<[u8]>>(data: T, lower_case: bool) -> String { |
| let mut s = String::with_capacity(data.as_ref().len() * 2); |
| if lower_case { |
| for b in data.as_ref() { |
| // Writing to a string never errors, so we can unwrap here. |
| write!(&mut s, "{b:02x}").unwrap(); |
| } |
| } else { |
| for b in data.as_ref() { |
| // Writing to a string never errors, so we can unwrap here. |
| write!(&mut s, "{b:02X}").unwrap(); |
| } |
| } |
| s |
| } |
| |
| #[inline(always)] |
| pub(crate) fn hex_strings<T: AsRef<[u8]>>(data: T) -> String { |
| hex_encode(data, true) |
| } |
| |
| #[inline(always)] |
| fn hex_bytes<T: AsRef<[u8]>>(bytes: T) -> Result<String, std::fmt::Error> { |
| let hex_string = hex_encode(bytes, false); |
| Ok(hex_string) |
| } |
| |
| /// Spark-compatible `hex` function |
| pub fn spark_hex(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> { |
| if args.len() != 1 { |
| return Err(DataFusionError::Internal( |
| "hex expects exactly one argument".to_string(), |
| )); |
| } |
| |
| match &args[0] { |
| ColumnarValue::Array(array) => match array.data_type() { |
| DataType::Int64 => { |
| let array = as_int64_array(array)?; |
| |
| let hexed_array: StringArray = array.iter().map(|v| v.map(hex_int64)).collect(); |
| |
| Ok(ColumnarValue::Array(Arc::new(hexed_array))) |
| } |
| DataType::Utf8 => { |
| let array = as_string_array(array); |
| |
| let hexed: StringArray = array |
| .iter() |
| .map(|v| v.map(hex_bytes).transpose()) |
| .collect::<Result<_, _>>()?; |
| |
| Ok(ColumnarValue::Array(Arc::new(hexed))) |
| } |
| DataType::LargeUtf8 => { |
| let array = as_largestring_array(array); |
| |
| let hexed: StringArray = array |
| .iter() |
| .map(|v| v.map(hex_bytes).transpose()) |
| .collect::<Result<_, _>>()?; |
| |
| Ok(ColumnarValue::Array(Arc::new(hexed))) |
| } |
| DataType::Binary => { |
| let array = as_binary_array(array)?; |
| |
| let hexed: StringArray = array |
| .iter() |
| .map(|v| v.map(hex_bytes).transpose()) |
| .collect::<Result<_, _>>()?; |
| |
| Ok(ColumnarValue::Array(Arc::new(hexed))) |
| } |
| DataType::FixedSizeBinary(_) => { |
| let array = as_fixed_size_binary_array(array)?; |
| |
| let hexed: StringArray = array |
| .iter() |
| .map(|v| v.map(hex_bytes).transpose()) |
| .collect::<Result<_, _>>()?; |
| |
| Ok(ColumnarValue::Array(Arc::new(hexed))) |
| } |
| DataType::Dictionary(_, value_type) => { |
| let dict = as_dictionary_array::<Int32Type>(&array); |
| |
| let values = match **value_type { |
| DataType::Int64 => as_int64_array(dict.values())? |
| .iter() |
| .map(|v| v.map(hex_int64)) |
| .collect::<Vec<_>>(), |
| DataType::Utf8 => as_string_array(dict.values()) |
| .iter() |
| .map(|v| v.map(hex_bytes).transpose()) |
| .collect::<Result<_, _>>()?, |
| DataType::Binary => as_binary_array(dict.values())? |
| .iter() |
| .map(|v| v.map(hex_bytes).transpose()) |
| .collect::<Result<_, _>>()?, |
| _ => exec_err!( |
| "hex got an unexpected argument type: {:?}", |
| array.data_type() |
| )?, |
| }; |
| |
| let new_values: Vec<Option<String>> = dict |
| .keys() |
| .iter() |
| .map(|key| key.map(|k| values[k as usize].clone()).unwrap_or(None)) |
| .collect(); |
| |
| let string_array_values = StringArray::from(new_values); |
| |
| Ok(ColumnarValue::Array(Arc::new(string_array_values))) |
| } |
| _ => exec_err!( |
| "hex got an unexpected argument type: {:?}", |
| array.data_type() |
| ), |
| }, |
| _ => exec_err!("native hex does not support scalar values at this time"), |
| } |
| } |
| |
| #[cfg(test)] |
| mod test { |
| use std::sync::Arc; |
| |
| use arrow::{ |
| array::{ |
| as_string_array, BinaryDictionaryBuilder, PrimitiveDictionaryBuilder, StringBuilder, |
| StringDictionaryBuilder, |
| }, |
| datatypes::{Int32Type, Int64Type}, |
| }; |
| use arrow_array::{Int64Array, StringArray}; |
| use datafusion::logical_expr::ColumnarValue; |
| |
| #[test] |
| fn test_dictionary_hex_utf8() { |
| let mut input_builder = StringDictionaryBuilder::<Int32Type>::new(); |
| input_builder.append_value("hi"); |
| input_builder.append_value("bye"); |
| input_builder.append_null(); |
| input_builder.append_value("rust"); |
| let input = input_builder.finish(); |
| |
| let mut string_builder = StringBuilder::new(); |
| string_builder.append_value("6869"); |
| string_builder.append_value("627965"); |
| string_builder.append_null(); |
| string_builder.append_value("72757374"); |
| let expected = string_builder.finish(); |
| |
| let columnar_value = ColumnarValue::Array(Arc::new(input)); |
| let result = super::spark_hex(&[columnar_value]).unwrap(); |
| |
| let result = match result { |
| ColumnarValue::Array(array) => array, |
| _ => panic!("Expected array"), |
| }; |
| |
| let result = as_string_array(&result); |
| |
| assert_eq!(result, &expected); |
| } |
| |
| #[test] |
| fn test_dictionary_hex_int64() { |
| let mut input_builder = PrimitiveDictionaryBuilder::<Int32Type, Int64Type>::new(); |
| input_builder.append_value(1); |
| input_builder.append_value(2); |
| input_builder.append_null(); |
| input_builder.append_value(3); |
| let input = input_builder.finish(); |
| |
| let mut string_builder = StringBuilder::new(); |
| string_builder.append_value("1"); |
| string_builder.append_value("2"); |
| string_builder.append_null(); |
| string_builder.append_value("3"); |
| let expected = string_builder.finish(); |
| |
| let columnar_value = ColumnarValue::Array(Arc::new(input)); |
| let result = super::spark_hex(&[columnar_value]).unwrap(); |
| |
| let result = match result { |
| ColumnarValue::Array(array) => array, |
| _ => panic!("Expected array"), |
| }; |
| |
| let result = as_string_array(&result); |
| |
| assert_eq!(result, &expected); |
| } |
| |
| #[test] |
| fn test_dictionary_hex_binary() { |
| let mut input_builder = BinaryDictionaryBuilder::<Int32Type>::new(); |
| input_builder.append_value("1"); |
| input_builder.append_value("j"); |
| input_builder.append_null(); |
| input_builder.append_value("3"); |
| let input = input_builder.finish(); |
| |
| let mut expected_builder = StringBuilder::new(); |
| expected_builder.append_value("31"); |
| expected_builder.append_value("6A"); |
| expected_builder.append_null(); |
| expected_builder.append_value("33"); |
| let expected = expected_builder.finish(); |
| |
| let columnar_value = ColumnarValue::Array(Arc::new(input)); |
| let result = super::spark_hex(&[columnar_value]).unwrap(); |
| |
| let result = match result { |
| ColumnarValue::Array(array) => array, |
| _ => panic!("Expected array"), |
| }; |
| |
| let result = as_string_array(&result); |
| |
| assert_eq!(result, &expected); |
| } |
| |
| #[test] |
| fn test_hex_int64() { |
| let num = 1234; |
| let hexed = super::hex_int64(num); |
| assert_eq!(hexed, "4D2".to_string()); |
| |
| let num = -1; |
| let hexed = super::hex_int64(num); |
| assert_eq!(hexed, "FFFFFFFFFFFFFFFF".to_string()); |
| } |
| |
| #[test] |
| fn test_spark_hex_int64() { |
| let int_array = Int64Array::from(vec![Some(1), Some(2), None, Some(3)]); |
| let columnar_value = ColumnarValue::Array(Arc::new(int_array)); |
| |
| let result = super::spark_hex(&[columnar_value]).unwrap(); |
| let result = match result { |
| ColumnarValue::Array(array) => array, |
| _ => panic!("Expected array"), |
| }; |
| |
| let string_array = as_string_array(&result); |
| let expected_array = StringArray::from(vec![ |
| Some("1".to_string()), |
| Some("2".to_string()), |
| None, |
| Some("3".to_string()), |
| ]); |
| |
| assert_eq!(string_array, &expected_array); |
| } |
| } |