Spark 3.3: Fix rewrite_position_deletes for certain partition types (#8059) (#8069)

diff --git a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRewritePositionDeleteFiles.java b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRewritePositionDeleteFiles.java
new file mode 100644
index 0000000..622b83c
--- /dev/null
+++ b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRewritePositionDeleteFiles.java
@@ -0,0 +1,408 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.spark.extensions;
+
+import static org.apache.iceberg.TableProperties.DEFAULT_FILE_FORMAT;
+import static org.apache.iceberg.TableProperties.DEFAULT_FILE_FORMAT_DEFAULT;
+import static org.apache.spark.sql.functions.expr;
+import static org.apache.spark.sql.functions.lit;
+
+import java.io.Closeable;
+import java.io.IOException;
+import java.math.BigDecimal;
+import java.sql.Date;
+import java.sql.Timestamp;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
+import java.util.stream.Collectors;
+import org.apache.iceberg.ContentFile;
+import org.apache.iceberg.DataFile;
+import org.apache.iceberg.DeleteFile;
+import org.apache.iceberg.FileFormat;
+import org.apache.iceberg.FileScanTask;
+import org.apache.iceberg.Files;
+import org.apache.iceberg.MetadataTableType;
+import org.apache.iceberg.MetadataTableUtils;
+import org.apache.iceberg.PositionDeletesScanTask;
+import org.apache.iceberg.RowDelta;
+import org.apache.iceberg.ScanTask;
+import org.apache.iceberg.StructLike;
+import org.apache.iceberg.Table;
+import org.apache.iceberg.actions.RewritePositionDeleteFiles.FileGroupRewriteResult;
+import org.apache.iceberg.actions.RewritePositionDeleteFiles.Result;
+import org.apache.iceberg.actions.SizeBasedFileRewriter;
+import org.apache.iceberg.data.GenericAppenderFactory;
+import org.apache.iceberg.data.Record;
+import org.apache.iceberg.deletes.PositionDelete;
+import org.apache.iceberg.deletes.PositionDeleteWriter;
+import org.apache.iceberg.encryption.EncryptedFiles;
+import org.apache.iceberg.encryption.EncryptedOutputFile;
+import org.apache.iceberg.encryption.EncryptionKeyMetadata;
+import org.apache.iceberg.io.CloseableIterable;
+import org.apache.iceberg.io.FileAppenderFactory;
+import org.apache.iceberg.io.OutputFile;
+import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
+import org.apache.iceberg.relocated.com.google.common.collect.Lists;
+import org.apache.iceberg.spark.Spark3Util;
+import org.apache.iceberg.spark.SparkCatalogConfig;
+import org.apache.iceberg.spark.actions.SparkActions;
+import org.apache.iceberg.util.Pair;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.types.StructType;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+import org.junit.runners.Parameterized;
+
+public class TestRewritePositionDeleteFiles extends SparkExtensionsTestBase {
+
+  private static final Map<String, String> CATALOG_PROPS =
+      ImmutableMap.of(
+          "type", "hive",
+          "default-namespace", "default",
+          "cache-enabled", "false");
+
+  private static final int NUM_DATA_FILES = 5;
+  private static final int ROWS_PER_DATA_FILE = 100;
+  private static final int DELETE_FILES_PER_PARTITION = 2;
+  private static final int DELETE_FILE_SIZE = 10;
+
+  @Parameterized.Parameters(
+      name = "formatVersion = {0}, catalogName = {1}, implementation = {2}, config = {3}")
+  public static Object[][] parameters() {
+    return new Object[][] {
+      {
+        SparkCatalogConfig.HIVE.catalogName(),
+        SparkCatalogConfig.HIVE.implementation(),
+        CATALOG_PROPS
+      }
+    };
+  }
+
+  @Rule public TemporaryFolder temp = new TemporaryFolder();
+
+  public TestRewritePositionDeleteFiles(
+      String catalogName, String implementation, Map<String, String> config) {
+    super(catalogName, implementation, config);
+  }
+
+  @After
+  public void cleanup() {
+    sql("DROP TABLE IF EXISTS %s", tableName);
+  }
+
+  @Test
+  public void testDatePartition() throws Exception {
+    createTable("date");
+    Date baseDate = Date.valueOf("2023-01-01");
+    insertData(i -> Date.valueOf(baseDate.toLocalDate().plusDays(i)));
+    testDanglingDelete();
+  }
+
+  @Test
+  public void testBooleanPartition() throws Exception {
+    createTable("boolean");
+    insertData(i -> i % 2 == 0, 2);
+    testDanglingDelete(2);
+  }
+
+  @Test
+  public void testTimestampPartition() throws Exception {
+    createTable("timestamp");
+    Timestamp baseTimestamp = Timestamp.valueOf("2023-01-01 15:30:00");
+    insertData(i -> Timestamp.valueOf(baseTimestamp.toLocalDateTime().plusDays(i)));
+    testDanglingDelete();
+  }
+
+  @Test
+  public void testBytePartition() throws Exception {
+    createTable("byte");
+    insertData(i -> i);
+    testDanglingDelete();
+  }
+
+  @Test
+  public void testDecimalPartition() throws Exception {
+    createTable("decimal(18, 10)");
+    BigDecimal baseDecimal = new BigDecimal("1.0");
+    insertData(i -> baseDecimal.add(new BigDecimal(i)));
+    testDanglingDelete();
+  }
+
+  @Test
+  public void testBinaryPartition() throws Exception {
+    createTable("binary");
+    insertData(i -> java.nio.ByteBuffer.allocate(4).putInt(i).array());
+    testDanglingDelete();
+  }
+
+  @Test
+  public void testCharPartition() throws Exception {
+    createTable("char(10)");
+    insertData(Object::toString);
+    testDanglingDelete();
+  }
+
+  @Test
+  public void testVarcharPartition() throws Exception {
+    createTable("varchar(10)");
+    insertData(Object::toString);
+    testDanglingDelete();
+  }
+
+  @Test
+  public void testIntPartition() throws Exception {
+    createTable("int");
+    insertData(i -> i);
+    testDanglingDelete();
+  }
+
+  @Test
+  public void testDaysPartitionTransform() throws Exception {
+    createTable("timestamp", "days(partition_col)");
+    Timestamp baseTimestamp = Timestamp.valueOf("2023-01-01 15:30:00");
+    insertData(i -> Timestamp.valueOf(baseTimestamp.toLocalDateTime().plusDays(i)));
+    testDanglingDelete();
+  }
+
+  @Test
+  public void testNullTransform() throws Exception {
+    createTable("int");
+    insertData(i -> i == 0 ? null : 1, 2);
+    testDanglingDelete(2);
+  }
+
+  private <T> void testDanglingDelete() throws Exception {
+    testDanglingDelete(NUM_DATA_FILES);
+  }
+
+  private <T> void testDanglingDelete(int numDataFiles) throws Exception {
+    Table table = Spark3Util.loadIcebergTable(spark, tableName);
+
+    List<DataFile> dataFiles = dataFiles(table);
+    Assert.assertEquals(numDataFiles, dataFiles.size());
+
+    SparkActions.get(spark)
+        .rewriteDataFiles(table)
+        .option(SizeBasedFileRewriter.REWRITE_ALL, "true")
+        .execute();
+
+    // write dangling delete files for 'old data files'
+    writePosDeletesForFiles(table, dataFiles);
+    List<DeleteFile> deleteFiles = deleteFiles(table);
+    Assert.assertEquals(numDataFiles * DELETE_FILES_PER_PARTITION, deleteFiles.size());
+
+    List<Object[]> expectedRecords = records(tableName);
+
+    Result result =
+        SparkActions.get(spark)
+            .rewritePositionDeletes(table)
+            .option(SizeBasedFileRewriter.REWRITE_ALL, "true")
+            .execute();
+
+    List<DeleteFile> newDeleteFiles = deleteFiles(table);
+    Assert.assertEquals("Should have removed all dangling delete files", 0, newDeleteFiles.size());
+    checkResult(result, deleteFiles, Lists.newArrayList(), numDataFiles);
+
+    List<Object[]> actualRecords = records(tableName);
+    assertEquals("Rows must match", expectedRecords, actualRecords);
+  }
+
+  private void createTable(String partitionType) {
+    createTable(partitionType, "partition_col");
+  }
+
+  private void createTable(String partitionType, String partitionCol) {
+    sql(
+        "CREATE TABLE %s (id long, partition_col %s, c1 string, c2 string) "
+            + "USING iceberg "
+            + "PARTITIONED BY (%s) "
+            + "TBLPROPERTIES('format-version'='2')",
+        tableName, partitionType, partitionCol);
+  }
+
+  private <T> void insertData(Function<Integer, ?> partitionValueFunction) throws Exception {
+    insertData(partitionValueFunction, NUM_DATA_FILES);
+  }
+
+  private <T> void insertData(Function<Integer, ?> partitionValue, int numDataFiles)
+      throws Exception {
+    for (int i = 0; i < numDataFiles; i++) {
+      Dataset<Row> df =
+          spark
+              .range(0, ROWS_PER_DATA_FILE)
+              .withColumn("partition_col", lit(partitionValue.apply(i)))
+              .withColumn("c1", expr("CAST(id AS STRING)"))
+              .withColumn("c2", expr("CAST(id AS STRING)"));
+      appendAsFile(df);
+    }
+  }
+
+  private void appendAsFile(Dataset<Row> df) throws Exception {
+    // ensure the schema is precise
+    StructType sparkSchema = spark.table(tableName).schema();
+    spark.createDataFrame(df.rdd(), sparkSchema).coalesce(1).writeTo(tableName).append();
+  }
+
+  private void writePosDeletesForFiles(Table table, List<DataFile> files) throws IOException {
+
+    Map<StructLike, List<DataFile>> filesByPartition =
+        files.stream().collect(Collectors.groupingBy(ContentFile::partition));
+    List<DeleteFile> deleteFiles =
+        Lists.newArrayListWithCapacity(DELETE_FILES_PER_PARTITION * filesByPartition.size());
+
+    for (Map.Entry<StructLike, List<DataFile>> filesByPartitionEntry :
+        filesByPartition.entrySet()) {
+
+      StructLike partition = filesByPartitionEntry.getKey();
+      List<DataFile> partitionFiles = filesByPartitionEntry.getValue();
+
+      int deletesForPartition = partitionFiles.size() * DELETE_FILE_SIZE;
+      Assert.assertEquals(
+          "Number of delete files per partition should be "
+              + "evenly divisible by requested deletes per data file times number of data files in this partition",
+          0,
+          deletesForPartition % DELETE_FILE_SIZE);
+      int deleteFileSize = deletesForPartition / DELETE_FILES_PER_PARTITION;
+
+      int counter = 0;
+      List<Pair<CharSequence, Long>> deletes = Lists.newArrayList();
+      for (DataFile partitionFile : partitionFiles) {
+        for (int deletePos = 0; deletePos < DELETE_FILE_SIZE; deletePos++) {
+          deletes.add(Pair.of(partitionFile.path(), (long) deletePos));
+          counter++;
+          if (counter == deleteFileSize) {
+            // Dump to file and reset variables
+            OutputFile output = Files.localOutput(temp.newFile());
+            deleteFiles.add(writeDeleteFile(table, output, partition, deletes));
+            counter = 0;
+            deletes.clear();
+          }
+        }
+      }
+    }
+
+    RowDelta rowDelta = table.newRowDelta();
+    deleteFiles.forEach(rowDelta::addDeletes);
+    rowDelta.commit();
+  }
+
+  private DeleteFile writeDeleteFile(
+      Table table, OutputFile out, StructLike partition, List<Pair<CharSequence, Long>> deletes)
+      throws IOException {
+    FileFormat format = defaultFormat(table.properties());
+    FileAppenderFactory<Record> factory = new GenericAppenderFactory(table.schema(), table.spec());
+
+    PositionDeleteWriter<Record> writer =
+        factory.newPosDeleteWriter(encrypt(out), format, partition);
+    PositionDelete<Record> posDelete = PositionDelete.create();
+    try (Closeable toClose = writer) {
+      for (Pair<CharSequence, Long> delete : deletes) {
+        writer.write(posDelete.set(delete.first(), delete.second(), null));
+      }
+    }
+
+    return writer.toDeleteFile();
+  }
+
+  private static EncryptedOutputFile encrypt(OutputFile out) {
+    return EncryptedFiles.encryptedOutput(out, EncryptionKeyMetadata.EMPTY);
+  }
+
+  private static FileFormat defaultFormat(Map<String, String> properties) {
+    String formatString = properties.getOrDefault(DEFAULT_FILE_FORMAT, DEFAULT_FILE_FORMAT_DEFAULT);
+    return FileFormat.fromString(formatString);
+  }
+
+  private List<Object[]> records(String table) {
+    return rowsToJava(
+        spark.read().format("iceberg").load(table).sort("partition_col", "id").collectAsList());
+  }
+
+  private long size(List<DeleteFile> deleteFiles) {
+    return deleteFiles.stream().mapToLong(DeleteFile::fileSizeInBytes).sum();
+  }
+
+  private List<DataFile> dataFiles(Table table) {
+    CloseableIterable<FileScanTask> tasks = table.newScan().includeColumnStats().planFiles();
+    return Lists.newArrayList(CloseableIterable.transform(tasks, FileScanTask::file));
+  }
+
+  private List<DeleteFile> deleteFiles(Table table) {
+    Table deletesTable =
+        MetadataTableUtils.createMetadataTableInstance(table, MetadataTableType.POSITION_DELETES);
+    CloseableIterable<ScanTask> tasks = deletesTable.newBatchScan().planFiles();
+    return Lists.newArrayList(
+        CloseableIterable.transform(tasks, t -> ((PositionDeletesScanTask) t).file()));
+  }
+
+  private void checkResult(
+      Result result,
+      List<DeleteFile> rewrittenDeletes,
+      List<DeleteFile> newDeletes,
+      int expectedGroups) {
+    Assert.assertEquals(
+        "Expected rewritten delete file count does not match",
+        rewrittenDeletes.size(),
+        result.rewrittenDeleteFilesCount());
+    Assert.assertEquals(
+        "Expected new delete file count does not match",
+        newDeletes.size(),
+        result.addedDeleteFilesCount());
+    Assert.assertEquals(
+        "Expected rewritten delete byte count does not match",
+        size(rewrittenDeletes),
+        result.rewrittenBytesCount());
+    Assert.assertEquals(
+        "Expected new delete byte count does not match",
+        size(newDeletes),
+        result.addedBytesCount());
+
+    Assert.assertEquals(
+        "Expected rewrite group count does not match",
+        expectedGroups,
+        result.rewriteResults().size());
+    Assert.assertEquals(
+        "Expected rewritten delete file count in all groups to match",
+        rewrittenDeletes.size(),
+        result.rewriteResults().stream()
+            .mapToInt(FileGroupRewriteResult::rewrittenDeleteFilesCount)
+            .sum());
+    Assert.assertEquals(
+        "Expected added delete file count in all groups to match",
+        newDeletes.size(),
+        result.rewriteResults().stream()
+            .mapToInt(FileGroupRewriteResult::addedDeleteFilesCount)
+            .sum());
+    Assert.assertEquals(
+        "Expected rewritten delete bytes in all groups to match",
+        size(rewrittenDeletes),
+        result.rewriteResults().stream()
+            .mapToLong(FileGroupRewriteResult::rewrittenBytesCount)
+            .sum());
+    Assert.assertEquals(
+        "Expected added delete bytes in all groups to match",
+        size(newDeletes),
+        result.rewriteResults().stream().mapToLong(FileGroupRewriteResult::addedBytesCount).sum());
+  }
+}
diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkValueConverter.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkValueConverter.java
index 687d9f4..c168f77 100644
--- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkValueConverter.java
+++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkValueConverter.java
@@ -28,6 +28,7 @@
 import org.apache.iceberg.relocated.com.google.common.collect.Maps;
 import org.apache.iceberg.types.Type;
 import org.apache.iceberg.types.Types;
+import org.apache.iceberg.util.ByteBuffers;
 import org.apache.spark.sql.Row;
 import org.apache.spark.sql.catalyst.util.DateTimeUtils;
 
@@ -120,4 +121,39 @@
     }
     return record;
   }
+
+  public static Object convertToSpark(Type type, Object object) {
+    if (object == null) {
+      return null;
+    }
+
+    switch (type.typeId()) {
+      case STRUCT:
+      case LIST:
+      case MAP:
+        return new UnsupportedOperationException("Complex types currently not supported");
+      case DATE:
+        return DateTimeUtils.daysToLocalDate((int) object);
+      case TIMESTAMP:
+        Types.TimestampType ts = (Types.TimestampType) type.asPrimitiveType();
+        if (ts.shouldAdjustToUTC()) {
+          return DateTimeUtils.microsToInstant((long) object);
+        } else {
+          return DateTimeUtils.microsToLocalDateTime((long) object);
+        }
+      case BINARY:
+        return ByteBuffers.toByteArray((ByteBuffer) object);
+      case INTEGER:
+      case BOOLEAN:
+      case LONG:
+      case FLOAT:
+      case DOUBLE:
+      case DECIMAL:
+      case STRING:
+      case FIXED:
+        return object;
+      default:
+        throw new UnsupportedOperationException("Not a supported type: " + type);
+    }
+  }
 }
diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/actions/SparkBinPackPositionDeletesRewriter.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/actions/SparkBinPackPositionDeletesRewriter.java
index 51c4cc6..565acd4 100644
--- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/actions/SparkBinPackPositionDeletesRewriter.java
+++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/actions/SparkBinPackPositionDeletesRewriter.java
@@ -20,6 +20,7 @@
 
 import static org.apache.iceberg.MetadataTableType.POSITION_DELETES;
 import static org.apache.spark.sql.functions.col;
+import static org.apache.spark.sql.functions.lit;
 
 import java.util.List;
 import java.util.Optional;
@@ -40,7 +41,9 @@
 import org.apache.iceberg.spark.SparkReadOptions;
 import org.apache.iceberg.spark.SparkTableCache;
 import org.apache.iceberg.spark.SparkTableUtil;
+import org.apache.iceberg.spark.SparkValueConverter;
 import org.apache.iceberg.spark.SparkWriteOptions;
+import org.apache.iceberg.types.Type;
 import org.apache.iceberg.types.Types;
 import org.apache.spark.sql.Column;
 import org.apache.spark.sql.Dataset;
@@ -125,10 +128,11 @@
         IntStream.range(0, fields.size())
             .mapToObj(
                 i -> {
-                  Class<?> type = fields.get(i).type().typeId().javaClass();
-                  Object value = partition.get(i, type);
+                  Type type = fields.get(i).type();
+                  Object value = partition.get(i, type.typeId().javaClass());
+                  Object convertedValue = SparkValueConverter.convertToSpark(type, value);
                   Column col = col("partition." + fields.get(i).name());
-                  return col.equalTo(value);
+                  return col.eqNullSafe(lit(convertedValue));
                 })
             .reduce(Column::and);
     if (condition.isPresent()) {