| // Licensed to the Apache Software Foundation (ASF) under one |
| // or more contributor license agreements. See the NOTICE file |
| // distributed with this work for additional information |
| // regarding copyright ownership. The ASF licenses this file |
| // to you under the Apache License, Version 2.0 (the |
| // "License"); you may not use this file except in compliance |
| // with the License. You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, |
| // software distributed under the License is distributed on an |
| // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| // KIND, either express or implied. See the License for the |
| // specific language governing permissions and limitations |
| // under the License. |
| |
| #pragma once |
| |
| #include <map> |
| #include <memory> |
| #include <string> |
| #include <utility> |
| #include <vector> |
| |
| #include "common/exception.h" |
| #include "common/status.h" |
| #include "vec/aggregate_functions/aggregate_function.h" |
| #include "vec/aggregate_functions/aggregate_function_simple_factory.h" |
| #include "vec/columns/column.h" |
| #include "vec/columns/column_decimal.h" |
| #include "vec/columns/column_string.h" |
| #include "vec/common/assert_cast.h" |
| #include "vec/common/string_ref.h" |
| #include "vec/core/types.h" |
| #include "vec/data_types/data_type_string.h" |
| #include "vec/io/io_helper.h" |
| #include "vec/utils/histogram_helpers.hpp" |
| |
| namespace doris { |
| #include "common/compile_check_begin.h" |
| } // namespace doris |
| |
| namespace doris::vectorized { |
| |
| template <PrimitiveType T> |
| struct AggregateFunctionHistogramData { |
| static constexpr auto Ptype = T; |
| using ColVecType = typename PrimitiveTypeTraits<T>::ColumnType; |
| const static size_t DEFAULT_BUCKET_NUM = 128; |
| const static size_t BUCKET_NUM_INIT_VALUE = 0; |
| |
| void set_parameters(size_t input_max_num_buckets) { max_num_buckets = input_max_num_buckets; } |
| |
| void reset() { ordered_map.clear(); } |
| |
| void add(const StringRef& value, const UInt64& number = 1) { |
| std::string data = value.to_string(); |
| auto it = ordered_map.find(data); |
| if (it != ordered_map.end()) { |
| it->second = it->second + number; |
| } else { |
| ordered_map.insert({data, number}); |
| } |
| } |
| |
| void add(const typename PrimitiveTypeTraits<T>::ColumnItemType& value, |
| const UInt64& number = 1) { |
| auto it = ordered_map.find(value); |
| if (it != ordered_map.end()) { |
| it->second = it->second + number; |
| } else { |
| ordered_map.insert({value, number}); |
| } |
| } |
| |
| void merge(const AggregateFunctionHistogramData& rhs) { |
| // if rhs.max_num_buckets == 0, it means the input block for serialization is all null |
| // we should discard this data, because histogram only fouce on the not-null data |
| if (!rhs.max_num_buckets) { |
| return; |
| } |
| |
| max_num_buckets = rhs.max_num_buckets; |
| |
| for (auto rhs_it : rhs.ordered_map) { |
| auto lhs_it = ordered_map.find(rhs_it.first); |
| if (lhs_it != ordered_map.end()) { |
| lhs_it->second += rhs_it.second; |
| } else { |
| ordered_map.insert({rhs_it.first, rhs_it.second}); |
| } |
| } |
| } |
| |
| void write(BufferWritable& buf) const { |
| buf.write_binary(max_num_buckets); |
| auto element_number = (size_t)ordered_map.size(); |
| buf.write_binary(element_number); |
| |
| auto pair_vector = map_to_vector(); |
| |
| for (auto i = 0; i < element_number; i++) { |
| auto element = pair_vector[i]; |
| buf.write_binary(element.second); |
| buf.write_binary(element.first); |
| } |
| } |
| |
| void read(BufferReadable& buf) { |
| buf.read_binary(max_num_buckets); |
| |
| size_t element_number = 0; |
| buf.read_binary(element_number); |
| |
| ordered_map.clear(); |
| std::pair<typename PrimitiveTypeTraits<T>::ColumnItemType, size_t> element; |
| for (auto i = 0; i < element_number; i++) { |
| buf.read_binary(element.first); |
| buf.read_binary(element.second); |
| ordered_map.insert(element); |
| } |
| } |
| |
| void insert_result_into(IColumn& to) const { |
| auto pair_vector = map_to_vector(); |
| for (auto i = 0; i < pair_vector.size(); i++) { |
| const auto& element = pair_vector[i]; |
| if constexpr (is_string_type(T)) { |
| assert_cast<ColumnString&>(to).insert_data(element.second.c_str(), |
| element.second.length()); |
| } else { |
| assert_cast<ColVecType&>(to).get_data().push_back(element.second); |
| } |
| } |
| } |
| |
| std::string get(const DataTypePtr& data_type) const { |
| std::vector<Bucket<typename PrimitiveTypeTraits<T>::ColumnItemType>> buckets; |
| rapidjson::StringBuffer buffer; |
| // NOTE: We need an extral branch for to handle max_num_buckets == 0, |
| // when target column is nullable, and input block is all null, |
| // set_parameters will not be called because of the short-circuit in |
| // AggregateFunctionNullVariadicInline, so max_num_buckets will be 0 in this situation. |
| build_histogram( |
| buckets, ordered_map, |
| max_num_buckets == BUCKET_NUM_INIT_VALUE ? DEFAULT_BUCKET_NUM : max_num_buckets); |
| histogram_to_json(buffer, buckets, data_type); |
| return {buffer.GetString()}; |
| } |
| |
| std::vector<std::pair<size_t, typename PrimitiveTypeTraits<T>::ColumnItemType>> map_to_vector() |
| const { |
| std::vector<std::pair<size_t, typename PrimitiveTypeTraits<T>::ColumnItemType>> pair_vector; |
| for (auto it : ordered_map) { |
| pair_vector.emplace_back(it.second, it.first); |
| } |
| return pair_vector; |
| } |
| |
| private: |
| size_t max_num_buckets = BUCKET_NUM_INIT_VALUE; |
| std::map<typename PrimitiveTypeTraits<T>::ColumnItemType, size_t> ordered_map; |
| }; |
| |
| template <typename Data, bool has_input_param> |
| class AggregateFunctionHistogram final |
| : public IAggregateFunctionDataHelper<Data, |
| AggregateFunctionHistogram<Data, has_input_param>>, |
| VarargsExpression, |
| NotNullableAggregateFunction { |
| public: |
| AggregateFunctionHistogram() = default; |
| AggregateFunctionHistogram(const DataTypes& argument_types_) |
| : IAggregateFunctionDataHelper<Data, AggregateFunctionHistogram<Data, has_input_param>>( |
| argument_types_), |
| _argument_type(argument_types_[0]) {} |
| |
| std::string get_name() const override { return "histogram"; } |
| |
| DataTypePtr get_return_type() const override { return std::make_shared<DataTypeString>(); } |
| |
| void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, |
| Arena&) const override { |
| if constexpr (has_input_param) { |
| Int32 input_max_num_buckets = |
| assert_cast<const ColumnInt32*>(columns[1])->get_element(row_num); |
| if (input_max_num_buckets <= 0 || input_max_num_buckets > 1000000) { |
| throw doris::Exception( |
| ErrorCode::INVALID_ARGUMENT, |
| "Invalid max_num_buckets {}, row_num {}, should be in (0, 1000000]", |
| input_max_num_buckets, row_num); |
| } |
| this->data(place).set_parameters(input_max_num_buckets); |
| } else { |
| this->data(place).set_parameters(Data::DEFAULT_BUCKET_NUM); |
| } |
| |
| if constexpr (is_string_type(Data::Ptype)) { |
| this->data(place).add( |
| assert_cast<const ColumnString&, TypeCheckOnRelease::DISABLE>(*columns[0]) |
| .get_data_at(row_num)); |
| } else { |
| this->data(place).add( |
| assert_cast<const typename Data::ColVecType&, TypeCheckOnRelease::DISABLE>( |
| *columns[0]) |
| .get_data()[row_num]); |
| } |
| } |
| |
| void reset(AggregateDataPtr place) const override { this->data(place).reset(); } |
| |
| void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, |
| Arena&) const override { |
| this->data(place).merge(this->data(rhs)); |
| } |
| |
| void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override { |
| this->data(place).write(buf); |
| } |
| |
| void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf, |
| Arena&) const override { |
| this->data(place).read(buf); |
| } |
| |
| void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { |
| const std::string bucket_json = this->data(place).get(_argument_type); |
| assert_cast<ColumnString&>(to).insert_data(bucket_json.c_str(), bucket_json.length()); |
| } |
| |
| private: |
| DataTypePtr _argument_type; |
| }; |
| |
| } // namespace doris::vectorized |
| |
| #include "common/compile_check_end.h" |