collect dc/rack information and validate when building (#15)

Patch by marcuse; reviewed by Alex Petrov for CASSANDRA-16109
diff --git a/pom.xml b/pom.xml
index 5e36e66..49113d7 100644
--- a/pom.xml
+++ b/pom.xml
@@ -69,6 +69,12 @@
             <version>3.15.0</version>
             <scope>test</scope>
         </dependency>
+        <dependency>
+           <groupId>org.mockito</groupId>
+            <artifactId>mockito-core</artifactId>
+            <version>3.5.10</version>
+            <scope>test</scope>
+        </dependency>
     </dependencies>
 
     <build>
diff --git a/src/main/java/org/apache/cassandra/distributed/shared/AbstractBuilder.java b/src/main/java/org/apache/cassandra/distributed/shared/AbstractBuilder.java
index 63ceca3..47f2d48 100644
--- a/src/main/java/org/apache/cassandra/distributed/shared/AbstractBuilder.java
+++ b/src/main/java/org/apache/cassandra/distributed/shared/AbstractBuilder.java
@@ -26,7 +26,10 @@
 import java.io.File;
 import java.io.IOException;
 import java.nio.file.Files;
+import java.util.ArrayList;
+import java.util.Collections;
 import java.util.HashMap;
+import java.util.List;
 import java.util.Map;
 import java.util.Objects;
 import java.util.function.BiConsumer;
@@ -55,6 +58,8 @@
     private int broadcastPort = 7012;
     private BiConsumer<ClassLoader, Integer> instanceInitializer = (cl, id) -> {};
     private int datadirCount = 3;
+    private final List<Rack> racks = new ArrayList<>();
+    private boolean finalised;
 
     public AbstractBuilder(Factory<I, C, B> factory)
     {
@@ -116,19 +121,12 @@
 
     public C createWithoutStarting() throws IOException
     {
+        finaliseBuilder();
         if (root == null)
             root = Files.createTempDirectory("dtests").toFile();
 
-        if (nodeCount <= 0)
-            throw new IllegalStateException("Cluster must have at least one node");
-
         root.mkdirs();
 
-        if (nodeIdTopology == null)
-            nodeIdTopology = IntStream.rangeClosed(1, nodeCount).boxed()
-                                      .collect(Collectors.toMap(nodeId -> nodeId,
-                                                                nodeId -> NetworkTopology.dcAndRack(dcName(0), rackName(0))));
-
         // TODO: make token allocation strategy configurable
         if (tokenSupplier == null)
             tokenSupplier = evenlyDistributedTokens(nodeCount);
@@ -159,72 +157,72 @@
         return (B) this;
     }
 
+    /**
+     * Start this many nodes initially
+     *
+     * Note that when using this in combination with withNodeIdTopology or withRacks/withDCs/... we
+     * might reduce the actual number of nodes if there are not enough nodes configured in the node id
+     * topology. For tests where additional nodes are started after the initial ones, it is ok to have
+     * nodeCount < node id topology size
+     */
     public B withNodes(int nodeCount)
     {
         this.nodeCount = nodeCount;
         return (B) this;
     }
 
+    /**
+     * Adds dcCount datacenters, splits the nodeCount over these dcs
+     */
     public B withDCs(int dcCount)
     {
         return withRacks(dcCount, 1);
     }
 
+    /**
+     * Adds this many racks per datacenter
+     *
+     * splits nodeCount over these racks/dcs
+     *
+     */
     public B withRacks(int dcCount, int racksPerDC)
     {
-        if (nodeCount == 0)
-            throw new IllegalStateException("Node count will be calculated. Do not supply total node count in the builder");
+        assert dcCount > 0 && racksPerDC > 0 : "Both dcCount and racksPerDC must be > 0";
 
-        int totalRacks = dcCount * racksPerDC;
-        int nodesPerRack = (nodeCount + totalRacks - 1) / totalRacks; // round up to next integer
-        return withRacks(dcCount, racksPerDC, nodesPerRack);
-    }
-
-    public B withRacks(int dcCount, int racksPerDC, int nodesPerRack)
-    {
-        if (nodeIdTopology != null)
-            throw new IllegalStateException("Network topology already created. Call withDCs/withRacks once or before withDC/withRack calls");
-
-        nodeIdTopology = new HashMap<>();
-        int nodeId = 1;
         for (int dc = 1; dc <= dcCount; dc++)
-        {
             for (int rack = 1; rack <= racksPerDC; rack++)
-            {
-                for (int rackNodeIdx = 0; rackNodeIdx < nodesPerRack; rackNodeIdx++)
-                    nodeIdTopology.put(nodeId++, NetworkTopology.dcAndRack(dcName(dc), rackName(rack)));
-            }
-        }
-        // adjust the node count to match the allocatation
-        final int adjustedNodeCount = dcCount * racksPerDC * nodesPerRack;
-        if (adjustedNodeCount != nodeCount)
-        {
-            assert adjustedNodeCount > nodeCount : "withRacks should only ever increase the node count";
-            System.out.println(String.format("Network topology of %s DCs with %s racks per DC and %s nodes per rack required increasing total nodes to %s",
-                                             dcCount, racksPerDC, nodesPerRack, adjustedNodeCount));
-            nodeCount = adjustedNodeCount;
-        }
+                withRack(dcName(dc), rackName(rack), -1);
         return (B) this;
     }
 
+    /**
+     * Creates a cluster with dcCount datacenters, racksPerDC racks in each dc and nodesPerRack nodes in each rack
+     *
+     * Note that node count must be >= dcCount * datacenters * racksPerDC, if it is smaller it will get adjusted up.
+     */
+    public B withRacks(int dcCount, int racksPerDC, int nodesPerRack)
+    {
+        assert dcCount > 0 && racksPerDC > 0 && nodesPerRack > 0 : "dcCount, racksPerDC and nodesPerRack must be > 0";
+        for (int dc = 1; dc <= dcCount; dc++)
+            for (int rack = 1; rack <= racksPerDC; rack++)
+                withRack(dcName(dc), rackName(rack), nodesPerRack);
+        return (B) this;
+    }
+
+    /**
+     * Add a dc with name dcName containing a single rack with nodeCount nodes
+     */
     public B withDC(String dcName, int nodeCount)
     {
         return withRack(dcName, rackName(1), nodeCount);
     }
 
+    /**
+     * Add a rack in dcName with name rackName containing nodesInRack nodes
+     */
     public B withRack(String dcName, String rackName, int nodesInRack)
     {
-        if (nodeIdTopology == null)
-        {
-            if (nodeCount > 0)
-                throw new IllegalStateException("Node count must not be explicitly set, or allocated using withDCs/withRacks");
-
-            nodeIdTopology = new HashMap<>();
-        }
-        for (int nodeId = nodeCount + 1; nodeId <= nodeCount + nodesInRack; nodeId++)
-            nodeIdTopology.put(nodeId, NetworkTopology.dcAndRack(dcName, rackName));
-
-        nodeCount += nodesInRack;
+        racks.add(new Rack(dcName, rackName, nodesInRack));
         return (B) this;
     }
 
@@ -239,12 +237,6 @@
                 throw new IllegalStateException("Topology is missing entry for nodeId " + nodeId);
         });
 
-        if (nodeCount != nodeIdTopology.size())
-        {
-            nodeCount = nodeIdTopology.size();
-            System.out.println(String.format("Adjusting node count to %s for supplied network topology", nodeCount));
-        }
-
         this.nodeIdTopology = new HashMap<>(nodeIdTopology);
 
         return (B) this;
@@ -281,6 +273,93 @@
         return (B) this;
     }
 
+    private void finaliseBuilder()
+    {
+        if (finalised)
+            return;
+        finalised = true;
+
+        if (!racks.isEmpty())
+        {
+            setRacks();
+        }
+        else if (nodeIdTopology != null)
+        {
+            if (nodeIdTopology.size() < nodeCount)
+            {
+                System.out.println("Adjusting node count since nodeIdTopology contains fewer nodes");
+                nodeCount = nodeIdTopology.size();
+            }
+            else if (nodeIdTopology.size() > nodeCount)
+            {
+                System.out.printf("nodeIdTopology configured for %d nodes while nodeCount is %d%n", nodeIdTopology.size(), nodeCount);
+            }
+        }
+        else
+        {
+            nodeIdTopology = IntStream.rangeClosed(1, nodeCount).boxed()
+                                      .collect(Collectors.toMap(nodeId -> nodeId,
+                                                                nodeId -> NetworkTopology.dcAndRack(dcName(0), rackName(0))));
+        }
+
+        if (nodeCount <= 0)
+            throw new IllegalStateException("Cluster must have at least one node");
+
+        System.out.println("Node id topology:");
+        for (int i = 1; i <= nodeIdTopology.size(); i++)
+        {
+            NetworkTopology.DcAndRack dcAndRack = nodeIdTopology.get(i);
+            System.out.printf("node %d: dc = %s, rack = %s%n", i, dcAndRack.dc, dcAndRack.rack);
+        }
+        System.out.printf("Configured node count: %d, nodeIdTopology size: %d%n", nodeCount, nodeIdTopology.size());
+    }
+
+    private void setRacks()
+    {
+        if (nodeIdTopology == null)
+            nodeIdTopology = new HashMap<>();
+
+        boolean shouldCalculatePerRackCount = false;
+        boolean hasExplicitPerRackCount = false;
+        for (Rack rack : racks)
+        {
+            if (rack.rackNodeCount == -1)
+                shouldCalculatePerRackCount = true;
+            else
+                hasExplicitPerRackCount = true;
+        }
+
+        if (shouldCalculatePerRackCount && hasExplicitPerRackCount)
+            throw new IllegalStateException("Can't mix explicit and implicit per rack counts");
+
+        int nodeId = nodeIdTopology.isEmpty() ? 1 : Collections.max(nodeIdTopology.keySet()) + 1;
+        if (shouldCalculatePerRackCount)
+        {
+            if (nodeCount == 0)
+                throw new IllegalStateException("Node count must be set when not setting per rack counts");
+            int totalRacks = racks.size();
+            int nodesPerRack = (nodeCount + totalRacks - 1) / totalRacks;
+
+            for (Rack rack : racks)
+                for (int i = 1; i <= nodesPerRack; i++)
+                    nodeIdTopology.put(nodeId++, NetworkTopology.dcAndRack(rack.dcName, rack.rackName));
+        }
+        else
+        {
+            for (Rack rack : racks)
+                for (int i = 1; i <= rack.rackNodeCount; i++)
+                    nodeIdTopology.put(nodeId++, NetworkTopology.dcAndRack(rack.dcName, rack.rackName));
+        }
+
+        if (nodeCount != nodeIdTopology.size())
+        {
+            assert nodeIdTopology.size() > nodeCount : "withRacks should only ever increase the node count";
+            if (nodeCount == 0)
+                nodeCount =  nodeIdTopology.size();
+            else
+                System.out.printf("Network topology of %s requires more nodes, only starting %s out of %s configured nodes%n", nodeIdTopology, nodeCount, nodeIdTopology.size());
+        }
+    }
 
     static String dcName(int index)
     {
@@ -291,6 +370,20 @@
     {
         return "rack" + index;
     }
+
+    private static class Rack
+    {
+        final String dcName;
+        final String rackName;
+        final int rackNodeCount;
+
+        private Rack(String dcName, String rackName, int rackNodeCount)
+        {
+            this.dcName = dcName;
+            this.rackName = rackName;
+            this.rackNodeCount = rackNodeCount;
+        }
+    }
 }
 
 
diff --git a/src/main/java/org/apache/cassandra/distributed/shared/NetworkTopology.java b/src/main/java/org/apache/cassandra/distributed/shared/NetworkTopology.java
index 7bd91d3..3eb2aec 100644
--- a/src/main/java/org/apache/cassandra/distributed/shared/NetworkTopology.java
+++ b/src/main/java/org/apache/cassandra/distributed/shared/NetworkTopology.java
@@ -34,8 +34,8 @@
 
     public static class DcAndRack
     {
-        private final String dc;
-        private final String rack;
+        public final String dc;
+        public final String rack;
 
         private DcAndRack(String dc, String rack)
         {
diff --git a/src/test/java/org/apache/cassandra/distributed/api/BuilderTest.java b/src/test/java/org/apache/cassandra/distributed/api/BuilderTest.java
new file mode 100644
index 0000000..04a9715
--- /dev/null
+++ b/src/test/java/org/apache/cassandra/distributed/api/BuilderTest.java
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.cassandra.distributed.api;
+
+import java.io.IOException;
+import java.util.HashMap;
+
+import org.junit.jupiter.api.Test;
+
+import org.apache.cassandra.distributed.shared.AbstractBuilder;
+import org.apache.cassandra.distributed.shared.NetworkTopology;
+import org.mockito.Mockito;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+public class BuilderTest
+{
+
+    @Test
+    public void testNodeCount() throws IOException
+    {
+        ICluster mockCluster = Mockito.mock(ICluster.class);
+        AbstractBuilder<?,?,?> builder = new AbstractBuilder(builder1 -> mockCluster){};
+
+        // empty case
+        try
+        {
+            builder.createWithoutStarting();
+        }
+        catch (Throwable t)
+        {
+            assertThat(t.getMessage().contains("Cluster must have at least one node"))
+            .isTrue();
+        }
+
+        // a single node
+        builder = new AbstractBuilder(builder1 -> mockCluster){};
+        builder.withNodes(1).createWithoutStarting();
+        assertThat(builder.getNodeCount())
+        .isEqualTo(1);
+        assertThat(builder.getNodeIdTopology().size())
+        .isEqualTo(1);
+
+        // numNodes > number of nodes in racks
+        builder = new AbstractBuilder(builder1 -> mockCluster){};
+        builder.withNodes(20)
+               .withRack("dc1", "rack1", 5)
+               .withRack("dc1", "rack2", 5)
+               .withRack("dc2", "rack1", 5);
+        try
+        {
+            builder.createWithoutStarting();
+        }
+        catch (Throwable t)
+        {
+            assertThat(t.getMessage().contains("withRacks should only ever increase the node count"))
+            .isTrue();
+        }
+
+        // numNodes < number of nodes in racks
+        builder = new AbstractBuilder(builder1 -> mockCluster){};
+        builder.withNodes(10)
+               .withRack("dc1", "rack1", 5)
+               .withRack("dc1", "rack2", 5)
+               .withRack("dc2", "rack1", 5);
+        builder.createWithoutStarting();
+        assertThat(builder.getNodeCount())
+        .isEqualTo(10);
+        assertThat(builder.getNodeIdTopology().size())
+        .isEqualTo(15);
+
+        builder = new AbstractBuilder(builder1 -> mockCluster){};
+        builder.withNodes(10)
+               .withRack("dc1", "rack1", 5)
+               .withRack("dc1", "rack2", 5)
+               .withRack("dc2", "rack1", 5)
+               .withNodeIdTopology(new HashMap<Integer, NetworkTopology.DcAndRack>() {{
+                   for (int i = 0; i < 3; i++)
+                       for (int j = 0; j < 3; j++)
+                           put(i * 3 + j + 1, NetworkTopology.dcAndRack("dc" + (i + 1), "rack" + (j + 1)));
+               }});
+        builder.createWithoutStarting();
+        assertThat(builder.getNodeCount())
+        .isEqualTo(10);
+        assertThat(builder.getNodeIdTopology().size())
+        .isEqualTo(9 + 15); // both given topology and racks are applied
+    }
+}