blob: 6f9e8b34ecabe97b9bf80f22b0220b9ef8dc74fe [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <rapidjson/encodings.h>
#include <rapidjson/stringbuffer.h>
#include <rapidjson/writer.h>
#include <stddef.h>
#include <stdint.h>
#include <algorithm>
#include <functional>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "vec/aggregate_functions/aggregate_function.h"
#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
#include "vec/columns/column.h"
#include "vec/columns/column_array.h"
#include "vec/columns/column_nullable.h"
#include "vec/columns/column_string.h"
#include "vec/columns/column_vector.h"
#include "vec/common/assert_cast.h"
#include "vec/common/hash_table/phmap_fwd_decl.h"
#include "vec/common/string_ref.h"
#include "vec/core/types.h"
#include "vec/data_types/data_type_array.h"
#include "vec/data_types/data_type_nullable.h"
#include "vec/data_types/data_type_string.h"
#include "vec/io/io_helper.h"
namespace doris {
#include "common/compile_check_begin.h"
} // namespace doris
namespace doris::vectorized {
// space-saving algorithm
template <PrimitiveType T>
struct AggregateFunctionTopNData {
using ColVecType = typename PrimitiveTypeTraits<T>::ColumnType;
using DataType = typename PrimitiveTypeTraits<T>::ColumnItemType;
void set_paramenters(int input_top_num, int space_expand_rate = 50) {
top_num = input_top_num;
capacity = (uint64_t)top_num * space_expand_rate;
}
void add(const StringRef& value, const UInt64& increment = 1) {
std::string data = value.to_string();
auto it = counter_map.find(data);
if (it != counter_map.end()) {
it->second = it->second + increment;
} else {
counter_map.insert({data, increment});
}
}
void add(const DataType& value, const UInt64& increment = 1) {
auto it = counter_map.find(value);
if (it != counter_map.end()) {
it->second = it->second + increment;
} else {
counter_map.insert({value, increment});
}
}
void merge(const AggregateFunctionTopNData& rhs) {
if (!rhs.top_num) {
return;
}
top_num = rhs.top_num;
capacity = rhs.capacity;
bool lhs_full = (counter_map.size() >= capacity);
bool rhs_full = (rhs.counter_map.size() >= capacity);
uint64_t lhs_min = 0;
uint64_t rhs_min = 0;
if (lhs_full) {
lhs_min = UINT64_MAX;
for (auto it : counter_map) {
lhs_min = std::min(lhs_min, it.second);
}
}
if (rhs_full) {
rhs_min = UINT64_MAX;
for (auto it : rhs.counter_map) {
rhs_min = std::min(rhs_min, it.second);
}
for (auto& it : counter_map) {
it.second += rhs_min;
}
}
for (auto rhs_it : rhs.counter_map) {
auto lhs_it = counter_map.find(rhs_it.first);
if (lhs_it != counter_map.end()) {
lhs_it->second += rhs_it.second - rhs_min;
} else {
counter_map.insert({rhs_it.first, rhs_it.second + lhs_min});
}
}
}
std::vector<std::pair<uint64_t, typename PrimitiveTypeTraits<T>::ColumnItemType>>
get_remain_vector() const {
std::vector<std::pair<uint64_t, typename PrimitiveTypeTraits<T>::ColumnItemType>>
counter_vector;
for (auto it : counter_map) {
counter_vector.emplace_back(it.second, it.first);
}
std::sort(counter_vector.begin(), counter_vector.end(),
std::greater<
std::pair<uint64_t, typename PrimitiveTypeTraits<T>::ColumnItemType>>());
return counter_vector;
}
void write(BufferWritable& buf) const {
buf.write_binary(top_num);
buf.write_binary(capacity);
uint64_t element_number = std::min(capacity, (uint64_t)counter_map.size());
buf.write_binary(element_number);
auto counter_vector = get_remain_vector();
for (auto i = 0; i < element_number; i++) {
auto element = counter_vector[i];
buf.write_binary(element.second);
buf.write_binary(element.first);
}
}
void read(BufferReadable& buf) {
buf.read_binary(top_num);
buf.read_binary(capacity);
uint64_t element_number = 0;
buf.read_binary(element_number);
counter_map.clear();
std::pair<DataType, uint64_t> element;
for (auto i = 0; i < element_number; i++) {
buf.read_binary(element.first);
buf.read_binary(element.second);
counter_map.insert(element);
}
}
std::string get() const {
auto counter_vector = get_remain_vector();
rapidjson::StringBuffer buffer;
rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
writer.StartObject();
for (int i = 0; i < std::min((int)counter_vector.size(), top_num); i++) {
const auto& element = counter_vector[i];
writer.Key(element.second.c_str());
writer.Uint64(element.first);
}
writer.EndObject();
return buffer.GetString();
}
void insert_result_into(IColumn& to) const {
auto counter_vector = get_remain_vector();
for (int i = 0; i < std::min((int)counter_vector.size(), top_num); i++) {
const auto& element = counter_vector[i];
if constexpr (is_string_type(T)) {
assert_cast<ColumnString&, TypeCheckOnRelease::DISABLE>(to).insert_data(
element.second.c_str(), element.second.length());
} else {
assert_cast<ColVecType&, TypeCheckOnRelease::DISABLE>(to).get_data().push_back(
element.second);
}
}
}
void reset() { counter_map.clear(); }
int top_num = 0;
uint64_t capacity = 0;
flat_hash_map<DataType, uint64_t> counter_map;
};
struct AggregateFunctionTopNImplInt {
using Data = AggregateFunctionTopNData<TYPE_STRING>;
static void add(Data& __restrict place, const IColumn** columns, size_t row_num) {
place.set_paramenters(
assert_cast<const ColumnInt32*, TypeCheckOnRelease::DISABLE>(columns[1])
->get_element(row_num));
place.add(assert_cast<const ColumnString&, TypeCheckOnRelease::DISABLE>(*columns[0])
.get_data_at(row_num));
}
};
struct AggregateFunctionTopNImplIntInt {
using Data = AggregateFunctionTopNData<TYPE_STRING>;
static void add(Data& __restrict place, const IColumn** columns, size_t row_num) {
place.set_paramenters(
assert_cast<const ColumnInt32*, TypeCheckOnRelease::DISABLE>(columns[1])
->get_element(row_num),
assert_cast<const ColumnInt32*, TypeCheckOnRelease::DISABLE>(columns[2])
->get_element(row_num));
place.add(assert_cast<const ColumnString&>(*columns[0]).get_data_at(row_num));
}
};
//for topn_array agg
template <PrimitiveType T, bool has_default_param>
struct AggregateFunctionTopNImplArray {
using Data = AggregateFunctionTopNData<T>;
using ColVecType = typename PrimitiveTypeTraits<T>::ColumnType;
static String get_name() { return "topn_array"; }
static void add(AggregateFunctionTopNData<T>& __restrict place, const IColumn** columns,
size_t row_num) {
if constexpr (has_default_param) {
place.set_paramenters(
assert_cast<const ColumnInt32*, TypeCheckOnRelease::DISABLE>(columns[1])
->get_element(row_num),
assert_cast<const ColumnInt32*, TypeCheckOnRelease::DISABLE>(columns[2])
->get_element(row_num));
} else {
place.set_paramenters(
assert_cast<const ColumnInt32*, TypeCheckOnRelease::DISABLE>(columns[1])
->get_element(row_num));
}
if constexpr (is_string_type(T)) {
place.add(assert_cast<const ColumnString&, TypeCheckOnRelease::DISABLE>(*columns[0])
.get_data_at(row_num));
} else {
typename PrimitiveTypeTraits<T>::ColumnItemType val =
assert_cast<const ColVecType&, TypeCheckOnRelease::DISABLE>(*columns[0])
.get_data()[row_num];
place.add(val);
}
}
};
//for topn_weighted agg
template <PrimitiveType T, bool has_default_param>
struct AggregateFunctionTopNImplWeight {
using Data = AggregateFunctionTopNData<T>;
using ColVecType = typename PrimitiveTypeTraits<T>::ColumnType;
static String get_name() { return "topn_weighted"; }
static void add(AggregateFunctionTopNData<T>& __restrict place, const IColumn** columns,
size_t row_num) {
if constexpr (has_default_param) {
place.set_paramenters(
assert_cast<const ColumnInt32*, TypeCheckOnRelease::DISABLE>(columns[2])
->get_element(row_num),
assert_cast<const ColumnInt32*, TypeCheckOnRelease::DISABLE>(columns[3])
->get_element(row_num));
} else {
place.set_paramenters(
assert_cast<const ColumnInt32*>(columns[2])->get_element(row_num));
}
if constexpr (is_string_type(T)) {
auto weight = assert_cast<const ColumnInt64&, TypeCheckOnRelease::DISABLE>(*columns[1])
.get_data()[row_num];
place.add(assert_cast<const ColumnString&, TypeCheckOnRelease::DISABLE>(*columns[0])
.get_data_at(row_num),
weight);
} else {
typename PrimitiveTypeTraits<T>::ColumnItemType val =
assert_cast<const typename PrimitiveTypeTraits<T>::ColumnType&,
TypeCheckOnRelease::DISABLE>(*columns[0])
.get_data()[row_num];
auto weight = assert_cast<const ColumnInt64&, TypeCheckOnRelease::DISABLE>(*columns[1])
.get_data()[row_num];
place.add(val, weight);
}
}
};
//base function
template <typename Impl>
class AggregateFunctionTopNBase
: public IAggregateFunctionDataHelper<typename Impl::Data,
AggregateFunctionTopNBase<Impl>> {
public:
AggregateFunctionTopNBase(const DataTypes& argument_types_)
: IAggregateFunctionDataHelper<typename Impl::Data, AggregateFunctionTopNBase<Impl>>(
argument_types_) {}
void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena&) const override {
Impl::add(this->data(place), columns, row_num);
}
void reset(AggregateDataPtr __restrict place) const override { this->data(place).reset(); }
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
Arena&) const override {
this->data(place).merge(this->data(rhs));
}
void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override {
this->data(place).write(buf);
}
void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
Arena&) const override {
this->data(place).read(buf);
}
};
//topn function return string
template <typename Impl>
class AggregateFunctionTopN final : public AggregateFunctionTopNBase<Impl>,
MultiExpression,
NullableAggregateFunction {
public:
AggregateFunctionTopN(const DataTypes& argument_types_)
: AggregateFunctionTopNBase<Impl>(argument_types_) {}
String get_name() const override { return "topn"; }
DataTypePtr get_return_type() const override { return std::make_shared<DataTypeString>(); }
void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
std::string result = this->data(place).get();
assert_cast<ColumnString&>(to).insert_data(result.c_str(), result.length());
}
};
//topn function return array
template <typename Impl>
class AggregateFunctionTopNArray final : public AggregateFunctionTopNBase<Impl>,
MultiExpression,
NullableAggregateFunction {
public:
AggregateFunctionTopNArray(const DataTypes& argument_types_)
: AggregateFunctionTopNBase<Impl>(argument_types_),
_argument_type(argument_types_[0]) {}
String get_name() const override { return Impl::get_name(); }
DataTypePtr get_return_type() const override {
return std::make_shared<DataTypeArray>(make_nullable(_argument_type));
}
void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
auto& to_arr = assert_cast<ColumnArray&>(to);
auto& to_nested_col = to_arr.get_data();
if (to_nested_col.is_nullable()) {
auto* col_null = assert_cast<ColumnNullable*>(&to_nested_col);
this->data(place).insert_result_into(col_null->get_nested_column());
col_null->get_null_map_data().resize_fill(col_null->get_nested_column().size(), 0);
} else {
this->data(place).insert_result_into(to_nested_col);
}
to_arr.get_offsets().push_back(to_nested_col.size());
}
private:
DataTypePtr _argument_type;
};
} // namespace doris::vectorized
#include "common/compile_check_end.h"