blob: c3f8707957ec680373cd23aba0488cf66fb33942 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include "runtime/runtime_predicate.h"
#include <memory>
#include "common/compiler_util.h" // IWYU pragma: keep
#include "common/exception.h"
#include "common/status.h"
#include "olap/accept_null_predicate.h"
#include "olap/column_predicate.h"
#include "olap/predicate_creator.h"
namespace doris::vectorized {
RuntimePredicate::RuntimePredicate(const TTopnFilterDesc& desc)
: _nulls_first(desc.null_first), _is_asc(desc.is_asc) {
DCHECK(!desc.target_node_id_to_target_expr.empty());
for (auto p : desc.target_node_id_to_target_expr) {
_contexts[p.first].expr = p.second;
}
PrimitiveType type = thrift_to_type(desc.target_node_id_to_target_expr.begin()
->second.nodes[0]
.type.types[0]
.scalar_type.type);
if (!_init(type)) {
std::stringstream ss;
desc.target_node_id_to_target_expr.begin()->second.nodes[0].printTo(ss);
throw Exception(ErrorCode::INTERNAL_ERROR, "meet invalid type, type={}, expr={}", int(type),
ss.str());
}
// For ASC sort, create runtime predicate col_name <= max_top_value
// since values that > min_top_value are large than any value in current topn values
// For DESC sort, create runtime predicate col_name >= min_top_value
// since values that < min_top_value are less than any value in current topn values
_pred_constructor = _is_asc ? create_comparison_predicate<PredicateType::LE>
: create_comparison_predicate<PredicateType::GE>;
}
void RuntimePredicate::init_target(
int32_t target_node_id, phmap::flat_hash_map<int, SlotDescriptor*> slot_id_to_slot_desc) {
std::unique_lock<std::shared_mutex> wlock(_rwlock);
check_target_node_id(target_node_id);
if (target_is_slot(target_node_id)) {
_contexts[target_node_id].col_name =
slot_id_to_slot_desc[get_texpr(target_node_id).nodes[0].slot_ref.slot_id]
->col_name();
}
_detected_target = true;
}
template <PrimitiveType type>
std::string get_normal_value(const Field& field) {
using ValueType = typename PrimitiveTypeTraits<type>::CppType;
return cast_to_string<type, ValueType>(field.get<ValueType>(), 0);
}
std::string get_date_value(const Field& field) {
using ValueType = typename PrimitiveTypeTraits<TYPE_DATE>::CppType;
ValueType value;
Int64 v = field.get<Int64>();
auto* p = (VecDateTimeValue*)&v;
value.from_olap_date(p->to_olap_date());
value.cast_to_date();
return cast_to_string<TYPE_DATE, ValueType>(value, 0);
}
std::string get_datetime_value(const Field& field) {
using ValueType = typename PrimitiveTypeTraits<TYPE_DATETIME>::CppType;
ValueType value;
Int64 v = field.get<Int64>();
auto* p = (VecDateTimeValue*)&v;
value.from_olap_datetime(p->to_olap_datetime());
value.to_datetime();
return cast_to_string<TYPE_DATETIME, ValueType>(value, 0);
}
std::string get_time_value(const Field& field) {
using ValueType = typename PrimitiveTypeTraits<TYPE_TIMEV2>::CppType;
ValueType value = field.get<ValueType>();
return cast_to_string<TYPE_TIMEV2, ValueType>(value, 0);
}
std::string get_decimalv2_value(const Field& field) {
// can NOT use PrimitiveTypeTraits<TYPE_DECIMALV2>::CppType since
// it is DecimalV2Value and Decimal128V2 can not convert to it implicitly
using ValueType = Decimal128V2::NativeType;
auto v = field.get<DecimalField<Decimal128V2>>();
// use TYPE_DECIMAL128I instead of TYPE_DECIMALV2 since v.get_scale()
// is always 9 for DECIMALV2
return cast_to_string<TYPE_DECIMAL128I, ValueType>(v.get_value(), v.get_scale());
}
template <PrimitiveType type>
std::string get_decimal_value(const Field& field) {
using ValueType = typename PrimitiveTypeTraits<type>::CppType;
auto v = field.get<DecimalField<ValueType>>();
return cast_to_string<type, ValueType>(v.get_value(), v.get_scale());
}
bool RuntimePredicate::_init(PrimitiveType type) {
// set get value function
switch (type) {
case PrimitiveType::TYPE_BOOLEAN: {
_get_value_fn = get_normal_value<TYPE_BOOLEAN>;
break;
}
case PrimitiveType::TYPE_TINYINT: {
_get_value_fn = get_normal_value<TYPE_TINYINT>;
break;
}
case PrimitiveType::TYPE_SMALLINT: {
_get_value_fn = get_normal_value<TYPE_SMALLINT>;
break;
}
case PrimitiveType::TYPE_INT: {
_get_value_fn = get_normal_value<TYPE_INT>;
break;
}
case PrimitiveType::TYPE_BIGINT: {
_get_value_fn = get_normal_value<TYPE_BIGINT>;
break;
}
case PrimitiveType::TYPE_LARGEINT: {
_get_value_fn = get_normal_value<TYPE_LARGEINT>;
break;
}
case PrimitiveType::TYPE_CHAR:
case PrimitiveType::TYPE_VARCHAR:
case PrimitiveType::TYPE_STRING: {
_get_value_fn = [](const Field& field) { return field.get<String>(); };
break;
}
case PrimitiveType::TYPE_DATEV2: {
_get_value_fn = get_normal_value<TYPE_DATEV2>;
break;
}
case PrimitiveType::TYPE_DATETIMEV2: {
_get_value_fn = get_normal_value<TYPE_DATETIMEV2>;
break;
}
case PrimitiveType::TYPE_DATE: {
_get_value_fn = get_date_value;
break;
}
case PrimitiveType::TYPE_DATETIME: {
_get_value_fn = get_datetime_value;
break;
}
case PrimitiveType::TYPE_TIMEV2: {
_get_value_fn = get_time_value;
break;
}
case PrimitiveType::TYPE_DECIMAL32: {
_get_value_fn = get_decimal_value<TYPE_DECIMAL32>;
break;
}
case PrimitiveType::TYPE_DECIMAL64: {
_get_value_fn = get_decimal_value<TYPE_DECIMAL64>;
break;
}
case PrimitiveType::TYPE_DECIMALV2: {
_get_value_fn = get_decimalv2_value;
break;
}
case PrimitiveType::TYPE_DECIMAL128I: {
_get_value_fn = get_decimal_value<TYPE_DECIMAL128I>;
break;
}
case PrimitiveType::TYPE_DECIMAL256: {
_get_value_fn = get_decimal_value<TYPE_DECIMAL256>;
break;
}
case PrimitiveType::TYPE_IPV4: {
_get_value_fn = get_normal_value<TYPE_IPV4>;
break;
}
case PrimitiveType::TYPE_IPV6: {
_get_value_fn = get_normal_value<TYPE_IPV6>;
break;
}
default:
return false;
}
return true;
}
Status RuntimePredicate::update(const Field& value) {
std::unique_lock<std::shared_mutex> wlock(_rwlock);
// skip null value
if (value.is_null()) {
return Status::OK();
}
bool updated = false;
if (UNLIKELY(_orderby_extrem.is_null())) {
_orderby_extrem = value;
updated = true;
} else {
if ((_is_asc && value < _orderby_extrem) || (!_is_asc && value > _orderby_extrem)) {
_orderby_extrem = value;
updated = true;
}
}
_has_value = true;
if (!updated) {
return Status::OK();
}
for (auto p : _contexts) {
auto ctx = p.second;
if (!ctx.tablet_schema) {
continue;
}
const auto& column = *DORIS_TRY(ctx.tablet_schema->column(ctx.col_name));
std::unique_ptr<ColumnPredicate> pred {
_pred_constructor(column.get_vec_type(), ctx.predicate->column_id(),
_get_value_fn(_orderby_extrem), false, _predicate_arena)};
// For NULLS FIRST, wrap a AcceptNullPredicate to return true for NULL
// since ORDER BY ASC/DESC should get NULL first but pred returns NULL
// and NULL in where predicate will be treated as FALSE
if (_nulls_first) {
pred = AcceptNullPredicate::create_unique(pred.release());
}
((SharedPredicate*)ctx.predicate.get())->set_nested(pred.release());
}
return Status::OK();
}
} // namespace doris::vectorized