| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| #include <vector> |
| #include <AggregateFunctions/AggregateFunctionFactory.h> |
| #include <AggregateFunctions/IAggregateFunction.h> |
| #include <AggregateFunctions/IAggregateFunction_fwd.h> |
| #include <Columns/ColumnArray.h> |
| #include <Core/Field.h> |
| #include <Core/Settings.h> |
| #include <DataTypes/DataTypeArray.h> |
| #include <DataTypes/DataTypeTuple.h> |
| #include <DataTypes/DataTypesNumber.h> |
| #include <DataTypes/IDataType.h> |
| #include <DataTypes/Serializations/ISerialization.h> |
| #include <IO/VarInt.h> |
| #include <Interpreters/Context.h> |
| #include <Parsers/ASTExpressionList.h> |
| #include <Parsers/ASTIdentifier.h> |
| #include <Parsers/ASTOrderByElement.h> |
| #include <Parsers/ExpressionListParsers.h> |
| #include <Parsers/parseQuery.h> |
| #include <Common/Exception.h> |
| #include <Common/logger_useful.h> |
| |
| namespace DB::ErrorCodes |
| { |
| extern const int LOGICAL_ERROR; |
| extern const int BAD_ARGUMENTS; |
| } |
| |
| namespace local_engine |
| { |
| |
| struct SortOrderField |
| { |
| size_t pos = 0; |
| Int8 direction = 0; |
| Int8 nulls_direction = 0; |
| }; |
| using SortOrderFields = std::vector<SortOrderField>; |
| |
| struct RowNumGroupArraySortedData |
| { |
| public: |
| using Data = DB::Tuple; |
| std::vector<Data> values; |
| |
| static bool compare(const Data & lhs, const Data & rhs, const SortOrderFields & sort_orders) |
| { |
| for (const auto & sort_order : sort_orders) |
| { |
| const auto & pos = sort_order.pos; |
| const auto & asc = sort_order.direction; |
| const auto & nulls_first = sort_order.nulls_direction; |
| bool l_is_null = lhs[pos].isNull(); |
| bool r_is_null = rhs[pos].isNull(); |
| if (l_is_null && r_is_null) |
| continue; |
| else if (l_is_null) |
| return nulls_first; |
| else if (r_is_null) |
| return !nulls_first; |
| else if (lhs[pos] < rhs[pos]) |
| return asc; |
| else if (lhs[pos] > rhs[pos]) |
| return !asc; |
| } |
| return false; |
| } |
| |
| ALWAYS_INLINE void heapReplaceTop(const SortOrderFields & sort_orders) |
| { |
| size_t size = values.size(); |
| if (size < 2) |
| return; |
| size_t child_index = 1; |
| if (size > 2 && compare(values[1], values[2], sort_orders)) |
| ++child_index; |
| |
| if (compare(values[child_index], values[0], sort_orders)) |
| return; |
| |
| size_t current_index = 0; |
| auto current = values[current_index]; |
| do |
| { |
| values[current_index] = values[child_index]; |
| current_index = child_index; |
| |
| child_index = 2 * child_index + 1; |
| |
| if (child_index >= size) |
| break; |
| |
| if ((child_index + 1) < size && compare(values[child_index], values[child_index + 1], sort_orders)) |
| ++child_index; |
| } while (!compare(values[child_index], current, sort_orders)); |
| |
| values[current_index] = current; |
| } |
| |
| ALWAYS_INLINE void addElement(const Data && data, const SortOrderFields & sort_orders, size_t max_elements) |
| { |
| if (values.size() >= max_elements) |
| { |
| if (!compare(data, values[0], sort_orders)) |
| return; |
| values[0] = data; |
| heapReplaceTop(sort_orders); |
| return; |
| } |
| values.emplace_back(std::move(data)); |
| auto cmp = [&sort_orders](const Data & a, const Data & b) { return compare(a, b, sort_orders); }; |
| std::push_heap(values.begin(), values.end(), cmp); |
| } |
| |
| ALWAYS_INLINE void sortAndLimit(size_t max_elements, const SortOrderFields & sort_orders) |
| { |
| ::sort(values.begin(), values.end(), [&sort_orders](const Data & a, const Data & b) { return compare(a, b, sort_orders); }); |
| if (values.size() > max_elements) |
| values.resize(max_elements); |
| } |
| |
| ALWAYS_INLINE void insertResultInto(DB::IColumn & to, size_t max_elements, const SortOrderFields & sort_orders) |
| { |
| auto & result_array = assert_cast<DB::ColumnArray &>(to); |
| auto & result_array_offsets = result_array.getOffsets(); |
| |
| sortAndLimit(max_elements, sort_orders); |
| |
| result_array_offsets.push_back(result_array_offsets.back() + values.size()); |
| |
| if (values.empty()) |
| return; |
| auto & result_array_data = result_array.getData(); |
| for (int i = 0, sz = static_cast<int>(values.size()); i < sz; ++i) |
| { |
| auto & value = values[i]; |
| value.push_back(i + 1); |
| result_array_data.insert(value); |
| } |
| } |
| }; |
| |
| static DB::DataTypePtr getRowNumReultDataType(DB::DataTypePtr data_type) |
| { |
| const auto * tuple_type = typeid_cast<const DB::DataTypeTuple *>(data_type.get()); |
| if (!tuple_type) |
| throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Tuple type is expected, but got: {}", data_type->getName()); |
| DB::DataTypes element_types = tuple_type->getElements(); |
| std::vector<String> element_names = tuple_type->getElementNames(); |
| element_types.push_back(std::make_shared<DB::DataTypeInt32>()); |
| element_names.push_back("row_num"); |
| auto nested_tuple_type = std::make_shared<DB::DataTypeTuple>(element_types, element_names); |
| return std::make_shared<DB::DataTypeArray>(nested_tuple_type); |
| } |
| |
| // usage: rowNumGroupArraySorted(1, "a asc nulls first, b desc nulls last")(tuple(a,b)) |
| class RowNumGroupArraySorted final : public DB::IAggregateFunctionDataHelper<RowNumGroupArraySortedData, RowNumGroupArraySorted> |
| { |
| public: |
| explicit RowNumGroupArraySorted(DB::DataTypePtr data_type, const DB::Array & parameters_) |
| : DB::IAggregateFunctionDataHelper<RowNumGroupArraySortedData, RowNumGroupArraySorted>( |
| {data_type}, parameters_, getRowNumReultDataType(data_type)) |
| { |
| if (parameters_.size() != 2) |
| throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "{} needs two parameters: limit and order clause", getName()); |
| const auto * tuple_type = typeid_cast<const DB::DataTypeTuple *>(data_type.get()); |
| if (!tuple_type) |
| throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Tuple type is expected, but got: {}", data_type->getName()); |
| |
| limit = parameters_[0].safeGet<UInt64>(); |
| |
| String order_by_clause = parameters_[1].safeGet<String>(); |
| sort_order_fields = parseSortOrderFields(order_by_clause); |
| |
| serialization = data_type->getDefaultSerialization(); |
| } |
| |
| String getName() const override { return "rowNumGroupArraySorted"; } |
| |
| void add(DB::AggregateDataPtr __restrict place, const DB::IColumn ** columns, size_t row_num, DB::Arena * /*arena*/) const override |
| { |
| auto & data = this->data(place); |
| DB::Tuple data_tuple = (*columns[0])[row_num].safeGet<DB::Tuple>(); |
| this->data(place).addElement(std::move(data_tuple), sort_order_fields, limit); |
| } |
| |
| void merge(DB::AggregateDataPtr __restrict place, DB::ConstAggregateDataPtr rhs, DB::Arena * /*arena*/) const override |
| { |
| auto & rhs_values = this->data(rhs).values; |
| for (auto & rhs_element : rhs_values) |
| this->data(place).addElement(std::move(rhs_element), sort_order_fields, limit); |
| } |
| |
| void serialize(DB::ConstAggregateDataPtr __restrict place, DB::WriteBuffer & buf, std::optional<size_t> /* version */) const override |
| { |
| auto & values = this->data(place).values; |
| size_t size = values.size(); |
| DB::writeVarUInt(size, buf); |
| |
| for (const auto & value : values) |
| serialization->serializeBinary(value, buf, {}); |
| } |
| |
| void deserialize( |
| DB::AggregateDataPtr __restrict place, DB::ReadBuffer & buf, std::optional<size_t> /* version */, DB::Arena *) const override |
| { |
| size_t size = 0; |
| DB::readVarUInt(size, buf); |
| |
| auto & values = this->data(place).values; |
| values.reserve(size); |
| for (size_t i = 0; i < size; ++i) |
| { |
| DB::Field data; |
| serialization->deserializeBinary(data, buf, {}); |
| values.emplace_back(data.safeGet<DB::Tuple>()); |
| } |
| } |
| |
| void insertResultInto(DB::AggregateDataPtr __restrict place, DB::IColumn & to, DB::Arena * /*arena*/) const override |
| { |
| this->data(place).insertResultInto(to, limit, sort_order_fields); |
| } |
| |
| bool allocatesMemoryInArena() const override { return true; } |
| |
| private: |
| size_t limit = 0; |
| SortOrderFields sort_order_fields; |
| DB::SerializationPtr serialization; |
| |
| SortOrderFields parseSortOrderFields(const String & order_by_clause) const |
| { |
| DB::ParserOrderByExpressionList order_by_parser; |
| auto order_by_ast = DB::parseQuery(order_by_parser, order_by_clause, 1000, 1000, 1000); |
| SortOrderFields fields; |
| const auto expression_list_ast = assert_cast<const DB::ASTExpressionList *>(order_by_ast.get()); |
| const auto & tuple_element_names = assert_cast<const DB::DataTypeTuple *>(argument_types[0].get())->getElementNames(); |
| for (const auto & child : expression_list_ast->children) |
| { |
| const auto * order_by_element_ast = assert_cast<const DB::ASTOrderByElement *>(child.get()); |
| const auto * ident_ast = assert_cast<const DB::ASTIdentifier *>(order_by_element_ast->children[0].get()); |
| const auto & ident_name = ident_ast->shortName(); |
| |
| |
| SortOrderField field; |
| field.direction = order_by_element_ast->direction == 1; |
| field.nulls_direction |
| = field.direction ? order_by_element_ast->nulls_direction == -1 : order_by_element_ast->nulls_direction == 1; |
| |
| auto name_pos = std::find(tuple_element_names.begin(), tuple_element_names.end(), ident_name); |
| if (name_pos == tuple_element_names.end()) |
| { |
| throw DB::Exception( |
| DB::ErrorCodes::BAD_ARGUMENTS, "Not found column {} in tuple {}", ident_name, argument_types[0]->getName()); |
| } |
| field.pos = std::distance(tuple_element_names.begin(), name_pos); |
| fields.push_back(field); |
| } |
| return fields; |
| } |
| }; |
| |
| |
| DB::AggregateFunctionPtr createAggregateFunctionRowNumGroupArray( |
| const std::string & name, const DB::DataTypes & argument_types, const DB::Array & parameters, const DB::Settings *) |
| { |
| if (argument_types.size() != 1 || !typeid_cast<const DB::DataTypeTuple *>(argument_types[0].get())) |
| throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, " {} Nees only one tuple argument", name); |
| return std::make_shared<RowNumGroupArraySorted>(argument_types[0], parameters); |
| } |
| |
| void registerAggregateFunctionRowNumGroup(DB::AggregateFunctionFactory & factory) |
| { |
| DB::AggregateFunctionProperties properties = {.returns_default_when_only_null = false, .is_order_dependent = false}; |
| |
| factory.registerFunction("rowNumGroupArraySorted", {createAggregateFunctionRowNumGroupArray, properties}); |
| } |
| } |