blob: 0b77a69b5a67929df735eb1539b5829409b741da [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.flink.translation.functions;
import java.io.IOException;
import java.util.EnumMap;
import java.util.Locale;
import java.util.Map;
import javax.annotation.Nullable;
import javax.annotation.concurrent.GuardedBy;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleProgressResponse;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleResponse;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.runners.core.InMemoryTimerInternals;
import org.apache.beam.runners.core.TimerInternals;
import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
import org.apache.beam.runners.core.construction.graph.ExecutableStage;
import org.apache.beam.runners.flink.metrics.FlinkMetricContainer;
import org.apache.beam.runners.fnexecution.control.BundleProgressHandler;
import org.apache.beam.runners.fnexecution.control.ExecutableStageContext;
import org.apache.beam.runners.fnexecution.control.OutputReceiverFactory;
import org.apache.beam.runners.fnexecution.control.ProcessBundleDescriptors;
import org.apache.beam.runners.fnexecution.control.RemoteBundle;
import org.apache.beam.runners.fnexecution.control.StageBundleFactory;
import org.apache.beam.runners.fnexecution.control.TimerReceiverFactory;
import org.apache.beam.runners.fnexecution.provisioning.JobInfo;
import org.apache.beam.runners.fnexecution.state.InMemoryBagUserStateFactory;
import org.apache.beam.runners.fnexecution.state.StateRequestHandler;
import org.apache.beam.runners.fnexecution.state.StateRequestHandlers;
import org.apache.beam.runners.fnexecution.translation.BatchSideInputHandlerFactory;
import org.apache.beam.runners.fnexecution.translation.PipelineTranslatorUtils;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.fn.data.FnDataReceiver;
import org.apache.beam.sdk.io.FileSystems;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.transforms.join.RawUnionValue;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.KV;
import org.apache.flink.api.common.functions.AbstractRichFunction;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.util.Collector;
import org.apache.flink.util.Preconditions;
import org.joda.time.Instant;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Flink operator that passes its input DataSet through an SDK-executed {@link
* org.apache.beam.runners.core.construction.graph.ExecutableStage}.
*
* <p>The output of this operation is a multiplexed DataSet whose elements are tagged with a union
* coder. The coder's tags are determined by the output coder map. The resulting data set should be
* further processed by a {@link FlinkExecutableStagePruningFunction}.
*/
public class FlinkExecutableStageFunction<InputT> extends AbstractRichFunction
implements MapPartitionFunction<WindowedValue<InputT>, RawUnionValue>,
GroupReduceFunction<WindowedValue<InputT>, RawUnionValue> {
private static final Logger LOG = LoggerFactory.getLogger(FlinkExecutableStageFunction.class);
// Main constructor fields. All must be Serializable because Flink distributes Functions to
// task managers via java serialization.
// Pipeline options for initializing the FileSystems
private final SerializablePipelineOptions pipelineOptions;
// The executable stage this function will run.
private final RunnerApi.ExecutableStagePayload stagePayload;
// Pipeline options. Used for provisioning api.
private final JobInfo jobInfo;
// Map from PCollection id to the union tag used to represent this PCollection in the output.
private final Map<String, Integer> outputMap;
private final FlinkExecutableStageContextFactory contextFactory;
private final Coder windowCoder;
// Unique name for namespacing metrics
private final String stepName;
// Worker-local fields. These should only be constructed and consumed on Flink TaskManagers.
private transient RuntimeContext runtimeContext;
private transient FlinkMetricContainer container;
private transient StateRequestHandler stateRequestHandler;
private transient ExecutableStageContext stageContext;
private transient StageBundleFactory stageBundleFactory;
private transient BundleProgressHandler progressHandler;
// Only initialized when the ExecutableStage is stateful
private transient InMemoryBagUserStateFactory bagUserStateHandlerFactory;
private transient ExecutableStage executableStage;
// In state
private transient Object currentTimerKey;
public FlinkExecutableStageFunction(
String stepName,
PipelineOptions pipelineOptions,
RunnerApi.ExecutableStagePayload stagePayload,
JobInfo jobInfo,
Map<String, Integer> outputMap,
FlinkExecutableStageContextFactory contextFactory,
Coder windowCoder) {
this.stepName = stepName;
this.pipelineOptions = new SerializablePipelineOptions(pipelineOptions);
this.stagePayload = stagePayload;
this.jobInfo = jobInfo;
this.outputMap = outputMap;
this.contextFactory = contextFactory;
this.windowCoder = windowCoder;
}
@Override
public void open(Configuration parameters) throws Exception {
// Register standard file systems.
FileSystems.setDefaultPipelineOptions(pipelineOptions.get());
executableStage = ExecutableStage.fromPayload(stagePayload);
runtimeContext = getRuntimeContext();
container = new FlinkMetricContainer(getRuntimeContext());
// TODO: Wire this into the distributed cache and make it pluggable.
stageContext = contextFactory.get(jobInfo);
stageBundleFactory = stageContext.getStageBundleFactory(executableStage);
// NOTE: It's safe to reuse the state handler between partitions because each partition uses the
// same backing runtime context and broadcast variables. We use checkState below to catch errors
// in backward-incompatible Flink changes.
stateRequestHandler =
getStateRequestHandler(
executableStage, stageBundleFactory.getProcessBundleDescriptor(), runtimeContext);
progressHandler =
new BundleProgressHandler() {
@Override
public void onProgress(ProcessBundleProgressResponse progress) {
container.updateMetrics(stepName, progress.getMonitoringInfosList());
}
@Override
public void onCompleted(ProcessBundleResponse response) {
container.updateMetrics(stepName, response.getMonitoringInfosList());
}
};
}
private StateRequestHandler getStateRequestHandler(
ExecutableStage executableStage,
ProcessBundleDescriptors.ExecutableProcessBundleDescriptor processBundleDescriptor,
RuntimeContext runtimeContext) {
final StateRequestHandler sideInputHandler;
StateRequestHandlers.SideInputHandlerFactory sideInputHandlerFactory =
BatchSideInputHandlerFactory.forStage(
executableStage, runtimeContext::getBroadcastVariable);
try {
sideInputHandler =
StateRequestHandlers.forSideInputHandlerFactory(
ProcessBundleDescriptors.getSideInputs(executableStage), sideInputHandlerFactory);
} catch (IOException e) {
throw new RuntimeException("Failed to setup state handler", e);
}
final StateRequestHandler userStateHandler;
if (executableStage.getUserStates().size() > 0) {
bagUserStateHandlerFactory = new InMemoryBagUserStateFactory();
userStateHandler =
StateRequestHandlers.forBagUserStateHandlerFactory(
processBundleDescriptor, bagUserStateHandlerFactory);
} else {
userStateHandler = StateRequestHandler.unsupported();
}
EnumMap<StateKey.TypeCase, StateRequestHandler> handlerMap =
new EnumMap<>(StateKey.TypeCase.class);
handlerMap.put(StateKey.TypeCase.MULTIMAP_SIDE_INPUT, sideInputHandler);
handlerMap.put(StateKey.TypeCase.BAG_USER_STATE, userStateHandler);
return StateRequestHandlers.delegateBasedUponType(handlerMap);
}
/** For non-stateful processing via a simple MapPartitionFunction. */
@Override
public void mapPartition(
Iterable<WindowedValue<InputT>> iterable, Collector<RawUnionValue> collector)
throws Exception {
ReceiverFactory receiverFactory = new ReceiverFactory(collector, outputMap);
try (RemoteBundle bundle =
stageBundleFactory.getBundle(receiverFactory, stateRequestHandler, progressHandler)) {
processElements(iterable, bundle);
}
}
/** For stateful and timer processing via a GroupReduceFunction. */
@Override
public void reduce(Iterable<WindowedValue<InputT>> iterable, Collector<RawUnionValue> collector)
throws Exception {
// Need to discard the old key's state
if (bagUserStateHandlerFactory != null) {
bagUserStateHandlerFactory.resetForNewKey();
}
// Used with Batch, we know that all the data is available for this key. We can't use the
// timer manager from the context because it doesn't exist. So we create one and advance
// time to the end after processing all elements.
final InMemoryTimerInternals timerInternals = new InMemoryTimerInternals();
timerInternals.advanceProcessingTime(Instant.now());
timerInternals.advanceSynchronizedProcessingTime(Instant.now());
ReceiverFactory receiverFactory =
new ReceiverFactory(
collector,
outputMap,
new TimerReceiverFactory(
stageBundleFactory,
(WindowedValue timerElement, TimerInternals.TimerData timerData) -> {
currentTimerKey = ((KV) timerElement.getValue()).getKey();
timerInternals.setTimer(timerData);
},
windowCoder));
// First process all elements and make sure no more elements can arrive
try (RemoteBundle bundle =
stageBundleFactory.getBundle(receiverFactory, stateRequestHandler, progressHandler)) {
processElements(iterable, bundle);
}
// Finish any pending windows by advancing the input watermark to infinity.
timerInternals.advanceInputWatermark(BoundedWindow.TIMESTAMP_MAX_VALUE);
// Finally, advance the processing time to infinity to fire any timers.
timerInternals.advanceProcessingTime(BoundedWindow.TIMESTAMP_MAX_VALUE);
timerInternals.advanceSynchronizedProcessingTime(BoundedWindow.TIMESTAMP_MAX_VALUE);
// Now we fire the timers and process elements generated by timers (which may be timers itself)
try (RemoteBundle bundle =
stageBundleFactory.getBundle(receiverFactory, stateRequestHandler, progressHandler)) {
PipelineTranslatorUtils.fireEligibleTimers(
timerInternals,
(String timerId, WindowedValue timerValue) -> {
FnDataReceiver<WindowedValue<?>> fnTimerReceiver =
bundle.getInputReceivers().get(timerId);
Preconditions.checkNotNull(fnTimerReceiver, "No FnDataReceiver found for %s", timerId);
try {
fnTimerReceiver.accept(timerValue);
} catch (Exception e) {
throw new RuntimeException(
String.format(Locale.ENGLISH, "Failed to process timer: %s", timerValue));
}
},
currentTimerKey);
}
}
private void processElements(Iterable<WindowedValue<InputT>> iterable, RemoteBundle bundle)
throws Exception {
Preconditions.checkArgument(bundle != null, "RemoteBundle must not be null");
String inputPCollectionId = executableStage.getInputPCollection().getId();
FnDataReceiver<WindowedValue<?>> mainReceiver =
Preconditions.checkNotNull(
bundle.getInputReceivers().get(inputPCollectionId),
"Main input receiver for %s could not be initialized",
inputPCollectionId);
for (WindowedValue<InputT> input : iterable) {
mainReceiver.accept(input);
}
}
@Override
public void close() throws Exception {
// close may be called multiple times when an exception is thrown
if (stageContext != null) {
try (AutoCloseable bundleFactoryCloser = stageBundleFactory;
AutoCloseable closable = stageContext) {
} catch (Exception e) {
LOG.error("Error in close: ", e);
throw e;
}
}
stageContext = null;
}
/**
* Receiver factory that wraps outgoing elements with the corresponding union tag for a
* multiplexed PCollection and optionally handles timer items.
*/
private static class ReceiverFactory implements OutputReceiverFactory {
private final Object collectorLock = new Object();
@GuardedBy("collectorLock")
private final Collector<RawUnionValue> collector;
private final Map<String, Integer> outputMap;
@Nullable private final TimerReceiverFactory timerReceiverFactory;
ReceiverFactory(Collector<RawUnionValue> collector, Map<String, Integer> outputMap) {
this(collector, outputMap, null);
}
ReceiverFactory(
Collector<RawUnionValue> collector,
Map<String, Integer> outputMap,
@Nullable TimerReceiverFactory timerReceiverFactory) {
this.collector = collector;
this.outputMap = outputMap;
this.timerReceiverFactory = timerReceiverFactory;
}
@Override
public <OutputT> FnDataReceiver<OutputT> create(String collectionId) {
Integer unionTag = outputMap.get(collectionId);
if (unionTag != null) {
int tagInt = unionTag;
return receivedElement -> {
synchronized (collectorLock) {
collector.collect(new RawUnionValue(tagInt, receivedElement));
}
};
} else if (timerReceiverFactory != null) {
// Delegate to TimerReceiverFactory
return timerReceiverFactory.create(collectionId);
} else {
throw new IllegalStateException(
String.format(Locale.ENGLISH, "Unknown PCollectionId %s", collectionId));
}
}
}
}