Fix multi-topics consumer will crash if one internal consumer fails getBrokerConsumerStatsAsync (#538)
diff --git a/lib/ClientConnection.cc b/lib/ClientConnection.cc index 4f7a1dd..1d488d8 100644 --- a/lib/ClientConnection.cc +++ b/lib/ClientConnection.cc
@@ -997,9 +997,14 @@ lock.unlock(); LOG_ERROR(cnxString_ << " Client is not connected to the broker"); promise.setFailed(ResultNotConnected); + return promise.getFuture(); } pendingConsumerStatsMap_.insert(std::make_pair(requestId, promise)); lock.unlock(); + if (mockingRequests_.load(std::memory_order_acquire) && mockServer_ != nullptr && + mockServer_->sendRequest("CONSUMER_STATS", requestId)) { + return promise.getFuture(); + } sendCommand(Commands::newConsumerStats(consumerId, requestId)); return promise.getFuture(); }
diff --git a/lib/ClientConnection.h b/lib/ClientConnection.h index aae53d2..b277000 100644 --- a/lib/ClientConnection.h +++ b/lib/ClientConnection.h
@@ -219,6 +219,8 @@ mockingRequests_.store(true, std::memory_order_release); } + void handleKeepAliveTimeout(); + private: struct PendingRequestData { Promise<Result, ResponseData> promise; @@ -284,8 +286,6 @@ void handleGetLastMessageIdTimeout(const ASIO_ERROR&, const LastMessageIdRequestData& data); - void handleKeepAliveTimeout(); - template <typename Handler> inline AllocHandler<Handler> customAllocReadHandler(Handler h) { return AllocHandler<Handler>(readHandlerAllocator_, h);
diff --git a/lib/MockServer.h b/lib/MockServer.h index bd413d3..2d830fc 100644 --- a/lib/MockServer.h +++ b/lib/MockServer.h
@@ -75,11 +75,18 @@ } }); } - schedule(connection, request + std::to_string(requestId), iter->second, [connection, requestId] { - proto::CommandSuccess success; - success.set_request_id(requestId); - connection->handleSuccess(success); - }); + schedule(connection, request + std::to_string(requestId), iter->second, + [connection, request, requestId] { + if (request == "CONSUMER_STATS") { + proto::CommandConsumerStatsResponse response; + response.set_request_id(requestId); + connection->handleConsumerStatsResponse(response); + } else { + proto::CommandSuccess success; + success.set_request_id(requestId); + connection->handleSuccess(success); + } + }); return true; } else { return false;
diff --git a/lib/MultiTopicsConsumerImpl.cc b/lib/MultiTopicsConsumerImpl.cc index 6e0ba86..9c741fa 100644 --- a/lib/MultiTopicsConsumerImpl.cc +++ b/lib/MultiTopicsConsumerImpl.cc
@@ -847,48 +847,47 @@ Lock lock(mutex_); MultiTopicsBrokerConsumerStatsPtr statsPtr = std::make_shared<MultiTopicsBrokerConsumerStatsImpl>(numberTopicPartitions_->load()); - LatchPtr latchPtr = std::make_shared<Latch>(numberTopicPartitions_->load()); + auto latchPtr = std::make_shared<std::atomic_size_t>(numberTopicPartitions_->load()); lock.unlock(); size_t i = 0; - consumers_.forEachValue([this, &latchPtr, &statsPtr, &i, callback](const ConsumerImplPtr& consumer) { - size_t index = i++; - auto weakSelf = weak_from_this(); - consumer->getBrokerConsumerStatsAsync([this, weakSelf, latchPtr, statsPtr, index, callback]( - Result result, const BrokerConsumerStats& stats) { - auto self = weakSelf.lock(); - if (self) { - handleGetConsumerStats(result, stats, latchPtr, statsPtr, index, callback); - } + auto failedResult = std::make_shared<std::atomic<Result>>(ResultOk); + consumers_.forEachValue( + [this, &latchPtr, &statsPtr, &i, callback, &failedResult](const ConsumerImplPtr& consumer) { + size_t index = i++; + auto weakSelf = weak_from_this(); + consumer->getBrokerConsumerStatsAsync( + [this, weakSelf, latchPtr, statsPtr, index, callback, failedResult]( + Result result, const BrokerConsumerStats& stats) { + auto self = weakSelf.lock(); + if (!self) { + return; + } + if (result == ResultOk) { + std::lock_guard<std::mutex> lock{mutex_}; + statsPtr->add(stats, index); + } else { + // Store the first failed result as the final failed result + auto expected = ResultOk; + failedResult->compare_exchange_strong(expected, result); + } + if (--*latchPtr == 0) { + if (auto firstFailedResult = failedResult->load(std::memory_order_acquire); + firstFailedResult == ResultOk) { + callback(ResultOk, BrokerConsumerStats{statsPtr}); + } else { + // Fail the whole operation if any of the consumers failed + callback(firstFailedResult, {}); + } + } + }); }); - }); } void MultiTopicsConsumerImpl::getLastMessageIdAsync(const BrokerGetLastMessageIdCallback& callback) { callback(ResultOperationNotSupported, GetLastMessageIdResponse()); } -void MultiTopicsConsumerImpl::handleGetConsumerStats(Result res, - const BrokerConsumerStats& brokerConsumerStats, - const LatchPtr& latchPtr, - const MultiTopicsBrokerConsumerStatsPtr& statsPtr, - size_t index, - const BrokerConsumerStatsCallback& callback) { - Lock lock(mutex_); - if (res == ResultOk) { - latchPtr->countdown(); - statsPtr->add(brokerConsumerStats, index); - } else { - lock.unlock(); - callback(res, BrokerConsumerStats()); - return; - } - if (latchPtr->getCount() == 0) { - lock.unlock(); - callback(ResultOk, BrokerConsumerStats(statsPtr)); - } -} - std::shared_ptr<TopicName> MultiTopicsConsumerImpl::topicNamesValid(const std::vector<std::string>& topics) { TopicNamePtr topicNamePtr = std::shared_ptr<TopicName>();
diff --git a/lib/MultiTopicsConsumerImpl.h b/lib/MultiTopicsConsumerImpl.h index b22227e..dc62865 100644 --- a/lib/MultiTopicsConsumerImpl.h +++ b/lib/MultiTopicsConsumerImpl.h
@@ -28,7 +28,6 @@ #include "ConsumerImpl.h" #include "ConsumerInterceptors.h" #include "Future.h" -#include "Latch.h" #include "LookupDataResult.h" #include "SynchronizedHashMap.h" #include "TestUtil.h" @@ -100,9 +99,6 @@ uint64_t getNumberOfConnectedConsumer() override; void hasMessageAvailableAsync(const HasMessageAvailableCallback& callback) override; - void handleGetConsumerStats(Result, const BrokerConsumerStats&, const LatchPtr&, - const MultiTopicsBrokerConsumerStatsPtr&, size_t, - const BrokerConsumerStatsCallback&); // return first topic name when all topics name valid, or return null pointer static std::shared_ptr<TopicName> topicNamesValid(const std::vector<std::string>& topics); void unsubscribeOneTopicAsync(const std::string& topic, const ResultCallback& callback);
diff --git a/tests/ConsumerTest.cc b/tests/ConsumerTest.cc index f1bca77..795613e 100644 --- a/tests/ConsumerTest.cc +++ b/tests/ConsumerTest.cc
@@ -40,6 +40,7 @@ #include "WaitUtils.h" #include "lib/ClientConnection.h" #include "lib/Future.h" +#include "lib/Latch.h" #include "lib/LogUtils.h" #include "lib/MessageIdUtil.h" #include "lib/MultiTopicsConsumerImpl.h"
diff --git a/tests/MultiTopicsConsumerTest.cc b/tests/MultiTopicsConsumerTest.cc index d59b50d..db3bc96 100644 --- a/tests/MultiTopicsConsumerTest.cc +++ b/tests/MultiTopicsConsumerTest.cc
@@ -20,9 +20,13 @@ #include <pulsar/Client.h> #include <chrono> +#include <future> +#include <thread> #include "ThreadSafeMessages.h" #include "lib/LogUtils.h" +#include "lib/MockServer.h" +#include "tests/PulsarFriend.h" static const std::string lookupUrl = "pulsar://localhost:6650"; @@ -142,3 +146,29 @@ client.close(); } + +TEST(MultiTopicsConsumerTest, testGetConsumerStatsFail) { + Client client{lookupUrl}; + std::vector<std::string> topics{"testGetConsumerStatsFail0", "testGetConsumerStatsFail1"}; + Consumer consumer; + ASSERT_EQ(ResultOk, client.subscribe(topics, "sub", consumer)); + + auto connection = *PulsarFriend::getConnections(client).begin(); + auto mockServer = std::make_shared<MockServer>(connection); + connection->attachMockServer(mockServer); + + mockServer->setRequestDelay({{"CONSUMER_STATS", 3000}}); + auto future = std::async(std::launch::async, [&consumer]() { + BrokerConsumerStats stats; + return consumer.getBrokerConsumerStats(stats); + }); + // Trigger the `getBrokerConsumerStats` in a new thread + future.wait_for(std::chrono::milliseconds(100)); + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + connection->handleKeepAliveTimeout(); + ASSERT_EQ(ResultDisconnected, future.get()); + + mockServer->close(); + client.close(); +}