Merge pull request #15480: [BEAM-12356] Make sure DatasetService is always closed

diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BatchedStreamingWrite.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BatchedStreamingWrite.java
index 484fe3b..dfe797e 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BatchedStreamingWrite.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BatchedStreamingWrite.java
@@ -86,7 +86,6 @@
   private final SerializableFunction<ElementT, TableRow> toTableRow;
   private final SerializableFunction<ElementT, TableRow> toFailsafeTableRow;
   private final Set<String> allowedMetricUrns;
-  private @Nullable DatasetService datasetService;
 
   /** Tracks bytes written, exposed as "ByteCount" Counter. */
   private Counter byteCounter = SinkMetrics.bytesWritten();
@@ -222,6 +221,15 @@
     /** The list of unique ids for each BigQuery table row. */
     private transient Map<String, List<String>> uniqueIdsForTableRows;
 
+    private transient @Nullable DatasetService datasetService;
+
+    private DatasetService getDatasetService(PipelineOptions pipelineOptions) throws IOException {
+      if (datasetService == null) {
+        datasetService = bqServices.getDatasetService(pipelineOptions.as(BigQueryOptions.class));
+      }
+      return datasetService;
+    }
+
     /** Prepares a target BigQuery table. */
     @StartBundle
     public void startBundle() {
@@ -257,10 +265,10 @@
           tableRows.entrySet()) {
         TableReference tableReference = BigQueryHelpers.parseTableSpec(entry.getKey());
         flushRows(
+            getDatasetService(options),
             tableReference,
             entry.getValue(),
             uniqueIdsForTableRows.get(entry.getKey()),
-            options,
             failedInserts,
             successfulInserts);
       }
@@ -272,6 +280,18 @@
       }
       reportStreamingApiLogging(options);
     }
+
+    @Teardown
+    public void onTeardown() {
+      try {
+        if (datasetService != null) {
+          datasetService.close();
+          datasetService = null;
+        }
+      } catch (Exception e) {
+        throw new RuntimeException(e);
+      }
+    }
   }
 
   // The max duration input records are allowed to be buffered in the state, if using ViaStateful.
@@ -325,13 +345,22 @@
   // shuffling.
   private class InsertBatchedElements
       extends DoFn<KV<ShardedKey<String>, Iterable<TableRowInfo<ElementT>>>, Void> {
+    private transient @Nullable DatasetService datasetService;
+
+    private DatasetService getDatasetService(PipelineOptions pipelineOptions) throws IOException {
+      if (datasetService == null) {
+        datasetService = bqServices.getDatasetService(pipelineOptions.as(BigQueryOptions.class));
+      }
+      return datasetService;
+    }
+
     @ProcessElement
     public void processElement(
         @Element KV<ShardedKey<String>, Iterable<TableRowInfo<ElementT>>> input,
         BoundedWindow window,
         ProcessContext context,
         MultiOutputReceiver out)
-        throws InterruptedException {
+        throws InterruptedException, IOException {
       List<FailsafeValueInSingleWindow<TableRow, TableRow>> tableRows = new ArrayList<>();
       List<String> uniqueIds = new ArrayList<>();
       for (TableRowInfo<ElementT> row : input.getValue()) {
@@ -347,7 +376,13 @@
       TableReference tableReference = BigQueryHelpers.parseTableSpec(input.getKey().getKey());
       List<ValueInSingleWindow<ErrorT>> failedInserts = Lists.newArrayList();
       List<ValueInSingleWindow<TableRow>> successfulInserts = Lists.newArrayList();
-      flushRows(tableReference, tableRows, uniqueIds, options, failedInserts, successfulInserts);
+      flushRows(
+          getDatasetService(options),
+          tableReference,
+          tableRows,
+          uniqueIds,
+          failedInserts,
+          successfulInserts);
 
       for (ValueInSingleWindow<ErrorT> row : failedInserts) {
         out.get(failedOutputTag).output(row.getValue());
@@ -357,44 +392,43 @@
       }
       reportStreamingApiLogging(options);
     }
-  }
 
-  @Teardown
-  public void onTeardown() {
-    try {
-      if (datasetService != null) {
-        datasetService.close();
-        datasetService = null;
+    @Teardown
+    public void onTeardown() {
+      try {
+        if (datasetService != null) {
+          datasetService.close();
+          datasetService = null;
+        }
+      } catch (Exception e) {
+        throw new RuntimeException(e);
       }
-    } catch (Exception e) {
-      throw new RuntimeException(e);
     }
   }
 
   /** Writes the accumulated rows into BigQuery with streaming API. */
   private void flushRows(
+      DatasetService datasetService,
       TableReference tableReference,
       List<FailsafeValueInSingleWindow<TableRow, TableRow>> tableRows,
       List<String> uniqueIds,
-      BigQueryOptions options,
       List<ValueInSingleWindow<ErrorT>> failedInserts,
       List<ValueInSingleWindow<TableRow>> successfulInserts)
       throws InterruptedException {
     if (!tableRows.isEmpty()) {
       try {
         long totalBytes =
-            getDatasetService(options)
-                .insertAll(
-                    tableReference,
-                    tableRows,
-                    uniqueIds,
-                    retryPolicy,
-                    failedInserts,
-                    errorContainer,
-                    skipInvalidRows,
-                    ignoreUnknownValues,
-                    ignoreInsertIds,
-                    successfulInserts);
+            datasetService.insertAll(
+                tableReference,
+                tableRows,
+                uniqueIds,
+                retryPolicy,
+                failedInserts,
+                errorContainer,
+                skipInvalidRows,
+                ignoreUnknownValues,
+                ignoreInsertIds,
+                successfulInserts);
         byteCounter.inc(totalBytes);
       } catch (IOException e) {
         throw new RuntimeException(e);
@@ -402,13 +436,6 @@
     }
   }
 
-  private DatasetService getDatasetService(PipelineOptions pipelineOptions) throws IOException {
-    if (datasetService == null) {
-      datasetService = bqServices.getDatasetService(pipelineOptions.as(BigQueryOptions.class));
-    }
-    return datasetService;
-  }
-
   private void reportStreamingApiLogging(BigQueryOptions options) {
     MetricsContainer processWideContainer = MetricsEnvironment.getProcessWideContainer();
     if (processWideContainer instanceof MetricsLogger) {
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java
index 1d3d894..8b9b705 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java
@@ -965,49 +965,53 @@
       // earlier stages of the pipeline or if a query depends on earlier stages of a pipeline.
       // For these cases the withoutValidation method can be used to disable the check.
       if (getValidate()) {
-        if (table != null) {
-          checkArgument(table.isAccessible(), "Cannot call validate if table is dynamically set.");
-        }
-        if (table != null && table.get().getProjectId() != null) {
-          // Check for source table presence for early failure notification.
-          DatasetService datasetService = getBigQueryServices().getDatasetService(bqOptions);
-          BigQueryHelpers.verifyDatasetPresence(datasetService, table.get());
-          BigQueryHelpers.verifyTablePresence(datasetService, table.get());
-        } else if (getQuery() != null) {
-          checkArgument(
-              getQuery().isAccessible(), "Cannot call validate if query is dynamically set.");
-          JobService jobService = getBigQueryServices().getJobService(bqOptions);
-          try {
-            jobService.dryRunQuery(
-                bqOptions.getBigQueryProject() == null
-                    ? bqOptions.getProject()
-                    : bqOptions.getBigQueryProject(),
-                new JobConfigurationQuery()
-                    .setQuery(getQuery().get())
-                    .setFlattenResults(getFlattenResults())
-                    .setUseLegacySql(getUseLegacySql()),
-                getQueryLocation());
-          } catch (Exception e) {
-            throw new IllegalArgumentException(
-                String.format(QUERY_VALIDATION_FAILURE_ERROR, getQuery().get()), e);
+        try (DatasetService datasetService = getBigQueryServices().getDatasetService(bqOptions)) {
+          if (table != null) {
+            checkArgument(
+                table.isAccessible(), "Cannot call validate if table is dynamically set.");
           }
+          if (table != null && table.get().getProjectId() != null) {
+            // Check for source table presence for early failure notification.
+            BigQueryHelpers.verifyDatasetPresence(datasetService, table.get());
+            BigQueryHelpers.verifyTablePresence(datasetService, table.get());
+          } else if (getQuery() != null) {
+            checkArgument(
+                getQuery().isAccessible(), "Cannot call validate if query is dynamically set.");
+            JobService jobService = getBigQueryServices().getJobService(bqOptions);
+            try {
+              jobService.dryRunQuery(
+                  bqOptions.getBigQueryProject() == null
+                      ? bqOptions.getProject()
+                      : bqOptions.getBigQueryProject(),
+                  new JobConfigurationQuery()
+                      .setQuery(getQuery().get())
+                      .setFlattenResults(getFlattenResults())
+                      .setUseLegacySql(getUseLegacySql()),
+                  getQueryLocation());
+            } catch (Exception e) {
+              throw new IllegalArgumentException(
+                  String.format(QUERY_VALIDATION_FAILURE_ERROR, getQuery().get()), e);
+            }
 
-          DatasetService datasetService = getBigQueryServices().getDatasetService(bqOptions);
-          // If the user provided a temp dataset, check if the dataset exists before launching the
-          // query
-          if (getQueryTempDataset() != null) {
-            // The temp table is only used for dataset and project id validation, not for table name
-            // validation
-            TableReference tempTable =
-                new TableReference()
-                    .setProjectId(
-                        bqOptions.getBigQueryProject() == null
-                            ? bqOptions.getProject()
-                            : bqOptions.getBigQueryProject())
-                    .setDatasetId(getQueryTempDataset())
-                    .setTableId("dummy table");
-            BigQueryHelpers.verifyDatasetPresence(datasetService, tempTable);
+            // If the user provided a temp dataset, check if the dataset exists before launching the
+            // query
+            if (getQueryTempDataset() != null) {
+              // The temp table is only used for dataset and project id validation, not for table
+              // name
+              // validation
+              TableReference tempTable =
+                  new TableReference()
+                      .setProjectId(
+                          bqOptions.getBigQueryProject() == null
+                              ? bqOptions.getProject()
+                              : bqOptions.getBigQueryProject())
+                      .setDatasetId(getQueryTempDataset())
+                      .setTableId("dummy table");
+              BigQueryHelpers.verifyDatasetPresence(datasetService, tempTable);
+            }
           }
+        } catch (Exception e) {
+          throw new RuntimeException(e);
         }
       }
     }
@@ -1401,15 +1405,17 @@
                           options.getJobName(), jobUuid, JobType.QUERY),
                       queryTempDataset);
 
-              DatasetService datasetService = getBigQueryServices().getDatasetService(options);
-              LOG.info("Deleting temporary table with query results {}", tempTable);
-              datasetService.deleteTable(tempTable);
-              // Delete dataset only if it was created by Beam
-              boolean datasetCreatedByBeam = !queryTempDataset.isPresent();
-              if (datasetCreatedByBeam) {
-                LOG.info(
-                    "Deleting temporary dataset with query results {}", tempTable.getDatasetId());
-                datasetService.deleteDataset(tempTable.getProjectId(), tempTable.getDatasetId());
+              try (DatasetService datasetService =
+                  getBigQueryServices().getDatasetService(options)) {
+                LOG.info("Deleting temporary table with query results {}", tempTable);
+                datasetService.deleteTable(tempTable);
+                // Delete dataset only if it was created by Beam
+                boolean datasetCreatedByBeam = !queryTempDataset.isPresent();
+                if (datasetCreatedByBeam) {
+                  LOG.info(
+                      "Deleting temporary dataset with query results {}", tempTable.getDatasetId());
+                  datasetService.deleteDataset(tempTable.getProjectId(), tempTable.getDatasetId());
+                }
               }
             }
           };
@@ -2484,17 +2490,20 @@
       // The user specified a table.
       if (getJsonTableRef() != null && getJsonTableRef().isAccessible() && getValidate()) {
         TableReference table = getTableWithDefaultProject(options).get();
-        DatasetService datasetService = getBigQueryServices().getDatasetService(options);
-        // Check for destination table presence and emptiness for early failure notification.
-        // Note that a presence check can fail when the table or dataset is created by an earlier
-        // stage of the pipeline. For these cases the #withoutValidation method can be used to
-        // disable the check.
-        BigQueryHelpers.verifyDatasetPresence(datasetService, table);
-        if (getCreateDisposition() == BigQueryIO.Write.CreateDisposition.CREATE_NEVER) {
-          BigQueryHelpers.verifyTablePresence(datasetService, table);
-        }
-        if (getWriteDisposition() == BigQueryIO.Write.WriteDisposition.WRITE_EMPTY) {
-          BigQueryHelpers.verifyTableNotExistOrEmpty(datasetService, table);
+        try (DatasetService datasetService = getBigQueryServices().getDatasetService(options)) {
+          // Check for destination table presence and emptiness for early failure notification.
+          // Note that a presence check can fail when the table or dataset is created by an earlier
+          // stage of the pipeline. For these cases the #withoutValidation method can be used to
+          // disable the check.
+          BigQueryHelpers.verifyDatasetPresence(datasetService, table);
+          if (getCreateDisposition() == BigQueryIO.Write.CreateDisposition.CREATE_NEVER) {
+            BigQueryHelpers.verifyTablePresence(datasetService, table);
+          }
+          if (getWriteDisposition() == BigQueryIO.Write.WriteDisposition.WRITE_EMPTY) {
+            BigQueryHelpers.verifyTableNotExistOrEmpty(datasetService, table);
+          }
+        } catch (Exception e) {
+          throw new RuntimeException(e);
         }
       }
     }