Improve StringView support for SUBSTR (#12044)
* operate stringview instead of generating string in SUBSTR
* treat Utf8View as Text in sqllogictests output
* add bench to see enhancement of utf8view against utf8 and large_utf8
* fix a tiny bug
* make clippy happy
* add tests to cover stringview larger than 12B and correct the code
* better comments
* fix lint
* correct feature setting
* avoid expensive utf8 and some other checks
* fix lint
* remove unnecessary indirection
* add optimized_utf8_to_str_type
* Simplify type check
* Use ByteView
* update datafusion-cli.lock
* Remove duration override
* format toml
* refactor the code, using append_view_u128 from arrow
* manually collect the views and nulls
* remove bench file and fix some comments
* fix tiny mistake
* Update Cargo.lock
---------
Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>
diff --git a/datafusion/functions/src/unicode/substr.rs b/datafusion/functions/src/unicode/substr.rs
index db218f9..8a70b38 100644
--- a/datafusion/functions/src/unicode/substr.rs
+++ b/datafusion/functions/src/unicode/substr.rs
@@ -19,18 +19,18 @@
use std::cmp::max;
use std::sync::Arc;
+use crate::utils::{make_scalar_function, utf8_to_str_type};
use arrow::array::{
- ArrayAccessor, ArrayIter, ArrayRef, AsArray, GenericStringArray, OffsetSizeTrait,
+ make_view, Array, ArrayAccessor, ArrayIter, ArrayRef, AsArray, ByteView,
+ GenericStringArray, OffsetSizeTrait, StringViewArray,
};
use arrow::datatypes::DataType;
-
+use arrow_buffer::{NullBufferBuilder, ScalarBuffer};
use datafusion_common::cast::as_int64_array;
use datafusion_common::{exec_datafusion_err, exec_err, Result};
use datafusion_expr::TypeSignature::Exact;
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
-use crate::utils::{make_scalar_function, utf8_to_str_type};
-
#[derive(Debug)]
pub struct SubstrFunc {
signature: Signature,
@@ -77,7 +77,11 @@
}
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
- utf8_to_str_type(&arg_types[0], "substr")
+ if arg_types[0] == DataType::Utf8View {
+ Ok(DataType::Utf8View)
+ } else {
+ utf8_to_str_type(&arg_types[0], "substr")
+ }
}
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
@@ -89,29 +93,188 @@
}
}
-pub fn substr(args: &[ArrayRef]) -> Result<ArrayRef> {
- match args[0].data_type() {
- DataType::Utf8 => {
- let string_array = args[0].as_string::<i32>();
- calculate_substr::<_, i32>(string_array, &args[1..])
- }
- DataType::LargeUtf8 => {
- let string_array = args[0].as_string::<i64>();
- calculate_substr::<_, i64>(string_array, &args[1..])
- }
- DataType::Utf8View => {
- let string_array = args[0].as_string_view();
- calculate_substr::<_, i32>(string_array, &args[1..])
- }
- other => exec_err!("Unsupported data type {other:?} for function substr"),
- }
-}
-
/// Extracts the substring of string starting at the start'th character, and extending for count characters if that is specified. (Same as substring(string from start for count).)
/// substr('alphabet', 3) = 'phabet'
/// substr('alphabet', 3, 2) = 'ph'
/// The implementation uses UTF-8 code points as characters
-fn calculate_substr<'a, V, T>(string_array: V, args: &[ArrayRef]) -> Result<ArrayRef>
+pub fn substr(args: &[ArrayRef]) -> Result<ArrayRef> {
+ match args[0].data_type() {
+ DataType::Utf8 => {
+ let string_array = args[0].as_string::<i32>();
+ string_substr::<_, i32>(string_array, &args[1..])
+ }
+ DataType::LargeUtf8 => {
+ let string_array = args[0].as_string::<i64>();
+ string_substr::<_, i64>(string_array, &args[1..])
+ }
+ DataType::Utf8View => {
+ let string_array = args[0].as_string_view();
+ string_view_substr(string_array, &args[1..])
+ }
+ other => exec_err!(
+ "Unsupported data type {other:?} for function substr,\
+ expected Utf8View, Utf8 or LargeUtf8."
+ ),
+ }
+}
+
+// Return the exact byte index for [start, end), set count to -1 to ignore count
+fn get_true_start_end(input: &str, start: usize, count: i64) -> (usize, usize) {
+ let (mut st, mut ed) = (input.len(), input.len());
+ let mut start_counting = false;
+ let mut cnt = 0;
+ for (char_cnt, (byte_cnt, _)) in input.char_indices().enumerate() {
+ if char_cnt == start {
+ st = byte_cnt;
+ if count != -1 {
+ start_counting = true;
+ } else {
+ break;
+ }
+ }
+ if start_counting {
+ if cnt == count {
+ ed = byte_cnt;
+ break;
+ }
+ cnt += 1;
+ }
+ }
+ (st, ed)
+}
+
+/// Make a `u128` based on the given substr, start(offset to view.offset), and
+/// push into to the given buffers
+fn make_and_append_view(
+ views_buffer: &mut Vec<u128>,
+ null_builder: &mut NullBufferBuilder,
+ raw: &u128,
+ substr: &str,
+ start: u32,
+) {
+ let substr_len = substr.len();
+ if substr_len == 0 {
+ null_builder.append_null();
+ views_buffer.push(0);
+ } else {
+ let sub_view = if substr_len > 12 {
+ let view = ByteView::from(*raw);
+ make_view(substr.as_bytes(), view.buffer_index, view.offset + start)
+ } else {
+ // inline value does not need block id or offset
+ make_view(substr.as_bytes(), 0, 0)
+ };
+ views_buffer.push(sub_view);
+ null_builder.append_non_null();
+ }
+}
+
+// The decoding process refs the trait at: arrow/arrow-data/src/byte_view.rs:44
+// From<u128> for ByteView
+fn string_view_substr(
+ string_view_array: &StringViewArray,
+ args: &[ArrayRef],
+) -> Result<ArrayRef> {
+ let mut views_buf = Vec::with_capacity(string_view_array.len());
+ let mut null_builder = NullBufferBuilder::new(string_view_array.len());
+
+ let start_array = as_int64_array(&args[0])?;
+
+ match args.len() {
+ 1 => {
+ for (idx, (raw, start)) in string_view_array
+ .views()
+ .iter()
+ .zip(start_array.iter())
+ .enumerate()
+ {
+ if let Some(start) = start {
+ let start = (start - 1).max(0) as usize;
+
+ // Safety:
+ // idx is always smaller or equal to string_view_array.views.len()
+ unsafe {
+ let str = string_view_array.value_unchecked(idx);
+ let (start, end) = get_true_start_end(str, start, -1);
+ let substr = &str[start..end];
+
+ make_and_append_view(
+ &mut views_buf,
+ &mut null_builder,
+ raw,
+ substr,
+ start as u32,
+ );
+ }
+ } else {
+ null_builder.append_null();
+ views_buf.push(0);
+ }
+ }
+ }
+ 2 => {
+ let count_array = as_int64_array(&args[1])?;
+ for (idx, ((raw, start), count)) in string_view_array
+ .views()
+ .iter()
+ .zip(start_array.iter())
+ .zip(count_array.iter())
+ .enumerate()
+ {
+ if let (Some(start), Some(count)) = (start, count) {
+ let start = (start - 1).max(0) as usize;
+ if count < 0 {
+ return exec_err!(
+ "negative substring length not allowed: substr(<str>, {start}, {count})"
+ );
+ } else {
+ // Safety:
+ // idx is always smaller or equal to string_view_array.views.len()
+ unsafe {
+ let str = string_view_array.value_unchecked(idx);
+ let (start, end) = get_true_start_end(str, start, count);
+ let substr = &str[start..end];
+
+ make_and_append_view(
+ &mut views_buf,
+ &mut null_builder,
+ raw,
+ substr,
+ start as u32,
+ );
+ }
+ }
+ } else {
+ null_builder.append_null();
+ views_buf.push(0);
+ }
+ }
+ }
+ other => {
+ return exec_err!(
+ "substr was called with {other} arguments. It requires 2 or 3."
+ )
+ }
+ }
+
+ let views_buf = ScalarBuffer::from(views_buf);
+ let nulls_buf = null_builder.finish();
+
+ // Safety:
+ // (1) The blocks of the given views are all provided
+ // (2) Each of the range `view.offset+start..end` of view in views_buf is within
+ // the bounds of each of the blocks
+ unsafe {
+ let array = StringViewArray::new_unchecked(
+ views_buf,
+ string_view_array.data_buffers().to_vec(),
+ nulls_buf,
+ );
+ Ok(Arc::new(array) as ArrayRef)
+ }
+}
+
+fn string_substr<'a, V, T>(string_array: V, args: &[ArrayRef]) -> Result<ArrayRef>
where
V: ArrayAccessor<Item = &'a str>,
T: OffsetSizeTrait,
@@ -174,8 +337,8 @@
#[cfg(test)]
mod tests {
- use arrow::array::{Array, StringArray};
- use arrow::datatypes::DataType::Utf8;
+ use arrow::array::{Array, StringArray, StringViewArray};
+ use arrow::datatypes::DataType::{Utf8, Utf8View};
use datafusion_common::{exec_err, Result, ScalarValue};
use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
@@ -193,8 +356,8 @@
],
Ok(None),
&str,
- Utf8,
- StringArray
+ Utf8View,
+ StringViewArray
);
test_function!(
SubstrFunc::new(),
@@ -206,8 +369,35 @@
],
Ok(Some("alphabet")),
&str,
- Utf8,
- StringArray
+ Utf8View,
+ StringViewArray
+ );
+ test_function!(
+ SubstrFunc::new(),
+ &[
+ ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
+ "this és longer than 12B"
+ )))),
+ ColumnarValue::Scalar(ScalarValue::from(5i64)),
+ ColumnarValue::Scalar(ScalarValue::from(2i64)),
+ ],
+ Ok(Some(" é")),
+ &str,
+ Utf8View,
+ StringViewArray
+ );
+ test_function!(
+ SubstrFunc::new(),
+ &[
+ ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
+ "this is longer than 12B"
+ )))),
+ ColumnarValue::Scalar(ScalarValue::from(5i64)),
+ ],
+ Ok(Some(" is longer than 12B")),
+ &str,
+ Utf8View,
+ StringViewArray
);
test_function!(
SubstrFunc::new(),
@@ -219,8 +409,8 @@
],
Ok(Some("ésoj")),
&str,
- Utf8,
- StringArray
+ Utf8View,
+ StringViewArray
);
test_function!(
SubstrFunc::new(),
@@ -233,8 +423,8 @@
],
Ok(Some("ph")),
&str,
- Utf8,
- StringArray
+ Utf8View,
+ StringViewArray
);
test_function!(
SubstrFunc::new(),
@@ -247,8 +437,8 @@
],
Ok(Some("phabet")),
&str,
- Utf8,
- StringArray
+ Utf8View,
+ StringViewArray
);
test_function!(
SubstrFunc::new(),