Merge pull request #10079 [BEAM-8621] Fix dependency checking job.

diff --git a/runners/flink/flink_runner.gradle b/runners/flink/flink_runner.gradle
index 3254a85..6281b94 100644
--- a/runners/flink/flink_runner.gradle
+++ b/runners/flink/flink_runner.gradle
@@ -200,6 +200,7 @@
       excludeCategories 'org.apache.beam.sdk.testing.UsesCommittedMetrics'
       if (config.streaming) {
         excludeCategories 'org.apache.beam.sdk.testing.UsesImpulse'
+        excludeCategories 'org.apache.beam.sdk.testing.UsesTestStreamWithMultipleStages'  // BEAM-8598
         excludeCategories 'org.apache.beam.sdk.testing.UsesTestStreamWithProcessingTime'
       } else {
         excludeCategories 'org.apache.beam.sdk.testing.UsesUnboundedSplittableParDo'
diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java
index 915acfb..eed52e9 100644
--- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java
+++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java
@@ -94,6 +94,8 @@
 import org.apache.flink.api.common.state.ListStateDescriptor;
 import org.apache.flink.api.common.typeutils.base.StringSerializer;
 import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
+import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyedStateBackend;
 import org.apache.flink.streaming.api.operators.InternalTimer;
 import org.apache.flink.streaming.api.watermark.Watermark;
@@ -246,7 +248,8 @@
                   () -> UUID.randomUUID().toString(),
                   keyedStateInternals,
                   getKeyedStateBackend(),
-                  stateBackendLock));
+                  stateBackendLock,
+                  keyCoder));
     } else {
       userStateRequestHandler = StateRequestHandler.unsupported();
     }
@@ -263,7 +266,12 @@
 
     private final StateInternals stateInternals;
     private final KeyedStateBackend<ByteBuffer> keyedStateBackend;
+    /** Lock to hold whenever accessing the state backend. */
     private final Lock stateBackendLock;
+    /** For debugging: The key coder used by the Runner. */
+    @Nullable private final Coder runnerKeyCoder;
+    /** For debugging: Same as keyedStateBackend but upcasted, to access key group meta info. */
+    @Nullable private final AbstractKeyedStateBackend<ByteBuffer> keyStateBackendWithKeyGroupInfo;
     /** Holds the valid cache token for user state for this operator. */
     private final ByteString cacheToken;
 
@@ -271,10 +279,20 @@
         IdGenerator cacheTokenGenerator,
         StateInternals stateInternals,
         KeyedStateBackend<ByteBuffer> keyedStateBackend,
-        Lock stateBackendLock) {
+        Lock stateBackendLock,
+        @Nullable Coder runnerKeyCoder) {
       this.stateInternals = stateInternals;
       this.keyedStateBackend = keyedStateBackend;
       this.stateBackendLock = stateBackendLock;
+      if (keyedStateBackend instanceof AbstractKeyedStateBackend) {
+        // This will always succeed, unless a custom state backend is used which does not extend
+        // AbstractKeyedStateBackend. This is unlikely but we should still consider this case.
+        this.keyStateBackendWithKeyGroupInfo =
+            (AbstractKeyedStateBackend<ByteBuffer>) keyedStateBackend;
+      } else {
+        this.keyStateBackendWithKeyGroupInfo = null;
+      }
+      this.runnerKeyCoder = runnerKeyCoder;
       this.cacheToken = ByteString.copyFrom(cacheTokenGenerator.getId().getBytes(Charsets.UTF_8));
     }
 
@@ -368,6 +386,19 @@
           // Key for state request is shipped encoded with NESTED context.
           ByteBuffer encodedKey = FlinkKeyUtils.fromEncodedKey(key);
           keyedStateBackend.setCurrentKey(encodedKey);
+          if (keyStateBackendWithKeyGroupInfo != null) {
+            int currentKeyGroupIndex = keyStateBackendWithKeyGroupInfo.getCurrentKeyGroupIndex();
+            KeyGroupRange keyGroupRange = keyStateBackendWithKeyGroupInfo.getKeyGroupRange();
+            Preconditions.checkState(
+                keyGroupRange.contains(currentKeyGroupIndex),
+                "The current key '%s' with key group index '%s' does not belong to the key group range '%s'. Runner keyCoder: %s. Ptransformid: %s Userstateid: %s",
+                Arrays.toString(key.toByteArray()),
+                currentKeyGroupIndex,
+                keyGroupRange,
+                runnerKeyCoder,
+                pTransformId,
+                userStateId);
+          }
         }
       };
     }
diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperatorTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperatorTest.java
index 0d7c99f..0cbf6b7 100644
--- a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperatorTest.java
+++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperatorTest.java
@@ -659,7 +659,7 @@
       ExecutableStageDoFnOperator.BagUserStateFactory<ByteString, Integer, GlobalWindow>
           bagUserStateFactory =
               new ExecutableStageDoFnOperator.BagUserStateFactory<>(
-                  cacheTokenGenerator, test, stateBackend, NoopLock.get());
+                  cacheTokenGenerator, test, stateBackend, NoopLock.get(), null);
 
       ByteString key1 = ByteString.copyFrom("key1", Charsets.UTF_8);
       ByteString key2 = ByteString.copyFrom("key2", Charsets.UTF_8);
diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptors.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptors.java
index b324cb1..a2129c4 100644
--- a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptors.java
+++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptors.java
@@ -57,6 +57,7 @@
 import org.apache.beam.sdk.util.WindowedValue.FullWindowedValueCoder;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.vendor.grpc.v1p21p0.com.google.protobuf.InvalidProtocolBufferException;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableTable;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
@@ -135,6 +136,10 @@
         forTimerSpecs(
             dataEndpoint, stage, components, inputDestinationsBuilder, remoteOutputCodersBuilder);
 
+    if (bagUserStateSpecs.size() > 0 || timerSpecs.size() > 0) {
+      lengthPrefixKeyCoder(stage.getInputPCollection().getId(), components);
+    }
+
     // Copy data from components to ProcessBundleDescriptor.
     ProcessBundleDescriptor.Builder bundleDescriptorBuilder =
         ProcessBundleDescriptor.newBuilder().setId(id);
@@ -158,6 +163,29 @@
         timerSpecs);
   }
 
+  /**
+   * Patches the input coder of a stateful Executable transform to ensure that the byte
+   * representation of a key used to partition the input element at the Runner, matches the key byte
+   * representation received for state requests and timers from the SDK Harness. Stateful transforms
+   * always have a KvCoder as input.
+   */
+  private static void lengthPrefixKeyCoder(
+      String inputColId, Components.Builder componentsBuilder) {
+    RunnerApi.PCollection pcollection = componentsBuilder.getPcollectionsOrThrow(inputColId);
+    RunnerApi.Coder kvCoder = componentsBuilder.getCodersOrThrow(pcollection.getCoderId());
+    Preconditions.checkState(
+        ModelCoders.KV_CODER_URN.equals(kvCoder.getSpec().getUrn()),
+        "Stateful executable stages must use a KV coder, but is: %s",
+        kvCoder.getSpec().getUrn());
+    String keyCoderId = ModelCoders.getKvCoderComponents(kvCoder).keyCoderId();
+    // Retain the original coder, but wrap in LengthPrefixCoder
+    String newKeyCoderId =
+        LengthPrefixUnknownCoders.addLengthPrefixedCoder(keyCoderId, componentsBuilder, false);
+    // Replace old key coder with LengthPrefixCoder<old_key_coder>
+    kvCoder = kvCoder.toBuilder().setComponentCoderIds(0, newKeyCoderId).build();
+    componentsBuilder.putCoders(pcollection.getCoderId(), kvCoder);
+  }
+
   private static Map<String, Coder<WindowedValue<?>>> addStageOutputs(
       ApiServiceDescriptor dataEndpoint,
       Collection<PCollectionNode> outputPCollections,
diff --git a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptorsTest.java b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptorsTest.java
new file mode 100644
index 0000000..b558b00
--- /dev/null
+++ b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptorsTest.java
@@ -0,0 +1,130 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.fnexecution.control;
+
+import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState;
+import static org.hamcrest.CoreMatchers.is;
+import static org.hamcrest.MatcherAssert.assertThat;
+
+import java.io.Serializable;
+import java.util.Map;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi;
+import org.apache.beam.model.pipeline.v1.Endpoints;
+import org.apache.beam.model.pipeline.v1.RunnerApi;
+import org.apache.beam.runners.core.construction.CoderTranslation;
+import org.apache.beam.runners.core.construction.ModelCoderRegistrar;
+import org.apache.beam.runners.core.construction.ModelCoders;
+import org.apache.beam.runners.core.construction.PipelineTranslation;
+import org.apache.beam.runners.core.construction.graph.ExecutableStage;
+import org.apache.beam.runners.core.construction.graph.FusedPipeline;
+import org.apache.beam.runners.core.construction.graph.GreedyPipelineFuser;
+import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.coders.VoidCoder;
+import org.apache.beam.sdk.state.BagState;
+import org.apache.beam.sdk.state.StateSpec;
+import org.apache.beam.sdk.state.StateSpecs;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.GroupByKey;
+import org.apache.beam.sdk.transforms.Impulse;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Optional;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
+import org.junit.Test;
+
+/** Tests for {@link ProcessBundleDescriptors}. */
+public class ProcessBundleDescriptorsTest implements Serializable {
+
+  /**
+   * Tests that a stateful Executable stage will wrap a key coder of a stateful transform in a
+   * LengthPrefixCoder.
+   */
+  @Test
+  public void testWrapKeyCoderOfStatefulExecutableStageInLengthPrefixCoder() throws Exception {
+    // Add another stateful stage with a non-standard key coder
+    Pipeline p = Pipeline.create();
+    Coder<Void> keycoder = VoidCoder.of();
+    assertThat(ModelCoderRegistrar.isKnownCoder(keycoder), is(false));
+    p.apply("impulse", Impulse.create())
+        .apply(
+            "create",
+            ParDo.of(
+                new DoFn<byte[], KV<Void, String>>() {
+                  @ProcessElement
+                  public void process(ProcessContext ctxt) {}
+                }))
+        .setCoder(KvCoder.of(keycoder, StringUtf8Coder.of()))
+        .apply(
+            "userState",
+            ParDo.of(
+                new DoFn<KV<Void, String>, KV<Void, String>>() {
+
+                  @StateId("stateId")
+                  private final StateSpec<BagState<String>> bufferState =
+                      StateSpecs.bag(StringUtf8Coder.of());
+
+                  @ProcessElement
+                  public void processElement(
+                      @Element KV<Void, String> element,
+                      @StateId("stateId") BagState<String> state,
+                      OutputReceiver<KV<Void, String>> r) {
+                    for (String value : state.read()) {
+                      r.output(KV.of(element.getKey(), value));
+                    }
+                    state.add(element.getValue());
+                  }
+                }))
+        // Force the output to be materialized
+        .apply("gbk", GroupByKey.create());
+
+    RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p);
+    FusedPipeline fused = GreedyPipelineFuser.fuse(pipelineProto);
+    Optional<ExecutableStage> optionalStage =
+        Iterables.tryFind(
+            fused.getFusedStages(),
+            (ExecutableStage stage) ->
+                stage.getUserStates().stream()
+                    .anyMatch(spec -> spec.localName().equals("stateId")));
+    checkState(optionalStage.isPresent(), "Expected a stage with user state.");
+    ExecutableStage stage = optionalStage.get();
+
+    ProcessBundleDescriptors.ExecutableProcessBundleDescriptor descriptor =
+        ProcessBundleDescriptors.fromExecutableStage(
+            "test_stage", stage, Endpoints.ApiServiceDescriptor.getDefaultInstance());
+
+    BeamFnApi.ProcessBundleDescriptor pbDescriptor = descriptor.getProcessBundleDescriptor();
+    String inputColId = stage.getInputPCollection().getId();
+    String inputCoderId = pbDescriptor.getPcollectionsMap().get(inputColId).getCoderId();
+
+    Map<String, RunnerApi.Coder> codersMap = pbDescriptor.getCodersMap();
+    RunnerApi.Coder coder = codersMap.get(inputCoderId);
+    String keyCoderId = ModelCoders.getKvCoderComponents(coder).keyCoderId();
+
+    assertThat(
+        codersMap.get(keyCoderId).getSpec().getUrn(), is(ModelCoders.LENGTH_PREFIX_CODER_URN));
+
+    RunnerApi.Coder orignalCoder = stage.getComponents().getCodersMap().get(inputCoderId);
+    String originalKeyCoderId = ModelCoders.getKvCoderComponents(orignalCoder).keyCoderId();
+    assertThat(
+        stage.getComponents().getCodersMap().get(originalKeyCoderId).getSpec().getUrn(),
+        is(CoderTranslation.JAVA_SERIALIZED_CODER_URN));
+  }
+}
diff --git a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java
index 0925767..5d4c8f0 100644
--- a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java
+++ b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java
@@ -18,10 +18,10 @@
 package org.apache.beam.runners.fnexecution.control;
 
 import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState;
+import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.Matchers.allOf;
 import static org.hamcrest.Matchers.containsInAnyOrder;
 import static org.hamcrest.Matchers.equalTo;
-import static org.junit.Assert.assertThat;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
 
@@ -946,8 +946,6 @@
   @Test
   public void testExecutionWithTimer() throws Exception {
     Pipeline p = Pipeline.create();
-    final String timerId = "foo";
-    final String timerId2 = "foo2";
 
     p.apply("impulse", Impulse.create())
         .apply(
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/UsesTestStreamWithMultipleStages.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/UsesTestStreamWithMultipleStages.java
new file mode 100644
index 0000000..55999ce
--- /dev/null
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/UsesTestStreamWithMultipleStages.java
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.testing;
+
+/**
+ * Subcategory for {@link UsesTestStream} tests which use {@link TestStream} # across multiple
+ * stages. Some Runners do not properly support quiescence in a way that {@link TestStream} demands
+ * it.
+ */
+public interface UsesTestStreamWithMultipleStages extends UsesTestStream {}
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/TestStreamTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/TestStreamTest.java
index 5e4cdcb..e48b6b2 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/TestStreamTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/TestStreamTest.java
@@ -17,6 +17,7 @@
  */
 package org.apache.beam.sdk.testing;
 
+import static org.apache.beam.sdk.transforms.windowing.Window.into;
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.Matchers.allOf;
 import static org.hamcrest.Matchers.greaterThanOrEqualTo;
@@ -28,12 +29,22 @@
 import org.apache.beam.sdk.coders.StringUtf8Coder;
 import org.apache.beam.sdk.coders.VarIntCoder;
 import org.apache.beam.sdk.coders.VarLongCoder;
+import org.apache.beam.sdk.state.StateSpec;
+import org.apache.beam.sdk.state.StateSpecs;
+import org.apache.beam.sdk.state.TimeDomain;
+import org.apache.beam.sdk.state.Timer;
+import org.apache.beam.sdk.state.TimerSpec;
+import org.apache.beam.sdk.state.TimerSpecs;
+import org.apache.beam.sdk.state.ValueState;
 import org.apache.beam.sdk.testing.TestStream.Builder;
 import org.apache.beam.sdk.transforms.Combine;
 import org.apache.beam.sdk.transforms.Count;
+import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.Flatten;
 import org.apache.beam.sdk.transforms.GroupByKey;
+import org.apache.beam.sdk.transforms.Keys;
 import org.apache.beam.sdk.transforms.MapElements;
+import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.transforms.Sum;
 import org.apache.beam.sdk.transforms.Values;
 import org.apache.beam.sdk.transforms.WithKeys;
@@ -44,6 +55,7 @@
 import org.apache.beam.sdk.transforms.windowing.DefaultTrigger;
 import org.apache.beam.sdk.transforms.windowing.FixedWindows;
 import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
+import org.apache.beam.sdk.transforms.windowing.GlobalWindows;
 import org.apache.beam.sdk.transforms.windowing.IntervalWindow;
 import org.apache.beam.sdk.transforms.windowing.Never;
 import org.apache.beam.sdk.transforms.windowing.Window;
@@ -51,8 +63,10 @@
 import org.apache.beam.sdk.util.CoderUtils;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionList;
 import org.apache.beam.sdk.values.TimestampedValue;
 import org.apache.beam.sdk.values.TypeDescriptors;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
 import org.joda.time.Duration;
 import org.joda.time.Instant;
 import org.junit.Rule;
@@ -263,7 +277,7 @@
     FixedWindows windows = FixedWindows.of(Duration.standardHours(6));
     PCollection<String> windowedValues =
         p.apply(stream)
-            .apply(Window.into(windows))
+            .apply(into(windows))
             .apply(WithKeys.of(1))
             .apply(GroupByKey.create())
             .apply(Values.create())
@@ -387,6 +401,74 @@
   }
 
   @Test
+  @Category({ValidatesRunner.class, UsesTestStream.class, UsesTestStreamWithMultipleStages.class})
+  public void testMultiStage() throws Exception {
+    TestStream<String> testStream =
+        TestStream.create(StringUtf8Coder.of())
+            .addElements("before") // before
+            .advanceWatermarkTo(Instant.ofEpochSecond(0)) // BEFORE
+            .addElements(TimestampedValue.of("after", Instant.ofEpochSecond(10))) // after
+            .advanceWatermarkToInfinity(); // AFTER
+
+    PCollection<String> input = p.apply(testStream);
+
+    PCollection<String> grouped =
+        input
+            .apply(Window.into(FixedWindows.of(Duration.standardSeconds(1))))
+            .apply(
+                MapElements.into(
+                        TypeDescriptors.kvs(TypeDescriptors.strings(), TypeDescriptors.strings()))
+                    .via(e -> KV.of(e, e)))
+            .apply(GroupByKey.create())
+            .apply(Keys.create())
+            .apply("Upper", MapElements.into(TypeDescriptors.strings()).via(String::toUpperCase))
+            .apply("Rewindow", Window.into(new GlobalWindows()));
+
+    PCollection<String> result =
+        PCollectionList.of(ImmutableList.of(input, grouped))
+            .apply(Flatten.pCollections())
+            .apply(
+                "Key",
+                MapElements.into(
+                        TypeDescriptors.kvs(TypeDescriptors.strings(), TypeDescriptors.strings()))
+                    .via(e -> KV.of("key", e)))
+            .apply(
+                ParDo.of(
+                    new DoFn<KV<String, String>, String>() {
+                      @StateId("seen")
+                      private final StateSpec<ValueState<String>> seenSpec =
+                          StateSpecs.value(StringUtf8Coder.of());
+
+                      @TimerId("emit")
+                      private final TimerSpec emitSpec = TimerSpecs.timer(TimeDomain.EVENT_TIME);
+
+                      @ProcessElement
+                      public void process(
+                          ProcessContext context,
+                          @StateId("seen") ValueState<String> seenState,
+                          @TimerId("emit") Timer emitTimer) {
+                        String element = context.element().getValue();
+                        if (seenState.read() == null) {
+                          seenState.write(element);
+                        } else {
+                          seenState.write(seenState.read() + "," + element);
+                        }
+                        emitTimer.set(Instant.ofEpochSecond(100));
+                      }
+
+                      @OnTimer("emit")
+                      public void onEmit(
+                          OnTimerContext context, @StateId("seen") ValueState<String> seenState) {
+                        context.output(seenState.read());
+                      }
+                    }));
+
+    PAssert.that(result).containsInAnyOrder("before,BEFORE,after,AFTER");
+
+    p.run().waitUntilFinish();
+  }
+
+  @Test
   @Category(UsesTestStreamWithProcessingTime.class)
   public void testCoder() throws Exception {
     TestStream<String> testStream =
diff --git a/sdks/python/apache_beam/runners/portability/portable_runner_test.py b/sdks/python/apache_beam/runners/portability/portable_runner_test.py
index 24c6b87..992aa08 100644
--- a/sdks/python/apache_beam/runners/portability/portable_runner_test.py
+++ b/sdks/python/apache_beam/runners/portability/portable_runner_test.py
@@ -45,7 +45,10 @@
 from apache_beam.runners.portability.portable_runner import PortableRunner
 from apache_beam.runners.worker import worker_pool_main
 from apache_beam.runners.worker.channel_factory import GRPCChannelFactory
+from apache_beam.testing.util import assert_that
+from apache_beam.testing.util import equal_to
 from apache_beam.transforms import environments
+from apache_beam.transforms import userstate
 
 
 class PortableRunnerTest(fn_api_runner_test.FnApiRunnerTest):
@@ -188,6 +191,46 @@
   def test_metrics(self):
     self.skipTest('Metrics not supported.')
 
+  def test_pardo_state_with_custom_key_coder(self):
+    """Tests that state requests work correctly when the key coder is an
+    SDK-specific coder, i.e. non standard coder. This is additionally enforced
+    by Java's ProcessBundleDescriptorsTest and by Flink's
+    ExecutableStageDoFnOperator which detects invalid encoding by checking for
+    the correct key group of the encoded key."""
+    index_state_spec = userstate.CombiningValueStateSpec('index', sum)
+
+    # Test params
+    # Ensure decent amount of elements to serve all partitions
+    n = 200
+    duplicates = 1
+
+    split = n // (duplicates + 1)
+    inputs = [(i % split, str(i % split)) for i in range(0, n)]
+
+    # Use a DoFn which has to use FastPrimitivesCoder because the type cannot
+    # be inferred
+    class Input(beam.DoFn):
+      def process(self, impulse):
+        for i in inputs:
+          yield i
+
+    class AddIndex(beam.DoFn):
+      def process(self, kv,
+                  index=beam.DoFn.StateParam(index_state_spec)):
+        k, v = kv
+        index.add(1)
+        yield k, v, index.read()
+
+    expected = [(i % split, str(i % split), i // split + 1)
+                for i in range(0, n)]
+
+    with self.create_pipeline() as p:
+      assert_that(p
+                  | beam.Impulse()
+                  | beam.ParDo(Input())
+                  | beam.ParDo(AddIndex()),
+                  equal_to(expected))
+
   # Inherits all other tests from fn_api_runner_test.FnApiRunnerTest
 
 
diff --git a/sdks/python/apache_beam/testing/test_stream.py b/sdks/python/apache_beam/testing/test_stream.py
index 610a1a8..9d9284c 100644
--- a/sdks/python/apache_beam/testing/test_stream.py
+++ b/sdks/python/apache_beam/testing/test_stream.py
@@ -31,6 +31,8 @@
 from apache_beam import coders
 from apache_beam import core
 from apache_beam import pvalue
+from apache_beam.portability import common_urns
+from apache_beam.portability.api import beam_runner_api_pb2
 from apache_beam.transforms import PTransform
 from apache_beam.transforms import window
 from apache_beam.transforms.window import TimestampedValue
@@ -66,6 +68,28 @@
     # TODO(BEAM-5949): Needed for Python 2 compatibility.
     return not self == other
 
+  @abstractmethod
+  def to_runner_api(self, element_coder):
+    raise NotImplementedError
+
+  @staticmethod
+  def from_runner_api(proto, element_coder):
+    if proto.HasField('element_event'):
+      return ElementEvent(
+          [TimestampedValue(
+              element_coder.decode(tv.encoded_element),
+              timestamp.Timestamp(micros=1000 * tv.timestamp))
+           for tv in proto.element_event.elements])
+    elif proto.HasField('watermark_event'):
+      return WatermarkEvent(timestamp.Timestamp(
+          micros=1000 * proto.watermark_event.new_watermark))
+    elif proto.HasField('processing_time_event'):
+      return ProcessingTimeEvent(timestamp.Duration(
+          micros=1000 * proto.processing_time_event.advance_duration))
+    else:
+      raise ValueError(
+          'Unknown TestStream Event type: %s' % proto.WhichOneof('event'))
+
 
 class ElementEvent(Event):
   """Element-producing test stream event."""
@@ -82,6 +106,15 @@
   def __lt__(self, other):
     return self.timestamped_values < other.timestamped_values
 
+  def to_runner_api(self, element_coder):
+    return beam_runner_api_pb2.TestStreamPayload.Event(
+        element_event=beam_runner_api_pb2.TestStreamPayload.Event.AddElements(
+            elements=[
+                beam_runner_api_pb2.TestStreamPayload.TimestampedElement(
+                    encoded_element=element_coder.encode(tv.value),
+                    timestamp=tv.timestamp.micros // 1000)
+                for tv in self.timestamped_values]))
+
 
 class WatermarkEvent(Event):
   """Watermark-advancing test stream event."""
@@ -98,6 +131,11 @@
   def __lt__(self, other):
     return self.new_watermark < other.new_watermark
 
+  def to_runner_api(self, unused_element_coder):
+    return beam_runner_api_pb2.TestStreamPayload.Event(
+        watermark_event
+        =beam_runner_api_pb2.TestStreamPayload.Event.AdvanceWatermark(
+            new_watermark=self.new_watermark.micros // 1000))
 
 class ProcessingTimeEvent(Event):
   """Processing time-advancing test stream event."""
@@ -114,6 +152,12 @@
   def __lt__(self, other):
     return self.advance_by < other.advance_by
 
+  def to_runner_api(self, unused_element_coder):
+    return beam_runner_api_pb2.TestStreamPayload.Event(
+        processing_time_event
+        =beam_runner_api_pb2.TestStreamPayload.Event.AdvanceProcessingTime(
+            advance_duration=self.advance_by.micros // 1000))
+
 
 class TestStream(PTransform):
   """Test stream that generates events on an unbounded PCollection of elements.
@@ -123,11 +167,12 @@
   output.
   """
 
-  def __init__(self, coder=coders.FastPrimitivesCoder()):
+  def __init__(self, coder=coders.FastPrimitivesCoder(), events=()):
+    super(TestStream, self).__init__()
     assert coder is not None
     self.coder = coder
     self.current_watermark = timestamp.MIN_TIMESTAMP
-    self.events = []
+    self.events = list(events)
 
   def get_windowing(self, unused_inputs):
     return core.Windowing(window.GlobalWindows())
@@ -206,3 +251,19 @@
     """
     self._add(ProcessingTimeEvent(advance_by))
     return self
+
+  def to_runner_api_parameter(self, context):
+    return (
+        common_urns.primitives.TEST_STREAM.urn,
+        beam_runner_api_pb2.TestStreamPayload(
+            coder_id=context.coders.get_id(self.coder),
+            events=[e.to_runner_api(self.coder) for e in self.events]))
+
+  @PTransform.register_urn(
+      common_urns.primitives.TEST_STREAM.urn,
+      beam_runner_api_pb2.TestStreamPayload)
+  def from_runner_api_parameter(payload, context):
+    coder = context.coders.get_by_id(payload.coder_id)
+    return TestStream(
+        coder=coder,
+        events=[Event.from_runner_api(e, coder) for e in payload.events])
diff --git a/sdks/python/apache_beam/transforms/trigger_test.py b/sdks/python/apache_beam/transforms/trigger_test.py
index dbc4bcd..22ecda3 100644
--- a/sdks/python/apache_beam/transforms/trigger_test.py
+++ b/sdks/python/apache_beam/transforms/trigger_test.py
@@ -20,6 +20,7 @@
 from __future__ import absolute_import
 
 import collections
+import json
 import os.path
 import pickle
 import unittest
@@ -31,6 +32,7 @@
 import yaml
 
 import apache_beam as beam
+from apache_beam import coders
 from apache_beam.options.pipeline_options import PipelineOptions
 from apache_beam.options.pipeline_options import StandardOptions
 from apache_beam.runners import pipeline_context
@@ -502,7 +504,10 @@
     while hasattr(cls, unique_name):
       counter += 1
       unique_name = 'test_%s_%d' % (name, counter)
-    setattr(cls, unique_name, lambda self: self._run_log_test(spec))
+    test_method = lambda self: self._run_log_test(spec)
+    test_method.__name__ = unique_name
+    test_method.__test__ = True
+    setattr(cls, unique_name, test_method)
 
   # We must prepend an underscore to this name so that the open-source unittest
   # runner does not execute this method directly as a test.
@@ -606,24 +611,25 @@
         window_fn, trigger_fn, accumulation_mode, timestamp_combiner,
         transcript, spec)
 
-  def _windowed_value_info(self, windowed_value):
-    # Currently some runners operate at the millisecond level, and some at the
-    # microsecond level.  Trigger transcript timestamps are expressed as
-    # integral units of the finest granularity, whatever that may be.
-    # In these tests we interpret them as integral seconds and then truncate
-    # the results to integral seconds to allow for portability across
-    # different sub-second resolutions.
-    window, = windowed_value.windows
-    return {
-        'window': [int(window.start), int(window.max_timestamp())],
-        'values': sorted(windowed_value.value),
-        'timestamp': int(windowed_value.timestamp),
-        'index': windowed_value.pane_info.index,
-        'nonspeculative_index': windowed_value.pane_info.nonspeculative_index,
-        'early': windowed_value.pane_info.timing == PaneInfoTiming.EARLY,
-        'late': windowed_value.pane_info.timing == PaneInfoTiming.LATE,
-        'final': windowed_value.pane_info.is_last,
-    }
+
+def _windowed_value_info(windowed_value):
+  # Currently some runners operate at the millisecond level, and some at the
+  # microsecond level.  Trigger transcript timestamps are expressed as
+  # integral units of the finest granularity, whatever that may be.
+  # In these tests we interpret them as integral seconds and then truncate
+  # the results to integral seconds to allow for portability across
+  # different sub-second resolutions.
+  window, = windowed_value.windows
+  return {
+      'window': [int(window.start), int(window.max_timestamp())],
+      'values': sorted(windowed_value.value),
+      'timestamp': int(windowed_value.timestamp),
+      'index': windowed_value.pane_info.index,
+      'nonspeculative_index': windowed_value.pane_info.nonspeculative_index,
+      'early': windowed_value.pane_info.timing == PaneInfoTiming.EARLY,
+      'late': windowed_value.pane_info.timing == PaneInfoTiming.LATE,
+      'final': windowed_value.pane_info.is_last,
+  }
 
 
 class TriggerDriverTranscriptTest(TranscriptTest):
@@ -645,7 +651,7 @@
         for timer_window, (name, time_domain, t_timestamp) in to_fire:
           for wvalue in driver.process_timer(
               timer_window, name, time_domain, t_timestamp, state):
-            output.append(self._windowed_value_info(wvalue))
+            output.append(_windowed_value_info(wvalue))
         to_fire = state.get_and_clear_timers(watermark)
 
     for action, params in transcript:
@@ -661,7 +667,7 @@
             WindowedValue(t, t, window_fn.assign(WindowFn.AssignContext(t, t)))
             for t in params]
         output = [
-            self._windowed_value_info(wv)
+            _windowed_value_info(wv)
             for wv in driver.process_elements(state, bundle, watermark)]
         fire_timers()
 
@@ -690,7 +696,7 @@
     self.assertEqual([], output, msg='Unexpected output: %s' % output)
 
 
-class TestStreamTranscriptTest(TranscriptTest):
+class BaseTestStreamTranscriptTest(TranscriptTest):
   """A suite of TestStream-based tests based on trigger transcript entries.
   """
 
@@ -702,14 +708,17 @@
     if runner_name in spec.get('broken_on', ()):
       self.skipTest('Known to be broken on %s' % runner_name)
 
-    test_stream = TestStream()
+    # Elements are encoded as a json strings to allow other languages to
+    # decode elements while executing the test stream.
+    # TODO(BEAM-8600): Eliminate these gymnastics.
+    test_stream = TestStream(coder=coders.StrUtf8Coder()).with_output_types(str)
     for action, params in transcript:
       if action == 'expect':
-        test_stream.add_elements([('expect', params)])
+        test_stream.add_elements([json.dumps(('expect', params))])
       else:
-        test_stream.add_elements([('expect', [])])
+        test_stream.add_elements([json.dumps(('expect', []))])
         if action == 'input':
-          test_stream.add_elements([('input', e) for e in params])
+          test_stream.add_elements([json.dumps(('input', e)) for e in params])
         elif action == 'watermark':
           test_stream.advance_watermark_to(params)
         elif action == 'clock':
@@ -718,7 +727,9 @@
           pass  # Requires inspection of implementation details.
         else:
           raise ValueError('Unexpected action: %s' % action)
-    test_stream.add_elements([('expect', [])])
+    test_stream.add_elements([json.dumps(('expect', []))])
+
+    read_test_stream = test_stream | beam.Map(json.loads)
 
     class Check(beam.DoFn):
       """A StatefulDoFn that verifies outputs are produced as expected.
@@ -731,12 +742,40 @@
 
       The key is ignored, but all items must be on the same key to share state.
       """
+      def __init__(self, allow_out_of_order=True):
+        # Some runners don't support cross-stage TestStream semantics.
+        self.allow_out_of_order = allow_out_of_order
+
       def process(
-          self, element, seen=beam.DoFn.StateParam(
+          self,
+          element,
+          seen=beam.DoFn.StateParam(
               beam.transforms.userstate.BagStateSpec(
                   'seen',
+                  beam.coders.FastPrimitivesCoder())),
+          expected=beam.DoFn.StateParam(
+              beam.transforms.userstate.BagStateSpec(
+                  'expected',
                   beam.coders.FastPrimitivesCoder()))):
         _, (action, data) = element
+
+        if self.allow_out_of_order:
+          if action == 'expect' and not list(seen.read()):
+            if data:
+              expected.add(data)
+            return
+          elif action == 'actual' and list(expected.read()):
+            seen.add(data)
+            all_data = list(seen.read())
+            all_expected = list(expected.read())
+            if len(all_data) == len(all_expected[0]):
+              expected.clear()
+              for expect in all_expected[1:]:
+                expected.add(expect)
+              action, data = 'expect', all_expected[0]
+            else:
+              return
+
         if action == 'actual':
           seen.add(data)
 
@@ -768,12 +807,14 @@
         else:
           raise ValueError('Unexpected action: %s' % action)
 
-    with TestPipeline(options=PipelineOptions(streaming=True)) as p:
+    with TestPipeline() as p:
+      # TODO(BEAM-8601): Pass this during pipeline construction.
+      p.options.view_as(StandardOptions).streaming = True
       # Split the test stream into a branch of to-be-processed elements, and
       # a branch of expected results.
       inputs, expected = (
           p
-          | test_stream
+          | read_test_stream
           | beam.MapTuple(
               lambda tag, value: beam.pvalue.TaggedOutput(tag, ('key', value))
               ).with_outputs('input', 'expect'))
@@ -794,7 +835,7 @@
                      t=beam.DoFn.TimestampParam,
                      p=beam.DoFn.PaneInfoParam: (
                          k,
-                         self._windowed_value_info(WindowedValue(
+                         _windowed_value_info(WindowedValue(
                              vs, windows=[window], timestamp=t, pane_info=p))))
           # Place outputs back into the global window to allow flattening
           # and share a single state in Check.
@@ -805,7 +846,17 @@
       tagged_outputs = (
           outputs | beam.MapTuple(lambda key, value: (key, ('actual', value))))
       # pylint: disable=expression-not-assigned
-      (tagged_expected, tagged_outputs) | beam.Flatten() | beam.ParDo(Check())
+      ([tagged_expected, tagged_outputs]
+       | beam.Flatten()
+       | beam.ParDo(Check(self.allow_out_of_order)))
+
+
+class TestStreamTranscriptTest(BaseTestStreamTranscriptTest):
+  allow_out_of_order = False
+
+
+class WeakTestStreamTranscriptTest(BaseTestStreamTranscriptTest):
+  allow_out_of_order = True
 
 
 TRANSCRIPT_TEST_FILE = os.path.join(
@@ -814,6 +865,7 @@
 if os.path.exists(TRANSCRIPT_TEST_FILE):
   TriggerDriverTranscriptTest._create_tests(TRANSCRIPT_TEST_FILE)
   TestStreamTranscriptTest._create_tests(TRANSCRIPT_TEST_FILE)
+  WeakTestStreamTranscriptTest._create_tests(TRANSCRIPT_TEST_FILE)
 
 
 if __name__ == '__main__':
diff --git a/sdks/python/test-suites/portable/common.gradle b/sdks/python/test-suites/portable/common.gradle
index f04d28d..3cb4362 100644
--- a/sdks/python/test-suites/portable/common.gradle
+++ b/sdks/python/test-suites/portable/common.gradle
@@ -79,3 +79,22 @@
 task flinkValidatesRunner() {
   dependsOn 'flinkCompatibilityMatrixLoopback'
 }
+
+// TODO(BEAM-8598): Enable on pre-commit.
+task flinkTriggerTranscript() {
+  dependsOn 'setupVirtualenv'
+  dependsOn ':runners:flink:1.9:job-server:shadowJar'
+  doLast {
+    exec {
+      executable 'sh'
+      args '-c', """
+          . ${envdir}/bin/activate \\
+          && cd ${pythonRootDir} \\
+          && pip install -e .[test] \\
+          && python setup.py nosetests \\
+              --tests apache_beam.transforms.trigger_test:WeakTestStreamTranscriptTest \\
+              --test-pipeline-options='--runner=FlinkRunner --environment_type=LOOPBACK --flink_job_server_jar=${project(":runners:flink:1.9:job-server:").shadowJar.archivePath}'
+          """
+    }
+  }
+}