Spark 3.3: Fix rewrite_position_deletes for certain partition types (#8059) (#8069)
diff --git a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRewritePositionDeleteFiles.java b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRewritePositionDeleteFiles.java
new file mode 100644
index 0000000..622b83c
--- /dev/null
+++ b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRewritePositionDeleteFiles.java
@@ -0,0 +1,408 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.spark.extensions;
+
+import static org.apache.iceberg.TableProperties.DEFAULT_FILE_FORMAT;
+import static org.apache.iceberg.TableProperties.DEFAULT_FILE_FORMAT_DEFAULT;
+import static org.apache.spark.sql.functions.expr;
+import static org.apache.spark.sql.functions.lit;
+
+import java.io.Closeable;
+import java.io.IOException;
+import java.math.BigDecimal;
+import java.sql.Date;
+import java.sql.Timestamp;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
+import java.util.stream.Collectors;
+import org.apache.iceberg.ContentFile;
+import org.apache.iceberg.DataFile;
+import org.apache.iceberg.DeleteFile;
+import org.apache.iceberg.FileFormat;
+import org.apache.iceberg.FileScanTask;
+import org.apache.iceberg.Files;
+import org.apache.iceberg.MetadataTableType;
+import org.apache.iceberg.MetadataTableUtils;
+import org.apache.iceberg.PositionDeletesScanTask;
+import org.apache.iceberg.RowDelta;
+import org.apache.iceberg.ScanTask;
+import org.apache.iceberg.StructLike;
+import org.apache.iceberg.Table;
+import org.apache.iceberg.actions.RewritePositionDeleteFiles.FileGroupRewriteResult;
+import org.apache.iceberg.actions.RewritePositionDeleteFiles.Result;
+import org.apache.iceberg.actions.SizeBasedFileRewriter;
+import org.apache.iceberg.data.GenericAppenderFactory;
+import org.apache.iceberg.data.Record;
+import org.apache.iceberg.deletes.PositionDelete;
+import org.apache.iceberg.deletes.PositionDeleteWriter;
+import org.apache.iceberg.encryption.EncryptedFiles;
+import org.apache.iceberg.encryption.EncryptedOutputFile;
+import org.apache.iceberg.encryption.EncryptionKeyMetadata;
+import org.apache.iceberg.io.CloseableIterable;
+import org.apache.iceberg.io.FileAppenderFactory;
+import org.apache.iceberg.io.OutputFile;
+import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
+import org.apache.iceberg.relocated.com.google.common.collect.Lists;
+import org.apache.iceberg.spark.Spark3Util;
+import org.apache.iceberg.spark.SparkCatalogConfig;
+import org.apache.iceberg.spark.actions.SparkActions;
+import org.apache.iceberg.util.Pair;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.types.StructType;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+import org.junit.runners.Parameterized;
+
+public class TestRewritePositionDeleteFiles extends SparkExtensionsTestBase {
+
+ private static final Map<String, String> CATALOG_PROPS =
+ ImmutableMap.of(
+ "type", "hive",
+ "default-namespace", "default",
+ "cache-enabled", "false");
+
+ private static final int NUM_DATA_FILES = 5;
+ private static final int ROWS_PER_DATA_FILE = 100;
+ private static final int DELETE_FILES_PER_PARTITION = 2;
+ private static final int DELETE_FILE_SIZE = 10;
+
+ @Parameterized.Parameters(
+ name = "formatVersion = {0}, catalogName = {1}, implementation = {2}, config = {3}")
+ public static Object[][] parameters() {
+ return new Object[][] {
+ {
+ SparkCatalogConfig.HIVE.catalogName(),
+ SparkCatalogConfig.HIVE.implementation(),
+ CATALOG_PROPS
+ }
+ };
+ }
+
+ @Rule public TemporaryFolder temp = new TemporaryFolder();
+
+ public TestRewritePositionDeleteFiles(
+ String catalogName, String implementation, Map<String, String> config) {
+ super(catalogName, implementation, config);
+ }
+
+ @After
+ public void cleanup() {
+ sql("DROP TABLE IF EXISTS %s", tableName);
+ }
+
+ @Test
+ public void testDatePartition() throws Exception {
+ createTable("date");
+ Date baseDate = Date.valueOf("2023-01-01");
+ insertData(i -> Date.valueOf(baseDate.toLocalDate().plusDays(i)));
+ testDanglingDelete();
+ }
+
+ @Test
+ public void testBooleanPartition() throws Exception {
+ createTable("boolean");
+ insertData(i -> i % 2 == 0, 2);
+ testDanglingDelete(2);
+ }
+
+ @Test
+ public void testTimestampPartition() throws Exception {
+ createTable("timestamp");
+ Timestamp baseTimestamp = Timestamp.valueOf("2023-01-01 15:30:00");
+ insertData(i -> Timestamp.valueOf(baseTimestamp.toLocalDateTime().plusDays(i)));
+ testDanglingDelete();
+ }
+
+ @Test
+ public void testBytePartition() throws Exception {
+ createTable("byte");
+ insertData(i -> i);
+ testDanglingDelete();
+ }
+
+ @Test
+ public void testDecimalPartition() throws Exception {
+ createTable("decimal(18, 10)");
+ BigDecimal baseDecimal = new BigDecimal("1.0");
+ insertData(i -> baseDecimal.add(new BigDecimal(i)));
+ testDanglingDelete();
+ }
+
+ @Test
+ public void testBinaryPartition() throws Exception {
+ createTable("binary");
+ insertData(i -> java.nio.ByteBuffer.allocate(4).putInt(i).array());
+ testDanglingDelete();
+ }
+
+ @Test
+ public void testCharPartition() throws Exception {
+ createTable("char(10)");
+ insertData(Object::toString);
+ testDanglingDelete();
+ }
+
+ @Test
+ public void testVarcharPartition() throws Exception {
+ createTable("varchar(10)");
+ insertData(Object::toString);
+ testDanglingDelete();
+ }
+
+ @Test
+ public void testIntPartition() throws Exception {
+ createTable("int");
+ insertData(i -> i);
+ testDanglingDelete();
+ }
+
+ @Test
+ public void testDaysPartitionTransform() throws Exception {
+ createTable("timestamp", "days(partition_col)");
+ Timestamp baseTimestamp = Timestamp.valueOf("2023-01-01 15:30:00");
+ insertData(i -> Timestamp.valueOf(baseTimestamp.toLocalDateTime().plusDays(i)));
+ testDanglingDelete();
+ }
+
+ @Test
+ public void testNullTransform() throws Exception {
+ createTable("int");
+ insertData(i -> i == 0 ? null : 1, 2);
+ testDanglingDelete(2);
+ }
+
+ private <T> void testDanglingDelete() throws Exception {
+ testDanglingDelete(NUM_DATA_FILES);
+ }
+
+ private <T> void testDanglingDelete(int numDataFiles) throws Exception {
+ Table table = Spark3Util.loadIcebergTable(spark, tableName);
+
+ List<DataFile> dataFiles = dataFiles(table);
+ Assert.assertEquals(numDataFiles, dataFiles.size());
+
+ SparkActions.get(spark)
+ .rewriteDataFiles(table)
+ .option(SizeBasedFileRewriter.REWRITE_ALL, "true")
+ .execute();
+
+ // write dangling delete files for 'old data files'
+ writePosDeletesForFiles(table, dataFiles);
+ List<DeleteFile> deleteFiles = deleteFiles(table);
+ Assert.assertEquals(numDataFiles * DELETE_FILES_PER_PARTITION, deleteFiles.size());
+
+ List<Object[]> expectedRecords = records(tableName);
+
+ Result result =
+ SparkActions.get(spark)
+ .rewritePositionDeletes(table)
+ .option(SizeBasedFileRewriter.REWRITE_ALL, "true")
+ .execute();
+
+ List<DeleteFile> newDeleteFiles = deleteFiles(table);
+ Assert.assertEquals("Should have removed all dangling delete files", 0, newDeleteFiles.size());
+ checkResult(result, deleteFiles, Lists.newArrayList(), numDataFiles);
+
+ List<Object[]> actualRecords = records(tableName);
+ assertEquals("Rows must match", expectedRecords, actualRecords);
+ }
+
+ private void createTable(String partitionType) {
+ createTable(partitionType, "partition_col");
+ }
+
+ private void createTable(String partitionType, String partitionCol) {
+ sql(
+ "CREATE TABLE %s (id long, partition_col %s, c1 string, c2 string) "
+ + "USING iceberg "
+ + "PARTITIONED BY (%s) "
+ + "TBLPROPERTIES('format-version'='2')",
+ tableName, partitionType, partitionCol);
+ }
+
+ private <T> void insertData(Function<Integer, ?> partitionValueFunction) throws Exception {
+ insertData(partitionValueFunction, NUM_DATA_FILES);
+ }
+
+ private <T> void insertData(Function<Integer, ?> partitionValue, int numDataFiles)
+ throws Exception {
+ for (int i = 0; i < numDataFiles; i++) {
+ Dataset<Row> df =
+ spark
+ .range(0, ROWS_PER_DATA_FILE)
+ .withColumn("partition_col", lit(partitionValue.apply(i)))
+ .withColumn("c1", expr("CAST(id AS STRING)"))
+ .withColumn("c2", expr("CAST(id AS STRING)"));
+ appendAsFile(df);
+ }
+ }
+
+ private void appendAsFile(Dataset<Row> df) throws Exception {
+ // ensure the schema is precise
+ StructType sparkSchema = spark.table(tableName).schema();
+ spark.createDataFrame(df.rdd(), sparkSchema).coalesce(1).writeTo(tableName).append();
+ }
+
+ private void writePosDeletesForFiles(Table table, List<DataFile> files) throws IOException {
+
+ Map<StructLike, List<DataFile>> filesByPartition =
+ files.stream().collect(Collectors.groupingBy(ContentFile::partition));
+ List<DeleteFile> deleteFiles =
+ Lists.newArrayListWithCapacity(DELETE_FILES_PER_PARTITION * filesByPartition.size());
+
+ for (Map.Entry<StructLike, List<DataFile>> filesByPartitionEntry :
+ filesByPartition.entrySet()) {
+
+ StructLike partition = filesByPartitionEntry.getKey();
+ List<DataFile> partitionFiles = filesByPartitionEntry.getValue();
+
+ int deletesForPartition = partitionFiles.size() * DELETE_FILE_SIZE;
+ Assert.assertEquals(
+ "Number of delete files per partition should be "
+ + "evenly divisible by requested deletes per data file times number of data files in this partition",
+ 0,
+ deletesForPartition % DELETE_FILE_SIZE);
+ int deleteFileSize = deletesForPartition / DELETE_FILES_PER_PARTITION;
+
+ int counter = 0;
+ List<Pair<CharSequence, Long>> deletes = Lists.newArrayList();
+ for (DataFile partitionFile : partitionFiles) {
+ for (int deletePos = 0; deletePos < DELETE_FILE_SIZE; deletePos++) {
+ deletes.add(Pair.of(partitionFile.path(), (long) deletePos));
+ counter++;
+ if (counter == deleteFileSize) {
+ // Dump to file and reset variables
+ OutputFile output = Files.localOutput(temp.newFile());
+ deleteFiles.add(writeDeleteFile(table, output, partition, deletes));
+ counter = 0;
+ deletes.clear();
+ }
+ }
+ }
+ }
+
+ RowDelta rowDelta = table.newRowDelta();
+ deleteFiles.forEach(rowDelta::addDeletes);
+ rowDelta.commit();
+ }
+
+ private DeleteFile writeDeleteFile(
+ Table table, OutputFile out, StructLike partition, List<Pair<CharSequence, Long>> deletes)
+ throws IOException {
+ FileFormat format = defaultFormat(table.properties());
+ FileAppenderFactory<Record> factory = new GenericAppenderFactory(table.schema(), table.spec());
+
+ PositionDeleteWriter<Record> writer =
+ factory.newPosDeleteWriter(encrypt(out), format, partition);
+ PositionDelete<Record> posDelete = PositionDelete.create();
+ try (Closeable toClose = writer) {
+ for (Pair<CharSequence, Long> delete : deletes) {
+ writer.write(posDelete.set(delete.first(), delete.second(), null));
+ }
+ }
+
+ return writer.toDeleteFile();
+ }
+
+ private static EncryptedOutputFile encrypt(OutputFile out) {
+ return EncryptedFiles.encryptedOutput(out, EncryptionKeyMetadata.EMPTY);
+ }
+
+ private static FileFormat defaultFormat(Map<String, String> properties) {
+ String formatString = properties.getOrDefault(DEFAULT_FILE_FORMAT, DEFAULT_FILE_FORMAT_DEFAULT);
+ return FileFormat.fromString(formatString);
+ }
+
+ private List<Object[]> records(String table) {
+ return rowsToJava(
+ spark.read().format("iceberg").load(table).sort("partition_col", "id").collectAsList());
+ }
+
+ private long size(List<DeleteFile> deleteFiles) {
+ return deleteFiles.stream().mapToLong(DeleteFile::fileSizeInBytes).sum();
+ }
+
+ private List<DataFile> dataFiles(Table table) {
+ CloseableIterable<FileScanTask> tasks = table.newScan().includeColumnStats().planFiles();
+ return Lists.newArrayList(CloseableIterable.transform(tasks, FileScanTask::file));
+ }
+
+ private List<DeleteFile> deleteFiles(Table table) {
+ Table deletesTable =
+ MetadataTableUtils.createMetadataTableInstance(table, MetadataTableType.POSITION_DELETES);
+ CloseableIterable<ScanTask> tasks = deletesTable.newBatchScan().planFiles();
+ return Lists.newArrayList(
+ CloseableIterable.transform(tasks, t -> ((PositionDeletesScanTask) t).file()));
+ }
+
+ private void checkResult(
+ Result result,
+ List<DeleteFile> rewrittenDeletes,
+ List<DeleteFile> newDeletes,
+ int expectedGroups) {
+ Assert.assertEquals(
+ "Expected rewritten delete file count does not match",
+ rewrittenDeletes.size(),
+ result.rewrittenDeleteFilesCount());
+ Assert.assertEquals(
+ "Expected new delete file count does not match",
+ newDeletes.size(),
+ result.addedDeleteFilesCount());
+ Assert.assertEquals(
+ "Expected rewritten delete byte count does not match",
+ size(rewrittenDeletes),
+ result.rewrittenBytesCount());
+ Assert.assertEquals(
+ "Expected new delete byte count does not match",
+ size(newDeletes),
+ result.addedBytesCount());
+
+ Assert.assertEquals(
+ "Expected rewrite group count does not match",
+ expectedGroups,
+ result.rewriteResults().size());
+ Assert.assertEquals(
+ "Expected rewritten delete file count in all groups to match",
+ rewrittenDeletes.size(),
+ result.rewriteResults().stream()
+ .mapToInt(FileGroupRewriteResult::rewrittenDeleteFilesCount)
+ .sum());
+ Assert.assertEquals(
+ "Expected added delete file count in all groups to match",
+ newDeletes.size(),
+ result.rewriteResults().stream()
+ .mapToInt(FileGroupRewriteResult::addedDeleteFilesCount)
+ .sum());
+ Assert.assertEquals(
+ "Expected rewritten delete bytes in all groups to match",
+ size(rewrittenDeletes),
+ result.rewriteResults().stream()
+ .mapToLong(FileGroupRewriteResult::rewrittenBytesCount)
+ .sum());
+ Assert.assertEquals(
+ "Expected added delete bytes in all groups to match",
+ size(newDeletes),
+ result.rewriteResults().stream().mapToLong(FileGroupRewriteResult::addedBytesCount).sum());
+ }
+}
diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkValueConverter.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkValueConverter.java
index 687d9f4..c168f77 100644
--- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkValueConverter.java
+++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkValueConverter.java
@@ -28,6 +28,7 @@
import org.apache.iceberg.relocated.com.google.common.collect.Maps;
import org.apache.iceberg.types.Type;
import org.apache.iceberg.types.Types;
+import org.apache.iceberg.util.ByteBuffers;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.catalyst.util.DateTimeUtils;
@@ -120,4 +121,39 @@
}
return record;
}
+
+ public static Object convertToSpark(Type type, Object object) {
+ if (object == null) {
+ return null;
+ }
+
+ switch (type.typeId()) {
+ case STRUCT:
+ case LIST:
+ case MAP:
+ return new UnsupportedOperationException("Complex types currently not supported");
+ case DATE:
+ return DateTimeUtils.daysToLocalDate((int) object);
+ case TIMESTAMP:
+ Types.TimestampType ts = (Types.TimestampType) type.asPrimitiveType();
+ if (ts.shouldAdjustToUTC()) {
+ return DateTimeUtils.microsToInstant((long) object);
+ } else {
+ return DateTimeUtils.microsToLocalDateTime((long) object);
+ }
+ case BINARY:
+ return ByteBuffers.toByteArray((ByteBuffer) object);
+ case INTEGER:
+ case BOOLEAN:
+ case LONG:
+ case FLOAT:
+ case DOUBLE:
+ case DECIMAL:
+ case STRING:
+ case FIXED:
+ return object;
+ default:
+ throw new UnsupportedOperationException("Not a supported type: " + type);
+ }
+ }
}
diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/actions/SparkBinPackPositionDeletesRewriter.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/actions/SparkBinPackPositionDeletesRewriter.java
index 51c4cc6..565acd4 100644
--- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/actions/SparkBinPackPositionDeletesRewriter.java
+++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/actions/SparkBinPackPositionDeletesRewriter.java
@@ -20,6 +20,7 @@
import static org.apache.iceberg.MetadataTableType.POSITION_DELETES;
import static org.apache.spark.sql.functions.col;
+import static org.apache.spark.sql.functions.lit;
import java.util.List;
import java.util.Optional;
@@ -40,7 +41,9 @@
import org.apache.iceberg.spark.SparkReadOptions;
import org.apache.iceberg.spark.SparkTableCache;
import org.apache.iceberg.spark.SparkTableUtil;
+import org.apache.iceberg.spark.SparkValueConverter;
import org.apache.iceberg.spark.SparkWriteOptions;
+import org.apache.iceberg.types.Type;
import org.apache.iceberg.types.Types;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
@@ -125,10 +128,11 @@
IntStream.range(0, fields.size())
.mapToObj(
i -> {
- Class<?> type = fields.get(i).type().typeId().javaClass();
- Object value = partition.get(i, type);
+ Type type = fields.get(i).type();
+ Object value = partition.get(i, type.typeId().javaClass());
+ Object convertedValue = SparkValueConverter.convertToSpark(type, value);
Column col = col("partition." + fields.get(i).name());
- return col.equalTo(value);
+ return col.eqNullSafe(lit(convertedValue));
})
.reduce(Column::and);
if (condition.isPresent()) {