KAFKA-15775: New consumer listTopics and partitionsFor (#14962)

Implement Consumer.listTopics and Consumer.partitionsFor in the new consumer. The topic metadata request manager already existed so this PR adds expiration to requests, removes some redundant state checking and adds tests.

Reviewers: Lucas Brutschy <lucasbru@apache.org>
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AsyncKafkaConsumer.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AsyncKafkaConsumer.java
index 10d706a..56d65a1 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AsyncKafkaConsumer.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AsyncKafkaConsumer.java
@@ -51,6 +51,7 @@
 import org.apache.kafka.clients.consumer.internals.events.NewTopicsMetadataUpdateRequestEvent;
 import org.apache.kafka.clients.consumer.internals.events.ResetPositionsApplicationEvent;
 import org.apache.kafka.clients.consumer.internals.events.SubscriptionChangeApplicationEvent;
+import org.apache.kafka.clients.consumer.internals.events.TopicMetadataApplicationEvent;
 import org.apache.kafka.clients.consumer.internals.events.UnsubscribeApplicationEvent;
 import org.apache.kafka.clients.consumer.internals.events.ValidatePositionsApplicationEvent;
 import org.apache.kafka.common.Cluster;
@@ -811,7 +812,31 @@
 
     @Override
     public List<PartitionInfo> partitionsFor(String topic, Duration timeout) {
-        throw new KafkaException("method not implemented");
+        acquireAndEnsureOpen();
+        try {
+            Cluster cluster = this.metadata.fetch();
+            List<PartitionInfo> parts = cluster.partitionsForTopic(topic);
+            if (!parts.isEmpty())
+                return parts;
+
+            if (timeout.toMillis() == 0L) {
+                throw new TimeoutException();
+            }
+
+            final TopicMetadataApplicationEvent topicMetadataApplicationEvent =
+                    new TopicMetadataApplicationEvent(topic, timeout.toMillis());
+            wakeupTrigger.setActiveTask(topicMetadataApplicationEvent.future());
+            try {
+                Map<String, List<PartitionInfo>> topicMetadata =
+                        applicationEventHandler.addAndGet(topicMetadataApplicationEvent, time.timer(timeout));
+
+                return topicMetadata.getOrDefault(topic, Collections.emptyList());
+            } finally {
+                wakeupTrigger.clearTask();
+            }
+        } finally {
+            release();
+        }
     }
 
     @Override
@@ -821,7 +846,23 @@
 
     @Override
     public Map<String, List<PartitionInfo>> listTopics(Duration timeout) {
-        throw new KafkaException("method not implemented");
+        acquireAndEnsureOpen();
+        try {
+            if (timeout.toMillis() == 0L) {
+                throw new TimeoutException();
+            }
+
+            final TopicMetadataApplicationEvent topicMetadataApplicationEvent =
+                    new TopicMetadataApplicationEvent(timeout.toMillis());
+            wakeupTrigger.setActiveTask(topicMetadataApplicationEvent.future());
+            try {
+                return applicationEventHandler.addAndGet(topicMetadataApplicationEvent, time.timer(timeout));
+            } finally {
+                wakeupTrigger.clearTask();
+            }
+        } finally {
+            release();
+        }
     }
 
     @Override
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/TopicMetadataRequestManager.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/TopicMetadataRequestManager.java
index 8429e80..75a5ed0 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/TopicMetadataRequestManager.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/TopicMetadataRequestManager.java
@@ -23,6 +23,7 @@
 import org.apache.kafka.common.PartitionInfo;
 import org.apache.kafka.common.errors.InvalidTopicException;
 import org.apache.kafka.common.errors.RetriableException;
+import org.apache.kafka.common.errors.TimeoutException;
 import org.apache.kafka.common.errors.TopicAuthorizationException;
 import org.apache.kafka.common.protocol.Errors;
 import org.apache.kafka.common.requests.MetadataRequest;
@@ -30,9 +31,9 @@
 import org.apache.kafka.common.utils.LogContext;
 import org.slf4j.Logger;
 
-import java.util.ArrayList;
 import java.util.Collections;
 import java.util.HashMap;
+import java.util.LinkedList;
 import java.util.List;
 import java.util.Map;
 import java.util.Optional;
@@ -55,16 +56,13 @@
  * <p>
  * The manager checks the state of the {@link TopicMetadataRequestState} before sending a new one to
  * prevent sending it without backing off from previous attempts.
- * It also checks the state of inflight requests to avoid overwhelming the broker with duplicate requests.
- * The {@code inflightRequests} are memorized by topic name. If all topics are requested, then we use {@code Optional
- * .empty()} as the key.
- * Once a request is completed successfully, its corresponding entry is removed.
+ * Once a request is completed successfully or times out, its corresponding entry is removed.
  * </p>
  */
 
 public class TopicMetadataRequestManager implements RequestManager {
     private final boolean allowAutoTopicCreation;
-    private final Map<Optional<String>, TopicMetadataRequestState> inflightRequests;
+    private final List<TopicMetadataRequestState> inflightRequests;
     private final long retryBackoffMs;
     private final long retryBackoffMaxMs;
     private final Logger log;
@@ -73,7 +71,7 @@
     public TopicMetadataRequestManager(final LogContext context, final ConsumerConfig config) {
         logContext = context;
         log = logContext.logger(getClass());
-        inflightRequests = new HashMap<>();
+        inflightRequests = new LinkedList<>();
         retryBackoffMs = config.getLong(ConsumerConfig.RETRY_BACKOFF_MS_CONFIG);
         retryBackoffMaxMs = config.getLong(ConsumerConfig.RETRY_BACKOFF_MAX_MS_CONFIG);
         allowAutoTopicCreation = config.getBoolean(ConsumerConfig.ALLOW_AUTO_CREATE_TOPICS_CONFIG);
@@ -81,52 +79,87 @@
 
     @Override
     public NetworkClientDelegate.PollResult poll(final long currentTimeMs) {
-        List<NetworkClientDelegate.UnsentRequest> requests = inflightRequests.values().stream()
+        // Prune any requests which have timed out
+        List<TopicMetadataRequestState> expiredRequests = inflightRequests.stream()
+                .filter(req -> req.isExpired(currentTimeMs))
+                .collect(Collectors.toList());
+        expiredRequests.forEach(TopicMetadataRequestState::expire);
+
+        List<NetworkClientDelegate.UnsentRequest> requests = inflightRequests.stream()
             .map(req -> req.send(currentTimeMs))
             .filter(Optional::isPresent)
             .map(Optional::get)
             .collect(Collectors.toList());
+
         return requests.isEmpty() ? EMPTY : new NetworkClientDelegate.PollResult(0, requests);
     }
 
     /**
-     * return the future of the metadata request. Return the existing future if a request for the same topic is already
-     * inflight.
+     * Return the future of the metadata request.
      *
-     * @param topic to be requested. If empty, return the metadata for all topics.
      * @return the future of the metadata request.
      */
-    public CompletableFuture<Map<String, List<PartitionInfo>>> requestTopicMetadata(final Optional<String> topic) {
-        if (inflightRequests.containsKey(topic)) {
-            return inflightRequests.get(topic).future;
-        }
+    public CompletableFuture<Map<String, List<PartitionInfo>>> requestAllTopicsMetadata(final long expirationTimeMs) {
+        TopicMetadataRequestState newRequest = new TopicMetadataRequestState(
+                logContext,
+                expirationTimeMs,
+                retryBackoffMs,
+                retryBackoffMaxMs);
+        inflightRequests.add(newRequest);
+        return newRequest.future;
+    }
 
+    /**
+     * Return the future of the metadata request.
+     *
+     * @param topic to be requested.
+     * @return the future of the metadata request.
+     */
+    public CompletableFuture<Map<String, List<PartitionInfo>>> requestTopicMetadata(final String topic, final long expirationTimeMs) {
         TopicMetadataRequestState newRequest = new TopicMetadataRequestState(
                 logContext,
                 topic,
+                expirationTimeMs,
                 retryBackoffMs,
                 retryBackoffMaxMs);
-        inflightRequests.put(topic, newRequest);
+        inflightRequests.add(newRequest);
         return newRequest.future;
     }
 
     // Visible for testing
     List<TopicMetadataRequestState> inflightRequests() {
-        return new ArrayList<>(inflightRequests.values());
+        return inflightRequests;
     }
 
     class TopicMetadataRequestState extends RequestState {
-        private final Optional<String> topic;
+        private final String topic;
+        private final boolean allTopics;
+        private final long expirationTimeMs;
         CompletableFuture<Map<String, List<PartitionInfo>>> future;
 
         public TopicMetadataRequestState(final LogContext logContext,
-                                         final Optional<String> topic,
+                                         final long expirationTimeMs,
+                                         final long retryBackoffMs,
+                                         final long retryBackoffMaxMs) {
+            super(logContext, TopicMetadataRequestState.class.getSimpleName(), retryBackoffMs,
+                    retryBackoffMaxMs);
+            future = new CompletableFuture<>();
+            this.topic = null;
+            this.allTopics = true;
+            this.expirationTimeMs = expirationTimeMs;
+        }
+
+        public TopicMetadataRequestState(final LogContext logContext,
+                                         final String topic,
+                                         final long expirationTimeMs,
                                          final long retryBackoffMs,
                                          final long retryBackoffMaxMs) {
             super(logContext, TopicMetadataRequestState.class.getSimpleName(), retryBackoffMs,
                 retryBackoffMaxMs);
             future = new CompletableFuture<>();
             this.topic = topic;
+            this.allTopics = false;
+            this.expirationTimeMs = expirationTimeMs;
         }
 
         /**
@@ -134,18 +167,31 @@
          * {@link org.apache.kafka.clients.consumer.internals.NetworkClientDelegate.UnsentRequest} if needed.
          */
         private Optional<NetworkClientDelegate.UnsentRequest> send(final long currentTimeMs) {
+            if (currentTimeMs >= expirationTimeMs) {
+                return Optional.empty();
+            }
+
             if (!canSendRequest(currentTimeMs)) {
                 return Optional.empty();
             }
             onSendAttempt(currentTimeMs);
 
-            final MetadataRequest.Builder request =
-                topic.map(t -> new MetadataRequest.Builder(Collections.singletonList(t), allowAutoTopicCreation))
-                    .orElseGet(MetadataRequest.Builder::allTopics);
+            final MetadataRequest.Builder request = allTopics
+                ? MetadataRequest.Builder.allTopics()
+                : new MetadataRequest.Builder(Collections.singletonList(topic), allowAutoTopicCreation);
 
             return Optional.of(createUnsentRequest(request));
         }
 
+        private boolean isExpired(final long currentTimeMs) {
+            return currentTimeMs >= expirationTimeMs;
+        }
+
+        private void expire() {
+            completeFutureAndRemoveRequest(
+                    new TimeoutException("Timeout expired while fetching topic metadata"));
+        }
+
         private NetworkClientDelegate.UnsentRequest createUnsentRequest(
                 final MetadataRequest.Builder request) {
             NetworkClientDelegate.UnsentRequest unsent = new NetworkClientDelegate.UnsentRequest(
@@ -164,7 +210,12 @@
         private void handleError(final Throwable exception,
                                  final long completionTimeMs) {
             if (exception instanceof RetriableException) {
-                onFailedAttempt(completionTimeMs);
+                if (completionTimeMs >= expirationTimeMs) {
+                    completeFutureAndRemoveRequest(
+                        new TimeoutException("Timeout expired while fetching topic metadata"));
+                } else {
+                    onFailedAttempt(completionTimeMs);
+                }
             } else {
                 completeFutureAndRemoveRequest(exception);
             }
@@ -175,9 +226,14 @@
             try {
                 Map<String, List<PartitionInfo>> res = handleTopicMetadataResponse((MetadataResponse) response.responseBody());
                 future.complete(res);
-                inflightRequests.remove(topic);
+                inflightRequests.remove(this);
             } catch (RetriableException e) {
-                onFailedAttempt(responseTimeMs);
+                if (responseTimeMs >= expirationTimeMs) {
+                    completeFutureAndRemoveRequest(
+                        new TimeoutException("Timeout expired while fetching topic metadata"));
+                } else {
+                    onFailedAttempt(responseTimeMs);
+                }
             } catch (Exception t) {
                 completeFutureAndRemoveRequest(t);
             }
@@ -185,7 +241,7 @@
 
         private void completeFutureAndRemoveRequest(final Throwable throwable) {
             future.completeExceptionally(throwable);
-            inflightRequests.remove(topic);
+            inflightRequests.remove(this);
         }
 
         private Map<String, List<PartitionInfo>> handleTopicMetadataResponse(final MetadataResponse response) {
@@ -212,9 +268,9 @@
                         // if a requested topic is unknown, we just continue and let it be absent
                         // in the returned map
                         continue;
-                    else if (error.exception() instanceof RetriableException) {
+                    else if (error.exception() instanceof RetriableException)
                         throw error.exception();
-                    } else
+                    else
                         throw new KafkaException("Unexpected error fetching metadata for topic " + topic,
                             error.exception());
                 }
@@ -226,7 +282,7 @@
             return topicsPartitionInfos;
         }
 
-        public Optional<String> topic() {
+        public String topic() {
             return topic;
         }
     }
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/events/ApplicationEventProcessor.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/events/ApplicationEventProcessor.java
index 9c0bcde..35f5370 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/events/ApplicationEventProcessor.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/events/ApplicationEventProcessor.java
@@ -32,7 +32,6 @@
 
 import java.util.List;
 import java.util.Map;
-import java.util.Optional;
 import java.util.concurrent.BlockingQueue;
 import java.util.concurrent.CompletableFuture;
 import java.util.function.Supplier;
@@ -219,8 +218,16 @@
     }
 
     private void process(final TopicMetadataApplicationEvent event) {
-        final CompletableFuture<Map<String, List<PartitionInfo>>> future =
-                requestManagers.topicMetadataRequestManager.requestTopicMetadata(Optional.of(event.topic()));
+        final CompletableFuture<Map<String, List<PartitionInfo>>> future;
+
+        long expirationTimeMs =
+            (event.getTimeoutMs() == Long.MAX_VALUE) ? Long.MAX_VALUE : System.currentTimeMillis() + event.getTimeoutMs();
+        if (event.isAllTopics()) {
+            future = requestManagers.topicMetadataRequestManager.requestAllTopicsMetadata(expirationTimeMs);
+        } else {
+            future = requestManagers.topicMetadataRequestManager.requestTopicMetadata(event.topic(), expirationTimeMs);
+        }
+
         event.chain(future);
     }
 
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/events/TopicMetadataApplicationEvent.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/events/TopicMetadataApplicationEvent.java
index 6486fe6..dd6f842 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/events/TopicMetadataApplicationEvent.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/events/TopicMetadataApplicationEvent.java
@@ -20,21 +20,44 @@
 
 import java.util.List;
 import java.util.Map;
+import java.util.Objects;
 
 public class TopicMetadataApplicationEvent extends CompletableApplicationEvent<Map<String, List<PartitionInfo>>> {
     private final String topic;
-    public TopicMetadataApplicationEvent(final String topic) {
+    private final boolean allTopics;
+    private final long timeoutMs;
+
+    public TopicMetadataApplicationEvent(final long timeoutMs) {
+        super(Type.TOPIC_METADATA);
+        this.topic = null;
+        this.allTopics = true;
+        this.timeoutMs = timeoutMs;
+    }
+
+    public TopicMetadataApplicationEvent(final String topic, final long timeoutMs) {
         super(Type.TOPIC_METADATA);
         this.topic = topic;
+        this.allTopics = false;
+        this.timeoutMs = timeoutMs;
     }
 
     public String topic() {
         return topic;
     }
 
+    public boolean isAllTopics() {
+        return allTopics;
+    }
+
+    public long getTimeoutMs() {
+        return timeoutMs;
+    }
     @Override
     public String toString() {
-        return "TopicMetadataApplicationEvent(topic=" + topic + ")";
+        return getClass().getSimpleName() + " {" + toStringBase() +
+                ", topic=" + topic +
+                ", allTopics=" + allTopics +
+                ", timeoutMs=" + timeoutMs + "}";
     }
 
     @Override
@@ -45,13 +68,11 @@
 
         TopicMetadataApplicationEvent that = (TopicMetadataApplicationEvent) o;
 
-        return topic.equals(that.topic);
+        return topic.equals(that.topic) && (allTopics == that.allTopics) && (timeoutMs == that.timeoutMs);
     }
 
     @Override
     public int hashCode() {
-        int result = super.hashCode();
-        result = 31 * result + topic.hashCode();
-        return result;
+        return Objects.hash(super.hashCode(), topic, allTopics, timeoutMs);
     }
 }
diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerNetworkThreadTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerNetworkThreadTest.java
index a137091..0eefb59 100644
--- a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerNetworkThreadTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerNetworkThreadTest.java
@@ -212,7 +212,7 @@
 
     @Test
     void testFetchTopicMetadata() {
-        applicationEventsQueue.add(new TopicMetadataApplicationEvent("topic"));
+        applicationEventsQueue.add(new TopicMetadataApplicationEvent("topic", Long.MAX_VALUE));
         consumerNetworkThread.runOnce();
         verify(applicationEventProcessor).process(any(TopicMetadataApplicationEvent.class));
     }
diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/TopicMetadataRequestManagerTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/TopicMetadataRequestManagerTest.java
index e72172e..c7b2315 100644
--- a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/TopicMetadataRequestManagerTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/TopicMetadataRequestManagerTest.java
@@ -48,7 +48,6 @@
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
-import java.util.Optional;
 import java.util.Properties;
 import java.util.concurrent.CompletableFuture;
 
@@ -58,7 +57,6 @@
 import static org.junit.jupiter.api.Assertions.assertEquals;
 import static org.junit.jupiter.api.Assertions.assertFalse;
 import static org.junit.jupiter.api.Assertions.assertTrue;
-import static org.junit.jupiter.api.Assertions.fail;
 import static org.mockito.Mockito.spy;
 
 public class TopicMetadataRequestManagerTest {
@@ -78,10 +76,18 @@
             new ConsumerConfig(props)));
     }
 
-    @ParameterizedTest
-    @MethodSource("topicsProvider")
-    public void testPoll_SuccessfulRequestTopicMetadata(Optional<String> topic) {
-        this.topicMetadataRequestManager.requestTopicMetadata(topic);
+    @Test
+    public void testPoll_SuccessfulRequestTopicMetadata() {
+        String topic = "hello";
+        this.topicMetadataRequestManager.requestTopicMetadata(topic, Long.MAX_VALUE);
+        this.time.sleep(100);
+        NetworkClientDelegate.PollResult res = this.topicMetadataRequestManager.poll(this.time.milliseconds());
+        assertEquals(1, res.unsentRequests.size());
+    }
+
+    @Test
+    public void testPoll_SuccessfulRequestAllTopicsMetadata() {
+        this.topicMetadataRequestManager.requestAllTopicsMetadata(Long.MAX_VALUE);
         this.time.sleep(100);
         NetworkClientDelegate.PollResult res = this.topicMetadataRequestManager.poll(this.time.milliseconds());
         assertEquals(1, res.unsentRequests.size());
@@ -89,56 +95,86 @@
 
     @ParameterizedTest
     @MethodSource("exceptionProvider")
-    public void testExceptionAndInflightRequests(final Errors error, final boolean shouldRetry) {
+    public void testTopicExceptionAndInflightRequests(final Errors error, final boolean shouldRetry) {
         String topic = "hello";
-        this.topicMetadataRequestManager.requestTopicMetadata(Optional.of("hello"));
+        this.topicMetadataRequestManager.requestTopicMetadata(topic, Long.MAX_VALUE);
         this.time.sleep(100);
         NetworkClientDelegate.PollResult res = this.topicMetadataRequestManager.poll(this.time.milliseconds());
         res.unsentRequests.get(0).future().complete(buildTopicMetadataClientResponse(
             res.unsentRequests.get(0),
-            Optional.of(topic),
+            topic,
             error));
         List<TopicMetadataRequestManager.TopicMetadataRequestState> inflights = this.topicMetadataRequestManager.inflightRequests();
 
         if (shouldRetry) {
             assertEquals(1, inflights.size());
-            assertEquals(topic, inflights.get(0).topic().orElse(null));
+            assertEquals(topic, inflights.get(0).topic());
         } else {
             assertEquals(0, inflights.size());
         }
     }
 
     @ParameterizedTest
-    @MethodSource("topicsProvider")
-    public void testSendingTheSameRequest(Optional<String> topic) {
-        CompletableFuture<Map<String, List<PartitionInfo>>> future = this.topicMetadataRequestManager.requestTopicMetadata(topic);
-        CompletableFuture<Map<String, List<PartitionInfo>>> future2 = this.topicMetadataRequestManager.requestTopicMetadata(topic);
+    @MethodSource("exceptionProvider")
+    public void testAllTopicsExceptionAndInflightRequests(final Errors error, final boolean shouldRetry) {
+        this.topicMetadataRequestManager.requestAllTopicsMetadata(Long.MAX_VALUE);
         this.time.sleep(100);
         NetworkClientDelegate.PollResult res = this.topicMetadataRequestManager.poll(this.time.milliseconds());
-        assertEquals(1, res.unsentRequests.size());
+        res.unsentRequests.get(0).future().complete(buildAllTopicsMetadataClientResponse(
+                res.unsentRequests.get(0),
+                error));
+        List<TopicMetadataRequestManager.TopicMetadataRequestState> inflights = this.topicMetadataRequestManager.inflightRequests();
 
+        if (shouldRetry) {
+            assertEquals(1, inflights.size());
+        } else {
+            assertEquals(0, inflights.size());
+        }
+    }
+
+    @Test
+    public void testExpiringRequest() {
+        String topic = "hello";
+
+        // Request topic metadata with 1000ms expiration
+        long now = this.time.milliseconds();
+        CompletableFuture<Map<String, List<PartitionInfo>>> future =
+            this.topicMetadataRequestManager.requestTopicMetadata(topic, now + 1000L);
+        assertEquals(1, this.topicMetadataRequestManager.inflightRequests().size());
+
+        // Poll the request manager to get the list of requests to send
+        // - fail the request with a RetriableException
+        NetworkClientDelegate.PollResult res = this.topicMetadataRequestManager.poll(this.time.milliseconds());
+        assertEquals(1, res.unsentRequests.size());
         res.unsentRequests.get(0).future().complete(buildTopicMetadataClientResponse(
             res.unsentRequests.get(0),
             topic,
-            Errors.NONE));
+            Errors.REQUEST_TIMED_OUT));
 
-        assertTrue(future.isDone());
-        assertFalse(future.isCompletedExceptionally());
-        try {
-            future.get();
-        } catch (Throwable e) {
-            fail("Expecting to succeed, but got: {}", e);
-        }
-        assertTrue(future2.isDone());
-        assertFalse(future2.isCompletedExceptionally());
+        // Sleep for long enough to exceed the backoff delay but still within the expiration
+        // - fail the request again with a RetriableException
+        this.time.sleep(500);
+        res = this.topicMetadataRequestManager.poll(this.time.milliseconds());
+        assertEquals(1, res.unsentRequests.size());
+        res.unsentRequests.get(0).future().complete(buildTopicMetadataClientResponse(
+            res.unsentRequests.get(0),
+            topic,
+            Errors.REQUEST_TIMED_OUT));
+
+        // Sleep for long enough to expire the request which should fail
+        this.time.sleep(1000);
+        res = this.topicMetadataRequestManager.poll(this.time.milliseconds());
+        assertEquals(0, res.unsentRequests.size());
+        assertEquals(0, this.topicMetadataRequestManager.inflightRequests().size());
+        assertTrue(future.isCompletedExceptionally());
     }
 
     @ParameterizedTest
     @MethodSource("hardFailureExceptionProvider")
     public void testHardFailures(Exception exception) {
-        Optional<String> topic = Optional.of("hello");
+        String topic = "hello";
 
-        this.topicMetadataRequestManager.requestTopicMetadata(topic);
+        this.topicMetadataRequestManager.requestTopicMetadata(topic, Long.MAX_VALUE);
         NetworkClientDelegate.PollResult res = this.topicMetadataRequestManager.poll(this.time.milliseconds());
         assertEquals(1, res.unsentRequests.size());
 
@@ -153,9 +189,9 @@
 
     @Test
     public void testNetworkTimeout() {
-        Optional<String> topic = Optional.of("hello");
+        String topic = "hello";
 
-        topicMetadataRequestManager.requestTopicMetadata(topic);
+        topicMetadataRequestManager.requestTopicMetadata(topic, Long.MAX_VALUE);
         NetworkClientDelegate.PollResult res = this.topicMetadataRequestManager.poll(this.time.milliseconds());
         assertEquals(1, res.unsentRequests.size());
         NetworkClientDelegate.PollResult res2 = this.topicMetadataRequestManager.poll(this.time.milliseconds());
@@ -182,24 +218,44 @@
     }
 
     private ClientResponse buildTopicMetadataClientResponse(
+            final NetworkClientDelegate.UnsentRequest request,
+            final String topic,
+            final Errors error) {
+        AbstractRequest abstractRequest = request.requestBuilder().build();
+        assertTrue(abstractRequest instanceof MetadataRequest);
+        MetadataRequest metadataRequest = (MetadataRequest) abstractRequest;
+        Cluster cluster = mockCluster(3, 0);
+        List<MetadataResponse.TopicMetadata> topics = new ArrayList<>();
+        topics.add(new MetadataResponse.TopicMetadata(error, topic, false,
+                Collections.emptyList()));
+        final MetadataResponse metadataResponse = RequestTestUtils.metadataResponse(cluster.nodes(),
+                cluster.clusterResource().clusterId(),
+                cluster.controller().id(),
+                topics);
+        return new ClientResponse(
+                new RequestHeader(ApiKeys.METADATA, metadataRequest.version(), "mockClientId", 1),
+                request.handler(),
+                "-1",
+                time.milliseconds(),
+                time.milliseconds(),
+                false,
+                null,
+                null,
+                metadataResponse);
+    }
+
+    private ClientResponse buildAllTopicsMetadataClientResponse(
         final NetworkClientDelegate.UnsentRequest request,
-        final Optional<String> topic,
         final Errors error) {
         AbstractRequest abstractRequest = request.requestBuilder().build();
         assertTrue(abstractRequest instanceof MetadataRequest);
         MetadataRequest metadataRequest = (MetadataRequest) abstractRequest;
         Cluster cluster = mockCluster(3, 0);
         List<MetadataResponse.TopicMetadata> topics = new ArrayList<>();
-        if (topic.isPresent()) {
-            topics.add(new MetadataResponse.TopicMetadata(error, topic.get(), false,
-                Collections.emptyList()));
-        } else {
-            // null topic means request for all topics
-            topics.add(new MetadataResponse.TopicMetadata(error, "topic1", false,
-                Collections.emptyList()));
-            topics.add(new MetadataResponse.TopicMetadata(error, "topic2", false,
-                Collections.emptyList()));
-        }
+        topics.add(new MetadataResponse.TopicMetadata(error, "topic1", false,
+            Collections.emptyList()));
+        topics.add(new MetadataResponse.TopicMetadata(error, "topic2", false,
+            Collections.emptyList()));
         final MetadataResponse metadataResponse = RequestTestUtils.metadataResponse(cluster.nodes(),
             cluster.clusterResource().clusterId(),
             cluster.controller().id(),
@@ -226,12 +282,6 @@
     }
 
 
-    private static Collection<Arguments> topicsProvider() {
-        return Arrays.asList(
-            Arguments.of(Optional.of("topic1")),
-            Arguments.of(Optional.empty()));
-    }
-
     private static Collection<Arguments> exceptionProvider() {
         return Arrays.asList(
             Arguments.of(Errors.UNKNOWN_TOPIC_OR_PARTITION, false),
diff --git a/core/src/test/scala/integration/kafka/api/PlaintextConsumerTest.scala b/core/src/test/scala/integration/kafka/api/PlaintextConsumerTest.scala
index b776809..c0a249b 100644
--- a/core/src/test/scala/integration/kafka/api/PlaintextConsumerTest.scala
+++ b/core/src/test/scala/integration/kafka/api/PlaintextConsumerTest.scala
@@ -564,9 +564,8 @@
     awaitAssignment(consumer, shrunkenAssignment)
   }
 
-  // partitionsFor not implemented in consumer group protocol
   @ParameterizedTest(name = TestInfoUtils.TestWithParameterizedQuorumAndGroupProtocolNames)
-  @MethodSource(Array("getTestQuorumAndGroupProtocolParametersGenericGroupProtocolOnly"))
+  @MethodSource(Array("getTestQuorumAndGroupProtocolParametersAll"))
   def testPartitionsFor(quorum: String, groupProtocol: String): Unit = {
     val numParts = 2
     createTopic("part-test", numParts, 1)
@@ -576,9 +575,8 @@
     assertEquals(2, parts.size)
   }
 
-  // partitionsFor not implemented in consumer group protocol
   @ParameterizedTest(name = TestInfoUtils.TestWithParameterizedQuorumAndGroupProtocolNames)
-  @MethodSource(Array("getTestQuorumAndGroupProtocolParametersGenericGroupProtocolOnly"))
+  @MethodSource(Array("getTestQuorumAndGroupProtocolParametersAll"))
   def testPartitionsForAutoCreate(quorum: String, groupProtocol: String): Unit = {
     val consumer = createConsumer()
     // First call would create the topic
@@ -588,9 +586,8 @@
     }, s"Timed out while awaiting non empty partitions.")
   }
 
-  // partitionsFor not implemented in consumer group protocol
   @ParameterizedTest(name = TestInfoUtils.TestWithParameterizedQuorumAndGroupProtocolNames)
-  @MethodSource(Array("getTestQuorumAndGroupProtocolParametersGenericGroupProtocolOnly"))
+  @MethodSource(Array("getTestQuorumAndGroupProtocolParametersAll"))
   def testPartitionsForInvalidTopic(quorum: String, groupProtocol: String): Unit = {
     val consumer = createConsumer()
     assertThrows(classOf[InvalidTopicException], () => consumer.partitionsFor(";3# ads,{234"))
@@ -1471,9 +1468,8 @@
       startingTimestamp = startTime, timestampType = TimestampType.LOG_APPEND_TIME)
   }
 
-  // listTopics temporarily not supported for consumer group protocol
   @ParameterizedTest(name = TestInfoUtils.TestWithParameterizedQuorumAndGroupProtocolNames)
-  @MethodSource(Array("getTestQuorumAndGroupProtocolParametersGenericGroupProtocolOnly"))
+  @MethodSource(Array("getTestQuorumAndGroupProtocolParametersAll"))
   def testListTopics(quorum: String, groupProtocol: String): Unit = {
     val numParts = 2
     val topic1 = "part-test-topic-1"