Spark 3.4: Fix system function pushdown in CoW row-level commands (#10119) (#10171)
Co-authored-by: Anton Okolnychyi <aokolnychyi@apple.com>
diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceStaticInvoke.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceStaticInvoke.scala
index 1f0e164..655a93a 100644
--- a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceStaticInvoke.scala
+++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceStaticInvoke.scala
@@ -22,12 +22,20 @@
import org.apache.spark.sql.catalyst.expressions.ApplyFunctionExpression
import org.apache.spark.sql.catalyst.expressions.BinaryComparison
import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.expressions.In
+import org.apache.spark.sql.catalyst.expressions.InSet
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
import org.apache.spark.sql.catalyst.plans.logical.Filter
+import org.apache.spark.sql.catalyst.plans.logical.Join
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.plans.logical.ReplaceData
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.BINARY_COMPARISON
+import org.apache.spark.sql.catalyst.trees.TreePattern.COMMAND
import org.apache.spark.sql.catalyst.trees.TreePattern.FILTER
+import org.apache.spark.sql.catalyst.trees.TreePattern.IN
+import org.apache.spark.sql.catalyst.trees.TreePattern.INSET
+import org.apache.spark.sql.catalyst.trees.TreePattern.JOIN
import org.apache.spark.sql.connector.catalog.functions.ScalarFunction
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.types.StructType
@@ -40,21 +48,36 @@
object ReplaceStaticInvoke extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan =
- plan.transformWithPruning (_.containsAllPatterns(BINARY_COMPARISON, FILTER)) {
- case filter @ Filter(condition, _) =>
- val newCondition = condition.transformWithPruning(_.containsPattern(BINARY_COMPARISON)) {
- case c @ BinaryComparison(left: StaticInvoke, right) if canReplace(left) && right.foldable =>
- c.withNewChildren(Seq(replaceStaticInvoke(left), right))
+ plan.transformWithPruning (_.containsAnyPattern(COMMAND, FILTER, JOIN)) {
+ case join @ Join(_, _, _, Some(cond), _) =>
+ replaceStaticInvoke(join, cond, newCond => join.copy(condition = Some(newCond)))
- case c @ BinaryComparison(left, right: StaticInvoke) if canReplace(right) && left.foldable =>
- c.withNewChildren(Seq(left, replaceStaticInvoke(right)))
- }
+ case filter @ Filter(cond, _) =>
+ replaceStaticInvoke(filter, cond, newCond => filter.copy(condition = newCond))
+ }
- if (newCondition fastEquals condition) {
- filter
- } else {
- filter.copy(condition = newCondition)
- }
+ private def replaceStaticInvoke[T <: LogicalPlan](
+ node: T,
+ condition: Expression,
+ copy: Expression => T): T = {
+ val newCondition = replaceStaticInvoke(condition)
+ if (newCondition fastEquals condition) node else copy(newCondition)
+ }
+
+ private def replaceStaticInvoke(condition: Expression): Expression = {
+ condition.transformWithPruning(_.containsAnyPattern(BINARY_COMPARISON, IN, INSET)) {
+ case in @ In(value: StaticInvoke, _) if canReplace(value) =>
+ in.copy(value = replaceStaticInvoke(value))
+
+ case in @ InSet(value: StaticInvoke, _) if canReplace(value) =>
+ in.copy(child = replaceStaticInvoke(value))
+
+ case c @ BinaryComparison(left: StaticInvoke, right) if canReplace(left) && right.foldable =>
+ c.withNewChildren(Seq(replaceStaticInvoke(left), right))
+
+ case c @ BinaryComparison(left, right: StaticInvoke) if canReplace(right) && left.foldable =>
+ c.withNewChildren(Seq(left, replaceStaticInvoke(right)))
+ }
}
private def replaceStaticInvoke(invoke: StaticInvoke): Expression = {
diff --git a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/SparkPlanUtil.java b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/SparkPlanUtil.java
index 4f7c3eb..830d07d 100644
--- a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/SparkPlanUtil.java
+++ b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/SparkPlanUtil.java
@@ -20,12 +20,17 @@
import static scala.collection.JavaConverters.seqAsJavaListConverter;
+import java.util.Collection;
import java.util.List;
+import java.util.function.Predicate;
import java.util.stream.Collectors;
+import org.apache.iceberg.relocated.com.google.common.collect.Lists;
+import org.apache.spark.sql.catalyst.expressions.Expression;
import org.apache.spark.sql.execution.CommandResultExec;
import org.apache.spark.sql.execution.SparkPlan;
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper;
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec;
+import scala.PartialFunction;
import scala.collection.Seq;
public class SparkPlanUtil {
@@ -53,6 +58,49 @@
}
}
+ public static List<Expression> collectExprs(
+ SparkPlan sparkPlan, Predicate<Expression> predicate) {
+ Seq<List<Expression>> seq =
+ SPARK_HELPER.collect(
+ sparkPlan,
+ new PartialFunction<SparkPlan, List<Expression>>() {
+ @Override
+ public List<Expression> apply(SparkPlan plan) {
+ List<Expression> exprs = Lists.newArrayList();
+
+ for (Expression expr : toJavaList(plan.expressions())) {
+ exprs.addAll(collectExprs(expr, predicate));
+ }
+
+ return exprs;
+ }
+
+ @Override
+ public boolean isDefinedAt(SparkPlan plan) {
+ return true;
+ }
+ });
+ return toJavaList(seq).stream().flatMap(Collection::stream).collect(Collectors.toList());
+ }
+
+ private static List<Expression> collectExprs(
+ Expression expression, Predicate<Expression> predicate) {
+ Seq<Expression> seq =
+ expression.collect(
+ new PartialFunction<Expression, Expression>() {
+ @Override
+ public Expression apply(Expression expr) {
+ return expr;
+ }
+
+ @Override
+ public boolean isDefinedAt(Expression expr) {
+ return predicate.test(expr);
+ }
+ });
+ return toJavaList(seq);
+ }
+
private static <T> List<T> toJavaList(Seq<T> seq) {
return seqAsJavaListConverter(seq).asJava();
}
diff --git a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSystemFunctionPushDownInRowLevelOperations.java b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSystemFunctionPushDownInRowLevelOperations.java
new file mode 100644
index 0000000..db4d106
--- /dev/null
+++ b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSystemFunctionPushDownInRowLevelOperations.java
@@ -0,0 +1,354 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.spark.extensions;
+
+import static org.apache.iceberg.RowLevelOperationMode.COPY_ON_WRITE;
+import static org.apache.iceberg.RowLevelOperationMode.MERGE_ON_READ;
+import static org.assertj.core.api.Assertions.assertThat;
+
+import java.util.List;
+import java.util.Map;
+import org.apache.iceberg.DistributionMode;
+import org.apache.iceberg.MetadataColumns;
+import org.apache.iceberg.RowLevelOperationMode;
+import org.apache.iceberg.TableProperties;
+import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
+import org.apache.iceberg.spark.SparkCatalogConfig;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Encoders;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.catalyst.expressions.ApplyFunctionExpression;
+import org.apache.spark.sql.catalyst.expressions.Expression;
+import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke;
+import org.apache.spark.sql.execution.CommandResultExec;
+import org.apache.spark.sql.execution.datasources.v2.V2TableWriteExec;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runners.Parameterized.Parameters;
+
+public class TestSystemFunctionPushDownInRowLevelOperations extends SparkExtensionsTestBase {
+
+ private static final String CHANGES_TABLE_NAME = "changes";
+
+ @Parameters(name = "catalogName = {0}, implementation = {1}, config = {2}")
+ public static Object[][] parameters() {
+ return new Object[][] {
+ {
+ SparkCatalogConfig.HIVE.catalogName(),
+ SparkCatalogConfig.HIVE.implementation(),
+ SparkCatalogConfig.HIVE.properties()
+ }
+ };
+ }
+
+ public TestSystemFunctionPushDownInRowLevelOperations(
+ String catalogName, String implementation, Map<String, String> config) {
+ super(catalogName, implementation, config);
+ }
+
+ @Before
+ public void beforeEach() {
+ sql("USE %s", catalogName);
+ }
+
+ @After
+ public void removeTables() {
+ sql("DROP TABLE IF EXISTS %s PURGE", tableName);
+ sql("DROP TABLE IF EXISTS %s PURGE", tableName(CHANGES_TABLE_NAME));
+ }
+
+ @Test
+ public void testCopyOnWriteDeleteBucketTransformInPredicate() {
+ initTable("bucket(4, dep)");
+ checkDelete(COPY_ON_WRITE, "system.bucket(4, dep) IN (2, 3)");
+ }
+
+ @Test
+ public void testMergeOnReadDeleteBucketTransformInPredicate() {
+ initTable("bucket(4, dep)");
+ checkDelete(MERGE_ON_READ, "system.bucket(4, dep) IN (2, 3)");
+ }
+
+ @Test
+ public void testCopyOnWriteDeleteBucketTransformEqPredicate() {
+ initTable("bucket(4, dep)");
+ checkDelete(COPY_ON_WRITE, "system.bucket(4, dep) = 2");
+ }
+
+ @Test
+ public void testMergeOnReadDeleteBucketTransformEqPredicate() {
+ initTable("bucket(4, dep)");
+ checkDelete(MERGE_ON_READ, "system.bucket(4, dep) = 2");
+ }
+
+ @Test
+ public void testCopyOnWriteDeleteYearsTransform() {
+ initTable("years(ts)");
+ checkDelete(COPY_ON_WRITE, "system.years(ts) > 30");
+ }
+
+ @Test
+ public void testMergeOnReadDeleteYearsTransform() {
+ initTable("years(ts)");
+ checkDelete(MERGE_ON_READ, "system.years(ts) <= 30");
+ }
+
+ @Test
+ public void testCopyOnWriteDeleteMonthsTransform() {
+ initTable("months(ts)");
+ checkDelete(COPY_ON_WRITE, "system.months(ts) <= 250");
+ }
+
+ @Test
+ public void testMergeOnReadDeleteMonthsTransform() {
+ initTable("months(ts)");
+ checkDelete(MERGE_ON_READ, "system.months(ts) > 250");
+ }
+
+ @Test
+ public void testCopyOnWriteDeleteDaysTransform() {
+ initTable("days(ts)");
+ checkDelete(COPY_ON_WRITE, "system.days(ts) <= date('2000-01-03 00:00:00')");
+ }
+
+ @Test
+ public void testMergeOnReadDeleteDaysTransform() {
+ initTable("days(ts)");
+ checkDelete(MERGE_ON_READ, "system.days(ts) > date('2000-01-03 00:00:00')");
+ }
+
+ @Test
+ public void testCopyOnWriteDeleteHoursTransform() {
+ initTable("hours(ts)");
+ checkDelete(COPY_ON_WRITE, "system.hours(ts) <= 100000");
+ }
+
+ @Test
+ public void testMergeOnReadDeleteHoursTransform() {
+ initTable("hours(ts)");
+ checkDelete(MERGE_ON_READ, "system.hours(ts) > 100000");
+ }
+
+ @Test
+ public void testCopyOnWriteDeleteTruncateTransform() {
+ initTable("truncate(1, dep)");
+ checkDelete(COPY_ON_WRITE, "system.truncate(1, dep) = 'i'");
+ }
+
+ @Test
+ public void testMergeOnReadDeleteTruncateTransform() {
+ initTable("truncate(1, dep)");
+ checkDelete(MERGE_ON_READ, "system.truncate(1, dep) = 'i'");
+ }
+
+ @Test
+ public void testCopyOnWriteUpdateBucketTransform() {
+ initTable("bucket(4, dep)");
+ checkUpdate(COPY_ON_WRITE, "system.bucket(4, dep) IN (2, 3)");
+ }
+
+ @Test
+ public void testMergeOnReadUpdateBucketTransform() {
+ initTable("bucket(4, dep)");
+ checkUpdate(MERGE_ON_READ, "system.bucket(4, dep) = 2");
+ }
+
+ @Test
+ public void testCopyOnWriteUpdateYearsTransform() {
+ initTable("years(ts)");
+ checkUpdate(COPY_ON_WRITE, "system.years(ts) > 30");
+ }
+
+ @Test
+ public void testMergeOnReadUpdateYearsTransform() {
+ initTable("years(ts)");
+ checkUpdate(MERGE_ON_READ, "system.years(ts) <= 30");
+ }
+
+ @Test
+ public void testCopyOnWriteMergeBucketTransform() {
+ initTable("bucket(4, dep)");
+ checkMerge(COPY_ON_WRITE, "system.bucket(4, dep) IN (2, 3)");
+ }
+
+ @Test
+ public void testMergeOnReadMergeBucketTransform() {
+ initTable("bucket(4, dep)");
+ checkMerge(MERGE_ON_READ, "system.bucket(4, dep) = 2");
+ }
+
+ @Test
+ public void testCopyOnWriteMergeYearsTransform() {
+ initTable("years(ts)");
+ checkMerge(COPY_ON_WRITE, "system.years(ts) > 30");
+ }
+
+ @Test
+ public void testMergeOnReadMergeYearsTransform() {
+ initTable("years(ts)");
+ checkMerge(MERGE_ON_READ, "system.years(ts) <= 30");
+ }
+
+ @Test
+ public void testCopyOnWriteMergeTruncateTransform() {
+ initTable("truncate(1, dep)");
+ checkMerge(COPY_ON_WRITE, "system.truncate(1, dep) = 'i'");
+ }
+
+ @Test
+ public void testMergeOnReadMergeTruncateTransform() {
+ initTable("truncate(1, dep)");
+ checkMerge(MERGE_ON_READ, "system.truncate(1, dep) = 'i'");
+ }
+
+ private void checkDelete(RowLevelOperationMode mode, String cond) {
+ withUnavailableLocations(
+ findIrrelevantFileLocations(cond),
+ () -> {
+ sql(
+ "ALTER TABLE %s SET TBLPROPERTIES ('%s' '%s', '%s' '%s')",
+ tableName,
+ TableProperties.DELETE_MODE,
+ mode.modeName(),
+ TableProperties.DELETE_DISTRIBUTION_MODE,
+ DistributionMode.NONE.modeName());
+
+ Dataset<Row> changeDF = spark.table(tableName).where(cond).limit(2).select("id");
+ changeDF.coalesce(1).writeTo(tableName(CHANGES_TABLE_NAME)).create();
+
+ List<Expression> calls =
+ executeAndCollectFunctionCalls(
+ "DELETE FROM %s t WHERE %s AND t.id IN (SELECT id FROM %s)",
+ tableName, cond, tableName(CHANGES_TABLE_NAME));
+ // CoW planning currently does not optimize post-scan filters in DELETE
+ int expectedCallCount = mode == COPY_ON_WRITE ? 1 : 0;
+ assertThat(calls).hasSize(expectedCallCount);
+
+ assertEquals(
+ "Should have no matching rows",
+ ImmutableList.of(),
+ sql(
+ "SELECT * FROM %s WHERE %s AND id IN (SELECT * FROM %s)",
+ tableName, cond, tableName(CHANGES_TABLE_NAME)));
+ });
+ }
+
+ private void checkUpdate(RowLevelOperationMode mode, String cond) {
+ withUnavailableLocations(
+ findIrrelevantFileLocations(cond),
+ () -> {
+ sql(
+ "ALTER TABLE %s SET TBLPROPERTIES ('%s' '%s', '%s' '%s')",
+ tableName,
+ TableProperties.UPDATE_MODE,
+ mode.modeName(),
+ TableProperties.UPDATE_DISTRIBUTION_MODE,
+ DistributionMode.NONE.modeName());
+
+ Dataset<Row> changeDF = spark.table(tableName).where(cond).limit(2).select("id");
+ changeDF.coalesce(1).writeTo(tableName(CHANGES_TABLE_NAME)).create();
+
+ List<Expression> calls =
+ executeAndCollectFunctionCalls(
+ "UPDATE %s t SET t.salary = -1 WHERE %s AND t.id IN (SELECT id FROM %s)",
+ tableName, cond, tableName(CHANGES_TABLE_NAME));
+ // CoW planning currently does not optimize post-scan filters in UPDATE
+ int expectedCallCount = mode == COPY_ON_WRITE ? 2 : 0;
+ assertThat(calls).hasSize(expectedCallCount);
+
+ assertEquals(
+ "Should have correct updates",
+ sql("SELECT id FROM %s", tableName(CHANGES_TABLE_NAME)),
+ sql("SELECT id FROM %s WHERE %s AND salary = -1", tableName, cond));
+ });
+ }
+
+ private void checkMerge(RowLevelOperationMode mode, String cond) {
+ withUnavailableLocations(
+ findIrrelevantFileLocations(cond),
+ () -> {
+ sql(
+ "ALTER TABLE %s SET TBLPROPERTIES ('%s' '%s', '%s' '%s')",
+ tableName,
+ TableProperties.MERGE_MODE,
+ mode.modeName(),
+ TableProperties.MERGE_DISTRIBUTION_MODE,
+ DistributionMode.NONE.modeName());
+
+ Dataset<Row> changeDF =
+ spark.table(tableName).where(cond).limit(2).selectExpr("id + 1 as id");
+ changeDF.coalesce(1).writeTo(tableName(CHANGES_TABLE_NAME)).create();
+
+ List<Expression> calls =
+ executeAndCollectFunctionCalls(
+ "MERGE INTO %s t USING %s s "
+ + "ON t.id == s.id AND %s "
+ + "WHEN MATCHED THEN "
+ + " UPDATE SET salary = -1 "
+ + "WHEN NOT MATCHED AND s.id = 2 THEN "
+ + " INSERT (id, salary, dep, ts) VALUES (100, -1, 'hr', null)",
+ tableName, tableName(CHANGES_TABLE_NAME), cond);
+ assertThat(calls).isEmpty();
+
+ assertEquals(
+ "Should have correct updates",
+ sql("SELECT id FROM %s", tableName(CHANGES_TABLE_NAME)),
+ sql("SELECT id FROM %s WHERE %s AND salary = -1", tableName, cond));
+ });
+ }
+
+ private List<Expression> executeAndCollectFunctionCalls(String query, Object... args) {
+ CommandResultExec command = (CommandResultExec) executeAndKeepPlan(query, args);
+ V2TableWriteExec write = (V2TableWriteExec) command.commandPhysicalPlan();
+ System.out.println("!!! WRITE PLAN !!!");
+ System.out.println(write.toString());
+ return SparkPlanUtil.collectExprs(
+ write.query(),
+ expr -> expr instanceof StaticInvoke || expr instanceof ApplyFunctionExpression);
+ }
+
+ private List<String> findIrrelevantFileLocations(String cond) {
+ return spark
+ .table(tableName)
+ .where("NOT " + cond)
+ .select(MetadataColumns.FILE_PATH.name())
+ .distinct()
+ .as(Encoders.STRING())
+ .collectAsList();
+ }
+
+ private void initTable(String transform) {
+ sql(
+ "CREATE TABLE %s (id BIGINT, salary INT, dep STRING, ts TIMESTAMP)"
+ + "USING iceberg "
+ + "PARTITIONED BY (%s) "
+ + "TBLPROPERTIES ('%s' 'true')",
+ tableName, transform, TableProperties.SPARK_WRITE_PARTITIONED_FANOUT_ENABLED);
+
+ append(
+ tableName,
+ "{ \"id\": 1, \"salary\": 100, \"dep\": \"hr\", \"ts\": \"1975-01-01 06:00:00\" }",
+ "{ \"id\": 2, \"salary\": 200, \"dep\": \"hr\", \"ts\": \"1975-01-01 06:00:00\" }",
+ "{ \"id\": 3, \"salary\": 300, \"dep\": \"hr\", \"ts\": \"1975-01-01 06:00:00\" }",
+ "{ \"id\": 4, \"salary\": 400, \"dep\": \"it\", \"ts\": \"2020-01-01 10:00:00\" }",
+ "{ \"id\": 5, \"salary\": 500, \"dep\": \"it\", \"ts\": \"2020-01-01 10:00:00\" }",
+ "{ \"id\": 6, \"salary\": 600, \"dep\": \"it\", \"ts\": \"2020-01-01 10:00:00\" }");
+ }
+}
diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/BaseScalarFunction.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/BaseScalarFunction.java
new file mode 100644
index 0000000..5ec44f3
--- /dev/null
+++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/BaseScalarFunction.java
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.spark.functions;
+
+import org.apache.spark.sql.connector.catalog.functions.ScalarFunction;
+
+abstract class BaseScalarFunction<R> implements ScalarFunction<R> {
+ @Override
+ public int hashCode() {
+ return canonicalName().hashCode();
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (this == other) {
+ return true;
+ } else if (!(other instanceof ScalarFunction)) {
+ return false;
+ }
+
+ ScalarFunction<?> that = (ScalarFunction<?>) other;
+ return canonicalName().equals(that.canonicalName());
+ }
+}
diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/BucketFunction.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/BucketFunction.java
index af3c67a..c3de3d4 100644
--- a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/BucketFunction.java
+++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/BucketFunction.java
@@ -25,7 +25,6 @@
import org.apache.iceberg.util.BucketUtil;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.connector.catalog.functions.BoundFunction;
-import org.apache.spark.sql.connector.catalog.functions.ScalarFunction;
import org.apache.spark.sql.connector.catalog.functions.UnboundFunction;
import org.apache.spark.sql.types.BinaryType;
import org.apache.spark.sql.types.ByteType;
@@ -115,7 +114,7 @@
return "bucket";
}
- public abstract static class BucketBase implements ScalarFunction<Integer> {
+ public abstract static class BucketBase extends BaseScalarFunction<Integer> {
public static int apply(int numBuckets, int hashedValue) {
return (hashedValue & Integer.MAX_VALUE) % numBuckets;
}
diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/DaysFunction.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/DaysFunction.java
index b8d28b7..f52edd9 100644
--- a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/DaysFunction.java
+++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/DaysFunction.java
@@ -21,7 +21,6 @@
import org.apache.iceberg.util.DateTimeUtil;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.connector.catalog.functions.BoundFunction;
-import org.apache.spark.sql.connector.catalog.functions.ScalarFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.DateType;
@@ -61,7 +60,7 @@
return "days";
}
- private abstract static class BaseToDaysFunction implements ScalarFunction<Integer> {
+ private abstract static class BaseToDaysFunction extends BaseScalarFunction<Integer> {
@Override
public String name() {
return "days";
diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/HoursFunction.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/HoursFunction.java
index 18697e1..660a182 100644
--- a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/HoursFunction.java
+++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/HoursFunction.java
@@ -21,7 +21,6 @@
import org.apache.iceberg.util.DateTimeUtil;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.connector.catalog.functions.BoundFunction;
-import org.apache.spark.sql.connector.catalog.functions.ScalarFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.TimestampNTZType;
@@ -58,7 +57,7 @@
return "hours";
}
- public static class TimestampToHoursFunction implements ScalarFunction<Integer> {
+ public static class TimestampToHoursFunction extends BaseScalarFunction<Integer> {
// magic method used in codegen
public static int invoke(long micros) {
return DateTimeUtil.microsToHours(micros);
@@ -91,7 +90,7 @@
}
}
- public static class TimestampNtzToHoursFunction implements ScalarFunction<Integer> {
+ public static class TimestampNtzToHoursFunction extends BaseScalarFunction<Integer> {
// magic method used in codegen
public static int invoke(long micros) {
return DateTimeUtil.microsToHours(micros);
diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/IcebergVersionFunction.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/IcebergVersionFunction.java
index 9cd0593..689a0f4 100644
--- a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/IcebergVersionFunction.java
+++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/IcebergVersionFunction.java
@@ -21,7 +21,6 @@
import org.apache.iceberg.IcebergBuild;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.connector.catalog.functions.BoundFunction;
-import org.apache.spark.sql.connector.catalog.functions.ScalarFunction;
import org.apache.spark.sql.connector.catalog.functions.UnboundFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
@@ -55,7 +54,7 @@
// Implementing class cannot be private, otherwise Spark is unable to access the static invoke
// function during code-gen and calling the function fails
- static class IcebergVersionFunctionImpl implements ScalarFunction<UTF8String> {
+ static class IcebergVersionFunctionImpl extends BaseScalarFunction<UTF8String> {
private static final UTF8String VERSION = UTF8String.fromString(IcebergBuild.version());
// magic function used in code-gen. must be named `invoke`.
diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/MonthsFunction.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/MonthsFunction.java
index 1d38014..353d850 100644
--- a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/MonthsFunction.java
+++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/MonthsFunction.java
@@ -21,7 +21,6 @@
import org.apache.iceberg.util.DateTimeUtil;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.connector.catalog.functions.BoundFunction;
-import org.apache.spark.sql.connector.catalog.functions.ScalarFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.DateType;
@@ -61,7 +60,7 @@
return "months";
}
- private abstract static class BaseToMonthsFunction implements ScalarFunction<Integer> {
+ private abstract static class BaseToMonthsFunction extends BaseScalarFunction<Integer> {
@Override
public String name() {
return "months";
diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/TruncateFunction.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/TruncateFunction.java
index 8cfb529..fac90c9 100644
--- a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/TruncateFunction.java
+++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/TruncateFunction.java
@@ -27,7 +27,6 @@
import org.apache.iceberg.util.TruncateUtil;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.connector.catalog.functions.BoundFunction;
-import org.apache.spark.sql.connector.catalog.functions.ScalarFunction;
import org.apache.spark.sql.connector.catalog.functions.UnboundFunction;
import org.apache.spark.sql.types.BinaryType;
import org.apache.spark.sql.types.ByteType;
@@ -108,7 +107,7 @@
return "truncate";
}
- public abstract static class TruncateBase<T> implements ScalarFunction<T> {
+ public abstract static class TruncateBase<T> extends BaseScalarFunction<T> {
@Override
public String name() {
return "truncate";
diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/YearsFunction.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/YearsFunction.java
index 02642e6..cfd1b0e 100644
--- a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/YearsFunction.java
+++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/YearsFunction.java
@@ -21,7 +21,6 @@
import org.apache.iceberg.util.DateTimeUtil;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.connector.catalog.functions.BoundFunction;
-import org.apache.spark.sql.connector.catalog.functions.ScalarFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.DateType;
@@ -61,7 +60,7 @@
return "years";
}
- private abstract static class BaseToYearsFunction implements ScalarFunction<Integer> {
+ private abstract static class BaseToYearsFunction extends BaseScalarFunction<Integer> {
@Override
public String name() {
return "years";