ARROW-10354: [Rust][DataFusion] regexp_extract function to select regex groups from strings
Adds a regexp_extract compute kernel to select a substring based on a regular expression.
Some things I did that I may be doing wrong:
* I exposed `GenericStringBuilder`
* I build the resulting Array using a builder - this looks quite different from e.g. the substring kernel. Should I change it accordingly, e.g. because of performance considerations?
* In order to apply the new function in datafusion, I did not see a better solution than to handle the pattern string as `StringArray` and take the first record to compile the regex pattern from it and apply it to all values. Is there a way to define that an argument has to be a literal/scalar and cannot be filled by e.g. another column? I consider my current implementation quite error prone and would like to make this a bit more robust.
Closes #9428 from sweb/ARROW-10354/regexp_extract
Authored-by: Florian Müller <florian@tomueller.de>
Signed-off-by: Andrew Lamb <andrew@nerdnetworks.org>
diff --git a/rust/arrow/src/array/mod.rs b/rust/arrow/src/array/mod.rs
index c0073c0..65cf308 100644
--- a/rust/arrow/src/array/mod.rs
+++ b/rust/arrow/src/array/mod.rs
@@ -216,6 +216,7 @@
pub use self::builder::DecimalBuilder;
pub use self::builder::FixedSizeBinaryBuilder;
pub use self::builder::FixedSizeListBuilder;
+pub use self::builder::GenericStringBuilder;
pub use self::builder::LargeBinaryBuilder;
pub use self::builder::LargeListBuilder;
pub use self::builder::LargeStringBuilder;
diff --git a/rust/arrow/src/compute/kernels/mod.rs b/rust/arrow/src/compute/kernels/mod.rs
index a8d2497..862f55f 100644
--- a/rust/arrow/src/compute/kernels/mod.rs
+++ b/rust/arrow/src/compute/kernels/mod.rs
@@ -28,6 +28,7 @@
pub mod filter;
pub mod length;
pub mod limit;
+pub mod regexp;
pub mod sort;
pub mod substring;
pub mod take;
diff --git a/rust/arrow/src/compute/kernels/regexp.rs b/rust/arrow/src/compute/kernels/regexp.rs
new file mode 100644
index 0000000..446d71d
--- /dev/null
+++ b/rust/arrow/src/compute/kernels/regexp.rs
@@ -0,0 +1,160 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Defines kernel to extract substrings based on a regular
+//! expression of a \[Large\]StringArray
+
+use crate::array::{
+ ArrayRef, GenericStringArray, GenericStringBuilder, ListBuilder,
+ StringOffsetSizeTrait,
+};
+use crate::error::{ArrowError, Result};
+use std::collections::HashMap;
+
+use std::sync::Arc;
+
+use regex::Regex;
+
+/// Extract all groups matched by a regular expression for a given String array.
+pub fn regexp_match<OffsetSize: StringOffsetSizeTrait>(
+ array: &GenericStringArray<OffsetSize>,
+ regex_array: &GenericStringArray<OffsetSize>,
+ flags_array: Option<&GenericStringArray<OffsetSize>>,
+) -> Result<ArrayRef> {
+ let mut patterns: HashMap<String, Regex> = HashMap::new();
+ let builder: GenericStringBuilder<OffsetSize> = GenericStringBuilder::new(0);
+ let mut list_builder = ListBuilder::new(builder);
+
+ let complete_pattern = match flags_array {
+ Some(flags) => Box::new(regex_array.iter().zip(flags.iter()).map(
+ |(pattern, flags)| {
+ pattern.map(|pattern| match flags {
+ Some(value) => format!("(?{}){}", value, pattern),
+ None => pattern.to_string(),
+ })
+ },
+ )) as Box<dyn Iterator<Item = Option<String>>>,
+ None => Box::new(
+ regex_array
+ .iter()
+ .map(|pattern| pattern.map(|pattern| pattern.to_string())),
+ ),
+ };
+ array
+ .iter()
+ .zip(complete_pattern)
+ .map(|(value, pattern)| {
+ match (value, pattern) {
+ // Required for Postgres compatibility:
+ // SELECT regexp_match('foobarbequebaz', ''); = {""}
+ (Some(_), Some(pattern)) if pattern == *"" => {
+ list_builder.values().append_value("")?;
+ list_builder.append(true)?;
+ }
+ (Some(value), Some(pattern)) => {
+ let existing_pattern = patterns.get(&pattern);
+ let re = match existing_pattern {
+ Some(re) => re.clone(),
+ None => {
+ let re = Regex::new(pattern.as_str()).map_err(|e| {
+ ArrowError::ComputeError(format!(
+ "Regular expression did not compile: {:?}",
+ e
+ ))
+ })?;
+ patterns.insert(pattern, re.clone());
+ re
+ }
+ };
+ match re.captures(value) {
+ Some(caps) => {
+ for m in caps.iter().skip(1) {
+ if let Some(v) = m {
+ list_builder.values().append_value(v.as_str())?;
+ }
+ }
+ list_builder.append(true)?
+ }
+ None => list_builder.append(false)?,
+ }
+ }
+ _ => list_builder.append(false)?,
+ }
+ Ok(())
+ })
+ .collect::<Result<Vec<()>>>()?;
+ Ok(Arc::new(list_builder.finish()))
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::array::{ListArray, StringArray};
+
+ #[test]
+ fn match_single_group() -> Result<()> {
+ let values = vec![
+ Some("abc-005-def"),
+ Some("X-7-5"),
+ Some("X545"),
+ None,
+ Some("foobarbequebaz"),
+ Some("foobarbequebaz"),
+ ];
+ let array = StringArray::from(values);
+ let mut pattern_values = vec![r".*-(\d*)-.*"; 4];
+ pattern_values.push(r"(bar)(bequ1e)");
+ pattern_values.push("");
+ let pattern = StringArray::from(pattern_values);
+ let actual = regexp_match(&array, &pattern, None)?;
+ let elem_builder: GenericStringBuilder<i32> = GenericStringBuilder::new(0);
+ let mut expected_builder = ListBuilder::new(elem_builder);
+ expected_builder.values().append_value("005")?;
+ expected_builder.append(true)?;
+ expected_builder.values().append_value("7")?;
+ expected_builder.append(true)?;
+ expected_builder.append(false)?;
+ expected_builder.append(false)?;
+ expected_builder.append(false)?;
+ expected_builder.values().append_value("")?;
+ expected_builder.append(true)?;
+ let expected = expected_builder.finish();
+ let result = actual.as_any().downcast_ref::<ListArray>().unwrap();
+ assert_eq!(&expected, result);
+ Ok(())
+ }
+
+ #[test]
+ fn match_single_group_with_flags() -> Result<()> {
+ let values = vec![Some("abc-005-def"), Some("X-7-5"), Some("X545"), None];
+ let array = StringArray::from(values);
+ let pattern = StringArray::from(vec![r"x.*-(\d*)-.*"; 4]);
+ let flags = StringArray::from(vec!["i"; 4]);
+ let actual = regexp_match(&array, &pattern, Some(&flags))?;
+ let elem_builder: GenericStringBuilder<i32> = GenericStringBuilder::new(0);
+ let mut expected_builder = ListBuilder::new(elem_builder);
+ expected_builder.append(false)?;
+ expected_builder.values().append_value("7")?;
+ expected_builder.append(true)?;
+ expected_builder.append(false)?;
+ expected_builder.append(false)?;
+ let expected = expected_builder.finish();
+ let result = actual.as_any().downcast_ref::<ListArray>().unwrap();
+ assert_eq!(&expected, result);
+ Ok(())
+ }
+}
diff --git a/rust/arrow/src/compute/mod.rs b/rust/arrow/src/compute/mod.rs
index 9de0738..be1aa27 100644
--- a/rust/arrow/src/compute/mod.rs
+++ b/rust/arrow/src/compute/mod.rs
@@ -29,6 +29,7 @@
pub use self::kernels::concat::*;
pub use self::kernels::filter::*;
pub use self::kernels::limit::*;
+pub use self::kernels::regexp::*;
pub use self::kernels::sort::*;
pub use self::kernels::take::*;
pub use self::kernels::temporal::*;
diff --git a/rust/datafusion/src/logical_plan/expr.rs b/rust/datafusion/src/logical_plan/expr.rs
index 314f5d4..991b160 100644
--- a/rust/datafusion/src/logical_plan/expr.rs
+++ b/rust/datafusion/src/logical_plan/expr.rs
@@ -1090,6 +1090,7 @@
unary_scalar_expr!(Ltrim, ltrim);
unary_scalar_expr!(MD5, md5);
unary_scalar_expr!(OctetLength, octet_length);
+unary_scalar_expr!(RegexpMatch, regexp_match);
unary_scalar_expr!(RegexpReplace, regexp_replace);
unary_scalar_expr!(Replace, replace);
unary_scalar_expr!(Repeat, repeat);
diff --git a/rust/datafusion/src/logical_plan/mod.rs b/rust/datafusion/src/logical_plan/mod.rs
index 0e7e619..f9be1ff 100644
--- a/rust/datafusion/src/logical_plan/mod.rs
+++ b/rust/datafusion/src/logical_plan/mod.rs
@@ -37,10 +37,10 @@
ceil, character_length, chr, col, combine_filters, concat, concat_ws, cos, count,
count_distinct, create_udaf, create_udf, exp, exprlist_to_fields, floor, in_list,
initcap, left, length, lit, ln, log10, log2, lower, lpad, ltrim, max, md5, min,
- octet_length, or, regexp_replace, repeat, replace, reverse, right, round, rpad,
- rtrim, sha224, sha256, sha384, sha512, signum, sin, split_part, sqrt, starts_with,
- strpos, substr, sum, tan, to_hex, translate, trim, trunc, upper, when, Expr,
- ExprRewriter, ExpressionVisitor, Literal, Recursion,
+ octet_length, or, regexp_match, regexp_replace, repeat, replace, reverse, right,
+ round, rpad, rtrim, sha224, sha256, sha384, sha512, signum, sin, split_part, sqrt,
+ starts_with, strpos, substr, sum, tan, to_hex, translate, trim, trunc, upper, when,
+ Expr, ExprRewriter, ExpressionVisitor, Literal, Recursion,
};
pub use extension::UserDefinedLogicalNode;
pub use operators::Operator;
diff --git a/rust/datafusion/src/physical_plan/functions.rs b/rust/datafusion/src/physical_plan/functions.rs
index 9dc54a4..56365fe 100644
--- a/rust/datafusion/src/physical_plan/functions.rs
+++ b/rust/datafusion/src/physical_plan/functions.rs
@@ -198,6 +198,8 @@
Trim,
/// upper
Upper,
+ /// regexp_match
+ RegexpMatch,
}
impl fmt::Display for BuiltinScalarFunction {
@@ -271,7 +273,7 @@
"translate" => BuiltinScalarFunction::Translate,
"trim" => BuiltinScalarFunction::Trim,
"upper" => BuiltinScalarFunction::Upper,
-
+ "regexp_match" => BuiltinScalarFunction::RegexpMatch,
_ => {
return Err(DataFusionError::Plan(format!(
"There is no built-in function named {}",
@@ -607,6 +609,20 @@
));
}
}),
+ BuiltinScalarFunction::RegexpMatch => Ok(match arg_types[0] {
+ DataType::LargeUtf8 => {
+ DataType::List(Box::new(Field::new("item", DataType::LargeUtf8, true)))
+ }
+ DataType::Utf8 => {
+ DataType::List(Box::new(Field::new("item", DataType::Utf8, true)))
+ }
+ _ => {
+ // this error is internal as `data_types` should have captured this.
+ return Err(DataFusionError::Internal(
+ "The regexp_extract function can only accept strings.".to_string(),
+ ));
+ }
+ }),
BuiltinScalarFunction::Abs
| BuiltinScalarFunction::Acos
@@ -853,6 +869,28 @@
_ => unreachable!(),
},
},
+ BuiltinScalarFunction::RegexpMatch => |args| match args[0].data_type() {
+ DataType::Utf8 => {
+ let func = invoke_if_regex_expressions_feature_flag!(
+ regexp_match,
+ i32,
+ "regexp_match"
+ );
+ make_scalar_function(func)(args)
+ }
+ DataType::LargeUtf8 => {
+ let func = invoke_if_regex_expressions_feature_flag!(
+ regexp_match,
+ i64,
+ "regexp_match"
+ );
+ make_scalar_function(func)(args)
+ }
+ other => Err(DataFusionError::Internal(format!(
+ "Unsupported data type {:?} for function regexp_match",
+ other
+ ))),
+ },
BuiltinScalarFunction::RegexpReplace => |args| match args[0].data_type() {
DataType::Utf8 => {
let func = invoke_if_regex_expressions_feature_flag!(
@@ -1229,6 +1267,12 @@
BuiltinScalarFunction::NullIf => {
Signature::Uniform(2, SUPPORTED_NULLIF_TYPES.to_vec())
}
+ BuiltinScalarFunction::RegexpMatch => Signature::OneOf(vec![
+ Signature::Exact(vec![DataType::Utf8, DataType::Utf8]),
+ Signature::Exact(vec![DataType::LargeUtf8, DataType::Utf8]),
+ Signature::Exact(vec![DataType::Utf8, DataType::Utf8, DataType::Utf8]),
+ Signature::Exact(vec![DataType::LargeUtf8, DataType::Utf8, DataType::Utf8]),
+ ]),
// math expressions expect 1 argument of type f64 or f32
// priority is given to f64 because e.g. `sqrt(1i32)` is in IR (real numbers) and thus we
// return the best approximation for it (in f64).
@@ -1386,7 +1430,7 @@
use arrow::{
array::{
Array, ArrayRef, BinaryArray, BooleanArray, FixedSizeListArray, Float64Array,
- Int32Array, StringArray, UInt32Array, UInt64Array,
+ Int32Array, ListArray, StringArray, UInt32Array, UInt64Array,
},
datatypes::Field,
record_batch::RecordBatch,
@@ -3646,4 +3690,78 @@
"PrimitiveArray<UInt64>\n[\n 1,\n 1,\n]",
)
}
+
+ #[test]
+ #[cfg(feature = "regex_expressions")]
+ fn test_regexp_match() -> Result<()> {
+ let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]);
+
+ // concat(value, value)
+ let col_value: ArrayRef = Arc::new(StringArray::from(vec!["aaa-555"]));
+ let pattern = lit(ScalarValue::Utf8(Some(r".*-(\d*)".to_string())));
+ let columns: Vec<ArrayRef> = vec![col_value];
+ let expr = create_physical_expr(
+ &BuiltinScalarFunction::RegexpMatch,
+ &[col("a"), pattern],
+ &schema,
+ )?;
+
+ // type is correct
+ assert_eq!(
+ expr.data_type(&schema)?,
+ DataType::List(Box::new(Field::new("item", DataType::Utf8, true)))
+ );
+
+ // evaluate works
+ let batch = RecordBatch::try_new(Arc::new(schema.clone()), columns)?;
+ let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
+
+ // downcast works
+ let result = result.as_any().downcast_ref::<ListArray>().unwrap();
+ let first_row = result.value(0);
+ let first_row = first_row.as_any().downcast_ref::<StringArray>().unwrap();
+
+ // value is correct
+ let expected = "555".to_string();
+ assert_eq!(first_row.value(0), expected);
+
+ Ok(())
+ }
+
+ #[test]
+ #[cfg(feature = "regex_expressions")]
+ fn test_regexp_match_all_literals() -> Result<()> {
+ let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
+
+ // concat(value, value)
+ let col_value = lit(ScalarValue::Utf8(Some("aaa-555".to_string())));
+ let pattern = lit(ScalarValue::Utf8(Some(r".*-(\d*)".to_string())));
+ let columns: Vec<ArrayRef> = vec![Arc::new(Int32Array::from(vec![1]))];
+ let expr = create_physical_expr(
+ &BuiltinScalarFunction::RegexpMatch,
+ &[col_value, pattern],
+ &schema,
+ )?;
+
+ // type is correct
+ assert_eq!(
+ expr.data_type(&schema)?,
+ DataType::List(Box::new(Field::new("item", DataType::Utf8, true)))
+ );
+
+ // evaluate works
+ let batch = RecordBatch::try_new(Arc::new(schema.clone()), columns)?;
+ let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
+
+ // downcast works
+ let result = result.as_any().downcast_ref::<ListArray>().unwrap();
+ let first_row = result.value(0);
+ let first_row = first_row.as_any().downcast_ref::<StringArray>().unwrap();
+
+ // value is correct
+ let expected = "555".to_string();
+ assert_eq!(first_row.value(0), expected);
+
+ Ok(())
+ }
}
diff --git a/rust/datafusion/src/physical_plan/regex_expressions.rs b/rust/datafusion/src/physical_plan/regex_expressions.rs
index 6482424..b526e72 100644
--- a/rust/datafusion/src/physical_plan/regex_expressions.rs
+++ b/rust/datafusion/src/physical_plan/regex_expressions.rs
@@ -26,6 +26,7 @@
use crate::error::{DataFusionError, Result};
use arrow::array::{ArrayRef, GenericStringArray, StringOffsetSizeTrait};
+use arrow::compute;
use hashbrown::HashMap;
use regex::Regex;
@@ -43,6 +44,20 @@
}};
}
+/// extract a specific group from a string column, using a regular expression
+pub fn regexp_match<T: StringOffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
+ match args.len() {
+ 2 => compute::regexp_match(downcast_string_arg!(args[0], "string", T), downcast_string_arg!(args[1], "pattern", T), None)
+ .map_err(DataFusionError::ArrowError),
+ 3 => compute::regexp_match(downcast_string_arg!(args[0], "string", T), downcast_string_arg!(args[1], "pattern", T), Some(downcast_string_arg!(args[1], "flags", T)))
+ .map_err(DataFusionError::ArrowError),
+ other => Err(DataFusionError::Internal(format!(
+ "regexp_match was called with {} arguments. It requires at least 2 and at most 3.",
+ other
+ ))),
+ }
+}
+
/// replace POSIX capture groups (like \1) with Rust Regex group (like ${1})
/// used by regexp_replace
fn regex_replace_posix_groups(replacement: &str) -> String {
diff --git a/rust/datafusion/src/scalar.rs b/rust/datafusion/src/scalar.rs
index f0c7acf..b236775 100644
--- a/rust/datafusion/src/scalar.rs
+++ b/rust/datafusion/src/scalar.rs
@@ -115,7 +115,7 @@
for scalar_value in values {
match scalar_value {
ScalarValue::$SCALAR_TY(Some(v)) => {
- builder.values().append_value(*v).unwrap()
+ builder.values().append_value(v.clone()).unwrap()
}
ScalarValue::$SCALAR_TY(None) => {
builder.values().append_null().unwrap();
@@ -335,6 +335,10 @@
DataType::UInt16 => build_list!(UInt16Builder, UInt16, values, size),
DataType::UInt32 => build_list!(UInt32Builder, UInt32, values, size),
DataType::UInt64 => build_list!(UInt64Builder, UInt64, values, size),
+ DataType::Utf8 => build_list!(StringBuilder, Utf8, values, size),
+ DataType::LargeUtf8 => {
+ build_list!(LargeStringBuilder, LargeUtf8, values, size)
+ }
_ => panic!("Unexpected DataType for list"),
}),
ScalarValue::Date32(e) => match e {
diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs
index 15234a8..8c2c35e 100644
--- a/rust/datafusion/tests/sql.rs
+++ b/rust/datafusion/tests/sql.rs
@@ -2560,6 +2560,17 @@
test_expression!("'2' IN ('a','b',NULL,1)", "NULL");
test_expression!("'1' NOT IN ('a','b',NULL,1)", "false");
test_expression!("'2' NOT IN ('a','b',NULL,1)", "NULL");
+ test_expression!("regexp_match('foobarbequebaz', '')", "[]");
+ test_expression!(
+ "regexp_match('foobarbequebaz', '(bar)(beque)')",
+ "[bar, beque]"
+ );
+ test_expression!("regexp_match('foobarbequebaz', '(ba3r)(bequ34e)')", "NULL");
+ test_expression!("regexp_match('aaa-0', '.*-(\\d)')", "[0]");
+ test_expression!("regexp_match('bb-1', '.*-(\\d)')", "[1]");
+ test_expression!("regexp_match('aa', '.*-(\\d)')", "NULL");
+ test_expression!("regexp_match(NULL, '.*-(\\d)')", "NULL");
+ test_expression!("regexp_match('aaa-0', NULL)", "NULL");
Ok(())
}