blob: c4e30204030ef34f39599660eeeb45d2a64ca61f [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.runtime.checkpoint;
import org.apache.flink.runtime.OperatorIDPair;
import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor.InflightDataGateOrPartitionRescalingDescriptor;
import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor.InflightDataGateOrPartitionRescalingDescriptor.MappingType;
import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
import org.apache.flink.runtime.executiongraph.IntermediateResult;
import org.apache.flink.runtime.io.network.api.writer.SubtaskStateMapper;
import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.apache.flink.runtime.jobgraph.OperatorID;
import org.apache.flink.runtime.jobgraph.OperatorInstanceID;
import org.apache.flink.runtime.state.InputChannelStateHandle;
import org.apache.flink.runtime.state.KeyedStateHandle;
import org.apache.flink.runtime.state.OperatorStateHandle;
import org.apache.flink.runtime.state.ResultSubpartitionStateHandle;
import org.apache.flink.runtime.state.StateObject;
import org.apache.flink.util.CollectionUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.stream.IntStream;
import static java.util.Collections.emptySet;
import static org.apache.flink.util.Preconditions.checkNotNull;
import static org.apache.flink.util.Preconditions.checkState;
/**
* Used by {@link StateAssignmentOperation} to store temporal information while creating {@link
* OperatorSubtaskState}.
*/
class TaskStateAssignment {
private static final Logger LOG = LoggerFactory.getLogger(TaskStateAssignment.class);
final ExecutionJobVertex executionJobVertex;
final Map<OperatorID, OperatorState> oldState;
final boolean hasNonFinishedState;
final boolean isFullyFinished;
final boolean hasInputState;
final boolean hasOutputState;
final int newParallelism;
final OperatorID inputOperatorID;
final OperatorID outputOperatorID;
final Map<OperatorInstanceID, List<OperatorStateHandle>> subManagedOperatorState;
final Map<OperatorInstanceID, List<OperatorStateHandle>> subRawOperatorState;
final Map<OperatorInstanceID, List<KeyedStateHandle>> subManagedKeyedState;
final Map<OperatorInstanceID, List<KeyedStateHandle>> subRawKeyedState;
final Map<OperatorInstanceID, List<InputChannelStateHandle>> inputChannelStates;
final Map<OperatorInstanceID, List<ResultSubpartitionStateHandle>> resultSubpartitionStates;
/** The subtask mapping when the output operator was rescaled. */
private final Map<Integer, SubtasksRescaleMapping> outputSubtaskMappings = new HashMap<>();
/** The subtask mapping when the input operator was rescaled. */
private final Map<Integer, SubtasksRescaleMapping> inputSubtaskMappings = new HashMap<>();
@Nullable private TaskStateAssignment[] downstreamAssignments;
@Nullable private TaskStateAssignment[] upstreamAssignments;
@Nullable private Boolean hasUpstreamOutputStates;
@Nullable private Boolean hasDownstreamInputStates;
private final Map<IntermediateDataSetID, TaskStateAssignment> consumerAssignment;
private final Map<ExecutionJobVertex, TaskStateAssignment> vertexAssignments;
public TaskStateAssignment(
ExecutionJobVertex executionJobVertex,
Map<OperatorID, OperatorState> oldState,
Map<IntermediateDataSetID, TaskStateAssignment> consumerAssignment,
Map<ExecutionJobVertex, TaskStateAssignment> vertexAssignments) {
this.executionJobVertex = executionJobVertex;
this.oldState = oldState;
this.hasNonFinishedState =
oldState.values().stream()
.anyMatch(operatorState -> operatorState.getNumberCollectedStates() > 0);
this.isFullyFinished = oldState.values().stream().anyMatch(OperatorState::isFullyFinished);
if (isFullyFinished) {
checkState(
oldState.values().stream().allMatch(OperatorState::isFullyFinished),
"JobVertex could not have mixed finished and unfinished operators");
}
newParallelism = executionJobVertex.getParallelism();
this.consumerAssignment = checkNotNull(consumerAssignment);
this.vertexAssignments = checkNotNull(vertexAssignments);
final int expectedNumberOfSubtasks = newParallelism * oldState.size();
subManagedOperatorState =
CollectionUtil.newHashMapWithExpectedSize(expectedNumberOfSubtasks);
subRawOperatorState = CollectionUtil.newHashMapWithExpectedSize(expectedNumberOfSubtasks);
inputChannelStates = CollectionUtil.newHashMapWithExpectedSize(expectedNumberOfSubtasks);
resultSubpartitionStates =
CollectionUtil.newHashMapWithExpectedSize(expectedNumberOfSubtasks);
subManagedKeyedState = CollectionUtil.newHashMapWithExpectedSize(expectedNumberOfSubtasks);
subRawKeyedState = CollectionUtil.newHashMapWithExpectedSize(expectedNumberOfSubtasks);
final List<OperatorIDPair> operatorIDs = executionJobVertex.getOperatorIDs();
outputOperatorID = operatorIDs.get(0).getGeneratedOperatorID();
inputOperatorID = operatorIDs.get(operatorIDs.size() - 1).getGeneratedOperatorID();
hasInputState =
oldState.get(inputOperatorID).getStates().stream()
.anyMatch(subState -> !subState.getInputChannelState().isEmpty());
hasOutputState =
oldState.get(outputOperatorID).getStates().stream()
.anyMatch(subState -> !subState.getResultSubpartitionState().isEmpty());
}
public TaskStateAssignment[] getDownstreamAssignments() {
if (downstreamAssignments == null) {
downstreamAssignments =
Arrays.stream(executionJobVertex.getProducedDataSets())
.map(result -> consumerAssignment.get(result.getId()))
.toArray(TaskStateAssignment[]::new);
}
return downstreamAssignments;
}
private static int getAssignmentIndex(
TaskStateAssignment[] assignments, TaskStateAssignment assignment) {
return Arrays.asList(assignments).indexOf(assignment);
}
public TaskStateAssignment[] getUpstreamAssignments() {
if (upstreamAssignments == null) {
upstreamAssignments =
executionJobVertex.getInputs().stream()
.map(result -> vertexAssignments.get(result.getProducer()))
.toArray(TaskStateAssignment[]::new);
}
return upstreamAssignments;
}
public OperatorSubtaskState getSubtaskState(OperatorInstanceID instanceID) {
checkState(
subManagedKeyedState.containsKey(instanceID)
|| !subRawKeyedState.containsKey(instanceID),
"If an operator has no managed key state, it should also not have a raw keyed state.");
final StateObjectCollection<InputChannelStateHandle> inputState =
getState(instanceID, inputChannelStates);
final StateObjectCollection<ResultSubpartitionStateHandle> outputState =
getState(instanceID, resultSubpartitionStates);
return OperatorSubtaskState.builder()
.setManagedOperatorState(getState(instanceID, subManagedOperatorState))
.setRawOperatorState(getState(instanceID, subRawOperatorState))
.setManagedKeyedState(getState(instanceID, subManagedKeyedState))
.setRawKeyedState(getState(instanceID, subRawKeyedState))
.setInputChannelState(inputState)
.setResultSubpartitionState(outputState)
.setInputRescalingDescriptor(
createRescalingDescriptor(
instanceID,
inputOperatorID,
getUpstreamAssignments(),
(assignment, recompute) -> {
int assignmentIndex =
getAssignmentIndex(
assignment.getDownstreamAssignments(), this);
return assignment.getOutputMapping(assignmentIndex, recompute);
},
inputSubtaskMappings,
this::getInputMapping))
.setOutputRescalingDescriptor(
createRescalingDescriptor(
instanceID,
outputOperatorID,
getDownstreamAssignments(),
(assignment, recompute) -> {
int assignmentIndex =
getAssignmentIndex(
assignment.getUpstreamAssignments(), this);
return assignment.getInputMapping(assignmentIndex, recompute);
},
outputSubtaskMappings,
this::getOutputMapping))
.build();
}
public boolean hasUpstreamOutputStates() {
if (hasUpstreamOutputStates == null) {
hasUpstreamOutputStates =
Arrays.stream(getUpstreamAssignments())
.anyMatch(assignment -> assignment.hasOutputState);
}
return hasUpstreamOutputStates;
}
public boolean hasDownstreamInputStates() {
if (hasDownstreamInputStates == null) {
hasDownstreamInputStates =
Arrays.stream(getDownstreamAssignments())
.anyMatch(assignment -> assignment.hasInputState);
}
return hasDownstreamInputStates;
}
private InflightDataGateOrPartitionRescalingDescriptor log(
InflightDataGateOrPartitionRescalingDescriptor descriptor, int subtask, int partition) {
LOG.debug(
"created {} for task={} subtask={} partition={}",
descriptor,
executionJobVertex.getName(),
subtask,
partition);
return descriptor;
}
private InflightDataRescalingDescriptor log(
InflightDataRescalingDescriptor descriptor, int subtask) {
LOG.debug(
"created {} for task={} subtask={}",
descriptor,
executionJobVertex.getName(),
subtask);
return descriptor;
}
private InflightDataRescalingDescriptor createRescalingDescriptor(
OperatorInstanceID instanceID,
OperatorID expectedOperatorID,
TaskStateAssignment[] connectedAssignments,
BiFunction<TaskStateAssignment, Boolean, SubtasksRescaleMapping> mappingRetriever,
Map<Integer, SubtasksRescaleMapping> subtaskGateOrPartitionMappings,
Function<Integer, SubtasksRescaleMapping> subtaskMappingCalculator) {
if (!expectedOperatorID.equals(instanceID.getOperatorId())) {
return InflightDataRescalingDescriptor.NO_RESCALE;
}
SubtasksRescaleMapping[] rescaledChannelsMappings =
Arrays.stream(connectedAssignments)
.map(assignment -> mappingRetriever.apply(assignment, false))
.toArray(SubtasksRescaleMapping[]::new);
// no state on input and output, especially for any aligned checkpoint
if (subtaskGateOrPartitionMappings.isEmpty()
&& Arrays.stream(rescaledChannelsMappings).allMatch(Objects::isNull)) {
return InflightDataRescalingDescriptor.NO_RESCALE;
}
InflightDataGateOrPartitionRescalingDescriptor[] gateOrPartitionDescriptors =
createGateOrPartitionRescalingDescriptors(
instanceID,
connectedAssignments,
assignment -> mappingRetriever.apply(assignment, true),
subtaskGateOrPartitionMappings,
subtaskMappingCalculator,
rescaledChannelsMappings);
if (Arrays.stream(gateOrPartitionDescriptors)
.allMatch(InflightDataGateOrPartitionRescalingDescriptor::isIdentity)) {
return log(InflightDataRescalingDescriptor.NO_RESCALE, instanceID.getSubtaskId());
} else {
return log(
new InflightDataRescalingDescriptor(gateOrPartitionDescriptors),
instanceID.getSubtaskId());
}
}
private InflightDataGateOrPartitionRescalingDescriptor[]
createGateOrPartitionRescalingDescriptors(
OperatorInstanceID instanceID,
TaskStateAssignment[] connectedAssignments,
Function<TaskStateAssignment, SubtasksRescaleMapping> mappingCalculator,
Map<Integer, SubtasksRescaleMapping> subtaskGateOrPartitionMappings,
Function<Integer, SubtasksRescaleMapping> subtaskMappingCalculator,
SubtasksRescaleMapping[] rescaledChannelsMappings) {
return IntStream.range(0, rescaledChannelsMappings.length)
.mapToObj(
partition -> {
TaskStateAssignment connectedAssignment =
connectedAssignments[partition];
SubtasksRescaleMapping rescaleMapping =
Optional.ofNullable(rescaledChannelsMappings[partition])
.orElseGet(
() ->
mappingCalculator.apply(
connectedAssignment));
SubtasksRescaleMapping subtaskMapping =
Optional.ofNullable(
subtaskGateOrPartitionMappings.get(partition))
.orElseGet(
() ->
subtaskMappingCalculator.apply(
partition));
return getInflightDataGateOrPartitionRescalingDescriptor(
instanceID, partition, rescaleMapping, subtaskMapping);
})
.toArray(InflightDataGateOrPartitionRescalingDescriptor[]::new);
}
private InflightDataGateOrPartitionRescalingDescriptor
getInflightDataGateOrPartitionRescalingDescriptor(
OperatorInstanceID instanceID,
int partition,
SubtasksRescaleMapping rescaleMapping,
SubtasksRescaleMapping subtaskMapping) {
int[] oldSubtaskInstances =
subtaskMapping.rescaleMappings.getMappedIndexes(instanceID.getSubtaskId());
// no scaling or simple scale-up without the need of virtual
// channels.
boolean isIdentity =
(subtaskMapping.rescaleMappings.isIdentity()
&& rescaleMapping.getRescaleMappings().isIdentity())
|| oldSubtaskInstances.length == 0;
final Set<Integer> ambiguousSubtasks =
subtaskMapping.mayHaveAmbiguousSubtasks
? subtaskMapping.rescaleMappings.getAmbiguousTargets()
: emptySet();
return log(
new InflightDataGateOrPartitionRescalingDescriptor(
oldSubtaskInstances,
rescaleMapping.getRescaleMappings(),
ambiguousSubtasks,
isIdentity ? MappingType.IDENTITY : MappingType.RESCALING),
instanceID.getSubtaskId(),
partition);
}
private <T extends StateObject> StateObjectCollection<T> getState(
OperatorInstanceID instanceID,
Map<OperatorInstanceID, List<T>> subManagedOperatorState) {
List<T> value = subManagedOperatorState.get(instanceID);
return value != null ? new StateObjectCollection<>(value) : StateObjectCollection.empty();
}
private SubtasksRescaleMapping getOutputMapping(int assignmentIndex, boolean recompute) {
SubtasksRescaleMapping mapping = outputSubtaskMappings.get(assignmentIndex);
if (recompute && mapping == null) {
return getOutputMapping(assignmentIndex);
} else {
return mapping;
}
}
private SubtasksRescaleMapping getInputMapping(int assignmentIndex, boolean recompute) {
SubtasksRescaleMapping mapping = inputSubtaskMappings.get(assignmentIndex);
if (recompute && mapping == null) {
return getInputMapping(assignmentIndex);
} else {
return mapping;
}
}
public SubtasksRescaleMapping getOutputMapping(int partitionIndex) {
final TaskStateAssignment downstreamAssignment = getDownstreamAssignments()[partitionIndex];
final IntermediateResult output = executionJobVertex.getProducedDataSets()[partitionIndex];
final int gateIndex = downstreamAssignment.executionJobVertex.getInputs().indexOf(output);
final SubtaskStateMapper mapper =
checkNotNull(
downstreamAssignment
.executionJobVertex
.getJobVertex()
.getInputs()
.get(gateIndex)
.getUpstreamSubtaskStateMapper(),
"No channel rescaler found during rescaling of channel state");
final RescaleMappings mapping =
mapper.getNewToOldSubtasksMapping(
oldState.get(outputOperatorID).getParallelism(), newParallelism);
return outputSubtaskMappings.compute(
partitionIndex,
(idx, oldMapping) ->
checkSubtaskMapping(oldMapping, mapping, mapper.isAmbiguous()));
}
public SubtasksRescaleMapping getInputMapping(int gateIndex) {
final SubtaskStateMapper mapper =
checkNotNull(
executionJobVertex
.getJobVertex()
.getInputs()
.get(gateIndex)
.getDownstreamSubtaskStateMapper(),
"No channel rescaler found during rescaling of channel state");
final RescaleMappings mapping =
mapper.getNewToOldSubtasksMapping(
oldState.get(inputOperatorID).getParallelism(), newParallelism);
return inputSubtaskMappings.compute(
gateIndex,
(idx, oldMapping) ->
checkSubtaskMapping(oldMapping, mapping, mapper.isAmbiguous()));
}
@Override
public String toString() {
return "TaskStateAssignment for " + executionJobVertex.getName();
}
private static @Nonnull SubtasksRescaleMapping checkSubtaskMapping(
@Nullable SubtasksRescaleMapping oldMapping,
RescaleMappings mapping,
boolean mayHaveAmbiguousSubtasks) {
if (oldMapping == null) {
return new SubtasksRescaleMapping(mapping, mayHaveAmbiguousSubtasks);
}
if (!oldMapping.rescaleMappings.equals(mapping)) {
throw new IllegalStateException(
"Incompatible subtask mappings: are multiple operators "
+ "ingesting/producing intermediate results with varying degrees of parallelism?"
+ "Found "
+ oldMapping
+ " and "
+ mapping
+ ".");
}
return new SubtasksRescaleMapping(
mapping, oldMapping.mayHaveAmbiguousSubtasks || mayHaveAmbiguousSubtasks);
}
static class SubtasksRescaleMapping {
private final RescaleMappings rescaleMappings;
/**
* If channel data cannot be safely divided into subtasks (several new subtask indexes are
* associated with the same old subtask index). Mostly used for range partitioners.
*/
private final boolean mayHaveAmbiguousSubtasks;
private SubtasksRescaleMapping(
RescaleMappings rescaleMappings, boolean mayHaveAmbiguousSubtasks) {
this.rescaleMappings = rescaleMappings;
this.mayHaveAmbiguousSubtasks = mayHaveAmbiguousSubtasks;
}
public RescaleMappings getRescaleMappings() {
return rescaleMappings;
}
public boolean isMayHaveAmbiguousSubtasks() {
return mayHaveAmbiguousSubtasks;
}
}
}