blob: 007c236da73d0fbf2d8f7ed1808948b9251e46e6 [file]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <glog/logging.h>
#include <stddef.h>
#include <stdint.h>
#include <boost/iterator/iterator_facade.hpp>
#include <cmath>
#include <cstdint>
#include <memory>
#include <string>
#include <vector>
#include "core/assert_cast.h"
#include "core/column/column.h"
#include "core/column/column_array.h"
#include "core/column/column_nullable.h"
#include "core/column/column_vector.h"
#include "core/data_type/data_type_array.h"
#include "core/data_type/data_type_nullable.h"
#include "core/data_type/data_type_number.h"
#include "core/pod_array.h"
#include "core/pod_array_fwd.h"
#include "core/types.h"
#include "exprs/aggregate/aggregate_function.h"
#include "util/percentile_util.h"
#include "util/tdigest.h"
namespace doris {
class Arena;
class BufferReadable;
struct PercentileApproxState {
static constexpr double INIT_QUANTILE = -1.0;
PercentileApproxState() = default;
~PercentileApproxState() = default;
void init(double quantile, float compression = 10000) {
if (!init_flag) {
//https://doris.apache.org/zh-CN/sql-reference/sql-functions/aggregate-functions/percentile_approx.html#description
//The compression parameter setting range is [2048, 10000].
//If the value of compression parameter is not specified set, or is outside the range of [2048, 10000],
//will use the default value of 10000
if (compression < 2048 || compression > 10000) {
compression = 10000;
}
digest = TDigest::create_unique(compression);
check_quantile(quantile);
target_quantile = quantile;
compressions = compression;
init_flag = true;
}
}
void write(BufferWritable& buf) const {
buf.write_binary(init_flag);
if (!init_flag) {
return;
}
buf.write_binary(target_quantile);
buf.write_binary(compressions);
uint32_t serialize_size = digest->serialized_size();
std::string result(serialize_size, '0');
DCHECK(digest.get() != nullptr);
digest->serialize((uint8_t*)result.c_str());
buf.write_binary(result);
}
void read(BufferReadable& buf) {
buf.read_binary(init_flag);
if (!init_flag) {
return;
}
buf.read_binary(target_quantile);
buf.read_binary(compressions);
std::string str;
buf.read_binary(str);
digest = TDigest::create_unique(compressions);
digest->unserialize((uint8_t*)str.c_str());
}
double get() const {
if (init_flag) {
return digest->quantile(static_cast<float>(target_quantile));
} else {
return std::nan("");
}
}
void merge(const PercentileApproxState& rhs) {
if (!rhs.init_flag) {
return;
}
if (init_flag) {
DCHECK(digest.get() != nullptr);
digest->merge(rhs.digest.get());
} else {
digest = TDigest::create_unique(compressions);
digest->merge(rhs.digest.get());
init_flag = true;
}
if (target_quantile == PercentileApproxState::INIT_QUANTILE) {
target_quantile = rhs.target_quantile;
}
}
void add(double source) { digest->add(static_cast<float>(source)); }
void add_with_weight(double source, double weight) {
// the weight should be positive num, as have check the value valid use DCHECK_GT(c._weight, 0);
if (weight <= 0) {
return;
}
digest->add(static_cast<float>(source), static_cast<float>(weight));
}
void reset() {
target_quantile = INIT_QUANTILE;
init_flag = false;
digest = TDigest::create_unique(compressions);
}
bool init_flag = false;
std::unique_ptr<TDigest> digest;
double target_quantile = INIT_QUANTILE;
float compressions = 10000;
};
template <typename Derived>
class AggregateFunctionPercentileApproxBase
: public IAggregateFunctionDataHelper<PercentileApproxState, Derived> {
public:
AggregateFunctionPercentileApproxBase(const DataTypes& argument_types_)
: IAggregateFunctionDataHelper<PercentileApproxState, Derived>(argument_types_) {}
String get_name() const override { return "percentile_approx"; }
void reset(AggregateDataPtr __restrict place) const override { this->data(place).reset(); }
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
Arena&) const override {
this->data(place).merge(this->data(rhs));
}
void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override {
this->data(place).write(buf);
}
void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
Arena&) const override {
this->data(place).read(buf);
}
};
class AggregateFunctionPercentileApproxTwoParams final
: public AggregateFunctionPercentileApproxBase<AggregateFunctionPercentileApproxTwoParams>,
public MultiExpression,
public NullableAggregateFunction {
public:
AggregateFunctionPercentileApproxTwoParams(const DataTypes& argument_types_)
: AggregateFunctionPercentileApproxBase<AggregateFunctionPercentileApproxTwoParams>(
argument_types_) {}
void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena&) const override {
const auto& sources =
assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(*columns[0]);
const auto& quantile =
assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(*columns[1]);
this->data(place).init(quantile.get_element(0));
this->data(place).add(sources.get_element(row_num));
}
DataTypePtr get_return_type() const override { return std::make_shared<DataTypeFloat64>(); }
void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
auto& col = assert_cast<ColumnFloat64&>(to);
double result = this->data(place).get();
if (std::isnan(result)) {
col.insert_default();
} else {
col.get_data().push_back(result);
}
}
};
class AggregateFunctionPercentileApproxThreeParams final
: public AggregateFunctionPercentileApproxBase<
AggregateFunctionPercentileApproxThreeParams>,
public MultiExpression,
public NullableAggregateFunction {
public:
AggregateFunctionPercentileApproxThreeParams(const DataTypes& argument_types_)
: AggregateFunctionPercentileApproxBase<AggregateFunctionPercentileApproxThreeParams>(
argument_types_) {}
void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena&) const override {
const auto& sources =
assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(*columns[0]);
const auto& quantile =
assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(*columns[1]);
const auto& compression =
assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(*columns[2]);
this->data(place).init(quantile.get_element(0),
static_cast<float>(compression.get_element(0)));
this->data(place).add(sources.get_element(row_num));
}
DataTypePtr get_return_type() const override { return std::make_shared<DataTypeFloat64>(); }
void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
auto& col = assert_cast<ColumnFloat64&>(to);
double result = this->data(place).get();
if (std::isnan(result)) {
col.insert_default();
} else {
col.get_data().push_back(result);
}
}
};
class AggregateFunctionPercentileApproxWeightedThreeParams final
: public AggregateFunctionPercentileApproxBase<
AggregateFunctionPercentileApproxWeightedThreeParams>,
MultiExpression,
NullableAggregateFunction {
public:
AggregateFunctionPercentileApproxWeightedThreeParams(const DataTypes& argument_types_)
: AggregateFunctionPercentileApproxBase<
AggregateFunctionPercentileApproxWeightedThreeParams>(argument_types_) {}
void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena&) const override {
const auto& sources =
assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(*columns[0]);
const auto& weight =
assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(*columns[1]);
const auto& quantile =
assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(*columns[2]);
this->data(place).init(quantile.get_element(0));
this->data(place).add_with_weight(sources.get_element(row_num),
weight.get_element(row_num));
}
DataTypePtr get_return_type() const override { return std::make_shared<DataTypeFloat64>(); }
void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
auto& col = assert_cast<ColumnFloat64&>(to);
double result = this->data(place).get();
if (std::isnan(result)) {
col.insert_default();
} else {
col.get_data().push_back(result);
}
}
};
class AggregateFunctionPercentileApproxWeightedFourParams final
: public AggregateFunctionPercentileApproxBase<
AggregateFunctionPercentileApproxWeightedFourParams>,
MultiExpression,
NullableAggregateFunction {
public:
AggregateFunctionPercentileApproxWeightedFourParams(const DataTypes& argument_types_)
: AggregateFunctionPercentileApproxBase<
AggregateFunctionPercentileApproxWeightedFourParams>(argument_types_) {}
void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena&) const override {
const auto& sources =
assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(*columns[0]);
const auto& weight =
assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(*columns[1]);
const auto& quantile =
assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(*columns[2]);
const auto& compression =
assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(*columns[3]);
this->data(place).init(quantile.get_element(0),
static_cast<float>(compression.get_element(0)));
this->data(place).add_with_weight(sources.get_element(row_num),
weight.get_element(row_num));
}
DataTypePtr get_return_type() const override { return std::make_shared<DataTypeFloat64>(); }
void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
auto& col = assert_cast<ColumnFloat64&>(to);
double result = this->data(place).get();
if (std::isnan(result)) {
col.insert_default();
} else {
col.get_data().push_back(result);
}
}
};
template <PrimitiveType T>
struct PercentileState {
mutable std::vector<Counts<typename PrimitiveTypeTraits<T>::CppType>> vec_counts;
std::vector<double> vec_quantile {-1};
bool inited_flag = false;
void write(BufferWritable& buf) const {
buf.write_binary(inited_flag);
if (!inited_flag) {
return;
}
int size_num = cast_set<int>(vec_quantile.size());
buf.write_binary(size_num);
for (const auto& quantile : vec_quantile) {
buf.write_binary(quantile);
}
for (auto& counts : vec_counts) {
counts.serialize(buf);
}
}
void read(BufferReadable& buf) {
buf.read_binary(inited_flag);
if (!inited_flag) {
return;
}
int size_num = 0;
buf.read_binary(size_num);
double data = 0.0;
vec_quantile.clear();
for (int i = 0; i < size_num; ++i) {
buf.read_binary(data);
vec_quantile.emplace_back(data);
}
vec_counts.clear();
vec_counts.resize(size_num);
for (int i = 0; i < size_num; ++i) {
vec_counts[i].unserialize(buf);
}
}
void add(typename PrimitiveTypeTraits<T>::CppType source,
const PaddedPODArray<Float64>& quantiles, const NullMap& null_maps, int64_t arg_size) {
if (!inited_flag) {
vec_counts.resize(arg_size);
vec_quantile.resize(arg_size, -1);
inited_flag = true;
for (int i = 0; i < arg_size; ++i) {
// throw Exception func call percentile_array(id, [1,0,null])
if (null_maps[i]) {
throw Exception(ErrorCode::INVALID_ARGUMENT,
"quantiles in func percentile_array should not have null");
}
check_quantile(quantiles[i]);
vec_quantile[i] = quantiles[i];
}
}
for (int i = 0; i < arg_size; ++i) {
vec_counts[i].increment(source);
}
}
void add_batch(const PaddedPODArray<typename PrimitiveTypeTraits<T>::CppType>& source,
const Float64& q) {
if (!inited_flag) {
inited_flag = true;
vec_counts.resize(1);
vec_quantile.resize(1);
check_quantile(q);
vec_quantile[0] = q;
}
vec_counts[0].increment_batch(source);
}
void merge(const PercentileState& rhs) {
if (!rhs.inited_flag) {
return;
}
int size_num = cast_set<int>(rhs.vec_quantile.size());
if (!inited_flag) {
vec_counts.resize(size_num);
vec_quantile.resize(size_num, -1);
inited_flag = true;
}
for (int i = 0; i < size_num; ++i) {
if (vec_quantile[i] == -1.0) {
vec_quantile[i] = rhs.vec_quantile[i];
}
vec_counts[i].merge(&(rhs.vec_counts[i]));
}
}
void reset() {
vec_counts.clear();
vec_quantile.clear();
inited_flag = false;
}
double get() const { return vec_counts.empty() ? 0 : vec_counts[0].terminate(vec_quantile[0]); }
void insert_result_into(IColumn& to) const {
auto& column_data = assert_cast<ColumnFloat64&>(to).get_data();
for (int i = 0; i < vec_counts.size(); ++i) {
column_data.push_back(vec_counts[i].terminate(vec_quantile[i]));
}
}
};
template <PrimitiveType T>
struct PercentileExactState {
using ValueType = typename PrimitiveTypeTraits<T>::CppType;
static constexpr size_t bytes_in_arena = 64 - sizeof(PODArray<ValueType>);
using Array = PODArrayWithStackMemory<ValueType, bytes_in_arena>;
void add_single_range(const ValueType* data, size_t count, double quantile) {
if (!inited_flag) {
_set_single_level(quantile);
inited_flag = true;
}
_append(data, count);
}
void add_many_range(const ValueType* data, size_t count,
const PaddedPODArray<Float64>& quantiles_data, const NullMap& null_maps,
size_t start, int64_t arg_size) {
if (!inited_flag) {
_set_many_levels(quantiles_data, null_maps, start, arg_size);
inited_flag = true;
}
if (levels.empty()) {
return;
}
_append(data, count);
}
void write(BufferWritable& buf) const {
buf.write_binary(inited_flag);
if (!inited_flag) {
return;
}
levels.write(buf);
size_t size = values.size();
buf.write_binary(size);
if (size > 0) {
buf.write(reinterpret_cast<const char*>(values.data()), sizeof(ValueType) * size);
}
}
void read(BufferReadable& buf) {
reset();
buf.read_binary(inited_flag);
if (!inited_flag) {
return;
}
levels.read(buf);
size_t size = 0;
buf.read_binary(size);
values.resize(size);
if (size > 0) {
auto raw = buf.read(sizeof(ValueType) * size);
memcpy(values.data(), raw.data, raw.size);
}
}
void merge(const PercentileExactState& rhs) {
if (!rhs.inited_flag) {
return;
}
if (!inited_flag) {
levels = rhs.levels;
inited_flag = true;
} else {
levels.merge(rhs.levels);
}
_append(rhs.values.data(), rhs.values.size());
}
void reset() {
values.clear();
levels.clear();
inited_flag = false;
}
double get() const {
if (!inited_flag || levels.empty() || values.empty()) {
return 0.0;
}
DCHECK_EQ(levels.quantiles.size(), 1);
return _get_result(levels.quantiles[0]);
}
void insert_result_into(IColumn& to) const {
auto& column_data = assert_cast<ColumnFloat64&>(to).get_data();
if (!inited_flag || levels.empty() || values.empty()) {
return;
}
size_t old_size = column_data.size();
size_t size = levels.quantiles.size();
column_data.resize(old_size + size);
auto* result = column_data.data() + old_size;
if (values.size() == 1) {
for (size_t i = 0; i < size; ++i) {
result[i] = static_cast<double>(values.front());
}
return;
}
size_t prev_index = 0;
const auto& quantiles = levels.quantiles;
const auto& permutation = levels.get_permutation();
for (size_t i = 0; i < size; ++i) {
auto level_index = permutation[i];
auto level = quantiles[level_index];
double u = static_cast<double>(values.size() - 1) * level;
auto index = static_cast<size_t>(u);
if (index + 1 >= values.size()) {
result[level_index] =
static_cast<double>(*std::max_element(values.begin(), values.end()));
} else {
std::nth_element(values.begin() + prev_index, values.begin() + index, values.end());
auto* nth_elem = std::min_element(values.begin() + index + 1, values.end());
result[level_index] =
static_cast<double>(values[index]) +
(u - static_cast<double>(index)) * (static_cast<double>(*nth_elem) -
static_cast<double>(values[index]));
prev_index = index;
}
}
}
private:
void _set_single_level(double quantile) {
DCHECK(levels.empty());
check_quantile(quantile);
levels.quantiles.push_back(quantile);
levels.permutation.push_back(0);
}
void _set_many_levels(const PaddedPODArray<Float64>& quantiles_data, const NullMap& null_maps,
size_t start, int64_t arg_size) {
DCHECK(levels.empty());
size_t size = cast_set<size_t>(arg_size);
levels.quantiles.resize(size);
levels.permutation.resize(size);
for (size_t i = 0; i < size; ++i) {
if (null_maps[start + i]) {
throw Exception(ErrorCode::INVALID_ARGUMENT,
"quantiles in func percentile_array should not have null");
}
check_quantile(quantiles_data[start + i]);
levels.quantiles[i] = quantiles_data[start + i];
levels.permutation[i] = i;
}
}
void _append(const ValueType* data, size_t count) {
if (count == 0) {
return;
}
values.reserve(values.size() + count);
values.insert_assume_reserved(data, data + count);
}
double _get_result(double quantile) const {
if (values.size() == 1) {
return static_cast<double>(values.front());
}
double u = static_cast<double>(values.size() - 1) * quantile;
auto index = static_cast<size_t>(u);
if (index + 1 >= values.size()) {
return static_cast<double>(*std::max_element(values.begin(), values.end()));
}
std::nth_element(values.begin(), values.begin() + index, values.end());
auto* nth_elem = std::min_element(values.begin() + index + 1, values.end());
return static_cast<double>(values[index]) +
(u - static_cast<double>(index)) *
(static_cast<double>(*nth_elem) - static_cast<double>(values[index]));
}
mutable Array values;
mutable PercentileLevels levels;
bool inited_flag = false;
};
template <PrimitiveType T>
class AggregateFunctionPercentile final
: public IAggregateFunctionDataHelper<PercentileState<T>, AggregateFunctionPercentile<T>>,
MultiExpression,
NullableAggregateFunction {
public:
using ColVecType = typename PrimitiveTypeTraits<T>::ColumnType;
using Base = IAggregateFunctionDataHelper<PercentileState<T>, AggregateFunctionPercentile<T>>;
AggregateFunctionPercentile(const DataTypes& argument_types_) : Base(argument_types_) {}
String get_name() const override { return "percentile"; }
DataTypePtr get_return_type() const override { return std::make_shared<DataTypeFloat64>(); }
void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena&) const override {
const auto& sources =
assert_cast<const ColVecType&, TypeCheckOnRelease::DISABLE>(*columns[0]);
const auto& quantile =
assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(*columns[1]);
AggregateFunctionPercentile::data(place).add(sources.get_data()[row_num],
quantile.get_data(), NullMap(1, 0), 1);
}
void add_batch_single_place(size_t batch_size, AggregateDataPtr place, const IColumn** columns,
Arena&) const override {
const auto& sources =
assert_cast<const ColVecType&, TypeCheckOnRelease::DISABLE>(*columns[0]);
const auto& quantile =
assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(*columns[1]);
DCHECK_EQ(sources.get_data().size(), batch_size);
AggregateFunctionPercentile::data(place).add_batch(sources.get_data(),
quantile.get_data()[0]);
}
void reset(AggregateDataPtr __restrict place) const override {
AggregateFunctionPercentile::data(place).reset();
}
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
Arena&) const override {
AggregateFunctionPercentile::data(place).merge(AggregateFunctionPercentile::data(rhs));
}
void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override {
AggregateFunctionPercentile::data(place).write(buf);
}
void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
Arena&) const override {
AggregateFunctionPercentile::data(place).read(buf);
}
void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
auto& col = assert_cast<ColumnFloat64&>(to);
col.insert_value(AggregateFunctionPercentile::data(place).get());
}
};
template <PrimitiveType T>
class AggregateFunctionPercentileArray final
: public IAggregateFunctionDataHelper<PercentileState<T>,
AggregateFunctionPercentileArray<T>>,
MultiExpression,
NotNullableAggregateFunction {
public:
using ColVecType = typename PrimitiveTypeTraits<T>::ColumnType;
using Base =
IAggregateFunctionDataHelper<PercentileState<T>, AggregateFunctionPercentileArray<T>>;
AggregateFunctionPercentileArray(const DataTypes& argument_types_) : Base(argument_types_) {}
String get_name() const override { return "percentile_array"; }
DataTypePtr get_return_type() const override {
return std::make_shared<DataTypeArray>(make_nullable(std::make_shared<DataTypeFloat64>()));
}
void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena&) const override {
const auto& sources =
assert_cast<const ColVecType&, TypeCheckOnRelease::DISABLE>(*columns[0]);
const auto& quantile_array =
assert_cast<const ColumnArray&, TypeCheckOnRelease::DISABLE>(*columns[1]);
const auto& offset_column_data = quantile_array.get_offsets();
const auto& null_maps = assert_cast<const ColumnNullable&, TypeCheckOnRelease::DISABLE>(
quantile_array.get_data())
.get_null_map_data();
const auto& nested_column = assert_cast<const ColumnNullable&, TypeCheckOnRelease::DISABLE>(
quantile_array.get_data())
.get_nested_column();
const auto& nested_column_data =
assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(nested_column);
AggregateFunctionPercentileArray::data(place).add(
sources.get_element(row_num), nested_column_data.get_data(), null_maps,
offset_column_data.data()[row_num] - offset_column_data[(ssize_t)row_num - 1]);
}
void reset(AggregateDataPtr __restrict place) const override {
AggregateFunctionPercentileArray::data(place).reset();
}
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
Arena&) const override {
AggregateFunctionPercentileArray::data(place).merge(
AggregateFunctionPercentileArray::data(rhs));
}
void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override {
AggregateFunctionPercentileArray::data(place).write(buf);
}
void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
Arena&) const override {
AggregateFunctionPercentileArray::data(place).read(buf);
}
void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
auto& to_arr = assert_cast<ColumnArray&>(to);
auto& to_nested_col = to_arr.get_data();
if (to_nested_col.is_nullable()) {
auto col_null = reinterpret_cast<ColumnNullable*>(&to_nested_col);
AggregateFunctionPercentileArray::data(place).insert_result_into(
col_null->get_nested_column());
col_null->get_null_map_data().resize_fill(col_null->get_nested_column().size(), 0);
} else {
AggregateFunctionPercentileArray::data(place).insert_result_into(to_nested_col);
}
to_arr.get_offsets().push_back(to_nested_col.size());
}
};
template <PrimitiveType T>
class AggregateFunctionPercentileV2 final
: public IAggregateFunctionDataHelper<PercentileExactState<T>,
AggregateFunctionPercentileV2<T>>,
MultiExpression,
NullableAggregateFunction {
public:
using ColVecType = typename PrimitiveTypeTraits<T>::ColumnType;
using Base =
IAggregateFunctionDataHelper<PercentileExactState<T>, AggregateFunctionPercentileV2<T>>;
AggregateFunctionPercentileV2(const DataTypes& argument_types_) : Base(argument_types_) {}
String get_name() const override { return "percentile_v2"; }
DataTypePtr get_return_type() const override { return std::make_shared<DataTypeFloat64>(); }
void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena&) const override {
const auto& sources =
assert_cast<const ColVecType&, TypeCheckOnRelease::DISABLE>(*columns[0]);
const auto& quantile =
assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(*columns[1]);
AggregateFunctionPercentileV2::data(place).add_single_range(&sources.get_data()[row_num], 1,
quantile.get_data()[0]);
}
void add_batch_single_place(size_t batch_size, AggregateDataPtr place, const IColumn** columns,
Arena&) const override {
const auto& sources =
assert_cast<const ColVecType&, TypeCheckOnRelease::DISABLE>(*columns[0]);
const auto& quantile =
assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(*columns[1]);
DCHECK_EQ(sources.get_data().size(), batch_size);
AggregateFunctionPercentileV2::data(place).add_single_range(
sources.get_data().data(), batch_size, quantile.get_data()[0]);
}
void add_batch_range(size_t batch_begin, size_t batch_end, AggregateDataPtr place,
const IColumn** columns, Arena&, bool has_null) override {
const auto& sources =
assert_cast<const ColVecType&, TypeCheckOnRelease::DISABLE>(*columns[0]);
const auto& quantile =
assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(*columns[1]);
DCHECK(!has_null);
AggregateFunctionPercentileV2::data(place).add_single_range(
sources.get_data().data() + batch_begin, batch_end - batch_begin + 1,
quantile.get_data()[0]);
}
void reset(AggregateDataPtr __restrict place) const override {
AggregateFunctionPercentileV2::data(place).reset();
}
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
Arena&) const override {
AggregateFunctionPercentileV2::data(place).merge(AggregateFunctionPercentileV2::data(rhs));
}
void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override {
AggregateFunctionPercentileV2::data(place).write(buf);
}
void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
Arena&) const override {
AggregateFunctionPercentileV2::data(place).read(buf);
}
void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
auto& col = assert_cast<ColumnFloat64&>(to);
col.insert_value(AggregateFunctionPercentileV2::data(place).get());
}
};
template <PrimitiveType T>
class AggregateFunctionPercentileArrayV2 final
: public IAggregateFunctionDataHelper<PercentileExactState<T>,
AggregateFunctionPercentileArrayV2<T>>,
MultiExpression,
NotNullableAggregateFunction {
public:
using ColVecType = typename PrimitiveTypeTraits<T>::ColumnType;
using Base = IAggregateFunctionDataHelper<PercentileExactState<T>,
AggregateFunctionPercentileArrayV2<T>>;
AggregateFunctionPercentileArrayV2(const DataTypes& argument_types_) : Base(argument_types_) {}
String get_name() const override { return "percentile_array_v2"; }
DataTypePtr get_return_type() const override {
return std::make_shared<DataTypeArray>(make_nullable(std::make_shared<DataTypeFloat64>()));
}
void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena&) const override {
const auto& sources =
assert_cast<const ColVecType&, TypeCheckOnRelease::DISABLE>(*columns[0]);
const auto& quantile_array =
assert_cast<const ColumnArray&, TypeCheckOnRelease::DISABLE>(*columns[1]);
const auto& offset_column_data = quantile_array.get_offsets();
const auto& null_maps = assert_cast<const ColumnNullable&, TypeCheckOnRelease::DISABLE>(
quantile_array.get_data())
.get_null_map_data();
const auto& nested_column = assert_cast<const ColumnNullable&, TypeCheckOnRelease::DISABLE>(
quantile_array.get_data())
.get_nested_column();
const auto& nested_column_data =
assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(nested_column);
size_t start = row_num == 0 ? 0 : offset_column_data[row_num - 1];
AggregateFunctionPercentileArrayV2::data(place).add_many_range(
&sources.get_data()[row_num], 1, nested_column_data.get_data(), null_maps, start,
cast_set<int64_t>(offset_column_data[row_num] - start));
}
void add_batch_single_place(size_t batch_size, AggregateDataPtr place, const IColumn** columns,
Arena&) const override {
const auto& sources =
assert_cast<const ColVecType&, TypeCheckOnRelease::DISABLE>(*columns[0]);
const auto& quantile_array =
assert_cast<const ColumnArray&, TypeCheckOnRelease::DISABLE>(*columns[1]);
const auto& offset_column_data = quantile_array.get_offsets();
const auto& null_maps = assert_cast<const ColumnNullable&, TypeCheckOnRelease::DISABLE>(
quantile_array.get_data())
.get_null_map_data();
const auto& nested_column = assert_cast<const ColumnNullable&, TypeCheckOnRelease::DISABLE>(
quantile_array.get_data())
.get_nested_column();
const auto& nested_column_data =
assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(nested_column);
DCHECK_EQ(sources.get_data().size(), batch_size);
AggregateFunctionPercentileArrayV2::data(place).add_many_range(
sources.get_data().data(), batch_size, nested_column_data.get_data(), null_maps, 0,
cast_set<int64_t>(offset_column_data[0]));
}
void add_batch_range(size_t batch_begin, size_t batch_end, AggregateDataPtr place,
const IColumn** columns, Arena&, bool has_null) override {
const auto& sources =
assert_cast<const ColVecType&, TypeCheckOnRelease::DISABLE>(*columns[0]);
const auto& quantile_array =
assert_cast<const ColumnArray&, TypeCheckOnRelease::DISABLE>(*columns[1]);
const auto& offset_column_data = quantile_array.get_offsets();
const auto& null_maps = assert_cast<const ColumnNullable&, TypeCheckOnRelease::DISABLE>(
quantile_array.get_data())
.get_null_map_data();
const auto& nested_column = assert_cast<const ColumnNullable&, TypeCheckOnRelease::DISABLE>(
quantile_array.get_data())
.get_nested_column();
const auto& nested_column_data =
assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(nested_column);
DCHECK(!has_null);
size_t start = batch_begin == 0 ? 0 : offset_column_data[batch_begin - 1];
AggregateFunctionPercentileArrayV2::data(place).add_many_range(
sources.get_data().data() + batch_begin, batch_end - batch_begin + 1,
nested_column_data.get_data(), null_maps, start,
cast_set<int64_t>(offset_column_data[batch_begin] - start));
}
void reset(AggregateDataPtr __restrict place) const override {
AggregateFunctionPercentileArrayV2::data(place).reset();
}
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
Arena&) const override {
AggregateFunctionPercentileArrayV2::data(place).merge(
AggregateFunctionPercentileArrayV2::data(rhs));
}
void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override {
AggregateFunctionPercentileArrayV2::data(place).write(buf);
}
void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
Arena&) const override {
AggregateFunctionPercentileArrayV2::data(place).read(buf);
}
void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
auto& to_arr = assert_cast<ColumnArray&>(to);
auto& to_nested_col = to_arr.get_data();
if (to_nested_col.is_nullable()) {
auto* col_null = reinterpret_cast<ColumnNullable*>(&to_nested_col);
AggregateFunctionPercentileArrayV2::data(place).insert_result_into(
col_null->get_nested_column());
col_null->get_null_map_data().resize_fill(col_null->get_nested_column().size(), 0);
} else {
AggregateFunctionPercentileArrayV2::data(place).insert_result_into(to_nested_col);
}
to_arr.get_offsets().push_back(to_nested_col.size());
}
};
} // namespace doris