KAFKA-19617: ConsumerPerformance#ConsumerPerfRebListener get corrupted value when the number of partitions is increased (#20388)

With changes to the consumer protocol, rebalance may not necessarily
result in a "stop the world".  Thus, the method for calculating pause
time in `ConsumerPerformance#ConsumerPerfRebListener` needs to be
modified.

Stop time is only recorded if `assignedPartitions` is empty.

Reviewers: Andrew Schofield <aschofield@confluent.io>
diff --git a/tools/src/main/java/org/apache/kafka/tools/ConsumerPerformance.java b/tools/src/main/java/org/apache/kafka/tools/ConsumerPerformance.java
index 0334af8..f4a987e 100644
--- a/tools/src/main/java/org/apache/kafka/tools/ConsumerPerformance.java
+++ b/tools/src/main/java/org/apache/kafka/tools/ConsumerPerformance.java
@@ -36,6 +36,7 @@
 import java.text.SimpleDateFormat;
 import java.time.Duration;
 import java.util.Collection;
+import java.util.HashSet;
 import java.util.List;
 import java.util.Optional;
 import java.util.Properties;
@@ -165,7 +166,7 @@
                     if (showDetailedStats)
                         printConsumerProgress(0, bytesRead, lastBytesRead, recordsRead, lastRecordsRead,
                             lastReportTimeMs, currentTimeMs, dateFormat, joinTimeMsInSingleRound.get());
-                    joinTimeMsInSingleRound = new AtomicLong(0);
+                    joinTimeMsInSingleRound.set(0);
                     lastReportTimeMs = currentTimeMs;
                     lastRecordsRead = recordsRead;
                     lastBytesRead = bytesRead;
@@ -230,24 +231,32 @@
     public static class ConsumerPerfRebListener implements ConsumerRebalanceListener {
         private final AtomicLong joinTimeMs;
         private final AtomicLong joinTimeMsInSingleRound;
+        private final Collection<TopicPartition> assignedPartitions;
         private long joinStartMs;
 
         public ConsumerPerfRebListener(AtomicLong joinTimeMs, long joinStartMs, AtomicLong joinTimeMsInSingleRound) {
             this.joinTimeMs = joinTimeMs;
             this.joinStartMs = joinStartMs;
             this.joinTimeMsInSingleRound = joinTimeMsInSingleRound;
+            this.assignedPartitions = new HashSet<>();
         }
 
         @Override
         public void onPartitionsRevoked(Collection<TopicPartition> partitions) {
-            joinStartMs = System.currentTimeMillis();
+            assignedPartitions.removeAll(partitions);
+            if (assignedPartitions.isEmpty()) {
+                joinStartMs = System.currentTimeMillis();
+            }
         }
 
         @Override
         public void onPartitionsAssigned(Collection<TopicPartition> partitions) {
-            long elapsedMs = System.currentTimeMillis() - joinStartMs;
-            joinTimeMs.addAndGet(elapsedMs);
-            joinTimeMsInSingleRound.addAndGet(elapsedMs);
+            if (assignedPartitions.isEmpty()) {
+                long elapsedMs = System.currentTimeMillis() - joinStartMs;
+                joinTimeMs.addAndGet(elapsedMs);
+                joinTimeMsInSingleRound.addAndGet(elapsedMs);
+            }
+            assignedPartitions.addAll(partitions);
         }
     }
 
diff --git a/tools/src/test/java/org/apache/kafka/tools/ConsumerPerformanceTest.java b/tools/src/test/java/org/apache/kafka/tools/ConsumerPerformanceTest.java
index 497deb7..9e38587 100644
--- a/tools/src/test/java/org/apache/kafka/tools/ConsumerPerformanceTest.java
+++ b/tools/src/test/java/org/apache/kafka/tools/ConsumerPerformanceTest.java
@@ -20,6 +20,7 @@
 import org.apache.kafka.clients.consumer.ConsumerConfig;
 import org.apache.kafka.clients.consumer.MockConsumer;
 import org.apache.kafka.clients.consumer.internals.AutoOffsetResetStrategy;
+import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.utils.Exit;
 import org.apache.kafka.common.utils.Utils;
 
@@ -35,9 +36,12 @@
 import java.nio.file.Path;
 import java.text.SimpleDateFormat;
 import java.util.Properties;
+import java.util.Set;
+import java.util.concurrent.atomic.AtomicLong;
 import java.util.function.Function;
 
 import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotEquals;
 import static org.junit.jupiter.api.Assertions.assertTrue;
 
 public class ConsumerPerformanceTest {
@@ -302,6 +306,61 @@
         assertTrue(Utils.isBlank(err), "Should be no stderr message, but was \"" + err + "\"");
     }
 
+    @Test
+    public void testConsumerListenerWithAllPartitionRevokedAndAssigned() throws InterruptedException {
+        String topicName = "topic";
+        TopicPartition tp0 = new TopicPartition(topicName, 0);
+        TopicPartition tp1 = new TopicPartition(topicName, 1);
+        AtomicLong joinTimeMs = new AtomicLong(0);
+        AtomicLong joinTimeMsInSingleRound = new AtomicLong(0);
+        ConsumerPerformance.ConsumerPerfRebListener listener = new ConsumerPerformance.ConsumerPerfRebListener(joinTimeMs, 0, joinTimeMsInSingleRound);
+        listener.onPartitionsAssigned(Set.of(tp0));
+        long lastJoinTimeMs = joinTimeMs.get();
+
+        // All assigned partitions have been revoked.
+        listener.onPartitionsRevoked(Set.of(tp0));
+        Thread.sleep(100);
+        listener.onPartitionsAssigned(Set.of(tp1));
+
+        assertNotEquals(lastJoinTimeMs, joinTimeMs.get());
+    }
+
+    @Test
+    public void testConsumerListenerWithPartialPartitionRevokedAndAssigned() throws InterruptedException {
+        String topicName = "topic";
+        TopicPartition tp0 = new TopicPartition(topicName, 0);
+        TopicPartition tp1 = new TopicPartition(topicName, 1);
+        AtomicLong joinTimeMs = new AtomicLong(0);
+        AtomicLong joinTimeMsInSingleRound = new AtomicLong(0);
+        ConsumerPerformance.ConsumerPerfRebListener listener = new ConsumerPerformance.ConsumerPerfRebListener(joinTimeMs, 0, joinTimeMsInSingleRound);
+        listener.onPartitionsAssigned(Set.of(tp0, tp1));
+        long lastJoinTimeMs = joinTimeMs.get();
+
+        // The assigned partitions were partially revoked.
+        listener.onPartitionsRevoked(Set.of(tp0));
+        Thread.sleep(100);
+        listener.onPartitionsAssigned(Set.of(tp0));
+
+        assertEquals(lastJoinTimeMs, joinTimeMs.get());
+    }
+
+    @Test
+    public void testConsumerListenerWithoutPartitionRevoked() throws InterruptedException {
+        String topicName = "topic";
+        TopicPartition tp0 = new TopicPartition(topicName, 0);
+        TopicPartition tp1 = new TopicPartition(topicName, 1);
+        AtomicLong joinTimeMs = new AtomicLong(0);
+        AtomicLong joinTimeMsInSingleRound = new AtomicLong(0);
+        ConsumerPerformance.ConsumerPerfRebListener listener = new ConsumerPerformance.ConsumerPerfRebListener(joinTimeMs, 0, joinTimeMsInSingleRound);
+        listener.onPartitionsAssigned(Set.of(tp0));
+        long lastJoinTimeMs = joinTimeMs.get();
+
+        Thread.sleep(100);
+        listener.onPartitionsAssigned(Set.of(tp1));
+
+        assertEquals(lastJoinTimeMs, joinTimeMs.get());
+    }
+
     private void testHeaderMatchContent(boolean detailed, int expectedOutputLineCount, Runnable runnable) {
         String out = ToolsTestUtils.captureStandardOut(() -> {
             ConsumerPerformance.printHeader(detailed);