blob: 1cb34efdb1e33ae0522418546bbc994124e45cb2 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <AggregateFunctions/AggregateFunctionAvg.h>
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/FactoryHelpers.h>
#include <AggregateFunctions/Helpers.h>
#include <Core/Settings.h>
#include <Common/CHUtil.h>
#include <Common/GlutenDecimalUtils.h>
#include <Common/GlutenSettings.h>
namespace DB
{
struct Settings;
namespace ErrorCodes
{
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
}
}
namespace local_engine
{
using namespace DB;
DataTypePtr getSparkAvgReturnType(const DataTypePtr & arg_type)
{
const UInt32 precision_value = std::min<size_t>(getDecimalPrecision(*arg_type) + 4, DecimalUtils::max_precision<Decimal128>);
const auto scale_value = std::min(getDecimalScale(*arg_type) + 4, precision_value);
return createDecimal<DataTypeDecimal>(precision_value, scale_value);
}
template <typename T, bool SPARK35>
requires is_decimal<T>
class AggregateFunctionSparkAvg final : public AggregateFunctionAvg<T>
{
public:
using Base = AggregateFunctionAvg<T>;
explicit AggregateFunctionSparkAvg(const DataTypes & argument_types_, UInt32 num_scale_, UInt32 round_scale_)
: Base(argument_types_, createResultType(argument_types_, num_scale_, round_scale_), num_scale_)
, num_scale(num_scale_)
, round_scale(round_scale_)
{
}
DataTypePtr createResultType(const DataTypes & argument_types_, UInt32 num_scale_, UInt32 /*round_scale_*/)
{
const DataTypePtr & data_type = argument_types_[0];
const UInt32 precision_value = std::min<size_t>(getDecimalPrecision(*data_type) + 4, DecimalUtils::max_precision<Decimal128>);
const auto scale_value = std::min(num_scale_ + 4, precision_value);
return createDecimal<DataTypeDecimal>(precision_value, scale_value);
}
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
{
const DataTypePtr & result_type = this->getResultType();
auto result_scale = getDecimalScale(*result_type);
WhichDataType which(result_type);
if (which.isDecimal32())
{
assert_cast<ColumnDecimal<Decimal32> &>(to).getData().push_back(
divideDecimalAndUInt(this->data(place), num_scale, result_scale, round_scale));
}
else if (which.isDecimal64())
{
assert_cast<ColumnDecimal<Decimal64> &>(to).getData().push_back(
divideDecimalAndUInt(this->data(place), num_scale, result_scale, round_scale));
}
else if (which.isDecimal128())
{
assert_cast<ColumnDecimal<Decimal128> &>(to).getData().push_back(
divideDecimalAndUInt(this->data(place), num_scale, result_scale, round_scale));
}
else
{
assert_cast<ColumnDecimal<Decimal256> &>(to).getData().push_back(
divideDecimalAndUInt(this->data(place), num_scale, result_scale, round_scale));
}
}
String getName() const override { return "sparkAvg"; }
private:
Int128 NO_SANITIZE_UNDEFINED
divideDecimalAndUInt(AvgFraction<AvgFieldType<T>, UInt64> avg, UInt32 num_scale, UInt32 result_scale, UInt32 round_scale) const
{
auto value = avg.numerator.value;
if (result_scale > num_scale)
{
auto diff = DecimalUtils::scaleMultiplier<AvgFieldType<T>>(result_scale - num_scale);
value = value * diff;
}
else if (result_scale < num_scale)
{
auto diff = DecimalUtils::scaleMultiplier<AvgFieldType<T>>(num_scale - result_scale);
value = value / diff;
}
auto result = value / avg.denominator;
if constexpr (SPARK35)
return result;
if (round_scale > result_scale)
return result;
auto round_diff = DecimalUtils::scaleMultiplier<AvgFieldType<T>>(result_scale - round_scale);
return (result + round_diff / 2) / round_diff * round_diff;
}
private:
UInt32 num_scale;
UInt32 round_scale;
};
template <bool Data, typename... TArgs>
static IAggregateFunction * createWithDecimalType(const IDataType & argument_type, TArgs && ... args)
{
WhichDataType which(argument_type);
if (which.idx == TypeIndex::Decimal32) return new AggregateFunctionSparkAvg<Decimal32, Data>(args...);
if (which.idx == TypeIndex::Decimal64) return new AggregateFunctionSparkAvg<Decimal64, Data>(args...);
if (which.idx == TypeIndex::Decimal128) return new AggregateFunctionSparkAvg<Decimal128, Data>(args...);
if (which.idx == TypeIndex::Decimal256) return new AggregateFunctionSparkAvg<Decimal256, Data>(args...);
if constexpr (AggregateFunctionSparkAvg<DateTime64, Data>::DateTime64Supported)
if (which.idx == TypeIndex::DateTime64) return new AggregateFunctionSparkAvg<DateTime64, Data>(args...);
return nullptr;
}
AggregateFunctionPtr createAggregateFunctionSparkAvg(
const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings * settings)
{
assertNoParameters(name, parameters);
assertUnary(name, argument_types);
AggregateFunctionPtr res;
const DataTypePtr & data_type = argument_types[0];
if (!isDecimal(data_type))
throw Exception(
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument for aggregate function {}", data_type->getName(), name);
std::string version;
if (tryGetString(*settings, "spark_version", version) && version.starts_with("3.5"))
{
res.reset(createWithDecimalType<true>(*data_type, argument_types, getDecimalScale(*data_type), 0));
return res;
}
bool allowPrecisionLoss = settings->get(DECIMAL_OPERATIONS_ALLOW_PREC_LOSS).safeGet<bool>();
const UInt32 p1 = DB::getDecimalPrecision(*data_type);
const UInt32 s1 = DB::getDecimalScale(*data_type);
auto [p2, s2] = GlutenDecimalUtils::LONG_DECIMAL;
auto [_, round_scale] = GlutenDecimalUtils::dividePrecisionScale(p1, s1, p2, s2, allowPrecisionLoss);
res.reset(createWithDecimalType<false>(*data_type, argument_types, getDecimalScale(*data_type), round_scale));
return res;
}
void registerAggregateFunctionSparkAvg(AggregateFunctionFactory & factory)
{
factory.registerFunction("sparkAvg", createAggregateFunctionSparkAvg);
}
}