Fix peers v2 system table behaviour when 2 nodes swap their IP Addresses

Throw if node id has been changed and does not match directory. If, however the _ip_ address has changed, issue Startup and correct the IP address. Disallow picking over identity of other nodes via hijacking their IPs or via overriding local node id with theirs.

Patch by Alex Petrov; reviewed by Sam Tunnicliffe for CASSANDRA-19221
diff --git a/src/java/org/apache/cassandra/db/virtual/PeersTable.java b/src/java/org/apache/cassandra/db/virtual/PeersTable.java
index 5b011de..8d50dd7 100644
--- a/src/java/org/apache/cassandra/db/virtual/PeersTable.java
+++ b/src/java/org/apache/cassandra/db/virtual/PeersTable.java
@@ -113,15 +113,6 @@
         return result;
     }
 
-    public static void initializeLegacyPeerTables(ClusterMetadata prev, ClusterMetadata next)
-    {
-        QueryProcessor.executeInternal(String.format("TRUNCATE %s.%s", SYSTEM_KEYSPACE_NAME, PEERS_V2));
-        QueryProcessor.executeInternal(String.format("TRUNCATE %s.%s", SYSTEM_KEYSPACE_NAME, LEGACY_PEERS));
-
-        for (NodeId nodeId : next.directory.peerIds())
-            updateLegacyPeerTable(nodeId, prev, next);
-    }
-
     private static String peers_v2_query = "INSERT INTO %s.%s ("
                                             + "peer, peer_port, "
                                             + "preferred_ip, preferred_port, "
@@ -156,9 +147,7 @@
         if (next.directory.peerState(nodeId) == null || next.directory.peerState(nodeId) == NodeState.LEFT)
         {
             NodeAddresses addresses = prev.directory.getNodeAddresses(nodeId);
-            logger.debug("Purging {} from system.peers_v2 table", addresses);
-            QueryProcessor.executeInternal(String.format(peers_delete_query, SYSTEM_KEYSPACE_NAME, PEERS_V2), addresses.broadcastAddress.getAddress(), addresses.broadcastAddress.getPort());
-            QueryProcessor.executeInternal(String.format(legacy_peers_delete_query, SYSTEM_KEYSPACE_NAME, LEGACY_PEERS), addresses.broadcastAddress.getAddress());
+            removeFromSystemPeersTables(addresses.broadcastAddress);
         }
         else if (NodeState.isPreJoin(next.directory.peerState(nodeId)))
         {
@@ -169,11 +158,7 @@
             NodeAddresses addresses = next.directory.getNodeAddresses(nodeId);
             NodeAddresses oldAddresses = prev.directory.getNodeAddresses(nodeId);
             if (oldAddresses != null && !oldAddresses.equals(addresses))
-            {
-                logger.debug("Purging {} from system.peers_v2 table", oldAddresses);
-                QueryProcessor.executeInternal(String.format(peers_delete_query, SYSTEM_KEYSPACE_NAME, PEERS_V2), oldAddresses.broadcastAddress.getAddress(), oldAddresses.broadcastAddress.getPort());
-                QueryProcessor.executeInternal(String.format(legacy_peers_delete_query, SYSTEM_KEYSPACE_NAME, LEGACY_PEERS), oldAddresses.broadcastAddress.getAddress());
-            }
+                removeFromSystemPeersTables(oldAddresses.broadcastAddress);
 
             Location location = next.directory.location(nodeId);
 
@@ -197,4 +182,11 @@
                                            tokens);
         }
     }
+
+    public static void removeFromSystemPeersTables(InetAddressAndPort addr)
+    {
+        logger.debug("Purging {} from system.peers_v2 table", addr);
+        QueryProcessor.executeInternal(String.format(peers_delete_query, SYSTEM_KEYSPACE_NAME, PEERS_V2), addr.getAddress(), addr.getPort());
+        QueryProcessor.executeInternal(String.format(legacy_peers_delete_query, SYSTEM_KEYSPACE_NAME, LEGACY_PEERS), addr.getAddress());
+    }
 }
\ No newline at end of file
diff --git a/src/java/org/apache/cassandra/tcm/Startup.java b/src/java/org/apache/cassandra/tcm/Startup.java
index dfe4df8..023fcdb 100644
--- a/src/java/org/apache/cassandra/tcm/Startup.java
+++ b/src/java/org/apache/cassandra/tcm/Startup.java
@@ -158,8 +158,17 @@
         UUID currentHostId = SystemKeyspace.getLocalHostId();
         if (nodeId != null && !Objects.equals(nodeId.toUUID(), currentHostId))
         {
-            logger.info("NodeId is wrong, updating from {} to {}", currentHostId, nodeId.toUUID());
-            SystemKeyspace.setLocalHostId(nodeId.toUUID());
+            if (currentHostId == null)
+            {
+                logger.info("Taking over the host ID: {}, replacing address {}", nodeId.toUUID(), FBUtilities.getBroadcastAddressAndPort());
+                SystemKeyspace.setLocalHostId(nodeId.toUUID());
+                return;
+            }
+
+            String error = String.format("NodeId does not match locally set one. Check for the IP address collision: %s vs %s %s.",
+                                         currentHostId, nodeId.toUUID(), FBUtilities.getBroadcastAddressAndPort());
+            logger.error(error);
+            throw new IllegalStateException(error);
         }
     }
 
diff --git a/src/java/org/apache/cassandra/tcm/listeners/LegacyStateListener.java b/src/java/org/apache/cassandra/tcm/listeners/LegacyStateListener.java
index c219904..0a1a759 100644
--- a/src/java/org/apache/cassandra/tcm/listeners/LegacyStateListener.java
+++ b/src/java/org/apache/cassandra/tcm/listeners/LegacyStateListener.java
@@ -63,7 +63,9 @@
             next.tokenMap.lastModified().equals(prev.tokenMap.lastModified()))
             return;
 
-        Set<NodeId> removed = Sets.difference(prev.directory.peerIds(), next.directory.peerIds());
+        Set<InetAddressAndPort> removedAddr = Sets.difference(new HashSet<>(prev.directory.allAddresses()),
+                                                              new HashSet<>(next.directory.allAddresses()));
+
         Set<NodeId> changed = new HashSet<>();
         for (NodeId node : next.directory.peerIds())
         {
@@ -71,10 +73,10 @@
                 changed.add(node);
         }
 
-        for (NodeId remove : removed)
+        for (InetAddressAndPort remove : removedAddr)
         {
-            GossipHelper.evictFromMembership(prev.directory.endpoint(remove));
-            PeersTable.updateLegacyPeerTable(remove, prev, next);
+            GossipHelper.evictFromMembership(remove);
+            PeersTable.removeFromSystemPeersTables(remove);
         }
 
         for (NodeId change : changed)
diff --git a/test/distributed/org/apache/cassandra/distributed/test/log/BounceResetHostIdTest.java b/test/distributed/org/apache/cassandra/distributed/test/log/BounceResetHostIdTest.java
index 0411c41..e593a71 100644
--- a/test/distributed/org/apache/cassandra/distributed/test/log/BounceResetHostIdTest.java
+++ b/test/distributed/org/apache/cassandra/distributed/test/log/BounceResetHostIdTest.java
@@ -18,35 +18,120 @@
 
 package org.apache.cassandra.distributed.test.log;
 
-import java.util.UUID;
+import java.net.InetAddress;
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.concurrent.TimeUnit;
 
+import org.junit.Assert;
 import org.junit.Test;
 
-import org.apache.cassandra.db.SystemKeyspace;
 import org.apache.cassandra.distributed.Cluster;
+import org.apache.cassandra.distributed.api.ConsistencyLevel;
+import org.apache.cassandra.distributed.api.Feature;
+import org.apache.cassandra.distributed.shared.AssertUtils;
+import org.apache.cassandra.distributed.shared.ClusterUtils;
 import org.apache.cassandra.distributed.test.TestBaseImpl;
 import org.apache.cassandra.tcm.membership.NodeId;
 
-import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertTrue;
+import static org.apache.cassandra.distributed.shared.AssertUtils.row;
+import static org.junit.Assert.fail;
 
 public class BounceResetHostIdTest extends TestBaseImpl
 {
     @Test
-    public void bounceTest() throws Exception
+    public void swapIpsTest() throws Exception
     {
-        try (Cluster cluster = init(builder().withNodes(1)
-                                             .start()))
+        try (Cluster cluster = builder().withNodes(3)
+                                        .withConfig(c -> c.with(Feature.GOSSIP, Feature.NATIVE_PROTOCOL)
+                                                               // disable DistributedTestSnitch as it tries to query before we setup
+                                                               .set("endpoint_snitch", "org.apache.cassandra.locator.SimpleSnitch"))
+                                        .createWithoutStarting())
         {
-            String wrongId = UUID.randomUUID().toString();
-            cluster.get(1).runOnInstance(() -> {
-                SystemKeyspace.setLocalHostId(UUID.fromString(wrongId));
-                assertFalse(NodeId.isValidNodeId(SystemKeyspace.getLocalHostId()));
-            });
-            cluster.get(1).shutdown().get();
-            cluster.get(1).startup();
-            cluster.get(1).logs().watchFor("NodeId is wrong, updating from "+wrongId+" to "+(new NodeId(1).toUUID()));
-            cluster.get(1).runOnInstance(() -> assertTrue(NodeId.isValidNodeId(SystemKeyspace.getLocalHostId())));
+            // This test relies on node IDs being in the same order as IP addresses
+            for (int i = 1; i <= 3; i++)
+                cluster.get(i).startup();
+
+            cluster.get(2).shutdown().get();
+            ClusterUtils.updateAddress(cluster.get(2), "127.0.0.4");
+            cluster.get(2).startup();
+
+            cluster.get(3).shutdown().get();
+            ClusterUtils.updateAddress(cluster.get(3), "127.0.0.2");
+            cluster.get(3).startup();
+
+            cluster.get(2).shutdown().get();
+            ClusterUtils.updateAddress(cluster.get(2), "127.0.0.3");
+            cluster.get(2).startup();
+
+            ClusterUtils.waitForCMSToQuiesce(cluster, cluster.get(1));
+
+            long deadline = System.nanoTime() + TimeUnit.SECONDS.toNanos(30);
+            while (true)
+            {
+                try
+                {
+                    AssertUtils.assertRows(sortHelper(cluster.coordinator(2).execute("select peer, host_id from system.peers_v2", ConsistencyLevel.QUORUM)),
+                                           rows(row(InetAddress.getByName("127.0.0.1"), new NodeId(1).toUUID()),
+                                                row(InetAddress.getByName("127.0.0.2"), new NodeId(3).toUUID())
+                                           ));
+                    AssertUtils.assertRows(sortHelper(cluster.coordinator(3).execute("select peer, host_id from system.peers_v2", ConsistencyLevel.QUORUM)),
+                                           rows(row(InetAddress.getByName("127.0.0.1"), new NodeId(1).toUUID()),
+                                                row(InetAddress.getByName("127.0.0.3"), new NodeId(2).toUUID())
+
+                                           ));
+                    return;
+                }
+                catch (AssertionError t)
+                {
+                    // If we are past the deadline, throw; allow to retry otherwise
+                    if (System.nanoTime() > deadline)
+                        throw t;
+                }
+            }
         }
     }
+
+    @Test
+    public void swapIpsDirectlyTest() throws Exception
+    {
+        try (Cluster cluster = builder().withNodes(3)
+                                        .withConfig(c -> c.with(Feature.GOSSIP, Feature.NATIVE_PROTOCOL)
+                                                          // disable DistributedTestSnitch as it tries to query before we setup
+                                                          .set("endpoint_snitch", "org.apache.cassandra.locator.SimpleSnitch"))
+                                        .createWithoutStarting())
+        {
+            // This test relies on node IDs being in the same order as IP addresses
+            for (int i = 1; i <= 3; i++)
+                cluster.get(i).startup();
+
+            cluster.get(2).shutdown().get();
+            cluster.get(3).shutdown().get();
+            ClusterUtils.updateAddress(cluster.get(2), "127.0.0.3");
+            ClusterUtils.updateAddress(cluster.get(3), "127.0.0.2");
+            try
+            {
+                cluster.get(2).startup();
+                fail("Should not have been able to start");
+            }
+            catch (Throwable t)
+            {
+                Assert.assertTrue(t.getMessage().contains("NodeId does not match locally set one"));
+            }
+            try
+            {
+                cluster.get(3).startup();
+                fail("Should not have been able to start");
+            }
+            catch (Throwable t)
+            {
+                Assert.assertTrue(t.getMessage().contains("NodeId does not match locally set one"));
+            }
+        }
+    }
+    public static Object[][] sortHelper(Object[][] rows)
+    {
+        Arrays.sort(rows, Comparator.comparing(r -> ((InetAddress)r[0]).getHostAddress()));
+        return rows;
+    }
 }