feat: implement truncate max for literals (#585)
diff --git a/src/iceberg/test/truncate_util_test.cc b/src/iceberg/test/truncate_util_test.cc
index 61010fc..849f67d 100644
--- a/src/iceberg/test/truncate_util_test.cc
+++ b/src/iceberg/test/truncate_util_test.cc
@@ -22,6 +22,7 @@
#include <gtest/gtest.h>
#include "iceberg/expression/literal.h"
+#include "iceberg/test/matchers.h"
namespace iceberg {
@@ -50,4 +51,141 @@
Literal::Binary(std::vector<uint8_t>(expected.begin(), expected.end())));
}
+TEST(TruncateUtilTest, TruncateBinaryMax) {
+ std::vector<uint8_t> test1{1, 1, 2};
+ std::vector<uint8_t> test2{1, 1, 0xFF, 2};
+ std::vector<uint8_t> test3{0xFF, 0xFF, 0xFF, 2};
+ std::vector<uint8_t> test4{1, 1, 0};
+ std::vector<uint8_t> expected_output{1, 2};
+
+ // Test1: truncate {1, 1, 2} to 2 bytes -> {1, 2}
+ ICEBERG_UNWRAP_OR_FAIL(auto result1,
+ TruncateUtils::TruncateLiteralMax(Literal::Binary(test1), 2));
+ EXPECT_EQ(result1, Literal::Binary(expected_output));
+
+ // Test2: truncate {1, 1, 0xFF, 2} to 2 bytes -> {1, 2}
+ ICEBERG_UNWRAP_OR_FAIL(auto result2,
+ TruncateUtils::TruncateLiteralMax(Literal::Binary(test2), 2));
+ EXPECT_EQ(result2, Literal::Binary(expected_output));
+
+ // Test2b: truncate {1, 1, 0xFF, 2} to 3 bytes -> {1, 2}
+ ICEBERG_UNWRAP_OR_FAIL(auto result2b,
+ TruncateUtils::TruncateLiteralMax(Literal::Binary(test2), 3));
+ EXPECT_EQ(result2b, Literal::Binary(expected_output));
+
+ // Test3: no truncation needed when length >= input size
+ ICEBERG_UNWRAP_OR_FAIL(auto result3,
+ TruncateUtils::TruncateLiteralMax(Literal::Binary(test3), 5));
+ EXPECT_EQ(result3, Literal::Binary(test3));
+
+ // Test3b: cannot truncate when first bytes are all 0xFF
+ EXPECT_THAT(TruncateUtils::TruncateLiteralMax(Literal::Binary(test3), 2),
+ IsError(ErrorKind::kInvalidArgument));
+
+ // Test4: truncate {1, 1, 0} to 2 bytes -> {1, 2}
+ ICEBERG_UNWRAP_OR_FAIL(auto result4,
+ TruncateUtils::TruncateLiteralMax(Literal::Binary(test4), 2));
+ EXPECT_EQ(result4, Literal::Binary(expected_output));
+}
+
+TEST(TruncateUtilTest, TruncateStringMax) {
+ // Test1: Japanese characters "イロハニホヘト"
+ std::string test1 =
+ "\xE3\x82\xA4\xE3\x83\xAD\xE3\x83\x8F\xE3\x83\x8B\xE3\x83\x9B\xE3\x83\x98\xE3\x83"
+ "\x88";
+ std::string test1_2_expected = "\xE3\x82\xA4\xE3\x83\xAE"; // "イヮ"
+ std::string test1_3_expected = "\xE3\x82\xA4\xE3\x83\xAD\xE3\x83\x90"; // "イロバ"
+
+ ICEBERG_UNWRAP_OR_FAIL(auto result1_2,
+ TruncateUtils::TruncateLiteralMax(Literal::String(test1), 2));
+ EXPECT_EQ(result1_2, Literal::String(test1_2_expected));
+
+ ICEBERG_UNWRAP_OR_FAIL(auto result1_3,
+ TruncateUtils::TruncateLiteralMax(Literal::String(test1), 3));
+ EXPECT_EQ(result1_3, Literal::String(test1_3_expected));
+
+ // No truncation needed when length >= input size
+ ICEBERG_UNWRAP_OR_FAIL(auto result1_7,
+ TruncateUtils::TruncateLiteralMax(Literal::String(test1), 7));
+ EXPECT_EQ(result1_7, Literal::String(test1));
+
+ ICEBERG_UNWRAP_OR_FAIL(auto result1_8,
+ TruncateUtils::TruncateLiteralMax(Literal::String(test1), 8));
+ EXPECT_EQ(result1_8, Literal::String(test1));
+
+ // Test2: Mixed characters "щщаεはчωいにπάほхεろへσκζ"
+ std::string test2 =
+ "\xD1\x89\xD1\x89\xD0\xB0\xCE\xB5\xE3\x81\xAF\xD1\x87\xCF\x89\xE3\x81\x84\xE3\x81"
+ "\xAB\xCF\x80\xCE\xAC\xE3\x81\xBB\xD1\x85\xCE\xB5\xE3\x82\x8D\xE3\x81\xB8\xCF\x83"
+ "\xCE\xBA\xCE\xB6";
+ std::string test2_7_expected =
+ "\xD1\x89\xD1\x89\xD0\xB0\xCE\xB5\xE3\x81\xAF\xD1\x87\xCF\x8A"; // "щщаεはчϊ"
+
+ ICEBERG_UNWRAP_OR_FAIL(auto result2_7,
+ TruncateUtils::TruncateLiteralMax(Literal::String(test2), 7));
+ EXPECT_EQ(result2_7, Literal::String(test2_7_expected));
+
+ // Test3: String with max 3-byte UTF-8 character "aनि\uFFFF\uFFFF"
+ std::string test3 = "a\xE0\xA4\xA8\xE0\xA4\xBF\xEF\xBF\xBF\xEF\xBF\xBF";
+ std::string test3_3_expected = "a\xE0\xA4\xA8\xE0\xA5\x80"; // "aनी"
+
+ ICEBERG_UNWRAP_OR_FAIL(auto result3_3,
+ TruncateUtils::TruncateLiteralMax(Literal::String(test3), 3));
+ EXPECT_EQ(result3_3, Literal::String(test3_3_expected));
+
+ // Test4: Max 3-byte UTF-8 character "\uFFFF\uFFFF"
+ std::string test4 = "\xEF\xBF\xBF\xEF\xBF\xBF";
+ std::string test4_1_expected = "\xF0\x90\x80\x80"; // U+10000 (first 4-byte UTF-8 char)
+
+ ICEBERG_UNWRAP_OR_FAIL(auto result4_1,
+ TruncateUtils::TruncateLiteralMax(Literal::String(test4), 1));
+ EXPECT_EQ(result4_1, Literal::String(test4_1_expected));
+
+ // Test5: Max 4-byte UTF-8 characters "\uDBFF\uDFFF\uDBFF\uDFFF"
+ std::string test5 = "\xF4\x8F\xBF\xBF\xF4\x8F\xBF\xBF"; // U+10FFFF U+10FFFF
+ EXPECT_THAT(TruncateUtils::TruncateLiteralMax(Literal::String(test5), 1),
+ IsError(ErrorKind::kInvalidArgument));
+
+ // Test6: 4-byte UTF-8 character "\uD800\uDFFF\uD800\uDFFF"
+ std::string test6 = "\xF0\x90\x8F\xBF\xF0\x90\x8F\xBF"; // U+103FF U+103FF
+ std::string test6_1_expected = "\xF0\x90\x90\x80"; // U+10400
+
+ ICEBERG_UNWRAP_OR_FAIL(auto result6_1,
+ TruncateUtils::TruncateLiteralMax(Literal::String(test6), 1));
+ EXPECT_EQ(result6_1, Literal::String(test6_1_expected));
+
+ // Test7: Emoji "\uD83D\uDE02\uD83D\uDE02\uD83D\uDE02"
+ std::string test7 = "\xF0\x9F\x98\x82\xF0\x9F\x98\x82\xF0\x9F\x98\x82"; // 😂😂😂
+ std::string test7_2_expected = "\xF0\x9F\x98\x82\xF0\x9F\x98\x83"; // 😂😃
+ std::string test7_1_expected = "\xF0\x9F\x98\x83"; // 😃
+
+ ICEBERG_UNWRAP_OR_FAIL(auto result7_2,
+ TruncateUtils::TruncateLiteralMax(Literal::String(test7), 2));
+ EXPECT_EQ(result7_2, Literal::String(test7_2_expected));
+
+ ICEBERG_UNWRAP_OR_FAIL(auto result7_1,
+ TruncateUtils::TruncateLiteralMax(Literal::String(test7), 1));
+ EXPECT_EQ(result7_1, Literal::String(test7_1_expected));
+
+ // Test8: Overflow case "a\uDBFF\uDFFFc"
+ std::string test8 =
+ "a\xF4\x8F\xBF\xBF"
+ "c"; // a U+10FFFF c
+ std::string test8_2_expected = "b";
+
+ ICEBERG_UNWRAP_OR_FAIL(auto result8_2,
+ TruncateUtils::TruncateLiteralMax(Literal::String(test8), 2));
+ EXPECT_EQ(result8_2, Literal::String(test8_2_expected));
+
+ // Test9: Skip surrogate range "a" + (char)(Character.MIN_SURROGATE - 1) + "b"
+ std::string test9 =
+ "a\xED\x9F\xBF"
+ "b"; // a U+D7FF b
+ std::string test9_2_expected = "a\xEE\x80\x80"; // a U+E000
+
+ ICEBERG_UNWRAP_OR_FAIL(auto result9_2,
+ TruncateUtils::TruncateLiteralMax(Literal::String(test9), 2));
+ EXPECT_EQ(result9_2, Literal::String(test9_2_expected));
+}
+
} // namespace iceberg
diff --git a/src/iceberg/util/truncate_util.cc b/src/iceberg/util/truncate_util.cc
index 9d0c6e7..aba22d1 100644
--- a/src/iceberg/util/truncate_util.cc
+++ b/src/iceberg/util/truncate_util.cc
@@ -29,11 +29,105 @@
namespace iceberg {
namespace {
-template <TypeId type_id>
-Literal TruncateLiteralImpl(const Literal& literal, int32_t width) {
- std::unreachable();
+constexpr uint32_t kUtf8MaxCodePoint = 0x10FFFF;
+constexpr uint32_t kUtf8MinSurrogate = 0xD800;
+constexpr uint32_t kUtf8MaxSurrogate = 0xDFFF;
+
+std::optional<uint32_t> DecodeUtf8CodePoint(std::string_view source) {
+ if (source.empty()) {
+ return std::nullopt;
+ }
+
+ auto byte0 = static_cast<uint8_t>(source[0]);
+
+ // 1-byte sequence (ASCII): 0xxxxxxx
+ if (byte0 < 0x80) {
+ return byte0;
+ }
+
+ const auto size = source.size();
+
+ // 2-byte sequence: 110xxxxx 10xxxxxx
+ if ((byte0 & 0xE0) == 0xC0) {
+ if (size < 2) {
+ return std::nullopt;
+ }
+ auto byte1 = static_cast<uint8_t>(source[1]);
+ if ((byte1 & 0xC0) != 0x80) {
+ return std::nullopt;
+ }
+ uint32_t code_point = ((byte0 & 0x1F) << 6) | (byte1 & 0x3F);
+ // Check for overlong encoding
+ if (code_point < 0x80) {
+ return std::nullopt;
+ }
+ return code_point;
+ }
+
+ // 3-byte sequence: 1110xxxx 10xxxxxx 10xxxxxx
+ if ((byte0 & 0xF0) == 0xE0) {
+ if (size < 3) {
+ return std::nullopt;
+ }
+ auto byte1 = static_cast<uint8_t>(source[1]);
+ auto byte2 = static_cast<uint8_t>(source[2]);
+ if ((byte1 & 0xC0) != 0x80 || (byte2 & 0xC0) != 0x80) {
+ return std::nullopt;
+ }
+ uint32_t code_point = ((byte0 & 0x0F) << 12) | ((byte1 & 0x3F) << 6) | (byte2 & 0x3F);
+ // Check for overlong encoding and surrogate pairs
+ if (code_point < 0x800 ||
+ (code_point >= kUtf8MinSurrogate && code_point <= kUtf8MaxSurrogate)) {
+ return std::nullopt;
+ }
+ return code_point;
+ }
+
+ // 4-byte sequence: 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx
+ if ((byte0 & 0xF8) == 0xF0) {
+ if (size < 4) {
+ return std::nullopt;
+ }
+ auto byte1 = static_cast<uint8_t>(source[1]);
+ auto byte2 = static_cast<uint8_t>(source[2]);
+ auto byte3 = static_cast<uint8_t>(source[3]);
+ if ((byte1 & 0xC0) != 0x80 || (byte2 & 0xC0) != 0x80 || (byte3 & 0xC0) != 0x80) {
+ return std::nullopt;
+ }
+ uint32_t code_point = ((byte0 & 0x07) << 18) | ((byte1 & 0x3F) << 12) |
+ ((byte2 & 0x3F) << 6) | (byte3 & 0x3F);
+ // Check for overlong encoding and valid Unicode range
+ if (code_point < 0x10000 || code_point > kUtf8MaxCodePoint) {
+ return std::nullopt;
+ }
+ return code_point;
+ }
+
+ // Invalid UTF-8 start byte
+ return std::nullopt;
}
+void AppendUtf8CodePoint(uint32_t code_point, std::string& target) {
+ if (code_point <= 0x7F) {
+ target.push_back(static_cast<char>(code_point));
+ } else if (code_point <= 0x7FF) {
+ target.push_back(static_cast<char>(0xC0 | (code_point >> 6)));
+ target.push_back(static_cast<char>(0x80 | (code_point & 0x3F)));
+ } else if (code_point <= 0xFFFF) {
+ target.push_back(static_cast<char>(0xE0 | (code_point >> 12)));
+ target.push_back(static_cast<char>(0x80 | ((code_point >> 6) & 0x3F)));
+ target.push_back(static_cast<char>(0x80 | (code_point & 0x3F)));
+ } else {
+ target.push_back(static_cast<char>(0xF0 | (code_point >> 18)));
+ target.push_back(static_cast<char>(0x80 | ((code_point >> 12) & 0x3F)));
+ target.push_back(static_cast<char>(0x80 | ((code_point >> 6) & 0x3F)));
+ target.push_back(static_cast<char>(0x80 | (code_point & 0x3F)));
+ }
+}
+
+template <TypeId type_id>
+Literal TruncateLiteralImpl(const Literal& literal, int32_t width) = delete;
+
template <>
Literal TruncateLiteralImpl<TypeId::kInt>(const Literal& literal, int32_t width) {
int32_t v = std::get<int32_t>(literal.value());
@@ -72,8 +166,80 @@
return Literal::Binary(std::vector<uint8_t>(data.begin(), data.begin() + width));
}
+template <TypeId type_id>
+Result<Literal> TruncateLiteralMaxImpl(const Literal& literal, int32_t width) = delete;
+
+template <>
+Result<Literal> TruncateLiteralMaxImpl<TypeId::kString>(const Literal& literal,
+ int32_t width) {
+ const auto& str = std::get<std::string>(literal.value());
+ ICEBERG_ASSIGN_OR_RAISE(std::string truncated,
+ TruncateUtils::TruncateUTF8Max(str, width));
+ return Literal::String(std::move(truncated));
+}
+
+template <>
+Result<Literal> TruncateLiteralMaxImpl<TypeId::kBinary>(const Literal& literal,
+ int32_t width) {
+ const auto& data = std::get<std::vector<uint8_t>>(literal.value());
+ if (static_cast<int32_t>(data.size()) <= width) {
+ return literal;
+ }
+
+ std::vector<uint8_t> truncated(data.begin(), data.begin() + width);
+ for (auto it = truncated.rbegin(); it != truncated.rend(); ++it) {
+ if (*it < 0xFF) {
+ ++(*it);
+ truncated.resize(truncated.size() - std::distance(truncated.rbegin(), it));
+ return Literal::Binary(std::move(truncated));
+ }
+ }
+ return InvalidArgument("Cannot truncate upper bound for binary: all bytes are 0xFF");
+}
+
} // namespace
+Result<std::string> TruncateUtils::TruncateUTF8Max(const std::string& source, size_t L) {
+ std::string truncated = TruncateUTF8(source, L);
+ if (truncated == source) {
+ return truncated;
+ }
+
+ // Try incrementing code points from the end
+ size_t last_cp_start = truncated.size();
+ while (last_cp_start > 0) {
+ size_t cp_start = last_cp_start;
+ // Find the start of the previous code point
+ do {
+ --cp_start;
+ } while (cp_start > 0 && (static_cast<uint8_t>(truncated[cp_start]) & 0xC0) == 0x80);
+
+ auto code_point_opt = DecodeUtf8CodePoint(
+ std::string_view(truncated.data() + cp_start, last_cp_start - cp_start));
+ if (!code_point_opt.has_value()) {
+ return InvalidArgument("Invalid UTF-8 in string literal");
+ }
+ uint32_t code_point = code_point_opt.value();
+
+ // Try to increment the code point
+ if (code_point < kUtf8MaxCodePoint) {
+ uint32_t next_code_point = code_point + 1;
+ // Skip surrogate range
+ if (next_code_point >= kUtf8MinSurrogate && next_code_point <= kUtf8MaxSurrogate) {
+ next_code_point = kUtf8MaxSurrogate + 1;
+ }
+ if (next_code_point <= kUtf8MaxCodePoint) {
+ truncated.resize(cp_start);
+ AppendUtf8CodePoint(next_code_point, truncated);
+ return truncated;
+ }
+ }
+ last_cp_start = cp_start;
+ }
+ return InvalidArgument(
+ "Cannot truncate upper bound for string: all code points are 0x10FFFF");
+}
+
Decimal TruncateUtils::TruncateDecimal(const Decimal& decimal, int32_t width) {
return decimal - (((decimal % width) + width) % width);
}
@@ -104,4 +270,27 @@
}
}
+#define DISPATCH_TRUNCATE_LITERAL_MAX(TYPE_ID) \
+ case TYPE_ID: \
+ return TruncateLiteralMaxImpl<TYPE_ID>(literal, width);
+
+Result<Literal> TruncateUtils::TruncateLiteralMax(const Literal& literal, int32_t width) {
+ if (literal.IsNull()) [[unlikely]] {
+ // Return null as is
+ return literal;
+ }
+
+ if (literal.IsAboveMax() || literal.IsBelowMin()) [[unlikely]] {
+ return NotSupported("Cannot truncate {}", literal.ToString());
+ }
+
+ switch (literal.type()->type_id()) {
+ DISPATCH_TRUNCATE_LITERAL_MAX(TypeId::kString);
+ DISPATCH_TRUNCATE_LITERAL_MAX(TypeId::kBinary);
+ default:
+ return NotSupported("Truncate max is not supported for type: {}",
+ literal.type()->ToString());
+ }
+}
+
} // namespace iceberg
diff --git a/src/iceberg/util/truncate_util.h b/src/iceberg/util/truncate_util.h
index e24cae3..1a1824a 100644
--- a/src/iceberg/util/truncate_util.h
+++ b/src/iceberg/util/truncate_util.h
@@ -61,6 +61,20 @@
return source;
}
+ /// \brief Truncate a UTF-8 string to a specified number of code points for
+ /// use as an upper-bound value.
+ ///
+ /// When truncation is required, the returned value is the smallest UTF-8
+ /// string greater than the truncated prefix. When no truncation is needed
+ /// for the given width, the original string may be returned unchanged.
+ ///
+ /// \param source The input string to truncate.
+ /// \param L The maximum number of code points allowed in the output string.
+ /// \return A Result containing the original string (if no truncation is
+ /// needed), or the smallest string greater than the truncated prefix, or an
+ /// error if no such value exists or the input is invalid UTF-8.
+ static Result<std::string> TruncateUTF8Max(const std::string& source, size_t L);
+
/// \brief Truncate an integer v, either int32_t or int64_t, to v - (v % W).
///
/// The remainder, v % W, must be positive. For languages where % can produce negative
@@ -86,6 +100,19 @@
/// - [Truncate Transform
/// Details](https://iceberg.apache.org/spec/#truncate-transform-details)
static Result<Literal> TruncateLiteral(const Literal& literal, int32_t width);
+
+ /// \brief Truncate a Literal to a specified width for use as an upper-bound value.
+ ///
+ /// When truncation is required, the returned value is the smallest Literal greater than
+ /// the truncated prefix. When no truncation is needed for the given width, the original
+ /// Literal may be returned unchanged.
+ ///
+ /// \param value The input Literal maximum value to truncate.
+ /// \param width The width to truncate to.
+ /// \return A Result containing either the original Literal (if no truncation is needed)
+ /// or the smallest Literal greater than the truncated prefix, or an error if no such
+ /// value exists or cannot be represented.
+ static Result<Literal> TruncateLiteralMax(const Literal& value, int32_t width);
};
} // namespace iceberg