KAFKA-16027: Refactor testUpdatePartitionLeadership (#17083)
Reviewers: David Arthur <mumrah@gmail.com>
diff --git a/clients/src/main/java/org/apache/kafka/clients/Metadata.java b/clients/src/main/java/org/apache/kafka/clients/Metadata.java
index 5f2a412..ece1a25 100644
--- a/clients/src/main/java/org/apache/kafka/clients/Metadata.java
+++ b/clients/src/main/java/org/apache/kafka/clients/Metadata.java
@@ -273,7 +273,7 @@
     /**
      * Return the cached partition info if it exists and a newer leader epoch isn't known about.
      */
-    synchronized Optional<MetadataResponse.PartitionMetadata> partitionMetadataIfCurrent(TopicPartition topicPartition) {
+    public synchronized Optional<MetadataResponse.PartitionMetadata> partitionMetadataIfCurrent(TopicPartition topicPartition) {
         Integer epoch = lastSeenLeaderEpochs.get(topicPartition);
         Optional<MetadataResponse.PartitionMetadata> partitionMetadata = metadataSnapshot.partitionMetadata(topicPartition);
         if (epoch == null) {
diff --git a/clients/src/test/java/org/apache/kafka/clients/MetadataTest.java b/clients/src/test/java/org/apache/kafka/clients/MetadataTest.java
index 3573e5b..450048e 100644
--- a/clients/src/test/java/org/apache/kafka/clients/MetadataTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/MetadataTest.java
@@ -17,7 +17,6 @@
 package org.apache.kafka.clients;
 
 import org.apache.kafka.common.Cluster;
-import org.apache.kafka.common.ClusterResourceListener;
 import org.apache.kafka.common.Node;
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.Uuid;
@@ -35,9 +34,7 @@
 import org.apache.kafka.common.protocol.MessageUtil;
 import org.apache.kafka.common.requests.MetadataRequest;
 import org.apache.kafka.common.requests.MetadataResponse;
-import org.apache.kafka.common.requests.MetadataResponse.PartitionMetadata;
 import org.apache.kafka.common.requests.RequestTestUtils;
-import org.apache.kafka.common.requests.RequestTestUtils.PartitionMetadataSupplier;
 import org.apache.kafka.common.utils.LogContext;
 import org.apache.kafka.common.utils.MockTime;
 import org.apache.kafka.common.utils.Time;
@@ -45,11 +42,9 @@
 import org.apache.kafka.test.MockClusterResourceListener;
 
 import org.junit.jupiter.api.Test;
-import org.mockito.Mockito;
 
 import java.net.InetSocketAddress;
 import java.nio.ByteBuffer;
-import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.HashMap;
@@ -65,7 +60,6 @@
 import java.util.concurrent.Executors;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicReference;
-import java.util.stream.Collectors;
 
 import static org.apache.kafka.test.TestUtils.assertOptional;
 import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
@@ -75,10 +69,6 @@
 import static org.junit.jupiter.api.Assertions.assertNull;
 import static org.junit.jupiter.api.Assertions.assertThrows;
 import static org.junit.jupiter.api.Assertions.assertTrue;
-import static org.mockito.ArgumentMatchers.any;
-import static org.mockito.Mockito.never;
-import static org.mockito.Mockito.times;
-import static org.mockito.Mockito.verify;
 
 public class MetadataTest {
 
@@ -1149,123 +1139,6 @@
         assertEquals(2, metadata.fetch().partition(tp1).leader().id());
     }
 
-    @Test
-    public void testUpdatePartitionLeadership() {
-        Time time = new MockTime();
-
-        // Initialize metadata
-        int numNodes = 5;
-        metadata = new Metadata(refreshBackoffMs, refreshBackoffMaxMs, metadataExpireMs, new LogContext(), new ClusterResourceListeners());
-        ClusterResourceListener mockListener = Mockito.mock(ClusterResourceListener.class);
-        metadata.addClusterUpdateListener(mockListener);
-        // topic1 has 2 partitions: tp11, tp12
-        // topic2 has 1 partition: tp21
-        String topic1 = "topic1";
-        TopicPartition tp11 = new TopicPartition(topic1, 0);
-        PartitionMetadata part1Metadata = new PartitionMetadata(Errors.NONE, tp11, Optional.of(1), Optional.of(100), Arrays.asList(1, 2), Arrays.asList(1, 2), Collections.singletonList(3));
-        Uuid topic1Id = Uuid.randomUuid();
-        TopicPartition tp12 = new TopicPartition(topic1, 1);
-        PartitionMetadata part12Metadata = new PartitionMetadata(Errors.NONE, tp12, Optional.of(2), Optional.of(200), Arrays.asList(2, 3), Arrays.asList(2, 3), Collections.singletonList(1));
-
-        String topic2 = "topic2";
-        TopicPartition tp21 = new TopicPartition(topic2, 0);
-        PartitionMetadata part2Metadata = new PartitionMetadata(Errors.NONE, tp21, Optional.of(2), Optional.of(200), Arrays.asList(2, 3), Arrays.asList(2, 3), Collections.singletonList(1));
-        Uuid topic2Id = Uuid.randomUuid();
-
-        Set<String> internalTopics = Collections.singleton(Topic.GROUP_METADATA_TOPIC_NAME);
-        TopicPartition internalPart = new TopicPartition(Topic.GROUP_METADATA_TOPIC_NAME, 0);
-        Uuid internalTopicId = Uuid.randomUuid();
-        PartitionMetadata internalTopicMetadata = new PartitionMetadata(Errors.NONE, internalPart, Optional.of(2), Optional.of(200), Arrays.asList(2, 3), Arrays.asList(2, 3), Collections.singletonList(1));
-
-        Map<String, Uuid> topicIds = new HashMap<>();
-        topicIds.put(topic1, topic1Id);
-        topicIds.put(topic2, topic2Id);
-        topicIds.put(internalTopics.iterator().next(), internalTopicId);
-
-        Map<String, Integer> topicPartitionCounts = new HashMap<>();
-        topicPartitionCounts.put(topic1, 2);
-        topicPartitionCounts.put(topic2, 1);
-        topicPartitionCounts.put(internalTopics.iterator().next(), 1);
-        PartitionMetadataSupplier metadataSupplier = (error, partition, leaderId, leaderEpoch, replicas, isr, offlineReplicas) -> {
-            if (partition.equals(tp11))
-                return part1Metadata;
-            else if (partition.equals(tp21))
-                return part2Metadata;
-            else if (partition.equals(tp12))
-                return part12Metadata;
-            else if (partition.equals(internalPart))
-                return internalTopicMetadata;
-            throw new RuntimeException("Unexpected partition " + partition);
-        };
-
-        // Setup invalid topics and unauthorized topics.
-        Map<String, Errors> errorCounts = new HashMap<>();
-        Set<String> invalidTopics = Collections.singleton("topic3");
-        errorCounts.put(invalidTopics.iterator().next(), Errors.INVALID_TOPIC_EXCEPTION);
-        Set<String> unauthorizedTopics = Collections.singleton("topic4");
-        errorCounts.put(unauthorizedTopics.iterator().next(), Errors.TOPIC_AUTHORIZATION_FAILED);
-
-        metadata.requestUpdate(true);
-        Metadata.MetadataRequestAndVersion versionAndBuilder = metadata.newMetadataRequestAndVersion(time.milliseconds());
-        assertFalse(versionAndBuilder.isPartialUpdate);
-        String clusterId = "kafka-cluster";
-        metadata.update(versionAndBuilder.requestVersion,
-            RequestTestUtils.metadataUpdateWith(clusterId, numNodes, errorCounts, topicPartitionCounts, tp -> null, metadataSupplier, ApiKeys.METADATA.latestVersion(), topicIds),
-            false, time.milliseconds());
-        List<Node> nodes = new ArrayList<>(metadata.fetch().nodes());
-        Node controller = metadata.fetch().controller();
-        assertEquals(numNodes, nodes.size());
-        assertFalse(metadata.updateRequested());
-        validateForUpdatePartitionLeadership(metadata, part1Metadata, part2Metadata, part12Metadata, internalTopicMetadata, nodes, clusterId, unauthorizedTopics, invalidTopics, internalTopics, controller, topicIds);
-        // Since cluster metadata was updated, listener should be called.
-        verify(mockListener, times(1)).onUpdate(any());
-        Mockito.reset(mockListener);
-
-        // TEST1: Ensure invalid updates get ignored
-        Map<TopicPartition, Metadata.LeaderIdAndEpoch> updates = new HashMap<>();
-        // New leader info is empty/invalid.
-        updates.put(new TopicPartition(topic1, 999), new Metadata.LeaderIdAndEpoch(Optional.empty(), Optional.empty()));
-        // Leader's node is unknown
-        updates.put(tp21, new  Metadata.LeaderIdAndEpoch(Optional.of(99999), Optional.of(99999)));
-        // Partition missing from existing metadata
-        updates.put(new TopicPartition("topic_missing_from_existing_metadata", 1), new  Metadata.LeaderIdAndEpoch(Optional.of(0), Optional.of(99999)));
-        // New leader info is stale.
-        updates.put(tp11, new  Metadata.LeaderIdAndEpoch(part1Metadata.leaderId, Optional.of(part1Metadata.leaderEpoch.get() - 1)));
-        Set<TopicPartition> updatedTps = metadata.updatePartitionLeadership(updates, nodes);
-        assertTrue(updatedTps.isEmpty());
-        validateForUpdatePartitionLeadership(metadata, part1Metadata, part2Metadata, part12Metadata, internalTopicMetadata, nodes, clusterId, unauthorizedTopics, invalidTopics, internalTopics, controller, topicIds);
-        // Since cluster metadata is unchanged, listener shouldn't be called.
-        verify(mockListener, never()).onUpdate(any());
-        Mockito.reset(mockListener);
-
-
-        //TEST2: Ensure valid update to tp11 is applied to the metadata.  Rest (tp12, tp21) remain unchanged.
-        // 1. New Node with id=999 is added.
-        // 2. Existing node with id=0 has host, port changed, so is updated.
-        Integer part1NewLeaderId = part1Metadata.leaderId.get() + 1;
-        Integer part1NewLeaderEpoch = part1Metadata.leaderEpoch.get() + 1;
-        updates.put(tp11, new Metadata.LeaderIdAndEpoch(Optional.of(part1NewLeaderId), Optional.of(part1NewLeaderEpoch)));
-        PartitionMetadata updatedPart1Metadata = new PartitionMetadata(part1Metadata.error, part1Metadata.topicPartition, Optional.of(part1NewLeaderId), Optional.of(part1NewLeaderEpoch), part1Metadata.replicaIds, part1Metadata.inSyncReplicaIds, part1Metadata.offlineReplicaIds);
-
-        Node newNode = new Node(999, "testhost", 99999, "testrack");
-        nodes.add(newNode);
-        int index = nodes.stream().filter(node -> node.id() == 0).findFirst().map(nodes::indexOf).orElse(-1);
-        Node existingNode = nodes.get(index);
-        Node updatedNode = new Node(existingNode.id(), "newhost", existingNode.port(), "newrack");
-        nodes.remove(index);
-        nodes.add(updatedNode);
-
-        updatedTps = metadata.updatePartitionLeadership(updates, nodes);
-
-        assertEquals(1, updatedTps.size());
-        assertEquals(part1Metadata.topicPartition, updatedTps.toArray()[0]);
-        // Validate metadata is changed for partition1, hosts are updated, everything else remains unchanged.
-        validateForUpdatePartitionLeadership(metadata, updatedPart1Metadata, part2Metadata, part12Metadata, internalTopicMetadata, nodes, clusterId, unauthorizedTopics, invalidTopics, internalTopics, controller, topicIds);
-        // Since cluster metadata was updated, listener should be called.
-        verify(mockListener, times(1)).onUpdate(any());
-        Mockito.reset(mockListener);
-    }
-
     /**
      * Test that concurrently updating Metadata, and fetching the corresponding MetadataSnapshot & Cluster work as expected, i.e.
      * snapshot & cluster contain the relevant updates.
@@ -1358,46 +1231,4 @@
         // Executor service should down much quickly, as all tasks are finished at this point.
         assertTrue(service.awaitTermination(60, TimeUnit.SECONDS));
     }
-
-    /**
-     * For testUpdatePartially, validates that updatedMetadata is matching expected part1Metadata, part2Metadata, internalPartMetadata, nodes & more.
-     */
-    void validateForUpdatePartitionLeadership(Metadata updatedMetadata,
-                                              PartitionMetadata part1Metadata, PartitionMetadata part2Metadata, PartitionMetadata part12Metadata,
-                                              PartitionMetadata internalPartMetadata,
-                                              List<Node> expectedNodes,
-                                              String expectedClusterId,
-                                              Set<String> expectedUnauthorisedTopics, Set<String> expectedInvalidTopics, Set<String> expectedInternalTopics,
-                                              Node expectedController,
-                                              Map<String, Uuid> expectedTopicIds) {
-        Cluster updatedCluster = updatedMetadata.fetch();
-        assertEquals(updatedCluster.clusterResource().clusterId(), expectedClusterId);
-        assertEquals(new HashSet<>(expectedNodes), new HashSet<>(updatedCluster.nodes()));
-        assertEquals(3, updatedCluster.topics().size());
-        assertEquals(expectedInternalTopics, updatedCluster.internalTopics());
-        assertEquals(expectedInvalidTopics, updatedCluster.invalidTopics());
-        assertEquals(expectedUnauthorisedTopics, updatedCluster.unauthorizedTopics());
-        assertEquals(expectedController, updatedCluster.controller());
-        assertEquals(expectedTopicIds, updatedMetadata.topicIds());
-
-        Map<Integer, Node> nodeMap = expectedNodes.stream().collect(Collectors.toMap(Node::id, e -> e));
-        for (PartitionMetadata partitionMetadata: Arrays.asList(part1Metadata, part2Metadata, part12Metadata, internalPartMetadata)) {
-            TopicPartition tp = new TopicPartition(partitionMetadata.topic(), partitionMetadata.partition());
-
-            Metadata.LeaderAndEpoch expectedLeaderInfo = new Metadata.LeaderAndEpoch(Optional.of(nodeMap.get(partitionMetadata.leaderId.get())), partitionMetadata.leaderEpoch);
-            assertEquals(expectedLeaderInfo, updatedMetadata.currentLeader(tp));
-
-            // Compare the partition-metadata.
-            Optional<PartitionMetadata> optionalUpdatedMetadata = updatedMetadata.partitionMetadataIfCurrent(tp);
-            assertTrue(optionalUpdatedMetadata.isPresent());
-            PartitionMetadata updatedPartMetadata = optionalUpdatedMetadata.get();
-            assertEquals(partitionMetadata.topicPartition, updatedPartMetadata.topicPartition);
-            assertEquals(partitionMetadata.error, updatedPartMetadata.error);
-            assertEquals(partitionMetadata.leaderId, updatedPartMetadata.leaderId);
-            assertEquals(partitionMetadata.leaderEpoch, updatedPartMetadata.leaderEpoch);
-            assertEquals(partitionMetadata.replicaIds, updatedPartMetadata.replicaIds);
-            assertEquals(partitionMetadata.inSyncReplicaIds, updatedPartMetadata.inSyncReplicaIds);
-            assertEquals(partitionMetadata.offlineReplicaIds, partitionMetadata.offlineReplicaIds);
-        }
-    }
 }
diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerMetadataTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerMetadataTest.java
index ca0f9e4..f4cdc60 100644
--- a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerMetadataTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerMetadataTest.java
@@ -16,11 +16,16 @@
  */
 package org.apache.kafka.clients.consumer.internals;
 
+import org.apache.kafka.clients.Metadata;
 import org.apache.kafka.clients.consumer.OffsetResetStrategy;
+import org.apache.kafka.common.Cluster;
+import org.apache.kafka.common.ClusterResourceListener;
 import org.apache.kafka.common.Node;
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.Uuid;
 import org.apache.kafka.common.internals.ClusterResourceListeners;
+import org.apache.kafka.common.internals.Topic;
+import org.apache.kafka.common.protocol.ApiKeys;
 import org.apache.kafka.common.protocol.Errors;
 import org.apache.kafka.common.requests.MetadataRequest;
 import org.apache.kafka.common.requests.MetadataResponse;
@@ -31,8 +36,10 @@
 import org.apache.kafka.common.utils.Utils;
 
 import org.junit.jupiter.api.Test;
+import org.mockito.Mockito;
 
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashSet;
@@ -40,7 +47,9 @@
 import java.util.Map;
 import java.util.Optional;
 import java.util.Set;
+import java.util.function.Function;
 import java.util.regex.Pattern;
+import java.util.stream.Collectors;
 
 import static java.util.Collections.singleton;
 import static java.util.Collections.singletonList;
@@ -49,11 +58,16 @@
 import static org.junit.jupiter.api.Assertions.assertFalse;
 import static org.junit.jupiter.api.Assertions.assertNull;
 import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
 
 public class ConsumerMetadataTest {
 
     private final Node node = new Node(1, "localhost", 9092);
     private final SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.EARLIEST);
+
     private final Time time = new MockTime();
 
     @Test
@@ -184,4 +198,185 @@
                 subscription, new LogContext(), new ClusterResourceListeners());
     }
 
+    @Test
+    public void testInvalidPartitionLeadershipUpdates() {
+        Metadata metadata = initializeMetadata();
+        List<Node> originalNodes = initializeNodes(metadata);
+        ClusterResourceListener mockListener = initializeMockListener(metadata);
+
+        // Ensure invalid updates get ignored
+        Map<TopicPartition, Metadata.LeaderIdAndEpoch> invalidUpdates = new HashMap<>();
+        // incomplete information
+        invalidUpdates.put(new TopicPartition("topic1", 999), new Metadata.LeaderIdAndEpoch(Optional.empty(), Optional.empty()));
+        // non-existing leader ID
+        invalidUpdates.put(new TopicPartition("topic2", 0), new Metadata.LeaderIdAndEpoch(Optional.of(99999), Optional.of(99999)));
+        // non-existing topicPartition
+        invalidUpdates.put(new TopicPartition("topic_missing_from_existing_metadata", 1), new Metadata.LeaderIdAndEpoch(Optional.of(0), Optional.of(99999)));
+        // stale epoch
+        invalidUpdates.put(new TopicPartition("topic1", 0), new Metadata.LeaderIdAndEpoch(Optional.of(1), Optional.of(99)));
+
+        Set<TopicPartition> updatedTps = metadata.updatePartitionLeadership(invalidUpdates, originalNodes);
+        assertTrue(updatedTps.isEmpty(), "Invalid updates should be ignored");
+
+        Cluster updatedCluster = metadata.fetch();
+        assertEquals(new HashSet<>(originalNodes), new HashSet<>(updatedCluster.nodes()));
+        verify(mockListener, never()).onUpdate(any());
+        validateForUpdatePartitionLeadership(metadata,
+                metadataSupplier(Errors.NONE, new TopicPartition("topic1", 0), Optional.of(1), Optional.of(100), Arrays.asList(1, 2), Arrays.asList(1, 2), Collections.singletonList(3)),
+                metadataSupplier(Errors.NONE, new TopicPartition("topic2", 0), Optional.of(2), Optional.of(200), Arrays.asList(2, 3), Arrays.asList(2, 3), Collections.singletonList(1)),
+                metadataSupplier(Errors.NONE, new TopicPartition("topic1", 1), Optional.of(2), Optional.of(200), Arrays.asList(2, 3), Arrays.asList(2, 3), Collections.singletonList(1)),
+                metadataSupplier(Errors.NONE, new TopicPartition(Topic.GROUP_METADATA_TOPIC_NAME, 0), Optional.of(2), Optional.of(300), Arrays.asList(2, 3), Arrays.asList(2, 3), Collections.singletonList(1)),
+                originalNodes,
+                "kafka-cluster",
+                Collections.singleton("topic4"),
+                Collections.singleton("topic3"),
+                Collections.singleton(Topic.GROUP_METADATA_TOPIC_NAME),
+                updatedCluster.controller(),
+                metadata.topicIds());
+    }
+
+
+    @Test
+    public void testValidPartitionLeadershipUpdate() {
+        Metadata metadata = initializeMetadata();
+        List<Node> originalNodes = initializeNodes(metadata);
+        ClusterResourceListener mockListener = initializeMockListener(metadata);
+
+        // Ensure valid update to tp11 is applied
+        Map<TopicPartition, Metadata.LeaderIdAndEpoch> validUpdates = new HashMap<>();
+        TopicPartition tp11 = new TopicPartition("topic1", 0);
+
+        Integer newLeaderId = 2; // New leader ID
+        Integer newLeaderEpoch = 101; // New leader epoch that is newer than existing
+
+        validUpdates.put(tp11, new Metadata.LeaderIdAndEpoch(Optional.of(newLeaderId), Optional.of(newLeaderEpoch)));
+
+        Set<TopicPartition> updatedTps = metadata.updatePartitionLeadership(validUpdates, originalNodes);
+
+        assertEquals(1, updatedTps.size());
+        assertEquals(tp11, updatedTps.iterator().next(), "tp11 should be updated");
+
+        Cluster updatedCluster = metadata.fetch();
+        assertEquals(new HashSet<>(originalNodes), new HashSet<>(updatedCluster.nodes()));
+        verify(mockListener, times(1)).onUpdate(any());
+        validateForUpdatePartitionLeadership(metadata,
+                new MetadataResponse.PartitionMetadata(Errors.NONE, tp11, Optional.of(newLeaderId), Optional.of(newLeaderEpoch), Arrays.asList(1, 2), Arrays.asList(1, 2), Collections.singletonList(3)),
+                metadataSupplier(Errors.NONE, new TopicPartition("topic2", 0), Optional.of(2), Optional.of(200), Arrays.asList(2, 3), Arrays.asList(2, 3), Collections.singletonList(1)),
+                metadataSupplier(Errors.NONE, new TopicPartition("topic1", 1), Optional.of(2), Optional.of(200), Arrays.asList(2, 3), Arrays.asList(2, 3), Collections.singletonList(1)),
+                metadataSupplier(Errors.NONE, new TopicPartition(Topic.GROUP_METADATA_TOPIC_NAME, 0), Optional.of(2), Optional.of(300), Arrays.asList(2, 3), Arrays.asList(2, 3), Collections.singletonList(1)),
+                originalNodes,
+                "kafka-cluster",
+                Collections.singleton("topic4"),
+                Collections.singleton("topic3"),
+                Collections.singleton(Topic.GROUP_METADATA_TOPIC_NAME),
+                updatedCluster.controller(),
+                metadata.topicIds());
+    }
+
+    private Metadata initializeMetadata() {
+        Metadata metadata = new Metadata(100, 1000, 1000, new LogContext(), new ClusterResourceListeners());
+
+        String topic1 = "topic1";
+        String topic2 = "topic2";
+        String topic3 = "topic3";
+        String topic4 = "topic4";
+        String clusterId = "kafka-cluster";
+        Uuid topic1Id = Uuid.randomUuid();
+        Uuid topic2Id = Uuid.randomUuid();
+        Uuid internalTopicId = Uuid.randomUuid();
+
+        Map<String, Uuid> topicIds = new HashMap<>();
+        topicIds.put(topic1, topic1Id);
+        topicIds.put(topic2, topic2Id);
+        topicIds.put(Topic.GROUP_METADATA_TOPIC_NAME, internalTopicId);
+
+        Map<String, Errors> errorCounts = new HashMap<>();
+        errorCounts.put(topic3, Errors.INVALID_TOPIC_EXCEPTION);
+        errorCounts.put(topic4, Errors.TOPIC_AUTHORIZATION_FAILED);
+
+        Map<String, Integer> topicPartitionCounts = new HashMap<>();
+        topicPartitionCounts.put(topic1, 2);
+        topicPartitionCounts.put(topic2, 1);
+        topicPartitionCounts.put(Topic.GROUP_METADATA_TOPIC_NAME, 1);
+
+        metadata.requestUpdate(true);
+        Metadata.MetadataRequestAndVersion versionAndBuilder = metadata.newMetadataRequestAndVersion(time.milliseconds());
+        metadata.update(versionAndBuilder.requestVersion,
+                RequestTestUtils.metadataUpdateWith(clusterId,
+                        5,
+                        errorCounts,
+                        topicPartitionCounts,
+                        tp -> null,
+                        this::metadataSupplier,
+                        ApiKeys.METADATA.latestVersion(),
+                        topicIds),
+                false,
+                time.milliseconds());
+
+        return metadata;
+    }
+
+    private List<Node> initializeNodes(Metadata metadata) {
+        return new ArrayList<>(metadata.fetch().nodes());
+    }
+
+    private ClusterResourceListener initializeMockListener(Metadata metadata) {
+        ClusterResourceListener mockListener = Mockito.mock(ClusterResourceListener.class);
+        metadata.addClusterUpdateListener(mockListener);
+        return mockListener;
+    }
+
+    private MetadataResponse.PartitionMetadata metadataSupplier(Errors error, TopicPartition partition, Optional<Integer> leaderId, Optional<Integer> leaderEpoch, List<Integer> replicas, List<Integer> isr, List<Integer> offlineReplicas) {
+        if ("topic1".equals(partition.topic()) && partition.partition() == 0)
+            return new MetadataResponse.PartitionMetadata(Errors.NONE, partition, Optional.of(1), Optional.of(100), Arrays.asList(1, 2), Arrays.asList(1, 2), Collections.singletonList(3));
+        else if ("topic1".equals(partition.topic()) && partition.partition() == 1)
+            return new MetadataResponse.PartitionMetadata(Errors.NONE, partition, Optional.of(2), Optional.of(200), Arrays.asList(2, 3), Arrays.asList(2, 3), Collections.singletonList(1));
+        else if ("topic2".equals(partition.topic()) && partition.partition() == 0)
+            return new MetadataResponse.PartitionMetadata(Errors.NONE, partition, Optional.of(2), Optional.of(200), Arrays.asList(2, 3), Arrays.asList(2, 3), Collections.singletonList(1));
+        else if (Topic.GROUP_METADATA_TOPIC_NAME.equals(partition.topic()) && partition.partition() == 0)
+            return new MetadataResponse.PartitionMetadata(Errors.NONE, partition, Optional.of(2), Optional.of(300), Arrays.asList(2, 3), Arrays.asList(2, 3), Collections.singletonList(1));
+        else throw new RuntimeException("Unexpected partition " + partition);
+    }
+
+    private void validateForUpdatePartitionLeadership(Metadata updatedMetadata,
+                                                      MetadataResponse.PartitionMetadata part1Metadata,
+                                                      MetadataResponse.PartitionMetadata part2Metadata,
+                                                      MetadataResponse.PartitionMetadata part12Metadata,
+                                                      MetadataResponse.PartitionMetadata internalPartMetadata,
+                                                      List<Node> expectedNodes,
+                                                      String expectedClusterId,
+                                                      Set<String> expectedUnauthorisedTopics,
+                                                      Set<String> expectedInvalidTopics,
+                                                      Set<String> expectedInternalTopics,
+                                                      Node expectedController,
+                                                      Map<String, Uuid> expectedTopicIds) {
+        Cluster updatedCluster = updatedMetadata.fetch();
+        assertEquals(updatedCluster.clusterResource().clusterId(), expectedClusterId);
+        assertEquals(new HashSet<>(expectedNodes), new HashSet<>(updatedCluster.nodes()));
+        assertEquals(3, updatedCluster.topics().size());
+        assertEquals(expectedInternalTopics, updatedCluster.internalTopics());
+        assertEquals(expectedInvalidTopics, updatedCluster.invalidTopics());
+        assertEquals(expectedUnauthorisedTopics, updatedCluster.unauthorizedTopics());
+        assertEquals(expectedController, updatedCluster.controller());
+        assertEquals(expectedTopicIds, updatedMetadata.topicIds());
+
+        Map<Integer, Node> nodeMap = expectedNodes.stream().collect(Collectors.toMap(Node::id, Function.identity()));
+        for (MetadataResponse.PartitionMetadata partitionMetadata : Arrays.asList(part1Metadata, part2Metadata, part12Metadata, internalPartMetadata)) {
+            TopicPartition tp = new TopicPartition(partitionMetadata.topic(), partitionMetadata.partition());
+
+            Metadata.LeaderAndEpoch expectedLeaderInfo = new Metadata.LeaderAndEpoch(Optional.of(nodeMap.get(partitionMetadata.leaderId.get())), partitionMetadata.leaderEpoch);
+            assertEquals(expectedLeaderInfo, updatedMetadata.currentLeader(tp));
+
+            Optional<MetadataResponse.PartitionMetadata> optionalUpdatedMetadata = updatedMetadata.partitionMetadataIfCurrent(tp);
+            assertTrue(optionalUpdatedMetadata.isPresent());
+            MetadataResponse.PartitionMetadata updatedPartMetadata = optionalUpdatedMetadata.get();
+            assertEquals(partitionMetadata.topicPartition, updatedPartMetadata.topicPartition);
+            assertEquals(partitionMetadata.error, updatedPartMetadata.error);
+            assertEquals(partitionMetadata.leaderId, updatedPartMetadata.leaderId);
+            assertEquals(partitionMetadata.leaderEpoch, updatedPartMetadata.leaderEpoch);
+            assertEquals(partitionMetadata.replicaIds, updatedPartMetadata.replicaIds);
+            assertEquals(partitionMetadata.inSyncReplicaIds, updatedPartMetadata.inSyncReplicaIds);
+            assertEquals(partitionMetadata.offlineReplicaIds, partitionMetadata.offlineReplicaIds);
+        }
+    }
 }