blob: cefb0ad0aeba7fd2108ab8bc53153ef3f1e8d127 [file] [log] [blame]
/*
* Copyright 2024-present Alibaba Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "paimon/common/predicate/literal_converter.h"
#include <optional>
#include "arrow/array/array_base.h"
#include "arrow/array/array_binary.h"
#include "arrow/array/array_decimal.h"
#include "arrow/array/array_dict.h"
#include "arrow/array/array_primitive.h"
#include "arrow/type.h"
#include "arrow/type_traits.h"
#include "arrow/util/checked_cast.h"
#include "arrow/util/decimal.h"
#include "fmt/format.h"
#include "paimon/common/data/binary_string.h"
#include "paimon/common/data/internal_row.h"
#include "paimon/common/utils/date_time_utils.h"
#include "paimon/common/utils/field_type_utils.h"
#include "paimon/common/utils/string_utils.h"
#include "paimon/data/decimal.h"
#include "paimon/data/timestamp.h"
#include "paimon/defs.h"
#include "paimon/memory/bytes.h"
#include "paimon/status.h"
namespace paimon {
Result<Literal> LiteralConverter::ConvertLiteralsFromString(const FieldType& type,
const std::string& value_str) {
switch (type) {
case FieldType::BOOLEAN: {
auto value = StringUtils::StringToValue<bool>(value_str);
if (value == std::nullopt) {
return Status::Invalid(fmt::format("cannot convert {} to BOOLEAN type", value_str));
}
return Literal(value.value());
}
case FieldType::TINYINT: {
auto value = StringUtils::StringToValue<int8_t>(value_str);
if (value == std::nullopt) {
return Status::Invalid(fmt::format("cannot convert {} to TINYINT type", value_str));
}
return Literal(value.value());
}
case FieldType::SMALLINT: {
auto value = StringUtils::StringToValue<int16_t>(value_str);
if (value == std::nullopt) {
return Status::Invalid(
fmt::format("cannot convert {} to SMALLINT type", value_str));
}
return Literal(value.value());
}
case FieldType::INT: {
auto value = StringUtils::StringToValue<int32_t>(value_str);
if (value == std::nullopt) {
return Status::Invalid(fmt::format("cannot convert {} to INT type", value_str));
}
return Literal(value.value());
}
case FieldType::BIGINT: {
auto value = StringUtils::StringToValue<int64_t>(value_str);
if (value == std::nullopt) {
return Status::Invalid(fmt::format("cannot convert {} to BIGINT type", value_str));
}
return Literal(value.value());
}
case FieldType::FLOAT: {
auto value = StringUtils::StringToValue<float>(value_str);
if (value == std::nullopt) {
return Status::Invalid(fmt::format("cannot convert {} to FLOAT type", value_str));
}
return Literal(value.value());
}
case FieldType::DOUBLE: {
auto value = StringUtils::StringToValue<double>(value_str);
if (value == std::nullopt) {
return Status::Invalid(fmt::format("cannot convert {} to DOUBLE type", value_str));
}
return Literal(value.value());
}
case FieldType::STRING:
case FieldType::BINARY:
return Literal(type, value_str.data(), value_str.size());
default:
return Status::Invalid(
fmt::format("Do not support type {}", FieldTypeUtils::FieldTypeToString(type)));
}
}
Result<Literal> LiteralConverter::ConvertLiteralsFromRow(
const std::shared_ptr<arrow::Schema>& schema, const InternalRow& row, int32_t field_idx,
const FieldType& type) {
if (row.IsNullAt(field_idx)) {
return Literal(type);
}
switch (type) {
case FieldType::BOOLEAN:
return Literal(row.GetBoolean(field_idx));
case FieldType::TINYINT:
return Literal(static_cast<int8_t>(row.GetByte(field_idx)));
case FieldType::SMALLINT:
return Literal(row.GetShort(field_idx));
case FieldType::INT:
return Literal(row.GetInt(field_idx));
case FieldType::BIGINT:
return Literal(row.GetLong(field_idx));
case FieldType::FLOAT:
return Literal(row.GetFloat(field_idx));
case FieldType::DOUBLE:
return Literal(row.GetDouble(field_idx));
case FieldType::STRING: {
std::string field = row.GetString(field_idx).ToString();
return Literal(type, field.data(), field.size());
}
case FieldType::BINARY: {
auto field = row.GetBinary(field_idx);
return Literal(type, field->data(), field->size());
}
case FieldType::TIMESTAMP: {
auto timestamp_type = arrow::internal::checked_pointer_cast<arrow::TimestampType>(
schema->field(field_idx)->type());
if (!timestamp_type) {
return Status::Invalid(
fmt::format("Convert literal from row not valid for schema {}, field_idx {}",
schema->ToString(), field_idx));
}
int32_t precision = DateTimeUtils::GetPrecisionFromType(timestamp_type);
Timestamp field = row.GetTimestamp(field_idx, precision);
return Literal(field);
}
case FieldType::DECIMAL: {
auto* decimal_type = arrow::internal::checked_cast<arrow::Decimal128Type*>(
schema->field(field_idx)->type().get());
if (!decimal_type) {
return Status::Invalid(
fmt::format("Convert literal from row not valid for schema {}, field_idx {}",
schema->ToString(), field_idx));
}
auto precision = decimal_type->precision();
auto scale = decimal_type->scale();
Decimal field = row.GetDecimal(field_idx, precision, scale);
return Literal(field);
}
case FieldType::DATE:
return Literal(FieldType::DATE, row.GetInt(field_idx));
case FieldType::ARRAY:
case FieldType::MAP:
case FieldType::STRUCT:
default:
return Status::Invalid(fmt::format("Convert literal from row not valid for {}",
FieldTypeUtils::FieldTypeToString(type)));
}
}
Result<std::vector<Literal>> LiteralConverter::ConvertLiteralsFromArray(const arrow::Array& array,
bool own_data) {
const auto kind = array.type_id();
switch (kind) {
case arrow::Type::type::BOOL:
return GetLiteralFromGenericArray<arrow::BooleanType>(array, FieldType::BOOLEAN);
case arrow::Type::type::INT8:
return GetLiteralFromGenericArray<arrow::Int8Type>(array, FieldType::TINYINT);
case arrow::Type::type::INT16:
return GetLiteralFromGenericArray<arrow::Int16Type>(array, FieldType::SMALLINT);
case arrow::Type::type::INT32:
return GetLiteralFromGenericArray<arrow::Int32Type>(array, FieldType::INT);
case arrow::Type::type::INT64:
return GetLiteralFromGenericArray<arrow::Int64Type>(array, FieldType::BIGINT);
case arrow::Type::type::FLOAT:
return GetLiteralFromGenericArray<arrow::FloatType>(array, FieldType::FLOAT);
case arrow::Type::type::DOUBLE:
return GetLiteralFromGenericArray<arrow::DoubleType>(array, FieldType::DOUBLE);
case arrow::Type::type::STRING:
return GetLiteralFromStringArray<arrow::StringType>(array, FieldType::STRING, own_data);
case arrow::Type::type::BINARY:
return GetLiteralFromStringArray<arrow::BinaryType>(array, FieldType::BINARY, own_data);
case arrow::Type::type::TIMESTAMP:
return GetLiteralFromTimestampArray(array);
case arrow::Type::type::DECIMAL128:
return GetLiteralFromDecimalArray(array);
case arrow::Type::type::DATE32:
return GetLiteralFromDateArray(array);
case arrow::Type::type::DICTIONARY: {
const auto& dict_array =
arrow::internal::checked_cast<const arrow::DictionaryArray&>(array);
auto* dict_type =
arrow::internal::checked_cast<arrow::DictionaryType*>(dict_array.type().get());
auto value_type_id = dict_type->value_type()->id();
auto index_type_id = dict_type->index_type()->id();
if (value_type_id == arrow::Type::type::STRING &&
index_type_id == arrow::Type::type::INT32) {
return GetLiteralFromDictionaryArray<arrow::StringArray, arrow::Int32Array>(
dict_array, FieldType::STRING, own_data);
} else if (value_type_id == arrow::Type::type::LARGE_STRING &&
index_type_id == arrow::Type::type::INT64) {
return GetLiteralFromDictionaryArray<arrow::LargeStringArray, arrow::Int64Array>(
dict_array, FieldType::STRING, own_data);
} else {
return Status::Invalid(
"only support [STRING, INT32] or [LARGE_STRING, INT64] for DictionaryArray");
}
}
default:
return Status::Invalid(
fmt::format("Not support literal on arrow {} type", array.type()->ToString()));
}
}
std::vector<Literal> LiteralConverter::GetLiteralFromDecimalArray(const arrow::Array& array) {
using ArrayType = typename arrow::TypeTraits<arrow::Decimal128Type>::ArrayType;
const auto& array_(arrow::internal::checked_cast<const ArrayType&>(array));
auto* arrow_type = arrow::internal::checked_cast<arrow::Decimal128Type*>(array.type().get());
int32_t precision = arrow_type->precision();
int32_t scale = arrow_type->scale();
std::vector<Literal> literals;
literals.reserve(array_.length());
for (int64_t i = 0; i < array_.length(); i++) {
if (array_.IsNull(i)) {
literals.emplace_back(FieldType::DECIMAL);
} else {
const arrow::Decimal128 decimal(array_.GetValue(i));
auto value =
static_cast<Decimal::int128_t>(decimal.high_bits()) << 64 | decimal.low_bits();
literals.emplace_back(Decimal(precision, scale, value));
}
}
return literals;
}
std::vector<Literal> LiteralConverter::GetLiteralFromDateArray(const arrow::Array& array) {
using ArrayType = typename arrow::TypeTraits<arrow::Date32Type>::ArrayType;
const auto& array_(arrow::internal::checked_cast<const ArrayType&>(array));
std::vector<Literal> literals;
literals.reserve(array_.length());
for (int64_t i = 0; i < array_.length(); i++) {
if (array_.IsNull(i)) {
literals.emplace_back(FieldType::DATE);
} else {
literals.emplace_back(FieldType::DATE, array_.Value(i));
}
}
return literals;
}
std::vector<Literal> LiteralConverter::GetLiteralFromTimestampArray(const arrow::Array& array) {
using ArrayType = typename arrow::TypeTraits<arrow::TimestampType>::ArrayType;
const auto& array_(arrow::internal::checked_cast<const ArrayType&>(array));
auto timestamp_type =
arrow::internal::checked_pointer_cast<arrow::TimestampType>(array_.type());
assert(timestamp_type);
DateTimeUtils::TimeType time_type = DateTimeUtils::GetTimeTypeFromArrowType(timestamp_type);
std::vector<Literal> literals;
literals.reserve(array_.length());
for (int64_t i = 0; i < array_.length(); i++) {
if (array_.IsNull(i)) {
literals.emplace_back(FieldType::TIMESTAMP);
} else {
int64_t data = array_.Value(i);
auto [milli, nano] = DateTimeUtils::TimestampConverter(
data, time_type, DateTimeUtils::TimeType::MILLISECOND,
DateTimeUtils::TimeType::NANOSECOND);
literals.emplace_back(Timestamp(milli, nano));
}
}
return literals;
}
} // namespace paimon