Merge pull request #26794: #26789 Fix auto schema update when schema order has changed. (#26810)

Co-authored-by: reuvenlax <relax@google.com>
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWriteUnshardedRecords.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWriteUnshardedRecords.java
index f9f9b1b..46b542a 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWriteUnshardedRecords.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWriteUnshardedRecords.java
@@ -38,6 +38,7 @@
 import java.util.List;
 import java.util.Map;
 import java.util.Objects;
+import java.util.Optional;
 import java.util.Random;
 import java.util.Set;
 import java.util.concurrent.ExecutorService;
@@ -685,16 +686,22 @@
 
       void postFlush() {
         // If we got a response indicating an updated schema, recreate the client.
-        if (this.appendClientInfo != null) {
+        if (this.appendClientInfo != null && autoUpdateSchema) {
           @Nullable
           StreamAppendClient streamAppendClient = appendClientInfo.getStreamAppendClient();
           @Nullable
-          TableSchema updatedTableSchema =
+          TableSchema updatedTableSchemaReturned =
               (streamAppendClient != null) ? streamAppendClient.getUpdatedSchema() : null;
-          if (updatedTableSchema != null && autoUpdateSchema) {
-            invalidateWriteStream();
-            appendClientInfo =
-                Preconditions.checkStateNotNull(getAppendClientInfo(false, updatedTableSchema));
+          if (updatedTableSchemaReturned != null) {
+            Optional<TableSchema> updatedTableSchema =
+                TableSchemaUpdateUtils.getUpdatedSchema(
+                    this.initialTableSchema, updatedTableSchemaReturned);
+            if (updatedTableSchema.isPresent()) {
+              invalidateWriteStream();
+              appendClientInfo =
+                  Preconditions.checkStateNotNull(
+                      getAppendClientInfo(false, updatedTableSchema.get()));
+            }
           }
         }
       }
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWritesShardedRecords.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWritesShardedRecords.java
index cd23b7b..e0353bf 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWritesShardedRecords.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWritesShardedRecords.java
@@ -36,6 +36,7 @@
 import java.util.Arrays;
 import java.util.List;
 import java.util.Map;
+import java.util.Optional;
 import java.util.Set;
 import java.util.concurrent.Callable;
 import java.util.concurrent.ExecutorService;
@@ -751,16 +752,23 @@
         if (autoUpdateSchema) {
           @Nullable
           StreamAppendClient streamAppendClient = appendClientInfo.get().getStreamAppendClient();
+          TableSchema originalSchema = appendClientInfo.get().getTableSchema();
+          ;
           @Nullable
-          TableSchema newSchema =
+          TableSchema updatedSchemaReturned =
               (streamAppendClient != null) ? streamAppendClient.getUpdatedSchema() : null;
           // Update the table schema and clear the append client.
-          if (newSchema != null) {
-            appendClientInfo.set(
-                AppendClientInfo.of(newSchema, appendClientInfo.get().getCloseAppendClient()));
-            APPEND_CLIENTS.invalidate(element.getKey());
-            APPEND_CLIENTS.put(element.getKey(), appendClientInfo.get());
-            updatedSchema.write(newSchema);
+          if (updatedSchemaReturned != null) {
+            Optional<TableSchema> newSchema =
+                TableSchemaUpdateUtils.getUpdatedSchema(originalSchema, updatedSchemaReturned);
+            if (newSchema.isPresent()) {
+              appendClientInfo.set(
+                  AppendClientInfo.of(
+                      newSchema.get(), appendClientInfo.get().getCloseAppendClient()));
+              APPEND_CLIENTS.invalidate(element.getKey());
+              APPEND_CLIENTS.put(element.getKey(), appendClientInfo.get());
+              updatedSchema.write(newSchema.get());
+            }
           }
         }
 
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/TableSchemaUpdateUtils.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/TableSchemaUpdateUtils.java
new file mode 100644
index 0000000..cba394a
--- /dev/null
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/TableSchemaUpdateUtils.java
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.gcp.bigquery;
+
+import com.google.auto.value.AutoValue;
+import com.google.cloud.bigquery.storage.v1.TableFieldSchema;
+import com.google.cloud.bigquery.storage.v1.TableSchema;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Set;
+import java.util.stream.Collectors;
+import javax.annotation.Nullable;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Sets;
+
+/** Helper utilities for handling schema-update responses. */
+public class TableSchemaUpdateUtils {
+  /*
+  Given an original schema and an updated schema, return a schema that should be used to process future records.
+  This function returns:
+      - If the new schema is not compatible (e.g. missing fields), then it will return Optional.empty().
+      - If the new schema is equivalent (i.e. equal modulo field ordering) to the old schema, then it will return
+        Optional.empty().
+      - The returned schema will always contain the old schema as a prefix. This ensures that if any of the old
+       fields are reordered in the new schema, we maintain the old order.
+   */
+  public static Optional<TableSchema> getUpdatedSchema(
+      TableSchema oldSchema, TableSchema newSchema) {
+    Result updatedFields = getUpdatedSchema(oldSchema.getFieldsList(), newSchema.getFieldsList());
+    if (updatedFields.isEquivalent()) {
+      return Optional.empty();
+    } else {
+      return updatedFields
+          .getFields()
+          .map(
+              tableFieldSchemas ->
+                  TableSchema.newBuilder().addAllFields(tableFieldSchemas).build());
+    }
+  }
+
+  @AutoValue
+  abstract static class Result {
+    abstract Optional<List<TableFieldSchema>> getFields();
+
+    abstract boolean isEquivalent();
+
+    static Result of(List<TableFieldSchema> fields, boolean isEquivalent) {
+      return new AutoValue_TableSchemaUpdateUtils_Result(Optional.of(fields), isEquivalent);
+    }
+
+    static Result empty() {
+      return new AutoValue_TableSchemaUpdateUtils_Result(Optional.empty(), false);
+    }
+  }
+
+  private static Result getUpdatedSchema(
+      @Nullable List<TableFieldSchema> oldSchema, @Nullable List<TableFieldSchema> newSchema) {
+    if (newSchema == null) {
+      return Result.empty();
+    }
+    if (oldSchema == null) {
+      return Result.of(newSchema, false);
+    }
+
+    Map<String, TableFieldSchema> newSchemaMap =
+        newSchema.stream().collect(Collectors.toMap(TableFieldSchema::getName, x -> x));
+    Set<String> fieldNamesPopulated = Sets.newHashSet();
+    List<TableFieldSchema> updatedSchema = Lists.newArrayList();
+    boolean isEquivalent = oldSchema.size() == newSchema.size();
+    for (TableFieldSchema tableFieldSchema : oldSchema) {
+      @Nullable TableFieldSchema newTableFieldSchema = newSchemaMap.get(tableFieldSchema.getName());
+      if (newTableFieldSchema == null) {
+        // We don't support deleting fields!
+        return Result.empty();
+      }
+      TableFieldSchema.Builder updatedTableFieldSchema = newTableFieldSchema.toBuilder();
+      updatedTableFieldSchema.clearFields();
+      if (tableFieldSchema.getType().equals(TableFieldSchema.Type.STRUCT)) {
+        Result updatedTableFields =
+            getUpdatedSchema(tableFieldSchema.getFieldsList(), newTableFieldSchema.getFieldsList());
+        if (!updatedTableFields.getFields().isPresent()) {
+          return updatedTableFields;
+        }
+        updatedTableFieldSchema.addAllFields(updatedTableFields.getFields().get());
+        isEquivalent = isEquivalent && updatedTableFields.isEquivalent();
+        isEquivalent =
+            isEquivalent
+                && tableFieldSchema
+                    .toBuilder()
+                    .clearFields()
+                    .build()
+                    .equals(newTableFieldSchema.toBuilder().clearFields().build());
+      } else {
+        isEquivalent = isEquivalent && tableFieldSchema.equals(newTableFieldSchema);
+      }
+      updatedSchema.add(updatedTableFieldSchema.build());
+      fieldNamesPopulated.add(updatedTableFieldSchema.getName());
+    }
+
+    // Add in new fields at the end of the schema.
+    newSchema.stream()
+        .filter(f -> !fieldNamesPopulated.contains(f.getName()))
+        .forEach(updatedSchema::add);
+    return Result.of(updatedSchema, isEquivalent);
+  }
+}
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOWriteTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOWriteTest.java
index 13a06b9..d0fddaf 100644
--- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOWriteTest.java
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOWriteTest.java
@@ -1985,9 +1985,13 @@
         new TableSchema()
             .setFields(
                 ImmutableList.of(
-                    new TableFieldSchema().setName("name").setType("STRING"),
                     new TableFieldSchema().setName("number").setType("INTEGER"),
+                    new TableFieldSchema().setName("name").setType("STRING"),
                     new TableFieldSchema().setName("req").setType("STRING").setMode("REQUIRED")));
+
+    // Add new fields to the update schema. Also reorder some existing fields to validate that we
+    // handle update
+    // field reordering correctly.
     TableSchema tableSchemaUpdated =
         new TableSchema()
             .setFields(
@@ -2018,8 +2022,8 @@
                 new TableRow()
                     .setF(
                         ImmutableList.of(
-                            new TableCell().setV("name" + i),
                             new TableCell().setV(Long.toString(i)),
+                            new TableCell().setV("name" + i),
                             new TableCell().setV(i > 5 ? null : "foo"),
                             new TableCell().setV(Long.toString(i * 2))));
 
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/TableSchemaUpdateUtilsTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/TableSchemaUpdateUtilsTest.java
new file mode 100644
index 0000000..bdbd207
--- /dev/null
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/TableSchemaUpdateUtilsTest.java
@@ -0,0 +1,253 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.gcp.bigquery;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+
+import com.google.cloud.bigquery.storage.v1.TableFieldSchema;
+import com.google.cloud.bigquery.storage.v1.TableSchema;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/** Tests for the {@link TableSchemaUpdateUtils class}. */
+@RunWith(JUnit4.class)
+public class TableSchemaUpdateUtilsTest {
+  @Test
+  public void testSchemaUpdate() {
+    TableSchema baseSchema =
+        TableSchema.newBuilder()
+            .addFields(
+                TableFieldSchema.newBuilder().setName("a").setType(TableFieldSchema.Type.STRING))
+            .addFields(
+                TableFieldSchema.newBuilder().setName("b").setType(TableFieldSchema.Type.STRING))
+            .addFields(
+                TableFieldSchema.newBuilder().setName("c").setType(TableFieldSchema.Type.STRING))
+            .build();
+    TableSchema schema =
+        TableSchema.newBuilder()
+            .addFields(
+                TableFieldSchema.newBuilder().setName("a").setType(TableFieldSchema.Type.STRING))
+            .addFields(
+                TableFieldSchema.newBuilder().setName("b").setType(TableFieldSchema.Type.STRING))
+            .addFields(
+                TableFieldSchema.newBuilder().setName("c").setType(TableFieldSchema.Type.STRING))
+            .addFields(
+                TableFieldSchema.newBuilder()
+                    .setName("nested")
+                    .setType(TableFieldSchema.Type.STRUCT)
+                    .addAllFields(baseSchema.getFieldsList()))
+            .build();
+    TableSchema topSchema =
+        TableSchema.newBuilder()
+            .addFields(
+                TableFieldSchema.newBuilder().setName("a").setType(TableFieldSchema.Type.STRING))
+            .addFields(
+                TableFieldSchema.newBuilder()
+                    .setName("nested")
+                    .setType(TableFieldSchema.Type.STRUCT)
+                    .addAllFields(schema.getFieldsList()))
+            .build();
+
+    TableSchema newBaseSchema =
+        TableSchema.newBuilder()
+            .addFields(
+                TableFieldSchema.newBuilder().setName("a").setType(TableFieldSchema.Type.STRING))
+            .addFields(
+                TableFieldSchema.newBuilder().setName("c").setType(TableFieldSchema.Type.STRING))
+            .addFields(
+                TableFieldSchema.newBuilder().setName("b").setType(TableFieldSchema.Type.STRING))
+            .addFields(
+                TableFieldSchema.newBuilder().setName("d").setType(TableFieldSchema.Type.STRING))
+            .build();
+    TableSchema newSchema =
+        TableSchema.newBuilder()
+            .addFields(
+                TableFieldSchema.newBuilder()
+                    .setName("nested")
+                    .setType(TableFieldSchema.Type.STRUCT)
+                    .addAllFields(newBaseSchema.getFieldsList()))
+            .addFields(
+                TableFieldSchema.newBuilder().setName("a").setType(TableFieldSchema.Type.STRING))
+            .addFields(
+                TableFieldSchema.newBuilder().setName("c").setType(TableFieldSchema.Type.STRING))
+            .addFields(
+                TableFieldSchema.newBuilder().setName("b").setType(TableFieldSchema.Type.STRING))
+            .addFields(
+                TableFieldSchema.newBuilder().setName("d").setType(TableFieldSchema.Type.STRING))
+            .build();
+
+    TableSchema newTopSchema =
+        TableSchema.newBuilder()
+            .addFields(
+                TableFieldSchema.newBuilder()
+                    .setName("nested")
+                    .setType(TableFieldSchema.Type.STRUCT)
+                    .addAllFields(newSchema.getFieldsList()))
+            .addFields(
+                TableFieldSchema.newBuilder().setName("a").setType(TableFieldSchema.Type.STRING))
+            .build();
+
+    TableSchema expectedSchemaBaseSchema =
+        TableSchema.newBuilder()
+            .addFields(
+                TableFieldSchema.newBuilder().setName("a").setType(TableFieldSchema.Type.STRING))
+            .addFields(
+                TableFieldSchema.newBuilder().setName("b").setType(TableFieldSchema.Type.STRING))
+            .addFields(
+                TableFieldSchema.newBuilder().setName("c").setType(TableFieldSchema.Type.STRING))
+            .addFields(
+                TableFieldSchema.newBuilder().setName("d").setType(TableFieldSchema.Type.STRING))
+            .build();
+    TableSchema expectedSchema =
+        TableSchema.newBuilder()
+            .addFields(
+                TableFieldSchema.newBuilder().setName("a").setType(TableFieldSchema.Type.STRING))
+            .addFields(
+                TableFieldSchema.newBuilder().setName("b").setType(TableFieldSchema.Type.STRING))
+            .addFields(
+                TableFieldSchema.newBuilder().setName("c").setType(TableFieldSchema.Type.STRING))
+            .addFields(
+                TableFieldSchema.newBuilder()
+                    .setName("nested")
+                    .setType(TableFieldSchema.Type.STRUCT)
+                    .addAllFields(expectedSchemaBaseSchema.getFieldsList()))
+            .addFields(
+                TableFieldSchema.newBuilder().setName("d").setType(TableFieldSchema.Type.STRING))
+            .build();
+    TableSchema expectedTopSchema =
+        TableSchema.newBuilder()
+            .addFields(
+                TableFieldSchema.newBuilder().setName("a").setType(TableFieldSchema.Type.STRING))
+            .addFields(
+                TableFieldSchema.newBuilder()
+                    .setName("nested")
+                    .setType(TableFieldSchema.Type.STRUCT)
+                    .addAllFields(expectedSchema.getFieldsList()))
+            .build();
+
+    TableSchema updatedTopSchema =
+        TableSchemaUpdateUtils.getUpdatedSchema(topSchema, newTopSchema).get();
+    assertEquals(expectedTopSchema, updatedTopSchema);
+  }
+
+  @Test
+  public void testEquivalentSchema() {
+    TableSchema baseSchema1 =
+        TableSchema.newBuilder()
+            .addFields(
+                TableFieldSchema.newBuilder().setName("a").setType(TableFieldSchema.Type.STRING))
+            .addFields(
+                TableFieldSchema.newBuilder().setName("b").setType(TableFieldSchema.Type.STRING))
+            .addFields(
+                TableFieldSchema.newBuilder().setName("c").setType(TableFieldSchema.Type.STRING))
+            .build();
+    TableSchema schema1 =
+        TableSchema.newBuilder()
+            .addFields(
+                TableFieldSchema.newBuilder().setName("a").setType(TableFieldSchema.Type.STRING))
+            .addFields(
+                TableFieldSchema.newBuilder().setName("b").setType(TableFieldSchema.Type.STRING))
+            .addFields(
+                TableFieldSchema.newBuilder().setName("c").setType(TableFieldSchema.Type.STRING))
+            .addFields(
+                TableFieldSchema.newBuilder()
+                    .setName("nested")
+                    .setType(TableFieldSchema.Type.STRUCT)
+                    .addAllFields(baseSchema1.getFieldsList()))
+            .build();
+
+    TableSchema baseSchema2 =
+        TableSchema.newBuilder()
+            .addFields(
+                TableFieldSchema.newBuilder().setName("c").setType(TableFieldSchema.Type.STRING))
+            .addFields(
+                TableFieldSchema.newBuilder().setName("b").setType(TableFieldSchema.Type.STRING))
+            .addFields(
+                TableFieldSchema.newBuilder().setName("a").setType(TableFieldSchema.Type.STRING))
+            .build();
+    TableSchema schema2 =
+        TableSchema.newBuilder()
+            .addFields(
+                TableFieldSchema.newBuilder()
+                    .setName("nested")
+                    .setType(TableFieldSchema.Type.STRUCT)
+                    .addAllFields(baseSchema2.getFieldsList()))
+            .addFields(
+                TableFieldSchema.newBuilder().setName("b").setType(TableFieldSchema.Type.STRING))
+            .addFields(
+                TableFieldSchema.newBuilder().setName("c").setType(TableFieldSchema.Type.STRING))
+            .addFields(
+                TableFieldSchema.newBuilder().setName("a").setType(TableFieldSchema.Type.STRING))
+            .build();
+    assertFalse(TableSchemaUpdateUtils.getUpdatedSchema(schema1, schema2).isPresent());
+  }
+
+  @Test
+  public void testNonEquivalentSchema() {
+    TableSchema baseSchema1 =
+        TableSchema.newBuilder()
+            .addFields(
+                TableFieldSchema.newBuilder().setName("a").setType(TableFieldSchema.Type.STRING))
+            .addFields(
+                TableFieldSchema.newBuilder().setName("b").setType(TableFieldSchema.Type.STRING))
+            .addFields(
+                TableFieldSchema.newBuilder().setName("c").setType(TableFieldSchema.Type.STRING))
+            .build();
+    TableSchema schema1 =
+        TableSchema.newBuilder()
+            .addFields(
+                TableFieldSchema.newBuilder().setName("a").setType(TableFieldSchema.Type.STRING))
+            .addFields(
+                TableFieldSchema.newBuilder().setName("b").setType(TableFieldSchema.Type.STRING))
+            .addFields(
+                TableFieldSchema.newBuilder().setName("c").setType(TableFieldSchema.Type.STRING))
+            .addFields(
+                TableFieldSchema.newBuilder()
+                    .setName("nested")
+                    .setType(TableFieldSchema.Type.STRUCT)
+                    .addAllFields(baseSchema1.getFieldsList()))
+            .build();
+    TableSchema baseSchema2 =
+        TableSchema.newBuilder()
+            .addFields(
+                TableFieldSchema.newBuilder().setName("a").setType(TableFieldSchema.Type.STRING))
+            .addFields(
+                TableFieldSchema.newBuilder().setName("b").setType(TableFieldSchema.Type.STRING))
+            .addFields(
+                TableFieldSchema.newBuilder().setName("c").setType(TableFieldSchema.Type.INT64))
+            .build();
+    TableSchema schema2 =
+        TableSchema.newBuilder()
+            .addFields(
+                TableFieldSchema.newBuilder().setName("a").setType(TableFieldSchema.Type.STRING))
+            .addFields(
+                TableFieldSchema.newBuilder().setName("b").setType(TableFieldSchema.Type.STRING))
+            .addFields(
+                TableFieldSchema.newBuilder().setName("c").setType(TableFieldSchema.Type.STRING))
+            .addFields(
+                TableFieldSchema.newBuilder()
+                    .setName("nested")
+                    .setType(TableFieldSchema.Type.STRUCT)
+                    .addAllFields(baseSchema2.getFieldsList()))
+            .build();
+    assertTrue(TableSchemaUpdateUtils.getUpdatedSchema(schema1, schema2).isPresent());
+  }
+}