Merge pull request #14741 from y1chi/beam-12294

[BEAM-12294] Implement close function for BeamFnStatusClient to shutd…
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnHarness.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnHarness.java
index 231d2b8..0ca4581 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnHarness.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnHarness.java
@@ -272,12 +272,14 @@
               finalizeBundleHandler,
               metricsShortIds);
 
+      BeamFnStatusClient beamFnStatusClient = null;
       if (statusApiServiceDescriptor != null) {
-        new BeamFnStatusClient(
-            statusApiServiceDescriptor,
-            channelFactory::forDescriptor,
-            processBundleHandler.getBundleProcessorCache(),
-            options);
+        beamFnStatusClient =
+            new BeamFnStatusClient(
+                statusApiServiceDescriptor,
+                channelFactory::forDescriptor,
+                processBundleHandler.getBundleProcessorCache(),
+                options);
       }
 
       // TODO(BEAM-9729): Remove once runners no longer send this instruction.
@@ -337,6 +339,9 @@
               executorService,
               handlers);
       control.waitForTermination();
+      if (beamFnStatusClient != null) {
+        beamFnStatusClient.close();
+      }
       processBundleHandler.shutdown();
     } finally {
       System.out.println("Shutting SDK harness down.");
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/status/BeamFnStatusClient.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/status/BeamFnStatusClient.java
index 4c01c04..e059471 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/status/BeamFnStatusClient.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/status/BeamFnStatusClient.java
@@ -25,6 +25,8 @@
 import java.util.Map;
 import java.util.Objects;
 import java.util.StringJoiner;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.TimeUnit;
 import java.util.function.Function;
 import org.apache.beam.fn.harness.control.ProcessBundleHandler.BundleProcessor;
 import org.apache.beam.fn.harness.control.ProcessBundleHandler.BundleProcessorCache;
@@ -42,9 +44,12 @@
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-public class BeamFnStatusClient {
+public class BeamFnStatusClient implements AutoCloseable {
+  private static final Object COMPLETED = new Object();
   private final StreamObserver<WorkerStatusResponse> outboundObserver;
   private final BundleProcessorCache processBundleCache;
+  private final ManagedChannel channel;
+  private final CompletableFuture<Object> inboundObserverCompletion;
   private static final Logger LOG = LoggerFactory.getLogger(BeamFnStatusClient.class);
   private final MemoryMonitor memoryMonitor;
 
@@ -53,11 +58,12 @@
       Function<ApiServiceDescriptor, ManagedChannel> channelFactory,
       BundleProcessorCache processBundleCache,
       PipelineOptions options) {
-    BeamFnWorkerStatusGrpc.BeamFnWorkerStatusStub stub =
-        BeamFnWorkerStatusGrpc.newStub(channelFactory.apply(apiServiceDescriptor));
-    this.outboundObserver = stub.workerStatus(new InboundObserver());
+    this.channel = channelFactory.apply(apiServiceDescriptor);
+    this.outboundObserver =
+        BeamFnWorkerStatusGrpc.newStub(channel).workerStatus(new InboundObserver());
     this.processBundleCache = processBundleCache;
     this.memoryMonitor = MemoryMonitor.fromOptions(options);
+    this.inboundObserverCompletion = new CompletableFuture<>();
     Thread thread = new Thread(memoryMonitor);
     thread.setDaemon(true);
     thread.setPriority(Thread.MIN_PRIORITY);
@@ -65,6 +71,22 @@
     thread.start();
   }
 
+  @Override
+  public void close() throws Exception {
+    try {
+      Object completion = inboundObserverCompletion.get(1, TimeUnit.MINUTES);
+      if (completion != COMPLETED) {
+        LOG.warn("InboundObserver for BeamFnStatusClient completed with exception.");
+      }
+    } finally {
+      // Shut the channel down
+      channel.shutdown();
+      if (!channel.awaitTermination(10, TimeUnit.SECONDS)) {
+        channel.shutdownNow();
+      }
+    }
+  }
+
   /**
    * Class representing the execution state of a thread.
    *
@@ -222,9 +244,12 @@
     @Override
     public void onError(Throwable t) {
       LOG.error("Error getting SDK harness status", t);
+      inboundObserverCompletion.completeExceptionally(t);
     }
 
     @Override
-    public void onCompleted() {}
+    public void onCompleted() {
+      inboundObserverCompletion.complete(COMPLETED);
+    }
   }
 }