blob: 7dbef1a17fb3ea752754237b8814a96d6f6f2335 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
// This file is copied from
// https://github.com/ClickHouse/ClickHouse/blob/master/AggregateFunctionWindowFunnel.h
// and modified by Doris
#pragma once
#include <gen_cpp/data.pb.h>
#include <algorithm>
#include <boost/iterator/iterator_facade.hpp>
#include <iterator>
#include <memory>
#include <type_traits>
#include <utility>
#include "common/cast_set.h"
#include "common/exception.h"
#include "common/status.h"
#include "util/binary_cast.hpp"
#include "util/simd/bits.h"
#include "vec/aggregate_functions/aggregate_function.h"
#include "vec/columns/column_string.h"
#include "vec/common/assert_cast.h"
#include "vec/core/sort_block.h"
#include "vec/core/types.h"
#include "vec/data_types/data_type_number.h"
#include "vec/io/var_int.h"
#include "vec/runtime/vdatetime_value.h"
namespace doris {
#include "common/compile_check_begin.h"
namespace vectorized {
class Arena;
class BufferReadable;
class BufferWritable;
class IColumn;
} // namespace vectorized
} // namespace doris
namespace doris::vectorized {
enum class WindowFunnelMode : Int64 { INVALID, DEFAULT, DEDUPLICATION, FIXED, INCREASE };
WindowFunnelMode string_to_window_funnel_mode(const String& string) {
if (string == "default") {
return WindowFunnelMode::DEFAULT;
} else if (string == "deduplication") {
return WindowFunnelMode::DEDUPLICATION;
} else if (string == "fixed") {
return WindowFunnelMode::FIXED;
} else if (string == "increase") {
return WindowFunnelMode::INCREASE;
} else {
return WindowFunnelMode::INVALID;
}
}
struct DataValue {
using TimestampEvent = std::vector<ColumnUInt8::Container>;
std::vector<UInt64> dt;
TimestampEvent event_columns_data;
bool operator<(const DataValue& other) const { return dt < other.dt; }
void clear() {
dt.clear();
for (auto& data : event_columns_data) {
data.clear();
}
}
auto size() const { return dt.size(); }
bool empty() const { return dt.empty(); }
std::string debug_string() const {
std::string result = "\n" + std::to_string(dt.size()) + " " +
std::to_string(event_columns_data[0].size()) + "\n";
for (size_t i = 0; i < dt.size(); ++i) {
result += std::to_string(dt[i]) + " VS " +
binary_cast<UInt64, DateV2Value<DateTimeV2ValueType>>(dt[i]).debug_string() +
" ,";
for (const auto& event : event_columns_data) {
result += std::to_string(event[i]) + ",";
}
result += "\n";
}
return result;
}
};
struct WindowFunnelState {
static constexpr PrimitiveType PType = PrimitiveType::TYPE_DATETIMEV2;
using NativeType = UInt64;
using DateValueType = DateV2Value<DateTimeV2ValueType>;
int event_count = 0;
int64_t window;
bool enable_mode;
WindowFunnelMode window_funnel_mode;
DataValue events_list;
WindowFunnelState() {
event_count = 0;
window = 0;
window_funnel_mode = WindowFunnelMode::INVALID;
}
WindowFunnelState(int arg_event_count) : WindowFunnelState() {
event_count = arg_event_count;
events_list.event_columns_data.resize(event_count);
}
void reset() { events_list.clear(); }
void add(const IColumn** arg_columns, ssize_t row_num, int64_t win, WindowFunnelMode mode) {
window = win;
window_funnel_mode = enable_mode ? mode : WindowFunnelMode::DEFAULT;
events_list.dt.emplace_back(
assert_cast<const ColumnVector<PType>&>(*arg_columns[2]).get_data()[row_num]);
for (int i = 0; i < event_count; i++) {
events_list.event_columns_data[i].emplace_back(
assert_cast<const ColumnUInt8&>(*arg_columns[3 + i]).get_data()[row_num]);
}
}
// todo: rethink thid sort method.
void sort() {
auto num = events_list.size();
std::vector<size_t> indices(num);
std::iota(indices.begin(), indices.end(), 0);
std::sort(indices.begin(), indices.end(),
[this](size_t i1, size_t i2) { return events_list.dt[i1] < events_list.dt[i2]; });
auto reorder = [&indices, &num](auto& vec) {
std::decay_t<decltype(vec)> temp;
temp.resize(num);
for (auto i = 0; i < num; i++) {
temp[i] = vec[indices[i]];
}
std::swap(vec, temp);
};
reorder(events_list.dt);
for (auto& inner_vec : events_list.event_columns_data) {
reorder(inner_vec);
}
}
template <WindowFunnelMode WINDOW_FUNNEL_MODE>
int _match_event_list(size_t& start_row, size_t row_count) const {
int matched_count = 0;
DateValueType start_timestamp;
DateValueType end_timestamp;
if (window < 0) {
throw Exception(ErrorCode::INVALID_ARGUMENT,
"the sliding time window must be a positive integer, but got: {}",
window);
}
TimeInterval interval(SECOND, window, false);
int column_idx = 0;
const auto& timestamp_data = events_list.dt;
const auto& first_event_data = events_list.event_columns_data[column_idx].data();
auto match_row = simd::find_one(first_event_data, start_row, row_count);
start_row = match_row + 1;
if (match_row < row_count) {
auto prev_timestamp = binary_cast<NativeType, DateValueType>(timestamp_data[match_row]);
end_timestamp = prev_timestamp;
end_timestamp.template date_add_interval<SECOND>(interval);
matched_count++;
column_idx++;
auto last_match_row = match_row;
++match_row;
for (; column_idx < event_count && match_row < row_count; column_idx++, match_row++) {
const auto& event_data = events_list.event_columns_data[column_idx];
if constexpr (WINDOW_FUNNEL_MODE == WindowFunnelMode::FIXED) {
if (event_data[match_row] == 1) {
auto current_timestamp =
binary_cast<NativeType, DateValueType>(timestamp_data[match_row]);
if (current_timestamp <= end_timestamp) {
matched_count++;
continue;
}
}
break;
}
match_row = simd::find_one(event_data.data(), match_row, row_count);
if (match_row < row_count) {
auto current_timestamp =
binary_cast<NativeType, DateValueType>(timestamp_data[match_row]);
bool is_matched = current_timestamp <= end_timestamp;
if (is_matched) {
if constexpr (WINDOW_FUNNEL_MODE == WindowFunnelMode::INCREASE) {
is_matched = current_timestamp > prev_timestamp;
}
}
if (!is_matched) {
break;
}
if constexpr (WINDOW_FUNNEL_MODE == WindowFunnelMode::INCREASE) {
prev_timestamp =
binary_cast<NativeType, DateValueType>(timestamp_data[match_row]);
}
if constexpr (WINDOW_FUNNEL_MODE == WindowFunnelMode::DEDUPLICATION) {
bool is_dup = false;
if (match_row != last_match_row + 1) {
for (int tmp_column_idx = 0; tmp_column_idx < column_idx;
tmp_column_idx++) {
const auto& tmp_event_data =
events_list.event_columns_data[tmp_column_idx].data();
auto dup_match_row = simd::find_one(tmp_event_data,
last_match_row + 1, match_row);
if (dup_match_row < match_row) {
is_dup = true;
break;
}
}
}
if (is_dup) {
break;
}
last_match_row = match_row;
}
matched_count++;
} else {
break;
}
}
}
return matched_count;
}
template <WindowFunnelMode WINDOW_FUNNEL_MODE>
int _get_internal() const {
size_t start_row = 0;
int max_found_event_count = 0;
auto row_count = events_list.size();
while (start_row < row_count) {
auto found_event_count = _match_event_list<WINDOW_FUNNEL_MODE>(start_row, row_count);
if (found_event_count == event_count) {
return found_event_count;
}
max_found_event_count = std::max(max_found_event_count, found_event_count);
}
return max_found_event_count;
}
int get() const {
auto row_count = events_list.size();
if (event_count == 0 || row_count == 0) {
return 0;
}
switch (window_funnel_mode) {
case WindowFunnelMode::DEFAULT:
return _get_internal<WindowFunnelMode::DEFAULT>();
case WindowFunnelMode::DEDUPLICATION:
return _get_internal<WindowFunnelMode::DEDUPLICATION>();
case WindowFunnelMode::FIXED:
return _get_internal<WindowFunnelMode::FIXED>();
case WindowFunnelMode::INCREASE:
return _get_internal<WindowFunnelMode::INCREASE>();
default:
throw doris::Exception(ErrorCode::INTERNAL_ERROR, "Invalid window_funnel mode");
return 0;
}
}
void merge(const WindowFunnelState& other) {
if (other.events_list.empty()) {
return;
}
events_list.dt.insert(std::end(events_list.dt), std::begin(other.events_list.dt),
std::end(other.events_list.dt));
for (size_t i = 0; i < event_count; i++) {
events_list.event_columns_data[i].insert(
std::end(events_list.event_columns_data[i]),
std::begin(other.events_list.event_columns_data[i]),
std::end(other.events_list.event_columns_data[i]));
}
event_count = event_count > 0 ? event_count : other.event_count;
window = window > 0 ? window : other.window;
if (enable_mode) {
window_funnel_mode = window_funnel_mode == WindowFunnelMode::INVALID
? other.window_funnel_mode
: window_funnel_mode;
} else {
window_funnel_mode = WindowFunnelMode::DEFAULT;
}
}
void write(BufferWritable& out) const {
write_var_int(event_count, out);
write_var_int(window, out);
if (enable_mode) {
write_var_int(static_cast<std::underlying_type_t<WindowFunnelMode>>(window_funnel_mode),
out);
}
auto size = events_list.size();
write_var_int(size, out);
for (const auto& timestamp : events_list.dt) {
write_var_int(timestamp, out);
}
for (int64_t i = 0; i < event_count; i++) {
const auto& event_columns_data = events_list.event_columns_data[i];
for (auto event : event_columns_data) {
write_var_int(event, out);
}
}
}
void read(BufferReadable& in) {
int64_t event_level;
read_var_int(event_level, in);
event_count = (int)event_level;
read_var_int(window, in);
window_funnel_mode = WindowFunnelMode::DEFAULT;
if (enable_mode) {
int64_t mode;
read_var_int(mode, in);
window_funnel_mode = static_cast<WindowFunnelMode>(mode);
}
int64_t size = 0;
read_var_int(size, in);
events_list.clear();
events_list.dt.resize(size);
for (auto i = 0; i < size; i++) {
read_var_int(*reinterpret_cast<Int64*>(&events_list.dt[i]), in);
}
events_list.event_columns_data.resize(event_count);
for (int64_t i = 0; i < event_count; i++) {
auto& event_columns_data = events_list.event_columns_data[i];
event_columns_data.resize(size);
for (auto j = 0; j < size; j++) {
Int64 temp_value;
read_var_int(temp_value, in);
event_columns_data[j] = static_cast<UInt8>(temp_value);
}
}
}
};
class AggregateFunctionWindowFunnel
: public IAggregateFunctionDataHelper<WindowFunnelState, AggregateFunctionWindowFunnel>,
MultiExpression,
NullableAggregateFunction {
public:
AggregateFunctionWindowFunnel(const DataTypes& argument_types_)
: IAggregateFunctionDataHelper<WindowFunnelState, AggregateFunctionWindowFunnel>(
argument_types_) {}
void create(AggregateDataPtr __restrict place) const override {
auto data = new (place) WindowFunnelState(
cast_set<int>(IAggregateFunction::get_argument_types().size() - 3));
/// support window funnel mode from 2.0. See `BeExecVersionManager::max_be_exec_version`
data->enable_mode = IAggregateFunction::version >= 3;
}
String get_name() const override { return "window_funnel"; }
DataTypePtr get_return_type() const override { return std::make_shared<DataTypeInt32>(); }
void reset(AggregateDataPtr __restrict place) const override { this->data(place).reset(); }
void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena&) const override {
const auto& window = assert_cast<const ColumnInt64&>(*columns[0]).get_data()[row_num];
StringRef mode = columns[1]->get_data_at(row_num);
this->data(place).add(columns, row_num, window,
string_to_window_funnel_mode(mode.to_string()));
}
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
Arena&) const override {
this->data(place).merge(this->data(rhs));
}
void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override {
this->data(place).write(buf);
}
void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
Arena&) const override {
this->data(place).read(buf);
}
void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
// place is essentially an AggregateDataPtr, passed as a ConstAggregateDataPtr.
this->data(const_cast<AggregateDataPtr>(place)).sort();
assert_cast<ColumnInt32&>(to).get_data().push_back(
IAggregateFunctionDataHelper<WindowFunnelState,
AggregateFunctionWindowFunnel>::data(place)
.get());
}
protected:
using IAggregateFunction::version;
};
} // namespace doris::vectorized
#include "common/compile_check_end.h"