blob: b526e7259ef61b5ecc166fd3f1df227b386a2212 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
// Some of these functions reference the Postgres documentation
// or implementation to ensure compatibility and are subject to
// the Postgres license.
//! Regex expressions
use std::any::type_name;
use std::sync::Arc;
use crate::error::{DataFusionError, Result};
use arrow::array::{ArrayRef, GenericStringArray, StringOffsetSizeTrait};
use arrow::compute;
use hashbrown::HashMap;
use regex::Regex;
macro_rules! downcast_string_arg {
($ARG:expr, $NAME:expr, $T:ident) => {{
$ARG.as_any()
.downcast_ref::<GenericStringArray<T>>()
.ok_or_else(|| {
DataFusionError::Internal(format!(
"could not cast {} to {}",
$NAME,
type_name::<GenericStringArray<T>>()
))
})?
}};
}
/// extract a specific group from a string column, using a regular expression
pub fn regexp_match<T: StringOffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
match args.len() {
2 => compute::regexp_match(downcast_string_arg!(args[0], "string", T), downcast_string_arg!(args[1], "pattern", T), None)
.map_err(DataFusionError::ArrowError),
3 => compute::regexp_match(downcast_string_arg!(args[0], "string", T), downcast_string_arg!(args[1], "pattern", T), Some(downcast_string_arg!(args[1], "flags", T)))
.map_err(DataFusionError::ArrowError),
other => Err(DataFusionError::Internal(format!(
"regexp_match was called with {} arguments. It requires at least 2 and at most 3.",
other
))),
}
}
/// replace POSIX capture groups (like \1) with Rust Regex group (like ${1})
/// used by regexp_replace
fn regex_replace_posix_groups(replacement: &str) -> String {
lazy_static! {
static ref CAPTURE_GROUPS_RE: Regex = Regex::new("(\\\\)(\\d*)").unwrap();
}
CAPTURE_GROUPS_RE
.replace_all(replacement, "$${$2}")
.into_owned()
}
/// Replaces substring(s) matching a POSIX regular expression.
///
/// example: `regexp_replace('Thomas', '.[mN]a.', 'M') = 'ThM'`
pub fn regexp_replace<T: StringOffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
// creating Regex is expensive so create hashmap for memoization
let mut patterns: HashMap<String, Regex> = HashMap::new();
match args.len() {
3 => {
let string_array = downcast_string_arg!(args[0], "string", T);
let pattern_array = downcast_string_arg!(args[1], "pattern", T);
let replacement_array = downcast_string_arg!(args[2], "replacement", T);
let result = string_array
.iter()
.zip(pattern_array.iter())
.zip(replacement_array.iter())
.map(|((string, pattern), replacement)| match (string, pattern, replacement) {
(Some(string), Some(pattern), Some(replacement)) => {
let replacement = regex_replace_posix_groups(replacement);
// if patterns hashmap already has regexp then use else else create and return
let re = match patterns.get(pattern) {
Some(re) => Ok(re.clone()),
None => {
match Regex::new(pattern) {
Ok(re) => {
patterns.insert(pattern.to_string(), re.clone());
Ok(re)
},
Err(err) => Err(DataFusionError::Execution(err.to_string())),
}
}
};
Some(re.map(|re| re.replace(string, replacement.as_str()))).transpose()
}
_ => Ok(None)
})
.collect::<Result<GenericStringArray<T>>>()?;
Ok(Arc::new(result) as ArrayRef)
}
4 => {
let string_array = downcast_string_arg!(args[0], "string", T);
let pattern_array = downcast_string_arg!(args[1], "pattern", T);
let replacement_array = downcast_string_arg!(args[2], "replacement", T);
let flags_array = downcast_string_arg!(args[3], "flags", T);
let result = string_array
.iter()
.zip(pattern_array.iter())
.zip(replacement_array.iter())
.zip(flags_array.iter())
.map(|(((string, pattern), replacement), flags)| match (string, pattern, replacement, flags) {
(Some(string), Some(pattern), Some(replacement), Some(flags)) => {
let replacement = regex_replace_posix_groups(replacement);
// format flags into rust pattern
let (pattern, replace_all) = if flags == "g" {
(pattern.to_string(), true)
} else if flags.contains('g') {
(format!("(?{}){}", flags.to_string().replace("g", ""), pattern), true)
} else {
(format!("(?{}){}", flags, pattern), false)
};
// if patterns hashmap already has regexp then use else else create and return
let re = match patterns.get(&pattern) {
Some(re) => Ok(re.clone()),
None => {
match Regex::new(pattern.as_str()) {
Ok(re) => {
patterns.insert(pattern, re.clone());
Ok(re)
},
Err(err) => Err(DataFusionError::Execution(err.to_string())),
}
}
};
Some(re.map(|re| {
if replace_all {
re.replace_all(string, replacement.as_str())
} else {
re.replace(string, replacement.as_str())
}
})).transpose()
}
_ => Ok(None)
})
.collect::<Result<GenericStringArray<T>>>()?;
Ok(Arc::new(result) as ArrayRef)
}
other => Err(DataFusionError::Internal(format!(
"regexp_replace was called with {} arguments. It requires at least 3 and at most 4.",
other
))),
}
}