Merge pull request #722 from thelastpickle/mck/16205

Use the token allocation strategy generator tool for 4.0 vnode cluste…
diff --git a/ccmlib/cluster.py b/ccmlib/cluster.py
index 966f9d5..33c612f 100644
--- a/ccmlib/cluster.py
+++ b/ccmlib/cluster.py
@@ -4,6 +4,7 @@
 import itertools
 import os
 import random
+import re
 import shutil
 import signal
 import subprocess
@@ -24,6 +25,9 @@
 except ImportError:
     from urlparse import urlparse
 
+
+CLUSTER_WAIT_TIMEOUT_IN_SECS = 120
+
 class Cluster(object):
 
     def __init__(self, path, name, partitioner=None, install_dir=None, create_directory=True, version=None, verbose=False, derived_cassandra_version=None, **kwargs):
@@ -247,7 +251,7 @@
         node._save()
         return self
 
-    def populate(self, nodes, debug=False, tokens=None, use_vnodes=False, ipprefix='127.0.0.', ipformat=None, install_byteman=False, use_single_interface=False):
+    def populate(self, nodes, debug=False, tokens=None, use_vnodes=None, ipprefix='127.0.0.', ipformat=None, install_byteman=False, use_single_interface=False):
         """Populate a cluster with nodes
         @use_single_interface : Populate the cluster with nodes that all share a single network interface.
         """
@@ -258,7 +262,15 @@
         node_count = nodes
         dcs = []
 
-        self.use_vnodes = use_vnodes
+        if use_vnodes is None:
+            self.use_vnodes = (
+                (tokens is not None and len(tokens) > 1)
+                    or ('num_tokens' in self._config_options
+                        and self._config_options['num_tokens'] is not None
+                        and int(self._config_options['num_tokens']) > 1))
+        else:
+            self.use_vnodes = use_vnodes
+
         if isinstance(nodes, list):
             self.set_configuration_options(values={'endpoint_snitch': 'org.apache.cassandra.locator.PropertyFileSnitch'})
             node_count = 0
@@ -276,11 +288,22 @@
             if 'node%s' % i in list(self.nodes.values()):
                 raise common.ArgumentError('Cannot create existing node node%s' % i)
 
-        if tokens is None and not use_vnodes:
-            if dcs is None or len(dcs) <= 1:
-                tokens = self.balanced_tokens(node_count)
+        if tokens is None:
+            if self.use_vnodes:
+                # from 4.0 tokens can be pre-generated via the `allocate_tokens_for_local_replication_factor: 3` strategy
+                #  this saves time, as allocating tokens during first start is slow and non-concurrent
+                if self.can_generate_tokens() and not 'CASSANDRA_TOKEN_PREGENERATION_DISABLED' in self._environment_variables:
+                    if len(dcs) <= 1:
+                        for x in xrange(0, node_count):
+                            dcs.append('dc1')
+
+                    tokens = self.generated_tokens(dcs)
             else:
-                tokens = self.balanced_tokens_across_dcs(dcs)
+                common.debug("using balanced tokens for non-vnode cluster")
+                if len(dcs) <= 1:
+                    tokens = self.balanced_tokens(node_count)
+                else:
+                    tokens = self.balanced_tokens_across_dcs(dcs)
 
         if not ipformat:
             ipformat = ipprefix + "%d"
@@ -350,6 +373,47 @@
         tokens.extend(new_tokens)
         return tokens
 
+    def can_generate_tokens(self):
+        return (self.cassandra_version() >= '4'
+                    and (self.partitioner is None or ('Murmur3' in self.partitioner or 'Random' in self.partitioner))
+                    and ('num_tokens' in self._config_options
+                            and self._config_options['num_tokens'] is not None and int(self._config_options['num_tokens']) > 1))
+
+    def generated_tokens(self, dcs):
+        tokens = []
+        # all nodes are in rack1
+        current_dc = dcs[0]
+        node_count = 0
+        for dc in dcs:
+            if dc == current_dc:
+                node_count += 1
+            else:
+                self.generate_dc_tokens(node_count, tokens)
+                current_dc = dc
+                node_count = 1
+        self.generate_dc_tokens(node_count, tokens)
+        return tokens
+
+    def generate_dc_tokens(self, node_count, tokens):
+        if self.cassandra_version() < '4' or (self.partitioner and not ('Murmur3' in self.partitioner or 'Random' in self.partitioner)):
+            raise common.ArgumentError("generate-tokens script only for >=4.0 and Murmur3 or Random")
+        if not ('num_tokens' in self._config_options and self._config_options['num_tokens'] is not None and int(self._config_options['num_tokens']) > 1):
+            raise common.ArgumentError("Cannot use generate-tokens script without num_tokens > 1")
+
+        partitioner = 'RandomPartitioner' if ( self.partitioner and 'Random' in self.partitioner) else 'Murmur3Partitioner'
+        generate_tokens = common.join_bin(self.get_install_dir(), os.path.join('tools', 'bin'), 'generatetokens')
+        cmd_list = [generate_tokens, '-n', str(node_count), '-t', str(self._config_options.get("num_tokens")), '--rf', str(min(3,node_count)), '-p', partitioner]
+        process = subprocess.Popen(cmd_list, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=os.environ.copy())
+        # the first line is "Generating tokens for X nodes with" and can be ignored
+        process.stdout.readline()
+
+        for n in range(1,node_count+1):
+            stdout_output = re.sub(r'^.*?:', '', process.stdout.readline().decode("utf-8"))
+            node_tokens = stdout_output.replace('[','').replace(' ','').replace(']','').replace('\n','')
+            tokens.append(node_tokens)
+
+        common.debug("pregenerated tokens from cmd_list: {} are {}".format(str(cmd_list),tokens))
+
     def remove(self, node=None):
         if node is not None:
             if node.name not in self.nodes:
@@ -451,9 +515,12 @@
                 if os.path.exists(node.logfilename()):
                     mark = node.mark_log()
 
+                # if the node is going to allocate_strategy_ tokens during start, then wait_for_binary_proto=True
+                node_wait_for_binary_proto = (self.can_generate_tokens() and self.use_vnodes and node.initial_token is None)
+
                 p = node.start(update_pid=False, jvm_args=jvm_args, jvm_version=jvm_version,
                                profile_options=profile_options, verbose=verbose, quiet_start=quiet_start,
-                               allow_root=allow_root)
+                               allow_root=allow_root, wait_for_binary_proto=node_wait_for_binary_proto)
 
                 # Prior to JDK8, starting every node at once could lead to a
                 # nanotime collision where the RNG that generates a node's tokens
@@ -470,7 +537,7 @@
             for node, p, mark in started:
                 try:
                     start_message = "Listening for thrift clients..." if self.cassandra_version() < "2.2" else "Starting listening for CQL clients"
-                    node.watch_log_for(start_message, timeout=kwargs.get('timeout',60), process=p, verbose=verbose, from_mark=mark)
+                    node.watch_log_for(start_message, timeout=kwargs.get('timeout',CLUSTER_WAIT_TIMEOUT_IN_SECS), process=p, verbose=verbose, from_mark=mark)
                 except RuntimeError:
                     return None
 
@@ -689,15 +756,8 @@
         for node in self.nodelist():
             if node.data_center is not None:
                 dcs.append((node.address(), node.data_center))
-
-        content = ""
-        for k, v in dcs:
-            content = "%s%s=%s:r1\n" % (content, k, v)
-
         for node in self.nodelist():
-            topology_file = os.path.join(node.get_conf_dir(), 'cassandra-topology.properties')
-            with open(topology_file, 'w') as f:
-                f.write(content)
+            node.update_topology(dcs)
 
     def enable_ssl(self, ssl_path, require_client_auth):
         shutil.copyfile(os.path.join(ssl_path, 'keystore.jks'), os.path.join(self.get_path(), 'keystore.jks'))
diff --git a/ccmlib/cmds/cluster_cmds.py b/ccmlib/cmds/cluster_cmds.py
index 1725144..dcca8fa 100644
--- a/ccmlib/cmds/cluster_cmds.py
+++ b/ccmlib/cmds/cluster_cmds.py
@@ -145,8 +145,11 @@
         if self.options.binary_protocol:
             cluster.set_configuration_options({'start_native_transport': True})
 
-        if cluster.cassandra_version() >= "1.2" and self.options.vnodes:
-            cluster.set_configuration_options({'num_tokens': 256})
+        if self.options.vnodes:
+            if cluster.cassandra_version() >= "4":
+                cluster.set_configuration_options({'num_tokens': 16})
+            elif cluster.cassandra_version() >= "1.2":
+                cluster.set_configuration_options({'num_tokens': 256})
 
         if not self.options.no_switch:
             common.switch_cluster(self.path, self.name)
@@ -284,8 +287,11 @@
 
     def run(self):
         try:
-            if self.cluster.cassandra_version() >= "1.2" and self.options.vnodes:
-                self.cluster.set_configuration_options({'num_tokens': 256})
+            if self.options.vnodes:
+                if self.cluster.cassandra_version() >= "4":
+                    self.cluster.set_configuration_options({'num_tokens': 16})
+                elif self.cluster.cassandra_version() >= "1.2":
+                    self.cluster.set_configuration_options({'num_tokens': 256})
 
             if not (self.options.ipprefix or self.options.ipformat):
                 self.options.ipformat = '127.0.0.%d'
diff --git a/ccmlib/node.py b/ccmlib/node.py
index 6d6ae8f..d555641 100644
--- a/ccmlib/node.py
+++ b/ccmlib/node.py
@@ -29,6 +29,7 @@
 
 logger = logging.getLogger(__name__)
 
+NODE_WAIT_TIMEOUT_IN_SECS = 90
 
 class Status():
     UNINITIALIZED = "UNINITIALIZED"
@@ -115,6 +116,7 @@
         self.workloads = []
         self._dse_config_options = {}
         self.__config_options = {}
+        self._topology = [('default', 'dc1')]
         self.__install_dir = None
         self.__global_log_level = None
         self.__classes_log_level = {}
@@ -620,21 +622,21 @@
         log for 'Starting listening for CQL clients' before checking for the
         interface to be listening.
 
-        Emits a warning if not listening after 30 seconds.
+        Emits a warning if not listening after NODE_WAIT_TIMEOUT_IN_SECS seconds.
         """
         if self.cluster.version() >= '1.2':
             self.watch_log_for("Starting listening for CQL clients", **kwargs)
 
         binary_itf = self.network_interfaces['binary']
-        if not common.check_socket_listening(binary_itf, timeout=30):
-            warnings.warn("Binary interface %s:%s is not listening after 30 seconds, node may have failed to start."
-                          % (binary_itf[0], binary_itf[1]))
+        if not common.check_socket_listening(binary_itf, timeout=NODE_WAIT_TIMEOUT_IN_SECS):
+            warnings.warn("Binary interface %s:%s is not listening after %s seconds, node may have failed to start."
+                          % (binary_itf[0], binary_itf[1], NODE_WAIT_TIMEOUT_IN_SECS))
 
     def wait_for_thrift_interface(self, **kwargs):
         """
         Waits for the Thrift interface to be listening.
 
-        Emits a warning if not listening after 30 seconds.
+        Emits a warning if not listening after NODE_WAIT_TIMEOUT_IN_SECS seconds.
         """
         if self.cluster.version() >= '4':
             return;
@@ -642,8 +644,9 @@
         self.watch_log_for("Listening for thrift clients...", **kwargs)
 
         thrift_itf = self.network_interfaces['thrift']
-        if not common.check_socket_listening(thrift_itf, timeout=30):
-            warnings.warn("Thrift interface {}:{} is not listening after 30 seconds, node may have failed to start.".format(thrift_itf[0], thrift_itf[1]))
+        if not common.check_socket_listening(thrift_itf, timeout=NODE_WAIT_TIMEOUT_IN_SECS):
+            warnings.warn(
+                "Thrift interface {}:{} is not listening after {} seconds, node may have failed to start.".format(thrift_itf[0], thrift_itf[1], NODE_WAIT_TIMEOUT_IN_SECS))
 
     def get_launch_bin(self):
         cdir = self.get_install_dir()
@@ -1493,6 +1496,7 @@
         self._update_config()
         self.copy_config_files()
         self._update_yaml()
+        self._update_topology_file()
         # loggers changed > 2.1
         if self.get_base_cassandra_version() < 2.1:
             self._update_log4j()
@@ -1820,6 +1824,19 @@
                             common.replace_in_file(f, '-Djava.net.preferIPv4Stack=true', '')
                     break
 
+    def update_topology(self, topology):
+        self._topology = topology
+        self._update_topology_file()
+
+    def _update_topology_file(self):
+        content = ""
+        for k, v in self._topology:
+            content = "%s%s=%s:r1\n" % (content, k, v)
+
+        topology_file = os.path.join(self.get_conf_dir(), 'cassandra-topology.properties')
+        with open(topology_file, 'w') as f:
+            f.write(content)
+
     def __update_status(self):
         if self.pid is None:
             if self.status == Status.UP or self.status == Status.DECOMMISSIONED: