[BEAM-10470] Handle null state from waitUntilFinish
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java
index 988921e..f7c74c0 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java
@@ -337,26 +337,30 @@
      */
     @Override
     public State waitUntilFinish(Duration duration) {
-      State startState = this.state;
-      if (!startState.isTerminal()) {
-        try {
-          state = executor.waitUntilFinish(duration);
-        } catch (UserCodeException uce) {
-          // Emulates the behavior of Pipeline#run(), where a stack trace caused by a
-          // UserCodeException is truncated and replaced with the stack starting at the call to
-          // waitToFinish
-          throw new Pipeline.PipelineExecutionException(uce.getCause());
-        } catch (Exception e) {
-          if (e instanceof InterruptedException) {
-            Thread.currentThread().interrupt();
-          }
-          if (e instanceof RuntimeException) {
-            throw (RuntimeException) e;
-          }
-          throw new RuntimeException(e);
-        }
+      if (this.state.isTerminal()) {
+        return this.state;
       }
-      return this.state;
+      final State endState;
+      try {
+        endState = executor.waitUntilFinish(duration);
+      } catch (UserCodeException uce) {
+        // Emulates the behavior of Pipeline#run(), where a stack trace caused by a
+        // UserCodeException is truncated and replaced with the stack starting at the call to
+        // waitToFinish
+        throw new Pipeline.PipelineExecutionException(uce.getCause());
+      } catch (Exception e) {
+        if (e instanceof InterruptedException) {
+          Thread.currentThread().interrupt();
+        }
+        if (e instanceof RuntimeException) {
+          throw (RuntimeException) e;
+        }
+        throw new RuntimeException(e);
+      }
+      if (endState != null) {
+        this.state = endState;
+      }
+      return endState;
     }
   }
 
diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectRunnerTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectRunnerTest.java
index 8054a07..fbcf0c0 100644
--- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectRunnerTest.java
+++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectRunnerTest.java
@@ -328,8 +328,9 @@
     // The pipeline should never complete;
     assertThat(result.getState(), is(State.RUNNING));
     // Must time out, otherwise this test will never complete
-    result.waitUntilFinish(Duration.millis(1L));
-    assertEquals(null, result.getState());
+    assertEquals(null, result.waitUntilFinish(Duration.millis(1L)));
+    // Ensure multiple calls complete
+    assertEquals(null, result.waitUntilFinish(Duration.millis(1L)));
   }
 
   private static final AtomicLong TEARDOWN_CALL = new AtomicLong(-1);