SOLR-11535: Fix race condition in singleton-per-collection StateWatcher creation (#1964)

diff --git a/solr/CHANGES.txt b/solr/CHANGES.txt
index a9c2054..281a1df 100644
--- a/solr/CHANGES.txt
+++ b/solr/CHANGES.txt
@@ -176,6 +176,8 @@
 
 * SOLR-17176: Fix log history V2 API serialization (Michael Gibney)
 
+* SOLR-11535: Fix race condition in singleton-per-collection StateWatcher creation (Michael Gibney)
+
 Dependency Upgrades
 ---------------------
 
diff --git a/solr/core/src/java/org/apache/solr/cloud/ActiveReplicaWatcher.java b/solr/core/src/java/org/apache/solr/cloud/ActiveReplicaWatcher.java
index 4fe3963..30cd1ba 100644
--- a/solr/core/src/java/org/apache/solr/cloud/ActiveReplicaWatcher.java
+++ b/solr/core/src/java/org/apache/solr/cloud/ActiveReplicaWatcher.java
@@ -119,6 +119,7 @@
   }
 
   // synchronized due to SOLR-11535
+  // TODO: can we remove `synchronized`, now that SOLR-11535 is fixed?
   @Override
   public synchronized boolean onStateChanged(Set<String> liveNodes, DocCollection collectionState) {
     if (log.isDebugEnabled()) {
diff --git a/solr/core/src/java/org/apache/solr/cloud/ZkController.java b/solr/core/src/java/org/apache/solr/cloud/ZkController.java
index 9044281..a2377ce 100644
--- a/solr/core/src/java/org/apache/solr/cloud/ZkController.java
+++ b/solr/core/src/java/org/apache/solr/cloud/ZkController.java
@@ -2784,6 +2784,7 @@
 
     @Override
     // synchronized due to SOLR-11535
+    // TODO: can we remove `synchronized`, now that SOLR-11535 is fixed?
     public synchronized boolean onStateChanged(DocCollection collectionState) {
       if (getCoreContainer().getCoreDescriptor(coreName) == null) return true;
 
diff --git a/solr/core/src/test/org/apache/solr/cloud/overseer/ZkStateReaderTest.java b/solr/core/src/test/org/apache/solr/cloud/overseer/ZkStateReaderTest.java
index 43c52d9..710fdc2 100644
--- a/solr/core/src/test/org/apache/solr/cloud/overseer/ZkStateReaderTest.java
+++ b/solr/core/src/test/org/apache/solr/cloud/overseer/ZkStateReaderTest.java
@@ -26,11 +26,16 @@
 import java.util.HashMap;
 import java.util.Map;
 import java.util.Set;
+import java.util.concurrent.BrokenBarrierException;
+import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.CyclicBarrier;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
 import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicReference;
+import java.util.concurrent.atomic.LongAdder;
 import org.apache.lucene.util.IOUtils;
 import org.apache.solr.SolrTestCaseJ4;
 import org.apache.solr.cloud.OverseerTest;
@@ -692,6 +697,82 @@
   }
 
   /**
+   * Simulates race condition that can arise from the normal way in which the removal of collection
+   * StateWatchers is deferred.
+   *
+   * <p>StateWatchers are registered at the level of Zk code, so when StateWatchers are removed in
+   * Solr code, the actual removal is deferred until the next callback for the associated collection
+   * fires, at which point the removed watcher should allow itself to expire. If a watcher is
+   * re-added for the associated collection in the intervening time, only the most recently added
+   * watcher should re-register; the removed watcher should simply expire.
+   *
+   * <p>Duplicate/redundant StateWatchers should no longer be registered with the new code that
+   * tracks the currently registered singleton-per-collection watcher in Solr code, and only
+   * re-registers the currently active watcher, with all other watchers allowing themselves to
+   * expire.
+   */
+  public void testStateWatcherRaceCondition() throws Exception {
+    ZkStateWriter writer = fixture.writer;
+    final ZkStateReader reader = fixture.reader;
+    fixture.zkClient.makePath(ZkStateReader.COLLECTIONS_ZKNODE + "/c1", true);
+    int extraWatchers = 10;
+    int iterations = 10;
+    for (int i = 0; i < extraWatchers; i++) {
+      // add and remove a bunch of watchers
+      DocCollectionWatcher w = (coll) -> false;
+      try {
+        reader.registerDocCollectionWatcher("c1", w);
+      } finally {
+        reader.removeDocCollectionWatcher("c1", w);
+      }
+    }
+    final ConcurrentHashMap<Integer, LongAdder> invoked = new ConcurrentHashMap<>();
+    CyclicBarrier barrier = new CyclicBarrier(2);
+    reader.registerDocCollectionWatcher(
+        "c1",
+        (coll) -> {
+          // add a watcher that tracks how many times it's invoked per znode version
+          if (coll != null) {
+            try {
+              barrier.await(250, TimeUnit.MILLISECONDS);
+            } catch (InterruptedException | TimeoutException | BrokenBarrierException e) {
+              throw new RuntimeException(e);
+            }
+            invoked.computeIfAbsent(coll.getZNodeVersion(), (k) -> new LongAdder()).increment();
+          }
+          return false;
+        });
+
+    ClusterState clusterState = reader.getClusterState();
+    int dataVersion = -1;
+    for (int i = 0; i < iterations; i++) {
+      // create or update collection
+      DocCollection state =
+          DocCollection.create(
+              "c1",
+              new HashMap<>(),
+              Map.of(ZkStateReader.CONFIGNAME_PROP, ConfigSetsHandler.DEFAULT_CONFIGSET_NAME),
+              DocRouter.DEFAULT,
+              dataVersion,
+              Instant.now(),
+              PerReplicaStatesOps.getZkClientPrsSupplier(
+                  fixture.zkClient, DocCollection.getCollectionPath("c1")));
+      ZkWriteCommand wc = new ZkWriteCommand("c1", state);
+      writer.enqueueUpdate(clusterState, Collections.singletonList(wc), null);
+      clusterState = writer.writePendingUpdates();
+      barrier.await(250, TimeUnit.MILLISECONDS); // wait for the watch callback to execute
+      fixture.zkClient.makePath(ZkStateReader.COLLECTIONS_ZKNODE + "/c1" + i, true);
+      dataVersion = clusterState.getCollectionOrNull("c1").getZNodeVersion();
+    }
+    // expect to have been invoked for each iteration ...
+    assertEquals(iterations, invoked.size());
+    // ... and only _once_ for each iteration
+    assertTrue(
+        "wrong number of watchers (expected 1): " + invoked,
+        invoked.values().stream().mapToLong(LongAdder::sum).allMatch((l) -> l == 1));
+  }
+
+  /**
    * Ensure that collection state fetching (getCollectionLive etc.) would not throw exception when
    * the state.json is deleted in between the state.json read and PRS entries read
    */
diff --git a/solr/solrj-zookeeper/src/java/org/apache/solr/common/cloud/ZkStateReader.java b/solr/solrj-zookeeper/src/java/org/apache/solr/common/cloud/ZkStateReader.java
index 1049631..32ec0b2 100644
--- a/solr/solrj-zookeeper/src/java/org/apache/solr/common/cloud/ZkStateReader.java
+++ b/solr/solrj-zookeeper/src/java/org/apache/solr/common/cloud/ZkStateReader.java
@@ -372,6 +372,22 @@
 
   private static class StatefulCollectionWatch extends CollectionWatch<DocCollectionWatcher> {
     private DocCollection currentState;
+
+    /**
+     * The {@link StateWatcher} that is associated with this {@link StatefulCollectionWatch}. It is
+     * necessary to track this because of the way {@link StateWatcher} instances expire
+     * asynchronously: once registered with ZooKeeper, a {@link StateWatcher} cannot be removed, and
+     * its {@link StateWatcher#process(WatchedEvent)} method will be invoked upon node update.
+     * Because it is not possible to synchronously remove the {@link StateWatcher} as part of a
+     * transaction with {@link ZkStateReader#collectionWatches}, we keep track of a unique {@link
+     * StateWatcher} here, so that all other {@link StateWatcher}s may properly expire in a deferred
+     * way.
+     */
+    private volatile StateWatcher associatedWatcher;
+
+    private StatefulCollectionWatch(StateWatcher associatedWatcher) {
+      this.associatedWatcher = associatedWatcher;
+    }
   }
 
   public static final Set<String> KNOWN_CLUSTER_PROPS =
@@ -651,8 +667,10 @@
 
   /** Refresh collections. */
   private void refreshCollections() {
-    for (String coll : collectionWatches.watchedCollections()) {
-      new StateWatcher(coll).refreshAndWatch();
+    for (Entry<String, StatefulCollectionWatch> e : collectionWatches.watchedCollectionEntries()) {
+      StateWatcher newStateWatcher = new StateWatcher(e.getKey());
+      e.getValue().associatedWatcher = newStateWatcher;
+      newStateWatcher.refreshAndWatch();
     }
   }
 
@@ -1336,8 +1354,9 @@
         return;
       }
 
-      if (!collectionWatches.watchedCollections().contains(coll)) {
-        // This collection is no longer interesting, stop watching.
+      StatefulCollectionWatch scw = collectionWatches.statefulWatchesByCollectionName.get(coll);
+      if (scw == null || scw.associatedWatcher != this) {
+        // Collection no longer interesting, or we have been replaced by a different watcher.
         log.debug("Uninteresting collection {}", coll);
         return;
       }
@@ -1679,19 +1698,21 @@
    * @see ZkStateReader#unregisterCore(String)
    */
   public void registerCore(String collection) {
-    AtomicBoolean reconstructState = new AtomicBoolean(false);
+    AtomicReference<StateWatcher> newWatcherRef = new AtomicReference<>();
     collectionWatches.compute(
         collection,
         (k, v) -> {
           if (v == null) {
-            reconstructState.set(true);
-            v = new StatefulCollectionWatch();
+            StateWatcher stateWatcher = new StateWatcher(collection);
+            newWatcherRef.set(stateWatcher);
+            v = new StatefulCollectionWatch(stateWatcher);
           }
           v.coreRefCount++;
           return v;
         });
-    if (reconstructState.get()) {
-      new StateWatcher(collection).refreshAndWatch();
+    StateWatcher newWatcher = newWatcherRef.get();
+    if (newWatcher != null) {
+      newWatcher.refreshAndWatch();
     }
   }
 
@@ -1763,26 +1784,29 @@
    * <p>The Watcher will automatically be removed when it's <code>onStateChanged</code> returns
    * <code>true</code>
    */
-  public void registerDocCollectionWatcher(String collection, DocCollectionWatcher stateWatcher) {
-    AtomicBoolean watchSet = new AtomicBoolean(false);
+  public void registerDocCollectionWatcher(
+      String collection, DocCollectionWatcher docCollectionWatcher) {
+    AtomicReference<StateWatcher> newWatcherRef = new AtomicReference<>();
     collectionWatches.compute(
         collection,
         (k, v) -> {
           if (v == null) {
-            v = new StatefulCollectionWatch();
-            watchSet.set(true);
+            StateWatcher stateWatcher = new StateWatcher(collection);
+            newWatcherRef.set(stateWatcher);
+            v = new StatefulCollectionWatch(stateWatcher);
           }
-          v.stateWatchers.add(stateWatcher);
+          v.stateWatchers.add(docCollectionWatcher);
           return v;
         });
 
-    if (watchSet.get()) {
-      new StateWatcher(collection).refreshAndWatch();
+    StateWatcher newWatcher = newWatcherRef.get();
+    if (newWatcher != null) {
+      newWatcher.refreshAndWatch();
     }
 
     DocCollection state = clusterState.getCollectionOrNull(collection);
-    if (stateWatcher.onStateChanged(state) == true) {
-      removeDocCollectionWatcher(collection, stateWatcher);
+    if (docCollectionWatcher.onStateChanged(state) == true) {
+      removeDocCollectionWatcher(collection, docCollectionWatcher);
     }
   }