KAFKA-16526; Quorum state data version 1 (#15859)

Allow KRaft replicas to read and write version 0 and 1 of the quorum-state file. Which version is written is controlled by the kraft.version. With kraft.version 0, version 0 of the quorum-state file is written. With kraft.version 1, version 1 of the quorum-state file is written. Version 1 of the quorum-state file adds the VotedDirectoryId field and removes the CurrentVoters. The other fields removed in version 1 are not important as they were not overwritten or used by KRaft.

In kraft.version 1 the set of voters will be stored in the kraft partition log segments and snapshots.

To implement this feature the following changes were made to KRaft.

FileBasedStateStore was renamed to FileQuorumStateStore to better match the name of the implemented interface QuorumStateStore.

The QuorumStateStore::writeElectionState was extended to include the kraft.version. This version is used to determine which version of QuorumStateData to store. When writing version 0 the VotedDirectoryId is not persisted but the latest value is kept in-memory. This allows the replica to vote consistently while they stay online. If a replica restarts in the middle of an election it will forget the VotedDirectoryId if the kraft.version is 0. This should be rare in practice and should only happen if there is an election and failure while the system is upgrading to kraft.version 1.

The type ElectionState, the interface EpochState and all of the implementations of EpochState (VotedState, UnattachedState, FollowerState, ResignedState, CandidateState and LeaderState) are extended to support the new voted directory id.

The type QuorumState is changed so that local directory id is used. The type is also changed so that the latest value for the set of voters and the kraft version is query from the KRaftControlRecordStateMachine.

The replica directory id is read from the meta.properties and passed to the KafkaRaftClient. The replica directory id is guaranteed to be set in the local replica.

Adds a new metric for current-vote-directory-id which exposes the latest in-memory value of the voted directory id.

Renames VoterSet.VoterKey to ReplicaKey.

It is important to note that after this change, version 1 of the quorum-state file will not be written by kraft controllers and brokers. This change adds support reading and writing version 1 of the file in preparation for future changes.

Reviewers: Jun Rao <junrao@apache.org>
diff --git a/core/src/main/scala/kafka/raft/RaftManager.scala b/core/src/main/scala/kafka/raft/RaftManager.scala
index 0430952..f44197a 100644
--- a/core/src/main/scala/kafka/raft/RaftManager.scala
+++ b/core/src/main/scala/kafka/raft/RaftManager.scala
@@ -41,7 +41,7 @@
 import org.apache.kafka.common.security.JaasContext
 import org.apache.kafka.common.security.auth.SecurityProtocol
 import org.apache.kafka.common.utils.{LogContext, Time, Utils}
-import org.apache.kafka.raft.{FileBasedStateStore, KafkaNetworkChannel, KafkaRaftClient, KafkaRaftClientDriver, LeaderAndEpoch, RaftClient, QuorumConfig, ReplicatedLog}
+import org.apache.kafka.raft.{FileQuorumStateStore, KafkaNetworkChannel, KafkaRaftClient, KafkaRaftClientDriver, LeaderAndEpoch, RaftClient, QuorumConfig, ReplicatedLog}
 import org.apache.kafka.server.ProcessRole
 import org.apache.kafka.server.common.serialization.RecordSerde
 import org.apache.kafka.server.util.KafkaScheduler
@@ -138,6 +138,7 @@
 class KafkaRaftManager[T](
   clusterId: String,
   config: KafkaConfig,
+  metadataLogDirUuid: Uuid,
   recordSerde: RecordSerde[T],
   topicPartition: TopicPartition,
   topicId: Uuid,
@@ -184,7 +185,7 @@
     client.initialize(
       controllerQuorumVotersFuture.get(),
       config.controllerListenerNames.head,
-      new FileBasedStateStore(new File(dataDir, FileBasedStateStore.DEFAULT_FILE_NAME)),
+      new FileQuorumStateStore(new File(dataDir, FileQuorumStateStore.DEFAULT_FILE_NAME)),
       metrics
     )
     netChannel.start()
@@ -218,6 +219,7 @@
   private def buildRaftClient(): KafkaRaftClient[T] = {
     val client = new KafkaRaftClient(
       OptionalInt.of(config.nodeId),
+      metadataLogDirUuid,
       recordSerde,
       netChannel,
       replicatedLog,
diff --git a/core/src/main/scala/kafka/server/KafkaServer.scala b/core/src/main/scala/kafka/server/KafkaServer.scala
index 0f54121..8dc35ff 100755
--- a/core/src/main/scala/kafka/server/KafkaServer.scala
+++ b/core/src/main/scala/kafka/server/KafkaServer.scala
@@ -236,6 +236,9 @@
         val initialMetaPropsEnsemble = {
           val loader = new MetaPropertiesEnsemble.Loader()
           config.logDirs.foreach(loader.addLogDir)
+          if (config.migrationEnabled) {
+            loader.addMetadataLogDir(config.metadataLogDir)
+          }
           loader.load()
         }
 
@@ -432,6 +435,8 @@
           raftManager = new KafkaRaftManager[ApiMessageAndVersion](
             metaPropsEnsemble.clusterId().get(),
             config,
+            // metadata log dir and directory.id must exist because migration is enabled
+            metaPropsEnsemble.logDirProps.get(metaPropsEnsemble.metadataLogDir.get).directoryId.get,
             new MetadataRecordSerde,
             KafkaRaftServer.MetadataPartition,
             KafkaRaftServer.MetadataTopicId,
diff --git a/core/src/main/scala/kafka/server/SharedServer.scala b/core/src/main/scala/kafka/server/SharedServer.scala
index e2a5e33..215208f 100644
--- a/core/src/main/scala/kafka/server/SharedServer.scala
+++ b/core/src/main/scala/kafka/server/SharedServer.scala
@@ -257,6 +257,7 @@
         val _raftManager = new KafkaRaftManager[ApiMessageAndVersion](
           clusterId,
           sharedServerConfig,
+          metaPropsEnsemble.logDirProps.get(metaPropsEnsemble.metadataLogDir.get).directoryId.get,
           new MetadataRecordSerde,
           KafkaRaftServer.MetadataPartition,
           KafkaRaftServer.MetadataTopicId,
diff --git a/core/src/main/scala/kafka/tools/TestRaftServer.scala b/core/src/main/scala/kafka/tools/TestRaftServer.scala
index 1be874a..5060e25 100644
--- a/core/src/main/scala/kafka/tools/TestRaftServer.scala
+++ b/core/src/main/scala/kafka/tools/TestRaftServer.scala
@@ -52,6 +52,7 @@
  */
 class TestRaftServer(
   val config: KafkaConfig,
+  val nodeDirectoryId: Uuid,
   val throughput: Int,
   val recordSize: Int
 ) extends Logging {
@@ -86,6 +87,7 @@
     raftManager = new KafkaRaftManager[Array[Byte]](
       Uuid.ZERO_UUID.toString,
       config,
+      nodeDirectoryId,
       new ByteArraySerde,
       partition,
       topicId,
@@ -431,6 +433,11 @@
       .ofType(classOf[Int])
       .defaultsTo(256)
 
+    val directoryId: OptionSpec[String] = parser.accepts("replica-directory-id", "The directory id of the replica")
+      .withRequiredArg
+      .describedAs("directory id")
+      .ofType(classOf[String])
+
     options = parser.parse(args : _*)
   }
 
@@ -444,6 +451,11 @@
       if (configFile == null) {
         throw new InvalidConfigurationException("Missing configuration file. Should specify with '--config'")
       }
+
+      val directoryIdAsString = opts.options.valueOf(opts.directoryId)
+      if (directoryIdAsString == null) {
+        throw new InvalidConfigurationException("Missing replica directory id. Should specify with --replica-directory-id")
+      }
       val serverProps = Utils.loadProps(configFile)
 
       // KafkaConfig requires either `process.roles` or `zookeeper.connect`. Neither are
@@ -453,7 +465,7 @@
       val config = KafkaConfig.fromProps(serverProps, doLog = false)
       val throughput = opts.options.valueOf(opts.throughputOpt)
       val recordSize = opts.options.valueOf(opts.recordSizeOpt)
-      val server = new TestRaftServer(config, throughput, recordSize)
+      val server = new TestRaftServer(config, Uuid.fromString(directoryIdAsString), throughput, recordSize)
 
       Exit.addShutdownHook("raft-shutdown-hook", server.shutdown())
 
diff --git a/core/src/test/scala/unit/kafka/raft/RaftManagerTest.scala b/core/src/test/scala/unit/kafka/raft/RaftManagerTest.scala
index a75abd5..e6153bd 100644
--- a/core/src/test/scala/unit/kafka/raft/RaftManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/raft/RaftManagerTest.scala
@@ -110,6 +110,7 @@
     new KafkaRaftManager[Array[Byte]](
       Uuid.randomUuid.toString,
       config,
+      Uuid.randomUuid,
       new ByteArraySerde,
       topicPartition,
       topicId,
diff --git a/core/src/test/scala/unit/kafka/tools/DumpLogSegmentsTest.scala b/core/src/test/scala/unit/kafka/tools/DumpLogSegmentsTest.scala
index 947f0ff..d433afd 100644
--- a/core/src/test/scala/unit/kafka/tools/DumpLogSegmentsTest.scala
+++ b/core/src/test/scala/unit/kafka/tools/DumpLogSegmentsTest.scala
@@ -337,7 +337,7 @@
         .setLastContainedLogTimestamp(lastContainedLogTimestamp)
         .setRawSnapshotWriter(metadataLog.createNewSnapshot(new OffsetAndEpoch(0, 0)).get)
         .setKraftVersion(1)
-        .setVoterSet(Optional.of(VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(1, 2, 3)))))
+        .setVoterSet(Optional.of(VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true))))
         .build(MetadataRecordSerde.INSTANCE)
     ) { snapshotWriter =>
       snapshotWriter.append(metadataRecords.asJava)
diff --git a/raft/src/main/java/org/apache/kafka/raft/CandidateState.java b/raft/src/main/java/org/apache/kafka/raft/CandidateState.java
index e4c68db..9e08b3d 100644
--- a/raft/src/main/java/org/apache/kafka/raft/CandidateState.java
+++ b/raft/src/main/java/org/apache/kafka/raft/CandidateState.java
@@ -16,19 +16,22 @@
  */
 package org.apache.kafka.raft;
 
-import org.apache.kafka.common.utils.LogContext;
-import org.apache.kafka.common.utils.Time;
-import org.apache.kafka.common.utils.Timer;
-import org.slf4j.Logger;
-
 import java.util.HashMap;
 import java.util.Map;
 import java.util.Optional;
 import java.util.Set;
 import java.util.stream.Collectors;
+import org.apache.kafka.common.Uuid;
+import org.apache.kafka.common.utils.LogContext;
+import org.apache.kafka.common.utils.Time;
+import org.apache.kafka.common.utils.Timer;
+import org.apache.kafka.raft.internals.ReplicaKey;
+import org.apache.kafka.raft.internals.VoterSet;
+import org.slf4j.Logger;
 
 public class CandidateState implements EpochState {
     private final int localId;
+    private final Uuid localDirectoryId;
     private final int epoch;
     private final int retries;
     private final Map<Integer, State> voteStates = new HashMap<>();
@@ -39,7 +42,7 @@
     private final Logger log;
 
     /**
-     * The lifetime of a candidate state is the following:
+     * The lifetime of a candidate state is the following.
      *
      *  1. Once started, it would keep record of the received votes.
      *  2. If majority votes granted, it can then end its life and will be replaced by a leader state;
@@ -51,14 +54,27 @@
     protected CandidateState(
         Time time,
         int localId,
+        Uuid localDirectoryId,
         int epoch,
-        Set<Integer> voters,
+        VoterSet voters,
         Optional<LogOffsetMetadata> highWatermark,
         int retries,
         int electionTimeoutMs,
         LogContext logContext
     ) {
+        if (!voters.isVoter(ReplicaKey.of(localId, Optional.of(localDirectoryId)))) {
+            throw new IllegalArgumentException(
+                String.format(
+                    "Local replica (%d, %s) must be in the set of voters %s",
+                    localId,
+                    localDirectoryId,
+                    voters
+                )
+            );
+        }
+
         this.localId = localId;
+        this.localDirectoryId = localDirectoryId;
         this.epoch = epoch;
         this.highWatermark = highWatermark;
         this.retries = retries;
@@ -68,7 +84,7 @@
         this.backoffTimer = time.timer(0);
         this.log = logContext.logger(CandidateState.class);
 
-        for (Integer voterId : voters) {
+        for (Integer voterId : voters.voterIds()) {
             voteStates.put(voterId, State.UNRECORDED);
         }
         voteStates.put(localId, State.GRANTED);
@@ -227,7 +243,11 @@
 
     @Override
     public ElectionState election() {
-        return ElectionState.withVotedCandidate(epoch, localId, voteStates.keySet());
+        return ElectionState.withVotedCandidate(
+            epoch,
+            ReplicaKey.of(localId, Optional.of(localDirectoryId)),
+            voteStates.keySet()
+        );
     }
 
     @Override
@@ -241,24 +261,33 @@
     }
 
     @Override
-    public boolean canGrantVote(int candidateId, boolean isLogUpToDate) {
+    public boolean canGrantVote(
+        ReplicaKey candidateKey,
+        boolean isLogUpToDate
+    ) {
         // Still reject vote request even candidateId = localId, Although the candidate votes for
         // itself, this vote is implicit and not "granted".
-        log.debug("Rejecting vote request from candidate {} since we are already candidate in epoch {}",
-            candidateId, epoch);
+        log.debug(
+            "Rejecting vote request from candidate ({}) since we are already candidate in epoch {}",
+            candidateKey,
+            epoch
+        );
         return false;
     }
 
     @Override
     public String toString() {
-        return "CandidateState(" +
-            "localId=" + localId +
-            ", epoch=" + epoch +
-            ", retries=" + retries +
-            ", voteStates=" + voteStates +
-            ", highWatermark=" + highWatermark +
-            ", electionTimeoutMs=" + electionTimeoutMs +
-            ')';
+        return String.format(
+            "CandidateState(localId=%d, localDirectoryId=%s,epoch=%d, retries=%d, voteStates=%s, " +
+            "highWatermark=%s, electionTimeoutMs=%d)",
+            localId,
+            localDirectoryId,
+            epoch,
+            retries,
+            voteStates,
+            highWatermark,
+            electionTimeoutMs
+        );
     }
 
     @Override
diff --git a/raft/src/main/java/org/apache/kafka/raft/ElectionState.java b/raft/src/main/java/org/apache/kafka/raft/ElectionState.java
index 43db9c0..005ff23 100644
--- a/raft/src/main/java/org/apache/kafka/raft/ElectionState.java
+++ b/raft/src/main/java/org/apache/kafka/raft/ElectionState.java
@@ -16,46 +16,44 @@
  */
 package org.apache.kafka.raft;
 
+import java.util.List;
+import java.util.Objects;
+import java.util.Optional;
 import java.util.OptionalInt;
 import java.util.Set;
+import java.util.stream.Collectors;
+import org.apache.kafka.common.Uuid;
+import org.apache.kafka.raft.generated.QuorumStateData;
+import org.apache.kafka.raft.internals.ReplicaKey;
 
 /**
  * Encapsulate election state stored on disk after every state change.
  */
-public class ElectionState {
-    public final int epoch;
-    public final OptionalInt leaderIdOpt;
-    public final OptionalInt votedIdOpt;
+final public class ElectionState {
+    private static int unknownLeaderId = -1;
+    private static int notVoted = -1;
+    private static Uuid noVotedDirectoryId = Uuid.ZERO_UUID;
+
+    private final int epoch;
+    private final OptionalInt leaderId;
+    private final Optional<ReplicaKey> votedKey;
+    // This is deprecated. It is only used when writing version 0 of the quorum state file
     private final Set<Integer> voters;
 
-    ElectionState(int epoch,
-                  OptionalInt leaderIdOpt,
-                  OptionalInt votedIdOpt,
-                  Set<Integer> voters) {
+    ElectionState(
+        int epoch,
+        OptionalInt leaderId,
+        Optional<ReplicaKey> votedKey,
+        Set<Integer> voters
+    ) {
         this.epoch = epoch;
-        this.leaderIdOpt = leaderIdOpt;
-        this.votedIdOpt = votedIdOpt;
+        this.leaderId = leaderId;
+        this.votedKey = votedKey;
         this.voters = voters;
     }
 
-    public static ElectionState withVotedCandidate(int epoch, int votedId, Set<Integer> voters) {
-        if (votedId < 0)
-            throw new IllegalArgumentException("Illegal voted Id " + votedId + ": must be non-negative");
-        if (!voters.contains(votedId))
-            throw new IllegalArgumentException("Voted candidate with id " + votedId + " is not among the valid voters");
-        return new ElectionState(epoch, OptionalInt.empty(), OptionalInt.of(votedId), voters);
-    }
-
-    public static ElectionState withElectedLeader(int epoch, int leaderId, Set<Integer> voters) {
-        if (leaderId < 0)
-            throw new IllegalArgumentException("Illegal leader Id " + leaderId + ": must be non-negative");
-        if (!voters.contains(leaderId))
-            throw new IllegalArgumentException("Leader with id " + leaderId + " is not among the valid voters");
-        return new ElectionState(epoch, OptionalInt.of(leaderId), OptionalInt.empty(), voters);
-    }
-
-    public static ElectionState withUnknownLeader(int epoch, Set<Integer> voters) {
-        return new ElectionState(epoch, OptionalInt.empty(), OptionalInt.empty(), voters);
+    public int epoch() {
+        return epoch;
     }
 
     public boolean isLeader(int nodeId) {
@@ -64,47 +62,100 @@
         return leaderIdOrSentinel() == nodeId;
     }
 
-    public boolean isVotedCandidate(int nodeId) {
-        if (nodeId < 0)
-            throw new IllegalArgumentException("Invalid negative nodeId: " + nodeId);
-        return votedIdOpt.orElse(-1) == nodeId;
+    /**
+     * Return if the replica has voted for the given candidate.
+     *
+     * A replica has voted for a candidate if all of the following are true:
+     * 1. the node's id and voted id match and
+     * 2. if the voted directory id is set, it matches the node's directory id
+     *
+     * @param nodeKey the id and directory id of the replica
+     * @return true when the arguments match, otherwise false
+     */
+    public boolean isVotedCandidate(ReplicaKey nodeKey) {
+        if (nodeKey.id() < 0) {
+            throw new IllegalArgumentException("Invalid node key " + nodeKey);
+        } else if (!votedKey.isPresent()) {
+            return false;
+        } else if (votedKey.get().id() != nodeKey.id()) {
+            return false;
+        } else if (!votedKey.get().directoryId().isPresent()) {
+            // when the persisted voted directory id is not present assume that we voted for this candidate;
+            // this happens when the kraft version is 0.
+            return true;
+        }
+
+        return votedKey.get().directoryId().equals(nodeKey.directoryId());
     }
 
     public int leaderId() {
-        if (!leaderIdOpt.isPresent())
+        if (!leaderId.isPresent())
             throw new IllegalStateException("Attempt to access nil leaderId");
-        return leaderIdOpt.getAsInt();
-    }
-
-    public int votedId() {
-        if (!votedIdOpt.isPresent())
-            throw new IllegalStateException("Attempt to access nil votedId");
-        return votedIdOpt.getAsInt();
-    }
-
-    public Set<Integer> voters() {
-        return voters;
-    }
-
-    public boolean hasLeader() {
-        return leaderIdOpt.isPresent();
-    }
-
-    public boolean hasVoted() {
-        return votedIdOpt.isPresent();
+        return leaderId.getAsInt();
     }
 
     public int leaderIdOrSentinel() {
-        return leaderIdOpt.orElse(-1);
+        return leaderId.orElse(unknownLeaderId);
     }
 
+    public OptionalInt optionalLeaderId() {
+        return leaderId;
+    }
+
+    public ReplicaKey votedKey() {
+        if (!votedKey.isPresent()) {
+            throw new IllegalStateException("Attempt to access nil votedId");
+        }
+
+        return votedKey.get();
+    }
+
+    public Optional<ReplicaKey> optionalVotedKey() {
+        return votedKey;
+    }
+
+    public boolean hasLeader() {
+        return leaderId.isPresent();
+    }
+
+    public boolean hasVoted() {
+        return votedKey.isPresent();
+    }
+
+    public QuorumStateData toQuorumStateData(short version) {
+        QuorumStateData data = new QuorumStateData()
+            .setLeaderEpoch(epoch)
+            .setLeaderId(leaderIdOrSentinel())
+            .setVotedId(votedKey.map(ReplicaKey::id).orElse(notVoted));
+
+        if (version == 0) {
+            List<QuorumStateData.Voter> dataVoters = voters
+                .stream()
+                .map(voterId -> new QuorumStateData.Voter().setVoterId(voterId))
+                .collect(Collectors.toList());
+            data.setCurrentVoters(dataVoters);
+        } else if (version == 1) {
+            data.setVotedDirectoryId(votedKey.flatMap(ReplicaKey::directoryId).orElse(noVotedDirectoryId));
+        } else {
+            throw new IllegalStateException(
+                String.format(
+                    "File quorum state store doesn't handle supported version %d", version
+                )
+            );
+        }
+
+        return data;
+    }
 
     @Override
     public String toString() {
-        return "Election(epoch=" + epoch +
-                ", leaderIdOpt=" + leaderIdOpt +
-                ", votedIdOpt=" + votedIdOpt +
-                ')';
+        return String.format(
+            "Election(epoch=%d, leaderId=%s, votedKey=%s, voters=%s)",
+            epoch,
+            leaderId,
+            votedKey,
+            voters
+        );
     }
 
     @Override
@@ -115,15 +166,51 @@
         ElectionState that = (ElectionState) o;
 
         if (epoch != that.epoch) return false;
-        if (!leaderIdOpt.equals(that.leaderIdOpt)) return false;
-        return votedIdOpt.equals(that.votedIdOpt);
+        if (!leaderId.equals(that.leaderId)) return false;
+        if (!votedKey.equals(that.votedKey)) return false;
+
+        return voters.equals(that.voters);
     }
 
     @Override
     public int hashCode() {
-        int result = epoch;
-        result = 31 * result + leaderIdOpt.hashCode();
-        result = 31 * result + votedIdOpt.hashCode();
-        return result;
+        return Objects.hash(epoch, leaderId, votedKey, voters);
+    }
+
+    public static ElectionState withVotedCandidate(int epoch, ReplicaKey votedKey, Set<Integer> voters) {
+        if (votedKey.id() < 0) {
+            throw new IllegalArgumentException("Illegal voted Id " + votedKey.id() + ": must be non-negative");
+        }
+
+        return new ElectionState(epoch, OptionalInt.empty(), Optional.of(votedKey), voters);
+    }
+
+    public static ElectionState withElectedLeader(int epoch, int leaderId, Set<Integer> voters) {
+        if (leaderId < 0) {
+            throw new IllegalArgumentException("Illegal leader Id " + leaderId + ": must be non-negative");
+        }
+
+        return new ElectionState(epoch, OptionalInt.of(leaderId), Optional.empty(), voters);
+    }
+
+    public static ElectionState withUnknownLeader(int epoch, Set<Integer> voters) {
+        return new ElectionState(epoch, OptionalInt.empty(), Optional.empty(), voters);
+    }
+
+    public static ElectionState fromQuorumStateData(QuorumStateData data) {
+        Optional<Uuid> votedDirectoryId = data.votedDirectoryId().equals(noVotedDirectoryId) ?
+            Optional.empty() :
+            Optional.of(data.votedDirectoryId());
+
+        Optional<ReplicaKey> votedKey = data.votedId() == notVoted ?
+            Optional.empty() :
+            Optional.of(ReplicaKey.of(data.votedId(), votedDirectoryId));
+
+        return new ElectionState(
+            data.leaderEpoch(),
+            data.leaderId() == unknownLeaderId ? OptionalInt.empty() : OptionalInt.of(data.leaderId()),
+            votedKey,
+            data.currentVoters().stream().map(QuorumStateData.Voter::voterId).collect(Collectors.toSet())
+        );
     }
 }
diff --git a/raft/src/main/java/org/apache/kafka/raft/EpochState.java b/raft/src/main/java/org/apache/kafka/raft/EpochState.java
index 9cf231c..c9ab157 100644
--- a/raft/src/main/java/org/apache/kafka/raft/EpochState.java
+++ b/raft/src/main/java/org/apache/kafka/raft/EpochState.java
@@ -18,6 +18,7 @@
 
 import java.io.Closeable;
 import java.util.Optional;
+import org.apache.kafka.raft.internals.ReplicaKey;
 
 public interface EpochState extends Closeable {
 
@@ -26,15 +27,16 @@
     }
 
     /**
-     * Decide whether to grant a vote to a candidate, it is the responsibility of the caller to invoke
-     * {@link QuorumState#transitionToVoted(int, int)} if vote is granted.
+     * Decide whether to grant a vote to a candidate.
      *
-     * @param candidateId The ID of the voter who attempt to become leader
-     * @param isLogUpToDate Whether the candidate’s log is at least as up-to-date as receiver’s log, it
-     *                      is the responsibility of the caller to compare the log in advance
-     * @return true If grant vote.
+     * It is the responsibility of the caller to invoke
+     * {@link QuorumState#transitionToVoted(int, ReplicaKey)} if vote is granted.
+     *
+     * @param candidateKey the id and directory of the candidate
+     * @param isLogUpToDate whether the candidate’s log is at least as up-to-date as receiver’s log
+     * @return true if it can grant the vote, false otherwise
      */
-    boolean canGrantVote(int candidateId, boolean isLogUpToDate);
+    boolean canGrantVote(ReplicaKey candidateKey, boolean isLogUpToDate);
 
     /**
      * Get the current election state, which is guaranteed to be immutable.
@@ -50,5 +52,4 @@
      * User-friendly description of the state
      */
     String name();
-
 }
diff --git a/raft/src/main/java/org/apache/kafka/raft/FileBasedStateStore.java b/raft/src/main/java/org/apache/kafka/raft/FileQuorumStateStore.java
similarity index 68%
rename from raft/src/main/java/org/apache/kafka/raft/FileBasedStateStore.java
rename to raft/src/main/java/org/apache/kafka/raft/FileQuorumStateStore.java
index a805d63..698a2ca 100644
--- a/raft/src/main/java/org/apache/kafka/raft/FileBasedStateStore.java
+++ b/raft/src/main/java/org/apache/kafka/raft/FileQuorumStateStore.java
@@ -20,56 +20,65 @@
 import com.fasterxml.jackson.databind.ObjectMapper;
 import com.fasterxml.jackson.databind.node.ObjectNode;
 import com.fasterxml.jackson.databind.node.ShortNode;
-import org.apache.kafka.common.errors.UnsupportedVersionException;
 import org.apache.kafka.common.utils.Utils;
 import org.apache.kafka.raft.generated.QuorumStateData;
-import org.apache.kafka.raft.generated.QuorumStateData.Voter;
 import org.apache.kafka.raft.generated.QuorumStateDataJsonConverter;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import java.io.BufferedReader;
 import java.io.BufferedWriter;
-import java.io.UncheckedIOException;
 import java.io.EOFException;
 import java.io.File;
 import java.io.FileOutputStream;
 import java.io.IOException;
 import java.io.OutputStreamWriter;
+import java.io.UncheckedIOException;
 import java.nio.charset.StandardCharsets;
 import java.nio.file.Files;
 import java.nio.file.Path;
-import java.util.List;
-import java.util.OptionalInt;
-import java.util.Set;
-import java.util.stream.Collectors;
+import java.util.Optional;
 
 /**
  * Local file based quorum state store. It takes the JSON format of {@link QuorumStateData}
- * with an extra data version number as part of the data for easy deserialization.
+ * with an extra data version number field (data_version) as part of the data.
  *
- * Example format:
+ * Example version 0 format:
  * <pre>
- * {"clusterId":"",
- *   "leaderId":1,
- *   "leaderEpoch":2,
- *   "votedId":-1,
- *   "appliedOffset":0,
- *   "currentVoters":[],
- *   "data_version":0}
+ * {
+ *   "clusterId": "",
+ *   "leaderId": 1,
+ *   "leaderEpoch": 2,
+ *   "votedId": -1,
+ *   "appliedOffset": 0,
+ *   "currentVoters": [],
+ *   "data_version": 0
+ * }
+ * </pre>
+ *
+ * Example version 1 format:
+ * <pre>
+ * {
+ *   "leaderId": -1,
+ *   "leaderEpoch": 2,
+ *   "votedId": 1,
+ *   "votedDirectoryId": "J8aAPcfLQt2bqs1JT_rMgQ",
+ *   "data_version": 1
+ * }
  * </pre>
  * */
-public class FileBasedStateStore implements QuorumStateStore {
-    private static final Logger log = LoggerFactory.getLogger(FileBasedStateStore.class);
+public class FileQuorumStateStore implements QuorumStateStore {
+    private static final Logger log = LoggerFactory.getLogger(FileQuorumStateStore.class);
     private static final String DATA_VERSION = "data_version";
 
-    static final short HIGHEST_SUPPORTED_VERSION = 0;
+    static final short LOWEST_SUPPORTED_VERSION = 0;
+    static final short HIGHEST_SUPPORTED_VERSION = 1;
 
     public static final String DEFAULT_FILE_NAME = "quorum-state";
 
     private final File stateFile;
 
-    public FileBasedStateStore(final File stateFile) {
+    public FileQuorumStateStore(final File stateFile) {
         this.stateFile = stateFile;
     }
 
@@ -95,11 +104,18 @@
                     " does not have " + DATA_VERSION + " field");
             }
 
-            if (dataVersionNode.asInt() != 0) {
-                throw new UnsupportedVersionException("Unknown data version of " + dataVersionNode);
+            final short dataVersion = dataVersionNode.shortValue();
+            if (dataVersion < LOWEST_SUPPORTED_VERSION || dataVersion > HIGHEST_SUPPORTED_VERSION) {
+                throw new IllegalStateException(
+                    String.format(
+                        "data_version (%d) is not within the min (%d) and max ($d) supported version",
+                        dataVersion,
+                        LOWEST_SUPPORTED_VERSION,
+                        HIGHEST_SUPPORTED_VERSION
+                    )
+                );
             }
 
-            final short dataVersion = dataVersionNode.shortValue();
             return QuorumStateDataJsonConverter.read(dataObject, dataVersion);
         } catch (IOException e) {
             throw new UncheckedIOException(
@@ -111,30 +127,23 @@
      * Reads the election state from local file.
      */
     @Override
-    public ElectionState readElectionState() {
+    public Optional<ElectionState> readElectionState() {
         if (!stateFile.exists()) {
-            return null;
+            return Optional.empty();
         }
 
-        QuorumStateData data = readStateFromFile(stateFile);
-
-        return new ElectionState(data.leaderEpoch(),
-            data.leaderId() == UNKNOWN_LEADER_ID ? OptionalInt.empty() :
-                OptionalInt.of(data.leaderId()),
-            data.votedId() == NOT_VOTED ? OptionalInt.empty() :
-                OptionalInt.of(data.votedId()),
-            data.currentVoters()
-                .stream().map(Voter::voterId).collect(Collectors.toSet()));
+        return Optional.of(ElectionState.fromQuorumStateData(readStateFromFile(stateFile)));
     }
 
     @Override
-    public void writeElectionState(ElectionState latest) {
-        QuorumStateData data = new QuorumStateData()
-            .setLeaderEpoch(latest.epoch)
-            .setVotedId(latest.hasVoted() ? latest.votedId() : NOT_VOTED)
-            .setLeaderId(latest.hasLeader() ? latest.leaderId() : UNKNOWN_LEADER_ID)
-            .setCurrentVoters(voters(latest.voters()));
-        writeElectionStateToFile(stateFile, data);
+    public void writeElectionState(ElectionState latest, short kraftVersion) {
+        short quorumStateVersion = quorumStateVersionFromKRaftVersion(kraftVersion);
+
+        writeElectionStateToFile(
+            stateFile,
+            latest.toQuorumStateData(quorumStateVersion),
+            quorumStateVersion
+        );
     }
 
     @Override
@@ -142,12 +151,28 @@
         return stateFile.toPath();
     }
 
-    private List<Voter> voters(Set<Integer> votersId) {
-        return votersId.stream().map(
-            voterId -> new Voter().setVoterId(voterId)).collect(Collectors.toList());
+    private short quorumStateVersionFromKRaftVersion(short kraftVersion) {
+        if (kraftVersion == 0) {
+            return 0;
+        } else if (kraftVersion == 1) {
+            return 1;
+        } else {
+            throw new IllegalArgumentException(
+                String.format("Unknown kraft.version %d", kraftVersion)
+            );
+        }
     }
 
-    private void writeElectionStateToFile(final File stateFile, QuorumStateData state) {
+    private void writeElectionStateToFile(final File stateFile, QuorumStateData state, short version) {
+        if (version > HIGHEST_SUPPORTED_VERSION) {
+            throw new IllegalArgumentException(
+                String.format(
+                    "Quorum state data version (%d) is greater than the supported version (%d)",
+                    version,
+                    HIGHEST_SUPPORTED_VERSION
+                )
+            );
+        }
         final File temp = new File(stateFile.getAbsolutePath() + ".tmp");
         deleteFileIfExists(temp);
 
@@ -159,8 +184,8 @@
                      new OutputStreamWriter(fileOutputStream, StandardCharsets.UTF_8)
                  )
             ) {
-                ObjectNode jsonState = (ObjectNode) QuorumStateDataJsonConverter.write(state, HIGHEST_SUPPORTED_VERSION);
-                jsonState.set(DATA_VERSION, new ShortNode(HIGHEST_SUPPORTED_VERSION));
+                ObjectNode jsonState = (ObjectNode) QuorumStateDataJsonConverter.write(state, version);
+                jsonState.set(DATA_VERSION, new ShortNode(version));
                 writer.write(jsonState.toString());
                 writer.flush();
                 fileOutputStream.getFD().sync();
diff --git a/raft/src/main/java/org/apache/kafka/raft/FollowerState.java b/raft/src/main/java/org/apache/kafka/raft/FollowerState.java
index aa8f9f7..49bfaff 100644
--- a/raft/src/main/java/org/apache/kafka/raft/FollowerState.java
+++ b/raft/src/main/java/org/apache/kafka/raft/FollowerState.java
@@ -16,17 +16,16 @@
  */
 package org.apache.kafka.raft;
 
+import java.util.Optional;
+import java.util.OptionalLong;
+import java.util.Set;
 import org.apache.kafka.common.utils.LogContext;
 import org.apache.kafka.common.utils.Time;
 import org.apache.kafka.common.utils.Timer;
+import org.apache.kafka.raft.internals.ReplicaKey;
 import org.apache.kafka.snapshot.RawSnapshotWriter;
 import org.slf4j.Logger;
 
-import java.util.Optional;
-import java.util.OptionalInt;
-import java.util.OptionalLong;
-import java.util.Set;
-
 public class FollowerState implements EpochState {
     private final int fetchTimeoutMs;
     private final int epoch;
@@ -63,12 +62,7 @@
 
     @Override
     public ElectionState election() {
-        return new ElectionState(
-            epoch,
-            OptionalInt.of(leaderId),
-            OptionalInt.empty(),
-            voters
-        );
+        return ElectionState.withElectedLeader(epoch, leaderId, voters);
     }
 
     @Override
@@ -158,9 +152,13 @@
     }
 
     @Override
-    public boolean canGrantVote(int candidateId, boolean isLogUpToDate) {
-        log.debug("Rejecting vote request from candidate {} since we already have a leader {} in epoch {}",
-                candidateId, leaderId(), epoch);
+    public boolean canGrantVote(ReplicaKey candidateKey, boolean isLogUpToDate) {
+        log.debug(
+            "Rejecting vote request from candidate ({}) since we already have a leader {} in epoch {}",
+            candidateKey,
+            leaderId(),
+            epoch
+        );
         return false;
     }
 
diff --git a/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java b/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java
index 70408c7..288933b 100644
--- a/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java
+++ b/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java
@@ -18,6 +18,7 @@
 
 import org.apache.kafka.common.KafkaException;
 import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.Uuid;
 import org.apache.kafka.common.errors.ClusterAuthorizationException;
 import org.apache.kafka.common.errors.NotLeaderOrFollowerException;
 import org.apache.kafka.common.memory.MemoryPool;
@@ -48,8 +49,8 @@
 import org.apache.kafka.common.requests.DescribeQuorumResponse;
 import org.apache.kafka.common.requests.EndQuorumEpochRequest;
 import org.apache.kafka.common.requests.EndQuorumEpochResponse;
-import org.apache.kafka.common.requests.FetchResponse;
 import org.apache.kafka.common.requests.FetchRequest;
+import org.apache.kafka.common.requests.FetchResponse;
 import org.apache.kafka.common.requests.FetchSnapshotRequest;
 import org.apache.kafka.common.requests.FetchSnapshotResponse;
 import org.apache.kafka.common.requests.VoteRequest;
@@ -69,6 +70,7 @@
 import org.apache.kafka.raft.internals.KafkaRaftMetrics;
 import org.apache.kafka.raft.internals.MemoryBatchReader;
 import org.apache.kafka.raft.internals.RecordsBatchReader;
+import org.apache.kafka.raft.internals.ReplicaKey;
 import org.apache.kafka.raft.internals.ThresholdPurgatory;
 import org.apache.kafka.raft.internals.VoterSet;
 import org.apache.kafka.server.common.serialization.RecordSerde;
@@ -150,6 +152,7 @@
     public static final int MAX_FETCH_SIZE_BYTES = MAX_BATCH_SIZE_BYTES;
 
     private final OptionalInt nodeId;
+    private final Uuid nodeDirectoryId;
     private final AtomicReference<GracefulShutdown> shutdown = new AtomicReference<>();
     private final LogContext logContext;
     private final Logger logger;
@@ -197,6 +200,7 @@
      */
     public KafkaRaftClient(
         OptionalInt nodeId,
+        Uuid nodeDirectoryId,
         RecordSerde<T> serde,
         NetworkChannel channel,
         ReplicatedLog log,
@@ -208,6 +212,7 @@
     ) {
         this(
             nodeId,
+            nodeDirectoryId,
             serde,
             channel,
             new BlockingMessageQueue(),
@@ -225,6 +230,7 @@
 
     KafkaRaftClient(
         OptionalInt nodeId,
+        Uuid nodeDirectoryId,
         RecordSerde<T> serde,
         NetworkChannel channel,
         RaftMessageQueue messageQueue,
@@ -239,6 +245,7 @@
         QuorumConfig quorumConfig
     ) {
         this.nodeId = nodeId;
+        this.nodeDirectoryId = nodeDirectoryId;
         this.logContext = logContext;
         this.serde = serde;
         this.channel = channel;
@@ -396,7 +403,9 @@
 
         quorum = new QuorumState(
             nodeId,
-            lastVoterSet.voterIds(),
+            nodeDirectoryId,
+            partitionState::lastVoterSet,
+            partitionState::lastKraftVersion,
             quorumConfig.electionTimeoutMs(),
             quorumConfig.fetchTimeoutMs(),
             quorumStateStore,
@@ -426,10 +435,7 @@
         }
 
         // When there is only a single voter, become candidate immediately
-        if (quorum.isVoter()
-            && quorum.remoteVoters().isEmpty()
-            && !quorum.isCandidate()) {
-
+        if (quorum.isOnlyVoter() && !quorum.isCandidate()) {
             transitionToCandidate(currentTimeMs);
         }
     }
@@ -539,8 +545,8 @@
         resetConnections();
     }
 
-    private void transitionToVoted(int candidateId, int epoch) {
-        quorum.transitionToVoted(epoch, candidateId);
+    private void transitionToVoted(ReplicaKey candidateKey, int epoch) {
+        quorum.transitionToVoted(epoch, candidateKey);
         maybeFireLeaderChange();
         resetConnections();
     }
@@ -627,10 +633,14 @@
         }
 
         OffsetAndEpoch lastEpochEndOffsetAndEpoch = new OffsetAndEpoch(lastEpochEndOffset, lastEpoch);
-        boolean voteGranted = quorum.canGrantVote(candidateId, lastEpochEndOffsetAndEpoch.compareTo(endOffset()) >= 0);
+        ReplicaKey candidateKey = ReplicaKey.of(candidateId, Optional.empty());
+        boolean voteGranted = quorum.canGrantVote(
+            candidateKey,
+            lastEpochEndOffsetAndEpoch.compareTo(endOffset()) >= 0
+        );
 
         if (voteGranted && quorum.isUnattached()) {
-            transitionToVoted(candidateId, candidateEpoch);
+            transitionToVoted(candidateKey, candidateEpoch);
         }
 
         logger.info("Vote request {} with epoch {} is {}", request, candidateEpoch, voteGranted ? "granted" : "rejected");
@@ -1700,16 +1710,16 @@
     }
 
     /**
-     * Validate a request which is only valid between voters. If an error is
-     * present in the returned value, it should be returned in the response.
+     * Validate common state for requests to establish leadership.
+     *
+     * These include the Vote, BeginQuorumEpoch and EndQuorumEpoch RPCs. If an error is present in
+     * the returned value, it should be returned in the response.
      */
     private Optional<Errors> validateVoterOnlyRequest(int remoteNodeId, int requestEpoch) {
         if (requestEpoch < quorum.epoch()) {
             return Optional.of(Errors.FENCED_LEADER_EPOCH);
         } else if (remoteNodeId < 0) {
             return Optional.of(Errors.INVALID_REQUEST);
-        } else if (quorum.isObserver() || !quorum.isVoter(remoteNodeId)) {
-            return Optional.of(Errors.INCONSISTENT_VOTER_SET);
         } else {
             return Optional.empty();
         }
@@ -2300,9 +2310,9 @@
         }
 
         if (quorum.isObserver()
-            || quorum.remoteVoters().isEmpty()
-            || quorum.hasRemoteLeader()) {
-
+            || quorum.isOnlyVoter()
+            || quorum.hasRemoteLeader()
+        ) {
             shutdown.complete();
             return true;
         }
diff --git a/raft/src/main/java/org/apache/kafka/raft/LeaderState.java b/raft/src/main/java/org/apache/kafka/raft/LeaderState.java
index cfa50d4..df4cc13 100644
--- a/raft/src/main/java/org/apache/kafka/raft/LeaderState.java
+++ b/raft/src/main/java/org/apache/kafka/raft/LeaderState.java
@@ -17,14 +17,15 @@
 package org.apache.kafka.raft;
 
 import org.apache.kafka.common.message.DescribeQuorumResponseData;
-import org.apache.kafka.common.message.LeaderChangeMessage;
 import org.apache.kafka.common.message.LeaderChangeMessage.Voter;
+import org.apache.kafka.common.message.LeaderChangeMessage;
 import org.apache.kafka.common.protocol.Errors;
 import org.apache.kafka.common.record.ControlRecordUtils;
 import org.apache.kafka.common.utils.LogContext;
 import org.apache.kafka.common.utils.Time;
 import org.apache.kafka.common.utils.Timer;
 import org.apache.kafka.raft.internals.BatchAccumulator;
+import org.apache.kafka.raft.internals.ReplicaKey;
 import org.slf4j.Logger;
 
 import java.util.ArrayList;
@@ -162,7 +163,7 @@
             .setLeaderId(this.election().leaderId())
             .setVoters(voters)
             .setGrantingVoters(grantingVoters);
-        
+
         accumulator.appendLeaderChangeMessage(leaderChangeMessage, currentTimeMs);
         accumulator.forceDrain();
     }
@@ -513,15 +514,18 @@
                 endOffset,
                 lastFetchTimestamp,
                 lastCaughtUpTimestamp,
-                hasAcknowledgedLeader 
+                hasAcknowledgedLeader
             );
         }
     }
 
     @Override
-    public boolean canGrantVote(int candidateId, boolean isLogUpToDate) {
-        log.debug("Rejecting vote request from candidate {} since we are already leader in epoch {}",
-            candidateId, epoch);
+    public boolean canGrantVote(ReplicaKey candidateKey, boolean isLogUpToDate) {
+        log.debug(
+            "Rejecting vote request from candidate ({}) since we are already leader in epoch {}",
+            candidateKey,
+            epoch
+        );
         return false;
     }
 
diff --git a/raft/src/main/java/org/apache/kafka/raft/QuorumConfig.java b/raft/src/main/java/org/apache/kafka/raft/QuorumConfig.java
index fbfaa82..5c9c20b 100644
--- a/raft/src/main/java/org/apache/kafka/raft/QuorumConfig.java
+++ b/raft/src/main/java/org/apache/kafka/raft/QuorumConfig.java
@@ -189,7 +189,7 @@
             InetSocketAddress address = new InetSocketAddress(host, port);
             if (address.getHostString().equals(NON_ROUTABLE_HOST) && requireRoutableAddresses) {
                 throw new ConfigException(
-                    String.format("Host string ({}) is not routeable", address.getHostString())
+                    String.format("Host string (%s) is not routeable", address.getHostString())
                 );
             } else {
                 voterMap.put(voterId, address);
diff --git a/raft/src/main/java/org/apache/kafka/raft/QuorumState.java b/raft/src/main/java/org/apache/kafka/raft/QuorumState.java
index b38e943..81a6c01 100644
--- a/raft/src/main/java/org/apache/kafka/raft/QuorumState.java
+++ b/raft/src/main/java/org/apache/kafka/raft/QuorumState.java
@@ -16,21 +16,21 @@
  */
 package org.apache.kafka.raft;
 
-import org.apache.kafka.common.utils.LogContext;
-import org.apache.kafka.common.utils.Time;
-import org.apache.kafka.raft.internals.BatchAccumulator;
-import org.slf4j.Logger;
-
 import java.io.IOException;
 import java.io.UncheckedIOException;
 import java.util.Collections;
-import java.util.HashSet;
 import java.util.List;
 import java.util.Optional;
 import java.util.OptionalInt;
 import java.util.Random;
-import java.util.Set;
-import java.util.stream.Collectors;
+import java.util.function.Supplier;
+import org.apache.kafka.common.Uuid;
+import org.apache.kafka.common.utils.LogContext;
+import org.apache.kafka.common.utils.Time;
+import org.apache.kafka.raft.internals.BatchAccumulator;
+import org.apache.kafka.raft.internals.ReplicaKey;
+import org.apache.kafka.raft.internals.VoterSet;
+import org.slf4j.Logger;
 
 /**
  * This class is responsible for managing the current state of this node and ensuring
@@ -76,10 +76,12 @@
  */
 public class QuorumState {
     private final OptionalInt localId;
+    private final Uuid localDirectoryId;
     private final Time time;
     private final Logger log;
     private final QuorumStateStore store;
-    private final Set<Integer> voters;
+    private final Supplier<VoterSet> latestVoterSet;
+    private final Supplier<Short> latestKraftVersion;
     private final Random random;
     private final int electionTimeoutMs;
     private final int fetchTimeoutMs;
@@ -87,16 +89,22 @@
 
     private volatile EpochState state;
 
-    public QuorumState(OptionalInt localId,
-                       Set<Integer> voters,
-                       int electionTimeoutMs,
-                       int fetchTimeoutMs,
-                       QuorumStateStore store,
-                       Time time,
-                       LogContext logContext,
-                       Random random) {
+    public QuorumState(
+        OptionalInt localId,
+        Uuid localDirectoryId,
+        Supplier<VoterSet> latestVoterSet,
+        Supplier<Short> latestKraftVersion,
+        int electionTimeoutMs,
+        int fetchTimeoutMs,
+        QuorumStateStore store,
+        Time time,
+        LogContext logContext,
+        Random random
+    ) {
         this.localId = localId;
-        this.voters = new HashSet<>(voters);
+        this.localDirectoryId = localDirectoryId;
+        this.latestVoterSet = latestVoterSet;
+        this.latestKraftVersion = latestKraftVersion;
         this.electionTimeoutMs = electionTimeoutMs;
         this.fetchTimeoutMs = fetchTimeoutMs;
         this.store = store;
@@ -112,45 +120,30 @@
         // when we send Vote or BeginEpoch requests.
 
         ElectionState election;
-        try {
-            election = store.readElectionState();
-            if (election == null) {
-                election = ElectionState.withUnknownLeader(0, voters);
-            }
-        } catch (final UncheckedIOException e) {
-            // For exceptions during state file loading (missing or not readable),
-            // we could assume the file is corrupted already and should be cleaned up.
-            log.warn("Clearing local quorum state store after error loading state {}",
-                store, e);
-            store.clear();
-            election = ElectionState.withUnknownLeader(0, voters);
-        }
+        election = store
+            .readElectionState()
+            .orElseGet(() -> ElectionState.withUnknownLeader(0, latestVoterSet.get().voterIds()));
 
         final EpochState initialState;
-        if (!election.voters().isEmpty() && !voters.equals(election.voters())) {
-            throw new IllegalStateException("Configured voter set: " + voters
-                + " is different from the voter set read from the state file: " + election.voters()
-                + ". Check if the quorum configuration is up to date, "
-                + "or wipe out the local state file if necessary");
-        } else if (election.hasVoted() && !isVoter()) {
-            String localIdDescription = localId.isPresent() ?
-                localId.getAsInt() + " is not a voter" :
-                "is undefined";
-            throw new IllegalStateException("Initialized quorum state " + election
-                + " with a voted candidate, which indicates this node was previously "
-                + " a voter, but the local id " + localIdDescription);
-        } else if (election.epoch < logEndOffsetAndEpoch.epoch()) {
+        if (election.hasVoted() && !localId.isPresent()) {
+            throw new IllegalStateException(
+                String.format(
+                    "Initialized quorum state (%s) with a voted candidate but without a local id",
+                    election
+                )
+            );
+        } else if (election.epoch() < logEndOffsetAndEpoch.epoch()) {
             log.warn(
                 "Epoch from quorum store file ({}) is {}, which is smaller than last written " +
                 "epoch {} in the log",
                 store.path(),
-                election.epoch,
+                election.epoch(),
                 logEndOffsetAndEpoch.epoch()
             );
             initialState = new UnattachedState(
                 time,
                 logEndOffsetAndEpoch.epoch(),
-                voters,
+                latestVoterSet.get().voterIds(),
                 Optional.empty(),
                 randomElectionTimeoutMs(),
                 logContext
@@ -165,18 +158,22 @@
             initialState = new ResignedState(
                 time,
                 localId.getAsInt(),
-                election.epoch,
-                voters,
+                election.epoch(),
+                latestVoterSet.get().voterIds(),
                 randomElectionTimeoutMs(),
                 Collections.emptyList(),
                 logContext
             );
-        } else if (localId.isPresent() && election.isVotedCandidate(localId.getAsInt())) {
+        } else if (
+            localId.isPresent() &&
+            election.isVotedCandidate(ReplicaKey.of(localId.getAsInt(), Optional.of(localDirectoryId)))
+        ) {
             initialState = new CandidateState(
                 time,
                 localId.getAsInt(),
-                election.epoch,
-                voters,
+                localDirectoryId,
+                election.epoch(),
+                latestVoterSet.get(),
                 Optional.empty(),
                 1,
                 randomElectionTimeoutMs(),
@@ -185,9 +182,9 @@
         } else if (election.hasVoted()) {
             initialState = new VotedState(
                 time,
-                election.epoch,
-                election.votedId(),
-                voters,
+                election.epoch(),
+                election.votedKey(),
+                latestVoterSet.get().voterIds(),
                 Optional.empty(),
                 randomElectionTimeoutMs(),
                 logContext
@@ -195,9 +192,9 @@
         } else if (election.hasLeader()) {
             initialState = new FollowerState(
                 time,
-                election.epoch,
+                election.epoch(),
                 election.leaderId(),
-                voters,
+                latestVoterSet.get().voterIds(),
                 Optional.empty(),
                 fetchTimeoutMs,
                 logContext
@@ -205,8 +202,8 @@
         } else {
             initialState = new UnattachedState(
                 time,
-                election.epoch,
-                voters,
+                election.epoch(),
+                latestVoterSet.get().voterIds(),
                 Optional.empty(),
                 randomElectionTimeoutMs(),
                 logContext
@@ -216,8 +213,11 @@
         durableTransitionTo(initialState);
     }
 
-    public Set<Integer> remoteVoters() {
-        return voters.stream().filter(voterId -> voterId != localIdOrSentinel()).collect(Collectors.toSet());
+    public boolean isOnlyVoter() {
+        return localId.isPresent() &&
+            latestVoterSet.get().isOnlyVoter(
+                ReplicaKey.of(localId.getAsInt(), Optional.of(localDirectoryId))
+            );
     }
 
     public int localIdOrSentinel() {
@@ -232,6 +232,10 @@
         return localId;
     }
 
+    public Uuid localDirectoryId() {
+        return localDirectoryId;
+    }
+
     public int epoch() {
         return state.epoch();
     }
@@ -262,11 +266,17 @@
     }
 
     public boolean isVoter() {
-        return localId.isPresent() && voters.contains(localId.getAsInt());
+        if (!localId.isPresent()) {
+            return false;
+        }
+
+        return latestVoterSet
+            .get()
+            .isVoter(ReplicaKey.of(localId.getAsInt(), Optional.of(localDirectoryId)));
     }
 
-    public boolean isVoter(int nodeId) {
-        return voters.contains(nodeId);
+    public boolean isVoter(ReplicaKey nodeKey) {
+        return latestVoterSet.get().isVoter(nodeKey);
     }
 
     public boolean isObserver() {
@@ -286,7 +296,7 @@
                 time,
                 localIdOrThrow(),
                 epoch,
-                voters,
+                latestVoterSet.get().voterIds(),
                 randomElectionTimeoutMs(),
                 preferredSuccessors,
                 logContext
@@ -321,7 +331,7 @@
         durableTransitionTo(new UnattachedState(
             time,
             epoch,
-            voters,
+            latestVoterSet.get().voterIds(),
             state.highWatermark(),
             electionTimeoutMs,
             logContext
@@ -336,40 +346,54 @@
      */
     public void transitionToVoted(
         int epoch,
-        int candidateId
+        ReplicaKey candidateKey
     ) {
-        if (localId.isPresent() && candidateId == localId.getAsInt()) {
-            throw new IllegalStateException("Cannot transition to Voted with votedId=" + candidateId +
-                " and epoch=" + epoch + " since it matches the local broker.id");
-        } else if (isObserver()) {
-            throw new IllegalStateException("Cannot transition to Voted with votedId=" + candidateId +
-                " and epoch=" + epoch + " since the local broker.id=" + localId + " is not a voter");
-        } else if (!isVoter(candidateId)) {
-            throw new IllegalStateException("Cannot transition to Voted with voterId=" + candidateId +
-                " and epoch=" + epoch + " since it is not one of the voters " + voters);
-        }
-
         int currentEpoch = state.epoch();
-        if (epoch < currentEpoch) {
-            throw new IllegalStateException("Cannot transition to Voted with votedId=" + candidateId +
-                " and epoch=" + epoch + " since the current epoch " + currentEpoch + " is larger");
+        if (localId.isPresent() && candidateKey.id() == localId.getAsInt()) {
+            throw new IllegalStateException(
+                String.format(
+                    "Cannot transition to Voted for %s and epoch %d since it matches the local " +
+                    "broker.id",
+                    candidateKey,
+                    epoch
+                )
+            );
+        } else if (!localId.isPresent()) {
+            throw new IllegalStateException("Cannot transition to voted without a replica id");
+        } else if (epoch < currentEpoch) {
+            throw new IllegalStateException(
+                String.format(
+                    "Cannot transition to Voted for %s and epoch %d since the current epoch " +
+                    "(%d) is larger",
+                    candidateKey,
+                    epoch,
+                    currentEpoch
+                )
+            );
         } else if (epoch == currentEpoch && !isUnattached()) {
-            throw new IllegalStateException("Cannot transition to Voted with votedId=" + candidateId +
-                " and epoch=" + epoch + " from the current state " + state);
+            throw new IllegalStateException(
+                String.format(
+                    "Cannot transition to Voted for %s and epoch %d from the current state (%s)",
+                    candidateKey,
+                    epoch,
+                    state
+                )
+            );
         }
 
         // Note that we reset the election timeout after voting for a candidate because we
         // know that the candidate has at least as good of a chance of getting elected as us
-
-        durableTransitionTo(new VotedState(
-            time,
-            epoch,
-            candidateId,
-            voters,
-            state.highWatermark(),
-            randomElectionTimeoutMs(),
-            logContext
-        ));
+        durableTransitionTo(
+            new VotedState(
+                time,
+                epoch,
+                candidateKey,
+                latestVoterSet.get().voterIds(),
+                state.highWatermark(),
+                randomElectionTimeoutMs(),
+                logContext
+            )
+        );
     }
 
     /**
@@ -379,16 +403,11 @@
         int epoch,
         int leaderId
     ) {
+        int currentEpoch = state.epoch();
         if (localId.isPresent() && leaderId == localId.getAsInt()) {
             throw new IllegalStateException("Cannot transition to Follower with leaderId=" + leaderId +
                 " and epoch=" + epoch + " since it matches the local broker.id=" + localId);
-        } else if (!isVoter(leaderId)) {
-            throw new IllegalStateException("Cannot transition to Follower with leaderId=" + leaderId +
-                " and epoch=" + epoch + " since it is not one of the voters " + voters);
-        }
-
-        int currentEpoch = state.epoch();
-        if (epoch < currentEpoch) {
+        } else if (epoch < currentEpoch) {
             throw new IllegalStateException("Cannot transition to Follower with leaderId=" + leaderId +
                 " and epoch=" + epoch + " since the current epoch " + currentEpoch + " is larger");
         } else if (epoch == currentEpoch
@@ -397,21 +416,30 @@
                 " and epoch=" + epoch + " from state " + state);
         }
 
-        durableTransitionTo(new FollowerState(
-            time,
-            epoch,
-            leaderId,
-            voters,
-            state.highWatermark(),
-            fetchTimeoutMs,
-            logContext
-        ));
+        durableTransitionTo(
+            new FollowerState(
+                time,
+                epoch,
+                leaderId,
+                latestVoterSet.get().voterIds(),
+                state.highWatermark(),
+                fetchTimeoutMs,
+                logContext
+            )
+        );
     }
 
     public void transitionToCandidate() {
         if (isObserver()) {
-            throw new IllegalStateException("Cannot transition to Candidate since the local broker.id=" + localId +
-                " is not one of the voters " + voters);
+            throw new IllegalStateException(
+                String.format(
+                    "Cannot transition to Candidate since the local id (%s) and directory id (%s) " +
+                    "is not one of the voters %s",
+                    localId,
+                    localDirectoryId,
+                    latestVoterSet.get()
+                )
+            );
         } else if (isLeader()) {
             throw new IllegalStateException("Cannot transition to Candidate since the local broker.id=" + localId +
                 " since this node is already a Leader with state " + state);
@@ -424,8 +452,9 @@
         durableTransitionTo(new CandidateState(
             time,
             localIdOrThrow(),
+            localDirectoryId,
             newEpoch,
-            voters,
+            latestVoterSet.get(),
             state.highWatermark(),
             retries,
             electionTimeoutMs,
@@ -435,8 +464,15 @@
 
     public <T> LeaderState<T> transitionToLeader(long epochStartOffset, BatchAccumulator<T> accumulator) {
         if (isObserver()) {
-            throw new IllegalStateException("Cannot transition to Leader since the local broker.id="  + localId +
-                " is not one of the voters " + voters);
+            throw new IllegalStateException(
+                String.format(
+                    "Cannot transition to Leader since the local id (%s) and directory id (%s) " +
+                    "is not one of the voters %s",
+                    localId,
+                    localDirectoryId,
+                    latestVoterSet.get()
+                )
+            );
         } else if (!isCandidate()) {
             throw new IllegalStateException("Cannot transition to Leader from current state " + state);
         }
@@ -461,7 +497,7 @@
             localIdOrThrow(),
             epoch(),
             epochStartOffset,
-            voters,
+            latestVoterSet.get().voterIds(),
             candidateState.grantingVoters(),
             accumulator,
             fetchTimeoutMs,
@@ -471,24 +507,24 @@
         return state;
     }
 
-    private void durableTransitionTo(EpochState state) {
-        if (this.state != null) {
+    private void durableTransitionTo(EpochState newState) {
+        if (state != null) {
             try {
-                this.state.close();
+                state.close();
             } catch (IOException e) {
                 throw new UncheckedIOException(
-                    "Failed to transition from " + this.state.name() + " to " + state.name(), e);
+                    "Failed to transition from " + state.name() + " to " + newState.name(), e);
             }
         }
 
-        this.store.writeElectionState(state.election());
-        memoryTransitionTo(state);
+        store.writeElectionState(newState.election(), latestKraftVersion.get());
+        memoryTransitionTo(newState);
     }
 
-    private void memoryTransitionTo(EpochState state) {
-        EpochState from = this.state;
-        this.state = state;
-        log.info("Completed transition to {} from {}", state, from);
+    private void memoryTransitionTo(EpochState newState) {
+        EpochState from = state;
+        state = newState;
+        log.info("Completed transition to {} from {}", newState, from);
     }
 
     private int randomElectionTimeoutMs() {
@@ -497,8 +533,8 @@
         return electionTimeoutMs + random.nextInt(electionTimeoutMs);
     }
 
-    public boolean canGrantVote(int candidateId, boolean isLogUpToDate) {
-        return state.canGrantVote(candidateId, isLogUpToDate);
+    public boolean canGrantVote(ReplicaKey candidateKey, boolean isLogUpToDate) {
+        return state.canGrantVote(candidateKey, isLogUpToDate);
     }
 
     public FollowerState followerStateOrThrow() {
@@ -508,9 +544,17 @@
     }
 
     public VotedState votedStateOrThrow() {
-        if (isVoted())
-            return (VotedState) state;
-        throw new IllegalStateException("Expected to be Voted, but current state is " + state);
+        return maybeVotedState()
+            .orElseThrow(() -> new IllegalStateException("Expected to be Voted, but current state is " + state));
+    }
+
+    public Optional<VotedState> maybeVotedState() {
+        EpochState fixedState = state;
+        if (fixedState instanceof VotedState) {
+            return Optional.of((VotedState) fixedState);
+        } else {
+            return Optional.empty();
+        }
     }
 
     public UnattachedState unattachedStateOrThrow() {
@@ -519,18 +563,16 @@
         throw new IllegalStateException("Expected to be Unattached, but current state is " + state);
     }
 
-    @SuppressWarnings("unchecked")
     public <T> LeaderState<T> leaderStateOrThrow() {
-        if (isLeader())
-            return (LeaderState<T>) state;
-        throw new IllegalStateException("Expected to be Leader, but current state is " + state);
+        return this.<T>maybeLeaderState()
+            .orElseThrow(() -> new IllegalStateException("Expected to be Leader, but current state is " + state));
     }
 
     @SuppressWarnings("unchecked")
     public <T> Optional<LeaderState<T>> maybeLeaderState() {
-        EpochState state = this.state;
-        if (state instanceof  LeaderState) {
-            return Optional.of((LeaderState<T>) state);
+        EpochState fixedState = state;
+        if (fixedState instanceof LeaderState) {
+            return Optional.of((LeaderState<T>) fixedState);
         } else {
             return Optional.empty();
         }
@@ -550,7 +592,7 @@
 
     public LeaderAndEpoch leaderAndEpoch() {
         ElectionState election = state.election();
-        return new LeaderAndEpoch(election.leaderIdOpt, election.epoch);
+        return new LeaderAndEpoch(election.optionalLeaderId(), election.epoch());
     }
 
     public boolean isFollower() {
diff --git a/raft/src/main/java/org/apache/kafka/raft/QuorumStateStore.java b/raft/src/main/java/org/apache/kafka/raft/QuorumStateStore.java
index 1f2f057..e3d252e 100644
--- a/raft/src/main/java/org/apache/kafka/raft/QuorumStateStore.java
+++ b/raft/src/main/java/org/apache/kafka/raft/QuorumStateStore.java
@@ -17,29 +17,29 @@
 package org.apache.kafka.raft;
 
 import java.nio.file.Path;
+import java.util.Optional;
 
 /**
  *  Maintain the save and retrieval of quorum state information, so far only supports
  *  read and write of election states.
  */
 public interface QuorumStateStore {
-
-    int UNKNOWN_LEADER_ID = -1;
-    int NOT_VOTED = -1;
-
     /**
      * Read the latest election state.
      *
-     * @return The latest written election state or `null` if there is none
+     * @return the latest written election state or {@code Optional.empty()} if there is none
      */
-    ElectionState readElectionState();
+    Optional<ElectionState> readElectionState();
 
     /**
-     * Persist the updated election state. This must be atomic, both writing the full updated state
-     * and replacing the old state.
-     * @param latest The latest election state
+     * Persist the updated election state.
+     *
+     * This must be atomic, both writing the full updated state and replacing the old state.
+     *
+     * @param latest the latest election state
+     * @param kraftVersion the finalized kraft.version
      */
-    void writeElectionState(ElectionState latest);
+    void writeElectionState(ElectionState latest, short kraftVersion);
 
     /**
      * Path to the quorum state store
diff --git a/raft/src/main/java/org/apache/kafka/raft/ResignedState.java b/raft/src/main/java/org/apache/kafka/raft/ResignedState.java
index 899823a..58f1d5d 100644
--- a/raft/src/main/java/org/apache/kafka/raft/ResignedState.java
+++ b/raft/src/main/java/org/apache/kafka/raft/ResignedState.java
@@ -19,6 +19,7 @@
 import org.apache.kafka.common.utils.LogContext;
 import org.apache.kafka.common.utils.Time;
 import org.apache.kafka.common.utils.Timer;
+import org.apache.kafka.raft.internals.ReplicaKey;
 import org.slf4j.Logger;
 
 import java.util.HashSet;
@@ -131,9 +132,13 @@
     }
 
     @Override
-    public boolean canGrantVote(int candidateId, boolean isLogUpToDate) {
-        log.debug("Rejecting vote request from candidate {} since we have resigned as candidate/leader in epoch {}",
-            candidateId, epoch);
+    public boolean canGrantVote(ReplicaKey candidateKey, boolean isLogUpToDate) {
+        log.debug(
+            "Rejecting vote request from candidate ({}) since we have resigned as candidate/leader in epoch {}",
+            candidateKey,
+            epoch
+        );
+
         return false;
     }
 
diff --git a/raft/src/main/java/org/apache/kafka/raft/UnattachedState.java b/raft/src/main/java/org/apache/kafka/raft/UnattachedState.java
index 4dc5fc7..a5ed20c 100644
--- a/raft/src/main/java/org/apache/kafka/raft/UnattachedState.java
+++ b/raft/src/main/java/org/apache/kafka/raft/UnattachedState.java
@@ -19,10 +19,10 @@
 import org.apache.kafka.common.utils.LogContext;
 import org.apache.kafka.common.utils.Time;
 import org.apache.kafka.common.utils.Timer;
+import org.apache.kafka.raft.internals.ReplicaKey;
 import org.slf4j.Logger;
 
 import java.util.Optional;
-import java.util.OptionalInt;
 import java.util.Set;
 
 /**
@@ -56,12 +56,7 @@
 
     @Override
     public ElectionState election() {
-        return new ElectionState(
-            epoch,
-            OptionalInt.empty(),
-            OptionalInt.empty(),
-            voters
-        );
+        return ElectionState.withUnknownLeader(epoch, voters);
     }
 
     @Override
@@ -94,11 +89,14 @@
     }
 
     @Override
-    public boolean canGrantVote(int candidateId, boolean isLogUpToDate) {
+    public boolean canGrantVote(ReplicaKey candidateKey, boolean isLogUpToDate) {
         if (!isLogUpToDate) {
-            log.debug("Rejecting vote request from candidate {} since candidate epoch/offset is not up to date with us",
-                candidateId);
+            log.debug(
+                "Rejecting vote request from candidate ({}) since candidate epoch/offset is not up to date with us",
+                candidateKey
+            );
         }
+
         return isLogUpToDate;
     }
 
diff --git a/raft/src/main/java/org/apache/kafka/raft/VotedState.java b/raft/src/main/java/org/apache/kafka/raft/VotedState.java
index d88668d..550f014 100644
--- a/raft/src/main/java/org/apache/kafka/raft/VotedState.java
+++ b/raft/src/main/java/org/apache/kafka/raft/VotedState.java
@@ -16,24 +16,24 @@
  */
 package org.apache.kafka.raft;
 
+import java.util.Optional;
+import java.util.Set;
 import org.apache.kafka.common.utils.LogContext;
 import org.apache.kafka.common.utils.Time;
 import org.apache.kafka.common.utils.Timer;
+import org.apache.kafka.raft.internals.ReplicaKey;
 import org.slf4j.Logger;
 
-import java.util.Optional;
-import java.util.OptionalInt;
-import java.util.Set;
-
 /**
  * The "voted" state is for voters who have cast their vote for a specific candidate.
+ *
  * Once a vote has been cast, it is not possible for a voter to change its vote until a
  * new election is started. If the election timeout expires before a new leader is elected,
  * then the voter will become a candidate.
  */
 public class VotedState implements EpochState {
     private final int epoch;
-    private final int votedId;
+    private final ReplicaKey votedKey;
     private final Set<Integer> voters;
     private final int electionTimeoutMs;
     private final Timer electionTimer;
@@ -43,14 +43,14 @@
     public VotedState(
         Time time,
         int epoch,
-        int votedId,
+        ReplicaKey votedKey,
         Set<Integer> voters,
         Optional<LogOffsetMetadata> highWatermark,
         int electionTimeoutMs,
         LogContext logContext
     ) {
         this.epoch = epoch;
-        this.votedId = votedId;
+        this.votedKey = votedKey;
         this.voters = voters;
         this.highWatermark = highWatermark;
         this.electionTimeoutMs = electionTimeoutMs;
@@ -60,16 +60,11 @@
 
     @Override
     public ElectionState election() {
-        return new ElectionState(
-            epoch,
-            OptionalInt.empty(),
-            OptionalInt.of(votedId),
-            voters
-        );
+        return ElectionState.withVotedCandidate(epoch, votedKey, voters);
     }
 
-    public int votedId() {
-        return votedId;
+    public ReplicaKey votedKey() {
+        return votedKey;
     }
 
     @Override
@@ -93,13 +88,19 @@
     }
 
     @Override
-    public boolean canGrantVote(int candidateId, boolean isLogUpToDate) {
-        if (votedId() == candidateId) {
-            return true;
+    public boolean canGrantVote(ReplicaKey candidateKey, boolean isLogUpToDate) {
+        if (votedKey.id() == candidateKey.id()) {
+            return !votedKey.directoryId().isPresent() || votedKey.directoryId().equals(candidateKey.directoryId());
         }
 
-        log.debug("Rejecting vote request from candidate {} since we already have voted for " +
-            "another candidate {} in epoch {}", candidateId, votedId(), epoch);
+        log.debug(
+            "Rejecting vote request from candidate ({}), already have voted for another " +
+            "candidate ({}) in epoch {}",
+            candidateKey,
+            votedKey,
+            epoch
+        );
+
         return false;
     }
 
@@ -110,12 +111,14 @@
 
     @Override
     public String toString() {
-        return "Voted(" +
-            "epoch=" + epoch +
-            ", votedId=" + votedId +
-            ", voters=" + voters +
-            ", electionTimeoutMs=" + electionTimeoutMs +
-            ')';
+        return String.format(
+            "Voted(epoch=%d, votedKey=%s, voters=%s, electionTimeoutMs=%d, highWatermark=%s)",
+            epoch,
+            votedKey,
+            voters,
+            electionTimeoutMs,
+            highWatermark
+        );
     }
 
     @Override
diff --git a/raft/src/main/java/org/apache/kafka/raft/internals/BatchAccumulator.java b/raft/src/main/java/org/apache/kafka/raft/internals/BatchAccumulator.java
index f16ec2f..f5f70e6 100644
--- a/raft/src/main/java/org/apache/kafka/raft/internals/BatchAccumulator.java
+++ b/raft/src/main/java/org/apache/kafka/raft/internals/BatchAccumulator.java
@@ -269,7 +269,7 @@
         } else if (batch.baseOffset() != nextOffset) {
             throw new IllegalArgumentException(
                 String.format(
-                    "Expected a base offset of {} but got {}",
+                    "Expected a base offset of %d but got %d",
                     nextOffset,
                     batch.baseOffset()
                 )
@@ -277,7 +277,7 @@
         } else if (batch.partitionLeaderEpoch() != epoch) {
             throw new IllegalArgumentException(
                 String.format(
-                    "Expected a partition leader epoch of {} but got {}",
+                    "Expected a partition leader epoch of %d but got %d",
                     epoch,
                     batch.partitionLeaderEpoch()
                 )
diff --git a/raft/src/main/java/org/apache/kafka/raft/internals/KRaftControlRecordStateMachine.java b/raft/src/main/java/org/apache/kafka/raft/internals/KRaftControlRecordStateMachine.java
index dd6e6a0..25ad3d0 100644
--- a/raft/src/main/java/org/apache/kafka/raft/internals/KRaftControlRecordStateMachine.java
+++ b/raft/src/main/java/org/apache/kafka/raft/internals/KRaftControlRecordStateMachine.java
@@ -136,6 +136,15 @@
     }
 
     /**
+     * Returns the last kraft version.
+     */
+    public short lastKraftVersion() {
+        synchronized (kraftVersionHistory) {
+            return kraftVersionHistory.lastEntry().map(LogHistory.Entry::value).orElse((short) 0);
+        }
+    }
+
+    /**
      * Returns the voter set at a given offset.
      *
      * @param offset the offset (inclusive)
diff --git a/raft/src/main/java/org/apache/kafka/raft/internals/KafkaRaftMetrics.java b/raft/src/main/java/org/apache/kafka/raft/internals/KafkaRaftMetrics.java
index 1ed2a4f..3bdac5f 100644
--- a/raft/src/main/java/org/apache/kafka/raft/internals/KafkaRaftMetrics.java
+++ b/raft/src/main/java/org/apache/kafka/raft/internals/KafkaRaftMetrics.java
@@ -17,6 +17,7 @@
 package org.apache.kafka.raft.internals;
 
 import org.apache.kafka.common.MetricName;
+import org.apache.kafka.common.Uuid;
 import org.apache.kafka.common.metrics.Gauge;
 import org.apache.kafka.common.metrics.Metrics;
 import org.apache.kafka.common.metrics.Sensor;
@@ -42,6 +43,7 @@
 
     private final MetricName currentLeaderIdMetricName;
     private final MetricName currentVotedIdMetricName;
+    private final MetricName currentVotedDirectoryIdMetricName;
     private final MetricName currentEpochMetricName;
     private final MetricName currentStateMetricName;
     private final MetricName highWatermarkMetricName;
@@ -87,17 +89,34 @@
         this.currentLeaderIdMetricName = metrics.metricName("current-leader", metricGroupName, "The current quorum leader's id; -1 indicates unknown");
         metrics.addMetric(this.currentLeaderIdMetricName, (mConfig, currentTimeMs) -> state.leaderId().orElse(-1));
 
-        this.currentVotedIdMetricName = metrics.metricName("current-vote", metricGroupName, "The current voted leader's id; -1 indicates not voted for anyone");
+        this.currentVotedIdMetricName = metrics.metricName("current-vote", metricGroupName, "The current voted id; -1 indicates not voted for anyone");
         metrics.addMetric(this.currentVotedIdMetricName, (mConfig, currentTimeMs) -> {
             if (state.isLeader() || state.isCandidate()) {
                 return state.localIdOrThrow();
-            } else if (state.isVoted()) {
-                return state.votedStateOrThrow().votedId();
             } else {
-                return -1;
+                return (double) state.maybeVotedState()
+                    .map(votedState -> votedState.votedKey().id())
+                    .orElse(-1);
             }
         });
 
+        this.currentVotedDirectoryIdMetricName = metrics.metricName(
+            "current-vote-directory-id",
+            metricGroupName,
+            String.format("The current voted directory id; %s indicates not voted for a directory id", Uuid.ZERO_UUID)
+        );
+        Gauge<String> votedDirectoryIdProvider = (mConfig, currentTimestamp) -> {
+            if (state.isLeader() || state.isCandidate()) {
+                return state.localDirectoryId().toString();
+            } else {
+                return state.maybeVotedState()
+                    .flatMap(votedState -> votedState.votedKey().directoryId())
+                    .orElse(Uuid.ZERO_UUID)
+                    .toString();
+            }
+        };
+        metrics.addMetric(this.currentVotedDirectoryIdMetricName, null, votedDirectoryIdProvider);
+
         this.currentEpochMetricName = metrics.metricName("current-epoch", metricGroupName, "The current quorum epoch.");
         metrics.addMetric(this.currentEpochMetricName, (mConfig, currentTimeMs) -> state.epoch());
 
@@ -196,6 +215,7 @@
         Arrays.asList(
             currentLeaderIdMetricName,
             currentVotedIdMetricName,
+            currentVotedDirectoryIdMetricName,
             currentEpochMetricName,
             currentStateMetricName,
             highWatermarkMetricName,
diff --git a/raft/src/main/java/org/apache/kafka/raft/internals/ReplicaKey.java b/raft/src/main/java/org/apache/kafka/raft/internals/ReplicaKey.java
new file mode 100644
index 0000000..7d799a9
--- /dev/null
+++ b/raft/src/main/java/org/apache/kafka/raft/internals/ReplicaKey.java
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.raft.internals;
+
+import java.util.Objects;
+import java.util.Optional;
+import org.apache.kafka.common.Uuid;
+
+public final class ReplicaKey {
+    private final int id;
+    private final Optional<Uuid> directoryId;
+
+    private ReplicaKey(int id, Optional<Uuid> directoryId) {
+        this.id = id;
+        this.directoryId = directoryId;
+    }
+
+    public int id() {
+        return id;
+    }
+
+    public Optional<Uuid> directoryId() {
+        return directoryId;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+
+        ReplicaKey that = (ReplicaKey) o;
+
+        if (id != that.id) return false;
+        if (!Objects.equals(directoryId, that.directoryId)) return false;
+
+        return true;
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(id, directoryId);
+    }
+
+    @Override
+    public String toString() {
+        return String.format("ReplicaKey(id=%d, directoryId=%s)", id, directoryId);
+    }
+
+    public static ReplicaKey of(int id, Optional<Uuid> directoryId) {
+        return new ReplicaKey(id, directoryId);
+    }
+}
diff --git a/raft/src/main/java/org/apache/kafka/raft/internals/VoterSet.java b/raft/src/main/java/org/apache/kafka/raft/internals/VoterSet.java
index 9ca3836..d5a046e 100644
--- a/raft/src/main/java/org/apache/kafka/raft/internals/VoterSet.java
+++ b/raft/src/main/java/org/apache/kafka/raft/internals/VoterSet.java
@@ -65,6 +65,43 @@
     }
 
     /**
+     * Returns if the node is a voter in the set of voters.
+     *
+     * If the voter set includes the directory id, the {@code nodeKey} directory id must match the
+     * directory id specified by the voter set.
+     *
+     * If the voter set doesn't include the directory id ({@code Optional.empty()}), a node is in
+     * the voter set as long as the node id matches. The directory id is not checked.
+     *
+     * @param nodeKey the node's id and directory id
+     * @return true if the node is a voter in the voter set, otherwise false
+     */
+    public boolean isVoter(ReplicaKey nodeKey) {
+        VoterNode node = voters.get(nodeKey.id());
+        if (node != null) {
+            if (node.voterKey().directoryId().isPresent()) {
+                return node.voterKey().directoryId().equals(nodeKey.directoryId());
+            } else {
+                // configured voter set doesn't include a directory id so it is a voter as long as the node id
+                // matches
+                return true;
+            }
+        } else {
+            return false;
+        }
+    }
+
+    /**
+     * Returns if the node is the only voter in the set of voters.
+     *
+     * @param nodeKey the node's id and directory id
+     * @return true if the node is the only voter in the voter set, otherwise false
+     */
+    public boolean isOnlyVoter(ReplicaKey nodeKey) {
+        return voters.size() == 1 && isVoter(nodeKey);
+    }
+
+    /**
      * Returns all of the voter ids.
      */
     public Set<Integer> voterIds() {
@@ -102,7 +139,7 @@
      * @param voterKey the voter key
      * @return a new voter set if the voter was removed, otherwise {@code Optional.empty()}
      */
-    public Optional<VoterSet> removeVoter(VoterKey voterKey) {
+    public Optional<VoterSet> removeVoter(ReplicaKey voterKey) {
         VoterNode oldVoter = voters.get(voterKey.id());
         if (oldVoter != null && Objects.equals(oldVoter.voterKey(), voterKey)) {
             HashMap<Integer, VoterNode> newVoters = new HashMap<>(voters);
@@ -168,20 +205,20 @@
      * @return true if they have an overlapping majority, false otherwise
      */
     public boolean hasOverlappingMajority(VoterSet that) {
-        Set<VoterKey> thisVoterKeys = voters
+        Set<ReplicaKey> thisReplicaKeys = voters
             .values()
             .stream()
             .map(VoterNode::voterKey)
             .collect(Collectors.toSet());
 
-        Set<VoterKey> thatVoterKeys = that.voters
+        Set<ReplicaKey> thatReplicaKeys = that.voters
             .values()
             .stream()
             .map(VoterNode::voterKey)
             .collect(Collectors.toSet());
 
-        if (Utils.diff(HashSet::new, thisVoterKeys, thatVoterKeys).size() > 1) return false;
-        if (Utils.diff(HashSet::new, thatVoterKeys, thisVoterKeys).size() > 1) return false;
+        if (Utils.diff(HashSet::new, thisReplicaKeys, thatReplicaKeys).size() > 1) return false;
+        if (Utils.diff(HashSet::new, thatReplicaKeys, thisReplicaKeys).size() > 1) return false;
 
         return true;
     }
@@ -206,58 +243,13 @@
         return String.format("VoterSet(voters=%s)", voters);
     }
 
-    public final static class VoterKey {
-        private final int id;
-        private final Optional<Uuid> directoryId;
-
-        private VoterKey(int id, Optional<Uuid> directoryId) {
-            this.id = id;
-            this.directoryId = directoryId;
-        }
-
-        public int id() {
-            return id;
-        }
-
-        public Optional<Uuid> directoryId() {
-            return directoryId;
-        }
-
-        @Override
-        public boolean equals(Object o) {
-            if (this == o) return true;
-            if (o == null || getClass() != o.getClass()) return false;
-
-            VoterKey that = (VoterKey) o;
-
-            if (id != that.id) return false;
-            if (!Objects.equals(directoryId, that.directoryId)) return false;
-
-            return true;
-        }
-
-        @Override
-        public int hashCode() {
-            return Objects.hash(id, directoryId);
-        }
-
-        @Override
-        public String toString() {
-            return String.format("VoterKey(id=%d, directoryId=%s)", id, directoryId);
-        }
-
-        public static VoterKey of(int id, Optional<Uuid> directoryId) {
-            return new VoterKey(id, directoryId);
-        }
-    }
-
-    final static class VoterNode {
-        private final VoterKey voterKey;
+    public final static class VoterNode {
+        private final ReplicaKey voterKey;
         private final Map<String, InetSocketAddress> listeners;
         private final SupportedVersionRange supportedKRaftVersion;
 
         VoterNode(
-            VoterKey voterKey,
+            ReplicaKey voterKey,
             Map<String, InetSocketAddress> listeners,
             SupportedVersionRange supportedKRaftVersion
         ) {
@@ -266,7 +258,7 @@
             this.supportedKRaftVersion = supportedKRaftVersion;
         }
 
-        VoterKey voterKey() {
+        public ReplicaKey voterKey() {
             return voterKey;
         }
 
@@ -337,7 +329,7 @@
             voterNodes.put(
                 voter.voterId(),
                 new VoterNode(
-                    VoterKey.of(voter.voterId(), directoryId),
+                    ReplicaKey.of(voter.voterId(), directoryId),
                     listeners,
                     new SupportedVersionRange(
                         voter.kRaftVersionFeature().minSupportedVersion(),
@@ -365,7 +357,7 @@
                 Collectors.toMap(
                     Map.Entry::getKey,
                     entry -> new VoterNode(
-                        VoterKey.of(entry.getKey(), Optional.empty()),
+                        ReplicaKey.of(entry.getKey(), Optional.empty()),
                         Collections.singletonMap(listener, entry.getValue()),
                         new SupportedVersionRange((short) 0, (short) 0)
                     )
diff --git a/raft/src/main/resources/common/message/QuorumStateData.json b/raft/src/main/resources/common/message/QuorumStateData.json
index d71a32c..fdfe45c 100644
--- a/raft/src/main/resources/common/message/QuorumStateData.json
+++ b/raft/src/main/resources/common/message/QuorumStateData.json
@@ -16,19 +16,17 @@
 {
   "type": "data",
   "name": "QuorumStateData",
-  "validVersions": "0",
+  "validVersions": "0-1",
   "flexibleVersions": "0+",
   "fields": [
-    {"name": "ClusterId", "type": "string", "versions": "0+"},
-    {"name": "LeaderId", "type": "int32", "versions": "0+", "default": "-1"},
-    {"name": "LeaderEpoch", "type": "int32", "versions": "0+", "default": "-1"},
-    {"name": "VotedId", "type": "int32", "versions": "0+", "default": "-1"},
-    {"name": "AppliedOffset", "type": "int64", "versions": "0+"},
-    {"name": "CurrentVoters", "type": "[]Voter", "versions": "0+", "nullableVersions": "0+"}
-  ],
-  "commonStructs": [
-    { "name": "Voter", "versions": "0+", "fields": [
-      {"name": "VoterId", "type": "int32", "versions": "0+"}
+    { "name": "ClusterId", "type": "string", "versions": "0" },
+    { "name": "LeaderId", "type": "int32", "versions": "0+", "default": "-1" },
+    { "name": "LeaderEpoch", "type": "int32", "versions": "0+", "default": "-1" },
+    { "name": "VotedId", "type": "int32", "versions": "0+", "default": "-1" },
+    { "name": "VotedDirectoryId", "type": "uuid", "versions": "1+" },
+    { "name": "AppliedOffset", "type": "int64", "versions": "0" },
+    { "name": "CurrentVoters", "type": "[]Voter", "versions": "0", "nullableVersions": "0", "fields": [
+      { "name": "VoterId", "type": "int32", "versions": "0" }
     ]}
   ]
 }
diff --git a/raft/src/test/java/org/apache/kafka/raft/CandidateStateTest.java b/raft/src/test/java/org/apache/kafka/raft/CandidateStateTest.java
index 71a2375..524a93f 100644
--- a/raft/src/test/java/org/apache/kafka/raft/CandidateStateTest.java
+++ b/raft/src/test/java/org/apache/kafka/raft/CandidateStateTest.java
@@ -19,13 +19,18 @@
 import org.apache.kafka.common.utils.LogContext;
 import org.apache.kafka.common.utils.MockTime;
 import org.apache.kafka.common.utils.Utils;
+import org.apache.kafka.raft.internals.ReplicaKey;
+import org.apache.kafka.raft.internals.VoterSet;
+import org.apache.kafka.raft.internals.VoterSetTest;
 import org.junit.jupiter.api.Test;
 import org.junit.jupiter.params.ParameterizedTest;
 import org.junit.jupiter.params.provider.ValueSource;
 
+import java.util.Arrays;
+import java.util.Collection;
 import java.util.Collections;
+import java.util.Map;
 import java.util.Optional;
-import java.util.Set;
 
 import static org.junit.jupiter.api.Assertions.assertEquals;
 import static org.junit.jupiter.api.Assertions.assertFalse;
@@ -33,22 +38,20 @@
 import static org.junit.jupiter.api.Assertions.assertTrue;
 
 public class CandidateStateTest {
-    private final int localId = 0;
+    private final VoterSet.VoterNode localNode = VoterSetTest.voterNode(0, true);
     private final int epoch = 5;
     private final MockTime time = new MockTime();
     private final int electionTimeoutMs = 5000;
     private final LogContext logContext = new LogContext();
 
-    private CandidateState newCandidateState(
-            Set<Integer> voters,
-            Optional<LogOffsetMetadata> highWatermark
-    ) {
+    private CandidateState newCandidateState(VoterSet voters) {
         return new CandidateState(
                 time,
-                localId,
+                localNode.voterKey().id(),
+                localNode.voterKey().directoryId().get(),
                 epoch,
                 voters,
-                highWatermark,
+                Optional.empty(),
                 0,
                 electionTimeoutMs,
                 logContext
@@ -57,7 +60,7 @@
 
     @Test
     public void testSingleNodeQuorum() {
-        CandidateState state = newCandidateState(Collections.singleton(localId), Optional.empty());
+        CandidateState state = newCandidateState(voterSetWithLocal(Collections.emptyList()));
         assertTrue(state.isVoteGranted());
         assertFalse(state.isVoteRejected());
         assertEquals(Collections.emptySet(), state.unrecordedVoters());
@@ -66,7 +69,9 @@
     @Test
     public void testTwoNodeQuorumVoteRejected() {
         int otherNodeId = 1;
-        CandidateState state = newCandidateState(Utils.mkSet(localId, otherNodeId), Optional.empty());
+        CandidateState state = newCandidateState(
+            voterSetWithLocal(Collections.singletonList(otherNodeId))
+        );
         assertFalse(state.isVoteGranted());
         assertFalse(state.isVoteRejected());
         assertEquals(Collections.singleton(otherNodeId), state.unrecordedVoters());
@@ -79,7 +84,8 @@
     public void testTwoNodeQuorumVoteGranted() {
         int otherNodeId = 1;
         CandidateState state = newCandidateState(
-            Utils.mkSet(localId, otherNodeId), Optional.empty());
+            voterSetWithLocal(Collections.singletonList(otherNodeId))
+        );
         assertFalse(state.isVoteGranted());
         assertFalse(state.isVoteRejected());
         assertEquals(Collections.singleton(otherNodeId), state.unrecordedVoters());
@@ -94,7 +100,8 @@
         int node1 = 1;
         int node2 = 2;
         CandidateState state = newCandidateState(
-            Utils.mkSet(localId, node1, node2), Optional.empty());
+            voterSetWithLocal(Arrays.asList(node1, node2))
+        );
         assertFalse(state.isVoteGranted());
         assertFalse(state.isVoteRejected());
         assertEquals(Utils.mkSet(node1, node2), state.unrecordedVoters());
@@ -113,7 +120,8 @@
         int node1 = 1;
         int node2 = 2;
         CandidateState state = newCandidateState(
-            Utils.mkSet(localId, node1, node2), Optional.empty());
+            voterSetWithLocal(Arrays.asList(node1, node2))
+        );
         assertFalse(state.isVoteGranted());
         assertFalse(state.isVoteRejected());
         assertEquals(Utils.mkSet(node1, node2), state.unrecordedVoters());
@@ -131,15 +139,20 @@
     public void testCannotRejectVoteFromLocalId() {
         int otherNodeId = 1;
         CandidateState state = newCandidateState(
-            Utils.mkSet(localId, otherNodeId), Optional.empty());
-        assertThrows(IllegalArgumentException.class, () -> state.recordRejectedVote(localId));
+            voterSetWithLocal(Collections.singletonList(otherNodeId))
+        );
+        assertThrows(
+            IllegalArgumentException.class,
+            () -> state.recordRejectedVote(localNode.voterKey().id())
+        );
     }
 
     @Test
     public void testCannotChangeVoteGrantedToRejected() {
         int otherNodeId = 1;
         CandidateState state = newCandidateState(
-            Utils.mkSet(localId, otherNodeId), Optional.empty());
+            voterSetWithLocal(Collections.singletonList(otherNodeId))
+        );
         assertTrue(state.recordGrantedVote(otherNodeId));
         assertThrows(IllegalArgumentException.class, () -> state.recordRejectedVote(otherNodeId));
         assertTrue(state.isVoteGranted());
@@ -149,7 +162,8 @@
     public void testCannotChangeVoteRejectedToGranted() {
         int otherNodeId = 1;
         CandidateState state = newCandidateState(
-            Utils.mkSet(localId, otherNodeId), Optional.empty());
+            voterSetWithLocal(Collections.singletonList(otherNodeId))
+        );
         assertTrue(state.recordRejectedVote(otherNodeId));
         assertThrows(IllegalArgumentException.class, () -> state.recordGrantedVote(otherNodeId));
         assertTrue(state.isVoteRejected());
@@ -158,8 +172,7 @@
     @Test
     public void testCannotGrantOrRejectNonVoters() {
         int nonVoterId = 1;
-        CandidateState state = newCandidateState(
-            Collections.singleton(localId), Optional.empty());
+        CandidateState state = newCandidateState(voterSetWithLocal(Collections.emptyList()));
         assertThrows(IllegalArgumentException.class, () -> state.recordGrantedVote(nonVoterId));
         assertThrows(IllegalArgumentException.class, () -> state.recordRejectedVote(nonVoterId));
     }
@@ -168,7 +181,8 @@
     public void testIdempotentGrant() {
         int otherNodeId = 1;
         CandidateState state = newCandidateState(
-            Utils.mkSet(localId, otherNodeId), Optional.empty());
+            voterSetWithLocal(Collections.singletonList(otherNodeId))
+        );
         assertTrue(state.recordGrantedVote(otherNodeId));
         assertFalse(state.recordGrantedVote(otherNodeId));
     }
@@ -177,7 +191,8 @@
     public void testIdempotentReject() {
         int otherNodeId = 1;
         CandidateState state = newCandidateState(
-            Utils.mkSet(localId, otherNodeId), Optional.empty());
+            voterSetWithLocal(Collections.singletonList(otherNodeId))
+        );
         assertTrue(state.recordRejectedVote(otherNodeId));
         assertFalse(state.recordRejectedVote(otherNodeId));
     }
@@ -186,13 +201,41 @@
     @ValueSource(booleans = {true, false})
     public void testGrantVote(boolean isLogUpToDate) {
         CandidateState state = newCandidateState(
-            Utils.mkSet(1, 2, 3),
-            Optional.empty()
+            voterSetWithLocal(Arrays.asList(1, 2, 3))
         );
 
-        assertFalse(state.canGrantVote(1, isLogUpToDate));
-        assertFalse(state.canGrantVote(2, isLogUpToDate));
-        assertFalse(state.canGrantVote(3, isLogUpToDate));
+        assertFalse(state.canGrantVote(ReplicaKey.of(0, Optional.empty()), isLogUpToDate));
+        assertFalse(state.canGrantVote(ReplicaKey.of(1, Optional.empty()), isLogUpToDate));
+        assertFalse(state.canGrantVote(ReplicaKey.of(2, Optional.empty()), isLogUpToDate));
+        assertFalse(state.canGrantVote(ReplicaKey.of(3, Optional.empty()), isLogUpToDate));
     }
 
+    @Test
+    public void testElectionState() {
+        VoterSet voters = voterSetWithLocal(Arrays.asList(1, 2, 3));
+        CandidateState state = newCandidateState(voters);
+        assertEquals(
+            ElectionState.withVotedCandidate(
+                epoch,
+                localNode.voterKey(),
+                voters.voterIds()
+            ),
+            state.election()
+        );
+    }
+
+    @Test
+    public void testInvalidVoterSet() {
+        assertThrows(
+            IllegalArgumentException.class,
+            () -> newCandidateState(VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true)))
+        );
+    }
+
+    private VoterSet voterSetWithLocal(Collection<Integer> remoteVoters) {
+        Map<Integer, VoterSet.VoterNode> voterMap = VoterSetTest.voterMap(remoteVoters, true);
+        voterMap.put(localNode.voterKey().id(), localNode);
+
+        return VoterSetTest.voterSet(voterMap);
+    }
 }
diff --git a/raft/src/test/java/org/apache/kafka/raft/ElectionStateTest.java b/raft/src/test/java/org/apache/kafka/raft/ElectionStateTest.java
new file mode 100644
index 0000000..c0b135c
--- /dev/null
+++ b/raft/src/test/java/org/apache/kafka/raft/ElectionStateTest.java
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.raft;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.Optional;
+import org.apache.kafka.common.Uuid;
+import org.apache.kafka.common.utils.Utils;
+import org.apache.kafka.raft.generated.QuorumStateData;
+import org.apache.kafka.raft.internals.ReplicaKey;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.ValueSource;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+final class ElectionStateTest {
+    @Test
+    void testVotedCandidateWithoutVotedId() {
+        ElectionState electionState = ElectionState.withUnknownLeader(5, Collections.emptySet());
+        assertFalse(electionState.isVotedCandidate(ReplicaKey.of(1, Optional.empty())));
+    }
+
+    @Test
+    void testVotedCandidateWithoutVotedDirectoryId() {
+        ElectionState electionState = ElectionState.withVotedCandidate(
+            5,
+            ReplicaKey.of(1, Optional.empty()),
+            Collections.emptySet()
+        );
+        assertTrue(electionState.isVotedCandidate(ReplicaKey.of(1, Optional.empty())));
+        assertTrue(
+            electionState.isVotedCandidate(ReplicaKey.of(1, Optional.of(Uuid.randomUuid())))
+        );
+    }
+
+    @Test
+    void testVotedCandidateWithVotedDirectoryId() {
+        ReplicaKey votedKey = ReplicaKey.of(1, Optional.of(Uuid.randomUuid()));
+        ElectionState electionState = ElectionState.withVotedCandidate(
+            5,
+            votedKey,
+            Collections.emptySet()
+        );
+        assertFalse(electionState.isVotedCandidate(ReplicaKey.of(1, Optional.empty())));
+        assertTrue(electionState.isVotedCandidate(votedKey));
+    }
+
+    @ParameterizedTest
+    @ValueSource(shorts = {0, 1})
+    void testQuorumStateDataRoundTrip(short version) {
+        ReplicaKey votedKey = ReplicaKey.of(1, Optional.of(Uuid.randomUuid()));
+        List<ElectionState> electionStates = Arrays.asList(
+            ElectionState.withUnknownLeader(5, Utils.mkSet(1, 2, 3)),
+            ElectionState.withElectedLeader(5, 1, Utils.mkSet(1, 2, 3)),
+            ElectionState.withVotedCandidate(5, votedKey, Utils.mkSet(1, 2, 3))
+        );
+
+        final List<ElectionState> expected;
+        if (version == 0) {
+            expected = Arrays.asList(
+                ElectionState.withUnknownLeader(5, Utils.mkSet(1, 2, 3)),
+                ElectionState.withElectedLeader(5, 1, Utils.mkSet(1, 2, 3)),
+                ElectionState.withVotedCandidate(
+                    5,
+                    ReplicaKey.of(1, Optional.empty()),
+                    Utils.mkSet(1, 2, 3)
+                )
+            );
+        } else {
+            expected = Arrays.asList(
+                ElectionState.withUnknownLeader(5, Collections.emptySet()),
+                ElectionState.withElectedLeader(5, 1, Collections.emptySet()),
+                ElectionState.withVotedCandidate(5, votedKey, Collections.emptySet())
+            );
+        }
+
+        int expectedId = 0;
+        for (ElectionState electionState : electionStates) {
+            QuorumStateData data = electionState.toQuorumStateData(version);
+            assertEquals(expected.get(expectedId), ElectionState.fromQuorumStateData(data));
+            expectedId++;
+        }
+    }
+}
diff --git a/raft/src/test/java/org/apache/kafka/raft/FileBasedStateStoreTest.java b/raft/src/test/java/org/apache/kafka/raft/FileBasedStateStoreTest.java
deleted file mode 100644
index c6c1f6c..0000000
--- a/raft/src/test/java/org/apache/kafka/raft/FileBasedStateStoreTest.java
+++ /dev/null
@@ -1,166 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.kafka.raft;
-
-import com.fasterxml.jackson.databind.JsonNode;
-import com.fasterxml.jackson.databind.ObjectMapper;
-import org.apache.kafka.common.errors.UnsupportedVersionException;
-import org.apache.kafka.common.protocol.types.TaggedFields;
-import org.apache.kafka.common.utils.Utils;
-import org.apache.kafka.raft.generated.QuorumStateData;
-import org.apache.kafka.test.TestUtils;
-import org.junit.jupiter.api.AfterEach;
-
-import java.io.BufferedWriter;
-import java.io.File;
-import java.io.FileOutputStream;
-import java.io.IOException;
-import java.io.OutputStreamWriter;
-import java.io.UncheckedIOException;
-import java.nio.charset.StandardCharsets;
-import java.util.OptionalInt;
-import java.util.Set;
-import org.junit.jupiter.api.Test;
-
-import static org.junit.jupiter.api.Assertions.assertEquals;
-import static org.junit.jupiter.api.Assertions.assertFalse;
-import static org.junit.jupiter.api.Assertions.assertThrows;
-import static org.junit.jupiter.api.Assertions.assertTrue;
-
-public class FileBasedStateStoreTest {
-
-    private FileBasedStateStore stateStore;
-
-    @Test
-    public void testReadElectionState() throws IOException {
-        final File stateFile = TestUtils.tempFile();
-
-        stateStore = new FileBasedStateStore(stateFile);
-
-        final int leaderId = 1;
-        final int epoch = 2;
-        Set<Integer> voters = Utils.mkSet(leaderId);
-
-        stateStore.writeElectionState(ElectionState.withElectedLeader(epoch, leaderId, voters));
-        assertTrue(stateFile.exists());
-        assertEquals(ElectionState.withElectedLeader(epoch, leaderId, voters), stateStore.readElectionState());
-
-        // Start another state store and try to read from the same file.
-        final FileBasedStateStore secondStateStore = new FileBasedStateStore(stateFile);
-        assertEquals(ElectionState.withElectedLeader(epoch, leaderId, voters), secondStateStore.readElectionState());
-    }
-
-    @Test
-    public void testWriteElectionState() throws IOException {
-        final File stateFile = TestUtils.tempFile();
-
-        stateStore = new FileBasedStateStore(stateFile);
-
-        // We initialized a state from the metadata log
-        assertTrue(stateFile.exists());
-
-        // The temp file should be removed
-        final File createdTempFile = new File(stateFile.getAbsolutePath() + ".tmp");
-        assertFalse(createdTempFile.exists());
-
-        final int epoch = 2;
-        final int leaderId = 1;
-        final int votedId = 5;
-        Set<Integer> voters = Utils.mkSet(leaderId, votedId);
-
-        stateStore.writeElectionState(ElectionState.withElectedLeader(epoch, leaderId, voters));
-
-        assertEquals(stateStore.readElectionState(), new ElectionState(epoch,
-            OptionalInt.of(leaderId), OptionalInt.empty(), voters));
-
-        stateStore.writeElectionState(ElectionState.withVotedCandidate(epoch, votedId, voters));
-
-        assertEquals(stateStore.readElectionState(), new ElectionState(epoch,
-            OptionalInt.empty(), OptionalInt.of(votedId), voters));
-
-        final FileBasedStateStore rebootStateStore = new FileBasedStateStore(stateFile);
-
-        assertEquals(rebootStateStore.readElectionState(), new ElectionState(epoch,
-            OptionalInt.empty(), OptionalInt.of(votedId), voters));
-
-        stateStore.clear();
-        assertFalse(stateFile.exists());
-    }
-
-    @Test
-    public void testCantReadVersionQuorumState() throws IOException {
-        String jsonString = "{\"leaderId\":9990,\"leaderEpoch\":3012,\"votedId\":-1," +
-                "\"appliedOffset\": 0,\"currentVoters\":[{\"voterId\":9990},{\"voterId\":9991},{\"voterId\":9992}]," +
-                "\"data_version\":2}";
-        assertCantReadQuorumStateVersion(jsonString);
-    }
-
-    @Test
-    public void testSupportedVersion() {
-        // If the next few checks fail, please check that they are compatible with previous releases of KRaft
-
-        // Check that FileBasedStateStore supports the latest version
-        assertEquals(FileBasedStateStore.HIGHEST_SUPPORTED_VERSION, QuorumStateData.HIGHEST_SUPPORTED_VERSION);
-        // Check that the supported versions haven't changed
-        assertEquals(0, QuorumStateData.HIGHEST_SUPPORTED_VERSION);
-        assertEquals(0, QuorumStateData.LOWEST_SUPPORTED_VERSION);
-        // For the latest version check that the number of tagged fields hasn't changed
-        TaggedFields taggedFields = (TaggedFields) QuorumStateData.SCHEMA_0.get(6).def.type;
-        assertEquals(0, taggedFields.numFields());
-    }
-
-    public void assertCantReadQuorumStateVersion(String jsonString) throws IOException {
-        final File stateFile = TestUtils.tempFile();
-        stateStore = new FileBasedStateStore(stateFile);
-
-        // We initialized a state from the metadata log
-        assertTrue(stateFile.exists());
-
-        writeToStateFile(stateFile, jsonString);
-
-        assertThrows(UnsupportedVersionException.class, () -> stateStore.readElectionState());
-
-        stateStore.clear();
-        assertFalse(stateFile.exists());
-    }
-
-    private void writeToStateFile(final File stateFile, String jsonString) {
-        try (final FileOutputStream fileOutputStream = new FileOutputStream(stateFile);
-             final BufferedWriter writer = new BufferedWriter(
-                     new OutputStreamWriter(fileOutputStream, StandardCharsets.UTF_8))) {
-            ObjectMapper mapper = new ObjectMapper();
-            JsonNode node = mapper.readTree(jsonString);
-
-            writer.write(node.toString());
-            writer.flush();
-            fileOutputStream.getFD().sync();
-
-        } catch (IOException e) {
-            throw new UncheckedIOException(
-                    String.format("Error while writing to Quorum state file %s",
-                            stateFile.getAbsolutePath()), e);
-        }
-    }
-
-
-    @AfterEach
-    public void cleanup() throws IOException {
-        if (stateStore != null) {
-            stateStore.clear();
-        }
-    }
-}
diff --git a/raft/src/test/java/org/apache/kafka/raft/FileQuorumStateStoreTest.java b/raft/src/test/java/org/apache/kafka/raft/FileQuorumStateStoreTest.java
new file mode 100644
index 0000000..d7ed2c8
--- /dev/null
+++ b/raft/src/test/java/org/apache/kafka/raft/FileQuorumStateStoreTest.java
@@ -0,0 +1,223 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.raft;
+
+import com.fasterxml.jackson.databind.JsonNode;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import org.apache.kafka.common.Uuid;
+import org.apache.kafka.common.protocol.types.TaggedFields;
+import org.apache.kafka.common.utils.Utils;
+import org.apache.kafka.raft.generated.QuorumStateData;
+import org.apache.kafka.raft.internals.ReplicaKey;
+import org.apache.kafka.test.TestUtils;
+
+import java.io.BufferedWriter;
+import java.io.File;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.OutputStreamWriter;
+import java.io.UncheckedIOException;
+import java.nio.charset.StandardCharsets;
+import java.util.Collections;
+import java.util.Optional;
+import java.util.Set;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.ValueSource;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+public class FileQuorumStateStoreTest {
+    @ParameterizedTest
+    @ValueSource(shorts = {0, 1})
+    void testWriteReadElectedLeader(short kraftVersion) throws IOException {
+        FileQuorumStateStore stateStore = new FileQuorumStateStore(TestUtils.tempFile());
+
+        final int epoch = 2;
+        final int voter1 = 1;
+        final int voter2 = 2;
+        final int voter3 = 3;
+        Set<Integer> voters = Utils.mkSet(voter1, voter2, voter3);
+
+        stateStore.writeElectionState(
+            ElectionState.withElectedLeader(epoch, voter1, voters),
+            kraftVersion
+        );
+
+        final Optional<ElectionState> expected;
+        if (kraftVersion == 1) {
+            expected = Optional.of(
+                ElectionState.withElectedLeader(epoch, voter1, Collections.emptySet())
+            );
+        } else {
+            expected = Optional.of(ElectionState.withElectedLeader(epoch, voter1, voters));
+        }
+
+        assertEquals(expected, stateStore.readElectionState());
+
+        stateStore.clear();
+    }
+
+    @ParameterizedTest
+    @ValueSource(shorts = {0, 1})
+    void testWriteReadVotedCandidate(short kraftVersion) throws IOException {
+        FileQuorumStateStore stateStore = new FileQuorumStateStore(TestUtils.tempFile());
+
+        final int epoch = 2;
+        final int voter1 = 1;
+        final Optional<Uuid> voter1DirectoryId = Optional.of(Uuid.randomUuid());
+        final ReplicaKey voter1Key = ReplicaKey.of(voter1, voter1DirectoryId);
+        final int voter2 = 2;
+        final int voter3 = 3;
+        Set<Integer> voters = Utils.mkSet(voter1, voter2, voter3);
+
+        stateStore.writeElectionState(
+            ElectionState.withVotedCandidate(epoch, voter1Key, voters),
+            kraftVersion
+        );
+
+        final Optional<ElectionState> expected;
+        if (kraftVersion == 1) {
+            expected = Optional.of(
+                ElectionState.withVotedCandidate(
+                    epoch,
+                    voter1Key,
+                    Collections.emptySet()
+                )
+            );
+        } else {
+            expected = Optional.of(
+                ElectionState.withVotedCandidate(
+                    epoch,
+                    ReplicaKey.of(voter1, Optional.empty()),
+                    voters
+                )
+            );
+        }
+
+        assertEquals(expected, stateStore.readElectionState());
+        stateStore.clear();
+    }
+
+    @ParameterizedTest
+    @ValueSource(shorts = {0, 1})
+    void testWriteReadUnknownLeader(short kraftVersion) throws IOException {
+        FileQuorumStateStore stateStore = new FileQuorumStateStore(TestUtils.tempFile());
+
+        final int epoch = 2;
+        Set<Integer> voters = Utils.mkSet(1, 2, 3);
+
+        stateStore.writeElectionState(
+            ElectionState.withUnknownLeader(epoch, voters),
+            kraftVersion
+        );
+
+        final Optional<ElectionState> expected;
+        if (kraftVersion == 1) {
+            expected = Optional.of(ElectionState.withUnknownLeader(epoch, Collections.emptySet()));
+        } else {
+            expected = Optional.of(ElectionState.withUnknownLeader(epoch, voters));
+        }
+
+        assertEquals(expected, stateStore.readElectionState());
+        stateStore.clear();
+    }
+
+    @Test
+    void testReload()  throws IOException {
+        final File stateFile = TestUtils.tempFile();
+        FileQuorumStateStore stateStore = new FileQuorumStateStore(stateFile);
+
+        final int epoch = 2;
+        Set<Integer> voters = Utils.mkSet(1, 2, 3);
+
+        stateStore.writeElectionState(ElectionState.withUnknownLeader(epoch, voters), (short) 1);
+
+        // Check that state is persisted
+        FileQuorumStateStore reloadedStore = new FileQuorumStateStore(stateFile);
+        assertEquals(
+            Optional.of(ElectionState.withUnknownLeader(epoch, Collections.emptySet())),
+            reloadedStore.readElectionState()
+        );
+    }
+
+    @Test
+    void testCreateAndClear() throws IOException {
+        final File stateFile = TestUtils.tempFile();
+        FileQuorumStateStore stateStore = new FileQuorumStateStore(stateFile);
+
+        // We initialized a state from the metadata log
+        assertTrue(stateFile.exists());
+
+        // The temp file should be removed
+        final File createdTempFile = new File(stateFile.getAbsolutePath() + ".tmp");
+        assertFalse(createdTempFile.exists());
+
+        // Clear delete the state file
+        stateStore.clear();
+        assertFalse(stateFile.exists());
+    }
+
+    @Test
+    public void testCantReadVersionQuorumState() throws IOException {
+        String jsonString = "{\"leaderId\":9990,\"leaderEpoch\":3012,\"votedId\":-1," +
+                "\"appliedOffset\": 0,\"currentVoters\":[{\"voterId\":9990},{\"voterId\":9991},{\"voterId\":9992}]," +
+                "\"data_version\":2}";
+        final File stateFile = TestUtils.tempFile();
+        writeToStateFile(stateFile, jsonString);
+
+        FileQuorumStateStore stateStore = new FileQuorumStateStore(stateFile);
+        assertThrows(IllegalStateException.class, stateStore::readElectionState);
+
+        stateStore.clear();
+    }
+
+    @Test
+    public void testSupportedVersion() {
+        // If the next few checks fail, please check that they are compatible with previous releases of KRaft
+
+        // Check that FileQuorumStateStore supports the latest version
+        assertEquals(FileQuorumStateStore.HIGHEST_SUPPORTED_VERSION, QuorumStateData.HIGHEST_SUPPORTED_VERSION);
+        // Check that the supported versions haven't changed
+        assertEquals(1, QuorumStateData.HIGHEST_SUPPORTED_VERSION);
+        assertEquals(0, QuorumStateData.LOWEST_SUPPORTED_VERSION);
+        // For the latest version check that the number of tagged fields hasn't changed
+        TaggedFields taggedFields = (TaggedFields) QuorumStateData.SCHEMA_1.get(4).def.type;
+        assertEquals(0, taggedFields.numFields());
+    }
+
+    private void writeToStateFile(final File stateFile, String jsonString) {
+        try (final FileOutputStream fileOutputStream = new FileOutputStream(stateFile);
+             final BufferedWriter writer = new BufferedWriter(
+                     new OutputStreamWriter(fileOutputStream, StandardCharsets.UTF_8))) {
+            ObjectMapper mapper = new ObjectMapper();
+            JsonNode node = mapper.readTree(jsonString);
+
+            writer.write(node.toString());
+            writer.flush();
+            fileOutputStream.getFD().sync();
+
+        } catch (IOException e) {
+            throw new UncheckedIOException(
+                    String.format("Error while writing to Quorum state file %s",
+                            stateFile.getAbsolutePath()), e);
+        }
+    }
+}
diff --git a/raft/src/test/java/org/apache/kafka/raft/FollowerStateTest.java b/raft/src/test/java/org/apache/kafka/raft/FollowerStateTest.java
index 42c6bc9..1894472 100644
--- a/raft/src/test/java/org/apache/kafka/raft/FollowerStateTest.java
+++ b/raft/src/test/java/org/apache/kafka/raft/FollowerStateTest.java
@@ -19,6 +19,7 @@
 import org.apache.kafka.common.utils.LogContext;
 import org.apache.kafka.common.utils.MockTime;
 import org.apache.kafka.common.utils.Utils;
+import org.apache.kafka.raft.internals.ReplicaKey;
 import org.junit.jupiter.api.Test;
 import org.junit.jupiter.params.ParameterizedTest;
 import org.junit.jupiter.params.provider.ValueSource;
@@ -90,9 +91,9 @@
             Optional.empty()
         );
 
-        assertFalse(state.canGrantVote(1, isLogUpToDate));
-        assertFalse(state.canGrantVote(2, isLogUpToDate));
-        assertFalse(state.canGrantVote(3, isLogUpToDate));
+        assertFalse(state.canGrantVote(ReplicaKey.of(1, Optional.empty()), isLogUpToDate));
+        assertFalse(state.canGrantVote(ReplicaKey.of(2, Optional.empty()), isLogUpToDate));
+        assertFalse(state.canGrantVote(ReplicaKey.of(3, Optional.empty()), isLogUpToDate));
     }
 
 }
diff --git a/raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientTest.java b/raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientTest.java
index 1b1d8fc..c531e58 100644
--- a/raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientTest.java
+++ b/raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientTest.java
@@ -43,6 +43,7 @@
 import org.apache.kafka.raft.errors.BufferAllocationException;
 import org.apache.kafka.raft.errors.NotLeaderException;
 import org.apache.kafka.raft.errors.UnexpectedBaseOffsetException;
+import org.apache.kafka.raft.internals.ReplicaKey;
 import org.apache.kafka.test.TestUtils;
 import org.junit.jupiter.api.Test;
 import org.junit.jupiter.params.ParameterizedTest;
@@ -134,7 +135,7 @@
 
         RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters)
             .updateRandom(r -> r.mockNextInt(DEFAULT_ELECTION_TIMEOUT_MS, 0))
-            .withVotedCandidate(epoch, localId)
+            .withVotedCandidate(epoch, ReplicaKey.of(localId, Optional.empty()))
             .build();
 
         assertEquals(0L, context.log.endOffset().offset);
@@ -185,7 +186,7 @@
 
         RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters)
             .updateRandom(r -> r.mockNextInt(DEFAULT_ELECTION_TIMEOUT_MS, 0))
-            .withVotedCandidate(epoch, localId)
+            .withVotedCandidate(epoch, ReplicaKey.of(localId, Optional.empty()))
             .build();
 
         // Resign from candidate, will restart in candidate state
@@ -661,7 +662,7 @@
         Set<Integer> voters = Utils.mkSet(localId, 1, 2);
 
         RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters)
-            .withVotedCandidate(2, localId)
+            .withVotedCandidate(2, ReplicaKey.of(localId, Optional.empty()))
             .build();
         context.assertVotedCandidate(2, localId);
         assertEquals(0L, context.log.endOffset().offset);
@@ -759,7 +760,7 @@
         Set<Integer> voters = Utils.mkSet(localId, otherNodeId);
 
         RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters)
-            .withVotedCandidate(votedCandidateEpoch, otherNodeId)
+            .withVotedCandidate(votedCandidateEpoch, ReplicaKey.of(otherNodeId, Optional.empty()))
             .build();
 
         context.deliverRequest(context.beginEpochRequest(votedCandidateEpoch, otherNodeId));
@@ -1166,7 +1167,7 @@
         Set<Integer> voters = Utils.mkSet(localId, otherNodeId, votedCandidateId);
 
         RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters)
-            .withVotedCandidate(epoch, votedCandidateId)
+            .withVotedCandidate(epoch, ReplicaKey.of(votedCandidateId, Optional.empty()))
             .build();
 
         context.deliverRequest(context.voteRequest(epoch, otherNodeId, epoch - 1, 1));
@@ -1209,8 +1210,8 @@
         context.deliverRequest(context.voteRequest(epoch + 1, otherNodeId, epoch, 1));
         context.pollUntilResponse();
 
-        context.assertSentVoteResponse(Errors.INCONSISTENT_VOTER_SET, epoch, OptionalInt.empty(), false);
-        context.assertUnknownLeader(epoch);
+        context.assertSentVoteResponse(Errors.NONE, epoch + 1, OptionalInt.empty(), true);
+        context.assertVotedCandidate(epoch + 1, otherNodeId);
     }
 
     @Test
@@ -1307,7 +1308,7 @@
         Set<Integer> voters = Utils.mkSet(localId, otherNodeId);
 
         RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters)
-            .withVotedCandidate(leaderEpoch, localId)
+            .withVotedCandidate(leaderEpoch, ReplicaKey.of(localId, Optional.empty()))
             .build();
 
         context.pollUntilRequest();
@@ -1690,7 +1691,7 @@
     }
 
     @Test
-    public void testVoterOnlyRequestValidation() throws Exception {
+    public void testLeaderAcceptVoteFromNonVoter() throws Exception {
         int localId = 0;
         int otherNodeId = 1;
         int epoch = 5;
@@ -1699,19 +1700,13 @@
         RaftClientTestContext context = RaftClientTestContext.initializeAsLeader(localId, voters, epoch);
 
         int nonVoterId = 2;
+        context.deliverRequest(context.voteRequest(epoch - 1, nonVoterId, 0, 0));
+        context.client.poll();
+        context.assertSentVoteResponse(Errors.FENCED_LEADER_EPOCH, epoch, OptionalInt.of(localId), false);
+
         context.deliverRequest(context.voteRequest(epoch, nonVoterId, 0, 0));
         context.client.poll();
-        context.assertSentVoteResponse(Errors.INCONSISTENT_VOTER_SET, epoch, OptionalInt.of(localId), false);
-
-        context.deliverRequest(context.beginEpochRequest(epoch, nonVoterId));
-        context.client.poll();
-        context.assertSentBeginQuorumEpochResponse(Errors.INCONSISTENT_VOTER_SET, epoch, OptionalInt.of(localId));
-
-        context.deliverRequest(context.endEpochRequest(epoch, nonVoterId, Collections.singletonList(otherNodeId)));
-        context.client.poll();
-
-        // The sent request has no localId as a preferable voter.
-        context.assertSentEndQuorumEpochResponse(Errors.INCONSISTENT_VOTER_SET, epoch, OptionalInt.of(localId));
+        context.assertSentVoteResponse(Errors.NONE, epoch, OptionalInt.of(localId), false);
     }
 
     @Test
diff --git a/raft/src/test/java/org/apache/kafka/raft/LeaderStateTest.java b/raft/src/test/java/org/apache/kafka/raft/LeaderStateTest.java
index bd39d0b..e8fd1bb 100644
--- a/raft/src/test/java/org/apache/kafka/raft/LeaderStateTest.java
+++ b/raft/src/test/java/org/apache/kafka/raft/LeaderStateTest.java
@@ -21,6 +21,7 @@
 import org.apache.kafka.common.utils.MockTime;
 import org.apache.kafka.common.utils.Utils;
 import org.apache.kafka.raft.internals.BatchAccumulator;
+import org.apache.kafka.raft.internals.ReplicaKey;
 import org.junit.jupiter.api.Test;
 import org.junit.jupiter.params.ParameterizedTest;
 import org.junit.jupiter.params.provider.ValueSource;
@@ -544,9 +545,9 @@
     public void testGrantVote(boolean isLogUpToDate) {
         LeaderState<?> state = newLeaderState(Utils.mkSet(1, 2, 3), 1);
 
-        assertFalse(state.canGrantVote(1, isLogUpToDate));
-        assertFalse(state.canGrantVote(2, isLogUpToDate));
-        assertFalse(state.canGrantVote(3, isLogUpToDate));
+        assertFalse(state.canGrantVote(ReplicaKey.of(1, Optional.empty()), isLogUpToDate));
+        assertFalse(state.canGrantVote(ReplicaKey.of(2, Optional.empty()), isLogUpToDate));
+        assertFalse(state.canGrantVote(ReplicaKey.of(3, Optional.empty()), isLogUpToDate));
     }
 
     private static class MockOffsetMetadata implements OffsetMetadata {
diff --git a/raft/src/test/java/org/apache/kafka/raft/MockQuorumStateStore.java b/raft/src/test/java/org/apache/kafka/raft/MockQuorumStateStore.java
index 87f7c0d..0a94a21 100644
--- a/raft/src/test/java/org/apache/kafka/raft/MockQuorumStateStore.java
+++ b/raft/src/test/java/org/apache/kafka/raft/MockQuorumStateStore.java
@@ -18,18 +18,22 @@
 
 import java.nio.file.FileSystems;
 import java.nio.file.Path;
+import java.util.Optional;
+import org.apache.kafka.raft.generated.QuorumStateData;
 
 public class MockQuorumStateStore implements QuorumStateStore {
-    private ElectionState current;
+    private Optional<QuorumStateData> current = Optional.empty();
 
     @Override
-    public ElectionState readElectionState() {
-        return current;
+    public Optional<ElectionState> readElectionState() {
+        return current.map(ElectionState::fromQuorumStateData);
     }
 
     @Override
-    public void writeElectionState(ElectionState update) {
-        this.current = update;
+    public void writeElectionState(ElectionState update, short kraftVersion) {
+        current = Optional.of(
+            update.toQuorumStateData(quorumStateVersionFromKRaftVersion(kraftVersion))
+        );
     }
 
     @Override
@@ -39,6 +43,18 @@
 
     @Override
     public void clear() {
-        current = null;
+        current = Optional.empty();
+    }
+
+    private short quorumStateVersionFromKRaftVersion(short kraftVersion) {
+        if (kraftVersion == 0) {
+            return 0;
+        } else if (kraftVersion == 1) {
+            return 1;
+        } else {
+            throw new IllegalArgumentException(
+                String.format("Unknown kraft.version %d", kraftVersion)
+            );
+        }
     }
 }
diff --git a/raft/src/test/java/org/apache/kafka/raft/QuorumStateTest.java b/raft/src/test/java/org/apache/kafka/raft/QuorumStateTest.java
index ce33328..08acba1 100644
--- a/raft/src/test/java/org/apache/kafka/raft/QuorumStateTest.java
+++ b/raft/src/test/java/org/apache/kafka/raft/QuorumStateTest.java
@@ -16,11 +16,15 @@
  */
 package org.apache.kafka.raft;
 
+import org.apache.kafka.common.Uuid;
 import org.apache.kafka.common.utils.LogContext;
 import org.apache.kafka.common.utils.MockTime;
 import org.apache.kafka.common.utils.Utils;
 import org.apache.kafka.raft.internals.BatchAccumulator;
-import org.junit.jupiter.api.Test;
+import org.apache.kafka.raft.internals.ReplicaKey;
+import org.apache.kafka.raft.internals.VoterSetTest;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.ValueSource;
 import org.mockito.Mockito;
 
 import java.io.UncheckedIOException;
@@ -32,12 +36,16 @@
 
 import static org.junit.jupiter.api.Assertions.assertEquals;
 import static org.junit.jupiter.api.Assertions.assertFalse;
-import static org.junit.jupiter.api.Assertions.assertNull;
 import static org.junit.jupiter.api.Assertions.assertThrows;
 import static org.junit.jupiter.api.Assertions.assertTrue;
 
 public class QuorumStateTest {
     private final int localId = 0;
+    private final Uuid localDirectoryId = Uuid.randomUuid();
+    private final ReplicaKey localVoterKey = ReplicaKey.of(
+        localId,
+        Optional.of(localDirectoryId)
+    );
     private final int logEndEpoch = 0;
     private final MockQuorumStateStore store = new MockQuorumStateStore();
     private final MockTime time = new MockTime();
@@ -46,17 +54,20 @@
     private final MockableRandom random = new MockableRandom(1L);
     private final BatchAccumulator<?> accumulator = Mockito.mock(BatchAccumulator.class);
 
-    private QuorumState buildQuorumState(Set<Integer> voters) {
-        return buildQuorumState(OptionalInt.of(localId), voters);
+    private QuorumState buildQuorumState(Set<Integer> voters, short kraftVersion) {
+        return buildQuorumState(OptionalInt.of(localId), voters, kraftVersion);
     }
 
     private QuorumState buildQuorumState(
         OptionalInt localId,
-        Set<Integer> voters
+        Set<Integer> voters,
+        short kraftVersion
     ) {
         return new QuorumState(
             localId,
-            voters,
+            localDirectoryId,
+            () -> VoterSetTest.voterSet(VoterSetTest.voterMap(voters, false)),
+            () -> kraftVersion,
             electionTimeoutMs,
             fetchTimeoutMs,
             store,
@@ -66,12 +77,13 @@
         );
     }
 
-    @Test
-    public void testInitializePrimordialEpoch() {
+    @ParameterizedTest
+    @ValueSource(shorts = {0, 1})
+    public void testInitializePrimordialEpoch(short kraftVersion) {
         Set<Integer> voters = Utils.mkSet(localId);
-        assertNull(store.readElectionState());
+        assertEquals(Optional.empty(), store.readElectionState());
 
-        QuorumState state = initializeEmptyState(voters);
+        QuorumState state = initializeEmptyState(voters, kraftVersion);
         assertTrue(state.isUnattached());
         assertEquals(0, state.epoch());
         state.transitionToCandidate();
@@ -80,18 +92,19 @@
         assertEquals(1, candidateState.epoch());
     }
 
-    @Test
-    public void testInitializeAsUnattached() {
+    @ParameterizedTest
+    @ValueSource(shorts = {0, 1})
+    public void testInitializeAsUnattached(short kraftVersion) {
         int node1 = 1;
         int node2 = 2;
         int epoch = 5;
         Set<Integer> voters = Utils.mkSet(localId, node1, node2);
-        store.writeElectionState(ElectionState.withUnknownLeader(epoch, voters));
+        store.writeElectionState(ElectionState.withUnknownLeader(epoch, voters), kraftVersion);
 
         int jitterMs = 2500;
         random.mockNextInt(jitterMs);
 
-        QuorumState state = buildQuorumState(voters);
+        QuorumState state = buildQuorumState(voters, kraftVersion);
         state.initialize(new OffsetAndEpoch(0L, 0));
 
         assertTrue(state.isUnattached());
@@ -101,15 +114,16 @@
             unattachedState.remainingElectionTimeMs(time.milliseconds()));
     }
 
-    @Test
-    public void testInitializeAsFollower() {
+    @ParameterizedTest
+    @ValueSource(shorts = {0, 1})
+    public void testInitializeAsFollower(short kraftVersion) {
         int node1 = 1;
         int node2 = 2;
         int epoch = 5;
         Set<Integer> voters = Utils.mkSet(localId, node1, node2);
-        store.writeElectionState(ElectionState.withElectedLeader(epoch, node1, voters));
+        store.writeElectionState(ElectionState.withElectedLeader(epoch, node1, voters), kraftVersion);
 
-        QuorumState state = buildQuorumState(voters);
+        QuorumState state = buildQuorumState(voters, kraftVersion);
         state.initialize(new OffsetAndEpoch(0L, logEndEpoch));
         assertTrue(state.isFollower());
         assertEquals(epoch, state.epoch());
@@ -120,64 +134,86 @@
         assertEquals(fetchTimeoutMs, followerState.remainingFetchTimeMs(time.milliseconds()));
     }
 
-    @Test
-    public void testInitializeAsVoted() {
+    @ParameterizedTest
+    @ValueSource(shorts = {0, 1})
+    public void testInitializeAsVoted(short kraftVersion) {
         int node1 = 1;
+        Optional<Uuid> node1DirectoryId = Optional.of(Uuid.randomUuid());
         int node2 = 2;
         int epoch = 5;
         Set<Integer> voters = Utils.mkSet(localId, node1, node2);
-        store.writeElectionState(ElectionState.withVotedCandidate(epoch, node1, voters));
+        store.writeElectionState(
+            ElectionState.withVotedCandidate(epoch, ReplicaKey.of(node1, node1DirectoryId), voters),
+            kraftVersion
+        );
 
         int jitterMs = 2500;
         random.mockNextInt(jitterMs);
 
-        QuorumState state = buildQuorumState(voters);
+        QuorumState state = buildQuorumState(voters, kraftVersion);
         state.initialize(new OffsetAndEpoch(0L, logEndEpoch));
         assertTrue(state.isVoted());
         assertEquals(epoch, state.epoch());
 
         VotedState votedState = state.votedStateOrThrow();
         assertEquals(epoch, votedState.epoch());
-        assertEquals(node1, votedState.votedId());
-        assertEquals(electionTimeoutMs + jitterMs,
-            votedState.remainingElectionTimeMs(time.milliseconds()));
+        assertEquals(
+            ReplicaKey.of(node1, persistedDirectoryId(node1DirectoryId, kraftVersion)),
+            votedState.votedKey()
+        );
+
+        assertEquals(
+            electionTimeoutMs + jitterMs,
+            votedState.remainingElectionTimeMs(time.milliseconds())
+        );
     }
 
-    @Test
-    public void testInitializeAsResignedCandidate() {
+    @ParameterizedTest
+    @ValueSource(shorts = {0, 1})
+    public void testInitializeAsResignedCandidate(short kraftVersion) {
         int node1 = 1;
         int node2 = 2;
         int epoch = 5;
         Set<Integer> voters = Utils.mkSet(localId, node1, node2);
-        ElectionState election = ElectionState.withVotedCandidate(epoch, localId, voters);
-        store.writeElectionState(election);
+        ElectionState election = ElectionState.withVotedCandidate(
+            epoch,
+            localVoterKey,
+            voters
+        );
+        store.writeElectionState(election, kraftVersion);
 
         int jitterMs = 2500;
         random.mockNextInt(jitterMs);
 
-        QuorumState state = buildQuorumState(voters);
+        QuorumState state = buildQuorumState(voters, kraftVersion);
         state.initialize(new OffsetAndEpoch(0L, logEndEpoch));
         assertTrue(state.isCandidate());
         assertEquals(epoch, state.epoch());
 
         CandidateState candidateState = state.candidateStateOrThrow();
         assertEquals(epoch, candidateState.epoch());
-        assertEquals(election, candidateState.election());
+        assertEquals(
+            ElectionState.withVotedCandidate(epoch, localVoterKey, voters),
+            candidateState.election()
+        );
         assertEquals(Utils.mkSet(node1, node2), candidateState.unrecordedVoters());
         assertEquals(Utils.mkSet(localId), candidateState.grantingVoters());
         assertEquals(Collections.emptySet(), candidateState.rejectingVoters());
-        assertEquals(electionTimeoutMs + jitterMs,
-            candidateState.remainingElectionTimeMs(time.milliseconds()));
+        assertEquals(
+            electionTimeoutMs + jitterMs,
+            candidateState.remainingElectionTimeMs(time.milliseconds())
+        );
     }
 
-    @Test
-    public void testInitializeAsResignedLeader() {
+    @ParameterizedTest
+    @ValueSource(shorts = {0, 1})
+    public void testInitializeAsResignedLeader(short kraftVersion) {
         int node1 = 1;
         int node2 = 2;
         int epoch = 5;
         Set<Integer> voters = Utils.mkSet(localId, node1, node2);
         ElectionState election = ElectionState.withElectedLeader(epoch, localId, voters);
-        store.writeElectionState(election);
+        store.writeElectionState(election, kraftVersion);
 
         // If we were previously a leader, we will start as resigned in order to ensure
         // a new leader gets elected. This ensures that records are always uniquely
@@ -187,7 +223,7 @@
         int jitterMs = 2500;
         random.mockNextInt(jitterMs);
 
-        QuorumState state = buildQuorumState(voters);
+        QuorumState state = buildQuorumState(voters, kraftVersion);
         state.initialize(new OffsetAndEpoch(0L, logEndEpoch));
         assertFalse(state.isLeader());
         assertEquals(epoch, state.epoch());
@@ -200,14 +236,15 @@
             resignedState.remainingElectionTimeMs(time.milliseconds()));
     }
 
-    @Test
-    public void testCandidateToCandidate() {
+    @ParameterizedTest
+    @ValueSource(shorts = {0, 1})
+    public void testCandidateToCandidate(short kraftVersion) {
         int node1 = 1;
         int node2 = 2;
         Set<Integer> voters = Utils.mkSet(localId, node1, node2);
-        assertNull(store.readElectionState());
+        assertEquals(Optional.empty(), store.readElectionState());
 
-        QuorumState state = initializeEmptyState(voters);
+        QuorumState state = initializeEmptyState(voters, kraftVersion);
         state.transitionToCandidate();
         assertTrue(state.isCandidate());
         assertEquals(1, state.epoch());
@@ -243,14 +280,15 @@
             candidate2.remainingElectionTimeMs(time.milliseconds()));
     }
 
-    @Test
-    public void testCandidateToResigned() {
+    @ParameterizedTest
+    @ValueSource(shorts = {0, 1})
+    public void testCandidateToResigned(short kraftVersion) {
         int node1 = 1;
         int node2 = 2;
         Set<Integer> voters = Utils.mkSet(localId, node1, node2);
-        assertNull(store.readElectionState());
+        assertEquals(Optional.empty(), store.readElectionState());
 
-        QuorumState state = initializeEmptyState(voters);
+        QuorumState state = initializeEmptyState(voters, kraftVersion);
         state.transitionToCandidate();
         assertTrue(state.isCandidate());
         assertEquals(1, state.epoch());
@@ -260,28 +298,30 @@
         assertTrue(state.isCandidate());
     }
 
-    @Test
-    public void testCandidateToLeader()  {
+    @ParameterizedTest
+    @ValueSource(shorts = {0, 1})
+    public void testCandidateToLeader(short kraftVersion)  {
         Set<Integer> voters = Utils.mkSet(localId);
-        assertNull(store.readElectionState());
+        assertEquals(Optional.empty(), store.readElectionState());
 
-        QuorumState state = initializeEmptyState(voters);
+        QuorumState state = initializeEmptyState(voters, kraftVersion);
         state.transitionToCandidate();
         assertTrue(state.isCandidate());
         assertEquals(1, state.epoch());
 
         state.transitionToLeader(0L, accumulator);
-        LeaderState<Object> leaderState = state.leaderStateOrThrow();
+        LeaderState<?> leaderState = state.leaderStateOrThrow();
         assertTrue(state.isLeader());
         assertEquals(1, leaderState.epoch());
         assertEquals(Optional.empty(), leaderState.highWatermark());
     }
 
-    @Test
-    public void testCandidateToLeaderWithoutGrantedVote() {
+    @ParameterizedTest
+    @ValueSource(shorts = {0, 1})
+    public void testCandidateToLeaderWithoutGrantedVote(short kraftVersion) {
         int otherNodeId = 1;
         Set<Integer> voters = Utils.mkSet(localId, otherNodeId);
-        QuorumState state = initializeEmptyState(voters);
+        QuorumState state = initializeEmptyState(voters, kraftVersion);
         state.initialize(new OffsetAndEpoch(0L, logEndEpoch));
         state.transitionToCandidate();
         assertFalse(state.candidateStateOrThrow().isVoteGranted());
@@ -292,72 +332,111 @@
         assertTrue(state.isLeader());
     }
 
-    @Test
-    public void testCandidateToFollower() {
+    @ParameterizedTest
+    @ValueSource(shorts = {0, 1})
+    public void testCandidateToFollower(short kraftVersion) {
         int otherNodeId = 1;
         Set<Integer> voters = Utils.mkSet(localId, otherNodeId);
-        QuorumState state = initializeEmptyState(voters);
+        QuorumState state = initializeEmptyState(voters, kraftVersion);
         state.initialize(new OffsetAndEpoch(0L, logEndEpoch));
         state.transitionToCandidate();
 
         state.transitionToFollower(5, otherNodeId);
         assertEquals(5, state.epoch());
         assertEquals(OptionalInt.of(otherNodeId), state.leaderId());
-        assertEquals(ElectionState.withElectedLeader(5, otherNodeId, voters), store.readElectionState());
+        assertEquals(
+            Optional.of(ElectionState.withElectedLeader(5, otherNodeId, persistedVoters(voters, kraftVersion))),
+            store.readElectionState()
+        );
     }
 
-    @Test
-    public void testCandidateToUnattached() {
+    @ParameterizedTest
+    @ValueSource(shorts = {0, 1})
+    public void testCandidateToUnattached(short kraftVersion) {
         int otherNodeId = 1;
         Set<Integer> voters = Utils.mkSet(localId, otherNodeId);
-        QuorumState state = initializeEmptyState(voters);
+        QuorumState state = initializeEmptyState(voters, kraftVersion);
         state.initialize(new OffsetAndEpoch(0L, logEndEpoch));
         state.transitionToCandidate();
 
         state.transitionToUnattached(5);
         assertEquals(5, state.epoch());
         assertEquals(OptionalInt.empty(), state.leaderId());
-        assertEquals(ElectionState.withUnknownLeader(5, voters), store.readElectionState());
+        assertEquals(
+            Optional.of(ElectionState.withUnknownLeader(5, persistedVoters(voters, kraftVersion))),
+            store.readElectionState()
+        );
     }
 
-    @Test
-    public void testCandidateToVoted() {
+    @ParameterizedTest
+    @ValueSource(shorts = {0, 1})
+    public void testCandidateToVoted(short kraftVersion) {
         int otherNodeId = 1;
+        Optional<Uuid> otherNodeDirectoryId = Optional.of(Uuid.randomUuid());
+        ReplicaKey otherNodeKey = ReplicaKey.of(otherNodeId, otherNodeDirectoryId);
         Set<Integer> voters = Utils.mkSet(localId, otherNodeId);
-        QuorumState state = initializeEmptyState(voters);
+        QuorumState state = initializeEmptyState(voters, kraftVersion);
         state.initialize(new OffsetAndEpoch(0L, logEndEpoch));
         state.transitionToCandidate();
 
-        state.transitionToVoted(5, otherNodeId);
+        state.transitionToVoted(5, otherNodeKey);
         assertEquals(5, state.epoch());
         assertEquals(OptionalInt.empty(), state.leaderId());
 
         VotedState followerState = state.votedStateOrThrow();
-        assertEquals(otherNodeId, followerState.votedId());
-        assertEquals(ElectionState.withVotedCandidate(5, otherNodeId, voters), store.readElectionState());
+        assertEquals(otherNodeKey, followerState.votedKey());
+
+        assertEquals(
+            Optional.of(
+                ElectionState.withVotedCandidate(
+                    5,
+                    ReplicaKey.of(
+                        otherNodeId,
+                        persistedDirectoryId(otherNodeDirectoryId, kraftVersion)
+                    ),
+                    persistedVoters(voters, kraftVersion))
+            ),
+            store.readElectionState()
+        );
     }
 
-    @Test
-    public void testCandidateToAnyStateLowerEpoch() {
+    @ParameterizedTest
+    @ValueSource(shorts = {0, 1})
+    public void testCandidateToAnyStateLowerEpoch(short kraftVersion) {
         int otherNodeId = 1;
+        Optional<Uuid> otherNodeDirectoryId = Optional.of(Uuid.randomUuid());
+        ReplicaKey otherNodeKey = ReplicaKey.of(otherNodeId, otherNodeDirectoryId);
         Set<Integer> voters = Utils.mkSet(localId, otherNodeId);
-        QuorumState state = initializeEmptyState(voters);
+        QuorumState state = initializeEmptyState(voters, kraftVersion);
         state.initialize(new OffsetAndEpoch(0L, logEndEpoch));
         state.transitionToUnattached(5);
         state.transitionToCandidate();
         assertThrows(IllegalStateException.class, () -> state.transitionToUnattached(4));
-        assertThrows(IllegalStateException.class, () -> state.transitionToVoted(4, otherNodeId));
+        assertThrows(IllegalStateException.class, () -> state.transitionToVoted(4, otherNodeKey));
         assertThrows(IllegalStateException.class, () -> state.transitionToFollower(4, otherNodeId));
         assertEquals(6, state.epoch());
-        assertEquals(ElectionState.withVotedCandidate(6, localId, voters), store.readElectionState());
+        assertEquals(
+            Optional.of(
+                ElectionState.withVotedCandidate(
+                    6,
+                    ReplicaKey.of(
+                        localId,
+                        persistedDirectoryId(Optional.of(localDirectoryId), kraftVersion)
+                    ),
+                    persistedVoters(voters, kraftVersion)
+                )
+            ),
+            store.readElectionState()
+        );
     }
 
-    @Test
-    public void testLeaderToLeader() {
+    @ParameterizedTest
+    @ValueSource(shorts = {0, 1})
+    public void testLeaderToLeader(short kraftVersion) {
         Set<Integer> voters = Utils.mkSet(localId);
-        assertNull(store.readElectionState());
+        assertEquals(Optional.empty(), store.readElectionState());
 
-        QuorumState state = initializeEmptyState(voters);
+        QuorumState state = initializeEmptyState(voters, kraftVersion);
         state.initialize(new OffsetAndEpoch(0L, logEndEpoch));
         state.transitionToCandidate();
         state.transitionToLeader(0L, accumulator);
@@ -369,12 +448,13 @@
         assertEquals(1, state.epoch());
     }
 
-    @Test
-    public void testLeaderToResigned() {
+    @ParameterizedTest
+    @ValueSource(shorts = {0, 1})
+    public void testLeaderToResigned(short kraftVersion) {
         Set<Integer> voters = Utils.mkSet(localId);
-        assertNull(store.readElectionState());
+        assertEquals(Optional.empty(), store.readElectionState());
 
-        QuorumState state = initializeEmptyState(voters);
+        QuorumState state = initializeEmptyState(voters, kraftVersion);
         state.initialize(new OffsetAndEpoch(0L, logEndEpoch));
         state.transitionToCandidate();
         state.transitionToLeader(0L, accumulator);
@@ -390,12 +470,13 @@
         assertEquals(Collections.emptySet(), resignedState.unackedVoters());
     }
 
-    @Test
-    public void testLeaderToCandidate() {
+    @ParameterizedTest
+    @ValueSource(shorts = {0, 1})
+    public void testLeaderToCandidate(short kraftVersion) {
         Set<Integer> voters = Utils.mkSet(localId);
-        assertNull(store.readElectionState());
+        assertEquals(Optional.empty(), store.readElectionState());
 
-        QuorumState state = initializeEmptyState(voters);
+        QuorumState state = initializeEmptyState(voters, kraftVersion);
         state.initialize(new OffsetAndEpoch(0L, logEndEpoch));
         state.transitionToCandidate();
         state.transitionToLeader(0L, accumulator);
@@ -407,12 +488,13 @@
         assertEquals(1, state.epoch());
     }
 
-    @Test
-    public void testLeaderToFollower() {
+    @ParameterizedTest
+    @ValueSource(shorts = {0, 1})
+    public void testLeaderToFollower(short kraftVersion) {
         int otherNodeId = 1;
         Set<Integer> voters = Utils.mkSet(localId, otherNodeId);
 
-        QuorumState state = initializeEmptyState(voters);
+        QuorumState state = initializeEmptyState(voters, kraftVersion);
 
         state.transitionToCandidate();
         state.candidateStateOrThrow().recordGrantedVote(otherNodeId);
@@ -421,14 +503,18 @@
 
         assertEquals(5, state.epoch());
         assertEquals(OptionalInt.of(otherNodeId), state.leaderId());
-        assertEquals(ElectionState.withElectedLeader(5, otherNodeId, voters), store.readElectionState());
+        assertEquals(
+            Optional.of(ElectionState.withElectedLeader(5, otherNodeId, persistedVoters(voters, kraftVersion))),
+            store.readElectionState()
+        );
     }
 
-    @Test
-    public void testLeaderToUnattached() {
+    @ParameterizedTest
+    @ValueSource(shorts = {0, 1})
+    public void testLeaderToUnattached(short kraftVersion) {
         int otherNodeId = 1;
         Set<Integer> voters = Utils.mkSet(localId, otherNodeId);
-        QuorumState state = initializeEmptyState(voters);
+        QuorumState state = initializeEmptyState(voters, kraftVersion);
         state.initialize(new OffsetAndEpoch(0L, logEndEpoch));
         state.transitionToCandidate();
         state.candidateStateOrThrow().recordGrantedVote(otherNodeId);
@@ -436,109 +522,177 @@
         state.transitionToUnattached(5);
         assertEquals(5, state.epoch());
         assertEquals(OptionalInt.empty(), state.leaderId());
-        assertEquals(ElectionState.withUnknownLeader(5, voters), store.readElectionState());
+        assertEquals(
+            Optional.of(ElectionState.withUnknownLeader(5, persistedVoters(voters, kraftVersion))),
+            store.readElectionState()
+        );
     }
 
-    @Test
-    public void testLeaderToVoted() {
+    @ParameterizedTest
+    @ValueSource(shorts = {0, 1})
+    public void testLeaderToVoted(short kraftVersion) {
         int otherNodeId = 1;
+        Optional<Uuid> otherNodeDirectoryId = Optional.of(Uuid.randomUuid());
+        ReplicaKey otherNodeKey = ReplicaKey.of(otherNodeId, otherNodeDirectoryId);
         Set<Integer> voters = Utils.mkSet(localId, otherNodeId);
-        QuorumState state = initializeEmptyState(voters);
+        QuorumState state = initializeEmptyState(voters, kraftVersion);
         state.initialize(new OffsetAndEpoch(0L, logEndEpoch));
         state.transitionToCandidate();
         state.candidateStateOrThrow().recordGrantedVote(otherNodeId);
         state.transitionToLeader(0L, accumulator);
-        state.transitionToVoted(5, otherNodeId);
+        state.transitionToVoted(5, otherNodeKey);
 
         assertEquals(5, state.epoch());
         assertEquals(OptionalInt.empty(), state.leaderId());
+
         VotedState votedState = state.votedStateOrThrow();
-        assertEquals(otherNodeId, votedState.votedId());
-        assertEquals(ElectionState.withVotedCandidate(5, otherNodeId, voters), store.readElectionState());
+        assertEquals(otherNodeKey, votedState.votedKey());
+
+        assertEquals(
+            Optional.of(
+                ElectionState.withVotedCandidate(
+                    5,
+                    ReplicaKey.of(
+                        otherNodeId,
+                        persistedDirectoryId(otherNodeDirectoryId, kraftVersion)
+                    ),
+                    persistedVoters(voters, kraftVersion)
+                )
+            ),
+            store.readElectionState()
+        );
     }
 
-    @Test
-    public void testLeaderToAnyStateLowerEpoch() {
+    @ParameterizedTest
+    @ValueSource(shorts = {0, 1})
+    public void testLeaderToAnyStateLowerEpoch(short kraftVersion) {
         int otherNodeId = 1;
+        Optional<Uuid> otherNodeDirectoryId = Optional.of(Uuid.randomUuid());
+        ReplicaKey otherNodeKey = ReplicaKey.of(otherNodeId, otherNodeDirectoryId);
         Set<Integer> voters = Utils.mkSet(localId, otherNodeId);
-        QuorumState state = initializeEmptyState(voters);
+        QuorumState state = initializeEmptyState(voters, kraftVersion);
         state.initialize(new OffsetAndEpoch(0L, logEndEpoch));
         state.transitionToUnattached(5);
         state.transitionToCandidate();
         state.candidateStateOrThrow().recordGrantedVote(otherNodeId);
         state.transitionToLeader(0L, accumulator);
         assertThrows(IllegalStateException.class, () -> state.transitionToUnattached(4));
-        assertThrows(IllegalStateException.class, () -> state.transitionToVoted(4, otherNodeId));
+        assertThrows(IllegalStateException.class, () -> state.transitionToVoted(4, otherNodeKey));
         assertThrows(IllegalStateException.class, () -> state.transitionToFollower(4, otherNodeId));
         assertEquals(6, state.epoch());
-        assertEquals(ElectionState.withElectedLeader(6, localId, voters), store.readElectionState());
+        assertEquals(
+            Optional.of(ElectionState.withElectedLeader(6, localId, persistedVoters(voters, kraftVersion))),
+            store.readElectionState()
+        );
     }
 
-    @Test
-    public void testCannotFollowOrVoteForSelf() {
+    @ParameterizedTest
+    @ValueSource(shorts = {0, 1})
+    public void testCannotFollowOrVoteForSelf(short kraftVersion) {
         Set<Integer> voters = Utils.mkSet(localId);
-        assertNull(store.readElectionState());
-        QuorumState state = initializeEmptyState(voters);
+        assertEquals(Optional.empty(), store.readElectionState());
+        QuorumState state = initializeEmptyState(voters, kraftVersion);
 
         assertThrows(IllegalStateException.class, () -> state.transitionToFollower(0, localId));
-        assertThrows(IllegalStateException.class, () -> state.transitionToVoted(0, localId));
+        assertThrows(IllegalStateException.class, () -> state.transitionToVoted(0, localVoterKey));
     }
 
-    @Test
-    public void testUnattachedToLeaderOrResigned() {
+    @ParameterizedTest
+    @ValueSource(shorts = {0, 1})
+    public void testUnattachedToLeaderOrResigned(short kraftVersion) {
         int leaderId = 1;
         int epoch = 5;
         Set<Integer> voters = Utils.mkSet(localId, leaderId);
-        store.writeElectionState(ElectionState.withVotedCandidate(epoch, leaderId, voters));
-        QuorumState state = initializeEmptyState(voters);
+        store.writeElectionState(
+            ElectionState.withVotedCandidate(
+                epoch,
+                ReplicaKey.of(leaderId, Optional.empty()),
+                voters
+            ),
+            kraftVersion
+        );
+        QuorumState state = initializeEmptyState(voters, kraftVersion);
         state.initialize(new OffsetAndEpoch(0L, logEndEpoch));
         assertTrue(state.isUnattached());
         assertThrows(IllegalStateException.class, () -> state.transitionToLeader(0L, accumulator));
         assertThrows(IllegalStateException.class, () -> state.transitionToResigned(Collections.emptyList()));
     }
 
-    @Test
-    public void testUnattachedToVotedSameEpoch() {
+    @ParameterizedTest
+    @ValueSource(shorts = {0, 1})
+    public void testUnattachedToVotedSameEpoch(short kraftVersion) {
         int otherNodeId = 1;
+        Optional<Uuid> otherNodeDirectoryId = Optional.of(Uuid.randomUuid());
+        ReplicaKey otherNodeKey = ReplicaKey.of(otherNodeId, otherNodeDirectoryId);
         Set<Integer> voters = Utils.mkSet(localId, otherNodeId);
-        QuorumState state = initializeEmptyState(voters);
+        QuorumState state = initializeEmptyState(voters, kraftVersion);
         state.initialize(new OffsetAndEpoch(0L, logEndEpoch));
         state.transitionToUnattached(5);
 
         int jitterMs = 2500;
         random.mockNextInt(electionTimeoutMs, jitterMs);
-        state.transitionToVoted(5, otherNodeId);
+        state.transitionToVoted(5, otherNodeKey);
 
         VotedState votedState = state.votedStateOrThrow();
         assertEquals(5, votedState.epoch());
-        assertEquals(otherNodeId, votedState.votedId());
-        assertEquals(ElectionState.withVotedCandidate(5, otherNodeId, voters), store.readElectionState());
+        assertEquals(otherNodeKey, votedState.votedKey());
+
+        assertEquals(
+            Optional.of(
+                ElectionState.withVotedCandidate(
+                    5,
+                    ReplicaKey.of(
+                        otherNodeId,
+                        persistedDirectoryId(otherNodeDirectoryId, kraftVersion)
+                    ),
+                    persistedVoters(voters, kraftVersion)
+                )
+            ),
+            store.readElectionState()
+        );
 
         // Verify election timeout is reset when we vote for a candidate
         assertEquals(electionTimeoutMs + jitterMs,
             votedState.remainingElectionTimeMs(time.milliseconds()));
     }
 
-    @Test
-    public void testUnattachedToVotedHigherEpoch() {
+    @ParameterizedTest
+    @ValueSource(shorts = {0, 1})
+    public void testUnattachedToVotedHigherEpoch(short kraftVersion) {
         int otherNodeId = 1;
+        Optional<Uuid> otherNodeDirectoryId = Optional.of(Uuid.randomUuid());
+        ReplicaKey otherNodeKey = ReplicaKey.of(otherNodeId, otherNodeDirectoryId);
         Set<Integer> voters = Utils.mkSet(localId, otherNodeId);
-        QuorumState state = initializeEmptyState(voters);
+        QuorumState state = initializeEmptyState(voters, kraftVersion);
         state.initialize(new OffsetAndEpoch(0L, logEndEpoch));
         state.transitionToUnattached(5);
-        state.transitionToVoted(8, otherNodeId);
+        state.transitionToVoted(8, otherNodeKey);
 
         VotedState votedState = state.votedStateOrThrow();
         assertEquals(8, votedState.epoch());
-        assertEquals(otherNodeId, votedState.votedId());
-        assertEquals(ElectionState.withVotedCandidate(8, otherNodeId, voters), store.readElectionState());
+        assertEquals(otherNodeKey, votedState.votedKey());
+
+        assertEquals(
+            Optional.of(
+                ElectionState.withVotedCandidate(
+                    8,
+                    ReplicaKey.of(
+                        otherNodeId,
+                        persistedDirectoryId(otherNodeDirectoryId, kraftVersion)
+                    ),
+                    persistedVoters(voters, kraftVersion)
+                )
+            ),
+            store.readElectionState()
+        );
     }
 
-    @Test
-    public void testUnattachedToCandidate() {
+    @ParameterizedTest
+    @ValueSource(shorts = {0, 1})
+    public void testUnattachedToCandidate(short kraftVersion) {
         int otherNodeId = 1;
         Set<Integer> voters = Utils.mkSet(localId, otherNodeId);
-        QuorumState state = initializeEmptyState(voters);
+        QuorumState state = initializeEmptyState(voters, kraftVersion);
         state.initialize(new OffsetAndEpoch(0L, logEndEpoch));
         state.transitionToUnattached(5);
 
@@ -553,11 +707,12 @@
             candidateState.remainingElectionTimeMs(time.milliseconds()));
     }
 
-    @Test
-    public void testUnattachedToUnattached() {
+    @ParameterizedTest
+    @ValueSource(shorts = {0, 1})
+    public void testUnattachedToUnattached(short kraftVersion) {
         int otherNodeId = 1;
         Set<Integer> voters = Utils.mkSet(localId, otherNodeId);
-        QuorumState state = initializeEmptyState(voters);
+        QuorumState state = initializeEmptyState(voters, kraftVersion);
         state.initialize(new OffsetAndEpoch(0L, logEndEpoch));
         state.transitionToUnattached(5);
 
@@ -573,11 +728,12 @@
             unattachedState.remainingElectionTimeMs(time.milliseconds()));
     }
 
-    @Test
-    public void testUnattachedToFollowerSameEpoch() {
+    @ParameterizedTest
+    @ValueSource(shorts = {0, 1})
+    public void testUnattachedToFollowerSameEpoch(short kraftVersion) {
         int otherNodeId = 1;
         Set<Integer> voters = Utils.mkSet(localId, otherNodeId);
-        QuorumState state = initializeEmptyState(voters);
+        QuorumState state = initializeEmptyState(voters, kraftVersion);
         state.initialize(new OffsetAndEpoch(0L, logEndEpoch));
         state.transitionToUnattached(5);
 
@@ -589,11 +745,12 @@
         assertEquals(fetchTimeoutMs, followerState.remainingFetchTimeMs(time.milliseconds()));
     }
 
-    @Test
-    public void testUnattachedToFollowerHigherEpoch() {
+    @ParameterizedTest
+    @ValueSource(shorts = {0, 1})
+    public void testUnattachedToFollowerHigherEpoch(short kraftVersion) {
         int otherNodeId = 1;
         Set<Integer> voters = Utils.mkSet(localId, otherNodeId);
-        QuorumState state = initializeEmptyState(voters);
+        QuorumState state = initializeEmptyState(voters, kraftVersion);
         state.initialize(new OffsetAndEpoch(0L, logEndEpoch));
         state.transitionToUnattached(5);
 
@@ -605,40 +762,47 @@
         assertEquals(fetchTimeoutMs, followerState.remainingFetchTimeMs(time.milliseconds()));
     }
 
-    @Test
-    public void testUnattachedToAnyStateLowerEpoch() {
+    @ParameterizedTest
+    @ValueSource(shorts = {0, 1})
+    public void testUnattachedToAnyStateLowerEpoch(short kraftVersion) {
         int otherNodeId = 1;
+        ReplicaKey otherNodeKey = ReplicaKey.of(otherNodeId, Optional.empty());
         Set<Integer> voters = Utils.mkSet(localId, otherNodeId);
-        QuorumState state = initializeEmptyState(voters);
+        QuorumState state = initializeEmptyState(voters, kraftVersion);
         state.initialize(new OffsetAndEpoch(0L, logEndEpoch));
         state.transitionToUnattached(5);
         assertThrows(IllegalStateException.class, () -> state.transitionToUnattached(4));
-        assertThrows(IllegalStateException.class, () -> state.transitionToVoted(4, otherNodeId));
+        assertThrows(IllegalStateException.class, () -> state.transitionToVoted(4, otherNodeKey));
         assertThrows(IllegalStateException.class, () -> state.transitionToFollower(4, otherNodeId));
         assertEquals(5, state.epoch());
-        assertEquals(ElectionState.withUnknownLeader(5, voters), store.readElectionState());
+        assertEquals(
+            Optional.of(ElectionState.withUnknownLeader(5, persistedVoters(voters, kraftVersion))),
+            store.readElectionState()
+        );
     }
 
-    @Test
-    public void testVotedToInvalidLeaderOrResigned() {
+    @ParameterizedTest
+    @ValueSource(shorts = {0, 1})
+    public void testVotedToInvalidLeaderOrResigned(short kraftVersion) {
         int node1 = 1;
         int node2 = 2;
         Set<Integer> voters = Utils.mkSet(localId, node1, node2);
-        QuorumState state = initializeEmptyState(voters);
+        QuorumState state = initializeEmptyState(voters, kraftVersion);
         state.initialize(new OffsetAndEpoch(0L, logEndEpoch));
-        state.transitionToVoted(5, node1);
+        state.transitionToVoted(5, ReplicaKey.of(node1, Optional.empty()));
         assertThrows(IllegalStateException.class, () -> state.transitionToLeader(0, accumulator));
         assertThrows(IllegalStateException.class, () -> state.transitionToResigned(Collections.emptyList()));
     }
 
-    @Test
-    public void testVotedToCandidate() {
+    @ParameterizedTest
+    @ValueSource(shorts = {0, 1})
+    public void testVotedToCandidate(short kraftVersion) {
         int node1 = 1;
         int node2 = 2;
         Set<Integer> voters = Utils.mkSet(localId, node1, node2);
-        QuorumState state = initializeEmptyState(voters);
+        QuorumState state = initializeEmptyState(voters, kraftVersion);
         state.initialize(new OffsetAndEpoch(0L, logEndEpoch));
-        state.transitionToVoted(5, node1);
+        state.transitionToVoted(5, ReplicaKey.of(node1, Optional.empty()));
 
         int jitterMs = 2500;
         random.mockNextInt(electionTimeoutMs, jitterMs);
@@ -650,69 +814,86 @@
             candidateState.remainingElectionTimeMs(time.milliseconds()));
     }
 
-    @Test
-    public void testVotedToVotedSameEpoch() {
+    @ParameterizedTest
+    @ValueSource(shorts = {0, 1})
+    public void testVotedToVotedSameEpoch(short kraftVersion) {
         int node1 = 1;
         int node2 = 2;
         Set<Integer> voters = Utils.mkSet(localId, node1, node2);
-        QuorumState state = initializeEmptyState(voters);
+        QuorumState state = initializeEmptyState(voters, kraftVersion);
         state.initialize(new OffsetAndEpoch(0L, logEndEpoch));
         state.transitionToUnattached(5);
-        state.transitionToVoted(8, node1);
-        assertThrows(IllegalStateException.class, () -> state.transitionToVoted(8, node1));
-        assertThrows(IllegalStateException.class, () -> state.transitionToVoted(8, node2));
+        state.transitionToVoted(8, ReplicaKey.of(node1, Optional.of(Uuid.randomUuid())));
+        assertThrows(
+            IllegalStateException.class,
+            () -> state.transitionToVoted(8, ReplicaKey.of(node1, Optional.empty()))
+        );
+        assertThrows(
+            IllegalStateException.class,
+            () -> state.transitionToVoted(8, ReplicaKey.of(node2, Optional.empty()))
+        );
     }
 
-    @Test
-    public void testVotedToFollowerSameEpoch() {
+    @ParameterizedTest
+    @ValueSource(shorts = {0, 1})
+    public void testVotedToFollowerSameEpoch(short kraftVersion) {
         int node1 = 1;
         int node2 = 2;
         Set<Integer> voters = Utils.mkSet(localId, node1, node2);
-        QuorumState state = initializeEmptyState(voters);
+        QuorumState state = initializeEmptyState(voters, kraftVersion);
         state.initialize(new OffsetAndEpoch(0L, logEndEpoch));
-        state.transitionToVoted(5, node1);
+        state.transitionToVoted(5, ReplicaKey.of(node1, Optional.empty()));
         state.transitionToFollower(5, node2);
 
         FollowerState followerState = state.followerStateOrThrow();
         assertEquals(5, followerState.epoch());
         assertEquals(node2, followerState.leaderId());
-        assertEquals(ElectionState.withElectedLeader(5, node2, voters), store.readElectionState());
+        assertEquals(
+            Optional.of(ElectionState.withElectedLeader(5, node2, persistedVoters(voters, kraftVersion))),
+            store.readElectionState()
+        );
     }
 
-    @Test
-    public void testVotedToFollowerHigherEpoch() {
+    @ParameterizedTest
+    @ValueSource(shorts = {0, 1})
+    public void testVotedToFollowerHigherEpoch(short kraftVersion) {
         int node1 = 1;
         int node2 = 2;
         Set<Integer> voters = Utils.mkSet(localId, node1, node2);
-        QuorumState state = initializeEmptyState(voters);
+        QuorumState state = initializeEmptyState(voters, kraftVersion);
         state.initialize(new OffsetAndEpoch(0L, logEndEpoch));
-        state.transitionToVoted(5, node1);
+        state.transitionToVoted(5, ReplicaKey.of(node1, Optional.empty()));
         state.transitionToFollower(8, node2);
 
         FollowerState followerState = state.followerStateOrThrow();
         assertEquals(8, followerState.epoch());
         assertEquals(node2, followerState.leaderId());
-        assertEquals(ElectionState.withElectedLeader(8, node2, voters), store.readElectionState());
+        assertEquals(
+            Optional.of(ElectionState.withElectedLeader(8, node2, persistedVoters(voters, kraftVersion))),
+            store.readElectionState()
+        );
     }
 
-    @Test
-    public void testVotedToUnattachedSameEpoch() {
+    @ParameterizedTest
+    @ValueSource(shorts = {0, 1})
+    public void testVotedToUnattachedSameEpoch(short kraftVersion) {
         int node1 = 1;
         int node2 = 2;
         Set<Integer> voters = Utils.mkSet(localId, node1, node2);
-        QuorumState state = initializeEmptyState(voters);
+        QuorumState state = initializeEmptyState(voters, kraftVersion);
         state.initialize(new OffsetAndEpoch(0L, logEndEpoch));
-        state.transitionToVoted(5, node1);
+        state.transitionToVoted(5, ReplicaKey.of(node1, Optional.empty()));
         assertThrows(IllegalStateException.class, () -> state.transitionToUnattached(5));
     }
 
-    @Test
-    public void testVotedToUnattachedHigherEpoch() {
+    @ParameterizedTest
+    @ValueSource(shorts = {0, 1})
+    public void testVotedToUnattachedHigherEpoch(short kraftVersion) {
         int otherNodeId = 1;
         Set<Integer> voters = Utils.mkSet(localId, otherNodeId);
-        QuorumState state = initializeEmptyState(voters);
+        QuorumState state = initializeEmptyState(voters, kraftVersion);
         state.initialize(new OffsetAndEpoch(0L, logEndEpoch));
-        state.transitionToVoted(5, otherNodeId);
+        state.transitionToVoted(5, ReplicaKey.of(otherNodeId, Optional.empty()));
 
         long remainingElectionTimeMs = state.votedStateOrThrow().remainingElectionTimeMs(time.milliseconds());
         time.sleep(1000);
@@ -726,26 +907,42 @@
             unattachedState.remainingElectionTimeMs(time.milliseconds()));
     }
 
-    @Test
-    public void testVotedToAnyStateLowerEpoch() {
+    @ParameterizedTest
+    @ValueSource(shorts = {0, 1})
+    public void testVotedToAnyStateLowerEpoch(short kraftVersion) {
         int otherNodeId = 1;
+        Optional<Uuid> otherNodeDirectoryId = Optional.of(Uuid.randomUuid());
+        ReplicaKey otherNodeKey = ReplicaKey.of(otherNodeId, otherNodeDirectoryId);
         Set<Integer> voters = Utils.mkSet(localId, otherNodeId);
-        QuorumState state = initializeEmptyState(voters);
+        QuorumState state = initializeEmptyState(voters, kraftVersion);
         state.initialize(new OffsetAndEpoch(0L, logEndEpoch));
-        state.transitionToVoted(5, otherNodeId);
+        state.transitionToVoted(5, otherNodeKey);
         assertThrows(IllegalStateException.class, () -> state.transitionToUnattached(4));
-        assertThrows(IllegalStateException.class, () -> state.transitionToVoted(4, otherNodeId));
+        assertThrows(IllegalStateException.class, () -> state.transitionToVoted(4, otherNodeKey));
         assertThrows(IllegalStateException.class, () -> state.transitionToFollower(4, otherNodeId));
         assertEquals(5, state.epoch());
-        assertEquals(ElectionState.withVotedCandidate(5, otherNodeId, voters), store.readElectionState());
+        assertEquals(
+            Optional.of(
+                ElectionState.withVotedCandidate(
+                    5,
+                    ReplicaKey.of(
+                        otherNodeId,
+                        persistedDirectoryId(otherNodeDirectoryId, kraftVersion)
+                    ),
+                    persistedVoters(voters, kraftVersion)
+                )
+            ),
+            store.readElectionState()
+        );
     }
 
-    @Test
-    public void testFollowerToFollowerSameEpoch() {
+    @ParameterizedTest
+    @ValueSource(shorts = {0, 1})
+    public void testFollowerToFollowerSameEpoch(short kraftVersion) {
         int node1 = 1;
         int node2 = 2;
         Set<Integer> voters = Utils.mkSet(localId, node1, node2);
-        QuorumState state = initializeEmptyState(voters);
+        QuorumState state = initializeEmptyState(voters, kraftVersion);
         state.initialize(new OffsetAndEpoch(0L, logEndEpoch));
         state.transitionToFollower(8, node2);
         assertThrows(IllegalStateException.class, () -> state.transitionToFollower(8, node1));
@@ -754,15 +951,19 @@
         FollowerState followerState = state.followerStateOrThrow();
         assertEquals(8, followerState.epoch());
         assertEquals(node2, followerState.leaderId());
-        assertEquals(ElectionState.withElectedLeader(8, node2, voters), store.readElectionState());
+        assertEquals(
+            Optional.of(ElectionState.withElectedLeader(8, node2, persistedVoters(voters, kraftVersion))),
+            store.readElectionState()
+        );
     }
 
-    @Test
-    public void testFollowerToFollowerHigherEpoch() {
+    @ParameterizedTest
+    @ValueSource(shorts = {0, 1})
+    public void testFollowerToFollowerHigherEpoch(short kraftVersion) {
         int node1 = 1;
         int node2 = 2;
         Set<Integer> voters = Utils.mkSet(localId, node1, node2);
-        QuorumState state = initializeEmptyState(voters);
+        QuorumState state = initializeEmptyState(voters, kraftVersion);
         state.initialize(new OffsetAndEpoch(0L, logEndEpoch));
         state.transitionToFollower(8, node2);
         state.transitionToFollower(9, node1);
@@ -770,27 +971,32 @@
         FollowerState followerState = state.followerStateOrThrow();
         assertEquals(9, followerState.epoch());
         assertEquals(node1, followerState.leaderId());
-        assertEquals(ElectionState.withElectedLeader(9, node1, voters), store.readElectionState());
+        assertEquals(
+            Optional.of(ElectionState.withElectedLeader(9, node1, persistedVoters(voters, kraftVersion))),
+            store.readElectionState()
+        );
     }
 
-    @Test
-    public void testFollowerToLeaderOrResigned() {
+    @ParameterizedTest
+    @ValueSource(shorts = {0, 1})
+    public void testFollowerToLeaderOrResigned(short kraftVersion) {
         int node1 = 1;
         int node2 = 2;
         Set<Integer> voters = Utils.mkSet(localId, node1, node2);
-        QuorumState state = initializeEmptyState(voters);
+        QuorumState state = initializeEmptyState(voters, kraftVersion);
         state.initialize(new OffsetAndEpoch(0L, logEndEpoch));
         state.transitionToFollower(8, node2);
         assertThrows(IllegalStateException.class, () -> state.transitionToLeader(0, accumulator));
         assertThrows(IllegalStateException.class, () -> state.transitionToResigned(Collections.emptyList()));
     }
 
-    @Test
-    public void testFollowerToCandidate() {
+    @ParameterizedTest
+    @ValueSource(shorts = {0, 1})
+    public void testFollowerToCandidate(short kraftVersion) {
         int node1 = 1;
         int node2 = 2;
         Set<Integer> voters = Utils.mkSet(localId, node1, node2);
-        QuorumState state = initializeEmptyState(voters);
+        QuorumState state = initializeEmptyState(voters, kraftVersion);
         state.initialize(new OffsetAndEpoch(0L, logEndEpoch));
         state.transitionToFollower(8, node2);
 
@@ -804,23 +1010,25 @@
             candidateState.remainingElectionTimeMs(time.milliseconds()));
     }
 
-    @Test
-    public void testFollowerToUnattachedSameEpoch() {
+    @ParameterizedTest
+    @ValueSource(shorts = {0, 1})
+    public void testFollowerToUnattachedSameEpoch(short kraftVersion) {
         int node1 = 1;
         int node2 = 2;
         Set<Integer> voters = Utils.mkSet(localId, node1, node2);
-        QuorumState state = initializeEmptyState(voters);
+        QuorumState state = initializeEmptyState(voters, kraftVersion);
         state.initialize(new OffsetAndEpoch(0L, logEndEpoch));
         state.transitionToFollower(8, node2);
         assertThrows(IllegalStateException.class, () -> state.transitionToUnattached(8));
     }
 
-    @Test
-    public void testFollowerToUnattachedHigherEpoch() {
+    @ParameterizedTest
+    @ValueSource(shorts = {0, 1})
+    public void testFollowerToUnattachedHigherEpoch(short kraftVersion) {
         int node1 = 1;
         int node2 = 2;
         Set<Integer> voters = Utils.mkSet(localId, node1, node2);
-        QuorumState state = initializeEmptyState(voters);
+        QuorumState state = initializeEmptyState(voters, kraftVersion);
         state.initialize(new OffsetAndEpoch(0L, logEndEpoch));
         state.transitionToFollower(8, node2);
 
@@ -834,83 +1042,140 @@
             unattachedState.remainingElectionTimeMs(time.milliseconds()));
     }
 
-    @Test
-    public void testFollowerToVotedSameEpoch() {
+    @ParameterizedTest
+    @ValueSource(shorts = {0, 1})
+    public void testFollowerToVotedSameEpoch(short kraftVersion) {
         int node1 = 1;
         int node2 = 2;
         Set<Integer> voters = Utils.mkSet(localId, node1, node2);
-        QuorumState state = initializeEmptyState(voters);
+        QuorumState state = initializeEmptyState(voters, kraftVersion);
         state.initialize(new OffsetAndEpoch(0L, logEndEpoch));
         state.transitionToFollower(8, node2);
 
-        assertThrows(IllegalStateException.class, () -> state.transitionToVoted(8, node1));
-        assertThrows(IllegalStateException.class, () -> state.transitionToVoted(8, localId));
-        assertThrows(IllegalStateException.class, () -> state.transitionToVoted(8, node2));
+        assertThrows(
+            IllegalStateException.class,
+            () -> state.transitionToVoted(8, ReplicaKey.of(node1, Optional.empty()))
+        );
+        assertThrows(
+            IllegalStateException.class,
+            () -> state.transitionToVoted(8, ReplicaKey.of(localId, Optional.empty()))
+        );
+        assertThrows(
+            IllegalStateException.class,
+            () -> state.transitionToVoted(8, ReplicaKey.of(node2, Optional.empty()))
+        );
     }
 
-    @Test
-    public void testFollowerToVotedHigherEpoch() {
+    @ParameterizedTest
+    @ValueSource(shorts = {0, 1})
+    public void testFollowerToVotedHigherEpoch(short kraftVersion) {
         int node1 = 1;
+        Optional<Uuid> node1DirectoryId = Optional.of(Uuid.randomUuid());
+        ReplicaKey node1Key = ReplicaKey.of(node1, node1DirectoryId);
         int node2 = 2;
         Set<Integer> voters = Utils.mkSet(localId, node1, node2);
-        QuorumState state = initializeEmptyState(voters);
+        QuorumState state = initializeEmptyState(voters, kraftVersion);
         state.initialize(new OffsetAndEpoch(0L, logEndEpoch));
         state.transitionToFollower(8, node2);
 
         int jitterMs = 2500;
         random.mockNextInt(electionTimeoutMs, jitterMs);
-        state.transitionToVoted(9, node1);
+
+        state.transitionToVoted(9, node1Key);
         assertTrue(state.isVoted());
+
         VotedState votedState = state.votedStateOrThrow();
         assertEquals(9, votedState.epoch());
-        assertEquals(node1, votedState.votedId());
+        assertEquals(node1Key, votedState.votedKey());
+
         assertEquals(electionTimeoutMs + jitterMs,
             votedState.remainingElectionTimeMs(time.milliseconds()));
     }
 
-    @Test
-    public void testFollowerToAnyStateLowerEpoch() {
+    @ParameterizedTest
+    @ValueSource(shorts = {0, 1})
+    public void testFollowerToAnyStateLowerEpoch(short kraftVersion) {
         int otherNodeId = 1;
         Set<Integer> voters = Utils.mkSet(localId, otherNodeId);
-        QuorumState state = initializeEmptyState(voters);
+        QuorumState state = initializeEmptyState(voters, kraftVersion);
         state.initialize(new OffsetAndEpoch(0L, logEndEpoch));
         state.transitionToFollower(5, otherNodeId);
         assertThrows(IllegalStateException.class, () -> state.transitionToUnattached(4));
-        assertThrows(IllegalStateException.class, () -> state.transitionToVoted(4, otherNodeId));
+        assertThrows(
+            IllegalStateException.class,
+            () -> state.transitionToVoted(4, ReplicaKey.of(otherNodeId, Optional.empty()))
+        );
         assertThrows(IllegalStateException.class, () -> state.transitionToFollower(4, otherNodeId));
         assertEquals(5, state.epoch());
-        assertEquals(ElectionState.withElectedLeader(5, otherNodeId, voters), store.readElectionState());
+        assertEquals(
+            Optional.of(ElectionState.withElectedLeader(5, otherNodeId, persistedVoters(voters, kraftVersion))),
+            store.readElectionState()
+        );
     }
 
-    @Test
-    public void testCannotBecomeFollowerOfNonVoter() {
+    @ParameterizedTest
+    @ValueSource(shorts = {0, 1})
+    public void testCanBecomeFollowerOfNonVoter(short kraftVersion) {
         int otherNodeId = 1;
         int nonVoterId = 2;
+        Optional<Uuid> nonVoterDirectoryId = Optional.of(Uuid.randomUuid());
+        ReplicaKey nonVoterKey = ReplicaKey.of(nonVoterId, nonVoterDirectoryId);
         Set<Integer> voters = Utils.mkSet(localId, otherNodeId);
-        QuorumState state = initializeEmptyState(voters);
+        QuorumState state = initializeEmptyState(voters, kraftVersion);
         state.initialize(new OffsetAndEpoch(0L, logEndEpoch));
-        assertThrows(IllegalStateException.class, () -> state.transitionToVoted(4, nonVoterId));
-        assertThrows(IllegalStateException.class, () -> state.transitionToFollower(4, nonVoterId));
+
+        // Transition to voted
+        state.transitionToVoted(4, nonVoterKey);
+        assertTrue(state.isVoted());
+
+        VotedState votedState = state.votedStateOrThrow();
+        assertEquals(4, votedState.epoch());
+        assertEquals(nonVoterKey, votedState.votedKey());
+
+        // Transition to follower
+        state.transitionToFollower(4, nonVoterId);
+        assertEquals(new LeaderAndEpoch(OptionalInt.of(nonVoterId), 4), state.leaderAndEpoch());
     }
 
-    @Test
-    public void testObserverCannotBecomeCandidateOrLeaderOrVoted() {
+    @ParameterizedTest
+    @ValueSource(shorts = {0, 1})
+    public void testObserverCannotBecomeCandidateOrLeader(short kraftVersion) {
         int otherNodeId = 1;
         Set<Integer> voters = Utils.mkSet(otherNodeId);
-        QuorumState state = initializeEmptyState(voters);
+        QuorumState state = initializeEmptyState(voters, kraftVersion);
         state.initialize(new OffsetAndEpoch(0L, logEndEpoch));
         assertTrue(state.isObserver());
         assertThrows(IllegalStateException.class, state::transitionToCandidate);
         assertThrows(IllegalStateException.class, () -> state.transitionToLeader(0L, accumulator));
-        assertThrows(IllegalStateException.class, () -> state.transitionToVoted(5, otherNodeId));
     }
 
-    @Test
-    public void testObserverFollowerToUnattached() {
+    @ParameterizedTest
+    @ValueSource(shorts = {0, 1})
+    public void testObserverWithIdCanVote(short kraftVersion) {
+        int otherNodeId = 1;
+        Optional<Uuid> otherNodeDirectoryId = Optional.of(Uuid.randomUuid());
+        ReplicaKey otherNodeKey = ReplicaKey.of(otherNodeId, otherNodeDirectoryId);
+        Set<Integer> voters = Utils.mkSet(otherNodeId);
+
+        QuorumState state = initializeEmptyState(voters, kraftVersion);
+        state.initialize(new OffsetAndEpoch(0L, logEndEpoch));
+        assertTrue(state.isObserver());
+
+        state.transitionToVoted(5, otherNodeKey);
+        assertTrue(state.isVoted());
+
+        VotedState votedState = state.votedStateOrThrow();
+        assertEquals(5, votedState.epoch());
+        assertEquals(otherNodeKey, votedState.votedKey());
+    }
+
+    @ParameterizedTest
+    @ValueSource(shorts = {0, 1})
+    public void testObserverFollowerToUnattached(short kraftVersion) {
         int node1 = 1;
         int node2 = 2;
         Set<Integer> voters = Utils.mkSet(node1, node2);
-        QuorumState state = initializeEmptyState(voters);
+        QuorumState state = initializeEmptyState(voters, kraftVersion);
         state.initialize(new OffsetAndEpoch(0L, logEndEpoch));
         assertTrue(state.isObserver());
 
@@ -924,12 +1189,13 @@
         assertEquals(Long.MAX_VALUE, unattachedState.electionTimeoutMs());
     }
 
-    @Test
-    public void testObserverUnattachedToFollower() {
+    @ParameterizedTest
+    @ValueSource(shorts = {0, 1})
+    public void testObserverUnattachedToFollower(short kraftVersion) {
         int node1 = 1;
         int node2 = 2;
         Set<Integer> voters = Utils.mkSet(node1, node2);
-        QuorumState state = initializeEmptyState(voters);
+        QuorumState state = initializeEmptyState(voters, kraftVersion);
         state.initialize(new OffsetAndEpoch(0L, logEndEpoch));
         assertTrue(state.isObserver());
 
@@ -942,12 +1208,13 @@
         assertEquals(fetchTimeoutMs, followerState.remainingFetchTimeMs(time.milliseconds()));
     }
 
-    @Test
-    public void testInitializeWithCorruptedStore() {
+    @ParameterizedTest
+    @ValueSource(shorts = {0, 1})
+    public void testInitializeWithCorruptedStore(short kraftVersion) {
         QuorumStateStore stateStore = Mockito.mock(QuorumStateStore.class);
         Mockito.doThrow(UncheckedIOException.class).when(stateStore).readElectionState();
 
-        QuorumState state = buildQuorumState(Utils.mkSet(localId));
+        QuorumState state = buildQuorumState(Utils.mkSet(localId), kraftVersion);
 
         int epoch = 2;
         state.initialize(new OffsetAndEpoch(0L, epoch));
@@ -956,28 +1223,15 @@
         assertFalse(state.hasLeader());
     }
 
-    @Test
-    public void testInconsistentVotersBetweenConfigAndState() {
+    @ParameterizedTest
+    @ValueSource(shorts = {0, 1})
+    public void testHasRemoteLeader(short kraftVersion) {
         int otherNodeId = 1;
+        Optional<Uuid> otherNodeDirectoryId = Optional.of(Uuid.randomUuid());
+        ReplicaKey otherNodeKey = ReplicaKey.of(otherNodeId, otherNodeDirectoryId);
         Set<Integer> voters = Utils.mkSet(localId, otherNodeId);
 
-        QuorumState state = initializeEmptyState(voters);
-
-        int unknownVoterId = 2;
-        Set<Integer> stateVoters = Utils.mkSet(localId, otherNodeId, unknownVoterId);
-
-        int epoch = 5;
-        store.writeElectionState(ElectionState.withElectedLeader(epoch, localId, stateVoters));
-        assertThrows(IllegalStateException.class,
-            () -> state.initialize(new OffsetAndEpoch(0L, logEndEpoch)));
-    }
-
-    @Test
-    public void testHasRemoteLeader() {
-        int otherNodeId = 1;
-        Set<Integer> voters = Utils.mkSet(localId, otherNodeId);
-
-        QuorumState state = initializeEmptyState(voters);
+        QuorumState state = initializeEmptyState(voters, kraftVersion);
         assertFalse(state.hasRemoteLeader());
 
         state.transitionToCandidate();
@@ -990,19 +1244,22 @@
         state.transitionToUnattached(state.epoch() + 1);
         assertFalse(state.hasRemoteLeader());
 
-        state.transitionToVoted(state.epoch() + 1, otherNodeId);
+        state.transitionToVoted(state.epoch() + 1, otherNodeKey);
         assertFalse(state.hasRemoteLeader());
 
         state.transitionToFollower(state.epoch() + 1, otherNodeId);
         assertTrue(state.hasRemoteLeader());
     }
 
-    @Test
-    public void testHighWatermarkRetained() {
+    @ParameterizedTest
+    @ValueSource(shorts = {0, 1})
+    public void testHighWatermarkRetained(short kraftVersion) {
         int otherNodeId = 1;
+        Optional<Uuid> otherNodeDirectoryId = Optional.of(Uuid.randomUuid());
+        ReplicaKey otherNodeKey = ReplicaKey.of(otherNodeId, otherNodeDirectoryId);
         Set<Integer> voters = Utils.mkSet(localId, otherNodeId);
 
-        QuorumState state = initializeEmptyState(voters);
+        QuorumState state = initializeEmptyState(voters, kraftVersion);
         state.transitionToFollower(5, otherNodeId);
 
         FollowerState followerState = state.followerStateOrThrow();
@@ -1014,7 +1271,7 @@
         state.transitionToUnattached(6);
         assertEquals(highWatermark, state.highWatermark());
 
-        state.transitionToVoted(7, otherNodeId);
+        state.transitionToVoted(7, otherNodeKey);
         assertEquals(highWatermark, state.highWatermark());
 
         state.transitionToCandidate();
@@ -1028,16 +1285,20 @@
         assertEquals(Optional.empty(), state.highWatermark());
     }
 
-    @Test
-    public void testInitializeWithEmptyLocalId() {
-        QuorumState state = buildQuorumState(OptionalInt.empty(), Utils.mkSet(0, 1));
+    @ParameterizedTest
+    @ValueSource(shorts = {0, 1})
+    public void testInitializeWithEmptyLocalId(short kraftVersion) {
+        QuorumState state = buildQuorumState(OptionalInt.empty(), Utils.mkSet(0, 1), kraftVersion);
         state.initialize(new OffsetAndEpoch(0L, 0));
 
         assertTrue(state.isObserver());
         assertFalse(state.isVoter());
 
         assertThrows(IllegalStateException.class, state::transitionToCandidate);
-        assertThrows(IllegalStateException.class, () -> state.transitionToVoted(1, 1));
+        assertThrows(
+            IllegalStateException.class,
+            () -> state.transitionToVoted(1, ReplicaKey.of(1, Optional.empty()))
+        );
         assertThrows(IllegalStateException.class, () -> state.transitionToLeader(0L, accumulator));
 
         state.transitionToFollower(1, 1);
@@ -1047,25 +1308,46 @@
         assertTrue(state.isUnattached());
     }
 
-    @Test
-    public void testObserverInitializationFailsIfElectionStateHasVotedCandidate() {
-        Set<Integer> voters = Utils.mkSet(0, 1);
+    @ParameterizedTest
+    @ValueSource(shorts = {0, 1})
+    public void testNoLocalIdInitializationFailsIfElectionStateHasVotedCandidate(short kraftVersion) {
         int epoch = 5;
         int votedId = 1;
+        Set<Integer> voters = Utils.mkSet(0, votedId);
 
-        store.writeElectionState(ElectionState.withVotedCandidate(epoch, votedId, voters));
+        store.writeElectionState(
+            ElectionState.withVotedCandidate(
+                epoch,
+                ReplicaKey.of(votedId, Optional.empty()),
+                voters
+            ),
+            kraftVersion
+        );
 
-        QuorumState state1 = buildQuorumState(OptionalInt.of(2), voters);
-        assertThrows(IllegalStateException.class, () -> state1.initialize(new OffsetAndEpoch(0, 0)));
-
-        QuorumState state2 = buildQuorumState(OptionalInt.empty(), voters);
+        QuorumState state2 = buildQuorumState(OptionalInt.empty(), voters, kraftVersion);
         assertThrows(IllegalStateException.class, () -> state2.initialize(new OffsetAndEpoch(0, 0)));
     }
 
-    private QuorumState initializeEmptyState(Set<Integer> voters) {
-        QuorumState state = buildQuorumState(voters);
-        store.writeElectionState(ElectionState.withUnknownLeader(0, voters));
+    private QuorumState initializeEmptyState(Set<Integer> voters, short kraftVersion) {
+        QuorumState state = buildQuorumState(voters, kraftVersion);
+        store.writeElectionState(ElectionState.withUnknownLeader(0, voters), kraftVersion);
         state.initialize(new OffsetAndEpoch(0L, logEndEpoch));
         return state;
     }
+
+    private Set<Integer> persistedVoters(Set<Integer> voters, short kraftVersion) {
+        if (kraftVersion == 1) {
+            return Collections.emptySet();
+        }
+
+        return voters;
+    }
+
+    private Optional<Uuid> persistedDirectoryId(Optional<Uuid> directoryId, short kraftVersion) {
+        if (kraftVersion == 1) {
+            return directoryId;
+        }
+
+        return Optional.empty();
+    }
 }
diff --git a/raft/src/test/java/org/apache/kafka/raft/RaftClientTestContext.java b/raft/src/test/java/org/apache/kafka/raft/RaftClientTestContext.java
index c797100..5c2ab10 100644
--- a/raft/src/test/java/org/apache/kafka/raft/RaftClientTestContext.java
+++ b/raft/src/test/java/org/apache/kafka/raft/RaftClientTestContext.java
@@ -54,6 +54,7 @@
 import org.apache.kafka.common.utils.MockTime;
 import org.apache.kafka.common.utils.Utils;
 import org.apache.kafka.raft.internals.BatchBuilder;
+import org.apache.kafka.raft.internals.ReplicaKey;
 import org.apache.kafka.raft.internals.StringSerde;
 import org.apache.kafka.server.common.serialization.RecordSerde;
 import org.apache.kafka.snapshot.RawSnapshotWriter;
@@ -138,6 +139,8 @@
         private final Uuid clusterId = Uuid.randomUuid();
         private final Set<Integer> voters;
         private final OptionalInt localId;
+        private final Uuid localDirectoryId = Uuid.randomUuid();
+        private final short kraftVersion = 0;
 
         private int requestTimeoutMs = DEFAULT_REQUEST_TIMEOUT_MS;
         private int electionTimeoutMs = DEFAULT_ELECTION_TIMEOUT_MS;
@@ -154,17 +157,26 @@
         }
 
         Builder withElectedLeader(int epoch, int leaderId) {
-            quorumStateStore.writeElectionState(ElectionState.withElectedLeader(epoch, leaderId, voters));
+            quorumStateStore.writeElectionState(
+                ElectionState.withElectedLeader(epoch, leaderId, voters),
+                kraftVersion
+            );
             return this;
         }
 
         Builder withUnknownLeader(int epoch) {
-            quorumStateStore.writeElectionState(ElectionState.withUnknownLeader(epoch, voters));
+            quorumStateStore.writeElectionState(
+                ElectionState.withUnknownLeader(epoch, voters),
+                kraftVersion
+            );
             return this;
         }
 
-        Builder withVotedCandidate(int epoch, int votedId) {
-            quorumStateStore.writeElectionState(ElectionState.withVotedCandidate(epoch, votedId, voters));
+        Builder withVotedCandidate(int epoch, ReplicaKey votedKey) {
+            quorumStateStore.writeElectionState(
+                ElectionState.withVotedCandidate(epoch, votedKey, voters),
+                kraftVersion
+            );
             return this;
         }
 
@@ -247,6 +259,7 @@
 
             KafkaRaftClient<String> client = new KafkaRaftClient<>(
                 localId,
+                localDirectoryId,
                 SERDE,
                 channel,
                 messageQueue,
@@ -392,8 +405,8 @@
     }
 
     LeaderAndEpoch currentLeaderAndEpoch() {
-        ElectionState election = quorumStateStore.readElectionState();
-        return new LeaderAndEpoch(election.leaderIdOpt, election.epoch);
+        ElectionState election = quorumStateStore.readElectionState().get();
+        return new LeaderAndEpoch(election.optionalLeaderId(), election.epoch());
     }
 
     void expectAndGrantVotes(int epoch) throws Exception {
@@ -439,21 +452,37 @@
         pollUntil(channel::hasSentRequests);
     }
 
-    void assertVotedCandidate(int epoch, int leaderId) {
-        assertEquals(ElectionState.withVotedCandidate(epoch, leaderId, voters), quorumStateStore.readElectionState());
+    void assertVotedCandidate(int epoch, int candidateId) {
+        assertEquals(
+            ElectionState.withVotedCandidate(
+                epoch,
+                ReplicaKey.of(candidateId, Optional.empty()),
+                voters
+            ),
+            quorumStateStore.readElectionState().get()
+        );
     }
 
     public void assertElectedLeader(int epoch, int leaderId) {
-        assertEquals(ElectionState.withElectedLeader(epoch, leaderId, voters), quorumStateStore.readElectionState());
+        assertEquals(
+            ElectionState.withElectedLeader(epoch, leaderId, voters),
+            quorumStateStore.readElectionState().get()
+        );
     }
 
     void assertUnknownLeader(int epoch) {
-        assertEquals(ElectionState.withUnknownLeader(epoch, voters), quorumStateStore.readElectionState());
+        assertEquals(
+            ElectionState.withUnknownLeader(epoch, voters),
+            quorumStateStore.readElectionState().get()
+        );
     }
 
     void assertResignedLeader(int epoch, int leaderId) {
         assertTrue(client.quorum().isResigned());
-        assertEquals(ElectionState.withElectedLeader(epoch, leaderId, voters), quorumStateStore.readElectionState());
+        assertEquals(
+            ElectionState.withElectedLeader(epoch, leaderId, voters),
+            quorumStateStore.readElectionState().get()
+        );
     }
 
     DescribeQuorumResponseData collectDescribeQuorumResponse() {
@@ -496,9 +525,7 @@
         return voteRequests.iterator().next().correlationId();
     }
 
-    void assertSentVoteResponse(
-            Errors error
-    ) {
+    void assertSentVoteResponse(Errors error) {
         List<RaftResponse.Outbound> sentMessages = drainSentResponses(ApiKeys.VOTE);
         assertEquals(1, sentMessages.size());
         RaftMessage raftMessage = sentMessages.get(0);
diff --git a/raft/src/test/java/org/apache/kafka/raft/RaftEventSimulationTest.java b/raft/src/test/java/org/apache/kafka/raft/RaftEventSimulationTest.java
index 94a42ab..f52ee37 100644
--- a/raft/src/test/java/org/apache/kafka/raft/RaftEventSimulationTest.java
+++ b/raft/src/test/java/org/apache/kafka/raft/RaftEventSimulationTest.java
@@ -510,6 +510,7 @@
 
     private static class PersistentState {
         final MockQuorumStateStore store = new MockQuorumStateStore();
+        final Uuid nodeDirectoryId = Uuid.randomUuid();
         final MockLog log;
 
         PersistentState(int nodeId) {
@@ -630,13 +631,13 @@
                 return false;
 
             RaftNode first = iter.next();
-            ElectionState election = first.store.readElectionState();
+            ElectionState election = first.store.readElectionState().get();
             if (!election.hasLeader())
                 return false;
 
             while (iter.hasNext()) {
                 RaftNode next = iter.next();
-                if (!election.equals(next.store.readElectionState()))
+                if (!election.equals(next.store.readElectionState().get()))
                     return false;
             }
 
@@ -739,6 +740,7 @@
 
             KafkaRaftClient<Integer> client = new KafkaRaftClient<>(
                 OptionalInt.of(nodeId),
+                persistentState.nodeDirectoryId,
                 serde,
                 channel,
                 messageQueue,
@@ -777,10 +779,7 @@
         final MockNetworkChannel channel;
         final MockMessageQueue messageQueue;
         final MockQuorumStateStore store;
-        final LogContext logContext;
         final ReplicatedCounter counter;
-        final Time time;
-        final Random random;
         final RecordSerde<Integer> intSerde;
 
         private RaftNode(
@@ -801,9 +800,6 @@
             this.channel = channel;
             this.messageQueue = messageQueue;
             this.store = store;
-            this.logContext = logContext;
-            this.time = time;
-            this.random = random;
             this.counter = new ReplicatedCounter(nodeId, client, logContext);
             this.intSerde = intSerde;
         }
@@ -850,14 +846,10 @@
     }
 
     private static class InflightRequest {
-        final int correlationId;
         final int sourceId;
-        final int destinationId;
 
-        private InflightRequest(int correlationId, int sourceId, int destinationId) {
-            this.correlationId = correlationId;
+        private InflightRequest(int sourceId) {
             this.sourceId = sourceId;
-            this.destinationId = destinationId;
         }
     }
 
@@ -934,18 +926,18 @@
                 PersistentState state = nodeStateEntry.getValue();
                 Integer oldEpoch = nodeEpochs.get(nodeId);
 
-                ElectionState electionState = state.store.readElectionState();
-                if (electionState == null) {
+                Optional<ElectionState> electionState = state.store.readElectionState();
+                if (!electionState.isPresent()) {
                     continue;
                 }
 
-                Integer newEpoch = electionState.epoch;
+                int newEpoch = electionState.get().epoch();
                 if (oldEpoch > newEpoch) {
                     fail("Non-monotonic update of epoch detected on node " + nodeId + ": " +
                             oldEpoch + " -> " + newEpoch);
                 }
                 cluster.ifRunning(nodeId, nodeState -> {
-                    assertEquals(newEpoch.intValue(), nodeState.client.quorum().epoch());
+                    assertEquals(newEpoch, nodeState.client.quorum().epoch());
                 });
                 nodeEpochs.put(nodeId, newEpoch);
             }
@@ -986,16 +978,18 @@
         public void verify() {
             for (Map.Entry<Integer, PersistentState> nodeEntry : cluster.nodes.entrySet()) {
                 PersistentState state = nodeEntry.getValue();
-                ElectionState electionState = state.store.readElectionState();
+                Optional<ElectionState> electionState = state.store.readElectionState();
 
-                if (electionState != null && electionState.epoch >= epoch && electionState.hasLeader()) {
-                    if (epoch == electionState.epoch && leaderId.isPresent()) {
-                        assertEquals(leaderId.getAsInt(), electionState.leaderId());
-                    } else {
-                        epoch = electionState.epoch;
-                        leaderId = OptionalInt.of(electionState.leaderId());
+                electionState.ifPresent(election -> {
+                    if (election.epoch() >= epoch && election.hasLeader()) {
+                        if (epoch == election.epoch() && leaderId.isPresent()) {
+                            assertEquals(leaderId.getAsInt(), election.leaderId());
+                        } else {
+                            epoch = election.epoch();
+                            leaderId = OptionalInt.of(election.leaderId());
+                        }
                     }
-                }
+                });
             }
         }
     }
@@ -1208,7 +1202,7 @@
                 return;
 
             cluster.nodeIfRunning(destinationId).ifPresent(node -> {
-                inflight.put(correlationId, new InflightRequest(correlationId, senderId, destinationId));
+                inflight.put(correlationId, new InflightRequest(senderId));
 
                 inbound.completion.whenComplete((response, exception) -> {
                     if (response != null && filters.get(destinationId).acceptOutbound(response)) {
diff --git a/raft/src/test/java/org/apache/kafka/raft/ResignedStateTest.java b/raft/src/test/java/org/apache/kafka/raft/ResignedStateTest.java
index c57ed70..88e6f9f 100644
--- a/raft/src/test/java/org/apache/kafka/raft/ResignedStateTest.java
+++ b/raft/src/test/java/org/apache/kafka/raft/ResignedStateTest.java
@@ -19,12 +19,14 @@
 import org.apache.kafka.common.utils.LogContext;
 import org.apache.kafka.common.utils.MockTime;
 import org.apache.kafka.common.utils.Utils;
+import org.apache.kafka.raft.internals.ReplicaKey;
 import org.junit.jupiter.api.Test;
 import org.junit.jupiter.params.ParameterizedTest;
 import org.junit.jupiter.params.provider.ValueSource;
 
 import java.util.Collections;
 import java.util.List;
+import java.util.Optional;
 import java.util.Set;
 
 import static org.junit.jupiter.api.Assertions.assertEquals;
@@ -87,9 +89,9 @@
             Collections.emptyList()
         );
 
-        assertFalse(state.canGrantVote(1, isLogUpToDate));
-        assertFalse(state.canGrantVote(2, isLogUpToDate));
-        assertFalse(state.canGrantVote(3, isLogUpToDate));
+        assertFalse(state.canGrantVote(ReplicaKey.of(1, Optional.empty()), isLogUpToDate));
+        assertFalse(state.canGrantVote(ReplicaKey.of(2, Optional.empty()), isLogUpToDate));
+        assertFalse(state.canGrantVote(ReplicaKey.of(3, Optional.empty()), isLogUpToDate));
     }
 
     @Test
diff --git a/raft/src/test/java/org/apache/kafka/raft/UnattachedStateTest.java b/raft/src/test/java/org/apache/kafka/raft/UnattachedStateTest.java
index 96f2a52..3550cd9 100644
--- a/raft/src/test/java/org/apache/kafka/raft/UnattachedStateTest.java
+++ b/raft/src/test/java/org/apache/kafka/raft/UnattachedStateTest.java
@@ -19,6 +19,7 @@
 import org.apache.kafka.common.utils.LogContext;
 import org.apache.kafka.common.utils.MockTime;
 import org.apache.kafka.common.utils.Utils;
+import org.apache.kafka.raft.internals.ReplicaKey;
 import org.junit.jupiter.api.Test;
 import org.junit.jupiter.params.ParameterizedTest;
 import org.junit.jupiter.params.provider.ValueSource;
@@ -80,8 +81,17 @@
                 Optional.empty()
         );
 
-        assertEquals(isLogUpToDate, state.canGrantVote(1, isLogUpToDate));
-        assertEquals(isLogUpToDate, state.canGrantVote(2, isLogUpToDate));
-        assertEquals(isLogUpToDate, state.canGrantVote(3, isLogUpToDate));
+        assertEquals(
+            isLogUpToDate,
+            state.canGrantVote(ReplicaKey.of(1, Optional.empty()), isLogUpToDate)
+        );
+        assertEquals(
+            isLogUpToDate,
+            state.canGrantVote(ReplicaKey.of(2, Optional.empty()), isLogUpToDate)
+        );
+        assertEquals(
+            isLogUpToDate,
+            state.canGrantVote(ReplicaKey.of(3, Optional.empty()), isLogUpToDate)
+        );
     }
 }
diff --git a/raft/src/test/java/org/apache/kafka/raft/VotedStateTest.java b/raft/src/test/java/org/apache/kafka/raft/VotedStateTest.java
index 317b80f..dca1a08 100644
--- a/raft/src/test/java/org/apache/kafka/raft/VotedStateTest.java
+++ b/raft/src/test/java/org/apache/kafka/raft/VotedStateTest.java
@@ -16,15 +16,16 @@
  */
 package org.apache.kafka.raft;
 
+import org.apache.kafka.common.Uuid;
 import org.apache.kafka.common.utils.LogContext;
 import org.apache.kafka.common.utils.MockTime;
-import org.apache.kafka.common.utils.Utils;
+import org.apache.kafka.raft.internals.ReplicaKey;
 import org.junit.jupiter.api.Test;
 import org.junit.jupiter.params.ParameterizedTest;
 import org.junit.jupiter.params.provider.ValueSource;
 
 import java.util.Optional;
-import java.util.Set;
+import java.util.Collections;
 
 import static org.junit.jupiter.api.Assertions.assertEquals;
 import static org.junit.jupiter.api.Assertions.assertFalse;
@@ -39,14 +40,14 @@
     private final int electionTimeoutMs = 10000;
 
     private VotedState newVotedState(
-        Set<Integer> voters,
+        Optional<Uuid> votedDirectoryId,
         Optional<LogOffsetMetadata> highWatermark
     ) {
         return new VotedState(
             time,
             epoch,
-            votedId,
-            voters,
+            ReplicaKey.of(votedId, votedDirectoryId),
+            Collections.emptySet(),
             highWatermark,
             electionTimeoutMs,
             logContext
@@ -55,13 +56,15 @@
 
     @Test
     public void testElectionTimeout() {
-        Set<Integer> voters = Utils.mkSet(1, 2, 3);
-
-        VotedState state = newVotedState(voters, Optional.empty());
+        VotedState state = newVotedState(Optional.empty(), Optional.empty());
+        ReplicaKey votedKey  = ReplicaKey.of(votedId, Optional.empty());
 
         assertEquals(epoch, state.epoch());
-        assertEquals(votedId, state.votedId());
-        assertEquals(ElectionState.withVotedCandidate(epoch, votedId, voters), state.election());
+        assertEquals(votedKey, state.votedKey());
+        assertEquals(
+            ElectionState.withVotedCandidate(epoch, votedKey, Collections.emptySet()),
+            state.election()
+        );
         assertEquals(electionTimeoutMs, state.remainingElectionTimeMs(time.milliseconds()));
         assertFalse(state.hasElectionTimeoutExpired(time.milliseconds()));
 
@@ -76,14 +79,37 @@
 
     @ParameterizedTest
     @ValueSource(booleans = {true, false})
-    public void testGrantVote(boolean isLogUpToDate) {
-        VotedState state = newVotedState(
-            Utils.mkSet(1, 2, 3),
-            Optional.empty()
+    public void testCanGrantVoteWithoutDirectoryId(boolean isLogUpToDate) {
+        VotedState state = newVotedState(Optional.empty(), Optional.empty());
+
+        assertTrue(
+            state.canGrantVote(ReplicaKey.of(votedId, Optional.empty()), isLogUpToDate)
+        );
+        assertTrue(
+            state.canGrantVote(
+                ReplicaKey.of(votedId, Optional.of(Uuid.randomUuid())),
+                isLogUpToDate
+            )
         );
 
-        assertTrue(state.canGrantVote(1, isLogUpToDate));
-        assertFalse(state.canGrantVote(2, isLogUpToDate));
-        assertFalse(state.canGrantVote(3, isLogUpToDate));
+        assertFalse(
+            state.canGrantVote(ReplicaKey.of(votedId + 1, Optional.empty()), isLogUpToDate)
+        );
+    }
+
+    @Test
+    void testCanGrantVoteWithDirectoryId() {
+        Optional<Uuid> votedDirectoryId = Optional.of(Uuid.randomUuid());
+        VotedState state = newVotedState(votedDirectoryId, Optional.empty());
+
+        assertTrue(state.canGrantVote(ReplicaKey.of(votedId, votedDirectoryId), false));
+
+        assertFalse(
+            state.canGrantVote(ReplicaKey.of(votedId, Optional.of(Uuid.randomUuid())), false)
+        );
+        assertFalse(state.canGrantVote(ReplicaKey.of(votedId, Optional.empty()), false));
+
+        assertFalse(state.canGrantVote(ReplicaKey.of(votedId + 1, votedDirectoryId), false));
+        assertFalse(state.canGrantVote(ReplicaKey.of(votedId + 1, Optional.empty()), false));
     }
 }
diff --git a/raft/src/test/java/org/apache/kafka/raft/internals/KRaftControlRecordStateMachineTest.java b/raft/src/test/java/org/apache/kafka/raft/internals/KRaftControlRecordStateMachineTest.java
index 355085c..80f7df0 100644
--- a/raft/src/test/java/org/apache/kafka/raft/internals/KRaftControlRecordStateMachineTest.java
+++ b/raft/src/test/java/org/apache/kafka/raft/internals/KRaftControlRecordStateMachineTest.java
@@ -52,7 +52,7 @@
     @Test
     void testEmptyPartition() {
         MockLog log = buildLog();
-        VoterSet voterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(1, 2, 3)));
+        VoterSet voterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true));
 
         KRaftControlRecordStateMachine partitionState = buildPartitionListener(log, Optional.of(voterSet));
 
@@ -65,7 +65,7 @@
     @Test
     void testUpdateWithoutSnapshot() {
         MockLog log = buildLog();
-        VoterSet staticVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(1, 2, 3)));
+        VoterSet staticVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true));
         BufferSupplier bufferSupplier = BufferSupplier.NO_CACHING;
         int epoch = 1;
 
@@ -85,7 +85,7 @@
         );
 
         // Append the voter set control record
-        VoterSet voterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(4, 5, 6)));
+        VoterSet voterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(4, 5, 6), true));
         log.appendAsLeader(
             MemoryRecords.withVotersRecord(
                 log.endOffset().offset,
@@ -108,7 +108,7 @@
     @Test
     void testUpdateWithEmptySnapshot() {
         MockLog log = buildLog();
-        VoterSet staticVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(1, 2, 3)));
+        VoterSet staticVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true));
         BufferSupplier bufferSupplier = BufferSupplier.NO_CACHING;
         int epoch = 1;
 
@@ -136,7 +136,7 @@
         );
 
         // Append the voter set control record
-        VoterSet voterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(4, 5, 6)));
+        VoterSet voterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(4, 5, 6), true));
         log.appendAsLeader(
             MemoryRecords.withVotersRecord(
                 log.endOffset().offset,
@@ -159,14 +159,14 @@
     @Test
     void testUpdateWithSnapshot() {
         MockLog log = buildLog();
-        VoterSet staticVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(1, 2, 3)));
+        VoterSet staticVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true));
         int epoch = 1;
 
         KRaftControlRecordStateMachine partitionState = buildPartitionListener(log, Optional.of(staticVoterSet));
 
         // Create a snapshot that has kraft.version and voter set control records
         short kraftVersion = 1;
-        VoterSet voterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(4, 5, 6)));
+        VoterSet voterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(4, 5, 6), true));
 
         RecordsSnapshotWriter.Builder builder = new RecordsSnapshotWriter.Builder()
             .setRawSnapshotWriter(log.createNewSnapshotUnchecked(new OffsetAndEpoch(10, epoch)).get())
@@ -188,7 +188,7 @@
     @Test
     void testUpdateWithSnapshotAndLogOverride() {
         MockLog log = buildLog();
-        VoterSet staticVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(1, 2, 3)));
+        VoterSet staticVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true));
         BufferSupplier bufferSupplier = BufferSupplier.NO_CACHING;
         int epoch = 1;
 
@@ -196,7 +196,7 @@
 
         // Create a snapshot that has kraft.version and voter set control records
         short kraftVersion = 1;
-        VoterSet snapshotVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(4, 5, 6)));
+        VoterSet snapshotVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(4, 5, 6), true));
 
         OffsetAndEpoch snapshotId = new OffsetAndEpoch(10, epoch);
         RecordsSnapshotWriter.Builder builder = new RecordsSnapshotWriter.Builder()
@@ -209,7 +209,7 @@
         log.truncateToLatestSnapshot();
 
         // Append the voter set control record
-        VoterSet voterSet = snapshotVoterSet.addVoter(VoterSetTest.voterNode(7)).get();
+        VoterSet voterSet = snapshotVoterSet.addVoter(VoterSetTest.voterNode(7, true)).get();
         log.appendAsLeader(
             MemoryRecords.withVotersRecord(
                 log.endOffset().offset,
@@ -235,7 +235,7 @@
     @Test
     void testTruncateTo() {
         MockLog log = buildLog();
-        VoterSet staticVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(1, 2, 3)));
+        VoterSet staticVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true));
         BufferSupplier bufferSupplier = BufferSupplier.NO_CACHING;
         int epoch = 1;
 
@@ -256,7 +256,7 @@
 
         // Append the voter set control record
         long firstVoterSetOffset = log.endOffset().offset;
-        VoterSet firstVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(4, 5, 6)));
+        VoterSet firstVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(4, 5, 6), true));
         log.appendAsLeader(
             MemoryRecords.withVotersRecord(
                 firstVoterSetOffset,
@@ -270,7 +270,7 @@
 
         // Append another voter set control record
         long voterSetOffset = log.endOffset().offset;
-        VoterSet voterSet = firstVoterSet.addVoter(VoterSetTest.voterNode(7)).get();
+        VoterSet voterSet = firstVoterSet.addVoter(VoterSetTest.voterNode(7, true)).get();
         log.appendAsLeader(
             MemoryRecords.withVotersRecord(
                 voterSetOffset,
@@ -303,7 +303,7 @@
     @Test
     void testTrimPrefixTo() {
         MockLog log = buildLog();
-        VoterSet staticVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(1, 2, 3)));
+        VoterSet staticVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true));
         BufferSupplier bufferSupplier = BufferSupplier.NO_CACHING;
         int epoch = 1;
 
@@ -325,7 +325,7 @@
 
         // Append the voter set control record
         long firstVoterSetOffset = log.endOffset().offset;
-        VoterSet firstVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(4, 5, 6)));
+        VoterSet firstVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(4, 5, 6), true));
         log.appendAsLeader(
             MemoryRecords.withVotersRecord(
                 firstVoterSetOffset,
@@ -339,7 +339,7 @@
 
         // Append another voter set control record
         long voterSetOffset = log.endOffset().offset;
-        VoterSet voterSet = firstVoterSet.addVoter(VoterSetTest.voterNode(7)).get();
+        VoterSet voterSet = firstVoterSet.addVoter(VoterSetTest.voterNode(7, true)).get();
         log.appendAsLeader(
             MemoryRecords.withVotersRecord(
                 voterSetOffset,
diff --git a/raft/src/test/java/org/apache/kafka/raft/internals/KafkaRaftMetricsTest.java b/raft/src/test/java/org/apache/kafka/raft/internals/KafkaRaftMetricsTest.java
index 4b2cc5a..1b729e3 100644
--- a/raft/src/test/java/org/apache/kafka/raft/internals/KafkaRaftMetricsTest.java
+++ b/raft/src/test/java/org/apache/kafka/raft/internals/KafkaRaftMetricsTest.java
@@ -16,7 +16,7 @@
  */
 package org.apache.kafka.raft.internals;
 
-
+import org.apache.kafka.common.Uuid;
 import org.apache.kafka.common.metrics.KafkaMetric;
 import org.apache.kafka.common.metrics.Metrics;
 import org.apache.kafka.common.utils.LogContext;
@@ -28,10 +28,13 @@
 import org.apache.kafka.raft.OffsetAndEpoch;
 import org.apache.kafka.raft.QuorumState;
 import org.junit.jupiter.api.AfterEach;
-import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.ValueSource;
 import org.mockito.Mockito;
 
+import java.util.Map;
 import java.util.Collections;
+import java.util.Optional;
 import java.util.OptionalInt;
 import java.util.OptionalLong;
 import java.util.Random;
@@ -42,6 +45,7 @@
 public class KafkaRaftMetricsTest {
 
     private final int localId = 0;
+    private final Uuid localDirectoryId = Uuid.randomUuid();
     private final int electionTimeoutMs = 5000;
     private final int fetchTimeoutMs = 10000;
 
@@ -60,10 +64,21 @@
         metrics.close();
     }
 
-    private QuorumState buildQuorumState(Set<Integer> voters) {
+    private QuorumState buildQuorumState(Set<Integer> voters, short kraftVersion) {
+        boolean withDirectoryId = kraftVersion > 0;
+
+        return buildQuorumState(
+            VoterSetTest.voterSet(VoterSetTest.voterMap(voters, withDirectoryId)),
+            kraftVersion
+        );
+    }
+
+    private QuorumState buildQuorumState(VoterSet voterSet, short kraftVersion) {
         return new QuorumState(
             OptionalInt.of(localId),
-            voters,
+            localDirectoryId,
+            () -> voterSet,
+            () -> kraftVersion,
             electionTimeoutMs,
             fetchTimeoutMs,
             new MockQuorumStateStore(),
@@ -73,9 +88,21 @@
         );
     }
 
-    @Test
-    public void shouldRecordVoterQuorumState() {
-        QuorumState state = buildQuorumState(Utils.mkSet(localId, 1, 2));
+    @ParameterizedTest
+    @ValueSource(shorts = {0, 1})
+    public void shouldRecordVoterQuorumState(short kraftVersion) {
+        boolean withDirectoryId = kraftVersion > 0;
+        Map<Integer, VoterSet.VoterNode> voterMap = VoterSetTest.voterMap(Utils.mkSet(1, 2), withDirectoryId);
+        voterMap.put(
+            localId,
+            VoterSetTest.voterNode(
+                ReplicaKey.of(
+                    localId,
+                    withDirectoryId ? Optional.of(localDirectoryId) : Optional.empty()
+                )
+            )
+        );
+        QuorumState state = buildQuorumState(VoterSetTest.voterSet(voterMap), kraftVersion);
 
         state.initialize(new OffsetAndEpoch(0L, 0));
         raftMetrics = new KafkaRaftMetrics(metrics, "raft", state);
@@ -83,6 +110,10 @@
         assertEquals("unattached", getMetric(metrics, "current-state").metricValue());
         assertEquals((double) -1L, getMetric(metrics, "current-leader").metricValue());
         assertEquals((double) -1L, getMetric(metrics, "current-vote").metricValue());
+        assertEquals(
+            Uuid.ZERO_UUID.toString(),
+            getMetric(metrics, "current-vote-directory-id").metricValue()
+        );
         assertEquals((double) 0, getMetric(metrics, "current-epoch").metricValue());
         assertEquals((double) -1L, getMetric(metrics, "high-watermark").metricValue());
 
@@ -90,6 +121,10 @@
         assertEquals("candidate", getMetric(metrics, "current-state").metricValue());
         assertEquals((double) -1L, getMetric(metrics, "current-leader").metricValue());
         assertEquals((double) localId, getMetric(metrics, "current-vote").metricValue());
+        assertEquals(
+            localDirectoryId.toString(),
+            getMetric(metrics, "current-vote-directory-id").metricValue()
+        );
         assertEquals((double) 1, getMetric(metrics, "current-epoch").metricValue());
         assertEquals((double) -1L, getMetric(metrics, "high-watermark").metricValue());
 
@@ -98,6 +133,10 @@
         assertEquals("leader", getMetric(metrics, "current-state").metricValue());
         assertEquals((double) localId, getMetric(metrics, "current-leader").metricValue());
         assertEquals((double) localId, getMetric(metrics, "current-vote").metricValue());
+        assertEquals(
+            localDirectoryId.toString(),
+            getMetric(metrics, "current-vote-directory-id").metricValue()
+        );
         assertEquals((double) 1, getMetric(metrics, "current-epoch").metricValue());
         assertEquals((double) -1L, getMetric(metrics, "high-watermark").metricValue());
 
@@ -109,16 +148,24 @@
         assertEquals("follower", getMetric(metrics, "current-state").metricValue());
         assertEquals((double) 1, getMetric(metrics, "current-leader").metricValue());
         assertEquals((double) -1, getMetric(metrics, "current-vote").metricValue());
+        assertEquals(
+            Uuid.ZERO_UUID.toString(),
+            getMetric(metrics, "current-vote-directory-id").metricValue()
+        );
         assertEquals((double) 2, getMetric(metrics, "current-epoch").metricValue());
         assertEquals((double) 5L, getMetric(metrics, "high-watermark").metricValue());
 
         state.followerStateOrThrow().updateHighWatermark(OptionalLong.of(10L));
         assertEquals((double) 10L, getMetric(metrics, "high-watermark").metricValue());
 
-        state.transitionToVoted(3, 2);
+        state.transitionToVoted(3, ReplicaKey.of(2, Optional.empty()));
         assertEquals("voted", getMetric(metrics, "current-state").metricValue());
         assertEquals((double) -1, getMetric(metrics, "current-leader").metricValue());
         assertEquals((double) 2, getMetric(metrics, "current-vote").metricValue());
+        assertEquals(
+            Uuid.ZERO_UUID.toString(),
+            getMetric(metrics, "current-vote-directory-id").metricValue()
+        );
         assertEquals((double) 3, getMetric(metrics, "current-epoch").metricValue());
         assertEquals((double) 10L, getMetric(metrics, "high-watermark").metricValue());
 
@@ -126,19 +173,28 @@
         assertEquals("unattached", getMetric(metrics, "current-state").metricValue());
         assertEquals((double) -1, getMetric(metrics, "current-leader").metricValue());
         assertEquals((double) -1, getMetric(metrics, "current-vote").metricValue());
+        assertEquals(
+            Uuid.ZERO_UUID.toString(),
+            getMetric(metrics, "current-vote-directory-id").metricValue()
+        );
         assertEquals((double) 4, getMetric(metrics, "current-epoch").metricValue());
         assertEquals((double) 10L, getMetric(metrics, "high-watermark").metricValue());
     }
 
-    @Test
-    public void shouldRecordNonVoterQuorumState() {
-        QuorumState state = buildQuorumState(Utils.mkSet(1, 2, 3));
+    @ParameterizedTest
+    @ValueSource(shorts = {0, 1})
+    public void shouldRecordNonVoterQuorumState(short kraftVersion) {
+        QuorumState state = buildQuorumState(Utils.mkSet(1, 2, 3), kraftVersion);
         state.initialize(new OffsetAndEpoch(0L, 0));
         raftMetrics = new KafkaRaftMetrics(metrics, "raft", state);
 
         assertEquals("unattached", getMetric(metrics, "current-state").metricValue());
         assertEquals((double) -1L, getMetric(metrics, "current-leader").metricValue());
         assertEquals((double) -1L, getMetric(metrics, "current-vote").metricValue());
+        assertEquals(
+            Uuid.ZERO_UUID.toString(),
+            getMetric(metrics, "current-vote-directory-id").metricValue()
+        );
         assertEquals((double) 0, getMetric(metrics, "current-epoch").metricValue());
         assertEquals((double) -1L, getMetric(metrics, "high-watermark").metricValue());
 
@@ -146,6 +202,10 @@
         assertEquals("observer", getMetric(metrics, "current-state").metricValue());
         assertEquals((double) 1, getMetric(metrics, "current-leader").metricValue());
         assertEquals((double) -1, getMetric(metrics, "current-vote").metricValue());
+        assertEquals(
+            Uuid.ZERO_UUID.toString(),
+            getMetric(metrics, "current-vote-directory-id").metricValue()
+        );
         assertEquals((double) 2, getMetric(metrics, "current-epoch").metricValue());
         assertEquals((double) -1L, getMetric(metrics, "high-watermark").metricValue());
 
@@ -156,13 +216,18 @@
         assertEquals("unattached", getMetric(metrics, "current-state").metricValue());
         assertEquals((double) -1, getMetric(metrics, "current-leader").metricValue());
         assertEquals((double) -1, getMetric(metrics, "current-vote").metricValue());
+        assertEquals(
+            Uuid.ZERO_UUID.toString(),
+            getMetric(metrics, "current-vote-directory-id").metricValue()
+        );
         assertEquals((double) 4, getMetric(metrics, "current-epoch").metricValue());
         assertEquals((double) 10L, getMetric(metrics, "high-watermark").metricValue());
     }
 
-    @Test
-    public void shouldRecordLogEnd() {
-        QuorumState state = buildQuorumState(Collections.singleton(localId));
+    @ParameterizedTest
+    @ValueSource(shorts = {0, 1})
+    public void shouldRecordLogEnd(short kraftVersion) {
+        QuorumState state = buildQuorumState(Collections.singleton(localId), kraftVersion);
         state.initialize(new OffsetAndEpoch(0L, 0));
         raftMetrics = new KafkaRaftMetrics(metrics, "raft", state);
 
@@ -175,9 +240,10 @@
         assertEquals((double) 1, getMetric(metrics, "log-end-epoch").metricValue());
     }
 
-    @Test
-    public void shouldRecordNumUnknownVoterConnections() {
-        QuorumState state = buildQuorumState(Collections.singleton(localId));
+    @ParameterizedTest
+    @ValueSource(shorts = {0, 1})
+    public void shouldRecordNumUnknownVoterConnections(short kraftVersion) {
+        QuorumState state = buildQuorumState(Collections.singleton(localId), kraftVersion);
         state.initialize(new OffsetAndEpoch(0L, 0));
         raftMetrics = new KafkaRaftMetrics(metrics, "raft", state);
 
@@ -188,9 +254,10 @@
         assertEquals((double) 2, getMetric(metrics, "number-unknown-voter-connections").metricValue());
     }
 
-    @Test
-    public void shouldRecordPollIdleRatio() {
-        QuorumState state = buildQuorumState(Collections.singleton(localId));
+    @ParameterizedTest
+    @ValueSource(shorts = {0, 1})
+    public void shouldRecordPollIdleRatio(short kraftVersion) {
+        QuorumState state = buildQuorumState(Collections.singleton(localId), kraftVersion);
         state.initialize(new OffsetAndEpoch(0L, 0));
         raftMetrics = new KafkaRaftMetrics(metrics, "raft", state);
 
@@ -260,9 +327,10 @@
         assertEquals(0.5, getMetric(metrics, "poll-idle-ratio-avg").metricValue());
     }
 
-    @Test
-    public void shouldRecordLatency() {
-        QuorumState state = buildQuorumState(Collections.singleton(localId));
+    @ParameterizedTest
+    @ValueSource(shorts = {0, 1})
+    public void shouldRecordLatency(short kraftVersion) {
+        QuorumState state = buildQuorumState(Collections.singleton(localId), kraftVersion);
         state.initialize(new OffsetAndEpoch(0L, 0));
         raftMetrics = new KafkaRaftMetrics(metrics, "raft", state);
 
@@ -291,9 +359,10 @@
         assertEquals(60.0, getMetric(metrics, "commit-latency-max").metricValue());
     }
 
-    @Test
-    public void shouldRecordRate() {
-        QuorumState state = buildQuorumState(Collections.singleton(localId));
+    @ParameterizedTest
+    @ValueSource(shorts = {0, 1})
+    public void shouldRecordRate(short kraftVersion) {
+        QuorumState state = buildQuorumState(Collections.singleton(localId), kraftVersion);
         state.initialize(new OffsetAndEpoch(0L, 0));
         raftMetrics = new KafkaRaftMetrics(metrics, "raft", state);
 
diff --git a/raft/src/test/java/org/apache/kafka/raft/internals/RecordsIteratorTest.java b/raft/src/test/java/org/apache/kafka/raft/internals/RecordsIteratorTest.java
index 580f509..8b5fe69 100644
--- a/raft/src/test/java/org/apache/kafka/raft/internals/RecordsIteratorTest.java
+++ b/raft/src/test/java/org/apache/kafka/raft/internals/RecordsIteratorTest.java
@@ -202,7 +202,9 @@
     @Test
     public void testControlRecordIterationWithKraftVerion1() {
         AtomicReference<ByteBuffer> buffer = new AtomicReference<>(null);
-        VoterSet voterSet = new VoterSet(new HashMap<>(VoterSetTest.voterMap(Arrays.asList(1, 2, 3))));
+        VoterSet voterSet = new VoterSet(
+            new HashMap<>(VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true))
+        );
         RecordsSnapshotWriter.Builder builder = new RecordsSnapshotWriter.Builder()
             .setTime(new MockTime())
             .setKraftVersion((short) 1)
diff --git a/raft/src/test/java/org/apache/kafka/raft/internals/VoterSetHistoryTest.java b/raft/src/test/java/org/apache/kafka/raft/internals/VoterSetHistoryTest.java
index c5c26e3..14386f8 100644
--- a/raft/src/test/java/org/apache/kafka/raft/internals/VoterSetHistoryTest.java
+++ b/raft/src/test/java/org/apache/kafka/raft/internals/VoterSetHistoryTest.java
@@ -27,7 +27,7 @@
 final public class VoterSetHistoryTest {
     @Test
     void testStaicVoterSet() {
-        VoterSet staticVoterSet = new VoterSet(VoterSetTest.voterMap(Arrays.asList(1, 2, 3)));
+        VoterSet staticVoterSet = new VoterSet(VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true));
         VoterSetHistory votersHistory = new VoterSetHistory(Optional.of(staticVoterSet));
 
         assertEquals(Optional.empty(), votersHistory.valueAtOrBefore(0));
@@ -58,17 +58,17 @@
 
     @Test
     void testAddAt() {
-        Map<Integer, VoterSet.VoterNode> voterMap = VoterSetTest.voterMap(Arrays.asList(1, 2, 3));
+        Map<Integer, VoterSet.VoterNode> voterMap = VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true);
         VoterSet staticVoterSet = new VoterSet(new HashMap<>(voterMap));
         VoterSetHistory votersHistory = new VoterSetHistory(Optional.of(staticVoterSet));
 
         assertThrows(
             IllegalArgumentException.class,
-            () -> votersHistory.addAt(-1, new VoterSet(VoterSetTest.voterMap(Arrays.asList(1, 2, 3))))
+            () -> votersHistory.addAt(-1, new VoterSet(VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true)))
         );
         assertEquals(staticVoterSet, votersHistory.lastValue());
 
-        voterMap.put(4, VoterSetTest.voterNode(4));
+        voterMap.put(4, VoterSetTest.voterNode(4, true));
         VoterSet addedVoterSet = new VoterSet(new HashMap<>(voterMap));
         votersHistory.addAt(100, addedVoterSet);
 
@@ -89,7 +89,7 @@
     void testAddAtNonOverlapping() {
         VoterSetHistory votersHistory = new VoterSetHistory(Optional.empty());
 
-        Map<Integer, VoterSet.VoterNode> voterMap = VoterSetTest.voterMap(Arrays.asList(1, 2, 3));
+        Map<Integer, VoterSet.VoterNode> voterMap = VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true);
         VoterSet voterSet = new VoterSet(new HashMap<>(voterMap));
 
         // Add a starting voter to the history
@@ -109,8 +109,8 @@
 
         // Add voters so that it doesn't overlap
         VoterSet nonoverlappingAddSet = voterSet
-            .addVoter(VoterSetTest.voterNode(4)).get()
-            .addVoter(VoterSetTest.voterNode(5)).get();
+            .addVoter(VoterSetTest.voterNode(4, true)).get()
+            .addVoter(VoterSetTest.voterNode(5, true)).get();
 
         assertThrows(
             IllegalArgumentException.class,
@@ -121,7 +121,7 @@
 
     @Test
     void testNonoverlappingFromStaticVoterSet() {
-        Map<Integer, VoterSet.VoterNode> voterMap = VoterSetTest.voterMap(Arrays.asList(1, 2, 3));
+        Map<Integer, VoterSet.VoterNode> voterMap = VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true);
         VoterSet staticVoterSet = new VoterSet(new HashMap<>(voterMap));
         VoterSetHistory votersHistory = new VoterSetHistory(Optional.empty());
 
@@ -136,17 +136,17 @@
 
     @Test
     void testTruncateTo() {
-        Map<Integer, VoterSet.VoterNode> voterMap = VoterSetTest.voterMap(Arrays.asList(1, 2, 3));
+        Map<Integer, VoterSet.VoterNode> voterMap = VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true);
         VoterSet staticVoterSet = new VoterSet(new HashMap<>(voterMap));
         VoterSetHistory votersHistory = new VoterSetHistory(Optional.of(staticVoterSet));
 
         // Add voter 4 to the voter set and voter set history
-        voterMap.put(4, VoterSetTest.voterNode(4));
+        voterMap.put(4, VoterSetTest.voterNode(4, true));
         VoterSet voterSet1234 = new VoterSet(new HashMap<>(voterMap));
         votersHistory.addAt(100, voterSet1234);
 
         // Add voter 5 to the voter set and voter set history
-        voterMap.put(5, VoterSetTest.voterNode(5));
+        voterMap.put(5, VoterSetTest.voterNode(5, true));
         VoterSet voterSet12345 = new VoterSet(new HashMap<>(voterMap));
         votersHistory.addAt(200, voterSet12345);
 
@@ -162,17 +162,17 @@
 
     @Test
     void testTrimPrefixTo() {
-        Map<Integer, VoterSet.VoterNode> voterMap = VoterSetTest.voterMap(Arrays.asList(1, 2, 3));
+        Map<Integer, VoterSet.VoterNode> voterMap = VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true);
         VoterSet staticVoterSet = new VoterSet(new HashMap<>(voterMap));
         VoterSetHistory votersHistory = new VoterSetHistory(Optional.of(staticVoterSet));
 
         // Add voter 4 to the voter set and voter set history
-        voterMap.put(4, VoterSetTest.voterNode(4));
+        voterMap.put(4, VoterSetTest.voterNode(4, true));
         VoterSet voterSet1234 = new VoterSet(new HashMap<>(voterMap));
         votersHistory.addAt(100, voterSet1234);
 
         // Add voter 5 to the voter set and voter set history
-        voterMap.put(5, VoterSetTest.voterNode(5));
+        voterMap.put(5, VoterSetTest.voterNode(5, true));
         VoterSet voterSet12345 = new VoterSet(new HashMap<>(voterMap));
         votersHistory.addAt(200, voterSet12345);
 
@@ -195,17 +195,17 @@
 
     @Test
     void testClear() {
-        Map<Integer, VoterSet.VoterNode> voterMap = VoterSetTest.voterMap(Arrays.asList(1, 2, 3));
+        Map<Integer, VoterSet.VoterNode> voterMap = VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true);
         VoterSet staticVoterSet = new VoterSet(new HashMap<>(voterMap));
         VoterSetHistory votersHistory = new VoterSetHistory(Optional.of(staticVoterSet));
 
         // Add voter 4 to the voter set and voter set history
-        voterMap.put(4, VoterSetTest.voterNode(4));
+        voterMap.put(4, VoterSetTest.voterNode(4, true));
         VoterSet voterSet1234 = new VoterSet(new HashMap<>(voterMap));
         votersHistory.addAt(100, voterSet1234);
 
         // Add voter 5 to the voter set and voter set history
-        voterMap.put(5, VoterSetTest.voterNode(5));
+        voterMap.put(5, VoterSetTest.voterNode(5, true));
         VoterSet voterSet12345 = new VoterSet(new HashMap<>(voterMap));
         votersHistory.addAt(200, voterSet12345);
 
diff --git a/raft/src/test/java/org/apache/kafka/raft/internals/VoterSetTest.java b/raft/src/test/java/org/apache/kafka/raft/internals/VoterSetTest.java
index 6226a88..f0ed10a 100644
--- a/raft/src/test/java/org/apache/kafka/raft/internals/VoterSetTest.java
+++ b/raft/src/test/java/org/apache/kafka/raft/internals/VoterSetTest.java
@@ -18,10 +18,10 @@
 
 import java.net.InetSocketAddress;
 import java.util.Arrays;
+import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashSet;
-import java.util.List;
 import java.util.Map;
 import java.util.Optional;
 import java.util.function.Function;
@@ -30,7 +30,9 @@
 import org.apache.kafka.common.feature.SupportedVersionRange;
 import org.junit.jupiter.api.Test;
 import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
 import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
 
 final public class VoterSetTest {
     @Test
@@ -40,7 +42,7 @@
 
     @Test
     void testVoterAddress() {
-        VoterSet voterSet = new VoterSet(voterMap(Arrays.asList(1, 2, 3)));
+        VoterSet voterSet = new VoterSet(voterMap(Arrays.asList(1, 2, 3), true));
         assertEquals(Optional.of(new InetSocketAddress("replica-1", 1234)), voterSet.voterAddress(1, "LISTENER"));
         assertEquals(Optional.empty(), voterSet.voterAddress(1, "MISSING"));
         assertEquals(Optional.empty(), voterSet.voterAddress(4, "LISTENER"));
@@ -48,29 +50,29 @@
 
     @Test
     void testVoterIds() {
-        VoterSet voterSet = new VoterSet(voterMap(Arrays.asList(1, 2, 3)));
+        VoterSet voterSet = new VoterSet(voterMap(Arrays.asList(1, 2, 3), true));
         assertEquals(new HashSet<>(Arrays.asList(1, 2, 3)), voterSet.voterIds());
     }
 
     @Test
     void testAddVoter() {
-        Map<Integer, VoterSet.VoterNode> aVoterMap = voterMap(Arrays.asList(1, 2, 3));
+        Map<Integer, VoterSet.VoterNode> aVoterMap = voterMap(Arrays.asList(1, 2, 3), true);
         VoterSet voterSet = new VoterSet(new HashMap<>(aVoterMap));
 
-        assertEquals(Optional.empty(), voterSet.addVoter(voterNode(1)));
+        assertEquals(Optional.empty(), voterSet.addVoter(voterNode(1, true)));
 
-        VoterSet.VoterNode voter4 = voterNode(4);
+        VoterSet.VoterNode voter4 = voterNode(4, true);
         aVoterMap.put(voter4.voterKey().id(), voter4);
         assertEquals(Optional.of(new VoterSet(new HashMap<>(aVoterMap))), voterSet.addVoter(voter4));
     }
 
     @Test
     void testRemoveVoter() {
-        Map<Integer, VoterSet.VoterNode> aVoterMap = voterMap(Arrays.asList(1, 2, 3));
+        Map<Integer, VoterSet.VoterNode> aVoterMap = voterMap(Arrays.asList(1, 2, 3), true);
         VoterSet voterSet = new VoterSet(new HashMap<>(aVoterMap));
 
-        assertEquals(Optional.empty(), voterSet.removeVoter(VoterSet.VoterKey.of(4, Optional.empty())));
-        assertEquals(Optional.empty(), voterSet.removeVoter(VoterSet.VoterKey.of(4, Optional.of(Uuid.randomUuid()))));
+        assertEquals(Optional.empty(), voterSet.removeVoter(ReplicaKey.of(4, Optional.empty())));
+        assertEquals(Optional.empty(), voterSet.removeVoter(ReplicaKey.of(4, Optional.of(Uuid.randomUuid()))));
 
         VoterSet.VoterNode voter3 = aVoterMap.remove(3);
         assertEquals(
@@ -80,19 +82,78 @@
     }
 
     @Test
+    void testIsVoterWithDirectoryId() {
+        Map<Integer, VoterSet.VoterNode> aVoterMap = voterMap(Arrays.asList(1, 2, 3), true);
+        VoterSet voterSet = new VoterSet(new HashMap<>(aVoterMap));
+
+        assertTrue(voterSet.isVoter(aVoterMap.get(1).voterKey()));
+        assertFalse(voterSet.isVoter(ReplicaKey.of(1, Optional.of(Uuid.randomUuid()))));
+        assertFalse(voterSet.isVoter(ReplicaKey.of(1, Optional.empty())));
+        assertFalse(
+            voterSet.isVoter(ReplicaKey.of(2, aVoterMap.get(1).voterKey().directoryId()))
+        );
+        assertFalse(
+            voterSet.isVoter(ReplicaKey.of(4, aVoterMap.get(1).voterKey().directoryId()))
+        );
+        assertFalse(voterSet.isVoter(ReplicaKey.of(4, Optional.empty())));
+    }
+
+    @Test
+    void testIsVoterWithoutDirectoryId() {
+        Map<Integer, VoterSet.VoterNode> aVoterMap = voterMap(Arrays.asList(1, 2, 3), false);
+        VoterSet voterSet = new VoterSet(new HashMap<>(aVoterMap));
+
+        assertTrue(voterSet.isVoter(ReplicaKey.of(1, Optional.empty())));
+        assertTrue(voterSet.isVoter(ReplicaKey.of(1, Optional.of(Uuid.randomUuid()))));
+        assertFalse(voterSet.isVoter(ReplicaKey.of(4, Optional.of(Uuid.randomUuid()))));
+        assertFalse(voterSet.isVoter(ReplicaKey.of(4, Optional.empty())));
+    }
+
+    @Test
+    void testIsOnlyVoterInStandalone() {
+        Map<Integer, VoterSet.VoterNode> aVoterMap = voterMap(Arrays.asList(1), true);
+        VoterSet voterSet = new VoterSet(new HashMap<>(aVoterMap));
+
+        assertTrue(voterSet.isOnlyVoter(aVoterMap.get(1).voterKey()));
+        assertFalse(voterSet.isOnlyVoter(ReplicaKey.of(1, Optional.of(Uuid.randomUuid()))));
+        assertFalse(voterSet.isOnlyVoter(ReplicaKey.of(1, Optional.empty())));
+        assertFalse(
+            voterSet.isOnlyVoter(ReplicaKey.of(4, aVoterMap.get(1).voterKey().directoryId()))
+        );
+        assertFalse(voterSet.isOnlyVoter(ReplicaKey.of(4, Optional.empty())));
+    }
+
+    @Test
+    void testIsOnlyVoterInNotStandalone() {
+        Map<Integer, VoterSet.VoterNode> aVoterMap = voterMap(Arrays.asList(1, 2), true);
+        VoterSet voterSet = new VoterSet(new HashMap<>(aVoterMap));
+
+        assertFalse(voterSet.isOnlyVoter(aVoterMap.get(1).voterKey()));
+        assertFalse(voterSet.isOnlyVoter(ReplicaKey.of(1, Optional.of(Uuid.randomUuid()))));
+        assertFalse(voterSet.isOnlyVoter(ReplicaKey.of(1, Optional.empty())));
+        assertFalse(
+            voterSet.isOnlyVoter(ReplicaKey.of(2, aVoterMap.get(1).voterKey().directoryId()))
+        );
+        assertFalse(
+            voterSet.isOnlyVoter(ReplicaKey.of(4, aVoterMap.get(1).voterKey().directoryId()))
+        );
+        assertFalse(voterSet.isOnlyVoter(ReplicaKey.of(4, Optional.empty())));
+    }
+
+    @Test
     void testRecordRoundTrip() {
-        VoterSet voterSet = new VoterSet(voterMap(Arrays.asList(1, 2, 3)));
+        VoterSet voterSet = new VoterSet(voterMap(Arrays.asList(1, 2, 3), true));
 
         assertEquals(voterSet, VoterSet.fromVotersRecord(voterSet.toVotersRecord((short) 0)));
     }
 
     @Test
     void testOverlappingMajority() {
-        Map<Integer, VoterSet.VoterNode> startingVoterMap = voterMap(Arrays.asList(1, 2, 3));
+        Map<Integer, VoterSet.VoterNode> startingVoterMap = voterMap(Arrays.asList(1, 2, 3), true);
         VoterSet startingVoterSet = voterSet(startingVoterMap);
 
         VoterSet biggerVoterSet = startingVoterSet
-            .addVoter(voterNode(4))
+            .addVoter(voterNode(4, true))
             .get();
         assertMajorities(true, startingVoterSet, biggerVoterSet);
 
@@ -104,21 +165,21 @@
         VoterSet replacedVoterSet = startingVoterSet
             .removeVoter(startingVoterMap.get(1).voterKey())
             .get()
-            .addVoter(voterNode(1))
+            .addVoter(voterNode(1, true))
             .get();
         assertMajorities(true, startingVoterSet, replacedVoterSet);
     }
 
     @Test
     void testNonoverlappingMajority() {
-        Map<Integer, VoterSet.VoterNode> startingVoterMap = voterMap(Arrays.asList(1, 2, 3, 4, 5));
+        Map<Integer, VoterSet.VoterNode> startingVoterMap = voterMap(Arrays.asList(1, 2, 3, 4, 5), true);
         VoterSet startingVoterSet = voterSet(startingVoterMap);
 
         // Two additions don't have an overlapping majority
         VoterSet biggerVoterSet = startingVoterSet
-            .addVoter(voterNode(6))
+            .addVoter(voterNode(6, true))
             .get()
-            .addVoter(voterNode(7))
+            .addVoter(voterNode(7, true))
             .get();
         assertMajorities(false, startingVoterSet, biggerVoterSet);
 
@@ -134,11 +195,11 @@
         VoterSet replacedVoterSet = startingVoterSet
             .removeVoter(startingVoterMap.get(1).voterKey())
             .get()
-            .addVoter(voterNode(1))
+            .addVoter(voterNode(1, true))
             .get()
             .removeVoter(startingVoterMap.get(2).voterKey())
             .get()
-            .addVoter(voterNode(2))
+            .addVoter(voterNode(2, true))
             .get();
         assertMajorities(false, startingVoterSet, replacedVoterSet);
     }
@@ -156,23 +217,38 @@
         );
     }
 
-    public static Map<Integer, VoterSet.VoterNode> voterMap(List<Integer> replicas) {
+    public static Map<Integer, VoterSet.VoterNode> voterMap(
+        Collection<Integer> replicas,
+        boolean withDirectoryId
+    ) {
         return replicas
             .stream()
             .collect(
                 Collectors.toMap(
                     Function.identity(),
-                    VoterSetTest::voterNode
+                    id -> VoterSetTest.voterNode(id, withDirectoryId)
                 )
             );
     }
 
-    static VoterSet.VoterNode voterNode(int id) {
+    public static VoterSet.VoterNode voterNode(int id, boolean withDirectoryId) {
+        return voterNode(
+            ReplicaKey.of(
+                id,
+                withDirectoryId ? Optional.of(Uuid.randomUuid()) : Optional.empty()
+            )
+        );
+    }
+
+    public static VoterSet.VoterNode voterNode(ReplicaKey replicaKey) {
         return new VoterSet.VoterNode(
-            VoterSet.VoterKey.of(id, Optional.of(Uuid.randomUuid())),
+            replicaKey,
             Collections.singletonMap(
                 "LISTENER",
-                InetSocketAddress.createUnresolved(String.format("replica-%d", id), 1234)
+                InetSocketAddress.createUnresolved(
+                    String.format("replica-%d", replicaKey.id()),
+                    1234
+                )
             ),
             new SupportedVersionRange((short) 0, (short) 0)
         );
diff --git a/raft/src/test/java/org/apache/kafka/snapshot/RecordsSnapshotWriterTest.java b/raft/src/test/java/org/apache/kafka/snapshot/RecordsSnapshotWriterTest.java
index fe8654f..17b7c5d 100644
--- a/raft/src/test/java/org/apache/kafka/snapshot/RecordsSnapshotWriterTest.java
+++ b/raft/src/test/java/org/apache/kafka/snapshot/RecordsSnapshotWriterTest.java
@@ -96,7 +96,9 @@
     void testBuilderKRaftVersion0WithVoterSet() {
         OffsetAndEpoch snapshotId = new OffsetAndEpoch(100, 10);
         int maxBatchSize = 1024;
-        VoterSet voterSet = VoterSetTest.voterSet(new HashMap<>(VoterSetTest.voterMap(Arrays.asList(1, 2, 3))));
+        VoterSet voterSet = VoterSetTest.voterSet(
+            new HashMap<>(VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true))
+        );
         AtomicReference<ByteBuffer> buffer = new AtomicReference<>(null);
         RecordsSnapshotWriter.Builder builder = new RecordsSnapshotWriter.Builder()
             .setKraftVersion((short) 0)
@@ -114,7 +116,9 @@
     void testKBuilderRaftVersion1WithVoterSet() {
         OffsetAndEpoch snapshotId = new OffsetAndEpoch(100, 10);
         int maxBatchSize = 1024;
-        VoterSet voterSet = VoterSetTest.voterSet(new HashMap<>(VoterSetTest.voterMap(Arrays.asList(1, 2, 3))));
+        VoterSet voterSet = VoterSetTest.voterSet(
+            new HashMap<>(VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true))
+        );
         AtomicReference<ByteBuffer> buffer = new AtomicReference<>(null);
         RecordsSnapshotWriter.Builder builder = new RecordsSnapshotWriter.Builder()
             .setKraftVersion((short) 1)