blob: 86e2ee68d7a55a9330d7515bc2fb73c7c11f23c7 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.samza.runtime;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.ServiceLoader;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.runners.core.DoFnRunner;
import org.apache.beam.runners.core.DoFnRunners;
import org.apache.beam.runners.core.PushbackSideInputDoFnRunner;
import org.apache.beam.runners.core.SideInputHandler;
import org.apache.beam.runners.core.SimplePushbackSideInputDoFnRunner;
import org.apache.beam.runners.core.StateNamespace;
import org.apache.beam.runners.core.StateNamespaces;
import org.apache.beam.runners.core.TimerInternals;
import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
import org.apache.beam.runners.core.construction.graph.ExecutableStage;
import org.apache.beam.runners.core.serialization.Base64Serializer;
import org.apache.beam.runners.fnexecution.control.StageBundleFactory;
import org.apache.beam.runners.samza.SamzaExecutionContext;
import org.apache.beam.runners.samza.SamzaPipelineOptions;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.DoFnSchemaInformation;
import org.apache.beam.sdk.transforms.join.RawUnionValue;
import org.apache.beam.sdk.transforms.reflect.DoFnInvoker;
import org.apache.beam.sdk.transforms.reflect.DoFnInvokers;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature;
import org.apache.beam.sdk.transforms.reflect.DoFnSignatures;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.WindowingStrategy;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterators;
import org.apache.samza.config.Config;
import org.apache.samza.context.Context;
import org.apache.samza.operators.Scheduler;
import org.joda.time.Instant;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/** Samza operator for {@link DoFn}. */
public class DoFnOp<InT, FnOutT, OutT> implements Op<InT, OutT, Void> {
private static final Logger LOG = LoggerFactory.getLogger(DoFnOp.class);
private final TupleTag<FnOutT> mainOutputTag;
private final DoFn<InT, FnOutT> doFn;
private final Coder<?> keyCoder;
private final Collection<PCollectionView<?>> sideInputs;
private final List<TupleTag<?>> sideOutputTags;
private final WindowingStrategy windowingStrategy;
private final OutputManagerFactory<OutT> outputManagerFactory;
// NOTE: we use HashMap here to guarantee Serializability
private final HashMap<String, PCollectionView<?>> idToViewMap;
private final String transformFullName;
private final String transformId;
private final Coder<InT> inputCoder;
private final HashMap<TupleTag<?>, Coder<?>> outputCoders;
private final PCollection.IsBounded isBounded;
// portable api related
private final boolean isPortable;
private final RunnerApi.ExecutableStagePayload stagePayload;
private final HashMap<String, TupleTag<?>> idToTupleTagMap;
private transient SamzaTimerInternalsFactory<?> timerInternalsFactory;
private transient DoFnRunner<InT, FnOutT> fnRunner;
private transient PushbackSideInputDoFnRunner<InT, FnOutT> pushbackFnRunner;
private transient SideInputHandler sideInputHandler;
private transient DoFnInvoker<InT, FnOutT> doFnInvoker;
// This is derivable from pushbackValues which is persisted to a store.
// TODO: eagerly initialize the hold in init
@edu.umd.cs.findbugs.annotations.SuppressWarnings(
justification = "No bug",
value = "SE_TRANSIENT_FIELD_NOT_RESTORED")
private transient Instant pushbackWatermarkHold;
// TODO: add this to checkpointable state
private transient Instant inputWatermark;
private transient Instant sideInputWatermark;
private transient List<WindowedValue<InT>> pushbackValues;
private transient StageBundleFactory stageBundleFactory;
private DoFnSchemaInformation doFnSchemaInformation;
private Map<String, PCollectionView<?>> sideInputMapping;
public DoFnOp(
TupleTag<FnOutT> mainOutputTag,
DoFn<InT, FnOutT> doFn,
Coder<?> keyCoder,
Coder<InT> inputCoder,
Map<TupleTag<?>, Coder<?>> outputCoders,
Collection<PCollectionView<?>> sideInputs,
List<TupleTag<?>> sideOutputTags,
WindowingStrategy windowingStrategy,
Map<String, PCollectionView<?>> idToViewMap,
OutputManagerFactory<OutT> outputManagerFactory,
String transformFullName,
String transformId,
PCollection.IsBounded isBounded,
boolean isPortable,
RunnerApi.ExecutableStagePayload stagePayload,
Map<String, TupleTag<?>> idToTupleTagMap,
DoFnSchemaInformation doFnSchemaInformation,
Map<String, PCollectionView<?>> sideInputMapping) {
this.mainOutputTag = mainOutputTag;
this.doFn = doFn;
this.sideInputs = sideInputs;
this.sideOutputTags = sideOutputTags;
this.inputCoder = inputCoder;
this.outputCoders = new HashMap<>(outputCoders);
this.windowingStrategy = windowingStrategy;
this.idToViewMap = new HashMap<>(idToViewMap);
this.outputManagerFactory = outputManagerFactory;
this.transformFullName = transformFullName;
this.transformId = transformId;
this.keyCoder = keyCoder;
this.isBounded = isBounded;
this.isPortable = isPortable;
this.stagePayload = stagePayload;
this.idToTupleTagMap = new HashMap<>(idToTupleTagMap);
this.doFnSchemaInformation = doFnSchemaInformation;
this.sideInputMapping = sideInputMapping;
}
@Override
public void open(
Config config,
Context context,
Scheduler<KeyedTimerData<Void>> timerRegistry,
OpEmitter<OutT> emitter) {
this.inputWatermark = BoundedWindow.TIMESTAMP_MIN_VALUE;
this.sideInputWatermark = BoundedWindow.TIMESTAMP_MIN_VALUE;
this.pushbackWatermarkHold = BoundedWindow.TIMESTAMP_MAX_VALUE;
final DoFnSignature signature = DoFnSignatures.getSignature(doFn.getClass());
final SamzaPipelineOptions pipelineOptions =
Base64Serializer.deserializeUnchecked(
config.get("beamPipelineOptions"), SerializablePipelineOptions.class)
.get()
.as(SamzaPipelineOptions.class);
final SamzaStoreStateInternals.Factory<?> nonKeyedStateInternalsFactory =
SamzaStoreStateInternals.createStateInternalFactory(
transformId, null, context.getTaskContext(), pipelineOptions, signature);
this.timerInternalsFactory =
SamzaTimerInternalsFactory.createTimerInternalFactory(
keyCoder,
(Scheduler) timerRegistry,
getTimerStateId(signature),
nonKeyedStateInternalsFactory,
windowingStrategy,
isBounded,
pipelineOptions);
this.sideInputHandler =
new SideInputHandler(sideInputs, nonKeyedStateInternalsFactory.stateInternalsForKey(null));
if (isPortable) {
SamzaExecutionContext samzaExecutionContext =
(SamzaExecutionContext) context.getApplicationContainerContext();
ExecutableStage executableStage = ExecutableStage.fromPayload(stagePayload);
stageBundleFactory = samzaExecutionContext.getJobBundleFactory().forStage(executableStage);
this.fnRunner =
SamzaDoFnRunners.createPortable(
outputManagerFactory.create(emitter),
stageBundleFactory,
mainOutputTag,
idToTupleTagMap,
context,
transformFullName);
} else {
this.fnRunner =
SamzaDoFnRunners.create(
pipelineOptions,
doFn,
windowingStrategy,
transformFullName,
transformId,
context,
mainOutputTag,
sideInputHandler,
timerInternalsFactory,
keyCoder,
outputManagerFactory.create(emitter),
inputCoder,
sideOutputTags,
outputCoders,
doFnSchemaInformation,
sideInputMapping);
}
this.pushbackFnRunner =
SimplePushbackSideInputDoFnRunner.create(fnRunner, sideInputs, sideInputHandler);
this.pushbackValues = new ArrayList<>();
final Iterator<SamzaDoFnInvokerRegistrar> invokerReg =
ServiceLoader.load(SamzaDoFnInvokerRegistrar.class).iterator();
if (!invokerReg.hasNext()) {
// use the default invoker here
doFnInvoker = DoFnInvokers.invokerFor(doFn);
} else {
doFnInvoker = Iterators.getOnlyElement(invokerReg).invokerFor(doFn, context);
}
doFnInvoker.invokeSetup();
}
private String getTimerStateId(DoFnSignature signature) {
final StringBuilder builder = new StringBuilder("timer");
if (signature.usesTimers()) {
signature.timerDeclarations().keySet().forEach(key -> builder.append(key));
}
return builder.toString();
}
@Override
public void processElement(WindowedValue<InT> inputElement, OpEmitter<OutT> emitter) {
pushbackFnRunner.startBundle();
final Iterable<WindowedValue<InT>> rejectedValues =
pushbackFnRunner.processElementInReadyWindows(inputElement);
for (WindowedValue<InT> rejectedValue : rejectedValues) {
if (rejectedValue.getTimestamp().compareTo(pushbackWatermarkHold) < 0) {
pushbackWatermarkHold = rejectedValue.getTimestamp();
}
pushbackValues.add(rejectedValue);
}
pushbackFnRunner.finishBundle();
}
@Override
public void processWatermark(Instant watermark, OpEmitter<OutT> emitter) {
this.inputWatermark = watermark;
if (sideInputWatermark.isEqual(BoundedWindow.TIMESTAMP_MAX_VALUE)) {
// this means we will never see any more side input
emitAllPushbackValues();
}
final Instant actualInputWatermark =
pushbackWatermarkHold.isBefore(inputWatermark) ? pushbackWatermarkHold : inputWatermark;
timerInternalsFactory.setInputWatermark(actualInputWatermark);
pushbackFnRunner.startBundle();
for (KeyedTimerData<?> keyedTimerData : timerInternalsFactory.removeReadyTimers()) {
fireTimer(keyedTimerData);
}
pushbackFnRunner.finishBundle();
if (timerInternalsFactory.getOutputWatermark() == null
|| timerInternalsFactory.getOutputWatermark().isBefore(actualInputWatermark)) {
timerInternalsFactory.setOutputWatermark(actualInputWatermark);
emitter.emitWatermark(timerInternalsFactory.getOutputWatermark());
}
}
@Override
public void processSideInput(
String id, WindowedValue<? extends Iterable<?>> elements, OpEmitter<OutT> emitter) {
@SuppressWarnings("unchecked")
final WindowedValue<Iterable<?>> retypedElements = (WindowedValue<Iterable<?>>) elements;
final PCollectionView<?> view = idToViewMap.get(id);
if (view == null) {
throw new IllegalArgumentException("No mapping of id " + id + " to view.");
}
sideInputHandler.addSideInputValue(view, retypedElements);
final List<WindowedValue<InT>> previousPushbackValues = new ArrayList<>(pushbackValues);
pushbackWatermarkHold = BoundedWindow.TIMESTAMP_MAX_VALUE;
pushbackValues.clear();
for (final WindowedValue<InT> value : previousPushbackValues) {
processElement(value, emitter);
}
// We may be able to advance the output watermark since we may have played some pushed back
// events.
processWatermark(this.inputWatermark, emitter);
}
@Override
public void processSideInputWatermark(Instant watermark, OpEmitter<OutT> emitter) {
sideInputWatermark = watermark;
if (sideInputWatermark.isEqual(BoundedWindow.TIMESTAMP_MAX_VALUE)) {
// this means we will never see any more side input
processWatermark(this.inputWatermark, emitter);
}
}
@Override
public void processTimer(KeyedTimerData<Void> keyedTimerData) {
pushbackFnRunner.startBundle();
fireTimer(keyedTimerData);
pushbackFnRunner.finishBundle();
this.timerInternalsFactory.removeProcessingTimer((KeyedTimerData) keyedTimerData);
}
@Override
public void close() {
doFnInvoker.invokeTeardown();
try (AutoCloseable closer = stageBundleFactory) {
// do nothing
} catch (Exception e) {
LOG.error("Failed to close stage bundle factory", e);
}
}
private void fireTimer(KeyedTimerData<?> keyedTimerData) {
final TimerInternals.TimerData timer = keyedTimerData.getTimerData();
LOG.debug("Firing timer {}", timer);
final StateNamespace namespace = timer.getNamespace();
// NOTE: not sure why this is safe, but DoFnOperator makes this assumption
final BoundedWindow window = ((StateNamespaces.WindowNamespace) namespace).getWindow();
if (fnRunner instanceof DoFnRunnerWithKeyedInternals) {
// Need to pass in the keyed TimerData here
((DoFnRunnerWithKeyedInternals) fnRunner).onTimer(keyedTimerData, window);
} else {
pushbackFnRunner.onTimer(timer.getTimerId(), window, timer.getTimestamp(), timer.getDomain());
}
}
private void emitAllPushbackValues() {
if (!pushbackValues.isEmpty()) {
pushbackFnRunner.startBundle();
final List<WindowedValue<InT>> previousPushbackValues = new ArrayList<>(pushbackValues);
pushbackWatermarkHold = BoundedWindow.TIMESTAMP_MAX_VALUE;
pushbackValues.clear();
for (final WindowedValue<InT> value : previousPushbackValues) {
fnRunner.processElement(value);
}
pushbackFnRunner.finishBundle();
}
}
/**
* Factory class to create an {@link org.apache.beam.runners.core.DoFnRunners.OutputManager} that
* emits values to the main output only, which is a single {@link
* org.apache.beam.sdk.values.PCollection}.
*
* @param <OutT> type of the output element.
*/
public static class SingleOutputManagerFactory<OutT> implements OutputManagerFactory<OutT> {
@Override
public DoFnRunners.OutputManager create(OpEmitter<OutT> emitter) {
return new DoFnRunners.OutputManager() {
@Override
public <T> void output(TupleTag<T> tupleTag, WindowedValue<T> windowedValue) {
// With only one input we know that T is of type OutT.
@SuppressWarnings("unchecked")
final WindowedValue<OutT> retypedWindowedValue = (WindowedValue<OutT>) windowedValue;
emitter.emitElement(retypedWindowedValue);
}
};
}
}
/**
* Factory class to create an {@link org.apache.beam.runners.core.DoFnRunners.OutputManager} that
* emits values to the main output as well as the side outputs via union type {@link
* RawUnionValue}.
*/
public static class MultiOutputManagerFactory implements OutputManagerFactory<RawUnionValue> {
private final Map<TupleTag<?>, Integer> tagToIndexMap;
public MultiOutputManagerFactory(Map<TupleTag<?>, Integer> tagToIndexMap) {
this.tagToIndexMap = tagToIndexMap;
}
@Override
public DoFnRunners.OutputManager create(OpEmitter<RawUnionValue> emitter) {
return new DoFnRunners.OutputManager() {
@Override
public <T> void output(TupleTag<T> tupleTag, WindowedValue<T> windowedValue) {
final int index = tagToIndexMap.get(tupleTag);
final T rawValue = windowedValue.getValue();
final RawUnionValue rawUnionValue = new RawUnionValue(index, rawValue);
emitter.emitElement(windowedValue.withValue(rawUnionValue));
}
};
}
}
}