[BEAM-8598] Test triggering BEAM-8598 on FlinkRunner. (#10049)

diff --git a/runners/flink/flink_runner.gradle b/runners/flink/flink_runner.gradle
index 3254a85..6281b94 100644
--- a/runners/flink/flink_runner.gradle
+++ b/runners/flink/flink_runner.gradle
@@ -200,6 +200,7 @@
       excludeCategories 'org.apache.beam.sdk.testing.UsesCommittedMetrics'
       if (config.streaming) {
         excludeCategories 'org.apache.beam.sdk.testing.UsesImpulse'
+        excludeCategories 'org.apache.beam.sdk.testing.UsesTestStreamWithMultipleStages'  // BEAM-8598
         excludeCategories 'org.apache.beam.sdk.testing.UsesTestStreamWithProcessingTime'
       } else {
         excludeCategories 'org.apache.beam.sdk.testing.UsesUnboundedSplittableParDo'
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/UsesTestStreamWithMultipleStages.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/UsesTestStreamWithMultipleStages.java
new file mode 100644
index 0000000..55999ce
--- /dev/null
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/UsesTestStreamWithMultipleStages.java
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.testing;
+
+/**
+ * Subcategory for {@link UsesTestStream} tests which use {@link TestStream} # across multiple
+ * stages. Some Runners do not properly support quiescence in a way that {@link TestStream} demands
+ * it.
+ */
+public interface UsesTestStreamWithMultipleStages extends UsesTestStream {}
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/TestStreamTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/TestStreamTest.java
index 5e4cdcb..e48b6b2 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/TestStreamTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/TestStreamTest.java
@@ -17,6 +17,7 @@
  */
 package org.apache.beam.sdk.testing;
 
+import static org.apache.beam.sdk.transforms.windowing.Window.into;
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.Matchers.allOf;
 import static org.hamcrest.Matchers.greaterThanOrEqualTo;
@@ -28,12 +29,22 @@
 import org.apache.beam.sdk.coders.StringUtf8Coder;
 import org.apache.beam.sdk.coders.VarIntCoder;
 import org.apache.beam.sdk.coders.VarLongCoder;
+import org.apache.beam.sdk.state.StateSpec;
+import org.apache.beam.sdk.state.StateSpecs;
+import org.apache.beam.sdk.state.TimeDomain;
+import org.apache.beam.sdk.state.Timer;
+import org.apache.beam.sdk.state.TimerSpec;
+import org.apache.beam.sdk.state.TimerSpecs;
+import org.apache.beam.sdk.state.ValueState;
 import org.apache.beam.sdk.testing.TestStream.Builder;
 import org.apache.beam.sdk.transforms.Combine;
 import org.apache.beam.sdk.transforms.Count;
+import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.Flatten;
 import org.apache.beam.sdk.transforms.GroupByKey;
+import org.apache.beam.sdk.transforms.Keys;
 import org.apache.beam.sdk.transforms.MapElements;
+import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.transforms.Sum;
 import org.apache.beam.sdk.transforms.Values;
 import org.apache.beam.sdk.transforms.WithKeys;
@@ -44,6 +55,7 @@
 import org.apache.beam.sdk.transforms.windowing.DefaultTrigger;
 import org.apache.beam.sdk.transforms.windowing.FixedWindows;
 import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
+import org.apache.beam.sdk.transforms.windowing.GlobalWindows;
 import org.apache.beam.sdk.transforms.windowing.IntervalWindow;
 import org.apache.beam.sdk.transforms.windowing.Never;
 import org.apache.beam.sdk.transforms.windowing.Window;
@@ -51,8 +63,10 @@
 import org.apache.beam.sdk.util.CoderUtils;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionList;
 import org.apache.beam.sdk.values.TimestampedValue;
 import org.apache.beam.sdk.values.TypeDescriptors;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
 import org.joda.time.Duration;
 import org.joda.time.Instant;
 import org.junit.Rule;
@@ -263,7 +277,7 @@
     FixedWindows windows = FixedWindows.of(Duration.standardHours(6));
     PCollection<String> windowedValues =
         p.apply(stream)
-            .apply(Window.into(windows))
+            .apply(into(windows))
             .apply(WithKeys.of(1))
             .apply(GroupByKey.create())
             .apply(Values.create())
@@ -387,6 +401,74 @@
   }
 
   @Test
+  @Category({ValidatesRunner.class, UsesTestStream.class, UsesTestStreamWithMultipleStages.class})
+  public void testMultiStage() throws Exception {
+    TestStream<String> testStream =
+        TestStream.create(StringUtf8Coder.of())
+            .addElements("before") // before
+            .advanceWatermarkTo(Instant.ofEpochSecond(0)) // BEFORE
+            .addElements(TimestampedValue.of("after", Instant.ofEpochSecond(10))) // after
+            .advanceWatermarkToInfinity(); // AFTER
+
+    PCollection<String> input = p.apply(testStream);
+
+    PCollection<String> grouped =
+        input
+            .apply(Window.into(FixedWindows.of(Duration.standardSeconds(1))))
+            .apply(
+                MapElements.into(
+                        TypeDescriptors.kvs(TypeDescriptors.strings(), TypeDescriptors.strings()))
+                    .via(e -> KV.of(e, e)))
+            .apply(GroupByKey.create())
+            .apply(Keys.create())
+            .apply("Upper", MapElements.into(TypeDescriptors.strings()).via(String::toUpperCase))
+            .apply("Rewindow", Window.into(new GlobalWindows()));
+
+    PCollection<String> result =
+        PCollectionList.of(ImmutableList.of(input, grouped))
+            .apply(Flatten.pCollections())
+            .apply(
+                "Key",
+                MapElements.into(
+                        TypeDescriptors.kvs(TypeDescriptors.strings(), TypeDescriptors.strings()))
+                    .via(e -> KV.of("key", e)))
+            .apply(
+                ParDo.of(
+                    new DoFn<KV<String, String>, String>() {
+                      @StateId("seen")
+                      private final StateSpec<ValueState<String>> seenSpec =
+                          StateSpecs.value(StringUtf8Coder.of());
+
+                      @TimerId("emit")
+                      private final TimerSpec emitSpec = TimerSpecs.timer(TimeDomain.EVENT_TIME);
+
+                      @ProcessElement
+                      public void process(
+                          ProcessContext context,
+                          @StateId("seen") ValueState<String> seenState,
+                          @TimerId("emit") Timer emitTimer) {
+                        String element = context.element().getValue();
+                        if (seenState.read() == null) {
+                          seenState.write(element);
+                        } else {
+                          seenState.write(seenState.read() + "," + element);
+                        }
+                        emitTimer.set(Instant.ofEpochSecond(100));
+                      }
+
+                      @OnTimer("emit")
+                      public void onEmit(
+                          OnTimerContext context, @StateId("seen") ValueState<String> seenState) {
+                        context.output(seenState.read());
+                      }
+                    }));
+
+    PAssert.that(result).containsInAnyOrder("before,BEFORE,after,AFTER");
+
+    p.run().waitUntilFinish();
+  }
+
+  @Test
   @Category(UsesTestStreamWithProcessingTime.class)
   public void testCoder() throws Exception {
     TestStream<String> testStream =