blob: 5a711c39272a46c67bf68acf081663e326e47dbb [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <glog/logging.h>
#include <stddef.h>
#include <stdint.h>
#include <boost/iterator/iterator_facade.hpp>
#include <cmath>
#include <cstdint>
#include <memory>
#include <string>
#include <vector>
#include "util/counts.h"
#include "util/tdigest.h"
#include "vec/aggregate_functions/aggregate_function.h"
#include "vec/columns/column.h"
#include "vec/columns/column_array.h"
#include "vec/columns/column_nullable.h"
#include "vec/columns/column_vector.h"
#include "vec/common/assert_cast.h"
#include "vec/common/pod_array_fwd.h"
#include "vec/core/types.h"
#include "vec/data_types/data_type_array.h"
#include "vec/data_types/data_type_nullable.h"
#include "vec/data_types/data_type_number.h"
namespace doris::vectorized {
#include "common/compile_check_begin.h"
class Arena;
class BufferReadable;
inline void check_quantile(double quantile) {
if (quantile < 0 || quantile > 1) {
throw Exception(ErrorCode::INVALID_ARGUMENT,
"quantile in func percentile should in [0, 1], but real data is:" +
std::to_string(quantile));
}
}
struct PercentileApproxState {
static constexpr double INIT_QUANTILE = -1.0;
PercentileApproxState() = default;
~PercentileApproxState() = default;
void init(double quantile, float compression = 10000) {
if (!init_flag) {
//https://doris.apache.org/zh-CN/sql-reference/sql-functions/aggregate-functions/percentile_approx.html#description
//The compression parameter setting range is [2048, 10000].
//If the value of compression parameter is not specified set, or is outside the range of [2048, 10000],
//will use the default value of 10000
if (compression < 2048 || compression > 10000) {
compression = 10000;
}
digest = TDigest::create_unique(compression);
check_quantile(quantile);
target_quantile = quantile;
compressions = compression;
init_flag = true;
}
}
void write(BufferWritable& buf) const {
buf.write_binary(init_flag);
if (!init_flag) {
return;
}
buf.write_binary(target_quantile);
buf.write_binary(compressions);
uint32_t serialize_size = digest->serialized_size();
std::string result(serialize_size, '0');
DCHECK(digest.get() != nullptr);
digest->serialize((uint8_t*)result.c_str());
buf.write_binary(result);
}
void read(BufferReadable& buf) {
buf.read_binary(init_flag);
if (!init_flag) {
return;
}
buf.read_binary(target_quantile);
buf.read_binary(compressions);
std::string str;
buf.read_binary(str);
digest = TDigest::create_unique(compressions);
digest->unserialize((uint8_t*)str.c_str());
}
double get() const {
if (init_flag) {
return digest->quantile(static_cast<float>(target_quantile));
} else {
return std::nan("");
}
}
void merge(const PercentileApproxState& rhs) {
if (!rhs.init_flag) {
return;
}
if (init_flag) {
DCHECK(digest.get() != nullptr);
digest->merge(rhs.digest.get());
} else {
digest = TDigest::create_unique(compressions);
digest->merge(rhs.digest.get());
init_flag = true;
}
if (target_quantile == PercentileApproxState::INIT_QUANTILE) {
target_quantile = rhs.target_quantile;
}
}
void add(double source) { digest->add(static_cast<float>(source)); }
void add_with_weight(double source, double weight) {
// the weight should be positive num, as have check the value valid use DCHECK_GT(c._weight, 0);
if (weight <= 0) {
return;
}
digest->add(static_cast<float>(source), static_cast<float>(weight));
}
void reset() {
target_quantile = INIT_QUANTILE;
init_flag = false;
digest = TDigest::create_unique(compressions);
}
bool init_flag = false;
std::unique_ptr<TDigest> digest;
double target_quantile = INIT_QUANTILE;
float compressions = 10000;
};
class AggregateFunctionPercentileApprox
: public IAggregateFunctionDataHelper<PercentileApproxState,
AggregateFunctionPercentileApprox> {
public:
AggregateFunctionPercentileApprox(const DataTypes& argument_types_)
: IAggregateFunctionDataHelper<PercentileApproxState,
AggregateFunctionPercentileApprox>(argument_types_) {}
String get_name() const override { return "percentile_approx"; }
void reset(AggregateDataPtr __restrict place) const override {
AggregateFunctionPercentileApprox::data(place).reset();
}
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
Arena&) const override {
AggregateFunctionPercentileApprox::data(place).merge(
AggregateFunctionPercentileApprox::data(rhs));
}
void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override {
AggregateFunctionPercentileApprox::data(place).write(buf);
}
void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
Arena&) const override {
AggregateFunctionPercentileApprox::data(place).read(buf);
}
};
class AggregateFunctionPercentileApproxTwoParams final : public AggregateFunctionPercentileApprox,
public MultiExpression,
public NullableAggregateFunction {
public:
AggregateFunctionPercentileApproxTwoParams(const DataTypes& argument_types_)
: AggregateFunctionPercentileApprox(argument_types_) {}
void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena&) const override {
const auto& sources =
assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(*columns[0]);
const auto& quantile =
assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(*columns[1]);
this->data(place).init(quantile.get_element(0));
this->data(place).add(sources.get_element(row_num));
}
DataTypePtr get_return_type() const override { return std::make_shared<DataTypeFloat64>(); }
void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
auto& col = assert_cast<ColumnFloat64&>(to);
double result = AggregateFunctionPercentileApprox::data(place).get();
if (std::isnan(result)) {
col.insert_default();
} else {
col.get_data().push_back(result);
}
}
};
class AggregateFunctionPercentileApproxThreeParams final : public AggregateFunctionPercentileApprox,
public MultiExpression,
public NullableAggregateFunction {
public:
AggregateFunctionPercentileApproxThreeParams(const DataTypes& argument_types_)
: AggregateFunctionPercentileApprox(argument_types_) {}
void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena&) const override {
const auto& sources =
assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(*columns[0]);
const auto& quantile =
assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(*columns[1]);
const auto& compression =
assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(*columns[2]);
this->data(place).init(quantile.get_element(0),
static_cast<float>(compression.get_element(0)));
this->data(place).add(sources.get_element(row_num));
}
DataTypePtr get_return_type() const override { return std::make_shared<DataTypeFloat64>(); }
void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
auto& col = assert_cast<ColumnFloat64&>(to);
double result = AggregateFunctionPercentileApprox::data(place).get();
if (std::isnan(result)) {
col.insert_default();
} else {
col.get_data().push_back(result);
}
}
};
class AggregateFunctionPercentileApproxWeightedThreeParams final
: public AggregateFunctionPercentileApprox,
MultiExpression,
NullableAggregateFunction {
public:
AggregateFunctionPercentileApproxWeightedThreeParams(const DataTypes& argument_types_)
: AggregateFunctionPercentileApprox(argument_types_) {}
void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena&) const override {
const auto& sources =
assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(*columns[0]);
const auto& weight =
assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(*columns[1]);
const auto& quantile =
assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(*columns[2]);
this->data(place).init(quantile.get_element(0));
this->data(place).add_with_weight(sources.get_element(row_num),
weight.get_element(row_num));
}
DataTypePtr get_return_type() const override { return std::make_shared<DataTypeFloat64>(); }
void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
auto& col = assert_cast<ColumnFloat64&>(to);
double result = AggregateFunctionPercentileApprox::data(place).get();
if (std::isnan(result)) {
col.insert_default();
} else {
col.get_data().push_back(result);
}
}
};
class AggregateFunctionPercentileApproxWeightedFourParams final
: public AggregateFunctionPercentileApprox,
MultiExpression,
NullableAggregateFunction {
public:
AggregateFunctionPercentileApproxWeightedFourParams(const DataTypes& argument_types_)
: AggregateFunctionPercentileApprox(argument_types_) {}
void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena&) const override {
const auto& sources =
assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(*columns[0]);
const auto& weight =
assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(*columns[1]);
const auto& quantile =
assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(*columns[2]);
const auto& compression =
assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(*columns[3]);
this->data(place).init(quantile.get_element(0),
static_cast<float>(compression.get_element(0)));
this->data(place).add_with_weight(sources.get_element(row_num),
weight.get_element(row_num));
}
DataTypePtr get_return_type() const override { return std::make_shared<DataTypeFloat64>(); }
void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
auto& col = assert_cast<ColumnFloat64&>(to);
double result = AggregateFunctionPercentileApprox::data(place).get();
if (std::isnan(result)) {
col.insert_default();
} else {
col.get_data().push_back(result);
}
}
};
template <PrimitiveType T>
struct PercentileState {
mutable std::vector<Counts<typename PrimitiveTypeTraits<T>::ColumnItemType>> vec_counts;
std::vector<double> vec_quantile {-1};
bool inited_flag = false;
void write(BufferWritable& buf) const {
buf.write_binary(inited_flag);
if (!inited_flag) {
return;
}
int size_num = cast_set<int>(vec_quantile.size());
buf.write_binary(size_num);
for (const auto& quantile : vec_quantile) {
buf.write_binary(quantile);
}
for (auto& counts : vec_counts) {
counts.serialize(buf);
}
}
void read(BufferReadable& buf) {
buf.read_binary(inited_flag);
if (!inited_flag) {
return;
}
int size_num = 0;
buf.read_binary(size_num);
double data = 0.0;
vec_quantile.clear();
for (int i = 0; i < size_num; ++i) {
buf.read_binary(data);
vec_quantile.emplace_back(data);
}
vec_counts.clear();
vec_counts.resize(size_num);
for (int i = 0; i < size_num; ++i) {
vec_counts[i].unserialize(buf);
}
}
void add(typename PrimitiveTypeTraits<T>::ColumnItemType source,
const PaddedPODArray<Float64>& quantiles, const NullMap& null_maps, int64_t arg_size) {
if (!inited_flag) {
vec_counts.resize(arg_size);
vec_quantile.resize(arg_size, -1);
inited_flag = true;
for (int i = 0; i < arg_size; ++i) {
// throw Exception func call percentile_array(id, [1,0,null])
if (null_maps[i]) {
throw Exception(ErrorCode::INVALID_ARGUMENT,
"quantiles in func percentile_array should not have null");
}
check_quantile(quantiles[i]);
vec_quantile[i] = quantiles[i];
}
}
for (int i = 0; i < arg_size; ++i) {
vec_counts[i].increment(source);
}
}
void add_batch(const PaddedPODArray<typename PrimitiveTypeTraits<T>::ColumnItemType>& source,
const Float64& q) {
if (!inited_flag) {
inited_flag = true;
vec_counts.resize(1);
vec_quantile.resize(1);
check_quantile(q);
vec_quantile[0] = q;
}
vec_counts[0].increment_batch(source);
}
void merge(const PercentileState& rhs) {
if (!rhs.inited_flag) {
return;
}
int size_num = cast_set<int>(rhs.vec_quantile.size());
if (!inited_flag) {
vec_counts.resize(size_num);
vec_quantile.resize(size_num, -1);
inited_flag = true;
}
for (int i = 0; i < size_num; ++i) {
if (vec_quantile[i] == -1.0) {
vec_quantile[i] = rhs.vec_quantile[i];
}
vec_counts[i].merge(&(rhs.vec_counts[i]));
}
}
void reset() {
vec_counts.clear();
vec_quantile.clear();
inited_flag = false;
}
double get() const { return vec_counts.empty() ? 0 : vec_counts[0].terminate(vec_quantile[0]); }
void insert_result_into(IColumn& to) const {
auto& column_data = assert_cast<ColumnFloat64&>(to).get_data();
for (int i = 0; i < vec_counts.size(); ++i) {
column_data.push_back(vec_counts[i].terminate(vec_quantile[i]));
}
}
};
template <PrimitiveType T>
class AggregateFunctionPercentile final
: public IAggregateFunctionDataHelper<PercentileState<T>, AggregateFunctionPercentile<T>>,
MultiExpression,
NullableAggregateFunction {
public:
using ColVecType = typename PrimitiveTypeTraits<T>::ColumnType;
using Base = IAggregateFunctionDataHelper<PercentileState<T>, AggregateFunctionPercentile<T>>;
AggregateFunctionPercentile(const DataTypes& argument_types_) : Base(argument_types_) {}
String get_name() const override { return "percentile"; }
DataTypePtr get_return_type() const override { return std::make_shared<DataTypeFloat64>(); }
void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena&) const override {
const auto& sources =
assert_cast<const ColVecType&, TypeCheckOnRelease::DISABLE>(*columns[0]);
const auto& quantile =
assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(*columns[1]);
AggregateFunctionPercentile::data(place).add(sources.get_data()[row_num],
quantile.get_data(), NullMap(1, 0), 1);
}
void add_batch_single_place(size_t batch_size, AggregateDataPtr place, const IColumn** columns,
Arena&) const override {
const auto& sources =
assert_cast<const ColVecType&, TypeCheckOnRelease::DISABLE>(*columns[0]);
const auto& quantile =
assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(*columns[1]);
DCHECK_EQ(sources.get_data().size(), batch_size);
AggregateFunctionPercentile::data(place).add_batch(sources.get_data(),
quantile.get_data()[0]);
}
void reset(AggregateDataPtr __restrict place) const override {
AggregateFunctionPercentile::data(place).reset();
}
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
Arena&) const override {
AggregateFunctionPercentile::data(place).merge(AggregateFunctionPercentile::data(rhs));
}
void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override {
AggregateFunctionPercentile::data(place).write(buf);
}
void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
Arena&) const override {
AggregateFunctionPercentile::data(place).read(buf);
}
void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
auto& col = assert_cast<ColumnFloat64&>(to);
col.insert_value(AggregateFunctionPercentile::data(place).get());
}
};
template <PrimitiveType T>
class AggregateFunctionPercentileArray final
: public IAggregateFunctionDataHelper<PercentileState<T>,
AggregateFunctionPercentileArray<T>>,
MultiExpression,
NotNullableAggregateFunction {
public:
using ColVecType = typename PrimitiveTypeTraits<T>::ColumnType;
using Base =
IAggregateFunctionDataHelper<PercentileState<T>, AggregateFunctionPercentileArray<T>>;
AggregateFunctionPercentileArray(const DataTypes& argument_types_) : Base(argument_types_) {}
String get_name() const override { return "percentile_array"; }
DataTypePtr get_return_type() const override {
return std::make_shared<DataTypeArray>(make_nullable(std::make_shared<DataTypeFloat64>()));
}
void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena&) const override {
const auto& sources =
assert_cast<const ColVecType&, TypeCheckOnRelease::DISABLE>(*columns[0]);
const auto& quantile_array =
assert_cast<const ColumnArray&, TypeCheckOnRelease::DISABLE>(*columns[1]);
const auto& offset_column_data = quantile_array.get_offsets();
const auto& null_maps = assert_cast<const ColumnNullable&, TypeCheckOnRelease::DISABLE>(
quantile_array.get_data())
.get_null_map_data();
const auto& nested_column = assert_cast<const ColumnNullable&, TypeCheckOnRelease::DISABLE>(
quantile_array.get_data())
.get_nested_column();
const auto& nested_column_data =
assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(nested_column);
AggregateFunctionPercentileArray::data(place).add(
sources.get_element(row_num), nested_column_data.get_data(), null_maps,
offset_column_data.data()[row_num] - offset_column_data[(ssize_t)row_num - 1]);
}
void reset(AggregateDataPtr __restrict place) const override {
AggregateFunctionPercentileArray::data(place).reset();
}
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
Arena&) const override {
AggregateFunctionPercentileArray::data(place).merge(
AggregateFunctionPercentileArray::data(rhs));
}
void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override {
AggregateFunctionPercentileArray::data(place).write(buf);
}
void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
Arena&) const override {
AggregateFunctionPercentileArray::data(place).read(buf);
}
void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
auto& to_arr = assert_cast<ColumnArray&>(to);
auto& to_nested_col = to_arr.get_data();
if (to_nested_col.is_nullable()) {
auto col_null = reinterpret_cast<ColumnNullable*>(&to_nested_col);
AggregateFunctionPercentileArray::data(place).insert_result_into(
col_null->get_nested_column());
col_null->get_null_map_data().resize_fill(col_null->get_nested_column().size(), 0);
} else {
AggregateFunctionPercentileArray::data(place).insert_result_into(to_nested_col);
}
to_arr.get_offsets().push_back(to_nested_col.size());
}
};
#include "common/compile_check_end.h"
} // namespace doris::vectorized