blob: 43c7cecfa10e076cc589efa4be57365e5178a094 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.iceberg.spark;
import static org.apache.iceberg.expressions.Expressions.and;
import static org.apache.iceberg.expressions.Expressions.equal;
import static org.apache.iceberg.expressions.Expressions.greaterThan;
import static org.apache.iceberg.expressions.Expressions.greaterThanOrEqual;
import static org.apache.iceberg.expressions.Expressions.in;
import static org.apache.iceberg.expressions.Expressions.isNaN;
import static org.apache.iceberg.expressions.Expressions.isNull;
import static org.apache.iceberg.expressions.Expressions.lessThan;
import static org.apache.iceberg.expressions.Expressions.lessThanOrEqual;
import static org.apache.iceberg.expressions.Expressions.not;
import static org.apache.iceberg.expressions.Expressions.notIn;
import static org.apache.iceberg.expressions.Expressions.notNull;
import static org.apache.iceberg.expressions.Expressions.or;
import static org.apache.iceberg.expressions.Expressions.startsWith;
import java.util.Arrays;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import org.apache.iceberg.expressions.Expression;
import org.apache.iceberg.expressions.Expression.Operation;
import org.apache.iceberg.expressions.Expressions;
import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
import org.apache.iceberg.util.NaNUtil;
import org.apache.spark.sql.connector.expressions.Literal;
import org.apache.spark.sql.connector.expressions.NamedReference;
import org.apache.spark.sql.connector.expressions.filter.And;
import org.apache.spark.sql.connector.expressions.filter.Not;
import org.apache.spark.sql.connector.expressions.filter.Or;
import org.apache.spark.sql.connector.expressions.filter.Predicate;
import org.apache.spark.sql.types.Decimal;
import org.apache.spark.unsafe.types.UTF8String;
public class SparkV2Filters {
private static final String TRUE = "ALWAYS_TRUE";
private static final String FALSE = "ALWAYS_FALSE";
private static final String EQ = "=";
private static final String EQ_NULL_SAFE = "<=>";
private static final String GT = ">";
private static final String GT_EQ = ">=";
private static final String LT = "<";
private static final String LT_EQ = "<=";
private static final String IN = "IN";
private static final String IS_NULL = "IS_NULL";
private static final String NOT_NULL = "IS_NOT_NULL";
private static final String AND = "AND";
private static final String OR = "OR";
private static final String NOT = "NOT";
private static final String STARTS_WITH = "STARTS_WITH";
private static final Map<String, Operation> FILTERS =
ImmutableMap.<String, Operation>builder()
.put(TRUE, Operation.TRUE)
.put(FALSE, Operation.FALSE)
.put(EQ, Operation.EQ)
.put(EQ_NULL_SAFE, Operation.EQ)
.put(GT, Operation.GT)
.put(GT_EQ, Operation.GT_EQ)
.put(LT, Operation.LT)
.put(LT_EQ, Operation.LT_EQ)
.put(IN, Operation.IN)
.put(IS_NULL, Operation.IS_NULL)
.put(NOT_NULL, Operation.NOT_NULL)
.put(AND, Operation.AND)
.put(OR, Operation.OR)
.put(NOT, Operation.NOT)
.put(STARTS_WITH, Operation.STARTS_WITH)
.buildOrThrow();
private SparkV2Filters() {}
@SuppressWarnings({"checkstyle:CyclomaticComplexity", "checkstyle:MethodLength"})
public static Expression convert(Predicate predicate) {
Operation op = FILTERS.get(predicate.name());
if (op != null) {
switch (op) {
case TRUE:
return Expressions.alwaysTrue();
case FALSE:
return Expressions.alwaysFalse();
case IS_NULL:
return isRef(child(predicate)) ? isNull(SparkUtil.toColumnName(child(predicate))) : null;
case NOT_NULL:
return isRef(child(predicate)) ? notNull(SparkUtil.toColumnName(child(predicate))) : null;
case LT:
if (isRef(leftChild(predicate)) && isLiteral(rightChild(predicate))) {
String columnName = SparkUtil.toColumnName(leftChild(predicate));
return lessThan(columnName, convertLiteral(rightChild(predicate)));
} else if (isRef(rightChild(predicate)) && isLiteral(leftChild(predicate))) {
String columnName = SparkUtil.toColumnName(rightChild(predicate));
return greaterThan(columnName, convertLiteral(leftChild(predicate)));
} else {
return null;
}
case LT_EQ:
if (isRef(leftChild(predicate)) && isLiteral(rightChild(predicate))) {
String columnName = SparkUtil.toColumnName(leftChild(predicate));
return lessThanOrEqual(columnName, convertLiteral(rightChild(predicate)));
} else if (isRef(rightChild(predicate)) && isLiteral(leftChild(predicate))) {
String columnName = SparkUtil.toColumnName(rightChild(predicate));
return greaterThanOrEqual(columnName, convertLiteral(leftChild(predicate)));
} else {
return null;
}
case GT:
if (isRef(leftChild(predicate)) && isLiteral(rightChild(predicate))) {
String columnName = SparkUtil.toColumnName(leftChild(predicate));
return greaterThan(columnName, convertLiteral(rightChild(predicate)));
} else if (isRef(rightChild(predicate)) && isLiteral(leftChild(predicate))) {
String columnName = SparkUtil.toColumnName(rightChild(predicate));
return lessThan(columnName, convertLiteral(leftChild(predicate)));
} else {
return null;
}
case GT_EQ:
if (isRef(leftChild(predicate)) && isLiteral(rightChild(predicate))) {
String columnName = SparkUtil.toColumnName(leftChild(predicate));
return greaterThanOrEqual(columnName, convertLiteral(rightChild(predicate)));
} else if (isRef(rightChild(predicate)) && isLiteral(leftChild(predicate))) {
String columnName = SparkUtil.toColumnName(rightChild(predicate));
return lessThanOrEqual(columnName, convertLiteral(leftChild(predicate)));
} else {
return null;
}
case EQ: // used for both eq and null-safe-eq
Object value;
String columnName;
if (isRef(leftChild(predicate)) && isLiteral(rightChild(predicate))) {
columnName = SparkUtil.toColumnName(leftChild(predicate));
value = convertLiteral(rightChild(predicate));
} else if (isRef(rightChild(predicate)) && isLiteral(leftChild(predicate))) {
columnName = SparkUtil.toColumnName(rightChild(predicate));
value = convertLiteral(leftChild(predicate));
} else {
return null;
}
if (predicate.name().equals(EQ)) {
// comparison with null in normal equality is always null. this is probably a mistake.
Preconditions.checkNotNull(
value, "Expression is always false (eq is not null-safe): %s", predicate);
return handleEqual(columnName, value);
} else { // "<=>"
if (value == null) {
return isNull(columnName);
} else {
return handleEqual(columnName, value);
}
}
case IN:
if (isSupportedInPredicate(predicate)) {
return in(
SparkUtil.toColumnName(childAtIndex(predicate, 0)),
Arrays.stream(predicate.children())
.skip(1)
.map(val -> convertLiteral(((Literal<?>) val)))
.filter(Objects::nonNull)
.collect(Collectors.toList()));
} else {
return null;
}
case NOT:
Not notPredicate = (Not) predicate;
Predicate childPredicate = notPredicate.child();
if (childPredicate.name().equals(IN) && isSupportedInPredicate(childPredicate)) {
// infer an extra notNull predicate for Spark NOT IN filters
// as Iceberg expressions don't follow the 3-value SQL boolean logic
// col NOT IN (1, 2) in Spark is equal to notNull(col) && notIn(col, 1, 2) in Iceberg
Expression notIn =
notIn(
SparkUtil.toColumnName(childAtIndex(childPredicate, 0)),
Arrays.stream(childPredicate.children())
.skip(1)
.map(val -> convertLiteral(((Literal<?>) val)))
.filter(Objects::nonNull)
.collect(Collectors.toList()));
return and(notNull(SparkUtil.toColumnName(childAtIndex(childPredicate, 0))), notIn);
} else if (hasNoInFilter(childPredicate)) {
Expression child = convert(childPredicate);
if (child != null) {
return not(child);
}
}
return null;
case AND:
{
And andPredicate = (And) predicate;
Expression left = convert(andPredicate.left());
Expression right = convert(andPredicate.right());
if (left != null && right != null) {
return and(left, right);
}
return null;
}
case OR:
{
Or orPredicate = (Or) predicate;
Expression left = convert(orPredicate.left());
Expression right = convert(orPredicate.right());
if (left != null && right != null) {
return or(left, right);
}
return null;
}
case STARTS_WITH:
String colName = SparkUtil.toColumnName(leftChild(predicate));
return startsWith(colName, convertLiteral(rightChild(predicate)).toString());
}
}
return null;
}
@SuppressWarnings("unchecked")
private static <T> T child(Predicate predicate) {
org.apache.spark.sql.connector.expressions.Expression[] children = predicate.children();
Preconditions.checkArgument(
children.length == 1, "Predicate should have one child: %s", predicate);
return (T) children[0];
}
@SuppressWarnings("unchecked")
private static <T> T leftChild(Predicate predicate) {
org.apache.spark.sql.connector.expressions.Expression[] children = predicate.children();
Preconditions.checkArgument(
children.length == 2, "Predicate should have two children: %s", predicate);
return (T) children[0];
}
@SuppressWarnings("unchecked")
private static <T> T rightChild(Predicate predicate) {
org.apache.spark.sql.connector.expressions.Expression[] children = predicate.children();
Preconditions.checkArgument(
children.length == 2, "Predicate should have two children: %s", predicate);
return (T) children[1];
}
@SuppressWarnings("unchecked")
private static <T> T childAtIndex(Predicate predicate, int index) {
return (T) predicate.children()[index];
}
private static boolean isRef(org.apache.spark.sql.connector.expressions.Expression expr) {
return expr instanceof NamedReference;
}
private static boolean isLiteral(org.apache.spark.sql.connector.expressions.Expression expr) {
return expr instanceof Literal;
}
private static Object convertLiteral(Literal<?> literal) {
if (literal.value() instanceof UTF8String) {
return ((UTF8String) literal.value()).toString();
} else if (literal.value() instanceof Decimal) {
return ((Decimal) literal.value()).toJavaBigDecimal();
}
return literal.value();
}
private static Expression handleEqual(String attribute, Object value) {
if (NaNUtil.isNaN(value)) {
return isNaN(attribute);
} else {
return equal(attribute, value);
}
}
private static boolean hasNoInFilter(Predicate predicate) {
Operation op = FILTERS.get(predicate.name());
if (op != null) {
switch (op) {
case AND:
And andPredicate = (And) predicate;
return hasNoInFilter(andPredicate.left()) && hasNoInFilter(andPredicate.right());
case OR:
Or orPredicate = (Or) predicate;
return hasNoInFilter(orPredicate.left()) && hasNoInFilter(orPredicate.right());
case NOT:
Not notPredicate = (Not) predicate;
return hasNoInFilter(notPredicate.child());
case IN:
return false;
default:
return true;
}
}
return false;
}
private static boolean isSupportedInPredicate(Predicate predicate) {
if (!isRef(childAtIndex(predicate, 0))) {
return false;
} else {
return Arrays.stream(predicate.children()).skip(1).allMatch(SparkV2Filters::isLiteral);
}
}
}