ARROW-12134: [C++] Add match_substring_regex kernel
For consistency with match_substring, this is the equivalent of Python's re.search(), not re.match().
Closes #9838 from lidavidm/arrow-12134
Authored-by: David Li <li.davidm96@gmail.com>
Signed-off-by: Antoine Pitrou <antoine@python.org>
diff --git a/cpp/src/arrow/compute/api_scalar.h b/cpp/src/arrow/compute/api_scalar.h
index 730836b..f59426d 100644
--- a/cpp/src/arrow/compute/api_scalar.h
+++ b/cpp/src/arrow/compute/api_scalar.h
@@ -45,7 +45,7 @@
struct ARROW_EXPORT MatchSubstringOptions : public FunctionOptions {
explicit MatchSubstringOptions(std::string pattern) : pattern(std::move(pattern)) {}
- /// The exact substring to look for inside input values.
+ /// The exact substring (or regex, depending on kernel) to look for inside input values.
std::string pattern;
};
diff --git a/cpp/src/arrow/compute/kernels/scalar_string.cc b/cpp/src/arrow/compute/kernels/scalar_string.cc
index 3986987..9ec1fe0 100644
--- a/cpp/src/arrow/compute/kernels/scalar_string.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_string.cc
@@ -368,83 +368,130 @@
}
}
-template <typename offset_type>
-void TransformMatchSubstring(const uint8_t* pattern, int64_t pattern_length,
- const offset_type* offsets, const uint8_t* data,
- int64_t length, int64_t output_offset, uint8_t* output) {
- // This is an implementation of the Knuth-Morris-Pratt algorithm
-
- // Phase 1: Build the prefix table
- std::vector<offset_type> prefix_table(pattern_length + 1);
- offset_type prefix_length = -1;
- prefix_table[0] = -1;
- for (offset_type pos = 0; pos < pattern_length; ++pos) {
- // The prefix cannot be expanded, reset.
- while (prefix_length >= 0 && pattern[pos] != pattern[prefix_length]) {
- prefix_length = prefix_table[prefix_length];
- }
- prefix_length++;
- prefix_table[pos + 1] = prefix_length;
- }
-
- // Phase 2: Find the prefix in the data
- FirstTimeBitmapWriter bitmap_writer(output, output_offset, length);
- for (int64_t i = 0; i < length; ++i) {
- const uint8_t* current_data = data + offsets[i];
- int64_t current_length = offsets[i + 1] - offsets[i];
-
- int64_t pattern_pos = 0;
- for (int64_t k = 0; k < current_length; k++) {
- while ((pattern_pos >= 0) && (pattern[pattern_pos] != current_data[k])) {
- pattern_pos = prefix_table[pattern_pos];
- }
- pattern_pos++;
- if (pattern_pos == pattern_length) {
- bitmap_writer.Set();
- break;
- }
- }
- bitmap_writer.Next();
- }
- bitmap_writer.Finish();
-}
-
using MatchSubstringState = OptionsWrapper<MatchSubstringOptions>;
-template <typename Type>
+template <typename Type, typename Matcher>
struct MatchSubstring {
using offset_type = typename Type::offset_type;
static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
- MatchSubstringOptions arg = MatchSubstringState::Get(ctx);
- const uint8_t* pat = reinterpret_cast<const uint8_t*>(arg.pattern.c_str());
- const int64_t pat_size = arg.pattern.length();
+ // TODO Cache matcher across invocations (for regex compilation)
+ Matcher matcher(ctx, MatchSubstringState::Get(ctx));
+ if (ctx->HasError()) return;
StringBoolTransform<Type>(
ctx, batch,
- [pat, pat_size](const void* offsets, const uint8_t* data, int64_t length,
- int64_t output_offset, uint8_t* output) {
- TransformMatchSubstring<offset_type>(
- pat, pat_size, reinterpret_cast<const offset_type*>(offsets), data, length,
- output_offset, output);
+ [&matcher](const void* raw_offsets, const uint8_t* data, int64_t length,
+ int64_t output_offset, uint8_t* output) {
+ const offset_type* offsets = reinterpret_cast<const offset_type*>(raw_offsets);
+ FirstTimeBitmapWriter bitmap_writer(output, output_offset, length);
+ for (int64_t i = 0; i < length; ++i) {
+ const char* current_data = reinterpret_cast<const char*>(data + offsets[i]);
+ int64_t current_length = offsets[i + 1] - offsets[i];
+ if (matcher.Match(util::string_view(current_data, current_length))) {
+ bitmap_writer.Set();
+ }
+ bitmap_writer.Next();
+ }
+ bitmap_writer.Finish();
},
out);
}
};
+// This is an implementation of the Knuth-Morris-Pratt algorithm
+struct PlainSubstringMatcher {
+ const MatchSubstringOptions& options_;
+ std::vector<int64_t> prefix_table;
+
+ PlainSubstringMatcher(KernelContext* ctx, const MatchSubstringOptions& options)
+ : options_(options) {
+ // Phase 1: Build the prefix table
+ const auto pattern_length = options_.pattern.size();
+ prefix_table.resize(pattern_length + 1, /*value=*/0);
+ int64_t prefix_length = -1;
+ prefix_table[0] = -1;
+ for (size_t pos = 0; pos < pattern_length; ++pos) {
+ // The prefix cannot be expanded, reset.
+ while (prefix_length >= 0 &&
+ options_.pattern[pos] != options_.pattern[prefix_length]) {
+ prefix_length = prefix_table[prefix_length];
+ }
+ prefix_length++;
+ prefix_table[pos + 1] = prefix_length;
+ }
+ }
+
+ bool Match(util::string_view current) {
+ // Phase 2: Find the prefix in the data
+ const auto pattern_length = options_.pattern.size();
+ int64_t pattern_pos = 0;
+ for (const auto c : current) {
+ while ((pattern_pos >= 0) && (options_.pattern[pattern_pos] != c)) {
+ pattern_pos = prefix_table[pattern_pos];
+ }
+ pattern_pos++;
+ if (static_cast<size_t>(pattern_pos) == pattern_length) {
+ return true;
+ }
+ }
+ return false;
+ }
+};
+
const FunctionDoc match_substring_doc(
"Match strings against literal pattern",
("For each string in `strings`, emit true iff it contains a given pattern.\n"
"Null inputs emit null. The pattern must be given in MatchSubstringOptions."),
{"strings"}, "MatchSubstringOptions");
+#ifdef ARROW_WITH_RE2
+struct RegexSubstringMatcher {
+ const MatchSubstringOptions& options_;
+ const RE2 regex_match_;
+
+ RegexSubstringMatcher(KernelContext* ctx, const MatchSubstringOptions& options)
+ : options_(options), regex_match_(options_.pattern) {
+ if (!regex_match_.ok()) {
+ ctx->SetStatus(Status::Invalid("Regular expression error"));
+ }
+ }
+
+ bool Match(util::string_view current) {
+ auto piece = re2::StringPiece(current.data(), current.length());
+ return re2::RE2::PartialMatch(piece, regex_match_);
+ }
+};
+
+const FunctionDoc match_substring_regex_doc(
+ "Match strings against regex pattern",
+ ("For each string in `strings`, emit true iff it matches a given pattern at any "
+ "position.\n"
+ "Null inputs emit null. The pattern must be given in MatchSubstringOptions."),
+ {"strings"}, "MatchSubstringOptions");
+#endif
+
void AddMatchSubstring(FunctionRegistry* registry) {
- auto func = std::make_shared<ScalarFunction>("match_substring", Arity::Unary(),
- &match_substring_doc);
- auto exec_32 = MatchSubstring<StringType>::Exec;
- auto exec_64 = MatchSubstring<LargeStringType>::Exec;
- DCHECK_OK(func->AddKernel({utf8()}, boolean(), exec_32, MatchSubstringState::Init));
- DCHECK_OK(
- func->AddKernel({large_utf8()}, boolean(), exec_64, MatchSubstringState::Init));
- DCHECK_OK(registry->AddFunction(std::move(func)));
+ {
+ auto func = std::make_shared<ScalarFunction>("match_substring", Arity::Unary(),
+ &match_substring_doc);
+ auto exec_32 = MatchSubstring<StringType, PlainSubstringMatcher>::Exec;
+ auto exec_64 = MatchSubstring<LargeStringType, PlainSubstringMatcher>::Exec;
+ DCHECK_OK(func->AddKernel({utf8()}, boolean(), exec_32, MatchSubstringState::Init));
+ DCHECK_OK(
+ func->AddKernel({large_utf8()}, boolean(), exec_64, MatchSubstringState::Init));
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+ }
+#ifdef ARROW_WITH_RE2
+ {
+ auto func = std::make_shared<ScalarFunction>("match_substring_regex", Arity::Unary(),
+ &match_substring_regex_doc);
+ auto exec_32 = MatchSubstring<StringType, RegexSubstringMatcher>::Exec;
+ auto exec_64 = MatchSubstring<LargeStringType, RegexSubstringMatcher>::Exec;
+ DCHECK_OK(func->AddKernel({utf8()}, boolean(), exec_32, MatchSubstringState::Init));
+ DCHECK_OK(
+ func->AddKernel({large_utf8()}, boolean(), exec_64, MatchSubstringState::Init));
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+ }
+#endif
}
// IsAlpha/Digit etc
@@ -1246,7 +1293,7 @@
using State = OptionsWrapper<ReplaceSubstringOptions>;
static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
- // TODO Cache replacer accross invocations (for regex compilation)
+ // TODO Cache replacer across invocations (for regex compilation)
Replacer replacer{ctx, State::Get(ctx)};
if (!ctx->HasError()) {
Replace(ctx, batch, &replacer, out);
diff --git a/cpp/src/arrow/compute/kernels/scalar_string_test.cc b/cpp/src/arrow/compute/kernels/scalar_string_test.cc
index 88622e8..2dd0a4d 100644
--- a/cpp/src/arrow/compute/kernels/scalar_string_test.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_string_test.cc
@@ -348,6 +348,27 @@
&options_double_char_2);
}
+#ifdef ARROW_WITH_RE2
+TYPED_TEST(TestStringKernels, MatchSubstringRegex) {
+ MatchSubstringOptions options{"ab"};
+ this->CheckUnary("match_substring_regex", "[]", boolean(), "[]", &options);
+ this->CheckUnary("match_substring_regex", R"(["abc", "acb", "cab", null, "bac"])",
+ boolean(), "[true, false, true, null, false]", &options);
+ MatchSubstringOptions options_repeated{"(ab){2}"};
+ this->CheckUnary("match_substring_regex", R"(["abab", "ab", "cababc", null, "bac"])",
+ boolean(), "[true, false, true, null, false]", &options_repeated);
+ MatchSubstringOptions options_digit{"\\d"};
+ this->CheckUnary("match_substring_regex", R"(["aacb", "a2ab", "", "24"])", boolean(),
+ "[false, true, false, true]", &options_digit);
+ MatchSubstringOptions options_star{"a*b"};
+ this->CheckUnary("match_substring_regex", R"(["aacb", "aab", "dab", "caaab", "b", ""])",
+ boolean(), "[true, true, true, true, true, false]", &options_star);
+ MatchSubstringOptions options_plus{"a+b"};
+ this->CheckUnary("match_substring_regex", R"(["aacb", "aab", "dab", "caaab", "b", ""])",
+ boolean(), "[false, true, true, true, false, false]", &options_plus);
+}
+#endif
+
TYPED_TEST(TestStringKernels, SplitBasics) {
SplitPatternOptions options{" "};
// basics
diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst
index 065b807..715d503 100644
--- a/docs/source/cpp/compute.rst
+++ b/docs/source/cpp/compute.rst
@@ -522,26 +522,31 @@
Containment tests
~~~~~~~~~~~~~~~~~
-+--------------------+------------+------------------------------------+---------------+----------------------------------------+
-| Function name | Arity | Input types | Output type | Options class |
-+====================+============+====================================+===============+========================================+
-| match_substring | Unary | String-like | Boolean (1) | :struct:`MatchSubstringOptions` |
-+--------------------+------------+------------------------------------+---------------+----------------------------------------+
-| index_in | Unary | Boolean, Null, Numeric, Temporal, | Int32 (2) | :struct:`SetLookupOptions` |
-| | | Binary- and String-like | | |
-+--------------------+------------+------------------------------------+---------------+----------------------------------------+
-| is_in | Unary | Boolean, Null, Numeric, Temporal, | Boolean (3) | :struct:`SetLookupOptions` |
-| | | Binary- and String-like | | |
-+--------------------+------------+------------------------------------+---------------+----------------------------------------+
++---------------------------+------------+------------------------------------+---------------+----------------------------------------+
+| Function name | Arity | Input types | Output type | Options class |
++===========================+============+====================================+===============+========================================+
+| match_substring | Unary | String-like | Boolean (1) | :struct:`MatchSubstringOptions` |
++---------------------------+------------+------------------------------------+---------------+----------------------------------------+
+| match_substring_regex | Unary | String-like | Boolean (2) | :struct:`MatchSubstringOptions` |
++---------------------------+------------+------------------------------------+---------------+----------------------------------------+
+| index_in | Unary | Boolean, Null, Numeric, Temporal, | Int32 (3) | :struct:`SetLookupOptions` |
+| | | Binary- and String-like | | |
++---------------------------+------------+------------------------------------+---------------+----------------------------------------+
+| is_in | Unary | Boolean, Null, Numeric, Temporal, | Boolean (4) | :struct:`SetLookupOptions` |
+| | | Binary- and String-like | | |
++---------------------------+------------+------------------------------------+---------------+----------------------------------------+
* \(1) Output is true iff :member:`MatchSubstringOptions::pattern`
is a substring of the corresponding input element.
-* \(2) Output is the index of the corresponding input element in
+* \(2) Output is true iff :member:`MatchSubstringOptions::pattern`
+ matches the corresponding input element at any position.
+
+* \(3) Output is the index of the corresponding input element in
:member:`SetLookupOptions::value_set`, if found there. Otherwise,
output is null.
-* \(3) Output is true iff the corresponding input element is equal to one
+* \(4) Output is true iff the corresponding input element is equal to one
of the elements in :member:`SetLookupOptions::value_set`.
diff --git a/docs/source/python/api/compute.rst b/docs/source/python/api/compute.rst
index 2dafbd2..d6efc6a 100644
--- a/docs/source/python/api/compute.rst
+++ b/docs/source/python/api/compute.rst
@@ -155,6 +155,7 @@
index_in
is_in
match_substring
+ match_substring_regex
Conversions
-----------
diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py
index 1b46a08..3928b9c 100644
--- a/python/pyarrow/compute.py
+++ b/python/pyarrow/compute.py
@@ -306,6 +306,24 @@
MatchSubstringOptions(pattern))
+def match_substring_regex(array, pattern):
+ """
+ Test if regex *pattern* matches at any position a value of a string array.
+
+ Parameters
+ ----------
+ array : pyarrow.Array or pyarrow.ChunkedArray
+ pattern : str
+ regex pattern to search
+
+ Returns
+ -------
+ result : pyarrow.Array or pyarrow.ChunkedArray
+ """
+ return call_function("match_substring_regex", [array],
+ MatchSubstringOptions(pattern))
+
+
def sum(array):
"""
Sum the values in a numerical (chunked) array.
diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py
index 160375f..94a6189 100644
--- a/python/pyarrow/tests/test_compute.py
+++ b/python/pyarrow/tests/test_compute.py
@@ -279,6 +279,13 @@
assert expected.equals(result)
+def test_match_substring_regex():
+ arr = pa.array(["ab", "abc", "ba", "c", None])
+ result = pc.match_substring_regex(arr, "^a?b")
+ expected = pa.array([True, True, True, False, None])
+ assert expected.equals(result)
+
+
def test_trim():
# \u3000 is unicode whitespace
arr = pa.array([" foo", None, " \u3000foo bar \t"])