[FLINK-31963][state] Fix rescaling bug in recovery from unaligned checkpoints. (#22584) (#22595)

This commit fixes problems in StateAssignmentOperation for unaligned checkpoints with stateless operators that have upstream operators with output partition state or downstream operators with input channel state.

(cherry picked from commit 354c0f455b92c083299d8028f161f0dd113ab614)
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
index 681e0b1..e476c6b 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
@@ -136,19 +136,24 @@
 
         // repartition state
         for (TaskStateAssignment stateAssignment : vertexAssignments.values()) {
-            if (stateAssignment.hasNonFinishedState) {
+            if (stateAssignment.hasNonFinishedState
+                    // FLINK-31963: We need to run repartitioning for stateless operators that have
+                    // upstream output or downstream input states.
+                    || stateAssignment.hasUpstreamOutputStates()
+                    || stateAssignment.hasDownstreamInputStates()) {
                 assignAttemptState(stateAssignment);
             }
         }
 
         // actually assign the state
         for (TaskStateAssignment stateAssignment : vertexAssignments.values()) {
-            // If upstream has output states, even the empty task state should be assigned for the
-            // current task in order to notify this task that the old states will send to it which
-            // likely should be filtered.
+            // If upstream has output states or downstream has input states, even the empty task
+            // state should be assigned for the current task in order to notify this task that the
+            // old states will send to it which likely should be filtered.
             if (stateAssignment.hasNonFinishedState
                     || stateAssignment.isFullyFinished
-                    || stateAssignment.hasUpstreamOutputStates()) {
+                    || stateAssignment.hasUpstreamOutputStates()
+                    || stateAssignment.hasDownstreamInputStates()) {
                 assignTaskStateToExecutionJobVertices(stateAssignment);
             }
         }
@@ -345,9 +350,10 @@
                                         newParallelism)));
     }
 
-    public <I, T extends AbstractChannelStateHandle<I>> void reDistributeResultSubpartitionStates(
-            TaskStateAssignment assignment) {
-        if (!assignment.hasOutputState) {
+    public void reDistributeResultSubpartitionStates(TaskStateAssignment assignment) {
+        // FLINK-31963: We can skip this phase if there is no output state AND downstream has no
+        // input states
+        if (!assignment.hasOutputState && !assignment.hasDownstreamInputStates()) {
             return;
         }
 
@@ -394,7 +400,9 @@
     }
 
     public void reDistributeInputChannelStates(TaskStateAssignment stateAssignment) {
-        if (!stateAssignment.hasInputState) {
+        // FLINK-31963: We can skip this phase only if there is no input state AND upstream has no
+        // output states
+        if (!stateAssignment.hasInputState && !stateAssignment.hasUpstreamOutputStates()) {
             return;
         }
 
@@ -435,7 +443,7 @@
                             : getPartitionState(
                                     inputOperatorState, InputChannelInfo::getGateIdx, gateIndex);
             final MappingBasedRepartitioner<InputChannelStateHandle> repartitioner =
-                    new MappingBasedRepartitioner(mapping);
+                    new MappingBasedRepartitioner<>(mapping);
             final Map<OperatorInstanceID, List<InputChannelStateHandle>> repartitioned =
                     applyRepartitioner(
                             stateAssignment.inputOperatorID,
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskStateAssignment.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskStateAssignment.java
index 75ffc71..e9f9d11 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskStateAssignment.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskStateAssignment.java
@@ -84,6 +84,8 @@
 
     @Nullable private TaskStateAssignment[] downstreamAssignments;
     @Nullable private TaskStateAssignment[] upstreamAssignments;
+    @Nullable private Boolean hasUpstreamOutputStates;
+    @Nullable private Boolean hasDownstreamInputStates;
 
     private final Map<IntermediateDataSetID, TaskStateAssignment> consumerAssignment;
     private final Map<ExecutionJobVertex, TaskStateAssignment> vertexAssignments;
@@ -202,8 +204,21 @@
     }
 
     public boolean hasUpstreamOutputStates() {
-        return Arrays.stream(getUpstreamAssignments())
-                .anyMatch(assignment -> assignment.hasOutputState);
+        if (hasUpstreamOutputStates == null) {
+            hasUpstreamOutputStates =
+                    Arrays.stream(getUpstreamAssignments())
+                            .anyMatch(assignment -> assignment.hasOutputState);
+        }
+        return hasUpstreamOutputStates;
+    }
+
+    public boolean hasDownstreamInputStates() {
+        if (hasDownstreamInputStates == null) {
+            hasDownstreamInputStates =
+                    Arrays.stream(getDownstreamAssignments())
+                            .anyMatch(assignment -> assignment.hasInputState);
+        }
+        return hasDownstreamInputStates;
     }
 
     private InflightDataGateOrPartitionRescalingDescriptor log(
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperationTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperationTest.java
index f9cb551..bffdd86 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperationTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperationTest.java
@@ -22,6 +22,7 @@
 import org.apache.flink.runtime.OperatorIDPair;
 import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor.InflightDataGateOrPartitionRescalingDescriptor;
 import org.apache.flink.runtime.client.JobExecutionException;
+import org.apache.flink.runtime.executiongraph.Execution;
 import org.apache.flink.runtime.executiongraph.ExecutionGraph;
 import org.apache.flink.runtime.executiongraph.ExecutionGraphTestUtils;
 import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
@@ -51,6 +52,9 @@
 import org.junit.ClassRule;
 import org.junit.Test;
 
+import javax.annotation.Nullable;
+
+import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.EnumMap;
@@ -82,6 +86,7 @@
 import static org.apache.flink.runtime.io.network.api.writer.SubtaskStateMapper.ARBITRARY;
 import static org.apache.flink.runtime.io.network.api.writer.SubtaskStateMapper.RANGE;
 import static org.apache.flink.runtime.io.network.api.writer.SubtaskStateMapper.ROUND_ROBIN;
+import static org.apache.flink.util.Preconditions.checkArgument;
 import static org.hamcrest.CoreMatchers.is;
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.Matchers.containsInAnyOrder;
@@ -785,6 +790,129 @@
         }
     }
 
+    /** FLINK-31963: Tests rescaling for stateless operators and upstream result partition state. */
+    @Test
+    public void testOnlyUpstreamChannelRescaleStateAssignment()
+            throws JobException, JobExecutionException {
+        Random random = new Random();
+        OperatorSubtaskState upstreamOpState =
+                OperatorSubtaskState.builder()
+                        .setResultSubpartitionState(
+                                new StateObjectCollection<>(
+                                        asList(
+                                                createNewResultSubpartitionStateHandle(10, random),
+                                                createNewResultSubpartitionStateHandle(
+                                                        10, random))))
+                        .build();
+        testOnlyUpstreamOrDownstreamRescalingInternal(upstreamOpState, null, 5, 7);
+    }
+
+    /** FLINK-31963: Tests rescaling for stateless operators and downstream input channel state. */
+    @Test
+    public void testOnlyDownstreamChannelRescaleStateAssignment()
+            throws JobException, JobExecutionException {
+        Random random = new Random();
+        OperatorSubtaskState downstreamOpState =
+                OperatorSubtaskState.builder()
+                        .setInputChannelState(
+                                new StateObjectCollection<>(
+                                        asList(
+                                                createNewInputChannelStateHandle(10, random),
+                                                createNewInputChannelStateHandle(10, random))))
+                        .build();
+        testOnlyUpstreamOrDownstreamRescalingInternal(null, downstreamOpState, 5, 5);
+    }
+
+    private void testOnlyUpstreamOrDownstreamRescalingInternal(
+            @Nullable OperatorSubtaskState upstreamOpState,
+            @Nullable OperatorSubtaskState downstreamOpState,
+            int expectedUpstreamCount,
+            int expectedDownstreamCount)
+            throws JobException, JobExecutionException {
+
+        checkArgument(
+                upstreamOpState != downstreamOpState
+                        && (upstreamOpState == null || downstreamOpState == null),
+                "Either upstream or downstream state must exist, but not both");
+
+        // Start from parallelism 5 for both operators
+        int upstreamParallelism = 5;
+        int downstreamParallelism = 5;
+
+        // Build states
+        List<OperatorID> operatorIds = buildOperatorIds(2);
+        Map<OperatorID, OperatorState> states = new HashMap<>();
+        OperatorState upstreamState =
+                new OperatorState(operatorIds.get(0), upstreamParallelism, MAX_P);
+        OperatorState downstreamState =
+                new OperatorState(operatorIds.get(1), downstreamParallelism, MAX_P);
+
+        states.put(operatorIds.get(0), upstreamState);
+        states.put(operatorIds.get(1), downstreamState);
+
+        if (upstreamOpState != null) {
+            upstreamState.putState(0, upstreamOpState);
+            // rescale downstream 5 -> 3
+            downstreamParallelism = 3;
+        }
+
+        if (downstreamOpState != null) {
+            downstreamState.putState(0, downstreamOpState);
+            // rescale upstream 5 -> 3
+            upstreamParallelism = 3;
+        }
+
+        List<OperatorIdWithParallelism> opIdWithParallelism = new ArrayList<>(2);
+        opIdWithParallelism.add(
+                new OperatorIdWithParallelism(operatorIds.get(0), upstreamParallelism));
+        opIdWithParallelism.add(
+                new OperatorIdWithParallelism(operatorIds.get(1), downstreamParallelism));
+
+        Map<OperatorID, ExecutionJobVertex> vertices =
+                buildVertices(opIdWithParallelism, RANGE, ROUND_ROBIN);
+
+        // Run state assignment
+        new StateAssignmentOperation(0, new HashSet<>(vertices.values()), states, false)
+                .assignStates();
+
+        // Check results
+        ExecutionJobVertex upstreamExecutionJobVertex = vertices.get(operatorIds.get(0));
+        ExecutionJobVertex downstreamExecutionJobVertex = vertices.get(operatorIds.get(1));
+
+        List<TaskStateSnapshot> upstreamTaskStateSnapshots =
+                getTaskStateSnapshotFromVertex(upstreamExecutionJobVertex);
+        List<TaskStateSnapshot> downstreamTaskStateSnapshots =
+                getTaskStateSnapshotFromVertex(downstreamExecutionJobVertex);
+
+        checkMappings(
+                upstreamTaskStateSnapshots,
+                TaskStateSnapshot::getOutputRescalingDescriptor,
+                expectedUpstreamCount);
+
+        checkMappings(
+                downstreamTaskStateSnapshots,
+                TaskStateSnapshot::getInputRescalingDescriptor,
+                expectedDownstreamCount);
+    }
+
+    private void checkMappings(
+            List<TaskStateSnapshot> taskStateSnapshots,
+            Function<TaskStateSnapshot, InflightDataRescalingDescriptor> extractFun,
+            int expectedCount) {
+        Assert.assertEquals(
+                expectedCount,
+                taskStateSnapshots.stream()
+                        .map(extractFun)
+                        .mapToInt(
+                                x -> {
+                                    int len = x.getOldSubtaskIndexes(0).length;
+                                    // Assert that there is a mapping.
+                                    Assert.assertTrue(len > 0);
+                                    return len;
+                                })
+                        .sum());
+    }
+
     @Test
     public void testStateWithFullyFinishedOperators() throws JobException, JobExecutionException {
         List<OperatorID> operatorIds = buildOperatorIds(2);
@@ -949,15 +1077,50 @@
                                 }));
     }
 
+    private static class OperatorIdWithParallelism {
+        private final OperatorID operatorID;
+        private final int parallelism;
+
+        public OperatorID getOperatorID() {
+            return operatorID;
+        }
+
+        public int getParallelism() {
+            return parallelism;
+        }
+
+        public OperatorIdWithParallelism(OperatorID operatorID, int parallelism) {
+            this.operatorID = operatorID;
+            this.parallelism = parallelism;
+        }
+    }
+
     private Map<OperatorID, ExecutionJobVertex> buildVertices(
             List<OperatorID> operatorIds,
-            int parallelism,
+            int parallelisms,
+            SubtaskStateMapper downstreamRescaler,
+            SubtaskStateMapper upstreamRescaler)
+            throws JobException, JobExecutionException {
+        List<OperatorIdWithParallelism> opIdsWithParallelism =
+                operatorIds.stream()
+                        .map(operatorID -> new OperatorIdWithParallelism(operatorID, parallelisms))
+                        .collect(Collectors.toList());
+        return buildVertices(opIdsWithParallelism, downstreamRescaler, upstreamRescaler);
+    }
+
+    private Map<OperatorID, ExecutionJobVertex> buildVertices(
+            List<OperatorIdWithParallelism> operatorIdsAndParallelism,
             SubtaskStateMapper downstreamRescaler,
             SubtaskStateMapper upstreamRescaler)
             throws JobException, JobExecutionException {
         final JobVertex[] jobVertices =
-                operatorIds.stream()
-                        .map(id -> createJobVertex(id, id, parallelism))
+                operatorIdsAndParallelism.stream()
+                        .map(
+                                idWithParallelism ->
+                                        createJobVertex(
+                                                idWithParallelism.getOperatorID(),
+                                                idWithParallelism.getOperatorID(),
+                                                idWithParallelism.getParallelism()))
                         .toArray(JobVertex[]::new);
         for (int index = 1; index < jobVertices.length; index++) {
             connectVertices(
@@ -1029,6 +1192,15 @@
         return jobVertex;
     }
 
+    private List<TaskStateSnapshot> getTaskStateSnapshotFromVertex(
+            ExecutionJobVertex executionJobVertex) {
+        return Arrays.stream(executionJobVertex.getTaskVertices())
+                .map(ExecutionVertex::getCurrentExecutionAttempt)
+                .map(Execution::getTaskRestore)
+                .map(JobManagerTaskRestore::getTaskStateSnapshot)
+                .collect(Collectors.toList());
+    }
+
     private OperatorSubtaskState getAssignedState(
             ExecutionJobVertex executionJobVertex, OperatorID operatorId, int subtaskIdx) {
         return executionJobVertex
diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointITCase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointITCase.java
index c75c8db..2c71a97 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointITCase.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointITCase.java
@@ -113,7 +113,8 @@
                     StreamExecutionEnvironment env,
                     int minCheckpoints,
                     boolean slotSharing,
-                    int expectedRestarts) {
+                    int expectedRestarts,
+                    long sourceSleepMs) {
                 final int parallelism = env.getParallelism();
                 final SingleOutputStreamOperator<Long> stream =
                         env.fromSource(
@@ -121,7 +122,8 @@
                                                 minCheckpoints,
                                                 parallelism,
                                                 expectedRestarts,
-                                                env.getCheckpointInterval()),
+                                                env.getCheckpointInterval(),
+                                                sourceSleepMs),
                                         noWatermarks(),
                                         "source")
                                 .slotSharingGroup(slotSharing ? "default" : "source")
@@ -144,7 +146,8 @@
                     StreamExecutionEnvironment env,
                     int minCheckpoints,
                     boolean slotSharing,
-                    int expectedRestarts) {
+                    int expectedRestarts,
+                    long sourceSleepMs) {
                 final int parallelism = env.getParallelism();
                 DataStream<Long> combinedSource = null;
                 for (int inputIndex = 0; inputIndex < NUM_SOURCES; inputIndex++) {
@@ -154,7 +157,8 @@
                                                     minCheckpoints,
                                                     parallelism,
                                                     expectedRestarts,
-                                                    env.getCheckpointInterval()),
+                                                    env.getCheckpointInterval(),
+                                                    sourceSleepMs),
                                             noWatermarks(),
                                             "source" + inputIndex)
                                     .slotSharingGroup(
@@ -182,7 +186,8 @@
                     StreamExecutionEnvironment env,
                     int minCheckpoints,
                     boolean slotSharing,
-                    int expectedRestarts) {
+                    int expectedRestarts,
+                    long sourceSleepMs) {
                 final int parallelism = env.getParallelism();
                 DataStream<Tuple2<Integer, Long>> combinedSource = null;
                 for (int inputIndex = 0; inputIndex < NUM_SOURCES; inputIndex++) {
@@ -193,7 +198,8 @@
                                                     minCheckpoints,
                                                     parallelism,
                                                     expectedRestarts,
-                                                    env.getCheckpointInterval()),
+                                                    env.getCheckpointInterval(),
+                                                    sourceSleepMs),
                                             noWatermarks(),
                                             "source" + inputIndex)
                                     .slotSharingGroup(
diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointRescaleITCase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointRescaleITCase.java
index fd89c38..b6ee276 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointRescaleITCase.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointRescaleITCase.java
@@ -68,6 +68,7 @@
     private final int oldParallelism;
     private final int newParallelism;
     private final int buffersPerChannel;
+    private final long sourceSleepMs;
 
     enum Topology implements DagCreator {
         PIPELINE {
@@ -76,7 +77,8 @@
                     StreamExecutionEnvironment env,
                     int minCheckpoints,
                     boolean slotSharing,
-                    int expectedRestarts) {
+                    int expectedRestarts,
+                    long sourceSleepMillis) {
                 final int parallelism = env.getParallelism();
                 final DataStream<Long> source =
                         createSourcePipeline(
@@ -86,6 +88,7 @@
                                 expectedRestarts,
                                 parallelism,
                                 0,
+                                sourceSleepMillis,
                                 val -> true);
                 addFailingSink(source, minCheckpoints, slotSharing);
             }
@@ -97,7 +100,8 @@
                     StreamExecutionEnvironment env,
                     int minCheckpoints,
                     boolean slotSharing,
-                    int expectedRestarts) {
+                    int expectedRestarts,
+                    long sourceSleepMs) {
 
                 final int parallelism = env.getParallelism();
                 DataStream<Long> combinedSource = null;
@@ -111,6 +115,7 @@
                                     expectedRestarts,
                                     parallelism,
                                     inputIndex,
+                                    sourceSleepMs,
                                     val -> withoutHeader(val) % NUM_SOURCES == finalInputIndex);
                     combinedSource =
                             combinedSource == null
@@ -134,10 +139,10 @@
                     StreamExecutionEnvironment env,
                     int minCheckpoints,
                     boolean slotSharing,
-                    int expectedRestarts) {
+                    int expectedRestarts,
+                    long sourceSleepMs) {
 
                 final int parallelism = env.getParallelism();
-                checkState(parallelism >= 4);
                 final DataStream<Long> source1 =
                         createSourcePipeline(
                                 env,
@@ -146,6 +151,7 @@
                                 expectedRestarts,
                                 parallelism / 2,
                                 0,
+                                sourceSleepMs,
                                 val -> withoutHeader(val) % 2 == 0);
                 final DataStream<Long> source2 =
                         createSourcePipeline(
@@ -155,6 +161,7 @@
                                 expectedRestarts,
                                 parallelism / 3,
                                 1,
+                                sourceSleepMs,
                                 val -> withoutHeader(val) % 2 == 1);
 
                 KeySelector<Long, Long> keySelector = i -> withoutHeader(i) % NUM_GROUPS;
@@ -174,7 +181,8 @@
                     StreamExecutionEnvironment env,
                     int minCheckpoints,
                     boolean slotSharing,
-                    int expectedRestarts) {
+                    int expectedRestarts,
+                    long sourceSleepMs) {
 
                 final int parallelism = env.getParallelism();
                 DataStream<Long> combinedSource = null;
@@ -188,6 +196,7 @@
                                     expectedRestarts,
                                     parallelism,
                                     inputIndex,
+                                    sourceSleepMs,
                                     val -> withoutHeader(val) % NUM_SOURCES == finalInputIndex);
                     combinedSource = combinedSource == null ? source : combinedSource.union(source);
                 }
@@ -202,7 +211,8 @@
                     StreamExecutionEnvironment env,
                     int minCheckpoints,
                     boolean slotSharing,
-                    int expectedRestarts) {
+                    int expectedRestarts,
+                    long sourceSleepMs) {
 
                 final int parallelism = env.getParallelism();
                 final DataStream<Long> broadcastSide =
@@ -211,7 +221,8 @@
                                         minCheckpoints,
                                         parallelism,
                                         expectedRestarts,
-                                        env.getCheckpointInterval()),
+                                        env.getCheckpointInterval(),
+                                        sourceSleepMs),
                                 noWatermarks(),
                                 "source");
                 final DataStream<Long> source =
@@ -222,6 +233,7 @@
                                         expectedRestarts,
                                         parallelism,
                                         0,
+                                        sourceSleepMs,
                                         val -> true)
                                 .map(i -> checkHeader(i))
                                 .name("map")
@@ -249,7 +261,8 @@
                     StreamExecutionEnvironment env,
                     int minCheckpoints,
                     boolean slotSharing,
-                    int expectedRestarts) {
+                    int expectedRestarts,
+                    long sourceSleepMs) {
 
                 final int parallelism = env.getParallelism();
                 final DataStream<Long> broadcastSide1 =
@@ -258,7 +271,8 @@
                                                 minCheckpoints,
                                                 1,
                                                 expectedRestarts,
-                                                env.getCheckpointInterval()),
+                                                env.getCheckpointInterval(),
+                                                sourceSleepMs),
                                         noWatermarks(),
                                         "source-1")
                                 .setParallelism(1);
@@ -268,7 +282,8 @@
                                                 minCheckpoints,
                                                 1,
                                                 expectedRestarts,
-                                                env.getCheckpointInterval()),
+                                                env.getCheckpointInterval(),
+                                                sourceSleepMs),
                                         noWatermarks(),
                                         "source-2")
                                 .setParallelism(1);
@@ -278,7 +293,8 @@
                                                 minCheckpoints,
                                                 1,
                                                 expectedRestarts,
-                                                env.getCheckpointInterval()),
+                                                env.getCheckpointInterval(),
+                                                sourceSleepMs),
                                         noWatermarks(),
                                         "source-3")
                                 .setParallelism(1);
@@ -290,6 +306,7 @@
                                         expectedRestarts,
                                         parallelism,
                                         0,
+                                        sourceSleepMs,
                                         val -> true)
                                 .map(i -> checkHeader(i))
                                 .name("map")
@@ -349,13 +366,15 @@
                 int expectedRestarts,
                 int parallelism,
                 int inputIndex,
+                long sourceSleepMs,
                 FilterFunction<Long> sourceFilter) {
             return env.fromSource(
                             new LongSource(
                                     minCheckpoints,
                                     parallelism,
                                     expectedRestarts,
-                                    env.getCheckpointInterval()),
+                                    env.getCheckpointInterval(),
+                                    sourceSleepMs),
                             noWatermarks(),
                             "source" + inputIndex)
                     .uid("source" + inputIndex)
@@ -459,46 +478,61 @@
         }
     }
 
-    @Parameterized.Parameters(name = "{0} {1} from {2} to {3}, buffersPerChannel = {4}")
+    @Parameterized.Parameters(
+            name = "{0} {1} from {2} to {3}, sourceSleepMs = {4}, buffersPerChannel = {5}")
     public static Object[][] getScaleFactors() {
+        // We use `sourceSleepMs` > 0 to test rescaling without backpressure and only very few
+        // captured in-flight records, see FLINK-31963.
         Object[][] parameters =
                 new Object[][] {
-                    new Object[] {"downscale", Topology.KEYED_DIFFERENT_PARALLELISM, 12, 7},
-                    new Object[] {"upscale", Topology.KEYED_DIFFERENT_PARALLELISM, 7, 12},
-                    new Object[] {"downscale", Topology.KEYED_BROADCAST, 7, 2},
-                    new Object[] {"upscale", Topology.KEYED_BROADCAST, 2, 7},
-                    new Object[] {"downscale", Topology.BROADCAST, 5, 2},
-                    new Object[] {"upscale", Topology.BROADCAST, 2, 5},
-                    new Object[] {"upscale", Topology.PIPELINE, 1, 2},
-                    new Object[] {"upscale", Topology.PIPELINE, 2, 3},
-                    new Object[] {"upscale", Topology.PIPELINE, 3, 7},
-                    new Object[] {"upscale", Topology.PIPELINE, 4, 8},
-                    new Object[] {"upscale", Topology.PIPELINE, 20, 21},
-                    new Object[] {"downscale", Topology.PIPELINE, 2, 1},
-                    new Object[] {"downscale", Topology.PIPELINE, 3, 2},
-                    new Object[] {"downscale", Topology.PIPELINE, 7, 3},
-                    new Object[] {"downscale", Topology.PIPELINE, 8, 4},
-                    new Object[] {"downscale", Topology.PIPELINE, 21, 20},
-                    new Object[] {"no scale", Topology.PIPELINE, 1, 1},
-                    new Object[] {"no scale", Topology.PIPELINE, 3, 3},
-                    new Object[] {"no scale", Topology.PIPELINE, 7, 7},
-                    new Object[] {"no scale", Topology.PIPELINE, 20, 20},
-                    new Object[] {"upscale", Topology.UNION, 1, 2},
-                    new Object[] {"upscale", Topology.UNION, 2, 3},
-                    new Object[] {"upscale", Topology.UNION, 3, 7},
-                    new Object[] {"downscale", Topology.UNION, 2, 1},
-                    new Object[] {"downscale", Topology.UNION, 3, 2},
-                    new Object[] {"downscale", Topology.UNION, 7, 3},
-                    new Object[] {"no scale", Topology.UNION, 1, 1},
-                    new Object[] {"no scale", Topology.UNION, 7, 7},
-                    new Object[] {"upscale", Topology.MULTI_INPUT, 1, 2},
-                    new Object[] {"upscale", Topology.MULTI_INPUT, 2, 3},
-                    new Object[] {"upscale", Topology.MULTI_INPUT, 3, 7},
-                    new Object[] {"downscale", Topology.MULTI_INPUT, 2, 1},
-                    new Object[] {"downscale", Topology.MULTI_INPUT, 3, 2},
-                    new Object[] {"downscale", Topology.MULTI_INPUT, 7, 3},
-                    new Object[] {"no scale", Topology.MULTI_INPUT, 1, 1},
-                    new Object[] {"no scale", Topology.MULTI_INPUT, 7, 7},
+                    new Object[] {"downscale", Topology.KEYED_DIFFERENT_PARALLELISM, 12, 7, 0L},
+                    new Object[] {"upscale", Topology.KEYED_DIFFERENT_PARALLELISM, 7, 12, 0L},
+                    new Object[] {"downscale", Topology.KEYED_DIFFERENT_PARALLELISM, 5, 3, 5L},
+                    new Object[] {"upscale", Topology.KEYED_DIFFERENT_PARALLELISM, 3, 5, 5L},
+                    new Object[] {"downscale", Topology.KEYED_BROADCAST, 7, 2, 0L},
+                    new Object[] {"upscale", Topology.KEYED_BROADCAST, 2, 7, 0L},
+                    new Object[] {"downscale", Topology.KEYED_BROADCAST, 5, 3, 5L},
+                    new Object[] {"upscale", Topology.KEYED_BROADCAST, 3, 5, 5L},
+                    new Object[] {"downscale", Topology.BROADCAST, 5, 2, 0L},
+                    new Object[] {"upscale", Topology.BROADCAST, 2, 5, 0L},
+                    new Object[] {"downscale", Topology.BROADCAST, 5, 3, 5L},
+                    new Object[] {"upscale", Topology.BROADCAST, 3, 5, 5L},
+                    new Object[] {"upscale", Topology.PIPELINE, 1, 2, 0L},
+                    new Object[] {"upscale", Topology.PIPELINE, 2, 3, 0L},
+                    new Object[] {"upscale", Topology.PIPELINE, 3, 7, 0L},
+                    new Object[] {"upscale", Topology.PIPELINE, 4, 8, 0L},
+                    new Object[] {"upscale", Topology.PIPELINE, 20, 21, 0L},
+                    new Object[] {"upscale", Topology.PIPELINE, 3, 5, 5L},
+                    new Object[] {"downscale", Topology.PIPELINE, 2, 1, 0L},
+                    new Object[] {"downscale", Topology.PIPELINE, 3, 2, 0L},
+                    new Object[] {"downscale", Topology.PIPELINE, 7, 3, 0L},
+                    new Object[] {"downscale", Topology.PIPELINE, 8, 4, 0L},
+                    new Object[] {"downscale", Topology.PIPELINE, 21, 20, 0L},
+                    new Object[] {"downscale", Topology.PIPELINE, 5, 3, 5L},
+                    new Object[] {"no scale", Topology.PIPELINE, 1, 1, 0L},
+                    new Object[] {"no scale", Topology.PIPELINE, 3, 3, 0L},
+                    new Object[] {"no scale", Topology.PIPELINE, 7, 7, 0L},
+                    new Object[] {"no scale", Topology.PIPELINE, 20, 20, 0L},
+                    new Object[] {"upscale", Topology.UNION, 1, 2, 0L},
+                    new Object[] {"upscale", Topology.UNION, 2, 3, 0L},
+                    new Object[] {"upscale", Topology.UNION, 3, 7, 0L},
+                    new Object[] {"upscale", Topology.UNION, 3, 5, 5L},
+                    new Object[] {"downscale", Topology.UNION, 2, 1, 0L},
+                    new Object[] {"downscale", Topology.UNION, 3, 2, 0L},
+                    new Object[] {"downscale", Topology.UNION, 7, 3, 0L},
+                    new Object[] {"downscale", Topology.UNION, 5, 3, 5L},
+                    new Object[] {"no scale", Topology.UNION, 1, 1, 0L},
+                    new Object[] {"no scale", Topology.UNION, 7, 7, 0L},
+                    new Object[] {"upscale", Topology.MULTI_INPUT, 1, 2, 0L},
+                    new Object[] {"upscale", Topology.MULTI_INPUT, 2, 3, 0L},
+                    new Object[] {"upscale", Topology.MULTI_INPUT, 3, 7, 0L},
+                    new Object[] {"upscale", Topology.MULTI_INPUT, 3, 5, 5L},
+                    new Object[] {"downscale", Topology.MULTI_INPUT, 2, 1, 0L},
+                    new Object[] {"downscale", Topology.MULTI_INPUT, 3, 2, 0L},
+                    new Object[] {"downscale", Topology.MULTI_INPUT, 7, 3, 0L},
+                    new Object[] {"downscale", Topology.MULTI_INPUT, 5, 3, 5L},
+                    new Object[] {"no scale", Topology.MULTI_INPUT, 1, 1, 0L},
+                    new Object[] {"no scale", Topology.MULTI_INPUT, 7, 7, 0L},
                 };
         return Arrays.stream(parameters)
                 .map(
@@ -516,10 +550,12 @@
             Topology topology,
             int oldParallelism,
             int newParallelism,
+            long sourceSleepMs,
             int buffersPerChannel) {
         this.topology = topology;
         this.oldParallelism = oldParallelism;
         this.newParallelism = newParallelism;
+        this.sourceSleepMs = sourceSleepMs;
         this.buffersPerChannel = buffersPerChannel;
     }
 
@@ -529,7 +565,8 @@
                 new UnalignedSettings(topology)
                         .setParallelism(oldParallelism)
                         .setExpectedFailures(1)
-                        .setBuffersPerChannel(buffersPerChannel);
+                        .setBuffersPerChannel(buffersPerChannel)
+                        .setSourceSleepMs(sourceSleepMs);
         prescaleSettings.setGenerateCheckpoint(true);
         final File checkpointDir = super.execute(prescaleSettings);
 
diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointTestBase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointTestBase.java
index 3420cf6..7ae3ad6 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointTestBase.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointTestBase.java
@@ -207,7 +207,8 @@
                 setupEnv,
                 settings.minCheckpoints,
                 settings.channelType.slotSharing,
-                settings.expectedFailures - settings.failuresAfterSourceFinishes);
+                settings.expectedFailures - settings.failuresAfterSourceFinishes,
+                settings.sourceSleepMs);
 
         return setupEnv.getStreamGraph();
     }
@@ -221,16 +222,19 @@
         private final int numSplits;
         private final int expectedRestarts;
         private final long checkpointingInterval;
+        private final long sourceSleepMs;
 
         protected LongSource(
                 int minCheckpoints,
                 int numSplits,
                 int expectedRestarts,
-                long checkpointingInterval) {
+                long checkpointingInterval,
+                long sourceSleepMs) {
             this.minCheckpoints = minCheckpoints;
             this.numSplits = numSplits;
             this.expectedRestarts = expectedRestarts;
             this.checkpointingInterval = checkpointingInterval;
+            this.sourceSleepMs = sourceSleepMs;
         }
 
         @Override
@@ -244,7 +248,8 @@
                     readerContext.getIndexOfSubtask(),
                     minCheckpoints,
                     expectedRestarts,
-                    checkpointingInterval);
+                    checkpointingInterval,
+                    sourceSleepMs);
         }
 
         @Override
@@ -285,17 +290,20 @@
             private int numCompletedCheckpoints;
             private boolean finishing;
             private boolean recovered;
+            private final long sourceSleepMs;
             @Nullable private Deadline pumpingUntil = null;
 
             public LongSourceReader(
                     int subtaskIndex,
                     int minCheckpoints,
                     int expectedRestarts,
-                    long checkpointingInterval) {
+                    long checkpointingInterval,
+                    long sourceSleepMs) {
                 this.subtaskIndex = subtaskIndex;
                 this.minCheckpoints = minCheckpoints;
                 this.expectedRestarts = expectedRestarts;
-                pumpInterval = Duration.ofMillis(checkpointingInterval);
+                this.pumpInterval = Duration.ofMillis(checkpointingInterval);
+                this.sourceSleepMs = sourceSleepMs;
             }
 
             @Override
@@ -304,6 +312,9 @@
             @Override
             public InputStatus pollNext(ReaderOutput<Long> output) throws InterruptedException {
                 for (LongSplit split : splits) {
+                    if (sourceSleepMs > 0L) {
+                        Thread.sleep(sourceSleepMs);
+                    }
                     output.collect(withHeader(split.nextNumber), split.nextNumber);
                     split.nextNumber += split.increment;
                 }
@@ -627,7 +638,8 @@
                 StreamExecutionEnvironment environment,
                 int minCheckpoints,
                 boolean slotSharing,
-                int expectedFailuresUntilSourceFinishes);
+                int expectedFailuresUntilSourceFinishes,
+                long sourceSleepMs);
     }
 
     /** Which channels are used to connect the tasks. */
@@ -664,6 +676,7 @@
         private int failuresAfterSourceFinishes = 0;
         private ChannelType channelType = ChannelType.MIXED;
         private int buffersPerChannel = 1;
+        private long sourceSleepMs = 0;
 
         public UnalignedSettings(DagCreator dagCreator) {
             this.dagCreator = dagCreator;
@@ -719,6 +732,11 @@
             return this;
         }
 
+        public UnalignedSettings setSourceSleepMs(long sourceSleepMs) {
+            this.sourceSleepMs = sourceSleepMs;
+            return this;
+        }
+
         public void configure(StreamExecutionEnvironment env) {
             env.enableCheckpointing(Math.max(100L, parallelism * 50L));
             env.getCheckpointConfig().setAlignmentTimeout(Duration.ofMillis(alignmentTimeout));
@@ -791,6 +809,8 @@
                     + failuresAfterSourceFinishes
                     + ", channelType="
                     + channelType
+                    + ", sourceSleepMs="
+                    + sourceSleepMs
                     + '}';
         }
     }