ARROW-10354: [Rust][DataFusion] regexp_extract function to select regex groups from strings

Adds a regexp_extract compute kernel to select a substring based on a regular expression.

Some things I did that I may be doing wrong:

* I exposed `GenericStringBuilder`
* I build the resulting Array using a builder - this looks quite different from e.g. the substring kernel. Should I change it accordingly, e.g. because of performance considerations?
* In order to apply the new function in datafusion, I did not see a better solution than to handle the pattern string as `StringArray` and take the first record to compile the regex pattern from it and apply it to all values. Is there a way to define that an argument has to be a literal/scalar and cannot be filled by e.g. another column? I consider my current implementation quite error prone and would like to make this a bit more robust.

Closes #9428 from sweb/ARROW-10354/regexp_extract

Authored-by: Florian Müller <florian@tomueller.de>
Signed-off-by: Andrew Lamb <andrew@nerdnetworks.org>
diff --git a/rust/arrow/src/array/mod.rs b/rust/arrow/src/array/mod.rs
index c0073c0..65cf308 100644
--- a/rust/arrow/src/array/mod.rs
+++ b/rust/arrow/src/array/mod.rs
@@ -216,6 +216,7 @@
 pub use self::builder::DecimalBuilder;
 pub use self::builder::FixedSizeBinaryBuilder;
 pub use self::builder::FixedSizeListBuilder;
+pub use self::builder::GenericStringBuilder;
 pub use self::builder::LargeBinaryBuilder;
 pub use self::builder::LargeListBuilder;
 pub use self::builder::LargeStringBuilder;
diff --git a/rust/arrow/src/compute/kernels/mod.rs b/rust/arrow/src/compute/kernels/mod.rs
index a8d2497..862f55f 100644
--- a/rust/arrow/src/compute/kernels/mod.rs
+++ b/rust/arrow/src/compute/kernels/mod.rs
@@ -28,6 +28,7 @@
 pub mod filter;
 pub mod length;
 pub mod limit;
+pub mod regexp;
 pub mod sort;
 pub mod substring;
 pub mod take;
diff --git a/rust/arrow/src/compute/kernels/regexp.rs b/rust/arrow/src/compute/kernels/regexp.rs
new file mode 100644
index 0000000..446d71d
--- /dev/null
+++ b/rust/arrow/src/compute/kernels/regexp.rs
@@ -0,0 +1,160 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Defines kernel to extract substrings based on a regular
+//! expression of a \[Large\]StringArray
+
+use crate::array::{
+    ArrayRef, GenericStringArray, GenericStringBuilder, ListBuilder,
+    StringOffsetSizeTrait,
+};
+use crate::error::{ArrowError, Result};
+use std::collections::HashMap;
+
+use std::sync::Arc;
+
+use regex::Regex;
+
+/// Extract all groups matched by a regular expression for a given String array.
+pub fn regexp_match<OffsetSize: StringOffsetSizeTrait>(
+    array: &GenericStringArray<OffsetSize>,
+    regex_array: &GenericStringArray<OffsetSize>,
+    flags_array: Option<&GenericStringArray<OffsetSize>>,
+) -> Result<ArrayRef> {
+    let mut patterns: HashMap<String, Regex> = HashMap::new();
+    let builder: GenericStringBuilder<OffsetSize> = GenericStringBuilder::new(0);
+    let mut list_builder = ListBuilder::new(builder);
+
+    let complete_pattern = match flags_array {
+        Some(flags) => Box::new(regex_array.iter().zip(flags.iter()).map(
+            |(pattern, flags)| {
+                pattern.map(|pattern| match flags {
+                    Some(value) => format!("(?{}){}", value, pattern),
+                    None => pattern.to_string(),
+                })
+            },
+        )) as Box<dyn Iterator<Item = Option<String>>>,
+        None => Box::new(
+            regex_array
+                .iter()
+                .map(|pattern| pattern.map(|pattern| pattern.to_string())),
+        ),
+    };
+    array
+        .iter()
+        .zip(complete_pattern)
+        .map(|(value, pattern)| {
+            match (value, pattern) {
+                // Required for Postgres compatibility:
+                // SELECT regexp_match('foobarbequebaz', ''); = {""}
+                (Some(_), Some(pattern)) if pattern == *"" => {
+                    list_builder.values().append_value("")?;
+                    list_builder.append(true)?;
+                }
+                (Some(value), Some(pattern)) => {
+                    let existing_pattern = patterns.get(&pattern);
+                    let re = match existing_pattern {
+                        Some(re) => re.clone(),
+                        None => {
+                            let re = Regex::new(pattern.as_str()).map_err(|e| {
+                                ArrowError::ComputeError(format!(
+                                    "Regular expression did not compile: {:?}",
+                                    e
+                                ))
+                            })?;
+                            patterns.insert(pattern, re.clone());
+                            re
+                        }
+                    };
+                    match re.captures(value) {
+                        Some(caps) => {
+                            for m in caps.iter().skip(1) {
+                                if let Some(v) = m {
+                                    list_builder.values().append_value(v.as_str())?;
+                                }
+                            }
+                            list_builder.append(true)?
+                        }
+                        None => list_builder.append(false)?,
+                    }
+                }
+                _ => list_builder.append(false)?,
+            }
+            Ok(())
+        })
+        .collect::<Result<Vec<()>>>()?;
+    Ok(Arc::new(list_builder.finish()))
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use crate::array::{ListArray, StringArray};
+
+    #[test]
+    fn match_single_group() -> Result<()> {
+        let values = vec![
+            Some("abc-005-def"),
+            Some("X-7-5"),
+            Some("X545"),
+            None,
+            Some("foobarbequebaz"),
+            Some("foobarbequebaz"),
+        ];
+        let array = StringArray::from(values);
+        let mut pattern_values = vec![r".*-(\d*)-.*"; 4];
+        pattern_values.push(r"(bar)(bequ1e)");
+        pattern_values.push("");
+        let pattern = StringArray::from(pattern_values);
+        let actual = regexp_match(&array, &pattern, None)?;
+        let elem_builder: GenericStringBuilder<i32> = GenericStringBuilder::new(0);
+        let mut expected_builder = ListBuilder::new(elem_builder);
+        expected_builder.values().append_value("005")?;
+        expected_builder.append(true)?;
+        expected_builder.values().append_value("7")?;
+        expected_builder.append(true)?;
+        expected_builder.append(false)?;
+        expected_builder.append(false)?;
+        expected_builder.append(false)?;
+        expected_builder.values().append_value("")?;
+        expected_builder.append(true)?;
+        let expected = expected_builder.finish();
+        let result = actual.as_any().downcast_ref::<ListArray>().unwrap();
+        assert_eq!(&expected, result);
+        Ok(())
+    }
+
+    #[test]
+    fn match_single_group_with_flags() -> Result<()> {
+        let values = vec![Some("abc-005-def"), Some("X-7-5"), Some("X545"), None];
+        let array = StringArray::from(values);
+        let pattern = StringArray::from(vec![r"x.*-(\d*)-.*"; 4]);
+        let flags = StringArray::from(vec!["i"; 4]);
+        let actual = regexp_match(&array, &pattern, Some(&flags))?;
+        let elem_builder: GenericStringBuilder<i32> = GenericStringBuilder::new(0);
+        let mut expected_builder = ListBuilder::new(elem_builder);
+        expected_builder.append(false)?;
+        expected_builder.values().append_value("7")?;
+        expected_builder.append(true)?;
+        expected_builder.append(false)?;
+        expected_builder.append(false)?;
+        let expected = expected_builder.finish();
+        let result = actual.as_any().downcast_ref::<ListArray>().unwrap();
+        assert_eq!(&expected, result);
+        Ok(())
+    }
+}
diff --git a/rust/arrow/src/compute/mod.rs b/rust/arrow/src/compute/mod.rs
index 9de0738..be1aa27 100644
--- a/rust/arrow/src/compute/mod.rs
+++ b/rust/arrow/src/compute/mod.rs
@@ -29,6 +29,7 @@
 pub use self::kernels::concat::*;
 pub use self::kernels::filter::*;
 pub use self::kernels::limit::*;
+pub use self::kernels::regexp::*;
 pub use self::kernels::sort::*;
 pub use self::kernels::take::*;
 pub use self::kernels::temporal::*;
diff --git a/rust/datafusion/src/logical_plan/expr.rs b/rust/datafusion/src/logical_plan/expr.rs
index 314f5d4..991b160 100644
--- a/rust/datafusion/src/logical_plan/expr.rs
+++ b/rust/datafusion/src/logical_plan/expr.rs
@@ -1090,6 +1090,7 @@
 unary_scalar_expr!(Ltrim, ltrim);
 unary_scalar_expr!(MD5, md5);
 unary_scalar_expr!(OctetLength, octet_length);
+unary_scalar_expr!(RegexpMatch, regexp_match);
 unary_scalar_expr!(RegexpReplace, regexp_replace);
 unary_scalar_expr!(Replace, replace);
 unary_scalar_expr!(Repeat, repeat);
diff --git a/rust/datafusion/src/logical_plan/mod.rs b/rust/datafusion/src/logical_plan/mod.rs
index 0e7e619..f9be1ff 100644
--- a/rust/datafusion/src/logical_plan/mod.rs
+++ b/rust/datafusion/src/logical_plan/mod.rs
@@ -37,10 +37,10 @@
     ceil, character_length, chr, col, combine_filters, concat, concat_ws, cos, count,
     count_distinct, create_udaf, create_udf, exp, exprlist_to_fields, floor, in_list,
     initcap, left, length, lit, ln, log10, log2, lower, lpad, ltrim, max, md5, min,
-    octet_length, or, regexp_replace, repeat, replace, reverse, right, round, rpad,
-    rtrim, sha224, sha256, sha384, sha512, signum, sin, split_part, sqrt, starts_with,
-    strpos, substr, sum, tan, to_hex, translate, trim, trunc, upper, when, Expr,
-    ExprRewriter, ExpressionVisitor, Literal, Recursion,
+    octet_length, or, regexp_match, regexp_replace, repeat, replace, reverse, right,
+    round, rpad, rtrim, sha224, sha256, sha384, sha512, signum, sin, split_part, sqrt,
+    starts_with, strpos, substr, sum, tan, to_hex, translate, trim, trunc, upper, when,
+    Expr, ExprRewriter, ExpressionVisitor, Literal, Recursion,
 };
 pub use extension::UserDefinedLogicalNode;
 pub use operators::Operator;
diff --git a/rust/datafusion/src/physical_plan/functions.rs b/rust/datafusion/src/physical_plan/functions.rs
index 9dc54a4..56365fe 100644
--- a/rust/datafusion/src/physical_plan/functions.rs
+++ b/rust/datafusion/src/physical_plan/functions.rs
@@ -198,6 +198,8 @@
     Trim,
     /// upper
     Upper,
+    /// regexp_match
+    RegexpMatch,
 }
 
 impl fmt::Display for BuiltinScalarFunction {
@@ -271,7 +273,7 @@
             "translate" => BuiltinScalarFunction::Translate,
             "trim" => BuiltinScalarFunction::Trim,
             "upper" => BuiltinScalarFunction::Upper,
-
+            "regexp_match" => BuiltinScalarFunction::RegexpMatch,
             _ => {
                 return Err(DataFusionError::Plan(format!(
                     "There is no built-in function named {}",
@@ -607,6 +609,20 @@
                 ));
             }
         }),
+        BuiltinScalarFunction::RegexpMatch => Ok(match arg_types[0] {
+            DataType::LargeUtf8 => {
+                DataType::List(Box::new(Field::new("item", DataType::LargeUtf8, true)))
+            }
+            DataType::Utf8 => {
+                DataType::List(Box::new(Field::new("item", DataType::Utf8, true)))
+            }
+            _ => {
+                // this error is internal as `data_types` should have captured this.
+                return Err(DataFusionError::Internal(
+                    "The regexp_extract function can only accept strings.".to_string(),
+                ));
+            }
+        }),
 
         BuiltinScalarFunction::Abs
         | BuiltinScalarFunction::Acos
@@ -853,6 +869,28 @@
                 _ => unreachable!(),
             },
         },
+        BuiltinScalarFunction::RegexpMatch => |args| match args[0].data_type() {
+            DataType::Utf8 => {
+                let func = invoke_if_regex_expressions_feature_flag!(
+                    regexp_match,
+                    i32,
+                    "regexp_match"
+                );
+                make_scalar_function(func)(args)
+            }
+            DataType::LargeUtf8 => {
+                let func = invoke_if_regex_expressions_feature_flag!(
+                    regexp_match,
+                    i64,
+                    "regexp_match"
+                );
+                make_scalar_function(func)(args)
+            }
+            other => Err(DataFusionError::Internal(format!(
+                "Unsupported data type {:?} for function regexp_match",
+                other
+            ))),
+        },
         BuiltinScalarFunction::RegexpReplace => |args| match args[0].data_type() {
             DataType::Utf8 => {
                 let func = invoke_if_regex_expressions_feature_flag!(
@@ -1229,6 +1267,12 @@
         BuiltinScalarFunction::NullIf => {
             Signature::Uniform(2, SUPPORTED_NULLIF_TYPES.to_vec())
         }
+        BuiltinScalarFunction::RegexpMatch => Signature::OneOf(vec![
+            Signature::Exact(vec![DataType::Utf8, DataType::Utf8]),
+            Signature::Exact(vec![DataType::LargeUtf8, DataType::Utf8]),
+            Signature::Exact(vec![DataType::Utf8, DataType::Utf8, DataType::Utf8]),
+            Signature::Exact(vec![DataType::LargeUtf8, DataType::Utf8, DataType::Utf8]),
+        ]),
         // math expressions expect 1 argument of type f64 or f32
         // priority is given to f64 because e.g. `sqrt(1i32)` is in IR (real numbers) and thus we
         // return the best approximation for it (in f64).
@@ -1386,7 +1430,7 @@
     use arrow::{
         array::{
             Array, ArrayRef, BinaryArray, BooleanArray, FixedSizeListArray, Float64Array,
-            Int32Array, StringArray, UInt32Array, UInt64Array,
+            Int32Array, ListArray, StringArray, UInt32Array, UInt64Array,
         },
         datatypes::Field,
         record_batch::RecordBatch,
@@ -3646,4 +3690,78 @@
             "PrimitiveArray<UInt64>\n[\n  1,\n  1,\n]",
         )
     }
+
+    #[test]
+    #[cfg(feature = "regex_expressions")]
+    fn test_regexp_match() -> Result<()> {
+        let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]);
+
+        // concat(value, value)
+        let col_value: ArrayRef = Arc::new(StringArray::from(vec!["aaa-555"]));
+        let pattern = lit(ScalarValue::Utf8(Some(r".*-(\d*)".to_string())));
+        let columns: Vec<ArrayRef> = vec![col_value];
+        let expr = create_physical_expr(
+            &BuiltinScalarFunction::RegexpMatch,
+            &[col("a"), pattern],
+            &schema,
+        )?;
+
+        // type is correct
+        assert_eq!(
+            expr.data_type(&schema)?,
+            DataType::List(Box::new(Field::new("item", DataType::Utf8, true)))
+        );
+
+        // evaluate works
+        let batch = RecordBatch::try_new(Arc::new(schema.clone()), columns)?;
+        let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
+
+        // downcast works
+        let result = result.as_any().downcast_ref::<ListArray>().unwrap();
+        let first_row = result.value(0);
+        let first_row = first_row.as_any().downcast_ref::<StringArray>().unwrap();
+
+        // value is correct
+        let expected = "555".to_string();
+        assert_eq!(first_row.value(0), expected);
+
+        Ok(())
+    }
+
+    #[test]
+    #[cfg(feature = "regex_expressions")]
+    fn test_regexp_match_all_literals() -> Result<()> {
+        let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
+
+        // concat(value, value)
+        let col_value = lit(ScalarValue::Utf8(Some("aaa-555".to_string())));
+        let pattern = lit(ScalarValue::Utf8(Some(r".*-(\d*)".to_string())));
+        let columns: Vec<ArrayRef> = vec![Arc::new(Int32Array::from(vec![1]))];
+        let expr = create_physical_expr(
+            &BuiltinScalarFunction::RegexpMatch,
+            &[col_value, pattern],
+            &schema,
+        )?;
+
+        // type is correct
+        assert_eq!(
+            expr.data_type(&schema)?,
+            DataType::List(Box::new(Field::new("item", DataType::Utf8, true)))
+        );
+
+        // evaluate works
+        let batch = RecordBatch::try_new(Arc::new(schema.clone()), columns)?;
+        let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
+
+        // downcast works
+        let result = result.as_any().downcast_ref::<ListArray>().unwrap();
+        let first_row = result.value(0);
+        let first_row = first_row.as_any().downcast_ref::<StringArray>().unwrap();
+
+        // value is correct
+        let expected = "555".to_string();
+        assert_eq!(first_row.value(0), expected);
+
+        Ok(())
+    }
 }
diff --git a/rust/datafusion/src/physical_plan/regex_expressions.rs b/rust/datafusion/src/physical_plan/regex_expressions.rs
index 6482424..b526e72 100644
--- a/rust/datafusion/src/physical_plan/regex_expressions.rs
+++ b/rust/datafusion/src/physical_plan/regex_expressions.rs
@@ -26,6 +26,7 @@
 
 use crate::error::{DataFusionError, Result};
 use arrow::array::{ArrayRef, GenericStringArray, StringOffsetSizeTrait};
+use arrow::compute;
 use hashbrown::HashMap;
 use regex::Regex;
 
@@ -43,6 +44,20 @@
     }};
 }
 
+/// extract a specific group from a string column, using a regular expression
+pub fn regexp_match<T: StringOffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
+    match args.len() {
+        2 => compute::regexp_match(downcast_string_arg!(args[0], "string", T), downcast_string_arg!(args[1], "pattern", T), None)
+        .map_err(DataFusionError::ArrowError),
+        3 => compute::regexp_match(downcast_string_arg!(args[0], "string", T), downcast_string_arg!(args[1], "pattern", T),  Some(downcast_string_arg!(args[1], "flags", T)))
+        .map_err(DataFusionError::ArrowError),
+        other => Err(DataFusionError::Internal(format!(
+            "regexp_match was called with {} arguments. It requires at least 2 and at most 3.",
+            other
+        ))),
+    }
+}
+
 /// replace POSIX capture groups (like \1) with Rust Regex group (like ${1})
 /// used by regexp_replace
 fn regex_replace_posix_groups(replacement: &str) -> String {
diff --git a/rust/datafusion/src/scalar.rs b/rust/datafusion/src/scalar.rs
index f0c7acf..b236775 100644
--- a/rust/datafusion/src/scalar.rs
+++ b/rust/datafusion/src/scalar.rs
@@ -115,7 +115,7 @@
                     for scalar_value in values {
                         match scalar_value {
                             ScalarValue::$SCALAR_TY(Some(v)) => {
-                                builder.values().append_value(*v).unwrap()
+                                builder.values().append_value(v.clone()).unwrap()
                             }
                             ScalarValue::$SCALAR_TY(None) => {
                                 builder.values().append_null().unwrap();
@@ -335,6 +335,10 @@
                 DataType::UInt16 => build_list!(UInt16Builder, UInt16, values, size),
                 DataType::UInt32 => build_list!(UInt32Builder, UInt32, values, size),
                 DataType::UInt64 => build_list!(UInt64Builder, UInt64, values, size),
+                DataType::Utf8 => build_list!(StringBuilder, Utf8, values, size),
+                DataType::LargeUtf8 => {
+                    build_list!(LargeStringBuilder, LargeUtf8, values, size)
+                }
                 _ => panic!("Unexpected DataType for list"),
             }),
             ScalarValue::Date32(e) => match e {
diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs
index 15234a8..8c2c35e 100644
--- a/rust/datafusion/tests/sql.rs
+++ b/rust/datafusion/tests/sql.rs
@@ -2560,6 +2560,17 @@
     test_expression!("'2' IN ('a','b',NULL,1)", "NULL");
     test_expression!("'1' NOT IN ('a','b',NULL,1)", "false");
     test_expression!("'2' NOT IN ('a','b',NULL,1)", "NULL");
+    test_expression!("regexp_match('foobarbequebaz', '')", "[]");
+    test_expression!(
+        "regexp_match('foobarbequebaz', '(bar)(beque)')",
+        "[bar, beque]"
+    );
+    test_expression!("regexp_match('foobarbequebaz', '(ba3r)(bequ34e)')", "NULL");
+    test_expression!("regexp_match('aaa-0', '.*-(\\d)')", "[0]");
+    test_expression!("regexp_match('bb-1', '.*-(\\d)')", "[1]");
+    test_expression!("regexp_match('aa', '.*-(\\d)')", "NULL");
+    test_expression!("regexp_match(NULL, '.*-(\\d)')", "NULL");
+    test_expression!("regexp_match('aaa-0', NULL)", "NULL");
     Ok(())
 }