Merge pull request #10043 [BEAM-8597] Allow TestStream trigger tests to run on other runners.
diff --git a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy
index 3ea3643..557f45f 100644
--- a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy
+++ b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy
@@ -1905,6 +1905,7 @@
mustRunAfter = [
':runners:flink:1.9:job-server-container:docker',
':runners:flink:1.9:job-server:shadowJar',
+ ':runners:spark:job-server:shadowJar',
':sdks:python:container:py2:docker',
':sdks:python:container:py35:docker',
':sdks:python:container:py36:docker',
@@ -1958,6 +1959,7 @@
addPortableWordCountTask(true, "PortableRunner")
addPortableWordCountTask(false, "FlinkRunner")
addPortableWordCountTask(true, "FlinkRunner")
+ addPortableWordCountTask(false, "SparkRunner")
}
}
}
diff --git a/runners/flink/flink_runner.gradle b/runners/flink/flink_runner.gradle
index 3254a85..6281b94 100644
--- a/runners/flink/flink_runner.gradle
+++ b/runners/flink/flink_runner.gradle
@@ -200,6 +200,7 @@
excludeCategories 'org.apache.beam.sdk.testing.UsesCommittedMetrics'
if (config.streaming) {
excludeCategories 'org.apache.beam.sdk.testing.UsesImpulse'
+ excludeCategories 'org.apache.beam.sdk.testing.UsesTestStreamWithMultipleStages' // BEAM-8598
excludeCategories 'org.apache.beam.sdk.testing.UsesTestStreamWithProcessingTime'
} else {
excludeCategories 'org.apache.beam.sdk.testing.UsesUnboundedSplittableParDo'
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/FnApiWindowMappingFn.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/FnApiWindowMappingFn.java
index 7e298e7..dcf15f4 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/FnApiWindowMappingFn.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/FnApiWindowMappingFn.java
@@ -56,7 +56,6 @@
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.WindowingStrategy;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Strings;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.Cache;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.CacheBuilder;
import org.slf4j.Logger;
@@ -253,7 +252,7 @@
}
// Check to see if processing the request failed.
- throwIfFailure(processResponse);
+ MoreFutures.get(processResponse);
waitForInboundTermination.awaitCompletion();
WindowedValue<KV<byte[], TargetWindowT>> sideInputWindow = outputValue.poll();
@@ -300,22 +299,10 @@
processBundleDescriptor.toBuilder().setId(descriptorId).build())
.build())
.build());
- throwIfFailure(response);
+ // Check if the bundle descriptor is registered successfully.
+ MoreFutures.get(response);
processBundleDescriptorId = descriptorId;
}
return processBundleDescriptorId;
}
-
- private static InstructionResponse throwIfFailure(
- CompletionStage<InstructionResponse> responseFuture)
- throws ExecutionException, InterruptedException {
- InstructionResponse response = MoreFutures.get(responseFuture);
- if (!Strings.isNullOrEmpty(response.getError())) {
- throw new IllegalStateException(
- String.format(
- "Client failed to process %s with error [%s].",
- response.getInstructionId(), response.getError()));
- }
- return response;
- }
}
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/fn/control/RegisterAndProcessBundleOperation.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/fn/control/RegisterAndProcessBundleOperation.java
index 0a3346c..bf42c4d 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/fn/control/RegisterAndProcessBundleOperation.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/fn/control/RegisterAndProcessBundleOperation.java
@@ -370,12 +370,8 @@
* elements consumed from the upstream read operation.
*
* <p>May be called at any time, including before start() and after finish().
- *
- * @throws InterruptedException
- * @throws ExecutionException
*/
- public CompletionStage<BeamFnApi.ProcessBundleProgressResponse> getProcessBundleProgress()
- throws InterruptedException, ExecutionException {
+ public CompletionStage<BeamFnApi.ProcessBundleProgressResponse> getProcessBundleProgress() {
// processBundleId may be reset if this bundle finishes asynchronously.
String processBundleId = this.processBundleId;
@@ -393,13 +389,7 @@
return instructionRequestHandler
.handle(processBundleRequest)
- .thenApply(
- response -> {
- if (!response.getError().isEmpty()) {
- throw new IllegalStateException(response.getError());
- }
- return response.getProcessBundleProgress();
- });
+ .thenApply(InstructionResponse::getProcessBundleProgress);
}
/** Returns the final metrics returned by the SDK harness when it completes the bundle. */
@@ -636,53 +626,36 @@
return true;
}
- private static CompletionStage<BeamFnApi.InstructionResponse> throwIfFailure(
+ private static CompletionStage<BeamFnApi.ProcessBundleResponse> getProcessBundleResponse(
CompletionStage<InstructionResponse> responseFuture) {
return responseFuture.thenApply(
response -> {
- if (!response.getError().isEmpty()) {
- throw new IllegalStateException(
- String.format(
- "Client failed to process %s with error [%s].",
- response.getInstructionId(), response.getError()));
+ switch (response.getResponseCase()) {
+ case PROCESS_BUNDLE:
+ return response.getProcessBundle();
+ default:
+ throw new IllegalStateException(
+ String.format(
+ "SDK harness returned wrong kind of response to ProcessBundleRequest: %s",
+ TextFormat.printToString(response)));
}
- return response;
});
}
- private static CompletionStage<BeamFnApi.ProcessBundleResponse> getProcessBundleResponse(
- CompletionStage<InstructionResponse> responseFuture) {
- return throwIfFailure(responseFuture)
- .thenApply(
- response -> {
- switch (response.getResponseCase()) {
- case PROCESS_BUNDLE:
- return response.getProcessBundle();
- default:
- throw new IllegalStateException(
- String.format(
- "SDK harness returned wrong kind of response to ProcessBundleRequest: %s",
- TextFormat.printToString(response)));
- }
- });
- }
-
private static CompletionStage<BeamFnApi.RegisterResponse> getRegisterResponse(
- CompletionStage<InstructionResponse> responseFuture)
- throws ExecutionException, InterruptedException {
- return throwIfFailure(responseFuture)
- .thenApply(
- response -> {
- switch (response.getResponseCase()) {
- case REGISTER:
- return response.getRegister();
- default:
- throw new IllegalStateException(
- String.format(
- "SDK harness returned wrong kind of response to RegisterRequest: %s",
- TextFormat.printToString(response)));
- }
- });
+ CompletionStage<InstructionResponse> responseFuture) {
+ return responseFuture.thenApply(
+ response -> {
+ switch (response.getResponseCase()) {
+ case REGISTER:
+ return response.getRegister();
+ default:
+ throw new IllegalStateException(
+ String.format(
+ "SDK harness returned wrong kind of response to RegisterRequest: %s",
+ TextFormat.printToString(response)));
+ }
+ });
}
private static void cancelIfNotNull(CompletionStage<?> future) {
diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/BundleCheckpointHandler.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/BundleCheckpointHandler.java
new file mode 100644
index 0000000..1e5fa53
--- /dev/null
+++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/BundleCheckpointHandler.java
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.fnexecution.control;
+
+import org.apache.beam.model.fnexecution.v1.BeamFnApi;
+
+/**
+ * A handler which is invoked when the SDK returns {@link BeamFnApi.DelayedBundleApplication}s as
+ * part of the bundle completion.
+ *
+ * <p>These bundle applications must be resumed otherwise data loss will occur.
+ *
+ * <p>See <a href="https://s.apache.org/beam-breaking-fusion">breaking the fusion barrier</a> for
+ * further details.
+ */
+public interface BundleCheckpointHandler {
+ void onCheckpoint(BeamFnApi.ProcessBundleResponse response);
+}
diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/BundleFinalizationHandler.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/BundleFinalizationHandler.java
new file mode 100644
index 0000000..849663b
--- /dev/null
+++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/BundleFinalizationHandler.java
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.fnexecution.control;
+
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleResponse;
+
+/**
+ * A handler for the runner when a finalization request has been received.
+ *
+ * <p>The runner is responsible for finalizing the bundle when all output from the bundle has been
+ * durably persisted.
+ *
+ * <p>See <a href="https://s.apache.org/beam-finalizing-bundles">finalizing bundles</a> for further
+ * details.
+ */
+public interface BundleFinalizationHandler {
+ void requestsFinalization(ProcessBundleResponse response);
+}
diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/InstructionRequestHandler.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/InstructionRequestHandler.java
index b655732..8a9dc75 100644
--- a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/InstructionRequestHandler.java
+++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/InstructionRequestHandler.java
@@ -20,7 +20,10 @@
import java.util.concurrent.CompletionStage;
import org.apache.beam.model.fnexecution.v1.BeamFnApi;
-/** Interface for any function that can handle a Fn API {@link BeamFnApi.InstructionRequest}. */
+/**
+ * Interface for any function that can handle a Fn API {@link BeamFnApi.InstructionRequest}. Any
+ * error responses will be converted to exceptionally completed futures.
+ */
public interface InstructionRequestHandler extends AutoCloseable {
CompletionStage<BeamFnApi.InstructionResponse> handle(BeamFnApi.InstructionRequest request);
}
diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/SdkHarnessClient.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/SdkHarnessClient.java
index 1ee5184..86232dd 100644
--- a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/SdkHarnessClient.java
+++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/SdkHarnessClient.java
@@ -95,6 +95,9 @@
* // send all main input elements ...
* }
* }</pre>
+ *
+ * <p>An exception during {@link #close()} will be thrown if the bundle requests finalization or
+ * attempts to checkpoint by returning a {@link BeamFnApi.DelayedBundleApplication}.
*/
public ActiveBundle newBundle(
Map<String, RemoteOutputReceiver<?>> outputReceivers,
@@ -122,6 +125,47 @@
* try (ActiveBundle bundle = SdkHarnessClient.newBundle(...)) {
* FnDataReceiver<InputT> inputReceiver =
* (FnDataReceiver) bundle.getInputReceivers().get(mainPCollectionId);
+ * // send all main input elements ...
+ * }
+ * }</pre>
+ *
+ * <p>An exception during {@link #close()} will be thrown if the bundle requests finalization or
+ * attempts to checkpoint by returning a {@link BeamFnApi.DelayedBundleApplication}.
+ */
+ public ActiveBundle newBundle(
+ Map<String, RemoteOutputReceiver<?>> outputReceivers,
+ StateRequestHandler stateRequestHandler,
+ BundleProgressHandler progressHandler) {
+ return newBundle(
+ outputReceivers,
+ stateRequestHandler,
+ progressHandler,
+ request -> {
+ throw new UnsupportedOperationException(
+ String.format(
+ "The %s does not have a registered bundle checkpoint handler.",
+ ActiveBundle.class.getSimpleName()));
+ },
+ request -> {
+ throw new UnsupportedOperationException(
+ String.format(
+ "The %s does not have a registered bundle finalization handler.",
+ ActiveBundle.class.getSimpleName()));
+ });
+ }
+
+ /**
+ * Start a new bundle for the given {@link BeamFnApi.ProcessBundleDescriptor} identifier.
+ *
+ * <p>The input channels for the returned {@link ActiveBundle} are derived from the instructions
+ * in the {@link BeamFnApi.ProcessBundleDescriptor}.
+ *
+ * <p>NOTE: It is important to {@link #close()} each bundle after all elements are emitted.
+ *
+ * <pre>{@code
+ * try (ActiveBundle bundle = SdkHarnessClient.newBundle(...)) {
+ * FnDataReceiver<InputT> inputReceiver =
+ * (FnDataReceiver) bundle.getInputReceivers().get(mainPCollectionId);
* // send all elements ...
* }
* }</pre>
@@ -129,7 +173,9 @@
public ActiveBundle newBundle(
Map<String, RemoteOutputReceiver<?>> outputReceivers,
StateRequestHandler stateRequestHandler,
- BundleProgressHandler progressHandler) {
+ BundleProgressHandler progressHandler,
+ BundleCheckpointHandler checkpointHandler,
+ BundleFinalizationHandler finalizationHandler) {
String bundleId = idGenerator.getId();
final CompletionStage<BeamFnApi.InstructionResponse> genericResponse =
@@ -175,7 +221,9 @@
dataReceiversBuilder.build(),
outputClients,
stateDelegator.registerForProcessBundleInstructionId(bundleId, stateRequestHandler),
- progressHandler);
+ progressHandler,
+ checkpointHandler,
+ finalizationHandler);
}
private <OutputT> InboundDataClient attachReceiver(
@@ -193,6 +241,8 @@
private final Map<String, InboundDataClient> outputClients;
private final StateDelegator.Registration stateRegistration;
private final BundleProgressHandler progressHandler;
+ private final BundleCheckpointHandler checkpointHandler;
+ private final BundleFinalizationHandler finalizationHandler;
private ActiveBundle(
String bundleId,
@@ -200,13 +250,17 @@
Map<String, CloseableFnDataReceiver> inputReceivers,
Map<String, InboundDataClient> outputClients,
StateDelegator.Registration stateRegistration,
- BundleProgressHandler progressHandler) {
+ BundleProgressHandler progressHandler,
+ BundleCheckpointHandler checkpointHandler,
+ BundleFinalizationHandler finalizationHandler) {
this.bundleId = bundleId;
this.response = response;
this.inputReceivers = inputReceivers;
this.outputClients = outputClients;
this.stateRegistration = stateRegistration;
this.progressHandler = progressHandler;
+ this.checkpointHandler = checkpointHandler;
+ this.finalizationHandler = finalizationHandler;
}
/** Returns an id used to represent this bundle. */
@@ -256,13 +310,15 @@
BeamFnApi.ProcessBundleResponse completedResponse = MoreFutures.get(response);
progressHandler.onCompleted(completedResponse);
if (completedResponse.getResidualRootsCount() > 0) {
- throw new IllegalStateException(
- "TODO: [BEAM-2939] residual roots in process bundle response not yet supported.");
+ checkpointHandler.onCheckpoint(completedResponse);
+ }
+ if (completedResponse.getRequiresFinalization()) {
+ finalizationHandler.requestsFinalization(completedResponse);
}
} else {
// TODO: [BEAM-3962] Handle aborting the bundle being processed.
throw new IllegalStateException(
- "Processing bundle failed, " + "TODO: [BEAM-3962] abort bundle.");
+ "Processing bundle failed, TODO: [BEAM-3962] abort bundle.");
}
} catch (Exception e) {
if (exception == null) {
diff --git a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/SdkHarnessClientTest.java b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/SdkHarnessClientTest.java
index 6a28735..089c8d1 100644
--- a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/SdkHarnessClientTest.java
+++ b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/SdkHarnessClientTest.java
@@ -31,6 +31,7 @@
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
+import static org.mockito.Mockito.verifyZeroInteractions;
import static org.mockito.Mockito.when;
import java.util.ArrayList;
@@ -41,6 +42,7 @@
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import org.apache.beam.model.fnexecution.v1.BeamFnApi;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.DelayedBundleApplication;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.InstructionResponse;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleDescriptor;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleResponse;
@@ -623,6 +625,110 @@
assertThat(requests.get(1).getProcessBundle().getCacheTokensList(), is(cacheTokens));
}
+ @Test
+ public void testBundleCheckpointCallback() throws Exception {
+ Exception testException = new Exception();
+
+ InboundDataClient mockOutputReceiver = mock(InboundDataClient.class);
+ CloseableFnDataReceiver mockInputSender = mock(CloseableFnDataReceiver.class);
+
+ CompletableFuture<InstructionResponse> processBundleResponseFuture = new CompletableFuture<>();
+ when(fnApiControlClient.handle(any(BeamFnApi.InstructionRequest.class)))
+ .thenReturn(createRegisterResponse())
+ .thenReturn(processBundleResponseFuture);
+
+ FullWindowedValueCoder<String> coder =
+ FullWindowedValueCoder.of(StringUtf8Coder.of(), Coder.INSTANCE);
+ BundleProcessor processor =
+ sdkHarnessClient.getProcessor(
+ descriptor,
+ Collections.singletonMap(
+ "inputPC",
+ RemoteInputDestination.of(
+ (FullWindowedValueCoder) coder, SDK_GRPC_READ_TRANSFORM)));
+ when(dataService.receive(any(), any(), any())).thenReturn(mockOutputReceiver);
+ when(dataService.send(any(), eq(coder))).thenReturn(mockInputSender);
+
+ RemoteOutputReceiver mockRemoteOutputReceiver = mock(RemoteOutputReceiver.class);
+ BundleProgressHandler mockProgressHandler = mock(BundleProgressHandler.class);
+ BundleCheckpointHandler mockCheckpointHandler = mock(BundleCheckpointHandler.class);
+ BundleFinalizationHandler mockFinalizationHandler = mock(BundleFinalizationHandler.class);
+
+ ProcessBundleResponse response =
+ ProcessBundleResponse.newBuilder()
+ .addResidualRoots(DelayedBundleApplication.getDefaultInstance())
+ .build();
+ ArrayList<ProcessBundleResponse> checkpoints = new ArrayList<>();
+
+ try (ActiveBundle activeBundle =
+ processor.newBundle(
+ ImmutableMap.of(SDK_GRPC_WRITE_TRANSFORM, mockRemoteOutputReceiver),
+ (request) -> {
+ throw new UnsupportedOperationException();
+ },
+ mockProgressHandler,
+ mockCheckpointHandler,
+ mockFinalizationHandler)) {
+ processBundleResponseFuture.complete(
+ InstructionResponse.newBuilder().setProcessBundle(response).build());
+ }
+
+ verify(mockProgressHandler).onCompleted(response);
+ verify(mockCheckpointHandler).onCheckpoint(response);
+ verifyZeroInteractions(mockFinalizationHandler);
+ }
+
+ @Test
+ public void testBundleFinalizationCallback() throws Exception {
+ Exception testException = new Exception();
+
+ InboundDataClient mockOutputReceiver = mock(InboundDataClient.class);
+ CloseableFnDataReceiver mockInputSender = mock(CloseableFnDataReceiver.class);
+
+ CompletableFuture<InstructionResponse> processBundleResponseFuture = new CompletableFuture<>();
+ when(fnApiControlClient.handle(any(BeamFnApi.InstructionRequest.class)))
+ .thenReturn(createRegisterResponse())
+ .thenReturn(processBundleResponseFuture);
+
+ FullWindowedValueCoder<String> coder =
+ FullWindowedValueCoder.of(StringUtf8Coder.of(), Coder.INSTANCE);
+ BundleProcessor processor =
+ sdkHarnessClient.getProcessor(
+ descriptor,
+ Collections.singletonMap(
+ "inputPC",
+ RemoteInputDestination.of(
+ (FullWindowedValueCoder) coder, SDK_GRPC_READ_TRANSFORM)));
+ when(dataService.receive(any(), any(), any())).thenReturn(mockOutputReceiver);
+ when(dataService.send(any(), eq(coder))).thenReturn(mockInputSender);
+
+ RemoteOutputReceiver mockRemoteOutputReceiver = mock(RemoteOutputReceiver.class);
+ BundleProgressHandler mockProgressHandler = mock(BundleProgressHandler.class);
+ BundleCheckpointHandler mockCheckpointHandler = mock(BundleCheckpointHandler.class);
+ BundleFinalizationHandler mockFinalizationHandler = mock(BundleFinalizationHandler.class);
+
+ ProcessBundleResponse response =
+ ProcessBundleResponse.newBuilder().setRequiresFinalization(true).build();
+ ArrayList<ProcessBundleResponse> checkpoints = new ArrayList<>();
+
+ try (ActiveBundle activeBundle =
+ processor.newBundle(
+ ImmutableMap.of(SDK_GRPC_WRITE_TRANSFORM, mockRemoteOutputReceiver),
+ (request) -> {
+ throw new UnsupportedOperationException();
+ },
+ mockProgressHandler,
+ mockCheckpointHandler,
+ mockFinalizationHandler)) {
+ processBundleResponseFuture.complete(
+ InstructionResponse.newBuilder().setProcessBundle(response).build());
+ }
+
+ verify(mockProgressHandler).onCompleted(response);
+ verify(mockFinalizationHandler).requestsFinalization(response);
+ verifyZeroInteractions(mockCheckpointHandler);
+ }
+
private CompletableFuture<InstructionResponse> createRegisterResponse() {
return CompletableFuture.completedFuture(
InstructionResponse.newBuilder()
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/UsesTestStreamWithMultipleStages.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/UsesTestStreamWithMultipleStages.java
new file mode 100644
index 0000000..55999ce
--- /dev/null
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/UsesTestStreamWithMultipleStages.java
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.testing;
+
+/**
+ * Subcategory for {@link UsesTestStream} tests which use {@link TestStream} # across multiple
+ * stages. Some Runners do not properly support quiescence in a way that {@link TestStream} demands
+ * it.
+ */
+public interface UsesTestStreamWithMultipleStages extends UsesTestStream {}
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/TestStreamTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/TestStreamTest.java
index 5e4cdcb..e48b6b2 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/TestStreamTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/TestStreamTest.java
@@ -17,6 +17,7 @@
*/
package org.apache.beam.sdk.testing;
+import static org.apache.beam.sdk.transforms.windowing.Window.into;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.allOf;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
@@ -28,12 +29,22 @@
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.coders.VarIntCoder;
import org.apache.beam.sdk.coders.VarLongCoder;
+import org.apache.beam.sdk.state.StateSpec;
+import org.apache.beam.sdk.state.StateSpecs;
+import org.apache.beam.sdk.state.TimeDomain;
+import org.apache.beam.sdk.state.Timer;
+import org.apache.beam.sdk.state.TimerSpec;
+import org.apache.beam.sdk.state.TimerSpecs;
+import org.apache.beam.sdk.state.ValueState;
import org.apache.beam.sdk.testing.TestStream.Builder;
import org.apache.beam.sdk.transforms.Combine;
import org.apache.beam.sdk.transforms.Count;
+import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.Flatten;
import org.apache.beam.sdk.transforms.GroupByKey;
+import org.apache.beam.sdk.transforms.Keys;
import org.apache.beam.sdk.transforms.MapElements;
+import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.Sum;
import org.apache.beam.sdk.transforms.Values;
import org.apache.beam.sdk.transforms.WithKeys;
@@ -44,6 +55,7 @@
import org.apache.beam.sdk.transforms.windowing.DefaultTrigger;
import org.apache.beam.sdk.transforms.windowing.FixedWindows;
import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
+import org.apache.beam.sdk.transforms.windowing.GlobalWindows;
import org.apache.beam.sdk.transforms.windowing.IntervalWindow;
import org.apache.beam.sdk.transforms.windowing.Never;
import org.apache.beam.sdk.transforms.windowing.Window;
@@ -51,8 +63,10 @@
import org.apache.beam.sdk.util.CoderUtils;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionList;
import org.apache.beam.sdk.values.TimestampedValue;
import org.apache.beam.sdk.values.TypeDescriptors;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
import org.joda.time.Duration;
import org.joda.time.Instant;
import org.junit.Rule;
@@ -263,7 +277,7 @@
FixedWindows windows = FixedWindows.of(Duration.standardHours(6));
PCollection<String> windowedValues =
p.apply(stream)
- .apply(Window.into(windows))
+ .apply(into(windows))
.apply(WithKeys.of(1))
.apply(GroupByKey.create())
.apply(Values.create())
@@ -387,6 +401,74 @@
}
@Test
+ @Category({ValidatesRunner.class, UsesTestStream.class, UsesTestStreamWithMultipleStages.class})
+ public void testMultiStage() throws Exception {
+ TestStream<String> testStream =
+ TestStream.create(StringUtf8Coder.of())
+ .addElements("before") // before
+ .advanceWatermarkTo(Instant.ofEpochSecond(0)) // BEFORE
+ .addElements(TimestampedValue.of("after", Instant.ofEpochSecond(10))) // after
+ .advanceWatermarkToInfinity(); // AFTER
+
+ PCollection<String> input = p.apply(testStream);
+
+ PCollection<String> grouped =
+ input
+ .apply(Window.into(FixedWindows.of(Duration.standardSeconds(1))))
+ .apply(
+ MapElements.into(
+ TypeDescriptors.kvs(TypeDescriptors.strings(), TypeDescriptors.strings()))
+ .via(e -> KV.of(e, e)))
+ .apply(GroupByKey.create())
+ .apply(Keys.create())
+ .apply("Upper", MapElements.into(TypeDescriptors.strings()).via(String::toUpperCase))
+ .apply("Rewindow", Window.into(new GlobalWindows()));
+
+ PCollection<String> result =
+ PCollectionList.of(ImmutableList.of(input, grouped))
+ .apply(Flatten.pCollections())
+ .apply(
+ "Key",
+ MapElements.into(
+ TypeDescriptors.kvs(TypeDescriptors.strings(), TypeDescriptors.strings()))
+ .via(e -> KV.of("key", e)))
+ .apply(
+ ParDo.of(
+ new DoFn<KV<String, String>, String>() {
+ @StateId("seen")
+ private final StateSpec<ValueState<String>> seenSpec =
+ StateSpecs.value(StringUtf8Coder.of());
+
+ @TimerId("emit")
+ private final TimerSpec emitSpec = TimerSpecs.timer(TimeDomain.EVENT_TIME);
+
+ @ProcessElement
+ public void process(
+ ProcessContext context,
+ @StateId("seen") ValueState<String> seenState,
+ @TimerId("emit") Timer emitTimer) {
+ String element = context.element().getValue();
+ if (seenState.read() == null) {
+ seenState.write(element);
+ } else {
+ seenState.write(seenState.read() + "," + element);
+ }
+ emitTimer.set(Instant.ofEpochSecond(100));
+ }
+
+ @OnTimer("emit")
+ public void onEmit(
+ OnTimerContext context, @StateId("seen") ValueState<String> seenState) {
+ context.output(seenState.read());
+ }
+ }));
+
+ PAssert.that(result).containsInAnyOrder("before,BEFORE,after,AFTER");
+
+ p.run().waitUntilFinish();
+ }
+
+ @Test
@Category(UsesTestStreamWithProcessingTime.class)
public void testCoder() throws Exception {
TestStream<String> testStream =
diff --git a/sdks/python/apache_beam/metrics/cells.pxd b/sdks/python/apache_beam/metrics/cells.pxd
new file mode 100644
index 0000000..0204da8
--- /dev/null
+++ b/sdks/python/apache_beam/metrics/cells.pxd
@@ -0,0 +1,49 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+cimport cython
+cimport libc.stdint
+
+
+cdef class MetricCell(object):
+ cdef object _lock
+ cpdef bint update(self, value) except -1
+
+
+cdef class CounterCell(MetricCell):
+ cdef readonly libc.stdint.int64_t value
+
+ @cython.locals(ivalue=libc.stdint.int64_t)
+ cpdef bint update(self, value) except -1
+
+
+cdef class DistributionCell(MetricCell):
+ cdef readonly DistributionData data
+
+ @cython.locals(ivalue=libc.stdint.int64_t)
+ cdef inline bint _update(self, value) except -1
+
+
+cdef class GaugeCell(MetricCell):
+ cdef readonly object data
+
+
+cdef class DistributionData(object):
+ cdef readonly libc.stdint.int64_t sum
+ cdef readonly libc.stdint.int64_t count
+ cdef readonly libc.stdint.int64_t min
+ cdef readonly libc.stdint.int64_t max
diff --git a/sdks/python/apache_beam/metrics/cells.py b/sdks/python/apache_beam/metrics/cells.py
index e7336e4..d30dd2a 100644
--- a/sdks/python/apache_beam/metrics/cells.py
+++ b/sdks/python/apache_beam/metrics/cells.py
@@ -30,12 +30,16 @@
from google.protobuf import timestamp_pb2
-from apache_beam.metrics.metricbase import Counter
-from apache_beam.metrics.metricbase import Distribution
-from apache_beam.metrics.metricbase import Gauge
from apache_beam.portability.api import beam_fn_api_pb2
from apache_beam.portability.api import metrics_pb2
+try:
+ import cython
+except ImportError:
+ class fake_cython:
+ compiled = False
+ globals()['cython'] = fake_cython
+
__all__ = ['DistributionResult', 'GaugeResult']
@@ -52,11 +56,17 @@
def __init__(self):
self._lock = threading.Lock()
+ def update(self, value):
+ raise NotImplementedError
+
def get_cumulative(self):
raise NotImplementedError
+ def __reduce__(self):
+ raise NotImplementedError
-class CounterCell(Counter, MetricCell):
+
+class CounterCell(MetricCell):
"""For internal use only; no backwards-compatibility guarantees.
Tracks the current value and delta of a counter metric.
@@ -80,27 +90,41 @@
return result
def inc(self, n=1):
- with self._lock:
- self.value += n
+ self.update(n)
+
+ def dec(self, n=1):
+ self.update(-n)
+
+ def update(self, value):
+ if cython.compiled:
+ ivalue = value
+ # We hold the GIL, no need for another lock.
+ self.value += ivalue
+ else:
+ with self._lock:
+ self.value += value
def get_cumulative(self):
with self._lock:
return self.value
- def to_runner_api_monitoring_info(self):
- """Returns a Metric with this counter value for use in a MonitoringInfo."""
- # TODO(ajamato): Update this code to be consistent with Gauges
- # and Distributions. Since there is no CounterData class this method
- # was added to CounterCell. Consider adding a CounterData class or
- # removing the GaugeData and DistributionData classes.
- return metrics_pb2.Metric(
- counter_data=metrics_pb2.CounterData(
- int64_value=self.get_cumulative()
- )
- )
+ def to_runner_api_user_metric(self, metric_name):
+ return beam_fn_api_pb2.Metrics.User(
+ metric_name=metric_name.to_runner_api(),
+ counter_data=beam_fn_api_pb2.Metrics.User.CounterData(
+ value=self.value))
+
+ def to_runner_api_monitoring_info(self, name, transform_id):
+ from apache_beam.metrics import monitoring_infos
+ return monitoring_infos.int64_user_counter(
+ name.namespace, name.name,
+ metrics_pb2.Metric(
+ counter_data=metrics_pb2.CounterData(
+ int64_value=self.get_cumulative())),
+ ptransform=transform_id)
-class DistributionCell(Distribution, MetricCell):
+class DistributionCell(MetricCell):
"""For internal use only; no backwards-compatibility guarantees.
Tracks the current value and delta for a distribution metric.
@@ -124,26 +148,43 @@
return result
def update(self, value):
- with self._lock:
+ if cython.compiled:
+ # We will hold the GIL throughout the entire _update.
self._update(value)
+ else:
+ with self._lock:
+ self._update(value)
def _update(self, value):
- value = int(value)
- self.data.count += 1
- self.data.sum += value
- self.data.min = (value
- if self.data.min is None or self.data.min > value
- else self.data.min)
- self.data.max = (value
- if self.data.max is None or self.data.max < value
- else self.data.max)
+ if cython.compiled:
+ ivalue = value
+ else:
+ ivalue = int(value)
+ self.data.count = self.data.count + 1
+ self.data.sum = self.data.sum + ivalue
+ if ivalue < self.data.min:
+ self.data.min = ivalue
+ if ivalue > self.data.max:
+ self.data.max = ivalue
def get_cumulative(self):
with self._lock:
return self.data.get_cumulative()
+ def to_runner_api_user_metric(self, metric_name):
+ return beam_fn_api_pb2.Metrics.User(
+ metric_name=metric_name.to_runner_api(),
+ distribution_data=self.get_cumulative().to_runner_api())
-class GaugeCell(Gauge, MetricCell):
+ def to_runner_api_monitoring_info(self, name, transform_id):
+ from apache_beam.metrics import monitoring_infos
+ return monitoring_infos.int64_user_distribution(
+ name.namespace, name.name,
+ self.get_cumulative().to_runner_api_monitoring_info(),
+ ptransform=transform_id)
+
+
+class GaugeCell(MetricCell):
"""For internal use only; no backwards-compatibility guarantees.
Tracks the current value and delta for a gauge metric.
@@ -167,6 +208,9 @@
return result
def set(self, value):
+ self.update(value)
+
+ def update(self, value):
value = int(value)
with self._lock:
# Set the value directly without checking timestamp, because
@@ -178,6 +222,18 @@
with self._lock:
return self.data.get_cumulative()
+ def to_runner_api_user_metric(self, metric_name):
+ return beam_fn_api_pb2.Metrics.User(
+ metric_name=metric_name.to_runner_api(),
+ gauge_data=self.get_cumulative().to_runner_api())
+
+ def to_runner_api_monitoring_info(self, name, transform_id):
+ from apache_beam.metrics import monitoring_infos
+ return monitoring_infos.int64_user_gauge(
+ name.namespace, name.name,
+ self.get_cumulative().to_runner_api_monitoring_info(),
+ ptransform=transform_id)
+
class DistributionResult(object):
"""The result of a Distribution metric."""
@@ -198,7 +254,7 @@
return not self == other
def __repr__(self):
- return '<DistributionResult(sum={}, count={}, min={}, max={})>'.format(
+ return 'DistributionResult(sum={}, count={}, min={}, max={})'.format(
self.sum,
self.count,
self.min,
@@ -206,11 +262,11 @@
@property
def max(self):
- return self.data.max
+ return self.data.max if self.data.count else None
@property
def min(self):
- return self.data.min
+ return self.data.min if self.data.count else None
@property
def count(self):
@@ -340,10 +396,15 @@
by other than the DistributionCell that contains it.
"""
def __init__(self, sum, count, min, max):
- self.sum = sum
- self.count = count
- self.min = min
- self.max = max
+ if count:
+ self.sum = sum
+ self.count = count
+ self.min = min
+ self.max = max
+ else:
+ self.sum = self.count = 0
+ self.min = 2**63 - 1
+ self.max = -2**63
def __eq__(self, other):
return (self.sum == other.sum and
@@ -359,7 +420,7 @@
return not self == other
def __repr__(self):
- return '<DistributionData(sum={}, count={}, min={}, max={})>'.format(
+ return 'DistributionData(sum={}, count={}, min={}, max={})'.format(
self.sum,
self.count,
self.min,
@@ -372,15 +433,11 @@
if other is None:
return self
- new_min = (None if self.min is None and other.min is None else
- min(x for x in (self.min, other.min) if x is not None))
- new_max = (None if self.max is None and other.max is None else
- max(x for x in (self.max, other.max) if x is not None))
return DistributionData(
self.sum + other.sum,
self.count + other.count,
- new_min,
- new_max)
+ self.min if self.min < other.min else other.min,
+ self.max if self.max > other.max else other.max)
@staticmethod
def singleton(value):
@@ -449,7 +506,7 @@
"""
@staticmethod
def identity_element():
- return DistributionData(0, 0, None, None)
+ return DistributionData(0, 0, 2**63 - 1, -2**63)
def combine(self, x, y):
return x.combine(y)
diff --git a/sdks/python/apache_beam/metrics/execution.pxd b/sdks/python/apache_beam/metrics/execution.pxd
index 74b34fb..6e1cbb0 100644
--- a/sdks/python/apache_beam/metrics/execution.pxd
+++ b/sdks/python/apache_beam/metrics/execution.pxd
@@ -16,10 +16,30 @@
#
cimport cython
+cimport libc.stdint
+
+from apache_beam.metrics.cells cimport MetricCell
+
+
+cdef object get_current_tracker
+
+
+cdef class _TypedMetricName(object):
+ cdef readonly object cell_type
+ cdef readonly object metric_name
+ cdef readonly object fast_name
+ cdef libc.stdint.int64_t _hash
+
+
+cdef object _DEFAULT
+
+
+cdef class MetricUpdater(object):
+ cdef _TypedMetricName typed_metric_name
+ cdef object default
cdef class MetricsContainer(object):
cdef object step_name
- cdef public object counters
- cdef public object distributions
- cdef public object gauges
+ cdef public dict metrics
+ cpdef MetricCell get_metric_cell(self, metric_key)
diff --git a/sdks/python/apache_beam/metrics/execution.py b/sdks/python/apache_beam/metrics/execution.py
index 91fe2f8..6918914 100644
--- a/sdks/python/apache_beam/metrics/execution.py
+++ b/sdks/python/apache_beam/metrics/execution.py
@@ -33,14 +33,13 @@
from __future__ import absolute_import
from builtins import object
-from collections import defaultdict
from apache_beam.metrics import monitoring_infos
from apache_beam.metrics.cells import CounterCell
from apache_beam.metrics.cells import DistributionCell
from apache_beam.metrics.cells import GaugeCell
-from apache_beam.portability.api import beam_fn_api_pb2
from apache_beam.runners.worker import statesampler
+from apache_beam.runners.worker.statesampler import get_current_tracker
class MetricKey(object):
@@ -150,88 +149,117 @@
MetricsEnvironment = _MetricsEnvironment()
+class _TypedMetricName(object):
+ """Like MetricName, but also stores the cell type of the metric."""
+ def __init__(self, cell_type, metric_name):
+ self.cell_type = cell_type
+ self.metric_name = metric_name
+ if isinstance(metric_name, str):
+ self.fast_name = metric_name
+ else:
+ self.fast_name = '%d_%s%s' % (
+ len(metric_name.name), metric_name.name, metric_name.namespace)
+ # Cached for speed, as this is used as a key for every counter update.
+ self._hash = hash((cell_type, self.fast_name))
+
+ def __eq__(self, other):
+ return self is other or (
+ self.cell_type == other.cell_type and self.fast_name == other.fast_name)
+
+ def __ne__(self, other):
+ return not self == other
+
+ def __hash__(self):
+ return self._hash
+
+ def __reduce__(self):
+ return _TypedMetricName, (self.cell_type, self.metric_name)
+
+
+_DEFAULT = None
+
+
+class MetricUpdater(object):
+ """A callable that updates the metric as quickly as possible."""
+ def __init__(self, cell_type, metric_name, default=None):
+ self.typed_metric_name = _TypedMetricName(cell_type, metric_name)
+ self.default = default
+
+ def __call__(self, value=_DEFAULT):
+ if value is _DEFAULT:
+ if self.default is _DEFAULT:
+ raise ValueError(
+ 'Missing value for update of %s' % self.metric_name)
+ value = self.default
+ tracker = get_current_tracker()
+ if tracker is not None:
+ tracker.update_metric(self.typed_metric_name, value)
+
+ def __reduce__(self):
+ return MetricUpdater, (
+ self.typed_metric_name.cell_type,
+ self.typed_metric_name.metric_name,
+ self.default)
+
+
class MetricsContainer(object):
"""Holds the metrics of a single step and a single bundle."""
def __init__(self, step_name):
self.step_name = step_name
- self.counters = defaultdict(lambda: CounterCell())
- self.distributions = defaultdict(lambda: DistributionCell())
- self.gauges = defaultdict(lambda: GaugeCell())
+ self.metrics = dict()
def get_counter(self, metric_name):
- return self.counters[metric_name]
+ return self.get_metric_cell(_TypedMetricName(CounterCell, metric_name))
def get_distribution(self, metric_name):
- return self.distributions[metric_name]
+ return self.get_metric_cell(_TypedMetricName(DistributionCell, metric_name))
def get_gauge(self, metric_name):
- return self.gauges[metric_name]
+ return self.get_metric_cell(_TypedMetricName(GaugeCell, metric_name))
+
+ def get_metric_cell(self, typed_metric_name):
+ cell = self.metrics.get(typed_metric_name, None)
+ if cell is None:
+ cell = self.metrics[typed_metric_name] = typed_metric_name.cell_type()
+ return cell
def get_cumulative(self):
"""Return MetricUpdates with cumulative values of all metrics in container.
This returns all the cumulative values for all metrics.
"""
- counters = {MetricKey(self.step_name, k): v.get_cumulative()
- for k, v in self.counters.items()}
+ counters = {MetricKey(self.step_name, k.metric_name): v.get_cumulative()
+ for k, v in self.metrics.items()
+ if k.cell_type == CounterCell}
- distributions = {MetricKey(self.step_name, k): v.get_cumulative()
- for k, v in self.distributions.items()}
+ distributions = {
+ MetricKey(self.step_name, k.metric_name): v.get_cumulative()
+ for k, v in self.metrics.items()
+ if k.cell_type == DistributionCell}
- gauges = {MetricKey(self.step_name, k): v.get_cumulative()
- for k, v in self.gauges.items()}
+ gauges = {MetricKey(self.step_name, k.metric_name): v.get_cumulative()
+ for k, v in self.metrics.items()
+ if k.cell_type == GaugeCell}
return MetricUpdates(counters, distributions, gauges)
def to_runner_api(self):
- return (
- [beam_fn_api_pb2.Metrics.User(
- metric_name=k.to_runner_api(),
- counter_data=beam_fn_api_pb2.Metrics.User.CounterData(
- value=v.get_cumulative()))
- for k, v in self.counters.items()] +
- [beam_fn_api_pb2.Metrics.User(
- metric_name=k.to_runner_api(),
- distribution_data=v.get_cumulative().to_runner_api())
- for k, v in self.distributions.items()] +
- [beam_fn_api_pb2.Metrics.User(
- metric_name=k.to_runner_api(),
- gauge_data=v.get_cumulative().to_runner_api())
- for k, v in self.gauges.items()]
- )
+ return [cell.to_runner_api_user_metric(key.metric_name)
+ for key, cell in self.metrics.items()]
def to_runner_api_monitoring_infos(self, transform_id):
"""Returns a list of MonitoringInfos for the metrics in this container."""
- all_user_metrics = []
- for k, v in self.counters.items():
- all_user_metrics.append(monitoring_infos.int64_user_counter(
- k.namespace, k.name,
- v.to_runner_api_monitoring_info(),
- ptransform=transform_id
- ))
-
- for k, v in self.distributions.items():
- all_user_metrics.append(monitoring_infos.int64_user_distribution(
- k.namespace, k.name,
- v.get_cumulative().to_runner_api_monitoring_info(),
- ptransform=transform_id
- ))
-
- for k, v in self.gauges.items():
- all_user_metrics.append(monitoring_infos.int64_user_gauge(
- k.namespace, k.name,
- v.get_cumulative().to_runner_api_monitoring_info(),
- ptransform=transform_id
- ))
+ all_user_metrics = [
+ cell.to_runner_api_monitoring_info(key.metric_name, transform_id)
+ for key, cell in self.metrics.items()]
return {monitoring_infos.to_key(mi) : mi for mi in all_user_metrics}
def reset(self):
- for counter in self.counters.values():
- counter.reset()
- for distribution in self.distributions.values():
- distribution.reset()
- for gauge in self.gauges.values():
- gauge.reset()
+ for metric in self.metrics.values():
+ metric.reset()
+
+ def __reduce__(self):
+ raise NotImplementedError
class MetricUpdates(object):
diff --git a/sdks/python/apache_beam/metrics/execution_test.py b/sdks/python/apache_beam/metrics/execution_test.py
index 9af1696..fc363a4 100644
--- a/sdks/python/apache_beam/metrics/execution_test.py
+++ b/sdks/python/apache_beam/metrics/execution_test.py
@@ -73,12 +73,6 @@
class TestMetricsContainer(unittest.TestCase):
- def test_create_new_counter(self):
- mc = MetricsContainer('astep')
- self.assertFalse(MetricName('namespace', 'name') in mc.counters)
- mc.get_counter(MetricName('namespace', 'name'))
- self.assertTrue(MetricName('namespace', 'name') in mc.counters)
-
def test_add_to_counter(self):
mc = MetricsContainer('astep')
counter = mc.get_counter(MetricName('namespace', 'name'))
diff --git a/sdks/python/apache_beam/metrics/metric.py b/sdks/python/apache_beam/metrics/metric.py
index acd4771..8bbe191 100644
--- a/sdks/python/apache_beam/metrics/metric.py
+++ b/sdks/python/apache_beam/metrics/metric.py
@@ -29,7 +29,8 @@
import inspect
from builtins import object
-from apache_beam.metrics.execution import MetricsEnvironment
+from apache_beam.metrics import cells
+from apache_beam.metrics.execution import MetricUpdater
from apache_beam.metrics.metricbase import Counter
from apache_beam.metrics.metricbase import Distribution
from apache_beam.metrics.metricbase import Gauge
@@ -101,11 +102,7 @@
def __init__(self, metric_name):
super(Metrics.DelegatingCounter, self).__init__()
self.metric_name = metric_name
-
- def inc(self, n=1):
- container = MetricsEnvironment.current_container()
- if container is not None:
- container.get_counter(self.metric_name).inc(n)
+ self.inc = MetricUpdater(cells.CounterCell, metric_name, default=1)
class DelegatingDistribution(Distribution):
"""Metrics Distribution Delegates functionality to MetricsEnvironment."""
@@ -113,11 +110,7 @@
def __init__(self, metric_name):
super(Metrics.DelegatingDistribution, self).__init__()
self.metric_name = metric_name
-
- def update(self, value):
- container = MetricsEnvironment.current_container()
- if container is not None:
- container.get_distribution(self.metric_name).update(value)
+ self.update = MetricUpdater(cells.DistributionCell, metric_name)
class DelegatingGauge(Gauge):
"""Metrics Gauge that Delegates functionality to MetricsEnvironment."""
@@ -125,11 +118,7 @@
def __init__(self, metric_name):
super(Metrics.DelegatingGauge, self).__init__()
self.metric_name = metric_name
-
- def set(self, value):
- container = MetricsEnvironment.current_container()
- if container is not None:
- container.get_gauge(self.metric_name).set(value)
+ self.set = MetricUpdater(cells.GaugeCell, metric_name)
class MetricResults(object):
diff --git a/sdks/python/apache_beam/metrics/metric_test.py b/sdks/python/apache_beam/metrics/metric_test.py
index 6e8ee08..cb18dc7 100644
--- a/sdks/python/apache_beam/metrics/metric_test.py
+++ b/sdks/python/apache_beam/metrics/metric_test.py
@@ -130,31 +130,36 @@
statesampler.set_current_tracker(sampler)
state1 = sampler.scoped_state('mystep', 'myState',
metrics_container=MetricsContainer('mystep'))
- sampler.start()
- with state1:
- counter_ns = 'aCounterNamespace'
- distro_ns = 'aDistributionNamespace'
- name = 'a_name'
- counter = Metrics.counter(counter_ns, name)
- distro = Metrics.distribution(distro_ns, name)
- counter.inc(10)
- counter.dec(3)
- distro.update(10)
- distro.update(2)
- self.assertTrue(isinstance(counter, Metrics.DelegatingCounter))
- self.assertTrue(isinstance(distro, Metrics.DelegatingDistribution))
- del distro
- del counter
+ try:
+ sampler.start()
+ with state1:
+ counter_ns = 'aCounterNamespace'
+ distro_ns = 'aDistributionNamespace'
+ name = 'a_name'
+ counter = Metrics.counter(counter_ns, name)
+ distro = Metrics.distribution(distro_ns, name)
+ counter.inc(10)
+ counter.dec(3)
+ distro.update(10)
+ distro.update(2)
+ self.assertTrue(isinstance(counter, Metrics.DelegatingCounter))
+ self.assertTrue(isinstance(distro, Metrics.DelegatingDistribution))
- container = MetricsEnvironment.current_container()
- self.assertEqual(
- container.counters[MetricName(counter_ns, name)].get_cumulative(),
- 7)
- self.assertEqual(
- container.distributions[MetricName(distro_ns, name)].get_cumulative(),
- DistributionData(12, 2, 2, 10))
- sampler.stop()
+ del distro
+ del counter
+
+ container = MetricsEnvironment.current_container()
+ self.assertEqual(
+ container.get_counter(
+ MetricName(counter_ns, name)).get_cumulative(),
+ 7)
+ self.assertEqual(
+ container.get_distribution(
+ MetricName(distro_ns, name)).get_cumulative(),
+ DistributionData(12, 2, 2, 10))
+ finally:
+ sampler.stop()
if __name__ == '__main__':
diff --git a/sdks/python/apache_beam/runners/portability/spark_runner.py b/sdks/python/apache_beam/runners/portability/spark_runner.py
new file mode 100644
index 0000000..ca03310
--- /dev/null
+++ b/sdks/python/apache_beam/runners/portability/spark_runner.py
@@ -0,0 +1,84 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""A runner for executing portable pipelines on Spark."""
+
+from __future__ import absolute_import
+from __future__ import print_function
+
+import re
+
+from apache_beam.options import pipeline_options
+from apache_beam.runners.portability import job_server
+from apache_beam.runners.portability import portable_runner
+
+# https://spark.apache.org/docs/latest/submitting-applications.html#master-urls
+LOCAL_MASTER_PATTERN = r'^local(\[.+\])?$'
+
+
+class SparkRunner(portable_runner.PortableRunner):
+ def run_pipeline(self, pipeline, options):
+ spark_options = options.view_as(SparkRunnerOptions)
+ portable_options = options.view_as(pipeline_options.PortableOptions)
+ if (re.match(LOCAL_MASTER_PATTERN, spark_options.spark_master_url)
+ and not portable_options.environment_type
+ and not portable_options.output_executable_path):
+ portable_options.environment_type = 'LOOPBACK'
+ return super(SparkRunner, self).run_pipeline(pipeline, options)
+
+ def default_job_server(self, options):
+ # TODO(BEAM-8139) submit a Spark jar to a cluster
+ return job_server.StopOnExitJobServer(SparkJarJobServer(options))
+
+
+class SparkRunnerOptions(pipeline_options.PipelineOptions):
+ @classmethod
+ def _add_argparse_args(cls, parser):
+ parser.add_argument('--spark_master_url',
+ default='local[4]',
+ help='Spark master URL (spark://HOST:PORT). '
+ 'Use "local" (single-threaded) or "local[*]" '
+ '(multi-threaded) to start a local cluster for '
+ 'the execution.')
+ parser.add_argument('--spark_job_server_jar',
+ help='Path or URL to a Beam Spark jobserver jar.')
+ parser.add_argument('--artifacts_dir', default=None)
+
+
+class SparkJarJobServer(job_server.JavaJarJobServer):
+ def __init__(self, options):
+ super(SparkJarJobServer, self).__init__()
+ options = options.view_as(SparkRunnerOptions)
+ self._jar = options.spark_job_server_jar
+ self._master_url = options.spark_master_url
+ self._artifacts_dir = options.artifacts_dir
+
+ def path_to_jar(self):
+ if self._jar:
+ return self._jar
+ else:
+ return self.path_to_beam_jar('runners:spark:job-server:shadowJar')
+
+ def java_arguments(self, job_port, artifacts_dir):
+ return [
+ '--spark-master-url', self._master_url,
+ '--artifacts-dir', (self._artifacts_dir
+ if self._artifacts_dir else artifacts_dir),
+ '--job-port', job_port,
+ '--artifact-port', 0,
+ '--expansion-port', 0
+ ]
diff --git a/sdks/python/apache_beam/runners/runner.py b/sdks/python/apache_beam/runners/runner.py
index c17ab08..fe9c492 100644
--- a/sdks/python/apache_beam/runners/runner.py
+++ b/sdks/python/apache_beam/runners/runner.py
@@ -38,6 +38,7 @@
'apache_beam.runners.interactive.interactive_runner.InteractiveRunner',
'apache_beam.runners.portability.flink_runner.FlinkRunner',
'apache_beam.runners.portability.portable_runner.PortableRunner',
+ 'apache_beam.runners.portability.spark_runner.SparkRunner',
'apache_beam.runners.test.TestDirectRunner',
'apache_beam.runners.test.TestDataflowRunner',
)
diff --git a/sdks/python/apache_beam/runners/worker/statesampler_fast.pxd b/sdks/python/apache_beam/runners/worker/statesampler_fast.pxd
index 799bd0d..aebf9f6 100644
--- a/sdks/python/apache_beam/runners/worker/statesampler_fast.pxd
+++ b/sdks/python/apache_beam/runners/worker/statesampler_fast.pxd
@@ -43,6 +43,9 @@
cdef int32_t current_state_index
+ cpdef ScopedState current_state(self)
+ cdef inline ScopedState current_state_c(self)
+
cpdef _scoped_state(
self, counter_name, name_context, output_counter, metrics_container)
@@ -56,7 +59,7 @@
cdef readonly object name_context
cdef readonly int64_t _nsecs
cdef int32_t old_state_index
- cdef readonly MetricsContainer _metrics_container
+ cdef readonly MetricsContainer metrics_container
cpdef __enter__(self)
diff --git a/sdks/python/apache_beam/runners/worker/statesampler_fast.pyx b/sdks/python/apache_beam/runners/worker/statesampler_fast.pyx
index 325ec99..8d2346a 100644
--- a/sdks/python/apache_beam/runners/worker/statesampler_fast.pyx
+++ b/sdks/python/apache_beam/runners/worker/statesampler_fast.pyx
@@ -159,8 +159,12 @@
(<ScopedState>state)._nsecs = 0
self.started = self.finished = False
- def current_state(self):
- return self.scoped_states_by_index[self.current_state_index]
+ cpdef ScopedState current_state(self):
+ return self.current_state_c()
+
+ cdef inline ScopedState current_state_c(self):
+ # Faster than cpdef due to self always being a Python subclass.
+ return <ScopedState>self.scoped_states_by_index[self.current_state_index]
cpdef _scoped_state(self, counter_name, name_context, output_counter,
metrics_container):
@@ -189,6 +193,11 @@
pythread.PyThread_release_lock(self.lock)
return scoped_state
+ def update_metric(self, typed_metric_name, value):
+ # Each of these is a cdef lookup.
+ self.current_state_c().metrics_container.get_metric_cell(
+ typed_metric_name).update(value)
+
cdef class ScopedState(object):
"""Context manager class managing transitions for a given sampler state."""
@@ -205,7 +214,7 @@
self.name_context = step_name_context
self.state_index = state_index
self.counter = counter
- self._metrics_container = metrics_container
+ self.metrics_container = metrics_container
@property
def nsecs(self):
@@ -232,7 +241,3 @@
self.sampler.current_state_index = self.old_state_index
self.sampler.state_transition_count += 1
pythread.PyThread_release_lock(self.sampler.lock)
-
- @property
- def metrics_container(self):
- return self._metrics_container
diff --git a/sdks/python/apache_beam/runners/worker/statesampler_slow.py b/sdks/python/apache_beam/runners/worker/statesampler_slow.py
index 0091828..fb2592c 100644
--- a/sdks/python/apache_beam/runners/worker/statesampler_slow.py
+++ b/sdks/python/apache_beam/runners/worker/statesampler_slow.py
@@ -50,6 +50,10 @@
return ScopedState(
self, counter_name, name_context, output_counter, metrics_container)
+ def update_metric(self, typed_metric_name, value):
+ self.current_state().metrics_container.get_metric_cell(
+ typed_metric_name).update(value)
+
def _enter_state(self, state):
self.state_transition_count += 1
self._state_stack.append(state)
diff --git a/sdks/python/setup.py b/sdks/python/setup.py
index ccf90f6..7eea64c 100644
--- a/sdks/python/setup.py
+++ b/sdks/python/setup.py
@@ -214,6 +214,7 @@
ext_modules=cythonize([
'apache_beam/**/*.pyx',
'apache_beam/coders/coder_impl.py',
+ 'apache_beam/metrics/cells.py',
'apache_beam/metrics/execution.py',
'apache_beam/runners/common.py',
'apache_beam/runners/worker/logger.py',
diff --git a/sdks/python/test-suites/portable/py2/build.gradle b/sdks/python/test-suites/portable/py2/build.gradle
index 3c1548d..5d967e4 100644
--- a/sdks/python/test-suites/portable/py2/build.gradle
+++ b/sdks/python/test-suites/portable/py2/build.gradle
@@ -39,6 +39,8 @@
dependsOn ':runners:flink:1.9:job-server:shadowJar'
dependsOn portableWordCountFlinkRunnerBatch
dependsOn portableWordCountFlinkRunnerStreaming
+ dependsOn ':runners:spark:job-server:shadowJar'
+ dependsOn portableWordCountSparkRunnerBatch
}
// TODO: Move the rest of this file into ../common.gradle.
diff --git a/sdks/python/test-suites/portable/py35/build.gradle b/sdks/python/test-suites/portable/py35/build.gradle
index 1b2cb4f..88b4e2f 100644
--- a/sdks/python/test-suites/portable/py35/build.gradle
+++ b/sdks/python/test-suites/portable/py35/build.gradle
@@ -36,4 +36,6 @@
dependsOn ':runners:flink:1.9:job-server:shadowJar'
dependsOn portableWordCountFlinkRunnerBatch
dependsOn portableWordCountFlinkRunnerStreaming
+ dependsOn ':runners:spark:job-server:shadowJar'
+ dependsOn portableWordCountSparkRunnerBatch
}
diff --git a/sdks/python/test-suites/portable/py36/build.gradle b/sdks/python/test-suites/portable/py36/build.gradle
index 475e110..496777d 100644
--- a/sdks/python/test-suites/portable/py36/build.gradle
+++ b/sdks/python/test-suites/portable/py36/build.gradle
@@ -36,4 +36,6 @@
dependsOn ':runners:flink:1.9:job-server:shadowJar'
dependsOn portableWordCountFlinkRunnerBatch
dependsOn portableWordCountFlinkRunnerStreaming
+ dependsOn ':runners:spark:job-server:shadowJar'
+ dependsOn portableWordCountSparkRunnerBatch
}
diff --git a/sdks/python/test-suites/portable/py37/build.gradle b/sdks/python/test-suites/portable/py37/build.gradle
index 912b316..924de81 100644
--- a/sdks/python/test-suites/portable/py37/build.gradle
+++ b/sdks/python/test-suites/portable/py37/build.gradle
@@ -36,4 +36,6 @@
dependsOn ':runners:flink:1.9:job-server:shadowJar'
dependsOn portableWordCountFlinkRunnerBatch
dependsOn portableWordCountFlinkRunnerStreaming
+ dependsOn ':runners:spark:job-server:shadowJar'
+ dependsOn portableWordCountSparkRunnerBatch
}
diff --git a/website/Gemfile b/website/Gemfile
index 4a08725..1050303 100644
--- a/website/Gemfile
+++ b/website/Gemfile
@@ -20,7 +20,7 @@
source 'https://rubygems.org'
-gem 'jekyll', '3.2'
+gem 'jekyll', '3.6.3'
# Jekyll plugins
group :jekyll_plugins do
diff --git a/website/Gemfile.lock b/website/Gemfile.lock
index e94f132..9db2ebe 100644
--- a/website/Gemfile.lock
+++ b/website/Gemfile.lock
@@ -13,7 +13,7 @@
concurrent-ruby (1.1.4)
ethon (0.11.0)
ffi (>= 1.3.0)
- ffi (1.9.25)
+ ffi (1.11.1)
forwardable-extended (2.6.0)
html-proofer (3.9.3)
activesupport (>= 4.2, < 6.0)
@@ -26,15 +26,16 @@
yell (~> 2.0)
i18n (0.9.5)
concurrent-ruby (~> 1.0)
- jekyll (3.2.0)
+ jekyll (3.6.3)
+ addressable (~> 2.4)
colorator (~> 1.0)
jekyll-sass-converter (~> 1.0)
jekyll-watch (~> 1.1)
- kramdown (~> 1.3)
- liquid (~> 3.0)
+ kramdown (~> 1.14)
+ liquid (~> 4.0)
mercenary (~> 0.3.3)
pathutil (~> 0.9)
- rouge (~> 1.7)
+ rouge (>= 1.7, < 3)
safe_yaml (~> 1.0)
jekyll-redirect-from (0.11.0)
jekyll (>= 2.0)
@@ -45,29 +46,27 @@
jekyll_github_sample (0.3.1)
activesupport (~> 4.0)
jekyll (~> 3.0)
- kramdown (1.16.2)
- liquid (3.0.6)
- listen (3.1.5)
- rb-fsevent (~> 0.9, >= 0.9.4)
- rb-inotify (~> 0.9, >= 0.9.7)
- ruby_dep (~> 1.2)
+ kramdown (1.17.0)
+ liquid (4.0.3)
+ listen (3.2.0)
+ rb-fsevent (~> 0.10, >= 0.10.3)
+ rb-inotify (~> 0.9, >= 0.9.10)
mercenary (0.3.6)
mini_portile2 (2.3.0)
minitest (5.11.3)
nokogiri (1.8.5)
mini_portile2 (~> 2.3.0)
parallel (1.12.1)
- pathutil (0.16.1)
+ pathutil (0.16.2)
forwardable-extended (~> 2.6)
public_suffix (3.0.3)
rake (12.3.0)
- rb-fsevent (0.10.2)
- rb-inotify (0.9.10)
- ffi (>= 0.5.0, < 2)
- rouge (1.11.1)
- ruby_dep (1.5.0)
- safe_yaml (1.0.4)
- sass (3.5.5)
+ rb-fsevent (0.10.3)
+ rb-inotify (0.10.0)
+ ffi (~> 1.0)
+ rouge (2.2.1)
+ safe_yaml (1.0.5)
+ sass (3.7.4)
sass-listen (~> 4.0.0)
sass-listen (4.0.0)
rb-fsevent (~> 0.9, >= 0.9.4)
@@ -85,7 +84,7 @@
DEPENDENCIES
activesupport (< 5.0.0.0)
html-proofer
- jekyll (= 3.2)
+ jekyll (= 3.6.3)
jekyll-redirect-from
jekyll-sass-converter
jekyll_github_sample