Test Failure: org.apache.cassandra.distributed.test.UpgradeSSTablesTest.truncateWhileUpgrading

patch by Berenguer Blasi; reviewed by Brandon Williams for CASSANDRA-19398
diff --git a/test/distributed/org/apache/cassandra/distributed/test/UpgradeSSTablesTest.java b/test/distributed/org/apache/cassandra/distributed/test/UpgradeSSTablesTest.java
index 2bdcaf1..76c6c18 100644
--- a/test/distributed/org/apache/cassandra/distributed/test/UpgradeSSTablesTest.java
+++ b/test/distributed/org/apache/cassandra/distributed/test/UpgradeSSTablesTest.java
@@ -50,10 +50,11 @@
 
 public class UpgradeSSTablesTest extends TestBaseImpl
 {
+
     @Test
     public void upgradeSSTablesInterruptsOngoingCompaction() throws Throwable
     {
-        try (ICluster<IInvokableInstance> cluster = init(builder().withNodes(1).start()))
+        try (ICluster<IInvokableInstance> cluster = init(builder().withNodes(1).withInstanceInitializer(CompactionLatchByteman::install).start()))
         {
             cluster.schemaChange("CREATE TABLE " + KEYSPACE + ".tbl (pk int, ck int, v text, PRIMARY KEY (pk, ck));");
             cluster.get(1).acceptsOnInstance((String ks) -> {
@@ -79,10 +80,15 @@
 
             LogAction logAction = cluster.get(1).logs();
             logAction.mark();
+
             Future<?> future = cluster.get(1).asyncAcceptsOnInstance((String ks) -> {
                 ColumnFamilyStore cfs = Keyspace.open(ks).getColumnFamilyStore("tbl");
                 CompactionManager.instance.submitMaximal(cfs, FBUtilities.nowInSeconds(), false, OperationType.COMPACTION);
             }).apply(KEYSPACE);
+
+            Assert.assertTrue(cluster.get(1).callOnInstance(() -> CompactionLatchByteman.starting.awaitUninterruptibly(1, TimeUnit.MINUTES)));
+            cluster.get(1).runOnInstance(() -> {
+                CompactionLatchByteman.start.decrement();});
             Assert.assertEquals(0, cluster.get(1).nodetool("upgradesstables", "-a", KEYSPACE, "tbl"));
             future.get();
             Assert.assertFalse(logAction.grep("Compaction interrupted").getResult().isEmpty());
@@ -136,7 +142,7 @@
     @Test
     public void cleanupDoesNotInterruptUpgradeSSTables() throws Throwable
     {
-        try (ICluster<IInvokableInstance> cluster = init(builder().withNodes(1).withInstanceInitializer(BB::install).start()))
+        try (ICluster<IInvokableInstance> cluster = init(builder().withNodes(1).withInstanceInitializer(UpgradeSStablesLatchByteman::install).start()))
         {
             cluster.schemaChange("CREATE TABLE " + KEYSPACE + ".tbl (pk int, ck int, v text, PRIMARY KEY (pk, ck));");
 
@@ -160,12 +166,12 @@
             LogAction logAction = cluster.get(1).logs();
             logAction.mark();
 
-            // Start upgradingsstables - use BB to pause once inside ActiveCompactions.beginCompaction
+            // Start upgradingsstables - use UpgradeSStablesLatchByteman to pause once inside ActiveCompactions.beginCompaction
             Thread upgradeThread = new Thread(() -> {
                 cluster.get(1).nodetool("upgradesstables", "-a", KEYSPACE, "tbl");
             });
             upgradeThread.start();
-            Assert.assertTrue(cluster.get(1).callOnInstance(() -> BB.starting.awaitUninterruptibly(1, TimeUnit.MINUTES)));
+            Assert.assertTrue(cluster.get(1).callOnInstance(() -> UpgradeSStablesLatchByteman.starting.awaitUninterruptibly(1, TimeUnit.MINUTES)));
 
             // Start a scrub and make sure that it fails, log check later to make sure it was
             // because it cannot cancel the active upgrade sstables
@@ -173,7 +179,7 @@
 
             // Now resume the upgrade sstables so test can shut down
             cluster.get(1).runOnInstance(() -> {
-                BB.start.decrement();
+                UpgradeSStablesLatchByteman.start.decrement();
             });
             upgradeThread.join();
 
@@ -186,7 +192,7 @@
     @Test
     public void truncateWhileUpgrading() throws Throwable
     {
-        try (ICluster<IInvokableInstance> cluster = init(builder().withNodes(1).start()))
+        try (ICluster<IInvokableInstance> cluster = init(builder().withNodes(1).withInstanceInitializer(UpgradeSStablesLatchByteman::install).start()))
         {
             cluster.schemaChange(withKeyspace("CREATE TABLE %s.tbl (pk int, ck int, v text, PRIMARY KEY (pk, ck)) "));
             cluster.get(1).acceptsOnInstance((String ks) -> {
@@ -215,6 +221,8 @@
                 cluster.get(1).nodetool("upgradesstables", "-a", KEYSPACE, "tbl");
             });
 
+            Assert.assertTrue(cluster.get(1).callOnInstance(() -> UpgradeSStablesLatchByteman.starting.awaitUninterruptibly(1, TimeUnit.MINUTES)));
+            cluster.get(1).runOnInstance(() -> {UpgradeSStablesLatchByteman.start.decrement();});
             cluster.schemaChange(withKeyspace("TRUNCATE %s.tbl"));
             upgrade.get();
             Assert.assertFalse(logAction.grep("Compaction interrupted").getResult().isEmpty());
@@ -303,7 +311,7 @@
         }
     }
 
-    public static class BB
+    public static class UpgradeSStablesLatchByteman
     {
         // Will be initialized in the context of the instance class loader
         static CountDownLatch starting = newCountDownLatch(1);
@@ -313,7 +321,7 @@
         {
             new ByteBuddy().rebase(ActiveCompactions.class)
                            .method(named("beginCompaction"))
-                           .intercept(MethodDelegation.to(BB.class))
+                           .intercept(MethodDelegation.to(UpgradeSStablesLatchByteman.class))
                            .make()
                            .load(classLoader, ClassLoadingStrategy.Default.INJECTION);
         }
@@ -336,4 +344,38 @@
             }
         }
     }
+
+    public static class CompactionLatchByteman
+    {
+        // Will be initialized in the context of the instance class loader
+        static CountDownLatch starting = newCountDownLatch(1);
+        static CountDownLatch start = newCountDownLatch(1);
+
+        public static void install(ClassLoader classLoader, Integer num)
+        {
+            new ByteBuddy().rebase(ActiveCompactions.class)
+                           .method(named("beginCompaction"))
+                           .intercept(MethodDelegation.to(CompactionLatchByteman.class))
+                           .make()
+                           .load(classLoader, ClassLoadingStrategy.Default.INJECTION);
+        }
+
+        @SuppressWarnings("unused")
+        public static void beginCompaction(CompactionInfo.Holder ci, @SuperCall Callable<Void> zuperCall)
+        {
+            try
+            {
+                zuperCall.call();
+                if (ci.getCompactionInfo().getTaskType() == OperationType.COMPACTION)
+                {
+                    starting.decrement();
+                    Assert.assertTrue(start.awaitUninterruptibly(1, TimeUnit.MINUTES));
+                }
+            }
+            catch (Exception e)
+            {
+                throw new RuntimeException(e);
+            }
+        }
+    }
 }