blob: 0b46b76a4c04d01de9f252c7e2839f1fb0a202bf [file] [log] [blame]
/*
* Copyright 2024-present Alibaba Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "paimon/format/parquet/predicate_converter.h"
#include <cstdint>
#include <utility>
#include "arrow/compute/expression.h"
#include "arrow/scalar.h"
#include "arrow/type_fwd.h"
#include "arrow/util/decimal.h"
#include "fmt/format.h"
#include "paimon/data/decimal.h"
#include "paimon/defs.h"
#include "paimon/predicate/compound_predicate.h"
#include "paimon/predicate/function.h"
#include "paimon/predicate/leaf_predicate.h"
#include "paimon/predicate/literal.h"
#include "paimon/predicate/predicate.h"
namespace paimon::parquet {
arrow::compute::Expression PredicateConverter::AlwaysTrue() {
static const arrow::compute::Expression expr = arrow::compute::literal(true);
return expr;
}
Result<arrow::compute::Expression> PredicateConverter::Convert(
const std::shared_ptr<Predicate>& predicate, uint32_t node_count_limit) {
if (!predicate) {
return AlwaysTrue();
}
uint32_t node_count = 0;
CollectNodeCount(predicate, &node_count);
if (node_count > node_count_limit) {
return AlwaysTrue();
}
return InnerConvert(predicate);
}
void PredicateConverter::CollectNodeCount(const std::shared_ptr<Predicate>& predicate,
uint32_t* node_count) {
const auto& function_type = predicate->GetFunction().GetType();
if (auto leaf_predicate = std::dynamic_pointer_cast<LeafPredicate>(predicate)) {
if (function_type == Function::Type::IN || function_type == Function::Type::NOT_IN) {
// IN and NOT_IN will be converted to Or(Equals) and And(NotEqual)
*node_count += leaf_predicate->Literals().size();
}
*node_count += 1;
return;
}
if (auto compound_predicate = std::dynamic_pointer_cast<CompoundPredicate>(predicate)) {
*node_count += 1;
for (const auto& child : compound_predicate->Children()) {
CollectNodeCount(child, node_count);
}
}
}
Result<arrow::compute::Expression> PredicateConverter::InnerConvert(
const std::shared_ptr<Predicate>& predicate) {
if (!predicate) {
return AlwaysTrue();
}
if (auto leaf_predicate = std::dynamic_pointer_cast<LeafPredicate>(predicate)) {
return ConvertLeaf(leaf_predicate);
}
if (auto compound_predicate = std::dynamic_pointer_cast<CompoundPredicate>(predicate)) {
return ConvertCompound(compound_predicate);
}
return Status::Invalid("invalid predicate, must be leaf or compound");
}
Result<arrow::compute::Expression> PredicateConverter::ConvertCompound(
const std::shared_ptr<CompoundPredicate>& compound_predicate) {
const auto& children = compound_predicate->Children();
const auto& function = compound_predicate->GetFunction();
auto function_type = function.GetType();
switch (function_type) {
case Function::Type::AND: {
std::vector<arrow::compute::Expression> sub_exprs;
sub_exprs.reserve(children.size());
for (const auto& child : children) {
PAIMON_ASSIGN_OR_RAISE(auto sub_expr, InnerConvert(child));
sub_exprs.push_back(std::move(sub_expr));
}
return arrow::compute::and_(sub_exprs);
}
case Function::Type::OR: {
std::vector<arrow::compute::Expression> sub_exprs;
sub_exprs.reserve(children.size());
for (const auto& child : children) {
PAIMON_ASSIGN_OR_RAISE(auto sub_expr, InnerConvert(child));
sub_exprs.push_back(std::move(sub_expr));
}
return arrow::compute::or_(sub_exprs);
}
default:
return Status::Invalid(
fmt::format("invalid predicate type {}", static_cast<int32_t>(function_type)));
}
}
Status PredicateConverter::CheckLiteralNotEmpty(const std::vector<Literal>& literals,
const Function& function,
const std::string& field_name) {
if (literals.empty()) {
return Status::Invalid(fmt::format("predicate [{}] need literal on field {}",
function.ToString(), field_name));
}
return Status::OK();
}
#define CONVERT_TO_ARROW_LITERAL(LITERAL) \
auto arrow_literal_result = ConvertToArrowLiteral(LITERAL); \
if (!arrow_literal_result.ok() && arrow_literal_result.status().IsNotImplemented()) { \
return AlwaysTrue(); \
} \
if (!arrow_literal_result.ok()) { \
return arrow_literal_result.status(); \
} \
auto arrow_literal = std::move(arrow_literal_result).value();
Result<arrow::compute::Expression> PredicateConverter::ConvertLeaf(
const std::shared_ptr<LeafPredicate>& leaf_predicate) {
const auto& field_name = leaf_predicate->FieldName();
const auto& literals = leaf_predicate->Literals();
const auto& function = leaf_predicate->GetFunction();
auto function_type = function.GetType();
switch (function_type) {
case Function::Type::IS_NULL: {
return arrow::compute::is_null(arrow::compute::field_ref(field_name),
/*nan_is_null=*/false);
}
case Function::Type::IS_NOT_NULL: {
return arrow::compute::not_(
arrow::compute::is_null(arrow::compute::field_ref(field_name),
/*nan_is_null=*/false));
}
case Function::Type::EQUAL: {
PAIMON_RETURN_NOT_OK(CheckLiteralNotEmpty(literals, function, field_name));
CONVERT_TO_ARROW_LITERAL(literals[0]);
return arrow::compute::equal(arrow::compute::field_ref(field_name), arrow_literal);
}
case Function::Type::NOT_EQUAL: {
PAIMON_RETURN_NOT_OK(CheckLiteralNotEmpty(literals, function, field_name));
CONVERT_TO_ARROW_LITERAL(literals[0]);
return arrow::compute::not_equal(arrow::compute::field_ref(field_name), arrow_literal);
}
case Function::Type::GREATER_THAN: {
PAIMON_RETURN_NOT_OK(CheckLiteralNotEmpty(literals, function, field_name));
CONVERT_TO_ARROW_LITERAL(literals[0]);
return arrow::compute::greater(arrow::compute::field_ref(field_name), arrow_literal);
}
case Function::Type::GREATER_OR_EQUAL: {
PAIMON_RETURN_NOT_OK(CheckLiteralNotEmpty(literals, function, field_name));
CONVERT_TO_ARROW_LITERAL(literals[0]);
return arrow::compute::greater_equal(arrow::compute::field_ref(field_name),
arrow_literal);
}
case Function::Type::LESS_THAN: {
PAIMON_RETURN_NOT_OK(CheckLiteralNotEmpty(literals, function, field_name));
CONVERT_TO_ARROW_LITERAL(literals[0]);
return arrow::compute::less(arrow::compute::field_ref(field_name), arrow_literal);
}
case Function::Type::LESS_OR_EQUAL: {
PAIMON_RETURN_NOT_OK(CheckLiteralNotEmpty(literals, function, field_name));
CONVERT_TO_ARROW_LITERAL(literals[0]);
return arrow::compute::less_equal(arrow::compute::field_ref(field_name), arrow_literal);
}
// Noted that: java paimon don't support pushdown IN and NOT_IN to parquet
case Function::Type::IN: {
PAIMON_RETURN_NOT_OK(CheckLiteralNotEmpty(literals, function, field_name));
// in convert to Or(Equals)
std::vector<arrow::compute::Expression> sub_exprs;
sub_exprs.reserve(literals.size());
for (const auto& literal : literals) {
CONVERT_TO_ARROW_LITERAL(literal);
sub_exprs.push_back(
arrow::compute::equal(arrow::compute::field_ref(field_name), arrow_literal));
}
return arrow::compute::or_(sub_exprs);
}
case Function::Type::NOT_IN: {
PAIMON_RETURN_NOT_OK(CheckLiteralNotEmpty(literals, function, field_name));
// not in convert to And(NotEqual)
std::vector<arrow::compute::Expression> sub_exprs;
sub_exprs.reserve(literals.size());
for (const auto& literal : literals) {
CONVERT_TO_ARROW_LITERAL(literal);
sub_exprs.push_back(arrow::compute::not_equal(arrow::compute::field_ref(field_name),
arrow_literal));
}
return arrow::compute::and_(sub_exprs);
}
default:
return Status::Invalid(
fmt::format("invalid predicate type {}", static_cast<int32_t>(function_type)));
}
return Status::OK();
}
Result<arrow::compute::Expression> PredicateConverter::ConvertToArrowLiteral(
const Literal& literal) {
auto literal_type = literal.GetType();
if (literal.IsNull()) {
return Status::Invalid("literal cannot be null in predicate");
}
switch (literal_type) {
case FieldType::BOOLEAN:
return arrow::compute::literal(std::make_shared<arrow::BooleanScalar>(
static_cast<bool>(literal.GetValue<bool>())));
case FieldType::TINYINT:
return arrow::compute::literal(std::make_shared<arrow::Int8Scalar>(
static_cast<int8_t>(literal.GetValue<int8_t>())));
case FieldType::SMALLINT:
return arrow::compute::literal(std::make_shared<arrow::Int16Scalar>(
static_cast<int16_t>(literal.GetValue<int16_t>())));
case FieldType::INT:
return arrow::compute::literal(std::make_shared<arrow::Int32Scalar>(
static_cast<int32_t>(literal.GetValue<int32_t>())));
case FieldType::DATE:
return arrow::compute::literal(std::make_shared<arrow::Date32Scalar>(
static_cast<int32_t>(literal.GetValue<int32_t>())));
case FieldType::BIGINT:
return arrow::compute::literal(std::make_shared<arrow::Int64Scalar>(
static_cast<int64_t>(literal.GetValue<int64_t>())));
case FieldType::FLOAT:
return arrow::compute::literal(std::make_shared<arrow::FloatScalar>(
static_cast<float>(literal.GetValue<float>())));
case FieldType::DOUBLE:
return arrow::compute::literal(std::make_shared<arrow::DoubleScalar>(
static_cast<double>(literal.GetValue<double>())));
case FieldType::STRING: {
auto str = literal.GetValue<std::string>();
return arrow::compute::literal(std::make_shared<arrow::StringScalar>(str));
}
case FieldType::DECIMAL: {
auto decimal = literal.GetValue<Decimal>();
return arrow::compute::literal(std::make_shared<arrow::Decimal128Scalar>(
arrow::Decimal128(decimal.HighBits(), decimal.LowBits()),
arrow::decimal128(decimal.Precision(), decimal.Scale())));
}
// TODO(lisizhuo.lsz): java paimon does not support BINARY, TIMESTAMP and DECIMAL
case FieldType::TIMESTAMP:
case FieldType::BINARY:
return Status::NotImplemented(
"Not support Binary and Timestamp predicate push down in parquet file "
"format");
default:
return Status::Invalid(
fmt::format("invalid literal type {}", static_cast<int32_t>(literal_type)));
}
}
} // namespace paimon::parquet