StaticTokenTreeBuilder should respect posibility of duplicate tokens

patch by jrwest and xedin; reviewed by xedin for CASSANDRA-11525
diff --git a/CHANGES.txt b/CHANGES.txt
index 58d8ae8..392d9e7 100644
--- a/CHANGES.txt
+++ b/CHANGES.txt
@@ -1,4 +1,5 @@
 3.5
+ * StaticTokenTreeBuilder should respect posibility of duplicate tokens (CASSANDRA-11525)
  * Correctly fix potential assertion error during compaction (CASSANDRA-11353)
  * Avoid index segment stitching in RAM which lead to OOM on big SSTable files (CASSANDRA-11383)
  * Fix clustering and row filters for LIKE queries on clustering columns (CASSANDRA-11397)
diff --git a/src/java/org/apache/cassandra/index/sasi/disk/AbstractTokenTreeBuilder.java b/src/java/org/apache/cassandra/index/sasi/disk/AbstractTokenTreeBuilder.java
index 4e93b2b..9a1f7f1 100644
--- a/src/java/org/apache/cassandra/index/sasi/disk/AbstractTokenTreeBuilder.java
+++ b/src/java/org/apache/cassandra/index/sasi/disk/AbstractTokenTreeBuilder.java
@@ -397,6 +397,7 @@
 
             public short offsetExtra()
             {
+                // exta offset is supposed to be an unsigned 16-bit integer
                 return (short) offset;
             }
         }
diff --git a/src/java/org/apache/cassandra/index/sasi/disk/StaticTokenTreeBuilder.java b/src/java/org/apache/cassandra/index/sasi/disk/StaticTokenTreeBuilder.java
index 147427e..7a41b38 100644
--- a/src/java/org/apache/cassandra/index/sasi/disk/StaticTokenTreeBuilder.java
+++ b/src/java/org/apache/cassandra/index/sasi/disk/StaticTokenTreeBuilder.java
@@ -79,7 +79,7 @@
 
     public boolean isEmpty()
     {
-        return combinedTerm.getTokenIterator().getCount() == 0;
+        return tokenCount == 0;
     }
 
     public Iterator<Pair<Long, LongSet>> iterator()
@@ -100,7 +100,7 @@
 
     public long getTokenCount()
     {
-        return combinedTerm.getTokenIterator().getCount();
+        return tokenCount;
     }
 
     @Override
@@ -130,64 +130,50 @@
     {
         RangeIterator<Long, Token> tokens = combinedTerm.getTokenIterator();
 
-        tokenCount = tokens.getCount();
+        tokenCount = 0;
         treeMinToken = tokens.getMinimum();
         treeMaxToken = tokens.getMaximum();
         numBlocks = 1;
 
-        if (tokenCount <= TOKENS_PER_BLOCK)
+        root = new InteriorNode();
+        rightmostParent = (InteriorNode) root;
+        Leaf lastLeaf = null;
+        Long lastToken, firstToken = null;
+        int leafSize = 0;
+        while (tokens.hasNext())
         {
-            leftmostLeaf = new StaticLeaf(tokens, tokens.getMinimum(), tokens.getMaximum(), tokens.getCount(), true);
-            rightmostLeaf = leftmostLeaf;
-            root = leftmostLeaf;
+            Long token = tokens.next().get();
+            if (firstToken == null)
+                firstToken = token;
+
+            tokenCount++;
+            leafSize++;
+
+            // skip until the last token in the leaf
+            if (tokenCount % TOKENS_PER_BLOCK != 0 && token != treeMaxToken)
+                continue;
+
+            lastToken = token;
+            Leaf leaf = new PartialLeaf(firstToken, lastToken, leafSize);
+            if (lastLeaf == null) // first leaf created
+                leftmostLeaf = leaf;
+            else
+                lastLeaf.next = leaf;
+
+
+            rightmostParent.add(leaf);
+            lastLeaf = rightmostLeaf = leaf;
+            firstToken = null;
+            numBlocks++;
+            leafSize = 0;
         }
-        else
+
+        // if the tree is really a single leaf the empty root interior
+        // node must be discarded
+        if (root.tokenCount() == 0)
         {
-            root = new InteriorNode();
-            rightmostParent = (InteriorNode) root;
-
-            // build all the leaves except for maybe
-            // the last leaf which is not completely full .
-            // This loop relies on the fact that multiple index segments
-            // will never have token intersection for a single term,
-            // because it's impossible to encounter the same value for
-            // the same column multiple times in a single key/sstable.
-            Leaf lastLeaf = null;
-            long numFullLeaves = tokenCount / TOKENS_PER_BLOCK;
-            for (long i = 0; i < numFullLeaves; i++)
-            {
-                Long firstToken = tokens.next().get();
-                for (int j = 1; j < (TOKENS_PER_BLOCK - 1); j++)
-                    tokens.next();
-
-                Long lastToken = tokens.next().get();
-                Leaf leaf = new PartialLeaf(firstToken, lastToken, TOKENS_PER_BLOCK);
-
-                if (lastLeaf == null)
-                    leftmostLeaf = leaf;
-                else
-                    lastLeaf.next = leaf;
-
-                rightmostParent.add(leaf);
-                lastLeaf = rightmostLeaf = leaf;
-                numBlocks++;
-            }
-
-            // build the last leaf out of any remaining tokens if necessary
-            // safe downcast since TOKENS_PER_BLOCK is an int
-            int remainingTokens = (int) (tokenCount % TOKENS_PER_BLOCK);
-            if (remainingTokens != 0)
-            {
-                Long firstToken = tokens.next().get();
-                Long lastToken = firstToken;
-                while (tokens.hasNext())
-                    lastToken = tokens.next().get();
-
-                Leaf leaf = new PartialLeaf(firstToken, lastToken, remainingTokens);
-                rightmostParent.add(leaf);
-                lastLeaf.next = rightmostLeaf = leaf;
-                numBlocks++;
-            }
+            numBlocks = 1;
+            root = new StaticLeaf(combinedTerm.getTokenIterator(), treeMinToken, treeMaxToken, tokenCount, true);
         }
     }
 
diff --git a/src/java/org/apache/cassandra/index/sasi/disk/TokenTree.java b/src/java/org/apache/cassandra/index/sasi/disk/TokenTree.java
index 3f8182d..c69ce00 100644
--- a/src/java/org/apache/cassandra/index/sasi/disk/TokenTree.java
+++ b/src/java/org/apache/cassandra/index/sasi/disk/TokenTree.java
@@ -470,30 +470,32 @@
         private long[] fetchOffsets()
         {
             short info = buffer.getShort(position);
-            short offsetShort = buffer.getShort(position + SHORT_BYTES);
-            int offsetInt = buffer.getInt(position + (2 * SHORT_BYTES) + LONG_BYTES);
+            // offset extra is unsigned short (right-most 16 bits of 48 bits allowed for an offset)
+            int offsetExtra = buffer.getShort(position + SHORT_BYTES) & 0xFFFF;
+            // is the it left-most (32-bit) base of the actual offset in the index file
+            int offsetData = buffer.getInt(position + (2 * SHORT_BYTES) + LONG_BYTES);
 
             EntryType type = EntryType.of(info & TokenTreeBuilder.ENTRY_TYPE_MASK);
 
             switch (type)
             {
                 case SIMPLE:
-                    return new long[] { offsetInt };
+                    return new long[] { offsetData };
 
                 case OVERFLOW:
-                    long[] offsets = new long[offsetShort]; // offsetShort contains count of tokens
-                    long offsetPos = (buffer.position() + (2 * (leafSize * LONG_BYTES)) + (offsetInt * LONG_BYTES));
+                    long[] offsets = new long[offsetExtra]; // offsetShort contains count of tokens
+                    long offsetPos = (buffer.position() + (2 * (leafSize * LONG_BYTES)) + (offsetData * LONG_BYTES));
 
-                    for (int i = 0; i < offsetShort; i++)
+                    for (int i = 0; i < offsetExtra; i++)
                         offsets[i] = buffer.getLong(offsetPos + (i * LONG_BYTES));
 
                     return offsets;
 
                 case FACTORED:
-                    return new long[] { (((long) offsetInt) << Short.SIZE) + offsetShort };
+                    return new long[] { (((long) offsetData) << Short.SIZE) + offsetExtra };
 
                 case PACKED:
-                    return new long[] { offsetShort, offsetInt };
+                    return new long[] { offsetExtra, offsetData };
 
                 default:
                     throw new IllegalStateException("Unknown entry type: " + type);
diff --git a/test/unit/org/apache/cassandra/index/sasi/disk/TokenTreeTest.java b/test/unit/org/apache/cassandra/index/sasi/disk/TokenTreeTest.java
index 189e9c6..67e54f4 100644
--- a/test/unit/org/apache/cassandra/index/sasi/disk/TokenTreeTest.java
+++ b/test/unit/org/apache/cassandra/index/sasi/disk/TokenTreeTest.java
@@ -33,6 +33,7 @@
 import org.apache.cassandra.index.sasi.utils.MappedBuffer;
 import org.apache.cassandra.index.sasi.utils.RangeIterator;
 import org.apache.cassandra.db.marshal.LongType;
+import org.apache.cassandra.index.sasi.utils.RangeUnionIterator;
 import org.apache.cassandra.io.compress.BufferType;
 import org.apache.cassandra.io.util.FileUtils;
 import org.apache.cassandra.utils.MurmurHash;
@@ -52,7 +53,7 @@
     private static final Function<Long, DecoratedKey> KEY_CONVERTER = new KeyConverter();
 
     static LongSet singleOffset = new LongOpenHashSet() {{ add(1); }};
-    static LongSet bigSingleOffset = new LongOpenHashSet() {{ add(((long) Integer.MAX_VALUE) + 10); }};
+    static LongSet bigSingleOffset = new LongOpenHashSet() {{ add(2147521562L); }};
     static LongSet shortPackableCollision = new LongOpenHashSet() {{ add(2L); add(3L); }}; // can pack two shorts
     static LongSet intPackableCollision = new LongOpenHashSet() {{ add(6L); add(((long) Short.MAX_VALUE) + 1); }}; // can pack int & short
     static LongSet multiCollision =  new LongOpenHashSet() {{ add(3L); add(4L); add(5L); }}; // can't pack
@@ -353,6 +354,75 @@
         Assert.assertEquals(EntryType.OVERFLOW, EntryType.of(EntryType.OVERFLOW.ordinal()));
     }
 
+    @Test
+    public void testMergingOfEqualTokenTrees() throws Exception
+    {
+        testMergingOfEqualTokenTrees(simpleTokenMap);
+        testMergingOfEqualTokenTrees(bigTokensMap);
+    }
+
+    public void testMergingOfEqualTokenTrees(SortedMap<Long, LongSet> tokensMap) throws Exception
+    {
+        TokenTreeBuilder tokensA = new DynamicTokenTreeBuilder(tokensMap);
+        TokenTreeBuilder tokensB = new DynamicTokenTreeBuilder(tokensMap);
+
+        TokenTree a = buildTree(tokensA);
+        TokenTree b = buildTree(tokensB);
+
+        TokenTreeBuilder tokensC = new StaticTokenTreeBuilder(new CombinedTerm(null, null)
+        {
+            public RangeIterator<Long, Token> getTokenIterator()
+            {
+                RangeIterator.Builder<Long, Token> union = RangeUnionIterator.builder();
+                union.add(a.iterator(new KeyConverter()));
+                union.add(b.iterator(new KeyConverter()));
+
+                return union.build();
+            }
+        });
+
+        TokenTree c = buildTree(tokensC);
+        Assert.assertEquals(tokensMap.size(), c.getCount());
+
+        Iterator<Token> tokenIterator = c.iterator(KEY_CONVERTER);
+        Iterator<Map.Entry<Long, LongSet>> listIterator = tokensMap.entrySet().iterator();
+        while (tokenIterator.hasNext() && listIterator.hasNext())
+        {
+            Token treeNext = tokenIterator.next();
+            Map.Entry<Long, LongSet> listNext = listIterator.next();
+
+            Assert.assertEquals(listNext.getKey(), treeNext.get());
+            Assert.assertEquals(convert(listNext.getValue()), convert(treeNext));
+        }
+
+        for (Map.Entry<Long, LongSet> entry : tokensMap.entrySet())
+        {
+            TokenTree.OnDiskToken result = c.get(entry.getKey(), KEY_CONVERTER);
+            Assert.assertNotNull("failed to find object for token " + entry.getKey(), result);
+
+            LongSet found = result.getOffsets();
+            Assert.assertEquals(entry.getValue(), found);
+
+        }
+    }
+
+
+    private static TokenTree buildTree(TokenTreeBuilder builder) throws Exception
+    {
+        builder.finish();
+        final File treeFile = File.createTempFile("token-tree-", "db");
+        treeFile.deleteOnExit();
+
+        try (SequentialWriter writer = new SequentialWriter(treeFile, 4096, BufferType.ON_HEAP))
+        {
+            builder.write(writer);
+            writer.sync();
+        }
+
+        final RandomAccessReader reader = RandomAccessReader.open(treeFile);
+        return new TokenTree(new MappedBuffer(reader));
+    }
+
     private static class EntrySetSkippableIterator extends RangeIterator<Long, TokenWithOffsets>
     {
         private final PeekingIterator<Map.Entry<Long, LongSet>> elements;