blob: 07088d901f661aa64a71b68e2e9eedc0c4c9d476 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.runtime.checkpoint;
import org.apache.flink.annotation.Internal;
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.runtime.OperatorIDPair;
import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo;
import org.apache.flink.runtime.checkpoint.channel.ResultSubpartitionInfo;
import org.apache.flink.runtime.executiongraph.Execution;
import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
import org.apache.flink.runtime.executiongraph.IntermediateResult;
import org.apache.flink.runtime.io.network.api.writer.SubtaskStateMapper;
import org.apache.flink.runtime.jobgraph.IntermediateDataSet;
import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.apache.flink.runtime.jobgraph.JobEdge;
import org.apache.flink.runtime.jobgraph.OperatorID;
import org.apache.flink.runtime.jobgraph.OperatorInstanceID;
import org.apache.flink.runtime.state.AbstractChannelStateHandle;
import org.apache.flink.runtime.state.InputChannelStateHandle;
import org.apache.flink.runtime.state.KeyGroupRange;
import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
import org.apache.flink.runtime.state.KeyGroupsStateHandle;
import org.apache.flink.runtime.state.KeyedStateHandle;
import org.apache.flink.runtime.state.ResultSubpartitionStateHandle;
import org.apache.flink.runtime.state.StateObject;
import org.apache.flink.util.CollectionUtil;
import org.apache.flink.util.Preconditions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import static java.util.Collections.emptyList;
import static org.apache.flink.util.Preconditions.checkNotNull;
/**
* This class encapsulates the operation of assigning restored state when restoring from a
* checkpoint.
*/
@Internal
public class StateAssignmentOperation {
private static final Logger LOG = LoggerFactory.getLogger(StateAssignmentOperation.class);
private final Set<ExecutionJobVertex> tasks;
private final Map<OperatorID, OperatorState> operatorStates;
private final long restoreCheckpointId;
private final boolean allowNonRestoredState;
/** The state assignments for each ExecutionJobVertex that will be filled in multiple passes. */
private final Map<ExecutionJobVertex, TaskStateAssignment> vertexAssignments;
/**
* Stores the assignment of a consumer. {@link IntermediateResult} only allows to traverse
* producer.
*/
private final Map<IntermediateDataSetID, TaskStateAssignment> consumerAssignment =
new HashMap<>();
public StateAssignmentOperation(
long restoreCheckpointId,
Set<ExecutionJobVertex> tasks,
Map<OperatorID, OperatorState> operatorStates,
boolean allowNonRestoredState) {
this.restoreCheckpointId = restoreCheckpointId;
this.tasks = Preconditions.checkNotNull(tasks);
this.operatorStates = Preconditions.checkNotNull(operatorStates);
this.allowNonRestoredState = allowNonRestoredState;
this.vertexAssignments = CollectionUtil.newHashMapWithExpectedSize(tasks.size());
}
public void assignStates() {
checkStateMappingCompleteness(allowNonRestoredState, operatorStates, tasks);
Map<OperatorID, OperatorState> localOperators = new HashMap<>(operatorStates);
// find the states of all operators belonging to this task and compute additional
// information in first pass
for (ExecutionJobVertex executionJobVertex : tasks) {
List<OperatorIDPair> operatorIDPairs = executionJobVertex.getOperatorIDs();
Map<OperatorID, OperatorState> operatorStates =
CollectionUtil.newHashMapWithExpectedSize(operatorIDPairs.size());
for (OperatorIDPair operatorIDPair : operatorIDPairs) {
OperatorID operatorID =
operatorIDPair
.getUserDefinedOperatorID()
.filter(localOperators::containsKey)
.orElse(operatorIDPair.getGeneratedOperatorID());
OperatorState operatorState = localOperators.remove(operatorID);
if (operatorState == null) {
operatorState =
new OperatorState(
operatorID,
executionJobVertex.getParallelism(),
executionJobVertex.getMaxParallelism());
}
operatorStates.put(operatorIDPair.getGeneratedOperatorID(), operatorState);
}
final TaskStateAssignment stateAssignment =
new TaskStateAssignment(
executionJobVertex,
operatorStates,
consumerAssignment,
vertexAssignments);
vertexAssignments.put(executionJobVertex, stateAssignment);
for (final IntermediateResult producedDataSet : executionJobVertex.getInputs()) {
consumerAssignment.put(producedDataSet.getId(), stateAssignment);
}
}
// repartition state
for (TaskStateAssignment stateAssignment : vertexAssignments.values()) {
if (stateAssignment.hasNonFinishedState
// FLINK-31963: We need to run repartitioning for stateless operators that have
// upstream output or downstream input states.
|| stateAssignment.hasUpstreamOutputStates()
|| stateAssignment.hasDownstreamInputStates()) {
assignAttemptState(stateAssignment);
}
}
// actually assign the state
for (TaskStateAssignment stateAssignment : vertexAssignments.values()) {
// If upstream has output states or downstream has input states, even the empty task
// state should be assigned for the current task in order to notify this task that the
// old states will send to it which likely should be filtered.
if (stateAssignment.hasNonFinishedState
|| stateAssignment.isFullyFinished
|| stateAssignment.hasUpstreamOutputStates()
|| stateAssignment.hasDownstreamInputStates()) {
assignTaskStateToExecutionJobVertices(stateAssignment);
}
}
}
private void assignAttemptState(TaskStateAssignment taskStateAssignment) {
// 1. first compute the new parallelism
checkParallelismPreconditions(taskStateAssignment);
List<KeyGroupRange> keyGroupPartitions =
createKeyGroupPartitions(
taskStateAssignment.executionJobVertex.getMaxParallelism(),
taskStateAssignment.newParallelism);
/*
* Redistribute ManagedOperatorStates and RawOperatorStates from old parallelism to new parallelism.
*
* The old ManagedOperatorStates with old parallelism 3:
*
* parallelism0 parallelism1 parallelism2
* op0 states0,0 state0,1 state0,2
* op1
* op2 states2,0 state2,1 state2,2
* op3 states3,0 state3,1 state3,2
*
* The new ManagedOperatorStates with new parallelism 4:
*
* parallelism0 parallelism1 parallelism2 parallelism3
* op0 state0,0 state0,1 state0,2 state0,3
* op1
* op2 state2,0 state2,1 state2,2 state2,3
* op3 state3,0 state3,1 state3,2 state3,3
*/
reDistributePartitionableStates(
taskStateAssignment.oldState,
taskStateAssignment.newParallelism,
OperatorSubtaskState::getManagedOperatorState,
RoundRobinOperatorStateRepartitioner.INSTANCE,
taskStateAssignment.subManagedOperatorState);
reDistributePartitionableStates(
taskStateAssignment.oldState,
taskStateAssignment.newParallelism,
OperatorSubtaskState::getRawOperatorState,
RoundRobinOperatorStateRepartitioner.INSTANCE,
taskStateAssignment.subRawOperatorState);
reDistributeInputChannelStates(taskStateAssignment);
reDistributeResultSubpartitionStates(taskStateAssignment);
reDistributeKeyedStates(keyGroupPartitions, taskStateAssignment);
}
private void assignTaskStateToExecutionJobVertices(TaskStateAssignment assignment) {
ExecutionJobVertex executionJobVertex = assignment.executionJobVertex;
List<OperatorIDPair> operatorIDs = executionJobVertex.getOperatorIDs();
final int newParallelism = executionJobVertex.getParallelism();
/*
* An executionJobVertex's all state handles needed to restore are something like a matrix
*
* parallelism0 parallelism1 parallelism2 parallelism3
* op0 sh(0,0) sh(0,1) sh(0,2) sh(0,3)
* op1 sh(1,0) sh(1,1) sh(1,2) sh(1,3)
* op2 sh(2,0) sh(2,1) sh(2,2) sh(2,3)
* op3 sh(3,0) sh(3,1) sh(3,2) sh(3,3)
*
*/
for (int subTaskIndex = 0; subTaskIndex < newParallelism; subTaskIndex++) {
Execution currentExecutionAttempt =
executionJobVertex.getTaskVertices()[subTaskIndex].getCurrentExecutionAttempt();
if (assignment.isFullyFinished) {
assignFinishedStateToTask(currentExecutionAttempt);
} else {
assignNonFinishedStateToTask(
assignment, operatorIDs, subTaskIndex, currentExecutionAttempt);
}
}
}
private void assignFinishedStateToTask(Execution currentExecutionAttempt) {
JobManagerTaskRestore taskRestore =
new JobManagerTaskRestore(
restoreCheckpointId, TaskStateSnapshot.FINISHED_ON_RESTORE);
currentExecutionAttempt.setInitialState(taskRestore);
}
private void assignNonFinishedStateToTask(
TaskStateAssignment assignment,
List<OperatorIDPair> operatorIDs,
int subTaskIndex,
Execution currentExecutionAttempt) {
TaskStateSnapshot taskState = new TaskStateSnapshot(operatorIDs.size(), false);
for (OperatorIDPair operatorID : operatorIDs) {
OperatorInstanceID instanceID =
OperatorInstanceID.of(subTaskIndex, operatorID.getGeneratedOperatorID());
OperatorSubtaskState operatorSubtaskState = assignment.getSubtaskState(instanceID);
taskState.putSubtaskStateByOperatorID(
operatorID.getGeneratedOperatorID(), operatorSubtaskState);
}
JobManagerTaskRestore taskRestore =
new JobManagerTaskRestore(restoreCheckpointId, taskState);
currentExecutionAttempt.setInitialState(taskRestore);
}
public void checkParallelismPreconditions(TaskStateAssignment taskStateAssignment) {
for (OperatorState operatorState : taskStateAssignment.oldState.values()) {
checkParallelismPreconditions(operatorState, taskStateAssignment.executionJobVertex);
}
}
private void reDistributeKeyedStates(
List<KeyGroupRange> keyGroupPartitions, TaskStateAssignment stateAssignment) {
stateAssignment.oldState.forEach(
(operatorID, operatorState) -> {
for (int subTaskIndex = 0;
subTaskIndex < stateAssignment.newParallelism;
subTaskIndex++) {
OperatorInstanceID instanceID =
OperatorInstanceID.of(subTaskIndex, operatorID);
Tuple2<List<KeyedStateHandle>, List<KeyedStateHandle>> subKeyedStates =
reAssignSubKeyedStates(
operatorState,
keyGroupPartitions,
subTaskIndex,
stateAssignment.newParallelism,
operatorState.getParallelism());
stateAssignment.subManagedKeyedState.put(instanceID, subKeyedStates.f0);
stateAssignment.subRawKeyedState.put(instanceID, subKeyedStates.f1);
}
});
}
// TODO rewrite based on operator id
private Tuple2<List<KeyedStateHandle>, List<KeyedStateHandle>> reAssignSubKeyedStates(
OperatorState operatorState,
List<KeyGroupRange> keyGroupPartitions,
int subTaskIndex,
int newParallelism,
int oldParallelism) {
List<KeyedStateHandle> subManagedKeyedState;
List<KeyedStateHandle> subRawKeyedState;
if (newParallelism == oldParallelism) {
if (operatorState.getState(subTaskIndex) != null) {
subManagedKeyedState =
operatorState.getState(subTaskIndex).getManagedKeyedState().asList();
subRawKeyedState = operatorState.getState(subTaskIndex).getRawKeyedState().asList();
} else {
subManagedKeyedState = emptyList();
subRawKeyedState = emptyList();
}
} else {
subManagedKeyedState =
getManagedKeyedStateHandles(
operatorState, keyGroupPartitions.get(subTaskIndex));
subRawKeyedState =
getRawKeyedStateHandles(operatorState, keyGroupPartitions.get(subTaskIndex));
}
if (subManagedKeyedState.isEmpty() && subRawKeyedState.isEmpty()) {
return new Tuple2<>(emptyList(), emptyList());
} else {
return new Tuple2<>(subManagedKeyedState, subRawKeyedState);
}
}
public static <T extends StateObject> void reDistributePartitionableStates(
Map<OperatorID, OperatorState> oldOperatorStates,
int newParallelism,
Function<OperatorSubtaskState, StateObjectCollection<T>> extractHandle,
OperatorStateRepartitioner<T> stateRepartitioner,
Map<OperatorInstanceID, List<T>> result) {
// The nested list wraps as the level of operator -> subtask -> state object collection
Map<OperatorID, List<List<T>>> oldStates =
splitManagedAndRawOperatorStates(oldOperatorStates, extractHandle);
oldOperatorStates.forEach(
(operatorID, oldOperatorState) ->
result.putAll(
applyRepartitioner(
operatorID,
stateRepartitioner,
oldStates.get(operatorID),
oldOperatorState.getParallelism(),
newParallelism)));
}
public void reDistributeResultSubpartitionStates(TaskStateAssignment assignment) {
// FLINK-31963: We can skip this phase if there is no output state AND downstream has no
// input states
if (!assignment.hasOutputState && !assignment.hasDownstreamInputStates()) {
return;
}
checkForUnsupportedToplogyChanges(
assignment.oldState,
OperatorSubtaskState::getResultSubpartitionState,
assignment.outputOperatorID);
final OperatorState outputState = assignment.oldState.get(assignment.outputOperatorID);
final List<List<ResultSubpartitionStateHandle>> outputOperatorState =
splitBySubtasks(outputState, OperatorSubtaskState::getResultSubpartitionState);
final ExecutionJobVertex executionJobVertex = assignment.executionJobVertex;
final List<IntermediateDataSet> outputs =
executionJobVertex.getJobVertex().getProducedDataSets();
if (outputState.getParallelism() == executionJobVertex.getParallelism()) {
assignment.resultSubpartitionStates.putAll(
toInstanceMap(assignment.outputOperatorID, outputOperatorState));
return;
}
// Parallelism of this vertex changed, distribute ResultSubpartitionStateHandle
// according to output mapping.
for (int partitionIndex = 0; partitionIndex < outputs.size(); partitionIndex++) {
final List<List<ResultSubpartitionStateHandle>> partitionState =
outputs.size() == 1
? outputOperatorState
: getPartitionState(
outputOperatorState,
ResultSubpartitionInfo::getPartitionIdx,
partitionIndex);
final MappingBasedRepartitioner<ResultSubpartitionStateHandle> repartitioner =
new MappingBasedRepartitioner<>(
assignment.getOutputMapping(partitionIndex).getRescaleMappings());
final Map<OperatorInstanceID, List<ResultSubpartitionStateHandle>> repartitioned =
applyRepartitioner(
assignment.outputOperatorID,
repartitioner,
partitionState,
outputOperatorState.size(),
executionJobVertex.getParallelism());
addToSubtasks(assignment.resultSubpartitionStates, repartitioned);
}
}
public void reDistributeInputChannelStates(TaskStateAssignment stateAssignment) {
// FLINK-31963: We can skip this phase only if there is no input state AND upstream has no
// output states
if (!stateAssignment.hasInputState && !stateAssignment.hasUpstreamOutputStates()) {
return;
}
checkForUnsupportedToplogyChanges(
stateAssignment.oldState,
OperatorSubtaskState::getInputChannelState,
stateAssignment.inputOperatorID);
final ExecutionJobVertex executionJobVertex = stateAssignment.executionJobVertex;
final List<IntermediateResult> inputs = executionJobVertex.getInputs();
// check for rescaling: no rescaling = simple reassignment
final OperatorState inputState =
stateAssignment.oldState.get(stateAssignment.inputOperatorID);
final List<List<InputChannelStateHandle>> inputOperatorState =
splitBySubtasks(inputState, OperatorSubtaskState::getInputChannelState);
boolean hasAnyFullMapper =
executionJobVertex.getJobVertex().getInputs().stream()
.map(JobEdge::getDownstreamSubtaskStateMapper)
.anyMatch(m -> m.equals(SubtaskStateMapper.FULL));
boolean hasAnyPreviousOperatorChanged =
executionJobVertex.getInputs().stream()
.map(IntermediateResult::getProducer)
.map(vertexAssignments::get)
.anyMatch(
taskStateAssignment -> {
final int oldParallelism =
stateAssignment
.oldState
.get(stateAssignment.inputOperatorID)
.getParallelism();
return oldParallelism
!= taskStateAssignment.executionJobVertex
.getParallelism();
});
// need rescale if any input operator parallelism was changed and have any input with FULL
// subtask state mapper
if (inputState.getParallelism() == executionJobVertex.getParallelism()
&& !(hasAnyFullMapper && hasAnyPreviousOperatorChanged)) {
stateAssignment.inputChannelStates.putAll(
toInstanceMap(stateAssignment.inputOperatorID, inputOperatorState));
return;
}
// if this subtask has a different degree of parallelism, use the partitioner to figure
// out to which subtasks the state should be reassigned. In some cases, state is
// replicated to multiple subtasks and filtered during recovery.
// example: if this task is downscaled from 3 to 2 and it uses a range partitioner over
// [0;128)
// old assignment: 0 -> [0;43); 1 -> [43;87); 2 -> [87;128)
// new assignment: 0 -> [0;64]; 1 -> [64;128)
// subtask 0 recovers data from old subtask 0 + 1 and subtask 1 recovers data from old
// subtask 1 + 2
for (int gateIndex = 0; gateIndex < inputs.size(); gateIndex++) {
final RescaleMappings mapping =
stateAssignment.getInputMapping(gateIndex).getRescaleMappings();
final List<List<InputChannelStateHandle>> gateState =
inputs.size() == 1
? inputOperatorState
: getPartitionState(
inputOperatorState, InputChannelInfo::getGateIdx, gateIndex);
final MappingBasedRepartitioner<InputChannelStateHandle> repartitioner =
new MappingBasedRepartitioner<>(mapping);
final Map<OperatorInstanceID, List<InputChannelStateHandle>> repartitioned =
applyRepartitioner(
stateAssignment.inputOperatorID,
repartitioner,
gateState,
inputOperatorState.size(),
stateAssignment.newParallelism);
addToSubtasks(stateAssignment.inputChannelStates, repartitioned);
}
}
private static <K, V> void addToSubtasks(Map<K, List<V>> target, Map<K, List<V>> toAdd) {
toAdd.forEach(
(key, values) ->
target.computeIfAbsent(key, (unused) -> new ArrayList<>(values.size()))
.addAll(values));
}
private <T extends AbstractChannelStateHandle<?>> void checkForUnsupportedToplogyChanges(
Map<OperatorID, OperatorState> oldOperatorStates,
Function<OperatorSubtaskState, StateObjectCollection<T>> extractHandle,
OperatorID expectedOperatorID) {
final List<OperatorID> unexpectedState =
oldOperatorStates.entrySet().stream()
.filter(idAndState -> !idAndState.getKey().equals(expectedOperatorID))
.filter(idAndState -> hasChannelState(idAndState.getValue(), extractHandle))
.map(Map.Entry::getKey)
.collect(Collectors.toList());
if (!unexpectedState.isEmpty()) {
throw new IllegalStateException(
"Cannot recover from unaligned checkpoint when topology changes, such that "
+ "data exchanges with persisted data are now chained.\n"
+ "The following operators contain channel state: "
+ unexpectedState);
}
}
private <T extends AbstractChannelStateHandle<?>> boolean hasChannelState(
OperatorState operatorState,
Function<OperatorSubtaskState, StateObjectCollection<T>> extractHandle) {
return operatorState.getSubtaskStates().values().stream()
.anyMatch(subState -> !isEmpty(extractHandle.apply(subState)));
}
private <T extends AbstractChannelStateHandle<?>> boolean isEmpty(StateObjectCollection<T> s) {
return s.stream().allMatch(state -> state.getOffsets().isEmpty());
}
private static <T extends AbstractChannelStateHandle<I>, I> List<List<T>> getPartitionState(
List<List<T>> subtaskStates, Function<I, Integer> partitionExtractor, int partitionId) {
return subtaskStates.stream()
.map(
subtaskState ->
subtaskState.stream()
.filter(
state ->
partitionExtractor.apply(state.getInfo())
== partitionId)
.collect(Collectors.toList()))
.collect(Collectors.toList());
}
private static <T extends StateObject>
Map<OperatorID, List<List<T>>> splitManagedAndRawOperatorStates(
Map<OperatorID, OperatorState> operatorStates,
Function<OperatorSubtaskState, StateObjectCollection<T>> extractHandle) {
return operatorStates.entrySet().stream()
.collect(
Collectors.toMap(
Map.Entry::getKey,
operatorIdAndState ->
splitBySubtasks(
operatorIdAndState.getValue(), extractHandle)));
}
private static <T extends StateObject> List<List<T>> splitBySubtasks(
OperatorState operatorState,
Function<OperatorSubtaskState, StateObjectCollection<T>> extractHandle) {
List<List<T>> statePerSubtask = new ArrayList<>(operatorState.getParallelism());
for (int subTaskIndex = 0; subTaskIndex < operatorState.getParallelism(); subTaskIndex++) {
OperatorSubtaskState subtaskState = operatorState.getState(subTaskIndex);
statePerSubtask.add(
subtaskState == null
? emptyList()
: extractHandle.apply(subtaskState).asList());
}
return statePerSubtask;
}
/**
* Collect {@link KeyGroupsStateHandle managedKeyedStateHandles} which have intersection with
* given {@link KeyGroupRange} from {@link TaskState operatorState}.
*
* @param operatorState all state handles of a operator
* @param subtaskKeyGroupRange the KeyGroupRange of a subtask
* @return all managedKeyedStateHandles which have intersection with given KeyGroupRange
*/
public static List<KeyedStateHandle> getManagedKeyedStateHandles(
OperatorState operatorState, KeyGroupRange subtaskKeyGroupRange) {
final int parallelism = operatorState.getParallelism();
List<KeyedStateHandle> subtaskKeyedStateHandles = null;
for (int i = 0; i < parallelism; i++) {
if (operatorState.getState(i) != null) {
Collection<KeyedStateHandle> keyedStateHandles =
operatorState.getState(i).getManagedKeyedState();
if (subtaskKeyedStateHandles == null) {
subtaskKeyedStateHandles =
new ArrayList<>(parallelism * keyedStateHandles.size());
}
extractIntersectingState(
keyedStateHandles, subtaskKeyGroupRange, subtaskKeyedStateHandles);
}
}
return subtaskKeyedStateHandles != null ? subtaskKeyedStateHandles : emptyList();
}
/**
* Collect {@link KeyGroupsStateHandle rawKeyedStateHandles} which have intersection with given
* {@link KeyGroupRange} from {@link TaskState operatorState}.
*
* @param operatorState all state handles of a operator
* @param subtaskKeyGroupRange the KeyGroupRange of a subtask
* @return all rawKeyedStateHandles which have intersection with given KeyGroupRange
*/
public static List<KeyedStateHandle> getRawKeyedStateHandles(
OperatorState operatorState, KeyGroupRange subtaskKeyGroupRange) {
final int parallelism = operatorState.getParallelism();
List<KeyedStateHandle> extractedKeyedStateHandles = null;
for (int i = 0; i < parallelism; i++) {
if (operatorState.getState(i) != null) {
Collection<KeyedStateHandle> rawKeyedState =
operatorState.getState(i).getRawKeyedState();
if (extractedKeyedStateHandles == null) {
extractedKeyedStateHandles =
new ArrayList<>(parallelism * rawKeyedState.size());
}
extractIntersectingState(
rawKeyedState, subtaskKeyGroupRange, extractedKeyedStateHandles);
}
}
return extractedKeyedStateHandles != null ? extractedKeyedStateHandles : emptyList();
}
/**
* Extracts certain key group ranges from the given state handles and adds them to the
* collector.
*/
@VisibleForTesting
public static void extractIntersectingState(
Collection<? extends KeyedStateHandle> originalSubtaskStateHandles,
KeyGroupRange rangeToExtract,
List<KeyedStateHandle> extractedStateCollector) {
for (KeyedStateHandle keyedStateHandle : originalSubtaskStateHandles) {
if (keyedStateHandle != null) {
KeyedStateHandle intersectedKeyedStateHandle =
keyedStateHandle.getIntersection(rangeToExtract);
if (intersectedKeyedStateHandle != null) {
extractedStateCollector.add(intersectedKeyedStateHandle);
}
}
}
}
/**
* Groups the available set of key groups into key group partitions. A key group partition is
* the set of key groups which is assigned to the same task. Each set of the returned list
* constitutes a key group partition.
*
* <p><b>IMPORTANT</b>: The assignment of key groups to partitions has to be in sync with the
* KeyGroupStreamPartitioner.
*
* @param numberKeyGroups Number of available key groups (indexed from 0 to numberKeyGroups - 1)
* @param parallelism Parallelism to generate the key group partitioning for
* @return List of key group partitions
*/
public static List<KeyGroupRange> createKeyGroupPartitions(
int numberKeyGroups, int parallelism) {
Preconditions.checkArgument(numberKeyGroups >= parallelism);
List<KeyGroupRange> result = new ArrayList<>(parallelism);
for (int i = 0; i < parallelism; ++i) {
result.add(
KeyGroupRangeAssignment.computeKeyGroupRangeForOperatorIndex(
numberKeyGroups, parallelism, i));
}
return result;
}
/**
* Verifies conditions in regards to parallelism and maxParallelism that must be met when
* restoring state.
*
* @param operatorState state to restore
* @param executionJobVertex task for which the state should be restored
*/
private static void checkParallelismPreconditions(
OperatorState operatorState, ExecutionJobVertex executionJobVertex) {
// ----------------------------------------max parallelism
// preconditions-------------------------------------
if (operatorState.getMaxParallelism() < executionJobVertex.getParallelism()) {
throw new IllegalStateException(
"The state for task "
+ executionJobVertex.getJobVertexId()
+ " can not be restored. The maximum parallelism ("
+ operatorState.getMaxParallelism()
+ ") of the restored state is lower than the configured parallelism ("
+ executionJobVertex.getParallelism()
+ "). Please reduce the parallelism of the task to be lower or equal to the maximum parallelism.");
}
// check that the number of key groups have not changed or if we need to override it to
// satisfy the restored state
if (operatorState.getMaxParallelism() != executionJobVertex.getMaxParallelism()) {
if (executionJobVertex.canRescaleMaxParallelism(operatorState.getMaxParallelism())) {
LOG.debug(
"Rescaling maximum parallelism for JobVertex {} from {} to {}",
executionJobVertex.getJobVertexId(),
executionJobVertex.getMaxParallelism(),
operatorState.getMaxParallelism());
executionJobVertex.setMaxParallelism(operatorState.getMaxParallelism());
} else {
// if the max parallelism cannot be rescaled, we complain on mismatch
throw new IllegalStateException(
"The maximum parallelism ("
+ operatorState.getMaxParallelism()
+ ") with which the latest "
+ "checkpoint of the execution job vertex "
+ executionJobVertex
+ " has been taken and the current maximum parallelism ("
+ executionJobVertex.getMaxParallelism()
+ ") changed. This "
+ "is currently not supported.");
}
}
}
/**
* Verifies that all operator states can be mapped to an execution job vertex.
*
* @param allowNonRestoredState if false an exception will be thrown if a state could not be
* mapped
* @param operatorStates operator states to map
* @param tasks task to map to
*/
private static void checkStateMappingCompleteness(
boolean allowNonRestoredState,
Map<OperatorID, OperatorState> operatorStates,
Set<ExecutionJobVertex> tasks) {
Set<OperatorID> allOperatorIDs = new HashSet<>();
for (ExecutionJobVertex executionJobVertex : tasks) {
for (OperatorIDPair operatorIDPair : executionJobVertex.getOperatorIDs()) {
allOperatorIDs.add(operatorIDPair.getGeneratedOperatorID());
operatorIDPair.getUserDefinedOperatorID().ifPresent(allOperatorIDs::add);
}
}
for (Map.Entry<OperatorID, OperatorState> operatorGroupStateEntry :
operatorStates.entrySet()) {
// ----------------------------------------find operator for
// state---------------------------------------------
if (!allOperatorIDs.contains(operatorGroupStateEntry.getKey())) {
OperatorState operatorState = operatorGroupStateEntry.getValue();
if (allowNonRestoredState) {
LOG.info(
"Skipped checkpoint state for operator {}.",
operatorState.getOperatorID());
} else {
throw new IllegalStateException(
"There is no operator for the state " + operatorState.getOperatorID());
}
}
}
}
public static <T> Map<OperatorInstanceID, List<T>> applyRepartitioner(
OperatorID operatorID,
OperatorStateRepartitioner<T> opStateRepartitioner,
List<List<T>> chainOpParallelStates,
int oldParallelism,
int newParallelism) {
List<List<T>> states =
applyRepartitioner(
opStateRepartitioner,
chainOpParallelStates,
oldParallelism,
newParallelism);
return toInstanceMap(operatorID, states);
}
private static <T> Map<OperatorInstanceID, List<T>> toInstanceMap(
OperatorID operatorID, List<List<T>> states) {
Map<OperatorInstanceID, List<T>> result =
CollectionUtil.newHashMapWithExpectedSize(states.size());
for (int subtaskIndex = 0; subtaskIndex < states.size(); subtaskIndex++) {
checkNotNull(states.get(subtaskIndex) != null, "states.get(subtaskIndex) is null");
result.put(OperatorInstanceID.of(subtaskIndex, operatorID), states.get(subtaskIndex));
}
return result;
}
/**
* Repartitions the given operator state using the given {@link OperatorStateRepartitioner} with
* respect to the new parallelism.
*
* @param opStateRepartitioner partitioner to use
* @param chainOpParallelStates state to repartition
* @param oldParallelism parallelism with which the state is currently partitioned
* @param newParallelism parallelism with which the state should be partitioned
* @return repartitioned state
*/
// TODO rewrite based on operator id
public static <T> List<List<T>> applyRepartitioner(
OperatorStateRepartitioner<T> opStateRepartitioner,
List<List<T>> chainOpParallelStates,
int oldParallelism,
int newParallelism) {
if (chainOpParallelStates == null) {
return emptyList();
}
return opStateRepartitioner.repartitionState(
chainOpParallelStates, oldParallelism, newParallelism);
}
}