blob: d77623c6bf93eceaacaa0767fa2440dc7c7b36d6 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
// This file is copied from
// https://github.com/ClickHouse/ClickHouse/blob/master/src/Functions/array/arrayCumSum.cpp
// and modified by Doris
#include "common/logging.h"
#include "common/status.h"
#include "runtime/define_primitive_type.h"
#include "runtime/primitive_type.h"
#include "vec/columns/column.h"
#include "vec/columns/column_array.h"
#include "vec/core/call_on_type_index.h"
#include "vec/core/types.h"
#include "vec/data_types/data_type.h"
#include "vec/data_types/data_type_array.h"
#include "vec/data_types/data_type_decimal.h"
#include "vec/data_types/data_type_nullable.h"
#include "vec/data_types/data_type_number.h"
#include "vec/functions/function.h"
#include "vec/functions/simple_function_factory.h"
#include "vec/utils/util.hpp"
namespace doris::vectorized {
// array_cum_sum([1, 2, 3, 4, 5]) -> [1, 3, 6, 10, 15]
// array_cum_sum([1, NULL, 3, NULL, 5]) -> [1, NULL, 4, NULL, 9]
template <PrimitiveType PType>
class FunctionArrayCumSum : public IFunction {
public:
using NullMapType = PaddedPODArray<UInt8>;
static constexpr auto name = "array_cum_sum";
explicit FunctionArrayCumSum(DataTypePtr result_type) {}
static FunctionPtr create() { return std::make_shared<FunctionArrayCumSum>(nullptr); }
String get_name() const override { return name; }
bool is_variadic() const override { return false; }
size_t get_number_of_arguments() const override { return 1; }
DataTypePtr get_return_type_impl(const DataTypes& arguments) const override {
DCHECK(arguments[0]->get_primitive_type() == TYPE_ARRAY)
<< "argument for function: " << name << " should be DataTypeArray but it has type "
<< arguments[0]->get_name() << ".";
auto nested_type = assert_cast<const DataTypeArray&>(*(arguments[0])).get_nested_type();
DataTypePtr return_type = nullptr;
switch (nested_type->get_primitive_type()) {
case PrimitiveType::TYPE_BOOLEAN:
case PrimitiveType::TYPE_TINYINT:
case PrimitiveType::TYPE_SMALLINT:
case PrimitiveType::TYPE_INT:
case PrimitiveType::TYPE_BIGINT: {
return_type = std::make_shared<DataTypeInt64>();
break;
}
case PrimitiveType::TYPE_LARGEINT: {
return_type = std::make_shared<DataTypeInt128>();
break;
}
case PrimitiveType::TYPE_FLOAT:
case PrimitiveType::TYPE_DOUBLE: {
return_type = std::make_shared<DataTypeFloat64>();
break;
}
case PrimitiveType::TYPE_DECIMALV2:
return_type = std::make_shared<DataTypeDecimalV2>(DataTypeDecimalV2::max_precision(),
nested_type->get_scale());
break;
case PrimitiveType::TYPE_DECIMAL32:
return_type = std::make_shared<DataTypeDecimal32>(DataTypeDecimal32::max_precision(),
nested_type->get_scale());
break;
case PrimitiveType::TYPE_DECIMAL64:
return_type = std::make_shared<DataTypeDecimal64>(DataTypeDecimal64::max_precision(),
nested_type->get_scale());
break;
case PrimitiveType::TYPE_DECIMAL128I:
return_type = std::make_shared<DataTypeDecimal128>(DataTypeDecimal128::max_precision(),
nested_type->get_scale());
break;
case PrimitiveType::TYPE_DECIMAL256:
return_type = std::make_shared<DataTypeDecimal256>(DataTypeDecimal256::max_precision(),
nested_type->get_scale());
break;
default:
break;
}
if (return_type) {
return std::make_shared<DataTypeArray>(make_nullable(return_type));
}
throw doris::Exception(ErrorCode::INVALID_ARGUMENT,
"Function of {}, return type get wrong: and input argument is: {}",
name, arguments[0]->get_name());
}
Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments,
const uint32_t result, size_t input_rows_count) const override {
auto src_arg = block.get_by_position(arguments[0]);
ColumnPtr src_column = src_arg.column->convert_to_full_column_if_const();
const auto& src_column_array = check_and_get_column<ColumnArray>(src_column.get());
if (!src_column_array) {
return Status::RuntimeError(
fmt::format("unsupported types for function {}({})", get_name(),
block.get_by_position(arguments[0]).type->get_name()));
}
const auto& src_offsets = src_column_array->get_offsets();
const auto* src_nested_column = &src_column_array->get_data();
DCHECK(src_nested_column != nullptr);
// get src nested column
auto src_nested_type = assert_cast<const DataTypeArray&>(*src_arg.type).get_nested_type();
// get null map
const ColumnNullable* src_nested_nullable_col =
check_and_get_column<ColumnNullable>(*src_nested_column);
src_nested_column = src_nested_nullable_col->get_nested_column_ptr().get();
const NullMapType& src_null_map = src_nested_nullable_col->get_null_map_column().get_data();
ColumnPtr res_nested_ptr;
auto res_val = _execute_by_type(src_nested_type, *src_nested_column, src_offsets,
src_null_map, res_nested_ptr);
if (!res_val) {
return Status::InvalidArgument(
"execute failed or unsupported types for function {}({})", get_name(),
block.get_by_position(arguments[0]).type->get_name());
}
ColumnPtr res_array_ptr =
ColumnArray::create(res_nested_ptr, src_column_array->get_offsets_ptr());
block.replace_by_position(result, std::move(res_array_ptr));
return Status::OK();
}
private:
bool _execute_by_type(DataTypePtr src_nested_type, const IColumn& src_column,
const ColumnArray::Offsets64& src_offsets,
const NullMapType& src_null_map, ColumnPtr& res_nested_ptr) const {
bool res = false;
switch (src_nested_type->get_primitive_type()) {
case TYPE_BOOLEAN:
res = _execute_number<TYPE_BOOLEAN, TYPE_BIGINT>(src_column, src_offsets, src_null_map,
res_nested_ptr);
break;
case TYPE_TINYINT:
res = _execute_number<TYPE_TINYINT, TYPE_BIGINT>(src_column, src_offsets, src_null_map,
res_nested_ptr);
break;
case TYPE_SMALLINT:
res = _execute_number<TYPE_SMALLINT, TYPE_BIGINT>(src_column, src_offsets, src_null_map,
res_nested_ptr);
break;
case TYPE_INT:
res = _execute_number<TYPE_INT, TYPE_BIGINT>(src_column, src_offsets, src_null_map,
res_nested_ptr);
break;
case TYPE_BIGINT:
res = _execute_number<TYPE_BIGINT, TYPE_BIGINT>(src_column, src_offsets, src_null_map,
res_nested_ptr);
break;
case TYPE_LARGEINT:
res = _execute_number<TYPE_LARGEINT, TYPE_LARGEINT>(src_column, src_offsets,
src_null_map, res_nested_ptr);
break;
case TYPE_FLOAT:
res = _execute_number<TYPE_FLOAT, TYPE_DOUBLE>(src_column, src_offsets, src_null_map,
res_nested_ptr);
break;
case TYPE_DOUBLE:
res = _execute_number<TYPE_DOUBLE, TYPE_DOUBLE>(src_column, src_offsets, src_null_map,
res_nested_ptr);
break;
case TYPE_DECIMAL32:
res = _execute_number<TYPE_DECIMAL32, PType>(src_column, src_offsets, src_null_map,
res_nested_ptr);
break;
case TYPE_DECIMAL64:
res = _execute_number<TYPE_DECIMAL64, PType>(src_column, src_offsets, src_null_map,
res_nested_ptr);
break;
case TYPE_DECIMAL128I:
res = _execute_number<TYPE_DECIMAL128I, PType>(src_column, src_offsets, src_null_map,
res_nested_ptr);
break;
case TYPE_DECIMAL256:
res = _execute_number<TYPE_DECIMAL256, PType>(src_column, src_offsets, src_null_map,
res_nested_ptr);
break;
case TYPE_DECIMALV2:
res = _execute_number<TYPE_DECIMALV2, TYPE_DECIMALV2>(src_column, src_offsets,
src_null_map, res_nested_ptr);
break;
default:
break;
}
return res;
}
template <PrimitiveType Element, PrimitiveType Result>
bool _execute_number(const IColumn& src_column, const ColumnArray::Offsets64& src_offsets,
const NullMapType& src_null_map, ColumnPtr& res_nested_ptr) const {
if constexpr (is_decimalv3(Element) &&
(TYPE_DECIMAL128I != Result && TYPE_DECIMAL256 != Result)) {
return false;
} else {
using ColVecType = typename PrimitiveTypeTraits<Element>::ColumnType;
using ColVecResult = typename PrimitiveTypeTraits<Result>::ColumnType;
// 1. get pod array from src
auto src_column_concrete = assert_cast<const ColVecType*>(&src_column);
if (!src_column_concrete) {
return false;
}
// 2. construct result data
typename ColVecResult::MutablePtr res_nested_mut_ptr = nullptr;
if constexpr (is_decimal(Result)) {
res_nested_mut_ptr = ColVecResult::create(0, src_column_concrete->get_scale());
} else {
res_nested_mut_ptr = ColVecResult::create();
}
// get result data pod array
auto size = src_column.size();
auto& res_datas = res_nested_mut_ptr->get_data();
res_datas.resize(size);
// 3. compute cum sum and null map
_compute_cum_sum<Result>(src_column_concrete->get_data(), src_offsets, src_null_map,
res_datas);
// handle null value in res_datas for first null value
auto res_null_map_col = ColumnUInt8::create(size, 0);
size_t first_not_null_pos =
VectorizedUtils::find_first_valid_simd(src_null_map, 0, size);
VLOG_DEBUG << "first_not_null_pos: " << std::to_string(first_not_null_pos);
VectorizedUtils::range_set_nullmap_to_true_simd(res_null_map_col->get_data(), 0,
first_not_null_pos);
res_nested_ptr = ColumnNullable::create(std::move(res_nested_mut_ptr),
std::move(res_null_map_col));
return true;
}
}
template <PrimitiveType Result>
void _compute_cum_sum(const auto& src_datas, const ColumnArray::Offsets64& src_offsets,
const NullMapType& src_null_map, auto& res_datas) const {
size_t prev_offset = 0;
for (auto cur_offset : src_offsets) {
// [1, null, 2, 3] -> [1, 1, 3, 6]
// [1, null, null, 3] -> [1, 1, 1, 4]
// [null, null, 1, 2, 3] -> [null, null, 1, 3, 6]
// [null, 1, null, 2, 3] -> [null, 1, 1, 3, 6]
// [null, null, null, null] -> [null, null, null, null]
typename PrimitiveTypeTraits<Result>::ColumnItemType accumulated {};
for (size_t pos = prev_offset; pos < cur_offset; ++pos) {
// treat null value as 0
if (src_null_map[pos]) {
accumulated += typename PrimitiveTypeTraits<Result>::ColumnItemType(0);
} else {
accumulated +=
typename PrimitiveTypeTraits<Result>::ColumnItemType(src_datas[pos]);
}
res_datas[pos] = accumulated;
}
prev_offset = cur_offset;
}
}
};
void register_function_array_cum_sum(SimpleFunctionFactory& factory) {
ArrayAggFunctionCreator creator = [&](const DataTypePtr& result_type) {
PrimitiveType primitive_type;
if (PrimitiveType::TYPE_ARRAY == result_type->get_primitive_type()) {
const DataTypeArray* data_type_array =
static_cast<const DataTypeArray*>(remove_nullable(result_type).get());
primitive_type = data_type_array->get_nested_type()->get_primitive_type();
} else {
primitive_type = result_type->get_primitive_type();
}
if (is_decimalv3(primitive_type)) {
return DefaultFunctionBuilder::create_array_agg_function_decimalv3<FunctionArrayCumSum>(
result_type);
} else {
FunctionBuilderPtr func;
auto call = [&](const auto& type) -> bool {
using DispatchType = std::decay_t<decltype(type)>;
if constexpr (!is_decimalv3(DispatchType::PType)) {
func = std::make_shared<DefaultFunctionBuilder>(
FunctionArrayCumSum<DispatchType::PType>::create());
return true;
} else {
return false;
}
};
if (!dispatch_switch_int(primitive_type, call) &&
!dispatch_switch_float(primitive_type, call)) {
throw doris::Exception(ErrorCode::INTERNAL_ERROR,
"array function array_cum_sum error, result type {}",
result_type->get_name());
}
return func;
}
};
factory.register_array_agg_function("array_cum_sum", creator);
}
} // namespace doris::vectorized