blob: 653c27313ec33627e2bc24db95a6d07d17444816 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.iceberg.spark;
import java.util.List;
import java.util.Set;
import java.util.function.Supplier;
import org.apache.iceberg.Schema;
import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
import org.apache.iceberg.relocated.com.google.common.collect.Lists;
import org.apache.iceberg.types.Type;
import org.apache.iceberg.types.Type.TypeID;
import org.apache.iceberg.types.TypeUtil;
import org.apache.iceberg.types.Types;
import org.apache.spark.sql.types.ArrayType;
import org.apache.spark.sql.types.BinaryType;
import org.apache.spark.sql.types.BooleanType;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DateType;
import org.apache.spark.sql.types.DecimalType;
import org.apache.spark.sql.types.DoubleType;
import org.apache.spark.sql.types.FloatType;
import org.apache.spark.sql.types.IntegerType;
import org.apache.spark.sql.types.LongType;
import org.apache.spark.sql.types.MapType;
import org.apache.spark.sql.types.StringType;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.types.TimestampType;
public class PruneColumnsWithoutReordering extends TypeUtil.CustomOrderSchemaVisitor<Type> {
private final StructType requestedType;
private final Set<Integer> filterRefs;
private DataType current = null;
PruneColumnsWithoutReordering(StructType requestedType, Set<Integer> filterRefs) {
this.requestedType = requestedType;
this.filterRefs = filterRefs;
}
@Override
public Type schema(Schema schema, Supplier<Type> structResult) {
this.current = requestedType;
try {
return structResult.get();
} finally {
this.current = null;
}
}
@Override
public Type struct(Types.StructType struct, Iterable<Type> fieldResults) {
Preconditions.checkNotNull(struct, "Cannot prune null struct. Pruning must start with a schema.");
Preconditions.checkArgument(current instanceof StructType, "Not a struct: %s", current);
List<Types.NestedField> fields = struct.fields();
List<Type> types = Lists.newArrayList(fieldResults);
boolean changed = false;
List<Types.NestedField> newFields = Lists.newArrayListWithExpectedSize(types.size());
for (int i = 0; i < fields.size(); i += 1) {
Types.NestedField field = fields.get(i);
Type type = types.get(i);
if (type == null) {
changed = true;
} else if (field.type() == type) {
newFields.add(field);
} else if (field.isOptional()) {
changed = true;
newFields.add(Types.NestedField.optional(field.fieldId(), field.name(), type));
} else {
changed = true;
newFields.add(Types.NestedField.required(field.fieldId(), field.name(), type));
}
}
if (changed) {
return Types.StructType.of(newFields);
}
return struct;
}
@Override
public Type field(Types.NestedField field, Supplier<Type> fieldResult) {
Preconditions.checkArgument(current instanceof StructType, "Not a struct: %s", current);
StructType requestedStruct = (StructType) current;
// fields are resolved by name because Spark only sees the current table schema.
if (requestedStruct.getFieldIndex(field.name()).isEmpty()) {
// make sure that filter fields are projected even if they aren't in the requested schema.
if (filterRefs.contains(field.fieldId())) {
return field.type();
}
return null;
}
int fieldIndex = requestedStruct.fieldIndex(field.name());
StructField requestedField = requestedStruct.fields()[fieldIndex];
Preconditions.checkArgument(requestedField.nullable() || field.isRequired(),
"Cannot project an optional field as non-null: %s", field.name());
this.current = requestedField.dataType();
try {
return fieldResult.get();
} catch (IllegalArgumentException e) {
throw new IllegalArgumentException(
"Invalid projection for field " + field.name() + ": " + e.getMessage(), e);
} finally {
this.current = requestedStruct;
}
}
@Override
public Type list(Types.ListType list, Supplier<Type> elementResult) {
Preconditions.checkArgument(current instanceof ArrayType, "Not an array: %s", current);
ArrayType requestedArray = (ArrayType) current;
Preconditions.checkArgument(requestedArray.containsNull() || !list.isElementOptional(),
"Cannot project an array of optional elements as required elements: %s", requestedArray);
this.current = requestedArray.elementType();
try {
Type elementType = elementResult.get();
if (list.elementType() == elementType) {
return list;
}
// must be a projected element type, create a new list
if (list.isElementOptional()) {
return Types.ListType.ofOptional(list.elementId(), elementType);
} else {
return Types.ListType.ofRequired(list.elementId(), elementType);
}
} finally {
this.current = requestedArray;
}
}
@Override
public Type map(Types.MapType map, Supplier<Type> keyResult, Supplier<Type> valueResult) {
Preconditions.checkArgument(current instanceof MapType, "Not a map: %s", current);
MapType requestedMap = (MapType) current;
Preconditions.checkArgument(requestedMap.valueContainsNull() || !map.isValueOptional(),
"Cannot project a map of optional values as required values: %s", map);
this.current = requestedMap.valueType();
try {
Type valueType = valueResult.get();
if (map.valueType() == valueType) {
return map;
}
if (map.isValueOptional()) {
return Types.MapType.ofOptional(map.keyId(), map.valueId(), map.keyType(), valueType);
} else {
return Types.MapType.ofRequired(map.keyId(), map.valueId(), map.keyType(), valueType);
}
} finally {
this.current = requestedMap;
}
}
@Override
public Type primitive(Type.PrimitiveType primitive) {
Class<? extends DataType> expectedType = TYPES.get(primitive.typeId());
Preconditions.checkArgument(expectedType != null && expectedType.isInstance(current),
"Cannot project %s to incompatible type: %s", primitive, current);
// additional checks based on type
switch (primitive.typeId()) {
case DECIMAL:
Types.DecimalType decimal = (Types.DecimalType) primitive;
DecimalType requestedDecimal = (DecimalType) current;
Preconditions.checkArgument(requestedDecimal.scale() == decimal.scale(),
"Cannot project decimal with incompatible scale: %s != %s", requestedDecimal.scale(), decimal.scale());
Preconditions.checkArgument(requestedDecimal.precision() >= decimal.precision(),
"Cannot project decimal with incompatible precision: %s < %s",
requestedDecimal.precision(), decimal.precision());
break;
case TIMESTAMP:
Types.TimestampType timestamp = (Types.TimestampType) primitive;
Preconditions.checkArgument(timestamp.shouldAdjustToUTC(),
"Cannot project timestamp (without time zone) as timestamptz (with time zone)");
break;
default:
}
return primitive;
}
private static final ImmutableMap<TypeID, Class<? extends DataType>> TYPES = ImmutableMap
.<TypeID, Class<? extends DataType>>builder()
.put(TypeID.BOOLEAN, BooleanType.class)
.put(TypeID.INTEGER, IntegerType.class)
.put(TypeID.LONG, LongType.class)
.put(TypeID.FLOAT, FloatType.class)
.put(TypeID.DOUBLE, DoubleType.class)
.put(TypeID.DATE, DateType.class)
.put(TypeID.TIMESTAMP, TimestampType.class)
.put(TypeID.DECIMAL, DecimalType.class)
.put(TypeID.UUID, StringType.class)
.put(TypeID.STRING, StringType.class)
.put(TypeID.FIXED, BinaryType.class)
.put(TypeID.BINARY, BinaryType.class)
.build();
}