blob: 39d3eaf31ddbeda7f3cb78eec78c0f29efecce8f [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include "core/types.h"
#include "exprs/function/cast/cast_base.h"
#include "util/io_helper.h"
namespace doris {
struct CastToBool {
template <class SRC>
static inline bool from_number(const SRC& from, UInt8& to, CastParameters& params);
template <class SRC>
static inline bool from_decimal(const SRC& from, UInt8& to, UInt32 precision, UInt32 scale,
CastParameters& params);
static inline bool from_string(const StringRef& from, UInt8& to, CastParameters& params);
};
template <>
inline bool CastToBool::from_number(const UInt8& from, UInt8& to, CastParameters&) {
to = from;
return true;
}
template <>
inline bool CastToBool::from_number(const Int8& from, UInt8& to, CastParameters&) {
to = (from != 0);
return true;
}
template <>
inline bool CastToBool::from_number(const Int16& from, UInt8& to, CastParameters&) {
to = (from != 0);
return true;
}
template <>
inline bool CastToBool::from_number(const Int32& from, UInt8& to, CastParameters&) {
to = (from != 0);
return true;
}
template <>
inline bool CastToBool::from_number(const Int64& from, UInt8& to, CastParameters&) {
to = (from != 0);
return true;
}
template <>
inline bool CastToBool::from_number(const Int128& from, UInt8& to, CastParameters&) {
to = (from != 0);
return true;
}
template <>
inline bool CastToBool::from_number(const Float32& from, UInt8& to, CastParameters&) {
to = (from != 0);
return true;
}
template <>
inline bool CastToBool::from_number(const Float64& from, UInt8& to, CastParameters&) {
to = (from != 0);
return true;
}
template <>
inline bool CastToBool::from_decimal(const Decimal32& from, UInt8& to, UInt32, UInt32,
CastParameters&) {
to = (from.value != 0);
return true;
}
template <>
inline bool CastToBool::from_decimal(const Decimal64& from, UInt8& to, UInt32, UInt32,
CastParameters&) {
to = (from.value != 0);
return true;
}
template <>
inline bool CastToBool::from_decimal(const DecimalV2Value& from, UInt8& to, UInt32, UInt32,
CastParameters&) {
to = (from.value() != 0);
return true;
}
template <>
inline bool CastToBool::from_decimal(const Decimal128V3& from, UInt8& to, UInt32, UInt32,
CastParameters&) {
to = (from.value != 0);
return true;
}
template <>
inline bool CastToBool::from_decimal(const Decimal256& from, UInt8& to, UInt32, UInt32,
CastParameters&) {
to = (from.value != 0);
return true;
}
inline bool CastToBool::from_string(const StringRef& from, UInt8& to, CastParameters&) {
return try_read_bool_text(to, from);
}
template <CastModeType Mode>
class CastToImpl<Mode, DataTypeString, DataTypeBool> : public CastToBase {
public:
Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments,
uint32_t result, size_t input_rows_count,
const NullMap::value_type* null_map = nullptr) const override {
const auto* col_from = check_and_get_column<DataTypeString::ColumnType>(
block.get_by_position(arguments[0]).column.get());
auto to_type = block.get_by_position(result).type;
auto serde = remove_nullable(to_type)->get_serde();
// by default framework, to_type is already unwrapped nullable
MutableColumnPtr column_to = to_type->create_column();
ColumnNullable::MutablePtr nullable_col_to = ColumnNullable::create(
std::move(column_to), ColumnUInt8::create(input_rows_count, 0));
if constexpr (Mode == CastModeType::NonStrictMode) {
// may write nulls to nullable_col_to
RETURN_IF_ERROR(serde->from_string_batch(*col_from, *nullable_col_to, {}));
} else if constexpr (Mode == CastModeType::StrictMode) {
// WON'T write nulls to nullable_col_to, just raise errors. null_map is only used to skip invalid rows
RETURN_IF_ERROR(serde->from_string_strict_mode_batch(
*col_from, nullable_col_to->get_nested_column(), {}, null_map));
} else {
return Status::InternalError("Unsupported cast mode");
}
block.get_by_position(result).column = std::move(nullable_col_to);
return Status::OK();
}
};
template <CastModeType AllMode, typename NumberType>
requires(IsDataTypeNumber<NumberType>)
class CastToImpl<AllMode, NumberType, DataTypeBool> : public CastToBase {
public:
Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments,
uint32_t result, size_t input_rows_count,
const NullMap::value_type* null_map = nullptr) const override {
const auto* col_from = check_and_get_column<typename NumberType::ColumnType>(
block.get_by_position(arguments[0]).column.get());
DataTypeBool::ColumnType::MutablePtr col_to =
DataTypeBool::ColumnType::create(input_rows_count);
CastParameters params;
params.is_strict = (AllMode == CastModeType::StrictMode);
for (size_t i = 0; i < input_rows_count; ++i) {
CastToBool::from_number(col_from->get_element(i), col_to->get_element(i), params);
}
block.get_by_position(result).column = std::move(col_to);
return Status::OK();
}
};
template <CastModeType AllMode, typename DecimalType>
requires(IsDataTypeDecimal<DecimalType>)
class CastToImpl<AllMode, DecimalType, DataTypeBool> : public CastToBase {
public:
Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments,
uint32_t result, size_t input_rows_count,
const NullMap::value_type* null_map = nullptr) const override {
const auto* col_from = check_and_get_column<typename DecimalType::ColumnType>(
block.get_by_position(arguments[0]).column.get());
const auto type_from = block.get_by_position(arguments[0]).type;
DataTypeBool::ColumnType::MutablePtr col_to =
DataTypeBool::ColumnType::create(input_rows_count);
CastParameters params;
params.is_strict = (AllMode == CastModeType::StrictMode);
auto precision = type_from->get_precision();
auto scale = type_from->get_scale();
for (size_t i = 0; i < input_rows_count; ++i) {
CastToBool::from_decimal(col_from->get_element(i), col_to->get_element(i), precision,
scale, params);
}
block.get_by_position(result).column = std::move(col_to);
return Status::OK();
}
};
} // namespace doris