blob: 4f9808ccf8e401b266859261fb5649f4ccbeb1c0 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#ifdef __AVX2__
#include <immintrin.h>
#endif
#include <unistd.h>
#include <array>
#include <cstddef>
#include <cstdint>
#include "util/simd/lower_upper_impl.h"
#include "util/sse_util.hpp"
#include "vec/common/string_ref.h"
namespace doris {
static constexpr std::array<uint8_t, 256> UTF8_BYTE_LENGTH = {
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6};
inline uint8_t get_utf8_byte_length(uint8_t character) {
return UTF8_BYTE_LENGTH[character];
}
// copy from https://github.com/lemire/fastvalidate-utf-8/blob/master/include/simdasciicheck.h
// The function returns true (1) if all chars passed in src are
// 7-bit values (0x00..0x7F). Otherwise, it returns false (0).
inline bool validate_ascii_fast(const char* src, size_t len) {
size_t i = 0;
__m128i has_error = _mm_setzero_si128();
if (len >= 16) {
for (; i <= len - 16; i += 16) {
__m128i current_bytes = _mm_loadu_si128((const __m128i*)(src + i));
has_error = _mm_or_si128(has_error, current_bytes);
}
}
int error_mask = _mm_movemask_epi8(has_error);
char tail_has_error = 0;
for (; i < len; i++) {
tail_has_error |= src[i];
}
error_mask |= (tail_has_error & 0x80);
return !error_mask;
}
#ifdef __AVX2__
#include <x86intrin.h>
// The function returns true (1) if all chars passed in src are
// 7-bit values (0x00..0x7F). Otherwise, it returns false (0).
inline bool validate_ascii_fast_avx(const char* src, size_t len) {
size_t i = 0;
__m256i has_error = _mm256_setzero_si256();
if (len >= 32) {
for (; i <= len - 32; i += 32) {
__m256i current_bytes = _mm256_loadu_si256((const __m256i*)(src + i));
has_error = _mm256_or_si256(has_error, current_bytes);
}
}
int error_mask = _mm256_movemask_epi8(has_error);
char tail_has_error = 0;
for (; i < len; i++) {
tail_has_error |= src[i];
}
error_mask |= (tail_has_error & 0x80);
return !error_mask;
}
#endif
namespace simd {
class VStringFunctions {
public:
#if defined(__SSE2__) || defined(__aarch64__)
/// n equals to 16 chars length
static constexpr auto REGISTER_SIZE = sizeof(__m128i);
#endif
template <bool trim_single>
static inline const unsigned char* rtrim(const unsigned char* begin, const unsigned char* end,
const StringRef& remove_str) {
if (remove_str.size == 0) {
return end;
}
const auto* p = end;
if constexpr (trim_single) {
const auto ch = remove_str.data[0];
#if defined(__AVX2__)
constexpr auto AVX2_BYTES = sizeof(__m256i);
const auto size = end - begin;
const auto* const avx2_begin = end - size / AVX2_BYTES * AVX2_BYTES;
const auto spaces = _mm256_set1_epi8(ch);
for (p = end - AVX2_BYTES; p >= avx2_begin; p -= AVX2_BYTES) {
uint32_t masks = _mm256_movemask_epi8(
_mm256_cmpeq_epi8(_mm256_loadu_si256((__m256i*)p), spaces));
if ((~masks)) {
break;
}
}
p += AVX2_BYTES;
#endif
for (; (p - 1) >= begin && *(p - 1) == ch; p--) {
}
return p;
}
const auto remove_size = remove_str.size;
const auto* const remove_data = remove_str.data;
while (p - begin >= remove_size) {
if (memcmp(p - remove_size, remove_data, remove_size) == 0) {
p -= remove_str.size;
} else {
break;
}
}
return p;
}
template <bool trim_single>
static inline const unsigned char* ltrim(const unsigned char* begin, const unsigned char* end,
const StringRef& remove_str) {
if (remove_str.size == 0) {
return begin;
}
const auto* p = begin;
if constexpr (trim_single) {
const auto ch = remove_str.data[0];
#if defined(__AVX2__)
constexpr auto AVX2_BYTES = sizeof(__m256i);
const auto size = end - begin;
const auto* const avx2_end = begin + size / AVX2_BYTES * AVX2_BYTES;
const auto spaces = _mm256_set1_epi8(ch);
for (; p < avx2_end; p += AVX2_BYTES) {
uint32_t masks = _mm256_movemask_epi8(
_mm256_cmpeq_epi8(_mm256_loadu_si256((__m256i*)p), spaces));
if ((~masks)) {
break;
}
}
#endif
for (; p < end && *p == ch; ++p) {
}
return p;
}
const auto remove_size = remove_str.size;
const auto* const remove_data = remove_str.data;
while (end - p >= remove_size) {
if (memcmp(p, remove_data, remove_size) == 0) {
p += remove_str.size;
} else {
break;
}
}
return p;
}
// Iterate a UTF-8 string without exceeding a given length n.
// The function returns two values:
// the first represents the byte length traversed, and the second represents the char length traversed.
static inline std::pair<size_t, size_t> iterate_utf8_with_limit_length(const char* begin,
const char* end,
size_t n) {
const char* p = begin;
int char_size = 0;
size_t i = 0;
for (; i < n && p < end; ++i, p += char_size) {
char_size = UTF8_BYTE_LENGTH[static_cast<uint8_t>(*p)];
}
return {p - begin, i};
}
// Gcc will do auto simd in this function
// if input empty, return true
static bool is_ascii(const StringRef& str) {
#ifdef __AVX2__
return validate_ascii_fast_avx(str.data, str.size);
#endif
return validate_ascii_fast(str.data, str.size);
}
static void reverse(const StringRef& str, StringRef dst) {
if (is_ascii(str)) {
int64_t begin = 0;
int64_t end = str.size;
int64_t result_end = dst.size - 1;
// auto SIMD here
auto* __restrict l = const_cast<char*>(dst.data);
auto* __restrict r = str.data;
for (; begin < end; ++begin, --result_end) {
l[result_end] = r[begin];
}
} else {
char* dst_data = const_cast<char*>(dst.data);
for (size_t i = 0, char_size = 0; i < str.size; i += char_size) {
char_size = UTF8_BYTE_LENGTH[(unsigned char)(str.data)[i]];
// there exists occasion where the last character is an illegal UTF-8 one which returns
// a char_size larger than the actual space, which would cause offset execeeding the buffer range
// for example, consider str.size=4, i = 3, then the last char returns char_size 2, then
// the str.data + offset would exceed the buffer range
size_t offset = i + char_size;
if (offset > str.size) {
offset = str.size;
}
std::copy(str.data + i, str.data + offset, dst_data + str.size - offset);
}
}
}
static void hex_encode(const unsigned char* src_str, size_t length, char* dst_str) {
static constexpr auto hex_table = "0123456789ABCDEF";
auto src_str_end = src_str + length;
#if defined(__SSE2__) || defined(__aarch64__)
constexpr auto step = sizeof(uint64_t);
if (src_str + step < src_str_end) {
const auto hex_map = _mm_loadu_si128(reinterpret_cast<const __m128i*>(hex_table));
const auto mask_map = _mm_set1_epi8(0x0F);
do {
auto data = _mm_loadu_si64(src_str);
auto hex_loc =
_mm_and_si128(_mm_unpacklo_epi8(_mm_srli_epi64(data, 4), data), mask_map);
_mm_storeu_si128(reinterpret_cast<__m128i*>(dst_str),
_mm_shuffle_epi8(hex_map, hex_loc));
src_str += step;
dst_str += step * 2;
} while (src_str + step < src_str_end);
}
#endif
char res[2];
// hex(str) str length is n, result must be 2 * n length
for (; src_str < src_str_end; src_str += 1, dst_str += 2) {
// low 4 bits
*(res + 1) = hex_table[src_str[0] & 0x0F];
// high 4 bits
*res = hex_table[(src_str[0] >> 4)];
std::copy(res, res + 2, dst_str);
}
}
static void to_lower(const uint8_t* src, int64_t len, uint8_t* dst) {
if (len <= 0) {
return;
}
LowerUpperImpl<'A', 'Z'> lowerUpper;
lowerUpper.transfer(src, src + len, dst);
}
static void to_upper(const uint8_t* src, int64_t len, uint8_t* dst) {
if (len <= 0) {
return;
}
LowerUpperImpl<'a', 'z'> lowerUpper;
lowerUpper.transfer(src, src + len, dst);
}
static inline size_t get_char_len(const char* src, size_t len, std::vector<size_t>& str_index) {
size_t char_len = 0;
for (size_t i = 0, char_size = 0; i < len; i += char_size) {
char_size = UTF8_BYTE_LENGTH[(unsigned char)src[i]];
str_index.push_back(i);
++char_len;
}
return char_len;
}
// utf8-encoding:
// - 1-byte: 0xxx_xxxx;
// - 2-byte: 110x_xxxx 10xx_xxxx;
// - 3-byte: 1110_xxxx 10xx_xxxx 10xx_xxxx;
// - 4-byte: 1111_0xxx 10xx_xxxx 10xx_xxxx 10xx_xxxx.
// Counting utf8 chars in a byte string is equivalent to counting first byte of utf chars, that
// is to say, counting bytes which do not match 10xx_xxxx pattern.
// All 0xxx_xxxx, 110x_xxxx, 1110_xxxx and 1111_0xxx are greater than 1011_1111 when use int8_t arithmetic,
// so just count bytes greater than 1011_1111 in a byte string as the result of utf8_length.
// get_char_len is used to return the UTF-8 length of a string.
// The return value will never exceed len.
template <typename T>
static inline T get_char_len(const char* src, T len) {
T char_len = 0;
const char* p = src;
const char* end = p + len;
#if defined(__SSE2__) || defined(__aarch64__)
constexpr auto bytes_sse2 = sizeof(__m128i);
const auto src_end_sse2 = p + (len & ~(bytes_sse2 - 1));
// threshold = 1011_1111
const auto threshold = _mm_set1_epi8(0xBF);
for (; p < src_end_sse2; p += bytes_sse2) {
char_len += __builtin_popcount(_mm_movemask_epi8(_mm_cmpgt_epi8(
_mm_loadu_si128(reinterpret_cast<const __m128i*>(p)), threshold)));
}
#endif
// process remaining bytes the number of which not exceed bytes_sse2 at the
// tail of string, one by one.
for (; p < end; ++p) {
char_len += static_cast<int8_t>(*p) > static_cast<int8_t>(0xBF);
}
return char_len;
}
};
} // namespace simd
} // namespace doris