| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| package org.apache.iceberg.spark.extensions; |
| |
| import static org.apache.iceberg.RowLevelOperationMode.COPY_ON_WRITE; |
| import static org.apache.iceberg.RowLevelOperationMode.MERGE_ON_READ; |
| import static org.assertj.core.api.Assertions.assertThat; |
| |
| import java.util.List; |
| import java.util.Map; |
| import org.apache.iceberg.DistributionMode; |
| import org.apache.iceberg.MetadataColumns; |
| import org.apache.iceberg.RowLevelOperationMode; |
| import org.apache.iceberg.TableProperties; |
| import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; |
| import org.apache.iceberg.spark.SparkCatalogConfig; |
| import org.apache.spark.sql.Dataset; |
| import org.apache.spark.sql.Encoders; |
| import org.apache.spark.sql.Row; |
| import org.apache.spark.sql.catalyst.expressions.ApplyFunctionExpression; |
| import org.apache.spark.sql.catalyst.expressions.Expression; |
| import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke; |
| import org.apache.spark.sql.execution.CommandResultExec; |
| import org.apache.spark.sql.execution.datasources.v2.V2TableWriteExec; |
| import org.junit.After; |
| import org.junit.Before; |
| import org.junit.Test; |
| import org.junit.runners.Parameterized.Parameters; |
| |
| public class TestSystemFunctionPushDownInRowLevelOperations extends SparkExtensionsTestBase { |
| |
| private static final String CHANGES_TABLE_NAME = "changes"; |
| |
| @Parameters(name = "catalogName = {0}, implementation = {1}, config = {2}") |
| public static Object[][] parameters() { |
| return new Object[][] { |
| { |
| SparkCatalogConfig.HIVE.catalogName(), |
| SparkCatalogConfig.HIVE.implementation(), |
| SparkCatalogConfig.HIVE.properties() |
| } |
| }; |
| } |
| |
| public TestSystemFunctionPushDownInRowLevelOperations( |
| String catalogName, String implementation, Map<String, String> config) { |
| super(catalogName, implementation, config); |
| } |
| |
| @Before |
| public void beforeEach() { |
| sql("USE %s", catalogName); |
| } |
| |
| @After |
| public void removeTables() { |
| sql("DROP TABLE IF EXISTS %s PURGE", tableName); |
| sql("DROP TABLE IF EXISTS %s PURGE", tableName(CHANGES_TABLE_NAME)); |
| } |
| |
| @Test |
| public void testCopyOnWriteDeleteBucketTransformInPredicate() { |
| initTable("bucket(4, dep)"); |
| checkDelete(COPY_ON_WRITE, "system.bucket(4, dep) IN (2, 3)"); |
| } |
| |
| @Test |
| public void testMergeOnReadDeleteBucketTransformInPredicate() { |
| initTable("bucket(4, dep)"); |
| checkDelete(MERGE_ON_READ, "system.bucket(4, dep) IN (2, 3)"); |
| } |
| |
| @Test |
| public void testCopyOnWriteDeleteBucketTransformEqPredicate() { |
| initTable("bucket(4, dep)"); |
| checkDelete(COPY_ON_WRITE, "system.bucket(4, dep) = 2"); |
| } |
| |
| @Test |
| public void testMergeOnReadDeleteBucketTransformEqPredicate() { |
| initTable("bucket(4, dep)"); |
| checkDelete(MERGE_ON_READ, "system.bucket(4, dep) = 2"); |
| } |
| |
| @Test |
| public void testCopyOnWriteDeleteYearsTransform() { |
| initTable("years(ts)"); |
| checkDelete(COPY_ON_WRITE, "system.years(ts) > 30"); |
| } |
| |
| @Test |
| public void testMergeOnReadDeleteYearsTransform() { |
| initTable("years(ts)"); |
| checkDelete(MERGE_ON_READ, "system.years(ts) <= 30"); |
| } |
| |
| @Test |
| public void testCopyOnWriteDeleteMonthsTransform() { |
| initTable("months(ts)"); |
| checkDelete(COPY_ON_WRITE, "system.months(ts) <= 250"); |
| } |
| |
| @Test |
| public void testMergeOnReadDeleteMonthsTransform() { |
| initTable("months(ts)"); |
| checkDelete(MERGE_ON_READ, "system.months(ts) > 250"); |
| } |
| |
| @Test |
| public void testCopyOnWriteDeleteDaysTransform() { |
| initTable("days(ts)"); |
| checkDelete(COPY_ON_WRITE, "system.days(ts) <= date('2000-01-03 00:00:00')"); |
| } |
| |
| @Test |
| public void testMergeOnReadDeleteDaysTransform() { |
| initTable("days(ts)"); |
| checkDelete(MERGE_ON_READ, "system.days(ts) > date('2000-01-03 00:00:00')"); |
| } |
| |
| @Test |
| public void testCopyOnWriteDeleteHoursTransform() { |
| initTable("hours(ts)"); |
| checkDelete(COPY_ON_WRITE, "system.hours(ts) <= 100000"); |
| } |
| |
| @Test |
| public void testMergeOnReadDeleteHoursTransform() { |
| initTable("hours(ts)"); |
| checkDelete(MERGE_ON_READ, "system.hours(ts) > 100000"); |
| } |
| |
| @Test |
| public void testCopyOnWriteDeleteTruncateTransform() { |
| initTable("truncate(1, dep)"); |
| checkDelete(COPY_ON_WRITE, "system.truncate(1, dep) = 'i'"); |
| } |
| |
| @Test |
| public void testMergeOnReadDeleteTruncateTransform() { |
| initTable("truncate(1, dep)"); |
| checkDelete(MERGE_ON_READ, "system.truncate(1, dep) = 'i'"); |
| } |
| |
| @Test |
| public void testCopyOnWriteUpdateBucketTransform() { |
| initTable("bucket(4, dep)"); |
| checkUpdate(COPY_ON_WRITE, "system.bucket(4, dep) IN (2, 3)"); |
| } |
| |
| @Test |
| public void testMergeOnReadUpdateBucketTransform() { |
| initTable("bucket(4, dep)"); |
| checkUpdate(MERGE_ON_READ, "system.bucket(4, dep) = 2"); |
| } |
| |
| @Test |
| public void testCopyOnWriteUpdateYearsTransform() { |
| initTable("years(ts)"); |
| checkUpdate(COPY_ON_WRITE, "system.years(ts) > 30"); |
| } |
| |
| @Test |
| public void testMergeOnReadUpdateYearsTransform() { |
| initTable("years(ts)"); |
| checkUpdate(MERGE_ON_READ, "system.years(ts) <= 30"); |
| } |
| |
| @Test |
| public void testCopyOnWriteMergeBucketTransform() { |
| initTable("bucket(4, dep)"); |
| checkMerge(COPY_ON_WRITE, "system.bucket(4, dep) IN (2, 3)"); |
| } |
| |
| @Test |
| public void testMergeOnReadMergeBucketTransform() { |
| initTable("bucket(4, dep)"); |
| checkMerge(MERGE_ON_READ, "system.bucket(4, dep) = 2"); |
| } |
| |
| @Test |
| public void testCopyOnWriteMergeYearsTransform() { |
| initTable("years(ts)"); |
| checkMerge(COPY_ON_WRITE, "system.years(ts) > 30"); |
| } |
| |
| @Test |
| public void testMergeOnReadMergeYearsTransform() { |
| initTable("years(ts)"); |
| checkMerge(MERGE_ON_READ, "system.years(ts) <= 30"); |
| } |
| |
| @Test |
| public void testCopyOnWriteMergeTruncateTransform() { |
| initTable("truncate(1, dep)"); |
| checkMerge(COPY_ON_WRITE, "system.truncate(1, dep) = 'i'"); |
| } |
| |
| @Test |
| public void testMergeOnReadMergeTruncateTransform() { |
| initTable("truncate(1, dep)"); |
| checkMerge(MERGE_ON_READ, "system.truncate(1, dep) = 'i'"); |
| } |
| |
| private void checkDelete(RowLevelOperationMode mode, String cond) { |
| withUnavailableLocations( |
| findIrrelevantFileLocations(cond), |
| () -> { |
| sql( |
| "ALTER TABLE %s SET TBLPROPERTIES ('%s' '%s', '%s' '%s')", |
| tableName, |
| TableProperties.DELETE_MODE, |
| mode.modeName(), |
| TableProperties.DELETE_DISTRIBUTION_MODE, |
| DistributionMode.NONE.modeName()); |
| |
| Dataset<Row> changeDF = spark.table(tableName).where(cond).limit(2).select("id"); |
| changeDF.coalesce(1).writeTo(tableName(CHANGES_TABLE_NAME)).create(); |
| |
| List<Expression> calls = |
| executeAndCollectFunctionCalls( |
| "DELETE FROM %s t WHERE %s AND t.id IN (SELECT id FROM %s)", |
| tableName, cond, tableName(CHANGES_TABLE_NAME)); |
| // CoW planning currently does not optimize post-scan filters in DELETE |
| int expectedCallCount = mode == COPY_ON_WRITE ? 1 : 0; |
| assertThat(calls).hasSize(expectedCallCount); |
| |
| assertEquals( |
| "Should have no matching rows", |
| ImmutableList.of(), |
| sql( |
| "SELECT * FROM %s WHERE %s AND id IN (SELECT * FROM %s)", |
| tableName, cond, tableName(CHANGES_TABLE_NAME))); |
| }); |
| } |
| |
| private void checkUpdate(RowLevelOperationMode mode, String cond) { |
| withUnavailableLocations( |
| findIrrelevantFileLocations(cond), |
| () -> { |
| sql( |
| "ALTER TABLE %s SET TBLPROPERTIES ('%s' '%s', '%s' '%s')", |
| tableName, |
| TableProperties.UPDATE_MODE, |
| mode.modeName(), |
| TableProperties.UPDATE_DISTRIBUTION_MODE, |
| DistributionMode.NONE.modeName()); |
| |
| Dataset<Row> changeDF = spark.table(tableName).where(cond).limit(2).select("id"); |
| changeDF.coalesce(1).writeTo(tableName(CHANGES_TABLE_NAME)).create(); |
| |
| List<Expression> calls = |
| executeAndCollectFunctionCalls( |
| "UPDATE %s t SET t.salary = -1 WHERE %s AND t.id IN (SELECT id FROM %s)", |
| tableName, cond, tableName(CHANGES_TABLE_NAME)); |
| // CoW planning currently does not optimize post-scan filters in UPDATE |
| int expectedCallCount = mode == COPY_ON_WRITE ? 2 : 0; |
| assertThat(calls).hasSize(expectedCallCount); |
| |
| assertEquals( |
| "Should have correct updates", |
| sql("SELECT id FROM %s", tableName(CHANGES_TABLE_NAME)), |
| sql("SELECT id FROM %s WHERE %s AND salary = -1", tableName, cond)); |
| }); |
| } |
| |
| private void checkMerge(RowLevelOperationMode mode, String cond) { |
| withUnavailableLocations( |
| findIrrelevantFileLocations(cond), |
| () -> { |
| sql( |
| "ALTER TABLE %s SET TBLPROPERTIES ('%s' '%s', '%s' '%s')", |
| tableName, |
| TableProperties.MERGE_MODE, |
| mode.modeName(), |
| TableProperties.MERGE_DISTRIBUTION_MODE, |
| DistributionMode.NONE.modeName()); |
| |
| Dataset<Row> changeDF = |
| spark.table(tableName).where(cond).limit(2).selectExpr("id + 1 as id"); |
| changeDF.coalesce(1).writeTo(tableName(CHANGES_TABLE_NAME)).create(); |
| |
| List<Expression> calls = |
| executeAndCollectFunctionCalls( |
| "MERGE INTO %s t USING %s s " |
| + "ON t.id == s.id AND %s " |
| + "WHEN MATCHED THEN " |
| + " UPDATE SET salary = -1 " |
| + "WHEN NOT MATCHED AND s.id = 2 THEN " |
| + " INSERT (id, salary, dep, ts) VALUES (100, -1, 'hr', null)", |
| tableName, tableName(CHANGES_TABLE_NAME), cond); |
| assertThat(calls).isEmpty(); |
| |
| assertEquals( |
| "Should have correct updates", |
| sql("SELECT id FROM %s", tableName(CHANGES_TABLE_NAME)), |
| sql("SELECT id FROM %s WHERE %s AND salary = -1", tableName, cond)); |
| }); |
| } |
| |
| private List<Expression> executeAndCollectFunctionCalls(String query, Object... args) { |
| CommandResultExec command = (CommandResultExec) executeAndKeepPlan(query, args); |
| V2TableWriteExec write = (V2TableWriteExec) command.commandPhysicalPlan(); |
| System.out.println("!!! WRITE PLAN !!!"); |
| System.out.println(write.toString()); |
| return SparkPlanUtil.collectExprs( |
| write.query(), |
| expr -> expr instanceof StaticInvoke || expr instanceof ApplyFunctionExpression); |
| } |
| |
| private List<String> findIrrelevantFileLocations(String cond) { |
| return spark |
| .table(tableName) |
| .where("NOT " + cond) |
| .select(MetadataColumns.FILE_PATH.name()) |
| .distinct() |
| .as(Encoders.STRING()) |
| .collectAsList(); |
| } |
| |
| private void initTable(String transform) { |
| sql( |
| "CREATE TABLE %s (id BIGINT, salary INT, dep STRING, ts TIMESTAMP)" |
| + "USING iceberg " |
| + "PARTITIONED BY (%s) " |
| + "TBLPROPERTIES ('%s' 'true')", |
| tableName, transform, TableProperties.SPARK_WRITE_PARTITIONED_FANOUT_ENABLED); |
| |
| append( |
| tableName, |
| "{ \"id\": 1, \"salary\": 100, \"dep\": \"hr\", \"ts\": \"1975-01-01 06:00:00\" }", |
| "{ \"id\": 2, \"salary\": 200, \"dep\": \"hr\", \"ts\": \"1975-01-01 06:00:00\" }", |
| "{ \"id\": 3, \"salary\": 300, \"dep\": \"hr\", \"ts\": \"1975-01-01 06:00:00\" }", |
| "{ \"id\": 4, \"salary\": 400, \"dep\": \"it\", \"ts\": \"2020-01-01 10:00:00\" }", |
| "{ \"id\": 5, \"salary\": 500, \"dep\": \"it\", \"ts\": \"2020-01-01 10:00:00\" }", |
| "{ \"id\": 6, \"salary\": 600, \"dep\": \"it\", \"ts\": \"2020-01-01 10:00:00\" }"); |
| } |
| } |