ARROW-12146: [C++][Gandiva] Implement CONVERT_FROM(expression, replacement char) function
Implement CONVERT_FROM(expression, ‘UTF8’, replacement char)
Converts the byte data in expression to UTF-8. Expression can be a literal string or a field name. Will replace any invalid UTF-8 characters with the replacement character.
Obs.: Actually we will only support a single byte replacement char
Closes #9844 from jpedroantunes/feature/convert-replace-utf8 and squashes the following commits:
bef6eafda <João Pedro> Add optimization for returning original string if no invalid chars were found
e7c6a71db <João Pedro> Refactor memcpy unnecessary for single byte
7aac875e7 <João Pedro> Add handler for cases with 0 char len on replace char
6544583f0 <João Pedro> Apply proper identation on types.h and string_ops.cc in gandiva
c66efb8e4 <João Pedro> Apply corrections and optimization on convert replace function
d815f854c <João Pedro> Add validation for MSBs on convert replace utf8 Gandiva function
8e44d413d <João Pedro> Add validation for defined char length greater than 1 on convert replace
a2ea61bee <João Pedro> Adapt convert_from method to support single char on replacement (defined with dremio team)
7d4cec02c <João Pedro> Adapt convert_from method to support multiple char on replacement
1a1734b9a <João Pedro> Change string ops test for defining int variables instead of size_t
b96dfc750 <João Pedro> Fix lint problems on string ops and test files
8f9a4bde0 <João Pedro> Fix identation on string files on gandiva module
875a1dd87 <João Pedro> Add integration test for convert replace utf8 method
536fd3a63 <João Pedro> Add definition of convert replace str method to types.h
c950c8a45 <João Pedro> Add base tests for convert replace invalid chars
2a5fe944e <João Pedro> Add base logic for convert replace utf8 invalid chars
Authored-by: João Pedro <joaop@simbioseventures.com>
Signed-off-by: Praveen <praveen@dremio.com>
diff --git a/cpp/src/gandiva/function_registry_string.cc b/cpp/src/gandiva/function_registry_string.cc
index ff438db..3c0d714 100644
--- a/cpp/src/gandiva/function_registry_string.cc
+++ b/cpp/src/gandiva/function_registry_string.cc
@@ -204,6 +204,11 @@
utf8(), kResultNullIfNull, "convert_fromUTF8_binary",
NativeFunction::kNeedsContext),
+ NativeFunction("convert_replaceUTF8", {"convert_replaceutf8"},
+ DataTypeVector{binary(), utf8()}, utf8(), kResultNullIfNull,
+ "convert_replace_invalid_fromUTF8_binary",
+ NativeFunction::kNeedsContext),
+
NativeFunction("locate", {"position"}, DataTypeVector{utf8(), utf8(), int32()},
int32(), kResultNullIfNull, "locate_utf8_utf8_int32",
NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors),
diff --git a/cpp/src/gandiva/precompiled/string_ops.cc b/cpp/src/gandiva/precompiled/string_ops.cc
index 92071ac..fa9164b 100644
--- a/cpp/src/gandiva/precompiled/string_ops.cc
+++ b/cpp/src/gandiva/precompiled/string_ops.cc
@@ -156,6 +156,17 @@
free(error);
}
+FORCE_INLINE
+bool validate_utf8_following_bytes(const char* data, int32_t data_len,
+ int32_t char_index) {
+ for (int j = 1; j < data_len; ++j) {
+ if ((data[char_index + j] & 0xC0) != 0x80) { // bytes following head-byte of glyph
+ return false;
+ }
+ }
+ return true;
+}
+
// Count the number of utf8 characters
// return 0 for invalid/incomplete input byte sequences
FORCE_INLINE
@@ -1246,6 +1257,59 @@
return ret;
}
+FORCE_INLINE
+const char* convert_replace_invalid_fromUTF8_binary(int64_t context, const char* text_in,
+ int32_t text_len,
+ const char* char_to_replace,
+ int32_t char_to_replace_len,
+ int32_t* out_len) {
+ if (char_to_replace_len == 0) {
+ *out_len = text_len;
+ return text_in;
+ } else if (char_to_replace_len != 1) {
+ gdv_fn_context_set_error_msg(context, "Replacement of multiple bytes not supported");
+ *out_len = 0;
+ return "";
+ }
+ // actually the convert_replace function replaces invalid chars with an ASCII
+ // character so the output length will be the same as the input length
+ *out_len = text_len;
+ char* ret = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, *out_len));
+ if (ret == nullptr) {
+ gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string");
+ *out_len = 0;
+ return "";
+ }
+ int32_t valid_bytes_to_cpy = 0;
+ int32_t out_byte_counter = 0;
+ int32_t char_len;
+ // scan the base text from left to right and increment the start pointer till
+ // looking for invalid chars to substitute
+ for (int text_index = 0; text_index < text_len; text_index += char_len) {
+ char_len = utf8_char_length(text_in[text_index]);
+ // only memory copy the bytes when detect invalid char
+ if (char_len == 0 || text_index + char_len > text_len ||
+ !validate_utf8_following_bytes(text_in, char_len, text_index)) {
+ // define char_len = 1 to increase text_index by 1 (as ASCII char fits in 1 byte)
+ char_len = 1;
+ // first copy the valid bytes until now and then replace the invalid character
+ memcpy(ret + out_byte_counter, text_in + out_byte_counter, valid_bytes_to_cpy);
+ ret[out_byte_counter + valid_bytes_to_cpy] = char_to_replace[0];
+ out_byte_counter += valid_bytes_to_cpy + char_len;
+ valid_bytes_to_cpy = 0;
+ continue;
+ }
+ valid_bytes_to_cpy += char_len;
+ }
+ // if invalid chars were not found, return the original string
+ if (out_byte_counter == 0) return text_in;
+ // if there are still valid bytes to copy, do it
+ if (valid_bytes_to_cpy != 0) {
+ memcpy(ret + out_byte_counter, text_in + out_byte_counter, valid_bytes_to_cpy);
+ }
+ return ret;
+}
+
// Search for a string within another string
FORCE_INLINE
gdv_int32 locate_utf8_utf8(gdv_int64 context, const char* sub_str, gdv_int32 sub_str_len,
diff --git a/cpp/src/gandiva/precompiled/string_ops_test.cc b/cpp/src/gandiva/precompiled/string_ops_test.cc
index b8f467a..9326aac 100644
--- a/cpp/src/gandiva/precompiled/string_ops_test.cc
+++ b/cpp/src/gandiva/precompiled/string_ops_test.cc
@@ -115,6 +115,64 @@
ctx.Reset();
}
+TEST(TestStringOps, TestConvertReplaceInvalidUtf8Char) {
+ gandiva::ExecutionContext ctx;
+ uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx);
+
+ // invalid utf8 (xf8 is invalid but x28 is not - x28 = '(')
+ std::string a(
+ "ok-\xf8\x28"
+ "-a");
+ auto a_in_out_len = static_cast<int>(a.length());
+ const char* a_str = convert_replace_invalid_fromUTF8_binary(
+ ctx_ptr, a.data(), a_in_out_len, "a", 1, &a_in_out_len);
+ EXPECT_EQ(std::string(a_str, a_in_out_len), "ok-a(-a");
+ EXPECT_FALSE(ctx.has_error());
+
+ // invalid utf8 (xa0 and xa1 are invalid)
+ std::string b("ok-\xa0\xa1-valid");
+ auto b_in_out_len = static_cast<int>(b.length());
+ const char* b_str = convert_replace_invalid_fromUTF8_binary(
+ ctx_ptr, b.data(), b_in_out_len, "b", 1, &b_in_out_len);
+ EXPECT_EQ(std::string(b_str, b_in_out_len), "ok-bb-valid");
+ EXPECT_FALSE(ctx.has_error());
+
+ // full valid utf8
+ std::string c("all-valid");
+ auto c_in_out_len = static_cast<int>(c.length());
+ const char* c_str = convert_replace_invalid_fromUTF8_binary(
+ ctx_ptr, c.data(), c_in_out_len, "c", 1, &c_in_out_len);
+ EXPECT_EQ(std::string(c_str, c_in_out_len), "all-valid");
+ EXPECT_FALSE(ctx.has_error());
+
+ // valid utf8 (महसुस is 4-char string, each char of which is likely a multibyte char)
+ std::string d("ok-महसुस-valid-new");
+ auto d_in_out_len = static_cast<int>(d.length());
+ const char* d_str = convert_replace_invalid_fromUTF8_binary(
+ ctx_ptr, d.data(), d_in_out_len, "d", 1, &d_in_out_len);
+ EXPECT_EQ(std::string(d_str, d_in_out_len), "ok-महसुस-valid-new");
+ EXPECT_FALSE(ctx.has_error());
+
+ // full valid utf8, but invalid replacement char length
+ std::string e("all-valid");
+ auto e_in_out_len = static_cast<int>(e.length());
+ const char* e_str = convert_replace_invalid_fromUTF8_binary(
+ ctx_ptr, e.data(), e_in_out_len, "ee", 2, &e_in_out_len);
+ EXPECT_EQ(std::string(e_str, e_in_out_len), "");
+ EXPECT_TRUE(ctx.has_error());
+ ctx.Reset();
+
+ // full valid utf8, but invalid replacement char length
+ std::string f("ok-\xa0\xa1-valid");
+ auto f_in_out_len = static_cast<int>(f.length());
+ const char* f_str = convert_replace_invalid_fromUTF8_binary(
+ ctx_ptr, f.data(), f_in_out_len, "", 0, &f_in_out_len);
+ EXPECT_EQ(std::string(f_str, f_in_out_len), "ok-\xa0\xa1-valid");
+ EXPECT_FALSE(ctx.has_error());
+
+ ctx.Reset();
+}
+
TEST(TestStringOps, TestCastBoolToVarchar) {
gandiva::ExecutionContext ctx;
uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx);
diff --git a/cpp/src/gandiva/precompiled/types.h b/cpp/src/gandiva/precompiled/types.h
index bc17208..eefc1f7 100644
--- a/cpp/src/gandiva/precompiled/types.h
+++ b/cpp/src/gandiva/precompiled/types.h
@@ -417,6 +417,12 @@
gdv_int32 from_str_len, const char* to_str,
gdv_int32 to_str_len, gdv_int32* out_len);
+const char* convert_replace_invalid_fromUTF8_binary(int64_t context, const char* text_in,
+ int32_t text_len,
+ const char* char_to_replace,
+ int32_t char_to_replace_len,
+ int32_t* out_len);
+
const char* split_part(gdv_int64 context, const char* text, gdv_int32 text_len,
const char* splitter, gdv_int32 split_len, gdv_int32 index,
gdv_int32* out_len);
diff --git a/cpp/src/gandiva/tests/utf8_test.cc b/cpp/src/gandiva/tests/utf8_test.cc
index 103992d..29ce81f 100644
--- a/cpp/src/gandiva/tests/utf8_test.cc
+++ b/cpp/src/gandiva/tests/utf8_test.cc
@@ -539,6 +539,56 @@
EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
}
+TEST_F(TestUtf8, TestConvertUtf8) {
+ // schema for input fields
+ auto field_a = field("a", arrow::binary());
+ auto field_c = field("c", utf8());
+ auto schema = arrow::schema({field_a, field_c});
+
+ // output fields
+ auto res = field("res", boolean());
+
+ // build expressions.
+ auto node_a = TreeExprBuilder::MakeField(field_a);
+ auto node_c = TreeExprBuilder::MakeField(field_c);
+
+ // define char to replace
+ auto node_b = TreeExprBuilder::MakeStringLiteral("z");
+
+ auto convert_replace_utf8 =
+ TreeExprBuilder::MakeFunction("convert_replaceUTF8", {node_a, node_b}, utf8());
+ auto equals =
+ TreeExprBuilder::MakeFunction("equal", {convert_replace_utf8, node_c}, boolean());
+ auto expr = TreeExprBuilder::MakeExpression(equals, res);
+
+ // Build a projector for the expressions.
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ // Create a row-batch with some sample data
+ int num_records = 3;
+ auto array_a = MakeArrowArrayUtf8({"ok-\xf8\x28"
+ "-a",
+ "all-valid", "ok-\xa0\xa1-valid"},
+ {true, true, true});
+
+ auto array_b =
+ MakeArrowArrayUtf8({"ok-z(-a", "all-valid", "ok-zz-valid"}, {true, true, true});
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a, array_b});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ auto exp = MakeArrowArrayBool({true, true, true}, {true, true, true});
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp, outputs[0]);
+}
+
TEST_F(TestUtf8, TestCastVarChar) {
// schema for input fields
auto field_a = field("a", utf8());