[FLINK-38453] Add full splits to KafkaSourceEnumState

KafkaEnumerator's state contains the TopicPartitions only but not the offsets, so it doesn't contain the full split state contrary to the design intent.

There are a couple of issues with that approach. It implicitly assumes that splits are fully assigned to readers before the first checkpoint. Else the enumerator will invoke the offset initializer again on recovery from such a checkpoint leading to inconsistencies (LATEST may be initialized during the first attempt for some partitions and initialized during second attempt for others).

Through addSplitBack callback, you may also get these scenarios later for BATCH which actually leads to duplicate rows (in case of EARLIEST or SPECIFIC-OFFSETS) or data loss (in case of LATEST). Finally, it's not possible to safely use KafkaSource as part of a HybridSource because the offset initializer cannot even be recreated on recovery.

All cases are solved by also retaining the offset in the enumerator state. To that end, this commit merges the async discovery phases to immediately initialize the splits from the partitions. Any subsequent checkpoint will contain the proper start offset.
diff --git a/flink-connector-kafka/archunit-violations/c0d94764-76a0-4c50-b617-70b1754c4612 b/flink-connector-kafka/archunit-violations/c0d94764-76a0-4c50-b617-70b1754c4612
index 236fade..e496d80 100644
--- a/flink-connector-kafka/archunit-violations/c0d94764-76a0-4c50-b617-70b1754c4612
+++ b/flink-connector-kafka/archunit-violations/c0d94764-76a0-4c50-b617-70b1754c4612
@@ -23,11 +23,11 @@
 Method <org.apache.flink.connector.kafka.dynamic.source.reader.DynamicKafkaSourceReader.syncAvailabilityHelperWithReaders()> calls method <org.apache.flink.streaming.runtime.io.MultipleFuturesAvailabilityHelper.anyOf(int, java.util.concurrent.CompletableFuture)> in (DynamicKafkaSourceReader.java:500)
 Method <org.apache.flink.connector.kafka.sink.ExactlyOnceKafkaWriter.getProducerPool()> is annotated with <org.apache.flink.annotation.VisibleForTesting> in (ExactlyOnceKafkaWriter.java:0)
 Method <org.apache.flink.connector.kafka.sink.ExactlyOnceKafkaWriter.getTransactionalIdPrefix()> is annotated with <org.apache.flink.annotation.VisibleForTesting> in (ExactlyOnceKafkaWriter.java:0)
-Method <org.apache.flink.connector.kafka.sink.KafkaSink.addPostCommitTopology(org.apache.flink.streaming.api.datastream.DataStream)> calls method <org.apache.flink.api.dag.Transformation.getCoLocationGroupKey()> in (KafkaSink.java:178)
-Method <org.apache.flink.connector.kafka.sink.KafkaSink.addPostCommitTopology(org.apache.flink.streaming.api.datastream.DataStream)> calls method <org.apache.flink.api.dag.Transformation.getInputs()> in (KafkaSink.java:181)
-Method <org.apache.flink.connector.kafka.sink.KafkaSink.addPostCommitTopology(org.apache.flink.streaming.api.datastream.DataStream)> calls method <org.apache.flink.api.dag.Transformation.getOutputType()> in (KafkaSink.java:177)
-Method <org.apache.flink.connector.kafka.sink.KafkaSink.addPostCommitTopology(org.apache.flink.streaming.api.datastream.DataStream)> calls method <org.apache.flink.api.dag.Transformation.setCoLocationGroupKey(java.lang.String)> in (KafkaSink.java:180)
-Method <org.apache.flink.connector.kafka.sink.KafkaSink.addPostCommitTopology(org.apache.flink.streaming.api.datastream.DataStream)> checks instanceof <org.apache.flink.streaming.api.connector.sink2.CommittableMessageTypeInfo> in (KafkaSink.java:177)
+Method <org.apache.flink.connector.kafka.sink.KafkaSink.addPostCommitTopology(org.apache.flink.streaming.api.datastream.DataStream)> calls method <org.apache.flink.api.dag.Transformation.getCoLocationGroupKey()> in (KafkaSink.java:183)
+Method <org.apache.flink.connector.kafka.sink.KafkaSink.addPostCommitTopology(org.apache.flink.streaming.api.datastream.DataStream)> calls method <org.apache.flink.api.dag.Transformation.getInputs()> in (KafkaSink.java:186)
+Method <org.apache.flink.connector.kafka.sink.KafkaSink.addPostCommitTopology(org.apache.flink.streaming.api.datastream.DataStream)> calls method <org.apache.flink.api.dag.Transformation.getOutputType()> in (KafkaSink.java:182)
+Method <org.apache.flink.connector.kafka.sink.KafkaSink.addPostCommitTopology(org.apache.flink.streaming.api.datastream.DataStream)> calls method <org.apache.flink.api.dag.Transformation.setCoLocationGroupKey(java.lang.String)> in (KafkaSink.java:185)
+Method <org.apache.flink.connector.kafka.sink.KafkaSink.addPostCommitTopology(org.apache.flink.streaming.api.datastream.DataStream)> checks instanceof <org.apache.flink.streaming.api.connector.sink2.CommittableMessageTypeInfo> in (KafkaSink.java:182)
 Method <org.apache.flink.connector.kafka.sink.KafkaSink.addPostCommitTopology(org.apache.flink.streaming.api.datastream.DataStream)> has generic parameter type <org.apache.flink.streaming.api.datastream.DataStream<org.apache.flink.streaming.api.connector.sink2.CommittableMessage<org.apache.flink.connector.kafka.sink.KafkaCommittable>>> with type argument depending on <org.apache.flink.streaming.api.connector.sink2.CommittableMessage> in (KafkaSink.java:0)
 Method <org.apache.flink.connector.kafka.sink.KafkaSink.getKafkaProducerConfig()> is annotated with <org.apache.flink.annotation.VisibleForTesting> in (KafkaSink.java:0)
 Method <org.apache.flink.connector.kafka.sink.KafkaSinkBuilder.setRecordSerializer(org.apache.flink.connector.kafka.sink.KafkaRecordSerializationSchema)> calls method <org.apache.flink.api.java.ClosureCleaner.clean(java.lang.Object, org.apache.flink.api.common.ExecutionConfig$ClosureCleanerLevel, boolean)> in (KafkaSinkBuilder.java:154)
@@ -39,9 +39,12 @@
 Method <org.apache.flink.connector.kafka.source.KafkaSource.getConfiguration()> is annotated with <org.apache.flink.annotation.VisibleForTesting> in (KafkaSource.java:0)
 Method <org.apache.flink.connector.kafka.source.KafkaSource.getKafkaSubscriber()> is annotated with <org.apache.flink.annotation.VisibleForTesting> in (KafkaSource.java:0)
 Method <org.apache.flink.connector.kafka.source.KafkaSource.getStoppingOffsetsInitializer()> is annotated with <org.apache.flink.annotation.VisibleForTesting> in (KafkaSource.java:0)
-Method <org.apache.flink.connector.kafka.source.enumerator.KafkaSourceEnumStateSerializer.serializeTopicPartitions(java.util.Collection)> is annotated with <org.apache.flink.annotation.VisibleForTesting> in (KafkaSourceEnumStateSerializer.java:0)
+Method <org.apache.flink.connector.kafka.source.enumerator.KafkaSourceEnumStateSerializer.serializeV1(java.util.Collection)> is annotated with <org.apache.flink.annotation.VisibleForTesting> in (KafkaSourceEnumStateSerializer.java:0)
+Method <org.apache.flink.connector.kafka.source.enumerator.KafkaSourceEnumStateSerializer.serializeV2(java.util.Collection, boolean)> is annotated with <org.apache.flink.annotation.VisibleForTesting> in (KafkaSourceEnumStateSerializer.java:0)
+Method <org.apache.flink.connector.kafka.source.enumerator.KafkaSourceEnumStateSerializer.serializeV3(org.apache.flink.connector.kafka.source.enumerator.KafkaSourceEnumState)> is annotated with <org.apache.flink.annotation.VisibleForTesting> in (KafkaSourceEnumStateSerializer.java:0)
 Method <org.apache.flink.connector.kafka.source.enumerator.KafkaSourceEnumerator.deepCopyProperties(java.util.Properties, java.util.Properties)> is annotated with <org.apache.flink.annotation.VisibleForTesting> in (KafkaSourceEnumerator.java:0)
 Method <org.apache.flink.connector.kafka.source.enumerator.KafkaSourceEnumerator.getPartitionChange(java.util.Set)> is annotated with <org.apache.flink.annotation.VisibleForTesting> in (KafkaSourceEnumerator.java:0)
+Method <org.apache.flink.connector.kafka.source.enumerator.KafkaSourceEnumerator.getPendingPartitionSplitAssignment()> is annotated with <org.apache.flink.annotation.VisibleForTesting> in (KafkaSourceEnumerator.java:0)
 Method <org.apache.flink.connector.kafka.source.enumerator.KafkaSourceEnumerator.getSplitOwner(org.apache.kafka.common.TopicPartition, int)> is annotated with <org.apache.flink.annotation.VisibleForTesting> in (KafkaSourceEnumerator.java:0)
 Method <org.apache.flink.connector.kafka.source.reader.KafkaPartitionSplitReader.consumer()> is annotated with <org.apache.flink.annotation.VisibleForTesting> in (KafkaPartitionSplitReader.java:0)
 Method <org.apache.flink.connector.kafka.source.reader.KafkaPartitionSplitReader.setConsumerClientRack(java.util.Properties, java.lang.String)> is annotated with <org.apache.flink.annotation.VisibleForTesting> in (KafkaPartitionSplitReader.java:0)
diff --git a/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/dynamic/source/enumerator/DynamicKafkaSourceEnumerator.java b/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/dynamic/source/enumerator/DynamicKafkaSourceEnumerator.java
index ff7cc21..7643e62 100644
--- a/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/dynamic/source/enumerator/DynamicKafkaSourceEnumerator.java
+++ b/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/dynamic/source/enumerator/DynamicKafkaSourceEnumerator.java
@@ -35,14 +35,13 @@
 import org.apache.flink.connector.kafka.source.KafkaPropertiesUtil;
 import org.apache.flink.connector.kafka.source.enumerator.KafkaSourceEnumState;
 import org.apache.flink.connector.kafka.source.enumerator.KafkaSourceEnumerator;
-import org.apache.flink.connector.kafka.source.enumerator.TopicPartitionAndAssignmentStatus;
+import org.apache.flink.connector.kafka.source.enumerator.SplitAndAssignmentStatus;
 import org.apache.flink.connector.kafka.source.enumerator.initializer.OffsetsInitializer;
 import org.apache.flink.connector.kafka.source.enumerator.subscriber.KafkaSubscriber;
 import org.apache.flink.connector.kafka.source.split.KafkaPartitionSplit;
 import org.apache.flink.util.Preconditions;
 
 import org.apache.kafka.common.KafkaException;
-import org.apache.kafka.common.TopicPartition;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -174,8 +173,8 @@
                 dynamicKafkaSourceEnumState.getClusterEnumeratorStates().entrySet()) {
             this.latestClusterTopicsMap.put(
                     clusterEnumState.getKey(),
-                    clusterEnumState.getValue().assignedPartitions().stream()
-                            .map(TopicPartition::topic)
+                    clusterEnumState.getValue().assignedSplits().stream()
+                            .map(KafkaPartitionSplit::getTopic)
                             .collect(Collectors.toSet()));
 
             createEnumeratorWithAssignedTopicPartitions(
@@ -291,9 +290,9 @@
                 final Set<String> activeTopics = activeClusterTopics.getValue();
 
                 // filter out removed topics
-                Set<TopicPartitionAndAssignmentStatus> partitions =
-                        kafkaSourceEnumState.partitions().stream()
-                                .filter(tp -> activeTopics.contains(tp.topicPartition().topic()))
+                Set<SplitAndAssignmentStatus> partitions =
+                        kafkaSourceEnumState.splits().stream()
+                                .filter(tp -> activeTopics.contains(tp.split().getTopic()))
                                 .collect(Collectors.toSet());
 
                 newKafkaSourceEnumState =
diff --git a/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/source/enumerator/AssignmentStatus.java b/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/source/enumerator/AssignmentStatus.java
index b7d1153..e8f9600 100644
--- a/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/source/enumerator/AssignmentStatus.java
+++ b/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/source/enumerator/AssignmentStatus.java
@@ -26,11 +26,8 @@
 
     /** Partitions that have been assigned to readers. */
     ASSIGNED(0),
-    /**
-     * The partitions that have been discovered during initialization but not assigned to readers
-     * yet.
-     */
-    UNASSIGNED_INITIAL(1);
+    /** The partitions that have been discovered but not assigned to readers yet. */
+    UNASSIGNED(1);
     private final int statusCode;
 
     AssignmentStatus(int statusCode) {
diff --git a/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/source/enumerator/KafkaSourceEnumState.java b/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/source/enumerator/KafkaSourceEnumState.java
index 66ceeeb..649bd58 100644
--- a/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/source/enumerator/KafkaSourceEnumState.java
+++ b/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/source/enumerator/KafkaSourceEnumState.java
@@ -19,9 +19,9 @@
 package org.apache.flink.connector.kafka.source.enumerator;
 
 import org.apache.flink.annotation.Internal;
+import org.apache.flink.connector.kafka.source.split.KafkaPartitionSplit;
 
-import org.apache.kafka.common.TopicPartition;
-
+import java.util.Collection;
 import java.util.HashSet;
 import java.util.Set;
 import java.util.stream.Collectors;
@@ -29,8 +29,8 @@
 /** The state of Kafka source enumerator. */
 @Internal
 public class KafkaSourceEnumState {
-    /** Partitions with status: ASSIGNED or UNASSIGNED_INITIAL. */
-    private final Set<TopicPartitionAndAssignmentStatus> partitions;
+    /** Splits with status: ASSIGNED or UNASSIGNED_INITIAL. */
+    private final Set<SplitAndAssignmentStatus> splits;
     /**
      * this flag will be marked as true if initial partitions are discovered after enumerator
      * starts.
@@ -38,57 +38,54 @@
     private final boolean initialDiscoveryFinished;
 
     public KafkaSourceEnumState(
-            Set<TopicPartitionAndAssignmentStatus> partitions, boolean initialDiscoveryFinished) {
-        this.partitions = partitions;
+            Set<SplitAndAssignmentStatus> splits, boolean initialDiscoveryFinished) {
+        this.splits = splits;
         this.initialDiscoveryFinished = initialDiscoveryFinished;
     }
 
     public KafkaSourceEnumState(
-            Set<TopicPartition> assignPartitions,
-            Set<TopicPartition> unassignedInitialPartitions,
+            Collection<KafkaPartitionSplit> assignedSplits,
+            Collection<KafkaPartitionSplit> unassignedSplits,
             boolean initialDiscoveryFinished) {
-        this.partitions = new HashSet<>();
-        partitions.addAll(
-                assignPartitions.stream()
+        this.splits = new HashSet<>();
+        splits.addAll(
+                assignedSplits.stream()
                         .map(
                                 topicPartition ->
-                                        new TopicPartitionAndAssignmentStatus(
+                                        new SplitAndAssignmentStatus(
                                                 topicPartition, AssignmentStatus.ASSIGNED))
                         .collect(Collectors.toSet()));
-        partitions.addAll(
-                unassignedInitialPartitions.stream()
+        splits.addAll(
+                unassignedSplits.stream()
                         .map(
                                 topicPartition ->
-                                        new TopicPartitionAndAssignmentStatus(
-                                                topicPartition,
-                                                AssignmentStatus.UNASSIGNED_INITIAL))
+                                        new SplitAndAssignmentStatus(
+                                                topicPartition, AssignmentStatus.UNASSIGNED))
                         .collect(Collectors.toSet()));
         this.initialDiscoveryFinished = initialDiscoveryFinished;
     }
 
-    public Set<TopicPartitionAndAssignmentStatus> partitions() {
-        return partitions;
+    public Set<SplitAndAssignmentStatus> splits() {
+        return splits;
     }
 
-    public Set<TopicPartition> assignedPartitions() {
-        return filterPartitionsByAssignmentStatus(AssignmentStatus.ASSIGNED);
+    public Collection<KafkaPartitionSplit> assignedSplits() {
+        return filterByAssignmentStatus(AssignmentStatus.ASSIGNED);
     }
 
-    public Set<TopicPartition> unassignedInitialPartitions() {
-        return filterPartitionsByAssignmentStatus(AssignmentStatus.UNASSIGNED_INITIAL);
+    public Collection<KafkaPartitionSplit> unassignedSplits() {
+        return filterByAssignmentStatus(AssignmentStatus.UNASSIGNED);
     }
 
     public boolean initialDiscoveryFinished() {
         return initialDiscoveryFinished;
     }
 
-    private Set<TopicPartition> filterPartitionsByAssignmentStatus(
+    private Collection<KafkaPartitionSplit> filterByAssignmentStatus(
             AssignmentStatus assignmentStatus) {
-        return partitions.stream()
-                .filter(
-                        partitionWithStatus ->
-                                partitionWithStatus.assignmentStatus().equals(assignmentStatus))
-                .map(TopicPartitionAndAssignmentStatus::topicPartition)
-                .collect(Collectors.toSet());
+        return splits.stream()
+                .filter(split -> split.assignmentStatus().equals(assignmentStatus))
+                .map(SplitAndAssignmentStatus::split)
+                .collect(Collectors.toList());
     }
 }
diff --git a/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/source/enumerator/KafkaSourceEnumStateSerializer.java b/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/source/enumerator/KafkaSourceEnumStateSerializer.java
index f8dc17d..99176cf 100644
--- a/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/source/enumerator/KafkaSourceEnumStateSerializer.java
+++ b/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/source/enumerator/KafkaSourceEnumStateSerializer.java
@@ -37,6 +37,8 @@
 import java.util.Map;
 import java.util.Set;
 
+import static org.apache.flink.connector.kafka.source.split.KafkaPartitionSplit.MIGRATED;
+
 /**
  * The {@link org.apache.flink.core.io.SimpleVersionedSerializer Serializer} for the enumerator
  * state of Kafka source.
@@ -58,7 +60,12 @@
      */
     private static final int VERSION_2 = 2;
 
-    private static final int CURRENT_VERSION = VERSION_2;
+    private static final int VERSION_3 = 3;
+
+    private static final int CURRENT_VERSION = VERSION_3;
+
+    private static final KafkaPartitionSplitSerializer SPLIT_SERIALIZER =
+            new KafkaPartitionSplitSerializer();
 
     @Override
     public int getVersion() {
@@ -67,15 +74,22 @@
 
     @Override
     public byte[] serialize(KafkaSourceEnumState enumState) throws IOException {
-        Set<TopicPartitionAndAssignmentStatus> partitions = enumState.partitions();
+        return serializeV3(enumState);
+    }
+
+    @VisibleForTesting
+    static byte[] serializeV3(KafkaSourceEnumState enumState) throws IOException {
+        Set<SplitAndAssignmentStatus> splits = enumState.splits();
         boolean initialDiscoveryFinished = enumState.initialDiscoveryFinished();
         try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
                 DataOutputStream out = new DataOutputStream(baos)) {
-            out.writeInt(partitions.size());
-            for (TopicPartitionAndAssignmentStatus topicPartitionAndAssignmentStatus : partitions) {
-                out.writeUTF(topicPartitionAndAssignmentStatus.topicPartition().topic());
-                out.writeInt(topicPartitionAndAssignmentStatus.topicPartition().partition());
-                out.writeInt(topicPartitionAndAssignmentStatus.assignmentStatus().getStatusCode());
+            out.writeInt(splits.size());
+            out.writeInt(SPLIT_SERIALIZER.getVersion());
+            for (SplitAndAssignmentStatus split : splits) {
+                final byte[] splitBytes = SPLIT_SERIALIZER.serialize(split.split());
+                out.writeInt(splitBytes.length);
+                out.write(splitBytes);
+                out.writeInt(split.assignmentStatus().getStatusCode());
             }
             out.writeBoolean(initialDiscoveryFinished);
             out.flush();
@@ -86,22 +100,14 @@
     @Override
     public KafkaSourceEnumState deserialize(int version, byte[] serialized) throws IOException {
         switch (version) {
-            case CURRENT_VERSION:
-                return deserializeTopicPartitionAndAssignmentStatus(serialized);
+            case VERSION_3:
+                return deserializeVersion3(serialized);
+            case VERSION_2:
+                return deserializeVersion2(serialized);
             case VERSION_1:
-                return deserializeAssignedTopicPartitions(serialized);
+                return deserializeVersion1(serialized);
             case VERSION_0:
-                Map<Integer, Set<KafkaPartitionSplit>> currentPartitionAssignment =
-                        SerdeUtils.deserializeSplitAssignments(
-                                serialized, new KafkaPartitionSplitSerializer(), HashSet::new);
-                Set<TopicPartition> currentAssignedSplits = new HashSet<>();
-                currentPartitionAssignment.forEach(
-                        (reader, splits) ->
-                                splits.forEach(
-                                        split ->
-                                                currentAssignedSplits.add(
-                                                        split.getTopicPartition())));
-                return new KafkaSourceEnumState(currentAssignedSplits, new HashSet<>(), true);
+                return deserializeVersion0(serialized);
             default:
                 throw new IOException(
                         String.format(
@@ -111,44 +117,24 @@
         }
     }
 
-    private static KafkaSourceEnumState deserializeAssignedTopicPartitions(
-            byte[] serializedTopicPartitions) throws IOException {
-        try (ByteArrayInputStream bais = new ByteArrayInputStream(serializedTopicPartitions);
-                DataInputStream in = new DataInputStream(bais)) {
+    private static KafkaSourceEnumState deserializeVersion3(byte[] serialized) throws IOException {
 
-            final int numPartitions = in.readInt();
-            Set<TopicPartitionAndAssignmentStatus> partitions = new HashSet<>(numPartitions);
-            for (int i = 0; i < numPartitions; i++) {
-                final String topic = in.readUTF();
-                final int partition = in.readInt();
-                partitions.add(
-                        new TopicPartitionAndAssignmentStatus(
-                                new TopicPartition(topic, partition), AssignmentStatus.ASSIGNED));
-            }
-            if (in.available() > 0) {
-                throw new IOException("Unexpected trailing bytes in serialized topic partitions");
-            }
-            return new KafkaSourceEnumState(partitions, true);
-        }
-    }
-
-    private static KafkaSourceEnumState deserializeTopicPartitionAndAssignmentStatus(
-            byte[] serialized) throws IOException {
+        final KafkaPartitionSplitSerializer splitSerializer = new KafkaPartitionSplitSerializer();
 
         try (ByteArrayInputStream bais = new ByteArrayInputStream(serialized);
                 DataInputStream in = new DataInputStream(bais)) {
 
             final int numPartitions = in.readInt();
-            Set<TopicPartitionAndAssignmentStatus> partitions = new HashSet<>(numPartitions);
+            final int splitVersion = in.readInt();
+            Set<SplitAndAssignmentStatus> partitions = new HashSet<>(numPartitions);
 
             for (int i = 0; i < numPartitions; i++) {
-                final String topic = in.readUTF();
-                final int partition = in.readInt();
+                final KafkaPartitionSplit split =
+                        splitSerializer.deserialize(splitVersion, in.readNBytes(in.readInt()));
                 final int statusCode = in.readInt();
                 partitions.add(
-                        new TopicPartitionAndAssignmentStatus(
-                                new TopicPartition(topic, partition),
-                                AssignmentStatus.ofStatusCode(statusCode)));
+                        new SplitAndAssignmentStatus(
+                                split, AssignmentStatus.ofStatusCode(statusCode)));
             }
             final boolean initialDiscoveryFinished = in.readBoolean();
             if (in.available() > 0) {
@@ -159,14 +145,26 @@
         }
     }
 
+    private static KafkaSourceEnumState deserializeVersion0(byte[] serialized) throws IOException {
+        Map<Integer, Set<KafkaPartitionSplit>> currentPartitionAssignment =
+                SerdeUtils.deserializeSplitAssignments(
+                        serialized, new KafkaPartitionSplitSerializer(), HashSet::new);
+        Set<KafkaPartitionSplit> currentAssignedSplits = new HashSet<>();
+        for (Map.Entry<Integer, Set<KafkaPartitionSplit>> entry :
+                currentPartitionAssignment.entrySet()) {
+            currentAssignedSplits.addAll(entry.getValue());
+        }
+        return new KafkaSourceEnumState(currentAssignedSplits, new HashSet<>(), true);
+    }
+
     @VisibleForTesting
-    public static byte[] serializeTopicPartitions(Collection<TopicPartition> topicPartitions)
-            throws IOException {
+    static byte[] serializeV1(Collection<KafkaPartitionSplit> splits) throws IOException {
         try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
                 DataOutputStream out = new DataOutputStream(baos)) {
 
-            out.writeInt(topicPartitions.size());
-            for (TopicPartition tp : topicPartitions) {
+            out.writeInt(splits.size());
+            for (KafkaPartitionSplit split : splits) {
+                final TopicPartition tp = split.getTopicPartition();
                 out.writeUTF(tp.topic());
                 out.writeInt(tp.partition());
             }
@@ -175,4 +173,74 @@
             return baos.toByteArray();
         }
     }
+
+    private static KafkaSourceEnumState deserializeVersion1(byte[] serializedTopicPartitions)
+            throws IOException {
+        try (ByteArrayInputStream bais = new ByteArrayInputStream(serializedTopicPartitions);
+                DataInputStream in = new DataInputStream(bais)) {
+
+            final int numPartitions = in.readInt();
+            Set<SplitAndAssignmentStatus> partitions = new HashSet<>(numPartitions);
+            for (int i = 0; i < numPartitions; i++) {
+                final String topic = in.readUTF();
+                final int partition = in.readInt();
+                partitions.add(
+                        new SplitAndAssignmentStatus(
+                                new KafkaPartitionSplit(
+                                        new TopicPartition(topic, partition), MIGRATED),
+                                AssignmentStatus.ASSIGNED));
+            }
+            if (in.available() > 0) {
+                throw new IOException("Unexpected trailing bytes in serialized topic partitions");
+            }
+            return new KafkaSourceEnumState(partitions, true);
+        }
+    }
+
+    @VisibleForTesting
+    static byte[] serializeV2(
+            Collection<SplitAndAssignmentStatus> splits, boolean initialDiscoveryFinished)
+            throws IOException {
+        try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
+                DataOutputStream out = new DataOutputStream(baos)) {
+            out.writeInt(splits.size());
+            for (SplitAndAssignmentStatus splitAndAssignmentStatus : splits) {
+                final TopicPartition topicPartition =
+                        splitAndAssignmentStatus.split().getTopicPartition();
+                out.writeUTF(topicPartition.topic());
+                out.writeInt(topicPartition.partition());
+                out.writeInt(splitAndAssignmentStatus.assignmentStatus().getStatusCode());
+            }
+            out.writeBoolean(initialDiscoveryFinished);
+            out.flush();
+            return baos.toByteArray();
+        }
+    }
+
+    private static KafkaSourceEnumState deserializeVersion2(byte[] serialized) throws IOException {
+
+        try (ByteArrayInputStream bais = new ByteArrayInputStream(serialized);
+                DataInputStream in = new DataInputStream(bais)) {
+
+            final int numPartitions = in.readInt();
+            Set<SplitAndAssignmentStatus> partitions = new HashSet<>(numPartitions);
+
+            for (int i = 0; i < numPartitions; i++) {
+                final String topic = in.readUTF();
+                final int partition = in.readInt();
+                final int statusCode = in.readInt();
+                partitions.add(
+                        new SplitAndAssignmentStatus(
+                                new KafkaPartitionSplit(
+                                        new TopicPartition(topic, partition), MIGRATED),
+                                AssignmentStatus.ofStatusCode(statusCode)));
+            }
+            final boolean initialDiscoveryFinished = in.readBoolean();
+            if (in.available() > 0) {
+                throw new IOException("Unexpected trailing bytes in serialized topic partitions");
+            }
+
+            return new KafkaSourceEnumState(partitions, initialDiscoveryFinished);
+        }
+    }
 }
diff --git a/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/source/enumerator/KafkaSourceEnumerator.java b/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/source/enumerator/KafkaSourceEnumerator.java
index f305819..e65e9a5 100644
--- a/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/source/enumerator/KafkaSourceEnumerator.java
+++ b/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/source/enumerator/KafkaSourceEnumerator.java
@@ -56,6 +56,9 @@
 import java.util.concurrent.ExecutionException;
 import java.util.function.Consumer;
 import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+import static org.apache.flink.util.Preconditions.checkState;
 
 /** The enumerator class for Kafka source. */
 @Internal
@@ -72,13 +75,12 @@
     private final Boundedness boundedness;
 
     /** Partitions that have been assigned to readers. */
-    private final Set<TopicPartition> assignedPartitions;
+    private final Map<TopicPartition, KafkaPartitionSplit> assignedSplits;
 
     /**
-     * The partitions that have been discovered during initialization but not assigned to readers
-     * yet.
+     * The splits that have been discovered during initialization but not assigned to readers yet.
      */
-    private final Set<TopicPartition> unassignedInitialPartitions;
+    private final Map<TopicPartition, KafkaPartitionSplit> unassignedSplits;
 
     /**
      * The discovered and initialized partition splits that are waiting for owner reader to be
@@ -96,7 +98,8 @@
     // initializing partition discovery has finished.
     private boolean noMoreNewPartitionSplits = false;
     // this flag will be marked as true if initial partitions are discovered after enumerator starts
-    private boolean initialDiscoveryFinished;
+    // the flag is read and set in main thread but also read in worker thread
+    private volatile boolean initialDiscoveryFinished;
 
     public KafkaSourceEnumerator(
             KafkaSubscriber subscriber,
@@ -131,7 +134,10 @@
         this.context = context;
         this.boundedness = boundedness;
 
-        this.assignedPartitions = new HashSet<>(kafkaSourceEnumState.assignedPartitions());
+        Map<AssignmentStatus, List<KafkaPartitionSplit>> splits =
+                initializeMigratedSplits(kafkaSourceEnumState.splits());
+        this.assignedSplits = indexByPartition(splits.get(AssignmentStatus.ASSIGNED));
+        this.unassignedSplits = indexByPartition(splits.get(AssignmentStatus.UNASSIGNED));
         this.pendingPartitionSplitAssignment = new HashMap<>();
         this.partitionDiscoveryIntervalMs =
                 KafkaSourceOptions.getOption(
@@ -139,12 +145,74 @@
                         KafkaSourceOptions.PARTITION_DISCOVERY_INTERVAL_MS,
                         Long::parseLong);
         this.consumerGroupId = properties.getProperty(ConsumerConfig.GROUP_ID_CONFIG);
-        this.unassignedInitialPartitions =
-                new HashSet<>(kafkaSourceEnumState.unassignedInitialPartitions());
         this.initialDiscoveryFinished = kafkaSourceEnumState.initialDiscoveryFinished();
     }
 
     /**
+     * Initialize migrated splits to splits with concrete starting offsets. This method ensures that
+     * the costly offset resolution is performed only when there are splits that have been
+     * checkpointed with previous enumerator versions.
+     *
+     * <p>Note that this method is deliberately performed in the main thread to avoid a checkpoint
+     * of the splits without starting offset.
+     */
+    private Map<AssignmentStatus, List<KafkaPartitionSplit>> initializeMigratedSplits(
+            Set<SplitAndAssignmentStatus> splits) {
+        final Set<TopicPartition> migratedPartitions =
+                splits.stream()
+                        .filter(
+                                splitStatus ->
+                                        splitStatus.split().getStartingOffset()
+                                                == KafkaPartitionSplit.MIGRATED)
+                        .map(splitStatus -> splitStatus.split().getTopicPartition())
+                        .collect(Collectors.toSet());
+
+        if (migratedPartitions.isEmpty()) {
+            return splitByAssignmentStatus(splits.stream());
+        }
+
+        final Map<TopicPartition, Long> startOffsets =
+                startingOffsetInitializer.getPartitionOffsets(
+                        migratedPartitions, getOffsetsRetriever());
+        return splitByAssignmentStatus(
+                splits.stream()
+                        .map(splitStatus -> resolveMigratedSplit(splitStatus, startOffsets)));
+    }
+
+    private static Map<AssignmentStatus, List<KafkaPartitionSplit>> splitByAssignmentStatus(
+            Stream<SplitAndAssignmentStatus> stream) {
+        return stream.collect(
+                Collectors.groupingBy(
+                        SplitAndAssignmentStatus::assignmentStatus,
+                        Collectors.mapping(SplitAndAssignmentStatus::split, Collectors.toList())));
+    }
+
+    private static SplitAndAssignmentStatus resolveMigratedSplit(
+            SplitAndAssignmentStatus splitStatus, Map<TopicPartition, Long> startOffsets) {
+        final KafkaPartitionSplit split = splitStatus.split();
+        if (split.getStartingOffset() != KafkaPartitionSplit.MIGRATED) {
+            return splitStatus;
+        }
+        final Long startOffset = startOffsets.get(split.getTopicPartition());
+        checkState(
+                startOffset != null,
+                "Cannot find starting offset for migrated partition %s",
+                split.getTopicPartition());
+        return new SplitAndAssignmentStatus(
+                new KafkaPartitionSplit(split.getTopicPartition(), startOffset),
+                splitStatus.assignmentStatus());
+    }
+
+    private Map<TopicPartition, KafkaPartitionSplit> indexByPartition(
+            List<KafkaPartitionSplit> splits) {
+        if (splits == null) {
+            return new HashMap<>();
+        }
+        return splits.stream()
+                .collect(Collectors.toMap(KafkaPartitionSplit::getTopicPartition, split -> split));
+    }
+
+    /**
      * Start the enumerator.
      *
      * <p>Depending on {@link #partitionDiscoveryIntervalMs}, the enumerator will trigger a one-time
@@ -153,9 +221,7 @@
      * <p>The invoking chain of partition discovery would be:
      *
      * <ol>
-     *   <li>{@link #getSubscribedTopicPartitions} in worker thread
-     *   <li>{@link #checkPartitionChanges} in coordinator thread
-     *   <li>{@link #initializePartitionSplits} in worker thread
+     *   <li>{@link #findNewPartitionSplits} in worker thread
      *   <li>{@link #handlePartitionSplitChanges} in coordinator thread
      * </ol>
      */
@@ -169,8 +235,8 @@
                     consumerGroupId,
                     partitionDiscoveryIntervalMs);
             context.callAsync(
-                    this::getSubscribedTopicPartitions,
-                    this::checkPartitionChanges,
+                    this::findNewPartitionSplits,
+                    this::handlePartitionSplitChanges,
                     0,
                     partitionDiscoveryIntervalMs);
         } else {
@@ -178,7 +244,7 @@
                     "Starting the KafkaSourceEnumerator for consumer group {} "
                             + "without periodic partition discovery.",
                     consumerGroupId);
-            context.callAsync(this::getSubscribedTopicPartitions, this::checkPartitionChanges);
+            context.callAsync(this::findNewPartitionSplits, this::handlePartitionSplitChanges);
         }
     }
 
@@ -189,6 +255,9 @@
 
     @Override
     public void addSplitsBack(List<KafkaPartitionSplit> splits, int subtaskId) {
+        for (KafkaPartitionSplit split : splits) {
+            unassignedSplits.put(split.getTopicPartition(), split);
+        }
         addPartitionSplitChangeToPendingAssignments(splits);
 
         // If the failed subtask has already restarted, we need to assign pending splits to it
@@ -209,7 +278,7 @@
     @Override
     public KafkaSourceEnumState snapshotState(long checkpointId) throws Exception {
         return new KafkaSourceEnumState(
-                assignedPartitions, unassignedInitialPartitions, initialDiscoveryFinished);
+                assignedSplits.values(), unassignedSplits.values(), initialDiscoveryFinished);
     }
 
     @Override
@@ -229,38 +298,16 @@
      *
      * @return Set of subscribed {@link TopicPartition}s
      */
-    private Set<TopicPartition> getSubscribedTopicPartitions() {
-        return subscriber.getSubscribedTopicPartitions(adminClient);
-    }
-
-    /**
-     * Check if there's any partition changes within subscribed topic partitions fetched by worker
-     * thread, and invoke {@link KafkaSourceEnumerator#initializePartitionSplits(PartitionChange)}
-     * in worker thread to initialize splits for new partitions.
-     *
-     * <p>NOTE: This method should only be invoked in the coordinator executor thread.
-     *
-     * @param fetchedPartitions Map from topic name to its description
-     * @param t Exception in worker thread
-     */
-    private void checkPartitionChanges(Set<TopicPartition> fetchedPartitions, Throwable t) {
-        if (t != null) {
-            throw new FlinkRuntimeException(
-                    "Failed to list subscribed topic partitions due to ", t);
-        }
-
-        if (!initialDiscoveryFinished) {
-            unassignedInitialPartitions.addAll(fetchedPartitions);
-            initialDiscoveryFinished = true;
-        }
+    private PartitionSplitChange findNewPartitionSplits() {
+        final Set<TopicPartition> fetchedPartitions =
+                subscriber.getSubscribedTopicPartitions(adminClient);
 
         final PartitionChange partitionChange = getPartitionChange(fetchedPartitions);
         if (partitionChange.isEmpty()) {
-            return;
+            return null;
         }
-        context.callAsync(
-                () -> initializePartitionSplits(partitionChange),
-                this::handlePartitionSplitChanges);
+
+        return initializePartitionSplits(partitionChange);
     }
 
     /**
@@ -290,13 +337,14 @@
         OffsetsInitializer.PartitionOffsetsRetriever offsetsRetriever = getOffsetsRetriever();
         // initial partitions use OffsetsInitializer specified by the user while new partitions use
         // EARLIEST
-        Map<TopicPartition, Long> startingOffsets = new HashMap<>();
-        startingOffsets.putAll(
-                newDiscoveryOffsetsInitializer.getPartitionOffsets(
-                        newPartitions, offsetsRetriever));
-        startingOffsets.putAll(
-                startingOffsetInitializer.getPartitionOffsets(
-                        unassignedInitialPartitions, offsetsRetriever));
+        final OffsetsInitializer initializer;
+        if (!initialDiscoveryFinished) {
+            initializer = startingOffsetInitializer;
+        } else {
+            initializer = newDiscoveryOffsetsInitializer;
+        }
+        Map<TopicPartition, Long> startingOffsets =
+                initializer.getPartitionOffsets(newPartitions, offsetsRetriever);
 
         Map<TopicPartition, Long> stoppingOffsets =
                 stoppingOffsetInitializer.getPartitionOffsets(newPartitions, offsetsRetriever);
@@ -322,14 +370,21 @@
      * @param t Exception in worker thread
      */
     private void handlePartitionSplitChanges(
-            PartitionSplitChange partitionSplitChange, Throwable t) {
+            @Nullable PartitionSplitChange partitionSplitChange, Throwable t) {
         if (t != null) {
             throw new FlinkRuntimeException("Failed to initialize partition splits due to ", t);
         }
+        initialDiscoveryFinished = true;
         if (partitionDiscoveryIntervalMs <= 0) {
             LOG.debug("Partition discovery is disabled.");
             noMoreNewPartitionSplits = true;
         }
+        if (partitionSplitChange == null) {
+            return;
+        }
+        for (KafkaPartitionSplit split : partitionSplitChange.newPartitionSplits) {
+            unassignedSplits.put(split.getTopicPartition(), split);
+        }
         // TODO: Handle removed partitions.
         addPartitionSplitChangeToPendingAssignments(partitionSplitChange.newPartitionSplits);
         assignPendingPartitionSplits(context.registeredReaders().keySet());
@@ -373,8 +428,8 @@
                 // Mark pending partitions as already assigned
                 pendingAssignmentForReader.forEach(
                         split -> {
-                            assignedPartitions.add(split.getTopicPartition());
-                            unassignedInitialPartitions.remove(split.getTopicPartition());
+                            assignedSplits.put(split.getTopicPartition(), split);
+                            unassignedSplits.remove(split.getTopicPartition());
                         });
             }
         }
@@ -414,7 +469,7 @@
                     }
                 };
 
-        assignedPartitions.forEach(dedupOrMarkAsRemoved);
+        assignedSplits.keySet().forEach(dedupOrMarkAsRemoved);
         pendingPartitionSplitAssignment.forEach(
                 (reader, splits) ->
                         splits.forEach(
@@ -446,6 +501,11 @@
         return new PartitionOffsetsRetrieverImpl(adminClient, groupId);
     }
 
+    @VisibleForTesting
+    Map<Integer, Set<KafkaPartitionSplit>> getPendingPartitionSplitAssignment() {
+        return pendingPartitionSplitAssignment;
+    }
+
     /**
      * Returns the index of the target subtask that a specific Kafka partition should be assigned
      * to.
diff --git a/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/source/enumerator/TopicPartitionAndAssignmentStatus.java b/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/source/enumerator/SplitAndAssignmentStatus.java
similarity index 75%
rename from flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/source/enumerator/TopicPartitionAndAssignmentStatus.java
rename to flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/source/enumerator/SplitAndAssignmentStatus.java
index 2caed99..a7763fb 100644
--- a/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/source/enumerator/TopicPartitionAndAssignmentStatus.java
+++ b/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/source/enumerator/SplitAndAssignmentStatus.java
@@ -19,23 +19,21 @@
 package org.apache.flink.connector.kafka.source.enumerator;
 
 import org.apache.flink.annotation.Internal;
-
-import org.apache.kafka.common.TopicPartition;
+import org.apache.flink.connector.kafka.source.split.KafkaPartitionSplit;
 
 /** Kafka partition with assign status. */
 @Internal
-public class TopicPartitionAndAssignmentStatus {
-    private final TopicPartition topicPartition;
+public class SplitAndAssignmentStatus {
+    private final KafkaPartitionSplit split;
     private final AssignmentStatus assignmentStatus;
 
-    public TopicPartitionAndAssignmentStatus(
-            TopicPartition topicPartition, AssignmentStatus assignStatus) {
-        this.topicPartition = topicPartition;
+    public SplitAndAssignmentStatus(KafkaPartitionSplit split, AssignmentStatus assignStatus) {
+        this.split = split;
         this.assignmentStatus = assignStatus;
     }
 
-    public TopicPartition topicPartition() {
-        return topicPartition;
+    public KafkaPartitionSplit split() {
+        return split;
     }
 
     public AssignmentStatus assignmentStatus() {
diff --git a/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/source/split/KafkaPartitionSplit.java b/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/source/split/KafkaPartitionSplit.java
index 7c04600..52cb3b9 100644
--- a/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/source/split/KafkaPartitionSplit.java
+++ b/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/source/split/KafkaPartitionSplit.java
@@ -41,10 +41,14 @@
     public static final long EARLIEST_OFFSET = -2;
     // Indicating the split should consume from the last committed offset.
     public static final long COMMITTED_OFFSET = -3;
+    // Used to indicate the split has been migrated from an earlier enumerator state; offset needs
+    // to be initialized on recovery
+    public static final long MIGRATED = Long.MIN_VALUE;
 
     // Valid special starting offsets
     public static final Set<Long> VALID_STARTING_OFFSET_MARKERS =
-            new HashSet<>(Arrays.asList(EARLIEST_OFFSET, LATEST_OFFSET, COMMITTED_OFFSET));
+            new HashSet<>(
+                    Arrays.asList(EARLIEST_OFFSET, LATEST_OFFSET, COMMITTED_OFFSET, MIGRATED));
     public static final Set<Long> VALID_STOPPING_OFFSET_MARKERS =
             new HashSet<>(Arrays.asList(LATEST_OFFSET, COMMITTED_OFFSET, NO_STOPPING_OFFSET));
 
diff --git a/flink-connector-kafka/src/test/java/org/apache/flink/connector/kafka/dynamic/source/enumerator/DynamicKafkaSourceEnumStateSerializerTest.java b/flink-connector-kafka/src/test/java/org/apache/flink/connector/kafka/dynamic/source/enumerator/DynamicKafkaSourceEnumStateSerializerTest.java
index 66caec4..251309b 100644
--- a/flink-connector-kafka/src/test/java/org/apache/flink/connector/kafka/dynamic/source/enumerator/DynamicKafkaSourceEnumStateSerializerTest.java
+++ b/flink-connector-kafka/src/test/java/org/apache/flink/connector/kafka/dynamic/source/enumerator/DynamicKafkaSourceEnumStateSerializerTest.java
@@ -22,7 +22,8 @@
 import org.apache.flink.connector.kafka.dynamic.metadata.KafkaStream;
 import org.apache.flink.connector.kafka.source.enumerator.AssignmentStatus;
 import org.apache.flink.connector.kafka.source.enumerator.KafkaSourceEnumState;
-import org.apache.flink.connector.kafka.source.enumerator.TopicPartitionAndAssignmentStatus;
+import org.apache.flink.connector.kafka.source.enumerator.SplitAndAssignmentStatus;
+import org.apache.flink.connector.kafka.source.split.KafkaPartitionSplit;
 
 import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.ImmutableSet;
@@ -33,6 +34,8 @@
 import java.util.Properties;
 import java.util.Set;
 
+import static org.apache.flink.connector.kafka.source.enumerator.AssignmentStatus.ASSIGNED;
+import static org.apache.flink.connector.kafka.source.enumerator.AssignmentStatus.UNASSIGNED;
 import static org.assertj.core.api.Assertions.assertThat;
 
 /**
@@ -81,28 +84,16 @@
                                 "cluster0",
                                 new KafkaSourceEnumState(
                                         ImmutableSet.of(
-                                                new TopicPartitionAndAssignmentStatus(
-                                                        new TopicPartition("topic0", 0),
-                                                        AssignmentStatus.ASSIGNED),
-                                                new TopicPartitionAndAssignmentStatus(
-                                                        new TopicPartition("topic1", 1),
-                                                        AssignmentStatus.UNASSIGNED_INITIAL)),
+                                                getSplitAssignment("topic0", 0, ASSIGNED),
+                                                getSplitAssignment("topic1", 1, UNASSIGNED)),
                                         true),
                                 "cluster1",
                                 new KafkaSourceEnumState(
                                         ImmutableSet.of(
-                                                new TopicPartitionAndAssignmentStatus(
-                                                        new TopicPartition("topic2", 0),
-                                                        AssignmentStatus.UNASSIGNED_INITIAL),
-                                                new TopicPartitionAndAssignmentStatus(
-                                                        new TopicPartition("topic3", 1),
-                                                        AssignmentStatus.UNASSIGNED_INITIAL),
-                                                new TopicPartitionAndAssignmentStatus(
-                                                        new TopicPartition("topic4", 2),
-                                                        AssignmentStatus.UNASSIGNED_INITIAL),
-                                                new TopicPartitionAndAssignmentStatus(
-                                                        new TopicPartition("topic5", 3),
-                                                        AssignmentStatus.UNASSIGNED_INITIAL)),
+                                                getSplitAssignment("topic2", 0, UNASSIGNED),
+                                                getSplitAssignment("topic3", 1, UNASSIGNED),
+                                                getSplitAssignment("topic4", 2, UNASSIGNED),
+                                                getSplitAssignment("topic5", 3, UNASSIGNED)),
                                         false)));
 
         DynamicKafkaSourceEnumState dynamicKafkaSourceEnumStateAfterSerde =
@@ -115,4 +106,10 @@
                 .usingRecursiveComparison()
                 .isEqualTo(dynamicKafkaSourceEnumStateAfterSerde);
     }
+
+    private static SplitAndAssignmentStatus getSplitAssignment(
+            String topic, int partition, AssignmentStatus assignStatus) {
+        return new SplitAndAssignmentStatus(
+                new KafkaPartitionSplit(new TopicPartition(topic, partition), 0), assignStatus);
+    }
 }
diff --git a/flink-connector-kafka/src/test/java/org/apache/flink/connector/kafka/dynamic/source/enumerator/DynamicKafkaSourceEnumeratorTest.java b/flink-connector-kafka/src/test/java/org/apache/flink/connector/kafka/dynamic/source/enumerator/DynamicKafkaSourceEnumeratorTest.java
index 8613334..f974b6f 100644
--- a/flink-connector-kafka/src/test/java/org/apache/flink/connector/kafka/dynamic/source/enumerator/DynamicKafkaSourceEnumeratorTest.java
+++ b/flink-connector-kafka/src/test/java/org/apache/flink/connector/kafka/dynamic/source/enumerator/DynamicKafkaSourceEnumeratorTest.java
@@ -33,7 +33,6 @@
 import org.apache.flink.connector.kafka.dynamic.source.split.DynamicKafkaSourceSplit;
 import org.apache.flink.connector.kafka.source.KafkaSourceOptions;
 import org.apache.flink.connector.kafka.source.enumerator.AssignmentStatus;
-import org.apache.flink.connector.kafka.source.enumerator.TopicPartitionAndAssignmentStatus;
 import org.apache.flink.connector.kafka.source.enumerator.initializer.NoStoppingOffsetsInitializer;
 import org.apache.flink.connector.kafka.source.enumerator.initializer.OffsetsInitializer;
 import org.apache.flink.connector.kafka.testutils.MockKafkaMetadataService;
@@ -429,7 +428,7 @@
             assertThat(
                             stateBeforeSplitAssignment.getClusterEnumeratorStates().values()
                                     .stream()
-                                    .map(subState -> subState.assignedPartitions().stream())
+                                    .map(subState -> subState.assignedSplits().stream())
                                     .count())
                     .as("no readers registered, so state should be empty")
                     .isZero();
@@ -458,7 +457,7 @@
 
             assertThat(
                             stateAfterSplitAssignment.getClusterEnumeratorStates().values().stream()
-                                    .flatMap(enumState -> enumState.assignedPartitions().stream())
+                                    .flatMap(enumState -> enumState.assignedSplits().stream())
                                     .count())
                     .isEqualTo(
                             NUM_SPLITS_PER_CLUSTER
@@ -514,15 +513,13 @@
 
             assertThat(getFilteredTopicPartitions(initialState, TOPIC, AssignmentStatus.ASSIGNED))
                     .hasSize(2);
-            assertThat(
-                            getFilteredTopicPartitions(
-                                    initialState, TOPIC, AssignmentStatus.UNASSIGNED_INITIAL))
+            assertThat(getFilteredTopicPartitions(initialState, TOPIC, AssignmentStatus.UNASSIGNED))
                     .hasSize(1);
             assertThat(getFilteredTopicPartitions(initialState, topic2, AssignmentStatus.ASSIGNED))
                     .hasSize(2);
             assertThat(
                             getFilteredTopicPartitions(
-                                    initialState, topic2, AssignmentStatus.UNASSIGNED_INITIAL))
+                                    initialState, topic2, AssignmentStatus.UNASSIGNED))
                     .hasSize(1);
 
             // mock metadata change
@@ -540,13 +537,13 @@
                     .hasSize(3);
             assertThat(
                             getFilteredTopicPartitions(
-                                    migratedState, TOPIC, AssignmentStatus.UNASSIGNED_INITIAL))
+                                    migratedState, TOPIC, AssignmentStatus.UNASSIGNED))
                     .isEmpty();
             assertThat(getFilteredTopicPartitions(migratedState, topic2, AssignmentStatus.ASSIGNED))
                     .isEmpty();
             assertThat(
                             getFilteredTopicPartitions(
-                                    migratedState, topic2, AssignmentStatus.UNASSIGNED_INITIAL))
+                                    migratedState, topic2, AssignmentStatus.UNASSIGNED))
                     .isEmpty();
         }
     }
@@ -955,12 +952,14 @@
     private List<TopicPartition> getFilteredTopicPartitions(
             DynamicKafkaSourceEnumState state, String topic, AssignmentStatus assignmentStatus) {
         return state.getClusterEnumeratorStates().values().stream()
-                .flatMap(s -> s.partitions().stream())
+                .flatMap(s -> s.splits().stream())
                 .filter(
                         partition ->
-                                partition.topicPartition().topic().equals(topic)
+                                partition.split().getTopic().equals(topic)
                                         && partition.assignmentStatus() == assignmentStatus)
-                .map(TopicPartitionAndAssignmentStatus::topicPartition)
+                .map(
+                        splitAndAssignmentStatus ->
+                                splitAndAssignmentStatus.split().getTopicPartition())
                 .collect(Collectors.toList());
     }
 
diff --git a/flink-connector-kafka/src/test/java/org/apache/flink/connector/kafka/source/enumerator/KafkaSourceEnumStateSerializerTest.java b/flink-connector-kafka/src/test/java/org/apache/flink/connector/kafka/source/enumerator/KafkaSourceEnumStateSerializerTest.java
index 6c172e4..5207687 100644
--- a/flink-connector-kafka/src/test/java/org/apache/flink/connector/kafka/source/enumerator/KafkaSourceEnumStateSerializerTest.java
+++ b/flink-connector-kafka/src/test/java/org/apache/flink/connector/kafka/source/enumerator/KafkaSourceEnumStateSerializerTest.java
@@ -29,8 +29,10 @@
 import java.util.Collection;
 import java.util.HashMap;
 import java.util.HashSet;
+import java.util.List;
 import java.util.Map;
 import java.util.Set;
+import java.util.stream.Collectors;
 
 import static org.assertj.core.api.Assertions.assertThat;
 
@@ -46,8 +48,8 @@
     public void testEnumStateSerde() throws IOException {
         final KafkaSourceEnumState state =
                 new KafkaSourceEnumState(
-                        constructTopicPartitions(0),
-                        constructTopicPartitions(NUM_PARTITIONS_PER_TOPIC),
+                        constructTopicSplits(0),
+                        constructTopicSplits(NUM_PARTITIONS_PER_TOPIC),
                         true);
         final KafkaSourceEnumStateSerializer serializer = new KafkaSourceEnumStateSerializer();
 
@@ -56,26 +58,35 @@
         final KafkaSourceEnumState restoredState =
                 serializer.deserialize(serializer.getVersion(), bytes);
 
-        assertThat(restoredState.assignedPartitions()).isEqualTo(state.assignedPartitions());
-        assertThat(restoredState.unassignedInitialPartitions())
-                .isEqualTo(state.unassignedInitialPartitions());
+        assertThat(restoredState.assignedSplits())
+                .containsExactlyInAnyOrderElementsOf(state.assignedSplits());
+        assertThat(restoredState.unassignedSplits())
+                .containsExactlyInAnyOrderElementsOf(state.unassignedSplits());
         assertThat(restoredState.initialDiscoveryFinished()).isTrue();
     }
 
     @Test
     public void testBackwardCompatibility() throws IOException {
 
-        final Set<TopicPartition> topicPartitions = constructTopicPartitions(0);
-        final Map<Integer, Set<KafkaPartitionSplit>> splitAssignments =
-                toSplitAssignments(topicPartitions);
+        final Set<KafkaPartitionSplit> splits = constructTopicSplits(0);
+        final Map<Integer, Collection<KafkaPartitionSplit>> splitAssignments =
+                toSplitAssignments(splits);
+        final List<SplitAndAssignmentStatus> splitAndAssignmentStatuses =
+                splits.stream()
+                        .map(
+                                split ->
+                                        new SplitAndAssignmentStatus(
+                                                split, getAssignmentStatus(split)))
+                        .collect(Collectors.toList());
 
         // Create bytes in the way of KafkaEnumStateSerializer version 0 doing serialization
         final byte[] bytesV0 =
                 SerdeUtils.serializeSplitAssignments(
                         splitAssignments, new KafkaPartitionSplitSerializer());
         // Create bytes in the way of KafkaEnumStateSerializer version 1 doing serialization
-        final byte[] bytesV1 =
-                KafkaSourceEnumStateSerializer.serializeTopicPartitions(topicPartitions);
+        final byte[] bytesV1 = KafkaSourceEnumStateSerializer.serializeV1(splits);
+        final byte[] bytesV2 =
+                KafkaSourceEnumStateSerializer.serializeV2(splitAndAssignmentStatuses, false);
 
         // Deserialize above bytes with KafkaEnumStateSerializer version 2 to check backward
         // compatibility
@@ -83,46 +94,72 @@
                 new KafkaSourceEnumStateSerializer().deserialize(0, bytesV0);
         final KafkaSourceEnumState kafkaSourceEnumStateV1 =
                 new KafkaSourceEnumStateSerializer().deserialize(1, bytesV1);
+        final KafkaSourceEnumState kafkaSourceEnumStateV2 =
+                new KafkaSourceEnumStateSerializer().deserialize(2, bytesV2);
 
-        assertThat(kafkaSourceEnumStateV0.assignedPartitions()).isEqualTo(topicPartitions);
-        assertThat(kafkaSourceEnumStateV0.unassignedInitialPartitions()).isEmpty();
+        assertThat(kafkaSourceEnumStateV0.assignedSplits())
+                .containsExactlyInAnyOrderElementsOf(splits);
+        assertThat(kafkaSourceEnumStateV0.unassignedSplits()).isEmpty();
         assertThat(kafkaSourceEnumStateV0.initialDiscoveryFinished()).isTrue();
 
-        assertThat(kafkaSourceEnumStateV1.assignedPartitions()).isEqualTo(topicPartitions);
-        assertThat(kafkaSourceEnumStateV1.unassignedInitialPartitions()).isEmpty();
+        assertThat(kafkaSourceEnumStateV1.assignedSplits())
+                .containsExactlyInAnyOrderElementsOf(splits);
+        assertThat(kafkaSourceEnumStateV1.unassignedSplits()).isEmpty();
         assertThat(kafkaSourceEnumStateV1.initialDiscoveryFinished()).isTrue();
+
+        final Map<AssignmentStatus, Set<KafkaPartitionSplit>> splitsByStatus =
+                splitAndAssignmentStatuses.stream()
+                        .collect(
+                                Collectors.groupingBy(
+                                        SplitAndAssignmentStatus::assignmentStatus,
+                                        Collectors.mapping(
+                                                SplitAndAssignmentStatus::split,
+                                                Collectors.toSet())));
+        assertThat(kafkaSourceEnumStateV2.assignedSplits())
+                .containsExactlyInAnyOrderElementsOf(splitsByStatus.get(AssignmentStatus.ASSIGNED));
+        assertThat(kafkaSourceEnumStateV2.unassignedSplits())
+                .containsExactlyInAnyOrderElementsOf(
+                        splitsByStatus.get(AssignmentStatus.UNASSIGNED));
+        assertThat(kafkaSourceEnumStateV2.initialDiscoveryFinished()).isFalse();
     }
 
-    private Set<TopicPartition> constructTopicPartitions(int startPartition) {
+    private static AssignmentStatus getAssignmentStatus(KafkaPartitionSplit split) {
+        return AssignmentStatus.values()[
+                Math.abs(split.hashCode()) % AssignmentStatus.values().length];
+    }
+
+    private Set<KafkaPartitionSplit> constructTopicSplits(int startPartition) {
         // Create topic partitions for readers.
         // Reader i will be assigned with NUM_PARTITIONS_PER_TOPIC splits, with topic name
         // "topic-{i}" and
         // NUM_PARTITIONS_PER_TOPIC partitions. The starting partition number is startPartition
         // Totally NUM_READERS * NUM_PARTITIONS_PER_TOPIC partitions will be created.
-        Set<TopicPartition> topicPartitions = new HashSet<>();
+        Set<KafkaPartitionSplit> topicPartitions = new HashSet<>();
         for (int readerId = 0; readerId < NUM_READERS; readerId++) {
             for (int partition = startPartition;
                     partition < startPartition + NUM_PARTITIONS_PER_TOPIC;
                     partition++) {
-                topicPartitions.add(new TopicPartition(TOPIC_PREFIX + readerId, partition));
+                topicPartitions.add(
+                        new KafkaPartitionSplit(
+                                new TopicPartition(TOPIC_PREFIX + readerId, partition),
+                                KafkaPartitionSplit.MIGRATED));
             }
         }
         return topicPartitions;
     }
 
-    private Map<Integer, Set<KafkaPartitionSplit>> toSplitAssignments(
-            Collection<TopicPartition> topicPartitions) {
+    private Map<Integer, Collection<KafkaPartitionSplit>> toSplitAssignments(
+            Collection<KafkaPartitionSplit> splits) {
         // Assign splits to readers according to topic name. For example, topic "topic-5" will be
         // assigned to reader with ID=5
-        Map<Integer, Set<KafkaPartitionSplit>> splitAssignments = new HashMap<>();
-        topicPartitions.forEach(
-                (tp) ->
-                        splitAssignments
-                                .computeIfAbsent(
-                                        Integer.valueOf(
-                                                tp.topic().substring(TOPIC_PREFIX.length())),
-                                        HashSet::new)
-                                .add(new KafkaPartitionSplit(tp, STARTING_OFFSET)));
+        Map<Integer, Collection<KafkaPartitionSplit>> splitAssignments = new HashMap<>();
+        for (KafkaPartitionSplit split : splits) {
+            splitAssignments
+                    .computeIfAbsent(
+                            Integer.valueOf(split.getTopic().substring(TOPIC_PREFIX.length())),
+                            HashSet::new)
+                    .add(split);
+        }
         return splitAssignments;
     }
 }
diff --git a/flink-connector-kafka/src/test/java/org/apache/flink/connector/kafka/source/enumerator/KafkaEnumeratorTest.java b/flink-connector-kafka/src/test/java/org/apache/flink/connector/kafka/source/enumerator/KafkaSourceEnumeratorTest.java
similarity index 71%
rename from flink-connector-kafka/src/test/java/org/apache/flink/connector/kafka/source/enumerator/KafkaEnumeratorTest.java
rename to flink-connector-kafka/src/test/java/org/apache/flink/connector/kafka/source/enumerator/KafkaSourceEnumeratorTest.java
index 8b308af..3e64e62 100644
--- a/flink-connector-kafka/src/test/java/org/apache/flink/connector/kafka/source/enumerator/KafkaEnumeratorTest.java
+++ b/flink-connector-kafka/src/test/java/org/apache/flink/connector/kafka/source/enumerator/KafkaSourceEnumeratorTest.java
@@ -30,14 +30,20 @@
 import org.apache.flink.connector.kafka.testutils.KafkaSourceTestEnv;
 import org.apache.flink.mock.Whitebox;
 
+import com.google.common.collect.Iterables;
 import org.apache.kafka.clients.admin.AdminClient;
 import org.apache.kafka.clients.admin.NewTopic;
 import org.apache.kafka.clients.consumer.ConsumerConfig;
+import org.apache.kafka.clients.consumer.OffsetResetStrategy;
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.serialization.StringDeserializer;
-import org.junit.AfterClass;
-import org.junit.BeforeClass;
-import org.junit.Test;
+import org.assertj.core.api.SoftAssertions;
+import org.junit.jupiter.api.AfterAll;
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.Timeout;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.EnumSource;
 
 import java.util.ArrayList;
 import java.util.Arrays;
@@ -50,13 +56,15 @@
 import java.util.Properties;
 import java.util.Set;
 import java.util.StringJoiner;
+import java.util.concurrent.TimeUnit;
 import java.util.regex.Pattern;
 import java.util.stream.Collectors;
 
+import static org.apache.flink.connector.kafka.source.split.KafkaPartitionSplit.MIGRATED;
 import static org.assertj.core.api.Assertions.assertThat;
 
 /** Unit tests for {@link KafkaSourceEnumerator}. */
-public class KafkaEnumeratorTest {
+public class KafkaSourceEnumeratorTest {
     private static final int NUM_SUBTASKS = 3;
     private static final String DYNAMIC_TOPIC_NAME = "dynamic_topic";
     private static final int NUM_PARTITIONS_DYNAMIC_TOPIC = 4;
@@ -74,15 +82,30 @@
     private static final boolean DISABLE_PERIODIC_PARTITION_DISCOVERY = false;
     private static final boolean INCLUDE_DYNAMIC_TOPIC = true;
     private static final boolean EXCLUDE_DYNAMIC_TOPIC = false;
+    private static KafkaSourceEnumerator.PartitionOffsetsRetrieverImpl retriever;
+    private static final Map<TopicPartition, Long> specificOffsets = new HashMap<>();
 
-    @BeforeClass
+    @BeforeAll
     public static void setup() throws Throwable {
         KafkaSourceTestEnv.setup();
+        retriever =
+                new KafkaSourceEnumerator.PartitionOffsetsRetrieverImpl(
+                        KafkaSourceTestEnv.getAdminClient(), KafkaSourceTestEnv.GROUP_ID);
         KafkaSourceTestEnv.setupTopic(TOPIC1, true, true, KafkaSourceTestEnv::getRecordsForTopic);
         KafkaSourceTestEnv.setupTopic(TOPIC2, true, true, KafkaSourceTestEnv::getRecordsForTopic);
+
+        for (Map.Entry<TopicPartition, Long> partitionEnd :
+                retriever
+                        .endOffsets(KafkaSourceTestEnv.getPartitionsForTopics(PRE_EXISTING_TOPICS))
+                        .entrySet()) {
+            specificOffsets.put(
+                    partitionEnd.getKey(),
+                    partitionEnd.getValue() / (partitionEnd.getKey().partition() + 1));
+        }
+        assertThat(specificOffsets).hasSize(2 * KafkaSourceTestEnv.NUM_PARTITIONS);
     }
 
-    @AfterClass
+    @AfterAll
     public static void tearDown() throws Exception {
         KafkaSourceTestEnv.tearDown();
     }
@@ -249,7 +272,8 @@
         }
     }
 
-    @Test(timeout = 30000L)
+    @Test
+    @Timeout(value = 30, unit = TimeUnit.SECONDS)
     public void testDiscoverPartitionsPeriodically() throws Throwable {
         try (MockSplitEnumeratorContext<KafkaPartitionSplit> context =
                         new MockSplitEnumeratorContext<>(NUM_SUBTASKS);
@@ -261,7 +285,7 @@
                                 OffsetsInitializer.latest());
                 AdminClient adminClient = KafkaSourceTestEnv.getAdminClient()) {
 
-            startEnumeratorAndRegisterReaders(context, enumerator);
+            startEnumeratorAndRegisterReaders(context, enumerator, OffsetsInitializer.latest());
 
             // invoke partition discovery callable again and there should be no new assignments.
             runPeriodicPartitionDiscovery(context);
@@ -289,11 +313,13 @@
                     break;
                 }
             }
+            // later elements are initialized with EARLIEST
             verifyLastReadersAssignments(
                     context,
                     Arrays.asList(READER0, READER1),
                     Collections.singleton(DYNAMIC_TOPIC_NAME),
-                    3);
+                    3,
+                    OffsetsInitializer.earliest());
 
             // new partitions use EARLIEST_OFFSET, while initial partitions use LATEST_OFFSET
             List<KafkaPartitionSplit> initialPartitionAssign =
@@ -316,39 +342,123 @@
         }
     }
 
+    /**
+     * Ensures that migrated splits are immediately initialized with {@link OffsetsInitializer},
+     * such that an early {@link
+     * org.apache.flink.api.connector.source.SplitEnumerator#snapshotState(long)} doesn't see the
+     * special value.
+     */
     @Test
-    public void testAddSplitsBack() throws Throwable {
+    public void shouldEagerlyInitializeSplitOffsetsOnMigration() throws Throwable {
+        final TopicPartition assigned1 = new TopicPartition(TOPIC1, 0);
+        final TopicPartition assigned2 = new TopicPartition(TOPIC1, 1);
+        final TopicPartition unassigned1 = new TopicPartition(TOPIC2, 0);
+        final TopicPartition unassigned2 = new TopicPartition(TOPIC2, 1);
+
+        final long migratedOffset1 = 11L;
+        final long migratedOffset2 = 22L;
+        final OffsetsInitializer offsetsInitializer =
+                new OffsetsInitializer() {
+                    @Override
+                    public Map<TopicPartition, Long> getPartitionOffsets(
+                            Collection<TopicPartition> partitions,
+                            PartitionOffsetsRetriever partitionOffsetsRetriever) {
+                        return Map.of(assigned1, migratedOffset1, unassigned2, migratedOffset2);
+                    }
+
+                    @Override
+                    public OffsetResetStrategy getAutoOffsetResetStrategy() {
+                        return null;
+                    }
+                };
         try (MockSplitEnumeratorContext<KafkaPartitionSplit> context =
                         new MockSplitEnumeratorContext<>(NUM_SUBTASKS);
                 KafkaSourceEnumerator enumerator =
-                        createEnumerator(context, ENABLE_PERIODIC_PARTITION_DISCOVERY)) {
+                        createEnumerator(
+                                context,
+                                offsetsInitializer,
+                                PRE_EXISTING_TOPICS,
+                                List.of(
+                                        new KafkaPartitionSplit(assigned1, MIGRATED),
+                                        new KafkaPartitionSplit(assigned2, 2)),
+                                List.of(
+                                        new KafkaPartitionSplit(unassigned1, 1),
+                                        new KafkaPartitionSplit(unassigned2, MIGRATED)),
+                                false,
+                                new Properties())) {
+            final KafkaSourceEnumState state = enumerator.snapshotState(1L);
+            assertThat(state.assignedSplits())
+                    .containsExactlyInAnyOrder(
+                            new KafkaPartitionSplit(assigned1, migratedOffset1),
+                            new KafkaPartitionSplit(assigned2, 2));
+            assertThat(state.unassignedSplits())
+                    .containsExactlyInAnyOrder(
+                            new KafkaPartitionSplit(unassigned1, 1),
+                            new KafkaPartitionSplit(unassigned2, migratedOffset2));
+        }
+    }
 
-            startEnumeratorAndRegisterReaders(context, enumerator);
+    @ParameterizedTest
+    @EnumSource(StandardOffsetsInitializer.class)
+    public void testAddSplitsBack(StandardOffsetsInitializer offsetsInitializer) throws Throwable {
+        try (MockSplitEnumeratorContext<KafkaPartitionSplit> context =
+                        new MockSplitEnumeratorContext<>(NUM_SUBTASKS);
+                KafkaSourceEnumerator enumerator =
+                        createEnumerator(
+                                context,
+                                ENABLE_PERIODIC_PARTITION_DISCOVERY,
+                                true,
+                                offsetsInitializer.getOffsetsInitializer())) {
+
+            startEnumeratorAndRegisterReaders(
+                    context, enumerator, offsetsInitializer.getOffsetsInitializer());
+
+            // READER2 not yet assigned
+            final Set<KafkaPartitionSplit> unassignedSplits =
+                    enumerator.getPendingPartitionSplitAssignment().get(READER2);
+            assertThat(enumerator.snapshotState(1L).unassignedSplits())
+                    .containsExactlyInAnyOrderElementsOf(unassignedSplits);
 
             // Simulate a reader failure.
             context.unregisterReader(READER0);
-            enumerator.addSplitsBack(
-                    context.getSplitsAssignmentSequence().get(0).assignment().get(READER0),
-                    READER0);
+            final List<KafkaPartitionSplit> assignedSplits =
+                    context.getSplitsAssignmentSequence().get(0).assignment().get(READER0);
+            final List<KafkaPartitionSplit> advancedSplits =
+                    assignedSplits.stream()
+                            .map(
+                                    split ->
+                                            new KafkaPartitionSplit(
+                                                    split.getTopicPartition(),
+                                                    split.getStartingOffset() + 1))
+                            .collect(Collectors.toList());
+            enumerator.addSplitsBack(advancedSplits, READER0);
             assertThat(context.getSplitsAssignmentSequence())
                     .as("The added back splits should have not been assigned")
                     .hasSize(2);
 
+            assertThat(enumerator.snapshotState(2L).unassignedSplits())
+                    .containsExactlyInAnyOrderElementsOf(
+                            Iterables.concat(
+                                    advancedSplits, unassignedSplits)); // READER0 + READER2
+
             // Simulate a reader recovery.
             registerReader(context, enumerator, READER0);
-            verifyLastReadersAssignments(
-                    context, Collections.singleton(READER0), PRE_EXISTING_TOPICS, 3);
+            verifyAssignments(
+                    Map.of(READER0, advancedSplits),
+                    context.getSplitsAssignmentSequence().get(2).assignment());
+            assertThat(enumerator.snapshotState(3L).unassignedSplits())
+                    .containsExactlyInAnyOrderElementsOf(unassignedSplits);
         }
     }
 
     @Test
     public void testWorkWithPreexistingAssignments() throws Throwable {
-        Set<TopicPartition> preexistingAssignments;
+        Collection<KafkaPartitionSplit> preexistingAssignments;
         try (MockSplitEnumeratorContext<KafkaPartitionSplit> context1 =
                         new MockSplitEnumeratorContext<>(NUM_SUBTASKS);
                 KafkaSourceEnumerator enumerator =
                         createEnumerator(context1, ENABLE_PERIODIC_PARTITION_DISCOVERY)) {
-            startEnumeratorAndRegisterReaders(context1, enumerator);
+            startEnumeratorAndRegisterReaders(context1, enumerator, OffsetsInitializer.earliest());
             preexistingAssignments =
                     asEnumState(context1.getSplitsAssignmentSequence().get(0).assignment());
         }
@@ -409,56 +519,50 @@
         }
     }
 
-    @Test
-    public void testSnapshotState() throws Throwable {
+    @ParameterizedTest
+    @EnumSource(StandardOffsetsInitializer.class)
+    public void testSnapshotState(StandardOffsetsInitializer offsetsInitializer) throws Throwable {
         try (MockSplitEnumeratorContext<KafkaPartitionSplit> context =
                         new MockSplitEnumeratorContext<>(NUM_SUBTASKS);
-                KafkaSourceEnumerator enumerator = createEnumerator(context, false)) {
+                KafkaSourceEnumerator enumerator =
+                        createEnumerator(
+                                context, false, true, offsetsInitializer.getOffsetsInitializer())) {
             enumerator.start();
 
             // Step1: Before first discovery, so the state should be empty
             final KafkaSourceEnumState state1 = enumerator.snapshotState(1L);
-            assertThat(state1.assignedPartitions()).isEmpty();
-            assertThat(state1.unassignedInitialPartitions()).isEmpty();
+            assertThat(state1.assignedSplits()).isEmpty();
+            assertThat(state1.unassignedSplits()).isEmpty();
             assertThat(state1.initialDiscoveryFinished()).isFalse();
 
             registerReader(context, enumerator, READER0);
             registerReader(context, enumerator, READER1);
 
-            // Step2: First partition discovery after start, but no assignments to readers
-            context.runNextOneTimeCallable();
-            final KafkaSourceEnumState state2 = enumerator.snapshotState(2L);
-            assertThat(state2.assignedPartitions()).isEmpty();
-            assertThat(state2.unassignedInitialPartitions()).isNotEmpty();
-            assertThat(state2.initialDiscoveryFinished()).isTrue();
-
-            // Step3: Assign partials partitions to reader0 and reader1
+            // Step2: Assign partials partitions to reader0 and reader1
             context.runNextOneTimeCallable();
 
             // The state should contain splits assigned to READER0 and READER1, but no READER2
             // register.
             // Thus, both assignedPartitions and unassignedInitialPartitions are not empty.
-            final KafkaSourceEnumState state3 = enumerator.snapshotState(3L);
+            final KafkaSourceEnumState state2 = enumerator.snapshotState(2L);
             verifySplitAssignmentWithPartitions(
                     getExpectedAssignments(
-                            new HashSet<>(Arrays.asList(READER0, READER1)), PRE_EXISTING_TOPICS),
-                    state3.assignedPartitions());
-            assertThat(state3.unassignedInitialPartitions()).isNotEmpty();
-            assertThat(state3.initialDiscoveryFinished()).isTrue();
-            // total partitions of state2 and state3  are equal
-            // state2 only includes unassignedInitialPartitions
-            // state3 includes unassignedInitialPartitions + assignedPartitions
-            Set<TopicPartition> allPartitionOfState3 = new HashSet<>();
-            allPartitionOfState3.addAll(state3.unassignedInitialPartitions());
-            allPartitionOfState3.addAll(state3.assignedPartitions());
-            assertThat(state2.unassignedInitialPartitions()).isEqualTo(allPartitionOfState3);
+                            new HashSet<>(Arrays.asList(READER0, READER1)),
+                            PRE_EXISTING_TOPICS,
+                            offsetsInitializer.getOffsetsInitializer()),
+                    state2.assignedSplits());
+            assertThat(state2.assignedSplits()).isNotEmpty();
+            assertThat(state2.unassignedSplits()).isNotEmpty();
+            assertThat(state2.initialDiscoveryFinished()).isTrue();
 
-            // Step4: register READER2, then all partitions are assigned
+            // Step3: register READER2, then all partitions are assigned
             registerReader(context, enumerator, READER2);
-            final KafkaSourceEnumState state4 = enumerator.snapshotState(4L);
-            assertThat(state4.assignedPartitions()).isEqualTo(allPartitionOfState3);
-            assertThat(state4.unassignedInitialPartitions()).isEmpty();
-            assertThat(state4.initialDiscoveryFinished()).isTrue();
+            final KafkaSourceEnumState state3 = enumerator.snapshotState(3L);
+            assertThat(state3.assignedSplits())
+                    .containsExactlyInAnyOrderElementsOf(
+                            Iterables.concat(state2.assignedSplits(), state2.unassignedSplits()));
+            assertThat(state3.unassignedSplits()).isEmpty();
+            assertThat(state3.initialDiscoveryFinished()).isTrue();
         }
     }
 
@@ -530,7 +634,8 @@
 
     private void startEnumeratorAndRegisterReaders(
             MockSplitEnumeratorContext<KafkaPartitionSplit> context,
-            KafkaSourceEnumerator enumerator)
+            KafkaSourceEnumerator enumerator,
+            OffsetsInitializer offsetsInitializer)
             throws Throwable {
         // Start the enumerator and it should schedule a one time task to discover and assign
         // partitions.
@@ -543,12 +648,20 @@
         // Run the partition discover callable and check the partition assignment.
         runPeriodicPartitionDiscovery(context);
         verifyLastReadersAssignments(
-                context, Collections.singleton(READER0), PRE_EXISTING_TOPICS, 1);
+                context,
+                Collections.singleton(READER0),
+                PRE_EXISTING_TOPICS,
+                1,
+                offsetsInitializer);
 
         // Register reader 1 after first partition discovery.
         registerReader(context, enumerator, READER1);
         verifyLastReadersAssignments(
-                context, Collections.singleton(READER1), PRE_EXISTING_TOPICS, 2);
+                context,
+                Collections.singleton(READER1),
+                PRE_EXISTING_TOPICS,
+                2,
+                offsetsInitializer);
     }
 
     // ----------------------------------------
@@ -619,8 +732,8 @@
             MockSplitEnumeratorContext<KafkaPartitionSplit> enumContext,
             OffsetsInitializer startingOffsetsInitializer,
             Collection<String> topicsToSubscribe,
-            Set<TopicPartition> assignedPartitions,
-            Set<TopicPartition> unassignedInitialPartitions,
+            Collection<KafkaPartitionSplit> assignedSplits,
+            Collection<KafkaPartitionSplit> unassignedInitialSplits,
             boolean initialDiscoveryFinished,
             Properties overrideProperties) {
         // Use a TopicPatternSubscriber so that no exception if a subscribed topic hasn't been
@@ -644,11 +757,29 @@
                 enumContext,
                 Boundedness.CONTINUOUS_UNBOUNDED,
                 new KafkaSourceEnumState(
-                        assignedPartitions, unassignedInitialPartitions, initialDiscoveryFinished));
+                        assignedSplits, unassignedInitialSplits, initialDiscoveryFinished));
     }
 
     // ---------------------
 
+    /** The standard {@link OffsetsInitializer}s used for parameterized tests. */
+    enum StandardOffsetsInitializer {
+        EARLIEST_OFFSETS(OffsetsInitializer.earliest()),
+        LATEST_OFFSETS(OffsetsInitializer.latest()),
+        SPECIFIC_OFFSETS(OffsetsInitializer.offsets(specificOffsets, OffsetResetStrategy.NONE)),
+        COMMITTED_OFFSETS(OffsetsInitializer.committedOffsets());
+
+        private final OffsetsInitializer offsetsInitializer;
+
+        StandardOffsetsInitializer(OffsetsInitializer offsetsInitializer) {
+            this.offsetsInitializer = offsetsInitializer;
+        }
+
+        public OffsetsInitializer getOffsetsInitializer() {
+            return offsetsInitializer;
+        }
+    }
+
     private void registerReader(
             MockSplitEnumeratorContext<KafkaPartitionSplit> context,
             KafkaSourceEnumerator enumerator,
@@ -662,63 +793,84 @@
             Collection<Integer> readers,
             Set<String> topics,
             int expectedAssignmentSeqSize) {
+        verifyLastReadersAssignments(
+                context, readers, topics, expectedAssignmentSeqSize, OffsetsInitializer.earliest());
+    }
+
+    private void verifyLastReadersAssignments(
+            MockSplitEnumeratorContext<KafkaPartitionSplit> context,
+            Collection<Integer> readers,
+            Set<String> topics,
+            int expectedAssignmentSeqSize,
+            OffsetsInitializer offsetsInitializer) {
         verifyAssignments(
-                getExpectedAssignments(new HashSet<>(readers), topics),
+                getExpectedAssignments(new HashSet<>(readers), topics, offsetsInitializer),
                 context.getSplitsAssignmentSequence()
                         .get(expectedAssignmentSeqSize - 1)
                         .assignment());
     }
 
     private void verifyAssignments(
-            Map<Integer, Set<TopicPartition>> expectedAssignments,
+            Map<Integer, Collection<KafkaPartitionSplit>> expectedAssignments,
             Map<Integer, List<KafkaPartitionSplit>> actualAssignments) {
-        actualAssignments.forEach(
-                (reader, splits) -> {
-                    Set<TopicPartition> expectedAssignmentsForReader =
-                            expectedAssignments.get(reader);
-                    assertThat(expectedAssignmentsForReader).isNotNull();
-                    assertThat(splits.size()).isEqualTo(expectedAssignmentsForReader.size());
-                    for (KafkaPartitionSplit split : splits) {
-                        assertThat(expectedAssignmentsForReader)
-                                .contains(split.getTopicPartition());
+        assertThat(actualAssignments).containsOnlyKeys(expectedAssignments.keySet());
+        SoftAssertions.assertSoftly(
+                softly -> {
+                    for (Map.Entry<Integer, List<KafkaPartitionSplit>> actual :
+                            actualAssignments.entrySet()) {
+                        softly.assertThat(actual.getValue())
+                                .as("Assignment for reader %s", actual.getKey())
+                                .containsExactlyInAnyOrderElementsOf(
+                                        expectedAssignments.get(actual.getKey()));
                     }
                 });
     }
 
-    private Map<Integer, Set<TopicPartition>> getExpectedAssignments(
-            Set<Integer> readers, Set<String> topics) {
-        Map<Integer, Set<TopicPartition>> expectedAssignments = new HashMap<>();
-        Set<TopicPartition> allPartitions = new HashSet<>();
+    private Map<Integer, Collection<KafkaPartitionSplit>> getExpectedAssignments(
+            Set<Integer> readers,
+            Set<String> topics,
+            OffsetsInitializer startingOffsetsInitializer) {
+        Map<Integer, Collection<KafkaPartitionSplit>> expectedAssignments = new HashMap<>();
+        Set<KafkaPartitionSplit> allPartitions = new HashSet<>();
 
         if (topics.contains(DYNAMIC_TOPIC_NAME)) {
             for (int i = 0; i < NUM_PARTITIONS_DYNAMIC_TOPIC; i++) {
-                allPartitions.add(new TopicPartition(DYNAMIC_TOPIC_NAME, i));
+                TopicPartition tp = new TopicPartition(DYNAMIC_TOPIC_NAME, i);
+                allPartitions.add(createSplit(tp, startingOffsetsInitializer));
             }
         }
 
         for (TopicPartition tp : KafkaSourceTestEnv.getPartitionsForTopics(PRE_EXISTING_TOPICS)) {
             if (topics.contains(tp.topic())) {
-                allPartitions.add(tp);
+                allPartitions.add(createSplit(tp, startingOffsetsInitializer));
             }
         }
 
-        for (TopicPartition tp : allPartitions) {
-            int ownerReader = KafkaSourceEnumerator.getSplitOwner(tp, NUM_SUBTASKS);
+        for (KafkaPartitionSplit split : allPartitions) {
+            int ownerReader =
+                    KafkaSourceEnumerator.getSplitOwner(split.getTopicPartition(), NUM_SUBTASKS);
             if (readers.contains(ownerReader)) {
-                expectedAssignments.computeIfAbsent(ownerReader, r -> new HashSet<>()).add(tp);
+                expectedAssignments.computeIfAbsent(ownerReader, r -> new HashSet<>()).add(split);
             }
         }
         return expectedAssignments;
     }
 
+    private static KafkaPartitionSplit createSplit(
+            TopicPartition tp, OffsetsInitializer startingOffsetsInitializer) {
+        return new KafkaPartitionSplit(
+                tp, startingOffsetsInitializer.getPartitionOffsets(List.of(tp), retriever).get(tp));
+    }
+
     private void verifySplitAssignmentWithPartitions(
-            Map<Integer, Set<TopicPartition>> expectedAssignment,
-            Set<TopicPartition> actualTopicPartitions) {
-        final Set<TopicPartition> allTopicPartitionsFromAssignment = new HashSet<>();
-        expectedAssignment.forEach(
-                (reader, topicPartitions) ->
-                        allTopicPartitionsFromAssignment.addAll(topicPartitions));
-        assertThat(actualTopicPartitions).isEqualTo(allTopicPartitionsFromAssignment);
+            Map<Integer, Collection<KafkaPartitionSplit>> expectedAssignment,
+            Collection<KafkaPartitionSplit> actualTopicPartitions) {
+        final Set<KafkaPartitionSplit> allTopicPartitionsFromAssignment =
+                expectedAssignment.values().stream()
+                        .flatMap(Collection::stream)
+                        .collect(Collectors.toSet());
+        assertThat(actualTopicPartitions)
+                .containsExactlyInAnyOrderElementsOf(allTopicPartitionsFromAssignment);
     }
 
     /** get all assigned partition splits of topics. */
@@ -740,12 +892,11 @@
         return allSplits;
     }
 
-    private Set<TopicPartition> asEnumState(Map<Integer, List<KafkaPartitionSplit>> assignments) {
-        Set<TopicPartition> enumState = new HashSet<>();
-        assignments.forEach(
-                (reader, assignment) ->
-                        assignment.forEach(split -> enumState.add(split.getTopicPartition())));
-        return enumState;
+    private Collection<KafkaPartitionSplit> asEnumState(
+            Map<Integer, List<KafkaPartitionSplit>> assignments) {
+        return assignments.values().stream()
+                .flatMap(Collection::stream)
+                .collect(Collectors.toList());
     }
 
     private void runOneTimePartitionDiscovery(