blob: fad2b6ab3bd29364dcf1bc021b45861f3ca7410d [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
// This file is copied from
#pragma once
#include <glog/logging.h>
#include <stddef.h>
#include <algorithm>
#include <memory>
#include <unordered_set>
#include <utility>
#include <vector>
#include "common/status.h"
#include "vec/columns/column.h"
#include "vec/columns/column_const.h"
#include "vec/columns/column_nullable.h"
#include "vec/columns/column_struct.h"
#include "vec/columns/column_vector.h"
#include "vec/core/block.h"
#include "vec/data_types/data_type_factory.hpp"
#include "vec/data_types/data_type_nullable.h"
#include "vec/data_types/data_type_number.h"
#include "vec/functions/function.h"
namespace doris::vectorized {
#include "common/compile_check_begin.h"
struct ColumnRowRef {
ENABLE_FACTORY_CREATOR(ColumnRowRef);
ColumnPtr column;
size_t row_idx;
// equals when call set insert, this operator will be used
bool operator==(const ColumnRowRef& other) const {
return column->compare_at(row_idx, other.row_idx, *other.column, 0) == 0;
}
// compare
bool operator<(const ColumnRowRef& other) const {
return column->compare_at(row_idx, other.row_idx, *other.column, 0) < 0;
}
// when call set find, will use hash to find
size_t operator()(const ColumnRowRef& a) const {
uint32_t hash_val = 0;
a.column->update_crc_with_value(a.row_idx, a.row_idx + 1, hash_val, nullptr);
return hash_val;
}
};
struct CollectionInState {
ENABLE_FACTORY_CREATOR(CollectionInState)
std::unordered_set<ColumnRowRef, ColumnRowRef> args_set;
bool null_in_set = false;
};
template <bool negative>
class FunctionCollectionIn : public IFunction {
public:
static constexpr auto name = negative ? "collection_not_in" : "collection_in";
static FunctionPtr create() { return std::make_shared<FunctionCollectionIn>(); }
String get_name() const override { return name; }
bool is_variadic() const override { return true; }
size_t get_number_of_arguments() const override { return 0; }
DataTypePtr get_return_type_impl(const DataTypes& args) const override {
for (const auto& arg : args) {
if (arg->is_nullable()) {
return make_nullable(std::make_shared<DataTypeUInt8>());
}
}
return std::make_shared<DataTypeUInt8>();
}
bool use_default_implementation_for_nulls() const override { return false; }
// make data in context into a set
Status open(FunctionContext* context, FunctionContext::FunctionStateScope scope) override {
if (scope == FunctionContext::THREAD_LOCAL) {
return Status::OK();
}
int num_args = context->get_num_args();
DCHECK(num_args >= 1);
std::shared_ptr<CollectionInState> state = std::make_shared<CollectionInState>();
context->set_function_state(scope, state);
DataTypePtr args_type = remove_nullable(context->get_arg_type(0));
MutableColumnPtr args_column_ptr = args_type->create_column();
for (int i = 1; i < num_args; i++) {
// FE should make element type consistent and
// equalize the length of the elements in struct
const auto& const_column_ptr = context->get_constant_col(i);
// Types like struct, array, and map only support constant expressions.
DCHECK(const_column_ptr != nullptr);
const auto& [col, _] = unpack_if_const(const_column_ptr->column_ptr);
if (col->is_nullable()) {
const auto* null_col =
vectorized::check_and_get_column<vectorized::ColumnNullable>(col.get());
if (null_col->has_null()) {
state->null_in_set = true;
} else {
args_column_ptr->insert_from(null_col->get_nested_column(), 0);
}
} else {
args_column_ptr->insert_from(*col, 0);
}
}
ColumnPtr column_ptr = std::move(args_column_ptr);
// make collection ref into set
auto col_size = column_ptr->size();
for (size_t i = 0; i < col_size; i++) {
state->args_set.insert({column_ptr, i});
}
return Status::OK();
}
Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments,
uint32_t result, size_t input_rows_count) const override {
auto in_state = reinterpret_cast<CollectionInState*>(
context->get_function_state(FunctionContext::FRAGMENT_LOCAL));
if (!in_state) {
return Status::RuntimeError("function context for function '{}' must have Set;",
get_name());
}
const auto& args_set = in_state->args_set;
const bool null_in_set = in_state->null_in_set;
auto res = ColumnUInt8::create();
ColumnUInt8::Container& vec_res = res->get_data();
vec_res.resize(input_rows_count);
ColumnUInt8::MutablePtr col_null_map_to;
col_null_map_to = ColumnUInt8::create(input_rows_count, false);
auto& vec_null_map_to = col_null_map_to->get_data();
const ColumnWithTypeAndName& left_arg = block.get_by_position(arguments[0]);
const auto& [materialized_column, col_const] = unpack_if_const(left_arg.column);
auto materialized_column_not_null = materialized_column;
if (materialized_column_not_null->is_nullable()) {
materialized_column_not_null = assert_cast<ColumnPtr>(
vectorized::check_and_get_column<vectorized::ColumnNullable>(
materialized_column_not_null.get())
->get_nested_column_ptr());
}
for (size_t i = 0; i < input_rows_count; ++i) {
bool find = args_set.find({materialized_column_not_null, i}) != args_set.end();
if constexpr (negative) {
vec_res[i] = !find;
} else {
vec_res[i] = find;
}
if (null_in_set) {
vec_null_map_to[i] = negative == vec_res[i];
} else {
vec_null_map_to[i] = false;
}
}
if (block.get_by_position(result).type->is_nullable()) {
block.replace_by_position(
result, ColumnNullable::create(std::move(res), std::move(col_null_map_to)));
} else {
block.replace_by_position(result, std::move(res));
}
return Status::OK();
}
};
} // namespace doris::vectorized
#include "common/compile_check_end.h"