blob: 20ef632ee6a99e2fed08d705a8fcf3905964ff11 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/IAggregateFunction.h>
#include <AggregateFunctions/IAggregateFunction_fwd.h>
#include <Columns/ColumnArray.h>
#include <Core/Field.h>
#include <Core/Settings.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypeTuple.h>
#include <DataTypes/DataTypesNumber.h>
#include <DataTypes/IDataType.h>
#include <DataTypes/Serializations/ISerialization.h>
#include <IO/VarInt.h>
#include <Interpreters/Context.h>
#include <Parsers/ASTExpressionList.h>
#include <Parsers/ASTIdentifier.h>
#include <Parsers/ASTOrderByElement.h>
#include <Parsers/ExpressionListParsers.h>
#include <Parsers/parseQuery.h>
#include <Common/Exception.h>
#include <Common/logger_useful.h>
namespace DB::ErrorCodes
{
extern const int LOGICAL_ERROR;
extern const int BAD_ARGUMENTS;
}
namespace local_engine
{
struct SortOrderField
{
size_t pos = 0;
Int8 direction = 0;
Int8 nulls_direction = 0;
};
using SortOrderFields = std::vector<SortOrderField>;
struct RowNumGroupArraySortedData
{
public:
using Data = DB::Tuple;
std::vector<Data> values;
static bool compare(const Data & lhs, const Data & rhs, const SortOrderFields & sort_orders)
{
for (const auto & sort_order : sort_orders)
{
const auto & pos = sort_order.pos;
const auto & asc = sort_order.direction;
const auto & nulls_first = sort_order.nulls_direction;
bool l_is_null = lhs[pos].isNull();
bool r_is_null = rhs[pos].isNull();
if (l_is_null && r_is_null)
continue;
else if (l_is_null)
return nulls_first;
else if (r_is_null)
return !nulls_first;
else if (lhs[pos] < rhs[pos])
return asc;
else if (lhs[pos] > rhs[pos])
return !asc;
}
return false;
}
ALWAYS_INLINE void heapReplaceTop(const SortOrderFields & sort_orders)
{
size_t size = values.size();
if (size < 2)
return;
size_t child_index = 1;
if (size > 2 && compare(values[1], values[2], sort_orders))
++child_index;
if (compare(values[child_index], values[0], sort_orders))
return;
size_t current_index = 0;
auto current = values[current_index];
do
{
values[current_index] = values[child_index];
current_index = child_index;
child_index = 2 * child_index + 1;
if (child_index >= size)
break;
if ((child_index + 1) < size && compare(values[child_index], values[child_index + 1], sort_orders))
++child_index;
} while (!compare(values[child_index], current, sort_orders));
values[current_index] = current;
}
ALWAYS_INLINE void addElement(const Data && data, const SortOrderFields & sort_orders, size_t max_elements)
{
if (values.size() >= max_elements)
{
if (!compare(data, values[0], sort_orders))
return;
values[0] = data;
heapReplaceTop(sort_orders);
return;
}
values.emplace_back(std::move(data));
auto cmp = [&sort_orders](const Data & a, const Data & b) { return compare(a, b, sort_orders); };
std::push_heap(values.begin(), values.end(), cmp);
}
ALWAYS_INLINE void sortAndLimit(size_t max_elements, const SortOrderFields & sort_orders)
{
::sort(values.begin(), values.end(), [&sort_orders](const Data & a, const Data & b) { return compare(a, b, sort_orders); });
if (values.size() > max_elements)
values.resize(max_elements);
}
ALWAYS_INLINE void insertResultInto(DB::IColumn & to, size_t max_elements, const SortOrderFields & sort_orders)
{
auto & result_array = assert_cast<DB::ColumnArray &>(to);
auto & result_array_offsets = result_array.getOffsets();
sortAndLimit(max_elements, sort_orders);
result_array_offsets.push_back(result_array_offsets.back() + values.size());
if (values.empty())
return;
auto & result_array_data = result_array.getData();
for (int i = 0, sz = static_cast<int>(values.size()); i < sz; ++i)
{
auto & value = values[i];
value.push_back(i + 1);
result_array_data.insert(value);
}
}
};
static DB::DataTypePtr getRowNumReultDataType(DB::DataTypePtr data_type)
{
const auto * tuple_type = typeid_cast<const DB::DataTypeTuple *>(data_type.get());
if (!tuple_type)
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Tuple type is expected, but got: {}", data_type->getName());
DB::DataTypes element_types = tuple_type->getElements();
std::vector<String> element_names = tuple_type->getElementNames();
element_types.push_back(std::make_shared<DB::DataTypeInt32>());
element_names.push_back("row_num");
auto nested_tuple_type = std::make_shared<DB::DataTypeTuple>(element_types, element_names);
return std::make_shared<DB::DataTypeArray>(nested_tuple_type);
}
// usage: rowNumGroupArraySorted(1, "a asc nulls first, b desc nulls last")(tuple(a,b))
class RowNumGroupArraySorted final : public DB::IAggregateFunctionDataHelper<RowNumGroupArraySortedData, RowNumGroupArraySorted>
{
public:
explicit RowNumGroupArraySorted(DB::DataTypePtr data_type, const DB::Array & parameters_)
: DB::IAggregateFunctionDataHelper<RowNumGroupArraySortedData, RowNumGroupArraySorted>(
{data_type}, parameters_, getRowNumReultDataType(data_type))
{
if (parameters_.size() != 2)
throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "{} needs two parameters: limit and order clause", getName());
const auto * tuple_type = typeid_cast<const DB::DataTypeTuple *>(data_type.get());
if (!tuple_type)
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Tuple type is expected, but got: {}", data_type->getName());
limit = parameters_[0].safeGet<UInt64>();
String order_by_clause = parameters_[1].safeGet<String>();
sort_order_fields = parseSortOrderFields(order_by_clause);
serialization = data_type->getDefaultSerialization();
}
String getName() const override { return "rowNumGroupArraySorted"; }
void add(DB::AggregateDataPtr __restrict place, const DB::IColumn ** columns, size_t row_num, DB::Arena * /*arena*/) const override
{
auto & data = this->data(place);
DB::Tuple data_tuple = (*columns[0])[row_num].safeGet<DB::Tuple>();
this->data(place).addElement(std::move(data_tuple), sort_order_fields, limit);
}
void merge(DB::AggregateDataPtr __restrict place, DB::ConstAggregateDataPtr rhs, DB::Arena * /*arena*/) const override
{
auto & rhs_values = this->data(rhs).values;
for (auto & rhs_element : rhs_values)
this->data(place).addElement(std::move(rhs_element), sort_order_fields, limit);
}
void serialize(DB::ConstAggregateDataPtr __restrict place, DB::WriteBuffer & buf, std::optional<size_t> /* version */) const override
{
auto & values = this->data(place).values;
size_t size = values.size();
DB::writeVarUInt(size, buf);
for (const auto & value : values)
serialization->serializeBinary(value, buf, {});
}
void deserialize(
DB::AggregateDataPtr __restrict place, DB::ReadBuffer & buf, std::optional<size_t> /* version */, DB::Arena *) const override
{
size_t size = 0;
DB::readVarUInt(size, buf);
auto & values = this->data(place).values;
values.reserve(size);
for (size_t i = 0; i < size; ++i)
{
DB::Field data;
serialization->deserializeBinary(data, buf, {});
values.emplace_back(data.safeGet<DB::Tuple>());
}
}
void insertResultInto(DB::AggregateDataPtr __restrict place, DB::IColumn & to, DB::Arena * /*arena*/) const override
{
this->data(place).insertResultInto(to, limit, sort_order_fields);
}
bool allocatesMemoryInArena() const override { return true; }
private:
size_t limit = 0;
SortOrderFields sort_order_fields;
DB::SerializationPtr serialization;
SortOrderFields parseSortOrderFields(const String & order_by_clause) const
{
DB::ParserOrderByExpressionList order_by_parser;
auto order_by_ast = DB::parseQuery(order_by_parser, order_by_clause, 1000, 1000, 1000);
SortOrderFields fields;
const auto expression_list_ast = assert_cast<const DB::ASTExpressionList *>(order_by_ast.get());
const auto & tuple_element_names = assert_cast<const DB::DataTypeTuple *>(argument_types[0].get())->getElementNames();
for (const auto & child : expression_list_ast->children)
{
const auto * order_by_element_ast = assert_cast<const DB::ASTOrderByElement *>(child.get());
const auto * ident_ast = assert_cast<const DB::ASTIdentifier *>(order_by_element_ast->children[0].get());
const auto & ident_name = ident_ast->shortName();
SortOrderField field;
field.direction = order_by_element_ast->direction == 1;
field.nulls_direction
= field.direction ? order_by_element_ast->nulls_direction == -1 : order_by_element_ast->nulls_direction == 1;
auto name_pos = std::find(tuple_element_names.begin(), tuple_element_names.end(), ident_name);
if (name_pos == tuple_element_names.end())
{
throw DB::Exception(
DB::ErrorCodes::BAD_ARGUMENTS, "Not found column {} in tuple {}", ident_name, argument_types[0]->getName());
}
field.pos = std::distance(tuple_element_names.begin(), name_pos);
fields.push_back(field);
}
return fields;
}
};
DB::AggregateFunctionPtr createAggregateFunctionRowNumGroupArray(
const std::string & name, const DB::DataTypes & argument_types, const DB::Array & parameters, const DB::Settings *)
{
if (argument_types.size() != 1 || !typeid_cast<const DB::DataTypeTuple *>(argument_types[0].get()))
throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, " {} Nees only one tuple argument", name);
return std::make_shared<RowNumGroupArraySorted>(argument_types[0], parameters);
}
void registerAggregateFunctionRowNumGroup(DB::AggregateFunctionFactory & factory)
{
DB::AggregateFunctionProperties properties = {.returns_default_when_only_null = false, .is_order_dependent = false};
factory.registerFunction("rowNumGroupArraySorted", {createAggregateFunctionRowNumGroupArray, properties});
}
}