Support registering preJob and postJob hooks (#6)
diff --git a/spark-job/src/main/java/org/apache/cassandra/diff/DiffJob.java b/spark-job/src/main/java/org/apache/cassandra/diff/DiffJob.java
index 26a74e6..3047c97 100644
--- a/spark-job/src/main/java/org/apache/cassandra/diff/DiffJob.java
+++ b/spark-job/src/main/java/org/apache/cassandra/diff/DiffJob.java
@@ -22,6 +22,8 @@
import java.io.Serializable;
import java.math.BigInteger;
import java.util.*;
+import java.util.function.BiConsumer;
+import java.util.function.Consumer;
import java.util.stream.Collectors;
import com.google.common.annotations.VisibleForTesting;
@@ -62,6 +64,19 @@
spark.stop();
}
+ // optional code block to run before a job starts
+ private Runnable preJobHook;
+ // optional code block to run after a job completes successfully; otherwise, it is not executed.
+ private Consumer<Map<String, RangeStats>> postJobHook;
+
+ public void addPreJobHook(Runnable preJobHook) {
+ this.preJobHook = preJobHook;
+ }
+
+ public void addPostJobHook(Consumer<Map<String, RangeStats>> postJobHook) {
+ this.postJobHook = postJobHook;
+ }
+
public void run(JobConfiguration configuration, JavaSparkContext sc) {
SparkConf conf = sc.getConf();
// get partitioner from both clusters and verify that they match
@@ -124,6 +139,9 @@
sourceProvider,
targetProvider);
+ if (null != preJobHook)
+ preJobHook.run();
+
// Run the distributed diff and collate results
Map<String, RangeStats> diffStats = sc.parallelize(splits, slices)
.map((split) -> new Differ(configuration,
@@ -140,6 +158,8 @@
// Publish results. This also removes the job from the currently running list
job.finalizeJob(params.jobId, diffStats);
logger.info("FINISHED: {}", diffStats);
+ if (null != postJobHook)
+ postJobHook.accept(diffStats);
} catch (Exception e) {
// If the job errors out, try and mark the job as not running, so it can be restarted.
// If the error was thrown from JobMetadataDb.finalizeJob *after* the job had already