fix(hotkey): fix two hidden dangers in hotkey_collector (#632)

diff --git a/src/server/hotkey_collector.cpp b/src/server/hotkey_collector.cpp
index 892e6eb..52808ba 100644
--- a/src/server/hotkey_collector.cpp
+++ b/src/server/hotkey_collector.cpp
@@ -39,7 +39,6 @@
     3,
     "the variance threshold to detect hot key during fine analysis of hotkey detection");
 
-// TODO: (Tangyanzhao) add a limit to avoid changing when detecting
 DSN_DEFINE_uint32("pegasus.server",
                   hotkey_buckets_num,
                   37,
@@ -58,7 +57,7 @@
     return true;
 });
 
-DSN_DEFINE_int32(
+DSN_DEFINE_uint32(
     "pegasus.server",
     max_seconds_to_detect_hotkey,
     150,
@@ -101,20 +100,78 @@
 }
 
 // TODO: (Tangyanzhao) replace it to xxhash
-/*extern*/ int get_bucket_id(dsn::string_view data)
+
+/*extern*/ int get_bucket_id(dsn::string_view data, int bucket_num)
 {
-    size_t hash_value = boost::hash_range(data.begin(), data.end());
-    return static_cast<int>(hash_value % FLAGS_hotkey_buckets_num);
+    return static_cast<int>(boost::hash_range(data.begin(), data.end()) % bucket_num);
 }
 
 hotkey_collector::hotkey_collector(dsn::replication::hotkey_type::type hotkey_type,
                                    dsn::replication::replica_base *r_base)
-    : replica_base(r_base),
-      _state(hotkey_collector_state::STOPPED),
-      _hotkey_type(hotkey_type),
-      _internal_collector(std::make_shared<hotkey_empty_data_collector>(this)),
-      _collector_start_time_second(0)
+    : replica_base(r_base), _hotkey_type(hotkey_type)
 {
+    int now_hash_bucket_num = FLAGS_hotkey_buckets_num;
+    _internal_coarse_collector =
+        std::make_shared<hotkey_coarse_data_collector>(this, now_hash_bucket_num);
+    _internal_fine_collector =
+        std::make_shared<hotkey_fine_data_collector>(this, now_hash_bucket_num);
+    _internal_empty_collector = std::make_shared<hotkey_empty_data_collector>(this);
+}
+
+inline void hotkey_collector::change_state_to_stopped()
+{
+    _state.store(hotkey_collector_state::STOPPED);
+    _result.if_find_result.store(false);
+    _internal_coarse_collector->clear();
+    _internal_fine_collector->clear();
+}
+
+inline void hotkey_collector::change_state_to_coarse_detecting()
+{
+    _state.store(hotkey_collector_state::COARSE_DETECTING);
+    _collector_start_time_second.store(dsn_now_s());
+}
+
+inline void hotkey_collector::change_state_to_fine_detecting()
+{
+    _state.store(hotkey_collector_state::FINE_DETECTING);
+    _internal_fine_collector->change_target_bucket(_result.coarse_bucket_index);
+}
+
+inline void hotkey_collector::change_state_to_finished()
+{
+    _state.store(hotkey_collector_state::FINISHED);
+    _result.if_find_result.store(true);
+}
+
+inline std::shared_ptr<internal_collector_base> hotkey_collector::get_internal_collector_by_state()
+{
+    switch (_state.load()) {
+    case hotkey_collector_state::COARSE_DETECTING:
+        return _internal_coarse_collector;
+    case hotkey_collector_state::FINE_DETECTING:
+        return _internal_fine_collector;
+    default:
+        return _internal_empty_collector;
+    }
+}
+
+inline void hotkey_collector::change_state_by_result()
+{
+    switch (_state.load()) {
+    case hotkey_collector_state::COARSE_DETECTING:
+        if (_result.coarse_bucket_index != -1) {
+            change_state_to_fine_detecting();
+        }
+        break;
+    case hotkey_collector_state::FINE_DETECTING:
+        if (!_result.hot_hash_key.empty()) {
+            change_state_to_finished();
+        }
+        break;
+    default:
+        break;
+    }
 }
 
 void hotkey_collector::handle_rpc(const dsn::replication::detect_hotkey_request &req,
@@ -145,19 +202,24 @@
 void hotkey_collector::capture_hash_key(const dsn::blob &hash_key, int64_t weight)
 {
     // TODO: (Tangyanzhao) add a unit test to ensure data integrity
-    _internal_collector->capture_data(hash_key, weight);
+    switch (_state.load()) {
+    case hotkey_collector_state::COARSE_DETECTING:
+    case hotkey_collector_state::FINE_DETECTING:
+        get_internal_collector_by_state()->capture_data(hash_key, weight > 0 ? weight : 1);
+        return;
+    default:
+        return;
+    }
 }
 
 void hotkey_collector::analyse_data()
 {
     switch (_state.load()) {
     case hotkey_collector_state::COARSE_DETECTING:
+    case hotkey_collector_state::FINE_DETECTING:
         if (!terminate_if_timeout()) {
-            _internal_collector->analyse_data(_result);
-            if (_result.coarse_bucket_index != -1) {
-                // TODO: (Tangyanzhao) reset _internal_collector to hotkey_fine_data_collector
-                _state.store(hotkey_collector_state::FINE_DETECTING);
-            }
+            get_internal_collector_by_state()->analyse_data(_result);
+            change_state_by_result();
         }
         return;
     default:
@@ -186,9 +248,7 @@
         dwarn_replica(hint);
         return;
     case hotkey_collector_state::STOPPED:
-        _collector_start_time_second = dsn_now_s();
-        _internal_collector.reset(new hotkey_coarse_data_collector(this));
-        _state.store(hotkey_collector_state::COARSE_DETECTING);
+        change_state_to_coarse_detecting();
         resp.err = dsn::ERR_OK;
         hint = fmt::format("starting to detect {} hotkey", dsn::enum_to_string(_hotkey_type));
         ddebug_replica(hint);
@@ -204,32 +264,28 @@
 
 void hotkey_collector::on_stop_detect(dsn::replication::detect_hotkey_response &resp)
 {
-    terminate();
+    change_state_to_stopped();
     resp.err = dsn::ERR_OK;
     std::string hint =
         fmt::format("{} hotkey stopped, cache cleared", dsn::enum_to_string(_hotkey_type));
     ddebug_replica(hint);
 }
 
-void hotkey_collector::terminate()
-{
-    _state.store(hotkey_collector_state::STOPPED);
-    _internal_collector.reset();
-    _collector_start_time_second = 0;
-}
-
 bool hotkey_collector::terminate_if_timeout()
 {
-    if (dsn_now_s() >= _collector_start_time_second + FLAGS_max_seconds_to_detect_hotkey) {
+    if (dsn_now_s() >= _collector_start_time_second.load() + FLAGS_max_seconds_to_detect_hotkey) {
         ddebug_replica("hotkey collector work time is exhausted but no hotkey has been found");
-        terminate();
+        change_state_to_stopped();
         return true;
     }
     return false;
 }
 
-hotkey_coarse_data_collector::hotkey_coarse_data_collector(replica_base *base)
-    : internal_collector_base(base), _hash_buckets(FLAGS_hotkey_buckets_num)
+hotkey_coarse_data_collector::hotkey_coarse_data_collector(replica_base *base,
+                                                           uint32_t hotkey_buckets_num)
+    : internal_collector_base(base),
+      _hash_bucket_num(hotkey_buckets_num),
+      _hash_buckets(hotkey_buckets_num)
 {
     for (auto &bucket : _hash_buckets) {
         bucket.store(0);
@@ -238,12 +294,12 @@
 
 void hotkey_coarse_data_collector::capture_data(const dsn::blob &hash_key, uint64_t weight)
 {
-    _hash_buckets[get_bucket_id(hash_key)].fetch_add(weight);
+    _hash_buckets[get_bucket_id(hash_key, _hash_bucket_num)].fetch_add(weight);
 }
 
 void hotkey_coarse_data_collector::analyse_data(detect_hotkey_result &result)
 {
-    std::vector<uint64_t> buckets(FLAGS_hotkey_buckets_num);
+    std::vector<uint64_t> buckets(_hash_bucket_num);
     for (int i = 0; i < buckets.size(); i++) {
         buckets[i] = _hash_buckets[i].load();
         _hash_buckets[i].store(0);
@@ -256,24 +312,31 @@
 
 void hotkey_coarse_data_collector::clear()
 {
-    for (int i = 0; i < FLAGS_hotkey_buckets_num; i++) {
+    for (int i = 0; i < _hash_bucket_num; i++) {
         _hash_buckets[i].store(0);
     }
 }
 
 hotkey_fine_data_collector::hotkey_fine_data_collector(replica_base *base,
-                                                       int target_bucket_index,
-                                                       int max_queue_size)
+                                                       uint32_t hotkey_buckets_num,
+                                                       uint32_t max_queue_size)
     : internal_collector_base(base),
       _max_queue_size(max_queue_size),
-      _target_bucket_index(target_bucket_index),
-      _capture_key_queue(max_queue_size)
+      _capture_key_queue(max_queue_size),
+      _hash_bucket_num(hotkey_buckets_num)
+
 {
+    _target_bucket_index.store(-1);
+}
+
+void hotkey_fine_data_collector::change_target_bucket(int target_bucket_index)
+{
+    _target_bucket_index.store(target_bucket_index);
 }
 
 void hotkey_fine_data_collector::capture_data(const dsn::blob &hash_key, uint64_t weight)
 {
-    if (get_bucket_id(hash_key) != _target_bucket_index) {
+    if (get_bucket_id(hash_key, _hash_bucket_num) != _target_bucket_index.load()) {
         return;
     }
     // abandon the key if enqueue failed (possibly because not enough room to enqueue)
@@ -325,9 +388,9 @@
         }
     }
 
-    // hash_key_counts stores the number of occurrences of each string captured in a period of time
-    // The size of weights influences our hotkey determination strategy
-    // weights.size() <= 2: the hotkey must exist (the most weighted key), because
+    // hash_key_counts stores the number of occurrences of each string captured in a period of
+    // time The size of weights influences our hotkey determination strategy weights.size() <=
+    // 2: the hotkey must exist (the most weighted key), because
     //                      the two-level filtering significantly reduces the
     //                      possibility that the hottest key is not the actual hotkey.
     // weights.size() >= 3: use find_outlier_index to determine whether a hotkey exists
@@ -340,6 +403,7 @@
 
 void hotkey_fine_data_collector::clear()
 {
+    _target_bucket_index.store(-1);
     std::pair<dsn::blob, uint64_t> key_weight_pair;
     while (_capture_key_queue.try_dequeue(key_weight_pair)) {
     }
diff --git a/src/server/hotkey_collector.h b/src/server/hotkey_collector.h
index 066c7df..82d5d61 100644
--- a/src/server/hotkey_collector.h
+++ b/src/server/hotkey_collector.h
@@ -26,14 +26,22 @@
 namespace server {
 
 class internal_collector_base;
+class hotkey_empty_data_collector;
+class hotkey_coarse_data_collector;
+class hotkey_fine_data_collector;
 
 struct detect_hotkey_result
 {
-    int coarse_bucket_index = -1;
+    std::atomic<bool> if_find_result;
+    int coarse_bucket_index;
     std::string hot_hash_key;
+    detect_hotkey_result() : coarse_bucket_index(-1), hot_hash_key("")
+    {
+        if_find_result.store(false);
+    }
 };
 
-extern int get_bucket_id(dsn::string_view data);
+extern int get_bucket_id(dsn::string_view data, int bucket_num);
 extern bool
 find_outlier_index(const std::vector<uint64_t> &captured_keys, int threshold, int &hot_index);
 
@@ -90,14 +98,24 @@
 private:
     void on_start_detect(dsn::replication::detect_hotkey_response &resp);
     void on_stop_detect(dsn::replication::detect_hotkey_response &resp);
-    void terminate();
-    bool terminate_if_timeout();
 
+    void change_state_to_stopped();
+    void change_state_to_coarse_detecting();
+    void change_state_to_fine_detecting();
+    void change_state_to_finished();
+
+    bool terminate_if_timeout();
+    std::shared_ptr<internal_collector_base> get_internal_collector_by_state();
+    void change_state_by_result();
+
+    const dsn::replication::hotkey_type::type _hotkey_type;
     detect_hotkey_result _result;
     std::atomic<hotkey_collector_state> _state;
-    const dsn::replication::hotkey_type::type _hotkey_type;
-    std::shared_ptr<internal_collector_base> _internal_collector;
-    uint64_t _collector_start_time_second;
+    std::atomic<uint64_t> _collector_start_time_second;
+
+    std::shared_ptr<hotkey_empty_data_collector> _internal_empty_collector;
+    std::shared_ptr<hotkey_coarse_data_collector> _internal_coarse_collector;
+    std::shared_ptr<hotkey_fine_data_collector> _internal_fine_collector;
 };
 
 // Be sure every function in internal_collector_base should be thread safe
@@ -124,7 +142,7 @@
 class hotkey_coarse_data_collector : public internal_collector_base
 {
 public:
-    explicit hotkey_coarse_data_collector(replica_base *base);
+    explicit hotkey_coarse_data_collector(replica_base *base, uint32_t hotkey_buckets_num);
     void capture_data(const dsn::blob &hash_key, uint64_t weight) override;
     void analyse_data(detect_hotkey_result &result) override;
     void clear() override;
@@ -132,6 +150,7 @@
 private:
     hotkey_coarse_data_collector() = delete;
 
+    const uint32_t _hash_bucket_num;
     std::vector<std::atomic<uint64_t>> _hash_buckets;
 
     friend class coarse_collector_test;
@@ -140,18 +159,22 @@
 class hotkey_fine_data_collector : public internal_collector_base
 {
 public:
-    hotkey_fine_data_collector(replica_base *base, int target_bucket_index, int max_queue_size);
+    hotkey_fine_data_collector(replica_base *base,
+                               uint32_t hotkey_buckets_num,
+                               uint32_t max_queue_size = 1000);
     void capture_data(const dsn::blob &hash_key, uint64_t weight) override;
     void analyse_data(detect_hotkey_result &result) override;
+    void change_target_bucket(int target_bucket_index);
     void clear() override;
 
 private:
     hotkey_fine_data_collector() = delete;
 
     const uint32_t _max_queue_size;
-    const uint32_t _target_bucket_index;
+    std::atomic<int32_t> _target_bucket_index;
     // ConcurrentQueue is a lock-free queue to capture keys
     moodycamel::ConcurrentQueue<std::pair<dsn::blob, uint64_t>> _capture_key_queue;
+    const uint32_t _hash_bucket_num;
 
     friend class fine_collector_test;
 };
diff --git a/src/server/test/hotkey_collector_test.cpp b/src/server/test/hotkey_collector_test.cpp
index 7c98ed9..7f2b5fd 100644
--- a/src/server/test/hotkey_collector_test.cpp
+++ b/src/server/test/hotkey_collector_test.cpp
@@ -48,7 +48,8 @@
 {
     int bucket_id = -1;
     for (int i = 0; i < 1000000; i++) {
-        bucket_id = get_bucket_id(dsn::blob::create_from_bytes(generate_hash_key_by_random(false)));
+        bucket_id = get_bucket_id(dsn::blob::create_from_bytes(generate_hash_key_by_random(false)),
+                                  FLAGS_hotkey_buckets_num);
         ASSERT_GE(bucket_id, 0);
         ASSERT_LT(bucket_id, FLAGS_hotkey_buckets_num);
     }
@@ -81,7 +82,7 @@
 class coarse_collector_test : public pegasus_server_test_base
 {
 public:
-    coarse_collector_test() : coarse_collector(_server.get()){};
+    coarse_collector_test() : coarse_collector(_server.get(), FLAGS_hotkey_buckets_num){};
 
     hotkey_coarse_data_collector coarse_collector;
 
@@ -134,9 +135,11 @@
 public:
     int max_queue_size = 1000;
     int target_bucket_index = 0;
-    fine_collector_test() : fine_collector(_server.get(), 0, max_queue_size){};
-
     hotkey_fine_data_collector fine_collector;
+    fine_collector_test() : fine_collector(_server.get(), 1, max_queue_size)
+    {
+        fine_collector.change_target_bucket(0);
+    };
 
     int now_queue_size()
     {
@@ -153,8 +156,6 @@
 
 TEST_F(fine_collector_test, fine_collector)
 {
-    auto hotkey_buckets_num_backup = FLAGS_hotkey_buckets_num;
-    FLAGS_hotkey_buckets_num = 1;
     detect_hotkey_result result;
 
     for (int i = 0; i < 1000; i++) {
@@ -192,8 +193,6 @@
     }
     _tracker.wait_outstanding_tasks();
     ASSERT_LT(now_queue_size(), max_queue_size * 2);
-
-    FLAGS_hotkey_buckets_num = hotkey_buckets_num_backup;
 }
 
 } // namespace server