ARROW-7215: [C++][Gandiva] Implement castVARCHAR(numeric_type) functions
This PR implements the castVARCHAR for numeric values inside the Gandiva.
It replaces the logic of the https://github.com/apache/arrow/pull/8158 PR to change the function output to match the Java language patterns.
Closes #9816 from anthonylouisbsb/feature/fix-castvarchar-to-match-java-impl and squashes the following commits:
7df55a58d <Anthony Louis> Apply formatting changes
7a724c0fd <Anthony Louis> Remove unnecessary macros
4fb8a7f44 <Anthony Louis> Refactor if chain
e78705136 <Anthony Louis> Add test to infinity case
b62b856a0 <Anthony Louis> Add comments for changes
cec11bbf0 <Anthony Louis> Add tests to check Java compatibility
302139c78 <Anthony Louis> Add emit trailing point tests
efb94b901 <Anthony Louis> Add -0.0 inside cast test
523e60a56 <Anthony Louis> Add custom constructor inside the class
34f2f926d <Anthony Louis> Add class to print in formatted way
e244502b3 <Anthony Louis> Fix tests to consider java formatting
33bc5b2de <Projjal Chanda> added castvarchar(numeric_types) functions
Lead-authored-by: Anthony Louis <anthony@simbioseventures.com>
Co-authored-by: Projjal Chanda <iam@pchanda.com>
Signed-off-by: Praveen <praveen@dremio.com>
diff --git a/cpp/src/arrow/util/formatting.cc b/cpp/src/arrow/util/formatting.cc
index 9e4d25c..c16d42c 100644
--- a/cpp/src/arrow/util/formatting.cc
+++ b/cpp/src/arrow/util/formatting.cc
@@ -43,11 +43,29 @@
: converter_(DoubleToStringConverter::EMIT_POSITIVE_EXPONENT_SIGN, "inf", "nan",
'e', -6, 10, 6, 0) {}
+ Impl(int flags, const char* inf_symbol, const char* nan_symbol, char exp_character,
+ int decimal_in_shortest_low, int decimal_in_shortest_high,
+ int max_leading_padding_zeroes_in_precision_mode,
+ int max_trailing_padding_zeroes_in_precision_mode)
+ : converter_(flags, inf_symbol, nan_symbol, exp_character, decimal_in_shortest_low,
+ decimal_in_shortest_high, max_leading_padding_zeroes_in_precision_mode,
+ max_trailing_padding_zeroes_in_precision_mode) {}
+
DoubleToStringConverter converter_;
};
FloatToStringFormatter::FloatToStringFormatter() : impl_(new Impl()) {}
+FloatToStringFormatter::FloatToStringFormatter(
+ int flags, const char* inf_symbol, const char* nan_symbol, char exp_character,
+ int decimal_in_shortest_low, int decimal_in_shortest_high,
+ int max_leading_padding_zeroes_in_precision_mode,
+ int max_trailing_padding_zeroes_in_precision_mode)
+ : impl_(new Impl(flags, inf_symbol, nan_symbol, exp_character,
+ decimal_in_shortest_low, decimal_in_shortest_high,
+ max_leading_padding_zeroes_in_precision_mode,
+ max_trailing_padding_zeroes_in_precision_mode)) {}
+
FloatToStringFormatter::~FloatToStringFormatter() {}
int FloatToStringFormatter::FormatFloat(float v, char* out_buffer, int out_size) {
diff --git a/cpp/src/arrow/util/formatting.h b/cpp/src/arrow/util/formatting.h
index 5f4b251..566c979 100644
--- a/cpp/src/arrow/util/formatting.h
+++ b/cpp/src/arrow/util/formatting.h
@@ -31,6 +31,7 @@
#include "arrow/status.h"
#include "arrow/type.h"
#include "arrow/type_traits.h"
+#include "arrow/util/double_conversion.h"
#include "arrow/util/string_view.h"
#include "arrow/util/time.h"
#include "arrow/util/visibility.h"
@@ -219,6 +220,11 @@
class ARROW_EXPORT FloatToStringFormatter {
public:
FloatToStringFormatter();
+ FloatToStringFormatter(int flags, const char* inf_symbol, const char* nan_symbol,
+ char exp_character, int decimal_in_shortest_low,
+ int decimal_in_shortest_high,
+ int max_leading_padding_zeroes_in_precision_mode,
+ int max_trailing_padding_zeroes_in_precision_mode);
~FloatToStringFormatter();
// Returns the number of characters written
@@ -239,6 +245,16 @@
explicit FloatToStringFormatterMixin(const std::shared_ptr<DataType>& = NULLPTR) {}
+ FloatToStringFormatterMixin(int flags, const char* inf_symbol, const char* nan_symbol,
+ char exp_character, int decimal_in_shortest_low,
+ int decimal_in_shortest_high,
+ int max_leading_padding_zeroes_in_precision_mode,
+ int max_trailing_padding_zeroes_in_precision_mode)
+ : FloatToStringFormatter(flags, inf_symbol, nan_symbol, exp_character,
+ decimal_in_shortest_low, decimal_in_shortest_high,
+ max_leading_padding_zeroes_in_precision_mode,
+ max_trailing_padding_zeroes_in_precision_mode) {}
+
template <typename Appender>
Return<Appender> operator()(value_type value, Appender&& append) {
char buffer[buffer_size];
diff --git a/cpp/src/arrow/vendored/double-conversion/double-conversion.cc b/cpp/src/arrow/vendored/double-conversion/double-conversion.cc
index 5d5d6f1..27e70b4 100644
--- a/cpp/src/arrow/vendored/double-conversion/double-conversion.cc
+++ b/cpp/src/arrow/vendored/double-conversion/double-conversion.cc
@@ -84,7 +84,25 @@
StringBuilder* result_builder) const {
ASSERT(length != 0);
result_builder->AddCharacter(decimal_digits[0]);
- if (length != 1) {
+
+ /* If the mantissa of the scientific notation representation is an integer number,
+ * the EMIT_TRAILING_DECIMAL_POINT flag will add a '.' character at the end of the
+ * representation:
+ * - With EMIT_TRAILING_DECIMAL_POINT enabled -> 0.0009 => 9.E-4
+ * - With EMIT_TRAILING_DECIMAL_POINT disabled -> 0.0009 => 9E-4
+ *
+ * If the mantissa is an integer and the EMIT_TRAILING_ZERO_AFTER_POINT flag is enabled
+ * it will add a '0' character at the end of the mantissa representation. Note that that
+ * flag depends on EMIT_TRAILING_DECIMAL_POINT flag be enabled.*/
+ if(length == 1){
+ if ((flags_ & EMIT_TRAILING_DECIMAL_POINT) != 0) {
+ result_builder->AddCharacter('.');
+
+ if ((flags_ & EMIT_TRAILING_ZERO_AFTER_POINT) != 0) {
+ result_builder->AddCharacter('0');
+ }
+ }
+ } else {
result_builder->AddCharacter('.');
result_builder->AddSubstring(&decimal_digits[1], length-1);
}
diff --git a/cpp/src/arrow/vendored/double-conversion/double-conversion.h b/cpp/src/arrow/vendored/double-conversion/double-conversion.h
index 6dbc099..9dc3ebd 100644
--- a/cpp/src/arrow/vendored/double-conversion/double-conversion.h
+++ b/cpp/src/arrow/vendored/double-conversion/double-conversion.h
@@ -104,6 +104,17 @@
// ToPrecision(230.0, 2) -> "230"
// ToPrecision(230.0, 2) -> "230." with EMIT_TRAILING_DECIMAL_POINT.
// ToPrecision(230.0, 2) -> "2.3e2" with EMIT_TRAILING_ZERO_AFTER_POINT.
+ //
+ // When converting numbers to scientific notation representation, if the mantissa of
+ // the representation is an integer number, the EMIT_TRAILING_DECIMAL_POINT flag will
+ // add a '.' character at the end of the representation:
+ // - With EMIT_TRAILING_DECIMAL_POINT enabled -> 0.0009 => 9.E-4
+ // - With EMIT_TRAILING_DECIMAL_POINT disabled -> 0.0009 => 9E-4
+ //
+ // If the mantissa is an integer and the EMIT_TRAILING_ZERO_AFTER_POINT flag is enabled
+ // it will add a '0' character at the end of the mantissa representation. Note that that
+ // flag depends on EMIT_TRAILING_DECIMAL_POINT flag be enabled.
+ // - With EMIT_TRAILING_ZERO_AFTER_POINT enabled -> 0.0009 => 9.0E-4
DoubleToStringConverter(int flags,
const char* infinity_symbol,
const char* nan_symbol,
diff --git a/cpp/src/gandiva/formatting_utils.h b/cpp/src/gandiva/formatting_utils.h
new file mode 100644
index 0000000..7bc6a49
--- /dev/null
+++ b/cpp/src/gandiva/formatting_utils.h
@@ -0,0 +1,69 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include "arrow/type.h"
+#include "arrow/util/formatting.h"
+#include "arrow/vendored/double-conversion/double-conversion.h"
+
+namespace gandiva {
+
+/// \brief The entry point for conversion to strings.
+template <typename ARROW_TYPE, typename Enable = void>
+class GdvStringFormatter;
+
+using double_conversion::DoubleToStringConverter;
+
+template <typename ARROW_TYPE>
+class FloatToStringGdvMixin
+ : public arrow::internal::FloatToStringFormatterMixin<ARROW_TYPE> {
+ public:
+ using arrow::internal::FloatToStringFormatterMixin<
+ ARROW_TYPE>::FloatToStringFormatterMixin;
+
+ // The mixin is a modified version of the existent FloatToStringFormatterMixin, but
+ // it defines some specific parameters in the FloatToStringFormatterMixin to cast
+ // the float numbers to string using the same patterns like Java.
+ //
+ // The Java real numbers are represented in two ways following these rules:
+ //- If the number is greater or equals than 10^7 and less than 10^(-3)
+ // it will be represented using scientific notation, e.g:
+ // - 0.000012 -> 1.2E-5
+ // - 10000002.3 -> 1.00000023E7
+ //- If the numbers are between that interval above, they are showed as is.
+ explicit FloatToStringGdvMixin(const std::shared_ptr<arrow::DataType>& = NULLPTR)
+ : arrow::internal::FloatToStringFormatterMixin<ARROW_TYPE>(
+ DoubleToStringConverter::EMIT_TRAILING_ZERO_AFTER_POINT |
+ DoubleToStringConverter::EMIT_TRAILING_DECIMAL_POINT,
+ "Infinity", "NaN", 'E', -3, 7, 3, 1) {}
+};
+
+template <>
+class GdvStringFormatter<arrow::FloatType>
+ : public FloatToStringGdvMixin<arrow::FloatType> {
+ public:
+ using FloatToStringGdvMixin::FloatToStringGdvMixin;
+};
+
+template <>
+class GdvStringFormatter<arrow::DoubleType>
+ : public FloatToStringGdvMixin<arrow::DoubleType> {
+ public:
+ using FloatToStringGdvMixin::FloatToStringGdvMixin;
+};
+} // namespace gandiva
diff --git a/cpp/src/gandiva/function_registry_string.cc b/cpp/src/gandiva/function_registry_string.cc
index 3c0d714..d1f97cd 100644
--- a/cpp/src/gandiva/function_registry_string.cc
+++ b/cpp/src/gandiva/function_registry_string.cc
@@ -92,6 +92,22 @@
kResultNullIfNull, "castVARCHAR_utf8_int64",
NativeFunction::kNeedsContext),
+ NativeFunction("castVARCHAR", {}, DataTypeVector{int32(), int64()}, utf8(),
+ kResultNullIfNull, "gdv_fn_castVARCHAR_int32_int64",
+ NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors),
+
+ NativeFunction("castVARCHAR", {}, DataTypeVector{int64(), int64()}, utf8(),
+ kResultNullIfNull, "gdv_fn_castVARCHAR_int64_int64",
+ NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors),
+
+ NativeFunction("castVARCHAR", {}, DataTypeVector{float32(), int64()}, utf8(),
+ kResultNullIfNull, "gdv_fn_castVARCHAR_float32_int64",
+ NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors),
+
+ NativeFunction("castVARCHAR", {}, DataTypeVector{float64(), int64()}, utf8(),
+ kResultNullIfNull, "gdv_fn_castVARCHAR_float64_int64",
+ NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors),
+
NativeFunction("castVARCHAR", {}, DataTypeVector{decimal128(), int64()}, utf8(),
kResultNullIfNull, "castVARCHAR_decimal128_int64",
NativeFunction::kNeedsContext),
diff --git a/cpp/src/gandiva/gdv_function_stubs.cc b/cpp/src/gandiva/gdv_function_stubs.cc
index 2d0e1a7..832eebc 100644
--- a/cpp/src/gandiva/gdv_function_stubs.cc
+++ b/cpp/src/gandiva/gdv_function_stubs.cc
@@ -20,12 +20,15 @@
#include <string>
#include <vector>
+#include "arrow/util/formatting.h"
#include "arrow/util/value_parsing.h"
#include "gandiva/engine.h"
#include "gandiva/exported_funcs.h"
+#include "gandiva/formatting_utils.h"
#include "gandiva/hash_utils.h"
#include "gandiva/in_holder.h"
#include "gandiva/like_holder.h"
+#include "gandiva/precompiled/types.h"
#include "gandiva/random_generator_holder.h"
#include "gandiva/to_date_holder.h"
@@ -303,6 +306,86 @@
CAST_NUMERIC_FROM_STRING(double, arrow::DoubleType, FLOAT8)
#undef CAST_NUMERIC_FROM_STRING
+
+#define GDV_FN_CAST_VARCHAR_INTEGER(IN_TYPE, ARROW_TYPE) \
+ GANDIVA_EXPORT \
+ const char* gdv_fn_castVARCHAR_##IN_TYPE##_int64(int64_t context, gdv_##IN_TYPE value, \
+ int64_t len, int32_t * out_len) { \
+ if (len < 0) { \
+ gdv_fn_context_set_error_msg(context, "Buffer length can not be negative"); \
+ *out_len = 0; \
+ return ""; \
+ } \
+ if (len == 0) { \
+ *out_len = 0; \
+ return ""; \
+ } \
+ arrow::internal::StringFormatter<arrow::ARROW_TYPE> formatter; \
+ char* ret = reinterpret_cast<char*>( \
+ gdv_fn_context_arena_malloc(context, static_cast<int32_t>(len))); \
+ if (ret == nullptr) { \
+ gdv_fn_context_set_error_msg(context, "Could not allocate memory"); \
+ *out_len = 0; \
+ return ""; \
+ } \
+ arrow::Status status = formatter(value, [&](arrow::util::string_view v) { \
+ int64_t size = static_cast<int64_t>(v.size()); \
+ *out_len = static_cast<int32_t>(len < size ? len : size); \
+ memcpy(ret, v.data(), *out_len); \
+ return arrow::Status::OK(); \
+ }); \
+ if (!status.ok()) { \
+ std::string err = "Could not cast " + std::to_string(value) + " to string"; \
+ gdv_fn_context_set_error_msg(context, err.c_str()); \
+ *out_len = 0; \
+ return ""; \
+ } \
+ return ret; \
+ }
+
+#define GDV_FN_CAST_VARCHAR_REAL(IN_TYPE, ARROW_TYPE) \
+ GANDIVA_EXPORT \
+ const char* gdv_fn_castVARCHAR_##IN_TYPE##_int64(int64_t context, gdv_##IN_TYPE value, \
+ int64_t len, int32_t * out_len) { \
+ if (len < 0) { \
+ gdv_fn_context_set_error_msg(context, "Buffer length can not be negative"); \
+ *out_len = 0; \
+ return ""; \
+ } \
+ if (len == 0) { \
+ *out_len = 0; \
+ return ""; \
+ } \
+ gandiva::GdvStringFormatter<arrow::ARROW_TYPE> formatter; \
+ char* ret = reinterpret_cast<char*>( \
+ gdv_fn_context_arena_malloc(context, static_cast<int32_t>(len))); \
+ if (ret == nullptr) { \
+ gdv_fn_context_set_error_msg(context, "Could not allocate memory"); \
+ *out_len = 0; \
+ return ""; \
+ } \
+ arrow::Status status = formatter(value, [&](arrow::util::string_view v) { \
+ int64_t size = static_cast<int64_t>(v.size()); \
+ *out_len = static_cast<int32_t>(len < size ? len : size); \
+ memcpy(ret, v.data(), *out_len); \
+ return arrow::Status::OK(); \
+ }); \
+ if (!status.ok()) { \
+ std::string err = "Could not cast " + std::to_string(value) + " to string"; \
+ gdv_fn_context_set_error_msg(context, err.c_str()); \
+ *out_len = 0; \
+ return ""; \
+ } \
+ return ret; \
+ }
+
+GDV_FN_CAST_VARCHAR_INTEGER(int32, Int32Type)
+GDV_FN_CAST_VARCHAR_INTEGER(int64, Int64Type)
+GDV_FN_CAST_VARCHAR_REAL(float32, FloatType)
+GDV_FN_CAST_VARCHAR_REAL(float64, DoubleType)
+
+#undef GDV_FN_CAST_VARCHAR_INTEGER
+#undef GDV_FN_CAST_VARCHAR_REAL
}
namespace gandiva {
@@ -471,6 +554,42 @@
engine->AddGlobalMappingForFunc("gdv_fn_castFLOAT8_utf8", types->double_type(), args,
reinterpret_cast<void*>(gdv_fn_castFLOAT8_utf8));
+ // gdv_fn_castVARCHAR_int32_int64
+ args = {types->i64_type(), // int64_t execution_context
+ types->i32_type(), // int32_t value
+ types->i64_type(), // int64_t len
+ types->i32_ptr_type()}; // int32_t* out_len
+ engine->AddGlobalMappingForFunc(
+ "gdv_fn_castVARCHAR_int32_int64", types->i8_ptr_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_castVARCHAR_int32_int64));
+
+ // gdv_fn_castVARCHAR_int64_int64
+ args = {types->i64_type(), // int64_t execution_context
+ types->i64_type(), // int64_t value
+ types->i64_type(), // int64_t len
+ types->i32_ptr_type()}; // int32_t* out_len
+ engine->AddGlobalMappingForFunc(
+ "gdv_fn_castVARCHAR_int64_int64", types->i8_ptr_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_castVARCHAR_int64_int64));
+
+ // gdv_fn_castVARCHAR_float32_int64
+ args = {types->i64_type(), // int64_t execution_context
+ types->float_type(), // float value
+ types->i64_type(), // int64_t len
+ types->i32_ptr_type()}; // int32_t* out_len
+ engine->AddGlobalMappingForFunc(
+ "gdv_fn_castVARCHAR_float32_int64", types->i8_ptr_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_castVARCHAR_float32_int64));
+
+ // gdv_fn_castVARCHAR_float64_int64
+ args = {types->i64_type(), // int64_t execution_context
+ types->double_type(), // double value
+ types->i64_type(), // int64_t len
+ types->i32_ptr_type()}; // int32_t* out_len
+ engine->AddGlobalMappingForFunc(
+ "gdv_fn_castVARCHAR_float64_int64", types->i8_ptr_type() /*return_type*/, args,
+ reinterpret_cast<void*>(gdv_fn_castVARCHAR_float64_int64));
+
// gdv_fn_sha1_int8
args = {
types->i64_type(), // context
diff --git a/cpp/src/gandiva/gdv_function_stubs.h b/cpp/src/gandiva/gdv_function_stubs.h
index 255e9af..0a6cd70 100644
--- a/cpp/src/gandiva/gdv_function_stubs.h
+++ b/cpp/src/gandiva/gdv_function_stubs.h
@@ -95,4 +95,17 @@
GANDIVA_EXPORT
double gdv_fn_castFLOAT8_utf8(int64_t context, const char* data, int32_t data_len);
+
+GANDIVA_EXPORT
+const char* gdv_fn_castVARCHAR_int32_int64(int64_t context, int32_t value, int64_t len,
+ int32_t* out_len);
+GANDIVA_EXPORT
+const char* gdv_fn_castVARCHAR_int64_int64(int64_t context, int64_t value, int64_t len,
+ int32_t* out_len);
+GANDIVA_EXPORT
+const char* gdv_fn_castVARCHAR_float32_int64(int64_t context, float value, int64_t len,
+ int32_t* out_len);
+GANDIVA_EXPORT
+const char* gdv_fn_castVARCHAR_float64_int64(int64_t context, double value, int64_t len,
+ int32_t* out_len);
}
diff --git a/cpp/src/gandiva/gdv_function_stubs_test.cc b/cpp/src/gandiva/gdv_function_stubs_test.cc
index 90ac1df..8f44ce2 100644
--- a/cpp/src/gandiva/gdv_function_stubs_test.cc
+++ b/cpp/src/gandiva/gdv_function_stubs_test.cc
@@ -160,4 +160,134 @@
ctx.Reset();
}
+TEST(TestGdvFnStubs, TestCastVARCHARFromInt32) {
+ gandiva::ExecutionContext ctx;
+ uint64_t ctx_ptr = reinterpret_cast<int64_t>(&ctx);
+ int32_t out_len = 0;
+
+ const char* out_str = gdv_fn_castVARCHAR_int32_int64(ctx_ptr, -46, 100, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "-46");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_castVARCHAR_int32_int64(ctx_ptr, 2147483647, 100, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "2147483647");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_castVARCHAR_int32_int64(ctx_ptr, -2147483647 - 1, 100, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "-2147483648");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_castVARCHAR_int32_int64(ctx_ptr, 0, 100, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "0");
+ EXPECT_FALSE(ctx.has_error());
+
+ // test with required length less than actual buffer length
+ out_str = gdv_fn_castVARCHAR_int32_int64(ctx_ptr, 34567, 3, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "345");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_castVARCHAR_int32_int64(ctx_ptr, 347, 0, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_castVARCHAR_int32_int64(ctx_ptr, 347, -1, &out_len);
+ EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Buffer length can not be negative"));
+ ctx.Reset();
+}
+
+TEST(TestGdvFnStubs, TestCastVARCHARFromInt64) {
+ gandiva::ExecutionContext ctx;
+ uint64_t ctx_ptr = reinterpret_cast<int64_t>(&ctx);
+ int32_t out_len = 0;
+
+ const char* out_str =
+ gdv_fn_castVARCHAR_int64_int64(ctx_ptr, 9223372036854775807LL, 100, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "9223372036854775807");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str =
+ gdv_fn_castVARCHAR_int64_int64(ctx_ptr, -9223372036854775807LL - 1, 100, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "-9223372036854775808");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_castVARCHAR_int64_int64(ctx_ptr, 0, 100, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "0");
+ EXPECT_FALSE(ctx.has_error());
+
+ // test with required length less than actual buffer length
+ out_str = gdv_fn_castVARCHAR_int64_int64(ctx_ptr, 12345, 3, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "123");
+ EXPECT_FALSE(ctx.has_error());
+}
+
+TEST(TestGdvFnStubs, TestCastVARCHARFromFloat) {
+ gandiva::ExecutionContext ctx;
+ uint64_t ctx_ptr = reinterpret_cast<int64_t>(&ctx);
+ int32_t out_len = 0;
+
+ const char* out_str = gdv_fn_castVARCHAR_float32_int64(ctx_ptr, 4.567f, 100, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "4.567");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_castVARCHAR_float32_int64(ctx_ptr, -3.4567f, 100, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "-3.4567");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_castVARCHAR_float32_int64(ctx_ptr, 0.00001f, 100, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "1.0E-5");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_castVARCHAR_float32_int64(ctx_ptr, 0.00099999f, 100, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "9.9999E-4");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_castVARCHAR_float32_int64(ctx_ptr, 0.0f, 100, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "0.0");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_castVARCHAR_float32_int64(ctx_ptr, 10.00000f, 100, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "10.0");
+ EXPECT_FALSE(ctx.has_error());
+
+ // test with required length less than actual buffer length
+ out_str = gdv_fn_castVARCHAR_float32_int64(ctx_ptr, 1.2345f, 3, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "1.2");
+ EXPECT_FALSE(ctx.has_error());
+}
+
+TEST(TestGdvFnStubs, TestCastVARCHARFromDouble) {
+ gandiva::ExecutionContext ctx;
+ uint64_t ctx_ptr = reinterpret_cast<int64_t>(&ctx);
+ int32_t out_len = 0;
+
+ const char* out_str = gdv_fn_castVARCHAR_float64_int64(ctx_ptr, 4.567, 100, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "4.567");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_castVARCHAR_float64_int64(ctx_ptr, -3.4567, 100, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "-3.4567");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_castVARCHAR_float64_int64(ctx_ptr, 0.00001, 100, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "1.0E-5");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_castVARCHAR_float32_int64(ctx_ptr, 0.00099999f, 100, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "9.9999E-4");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_castVARCHAR_float64_int64(ctx_ptr, 0.0, 100, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "0.0");
+ EXPECT_FALSE(ctx.has_error());
+
+ out_str = gdv_fn_castVARCHAR_float64_int64(ctx_ptr, 10.0000000000, 100, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "10.0");
+ EXPECT_FALSE(ctx.has_error());
+
+ // test with required length less than actual buffer length
+ out_str = gdv_fn_castVARCHAR_float64_int64(ctx_ptr, 1.2345, 3, &out_len);
+ EXPECT_EQ(std::string(out_str, out_len), "1.2");
+ EXPECT_FALSE(ctx.has_error());
+}
+
} // namespace gandiva
diff --git a/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorTest.java b/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorTest.java
index 446efd1..606c1a9 100644
--- a/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorTest.java
+++ b/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorTest.java
@@ -1193,7 +1193,7 @@
Field c1 = Field.nullable("c1", int32);
TreeNode inExpr =
- TreeBuilder.makeInExpressionInt32(TreeBuilder.makeField(c1), Sets.newHashSet(1, 2, 3, 4, 5, 15, 16));
+ TreeBuilder.makeInExpressionInt32(TreeBuilder.makeField(c1), Sets.newHashSet(1, 2, 3, 4, 5, 15, 16));
ExpressionTree expr = TreeBuilder.makeExpression(inExpr, Field.nullable("result", boolType));
Schema schema = new Schema(Lists.newArrayList(c1));
Projector eval = Projector.make(schema, Lists.newArrayList(expr));
@@ -1208,10 +1208,10 @@
ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0);
ArrowRecordBatch batch =
- new ArrowRecordBatch(
- numRows,
- Lists.newArrayList(fieldNode, fieldNode),
- Lists.newArrayList(c1Validity, c1Data, c2Validity));
+ new ArrowRecordBatch(
+ numRows,
+ Lists.newArrayList(fieldNode, fieldNode),
+ Lists.newArrayList(c1Validity, c1Data, c2Validity));
BitVector bitVector = new BitVector(EMPTY_SCHEMA_PATH, allocator);
bitVector.allocateNew(numRows);
@@ -1297,7 +1297,7 @@
List<TreeNode> args = Lists.newArrayList(TreeBuilder.makeField(c1), l1, l2);
TreeNode substr = TreeBuilder.makeFunction("substr", args, new ArrowType.Utf8());
TreeNode inExpr =
- TreeBuilder.makeInExpressionString(substr, Sets.newHashSet("one", "two", "thr", "fou"));
+ TreeBuilder.makeInExpressionString(substr, Sets.newHashSet("one", "two", "thr", "fou"));
ExpressionTree expr = TreeBuilder.makeExpression(inExpr, Field.nullable("result", boolType));
Schema schema = new Schema(Lists.newArrayList(c1));
Projector eval = Projector.make(schema, Lists.newArrayList(expr));
@@ -1305,8 +1305,8 @@
int numRows = 16;
byte[] validity = new byte[]{(byte) 255, 0};
String[] c1Values = new String[]{"one", "two", "three", "four", "five", "six", "seven",
- "eight", "nine", "ten", "eleven", "twelve", "thirteen", "fourteen", "fifteen",
- "sixteen"};
+ "eight", "nine", "ten", "eleven", "twelve", "thirteen", "fourteen", "fifteen",
+ "sixteen"};
ArrowBuf c1Validity = buf(validity);
List<ArrowBuf> dataBufsX = stringBufs(c1Values);
@@ -1314,10 +1314,10 @@
ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0);
ArrowRecordBatch batch =
- new ArrowRecordBatch(
- numRows,
- Lists.newArrayList(fieldNode, fieldNode),
- Lists.newArrayList(c1Validity, dataBufsX.get(0), dataBufsX.get(1), c2Validity));
+ new ArrowRecordBatch(
+ numRows,
+ Lists.newArrayList(fieldNode, fieldNode),
+ Lists.newArrayList(c1Validity, dataBufsX.get(0), dataBufsX.get(1), c2Validity));
BitVector bitVector = new BitVector(EMPTY_SCHEMA_PATH, allocator);
bitVector.allocateNew(numRows);
@@ -1509,9 +1509,9 @@
Field resultField = Field.nullable("result", date64);
List<ExpressionTree> exprs =
- Lists.newArrayList(
- TreeBuilder.makeExpression(dateToYear, resultField),
- TreeBuilder.makeExpression(dateToMonth, resultField));
+ Lists.newArrayList(
+ TreeBuilder.makeExpression(dateToYear, resultField),
+ TreeBuilder.makeExpression(dateToMonth, resultField));
Schema schema = new Schema(Lists.newArrayList(dateField));
Projector eval = Projector.make(schema, exprs);
@@ -1544,10 +1544,10 @@
ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0);
ArrowRecordBatch batch =
- new ArrowRecordBatch(
- numRows,
- Lists.newArrayList(fieldNode),
- Lists.newArrayList(bufValidity, millisData));
+ new ArrowRecordBatch(
+ numRows,
+ Lists.newArrayList(fieldNode),
+ Lists.newArrayList(bufValidity, millisData));
List<ValueVector> output = new ArrayList<ValueVector>();
for (int i = 0; i < exprs.size(); i++) {
@@ -2044,6 +2044,194 @@
releaseRecordBatch(batch);
releaseValueVectors(output);
eval.close();
+ }
+
+ @Test
+ public void testCastVarcharFromInteger() throws Exception {
+ Field inField = Field.nullable("input", int32);
+ Field lenField = Field.nullable("outLength", int64);
+
+ TreeNode inNode = TreeBuilder.makeField(inField);
+ TreeNode lenNode = TreeBuilder.makeField(lenField);
+
+ TreeNode tsToString = TreeBuilder.makeFunction("castVARCHAR", Lists.newArrayList(inNode, lenNode),
+ new ArrowType.Utf8());
+
+ Field resultField = Field.nullable("result", new ArrowType.Utf8());
+ List<ExpressionTree> exprs =
+ Lists.newArrayList(
+ TreeBuilder.makeExpression(tsToString, resultField));
+
+ Schema schema = new Schema(Lists.newArrayList(inField, lenField));
+ Projector eval = Projector.make(schema, exprs);
+
+ int numRows = 5;
+ byte[] validity = new byte[] {(byte) 255};
+ int[] values =
+ new int[] {
+ 2345,
+ 2345,
+ 2345,
+ 2345,
+ -2345,
+ };
+ long[] lenValues =
+ new long[] {
+ 0L, 4L, 2L, 6L, 5L
+ };
+
+ String[] expValues =
+ new String[] {
+ "",
+ Integer.toString(2345).substring(0, 4),
+ Integer.toString(2345).substring(0, 2),
+ Integer.toString(2345),
+ Integer.toString(-2345)
+ };
+
+ ArrowBuf bufValidity = buf(validity);
+ ArrowBuf bufData = intBuf(values);
+ ArrowBuf lenValidity = buf(validity);
+ ArrowBuf lenData = longBuf(lenValues);
+
+ ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0);
+ ArrowRecordBatch batch =
+ new ArrowRecordBatch(
+ numRows,
+ Lists.newArrayList(fieldNode, fieldNode),
+ Lists.newArrayList(bufValidity, bufData, lenValidity, lenData));
+
+ List<ValueVector> output = new ArrayList<>();
+ for (int i = 0; i < exprs.size(); i++) {
+ VarCharVector charVector = new VarCharVector(EMPTY_SCHEMA_PATH, allocator);
+
+ charVector.allocateNew(numRows * 5, numRows);
+ output.add(charVector);
+ }
+ eval.evaluate(batch, output);
+ eval.close();
+
+ for (ValueVector valueVector : output) {
+ VarCharVector charVector = (VarCharVector) valueVector;
+
+ for (int j = 0; j < numRows; j++) {
+ assertFalse(charVector.isNull(j));
+ assertEquals(expValues[j], new String(charVector.get(j)));
+ }
+ }
+
+ releaseRecordBatch(batch);
+ releaseValueVectors(output);
+ }
+
+ @Test
+ public void testCastVarcharFromFloat() throws Exception {
+ Field inField = Field.nullable("input", float64);
+ Field lenField = Field.nullable("outLength", int64);
+
+ TreeNode inNode = TreeBuilder.makeField(inField);
+ TreeNode lenNode = TreeBuilder.makeField(lenField);
+
+ TreeNode tsToString = TreeBuilder.makeFunction("castVARCHAR", Lists.newArrayList(inNode, lenNode),
+ new ArrowType.Utf8());
+
+ Field resultField = Field.nullable("result", new ArrowType.Utf8());
+ List<ExpressionTree> exprs =
+ Lists.newArrayList(
+ TreeBuilder.makeExpression(tsToString, resultField));
+
+ Schema schema = new Schema(Lists.newArrayList(inField, lenField));
+ Projector eval = Projector.make(schema, exprs);
+
+ int numRows = 5;
+ byte[] validity = new byte[] {(byte) 255};
+ double[] values =
+ new double[] {
+ 0.0,
+ -0.0,
+ 1.0,
+ 0.001,
+ 0.0009,
+ 0.00099893,
+ 999999.9999,
+ 10000000.0,
+ 23943410000000.343434,
+ Double.POSITIVE_INFINITY,
+ Double.NEGATIVE_INFINITY,
+ Double.NaN,
+ 23.45,
+ 23.45,
+ -23.45,
+ };
+ long[] lenValues =
+ new long[] {
+ 6L, 6L, 6L, 6L, 10L, 15L, 15L, 15L, 30L,
+ 15L, 15L, 15L, 0L, 6L, 6L
+ };
+
+ /* The Java real numbers are represented in two ways and Gandiva must
+ * follow the same rules:
+ * - If the number is greater or equals than 10^7 and less than 10^(-3)
+ * it will be represented using scientific notation, e.g:
+ * - 0.000012 -> 1.2E-5
+ * - 10000002.3 -> 1.00000023E7
+ * - If the numbers are between that interval above, they are showed as is.
+ *
+ * The test checks if the Gandiva function casts the number with the same notation of the
+ * Java.
+ * */
+ String[] expValues =
+ new String[] {
+ Double.toString(0.0), // must be cast to -> "0.0"
+ Double.toString(-0.0), // must be cast to -> "-0.0"
+ Double.toString(1.0), // must be cast to -> "1.0"
+ Double.toString(0.001), // must be cast to -> "0.001"
+ Double.toString(0.0009), // must be cast to -> "9E-4"
+ Double.toString(0.00099893), // must be cast to -> "9E-4"
+ Double.toString(999999.9999), // must be cast to -> "999999.9999"
+ Double.toString(10000000.0), // must be cast to 1E7
+ Double.toString(23943410000000.343434),
+ Double.toString(Double.POSITIVE_INFINITY),
+ Double.toString(Double.NEGATIVE_INFINITY),
+ Double.toString(Double.NaN),
+ "",
+ Double.toString(23.45),
+ Double.toString(-23.45)
+ };
+
+ ArrowBuf bufValidity = buf(validity);
+ ArrowBuf bufData = doubleBuf(values);
+ ArrowBuf lenValidity = buf(validity);
+ ArrowBuf lenData = longBuf(lenValues);
+
+ ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0);
+ ArrowRecordBatch batch =
+ new ArrowRecordBatch(
+ numRows,
+ Lists.newArrayList(fieldNode, fieldNode),
+ Lists.newArrayList(bufValidity, bufData, lenValidity, lenData));
+
+ List<ValueVector> output = new ArrayList<>();
+ for (int i = 0; i < exprs.size(); i++) {
+ VarCharVector charVector = new VarCharVector(EMPTY_SCHEMA_PATH, allocator);
+
+ charVector.allocateNew(numRows * 5, numRows);
+ output.add(charVector);
+ }
+ eval.evaluate(batch, output);
+ eval.close();
+
+ for (ValueVector valueVector : output) {
+ VarCharVector charVector = (VarCharVector) valueVector;
+
+ for (int j = 0; j < numRows; j++) {
+ assertFalse(charVector.isNull(j));
+ assertEquals(expValues[j], new String(charVector.get(j)));
+ }
+ }
+
+ releaseRecordBatch(batch);
+ releaseValueVectors(output);
}
}