blob: cbd4c5534c1e38fdd13b37f94227f31797a7eb1e [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <algorithm>
#include <variant>
#include "exec/common/hash_table/hash_key_type.h"
#include "exec/common/hash_table/hash_map_context.h"
#include "exec/common/hash_table/join_hash_table.h"
namespace doris {
// Devirtualize compare_at for ASOF JOIN supported column types.
// ASOF JOIN only supports DateV2, DateTimeV2, and TimestampTZ.
// Dispatches to the concrete ColumnVector<T> once so that all compare_at
// calls inside `func` are direct (non-virtual) calls.
// `func` receives a single argument: a const pointer to the concrete column
// (or const IColumn* as fallback for unexpected types).
template <typename Func>
decltype(auto) asof_column_dispatch(const IColumn* col, Func&& func) {
if (const auto* c_dv2 = check_and_get_column<ColumnDateV2>(col)) {
return std::forward<Func>(func)(c_dv2);
} else if (const auto* c_dtv2 = check_and_get_column<ColumnDateTimeV2>(col)) {
return std::forward<Func>(func)(c_dtv2);
} else if (const auto* c_tstz = check_and_get_column<ColumnTimeStampTz>(col)) {
return std::forward<Func>(func)(c_tstz);
} else {
return std::forward<Func>(func)(col);
}
}
using JoinOpVariants =
std::variant<std::integral_constant<TJoinOp::type, TJoinOp::INNER_JOIN>,
std::integral_constant<TJoinOp::type, TJoinOp::LEFT_SEMI_JOIN>,
std::integral_constant<TJoinOp::type, TJoinOp::LEFT_ANTI_JOIN>,
std::integral_constant<TJoinOp::type, TJoinOp::LEFT_OUTER_JOIN>,
std::integral_constant<TJoinOp::type, TJoinOp::FULL_OUTER_JOIN>,
std::integral_constant<TJoinOp::type, TJoinOp::RIGHT_OUTER_JOIN>,
std::integral_constant<TJoinOp::type, TJoinOp::CROSS_JOIN>,
std::integral_constant<TJoinOp::type, TJoinOp::RIGHT_SEMI_JOIN>,
std::integral_constant<TJoinOp::type, TJoinOp::RIGHT_ANTI_JOIN>,
std::integral_constant<TJoinOp::type, TJoinOp::NULL_AWARE_LEFT_ANTI_JOIN>,
std::integral_constant<TJoinOp::type, TJoinOp::NULL_AWARE_LEFT_SEMI_JOIN>,
std::integral_constant<TJoinOp::type, TJoinOp::ASOF_LEFT_INNER_JOIN>,
std::integral_constant<TJoinOp::type, TJoinOp::ASOF_LEFT_OUTER_JOIN>>;
inline bool is_asof_join(TJoinOp::type join_op) {
return join_op == TJoinOp::ASOF_LEFT_INNER_JOIN || join_op == TJoinOp::ASOF_LEFT_OUTER_JOIN;
}
template <int JoinOpType>
inline constexpr bool is_asof_join_op_v =
JoinOpType == TJoinOp::ASOF_LEFT_INNER_JOIN || JoinOpType == TJoinOp::ASOF_LEFT_OUTER_JOIN;
template <int JoinOpType>
inline constexpr bool is_asof_outer_join_op_v = JoinOpType == TJoinOp::ASOF_LEFT_OUTER_JOIN;
template <class T>
using PrimaryTypeHashTableContext = MethodOneNumber<T, JoinHashMap<T, HashCRC32<T>, false>>;
template <class T>
using DirectPrimaryTypeHashTableContext =
MethodOneNumberDirect<T, JoinHashMap<T, HashCRC32<T>, true>>;
template <class Key>
using FixedKeyHashTableContext = MethodKeysFixed<JoinHashMap<Key, HashCRC32<Key>, false>>;
using SerializedHashTableContext =
MethodSerialized<JoinHashMap<StringRef, DefaultHash<StringRef>, false>>;
using MethodOneString = MethodStringNoCache<JoinHashMap<StringRef, DefaultHash<StringRef>, false>>;
using HashTableVariants = std::variant<
std::monostate, SerializedHashTableContext, PrimaryTypeHashTableContext<UInt8>,
PrimaryTypeHashTableContext<UInt16>, PrimaryTypeHashTableContext<UInt32>,
PrimaryTypeHashTableContext<UInt64>, PrimaryTypeHashTableContext<UInt128>,
PrimaryTypeHashTableContext<UInt256>, DirectPrimaryTypeHashTableContext<UInt8>,
DirectPrimaryTypeHashTableContext<UInt16>, DirectPrimaryTypeHashTableContext<UInt32>,
DirectPrimaryTypeHashTableContext<UInt64>, DirectPrimaryTypeHashTableContext<UInt128>,
FixedKeyHashTableContext<UInt64>, FixedKeyHashTableContext<UInt72>,
FixedKeyHashTableContext<UInt96>, FixedKeyHashTableContext<UInt104>,
FixedKeyHashTableContext<UInt128>, FixedKeyHashTableContext<UInt136>,
FixedKeyHashTableContext<UInt256>, MethodOneString>;
struct JoinDataVariants {
HashTableVariants method_variant;
void init(const std::vector<DataTypePtr>& data_types, HashKeyType type) {
switch (type) {
case HashKeyType::serialized:
method_variant.emplace<SerializedHashTableContext>();
break;
case HashKeyType::int8_key:
method_variant.emplace<PrimaryTypeHashTableContext<UInt8>>();
break;
case HashKeyType::int16_key:
method_variant.emplace<PrimaryTypeHashTableContext<UInt16>>();
break;
case HashKeyType::int32_key:
method_variant.emplace<PrimaryTypeHashTableContext<UInt32>>();
break;
case HashKeyType::int64_key:
method_variant.emplace<PrimaryTypeHashTableContext<UInt64>>();
break;
case HashKeyType::int128_key:
method_variant.emplace<PrimaryTypeHashTableContext<UInt128>>();
break;
case HashKeyType::int256_key:
method_variant.emplace<PrimaryTypeHashTableContext<UInt256>>();
break;
case HashKeyType::string_key:
method_variant.emplace<MethodOneString>();
break;
case HashKeyType::fixed64:
method_variant.emplace<FixedKeyHashTableContext<UInt64>>(get_key_sizes(data_types));
break;
case HashKeyType::fixed72:
method_variant.emplace<FixedKeyHashTableContext<UInt72>>(get_key_sizes(data_types));
break;
case HashKeyType::fixed96:
method_variant.emplace<FixedKeyHashTableContext<UInt96>>(get_key_sizes(data_types));
break;
case HashKeyType::fixed104:
method_variant.emplace<FixedKeyHashTableContext<UInt104>>(get_key_sizes(data_types));
break;
case HashKeyType::fixed128:
method_variant.emplace<FixedKeyHashTableContext<UInt128>>(get_key_sizes(data_types));
break;
case HashKeyType::fixed136:
method_variant.emplace<FixedKeyHashTableContext<UInt136>>(get_key_sizes(data_types));
break;
case HashKeyType::fixed256:
method_variant.emplace<FixedKeyHashTableContext<UInt256>>(get_key_sizes(data_types));
break;
default:
throw Exception(ErrorCode::INTERNAL_ERROR,
"JoinDataVariants meet invalid key type, type={}", type);
}
}
};
template <typename Method>
void primary_to_direct_mapping(Method* context, const ColumnRawPtrs& key_columns,
const std::vector<std::shared_ptr<JoinDataVariants>>& variant_ptrs) {
using FieldType = typename Method::Base::Key;
FieldType max_key = std::numeric_limits<FieldType>::min();
FieldType min_key = std::numeric_limits<FieldType>::max();
size_t num_rows = key_columns[0]->size();
if (key_columns[0]->is_nullable()) {
const FieldType* input_keys = (FieldType*)assert_cast<const ColumnNullable*>(key_columns[0])
->get_nested_column_ptr()
->get_raw_data()
.data;
const NullMap& null_map =
assert_cast<const ColumnNullable*>(key_columns[0])->get_null_map_data();
// skip first mocked row
for (size_t i = 1; i < num_rows; i++) {
if (null_map[i]) {
continue;
}
max_key = std::max(max_key, input_keys[i]);
min_key = std::min(min_key, input_keys[i]);
}
} else {
const FieldType* input_keys = (FieldType*)key_columns[0]->get_raw_data().data;
// skip first mocked row
for (size_t i = 1; i < num_rows; i++) {
max_key = std::max(max_key, input_keys[i]);
min_key = std::min(min_key, input_keys[i]);
}
}
constexpr auto MAX_MAPPING_RANGE = 1 << 23;
bool allow_direct_mapping = (max_key >= min_key && max_key - min_key < MAX_MAPPING_RANGE - 1);
if (allow_direct_mapping) {
for (const auto& variant_ptr : variant_ptrs) {
variant_ptr->method_variant.emplace<DirectPrimaryTypeHashTableContext<FieldType>>(
max_key, min_key);
}
}
}
template <typename Method>
void try_convert_to_direct_mapping(
Method* method, const ColumnRawPtrs& key_columns,
const std::vector<std::shared_ptr<JoinDataVariants>>& variant_ptrs) {}
inline void try_convert_to_direct_mapping(
PrimaryTypeHashTableContext<UInt8>* context, const ColumnRawPtrs& key_columns,
const std::vector<std::shared_ptr<JoinDataVariants>>& variant_ptrs) {
primary_to_direct_mapping(context, key_columns, variant_ptrs);
}
inline void try_convert_to_direct_mapping(
PrimaryTypeHashTableContext<UInt16>* context, const ColumnRawPtrs& key_columns,
const std::vector<std::shared_ptr<JoinDataVariants>>& variant_ptrs) {
primary_to_direct_mapping(context, key_columns, variant_ptrs);
}
inline void try_convert_to_direct_mapping(
PrimaryTypeHashTableContext<UInt32>* context, const ColumnRawPtrs& key_columns,
const std::vector<std::shared_ptr<JoinDataVariants>>& variant_ptrs) {
primary_to_direct_mapping(context, key_columns, variant_ptrs);
}
inline void try_convert_to_direct_mapping(
PrimaryTypeHashTableContext<UInt64>* context, const ColumnRawPtrs& key_columns,
const std::vector<std::shared_ptr<JoinDataVariants>>& variant_ptrs) {
primary_to_direct_mapping(context, key_columns, variant_ptrs);
}
inline void try_convert_to_direct_mapping(
PrimaryTypeHashTableContext<UInt128>* context, const ColumnRawPtrs& key_columns,
const std::vector<std::shared_ptr<JoinDataVariants>>& variant_ptrs) {
primary_to_direct_mapping(context, key_columns, variant_ptrs);
}
// ASOF JOIN index with inline values for cache-friendly branchless binary search.
// IntType is the integer representation of the ASOF column value:
// uint32_t for DateV2, uint64_t for DateTimeV2 and TimestampTZ.
// Rows are sorted by asof_value during build, then materialized into SoA arrays
// so probe-side binary search only touches the ASOF values hot path.
template <typename IntType>
struct AsofIndexGroup {
using int_type = IntType;
struct Entry {
IntType asof_value;
uint32_t row_index; // 1-based, 0 = invalid/padding
};
std::vector<Entry> entries;
std::vector<IntType> asof_values;
std::vector<uint32_t> row_indexes;
void add_row(IntType value, uint32_t row_idx) { entries.push_back({value, row_idx}); }
void sort_and_finalize() {
if (entries.empty()) {
return;
}
if (entries.size() > 1) {
pdqsort(entries.begin(), entries.end(),
[](const Entry& a, const Entry& b) { return a.asof_value < b.asof_value; });
}
asof_values.resize(entries.size());
row_indexes.resize(entries.size());
for (size_t i = 0; i < entries.size(); ++i) {
asof_values[i] = entries[i].asof_value;
row_indexes[i] = entries[i].row_index;
}
std::vector<Entry>().swap(entries);
}
const IntType* values_data() const { return asof_values.data(); }
// Branchless lower_bound: first i where asof_values[i] >= target
ALWAYS_INLINE size_t lower_bound(IntType target) const {
size_t lo = 0, n = asof_values.size();
while (n > 1) {
size_t half = n / 2;
lo += half * (asof_values[lo + half] < target);
n -= half;
}
if (lo < asof_values.size()) {
lo += (asof_values[lo] < target);
}
return lo;
}
// Branchless upper_bound: first i where asof_values[i] > target
ALWAYS_INLINE size_t upper_bound(IntType target) const {
size_t lo = 0, n = asof_values.size();
while (n > 1) {
size_t half = n / 2;
lo += half * (asof_values[lo + half] <= target);
n -= half;
}
if (lo < asof_values.size()) {
lo += (asof_values[lo] <= target);
}
return lo;
}
// Semantics by (is_greater, is_strict):
// (true, false): probe >= build -> find largest build value <= probe
// (true, true): probe > build -> find largest build value < probe
// (false, false): probe <= build -> find smallest build value >= probe
// (false, true): probe < build -> find smallest build value > probe
// Returns the build row index of the best match, or 0 if no match.
template <bool IsGreater, bool IsStrict>
ALWAYS_INLINE uint32_t find_best_match(IntType probe_value) const {
if (asof_values.empty()) {
return 0;
}
if constexpr (IsGreater) {
size_t pos = IsStrict ? lower_bound(probe_value) : upper_bound(probe_value);
return pos > 0 ? row_indexes[pos - 1] : 0;
} else {
size_t pos = IsStrict ? upper_bound(probe_value) : lower_bound(probe_value);
return pos < asof_values.size() ? row_indexes[pos] : 0;
}
}
};
// Type-erased container for all ASOF index groups.
// DateV2 -> uint32_t, DateTimeV2/TimestampTZ -> uint64_t.
using AsofIndexVariant = std::variant<std::monostate, std::vector<AsofIndexGroup<uint32_t>>,
std::vector<AsofIndexGroup<uint64_t>>>;
} // namespace doris