blob: 54a5cfd9a9e25c4d717d0a4369afb952f367d4fb [file]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "paimon/common/utils/fields_comparator.h"
#include <cstddef>
#include <string>
#include "arrow/api.h"
#include "arrow/util/checked_cast.h"
#include "fmt/format.h"
#include "paimon/common/data/binary_string.h"
#include "paimon/common/types/data_field.h"
#include "paimon/common/utils/date_time_utils.h"
#include "paimon/data/decimal.h"
#include "paimon/data/timestamp.h"
#include "paimon/memory/bytes.h"
#include "paimon/status.h"
namespace paimon {
Result<std::unique_ptr<FieldsComparator>> FieldsComparator::Create(
const std::vector<DataField>& input_data_field, bool is_ascending_order) {
std::vector<int32_t> sort_fields;
sort_fields.reserve(input_data_field.size());
for (int32_t i = 0; i < static_cast<int32_t>(input_data_field.size()); i++) {
sort_fields.push_back(i);
}
return Create(input_data_field, sort_fields, is_ascending_order);
}
Result<std::unique_ptr<FieldsComparator>> FieldsComparator::Create(
const std::vector<DataField>& input_data_field, const std::vector<int32_t>& sort_fields,
bool is_ascending_order) {
std::vector<FieldComparatorFunc> comparators;
comparators.reserve(sort_fields.size());
for (const auto& sort_field_idx : sort_fields) {
const auto& type = input_data_field[sort_field_idx].Type();
PAIMON_ASSIGN_OR_RAISE(FieldComparatorFunc cmp, CompareField(sort_field_idx, type));
comparators.emplace_back(cmp);
}
return std::unique_ptr<FieldsComparator>(
new FieldsComparator(is_ascending_order, sort_fields, std::move(comparators)));
}
int32_t FieldsComparator::CompareTo(const InternalRow& lhs, const InternalRow& rhs) const {
// in default comparator, null is first (not smallest)
int32_t null_is_last_ret = -1;
for (size_t i = 0; i < sort_fields_.size(); i++) {
bool lhs_null = lhs.IsNullAt(sort_fields_[i]);
bool rhs_null = rhs.IsNullAt(sort_fields_[i]);
if (lhs_null && rhs_null) {
// Continue to compare the next element
} else if (lhs_null) {
return null_is_last_ret;
} else if (rhs_null) {
return -null_is_last_ret;
} else {
int32_t comp = comparators_[i](lhs, rhs);
if (comp != 0) {
return is_ascending_order_ ? comp : -comp;
}
}
}
return 0;
}
Result<FieldsComparator::FieldComparatorFunc> FieldsComparator::CompareField(
int32_t field_idx, const std::shared_ptr<arrow::DataType>& input_type) {
arrow::Type::type type = input_type->id();
switch (type) {
case arrow::Type::type::BOOL:
return FieldsComparator::FieldComparatorFunc(
[field_idx](const InternalRow& lhs, const InternalRow& rhs) -> int32_t {
bool lvalue = lhs.GetBoolean(field_idx);
bool rvalue = rhs.GetBoolean(field_idx);
return lvalue == rvalue ? 0 : (lvalue < rvalue ? -1 : 1);
});
case arrow::Type::type::INT8:
return FieldsComparator::FieldComparatorFunc(
[field_idx](const InternalRow& lhs, const InternalRow& rhs) -> int32_t {
int8_t lvalue = lhs.GetByte(field_idx);
int8_t rvalue = rhs.GetByte(field_idx);
return lvalue == rvalue ? 0 : (lvalue < rvalue ? -1 : 1);
});
case arrow::Type::type::INT16:
return FieldsComparator::FieldComparatorFunc(
[field_idx](const InternalRow& lhs, const InternalRow& rhs) -> int32_t {
int16_t lvalue = lhs.GetShort(field_idx);
int16_t rvalue = rhs.GetShort(field_idx);
return lvalue == rvalue ? 0 : (lvalue < rvalue ? -1 : 1);
});
case arrow::Type::type::DATE32:
return FieldsComparator::FieldComparatorFunc(
[field_idx](const InternalRow& lhs, const InternalRow& rhs) -> int32_t {
int32_t lvalue = lhs.GetDate(field_idx);
int32_t rvalue = rhs.GetDate(field_idx);
return lvalue == rvalue ? 0 : (lvalue < rvalue ? -1 : 1);
});
case arrow::Type::type::INT32:
return FieldsComparator::FieldComparatorFunc(
[field_idx](const InternalRow& lhs, const InternalRow& rhs) -> int32_t {
int32_t lvalue = lhs.GetInt(field_idx);
int32_t rvalue = rhs.GetInt(field_idx);
return lvalue == rvalue ? 0 : (lvalue < rvalue ? -1 : 1);
});
case arrow::Type::type::INT64:
return FieldsComparator::FieldComparatorFunc(
[field_idx](const InternalRow& lhs, const InternalRow& rhs) -> int32_t {
int64_t lvalue = lhs.GetLong(field_idx);
int64_t rvalue = rhs.GetLong(field_idx);
return lvalue == rvalue ? 0 : (lvalue < rvalue ? -1 : 1);
});
case arrow::Type::type::FLOAT:
// TODO(xinyu.lxy):
// currently in java KeyComparatorSupplier: -inf < -0.0 == +0.0 < +inf = nan
// paimon-cpp: -inf < -0.0 == +0.0 < +inf and nan cannot be compared
return FieldsComparator::FieldComparatorFunc(
[field_idx](const InternalRow& lhs, const InternalRow& rhs) -> int32_t {
float lvalue = lhs.GetFloat(field_idx);
float rvalue = rhs.GetFloat(field_idx);
return lvalue == rvalue ? 0 : (lvalue < rvalue ? -1 : 1);
});
case arrow::Type::type::DOUBLE:
return FieldsComparator::FieldComparatorFunc(
[field_idx](const InternalRow& lhs, const InternalRow& rhs) -> int32_t {
double lvalue = lhs.GetDouble(field_idx);
double rvalue = rhs.GetDouble(field_idx);
return lvalue == rvalue ? 0 : (lvalue < rvalue ? -1 : 1);
});
case arrow::Type::type::STRING:
case arrow::Type::type::BINARY: {
return FieldsComparator::FieldComparatorFunc(
[field_idx](const InternalRow& lhs, const InternalRow& rhs) -> int32_t {
auto lvalue = lhs.GetStringView(field_idx);
auto rvalue = rhs.GetStringView(field_idx);
int32_t cmp = lvalue.compare(rvalue);
return cmp == 0 ? 0 : (cmp > 0 ? 1 : -1);
});
}
case arrow::Type::type::TIMESTAMP: {
auto timestamp_type =
arrow::internal::checked_pointer_cast<arrow::TimestampType>(input_type);
assert(timestamp_type);
int32_t precision = DateTimeUtils::GetPrecisionFromType(timestamp_type);
return FieldsComparator::FieldComparatorFunc(
[field_idx, precision](const InternalRow& lhs, const InternalRow& rhs) -> int32_t {
Timestamp lvalue = lhs.GetTimestamp(field_idx, precision);
Timestamp rvalue = rhs.GetTimestamp(field_idx, precision);
return lvalue == rvalue ? 0 : (lvalue < rvalue ? -1 : 1);
});
}
case arrow::Type::type::DECIMAL: {
auto* decimal_type =
arrow::internal::checked_cast<arrow::Decimal128Type*>(input_type.get());
assert(decimal_type);
auto precision = decimal_type->precision();
auto scale = decimal_type->scale();
return FieldsComparator::FieldComparatorFunc(
[field_idx, precision, scale](const InternalRow& lhs,
const InternalRow& rhs) -> int32_t {
Decimal lvalue = lhs.GetDecimal(field_idx, precision, scale);
Decimal rvalue = rhs.GetDecimal(field_idx, precision, scale);
int32_t cmp = lvalue.CompareTo(rvalue);
return cmp == 0 ? 0 : (cmp > 0 ? 1 : -1);
});
}
default:
return Status::NotImplemented(fmt::format("Do not support comparing {} type in idx {}",
input_type->ToString(), field_idx));
}
}
} // namespace paimon