Probabilistic diff to sample partitions for diff testing based on probability

Patch by Jyothsna Konisa; reviewed by Dinesh Joshi, Yifan Cai for CASSANDRA-16967
diff --git a/common/src/main/java/org/apache/cassandra/diff/JobConfiguration.java b/common/src/main/java/org/apache/cassandra/diff/JobConfiguration.java
index 7a20b30..8d74de8 100644
--- a/common/src/main/java/org/apache/cassandra/diff/JobConfiguration.java
+++ b/common/src/main/java/org/apache/cassandra/diff/JobConfiguration.java
@@ -88,6 +88,13 @@
     MetadataKeyspaceOptions metadataOptions();
 
     /**
+     * Sampling probability ranges from 0-1 which decides how many partitions are to be diffed using probabilistic diff
+     * default value is 1 which means all the partitions are diffed
+     * @return partitionSamplingProbability
+     */
+    double partitionSamplingProbability();
+
+    /**
      * Contains the options that specify the retry strategy for retrieving data at the application level.
      * Note that it is different than cassandra java driver's {@link com.datastax.driver.core.policies.RetryPolicy},
      * which is evaluated at the Netty worker threads.
diff --git a/common/src/main/java/org/apache/cassandra/diff/YamlJobConfiguration.java b/common/src/main/java/org/apache/cassandra/diff/YamlJobConfiguration.java
index 359466a..7d60403 100644
--- a/common/src/main/java/org/apache/cassandra/diff/YamlJobConfiguration.java
+++ b/common/src/main/java/org/apache/cassandra/diff/YamlJobConfiguration.java
@@ -48,6 +48,7 @@
     public String specific_tokens = null;
     public String disallowed_tokens = null;
     public RetryOptions retry_options;
+    public double partition_sampling_probability = 1;
 
     public static YamlJobConfiguration load(InputStream inputStream) {
         Yaml yaml = new Yaml(new CustomClassLoaderConstructor(YamlJobConfiguration.class,
@@ -103,6 +104,11 @@
         return metadata_options;
     }
 
+    @Override
+    public double partitionSamplingProbability() {
+        return partition_sampling_probability;
+    }
+
     public RetryOptions retryOptions() {
         return retry_options;
     }
@@ -130,6 +136,7 @@
                ", keyspace_tables=" + keyspace_tables +
                ", buckets=" + buckets +
                ", rate_limit=" + rate_limit +
+               ", partition_sampling_probability=" + partition_sampling_probability +
                ", job_id='" + job_id + '\'' +
                ", token_scan_fetch_size=" + token_scan_fetch_size +
                ", partition_read_fetch_size=" + partition_read_fetch_size +
diff --git a/spark-job/src/main/java/org/apache/cassandra/diff/Differ.java b/spark-job/src/main/java/org/apache/cassandra/diff/Differ.java
index cf1c9a5..11794c5 100644
--- a/spark-job/src/main/java/org/apache/cassandra/diff/Differ.java
+++ b/spark-job/src/main/java/org/apache/cassandra/diff/Differ.java
@@ -27,10 +27,12 @@
 import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
+import java.util.Random;
 import java.util.UUID;
 import java.util.concurrent.Callable;
 import java.util.function.BiConsumer;
 import java.util.function.Function;
+import java.util.function.Predicate;
 import java.util.stream.Collectors;
 
 import com.google.common.annotations.VisibleForTesting;
@@ -63,6 +65,7 @@
     private final double reverseReadProbability;
     private final SpecificTokens specificTokens;
     private final RetryStrategyProvider retryStrategyProvider;
+    private final double partitionSamplingProbability;
 
     private static DiffCluster srcDiffCluster;
     private static DiffCluster targetDiffCluster;
@@ -103,6 +106,7 @@
         this.reverseReadProbability = config.reverseReadProbability();
         this.specificTokens = config.specificTokens();
         this.retryStrategyProvider = retryStrategyProvider;
+        this.partitionSamplingProbability = config.partitionSamplingProbability();
         synchronized (Differ.class)
         {
             /*
@@ -225,12 +229,28 @@
                                                               mismatchReporter,
                                                               journal,
                                                               COMPARISON_EXECUTOR);
-
-        final RangeStats tableStats = rangeComparator.compare(sourceKeys, targetKeys, partitionTaskProvider);
+        final Predicate<PartitionKey> partitionSamplingFunction = shouldIncludePartition(jobId, partitionSamplingProbability);
+        final RangeStats tableStats = rangeComparator.compare(sourceKeys, targetKeys, partitionTaskProvider, partitionSamplingFunction);
         logger.debug("Table [{}] stats - ({})", context.table.getTable(), tableStats);
         return tableStats;
     }
 
+    // Returns a function which decides if we should include a partition for diffing
+    // Uses probability for sampling.
+    @VisibleForTesting
+    static Predicate<PartitionKey> shouldIncludePartition(final UUID jobId, final double partitionSamplingProbability) {
+        if (partitionSamplingProbability > 1 || partitionSamplingProbability <= 0) {
+            logger.error("Invalid partition sampling property {}, it should be between 0 and 1", partitionSamplingProbability);
+            throw new IllegalArgumentException("Invalid partition sampling property, it should be between 0 and 1");
+        }
+        if (partitionSamplingProbability == 1) {
+            return partitionKey -> true;
+        } else {
+            final Random random = new Random(jobId.hashCode());
+            return partitionKey -> random.nextDouble() <= partitionSamplingProbability;
+        }
+    }
+
     private Iterator<Row> fetchRows(DiffContext context, PartitionKey key, boolean shouldReverse, DiffCluster.Type type) {
         Callable<Iterator<Row>> rows = () -> type == DiffCluster.Type.SOURCE
                                              ? context.source.getPartition(context.table, key, shouldReverse)
diff --git a/spark-job/src/main/java/org/apache/cassandra/diff/RangeComparator.java b/spark-job/src/main/java/org/apache/cassandra/diff/RangeComparator.java
index 5d6710e..280fbd5 100644
--- a/spark-job/src/main/java/org/apache/cassandra/diff/RangeComparator.java
+++ b/spark-job/src/main/java/org/apache/cassandra/diff/RangeComparator.java
@@ -27,6 +27,7 @@
 import java.util.function.BiConsumer;
 import java.util.function.Consumer;
 import java.util.function.Function;
+import java.util.function.Predicate;
 
 import com.google.common.base.Verify;
 import org.slf4j.Logger;
@@ -57,6 +58,22 @@
     public RangeStats compare(Iterator<PartitionKey> sourceKeys,
                               Iterator<PartitionKey> targetKeys,
                               Function<PartitionKey, PartitionComparator> partitionTaskProvider) {
+        return compare(sourceKeys,targetKeys,partitionTaskProvider, partitionKey -> true);
+    }
+
+    /**
+     * Compares partitions in src and target clusters.
+     *
+     * @param sourceKeys partition keys in the source cluster
+     * @param targetKeys partition keys in the target cluster
+     * @param partitionTaskProvider comparision task
+     * @param partitionSampler samples partitions based on the probability for probabilistic diff
+     * @return stats about the diff
+     */
+    public RangeStats compare(Iterator<PartitionKey> sourceKeys,
+                              Iterator<PartitionKey> targetKeys,
+                              Function<PartitionKey, PartitionComparator> partitionTaskProvider,
+                              Predicate<PartitionKey> partitionSampler) {
 
         final RangeStats rangeStats = RangeStats.newStats();
         // We can catch this condition earlier, but it doesn't hurt to also check here
@@ -115,11 +132,16 @@
 
                     BigInteger token = sourceKey.getTokenAsBigInteger();
                     try {
-                        PartitionComparator comparisonTask = partitionTaskProvider.apply(sourceKey);
-                        comparisonExecutor.submit(comparisonTask,
-                                                  onSuccess(rangeStats, partitionCount, token, highestTokenSeen, mismatchReporter, journal),
-                                                  onError(rangeStats, token, errorReporter),
-                                                  phaser);
+                        // Use probabilisticPartitionSampler for sampling partitions, skip partition
+                        // if the sampler returns false otherwise run diff on that partition
+                        if (partitionSampler.test(sourceKey)) {
+                            PartitionComparator comparisonTask = partitionTaskProvider.apply(sourceKey);
+                            comparisonExecutor.submit(comparisonTask,
+                                                      onSuccess(rangeStats, partitionCount, token, highestTokenSeen, mismatchReporter, journal),
+                                                      onError(rangeStats, token, errorReporter),
+                                                      phaser);
+                        }
+
                     } catch (Throwable t) {
                         // Handle errors thrown when creating the comparison task. This should trap timeouts and
                         // unavailables occurring when performing the initial query to read the full partition.
diff --git a/spark-job/src/test/java/org/apache/cassandra/diff/DiffJobTest.java b/spark-job/src/test/java/org/apache/cassandra/diff/DiffJobTest.java
index 1bf656d..49c1f11 100644
--- a/spark-job/src/test/java/org/apache/cassandra/diff/DiffJobTest.java
+++ b/spark-job/src/test/java/org/apache/cassandra/diff/DiffJobTest.java
@@ -108,5 +108,10 @@
         public Optional<UUID> jobId() {
             return Optional.of(UUID.randomUUID());
         }
+
+        @Override
+        public double partitionSamplingProbability() {
+            return 1;
+        }
     }
 }
diff --git a/spark-job/src/test/java/org/apache/cassandra/diff/DifferTest.java b/spark-job/src/test/java/org/apache/cassandra/diff/DifferTest.java
index e588575..b1b524d 100644
--- a/spark-job/src/test/java/org/apache/cassandra/diff/DifferTest.java
+++ b/spark-job/src/test/java/org/apache/cassandra/diff/DifferTest.java
@@ -21,16 +21,65 @@
 
 import java.math.BigInteger;
 import java.util.Map;
+import java.util.UUID;
 import java.util.function.Function;
+import java.util.function.Predicate;
 
 import com.google.common.base.VerifyException;
 import com.google.common.collect.Lists;
+import org.junit.Rule;
 import org.junit.Test;
+import org.junit.rules.ExpectedException;
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
 
 public class DifferTest {
+    @Rule
+    public ExpectedException expectedException = ExpectedException.none();
+
+    @Test
+    public void testIncludeAllPartitions() {
+        final PartitionKey testKey = new RangeComparatorTest.TestPartitionKey(0);
+        final UUID uuid = UUID.fromString("cde3b15d-2363-4028-885a-52de58bad64e");
+        assertTrue(Differ.shouldIncludePartition(uuid, 1).test(testKey));
+    }
+
+    @Test
+    public void shouldIncludePartitionWithProbabilityInvalidProbability() {
+        final PartitionKey testKey = new RangeComparatorTest.TestPartitionKey(0);
+        final UUID uuid = UUID.fromString("cde3b15d-2363-4028-885a-52de58bad64e");
+        expectedException.expect(IllegalArgumentException.class);
+        expectedException.expectMessage("Invalid partition sampling property, it should be between 0 and 1");
+        Differ.shouldIncludePartition(uuid, -1).test(testKey);
+    }
+
+    @Test
+    public void shouldIncludePartitionWithProbabilityHalf() {
+        final PartitionKey testKey = new RangeComparatorTest.TestPartitionKey(0);
+        int count = 0;
+        final UUID uuid = UUID.fromString("cde3b15d-2363-4028-885a-52de58bad64e");
+        final Predicate<PartitionKey> partitionSampler = Differ.shouldIncludePartition(uuid, 0.5);
+        for (int i = 0; i < 20; i++) {
+            if (partitionSampler.test(testKey)) {
+                count++;
+            }
+        }
+        assertTrue(count <= 15);
+        assertTrue(count >= 5);
+    }
+
+    @Test
+    public void shouldIncludePartitionShouldGenerateSameSequenceForGivenJobId() {
+        final UUID uuid = UUID.fromString("cde3b15d-2363-4028-885a-52de58bad64e");
+        final PartitionKey testKey = new RangeComparatorTest.TestPartitionKey(0);
+        final Predicate<PartitionKey> partitionSampler1 = Differ.shouldIncludePartition(uuid, 0.5);
+        final Predicate<PartitionKey> partitionSampler2 = Differ.shouldIncludePartition(uuid, 0.5);
+        for (int i = 0; i < 10; i++) {
+            assertEquals(partitionSampler2.test(testKey), partitionSampler1.test(testKey));
+        }
+    }
 
     @Test(expected = VerifyException.class)
     public void rejectNullStartOfRange() {
diff --git a/spark-job/src/test/java/org/apache/cassandra/diff/RangeComparatorTest.java b/spark-job/src/test/java/org/apache/cassandra/diff/RangeComparatorTest.java
index fd2926b..e09f68f 100644
--- a/spark-job/src/test/java/org/apache/cassandra/diff/RangeComparatorTest.java
+++ b/spark-job/src/test/java/org/apache/cassandra/diff/RangeComparatorTest.java
@@ -57,6 +57,38 @@
     private RetryStrategyProvider mockRetryStrategyFactory = RetryStrategyProvider.create(null); // create a NoRetry provider
 
     @Test
+    public void probabilisticDiffIncludeAllPartitions() {
+        RangeComparator comparator = comparator(context(0L, 100L));
+        RangeStats stats = comparator.compare(keys(0, 1, 2, 3, 4, 5, 6), keys(0,1, 2, 3, 4, 5, 7), this::alwaysMatch);
+        assertFalse(stats.isEmpty());
+        assertEquals(1, stats.getOnlyInSource());
+        assertEquals(1, stats.getOnlyInTarget());
+        assertEquals(6, stats.getMatchedPartitions());
+        assertReported(6, MismatchType.ONLY_IN_SOURCE, mismatches);
+        assertReported(7, MismatchType.ONLY_IN_TARGET, mismatches);
+        assertNothingReported(errors, journal);
+        assertCompared(0, 1, 2, 3, 4, 5);
+    }
+
+    @Test
+    public void probabilisticDiffProbabilityHalf() {
+        RangeComparator comparator = comparator(context(0L, 100L));
+        RangeStats stats = comparator.compare(keys(0, 1, 2, 3, 4, 5, 6),
+                                              keys(0, 1, 2, 3, 4, 5, 7),
+                                              this::alwaysMatch,
+                                              key -> key.getTokenAsBigInteger().intValue() % 2 == 0);
+        assertFalse(stats.isEmpty());
+        assertEquals(1, stats.getOnlyInSource());
+        assertEquals(1, stats.getOnlyInTarget());
+        assertEquals(3, stats.getMatchedPartitions());
+        assertReported(6, MismatchType.ONLY_IN_SOURCE, mismatches);
+        assertReported(7, MismatchType.ONLY_IN_TARGET, mismatches);
+        assertNothingReported(errors, journal);
+        assertCompared(0, 2, 4);
+    }
+
+
+    @Test
     public void emptyRange() {
         RangeComparator comparator = comparator(context(100L, 100L));
         RangeStats stats = comparator.compare(keys(), keys(), this::alwaysMatch);
diff --git a/spark-job/src/test/java/org/apache/cassandra/diff/SchemaTest.java b/spark-job/src/test/java/org/apache/cassandra/diff/SchemaTest.java
index 17dc67c..b94d22c 100644
--- a/spark-job/src/test/java/org/apache/cassandra/diff/SchemaTest.java
+++ b/spark-job/src/test/java/org/apache/cassandra/diff/SchemaTest.java
@@ -29,6 +29,11 @@
         public List<String> disallowedKeyspaces() {
             return disallowedKeyspaces;
         }
+
+        @Override
+        public double partitionSamplingProbability() {
+            return 1;
+        }
     }
 
     @Test