Add nodetool command to unregister LEFT nodes

Patch by marcuse and Sam Tunnicliffe; reviewed by Sam Tunnicliffe for CASSANDRA-19581

Co-authored-by: Sam Tunnicliffe <samt@apache.org>
Co-authored-by: Marcus Eriksson <marcuse@apache.org>
diff --git a/CHANGES.txt b/CHANGES.txt
index c8355de..12787c4 100644
--- a/CHANGES.txt
+++ b/CHANGES.txt
@@ -1,4 +1,5 @@
 5.1
+ * Add nodetool command to unregister LEFT nodes (CASSANDRA-19581)
  * Add cluster metadata id to gossip syn messages (CASSANDRA-19613)
  * Reduce heap usage occupied by the metrics (CASSANDRA-19567)
  * Improve handling of transient replicas during range movements (CASSANDRA-19344)
diff --git a/src/java/org/apache/cassandra/db/virtual/ClusterMetadataDirectoryTable.java b/src/java/org/apache/cassandra/db/virtual/ClusterMetadataDirectoryTable.java
new file mode 100644
index 0000000..0d026ce
--- /dev/null
+++ b/src/java/org/apache/cassandra/db/virtual/ClusterMetadataDirectoryTable.java
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.cassandra.db.virtual;
+
+import java.util.Map;
+
+import org.apache.cassandra.db.marshal.InetAddressType;
+import org.apache.cassandra.db.marshal.Int32Type;
+import org.apache.cassandra.db.marshal.LongType;
+import org.apache.cassandra.db.marshal.UTF8Type;
+import org.apache.cassandra.db.marshal.UUIDType;
+import org.apache.cassandra.dht.LocalPartitioner;
+import org.apache.cassandra.schema.TableMetadata;
+import org.apache.cassandra.tcm.ClusterMetadata;
+import org.apache.cassandra.tcm.membership.Directory;
+import org.apache.cassandra.tcm.membership.Location;
+import org.apache.cassandra.tcm.membership.NodeAddresses;
+import org.apache.cassandra.tcm.membership.NodeId;
+import org.apache.cassandra.tcm.membership.NodeState;
+import org.apache.cassandra.tcm.membership.NodeVersion;
+
+
+final class ClusterMetadataDirectoryTable extends AbstractVirtualTable
+{
+    private static final String NODE_ID = "node_id";
+    private static final String HOST_ID = "host_id";
+    private static final String STATE = "state";
+    private static final String CASSANDRA_VERSION = "cassandra_version";
+    private static final String SERIALIZATION_VERSION = "serialization_version";
+    private static final String RACK = "rack";
+    private static final String DC = "dc";
+    private static final String BROADCAST_ADDRESS = "broadcast_address";
+    private static final String BROADCAST_PORT = "broadcast_port";
+    private static final String LOCAL_ADDRESS = "local_address";
+    private static final String LOCAL_PORT = "local_port";
+    private static final String NATIVE_ADDRESS = "native_address";
+    private static final String NATIVE_PORT = "native_port";
+
+
+    ClusterMetadataDirectoryTable(String keyspace)
+    {
+        super(TableMetadata.builder(keyspace, "cluster_metadata_directory")
+                           .comment("cluster metadata directory")
+                           .kind(TableMetadata.Kind.VIRTUAL)
+                           .partitioner(new LocalPartitioner(LongType.instance))
+                           .addPartitionKeyColumn(NODE_ID, Int32Type.instance)
+                           .addRegularColumn(HOST_ID, UUIDType.instance)
+                           .addRegularColumn(STATE, UTF8Type.instance)
+                           .addRegularColumn(CASSANDRA_VERSION, UTF8Type.instance)
+                           .addRegularColumn(SERIALIZATION_VERSION, Int32Type.instance)
+                           .addRegularColumn(RACK, UTF8Type.instance)
+                           .addRegularColumn(DC, UTF8Type.instance)
+                           .addRegularColumn(BROADCAST_ADDRESS, InetAddressType.instance)
+                           .addRegularColumn(BROADCAST_PORT, Int32Type.instance)
+                           .addRegularColumn(LOCAL_ADDRESS, InetAddressType.instance)
+                           .addRegularColumn(LOCAL_PORT, Int32Type.instance)
+                           .addRegularColumn(NATIVE_ADDRESS, InetAddressType.instance)
+                           .addRegularColumn(NATIVE_PORT, Int32Type.instance)
+                           .build());
+    }
+
+    @Override
+    public DataSet data()
+    {
+        ClusterMetadata metadata = ClusterMetadata.current();
+        Directory directory = metadata.directory;
+        SimpleDataSet result = new SimpleDataSet(metadata());
+        for (Map.Entry<NodeId, NodeState> entry : directory.states.entrySet())
+        {
+            NodeId nodeId = entry.getKey();
+            NodeState nodeState = entry.getValue();
+            NodeAddresses address = directory.getNodeAddresses(nodeId);
+            Location location = directory.location(nodeId);
+            NodeVersion version = directory.version(nodeId);
+            result.row(nodeId.id())
+                  .column(HOST_ID, nodeId.toUUID())
+                  .column(STATE, nodeState.toString())
+                  .column(CASSANDRA_VERSION, version != null ? version.cassandraVersion.toString() : null)
+                  .column(SERIALIZATION_VERSION, version != null ? version.serializationVersion : null)
+                  .column(RACK, location != null ? location.rack : null)
+                  .column(DC, location != null ? location.datacenter : null)
+                  .column(BROADCAST_ADDRESS, address != null ? address.broadcastAddress.getAddress() : null)
+                  .column(BROADCAST_PORT, address != null ? address.broadcastAddress.getPort() : null)
+                  .column(LOCAL_ADDRESS, address != null ? address.localAddress.getAddress() : null)
+                  .column(LOCAL_PORT, address != null ? address.localAddress.getPort() : null)
+                  .column(NATIVE_ADDRESS, address != null ? address.nativeAddress.getAddress() : null)
+                  .column(NATIVE_PORT, address != null ? address.nativeAddress.getPort() : null);
+        }
+        return result;
+    }
+}
diff --git a/src/java/org/apache/cassandra/db/virtual/SystemViewsKeyspace.java b/src/java/org/apache/cassandra/db/virtual/SystemViewsKeyspace.java
index 973be20..8c1412e 100644
--- a/src/java/org/apache/cassandra/db/virtual/SystemViewsKeyspace.java
+++ b/src/java/org/apache/cassandra/db/virtual/SystemViewsKeyspace.java
@@ -64,6 +64,7 @@
                     .add(new PeersTable(VIRTUAL_VIEWS))
                     .add(new LocalTable(VIRTUAL_VIEWS))
                     .add(new ClusterMetadataLogTable(VIRTUAL_VIEWS))
+                    .add(new ClusterMetadataDirectoryTable(VIRTUAL_VIEWS))
                     .addAll(LocalRepairTables.getAll(VIRTUAL_VIEWS))
                     .addAll(CIDRFilteringMetricsTable.getAll(VIRTUAL_VIEWS))
                     .addAll(StorageAttachedIndexTables.getAll(VIRTUAL_VIEWS))
diff --git a/src/java/org/apache/cassandra/service/StorageService.java b/src/java/org/apache/cassandra/service/StorageService.java
index 4dce59e..8514f7b 100644
--- a/src/java/org/apache/cassandra/service/StorageService.java
+++ b/src/java/org/apache/cassandra/service/StorageService.java
@@ -256,6 +256,7 @@
 import static org.apache.cassandra.tcm.membership.NodeState.BOOT_REPLACING;
 import static org.apache.cassandra.tcm.membership.NodeState.JOINED;
 import static org.apache.cassandra.tcm.membership.NodeState.MOVING;
+import static org.apache.cassandra.tcm.membership.NodeState.REGISTERED;
 import static org.apache.cassandra.utils.Clock.Global.currentTimeMillis;
 import static org.apache.cassandra.utils.FBUtilities.getBroadcastAddressAndPort;
 import static org.apache.cassandra.utils.FBUtilities.now;
@@ -1610,7 +1611,7 @@
                         throw new RuntimeException("Can't abort bootstrap for " + nodeId + " since it is not bootstrapping");
                     ClusterMetadataService.instance().commit(new CancelInProgressSequence(nodeId));
                 }
-                ClusterMetadataService.instance().commit(new Unregister(nodeId));
+                ClusterMetadataService.instance().commit(new Unregister(nodeId, EnumSet.of(REGISTERED, BOOTSTRAPPING, BOOT_REPLACING)));
                 break;
             default:
                 throw new RuntimeException("Can't abort bootstrap for node " + nodeId + " since the state is " + nodeState);
diff --git a/src/java/org/apache/cassandra/tcm/CMSOperations.java b/src/java/org/apache/cassandra/tcm/CMSOperations.java
index 12d6c94..0d9f491 100644
--- a/src/java/org/apache/cassandra/tcm/CMSOperations.java
+++ b/src/java/org/apache/cassandra/tcm/CMSOperations.java
@@ -20,6 +20,7 @@
 
 import java.io.IOException;
 import java.util.Collections;
+import java.util.EnumSet;
 import java.util.HashMap;
 import java.util.LinkedHashMap;
 import java.util.List;
@@ -31,11 +32,14 @@
 
 import org.apache.cassandra.config.DatabaseDescriptor;
 import org.apache.cassandra.schema.ReplicationParams;
+import org.apache.cassandra.tcm.membership.NodeId;
+import org.apache.cassandra.tcm.membership.NodeState;
 import org.apache.cassandra.tcm.membership.NodeVersion;
 import org.apache.cassandra.tcm.sequences.CancelCMSReconfiguration;
 import org.apache.cassandra.tcm.sequences.InProgressSequences;
 import org.apache.cassandra.tcm.sequences.ReconfigureCMS;
 import org.apache.cassandra.tcm.serialization.Version;
+import org.apache.cassandra.tcm.transformations.Unregister;
 import org.apache.cassandra.tcm.transformations.cms.AdvanceCMSReconfiguration;
 import org.apache.cassandra.utils.FBUtilities;
 import org.apache.cassandra.utils.MBeanWrapper;
@@ -196,4 +200,47 @@
     {
         return InProgressSequences.cancelInProgressSequences(sequenceOwner, expectedSequenceKind);
     }
+
+    @Override
+    public void unregisterLeftNodes(List<String> nodeIdStrings)
+    {
+        List<NodeId> nodeIds = nodeIdStrings.stream().map(NodeId::fromString).collect(Collectors.toList());
+        ClusterMetadata metadata = ClusterMetadata.current();
+        List<NodeId> nonLeftNodes = nodeIds.stream()
+                                           .filter(nodeId -> metadata.directory.peerState(nodeId) != NodeState.LEFT)
+                                           .collect(Collectors.toList());
+        if (!nonLeftNodes.isEmpty())
+        {
+            StringBuilder message = new StringBuilder();
+            for (NodeId nonLeft : nonLeftNodes)
+            {
+                NodeState nodeState = metadata.directory.peerState(nonLeft);
+                message.append("Node ").append(nonLeft.id()).append(" is in state ").append(nodeState);
+                switch (nodeState)
+                {
+                    case REGISTERED:
+                    case BOOTSTRAPPING:
+                    case BOOT_REPLACING:
+                        message.append(" - need to use `nodetool abortbootstrap` instead of unregistering").append('\n');
+                        break;
+                    case JOINED:
+                        message.append(" - use `nodetool decommission` or `nodetool removenode` to remove this node").append('\n');
+                        break;
+                    case MOVING:
+                        message.append(" - wait until move has been completed, then use `nodetool decommission` or `nodetool removenode` to remove this node").append('\n');
+                        break;
+                    case LEAVING:
+                        message.append(" - wait until leave-operation has completed, then retry this command").append('\n');
+                        break;
+                }
+            }
+            throw new IllegalStateException("Can't unregister node(s):\n" + message);
+        }
+
+        for (NodeId nodeId : nodeIds)
+        {
+            logger.info("Unregistering " + nodeId);
+            cms.commit(new Unregister(nodeId, EnumSet.of(NodeState.LEFT)));
+        }
+    }
 }
diff --git a/src/java/org/apache/cassandra/tcm/CMSOperationsMBean.java b/src/java/org/apache/cassandra/tcm/CMSOperationsMBean.java
index 4f25a5b..3b6e4b5 100644
--- a/src/java/org/apache/cassandra/tcm/CMSOperationsMBean.java
+++ b/src/java/org/apache/cassandra/tcm/CMSOperationsMBean.java
@@ -43,4 +43,6 @@
     public boolean getCommitsPaused();
 
     public boolean cancelInProgressSequences(String sequenceOwner, String expectedSequenceKind);
+
+    public void unregisterLeftNodes(List<String> nodeIds);
 }
diff --git a/src/java/org/apache/cassandra/tcm/transformations/Register.java b/src/java/org/apache/cassandra/tcm/transformations/Register.java
index 014d549..f53eb9a 100644
--- a/src/java/org/apache/cassandra/tcm/transformations/Register.java
+++ b/src/java/org/apache/cassandra/tcm/transformations/Register.java
@@ -19,6 +19,7 @@
 package org.apache.cassandra.tcm.transformations;
 
 import java.io.IOException;
+import java.util.EnumSet;
 import java.util.Map;
 import java.util.UUID;
 
@@ -107,7 +108,7 @@
         {
             if (nodeId != null)
                 ClusterMetadataService.instance()
-                                      .commit(new Unregister(nodeId));
+                                      .commit(new Unregister(nodeId, EnumSet.of(NodeState.LEFT)));
             nodeId = ClusterMetadataService.instance()
                                            .commit(new Register(nodeAddresses, location, nodeVersion))
                      .directory
diff --git a/src/java/org/apache/cassandra/tcm/transformations/Unregister.java b/src/java/org/apache/cassandra/tcm/transformations/Unregister.java
index ff1fc72..6d6fa04 100644
--- a/src/java/org/apache/cassandra/tcm/transformations/Unregister.java
+++ b/src/java/org/apache/cassandra/tcm/transformations/Unregister.java
@@ -19,17 +19,18 @@
 package org.apache.cassandra.tcm.transformations;
 
 import java.io.IOException;
+import java.util.EnumSet;
 
 import com.google.common.annotations.VisibleForTesting;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
 
+import org.apache.cassandra.db.TypeSizes;
 import org.apache.cassandra.io.util.DataInputPlus;
 import org.apache.cassandra.io.util.DataOutputPlus;
 import org.apache.cassandra.tcm.ClusterMetadata;
 import org.apache.cassandra.tcm.ClusterMetadataService;
 import org.apache.cassandra.tcm.Transformation;
 import org.apache.cassandra.tcm.membership.NodeId;
+import org.apache.cassandra.tcm.membership.NodeState;
 import org.apache.cassandra.tcm.sequences.LockedRanges;
 import org.apache.cassandra.tcm.serialization.AsymmetricMetadataSerializer;
 import org.apache.cassandra.tcm.serialization.Version;
@@ -38,13 +39,15 @@
 
 public class Unregister implements Transformation
 {
-    private static final Logger logger = LoggerFactory.getLogger(Unregister.class);
     public static final Serializer serializer = new Serializer();
 
     private final NodeId nodeId;
-    public Unregister(NodeId nodeId)
+    private final EnumSet<NodeState> allowedNodeStartStates;
+
+    public Unregister(NodeId nodeId, EnumSet<NodeState> allowedNodeStartStates)
     {
         this.nodeId = nodeId;
+        this.allowedNodeStartStates = allowedNodeStartStates;
     }
 
     @Override
@@ -57,19 +60,25 @@
     public Result execute(ClusterMetadata prev)
     {
         if (!prev.directory.peerIds().contains(nodeId))
-            return new Rejected(INVALID, String.format("Can not unregsiter %s since it is not present in the directory.", nodeId));
+            return new Rejected(INVALID, String.format("Can not unregister %s since it is not present in the directory.", nodeId));
 
-        ClusterMetadata.Transformer next = prev.transformer()
-                                           .unregister(nodeId);
+        NodeState startState = prev.directory.peerState(nodeId);
+        if (!allowedNodeStartStates.contains(startState))
+            return new Transformation.Rejected(INVALID, "Can't unregister " + nodeId + " - node state is " + startState + " not " + allowedNodeStartStates);
+
+        ClusterMetadata.Transformer next = prev.transformer().unregister(nodeId);
 
         return Transformation.success(next, LockedRanges.AffectedRanges.EMPTY);
     }
 
+    /**
+     * unsafe, only for test use
+     */
     @VisibleForTesting
     public static void unregister(NodeId nodeId)
     {
         ClusterMetadataService.instance()
-                              .commit(new Unregister(nodeId));
+                              .commit(new Unregister(nodeId, EnumSet.allOf(NodeState.class)));
     }
 
     public String toString()
@@ -85,20 +94,41 @@
         {
             assert t instanceof Unregister;
             Unregister register = (Unregister)t;
+            if (version.isAtLeast(Version.V2))
+            {
+                out.writeUnsignedVInt32(register.allowedNodeStartStates.size());
+                for (NodeState allowedState : register.allowedNodeStartStates)
+                    out.writeUTF(allowedState.name());
+            }
             NodeId.serializer.serialize(register.nodeId, out, version);
         }
 
         public Unregister deserialize(DataInputPlus in, Version version) throws IOException
         {
+            EnumSet<NodeState> states = EnumSet.noneOf(NodeState.class);
+            if (version.isAtLeast(Version.V2))
+            {
+                int startStateSize = in.readUnsignedVInt32();
+                for (int i = 0; i < startStateSize; i++)
+                    states.add(NodeState.valueOf(in.readUTF()));
+            }
             NodeId nodeId = NodeId.serializer.deserialize(in, version);
-            return new Unregister(nodeId);
+            return new Unregister(nodeId, version.isAtLeast(Version.V2) ? states : EnumSet.allOf(NodeState.class));
         }
 
         public long serializedSize(Transformation t, Version version)
         {
             assert t instanceof Unregister;
             Unregister unregister = (Unregister) t;
-            return NodeId.serializer.serializedSize(unregister.nodeId, version);
+            long size = 0;
+            if (version.isAtLeast(Version.V2))
+            {
+                size += TypeSizes.sizeofUnsignedVInt(unregister.allowedNodeStartStates.size());
+                for (NodeState state : unregister.allowedNodeStartStates)
+                    size += TypeSizes.sizeof(state.name());
+            }
+            size += NodeId.serializer.serializedSize(unregister.nodeId, version);
+            return size;
         }
     }
 }
diff --git a/src/java/org/apache/cassandra/tools/NodeTool.java b/src/java/org/apache/cassandra/tools/NodeTool.java
index d31e69b..bd1e302 100644
--- a/src/java/org/apache/cassandra/tools/NodeTool.java
+++ b/src/java/org/apache/cassandra/tools/NodeTool.java
@@ -264,7 +264,8 @@
                .withCommand(CMSAdmin.DescribeCMS.class)
                .withCommand(CMSAdmin.InitializeCMS.class)
                .withCommand(CMSAdmin.ReconfigureCMS.class)
-               .withCommand(CMSAdmin.Snapshot.class);
+               .withCommand(CMSAdmin.Snapshot.class)
+               .withCommand(CMSAdmin.Unregister.class);
 
         Cli<NodeToolCmdRunnable> parser = builder.build();
 
diff --git a/src/java/org/apache/cassandra/tools/nodetool/CMSAdmin.java b/src/java/org/apache/cassandra/tools/nodetool/CMSAdmin.java
index 841da97..e54430a 100644
--- a/src/java/org/apache/cassandra/tools/nodetool/CMSAdmin.java
+++ b/src/java/org/apache/cassandra/tools/nodetool/CMSAdmin.java
@@ -171,4 +171,16 @@
         }
     }
 
+    @Command(name = "unregister", description = "Unregister nodes in LEFT state")
+    public static class Unregister extends NodeTool.NodeToolCmd
+    {
+        @Arguments(required = true, title = "Unregister nodes in LEFT state", description = "One or more nodeIds to unregister, they all need to be in LEFT state", usage = "<nodeId>+")
+        public List<String> nodeIds;
+
+        @Override
+        protected void execute(NodeProbe probe)
+        {
+            probe.getCMSOperationsProxy().unregisterLeftNodes(nodeIds);
+        }
+    }
 }
diff --git a/test/distributed/org/apache/cassandra/distributed/test/log/RegisterTest.java b/test/distributed/org/apache/cassandra/distributed/test/log/RegisterTest.java
index a32dfbc..1a75b62 100644
--- a/test/distributed/org/apache/cassandra/distributed/test/log/RegisterTest.java
+++ b/test/distributed/org/apache/cassandra/distributed/test/log/RegisterTest.java
@@ -21,6 +21,7 @@
 import java.io.IOException;
 import java.net.UnknownHostException;
 import java.nio.ByteBuffer;
+import java.util.EnumSet;
 
 import org.junit.Test;
 
@@ -41,6 +42,7 @@
 import org.apache.cassandra.tcm.membership.Location;
 import org.apache.cassandra.tcm.membership.NodeAddresses;
 import org.apache.cassandra.tcm.membership.NodeId;
+import org.apache.cassandra.tcm.membership.NodeState;
 import org.apache.cassandra.tcm.membership.NodeVersion;
 import org.apache.cassandra.tcm.sequences.LeaveStreams;
 import org.apache.cassandra.tcm.sequences.UnbootstrapAndLeave;
@@ -81,7 +83,7 @@
                     ClusterMetadataService.instance().commit(unbootstrapAndLeave.startLeave);
                     ClusterMetadataService.instance().commit(unbootstrapAndLeave.midLeave);
                     ClusterMetadataService.instance().commit(unbootstrapAndLeave.finishLeave);
-                    ClusterMetadataService.instance().commit(new Unregister(ClusterMetadata.current().myNodeId()));
+                    ClusterMetadataService.instance().commit(new Unregister(ClusterMetadata.current().myNodeId(), EnumSet.of(NodeState.LEFT)));
                 });
 
                 cluster.get(1).runOnInstance(() -> {
@@ -140,7 +142,7 @@
                         }
 
                         // If we unregister oldNode, then the ceiling for serialization version will rise
-                        ClusterMetadataService.instance().commit(new Unregister(oldNode));
+                        ClusterMetadataService.instance().commit(new Unregister(oldNode, EnumSet.allOf(NodeState.class)));
                         assertEquals(ClusterMetadata.current().directory.clusterMinVersion.serializationVersion,
                                      NodeVersion.CURRENT_METADATA_VERSION.asInt());
                         bytes = t.kind().toVersionedBytes(t);
diff --git a/test/distributed/org/apache/cassandra/distributed/test/log/UnregisterTest.java b/test/distributed/org/apache/cassandra/distributed/test/log/UnregisterTest.java
new file mode 100644
index 0000000..c9dd5b4
--- /dev/null
+++ b/test/distributed/org/apache/cassandra/distributed/test/log/UnregisterTest.java
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.cassandra.distributed.test.log;
+
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+
+import org.junit.Test;
+
+import org.apache.cassandra.distributed.Cluster;
+import org.apache.cassandra.distributed.api.Feature;
+import org.apache.cassandra.distributed.test.TestBaseImpl;
+import org.apache.cassandra.tcm.ClusterMetadata;
+import org.apache.cassandra.tcm.membership.NodeId;
+import org.apache.cassandra.tcm.membership.NodeState;
+
+import static org.apache.cassandra.distributed.shared.ClusterUtils.getNodeId;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+
+public class UnregisterTest extends TestBaseImpl
+{
+    @Test
+    public void testUnregister() throws Throwable
+    {
+        try (Cluster cluster = init(builder().withNodes(5)
+                                        .withConfig((config) -> config.with(Feature.NETWORK, Feature.GOSSIP))
+                                        .start()))
+        {
+            Map<Integer, String> nodeToNodeId = new HashMap<>();
+            for (int i = 1; i <= 5; i++)
+                nodeToNodeId.put(i, String.valueOf(getNodeId(cluster.get(i)).id()));
+            verifyVirtualTable(cluster, nodeToNodeId, 5);
+            cluster.get(5).nodetoolResult("decommission", "--force").asserts().success();
+
+            verifyVirtualTable(cluster, nodeToNodeId,5, 5);
+            cluster.get(4).nodetoolResult("decommission", "--force").asserts().success();
+            verifyVirtualTable(cluster, nodeToNodeId, 5, 5, 4);
+            cluster.get(3).nodetoolResult("decommission", "--force").asserts().success();
+            verifyVirtualTable(cluster, nodeToNodeId, 5, 5, 4, 3);
+            // unregister a single node
+            cluster.get(1).nodetoolResult("cms", "unregister", nodeToNodeId.get(5)).asserts().success();
+            verifyVirtualTable(cluster, nodeToNodeId, 4, 4, 3);
+            // unregister multiple nodes
+            cluster.get(1).nodetoolResult("cms", "unregister", nodeToNodeId.get(4), nodeToNodeId.get(3)).asserts().success();
+            verifyVirtualTable(cluster, nodeToNodeId, 2);
+            // try to unregister a joined node, should fail
+            cluster.get(1).nodetoolResult("cms", "unregister", nodeToNodeId.get(2)).asserts().failure();
+            verifyVirtualTable(cluster, nodeToNodeId,2);
+
+            cluster.get(1).runOnInstance(() -> {
+                ClusterMetadata metadata = ClusterMetadata.current();
+                assertEquals(2, metadata.directory.states.size());
+                for (Map.Entry<NodeId, NodeState> entry : metadata.directory.states.entrySet())
+                    assertEquals(NodeState.JOINED, entry.getValue());
+            });
+        }
+    }
+
+    private static void verifyVirtualTable(Cluster cluster, Map<Integer, String> nodeToNodeId, int expectedTotal, int ... expectedLeftNodes)
+    {
+        Set<Integer> leftNodeIds = new HashSet<>();
+        for (int i : expectedLeftNodes)
+        {
+            NodeId nodeId = NodeId.fromString(nodeToNodeId.get(i));
+            leftNodeIds.add(nodeId.id());
+        }
+        Object [][] res = cluster.get(1).executeInternal("select node_id, state from system_views.cluster_metadata_directory");
+        assertEquals(expectedTotal, res.length);
+        for (Object [] row : res)
+        {
+            int id = (int)row[0];
+            if (row[1].equals("JOINED"))
+                assertFalse(leftNodeIds.contains(id));
+            else if (row[1].equals("LEFT"))
+                assertTrue(leftNodeIds.remove(id));
+            else
+                throw new AssertionError("Unexpected state: " + row[1]);
+        }
+    }
+}