blob: 511c07fe45e98cf7bd24f06838211bdbfc59ccf0 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.iteration.operator.allround;
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.state.OperatorStateStore;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.common.typeutils.base.IntSerializer;
import org.apache.flink.core.fs.CloseableRegistry;
import org.apache.flink.iteration.IterationListener;
import org.apache.flink.iteration.IterationRecord;
import org.apache.flink.iteration.operator.AbstractWrapperOperator;
import org.apache.flink.iteration.operator.OperatorStateUtils;
import org.apache.flink.iteration.operator.OperatorUtils;
import org.apache.flink.metrics.MetricGroup;
import org.apache.flink.metrics.groups.OperatorMetricGroup;
import org.apache.flink.runtime.checkpoint.CheckpointOptions;
import org.apache.flink.runtime.jobgraph.OperatorID;
import org.apache.flink.runtime.state.CheckpointStreamFactory;
import org.apache.flink.streaming.api.operators.KeyContext;
import org.apache.flink.streaming.api.operators.OperatorSnapshotFutures;
import org.apache.flink.streaming.api.operators.StreamOperator;
import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
import org.apache.flink.streaming.api.operators.StreamOperatorFactoryUtil;
import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
import org.apache.flink.streaming.api.operators.StreamOperatorStateContext;
import org.apache.flink.streaming.api.operators.StreamTaskStateInitializer;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService;
import org.apache.flink.streaming.runtime.tasks.StreamTask;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.io.IOException;
import java.util.Collections;
import static org.apache.flink.iteration.operator.OperatorUtils.processOperatorOrUdfIfSatisfy;
import static org.apache.flink.util.Preconditions.checkState;
/** The base class for the all-round wrapper operators. */
public abstract class AbstractAllRoundWrapperOperator<T, S extends StreamOperator<T>>
extends AbstractWrapperOperator<T> {
protected final S wrappedOperator;
// --------------- state ---------------------------
private int latestEpochWatermark = -1;
private ListState<Integer> parallelismState;
private ListState<Integer> latestEpochWatermarkState;
@SuppressWarnings({"unchecked", "rawtypes"})
public AbstractAllRoundWrapperOperator(
StreamOperatorParameters<IterationRecord<T>> parameters,
StreamOperatorFactory<T> operatorFactory) {
super(parameters, operatorFactory);
this.wrappedOperator =
(S)
StreamOperatorFactoryUtil.<T, S>createOperator(
operatorFactory,
(StreamTask) parameters.getContainingTask(),
OperatorUtils.createWrappedOperatorConfig(
parameters.getStreamConfig(),
containingTask.getUserCodeClassLoader()),
proxyOutput,
parameters.getOperatorEventDispatcher())
.f0;
processOperatorOrUdfIfSatisfy(
wrappedOperator,
EpochAware.class,
epochWatermarkAware ->
epochWatermarkAware.setEpochSupplier(epochWatermarkSupplier));
}
@Override
public void onEpochWatermarkIncrement(int epochWatermark) throws IOException {
if (epochWatermark > latestEpochWatermark) {
latestEpochWatermark = epochWatermark;
setIterationContextRound(epochWatermark);
processOperatorOrUdfIfSatisfy(
wrappedOperator,
IterationListener.class,
listener -> notifyEpochWatermarkIncrement(listener, epochWatermark));
clearIterationContextRound();
}
// Always broadcasts the events.
super.onEpochWatermarkIncrement(epochWatermark);
}
@Override
public void initializeState(StreamTaskStateInitializer streamTaskStateManager)
throws Exception {
RecordingStreamTaskStateInitializer recordingStreamTaskStateInitializer =
new RecordingStreamTaskStateInitializer(streamTaskStateManager);
wrappedOperator.initializeState(recordingStreamTaskStateInitializer);
checkState(recordingStreamTaskStateInitializer.lastCreated != null);
OperatorStateStore operatorStateStore =
recordingStreamTaskStateInitializer.lastCreated.operatorStateBackend();
parallelismState =
operatorStateStore.getUnionListState(
new ListStateDescriptor<>("parallelism", IntSerializer.INSTANCE));
OperatorStateUtils.getUniqueElement(parallelismState, "parallelism")
.ifPresent(
oldParallelism ->
checkState(
oldParallelism
== containingTask
.getEnvironment()
.getTaskInfo()
.getNumberOfParallelSubtasks(),
"The all-round wrapper operator is recovered with parallelism changed from "
+ oldParallelism
+ " to "
+ containingTask
.getEnvironment()
.getTaskInfo()
.getNumberOfParallelSubtasks()));
latestEpochWatermarkState =
operatorStateStore.getListState(
new ListStateDescriptor<>("latestEpoch", IntSerializer.INSTANCE));
OperatorStateUtils.getUniqueElement(latestEpochWatermarkState, "latestEpoch")
.ifPresent(
oldLatestEpochWatermark -> latestEpochWatermark = oldLatestEpochWatermark);
}
@Override
public OperatorSnapshotFutures snapshotState(
long checkpointId,
long timestamp,
CheckpointOptions checkpointOptions,
CheckpointStreamFactory storageLocation)
throws Exception {
// Always clear the union list state before set value.
parallelismState.clear();
if (containingTask.getEnvironment().getTaskInfo().getIndexOfThisSubtask() == 0) {
parallelismState.update(
Collections.singletonList(
containingTask
.getEnvironment()
.getTaskInfo()
.getNumberOfParallelSubtasks()));
}
latestEpochWatermarkState.update(Collections.singletonList(latestEpochWatermark));
return wrappedOperator.snapshotState(
checkpointId, timestamp, checkpointOptions, storageLocation);
}
@Override
public void open() throws Exception {
wrappedOperator.open();
}
@Override
public void finish() throws Exception {
setIterationContextRound(IterationRecord.END_EPOCH_WATERMARK);
wrappedOperator.finish();
clearIterationContextRound();
}
@Override
public void close() throws Exception {
setIterationContextRound(IterationRecord.END_EPOCH_WATERMARK);
wrappedOperator.close();
clearIterationContextRound();
}
@Override
public void prepareSnapshotPreBarrier(long checkpointId) throws Exception {
wrappedOperator.prepareSnapshotPreBarrier(checkpointId);
}
@Override
public void setKeyContextElement1(StreamRecord<?> record) throws Exception {
wrappedOperator.setKeyContextElement1(record);
}
@Override
public void setKeyContextElement2(StreamRecord<?> record) throws Exception {
wrappedOperator.setKeyContextElement2(record);
}
@Override
public OperatorMetricGroup getMetricGroup() {
return wrappedOperator.getMetricGroup();
}
@Override
public OperatorID getOperatorID() {
return wrappedOperator.getOperatorID();
}
@Override
public void notifyCheckpointComplete(long checkpointId) throws Exception {
wrappedOperator.notifyCheckpointComplete(checkpointId);
}
@Override
public void notifyCheckpointAborted(long checkpointId) throws Exception {
wrappedOperator.notifyCheckpointAborted(checkpointId);
}
@Override
public void setCurrentKey(Object key) {
wrappedOperator.setCurrentKey(key);
}
@Override
public Object getCurrentKey() {
return wrappedOperator.getCurrentKey();
}
@VisibleForTesting
int getLatestEpochWatermark() {
return latestEpochWatermark;
}
private static class RecordingStreamTaskStateInitializer implements StreamTaskStateInitializer {
private final StreamTaskStateInitializer wrapped;
StreamOperatorStateContext lastCreated;
public RecordingStreamTaskStateInitializer(StreamTaskStateInitializer wrapped) {
this.wrapped = wrapped;
}
@Override
public StreamOperatorStateContext streamOperatorStateContext(
@Nonnull OperatorID operatorID,
@Nonnull String s,
@Nonnull ProcessingTimeService processingTimeService,
@Nonnull KeyContext keyContext,
@Nullable TypeSerializer<?> typeSerializer,
@Nonnull CloseableRegistry closeableRegistry,
@Nonnull MetricGroup metricGroup,
double v,
boolean b)
throws Exception {
lastCreated =
wrapped.streamOperatorStateContext(
operatorID,
s,
processingTimeService,
keyContext,
typeSerializer,
closeableRegistry,
metricGroup,
v,
b);
return lastCreated;
}
}
}