Merge pull request #10076 from ibzib/java-default-region
[BEAM-8628] use mock GcsUtil in testDefaultGcpTempLocationDoesNotExist
diff --git a/.test-infra/jenkins/job_Dependency_Check.groovy b/.test-infra/jenkins/job_Dependency_Check.groovy
index ac66881..dddd2f7 100644
--- a/.test-infra/jenkins/job_Dependency_Check.groovy
+++ b/.test-infra/jenkins/job_Dependency_Check.groovy
@@ -38,7 +38,7 @@
steps {
gradle {
rootBuildScriptDir(commonJobProperties.checkoutDir)
- tasks(':runBeamDependencyCheck')
+ tasks('runBeamDependencyCheck')
commonJobProperties.setGradleSwitches(delegate)
switches('-Drevision=release')
}
diff --git a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy
index 3ea3643..557f45f 100644
--- a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy
+++ b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy
@@ -1905,6 +1905,7 @@
mustRunAfter = [
':runners:flink:1.9:job-server-container:docker',
':runners:flink:1.9:job-server:shadowJar',
+ ':runners:spark:job-server:shadowJar',
':sdks:python:container:py2:docker',
':sdks:python:container:py35:docker',
':sdks:python:container:py36:docker',
@@ -1958,6 +1959,7 @@
addPortableWordCountTask(true, "PortableRunner")
addPortableWordCountTask(false, "FlinkRunner")
addPortableWordCountTask(true, "FlinkRunner")
+ addPortableWordCountTask(false, "SparkRunner")
}
}
}
diff --git a/model/fn-execution/src/main/proto/beam_fn_api.proto b/model/fn-execution/src/main/proto/beam_fn_api.proto
index ed2f013..0ddc48e 100644
--- a/model/fn-execution/src/main/proto/beam_fn_api.proto
+++ b/model/fn-execution/src/main/proto/beam_fn_api.proto
@@ -42,6 +42,7 @@
import "endpoints.proto";
import "google/protobuf/descriptor.proto";
import "google/protobuf/timestamp.proto";
+import "google/protobuf/duration.proto";
import "google/protobuf/wrappers.proto";
import "metrics.proto";
@@ -203,13 +204,21 @@
}
// An Application should be scheduled for execution after a delay.
+// Either an absolute timestamp or a relative timestamp can represent a
+// scheduled execution time.
message DelayedBundleApplication {
// Recommended time at which the application should be scheduled to execute
// by the runner. Times in the past may be scheduled to execute immediately.
+ // TODO(BEAM-8536): Migrate usage of absolute time to requested_time_delay.
google.protobuf.Timestamp requested_execution_time = 1;
// (Required) The application that should be scheduled.
BundleApplication application = 2;
+
+ // Recommended time delay at which the application should be scheduled to
+ // execute by the runner. Time delay that equals 0 may be scheduled to execute
+ // immediately. The unit of time delay should be microsecond.
+ google.protobuf.Duration requested_time_delay = 3;
}
// A request to process a given bundle.
diff --git a/model/job-management/src/main/proto/beam_job_api.proto b/model/job-management/src/main/proto/beam_job_api.proto
index e9f0eb9..d297d3b 100644
--- a/model/job-management/src/main/proto/beam_job_api.proto
+++ b/model/job-management/src/main/proto/beam_job_api.proto
@@ -213,17 +213,40 @@
// without needing to pass through STARTING.
message JobState {
enum Enum {
+ // The job state reported by a runner cannot be interpreted by the SDK.
UNSPECIFIED = 0;
+
+ // The job has not yet started.
STOPPED = 1;
+
+ // The job is currently running.
RUNNING = 2;
+
+ // The job has successfully completed. (terminal)
DONE = 3;
+
+ // The job has failed. (terminal)
FAILED = 4;
+
+ // The job has been explicitly cancelled. (terminal)
CANCELLED = 5;
+
+ // The job has been updated. (terminal)
UPDATED = 6;
+
+ // The job is draining its data. (optional)
DRAINING = 7;
+
+ // The job has completed draining its data. (terminal)
DRAINED = 8;
+
+ // The job is starting up.
STARTING = 9;
+
+ // The job is cancelling. (optional)
CANCELLING = 10;
+
+ // The job is in the process of being updated. (optional)
UPDATING = 11;
}
}
diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/GreedyPCollectionFusers.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/GreedyPCollectionFusers.java
index 1a6fee4..cecbee9 100644
--- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/GreedyPCollectionFusers.java
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/GreedyPCollectionFusers.java
@@ -50,11 +50,20 @@
PTransformTranslation.SPLITTABLE_PAIR_WITH_RESTRICTION_URN,
GreedyPCollectionFusers::canFuseParDo)
.put(
+ PTransformTranslation.SPLITTABLE_SPLIT_RESTRICTION_URN,
+ GreedyPCollectionFusers::canFuseParDo)
+ .put(
+ PTransformTranslation.SPLITTABLE_PROCESS_KEYED_URN,
+ GreedyPCollectionFusers::cannotFuse)
+ .put(
+ PTransformTranslation.SPLITTABLE_PROCESS_ELEMENTS_URN,
+ GreedyPCollectionFusers::cannotFuse)
+ .put(
PTransformTranslation.SPLITTABLE_SPLIT_AND_SIZE_RESTRICTIONS_URN,
GreedyPCollectionFusers::canFuseParDo)
.put(
PTransformTranslation.SPLITTABLE_PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS_URN,
- GreedyPCollectionFusers::canFuseParDo)
+ GreedyPCollectionFusers::cannotFuse)
.put(
PTransformTranslation.COMBINE_PER_KEY_PRECOMBINE_TRANSFORM_URN,
GreedyPCollectionFusers::canFuseCompatibleEnvironment)
diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/ProtoOverrides.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/ProtoOverrides.java
index 5ea867e..cd28cbf 100644
--- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/ProtoOverrides.java
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/ProtoOverrides.java
@@ -21,6 +21,7 @@
import java.util.List;
import java.util.Map;
+import javax.annotation.Nullable;
import org.apache.beam.model.pipeline.v1.RunnerApi.Components;
import org.apache.beam.model.pipeline.v1.RunnerApi.ComponentsOrBuilder;
import org.apache.beam.model.pipeline.v1.RunnerApi.MessageWithComponents;
@@ -32,10 +33,9 @@
/**
* A way to apply a Proto-based {@link PTransformOverride}.
*
- * <p>This should generally be used to replace runner-executed transforms with runner-executed
- * composites and simpler runner-executed primitives. It is generically less powerful than the
- * native {@link org.apache.beam.sdk.Pipeline#replaceAll(List)} and more error-prone, so should only
- * be used for relatively simple replacements.
+ * <p>This should generally be used by runners to replace transforms within graphs. SDK construction
+ * code should rely on the more powerful and native {@link
+ * org.apache.beam.sdk.Pipeline#replaceAll(List)}.
*/
@Experimental
public class ProtoOverrides {
@@ -51,6 +51,10 @@
if (pt.getValue().getSpec() != null && urn.equals(pt.getValue().getSpec().getUrn())) {
MessageWithComponents updated =
compositeBuilder.getReplacement(pt.getKey(), originalPipeline.getComponents());
+ if (updated == null) {
+ continue;
+ }
+
checkArgument(
updated.getPtransform().getOutputsMap().equals(pt.getValue().getOutputsMap()),
"A %s must produce all of the outputs of the original %s",
@@ -66,8 +70,8 @@
}
/**
- * Remove all subtransforms of the provided transform recursively.A {@link PTransform} can be the
- * subtransform of only one enclosing transform.
+ * Remove all sub-transforms of the provided transform recursively. A {@link PTransform} can be
+ * the sub-transform of only one enclosing transform.
*/
private static void removeSubtransforms(PTransform pt, Components.Builder target) {
for (String subtransformId : pt.getSubtransformsList()) {
@@ -87,14 +91,16 @@
/**
* Returns the updated composite structure for the provided {@link PTransform}.
*
- * <p>The returned {@link MessageWithComponents} must contain a single {@link PTransform}. The
- * result {@link Components} will be merged into the existing components, and the result {@link
- * PTransform} will be set as a replacement of the original {@link PTransform}. Notably, this
- * does not require that the {@code existingComponents} are present in the returned {@link
+ * <p>If the return is null, then no replacement is performed, otherwise the returned {@link
+ * MessageWithComponents} must contain a single {@link PTransform}. The result {@link
+ * Components} will be merged into the existing components, and the result {@link PTransform}
+ * will be set as a replacement of the original {@link PTransform}. Notably, this does not
+ * require that the {@code existingComponents} are present in the returned {@link
* MessageWithComponents}.
*
* <p>Introduced components must not collide with any components in the existing components.
*/
+ @Nullable
MessageWithComponents getReplacement(
String transformId, ComponentsOrBuilder existingComponents);
}
diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/QueryablePipeline.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/QueryablePipeline.java
index 382f68a..4ed19da 100644
--- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/QueryablePipeline.java
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/QueryablePipeline.java
@@ -28,10 +28,12 @@
import static org.apache.beam.runners.core.construction.PTransformTranslation.MAP_WINDOWS_TRANSFORM_URN;
import static org.apache.beam.runners.core.construction.PTransformTranslation.PAR_DO_TRANSFORM_URN;
import static org.apache.beam.runners.core.construction.PTransformTranslation.READ_TRANSFORM_URN;
+import static org.apache.beam.runners.core.construction.PTransformTranslation.SPLITTABLE_PAIR_WITH_RESTRICTION_URN;
import static org.apache.beam.runners.core.construction.PTransformTranslation.SPLITTABLE_PROCESS_ELEMENTS_URN;
import static org.apache.beam.runners.core.construction.PTransformTranslation.SPLITTABLE_PROCESS_KEYED_URN;
import static org.apache.beam.runners.core.construction.PTransformTranslation.SPLITTABLE_PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS_URN;
import static org.apache.beam.runners.core.construction.PTransformTranslation.SPLITTABLE_SPLIT_AND_SIZE_RESTRICTIONS_URN;
+import static org.apache.beam.runners.core.construction.PTransformTranslation.SPLITTABLE_SPLIT_RESTRICTION_URN;
import static org.apache.beam.runners.core.construction.PTransformTranslation.TEST_STREAM_TRANSFORM_URN;
import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument;
@@ -174,6 +176,8 @@
COMBINE_PER_KEY_PRECOMBINE_TRANSFORM_URN,
COMBINE_PER_KEY_MERGE_ACCUMULATORS_TRANSFORM_URN,
COMBINE_PER_KEY_EXTRACT_OUTPUTS_TRANSFORM_URN,
+ SPLITTABLE_PAIR_WITH_RESTRICTION_URN,
+ SPLITTABLE_SPLIT_RESTRICTION_URN,
SPLITTABLE_PROCESS_KEYED_URN,
SPLITTABLE_PROCESS_ELEMENTS_URN,
SPLITTABLE_SPLIT_AND_SIZE_RESTRICTIONS_URN,
diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/SplittableParDoExpander.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/SplittableParDoExpander.java
new file mode 100644
index 0000000..77f0211
--- /dev/null
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/SplittableParDoExpander.java
@@ -0,0 +1,273 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.core.construction.graph;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Map;
+import java.util.function.Predicate;
+import org.apache.beam.model.pipeline.v1.RunnerApi.Coder;
+import org.apache.beam.model.pipeline.v1.RunnerApi.ComponentsOrBuilder;
+import org.apache.beam.model.pipeline.v1.RunnerApi.FunctionSpec;
+import org.apache.beam.model.pipeline.v1.RunnerApi.MessageWithComponents;
+import org.apache.beam.model.pipeline.v1.RunnerApi.PCollection;
+import org.apache.beam.model.pipeline.v1.RunnerApi.PTransform;
+import org.apache.beam.model.pipeline.v1.RunnerApi.ParDoPayload;
+import org.apache.beam.runners.core.construction.ModelCoders;
+import org.apache.beam.runners.core.construction.PTransformTranslation;
+import org.apache.beam.runners.core.construction.ParDoTranslation;
+import org.apache.beam.runners.core.construction.graph.ProtoOverrides.TransformReplacement;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps;
+
+/**
+ * A set of transform replacements for expanding a splittable ParDo into various sub components.
+ *
+ * <p>Further details about the expansion can be found at <a
+ * href="https://github.com/apache/beam/blob/cb15994d5228f729dda922419b08520c8be8804e/model/pipeline/src/main/proto/beam_runner_api.proto#L279"
+ * />
+ */
+public class SplittableParDoExpander {
+
+ /**
+ * Returns a transform replacement which expands a splittable ParDo from:
+ *
+ * <pre>{@code
+ * sideInputA ---------\
+ * sideInputB ---------V
+ * mainInput ---> SplittableParDo --> outputA
+ * \-> outputB
+ * }</pre>
+ *
+ * into:
+ *
+ * <pre>{@code
+ * sideInputA ---------\---------------------\--------------------------\
+ * sideInputB ---------V---------------------V--------------------------V
+ * mainInput ---> PairWithRestricton --> SplitAndSize --> ProcessSizedElementsAndRestriction --> outputA
+ * \-> outputB
+ * }</pre>
+ *
+ * <p>Specifically this transform ensures that initial splitting is performed and that the sizing
+ * information is available to the runner if it chooses to inspect it.
+ */
+ public static TransformReplacement createSizedReplacement() {
+ return SizedReplacement.INSTANCE;
+ }
+
+ /** See {@link #createSizedReplacement()} for details. */
+ private static class SizedReplacement implements TransformReplacement {
+
+ private static final SizedReplacement INSTANCE = new SizedReplacement();
+
+ @Override
+ public MessageWithComponents getReplacement(
+ String transformId, ComponentsOrBuilder existingComponents) {
+ try {
+ MessageWithComponents.Builder rval = MessageWithComponents.newBuilder();
+
+ PTransform splittableParDo = existingComponents.getTransformsOrThrow(transformId);
+ ParDoPayload payload = ParDoPayload.parseFrom(splittableParDo.getSpec().getPayload());
+ // Only perform the expansion if this is a splittable DoFn.
+ if (payload.getRestrictionCoderId() == null || payload.getRestrictionCoderId().isEmpty()) {
+ return null;
+ }
+
+ String mainInputName = ParDoTranslation.getMainInputName(splittableParDo);
+ String mainInputPCollectionId = splittableParDo.getInputsOrThrow(mainInputName);
+ PCollection mainInputPCollection =
+ existingComponents.getPcollectionsOrThrow(mainInputPCollectionId);
+ Map<String, String> sideInputs =
+ Maps.filterKeys(
+ splittableParDo.getInputsMap(), input -> payload.containsSideInputs(input));
+
+ String pairWithRestrictionOutCoderId =
+ generateUniqueId(
+ mainInputPCollection.getCoderId() + "/PairWithRestriction",
+ existingComponents::containsCoders);
+ rval.getComponentsBuilder()
+ .putCoders(
+ pairWithRestrictionOutCoderId,
+ ModelCoders.kvCoder(
+ mainInputPCollection.getCoderId(), payload.getRestrictionCoderId()));
+
+ String pairWithRestrictionOutId =
+ generateUniqueId(
+ mainInputPCollectionId + "/PairWithRestriction",
+ existingComponents::containsPcollections);
+ rval.getComponentsBuilder()
+ .putPcollections(
+ pairWithRestrictionOutId,
+ PCollection.newBuilder()
+ .setCoderId(pairWithRestrictionOutCoderId)
+ .setIsBounded(mainInputPCollection.getIsBounded())
+ .setWindowingStrategyId(mainInputPCollection.getWindowingStrategyId())
+ .setUniqueName(
+ generateUniquePCollectonName(
+ mainInputPCollection.getUniqueName() + "/PairWithRestriction",
+ existingComponents))
+ .build());
+
+ String splitAndSizeOutCoderId =
+ generateUniqueId(
+ mainInputPCollection.getCoderId() + "/SplitAndSize",
+ existingComponents::containsCoders);
+ rval.getComponentsBuilder()
+ .putCoders(
+ splitAndSizeOutCoderId,
+ ModelCoders.kvCoder(
+ pairWithRestrictionOutCoderId, getOrAddDoubleCoder(existingComponents, rval)));
+
+ String splitAndSizeOutId =
+ generateUniqueId(
+ mainInputPCollectionId + "/SplitAndSize", existingComponents::containsPcollections);
+ rval.getComponentsBuilder()
+ .putPcollections(
+ splitAndSizeOutId,
+ PCollection.newBuilder()
+ .setCoderId(splitAndSizeOutCoderId)
+ .setIsBounded(mainInputPCollection.getIsBounded())
+ .setWindowingStrategyId(mainInputPCollection.getWindowingStrategyId())
+ .setUniqueName(
+ generateUniquePCollectonName(
+ mainInputPCollection.getUniqueName() + "/SplitAndSize",
+ existingComponents))
+ .build());
+
+ String pairWithRestrictionId =
+ generateUniqueId(
+ transformId + "/PairWithRestriction", existingComponents::containsTransforms);
+ {
+ PTransform.Builder pairWithRestriction = PTransform.newBuilder();
+ pairWithRestriction.putAllInputs(splittableParDo.getInputsMap());
+ pairWithRestriction.putOutputs("out", pairWithRestrictionOutId);
+ pairWithRestriction.setUniqueName(
+ generateUniquePCollectonName(
+ splittableParDo.getUniqueName() + "/PairWithRestriction", existingComponents));
+ pairWithRestriction.setSpec(
+ FunctionSpec.newBuilder()
+ .setUrn(PTransformTranslation.SPLITTABLE_PAIR_WITH_RESTRICTION_URN)
+ .setPayload(splittableParDo.getSpec().getPayload()));
+ rval.getComponentsBuilder()
+ .putTransforms(pairWithRestrictionId, pairWithRestriction.build());
+ }
+
+ String splitAndSizeId =
+ generateUniqueId(transformId + "/SplitAndSize", existingComponents::containsTransforms);
+ {
+ PTransform.Builder splitAndSize = PTransform.newBuilder();
+ splitAndSize.putInputs(mainInputName, pairWithRestrictionOutId);
+ splitAndSize.putAllInputs(sideInputs);
+ splitAndSize.putOutputs("out", splitAndSizeOutId);
+ splitAndSize.setUniqueName(
+ generateUniquePCollectonName(
+ splittableParDo.getUniqueName() + "/SplitAndSize", existingComponents));
+ splitAndSize.setSpec(
+ FunctionSpec.newBuilder()
+ .setUrn(PTransformTranslation.SPLITTABLE_SPLIT_AND_SIZE_RESTRICTIONS_URN)
+ .setPayload(splittableParDo.getSpec().getPayload()));
+ rval.getComponentsBuilder().putTransforms(splitAndSizeId, splitAndSize.build());
+ }
+
+ String processSizedElementsAndRestrictionsId =
+ generateUniqueId(
+ transformId + "/ProcessSizedElementsAndRestrictions",
+ existingComponents::containsTransforms);
+ {
+ PTransform.Builder processSizedElementsAndRestrictions = PTransform.newBuilder();
+ processSizedElementsAndRestrictions.putInputs(mainInputName, splitAndSizeOutId);
+ processSizedElementsAndRestrictions.putAllInputs(sideInputs);
+ processSizedElementsAndRestrictions.putAllOutputs(splittableParDo.getOutputsMap());
+ processSizedElementsAndRestrictions.setUniqueName(
+ generateUniquePCollectonName(
+ splittableParDo.getUniqueName() + "/ProcessSizedElementsAndRestrictions",
+ existingComponents));
+ processSizedElementsAndRestrictions.setSpec(
+ FunctionSpec.newBuilder()
+ .setUrn(
+ PTransformTranslation.SPLITTABLE_PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS_URN)
+ .setPayload(splittableParDo.getSpec().getPayload()));
+ rval.getComponentsBuilder()
+ .putTransforms(
+ processSizedElementsAndRestrictionsId,
+ processSizedElementsAndRestrictions.build());
+ }
+
+ PTransform.Builder newCompositeRoot =
+ splittableParDo
+ .toBuilder()
+ // Clear the original splittable ParDo spec and add all the new transforms as
+ // children.
+ .clearSpec()
+ .addAllSubtransforms(
+ Arrays.asList(
+ pairWithRestrictionId,
+ splitAndSizeId,
+ processSizedElementsAndRestrictionsId));
+ rval.setPtransform(newCompositeRoot);
+
+ return rval.build();
+ } catch (IOException e) {
+ throw new RuntimeException("Unable to perform expansion for transform " + transformId, e);
+ }
+ }
+ }
+
+ private static String getOrAddDoubleCoder(
+ ComponentsOrBuilder existingComponents, MessageWithComponents.Builder out) {
+ for (Map.Entry<String, Coder> coder : existingComponents.getCodersMap().entrySet()) {
+ if (ModelCoders.DOUBLE_CODER_URN.equals(coder.getValue().getSpec().getUrn())) {
+ return coder.getKey();
+ }
+ }
+ String doubleCoderId = generateUniqueId("DoubleCoder", existingComponents::containsCoders);
+ out.getComponentsBuilder()
+ .putCoders(
+ doubleCoderId,
+ Coder.newBuilder()
+ .setSpec(FunctionSpec.newBuilder().setUrn(ModelCoders.DOUBLE_CODER_URN))
+ .build());
+ return doubleCoderId;
+ }
+
+ /**
+ * Returns a PCollection name that uses the supplied prefix that does not exist in {@code
+ * existingComponents}.
+ */
+ private static String generateUniquePCollectonName(
+ String prefix, ComponentsOrBuilder existingComponents) {
+ return generateUniqueId(
+ prefix,
+ input -> {
+ for (PCollection pc : existingComponents.getPcollectionsMap().values()) {
+ if (input.equals(pc.getUniqueName())) {
+ return true;
+ }
+ }
+ return false;
+ });
+ }
+
+ /** Generates a unique id given a prefix and a predicate to compare if the id is already used. */
+ private static String generateUniqueId(String prefix, Predicate<String> isExistingId) {
+ int i = 0;
+ while (isExistingId.test(prefix + i)) {
+ i += 1;
+ }
+ return prefix + i;
+ }
+}
diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/PipelineTranslationTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/PipelineTranslationTest.java
index e149e66..7a8a51b 100644
--- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/PipelineTranslationTest.java
+++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/PipelineTranslationTest.java
@@ -96,7 +96,7 @@
Window.<Long>into(FixedWindows.of(Duration.standardMinutes(7)))
.triggering(
AfterWatermark.pastEndOfWindow()
- .withEarlyFirings(AfterPane.elementCountAtLeast(19)))
+ .withLateFirings(AfterPane.elementCountAtLeast(19)))
.accumulatingFiredPanes()
.withAllowedLateness(Duration.standardMinutes(3L)));
final WindowingStrategy<?, ?> windowedStrategy = windowed.getWindowingStrategy();
diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/graph/SplittableParDoExpanderTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/graph/SplittableParDoExpanderTest.java
new file mode 100644
index 0000000..5a8d125
--- /dev/null
+++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/graph/SplittableParDoExpanderTest.java
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.core.construction.graph;
+
+import static org.apache.beam.sdk.transforms.DoFn.ProcessContinuation.resume;
+import static org.apache.beam.sdk.transforms.DoFn.ProcessContinuation.stop;
+import static org.junit.Assert.assertEquals;
+
+import org.apache.beam.model.pipeline.v1.RunnerApi;
+import org.apache.beam.model.pipeline.v1.RunnerApi.FunctionSpec;
+import org.apache.beam.runners.core.construction.PTransformTranslation;
+import org.apache.beam.runners.core.construction.PipelineTranslation;
+import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.io.range.OffsetRange;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.DoFn.UnboundedPerElement;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class SplittableParDoExpanderTest {
+
+ @UnboundedPerElement
+ static class PairStringWithIndexToLengthBase extends DoFn<String, KV<String, Integer>> {
+ @ProcessElement
+ public ProcessContinuation process(
+ ProcessContext c, RestrictionTracker<OffsetRange, Long> tracker) {
+ for (long i = tracker.currentRestriction().getFrom(), numIterations = 0;
+ tracker.tryClaim(i);
+ ++i, ++numIterations) {
+ c.output(KV.of(c.element(), (int) i));
+ if (numIterations % 3 == 0) {
+ return resume();
+ }
+ }
+ return stop();
+ }
+
+ @GetInitialRestriction
+ public OffsetRange getInitialRange(String element) {
+ return new OffsetRange(0, element.length());
+ }
+
+ @SplitRestriction
+ public void splitRange(
+ String element, OffsetRange range, OutputReceiver<OffsetRange> receiver) {
+ receiver.output(new OffsetRange(range.getFrom(), (range.getFrom() + range.getTo()) / 2));
+ receiver.output(new OffsetRange((range.getFrom() + range.getTo()) / 2, range.getTo()));
+ }
+ }
+
+ @Test
+ public void testSizedReplacement() {
+ Pipeline p = Pipeline.create();
+ p.apply(Create.of("1", "2", "3"))
+ .apply("TestSDF", ParDo.of(new PairStringWithIndexToLengthBase()));
+
+ RunnerApi.Pipeline proto = PipelineTranslation.toProto(p);
+ String transformName =
+ Iterables.getOnlyElement(
+ Maps.filterValues(
+ proto.getComponents().getTransformsMap(),
+ (RunnerApi.PTransform transform) ->
+ transform
+ .getUniqueName()
+ .contains(PairStringWithIndexToLengthBase.class.getSimpleName()))
+ .keySet());
+
+ RunnerApi.Pipeline updatedProto =
+ ProtoOverrides.updateTransform(
+ PTransformTranslation.PAR_DO_TRANSFORM_URN,
+ proto,
+ SplittableParDoExpander.createSizedReplacement());
+ RunnerApi.PTransform newComposite =
+ updatedProto.getComponents().getTransformsOrThrow(transformName);
+ assertEquals(FunctionSpec.getDefaultInstance(), newComposite.getSpec());
+ assertEquals(3, newComposite.getSubtransformsCount());
+ assertEquals(
+ PTransformTranslation.SPLITTABLE_PAIR_WITH_RESTRICTION_URN,
+ updatedProto
+ .getComponents()
+ .getTransformsOrThrow(newComposite.getSubtransforms(0))
+ .getSpec()
+ .getUrn());
+ assertEquals(
+ PTransformTranslation.SPLITTABLE_SPLIT_AND_SIZE_RESTRICTIONS_URN,
+ updatedProto
+ .getComponents()
+ .getTransformsOrThrow(newComposite.getSubtransforms(1))
+ .getSpec()
+ .getUrn());
+ assertEquals(
+ PTransformTranslation.SPLITTABLE_PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS_URN,
+ updatedProto
+ .getComponents()
+ .getTransformsOrThrow(newComposite.getSubtransforms(2))
+ .getSpec()
+ .getUrn());
+ }
+}
diff --git a/runners/flink/flink_runner.gradle b/runners/flink/flink_runner.gradle
index 3254a85..6281b94 100644
--- a/runners/flink/flink_runner.gradle
+++ b/runners/flink/flink_runner.gradle
@@ -200,6 +200,7 @@
excludeCategories 'org.apache.beam.sdk.testing.UsesCommittedMetrics'
if (config.streaming) {
excludeCategories 'org.apache.beam.sdk.testing.UsesImpulse'
+ excludeCategories 'org.apache.beam.sdk.testing.UsesTestStreamWithMultipleStages' // BEAM-8598
excludeCategories 'org.apache.beam.sdk.testing.UsesTestStreamWithProcessingTime'
} else {
excludeCategories 'org.apache.beam.sdk.testing.UsesUnboundedSplittableParDo'
diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkPipelineRunner.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkPipelineRunner.java
index 212c75e..33d2c76 100644
--- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkPipelineRunner.java
+++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkPipelineRunner.java
@@ -26,10 +26,13 @@
import javax.annotation.Nullable;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.model.pipeline.v1.RunnerApi.Pipeline;
+import org.apache.beam.runners.core.construction.PTransformTranslation;
import org.apache.beam.runners.core.construction.PipelineOptionsTranslation;
import org.apache.beam.runners.core.construction.graph.ExecutableStage;
import org.apache.beam.runners.core.construction.graph.GreedyPipelineFuser;
import org.apache.beam.runners.core.construction.graph.PipelineTrimmer;
+import org.apache.beam.runners.core.construction.graph.ProtoOverrides;
+import org.apache.beam.runners.core.construction.graph.SplittableParDoExpander;
import org.apache.beam.runners.core.metrics.MetricsPusher;
import org.apache.beam.runners.fnexecution.jobsubmission.PortablePipelineJarUtils;
import org.apache.beam.runners.fnexecution.jobsubmission.PortablePipelineResult;
@@ -87,8 +90,16 @@
throws Exception {
LOG.info("Translating pipeline to Flink program.");
+ // Expand any splittable ParDos within the graph to enable sizing and splitting of bundles.
+ Pipeline pipelineWithSdfExpanded =
+ ProtoOverrides.updateTransform(
+ PTransformTranslation.PAR_DO_TRANSFORM_URN,
+ pipeline,
+ SplittableParDoExpander.createSizedReplacement());
+
// Don't let the fuser fuse any subcomponents of native transforms.
- Pipeline trimmedPipeline = PipelineTrimmer.trim(pipeline, translator.knownUrns());
+ Pipeline trimmedPipeline =
+ PipelineTrimmer.trim(pipelineWithSdfExpanded, translator.knownUrns());
// Fused pipeline proto.
// TODO: Consider supporting partially-fused graphs.
diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java
index 915acfb..eed52e9 100644
--- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java
+++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java
@@ -94,6 +94,8 @@
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.typeutils.base.StringSerializer;
import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
+import org.apache.flink.runtime.state.KeyGroupRange;
import org.apache.flink.runtime.state.KeyedStateBackend;
import org.apache.flink.streaming.api.operators.InternalTimer;
import org.apache.flink.streaming.api.watermark.Watermark;
@@ -246,7 +248,8 @@
() -> UUID.randomUUID().toString(),
keyedStateInternals,
getKeyedStateBackend(),
- stateBackendLock));
+ stateBackendLock,
+ keyCoder));
} else {
userStateRequestHandler = StateRequestHandler.unsupported();
}
@@ -263,7 +266,12 @@
private final StateInternals stateInternals;
private final KeyedStateBackend<ByteBuffer> keyedStateBackend;
+ /** Lock to hold whenever accessing the state backend. */
private final Lock stateBackendLock;
+ /** For debugging: The key coder used by the Runner. */
+ @Nullable private final Coder runnerKeyCoder;
+ /** For debugging: Same as keyedStateBackend but upcasted, to access key group meta info. */
+ @Nullable private final AbstractKeyedStateBackend<ByteBuffer> keyStateBackendWithKeyGroupInfo;
/** Holds the valid cache token for user state for this operator. */
private final ByteString cacheToken;
@@ -271,10 +279,20 @@
IdGenerator cacheTokenGenerator,
StateInternals stateInternals,
KeyedStateBackend<ByteBuffer> keyedStateBackend,
- Lock stateBackendLock) {
+ Lock stateBackendLock,
+ @Nullable Coder runnerKeyCoder) {
this.stateInternals = stateInternals;
this.keyedStateBackend = keyedStateBackend;
this.stateBackendLock = stateBackendLock;
+ if (keyedStateBackend instanceof AbstractKeyedStateBackend) {
+ // This will always succeed, unless a custom state backend is used which does not extend
+ // AbstractKeyedStateBackend. This is unlikely but we should still consider this case.
+ this.keyStateBackendWithKeyGroupInfo =
+ (AbstractKeyedStateBackend<ByteBuffer>) keyedStateBackend;
+ } else {
+ this.keyStateBackendWithKeyGroupInfo = null;
+ }
+ this.runnerKeyCoder = runnerKeyCoder;
this.cacheToken = ByteString.copyFrom(cacheTokenGenerator.getId().getBytes(Charsets.UTF_8));
}
@@ -368,6 +386,19 @@
// Key for state request is shipped encoded with NESTED context.
ByteBuffer encodedKey = FlinkKeyUtils.fromEncodedKey(key);
keyedStateBackend.setCurrentKey(encodedKey);
+ if (keyStateBackendWithKeyGroupInfo != null) {
+ int currentKeyGroupIndex = keyStateBackendWithKeyGroupInfo.getCurrentKeyGroupIndex();
+ KeyGroupRange keyGroupRange = keyStateBackendWithKeyGroupInfo.getKeyGroupRange();
+ Preconditions.checkState(
+ keyGroupRange.contains(currentKeyGroupIndex),
+ "The current key '%s' with key group index '%s' does not belong to the key group range '%s'. Runner keyCoder: %s. Ptransformid: %s Userstateid: %s",
+ Arrays.toString(key.toByteArray()),
+ currentKeyGroupIndex,
+ keyGroupRange,
+ runnerKeyCoder,
+ pTransformId,
+ userStateId);
+ }
}
};
}
diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperatorTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperatorTest.java
index 0d7c99f..0cbf6b7 100644
--- a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperatorTest.java
+++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperatorTest.java
@@ -659,7 +659,7 @@
ExecutableStageDoFnOperator.BagUserStateFactory<ByteString, Integer, GlobalWindow>
bagUserStateFactory =
new ExecutableStageDoFnOperator.BagUserStateFactory<>(
- cacheTokenGenerator, test, stateBackend, NoopLock.get());
+ cacheTokenGenerator, test, stateBackend, NoopLock.get(), null);
ByteString key1 = ByteString.copyFrom("key1", Charsets.UTF_8);
ByteString key2 = ByteString.copyFrom("key2", Charsets.UTF_8);
diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java
index 3206870..dd12c22 100644
--- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java
+++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java
@@ -1453,7 +1453,8 @@
private static class ImpulseTranslator implements TransformTranslator<Impulse> {
@Override
public void translate(Impulse transform, TranslationContext context) {
- if (context.getPipelineOptions().isStreaming() && !context.isFnApi()) {
+ if (context.getPipelineOptions().isStreaming()
+ && (!context.isFnApi() || !context.isStreamingEngine())) {
StepTranslationContext stepContext = context.addStep(transform, "ParallelRead");
stepContext.addInput(PropertyNames.FORMAT, "pubsub");
stepContext.addInput(PropertyNames.PUBSUB_SUBSCRIPTION, "_starting_signal/");
diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/TransformTranslator.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/TransformTranslator.java
index 5cfcb6e..71d49ac 100644
--- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/TransformTranslator.java
+++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/TransformTranslator.java
@@ -25,6 +25,7 @@
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.annotations.Internal;
import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.extensions.gcp.options.GcpOptions;
import org.apache.beam.sdk.runners.AppliedPTransform;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.values.PCollection;
@@ -52,6 +53,13 @@
return experiments != null && experiments.contains("beam_fn_api");
}
+ default boolean isStreamingEngine() {
+ List<String> experiments = getPipelineOptions().getExperiments();
+ return experiments != null
+ && experiments.contains(GcpOptions.STREAMING_ENGINE_EXPERIMENT)
+ && experiments.contains(GcpOptions.WINDMILL_SERVICE_EXPERIMENT);
+ }
+
/** Returns the configured pipeline options. */
DataflowPipelineOptions getPipelineOptions();
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/FnApiWindowMappingFn.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/FnApiWindowMappingFn.java
index 7e298e7..dcf15f4 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/FnApiWindowMappingFn.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/FnApiWindowMappingFn.java
@@ -56,7 +56,6 @@
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.WindowingStrategy;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Strings;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.Cache;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.CacheBuilder;
import org.slf4j.Logger;
@@ -253,7 +252,7 @@
}
// Check to see if processing the request failed.
- throwIfFailure(processResponse);
+ MoreFutures.get(processResponse);
waitForInboundTermination.awaitCompletion();
WindowedValue<KV<byte[], TargetWindowT>> sideInputWindow = outputValue.poll();
@@ -300,22 +299,10 @@
processBundleDescriptor.toBuilder().setId(descriptorId).build())
.build())
.build());
- throwIfFailure(response);
+ // Check if the bundle descriptor is registered successfully.
+ MoreFutures.get(response);
processBundleDescriptorId = descriptorId;
}
return processBundleDescriptorId;
}
-
- private static InstructionResponse throwIfFailure(
- CompletionStage<InstructionResponse> responseFuture)
- throws ExecutionException, InterruptedException {
- InstructionResponse response = MoreFutures.get(responseFuture);
- if (!Strings.isNullOrEmpty(response.getError())) {
- throw new IllegalStateException(
- String.format(
- "Client failed to process %s with error [%s].",
- response.getInstructionId(), response.getError()));
- }
- return response;
- }
}
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/fn/control/RegisterAndProcessBundleOperation.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/fn/control/RegisterAndProcessBundleOperation.java
index 73bbf4b..bf42c4d 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/fn/control/RegisterAndProcessBundleOperation.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/fn/control/RegisterAndProcessBundleOperation.java
@@ -297,7 +297,6 @@
.setRegister(registerRequest)
.build();
registerFuture = instructionRequestHandler.handle(request);
- getRegisterResponse(registerFuture);
}
checkState(
@@ -315,7 +314,10 @@
deregisterStateHandler =
beamFnStateDelegator.registerForProcessBundleInstructionId(
getProcessBundleInstructionId(), this::delegateByStateKeyType);
- processBundleResponse = instructionRequestHandler.handle(processBundleRequest);
+ processBundleResponse =
+ getRegisterResponse(registerFuture)
+ .thenCompose(
+ registerResponse -> instructionRequestHandler.handle(processBundleRequest));
}
}
@@ -368,12 +370,8 @@
* elements consumed from the upstream read operation.
*
* <p>May be called at any time, including before start() and after finish().
- *
- * @throws InterruptedException
- * @throws ExecutionException
*/
- public CompletionStage<BeamFnApi.ProcessBundleProgressResponse> getProcessBundleProgress()
- throws InterruptedException, ExecutionException {
+ public CompletionStage<BeamFnApi.ProcessBundleProgressResponse> getProcessBundleProgress() {
// processBundleId may be reset if this bundle finishes asynchronously.
String processBundleId = this.processBundleId;
@@ -391,13 +389,7 @@
return instructionRequestHandler
.handle(processBundleRequest)
- .thenApply(
- response -> {
- if (!response.getError().isEmpty()) {
- throw new IllegalStateException(response.getError());
- }
- return response.getProcessBundleProgress();
- });
+ .thenApply(InstructionResponse::getProcessBundleProgress);
}
/** Returns the final metrics returned by the SDK harness when it completes the bundle. */
@@ -634,53 +626,36 @@
return true;
}
- private static CompletionStage<BeamFnApi.InstructionResponse> throwIfFailure(
+ private static CompletionStage<BeamFnApi.ProcessBundleResponse> getProcessBundleResponse(
CompletionStage<InstructionResponse> responseFuture) {
return responseFuture.thenApply(
response -> {
- if (!response.getError().isEmpty()) {
- throw new IllegalStateException(
- String.format(
- "Client failed to process %s with error [%s].",
- response.getInstructionId(), response.getError()));
+ switch (response.getResponseCase()) {
+ case PROCESS_BUNDLE:
+ return response.getProcessBundle();
+ default:
+ throw new IllegalStateException(
+ String.format(
+ "SDK harness returned wrong kind of response to ProcessBundleRequest: %s",
+ TextFormat.printToString(response)));
}
- return response;
});
}
- private static CompletionStage<BeamFnApi.ProcessBundleResponse> getProcessBundleResponse(
- CompletionStage<InstructionResponse> responseFuture) {
- return throwIfFailure(responseFuture)
- .thenApply(
- response -> {
- switch (response.getResponseCase()) {
- case PROCESS_BUNDLE:
- return response.getProcessBundle();
- default:
- throw new IllegalStateException(
- String.format(
- "SDK harness returned wrong kind of response to ProcessBundleRequest: %s",
- TextFormat.printToString(response)));
- }
- });
- }
-
private static CompletionStage<BeamFnApi.RegisterResponse> getRegisterResponse(
- CompletionStage<InstructionResponse> responseFuture)
- throws ExecutionException, InterruptedException {
- return throwIfFailure(responseFuture)
- .thenApply(
- response -> {
- switch (response.getResponseCase()) {
- case REGISTER:
- return response.getRegister();
- default:
- throw new IllegalStateException(
- String.format(
- "SDK harness returned wrong kind of response to RegisterRequest: %s",
- TextFormat.printToString(response)));
- }
- });
+ CompletionStage<InstructionResponse> responseFuture) {
+ return responseFuture.thenApply(
+ response -> {
+ switch (response.getResponseCase()) {
+ case REGISTER:
+ return response.getRegister();
+ default:
+ throw new IllegalStateException(
+ String.format(
+ "SDK harness returned wrong kind of response to RegisterRequest: %s",
+ TextFormat.printToString(response)));
+ }
+ });
}
private static void cancelIfNotNull(CompletionStage<?> future) {
diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/fn/control/BeamFnMapTaskExecutorTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/fn/control/BeamFnMapTaskExecutorTest.java
index 89d4595..aff3f86 100644
--- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/fn/control/BeamFnMapTaskExecutorTest.java
+++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/fn/control/BeamFnMapTaskExecutorTest.java
@@ -141,9 +141,7 @@
return MoreFutures.supplyAsync(
() -> {
processBundleLatch.await();
- return responseFor(request)
- .setProcessBundle(BeamFnApi.ProcessBundleResponse.getDefaultInstance())
- .build();
+ return responseFor(request).build();
});
case PROCESS_BUNDLE_PROGRESS:
progressSentLatch.countDown();
@@ -238,9 +236,7 @@
return MoreFutures.supplyAsync(
() -> {
processBundleLatch.await();
- return responseFor(request)
- .setProcessBundle(BeamFnApi.ProcessBundleResponse.getDefaultInstance())
- .build();
+ return responseFor(request).build();
});
case PROCESS_BUNDLE_PROGRESS:
progressSentTwiceLatch.countDown();
@@ -623,6 +619,20 @@
}
private BeamFnApi.InstructionResponse.Builder responseFor(BeamFnApi.InstructionRequest request) {
- return BeamFnApi.InstructionResponse.newBuilder().setInstructionId(request.getInstructionId());
+ BeamFnApi.InstructionResponse.Builder response =
+ BeamFnApi.InstructionResponse.newBuilder().setInstructionId(request.getInstructionId());
+ if (request.hasRegister()) {
+ response.setRegister(BeamFnApi.RegisterResponse.getDefaultInstance());
+ } else if (request.hasProcessBundle()) {
+ response.setProcessBundle(BeamFnApi.ProcessBundleResponse.getDefaultInstance());
+ } else if (request.hasFinalizeBundle()) {
+ response.setFinalizeBundle(BeamFnApi.FinalizeBundleResponse.getDefaultInstance());
+ } else if (request.hasProcessBundleProgress()) {
+ response.setProcessBundleProgress(
+ BeamFnApi.ProcessBundleProgressResponse.getDefaultInstance());
+ } else if (request.hasProcessBundleSplit()) {
+ response.setProcessBundleSplit(BeamFnApi.ProcessBundleSplitResponse.getDefaultInstance());
+ }
+ return response;
}
}
diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/fn/control/RegisterAndProcessBundleOperationTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/fn/control/RegisterAndProcessBundleOperationTest.java
index 321a236..eb3d21d 100644
--- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/fn/control/RegisterAndProcessBundleOperationTest.java
+++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/fn/control/RegisterAndProcessBundleOperationTest.java
@@ -192,15 +192,8 @@
requests.add(request);
switch (request.getRequestCase()) {
case REGISTER:
- return CompletableFuture.completedFuture(
- responseFor(request)
- .setRegister(BeamFnApi.RegisterResponse.getDefaultInstance())
- .build());
case PROCESS_BUNDLE:
- return CompletableFuture.completedFuture(
- responseFor(request)
- .setProcessBundle(BeamFnApi.ProcessBundleResponse.getDefaultInstance())
- .build());
+ return CompletableFuture.completedFuture(responseFor(request).build());
default:
// block forever on other requests
return new CompletableFuture<>();
@@ -277,9 +270,7 @@
return MoreFutures.supplyAsync(
() -> {
processBundleLatch.await();
- return responseFor(request)
- .setProcessBundle(BeamFnApi.ProcessBundleResponse.getDefaultInstance())
- .build();
+ return responseFor(request).build();
});
case PROCESS_BUNDLE_PROGRESS:
return CompletableFuture.completedFuture(
@@ -459,10 +450,7 @@
requests.add(request);
switch (request.getRequestCase()) {
case REGISTER:
- return CompletableFuture.completedFuture(
- InstructionResponse.newBuilder()
- .setInstructionId(request.getInstructionId())
- .build());
+ return CompletableFuture.completedFuture(responseFor(request).build());
case PROCESS_BUNDLE:
CompletableFuture<InstructionResponse> responseFuture =
new CompletableFuture<>();
@@ -470,12 +458,7 @@
() -> {
// Purposefully sleep simulating SDK harness doing work
Thread.sleep(100);
- responseFuture.complete(
- InstructionResponse.newBuilder()
- .setInstructionId(request.getInstructionId())
- .setProcessBundle(
- BeamFnApi.ProcessBundleResponse.getDefaultInstance())
- .build());
+ responseFuture.complete(responseFor(request).build());
completeFuture(request, responseFuture);
return null;
});
@@ -591,9 +574,7 @@
MoreFutures.get(stateHandler.handle(clear));
assertNotNull(clearResponse);
- return responseFor(request)
- .setProcessBundle(BeamFnApi.ProcessBundleResponse.getDefaultInstance())
- .build();
+ return responseFor(request).build();
});
default:
// block forever
@@ -685,9 +666,7 @@
encodeAndConcat(Arrays.asList("X", "Y", "Z"), StringUtf8Coder.of()),
getResponse.getGet().getData());
- return responseFor(request)
- .setProcessBundle(BeamFnApi.ProcessBundleResponse.getDefaultInstance())
- .build();
+ return responseFor(request).build();
});
default:
// block forever on other request types
@@ -855,15 +834,26 @@
}
private InstructionResponse.Builder responseFor(BeamFnApi.InstructionRequest request) {
- return BeamFnApi.InstructionResponse.newBuilder().setInstructionId(request.getInstructionId());
+ BeamFnApi.InstructionResponse.Builder response =
+ BeamFnApi.InstructionResponse.newBuilder().setInstructionId(request.getInstructionId());
+ if (request.hasRegister()) {
+ response.setRegister(BeamFnApi.RegisterResponse.getDefaultInstance());
+ } else if (request.hasProcessBundle()) {
+ response.setProcessBundle(BeamFnApi.ProcessBundleResponse.getDefaultInstance());
+ } else if (request.hasFinalizeBundle()) {
+ response.setFinalizeBundle(BeamFnApi.FinalizeBundleResponse.getDefaultInstance());
+ } else if (request.hasProcessBundleProgress()) {
+ response.setProcessBundleProgress(
+ BeamFnApi.ProcessBundleProgressResponse.getDefaultInstance());
+ } else if (request.hasProcessBundleSplit()) {
+ response.setProcessBundleSplit(BeamFnApi.ProcessBundleSplitResponse.getDefaultInstance());
+ }
+ return response;
}
private void completeFuture(
BeamFnApi.InstructionRequest request, CompletableFuture<InstructionResponse> response) {
- response.complete(
- BeamFnApi.InstructionResponse.newBuilder()
- .setInstructionId(request.getInstructionId())
- .build());
+ response.complete(responseFor(request).build());
}
@Test
diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/BundleCheckpointHandler.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/BundleCheckpointHandler.java
new file mode 100644
index 0000000..1e5fa53
--- /dev/null
+++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/BundleCheckpointHandler.java
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.fnexecution.control;
+
+import org.apache.beam.model.fnexecution.v1.BeamFnApi;
+
+/**
+ * A handler which is invoked when the SDK returns {@link BeamFnApi.DelayedBundleApplication}s as
+ * part of the bundle completion.
+ *
+ * <p>These bundle applications must be resumed otherwise data loss will occur.
+ *
+ * <p>See <a href="https://s.apache.org/beam-breaking-fusion">breaking the fusion barrier</a> for
+ * further details.
+ */
+public interface BundleCheckpointHandler {
+ void onCheckpoint(BeamFnApi.ProcessBundleResponse response);
+}
diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/BundleFinalizationHandler.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/BundleFinalizationHandler.java
new file mode 100644
index 0000000..849663b
--- /dev/null
+++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/BundleFinalizationHandler.java
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.fnexecution.control;
+
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleResponse;
+
+/**
+ * A handler for the runner when a finalization request has been received.
+ *
+ * <p>The runner is responsible for finalizing the bundle when all output from the bundle has been
+ * durably persisted.
+ *
+ * <p>See <a href="https://s.apache.org/beam-finalizing-bundles">finalizing bundles</a> for further
+ * details.
+ */
+public interface BundleFinalizationHandler {
+ void requestsFinalization(ProcessBundleResponse response);
+}
diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/InstructionRequestHandler.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/InstructionRequestHandler.java
index b655732..8a9dc75 100644
--- a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/InstructionRequestHandler.java
+++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/InstructionRequestHandler.java
@@ -20,7 +20,10 @@
import java.util.concurrent.CompletionStage;
import org.apache.beam.model.fnexecution.v1.BeamFnApi;
-/** Interface for any function that can handle a Fn API {@link BeamFnApi.InstructionRequest}. */
+/**
+ * Interface for any function that can handle a Fn API {@link BeamFnApi.InstructionRequest}. Any
+ * error responses will be converted to exceptionally completed futures.
+ */
public interface InstructionRequestHandler extends AutoCloseable {
CompletionStage<BeamFnApi.InstructionResponse> handle(BeamFnApi.InstructionRequest request);
}
diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptors.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptors.java
index b324cb1..a2129c4 100644
--- a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptors.java
+++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptors.java
@@ -57,6 +57,7 @@
import org.apache.beam.sdk.util.WindowedValue.FullWindowedValueCoder;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.vendor.grpc.v1p21p0.com.google.protobuf.InvalidProtocolBufferException;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableTable;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
@@ -135,6 +136,10 @@
forTimerSpecs(
dataEndpoint, stage, components, inputDestinationsBuilder, remoteOutputCodersBuilder);
+ if (bagUserStateSpecs.size() > 0 || timerSpecs.size() > 0) {
+ lengthPrefixKeyCoder(stage.getInputPCollection().getId(), components);
+ }
+
// Copy data from components to ProcessBundleDescriptor.
ProcessBundleDescriptor.Builder bundleDescriptorBuilder =
ProcessBundleDescriptor.newBuilder().setId(id);
@@ -158,6 +163,29 @@
timerSpecs);
}
+ /**
+ * Patches the input coder of a stateful Executable transform to ensure that the byte
+ * representation of a key used to partition the input element at the Runner, matches the key byte
+ * representation received for state requests and timers from the SDK Harness. Stateful transforms
+ * always have a KvCoder as input.
+ */
+ private static void lengthPrefixKeyCoder(
+ String inputColId, Components.Builder componentsBuilder) {
+ RunnerApi.PCollection pcollection = componentsBuilder.getPcollectionsOrThrow(inputColId);
+ RunnerApi.Coder kvCoder = componentsBuilder.getCodersOrThrow(pcollection.getCoderId());
+ Preconditions.checkState(
+ ModelCoders.KV_CODER_URN.equals(kvCoder.getSpec().getUrn()),
+ "Stateful executable stages must use a KV coder, but is: %s",
+ kvCoder.getSpec().getUrn());
+ String keyCoderId = ModelCoders.getKvCoderComponents(kvCoder).keyCoderId();
+ // Retain the original coder, but wrap in LengthPrefixCoder
+ String newKeyCoderId =
+ LengthPrefixUnknownCoders.addLengthPrefixedCoder(keyCoderId, componentsBuilder, false);
+ // Replace old key coder with LengthPrefixCoder<old_key_coder>
+ kvCoder = kvCoder.toBuilder().setComponentCoderIds(0, newKeyCoderId).build();
+ componentsBuilder.putCoders(pcollection.getCoderId(), kvCoder);
+ }
+
private static Map<String, Coder<WindowedValue<?>>> addStageOutputs(
ApiServiceDescriptor dataEndpoint,
Collection<PCollectionNode> outputPCollections,
diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/SdkHarnessClient.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/SdkHarnessClient.java
index b3f1c2e..86232dd 100644
--- a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/SdkHarnessClient.java
+++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/SdkHarnessClient.java
@@ -95,6 +95,9 @@
* // send all main input elements ...
* }
* }</pre>
+ *
+ * <p>An exception during {@link #close()} will be thrown if the bundle requests finalization or
+ * attempts to checkpoint by returning a {@link BeamFnApi.DelayedBundleApplication}.
*/
public ActiveBundle newBundle(
Map<String, RemoteOutputReceiver<?>> outputReceivers,
@@ -122,6 +125,47 @@
* try (ActiveBundle bundle = SdkHarnessClient.newBundle(...)) {
* FnDataReceiver<InputT> inputReceiver =
* (FnDataReceiver) bundle.getInputReceivers().get(mainPCollectionId);
+ * // send all main input elements ...
+ * }
+ * }</pre>
+ *
+ * <p>An exception during {@link #close()} will be thrown if the bundle requests finalization or
+ * attempts to checkpoint by returning a {@link BeamFnApi.DelayedBundleApplication}.
+ */
+ public ActiveBundle newBundle(
+ Map<String, RemoteOutputReceiver<?>> outputReceivers,
+ StateRequestHandler stateRequestHandler,
+ BundleProgressHandler progressHandler) {
+ return newBundle(
+ outputReceivers,
+ stateRequestHandler,
+ progressHandler,
+ request -> {
+ throw new UnsupportedOperationException(
+ String.format(
+ "The %s does not have a registered bundle checkpoint handler.",
+ ActiveBundle.class.getSimpleName()));
+ },
+ request -> {
+ throw new UnsupportedOperationException(
+ String.format(
+ "The %s does not have a registered bundle finalization handler.",
+ ActiveBundle.class.getSimpleName()));
+ });
+ }
+
+ /**
+ * Start a new bundle for the given {@link BeamFnApi.ProcessBundleDescriptor} identifier.
+ *
+ * <p>The input channels for the returned {@link ActiveBundle} are derived from the instructions
+ * in the {@link BeamFnApi.ProcessBundleDescriptor}.
+ *
+ * <p>NOTE: It is important to {@link #close()} each bundle after all elements are emitted.
+ *
+ * <pre>{@code
+ * try (ActiveBundle bundle = SdkHarnessClient.newBundle(...)) {
+ * FnDataReceiver<InputT> inputReceiver =
+ * (FnDataReceiver) bundle.getInputReceivers().get(mainPCollectionId);
* // send all elements ...
* }
* }</pre>
@@ -129,18 +173,22 @@
public ActiveBundle newBundle(
Map<String, RemoteOutputReceiver<?>> outputReceivers,
StateRequestHandler stateRequestHandler,
- BundleProgressHandler progressHandler) {
+ BundleProgressHandler progressHandler,
+ BundleCheckpointHandler checkpointHandler,
+ BundleFinalizationHandler finalizationHandler) {
String bundleId = idGenerator.getId();
final CompletionStage<BeamFnApi.InstructionResponse> genericResponse =
- fnApiControlClient.handle(
- BeamFnApi.InstructionRequest.newBuilder()
- .setInstructionId(bundleId)
- .setProcessBundle(
- BeamFnApi.ProcessBundleRequest.newBuilder()
- .setProcessBundleDescriptorId(processBundleDescriptor.getId())
- .addAllCacheTokens(stateRequestHandler.getCacheTokens()))
- .build());
+ registrationFuture.thenCompose(
+ registration ->
+ fnApiControlClient.handle(
+ BeamFnApi.InstructionRequest.newBuilder()
+ .setInstructionId(bundleId)
+ .setProcessBundle(
+ BeamFnApi.ProcessBundleRequest.newBuilder()
+ .setProcessBundleDescriptorId(processBundleDescriptor.getId())
+ .addAllCacheTokens(stateRequestHandler.getCacheTokens()))
+ .build()));
LOG.debug(
"Sent {} with ID {} for {} with ID {}",
ProcessBundleRequest.class.getSimpleName(),
@@ -173,7 +221,9 @@
dataReceiversBuilder.build(),
outputClients,
stateDelegator.registerForProcessBundleInstructionId(bundleId, stateRequestHandler),
- progressHandler);
+ progressHandler,
+ checkpointHandler,
+ finalizationHandler);
}
private <OutputT> InboundDataClient attachReceiver(
@@ -191,6 +241,8 @@
private final Map<String, InboundDataClient> outputClients;
private final StateDelegator.Registration stateRegistration;
private final BundleProgressHandler progressHandler;
+ private final BundleCheckpointHandler checkpointHandler;
+ private final BundleFinalizationHandler finalizationHandler;
private ActiveBundle(
String bundleId,
@@ -198,13 +250,17 @@
Map<String, CloseableFnDataReceiver> inputReceivers,
Map<String, InboundDataClient> outputClients,
StateDelegator.Registration stateRegistration,
- BundleProgressHandler progressHandler) {
+ BundleProgressHandler progressHandler,
+ BundleCheckpointHandler checkpointHandler,
+ BundleFinalizationHandler finalizationHandler) {
this.bundleId = bundleId;
this.response = response;
this.inputReceivers = inputReceivers;
this.outputClients = outputClients;
this.stateRegistration = stateRegistration;
this.progressHandler = progressHandler;
+ this.checkpointHandler = checkpointHandler;
+ this.finalizationHandler = finalizationHandler;
}
/** Returns an id used to represent this bundle. */
@@ -254,13 +310,15 @@
BeamFnApi.ProcessBundleResponse completedResponse = MoreFutures.get(response);
progressHandler.onCompleted(completedResponse);
if (completedResponse.getResidualRootsCount() > 0) {
- throw new IllegalStateException(
- "TODO: [BEAM-2939] residual roots in process bundle response not yet supported.");
+ checkpointHandler.onCheckpoint(completedResponse);
+ }
+ if (completedResponse.getRequiresFinalization()) {
+ finalizationHandler.requestsFinalization(completedResponse);
}
} else {
// TODO: [BEAM-3962] Handle aborting the bundle being processed.
throw new IllegalStateException(
- "Processing bundle failed, " + "TODO: [BEAM-3962] abort bundle.");
+ "Processing bundle failed, TODO: [BEAM-3962] abort bundle.");
}
} catch (Exception e) {
if (exception == null) {
diff --git a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptorsTest.java b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptorsTest.java
new file mode 100644
index 0000000..b558b00
--- /dev/null
+++ b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptorsTest.java
@@ -0,0 +1,130 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.fnexecution.control;
+
+import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState;
+import static org.hamcrest.CoreMatchers.is;
+import static org.hamcrest.MatcherAssert.assertThat;
+
+import java.io.Serializable;
+import java.util.Map;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi;
+import org.apache.beam.model.pipeline.v1.Endpoints;
+import org.apache.beam.model.pipeline.v1.RunnerApi;
+import org.apache.beam.runners.core.construction.CoderTranslation;
+import org.apache.beam.runners.core.construction.ModelCoderRegistrar;
+import org.apache.beam.runners.core.construction.ModelCoders;
+import org.apache.beam.runners.core.construction.PipelineTranslation;
+import org.apache.beam.runners.core.construction.graph.ExecutableStage;
+import org.apache.beam.runners.core.construction.graph.FusedPipeline;
+import org.apache.beam.runners.core.construction.graph.GreedyPipelineFuser;
+import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.coders.VoidCoder;
+import org.apache.beam.sdk.state.BagState;
+import org.apache.beam.sdk.state.StateSpec;
+import org.apache.beam.sdk.state.StateSpecs;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.GroupByKey;
+import org.apache.beam.sdk.transforms.Impulse;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Optional;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
+import org.junit.Test;
+
+/** Tests for {@link ProcessBundleDescriptors}. */
+public class ProcessBundleDescriptorsTest implements Serializable {
+
+ /**
+ * Tests that a stateful Executable stage will wrap a key coder of a stateful transform in a
+ * LengthPrefixCoder.
+ */
+ @Test
+ public void testWrapKeyCoderOfStatefulExecutableStageInLengthPrefixCoder() throws Exception {
+ // Add another stateful stage with a non-standard key coder
+ Pipeline p = Pipeline.create();
+ Coder<Void> keycoder = VoidCoder.of();
+ assertThat(ModelCoderRegistrar.isKnownCoder(keycoder), is(false));
+ p.apply("impulse", Impulse.create())
+ .apply(
+ "create",
+ ParDo.of(
+ new DoFn<byte[], KV<Void, String>>() {
+ @ProcessElement
+ public void process(ProcessContext ctxt) {}
+ }))
+ .setCoder(KvCoder.of(keycoder, StringUtf8Coder.of()))
+ .apply(
+ "userState",
+ ParDo.of(
+ new DoFn<KV<Void, String>, KV<Void, String>>() {
+
+ @StateId("stateId")
+ private final StateSpec<BagState<String>> bufferState =
+ StateSpecs.bag(StringUtf8Coder.of());
+
+ @ProcessElement
+ public void processElement(
+ @Element KV<Void, String> element,
+ @StateId("stateId") BagState<String> state,
+ OutputReceiver<KV<Void, String>> r) {
+ for (String value : state.read()) {
+ r.output(KV.of(element.getKey(), value));
+ }
+ state.add(element.getValue());
+ }
+ }))
+ // Force the output to be materialized
+ .apply("gbk", GroupByKey.create());
+
+ RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p);
+ FusedPipeline fused = GreedyPipelineFuser.fuse(pipelineProto);
+ Optional<ExecutableStage> optionalStage =
+ Iterables.tryFind(
+ fused.getFusedStages(),
+ (ExecutableStage stage) ->
+ stage.getUserStates().stream()
+ .anyMatch(spec -> spec.localName().equals("stateId")));
+ checkState(optionalStage.isPresent(), "Expected a stage with user state.");
+ ExecutableStage stage = optionalStage.get();
+
+ ProcessBundleDescriptors.ExecutableProcessBundleDescriptor descriptor =
+ ProcessBundleDescriptors.fromExecutableStage(
+ "test_stage", stage, Endpoints.ApiServiceDescriptor.getDefaultInstance());
+
+ BeamFnApi.ProcessBundleDescriptor pbDescriptor = descriptor.getProcessBundleDescriptor();
+ String inputColId = stage.getInputPCollection().getId();
+ String inputCoderId = pbDescriptor.getPcollectionsMap().get(inputColId).getCoderId();
+
+ Map<String, RunnerApi.Coder> codersMap = pbDescriptor.getCodersMap();
+ RunnerApi.Coder coder = codersMap.get(inputCoderId);
+ String keyCoderId = ModelCoders.getKvCoderComponents(coder).keyCoderId();
+
+ assertThat(
+ codersMap.get(keyCoderId).getSpec().getUrn(), is(ModelCoders.LENGTH_PREFIX_CODER_URN));
+
+ RunnerApi.Coder orignalCoder = stage.getComponents().getCodersMap().get(inputCoderId);
+ String originalKeyCoderId = ModelCoders.getKvCoderComponents(orignalCoder).keyCoderId();
+ assertThat(
+ stage.getComponents().getCodersMap().get(originalKeyCoderId).getSpec().getUrn(),
+ is(CoderTranslation.JAVA_SERIALIZED_CODER_URN));
+ }
+}
diff --git a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java
index 0925767..5d4c8f0 100644
--- a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java
+++ b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java
@@ -18,10 +18,10 @@
package org.apache.beam.runners.fnexecution.control;
import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState;
+import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.allOf;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.equalTo;
-import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
@@ -946,8 +946,6 @@
@Test
public void testExecutionWithTimer() throws Exception {
Pipeline p = Pipeline.create();
- final String timerId = "foo";
- final String timerId2 = "foo2";
p.apply("impulse", Impulse.create())
.apply(
diff --git a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/SdkHarnessClientTest.java b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/SdkHarnessClientTest.java
index 5bc0d1c..089c8d1 100644
--- a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/SdkHarnessClientTest.java
+++ b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/SdkHarnessClientTest.java
@@ -31,6 +31,7 @@
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
+import static org.mockito.Mockito.verifyZeroInteractions;
import static org.mockito.Mockito.when;
import java.util.ArrayList;
@@ -41,6 +42,7 @@
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import org.apache.beam.model.fnexecution.v1.BeamFnApi;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.DelayedBundleApplication;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.InstructionResponse;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleDescriptor;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleResponse;
@@ -159,9 +161,8 @@
@Test
public void testRegisterCachesBundleProcessors() throws Exception {
- CompletableFuture<InstructionResponse> registerResponseFuture = new CompletableFuture<>();
when(fnApiControlClient.handle(any(BeamFnApi.InstructionRequest.class)))
- .thenReturn(registerResponseFuture);
+ .thenReturn(createRegisterResponse());
ProcessBundleDescriptor descriptor1 =
ProcessBundleDescriptor.newBuilder().setId("descriptor1").build();
@@ -187,9 +188,8 @@
@Test
public void testRegisterWithStateRequiresStateDelegator() throws Exception {
- CompletableFuture<InstructionResponse> registerResponseFuture = new CompletableFuture<>();
when(fnApiControlClient.handle(any(BeamFnApi.InstructionRequest.class)))
- .thenReturn(registerResponseFuture);
+ .thenReturn(createRegisterResponse());
ProcessBundleDescriptor descriptor =
ProcessBundleDescriptor.newBuilder()
@@ -214,7 +214,7 @@
public void testNewBundleNoDataDoesNotCrash() throws Exception {
CompletableFuture<InstructionResponse> processBundleResponseFuture = new CompletableFuture<>();
when(fnApiControlClient.handle(any(BeamFnApi.InstructionRequest.class)))
- .thenReturn(new CompletableFuture<>())
+ .thenReturn(createRegisterResponse())
.thenReturn(processBundleResponseFuture);
FullWindowedValueCoder<String> coder =
@@ -290,7 +290,7 @@
CompletableFuture<InstructionResponse> processBundleResponseFuture = new CompletableFuture<>();
when(fnApiControlClient.handle(any(BeamFnApi.InstructionRequest.class)))
- .thenReturn(new CompletableFuture<>())
+ .thenReturn(createRegisterResponse())
.thenReturn(processBundleResponseFuture);
FullWindowedValueCoder<String> coder =
@@ -343,7 +343,7 @@
CompletableFuture<InstructionResponse> processBundleResponseFuture = new CompletableFuture<>();
when(fnApiControlClient.handle(any(BeamFnApi.InstructionRequest.class)))
- .thenReturn(new CompletableFuture<>())
+ .thenReturn(createRegisterResponse())
.thenReturn(processBundleResponseFuture);
FullWindowedValueCoder<String> coder =
@@ -389,7 +389,7 @@
CompletableFuture<InstructionResponse> processBundleResponseFuture = new CompletableFuture<>();
when(fnApiControlClient.handle(any(BeamFnApi.InstructionRequest.class)))
- .thenReturn(new CompletableFuture<>())
+ .thenReturn(createRegisterResponse())
.thenReturn(processBundleResponseFuture);
FullWindowedValueCoder<String> coder =
@@ -439,7 +439,7 @@
CompletableFuture<InstructionResponse> processBundleResponseFuture = new CompletableFuture<>();
when(fnApiControlClient.handle(any(BeamFnApi.InstructionRequest.class)))
- .thenReturn(new CompletableFuture<>())
+ .thenReturn(createRegisterResponse())
.thenReturn(processBundleResponseFuture);
FullWindowedValueCoder<String> coder =
@@ -483,7 +483,7 @@
CompletableFuture<InstructionResponse> processBundleResponseFuture = new CompletableFuture<>();
when(fnApiControlClient.handle(any(BeamFnApi.InstructionRequest.class)))
- .thenReturn(new CompletableFuture<>())
+ .thenReturn(createRegisterResponse())
.thenReturn(processBundleResponseFuture);
FullWindowedValueCoder<String> coder =
@@ -540,7 +540,7 @@
CompletableFuture<InstructionResponse> processBundleResponseFuture = new CompletableFuture<>();
when(fnApiControlClient.handle(any(BeamFnApi.InstructionRequest.class)))
- .thenReturn(new CompletableFuture<>())
+ .thenReturn(createRegisterResponse())
.thenReturn(processBundleResponseFuture);
FullWindowedValueCoder<String> coder =
@@ -582,10 +582,9 @@
}
@Test
- public void verifyCacheTokensAreUsedInNewBundleRequest() {
- CompletableFuture<InstructionResponse> registerResponseFuture = new CompletableFuture<>();
+ public void verifyCacheTokensAreUsedInNewBundleRequest() throws InterruptedException {
when(fnApiControlClient.handle(any(BeamFnApi.InstructionRequest.class)))
- .thenReturn(registerResponseFuture);
+ .thenReturn(createRegisterResponse());
ProcessBundleDescriptor descriptor1 =
ProcessBundleDescriptor.newBuilder().setId("descriptor1").build();
@@ -626,6 +625,117 @@
assertThat(requests.get(1).getProcessBundle().getCacheTokensList(), is(cacheTokens));
}
+ @Test
+ public void testBundleCheckpointCallback() throws Exception {
+ Exception testException = new Exception();
+
+ InboundDataClient mockOutputReceiver = mock(InboundDataClient.class);
+ CloseableFnDataReceiver mockInputSender = mock(CloseableFnDataReceiver.class);
+
+ CompletableFuture<InstructionResponse> processBundleResponseFuture = new CompletableFuture<>();
+ when(fnApiControlClient.handle(any(BeamFnApi.InstructionRequest.class)))
+ .thenReturn(createRegisterResponse())
+ .thenReturn(processBundleResponseFuture);
+
+ FullWindowedValueCoder<String> coder =
+ FullWindowedValueCoder.of(StringUtf8Coder.of(), Coder.INSTANCE);
+ BundleProcessor processor =
+ sdkHarnessClient.getProcessor(
+ descriptor,
+ Collections.singletonMap(
+ "inputPC",
+ RemoteInputDestination.of(
+ (FullWindowedValueCoder) coder, SDK_GRPC_READ_TRANSFORM)));
+ when(dataService.receive(any(), any(), any())).thenReturn(mockOutputReceiver);
+ when(dataService.send(any(), eq(coder))).thenReturn(mockInputSender);
+
+ RemoteOutputReceiver mockRemoteOutputReceiver = mock(RemoteOutputReceiver.class);
+ BundleProgressHandler mockProgressHandler = mock(BundleProgressHandler.class);
+ BundleCheckpointHandler mockCheckpointHandler = mock(BundleCheckpointHandler.class);
+ BundleFinalizationHandler mockFinalizationHandler = mock(BundleFinalizationHandler.class);
+
+ ProcessBundleResponse response =
+ ProcessBundleResponse.newBuilder()
+ .addResidualRoots(DelayedBundleApplication.getDefaultInstance())
+ .build();
+ ArrayList<ProcessBundleResponse> checkpoints = new ArrayList<>();
+
+ try (ActiveBundle activeBundle =
+ processor.newBundle(
+ ImmutableMap.of(SDK_GRPC_WRITE_TRANSFORM, mockRemoteOutputReceiver),
+ (request) -> {
+ throw new UnsupportedOperationException();
+ },
+ mockProgressHandler,
+ mockCheckpointHandler,
+ mockFinalizationHandler)) {
+ processBundleResponseFuture.complete(
+ InstructionResponse.newBuilder().setProcessBundle(response).build());
+ }
+
+ verify(mockProgressHandler).onCompleted(response);
+ verify(mockCheckpointHandler).onCheckpoint(response);
+ verifyZeroInteractions(mockFinalizationHandler);
+ }
+
+ @Test
+ public void testBundleFinalizationCallback() throws Exception {
+ Exception testException = new Exception();
+
+ InboundDataClient mockOutputReceiver = mock(InboundDataClient.class);
+ CloseableFnDataReceiver mockInputSender = mock(CloseableFnDataReceiver.class);
+
+ CompletableFuture<InstructionResponse> processBundleResponseFuture = new CompletableFuture<>();
+ when(fnApiControlClient.handle(any(BeamFnApi.InstructionRequest.class)))
+ .thenReturn(createRegisterResponse())
+ .thenReturn(processBundleResponseFuture);
+
+ FullWindowedValueCoder<String> coder =
+ FullWindowedValueCoder.of(StringUtf8Coder.of(), Coder.INSTANCE);
+ BundleProcessor processor =
+ sdkHarnessClient.getProcessor(
+ descriptor,
+ Collections.singletonMap(
+ "inputPC",
+ RemoteInputDestination.of(
+ (FullWindowedValueCoder) coder, SDK_GRPC_READ_TRANSFORM)));
+ when(dataService.receive(any(), any(), any())).thenReturn(mockOutputReceiver);
+ when(dataService.send(any(), eq(coder))).thenReturn(mockInputSender);
+
+ RemoteOutputReceiver mockRemoteOutputReceiver = mock(RemoteOutputReceiver.class);
+ BundleProgressHandler mockProgressHandler = mock(BundleProgressHandler.class);
+ BundleCheckpointHandler mockCheckpointHandler = mock(BundleCheckpointHandler.class);
+ BundleFinalizationHandler mockFinalizationHandler = mock(BundleFinalizationHandler.class);
+
+ ProcessBundleResponse response =
+ ProcessBundleResponse.newBuilder().setRequiresFinalization(true).build();
+ ArrayList<ProcessBundleResponse> checkpoints = new ArrayList<>();
+
+ try (ActiveBundle activeBundle =
+ processor.newBundle(
+ ImmutableMap.of(SDK_GRPC_WRITE_TRANSFORM, mockRemoteOutputReceiver),
+ (request) -> {
+ throw new UnsupportedOperationException();
+ },
+ mockProgressHandler,
+ mockCheckpointHandler,
+ mockFinalizationHandler)) {
+ processBundleResponseFuture.complete(
+ InstructionResponse.newBuilder().setProcessBundle(response).build());
+ }
+
+ verify(mockProgressHandler).onCompleted(response);
+ verify(mockFinalizationHandler).requestsFinalization(response);
+ verifyZeroInteractions(mockCheckpointHandler);
+ }
+
+ private CompletableFuture<InstructionResponse> createRegisterResponse() {
+ return CompletableFuture.completedFuture(
+ InstructionResponse.newBuilder()
+ .setRegister(BeamFnApi.RegisterResponse.getDefaultInstance())
+ .build());
+ }
+
private static class TestFn extends DoFn<String, String> {
@ProcessElement
public void processElement(ProcessContext context) {
diff --git a/runners/samza/src/main/java/org/apache/beam/runners/samza/SamzaPipelineRunner.java b/runners/samza/src/main/java/org/apache/beam/runners/samza/SamzaPipelineRunner.java
index 85bd576..4733ef5 100644
--- a/runners/samza/src/main/java/org/apache/beam/runners/samza/SamzaPipelineRunner.java
+++ b/runners/samza/src/main/java/org/apache/beam/runners/samza/SamzaPipelineRunner.java
@@ -19,7 +19,10 @@
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.model.pipeline.v1.RunnerApi.Pipeline;
+import org.apache.beam.runners.core.construction.PTransformTranslation;
import org.apache.beam.runners.core.construction.graph.GreedyPipelineFuser;
+import org.apache.beam.runners.core.construction.graph.ProtoOverrides;
+import org.apache.beam.runners.core.construction.graph.SplittableParDoExpander;
import org.apache.beam.runners.core.construction.renderer.PipelineDotRenderer;
import org.apache.beam.runners.fnexecution.jobsubmission.PortablePipelineResult;
import org.apache.beam.runners.fnexecution.jobsubmission.PortablePipelineRunner;
@@ -36,8 +39,16 @@
@Override
public PortablePipelineResult run(final Pipeline pipeline, JobInfo jobInfo) {
+ // Expand any splittable DoFns within the graph to enable sizing and splitting of bundles.
+ Pipeline pipelineWithSdfExpanded =
+ ProtoOverrides.updateTransform(
+ PTransformTranslation.PAR_DO_TRANSFORM_URN,
+ pipeline,
+ SplittableParDoExpander.createSizedReplacement());
+
// Fused pipeline proto.
- final RunnerApi.Pipeline fusedPipeline = GreedyPipelineFuser.fuse(pipeline).toPipeline();
+ final RunnerApi.Pipeline fusedPipeline =
+ GreedyPipelineFuser.fuse(pipelineWithSdfExpanded).toPipeline();
LOG.info("Portable pipeline to run:");
LOG.info(PipelineDotRenderer.toDotString(fusedPipeline));
// the pipeline option coming from sdk will set the sdk specific runner which will break
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineRunner.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineRunner.java
index 725df75..1d3f92d 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineRunner.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineRunner.java
@@ -25,9 +25,12 @@
import java.util.concurrent.Future;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.model.pipeline.v1.RunnerApi.Pipeline;
+import org.apache.beam.runners.core.construction.PTransformTranslation;
import org.apache.beam.runners.core.construction.graph.ExecutableStage;
import org.apache.beam.runners.core.construction.graph.GreedyPipelineFuser;
import org.apache.beam.runners.core.construction.graph.PipelineTrimmer;
+import org.apache.beam.runners.core.construction.graph.ProtoOverrides;
+import org.apache.beam.runners.core.construction.graph.SplittableParDoExpander;
import org.apache.beam.runners.core.metrics.MetricsPusher;
import org.apache.beam.runners.fnexecution.jobsubmission.PortablePipelineResult;
import org.apache.beam.runners.fnexecution.jobsubmission.PortablePipelineRunner;
@@ -58,8 +61,16 @@
public PortablePipelineResult run(RunnerApi.Pipeline pipeline, JobInfo jobInfo) {
SparkBatchPortablePipelineTranslator translator = new SparkBatchPortablePipelineTranslator();
+ // Expand any splittable DoFns within the graph to enable sizing and splitting of bundles.
+ Pipeline pipelineWithSdfExpanded =
+ ProtoOverrides.updateTransform(
+ PTransformTranslation.PAR_DO_TRANSFORM_URN,
+ pipeline,
+ SplittableParDoExpander.createSizedReplacement());
+
// Don't let the fuser fuse any subcomponents of native transforms.
- Pipeline trimmedPipeline = PipelineTrimmer.trim(pipeline, translator.knownUrns());
+ Pipeline trimmedPipeline =
+ PipelineTrimmer.trim(pipelineWithSdfExpanded, translator.knownUrns());
// Fused pipeline proto.
// TODO: Consider supporting partially-fused graphs.
diff --git a/sdks/go/pkg/beam/pardo.go b/sdks/go/pkg/beam/pardo.go
index 41283f7..21e515a 100644
--- a/sdks/go/pkg/beam/pardo.go
+++ b/sdks/go/pkg/beam/pardo.go
@@ -123,7 +123,7 @@
// words := beam.ParDo(s, &Foo{...}, ...)
// lengths := beam.ParDo(s, func (word string) int) {
// return len(word)
-// }, works)
+// }, words)
//
//
// Each output element has the same timestamp and is in the same windows as its
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TextSource.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TextSource.java
index 5e2d0bf..df9a5f4 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TextSource.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TextSource.java
@@ -87,6 +87,8 @@
@VisibleForTesting
static class TextBasedReader extends FileBasedReader<String> {
private static final int READ_BUFFER_SIZE = 8192;
+ private static final ByteString UTF8_BOM =
+ ByteString.copyFrom(new byte[] {(byte) 0xEF, (byte) 0xBB, (byte) 0xBF});
private final ByteBuffer readBuffer = ByteBuffer.allocate(READ_BUFFER_SIZE);
private ByteString buffer;
private int startOfDelimiterInBuffer;
@@ -251,6 +253,10 @@
*/
private void decodeCurrentElement() throws IOException {
ByteString dataToDecode = buffer.substring(0, startOfDelimiterInBuffer);
+ // If present, the UTF8 Byte Order Mark (BOM) will be removed.
+ if (startOfRecord == 0 && dataToDecode.startsWith(UTF8_BOM)) {
+ dataToDecode = dataToDecode.substring(UTF8_BOM.size());
+ }
currentValue = dataToDecode.toStringUtf8();
elementIsPresent = true;
buffer = buffer.substring(endOfDelimiterInBuffer);
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/UsesTestStreamWithMultipleStages.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/UsesTestStreamWithMultipleStages.java
new file mode 100644
index 0000000..55999ce
--- /dev/null
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/UsesTestStreamWithMultipleStages.java
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.testing;
+
+/**
+ * Subcategory for {@link UsesTestStream} tests which use {@link TestStream} # across multiple
+ * stages. Some Runners do not properly support quiescence in a way that {@link TestStream} demands
+ * it.
+ */
+public interface UsesTestStreamWithMultipleStages extends UsesTestStream {}
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupByKey.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupByKey.java
index 78efbf7..b53ca13 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupByKey.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupByKey.java
@@ -22,9 +22,12 @@
import org.apache.beam.sdk.coders.IterableCoder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.transforms.display.DisplayData;
+import org.apache.beam.sdk.transforms.windowing.AfterWatermark.AfterWatermarkEarlyAndLate;
+import org.apache.beam.sdk.transforms.windowing.AfterWatermark.FromEndOfWindow;
import org.apache.beam.sdk.transforms.windowing.DefaultTrigger;
import org.apache.beam.sdk.transforms.windowing.GlobalWindows;
import org.apache.beam.sdk.transforms.windowing.InvalidWindows;
+import org.apache.beam.sdk.transforms.windowing.Never.NeverTrigger;
import org.apache.beam.sdk.transforms.windowing.TimestampCombiner;
import org.apache.beam.sdk.transforms.windowing.Window;
import org.apache.beam.sdk.transforms.windowing.WindowFn;
@@ -151,9 +154,8 @@
&& windowingStrategy.getTrigger() instanceof DefaultTrigger
&& input.isBounded() != IsBounded.BOUNDED) {
throw new IllegalStateException(
- "GroupByKey cannot be applied to non-bounded PCollection in "
- + "the GlobalWindow without a trigger. Use a Window.into or Window.triggering transform "
- + "prior to GroupByKey.");
+ "GroupByKey cannot be applied to non-bounded PCollection in the GlobalWindow without a"
+ + " trigger. Use a Window.into or Window.triggering transform prior to GroupByKey.");
}
// Validate the window merge function.
@@ -162,6 +164,45 @@
throw new IllegalStateException(
"GroupByKey must have a valid Window merge function. " + "Invalid because: " + cause);
}
+
+ // Validate that the trigger does not finish before garbage collection time
+ if (!triggerIsSafe(windowingStrategy)) {
+ throw new IllegalArgumentException(
+ String.format(
+ "Unsafe trigger may lose data, see"
+ + " https://s.apache.org/finishing-triggers-drop-data: %s",
+ windowingStrategy.getTrigger()));
+ }
+ }
+
+ // Note that Never trigger finishes *at* GC time so it is OK, and
+ // AfterWatermark.fromEndOfWindow() finishes at end-of-window time so it is
+ // OK if there is no allowed lateness.
+ private static boolean triggerIsSafe(WindowingStrategy<?, ?> windowingStrategy) {
+ if (!windowingStrategy.getTrigger().mayFinish()) {
+ return true;
+ }
+
+ if (windowingStrategy.getTrigger() instanceof NeverTrigger) {
+ return true;
+ }
+
+ if (windowingStrategy.getTrigger() instanceof FromEndOfWindow
+ && windowingStrategy.getAllowedLateness().getMillis() == 0) {
+ return true;
+ }
+
+ if (windowingStrategy.getTrigger() instanceof AfterWatermarkEarlyAndLate
+ && windowingStrategy.getAllowedLateness().getMillis() == 0) {
+ return true;
+ }
+
+ if (windowingStrategy.getTrigger() instanceof AfterWatermarkEarlyAndLate
+ && ((AfterWatermarkEarlyAndLate) windowingStrategy.getTrigger()).getLateTrigger() != null) {
+ return true;
+ }
+
+ return false;
}
public WindowingStrategy<?, ?> updateWindowingStrategy(WindowingStrategy<?, ?> inputStrategy) {
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/AfterEach.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/AfterEach.java
index 2ce2fdf..eb15888 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/AfterEach.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/AfterEach.java
@@ -68,6 +68,11 @@
}
@Override
+ public boolean mayFinish() {
+ return subTriggers.stream().allMatch(trigger -> trigger.mayFinish());
+ }
+
+ @Override
protected Trigger getContinuationTrigger(List<Trigger> continuationTriggers) {
return Repeatedly.forever(new AfterFirst(continuationTriggers));
}
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/AfterWatermark.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/AfterWatermark.java
index 2be41de..339b13b 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/AfterWatermark.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/AfterWatermark.java
@@ -119,6 +119,12 @@
return window.maxTimestamp();
}
+ /** @return true if there is no late firing set up, otherwise false */
+ @Override
+ public boolean mayFinish() {
+ return lateTrigger == null;
+ }
+
@Override
public String toString() {
StringBuilder builder = new StringBuilder(TO_STRING);
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/DefaultTrigger.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/DefaultTrigger.java
index 39d5d13..e2aff9f 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/DefaultTrigger.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/DefaultTrigger.java
@@ -44,6 +44,12 @@
return window.maxTimestamp();
}
+ /** @return false; the default trigger never finishes */
+ @Override
+ public boolean mayFinish() {
+ return false;
+ }
+
@Override
public boolean isCompatible(Trigger other) {
// Semantically, all default triggers are identical
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/OrFinallyTrigger.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/OrFinallyTrigger.java
index a8f6659..d2ea3f1 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/OrFinallyTrigger.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/OrFinallyTrigger.java
@@ -60,6 +60,11 @@
}
@Override
+ public boolean mayFinish() {
+ return subTriggers.get(ACTUAL).mayFinish() || subTriggers.get(UNTIL).mayFinish();
+ }
+
+ @Override
protected Trigger getContinuationTrigger(List<Trigger> continuationTriggers) {
// Use OrFinallyTrigger instead of AfterFirst because the continuation of ACTUAL
// may not be a OnceTrigger.
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/Repeatedly.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/Repeatedly.java
index be4dd53..9c54d75 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/Repeatedly.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/Repeatedly.java
@@ -69,6 +69,11 @@
}
@Override
+ public boolean mayFinish() {
+ return false;
+ }
+
+ @Override
protected Trigger getContinuationTrigger(List<Trigger> continuationTriggers) {
return new Repeatedly(continuationTriggers.get(REPEATED));
}
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/ReshuffleTrigger.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/ReshuffleTrigger.java
index ceb7011..63103e6 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/ReshuffleTrigger.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/ReshuffleTrigger.java
@@ -52,6 +52,11 @@
}
@Override
+ public boolean mayFinish() {
+ return false;
+ }
+
+ @Override
public String toString() {
return "ReshuffleTrigger()";
}
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/Trigger.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/Trigger.java
index ffddebd..639ba6c 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/Trigger.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/Trigger.java
@@ -137,6 +137,15 @@
/**
* <b><i>For internal use only; no backwards-compatibility guarantees.</i></b>
*
+ * <p>Indicates whether this trigger may "finish". A top level trigger that finishes can cause
+ * data loss, so is rejected by GroupByKey validation.
+ */
+ @Internal
+ public abstract boolean mayFinish();
+
+ /**
+ * <b><i>For internal use only; no backwards-compatibility guarantees.</i></b>
+ *
* <p>Returns whether this performs the same triggering as the given {@link Trigger}.
*/
@Internal
@@ -230,6 +239,11 @@
}
@Override
+ public final boolean mayFinish() {
+ return true;
+ }
+
+ @Override
public final OnceTrigger getContinuationTrigger() {
Trigger continuation = super.getContinuationTrigger();
if (!(continuation instanceof OnceTrigger)) {
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/package-info.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/package-info.java
index b4772f3..8a73dd0 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/package-info.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/package-info.java
@@ -16,5 +16,9 @@
* limitations under the License.
*/
-/** Defines utilities that can be used by Beam runners. */
+/**
+ * <b>For internal use only; no backwards compatibility guarantees.</b>
+ *
+ * <p>Defines utilities that can be used by Beam runners.
+ */
package org.apache.beam.sdk.util;
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TextIOWriteTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TextIOWriteTest.java
index e358ff9..2bce31b 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TextIOWriteTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TextIOWriteTest.java
@@ -639,6 +639,10 @@
@Test
@Category(NeedsRunner.class)
public void testWindowedWritesWithOnceTrigger() throws Throwable {
+ p.enableAbandonedNodeEnforcement(false);
+ expectedException.expect(IllegalArgumentException.class);
+ expectedException.expectMessage("Unsafe trigger");
+
// Tests for https://issues.apache.org/jira/browse/BEAM-3169
PCollection<String> data =
p.apply(Create.of("0", "1", "2"))
@@ -660,17 +664,6 @@
.<Void>withOutputFilenames())
.getPerDestinationOutputFilenames()
.apply(Values.create());
-
- PAssert.that(
- filenames
- .apply(FileIO.matchAll())
- .apply(FileIO.readMatches())
- .apply(TextIO.readFiles()))
- .containsInAnyOrder("0", "1", "2");
-
- PAssert.that(filenames.apply(TextIO.readAll())).containsInAnyOrder("0", "1", "2");
-
- p.run();
}
@Test
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TextSourceTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TextSourceTest.java
new file mode 100644
index 0000000..36a3f68
--- /dev/null
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TextSourceTest.java
@@ -0,0 +1,158 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io;
+
+import java.io.BufferedWriter;
+import java.io.IOException;
+import java.nio.charset.Charset;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.io.FileIO.ReadableFile;
+import org.apache.beam.sdk.io.fs.EmptyMatchTreatment;
+import org.apache.beam.sdk.options.ValueProvider;
+import org.apache.beam.sdk.testing.NeedsRunner;
+import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.values.PCollection;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.experimental.categories.Category;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/** Tests for TextSource class. */
+@RunWith(JUnit4.class)
+public class TextSourceTest {
+ @Rule public transient TestPipeline pipeline = TestPipeline.create();
+
+ @Test
+ @Category(NeedsRunner.class)
+ public void testRemoveUtf8BOM() throws Exception {
+ Path p1 = createTestFile("test_txt_ascii", Charset.forName("US-ASCII"), "1,p1", "2,p1");
+ Path p2 =
+ createTestFile(
+ "test_txt_utf8_no_bom",
+ Charset.forName("UTF-8"),
+ "1,p2-Japanese:テスト",
+ "2,p2-Japanese:テスト");
+ Path p3 =
+ createTestFile(
+ "test_txt_utf8_bom",
+ Charset.forName("UTF-8"),
+ "\uFEFF1,p3-テストBOM",
+ "\uFEFF2,p3-テストBOM");
+ PCollection<String> contents =
+ pipeline
+ .apply("Create", Create.of(p1.toString(), p2.toString(), p3.toString()))
+ .setCoder(StringUtf8Coder.of())
+ // PCollection<String>
+ .apply("Read file", new TextFileReadTransform());
+ // PCollection<KV<String, String>>: tableName, line
+
+ // Validate that the BOM bytes (\uFEFF) at the beginning of the first line have been removed.
+ PAssert.that(contents)
+ .containsInAnyOrder(
+ "1,p1",
+ "2,p1",
+ "1,p2-Japanese:テスト",
+ "2,p2-Japanese:テスト",
+ "1,p3-テストBOM",
+ "\uFEFF2,p3-テストBOM");
+
+ pipeline.run();
+ }
+
+ @Test
+ @Category(NeedsRunner.class)
+ public void testPreserveNonBOMBytes() throws Exception {
+ // Contains \uFEFE, not UTF BOM.
+ Path p1 =
+ createTestFile(
+ "test_txt_utf_bom", Charset.forName("UTF-8"), "\uFEFE1,p1テスト", "\uFEFE2,p1テスト");
+ PCollection<String> contents =
+ pipeline
+ .apply("Create", Create.of(p1.toString()))
+ .setCoder(StringUtf8Coder.of())
+ // PCollection<String>
+ .apply("Read file", new TextFileReadTransform());
+
+ PAssert.that(contents).containsInAnyOrder("\uFEFE1,p1テスト", "\uFEFE2,p1テスト");
+
+ pipeline.run();
+ }
+
+ private static class FileReadDoFn extends DoFn<ReadableFile, String> {
+
+ @ProcessElement
+ public void processElement(ProcessContext c) {
+ ReadableFile file = c.element();
+ ValueProvider<String> filenameProvider =
+ ValueProvider.StaticValueProvider.of(file.getMetadata().resourceId().getFilename());
+ // Create a TextSource, passing null as the delimiter to use the default
+ // delimiters ('\n', '\r', or '\r\n').
+ TextSource textSource = new TextSource(filenameProvider, null, null);
+ try {
+ BoundedSource.BoundedReader<String> reader =
+ textSource
+ .createForSubrangeOfFile(file.getMetadata(), 0, file.getMetadata().sizeBytes())
+ .createReader(c.getPipelineOptions());
+ for (boolean more = reader.start(); more; more = reader.advance()) {
+ c.output(reader.getCurrent());
+ }
+ } catch (IOException e) {
+ throw new RuntimeException(
+ "Unable to readFile: " + file.getMetadata().resourceId().toString());
+ }
+ }
+ }
+
+ /** A transform that reads CSV file records. */
+ private static class TextFileReadTransform
+ extends PTransform<PCollection<String>, PCollection<String>> {
+ public TextFileReadTransform() {}
+
+ @Override
+ public PCollection<String> expand(PCollection<String> files) {
+ return files
+ // PCollection<String>
+ .apply(FileIO.matchAll().withEmptyMatchTreatment(EmptyMatchTreatment.DISALLOW))
+ // PCollection<Match.Metadata>
+ .apply(FileIO.readMatches())
+ // PCollection<FileIO.ReadableFile>
+ .apply("Read lines", ParDo.of(new FileReadDoFn()));
+ // PCollection<String>: line
+ }
+ }
+
+ private Path createTestFile(String filename, Charset charset, String... lines)
+ throws IOException {
+ Path path = Files.createTempFile(filename, ".csv");
+ try (BufferedWriter writer = Files.newBufferedWriter(path, charset)) {
+ for (String line : lines) {
+ writer.write(line);
+ writer.write('\n');
+ }
+ }
+ return path;
+ }
+}
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/TestStreamTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/TestStreamTest.java
index 5e4cdcb..e48b6b2 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/TestStreamTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/TestStreamTest.java
@@ -17,6 +17,7 @@
*/
package org.apache.beam.sdk.testing;
+import static org.apache.beam.sdk.transforms.windowing.Window.into;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.allOf;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
@@ -28,12 +29,22 @@
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.coders.VarIntCoder;
import org.apache.beam.sdk.coders.VarLongCoder;
+import org.apache.beam.sdk.state.StateSpec;
+import org.apache.beam.sdk.state.StateSpecs;
+import org.apache.beam.sdk.state.TimeDomain;
+import org.apache.beam.sdk.state.Timer;
+import org.apache.beam.sdk.state.TimerSpec;
+import org.apache.beam.sdk.state.TimerSpecs;
+import org.apache.beam.sdk.state.ValueState;
import org.apache.beam.sdk.testing.TestStream.Builder;
import org.apache.beam.sdk.transforms.Combine;
import org.apache.beam.sdk.transforms.Count;
+import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.Flatten;
import org.apache.beam.sdk.transforms.GroupByKey;
+import org.apache.beam.sdk.transforms.Keys;
import org.apache.beam.sdk.transforms.MapElements;
+import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.Sum;
import org.apache.beam.sdk.transforms.Values;
import org.apache.beam.sdk.transforms.WithKeys;
@@ -44,6 +55,7 @@
import org.apache.beam.sdk.transforms.windowing.DefaultTrigger;
import org.apache.beam.sdk.transforms.windowing.FixedWindows;
import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
+import org.apache.beam.sdk.transforms.windowing.GlobalWindows;
import org.apache.beam.sdk.transforms.windowing.IntervalWindow;
import org.apache.beam.sdk.transforms.windowing.Never;
import org.apache.beam.sdk.transforms.windowing.Window;
@@ -51,8 +63,10 @@
import org.apache.beam.sdk.util.CoderUtils;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionList;
import org.apache.beam.sdk.values.TimestampedValue;
import org.apache.beam.sdk.values.TypeDescriptors;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
import org.joda.time.Duration;
import org.joda.time.Instant;
import org.junit.Rule;
@@ -263,7 +277,7 @@
FixedWindows windows = FixedWindows.of(Duration.standardHours(6));
PCollection<String> windowedValues =
p.apply(stream)
- .apply(Window.into(windows))
+ .apply(into(windows))
.apply(WithKeys.of(1))
.apply(GroupByKey.create())
.apply(Values.create())
@@ -387,6 +401,74 @@
}
@Test
+ @Category({ValidatesRunner.class, UsesTestStream.class, UsesTestStreamWithMultipleStages.class})
+ public void testMultiStage() throws Exception {
+ TestStream<String> testStream =
+ TestStream.create(StringUtf8Coder.of())
+ .addElements("before") // before
+ .advanceWatermarkTo(Instant.ofEpochSecond(0)) // BEFORE
+ .addElements(TimestampedValue.of("after", Instant.ofEpochSecond(10))) // after
+ .advanceWatermarkToInfinity(); // AFTER
+
+ PCollection<String> input = p.apply(testStream);
+
+ PCollection<String> grouped =
+ input
+ .apply(Window.into(FixedWindows.of(Duration.standardSeconds(1))))
+ .apply(
+ MapElements.into(
+ TypeDescriptors.kvs(TypeDescriptors.strings(), TypeDescriptors.strings()))
+ .via(e -> KV.of(e, e)))
+ .apply(GroupByKey.create())
+ .apply(Keys.create())
+ .apply("Upper", MapElements.into(TypeDescriptors.strings()).via(String::toUpperCase))
+ .apply("Rewindow", Window.into(new GlobalWindows()));
+
+ PCollection<String> result =
+ PCollectionList.of(ImmutableList.of(input, grouped))
+ .apply(Flatten.pCollections())
+ .apply(
+ "Key",
+ MapElements.into(
+ TypeDescriptors.kvs(TypeDescriptors.strings(), TypeDescriptors.strings()))
+ .via(e -> KV.of("key", e)))
+ .apply(
+ ParDo.of(
+ new DoFn<KV<String, String>, String>() {
+ @StateId("seen")
+ private final StateSpec<ValueState<String>> seenSpec =
+ StateSpecs.value(StringUtf8Coder.of());
+
+ @TimerId("emit")
+ private final TimerSpec emitSpec = TimerSpecs.timer(TimeDomain.EVENT_TIME);
+
+ @ProcessElement
+ public void process(
+ ProcessContext context,
+ @StateId("seen") ValueState<String> seenState,
+ @TimerId("emit") Timer emitTimer) {
+ String element = context.element().getValue();
+ if (seenState.read() == null) {
+ seenState.write(element);
+ } else {
+ seenState.write(seenState.read() + "," + element);
+ }
+ emitTimer.set(Instant.ofEpochSecond(100));
+ }
+
+ @OnTimer("emit")
+ public void onEmit(
+ OnTimerContext context, @StateId("seen") ValueState<String> seenState) {
+ context.output(seenState.read());
+ }
+ }));
+
+ PAssert.that(result).containsInAnyOrder("before,BEFORE,after,AFTER");
+
+ p.run().waitUntilFinish();
+ }
+
+ @Test
@Category(UsesTestStreamWithProcessingTime.class)
public void testCoder() throws Exception {
TestStream<String> testStream =
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupByKeyTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupByKeyTest.java
index 6d8408e..47af8a9 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupByKeyTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupByKeyTest.java
@@ -56,7 +56,9 @@
import org.apache.beam.sdk.testing.UsesTestStreamWithProcessingTime;
import org.apache.beam.sdk.testing.ValidatesRunner;
import org.apache.beam.sdk.transforms.display.DisplayData;
+import org.apache.beam.sdk.transforms.windowing.AfterPane;
import org.apache.beam.sdk.transforms.windowing.AfterProcessingTime;
+import org.apache.beam.sdk.transforms.windowing.AfterWatermark;
import org.apache.beam.sdk.transforms.windowing.FixedWindows;
import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
import org.apache.beam.sdk.transforms.windowing.IntervalWindow;
@@ -79,12 +81,14 @@
import org.junit.Rule;
import org.junit.Test;
import org.junit.experimental.categories.Category;
+import org.junit.experimental.runners.Enclosed;
import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** Tests for GroupByKey. */
@SuppressWarnings({"rawtypes", "unchecked"})
+@RunWith(Enclosed.class)
public class GroupByKeyTest implements Serializable {
/** Shared test base class with setup/teardown helpers. */
public abstract static class SharedTestBase {
@@ -196,6 +200,104 @@
input.apply(GroupByKey.create());
}
+ // AfterPane.elementCountAtLeast(1) is not OK
+ @Test
+ public void testGroupByKeyFinishingTriggerRejected() {
+ PCollection<KV<String, String>> input =
+ p.apply(Create.of(KV.of("hello", "goodbye")))
+ .apply(
+ Window.<KV<String, String>>configure()
+ .discardingFiredPanes()
+ .triggering(AfterPane.elementCountAtLeast(1)));
+
+ thrown.expect(IllegalArgumentException.class);
+ thrown.expectMessage("Unsafe trigger");
+ input.apply(GroupByKey.create());
+ }
+
+ // AfterWatermark.pastEndOfWindow() is OK with 0 allowed lateness
+ @Test
+ public void testGroupByKeyFinishingEndOfWindowTriggerOk() {
+ PCollection<KV<String, String>> input =
+ p.apply(Create.of(KV.of("hello", "goodbye")))
+ .apply(
+ Window.<KV<String, String>>configure()
+ .discardingFiredPanes()
+ .triggering(AfterWatermark.pastEndOfWindow())
+ .withAllowedLateness(Duration.ZERO));
+
+ // OK
+ input.apply(GroupByKey.create());
+ }
+
+ // AfterWatermark.pastEndOfWindow().withEarlyFirings() is OK with 0 allowed lateness
+ @Test
+ public void testGroupByKeyFinishingEndOfWindowEarlyFiringsTriggerOk() {
+ PCollection<KV<String, String>> input =
+ p.apply(Create.of(KV.of("hello", "goodbye")))
+ .apply(
+ Window.<KV<String, String>>configure()
+ .discardingFiredPanes()
+ .triggering(
+ AfterWatermark.pastEndOfWindow()
+ .withEarlyFirings(AfterPane.elementCountAtLeast(1)))
+ .withAllowedLateness(Duration.ZERO));
+
+ // OK
+ input.apply(GroupByKey.create());
+ }
+
+ // AfterWatermark.pastEndOfWindow() is not OK with > 0 allowed lateness
+ @Test
+ public void testGroupByKeyFinishingEndOfWindowTriggerNotOk() {
+ PCollection<KV<String, String>> input =
+ p.apply(Create.of(KV.of("hello", "goodbye")))
+ .apply(
+ Window.<KV<String, String>>configure()
+ .discardingFiredPanes()
+ .triggering(AfterWatermark.pastEndOfWindow())
+ .withAllowedLateness(Duration.millis(10)));
+
+ thrown.expect(IllegalArgumentException.class);
+ thrown.expectMessage("Unsafe trigger");
+ input.apply(GroupByKey.create());
+ }
+
+ // AfterWatermark.pastEndOfWindow().withEarlyFirings() is not OK with > 0 allowed lateness
+ @Test
+ public void testGroupByKeyFinishingEndOfWindowEarlyFiringsTriggerNotOk() {
+ PCollection<KV<String, String>> input =
+ p.apply(Create.of(KV.of("hello", "goodbye")))
+ .apply(
+ Window.<KV<String, String>>configure()
+ .discardingFiredPanes()
+ .triggering(
+ AfterWatermark.pastEndOfWindow()
+ .withEarlyFirings(AfterPane.elementCountAtLeast(1)))
+ .withAllowedLateness(Duration.millis(10)));
+
+ thrown.expect(IllegalArgumentException.class);
+ thrown.expectMessage("Unsafe trigger");
+ input.apply(GroupByKey.create());
+ }
+
+ // AfterWatermark.pastEndOfWindow().withLateFirings() is always OK
+ @Test
+ public void testGroupByKeyEndOfWindowLateFiringsOk() {
+ PCollection<KV<String, String>> input =
+ p.apply(Create.of(KV.of("hello", "goodbye")))
+ .apply(
+ Window.<KV<String, String>>configure()
+ .discardingFiredPanes()
+ .triggering(
+ AfterWatermark.pastEndOfWindow()
+ .withLateFirings(AfterPane.elementCountAtLeast(1)))
+ .withAllowedLateness(Duration.millis(10)));
+
+ // OK
+ input.apply(GroupByKey.create());
+ }
+
@Test
@Category(NeedsRunner.class)
public void testRemerge() {
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/windowing/TriggerTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/windowing/TriggerTest.java
index 36037c0..335d967 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/windowing/TriggerTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/windowing/TriggerTest.java
@@ -68,6 +68,11 @@
public Instant getWatermarkThatGuaranteesFiring(BoundedWindow window) {
return null;
}
+
+ @Override
+ public boolean mayFinish() {
+ return false;
+ }
}
private static class Trigger2 extends Trigger {
@@ -85,5 +90,10 @@
public Instant getWatermarkThatGuaranteesFiring(BoundedWindow window) {
return null;
}
+
+ @Override
+ public boolean mayFinish() {
+ return false;
+ }
}
}
diff --git a/sdks/java/extensions/euphoria/build.gradle b/sdks/java/extensions/euphoria/build.gradle
index 93f6ebd..c79915b 100644
--- a/sdks/java/extensions/euphoria/build.gradle
+++ b/sdks/java/extensions/euphoria/build.gradle
@@ -30,7 +30,7 @@
testCompile library.java.hamcrest_library
testCompile library.java.mockito_core
testCompile project(path: ":sdks:java:core", configuration: "shadowTest")
- testRuntimeOnly project(":runners:direct-java")
+ testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
}
test {
diff --git a/sdks/java/extensions/jackson/build.gradle b/sdks/java/extensions/jackson/build.gradle
index 3d1e692..b36343c 100644
--- a/sdks/java/extensions/jackson/build.gradle
+++ b/sdks/java/extensions/jackson/build.gradle
@@ -32,5 +32,5 @@
testCompile library.java.hamcrest_core
testCompile library.java.hamcrest_library
testCompile library.java.junit
- testRuntimeOnly project(":runners:direct-java")
+ testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
}
diff --git a/sdks/java/extensions/join-library/build.gradle b/sdks/java/extensions/join-library/build.gradle
index c3c79e9..1257f8f 100644
--- a/sdks/java/extensions/join-library/build.gradle
+++ b/sdks/java/extensions/join-library/build.gradle
@@ -27,5 +27,5 @@
testCompile library.java.hamcrest_core
testCompile library.java.hamcrest_library
testCompile library.java.junit
- testRuntimeOnly project(":runners:direct-java")
+ testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
}
diff --git a/sdks/java/extensions/kryo/build.gradle b/sdks/java/extensions/kryo/build.gradle
index 24c8e0c..dc8c62a 100644
--- a/sdks/java/extensions/kryo/build.gradle
+++ b/sdks/java/extensions/kryo/build.gradle
@@ -43,7 +43,7 @@
compile "com.esotericsoftware:kryo:${kryoVersion}"
shadow project(path: ':sdks:java:core', configuration: 'shadow')
testCompile project(path: ':sdks:java:core', configuration: 'shadowTest')
- testRuntimeOnly project(':runners:direct-java')
+ testRuntimeOnly project(path: ':runners:direct-java', configuration: 'shadow')
}
test {
diff --git a/sdks/java/extensions/sketching/build.gradle b/sdks/java/extensions/sketching/build.gradle
index 1f403ee..d923501 100644
--- a/sdks/java/extensions/sketching/build.gradle
+++ b/sdks/java/extensions/sketching/build.gradle
@@ -36,5 +36,5 @@
testCompile library.java.hamcrest_core
testCompile library.java.hamcrest_library
testCompile library.java.junit
- testRuntimeOnly project(":runners:direct-java")
+ testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
}
diff --git a/sdks/java/extensions/sorter/build.gradle b/sdks/java/extensions/sorter/build.gradle
index 94cbba5..4994003 100644
--- a/sdks/java/extensions/sorter/build.gradle
+++ b/sdks/java/extensions/sorter/build.gradle
@@ -30,5 +30,5 @@
testCompile library.java.hamcrest_library
testCompile library.java.mockito_core
testCompile library.java.junit
- testRuntimeOnly project(":runners:direct-java")
+ testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
}
diff --git a/sdks/java/extensions/sql/build.gradle b/sdks/java/extensions/sql/build.gradle
index 78d7383..b75c339 100644
--- a/sdks/java/extensions/sql/build.gradle
+++ b/sdks/java/extensions/sql/build.gradle
@@ -45,7 +45,7 @@
fmppTemplates library.java.vendored_calcite_1_20_0
compile project(":sdks:java:core")
compile project(":sdks:java:extensions:join-library")
- compile project(":runners:direct-java")
+ compile project(path: ":runners:direct-java", configuration: "shadow")
compile library.java.commons_csv
compile library.java.vendored_calcite_1_20_0
compile "com.alibaba:fastjson:1.2.49"
@@ -53,8 +53,10 @@
compile "org.codehaus.janino:commons-compiler:3.0.11"
provided project(":sdks:java:io:kafka")
provided project(":sdks:java:io:google-cloud-platform")
+ compile project(":sdks:java:io:mongodb")
provided project(":sdks:java:io:parquet")
provided library.java.kafka_clients
+ runtimeOnly library.java.hadoop_client
testCompile library.java.vendored_calcite_1_20_0
testCompile library.java.vendored_guava_26_0_jre
testCompile library.java.junit
@@ -62,6 +64,7 @@
testCompile library.java.hamcrest_library
testCompile library.java.mockito_core
testCompile library.java.quickcheck_core
+ testCompile project(path: ":sdks:java:io:mongodb", configuration: "testRuntime")
testRuntimeClasspath library.java.slf4j_jdk14
}
@@ -156,6 +159,7 @@
include '**/*IT.class'
exclude '**/KafkaCSVTableIT.java'
+ exclude '**/MongoDbReadWriteIT.java'
maxParallelForks 4
classpath = project(":sdks:java:extensions:sql")
.sourceSets
diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/mongodb/MongoDbTable.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/mongodb/MongoDbTable.java
new file mode 100644
index 0000000..eaa9661
--- /dev/null
+++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/mongodb/MongoDbTable.java
@@ -0,0 +1,143 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.extensions.sql.meta.provider.mongodb;
+
+import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument;
+
+import java.io.Serializable;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
+import org.apache.beam.sdk.annotations.Experimental;
+import org.apache.beam.sdk.extensions.sql.impl.BeamTableStatistics;
+import org.apache.beam.sdk.extensions.sql.meta.SchemaBaseBeamTable;
+import org.apache.beam.sdk.extensions.sql.meta.Table;
+import org.apache.beam.sdk.io.mongodb.MongoDbIO;
+import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.schemas.Schema;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.JsonToRow;
+import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.values.PBegin;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollection.IsBounded;
+import org.apache.beam.sdk.values.POutput;
+import org.apache.beam.sdk.values.Row;
+import org.apache.beam.vendor.calcite.v1_20_0.com.google.common.annotations.VisibleForTesting;
+import org.bson.Document;
+import org.bson.json.JsonMode;
+import org.bson.json.JsonWriterSettings;
+
+@Experimental
+public class MongoDbTable extends SchemaBaseBeamTable implements Serializable {
+ // Should match: mongodb://username:password@localhost:27017/database/collection
+ @VisibleForTesting
+ final Pattern locationPattern =
+ Pattern.compile(
+ "(?<credsHostPort>mongodb://(?<usernamePassword>.*(?<password>:.*)?@)?.+:\\d+)/(?<database>.+)/(?<collection>.+)");
+
+ @VisibleForTesting final String dbCollection;
+ @VisibleForTesting final String dbName;
+ @VisibleForTesting final String dbUri;
+
+ MongoDbTable(Table table) {
+ super(table.getSchema());
+
+ String location = table.getLocation();
+ Matcher matcher = locationPattern.matcher(location);
+ checkArgument(
+ matcher.matches(),
+ "MongoDb location must be in the following format: 'mongodb://(username:password@)?localhost:27017/database/collection'");
+ this.dbUri = matcher.group("credsHostPort"); // "mongodb://localhost:27017"
+ this.dbName = matcher.group("database");
+ this.dbCollection = matcher.group("collection");
+ }
+
+ @Override
+ public PCollection<Row> buildIOReader(PBegin begin) {
+ // Read MongoDb Documents
+ PCollection<Document> readDocuments =
+ MongoDbIO.read()
+ .withUri(dbUri)
+ .withDatabase(dbName)
+ .withCollection(dbCollection)
+ .expand(begin);
+
+ return readDocuments.apply(DocumentToRow.withSchema(getSchema()));
+ }
+
+ @Override
+ public POutput buildIOWriter(PCollection<Row> input) {
+ throw new UnsupportedOperationException("Writing to a MongoDB is not supported");
+ }
+
+ @Override
+ public IsBounded isBounded() {
+ return IsBounded.BOUNDED;
+ }
+
+ @Override
+ public BeamTableStatistics getTableStatistics(PipelineOptions options) {
+ long count =
+ MongoDbIO.read()
+ .withUri(dbUri)
+ .withDatabase(dbName)
+ .withCollection(dbCollection)
+ .getDocumentCount();
+
+ if (count < 0) {
+ return BeamTableStatistics.BOUNDED_UNKNOWN;
+ }
+
+ return BeamTableStatistics.createBoundedTableStatistics((double) count);
+ }
+
+ public static class DocumentToRow extends PTransform<PCollection<Document>, PCollection<Row>> {
+ private final Schema schema;
+
+ private DocumentToRow(Schema schema) {
+ this.schema = schema;
+ }
+
+ public static DocumentToRow withSchema(Schema schema) {
+ return new DocumentToRow(schema);
+ }
+
+ @Override
+ public PCollection<Row> expand(PCollection<Document> input) {
+ // TODO(BEAM-8498): figure out a way convert Document directly to Row.
+ return input
+ .apply("Convert Document to JSON", ParDo.of(new DocumentToJsonStringConverter()))
+ .apply("Transform JSON to Row", JsonToRow.withSchema(schema))
+ .setRowSchema(schema);
+ }
+
+ // TODO: add support for complex fields (May require modifying how Calcite parses nested
+ // fields).
+ @VisibleForTesting
+ static class DocumentToJsonStringConverter extends DoFn<Document, String> {
+ @DoFn.ProcessElement
+ public void processElement(ProcessContext context) {
+ context.output(
+ context
+ .element()
+ .toJson(JsonWriterSettings.builder().outputMode(JsonMode.RELAXED).build()));
+ }
+ }
+ }
+}
diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/mongodb/MongoDbTableProvider.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/mongodb/MongoDbTableProvider.java
new file mode 100644
index 0000000..ead09f0
--- /dev/null
+++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/mongodb/MongoDbTableProvider.java
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.extensions.sql.meta.provider.mongodb;
+
+import com.google.auto.service.AutoService;
+import org.apache.beam.sdk.extensions.sql.meta.BeamSqlTable;
+import org.apache.beam.sdk.extensions.sql.meta.Table;
+import org.apache.beam.sdk.extensions.sql.meta.provider.InMemoryMetaTableProvider;
+import org.apache.beam.sdk.extensions.sql.meta.provider.TableProvider;
+
+/**
+ * {@link TableProvider} for {@link MongoDbTable}.
+ *
+ * <p>A sample of MongoDb table is:
+ *
+ * <pre>{@code
+ * CREATE TABLE ORDERS(
+ * name VARCHAR,
+ * favorite_color VARCHAR,
+ * favorite_numbers ARRAY<INTEGER>
+ * )
+ * TYPE 'mongodb'
+ * LOCATION 'mongodb://username:password@localhost:27017/database/collection'
+ * }</pre>
+ */
+@AutoService(TableProvider.class)
+public class MongoDbTableProvider extends InMemoryMetaTableProvider {
+
+ @Override
+ public String getTableType() {
+ return "mongodb";
+ }
+
+ @Override
+ public BeamSqlTable buildBeamSqlTable(Table table) {
+ return new MongoDbTable(table);
+ }
+}
diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/mongodb/package-info.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/mongodb/package-info.java
new file mode 100644
index 0000000..51c9a74
--- /dev/null
+++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/mongodb/package-info.java
@@ -0,0 +1,24 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/** Table schema for MongoDb. */
+@DefaultAnnotation(NonNull.class)
+package org.apache.beam.sdk.extensions.sql.meta.provider.mongodb;
+
+import edu.umd.cs.findbugs.annotations.DefaultAnnotation;
+import edu.umd.cs.findbugs.annotations.NonNull;
diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/mongodb/MongoDbReadWriteIT.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/mongodb/MongoDbReadWriteIT.java
new file mode 100644
index 0000000..9c4d4cd
--- /dev/null
+++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/mongodb/MongoDbReadWriteIT.java
@@ -0,0 +1,197 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.extensions.sql.meta.provider.mongodb;
+
+import static org.apache.beam.sdk.schemas.Schema.FieldType.BOOLEAN;
+import static org.apache.beam.sdk.schemas.Schema.FieldType.BYTE;
+import static org.apache.beam.sdk.schemas.Schema.FieldType.DOUBLE;
+import static org.apache.beam.sdk.schemas.Schema.FieldType.FLOAT;
+import static org.apache.beam.sdk.schemas.Schema.FieldType.INT16;
+import static org.apache.beam.sdk.schemas.Schema.FieldType.INT32;
+import static org.apache.beam.sdk.schemas.Schema.FieldType.INT64;
+import static org.apache.beam.sdk.schemas.Schema.FieldType.STRING;
+import static org.junit.Assert.assertEquals;
+
+import com.mongodb.MongoClient;
+import java.util.Arrays;
+import org.apache.beam.sdk.PipelineResult;
+import org.apache.beam.sdk.extensions.sql.impl.BeamSqlEnv;
+import org.apache.beam.sdk.extensions.sql.impl.rel.BeamSqlRelUtils;
+import org.apache.beam.sdk.io.mongodb.MongoDBIOIT.MongoDBPipelineOptions;
+import org.apache.beam.sdk.io.mongodb.MongoDbIO;
+import org.apache.beam.sdk.options.PipelineOptionsFactory;
+import org.apache.beam.sdk.schemas.Schema;
+import org.apache.beam.sdk.schemas.Schema.FieldType;
+import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.MapElements;
+import org.apache.beam.sdk.transforms.SimpleFunction;
+import org.apache.beam.sdk.transforms.ToJson;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.Row;
+import org.bson.Document;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Ignore;
+import org.junit.Rule;
+import org.junit.Test;
+
+/**
+ * A test of {@link org.apache.beam.sdk.extensions.sql.meta.provider.mongodb.MongoDbTable} on an
+ * independent Mongo instance.
+ *
+ * <p>This test requires a running instance of MongoDB. Pass in connection information using
+ * PipelineOptions:
+ *
+ * <pre>
+ * ./gradlew integrationTest -p sdks/java/extensions/sql/integrationTest -DintegrationTestPipelineOptions='[
+ * "--mongoDBHostName=1.2.3.4",
+ * "--mongoDBPort=27017",
+ * "--mongoDBDatabaseName=mypass",
+ * "--numberOfRecords=1000" ]'
+ * --tests org.apache.beam.sdk.extensions.sql.meta.provider.mongodb.MongoDbReadWriteIT
+ * -DintegrationTestRunner=direct
+ * </pre>
+ *
+ * A database, specified in the pipeline options, will be created implicitly if it does not exist
+ * already. And dropped upon completing tests.
+ *
+ * <p>Please see 'build_rules.gradle' file for instructions regarding running this test using Beam
+ * performance testing framework.
+ */
+@Ignore("https://issues.apache.org/jira/browse/BEAM-8586")
+public class MongoDbReadWriteIT {
+ private static final Schema SOURCE_SCHEMA =
+ Schema.builder()
+ .addNullableField("_id", STRING)
+ .addNullableField("c_bigint", INT64)
+ .addNullableField("c_tinyint", BYTE)
+ .addNullableField("c_smallint", INT16)
+ .addNullableField("c_integer", INT32)
+ .addNullableField("c_float", FLOAT)
+ .addNullableField("c_double", DOUBLE)
+ .addNullableField("c_boolean", BOOLEAN)
+ .addNullableField("c_varchar", STRING)
+ .addNullableField("c_arr", FieldType.array(STRING))
+ .build();
+ private static final String collection = "collection";
+ private static MongoDBPipelineOptions options;
+
+ @Rule public final TestPipeline writePipeline = TestPipeline.create();
+ @Rule public final TestPipeline readPipeline = TestPipeline.create();
+
+ @BeforeClass
+ public static void setUp() throws Exception {
+ PipelineOptionsFactory.register(MongoDBPipelineOptions.class);
+ options = TestPipeline.testingPipelineOptions().as(MongoDBPipelineOptions.class);
+ }
+
+ @AfterClass
+ public static void tearDown() throws Exception {
+ dropDatabase();
+ }
+
+ private static void dropDatabase() throws Exception {
+ new MongoClient(options.getMongoDBHostName())
+ .getDatabase(options.getMongoDBDatabaseName())
+ .drop();
+ }
+
+ @Test
+ public void testWriteAndRead() {
+ final String mongoUrl =
+ String.format("mongodb://%s:%d", options.getMongoDBHostName(), options.getMongoDBPort());
+ final String mongoSqlUrl =
+ String.format(
+ "mongodb://%s:%d/%s/%s",
+ options.getMongoDBHostName(),
+ options.getMongoDBPort(),
+ options.getMongoDBDatabaseName(),
+ collection);
+
+ Row testRow =
+ row(
+ SOURCE_SCHEMA,
+ "object_id",
+ 9223372036854775807L,
+ (byte) 127,
+ (short) 32767,
+ 2147483647,
+ (float) 1.0,
+ 1.0,
+ true,
+ "varchar",
+ Arrays.asList("123", "456"));
+
+ writePipeline
+ .apply(Create.of(testRow))
+ .setRowSchema(SOURCE_SCHEMA)
+ .apply("Transform Rows to JSON", ToJson.of())
+ .apply("Produce documents from JSON", MapElements.via(new ObjectToDocumentFn()))
+ .apply(
+ "Write documents to MongoDB",
+ MongoDbIO.write()
+ .withUri(mongoUrl)
+ .withDatabase(options.getMongoDBDatabaseName())
+ .withCollection(collection));
+ PipelineResult writeResult = writePipeline.run();
+ writeResult.waitUntilFinish();
+
+ String createTableStatement =
+ "CREATE EXTERNAL TABLE TEST( \n"
+ + " _id VARCHAR, \n "
+ + " c_bigint BIGINT, \n "
+ + " c_tinyint TINYINT, \n"
+ + " c_smallint SMALLINT, \n"
+ + " c_integer INTEGER, \n"
+ + " c_float FLOAT, \n"
+ + " c_double DOUBLE, \n"
+ + " c_boolean BOOLEAN, \n"
+ + " c_varchar VARCHAR, \n "
+ + " c_arr ARRAY<VARCHAR> \n"
+ + ") \n"
+ + "TYPE 'mongodb' \n"
+ + "LOCATION '"
+ + mongoSqlUrl
+ + "'";
+
+ BeamSqlEnv sqlEnv = BeamSqlEnv.inMemory(new MongoDbTableProvider());
+ sqlEnv.executeDdl(createTableStatement);
+
+ PCollection<Row> output =
+ BeamSqlRelUtils.toPCollection(readPipeline, sqlEnv.parseQuery("select * from TEST"));
+
+ assertEquals(output.getSchema(), SOURCE_SCHEMA);
+
+ PAssert.that(output).containsInAnyOrder(testRow);
+
+ readPipeline.run().waitUntilFinish();
+ }
+
+ private static class ObjectToDocumentFn extends SimpleFunction<String, Document> {
+ @Override
+ public Document apply(String input) {
+ return Document.parse(input);
+ }
+ }
+
+ private Row row(Schema schema, Object... values) {
+ return Row.withSchema(schema).addValues(values).build();
+ }
+}
diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/mongodb/MongoDbTableProviderTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/mongodb/MongoDbTableProviderTest.java
new file mode 100644
index 0000000..459af56
--- /dev/null
+++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/mongodb/MongoDbTableProviderTest.java
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.extensions.sql.meta.provider.mongodb;
+
+import static org.apache.beam.sdk.schemas.Schema.toSchema;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertThrows;
+import static org.junit.Assert.assertTrue;
+
+import java.util.stream.Stream;
+import org.apache.beam.sdk.extensions.sql.meta.BeamSqlTable;
+import org.apache.beam.sdk.extensions.sql.meta.Table;
+import org.apache.beam.sdk.schemas.Schema;
+import org.apache.beam.vendor.calcite.v1_20_0.com.google.common.collect.ImmutableList;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class MongoDbTableProviderTest {
+ private MongoDbTableProvider provider = new MongoDbTableProvider();
+
+ @Test
+ public void testGetTableType() {
+ assertEquals("mongodb", provider.getTableType());
+ }
+
+ @Test
+ public void testBuildBeamSqlTable() {
+ Table table = fakeTable("TEST", "mongodb://localhost:27017/database/collection");
+ BeamSqlTable sqlTable = provider.buildBeamSqlTable(table);
+
+ assertNotNull(sqlTable);
+ assertTrue(sqlTable instanceof MongoDbTable);
+
+ MongoDbTable mongoTable = (MongoDbTable) sqlTable;
+ assertEquals("mongodb://localhost:27017", mongoTable.dbUri);
+ assertEquals("database", mongoTable.dbName);
+ assertEquals("collection", mongoTable.dbCollection);
+ }
+
+ @Test
+ public void testBuildBeamSqlTable_withUsernameOnly() {
+ Table table = fakeTable("TEST", "mongodb://username@localhost:27017/database/collection");
+ BeamSqlTable sqlTable = provider.buildBeamSqlTable(table);
+
+ assertNotNull(sqlTable);
+ assertTrue(sqlTable instanceof MongoDbTable);
+
+ MongoDbTable mongoTable = (MongoDbTable) sqlTable;
+ assertEquals("mongodb://username@localhost:27017", mongoTable.dbUri);
+ assertEquals("database", mongoTable.dbName);
+ assertEquals("collection", mongoTable.dbCollection);
+ }
+
+ @Test
+ public void testBuildBeamSqlTable_withUsernameAndPassword() {
+ Table table =
+ fakeTable("TEST", "mongodb://username:pasword@localhost:27017/database/collection");
+ BeamSqlTable sqlTable = provider.buildBeamSqlTable(table);
+
+ assertNotNull(sqlTable);
+ assertTrue(sqlTable instanceof MongoDbTable);
+
+ MongoDbTable mongoTable = (MongoDbTable) sqlTable;
+ assertEquals("mongodb://username:pasword@localhost:27017", mongoTable.dbUri);
+ assertEquals("database", mongoTable.dbName);
+ assertEquals("collection", mongoTable.dbCollection);
+ }
+
+ @Test
+ public void testBuildBeamSqlTable_withBadLocation_throwsException() {
+ ImmutableList<String> badLocations =
+ ImmutableList.of(
+ "mongodb://localhost:27017/database/",
+ "mongodb://localhost:27017/database",
+ "localhost:27017/database/collection",
+ "mongodb://:27017/database/collection",
+ "mongodb://localhost:27017//collection",
+ "mongodb://localhost/database/collection",
+ "mongodb://localhost:/database/collection");
+
+ for (String badLocation : badLocations) {
+ Table table = fakeTable("TEST", badLocation);
+ assertThrows(IllegalArgumentException.class, () -> provider.buildBeamSqlTable(table));
+ }
+ }
+
+ private static Table fakeTable(String name, String location) {
+ return Table.builder()
+ .name(name)
+ .comment(name + " table")
+ .location(location)
+ .schema(
+ Stream.of(
+ Schema.Field.nullable("id", Schema.FieldType.INT32),
+ Schema.Field.nullable("name", Schema.FieldType.STRING))
+ .collect(toSchema()))
+ .type("mongodb")
+ .build();
+ }
+}
diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/mongodb/MongoDbTableTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/mongodb/MongoDbTableTest.java
new file mode 100644
index 0000000..cccac9c
--- /dev/null
+++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/mongodb/MongoDbTableTest.java
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.extensions.sql.meta.provider.mongodb;
+
+import static org.apache.beam.sdk.schemas.Schema.FieldType.BOOLEAN;
+import static org.apache.beam.sdk.schemas.Schema.FieldType.BYTE;
+import static org.apache.beam.sdk.schemas.Schema.FieldType.DOUBLE;
+import static org.apache.beam.sdk.schemas.Schema.FieldType.FLOAT;
+import static org.apache.beam.sdk.schemas.Schema.FieldType.INT16;
+import static org.apache.beam.sdk.schemas.Schema.FieldType.INT32;
+import static org.apache.beam.sdk.schemas.Schema.FieldType.INT64;
+import static org.apache.beam.sdk.schemas.Schema.FieldType.STRING;
+
+import java.util.Arrays;
+import org.apache.beam.sdk.extensions.sql.meta.provider.mongodb.MongoDbTable.DocumentToRow;
+import org.apache.beam.sdk.schemas.Schema;
+import org.apache.beam.sdk.schemas.Schema.FieldType;
+import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.Row;
+import org.bson.Document;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class MongoDbTableTest {
+
+ private static final Schema SCHEMA =
+ Schema.builder()
+ .addNullableField("long", INT64)
+ .addNullableField("int32", INT32)
+ .addNullableField("int16", INT16)
+ .addNullableField("byte", BYTE)
+ .addNullableField("bool", BOOLEAN)
+ .addNullableField("double", DOUBLE)
+ .addNullableField("float", FLOAT)
+ .addNullableField("string", STRING)
+ .addRowField("nested", Schema.builder().addNullableField("int32", INT32).build())
+ .addNullableField("arr", FieldType.array(STRING))
+ .build();
+ private static final String JSON_ROW =
+ "{ "
+ + "\"long\" : 9223372036854775807, "
+ + "\"int32\" : 2147483647, "
+ + "\"int16\" : 32767, "
+ + "\"byte\" : 127, "
+ + "\"bool\" : true, "
+ + "\"double\" : 1.0, "
+ + "\"float\" : 1.0, "
+ + "\"string\" : \"string\", "
+ + "\"nested\" : {\"int32\" : 2147483645}, "
+ + "\"arr\" : [\"str1\", \"str2\", \"str3\"]"
+ + " }";
+
+ @Rule public transient TestPipeline pipeline = TestPipeline.create();
+
+ @Test
+ public void testDocumentToRowConverter() {
+ PCollection<Row> output =
+ pipeline
+ .apply("Create document from JSON", Create.<Document>of(Document.parse(JSON_ROW)))
+ .apply("CConvert document to Row", DocumentToRow.withSchema(SCHEMA));
+
+ // Make sure proper rows are constructed from JSON.
+ PAssert.that(output)
+ .containsInAnyOrder(
+ row(
+ SCHEMA,
+ 9223372036854775807L,
+ 2147483647,
+ (short) 32767,
+ (byte) 127,
+ true,
+ 1.0,
+ (float) 1.0,
+ "string",
+ row(Schema.builder().addNullableField("int32", INT32).build(), 2147483645),
+ Arrays.asList("str1", "str2", "str3")));
+
+ pipeline.run().waitUntilFinish();
+ }
+
+ private Row row(Schema schema, Object... values) {
+ return Row.withSchema(schema).addValues(values).build();
+ }
+}
diff --git a/sdks/java/extensions/zetasketch/build.gradle b/sdks/java/extensions/zetasketch/build.gradle
index 30e8bc8..e19da15 100644
--- a/sdks/java/extensions/zetasketch/build.gradle
+++ b/sdks/java/extensions/zetasketch/build.gradle
@@ -35,7 +35,7 @@
testCompile library.java.junit
testCompile project(":sdks:java:io:google-cloud-platform")
testRuntimeOnly library.java.slf4j_simple
- testRuntimeOnly project(":runners:direct-java")
+ testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
testRuntimeOnly project(":runners:google-cloud-dataflow-java")
}
diff --git a/sdks/java/extensions/zetasketch/src/main/java/org/apache/beam/sdk/extensions/zetasketch/HllCount.java b/sdks/java/extensions/zetasketch/src/main/java/org/apache/beam/sdk/extensions/zetasketch/HllCount.java
index 5a975da..e4851bd 100644
--- a/sdks/java/extensions/zetasketch/src/main/java/org/apache/beam/sdk/extensions/zetasketch/HllCount.java
+++ b/sdks/java/extensions/zetasketch/src/main/java/org/apache/beam/sdk/extensions/zetasketch/HllCount.java
@@ -18,6 +18,8 @@
package org.apache.beam.sdk.extensions.zetasketch;
import com.google.zetasketch.HyperLogLogPlusPlus;
+import java.nio.ByteBuffer;
+import javax.annotation.Nullable;
import org.apache.beam.sdk.annotations.Experimental;
import org.apache.beam.sdk.transforms.Combine;
import org.apache.beam.sdk.transforms.DoFn;
@@ -108,6 +110,23 @@
private HllCount() {}
/**
+ * Converts the passed-in sketch from {@code ByteBuffer} to {@code byte[]}, mapping {@code null
+ * ByteBuffer}s (representing empty sketches) to empty {@code byte[]}s.
+ *
+ * <p>Utility method to convert sketches materialized with ZetaSQL/BigQuery to valid inputs for
+ * Beam {@code HllCount} transforms.
+ */
+ public static byte[] getSketchFromByteBuffer(@Nullable ByteBuffer bf) {
+ if (bf == null) {
+ return new byte[0];
+ } else {
+ byte[] result = new byte[bf.remaining()];
+ bf.get(result);
+ return result;
+ }
+ }
+
+ /**
* Provides {@code PTransform}s to aggregate inputs into HLL++ sketches. The four supported input
* types are {@code Integer}, {@code Long}, {@code String}, and {@code byte[]}.
*
diff --git a/sdks/java/extensions/zetasketch/src/test/java/org/apache/beam/sdk/extensions/zetasketch/BigQueryHllSketchCompatibilityIT.java b/sdks/java/extensions/zetasketch/src/test/java/org/apache/beam/sdk/extensions/zetasketch/BigQueryHllSketchCompatibilityIT.java
index 3f7927d..462a715 100644
--- a/sdks/java/extensions/zetasketch/src/test/java/org/apache/beam/sdk/extensions/zetasketch/BigQueryHllSketchCompatibilityIT.java
+++ b/sdks/java/extensions/zetasketch/src/test/java/org/apache/beam/sdk/extensions/zetasketch/BigQueryHllSketchCompatibilityIT.java
@@ -44,6 +44,7 @@
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.TypeDescriptor;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
@@ -65,23 +66,32 @@
private static final List<String> TEST_DATA =
Arrays.asList("Apple", "Orange", "Banana", "Orange");
- // Data Table: used by testReadSketchFromBigQuery())
+ // Data Table: used by tests reading sketches from BigQuery
// Schema: only one STRING field named "data".
- // Content: prepopulated with 4 rows: "Apple", "Orange", "Banana", "Orange"
- private static final String DATA_TABLE_ID = "hll_data";
private static final String DATA_FIELD_NAME = "data";
private static final String DATA_FIELD_TYPE = "STRING";
private static final String QUERY_RESULT_FIELD_NAME = "sketch";
- private static final Long EXPECTED_COUNT = 3L;
- // Sketch Table: used by testWriteSketchToBigQuery()
+ // Content: prepopulated with 4 rows: "Apple", "Orange", "Banana", "Orange"
+ private static final String DATA_TABLE_ID_NON_EMPTY = "hll_data_non_empty";
+ private static final Long EXPECTED_COUNT_NON_EMPTY = 3L;
+
+ // Content: empty
+ private static final String DATA_TABLE_ID_EMPTY = "hll_data_empty";
+ private static final Long EXPECTED_COUNT_EMPTY = 0L;
+
+ // Sketch Table: used by tests writing sketches to BigQuery
// Schema: only one BYTES field named "sketch".
- // Content: will be overridden by the sketch computed by the test pipeline each time the test runs
- private static final String SKETCH_TABLE_ID = "hll_sketch";
private static final String SKETCH_FIELD_NAME = "sketch";
private static final String SKETCH_FIELD_TYPE = "BYTES";
+
+ // Content: will be overridden by the sketch computed by the test pipeline each time the test runs
+ private static final String SKETCH_TABLE_ID = "hll_sketch";
// SHA-1 hash of string "[3]", the string representation of a row that has only one field 3 in it
- private static final String EXPECTED_CHECKSUM = "f1e31df9806ce94c5bdbbfff9608324930f4d3f1";
+ private static final String EXPECTED_CHECKSUM_NON_EMPTY =
+ "f1e31df9806ce94c5bdbbfff9608324930f4d3f1";
+ // SHA-1 hash of string "[0]", the string representation of a row that has only one field 0 in it
+ private static final String EXPECTED_CHECKSUM_EMPTY = "1184f5b8d4b6dd08709cf1513f26744167065e0d";
static {
ApplicationNameOptions options =
@@ -93,31 +103,40 @@
}
@BeforeClass
- public static void prepareDatasetAndDataTable() throws Exception {
+ public static void prepareDatasetAndDataTables() throws Exception {
BIGQUERY_CLIENT.createNewDataset(PROJECT_ID, DATASET_ID);
- // Create Data Table
TableSchema dataTableSchema =
new TableSchema()
.setFields(
Collections.singletonList(
new TableFieldSchema().setName(DATA_FIELD_NAME).setType(DATA_FIELD_TYPE)));
- Table dataTable =
+
+ Table dataTableNonEmpty =
new Table()
.setSchema(dataTableSchema)
.setTableReference(
new TableReference()
.setProjectId(PROJECT_ID)
.setDatasetId(DATASET_ID)
- .setTableId(DATA_TABLE_ID));
- BIGQUERY_CLIENT.createNewTable(PROJECT_ID, DATASET_ID, dataTable);
-
- // Prepopulate test data to Data Table
+ .setTableId(DATA_TABLE_ID_NON_EMPTY));
+ BIGQUERY_CLIENT.createNewTable(PROJECT_ID, DATASET_ID, dataTableNonEmpty);
+ // Prepopulates dataTableNonEmpty with TEST_DATA
List<Map<String, Object>> rows =
TEST_DATA.stream()
.map(v -> Collections.singletonMap(DATA_FIELD_NAME, (Object) v))
.collect(Collectors.toList());
- BIGQUERY_CLIENT.insertDataToTable(PROJECT_ID, DATASET_ID, DATA_TABLE_ID, rows);
+ BIGQUERY_CLIENT.insertDataToTable(PROJECT_ID, DATASET_ID, DATA_TABLE_ID_NON_EMPTY, rows);
+
+ Table dataTableEmpty =
+ new Table()
+ .setSchema(dataTableSchema)
+ .setTableReference(
+ new TableReference()
+ .setProjectId(PROJECT_ID)
+ .setDatasetId(DATASET_ID)
+ .setTableId(DATA_TABLE_ID_EMPTY));
+ BIGQUERY_CLIENT.createNewTable(PROJECT_ID, DATASET_ID, dataTableEmpty);
}
@AfterClass
@@ -126,22 +145,41 @@
}
/**
- * Test that HLL++ sketch computed in BigQuery can be processed by Beam. Hll sketch is computed by
- * {@code HLL_COUNT.INIT} in BigQuery and read into Beam; the test verifies that we can run {@link
- * HllCount.MergePartial} and {@link HllCount.Extract} on the sketch in Beam to get the correct
- * estimated count.
+ * Tests that a non-empty HLL++ sketch computed in BigQuery can be processed by Beam.
+ *
+ * <p>The Hll sketch is computed by {@code HLL_COUNT.INIT} in BigQuery and read into Beam; the
+ * test verifies that we can run {@link HllCount.MergePartial} and {@link HllCount.Extract} on the
+ * sketch in Beam to get the correct estimated count.
*/
@Test
- public void testReadSketchFromBigQuery() {
- String tableSpec = String.format("%s.%s", DATASET_ID, DATA_TABLE_ID);
+ public void testReadNonEmptySketchFromBigQuery() {
+ readSketchFromBigQuery(DATA_TABLE_ID_NON_EMPTY, EXPECTED_COUNT_NON_EMPTY);
+ }
+
+ /**
+ * Tests that an empty HLL++ sketch computed in BigQuery can be processed by Beam.
+ *
+ * <p>The Hll sketch is computed by {@code HLL_COUNT.INIT} in BigQuery and read into Beam; the
+ * test verifies that we can run {@link HllCount.MergePartial} and {@link HllCount.Extract} on the
+ * sketch in Beam to get the correct estimated count.
+ */
+ @Test
+ public void testReadEmptySketchFromBigQuery() {
+ readSketchFromBigQuery(DATA_TABLE_ID_EMPTY, EXPECTED_COUNT_EMPTY);
+ }
+
+ private void readSketchFromBigQuery(String tableId, Long expectedCount) {
+ String tableSpec = String.format("%s.%s", DATASET_ID, tableId);
String query =
String.format(
"SELECT HLL_COUNT.INIT(%s) AS %s FROM %s",
DATA_FIELD_NAME, QUERY_RESULT_FIELD_NAME, tableSpec);
+
SerializableFunction<SchemaAndRecord, byte[]> parseQueryResultToByteArray =
- (SchemaAndRecord schemaAndRecord) ->
+ input ->
// BigQuery BYTES type corresponds to Java java.nio.ByteBuffer type
- ((ByteBuffer) schemaAndRecord.getRecord().get(QUERY_RESULT_FIELD_NAME)).array();
+ HllCount.getSketchFromByteBuffer(
+ (ByteBuffer) input.getRecord().get(QUERY_RESULT_FIELD_NAME));
TestPipelineOptions options =
TestPipeline.testingPipelineOptions().as(TestPipelineOptions.class);
@@ -156,17 +194,35 @@
.withCoder(ByteArrayCoder.of()))
.apply(HllCount.MergePartial.globally()) // no-op, only for testing MergePartial
.apply(HllCount.Extract.globally());
- PAssert.thatSingleton(result).isEqualTo(EXPECTED_COUNT);
+ PAssert.thatSingleton(result).isEqualTo(expectedCount);
p.run().waitUntilFinish();
}
/**
- * Test that HLL++ sketch computed in Beam can be processed by BigQuery. Hll sketch is computed by
- * {@link HllCount.Init} in Beam and written to BigQuery; the test verifies that we can run {@code
- * HLL_COUNT.EXTRACT()} on the sketch in BigQuery to get the correct estimated count.
+ * Tests that a non-empty HLL++ sketch computed in Beam can be processed by BigQuery.
+ *
+ * <p>The Hll sketch is computed by {@link HllCount.Init} in Beam and written to BigQuery; the
+ * test verifies that we can run {@code HLL_COUNT.EXTRACT()} on the sketch in BigQuery to get the
+ * correct estimated count.
*/
@Test
- public void testWriteSketchToBigQuery() {
+ public void testWriteNonEmptySketchToBigQuery() {
+ writeSketchToBigQuery(TEST_DATA, EXPECTED_CHECKSUM_NON_EMPTY);
+ }
+
+ /**
+ * Tests that an empty HLL++ sketch computed in Beam can be processed by BigQuery.
+ *
+ * <p>The Hll sketch is computed by {@link HllCount.Init} in Beam and written to BigQuery; the
+ * test verifies that we can run {@code HLL_COUNT.EXTRACT()} on the sketch in BigQuery to get the
+ * correct estimated count.
+ */
+ @Test
+ public void testWriteEmptySketchToBigQuery() {
+ writeSketchToBigQuery(Collections.emptyList(), EXPECTED_CHECKSUM_EMPTY);
+ }
+
+ private void writeSketchToBigQuery(List<String> testData, String expectedChecksum) {
String tableSpec = String.format("%s.%s", DATASET_ID, SKETCH_TABLE_ID);
String query =
String.format("SELECT HLL_COUNT.EXTRACT(%s) FROM %s", SKETCH_FIELD_NAME, tableSpec);
@@ -181,16 +237,20 @@
// After the pipeline finishes, BigqueryMatcher will send a query to retrieve the estimated
// count and verifies its correctness using checksum.
options.setOnSuccessMatcher(
- BigqueryMatcher.createUsingStandardSql(APP_NAME, PROJECT_ID, query, EXPECTED_CHECKSUM));
+ BigqueryMatcher.createUsingStandardSql(APP_NAME, PROJECT_ID, query, expectedChecksum));
Pipeline p = Pipeline.create(options);
- p.apply(Create.of(TEST_DATA))
+ p.apply(Create.of(testData).withType(TypeDescriptor.of(String.class)))
.apply(HllCount.Init.forStrings().globally())
.apply(
BigQueryIO.<byte[]>write()
.to(tableSpec)
.withSchema(tableSchema)
- .withFormatFunction(sketch -> new TableRow().set(SKETCH_FIELD_NAME, sketch))
+ .withFormatFunction(
+ sketch ->
+ // Empty sketch is represented by empty byte array in Beam and by null in
+ // BigQuery
+ new TableRow().set(SKETCH_FIELD_NAME, sketch.length == 0 ? null : sketch))
.withWriteDisposition(BigQueryIO.Write.WriteDisposition.WRITE_TRUNCATE));
p.run().waitUntilFinish();
}
diff --git a/sdks/java/extensions/zetasketch/src/test/java/org/apache/beam/sdk/extensions/zetasketch/HllCountTest.java b/sdks/java/extensions/zetasketch/src/test/java/org/apache/beam/sdk/extensions/zetasketch/HllCountTest.java
index 4ef3e6a..137e976 100644
--- a/sdks/java/extensions/zetasketch/src/test/java/org/apache/beam/sdk/extensions/zetasketch/HllCountTest.java
+++ b/sdks/java/extensions/zetasketch/src/test/java/org/apache/beam/sdk/extensions/zetasketch/HllCountTest.java
@@ -17,8 +17,11 @@
*/
package org.apache.beam.sdk.extensions.zetasketch;
+import static org.junit.Assert.assertArrayEquals;
+
import com.google.zetasketch.HyperLogLogPlusPlus;
import com.google.zetasketch.shaded.com.google.protobuf.ByteString;
+import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
@@ -484,4 +487,15 @@
PAssert.thatSingleton(result).isEqualTo(KV.of("k", 0L));
p.run();
}
+
+ @Test
+ public void testGetSketchFromByteBufferForEmptySketch() {
+ assertArrayEquals(HllCount.getSketchFromByteBuffer(null), EMPTY_SKETCH);
+ }
+
+ @Test
+ public void testGetSketchFromByteBufferForNonEmptySketch() {
+ ByteBuffer bf = ByteBuffer.wrap(INTS1_SKETCH);
+ assertArrayEquals(HllCount.getSketchFromByteBuffer(bf), INTS1_SKETCH);
+ }
}
diff --git a/sdks/java/io/amazon-web-services/build.gradle b/sdks/java/io/amazon-web-services/build.gradle
index ca88447..d7e4139 100644
--- a/sdks/java/io/amazon-web-services/build.gradle
+++ b/sdks/java/io/amazon-web-services/build.gradle
@@ -50,7 +50,7 @@
testCompile 'org.elasticmq:elasticmq-rest-sqs_2.12:0.14.1'
testCompile 'org.testcontainers:localstack:1.11.2'
testRuntimeOnly library.java.slf4j_jdk14
- testRuntimeOnly project(":runners:direct-java")
+ testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
}
test {
diff --git a/sdks/java/io/amazon-web-services2/build.gradle b/sdks/java/io/amazon-web-services2/build.gradle
index eb33c56..a52c827 100644
--- a/sdks/java/io/amazon-web-services2/build.gradle
+++ b/sdks/java/io/amazon-web-services2/build.gradle
@@ -42,7 +42,7 @@
testCompile library.java.junit
testCompile 'org.testcontainers:testcontainers:1.11.3'
testRuntimeOnly library.java.slf4j_jdk14
- testRuntimeOnly project(":runners:direct-java")
+ testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
}
test {
diff --git a/sdks/java/io/amqp/build.gradle b/sdks/java/io/amqp/build.gradle
index a4c35a3..6a6b3a5 100644
--- a/sdks/java/io/amqp/build.gradle
+++ b/sdks/java/io/amqp/build.gradle
@@ -35,5 +35,5 @@
testCompile library.java.activemq_amqp
testCompile library.java.activemq_junit
testRuntimeOnly library.java.slf4j_jdk14
- testRuntimeOnly project(":runners:direct-java")
+ testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
}
diff --git a/sdks/java/io/cassandra/build.gradle b/sdks/java/io/cassandra/build.gradle
index ca3015c4..36dbede 100644
--- a/sdks/java/io/cassandra/build.gradle
+++ b/sdks/java/io/cassandra/build.gradle
@@ -45,5 +45,5 @@
testCompile group: 'info.archinnov', name: 'achilles-junit', version: "$achilles_version"
testCompile library.java.jackson_jaxb_annotations
testRuntimeOnly library.java.slf4j_jdk14
- testRuntimeOnly project(":runners:direct-java")
+ testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
}
diff --git a/sdks/java/io/clickhouse/build.gradle b/sdks/java/io/clickhouse/build.gradle
index ee8d382..ec102b4 100644
--- a/sdks/java/io/clickhouse/build.gradle
+++ b/sdks/java/io/clickhouse/build.gradle
@@ -60,5 +60,5 @@
testCompile library.java.hamcrest_library
testCompile "org.testcontainers:clickhouse:$testcontainers_version"
testRuntimeOnly library.java.slf4j_jdk14
- testRuntimeOnly project(":runners:direct-java")
+ testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
}
diff --git a/sdks/java/io/elasticsearch-tests/elasticsearch-tests-2/build.gradle b/sdks/java/io/elasticsearch-tests/elasticsearch-tests-2/build.gradle
index 0788f89..482fc7d 100644
--- a/sdks/java/io/elasticsearch-tests/elasticsearch-tests-2/build.gradle
+++ b/sdks/java/io/elasticsearch-tests/elasticsearch-tests-2/build.gradle
@@ -48,5 +48,5 @@
testCompile library.java.vendored_guava_26_0_jre
testCompile "org.elasticsearch:elasticsearch:$elastic_search_version"
testRuntimeOnly library.java.slf4j_jdk14
- testRuntimeOnly project(":runners:direct-java")
+ testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
}
diff --git a/sdks/java/io/elasticsearch-tests/elasticsearch-tests-5/build.gradle b/sdks/java/io/elasticsearch-tests/elasticsearch-tests-5/build.gradle
index 543c0db..2e13700 100644
--- a/sdks/java/io/elasticsearch-tests/elasticsearch-tests-5/build.gradle
+++ b/sdks/java/io/elasticsearch-tests/elasticsearch-tests-5/build.gradle
@@ -65,5 +65,5 @@
testCompile library.java.junit
testCompile "org.elasticsearch.client:elasticsearch-rest-client:$elastic_search_version"
testRuntimeOnly library.java.slf4j_jdk14
- testRuntimeOnly project(":runners:direct-java")
+ testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
}
diff --git a/sdks/java/io/elasticsearch-tests/elasticsearch-tests-6/build.gradle b/sdks/java/io/elasticsearch-tests/elasticsearch-tests-6/build.gradle
index 93b59b5..b7bf6d0 100644
--- a/sdks/java/io/elasticsearch-tests/elasticsearch-tests-6/build.gradle
+++ b/sdks/java/io/elasticsearch-tests/elasticsearch-tests-6/build.gradle
@@ -65,5 +65,5 @@
testCompile library.java.junit
testCompile "org.elasticsearch.client:elasticsearch-rest-client:$elastic_search_version"
testRuntimeOnly library.java.slf4j_jdk14
- testRuntimeOnly project(":runners:direct-java")
+ testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
}
diff --git a/sdks/java/io/elasticsearch-tests/elasticsearch-tests-common/build.gradle b/sdks/java/io/elasticsearch-tests/elasticsearch-tests-common/build.gradle
index 168fc5b..53a1ff4 100644
--- a/sdks/java/io/elasticsearch-tests/elasticsearch-tests-common/build.gradle
+++ b/sdks/java/io/elasticsearch-tests/elasticsearch-tests-common/build.gradle
@@ -48,5 +48,5 @@
testCompile library.java.junit
testCompile "org.elasticsearch.client:elasticsearch-rest-client:6.4.0"
testRuntimeOnly library.java.slf4j_jdk14
- testRuntimeOnly project(":runners:direct-java")
+ testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
}
diff --git a/sdks/java/io/file-based-io-tests/build.gradle b/sdks/java/io/file-based-io-tests/build.gradle
index 845a9a8..fa737cb 100644
--- a/sdks/java/io/file-based-io-tests/build.gradle
+++ b/sdks/java/io/file-based-io-tests/build.gradle
@@ -33,4 +33,5 @@
testCompile library.java.junit
testCompile library.java.hamcrest_core
testCompile library.java.jaxb_api
+ testRuntime library.java.hadoop_client
}
diff --git a/sdks/java/io/hadoop-file-system/build.gradle b/sdks/java/io/hadoop-file-system/build.gradle
index 8ebdc93..26f9db3 100644
--- a/sdks/java/io/hadoop-file-system/build.gradle
+++ b/sdks/java/io/hadoop-file-system/build.gradle
@@ -39,5 +39,5 @@
testCompile library.java.hadoop_minicluster
testCompile library.java.hadoop_hdfs_tests
testRuntimeOnly library.java.slf4j_jdk14
- testRuntimeOnly project(":runners:direct-java")
+ testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
}
diff --git a/sdks/java/io/hadoop-format/build.gradle b/sdks/java/io/hadoop-format/build.gradle
index 20dba8d..d575d40 100644
--- a/sdks/java/io/hadoop-format/build.gradle
+++ b/sdks/java/io/hadoop-format/build.gradle
@@ -82,7 +82,7 @@
testCompile library.java.hamcrest_core
testCompile library.java.hamcrest_library
testRuntimeOnly library.java.slf4j_jdk14
- testRuntimeOnly project(":runners:direct-java")
+ testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
compile library.java.commons_io_2x
delegate.add("sparkRunner", project(":sdks:java:io:hadoop-format"))
diff --git a/sdks/java/io/hbase/build.gradle b/sdks/java/io/hbase/build.gradle
index 882eb5e..e2dd902 100644
--- a/sdks/java/io/hbase/build.gradle
+++ b/sdks/java/io/hbase/build.gradle
@@ -64,6 +64,6 @@
}
testCompile "org.apache.hbase:hbase-hadoop-compat:$hbase_version:tests"
testCompile "org.apache.hbase:hbase-hadoop2-compat:$hbase_version:tests"
- testRuntimeOnly project(":runners:direct-java")
+ testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
}
diff --git a/sdks/java/io/hcatalog/build.gradle b/sdks/java/io/hcatalog/build.gradle
index da0ee24..e32277d 100644
--- a/sdks/java/io/hcatalog/build.gradle
+++ b/sdks/java/io/hcatalog/build.gradle
@@ -67,6 +67,6 @@
testCompile "org.apache.hive:hive-exec:$hive_version"
testCompile "org.apache.hive:hive-common:$hive_version"
testCompile "org.apache.hive:hive-cli:$hive_version"
- testRuntimeOnly project(":runners:direct-java")
+ testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
}
diff --git a/sdks/java/io/jdbc/build.gradle b/sdks/java/io/jdbc/build.gradle
index d8ce5f2..ad59ab1 100644
--- a/sdks/java/io/jdbc/build.gradle
+++ b/sdks/java/io/jdbc/build.gradle
@@ -40,5 +40,5 @@
testCompile "org.apache.derby:derbyclient:10.14.2.0"
testCompile "org.apache.derby:derbynet:10.14.2.0"
testRuntimeOnly library.java.slf4j_jdk14
- testRuntimeOnly project(":runners:direct-java")
+ testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
}
diff --git a/sdks/java/io/jms/build.gradle b/sdks/java/io/jms/build.gradle
index 8056106..3886d23 100644
--- a/sdks/java/io/jms/build.gradle
+++ b/sdks/java/io/jms/build.gradle
@@ -37,5 +37,5 @@
testCompile library.java.hamcrest_core
testCompile library.java.hamcrest_library
testRuntimeOnly library.java.slf4j_jdk14
- testRuntimeOnly project(":runners:direct-java")
+ testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
}
diff --git a/sdks/java/io/kafka/build.gradle b/sdks/java/io/kafka/build.gradle
index da8655e..d2d79c7 100644
--- a/sdks/java/io/kafka/build.gradle
+++ b/sdks/java/io/kafka/build.gradle
@@ -46,5 +46,5 @@
testCompile library.java.powermock
testCompile library.java.powermock_mockito
testRuntimeOnly library.java.slf4j_jdk14
- testRuntimeOnly project(":runners:direct-java")
+ testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
}
diff --git a/sdks/java/io/kinesis/build.gradle b/sdks/java/io/kinesis/build.gradle
index a73c770..ccc0dce 100644
--- a/sdks/java/io/kinesis/build.gradle
+++ b/sdks/java/io/kinesis/build.gradle
@@ -50,5 +50,5 @@
testCompile library.java.powermock_mockito
testCompile "org.assertj:assertj-core:3.11.1"
testRuntimeOnly library.java.slf4j_jdk14
- testRuntimeOnly project(":runners:direct-java")
+ testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
}
diff --git a/sdks/java/io/kudu/build.gradle b/sdks/java/io/kudu/build.gradle
index 8619d67..b32a02d 100644
--- a/sdks/java/io/kudu/build.gradle
+++ b/sdks/java/io/kudu/build.gradle
@@ -45,6 +45,6 @@
testCompile library.java.hamcrest_library
testCompile library.java.junit
testRuntimeOnly library.java.slf4j_jdk14
- testRuntimeOnly project(":runners:direct-java")
+ testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
}
diff --git a/sdks/java/io/mongodb/build.gradle b/sdks/java/io/mongodb/build.gradle
index d8ac11c..040f88e 100644
--- a/sdks/java/io/mongodb/build.gradle
+++ b/sdks/java/io/mongodb/build.gradle
@@ -38,5 +38,5 @@
testCompile "de.flapdoodle.embed:de.flapdoodle.embed.mongo:2.2.0"
testCompile "de.flapdoodle.embed:de.flapdoodle.embed.process:2.1.2"
testRuntimeOnly library.java.slf4j_jdk14
- testRuntimeOnly project(":runners:direct-java")
+ testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
}
diff --git a/sdks/java/io/mongodb/src/main/java/org/apache/beam/sdk/io/mongodb/MongoDbIO.java b/sdks/java/io/mongodb/src/main/java/org/apache/beam/sdk/io/mongodb/MongoDbIO.java
index 4bdbbf4..9fd06d3 100644
--- a/sdks/java/io/mongodb/src/main/java/org/apache/beam/sdk/io/mongodb/MongoDbIO.java
+++ b/sdks/java/io/mongodb/src/main/java/org/apache/beam/sdk/io/mongodb/MongoDbIO.java
@@ -323,6 +323,13 @@
return input.apply(org.apache.beam.sdk.io.Read.from(new BoundedMongoDbSource(this)));
}
+ public long getDocumentCount() {
+ checkArgument(uri() != null, "withUri() is required");
+ checkArgument(database() != null, "withDatabase() is required");
+ checkArgument(collection() != null, "withCollection() is required");
+ return new BoundedMongoDbSource(this).getDocumentCount();
+ }
+
@Override
public void populateDisplayData(DisplayData.Builder builder) {
super.populateDisplayData(builder);
@@ -376,6 +383,38 @@
return new BoundedMongoDbReader(this);
}
+ /**
+ * Returns number of Documents in a collection.
+ *
+ * @return Positive number of Documents in a collection or -1 on error.
+ */
+ long getDocumentCount() {
+ try (MongoClient mongoClient =
+ new MongoClient(
+ new MongoClientURI(
+ spec.uri(),
+ getOptions(
+ spec.maxConnectionIdleTime(),
+ spec.sslEnabled(),
+ spec.sslInvalidHostNameAllowed())))) {
+ return getDocumentCount(mongoClient, spec.database(), spec.collection());
+ } catch (Exception e) {
+ return -1;
+ }
+ }
+
+ private long getDocumentCount(MongoClient mongoClient, String database, String collection) {
+ MongoDatabase mongoDatabase = mongoClient.getDatabase(database);
+
+ // get the Mongo collStats object
+ // it gives the size for the entire collection
+ BasicDBObject stat = new BasicDBObject();
+ stat.append("collStats", collection);
+ Document stats = mongoDatabase.runCommand(stat);
+
+ return stats.get("count", Number.class).longValue();
+ }
+
@Override
public long getEstimatedSizeBytes(PipelineOptions pipelineOptions) {
try (MongoClient mongoClient =
diff --git a/sdks/java/io/mqtt/build.gradle b/sdks/java/io/mqtt/build.gradle
index 9d7b188..a384274 100644
--- a/sdks/java/io/mqtt/build.gradle
+++ b/sdks/java/io/mqtt/build.gradle
@@ -37,5 +37,5 @@
testCompile library.java.hamcrest_core
testCompile library.java.hamcrest_library
testRuntimeOnly library.java.slf4j_jdk14
- testRuntimeOnly project(":runners:direct-java")
+ testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
}
diff --git a/sdks/java/io/parquet/build.gradle b/sdks/java/io/parquet/build.gradle
index 7a038f3..2e14f77 100644
--- a/sdks/java/io/parquet/build.gradle
+++ b/sdks/java/io/parquet/build.gradle
@@ -32,12 +32,11 @@
compile "org.apache.parquet:parquet-common:$parquet_version"
compile "org.apache.parquet:parquet-hadoop:$parquet_version"
compile library.java.avro
- compile library.java.hadoop_client
- compile library.java.hadoop_common
+ provided library.java.hadoop_client
testCompile project(path: ":sdks:java:core", configuration: "shadowTest")
testCompile library.java.junit
testCompile library.java.hamcrest_core
testCompile library.java.hamcrest_library
testRuntimeOnly library.java.slf4j_jdk14
- testRuntimeOnly project(":runners:direct-java")
+ testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
}
diff --git a/sdks/java/io/rabbitmq/build.gradle b/sdks/java/io/rabbitmq/build.gradle
index 57415f9..cf47712 100644
--- a/sdks/java/io/rabbitmq/build.gradle
+++ b/sdks/java/io/rabbitmq/build.gradle
@@ -36,5 +36,5 @@
testCompile library.java.hamcrest_library
testCompile library.java.slf4j_api
testRuntimeOnly library.java.slf4j_jdk14
- testRuntimeOnly project(":runners:direct-java")
+ testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
}
diff --git a/sdks/java/io/redis/build.gradle b/sdks/java/io/redis/build.gradle
index 2626400..e205155 100644
--- a/sdks/java/io/redis/build.gradle
+++ b/sdks/java/io/redis/build.gradle
@@ -32,5 +32,5 @@
testCompile library.java.hamcrest_library
testCompile "com.github.kstyrc:embedded-redis:0.6"
testRuntimeOnly library.java.slf4j_jdk14
- testRuntimeOnly project(":runners:direct-java")
+ testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
}
diff --git a/sdks/java/io/solr/build.gradle b/sdks/java/io/solr/build.gradle
index 928d6db..c2a9ddf 100644
--- a/sdks/java/io/solr/build.gradle
+++ b/sdks/java/io/solr/build.gradle
@@ -38,5 +38,5 @@
testCompile "org.apache.solr:solr-core:5.5.4"
testCompile "com.carrotsearch.randomizedtesting:randomizedtesting-runner:2.3.2"
testRuntimeOnly library.java.slf4j_jdk14
- testRuntimeOnly project(":runners:direct-java")
+ testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
}
diff --git a/sdks/java/io/synthetic/build.gradle b/sdks/java/io/synthetic/build.gradle
index 52c794b..e1d2abc 100644
--- a/sdks/java/io/synthetic/build.gradle
+++ b/sdks/java/io/synthetic/build.gradle
@@ -34,5 +34,5 @@
testCompile library.java.junit
testCompile library.java.hamcrest_core
testCompile library.java.hamcrest_library
- testRuntimeOnly project(":runners:direct-java")
+ testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
}
diff --git a/sdks/java/io/tika/build.gradle b/sdks/java/io/tika/build.gradle
index 813a692..1aae4d1 100644
--- a/sdks/java/io/tika/build.gradle
+++ b/sdks/java/io/tika/build.gradle
@@ -36,5 +36,5 @@
testCompile library.java.hamcrest_library
testCompile "org.apache.tika:tika-parsers:$tika_version"
testCompileOnly "biz.aQute:bndlib:$bndlib_version"
- testRuntimeOnly project(":runners:direct-java")
+ testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
}
diff --git a/sdks/java/io/xml/build.gradle b/sdks/java/io/xml/build.gradle
index 2cb9077..7b59671 100644
--- a/sdks/java/io/xml/build.gradle
+++ b/sdks/java/io/xml/build.gradle
@@ -33,5 +33,5 @@
testCompile library.java.hamcrest_core
testCompile library.java.hamcrest_library
testRuntimeOnly library.java.slf4j_jdk14
- testRuntimeOnly project(":runners:direct-java")
+ testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
}
diff --git a/sdks/java/testing/expansion-service/build.gradle b/sdks/java/testing/expansion-service/build.gradle
index bfbcc23..09d830c 100644
--- a/sdks/java/testing/expansion-service/build.gradle
+++ b/sdks/java/testing/expansion-service/build.gradle
@@ -28,6 +28,7 @@
compile project(path: ":runners:core-construction-java")
compile project(path: ":sdks:java:io:parquet")
compile project(path: ":sdks:java:core", configuration: "shadow")
+ runtimeOnly library.java.hadoop_client
}
task runTestExpansionService (type: JavaExec) {
diff --git a/sdks/java/testing/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/Query10.java b/sdks/java/testing/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/Query10.java
index 89b0cc6..aa133e9 100644
--- a/sdks/java/testing/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/Query10.java
+++ b/sdks/java/testing/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/Query10.java
@@ -41,6 +41,7 @@
import org.apache.beam.sdk.transforms.windowing.AfterProcessingTime;
import org.apache.beam.sdk.transforms.windowing.AfterWatermark;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.DefaultTrigger;
import org.apache.beam.sdk.transforms.windowing.FixedWindows;
import org.apache.beam.sdk.transforms.windowing.PaneInfo;
import org.apache.beam.sdk.transforms.windowing.PaneInfo.Timing;
@@ -314,7 +315,7 @@
name + ".WindowLogFiles",
Window.<KV<Void, OutputFile>>into(
FixedWindows.of(Duration.standardSeconds(configuration.windowSizeSec)))
- .triggering(AfterWatermark.pastEndOfWindow())
+ .triggering(DefaultTrigger.of())
// We expect no late data here, but we'll assume the worst so we can detect any.
.withAllowedLateness(Duration.standardDays(1))
.discardingFiredPanes())
diff --git a/sdks/python/apache_beam/io/iobase.py b/sdks/python/apache_beam/io/iobase.py
index 5b66730..e21052f 100644
--- a/sdks/python/apache_beam/io/iobase.py
+++ b/sdks/python/apache_beam/io/iobase.py
@@ -35,6 +35,7 @@
import logging
import math
import random
+import threading
import uuid
from builtins import object
from builtins import range
@@ -1104,13 +1105,17 @@
class RestrictionTracker(object):
"""Manages concurrent access to a restriction.
- Experimental; no backwards-compatibility guarantees.
-
Keeps track of the restrictions claimed part for a Splittable DoFn.
+ The restriction may be modified by different threads, however the system will
+ ensure sufficient locking such that no methods on the restriction tracker
+ will be called concurrently.
+
See following documents for more details.
* https://s.apache.org/splittable-do-fn
* https://s.apache.org/splittable-do-fn-python-sdk
+
+ Experimental; no backwards-compatibility guarantees.
"""
def current_restriction(self):
@@ -1121,54 +1126,22 @@
The current restriction returned by method may be updated dynamically due
to due to concurrent invocation of other methods of the
- ``RestrictionTracker``, For example, ``checkpoint()``.
+ ``RestrictionTracker``, For example, ``split()``.
- ** Thread safety **
+ This API is required to be implemented.
- Methods of the class ``RestrictionTracker`` including this method may get
- invoked by different threads, hence must be made thread-safe, e.g. by using
- a single lock object.
-
- TODO(BEAM-7473): Remove thread safety requirements from API implementation.
+ Returns: a restriction object.
"""
raise NotImplementedError
def current_progress(self):
"""Returns a RestrictionProgress object representing the current progress.
+
+ This API is recommended to be implemented. The runner can do a better job
+ at parallel processing with better progress signals.
"""
raise NotImplementedError
- def current_watermark(self):
- """Returns current watermark. By default, not report watermark.
-
- TODO(BEAM-7473): Provide synchronization guarantee by using a wrapper.
- """
- return None
-
- def checkpoint(self):
- """Performs a checkpoint of the current restriction.
-
- Signals that the current ``DoFn.process()`` call should terminate as soon as
- possible. After this method returns, the tracker MUST refuse all future
- claim calls, and ``RestrictionTracker.check_done()`` MUST succeed.
-
- This invocation modifies the value returned by ``current_restriction()``
- invocation and returns a restriction representing the rest of the work. The
- old value of ``current_restriction()`` is equivalent to the new value of
- ``current_restriction()`` and the return value of this method invocation
- combined.
-
- ** Thread safety **
-
- Methods of the class ``RestrictionTracker`` including this method may get
- invoked by different threads, hence must be made thread-safe, e.g. by using
- a single lock object.
-
- TODO(BEAM-7473): Remove thread safety requirements from API implementation.
- """
-
- raise NotImplementedError
-
def check_done(self):
"""Checks whether the restriction has been fully processed.
@@ -1179,13 +1152,8 @@
remaining in the restriction when this method is invoked. Exception raised
must have an informative error message.
- ** Thread safety **
-
- Methods of the class ``RestrictionTracker`` including this method may get
- invoked by different threads, hence must be made thread-safe, e.g. by using
- a single lock object.
-
- TODO(BEAM-7473): Remove thread safety requirements from API implementation.
+ This API is required to be implemented in order to make sure no data loss
+ during SDK processing.
Returns: ``True`` if current restriction has been fully processed.
Raises:
@@ -1215,8 +1183,12 @@
restrictions returned would be [100, 179), [179, 200) (note: current_offset
+ fraction_of_remainder * remaining_work = 130 + 0.7 * 70 = 179).
- It is very important for pipeline scaling and end to end pipeline execution
- that try_split is implemented well.
+ ``fraction_of_remainder`` = 0 means a checkpoint is required.
+
+ The API is recommended to be implemented for batch pipeline given that it is
+ very important for pipeline scaling and end to end pipeline execution.
+
+ The API is required to be implemented for a streaming pipeline.
Args:
fraction_of_remainder: A hint as to the fraction of work the primary
@@ -1226,19 +1198,11 @@
Returns:
(primary_restriction, residual_restriction) if a split was possible,
otherwise returns ``None``.
-
- ** Thread safety **
-
- Methods of the class ``RestrictionTracker`` including this method may get
- invoked by different threads, hence must be made thread-safe, e.g. by using
- a single lock object.
-
- TODO(BEAM-7473): Remove thread safety requirements from API implementation.
"""
raise NotImplementedError
def try_claim(self, position):
- """ Attempts to claim the block of work in the current restriction
+ """Attempts to claim the block of work in the current restriction
identified by the given position.
If this succeeds, the DoFn MUST execute the entire block of work. If it
@@ -1247,40 +1211,137 @@
work from ``DoFn.process()`` is also not allowed before the first call of
this method).
+ The API is required to be implemented.
+
Args:
position: current position that wants to be claimed.
Returns: ``True`` if the position can be claimed as current_position.
Otherwise, returns ``False``.
-
- ** Thread safety **
-
- Methods of the class ``RestrictionTracker`` including this method may get
- invoked by different threads, hence must be made thread-safe, e.g. by using
- a single lock object.
-
- TODO(BEAM-7473): Remove thread safety requirements from API implementation.
"""
raise NotImplementedError
- def defer_remainder(self, watermark=None):
- """ Invokes checkpoint() in an SDF.process().
- TODO(BEAM-7472): Remove defer_remainder() once SDF.process() uses
- ``ProcessContinuation``.
+class ThreadsafeRestrictionTracker(object):
+ """A thread-safe wrapper which wraps a `RestritionTracker`.
+
+ This wrapper guarantees synchronization of modifying restrictions across
+ multi-thread.
+ """
+
+ def __init__(self, restriction_tracker):
+ if not isinstance(restriction_tracker, RestrictionTracker):
+ raise ValueError(
+ 'Initialize ThreadsafeRestrictionTracker requires'
+ 'RestrictionTracker.')
+ self._restriction_tracker = restriction_tracker
+ # Records an absolute timestamp when defer_remainder is called.
+ self._deferred_timestamp = None
+ self._lock = threading.RLock()
+ self._deferred_residual = None
+ self._deferred_watermark = None
+
+ def current_restriction(self):
+ with self._lock:
+ return self._restriction_tracker.current_restriction()
+
+ def try_claim(self, position):
+ with self._lock:
+ return self._restriction_tracker.try_claim(position)
+
+ def defer_remainder(self, deferred_time=None):
+ """Performs self-checkpoint on current processing restriction with an
+ expected resuming time.
+
+ Self-checkpoint could happen during processing elements. When executing an
+ DoFn.process(), you may want to stop processing an element and resuming
+ later if current element has been processed quit a long time or you also
+ want to have some outputs from other elements. ``defer_remainder()`` can be
+ called on per element if needed.
Args:
- watermark
+ deferred_time: A relative ``timestamp.Duration`` that indicates the ideal
+ time gap between now and resuming, or an absolute ``timestamp.Timestamp``
+ for resuming execution time. If the time_delay is None, the deferred work
+ will be executed as soon as possible.
"""
- raise NotImplementedError
+
+ # Record current time for calculating deferred_time later.
+ self._deferred_timestamp = timestamp.Timestamp.now()
+ if (deferred_time and
+ not isinstance(deferred_time, timestamp.Duration) and
+ not isinstance(deferred_time, timestamp.Timestamp)):
+ raise ValueError('The timestamp of deter_remainder() should be a '
+ 'Duration or a Timestamp, or None.')
+ self._deferred_watermark = deferred_time
+ checkpoint = self.try_split(0)
+ if checkpoint:
+ _, self._deferred_residual = checkpoint
+
+ def check_done(self):
+ with self._lock:
+ return self._restriction_tracker.check_done()
+
+ def current_progress(self):
+ with self._lock:
+ return self._restriction_tracker.current_progress()
+
+ def try_split(self, fraction_of_remainder):
+ with self._lock:
+ return self._restriction_tracker.try_split(fraction_of_remainder)
def deferred_status(self):
- """ Returns deferred_residual with deferred_watermark.
+ """Returns deferred work which is produced by ``defer_remainder()``.
- TODO(BEAM-7472): Remove defer_status() once SDF.process() uses
- ``ProcessContinuation``.
+ When there is a self-checkpoint performed, the system needs to fulfill the
+ DelayedBundleApplication with deferred_work for a ProcessBundleResponse.
+ The system calls this API to get deferred_residual with watermark together
+ to help the runner to schedule a future work.
+
+ Returns: (deferred_residual, time_delay) if having any residual, else None.
"""
- raise NotImplementedError
+ if self._deferred_residual:
+ # If _deferred_watermark is None, create Duration(0).
+ if not self._deferred_watermark:
+ self._deferred_watermark = timestamp.Duration()
+ # If an absolute timestamp is provided, calculate the delta between
+ # the absoluted time and the time deferred_status() is called.
+ elif isinstance(self._deferred_watermark, timestamp.Timestamp):
+ self._deferred_watermark = (self._deferred_watermark -
+ timestamp.Timestamp.now())
+ # If a Duration is provided, the deferred time should be:
+ # provided duration - the spent time since the defer_remainder() is
+ # called.
+ elif isinstance(self._deferred_watermark, timestamp.Duration):
+ self._deferred_watermark -= (timestamp.Timestamp.now() -
+ self._deferred_timestamp)
+ return self._deferred_residual, self._deferred_watermark
+
+
+class RestrictionTrackerView(object):
+ """A DoFn view of thread-safe RestrictionTracker.
+
+ The RestrictionTrackerView wraps a ThreadsafeRestrictionTracker and only
+ exposes APIs that will be called by a ``DoFn.process()``. During execution
+ time, the RestrictionTrackerView will be fed into the ``DoFn.process`` as a
+ restriction_tracker.
+ """
+
+ def __init__(self, threadsafe_restriction_tracker):
+ if not isinstance(threadsafe_restriction_tracker,
+ ThreadsafeRestrictionTracker):
+ raise ValueError('Initialize RestrictionTrackerView requires '
+ 'ThreadsafeRestrictionTracker.')
+ self._threadsafe_restriction_tracker = threadsafe_restriction_tracker
+
+ def current_restriction(self):
+ return self._threadsafe_restriction_tracker.current_restriction()
+
+ def try_claim(self, position):
+ return self._threadsafe_restriction_tracker.try_claim(position)
+
+ def defer_remainder(self, deferred_time=None):
+ self._threadsafe_restriction_tracker.defer_remainder(deferred_time)
class RestrictionProgress(object):
@@ -1400,17 +1461,8 @@
SourceBundle(residual_weight, self._source, split_pos,
stop_pos))
- def deferred_status(self):
- return None
-
- def current_watermark(self):
- return None
-
- def get_delegate_range_tracker(self):
- return self._delegate_range_tracker
-
- def get_tracking_source(self):
- return self._source
+ def check_done(self):
+ return self._delegate_range_tracker.fraction_consumed() >= 1.0
class _SDFBoundedSourceRestrictionProvider(core.RestrictionProvider):
"""A `RestrictionProvider` that is used by SDF for `BoundedSource`."""
@@ -1463,8 +1515,13 @@
restriction_tracker=core.DoFn.RestrictionParam(
_SDFBoundedSourceWrapper._SDFBoundedSourceRestrictionProvider(
source, chunk_size))):
- return restriction_tracker.get_tracking_source().read(
- restriction_tracker.get_delegate_range_tracker())
+ current_restriction = restriction_tracker.current_restriction()
+ assert isinstance(current_restriction, SourceBundle)
+ tracking_source = current_restriction.source
+ start = current_restriction.start_position
+ stop = current_restriction.stop_position
+ return tracking_source.read(tracking_source.get_range_tracker(start,
+ stop))
return SDFBoundedSourceDoFn(self.source)
diff --git a/sdks/python/apache_beam/io/iobase_test.py b/sdks/python/apache_beam/io/iobase_test.py
index 7adb764..0a6afae 100644
--- a/sdks/python/apache_beam/io/iobase_test.py
+++ b/sdks/python/apache_beam/io/iobase_test.py
@@ -19,6 +19,7 @@
from __future__ import absolute_import
+import time
import unittest
import mock
@@ -28,6 +29,9 @@
from apache_beam.io.concat_source_test import RangeSource
from apache_beam.io import iobase
from apache_beam.io.iobase import SourceBundle
+from apache_beam.io.restriction_trackers import OffsetRange
+from apache_beam.io.restriction_trackers import OffsetRestrictionTracker
+from apache_beam.utils import timestamp
from apache_beam.options.pipeline_options import DebugOptions
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
@@ -191,5 +195,87 @@
self._run_sdf_wrapper_pipeline(RangeSource(0, 4), [0, 1, 2, 3])
+class ThreadsafeRestrictionTrackerTest(unittest.TestCase):
+
+ def test_initialization(self):
+ with self.assertRaises(ValueError):
+ iobase.ThreadsafeRestrictionTracker(RangeSource(0, 1))
+
+ def test_defer_remainder_with_wrong_time_type(self):
+ threadsafe_tracker = iobase.ThreadsafeRestrictionTracker(
+ OffsetRestrictionTracker(OffsetRange(0, 10)))
+ with self.assertRaises(ValueError):
+ threadsafe_tracker.defer_remainder(10)
+
+ def test_self_checkpoint_immediately(self):
+ restriction_tracker = OffsetRestrictionTracker(OffsetRange(0, 10))
+ threadsafe_tracker = iobase.ThreadsafeRestrictionTracker(
+ restriction_tracker)
+ threadsafe_tracker.defer_remainder()
+ deferred_residual, deferred_time = threadsafe_tracker.deferred_status()
+ expected_residual = OffsetRange(0, 10)
+ self.assertEqual(deferred_residual, expected_residual)
+ self.assertTrue(isinstance(deferred_time, timestamp.Duration))
+ self.assertEqual(deferred_time, 0)
+
+ def test_self_checkpoint_with_relative_time(self):
+ threadsafe_tracker = iobase.ThreadsafeRestrictionTracker(
+ OffsetRestrictionTracker(OffsetRange(0, 10)))
+ threadsafe_tracker.defer_remainder(timestamp.Duration(100))
+ time.sleep(2)
+ _, deferred_time = threadsafe_tracker.deferred_status()
+ self.assertTrue(isinstance(deferred_time, timestamp.Duration))
+ # The expectation = 100 - 2 - some_delta
+ self.assertTrue(deferred_time <= 98)
+
+ def test_self_checkpoint_with_absolute_time(self):
+ threadsafe_tracker = iobase.ThreadsafeRestrictionTracker(
+ OffsetRestrictionTracker(OffsetRange(0, 10)))
+ now = timestamp.Timestamp.now()
+ schedule_time = now + timestamp.Duration(100)
+ self.assertTrue(isinstance(schedule_time, timestamp.Timestamp))
+ threadsafe_tracker.defer_remainder(schedule_time)
+ time.sleep(2)
+ _, deferred_time = threadsafe_tracker.deferred_status()
+ self.assertTrue(isinstance(deferred_time, timestamp.Duration))
+ # The expectation =
+ # schedule_time - the time when deferred_status is called - some_delta
+ self.assertTrue(deferred_time <= 98)
+
+
+class RestrictionTrackerViewTest(unittest.TestCase):
+
+ def test_initialization(self):
+ with self.assertRaises(ValueError):
+ iobase.RestrictionTrackerView(
+ OffsetRestrictionTracker(OffsetRange(0, 10)))
+
+ def test_api_expose(self):
+ threadsafe_tracker = iobase.ThreadsafeRestrictionTracker(
+ OffsetRestrictionTracker(OffsetRange(0, 10)))
+ tracker_view = iobase.RestrictionTrackerView(threadsafe_tracker)
+ current_restriction = tracker_view.current_restriction()
+ self.assertEqual(current_restriction, OffsetRange(0, 10))
+ self.assertTrue(tracker_view.try_claim(0))
+ tracker_view.defer_remainder()
+ deferred_remainder, deferred_watermark = (
+ threadsafe_tracker.deferred_status())
+ self.assertEqual(deferred_remainder, OffsetRange(1, 10))
+ self.assertEqual(deferred_watermark, timestamp.Duration())
+
+ def test_non_expose_apis(self):
+ threadsafe_tracker = iobase.ThreadsafeRestrictionTracker(
+ OffsetRestrictionTracker(OffsetRange(0, 10)))
+ tracker_view = iobase.RestrictionTrackerView(threadsafe_tracker)
+ with self.assertRaises(AttributeError):
+ tracker_view.check_done()
+ with self.assertRaises(AttributeError):
+ tracker_view.current_progress()
+ with self.assertRaises(AttributeError):
+ tracker_view.try_split()
+ with self.assertRaises(AttributeError):
+ tracker_view.deferred_status()
+
+
if __name__ == '__main__':
unittest.main()
diff --git a/sdks/python/apache_beam/io/restriction_trackers.py b/sdks/python/apache_beam/io/restriction_trackers.py
index 0ba5b23..20bb5c1 100644
--- a/sdks/python/apache_beam/io/restriction_trackers.py
+++ b/sdks/python/apache_beam/io/restriction_trackers.py
@@ -19,7 +19,6 @@
from __future__ import absolute_import
from __future__ import division
-import threading
from builtins import object
from apache_beam.io.iobase import RestrictionProgress
@@ -86,104 +85,69 @@
assert isinstance(offset_range, OffsetRange)
self._range = offset_range
self._current_position = None
- self._current_watermark = None
self._last_claim_attempt = None
- self._deferred_residual = None
self._checkpointed = False
- self._lock = threading.RLock()
def check_done(self):
- with self._lock:
- if self._last_claim_attempt < self._range.stop - 1:
- raise ValueError(
- 'OffsetRestrictionTracker is not done since work in range [%s, %s) '
- 'has not been claimed.'
- % (self._last_claim_attempt if self._last_claim_attempt is not None
- else self._range.start,
- self._range.stop))
+ if self._last_claim_attempt < self._range.stop - 1:
+ raise ValueError(
+ 'OffsetRestrictionTracker is not done since work in range [%s, %s) '
+ 'has not been claimed.'
+ % (self._last_claim_attempt if self._last_claim_attempt is not None
+ else self._range.start,
+ self._range.stop))
def current_restriction(self):
- with self._lock:
- return self._range
-
- def current_watermark(self):
- return self._current_watermark
+ return self._range
def current_progress(self):
- with self._lock:
- if self._current_position is None:
- fraction = 0.0
- elif self._range.stop == self._range.start:
- # If self._current_position is not None, we must be done.
- fraction = 1.0
- else:
- fraction = (
- float(self._current_position - self._range.start)
- / (self._range.stop - self._range.start))
+ if self._current_position is None:
+ fraction = 0.0
+ elif self._range.stop == self._range.start:
+ # If self._current_position is not None, we must be done.
+ fraction = 1.0
+ else:
+ fraction = (
+ float(self._current_position - self._range.start)
+ / (self._range.stop - self._range.start))
return RestrictionProgress(fraction=fraction)
def start_position(self):
- with self._lock:
- return self._range.start
+ return self._range.start
def stop_position(self):
- with self._lock:
- return self._range.stop
-
- def default_size(self):
- return self._range.size()
+ return self._range.stop
def try_claim(self, position):
- with self._lock:
- if self._last_claim_attempt and position <= self._last_claim_attempt:
- raise ValueError(
- 'Positions claimed should strictly increase. Trying to claim '
- 'position %d while last claim attempt was %d.'
- % (position, self._last_claim_attempt))
+ if self._last_claim_attempt and position <= self._last_claim_attempt:
+ raise ValueError(
+ 'Positions claimed should strictly increase. Trying to claim '
+ 'position %d while last claim attempt was %d.'
+ % (position, self._last_claim_attempt))
- self._last_claim_attempt = position
- if position < self._range.start:
- raise ValueError(
- 'Position to be claimed cannot be smaller than the start position '
- 'of the range. Tried to claim position %r for the range [%r, %r)'
- % (position, self._range.start, self._range.stop))
+ self._last_claim_attempt = position
+ if position < self._range.start:
+ raise ValueError(
+ 'Position to be claimed cannot be smaller than the start position '
+ 'of the range. Tried to claim position %r for the range [%r, %r)'
+ % (position, self._range.start, self._range.stop))
- if position >= self._range.start and position < self._range.stop:
- self._current_position = position
- return True
+ if position >= self._range.start and position < self._range.stop:
+ self._current_position = position
+ return True
- return False
+ return False
def try_split(self, fraction_of_remainder):
- with self._lock:
- if not self._checkpointed:
- if self._current_position is None:
- cur = self._range.start - 1
- else:
- cur = self._current_position
- split_point = (
- cur + int(max(1, (self._range.stop - cur) * fraction_of_remainder)))
- if split_point < self._range.stop:
- self._range, residual_range = self._range.split_at(split_point)
- return self._range, residual_range
-
- # TODO(SDF): Replace all calls with try_claim(0).
- def checkpoint(self):
- with self._lock:
- # If self._current_position is 'None' no records have been claimed so
- # residual should start from self._range.start.
+ if not self._checkpointed:
if self._current_position is None:
- end_position = self._range.start
+ cur = self._range.start - 1
else:
- end_position = self._current_position + 1
- self._range, residual_range = self._range.split_at(end_position)
- return residual_range
-
- def defer_remainder(self, watermark=None):
- with self._lock:
- self._deferred_watermark = watermark or self._current_watermark
- self._deferred_residual = self.checkpoint()
-
- def deferred_status(self):
- if self._deferred_residual:
- return (self._deferred_residual, self._deferred_watermark)
+ cur = self._current_position
+ split_point = (
+ cur + int(max(1, (self._range.stop - cur) * fraction_of_remainder)))
+ if split_point < self._range.stop:
+ if fraction_of_remainder == 0:
+ self._checkpointed = True
+ self._range, residual_range = self._range.split_at(split_point)
+ return self._range, residual_range
diff --git a/sdks/python/apache_beam/io/restriction_trackers_test.py b/sdks/python/apache_beam/io/restriction_trackers_test.py
index 459b039..4a57d98 100644
--- a/sdks/python/apache_beam/io/restriction_trackers_test.py
+++ b/sdks/python/apache_beam/io/restriction_trackers_test.py
@@ -81,14 +81,14 @@
def test_checkpoint_unstarted(self):
tracker = OffsetRestrictionTracker(OffsetRange(100, 200))
- checkpoint = tracker.checkpoint()
+ _, checkpoint = tracker.try_split(0)
self.assertEqual(OffsetRange(100, 100), tracker.current_restriction())
self.assertEqual(OffsetRange(100, 200), checkpoint)
def test_checkpoint_just_started(self):
tracker = OffsetRestrictionTracker(OffsetRange(100, 200))
self.assertTrue(tracker.try_claim(100))
- checkpoint = tracker.checkpoint()
+ _, checkpoint = tracker.try_split(0)
self.assertEqual(OffsetRange(100, 101), tracker.current_restriction())
self.assertEqual(OffsetRange(101, 200), checkpoint)
@@ -96,7 +96,7 @@
tracker = OffsetRestrictionTracker(OffsetRange(100, 200))
self.assertTrue(tracker.try_claim(105))
self.assertTrue(tracker.try_claim(110))
- checkpoint = tracker.checkpoint()
+ _, checkpoint = tracker.try_split(0)
self.assertEqual(OffsetRange(100, 111), tracker.current_restriction())
self.assertEqual(OffsetRange(111, 200), checkpoint)
@@ -105,9 +105,9 @@
self.assertTrue(tracker.try_claim(105))
self.assertTrue(tracker.try_claim(110))
self.assertTrue(tracker.try_claim(199))
- checkpoint = tracker.checkpoint()
+ checkpoint = tracker.try_split(0)
self.assertEqual(OffsetRange(100, 200), tracker.current_restriction())
- self.assertEqual(OffsetRange(200, 200), checkpoint)
+ self.assertEqual(None, checkpoint)
def test_checkpoint_after_failed_claim(self):
tracker = OffsetRestrictionTracker(OffsetRange(100, 200))
@@ -116,7 +116,7 @@
self.assertTrue(tracker.try_claim(160))
self.assertFalse(tracker.try_claim(240))
- checkpoint = tracker.checkpoint()
+ _, checkpoint = tracker.try_split(0)
self.assertTrue(OffsetRange(100, 161), tracker.current_restriction())
self.assertTrue(OffsetRange(161, 200), checkpoint)
diff --git a/sdks/python/apache_beam/metrics/cells.pxd b/sdks/python/apache_beam/metrics/cells.pxd
new file mode 100644
index 0000000..0204da8
--- /dev/null
+++ b/sdks/python/apache_beam/metrics/cells.pxd
@@ -0,0 +1,49 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+cimport cython
+cimport libc.stdint
+
+
+cdef class MetricCell(object):
+ cdef object _lock
+ cpdef bint update(self, value) except -1
+
+
+cdef class CounterCell(MetricCell):
+ cdef readonly libc.stdint.int64_t value
+
+ @cython.locals(ivalue=libc.stdint.int64_t)
+ cpdef bint update(self, value) except -1
+
+
+cdef class DistributionCell(MetricCell):
+ cdef readonly DistributionData data
+
+ @cython.locals(ivalue=libc.stdint.int64_t)
+ cdef inline bint _update(self, value) except -1
+
+
+cdef class GaugeCell(MetricCell):
+ cdef readonly object data
+
+
+cdef class DistributionData(object):
+ cdef readonly libc.stdint.int64_t sum
+ cdef readonly libc.stdint.int64_t count
+ cdef readonly libc.stdint.int64_t min
+ cdef readonly libc.stdint.int64_t max
diff --git a/sdks/python/apache_beam/metrics/cells.py b/sdks/python/apache_beam/metrics/cells.py
index e7336e4..d30dd2a 100644
--- a/sdks/python/apache_beam/metrics/cells.py
+++ b/sdks/python/apache_beam/metrics/cells.py
@@ -30,12 +30,16 @@
from google.protobuf import timestamp_pb2
-from apache_beam.metrics.metricbase import Counter
-from apache_beam.metrics.metricbase import Distribution
-from apache_beam.metrics.metricbase import Gauge
from apache_beam.portability.api import beam_fn_api_pb2
from apache_beam.portability.api import metrics_pb2
+try:
+ import cython
+except ImportError:
+ class fake_cython:
+ compiled = False
+ globals()['cython'] = fake_cython
+
__all__ = ['DistributionResult', 'GaugeResult']
@@ -52,11 +56,17 @@
def __init__(self):
self._lock = threading.Lock()
+ def update(self, value):
+ raise NotImplementedError
+
def get_cumulative(self):
raise NotImplementedError
+ def __reduce__(self):
+ raise NotImplementedError
-class CounterCell(Counter, MetricCell):
+
+class CounterCell(MetricCell):
"""For internal use only; no backwards-compatibility guarantees.
Tracks the current value and delta of a counter metric.
@@ -80,27 +90,41 @@
return result
def inc(self, n=1):
- with self._lock:
- self.value += n
+ self.update(n)
+
+ def dec(self, n=1):
+ self.update(-n)
+
+ def update(self, value):
+ if cython.compiled:
+ ivalue = value
+ # We hold the GIL, no need for another lock.
+ self.value += ivalue
+ else:
+ with self._lock:
+ self.value += value
def get_cumulative(self):
with self._lock:
return self.value
- def to_runner_api_monitoring_info(self):
- """Returns a Metric with this counter value for use in a MonitoringInfo."""
- # TODO(ajamato): Update this code to be consistent with Gauges
- # and Distributions. Since there is no CounterData class this method
- # was added to CounterCell. Consider adding a CounterData class or
- # removing the GaugeData and DistributionData classes.
- return metrics_pb2.Metric(
- counter_data=metrics_pb2.CounterData(
- int64_value=self.get_cumulative()
- )
- )
+ def to_runner_api_user_metric(self, metric_name):
+ return beam_fn_api_pb2.Metrics.User(
+ metric_name=metric_name.to_runner_api(),
+ counter_data=beam_fn_api_pb2.Metrics.User.CounterData(
+ value=self.value))
+
+ def to_runner_api_monitoring_info(self, name, transform_id):
+ from apache_beam.metrics import monitoring_infos
+ return monitoring_infos.int64_user_counter(
+ name.namespace, name.name,
+ metrics_pb2.Metric(
+ counter_data=metrics_pb2.CounterData(
+ int64_value=self.get_cumulative())),
+ ptransform=transform_id)
-class DistributionCell(Distribution, MetricCell):
+class DistributionCell(MetricCell):
"""For internal use only; no backwards-compatibility guarantees.
Tracks the current value and delta for a distribution metric.
@@ -124,26 +148,43 @@
return result
def update(self, value):
- with self._lock:
+ if cython.compiled:
+ # We will hold the GIL throughout the entire _update.
self._update(value)
+ else:
+ with self._lock:
+ self._update(value)
def _update(self, value):
- value = int(value)
- self.data.count += 1
- self.data.sum += value
- self.data.min = (value
- if self.data.min is None or self.data.min > value
- else self.data.min)
- self.data.max = (value
- if self.data.max is None or self.data.max < value
- else self.data.max)
+ if cython.compiled:
+ ivalue = value
+ else:
+ ivalue = int(value)
+ self.data.count = self.data.count + 1
+ self.data.sum = self.data.sum + ivalue
+ if ivalue < self.data.min:
+ self.data.min = ivalue
+ if ivalue > self.data.max:
+ self.data.max = ivalue
def get_cumulative(self):
with self._lock:
return self.data.get_cumulative()
+ def to_runner_api_user_metric(self, metric_name):
+ return beam_fn_api_pb2.Metrics.User(
+ metric_name=metric_name.to_runner_api(),
+ distribution_data=self.get_cumulative().to_runner_api())
-class GaugeCell(Gauge, MetricCell):
+ def to_runner_api_monitoring_info(self, name, transform_id):
+ from apache_beam.metrics import monitoring_infos
+ return monitoring_infos.int64_user_distribution(
+ name.namespace, name.name,
+ self.get_cumulative().to_runner_api_monitoring_info(),
+ ptransform=transform_id)
+
+
+class GaugeCell(MetricCell):
"""For internal use only; no backwards-compatibility guarantees.
Tracks the current value and delta for a gauge metric.
@@ -167,6 +208,9 @@
return result
def set(self, value):
+ self.update(value)
+
+ def update(self, value):
value = int(value)
with self._lock:
# Set the value directly without checking timestamp, because
@@ -178,6 +222,18 @@
with self._lock:
return self.data.get_cumulative()
+ def to_runner_api_user_metric(self, metric_name):
+ return beam_fn_api_pb2.Metrics.User(
+ metric_name=metric_name.to_runner_api(),
+ gauge_data=self.get_cumulative().to_runner_api())
+
+ def to_runner_api_monitoring_info(self, name, transform_id):
+ from apache_beam.metrics import monitoring_infos
+ return monitoring_infos.int64_user_gauge(
+ name.namespace, name.name,
+ self.get_cumulative().to_runner_api_monitoring_info(),
+ ptransform=transform_id)
+
class DistributionResult(object):
"""The result of a Distribution metric."""
@@ -198,7 +254,7 @@
return not self == other
def __repr__(self):
- return '<DistributionResult(sum={}, count={}, min={}, max={})>'.format(
+ return 'DistributionResult(sum={}, count={}, min={}, max={})'.format(
self.sum,
self.count,
self.min,
@@ -206,11 +262,11 @@
@property
def max(self):
- return self.data.max
+ return self.data.max if self.data.count else None
@property
def min(self):
- return self.data.min
+ return self.data.min if self.data.count else None
@property
def count(self):
@@ -340,10 +396,15 @@
by other than the DistributionCell that contains it.
"""
def __init__(self, sum, count, min, max):
- self.sum = sum
- self.count = count
- self.min = min
- self.max = max
+ if count:
+ self.sum = sum
+ self.count = count
+ self.min = min
+ self.max = max
+ else:
+ self.sum = self.count = 0
+ self.min = 2**63 - 1
+ self.max = -2**63
def __eq__(self, other):
return (self.sum == other.sum and
@@ -359,7 +420,7 @@
return not self == other
def __repr__(self):
- return '<DistributionData(sum={}, count={}, min={}, max={})>'.format(
+ return 'DistributionData(sum={}, count={}, min={}, max={})'.format(
self.sum,
self.count,
self.min,
@@ -372,15 +433,11 @@
if other is None:
return self
- new_min = (None if self.min is None and other.min is None else
- min(x for x in (self.min, other.min) if x is not None))
- new_max = (None if self.max is None and other.max is None else
- max(x for x in (self.max, other.max) if x is not None))
return DistributionData(
self.sum + other.sum,
self.count + other.count,
- new_min,
- new_max)
+ self.min if self.min < other.min else other.min,
+ self.max if self.max > other.max else other.max)
@staticmethod
def singleton(value):
@@ -449,7 +506,7 @@
"""
@staticmethod
def identity_element():
- return DistributionData(0, 0, None, None)
+ return DistributionData(0, 0, 2**63 - 1, -2**63)
def combine(self, x, y):
return x.combine(y)
diff --git a/sdks/python/apache_beam/metrics/execution.pxd b/sdks/python/apache_beam/metrics/execution.pxd
index 74b34fb..6e1cbb0 100644
--- a/sdks/python/apache_beam/metrics/execution.pxd
+++ b/sdks/python/apache_beam/metrics/execution.pxd
@@ -16,10 +16,30 @@
#
cimport cython
+cimport libc.stdint
+
+from apache_beam.metrics.cells cimport MetricCell
+
+
+cdef object get_current_tracker
+
+
+cdef class _TypedMetricName(object):
+ cdef readonly object cell_type
+ cdef readonly object metric_name
+ cdef readonly object fast_name
+ cdef libc.stdint.int64_t _hash
+
+
+cdef object _DEFAULT
+
+
+cdef class MetricUpdater(object):
+ cdef _TypedMetricName typed_metric_name
+ cdef object default
cdef class MetricsContainer(object):
cdef object step_name
- cdef public object counters
- cdef public object distributions
- cdef public object gauges
+ cdef public dict metrics
+ cpdef MetricCell get_metric_cell(self, metric_key)
diff --git a/sdks/python/apache_beam/metrics/execution.py b/sdks/python/apache_beam/metrics/execution.py
index 91fe2f8..6918914 100644
--- a/sdks/python/apache_beam/metrics/execution.py
+++ b/sdks/python/apache_beam/metrics/execution.py
@@ -33,14 +33,13 @@
from __future__ import absolute_import
from builtins import object
-from collections import defaultdict
from apache_beam.metrics import monitoring_infos
from apache_beam.metrics.cells import CounterCell
from apache_beam.metrics.cells import DistributionCell
from apache_beam.metrics.cells import GaugeCell
-from apache_beam.portability.api import beam_fn_api_pb2
from apache_beam.runners.worker import statesampler
+from apache_beam.runners.worker.statesampler import get_current_tracker
class MetricKey(object):
@@ -150,88 +149,117 @@
MetricsEnvironment = _MetricsEnvironment()
+class _TypedMetricName(object):
+ """Like MetricName, but also stores the cell type of the metric."""
+ def __init__(self, cell_type, metric_name):
+ self.cell_type = cell_type
+ self.metric_name = metric_name
+ if isinstance(metric_name, str):
+ self.fast_name = metric_name
+ else:
+ self.fast_name = '%d_%s%s' % (
+ len(metric_name.name), metric_name.name, metric_name.namespace)
+ # Cached for speed, as this is used as a key for every counter update.
+ self._hash = hash((cell_type, self.fast_name))
+
+ def __eq__(self, other):
+ return self is other or (
+ self.cell_type == other.cell_type and self.fast_name == other.fast_name)
+
+ def __ne__(self, other):
+ return not self == other
+
+ def __hash__(self):
+ return self._hash
+
+ def __reduce__(self):
+ return _TypedMetricName, (self.cell_type, self.metric_name)
+
+
+_DEFAULT = None
+
+
+class MetricUpdater(object):
+ """A callable that updates the metric as quickly as possible."""
+ def __init__(self, cell_type, metric_name, default=None):
+ self.typed_metric_name = _TypedMetricName(cell_type, metric_name)
+ self.default = default
+
+ def __call__(self, value=_DEFAULT):
+ if value is _DEFAULT:
+ if self.default is _DEFAULT:
+ raise ValueError(
+ 'Missing value for update of %s' % self.metric_name)
+ value = self.default
+ tracker = get_current_tracker()
+ if tracker is not None:
+ tracker.update_metric(self.typed_metric_name, value)
+
+ def __reduce__(self):
+ return MetricUpdater, (
+ self.typed_metric_name.cell_type,
+ self.typed_metric_name.metric_name,
+ self.default)
+
+
class MetricsContainer(object):
"""Holds the metrics of a single step and a single bundle."""
def __init__(self, step_name):
self.step_name = step_name
- self.counters = defaultdict(lambda: CounterCell())
- self.distributions = defaultdict(lambda: DistributionCell())
- self.gauges = defaultdict(lambda: GaugeCell())
+ self.metrics = dict()
def get_counter(self, metric_name):
- return self.counters[metric_name]
+ return self.get_metric_cell(_TypedMetricName(CounterCell, metric_name))
def get_distribution(self, metric_name):
- return self.distributions[metric_name]
+ return self.get_metric_cell(_TypedMetricName(DistributionCell, metric_name))
def get_gauge(self, metric_name):
- return self.gauges[metric_name]
+ return self.get_metric_cell(_TypedMetricName(GaugeCell, metric_name))
+
+ def get_metric_cell(self, typed_metric_name):
+ cell = self.metrics.get(typed_metric_name, None)
+ if cell is None:
+ cell = self.metrics[typed_metric_name] = typed_metric_name.cell_type()
+ return cell
def get_cumulative(self):
"""Return MetricUpdates with cumulative values of all metrics in container.
This returns all the cumulative values for all metrics.
"""
- counters = {MetricKey(self.step_name, k): v.get_cumulative()
- for k, v in self.counters.items()}
+ counters = {MetricKey(self.step_name, k.metric_name): v.get_cumulative()
+ for k, v in self.metrics.items()
+ if k.cell_type == CounterCell}
- distributions = {MetricKey(self.step_name, k): v.get_cumulative()
- for k, v in self.distributions.items()}
+ distributions = {
+ MetricKey(self.step_name, k.metric_name): v.get_cumulative()
+ for k, v in self.metrics.items()
+ if k.cell_type == DistributionCell}
- gauges = {MetricKey(self.step_name, k): v.get_cumulative()
- for k, v in self.gauges.items()}
+ gauges = {MetricKey(self.step_name, k.metric_name): v.get_cumulative()
+ for k, v in self.metrics.items()
+ if k.cell_type == GaugeCell}
return MetricUpdates(counters, distributions, gauges)
def to_runner_api(self):
- return (
- [beam_fn_api_pb2.Metrics.User(
- metric_name=k.to_runner_api(),
- counter_data=beam_fn_api_pb2.Metrics.User.CounterData(
- value=v.get_cumulative()))
- for k, v in self.counters.items()] +
- [beam_fn_api_pb2.Metrics.User(
- metric_name=k.to_runner_api(),
- distribution_data=v.get_cumulative().to_runner_api())
- for k, v in self.distributions.items()] +
- [beam_fn_api_pb2.Metrics.User(
- metric_name=k.to_runner_api(),
- gauge_data=v.get_cumulative().to_runner_api())
- for k, v in self.gauges.items()]
- )
+ return [cell.to_runner_api_user_metric(key.metric_name)
+ for key, cell in self.metrics.items()]
def to_runner_api_monitoring_infos(self, transform_id):
"""Returns a list of MonitoringInfos for the metrics in this container."""
- all_user_metrics = []
- for k, v in self.counters.items():
- all_user_metrics.append(monitoring_infos.int64_user_counter(
- k.namespace, k.name,
- v.to_runner_api_monitoring_info(),
- ptransform=transform_id
- ))
-
- for k, v in self.distributions.items():
- all_user_metrics.append(monitoring_infos.int64_user_distribution(
- k.namespace, k.name,
- v.get_cumulative().to_runner_api_monitoring_info(),
- ptransform=transform_id
- ))
-
- for k, v in self.gauges.items():
- all_user_metrics.append(monitoring_infos.int64_user_gauge(
- k.namespace, k.name,
- v.get_cumulative().to_runner_api_monitoring_info(),
- ptransform=transform_id
- ))
+ all_user_metrics = [
+ cell.to_runner_api_monitoring_info(key.metric_name, transform_id)
+ for key, cell in self.metrics.items()]
return {monitoring_infos.to_key(mi) : mi for mi in all_user_metrics}
def reset(self):
- for counter in self.counters.values():
- counter.reset()
- for distribution in self.distributions.values():
- distribution.reset()
- for gauge in self.gauges.values():
- gauge.reset()
+ for metric in self.metrics.values():
+ metric.reset()
+
+ def __reduce__(self):
+ raise NotImplementedError
class MetricUpdates(object):
diff --git a/sdks/python/apache_beam/metrics/execution_test.py b/sdks/python/apache_beam/metrics/execution_test.py
index 9af1696..fc363a4 100644
--- a/sdks/python/apache_beam/metrics/execution_test.py
+++ b/sdks/python/apache_beam/metrics/execution_test.py
@@ -73,12 +73,6 @@
class TestMetricsContainer(unittest.TestCase):
- def test_create_new_counter(self):
- mc = MetricsContainer('astep')
- self.assertFalse(MetricName('namespace', 'name') in mc.counters)
- mc.get_counter(MetricName('namespace', 'name'))
- self.assertTrue(MetricName('namespace', 'name') in mc.counters)
-
def test_add_to_counter(self):
mc = MetricsContainer('astep')
counter = mc.get_counter(MetricName('namespace', 'name'))
diff --git a/sdks/python/apache_beam/metrics/metric.py b/sdks/python/apache_beam/metrics/metric.py
index acd4771..8bbe191 100644
--- a/sdks/python/apache_beam/metrics/metric.py
+++ b/sdks/python/apache_beam/metrics/metric.py
@@ -29,7 +29,8 @@
import inspect
from builtins import object
-from apache_beam.metrics.execution import MetricsEnvironment
+from apache_beam.metrics import cells
+from apache_beam.metrics.execution import MetricUpdater
from apache_beam.metrics.metricbase import Counter
from apache_beam.metrics.metricbase import Distribution
from apache_beam.metrics.metricbase import Gauge
@@ -101,11 +102,7 @@
def __init__(self, metric_name):
super(Metrics.DelegatingCounter, self).__init__()
self.metric_name = metric_name
-
- def inc(self, n=1):
- container = MetricsEnvironment.current_container()
- if container is not None:
- container.get_counter(self.metric_name).inc(n)
+ self.inc = MetricUpdater(cells.CounterCell, metric_name, default=1)
class DelegatingDistribution(Distribution):
"""Metrics Distribution Delegates functionality to MetricsEnvironment."""
@@ -113,11 +110,7 @@
def __init__(self, metric_name):
super(Metrics.DelegatingDistribution, self).__init__()
self.metric_name = metric_name
-
- def update(self, value):
- container = MetricsEnvironment.current_container()
- if container is not None:
- container.get_distribution(self.metric_name).update(value)
+ self.update = MetricUpdater(cells.DistributionCell, metric_name)
class DelegatingGauge(Gauge):
"""Metrics Gauge that Delegates functionality to MetricsEnvironment."""
@@ -125,11 +118,7 @@
def __init__(self, metric_name):
super(Metrics.DelegatingGauge, self).__init__()
self.metric_name = metric_name
-
- def set(self, value):
- container = MetricsEnvironment.current_container()
- if container is not None:
- container.get_gauge(self.metric_name).set(value)
+ self.set = MetricUpdater(cells.GaugeCell, metric_name)
class MetricResults(object):
diff --git a/sdks/python/apache_beam/metrics/metric_test.py b/sdks/python/apache_beam/metrics/metric_test.py
index 6e8ee08..cb18dc7 100644
--- a/sdks/python/apache_beam/metrics/metric_test.py
+++ b/sdks/python/apache_beam/metrics/metric_test.py
@@ -130,31 +130,36 @@
statesampler.set_current_tracker(sampler)
state1 = sampler.scoped_state('mystep', 'myState',
metrics_container=MetricsContainer('mystep'))
- sampler.start()
- with state1:
- counter_ns = 'aCounterNamespace'
- distro_ns = 'aDistributionNamespace'
- name = 'a_name'
- counter = Metrics.counter(counter_ns, name)
- distro = Metrics.distribution(distro_ns, name)
- counter.inc(10)
- counter.dec(3)
- distro.update(10)
- distro.update(2)
- self.assertTrue(isinstance(counter, Metrics.DelegatingCounter))
- self.assertTrue(isinstance(distro, Metrics.DelegatingDistribution))
- del distro
- del counter
+ try:
+ sampler.start()
+ with state1:
+ counter_ns = 'aCounterNamespace'
+ distro_ns = 'aDistributionNamespace'
+ name = 'a_name'
+ counter = Metrics.counter(counter_ns, name)
+ distro = Metrics.distribution(distro_ns, name)
+ counter.inc(10)
+ counter.dec(3)
+ distro.update(10)
+ distro.update(2)
+ self.assertTrue(isinstance(counter, Metrics.DelegatingCounter))
+ self.assertTrue(isinstance(distro, Metrics.DelegatingDistribution))
- container = MetricsEnvironment.current_container()
- self.assertEqual(
- container.counters[MetricName(counter_ns, name)].get_cumulative(),
- 7)
- self.assertEqual(
- container.distributions[MetricName(distro_ns, name)].get_cumulative(),
- DistributionData(12, 2, 2, 10))
- sampler.stop()
+ del distro
+ del counter
+
+ container = MetricsEnvironment.current_container()
+ self.assertEqual(
+ container.get_counter(
+ MetricName(counter_ns, name)).get_cumulative(),
+ 7)
+ self.assertEqual(
+ container.get_distribution(
+ MetricName(distro_ns, name)).get_cumulative(),
+ DistributionData(12, 2, 2, 10))
+ finally:
+ sampler.stop()
if __name__ == '__main__':
diff --git a/sdks/python/apache_beam/runners/common.pxd b/sdks/python/apache_beam/runners/common.pxd
index 2ffe432..37e05bf 100644
--- a/sdks/python/apache_beam/runners/common.pxd
+++ b/sdks/python/apache_beam/runners/common.pxd
@@ -42,6 +42,8 @@
cdef object key_arg_name
cdef object restriction_provider
cdef object restriction_provider_arg_name
+ cdef object watermark_estimator
+ cdef object watermark_estimator_arg_name
cdef class DoFnSignature(object):
@@ -91,7 +93,9 @@
cdef bint cache_globally_windowed_args
cdef object process_method
cdef bint is_splittable
- cdef object restriction_tracker
+ cdef object threadsafe_restriction_tracker
+ cdef object watermark_estimator
+ cdef object watermark_estimator_param
cdef WindowedValue current_windowed_value
cdef bint is_key_param_required
diff --git a/sdks/python/apache_beam/runners/common.py b/sdks/python/apache_beam/runners/common.py
index 3e14f3b..8632cfd 100644
--- a/sdks/python/apache_beam/runners/common.py
+++ b/sdks/python/apache_beam/runners/common.py
@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-
# cython: profile=True
"""Worker operations executor.
@@ -167,6 +166,8 @@
self.key_arg_name = None
self.restriction_provider = None
self.restriction_provider_arg_name = None
+ self.watermark_estimator = None
+ self.watermark_estimator_arg_name = None
for kw, v in zip(self.args[-len(self.defaults):], self.defaults):
if isinstance(v, core.DoFn.StateParam):
@@ -184,6 +185,9 @@
elif isinstance(v, core.DoFn.RestrictionParam):
self.restriction_provider = v.restriction_provider
self.restriction_provider_arg_name = kw
+ elif isinstance(v, core.DoFn.WatermarkEstimatorParam):
+ self.watermark_estimator = v.watermark_estimator
+ self.watermark_estimator_arg_name = kw
def invoke_timer_callback(self,
user_state_context,
@@ -264,6 +268,9 @@
def get_restriction_provider(self):
return self.process_method.restriction_provider
+ def get_watermark_estimator(self):
+ return self.process_method.watermark_estimator
+
def _validate(self):
self._validate_process()
self._validate_bundle_method(self.start_bundle_method)
@@ -458,7 +465,11 @@
signature.is_stateful_dofn())
self.user_state_context = user_state_context
self.is_splittable = signature.is_splittable_dofn()
- self.restriction_tracker = None
+ self.watermark_estimator = self.signature.get_watermark_estimator()
+ self.watermark_estimator_param = (
+ self.signature.process_method.watermark_estimator_arg_name
+ if self.watermark_estimator else None)
+ self.threadsafe_restriction_tracker = None
self.current_windowed_value = None
self.bundle_finalizer_param = bundle_finalizer_param
self.is_key_param_required = False
@@ -569,15 +580,24 @@
raise ValueError(
'A RestrictionTracker %r was provided but DoFn does not have a '
'RestrictionTrackerParam defined' % restriction_tracker)
- additional_kwargs[restriction_tracker_param] = restriction_tracker
+ from apache_beam.io import iobase
+ self.threadsafe_restriction_tracker = iobase.ThreadsafeRestrictionTracker(
+ restriction_tracker)
+ additional_kwargs[restriction_tracker_param] = (
+ iobase.RestrictionTrackerView(self.threadsafe_restriction_tracker))
+
+ if self.watermark_estimator:
+ # The watermark estimator needs to be reset for every element.
+ self.watermark_estimator.reset()
+ additional_kwargs[self.watermark_estimator_param] = (
+ self.watermark_estimator)
try:
self.current_windowed_value = windowed_value
- self.restriction_tracker = restriction_tracker
return self._invoke_process_per_window(
windowed_value, additional_args, additional_kwargs,
output_processor)
finally:
- self.restriction_tracker = None
+ self.threadsafe_restriction_tracker = None
self.current_windowed_value = windowed_value
elif self.has_windowed_inputs and len(windowed_value.windows) != 1:
@@ -664,24 +684,34 @@
windowed_value, self.process_method(*args_for_process))
if self.is_splittable:
- deferred_status = self.restriction_tracker.deferred_status()
+ # TODO: Consider calling check_done right after SDF.Process() finishing.
+ # In order to do this, we need to know that current invoking dofn is
+ # ProcessSizedElementAndRestriction.
+ self.threadsafe_restriction_tracker.check_done()
+ deferred_status = self.threadsafe_restriction_tracker.deferred_status()
+ output_watermark = None
+ if self.watermark_estimator:
+ output_watermark = self.watermark_estimator.current_watermark()
if deferred_status:
deferred_restriction, deferred_watermark = deferred_status
element = windowed_value.value
size = self.signature.get_restriction_provider().restriction_size(
element, deferred_restriction)
- return (
+ return ((
windowed_value.with_value(((element, deferred_restriction), size)),
- deferred_watermark)
+ output_watermark), deferred_watermark)
def try_split(self, fraction):
- restriction_tracker = self.restriction_tracker
+ restriction_tracker = self.threadsafe_restriction_tracker
current_windowed_value = self.current_windowed_value
if restriction_tracker and current_windowed_value:
# Temporary workaround for [BEAM-7473]: get current_watermark before
# split, in case watermark gets advanced before getting split results.
# In worst case, current_watermark is always stale, which is ok.
- current_watermark = restriction_tracker.current_watermark()
+ if self.watermark_estimator:
+ current_watermark = self.watermark_estimator.current_watermark()
+ else:
+ current_watermark = None
split = restriction_tracker.try_split(fraction)
if split:
primary, residual = split
@@ -690,15 +720,13 @@
primary_size = restriction_provider.restriction_size(element, primary)
residual_size = restriction_provider.restriction_size(element, residual)
return (
- (self.current_windowed_value.with_value(
- ((element, primary), primary_size)),
- None),
- (self.current_windowed_value.with_value(
- ((element, residual), residual_size)),
- current_watermark))
+ ((self.current_windowed_value.with_value((
+ (element, primary), primary_size)), None), None),
+ ((self.current_windowed_value.with_value((
+ (element, residual), residual_size)), current_watermark), None))
def current_element_progress(self):
- restriction_tracker = self.restriction_tracker
+ restriction_tracker = self.threadsafe_restriction_tracker
if restriction_tracker:
return restriction_tracker.current_progress()
diff --git a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py b/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py
index 4a5fef4..50b79e8 100644
--- a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py
+++ b/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py
@@ -58,6 +58,7 @@
from apache_beam.runners.runner import PipelineRunner
from apache_beam.runners.runner import PipelineState
from apache_beam.runners.runner import PValueCache
+from apache_beam.runners.utils import is_interactive
from apache_beam.transforms import window
from apache_beam.transforms.display import DisplayData
from apache_beam.typehints import typehints
@@ -364,6 +365,16 @@
def run_pipeline(self, pipeline, options):
"""Remotely executes entire pipeline or parts reachable from node."""
+ # Label goog-dataflow-notebook if job is started from notebook.
+ _, is_in_notebook = is_interactive()
+ if is_in_notebook:
+ notebook_version = ('goog-dataflow-notebook=' +
+ beam.version.__version__.replace('.', '_'))
+ if options.view_as(GoogleCloudOptions).labels:
+ options.view_as(GoogleCloudOptions).labels.append(notebook_version)
+ else:
+ options.view_as(GoogleCloudOptions).labels = [notebook_version]
+
# Import here to avoid adding the dependency for local running scenarios.
try:
# pylint: disable=wrong-import-order, wrong-import-position
@@ -622,9 +633,15 @@
debug_options = options.view_as(DebugOptions)
use_fn_api = (debug_options.experiments and
'beam_fn_api' in debug_options.experiments)
+ use_streaming_engine = (
+ debug_options.experiments and
+ 'enable_streaming_engine' in debug_options.experiments and
+ 'enable_windmill_service' in debug_options.experiments)
+
step = self._add_step(
TransformNames.READ, transform_node.full_label, transform_node)
- if standard_options.streaming and not use_fn_api:
+ if (standard_options.streaming and
+ (not use_fn_api or not use_streaming_engine)):
step.add_property(PropertyNames.FORMAT, 'pubsub')
step.add_property(PropertyNames.PUBSUB_SUBSCRIPTION, '_starting_signal/')
else:
@@ -1195,6 +1212,52 @@
PropertyNames.STEP_NAME: input_step.proto.name,
PropertyNames.OUTPUT_NAME: input_step.get_output(input_tag)})
+ def run_TestStream(self, transform_node, options):
+ from apache_beam.portability.api import beam_runner_api_pb2
+ from apache_beam.testing.test_stream import ElementEvent
+ from apache_beam.testing.test_stream import ProcessingTimeEvent
+ from apache_beam.testing.test_stream import WatermarkEvent
+ standard_options = options.view_as(StandardOptions)
+ if not standard_options.streaming:
+ raise ValueError('TestStream is currently available for use '
+ 'only in streaming pipelines.')
+
+ transform = transform_node.transform
+ step = self._add_step(TransformNames.READ, transform_node.full_label,
+ transform_node)
+ step.add_property(PropertyNames.FORMAT, 'test_stream')
+ test_stream_payload = beam_runner_api_pb2.TestStreamPayload()
+ # TestStream source doesn't do any decoding of elements,
+ # so we won't set test_stream_payload.coder_id.
+ output_coder = transform._infer_output_coder() # pylint: disable=protected-access
+ for event in transform.events:
+ new_event = test_stream_payload.events.add()
+ if isinstance(event, ElementEvent):
+ for tv in event.timestamped_values:
+ element = new_event.element_event.elements.add()
+ element.encoded_element = output_coder.encode(tv.value)
+ element.timestamp = tv.timestamp.micros
+ elif isinstance(event, ProcessingTimeEvent):
+ new_event.processing_time_event.advance_duration = (
+ event.advance_by.micros)
+ elif isinstance(event, WatermarkEvent):
+ new_event.watermark_event.new_watermark = event.new_watermark.micros
+ serialized_payload = self.byte_array_to_json_string(
+ test_stream_payload.SerializeToString())
+ step.add_property(PropertyNames.SERIALIZED_TEST_STREAM, serialized_payload)
+
+ step.encoding = self._get_encoded_output_coder(transform_node)
+ step.add_property(PropertyNames.OUTPUT_INFO, [{
+ PropertyNames.USER_NAME:
+ ('%s.%s' % (transform_node.full_label, PropertyNames.OUT)),
+ PropertyNames.ENCODING: step.encoding,
+ PropertyNames.OUTPUT_NAME: PropertyNames.OUT
+ }])
+
+ # We must mark this method as not a test or else its name is a matcher for
+ # nosetest tests.
+ run_TestStream.__test__ = False
+
@classmethod
def serialize_windowing_strategy(cls, windowing):
from apache_beam.runners import pipeline_context
diff --git a/sdks/python/apache_beam/runners/dataflow/internal/names.py b/sdks/python/apache_beam/runners/dataflow/internal/names.py
index eacce15..ec6f988 100644
--- a/sdks/python/apache_beam/runners/dataflow/internal/names.py
+++ b/sdks/python/apache_beam/runners/dataflow/internal/names.py
@@ -45,8 +45,8 @@
# TODO(BEAM-5939): Remove these shared names once Dataflow worker is updated.
PICKLED_MAIN_SESSION_FILE = 'pickled_main_session'
-STAGED_PIPELINE_FILENAME = "pipeline.pb"
-STAGED_PIPELINE_URL_METADATA_FIELD = "pipeline_url"
+STAGED_PIPELINE_FILENAME = 'pipeline.pb'
+STAGED_PIPELINE_URL_METADATA_FIELD = 'pipeline_url'
# Package names for different distributions
BEAM_PACKAGE_NAME = 'apache-beam'
@@ -61,7 +61,8 @@
class TransformNames(object):
"""For internal use only; no backwards-compatibility guarantees.
- Transform strings as they are expected in the CloudWorkflow protos."""
+ Transform strings as they are expected in the CloudWorkflow protos.
+ """
COLLECTION_TO_SINGLETON = 'CollectionToSingleton'
COMBINE = 'CombineValues'
CREATE_PCOLLECTION = 'CreateCollection'
@@ -75,7 +76,8 @@
class PropertyNames(object):
"""For internal use only; no backwards-compatibility guarantees.
- Property strings as they are expected in the CloudWorkflow protos."""
+ Property strings as they are expected in the CloudWorkflow protos.
+ """
BIGQUERY_CREATE_DISPOSITION = 'create_disposition'
BIGQUERY_DATASET = 'dataset'
BIGQUERY_EXPORT_FORMAT = 'bigquery_export_format'
@@ -113,6 +115,7 @@
SERIALIZED_FN = 'serialized_fn'
SHARD_NAME_TEMPLATE = 'shard_template'
SOURCE_STEP_INPUT = 'custom_source_step_input'
+ SERIALIZED_TEST_STREAM = 'serialized_test_stream'
STEP_NAME = 'step_name'
USER_FN = 'user_fn'
USER_NAME = 'user_name'
diff --git a/sdks/python/apache_beam/runners/direct/sdf_direct_runner_test.py b/sdks/python/apache_beam/runners/direct/sdf_direct_runner_test.py
index 946ef34..fd04d4c 100644
--- a/sdks/python/apache_beam/runners/direct/sdf_direct_runner_test.py
+++ b/sdks/python/apache_beam/runners/direct/sdf_direct_runner_test.py
@@ -51,6 +51,9 @@
def create_tracker(self, restriction):
return OffsetRestrictionTracker(restriction)
+ def restriction_size(self, element, restriction):
+ return restriction.size()
+
class ReadFiles(DoFn):
@@ -63,12 +66,11 @@
restriction_tracker=DoFn.RestrictionParam(ReadFilesProvider()),
*args, **kwargs):
file_name = element
- assert isinstance(restriction_tracker, OffsetRestrictionTracker)
with open(file_name, 'rb') as file:
- pos = restriction_tracker.start_position()
- if restriction_tracker.start_position() > 0:
- file.seek(restriction_tracker.start_position() - 1)
+ pos = restriction_tracker.current_restriction().start
+ if restriction_tracker.current_restriction().start > 0:
+ file.seek(restriction_tracker.current_restriction().start - 1)
line = file.readline()
pos = pos - 1 + len(line)
@@ -104,6 +106,9 @@
def split(self, element, restriction):
return [restriction,]
+ def restriction_size(self, element, restriction):
+ return restriction.size()
+
class ExpandStrings(DoFn):
@@ -118,10 +123,9 @@
side.extend(side1)
side.extend(side2)
side.extend(side3)
- assert isinstance(restriction_tracker, OffsetRestrictionTracker)
side = list(side)
- for i in range(restriction_tracker.start_position(),
- restriction_tracker.stop_position()):
+ for i in range(restriction_tracker.current_restriction().start,
+ restriction_tracker.current_restriction().stop):
if restriction_tracker.try_claim(i):
if not side:
yield (
diff --git a/sdks/python/apache_beam/runners/interactive/display/display_manager.py b/sdks/python/apache_beam/runners/interactive/display/display_manager.py
index 84025f9..c6ead9d 100644
--- a/sdks/python/apache_beam/runners/interactive/display/display_manager.py
+++ b/sdks/python/apache_beam/runners/interactive/display/display_manager.py
@@ -32,17 +32,18 @@
try:
import IPython # pylint: disable=import-error
+ from IPython import get_ipython # pylint: disable=import-error
+ from IPython.display import display as ip_display # pylint: disable=import-error
# _display_progress defines how outputs are printed on the frontend.
- _display_progress = IPython.display.display
+ _display_progress = ip_display
def _formatter(string, pp, cycle): # pylint: disable=unused-argument
pp.text(string)
- plain = get_ipython().display_formatter.formatters['text/plain'] # pylint: disable=undefined-variable
- plain.for_type(str, _formatter)
+ if get_ipython():
+ plain = get_ipython().display_formatter.formatters['text/plain'] # pylint: disable=undefined-variable
+ plain.for_type(str, _formatter)
-# NameError is added here because get_ipython() throws "not defined" NameError
-# if not started with IPython.
-except (ImportError, NameError):
+except ImportError:
IPython = None
_display_progress = print
diff --git a/sdks/python/apache_beam/runners/interactive/display/pipeline_graph.py b/sdks/python/apache_beam/runners/interactive/display/pipeline_graph.py
index 23605b5..aabf959 100644
--- a/sdks/python/apache_beam/runners/interactive/display/pipeline_graph.py
+++ b/sdks/python/apache_beam/runners/interactive/display/pipeline_graph.py
@@ -44,12 +44,12 @@
Examples:
graph = pipeline_graph.PipelineGraph(pipeline_proto)
- graph.display_graph()
+ graph.get_dot()
or
graph = pipeline_graph.PipelineGraph(pipeline)
- graph.display_graph()
+ graph.get_dot()
Args:
pipeline: (Pipeline proto) or (Pipeline) pipeline to be rendered.
diff --git a/sdks/python/apache_beam/runners/interactive/interactive_environment.py b/sdks/python/apache_beam/runners/interactive/interactive_environment.py
index 2b0dc7a..2dbc102 100644
--- a/sdks/python/apache_beam/runners/interactive/interactive_environment.py
+++ b/sdks/python/apache_beam/runners/interactive/interactive_environment.py
@@ -30,6 +30,7 @@
import apache_beam as beam
from apache_beam.runners import runner
+from apache_beam.runners.utils import is_interactive
_interactive_beam_env = None
@@ -93,17 +94,7 @@
'install apache-beam[interactive]` to install necessary '
'dependencies to enable all data visualization features.')
- self._is_in_ipython = False
- self._is_in_notebook = False
- # Check if the runtime is within an interactive environment, i.e., ipython.
- try:
- from IPython import get_ipython # pylint: disable=import-error
- if get_ipython():
- self._is_in_ipython = True
- if 'IPKernelApp' in get_ipython().config:
- self._is_in_notebook = True
- except ImportError:
- pass
+ self._is_in_ipython, self._is_in_notebook = is_interactive()
if not self._is_in_ipython:
logging.warning('You cannot use Interactive Beam features when you are '
'not in an interactive environment such as a Jupyter '
diff --git a/sdks/python/apache_beam/runners/interactive/interactive_environment_test.py b/sdks/python/apache_beam/runners/interactive/interactive_environment_test.py
index 342f400..6fa257b 100644
--- a/sdks/python/apache_beam/runners/interactive/interactive_environment_test.py
+++ b/sdks/python/apache_beam/runners/interactive/interactive_environment_test.py
@@ -126,7 +126,6 @@
def test_determine_terminal_state(self):
for state in (runner.PipelineState.DONE,
- runner.PipelineState.STOPPED,
runner.PipelineState.FAILED,
runner.PipelineState.CANCELLED,
runner.PipelineState.UPDATED,
@@ -136,7 +135,7 @@
self.assertTrue(ie.current_env().is_terminated(self._p))
for state in (runner.PipelineState.UNKNOWN,
runner.PipelineState.STARTING,
-
+ runner.PipelineState.STOPPED,
runner.PipelineState.RUNNING,
runner.PipelineState.DRAINING,
runner.PipelineState.PENDING,
diff --git a/sdks/python/apache_beam/runners/interactive/interactive_runner.py b/sdks/python/apache_beam/runners/interactive/interactive_runner.py
index 94c0de7..7b6df70 100644
--- a/sdks/python/apache_beam/runners/interactive/interactive_runner.py
+++ b/sdks/python/apache_beam/runners/interactive/interactive_runner.py
@@ -48,7 +48,8 @@
underlying_runner=None,
cache_dir=None,
cache_format='text',
- render_option=None):
+ render_option=None,
+ skip_display=False):
"""Constructor of InteractiveRunner.
Args:
@@ -58,12 +59,16 @@
PCollection caches. Available options are 'text' and 'tfrecord'.
render_option: (str) this parameter decides how the pipeline graph is
rendered. See display.pipeline_graph_renderer for available options.
+ skip_display: (bool) whether to skip display operations when running the
+ pipeline. Useful if running large pipelines when display is not
+ needed.
"""
self._underlying_runner = (underlying_runner
or direct_runner.DirectRunner())
self._cache_manager = cache.FileBasedCacheManager(cache_dir, cache_format)
self._renderer = pipeline_graph_renderer.get_renderer(render_option)
self._in_session = False
+ self._skip_display = skip_display
def is_fnapi_compatible(self):
# TODO(BEAM-8436): return self._underlying_runner.is_fnapi_compatible()
@@ -140,15 +145,19 @@
self._underlying_runner,
options)
- display = display_manager.DisplayManager(
- pipeline_proto=pipeline_proto,
- pipeline_analyzer=analyzer,
- cache_manager=self._cache_manager,
- pipeline_graph_renderer=self._renderer)
- display.start_periodic_update()
+ if not self._skip_display:
+ display = display_manager.DisplayManager(
+ pipeline_proto=pipeline_proto,
+ pipeline_analyzer=analyzer,
+ cache_manager=self._cache_manager,
+ pipeline_graph_renderer=self._renderer)
+ display.start_periodic_update()
+
result = pipeline_to_execute.run()
result.wait_until_finish()
- display.stop_periodic_update()
+
+ if not self._skip_display:
+ display.stop_periodic_update()
return PipelineResult(result, self, self._analyzer.pipeline_info(),
self._cache_manager, pcolls_to_pcoll_id)
diff --git a/sdks/python/apache_beam/runners/portability/abstract_job_service.py b/sdks/python/apache_beam/runners/portability/abstract_job_service.py
index 982fad1..5dd497a 100644
--- a/sdks/python/apache_beam/runners/portability/abstract_job_service.py
+++ b/sdks/python/apache_beam/runners/portability/abstract_job_service.py
@@ -23,13 +23,6 @@
from apache_beam.portability.api import beam_job_api_pb2
from apache_beam.portability.api import beam_job_api_pb2_grpc
-TERMINAL_STATES = [
- beam_job_api_pb2.JobState.DONE,
- beam_job_api_pb2.JobState.STOPPED,
- beam_job_api_pb2.JobState.FAILED,
- beam_job_api_pb2.JobState.CANCELLED,
-]
-
class AbstractJobServiceServicer(beam_job_api_pb2_grpc.JobServiceServicer):
"""Manages one or more pipelines, possibly concurrently.
@@ -131,6 +124,11 @@
def get_pipeline(self):
return self._pipeline_proto
+ @staticmethod
+ def is_terminal_state(state):
+ from apache_beam.runners.portability import portable_runner
+ return state in portable_runner.TERMINAL_STATES
+
def to_runner_api(self):
return beam_job_api_pb2.JobInfo(
job_id=self._job_id,
diff --git a/sdks/python/apache_beam/runners/portability/flink_runner_test.py b/sdks/python/apache_beam/runners/portability/flink_runner_test.py
index da2ebdb..377ceb7 100644
--- a/sdks/python/apache_beam/runners/portability/flink_runner_test.py
+++ b/sdks/python/apache_beam/runners/portability/flink_runner_test.py
@@ -319,10 +319,10 @@
line = f.readline()
self.assertSetEqual(lines_actual, lines_expected)
- def test_sdf_with_sdf_initiated_checkpointing(self):
+ def test_sdf_with_watermark_tracking(self):
raise unittest.SkipTest("BEAM-2939")
- def test_sdf_synthetic_source(self):
+ def test_sdf_with_sdf_initiated_checkpointing(self):
raise unittest.SkipTest("BEAM-2939")
def test_callbacks_with_exception(self):
diff --git a/sdks/python/apache_beam/runners/portability/flink_uber_jar_job_server.py b/sdks/python/apache_beam/runners/portability/flink_uber_jar_job_server.py
index b69da66..d2a890f 100644
--- a/sdks/python/apache_beam/runners/portability/flink_uber_jar_job_server.py
+++ b/sdks/python/apache_beam/runners/portability/flink_uber_jar_job_server.py
@@ -216,7 +216,7 @@
'IN_PROGRESS': beam_job_api_pb2.JobState.RUNNING,
'COMPLETED': beam_job_api_pb2.JobState.DONE,
}.get(flink_status, beam_job_api_pb2.JobState.UNSPECIFIED)
- if beam_state in abstract_job_service.TERMINAL_STATES:
+ if self.is_terminal_state(beam_state):
self.delete_jar()
return beam_state
@@ -224,7 +224,7 @@
sleep_secs = 1.0
current_state = self.get_state()
yield current_state
- while current_state not in abstract_job_service.TERMINAL_STATES:
+ while not self.is_terminal_state(current_state):
sleep_secs = min(60, sleep_secs * 1.2)
time.sleep(sleep_secs)
previous_state, current_state = current_state, self.get_state()
@@ -233,7 +233,7 @@
def get_message_stream(self):
for state in self.get_state_stream():
- if state in abstract_job_service.TERMINAL_STATES:
+ if self.is_terminal_state(state):
response = self.get('v1/jobs/%s/exceptions' % self._flink_job_id)
for ix, exc in enumerate(response['all-exceptions']):
yield beam_job_api_pb2.JobMessage(
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
index 2204a24..b7929cb 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
@@ -41,6 +41,7 @@
from tenacity import stop_after_attempt
import apache_beam as beam
+from apache_beam.io import iobase
from apache_beam.io import restriction_trackers
from apache_beam.metrics import monitoring_infos
from apache_beam.metrics.execution import MetricKey
@@ -56,9 +57,11 @@
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
from apache_beam.tools import utils
+from apache_beam.transforms import core
from apache_beam.transforms import environments
from apache_beam.transforms import userstate
from apache_beam.transforms import window
+from apache_beam.utils import timestamp
if statesampler.FAST_SAMPLER:
DEFAULT_SAMPLING_PERIOD_MS = statesampler.DEFAULT_SAMPLING_PERIOD_MS
@@ -423,17 +426,14 @@
assert_that(actual, is_buffered_correctly)
def test_sdf(self):
-
class ExpandingStringsDoFn(beam.DoFn):
def process(
self,
element,
restriction_tracker=beam.DoFn.RestrictionParam(
ExpandStringsProvider())):
- assert isinstance(
- restriction_tracker,
- restriction_trackers.OffsetRestrictionTracker), restriction_tracker
- cur = restriction_tracker.start_position()
+ assert isinstance(restriction_tracker, iobase.RestrictionTrackerView)
+ cur = restriction_tracker.current_restriction().start
while restriction_tracker.try_claim(cur):
yield element[cur]
cur += 1
@@ -446,6 +446,56 @@
| beam.ParDo(ExpandingStringsDoFn()))
assert_that(actual, equal_to(list(''.join(data))))
+ def test_sdf_with_check_done_failed(self):
+ class ExpandingStringsDoFn(beam.DoFn):
+ def process(
+ self,
+ element,
+ restriction_tracker=beam.DoFn.RestrictionParam(
+ ExpandStringsProvider())):
+ assert isinstance(restriction_tracker, iobase.RestrictionTrackerView)
+ cur = restriction_tracker.current_restriction().start
+ while restriction_tracker.try_claim(cur):
+ yield element[cur]
+ cur += 1
+ return
+ with self.assertRaises(Exception):
+ with self.create_pipeline() as p:
+ data = ['abc', 'defghijklmno', 'pqrstuv', 'wxyz']
+ _ = (
+ p
+ | beam.Create(data)
+ | beam.ParDo(ExpandingStringsDoFn()))
+
+ def test_sdf_with_watermark_tracking(self):
+
+ class ExpandingStringsDoFn(beam.DoFn):
+ def process(
+ self,
+ element,
+ restriction_tracker=beam.DoFn.RestrictionParam(
+ ExpandStringsProvider()),
+ watermark_estimator=beam.DoFn.WatermarkEstimatorParam(
+ core.WatermarkEstimator())):
+ cur = restriction_tracker.current_restriction().start
+ start = cur
+ while restriction_tracker.try_claim(cur):
+ watermark_estimator.set_watermark(timestamp.Timestamp(micros=cur))
+ assert watermark_estimator.current_watermark().micros == start
+ yield element[cur]
+ if cur % 2 == 1:
+ restriction_tracker.defer_remainder(timestamp.Duration(micros=5))
+ return
+ cur += 1
+
+ with self.create_pipeline() as p:
+ data = ['abc', 'defghijklmno', 'pqrstuv', 'wxyz']
+ actual = (
+ p
+ | beam.Create(data)
+ | beam.ParDo(ExpandingStringsDoFn()))
+ assert_that(actual, equal_to(list(''.join(data))))
+
def test_sdf_with_sdf_initiated_checkpointing(self):
counter = beam.metrics.Metrics.counter('ns', 'my_counter')
@@ -456,10 +506,8 @@
element,
restriction_tracker=beam.DoFn.RestrictionParam(
ExpandStringsProvider())):
- assert isinstance(
- restriction_tracker,
- restriction_trackers.OffsetRestrictionTracker), restriction_tracker
- cur = restriction_tracker.start_position()
+ assert isinstance(restriction_tracker, iobase.RestrictionTrackerView)
+ cur = restriction_tracker.current_restriction().start
while restriction_tracker.try_claim(cur):
counter.inc()
yield element[cur]
@@ -1123,6 +1171,9 @@
def test_sdf_with_sdf_initiated_checkpointing(self):
raise unittest.SkipTest("This test is for a single worker only.")
+ def test_sdf_with_watermark_tracking(self):
+ raise unittest.SkipTest("This test is for a single worker only.")
+
class FnApiRunnerTestWithGrpcAndMultiWorkers(FnApiRunnerTest):
@@ -1142,6 +1193,9 @@
def test_sdf_with_sdf_initiated_checkpointing(self):
raise unittest.SkipTest("This test is for a single worker only.")
+ def test_sdf_with_watermark_tracking(self):
+ raise unittest.SkipTest("This test is for a single worker only.")
+
class FnApiRunnerTestWithBundleRepeat(FnApiRunnerTest):
@@ -1172,6 +1226,9 @@
def test_sdf_with_sdf_initiated_checkpointing(self):
raise unittest.SkipTest("This test is for a single worker only.")
+ def test_sdf_with_watermark_tracking(self):
+ raise unittest.SkipTest("This test is for a single worker only.")
+
class FnApiRunnerSplitTest(unittest.TestCase):
@@ -1340,7 +1397,7 @@
element,
restriction_tracker=beam.DoFn.RestrictionParam(EnumerateProvider())):
to_emit = []
- cur = restriction_tracker.start_position()
+ cur = restriction_tracker.current_restriction().start
while restriction_tracker.try_claim(cur):
to_emit.append((element, cur))
element_counter.increment()
diff --git a/sdks/python/apache_beam/runners/portability/local_job_service.py b/sdks/python/apache_beam/runners/portability/local_job_service.py
index 4305810..6aad1af 100644
--- a/sdks/python/apache_beam/runners/portability/local_job_service.py
+++ b/sdks/python/apache_beam/runners/portability/local_job_service.py
@@ -195,7 +195,7 @@
self._state = None
self._state_queues = []
self._log_queues = []
- self.state = beam_job_api_pb2.JobState.STARTING
+ self.state = beam_job_api_pb2.JobState.STOPPED
self.daemon = True
self.result = None
@@ -220,10 +220,12 @@
return self._artifact_staging_endpoint
def run(self):
+ self.state = beam_job_api_pb2.JobState.STARTING
self._run_thread = threading.Thread(target=self._run_job)
self._run_thread.start()
def _run_job(self):
+ self.state = beam_job_api_pb2.JobState.RUNNING
with JobLogHandler(self._log_queues):
try:
result = fn_api_runner.FnApiRunner(
@@ -239,7 +241,7 @@
raise
def cancel(self):
- if self.state not in abstract_job_service.TERMINAL_STATES:
+ if not self.is_terminal_state(self.state):
self.state = beam_job_api_pb2.JobState.CANCELLING
# TODO(robertwb): Actually cancel...
self.state = beam_job_api_pb2.JobState.CANCELLED
@@ -253,7 +255,7 @@
while True:
current_state = state_queue.get(block=True)
yield current_state
- if current_state in abstract_job_service.TERMINAL_STATES:
+ if self.is_terminal_state(current_state):
break
def get_message_stream(self):
@@ -264,7 +266,7 @@
current_state = self.state
yield current_state
- while current_state not in abstract_job_service.TERMINAL_STATES:
+ while not self.is_terminal_state(current_state):
msg = log_queue.get(block=True)
yield msg
if isinstance(msg, int):
diff --git a/sdks/python/apache_beam/runners/portability/portable_runner.py b/sdks/python/apache_beam/runners/portability/portable_runner.py
index 2cffd47..6de9e5c 100644
--- a/sdks/python/apache_beam/runners/portability/portable_runner.py
+++ b/sdks/python/apache_beam/runners/portability/portable_runner.py
@@ -58,7 +58,7 @@
TERMINAL_STATES = [
beam_job_api_pb2.JobState.DONE,
- beam_job_api_pb2.JobState.STOPPED,
+ beam_job_api_pb2.JobState.DRAINED,
beam_job_api_pb2.JobState.FAILED,
beam_job_api_pb2.JobState.CANCELLED,
]
@@ -178,10 +178,9 @@
del transform_proto.subtransforms[:]
# Preemptively apply combiner lifting, until all runners support it.
- # Also apply sdf expansion.
# These optimizations commute and are idempotent.
pre_optimize = options.view_as(DebugOptions).lookup_experiment(
- 'pre_optimize', 'lift_combiners,expand_sdf').lower()
+ 'pre_optimize', 'lift_combiners').lower()
if not options.view_as(StandardOptions).streaming:
flink_known_urns = frozenset([
common_urns.composites.RESHUFFLE.urn,
@@ -210,7 +209,7 @@
phases = []
for phase_name in pre_optimize.split(','):
# For now, these are all we allow.
- if phase_name in ('lift_combiners', 'expand_sdf'):
+ if phase_name in 'lift_combiners':
phases.append(getattr(fn_api_runner_transforms, phase_name))
else:
raise ValueError(
@@ -274,10 +273,12 @@
for k, v in all_options.items()
if v is not None}
+ prepare_request = beam_job_api_pb2.PrepareJobRequest(
+ job_name='job', pipeline=proto_pipeline,
+ pipeline_options=job_utils.dict_to_struct(p_options))
+ logging.debug('PrepareJobRequest: %s', prepare_request)
prepare_response = job_service.Prepare(
- beam_job_api_pb2.PrepareJobRequest(
- job_name='job', pipeline=proto_pipeline,
- pipeline_options=job_utils.dict_to_struct(p_options)),
+ prepare_request,
timeout=portable_options.job_server_timeout)
if prepare_response.artifact_staging_endpoint.url:
stager = portable_stager.PortableStager(
diff --git a/sdks/python/apache_beam/runners/portability/portable_runner_test.py b/sdks/python/apache_beam/runners/portability/portable_runner_test.py
index 24c6b87..992aa08 100644
--- a/sdks/python/apache_beam/runners/portability/portable_runner_test.py
+++ b/sdks/python/apache_beam/runners/portability/portable_runner_test.py
@@ -45,7 +45,10 @@
from apache_beam.runners.portability.portable_runner import PortableRunner
from apache_beam.runners.worker import worker_pool_main
from apache_beam.runners.worker.channel_factory import GRPCChannelFactory
+from apache_beam.testing.util import assert_that
+from apache_beam.testing.util import equal_to
from apache_beam.transforms import environments
+from apache_beam.transforms import userstate
class PortableRunnerTest(fn_api_runner_test.FnApiRunnerTest):
@@ -188,6 +191,46 @@
def test_metrics(self):
self.skipTest('Metrics not supported.')
+ def test_pardo_state_with_custom_key_coder(self):
+ """Tests that state requests work correctly when the key coder is an
+ SDK-specific coder, i.e. non standard coder. This is additionally enforced
+ by Java's ProcessBundleDescriptorsTest and by Flink's
+ ExecutableStageDoFnOperator which detects invalid encoding by checking for
+ the correct key group of the encoded key."""
+ index_state_spec = userstate.CombiningValueStateSpec('index', sum)
+
+ # Test params
+ # Ensure decent amount of elements to serve all partitions
+ n = 200
+ duplicates = 1
+
+ split = n // (duplicates + 1)
+ inputs = [(i % split, str(i % split)) for i in range(0, n)]
+
+ # Use a DoFn which has to use FastPrimitivesCoder because the type cannot
+ # be inferred
+ class Input(beam.DoFn):
+ def process(self, impulse):
+ for i in inputs:
+ yield i
+
+ class AddIndex(beam.DoFn):
+ def process(self, kv,
+ index=beam.DoFn.StateParam(index_state_spec)):
+ k, v = kv
+ index.add(1)
+ yield k, v, index.read()
+
+ expected = [(i % split, str(i % split), i // split + 1)
+ for i in range(0, n)]
+
+ with self.create_pipeline() as p:
+ assert_that(p
+ | beam.Impulse()
+ | beam.ParDo(Input())
+ | beam.ParDo(AddIndex()),
+ equal_to(expected))
+
# Inherits all other tests from fn_api_runner_test.FnApiRunnerTest
diff --git a/sdks/python/apache_beam/runners/portability/spark_runner.py b/sdks/python/apache_beam/runners/portability/spark_runner.py
new file mode 100644
index 0000000..ca03310
--- /dev/null
+++ b/sdks/python/apache_beam/runners/portability/spark_runner.py
@@ -0,0 +1,84 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""A runner for executing portable pipelines on Spark."""
+
+from __future__ import absolute_import
+from __future__ import print_function
+
+import re
+
+from apache_beam.options import pipeline_options
+from apache_beam.runners.portability import job_server
+from apache_beam.runners.portability import portable_runner
+
+# https://spark.apache.org/docs/latest/submitting-applications.html#master-urls
+LOCAL_MASTER_PATTERN = r'^local(\[.+\])?$'
+
+
+class SparkRunner(portable_runner.PortableRunner):
+ def run_pipeline(self, pipeline, options):
+ spark_options = options.view_as(SparkRunnerOptions)
+ portable_options = options.view_as(pipeline_options.PortableOptions)
+ if (re.match(LOCAL_MASTER_PATTERN, spark_options.spark_master_url)
+ and not portable_options.environment_type
+ and not portable_options.output_executable_path):
+ portable_options.environment_type = 'LOOPBACK'
+ return super(SparkRunner, self).run_pipeline(pipeline, options)
+
+ def default_job_server(self, options):
+ # TODO(BEAM-8139) submit a Spark jar to a cluster
+ return job_server.StopOnExitJobServer(SparkJarJobServer(options))
+
+
+class SparkRunnerOptions(pipeline_options.PipelineOptions):
+ @classmethod
+ def _add_argparse_args(cls, parser):
+ parser.add_argument('--spark_master_url',
+ default='local[4]',
+ help='Spark master URL (spark://HOST:PORT). '
+ 'Use "local" (single-threaded) or "local[*]" '
+ '(multi-threaded) to start a local cluster for '
+ 'the execution.')
+ parser.add_argument('--spark_job_server_jar',
+ help='Path or URL to a Beam Spark jobserver jar.')
+ parser.add_argument('--artifacts_dir', default=None)
+
+
+class SparkJarJobServer(job_server.JavaJarJobServer):
+ def __init__(self, options):
+ super(SparkJarJobServer, self).__init__()
+ options = options.view_as(SparkRunnerOptions)
+ self._jar = options.spark_job_server_jar
+ self._master_url = options.spark_master_url
+ self._artifacts_dir = options.artifacts_dir
+
+ def path_to_jar(self):
+ if self._jar:
+ return self._jar
+ else:
+ return self.path_to_beam_jar('runners:spark:job-server:shadowJar')
+
+ def java_arguments(self, job_port, artifacts_dir):
+ return [
+ '--spark-master-url', self._master_url,
+ '--artifacts-dir', (self._artifacts_dir
+ if self._artifacts_dir else artifacts_dir),
+ '--job-port', job_port,
+ '--artifact-port', 0,
+ '--expansion-port', 0
+ ]
diff --git a/sdks/python/apache_beam/runners/runner.py b/sdks/python/apache_beam/runners/runner.py
index 1000480..fe9c492 100644
--- a/sdks/python/apache_beam/runners/runner.py
+++ b/sdks/python/apache_beam/runners/runner.py
@@ -38,6 +38,7 @@
'apache_beam.runners.interactive.interactive_runner.InteractiveRunner',
'apache_beam.runners.portability.flink_runner.FlinkRunner',
'apache_beam.runners.portability.portable_runner.PortableRunner',
+ 'apache_beam.runners.portability.spark_runner.SparkRunner',
'apache_beam.runners.test.TestDirectRunner',
'apache_beam.runners.test.TestDataflowRunner',
)
@@ -328,7 +329,7 @@
@classmethod
def is_terminal(cls, state):
- return state in [cls.STOPPED, cls.DONE, cls.FAILED, cls.CANCELLED,
+ return state in [cls.DONE, cls.FAILED, cls.CANCELLED,
cls.UPDATED, cls.DRAINED]
diff --git a/sdks/python/apache_beam/runners/utils.py b/sdks/python/apache_beam/runners/utils.py
new file mode 100644
index 0000000..8952423
--- /dev/null
+++ b/sdks/python/apache_beam/runners/utils.py
@@ -0,0 +1,47 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Common utility module shared by runners.
+
+For internal use only; no backwards-compatibility guarantees.
+"""
+from __future__ import absolute_import
+
+
+def is_interactive():
+ """Determines if current code execution is in interactive environment.
+
+ Returns:
+ is_in_ipython: (bool) tells if current code is executed within an ipython
+ session.
+ is_in_notebook: (bool) tells if current code is executed from an ipython
+ notebook.
+
+ If is_in_notebook is True, then is_in_ipython must also be True.
+ """
+ is_in_ipython = False
+ is_in_notebook = False
+ # Check if the runtime is within an interactive environment, i.e., ipython.
+ try:
+ from IPython import get_ipython # pylint: disable=import-error
+ if get_ipython():
+ is_in_ipython = True
+ if 'IPKernelApp' in get_ipython().config:
+ is_in_notebook = True
+ except ImportError:
+ pass # If dependencies are not available, then not interactive for sure.
+ return is_in_ipython, is_in_notebook
diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor.py b/sdks/python/apache_beam/runners/worker/bundle_processor.py
index 8439c8f..b3440df 100644
--- a/sdks/python/apache_beam/runners/worker/bundle_processor.py
+++ b/sdks/python/apache_beam/runners/worker/bundle_processor.py
@@ -32,6 +32,7 @@
from builtins import object
from future.utils import itervalues
+from google.protobuf import duration_pb2
from google.protobuf import timestamp_pb2
import apache_beam as beam
@@ -704,8 +705,7 @@
) = split
if element_primary:
split_response.primary_roots.add().CopyFrom(
- self.delayed_bundle_application(
- *element_primary).application)
+ self.bundle_application(*element_primary))
if element_residual:
split_response.residual_roots.add().CopyFrom(
self.delayed_bundle_application(*element_residual))
@@ -718,22 +718,39 @@
return split_response
def delayed_bundle_application(self, op, deferred_remainder):
- transform_id, main_input_tag, main_input_coder, outputs = op.input_info
# TODO(SDF): For non-root nodes, need main_input_coder + residual_coder.
- element_and_restriction, watermark = deferred_remainder
- if watermark:
- proto_watermark = timestamp_pb2.Timestamp()
- proto_watermark.FromMicroseconds(watermark.micros)
- output_watermarks = {output: proto_watermark for output in outputs}
+ ((element_and_restriction, output_watermark),
+ deferred_watermark) = deferred_remainder
+ if deferred_watermark:
+ assert isinstance(deferred_watermark, timestamp.Duration)
+ proto_deferred_watermark = duration_pb2.Duration()
+ proto_deferred_watermark.FromMicroseconds(deferred_watermark.micros)
+ else:
+ proto_deferred_watermark = None
+ return beam_fn_api_pb2.DelayedBundleApplication(
+ requested_time_delay=proto_deferred_watermark,
+ application=self.construct_bundle_application(
+ op, output_watermark, element_and_restriction))
+
+ def bundle_application(self, op, primary):
+ ((element_and_restriction, output_watermark),
+ _) = primary
+ return self.construct_bundle_application(
+ op, output_watermark, element_and_restriction)
+
+ def construct_bundle_application(self, op, output_watermark, element):
+ transform_id, main_input_tag, main_input_coder, outputs = op.input_info
+ if output_watermark:
+ proto_output_watermark = timestamp_pb2.Timestamp()
+ proto_output_watermark.FromMicroseconds(output_watermark.micros)
+ output_watermarks = {output: proto_output_watermark for output in outputs}
else:
output_watermarks = None
- return beam_fn_api_pb2.DelayedBundleApplication(
- application=beam_fn_api_pb2.BundleApplication(
- transform_id=transform_id,
- input_id=main_input_tag,
- output_watermarks=output_watermarks,
- element=main_input_coder.get_impl().encode_nested(
- element_and_restriction)))
+ return beam_fn_api_pb2.BundleApplication(
+ transform_id=transform_id,
+ input_id=main_input_tag,
+ output_watermarks=output_watermarks,
+ element=main_input_coder.get_impl().encode_nested(element))
def metrics(self):
# DEPRECATED
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker.py b/sdks/python/apache_beam/runners/worker/sdk_worker.py
index 208fe75..74a3e99 100644
--- a/sdks/python/apache_beam/runners/worker/sdk_worker.py
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker.py
@@ -182,17 +182,7 @@
self._responses.put(response)
def _request_register(self, request):
-
- def task():
- for process_bundle_descriptor in getattr(
- request, request.WhichOneof('request')).process_bundle_descriptor:
- self._fns[process_bundle_descriptor.id] = process_bundle_descriptor
-
- return beam_fn_api_pb2.InstructionResponse(
- instruction_id=request.instruction_id,
- register=beam_fn_api_pb2.RegisterResponse())
-
- self._execute(task, request)
+ self._request_execute(request)
def _request_process_bundle(self, request):
@@ -241,6 +231,9 @@
self._progress_thread_pool.submit(task)
def _request_finalize_bundle(self, request):
+ self._request_execute(request)
+
+ def _request_execute(self, request):
def task():
# Get one available worker.
diff --git a/sdks/python/apache_beam/runners/worker/statesampler_fast.pxd b/sdks/python/apache_beam/runners/worker/statesampler_fast.pxd
index 799bd0d..aebf9f6 100644
--- a/sdks/python/apache_beam/runners/worker/statesampler_fast.pxd
+++ b/sdks/python/apache_beam/runners/worker/statesampler_fast.pxd
@@ -43,6 +43,9 @@
cdef int32_t current_state_index
+ cpdef ScopedState current_state(self)
+ cdef inline ScopedState current_state_c(self)
+
cpdef _scoped_state(
self, counter_name, name_context, output_counter, metrics_container)
@@ -56,7 +59,7 @@
cdef readonly object name_context
cdef readonly int64_t _nsecs
cdef int32_t old_state_index
- cdef readonly MetricsContainer _metrics_container
+ cdef readonly MetricsContainer metrics_container
cpdef __enter__(self)
diff --git a/sdks/python/apache_beam/runners/worker/statesampler_fast.pyx b/sdks/python/apache_beam/runners/worker/statesampler_fast.pyx
index 325ec99..8d2346a 100644
--- a/sdks/python/apache_beam/runners/worker/statesampler_fast.pyx
+++ b/sdks/python/apache_beam/runners/worker/statesampler_fast.pyx
@@ -159,8 +159,12 @@
(<ScopedState>state)._nsecs = 0
self.started = self.finished = False
- def current_state(self):
- return self.scoped_states_by_index[self.current_state_index]
+ cpdef ScopedState current_state(self):
+ return self.current_state_c()
+
+ cdef inline ScopedState current_state_c(self):
+ # Faster than cpdef due to self always being a Python subclass.
+ return <ScopedState>self.scoped_states_by_index[self.current_state_index]
cpdef _scoped_state(self, counter_name, name_context, output_counter,
metrics_container):
@@ -189,6 +193,11 @@
pythread.PyThread_release_lock(self.lock)
return scoped_state
+ def update_metric(self, typed_metric_name, value):
+ # Each of these is a cdef lookup.
+ self.current_state_c().metrics_container.get_metric_cell(
+ typed_metric_name).update(value)
+
cdef class ScopedState(object):
"""Context manager class managing transitions for a given sampler state."""
@@ -205,7 +214,7 @@
self.name_context = step_name_context
self.state_index = state_index
self.counter = counter
- self._metrics_container = metrics_container
+ self.metrics_container = metrics_container
@property
def nsecs(self):
@@ -232,7 +241,3 @@
self.sampler.current_state_index = self.old_state_index
self.sampler.state_transition_count += 1
pythread.PyThread_release_lock(self.sampler.lock)
-
- @property
- def metrics_container(self):
- return self._metrics_container
diff --git a/sdks/python/apache_beam/runners/worker/statesampler_slow.py b/sdks/python/apache_beam/runners/worker/statesampler_slow.py
index 0091828..fb2592c 100644
--- a/sdks/python/apache_beam/runners/worker/statesampler_slow.py
+++ b/sdks/python/apache_beam/runners/worker/statesampler_slow.py
@@ -50,6 +50,10 @@
return ScopedState(
self, counter_name, name_context, output_counter, metrics_container)
+ def update_metric(self, typed_metric_name, value):
+ self.current_state().metrics_container.get_metric_cell(
+ typed_metric_name).update(value)
+
def _enter_state(self, state):
self.state_transition_count += 1
self._state_stack.append(state)
diff --git a/sdks/python/apache_beam/testing/data/trigger_transcripts.yaml b/sdks/python/apache_beam/testing/data/trigger_transcripts.yaml
index 0e01ad9..cac0c74 100644
--- a/sdks/python/apache_beam/testing/data/trigger_transcripts.yaml
+++ b/sdks/python/apache_beam/testing/data/trigger_transcripts.yaml
@@ -139,6 +139,32 @@
final: false, index: 3, nonspeculative_index: 1}
---
+name: discarding_early_fixed
+window_fn: FixedWindows(10)
+trigger_fn: AfterWatermark(early=AfterCount(2))
+timestamp_combiner: OUTPUT_AT_EOW
+accumulation_mode: discarding
+transcript:
+- input: [1, 2, 3]
+- expect:
+ - {window: [0, 9], values: [1, 2, 3], timestamp: 9, early: true, index: 0}
+- input: [4] # no output
+- input: [14] # no output
+- input: [5]
+- expect:
+ - {window: [0, 9], values: [4, 5], timestamp: 9, early: true, index: 1}
+- input: [18]
+- expect:
+ - {window: [10, 19], values: [14, 18], timestamp: 19, early: true, index: 0}
+- input: [6]
+- watermark: 100
+- expect:
+ - {window: [0, 9], values:[6], timestamp: 9, early: false, late: false,
+ final: true, index: 2, nonspeculative_index: 0}
+ - {window: [10, 19], values:[], timestamp: 19, early: false, late: false,
+ final: true, index: 1, nonspeculative_index: 0}
+
+---
name: garbage_collection
broken_on:
- SwitchingDirectRunner # claims pipeline stall
diff --git a/sdks/python/apache_beam/testing/pipeline_verifiers_test.py b/sdks/python/apache_beam/testing/pipeline_verifiers_test.py
index 16ffee9..ec17ef6 100644
--- a/sdks/python/apache_beam/testing/pipeline_verifiers_test.py
+++ b/sdks/python/apache_beam/testing/pipeline_verifiers_test.py
@@ -66,12 +66,11 @@
def test_pipeline_state_matcher_fails(self):
"""Test PipelineStateMatcher fails when using default expected state
- and job actually finished in CANCELLED/DRAINED/FAILED/STOPPED/UNKNOWN
+ and job actually finished in CANCELLED/DRAINED/FAILED/UNKNOWN
"""
failed_state = [PipelineState.CANCELLED,
PipelineState.DRAINED,
PipelineState.FAILED,
- PipelineState.STOPPED,
PipelineState.UNKNOWN]
for state in failed_state:
diff --git a/sdks/python/apache_beam/testing/synthetic_pipeline.py b/sdks/python/apache_beam/testing/synthetic_pipeline.py
index 50740ba..fbef112 100644
--- a/sdks/python/apache_beam/testing/synthetic_pipeline.py
+++ b/sdks/python/apache_beam/testing/synthetic_pipeline.py
@@ -523,7 +523,7 @@
element,
restriction_tracker=beam.DoFn.RestrictionParam(
SyntheticSDFSourceRestrictionProvider())):
- cur = restriction_tracker.start_position()
+ cur = restriction_tracker.current_restriction().start
while restriction_tracker.try_claim(cur):
r = np.random.RandomState(cur)
time.sleep(element['sleep_per_input_record_sec'])
diff --git a/sdks/python/apache_beam/testing/test_stream.py b/sdks/python/apache_beam/testing/test_stream.py
index 02a8607..9d9284c 100644
--- a/sdks/python/apache_beam/testing/test_stream.py
+++ b/sdks/python/apache_beam/testing/test_stream.py
@@ -31,6 +31,8 @@
from apache_beam import coders
from apache_beam import core
from apache_beam import pvalue
+from apache_beam.portability import common_urns
+from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.transforms import PTransform
from apache_beam.transforms import window
from apache_beam.transforms.window import TimestampedValue
@@ -66,6 +68,28 @@
# TODO(BEAM-5949): Needed for Python 2 compatibility.
return not self == other
+ @abstractmethod
+ def to_runner_api(self, element_coder):
+ raise NotImplementedError
+
+ @staticmethod
+ def from_runner_api(proto, element_coder):
+ if proto.HasField('element_event'):
+ return ElementEvent(
+ [TimestampedValue(
+ element_coder.decode(tv.encoded_element),
+ timestamp.Timestamp(micros=1000 * tv.timestamp))
+ for tv in proto.element_event.elements])
+ elif proto.HasField('watermark_event'):
+ return WatermarkEvent(timestamp.Timestamp(
+ micros=1000 * proto.watermark_event.new_watermark))
+ elif proto.HasField('processing_time_event'):
+ return ProcessingTimeEvent(timestamp.Duration(
+ micros=1000 * proto.processing_time_event.advance_duration))
+ else:
+ raise ValueError(
+ 'Unknown TestStream Event type: %s' % proto.WhichOneof('event'))
+
class ElementEvent(Event):
"""Element-producing test stream event."""
@@ -82,6 +106,15 @@
def __lt__(self, other):
return self.timestamped_values < other.timestamped_values
+ def to_runner_api(self, element_coder):
+ return beam_runner_api_pb2.TestStreamPayload.Event(
+ element_event=beam_runner_api_pb2.TestStreamPayload.Event.AddElements(
+ elements=[
+ beam_runner_api_pb2.TestStreamPayload.TimestampedElement(
+ encoded_element=element_coder.encode(tv.value),
+ timestamp=tv.timestamp.micros // 1000)
+ for tv in self.timestamped_values]))
+
class WatermarkEvent(Event):
"""Watermark-advancing test stream event."""
@@ -98,6 +131,11 @@
def __lt__(self, other):
return self.new_watermark < other.new_watermark
+ def to_runner_api(self, unused_element_coder):
+ return beam_runner_api_pb2.TestStreamPayload.Event(
+ watermark_event
+ =beam_runner_api_pb2.TestStreamPayload.Event.AdvanceWatermark(
+ new_watermark=self.new_watermark.micros // 1000))
class ProcessingTimeEvent(Event):
"""Processing time-advancing test stream event."""
@@ -114,6 +152,12 @@
def __lt__(self, other):
return self.advance_by < other.advance_by
+ def to_runner_api(self, unused_element_coder):
+ return beam_runner_api_pb2.TestStreamPayload.Event(
+ processing_time_event
+ =beam_runner_api_pb2.TestStreamPayload.Event.AdvanceProcessingTime(
+ advance_duration=self.advance_by.micros // 1000))
+
class TestStream(PTransform):
"""Test stream that generates events on an unbounded PCollection of elements.
@@ -123,11 +167,12 @@
output.
"""
- def __init__(self, coder=coders.FastPrimitivesCoder):
+ def __init__(self, coder=coders.FastPrimitivesCoder(), events=()):
+ super(TestStream, self).__init__()
assert coder is not None
self.coder = coder
self.current_watermark = timestamp.MIN_TIMESTAMP
- self.events = []
+ self.events = list(events)
def get_windowing(self, unused_inputs):
return core.Windowing(window.GlobalWindows())
@@ -206,3 +251,19 @@
"""
self._add(ProcessingTimeEvent(advance_by))
return self
+
+ def to_runner_api_parameter(self, context):
+ return (
+ common_urns.primitives.TEST_STREAM.urn,
+ beam_runner_api_pb2.TestStreamPayload(
+ coder_id=context.coders.get_id(self.coder),
+ events=[e.to_runner_api(self.coder) for e in self.events]))
+
+ @PTransform.register_urn(
+ common_urns.primitives.TEST_STREAM.urn,
+ beam_runner_api_pb2.TestStreamPayload)
+ def from_runner_api_parameter(payload, context):
+ coder = context.coders.get_by_id(payload.coder_id)
+ return TestStream(
+ coder=coder,
+ events=[Event.from_runner_api(e, coder) for e in payload.events])
diff --git a/sdks/python/apache_beam/testing/test_stream_test.py b/sdks/python/apache_beam/testing/test_stream_test.py
index 3297f63..c8bc9ff 100644
--- a/sdks/python/apache_beam/testing/test_stream_test.py
+++ b/sdks/python/apache_beam/testing/test_stream_test.py
@@ -101,9 +101,11 @@
.advance_processing_time(10)
.advance_watermark_to(300)
.add_elements([TimestampedValue('late', 12)])
- .add_elements([TimestampedValue('last', 310)]))
+ .add_elements([TimestampedValue('last', 310)])
+ .advance_watermark_to_infinity())
class RecordFn(beam.DoFn):
+
def process(self, element=beam.DoFn.ElementParam,
timestamp=beam.DoFn.TimestampParam):
yield (element, timestamp)
@@ -135,7 +137,8 @@
.advance_processing_time(10)
.advance_watermark_to(300)
.add_elements([TimestampedValue('late', 12)])
- .add_elements([TimestampedValue('last', 310)]))
+ .add_elements([TimestampedValue('last', 310)])
+ .advance_watermark_to_infinity())
options = PipelineOptions()
options.view_as(StandardOptions).streaming = True
@@ -175,7 +178,8 @@
test_stream = (TestStream()
.advance_watermark_to(10)
.add_elements(['a'])
- .advance_watermark_to(20))
+ .advance_watermark_to(20)
+ .advance_watermark_to_infinity())
options = PipelineOptions()
options.view_as(StandardOptions).streaming = True
@@ -217,7 +221,8 @@
test_stream = (TestStream()
.advance_watermark_to(10)
.add_elements(['a'])
- .advance_processing_time(5.1))
+ .advance_processing_time(5.1)
+ .advance_watermark_to_infinity())
options = PipelineOptions()
options.view_as(StandardOptions).streaming = True
@@ -255,12 +260,14 @@
main_stream = (p
| 'main TestStream' >> TestStream()
.advance_watermark_to(10)
- .add_elements(['e']))
+ .add_elements(['e'])
+ .advance_watermark_to_infinity())
side = (p
| beam.Create([2, 1, 4])
| beam.Map(lambda t: window.TimestampedValue(t, t)))
class RecordFn(beam.DoFn):
+
def process(self,
elm=beam.DoFn.ElementParam,
ts=beam.DoFn.TimestampParam,
@@ -316,6 +323,7 @@
.add_elements(['a'])
.advance_watermark_to(4)
.add_elements(['b'])
+ .advance_watermark_to_infinity()
| 'main window' >> beam.WindowInto(window.FixedWindows(1)))
side = (p
| beam.Create([2, 1, 4])
@@ -323,6 +331,7 @@
| beam.WindowInto(window.FixedWindows(2)))
class RecordFn(beam.DoFn):
+
def process(self,
elm=beam.DoFn.ElementParam,
ts=beam.DoFn.TimestampParam,
@@ -334,8 +343,8 @@
# assert per window
expected_window_to_elements = {
- window.IntervalWindow(2, 3):[('a', Timestamp(2), [2])],
- window.IntervalWindow(4, 5):[('b', Timestamp(4), [4])]
+ window.IntervalWindow(2, 3): [('a', Timestamp(2), [2])],
+ window.IntervalWindow(4, 5): [('b', Timestamp(4), [4])]
}
assert_that(
records,
diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py
index 148caae..06fd201 100644
--- a/sdks/python/apache_beam/transforms/core.py
+++ b/sdks/python/apache_beam/transforms/core.py
@@ -63,6 +63,7 @@
from apache_beam.typehints.decorators import get_type_hints
from apache_beam.typehints.trivial_inference import element_type
from apache_beam.typehints.typehints import is_consistent_with
+from apache_beam.utils import timestamp
from apache_beam.utils import urns
try:
@@ -91,7 +92,8 @@
'Flatten',
'Create',
'Impulse',
- 'RestrictionProvider'
+ 'RestrictionProvider',
+ 'WatermarkEstimator'
]
# Type variables
@@ -242,6 +244,8 @@
def create_tracker(self, restriction):
"""Produces a new ``RestrictionTracker`` for the given restriction.
+ This API is required to be implemented.
+
Args:
restriction: an object that defines a restriction as identified by a
Splittable ``DoFn`` that utilizes the current ``RestrictionProvider``.
@@ -252,7 +256,10 @@
raise NotImplementedError
def initial_restriction(self, element):
- """Produces an initial restriction for the given element."""
+ """Produces an initial restriction for the given element.
+
+ This API is required to be implemented.
+ """
raise NotImplementedError
def split(self, element, restriction):
@@ -262,6 +269,9 @@
reading input element for each of the returned restrictions should be the
same as the total set of elements produced by reading the input element for
the input restriction.
+
+ This API is optional if ``split_and_size`` has been implemented.
+
"""
yield restriction
@@ -281,11 +291,16 @@
By default, asks a newly-created restriction tracker for the default size
of the restriction.
+
+ This API is required to be implemented.
"""
- return self.create_tracker(restriction).default_size()
+ raise NotImplementedError
def split_and_size(self, element, restriction):
"""Like split, but also does sizing, returning (restriction, size) pairs.
+
+ This API is optional if ``split`` and ``restriction_size`` have been
+ implemented.
"""
for part in self.split(element, restriction):
yield part, self.restriction_size(element, part)
@@ -379,6 +394,43 @@
return None
+class WatermarkEstimator(object):
+ """A WatermarkEstimator which is used for tracking output_watermark in a
+ DoFn.process(), typically tracking per <element, restriction> pair in SDF in
+ streaming.
+
+ There are 3 APIs in this class: set_watermark, current_watermark and reset
+ with default implementations.
+
+ TODO(BEAM-8537): Create WatermarkEstimatorProvider to support different types.
+ """
+ def __init__(self):
+ self._watermark = None
+
+ def set_watermark(self, watermark):
+ """Update tracking output_watermark with latest output_watermark.
+ This function is called inside an SDF.Process() to track the watermark of
+ output element.
+
+ Args:
+ watermark: the `timestamp.Timestamp` of current output element.
+ """
+ if not isinstance(watermark, timestamp.Timestamp):
+ raise ValueError('watermark should be a object of timestamp.Timestamp')
+ if self._watermark is None:
+ self._watermark = watermark
+ else:
+ self._watermark = min(self._watermark, watermark)
+
+ def current_watermark(self):
+ """Get current output_watermark. This function is called by system."""
+ return self._watermark
+
+ def reset(self):
+ """ Reset current tracking watermark to None."""
+ self._watermark = None
+
+
class _DoFnParam(object):
"""DoFn parameter."""
@@ -459,6 +511,17 @@
del self._callbacks[:]
+class _WatermarkEstimatorParam(_DoFnParam):
+ """WatermarkEstomator DoFn parameter."""
+
+ def __init__(self, watermark_estimator):
+ if not isinstance(watermark_estimator, WatermarkEstimator):
+ raise ValueError('DoFn.WatermarkEstimatorParam expected'
+ 'WatermarkEstimator object.')
+ self.watermark_estimator = watermark_estimator
+ self.param_id = 'WatermarkEstimator'
+
+
class DoFn(WithTypeHints, HasDisplayData, urns.RunnerApiFn):
"""A function object used by a transform with custom processing.
@@ -477,7 +540,7 @@
TimestampParam = _DoFnParam('TimestampParam')
WindowParam = _DoFnParam('WindowParam')
PaneInfoParam = _DoFnParam('PaneInfoParam')
- WatermarkReporterParam = _DoFnParam('WatermarkReporterParam')
+ WatermarkEstimatorParam = _WatermarkEstimatorParam
BundleFinalizerParam = _BundleFinalizerParam
KeyParam = _DoFnParam('KeyParam')
@@ -489,7 +552,7 @@
TimerParam = _TimerDoFnParam
DoFnProcessParams = [ElementParam, SideInputParam, TimestampParam,
- WindowParam, WatermarkReporterParam, PaneInfoParam,
+ WindowParam, WatermarkEstimatorParam, PaneInfoParam,
BundleFinalizerParam, KeyParam, StateParam, TimerParam]
RestrictionParam = _RestrictionDoFnParam
@@ -522,7 +585,7 @@
``DoFn.RestrictionParam``: an ``iobase.RestrictionTracker`` will be
provided here to allow treatment as a Splittable ``DoFn``. The restriction
tracker will be derived from the restriction provider in the parameter.
- ``DoFn.WatermarkReporterParam``: a function that can be used to report
+ ``DoFn.WatermarkEstimatorParam``: a function that can be used to track
output watermark of Splittable ``DoFn`` implementations.
Args:
diff --git a/sdks/python/apache_beam/transforms/core_test.py b/sdks/python/apache_beam/transforms/core_test.py
new file mode 100644
index 0000000..1a27bd2
--- /dev/null
+++ b/sdks/python/apache_beam/transforms/core_test.py
@@ -0,0 +1,54 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Unit tests for core module."""
+
+from __future__ import absolute_import
+
+import unittest
+
+from apache_beam.transforms.core import WatermarkEstimator
+from apache_beam.utils.timestamp import Timestamp
+
+
+class WatermarkEstimatorTest(unittest.TestCase):
+
+ def test_set_watermark(self):
+ watermark_estimator = WatermarkEstimator()
+ self.assertEqual(watermark_estimator.current_watermark(), None)
+ # set_watermark should only accept timestamp.Timestamp.
+ with self.assertRaises(ValueError):
+ watermark_estimator.set_watermark(0)
+
+ # watermark_estimator should always keep minimal timestamp.
+ watermark_estimator.set_watermark(Timestamp(100))
+ self.assertEqual(watermark_estimator.current_watermark(), 100)
+ watermark_estimator.set_watermark(Timestamp(150))
+ self.assertEqual(watermark_estimator.current_watermark(), 100)
+ watermark_estimator.set_watermark(Timestamp(50))
+ self.assertEqual(watermark_estimator.current_watermark(), 50)
+
+ def test_reset(self):
+ watermark_estimator = WatermarkEstimator()
+ watermark_estimator.set_watermark(Timestamp(100))
+ self.assertEqual(watermark_estimator.current_watermark(), 100)
+ watermark_estimator.reset()
+ self.assertEqual(watermark_estimator.current_watermark(), None)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/sdks/python/apache_beam/transforms/trigger_test.py b/sdks/python/apache_beam/transforms/trigger_test.py
index dbc4bcd..22ecda3 100644
--- a/sdks/python/apache_beam/transforms/trigger_test.py
+++ b/sdks/python/apache_beam/transforms/trigger_test.py
@@ -20,6 +20,7 @@
from __future__ import absolute_import
import collections
+import json
import os.path
import pickle
import unittest
@@ -31,6 +32,7 @@
import yaml
import apache_beam as beam
+from apache_beam import coders
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.options.pipeline_options import StandardOptions
from apache_beam.runners import pipeline_context
@@ -502,7 +504,10 @@
while hasattr(cls, unique_name):
counter += 1
unique_name = 'test_%s_%d' % (name, counter)
- setattr(cls, unique_name, lambda self: self._run_log_test(spec))
+ test_method = lambda self: self._run_log_test(spec)
+ test_method.__name__ = unique_name
+ test_method.__test__ = True
+ setattr(cls, unique_name, test_method)
# We must prepend an underscore to this name so that the open-source unittest
# runner does not execute this method directly as a test.
@@ -606,24 +611,25 @@
window_fn, trigger_fn, accumulation_mode, timestamp_combiner,
transcript, spec)
- def _windowed_value_info(self, windowed_value):
- # Currently some runners operate at the millisecond level, and some at the
- # microsecond level. Trigger transcript timestamps are expressed as
- # integral units of the finest granularity, whatever that may be.
- # In these tests we interpret them as integral seconds and then truncate
- # the results to integral seconds to allow for portability across
- # different sub-second resolutions.
- window, = windowed_value.windows
- return {
- 'window': [int(window.start), int(window.max_timestamp())],
- 'values': sorted(windowed_value.value),
- 'timestamp': int(windowed_value.timestamp),
- 'index': windowed_value.pane_info.index,
- 'nonspeculative_index': windowed_value.pane_info.nonspeculative_index,
- 'early': windowed_value.pane_info.timing == PaneInfoTiming.EARLY,
- 'late': windowed_value.pane_info.timing == PaneInfoTiming.LATE,
- 'final': windowed_value.pane_info.is_last,
- }
+
+def _windowed_value_info(windowed_value):
+ # Currently some runners operate at the millisecond level, and some at the
+ # microsecond level. Trigger transcript timestamps are expressed as
+ # integral units of the finest granularity, whatever that may be.
+ # In these tests we interpret them as integral seconds and then truncate
+ # the results to integral seconds to allow for portability across
+ # different sub-second resolutions.
+ window, = windowed_value.windows
+ return {
+ 'window': [int(window.start), int(window.max_timestamp())],
+ 'values': sorted(windowed_value.value),
+ 'timestamp': int(windowed_value.timestamp),
+ 'index': windowed_value.pane_info.index,
+ 'nonspeculative_index': windowed_value.pane_info.nonspeculative_index,
+ 'early': windowed_value.pane_info.timing == PaneInfoTiming.EARLY,
+ 'late': windowed_value.pane_info.timing == PaneInfoTiming.LATE,
+ 'final': windowed_value.pane_info.is_last,
+ }
class TriggerDriverTranscriptTest(TranscriptTest):
@@ -645,7 +651,7 @@
for timer_window, (name, time_domain, t_timestamp) in to_fire:
for wvalue in driver.process_timer(
timer_window, name, time_domain, t_timestamp, state):
- output.append(self._windowed_value_info(wvalue))
+ output.append(_windowed_value_info(wvalue))
to_fire = state.get_and_clear_timers(watermark)
for action, params in transcript:
@@ -661,7 +667,7 @@
WindowedValue(t, t, window_fn.assign(WindowFn.AssignContext(t, t)))
for t in params]
output = [
- self._windowed_value_info(wv)
+ _windowed_value_info(wv)
for wv in driver.process_elements(state, bundle, watermark)]
fire_timers()
@@ -690,7 +696,7 @@
self.assertEqual([], output, msg='Unexpected output: %s' % output)
-class TestStreamTranscriptTest(TranscriptTest):
+class BaseTestStreamTranscriptTest(TranscriptTest):
"""A suite of TestStream-based tests based on trigger transcript entries.
"""
@@ -702,14 +708,17 @@
if runner_name in spec.get('broken_on', ()):
self.skipTest('Known to be broken on %s' % runner_name)
- test_stream = TestStream()
+ # Elements are encoded as a json strings to allow other languages to
+ # decode elements while executing the test stream.
+ # TODO(BEAM-8600): Eliminate these gymnastics.
+ test_stream = TestStream(coder=coders.StrUtf8Coder()).with_output_types(str)
for action, params in transcript:
if action == 'expect':
- test_stream.add_elements([('expect', params)])
+ test_stream.add_elements([json.dumps(('expect', params))])
else:
- test_stream.add_elements([('expect', [])])
+ test_stream.add_elements([json.dumps(('expect', []))])
if action == 'input':
- test_stream.add_elements([('input', e) for e in params])
+ test_stream.add_elements([json.dumps(('input', e)) for e in params])
elif action == 'watermark':
test_stream.advance_watermark_to(params)
elif action == 'clock':
@@ -718,7 +727,9 @@
pass # Requires inspection of implementation details.
else:
raise ValueError('Unexpected action: %s' % action)
- test_stream.add_elements([('expect', [])])
+ test_stream.add_elements([json.dumps(('expect', []))])
+
+ read_test_stream = test_stream | beam.Map(json.loads)
class Check(beam.DoFn):
"""A StatefulDoFn that verifies outputs are produced as expected.
@@ -731,12 +742,40 @@
The key is ignored, but all items must be on the same key to share state.
"""
+ def __init__(self, allow_out_of_order=True):
+ # Some runners don't support cross-stage TestStream semantics.
+ self.allow_out_of_order = allow_out_of_order
+
def process(
- self, element, seen=beam.DoFn.StateParam(
+ self,
+ element,
+ seen=beam.DoFn.StateParam(
beam.transforms.userstate.BagStateSpec(
'seen',
+ beam.coders.FastPrimitivesCoder())),
+ expected=beam.DoFn.StateParam(
+ beam.transforms.userstate.BagStateSpec(
+ 'expected',
beam.coders.FastPrimitivesCoder()))):
_, (action, data) = element
+
+ if self.allow_out_of_order:
+ if action == 'expect' and not list(seen.read()):
+ if data:
+ expected.add(data)
+ return
+ elif action == 'actual' and list(expected.read()):
+ seen.add(data)
+ all_data = list(seen.read())
+ all_expected = list(expected.read())
+ if len(all_data) == len(all_expected[0]):
+ expected.clear()
+ for expect in all_expected[1:]:
+ expected.add(expect)
+ action, data = 'expect', all_expected[0]
+ else:
+ return
+
if action == 'actual':
seen.add(data)
@@ -768,12 +807,14 @@
else:
raise ValueError('Unexpected action: %s' % action)
- with TestPipeline(options=PipelineOptions(streaming=True)) as p:
+ with TestPipeline() as p:
+ # TODO(BEAM-8601): Pass this during pipeline construction.
+ p.options.view_as(StandardOptions).streaming = True
# Split the test stream into a branch of to-be-processed elements, and
# a branch of expected results.
inputs, expected = (
p
- | test_stream
+ | read_test_stream
| beam.MapTuple(
lambda tag, value: beam.pvalue.TaggedOutput(tag, ('key', value))
).with_outputs('input', 'expect'))
@@ -794,7 +835,7 @@
t=beam.DoFn.TimestampParam,
p=beam.DoFn.PaneInfoParam: (
k,
- self._windowed_value_info(WindowedValue(
+ _windowed_value_info(WindowedValue(
vs, windows=[window], timestamp=t, pane_info=p))))
# Place outputs back into the global window to allow flattening
# and share a single state in Check.
@@ -805,7 +846,17 @@
tagged_outputs = (
outputs | beam.MapTuple(lambda key, value: (key, ('actual', value))))
# pylint: disable=expression-not-assigned
- (tagged_expected, tagged_outputs) | beam.Flatten() | beam.ParDo(Check())
+ ([tagged_expected, tagged_outputs]
+ | beam.Flatten()
+ | beam.ParDo(Check(self.allow_out_of_order)))
+
+
+class TestStreamTranscriptTest(BaseTestStreamTranscriptTest):
+ allow_out_of_order = False
+
+
+class WeakTestStreamTranscriptTest(BaseTestStreamTranscriptTest):
+ allow_out_of_order = True
TRANSCRIPT_TEST_FILE = os.path.join(
@@ -814,6 +865,7 @@
if os.path.exists(TRANSCRIPT_TEST_FILE):
TriggerDriverTranscriptTest._create_tests(TRANSCRIPT_TEST_FILE)
TestStreamTranscriptTest._create_tests(TRANSCRIPT_TEST_FILE)
+ WeakTestStreamTranscriptTest._create_tests(TRANSCRIPT_TEST_FILE)
if __name__ == '__main__':
diff --git a/sdks/python/apache_beam/utils/subprocess_server.py b/sdks/python/apache_beam/utils/subprocess_server.py
index e63f64e..65dbcae 100644
--- a/sdks/python/apache_beam/utils/subprocess_server.py
+++ b/sdks/python/apache_beam/utils/subprocess_server.py
@@ -174,8 +174,9 @@
elif '.dev' in beam_version:
# TODO: Attempt to use nightly snapshots?
raise RuntimeError(
- 'Please build the server with \n cd %s; ./gradlew %s' % (
- os.path.abspath(project_root), gradle_target))
+ ('%s not found. '
+ 'Please build the server with \n cd %s; ./gradlew %s') % (
+ local_path, os.path.abspath(project_root), gradle_target))
else:
return cls.path_to_maven_jar(
artifact_id, cls.BEAM_GROUP_ID, beam_version, cls.APACHE_REPOSITORY)
diff --git a/sdks/python/apache_beam/utils/timestamp.py b/sdks/python/apache_beam/utils/timestamp.py
index 9bccdfd..a3f3abf 100644
--- a/sdks/python/apache_beam/utils/timestamp.py
+++ b/sdks/python/apache_beam/utils/timestamp.py
@@ -25,6 +25,7 @@
import datetime
import functools
+import time
from builtins import object
import dateutil.parser
@@ -76,6 +77,10 @@
return Timestamp(seconds)
@staticmethod
+ def now():
+ return Timestamp(seconds=time.time())
+
+ @staticmethod
def _epoch_datetime_utc():
return datetime.datetime.fromtimestamp(0, pytz.utc)
@@ -173,6 +178,8 @@
return self + other
def __sub__(self, other):
+ if isinstance(other, Timestamp):
+ return Duration(micros=self.micros - other.micros)
other = Duration.of(other)
return Timestamp(micros=self.micros - other.micros)
diff --git a/sdks/python/apache_beam/utils/timestamp_test.py b/sdks/python/apache_beam/utils/timestamp_test.py
index d26d561..2a4d454 100644
--- a/sdks/python/apache_beam/utils/timestamp_test.py
+++ b/sdks/python/apache_beam/utils/timestamp_test.py
@@ -100,6 +100,7 @@
self.assertEqual(Timestamp(123) - Duration(456), -333)
self.assertEqual(Timestamp(1230) % 456, 318)
self.assertEqual(Timestamp(1230) % Duration(456), 318)
+ self.assertEqual(Timestamp(123) - Timestamp(100), 23)
# Check that direct comparison of Timestamp and Duration is allowed.
self.assertTrue(Duration(123) == Timestamp(123))
@@ -116,6 +117,7 @@
self.assertEqual((Timestamp(123) - Duration(456)).__class__, Timestamp)
self.assertEqual((Timestamp(1230) % 456).__class__, Duration)
self.assertEqual((Timestamp(1230) % Duration(456)).__class__, Duration)
+ self.assertEqual((Timestamp(123) - Timestamp(100)).__class__, Duration)
# Unsupported operations.
with self.assertRaises(TypeError):
@@ -159,6 +161,10 @@
self.assertEqual('Timestamp(-999999999)',
str(Timestamp(-999999999)))
+ def test_now(self):
+ now = Timestamp.now()
+ self.assertTrue(isinstance(now, Timestamp))
+
class DurationTest(unittest.TestCase):
diff --git a/sdks/python/container/base_image_requirements.txt b/sdks/python/container/base_image_requirements.txt
index d3931e7..359d6e5 100644
--- a/sdks/python/container/base_image_requirements.txt
+++ b/sdks/python/container/base_image_requirements.txt
@@ -40,7 +40,7 @@
pydot==1.4.1
pytz==2019.1
pyvcf==0.6.8;python_version<"3.0"
-pyyaml==3.13
+pyyaml==5.1
typing==3.6.6
# Setup packages
diff --git a/sdks/python/scripts/generate_pydoc.sh b/sdks/python/scripts/generate_pydoc.sh
index e3794ba..bc77fd8 100755
--- a/sdks/python/scripts/generate_pydoc.sh
+++ b/sdks/python/scripts/generate_pydoc.sh
@@ -183,6 +183,7 @@
'_TimerDoFnParam',
'_BundleFinalizerParam',
'_RestrictionDoFnParam',
+ '_WatermarkEstimatorParam',
# Sphinx cannot find this py:class reference target
'typing.Generic',
diff --git a/sdks/python/setup.py b/sdks/python/setup.py
index ccf90f6..7eea64c 100644
--- a/sdks/python/setup.py
+++ b/sdks/python/setup.py
@@ -214,6 +214,7 @@
ext_modules=cythonize([
'apache_beam/**/*.pyx',
'apache_beam/coders/coder_impl.py',
+ 'apache_beam/metrics/cells.py',
'apache_beam/metrics/execution.py',
'apache_beam/runners/common.py',
'apache_beam/runners/worker/logger.py',
diff --git a/sdks/python/test-suites/portable/common.gradle b/sdks/python/test-suites/portable/common.gradle
index f04d28d..3cb4362 100644
--- a/sdks/python/test-suites/portable/common.gradle
+++ b/sdks/python/test-suites/portable/common.gradle
@@ -79,3 +79,22 @@
task flinkValidatesRunner() {
dependsOn 'flinkCompatibilityMatrixLoopback'
}
+
+// TODO(BEAM-8598): Enable on pre-commit.
+task flinkTriggerTranscript() {
+ dependsOn 'setupVirtualenv'
+ dependsOn ':runners:flink:1.9:job-server:shadowJar'
+ doLast {
+ exec {
+ executable 'sh'
+ args '-c', """
+ . ${envdir}/bin/activate \\
+ && cd ${pythonRootDir} \\
+ && pip install -e .[test] \\
+ && python setup.py nosetests \\
+ --tests apache_beam.transforms.trigger_test:WeakTestStreamTranscriptTest \\
+ --test-pipeline-options='--runner=FlinkRunner --environment_type=LOOPBACK --flink_job_server_jar=${project(":runners:flink:1.9:job-server:").shadowJar.archivePath}'
+ """
+ }
+ }
+}
diff --git a/sdks/python/test-suites/portable/py2/build.gradle b/sdks/python/test-suites/portable/py2/build.gradle
index 3c1548d..5d967e4 100644
--- a/sdks/python/test-suites/portable/py2/build.gradle
+++ b/sdks/python/test-suites/portable/py2/build.gradle
@@ -39,6 +39,8 @@
dependsOn ':runners:flink:1.9:job-server:shadowJar'
dependsOn portableWordCountFlinkRunnerBatch
dependsOn portableWordCountFlinkRunnerStreaming
+ dependsOn ':runners:spark:job-server:shadowJar'
+ dependsOn portableWordCountSparkRunnerBatch
}
// TODO: Move the rest of this file into ../common.gradle.
diff --git a/sdks/python/test-suites/portable/py35/build.gradle b/sdks/python/test-suites/portable/py35/build.gradle
index 1b2cb4f..88b4e2f 100644
--- a/sdks/python/test-suites/portable/py35/build.gradle
+++ b/sdks/python/test-suites/portable/py35/build.gradle
@@ -36,4 +36,6 @@
dependsOn ':runners:flink:1.9:job-server:shadowJar'
dependsOn portableWordCountFlinkRunnerBatch
dependsOn portableWordCountFlinkRunnerStreaming
+ dependsOn ':runners:spark:job-server:shadowJar'
+ dependsOn portableWordCountSparkRunnerBatch
}
diff --git a/sdks/python/test-suites/portable/py36/build.gradle b/sdks/python/test-suites/portable/py36/build.gradle
index 475e110..496777d 100644
--- a/sdks/python/test-suites/portable/py36/build.gradle
+++ b/sdks/python/test-suites/portable/py36/build.gradle
@@ -36,4 +36,6 @@
dependsOn ':runners:flink:1.9:job-server:shadowJar'
dependsOn portableWordCountFlinkRunnerBatch
dependsOn portableWordCountFlinkRunnerStreaming
+ dependsOn ':runners:spark:job-server:shadowJar'
+ dependsOn portableWordCountSparkRunnerBatch
}
diff --git a/sdks/python/test-suites/portable/py37/build.gradle b/sdks/python/test-suites/portable/py37/build.gradle
index 912b316..924de81 100644
--- a/sdks/python/test-suites/portable/py37/build.gradle
+++ b/sdks/python/test-suites/portable/py37/build.gradle
@@ -36,4 +36,6 @@
dependsOn ':runners:flink:1.9:job-server:shadowJar'
dependsOn portableWordCountFlinkRunnerBatch
dependsOn portableWordCountFlinkRunnerStreaming
+ dependsOn ':runners:spark:job-server:shadowJar'
+ dependsOn portableWordCountSparkRunnerBatch
}
diff --git a/website/Gemfile b/website/Gemfile
index 4a08725..1050303 100644
--- a/website/Gemfile
+++ b/website/Gemfile
@@ -20,7 +20,7 @@
source 'https://rubygems.org'
-gem 'jekyll', '3.2'
+gem 'jekyll', '3.6.3'
# Jekyll plugins
group :jekyll_plugins do
diff --git a/website/Gemfile.lock b/website/Gemfile.lock
index e94f132..9db2ebe 100644
--- a/website/Gemfile.lock
+++ b/website/Gemfile.lock
@@ -13,7 +13,7 @@
concurrent-ruby (1.1.4)
ethon (0.11.0)
ffi (>= 1.3.0)
- ffi (1.9.25)
+ ffi (1.11.1)
forwardable-extended (2.6.0)
html-proofer (3.9.3)
activesupport (>= 4.2, < 6.0)
@@ -26,15 +26,16 @@
yell (~> 2.0)
i18n (0.9.5)
concurrent-ruby (~> 1.0)
- jekyll (3.2.0)
+ jekyll (3.6.3)
+ addressable (~> 2.4)
colorator (~> 1.0)
jekyll-sass-converter (~> 1.0)
jekyll-watch (~> 1.1)
- kramdown (~> 1.3)
- liquid (~> 3.0)
+ kramdown (~> 1.14)
+ liquid (~> 4.0)
mercenary (~> 0.3.3)
pathutil (~> 0.9)
- rouge (~> 1.7)
+ rouge (>= 1.7, < 3)
safe_yaml (~> 1.0)
jekyll-redirect-from (0.11.0)
jekyll (>= 2.0)
@@ -45,29 +46,27 @@
jekyll_github_sample (0.3.1)
activesupport (~> 4.0)
jekyll (~> 3.0)
- kramdown (1.16.2)
- liquid (3.0.6)
- listen (3.1.5)
- rb-fsevent (~> 0.9, >= 0.9.4)
- rb-inotify (~> 0.9, >= 0.9.7)
- ruby_dep (~> 1.2)
+ kramdown (1.17.0)
+ liquid (4.0.3)
+ listen (3.2.0)
+ rb-fsevent (~> 0.10, >= 0.10.3)
+ rb-inotify (~> 0.9, >= 0.9.10)
mercenary (0.3.6)
mini_portile2 (2.3.0)
minitest (5.11.3)
nokogiri (1.8.5)
mini_portile2 (~> 2.3.0)
parallel (1.12.1)
- pathutil (0.16.1)
+ pathutil (0.16.2)
forwardable-extended (~> 2.6)
public_suffix (3.0.3)
rake (12.3.0)
- rb-fsevent (0.10.2)
- rb-inotify (0.9.10)
- ffi (>= 0.5.0, < 2)
- rouge (1.11.1)
- ruby_dep (1.5.0)
- safe_yaml (1.0.4)
- sass (3.5.5)
+ rb-fsevent (0.10.3)
+ rb-inotify (0.10.0)
+ ffi (~> 1.0)
+ rouge (2.2.1)
+ safe_yaml (1.0.5)
+ sass (3.7.4)
sass-listen (~> 4.0.0)
sass-listen (4.0.0)
rb-fsevent (~> 0.9, >= 0.9.4)
@@ -85,7 +84,7 @@
DEPENDENCIES
activesupport (< 5.0.0.0)
html-proofer
- jekyll (= 3.2)
+ jekyll (= 3.6.3)
jekyll-redirect-from
jekyll-sass-converter
jekyll_github_sample
diff --git a/website/src/_includes/section-menu/documentation.html b/website/src/_includes/section-menu/documentation.html
index 529c264e..ed776ea 100644
--- a/website/src/_includes/section-menu/documentation.html
+++ b/website/src/_includes/section-menu/documentation.html
@@ -215,6 +215,7 @@
<li><a href="{{ site.baseurl }}/documentation/transforms/java/aggregation/distinct/">Distinct</a></li>
<li><a href="{{ site.baseurl }}/documentation/transforms/java/aggregation/groupbykey/">GroupByKey</a></li>
<li><a href="{{ site.baseurl }}/documentation/transforms/java/aggregation/groupintobatches/">GroupIntoBatches</a></li>
+ <li><a href="{{ site.baseurl }}/documentation/transforms/java/aggregation/hllcount/">HllCount</a></li>
<li><a href="{{ site.baseurl }}/documentation/transforms/java/aggregation/latest/">Latest</a></li>
<li><a href="{{ site.baseurl }}/documentation/transforms/java/aggregation/max/">Max</a></li>
<li><a href="{{ site.baseurl }}/documentation/transforms/java/aggregation/mean/">Mean</a></li>
diff --git a/website/src/documentation/programming-guide.md b/website/src/documentation/programming-guide.md
index d78b609..9d644ae 100644
--- a/website/src/documentation/programming-guide.md
+++ b/website/src/documentation/programming-guide.md
@@ -1932,8 +1932,8 @@
suffix ".csv" in the given location:
```java
-p.apply(“ReadFromText”,
- TextIO.read().from("protocol://my_bucket/path/to/input-*.csv");
+p.apply("ReadFromText",
+ TextIO.read().from("protocol://my_bucket/path/to/input-*.csv"));
```
```py
diff --git a/website/src/documentation/transforms/java/aggregation/approximateunique.md b/website/src/documentation/transforms/java/aggregation/approximateunique.md
index 9b3e6d0..448c0ee 100644
--- a/website/src/documentation/transforms/java/aggregation/approximateunique.md
+++ b/website/src/documentation/transforms/java/aggregation/approximateunique.md
@@ -35,6 +35,8 @@
See [BEAM-7703](https://issues.apache.org/jira/browse/BEAM-7703) for updates.
## Related transforms
+* [HllCount]({{ site.baseurl }}/documentation/transforms/java/aggregation/hllcount)
+ estimates the number of distinct elements and creates re-aggregatable sketches using the HyperLogLog++ algorithm.
* [Count]({{ site.baseurl }}/documentation/transforms/java/aggregation/count)
counts the number of elements within each aggregation.
* [Distinct]({{ site.baseurl }}/documentation/transforms/java/aggregation/distinct)
\ No newline at end of file
diff --git a/website/src/documentation/transforms/java/aggregation/hllcount.md b/website/src/documentation/transforms/java/aggregation/hllcount.md
new file mode 100644
index 0000000..506a8dc
--- /dev/null
+++ b/website/src/documentation/transforms/java/aggregation/hllcount.md
@@ -0,0 +1,77 @@
+---
+layout: section
+title: "HllCount"
+permalink: /documentation/transforms/java/aggregation/hllcount/
+section_menu: section-menu/documentation.html
+---
+<!--
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+-->
+# Latest
+<table align="left">
+ <a target="_blank" class="button"
+ href="https://beam.apache.org/releases/javadoc/current/index.html?org/apache/beam/sdk/extensions/zetasketch/HllCount.html">
+ <img src="https://beam.apache.org/images/logos/sdks/java.png" width="20px" height="20px"
+ alt="Javadoc" />
+ Javadoc
+ </a>
+</table>
+<br>
+
+Estimates the number of distinct elements in a data stream using the
+[HyperLogLog++ algorithm](http://static.googleusercontent.com/media/research.google.com/en/us/pubs/archive/40671.pdf).
+The respective transforms to create and merge sketches, and to extract from them, are:
+
+* `HllCount.Init` aggregates inputs into HLL++ sketches.
+* `HllCount.MergePartial` merges HLL++ sketches into a new sketch.
+* `HllCount.Extract` extracts the estimated count of distinct elements from HLL++ sketches.
+
+You can read more about what a sketch is at https://github.com/google/zetasketch.
+
+## Examples
+**Example 1**: creates a long-type sketch for a `PCollection<Long>` with a custom precision:
+```java
+ PCollection<Long> input = ...;
+ int p = ...;
+ PCollection<byte[]> sketch = input.apply(HllCount.Init.forLongs().withPrecision(p).globally());
+```
+
+**Example 2**: creates a bytes-type sketch for a `PCollection<KV<String, byte[]>>`:
+```java
+ PCollection<KV<String, byte[]>> input = ...;
+ PCollection<KV<String, byte[]>> sketch = input.apply(HllCount.Init.forBytes().perKey());
+```
+
+**Example 3**: merges existing sketches in a `PCollection<byte[]>` into a new sketch,
+which summarizes the union of the inputs that were aggregated in the merged sketches:
+```java
+ PCollection<byte[]> sketches = ...;
+ PCollection<byte[]> mergedSketch = sketches.apply(HllCount.MergePartial.globally());
+```
+
+**Example 4**: estimates the count of distinct elements in a `PCollection<String>`:
+```java
+ PCollection<String> input = ...;
+ PCollection<Long> countDistinct =
+ input.apply(HllCount.Init.forStrings().globally()).apply(HllCount.Extract.globally());
+```
+
+**Example 5**: extracts the count distinct estimate from an existing sketch:
+```java
+ PCollection<byte[]> sketch = ...;
+ PCollection<Long> countDistinct = sketch.apply(HllCount.Extract.globally());
+```
+
+## Related transforms
+* [ApproximateUnique]({{ site.baseurl }}/documentation/transforms/java/aggregation/approximateunique)
+ estimates the number of distinct elements or values in key-value pairs (but does not expose sketches; also less accurate than `HllCount`).
\ No newline at end of file
diff --git a/website/src/documentation/transforms/java/index.md b/website/src/documentation/transforms/java/index.md
index b36e305..71b3721 100644
--- a/website/src/documentation/transforms/java/index.md
+++ b/website/src/documentation/transforms/java/index.md
@@ -58,6 +58,7 @@
<tr><td><a href="{{ site.baseurl }}/documentation/transforms/java/aggregation/groupbykey">GroupByKey</a></td><td>Takes a keyed collection of elements and produces a collection where each element
consists of a key and all values associated with that key.</td></tr>
<tr><td><a href="{{ site.baseurl }}/documentation/transforms/java/aggregation/groupintobatches">GroupIntoBatches</a></td><td>Batches values associated with keys into <code>Iterable</code> batches of some size. Each batch contains elements associated with a specific key.</td></tr>
+ <tr><td><a href="{{ site.baseurl }}/documentation/transforms/java/aggregation/hllcount">HllCount</a></td><td>Estimates the number of distinct elements and creates re-aggregatable sketches using the HyperLogLog++ algorithm.</td></tr>
<tr><td><a href="{{ site.baseurl }}/documentation/transforms/java/aggregation/latest">Latest</a></td><td>Selects the latest element within each aggregation according to the implicit timestamp.</td></tr>
<tr><td><a href="{{ site.baseurl }}/documentation/transforms/java/aggregation/max">Max</a></td><td>Outputs the maximum element within each aggregation.</td></tr>
<tr><td><a href="{{ site.baseurl }}/documentation/transforms/java/aggregation/mean">Mean</a></td><td>Computes the average within each aggregation.</td></tr>
diff --git a/website/src/roadmap/connectors-multi-sdk.md b/website/src/roadmap/connectors-multi-sdk.md
index 5ab6877..a0ccea3 100644
--- a/website/src/roadmap/connectors-multi-sdk.md
+++ b/website/src/roadmap/connectors-multi-sdk.md
@@ -28,18 +28,73 @@
# Cross-language transforms
-As an added benefit of Beam portability efforts, in the future, we’ll be
-able to utilize Beam transforms across languages. This has many benefits.
-For example.
+_Last updated on November 2019._
-* Beam pipelines written using Python and Go SDKs will be able to utilize
-the vast selection of connectors that are currently available for Java SDK.
-* Java SDK will be able to utilize connectors for systems that only offer a
-Python API.
-* Go SDK, will be able to utilize connectors currently available for Java and
-Python SDKs.
-* Connector authors will be able to implement new Beam connectors using a
-language of choice and utilize these connectors from other languages reducing
-the maintenance and support efforts.
+As an added benefit of Beam portability effort, we are able to utilize Beam transforms across SDKs. This has many benefits.
-See [Beam portability framework roadmap](https://beam.apache.org/roadmap/portability/) for more details.
+* Connector sharing across SDKs. For example,
+ + Beam pipelines written using Python and Go SDKs will be able to utilize the vast selection of connectors that are currently implemented for Java SDK.
+ + Java SDK will be able to utilize connectors for systems that only offer a Python API.
+ + Go SDK, will be able to utilize connectors currently available for Java and Python SDKs.
+* Ease of developing and maintaining Beam transforms - in general, with cross-language transforms, Beam transform authors will be able to implement new Beam transforms using a
+language of choice and utilize these transforms from other languages reducing the maintenance and support overheads.
+* [Beam SQL](https://beam.apache.org/documentation/dsls/sql/overview/), that is currently only available to Java SDK, will become available to Python and Go SDKs.
+* [Beam TFX transforms](https://www.tensorflow.org/tfx/transform/get_started), that are currently only available to Beam Python SDK pipelines will become available to Java and Go SDKs.
+
+## Completed and Ongoing Efforts
+
+Many efforts related to cross-language transforms are currently in flux. Some of the completed and ongoing efforts are given below.
+
+### Cross-language transforms API and expansion service
+
+Work related to developing/updating the cross-language transforms API for Java/Python/Go SDKs and work related to cross-language transform expansion services.
+
+* Basic API for Java SDK - completed
+* Basic API for Python SDK - completed
+* Basic API for Go SDK - Not started
+* Basic cross-language transform expansion service for Java and Python SDKs - completed
+* Artifact staging - In progress - [email thread](https://lists.apache.org/thread.html/6fcee7047f53cf1c0636fb65367ef70842016d57effe2e5795c4137d@%3Cdev.beam.apache.org%3E), [doc](https://docs.google.com/document/d/1XaiNekAY2sptuQRIXpjGAyaYdSc-wlJ-VKjl04c8N48/edit#heading=h.900gc947qrw8)
+
+### Support for Flink runner
+
+Work related to making cross-language transforms available for Flink runner.
+
+* Basic support for executing cross-language transforms on portable Flink runner - completed
+
+### Support for Dataflow runner
+
+Work related to making cross-language transforms available for Dataflow runner.
+
+* Basic support for executing cross-language transforms on Dataflow runner
+ + This work requires updates to Dataflow service's job submission and job execution logic. This is currently being developed at Google.
+
+### Support for Direct runner
+
+Work related to making cross-language transforms available on Direct runner
+
+* Basic support for executing cross-language transforms on portable Direct runner - Not started
+
+### Connector/transform support
+
+Ongoing and planned work related to making existing connectors/transforms available to other SDKs through the cross-language transforms framework.
+
+* Java KafkIO - In progress - [BEAM-7029](https://issues.apache.org/jira/browse/BEAM-7029)
+* Java PubSubIO - In progress - [BEAM-7738](https://issues.apache.org/jira/browse/BEAM-7738)
+
+### Portable Beam schema
+
+Portable Beam schema support will provide a generalized mechanism for serializing and transferring data across language boundaries which will be extremely useful for pipelines that employ cross-language transforms.
+
+* Make row coder a standard coder and implement in python - In progress - [BEAM-7886](https://issues.apache.org/jira/browse/BEAM-7886)
+
+### Integration/Performance testing
+
+* Add an integration test suite for cross-language transforms on Flink runner - In progress - [BEAM-6683](https://issues.apache.org/jira/browse/BEAM-6683)
+
+### Documentation
+
+Work related to adding documenting on cross-language transforms to Beam Website.
+
+* Document cross-language transforms API for Java/Python - Not started
+* Document API for making existing transforms available as cross-language transforms for Java/Python - Not started
+
diff --git a/website/src/roadmap/index.md b/website/src/roadmap/index.md
index 21aee01..6c1a1d6 100644
--- a/website/src/roadmap/index.md
+++ b/website/src/roadmap/index.md
@@ -32,18 +32,20 @@
## Portability Framework
-Portability is the primary Beam vision: running pipelines authors with _any SDK_
+Portability is the primary Beam vision: running pipelines authored with _any SDK_
on _any runner_. This is a cross-cutting effort across Java, Python, and Go,
-and every Beam runner.
+and every Beam runner. Portability is currently supported on the
+[Flink]({{site.baseurl}}/documentation/runners/flink/)
+and [Spark]({{site.baseurl}}/documentation/runners/spark/) runners.
See the details on the [Portability Roadmap]({{site.baseurl}}/roadmap/portability/)
-## Python on Flink
+## Cross-language transforms
-A major highlight of the portability effort is the effort in running Python pipelines
-the Flink runner.
-
-You can [follow the instructions to try it out]({{site.baseurl}}/roadmap/portability/#python-on-flink)
+As a benefit of the portability effort, we are able to utilize Beam transforms across SDKs.
+Examples include using Java connectors and Beam SQL from Python or Go pipelines
+or Beam TFX transforms from Java and Go.
+For details see [Roadmap for multi-SDK efforts]({{ site.baseurl }}/roadmap/connectors-multi-sdk/).
## Go SDK