blob: 7eb3ed90bed2690e05446b16c84e4118283c0a2f [file]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <array>
#include <cmath>
#include <cstdint>
#include <limits>
#include <tuple>
#include <type_traits>
#include "core/assert_cast.h"
#include "core/column/column_nullable.h"
#include "core/column/column_vector.h"
#include "core/data_type/data_type.h"
#include "core/data_type/data_type_nullable.h"
#include "core/data_type/data_type_number.h"
#include "core/field.h"
#include "core/types.h"
#include "exprs/aggregate/aggregate_function.h"
namespace doris {
enum class RegrFunctionKind : uint8_t {
regr_avgx,
regr_avgy,
regr_count,
regr_slope,
regr_intercept,
regr_sxx,
regr_syy,
regr_sxy,
regr_r2,
};
template <RegrFunctionKind kind>
struct RegrTraits;
template <>
struct RegrTraits<RegrFunctionKind::regr_avgx> {
static constexpr auto name = "regr_avgx";
static constexpr size_t sx_level = 1;
static constexpr size_t sy_level = 0;
static constexpr bool use_sxy = false;
};
template <>
struct RegrTraits<RegrFunctionKind::regr_avgy> {
static constexpr auto name = "regr_avgy";
static constexpr size_t sx_level = 0;
static constexpr size_t sy_level = 1;
static constexpr bool use_sxy = false;
};
template <>
struct RegrTraits<RegrFunctionKind::regr_count> {
static constexpr auto name = "regr_count";
static constexpr size_t sx_level = 0;
static constexpr size_t sy_level = 0;
static constexpr bool use_sxy = false;
};
template <>
struct RegrTraits<RegrFunctionKind::regr_slope> {
static constexpr auto name = "regr_slope";
static constexpr size_t sx_level = 2;
static constexpr size_t sy_level = 1;
static constexpr bool use_sxy = true;
};
template <>
struct RegrTraits<RegrFunctionKind::regr_intercept> {
static constexpr auto name = "regr_intercept";
static constexpr size_t sx_level = 2;
// Keep `syy` to preserve regr_intercept's historical serialized state layout.
static constexpr size_t sy_level = 2;
static constexpr bool use_sxy = true;
};
template <>
struct RegrTraits<RegrFunctionKind::regr_sxx> {
static constexpr auto name = "regr_sxx";
static constexpr size_t sx_level = 2;
static constexpr size_t sy_level = 0;
static constexpr bool use_sxy = false;
};
template <>
struct RegrTraits<RegrFunctionKind::regr_syy> {
static constexpr auto name = "regr_syy";
static constexpr size_t sx_level = 0;
static constexpr size_t sy_level = 2;
static constexpr bool use_sxy = false;
};
template <>
struct RegrTraits<RegrFunctionKind::regr_sxy> {
static constexpr auto name = "regr_sxy";
static constexpr size_t sx_level = 1;
static constexpr size_t sy_level = 1;
static constexpr bool use_sxy = true;
};
template <>
struct RegrTraits<RegrFunctionKind::regr_r2> {
static constexpr auto name = "regr_r2";
static constexpr size_t sx_level = 2;
static constexpr size_t sy_level = 2;
static constexpr bool use_sxy = true;
};
template <PrimitiveType T, RegrFunctionKind kind>
struct AggregateFunctionRegrData {
using Traits = RegrTraits<kind>;
static constexpr PrimitiveType Type = T;
static_assert(Traits::sx_level <= 2 && Traits::sy_level <= 2, "sx/sy level must be <= 2");
static_assert(!Traits::use_sxy || (Traits::sx_level > 0 && Traits::sy_level > 0),
"sxy requires sx_level > 0 and sy_level > 0");
static constexpr bool has_sx = Traits::sx_level > 0;
static constexpr bool has_sy = Traits::sy_level > 0;
static constexpr bool has_sxx = Traits::sx_level > 1;
static constexpr bool has_syy = Traits::sy_level > 1;
static constexpr bool has_sxy = Traits::use_sxy;
static constexpr size_t k_num_moments = Traits::sx_level + Traits::sy_level + size_t {has_sxy};
static constexpr bool has_moments = k_num_moments > 0;
static_assert(k_num_moments <= 5, "Unexpected size of regr moment array");
using State = std::conditional_t<has_moments,
// (n, moments)
std::tuple<UInt64, std::array<Float64, k_num_moments>>,
// (n)
std::tuple<UInt64>>;
/**
* The aggregate state stores:
* N = count
*
* The following moments are stored only when enabled by RegrTraits<kind>:
* Sx = sum(X)
* Sy = sum(Y)
* Sxx = sum((X-Sx/N)^2)
* Syy = sum((Y-Sy/N)^2)
* Sxy = sum((X-Sx/N)*(Y-Sy/N))
*/
State state {};
auto& moments() {
static_assert(has_moments, "moments not enabled");
return std::get<1>(state);
}
const auto& moments() const {
static_assert(has_moments, "moments not enabled");
return std::get<1>(state);
}
static constexpr size_t sx_index() {
static_assert(has_sx, "sx not enabled");
return 0;
}
static constexpr size_t sy_index() {
static_assert(has_sy, "sy not enabled");
return size_t {has_sx};
}
static constexpr size_t sxx_index() {
static_assert(has_sxx, "sxx not enabled");
return size_t {has_sx + has_sy};
}
static constexpr size_t syy_index() {
static_assert(has_syy, "syy not enabled");
return size_t {has_sx + has_sy + has_sxx};
}
static constexpr size_t sxy_index() {
static_assert(has_sxy, "sxy not enabled");
return size_t {has_sx + has_sy + has_sxx + has_syy};
}
UInt64& n() { return std::get<0>(state); }
Float64& sx() { return moments()[sx_index()]; }
Float64& sy() { return moments()[sy_index()]; }
Float64& sxx() { return moments()[sxx_index()]; }
Float64& syy() { return moments()[syy_index()]; }
Float64& sxy() { return moments()[sxy_index()]; }
const UInt64& n() const { return std::get<0>(state); }
const Float64& sx() const { return moments()[sx_index()]; }
const Float64& sy() const { return moments()[sy_index()]; }
const Float64& sxx() const { return moments()[sxx_index()]; }
const Float64& syy() const { return moments()[syy_index()]; }
const Float64& sxy() const { return moments()[sxy_index()]; }
void write(BufferWritable& buf) const {
if constexpr (has_sx) {
buf.write_binary(sx());
}
if constexpr (has_sy) {
buf.write_binary(sy());
}
if constexpr (has_sxx) {
buf.write_binary(sxx());
}
if constexpr (has_syy) {
buf.write_binary(syy());
}
if constexpr (has_sxy) {
buf.write_binary(sxy());
}
buf.write_binary(n());
}
void read(BufferReadable& buf) {
if constexpr (has_sx) {
buf.read_binary(sx());
}
if constexpr (has_sy) {
buf.read_binary(sy());
}
if constexpr (has_sxx) {
buf.read_binary(sxx());
}
if constexpr (has_syy) {
buf.read_binary(syy());
}
if constexpr (has_sxy) {
buf.read_binary(sxy());
}
buf.read_binary(n());
}
void reset() {
if constexpr (has_moments) {
moments().fill({});
}
n() = {};
}
/**
* The merge function uses the Youngs-Cramer algorithm:
* N = N1 + N2
* Sx = Sx1 + Sx2
* Sy = Sy1 + Sy2
* Sxx = Sxx1 + Sxx2 + N1 * N2 * (Sx1/N1 - Sx2/N2)^2 / N
* Syy = Syy1 + Syy2 + N1 * N2 * (Sy1/N1 - Sy2/N2)^2 / N
* Sxy = Sxy1 + Sxy2 + N1 * N2 * (Sx1/N1 - Sx2/N2) * (Sy1/N1 - Sy2/N2) / N
*/
void merge(const AggregateFunctionRegrData& rhs) {
if (rhs.n() == 0) {
return;
}
if (n() == 0) {
*this = rhs;
return;
}
const auto n1 = static_cast<Float64>(n());
const auto n2 = static_cast<Float64>(rhs.n());
const auto nsum = n1 + n2;
Float64 dx {};
Float64 dy {};
if constexpr (has_sxx || has_sxy) {
dx = sx() / n1 - rhs.sx() / n2;
}
if constexpr (has_syy || has_sxy) {
dy = sy() / n1 - rhs.sy() / n2;
}
n() += rhs.n();
if constexpr (has_sx) {
sx() += rhs.sx();
}
if constexpr (has_sy) {
sy() += rhs.sy();
}
if constexpr (has_sxx) {
sxx() += rhs.sxx() + n1 * n2 * dx * dx / nsum;
}
if constexpr (has_syy) {
syy() += rhs.syy() + n1 * n2 * dy * dy / nsum;
}
if constexpr (has_sxy) {
sxy() += rhs.sxy() + n1 * n2 * dx * dy / nsum;
}
}
/**
* N = count
* Sx = sum(X)
* Sy = sum(Y)
* Sxx = sum((X-Sx/N)^2)
* Syy = sum((Y-Sy/N)^2)
* Sxy = sum((X-Sx/N)*(Y-Sy/N))
*/
void add(typename PrimitiveTypeTraits<Type>::CppType value_y,
typename PrimitiveTypeTraits<Type>::CppType value_x) {
const auto x = static_cast<Float64>(value_x);
const auto y = static_cast<Float64>(value_y);
if constexpr (has_sx) {
sx() += x;
}
if constexpr (has_sy) {
sy() += y;
}
if (n() == 0) [[unlikely]] {
n() = 1;
return;
}
const auto n_old = static_cast<Float64>(n());
const auto n_new = n_old + 1;
const auto scale = 1.0 / (n_new * n_old);
n() += 1;
Float64 tmp_x {};
Float64 tmp_y {};
if constexpr (has_sxx || has_sxy) {
tmp_x = x * n_new - sx();
}
if constexpr (has_syy || has_sxy) {
tmp_y = y * n_new - sy();
}
if constexpr (has_sxx) {
sxx() += tmp_x * tmp_x * scale;
}
if constexpr (has_syy) {
syy() += tmp_y * tmp_y * scale;
}
if constexpr (has_sxy) {
sxy() += tmp_x * tmp_y * scale;
}
}
auto get_result() const {
if constexpr (kind == RegrFunctionKind::regr_count) {
return static_cast<Int64>(n());
} else if constexpr (kind == RegrFunctionKind::regr_avgx) {
if (n() < 1) {
return std::numeric_limits<Float64>::quiet_NaN();
}
return sx() / static_cast<Float64>(n());
} else if constexpr (kind == RegrFunctionKind::regr_avgy) {
if (n() < 1) {
return std::numeric_limits<Float64>::quiet_NaN();
}
return sy() / static_cast<Float64>(n());
} else if constexpr (kind == RegrFunctionKind::regr_slope) {
if (n() < 1 || sxx() == 0.0) {
return std::numeric_limits<Float64>::quiet_NaN();
}
return sxy() / sxx();
} else if constexpr (kind == RegrFunctionKind::regr_intercept) {
if (n() < 1 || sxx() == 0.0) {
return std::numeric_limits<Float64>::quiet_NaN();
}
return (sy() - sx() * sxy() / sxx()) / static_cast<Float64>(n());
} else if constexpr (kind == RegrFunctionKind::regr_sxx) {
if (n() < 1) {
return std::numeric_limits<Float64>::quiet_NaN();
}
return sxx();
} else if constexpr (kind == RegrFunctionKind::regr_syy) {
if (n() < 1) {
return std::numeric_limits<Float64>::quiet_NaN();
}
return syy();
} else if constexpr (kind == RegrFunctionKind::regr_sxy) {
if (n() < 1) {
return std::numeric_limits<Float64>::quiet_NaN();
}
return sxy();
} else if constexpr (kind == RegrFunctionKind::regr_r2) {
if (n() < 1 || sxx() == 0.0) {
return std::numeric_limits<Float64>::quiet_NaN();
}
if (syy() == 0.0) {
return 1.0;
}
return (sxy() * sxy()) / (sxx() * syy());
} else {
__builtin_unreachable();
}
}
};
template <PrimitiveType T, RegrFunctionKind kind, bool y_is_nullable, bool x_is_nullable,
typename Derived>
class AggregateFunctionRegrBase
: public IAggregateFunctionDataHelper<AggregateFunctionRegrData<T, kind>, Derived> {
public:
using Data = AggregateFunctionRegrData<T, kind>;
using InputCol = typename PrimitiveTypeTraits<Data::Type>::ColumnType;
explicit AggregateFunctionRegrBase(const DataTypes& argument_types_)
: IAggregateFunctionDataHelper<Data, Derived>(argument_types_) {
DCHECK(argument_types_.size() == 2);
}
String get_name() const override { return RegrTraits<kind>::name; }
// Regr aggregates only consume rows where both y and x are non-null.
void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena&) const override {
const auto* y_column = get_nested_column_or_null<y_is_nullable>(columns[0], row_num);
if constexpr (y_is_nullable) {
if (y_column == nullptr) {
return;
}
}
const auto* x_column = get_nested_column_or_null<x_is_nullable>(columns[1], row_num);
if constexpr (x_is_nullable) {
if (x_column == nullptr) {
return;
}
}
this->data(place).add(y_column->get_data()[row_num], x_column->get_data()[row_num]);
}
void reset(AggregateDataPtr __restrict place) const override { this->data(place).reset(); }
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
Arena&) const override {
this->data(place).merge(this->data(rhs));
}
void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override {
this->data(place).write(buf);
}
void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
Arena&) const override {
this->data(place).read(buf);
}
protected:
using IAggregateFunctionDataHelper<Data, Derived>::data;
private:
template <bool is_nullable>
static ALWAYS_INLINE const InputCol* get_nested_column_or_null(const IColumn* col,
ssize_t row_num) {
if constexpr (is_nullable) {
const auto& nullable_column =
assert_cast<const ColumnNullable&, TypeCheckOnRelease::DISABLE>(*col);
if (nullable_column.is_null_at(row_num)) {
return nullptr;
}
return assert_cast<const InputCol*, TypeCheckOnRelease::DISABLE>(
nullable_column.get_nested_column_ptr().get());
} else {
return assert_cast<const InputCol*, TypeCheckOnRelease::DISABLE>(col->get_ptr().get());
}
}
};
template <PrimitiveType T, RegrFunctionKind kind, bool y_is_nullable, bool x_is_nullable>
class AggregateFunctionRegr final
: public AggregateFunctionRegrBase<
T, kind, y_is_nullable, x_is_nullable,
AggregateFunctionRegr<T, kind, y_is_nullable, x_is_nullable>> {
public:
static_assert(kind != RegrFunctionKind::regr_count,
"regr_count uses the AggregateFunctionRegr specialization");
using ResultDataType = ColumnFloat64;
using Base =
AggregateFunctionRegrBase<T, kind, y_is_nullable, x_is_nullable, AggregateFunctionRegr>;
explicit AggregateFunctionRegr(const DataTypes& argument_types_) : Base(argument_types_) {}
DataTypePtr get_return_type() const override {
return make_nullable(std::make_shared<DataTypeFloat64>());
}
void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
auto& nullable_column = assert_cast<ColumnNullable&>(to);
auto& nested_column = assert_cast<ResultDataType&>(nullable_column.get_nested_column());
const Float64 result = this->data(place).get_result();
if (std::isnan(result)) {
nullable_column.get_null_map_data().push_back(1);
nested_column.insert_default();
} else {
nullable_column.get_null_map_data().push_back(0);
nested_column.get_data().push_back(result);
}
}
};
template <PrimitiveType T, bool y_is_nullable, bool x_is_nullable>
class AggregateFunctionRegr<T, RegrFunctionKind::regr_count, y_is_nullable, x_is_nullable> final
: public AggregateFunctionRegrBase<T, RegrFunctionKind::regr_count, y_is_nullable,
x_is_nullable,
AggregateFunctionRegr<T, RegrFunctionKind::regr_count,
y_is_nullable, x_is_nullable>> {
public:
using ResultDataType = ColumnInt64;
using Base = AggregateFunctionRegrBase<T, RegrFunctionKind::regr_count, y_is_nullable,
x_is_nullable, AggregateFunctionRegr>;
explicit AggregateFunctionRegr(const DataTypes& argument_types_) : Base(argument_types_) {}
DataTypePtr get_return_type() const override { return std::make_shared<DataTypeInt64>(); }
void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
auto& result_column = assert_cast<ResultDataType&>(to);
result_column.get_data().push_back(this->data(place).get_result());
}
};
} // namespace doris