blob: cba7ee62a76e714263bc4371e283e78defcf91a6 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <variant>
#include <vector>
#include "core/arena.h"
#include "exec/common/hash_table/hash_map_context.h"
#include "exec/common/hash_table/hash_map_util.h"
#include "exec/common/hash_table/ph_hash_map.h"
#include "exec/common/hash_table/string_hash_map.h"
namespace doris {
template <typename T>
using AggData = PHHashMap<T, AggregateDataPtr, HashCRC32<T>>;
template <typename T>
using AggDataNullable = DataWithNullKey<AggData<T>>;
using AggregatedDataWithoutKey = AggregateDataPtr;
using AggregatedDataWithStringKey = PHHashMap<StringRef, AggregateDataPtr>;
using AggregatedDataWithShortStringKey = StringHashMap<AggregateDataPtr>;
using AggregatedDataWithUInt32KeyPhase2 =
PHHashMap<UInt32, AggregateDataPtr, HashMixWrapper<UInt32>>;
using AggregatedDataWithUInt64KeyPhase2 =
PHHashMap<UInt64, AggregateDataPtr, HashMixWrapper<UInt64>>;
using AggregatedDataWithNullableUInt32KeyPhase2 =
DataWithNullKey<AggregatedDataWithUInt32KeyPhase2>;
using AggregatedDataWithNullableUInt64KeyPhase2 =
DataWithNullKey<AggregatedDataWithUInt64KeyPhase2>;
using AggregatedDataWithNullableShortStringKey = DataWithNullKey<AggregatedDataWithShortStringKey>;
using AggregatedDataWithNullableStringKey = DataWithNullKey<AggregatedDataWithStringKey>;
/// Parameterized method variant for aggregation hash tables.
/// StringData / NullableStringData control which hash map is used for string keys:
/// - AggregatedDataVariants uses StringHashMap (AggregatedDataWithShortStringKey)
/// - BucketedAggDataVariants uses PHHashMap<StringRef> (AggregatedDataWithStringKey)
/// to avoid StringHashMap's sub-table complexity and unify the emplace interface.
template <typename StringData, typename NullableStringData>
using AggMethodVariantsBase = std::variant<
std::monostate, MethodSerialized<AggregatedDataWithStringKey>,
MethodOneNumber<UInt8, AggData<UInt8>>, MethodOneNumber<UInt16, AggData<UInt16>>,
MethodOneNumber<UInt32, AggData<UInt32>>, MethodOneNumber<UInt64, AggData<UInt64>>,
MethodStringNoCache<StringData>, MethodOneNumber<UInt128, AggData<UInt128>>,
MethodOneNumber<UInt256, AggData<UInt256>>,
MethodOneNumber<UInt32, AggregatedDataWithUInt32KeyPhase2>,
MethodOneNumber<UInt64, AggregatedDataWithUInt64KeyPhase2>,
MethodSingleNullableColumn<MethodOneNumber<UInt8, AggDataNullable<UInt8>>>,
MethodSingleNullableColumn<MethodOneNumber<UInt16, AggDataNullable<UInt16>>>,
MethodSingleNullableColumn<MethodOneNumber<UInt32, AggDataNullable<UInt32>>>,
MethodSingleNullableColumn<MethodOneNumber<UInt64, AggDataNullable<UInt64>>>,
MethodSingleNullableColumn<
MethodOneNumber<UInt32, AggregatedDataWithNullableUInt32KeyPhase2>>,
MethodSingleNullableColumn<
MethodOneNumber<UInt64, AggregatedDataWithNullableUInt64KeyPhase2>>,
MethodSingleNullableColumn<MethodOneNumber<UInt128, AggDataNullable<UInt128>>>,
MethodSingleNullableColumn<MethodOneNumber<UInt256, AggDataNullable<UInt256>>>,
MethodSingleNullableColumn<MethodStringNoCache<NullableStringData>>,
MethodKeysFixed<AggData<UInt64>>, MethodKeysFixed<AggData<UInt72>>,
MethodKeysFixed<AggData<UInt96>>, MethodKeysFixed<AggData<UInt104>>,
MethodKeysFixed<AggData<UInt128>>, MethodKeysFixed<AggData<UInt136>>,
MethodKeysFixed<AggData<UInt256>>>;
using AggregatedMethodVariants = AggMethodVariantsBase<AggregatedDataWithShortStringKey,
AggregatedDataWithNullableShortStringKey>;
/// Bucketed agg uses PHHashMap<StringRef> for string keys instead of StringHashMap.
/// This avoids StringHashMap's sub-table complexity and unifies the emplace interface
/// (3-arg PHHashMap::emplace), while still using HashMethodString for correct
/// single-column string key extraction.
using BucketedAggMethodVariants =
AggMethodVariantsBase<AggregatedDataWithStringKey, AggregatedDataWithNullableStringKey>;
/// Intermediate base that adds the shared init logic for aggregation data
/// variants. Only the string_key case differs between AggregatedDataVariants
/// and BucketedAggDataVariants; all other key types are identical. The
/// StringData/NullableStringData template parameters control which hash map
/// type is emplaced for string_key.
template <typename MethodVariants, typename StringData, typename NullableStringData>
struct AggDataVariantsBase : public DataVariants<MethodVariants, MethodSingleNullableColumn,
MethodOneNumber, DataWithNullKey> {
void init_agg_data(const std::vector<DataTypePtr>& data_types, HashKeyType type) {
bool nullable = data_types.size() == 1 && data_types[0]->is_nullable();
switch (type) {
case HashKeyType::without_key:
break;
case HashKeyType::serialized:
this->method_variant.template emplace<MethodSerialized<AggregatedDataWithStringKey>>();
break;
case HashKeyType::int8_key:
this->template emplace_single<UInt8, AggData<UInt8>>(nullable);
break;
case HashKeyType::int16_key:
this->template emplace_single<UInt16, AggData<UInt16>>(nullable);
break;
case HashKeyType::int32_key:
this->template emplace_single<UInt32, AggData<UInt32>>(nullable);
break;
case HashKeyType::int32_key_phase2:
this->template emplace_single<UInt32, AggregatedDataWithUInt32KeyPhase2>(nullable);
break;
case HashKeyType::int64_key:
this->template emplace_single<UInt64, AggData<UInt64>>(nullable);
break;
case HashKeyType::int64_key_phase2:
this->template emplace_single<UInt64, AggregatedDataWithUInt64KeyPhase2>(nullable);
break;
case HashKeyType::int128_key:
this->template emplace_single<UInt128, AggData<UInt128>>(nullable);
break;
case HashKeyType::int256_key:
this->template emplace_single<UInt256, AggData<UInt256>>(nullable);
break;
case HashKeyType::string_key:
if (nullable) {
this->method_variant.template emplace<
MethodSingleNullableColumn<MethodStringNoCache<NullableStringData>>>();
} else {
this->method_variant.template emplace<MethodStringNoCache<StringData>>();
}
break;
case HashKeyType::fixed64:
this->method_variant.template emplace<MethodKeysFixed<AggData<UInt64>>>(
get_key_sizes(data_types));
break;
case HashKeyType::fixed72:
this->method_variant.template emplace<MethodKeysFixed<AggData<UInt72>>>(
get_key_sizes(data_types));
break;
case HashKeyType::fixed96:
this->method_variant.template emplace<MethodKeysFixed<AggData<UInt96>>>(
get_key_sizes(data_types));
break;
case HashKeyType::fixed104:
this->method_variant.template emplace<MethodKeysFixed<AggData<UInt104>>>(
get_key_sizes(data_types));
break;
case HashKeyType::fixed128:
this->method_variant.template emplace<MethodKeysFixed<AggData<UInt128>>>(
get_key_sizes(data_types));
break;
case HashKeyType::fixed136:
this->method_variant.template emplace<MethodKeysFixed<AggData<UInt136>>>(
get_key_sizes(data_types));
break;
case HashKeyType::fixed256:
this->method_variant.template emplace<MethodKeysFixed<AggData<UInt256>>>(
get_key_sizes(data_types));
break;
default:
throw Exception(ErrorCode::INTERNAL_ERROR, "meet invalid agg key type, type={}", type);
}
}
};
struct AggregatedDataVariants
: public AggDataVariantsBase<AggregatedMethodVariants, AggregatedDataWithShortStringKey,
AggregatedDataWithNullableShortStringKey> {
AggregatedDataWithoutKey without_key = nullptr;
void init(const std::vector<DataTypePtr>& data_types, HashKeyType type) {
this->init_agg_data(data_types, type);
}
};
using AggregatedDataVariantsUPtr = std::unique_ptr<AggregatedDataVariants>;
using ArenaUPtr = std::unique_ptr<Arena>;
/// Data variants for bucketed hash aggregation.
/// Uses BucketedAggMethodVariants (PHHashMap for string keys).
struct BucketedAggDataVariants
: public AggDataVariantsBase<BucketedAggMethodVariants, AggregatedDataWithStringKey,
AggregatedDataWithNullableStringKey> {
void init(const std::vector<DataTypePtr>& data_types, HashKeyType type) {
this->init_agg_data(data_types, type);
}
};
using BucketedAggDataVariantsUPtr = std::unique_ptr<BucketedAggDataVariants>;
struct AggregateDataContainer {
public:
AggregateDataContainer(size_t size_of_key, size_t size_of_aggregate_states)
: _size_of_key(size_of_key), _size_of_aggregate_states(size_of_aggregate_states) {}
int64_t memory_usage() const { return _arena_pool.size(); }
template <typename KeyType>
AggregateDataPtr append_data(const KeyType& key) {
DCHECK_EQ(sizeof(KeyType), _size_of_key);
// SUB_CONTAINER_CAPACITY should add a new sub container, and also expand when it is zero
if (UNLIKELY(_index_in_sub_container % SUB_CONTAINER_CAPACITY == 0)) {
_expand();
}
*reinterpret_cast<KeyType*>(_current_keys) = key;
auto* aggregate_data = _current_agg_data;
++_total_count;
++_index_in_sub_container;
_current_agg_data += _size_of_aggregate_states;
_current_keys += _size_of_key;
return aggregate_data;
}
template <typename Derived, bool IsConst>
class IteratorBase {
using Container =
std::conditional_t<IsConst, const AggregateDataContainer, AggregateDataContainer>;
Container* container = nullptr;
uint32_t index;
uint32_t sub_container_index;
uint32_t index_in_sub_container;
friend class HashTable;
public:
IteratorBase() = default;
IteratorBase(Container* container_, uint32_t index_)
: container(container_), index(index_) {
sub_container_index = index / SUB_CONTAINER_CAPACITY;
index_in_sub_container = index - sub_container_index * SUB_CONTAINER_CAPACITY;
}
bool operator==(const IteratorBase& rhs) const { return index == rhs.index; }
bool operator!=(const IteratorBase& rhs) const { return index != rhs.index; }
Derived& operator++() {
index++;
index_in_sub_container++;
if (index_in_sub_container == SUB_CONTAINER_CAPACITY) {
index_in_sub_container = 0;
sub_container_index++;
}
return static_cast<Derived&>(*this);
}
template <typename KeyType>
KeyType get_key() {
DCHECK_EQ(sizeof(KeyType), container->_size_of_key);
return ((KeyType*)(container->_key_containers[sub_container_index]))
[index_in_sub_container];
}
AggregateDataPtr get_aggregate_data() {
return &(container->_value_containers[sub_container_index]
[container->_size_of_aggregate_states *
index_in_sub_container]);
}
};
class Iterator : public IteratorBase<Iterator, false> {
public:
using IteratorBase<Iterator, false>::IteratorBase;
};
class ConstIterator : public IteratorBase<ConstIterator, true> {
public:
using IteratorBase<ConstIterator, true>::IteratorBase;
};
ConstIterator begin() const { return {this, 0}; }
ConstIterator cbegin() const { return begin(); }
Iterator begin() { return {this, 0}; }
ConstIterator end() const { return {this, _total_count}; }
ConstIterator cend() const { return end(); }
Iterator end() { return {this, _total_count}; }
[[nodiscard]] uint32_t total_count() const { return _total_count; }
size_t estimate_memory(size_t rows) const {
bool need_to_expand = false;
if (_total_count == 0) {
need_to_expand = true;
} else if ((_index_in_sub_container + rows) > SUB_CONTAINER_CAPACITY) {
need_to_expand = true;
rows -= (SUB_CONTAINER_CAPACITY - _index_in_sub_container);
}
if (!need_to_expand) {
return 0;
}
size_t count = (rows + SUB_CONTAINER_CAPACITY - 1) / SUB_CONTAINER_CAPACITY;
size_t size = _size_of_key * SUB_CONTAINER_CAPACITY;
size += _size_of_aggregate_states * SUB_CONTAINER_CAPACITY;
size *= count;
return size;
}
void init_once() {
if (_inited) {
return;
}
_inited = true;
iterator = begin();
}
Iterator iterator;
private:
void _expand() {
_index_in_sub_container = 0;
_current_keys = nullptr;
_current_agg_data = nullptr;
try {
_current_keys = _arena_pool.alloc(_size_of_key * SUB_CONTAINER_CAPACITY);
_key_containers.emplace_back(_current_keys);
_current_agg_data = (AggregateDataPtr)_arena_pool.alloc(_size_of_aggregate_states *
SUB_CONTAINER_CAPACITY);
_value_containers.emplace_back(_current_agg_data);
} catch (...) {
if (_current_keys) {
_key_containers.pop_back();
_current_keys = nullptr;
}
if (_current_agg_data) {
_value_containers.pop_back();
_current_agg_data = nullptr;
}
throw;
}
}
static constexpr uint32_t SUB_CONTAINER_CAPACITY = 8192;
Arena _arena_pool;
std::vector<char*> _key_containers;
std::vector<AggregateDataPtr> _value_containers;
AggregateDataPtr _current_agg_data = nullptr;
char* _current_keys = nullptr;
size_t _size_of_key {};
size_t _size_of_aggregate_states {};
uint32_t _index_in_sub_container {};
uint32_t _total_count {};
bool _inited = false;
};
} // namespace doris