blob: 8e1db0f897dfed884cc4666cca96511ac77e598d [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include <utility>
#include "cast_to_array.h"
#include "cast_to_boolean.h"
#include "cast_to_date.h"
#include "cast_to_decimal.h"
#include "cast_to_float.h"
#include "cast_to_int.h"
#include "cast_to_ip.h"
#include "cast_to_jsonb.h"
#include "cast_to_map.h"
#include "cast_to_string.h"
#include "cast_to_struct.h"
#include "cast_to_timestamptz.h"
#include "cast_to_variant.h"
#include "runtime/primitive_type.h"
#include "vec/data_types/data_type_agg_state.h"
#include "vec/data_types/data_type_decimal.h"
#include "vec/data_types/data_type_number.h" // IWYU pragma: keep
#include "vec/functions/simple_function_factory.h"
namespace doris::vectorized {
namespace CastWrapper {
WrapperType create_hll_wrapper(FunctionContext* context, const DataTypePtr& from_type_untyped,
const DataTypeHLL& to_type) {
/// Conversion from String through parsing.
if (check_and_get_data_type<DataTypeString>(from_type_untyped.get())) {
return cast_from_string_to_generic;
}
return CastWrapper::create_unsupport_wrapper("Cast to HLL only support from String type");
}
WrapperType create_bitmap_wrapper(FunctionContext* context, const DataTypePtr& from_type_untyped,
const DataTypeBitMap& to_type) {
/// Conversion from String through parsing.
if (check_and_get_data_type<DataTypeString>(from_type_untyped.get())) {
return cast_from_string_to_generic;
}
return CastWrapper::create_unsupport_wrapper("Cast to BitMap only support from String type");
}
WrapperType create_varbinary_wrapper(const DataTypePtr& from_type_untyped) {
/// Conversion from String through parsing.
if (check_and_get_data_type<DataTypeString>(from_type_untyped.get())) {
return cast_from_string_to_generic;
}
return CastWrapper::create_unsupport_wrapper("Cast to Varbinary only support from String type");
}
WrapperType prepare_unpack_dictionaries(FunctionContext* context, const DataTypePtr& from_type,
const DataTypePtr& to_type) {
const auto& from_nested = from_type;
const auto& to_nested = to_type;
if (from_type->is_null_literal()) {
if (!to_nested->is_nullable()) {
return CastWrapper::create_unsupport_wrapper(
"Cannot convert NULL to a non-nullable type");
}
return [](FunctionContext* context, Block& block, const ColumnNumbers&, uint32_t result,
size_t input_rows_count, const NullMap::value_type* null_map = nullptr) {
/// TODO: remove this in the future.
auto& res = block.get_by_position(result);
res.column = res.type->create_column_const_with_default_value(input_rows_count)
->convert_to_full_column_if_const();
return Status::OK();
};
}
auto wrapper = prepare_remove_nullable(context, from_nested, to_nested);
return wrapper;
}
bool need_replace_null_data_to_default(FunctionContext* context, const DataTypePtr& from_type,
const DataTypePtr& to_type) {
if (from_type->equals(*to_type)) {
return false;
}
auto make_default_wrapper = [&](const auto& types) -> bool {
using Types = std::decay_t<decltype(types)>;
using ToDataType = typename Types::LeftType;
if constexpr (!(IsDataTypeDecimalOrNumber<ToDataType> || IsDatelikeV1Types<ToDataType> ||
IsDatelikeV2Types<ToDataType> ||
std::is_same_v<ToDataType, DataTypeTimeV2>)) {
return false;
}
return call_on_index_and_data_type<
ToDataType>(from_type->get_primitive_type(), [&](const auto& types2) -> bool {
using Types2 = std::decay_t<decltype(types2)>;
using FromDataType = typename Types2::LeftType;
if constexpr (!(IsDataTypeDecimalOrNumber<FromDataType> ||
IsDatelikeV1Types<FromDataType> || IsDatelikeV2Types<FromDataType> ||
std::is_same_v<FromDataType, DataTypeTimeV2>)) {
return false;
}
if constexpr (IsDataTypeDecimal<FromDataType> || IsDataTypeDecimal<ToDataType>) {
using FromFieldType = typename FromDataType::FieldType;
using ToFieldType = typename ToDataType::FieldType;
UInt32 from_precision = NumberTraits::max_ascii_len<FromFieldType>();
UInt32 from_scale = 0;
if constexpr (IsDataTypeDecimal<FromDataType>) {
const auto* from_decimal_type =
check_and_get_data_type<FromDataType>(from_type.get());
from_precision =
NumberTraits::max_ascii_len<typename FromFieldType::NativeType>();
from_scale = from_decimal_type->get_scale();
}
UInt32 to_max_digits = 0;
UInt32 to_precision = 0;
UInt32 to_scale = 0;
if constexpr (IsDataTypeDecimal<ToDataType>) {
to_max_digits = NumberTraits::max_ascii_len<typename ToFieldType::NativeType>();
const auto* to_decimal_type =
check_and_get_data_type<ToDataType>(to_type.get());
to_precision = to_decimal_type->get_precision();
ToDataType::check_type_precision(to_precision);
to_scale = to_decimal_type->get_scale();
ToDataType::check_type_scale(to_scale);
}
if constexpr (IsIntegralV<ToFieldType> || std::is_floating_point_v<ToFieldType>) {
to_max_digits = NumberTraits::max_ascii_len<ToFieldType>();
to_precision = to_max_digits;
}
bool narrow_integral = context->check_overflow_for_decimal() &&
(to_precision - to_scale) <= (from_precision - from_scale);
bool multiply_may_overflow = context->check_overflow_for_decimal();
if (to_scale > from_scale) {
multiply_may_overflow &=
(from_precision + to_scale - from_scale) >= to_max_digits;
}
return narrow_integral || multiply_may_overflow;
}
return false;
});
};
return call_on_index_and_data_type<void>(to_type->get_primitive_type(), make_default_wrapper);
}
WrapperType prepare_remove_nullable(FunctionContext* context, const DataTypePtr& from_type,
const DataTypePtr& to_type) {
/// Determine whether pre-processing and/or post-processing must take place during conversion.
bool result_is_nullable = to_type->is_nullable();
if (result_is_nullable) {
return [from_type, to_type](FunctionContext* context, Block& block,
const ColumnNumbers& arguments, uint32_t result,
size_t input_rows_count,
const NullMap::value_type* null_map = nullptr) {
auto from_type_not_nullable = remove_nullable(from_type);
auto to_type_not_nullable = remove_nullable(to_type);
bool replace_null_data_to_default = need_replace_null_data_to_default(
context, from_type_not_nullable, to_type_not_nullable);
auto nested_result_index = block.columns();
block.insert(block.get_by_position(result).unnest_nullable());
auto nested_source_index = block.columns();
block.insert(block.get_by_position(arguments[0])
.unnest_nullable(replace_null_data_to_default));
const auto& arg_col = block.get_by_position(arguments[0]);
const NullMap::value_type* arg_null_map = nullptr;
if (const auto* nullable = check_and_get_column<ColumnNullable>(*arg_col.column)) {
arg_null_map = nullable->get_null_map_data().data();
}
RETURN_IF_ERROR(prepare_impl(context, from_type_not_nullable, to_type_not_nullable)(
context, block, {nested_source_index}, nested_result_index, input_rows_count,
arg_null_map));
block.get_by_position(result).column =
wrap_in_nullable(block.get_by_position(nested_result_index).column, block,
arguments, input_rows_count);
block.erase(nested_source_index);
block.erase(nested_result_index);
return Status::OK();
};
} else {
return prepare_impl(context, from_type, to_type);
}
}
/// 'from_type' and 'to_type' are nested types in case of Nullable.
/// 'requested_result_is_nullable' is true if CAST to Nullable type is requested.
WrapperType prepare_impl(FunctionContext* context, const DataTypePtr& origin_from_type,
const DataTypePtr& origin_to_type) {
auto to_type = get_serialized_type(origin_to_type);
auto from_type = get_serialized_type(origin_from_type);
if (from_type->equals(*to_type)) {
return create_identity_wrapper(from_type);
}
// variant needs to be judged first
if (to_type->get_primitive_type() == PrimitiveType::TYPE_VARIANT) {
return create_cast_to_variant_wrapper(from_type,
static_cast<const DataTypeVariant&>(*to_type));
}
if (from_type->get_primitive_type() == PrimitiveType::TYPE_VARIANT) {
return create_cast_from_variant_wrapper(static_cast<const DataTypeVariant&>(*from_type),
to_type);
}
if (from_type->get_primitive_type() == PrimitiveType::TYPE_JSONB) {
return create_cast_from_jsonb_wrapper(static_cast<const DataTypeJsonb&>(*from_type),
to_type,
context ? context->jsonb_string_as_string() : false);
}
switch (to_type->get_primitive_type()) {
case PrimitiveType::TYPE_BOOLEAN:
return create_boolean_wrapper(context, from_type);
case PrimitiveType::TYPE_TINYINT:
return create_int_wrapper<DataTypeInt8>(context, from_type);
case PrimitiveType::TYPE_SMALLINT:
return create_int_wrapper<DataTypeInt16>(context, from_type);
case PrimitiveType::TYPE_INT:
return create_int_wrapper<DataTypeInt32>(context, from_type);
case PrimitiveType::TYPE_BIGINT:
return create_int_wrapper<DataTypeInt64>(context, from_type);
case PrimitiveType::TYPE_LARGEINT:
return create_int_wrapper<DataTypeInt128>(context, from_type);
case PrimitiveType::TYPE_FLOAT:
return create_float_wrapper<DataTypeFloat32>(context, from_type);
case PrimitiveType::TYPE_DOUBLE:
return create_float_wrapper<DataTypeFloat64>(context, from_type);
case PrimitiveType::TYPE_DATE:
return create_datelike_wrapper<DataTypeDate>(context, from_type);
case PrimitiveType::TYPE_DATETIME:
return create_datelike_wrapper<DataTypeDateTime>(context, from_type);
case PrimitiveType::TYPE_DATEV2:
return create_datelike_wrapper<DataTypeDateV2>(context, from_type);
case PrimitiveType::TYPE_DATETIMEV2:
return create_datelike_wrapper<DataTypeDateTimeV2>(context, from_type);
case PrimitiveType::TYPE_TIMESTAMPTZ:
return create_timestamptz_wrapper(context, from_type);
case PrimitiveType::TYPE_TIMEV2:
return create_datelike_wrapper<DataTypeTimeV2>(context, from_type);
case PrimitiveType::TYPE_IPV4:
return create_ip_wrapper<DataTypeIPv4>(context, from_type);
case PrimitiveType::TYPE_IPV6:
return create_ip_wrapper<DataTypeIPv6>(context, from_type);
case PrimitiveType::TYPE_DECIMALV2:
return create_decimal_wrapper<DataTypeDecimalV2>(context, from_type);
case PrimitiveType::TYPE_DECIMAL32:
return create_decimal_wrapper<DataTypeDecimal32>(context, from_type);
case PrimitiveType::TYPE_DECIMAL64:
return create_decimal_wrapper<DataTypeDecimal64>(context, from_type);
case PrimitiveType::TYPE_DECIMAL128I:
return create_decimal_wrapper<DataTypeDecimal128>(context, from_type);
case PrimitiveType::TYPE_DECIMAL256:
return create_decimal_wrapper<DataTypeDecimal256>(context, from_type);
case PrimitiveType::TYPE_CHAR:
case PrimitiveType::TYPE_VARCHAR:
case PrimitiveType::TYPE_STRING:
return create_string_wrapper(from_type);
case PrimitiveType::TYPE_ARRAY:
return create_array_wrapper(context, from_type,
static_cast<const DataTypeArray&>(*to_type));
case PrimitiveType::TYPE_STRUCT:
return create_struct_wrapper(context, from_type,
static_cast<const DataTypeStruct&>(*to_type));
case PrimitiveType::TYPE_MAP:
return create_map_wrapper(context, from_type, static_cast<const DataTypeMap&>(*to_type));
case PrimitiveType::TYPE_HLL:
return create_hll_wrapper(context, from_type, static_cast<const DataTypeHLL&>(*to_type));
case PrimitiveType::TYPE_BITMAP:
return create_bitmap_wrapper(context, from_type,
static_cast<const DataTypeBitMap&>(*to_type));
case PrimitiveType::TYPE_JSONB:
return create_cast_to_jsonb_wrapper(from_type, static_cast<const DataTypeJsonb&>(*to_type),
context ? context->string_as_jsonb_string() : false);
case PrimitiveType::TYPE_VARBINARY:
return create_varbinary_wrapper(from_type);
default:
break;
}
return create_unsupport_wrapper(from_type->get_name(), to_type->get_name());
}
} // namespace CastWrapper
class PreparedFunctionCast : public PreparedFunctionImpl {
public:
explicit PreparedFunctionCast(CastWrapper::WrapperType&& wrapper_function_, const char* name_)
: wrapper_function(std::move(wrapper_function_)), name(name_) {}
String get_name() const override { return name; }
protected:
Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments,
uint32_t result, size_t input_rows_count) const override {
return wrapper_function(context, block, arguments, result, input_rows_count, nullptr);
}
bool use_default_implementation_for_nulls() const override { return false; }
ColumnNumbers get_arguments_that_are_always_constant() const override { return {1}; }
private:
CastWrapper::WrapperType wrapper_function;
const char* name;
};
class FunctionCast final : public IFunctionBase {
public:
FunctionCast(const char* name_, DataTypes argument_types_, DataTypePtr return_type_)
: name(name_),
argument_types(std::move(argument_types_)),
return_type(std::move(return_type_)) {}
const DataTypes& get_argument_types() const override { return argument_types; }
const DataTypePtr& get_return_type() const override { return return_type; }
PreparedFunctionPtr prepare(FunctionContext* context, const Block& /*sample_block*/,
const ColumnNumbers& /*arguments*/,
uint32_t /*result*/) const override {
return std::make_shared<PreparedFunctionCast>(
CastWrapper::prepare_unpack_dictionaries(context, get_argument_types()[0],
get_return_type()),
name);
}
String get_name() const override { return name; }
bool is_use_default_implementation_for_constants() const override { return true; }
private:
const char* name = nullptr;
DataTypes argument_types;
DataTypePtr return_type;
};
class FunctionBuilderCast : public FunctionBuilderImpl {
public:
static constexpr auto name = "CAST";
static FunctionBuilderPtr create() { return std::make_shared<FunctionBuilderCast>(); }
FunctionBuilderCast() = default;
String get_name() const override { return name; }
size_t get_number_of_arguments() const override { return 2; }
ColumnNumbers get_arguments_that_are_always_constant() const override { return {1}; }
protected:
FunctionBasePtr build_impl(const ColumnsWithTypeAndName& arguments,
const DataTypePtr& return_type) const override {
DataTypes data_types(arguments.size());
for (size_t i = 0; i < arguments.size(); ++i) {
data_types[i] = arguments[i].type;
}
return std::make_shared<FunctionCast>(name, data_types, return_type);
}
bool skip_return_type_check() const override { return true; }
DataTypePtr get_return_type_impl(const ColumnsWithTypeAndName& arguments) const override {
return nullptr;
}
bool use_default_implementation_for_nulls() const override { return false; }
};
void register_function_cast(SimpleFunctionFactory& factory) {
factory.register_function<FunctionBuilderCast>();
}
} // namespace doris::vectorized