blob: ce205270e56330fd4070d035693caf083a18484b [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.dataflow.worker.fn.control;
import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState;
import java.io.Closeable;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutionException;
import javax.annotation.Nullable;
import org.apache.beam.model.fnexecution.v1.BeamFnApi;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.InstructionRequest;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.InstructionResponse;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleDescriptor;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleProgressRequest;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleRequest;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.RegisterRequest;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateAppendResponse;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateClearResponse;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateGetResponse;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest.RequestCase;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateResponse;
import org.apache.beam.model.pipeline.v1.MetricsApi.MonitoringInfo;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.model.pipeline.v1.RunnerApi.PTransform;
import org.apache.beam.runners.core.SideInputReader;
import org.apache.beam.runners.core.StateNamespaces;
import org.apache.beam.runners.core.StateTags;
import org.apache.beam.runners.core.metrics.MonitoringInfoConstants;
import org.apache.beam.runners.dataflow.worker.ByteStringCoder;
import org.apache.beam.runners.dataflow.worker.DataflowExecutionContext.DataflowStepContext;
import org.apache.beam.runners.dataflow.worker.DataflowOperationContext;
import org.apache.beam.runners.dataflow.worker.counters.NameContext;
import org.apache.beam.runners.dataflow.worker.util.common.worker.Operation;
import org.apache.beam.runners.dataflow.worker.util.common.worker.OperationContext;
import org.apache.beam.runners.dataflow.worker.util.common.worker.OutputReceiver;
import org.apache.beam.runners.fnexecution.control.InstructionRequestHandler;
import org.apache.beam.runners.fnexecution.state.StateDelegator;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.fn.IdGenerator;
import org.apache.beam.sdk.fn.data.RemoteGrpcPortRead;
import org.apache.beam.sdk.fn.data.RemoteGrpcPortWrite;
import org.apache.beam.sdk.state.BagState;
import org.apache.beam.sdk.transforms.Materializations;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
import org.apache.beam.sdk.util.MoreFutures;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.vendor.grpc.v1p21p0.com.google.protobuf.ByteString;
import org.apache.beam.vendor.grpc.v1p21p0.com.google.protobuf.TextFormat;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.MoreObjects;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Table;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* This {@link Operation} is responsible for communicating with the SDK harness and asking it to
* process a bundle of work. This operation registers the {@link
* org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleDescriptor} when executed the first
* time. Afterwards, it only asks the SDK harness to process the bundle using the already registered
* {@link org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleDescriptor}.
*
* <p>This operation supports restart.
*/
public class RegisterAndProcessBundleOperation extends Operation {
private static final Logger LOG =
LoggerFactory.getLogger(RegisterAndProcessBundleOperation.class);
private static final OutputReceiver[] EMPTY_RECEIVERS = new OutputReceiver[0];
private final IdGenerator idGenerator;
private final InstructionRequestHandler instructionRequestHandler;
private final StateDelegator beamFnStateDelegator;
private final RegisterRequest registerRequest;
private final Map<String, DataflowStepContext> ptransformIdToUserStepContext;
private final Map<String, SideInputReader> ptransformIdToSideInputReader;
private final Table<String, String, PCollectionView<?>>
ptransformIdToSideInputIdToPCollectionView;
private final ConcurrentHashMap<StateKey, BagState<ByteString>> userStateData;
private final Map<String, NameContext> pcollectionIdToNameContext;
private @Nullable CompletionStage<InstructionResponse> registerFuture;
private @Nullable CompletionStage<InstructionResponse> processBundleResponse;
private volatile @Nullable String processBundleId = null;
private StateDelegator.Registration deregisterStateHandler;
private @Nullable String grpcReadTransformId = null;
private String grpcReadTransformOutputName = null;
private String grpcReadTransformOutputPCollectionName = null;
private final Set<String> grpcReadTransformReadWritePCollectionNames;
public RegisterAndProcessBundleOperation(
IdGenerator idGenerator,
InstructionRequestHandler instructionRequestHandler,
StateDelegator beamFnStateDelegator,
RegisterRequest registerRequest,
Map<String, DataflowOperationContext> ptransformIdToOperationContext,
Map<String, DataflowStepContext> ptransformIdToSystemStepContext,
Map<String, SideInputReader> ptransformIdToSideInputReader,
Table<String, String, PCollectionView<?>> ptransformIdToSideInputIdToPCollectionView,
Map<String, NameContext> pcollectionIdToNameContext,
OperationContext context) {
super(EMPTY_RECEIVERS, context);
this.idGenerator = idGenerator;
this.instructionRequestHandler = instructionRequestHandler;
this.beamFnStateDelegator = beamFnStateDelegator;
this.registerRequest = registerRequest;
this.ptransformIdToSideInputReader = ptransformIdToSideInputReader;
this.ptransformIdToSideInputIdToPCollectionView = ptransformIdToSideInputIdToPCollectionView;
this.pcollectionIdToNameContext = pcollectionIdToNameContext;
ImmutableMap.Builder<String, DataflowStepContext> userStepContextsMap = ImmutableMap.builder();
for (Map.Entry<String, DataflowStepContext> entry :
ptransformIdToSystemStepContext.entrySet()) {
userStepContextsMap.put(entry.getKey(), entry.getValue().namespacedToUser());
}
this.ptransformIdToUserStepContext = userStepContextsMap.build();
this.userStateData = new ConcurrentHashMap<>();
checkState(
registerRequest.getProcessBundleDescriptorCount() == 1,
"Only one bundle registration at a time currently supported.");
if (LOG.isDebugEnabled()) {
LOG.debug(
"Process bundle descriptor {}", toDot(registerRequest.getProcessBundleDescriptor(0)));
}
for (Map.Entry<String, RunnerApi.PTransform> pTransform :
registerRequest.getProcessBundleDescriptor(0).getTransformsMap().entrySet()) {
if (pTransform.getValue().getSpec().getUrn().equals(RemoteGrpcPortRead.URN)) {
if (grpcReadTransformId != null) {
// TODO: Handle the case of more than one input.
grpcReadTransformId = null;
grpcReadTransformOutputName = null;
grpcReadTransformOutputPCollectionName = null;
break;
}
grpcReadTransformId = pTransform.getKey();
grpcReadTransformOutputName =
Iterables.getOnlyElement(pTransform.getValue().getOutputsMap().keySet());
grpcReadTransformOutputPCollectionName =
pTransform.getValue().getOutputsMap().get(grpcReadTransformOutputName);
}
}
grpcReadTransformReadWritePCollectionNames =
extractCrossBoundaryGrpcPCollectionNames(
registerRequest.getProcessBundleDescriptor(0).getTransformsMap().entrySet());
}
private Set<String> extractCrossBoundaryGrpcPCollectionNames(
final Set<Entry<String, PTransform>> ptransforms) {
Set<String> result = new HashSet<>();
// GRPC Read/Write expected to only have one Output/Input respectively.
for (Map.Entry<String, RunnerApi.PTransform> pTransform : ptransforms) {
if (pTransform.getValue().getSpec().getUrn().equals(RemoteGrpcPortRead.URN)) {
String grpcReadTransformOutputName =
Iterables.getOnlyElement(pTransform.getValue().getOutputsMap().keySet());
String pcollectionName =
pTransform.getValue().getOutputsMap().get(grpcReadTransformOutputName);
result.add(pcollectionName);
}
if (pTransform.getValue().getSpec().getUrn().equals(RemoteGrpcPortWrite.URN)) {
String grpcTransformInputName =
Iterables.getOnlyElement(pTransform.getValue().getInputsMap().keySet());
String pcollectionName = pTransform.getValue().getInputsMap().get(grpcTransformInputName);
result.add(pcollectionName);
}
}
return result;
}
/** Generates a dot description of the process bundle descriptor. */
private static String toDot(ProcessBundleDescriptor processBundleDescriptor) {
StringBuilder builder = new StringBuilder();
builder.append("digraph network {\n");
Map<String, String> nodeName = Maps.newHashMap();
processBundleDescriptor
.getPcollectionsMap()
.forEach((key, node) -> nodeName.put("pc " + key, "n" + nodeName.size()));
processBundleDescriptor
.getTransformsMap()
.forEach((key, node) -> nodeName.put("pt " + key, "n" + nodeName.size()));
for (Entry<String, RunnerApi.PCollection> nodeEntry :
processBundleDescriptor.getPcollectionsMap().entrySet()) {
builder.append(
String.format(
" %s [fontname=\"Courier New\" label=\"%s\"];%n",
nodeName.get("pc " + nodeEntry.getKey()),
escapeDot(nodeEntry.getKey() + ": " + nodeEntry.getValue().getUniqueName())));
}
for (Entry<String, RunnerApi.PTransform> nodeEntry :
processBundleDescriptor.getTransformsMap().entrySet()) {
builder.append(
String.format(
" %s [fontname=\"Courier New\" label=\"%s\"];%n",
nodeName.get("pt " + nodeEntry.getKey()),
escapeDot(
nodeEntry.getKey()
+ ": "
+ nodeEntry.getValue().getSpec().getUrn()
+ " "
+ nodeEntry.getValue().getUniqueName())));
for (Entry<String, String> inputEntry : nodeEntry.getValue().getInputsMap().entrySet()) {
builder.append(
String.format(
" %s -> %s [fontname=\"Courier New\" label=\"%s\"];%n",
nodeName.get("pc " + inputEntry.getValue()),
nodeName.get("pt " + nodeEntry.getKey()),
escapeDot(inputEntry.getKey())));
}
for (Entry<String, String> outputEntry : nodeEntry.getValue().getOutputsMap().entrySet()) {
builder.append(
String.format(
" %s -> %s [fontname=\"Courier New\" label=\"%s\"];%n",
nodeName.get("pt " + nodeEntry.getKey()),
nodeName.get("pc " + outputEntry.getValue()),
escapeDot(outputEntry.getKey())));
}
}
builder.append("}");
return builder.toString();
}
private static String escapeDot(String s) {
return s.replace("\\", "\\\\")
.replace("\"", "\\\"")
// http://www.graphviz.org/doc/info/attrs.html#k:escString
// The escape sequences "\n", "\l" and "\r" divide the label into lines, centered,
// left-justified, and right-justified, respectively.
.replace("\n", "\\l");
}
/**
* Returns an id for the current bundle being processed.
*
* <p>Generates new id with idGenerator if no id is cached.
*
* <p><b>Note</b>: This operation could be used across multiple bundles, so a unique id is
* generated for every bundle. {@link Operation Operations} accessing the bundle id should only
* call this once per bundle and cache the id in the {@link Operation#start()} method and clear it
* in the {@link Operation#finish()} method.
*/
public synchronized String getProcessBundleInstructionId() {
if (processBundleId == null) {
processBundleId = idGenerator.getId();
}
return processBundleId;
}
public String getCurrentProcessBundleInstructionId() {
return processBundleId;
}
@Override
public void start() throws Exception {
try (Closeable scope = context.enterStart()) {
super.start();
// Only register once by using the presence of the future as a signal.
if (registerFuture == null) {
InstructionRequest request =
InstructionRequest.newBuilder()
.setInstructionId(idGenerator.getId())
.setRegister(registerRequest)
.build();
registerFuture = instructionRequestHandler.handle(request);
getRegisterResponse(registerFuture);
}
checkState(
registerRequest.getProcessBundleDescriptorCount() == 1,
"Only one bundle registration at a time currently supported.");
InstructionRequest processBundleRequest =
InstructionRequest.newBuilder()
.setInstructionId(getProcessBundleInstructionId())
.setProcessBundle(
ProcessBundleRequest.newBuilder()
.setProcessBundleDescriptorReference(
registerRequest.getProcessBundleDescriptor(0).getId()))
.build();
deregisterStateHandler =
beamFnStateDelegator.registerForProcessBundleInstructionId(
getProcessBundleInstructionId(), this::delegateByStateKeyType);
processBundleResponse = instructionRequestHandler.handle(processBundleRequest);
}
}
@Override
public void finish() throws Exception {
// TODO: Once we have access to windowing strategy via the ParDoPayload, add support to garbage
// collect any user state set. Also add support for consuming those garbage collection timers.
try (Closeable scope = context.enterFinish()) {
// Await completion or failure
BeamFnApi.ProcessBundleResponse completedResponse =
MoreFutures.get(getProcessBundleResponse(processBundleResponse));
if (completedResponse.getResidualRootsCount() > 0) {
throw new IllegalStateException(
"TODO: [BEAM-2939] residual roots in process bundle response not yet supported.");
}
deregisterStateHandler.deregister();
userStateData.clear();
processBundleId = null;
super.finish();
}
}
@Override
public void abort() throws Exception {
try (Closeable scope = context.enterAbort()) {
deregisterStateHandler.abort();
cancelIfNotNull(registerFuture);
cancelIfNotNull(processBundleResponse);
super.abort();
}
}
public Map<String, NameContext> getPCollectionIdToNameContext() {
return this.pcollectionIdToNameContext;
}
public Map<String, DataflowStepContext> getPtransformIdToUserStepContext() {
return ptransformIdToUserStepContext;
}
/**
* Returns the compound metrics recorded, by issuing a request to the SDK harness.
*
* <p>This includes key progress indicators in {@link BeamFnApi.Metrics.PTransform.Measured} as
* well as user-defined metrics in {@link BeamFnApi.Metrics.User}.
*
* <p>Use {@link #getInputElementsConsumed(BeamFnApi.Metrics)} on the future value to extract the
* elements consumed from the upstream read operation.
*
* <p>May be called at any time, including before start() and after finish().
*
* @throws InterruptedException
* @throws ExecutionException
*/
public CompletionStage<BeamFnApi.ProcessBundleProgressResponse> getProcessBundleProgress()
throws InterruptedException, ExecutionException {
// processBundleId may be reset if this bundle finishes asynchronously.
String processBundleId = this.processBundleId;
if (processBundleId == null) {
return CompletableFuture.completedFuture(
BeamFnApi.ProcessBundleProgressResponse.getDefaultInstance());
}
InstructionRequest processBundleRequest =
InstructionRequest.newBuilder()
.setInstructionId(idGenerator.getId())
.setProcessBundleProgress(
ProcessBundleProgressRequest.newBuilder().setInstructionReference(processBundleId))
.build();
return instructionRequestHandler
.handle(processBundleRequest)
.thenApply(
response -> {
if (!response.getError().isEmpty()) {
throw new IllegalStateException(response.getError());
}
return response.getProcessBundleProgress();
});
}
/** Returns the final metrics returned by the SDK harness when it completes the bundle. */
public CompletionStage<BeamFnApi.Metrics> getFinalMetrics() {
return getProcessBundleResponse(processBundleResponse)
.thenApply(response -> response.getMetrics());
}
public CompletionStage<List<MonitoringInfo>> getFinalMonitoringInfos() {
return getProcessBundleResponse(processBundleResponse)
.thenApply(response -> response.getMonitoringInfosList());
}
public boolean hasFailed() throws ExecutionException, InterruptedException {
if (processBundleResponse != null && processBundleResponse.toCompletableFuture().isDone()) {
return !processBundleResponse.toCompletableFuture().get().getError().isEmpty();
} else {
// At the very least, we don't know that this has failed yet.
return false;
}
}
/*
* Returns a subset of monitoring infos that refer to grpc IO.
*/
public List<MonitoringInfo> findIOPCollectionMonitoringInfos(
Iterable<MonitoringInfo> monitoringInfos) {
List<MonitoringInfo> result = new ArrayList<MonitoringInfo>();
if (grpcReadTransformReadWritePCollectionNames.isEmpty()) {
return result;
}
for (MonitoringInfo mi : monitoringInfos) {
if (mi.getUrn().equals(MonitoringInfoConstants.Urns.ELEMENT_COUNT)) {
String pcollection =
mi.getLabelsOrDefault(MonitoringInfoConstants.Labels.PCOLLECTION, null);
if ((pcollection != null)
&& (grpcReadTransformReadWritePCollectionNames.contains(pcollection))) {
result.add(mi);
}
}
}
return result;
}
long getInputElementsConsumed(final Iterable<MonitoringInfo> monitoringInfos) {
if (grpcReadTransformId == null) {
return 0;
}
for (MonitoringInfo mi : monitoringInfos) {
if (mi.getUrn().equals(MonitoringInfoConstants.Urns.ELEMENT_COUNT)) {
String pcollection =
mi.getLabelsOrDefault(MonitoringInfoConstants.Labels.PCOLLECTION, null);
if (pcollection != null && pcollection.equals(grpcReadTransformOutputPCollectionName)) {
return mi.getMetric().getCounterData().getInt64Value();
}
}
}
return 0;
}
/** Returns the number of input elements consumed by the gRPC read, if known, otherwise 0. */
double getInputElementsConsumed(BeamFnApi.Metrics metrics) {
return metrics
.getPtransformsOrDefault(
grpcReadTransformId, BeamFnApi.Metrics.PTransform.getDefaultInstance())
.getProcessedElements()
.getMeasured()
.getOutputElementCountsOrDefault(grpcReadTransformOutputName, 0);
}
private CompletionStage<BeamFnApi.StateResponse.Builder> delegateByStateKeyType(
StateRequest stateRequest) {
switch (stateRequest.getStateKey().getTypeCase()) {
case BAG_USER_STATE:
return handleBagUserState(stateRequest);
case MULTIMAP_SIDE_INPUT:
return handleMultimapSideInput(stateRequest);
default:
throw new UnsupportedOperationException(
String.format(
"Dataflow does not handle StateRequests of type %s",
stateRequest.getStateKey().getTypeCase()));
}
}
private CompletionStage<BeamFnApi.StateResponse.Builder> handleMultimapSideInput(
StateRequest stateRequest) {
checkState(
stateRequest.getRequestCase() == RequestCase.GET,
String.format(
"MultimapSideInput state requests only support '%s' requests, received '%s'",
RequestCase.GET, stateRequest.getRequestCase()));
StateKey.MultimapSideInput multimapSideInputStateKey =
stateRequest.getStateKey().getMultimapSideInput();
SideInputReader sideInputReader =
ptransformIdToSideInputReader.get(multimapSideInputStateKey.getPtransformId());
checkState(
sideInputReader != null,
String.format("Unknown PTransform '%s'", multimapSideInputStateKey.getPtransformId()));
PCollectionView<Materializations.MultimapView<Object, Object>> view =
(PCollectionView<Materializations.MultimapView<Object, Object>>)
ptransformIdToSideInputIdToPCollectionView.get(
multimapSideInputStateKey.getPtransformId(),
multimapSideInputStateKey.getSideInputId());
checkState(
view != null,
String.format(
"Unknown side input '%s' on PTransform '%s'",
multimapSideInputStateKey.getSideInputId(),
multimapSideInputStateKey.getPtransformId()));
checkState(
Materializations.MULTIMAP_MATERIALIZATION_URN.equals(
view.getViewFn().getMaterialization().getUrn()),
String.format(
"Unknown materialization for side input '%s' on PTransform '%s' with urn '%s'",
multimapSideInputStateKey.getSideInputId(),
multimapSideInputStateKey.getPtransformId(),
view.getViewFn().getMaterialization().getUrn()));
checkState(
view.getCoderInternal() instanceof KvCoder,
String.format(
"Materialization of side input '%s' on PTransform '%s' expects %s but received %s.",
multimapSideInputStateKey.getSideInputId(),
multimapSideInputStateKey.getPtransformId(),
KvCoder.class.getSimpleName(),
view.getCoderInternal().getClass().getSimpleName()));
Coder<Object> keyCoder = ((KvCoder) view.getCoderInternal()).getKeyCoder();
Coder<Object> valueCoder = ((KvCoder) view.getCoderInternal()).getValueCoder();
BoundedWindow window;
try {
// TODO: Use EncodedWindow instead of decoding the window.
window =
view.getWindowingStrategyInternal()
.getWindowFn()
.windowCoder()
.decode(multimapSideInputStateKey.getWindow().newInput());
} catch (IOException e) {
throw new IllegalArgumentException(
String.format(
"Unable to decode window for side input '%s' on PTransform '%s'.",
multimapSideInputStateKey.getSideInputId(),
multimapSideInputStateKey.getPtransformId()),
e);
}
Object userKey;
try {
// TODO: Use the encoded representation of the key.
userKey = keyCoder.decode(multimapSideInputStateKey.getKey().newInput());
} catch (IOException e) {
throw new IllegalArgumentException(
String.format(
"Unable to decode user key for side input '%s' on PTransform '%s'.",
multimapSideInputStateKey.getSideInputId(),
multimapSideInputStateKey.getPtransformId()),
e);
}
Materializations.MultimapView<Object, Object> sideInput = sideInputReader.get(view, window);
Iterable<Object> values = sideInput.get(userKey);
try {
// TODO: Chunk the requests and use a continuation key to support side input values
// that are larger then 2 GiBs.
// TODO: Use the raw value so we don't go through a decode/encode cycle for no reason.
return CompletableFuture.completedFuture(
StateResponse.newBuilder()
.setGet(StateGetResponse.newBuilder().setData(encodeAndConcat(values, valueCoder))));
} catch (IOException e) {
throw new IllegalArgumentException(
String.format(
"Unable to encode values for side input '%s' on PTransform '%s'.",
multimapSideInputStateKey.getSideInputId(),
multimapSideInputStateKey.getPtransformId()),
e);
}
}
private CompletionStage<BeamFnApi.StateResponse.Builder> handleBagUserState(
StateRequest stateRequest) {
StateKey.BagUserState bagUserStateKey = stateRequest.getStateKey().getBagUserState();
DataflowStepContext userStepContext =
ptransformIdToUserStepContext.get(bagUserStateKey.getPtransformId());
checkState(
userStepContext != null,
String.format("Unknown PTransform id '%s'", bagUserStateKey.getPtransformId()));
// TODO: We should not be required to hold onto a pointer to the bag states for the
// user. InMemoryStateInternals assumes that the Java garbage collector does the clean-up work
// but instead StateInternals should hold its own references and write out any data and
// clear references when the MapTask within Dataflow completes like how WindmillStateInternals
// works.
BagState<ByteString> state =
userStateData.computeIfAbsent(
stateRequest.getStateKey(),
unused ->
userStepContext
.stateInternals()
.state(
// TODO: Once we have access to the ParDoPayload, use its windowing strategy
// to decode the window for the well known window types. Longer term we need
// to swap
// to use the encoded version and not rely on needing to decode the entire
// window.
StateNamespaces.window(GlobalWindow.Coder.INSTANCE, GlobalWindow.INSTANCE),
StateTags.bag(bagUserStateKey.getUserStateId(), ByteStringCoder.of())));
switch (stateRequest.getRequestCase()) {
case GET:
return CompletableFuture.completedFuture(
StateResponse.newBuilder()
.setGet(StateGetResponse.newBuilder().setData(concat(state.read()))));
case APPEND:
state.add(stateRequest.getAppend().getData());
return CompletableFuture.completedFuture(
StateResponse.newBuilder().setAppend(StateAppendResponse.getDefaultInstance()));
case CLEAR:
state.clear();
return CompletableFuture.completedFuture(
StateResponse.newBuilder().setClear(StateClearResponse.getDefaultInstance()));
default:
throw new IllegalArgumentException(
String.format("Unknown request type %s", stateRequest.getRequestCase()));
}
}
@Override
public boolean supportsRestart() {
return true;
}
private static CompletionStage<BeamFnApi.InstructionResponse> throwIfFailure(
CompletionStage<InstructionResponse> responseFuture) {
return responseFuture.thenApply(
response -> {
if (!response.getError().isEmpty()) {
throw new IllegalStateException(
String.format(
"Client failed to process %s with error [%s].",
response.getInstructionId(), response.getError()));
}
return response;
});
}
private static CompletionStage<BeamFnApi.ProcessBundleResponse> getProcessBundleResponse(
CompletionStage<InstructionResponse> responseFuture) {
return throwIfFailure(responseFuture)
.thenApply(
response -> {
switch (response.getResponseCase()) {
case PROCESS_BUNDLE:
return response.getProcessBundle();
default:
throw new IllegalStateException(
String.format(
"SDK harness returned wrong kind of response to ProcessBundleRequest: %s",
TextFormat.printToString(response)));
}
});
}
private static CompletionStage<BeamFnApi.RegisterResponse> getRegisterResponse(
CompletionStage<InstructionResponse> responseFuture)
throws ExecutionException, InterruptedException {
return throwIfFailure(responseFuture)
.thenApply(
response -> {
switch (response.getResponseCase()) {
case REGISTER:
return response.getRegister();
default:
throw new IllegalStateException(
String.format(
"SDK harness returned wrong kind of response to RegisterRequest: %s",
TextFormat.printToString(response)));
}
});
}
private static void cancelIfNotNull(CompletionStage<?> future) {
if (future != null) {
// TODO: add cancel(boolean) to MoreFutures
future.toCompletableFuture().cancel(true);
}
}
private ByteString concat(Iterable<ByteString> values) {
ByteString rval = ByteString.EMPTY;
if (values != null) {
for (ByteString value : values) {
rval = rval.concat(value);
}
}
return rval;
}
static ByteString encodeAndConcat(Iterable<Object> values, Coder valueCoder) throws IOException {
ByteString.Output out = ByteString.newOutput();
if (values != null) {
for (Object value : values) {
int size = out.size();
valueCoder.encode(value, out);
// Pad empty values by one byte as per the Beam Fn data transfer specification.
if (size == out.size()) {
out.write(0);
}
}
}
return out.toByteString();
}
@Override
public String toString() {
return MoreObjects.toStringHelper(this)
.add("processBundleId", processBundleId)
.add("processBundleDescriptors", registerRequest.getProcessBundleDescriptorList())
.toString();
}
}