blob: b29e8e7fd205b17c63d20e75858e6329d5f7cae9 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <cmath>
#include "vec/aggregate_functions/aggregate_function.h"
#include "vec/columns/column_nullable.h"
#include "vec/columns/column_vector.h"
#include "vec/common/assert_cast.h"
#include "vec/core/field.h"
#include "vec/core/types.h"
#include "vec/data_types/data_type.h"
#include "vec/data_types/data_type_nullable.h"
#include "vec/data_types/data_type_number.h"
namespace doris::vectorized {
#include "common/compile_check_begin.h"
template <PrimitiveType T,
// requires Sx and Sy
bool NeedSxy,
// level 1: Sx
// level 2: Sxx
size_t SxLevel = size_t {NeedSxy},
// level 1: Sy
// level 2: Syy
size_t SyLevel = size_t {NeedSxy}>
struct AggregateFunctionRegrData {
static constexpr PrimitiveType Type = T;
static_assert(!NeedSxy || (SxLevel > 0 && SyLevel > 0),
"NeedSxy requires SxLevel > 0 and SyLevel > 0");
static_assert(SxLevel <= 2 && SyLevel <= 2, "Sx/Sy level must be <= 2");
static constexpr bool need_sx = SxLevel > 0;
static constexpr bool need_sy = SyLevel > 0;
static constexpr bool need_sxx = SxLevel > 1;
static constexpr bool need_syy = SyLevel > 1;
static constexpr bool need_sxy = NeedSxy;
static constexpr size_t kMomentSize = SxLevel + SyLevel + size_t {need_sxy};
static_assert(kMomentSize > 0 && kMomentSize <= 5, "Unexpected size of regr moment array");
/**
* The moments array is:
* Sx = sum(X)
* Sy = sum(Y)
* Sxx = sum((X-Sx/N)^2)
* Syy = sum((Y-Sy/N)^2)
* Sxy = sum((X-Sx/N)*(Y-Sy/N))
*/
std::array<Float64, kMomentSize> moments {};
UInt64 n {};
static constexpr size_t idx_sx() {
static_assert(need_sx, "sx not enabled");
return 0;
}
static constexpr size_t idx_sy() {
static_assert(need_sy, "sy not enabled");
return size_t {need_sx};
}
static constexpr size_t idx_sxx() {
static_assert(need_sxx, "sxx not enabled");
return size_t {need_sx + need_sy};
}
static constexpr size_t idx_syy() {
static_assert(need_syy, "syy not enabled");
return size_t {need_sx + need_sy + need_sxx};
}
static constexpr size_t idx_sxy() {
static_assert(need_sxy, "sxy not enabled");
return size_t {need_sx + need_sy + need_sxx + need_syy};
}
Float64& sx() { return moments[idx_sx()]; }
Float64& sy() { return moments[idx_sy()]; }
Float64& sxx() { return moments[idx_sxx()]; }
Float64& syy() { return moments[idx_syy()]; }
Float64& sxy() { return moments[idx_sxy()]; }
const Float64& sx() const { return moments[idx_sx()]; }
const Float64& sy() const { return moments[idx_sy()]; }
const Float64& sxx() const { return moments[idx_sxx()]; }
const Float64& syy() const { return moments[idx_syy()]; }
const Float64& sxy() const { return moments[idx_sxy()]; }
void write(BufferWritable& buf) const {
if constexpr (need_sx) {
buf.write_binary(sx());
}
if constexpr (need_sy) {
buf.write_binary(sy());
}
if constexpr (need_sxx) {
buf.write_binary(sxx());
}
if constexpr (need_syy) {
buf.write_binary(syy());
}
if constexpr (need_sxy) {
buf.write_binary(sxy());
}
buf.write_binary(n);
}
void read(BufferReadable& buf) {
if constexpr (need_sx) {
buf.read_binary(sx());
}
if constexpr (need_sy) {
buf.read_binary(sy());
}
if constexpr (need_sxx) {
buf.read_binary(sxx());
}
if constexpr (need_syy) {
buf.read_binary(syy());
}
if constexpr (need_sxy) {
buf.read_binary(sxy());
}
buf.read_binary(n);
}
void reset() {
moments.fill({});
n = {};
}
/**
* The merge function uses the Youngs–Cramer algorithm:
* N = N1 + N2
* Sx = Sx1 + Sx2
* Sy = Sy1 + Sy2
* Sxx = Sxx1 + Sxx2 + N1 * N2 * (Sx1/N1 - Sx2/N2)^2 / N
* Syy = Syy1 + Syy2 + N1 * N2 * (Sy1/N1 - Sy2/N2)^2 / N
* Sxy = Sxy1 + Sxy2 + N1 * N2 * (Sx1/N1 - Sx2/N2) * (Sy1/N1 - Sy2/N2) / N
*/
void merge(const AggregateFunctionRegrData& rhs) {
if (rhs.n == 0) {
return;
}
if (n == 0) {
*this = rhs;
return;
}
const auto n1 = static_cast<Float64>(n);
const auto n2 = static_cast<Float64>(rhs.n);
const auto nsum = n1 + n2;
Float64 dx {};
Float64 dy {};
if constexpr (need_sxx || need_sxy) {
dx = sx() / n1 - rhs.sx() / n2;
}
if constexpr (need_syy || need_sxy) {
dy = sy() / n1 - rhs.sy() / n2;
}
n += rhs.n;
if constexpr (need_sx) {
sx() += rhs.sx();
}
if constexpr (need_sy) {
sy() += rhs.sy();
}
if constexpr (need_sxx) {
sxx() += rhs.sxx() + n1 * n2 * dx * dx / nsum;
}
if constexpr (need_syy) {
syy() += rhs.syy() + n1 * n2 * dy * dy / nsum;
}
if constexpr (need_sxy) {
sxy() += rhs.sxy() + n1 * n2 * dx * dy / nsum;
}
}
/**
* N
* Sx = sum(X)
* Sy = sum(Y)
* Sxx = sum((X-Sx/N)^2)
* Syy = sum((Y-Sy/N)^2)
* Sxy = sum((X-Sx/N)*(Y-Sy/N))
*/
void add(typename PrimitiveTypeTraits<T>::ColumnItemType value_y,
typename PrimitiveTypeTraits<T>::ColumnItemType value_x) {
const auto x = static_cast<Float64>(value_x);
const auto y = static_cast<Float64>(value_y);
if constexpr (need_sx) {
sx() += x;
}
if constexpr (need_sy) {
sy() += y;
}
if (n == 0) [[unlikely]] {
n = 1;
return;
}
const auto n_old = static_cast<Float64>(n);
const auto n_new = n_old + 1;
const auto scale = 1.0 / (n_new * n_old);
n += 1;
Float64 tmp_x {};
Float64 tmp_y {};
if constexpr (need_sxx || need_sxy) {
tmp_x = x * n_new - sx();
}
if constexpr (need_syy || need_sxy) {
tmp_y = y * n_new - sy();
}
if constexpr (need_sxx) {
sxx() += tmp_x * tmp_x * scale;
}
if constexpr (need_syy) {
syy() += tmp_y * tmp_y * scale;
}
if constexpr (need_sxy) {
sxy() += tmp_x * tmp_y * scale;
}
}
};
template <PrimitiveType T>
struct RegrSlopeFunc : AggregateFunctionRegrData<T, true, 2, 1> {
static constexpr const char* name = "regr_slope";
Float64 get_result() const {
if (this->n < 1 || this->sxx() == 0.0) {
return std::numeric_limits<Float64>::quiet_NaN();
}
return this->sxy() / this->sxx();
}
};
template <PrimitiveType T>
struct RegrInterceptFunc : AggregateFunctionRegrData<T, true, 2, 2> {
static constexpr const char* name = "regr_intercept";
Float64 get_result() const {
if (this->n < 1 || this->sxx() == 0.0) {
return std::numeric_limits<Float64>::quiet_NaN();
}
return (this->sy() - this->sx() * this->sxy() / this->sxx()) /
static_cast<Float64>(this->n);
}
};
template <PrimitiveType T>
struct RegrSxxFunc : AggregateFunctionRegrData<T, false, 2, 0> {
static constexpr const char* name = "regr_sxx";
Float64 get_result() const {
if (this->n < 1) {
return std::numeric_limits<Float64>::quiet_NaN();
}
return this->sxx();
}
};
template <PrimitiveType T>
struct RegrSyyFunc : AggregateFunctionRegrData<T, false, 0, 2> {
static constexpr const char* name = "regr_syy";
Float64 get_result() const {
if (this->n < 1) {
return std::numeric_limits<Float64>::quiet_NaN();
}
return this->syy();
}
};
template <PrimitiveType T>
struct RegrSxyFunc : AggregateFunctionRegrData<T, true, 1, 1> {
static constexpr const char* name = "regr_sxy";
Float64 get_result() const {
if (this->n < 1) {
return std::numeric_limits<Float64>::quiet_NaN();
}
return this->sxy();
}
};
template <typename RegrFunc, bool y_nullable, bool x_nullable>
class AggregateFunctionRegrSimple
: public IAggregateFunctionDataHelper<
RegrFunc, AggregateFunctionRegrSimple<RegrFunc, y_nullable, x_nullable>> {
public:
using InputCol = typename PrimitiveTypeTraits<RegrFunc::Type>::ColumnType;
using ResultCol = ColumnFloat64;
explicit AggregateFunctionRegrSimple(const DataTypes& argument_types_)
: IAggregateFunctionDataHelper<
RegrFunc, AggregateFunctionRegrSimple<RegrFunc, y_nullable, x_nullable>>(
argument_types_) {
DCHECK(!argument_types_.empty());
}
String get_name() const override { return RegrFunc::name; }
DataTypePtr get_return_type() const override {
return make_nullable(std::make_shared<DataTypeFloat64>());
}
void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena&) const override {
const auto* y_col = nested_or_null<y_nullable>(columns[0], row_num);
if constexpr (y_nullable) {
if (y_col == nullptr) {
return;
}
}
const auto* x_col = nested_or_null<x_nullable>(columns[1], row_num);
if constexpr (x_nullable) {
if (x_col == nullptr) {
return;
}
}
this->data(place).add(y_col->get_data()[row_num], x_col->get_data()[row_num]);
}
void reset(AggregateDataPtr __restrict place) const override { this->data(place).reset(); }
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
Arena&) const override {
this->data(place).merge(this->data(rhs));
}
void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override {
this->data(place).write(buf);
}
void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
Arena&) const override {
this->data(place).read(buf);
}
void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
const auto& data = this->data(place);
auto& dst_column_with_nullable = assert_cast<ColumnNullable&>(to);
auto& dst_column = assert_cast<ResultCol&>(dst_column_with_nullable.get_nested_column());
Float64 result = data.get_result();
if (std::isnan(result)) {
dst_column_with_nullable.get_null_map_data().push_back(1);
dst_column.insert_default();
} else {
dst_column_with_nullable.get_null_map_data().push_back(0);
dst_column.get_data().push_back(result);
}
}
private:
template <bool Nullable>
static ALWAYS_INLINE const InputCol* nested_or_null(const IColumn* col, ssize_t row_num) {
if constexpr (Nullable) {
const auto& c = assert_cast<const ColumnNullable&, TypeCheckOnRelease::DISABLE>(*col);
if (c.is_null_at(row_num)) {
return nullptr;
}
return assert_cast<const InputCol*, TypeCheckOnRelease::DISABLE>(
c.get_nested_column_ptr().get());
} else {
return assert_cast<const InputCol*, TypeCheckOnRelease::DISABLE>(col->get_ptr().get());
}
}
};
} // namespace doris::vectorized
#include "common/compile_check_end.h"