Probabilistic diff to sample partitions for diff testing based on probability
Patch by Jyothsna Konisa; reviewed by Dinesh Joshi, Yifan Cai for CASSANDRA-16967
diff --git a/common/src/main/java/org/apache/cassandra/diff/JobConfiguration.java b/common/src/main/java/org/apache/cassandra/diff/JobConfiguration.java
index 7a20b30..8d74de8 100644
--- a/common/src/main/java/org/apache/cassandra/diff/JobConfiguration.java
+++ b/common/src/main/java/org/apache/cassandra/diff/JobConfiguration.java
@@ -88,6 +88,13 @@
MetadataKeyspaceOptions metadataOptions();
/**
+ * Sampling probability ranges from 0-1 which decides how many partitions are to be diffed using probabilistic diff
+ * default value is 1 which means all the partitions are diffed
+ * @return partitionSamplingProbability
+ */
+ double partitionSamplingProbability();
+
+ /**
* Contains the options that specify the retry strategy for retrieving data at the application level.
* Note that it is different than cassandra java driver's {@link com.datastax.driver.core.policies.RetryPolicy},
* which is evaluated at the Netty worker threads.
diff --git a/common/src/main/java/org/apache/cassandra/diff/YamlJobConfiguration.java b/common/src/main/java/org/apache/cassandra/diff/YamlJobConfiguration.java
index 359466a..7d60403 100644
--- a/common/src/main/java/org/apache/cassandra/diff/YamlJobConfiguration.java
+++ b/common/src/main/java/org/apache/cassandra/diff/YamlJobConfiguration.java
@@ -48,6 +48,7 @@
public String specific_tokens = null;
public String disallowed_tokens = null;
public RetryOptions retry_options;
+ public double partition_sampling_probability = 1;
public static YamlJobConfiguration load(InputStream inputStream) {
Yaml yaml = new Yaml(new CustomClassLoaderConstructor(YamlJobConfiguration.class,
@@ -103,6 +104,11 @@
return metadata_options;
}
+ @Override
+ public double partitionSamplingProbability() {
+ return partition_sampling_probability;
+ }
+
public RetryOptions retryOptions() {
return retry_options;
}
@@ -130,6 +136,7 @@
", keyspace_tables=" + keyspace_tables +
", buckets=" + buckets +
", rate_limit=" + rate_limit +
+ ", partition_sampling_probability=" + partition_sampling_probability +
", job_id='" + job_id + '\'' +
", token_scan_fetch_size=" + token_scan_fetch_size +
", partition_read_fetch_size=" + partition_read_fetch_size +
diff --git a/spark-job/src/main/java/org/apache/cassandra/diff/Differ.java b/spark-job/src/main/java/org/apache/cassandra/diff/Differ.java
index cf1c9a5..11794c5 100644
--- a/spark-job/src/main/java/org/apache/cassandra/diff/Differ.java
+++ b/spark-job/src/main/java/org/apache/cassandra/diff/Differ.java
@@ -27,10 +27,12 @@
import java.util.Iterator;
import java.util.List;
import java.util.Map;
+import java.util.Random;
import java.util.UUID;
import java.util.concurrent.Callable;
import java.util.function.BiConsumer;
import java.util.function.Function;
+import java.util.function.Predicate;
import java.util.stream.Collectors;
import com.google.common.annotations.VisibleForTesting;
@@ -63,6 +65,7 @@
private final double reverseReadProbability;
private final SpecificTokens specificTokens;
private final RetryStrategyProvider retryStrategyProvider;
+ private final double partitionSamplingProbability;
private static DiffCluster srcDiffCluster;
private static DiffCluster targetDiffCluster;
@@ -103,6 +106,7 @@
this.reverseReadProbability = config.reverseReadProbability();
this.specificTokens = config.specificTokens();
this.retryStrategyProvider = retryStrategyProvider;
+ this.partitionSamplingProbability = config.partitionSamplingProbability();
synchronized (Differ.class)
{
/*
@@ -225,12 +229,28 @@
mismatchReporter,
journal,
COMPARISON_EXECUTOR);
-
- final RangeStats tableStats = rangeComparator.compare(sourceKeys, targetKeys, partitionTaskProvider);
+ final Predicate<PartitionKey> partitionSamplingFunction = shouldIncludePartition(jobId, partitionSamplingProbability);
+ final RangeStats tableStats = rangeComparator.compare(sourceKeys, targetKeys, partitionTaskProvider, partitionSamplingFunction);
logger.debug("Table [{}] stats - ({})", context.table.getTable(), tableStats);
return tableStats;
}
+ // Returns a function which decides if we should include a partition for diffing
+ // Uses probability for sampling.
+ @VisibleForTesting
+ static Predicate<PartitionKey> shouldIncludePartition(final UUID jobId, final double partitionSamplingProbability) {
+ if (partitionSamplingProbability > 1 || partitionSamplingProbability <= 0) {
+ logger.error("Invalid partition sampling property {}, it should be between 0 and 1", partitionSamplingProbability);
+ throw new IllegalArgumentException("Invalid partition sampling property, it should be between 0 and 1");
+ }
+ if (partitionSamplingProbability == 1) {
+ return partitionKey -> true;
+ } else {
+ final Random random = new Random(jobId.hashCode());
+ return partitionKey -> random.nextDouble() <= partitionSamplingProbability;
+ }
+ }
+
private Iterator<Row> fetchRows(DiffContext context, PartitionKey key, boolean shouldReverse, DiffCluster.Type type) {
Callable<Iterator<Row>> rows = () -> type == DiffCluster.Type.SOURCE
? context.source.getPartition(context.table, key, shouldReverse)
diff --git a/spark-job/src/main/java/org/apache/cassandra/diff/RangeComparator.java b/spark-job/src/main/java/org/apache/cassandra/diff/RangeComparator.java
index 5d6710e..280fbd5 100644
--- a/spark-job/src/main/java/org/apache/cassandra/diff/RangeComparator.java
+++ b/spark-job/src/main/java/org/apache/cassandra/diff/RangeComparator.java
@@ -27,6 +27,7 @@
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.function.Function;
+import java.util.function.Predicate;
import com.google.common.base.Verify;
import org.slf4j.Logger;
@@ -57,6 +58,22 @@
public RangeStats compare(Iterator<PartitionKey> sourceKeys,
Iterator<PartitionKey> targetKeys,
Function<PartitionKey, PartitionComparator> partitionTaskProvider) {
+ return compare(sourceKeys,targetKeys,partitionTaskProvider, partitionKey -> true);
+ }
+
+ /**
+ * Compares partitions in src and target clusters.
+ *
+ * @param sourceKeys partition keys in the source cluster
+ * @param targetKeys partition keys in the target cluster
+ * @param partitionTaskProvider comparision task
+ * @param partitionSampler samples partitions based on the probability for probabilistic diff
+ * @return stats about the diff
+ */
+ public RangeStats compare(Iterator<PartitionKey> sourceKeys,
+ Iterator<PartitionKey> targetKeys,
+ Function<PartitionKey, PartitionComparator> partitionTaskProvider,
+ Predicate<PartitionKey> partitionSampler) {
final RangeStats rangeStats = RangeStats.newStats();
// We can catch this condition earlier, but it doesn't hurt to also check here
@@ -115,11 +132,16 @@
BigInteger token = sourceKey.getTokenAsBigInteger();
try {
- PartitionComparator comparisonTask = partitionTaskProvider.apply(sourceKey);
- comparisonExecutor.submit(comparisonTask,
- onSuccess(rangeStats, partitionCount, token, highestTokenSeen, mismatchReporter, journal),
- onError(rangeStats, token, errorReporter),
- phaser);
+ // Use probabilisticPartitionSampler for sampling partitions, skip partition
+ // if the sampler returns false otherwise run diff on that partition
+ if (partitionSampler.test(sourceKey)) {
+ PartitionComparator comparisonTask = partitionTaskProvider.apply(sourceKey);
+ comparisonExecutor.submit(comparisonTask,
+ onSuccess(rangeStats, partitionCount, token, highestTokenSeen, mismatchReporter, journal),
+ onError(rangeStats, token, errorReporter),
+ phaser);
+ }
+
} catch (Throwable t) {
// Handle errors thrown when creating the comparison task. This should trap timeouts and
// unavailables occurring when performing the initial query to read the full partition.
diff --git a/spark-job/src/test/java/org/apache/cassandra/diff/DiffJobTest.java b/spark-job/src/test/java/org/apache/cassandra/diff/DiffJobTest.java
index 1bf656d..49c1f11 100644
--- a/spark-job/src/test/java/org/apache/cassandra/diff/DiffJobTest.java
+++ b/spark-job/src/test/java/org/apache/cassandra/diff/DiffJobTest.java
@@ -108,5 +108,10 @@
public Optional<UUID> jobId() {
return Optional.of(UUID.randomUUID());
}
+
+ @Override
+ public double partitionSamplingProbability() {
+ return 1;
+ }
}
}
diff --git a/spark-job/src/test/java/org/apache/cassandra/diff/DifferTest.java b/spark-job/src/test/java/org/apache/cassandra/diff/DifferTest.java
index e588575..b1b524d 100644
--- a/spark-job/src/test/java/org/apache/cassandra/diff/DifferTest.java
+++ b/spark-job/src/test/java/org/apache/cassandra/diff/DifferTest.java
@@ -21,16 +21,65 @@
import java.math.BigInteger;
import java.util.Map;
+import java.util.UUID;
import java.util.function.Function;
+import java.util.function.Predicate;
import com.google.common.base.VerifyException;
import com.google.common.collect.Lists;
+import org.junit.Rule;
import org.junit.Test;
+import org.junit.rules.ExpectedException;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
public class DifferTest {
+ @Rule
+ public ExpectedException expectedException = ExpectedException.none();
+
+ @Test
+ public void testIncludeAllPartitions() {
+ final PartitionKey testKey = new RangeComparatorTest.TestPartitionKey(0);
+ final UUID uuid = UUID.fromString("cde3b15d-2363-4028-885a-52de58bad64e");
+ assertTrue(Differ.shouldIncludePartition(uuid, 1).test(testKey));
+ }
+
+ @Test
+ public void shouldIncludePartitionWithProbabilityInvalidProbability() {
+ final PartitionKey testKey = new RangeComparatorTest.TestPartitionKey(0);
+ final UUID uuid = UUID.fromString("cde3b15d-2363-4028-885a-52de58bad64e");
+ expectedException.expect(IllegalArgumentException.class);
+ expectedException.expectMessage("Invalid partition sampling property, it should be between 0 and 1");
+ Differ.shouldIncludePartition(uuid, -1).test(testKey);
+ }
+
+ @Test
+ public void shouldIncludePartitionWithProbabilityHalf() {
+ final PartitionKey testKey = new RangeComparatorTest.TestPartitionKey(0);
+ int count = 0;
+ final UUID uuid = UUID.fromString("cde3b15d-2363-4028-885a-52de58bad64e");
+ final Predicate<PartitionKey> partitionSampler = Differ.shouldIncludePartition(uuid, 0.5);
+ for (int i = 0; i < 20; i++) {
+ if (partitionSampler.test(testKey)) {
+ count++;
+ }
+ }
+ assertTrue(count <= 15);
+ assertTrue(count >= 5);
+ }
+
+ @Test
+ public void shouldIncludePartitionShouldGenerateSameSequenceForGivenJobId() {
+ final UUID uuid = UUID.fromString("cde3b15d-2363-4028-885a-52de58bad64e");
+ final PartitionKey testKey = new RangeComparatorTest.TestPartitionKey(0);
+ final Predicate<PartitionKey> partitionSampler1 = Differ.shouldIncludePartition(uuid, 0.5);
+ final Predicate<PartitionKey> partitionSampler2 = Differ.shouldIncludePartition(uuid, 0.5);
+ for (int i = 0; i < 10; i++) {
+ assertEquals(partitionSampler2.test(testKey), partitionSampler1.test(testKey));
+ }
+ }
@Test(expected = VerifyException.class)
public void rejectNullStartOfRange() {
diff --git a/spark-job/src/test/java/org/apache/cassandra/diff/RangeComparatorTest.java b/spark-job/src/test/java/org/apache/cassandra/diff/RangeComparatorTest.java
index fd2926b..e09f68f 100644
--- a/spark-job/src/test/java/org/apache/cassandra/diff/RangeComparatorTest.java
+++ b/spark-job/src/test/java/org/apache/cassandra/diff/RangeComparatorTest.java
@@ -57,6 +57,38 @@
private RetryStrategyProvider mockRetryStrategyFactory = RetryStrategyProvider.create(null); // create a NoRetry provider
@Test
+ public void probabilisticDiffIncludeAllPartitions() {
+ RangeComparator comparator = comparator(context(0L, 100L));
+ RangeStats stats = comparator.compare(keys(0, 1, 2, 3, 4, 5, 6), keys(0,1, 2, 3, 4, 5, 7), this::alwaysMatch);
+ assertFalse(stats.isEmpty());
+ assertEquals(1, stats.getOnlyInSource());
+ assertEquals(1, stats.getOnlyInTarget());
+ assertEquals(6, stats.getMatchedPartitions());
+ assertReported(6, MismatchType.ONLY_IN_SOURCE, mismatches);
+ assertReported(7, MismatchType.ONLY_IN_TARGET, mismatches);
+ assertNothingReported(errors, journal);
+ assertCompared(0, 1, 2, 3, 4, 5);
+ }
+
+ @Test
+ public void probabilisticDiffProbabilityHalf() {
+ RangeComparator comparator = comparator(context(0L, 100L));
+ RangeStats stats = comparator.compare(keys(0, 1, 2, 3, 4, 5, 6),
+ keys(0, 1, 2, 3, 4, 5, 7),
+ this::alwaysMatch,
+ key -> key.getTokenAsBigInteger().intValue() % 2 == 0);
+ assertFalse(stats.isEmpty());
+ assertEquals(1, stats.getOnlyInSource());
+ assertEquals(1, stats.getOnlyInTarget());
+ assertEquals(3, stats.getMatchedPartitions());
+ assertReported(6, MismatchType.ONLY_IN_SOURCE, mismatches);
+ assertReported(7, MismatchType.ONLY_IN_TARGET, mismatches);
+ assertNothingReported(errors, journal);
+ assertCompared(0, 2, 4);
+ }
+
+
+ @Test
public void emptyRange() {
RangeComparator comparator = comparator(context(100L, 100L));
RangeStats stats = comparator.compare(keys(), keys(), this::alwaysMatch);
diff --git a/spark-job/src/test/java/org/apache/cassandra/diff/SchemaTest.java b/spark-job/src/test/java/org/apache/cassandra/diff/SchemaTest.java
index 17dc67c..b94d22c 100644
--- a/spark-job/src/test/java/org/apache/cassandra/diff/SchemaTest.java
+++ b/spark-job/src/test/java/org/apache/cassandra/diff/SchemaTest.java
@@ -29,6 +29,11 @@
public List<String> disallowedKeyspaces() {
return disallowedKeyspaces;
}
+
+ @Override
+ public double partitionSamplingProbability() {
+ return 1;
+ }
}
@Test