| // Licensed to the Apache Software Foundation (ASF) under one |
| // or more contributor license agreements. See the NOTICE file |
| // distributed with this work for additional information |
| // regarding copyright ownership. The ASF licenses this file |
| // to you under the Apache License, Version 2.0 (the |
| // "License"); you may not use this file except in compliance |
| // with the License. You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, |
| // software distributed under the License is distributed on an |
| // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| // KIND, either express or implied. See the License for the |
| // specific language governing permissions and limitations |
| // under the License. |
| // This file is copied from |
| // https://github.com/ClickHouse/ClickHouse/blob/master/src/Functions/isIPAddressContainedIn.cpp |
| // and modified by Doris |
| |
| #pragma once |
| |
| #include "util/sse_util.hpp" |
| #include "vec/common/format_ip.h" |
| #include "vec/common/ipv6_to_binary.h" |
| namespace doris { |
| #include "common/compile_check_begin.h" |
| |
| namespace vectorized { |
| static inline std::pair<UInt32, UInt32> apply_cidr_mask(UInt32 src, UInt8 bits_to_keep) { |
| if (bits_to_keep >= 8 * sizeof(UInt32)) { |
| return {src, src}; |
| } |
| if (bits_to_keep == 0) { |
| return {static_cast<UInt32>(0), static_cast<UInt32>(-1)}; |
| } |
| UInt32 mask = static_cast<UInt32>(-1) << (8 * sizeof(UInt32) - bits_to_keep); |
| UInt32 lower = src & mask; |
| UInt32 upper = lower | ~mask; |
| |
| return {lower, upper}; |
| } |
| |
| static inline void apply_cidr_mask(const char* __restrict src, char* __restrict dst_lower, |
| char* __restrict dst_upper, UInt8 bits_to_keep) { |
| // little-endian mask |
| const auto& mask = get_cidr_mask_ipv6(bits_to_keep); |
| |
| for (int8_t i = IPV6_BINARY_LENGTH - 1; i >= 0; --i) { |
| dst_lower[i] = src[i] & mask[i]; |
| dst_upper[i] = char(dst_lower[i] | ~mask[i]); |
| } |
| } |
| |
| static inline void apply_cidr_mask(const char* __restrict src, char* __restrict dst_lower, |
| UInt8 bits_to_keep) { |
| // little-endian mask |
| const auto& mask = get_cidr_mask_ipv6(bits_to_keep); |
| |
| for (int8_t i = IPV6_BINARY_LENGTH - 1; i >= 0; --i) { |
| dst_lower[i] = src[i] & mask[i]; |
| } |
| } |
| |
| } // namespace vectorized |
| |
| class IPAddressVariant { |
| public: |
| explicit IPAddressVariant(std::string_view address_str) { |
| vectorized::Int64 v4 = 0; |
| if (vectorized::parse_ipv4_whole(address_str.begin(), address_str.end(), |
| reinterpret_cast<unsigned char*>(&v4))) { |
| _addr = static_cast<vectorized::UInt32>(v4); |
| } else { |
| _addr = IPv6AddrType(); |
| // parse ipv6 in little-endian |
| if (!vectorized::parse_ipv6_whole(address_str.begin(), address_str.end(), |
| std::get<IPv6AddrType>(_addr).data())) { |
| throw Exception(ErrorCode::INVALID_ARGUMENT, "Neither IPv4 nor IPv6 address: '{}'", |
| address_str); |
| } |
| } |
| } |
| |
| vectorized::UInt32 as_v4() const { |
| if (const auto* val = std::get_if<IPv4AddrType>(&_addr)) { |
| return *val; |
| } |
| return 0; |
| } |
| |
| const vectorized::UInt8* as_v6() const { |
| if (const auto* val = std::get_if<IPv6AddrType>(&_addr)) { |
| return val->data(); |
| } |
| return nullptr; |
| } |
| |
| private: |
| using IPv4AddrType = vectorized::UInt32; |
| using IPv6AddrType = std::array<vectorized::UInt8, IPV6_BINARY_LENGTH>; |
| |
| std::variant<IPv4AddrType, IPv6AddrType> _addr; |
| }; |
| |
| struct IPAddressCIDR { |
| IPAddressVariant _address; |
| vectorized::UInt8 _prefix; |
| }; |
| |
| inline bool match_ipv4_subnet(uint32_t addr, uint32_t cidr_addr, uint8_t prefix) { |
| uint32_t mask = (prefix >= 32) ? 0xffffffffU : ~(0xffffffffU >> prefix); |
| return (addr & mask) == (cidr_addr & mask); |
| } |
| |
| #if defined(__SSE2__) || defined(__aarch64__) |
| |
| inline bool match_ipv6_subnet(const uint8_t* addr, const uint8_t* cidr_addr, uint8_t prefix) { |
| uint16_t mask = (uint16_t)_mm_movemask_epi8( |
| _mm_cmpeq_epi8(_mm_loadu_si128(reinterpret_cast<const __m128i*>(addr)), |
| _mm_loadu_si128(reinterpret_cast<const __m128i*>(cidr_addr)))); |
| mask = ~mask; |
| |
| if (mask) { |
| const auto offset = std::countl_zero(mask); |
| if (prefix / 8 != offset) { |
| return prefix / 8 < offset; |
| } |
| auto cmpmask = ~(0xff >> (prefix % 8)); |
| return (addr[IPV6_BINARY_LENGTH - 1 - offset] & cmpmask) == |
| (cidr_addr[IPV6_BINARY_LENGTH - 1 - offset] & cmpmask); |
| } else { |
| // All the bytes are equal. |
| } |
| return true; |
| } |
| |
| #else |
| // ipv6 liitle-endian input |
| inline bool match_ipv6_subnet(const uint8_t* addr, const uint8_t* cidr_addr, uint8_t prefix) { |
| if (prefix > IPV6_BINARY_LENGTH * 8U) { |
| prefix = IPV6_BINARY_LENGTH * 8U; |
| } |
| size_t i = IPV6_BINARY_LENGTH - 1; |
| |
| for (; prefix >= 8; --i, prefix -= 8) { |
| if (addr[i] != cidr_addr[i]) { |
| return false; |
| } |
| } |
| |
| if (prefix == 0) { |
| return true; |
| } |
| |
| auto mask = ~(0xff >> prefix); |
| return (addr[i] & mask) == (cidr_addr[i] & mask); |
| } |
| #endif |
| |
| inline IPAddressCIDR parse_ip_with_cidr(std::string_view cidr_str) { |
| size_t pos_slash = cidr_str.find('/'); |
| |
| if (pos_slash == 0) { |
| throw Exception(ErrorCode::INVALID_ARGUMENT, "Error parsing IP address with prefix: {}", |
| std::string(cidr_str)); |
| } |
| |
| if (pos_slash == std::string_view::npos) { |
| throw Exception(ErrorCode::INVALID_ARGUMENT, "The text does not contain '/': {}", |
| std::string(cidr_str)); |
| } |
| |
| std::string_view addr_str = cidr_str.substr(0, pos_slash); |
| IPAddressVariant addr(addr_str); |
| |
| uint8_t prefix = 0; |
| auto prefix_str = cidr_str.substr(pos_slash + 1); |
| |
| const auto* prefix_str_end = prefix_str.data() + prefix_str.size(); |
| auto [parse_end, parse_error] = std::from_chars(prefix_str.data(), prefix_str_end, prefix); |
| uint8_t max_prefix = (addr.as_v6() ? IPV6_BINARY_LENGTH : IPV4_BINARY_LENGTH) * 8; |
| |
| if (parse_error != std::errc() || parse_end != prefix_str_end || prefix > max_prefix) { |
| throw Exception(ErrorCode::INVALID_ARGUMENT, "The CIDR has a malformed prefix bits: {}", |
| std::string(cidr_str)); |
| } |
| |
| return {addr, static_cast<uint8_t>(prefix)}; |
| } |
| |
| inline bool is_address_in_range(const IPAddressVariant& address, const IPAddressCIDR& cidr) { |
| const auto* cidr_v6 = cidr._address.as_v6(); |
| const auto* addr_v6 = address.as_v6(); |
| if (cidr_v6) { |
| if (addr_v6) { |
| return match_ipv6_subnet(addr_v6, cidr_v6, cidr._prefix); |
| } |
| } else { |
| if (!addr_v6) { |
| return match_ipv4_subnet(address.as_v4(), cidr._address.as_v4(), cidr._prefix); |
| } |
| } |
| return false; |
| } |
| |
| } // namespace doris |
| #include "common/compile_check_end.h" |