Merge pull request #10079 [BEAM-8621] Fix dependency checking job.
diff --git a/runners/flink/flink_runner.gradle b/runners/flink/flink_runner.gradle
index 3254a85..6281b94 100644
--- a/runners/flink/flink_runner.gradle
+++ b/runners/flink/flink_runner.gradle
@@ -200,6 +200,7 @@
excludeCategories 'org.apache.beam.sdk.testing.UsesCommittedMetrics'
if (config.streaming) {
excludeCategories 'org.apache.beam.sdk.testing.UsesImpulse'
+ excludeCategories 'org.apache.beam.sdk.testing.UsesTestStreamWithMultipleStages' // BEAM-8598
excludeCategories 'org.apache.beam.sdk.testing.UsesTestStreamWithProcessingTime'
} else {
excludeCategories 'org.apache.beam.sdk.testing.UsesUnboundedSplittableParDo'
diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java
index 915acfb..eed52e9 100644
--- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java
+++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java
@@ -94,6 +94,8 @@
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.typeutils.base.StringSerializer;
import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
+import org.apache.flink.runtime.state.KeyGroupRange;
import org.apache.flink.runtime.state.KeyedStateBackend;
import org.apache.flink.streaming.api.operators.InternalTimer;
import org.apache.flink.streaming.api.watermark.Watermark;
@@ -246,7 +248,8 @@
() -> UUID.randomUUID().toString(),
keyedStateInternals,
getKeyedStateBackend(),
- stateBackendLock));
+ stateBackendLock,
+ keyCoder));
} else {
userStateRequestHandler = StateRequestHandler.unsupported();
}
@@ -263,7 +266,12 @@
private final StateInternals stateInternals;
private final KeyedStateBackend<ByteBuffer> keyedStateBackend;
+ /** Lock to hold whenever accessing the state backend. */
private final Lock stateBackendLock;
+ /** For debugging: The key coder used by the Runner. */
+ @Nullable private final Coder runnerKeyCoder;
+ /** For debugging: Same as keyedStateBackend but upcasted, to access key group meta info. */
+ @Nullable private final AbstractKeyedStateBackend<ByteBuffer> keyStateBackendWithKeyGroupInfo;
/** Holds the valid cache token for user state for this operator. */
private final ByteString cacheToken;
@@ -271,10 +279,20 @@
IdGenerator cacheTokenGenerator,
StateInternals stateInternals,
KeyedStateBackend<ByteBuffer> keyedStateBackend,
- Lock stateBackendLock) {
+ Lock stateBackendLock,
+ @Nullable Coder runnerKeyCoder) {
this.stateInternals = stateInternals;
this.keyedStateBackend = keyedStateBackend;
this.stateBackendLock = stateBackendLock;
+ if (keyedStateBackend instanceof AbstractKeyedStateBackend) {
+ // This will always succeed, unless a custom state backend is used which does not extend
+ // AbstractKeyedStateBackend. This is unlikely but we should still consider this case.
+ this.keyStateBackendWithKeyGroupInfo =
+ (AbstractKeyedStateBackend<ByteBuffer>) keyedStateBackend;
+ } else {
+ this.keyStateBackendWithKeyGroupInfo = null;
+ }
+ this.runnerKeyCoder = runnerKeyCoder;
this.cacheToken = ByteString.copyFrom(cacheTokenGenerator.getId().getBytes(Charsets.UTF_8));
}
@@ -368,6 +386,19 @@
// Key for state request is shipped encoded with NESTED context.
ByteBuffer encodedKey = FlinkKeyUtils.fromEncodedKey(key);
keyedStateBackend.setCurrentKey(encodedKey);
+ if (keyStateBackendWithKeyGroupInfo != null) {
+ int currentKeyGroupIndex = keyStateBackendWithKeyGroupInfo.getCurrentKeyGroupIndex();
+ KeyGroupRange keyGroupRange = keyStateBackendWithKeyGroupInfo.getKeyGroupRange();
+ Preconditions.checkState(
+ keyGroupRange.contains(currentKeyGroupIndex),
+ "The current key '%s' with key group index '%s' does not belong to the key group range '%s'. Runner keyCoder: %s. Ptransformid: %s Userstateid: %s",
+ Arrays.toString(key.toByteArray()),
+ currentKeyGroupIndex,
+ keyGroupRange,
+ runnerKeyCoder,
+ pTransformId,
+ userStateId);
+ }
}
};
}
diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperatorTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperatorTest.java
index 0d7c99f..0cbf6b7 100644
--- a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperatorTest.java
+++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperatorTest.java
@@ -659,7 +659,7 @@
ExecutableStageDoFnOperator.BagUserStateFactory<ByteString, Integer, GlobalWindow>
bagUserStateFactory =
new ExecutableStageDoFnOperator.BagUserStateFactory<>(
- cacheTokenGenerator, test, stateBackend, NoopLock.get());
+ cacheTokenGenerator, test, stateBackend, NoopLock.get(), null);
ByteString key1 = ByteString.copyFrom("key1", Charsets.UTF_8);
ByteString key2 = ByteString.copyFrom("key2", Charsets.UTF_8);
diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptors.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptors.java
index b324cb1..a2129c4 100644
--- a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptors.java
+++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptors.java
@@ -57,6 +57,7 @@
import org.apache.beam.sdk.util.WindowedValue.FullWindowedValueCoder;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.vendor.grpc.v1p21p0.com.google.protobuf.InvalidProtocolBufferException;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableTable;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
@@ -135,6 +136,10 @@
forTimerSpecs(
dataEndpoint, stage, components, inputDestinationsBuilder, remoteOutputCodersBuilder);
+ if (bagUserStateSpecs.size() > 0 || timerSpecs.size() > 0) {
+ lengthPrefixKeyCoder(stage.getInputPCollection().getId(), components);
+ }
+
// Copy data from components to ProcessBundleDescriptor.
ProcessBundleDescriptor.Builder bundleDescriptorBuilder =
ProcessBundleDescriptor.newBuilder().setId(id);
@@ -158,6 +163,29 @@
timerSpecs);
}
+ /**
+ * Patches the input coder of a stateful Executable transform to ensure that the byte
+ * representation of a key used to partition the input element at the Runner, matches the key byte
+ * representation received for state requests and timers from the SDK Harness. Stateful transforms
+ * always have a KvCoder as input.
+ */
+ private static void lengthPrefixKeyCoder(
+ String inputColId, Components.Builder componentsBuilder) {
+ RunnerApi.PCollection pcollection = componentsBuilder.getPcollectionsOrThrow(inputColId);
+ RunnerApi.Coder kvCoder = componentsBuilder.getCodersOrThrow(pcollection.getCoderId());
+ Preconditions.checkState(
+ ModelCoders.KV_CODER_URN.equals(kvCoder.getSpec().getUrn()),
+ "Stateful executable stages must use a KV coder, but is: %s",
+ kvCoder.getSpec().getUrn());
+ String keyCoderId = ModelCoders.getKvCoderComponents(kvCoder).keyCoderId();
+ // Retain the original coder, but wrap in LengthPrefixCoder
+ String newKeyCoderId =
+ LengthPrefixUnknownCoders.addLengthPrefixedCoder(keyCoderId, componentsBuilder, false);
+ // Replace old key coder with LengthPrefixCoder<old_key_coder>
+ kvCoder = kvCoder.toBuilder().setComponentCoderIds(0, newKeyCoderId).build();
+ componentsBuilder.putCoders(pcollection.getCoderId(), kvCoder);
+ }
+
private static Map<String, Coder<WindowedValue<?>>> addStageOutputs(
ApiServiceDescriptor dataEndpoint,
Collection<PCollectionNode> outputPCollections,
diff --git a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptorsTest.java b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptorsTest.java
new file mode 100644
index 0000000..b558b00
--- /dev/null
+++ b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptorsTest.java
@@ -0,0 +1,130 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.fnexecution.control;
+
+import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState;
+import static org.hamcrest.CoreMatchers.is;
+import static org.hamcrest.MatcherAssert.assertThat;
+
+import java.io.Serializable;
+import java.util.Map;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi;
+import org.apache.beam.model.pipeline.v1.Endpoints;
+import org.apache.beam.model.pipeline.v1.RunnerApi;
+import org.apache.beam.runners.core.construction.CoderTranslation;
+import org.apache.beam.runners.core.construction.ModelCoderRegistrar;
+import org.apache.beam.runners.core.construction.ModelCoders;
+import org.apache.beam.runners.core.construction.PipelineTranslation;
+import org.apache.beam.runners.core.construction.graph.ExecutableStage;
+import org.apache.beam.runners.core.construction.graph.FusedPipeline;
+import org.apache.beam.runners.core.construction.graph.GreedyPipelineFuser;
+import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.coders.VoidCoder;
+import org.apache.beam.sdk.state.BagState;
+import org.apache.beam.sdk.state.StateSpec;
+import org.apache.beam.sdk.state.StateSpecs;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.GroupByKey;
+import org.apache.beam.sdk.transforms.Impulse;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Optional;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
+import org.junit.Test;
+
+/** Tests for {@link ProcessBundleDescriptors}. */
+public class ProcessBundleDescriptorsTest implements Serializable {
+
+ /**
+ * Tests that a stateful Executable stage will wrap a key coder of a stateful transform in a
+ * LengthPrefixCoder.
+ */
+ @Test
+ public void testWrapKeyCoderOfStatefulExecutableStageInLengthPrefixCoder() throws Exception {
+ // Add another stateful stage with a non-standard key coder
+ Pipeline p = Pipeline.create();
+ Coder<Void> keycoder = VoidCoder.of();
+ assertThat(ModelCoderRegistrar.isKnownCoder(keycoder), is(false));
+ p.apply("impulse", Impulse.create())
+ .apply(
+ "create",
+ ParDo.of(
+ new DoFn<byte[], KV<Void, String>>() {
+ @ProcessElement
+ public void process(ProcessContext ctxt) {}
+ }))
+ .setCoder(KvCoder.of(keycoder, StringUtf8Coder.of()))
+ .apply(
+ "userState",
+ ParDo.of(
+ new DoFn<KV<Void, String>, KV<Void, String>>() {
+
+ @StateId("stateId")
+ private final StateSpec<BagState<String>> bufferState =
+ StateSpecs.bag(StringUtf8Coder.of());
+
+ @ProcessElement
+ public void processElement(
+ @Element KV<Void, String> element,
+ @StateId("stateId") BagState<String> state,
+ OutputReceiver<KV<Void, String>> r) {
+ for (String value : state.read()) {
+ r.output(KV.of(element.getKey(), value));
+ }
+ state.add(element.getValue());
+ }
+ }))
+ // Force the output to be materialized
+ .apply("gbk", GroupByKey.create());
+
+ RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p);
+ FusedPipeline fused = GreedyPipelineFuser.fuse(pipelineProto);
+ Optional<ExecutableStage> optionalStage =
+ Iterables.tryFind(
+ fused.getFusedStages(),
+ (ExecutableStage stage) ->
+ stage.getUserStates().stream()
+ .anyMatch(spec -> spec.localName().equals("stateId")));
+ checkState(optionalStage.isPresent(), "Expected a stage with user state.");
+ ExecutableStage stage = optionalStage.get();
+
+ ProcessBundleDescriptors.ExecutableProcessBundleDescriptor descriptor =
+ ProcessBundleDescriptors.fromExecutableStage(
+ "test_stage", stage, Endpoints.ApiServiceDescriptor.getDefaultInstance());
+
+ BeamFnApi.ProcessBundleDescriptor pbDescriptor = descriptor.getProcessBundleDescriptor();
+ String inputColId = stage.getInputPCollection().getId();
+ String inputCoderId = pbDescriptor.getPcollectionsMap().get(inputColId).getCoderId();
+
+ Map<String, RunnerApi.Coder> codersMap = pbDescriptor.getCodersMap();
+ RunnerApi.Coder coder = codersMap.get(inputCoderId);
+ String keyCoderId = ModelCoders.getKvCoderComponents(coder).keyCoderId();
+
+ assertThat(
+ codersMap.get(keyCoderId).getSpec().getUrn(), is(ModelCoders.LENGTH_PREFIX_CODER_URN));
+
+ RunnerApi.Coder orignalCoder = stage.getComponents().getCodersMap().get(inputCoderId);
+ String originalKeyCoderId = ModelCoders.getKvCoderComponents(orignalCoder).keyCoderId();
+ assertThat(
+ stage.getComponents().getCodersMap().get(originalKeyCoderId).getSpec().getUrn(),
+ is(CoderTranslation.JAVA_SERIALIZED_CODER_URN));
+ }
+}
diff --git a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java
index 0925767..5d4c8f0 100644
--- a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java
+++ b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java
@@ -18,10 +18,10 @@
package org.apache.beam.runners.fnexecution.control;
import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState;
+import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.allOf;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.equalTo;
-import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
@@ -946,8 +946,6 @@
@Test
public void testExecutionWithTimer() throws Exception {
Pipeline p = Pipeline.create();
- final String timerId = "foo";
- final String timerId2 = "foo2";
p.apply("impulse", Impulse.create())
.apply(
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/UsesTestStreamWithMultipleStages.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/UsesTestStreamWithMultipleStages.java
new file mode 100644
index 0000000..55999ce
--- /dev/null
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/UsesTestStreamWithMultipleStages.java
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.testing;
+
+/**
+ * Subcategory for {@link UsesTestStream} tests which use {@link TestStream} # across multiple
+ * stages. Some Runners do not properly support quiescence in a way that {@link TestStream} demands
+ * it.
+ */
+public interface UsesTestStreamWithMultipleStages extends UsesTestStream {}
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/TestStreamTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/TestStreamTest.java
index 5e4cdcb..e48b6b2 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/TestStreamTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/TestStreamTest.java
@@ -17,6 +17,7 @@
*/
package org.apache.beam.sdk.testing;
+import static org.apache.beam.sdk.transforms.windowing.Window.into;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.allOf;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
@@ -28,12 +29,22 @@
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.coders.VarIntCoder;
import org.apache.beam.sdk.coders.VarLongCoder;
+import org.apache.beam.sdk.state.StateSpec;
+import org.apache.beam.sdk.state.StateSpecs;
+import org.apache.beam.sdk.state.TimeDomain;
+import org.apache.beam.sdk.state.Timer;
+import org.apache.beam.sdk.state.TimerSpec;
+import org.apache.beam.sdk.state.TimerSpecs;
+import org.apache.beam.sdk.state.ValueState;
import org.apache.beam.sdk.testing.TestStream.Builder;
import org.apache.beam.sdk.transforms.Combine;
import org.apache.beam.sdk.transforms.Count;
+import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.Flatten;
import org.apache.beam.sdk.transforms.GroupByKey;
+import org.apache.beam.sdk.transforms.Keys;
import org.apache.beam.sdk.transforms.MapElements;
+import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.Sum;
import org.apache.beam.sdk.transforms.Values;
import org.apache.beam.sdk.transforms.WithKeys;
@@ -44,6 +55,7 @@
import org.apache.beam.sdk.transforms.windowing.DefaultTrigger;
import org.apache.beam.sdk.transforms.windowing.FixedWindows;
import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
+import org.apache.beam.sdk.transforms.windowing.GlobalWindows;
import org.apache.beam.sdk.transforms.windowing.IntervalWindow;
import org.apache.beam.sdk.transforms.windowing.Never;
import org.apache.beam.sdk.transforms.windowing.Window;
@@ -51,8 +63,10 @@
import org.apache.beam.sdk.util.CoderUtils;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionList;
import org.apache.beam.sdk.values.TimestampedValue;
import org.apache.beam.sdk.values.TypeDescriptors;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
import org.joda.time.Duration;
import org.joda.time.Instant;
import org.junit.Rule;
@@ -263,7 +277,7 @@
FixedWindows windows = FixedWindows.of(Duration.standardHours(6));
PCollection<String> windowedValues =
p.apply(stream)
- .apply(Window.into(windows))
+ .apply(into(windows))
.apply(WithKeys.of(1))
.apply(GroupByKey.create())
.apply(Values.create())
@@ -387,6 +401,74 @@
}
@Test
+ @Category({ValidatesRunner.class, UsesTestStream.class, UsesTestStreamWithMultipleStages.class})
+ public void testMultiStage() throws Exception {
+ TestStream<String> testStream =
+ TestStream.create(StringUtf8Coder.of())
+ .addElements("before") // before
+ .advanceWatermarkTo(Instant.ofEpochSecond(0)) // BEFORE
+ .addElements(TimestampedValue.of("after", Instant.ofEpochSecond(10))) // after
+ .advanceWatermarkToInfinity(); // AFTER
+
+ PCollection<String> input = p.apply(testStream);
+
+ PCollection<String> grouped =
+ input
+ .apply(Window.into(FixedWindows.of(Duration.standardSeconds(1))))
+ .apply(
+ MapElements.into(
+ TypeDescriptors.kvs(TypeDescriptors.strings(), TypeDescriptors.strings()))
+ .via(e -> KV.of(e, e)))
+ .apply(GroupByKey.create())
+ .apply(Keys.create())
+ .apply("Upper", MapElements.into(TypeDescriptors.strings()).via(String::toUpperCase))
+ .apply("Rewindow", Window.into(new GlobalWindows()));
+
+ PCollection<String> result =
+ PCollectionList.of(ImmutableList.of(input, grouped))
+ .apply(Flatten.pCollections())
+ .apply(
+ "Key",
+ MapElements.into(
+ TypeDescriptors.kvs(TypeDescriptors.strings(), TypeDescriptors.strings()))
+ .via(e -> KV.of("key", e)))
+ .apply(
+ ParDo.of(
+ new DoFn<KV<String, String>, String>() {
+ @StateId("seen")
+ private final StateSpec<ValueState<String>> seenSpec =
+ StateSpecs.value(StringUtf8Coder.of());
+
+ @TimerId("emit")
+ private final TimerSpec emitSpec = TimerSpecs.timer(TimeDomain.EVENT_TIME);
+
+ @ProcessElement
+ public void process(
+ ProcessContext context,
+ @StateId("seen") ValueState<String> seenState,
+ @TimerId("emit") Timer emitTimer) {
+ String element = context.element().getValue();
+ if (seenState.read() == null) {
+ seenState.write(element);
+ } else {
+ seenState.write(seenState.read() + "," + element);
+ }
+ emitTimer.set(Instant.ofEpochSecond(100));
+ }
+
+ @OnTimer("emit")
+ public void onEmit(
+ OnTimerContext context, @StateId("seen") ValueState<String> seenState) {
+ context.output(seenState.read());
+ }
+ }));
+
+ PAssert.that(result).containsInAnyOrder("before,BEFORE,after,AFTER");
+
+ p.run().waitUntilFinish();
+ }
+
+ @Test
@Category(UsesTestStreamWithProcessingTime.class)
public void testCoder() throws Exception {
TestStream<String> testStream =
diff --git a/sdks/python/apache_beam/runners/portability/portable_runner_test.py b/sdks/python/apache_beam/runners/portability/portable_runner_test.py
index 24c6b87..992aa08 100644
--- a/sdks/python/apache_beam/runners/portability/portable_runner_test.py
+++ b/sdks/python/apache_beam/runners/portability/portable_runner_test.py
@@ -45,7 +45,10 @@
from apache_beam.runners.portability.portable_runner import PortableRunner
from apache_beam.runners.worker import worker_pool_main
from apache_beam.runners.worker.channel_factory import GRPCChannelFactory
+from apache_beam.testing.util import assert_that
+from apache_beam.testing.util import equal_to
from apache_beam.transforms import environments
+from apache_beam.transforms import userstate
class PortableRunnerTest(fn_api_runner_test.FnApiRunnerTest):
@@ -188,6 +191,46 @@
def test_metrics(self):
self.skipTest('Metrics not supported.')
+ def test_pardo_state_with_custom_key_coder(self):
+ """Tests that state requests work correctly when the key coder is an
+ SDK-specific coder, i.e. non standard coder. This is additionally enforced
+ by Java's ProcessBundleDescriptorsTest and by Flink's
+ ExecutableStageDoFnOperator which detects invalid encoding by checking for
+ the correct key group of the encoded key."""
+ index_state_spec = userstate.CombiningValueStateSpec('index', sum)
+
+ # Test params
+ # Ensure decent amount of elements to serve all partitions
+ n = 200
+ duplicates = 1
+
+ split = n // (duplicates + 1)
+ inputs = [(i % split, str(i % split)) for i in range(0, n)]
+
+ # Use a DoFn which has to use FastPrimitivesCoder because the type cannot
+ # be inferred
+ class Input(beam.DoFn):
+ def process(self, impulse):
+ for i in inputs:
+ yield i
+
+ class AddIndex(beam.DoFn):
+ def process(self, kv,
+ index=beam.DoFn.StateParam(index_state_spec)):
+ k, v = kv
+ index.add(1)
+ yield k, v, index.read()
+
+ expected = [(i % split, str(i % split), i // split + 1)
+ for i in range(0, n)]
+
+ with self.create_pipeline() as p:
+ assert_that(p
+ | beam.Impulse()
+ | beam.ParDo(Input())
+ | beam.ParDo(AddIndex()),
+ equal_to(expected))
+
# Inherits all other tests from fn_api_runner_test.FnApiRunnerTest
diff --git a/sdks/python/apache_beam/testing/test_stream.py b/sdks/python/apache_beam/testing/test_stream.py
index 610a1a8..9d9284c 100644
--- a/sdks/python/apache_beam/testing/test_stream.py
+++ b/sdks/python/apache_beam/testing/test_stream.py
@@ -31,6 +31,8 @@
from apache_beam import coders
from apache_beam import core
from apache_beam import pvalue
+from apache_beam.portability import common_urns
+from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.transforms import PTransform
from apache_beam.transforms import window
from apache_beam.transforms.window import TimestampedValue
@@ -66,6 +68,28 @@
# TODO(BEAM-5949): Needed for Python 2 compatibility.
return not self == other
+ @abstractmethod
+ def to_runner_api(self, element_coder):
+ raise NotImplementedError
+
+ @staticmethod
+ def from_runner_api(proto, element_coder):
+ if proto.HasField('element_event'):
+ return ElementEvent(
+ [TimestampedValue(
+ element_coder.decode(tv.encoded_element),
+ timestamp.Timestamp(micros=1000 * tv.timestamp))
+ for tv in proto.element_event.elements])
+ elif proto.HasField('watermark_event'):
+ return WatermarkEvent(timestamp.Timestamp(
+ micros=1000 * proto.watermark_event.new_watermark))
+ elif proto.HasField('processing_time_event'):
+ return ProcessingTimeEvent(timestamp.Duration(
+ micros=1000 * proto.processing_time_event.advance_duration))
+ else:
+ raise ValueError(
+ 'Unknown TestStream Event type: %s' % proto.WhichOneof('event'))
+
class ElementEvent(Event):
"""Element-producing test stream event."""
@@ -82,6 +106,15 @@
def __lt__(self, other):
return self.timestamped_values < other.timestamped_values
+ def to_runner_api(self, element_coder):
+ return beam_runner_api_pb2.TestStreamPayload.Event(
+ element_event=beam_runner_api_pb2.TestStreamPayload.Event.AddElements(
+ elements=[
+ beam_runner_api_pb2.TestStreamPayload.TimestampedElement(
+ encoded_element=element_coder.encode(tv.value),
+ timestamp=tv.timestamp.micros // 1000)
+ for tv in self.timestamped_values]))
+
class WatermarkEvent(Event):
"""Watermark-advancing test stream event."""
@@ -98,6 +131,11 @@
def __lt__(self, other):
return self.new_watermark < other.new_watermark
+ def to_runner_api(self, unused_element_coder):
+ return beam_runner_api_pb2.TestStreamPayload.Event(
+ watermark_event
+ =beam_runner_api_pb2.TestStreamPayload.Event.AdvanceWatermark(
+ new_watermark=self.new_watermark.micros // 1000))
class ProcessingTimeEvent(Event):
"""Processing time-advancing test stream event."""
@@ -114,6 +152,12 @@
def __lt__(self, other):
return self.advance_by < other.advance_by
+ def to_runner_api(self, unused_element_coder):
+ return beam_runner_api_pb2.TestStreamPayload.Event(
+ processing_time_event
+ =beam_runner_api_pb2.TestStreamPayload.Event.AdvanceProcessingTime(
+ advance_duration=self.advance_by.micros // 1000))
+
class TestStream(PTransform):
"""Test stream that generates events on an unbounded PCollection of elements.
@@ -123,11 +167,12 @@
output.
"""
- def __init__(self, coder=coders.FastPrimitivesCoder()):
+ def __init__(self, coder=coders.FastPrimitivesCoder(), events=()):
+ super(TestStream, self).__init__()
assert coder is not None
self.coder = coder
self.current_watermark = timestamp.MIN_TIMESTAMP
- self.events = []
+ self.events = list(events)
def get_windowing(self, unused_inputs):
return core.Windowing(window.GlobalWindows())
@@ -206,3 +251,19 @@
"""
self._add(ProcessingTimeEvent(advance_by))
return self
+
+ def to_runner_api_parameter(self, context):
+ return (
+ common_urns.primitives.TEST_STREAM.urn,
+ beam_runner_api_pb2.TestStreamPayload(
+ coder_id=context.coders.get_id(self.coder),
+ events=[e.to_runner_api(self.coder) for e in self.events]))
+
+ @PTransform.register_urn(
+ common_urns.primitives.TEST_STREAM.urn,
+ beam_runner_api_pb2.TestStreamPayload)
+ def from_runner_api_parameter(payload, context):
+ coder = context.coders.get_by_id(payload.coder_id)
+ return TestStream(
+ coder=coder,
+ events=[Event.from_runner_api(e, coder) for e in payload.events])
diff --git a/sdks/python/apache_beam/transforms/trigger_test.py b/sdks/python/apache_beam/transforms/trigger_test.py
index dbc4bcd..22ecda3 100644
--- a/sdks/python/apache_beam/transforms/trigger_test.py
+++ b/sdks/python/apache_beam/transforms/trigger_test.py
@@ -20,6 +20,7 @@
from __future__ import absolute_import
import collections
+import json
import os.path
import pickle
import unittest
@@ -31,6 +32,7 @@
import yaml
import apache_beam as beam
+from apache_beam import coders
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.options.pipeline_options import StandardOptions
from apache_beam.runners import pipeline_context
@@ -502,7 +504,10 @@
while hasattr(cls, unique_name):
counter += 1
unique_name = 'test_%s_%d' % (name, counter)
- setattr(cls, unique_name, lambda self: self._run_log_test(spec))
+ test_method = lambda self: self._run_log_test(spec)
+ test_method.__name__ = unique_name
+ test_method.__test__ = True
+ setattr(cls, unique_name, test_method)
# We must prepend an underscore to this name so that the open-source unittest
# runner does not execute this method directly as a test.
@@ -606,24 +611,25 @@
window_fn, trigger_fn, accumulation_mode, timestamp_combiner,
transcript, spec)
- def _windowed_value_info(self, windowed_value):
- # Currently some runners operate at the millisecond level, and some at the
- # microsecond level. Trigger transcript timestamps are expressed as
- # integral units of the finest granularity, whatever that may be.
- # In these tests we interpret them as integral seconds and then truncate
- # the results to integral seconds to allow for portability across
- # different sub-second resolutions.
- window, = windowed_value.windows
- return {
- 'window': [int(window.start), int(window.max_timestamp())],
- 'values': sorted(windowed_value.value),
- 'timestamp': int(windowed_value.timestamp),
- 'index': windowed_value.pane_info.index,
- 'nonspeculative_index': windowed_value.pane_info.nonspeculative_index,
- 'early': windowed_value.pane_info.timing == PaneInfoTiming.EARLY,
- 'late': windowed_value.pane_info.timing == PaneInfoTiming.LATE,
- 'final': windowed_value.pane_info.is_last,
- }
+
+def _windowed_value_info(windowed_value):
+ # Currently some runners operate at the millisecond level, and some at the
+ # microsecond level. Trigger transcript timestamps are expressed as
+ # integral units of the finest granularity, whatever that may be.
+ # In these tests we interpret them as integral seconds and then truncate
+ # the results to integral seconds to allow for portability across
+ # different sub-second resolutions.
+ window, = windowed_value.windows
+ return {
+ 'window': [int(window.start), int(window.max_timestamp())],
+ 'values': sorted(windowed_value.value),
+ 'timestamp': int(windowed_value.timestamp),
+ 'index': windowed_value.pane_info.index,
+ 'nonspeculative_index': windowed_value.pane_info.nonspeculative_index,
+ 'early': windowed_value.pane_info.timing == PaneInfoTiming.EARLY,
+ 'late': windowed_value.pane_info.timing == PaneInfoTiming.LATE,
+ 'final': windowed_value.pane_info.is_last,
+ }
class TriggerDriverTranscriptTest(TranscriptTest):
@@ -645,7 +651,7 @@
for timer_window, (name, time_domain, t_timestamp) in to_fire:
for wvalue in driver.process_timer(
timer_window, name, time_domain, t_timestamp, state):
- output.append(self._windowed_value_info(wvalue))
+ output.append(_windowed_value_info(wvalue))
to_fire = state.get_and_clear_timers(watermark)
for action, params in transcript:
@@ -661,7 +667,7 @@
WindowedValue(t, t, window_fn.assign(WindowFn.AssignContext(t, t)))
for t in params]
output = [
- self._windowed_value_info(wv)
+ _windowed_value_info(wv)
for wv in driver.process_elements(state, bundle, watermark)]
fire_timers()
@@ -690,7 +696,7 @@
self.assertEqual([], output, msg='Unexpected output: %s' % output)
-class TestStreamTranscriptTest(TranscriptTest):
+class BaseTestStreamTranscriptTest(TranscriptTest):
"""A suite of TestStream-based tests based on trigger transcript entries.
"""
@@ -702,14 +708,17 @@
if runner_name in spec.get('broken_on', ()):
self.skipTest('Known to be broken on %s' % runner_name)
- test_stream = TestStream()
+ # Elements are encoded as a json strings to allow other languages to
+ # decode elements while executing the test stream.
+ # TODO(BEAM-8600): Eliminate these gymnastics.
+ test_stream = TestStream(coder=coders.StrUtf8Coder()).with_output_types(str)
for action, params in transcript:
if action == 'expect':
- test_stream.add_elements([('expect', params)])
+ test_stream.add_elements([json.dumps(('expect', params))])
else:
- test_stream.add_elements([('expect', [])])
+ test_stream.add_elements([json.dumps(('expect', []))])
if action == 'input':
- test_stream.add_elements([('input', e) for e in params])
+ test_stream.add_elements([json.dumps(('input', e)) for e in params])
elif action == 'watermark':
test_stream.advance_watermark_to(params)
elif action == 'clock':
@@ -718,7 +727,9 @@
pass # Requires inspection of implementation details.
else:
raise ValueError('Unexpected action: %s' % action)
- test_stream.add_elements([('expect', [])])
+ test_stream.add_elements([json.dumps(('expect', []))])
+
+ read_test_stream = test_stream | beam.Map(json.loads)
class Check(beam.DoFn):
"""A StatefulDoFn that verifies outputs are produced as expected.
@@ -731,12 +742,40 @@
The key is ignored, but all items must be on the same key to share state.
"""
+ def __init__(self, allow_out_of_order=True):
+ # Some runners don't support cross-stage TestStream semantics.
+ self.allow_out_of_order = allow_out_of_order
+
def process(
- self, element, seen=beam.DoFn.StateParam(
+ self,
+ element,
+ seen=beam.DoFn.StateParam(
beam.transforms.userstate.BagStateSpec(
'seen',
+ beam.coders.FastPrimitivesCoder())),
+ expected=beam.DoFn.StateParam(
+ beam.transforms.userstate.BagStateSpec(
+ 'expected',
beam.coders.FastPrimitivesCoder()))):
_, (action, data) = element
+
+ if self.allow_out_of_order:
+ if action == 'expect' and not list(seen.read()):
+ if data:
+ expected.add(data)
+ return
+ elif action == 'actual' and list(expected.read()):
+ seen.add(data)
+ all_data = list(seen.read())
+ all_expected = list(expected.read())
+ if len(all_data) == len(all_expected[0]):
+ expected.clear()
+ for expect in all_expected[1:]:
+ expected.add(expect)
+ action, data = 'expect', all_expected[0]
+ else:
+ return
+
if action == 'actual':
seen.add(data)
@@ -768,12 +807,14 @@
else:
raise ValueError('Unexpected action: %s' % action)
- with TestPipeline(options=PipelineOptions(streaming=True)) as p:
+ with TestPipeline() as p:
+ # TODO(BEAM-8601): Pass this during pipeline construction.
+ p.options.view_as(StandardOptions).streaming = True
# Split the test stream into a branch of to-be-processed elements, and
# a branch of expected results.
inputs, expected = (
p
- | test_stream
+ | read_test_stream
| beam.MapTuple(
lambda tag, value: beam.pvalue.TaggedOutput(tag, ('key', value))
).with_outputs('input', 'expect'))
@@ -794,7 +835,7 @@
t=beam.DoFn.TimestampParam,
p=beam.DoFn.PaneInfoParam: (
k,
- self._windowed_value_info(WindowedValue(
+ _windowed_value_info(WindowedValue(
vs, windows=[window], timestamp=t, pane_info=p))))
# Place outputs back into the global window to allow flattening
# and share a single state in Check.
@@ -805,7 +846,17 @@
tagged_outputs = (
outputs | beam.MapTuple(lambda key, value: (key, ('actual', value))))
# pylint: disable=expression-not-assigned
- (tagged_expected, tagged_outputs) | beam.Flatten() | beam.ParDo(Check())
+ ([tagged_expected, tagged_outputs]
+ | beam.Flatten()
+ | beam.ParDo(Check(self.allow_out_of_order)))
+
+
+class TestStreamTranscriptTest(BaseTestStreamTranscriptTest):
+ allow_out_of_order = False
+
+
+class WeakTestStreamTranscriptTest(BaseTestStreamTranscriptTest):
+ allow_out_of_order = True
TRANSCRIPT_TEST_FILE = os.path.join(
@@ -814,6 +865,7 @@
if os.path.exists(TRANSCRIPT_TEST_FILE):
TriggerDriverTranscriptTest._create_tests(TRANSCRIPT_TEST_FILE)
TestStreamTranscriptTest._create_tests(TRANSCRIPT_TEST_FILE)
+ WeakTestStreamTranscriptTest._create_tests(TRANSCRIPT_TEST_FILE)
if __name__ == '__main__':
diff --git a/sdks/python/test-suites/portable/common.gradle b/sdks/python/test-suites/portable/common.gradle
index f04d28d..3cb4362 100644
--- a/sdks/python/test-suites/portable/common.gradle
+++ b/sdks/python/test-suites/portable/common.gradle
@@ -79,3 +79,22 @@
task flinkValidatesRunner() {
dependsOn 'flinkCompatibilityMatrixLoopback'
}
+
+// TODO(BEAM-8598): Enable on pre-commit.
+task flinkTriggerTranscript() {
+ dependsOn 'setupVirtualenv'
+ dependsOn ':runners:flink:1.9:job-server:shadowJar'
+ doLast {
+ exec {
+ executable 'sh'
+ args '-c', """
+ . ${envdir}/bin/activate \\
+ && cd ${pythonRootDir} \\
+ && pip install -e .[test] \\
+ && python setup.py nosetests \\
+ --tests apache_beam.transforms.trigger_test:WeakTestStreamTranscriptTest \\
+ --test-pipeline-options='--runner=FlinkRunner --environment_type=LOOPBACK --flink_job_server_jar=${project(":runners:flink:1.9:job-server:").shadowJar.archivePath}'
+ """
+ }
+ }
+}