blob: 60fda47cbc8ff2cac47cfe13f749c94646169f60 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.iteration.operator.perround;
import org.apache.flink.annotation.Internal;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.common.typeutils.base.IntSerializer;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.MetricOptions;
import org.apache.flink.contrib.streaming.state.RocksDBKeyedStateBackend;
import org.apache.flink.core.memory.ManagedMemoryUseCase;
import org.apache.flink.iteration.IterationListener;
import org.apache.flink.iteration.IterationRecord;
import org.apache.flink.iteration.operator.AbstractWrapperOperator;
import org.apache.flink.iteration.operator.OperatorStateUtils;
import org.apache.flink.iteration.operator.OperatorUtils;
import org.apache.flink.iteration.proxy.state.ProxyStateSnapshotContext;
import org.apache.flink.iteration.proxy.state.ProxyStreamOperatorStateContext;
import org.apache.flink.iteration.utils.ReflectionUtils;
import org.apache.flink.metrics.MetricGroup;
import org.apache.flink.metrics.groups.OperatorMetricGroup;
import org.apache.flink.runtime.checkpoint.CheckpointOptions;
import org.apache.flink.runtime.jobgraph.OperatorID;
import org.apache.flink.runtime.metrics.groups.UnregisteredMetricGroups;
import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
import org.apache.flink.runtime.state.CheckpointStreamFactory;
import org.apache.flink.runtime.state.DefaultOperatorStateBackend;
import org.apache.flink.runtime.state.KeyedStateBackend;
import org.apache.flink.runtime.state.OperatorStateBackend;
import org.apache.flink.runtime.state.OperatorStateCheckpointOutputStream;
import org.apache.flink.runtime.state.StateInitializationContext;
import org.apache.flink.runtime.state.StatePartitionStreamProvider;
import org.apache.flink.runtime.state.StateSnapshotContext;
import org.apache.flink.runtime.state.heap.HeapKeyedStateBackend;
import org.apache.flink.streaming.api.operators.InternalTimeServiceManager;
import org.apache.flink.streaming.api.operators.OperatorSnapshotFutures;
import org.apache.flink.streaming.api.operators.StreamOperator;
import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
import org.apache.flink.streaming.api.operators.StreamOperatorFactoryUtil;
import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
import org.apache.flink.streaming.api.operators.StreamOperatorStateContext;
import org.apache.flink.streaming.api.operators.StreamOperatorStateHandler;
import org.apache.flink.streaming.api.operators.StreamTaskStateInitializer;
import org.apache.flink.streaming.runtime.streamrecord.LatencyMarker;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.streaming.runtime.tasks.StreamTask;
import org.apache.flink.streaming.util.LatencyStats;
import org.apache.flink.util.CloseableIterable;
import org.apache.flink.util.ExceptionUtils;
import org.apache.flink.util.InstantiationUtil;
import org.apache.flink.util.function.BiConsumerWithException;
import org.apache.commons.collections.IteratorUtils;
import org.rocksdb.RocksDB;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import static org.apache.flink.util.Preconditions.checkState;
/** The base class for all the per-round wrapper operators. */
public abstract class AbstractPerRoundWrapperOperator<T, S extends StreamOperator<T>>
extends AbstractWrapperOperator<T>
implements StreamOperatorStateHandler.CheckpointedStreamOperator {
private static final Logger LOG =
LoggerFactory.getLogger(AbstractPerRoundWrapperOperator.class);
private static final String HEAP_KEYED_STATE_NAME =
"org.apache.flink.runtime.state.heap.HeapKeyedStateBackend";
private static final String ROCKSDB_KEYED_STATE_NAME =
"org.apache.flink.contrib.streaming.state.RocksDBKeyedStateBackend";
/** The wrapped operators for each round. */
private final Map<Integer, S> wrappedOperators;
protected final LatencyStats latencyStats;
private transient StreamOperatorStateContext streamOperatorStateContext;
private transient StreamOperatorStateHandler stateHandler;
private transient InternalTimeServiceManager<?> timeServiceManager;
private transient KeySelector<?, ?> stateKeySelector1;
private transient KeySelector<?, ?> stateKeySelector2;
private int latestEpochWatermark = -1;
// --------------- state ---------------------------
/** Stores the parallelism of the last run so that we could check if parallelism is changed. */
private ListState<Integer> parallelismState;
/** Stores the latest epoch watermark received. */
private ListState<Integer> latestEpochWatermarkState;
/**
* Stores the list of epochs that are not aligned yet. After failover the wrapped operators
* corresponding to these epochs would be re-crated.
*/
private ListState<Integer> pendingEpochState;
private ListState<Integer> rawStateEpochState;
public AbstractPerRoundWrapperOperator(
StreamOperatorParameters<IterationRecord<T>> parameters,
StreamOperatorFactory<T> operatorFactory) {
super(parameters, operatorFactory);
this.wrappedOperators = new HashMap<>();
this.latencyStats = initializeLatencyStats();
}
protected S getWrappedOperator(int round) {
return getWrappedOperator(
round, CloseableIterable.<StatePartitionStreamProvider>empty().iterator(), 0);
}
@SuppressWarnings({"unchecked", "rawtypes"})
private S getWrappedOperator(
int round, Iterator<StatePartitionStreamProvider> rawOperatorStates, int count) {
S wrappedOperator = wrappedOperators.get(round);
if (wrappedOperator != null) {
return wrappedOperator;
}
// We need to clone the operator factory to also support SimpleOperatorFactory.
try {
StreamOperatorFactory<T> clonedOperatorFactory =
InstantiationUtil.clone(
operatorFactory, containingTask.getUserCodeClassLoader());
wrappedOperator =
(S)
StreamOperatorFactoryUtil.<T, S>createOperator(
clonedOperatorFactory,
(StreamTask) parameters.getContainingTask(),
OperatorUtils.createWrappedOperatorConfig(
parameters.getStreamConfig(),
containingTask.getUserCodeClassLoader()),
proxyOutput,
parameters.getOperatorEventDispatcher())
.f0;
initializeStreamOperator(wrappedOperator, round, rawOperatorStates, count);
wrappedOperators.put(round, wrappedOperator);
return wrappedOperator;
} catch (Exception e) {
ExceptionUtils.rethrow(e);
}
return wrappedOperator;
}
protected abstract void endInputAndEmitMaxWatermark(S operator, int epoch, int epochWatermark)
throws Exception;
protected void closeStreamOperator(S operator, int epoch, int epochWatermark) throws Exception {
setIterationContextRound(epoch);
OperatorUtils.processOperatorOrUdfIfSatisfy(
operator,
IterationListener.class,
listener -> notifyEpochWatermarkIncrement(listener, epochWatermark));
endInputAndEmitMaxWatermark(operator, epoch, epochWatermark);
operator.finish();
operator.close();
setIterationContextRound(null);
// Cleanup the states used by this operator.
cleanupOperatorStates(epoch);
if (stateHandler.getKeyedStateBackend() != null) {
cleanupKeyedStates(epoch);
}
}
@Override
public void onEpochWatermarkIncrement(int epochWatermark) throws IOException {
checkState(epochWatermark >= 0, "The epoch watermark should be non-negative.");
if (epochWatermark > latestEpochWatermark) {
latestEpochWatermark = epochWatermark;
// Destroys all the operators with round < epoch watermark. Notes that
// the onEpochWatermarkIncrement must be from 0 and increment by 1 each time, except
// for the last round.
try {
if (epochWatermark < IterationRecord.END_EPOCH_WATERMARK) {
S wrappedOperator = wrappedOperators.remove(epochWatermark);
if (wrappedOperator != null) {
closeStreamOperator(wrappedOperator, epochWatermark, epochWatermark);
}
} else {
List<Integer> sortedEpochs = new ArrayList<>(wrappedOperators.keySet());
Collections.sort(sortedEpochs);
for (Integer epoch : sortedEpochs) {
closeStreamOperator(wrappedOperators.remove(epoch), epoch, epochWatermark);
}
}
} catch (Exception exception) {
ExceptionUtils.rethrow(exception);
}
}
super.onEpochWatermarkIncrement(epochWatermark);
}
protected void processForEachWrappedOperator(
BiConsumerWithException<Integer, S, Exception> consumer) throws Exception {
for (Map.Entry<Integer, S> entry : wrappedOperators.entrySet()) {
consumer.accept(entry.getKey(), entry.getValue());
}
}
@Override
public void open() throws Exception {}
@Override
public void initializeState(StreamTaskStateInitializer streamTaskStateManager)
throws Exception {
final TypeSerializer<?> keySerializer =
streamConfig.getStateKeySerializer(containingTask.getUserCodeClassLoader());
streamOperatorStateContext =
streamTaskStateManager.streamOperatorStateContext(
getOperatorID(),
getClass().getSimpleName(),
parameters.getProcessingTimeService(),
this,
keySerializer,
containingTask.getCancelables(),
metrics,
streamConfig.getManagedMemoryFractionOperatorUseCaseOfSlot(
ManagedMemoryUseCase.STATE_BACKEND,
containingTask
.getEnvironment()
.getTaskManagerInfo()
.getConfiguration(),
containingTask.getUserCodeClassLoader()),
isUsingCustomRawKeyedState());
stateHandler =
new StreamOperatorStateHandler(
streamOperatorStateContext,
containingTask.getExecutionConfig(),
containingTask.getCancelables());
stateHandler.initializeOperatorState(this);
this.timeServiceManager = streamOperatorStateContext.internalTimerServiceManager();
stateKeySelector1 =
streamConfig.getStatePartitioner(0, containingTask.getUserCodeClassLoader());
stateKeySelector2 =
streamConfig.getStatePartitioner(1, containingTask.getUserCodeClassLoader());
}
@Override
public void initializeState(StateInitializationContext context) throws Exception {
parallelismState =
context.getOperatorStateStore()
.getUnionListState(
new ListStateDescriptor<>("parallelism", IntSerializer.INSTANCE));
OperatorStateUtils.getUniqueElement(parallelismState, "parallelism")
.ifPresent(
oldParallelism ->
checkState(
oldParallelism
== containingTask
.getEnvironment()
.getTaskInfo()
.getNumberOfParallelSubtasks(),
"The all-round wrapper operator is recovered with parallelism changed from "
+ oldParallelism
+ " to "
+ containingTask
.getEnvironment()
.getTaskInfo()
.getNumberOfParallelSubtasks()));
latestEpochWatermarkState =
context.getOperatorStateStore()
.getListState(
new ListStateDescriptor<>("latestEpoch", IntSerializer.INSTANCE));
OperatorStateUtils.getUniqueElement(latestEpochWatermarkState, "latestEpoch")
.ifPresent(
oldLatestEpochWatermark -> latestEpochWatermark = oldLatestEpochWatermark);
// Notes that the list must be sorted.
rawStateEpochState =
context.getOperatorStateStore()
.getListState(new ListStateDescriptor<>("rawStateEpoch", Integer.class));
List<Integer> rawStateEpochs = IteratorUtils.toList(rawStateEpochState.get().iterator());
// Notes that the list must be sorted.
pendingEpochState =
context.getOperatorStateStore()
.getListState(
new ListStateDescriptor<>("pendingEpochs", IntSerializer.INSTANCE));
List<Integer> pendingEpochs = IteratorUtils.toList(pendingEpochState.get().iterator());
// Unfortunately, for the raw state we could not call get input stream unless the previous
// records are consumed. We would have to do a "merge" of the two lists.
Iterator<StatePartitionStreamProvider> rawStates =
context.getRawOperatorStateInputs().iterator();
int nextRawStateEntryIndex = 0;
for (int epoch : pendingEpochs) {
checkState(
nextRawStateEntryIndex == rawStateEpochs.size()
|| rawStateEpochs.get(nextRawStateEntryIndex) >= epoch,
String.format(
"Unexpected raw state indices %s and epochs %s",
rawStateEpochs.toString(), pendingEpochs.toString()));
// Let's find how much entries this epoch has.
int numberOfStateEntries = 0;
while (nextRawStateEntryIndex < rawStateEpochs.size()
&& rawStateEpochs.get(nextRawStateEntryIndex) == epoch) {
numberOfStateEntries++;
nextRawStateEntryIndex++;
}
// We first open these operators
getWrappedOperator(epoch, rawStates, numberOfStateEntries);
}
}
@Internal
protected boolean isUsingCustomRawKeyedState() {
return false;
}
@Override
public void finish() throws Exception {
checkState(
wrappedOperators.size() == 0,
"Some wrapped operators are still not closed yet: " + wrappedOperators.keySet());
}
@Override
public void close() throws Exception {
if (stateHandler != null) {
stateHandler.dispose();
}
}
@Override
public void prepareSnapshotPreBarrier(long checkpointId) throws Exception {
for (Map.Entry<Integer, S> entry : wrappedOperators.entrySet()) {
entry.getValue().prepareSnapshotPreBarrier(checkpointId);
}
}
@Override
public OperatorSnapshotFutures snapshotState(
long checkpointId,
long timestamp,
CheckpointOptions checkpointOptions,
CheckpointStreamFactory factory)
throws Exception {
return stateHandler.snapshotState(
this,
Optional.ofNullable(timeServiceManager),
streamConfig.getOperatorName(),
checkpointId,
timestamp,
checkpointOptions,
factory,
isUsingCustomRawKeyedState());
}
@Override
public void snapshotState(StateSnapshotContext context) throws Exception {
OperatorStateCheckpointOutputStream rawOperatorStateOutputStream =
context.getRawOperatorStateOutput();
List<Integer> operatorStateEpoch = new ArrayList<>();
List<Integer> sortedEpochs = new ArrayList<>(wrappedOperators.keySet());
Collections.sort(sortedEpochs);
for (int epoch : sortedEpochs) {
S wrappedOperator = wrappedOperators.get(epoch);
if (StreamOperatorStateHandler.CheckpointedStreamOperator.class.isAssignableFrom(
wrappedOperator.getClass())) {
((StreamOperatorStateHandler.CheckpointedStreamOperator) wrappedOperator)
.snapshotState(new ProxyStateSnapshotContext(context));
// Gets the count of the raw operator state.
int numberOfPartitions = rawOperatorStateOutputStream.getNumberOfPartitions();
while (operatorStateEpoch.size() < numberOfPartitions) {
operatorStateEpoch.add(epoch);
}
}
}
// Then snapshot our own states
// Always clear the union list state before set value.
parallelismState.clear();
if (containingTask.getEnvironment().getTaskInfo().getIndexOfThisSubtask() == 0) {
parallelismState.update(
Collections.singletonList(
containingTask
.getEnvironment()
.getTaskInfo()
.getNumberOfParallelSubtasks()));
}
latestEpochWatermarkState.update(Collections.singletonList(latestEpochWatermark));
// The list must be sorted
rawStateEpochState.update(operatorStateEpoch);
// The list must be sorted
pendingEpochState.update(sortedEpochs);
}
@Override
@SuppressWarnings({"unchecked", "rawtypes"})
public void setKeyContextElement1(StreamRecord record) throws Exception {
setKeyContextElement(record, stateKeySelector1);
}
@Override
@SuppressWarnings({"unchecked", "rawtypes"})
public void setKeyContextElement2(StreamRecord record) throws Exception {
setKeyContextElement(record, stateKeySelector2);
}
private <T> void setKeyContextElement(StreamRecord<T> record, KeySelector<T, ?> selector)
throws Exception {
if (selector != null
&& ((IterationRecord<?>) record.getValue()).getType()
== IterationRecord.Type.RECORD) {
Object key = selector.getKey(record.getValue());
setCurrentKey(key);
}
}
@Override
public OperatorMetricGroup getMetricGroup() {
return metrics;
}
@Override
public OperatorID getOperatorID() {
return streamConfig.getOperatorID();
}
@Override
public void notifyCheckpointComplete(long l) throws Exception {
for (Map.Entry<Integer, S> entry : wrappedOperators.entrySet()) {
entry.getValue().notifyCheckpointComplete(l);
}
}
@Override
public void notifyCheckpointAborted(long checkpointId) throws Exception {
for (Map.Entry<Integer, S> entry : wrappedOperators.entrySet()) {
entry.getValue().notifyCheckpointAborted(checkpointId);
}
}
@Override
public void setCurrentKey(Object key) {
stateHandler.setCurrentKey(key);
}
@Override
public Object getCurrentKey() {
if (stateHandler == null) {
return null;
}
return stateHandler.getCurrentKey();
}
protected void reportOrForwardLatencyMarker(LatencyMarker marker) {
// all operators are tracking latencies
this.latencyStats.reportLatency(marker);
// everything except sinks forwards latency markers
this.output.emitLatencyMarker(marker);
}
private LatencyStats initializeLatencyStats() {
try {
Configuration taskManagerConfig =
containingTask.getEnvironment().getTaskManagerInfo().getConfiguration();
int historySize = taskManagerConfig.getInteger(MetricOptions.LATENCY_HISTORY_SIZE);
if (historySize <= 0) {
LOG.warn(
"{} has been set to a value equal or below 0: {}. Using default.",
MetricOptions.LATENCY_HISTORY_SIZE,
historySize);
historySize = MetricOptions.LATENCY_HISTORY_SIZE.defaultValue();
}
final String configuredGranularity =
taskManagerConfig.getString(MetricOptions.LATENCY_SOURCE_GRANULARITY);
LatencyStats.Granularity granularity;
try {
granularity =
LatencyStats.Granularity.valueOf(
configuredGranularity.toUpperCase(Locale.ROOT));
} catch (IllegalArgumentException iae) {
granularity = LatencyStats.Granularity.OPERATOR;
LOG.warn(
"Configured value {} option for {} is invalid. Defaulting to {}.",
configuredGranularity,
MetricOptions.LATENCY_SOURCE_GRANULARITY.key(),
granularity);
}
MetricGroup jobMetricGroup = this.metrics.getJobMetricGroup();
return new LatencyStats(
jobMetricGroup.addGroup("latency"),
historySize,
containingTask.getIndexInSubtaskGroup(),
getOperatorID(),
granularity);
} catch (Exception e) {
LOG.warn("An error occurred while instantiating latency metrics.", e);
return new LatencyStats(
UnregisteredMetricGroups.createUnregisteredTaskManagerJobMetricGroup()
.addGroup("latency"),
1,
0,
new OperatorID(),
LatencyStats.Granularity.SINGLE);
}
}
private void initializeStreamOperator(
S operator,
int round,
Iterator<StatePartitionStreamProvider> rawOperatorStates,
int count)
throws Exception {
operator.initializeState(
(operatorID,
operatorClassName,
processingTimeService,
keyContext,
keySerializer,
streamTaskCloseableRegistry,
metricGroup,
managedMemoryFraction,
isUsingCustomRawKeyedState) ->
new ProxyStreamOperatorStateContext(
streamOperatorStateContext,
getRoundStatePrefix(round),
rawOperatorStates,
count));
operator.open();
}
private void cleanupOperatorStates(int round) {
String roundPrefix = getRoundStatePrefix(round);
OperatorStateBackend operatorStateBackend = stateHandler.getOperatorStateBackend();
if (operatorStateBackend instanceof DefaultOperatorStateBackend) {
for (String fieldNames :
new String[] {
"registeredOperatorStates",
"registeredBroadcastStates",
"accessedStatesByName",
"accessedBroadcastStatesByName"
}) {
Map<String, ?> field =
ReflectionUtils.getFieldValue(
operatorStateBackend,
DefaultOperatorStateBackend.class,
fieldNames);
field.entrySet().removeIf(entry -> entry.getKey().startsWith(roundPrefix));
}
} else {
LOG.warn("Unable to cleanup the operator state {}", operatorStateBackend);
}
}
private void cleanupKeyedStates(int round) {
String roundPrefix = getRoundStatePrefix(round);
KeyedStateBackend<?> keyedStateBackend = stateHandler.getKeyedStateBackend();
if (keyedStateBackend.getClass().getName().equals(HEAP_KEYED_STATE_NAME)) {
ReflectionUtils.<Map<String, ?>>getFieldValue(
keyedStateBackend, HeapKeyedStateBackend.class, "registeredKVStates")
.entrySet()
.removeIf(entry -> entry.getKey().startsWith(roundPrefix));
ReflectionUtils.<Map<String, ?>>getFieldValue(
keyedStateBackend, HeapKeyedStateBackend.class, "createdKVStates")
.entrySet()
.removeIf(entry -> entry.getKey().startsWith(roundPrefix));
ReflectionUtils.<Map<String, ?>>getFieldValue(
keyedStateBackend,
AbstractKeyedStateBackend.class,
"keyValueStatesByName")
.entrySet()
.removeIf(entry -> entry.getKey().startsWith(roundPrefix));
} else if (keyedStateBackend.getClass().getName().equals(ROCKSDB_KEYED_STATE_NAME)) {
RocksDB db =
ReflectionUtils.getFieldValue(
keyedStateBackend, RocksDBKeyedStateBackend.class, "db");
HashMap<String, RocksDBKeyedStateBackend.RocksDbKvStateInfo> kvStateInformation =
ReflectionUtils.getFieldValue(
keyedStateBackend,
RocksDBKeyedStateBackend.class,
"kvStateInformation");
kvStateInformation.entrySet().stream()
.filter(entry -> entry.getKey().startsWith(roundPrefix))
.forEach(
entry -> {
try {
db.dropColumnFamily(entry.getValue().columnFamilyHandle);
} catch (Exception e) {
LOG.error(
"Failed to drop state {} for round {}",
entry.getKey(),
round);
}
});
kvStateInformation.entrySet().removeIf(entry -> entry.getKey().startsWith(roundPrefix));
ReflectionUtils.<Map<String, ?>>getFieldValue(
keyedStateBackend, RocksDBKeyedStateBackend.class, "createdKVStates")
.entrySet()
.removeIf(entry -> entry.getKey().startsWith(roundPrefix));
ReflectionUtils.<Map<String, ?>>getFieldValue(
keyedStateBackend,
AbstractKeyedStateBackend.class,
"keyValueStatesByName")
.entrySet()
.removeIf(entry -> entry.getKey().startsWith(roundPrefix));
} else {
LOG.warn("Unable to cleanup the keyed state {}", keyedStateBackend);
}
}
private String getRoundStatePrefix(int round) {
return "r" + round + "-";
}
int getLatestEpochWatermark() {
return latestEpochWatermark;
}
public Map<Integer, S> getWrappedOperators() {
return wrappedOperators;
}
}