blob: b417b427af5a636ed576965f141f4b6e7c4c55d7 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <assert.h>
#include <glog/logging.h>
#include <string.h>
#include <cstddef>
#include <limits>
#include <memory>
#include <new>
#include <string>
#include "vec/aggregate_functions/aggregate_function.h"
#include "vec/columns/column.h"
#include "vec/columns/column_array.h"
#include "vec/columns/column_decimal.h"
#include "vec/columns/column_nullable.h"
#include "vec/columns/column_string.h"
#include "vec/common/assert_cast.h"
#include "vec/common/pod_array_fwd.h"
#include "vec/common/string_buffer.hpp"
#include "vec/common/string_ref.h"
#include "vec/core/types.h"
#include "vec/data_types/data_type.h"
#include "vec/data_types/data_type_array.h"
#include "vec/data_types/data_type_nullable.h"
#include "vec/io/io_helper.h"
#include "vec/io/var_int.h"
namespace doris::vectorized {
#include "common/compile_check_begin.h"
template <PrimitiveType T, bool HasLimit>
struct AggregateFunctionCollectSetData {
static constexpr PrimitiveType PType = T;
using ElementType = typename PrimitiveTypeTraits<T>::ColumnItemType;
using ColVecType = typename PrimitiveTypeTraits<T>::ColumnType;
using SelfType = AggregateFunctionCollectSetData;
using Set = phmap::flat_hash_set<ElementType>;
Set data_set;
Int64 max_size = -1;
AggregateFunctionCollectSetData(const DataTypes& argument_types) {}
size_t size() const { return data_set.size(); }
void add(const IColumn& column, size_t row_num) {
data_set.insert(assert_cast<const ColVecType&, TypeCheckOnRelease::DISABLE>(column)
.get_data()[row_num]);
}
void merge(const SelfType& rhs) {
if constexpr (HasLimit) {
if (max_size == -1) {
max_size = rhs.max_size;
}
for (auto& rhs_elem : rhs.data_set) {
if (size() >= max_size) {
return;
}
data_set.insert(rhs_elem);
}
} else {
data_set.merge(Set(rhs.data_set));
}
}
void write(BufferWritable& buf) const {
buf.write_var_uint(data_set.size());
for (const auto& value : data_set) {
buf.write_binary(value);
}
write_var_int(max_size, buf);
}
void read(BufferReadable& buf) {
uint64_t new_size = 0;
buf.read_var_uint(new_size);
ElementType x;
for (size_t i = 0; i < new_size; ++i) {
buf.read_binary(x);
data_set.insert(x);
}
read_var_int(max_size, buf);
}
void insert_result_into(IColumn& to) const {
auto& vec = assert_cast<ColVecType&>(to).get_data();
vec.reserve(size());
for (const auto& item : data_set) {
vec.push_back(item);
}
}
void reset() { data_set.clear(); }
};
template <PrimitiveType T, bool HasLimit>
requires(is_string_type(T))
struct AggregateFunctionCollectSetData<T, HasLimit> {
static constexpr PrimitiveType PType = T;
using ElementType = StringRef;
using ColVecType = ColumnString;
using SelfType = AggregateFunctionCollectSetData<T, HasLimit>;
using Set = phmap::flat_hash_set<ElementType>;
Set data_set;
Int64 max_size = -1;
AggregateFunctionCollectSetData(const DataTypes& argument_types) {}
size_t size() const { return data_set.size(); }
void add(const IColumn& column, size_t row_num, Arena& arena) {
auto key = column.get_data_at(row_num);
key.data = arena.insert(key.data, key.size);
data_set.insert(key);
}
void merge(const SelfType& rhs, Arena& arena) {
if (max_size == -1) {
max_size = rhs.max_size;
}
max_size = rhs.max_size;
for (const auto& rhs_elem : rhs.data_set) {
if constexpr (HasLimit) {
if (size() >= max_size) {
return;
}
}
StringRef key = rhs_elem;
key.data = arena.insert(key.data, key.size);
data_set.insert(key);
}
}
void write(BufferWritable& buf) const {
buf.write_var_uint(size());
for (const auto& elem : data_set) {
buf.write_binary(elem);
}
write_var_int(max_size, buf);
}
void read(BufferReadable& buf) {
UInt64 size;
buf.read_var_uint(size);
StringRef ref;
for (size_t i = 0; i < size; ++i) {
buf.read_binary(ref);
data_set.insert(ref);
}
read_var_int(max_size, buf);
}
void insert_result_into(IColumn& to) const {
auto& vec = assert_cast<ColVecType&>(to);
vec.reserve(size());
for (const auto& item : data_set) {
vec.insert_data(item.data, item.size);
}
}
void reset() { data_set.clear(); }
};
template <PrimitiveType T, bool HasLimit>
struct AggregateFunctionCollectListData {
static constexpr PrimitiveType PType = T;
using ElementType = typename PrimitiveTypeTraits<T>::ColumnItemType;
using ColVecType = typename PrimitiveTypeTraits<T>::ColumnType;
using SelfType = AggregateFunctionCollectListData<T, HasLimit>;
PaddedPODArray<ElementType> data;
Int64 max_size = -1;
AggregateFunctionCollectListData(const DataTypes& argument_types) {}
size_t size() const { return data.size(); }
void add(const IColumn& column, size_t row_num) {
const auto& vec =
assert_cast<const ColVecType&, TypeCheckOnRelease::DISABLE>(column).get_data();
data.push_back(vec[row_num]);
}
void merge(const SelfType& rhs) {
if constexpr (HasLimit) {
if (max_size == -1) {
max_size = rhs.max_size;
}
max_size = rhs.max_size;
for (auto& rhs_elem : rhs.data) {
if (size() >= max_size) {
return;
}
data.push_back(rhs_elem);
}
} else {
data.insert(rhs.data.begin(), rhs.data.end());
}
}
void write(BufferWritable& buf) const {
buf.write_var_uint(size());
buf.write(data.raw_data(), size() * sizeof(ElementType));
write_var_int(max_size, buf);
}
void read(BufferReadable& buf) {
UInt64 rows = 0;
buf.read_var_uint(rows);
data.resize(rows);
buf.read(reinterpret_cast<char*>(data.data()), rows * sizeof(ElementType));
read_var_int(max_size, buf);
}
void reset() { data.clear(); }
void insert_result_into(IColumn& to) const {
auto& vec = assert_cast<ColVecType&>(to).get_data();
size_t old_size = vec.size();
vec.resize(old_size + size());
memcpy(vec.data() + old_size, data.data(), size() * sizeof(ElementType));
}
};
template <PrimitiveType T, bool HasLimit>
requires(is_string_type(T))
struct AggregateFunctionCollectListData<T, HasLimit> {
static constexpr PrimitiveType PType = T;
using ElementType = StringRef;
using ColVecType = ColumnString;
MutableColumnPtr data;
Int64 max_size = -1;
AggregateFunctionCollectListData(const DataTypes& argument_types) {
data = ColVecType::create();
}
size_t size() const { return data->size(); }
void add(const IColumn& column, size_t row_num) { data->insert_from(column, row_num); }
void merge(const AggregateFunctionCollectListData& rhs) {
if constexpr (HasLimit) {
if (max_size == -1) {
max_size = rhs.max_size;
}
max_size = rhs.max_size;
data->insert_range_from(*rhs.data, 0,
std::min(assert_cast<size_t, TypeCheckOnRelease::DISABLE>(
static_cast<size_t>(max_size - size())),
rhs.size()));
} else {
data->insert_range_from(*rhs.data, 0, rhs.size());
}
}
void write(BufferWritable& buf) const {
auto& col = assert_cast<ColVecType&>(*data);
buf.write_var_uint(col.size());
buf.write(col.get_offsets().raw_data(), col.size() * sizeof(IColumn::Offset));
buf.write_var_uint(col.get_chars().size());
buf.write(col.get_chars().raw_data(), col.get_chars().size());
write_var_int(max_size, buf);
}
void read(BufferReadable& buf) {
auto& col = assert_cast<ColVecType&>(*data);
UInt64 offs_size = 0;
buf.read_var_uint(offs_size);
col.get_offsets().resize(offs_size);
buf.read(reinterpret_cast<char*>(col.get_offsets().data()),
offs_size * sizeof(IColumn::Offset));
UInt64 chars_size = 0;
buf.read_var_uint(chars_size);
col.get_chars().resize(chars_size);
buf.read(reinterpret_cast<char*>(col.get_chars().data()), chars_size);
read_var_int(max_size, buf);
}
void reset() { data->clear(); }
void insert_result_into(IColumn& to) const {
auto& to_str = assert_cast<ColVecType&>(to);
to_str.insert_range_from(*data, 0, size());
}
};
template <PrimitiveType T, bool HasLimit>
requires(!is_string_type(T) && !is_int_or_bool(T) && !is_float_or_double(T) && !is_decimal(T) &&
!is_date_type(T) && !is_ip(T) && !is_timestamptz_type(T))
struct AggregateFunctionCollectListData<T, HasLimit> {
static constexpr PrimitiveType PType = T;
using ElementType = StringRef;
using Self = AggregateFunctionCollectListData<T, HasLimit>;
DataTypeSerDeSPtr serde; // for complex serialize && deserialize from multi BE
MutableColumnPtr column_data;
Int64 max_size = -1;
AggregateFunctionCollectListData(const DataTypes& argument_types) {
DataTypePtr column_type = argument_types[0];
column_data = column_type->create_column();
serde = column_type->get_serde();
}
size_t size() const { return column_data->size(); }
void add(const IColumn& column, size_t row_num) { column_data->insert_from(column, row_num); }
void merge(const AggregateFunctionCollectListData& rhs) {
if constexpr (HasLimit) {
if (max_size == -1) {
max_size = rhs.max_size;
}
max_size = rhs.max_size;
column_data->insert_range_from(
*rhs.column_data, 0,
std::min(assert_cast<size_t, TypeCheckOnRelease::DISABLE>(
static_cast<size_t>(max_size - size())),
rhs.size()));
} else {
column_data->insert_range_from(*rhs.column_data, 0, rhs.size());
}
}
void write(BufferWritable& buf) const {
const size_t size = column_data->size();
buf.write_binary(size);
DataTypeSerDe::FormatOptions opt;
auto tmp_str = ColumnString::create();
VectorBufferWriter tmp_buf(*tmp_str.get());
for (size_t i = 0; i < size; i++) {
tmp_str->clear();
if (Status st = serde->serialize_one_cell_to_json(*column_data, i, tmp_buf, opt); !st) {
throw doris::Exception(ErrorCode::INTERNAL_ERROR,
"Failed to serialize data for " + column_data->get_name() +
" error: " + st.to_string());
}
tmp_buf.commit();
buf.write_binary(tmp_str->get_data_at(0));
}
write_var_int(max_size, buf);
}
void read(BufferReadable& buf) {
size_t size = 0;
buf.read_binary(size);
column_data->clear();
column_data->reserve(size);
StringRef s;
DataTypeSerDe::FormatOptions opt;
for (size_t i = 0; i < size; i++) {
buf.read_binary(s);
Slice slice(s.data, s.size);
if (Status st = serde->deserialize_one_cell_from_json(*column_data, slice, opt); !st) {
throw doris::Exception(ErrorCode::INTERNAL_ERROR,
"Failed to deserialize data for " + column_data->get_name() +
" error: " + st.to_string());
}
}
read_var_int(max_size, buf);
}
void reset() { column_data->clear(); }
void insert_result_into(IColumn& to) const { to.insert_range_from(*column_data, 0, size()); }
};
template <typename Data, bool HasLimit>
class AggregateFunctionCollect
: public IAggregateFunctionDataHelper<Data, AggregateFunctionCollect<Data, HasLimit>, true>,
VarargsExpression,
NotNullableAggregateFunction {
static constexpr bool ENABLE_ARENA =
std::is_same_v<Data, AggregateFunctionCollectSetData<TYPE_STRING, HasLimit>> ||
std::is_same_v<Data, AggregateFunctionCollectSetData<TYPE_CHAR, HasLimit>> ||
std::is_same_v<Data, AggregateFunctionCollectSetData<TYPE_VARCHAR, HasLimit>>;
public:
AggregateFunctionCollect(const DataTypes& argument_types_)
: IAggregateFunctionDataHelper<Data, AggregateFunctionCollect<Data, HasLimit>, true>(
{argument_types_}),
return_type(std::make_shared<DataTypeArray>(make_nullable(argument_types_[0]))) {}
std::string get_name() const override {
if constexpr (std::is_same_v<AggregateFunctionCollectListData<Data::PType, HasLimit>,
Data>) {
return "collect_list";
} else {
return "collect_set";
}
}
DataTypePtr get_return_type() const override { return return_type; }
void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena& arena) const override {
auto& data = this->data(place);
if constexpr (HasLimit) {
if (data.max_size == -1) {
data.max_size =
(UInt64)assert_cast<const ColumnInt32*, TypeCheckOnRelease::DISABLE>(
columns[1])
->get_element(row_num);
}
if (data.size() >= data.max_size) {
return;
}
}
if constexpr (ENABLE_ARENA) {
data.add(*columns[0], row_num, arena);
} else {
data.add(*columns[0], row_num);
}
}
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
Arena& arena) const override {
auto& data = this->data(place);
const auto& rhs_data = this->data(rhs);
if constexpr (ENABLE_ARENA) {
data.merge(rhs_data, arena);
} else {
data.merge(rhs_data);
}
}
void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override {
this->data(place).write(buf);
}
void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
Arena&) const override {
this->data(place).read(buf);
}
void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
auto& to_arr = assert_cast<ColumnArray&>(to);
auto& to_nested_col = to_arr.get_data();
auto* col_null = assert_cast<ColumnNullable*>(&to_nested_col);
this->data(place).insert_result_into(col_null->get_nested_column());
col_null->get_null_map_data().resize_fill(col_null->get_nested_column().size(), 0);
to_arr.get_offsets().push_back(to_nested_col.size());
}
private:
DataTypePtr return_type;
};
} // namespace doris::vectorized
#include "common/compile_check_end.h"