Spark 3.4: Fix system function pushdown in CoW row-level commands (#10119) (#10171)

Co-authored-by: Anton Okolnychyi <aokolnychyi@apple.com>
diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceStaticInvoke.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceStaticInvoke.scala
index 1f0e164..655a93a 100644
--- a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceStaticInvoke.scala
+++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceStaticInvoke.scala
@@ -22,12 +22,20 @@
 import org.apache.spark.sql.catalyst.expressions.ApplyFunctionExpression
 import org.apache.spark.sql.catalyst.expressions.BinaryComparison
 import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.expressions.In
+import org.apache.spark.sql.catalyst.expressions.InSet
 import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
 import org.apache.spark.sql.catalyst.plans.logical.Filter
+import org.apache.spark.sql.catalyst.plans.logical.Join
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.plans.logical.ReplaceData
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.catalyst.trees.TreePattern.BINARY_COMPARISON
+import org.apache.spark.sql.catalyst.trees.TreePattern.COMMAND
 import org.apache.spark.sql.catalyst.trees.TreePattern.FILTER
+import org.apache.spark.sql.catalyst.trees.TreePattern.IN
+import org.apache.spark.sql.catalyst.trees.TreePattern.INSET
+import org.apache.spark.sql.catalyst.trees.TreePattern.JOIN
 import org.apache.spark.sql.connector.catalog.functions.ScalarFunction
 import org.apache.spark.sql.types.StructField
 import org.apache.spark.sql.types.StructType
@@ -40,21 +48,36 @@
 object ReplaceStaticInvoke extends Rule[LogicalPlan] {
 
   override def apply(plan: LogicalPlan): LogicalPlan =
-    plan.transformWithPruning (_.containsAllPatterns(BINARY_COMPARISON, FILTER)) {
-      case filter @ Filter(condition, _) =>
-        val newCondition = condition.transformWithPruning(_.containsPattern(BINARY_COMPARISON)) {
-          case c @ BinaryComparison(left: StaticInvoke, right) if canReplace(left) && right.foldable =>
-            c.withNewChildren(Seq(replaceStaticInvoke(left), right))
+    plan.transformWithPruning (_.containsAnyPattern(COMMAND, FILTER, JOIN)) {
+      case join @ Join(_, _, _, Some(cond), _) =>
+        replaceStaticInvoke(join, cond, newCond => join.copy(condition = Some(newCond)))
 
-          case c @ BinaryComparison(left, right: StaticInvoke) if canReplace(right) && left.foldable =>
-            c.withNewChildren(Seq(left, replaceStaticInvoke(right)))
-        }
+      case filter @ Filter(cond, _) =>
+        replaceStaticInvoke(filter, cond, newCond => filter.copy(condition = newCond))
+  }
 
-        if (newCondition fastEquals condition) {
-          filter
-        } else {
-          filter.copy(condition = newCondition)
-        }
+  private def replaceStaticInvoke[T <: LogicalPlan](
+      node: T,
+      condition: Expression,
+      copy: Expression => T): T = {
+    val newCondition = replaceStaticInvoke(condition)
+    if (newCondition fastEquals condition) node else copy(newCondition)
+  }
+
+  private def replaceStaticInvoke(condition: Expression): Expression = {
+    condition.transformWithPruning(_.containsAnyPattern(BINARY_COMPARISON, IN, INSET)) {
+      case in @ In(value: StaticInvoke, _) if canReplace(value) =>
+        in.copy(value = replaceStaticInvoke(value))
+
+      case in @ InSet(value: StaticInvoke, _) if canReplace(value) =>
+        in.copy(child = replaceStaticInvoke(value))
+
+      case c @ BinaryComparison(left: StaticInvoke, right) if canReplace(left) && right.foldable =>
+        c.withNewChildren(Seq(replaceStaticInvoke(left), right))
+
+      case c @ BinaryComparison(left, right: StaticInvoke) if canReplace(right) && left.foldable =>
+        c.withNewChildren(Seq(left, replaceStaticInvoke(right)))
+    }
   }
 
   private def replaceStaticInvoke(invoke: StaticInvoke): Expression = {
diff --git a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/SparkPlanUtil.java b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/SparkPlanUtil.java
index 4f7c3eb..830d07d 100644
--- a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/SparkPlanUtil.java
+++ b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/SparkPlanUtil.java
@@ -20,12 +20,17 @@
 
 import static scala.collection.JavaConverters.seqAsJavaListConverter;
 
+import java.util.Collection;
 import java.util.List;
+import java.util.function.Predicate;
 import java.util.stream.Collectors;
+import org.apache.iceberg.relocated.com.google.common.collect.Lists;
+import org.apache.spark.sql.catalyst.expressions.Expression;
 import org.apache.spark.sql.execution.CommandResultExec;
 import org.apache.spark.sql.execution.SparkPlan;
 import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper;
 import org.apache.spark.sql.execution.datasources.v2.BatchScanExec;
+import scala.PartialFunction;
 import scala.collection.Seq;
 
 public class SparkPlanUtil {
@@ -53,6 +58,49 @@
     }
   }
 
+  public static List<Expression> collectExprs(
+      SparkPlan sparkPlan, Predicate<Expression> predicate) {
+    Seq<List<Expression>> seq =
+        SPARK_HELPER.collect(
+            sparkPlan,
+            new PartialFunction<SparkPlan, List<Expression>>() {
+              @Override
+              public List<Expression> apply(SparkPlan plan) {
+                List<Expression> exprs = Lists.newArrayList();
+
+                for (Expression expr : toJavaList(plan.expressions())) {
+                  exprs.addAll(collectExprs(expr, predicate));
+                }
+
+                return exprs;
+              }
+
+              @Override
+              public boolean isDefinedAt(SparkPlan plan) {
+                return true;
+              }
+            });
+    return toJavaList(seq).stream().flatMap(Collection::stream).collect(Collectors.toList());
+  }
+
+  private static List<Expression> collectExprs(
+      Expression expression, Predicate<Expression> predicate) {
+    Seq<Expression> seq =
+        expression.collect(
+            new PartialFunction<Expression, Expression>() {
+              @Override
+              public Expression apply(Expression expr) {
+                return expr;
+              }
+
+              @Override
+              public boolean isDefinedAt(Expression expr) {
+                return predicate.test(expr);
+              }
+            });
+    return toJavaList(seq);
+  }
+
   private static <T> List<T> toJavaList(Seq<T> seq) {
     return seqAsJavaListConverter(seq).asJava();
   }
diff --git a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSystemFunctionPushDownInRowLevelOperations.java b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSystemFunctionPushDownInRowLevelOperations.java
new file mode 100644
index 0000000..db4d106
--- /dev/null
+++ b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSystemFunctionPushDownInRowLevelOperations.java
@@ -0,0 +1,354 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.spark.extensions;
+
+import static org.apache.iceberg.RowLevelOperationMode.COPY_ON_WRITE;
+import static org.apache.iceberg.RowLevelOperationMode.MERGE_ON_READ;
+import static org.assertj.core.api.Assertions.assertThat;
+
+import java.util.List;
+import java.util.Map;
+import org.apache.iceberg.DistributionMode;
+import org.apache.iceberg.MetadataColumns;
+import org.apache.iceberg.RowLevelOperationMode;
+import org.apache.iceberg.TableProperties;
+import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
+import org.apache.iceberg.spark.SparkCatalogConfig;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Encoders;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.catalyst.expressions.ApplyFunctionExpression;
+import org.apache.spark.sql.catalyst.expressions.Expression;
+import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke;
+import org.apache.spark.sql.execution.CommandResultExec;
+import org.apache.spark.sql.execution.datasources.v2.V2TableWriteExec;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runners.Parameterized.Parameters;
+
+public class TestSystemFunctionPushDownInRowLevelOperations extends SparkExtensionsTestBase {
+
+  private static final String CHANGES_TABLE_NAME = "changes";
+
+  @Parameters(name = "catalogName = {0}, implementation = {1}, config = {2}")
+  public static Object[][] parameters() {
+    return new Object[][] {
+      {
+        SparkCatalogConfig.HIVE.catalogName(),
+        SparkCatalogConfig.HIVE.implementation(),
+        SparkCatalogConfig.HIVE.properties()
+      }
+    };
+  }
+
+  public TestSystemFunctionPushDownInRowLevelOperations(
+      String catalogName, String implementation, Map<String, String> config) {
+    super(catalogName, implementation, config);
+  }
+
+  @Before
+  public void beforeEach() {
+    sql("USE %s", catalogName);
+  }
+
+  @After
+  public void removeTables() {
+    sql("DROP TABLE IF EXISTS %s PURGE", tableName);
+    sql("DROP TABLE IF EXISTS %s PURGE", tableName(CHANGES_TABLE_NAME));
+  }
+
+  @Test
+  public void testCopyOnWriteDeleteBucketTransformInPredicate() {
+    initTable("bucket(4, dep)");
+    checkDelete(COPY_ON_WRITE, "system.bucket(4, dep) IN (2, 3)");
+  }
+
+  @Test
+  public void testMergeOnReadDeleteBucketTransformInPredicate() {
+    initTable("bucket(4, dep)");
+    checkDelete(MERGE_ON_READ, "system.bucket(4, dep) IN (2, 3)");
+  }
+
+  @Test
+  public void testCopyOnWriteDeleteBucketTransformEqPredicate() {
+    initTable("bucket(4, dep)");
+    checkDelete(COPY_ON_WRITE, "system.bucket(4, dep) = 2");
+  }
+
+  @Test
+  public void testMergeOnReadDeleteBucketTransformEqPredicate() {
+    initTable("bucket(4, dep)");
+    checkDelete(MERGE_ON_READ, "system.bucket(4, dep) = 2");
+  }
+
+  @Test
+  public void testCopyOnWriteDeleteYearsTransform() {
+    initTable("years(ts)");
+    checkDelete(COPY_ON_WRITE, "system.years(ts) > 30");
+  }
+
+  @Test
+  public void testMergeOnReadDeleteYearsTransform() {
+    initTable("years(ts)");
+    checkDelete(MERGE_ON_READ, "system.years(ts) <= 30");
+  }
+
+  @Test
+  public void testCopyOnWriteDeleteMonthsTransform() {
+    initTable("months(ts)");
+    checkDelete(COPY_ON_WRITE, "system.months(ts) <= 250");
+  }
+
+  @Test
+  public void testMergeOnReadDeleteMonthsTransform() {
+    initTable("months(ts)");
+    checkDelete(MERGE_ON_READ, "system.months(ts) > 250");
+  }
+
+  @Test
+  public void testCopyOnWriteDeleteDaysTransform() {
+    initTable("days(ts)");
+    checkDelete(COPY_ON_WRITE, "system.days(ts) <= date('2000-01-03 00:00:00')");
+  }
+
+  @Test
+  public void testMergeOnReadDeleteDaysTransform() {
+    initTable("days(ts)");
+    checkDelete(MERGE_ON_READ, "system.days(ts) > date('2000-01-03 00:00:00')");
+  }
+
+  @Test
+  public void testCopyOnWriteDeleteHoursTransform() {
+    initTable("hours(ts)");
+    checkDelete(COPY_ON_WRITE, "system.hours(ts) <= 100000");
+  }
+
+  @Test
+  public void testMergeOnReadDeleteHoursTransform() {
+    initTable("hours(ts)");
+    checkDelete(MERGE_ON_READ, "system.hours(ts) > 100000");
+  }
+
+  @Test
+  public void testCopyOnWriteDeleteTruncateTransform() {
+    initTable("truncate(1, dep)");
+    checkDelete(COPY_ON_WRITE, "system.truncate(1, dep) = 'i'");
+  }
+
+  @Test
+  public void testMergeOnReadDeleteTruncateTransform() {
+    initTable("truncate(1, dep)");
+    checkDelete(MERGE_ON_READ, "system.truncate(1, dep) = 'i'");
+  }
+
+  @Test
+  public void testCopyOnWriteUpdateBucketTransform() {
+    initTable("bucket(4, dep)");
+    checkUpdate(COPY_ON_WRITE, "system.bucket(4, dep) IN (2, 3)");
+  }
+
+  @Test
+  public void testMergeOnReadUpdateBucketTransform() {
+    initTable("bucket(4, dep)");
+    checkUpdate(MERGE_ON_READ, "system.bucket(4, dep) = 2");
+  }
+
+  @Test
+  public void testCopyOnWriteUpdateYearsTransform() {
+    initTable("years(ts)");
+    checkUpdate(COPY_ON_WRITE, "system.years(ts) > 30");
+  }
+
+  @Test
+  public void testMergeOnReadUpdateYearsTransform() {
+    initTable("years(ts)");
+    checkUpdate(MERGE_ON_READ, "system.years(ts) <= 30");
+  }
+
+  @Test
+  public void testCopyOnWriteMergeBucketTransform() {
+    initTable("bucket(4, dep)");
+    checkMerge(COPY_ON_WRITE, "system.bucket(4, dep) IN (2, 3)");
+  }
+
+  @Test
+  public void testMergeOnReadMergeBucketTransform() {
+    initTable("bucket(4, dep)");
+    checkMerge(MERGE_ON_READ, "system.bucket(4, dep) = 2");
+  }
+
+  @Test
+  public void testCopyOnWriteMergeYearsTransform() {
+    initTable("years(ts)");
+    checkMerge(COPY_ON_WRITE, "system.years(ts) > 30");
+  }
+
+  @Test
+  public void testMergeOnReadMergeYearsTransform() {
+    initTable("years(ts)");
+    checkMerge(MERGE_ON_READ, "system.years(ts) <= 30");
+  }
+
+  @Test
+  public void testCopyOnWriteMergeTruncateTransform() {
+    initTable("truncate(1, dep)");
+    checkMerge(COPY_ON_WRITE, "system.truncate(1, dep) = 'i'");
+  }
+
+  @Test
+  public void testMergeOnReadMergeTruncateTransform() {
+    initTable("truncate(1, dep)");
+    checkMerge(MERGE_ON_READ, "system.truncate(1, dep) = 'i'");
+  }
+
+  private void checkDelete(RowLevelOperationMode mode, String cond) {
+    withUnavailableLocations(
+        findIrrelevantFileLocations(cond),
+        () -> {
+          sql(
+              "ALTER TABLE %s SET TBLPROPERTIES ('%s' '%s', '%s' '%s')",
+              tableName,
+              TableProperties.DELETE_MODE,
+              mode.modeName(),
+              TableProperties.DELETE_DISTRIBUTION_MODE,
+              DistributionMode.NONE.modeName());
+
+          Dataset<Row> changeDF = spark.table(tableName).where(cond).limit(2).select("id");
+          changeDF.coalesce(1).writeTo(tableName(CHANGES_TABLE_NAME)).create();
+
+          List<Expression> calls =
+              executeAndCollectFunctionCalls(
+                  "DELETE FROM %s t WHERE %s AND t.id IN (SELECT id FROM %s)",
+                  tableName, cond, tableName(CHANGES_TABLE_NAME));
+          // CoW planning currently does not optimize post-scan filters in DELETE
+          int expectedCallCount = mode == COPY_ON_WRITE ? 1 : 0;
+          assertThat(calls).hasSize(expectedCallCount);
+
+          assertEquals(
+              "Should have no matching rows",
+              ImmutableList.of(),
+              sql(
+                  "SELECT * FROM %s WHERE %s AND id IN (SELECT * FROM %s)",
+                  tableName, cond, tableName(CHANGES_TABLE_NAME)));
+        });
+  }
+
+  private void checkUpdate(RowLevelOperationMode mode, String cond) {
+    withUnavailableLocations(
+        findIrrelevantFileLocations(cond),
+        () -> {
+          sql(
+              "ALTER TABLE %s SET TBLPROPERTIES ('%s' '%s', '%s' '%s')",
+              tableName,
+              TableProperties.UPDATE_MODE,
+              mode.modeName(),
+              TableProperties.UPDATE_DISTRIBUTION_MODE,
+              DistributionMode.NONE.modeName());
+
+          Dataset<Row> changeDF = spark.table(tableName).where(cond).limit(2).select("id");
+          changeDF.coalesce(1).writeTo(tableName(CHANGES_TABLE_NAME)).create();
+
+          List<Expression> calls =
+              executeAndCollectFunctionCalls(
+                  "UPDATE %s t SET t.salary = -1 WHERE %s AND t.id IN (SELECT id FROM %s)",
+                  tableName, cond, tableName(CHANGES_TABLE_NAME));
+          // CoW planning currently does not optimize post-scan filters in UPDATE
+          int expectedCallCount = mode == COPY_ON_WRITE ? 2 : 0;
+          assertThat(calls).hasSize(expectedCallCount);
+
+          assertEquals(
+              "Should have correct updates",
+              sql("SELECT id FROM %s", tableName(CHANGES_TABLE_NAME)),
+              sql("SELECT id FROM %s WHERE %s AND salary = -1", tableName, cond));
+        });
+  }
+
+  private void checkMerge(RowLevelOperationMode mode, String cond) {
+    withUnavailableLocations(
+        findIrrelevantFileLocations(cond),
+        () -> {
+          sql(
+              "ALTER TABLE %s SET TBLPROPERTIES ('%s' '%s', '%s' '%s')",
+              tableName,
+              TableProperties.MERGE_MODE,
+              mode.modeName(),
+              TableProperties.MERGE_DISTRIBUTION_MODE,
+              DistributionMode.NONE.modeName());
+
+          Dataset<Row> changeDF =
+              spark.table(tableName).where(cond).limit(2).selectExpr("id + 1 as id");
+          changeDF.coalesce(1).writeTo(tableName(CHANGES_TABLE_NAME)).create();
+
+          List<Expression> calls =
+              executeAndCollectFunctionCalls(
+                  "MERGE INTO %s t USING %s s "
+                      + "ON t.id == s.id AND %s "
+                      + "WHEN MATCHED THEN "
+                      + "  UPDATE SET salary = -1 "
+                      + "WHEN NOT MATCHED AND s.id = 2 THEN "
+                      + "  INSERT (id, salary, dep, ts) VALUES (100, -1, 'hr', null)",
+                  tableName, tableName(CHANGES_TABLE_NAME), cond);
+          assertThat(calls).isEmpty();
+
+          assertEquals(
+              "Should have correct updates",
+              sql("SELECT id FROM %s", tableName(CHANGES_TABLE_NAME)),
+              sql("SELECT id FROM %s WHERE %s AND salary = -1", tableName, cond));
+        });
+  }
+
+  private List<Expression> executeAndCollectFunctionCalls(String query, Object... args) {
+    CommandResultExec command = (CommandResultExec) executeAndKeepPlan(query, args);
+    V2TableWriteExec write = (V2TableWriteExec) command.commandPhysicalPlan();
+    System.out.println("!!! WRITE PLAN !!!");
+    System.out.println(write.toString());
+    return SparkPlanUtil.collectExprs(
+        write.query(),
+        expr -> expr instanceof StaticInvoke || expr instanceof ApplyFunctionExpression);
+  }
+
+  private List<String> findIrrelevantFileLocations(String cond) {
+    return spark
+        .table(tableName)
+        .where("NOT " + cond)
+        .select(MetadataColumns.FILE_PATH.name())
+        .distinct()
+        .as(Encoders.STRING())
+        .collectAsList();
+  }
+
+  private void initTable(String transform) {
+    sql(
+        "CREATE TABLE %s (id BIGINT, salary INT, dep STRING, ts TIMESTAMP)"
+            + "USING iceberg "
+            + "PARTITIONED BY (%s) "
+            + "TBLPROPERTIES ('%s' 'true')",
+        tableName, transform, TableProperties.SPARK_WRITE_PARTITIONED_FANOUT_ENABLED);
+
+    append(
+        tableName,
+        "{ \"id\": 1, \"salary\": 100, \"dep\": \"hr\", \"ts\": \"1975-01-01 06:00:00\" }",
+        "{ \"id\": 2, \"salary\": 200, \"dep\": \"hr\", \"ts\": \"1975-01-01 06:00:00\" }",
+        "{ \"id\": 3, \"salary\": 300, \"dep\": \"hr\", \"ts\": \"1975-01-01 06:00:00\" }",
+        "{ \"id\": 4, \"salary\": 400, \"dep\": \"it\", \"ts\": \"2020-01-01 10:00:00\" }",
+        "{ \"id\": 5, \"salary\": 500, \"dep\": \"it\", \"ts\": \"2020-01-01 10:00:00\" }",
+        "{ \"id\": 6, \"salary\": 600, \"dep\": \"it\", \"ts\": \"2020-01-01 10:00:00\" }");
+  }
+}
diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/BaseScalarFunction.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/BaseScalarFunction.java
new file mode 100644
index 0000000..5ec44f3
--- /dev/null
+++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/BaseScalarFunction.java
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.spark.functions;
+
+import org.apache.spark.sql.connector.catalog.functions.ScalarFunction;
+
+abstract class BaseScalarFunction<R> implements ScalarFunction<R> {
+  @Override
+  public int hashCode() {
+    return canonicalName().hashCode();
+  }
+
+  @Override
+  public boolean equals(Object other) {
+    if (this == other) {
+      return true;
+    } else if (!(other instanceof ScalarFunction)) {
+      return false;
+    }
+
+    ScalarFunction<?> that = (ScalarFunction<?>) other;
+    return canonicalName().equals(that.canonicalName());
+  }
+}
diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/BucketFunction.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/BucketFunction.java
index af3c67a..c3de3d4 100644
--- a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/BucketFunction.java
+++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/BucketFunction.java
@@ -25,7 +25,6 @@
 import org.apache.iceberg.util.BucketUtil;
 import org.apache.spark.sql.catalyst.InternalRow;
 import org.apache.spark.sql.connector.catalog.functions.BoundFunction;
-import org.apache.spark.sql.connector.catalog.functions.ScalarFunction;
 import org.apache.spark.sql.connector.catalog.functions.UnboundFunction;
 import org.apache.spark.sql.types.BinaryType;
 import org.apache.spark.sql.types.ByteType;
@@ -115,7 +114,7 @@
     return "bucket";
   }
 
-  public abstract static class BucketBase implements ScalarFunction<Integer> {
+  public abstract static class BucketBase extends BaseScalarFunction<Integer> {
     public static int apply(int numBuckets, int hashedValue) {
       return (hashedValue & Integer.MAX_VALUE) % numBuckets;
     }
diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/DaysFunction.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/DaysFunction.java
index b8d28b7..f52edd9 100644
--- a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/DaysFunction.java
+++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/DaysFunction.java
@@ -21,7 +21,6 @@
 import org.apache.iceberg.util.DateTimeUtil;
 import org.apache.spark.sql.catalyst.InternalRow;
 import org.apache.spark.sql.connector.catalog.functions.BoundFunction;
-import org.apache.spark.sql.connector.catalog.functions.ScalarFunction;
 import org.apache.spark.sql.types.DataType;
 import org.apache.spark.sql.types.DataTypes;
 import org.apache.spark.sql.types.DateType;
@@ -61,7 +60,7 @@
     return "days";
   }
 
-  private abstract static class BaseToDaysFunction implements ScalarFunction<Integer> {
+  private abstract static class BaseToDaysFunction extends BaseScalarFunction<Integer> {
     @Override
     public String name() {
       return "days";
diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/HoursFunction.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/HoursFunction.java
index 18697e1..660a182 100644
--- a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/HoursFunction.java
+++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/HoursFunction.java
@@ -21,7 +21,6 @@
 import org.apache.iceberg.util.DateTimeUtil;
 import org.apache.spark.sql.catalyst.InternalRow;
 import org.apache.spark.sql.connector.catalog.functions.BoundFunction;
-import org.apache.spark.sql.connector.catalog.functions.ScalarFunction;
 import org.apache.spark.sql.types.DataType;
 import org.apache.spark.sql.types.DataTypes;
 import org.apache.spark.sql.types.TimestampNTZType;
@@ -58,7 +57,7 @@
     return "hours";
   }
 
-  public static class TimestampToHoursFunction implements ScalarFunction<Integer> {
+  public static class TimestampToHoursFunction extends BaseScalarFunction<Integer> {
     // magic method used in codegen
     public static int invoke(long micros) {
       return DateTimeUtil.microsToHours(micros);
@@ -91,7 +90,7 @@
     }
   }
 
-  public static class TimestampNtzToHoursFunction implements ScalarFunction<Integer> {
+  public static class TimestampNtzToHoursFunction extends BaseScalarFunction<Integer> {
     // magic method used in codegen
     public static int invoke(long micros) {
       return DateTimeUtil.microsToHours(micros);
diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/IcebergVersionFunction.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/IcebergVersionFunction.java
index 9cd0593..689a0f4 100644
--- a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/IcebergVersionFunction.java
+++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/IcebergVersionFunction.java
@@ -21,7 +21,6 @@
 import org.apache.iceberg.IcebergBuild;
 import org.apache.spark.sql.catalyst.InternalRow;
 import org.apache.spark.sql.connector.catalog.functions.BoundFunction;
-import org.apache.spark.sql.connector.catalog.functions.ScalarFunction;
 import org.apache.spark.sql.connector.catalog.functions.UnboundFunction;
 import org.apache.spark.sql.types.DataType;
 import org.apache.spark.sql.types.DataTypes;
@@ -55,7 +54,7 @@
 
   // Implementing class cannot be private, otherwise Spark is unable to access the static invoke
   // function during code-gen and calling the function fails
-  static class IcebergVersionFunctionImpl implements ScalarFunction<UTF8String> {
+  static class IcebergVersionFunctionImpl extends BaseScalarFunction<UTF8String> {
     private static final UTF8String VERSION = UTF8String.fromString(IcebergBuild.version());
 
     // magic function used in code-gen. must be named `invoke`.
diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/MonthsFunction.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/MonthsFunction.java
index 1d38014..353d850 100644
--- a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/MonthsFunction.java
+++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/MonthsFunction.java
@@ -21,7 +21,6 @@
 import org.apache.iceberg.util.DateTimeUtil;
 import org.apache.spark.sql.catalyst.InternalRow;
 import org.apache.spark.sql.connector.catalog.functions.BoundFunction;
-import org.apache.spark.sql.connector.catalog.functions.ScalarFunction;
 import org.apache.spark.sql.types.DataType;
 import org.apache.spark.sql.types.DataTypes;
 import org.apache.spark.sql.types.DateType;
@@ -61,7 +60,7 @@
     return "months";
   }
 
-  private abstract static class BaseToMonthsFunction implements ScalarFunction<Integer> {
+  private abstract static class BaseToMonthsFunction extends BaseScalarFunction<Integer> {
     @Override
     public String name() {
       return "months";
diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/TruncateFunction.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/TruncateFunction.java
index 8cfb529..fac90c9 100644
--- a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/TruncateFunction.java
+++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/TruncateFunction.java
@@ -27,7 +27,6 @@
 import org.apache.iceberg.util.TruncateUtil;
 import org.apache.spark.sql.catalyst.InternalRow;
 import org.apache.spark.sql.connector.catalog.functions.BoundFunction;
-import org.apache.spark.sql.connector.catalog.functions.ScalarFunction;
 import org.apache.spark.sql.connector.catalog.functions.UnboundFunction;
 import org.apache.spark.sql.types.BinaryType;
 import org.apache.spark.sql.types.ByteType;
@@ -108,7 +107,7 @@
     return "truncate";
   }
 
-  public abstract static class TruncateBase<T> implements ScalarFunction<T> {
+  public abstract static class TruncateBase<T> extends BaseScalarFunction<T> {
     @Override
     public String name() {
       return "truncate";
diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/YearsFunction.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/YearsFunction.java
index 02642e6..cfd1b0e 100644
--- a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/YearsFunction.java
+++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/YearsFunction.java
@@ -21,7 +21,6 @@
 import org.apache.iceberg.util.DateTimeUtil;
 import org.apache.spark.sql.catalyst.InternalRow;
 import org.apache.spark.sql.connector.catalog.functions.BoundFunction;
-import org.apache.spark.sql.connector.catalog.functions.ScalarFunction;
 import org.apache.spark.sql.types.DataType;
 import org.apache.spark.sql.types.DataTypes;
 import org.apache.spark.sql.types.DateType;
@@ -61,7 +60,7 @@
     return "years";
   }
 
-  private abstract static class BaseToYearsFunction implements ScalarFunction<Integer> {
+  private abstract static class BaseToYearsFunction extends BaseScalarFunction<Integer> {
     @Override
     public String name() {
       return "years";