CASSANDRA-19563: Support bulk write via S3 (#53)
This commit adds a configuration (writer) option to pick a transport other than the previously-implemented "direct upload to all sidecars" (now known as the "Direct" transport). The second transport, now being implemented, is the "S3_COMPAT" transport, which allows the job to upload the generated SSTables to an S3-compatible storage system, and then inform the Cassandra Sidecar that those files are available for download & commit.
Additionally, a plug-in system was added to allow communications between custom transport hooks and the job, so the custom hook can provide updated credentials and out-of-band status updates on S3-related issues.
Co-Authored-By: Yifan Cai <ycai@apache.org>
Co-Authored-By: Doug Rohrer <drohrer@apple.com>
Co-Authored-By: Francisco Guerrero <frankgh@apache.org>
Co-Authored-By: Saranya Krishnakumar <saranya_k@apple.com>
Patch by Yifan Cai, Doug Rohrer, Francisco Guerrero, Saranya Krishnakumar; Reviewed by Francisco Guerrero for CASSANDRA-19563
diff --git a/.circleci/config.yml b/.circleci/config.yml
index b7603fc..5f7a9b0 100644
--- a/.circleci/config.yml
+++ b/.circleci/config.yml
@@ -111,9 +111,10 @@
./gradlew cassandra-analytics-integration-tests:test --tests $TEST_NAME --no-daemon || EXIT_STATUS=$?;
done;
exit $EXIT_STATUS
+ no_output_timeout: 30m
jobs:
- build-dependencies-jdk8:
+ build-deps-jdk8:
docker:
- image: cimg/openjdk:8.0
resource_class: large
@@ -130,7 +131,7 @@
- "*.jar"
- "org/**/*"
- build-dependencies-jdk11:
+ build-deps-jdk11:
docker:
- image: cimg/openjdk:11.0
resource_class: large
@@ -178,6 +179,7 @@
- image: cimg/openjdk:8.0
resource_class: large
steps:
+ - setup_remote_docker
- install_common
- checkout
- attach_workspace:
@@ -229,6 +231,7 @@
- image: cimg/openjdk:8.0
resource_class: large
steps:
+ - setup_remote_docker
- install_common
- checkout
- attach_workspace:
@@ -281,6 +284,7 @@
- image: cimg/openjdk:11.0
resource_class: large
steps:
+ - setup_remote_docker
- install_common
- checkout
- attach_workspace:
@@ -334,6 +338,7 @@
- image: cimg/openjdk:11.0
resource_class: large
steps:
+ - setup_remote_docker
- install_common
- checkout
- attach_workspace:
@@ -359,29 +364,29 @@
version: 2
build-and-test:
jobs:
- - build-dependencies-jdk8
- - build-dependencies-jdk11
+ - build-deps-jdk8
+ - build-deps-jdk11
- spark2-2_11-jdk8:
requires:
- - build-dependencies-jdk8
+ - build-deps-jdk8
- spark2-2_12-jdk8:
requires:
- - build-dependencies-jdk8
+ - build-deps-jdk8
- spark3-2_12-jdk11:
requires:
- - build-dependencies-jdk11
+ - build-deps-jdk11
- spark3-2_13-jdk11:
requires:
- - build-dependencies-jdk11
+ - build-deps-jdk11
- int-spark2-2_11-jdk8:
requires:
- - build-dependencies-jdk8
+ - build-deps-jdk8
- int-spark2-2_12-jdk8:
requires:
- - build-dependencies-jdk8
+ - build-deps-jdk8
- int-spark3-2_12-jdk11:
requires:
- - build-dependencies-jdk11
+ - build-deps-jdk11
- int-spark3-2_13-jdk11:
requires:
- - build-dependencies-jdk11
+ - build-deps-jdk11
diff --git a/CHANGES.txt b/CHANGES.txt
index 741584c..723135d 100644
--- a/CHANGES.txt
+++ b/CHANGES.txt
@@ -1,4 +1,5 @@
1.0.0
+ * Support bulk write via S3 (CASSANDRA-19563)
* Support UDTs in the Bulk Writer (CASSANDRA-19340)
* Fix bulk reads of multiple tables that potentially have the same data file name (CASSANDRA-19507)
* Fix XXHash32Digest calculated digest value (CASSANDRA-19500)
diff --git a/build.gradle b/build.gradle
index 4b41e23..f7c8d24 100644
--- a/build.gradle
+++ b/build.gradle
@@ -177,14 +177,6 @@
}
archivesBaseName = "${archivesBaseName}_${scalaMajorVersion}"
- if ("${project.rootProject.ext.jdkLabel}" == '1.8') {
- if ("${version}".contains('-SNAPSHOT')) {
- version = "${version}".replace('-SNAPSHOT', '-jdk8-SNAPSHOT')
- } else {
- version = "${version}-jdk8"
- }
- }
-
repositories {
mavenCentral()
mavenLocal {
@@ -253,4 +245,9 @@
shouldRunAfter(tasks.withType(Checkstyle))
shouldRunAfter(tasks.withType(RatTask))
}
+
+ tasks.register('printCodeVersion') {
+ println()
+ println(version)
+ }
}
diff --git a/cassandra-analytics-core-example/README.md b/cassandra-analytics-core-example/README.md
index 0b86c6f..0958ffc 100644
--- a/cassandra-analytics-core-example/README.md
+++ b/cassandra-analytics-core-example/README.md
@@ -72,7 +72,7 @@
which would have cloned and built the sidecar into `./dependencies/sidecar-build`. Use that build to run the sidecar.
```shell
-cd ./dependencies/sidecar-build
+cd ./dependencies/sidecar-build/trunk
```
Configure the `src/main/dist/sidecar.yaml` file for your local environment. You will most likely only need to configure
@@ -114,56 +114,8 @@
You can, of course, choose to run them (and should when working on the sidecar the project itself).
```shell
-user:~$ ./gradlew run -x integrationTest
-> Task :common:compileJava UP-TO-DATE
-> Task :cassandra40:compileJava UP-TO-DATE
-> Task :compileJava UP-TO-DATE
-> Task :processResources UP-TO-DATE
-> Task :classes UP-TO-DATE
-> Task :jar UP-TO-DATE
-> Task :startScripts UP-TO-DATE
-> Task :cassandra40:processResources NO-SOURCE
-> Task :cassandra40:classes UP-TO-DATE
-> Task :cassandra40:jar UP-TO-DATE
-> Task :common:processResources NO-SOURCE
-> Task :common:classes UP-TO-DATE
-> Task :common:jar UP-TO-DATE
-> Task :distTar
-> Task :distZip
-> Task :assemble
-> Task :common:compileTestFixturesJava UP-TO-DATE
-> Task :compileTestFixturesJava UP-TO-DATE
-> Task :compileTestJava UP-TO-DATE
-> Task :processTestResources UP-TO-DATE
-> Task :testClasses UP-TO-DATE
-> Task :compileIntegrationTestJava UP-TO-DATE
-> Task :processIntegrationTestResources NO-SOURCE
-> Task :integrationTestClasses UP-TO-DATE
-> Task :checkstyleIntegrationTest UP-TO-DATE
-> Task :checkstyleMain UP-TO-DATE
-> Task :checkstyleTest UP-TO-DATE
-> Task :processTestFixturesResources NO-SOURCE
-> Task :testFixturesClasses UP-TO-DATE
-> Task :checkstyleTestFixtures UP-TO-DATE
-> Task :testFixturesJar UP-TO-DATE
-> Task :common:processTestFixturesResources NO-SOURCE
-> Task :common:testFixturesClasses UP-TO-DATE
-> Task :common:testFixturesJar UP-TO-DATE
-> Task :test UP-TO-DATE
-> Task :jacocoTestReport UP-TO-DATE
-> Task :spotbugsIntegrationTest UP-TO-DATE
-> Task :spotbugsMain UP-TO-DATE
-> Task :spotbugsTest UP-TO-DATE
-> Task :spotbugsTestFixtures UP-TO-DATE
-> Task :check
-> Task :copyJolokia UP-TO-DATE
-> Task :installDist
-> Task :copyDist
-> Task :docs:asciidoctor UP-TO-DATE
-> Task :copyDocs UP-TO-DATE
-> Task :generateReDoc NO-SOURCE
-> Task :generateSwaggerUI NO-SOURCE
-> Task :build
+user:~$ ./gradlew run -x test -x integrationTest -x containerTest
+...
> Task :run
Could not start Jolokia agent: java.net.BindException: Address already in use
@@ -191,7 +143,7 @@
### Step 3: Run the Sample Job
-To be able to run the [Sample Job](./src/main/java/org/apache/cassandra/spark/example/SampleCassandraJob.java), you
+To be able to run the [Sample Job](./src/main/java/org/apache/cassandra/spark/example/DirectWriteAndReadJob.java), you
need to create the keyspace and table used for the test.
Connect to your local Cassandra cluster using CCM:
@@ -218,11 +170,46 @@
);
```
+#### Start DirectWriteAndReadJob
+
Finally, we are ready to run the example spark job:
```shell
cd ${ANALYTICS_REPOSITORY_HOME}
./gradlew :cassandra-analytics-core-example:run
+# or this command
+# ./gradlew :cassandra-analytics-core-example:run --args='DirectWriteAndReadJob'
+```
+
+#### Start LocalS3WriteAndReadJob
+
+Alternatively, we can run the [LocalS3CassandraWriteJob](./src/main/java/org/apache/cassandra/spark/example/LocalS3WriteAndReadJob.java), which bulk writes
+data via S3. In order to run such job, there is two additional prerequisite steps.
+
+Start S3Mock
+
+```shell
+docker run -p 127.0.0.1:9090:9090 -p 127.0.0.1:9191:9191 -t adobe/s3mock:2.17.0
+```
+
+Restart sidecar with the following edits for `sidecar.yaml`.
+It is required to enable sidecar schema and point s3 client to the S3Mock with the endpoint_override.
+
+```yaml
+sidecar:
+ schema:
+ is_enabled: true
+...
+s3_client:
+ proxy_config:
+ endpoint_override: localhost:9090
+```
+
+Then, we can run the example spark job:
+
+```shell
+cd ${ANALYTICS_REPOSITORY_HOME}
+./gradlew :cassandra-analytics-core-example:run --args='LocalS3WriteAndReadJob'
```
## Tear down
diff --git a/cassandra-analytics-core-example/build.gradle b/cassandra-analytics-core-example/build.gradle
index b0bcfd3..b8ab3d4 100644
--- a/cassandra-analytics-core-example/build.gradle
+++ b/cassandra-analytics-core-example/build.gradle
@@ -34,7 +34,12 @@
}
application {
- mainClass = 'org.apache.cassandra.spark.example.SampleCassandraJob'
+ // Optionally allow to start with a different mainClass, rather than the default 'DirectWriteAndReadJob'
+ // For example,
+ // ./gradlew :cassandra-analytics-core-example:run --args='DirectWriteAndReadJob'
+ // or
+ // ./gradlew :cassandra-analytics-core-example:run --args='LocalS3WriteAndReadJob'
+ mainClass = 'org.apache.cassandra.spark.example.JobSelector'
applicationDefaultJvmArgs = ["-Dfile.encoding=UTF-8",
"-Djdk.attach.allowAttachSelf=true",
"--add-exports", "java.base/jdk.internal.misc=ALL-UNNAMED",
diff --git a/cassandra-analytics-core-example/src/main/java/org/apache/cassandra/spark/example/AbstractCassandraJob.java b/cassandra-analytics-core-example/src/main/java/org/apache/cassandra/spark/example/AbstractCassandraJob.java
new file mode 100644
index 0000000..f63ff4e
--- /dev/null
+++ b/cassandra-analytics-core-example/src/main/java/org/apache/cassandra/spark/example/AbstractCassandraJob.java
@@ -0,0 +1,237 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.cassandra.spark.example;
+
+import java.nio.ByteBuffer;
+import java.util.Arrays;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+import java.util.stream.LongStream;
+
+import com.google.common.collect.ImmutableMap;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.cassandra.spark.KryoRegister;
+import org.apache.cassandra.spark.bulkwriter.BulkSparkConf;
+import org.apache.spark.SparkConf;
+import org.apache.spark.SparkContext;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.function.Function2;
+import org.apache.spark.sql.DataFrameReader;
+import org.apache.spark.sql.DataFrameWriter;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.RowFactory;
+import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.types.StructType;
+
+import static org.apache.spark.sql.types.DataTypes.BinaryType;
+import static org.apache.spark.sql.types.DataTypes.LongType;
+
+/**
+ * Example showcasing the Cassandra Spark Analytics write and read capabilities
+ *
+ * <p>Prepare your environment by creating the following keyspace and table
+ *
+ * <p>Schema for the {@code keyspace}:
+ * <pre>
+ * CREATE KEYSPACE spark_test WITH replication = {'class': 'NetworkTopologyStrategy', 'datacenter1': '3'}
+ * AND durable_writes = true;
+ * </pre>
+ *
+ * <p>Schema for the {@code table}:
+ * <pre>
+ * CREATE TABLE spark_test.test (
+ * id BIGINT PRIMARY KEY,
+ * course BLOB,
+ * marks BIGINT
+ * );
+ * </pre>
+ */
+public abstract class AbstractCassandraJob
+{
+ private final Logger logger = LoggerFactory.getLogger(this.getClass());
+
+ private JobConfiguration configuration;
+
+ protected abstract JobConfiguration configureJob(SparkContext sc, SparkConf sparkConf);
+
+ public void start(String[] args)
+ {
+ logger.info("Starting Spark job with args={}", Arrays.toString(args));
+
+ SparkConf sparkConf = new SparkConf().setAppName("Sample Spark Cassandra Bulk Reader Job")
+ .set("spark.master", "local[8]");
+
+ // Add SBW-specific settings
+ // TODO: simplify setting up spark conf
+ BulkSparkConf.setupSparkConf(sparkConf, true);
+ KryoRegister.setup(sparkConf);
+
+ SparkSession spark = SparkSession
+ .builder()
+ .config(sparkConf)
+ .getOrCreate();
+ SparkContext sc = spark.sparkContext();
+ SQLContext sql = spark.sqlContext();
+ logger.info("Spark Conf: " + sparkConf.toDebugString());
+
+ configuration = configureJob(sc, sparkConf);
+ long rowCount = configuration.rowCount;
+
+ try
+ {
+ Dataset<Row> written = null;
+ if (configuration.shouldWrite())
+ {
+ written = write(rowCount, sql, sc);
+ }
+ Dataset<Row> read = null;
+ if (configuration.shouldRead())
+ {
+ read = read(rowCount, sql);
+ }
+
+ if (configuration.shouldWrite() && configuration.shouldRead())
+ {
+ checkSmallDataFrameEquality(written, read);
+ }
+ logger.info("Finished Spark job, shutting down...");
+ sc.stop();
+ }
+ catch (Throwable throwable)
+ {
+ logger.error("Unexpected exception executing Spark job", throwable);
+ try
+ {
+ sc.stop();
+ }
+ catch (Throwable ignored)
+ {
+ }
+ }
+ }
+
+ protected Dataset<Row> write(long rowCount, SQLContext sql, SparkContext sc)
+ {
+ StructType schema = new StructType()
+ .add("id", LongType, false)
+ .add("course", BinaryType, false)
+ .add("marks", LongType, false);
+
+ JavaSparkContext javaSparkContext = JavaSparkContext.fromSparkContext(sc);
+ int parallelism = sc.defaultParallelism();
+ JavaRDD<Row> rows = genDataset(javaSparkContext, rowCount, parallelism);
+ Dataset<Row> df = sql.createDataFrame(rows, schema);
+
+ DataFrameWriter<Row> writer = df.write().format("org.apache.cassandra.spark.sparksql.CassandraDataSink");
+ writer.options(configuration.writeOptions);
+ writer.mode("append").save();
+ return df;
+ }
+
+ protected Dataset<Row> read(long expectedRowCount, SQLContext sql)
+ {
+ DataFrameReader reader = sql.read().format("org.apache.cassandra.spark.sparksql.CassandraDataSource");
+ reader.options(configuration.readOptions);
+ Dataset<Row> df = reader.load();
+ long count = df.count();
+ logger.info("Found {} records", count);
+
+ if (count != expectedRowCount)
+ {
+ logger.error("Expected {} records but found {} records", expectedRowCount, count);
+ return null;
+ }
+ return df;
+ }
+
+ private static void checkSmallDataFrameEquality(Dataset<Row> expected, Dataset<Row> actual)
+ {
+ if (actual == null)
+ {
+ throw new NullPointerException("actual dataframe is null");
+ }
+ if (!actual.exceptAll(expected).isEmpty())
+ {
+ throw new IllegalStateException("The content of the dataframes differs");
+ }
+ }
+
+ private static JavaRDD<Row> genDataset(JavaSparkContext sc, long records, Integer parallelism)
+ {
+ long recordsPerPartition = records / parallelism;
+ long remainder = records - (recordsPerPartition * parallelism);
+ List<Integer> seq = IntStream.range(0, parallelism).boxed().collect(Collectors.toList());
+ return sc.parallelize(seq, parallelism).mapPartitionsWithIndex(
+ (Function2<Integer, Iterator<Integer>, Iterator<Row>>) (index, integerIterator) -> {
+ long firstRecordNumber = index * recordsPerPartition;
+ long recordsToGenerate = (index.equals(parallelism)) ? remainder : recordsPerPartition;
+ return LongStream.range(0, recordsToGenerate).mapToObj((offset) -> {
+ long i = firstRecordNumber + offset;
+ String courseNameString = String.valueOf(i);
+ int courseNameStringLen = courseNameString.length();
+ int courseNameMultiplier = 1000 / courseNameStringLen;
+ byte[] courseName = dupStringAsBytes(courseNameString, courseNameMultiplier);
+ return RowFactory.create(i, courseName, i);
+ }).iterator();
+ }, false);
+ }
+
+ private static byte[] dupStringAsBytes(String string, Integer times)
+ {
+ byte[] stringBytes = string.getBytes();
+ ByteBuffer buf = ByteBuffer.allocate(stringBytes.length * times);
+ for (int i = 0; i < times; i++)
+ {
+ buf.put(stringBytes);
+ }
+ return buf.array();
+ }
+
+ static class JobConfiguration
+ {
+ long rowCount = 100_000; // being mutable deliberately for testing convenience.
+ ImmutableMap<String, String> writeOptions;
+ ImmutableMap<String, String> readOptions;
+
+ JobConfiguration(Map<String, String> writeOptions, Map<String, String> readOptions)
+ {
+ this.writeOptions = ImmutableMap.copyOf(writeOptions);
+ this.readOptions = ImmutableMap.copyOf(readOptions);
+ }
+
+ boolean shouldWrite()
+ {
+ return !writeOptions.isEmpty();
+ }
+
+ boolean shouldRead()
+ {
+ return !readOptions.isEmpty();
+ }
+ }
+}
diff --git a/cassandra-analytics-core-example/src/main/java/org/apache/cassandra/spark/example/DirectWriteAndReadJob.java b/cassandra-analytics-core-example/src/main/java/org/apache/cassandra/spark/example/DirectWriteAndReadJob.java
new file mode 100644
index 0000000..c94b195
--- /dev/null
+++ b/cassandra-analytics-core-example/src/main/java/org/apache/cassandra/spark/example/DirectWriteAndReadJob.java
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.cassandra.spark.example;
+
+import java.util.HashMap;
+import java.util.Map;
+import java.util.UUID;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.SparkContext;
+
+/**
+ * A sample cassandra spark job that writes directly to Cassandra via Sidecar,
+ * then reads from Cassandra
+ */
+public class DirectWriteAndReadJob extends AbstractCassandraJob
+{
+ public static void main(String[] args)
+ {
+ System.setProperty("SKIP_STARTUP_VALIDATIONS", "true");
+ new DirectWriteAndReadJob().start(args);
+ }
+
+ protected JobConfiguration configureJob(SparkContext sc, SparkConf sparkConf)
+ {
+ Map<String, String> writeOptions = new HashMap<>();
+ writeOptions.put("sidecar_instances", "localhost,localhost2,localhost3");
+ writeOptions.put("keyspace", "spark_test");
+ writeOptions.put("table", "test");
+ writeOptions.put("local_dc", "datacenter1");
+ writeOptions.put("bulk_writer_cl", "LOCAL_QUORUM");
+ writeOptions.put("number_splits", "-1");
+ writeOptions.put("data_transport", "DIRECT");
+
+ int coresPerExecutor = sparkConf.getInt("spark.executor.cores", 1);
+ int numExecutors = sparkConf.getInt("spark.dynamicAllocation.maxExecutors",
+ sparkConf.getInt("spark.executor.instances", 1));
+ int numCores = coresPerExecutor * numExecutors;
+ Map<String, String> readerOptions = new HashMap<>();
+ readerOptions.put("sidecar_instances", "localhost,localhost2,localhost3");
+ readerOptions.put("keyspace", "spark_test");
+ readerOptions.put("table", "test");
+ readerOptions.put("DC", "datacenter1");
+ readerOptions.put("snapshotName", UUID.randomUUID().toString());
+ readerOptions.put("createSnapshot", "true");
+ readerOptions.put("defaultParallelism", String.valueOf(sc.defaultParallelism()));
+ readerOptions.put("numCores", String.valueOf(numCores));
+ readerOptions.put("sizing", "default");
+ return new JobConfiguration(writeOptions, readerOptions);
+ }
+}
diff --git a/cassandra-analytics-core-example/src/main/java/org/apache/cassandra/spark/example/ExampleStorageTransportExtension.java b/cassandra-analytics-core-example/src/main/java/org/apache/cassandra/spark/example/ExampleStorageTransportExtension.java
new file mode 100644
index 0000000..a4524a7
--- /dev/null
+++ b/cassandra-analytics-core-example/src/main/java/org/apache/cassandra/spark/example/ExampleStorageTransportExtension.java
@@ -0,0 +1,154 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.cassandra.spark.example;
+
+import java.util.concurrent.Executors;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.TimeUnit;
+
+import com.google.common.collect.ImmutableMap;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.cassandra.spark.bulkwriter.util.ThreadUtil;
+import org.apache.cassandra.spark.transports.storage.StorageCredentialPair;
+import org.apache.cassandra.spark.transports.storage.StorageCredentials;
+import org.apache.cassandra.spark.transports.storage.extensions.StorageTransportConfiguration;
+import org.apache.cassandra.spark.transports.storage.extensions.StorageTransportExtension;
+import org.apache.cassandra.spark.transports.storage.extensions.ObjectFailureListener;
+import org.apache.cassandra.spark.transports.storage.extensions.CredentialChangeListener;
+import org.apache.spark.SparkConf;
+
+public class ExampleStorageTransportExtension implements StorageTransportExtension
+{
+ private static final Logger LOGGER = LoggerFactory.getLogger(ExampleStorageTransportExtension.class);
+
+ private SparkConf conf;
+ private ScheduledExecutorService scheduledExecutorService =
+ Executors.newSingleThreadScheduledExecutor(ThreadUtil.threadFactory("ExampleBlobStorageOperations"));
+ private String jobId;
+ private long tokenCount = 0;
+ private CredentialChangeListener credentialChangeListener;
+ private ObjectFailureListener objectFailureListener;
+ private boolean shouldFail;
+
+ @Override
+ public void initialize(String jobId, SparkConf conf, boolean isOnDriver)
+ {
+ this.jobId = jobId;
+ this.conf = conf;
+ this.shouldFail = conf.getBoolean("blob_operations_should_fail", false);
+ }
+
+ @Override
+ public StorageTransportConfiguration getStorageConfiguration()
+ {
+ ImmutableMap<String, String> additionalTags = ImmutableMap.of("additional-key", "additional-value");
+ return new StorageTransportConfiguration("writebucket-name",
+ "us-west-2",
+ "readbucket-name",
+ "eu-west-1",
+ "some-prefix-for-each-job",
+ generateTokens(this.tokenCount++),
+ additionalTags);
+ }
+
+ @Override
+ public void onTransportStart(long elapsedMillis)
+ {
+
+ }
+
+ @Override
+ public void setCredentialChangeListener(CredentialChangeListener credentialChangeListener)
+ {
+ LOGGER.info("Token listener registered for job {}", jobId);
+ this.credentialChangeListener = credentialChangeListener;
+ startFakeTokenRefresh();
+ }
+
+ @Override
+ public void setObjectFailureListener(ObjectFailureListener objectFailureListener)
+ {
+ this.objectFailureListener = objectFailureListener;
+ if (this.shouldFail)
+ {
+ scheduledExecutorService.schedule(this::fail, 1, TimeUnit.SECONDS);
+ }
+ }
+
+ private void fail()
+ {
+ this.objectFailureListener.onObjectFailed(this.jobId, "failed_bucket", "failed_key", "Fake failure");
+ }
+
+ @Override
+ public void onObjectPersisted(String bucket, String key, long sizeInBytes)
+ {
+ LOGGER.info("Object {}/{} for job {} persisted with size {} bytes", bucket, key, jobId, sizeInBytes);
+ }
+
+ @Override
+ public void onAllObjectsPersisted(long objectsCount, long rowCount, long elapsedMillis)
+ {
+ LOGGER.info("All {} objects, totaling {} rows, are persisted with elapsed time {}ms",
+ objectsCount, rowCount, elapsedMillis);
+ }
+
+ @Override
+ public void onObjectApplied(String bucket, String key, long sizeInBytes, long elapsedMillis)
+ {
+
+ }
+
+ @Override
+ public void onJobSucceeded(long elapsedMillis)
+ {
+ LOGGER.info("Job {} succeeded with elapsed time {}ms", jobId, elapsedMillis);
+ }
+
+ @Override
+ public void onJobFailed(long elapsedMillis, Throwable throwable)
+ {
+ LOGGER.error("Job {} failed after {}ms", jobId, elapsedMillis, throwable);
+ }
+
+ private void startFakeTokenRefresh()
+ {
+ scheduledExecutorService.scheduleAtFixedRate(this::refreshTokens, 1, 1, TimeUnit.SECONDS);
+
+ }
+
+ private void refreshTokens()
+ {
+ this.credentialChangeListener.onCredentialsChanged(this.jobId, generateTokens(this.tokenCount++));
+ }
+
+ private StorageCredentialPair generateTokens(long tokenCount)
+ {
+ return new StorageCredentialPair(new StorageCredentials("writeAccessKeyId-" + tokenCount,
+ "writeSecretKey-" + tokenCount,
+ "writeSessionToken-" + tokenCount),
+ new StorageCredentials(
+ "readAccessKeyId-" + tokenCount,
+ "readSecretKey-" + tokenCount,
+ "readSessionToken-" + tokenCount));
+ }
+}
diff --git a/cassandra-analytics-core-example/src/main/java/org/apache/cassandra/spark/example/JobSelector.java b/cassandra-analytics-core-example/src/main/java/org/apache/cassandra/spark/example/JobSelector.java
new file mode 100644
index 0000000..f240a28
--- /dev/null
+++ b/cassandra-analytics-core-example/src/main/java/org/apache/cassandra/spark/example/JobSelector.java
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.cassandra.spark.example;
+
+public final class JobSelector
+{
+ private JobSelector()
+ {
+ throw new IllegalStateException(getClass() + " is static utility class and shall not be instantiated");
+ }
+
+ public static void main(String[] args)
+ {
+ String jobClassName = "DirectCassandraWriteJob";
+ if (args.length != 1)
+ {
+ System.out.println("Invalid number of arguments supplied. Fall back to run " + jobClassName);
+ }
+ else
+ {
+ jobClassName = args[0];
+ }
+
+ if (jobClassName.equalsIgnoreCase(DirectWriteAndReadJob.class.getSimpleName()))
+ {
+ DirectWriteAndReadJob.main(args);
+ }
+ else if (jobClassName.equalsIgnoreCase(LocalS3WriteAndReadJob.class.getSimpleName()))
+ {
+ // shift by 1
+ String[] newArgs = new String[args.length - 1];
+ System.arraycopy(args, 1, newArgs, 0, newArgs.length);
+ LocalS3WriteAndReadJob.main(newArgs);
+ }
+ else
+ {
+ System.err.println("Unknown job class named supplied. ClassName: " + jobClassName);
+ }
+ }
+}
diff --git a/cassandra-analytics-core-example/src/main/java/org/apache/cassandra/spark/example/LocalS3WriteAndReadJob.java b/cassandra-analytics-core-example/src/main/java/org/apache/cassandra/spark/example/LocalS3WriteAndReadJob.java
new file mode 100644
index 0000000..749a051
--- /dev/null
+++ b/cassandra-analytics-core-example/src/main/java/org/apache/cassandra/spark/example/LocalS3WriteAndReadJob.java
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.cassandra.spark.example;
+
+import java.util.HashMap;
+import java.util.Map;
+import java.util.UUID;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.SparkContext;
+
+import static org.apache.cassandra.spark.example.LocalStorageTransportExtension.BUCKET_NAME;
+
+/**
+ * A sample Cassandra spark job that writes to (local) s3 first and imports into Cassandra via Sidecar
+ */
+public class LocalS3WriteAndReadJob extends AbstractCassandraJob
+{
+ private String dataCenter = "datacenter1";
+ private String sidecarInstances = "localhost,localhost2,localhost3";
+
+ LocalS3WriteAndReadJob(String[] args)
+ {
+ if (args.length == 2)
+ {
+ dataCenter = args[0];
+ sidecarInstances = args[1];
+ }
+ }
+
+ public static void main(String[] args)
+ {
+ System.setProperty("SKIP_STARTUP_VALIDATIONS", "true");
+ // It expects to have mocks3 running locally on 9090
+ ProcessBuilder pb = new ProcessBuilder();
+ pb.command("curl", "-X", "PUT", "localhost:9090/" + BUCKET_NAME);
+ try
+ {
+ pb.start().waitFor();
+ }
+ catch (Exception e)
+ {
+ // ignore when the bucket is already created
+ }
+
+ new LocalS3WriteAndReadJob(args).start(args);
+ }
+
+ protected JobConfiguration configureJob(SparkContext sc, SparkConf sparkConf)
+ {
+ Map<String, String> writeOptions = new HashMap<>();
+ writeOptions.put("sidecar_instances", sidecarInstances);
+ writeOptions.put("keyspace", "spark_test");
+ writeOptions.put("table", "test");
+ writeOptions.put("local_dc", dataCenter);
+ writeOptions.put("bulk_writer_cl", "LOCAL_QUORUM");
+ writeOptions.put("number_splits", "-1");
+ // ---- Below write options are for S3_COMPAT impl only ----
+ // Set the data transport mode to "S3_COMPAT" to use an AWS S3-compatible
+ // storage service to move data from Spark to Sidecar
+ writeOptions.put("data_transport", "S3_COMPAT");
+ writeOptions.put("data_transport_extension_class", LocalStorageTransportExtension.class.getCanonicalName());
+
+ // It is only needed in order to talk to the local mocks3 server. Do not set the option in other scenarios.
+ writeOptions.put("storage_client_endpoint_override", "http://localhost:9090");
+ // 5MiB for testing. The default is 100MiB. It controls chunk size for multipart upload
+ writeOptions.put("storage_client_max_chunk_size_in_bytes", "5242880");
+ // 10MiB for testing. The default is 5GiB. It controls object size on S3
+ writeOptions.put("max_size_per_sstable_bundle_in_bytes_s3_transport", "10485760");
+ writeOptions.put("max_job_duration_minutes", "10");
+ writeOptions.put("job_id", "a_unique_id_made_of_arbitrary_string");
+
+ int coresPerExecutor = sparkConf.getInt("spark.executor.cores", 1);
+ int numExecutors = sparkConf.getInt("spark.dynamicAllocation.maxExecutors",
+ sparkConf.getInt("spark.executor.instances", 1));
+ int numCores = coresPerExecutor * numExecutors;
+ Map<String, String> readerOptions = new HashMap<>();
+ readerOptions.put("sidecar_instances", "localhost,localhost2,localhost3");
+ readerOptions.put("keyspace", "spark_test");
+ readerOptions.put("table", "test");
+ readerOptions.put("DC", "datacenter1");
+ readerOptions.put("snapshotName", UUID.randomUUID().toString());
+ readerOptions.put("createSnapshot", "true");
+ readerOptions.put("defaultParallelism", String.valueOf(sc.defaultParallelism()));
+ readerOptions.put("numCores", String.valueOf(numCores));
+ readerOptions.put("sizing", "default");
+
+ JobConfiguration config = new JobConfiguration(writeOptions, readerOptions); // empty read option since the job does not perform read.
+ config.rowCount = 2_000_000L;
+ return config;
+ }
+}
diff --git a/cassandra-analytics-core-example/src/main/java/org/apache/cassandra/spark/example/LocalStorageTransportExtension.java b/cassandra-analytics-core-example/src/main/java/org/apache/cassandra/spark/example/LocalStorageTransportExtension.java
new file mode 100644
index 0000000..95d5a96
--- /dev/null
+++ b/cassandra-analytics-core-example/src/main/java/org/apache/cassandra/spark/example/LocalStorageTransportExtension.java
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.cassandra.spark.example;
+
+import com.google.common.collect.ImmutableMap;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.cassandra.spark.transports.storage.StorageCredentialPair;
+import org.apache.cassandra.spark.transports.storage.StorageCredentials;
+import org.apache.cassandra.spark.transports.storage.extensions.CredentialChangeListener;
+import org.apache.cassandra.spark.transports.storage.extensions.ObjectFailureListener;
+import org.apache.cassandra.spark.transports.storage.extensions.StorageTransportConfiguration;
+import org.apache.cassandra.spark.transports.storage.extensions.StorageTransportExtension;
+import org.apache.spark.SparkConf;
+
+public class LocalStorageTransportExtension implements StorageTransportExtension
+{
+ public static final String BUCKET_NAME = "sbw-bucket";
+
+ private static final Logger LOGGER = LoggerFactory.getLogger(LocalStorageTransportExtension.class);
+
+ private String jobId;
+
+ @Override
+ public void initialize(String jobId, SparkConf conf, boolean isOnDriver)
+ {
+ this.jobId = jobId;
+ }
+
+ @Override
+ public StorageTransportConfiguration getStorageConfiguration()
+ {
+ ImmutableMap<String, String> additionalTags = ImmutableMap.of("additional-key", "additional-value");
+ return new StorageTransportConfiguration(BUCKET_NAME,
+ "us-west-1",
+ BUCKET_NAME,
+ "eu-west-1",
+ "key-prefix",
+ generateTokens(),
+ additionalTags);
+ }
+
+ @Override
+ public void onTransportStart(long elapsedMillis)
+ {
+
+ }
+
+ @Override
+ public void setCredentialChangeListener(CredentialChangeListener credentialChangeListener)
+ {
+ }
+
+ @Override
+ public void setObjectFailureListener(ObjectFailureListener objectFailureListener)
+ {
+ }
+
+ @Override
+ public void onObjectPersisted(String bucket, String key, long sizeInBytes)
+ {
+ LOGGER.info("Object {}/{} for job {} persisted with size {} bytes", bucket, key, jobId, sizeInBytes);
+ }
+
+ @Override
+ public void onAllObjectsPersisted(long objectsCount, long rowCount, long elapsedMillis)
+ {
+ LOGGER.info("All {} objects, totaling {} rows, are persisted with elapsed time {}ms",
+ objectsCount, rowCount, elapsedMillis);
+ }
+
+ @Override
+ public void onObjectApplied(String bucket, String key, long sizeInBytes, long elapsedMillis)
+ {
+
+ }
+
+ @Override
+ public void onJobSucceeded(long elapsedMillis)
+ {
+ LOGGER.info("Job {} succeeded with elapsed time {}ms", jobId, elapsedMillis);
+ }
+
+ @Override
+ public void onJobFailed(long elapsedMillis, Throwable throwable)
+ {
+ LOGGER.error("Job {} failed after {}ms", jobId, elapsedMillis, throwable);
+ }
+
+ private StorageCredentialPair generateTokens()
+ {
+ return new StorageCredentialPair(new StorageCredentials("writeKey",
+ "writeSecret",
+ "writeSessionToken"),
+ new StorageCredentials("readKey",
+ "readSecret",
+ "readSessionToken"));
+ }
+}
diff --git a/cassandra-analytics-core-example/src/main/java/org/apache/cassandra/spark/example/SampleCassandraJob.java b/cassandra-analytics-core-example/src/main/java/org/apache/cassandra/spark/example/SampleCassandraJob.java
deleted file mode 100644
index b71c57b..0000000
--- a/cassandra-analytics-core-example/src/main/java/org/apache/cassandra/spark/example/SampleCassandraJob.java
+++ /dev/null
@@ -1,281 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.cassandra.spark.example;
-
-import java.nio.ByteBuffer;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.Iterator;
-import java.util.List;
-import java.util.UUID;
-import java.util.stream.Collectors;
-import java.util.stream.IntStream;
-import java.util.stream.LongStream;
-
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-import org.apache.cassandra.spark.KryoRegister;
-import org.apache.cassandra.spark.bulkwriter.BulkSparkConf;
-import org.apache.cassandra.spark.bulkwriter.TTLOption;
-import org.apache.cassandra.spark.bulkwriter.TimestampOption;
-import org.apache.cassandra.spark.bulkwriter.WriterOptions;
-import org.apache.spark.SparkConf;
-import org.apache.spark.SparkContext;
-import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.api.java.function.Function2;
-import org.apache.spark.sql.DataFrameWriter;
-import org.apache.spark.sql.Dataset;
-import org.apache.spark.sql.Row;
-import org.apache.spark.sql.RowFactory;
-import org.apache.spark.sql.SQLContext;
-import org.apache.spark.sql.SparkSession;
-import org.apache.spark.sql.types.StructType;
-
-import static org.apache.spark.sql.types.DataTypes.BinaryType;
-import static org.apache.spark.sql.types.DataTypes.IntegerType;
-import static org.apache.spark.sql.types.DataTypes.LongType;
-
-/**
- * Example showcasing the Cassandra Spark Analytics write and read capabilities
- * <p>
- * Prepare your environment by creating the following keyspace and table
- * <p>
- * Schema for the {@code keyspace}:
- * <pre>
- * CREATE KEYSPACE spark_test WITH replication = {'class': 'NetworkTopologyStrategy', 'datacenter1': '3'}
- * AND durable_writes = true;
- * </pre>
- * <p>
- * Schema for the {@code table}:
- * <pre>
- * CREATE TABLE spark_test.test (
- * id BIGINT PRIMARY KEY,
- * course BLOB,
- * marks BIGINT
- * );
- * </pre>
- */
-public final class SampleCassandraJob
-{
- private static final Logger LOGGER = LoggerFactory.getLogger(SampleCassandraJob.class);
-
- private SampleCassandraJob()
- {
- throw new IllegalStateException(getClass() + " is static utility class and shall not be instantiated");
- }
-
- public static void main(String[] args)
- {
- LOGGER.info("Starting Spark job with args={}", Arrays.toString(args));
-
- SparkConf sparkConf = new SparkConf().setAppName("Sample Spark Cassandra Bulk Reader Job")
- .set("spark.master", "local[8]");
-
- // Add SBW-specific settings
- // TODO: Simplify setting up spark conf
- BulkSparkConf.setupSparkConf(sparkConf, true);
- KryoRegister.setup(sparkConf);
-
- SparkSession spark = SparkSession
- .builder()
- .config(sparkConf)
- .getOrCreate();
- SparkContext sc = spark.sparkContext();
- SQLContext sql = spark.sqlContext();
- LOGGER.info("Spark Conf: " + sparkConf.toDebugString());
- int rowCount = 10_000;
-
- try
- {
- Dataset<Row> written = write(rowCount, sparkConf, sql, sc);
- Dataset<Row> read = read(rowCount, sparkConf, sql, sc);
-
- checkSmallDataFrameEquality(written, read);
- LOGGER.info("Finished Spark job, shutting down...");
- sc.stop();
- }
- catch (Throwable throwable)
- {
- LOGGER.error("Unexpected exception executing Spark job", throwable);
- try
- {
- sc.stop();
- }
- catch (Throwable ignored)
- {
- }
- }
- }
-
- private static Dataset<Row> write(long rowCount, SparkConf sparkConf, SQLContext sql, SparkContext sc)
- {
- JavaSparkContext javaSparkContext = JavaSparkContext.fromSparkContext(sc);
- int parallelism = sc.defaultParallelism();
- boolean addTTLColumn = false;
- boolean addTimestampColumn = false;
- JavaRDD<Row> rows = genDataset(javaSparkContext, rowCount, parallelism, addTTLColumn, addTimestampColumn);
- Dataset<Row> df = sql.createDataFrame(rows, getWriteSchema(addTTLColumn, addTimestampColumn));
-
- DataFrameWriter<Row> dfWriter = df.write()
- .format("org.apache.cassandra.spark.sparksql.CassandraDataSink")
- .option("sidecar_instances", "localhost,localhost2,localhost3")
- .option("keyspace", "spark_test")
- .option("table", "test")
- .option("local_dc", "datacenter1")
- .option("bulk_writer_cl", "LOCAL_QUORUM")
- .option("number_splits", "-1")
- // A constant timestamp and TTL can be used by setting the following options.
- // .option(WriterOptions.TTL.name(), TTLOption.constant(20))
- // .option(WriterOptions.TIMESTAMP.name(), TimestampOption.constant(System.currentTimeMillis() * 1000))
- .mode("append");
-
- List<String> addedColumns = new ArrayList<>();
- if (addTTLColumn)
- {
- addedColumns.add("ttl");
- dfWriter = dfWriter
- .option(WriterOptions.TTL.name(), TTLOption.perRow("ttl"));
- }
-
- if (addTimestampColumn)
- {
- addedColumns.add("timestamp");
- dfWriter = dfWriter
- .option(WriterOptions.TIMESTAMP.name(), TimestampOption.perRow("timestamp"));
- }
-
- dfWriter.save();
-
- if (!addedColumns.isEmpty())
- {
- df = df.drop(addedColumns.toArray(new String[0]));
- }
-
- return df;
- }
-
- private static Dataset<Row> read(int expectedRowCount, SparkConf sparkConf, SQLContext sql, SparkContext sc)
- {
- int coresPerExecutor = sparkConf.getInt("spark.executor.cores", 1);
- int numExecutors = sparkConf.getInt("spark.dynamicAllocation.maxExecutors", sparkConf.getInt("spark.executor.instances", 1));
- int numCores = coresPerExecutor * numExecutors;
-
- Dataset<Row> df = sql.read().format("org.apache.cassandra.spark.sparksql.CassandraDataSource")
- .option("sidecar_instances", "localhost,localhost2,localhost3")
- .option("keyspace", "spark_test")
- .option("table", "test")
- .option("DC", "datacenter1")
- .option("snapshotName", UUID.randomUUID().toString())
- .option("createSnapshot", "true")
- .option("defaultParallelism", sc.defaultParallelism())
- .option("numCores", numCores)
- .option("sizing", "default")
- .load();
-
- long count = df.count();
- LOGGER.info("Found {} records", count);
-
- if (count != expectedRowCount)
- {
- LOGGER.error("Expected {} records but found {} records", expectedRowCount, count);
- return null;
- }
- return df;
- }
-
- private static StructType getWriteSchema(boolean addTTLColumn, boolean addTimestampColumn)
- {
- StructType schema = new StructType()
- .add("id", LongType, false)
- .add("course", BinaryType, false)
- .add("marks", LongType, false);
- if (addTTLColumn)
- {
- schema = schema.add("ttl", IntegerType, false);
- }
- if (addTimestampColumn)
- {
- schema = schema.add("timestamp", LongType, false);
- }
- return schema;
- }
-
- private static void checkSmallDataFrameEquality(Dataset<Row> expected, Dataset<Row> actual)
- {
- if (actual == null)
- {
- throw new NullPointerException("actual dataframe is null");
- }
- if (!actual.exceptAll(expected).isEmpty())
- {
- throw new IllegalStateException("The content of the dataframes differs");
- }
- }
-
- private static JavaRDD<Row> genDataset(JavaSparkContext sc, long records, Integer parallelism,
- boolean addTTLColumn, boolean addTimestampColumn)
- {
- long recordsPerPartition = records / parallelism;
- long remainder = records - (recordsPerPartition * parallelism);
- List<Integer> seq = IntStream.range(0, parallelism).boxed().collect(Collectors.toList());
- int ttl = 120; // data will not be queryable in two minutes
- long timeStamp = System.currentTimeMillis() * 1000;
- JavaRDD<Row> dataset = sc.parallelize(seq, parallelism).mapPartitionsWithIndex(
- (Function2<Integer, Iterator<Integer>, Iterator<Row>>) (index, integerIterator) -> {
- long firstRecordNumber = index * recordsPerPartition;
- long recordsToGenerate = index.equals(parallelism) ? remainder : recordsPerPartition;
- java.util.Iterator<Row> rows = LongStream.range(0, recordsToGenerate).mapToObj(offset -> {
- long recordNumber = firstRecordNumber + offset;
- String courseNameString = String.valueOf(recordNumber);
- Integer courseNameStringLen = courseNameString.length();
- Integer courseNameMultiplier = 1000 / courseNameStringLen;
- byte[] courseName = dupStringAsBytes(courseNameString, courseNameMultiplier);
- if (addTTLColumn && addTimestampColumn)
- {
- return RowFactory.create(recordNumber, courseName, recordNumber, ttl, timeStamp);
- }
- if (addTTLColumn)
- {
- return RowFactory.create(recordNumber, courseName, recordNumber, ttl);
- }
- if (addTimestampColumn)
- {
- return RowFactory.create(recordNumber, courseName, recordNumber, timeStamp);
- }
- return RowFactory.create(recordNumber, courseName, recordNumber);
- }).iterator();
- return rows;
- }, false);
- return dataset;
- }
-
- private static byte[] dupStringAsBytes(String string, Integer times)
- {
- byte[] stringBytes = string.getBytes();
- ByteBuffer buffer = ByteBuffer.allocate(stringBytes.length * times);
- for (int time = 0; time < times; time++)
- {
- buffer.put(stringBytes);
- }
- return buffer.array();
- }
-}
diff --git a/cassandra-analytics-core-example/src/main/resources/keystore-private.p12 b/cassandra-analytics-core-example/src/main/resources/keystore-private.p12
new file mode 100644
index 0000000..f6e226d
--- /dev/null
+++ b/cassandra-analytics-core-example/src/main/resources/keystore-private.p12
Binary files differ
diff --git a/cassandra-analytics-core/build.gradle b/cassandra-analytics-core/build.gradle
index 0b2fcc9..b5ea9a2 100644
--- a/cassandra-analytics-core/build.gradle
+++ b/cassandra-analytics-core/build.gradle
@@ -68,12 +68,17 @@
// This dependency must be built by running `scripts/build-dependencies.sh`
api(group: "${sidecarClientGroup}", name: "${sidecarClientName}", version: "${sidecarVersion}")
- implementation(group: 'org.lz4', name: 'lz4-java', version: '1.8.0')
+ implementation(group: 'org.lz4', name: 'lz4-java', version: '1.8.0') // for xxhash
if ("${scalaMajorVersion}" == "2.11") {
implementation(group: 'org.scala-lang.modules', name: "scala-java8-compat_2.11", version: '1.0.1', transitive: false)
}
+ // aws sdk BOM + s3
+ implementation 'software.amazon.awssdk:s3'
+ implementation 'software.amazon.awssdk:netty-nio-client'
+ implementation platform(group: 'software.amazon.awssdk', name:'bom', version:"${project.aswSdkVersion}")
+
compileOnly(group: "${sparkGroupId}", name: "spark-core_${scalaMajorVersion}", version: "${project.rootProject.sparkVersion}")
compileOnly(group: "${sparkGroupId}", name: "spark-sql_${scalaMajorVersion}", version: "${project.rootProject.sparkVersion}")
diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/clients/Sidecar.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/clients/Sidecar.java
index 53f092e..4f73640 100644
--- a/cassandra-analytics-core/src/main/java/org/apache/cassandra/clients/Sidecar.java
+++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/clients/Sidecar.java
@@ -41,12 +41,15 @@
import org.apache.cassandra.sidecar.client.SidecarClientConfig;
import org.apache.cassandra.sidecar.client.SidecarClientConfigImpl;
import org.apache.cassandra.sidecar.client.SidecarInstance;
+import org.apache.cassandra.sidecar.client.SidecarInstanceImpl;
import org.apache.cassandra.sidecar.client.SidecarInstancesProvider;
import org.apache.cassandra.sidecar.client.VertxHttpClient;
import org.apache.cassandra.sidecar.client.VertxRequestExecutor;
import org.apache.cassandra.sidecar.client.retry.ExponentialBackoffRetryPolicy;
import org.apache.cassandra.sidecar.client.retry.RetryPolicy;
import org.apache.cassandra.spark.bulkwriter.BulkSparkConf;
+import org.apache.cassandra.spark.bulkwriter.DataTransport;
+import org.apache.cassandra.spark.common.model.CassandraInstance;
import org.apache.cassandra.spark.data.FileType;
import org.apache.cassandra.spark.utils.BuildInfo;
import org.apache.cassandra.spark.utils.MapUtils;
@@ -119,15 +122,28 @@
return buildClient(sidecarConfig, vertx, httpClientConfig, sidecarInstancesProvider);
}
+ static String transportModeBasedWriterUserAgent(DataTransport transport)
+ {
+ switch (transport)
+ {
+ case S3_COMPAT:
+ return BuildInfo.WRITER_S3_USER_AGENT;
+ case DIRECT:
+ default:
+ return BuildInfo.WRITER_USER_AGENT;
+ }
+ }
+
public static SidecarClient from(SidecarInstancesProvider sidecarInstancesProvider, BulkSparkConf conf)
{
Vertx vertx = Vertx.vertx(new VertxOptions().setUseDaemonThread(true)
.setWorkerPoolSize(conf.getMaxHttpConnections()));
+ String userAgent = transportModeBasedWriterUserAgent(conf.getTransportInfo().getTransport());
HttpClientConfig httpClientConfig = new HttpClientConfig.Builder<>()
.timeoutMillis(conf.getHttpResponseTimeoutMs())
.idleTimeoutMillis(conf.getHttpConnectionTimeoutMs())
- .userAgent(BuildInfo.WRITER_USER_AGENT)
+ .userAgent(userAgent)
.keyStoreInputStream(conf.getKeyStore())
.keyStorePassword(conf.getKeyStorePassword())
.keyStoreType(conf.getKeyStoreTypeOrDefault())
@@ -179,6 +195,11 @@
.collect(Collectors.toList());
}
+ public static SidecarInstance toSidecarInstance(CassandraInstance instance, int sidecarPort)
+ {
+ return new SidecarInstanceImpl(instance.nodeName(), sidecarPort);
+ }
+
public static final class ClientConfig
{
public static final String SIDECAR_PORT = "sidecar_port";
diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/clients/SidecarInstanceImpl.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/clients/SidecarInstanceImpl.java
deleted file mode 100644
index d73dc2e..0000000
--- a/cassandra-analytics-core/src/main/java/org/apache/cassandra/clients/SidecarInstanceImpl.java
+++ /dev/null
@@ -1,143 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.cassandra.clients;
-
-import java.io.IOException;
-import java.io.ObjectInputStream;
-import java.io.ObjectOutputStream;
-import java.io.Serializable;
-import java.util.Objects;
-
-import com.google.common.base.Preconditions;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-import com.esotericsoftware.kryo.Kryo;
-import com.esotericsoftware.kryo.io.Input;
-import com.esotericsoftware.kryo.io.Output;
-import org.apache.cassandra.sidecar.client.SidecarInstance;
-
-/**
- * A simple implementation of the {@link SidecarInstance} interface
- */
-public class SidecarInstanceImpl implements Serializable, SidecarInstance
-{
- private static final long serialVersionUID = -8650006905764842232L;
- private static final Logger LOGGER = LoggerFactory.getLogger(SidecarInstanceImpl.class);
-
- private int port;
- private String hostname;
-
- /**
- * Constructs a new Sidecar instance with the given {@code port} and {@code hostname}
- *
- * @param hostname the host name where Sidecar is running
- * @param port the port where Sidecar is running
- */
- public SidecarInstanceImpl(String hostname, int port)
- {
- Preconditions.checkArgument(0 < port && port <= 65535,
- "The Sidecar port number must be in the range 1-65535: %s", port);
- this.port = port;
- this.hostname = Objects.requireNonNull(hostname, "The Sidecar hostname must be non-null");
- }
-
- /**
- * {@inheritDoc}
- */
- @Override
- public int port()
- {
- return port;
- }
-
- /**
- * {@inheritDoc}
- */
- @Override
- public String hostname()
- {
- return hostname;
- }
-
- /**
- * {@inheritDoc}
- */
- @Override
- public String toString()
- {
- return String.format("SidecarInstanceImpl{hostname='%s', port=%d}", hostname, port);
- }
-
- @Override
- public boolean equals(Object object)
- {
- if (this == object)
- {
- return true;
- }
- if (object == null || getClass() != object.getClass())
- {
- return false;
- }
- SidecarInstanceImpl that = (SidecarInstanceImpl) object;
- return port == that.port && Objects.equals(hostname, that.hostname);
- }
-
- @Override
- public int hashCode()
- {
- return Objects.hash(port, hostname);
- }
-
- // JDK Serialization
-
- private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException
- {
- LOGGER.debug("Falling back to JDK deserialization");
- hostname = in.readUTF();
- port = in.readInt();
- }
-
- private void writeObject(ObjectOutputStream out) throws IOException, ClassNotFoundException
- {
- LOGGER.debug("Falling back to JDK serialization");
- out.writeUTF(hostname);
- out.writeInt(port);
- }
-
- // Kryo Serialization
-
- public static class Serializer extends com.esotericsoftware.kryo.Serializer<SidecarInstanceImpl>
- {
- @Override
- public void write(Kryo kryo, Output out, SidecarInstanceImpl sidecarInstance)
- {
- out.writeString(sidecarInstance.hostname);
- out.writeInt(sidecarInstance.port);
- }
-
- @Override
- public SidecarInstanceImpl read(Kryo kryo, Input input, Class<SidecarInstanceImpl> type)
- {
- return new SidecarInstanceImpl(input.readString(), input.readInt());
- }
- }
-}
diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/KryoRegister.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/KryoRegister.java
index 54d14bf..359782b 100644
--- a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/KryoRegister.java
+++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/KryoRegister.java
@@ -34,7 +34,6 @@
import org.apache.cassandra.bridge.BigNumberConfigImpl;
import org.apache.cassandra.bridge.CassandraBridgeFactory;
import org.apache.cassandra.bridge.CassandraVersion;
-import org.apache.cassandra.clients.SidecarInstanceImpl;
import org.apache.cassandra.secrets.SslConfig;
import org.apache.cassandra.spark.data.CassandraDataLayer;
import org.apache.cassandra.spark.data.LocalDataLayer;
@@ -66,7 +65,6 @@
KRYO_SERIALIZERS.put(TokenPartitioner.class, new TokenPartitioner.Serializer());
KRYO_SERIALIZERS.put(CassandraDataLayer.class, new CassandraDataLayer.Serializer());
KRYO_SERIALIZERS.put(BigNumberConfigImpl.class, new BigNumberConfigImpl.Serializer());
- KRYO_SERIALIZERS.put(SidecarInstanceImpl.class, new SidecarInstanceImpl.Serializer());
KRYO_SERIALIZERS.put(SslConfig.class, new SslConfig.Serializer());
}
diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/BulkSparkConf.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/BulkSparkConf.java
index ff7551a..2db19d5 100644
--- a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/BulkSparkConf.java
+++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/BulkSparkConf.java
@@ -38,8 +38,9 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import org.apache.cassandra.clients.SidecarInstanceImpl;
+import org.apache.cassandra.sidecar.client.SidecarInstanceImpl;
import org.apache.cassandra.sidecar.client.SidecarInstance;
+import org.apache.cassandra.spark.bulkwriter.blobupload.StorageClientConfig;
import org.apache.cassandra.spark.bulkwriter.token.ConsistencyLevel;
import org.apache.cassandra.spark.bulkwriter.util.SbwKryoRegistrator;
import org.apache.cassandra.spark.utils.BuildInfo;
@@ -83,6 +84,10 @@
public static final int DEFAULT_COMMIT_BATCH_SIZE = 10_000;
public static final int DEFAULT_RING_RETRY_COUNT = 3;
public static final int DEFAULT_SSTABLE_DATA_SIZE_IN_MIB = 160;
+ public static final long DEFAULT_STORAGE_CLIENT_KEEP_ALIVE_SECONDS = 60;
+ public static final int DEFAULT_STORAGE_CLIENT_CONCURRENCY = Runtime.getRuntime().availableProcessors() * 2;
+ public static final int DEFAULT_STORAGE_CLIENT_MAX_CHUNK_SIZE_IN_BYTES = 100 * 1024 * 1024; // 100 MiB
+ private static final long DEFAULT_MAX_SIZE_PER_SSTABLE_BUNDLE_IN_BYTES_S3_TRANSPORT = 5L * 1024 * 1024 * 1024;
// NOTE: All Cassandra Analytics setting names must start with "spark" in order to not be ignored by Spark,
// and must not start with "spark.cassandra" so as to not conflict with Spark Cassandra Connector
@@ -99,9 +104,11 @@
public static final String SIDECAR_REQUEST_TIMEOUT_SECONDS = SETTING_PREFIX + "sidecar.request.timeout.seconds";
public static final String SKIP_CLEAN = SETTING_PREFIX + "job.skip_clean";
public static final String USE_OPENSSL = SETTING_PREFIX + "use_openssl";
+ // defines the max number of consecutive retries allowed in the ring monitor
public static final String RING_RETRY_COUNT = SETTING_PREFIX + "ring_retry_count";
+ public static final String IMPORT_COORDINATOR_TIMEOUT_MULTIPLIER = SETTING_PREFIX + "importCoordinatorTimeoutMultiplier";
+ public static final int MINIMUM_JOB_KEEP_ALIVE_MINUTES = 10;
- public final Set<? extends SidecarInstance> sidecarInstances;
public final String keyspace;
public final String table;
public final ConsistencyLevel.CL consistencyLevel;
@@ -111,6 +118,9 @@
public final int commitBatchSize;
public final boolean skipExtendedVerify;
public final WriteMode writeMode;
+ public final int commitThreadsPerInstance;
+ public final int importCoordinatorTimeoutMultiplier;
+ public boolean quoteIdentifiers;
protected final String keystorePassword;
protected final String keystorePath;
protected final String keystoreBase64Encoded;
@@ -122,14 +132,22 @@
protected final String ttl;
protected final String timestamp;
protected final SparkConf conf;
- public final int commitThreadsPerInstance;
- public boolean quoteIdentifiers;
protected final int effectiveSidecarPort;
protected final int userProvidedSidecarPort;
- protected boolean useOpenSsl;
- protected int ringRetryCount;
protected final Set<String> blockedInstances;
protected final DigestAlgorithmSupplier digestAlgorithmSupplier;
+ protected final StorageClientConfig storageClientConfig;
+ protected final DataTransportInfo dataTransportInfo;
+ protected final int jobKeepAliveMinutes;
+ // An optional unique identifier supplied by customer. The jobId is different from restoreJobId that is used internally.
+ // The value is null when absent
+ protected final String configuredJobId;
+ protected boolean useOpenSsl;
+ protected int ringRetryCount;
+ // create sidecarInstances from sidecarInstancesValue and effectiveSidecarPort
+ private final String sidecarInstancesValue;
+ private transient Set<? extends SidecarInstance> sidecarInstances; // not serialized
+
public BulkSparkConf(SparkConf conf, Map<String, String> options)
{
@@ -137,7 +155,8 @@
Optional<Integer> sidecarPortFromOptions = MapUtils.getOptionalInt(options, WriterOptions.SIDECAR_PORT.name(), "sidecar port");
this.userProvidedSidecarPort = sidecarPortFromOptions.isPresent() ? sidecarPortFromOptions.get() : getOptionalInt(SIDECAR_PORT).orElse(-1);
this.effectiveSidecarPort = this.userProvidedSidecarPort == -1 ? DEFAULT_SIDECAR_PORT : this.userProvidedSidecarPort;
- this.sidecarInstances = buildSidecarInstances(options, effectiveSidecarPort);
+ this.sidecarInstancesValue = MapUtils.getOrThrow(options, WriterOptions.SIDECAR_INSTANCES.name(), "sidecar_instances");
+ this.sidecarInstances = sidecarInstances();
this.keyspace = MapUtils.getOrThrow(options, WriterOptions.KEYSPACE.name());
this.table = MapUtils.getOrThrow(options, WriterOptions.TABLE.name());
this.skipExtendedVerify = MapUtils.getBoolean(options, WriterOptions.SKIP_EXTENDED_VERIFY.name(), true,
@@ -162,10 +181,42 @@
// else fall back to props, and then default if neither specified
this.useOpenSsl = getBoolean(USE_OPENSSL, true);
this.ringRetryCount = getInt(RING_RETRY_COUNT, DEFAULT_RING_RETRY_COUNT);
+ this.importCoordinatorTimeoutMultiplier = getInt(IMPORT_COORDINATOR_TIMEOUT_MULTIPLIER, 2);
this.ttl = MapUtils.getOrDefault(options, WriterOptions.TTL.name(), null);
this.timestamp = MapUtils.getOrDefault(options, WriterOptions.TIMESTAMP.name(), null);
this.quoteIdentifiers = MapUtils.getBoolean(options, WriterOptions.QUOTE_IDENTIFIERS.name(), false, "quote identifiers");
this.blockedInstances = buildBlockedInstances(options);
+ int storageClientConcurrency = MapUtils.getInt(options, WriterOptions.STORAGE_CLIENT_CONCURRENCY.name(),
+ DEFAULT_STORAGE_CLIENT_CONCURRENCY, "storage client concurrency");
+ long storageClientKeepAliveSeconds = MapUtils.getLong(options, WriterOptions.STORAGE_CLIENT_THREAD_KEEP_ALIVE_SECONDS.name(),
+ DEFAULT_STORAGE_CLIENT_KEEP_ALIVE_SECONDS);
+ int storageClientMaxChunkSizeInBytes = MapUtils.getInt(options, WriterOptions.STORAGE_CLIENT_MAX_CHUNK_SIZE_IN_BYTES.name(),
+ DEFAULT_STORAGE_CLIENT_MAX_CHUNK_SIZE_IN_BYTES);
+ String storageClientHttpsProxy = MapUtils.getOrDefault(options, WriterOptions.STORAGE_CLIENT_HTTPS_PROXY.name(), null);
+ String storageClientEndpointOverride = MapUtils.getOrDefault(options, WriterOptions.STORAGE_CLIENT_ENDPOINT_OVERRIDE.name(), null);
+ long nioHttpClientConnectionAcquisitionTimeoutSeconds =
+ MapUtils.getLong(options, WriterOptions.STORAGE_CLIENT_NIO_HTTP_CLIENT_CONNECTION_ACQUISITION_TIMEOUT_SECONDS.name(), 300);
+ int nioHttpClientMaxConcurrency = MapUtils.getInt(options, WriterOptions.STORAGE_CLIENT_NIO_HTTP_CLIENT_MAX_CONCURRENCY.name(), 50);
+ this.storageClientConfig = new StorageClientConfig(storageClientConcurrency,
+ storageClientKeepAliveSeconds,
+ storageClientMaxChunkSizeInBytes,
+ storageClientHttpsProxy,
+ storageClientEndpointOverride,
+ nioHttpClientConnectionAcquisitionTimeoutSeconds,
+ nioHttpClientMaxConcurrency);
+ DataTransport dataTransport = MapUtils.getEnumOption(options, WriterOptions.DATA_TRANSPORT.name(), DataTransport.DIRECT, "Data Transport");
+ long maxSizePerSSTableBundleInBytesS3Transport = MapUtils.getLong(options, WriterOptions.MAX_SIZE_PER_SSTABLE_BUNDLE_IN_BYTES_S3_TRANSPORT.name(),
+ DEFAULT_MAX_SIZE_PER_SSTABLE_BUNDLE_IN_BYTES_S3_TRANSPORT);
+ String transportExtensionClass = MapUtils.getOrDefault(options, WriterOptions.DATA_TRANSPORT_EXTENSION_CLASS.name(), null);
+ this.dataTransportInfo = new DataTransportInfo(dataTransport, transportExtensionClass, maxSizePerSSTableBundleInBytesS3Transport);
+ this.jobKeepAliveMinutes = MapUtils.getInt(options, WriterOptions.JOB_KEEP_ALIVE_MINUTES.name(), MINIMUM_JOB_KEEP_ALIVE_MINUTES);
+ if (this.jobKeepAliveMinutes < MINIMUM_JOB_KEEP_ALIVE_MINUTES)
+ {
+ throw new IllegalArgumentException(String.format("Invalid value for the '%s' Bulk Writer option (%d). It cannot be less than the minimum %s",
+ WriterOptions.JOB_KEEP_ALIVE_MINUTES, jobKeepAliveMinutes, MINIMUM_JOB_KEEP_ALIVE_MINUTES));
+ }
+ this.configuredJobId = MapUtils.getOrDefault(options, WriterOptions.JOB_ID.name(), null);
+
validateEnvironment();
}
@@ -211,14 +262,22 @@
return legacyOptionValue == -1 ? DEFAULT_SSTABLE_DATA_SIZE_IN_MIB : legacyOptionValue;
}
- protected Set<? extends SidecarInstance> buildSidecarInstances(Map<String, String> options, int sidecarPort)
+ protected Set<? extends SidecarInstance> buildSidecarInstances()
{
- String sidecarInstances = MapUtils.getOrThrow(options, WriterOptions.SIDECAR_INSTANCES.name(), "sidecar_instances");
- return Arrays.stream(sidecarInstances.split(","))
- .map(hostname -> new SidecarInstanceImpl(hostname, sidecarPort))
+ return Arrays.stream(sidecarInstancesValue.split(","))
+ .map(hostname -> new SidecarInstanceImpl(hostname, effectiveSidecarPort))
.collect(Collectors.toSet());
}
+ Set<? extends SidecarInstance> sidecarInstances()
+ {
+ if (sidecarInstances == null)
+ {
+ sidecarInstances = buildSidecarInstances();
+ }
+ return sidecarInstances;
+ }
+
protected void validateEnvironment() throws RuntimeException
{
Preconditions.checkNotNull(keyspace);
@@ -430,6 +489,11 @@
return coresPerExecutor * numExecutors;
}
+ public int getJobKeepAliveMinutes()
+ {
+ return jobKeepAliveMinutes;
+ }
+
protected int getInt(String settingName, int defaultValue)
{
String finalSetting = getSettingNameOrDeprecatedName(settingName);
@@ -522,7 +586,7 @@
}
}
- protected SparkConf getConf()
+ public SparkConf getSparkConf()
{
return conf;
}
@@ -537,6 +601,16 @@
return ringRetryCount;
}
+ public StorageClientConfig getStorageClientConfig()
+ {
+ return storageClientConfig;
+ }
+
+ public DataTransportInfo getTransportInfo()
+ {
+ return dataTransportInfo;
+ }
+
public boolean hasKeystoreAndKeystorePassword()
{
return keystorePassword != null && (keystorePath != null || keystoreBase64Encoded != null);
diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/BulkWriteValidator.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/BulkWriteValidator.java
index 172586f..42413d0 100644
--- a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/BulkWriteValidator.java
+++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/BulkWriteValidator.java
@@ -22,6 +22,7 @@
import java.math.BigInteger;
import java.util.AbstractMap;
import java.util.Collection;
+import java.util.List;
import java.util.Map;
import com.google.common.collect.Multimap;
@@ -45,8 +46,8 @@
public BulkWriteValidator(BulkWriterContext bulkWriterContext,
ReplicaAwareFailureHandler<RingInstance> failureHandler)
{
- cluster = bulkWriterContext.cluster();
- job = bulkWriterContext.job();
+ this.cluster = bulkWriterContext.cluster();
+ this.job = bulkWriterContext.job();
this.failureHandler = failureHandler;
}
@@ -87,9 +88,20 @@
public void setPhase(String phase)
{
+ LOGGER.info("Updating write phase from {} to {}", this.phase, phase);
this.phase = phase;
}
+ public synchronized void updateFailureHandler(List<? extends StreamResult> results)
+ {
+ results.stream()
+ .flatMap(res -> res.failures.stream())
+ .forEach(err -> {
+ LOGGER.info("Populate stream error from tasks. {}", err);
+ failureHandler.addFailure(err.failedRange, err.instance, err.errMsg);
+ });
+ }
+
public static void updateFailureHandler(CommitResult commitResult,
String phase,
ReplicaAwareFailureHandler<RingInstance> failureHandler)
@@ -105,6 +117,11 @@
});
}
+ public void updateFailureHandler(Range<BigInteger> failedRange, RingInstance instance, String reason)
+ {
+ failureHandler.addFailure(failedRange, instance, reason);
+ }
+
public void validateClOrFail(TokenRangeMapping<RingInstance> tokenRangeMapping)
{
// Updates failures by looking up instance metadata
@@ -130,7 +147,6 @@
+ "Please rerun import once topology changes are complete.",
instance.nodeName(), cluster.getInstanceState(instance));
throw new RuntimeException(errorMessage);
-
// Check for blocked instances and ranges for the purpose of logging only.
// We check for blocked instances while validating consistency level requirements
case UNAVAILABLE_BLOCKED:
diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/BulkWriterContext.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/BulkWriterContext.java
index 945f8a2..40012e0 100644
--- a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/BulkWriterContext.java
+++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/BulkWriterContext.java
@@ -34,11 +34,11 @@
SchemaInfo schema();
- DataTransferApi transfer();
-
CassandraBridge bridge();
// NOTE: This interface intentionally does *not* implement AutoClosable as Spark can close Broadcast variables
// that implement AutoClosable while they are still in use, causing the underlying object to become unusable
void shutdown();
+
+ TransportContext transportContext();
}
diff --git a/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/NonValidatingTestSSTableWriter.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/CancelJobEvent.java
similarity index 63%
copy from cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/NonValidatingTestSSTableWriter.java
copy to cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/CancelJobEvent.java
index 08ae58a..cb33d77 100644
--- a/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/NonValidatingTestSSTableWriter.java
+++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/CancelJobEvent.java
@@ -19,21 +19,24 @@
package org.apache.cassandra.spark.bulkwriter;
-import java.nio.file.Path;
-
-import org.apache.cassandra.spark.utils.DigestAlgorithm;
-import org.jetbrains.annotations.NotNull;
-
-class NonValidatingTestSSTableWriter extends SSTableWriter
+/**
+ * A simple data structure to describe an event that leads to job cancellation.
+ * It contains the reason of cancellation and optionally the cause
+ */
+public class CancelJobEvent
{
- NonValidatingTestSSTableWriter(MockTableWriter tableWriter, Path path, DigestAlgorithm digestAlgorithm)
+ public final Throwable exception;
+ public final String reason;
+
+ public CancelJobEvent(String reason)
{
- super(tableWriter, path, digestAlgorithm);
+ this.reason = reason;
+ this.exception = null;
}
- @Override
- public void validateSSTables(@NotNull BulkWriterContext writerContext, int partitionId)
+ public CancelJobEvent(String reason, Throwable throwable)
{
- // Skip validation for these tests
+ this.reason = reason;
+ this.exception = throwable;
}
}
diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/CassandraBulkSourceRelation.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/CassandraBulkSourceRelation.java
index bba5a31..50a0119 100644
--- a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/CassandraBulkSourceRelation.java
+++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/CassandraBulkSourceRelation.java
@@ -26,10 +26,24 @@
import java.util.List;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
+import java.util.UUID;
+import java.util.function.Consumer;
+import javax.validation.constraints.NotNull;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import o.a.c.sidecar.client.shaded.common.data.CreateRestoreJobRequestPayload;
+import o.a.c.sidecar.client.shaded.common.data.RestoreJobSecrets;
+import o.a.c.sidecar.client.shaded.common.data.RestoreJobStatus;
+import o.a.c.sidecar.client.shaded.common.data.UpdateRestoreJobRequestPayload;
+import org.apache.cassandra.spark.bulkwriter.blobupload.BlobStreamResult;
+import org.apache.cassandra.spark.bulkwriter.token.ReplicaAwareFailureHandler;
+import org.apache.cassandra.spark.common.client.ClientException;
+import org.apache.cassandra.spark.transports.storage.extensions.StorageTransportConfiguration;
+import org.apache.cassandra.spark.transports.storage.extensions.StorageTransportExtension;
+import org.apache.cassandra.spark.transports.storage.extensions.StorageTransportHandler;
+import org.apache.cassandra.spark.utils.BuildInfo;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FlatMapFunction;
@@ -40,7 +54,6 @@
import org.apache.spark.sql.sources.BaseRelation;
import org.apache.spark.sql.sources.InsertableRelation;
import org.apache.spark.sql.types.StructType;
-import org.jetbrains.annotations.NotNull;
import scala.Tuple2;
import scala.collection.JavaConverters;
import scala.util.control.NonFatal$;
@@ -52,6 +65,8 @@
private final SQLContext sqlContext;
private final JavaSparkContext sparkContext;
private final Broadcast<BulkWriterContext> broadcastContext;
+ private final BulkWriteValidator writeValidator;
+ private HeartbeatReporter heartbeatReporter;
private long startTimeNanos;
@SuppressWarnings("RedundantTypeArguments")
@@ -61,6 +76,9 @@
this.sqlContext = sqlContext;
this.sparkContext = JavaSparkContext.fromSparkContext(sqlContext.sparkContext());
this.broadcastContext = sparkContext.<BulkWriterContext>broadcast(writerContext);
+ ReplicaAwareFailureHandler<RingInstance> failureHandler = new ReplicaAwareFailureHandler<>(writerContext.cluster().getPartitioner());
+ this.writeValidator = new BulkWriteValidator(writerContext, failureHandler);
+ onCloudStorageTransport(ignored -> this.heartbeatReporter = new HeartbeatReporter());
}
@Override
@@ -94,12 +112,10 @@
@Override
public void insert(@NotNull Dataset<Row> data, boolean overwrite)
{
+ validateJob(overwrite);
this.startTimeNanos = System.nanoTime();
- if (overwrite)
- {
- throw new LoadNotSupportedException("Overwriting existing data needs TRUNCATE on Cassandra, which is not supported");
- }
- writerContext.cluster().checkBulkWriterIsEnabledOrThrow();
+
+ maybeEnableTransportExtension();
Tokenizer tokenizer = new Tokenizer(writerContext);
TableSchema tableSchema = writerContext.schema().getTableSchema();
JavaPairRDD<DecoratedKey, Object[]> sortedRDD = data.toJavaRDD()
@@ -111,82 +127,159 @@
persist(sortedRDD, data.columns());
}
- private void persist(@NotNull JavaPairRDD<DecoratedKey, Object[]> sortedRDD, String[] columnNames)
+ private void validateJob(boolean overwrite)
{
+ if (overwrite)
+ {
+ throw new LoadNotSupportedException("Overwriting existing data needs TRUNCATE on Cassandra, which is not supported");
+ }
+ writerContext.cluster().checkBulkWriterIsEnabledOrThrow();
+ }
+
+ public void cancelJob(@NotNull CancelJobEvent cancelJobEvent)
+ {
+ if (cancelJobEvent.exception != null)
+ {
+ LOGGER.error("An unrecoverable error occurred during {} stage of import while validating the current cluster state; cancelling job",
+ writeValidator.getPhase(), cancelJobEvent.exception);
+ }
+ else
+ {
+ LOGGER.error("Job was canceled due to '{}' during {} stage of import; please rerun import once topology changes are complete",
+ cancelJobEvent.reason, writeValidator.getPhase());
+ }
try
{
- List<WriteResult> writeResults = sortedRDD
- .mapPartitions(writeRowsInPartition(broadcastContext, columnNames))
- .collect();
+ onCloudStorageTransport(ctx -> abortRestoreJob(ctx, cancelJobEvent.exception));
+ }
+ finally
+ {
+ sparkContext.cancelJobGroup(writerContext.job().getId());
+ }
+ }
- publishSuccessfulJobStats(writeResults);
+ private void persist(@NotNull JavaPairRDD<DecoratedKey, Object[]> sortedRDD, String[] columnNames)
+ {
+ onDirectTransport(ctx -> writeValidator.setPhase("UploadAndCommit"));
+ onCloudStorageTransport(ctx -> {
+ writeValidator.setPhase("UploadToCloudStorage");
+ ctx.transportExtensionImplementation().onTransportStart(elapsedTimeMillis());
+ });
+
+ try
+ {
+ // Copy the broadcast context as a local variable (by passing as the input) to avoid serialization error
+ // W/o this, SerializedLambda captures the CassandraBulkSourceRelation object, which is not serializable (required by Spark),
+ // as a captured argument. It causes "Task not serializable" error.
+ List<WriteResult> writeResults = sortedRDD
+ .mapPartitions(writeRowsInPartition(broadcastContext, columnNames))
+ .collect();
+
+ // Unpersist broadcast context to free up executors while driver waits for the
+ // import to complete
+ unpersist();
+
+ List<StreamResult> streamResults = writeResults.stream()
+ .map(WriteResult::streamResults)
+ .flatMap(Collection::stream)
+ .collect(Collectors.toList());
+
+ long rowCount = streamResults.stream().mapToLong(res -> res.rowCount).sum();
+ long totalBytesWritten = streamResults.stream().mapToLong(res -> res.bytesWritten).sum();
+ boolean hasClusterTopologyChanged = writeResults.stream().anyMatch(WriteResult::isClusterResizeDetected);
+
+ onCloudStorageTransport(context -> {
+ LOGGER.info("Waiting for Cassandra to complete import slices. rows={} bytes={} cluster_resized={}",
+ rowCount,
+ totalBytesWritten,
+ hasClusterTopologyChanged);
+
+ // Update with the stream result from tasks.
+ // Some token ranges might fail on instances, but the CL is still satisfied at this step
+ writeValidator.updateFailureHandler(streamResults);
+
+ List<BlobStreamResult> resultsAsBlobStreamResults = streamResults.stream()
+ .map(BlobStreamResult.class::cast)
+ .collect(Collectors.toList());
+
+ int objectsCount = resultsAsBlobStreamResults.stream()
+ .mapToInt(res -> res.createdRestoreSlices.size())
+ .sum();
+ // report the number of objects persisted on s3
+ LOGGER.info("Notifying extension all objects have been persisted, totaling {} objects", objectsCount);
+ context.transportExtensionImplementation()
+ .onAllObjectsPersisted(objectsCount, rowCount, elapsedTimeMillis());
+
+ ImportCompletionCoordinator.of(startTimeNanos, writerContext, context.dataTransferApi(),
+ writeValidator, resultsAsBlobStreamResults,
+ context.transportExtensionImplementation(), this::cancelJob)
+ .waitForCompletion();
+ markRestoreJobAsSucceeded(context);
+ });
+
+ LOGGER.info("Bulk writer job complete. rows={} bytes={} cluster_resized={}",
+ rowCount,
+ totalBytesWritten,
+ hasClusterTopologyChanged);
+ publishSuccessfulJobStats(rowCount, totalBytesWritten, hasClusterTopologyChanged);
}
catch (Throwable throwable)
{
publishFailureJobStats(throwable.getMessage());
- LOGGER.error("Bulk Write Failed", throwable);
- throw new RuntimeException("Bulk Write to Cassandra has failed", throwable);
+ LOGGER.error("Bulk Write Failed.", throwable);
+ RuntimeException failure = new RuntimeException("Bulk Write to Cassandra has failed", throwable);
+ try
+ {
+ onCloudStorageTransport(ctx -> abortRestoreJob(ctx, throwable));
+ }
+ catch (Exception rte)
+ {
+ failure.addSuppressed(rte);
+ }
+
+ throw failure;
}
finally
{
try
{
+ onCloudStorageTransport(ignored -> heartbeatReporter.close());
writerContext.shutdown();
sqlContext().sparkContext().clearJobGroup();
}
catch (Exception ignored)
{
+ LOGGER.warn("Ignored exception during spark job shutdown.", ignored);
// We've made our best effort to close the Bulk Writer context
}
unpersist();
}
}
- private void publishSuccessfulJobStats(List<WriteResult> writeResults)
+ private void publishSuccessfulJobStats(long rowCount, long totalBytesWritten, boolean hasClusterTopologyChanged)
{
- List<StreamResult> streamResults = writeResults.stream()
- .map(WriteResult::streamResults)
- .flatMap(Collection::stream)
- .collect(Collectors.toList());
-
- long rowCount = streamResults.stream().mapToLong(res -> res.rowCount).sum();
- long totalBytesWritten = streamResults.stream().mapToLong(res -> res.bytesWritten).sum();
- boolean hasClusterTopologyChanged = writeResults.stream()
- .anyMatch(WriteResult::isClusterResizeDetected);
- LOGGER.info("Bulk writer job complete. rows={} bytes={} cluster_resize={}",
- rowCount,
- totalBytesWritten,
- hasClusterTopologyChanged);
- writerContext.jobStats().publish(new HashMap<String, String>()
- {
- {
+ writerContext.jobStats().publish(new HashMap<String, String>() // type declaration required to compile with java8
+ {{
put("jobId", writerContext.job().getId().toString());
+ put("transportInfo", writerContext.job().transportInfo().toString());
put("rowsWritten", Long.toString(rowCount));
put("bytesWritten", Long.toString(totalBytesWritten));
put("jobStatus", "Succeeded");
put("clusterResizeDetected", String.valueOf(hasClusterTopologyChanged));
- put("jobElapsedTimeMillis", Long.toString(getElapsedTimeMillis()));
- }
- });
+ put("jobElapsedTimeMillis", Long.toString(elapsedTimeMillis()));
+ }});
}
private void publishFailureJobStats(String reason)
{
- writerContext.jobStats().publish(new HashMap<String, String>()
- {
- {
+ writerContext.jobStats().publish(new HashMap<String, String>() // type declaration required to compile with java8
+ {{
put("jobId", writerContext.job().getId().toString());
+ put("transportInfo", writerContext.job().transportInfo().toString());
put("jobStatus", "Failed");
put("failureReason", reason);
- put("jobElapsedTimeMillis", Long.toString(getElapsedTimeMillis()));
- }
- });
- }
-
- private long getElapsedTimeMillis()
- {
- long now = System.nanoTime();
- return TimeUnit.NANOSECONDS.toMillis(now - this.startTimeNanos);
+ put("jobElapsedTimeMillis", Long.toString(elapsedTimeMillis()));
+ }});
}
/**
@@ -224,4 +317,121 @@
}
}
}
+
+ // initialization for CloudStorageTransport
+ private void maybeEnableTransportExtension()
+ {
+ onCloudStorageTransport(ctx -> {
+ StorageTransportHandler storageTransportHandler = new StorageTransportHandler(ctx, writerContext.job(), this::cancelJob);
+ StorageTransportExtension impl = ctx.transportExtensionImplementation();
+ impl.setCredentialChangeListener(storageTransportHandler);
+ impl.setObjectFailureListener(storageTransportHandler);
+ createRestoreJob(ctx);
+ heartbeatReporter.schedule("Extend lease",
+ TimeUnit.MINUTES.toMillis(1),
+ () -> extendLeaseForJob(ctx));
+ });
+ }
+
+ private void extendLeaseForJob(TransportContext.CloudStorageTransportContext ctx)
+ {
+ UpdateRestoreJobRequestPayload payload = new UpdateRestoreJobRequestPayload(null, null, null, updatedLeaseTime());
+ try
+ {
+ ctx.dataTransferApi().updateRestoreJob(payload);
+ }
+ catch (ClientException e)
+ {
+ LOGGER.warn("Failed to update expireAt for job", e);
+ }
+ }
+
+ private long updatedLeaseTime()
+ {
+ return System.currentTimeMillis() + TimeUnit.MINUTES.toMillis(writerContext.job().jobKeepAliveMinutes());
+ }
+
+ private long elapsedTimeMillis()
+ {
+ long now = System.nanoTime();
+ return TimeUnit.NANOSECONDS.toMillis(now - this.startTimeNanos);
+ }
+
+ void onCloudStorageTransport(Consumer<TransportContext.CloudStorageTransportContext> consumer)
+ {
+ TransportContext transportContext = writerContext.transportContext();
+ if (transportContext instanceof TransportContext.CloudStorageTransportContext)
+ {
+ consumer.accept((TransportContext.CloudStorageTransportContext) transportContext);
+ }
+ }
+
+ void onDirectTransport(Consumer<TransportContext.DirectDataBulkWriterContext> consumer)
+ {
+ TransportContext transportContext = writerContext.transportContext();
+ if (transportContext instanceof TransportContext.DirectDataBulkWriterContext)
+ {
+ consumer.accept((TransportContext.DirectDataBulkWriterContext) transportContext);
+ }
+ }
+
+ private void createRestoreJob(TransportContext.CloudStorageTransportContext context)
+ {
+ StorageTransportConfiguration conf = context.transportConfiguration();
+ RestoreJobSecrets secrets = conf.getStorageCredentialPair().toRestoreJobSecrets(conf.getReadRegion(),
+ conf.getWriteRegion());
+ JobInfo job = writerContext.job();
+ CreateRestoreJobRequestPayload payload = CreateRestoreJobRequestPayload
+ .builder(secrets, updatedLeaseTime())
+ .jobAgent(BuildInfo.APPLICATION_NAME)
+ .jobId(job.getRestoreJobId())
+ .updateImportOptions(importOptions -> {
+ importOptions.verifySSTables(true) // we disallow the end-user to bypass the non-extended verify anymore
+ .extendedVerify(false); // always turn off
+ })
+ .build();
+
+ try
+ {
+ context.dataTransferApi().createRestoreJob(payload);
+ }
+ catch (ClientException e)
+ {
+ throw new RuntimeException("Failed to create a new restore job on Sidecar", e);
+ }
+ }
+
+ private void markRestoreJobAsSucceeded(TransportContext.CloudStorageTransportContext context)
+ {
+ UpdateRestoreJobRequestPayload requestPayload = new UpdateRestoreJobRequestPayload(null, null, RestoreJobStatus.SUCCEEDED, null);
+ UUID jobId = writerContext.job().getRestoreJobId();
+ try
+ {
+ LOGGER.info("Marking the restore job as succeeded. jobId={}", jobId);
+ // Prioritize the call to extension, so onJobSucceeded is always invoked.
+ context.transportExtensionImplementation().onJobSucceeded(elapsedTimeMillis());
+ context.dataTransferApi().updateRestoreJob(requestPayload);
+ }
+ catch (Exception e)
+ {
+ LOGGER.warn("Failed to mark the restore job as succeeded. jobId={}", jobId, e);
+ // Do not rethrow - avoid triggering the catch block at the call-site that marks job as failed.
+ }
+ }
+
+ private void abortRestoreJob(TransportContext.CloudStorageTransportContext context, Throwable cause)
+ {
+ // Prioritize the call to extension, so onJobFailed is always invoked.
+ context.transportExtensionImplementation().onJobFailed(elapsedTimeMillis(), cause);
+ UUID jobId = writerContext.job().getRestoreJobId();
+ try
+ {
+ LOGGER.info("Aborting job. jobId={}", jobId);
+ context.dataTransferApi().abortRestoreJob();
+ }
+ catch (ClientException e)
+ {
+ throw new RuntimeException("Failed to abort the restore job on Sidecar. jobId: " + jobId, e);
+ }
+ }
}
diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/CassandraBulkWriterContext.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/CassandraBulkWriterContext.java
index 0999604..89f2bb9 100644
--- a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/CassandraBulkWriterContext.java
+++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/CassandraBulkWriterContext.java
@@ -56,11 +56,10 @@
private final JobInfo jobInfo;
private final String lowestCassandraVersion;
private transient CassandraBridge bridge;
- private transient DataTransferApi dataTransferApi;
private final CassandraClusterInfo clusterInfo;
private final SchemaInfo schemaInfo;
-
- private transient JobStatsPublisher jobStatsPublisher;
+ private final transient JobStatsPublisher jobStatsPublisher;
+ protected transient volatile TransportContext transportContext;
protected CassandraBulkWriterContext(@NotNull BulkSparkConf conf,
@NotNull CassandraClusterInfo clusterInfo,
@@ -69,11 +68,13 @@
{
this.conf = conf;
this.clusterInfo = clusterInfo;
+ clusterInfo.startupValidate();
this.jobStatsPublisher = new LogStatsPublisher();
lowestCassandraVersion = clusterInfo.getLowestCassandraVersion();
this.bridge = CassandraBridgeFactory.get(lowestCassandraVersion);
TokenRangeMapping<RingInstance> tokenRangeMapping = clusterInfo.getTokenRangeMapping(true);
jobInfo = new CassandraJobInfo(conf,
+ bridge.getTimeUUID(), // used for creating restore job on sidecar
new TokenPartitioner(tokenRangeMapping,
conf.numberSplits,
sparkContext.defaultParallelism(),
@@ -84,8 +85,10 @@
.containsKey(conf.localDC)),
String.format("Keyspace %s is not replicated on datacenter %s", conf.keyspace, conf.localDC));
- String keyspace = jobInfo.keyspace();
- String table = jobInfo.tableName();
+ transportContext = createTransportContext(true);
+
+ String keyspace = jobInfo.qualifiedTableName().keyspace();
+ String table = jobInfo.qualifiedTableName().table();
String keyspaceSchema = clusterInfo.getKeyspaceSchema(true);
Partitioner partitioner = clusterInfo.getPartitioner();
@@ -112,6 +115,7 @@
return bridge;
}
+ // Static factory to create BulkWriterContext based on the requested Bulk Writer transport strategy
public static BulkWriterContext fromOptions(@NotNull SparkContext sparkContext,
@NotNull Map<String, String> strOptions,
@NotNull StructType dfSchema)
@@ -120,9 +124,6 @@
BulkSparkConf conf = new BulkSparkConf(sparkContext.getConf(), strOptions);
CassandraClusterInfo clusterInfo = new CassandraClusterInfo(conf);
-
- clusterInfo.startupValidate();
-
CassandraBulkWriterContext bulkWriterContext = new CassandraBulkWriterContext(conf, clusterInfo, dfSchema, sparkContext);
ShutdownHookManager.addShutdownHook(org.apache.spark.util.ShutdownHookManager.TEMP_DIR_SHUTDOWN_PRIORITY(),
ScalaFunctions.wrapLambda(bulkWriterContext::shutdown));
@@ -132,7 +133,7 @@
private void publishInitialJobStats(String sparkVersion)
{
- Map<String, String> initialJobStats = new HashMap<String, String>()
+ Map<String, String> initialJobStats = new HashMap<String, String>() // type declaration required to compile with java8
{{
put("jobId", jobInfo.getId().toString());
put("sparkVersion", sparkVersion);
@@ -146,12 +147,14 @@
public void shutdown()
{
LOGGER.info("Shutting down {}", this);
- synchronized (this)
+ if (clusterInfo != null)
{
- if (clusterInfo != null)
- {
- clusterInfo.close();
- }
+ clusterInfo.close();
+ }
+
+ if (transportContext != null)
+ {
+ transportContext.close();
}
}
@@ -212,17 +215,31 @@
}
@Override
- @NotNull
- public synchronized DataTransferApi transfer()
+ public TransportContext transportContext()
{
- if (dataTransferApi == null)
+ // When running on driver, transportContext is created at the constructor, and it is not null
+ if (transportContext != null)
{
- dataTransferApi = new SidecarDataTransferApi(clusterInfo.getCassandraContext(),
- bridge(),
- jobInfo,
- conf);
+ return transportContext;
}
- return dataTransferApi;
+
+ // When running on executor, transportContext is null. Synchronize to avoid multi-instantiation
+ synchronized (this)
+ {
+ if (transportContext == null)
+ {
+ transportContext = createTransportContext(false);
+ }
+ }
+ return transportContext;
+ }
+
+ @NotNull
+ protected TransportContext createTransportContext(boolean isOnDriver)
+ {
+ return conf.getTransportInfo()
+ .getTransport()
+ .createContext(this, conf, isOnDriver);
}
@NotNull
@@ -237,6 +254,6 @@
conf.getTTLOptions(),
conf.getTimestampOptions(),
lowestCassandraVersion,
- conf.quoteIdentifiers);
+ jobInfo.qualifiedTableName().quoteIdentifiers());
}
}
diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/CassandraClusterInfo.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/CassandraClusterInfo.java
index 7d32f7a..9ba7824 100644
--- a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/CassandraClusterInfo.java
+++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/CassandraClusterInfo.java
@@ -106,19 +106,12 @@
}
@Override
- public boolean instanceIsAvailable(RingInstance ringInstance)
- {
- return instanceIsUp(ringInstance.ringInstance())
- && instanceIsNormal(ringInstance.ringInstance())
- && !instanceIsBlocked(ringInstance);
- }
-
- @Override
public InstanceState getInstanceState(RingInstance ringInstance)
{
return InstanceState.valueOf(ringInstance.ringInstance().state().toUpperCase());
}
+ @Override
public CassandraContext getCassandraContext()
{
CassandraContext currentCassandraContext = cassandraContext;
@@ -157,11 +150,8 @@
@Override
public void close()
{
- synchronized (this)
- {
- LOGGER.info("Closing {}", this);
- getCassandraContext().close();
- }
+ LOGGER.info("Closing {}", this);
+ getCassandraContext().close();
}
@Override
@@ -292,24 +282,31 @@
@Override
public TokenRangeMapping<RingInstance> getTokenRangeMapping(boolean cached)
{
- TokenRangeMapping<RingInstance> tokenRangeReplicas = this.tokenRangeReplicas;
- if (cached && tokenRangeReplicas != null)
+ TokenRangeMapping<RingInstance> topology = this.tokenRangeReplicas;
+ if (cached && topology != null)
{
- return tokenRangeReplicas;
+ return topology;
}
+ // Block for the call-sites requesting the latest view of the ring; but it is OK to serve the other call-sites that request for the cached view
+ // We can avoid synchronization here
+ if (topology != null)
+ {
+ topology = getTokenRangeReplicas();
+ this.tokenRangeReplicas = topology;
+ return topology;
+ }
+
+ // Only synchronize when it is the first time fetching the ring information
synchronized (this)
{
- if (!cached || this.tokenRangeReplicas == null)
+ try
{
- try
- {
- this.tokenRangeReplicas = getTokenRangeReplicas();
- }
- catch (Exception exception)
- {
- throw new RuntimeException("Unable to initialize ring information", exception);
- }
+ this.tokenRangeReplicas = getTokenRangeReplicas();
+ }
+ catch (Exception exception)
+ {
+ throw new RuntimeException("Unable to initialize ring information", exception);
}
return this.tokenRangeReplicas;
}
diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/CassandraContext.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/CassandraContext.java
index 08fb79b..97d19c2 100644
--- a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/CassandraContext.java
+++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/CassandraContext.java
@@ -88,7 +88,7 @@
protected Set<? extends SidecarInstance> createClusterConfig()
{
- return conf.sidecarInstances;
+ return conf.sidecarInstances();
}
public SidecarClient getSidecarClient()
diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/CassandraDirectDataTransportContext.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/CassandraDirectDataTransportContext.java
new file mode 100644
index 0000000..963ebad
--- /dev/null
+++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/CassandraDirectDataTransportContext.java
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.cassandra.spark.bulkwriter;
+
+import java.math.BigInteger;
+
+import com.google.common.collect.Range;
+
+import org.apache.cassandra.bridge.CassandraBridge;
+import org.apache.cassandra.bridge.CassandraBridgeFactory;
+import org.apache.cassandra.spark.bulkwriter.token.ReplicaAwareFailureHandler;
+import org.jetbrains.annotations.NotNull;
+
+public class CassandraDirectDataTransportContext implements TransportContext.DirectDataBulkWriterContext
+{
+ @NotNull
+ private final JobInfo jobInfo;
+ @NotNull
+ private final ClusterInfo clusterInfo;
+ @NotNull
+ private final DirectDataTransferApi dataTransferApi;
+
+ public CassandraDirectDataTransportContext(@NotNull BulkWriterContext bulkWriterContext)
+ {
+ this.jobInfo = bulkWriterContext.job();
+ this.clusterInfo = bulkWriterContext.cluster();
+ this.dataTransferApi = createDirectDataTransferApi();
+ }
+
+ @Override
+ public DirectStreamSession createStreamSession(BulkWriterContext writerContext,
+ String sessionId,
+ SortedSSTableWriter sstableWriter,
+ Range<BigInteger> range,
+ ReplicaAwareFailureHandler<RingInstance> failureHandler)
+ {
+ return new DirectStreamSession(writerContext,
+ sstableWriter,
+ this,
+ sessionId,
+ range,
+ failureHandler);
+ }
+
+ @Override
+ public DirectDataTransferApi dataTransferApi()
+ {
+ return dataTransferApi;
+ }
+
+ // only invoke in constructor
+ protected DirectDataTransferApi createDirectDataTransferApi()
+ {
+ CassandraBridge bridge = CassandraBridgeFactory.get(clusterInfo.getLowestCassandraVersion());
+ return new SidecarDataTransferApi(clusterInfo.getCassandraContext(), bridge, jobInfo);
+ }
+}
diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/CassandraJobInfo.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/CassandraJobInfo.java
index f91b495..3e7139f 100644
--- a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/CassandraJobInfo.java
+++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/CassandraJobInfo.java
@@ -22,18 +22,19 @@
import java.util.UUID;
import org.apache.cassandra.spark.bulkwriter.token.ConsistencyLevel;
+import org.apache.cassandra.spark.data.QualifiedTableName;
import org.jetbrains.annotations.NotNull;
public class CassandraJobInfo implements JobInfo
{
private static final long serialVersionUID = 6140098484732683759L;
- private final BulkSparkConf conf;
- @NotNull
- private final UUID jobId = UUID.randomUUID();
- private final TokenPartitioner tokenPartitioner;
+ protected final BulkSparkConf conf;
+ protected final UUID restoreJobId;
+ protected final TokenPartitioner tokenPartitioner;
- CassandraJobInfo(BulkSparkConf conf, TokenPartitioner tokenPartitioner)
+ public CassandraJobInfo(BulkSparkConf conf, UUID restoreJobId, TokenPartitioner tokenPartitioner)
{
+ this.restoreJobId = restoreJobId;
this.conf = conf;
this.tokenPartitioner = tokenPartitioner;
}
@@ -75,15 +76,45 @@
}
@Override
+ public DataTransportInfo transportInfo()
+ {
+ return conf.getTransportInfo();
+ }
+
+ @Override
+ public int jobKeepAliveMinutes()
+ {
+ return conf.getJobKeepAliveMinutes();
+ }
+
+ @Override
+ public int effectiveSidecarPort()
+ {
+ return conf.getEffectiveSidecarPort();
+ }
+
+ @Override
+ public int importCoordinatorTimeoutMultiplier()
+ {
+ return conf.importCoordinatorTimeoutMultiplier;
+ }
+
+ @Override
public int getCommitThreadsPerInstance()
{
return conf.commitThreadsPerInstance;
}
@Override
- public UUID getId()
+ public UUID getRestoreJobId()
{
- return jobId;
+ return restoreJobId;
+ }
+
+ @Override
+ public String getConfiguredJobId()
+ {
+ return conf.configuredJobId;
}
@Override
@@ -92,28 +123,16 @@
return tokenPartitioner;
}
- @Override
- public boolean quoteIdentifiers()
- {
- return conf.quoteIdentifiers;
- }
-
- @Override
- public String keyspace()
- {
- return conf.keyspace;
- }
-
- @Override
- public String tableName()
- {
- return conf.table;
- }
-
@NotNull
@Override
public DigestAlgorithmSupplier digestAlgorithmSupplier()
{
return conf.digestAlgorithmSupplier;
}
+
+ @NotNull
+ public QualifiedTableName qualifiedTableName()
+ {
+ return new QualifiedTableName(conf.keyspace, conf.table, conf.quoteIdentifiers);
+ }
}
diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/CassandraTopologyMonitor.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/CassandraTopologyMonitor.java
new file mode 100644
index 0000000..52b9cca
--- /dev/null
+++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/CassandraTopologyMonitor.java
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.cassandra.spark.bulkwriter;
+
+import java.util.concurrent.Executors;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.TimeUnit;
+import java.util.function.Consumer;
+
+import com.google.common.annotations.VisibleForTesting;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.cassandra.spark.bulkwriter.token.TokenRangeMapping;
+import org.apache.cassandra.spark.bulkwriter.util.ThreadUtil;
+
+/**
+ * A monitor that check whether the cassandra topology has changed.
+ * On topology change, the write produced by the job is no longer accurate. It should fail as soon as change is detected.
+ */
+public class CassandraTopologyMonitor
+{
+ private static final Logger LOGGER = LoggerFactory.getLogger(CassandraTopologyMonitor.class);
+ private static final long PERIODIC_CHECK_DELAY_MS = 5000;
+ private static final long MAX_CHECK_ATTEMPTS = 10; // Number of attempts to retry the failed check task
+
+ private final ClusterInfo clusterInfo;
+ private final TokenRangeMapping<RingInstance> initialTopology;
+ private final ScheduledExecutorService executorService;
+ private final Consumer<CancelJobEvent> onCancelJob;
+ private int retryCount = 0;
+ // isStopped is set to true on job cancellation, the scheduled task should do no-op
+ private volatile boolean isStopped = false;
+
+ public CassandraTopologyMonitor(ClusterInfo clusterInfo, Consumer<CancelJobEvent> onCancelJob)
+ {
+ this.clusterInfo = clusterInfo;
+ // stop the monitor when job is cancelled
+ this.onCancelJob = onCancelJob.andThen(e -> isStopped = true);
+ this.initialTopology = clusterInfo.getTokenRangeMapping(false);
+ this.executorService = Executors.newSingleThreadScheduledExecutor(ThreadUtil.threadFactory("Cassandra Topology Monitor"));
+ executorService.scheduleWithFixedDelay(this::checkTopology, PERIODIC_CHECK_DELAY_MS, PERIODIC_CHECK_DELAY_MS, TimeUnit.MILLISECONDS);
+ }
+
+ /**
+ * Attempts to stop all tasks; we do not wait here as it is only called on job termination
+ */
+ public void shutdownNow()
+ {
+ executorService.shutdownNow();
+ }
+
+ /**
+ * @return the initial topology retrieved
+ */
+ public TokenRangeMapping<RingInstance> initialTopology()
+ {
+ return initialTopology;
+ }
+
+ private void checkTopology()
+ {
+ if (isStopped)
+ {
+ LOGGER.info("Already stopped. Skip checking topology");
+ return;
+ }
+
+ LOGGER.debug("Checking topology");
+ try
+ {
+ TokenRangeMapping<RingInstance> currentTopology = clusterInfo.getTokenRangeMapping(false);
+ if (!currentTopology.equals(initialTopology))
+ {
+ onCancelJob.accept(new CancelJobEvent("Topology changed during bulk write"));
+ return;
+ }
+ retryCount = 0;
+ }
+ catch (Exception exception)
+ {
+ if (retryCount++ > MAX_CHECK_ATTEMPTS)
+ {
+ LOGGER.error("Could not retrieve current topology. All hosts exhausted. The retrieval has failed consecutive for {} times", retryCount);
+ onCancelJob.accept(new CancelJobEvent("Could not retrieve current cassandra topology. " +
+ "All hosts and retries have been exhausted.",
+ exception));
+ }
+ else
+ {
+ LOGGER.warn("Could not retrieve current topology. Will retry momentarily. Continuing bulk write.", exception);
+ }
+ }
+ }
+
+ @VisibleForTesting
+ void checkTopologyOnDemand()
+ {
+ checkTopology();
+ }
+}
diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/ClusterInfo.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/ClusterInfo.java
index 38d87fc..2c1700d 100644
--- a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/ClusterInfo.java
+++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/ClusterInfo.java
@@ -39,8 +39,6 @@
Map<RingInstance, InstanceAvailability> getInstanceAvailability();
- boolean instanceIsAvailable(RingInstance ringInstance);
-
InstanceState getInstanceState(RingInstance instance);
Partitioner getPartitioner();
@@ -50,4 +48,10 @@
TimeSkewResponse getTimeSkew(List<RingInstance> replicas);
String getKeyspaceSchema(boolean cached);
+
+ CassandraContext getCassandraContext();
+
+ default void close()
+ {
+ }
}
diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/CommitCoordinator.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/CommitCoordinator.java
index cccfa10..b90b213 100644
--- a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/CommitCoordinator.java
+++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/CommitCoordinator.java
@@ -51,26 +51,32 @@
private static final Logger LOGGER = LoggerFactory.getLogger(CommitCoordinator.class);
private final HashMap<RingInstance, ListeningExecutorService> executors = new HashMap<>();
- private final List<StreamResult> successfulUploads;
- private final DataTransferApi transferApi;
+ private final List<DirectStreamResult> successfulUploads;
+ private final DirectDataTransferApi directDataTransferApi;
private final ClusterInfo cluster;
private final JobInfo job;
private ListenableFuture<List<CommitResult>> allCommits;
- private final String jobSufix;
+ private final String jobSuffix;
- public static CommitCoordinator commit(BulkWriterContext bulkWriterContext, StreamResult[] uploadResults)
+ public static CommitCoordinator commit(BulkWriterContext writerContext,
+ TransportContext.DirectDataBulkWriterContext
+ transportContext,
+ DirectStreamResult... uploadResults)
{
- CommitCoordinator coordinator = new CommitCoordinator(bulkWriterContext, uploadResults);
+ CommitCoordinator coordinator = new CommitCoordinator(writerContext.cluster(),
+ writerContext.job(),
+ transportContext.dataTransferApi(),
+ uploadResults);
coordinator.commit();
return coordinator;
}
- private CommitCoordinator(BulkWriterContext writerContext, StreamResult[] uploadResults)
+ private CommitCoordinator(ClusterInfo cluster, JobInfo job, DirectDataTransferApi dataTransferApi, DirectStreamResult[] uploadResults)
{
- this.transferApi = writerContext.transfer();
- this.cluster = writerContext.cluster();
- this.job = writerContext.job();
- this.jobSufix = "-" + job.getId();
+ this.directDataTransferApi = dataTransferApi;
+ this.cluster = cluster;
+ this.job = job;
+ this.jobSuffix = "-" + job.getRestoreJobId();
successfulUploads = Arrays.stream(uploadResults)
.filter(result -> !result.passed.isEmpty())
.collect(Collectors.toList());
@@ -87,9 +93,9 @@
{
// We may have already committed - we should never get here if we did, but if somehow we do we should
// simply return the commit results we already collected
- if (successfulUploads.size() > 0 && successfulUploads.stream()
+ if (!successfulUploads.isEmpty() && successfulUploads.stream()
.allMatch(result -> result.commitResults != null
- && result.commitResults.size() > 0))
+ && !result.commitResults.isEmpty()))
{
List<CommitResult> collect = successfulUploads.stream()
.flatMap(streamResult -> streamResult.commitResults.stream())
@@ -142,7 +148,7 @@
CommitResult commitResult = new CommitResult(migrationId, instance, uploadRanges);
try
{
- DataTransferApi.RemoteCommitResult result = transferApi.commitSSTables(instance, migrationId, uuids);
+ DirectDataTransferApi.RemoteCommitResult result = directDataTransferApi.commitSSTables(instance, migrationId, uuids);
if (result.isSuccess)
{
LOGGER.info("[{}]: Commit succeeded on {} for {}", migrationId, instance.nodeName(), uploadRanges);
@@ -154,7 +160,7 @@
uploadRanges.entrySet(),
result.failedUuids,
result.stdErr);
- if (result.failedUuids.size() > 0)
+ if (!result.failedUuids.isEmpty())
{
addFailures(result.failedUuids, uploadRanges, commitResult, result.stdErr);
}
@@ -181,7 +187,7 @@
String error)
{
failedRanges.forEach(uuid -> {
- String shortUuid = uuid.replace(jobSufix, "");
+ String shortUuid = uuid.replace(jobSuffix, "");
commitResult.addFailedCommit(shortUuid, uploadRanges.get(shortUuid), error != null ? error : "Unknown Commit Failure");
});
}
@@ -192,7 +198,7 @@
LOGGER.debug("Added failures to commitResult by Range: {}", commitResult);
}
- private Map<RingInstance, Map<String, Range<BigInteger>>> getResultsByInstance(List<StreamResult> successfulUploads)
+ private Map<RingInstance, Map<String, Range<BigInteger>>> getResultsByInstance(List<DirectStreamResult> successfulUploads)
{
return successfulUploads
.stream()
diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/CommitResult.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/CommitResult.java
index 6020d05..7b89463 100644
--- a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/CommitResult.java
+++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/CommitResult.java
@@ -33,7 +33,8 @@
public class CommitResult implements Serializable
{
- private static final Logger LOGGER = LoggerFactory.getLogger(StreamResult.class);
+ private static final Logger LOGGER = LoggerFactory.getLogger(CommitResult.class);
+ private static final long serialVersionUID = 773475991511158249L;
public final String migrationId;
protected RingInstance instance;
diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/DataTransport.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/DataTransport.java
new file mode 100644
index 0000000..2574780
--- /dev/null
+++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/DataTransport.java
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.cassandra.spark.bulkwriter;
+
+import org.apache.cassandra.spark.bulkwriter.blobupload.CassandraCloudStorageTransportContext;
+import org.jetbrains.annotations.NotNull;
+
+public enum DataTransport implements TransportContext.TransportContextProvider
+{
+ DIRECT
+ {
+ @Override
+ public TransportContext createContext(@NotNull BulkWriterContext bulkWriterContext,
+ @NotNull BulkSparkConf conf,
+ boolean isOnDriver)
+ {
+ // DIRECT mode does not need to distinguish driver and executor
+ return new CassandraDirectDataTransportContext(bulkWriterContext);
+ }
+ },
+ S3_COMPAT
+ {
+ @Override
+ public TransportContext createContext(@NotNull BulkWriterContext bulkWriterContext,
+ @NotNull BulkSparkConf conf,
+ boolean isOnDriver)
+ {
+ return new CassandraCloudStorageTransportContext(bulkWriterContext, conf, isOnDriver);
+ }
+ };
+}
diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/DataTransportInfo.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/DataTransportInfo.java
new file mode 100644
index 0000000..ac6aa43
--- /dev/null
+++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/DataTransportInfo.java
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.cassandra.spark.bulkwriter;
+
+import java.io.Serializable;
+
+public class DataTransportInfo implements Serializable
+{
+ private static final long serialVersionUID = 1823178559314014761L;
+ private final DataTransport transport;
+ private final String transportExtensionClass;
+
+ // note this should be shifted under appropriate class
+ private final long maxSizePerBundleInBytes;
+
+ public DataTransportInfo(DataTransport transport, String transportExtensionClass, long maxSizePerBundleInBytes)
+ {
+ this.transport = transport;
+ this.transportExtensionClass = transportExtensionClass;
+ this.maxSizePerBundleInBytes = maxSizePerBundleInBytes;
+ }
+
+ public DataTransport getTransport()
+ {
+ return transport;
+ }
+
+ public String getTransportExtensionClass()
+ {
+ return transportExtensionClass;
+ }
+
+ public long getMaxSizePerBundleInBytes()
+ {
+ return maxSizePerBundleInBytes;
+ }
+
+ public String toString()
+ {
+ return "TransportInfo={dataTransport=" + transport
+ + ",transportExtensionClass=" + transportExtensionClass
+ + ",maxSizePerBundleInBytes=" + maxSizePerBundleInBytes + '}';
+ }
+}
diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/DataTransferApi.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/DirectDataTransferApi.java
similarity index 91%
rename from cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/DataTransferApi.java
rename to cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/DirectDataTransferApi.java
index f36ed73..10654a3 100644
--- a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/DataTransferApi.java
+++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/DirectDataTransferApi.java
@@ -19,7 +19,6 @@
package org.apache.cassandra.spark.bulkwriter;
-import java.io.Serializable;
import java.nio.file.Path;
import java.util.List;
@@ -28,8 +27,22 @@
import org.apache.cassandra.spark.common.model.CassandraInstance;
import org.jetbrains.annotations.Nullable;
-public interface DataTransferApi extends Serializable
+public interface DirectDataTransferApi
{
+ RemoteCommitResult commitSSTables(CassandraInstance instance,
+ String migrationId,
+ List<String> uuids) throws ClientException;
+
+ void cleanUploadSession(CassandraInstance instance,
+ String sessionID,
+ String jobID) throws ClientException;
+
+ void uploadSSTableComponent(Path componentFile,
+ int ssTableIdx,
+ CassandraInstance instance,
+ String sessionID,
+ Digest digest) throws ClientException;
+
class RemoteCommitResult
{
public final boolean isSuccess;
@@ -48,16 +61,4 @@
this.stdErr = stdErr;
}
}
-
- RemoteCommitResult commitSSTables(CassandraInstance instance,
- String migrationId,
- List<String> uuids) throws ClientException;
-
- void cleanUploadSession(CassandraInstance instance, String sessionID, String jobID) throws ClientException;
-
- void uploadSSTableComponent(Path componentFile,
- int ssTableIdx,
- CassandraInstance instance,
- String sessionID,
- Digest digest) throws ClientException;
}
diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/DirectStreamResult.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/DirectStreamResult.java
new file mode 100644
index 0000000..948566a
--- /dev/null
+++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/DirectStreamResult.java
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.cassandra.spark.bulkwriter;
+
+import java.math.BigInteger;
+import java.util.Collections;
+import java.util.List;
+
+import com.google.common.collect.Range;
+
+public class DirectStreamResult extends StreamResult
+{
+ private static final long serialVersionUID = 3531459795301200014L;
+ protected List<CommitResult> commitResults; // CHECKSTYLE IGNORE: Public mutable field
+
+ public DirectStreamResult(String sessionID, Range<BigInteger> tokenRange,
+ List<StreamError> failures, List<RingInstance> passed,
+ long rowCount, long bytesWritten)
+ {
+ super(sessionID, tokenRange, failures, passed, rowCount, bytesWritten);
+ }
+
+ public void setCommitResults(List<CommitResult> commitResult)
+ {
+ this.commitResults = Collections.unmodifiableList(commitResult);
+ }
+
+ @Override
+ public String toString()
+ {
+ return "StreamResult{"
+ + "sessionID='" + sessionID + '\''
+ + ", tokenRange=" + tokenRange
+ + ", rowCount=" + rowCount
+ + ", failures=" + failures
+ + ", commitResults=" + commitResults
+ + ", passed=" + passed
+ + ", bytesWritten=" + bytesWritten
+ + '}';
+ }
+}
diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/DirectStreamSession.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/DirectStreamSession.java
new file mode 100644
index 0000000..fafa4d0
--- /dev/null
+++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/DirectStreamSession.java
@@ -0,0 +1,229 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.cassandra.spark.bulkwriter;
+
+import java.io.File;
+import java.io.IOException;
+import java.math.BigInteger;
+import java.nio.file.DirectoryStream;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.stream.Collectors;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Range;
+import org.apache.commons.io.FileUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.cassandra.spark.bulkwriter.token.ReplicaAwareFailureHandler;
+import org.apache.cassandra.spark.common.Digest;
+import org.apache.cassandra.spark.common.SSTables;
+
+public class DirectStreamSession extends StreamSession<TransportContext.DirectDataBulkWriterContext>
+{
+ private static final Logger LOGGER = LoggerFactory.getLogger(DirectStreamSession.class);
+ private static final String WRITE_PHASE = "UploadAndCommit";
+ private final AtomicInteger nextSSTableIdx = new AtomicInteger(1);
+ private final DirectDataTransferApi directDataTransferApi;
+
+ public DirectStreamSession(BulkWriterContext writerContext,
+ SortedSSTableWriter sstableWriter,
+ TransportContext.DirectDataBulkWriterContext transportContext,
+ String sessionID,
+ Range<BigInteger> tokenRange,
+ ReplicaAwareFailureHandler<RingInstance> failureHandler)
+ {
+ super(writerContext, sstableWriter, transportContext, sessionID, tokenRange, failureHandler);
+ this.directDataTransferApi = transportContext.dataTransferApi();
+ }
+
+ @Override
+ protected StreamResult doScheduleStream(SortedSSTableWriter sstableWriter)
+ {
+ sendSSTables(sstableWriter);
+ // StreamResult has errors streaming to replicas
+ DirectStreamResult streamResult = new DirectStreamResult(sessionID,
+ tokenRange,
+ errors,
+ new ArrayList<>(replicas),
+ sstableWriter.rowCount(),
+ sstableWriter.bytesWritten());
+ List<CommitResult> cr;
+ try
+ {
+ cr = commit(streamResult);
+ }
+ catch (Exception e)
+ {
+ if (e instanceof InterruptedException)
+ {
+ Thread.currentThread().interrupt();
+ }
+ throw new RuntimeException(e);
+ }
+ streamResult.setCommitResults(cr);
+ LOGGER.debug("StreamResult: {}", streamResult);
+ // Check consistency given the no. failures
+ BulkWriteValidator.validateClOrFail(tokenRangeMapping, failureHandler, LOGGER, WRITE_PHASE, writerContext.job());
+ return streamResult;
+ }
+
+ @Override
+ protected void sendSSTables(final SortedSSTableWriter sstableWriter)
+ {
+ try (DirectoryStream<Path> dataFileStream = Files.newDirectoryStream(sstableWriter.getOutDir(), "*Data.db"))
+ {
+ for (Path dataFile : dataFileStream)
+ {
+ int ssTableIdx = nextSSTableIdx.getAndIncrement();
+
+ LOGGER.info("[{}]: Pushing SSTable {} to replicas {}",
+ sessionID, dataFile,
+ replicas.stream().map(RingInstance::nodeName).collect(Collectors.joining(",")));
+ replicas.removeIf(replica -> !trySendSSTableToReplica(sstableWriter, dataFile, ssTableIdx, replica));
+ }
+
+ LOGGER.info("[{}]: Sent SSTables. sstables={}", sessionID, sstableWriter.sstableCount());
+ }
+ catch (IOException exception)
+ {
+ LOGGER.error("[{}]: Unexpected exception while streaming SSTables {}",
+ sessionID, sstableWriter.getOutDir());
+ cleanAllReplicas();
+ throw new RuntimeException(exception);
+ }
+ finally
+ {
+ // Clean up SSTable files once the task is complete
+ File tempDir = sstableWriter.getOutDir().toFile();
+ LOGGER.info("[{}]: Removing temporary files after stream session from {}", sessionID, tempDir);
+ try
+ {
+ FileUtils.deleteDirectory(tempDir);
+ }
+ catch (IOException exception)
+ {
+ LOGGER.warn("[{}]: Failed to delete temporary directory {}", sessionID, tempDir, exception);
+ }
+ }
+ }
+
+ private boolean trySendSSTableToReplica(SortedSSTableWriter sstableWriter,
+ Path dataFile,
+ int ssTableIdx,
+ RingInstance replica)
+ {
+ try
+ {
+ sendSSTableToReplica(dataFile, ssTableIdx, replica, sstableWriter.fileDigestMap());
+ return true;
+ }
+ catch (Exception exception)
+ {
+ LOGGER.error("[{}]: Failed to stream range {} to instance {}",
+ sessionID, tokenRange, replica.nodeName(), exception);
+ writerContext.cluster().refreshClusterInfo();
+ failureHandler.addFailure(this.tokenRange, replica, exception.getMessage());
+ errors.add(new StreamError(this.tokenRange, replica, exception.getMessage()));
+ clean(replica, sessionID);
+ return false;
+ }
+ }
+
+ private void sendSSTableToReplica(Path dataFile,
+ int ssTableIdx,
+ RingInstance instance,
+ Map<Path, Digest> fileHashes) throws IOException
+ {
+ try (DirectoryStream<Path> componentFileStream = Files.newDirectoryStream(dataFile.getParent(),
+ SSTables.getSSTableBaseName(dataFile) + "*"))
+ {
+ for (Path componentFile : componentFileStream)
+ {
+ // send data component the last
+ if (componentFile.getFileName().toString().endsWith("Data.db"))
+ {
+ continue;
+ }
+ sendSSTableComponent(componentFile, ssTableIdx, instance, fileHashes.get(componentFile));
+ }
+ sendSSTableComponent(dataFile, ssTableIdx, instance, fileHashes.get(dataFile));
+ }
+ }
+
+ private void sendSSTableComponent(Path componentFile,
+ int ssTableIdx,
+ RingInstance instance,
+ Digest digest) throws IOException
+ {
+ Preconditions.checkNotNull(digest, "All files must have a digest. SSTableWriter should have calculated these.");
+ LOGGER.info("[{}]: Uploading {} to {}: Size is {}",
+ sessionID, componentFile, instance.nodeName(), Files.size(componentFile));
+ directDataTransferApi.uploadSSTableComponent(componentFile, ssTableIdx, instance, this.sessionID, digest);
+ }
+
+ private List<CommitResult> commit(DirectStreamResult streamResult) throws ExecutionException, InterruptedException
+ {
+ try (CommitCoordinator cc = CommitCoordinator.commit(writerContext, transportContext, streamResult))
+ {
+ List<CommitResult> commitResults = cc.get();
+ LOGGER.debug("All CommitResults: {}", commitResults);
+ commitResults.forEach(cr -> BulkWriteValidator.updateFailureHandler(cr, WRITE_PHASE, failureHandler));
+ return commitResults;
+ }
+ }
+
+ /* Get all replicas and clean temporary state on them */
+ private void cleanAllReplicas()
+ {
+ Set<RingInstance> instances = new HashSet<>(replicas);
+ errors.forEach(streamError -> instances.add(streamError.instance));
+ instances.forEach(instance -> clean(instance, sessionID));
+ }
+
+ private void clean(RingInstance instance, String sessionID)
+ {
+ if (writerContext.job().getSkipClean())
+ {
+ LOGGER.info("Skip clean requested - not cleaning SSTable session {} on instance {}",
+ sessionID, instance.nodeName());
+ return;
+ }
+ String jobID = writerContext.job().getId();
+ LOGGER.info("Cleaning SSTable session {} on instance {}", sessionID, instance.nodeName());
+ try
+ {
+ directDataTransferApi.cleanUploadSession(instance, sessionID, jobID);
+ }
+ catch (Exception exception)
+ {
+ LOGGER.warn("Failed to clean SSTables on {} for session {} and ignoring errMsg",
+ instance.nodeName(), sessionID, exception);
+ }
+ }
+}
diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/HeartbeatReporter.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/HeartbeatReporter.java
new file mode 100644
index 0000000..b27983c
--- /dev/null
+++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/HeartbeatReporter.java
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.cassandra.spark.bulkwriter;
+
+import java.io.Closeable;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.concurrent.Executors;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.ScheduledFuture;
+import java.util.concurrent.ThreadFactory;
+import java.util.concurrent.TimeUnit;
+
+import com.google.common.annotations.VisibleForTesting;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.cassandra.spark.bulkwriter.util.ThreadUtil;
+
+public class HeartbeatReporter implements Closeable
+{
+ private static final Logger LOGGER = LoggerFactory.getLogger(HeartbeatReporter.class);
+
+ private final ScheduledExecutorService scheduler;
+ private final Map<String, ScheduledFuture<?>> scheduledHeartbeats;
+ private boolean isClosed;
+
+ public HeartbeatReporter()
+ {
+ ThreadFactory tf = ThreadUtil.threadFactory("Heartbeat reporter");
+ this.scheduler = Executors.newSingleThreadScheduledExecutor(tf);
+ this.scheduledHeartbeats = new HashMap<>();
+ this.isClosed = false;
+ }
+
+ public synchronized void schedule(String name, long heartBeatIntervalMillis, Runnable heartBeat)
+ {
+ if (isClosed)
+ {
+ LOGGER.info("HeartbeatReporter is already closed");
+ return;
+ }
+
+ if (scheduledHeartbeats.containsKey(name))
+ {
+ LOGGER.info("The heartbeat has been scheduled already. heartbeat={}", name);
+ return;
+ }
+ ScheduledFuture<?> fut = scheduler.scheduleWithFixedDelay(new NoThrow(name, heartBeat),
+ heartBeatIntervalMillis, // initial delay
+ heartBeatIntervalMillis, // delay
+ TimeUnit.MILLISECONDS);
+ scheduledHeartbeats.put(name, fut);
+ }
+
+ // return true if unscheduled; return false if unable to unschedule, typically it is unscheduled already
+ @VisibleForTesting
+ public synchronized boolean unschedule(String name)
+ {
+ if (isClosed)
+ {
+ LOGGER.info("HeartbeatReporter is already closed");
+ return false;
+ }
+
+ ScheduledFuture<?> fut = scheduledHeartbeats.remove(name);
+ if (fut == null)
+ {
+ return false;
+ }
+ return fut.cancel(true);
+ }
+
+ /**
+ * Close the resources at best effort. The action is uninterruptible, but the interruption status is restore.
+ */
+ public synchronized void close()
+ {
+ isClosed = true;
+ scheduledHeartbeats.values().forEach(fut -> fut.cancel(true));
+ scheduler.shutdownNow();
+ try
+ {
+ boolean terminated = scheduler.awaitTermination(2, TimeUnit.SECONDS);
+ if (!terminated)
+ {
+ LOGGER.warn("Closing heartbeat reporter times out");
+ }
+ }
+ catch (InterruptedException ie)
+ {
+ Thread.currentThread().interrupt();
+ }
+ catch (Exception exception)
+ {
+ LOGGER.warn("Exception when closing scheduler", exception);
+ }
+ }
+
+ // A Runnable wrapper that does not throw exceptions. Therefore, it gets executed again by scheduler
+ private static class NoThrow implements Runnable
+ {
+ private final String name;
+ private final Runnable beat;
+
+ NoThrow(String name, Runnable beat)
+ {
+ this.beat = beat;
+ this.name = name;
+ }
+
+ @Override
+ public void run()
+ {
+ try
+ {
+ beat.run();
+ }
+ catch (Exception exception)
+ {
+ LOGGER.warn("{} failed to run", name, exception);
+ }
+ }
+ }
+}
diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/ImportCompletionCoordinator.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/ImportCompletionCoordinator.java
new file mode 100644
index 0000000..94ea6d0
--- /dev/null
+++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/ImportCompletionCoordinator.java
@@ -0,0 +1,417 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.cassandra.spark.bulkwriter;
+
+import java.math.BigInteger;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.CancellationException;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.Executors;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.ThreadFactory;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.BiFunction;
+import java.util.function.Consumer;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.collect.Range;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import o.a.c.sidecar.client.shaded.common.data.CreateSliceRequestPayload;
+import org.apache.cassandra.sidecar.client.SidecarInstance;
+import org.apache.cassandra.spark.bulkwriter.blobupload.BlobDataTransferApi;
+import org.apache.cassandra.spark.bulkwriter.blobupload.BlobStreamResult;
+import org.apache.cassandra.spark.bulkwriter.blobupload.CreatedRestoreSlice;
+import org.apache.cassandra.spark.bulkwriter.util.ThreadUtil;
+import org.apache.cassandra.spark.data.ReplicationFactor;
+import org.apache.cassandra.spark.transports.storage.extensions.StorageTransportExtension;
+
+import static org.apache.cassandra.clients.Sidecar.toSidecarInstance;
+import static org.apache.cassandra.spark.bulkwriter.blobupload.CreatedRestoreSlice.ConsistencyLevelCheckResult.NOT_SATISFIED;
+import static org.apache.cassandra.spark.bulkwriter.blobupload.CreatedRestoreSlice.ConsistencyLevelCheckResult.SATISFIED;
+
+public final class ImportCompletionCoordinator
+{
+ private static final Logger LOGGER = LoggerFactory.getLogger(ImportCompletionCoordinator.class);
+
+ private final long startTimeNanos;
+ private final BlobDataTransferApi dataTransferApi;
+ private final BulkWriteValidator writeValidator;
+ private final List<BlobStreamResult> blobStreamResultList;
+ private final JobInfo job;
+ private final ScheduledExecutorService scheduler;
+ private final CassandraTopologyMonitor cassandraTopologyMonitor;
+ private final ReplicationFactor replicationFactor;
+ private final StorageTransportExtension extension;
+ private final CompletableFuture<Void> firstFailure = new CompletableFuture<>();
+ private final CompletableFuture<Void> terminal = new CompletableFuture<>();
+ private final Map<CompletableFuture<Void>, RequestAndInstance> importFutures = new HashMap<>();
+ private final AtomicBoolean terminalScheduled = new AtomicBoolean(false);
+ private final AtomicInteger completedSlices = new AtomicInteger(0);
+
+ private long waitStartNanos;
+ private long minSliceSize = Long.MAX_VALUE;
+ private long maxSliceSize = Long.MIN_VALUE;
+ private int totalSlices;
+ private AtomicInteger satisfiedSlices;
+
+ private ImportCompletionCoordinator(long startTimeNanos,
+ BulkWriterContext writerContext,
+ BlobDataTransferApi dataTransferApi,
+ BulkWriteValidator writeValidator,
+ List<BlobStreamResult> blobStreamResultList,
+ StorageTransportExtension extension,
+ Consumer<CancelJobEvent> onCancelJob)
+ {
+ this(startTimeNanos, writerContext, dataTransferApi, writeValidator, blobStreamResultList, extension, onCancelJob, CassandraTopologyMonitor::new);
+ }
+
+ @VisibleForTesting
+ ImportCompletionCoordinator(long startTimeNanos,
+ BulkWriterContext writerContext,
+ BlobDataTransferApi dataTransferApi,
+ BulkWriteValidator writeValidator,
+ List<BlobStreamResult> blobStreamResultList,
+ StorageTransportExtension extension,
+ Consumer<CancelJobEvent> onCancelJob,
+ BiFunction<ClusterInfo, Consumer<CancelJobEvent>, CassandraTopologyMonitor> monitorCreator)
+ {
+ this.startTimeNanos = startTimeNanos;
+ this.job = writerContext.job();
+ this.dataTransferApi = dataTransferApi;
+ this.writeValidator = writeValidator;
+ this.blobStreamResultList = blobStreamResultList;
+ this.extension = extension;
+ ThreadFactory tf = ThreadUtil.threadFactory("Import completion timeout");
+ this.scheduler = Executors.newSingleThreadScheduledExecutor(tf);
+ Consumer<CancelJobEvent> wrapped = cancelJobEvent -> {
+ // try to complete the firstFailure, in order to exit coordinator ASAP
+ firstFailure.completeExceptionally(new RuntimeException(cancelJobEvent.reason, cancelJobEvent.exception));
+ onCancelJob.accept(cancelJobEvent);
+ };
+ this.cassandraTopologyMonitor = monitorCreator.apply(writerContext.cluster(), wrapped);
+ this.replicationFactor = cassandraTopologyMonitor.initialTopology().replicationFactor();
+ }
+
+
+ public static ImportCompletionCoordinator of(long startTimeNanos,
+ BulkWriterContext writerContext,
+ BlobDataTransferApi dataTransferApi,
+ BulkWriteValidator writeValidator,
+ List<BlobStreamResult> resultsAsBlobStreamResults,
+ StorageTransportExtension extension,
+ Consumer<CancelJobEvent> onCancelJob)
+ {
+ return new ImportCompletionCoordinator(startTimeNanos,
+ writerContext, dataTransferApi,
+ writeValidator, resultsAsBlobStreamResults,
+ extension, onCancelJob);
+ }
+
+ /**
+ * Block for the imports to complete by invoking the CreateRestoreJobSlice call to the server.
+ * The method passes when the successful import can satisfy the configured consistency level;
+ * otherwise, the method fails.
+ * The wait is indefinite until one of the following conditions is met,
+ * 1) _all_ slices have been checked, or
+ * 2) the spark job reaches to its completion timeout
+ * 3) At least one slice fails CL validation, as the job will eventually fail in this case.
+ * this means that some slices may never be processed by this loop
+ * <p>
+ * When there is a slice failed on CL validation and there are remaining slices to check, the wait continues.
+ */
+ public void waitForCompletion()
+ {
+ writeValidator.setPhase("WaitForCommitCompletion");
+
+ try
+ {
+ waitForCompletionInternal();
+ }
+ finally
+ {
+ if (terminal.isDone())
+ {
+ LOGGER.info("Concluded the safe termination, given the specified consistency level is satisfied " +
+ "and enough time has been blocked for importing slices.");
+ }
+ cassandraTopologyMonitor.shutdownNow();
+ importFutures.keySet().forEach(f -> f.cancel(true));
+ terminal.complete(null);
+ scheduler.shutdownNow(); // shutdown and do not wait for the termination; the job is completing
+ }
+ }
+
+ private void waitForCompletionInternal()
+ {
+ prepareToPoll();
+
+ startPolling();
+
+ await();
+ }
+
+ private void prepareToPoll()
+ {
+ totalSlices = blobStreamResultList.stream().mapToInt(res -> res.createdRestoreSlices.size()).sum();
+ blobStreamResultList
+ .stream()
+ .flatMap(res -> res.createdRestoreSlices
+ .stream()
+ .map(CreatedRestoreSlice::sliceRequestPayload))
+ .mapToLong(slice -> {
+ // individual task should never return slice with 0-size bundle
+ long size = slice.compressedSizeOrZero();
+ if (size == 0)
+ {
+ throw new IllegalStateException("Found invalid slice with 0 compressed size. " +
+ "slice: " + slice);
+ }
+ return size;
+ })
+ .forEach(size -> {
+ minSliceSize = Math.min(minSliceSize, size);
+ maxSliceSize = Math.max(maxSliceSize, size);
+ });
+ satisfiedSlices = new AtomicInteger(0);
+ waitStartNanos = System.nanoTime();
+ }
+
+ private void startPolling()
+ {
+ for (BlobStreamResult blobStreamResult : blobStreamResultList)
+ {
+ for (CreatedRestoreSlice createdRestoreSlice : blobStreamResult.createdRestoreSlices)
+ {
+ for (RingInstance instance : blobStreamResult.passed)
+ {
+ createSliceInstanceFuture(createdRestoreSlice, instance);
+ }
+ }
+ }
+ }
+
+ private void addCompletionMonitor(CompletableFuture<?> future)
+ {
+ // whenComplete callback will still be invoked when the future is cancelled.
+ // In such case, expect CancellationException
+ future.whenComplete((v, t) -> {
+ LOGGER.info("Completed slice requests {}/{}", completedSlices.incrementAndGet(), importFutures.keySet().size());
+
+ if (t instanceof CancellationException)
+ {
+ RequestAndInstance rai = importFutures.get(future);
+ LOGGER.info("Cancelled import. instance={} slice={}", rai.nodeFqdn, rai.requestPayload);
+ return;
+ }
+
+ // only enter the block once
+ if (satisfiedSlices.get() == totalSlices
+ && terminalScheduled.compareAndSet(false, true))
+ {
+ long timeToAllSatisfiedNanos = System.nanoTime() - waitStartNanos;
+ long timeout = estimateTimeout(timeToAllSatisfiedNanos);
+ LOGGER.info("The specified consistency level of the job has been satisfied. " +
+ "Continuing to waiting on slices completion in order to prevent Cassandra side " +
+ "streaming as much as possible. The estimated additional wait time is {} seconds.",
+ TimeUnit.NANOSECONDS.toSeconds(timeout));
+ // schedule to complete the terminal
+ scheduler.schedule(() -> terminal.complete(null),
+ timeout, TimeUnit.NANOSECONDS);
+ }
+ });
+ }
+
+ private void await()
+ {
+ // the result either fail early once firstFailure future completes exceptionally, reached timeout (while CL is satisfied),
+ // or the results list completes
+ CompletableFuture.anyOf(firstFailure, terminal,
+ CompletableFuture.allOf(importFutures.keySet().toArray(new CompletableFuture[0])))
+ .join();
+ // double check to make sure all slices are satisfied
+ // Because at this point all ranges have been either satisfied or the job has already failed,
+ // this is really just a sanity check for things like lost futures/future-introduced bugs
+ validateAllRangesAreSatisfied();
+ }
+
+ // calculate the timeout based on the 1) time taken to have all slices satisfied, and 2) use import rate
+ private long estimateTimeout(long timeToAllSatisfiedNanos)
+ {
+ long timeout = timeToAllSatisfiedNanos;
+ // use the minSliceSize to get the slowest import rate. R = minSliceSize / T
+ // use the maxSliceSize to get the highest amount of time needed for import. D = maxSliceSize / R
+ // Please do not combine the two statements below for readability purpose
+ double estimatedRateFloor = ((double) minSliceSize) / timeToAllSatisfiedNanos;
+ double timeEstimateBasedOnRate = ((double) maxSliceSize) / estimatedRateFloor;
+ timeout = Math.max((long) timeEstimateBasedOnRate, timeout);
+ timeout = job.importCoordinatorTimeoutMultiplier() * timeout;
+ if (TimeUnit.NANOSECONDS.toHours(timeout) > 1)
+ {
+ LOGGER.warn("The estimated additional timeout is more than 1 hour. timeout={} seconds",
+ TimeUnit.NANOSECONDS.toSeconds(timeout));
+ }
+ return timeout;
+ }
+
+ private void createSliceInstanceFuture(CreatedRestoreSlice createdRestoreSlice,
+ RingInstance instance)
+ {
+ if (firstFailure.isCompletedExceptionally())
+ {
+ LOGGER.warn("The job has failed already. Skip sending import request. instance={} slice={}",
+ instance.nodeName(), createdRestoreSlice.sliceRequestPayload());
+ return;
+ }
+ SidecarInstance sidecarInstance = toSidecarInstance(instance, job.effectiveSidecarPort());
+ CreateSliceRequestPayload createSliceRequestPayload = createdRestoreSlice.sliceRequestPayload();
+ CompletableFuture<Void> fut = dataTransferApi.createRestoreSliceFromDriver(sidecarInstance,
+ createSliceRequestPayload);
+ fut = fut.handleAsync((ignored, throwable) -> {
+ if (throwable == null)
+ {
+ handleSuccessfulSliceInstance(createdRestoreSlice, instance, createSliceRequestPayload);
+ }
+ else
+ {
+ // use handle API to swallow the throwable on purpose; the throwable is set to `firstFailure`
+ handleFailedSliceInstance(instance, createSliceRequestPayload, throwable);
+ }
+ return null;
+ });
+ addCompletionMonitor(fut);
+ // Use the fut variable (, instead of the new future object from whenComplete) for key on purpose.
+ // So that whenComplete callback can receive CancellationException
+ importFutures.put(fut, new RequestAndInstance(createSliceRequestPayload, instance.nodeName()));
+ }
+
+ private void handleFailedSliceInstance(RingInstance instance,
+ CreateSliceRequestPayload createSliceRequestPayload,
+ Throwable throwable)
+ {
+ LOGGER.warn("Import failed. instance={} slice={}", instance.nodeName(), createSliceRequestPayload, throwable);
+
+ Range<BigInteger> range = Range.openClosed(createSliceRequestPayload.startToken(),
+ createSliceRequestPayload.endToken());
+ writeValidator.updateFailureHandler(range, instance, "Failed to import slice. " + throwable.getMessage());
+ // it either passes or throw if consistency level cannot be satisfied
+ try
+ {
+ writeValidator.validateClOrFail(cassandraTopologyMonitor.initialTopology());
+ }
+ catch (RuntimeException rte)
+ {
+ // record the first failure and cancel queued futures.
+ firstFailure.completeExceptionally(rte);
+ }
+ }
+
+ private void handleSuccessfulSliceInstance(CreatedRestoreSlice createdRestoreSlice,
+ RingInstance instance,
+ CreateSliceRequestPayload createSliceRequestPayload)
+ {
+ LOGGER.info("Import succeeded. instance={} slice={}", instance.nodeName(), createSliceRequestPayload);
+ createdRestoreSlice.addSucceededInstance(instance);
+ if (SATISFIED ==
+ createdRestoreSlice.checkForConsistencyLevel(job.getConsistencyLevel(),
+ replicationFactor,
+ job.getLocalDC()))
+ {
+ satisfiedSlices.incrementAndGet();
+ try
+ {
+ extension.onObjectApplied(createSliceRequestPayload.bucket(),
+ createSliceRequestPayload.key(),
+ createSliceRequestPayload.compressedSizeOrZero(),
+ System.nanoTime() - startTimeNanos);
+ }
+ catch (Throwable t)
+ {
+ // log a warning message and carry on
+ LOGGER.warn("StorageTransportExtension fails to process ObjectApplied notification", t);
+ }
+ }
+ }
+
+ /**
+ * Validate that all ranges should collect enough write acknowledges to satisfy the consistency level
+ * It throws when there is any range w/o enough write acknowledges
+ */
+ private void validateAllRangesAreSatisfied()
+ {
+ List<CreatedRestoreSlice> unsatisfiedSlices = new ArrayList<>();
+ for (BlobStreamResult blobStreamResult : blobStreamResultList)
+ {
+ for (CreatedRestoreSlice createdRestoreSlice : blobStreamResult.createdRestoreSlices)
+ {
+ if (NOT_SATISFIED == createdRestoreSlice.checkForConsistencyLevel(job.getConsistencyLevel(),
+ replicationFactor,
+ job.getLocalDC()))
+ {
+ unsatisfiedSlices.add(createdRestoreSlice);
+ }
+ }
+ }
+ if (unsatisfiedSlices.isEmpty())
+ {
+ LOGGER.info("All token ranges have satisfied with consistency level. consistencyLevel={} phase={}",
+ job.getConsistencyLevel(), writeValidator.getPhase());
+ }
+ else
+ {
+ String message = String.format("Some of the token ranges cannot satisfy with consistency level. " +
+ "job=%s phase=%s consistencyLevel=%s ranges=%s",
+ job.getRestoreJobId(), writeValidator.getPhase(), job.getConsistencyLevel(), unsatisfiedSlices);
+ LOGGER.error(message);
+ throw new RuntimeException(message);
+ }
+ }
+
+ @VisibleForTesting
+ Map<CompletableFuture<Void>, RequestAndInstance> importFutures()
+ {
+ return importFutures;
+ }
+
+ @VisibleForTesting
+ CompletableFuture<Void> firstFailure()
+ {
+ return firstFailure;
+ }
+
+ // simple data class to group the request and the node fqdn
+ static class RequestAndInstance
+ {
+ final String nodeFqdn;
+ final CreateSliceRequestPayload requestPayload;
+
+ RequestAndInstance(CreateSliceRequestPayload requestPayload, String nodeFqdn)
+ {
+ this.nodeFqdn = nodeFqdn;
+ this.requestPayload = requestPayload;
+ }
+ }
+}
diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/JobInfo.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/JobInfo.java
index 10635b1..c93a47d 100644
--- a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/JobInfo.java
+++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/JobInfo.java
@@ -23,10 +23,12 @@
import java.util.UUID;
import org.apache.cassandra.spark.bulkwriter.token.ConsistencyLevel;
+import org.apache.cassandra.spark.data.QualifiedTableName;
import org.jetbrains.annotations.NotNull;
public interface JobInfo extends Serializable
{
+ // ******************
// Job Information API - should this really just move back to Config? Here to try to reduce the violations of the Law of Demeter more than anything else
ConsistencyLevel getConsistencyLevel();
@@ -41,24 +43,30 @@
int getCommitThreadsPerInstance();
- UUID getId();
+ /**
+ * return the identifier of the restore job created on Cassandra Sidecar
+ * @return time-based uuid
+ */
+ UUID getRestoreJobId();
+
+ /**
+ * An optional unique identified supplied in spark configuration
+ * @return a id string or null
+ */
+ String getConfiguredJobId();
+
+ // Convenient method to decide a unique identified used for the job.
+ // It prefers the configuredJobId if present; otherwise, fallback to the restoreJobId
+ default String getId()
+ {
+ String configuredJobId = getConfiguredJobId();
+ return configuredJobId == null ? getRestoreJobId().toString() : configuredJobId;
+ }
TokenPartitioner getTokenPartitioner();
boolean skipExtendedVerify();
- boolean quoteIdentifiers();
-
- String keyspace();
-
- String tableName();
-
- @NotNull
- default String getFullTableName()
- {
- return keyspace() + "." + tableName();
- }
-
boolean getSkipClean();
/**
@@ -66,4 +74,23 @@
*/
@NotNull
DigestAlgorithmSupplier digestAlgorithmSupplier();
+
+ QualifiedTableName qualifiedTableName();
+
+ DataTransportInfo transportInfo();
+
+ /**
+ * @return job keep alive time in minutes
+ */
+ int jobKeepAliveMinutes();
+
+ /**
+ * @return sidecar service port
+ */
+ int effectiveSidecarPort();
+
+ /**
+ * @return multiplier to calculate the final timeout for import coordinator
+ */
+ int importCoordinatorTimeoutMultiplier();
}
diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/RecordWriter.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/RecordWriter.java
index b6fea5a..44af5c5 100644
--- a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/RecordWriter.java
+++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/RecordWriter.java
@@ -20,11 +20,9 @@
package org.apache.cassandra.spark.bulkwriter;
import java.io.IOException;
-import java.io.Serializable;
import java.math.BigInteger;
import java.nio.file.Files;
import java.nio.file.Path;
-import java.nio.file.Paths;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
@@ -36,9 +34,10 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
-import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
-import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;
@@ -53,41 +52,42 @@
import o.a.c.sidecar.client.shaded.common.data.TimeSkewResponse;
import org.apache.cassandra.spark.bulkwriter.token.ReplicaAwareFailureHandler;
import org.apache.cassandra.spark.bulkwriter.token.TokenRangeMapping;
+import org.apache.cassandra.spark.bulkwriter.util.TaskContextUtils;
+import org.apache.cassandra.spark.bulkwriter.util.ThreadUtil;
import org.apache.cassandra.spark.data.BridgeUdtValue;
import org.apache.cassandra.spark.data.CqlField;
import org.apache.cassandra.spark.data.CqlTable;
import org.apache.cassandra.spark.data.ReplicationFactor;
import org.apache.cassandra.spark.utils.DigestAlgorithm;
-import org.apache.spark.InterruptibleIterator;
import org.apache.spark.TaskContext;
+import org.jetbrains.annotations.NotNull;
import scala.Tuple2;
-import static org.apache.cassandra.spark.utils.ScalaConversionUtils.asScalaIterator;
-
@SuppressWarnings({ "ConstantConditions" })
-public class RecordWriter implements Serializable
+public class RecordWriter
{
public static final ReplicationFactor IGNORED_REPLICATION_FACTOR = new ReplicationFactor(ReplicationFactor.ReplicationStrategy.SimpleStrategy,
ImmutableMap.of("replication_factor", 1));
private static final Logger LOGGER = LoggerFactory.getLogger(RecordWriter.class);
- private static final long serialVersionUID = 3746578054834640428L;
+
private final BulkWriterContext writerContext;
private final String[] columnNames;
private final SSTableWriterFactory tableWriterFactory;
private final DigestAlgorithm digestAlgorithm;
-
private final BulkWriteValidator writeValidator;
private final ReplicaAwareFailureHandler<RingInstance> failureHandler;
-
private final Supplier<TaskContext> taskContextSupplier;
private final ConcurrentHashMap<String, CqlField.CqlUdt> udtCache = new ConcurrentHashMap<>();
- private SSTableWriter sstableWriter = null;
- private int outputSequence = 0; // sub-folder for possible subrange splits
- private transient volatile CqlTable cqlTable;
+ private final Map<String, Future<StreamResult>> streamFutures;
+ private final ExecutorService executorService;
+ private final Path baseDir;
+
+ private volatile CqlTable cqlTable;
+ private StreamSession<?> streamSession = null;
public RecordWriter(BulkWriterContext writerContext, String[] columnNames)
{
- this(writerContext, columnNames, TaskContext::get, SSTableWriter::new);
+ this(writerContext, columnNames, TaskContext::get, SortedSSTableWriter::new);
}
@VisibleForTesting
@@ -103,27 +103,22 @@
this.failureHandler = new ReplicaAwareFailureHandler<>(writerContext.cluster().getPartitioner());
this.writeValidator = new BulkWriteValidator(writerContext, failureHandler);
this.digestAlgorithm = this.writerContext.job().digestAlgorithmSupplier().get();
+ this.streamFutures = new HashMap<>();
+ this.executorService = Executors.newSingleThreadExecutor(ThreadUtil.threadFactory("RecordWriter-worker"));
+ this.baseDir = TaskContextUtils.getPartitionUniquePath(System.getProperty("java.io.tmpdir"),
+ writerContext.job().getId(),
+ taskContextSupplier.get());
writerContext.cluster().startupValidate();
}
- private Range<BigInteger> getTokenRange(TaskContext taskContext)
- {
- return writerContext.job().getTokenPartitioner().getTokenRange(taskContext.partitionId());
- }
-
- private String getStreamId(TaskContext taskContext)
- {
- return String.format("%d-%s", taskContext.partitionId(), UUID.randomUUID());
- }
-
private CqlTable cqlTable()
{
if (cqlTable == null)
{
cqlTable = writerContext.bridge()
.buildSchema(writerContext.schema().getTableSchema().createStatement,
- writerContext.job().keyspace(),
+ writerContext.job().qualifiedTableName().keyspace(),
IGNORED_REPLICATION_FACTOR,
writerContext.cluster().getPartitioner(),
writerContext.schema().getUserDefinedTypeStatements());
@@ -132,6 +127,11 @@
return cqlTable;
}
+ /**
+ * Write data into stream
+ * @param sourceIterator source data
+ * @return write result
+ */
public WriteResult write(Iterator<Tuple2<DecoratedKey, Object[]>> sourceIterator)
{
TaskContext taskContext = taskContextSupplier.get();
@@ -148,7 +148,7 @@
LOGGER.info("[{}]: Fetched token range mapping for keyspace: {} with write replicas: {} containing pending " +
"replicas: {}, blocked instances: {}, replacement instances: {}",
taskContext.partitionId(),
- writerContext.job().keyspace(),
+ writerContext.job().qualifiedTableName().keyspace(),
initialTokenRangeMapping.getWriteReplicas().size(),
initialTokenRangeMapping.getPendingReplicas().size(),
initialTokenRangeMapping.getBlockedInstances().size(),
@@ -156,7 +156,6 @@
Map<Range<BigInteger>, List<RingInstance>> initialTokenRangeInstances =
taskTokenRangeMapping(initialTokenRangeMapping, taskTokenRange);
- List<StreamResult> results = new ArrayList<>();
writeValidator.setPhase("Environment Validation");
writeValidator.validateClOrFail(initialTokenRangeMapping);
@@ -165,17 +164,11 @@
// for all replicas in this partition
validateAcceptableTimeSkewOrThrow(new ArrayList<>(instancesFromMapping(initialTokenRangeInstances)));
- scala.collection.Iterator<scala.Tuple2<DecoratedKey, Object[]>> dataIterator =
- new InterruptibleIterator<>(taskContext, asScalaIterator(sourceIterator));
- StreamSession streamSession = null;
+ Iterator<Tuple2<DecoratedKey, Object[]>> dataIterator = new JavaInterruptibleIterator<>(taskContext, sourceIterator);
int partitionId = taskContext.partitionId();
JobInfo job = writerContext.job();
- Path baseDir = Paths.get(System.getProperty("java.io.tmpdir"),
- job.getId().toString(),
- Integer.toString(taskContext.stageAttemptNumber()),
- Integer.toString(taskContext.attemptNumber()),
- Integer.toString(partitionId));
Map<String, Object> valueMap = new HashMap<>();
+
try
{
// preserve the order of ranges
@@ -205,41 +198,61 @@
}
currentRange = subRanges.get(currentRangeIndex);
}
- streamSession = maybeCreateStreamSession(taskContext, streamSession, currentRange, failureHandler, results);
- maybeCreateTableWriter(partitionId, baseDir);
+ maybeCreateStreamSession(taskContext, currentRange);
writeRow(rowData, valueMap, partitionId, streamSession.getTokenRange());
}
// Finalize SSTable for the last StreamSession
- if (sstableWriter != null)
+ if (streamSession != null)
{
- finalizeSSTable(streamSession, partitionId);
- results.add(streamSession.close());
+ flushAsync(partitionId);
}
- LOGGER.info("[{}] Done with all writers and waiting for stream to complete", partitionId);
- // When instances for the partition's token range have changed within the scope of the task execution,
- // we fail the task for it to be retried
- validateTaskTokenRangeMappings(partitionId, initialTokenRangeMapping, taskTokenRange);
+ List<StreamResult> results = waitForStreamCompletionAndValidate(partitionId, initialTokenRangeMapping, taskTokenRange);
return new WriteResult(results, isClusterBeingResized);
}
catch (Exception exception)
{
LOGGER.error("[{}] Failed to write job={}, taskStageAttemptNumber={}, taskAttemptNumber={}",
partitionId,
- job.getId().toString(),
+ job.getId(),
taskContext.stageAttemptNumber(),
taskContext.attemptNumber());
+
+ if (exception instanceof InterruptedException)
+ {
+ Thread.currentThread().interrupt();
+ }
throw new RuntimeException(exception);
}
}
- public static <T> Set<T> symmetricDifference(Set<T> set1, Set<T> set2)
+ @NotNull
+ private List<StreamResult> waitForStreamCompletionAndValidate(int partitionId,
+ TokenRangeMapping<RingInstance> initialTokenRangeMapping,
+ Range<BigInteger> taskTokenRange)
{
- return Stream.concat(
- set1.stream().filter(element -> !set2.contains(element)),
- set2.stream().filter(element -> !set1.contains(element)))
- .collect(Collectors.toSet());
+ List<StreamResult> results = streamFutures.values().stream().map(f -> {
+ try
+ {
+ return f.get();
+ }
+ catch (Exception e)
+ {
+ if (e instanceof InterruptedException)
+ {
+ Thread.currentThread().interrupt();
+ }
+ throw new RuntimeException(e);
+ }
+ }).collect(Collectors.toList());
+
+ LOGGER.info("[{}] Done with all writers and waiting for stream to complete", partitionId);
+
+ // When instances for the partition's token range have changed within the scope of the task execution,
+ // we fail the task for it to be retried
+ validateTaskTokenRangeMappings(partitionId, initialTokenRangeMapping, taskTokenRange);
+ return results;
}
private Map<Range<BigInteger>, List<RingInstance>> taskTokenRangeMapping(TokenRangeMapping<RingInstance> tokenRange,
@@ -261,17 +274,16 @@
* If we do find the need to split a range into sub-ranges, we create the corresponding session for the sub-range
* if the token from the row data belongs to the range.
*/
- private StreamSession maybeCreateStreamSession(TaskContext taskContext,
- StreamSession streamSession,
- Range<BigInteger> currentRange,
- ReplicaAwareFailureHandler<RingInstance> failureHandler,
- List<StreamResult> results)
- throws IOException, ExecutionException, InterruptedException
+ private void maybeCreateStreamSession(TaskContext taskContext,
+ Range<BigInteger> currentRange) throws IOException
{
- streamSession = maybeCreateSubRangeSession(taskContext, streamSession, failureHandler, results, currentRange);
+ maybeCreateSubRangeSession(taskContext, currentRange);
// If we do not have any stream session at this point, we create a session using the partition's token range
- return (streamSession == null) ? createStreamSession(taskContext) : streamSession;
+ if (streamSession == null)
+ {
+ createStreamSessionWithAssignedRange(taskContext);
+ }
}
/**
@@ -279,26 +291,41 @@
* 1) we do not have an existing stream session, or 2) the existing stream session corresponds to a range that
* does NOT match the sub-range the token belongs to.
*/
- private StreamSession maybeCreateSubRangeSession(TaskContext taskContext,
- StreamSession streamSession,
- ReplicaAwareFailureHandler<RingInstance> failureHandler,
- List<StreamResult> results,
- Range<BigInteger> matchingSubRange)
- throws IOException, ExecutionException, InterruptedException
+ private void maybeCreateSubRangeSession(TaskContext taskContext,
+ Range<BigInteger> matchingSubRange) throws IOException
{
- if (streamSession == null || streamSession.getTokenRange() != matchingSubRange)
+ if (streamSession != null && streamSession.getTokenRange().equals(matchingSubRange))
{
- LOGGER.debug("[{}] Creating stream session for range: {}", taskContext.partitionId(), matchingSubRange);
- // Schedule data to be sent if we are processing a batch that has not been scheduled yet.
- if (streamSession != null)
- {
- // Complete existing writes (if any) before the existing stream session is closed
- finalizeSSTable(streamSession, taskContext.partitionId());
- results.add(streamSession.close());
- }
- streamSession = new StreamSession(writerContext, getStreamId(taskContext), matchingSubRange, failureHandler);
+ return;
}
- return streamSession;
+
+ // Schedule data to be sent if we are processing a batch that has not been scheduled yet.
+ if (streamSession != null)
+ {
+ // Complete existing writes (if any) before the existing stream session is closed
+ flushAsync(taskContext.partitionId());
+ }
+
+ streamSession = createStreamSession(taskContext, matchingSubRange);
+ }
+
+ private void createStreamSessionWithAssignedRange(TaskContext taskContext) throws IOException
+ {
+ createStreamSession(taskContext, getTokenRange(taskContext));
+ }
+
+ private StreamSession<?> createStreamSession(TaskContext taskContext, Range<BigInteger> range) throws IOException
+ {
+ LOGGER.info("[{}] Creating new stream session. range={}", taskContext.partitionId(), range);
+
+ String sessionId = TaskContextUtils.createStreamSessionId(taskContext);
+ Path perSessionDirectory = baseDir.resolve(sessionId);
+ Files.createDirectories(perSessionDirectory);
+ SortedSSTableWriter sstableWriter = tableWriterFactory.create(writerContext, perSessionDirectory, digestAlgorithm);
+ LOGGER.info("[{}][{}] Created new SSTable writer with directory={}",
+ taskContext.partitionId(), sessionId, perSessionDirectory);
+ return writerContext.transportContext()
+ .createStreamSession(writerContext, sessionId, sstableWriter, range, failureHandler);
}
/**
@@ -341,6 +368,19 @@
}
}
+ static <T> Set<T> symmetricDifference(Set<T> set1, Set<T> set2)
+ {
+ return Stream.concat(
+ set1.stream().filter(element -> !set2.contains(element)),
+ set2.stream().filter(element -> !set1.contains(element)))
+ .collect(Collectors.toSet());
+ }
+
+ private Range<BigInteger> getTokenRange(TaskContext taskContext)
+ {
+ return writerContext.job().getTokenPartitioner().getTokenRange(taskContext.partitionId());
+ }
+
private void validateAcceptableTimeSkewOrThrow(List<RingInstance> replicas)
{
if (replicas.isEmpty())
@@ -361,18 +401,18 @@
}
}
- private void writeRow(Tuple2<DecoratedKey, Object[]> rowData,
+ private void writeRow(Tuple2<DecoratedKey, Object[]> keyAndRowData,
Map<String, Object> valueMap,
int partitionId,
Range<BigInteger> range) throws IOException
{
- DecoratedKey key = rowData._1();
+ DecoratedKey key = keyAndRowData._1();
BigInteger token = key.getToken();
Preconditions.checkState(range.contains(token),
String.format("Received Token %s outside of expected range %s", token, range));
try
{
- sstableWriter.addRow(token, getBindValuesForColumns(valueMap, columnNames, rowData._2()));
+ streamSession.addRow(token, getBindValuesForColumns(valueMap, columnNames, keyAndRowData._2()));
}
catch (RuntimeException exception)
{
@@ -383,21 +423,10 @@
}
}
- private void maybeCreateTableWriter(int partitionId, Path baseDir) throws IOException
- {
- if (sstableWriter == null)
- {
- Path outDir = Paths.get(baseDir.toString(), Integer.toString(outputSequence++));
- Files.createDirectories(outDir);
-
- sstableWriter = tableWriterFactory.create(writerContext, outDir, digestAlgorithm);
- LOGGER.info("[{}] Created new SSTable writer", partitionId);
- }
- }
-
private Map<String, Object> getBindValuesForColumns(Map<String, Object> map, String[] columnNames, Object[] values)
{
- assert values.length == columnNames.length : "Number of values does not match the number of columns " + values.length + ", " + columnNames.length;
+ Preconditions.checkArgument(values.length == columnNames.length,
+ "Number of values does not match the number of columns " + values.length + ", " + columnNames.length);
for (int i = 0; i < columnNames.length; i++)
{
map.put(columnNames[i], maybeConvertUdt(values[i]));
@@ -438,45 +467,66 @@
}
/**
- * Close the {@link RecordWriter#sstableWriter} if present. Schedule a stream session with the produced sstables.
- * And finally, nullify {@link RecordWriter#sstableWriter}
+ * Flushes the written rows and schedule a stream session with the produced sstable asynchronously.
+ * Finally, nullify {@link RecordWriter#streamSession}.
+ *
+ * @param partitionId partition id
+ * @throws IOException I/O exceptions during flush
*/
- private void finalizeSSTable(StreamSession streamSession, int partitionId) throws IOException
+ private void flushAsync(int partitionId) throws IOException
{
- if (sstableWriter == null)
- {
- LOGGER.warn("SSTableWriter is null. Nothing to finalize");
- return;
- }
- LOGGER.info("[{}] Closing writer and scheduling SStable stream", partitionId);
- sstableWriter.close(writerContext, partitionId);
- streamSession.scheduleStream(sstableWriter);
- sstableWriter = null;
- }
-
- private StreamSession createStreamSession(TaskContext taskContext)
- {
- Range<BigInteger> tokenRange = getTokenRange(taskContext);
- LOGGER.info("[{}] Creating stream session for range={}", taskContext.partitionId(), tokenRange);
- return new StreamSession(writerContext, getStreamId(taskContext), tokenRange, failureHandler);
+ Preconditions.checkState(streamSession != null);
+ LOGGER.info("[{}][{}] Closing writer and scheduling SStable stream with {} rows",
+ partitionId, streamSession.sessionID, streamSession.rowCount());
+ Future<StreamResult> future = streamSession.scheduleStreamAsync(partitionId, executorService);
+ streamFutures.put(streamSession.sessionID, future);
+ streamSession = null;
}
/**
- * Functional interface that helps with creating {@link SSTableWriter} instances.
+ * Functional interface that helps with creating {@link SortedSSTableWriter} instances.
*/
public interface SSTableWriterFactory
{
/**
- * Creates a new instance of the {@link SSTableWriter} with the provided {@code writerContext},
+ * Creates a new instance of the {@link SortedSSTableWriter} with the provided {@code writerContext},
* {@code outDir}, and {@code digestProvider} parameters.
*
- * @param writerContext the context for the bulk writer job
- * @param outDir an output directory where SSTables components will be written to
+ * @param writerContext the context for the bulk writer job
+ * @param outDir an output directory where SSTables components will be written to
* @param digestAlgorithm a digest provider to calculate digests for every SSTable component
- * @return a new {@link SSTableWriter}
+ * @return a new {@link SortedSSTableWriter}
*/
- SSTableWriter create(BulkWriterContext writerContext,
- Path outDir,
- DigestAlgorithm digestAlgorithm);
+ SortedSSTableWriter create(BulkWriterContext writerContext,
+ Path outDir,
+ DigestAlgorithm digestAlgorithm);
+ }
+
+ // The java version of org.apache.spark.InterruptibleIterator
+ // An iterator that wraps around an existing iterator to provide task killing functionality.
+ // It works by checking the interrupted flag in TaskContext.
+ private static class JavaInterruptibleIterator<T> implements Iterator<T>
+ {
+ private final TaskContext taskContext;
+ private final Iterator<T> delegate;
+
+ JavaInterruptibleIterator(TaskContext taskContext, Iterator<T> delegate)
+ {
+ this.taskContext = taskContext;
+ this.delegate = delegate;
+ }
+
+ @Override
+ public boolean hasNext()
+ {
+ taskContext.killTaskIfInterrupted();
+ return delegate.hasNext();
+ }
+
+ @Override
+ public T next()
+ {
+ return delegate.next();
+ }
}
}
diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/SidecarDataTransferApi.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/SidecarDataTransferApi.java
index 8710f88..189fe10 100644
--- a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/SidecarDataTransferApi.java
+++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/SidecarDataTransferApi.java
@@ -35,32 +35,30 @@
import org.apache.cassandra.spark.common.Digest;
import org.apache.cassandra.spark.common.client.ClientException;
import org.apache.cassandra.spark.common.model.CassandraInstance;
+import org.apache.cassandra.spark.data.QualifiedTableName;
import static org.apache.cassandra.bridge.CassandraBridgeFactory.maybeQuotedIdentifier;
/**
- * A {@link DataTransferApi} implementation that interacts with Cassandra Sidecar
+ * A {@link DirectDataTransferApi} implementation that interacts with Cassandra Sidecar
*/
-public class SidecarDataTransferApi implements DataTransferApi
+public class SidecarDataTransferApi implements DirectDataTransferApi
{
- private static final long serialVersionUID = 2563347232666882754L;
private static final Logger LOGGER = LoggerFactory.getLogger(SidecarDataTransferApi.class);
private static final String SSTABLE_NAME_SEPARATOR = "-";
private static final int SSTABLE_GENERATION_REVERSE_OFFSET = 3;
- private final transient SidecarClient sidecarClient;
+ private final SidecarClient sidecarClient;
private final CassandraBridge bridge;
private final int sidecarPort;
private final JobInfo job;
- private final BulkSparkConf conf;
- public SidecarDataTransferApi(CassandraContext cassandraContext, CassandraBridge bridge, JobInfo job, BulkSparkConf conf)
+ public SidecarDataTransferApi(CassandraContext cassandraContext, CassandraBridge bridge, JobInfo job)
{
this.sidecarClient = cassandraContext.getSidecarClient();
this.sidecarPort = cassandraContext.sidecarPort();
this.bridge = bridge;
this.job = job;
- this.conf = conf;
}
@Override
@@ -71,12 +69,13 @@
Digest digest) throws ClientException
{
String componentName = updateComponentName(componentFile, ssTableIdx);
- String uploadId = getUploadId(sessionID, job.getId().toString());
+ String uploadId = getUploadId(sessionID, job.getRestoreJobId().toString());
+ QualifiedTableName qt = job.qualifiedTableName();
try
{
sidecarClient.uploadSSTableRequest(toSidecarInstance(instance),
- maybeQuotedIdentifier(bridge, conf.quoteIdentifiers, conf.keyspace),
- maybeQuotedIdentifier(bridge, conf.quoteIdentifiers, conf.table),
+ maybeQuotedIdentifier(bridge, qt.quoteIdentifiers(), qt.keyspace()),
+ maybeQuotedIdentifier(bridge, qt.quoteIdentifiers(), qt.table()),
uploadId,
componentName,
digest.toSidecarDigest(),
@@ -86,10 +85,10 @@
catch (ExecutionException | InterruptedException exception)
{
LOGGER.warn("Failed to upload file={}, keyspace={}, table={}, uploadId={}, componentName={}, instance={}",
- componentFile, conf.keyspace, conf.table, uploadId, componentName, instance);
+ componentFile, qt.keyspace(), qt.table(), uploadId, componentName, instance);
throw new ClientException(
String.format("Failed to upload file=%s into keyspace=%s, table=%s, componentName=%s with uploadId=%s to instance=%s",
- componentFile, conf.keyspace, conf.table, componentName, uploadId, instance), exception);
+ componentFile, qt.keyspace(), qt.table(), componentName, uploadId, instance), exception);
}
}
@@ -102,7 +101,7 @@
{
throw new UnsupportedOperationException("Only a single UUID is supported, you provided " + uuids.size());
}
- String uploadId = getUploadId(uuids.get(0), job.getId().toString());
+ String uploadId = getUploadId(uuids.get(0), job.getRestoreJobId().toString());
ImportSSTableRequest.ImportOptions importOptions = new ImportSSTableRequest.ImportOptions();
// Always verify SSTables on import
@@ -110,10 +109,11 @@
try
{
+ QualifiedTableName qt = job.qualifiedTableName();
SSTableImportResponse response =
sidecarClient.importSSTableRequest(toSidecarInstance(instance),
- maybeQuotedIdentifier(bridge, conf.quoteIdentifiers, conf.keyspace),
- maybeQuotedIdentifier(bridge, conf.quoteIdentifiers, conf.table),
+ maybeQuotedIdentifier(bridge, qt.quoteIdentifiers(), qt.keyspace()),
+ maybeQuotedIdentifier(bridge, qt.quoteIdentifiers(), qt.table()),
uploadId,
importOptions).get();
if (response.success())
diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/SSTableWriter.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/SortedSSTableWriter.java
similarity index 75%
rename from cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/SSTableWriter.java
rename to cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/SortedSSTableWriter.java
index addbc11..839beef 100644
--- a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/SSTableWriter.java
+++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/SortedSSTableWriter.java
@@ -24,6 +24,7 @@
import java.nio.file.DirectoryStream;
import java.nio.file.Files;
import java.nio.file.Path;
+import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
@@ -46,10 +47,21 @@
import org.apache.cassandra.spark.utils.DigestAlgorithm;
import org.jetbrains.annotations.NotNull;
+/**
+ * SSTableWriter that expects sorted data
+ * <br>
+ * Note for implementor: the bulk writer always sort the data in entire spark partition before writing. One of the
+ * benefit is that the output sstables are sorted and non-overlapping. It allows Cassandra to perform optimization
+ * when importing those sstables, as they can be considered as a single large SSTable technically.
+ * You might want to introduce a SSTableWriter for unsorted data, say UnsortedSSTableWriter, and stop sorting the
+ * entire partition, i.e. repartitionAndSortWithinPartitions. By doing so, it eliminates the nice property of the
+ * output sstable being globally sorted and non-overlapping.
+ * Unless you can think of a better use case, we should stick with this SortedSSTableWriter
+ */
@SuppressWarnings("WeakerAccess")
-public class SSTableWriter
+public class SortedSSTableWriter
{
- private static final Logger LOGGER = LoggerFactory.getLogger(SSTableWriter.class);
+ private static final Logger LOGGER = LoggerFactory.getLogger(SortedSSTableWriter.class);
public static final String CASSANDRA_VERSION_PREFIX = "cassandra-";
@@ -60,17 +72,19 @@
private final Map<Path, Digest> fileDigestMap = new HashMap<>();
private final DigestAlgorithm digestAlgorithm;
+ private int sstableCount = 0;
private long rowCount = 0;
+ private long bytesWritten = 0;
- public SSTableWriter(org.apache.cassandra.bridge.SSTableWriter tableWriter, Path outDir,
- DigestAlgorithm digestAlgorithm)
+ public SortedSSTableWriter(org.apache.cassandra.bridge.SSTableWriter tableWriter, Path outDir,
+ DigestAlgorithm digestAlgorithm)
{
cqlSSTableWriter = tableWriter;
this.outDir = outDir;
this.digestAlgorithm = digestAlgorithm;
}
- public SSTableWriter(BulkWriterContext writerContext, Path outDir, DigestAlgorithm digestAlgorithm)
+ public SortedSSTableWriter(BulkWriterContext writerContext, Path outDir, DigestAlgorithm digestAlgorithm)
{
this.outDir = outDir;
this.digestAlgorithm = digestAlgorithm;
@@ -97,12 +111,20 @@
return CASSANDRA_VERSION_PREFIX + lowestCassandraVersion;
}
+ /**
+ * Add a row to be written.
+ * @param token the hashed token of the row's partition key.
+ * The value must be monotonically increasing in the subsequent calls.
+ * @param boundValues bound values of the columns in the row
+ * @throws IOException I/O exception when adding the row
+ */
public void addRow(BigInteger token, Map<String, Object> boundValues) throws IOException
{
- if (minToken == null)
+ if (rowCount == 0)
{
minToken = token;
}
+ // rows are sorted. Therefore, only update the maxToken
maxToken = token;
cqlSSTableWriter.addRow(boundValues);
rowCount += 1;
@@ -116,16 +138,35 @@
return rowCount;
}
+ /**
+ * @return the total number of bytes written
+ */
+ public long bytesWritten()
+ {
+ return bytesWritten;
+ }
+
+ /**
+ * @return the total number of sstables written
+ */
+ public int sstableCount()
+ {
+ return sstableCount;
+ }
+
public void close(BulkWriterContext writerContext, int partitionId) throws IOException
{
cqlSSTableWriter.close();
+ sstableCount = 0;
for (Path dataFile : getDataFileStream())
{
// NOTE: We calculate file hashes before re-reading so that we know what we hashed
// is what we validated. Then we send these along with the files and the
// receiving end re-hashes the files to make sure they still match.
fileDigestMap.putAll(calculateFileDigestMap(dataFile));
+ sstableCount += 1;
}
+ bytesWritten = calculatedTotalSize(fileDigestMap.keySet());
validateSSTables(writerContext, partitionId);
}
@@ -138,7 +179,7 @@
try
{
CassandraVersion version = CassandraBridgeFactory.getCassandraVersion(writerContext.cluster().getLowestCassandraVersion());
- String keyspace = writerContext.job().keyspace();
+ String keyspace = writerContext.job().qualifiedTableName().keyspace();
String schema = writerContext.schema().getTableSchema().createStatement;
Set<String> udtStatements = writerContext.schema().getUserDefinedTypeStatements();
String directory = getOutDir().toString();
@@ -179,6 +220,16 @@
return fileHashes;
}
+ private long calculatedTotalSize(Collection<Path> paths) throws IOException
+ {
+ long totalSize = 0;
+ for (Path path : paths)
+ {
+ totalSize += Files.size(path);
+ }
+ return totalSize;
+ }
+
public Range<BigInteger> getTokenRange()
{
return Range.closed(minToken, maxToken);
diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/StreamError.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/StreamError.java
index da5383d..a16224d 100644
--- a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/StreamError.java
+++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/StreamError.java
@@ -20,15 +20,31 @@
package org.apache.cassandra.spark.bulkwriter;
import java.io.Serializable;
+import java.math.BigInteger;
+
+import com.google.common.collect.Range;
public class StreamError implements Serializable
{
+ private static final long serialVersionUID = 8897970306271012424L;
public final RingInstance instance;
public final String errMsg;
+ public final Range<BigInteger> failedRange;
- public StreamError(RingInstance instance, String errMsg)
+ public StreamError(Range<BigInteger> failedRange, RingInstance instance, String errMsg)
{
+ // todo: range should be all open-closed, but it is not consistent in the project yet. Enable the check later
+// Preconditions.checkArgument(RangeUtils.isOpenClosedRange(failedRange), "Token range is not open-closed");
+ this.failedRange = failedRange;
this.instance = instance;
this.errMsg = errMsg;
}
+
+ @Override
+ public String toString()
+ {
+ return "StreamError{instance:'" + instance
+ + "',failedRange:'" + failedRange
+ + "',errMsg:'" + errMsg + "'}";
+ }
}
diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/StreamResult.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/StreamResult.java
index 83ab675..56b2dbb 100644
--- a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/StreamResult.java
+++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/StreamResult.java
@@ -25,22 +25,22 @@
import com.google.common.collect.Range;
-public class StreamResult implements Serializable
+public abstract class StreamResult implements Serializable
{
+ private static final long serialVersionUID = -6533153391143872430L;
public final String sessionID;
public final Range<BigInteger> tokenRange;
- public final List<StreamError> failures;
- public List<CommitResult> commitResults; // CHECKSTYLE IGNORE: Public mutable field
+ public final List<StreamError> failures; // CHECKSTYLE IGNORE: Public mutable field
public final List<RingInstance> passed;
public final long rowCount;
public final long bytesWritten;
- public StreamResult(String sessionID,
- Range<BigInteger> tokenRange,
- List<StreamError> failures,
- List<RingInstance> passed,
- long rowCount,
- long bytesWritten)
+ protected StreamResult(String sessionID,
+ Range<BigInteger> tokenRange,
+ List<StreamError> failures,
+ List<RingInstance> passed,
+ long rowCount,
+ long bytesWritten)
{
this.sessionID = sessionID;
this.tokenRange = tokenRange;
@@ -49,17 +49,4 @@
this.rowCount = rowCount;
this.bytesWritten = bytesWritten;
}
-
- public void setCommitResults(List<CommitResult> commitResult)
- {
- this.commitResults = commitResult;
- }
-
- @Override
- public String toString()
- {
- return String.format("StreamResult{sessionID='%s', tokenRange=%s, failures=%s, commitResults=%s, passed=%s, " +
- "rowCount=%d, bytesWritten=%d}",
- sessionID, tokenRange, failures, commitResults, passed, rowCount, bytesWritten);
- }
}
diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/StreamSession.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/StreamSession.java
index 2228682..c437f6e 100644
--- a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/StreamSession.java
+++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/StreamSession.java
@@ -19,78 +19,57 @@
package org.apache.cassandra.spark.bulkwriter;
-import java.io.File;
import java.io.IOException;
import java.math.BigInteger;
-import java.nio.file.DirectoryStream;
-import java.nio.file.Files;
-import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
-import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
-import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
-import java.util.concurrent.Executors;
import java.util.concurrent.Future;
-import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.collect.Range;
-import org.apache.commons.io.FileUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.cassandra.spark.bulkwriter.token.ReplicaAwareFailureHandler;
import org.apache.cassandra.spark.bulkwriter.token.TokenRangeMapping;
-import org.apache.cassandra.spark.common.Digest;
-import org.apache.cassandra.spark.common.SSTables;
-public class StreamSession
+public abstract class StreamSession<T extends TransportContext>
{
private static final Logger LOGGER = LoggerFactory.getLogger(StreamSession.class);
- private static final String WRITE_PHASE = "UploadAndCommit";
- private final BulkWriterContext writerContext;
- private final String sessionID;
- private final Range<BigInteger> tokenRange;
- final List<RingInstance> replicas;
- private final ArrayList<StreamError> errors = new ArrayList<>();
- private final ReplicaAwareFailureHandler<RingInstance> failureHandler;
- private final AtomicInteger nextSSTableIdx = new AtomicInteger(1);
- private final ExecutorService executor;
- private final List<Future<?>> futures = new ArrayList<>();
- private final TokenRangeMapping<RingInstance> tokenRangeMapping;
- private long rowCount = 0; // total number of rows written by the SSTableWriter
- private long bytesWritten = 0;
-
- public StreamSession(final BulkWriterContext writerContext,
- final String sessionID,
- final Range<BigInteger> tokenRange,
- final ReplicaAwareFailureHandler<RingInstance> failureHandler)
- {
- this(writerContext, sessionID, tokenRange, Executors.newSingleThreadExecutor(), failureHandler);
- }
+ protected final BulkWriterContext writerContext;
+ protected final T transportContext;
+ protected final String sessionID;
+ protected final Range<BigInteger> tokenRange;
+ protected final List<RingInstance> replicas;
+ protected final ArrayList<StreamError> errors = new ArrayList<>();
+ protected final ReplicaAwareFailureHandler<RingInstance> failureHandler;
+ protected final TokenRangeMapping<RingInstance> tokenRangeMapping;
+ protected final SortedSSTableWriter sstableWriter;
@VisibleForTesting
- public StreamSession(BulkWriterContext writerContext,
- String sessionID,
- Range<BigInteger> tokenRange,
- ExecutorService executor,
- ReplicaAwareFailureHandler<RingInstance> failureHandler)
+ protected StreamSession(BulkWriterContext writerContext,
+ SortedSSTableWriter sstableWriter,
+ T transportContext,
+ String sessionID,
+ Range<BigInteger> tokenRange,
+ ReplicaAwareFailureHandler<RingInstance> failureHandler)
{
this.writerContext = writerContext;
+ this.sstableWriter = sstableWriter;
+ this.transportContext = transportContext;
this.tokenRangeMapping = writerContext.cluster().getTokenRangeMapping(true);
this.sessionID = sessionID;
this.tokenRange = tokenRange;
this.failureHandler = failureHandler;
this.replicas = getReplicas();
- this.executor = executor;
}
public Range<BigInteger> getTokenRange()
@@ -98,68 +77,24 @@
return tokenRange;
}
- public void scheduleStream(SSTableWriter ssTableWriter)
+ public void addRow(BigInteger token, Map<String, Object> boundValues) throws IOException
{
- Preconditions.checkState(!ssTableWriter.getTokenRange().isEmpty(), "Trying to stream empty SSTable");
-
- Preconditions.checkState(tokenRange.encloses(ssTableWriter.getTokenRange()),
- String.format("SSTable range %s should be enclosed in the partition range %s",
- ssTableWriter.getTokenRange(), tokenRange));
- rowCount += ssTableWriter.rowCount();
- futures.add(executor.submit(() -> sendSSTables(writerContext, ssTableWriter)));
+ sstableWriter.addRow(token, boundValues);
}
- public StreamResult close() throws ExecutionException, InterruptedException
+ public long rowCount()
{
- for (Future<?> future : futures)
- {
- try
- {
- future.get();
- }
- catch (Exception exception)
- {
- LOGGER.error("Unexpected stream errMsg. "
- + "Stream errors should have converted to StreamError and sent to driver", exception);
- throw new RuntimeException(exception);
- }
- }
-
- executor.shutdown();
- LOGGER.info("[{}]: Closing stream session. Sent {} SSTables", sessionID, futures.size());
-
- // No data written at all
- if (futures.isEmpty())
- {
- return new StreamResult(sessionID, tokenRange, new ArrayList<>(), new ArrayList<>(), rowCount, bytesWritten);
- }
- else
- {
- // StreamResult has errors streaming to replicas
- StreamResult streamResult = new StreamResult(sessionID,
- tokenRange,
- errors,
- new ArrayList<>(replicas),
- rowCount,
- bytesWritten);
- List<CommitResult> cr = commit(streamResult);
- streamResult.setCommitResults(cr);
- LOGGER.debug("StreamResult: {}", streamResult);
- // Check consistency given the no. failures
- BulkWriteValidator.validateClOrFail(tokenRangeMapping, failureHandler, LOGGER, WRITE_PHASE, writerContext.job());
- return streamResult;
- }
+ return sstableWriter.rowCount();
}
- private List<CommitResult> commit(StreamResult streamResult) throws ExecutionException, InterruptedException
+ public Future<StreamResult> scheduleStreamAsync(int partitionId, ExecutorService executorService) throws IOException
{
- try (CommitCoordinator cc = CommitCoordinator.commit(writerContext, new StreamResult[]{streamResult }))
- {
- List<CommitResult> commitResults = cc.get();
- LOGGER.debug("All CommitResults: {}", commitResults);
- commitResults.forEach(cr -> BulkWriteValidator.updateFailureHandler(cr, WRITE_PHASE, failureHandler));
- return commitResults;
- }
+ Preconditions.checkState(!sstableWriter.getTokenRange().isEmpty(), "Trying to stream empty SSTable");
+ Preconditions.checkState(tokenRange.encloses(sstableWriter.getTokenRange()),
+ "SSTable range %s should be enclosed in the partition range %s",
+ sstableWriter.getTokenRange(), tokenRange);
+ sstableWriter.close(writerContext, partitionId);
+ return executorService.submit(() -> doScheduleStream(sstableWriter));
}
@VisibleForTesting
@@ -171,9 +106,7 @@
Map<Range<BigInteger>, List<RingInstance>> overlappingRanges = tokenRangeMapping.getSubRanges(tokenRange).asMapOfRanges();
LOGGER.debug("[{}]: Stream session token range: {} overlaps with ring ranges: {}",
- sessionID,
- tokenRange,
- overlappingRanges);
+ sessionID, tokenRange, overlappingRanges);
List<RingInstance> replicasForTokenRange = overlappingRanges.values().stream()
.flatMap(Collection::stream)
@@ -184,7 +117,7 @@
.collect(Collectors.toList());
Preconditions.checkState(!replicasForTokenRange.isEmpty(),
- String.format("No replicas found for range %s", tokenRange));
+ "No replicas found for range %s", tokenRange);
// In order to better utilize replicas, shuffle the replicaList so each session starts writing to a different replica first.
Collections.shuffle(replicasForTokenRange);
@@ -206,128 +139,19 @@
|| blockedInstanceIps.contains(ringInstance.ipAddress());
}
- private void sendSSTables(BulkWriterContext writerContext, SSTableWriter ssTableWriter)
- {
- try (DirectoryStream<Path> dataFileStream = Files.newDirectoryStream(ssTableWriter.getOutDir(), "*Data.db"))
- {
- for (Path dataFile : dataFileStream)
- {
- int ssTableIdx = nextSSTableIdx.getAndIncrement();
-
- LOGGER.info("[{}]: Pushing SSTable {} to replicas {}",
- sessionID, dataFile, replicas.stream()
- .map(RingInstance::nodeName)
- .collect(Collectors.joining(",")));
- replicas.removeIf(replica -> !trySendSSTableToReplica(writerContext, ssTableWriter, dataFile, ssTableIdx, replica));
- }
- }
- catch (IOException exception)
- {
- LOGGER.error("[{}]: Unexpected exception while streaming SSTables {}",
- sessionID, ssTableWriter.getOutDir());
- cleanAllReplicas();
- throw new RuntimeException(exception);
- }
- finally
- {
- // Clean up SSTable files once the task is complete
- File tempDir = ssTableWriter.getOutDir().toFile();
- LOGGER.info("[{}]:Removing temporary files after stream session from {}", sessionID, tempDir);
- try
- {
- FileUtils.deleteDirectory(tempDir);
- }
- catch (IOException exception)
- {
- LOGGER.warn("[{}]:Failed to delete temporary directory {}", sessionID, tempDir, exception);
- }
- }
- }
-
- private boolean trySendSSTableToReplica(BulkWriterContext writerContext,
- SSTableWriter ssTableWriter,
- Path dataFile,
- int ssTableIdx,
- RingInstance replica)
- {
- try
- {
- sendSSTableToReplica(writerContext, dataFile, ssTableIdx, replica, ssTableWriter.fileDigestMap());
- return true;
- }
- catch (Exception exception)
- {
- LOGGER.error("[{}]: Failed to stream range {} to instance {}",
- sessionID, tokenRange, replica.nodeName(), exception);
- writerContext.cluster().refreshClusterInfo();
- failureHandler.addFailure(tokenRange, replica, exception.getMessage());
- errors.add(new StreamError(replica, exception.getMessage()));
- clean(writerContext, replica, sessionID);
- return false;
- }
- }
+ /**
+ * Schedule the stream with the produced sstables and return the stream result.
+ *
+ * @param sstableWriter produces SSTable(s)
+ * @return stream result
+ */
+ protected abstract StreamResult doScheduleStream(SortedSSTableWriter sstableWriter);
/**
- * Get all replicas and clean temporary state on them
+ * Send the SSTable(s) written by SSTableWriter
+ * The code runs on a separate thread
+ *
+ * @param sstableWriter produces SSTable(s)
*/
- private void cleanAllReplicas()
- {
- Set<RingInstance> instances = new HashSet<>(replicas);
- errors.forEach(streamError -> instances.add(streamError.instance));
- instances.forEach(instance -> clean(writerContext, instance, sessionID));
- }
-
- private void sendSSTableToReplica(BulkWriterContext writerContext,
- Path dataFile,
- int ssTableIdx,
- RingInstance instance,
- Map<Path, Digest> fileDigestMap) throws Exception
- {
- try (DirectoryStream<Path> componentFileStream = Files.newDirectoryStream(dataFile.getParent(), SSTables.getSSTableBaseName(dataFile) + "*"))
- {
- for (Path componentFile : componentFileStream)
- {
- if (componentFile.getFileName().toString().endsWith("Data.db"))
- {
- continue;
- }
- sendSSTableComponent(writerContext, componentFile, ssTableIdx, instance, fileDigestMap.get(componentFile));
- }
- sendSSTableComponent(writerContext, dataFile, ssTableIdx, instance, fileDigestMap.get(dataFile));
- }
- }
-
- private void sendSSTableComponent(BulkWriterContext writerContext,
- Path componentFile,
- int ssTableIdx,
- RingInstance instance,
- Digest digest) throws Exception
- {
- Preconditions.checkNotNull(digest, "All files must have a hash. SSTableWriter should have calculated these. This is a bug.");
- long fileSize = Files.size(componentFile);
- bytesWritten += fileSize;
- LOGGER.info("[{}]: Uploading {} to {}: Size is {}", sessionID, componentFile, instance.nodeName(), fileSize);
- writerContext.transfer().uploadSSTableComponent(componentFile, ssTableIdx, instance, sessionID, digest);
- }
-
- public static void clean(BulkWriterContext writerContext, RingInstance instance, String sessionID)
- {
- if (writerContext.job().getSkipClean())
- {
- LOGGER.info("Skip clean requested - not cleaning SSTable session {} on instance {}",
- sessionID, instance.nodeName());
- return;
- }
- String jobID = writerContext.job().getId().toString();
- LOGGER.info("Cleaning SSTable session {} on instance {}", sessionID, instance.nodeName());
- try
- {
- writerContext.transfer().cleanUploadSession(instance, sessionID, jobID);
- }
- catch (Exception exception)
- {
- LOGGER.warn("Failed to clean SSTables on {} for session {} and ignoring errMsg",
- instance.nodeName(), sessionID, exception);
- }
- }
+ protected abstract void sendSSTables(SortedSSTableWriter sstableWriter);
}
diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/TransportContext.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/TransportContext.java
new file mode 100644
index 0000000..59e848a
--- /dev/null
+++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/TransportContext.java
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.cassandra.spark.bulkwriter;
+
+import java.io.Serializable;
+import java.math.BigInteger;
+
+import com.google.common.collect.Range;
+
+import org.apache.cassandra.spark.bulkwriter.blobupload.BlobDataTransferApi;
+import org.apache.cassandra.spark.bulkwriter.token.ReplicaAwareFailureHandler;
+import org.apache.cassandra.spark.transports.storage.extensions.StorageTransportConfiguration;
+import org.apache.cassandra.spark.transports.storage.extensions.StorageTransportExtension;
+import org.jetbrains.annotations.NotNull;
+
+/**
+ * An interface that defines the transport context required to perform the bulk writes
+ */
+public interface TransportContext
+{
+ /**
+ * Create a new stream session that writes data to Cassandra
+ *
+ * @param writerContext bulk writer context
+ * @param sstableWriter sstable writer of the stream session
+ * @param range token range of the stream session
+ * @param failureHandler handler to track failures of the stream session
+ * @return a new stream session
+ */
+ StreamSession<? extends TransportContext> createStreamSession(BulkWriterContext writerContext,
+ String sessionId,
+ SortedSSTableWriter sstableWriter,
+ Range<BigInteger> range,
+ ReplicaAwareFailureHandler<RingInstance> failureHandler);
+
+ default void close()
+ {
+ }
+
+ /**
+ * Context used when prepared SSTables are directly written to C* through Sidecar
+ */
+ interface DirectDataBulkWriterContext extends TransportContext
+ {
+ /**
+ * @return data transfer API client for the direct write mode
+ */
+ DirectDataTransferApi dataTransferApi();
+ }
+
+ /**
+ * Context used when SSTables are uploaded to cloud
+ */
+ interface CloudStorageTransportContext extends TransportContext
+ {
+ /**
+ * @return data transfer API client for the S3_COMPAT mode
+ * Implementation note: never return null result
+ */
+ BlobDataTransferApi dataTransferApi();
+
+ /**
+ * @return configuration for S3_COMPAT
+ * Implementation note: never return null result
+ */
+ @NotNull
+ StorageTransportConfiguration transportConfiguration();
+
+ /**
+ * @return transport extension instance for S3_COMPAT
+ * Implementation note: never return null result
+ */
+ @NotNull
+ StorageTransportExtension transportExtensionImplementation();
+ }
+
+ interface TransportContextProvider extends Serializable
+ {
+ /**
+ * Create a new transport context instance
+ *
+ * @param bulkWriterContext bulk writer context
+ * @param conf bulk writer spark configuration
+ * @param isOnDriver indicates whether the role of the runtime is Spark driver or executor
+ * @return a new transport context instance
+ */
+ TransportContext createContext(@NotNull BulkWriterContext bulkWriterContext,
+ @NotNull BulkSparkConf conf,
+ boolean isOnDriver);
+ }
+}
diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/WriterOptions.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/WriterOptions.java
index fd04d97..a7578b0 100644
--- a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/WriterOptions.java
+++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/WriterOptions.java
@@ -19,6 +19,11 @@
package org.apache.cassandra.spark.bulkwriter;
+import org.apache.cassandra.spark.transports.storage.extensions.StorageTransportExtension;
+
+/**
+ * Spark options to configure bulk writer
+ */
public enum WriterOptions implements WriterOption
{
SIDECAR_INSTANCES,
@@ -61,4 +66,51 @@
* If unspecified, it defaults to {@code XXHash32} digests. The legacy {@code MD5} digest is also supported.
*/
DIGEST,
+ /**
+ * Option to specify the data transport mode. It accepts either {@link DataTransport#DIRECT} or {@link DataTransport#S3_COMPAT}
+ * Note that if S3_COMPAT is configured, {@link DATA_TRANSPORT_EXTENSION_CLASS} must be configured too.
+ */
+ DATA_TRANSPORT,
+ /**
+ * Option to specify the FQCN of class that implements {@link StorageTransportExtension}
+ */
+ DATA_TRANSPORT_EXTENSION_CLASS,
+ /**
+ * Option to tune the concurrency of S3 client's worker thread pool
+ */
+ STORAGE_CLIENT_CONCURRENCY,
+ /**
+ * Option to tune the thread keep alive seconds for the thread pool used in s3 client
+ */
+ STORAGE_CLIENT_THREAD_KEEP_ALIVE_SECONDS,
+ /**
+ * Option to specify the max chunk size for the multipart upload to S3
+ */
+ STORAGE_CLIENT_MAX_CHUNK_SIZE_IN_BYTES,
+ /**
+ * Option to specify the https proxy for s3 client
+ */
+ STORAGE_CLIENT_HTTPS_PROXY,
+ /**
+ * Option to specify the s3 server endpoint override; it is mostly used for testing
+ */
+ STORAGE_CLIENT_ENDPOINT_OVERRIDE,
+ /**
+ * Option to specify the maximum size of bundle (s3 object) to upload to s3
+ */
+ MAX_SIZE_PER_SSTABLE_BUNDLE_IN_BYTES_S3_TRANSPORT,
+ /**
+ * Option to specify the keep alive time in minutes for Sidecar to consider a job has lost/failed
+ * after not receiving its heartbeat
+ */
+ JOB_KEEP_ALIVE_MINUTES,
+ JOB_ID,
+ /**
+ * Option to tune the connection acquisition timeout for the nio http client employed in s3 client
+ */
+ STORAGE_CLIENT_NIO_HTTP_CLIENT_CONNECTION_ACQUISITION_TIMEOUT_SECONDS,
+ /**
+ * Option to tune the concurrency of the nio http client employed in s3 client
+ */
+ STORAGE_CLIENT_NIO_HTTP_CLIENT_MAX_CONCURRENCY,
}
diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/blobupload/BlobDataTransferApi.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/blobupload/BlobDataTransferApi.java
new file mode 100644
index 0000000..55eebed
--- /dev/null
+++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/blobupload/BlobDataTransferApi.java
@@ -0,0 +1,283 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.cassandra.spark.bulkwriter.blobupload;
+
+import java.io.IOException;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutionException;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import io.netty.handler.codec.http.HttpResponseStatus;
+import o.a.c.sidecar.client.shaded.common.data.CreateRestoreJobRequestPayload;
+import o.a.c.sidecar.client.shaded.common.data.CreateRestoreJobResponsePayload;
+import o.a.c.sidecar.client.shaded.common.data.CreateSliceRequestPayload;
+import o.a.c.sidecar.client.shaded.common.data.RestoreJobSummaryResponsePayload;
+import o.a.c.sidecar.client.shaded.common.data.UpdateRestoreJobRequestPayload;
+import org.apache.cassandra.sidecar.client.HttpResponse;
+import org.apache.cassandra.sidecar.client.HttpResponseImpl;
+import org.apache.cassandra.sidecar.client.SidecarClient;
+import org.apache.cassandra.sidecar.client.SidecarInstance;
+import org.apache.cassandra.sidecar.client.request.CreateRestoreJobSliceRequest;
+import org.apache.cassandra.sidecar.client.request.Request;
+import org.apache.cassandra.sidecar.client.retry.RetryAction;
+import org.apache.cassandra.sidecar.client.retry.RetryPolicy;
+import org.apache.cassandra.spark.bulkwriter.JobInfo;
+import org.apache.cassandra.spark.common.client.ClientException;
+import org.apache.cassandra.spark.data.QualifiedTableName;
+import org.apache.cassandra.spark.transports.storage.StorageCredentials;
+
+/**
+ * Encapsulates transfer APIs used by {@link BlobStreamSession}. {@link StorageClient} is used to interact with S3 and
+ * upload SSTables bundles to S3 bucket. It also has {@link SidecarClient} to call relevant sidecar APIs.
+ */
+public class BlobDataTransferApi
+{
+ private final JobInfo jobInfo;
+ private final SidecarClient sidecarClient;
+ private final StorageClient storageClient;
+
+ public BlobDataTransferApi(JobInfo jobInfo, SidecarClient sidecarClient, StorageClient storageClient)
+ {
+ this.jobInfo = jobInfo;
+ this.sidecarClient = sidecarClient;
+ this.storageClient = storageClient;
+ }
+
+ /*------ Blob APIs -------*/
+
+ public BundleStorageObject uploadBundle(StorageCredentials writeCredentials, Bundle bundle)
+ throws ClientException
+ {
+ try
+ {
+ return storageClient.multiPartUpload(writeCredentials, bundle);
+ }
+ catch (IOException | ExecutionException | InterruptedException exception)
+ {
+ rethrowOnInterruptedException("Got interrupted when uploading bundles to S3", exception);
+ throw new ClientException("Failed to upload bundles to S3", exception);
+ }
+ }
+
+ /*------ Sidecar APIs -------*/
+
+ public CreateRestoreJobResponsePayload createRestoreJob(CreateRestoreJobRequestPayload createRestoreJobRequestPayload)
+ throws ClientException
+ {
+ try
+ {
+ QualifiedTableName qualifiedTableName = jobInfo.qualifiedTableName();
+ return sidecarClient.createRestoreJob(qualifiedTableName.keyspace(),
+ qualifiedTableName.table(),
+ createRestoreJobRequestPayload).get();
+ }
+ catch (ExecutionException | InterruptedException exception)
+ {
+ rethrowOnInterruptedException("Got interrupted when creating new restore job", exception);
+ throw new ClientException("Failed to create new restore job", exception);
+ }
+ }
+
+ public RestoreJobSummaryResponsePayload restoreJobSummary()
+ throws ClientException
+ {
+ try
+ {
+ QualifiedTableName qualifiedTableName = jobInfo.qualifiedTableName();
+ return sidecarClient.restoreJobSummary(qualifiedTableName.keyspace(),
+ qualifiedTableName.table(),
+ jobInfo.getRestoreJobId()).get();
+ }
+ catch (ExecutionException | InterruptedException exception)
+ {
+ rethrowOnInterruptedException("Got interrupted when retrieving restore job summary", exception);
+ throw new ClientException("Failed to retrieve restore job summary", exception);
+ }
+ }
+
+ /**
+ * Called from task level to create a restore slice.
+ * The request retries until the slice is created (201) or retry has exhausted.
+ *
+ * @param sidecarInstance the sidecar instance where we will create the slice
+ * @param createSliceRequestPayload the payload to create the slice
+ * @throws ClientException when an error occurs during the slice creation
+ */
+ public void createRestoreSliceFromExecutor(SidecarInstance sidecarInstance,
+ CreateSliceRequestPayload createSliceRequestPayload) throws ClientException
+ {
+ try
+ {
+ createRestoreSlice(sidecarInstance, createSliceRequestPayload, new ExecutorCreateSliceRetryPolicy())
+ .get();
+ }
+ catch (ExecutionException | InterruptedException exception)
+ {
+ rethrowOnInterruptedException("Got interrupted when creating restore slice", exception);
+ throw new ClientException("Failed to create restore slice for payload: " + createSliceRequestPayload,
+ exception);
+ }
+ }
+
+ /**
+ * Called from driver level to create a restore slice asynchronously.
+ * The request retries until the slice succeeds (200), failed (550) or retry has exhausted.
+ *
+ * @param sidecarInstance the sidecar instance where we will create the slice
+ * @param createSliceRequestPayload the payload to create the slice
+ * @return future of create restore slice request
+ */
+ public CompletableFuture<Void> createRestoreSliceFromDriver(SidecarInstance sidecarInstance,
+ CreateSliceRequestPayload createSliceRequestPayload)
+ {
+ return createRestoreSlice(sidecarInstance, createSliceRequestPayload,
+ new DriverCreateSliceRetryPolicy(sidecarClient.defaultRetryPolicy()));
+ }
+
+ /**
+ * Create a restore slice with custom retry policy
+ */
+ private CompletableFuture<Void> createRestoreSlice(SidecarInstance sidecarInstance,
+ CreateSliceRequestPayload createSliceRequestPayload,
+ RetryPolicy retryPolicy)
+ {
+ QualifiedTableName qualifiedTableName = jobInfo.qualifiedTableName();
+ CreateRestoreJobSliceRequest request = new CreateRestoreJobSliceRequest(qualifiedTableName.keyspace(),
+ qualifiedTableName.table(),
+ jobInfo.getRestoreJobId(),
+ createSliceRequestPayload);
+ return sidecarClient.executeRequestAsync(sidecarClient.requestBuilder()
+ .retryPolicy(retryPolicy)
+ .singleInstanceSelectionPolicy(sidecarInstance)
+ .request(request)
+ .build());
+ }
+
+ public void updateRestoreJob(UpdateRestoreJobRequestPayload updateRestoreJobRequestPayload) throws ClientException
+ {
+ try
+ {
+ QualifiedTableName qualifiedTableName = jobInfo.qualifiedTableName();
+ sidecarClient.updateRestoreJob(qualifiedTableName.keyspace(),
+ qualifiedTableName.table(),
+ jobInfo.getRestoreJobId(),
+ updateRestoreJobRequestPayload).get();
+ }
+ catch (ExecutionException | InterruptedException exception)
+ {
+ rethrowOnInterruptedException("Got interrupted when updating restore job", exception);
+ throw new ClientException("Failed to update restore job", exception);
+ }
+ }
+
+ public void abortRestoreJob() throws ClientException
+ {
+ try
+ {
+ QualifiedTableName qualifiedTableName = jobInfo.qualifiedTableName();
+ sidecarClient.abortRestoreJob(qualifiedTableName.keyspace(),
+ qualifiedTableName.table(),
+ jobInfo.getRestoreJobId()).get();
+ }
+ catch (ExecutionException | InterruptedException exception)
+ {
+ rethrowOnInterruptedException("Got interrupted when aborting restore job", exception);
+ throw new ClientException("Failed to abort restore job", exception);
+ }
+ }
+
+ /**
+ * {@link SidecarClient} by default retries till 200 Http response. But for create slice endpoint at task level,
+ * we want to wait only till 201 Http response, hence using a custom retry policy
+ */
+ class ExecutorCreateSliceRetryPolicy extends RetryPolicy
+ {
+ @Override
+ public void onResponse(CompletableFuture<HttpResponse> completableFuture,
+ Request request, HttpResponse httpResponse, Throwable throwable,
+ int attempts, boolean canRetryOnADifferentHost, RetryAction retryAction)
+ {
+ if (httpResponse != null && httpResponse.statusCode() == HttpResponseStatus.CREATED.code())
+ {
+ completableFuture.complete(httpResponse);
+ }
+ else
+ {
+ sidecarClient.defaultRetryPolicy().onResponse(completableFuture, request, httpResponse,
+ throwable, attempts, canRetryOnADifferentHost,
+ retryAction);
+ }
+ }
+ }
+
+ /**
+ * Retry when server return CREATED 201. Besides that, its behavior is the same as what the default does.
+ */
+ static class DriverCreateSliceRetryPolicy extends RetryPolicy
+ {
+ private static final Logger LOGGER = LoggerFactory.getLogger(DriverCreateSliceRetryPolicy.class);
+ private final RetryPolicy delegate;
+
+ DriverCreateSliceRetryPolicy(RetryPolicy delegate)
+ {
+ this.delegate = delegate;
+ }
+
+ @Override
+ public void onResponse(CompletableFuture<HttpResponse> completableFuture,
+ Request request, HttpResponse httpResponse, Throwable throwable,
+ int attempts, boolean canRetryOnADifferentHost, RetryAction retryAction)
+ {
+ if (httpResponse != null && httpResponse.statusCode() == HttpResponseStatus.CREATED.code())
+ {
+ // This is very hacky due to sidecar client is not open to modification!
+ // ACCEPTED will trigger a special/unlimited retry, which is wanted here.
+ // Therefore, fake a http response by setting the status code to ACCEPTED
+
+ LOGGER.info("Received CREATED(201) for CreateSliceRequest. " +
+ "Changing the status code to ACCEPTED(202) for unlimited retry.");
+ HttpResponse fakeResponseForRetry = new HttpResponseImpl(HttpResponseStatus.ACCEPTED.code(),
+ httpResponse.statusMessage(),
+ httpResponse.headers(),
+ httpResponse.sidecarInstance());
+ delegate.onResponse(completableFuture, request, fakeResponseForRetry,
+ throwable, attempts, canRetryOnADifferentHost,
+ retryAction);
+ }
+ else
+ {
+ delegate.onResponse(completableFuture, request, httpResponse,
+ throwable, attempts, canRetryOnADifferentHost,
+ retryAction);
+ }
+ }
+ }
+
+ private void rethrowOnInterruptedException(String message, Exception cause) throws ClientException
+ {
+ if (cause instanceof InterruptedException)
+ {
+ Thread.currentThread().interrupt();
+ throw new ClientException(message, cause);
+ }
+ }
+}
diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/blobupload/BlobStreamResult.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/blobupload/BlobStreamResult.java
new file mode 100644
index 0000000..093a56e
--- /dev/null
+++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/blobupload/BlobStreamResult.java
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.cassandra.spark.bulkwriter.blobupload;
+
+import java.math.BigInteger;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+import com.google.common.collect.Range;
+
+import org.apache.cassandra.spark.bulkwriter.RingInstance;
+import org.apache.cassandra.spark.bulkwriter.StreamError;
+import org.apache.cassandra.spark.bulkwriter.StreamResult;
+
+/**
+ * implementation of {@link StreamResult} to return results from {@link BlobStreamSession} for S3_COMPAT data
+ * transport option.
+ */
+public class BlobStreamResult extends StreamResult
+{
+ private static final long serialVersionUID = 9096932762489827053L;
+ public final Set<CreatedRestoreSlice> createdRestoreSlices;
+
+ public static BlobStreamResult empty(String sessionID, Range<BigInteger> tokenRange)
+ {
+ return new BlobStreamResult(sessionID, tokenRange, new ArrayList<>(), new ArrayList<>(), new HashSet<>(), 0, 0);
+ }
+
+ public BlobStreamResult(String sessionID,
+ Range<BigInteger> tokenRange,
+ List<StreamError> failures,
+ List<RingInstance> passed,
+ Set<CreatedRestoreSlice> createdRestoreSlices,
+ long rowCount,
+ long bytesWritten)
+ {
+ super(sessionID, tokenRange, failures, passed, rowCount, bytesWritten);
+ this.createdRestoreSlices = Collections.unmodifiableSet(createdRestoreSlices);
+ }
+
+ @Override
+ public String toString()
+ {
+ return "StreamResult{"
+ + "sessionID='" + sessionID + '\''
+ + ", tokenRange=" + tokenRange
+ + ", rowCount=" + rowCount
+ + ", bytesWritten=" + bytesWritten
+ + ", failures=" + failures
+ + ", createdRestoreSlices=" + createdRestoreSlices
+ + ", passed=" + passed
+ + '}';
+ }
+}
diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/blobupload/BlobStreamSession.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/blobupload/BlobStreamSession.java
new file mode 100644
index 0000000..275ee0e
--- /dev/null
+++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/blobupload/BlobStreamSession.java
@@ -0,0 +1,283 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.cassandra.spark.bulkwriter.blobupload;
+
+import java.math.BigInteger;
+import java.nio.file.Path;
+import java.util.HashSet;
+import java.util.Set;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.collect.Range;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import o.a.c.sidecar.client.shaded.common.data.CreateSliceRequestPayload;
+import o.a.c.sidecar.client.shaded.common.data.RestoreJobSummaryResponsePayload;
+import org.apache.cassandra.bridge.CassandraBridge;
+import org.apache.cassandra.bridge.CassandraBridgeFactory;
+import org.apache.cassandra.clients.Sidecar;
+import org.apache.cassandra.sidecar.client.SidecarInstance;
+import org.apache.cassandra.spark.bulkwriter.BulkWriteValidator;
+import org.apache.cassandra.spark.bulkwriter.BulkWriterContext;
+import org.apache.cassandra.spark.bulkwriter.JobInfo;
+import org.apache.cassandra.spark.bulkwriter.RingInstance;
+import org.apache.cassandra.spark.bulkwriter.SortedSSTableWriter;
+import org.apache.cassandra.spark.bulkwriter.StreamError;
+import org.apache.cassandra.spark.bulkwriter.StreamResult;
+import org.apache.cassandra.spark.bulkwriter.StreamSession;
+import org.apache.cassandra.spark.bulkwriter.TransportContext;
+import org.apache.cassandra.spark.bulkwriter.token.ReplicaAwareFailureHandler;
+import org.apache.cassandra.spark.common.client.ClientException;
+import org.apache.cassandra.spark.data.QualifiedTableName;
+import org.apache.cassandra.spark.transports.storage.StorageCredentials;
+
+/**
+ * {@link StreamSession} implementation that is used for streaming bundled SSTables for S3_COMPAT transport option.
+ */
+public class BlobStreamSession extends StreamSession<TransportContext.CloudStorageTransportContext>
+{
+ private static final Logger LOGGER = LoggerFactory.getLogger(BlobStreamSession.class);
+ private static final String WRITE_PHASE = "UploadAndPrepareToImport";
+ protected final BundleNameGenerator bundleNameGenerator;
+ protected final BlobDataTransferApi blobDataTransferApi;
+ protected final CassandraBridge bridge;
+ private final Set<CreatedRestoreSlice> createdRestoreSlices = new HashSet<>();
+ private final SSTablesBundler sstablesBundler;
+ private int bundleCount;
+
+ public BlobStreamSession(BulkWriterContext bulkWriterContext, SortedSSTableWriter sstableWriter,
+ TransportContext.CloudStorageTransportContext transportContext,
+ String sessionID, Range<BigInteger> tokenRange,
+ ReplicaAwareFailureHandler<RingInstance> failureHandler)
+ {
+ this(bulkWriterContext, sstableWriter, transportContext, sessionID, tokenRange,
+ CassandraBridgeFactory.get(bulkWriterContext.cluster().getLowestCassandraVersion()),
+ failureHandler);
+ }
+
+ @VisibleForTesting
+ public BlobStreamSession(BulkWriterContext bulkWriterContext, SortedSSTableWriter sstableWriter,
+ TransportContext.CloudStorageTransportContext transportContext,
+ String sessionID, Range<BigInteger> tokenRange,
+ CassandraBridge bridge, ReplicaAwareFailureHandler<RingInstance> failureHandler)
+ {
+ super(bulkWriterContext, sstableWriter, transportContext, sessionID, tokenRange, failureHandler);
+
+ JobInfo job = bulkWriterContext.job();
+ long maxSizePerBundleInBytes = job.transportInfo().getMaxSizePerBundleInBytes();
+ this.bundleNameGenerator = new BundleNameGenerator(job.getRestoreJobId().toString(), sessionID);
+ this.blobDataTransferApi = transportContext.dataTransferApi();
+ this.bridge = bridge;
+ QualifiedTableName qualifiedTableName = job.qualifiedTableName();
+ SSTableLister sstableLister = new SSTableLister(qualifiedTableName, bridge);
+ Path bundleStagingDir = sstableWriter.getOutDir().resolve("bundle_staging");
+ this.sstablesBundler = new SSTablesBundler(bundleStagingDir, sstableLister,
+ bundleNameGenerator, maxSizePerBundleInBytes);
+ }
+
+ @Override
+ protected StreamResult doScheduleStream(SortedSSTableWriter sstableWriter)
+ {
+ sstablesBundler.includeDirectory(sstableWriter.getOutDir());
+
+ sstablesBundler.finish();
+
+ if (!sstablesBundler.hasNext())
+ {
+ if (sstableWriter.sstableCount() != 0)
+ {
+ LOGGER.error("[{}] SSTable writer has produced files, but no bundle is produced", sessionID);
+ throw new RuntimeException("Bundle expected but not found");
+ }
+
+ LOGGER.warn("[{}] SSTableBundler does not produce any bundle to send", sessionID);
+ return BlobStreamResult.empty(sessionID, tokenRange);
+ }
+
+ sendSSTables(sstableWriter);
+ LOGGER.info("[{}]: Uploaded bundles to S3. sstables={} bundles={}", sessionID, sstableWriter.sstableCount(), bundleCount);
+
+ BlobStreamResult streamResult = new BlobStreamResult(sessionID,
+ tokenRange,
+ errors,
+ replicas,
+ createdRestoreSlices,
+ sstableWriter.rowCount(),
+ sstableWriter.bytesWritten());
+ LOGGER.info("StreamResult: {}", streamResult);
+ // If the number of successful createSliceRequests cannot satisfy the configured consistency level,
+ // an exception is thrown and the task is failed. Spark might retry the task.
+ BulkWriteValidator.validateClOrFail(tokenRangeMapping, failureHandler, LOGGER, WRITE_PHASE, writerContext.job());
+ return streamResult;
+ }
+
+ @Override
+ protected void sendSSTables(SortedSSTableWriter sstableWriter)
+ {
+ bundleCount = 0;
+ while (sstablesBundler.hasNext())
+ {
+ bundleCount++;
+ try
+ {
+ sendBundle(sstablesBundler.next(), false);
+ }
+ catch (RuntimeException e)
+ {
+ // log and rethrow
+ LOGGER.error("[{}]: Unexpected exception while upload SSTable", sessionID, e);
+ throw e;
+ }
+ finally
+ {
+ sstablesBundler.cleanupBundle(sessionID);
+ }
+ }
+ }
+
+ void sendBundle(Bundle bundle, boolean hasRefreshedCredentials)
+ {
+ StorageCredentials writeCredentials = getStorageCredentialsFromSidecar();
+ BundleStorageObject bundleStorageObject;
+ try
+ {
+ bundleStorageObject = uploadBundle(writeCredentials, bundle);
+ }
+ catch (Exception e)
+ {
+ // the credential might have expired; retry once
+ if (!hasRefreshedCredentials)
+ {
+ sendBundle(bundle, true);
+ return;
+ }
+ else
+ {
+ LOGGER.error("[{}]: Failed to send SSTables after refreshing token", sessionID, e);
+ throw new RuntimeException(e);
+ }
+ }
+ CreateSliceRequestPayload slicePayload = toCreateSliceRequestPayload(bundleStorageObject);
+ // Create slices on all replicas; remove the _failed_ replica if operation fails
+ replicas.removeIf(replica -> !tryCreateRestoreSlicePerReplica(replica, slicePayload));
+ if (!replicas.isEmpty())
+ {
+ createdRestoreSlices.add(new CreatedRestoreSlice(slicePayload));
+ }
+ }
+
+ private StorageCredentials getStorageCredentialsFromSidecar()
+ {
+ RestoreJobSummaryResponsePayload summary;
+ try
+ {
+ summary = blobDataTransferApi.restoreJobSummary();
+ }
+ catch (ClientException e)
+ {
+ LOGGER.error("[{}]: Failed to get restore job summary during uploading SSTable bundles", sessionID, e);
+ throw new RuntimeException(e);
+ }
+
+ return StorageCredentials.fromSidecarCredentials(summary.secrets().writeCredentials());
+ }
+
+ /**
+ * Uploads generated SSTable bundle
+ */
+ private BundleStorageObject uploadBundle(StorageCredentials writeCredentials, Bundle bundle) throws ClientException
+ {
+ BundleStorageObject object = blobDataTransferApi.uploadBundle(writeCredentials, bundle);
+ transportContext.transportExtensionImplementation()
+ .onObjectPersisted(transportContext.transportConfiguration()
+ .getWriteBucket(),
+ object.storageObjectKey,
+ bundle.bundleCompressedSize);
+ LOGGER.info("[{}]: Uploaded bundle. storageKey={} uncompressedSize={} compressedSize={}",
+ sessionID,
+ object.storageObjectKey,
+ bundle.bundleUncompressedSize,
+ bundle.bundleCompressedSize);
+ return object;
+ }
+
+ private boolean tryCreateRestoreSlicePerReplica(RingInstance replica, CreateSliceRequestPayload slicePayload)
+ {
+ try
+ {
+ SidecarInstance sidecarInstance = Sidecar.toSidecarInstance(replica, writerContext.job().effectiveSidecarPort());
+ blobDataTransferApi.createRestoreSliceFromExecutor(sidecarInstance, slicePayload);
+ return true;
+ }
+ catch (Exception exception)
+ {
+ LOGGER.error("[{}]: Failed to create slice. instance={}, slicePayload={}",
+ sessionID, replica.nodeName(), slicePayload, exception);
+ writerContext.cluster().refreshClusterInfo();
+ // the failed range is a sub-range of the tokenRange; it is guaranteed to not wrap-around
+ Range<BigInteger> failedRange = Range.openClosed(slicePayload.startToken(), slicePayload.endToken());
+ this.failureHandler.addFailure(failedRange, replica, exception.getMessage());
+ errors.add(new StreamError(failedRange, replica, exception.getMessage()));
+ // Do not abort job on a single slice failure
+ // per Doug: Thinking about this more, you probably shouldn't abort the job in any class that's running in
+ // the executors - just bubble up the exception, let Spark retry the task, and if it fails enough times catch the issue
+ // in the driver and abort the job there. Aborting the job here and then throwing will cause Spark to retry
+ // the task but since you've already aborted the job there's no retry that's really possible. Would it be
+ // possible to track the task/slice ID and just abort processing that slice here? Essentially,
+ // try to get the Sidecar to skip any uploaded data on an aborted task but let Spark's retry mechanism continue to work properly.
+ //
+ // Re: the question of just abort processing the slice. If the createSliceRequest fails,
+ // the sidecar instance does not import the slice, since there is no such task.
+ // Unlike the DIRECT mode, the files is uploaded to blob and there is no file on the disk of the sidecar instance
+ // Therefore, no need to clean up the failed files on the failed instance.
+ // However, the minority, which has created slices successfully, continue to process and import the slice.
+ // The slice data present in the minority of the replica set.
+ return false;
+ }
+ }
+
+ private CreateSliceRequestPayload toCreateSliceRequestPayload(BundleStorageObject bundleStorageObject)
+ {
+ Bundle bundle = bundleStorageObject.bundle;
+ String sliceId = generateSliceId(bundle.bundleSequence);
+ return new CreateSliceRequestPayload(sliceId,
+ 0, // todo: Yifan, assign meaningful bucket id
+ transportContext.transportConfiguration()
+ .getReadBucket(),
+ bundleStorageObject.storageObjectKey,
+ bundleStorageObject.storageObjectChecksum,
+ bundle.startToken,
+ bundle.endToken,
+ bundle.bundleUncompressedSize,
+ bundle.bundleCompressedSize);
+ }
+
+ private String generateSliceId(int bundleSequence)
+ {
+ return sessionID + '-' + bundleSequence;
+ }
+
+ @VisibleForTesting
+ Set<CreatedRestoreSlice> createdRestoreSlices()
+ {
+ return createdRestoreSlices;
+ }
+}
diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/blobupload/Bundle.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/blobupload/Bundle.java
new file mode 100644
index 0000000..02ad1ad
--- /dev/null
+++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/blobupload/Bundle.java
@@ -0,0 +1,274 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.cassandra.spark.bulkwriter.blobupload;
+
+import java.io.IOException;
+import java.math.BigInteger;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.util.ArrayList;
+import java.util.List;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Preconditions;
+import org.apache.commons.io.FileUtils;
+
+import org.apache.cassandra.spark.bulkwriter.blobupload.SSTableCollector.SSTableFilesAndRange;
+import org.apache.cassandra.spark.bulkwriter.util.IOUtils;
+import org.apache.cassandra.spark.common.DataObjectBuilder;
+
+/**
+ * Bundle represents a set of SSTables bundled, as per bundle size set by clients through writer option.
+ * {@link SSTablesBundler} can create multiple bundles, {@link #bundleSequence} is used to order the produced bundles.
+ */
+public class Bundle
+{
+ private static final String MANIFEST_FILE_NAME = "manifest.json";
+
+ public final BigInteger startToken;
+ public final BigInteger endToken;
+ public final long bundleUncompressedSize;
+ public final long bundleCompressedSize;
+ // path to the bundle directory, which contains multiple files
+ public final Path bundleDirectory;
+ // path to the bundle zip file (single file)
+ public final Path bundleFile;
+ public final int bundleSequence;
+
+ // private access for internal and mutable states
+ private final BundleManifest bundleManifest;
+ private final List<SSTableCollector.SSTableFilesAndRange> sourceSSTables;
+
+ static Builder builder()
+ {
+ return new Builder();
+ }
+
+ protected Bundle(Builder builder)
+ {
+ this.startToken = builder.startToken;
+ this.endToken = builder.endToken;
+ this.bundleManifest = builder.bundleManifest;
+ this.bundleUncompressedSize = builder.bundleUncompressedSize;
+ this.bundleCompressedSize = builder.bundleCompressedSize;
+ this.bundleDirectory = builder.bundleDirectory;
+ this.bundleFile = builder.bundleFile;
+ this.bundleSequence = builder.bundleSequence;
+ this.sourceSSTables = builder.sourceSSTables;
+ }
+
+ public void deleteAll() throws IOException
+ {
+ List<IOException> ioExceptions = new ArrayList<>();
+ sourceSSTables.forEach(sstable -> sstable.files.forEach(path -> {
+ try
+ {
+ Files.deleteIfExists(path);
+ }
+ catch (IOException e)
+ {
+ ioExceptions.add(e);
+ }
+ }));
+
+ try
+ {
+ FileUtils.deleteDirectory(bundleDirectory.toFile());
+ }
+ catch (IOException e)
+ {
+ ioExceptions.add(e);
+ }
+
+ try
+ {
+ Files.deleteIfExists(bundleFile);
+ }
+ catch (IOException e)
+ {
+ ioExceptions.add(e);
+ }
+
+ if (!ioExceptions.isEmpty())
+ {
+ IOException ioe = new IOException("Failed to delete all files of a bundle");
+ ioExceptions.forEach(ioe::addSuppressed);
+ throw ioe;
+ }
+ }
+
+ @VisibleForTesting
+ BundleManifest.Entry manifestEntry(String key)
+ {
+ return bundleManifest.get(key);
+ }
+
+ @Override
+ public String toString()
+ {
+ return "BundleManifest{entryCount: " + bundleManifest.size()
+ + ", bundleSequence: " + bundleSequence
+ + ", bundleFile: " + bundleFile
+ + ", uncompressedSize: " + bundleUncompressedSize
+ + ", compressedSize: " + bundleCompressedSize
+ + ", startToken: " + startToken
+ + ", endToken: " + endToken + '}';
+ }
+
+ /**
+ * Builder for {@link Bundle}
+ */
+ static class Builder implements DataObjectBuilder<Builder, Bundle>
+ {
+ private final BundleManifest bundleManifest;
+
+ private BigInteger startToken;
+ private BigInteger endToken;
+ private Path bundleStagingDirectory;
+ private Path bundleDirectory; // path of the directory that include sstables and manifest file to be bundled
+ private Path bundleFile; // path of the bundle/zip file, which is uploaded to s3
+ private int bundleSequence;
+ private List<SSTableCollector.SSTableFilesAndRange> sourceSSTables;
+ private long bundleUncompressedSize;
+ private long bundleCompressedSize;
+ private BundleNameGenerator bundleNameGenerator;
+
+ Builder()
+ {
+ this.bundleManifest = new BundleManifest();
+ }
+
+ /**
+ * Set the staging directory for all bundles
+ * @param bundleStagingDirectory staging directory for all bundles
+ * @return builder
+ */
+ public Builder bundleStagingDirectory(Path bundleStagingDirectory)
+ {
+ Preconditions.checkNotNull(bundleStagingDirectory, "Cannot set bundle staging directory to null");
+ return with(b -> b.bundleStagingDirectory = bundleStagingDirectory);
+ }
+
+ /**
+ * Set the bundle name generator
+ * @param bundleNameGenerator generates bundle name
+ * @return builder
+ */
+ public Builder bundleNameGenerator(BundleNameGenerator bundleNameGenerator)
+ {
+ return with(b -> b.bundleNameGenerator = bundleNameGenerator);
+ }
+
+ /**
+ * Set the sequence of the bundle. It should be monotonically increasing
+ * @param bundleSequence sequence of the bundle
+ * @return builder
+ */
+ public Builder bundleSequence(int bundleSequence)
+ {
+ Preconditions.checkArgument(bundleSequence >= 0, "bundleSequence cannot be negative");
+ return with(b -> b.bundleSequence = bundleSequence);
+ }
+
+ /**
+ * Set the source sstables to be bundled
+ * @param sourceSSTables sstables to be bundled
+ * @return builder
+ */
+ public Builder sourceSSTables(List<SSTableCollector.SSTableFilesAndRange> sourceSSTables)
+ {
+ Preconditions.checkArgument(sourceSSTables != null && !sourceSSTables.isEmpty(),
+ "No files to bundle");
+
+ return with(b -> {
+ b.sourceSSTables = sourceSSTables;
+ b.bundleUncompressedSize = sourceSSTables.stream()
+ .mapToLong(sstable -> sstable.size)
+ .sum();
+ });
+ }
+
+ public Bundle build()
+ {
+ try
+ {
+ prepareBuild();
+ }
+ catch (IOException ioe)
+ {
+ throw new RuntimeException("Unable to produce bundle manifest", ioe);
+ }
+
+ return new Bundle(this);
+ }
+
+ public Builder self()
+ {
+ return this;
+ }
+
+ private void prepareBuild() throws IOException
+ {
+ bundleDirectory = bundleStagingDirectory.resolve(Integer.toString(bundleSequence));
+ Files.createDirectories(bundleDirectory);
+
+ populateBundleManifestAndPersist();
+
+ String bundleName = bundleNameGenerator.generate(startToken, endToken);
+ bundleFile = bundleStagingDirectory.resolve(bundleName);
+ bundleCompressedSize = SSTablesBundler.zip(bundleDirectory, bundleFile);
+ }
+
+ private void populateBundleManifestAndPersist() throws IOException
+ {
+ for (SSTableFilesAndRange sstable : sourceSSTables)
+ {
+ // all SSTable components related to one SSTable moved under same bundle
+ BundleManifest.Entry manifestEntry = new BundleManifest.Entry(sstable.summary);
+ for (Path componentPath : sstable.files)
+ {
+ String checksum = IOUtils.xxhash32(componentPath);
+ Path targetPath = bundleDirectory.resolve(componentPath.getFileName());
+ // link the original files to the bundle dir to avoid copying data
+ Files.createLink(targetPath, componentPath);
+ manifestEntry.addComponentChecksum(componentPath.getFileName().toString(), checksum);
+ }
+ addManifestEntry(manifestEntry);
+ }
+
+ bundleManifest.persistTo(bundleDirectory.resolve(Bundle.MANIFEST_FILE_NAME));
+ }
+
+ private void addManifestEntry(BundleManifest.Entry entry)
+ {
+ if (bundleManifest.isEmpty())
+ {
+ startToken = entry.startToken();
+ endToken = entry.endToken();
+ }
+ else
+ {
+ startToken = startToken.min(entry.startToken());
+ endToken = endToken.max(entry.endToken());
+ }
+ bundleManifest.addEntry(entry);
+ }
+ }
+}
diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/blobupload/BundleManifest.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/blobupload/BundleManifest.java
new file mode 100644
index 0000000..a79db54
--- /dev/null
+++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/blobupload/BundleManifest.java
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.cassandra.spark.bulkwriter.blobupload;
+
+import java.io.IOException;
+import java.math.BigInteger;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.util.HashMap;
+import java.util.Map;
+
+import com.google.common.annotations.VisibleForTesting;
+
+import com.fasterxml.jackson.annotation.JsonProperty;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.fasterxml.jackson.databind.ObjectWriter;
+import org.apache.cassandra.bridge.SSTableSummary;
+
+/**
+ * Manifest of all SSTables in the bundle
+ * It is a variant of {@link HashMap}, for the convenience of json serialization
+ */
+public class BundleManifest extends HashMap<String, BundleManifest.Entry>
+{
+ public static final ObjectWriter OBJECT_WRITER = new ObjectMapper().writerWithDefaultPrettyPrinter();
+
+ private static final long serialVersionUID = 6593130321276240266L;
+
+ public void addEntry(Entry manifestEntry)
+ {
+ super.put(manifestEntry.key, manifestEntry);
+ }
+
+ public void persistTo(Path filePath) throws IOException
+ {
+ Files.createFile(filePath);
+ OBJECT_WRITER.writeValue(filePath.toFile(), this);
+ }
+
+ /**
+ * Manifest of a single SSTable
+ * componentsChecksum include checksums of individual SSTable components
+ * startToken and endToken represents the token range of the SSTable
+ */
+ public static class Entry
+ {
+ // uniquely identify a manifest entry
+ private final String key;
+ private final Map<String, String> componentsChecksum;
+ private final BigInteger startToken;
+ private final BigInteger endToken;
+
+ @VisibleForTesting
+ Entry(String key, BigInteger startToken, BigInteger endToken)
+ {
+ this.key = key;
+ this.startToken = startToken;
+ this.endToken = endToken;
+ this.componentsChecksum = new HashMap<>();
+ }
+
+ public Entry(SSTableSummary summary)
+ {
+ this.key = summary.sstableId;
+ this.startToken = summary.firstToken;
+ this.endToken = summary.lastToken;
+ this.componentsChecksum = new HashMap<>();
+ }
+
+ public void addComponentChecksum(String component, String checksum)
+ {
+ componentsChecksum.put(component, checksum);
+ }
+
+ @JsonProperty("components_checksum")
+ public Map<String, String> componentsChecksum()
+ {
+ return componentsChecksum;
+ }
+
+ @JsonProperty("start_token")
+ public BigInteger startToken()
+ {
+ return startToken;
+ }
+
+ @JsonProperty("end_token")
+ public BigInteger endToken()
+ {
+ return endToken;
+ }
+ }
+}
diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/blobupload/BundleNameGenerator.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/blobupload/BundleNameGenerator.java
new file mode 100644
index 0000000..afeead1
--- /dev/null
+++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/blobupload/BundleNameGenerator.java
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.cassandra.spark.bulkwriter.blobupload;
+
+import java.math.BigInteger;
+
+/**
+ * Generate names for SSTable bundles
+ */
+public class BundleNameGenerator
+{
+ private final String commonName;
+
+ public BundleNameGenerator(String jobId, String sessionId)
+ {
+ this.commonName = '_' + jobId + '_' + sessionId + '_';
+ }
+
+ /**
+ * We want to introduce variability in starting character of zip file name, to guarantee entropy on the object name to
+ * avoid 503s from S3 to workaround the throughput limit that is based on the object name.
+ * <p>
+ * We use 62 for mod, because 62 = 26 (lower case alphabets) + 26 (upper case alphabets) + 10 (digits)
+ * For e.g. seed = 512 will map to lower case alphabet q
+ * </p>
+ * @param seed a random integer to derive the prefix character
+ * @return starting character to be used while naming zipped SSTables file
+ */
+ private char generatePrefixChar(int seed)
+ {
+ int group = seed % 62;
+ if (group <= 25)
+ {
+ return (char) ('a' + group);
+ }
+ else if (group <= 51)
+ {
+ return (char) ('A' + group - 26);
+ }
+ else
+ {
+ return (char) ('0' + group - 52);
+ }
+ }
+
+ public String generate(BigInteger startToken, BigInteger endToken)
+ {
+ return generatePrefixChar(startToken.intValue()) + commonName + startToken + '_' + endToken;
+ }
+}
diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/blobupload/BundleStorageObject.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/blobupload/BundleStorageObject.java
new file mode 100644
index 0000000..f0a2a5f
--- /dev/null
+++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/blobupload/BundleStorageObject.java
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.cassandra.spark.bulkwriter.blobupload;
+
+import org.apache.cassandra.spark.common.DataObjectBuilder;
+
+/**
+ * Storage object of the uploaded bundle, including object key and checksum
+ */
+public class BundleStorageObject
+{
+ public final String storageObjectKey;
+ public final String storageObjectChecksum;
+ public final Bundle bundle;
+
+ static Builder builder()
+ {
+ return new Builder();
+ }
+
+ protected BundleStorageObject(Builder builder)
+ {
+ this.storageObjectChecksum = builder.storageObjectChecksum;
+ this.storageObjectKey = builder.storageObjectKey;
+ this.bundle = builder.bundle;
+ }
+
+ @Override
+ public String toString()
+ {
+ return "Bundle{manifest: " + bundle
+ + ", storageObjectKey: " + storageObjectKey
+ + ", storageObjectChecksum: " + storageObjectChecksum + '}';
+ }
+
+ static class Builder implements DataObjectBuilder<Builder, BundleStorageObject>
+ {
+ private String storageObjectChecksum;
+ private String storageObjectKey;
+ private Bundle bundle;
+
+ public Builder storageObjectChecksum(String bundleChecksum)
+ {
+ return with(b -> b.storageObjectChecksum = bundleChecksum);
+ }
+
+ public Builder storageObjectKey(String storageObjectKey)
+ {
+ return with(b -> b.storageObjectKey = storageObjectKey);
+ }
+
+ public Builder bundle(Bundle bundle)
+ {
+ return with(b -> b.bundle = bundle);
+ }
+
+ public BundleStorageObject build()
+ {
+ return new BundleStorageObject(this);
+ }
+
+ public Builder self()
+ {
+ return this;
+ }
+ }
+}
diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/blobupload/CassandraCloudStorageTransportContext.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/blobupload/CassandraCloudStorageTransportContext.java
new file mode 100644
index 0000000..fe07304
--- /dev/null
+++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/blobupload/CassandraCloudStorageTransportContext.java
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.cassandra.spark.bulkwriter.blobupload;
+
+import java.lang.reflect.InvocationTargetException;
+import java.math.BigInteger;
+import java.util.Objects;
+
+import com.google.common.collect.Range;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.cassandra.spark.bulkwriter.BulkSparkConf;
+import org.apache.cassandra.spark.bulkwriter.BulkWriterContext;
+import org.apache.cassandra.spark.bulkwriter.ClusterInfo;
+import org.apache.cassandra.spark.bulkwriter.JobInfo;
+import org.apache.cassandra.spark.bulkwriter.RingInstance;
+import org.apache.cassandra.spark.bulkwriter.SortedSSTableWriter;
+import org.apache.cassandra.spark.bulkwriter.TransportContext;
+import org.apache.cassandra.spark.bulkwriter.token.ReplicaAwareFailureHandler;
+import org.apache.cassandra.spark.transports.storage.extensions.StorageTransportConfiguration;
+import org.apache.cassandra.spark.transports.storage.extensions.StorageTransportExtension;
+import org.jetbrains.annotations.NotNull;
+
+public class CassandraCloudStorageTransportContext implements TransportContext.CloudStorageTransportContext
+{
+ private static final Logger LOGGER = LoggerFactory.getLogger(CassandraCloudStorageTransportContext.class);
+
+ @NotNull
+ private final StorageTransportExtension storageTransportExtension;
+ @NotNull
+ private final StorageTransportConfiguration storageTransportConfiguration;
+ @NotNull
+ private final BlobDataTransferApi dataTransferApi;
+ @NotNull
+ private final BulkSparkConf conf;
+ @NotNull
+ private final JobInfo jobInfo;
+ @NotNull
+ private final ClusterInfo clusterInfo;
+ private StorageClient storageClient;
+
+ public CassandraCloudStorageTransportContext(@NotNull BulkWriterContext bulkWriterContext,
+ @NotNull BulkSparkConf conf,
+ boolean isOnDriver)
+ {
+ // we may not always need a transport extension implementation in cloud based transport context, revisit this
+ // check when we have multiple cloud based transport options supported
+ Objects.requireNonNull(conf.getTransportInfo().getTransportExtensionClass(),
+ "DATA_TRANSPORT_EXTENSION_CLASS must be provided");
+ this.conf = conf;
+ this.jobInfo = bulkWriterContext.job();
+ this.clusterInfo = bulkWriterContext.cluster();
+ this.storageTransportExtension = createStorageTransportExtension(isOnDriver);
+ this.storageTransportConfiguration = storageTransportExtension.getStorageConfiguration();
+ Objects.requireNonNull(storageTransportConfiguration,
+ "Storage configuration cannot be null in order to upload to cloud");
+ this.dataTransferApi = createBlobDataTransferApi();
+ }
+
+ @Override
+ public BlobStreamSession createStreamSession(BulkWriterContext writerContext,
+ String sessionId,
+ SortedSSTableWriter sstableWriter,
+ Range<BigInteger> range,
+ ReplicaAwareFailureHandler<RingInstance> failureHandler)
+ {
+ return new BlobStreamSession(writerContext, sstableWriter,
+ this, sessionId, range, failureHandler);
+ }
+
+ @Override
+ public BlobDataTransferApi dataTransferApi()
+ {
+ return dataTransferApi;
+ }
+
+ @NotNull
+ @Override
+ public StorageTransportConfiguration transportConfiguration()
+ {
+ return storageTransportConfiguration;
+ }
+
+ /**
+ * Instantiate and initialize the StorageTransportExtension instance, for only once.
+ *
+ * @return StorageTransportExtension instance
+ */
+ @NotNull
+ @Override
+ public StorageTransportExtension transportExtensionImplementation()
+ {
+ return this.storageTransportExtension;
+ }
+
+ // only invoke it in constructor
+ protected BlobDataTransferApi createBlobDataTransferApi()
+ {
+ storageClient = new StorageClient(storageTransportConfiguration, conf.getStorageClientConfig());
+ return new BlobDataTransferApi(jobInfo,
+ clusterInfo.getCassandraContext().getSidecarClient(),
+ storageClient);
+ }
+
+ // only invoke it in constructor
+ protected StorageTransportExtension createStorageTransportExtension(boolean isOnDriver)
+ {
+ String transportExtensionClass = conf.getTransportInfo().getTransportExtensionClass();
+ try
+ {
+ Class<StorageTransportExtension> clazz = (Class<StorageTransportExtension>) Class.forName(transportExtensionClass);
+ StorageTransportExtension extension = clazz.getConstructor().newInstance();
+ LOGGER.info("Initializing storage transport extension. jobId={}, restoreJobId={}",
+ jobInfo.getId(), jobInfo.getRestoreJobId());
+ extension.initialize(jobInfo.getId(), conf.getSparkConf(), isOnDriver);
+ // Only assign when initialization is complete
+ // to avoid exposing uninitialized extension object, which leads to unexpected behavior
+ return extension;
+ }
+ catch (ClassNotFoundException | ClassCastException | InvocationTargetException | InstantiationException
+ | IllegalAccessException | NoSuchMethodException e)
+ {
+ throw new RuntimeException("Invalid storage transport extension class specified: '" + transportExtensionClass, e);
+ }
+ }
+
+ @Override
+ public void close()
+ {
+ if (storageClient != null)
+ {
+ try
+ {
+ storageClient.close();
+ }
+ catch (Exception exception)
+ {
+ LOGGER.warn("Failed to close storage client", exception);
+ }
+ }
+ }
+}
diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/blobupload/CreatedRestoreSlice.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/blobupload/CreatedRestoreSlice.java
new file mode 100644
index 0000000..c1cec7d
--- /dev/null
+++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/blobupload/CreatedRestoreSlice.java
@@ -0,0 +1,183 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.cassandra.spark.bulkwriter.blobupload;
+
+import java.io.Serializable;
+import java.util.Objects;
+import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+// Note: Must use the sidecar shaded jackson to ser/deser sidecar objects
+import o.a.c.sidecar.client.shaded.com.fasterxml.jackson.core.JsonProcessingException;
+import o.a.c.sidecar.client.shaded.com.fasterxml.jackson.databind.ObjectMapper;
+import o.a.c.sidecar.client.shaded.common.data.CreateSliceRequestPayload;
+import org.apache.cassandra.spark.bulkwriter.token.ConsistencyLevel;
+import org.apache.cassandra.spark.common.model.CassandraInstance;
+import org.apache.cassandra.spark.data.ReplicationFactor;
+import org.jetbrains.annotations.NotNull;
+
+/**
+ * A serializable wrapper of {@link CreateSliceRequestPayload} and also implements hashcode and equals
+ */
+public class CreatedRestoreSlice implements Serializable
+{
+ private static final ObjectMapper MAPPER = new ObjectMapper();
+ private static final Logger LOGGER = LoggerFactory.getLogger(CreatedRestoreSlice.class);
+ private static final long serialVersionUID = 1738928448022537598L;
+
+ private transient CreateSliceRequestPayload sliceRequestPayload;
+ private transient Set<CassandraInstance> succeededInstances;
+ private transient boolean isSatisfied = false;
+ public final String sliceRequestPayloadJson; // equals and hashcode use and only implement with this field
+
+ public CreatedRestoreSlice(@NotNull CreateSliceRequestPayload sliceRequestPayload)
+ {
+ this.sliceRequestPayload = sliceRequestPayload;
+ this.sliceRequestPayloadJson = toJson(sliceRequestPayload);
+ }
+
+ public CreateSliceRequestPayload sliceRequestPayload()
+ {
+ if (sliceRequestPayloadJson == null)
+ {
+ throw new IllegalStateException("sliceRequestPayloadJson cannot be null");
+ }
+
+ if (sliceRequestPayload != null)
+ {
+ return sliceRequestPayload;
+ }
+
+ // The following code could run multiple times if in a multi-threads environment.
+ // It is relatively cheap to deserialize, hence that multiple runs is acceptable.
+ try
+ {
+ sliceRequestPayload = MAPPER.readValue(sliceRequestPayloadJson, CreateSliceRequestPayload.class);
+ return sliceRequestPayload;
+ }
+ catch (Exception exception)
+ {
+ LOGGER.error("Unable to deserialize CreateSliceRequestPayload from JSON. requestPayloadJson={}",
+ sliceRequestPayloadJson, exception);
+ throw new RuntimeException("Unable to deserialize CreateSliceRequestPayload from JSON", exception);
+ }
+ }
+
+ public void addSucceededInstance(CassandraInstance instance)
+ {
+ succeededInstances().add(instance);
+ }
+
+ /**
+ * Check whether the slice satisfies the consistency level
+ *
+ * @param consistencyLevel consistency level to check
+ * @param replicationFactor replication factor to check
+ * @param localDC local DC name if any
+ * @return check result, either not satisfied, satisfied, or already satisfied
+ */
+ public synchronized ConsistencyLevelCheckResult checkForConsistencyLevel(ConsistencyLevel consistencyLevel,
+ ReplicationFactor replicationFactor,
+ String localDC)
+ {
+ if (isSatisfied)
+ {
+ return ConsistencyLevelCheckResult.ALREADY_SATISFIED;
+ }
+
+ if (!succeededInstances().isEmpty()
+ && consistencyLevel.canBeSatisfied(succeededInstances(), replicationFactor, localDC))
+ {
+ isSatisfied = true;
+ return ConsistencyLevelCheckResult.SATISFIED;
+ }
+
+ return ConsistencyLevelCheckResult.NOT_SATISFIED;
+ }
+
+ @Override
+ public boolean equals(Object o)
+ {
+ if (this == o)
+ {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass())
+ {
+ return false;
+ }
+ CreatedRestoreSlice that = (CreatedRestoreSlice) o;
+ return Objects.equals(sliceRequestPayloadJson, that.sliceRequestPayloadJson);
+ }
+
+ @Override
+ public int hashCode()
+ {
+ return Objects.hash(sliceRequestPayloadJson);
+ }
+
+ @Override
+ public String toString()
+ {
+ return sliceRequestPayload.toString();
+ }
+
+ Set<CassandraInstance> succeededInstances()
+ {
+ Set<CassandraInstance> currentInstances = succeededInstances;
+ if (currentInstances != null)
+ {
+ return currentInstances;
+ }
+
+ synchronized (this)
+ {
+ if (succeededInstances == null)
+ {
+ succeededInstances = ConcurrentHashMap.newKeySet();
+ }
+ }
+ return succeededInstances;
+ }
+
+ private static String toJson(@NotNull CreateSliceRequestPayload sliceRequestPayload)
+ {
+ try
+ {
+ return MAPPER.writeValueAsString(sliceRequestPayload);
+ }
+ catch (JsonProcessingException jsonProcessingException)
+ {
+ LOGGER.error("Unable to serialize CreateSliceRequestPayload to JSON. requestPayload={}",
+ sliceRequestPayload, jsonProcessingException);
+ throw new RuntimeException("Unable to serialize CreateSliceRequestPayload to JSON", jsonProcessingException);
+ }
+ }
+
+ public enum ConsistencyLevelCheckResult
+ {
+ NOT_SATISFIED,
+ SATISFIED,
+ ALREADY_SATISFIED;
+ }
+}
diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/blobupload/DataChunker.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/blobupload/DataChunker.java
new file mode 100644
index 0000000..460992c
--- /dev/null
+++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/blobupload/DataChunker.java
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.cassandra.spark.bulkwriter.blobupload;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.channels.ReadableByteChannel;
+import java.util.Iterator;
+import java.util.NoSuchElementException;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Preconditions;
+
+/**
+ * {@link DataChunker} helps break down data into chunks according to maxChunkSizeInBytes set.
+ */
+public class DataChunker
+{
+ // See https://docs.aws.amazon.com/AmazonS3/latest/userguide/qfacts.html
+ private static final int MINIMUM_CHUNK_SIZE_IN_BYTES = 5 * 1024 * 1024; // 5MiB. There is no size requirement for the last chunk
+ private final int chunkSizeInBytes;
+
+ public DataChunker(int chunkSizeInBytes)
+ {
+ this(chunkSizeInBytes, true);
+ }
+
+ /**
+ * Do not use this constructor, unless testing
+ * @param chunkSizeInBytes chunk size in bytes
+ * @param validate whether enables validtion for chunkSizeInBytes or not
+ */
+ @VisibleForTesting
+ DataChunker(int chunkSizeInBytes, boolean validate)
+ {
+ if (validate)
+ {
+ Preconditions.checkArgument(chunkSizeInBytes >= MINIMUM_CHUNK_SIZE_IN_BYTES,
+ "Chunk size is too small. Minimum size: " + MINIMUM_CHUNK_SIZE_IN_BYTES);
+ }
+ this.chunkSizeInBytes = chunkSizeInBytes;
+ }
+
+ /**
+ * Chunk the input stream based on chunkSize
+ * @param channel data source file channel
+ * @return iterator of ByteBuffers. Call-sites should check {@link ByteBuffer#limit()}
+ * to determine how much data to read, especially the last chunk
+ */
+ public Iterator<ByteBuffer> chunks(ReadableByteChannel channel)
+ {
+ return new Iterator<ByteBuffer>()
+ {
+ private boolean eosReached = false;
+ private final ByteBuffer buffer = ByteBuffer.allocate(chunkSizeInBytes);
+
+ public boolean hasNext()
+ {
+ buffer.clear();
+ try
+ {
+ eosReached = -1 == channel.read(buffer);
+ buffer.flip();
+ }
+ catch (IOException e)
+ {
+ eosReached = true;
+ }
+ return !eosReached;
+ }
+
+ public ByteBuffer next()
+ {
+ if (eosReached)
+ {
+ throw new NoSuchElementException("End of stream has reached");
+ }
+
+ return buffer;
+ }
+ };
+ }
+}
diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/blobupload/SSTableBundleSpec.md b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/blobupload/SSTableBundleSpec.md
new file mode 100644
index 0000000..df31901
--- /dev/null
+++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/blobupload/SSTableBundleSpec.md
@@ -0,0 +1,78 @@
+<!--
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+-->
+
+# SSTable Bundle Specification
+
+When bulk writing SSTables using the `S3_COMPAT` mode, the writer groups several SSTables into a single file, named `Bundle`.
+
+This document walks through the specification of `Bundle`.
+
+## Bundle Structure
+
+A `Bundle` is essentially a zip file.
+
+Each bundle consists of a `BundleManifest` file and sstable files. The manifest file describes the bundled sstable files, and it is used for validation.
+
+Visually, the bundle structure is illustrated in the following tree diagram
+
+```Plain Text
+bundle/
+├── manifest.json
+├── sstable-1
+├── sstable-2
+└── ...
+```
+
+> ℹ️ A sstable consists of several components/files
+
+## Manifest Schema
+
+A manifest file is in `JSON` format. It is essentially a map of `<SSTableId, SSTableMetadata>`.
+`SSTableId` is the shared prefix of the components of a sstable. `SSTableMetadata` consists of a string map `componentsChecksum`, `startToken` and `endToken`.
+
+> ℹ️ The checksum value of the individual component is computed using XXHash32
+
+Here is an example of the manifest
+
+```json
+{
+ "nb-1-big-" : {
+ "start_token" : 1,
+ "end_token" : 3,
+ "components_checksum" : {
+ "nb-1-big-Data.db" : "12345678",
+ "nb-1-big-Index.db" : "12345678",
+ "nb-1-big-CompressionInfo.db" : "12345678",
+ "nb-1-big-Summary.db" : "12345678",
+ ...
+ }
+ },
+ "nb-2-big-" : {
+ "start_token" : 4,
+ "end_token" : 7,
+ "components_checksum" : {
+ "nb-2-big-Data.db" : "12345678",
+ "nb-2-big-Index.db" : "12345678",
+ "nb-2-big-CompressionInfo.db" : "12345678",
+ "nb-2-big-Summary.db" : "12345678",
+ ...
+ },
+ ...
+ }
+}
+```
diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/blobupload/SSTableCollector.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/blobupload/SSTableCollector.java
new file mode 100644
index 0000000..7bb1c6d
--- /dev/null
+++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/blobupload/SSTableCollector.java
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.cassandra.spark.bulkwriter.blobupload;
+
+import java.nio.file.Path;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+import org.apache.cassandra.bridge.SSTableSummary;
+
+/**
+ * Collect SSTables from listing the included directories
+ */
+public interface SSTableCollector
+{
+ /**
+ * Include sstables under the directory
+ * @param dir directory that contains sstables
+ */
+ void includeDirectory(Path dir);
+
+ /**
+ * @return total size of all sstables included
+ */
+ long totalSize();
+
+ /**
+ * Get an SSTable from the collector, but do not remove it
+ * @return sstable or null if the collector is empty
+ */
+ SSTableFilesAndRange peek();
+
+ /**
+ * Get an SSTable from the collector and remove it
+ * @return sstable or null if the collector is empty
+ */
+ SSTableFilesAndRange consumeOne();
+
+ /**
+ * @return true if the collector is empty; otherwise, false
+ */
+ boolean isEmpty();
+
+ /**
+ * Simple record class containing SSTable component file paths, summary and size
+ */
+ class SSTableFilesAndRange
+ {
+ public final Set<Path> files; // immutable set
+ public final SSTableSummary summary;
+ public final long size;
+
+ public SSTableFilesAndRange(SSTableSummary summary, List<Path> components, long size)
+ {
+ this.summary = summary;
+ this.files = Collections.unmodifiableSet(new HashSet<>(components));
+ this.size = size;
+ }
+ }
+}
diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/blobupload/SSTableLister.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/blobupload/SSTableLister.java
new file mode 100644
index 0000000..8b7a227
--- /dev/null
+++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/blobupload/SSTableLister.java
@@ -0,0 +1,198 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.cassandra.spark.bulkwriter.blobupload;
+
+import java.io.IOException;
+import java.math.BigInteger;
+import java.nio.file.Files;
+import java.nio.file.LinkOption;
+import java.nio.file.Path;
+import java.nio.file.attribute.BasicFileAttributes;
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Queue;
+import java.util.Set;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.cassandra.bridge.CassandraBridge;
+import org.apache.cassandra.bridge.SSTableSummary;
+import org.apache.cassandra.spark.data.FileSystemSSTable;
+import org.apache.cassandra.spark.data.QualifiedTableName;
+import org.apache.cassandra.spark.data.SSTable;
+import org.apache.cassandra.spark.stats.Stats;
+
+/**
+ * {@link SSTableLister} lists the directories containing SSTables.
+ * Internally, the listed SSTables are sorted by the insertion order of the directories,
+ * and by the first and end token of the SSTables.
+ * Therefore, it is expected that the SSTables are sorted when consuming from the lister.
+ */
+public class SSTableLister implements SSTableCollector
+{
+ private static final Logger LOGGER = LoggerFactory.getLogger(SSTableLister.class);
+ private static final Comparator<SSTableFilesAndRange> SORT_BY_FIRST_TOKEN_THEN_LAST_TOKEN =
+ Comparator.<SSTableFilesAndRange, BigInteger>comparing(sstable -> sstable.summary.firstToken)
+ .thenComparing(sstable -> sstable.summary.lastToken);
+ private final QualifiedTableName qualifiedTableName;
+ private final CassandraBridge bridge;
+ private final Queue<SSTableFilesAndRange> sstables;
+ private final Set<Path> sstableDirectories;
+ private long totalSize;
+
+ public SSTableLister(QualifiedTableName qualifiedTableName, CassandraBridge bridge)
+ {
+ this.qualifiedTableName = qualifiedTableName;
+ this.bridge = bridge;
+ this.sstables = new LinkedBlockingQueue<>();
+ this.sstableDirectories = new HashSet<>();
+ }
+
+ @Override
+ public void includeDirectory(Path dir)
+ {
+ if (!sstableDirectories.add(dir))
+ {
+ throw new IllegalArgumentException("The directory has been included already! Input dir: " + dir
+ + "; existing directories: " + sstableDirectories);
+ }
+
+ listSSTables(dir)
+ .map(components -> {
+ SSTable sstable = buildSSTable(components);
+ SSTableSummary summary = bridge.getSSTableSummary(qualifiedTableName.keyspace(),
+ qualifiedTableName.table(),
+ sstable);
+ long size = sizeSum(components);
+ totalSize += size;
+ return new SSTableFilesAndRange(summary, components, sizeSum(components));
+ })
+ .sorted(SORT_BY_FIRST_TOKEN_THEN_LAST_TOKEN)
+ .forEach(sstables::add);
+ }
+
+ @Override
+ public long totalSize()
+ {
+ return totalSize;
+ }
+
+ @Override
+ public SSTableFilesAndRange peek()
+ {
+ return sstables.peek();
+ }
+
+ @Override
+ public SSTableFilesAndRange consumeOne()
+ {
+ SSTableFilesAndRange sstable = sstables.poll();
+ if (sstable != null)
+ {
+ totalSize -= sstable.size;
+ }
+
+ return sstable;
+ }
+
+ @Override
+ public boolean isEmpty()
+ {
+ return sstables.isEmpty();
+ }
+
+ private Stream<List<Path>> listSSTables(Path dir)
+ {
+ Map<String, List<Path>> componentsByPrefix = new HashMap<>();
+ try (Stream<Path> stream = Files.list(dir))
+ {
+ stream.forEach(path -> {
+ final String ssTablePrefix = getSSTablePrefix(path.getFileName().toString());
+
+ if (ssTablePrefix.isEmpty())
+ {
+ // ignore files that are not SSTables components
+ return;
+ }
+
+ List<Path> prefixPaths = componentsByPrefix.computeIfAbsent(ssTablePrefix, ignored -> new ArrayList<>(8));
+ prefixPaths.add(path);
+ });
+ return componentsByPrefix.values().stream();
+ }
+ catch (IOException e)
+ {
+ throw new RuntimeException(e);
+ }
+ }
+
+ private long sizeSum(List<Path> files)
+ {
+ return files
+ .stream()
+ .mapToLong(path -> {
+ try
+ {
+ BasicFileAttributes fileAttributes = Files.readAttributes(path, BasicFileAttributes.class,
+ // not expecting links and do not follow links
+ LinkOption.NOFOLLOW_LINKS);
+ if (fileAttributes != null && fileAttributes.isRegularFile())
+ {
+ return fileAttributes.size();
+ }
+ else
+ {
+ return 0L;
+ }
+ }
+ catch (IOException e)
+ {
+ LOGGER.warn("Failed to get size of file. path={}", path);
+ return 0L;
+ }
+ })
+ .sum();
+ }
+
+ private String getSSTablePrefix(String componentName)
+ {
+ return componentName.substring(0, componentName.lastIndexOf('-') + 1);
+ }
+
+ private SSTable buildSSTable(List<Path> components)
+ {
+ List<Path> dataComponents = components.stream()
+ .filter(path -> path.getFileName().toString().contains("Data.db"))
+ .collect(Collectors.toList());
+ if (dataComponents.size() != 1)
+ {
+ throw new IllegalArgumentException("SSTable should have only one data component");
+ }
+ return new FileSystemSSTable(dataComponents.get(0), true, () -> Stats.DoNothingStats.INSTANCE);
+ }
+}
diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/blobupload/SSTablesBundler.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/blobupload/SSTablesBundler.java
new file mode 100644
index 0000000..c93044f
--- /dev/null
+++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/blobupload/SSTablesBundler.java
@@ -0,0 +1,188 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.cassandra.spark.bulkwriter.blobupload;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+import java.util.NoSuchElementException;
+import java.util.stream.Stream;
+import java.util.zip.ZipEntry;
+import java.util.zip.ZipOutputStream;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.cassandra.spark.bulkwriter.blobupload.SSTableCollector.SSTableFilesAndRange;
+
+/**
+ * {@link SSTablesBundler} bundles SSTables in the output directory provided by
+ * {@link org.apache.cassandra.bridge.SSTableWriter}. With output from {@link SSTableLister}, we get sorted
+ * list of {@link SSTableFilesAndRange}. According to sorted order, we move all component files
+ * related to a SSTable into bundle folder. When a bundle's size exceeds configured, a new bundle is created and
+ * SSTable components are moved into new bundle folder.
+ * <br>
+ * When a bundle is being closed, {@link Bundle} generated for that bundle gets written to manifest.json file
+ * and added to bundle folder. The entire folder is then zipped and added to zipped_bundles folder
+ * <br>
+ * Under output directory of {@link org.apache.cassandra.bridge.SSTableWriter}, sample folders created look like
+ * bundle0, bundle1, bundle2, zipped_bundles
+ */
+public class SSTablesBundler implements Iterator<Bundle>
+{
+ private static final Logger LOGGER = LoggerFactory.getLogger(SSTablesBundler.class);
+ private final SSTableCollector collector;
+ private final BundleNameGenerator bundleNameGenerator;
+ private final Path bundleStagingDir;
+ private final long maxSizePerBundleInBytes;
+ private boolean reachedEnd = false;
+ private int bundleIndex = 0;
+ private Bundle currentBundle = null;
+
+ public SSTablesBundler(Path bundleStagingDir, SSTableCollector collector,
+ BundleNameGenerator bundleNameGenerator, long maxSizePerBundleInBytes)
+ {
+ this.bundleStagingDir = bundleStagingDir;
+ this.collector = collector;
+ this.bundleNameGenerator = bundleNameGenerator;
+ this.maxSizePerBundleInBytes = maxSizePerBundleInBytes;
+ }
+
+ @Override
+ public boolean hasNext()
+ {
+ if (reachedEnd)
+ {
+ // consume all sstables from collector
+ return !collector.isEmpty();
+ }
+ else
+ {
+ // consume only when sstables have enough total size
+ return collector.totalSize() > maxSizePerBundleInBytes;
+ }
+ }
+
+ @Override
+ public Bundle next()
+ {
+ if (!hasNext())
+ {
+ throw new NoSuchElementException("Bundles have exhausted");
+ }
+
+ try
+ {
+ currentBundle = computeNext();
+ return currentBundle;
+ }
+ catch (Exception exception)
+ {
+ throw new RuntimeException("Unable to produce bundle", exception);
+ }
+ }
+
+ public void includeDirectory(Path dir)
+ {
+ collector.includeDirectory(dir);
+ }
+
+ public void finish()
+ {
+ reachedEnd = true;
+ }
+
+ public void cleanupBundle(String sessionID)
+ {
+ LOGGER.info("[{}]: Clean up bundle files after stream session bundle={}", sessionID, currentBundle);
+ try
+ {
+ Bundle bundle = currentBundle;
+ currentBundle = null;
+ bundle.deleteAll();
+ }
+ catch (IOException exception)
+ {
+ LOGGER.warn("[{}]: Failed to clean up bundle files bundle={}", sessionID, currentBundle, exception);
+ }
+ }
+
+ private Bundle computeNext() throws IOException
+ {
+ List<SSTableFilesAndRange> sstableFiles = new ArrayList<>();
+ long size = 0;
+ while (!collector.isEmpty())
+ {
+ SSTableFilesAndRange sstable = collector.peek();
+ long lastSize = size;
+ size += sstable.size;
+ // Stop adding more, _only_ when
+ // 1) it has included some sstables already, and
+ // 2) adding this one will exceed the size limit
+ // It means that if the first sstable included in the loop is larger than the limit,
+ // the large sstable is added regardless.
+ if (size > maxSizePerBundleInBytes && lastSize != 0)
+ {
+ break;
+ }
+ else
+ {
+ sstableFiles.add(sstable);
+ collector.consumeOne();
+ }
+ }
+
+ // if not exist yet, create folder for holding all zipped bundles
+ Files.createDirectories(bundleStagingDir);
+ return Bundle.builder()
+ .bundleSequence(bundleIndex++)
+ .bundleStagingDirectory(bundleStagingDir)
+ .sourceSSTables(sstableFiles)
+ .bundleNameGenerator(bundleNameGenerator)
+ .build();
+ }
+
+ static long zip(Path sourcePath, Path targetPath) throws IOException
+ {
+ try (ZipOutputStream zos = new ZipOutputStream(Files.newOutputStream(targetPath));
+ Stream<Path> stream = Files.walk(sourcePath, 1))
+ {
+ stream.filter(Files::isRegularFile)
+ .forEach(path -> {
+ ZipEntry zipEntry = new ZipEntry(sourcePath.relativize(path).toString());
+ try
+ {
+ zos.putNextEntry(zipEntry);
+ Files.copy(path, zos);
+ zos.closeEntry();
+ }
+ catch (IOException e)
+ {
+ LOGGER.error("Unexpected error while zipping file. path={}", path, e);
+ throw new RuntimeException(e);
+ }
+ });
+ }
+ return targetPath.toFile().length();
+ }
+}
diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/blobupload/StorageClient.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/blobupload/StorageClient.java
new file mode 100644
index 0000000..bda3d90
--- /dev/null
+++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/blobupload/StorageClient.java
@@ -0,0 +1,260 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.cassandra.spark.bulkwriter.blobupload;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.channels.Channels;
+import java.nio.channels.ReadableByteChannel;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.time.Duration;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.ThreadPoolExecutor;
+import java.util.concurrent.TimeUnit;
+import java.util.function.Consumer;
+import java.util.function.Function;
+import java.util.stream.Collectors;
+
+import org.apache.cassandra.spark.transports.storage.StorageCredentials;
+import org.apache.cassandra.spark.transports.storage.extensions.StorageTransportConfiguration;
+import org.apache.cassandra.spark.utils.ByteBufferUtils;
+import software.amazon.awssdk.auth.credentials.AwsCredentials;
+import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
+import software.amazon.awssdk.auth.credentials.AwsSessionCredentials;
+import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
+import software.amazon.awssdk.awscore.AwsRequestOverrideConfiguration;
+import software.amazon.awssdk.core.async.AsyncRequestBody;
+import software.amazon.awssdk.core.client.config.SdkAdvancedAsyncClientOption;
+import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient;
+import software.amazon.awssdk.http.nio.netty.ProxyConfiguration;
+import software.amazon.awssdk.regions.Region;
+import software.amazon.awssdk.services.s3.S3AsyncClient;
+import software.amazon.awssdk.services.s3.S3AsyncClientBuilder;
+import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadRequest;
+import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadResponse;
+import software.amazon.awssdk.services.s3.model.CompletedMultipartUpload;
+import software.amazon.awssdk.services.s3.model.CompletedPart;
+import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest;
+import software.amazon.awssdk.services.s3.model.CreateMultipartUploadResponse;
+import software.amazon.awssdk.services.s3.model.Tag;
+import software.amazon.awssdk.services.s3.model.Tagging;
+import software.amazon.awssdk.services.s3.model.UploadPartRequest;
+import software.amazon.awssdk.services.s3.model.UploadPartResponse;
+import software.amazon.awssdk.utils.ThreadFactoryBuilder;
+
+/**
+ * Client used for upload SSTable bundle to S3 bucket
+ */
+public class StorageClient implements AutoCloseable
+{
+ public static final char SEPARATOR = '/';
+ private final StorageTransportConfiguration storageTransportConfiguration;
+ private final DataChunker dataChunker;
+ private final Tagging tagging;
+ private final S3AsyncClient client;
+ private final Map<StorageCredentials, AwsCredentialsProvider> credentialsCache;
+
+ public StorageClient(StorageTransportConfiguration storageTransportConfiguration,
+ StorageClientConfig storageClientConfig)
+ {
+ this.storageTransportConfiguration = storageTransportConfiguration;
+ ThreadPoolExecutor executor = new ThreadPoolExecutor(storageClientConfig.concurrency, // core
+ storageClientConfig.concurrency, // max
+ // keep alive
+ storageClientConfig.threadKeepAliveSeconds, TimeUnit.SECONDS,
+ new LinkedBlockingQueue<>(), // unbounded work queue
+ new ThreadFactoryBuilder().threadNamePrefix(storageClientConfig.threadNamePrefix)
+ .daemonThreads(true)
+ .build());
+ // Must set it to allow threads to time out, so that it can release resources when idle.
+ executor.allowCoreThreadTimeOut(true);
+ Map<SdkAdvancedAsyncClientOption<?>, ?> advancedOptions = Collections.singletonMap(
+ SdkAdvancedAsyncClientOption.FUTURE_COMPLETION_EXECUTOR, executor
+ );
+
+
+ S3AsyncClientBuilder clientBuilder = S3AsyncClient.builder()
+ .region(Region.of(this.storageTransportConfiguration.getWriteRegion()))
+ .asyncConfiguration(b -> b.advancedOptions(advancedOptions));
+ if (storageClientConfig.endpointOverride != null)
+ {
+ clientBuilder.endpointOverride(storageClientConfig.endpointOverride)
+ .forcePathStyle(true);
+ }
+ if (storageClientConfig.httpsProxy != null)
+ {
+ ProxyConfiguration proxyConfig = ProxyConfiguration.builder()
+ .host(storageClientConfig.httpsProxy.getHost())
+ .port(storageClientConfig.httpsProxy.getPort())
+ .scheme(storageClientConfig.httpsProxy.getScheme())
+ .build();
+ Duration connectionAcquisitionTimeout = Duration.ofSeconds(storageClientConfig.nioHttpClientConnectionAcquisitionTimeoutSeconds);
+ clientBuilder.httpClient(NettyNioAsyncHttpClient.builder()
+ .proxyConfiguration(proxyConfig)
+ .connectionAcquisitionTimeout(connectionAcquisitionTimeout)
+ .maxConcurrency(storageClientConfig.nioHttpClientMaxConcurrency)
+ .build());
+ }
+ this.client = clientBuilder.build();
+ this.dataChunker = new DataChunker(storageClientConfig.maxChunkSizeInBytes);
+ List<Tag> tags = this.storageTransportConfiguration.getTags()
+ .entrySet()
+ .stream()
+ .map(entry -> Tag.builder()
+ .key(entry.getKey())
+ .value(entry.getValue())
+ .build())
+ .collect(Collectors.toList());
+ this.tagging = Tagging.builder().tagSet(tags).build();
+ this.credentialsCache = new ConcurrentHashMap<>();
+ }
+
+ /**
+ * We use {@link CreateMultipartUploadRequest} to break down each SSTable bundle into chunks, according to
+ * chunk size set, and upload to S3.
+ *
+ * @param credentials credentials used for uploading to S3
+ * @param bundle bundle of sstables
+ * @return BundleStorageObject representing the uploaded bundle
+ * @throws IOException when an IO exception occurs during the multipart upload
+ * @throws ExecutionException when it fails to retrieve the state of a task
+ * @throws InterruptedException when the thread is interrupted
+ */
+ public BundleStorageObject multiPartUpload(StorageCredentials credentials,
+ Bundle bundle)
+ throws IOException, ExecutionException, InterruptedException
+ {
+ if (credentials == null)
+ {
+ throw new IllegalArgumentException("No credentials provided for uploading bundles");
+ }
+
+ String key = calculateStorageKeyForBundle(storageTransportConfiguration.getPrefix(),
+ bundle.bundleFile);
+ CreateMultipartUploadRequest multipartUploadRequest = CreateMultipartUploadRequest.builder()
+ .overrideConfiguration(credentialsOverride(credentials))
+ .bucket(storageTransportConfiguration.getWriteBucket())
+ .key(key)
+ .tagging(tagging)
+ .build();
+
+ CreateMultipartUploadResponse multipartUploadResponse = client.createMultipartUpload(multipartUploadRequest).get();
+ String uploadId = multipartUploadResponse.uploadId();
+
+ List<CompletedPart> completedParts = uploadPartsOfBundle(key, uploadId, credentials, bundle);
+
+ // tell s3 to merge all completed parts by making the CompleteMultipartUploadRequest
+ CompletedMultipartUpload completedUpload = CompletedMultipartUpload.builder()
+ .parts(completedParts)
+ .build();
+
+ CompleteMultipartUploadRequest completeMultipartUploadRequest = CompleteMultipartUploadRequest.builder()
+ .overrideConfiguration(credentialsOverride(credentials))
+ .bucket(storageTransportConfiguration.getWriteBucket())
+ .key(key)
+ .uploadId(uploadId)
+ .multipartUpload(completedUpload)
+ .build();
+ CompleteMultipartUploadResponse completeMultipartUploadResponse = client.completeMultipartUpload(completeMultipartUploadRequest).get();
+ return BundleStorageObject.builder()
+ .bundle(bundle)
+ .storageObjectKey(key)
+ .storageObjectChecksum(completeMultipartUploadResponse.eTag())
+ .build();
+ }
+
+ @Override
+ public void close()
+ {
+ if (client != null)
+ {
+ client.close();
+ }
+ }
+
+ private List<CompletedPart> uploadPartsOfBundle(String key, String uploadId,
+ StorageCredentials credentials,
+ Bundle bundle)
+ throws IOException, ExecutionException, InterruptedException
+ {
+ List<CompletableFuture<CompletedPart>> futures = new ArrayList<>();
+ // upload the zip file using multipart upload
+ // todo: use the simple upload when the zip file is small
+ try (ReadableByteChannel channel = Channels.newChannel(Files.newInputStream(bundle.bundleFile)))
+ {
+ Iterator<ByteBuffer> chunks = dataChunker.chunks(channel);
+ // part number must start from 1 and not exceed 10,000.
+ // See https://docs.aws.amazon.com/AmazonS3/latest/userguide/mpuoverview.html#mpuchecksums
+ int partNumber = 1;
+ while (chunks.hasNext())
+ {
+ int loopPartNumber = partNumber;
+ ByteBuffer buffer = chunks.next();
+ AsyncRequestBody body = AsyncRequestBody.fromBytes(ByteBufferUtils.getArray(buffer));
+ UploadPartRequest uploadPartRequest = UploadPartRequest.builder()
+ .overrideConfiguration(credentialsOverride(credentials))
+ .bucket(storageTransportConfiguration.getWriteBucket())
+ .key(key)
+ .uploadId(uploadId)
+ .partNumber(loopPartNumber)
+ .build();
+
+ Function<UploadPartResponse, CompletedPart> buildPart = r -> CompletedPart.builder()
+ .partNumber(loopPartNumber)
+ .eTag(r.eTag())
+ .build();
+ CompletableFuture<CompletedPart> completedPart = client.uploadPart(uploadPartRequest, body)
+ .thenApply(buildPart);
+ futures.add(completedPart);
+ partNumber++;
+ }
+ }
+
+ CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])).get(); // exit/throw early from here
+ return futures.stream().map(CompletableFuture::join).collect(Collectors.toList()); // or collect in the correct sequence
+ }
+
+ private String calculateStorageKeyForBundle(String prefix, Path bundleLocation)
+ {
+ return prefix + SEPARATOR + bundleLocation.getFileName();
+ }
+
+ private AwsCredentialsProvider toCredentialsProvider(StorageCredentials storageCredentials)
+ {
+ AwsCredentials credentials = AwsSessionCredentials.create(storageCredentials.getAccessKeyId(),
+ storageCredentials.getSecretKey(),
+ storageCredentials.getSessionToken());
+ return StaticCredentialsProvider.create(credentials);
+ }
+
+ private Consumer<AwsRequestOverrideConfiguration.Builder> credentialsOverride(StorageCredentials credentials)
+ {
+ return b -> b.credentialsProvider(credentialsCache.computeIfAbsent(credentials, this::toCredentialsProvider));
+ }
+}
diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/blobupload/StorageClientConfig.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/blobupload/StorageClientConfig.java
new file mode 100644
index 0000000..39ff4dd
--- /dev/null
+++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/blobupload/StorageClientConfig.java
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.cassandra.spark.bulkwriter.blobupload;
+
+import java.io.Serializable;
+import java.net.URI;
+import java.net.URISyntaxException;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class StorageClientConfig implements Serializable
+{
+ private static final long serialVersionUID = -1572678388713210328L;
+ private static final Logger LOGGER = LoggerFactory.getLogger(StorageClientConfig.class);
+
+ public final String threadNamePrefix;
+ // Controls the max concurrency/parallelism of the thread pool used by s3 client
+ public final int concurrency;
+ // Controls the timeout of idle threads
+ public final long threadKeepAliveSeconds;
+ public final int maxChunkSizeInBytes;
+ public final URI httpsProxy; // optional; configures https proxy for s3 client
+ public final long nioHttpClientConnectionAcquisitionTimeoutSeconds; // optional; only applied for NettyNioHttpClient
+ public final int nioHttpClientMaxConcurrency; // optional; only applied for NettyNioHttpClient
+
+ public final URI endpointOverride; // nullable; only used for testing.
+
+ public StorageClientConfig(int concurrency,
+ long threadKeepAliveSeconds,
+ int maxChunkSizeInBytes,
+ String httpsProxy,
+ String endpointOverride,
+ long nioHttpClientConnectionAcquisitionTimeoutSeconds,
+ int nioHttpClientMaxConcurrency)
+ {
+ this.threadNamePrefix = "storage-client";
+ this.concurrency = concurrency;
+ this.threadKeepAliveSeconds = threadKeepAliveSeconds;
+ this.maxChunkSizeInBytes = maxChunkSizeInBytes;
+ this.httpsProxy = toURI(httpsProxy, "HttpsProxy");
+ this.endpointOverride = toURI(endpointOverride, "EndpointOverride");
+ this.nioHttpClientConnectionAcquisitionTimeoutSeconds = nioHttpClientConnectionAcquisitionTimeoutSeconds;
+ this.nioHttpClientMaxConcurrency = nioHttpClientMaxConcurrency;
+ }
+
+ private URI toURI(String uriString, String hint)
+ {
+ if (uriString == null)
+ {
+ return null;
+ }
+
+ try
+ {
+ return new URI(uriString);
+ }
+ catch (URISyntaxException e)
+ {
+ LOGGER.error("{} is specified, but the value is invalid. input={}", hint, uriString);
+ throw new RuntimeException("Unable to resolve " + uriString, e);
+ }
+ }
+}
diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/token/ConsistencyLevel.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/token/ConsistencyLevel.java
index 7785587..dd48bfe 100644
--- a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/token/ConsistencyLevel.java
+++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/token/ConsistencyLevel.java
@@ -19,16 +19,42 @@
package org.apache.cassandra.spark.bulkwriter.token;
+import java.util.Collection;
+import java.util.Objects;
import java.util.Set;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
+import com.google.common.base.Preconditions;
+
+import org.apache.cassandra.spark.common.model.CassandraInstance;
+import org.apache.cassandra.spark.data.ReplicationFactor;
public interface ConsistencyLevel
{
+ /**
+ * Whether the consistency level only considers replicas in the local data center.
+ *
+ * @return true if only considering the local replicas; otherwise, return false
+ */
boolean isLocal();
- Logger LOGGER = LoggerFactory.getLogger(ConsistencyLevel.class);
+ /**
+ * Check consistency level with the collection of the succeeded instances
+ *
+ * @param succeededInstances the succeeded instances in the replica set
+ * @param replicationFactor replication factor to check with
+ * @param localDC the local data center name if required for the check
+ * @return true means the consistency level is _definitively_ satisfied.
+ * Meanwhile, returning false means no conclusion can be drawn
+ */
+ boolean canBeSatisfied(Collection<? extends CassandraInstance> succeededInstances,
+ ReplicationFactor replicationFactor,
+ String localDC);
+
+ default void ensureNetworkTopologyStrategy(ReplicationFactor replicationFactor, CL cl)
+ {
+ Preconditions.checkArgument(replicationFactor.getReplicationStrategy() == ReplicationFactor.ReplicationStrategy.NetworkTopologyStrategy,
+ cl.name() + " only make sense for NetworkTopologyStrategy keyspaces");
+ }
/**
* Checks if the consistency guarantees are maintained, given the failed, blocked and replacing instances, consistency-level and the replication-factor.
@@ -45,7 +71,7 @@
*
* @param writeReplicas the set of replicas for write operations
* @param pendingReplicas the set of replicas pending status
- * @param replacementInstances the instances being replaced
+ * @param replacementInstances the set of instances that are replacing the other instances
* @param blockedInstances the set of instances that have been blocked for the bulk operation
* @param failedInstanceIps the set of instances where there were failures
* @param localDC the local datacenter used for consistency level, or {@code null} if not provided
@@ -56,7 +82,7 @@
Set<String> replacementInstances,
Set<String> blockedInstances,
Set<String> failedInstanceIps,
- String localDC);
+ String localDC); // todo: simplify the parameter list. not all are required in impl
// Check if successful writes forms quorum of non-replacing nodes - N/A as quorum is if there are no failures/blocked
enum CL implements ConsistencyLevel
@@ -80,6 +106,17 @@
int failedExcludingReplacements = failedInstanceIps.size() - replacementInstances.size();
return failedExcludingReplacements <= 0 && blockedInstances.isEmpty();
}
+
+ @Override
+ public boolean canBeSatisfied(Collection<? extends CassandraInstance> succeededInstances,
+ ReplicationFactor replicationFactor,
+ String localDC)
+ {
+ int rf = replicationFactor.getTotalReplicationFactor();
+ // The effective RF during expansion could be larger than the defined RF
+ // The check for CL satisfaction should consider the scenario and use >=
+ return succeededInstances.size() >= rf;
+ }
},
EACH_QUORUM
@@ -100,6 +137,28 @@
{
return (failedInstanceIps.size() + blockedInstances.size()) <= (writeReplicas.size() - (writeReplicas.size() / 2 + 1));
}
+
+ @Override
+ public boolean canBeSatisfied(Collection<? extends CassandraInstance> succeededInstances,
+ ReplicationFactor replicationFactor,
+ String localDC)
+ {
+ ensureNetworkTopologyStrategy(replicationFactor, EACH_QUORUM);
+ Objects.requireNonNull(localDC, "localDC cannot be null");
+
+ for (String datacenter : replicationFactor.getOptions().keySet())
+ {
+ int rf = replicationFactor.getOptions().get(datacenter);
+ int majority = rf / 2 + 1;
+ if (succeededInstances.stream()
+ .filter(instance -> instance.datacenter().equalsIgnoreCase(datacenter))
+ .count() < majority)
+ {
+ return false;
+ }
+ }
+ return true;
+ }
},
QUORUM
{
@@ -119,6 +178,15 @@
{
return (failedInstanceIps.size() + blockedInstances.size()) <= (writeReplicas.size() - (writeReplicas.size() / 2 + 1));
}
+
+ @Override
+ public boolean canBeSatisfied(Collection<? extends CassandraInstance> succeededInstances,
+ ReplicationFactor replicationFactor,
+ String localDC)
+ {
+ int rf = replicationFactor.getTotalReplicationFactor();
+ return succeededInstances.size() > rf / 2;
+ }
},
LOCAL_QUORUM
{
@@ -138,6 +206,20 @@
{
return (failedInstanceIps.size() + blockedInstances.size()) <= (writeReplicas.size() - (writeReplicas.size() / 2 + 1));
}
+
+ @Override
+ public boolean canBeSatisfied(Collection<? extends CassandraInstance> succeededInstances,
+ ReplicationFactor replicationFactor,
+ String localDC)
+ {
+ ensureNetworkTopologyStrategy(replicationFactor, LOCAL_QUORUM);
+ Objects.requireNonNull(localDC, "localDC cannot be null");
+
+ int rf = replicationFactor.getOptions().get(localDC);
+ return succeededInstances.stream()
+ .filter(instance -> instance.datacenter().equalsIgnoreCase(localDC))
+ .count() > rf / 2;
+ }
},
ONE
{
@@ -158,6 +240,14 @@
return (failedInstanceIps.size() + blockedInstances.size())
<= (writeReplicas.size() - pendingReplicas.size() - 1);
}
+
+ @Override
+ public boolean canBeSatisfied(Collection<? extends CassandraInstance> succeededInstances,
+ ReplicationFactor replicationFactor,
+ String localDC)
+ {
+ return !succeededInstances.isEmpty();
+ }
},
TWO
{
@@ -178,6 +268,14 @@
return (failedInstanceIps.size() + blockedInstances.size())
<= (writeReplicas.size() - pendingReplicas.size() - 2);
}
+
+ @Override
+ public boolean canBeSatisfied(Collection<? extends CassandraInstance> succeededInstances,
+ ReplicationFactor replicationFactor,
+ String localDC)
+ {
+ return succeededInstances.size() >= 2;
+ }
},
LOCAL_ONE
{
@@ -197,6 +295,17 @@
{
return (failedInstanceIps.size() + blockedInstances.size()) <= (writeReplicas.size() - pendingReplicas.size() - 1);
}
+
+ @Override
+ public boolean canBeSatisfied(Collection<? extends CassandraInstance> succeededInstances,
+ ReplicationFactor replicationFactor,
+ String localDC)
+ {
+ ensureNetworkTopologyStrategy(replicationFactor, LOCAL_ONE);
+ Objects.requireNonNull(localDC, "localDC cannot be null");
+
+ return succeededInstances.stream().anyMatch(instance -> instance.datacenter().equalsIgnoreCase(localDC));
+ }
};
}
}
diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/token/ReplicaAwareFailureHandler.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/token/ReplicaAwareFailureHandler.java
index d66f65c..06382f8 100644
--- a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/token/ReplicaAwareFailureHandler.java
+++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/token/ReplicaAwareFailureHandler.java
@@ -64,7 +64,7 @@
* @param casInstance the instance on which the range failed
* @param errMessage the error that occurred for this particular range/instance pair
*/
- public void addFailure(Range<BigInteger> tokenRange, Instance casInstance, String errMessage)
+ public synchronized void addFailure(Range<BigInteger> tokenRange, Instance casInstance, String errMessage)
{
RangeMap<BigInteger, Multimap<Instance, String>> overlappingFailures = failedRangesMap.subRangeMap(tokenRange);
RangeMap<BigInteger, Multimap<Instance, String>> mappingsToAdd = TreeRangeMap.create();
@@ -97,7 +97,7 @@
* @return list of failed entries for token ranges that break consistency. This should ideally be empty for a
* successful operation.
*/
- public Collection<AbstractMap.SimpleEntry<Range<BigInteger>, Multimap<Instance, String>>>
+ public synchronized Collection<AbstractMap.SimpleEntry<Range<BigInteger>, Multimap<Instance, String>>>
getFailedEntries(TokenRangeMapping<? extends CassandraInstance> tokenRangeMapping,
ConsistencyLevel cl,
String localDC)
diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/util/IOUtils.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/util/IOUtils.java
new file mode 100644
index 0000000..99e3829
--- /dev/null
+++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/util/IOUtils.java
@@ -0,0 +1,130 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.cassandra.spark.bulkwriter.util;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.util.stream.Stream;
+import java.util.zip.ZipEntry;
+import java.util.zip.ZipOutputStream;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import net.jpountz.xxhash.StreamingXXHash32;
+import net.jpountz.xxhash.XXHashFactory;
+
+public final class IOUtils
+{
+ private static final Logger LOGGER = LoggerFactory.getLogger(IOUtils.class);
+
+ private static final int HASH_BUFFER_SIZE = 512 * 1024; // 512KiB
+
+ private IOUtils()
+ {
+ throw new UnsupportedOperationException("Cannot instantiate utility class");
+ }
+
+ /**
+ * Zip the files under source path. It doest not zip recursively.
+ * @param sourcePath directory that contains files to be zipped
+ * @param targetPath output zip file path
+ * @return compressed size, i.e. the size of the zip file
+ * @throws IOException I/O exception during zipping
+ */
+ public static long zip(Path sourcePath, Path targetPath) throws IOException
+ {
+ return zip(sourcePath, targetPath, 1);
+ }
+
+ /**
+ * Zip the files under source path. The files within the maxDepth directory levels are considered.
+ * @param sourcePath directory that contains files to be zipped
+ * @param targetPath output zip file path
+ * @param maxDepth the maximum number of directory levels to visit
+ * @return compressed size, i.e. the size of the zip file
+ * @throws IOException I/O exception during zipping
+ */
+ public static long zip(Path sourcePath, Path targetPath, int maxDepth) throws IOException
+ {
+ if (!Files.isDirectory(sourcePath))
+ {
+ throw new IOException("Not a directory. sourcePath: " + sourcePath);
+ }
+
+ try (ZipOutputStream zos = new ZipOutputStream(Files.newOutputStream(targetPath));
+ Stream<Path> stream = Files.walk(sourcePath, maxDepth))
+ {
+ stream.filter(Files::isRegularFile)
+ .forEach(path -> {
+ ZipEntry zipEntry = new ZipEntry(sourcePath.relativize(path).toString());
+ try
+ {
+ zos.putNextEntry(zipEntry);
+ Files.copy(path, zos);
+ zos.closeEntry();
+ }
+ catch (IOException e)
+ {
+ LOGGER.error("Unexpected error while zipping SSTable components, path = {} not zipped, ",
+ path, e);
+ throw new RuntimeException(e);
+ }
+ });
+ }
+ return targetPath.toFile().length();
+ }
+
+ /**
+ * Calculate the checksum of the file using the specified buffer size
+ * @param path file
+ * @param bufferSize buffer size for file content to calculate checksum
+ * @return checksum string
+ * @throws IOException I/O exception during checksum calculation
+ */
+ public static String xxhash32(Path path, int bufferSize) throws IOException
+ {
+ XXHashFactory factory = XXHashFactory.safeInstance();
+ try (InputStream inputStream = Files.newInputStream(path);
+ StreamingXXHash32 hasher = factory.newStreamingHash32(0))
+ {
+ int len;
+ byte[] buffer = new byte[bufferSize];
+ while ((len = inputStream.read(buffer)) != -1)
+ {
+ hasher.update(buffer, 0, len);
+ }
+ return Integer.toHexString(hasher.getValue());
+ }
+ }
+
+ /**
+ * Calculate the checksum of the file using the default buffer size
+ * @param path file
+ * @return checksum string
+ * @throws IOException I/O exception during checksum calculation
+ */
+ public static String xxhash32(Path path) throws IOException
+ {
+ return xxhash32(path, HASH_BUFFER_SIZE);
+ }
+}
diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/util/SbwKryoRegistrator.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/util/SbwKryoRegistrator.java
index d9bfb76..6fd362e 100644
--- a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/util/SbwKryoRegistrator.java
+++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/util/SbwKryoRegistrator.java
@@ -31,6 +31,9 @@
import org.apache.cassandra.spark.bulkwriter.CassandraBulkWriterContext;
import org.apache.cassandra.spark.bulkwriter.RingInstance;
import org.apache.cassandra.spark.bulkwriter.TokenPartitioner;
+import org.apache.cassandra.spark.transports.storage.StorageCredentialPair;
+import org.apache.cassandra.spark.transports.storage.StorageCredentials;
+import org.apache.cassandra.spark.transports.storage.extensions.StorageTransportConfiguration;
import org.apache.spark.SparkConf;
import org.apache.spark.serializer.KryoRegistrator;
import org.jetbrains.annotations.NotNull;
@@ -55,6 +58,9 @@
javaSerializableClasses.stream()
.sorted(Comparator.comparing(Class::getCanonicalName))
.forEach(javaSerializableClass -> kryo.register(javaSerializableClass, new SbwJavaSerializer()));
+ kryo.register(StorageTransportConfiguration.class, new StorageTransportConfiguration.Serializer());
+ kryo.register(StorageCredentialPair.class, new StorageCredentialPair.Serializer());
+ kryo.register(StorageCredentials.class, new StorageCredentials.Serializer());
}
public static void addJavaSerializableClass(@NotNull Class<? extends Serializable> javaSerializableClass)
diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/util/TaskContextUtils.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/util/TaskContextUtils.java
new file mode 100644
index 0000000..a60bff0
--- /dev/null
+++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/bulkwriter/util/TaskContextUtils.java
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.cassandra.spark.bulkwriter.util;
+
+import java.math.BigInteger;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.UUID;
+
+import com.google.common.collect.Range;
+
+import org.apache.cassandra.spark.bulkwriter.JobInfo;
+import org.apache.spark.TaskContext;
+
+public final class TaskContextUtils
+{
+ private TaskContextUtils()
+ {
+ }
+
+ public static Range<BigInteger> getTokenRange(TaskContext taskContext, JobInfo job)
+ {
+ return job.getTokenPartitioner().getTokenRange(taskContext.partitionId());
+ }
+
+ /**
+ * Create a new stream session identifier
+ * <p>
+ * Note that the stream session id is used as part of filename. It cannot contain invalid characters for filename.
+ * @param taskContext task context
+ * @return a new stream ID
+ */
+ public static String createStreamSessionId(TaskContext taskContext)
+ {
+ return String.format("%d-%d-%s", taskContext.partitionId(), taskContext.attemptNumber(), UUID.randomUUID());
+ }
+
+ /**
+ * Create a path that is unique to the spark partition
+ * @return path
+ */
+ public static Path getPartitionUniquePath(String basePath, String jobId, TaskContext taskContext)
+ {
+ return Paths.get(basePath,
+ jobId,
+ Integer.toString(taskContext.stageAttemptNumber()),
+ Integer.toString(taskContext.attemptNumber()),
+ Integer.toString(taskContext.partitionId()));
+ }
+}
diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/common/DataObjectBuilder.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/common/DataObjectBuilder.java
new file mode 100644
index 0000000..fa88061
--- /dev/null
+++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/common/DataObjectBuilder.java
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.cassandra.spark.common;
+
+import java.util.function.Consumer;
+
+/**
+ * Interface to build data objects
+ * @param <T> type of builder
+ * @param <R> type of result from build
+ */
+public interface DataObjectBuilder<T extends DataObjectBuilder<?, ?>, R>
+{
+ /**
+ * Build into data object of type R
+ * @return data object type
+ */
+ R build();
+
+ /**
+ * Self typing
+ * @return type of implementor class
+ */
+ T self();
+
+ /**
+ * Update fields in builder
+ * @param updater function to update fields
+ * @return builder itself for chained invocation
+ */
+ default T with(Consumer<? super T> updater)
+ {
+ updater.accept(self());
+ return self();
+ }
+}
diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/common/stats/LogStatsPublisher.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/common/stats/LogStatsPublisher.java
index ad99755..ce6ac0c 100644
--- a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/common/stats/LogStatsPublisher.java
+++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/common/stats/LogStatsPublisher.java
@@ -35,6 +35,6 @@
@Override
public void publish(Map<String, String> stats)
{
- LOGGER.info("Job Stats:" + stats);
+ LOGGER.info("Job Stats: {}", stats);
}
}
diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/data/CassandraDataLayer.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/data/CassandraDataLayer.java
index ccbb05e..5bd2a1d 100644
--- a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/data/CassandraDataLayer.java
+++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/data/CassandraDataLayer.java
@@ -29,7 +29,6 @@
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
-import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
@@ -69,11 +68,11 @@
import org.apache.cassandra.bridge.CassandraVersion;
import org.apache.cassandra.clients.ExecutorHolder;
import org.apache.cassandra.clients.Sidecar;
-import org.apache.cassandra.clients.SidecarInstanceImpl;
import org.apache.cassandra.secrets.SslConfig;
import org.apache.cassandra.secrets.SslConfigSecretsProvider;
import org.apache.cassandra.sidecar.client.SidecarClient;
import org.apache.cassandra.sidecar.client.SidecarInstance;
+import org.apache.cassandra.sidecar.client.SidecarInstanceImpl;
import org.apache.cassandra.sidecar.client.SimpleSidecarInstancesProvider;
import org.apache.cassandra.sidecar.client.exception.RetriesExhaustedException;
import org.apache.cassandra.spark.config.SchemaFeature;
@@ -120,7 +119,10 @@
protected String maybeQuotedKeyspace;
protected String maybeQuotedTable;
protected CassandraBridge bridge;
- protected Set<? extends SidecarInstance> clusterConfig;
+ // create clusterConfig from sidecarInstances and sidecarPort, see initializeClusterConfig
+ protected String sidecarInstances;
+ protected int sidecarPort;
+ protected transient Set<? extends SidecarInstance> clusterConfig;
protected TokenPartitioner tokenPartitioner;
protected Map<String, AvailabilityHint> availabilityHints;
protected Sidecar.ClientConfig sidecarClientConfig;
@@ -152,6 +154,8 @@
this.table = options.table();
this.quoteIdentifiers = options.quoteIdentifiers();
this.sidecarClientConfig = sidecarClientConfig;
+ this.sidecarInstances = options.sidecarInstances;
+ this.sidecarPort = options.sidecarPort;
this.sslConfig = sslConfig;
this.bigNumberConfigMap = options.bigNumberConfigMap();
this.enableStats = options.enableStats();
@@ -175,7 +179,8 @@
@NotNull TokenPartitioner tokenPartitioner,
@NotNull CassandraVersion version,
@NotNull ConsistencyLevel consistencyLevel,
- @NotNull Set<SidecarInstanceImpl> clusterConfig,
+ @NotNull String sidecarInstances,
+ @NotNull int sidecarPort,
@NotNull Map<String, PartitionedDataLayer.AvailabilityHint> availabilityHints,
@NotNull Map<String, BigNumberConfigImpl> bigNumberConfigMap,
boolean enableStats,
@@ -194,7 +199,7 @@
this.quoteIdentifiers = quoteIdentifiers;
this.cqlTable = cqlTable;
this.tokenPartitioner = tokenPartitioner;
- this.clusterConfig = clusterConfig;
+ this.clusterConfig = initializeClusterConfig(sidecarInstances, sidecarPort);
this.availabilityHints = availabilityHints;
this.sidecarClientConfig = sidecarClientConfig;
this.sslConfig = sslConfig;
@@ -694,7 +699,9 @@
this.cqlTable = bridge.javaDeserialize(in, CqlTable.class); // Delegate (de-)serialization of version-specific objects to the Cassandra Bridge
this.tokenPartitioner = (TokenPartitioner) in.readObject();
- this.clusterConfig = (Set<SidecarInstanceImpl>) in.readObject();
+ this.sidecarInstances = in.readUTF();
+ this.sidecarPort = in.readInt();
+ this.clusterConfig = initializeClusterConfig(sidecarInstances, sidecarPort);
this.availabilityHints = (Map<String, AvailabilityHint>) in.readObject();
this.bigNumberConfigMap = (Map<String, BigNumberConfigImpl>) in.readObject();
this.enableStats = in.readBoolean();
@@ -742,7 +749,8 @@
out.writeObject(this.sslConfig);
bridge.javaSerialize(out, this.cqlTable); // Delegate (de-)serialization of version-specific objects to the Cassandra Bridge
out.writeObject(this.tokenPartitioner);
- out.writeObject(this.clusterConfig);
+ out.writeUTF(this.sidecarInstances);
+ out.writeInt(this.sidecarPort);
out.writeObject(this.availabilityHints);
out.writeObject(this.bigNumberConfigMap);
out.writeBoolean(this.enableStats);
@@ -811,7 +819,8 @@
kryo.writeObject(out, dataLayer.tokenPartitioner);
kryo.writeObject(out, dataLayer.version());
kryo.writeObject(out, dataLayer.consistencyLevel);
- kryo.writeObject(out, dataLayer.clusterConfig);
+ out.writeString(dataLayer.sidecarInstances);
+ out.writeInt(dataLayer.sidecarPort);
kryo.writeObject(out, dataLayer.availabilityHints);
out.writeBoolean(dataLayer.bigNumberConfigMap.isEmpty()); // Kryo fails to deserialize bigNumberConfigMap map if empty
if (!dataLayer.bigNumberConfigMap.isEmpty())
@@ -860,7 +869,8 @@
kryo.readObject(in, TokenPartitioner.class),
kryo.readObject(in, CassandraVersion.class),
kryo.readObject(in, ConsistencyLevel.class),
- kryo.readObject(in, HashSet.class),
+ in.readString(), // sidecarInstances
+ in.readInt(), // sidecarPort
(Map<String, PartitionedDataLayer.AvailabilityHint>) kryo.readObject(in, HashMap.class),
in.readBoolean() ? Collections.emptyMap()
: (Map<String, BigNumberConfigImpl>) kryo.readObject(in, HashMap.class),
@@ -889,8 +899,14 @@
protected Set<? extends SidecarInstance> initializeClusterConfig(ClientConfig options)
{
- return Arrays.stream(options.sidecarInstances().split(","))
- .map(hostname -> new SidecarInstanceImpl(hostname, options.sidecarPort()))
+ return initializeClusterConfig(options.sidecarInstances, options.sidecarPort());
+ }
+
+ // not intended to be overridden
+ private Set<? extends SidecarInstance> initializeClusterConfig(String sidecarInstances, int sidecarPort)
+ {
+ return Arrays.stream(sidecarInstances.split(","))
+ .map(hostname -> new SidecarInstanceImpl(hostname, sidecarPort))
.collect(Collectors.toSet());
}
diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/data/FileSystemSSTable.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/data/FileSystemSSTable.java
index c8c4296..07f228f 100644
--- a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/data/FileSystemSSTable.java
+++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/data/FileSystemSSTable.java
@@ -26,6 +26,7 @@
import java.io.InputStream;
import java.nio.file.Path;
import java.util.function.Supplier;
+import javax.validation.constraints.NotNull;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -36,15 +37,16 @@
import org.apache.cassandra.spark.utils.streaming.SSTableInputStream;
import org.jetbrains.annotations.Nullable;
-class FileSystemSSTable extends SSTable
+public class FileSystemSSTable extends SSTable
{
private static final Logger LOGGER = LoggerFactory.getLogger(FileSystemSSTable.class);
+ private static final long serialVersionUID = -7545780596504602254L;
- private final Path dataFilePath;
- private final boolean useSSTableInputStream;
- private final Supplier<Stats> stats;
+ private final transient Path dataFilePath;
+ private final transient boolean useSSTableInputStream;
+ private final transient Supplier<Stats> stats;
- FileSystemSSTable(Path dataFilePath, boolean useSSTableInputStream, Supplier<Stats> stats)
+ public FileSystemSSTable(@NotNull Path dataFilePath, boolean useSSTableInputStream, @NotNull Supplier<Stats> stats)
{
this.dataFilePath = dataFilePath;
this.useSSTableInputStream = useSSTableInputStream;
@@ -62,8 +64,8 @@
try
{
return useSSTableInputStream
- ? new SSTableInputStream<>(new FileSystemSource(this, fileType, filePath), stats.get())
- : new BufferedInputStream(new FileInputStream(filePath.toFile()));
+ ? new SSTableInputStream<>(new FileSystemSource(this, fileType, filePath), stats.get())
+ : new BufferedInputStream(new FileInputStream(filePath.toFile()));
}
catch (FileNotFoundException exception)
{
@@ -110,6 +112,6 @@
public boolean equals(Object other)
{
return other instanceof FileSystemSSTable
- && this.dataFilePath.equals(((FileSystemSSTable) other).dataFilePath);
+ && this.dataFilePath.equals(((FileSystemSSTable) other).dataFilePath);
}
}
diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/data/QualifiedTableName.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/data/QualifiedTableName.java
new file mode 100644
index 0000000..da228d7
--- /dev/null
+++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/data/QualifiedTableName.java
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.cassandra.spark.data;
+
+import java.util.Objects;
+
+import org.jetbrains.annotations.NotNull;
+
+/**
+ * Contains the keyspace and table name in Cassandra
+ */
+public class QualifiedTableName
+{
+ @NotNull
+ private final String keyspace;
+ @NotNull
+ private final String table;
+ private final boolean quoteIdentifiers;
+
+ /**
+ * Constructs a qualified table name with the given {@code keyspace} and {@code tableName}
+ *
+ * @param keyspace the unquoted keyspace name in Cassandra
+ * @param tableName the unquoted table name in Cassandra
+ * @param quoteIdentifiers indicate whether the identifiers should be quoted
+ */
+ public QualifiedTableName(String keyspace, String tableName, boolean quoteIdentifiers)
+ {
+ this.keyspace = Objects.requireNonNull(keyspace);
+ this.table = Objects.requireNonNull(tableName);
+ this.quoteIdentifiers = quoteIdentifiers;
+ }
+
+ /**
+ * Construct a qualified table name that its keyspace and table name does not need to be quoted
+ *
+ * @param keyspace the unquoted keyspace name in Cassandra
+ * @param tableName the unquoted table name in Cassandra
+ */
+ public QualifiedTableName(String keyspace, String tableName)
+ {
+ this(keyspace, tableName, false);
+ }
+
+ /**
+ * @return the keyspace in Cassandra
+ */
+ public String keyspace()
+ {
+ return keyspace;
+ }
+
+ /**
+ * @return the table name in Cassandra
+ */
+ public String table()
+ {
+ return table;
+ }
+
+ /**
+ * @return the identifiers should be quoted
+ */
+ public boolean quoteIdentifiers()
+ {
+ return quoteIdentifiers;
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public String toString()
+ {
+ return maybeQuote(keyspace) + "." + maybeQuote(table);
+ }
+
+ private String maybeQuote(String name)
+ {
+ if (quoteIdentifiers)
+ {
+ return '"' + name + '"';
+ }
+
+ return name;
+ }
+}
diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/sparksql/CassandraDataSink.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/sparksql/CassandraDataSink.java
index 50c9359..ad87277 100644
--- a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/sparksql/CassandraDataSink.java
+++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/sparksql/CassandraDataSink.java
@@ -74,15 +74,15 @@
// Initialize the job group ID for later use if we need to cancel the job
// TODO: Can we get a more descriptive "description" in here from the end user somehow?
BulkWriterContext writerContext = createBulkWriterContext(
- sqlContext.sparkContext(),
- ScalaConversionUtils.<String, String>mapAsJavaMap(parameters),
- data.schema());
+ sqlContext.sparkContext(),
+ ScalaConversionUtils.<String, String>mapAsJavaMap(parameters),
+ data.schema());
try
{
JobInfo jobInfo = writerContext.job();
- String description = "Cassandra Bulk Load for table " + jobInfo.getFullTableName();
+ String description = "Cassandra Bulk Load for table " + jobInfo.qualifiedTableName();
CassandraBulkSourceRelation relation = new CassandraBulkSourceRelation(writerContext, sqlContext);
- sqlContext.sparkContext().setJobGroup(jobInfo.getId().toString(), description, false);
+ sqlContext.sparkContext().setJobGroup(jobInfo.getId(), description, false);
relation.insert(data, false);
return relation;
}
diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/transports/storage/StorageCredentialPair.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/transports/storage/StorageCredentialPair.java
new file mode 100644
index 0000000..7ee52e1
--- /dev/null
+++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/transports/storage/StorageCredentialPair.java
@@ -0,0 +1,116 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.cassandra.spark.transports.storage;
+
+import java.io.Serializable;
+import java.util.Objects;
+
+import com.esotericsoftware.kryo.Kryo;
+import com.esotericsoftware.kryo.io.Input;
+import com.esotericsoftware.kryo.io.Output;
+import o.a.c.sidecar.client.shaded.common.data.RestoreJobSecrets;
+
+/**
+ * A class representing the pair of credentials needed to complete an analytics operation using the Storage transport.
+ * It is possible that both credentials (read and write) will be the same, but also that they could represent
+ * the credentials needed for two different buckets when using cross-region synchronization to transfer data
+ * between regions.
+ */
+public class StorageCredentialPair implements Serializable
+{
+ private static final long serialVersionUID = 6084829690503608102L;
+ StorageCredentials writeCredentials;
+ StorageCredentials readCredentials;
+
+ /**
+ * Create a new instance of a StorageCredentialPair
+ *
+ * @param writeCredentials the credentials used for writing to the storage endpoint.
+ * @param readCredentials the credentials used to read from the storage endpoint.
+ */
+ public StorageCredentialPair(StorageCredentials writeCredentials, StorageCredentials readCredentials)
+ {
+ this.writeCredentials = writeCredentials;
+ this.readCredentials = readCredentials;
+ }
+
+ public StorageCredentials getWriteCredentials()
+ {
+ return writeCredentials;
+ }
+
+ public StorageCredentials getReadCredentials()
+ {
+ return readCredentials;
+ }
+
+ @Override
+ public String toString()
+ {
+ return "StorageCredentialPair{"
+ + "writeCredentials=" + writeCredentials
+ + ", readCredentials=" + readCredentials
+ + '}';
+ }
+
+ public RestoreJobSecrets toRestoreJobSecrets(String readRegion, String writeRegion)
+ {
+ return new RestoreJobSecrets(readCredentials.toSidecarCredentials(readRegion),
+ writeCredentials.toSidecarCredentials(writeRegion));
+ }
+
+ @Override
+ public boolean equals(Object o)
+ {
+ if (this == o)
+ {
+ return true;
+ }
+ if (!(o instanceof StorageCredentialPair))
+ {
+ return false;
+ }
+ StorageCredentialPair that = (StorageCredentialPair) o;
+ return Objects.equals(writeCredentials, that.writeCredentials) && Objects.equals(readCredentials, that.readCredentials);
+ }
+
+ public int hashCode()
+ {
+ return Objects.hash(writeCredentials, readCredentials);
+ }
+
+ public static class Serializer extends com.esotericsoftware.kryo.Serializer<StorageCredentialPair>
+ {
+
+ public void write(Kryo kryo, Output out, StorageCredentialPair object)
+ {
+ kryo.writeObject(out, object.writeCredentials);
+ kryo.writeObject(out, object.readCredentials);
+ }
+
+ public StorageCredentialPair read(Kryo kryo, Input in, Class<StorageCredentialPair> type)
+ {
+ return new StorageCredentialPair(
+ kryo.readObject(in, StorageCredentials.class),
+ kryo.readObject(in, StorageCredentials.class)
+ );
+ }
+ }
+}
diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/transports/storage/StorageCredentials.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/transports/storage/StorageCredentials.java
new file mode 100644
index 0000000..670416c
--- /dev/null
+++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/transports/storage/StorageCredentials.java
@@ -0,0 +1,162 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.cassandra.spark.transports.storage;
+
+import java.io.Serializable;
+import java.util.Objects;
+
+import com.esotericsoftware.kryo.Kryo;
+import com.esotericsoftware.kryo.io.Input;
+import com.esotericsoftware.kryo.io.Output;
+
+/**
+ * StorageCredentials are used to represent the security information required to read from or write to a storage endpoint.
+ * Storage credentials can be either an access key ID and secret key, or also include a sessionToken when using
+ * temporary a temporary IAM credential.
+ */
+public class StorageCredentials implements Serializable
+{
+
+ private static final long serialVersionUID = 7704978447772194970L;
+ private final String accessKeyId;
+ private final String secretKey;
+ private final String sessionToken;
+
+ public static StorageCredentials fromSidecarCredentials(o.a.c.sidecar.client.shaded.common.data.StorageCredentials credentials)
+ {
+ return new StorageCredentials(credentials.accessKeyId(),
+ credentials.secretAccessKey(),
+ credentials.sessionToken());
+ }
+
+ /**
+ * Creates a Storage Credential instance with only an Access Key and Secret Key.
+ *
+ * @param accessKeyId the accessKeyId to use to access the S3 bucket
+ * @param secretKey the secretKey to use to access the S3 Bucket
+ */
+ public StorageCredentials(String accessKeyId, String secretKey)
+ {
+ this(accessKeyId, secretKey, null);
+ }
+
+ /**
+ * Creates a Storage Credential instance with only an Access Key, Secret Key, and Session Token.
+ * Used when a temporary IAM credential is to be provided to S3 for authentication/authorization.
+ * See <a href="https://docs.aws.amazon.com/AmazonS3/latest/userguide/AuthUsingTempSessionToken.html">The Amazon
+ * Documentation on Temporary IAM Credentials</a>
+ * for more details.
+ *
+ * @param accessKeyId the accessKeyId to use to access the S3 bucket
+ * @param secretKey the secretKey to use to access the S3 Bucket
+ * @param sessionToken the session token to use to access the S3 Bucket
+ */
+ public StorageCredentials(String accessKeyId, String secretKey, String sessionToken)
+ {
+ this.accessKeyId = accessKeyId;
+ this.secretKey = secretKey;
+ this.sessionToken = sessionToken;
+ }
+
+ public String getAccessKeyId()
+ {
+ return accessKeyId;
+ }
+
+ public String getSecretKey()
+ {
+ return secretKey;
+ }
+
+ public String getSessionToken()
+ {
+ return sessionToken;
+ }
+
+ public o.a.c.sidecar.client.shaded.common.data.StorageCredentials toSidecarCredentials(String region)
+ {
+ return o.a.c.sidecar.client.shaded.common.data.StorageCredentials
+ .builder()
+ .accessKeyId(accessKeyId)
+ .secretAccessKey(secretKey)
+ .sessionToken(sessionToken)
+ .region(region)
+ .build();
+ }
+
+ @Override
+ public String toString()
+ {
+ return "StorageCredentials{"
+ + "accessKeyId='" + accessKeyId + '\''
+ + ", secretKey='" + redact(secretKey) + '\''
+ + ", sessionToken='" + redact(sessionToken) + '\''
+ + '}';
+ }
+
+ private String redact(String value)
+ {
+ if (value == null || value.isEmpty())
+ {
+ return "NOT_PROVIDED";
+ }
+ return "*****";
+ }
+
+ public boolean equals(Object o)
+ {
+ if (this == o)
+ {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass())
+ {
+ return false;
+ }
+ StorageCredentials that = (StorageCredentials) o;
+ return Objects.equals(accessKeyId, that.accessKeyId)
+ && Objects.equals(secretKey, that.secretKey)
+ && Objects.equals(sessionToken, that.sessionToken);
+ }
+
+ public int hashCode()
+ {
+ return Objects.hash(accessKeyId, secretKey, sessionToken);
+ }
+
+ public static class Serializer extends com.esotericsoftware.kryo.Serializer<StorageCredentials>
+ {
+ public void write(Kryo kryo, Output out, StorageCredentials obj)
+ {
+ out.writeString(obj.accessKeyId);
+ out.writeString(obj.secretKey);
+ out.writeString(obj.sessionToken);
+ }
+
+ public StorageCredentials read(Kryo kryo, Input input, Class<StorageCredentials> type)
+ {
+ return new StorageCredentials(
+ input.readString(),
+ input.readString(),
+ input.readString()
+ );
+ }
+ }
+}
diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/transports/storage/extensions/CommonStorageTransportExtension.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/transports/storage/extensions/CommonStorageTransportExtension.java
new file mode 100644
index 0000000..fc3f45c
--- /dev/null
+++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/transports/storage/extensions/CommonStorageTransportExtension.java
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.cassandra.spark.transports.storage.extensions;
+
+import org.apache.spark.SparkConf;
+
+/**
+ * Extension methods that are invoked in both Spark driver and executors
+ * Package-private interface only to be extended by {@link StorageTransportExtension}
+ */
+interface CommonStorageTransportExtension
+{
+ /**
+ * Initializes the instance of this class after it has been created.
+ * The initialization implementation could differentiate based on whether is it running on Spark driver or executor
+ *
+ * @param jobId the unique identifier for the job.
+ * It could either be supplied by customer with {@link org.apache.cassandra.spark.bulkwriter.WriterOptions#JOB_ID},
+ * or a unique id string generated by the job on starting up, if no jobId is supplied.
+ * @param conf the spark configuration
+ * @param isOnDriver indicate whether the role of the runtime is Spark driver or executor
+ */
+ void initialize(String jobId, SparkConf conf, boolean isOnDriver);
+
+ /**
+ * Returns the {@link StorageTransportConfiguration}
+ *
+ * @return the {@link StorageTransportConfiguration}
+ */
+ StorageTransportConfiguration getStorageConfiguration();
+}
diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/transports/storage/extensions/CredentialChangeListener.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/transports/storage/extensions/CredentialChangeListener.java
new file mode 100644
index 0000000..e725c8c
--- /dev/null
+++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/transports/storage/extensions/CredentialChangeListener.java
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.cassandra.spark.transports.storage.extensions;
+
+import org.apache.cassandra.spark.transports.storage.StorageCredentialPair;
+
+/**
+ * A listener interface that is notified on access token changes
+ */
+public interface CredentialChangeListener
+{
+ /**
+ * Method called when new access tokens are available for the job with ID {@code jobId}.
+ * The previous set of credentials and the newly-provided set must both be valid simultaneously
+ * for the Spark job to have time to rotate credentials without interruption.
+ * These tokens should be provided with plenty of time for the job to distribute them to
+ * the consumers of the storage transport endpoint to update their tokens before expiration.
+ *
+ * @param jobId the unique identifier for the job
+ * @param newTokens a map of access tokens used to authenticate to the storage transport
+ */
+ void onCredentialsChanged(String jobId, StorageCredentialPair newTokens);
+}
diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/transports/storage/extensions/DriverStorageTransportExtension.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/transports/storage/extensions/DriverStorageTransportExtension.java
new file mode 100644
index 0000000..7c9fb99
--- /dev/null
+++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/transports/storage/extensions/DriverStorageTransportExtension.java
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.cassandra.spark.transports.storage.extensions;
+
+/**
+ * Extension methods that are invoked in Spark driver only
+ * Package-private interface only to be extended by {@link StorageTransportExtension}
+ */
+interface DriverStorageTransportExtension
+{
+ /**
+ * Notifies the extension that data transport has been started. This method will be called from the driver.
+ * @param elapsedMillis the elapsed time from the start of the bulk write run until this step for the job in milliseconds
+ */
+ void onTransportStart(long elapsedMillis);
+
+ /**
+ * Sets the {@link CredentialChangeListener} to listen for token changes. This method
+ * will be called from the driver.
+ *
+ * @param credentialChangeListener an implementation of the {@link CredentialChangeListener}
+ */
+ void setCredentialChangeListener(CredentialChangeListener credentialChangeListener);
+
+ /**
+ * Sets the {@link ObjectFailureListener} to listen for token changes. This method
+ * will be called from the driver.
+ *
+ * @param objectFailureListener an implementation of the {@link ObjectFailureListener}
+ */
+ void setObjectFailureListener(ObjectFailureListener objectFailureListener);
+
+ /**
+ * Notifies the extension that all the objects have been persisted to the blob store successfully.
+ * This method is called from driver when all executor tasks complete.
+ *
+ * @param objectsCount the total count of objects persisted
+ * @param rowCount the total count of rows persisted
+ * @param elapsedMillis the elapsed time from the start of the bulk write run until this step for the job in milliseconds
+ */
+ void onAllObjectsPersisted(long objectsCount, long rowCount, long elapsedMillis);
+
+ /**
+ * Notifies the extension that the object identified by the bucket and key has been applied, meaning
+ * the SSTables included in the object is imported into Cassandra and satisfies the desired consistency level.
+ * <br>
+ * The notification is only emitted once per object and as soon as the consistency level is satisfied.
+ *
+ * @param bucket the belonging bucket of the object
+ * @param key the object key
+ * @param sizeInBytes the size of the object in bytes
+ * @param elapsedMillis the elapsed time from the start of the bulk write run until this step for the job in milliseconds
+ */
+ void onObjectApplied(String bucket, String key, long sizeInBytes, long elapsedMillis);
+
+ /**
+ * Notifies the extension that the job has completed successfully. This method will be called
+ * from the driver at the end of the Spark Bulk Writer execution when the job succeeds.
+ *
+ * @param elapsedMillis the elapsed time from the start of the bulk write run until this step for the job in milliseconds
+ */
+ void onJobSucceeded(long elapsedMillis);
+
+ /**
+ * Notifies the extension that the job has failed with exception {@link Throwable throwable}.
+ * This method will be called from the driver at the end of the Spark Bulk Writer execution when the job fails.
+ *
+ * @param elapsedMillis the elapsed time from the start of the bulk write run until this step for the job in milliseconds
+ * @param throwable the exception encountered by the job
+ */
+ void onJobFailed(long elapsedMillis, Throwable throwable);
+}
diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/transports/storage/extensions/ExecutorStorageTransportExtension.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/transports/storage/extensions/ExecutorStorageTransportExtension.java
new file mode 100644
index 0000000..dfd50b2
--- /dev/null
+++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/transports/storage/extensions/ExecutorStorageTransportExtension.java
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.cassandra.spark.transports.storage.extensions;
+
+/**
+ * Extension methods that are invoked in Spark executors only
+ * Package-private interface only to be extended by {@link StorageTransportExtension}
+ */
+interface ExecutorStorageTransportExtension
+{
+ /**
+ * Notifies the extension that the {@code objectURI} has been successfully persisted to the blob store.
+ * This method will be called from each task during the job execution.
+ *
+ * @param bucket the bucket to which the file was written
+ * @param key the key to the object written
+ * @param sizeInBytes the size of the object, in bytes
+ */
+ void onObjectPersisted(String bucket, String key, long sizeInBytes);
+}
diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/transports/storage/extensions/ObjectFailureListener.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/transports/storage/extensions/ObjectFailureListener.java
new file mode 100644
index 0000000..018ed49
--- /dev/null
+++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/transports/storage/extensions/ObjectFailureListener.java
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.cassandra.spark.transports.storage.extensions;
+
+/**
+ * A listener interface that is notified on failures processing an object
+ */
+public interface ObjectFailureListener
+{
+ /**
+ * Method to call when an unrecoverable error has been encountered for the given {@code jobId}, {@code objectURI},
+ * with {@code errorMessage}.
+ *
+ * @param jobId the unique identifier for the job. It could be customer-supplied
+ * or self-generated (when not supplied)
+ * @param bucket the object storage bucket
+ * @param key the key in the object storage
+ * @param errorMessage a description of the error
+ */
+ void onObjectFailed(String jobId, String bucket, String key, String errorMessage);
+}
diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/transports/storage/extensions/StorageTransportConfiguration.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/transports/storage/extensions/StorageTransportConfiguration.java
new file mode 100644
index 0000000..47abcd6
--- /dev/null
+++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/transports/storage/extensions/StorageTransportConfiguration.java
@@ -0,0 +1,159 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.cassandra.spark.transports.storage.extensions;
+
+import java.io.Serializable;
+import java.net.URI;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Objects;
+
+import com.esotericsoftware.kryo.Kryo;
+import com.esotericsoftware.kryo.io.Input;
+import com.esotericsoftware.kryo.io.Output;
+import org.apache.cassandra.spark.transports.storage.StorageCredentialPair;
+
+/**
+ * Holds information about the Blob configuration
+ */
+public class StorageTransportConfiguration implements Serializable
+{
+ private static final long serialVersionUID = -8164878804296039585L;
+ private final String writeBucket;
+ private final String writeRegion;
+ private final String readBucket;
+ private final String readRegion;
+ private final String prefix;
+ private final Map<String, String> tags;
+ private StorageCredentialPair storageCredentialPair;
+
+ public StorageTransportConfiguration(String writeBucket, String writeRegion,
+ String readBucket, String readRegion,
+ String prefix,
+ StorageCredentialPair storageCredentialPair,
+ Map<String, String> tags)
+ {
+ this.writeBucket = writeBucket;
+ this.writeRegion = writeRegion;
+ this.readBucket = readBucket;
+ this.readRegion = readRegion;
+ this.prefix = prefix;
+ this.storageCredentialPair = storageCredentialPair;
+ this.tags = Collections.unmodifiableMap(tags);
+ }
+
+ /**
+ * @return the base {@link URI} to use for accessing the storage transport
+ */
+ public String getWriteBucket()
+ {
+ return writeBucket;
+ }
+
+ /**
+ * @return a map of access tokens used to authenticate to the storage transport
+ */
+ public StorageCredentialPair getStorageCredentialPair()
+ {
+ return storageCredentialPair;
+ }
+
+ public void setBlobCredentialPair(StorageCredentialPair newCredentials)
+ {
+ this.storageCredentialPair = newCredentials;
+ }
+
+ public String getWriteRegion()
+ {
+ return writeRegion;
+ }
+
+ public String getReadBucket()
+ {
+ return readBucket;
+ }
+
+ public String getReadRegion()
+ {
+ return readRegion;
+ }
+
+ public String getPrefix()
+ {
+ return prefix;
+ }
+
+ public Map<String, String> getTags()
+ {
+ return tags;
+ }
+
+ public boolean equals(Object o)
+ {
+ if (this == o)
+ {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass())
+ {
+ return false;
+ }
+ StorageTransportConfiguration that = (StorageTransportConfiguration) o;
+ return Objects.equals(writeBucket, that.writeBucket)
+ && Objects.equals(writeRegion, that.writeRegion)
+ && Objects.equals(readBucket, that.readBucket)
+ && Objects.equals(readRegion, that.readRegion)
+ && Objects.equals(prefix, that.prefix)
+ && Objects.equals(storageCredentialPair, that.storageCredentialPair)
+ && Objects.equals(tags, that.tags);
+ }
+
+ public int hashCode()
+ {
+ return Objects.hash(writeBucket, writeRegion, readBucket, readRegion, prefix, storageCredentialPair, tags);
+ }
+
+ public static class Serializer extends com.esotericsoftware.kryo.Serializer<StorageTransportConfiguration>
+ {
+ public void write(Kryo kryo, Output out, StorageTransportConfiguration obj)
+ {
+ out.writeString(obj.writeBucket);
+ out.writeString(obj.writeRegion);
+ out.writeString(obj.readBucket);
+ out.writeString(obj.readRegion);
+ out.writeString(obj.prefix);
+ kryo.writeObject(out, obj.storageCredentialPair);
+ kryo.writeObject(out, obj.tags);
+ }
+
+ @SuppressWarnings("unchecked")
+ public StorageTransportConfiguration read(Kryo kryo, Input in, Class<StorageTransportConfiguration> type)
+ {
+ return new StorageTransportConfiguration(in.readString(),
+ in.readString(),
+ in.readString(),
+ in.readString(),
+ in.readString(),
+ kryo.readObject(in, StorageCredentialPair.class),
+ kryo.readObject(in, HashMap.class));
+ }
+ }
+}
diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/transports/storage/extensions/StorageTransportExtension.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/transports/storage/extensions/StorageTransportExtension.java
new file mode 100644
index 0000000..6165e72
--- /dev/null
+++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/transports/storage/extensions/StorageTransportExtension.java
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.cassandra.spark.transports.storage.extensions;
+
+import org.apache.spark.SparkConf;
+
+/**
+ * The facade interface defines the contract of the extension for cloud storage data transport.
+ * It servers as the integration point for the library consumer.
+ * - Register callbacks for data transport progress
+ * - Supply necessary information, e.g. credentials, bucket, etc., to conduct a successful data transport
+ * <br>
+ * Notes for the interface implementors:
+ * - Not all methods defined in the interface are invoked in both Spark driver and executors.
+ * 1. The methods in {@link CommonStorageTransportExtension} are invoked in both places.
+ * 2. The methods in {@link ExecutorStorageTransportExtension} are invoked in Spark executors only.
+ * 3. The methods in {@link DriverStorageTransportExtension} are invoked in Spark driver only.
+ * - The Analytics library guarantees the following sequence in Spark driver on initialization
+ * 1. Create the new {@link StorageTransportExtension} instance
+ * 2. Invoke {@link #initialize(String, SparkConf, boolean)}
+ * 3. Invoke {@link #getStorageConfiguration()}
+ * 4. Invoke {@link #setCredentialChangeListener(CredentialChangeListener)}
+ * 5. Invoke {@link #setObjectFailureListener(ObjectFailureListener)}
+ */
+public interface StorageTransportExtension extends CommonStorageTransportExtension,
+ ExecutorStorageTransportExtension,
+ DriverStorageTransportExtension
+{
+ // No extra methods to be defined
+ // When adding a new one, please determine its call-site, and add it in one of the three interfaces
+}
diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/transports/storage/extensions/StorageTransportHandler.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/transports/storage/extensions/StorageTransportHandler.java
new file mode 100644
index 0000000..954f7de
--- /dev/null
+++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/transports/storage/extensions/StorageTransportHandler.java
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.cassandra.spark.transports.storage.extensions;
+
+import java.util.Objects;
+import java.util.UUID;
+import java.util.function.Consumer;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import o.a.c.sidecar.client.shaded.common.data.RestoreJobSecrets;
+import o.a.c.sidecar.client.shaded.common.data.UpdateRestoreJobRequestPayload;
+import org.apache.cassandra.spark.bulkwriter.CancelJobEvent;
+import org.apache.cassandra.spark.bulkwriter.JobInfo;
+import org.apache.cassandra.spark.bulkwriter.TransportContext;
+import org.apache.cassandra.spark.common.client.ClientException;
+import org.apache.cassandra.spark.transports.storage.StorageCredentialPair;
+
+public class StorageTransportHandler implements CredentialChangeListener, ObjectFailureListener
+{
+ private final TransportContext.CloudStorageTransportContext transportContext;
+ private final Consumer<CancelJobEvent> cancelConsumer;
+ private final JobInfo jobInfo;
+
+ private static final Logger LOGGER = LoggerFactory.getLogger(StorageTransportHandler.class);
+
+ public StorageTransportHandler(TransportContext.CloudStorageTransportContext transportContext,
+ JobInfo jobInfo,
+ Consumer<CancelJobEvent> cancelConsumer)
+ {
+ this.transportContext = transportContext;
+ this.jobInfo = jobInfo;
+ this.cancelConsumer = cancelConsumer;
+ }
+
+ @Override
+ public void onCredentialsChanged(String jobId, StorageCredentialPair newCredentials)
+ {
+ validateReceivedJobId(jobId);
+ if (Objects.equals(transportContext.transportConfiguration().getStorageCredentialPair(), newCredentials))
+ {
+ LOGGER.info("The received new credential is the same as the existing one. Skip updating.");
+ return;
+ }
+
+ LOGGER.info("Refreshing cloud storage credentials. jobId={}, credentials={}", jobId, newCredentials);
+ transportContext.transportConfiguration().setBlobCredentialPair(newCredentials);
+ updateCredentials(jobInfo.getRestoreJobId(), newCredentials);
+ }
+
+ @Override
+ public void onObjectFailed(String jobId, String bucket, String key, String errorMessage)
+ {
+ validateReceivedJobId(jobId);
+ LOGGER.error("Object with bucket {} and key {} failed to be transported correctly. Cancelling job. Error was: {}", bucket, key, errorMessage);
+ cancelConsumer.accept(new CancelJobEvent(errorMessage));
+ }
+
+ private void updateCredentials(UUID jobId, StorageCredentialPair credentialPair)
+ {
+ StorageTransportConfiguration conf = transportContext.transportConfiguration();
+ RestoreJobSecrets secrets = credentialPair.toRestoreJobSecrets(conf.getReadRegion(), conf.getWriteRegion());
+ UpdateRestoreJobRequestPayload requestPayload = new UpdateRestoreJobRequestPayload(null, secrets, null, null);
+ try
+ {
+ transportContext.dataTransferApi().updateRestoreJob(requestPayload);
+ }
+ catch (ClientException e)
+ {
+ throw new RuntimeException("Failed to update secretes for restore job. restoreJobId: " + jobId, e);
+ }
+ }
+
+ private void validateReceivedJobId(String jobId)
+ {
+ String actualJobId = jobInfo.getId();
+ if (!Objects.equals(jobId, actualJobId))
+ {
+ throw new IllegalStateException("Received jobId does not match with the actual one. Received: " + jobId
+ + "; actual: " + actualJobId);
+ }
+ }
+}
diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/utils/BuildInfo.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/utils/BuildInfo.java
index c432258..8e3ffd8 100644
--- a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/utils/BuildInfo.java
+++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/utils/BuildInfo.java
@@ -41,6 +41,7 @@
public static final String BUILD_VERSION_AND_REVISION = getBuildVersionAndRevision();
public static final String READER_USER_AGENT = getUserAgent("reader");
public static final String WRITER_USER_AGENT = getUserAgent("writer");
+ public static final String WRITER_S3_USER_AGENT = getUserAgent("writer-s3");
private BuildInfo()
{
diff --git a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/utils/RangeUtils.java b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/utils/RangeUtils.java
index 423b9dc..3df7c10 100644
--- a/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/utils/RangeUtils.java
+++ b/cassandra-analytics-core/src/main/java/org/apache/cassandra/spark/utils/RangeUtils.java
@@ -50,9 +50,7 @@
{
Preconditions.checkArgument(range.lowerEndpoint().compareTo(range.upperEndpoint()) <= 0,
"RangeUtils assume ranges are not wrap-around");
- Preconditions.checkArgument(range.lowerBoundType() == BoundType.OPEN
- && range.upperBoundType() == BoundType.CLOSED,
- "Input must be an open-closed range");
+ Preconditions.checkArgument(isOpenClosedRange(range), "Input must be an open-closed range");
if (range.isEmpty())
{
@@ -63,6 +61,16 @@
}
/**
+ * Check whether a range is open (exclusive) on its lower end and closed (inclusive) on its upper end.
+ * @param range range
+ * @return true if the range is open closed.
+ */
+ public static boolean isOpenClosedRange(Range<?> range)
+ {
+ return range.lowerBoundType() == BoundType.OPEN && range.upperBoundType() == BoundType.CLOSED;
+ }
+
+ /**
* Splits the given range into equal-sized small ranges. Number of splits can be controlled by
* nrSplits. If nrSplits are smaller than size of the range, split size would be set to 1, which is
* the minimum allowed. For example, if the input range is {@code (0, 1]} and nrSplits is 10, the split
@@ -79,9 +87,7 @@
{
Preconditions.checkArgument(range.lowerEndpoint().compareTo(range.upperEndpoint()) <= 0,
"RangeUtils assume ranges are not wrap-around");
- Preconditions.checkArgument(range.lowerBoundType() == BoundType.OPEN
- && range.upperBoundType() == BoundType.CLOSED,
- "Input must be an open-closed range");
+ Preconditions.checkArgument(isOpenClosedRange(range), "Input must be an open-closed range");
if (range.isEmpty())
{
@@ -175,9 +181,7 @@
@NotNull
public static TokenRange toTokenRange(@NotNull Range<BigInteger> range)
{
- Preconditions.checkArgument(range.lowerBoundType() == BoundType.OPEN
- && range.upperBoundType() == BoundType.CLOSED,
- "Input must be an open-closed range");
+ Preconditions.checkArgument(isOpenClosedRange(range), "Input must be an open-closed range");
return TokenRange.openClosed(range.lowerEndpoint(), range.upperEndpoint());
}
diff --git a/cassandra-analytics-core/src/test/java/org/apache/cassandra/clients/SidecarClientConfigTest.java b/cassandra-analytics-core/src/test/java/org/apache/cassandra/clients/SidecarClientConfigTest.java
index c20910b..bff83f3 100644
--- a/cassandra-analytics-core/src/test/java/org/apache/cassandra/clients/SidecarClientConfigTest.java
+++ b/cassandra-analytics-core/src/test/java/org/apache/cassandra/clients/SidecarClientConfigTest.java
@@ -22,8 +22,11 @@
import com.google.common.collect.ImmutableMap;
import org.junit.jupiter.api.Test;
+import org.apache.cassandra.spark.bulkwriter.DataTransport;
+
import static org.apache.cassandra.spark.bulkwriter.BulkSparkConf.DEFAULT_SIDECAR_PORT;
import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertTrue;
/**
* Unit tests for the {@link Sidecar.ClientConfig} inner class
@@ -101,4 +104,14 @@
Sidecar.ClientConfig clientConfig = Sidecar.ClientConfig.create(ImmutableMap.of("timeoutseconds", "2"));
assertEquals(2, clientConfig.timeoutSeconds());
}
+
+ @Test
+ public void testTransportModeBasedWriterUserAgent()
+ {
+ String userAgentStr = Sidecar.transportModeBasedWriterUserAgent(DataTransport.DIRECT);
+ assertTrue(userAgentStr.endsWith(" writer"));
+
+ userAgentStr = Sidecar.transportModeBasedWriterUserAgent(DataTransport.S3_COMPAT);
+ assertTrue(userAgentStr.endsWith(" writer-s3"));
+ }
}
diff --git a/cassandra-analytics-core/src/test/java/org/apache/cassandra/clients/SidecarInstanceImplTest.java b/cassandra-analytics-core/src/test/java/org/apache/cassandra/clients/SidecarInstanceImplTest.java
deleted file mode 100644
index b782514..0000000
--- a/cassandra-analytics-core/src/test/java/org/apache/cassandra/clients/SidecarInstanceImplTest.java
+++ /dev/null
@@ -1,74 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.cassandra.clients;
-
-import org.junit.jupiter.api.BeforeAll;
-import org.junit.jupiter.api.Test;
-
-import com.esotericsoftware.kryo.io.Output;
-import org.apache.cassandra.sidecar.client.SidecarInstance;
-
-import static org.apache.cassandra.spark.utils.SerializationUtils.deserialize;
-import static org.apache.cassandra.spark.utils.SerializationUtils.kryoDeserialize;
-import static org.apache.cassandra.spark.utils.SerializationUtils.kryoSerialize;
-import static org.apache.cassandra.spark.utils.SerializationUtils.register;
-import static org.apache.cassandra.spark.utils.SerializationUtils.serialize;
-import static org.junit.jupiter.api.Assertions.assertEquals;
-import static org.junit.jupiter.api.Assertions.assertNotNull;
-
-/**
- * Unit tests for the {@link SidecarInstanceImpl} class
- */
-public class SidecarInstanceImplTest extends SidecarInstanceTest
-{
- @BeforeAll
- public static void setupKryo()
- {
- register(SidecarInstanceImpl.class, new SidecarInstanceImpl.Serializer());
- }
-
- @Override
- protected SidecarInstance newInstance(String hostname, int port)
- {
- return new SidecarInstanceImpl(hostname, port);
- }
-
- @Test
- public void testJdkSerDe()
- {
- SidecarInstance instance = newInstance("localhost", 9043);
- byte[] bytes = serialize(instance);
- SidecarInstance deserialized = deserialize(bytes, SidecarInstanceImpl.class);
- assertNotNull(deserialized);
- assertEquals("localhost", deserialized.hostname());
- assertEquals(9043, deserialized.port());
- }
-
- @Test
- public void testKryoSerDe()
- {
- SidecarInstance instance = newInstance("localhost", 9043);
- Output out = kryoSerialize(instance);
- SidecarInstance deserialized = kryoDeserialize(out, SidecarInstanceImpl.class);
- assertNotNull(deserialized);
- assertEquals("localhost", deserialized.hostname());
- assertEquals(9043, deserialized.port());
- }
-}
diff --git a/cassandra-analytics-core/src/test/java/org/apache/cassandra/clients/SidecarInstanceTest.java b/cassandra-analytics-core/src/test/java/org/apache/cassandra/clients/SidecarInstanceTest.java
index 450dfd5..743f806 100644
--- a/cassandra-analytics-core/src/test/java/org/apache/cassandra/clients/SidecarInstanceTest.java
+++ b/cassandra-analytics-core/src/test/java/org/apache/cassandra/clients/SidecarInstanceTest.java
@@ -49,7 +49,7 @@
}
catch (IllegalArgumentException illegalArgumentException)
{
- assertEquals("The Sidecar port number must be in the range 1-65535: " + invalidPortNumber,
+ assertEquals("Invalid port number for the Sidecar service: " + invalidPortNumber,
illegalArgumentException.getMessage());
}
}
diff --git a/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/KryoSerializationTests.java b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/KryoSerializationTests.java
index 8190351..6354428 100644
--- a/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/KryoSerializationTests.java
+++ b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/KryoSerializationTests.java
@@ -35,6 +35,7 @@
import com.esotericsoftware.kryo.io.Output;
import org.apache.cassandra.bridge.CassandraBridge;
import org.apache.cassandra.secrets.SslConfig;
+import org.apache.cassandra.spark.bulkwriter.util.SbwKryoRegistrator;
import org.apache.cassandra.spark.data.CqlField;
import org.apache.cassandra.spark.data.CqlTable;
import org.apache.cassandra.spark.data.LocalDataLayer;
@@ -42,6 +43,9 @@
import org.apache.cassandra.spark.data.partitioner.CassandraInstance;
import org.apache.cassandra.spark.data.partitioner.CassandraRing;
import org.apache.cassandra.spark.data.partitioner.TokenPartitioner;
+import org.apache.cassandra.spark.transports.storage.StorageCredentialPair;
+import org.apache.cassandra.spark.transports.storage.StorageCredentials;
+import org.apache.cassandra.spark.transports.storage.extensions.StorageTransportConfiguration;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
@@ -237,7 +241,7 @@
CqlTable table = new CqlTable("test_keyspace",
"test_table",
"create table test_keyspace.test_table"
- + " (a bigint, b bigint, c bigint, d bigint, e bigint, primary key((a, b), c));",
+ + " (a bigint, b bigint, c bigint, d bigint, e bigint, primary key((a, b), c));",
replicationFactor,
fields);
@@ -283,7 +287,7 @@
LocalDataLayer localDataLayer = new LocalDataLayer(bridge.getVersion(),
"test_keyspace",
"create table test_keyspace.test_table"
- + " (a int, b int, c int, primary key(a, b));",
+ + " (a int, b int, c int, primary key(a, b));",
path1,
path2,
path3);
@@ -342,15 +346,15 @@
public void testSslConfig()
{
SslConfig config = new SslConfig.Builder<>()
- .keyStorePath("keyStorePath")
- .base64EncodedKeyStore("encodedKeyStore")
- .keyStorePassword("keyStorePassword")
- .keyStoreType("keyStoreType")
- .trustStorePath("trustStorePath")
- .base64EncodedTrustStore("encodedTrustStore")
- .trustStorePassword("trustStorePassword")
- .trustStoreType("trustStoreType")
- .build();
+ .keyStorePath("keyStorePath")
+ .base64EncodedKeyStore("encodedKeyStore")
+ .keyStorePassword("keyStorePassword")
+ .keyStoreType("keyStoreType")
+ .trustStorePath("trustStorePath")
+ .base64EncodedTrustStore("encodedTrustStore")
+ .trustStorePassword("trustStorePassword")
+ .trustStoreType("trustStoreType")
+ .build();
Output out = serialize(config);
SslConfig deserialized = deserialize(out, SslConfig.class);
@@ -363,4 +367,34 @@
assertEquals(config.trustStorePassword(), deserialized.trustStorePassword());
assertEquals(config.trustStoreType(), deserialized.trustStoreType());
}
+
+ @Test
+ public void testStorageTransportConfiguration()
+ {
+ final StorageTransportConfiguration config = new StorageTransportConfiguration(
+ "writeBucket",
+ "writeRegion",
+ "readBucket",
+ "readRegion",
+ "prefix",
+ new StorageCredentialPair(
+ new StorageCredentials("keyId1", "secret1", "sessionToken1"),
+ new StorageCredentials("keyId2", "secret2", "sessionToken2")
+ ),
+ ImmutableMap.of("tag1", "tagVal1", "tag2", "tagVal2")
+ );
+
+ StorageTransportConfiguration deserialized;
+ try (Output out = serialize(config))
+ {
+ deserialized = deserialize(out, StorageTransportConfiguration.class);
+ }
+ assertEquals(config, deserialized);
+ }
+
+ static
+ {
+ new KryoRegister().registerClasses(KRYO);
+ new SbwKryoRegistrator().registerClasses(KRYO);
+ }
}
diff --git a/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/BulkSparkConfTest.java b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/BulkSparkConfTest.java
index b55045d..b059479 100644
--- a/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/BulkSparkConfTest.java
+++ b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/BulkSparkConfTest.java
@@ -34,6 +34,7 @@
import org.jetbrains.annotations.NotNull;
import static org.apache.cassandra.spark.bulkwriter.BulkSparkConf.DEFAULT_SIDECAR_PORT;
+import static org.apache.cassandra.spark.bulkwriter.BulkSparkConf.MINIMUM_JOB_KEEP_ALIVE_MINUTES;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.emptyString;
import static org.hamcrest.core.Is.is;
@@ -225,6 +226,33 @@
assertTrue(bulkSparkConf.quoteIdentifiers);
}
+ @Test
+ public void testInvalidJobKeepAliveMinutes()
+ {
+ Map<String, String> options = copyDefaultOptions();
+ options.put(WriterOptions.JOB_KEEP_ALIVE_MINUTES.name(), "-100");
+ IllegalArgumentException iae = assertThrows(IllegalArgumentException.class, () -> new BulkSparkConf(sparkConf, options));
+ assertEquals("Invalid value for the 'JOB_KEEP_ALIVE_MINUTES' Bulk Writer option (-100). It cannot be less than the minimum 10",
+ iae.getMessage());
+ }
+
+ @Test
+ public void testDefaultJobKeepAliveMinutes()
+ {
+ Map<String, String> options = copyDefaultOptions();
+ BulkSparkConf conf = new BulkSparkConf(sparkConf, options);
+ assertEquals(MINIMUM_JOB_KEEP_ALIVE_MINUTES, conf.getJobKeepAliveMinutes());
+ }
+
+ @Test
+ public void testJobKeepAliveMinutes()
+ {
+ Map<String, String> options = copyDefaultOptions();
+ options.put(WriterOptions.JOB_KEEP_ALIVE_MINUTES.name(), "30");
+ BulkSparkConf conf = new BulkSparkConf(sparkConf, options);
+ assertEquals(30, conf.getJobKeepAliveMinutes());
+ }
+
private Map<String, String> copyDefaultOptions()
{
TreeMap<String, String> map = new TreeMap<>(String.CASE_INSENSITIVE_ORDER);
diff --git a/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/CassandraTopologyMonitorTest.java b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/CassandraTopologyMonitorTest.java
new file mode 100644
index 0000000..0b45d25
--- /dev/null
+++ b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/CassandraTopologyMonitorTest.java
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.cassandra.spark.bulkwriter;
+
+import java.util.concurrent.atomic.AtomicBoolean;
+
+import com.google.common.collect.ImmutableMap;
+import org.junit.jupiter.api.Test;
+
+import org.apache.cassandra.spark.bulkwriter.token.TokenRangeMapping;
+
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+class CassandraTopologyMonitorTest
+{
+ @Test
+ void testNoTopologyChange()
+ {
+ ClusterInfo mockClusterInfo = mock(ClusterInfo.class);
+ when(mockClusterInfo.getTokenRangeMapping(false)).thenReturn(buildTopology(10));
+ AtomicBoolean noChange = new AtomicBoolean(true);
+ CassandraTopologyMonitor monitor = new CassandraTopologyMonitor(mockClusterInfo, event -> noChange.set(false));
+ monitor.checkTopologyOnDemand();
+ assertTrue(noChange.get());
+ }
+
+ @Test
+ void testTopologyChanged()
+ {
+ ClusterInfo mockClusterInfo = mock(ClusterInfo.class);
+ when(mockClusterInfo.getTokenRangeMapping(false))
+ .thenReturn(buildTopology(10))
+ .thenReturn(buildTopology(11)); // token moved
+ AtomicBoolean noChange = new AtomicBoolean(true);
+ CassandraTopologyMonitor monitor = new CassandraTopologyMonitor(mockClusterInfo, event -> noChange.set(false));
+ monitor.checkTopologyOnDemand();
+ assertFalse(noChange.get());
+ }
+
+ private TokenRangeMapping<RingInstance> buildTopology(int instancesCount)
+ {
+ return TokenRangeMappingUtils.buildTokenRangeMapping(0, ImmutableMap.of("DC1", 3), instancesCount);
+ }
+}
diff --git a/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/CommitCoordinatorTest.java b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/CommitCoordinatorTest.java
new file mode 100644
index 0000000..258cfe2
--- /dev/null
+++ b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/CommitCoordinatorTest.java
@@ -0,0 +1,268 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.cassandra.spark.bulkwriter;
+
+import java.math.BigInteger;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.UUID;
+import java.util.concurrent.ExecutionException;
+
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.Range;
+
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+import org.apache.cassandra.spark.bulkwriter.token.TokenRangeMapping;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+/**
+ * Unit tests for {@link CommitCoordinator}
+ */
+class CommitCoordinatorTest
+{
+ TokenRangeMapping<RingInstance> topology;
+ MockBulkWriterContext context;
+ TransportContext.DirectDataBulkWriterContext transportContext;
+
+ @BeforeEach
+ public void setup()
+ {
+ topology = TokenRangeMappingUtils.buildTokenRangeMapping(0, ImmutableMap.of("DC1", 3), 3);
+ context = new MockBulkWriterContext(topology);
+ transportContext = (TransportContext.DirectDataBulkWriterContext) context.transportContext();
+ }
+
+ @Test
+ void commitsForEachSuccessfulUpload() throws ExecutionException, InterruptedException
+ {
+ int successfulUploads = 3;
+ DirectStreamResult uploadResult = DirectStreamResultBuilder
+ .withTopology(topology)
+ .withSuccessfulUploads(successfulUploads)
+ .build();
+ try (CommitCoordinator coordinator = CommitCoordinator.commit(context, transportContext, uploadResult))
+ {
+ List<CommitResult> commitResults = coordinator.get();
+ assertEquals(successfulUploads, commitResults.size());
+ commitResults.forEach(cr -> {
+ assertEquals(0, cr.failures.size());
+ assertEquals(1, cr.passed.size());
+ });
+ }
+ }
+
+ @Test
+ void commitWillNotCommitWhenUploadFailed() throws ExecutionException, InterruptedException
+ {
+ int successfulUploads = 1;
+ int failedUploads = 2;
+ DirectStreamResult uploadResult = DirectStreamResultBuilder
+ .withTopology(topology)
+ .withSuccessfulUploads(successfulUploads)
+ .withFailedUploads(failedUploads)
+ .build();
+ try (CommitCoordinator coordinator = CommitCoordinator.commit(context, transportContext, uploadResult))
+ {
+ List<CommitResult> commitResults = coordinator.get();
+ assertEquals(successfulUploads, commitResults.size());
+ CommitResult cr = commitResults.get(0);
+ assertEquals(0, cr.failures.size()); // Failed uploads should not be committed at all
+ assertEquals(successfulUploads, cr.passed.size());
+ }
+ }
+
+ @Test
+ void commitWillNotCommitWhenAlreadyCommitted() throws ExecutionException, InterruptedException
+ {
+
+ context.setCommitResultSupplier((uuids, dc) -> {
+ throw new RuntimeException("Should not have called commit");
+ });
+
+ int successfulUploads = 3;
+ int successfulCommits = 3;
+ DirectStreamResult uploadResults = DirectStreamResultBuilder
+ .withTopology(topology)
+ .withSuccessfulUploads(successfulUploads)
+ .withSuccessfulCommits(successfulCommits)
+ .build();
+ try (CommitCoordinator coordinator = CommitCoordinator.commit(context, transportContext, uploadResults))
+ {
+ List<CommitResult> commitResults = coordinator.get();
+ assertEquals(successfulUploads, commitResults.size());
+ commitResults.forEach(cr -> {
+ assertEquals(0, cr.failures.size());
+ assertEquals(1, cr.passed.size());
+ });
+ }
+ }
+
+ @Test
+ void commitWillReturnFailuresWhenCommitRequestFails() throws ExecutionException, InterruptedException
+ {
+ context.setCommitResultSupplier((uuids, dc) -> {
+ throw new RuntimeException("Intentionally Failing Commit for uuids: " + Arrays.toString(uuids.toArray()));
+ });
+ int successfulUploads = 3;
+ DirectStreamResult uploadResults = DirectStreamResultBuilder
+ .withTopology(topology)
+ .withSuccessfulUploads(successfulUploads)
+ .build();
+ try (CommitCoordinator coordinator = CommitCoordinator.commit(context, transportContext, uploadResults))
+ {
+ List<CommitResult> commitResults = coordinator.get();
+ assertEquals(successfulUploads, commitResults.size());
+ commitResults.forEach(cr -> {
+ assertEquals(1, cr.failures.size());
+ assertEquals(0, cr.passed.size());
+ });
+ }
+ }
+
+ @Test
+ void commitWillReturnFailuresWhenCommitFailsOnServerWithSpecificUuids() throws ExecutionException, InterruptedException
+ {
+ context.setCommitResultSupplier((uuids, dc) -> new DirectDataTransferApi.RemoteCommitResult(false, uuids, Collections.emptyList(), "Failed nodetool import"));
+ int successfulUploads = 3;
+ DirectStreamResult uploadResults = DirectStreamResultBuilder
+ .withTopology(topology)
+ .withSuccessfulUploads(successfulUploads)
+ .build();
+ try (CommitCoordinator coordinator = CommitCoordinator.commit(context, transportContext, uploadResults))
+ {
+ List<CommitResult> commitResults = coordinator.get();
+ assertEquals(successfulUploads, commitResults.size());
+ commitResults.forEach(cr -> {
+ assertEquals(1, cr.failures.size());
+ assertEquals(0, cr.passed.size());
+ });
+ }
+ }
+
+ @Test
+ void commitWillReturnFailuresWhenCommitFailsOnServerWithNoUuids() throws ExecutionException, InterruptedException
+ {
+ context.setCommitResultSupplier((uuids, dc) -> new DirectDataTransferApi.RemoteCommitResult(false, Collections.emptyList(), Collections.emptyList(), "Failed nodetool import"));
+ int successfulUploads = 3;
+ DirectStreamResult uploadResults = DirectStreamResultBuilder
+ .withTopology(topology)
+ .withSuccessfulUploads(successfulUploads)
+ .build();
+ try (CommitCoordinator coordinator = CommitCoordinator.commit(context, transportContext, uploadResults))
+ {
+ List<CommitResult> commitResults = coordinator.get();
+ assertEquals(successfulUploads, commitResults.size());
+ commitResults.forEach(cr -> {
+ assertEquals(1, cr.failures.size());
+ assertEquals(0, cr.passed.size());
+ });
+ }
+ }
+
+ static class DirectStreamResultBuilder
+ {
+ private static final Range<BigInteger> TEST_RANGE = Range.openClosed(BigInteger.valueOf(0), BigInteger.valueOf(200));
+
+ private final TokenRangeMapping<RingInstance> topology;
+ private int successfulUploads;
+ private int failedUploads;
+ private int successfulCommits;
+ private int failedCommits;
+ private RingInstance[] allInstances;
+
+ DirectStreamResultBuilder(TokenRangeMapping<RingInstance> topology)
+ {
+ this.topology = topology;
+ }
+
+ static DirectStreamResultBuilder withTopology(TokenRangeMapping<RingInstance> topology)
+ {
+ return new DirectStreamResultBuilder(topology);
+ }
+
+ public DirectStreamResultBuilder withSuccessfulUploads(int successfulUploads)
+ {
+ this.successfulUploads = successfulUploads;
+ return this;
+ }
+
+ public DirectStreamResultBuilder withFailedUploads(int failedUploads)
+ {
+ this.failedUploads = failedUploads;
+ return this;
+ }
+
+ public DirectStreamResultBuilder withSuccessfulCommits(int successfulCommits)
+ {
+ this.successfulCommits = successfulCommits;
+ return this;
+ }
+
+ public DirectStreamResultBuilder withFailedCommits(int failedCommits)
+ {
+ this.failedCommits = failedCommits;
+ return this;
+ }
+
+ DirectStreamResult build()
+ {
+ allInstances = this.topology.getTokenRanges().keySet().toArray(new RingInstance[0]);
+ DirectStreamResult sr = new DirectStreamResult(UUID.randomUUID().toString(),
+ TEST_RANGE,
+ buildFailures(),
+ buildPassed(), 0, 0);
+ if (successfulCommits > 0 || failedCommits > 0)
+ {
+ List<CommitResult> commitResults = new ArrayList<>();
+ for (RingInstance inst : this.topology.getTokenRanges().keySet())
+ {
+ CommitResult cr = new CommitResult(
+ sr.sessionID,
+ inst,
+ ImmutableMap.of(sr.sessionID, this.topology.getTokenRanges().get(inst).stream().findFirst().get())
+ );
+ commitResults.add(cr);
+ }
+ sr.setCommitResults(commitResults);
+ }
+ return sr;
+ }
+
+ private List<RingInstance> buildPassed()
+ {
+ return new ArrayList<>(Arrays.asList(allInstances).subList(0, successfulUploads));
+ }
+
+ private ArrayList<StreamError> buildFailures()
+ {
+ ArrayList<StreamError> failedInstances = new ArrayList<>();
+ for (int i = failedUploads - 1; i >= 0; i--)
+ {
+ failedInstances.add(new StreamError(TEST_RANGE, allInstances[i], "failed"));
+ }
+ return failedInstances;
+ }
+ }
+}
diff --git a/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/StreamSessionTest.java b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/DirectStreamSessionTest.java
similarity index 61%
rename from cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/StreamSessionTest.java
rename to cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/DirectStreamSessionTest.java
index 82924cb..aa16da3 100644
--- a/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/StreamSessionTest.java
+++ b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/DirectStreamSessionTest.java
@@ -21,11 +21,13 @@
import java.io.IOException;
import java.math.BigInteger;
+import java.nio.file.NoSuchFileException;
import java.nio.file.Path;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutionException;
+import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors;
@@ -42,20 +44,23 @@
import org.apache.cassandra.spark.common.model.CassandraInstance;
import org.apache.cassandra.spark.utils.DigestAlgorithm;
import org.apache.cassandra.spark.utils.XXHash32DigestAlgorithm;
+import org.assertj.core.api.Assertions;
import org.jetbrains.annotations.NotNull;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.iterableWithSize;
import static org.hamcrest.Matchers.matchesPattern;
import static org.hamcrest.Matchers.startsWith;
import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
-public class StreamSessionTest
+public class DirectStreamSessionTest
{
public static final String LOAD_RANGE_ERROR_PREFIX = "Failed to load 1 ranges with LOCAL_QUORUM";
private static final Map<String, Object> COLUMN_BOUND_VALUES = ImmutableMap.of("id", 0, "date", 1, "course", "course", "marks", 2);
@@ -64,8 +69,8 @@
private static final int FILES_PER_SSTABLE = 8;
private static final int RF = 3;
- private StreamSession ss;
private MockBulkWriterContext writerContext;
+ private TransportContext.DirectDataBulkWriterContext transportContext;
private List<String> expectedInstances;
private TokenRangeMapping<RingInstance> tokenRangeMapping;
private MockScheduledExecutorService executor;
@@ -81,28 +86,29 @@
tokenRangeMapping = TokenRangeMappingUtils.buildTokenRangeMapping(0, ImmutableMap.of("DC1", 3), 12);
writerContext = getBulkWriterContext();
tableWriter = new MockTableWriter(folder);
+ transportContext = (TransportContext.DirectDataBulkWriterContext) writerContext.transportContext();
executor = new MockScheduledExecutorService();
- ss = new StreamSession(writerContext, "sessionId", range, executor, new ReplicaAwareFailureHandler<>(writerContext.cluster().getPartitioner()));
expectedInstances = Lists.newArrayList("DC1-i1", "DC1-i2", "DC1-i3");
}
@Test
- public void testGetReplicasReturnsCorrectData()
+ void testGetReplicasReturnsCorrectData()
{
- List<RingInstance> replicas = ss.getReplicas();
+ StreamSession<?> streamSession = createStreamSession(SortedSSTableWriter::new);
+ List<RingInstance> replicas = streamSession.getReplicas();
assertNotNull(replicas);
List<String> actualInstances = replicas.stream().map(RingInstance::nodeName).collect(Collectors.toList());
assertThat(actualInstances, containsInAnyOrder(expectedInstances.toArray()));
}
@Test
- public void testScheduleStreamSendsCorrectFilesToCorrectInstances() throws IOException, ExecutionException, InterruptedException
+ void testScheduleStreamSendsCorrectFilesToCorrectInstances() throws IOException, ExecutionException, InterruptedException
{
- SSTableWriter tr = new NonValidatingTestSSTableWriter(tableWriter, folder, digestAlgorithm);
- tr.addRow(BigInteger.valueOf(102L), COLUMN_BOUND_VALUES);
- tr.close(writerContext, 1);
- ss.scheduleStream(tr);
- ss.close(); // Force "execution" of futures
+ StreamSession<?> ss = createStreamSession(NonValidatingTestSortedSSTableWriter::new);
+ ss.addRow(BigInteger.valueOf(102L), COLUMN_BOUND_VALUES);
+ assertThat(ss.rowCount(), is(1L));
+ StreamResult streamResult = ss.scheduleStreamAsync(1, executor).get();
+ assertThat(streamResult.rowCount, is(1L));
executor.assertFuturesCalled();
assertThat(executor.futures.size(), equalTo(1)); // We only scheduled one SSTable
assertThat(writerContext.getUploads().values().stream().mapToInt(Collection::size).sum(), equalTo(RF * FILES_PER_SSTABLE));
@@ -111,19 +117,33 @@
}
@Test
- public void testMismatchedTokenRangeFails() throws IOException
+ void testEmptyTokenRangeFails()
{
- SSTableWriter tr = new NonValidatingTestSSTableWriter(tableWriter, folder, digestAlgorithm);
- tr.addRow(BigInteger.valueOf(9999L), COLUMN_BOUND_VALUES);
- tr.close(writerContext, 1);
+ Exception exception = assertThrows(IllegalStateException.class,
+ () -> new DirectStreamSession(
+ writerContext,
+ new NonValidatingTestSortedSSTableWriter(tableWriter, folder, digestAlgorithm),
+ transportContext,
+ "sessionId",
+ Range.range(BigInteger.valueOf(0L), BoundType.OPEN, BigInteger.valueOf(0L), BoundType.CLOSED),
+ replicaAwareFailureHandler())
+ );
+ assertThat(exception.getMessage(), is("No replicas found for range (0‥0]"));
+ }
+
+ @Test
+ void testMismatchedTokenRangeFails() throws IOException
+ {
+ StreamSession<?> ss = createStreamSession(NonValidatingTestSortedSSTableWriter::new);
+ ss.addRow(BigInteger.valueOf(9999L), COLUMN_BOUND_VALUES);
IllegalStateException illegalStateException = assertThrows(IllegalStateException.class,
- () -> ss.scheduleStream(tr));
+ () -> ss.scheduleStreamAsync(1, executor));
assertThat(illegalStateException.getMessage(), matchesPattern(
"SSTable range \\[9999(‥|..)9999] should be enclosed in the partition range \\[101(‥|..)199]"));
}
@Test
- public void testUploadFailureCallsClean() throws IOException, ExecutionException, InterruptedException
+ void testUploadFailureCallsClean() throws IOException, ExecutionException, InterruptedException
{
runFailedUpload();
@@ -134,7 +154,7 @@
}
@Test
- public void testUploadFailureSkipsCleanWhenConfigured() throws IOException, ExecutionException, InterruptedException
+ void testUploadFailureSkipsCleanWhenConfigured() throws IOException, ExecutionException, InterruptedException
{
writerContext.setSkipCleanOnFailures(true);
runFailedUpload();
@@ -147,29 +167,29 @@
.stream()
.flatMap(Collection::stream)
.collect(Collectors.toList());
- assertTrue(uploads.size() > 0);
+ assertFalse(uploads.isEmpty());
assertTrue(uploads.stream().noneMatch(u -> u.uploadSucceeded));
}
@Test
- public void testUploadFailureRefreshesClusterInfo() throws IOException, ExecutionException, InterruptedException
+ void testUploadFailureRefreshesClusterInfo() throws IOException, ExecutionException, InterruptedException
{
runFailedUpload();
assertThat(writerContext.refreshClusterInfoCallCount(), equalTo(3));
}
@Test
- public void testOutDirCreationFailureCleansAllReplicas()
+ void testOutDirCreationFailureCleansAllReplicas()
{
- assertThrows(RuntimeException.class, () -> {
- SSTableWriter tr = new NonValidatingTestSSTableWriter(tableWriter, tableWriter.getOutDir(), digestAlgorithm);
- tr.addRow(BigInteger.valueOf(102L), COLUMN_BOUND_VALUES);
- tr.close(writerContext, 1);
+ ExecutionException ex = assertThrows(ExecutionException.class, () -> {
+ StreamSession<?> ss = createStreamSession(NonValidatingTestSortedSSTableWriter::new);
+ ss.addRow(BigInteger.valueOf(102L), COLUMN_BOUND_VALUES);
+ Future<?> fut = ss.scheduleStreamAsync(1, executor);
tableWriter.removeOutDir();
- ss.scheduleStream(tr);
- ss.close();
+ fut.get();
});
+ Assertions.assertThat(ex).hasRootCauseInstanceOf(NoSuchFileException.class);
List<String> actualInstances = writerContext.getCleanedInstances().stream()
.map(CassandraInstance::nodeName)
.collect(Collectors.toList());
@@ -177,46 +197,30 @@
}
@Test
- public void streamWithNoWritersReturnsEmptyStreamResult() throws ExecutionException, InterruptedException
- {
- writerContext.setInstancesAreAvailable(false);
- ss = new StreamSession(writerContext, "sessionId", range, executor, new ReplicaAwareFailureHandler<>(writerContext.cluster().getPartitioner()));
- StreamResult result = ss.close();
- assertThat(result.failures.size(), equalTo(0));
- assertThat(result.passed.size(), equalTo(0));
- assertThat(result.sessionID, equalTo("sessionId"));
- assertThat(result.tokenRange, equalTo(range));
- }
-
- @Test
- public void failedCleanDoesNotThrow() throws IOException, ExecutionException, InterruptedException
+ void failedCleanDoesNotThrow() throws IOException
{
writerContext.setCleanShouldThrow(true);
runFailedUpload();
}
@Test
- public void testLocalQuorumSucceedsWhenSingleCommitFails(
- ) throws IOException, ExecutionException, InterruptedException
+ void testLocalQuorumSucceedsWhenSingleCommitFails() throws IOException, ExecutionException, InterruptedException
{
- ss = new StreamSession(writerContext, "sessionId", range, executor, new ReplicaAwareFailureHandler<>(writerContext.cluster().getPartitioner()));
+ StreamSession<?> ss = createStreamSession(NonValidatingTestSortedSSTableWriter::new);
AtomicBoolean success = new AtomicBoolean(true);
writerContext.setCommitResultSupplier((uuids, dc) -> {
// Return failed result for 1st result, success for the rest
if (success.getAndSet(false))
{
- return new DataTransferApi.RemoteCommitResult(false, uuids, null, "");
+ return new DirectDataTransferApi.RemoteCommitResult(false, uuids, null, "");
}
else
{
- return new DataTransferApi.RemoteCommitResult(true, null, uuids, "");
+ return new DirectDataTransferApi.RemoteCommitResult(true, null, uuids, "");
}
});
- SSTableWriter tr = new NonValidatingTestSSTableWriter(tableWriter, folder, digestAlgorithm);
- tr.addRow(BigInteger.valueOf(102L), COLUMN_BOUND_VALUES);
- tr.close(writerContext, 1);
- ss.scheduleStream(tr);
- ss.close(); // Force "execution" of futures
+ ss.addRow(BigInteger.valueOf(102L), COLUMN_BOUND_VALUES);
+ ss.scheduleStreamAsync(1, executor).get();
executor.assertFuturesCalled();
assertThat(writerContext.getUploads().values().stream().mapToInt(Collection::size).sum(), equalTo(RF * FILES_PER_SSTABLE));
final List<String> instances = writerContext.getUploads().keySet().stream().map(CassandraInstance::nodeName).collect(Collectors.toList());
@@ -224,44 +228,40 @@
}
@Test
- public void testLocalQuorumFailsWhenCommitsFail() throws IOException, ExecutionException, InterruptedException
+ void testLocalQuorumFailsWhenCommitsFail() throws IOException
{
- ss = new StreamSession(writerContext, "sessionId", range, executor, new ReplicaAwareFailureHandler<>(writerContext.cluster().getPartitioner()));
+ StreamSession<?> ss = createStreamSession(NonValidatingTestSortedSSTableWriter::new);
AtomicBoolean success = new AtomicBoolean(true);
// Return successful result for 1st result, failed for the rest
writerContext.setCommitResultSupplier((uuids, dc) -> {
if (success.getAndSet(false))
{
- return new DataTransferApi.RemoteCommitResult(true, null, uuids, "");
+ return new DirectDataTransferApi.RemoteCommitResult(true, null, uuids, "");
}
else
{
- return new DataTransferApi.RemoteCommitResult(false, uuids, null, "");
+ return new DirectDataTransferApi.RemoteCommitResult(false, uuids, null, "");
}
});
-
- SSTableWriter tr = new NonValidatingTestSSTableWriter(tableWriter, folder, digestAlgorithm);
- tr.addRow(BigInteger.valueOf(102L), COLUMN_BOUND_VALUES);
- tr.close(writerContext, 1);
- ss.scheduleStream(tr);
- RuntimeException exception = assertThrows(RuntimeException.class, () -> ss.close()); // Force "execution" of futures
+ ss.addRow(BigInteger.valueOf(102L), COLUMN_BOUND_VALUES);
+ ExecutionException exception = assertThrows(ExecutionException.class,
+ () -> ss.scheduleStreamAsync(1, executor).get());
assertEquals("Failed to load 1 ranges with LOCAL_QUORUM for job " + writerContext.job().getId()
- + " in phase UploadAndCommit.", exception.getMessage());
+ + " in phase UploadAndCommit.", exception.getCause().getMessage());
executor.assertFuturesCalled();
assertThat(writerContext.getUploads().values().stream().mapToInt(Collection::size).sum(), equalTo(RF * FILES_PER_SSTABLE));
List<String> instances = writerContext.getUploads().keySet().stream().map(CassandraInstance::nodeName).collect(Collectors.toList());
assertThat(instances, containsInAnyOrder(expectedInstances.toArray()));
}
- private void runFailedUpload() throws IOException, ExecutionException, InterruptedException
+ private void runFailedUpload() throws IOException
{
writerContext.setUploadSupplier(instance -> false);
- SSTableWriter tr = new NonValidatingTestSSTableWriter(tableWriter, folder, digestAlgorithm);
- tr.addRow(BigInteger.valueOf(102L), COLUMN_BOUND_VALUES);
- tr.close(writerContext, 1);
- ss.scheduleStream(tr);
- RuntimeException ex = assertThrows(RuntimeException.class, () -> ss.close());
- assertThat(ex.getMessage(), startsWith(LOAD_RANGE_ERROR_PREFIX));
+ StreamSession<?> ss = createStreamSession(NonValidatingTestSortedSSTableWriter::new);
+ ss.addRow(BigInteger.valueOf(102L), COLUMN_BOUND_VALUES);
+ ExecutionException ex = assertThrows(ExecutionException.class,
+ () -> ss.scheduleStreamAsync(1, executor).get());
+ assertThat(ex.getCause().getMessage(), startsWith(LOAD_RANGE_ERROR_PREFIX));
}
@NotNull
@@ -269,4 +269,20 @@
{
return new MockBulkWriterContext(tokenRangeMapping);
}
+
+ @NotNull
+ private ReplicaAwareFailureHandler<RingInstance> replicaAwareFailureHandler()
+ {
+ return new ReplicaAwareFailureHandler<>(writerContext.cluster().getPartitioner());
+ }
+
+ private DirectStreamSession createStreamSession(MockTableWriter.Creator writerCreator)
+ {
+ return new DirectStreamSession(writerContext,
+ writerCreator.create(tableWriter, folder, digestAlgorithm),
+ transportContext,
+ "sessionId",
+ range,
+ replicaAwareFailureHandler());
+ }
}
diff --git a/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/HeartbeatReporterTest.java b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/HeartbeatReporterTest.java
new file mode 100644
index 0000000..7a1c4bd
--- /dev/null
+++ b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/HeartbeatReporterTest.java
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.cassandra.spark.bulkwriter;
+
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
+
+import com.google.common.util.concurrent.Uninterruptibles;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.Test;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+public class HeartbeatReporterTest
+{
+ private HeartbeatReporter heartbeatReporter = new HeartbeatReporter();
+ private String heartbeatName = "test-heartbeat";
+
+ @AfterEach
+ public void teardown()
+ {
+ heartbeatReporter.unschedule(heartbeatName);
+ }
+
+ @Test
+ public void testScheduleHeartbeat()
+ {
+ CountDownLatch latch = new CountDownLatch(10);
+ long start = System.nanoTime();
+ heartbeatReporter.schedule(heartbeatName, 10, latch::countDown);
+ Uninterruptibles.awaitUninterruptibly(latch);
+ assertEquals(0, latch.getCount());
+ assertTrue(System.nanoTime() > start + TimeUnit.MILLISECONDS.toNanos(10 * 10));
+ }
+
+ @Test
+ public void testScheduleSuppressThrows()
+ {
+ CountDownLatch latch = new CountDownLatch(10);
+ long start = System.nanoTime();
+ heartbeatReporter.schedule(heartbeatName, 10, () -> {
+ latch.countDown();
+ throw new RuntimeException("It fails");
+ });
+ Uninterruptibles.awaitUninterruptibly(latch);
+ assertEquals(0, latch.getCount());
+ assertTrue(System.nanoTime() > start + TimeUnit.MILLISECONDS.toNanos(10 * 10));
+ }
+}
diff --git a/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/ImportCompletionCoordinatorTest.java b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/ImportCompletionCoordinatorTest.java
new file mode 100644
index 0000000..8d35f8d
--- /dev/null
+++ b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/ImportCompletionCoordinatorTest.java
@@ -0,0 +1,402 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.cassandra.spark.bulkwriter;
+
+import java.math.BigInteger;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.UUID;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.CompletionException;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ThreadLocalRandom;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.Consumer;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.Range;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+import o.a.c.sidecar.client.shaded.common.data.CreateSliceRequestPayload;
+import o.a.c.sidecar.client.shaded.common.data.RingEntry;
+import org.apache.cassandra.sidecar.client.SidecarClient;
+import org.apache.cassandra.sidecar.client.SidecarInstanceImpl;
+import org.apache.cassandra.sidecar.client.exception.RetriesExhaustedException;
+import org.apache.cassandra.sidecar.client.request.Request;
+import org.apache.cassandra.spark.bulkwriter.ImportCompletionCoordinator.RequestAndInstance;
+import org.apache.cassandra.spark.bulkwriter.blobupload.BlobDataTransferApi;
+import org.apache.cassandra.spark.bulkwriter.blobupload.BlobStreamResult;
+import org.apache.cassandra.spark.bulkwriter.blobupload.CreatedRestoreSlice;
+import org.apache.cassandra.spark.bulkwriter.blobupload.StorageClient;
+import org.apache.cassandra.spark.bulkwriter.token.ConsistencyLevel;
+import org.apache.cassandra.spark.bulkwriter.token.ReplicaAwareFailureHandler;
+import org.apache.cassandra.spark.bulkwriter.token.TokenRangeMapping;
+import org.apache.cassandra.spark.data.QualifiedTableName;
+import org.apache.cassandra.spark.data.partitioner.Partitioner;
+import org.apache.cassandra.spark.transports.storage.extensions.StorageTransportExtension;
+import org.mockito.ArgumentCaptor;
+import org.mockito.stubbing.Answer;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertNull;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyBoolean;
+import static org.mockito.ArgumentMatchers.anyLong;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.atMostOnce;
+import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.doNothing;
+import static org.mockito.Mockito.doReturn;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+class ImportCompletionCoordinatorTest
+{
+ private static final int TOTAL_INSTANCES = 10;
+
+ BulkWriterContext mockWriterContext;
+ BulkWriteValidator writerValidator;
+ TokenRangeMapping<RingInstance> topology;
+ JobInfo mockJobInfo;
+ BlobDataTransferApi dataTransferApi;
+ UUID jobId;
+ StorageTransportExtension mockExtension;
+ ArgumentCaptor<String> appliedObjectKeys;
+ Consumer<CancelJobEvent> onCancelJob;
+
+ @BeforeEach
+ public void setup() throws Exception
+ {
+ mockJobInfo = mock(JobInfo.class);
+ jobId = UUID.randomUUID();
+ when(mockJobInfo.getId()).thenReturn(jobId.toString());
+ when(mockJobInfo.getRestoreJobId()).thenReturn(jobId);
+ when(mockJobInfo.qualifiedTableName()).thenReturn(new QualifiedTableName("testkeyspace", "testtable"));
+ when(mockJobInfo.getConsistencyLevel()).thenReturn(ConsistencyLevel.CL.QUORUM);
+ when(mockJobInfo.effectiveSidecarPort()).thenReturn(9043);
+ when(mockJobInfo.jobKeepAliveMinutes()).thenReturn(-1);
+
+ mockWriterContext = mock(BulkWriterContext.class);
+ ClusterInfo mockClusterInfo = mock(ClusterInfo.class);
+ when(mockWriterContext.cluster()).thenReturn(mockClusterInfo);
+
+ CassandraContext mockCassandraContext = mock(CassandraContext.class);
+ when(mockClusterInfo.getCassandraContext()).thenReturn(mockCassandraContext);
+ topology = TokenRangeMappingUtils.buildTokenRangeMapping(0, ImmutableMap.of("DC1", 3), TOTAL_INSTANCES);
+ when(mockClusterInfo.getTokenRangeMapping(anyBoolean())).thenReturn(topology);
+ when(mockWriterContext.job()).thenReturn(mockJobInfo);
+
+ writerValidator = new BulkWriteValidator(mockWriterContext, new ReplicaAwareFailureHandler<>(Partitioner.Murmur3Partitioner));
+
+ // clients will not be used in this test class; mock is at the API method level
+ BlobDataTransferApi api = new BlobDataTransferApi(mockJobInfo, mock(SidecarClient.class), mock(StorageClient.class));
+ dataTransferApi = spy(api);
+
+ mockExtension = mock(StorageTransportExtension.class);
+ appliedObjectKeys = ArgumentCaptor.forClass(String.class);
+ doNothing().when(mockExtension).onObjectApplied(any(), appliedObjectKeys.capture(), anyLong(), anyLong());
+
+ onCancelJob = event -> {
+ throw new RuntimeException("It should not be called");
+ };
+ }
+
+ @Test
+ void testAwaitForCompletionWithNoErrors()
+ {
+ List<BlobStreamResult> resultList = buildBlobStreamResult(0, false, 0);
+ ImportCompletionCoordinator.of(0, mockWriterContext, dataTransferApi,
+ writerValidator, resultList, mockExtension, onCancelJob)
+ .waitForCompletion();
+ validateAllSlicesWereCalledAtMostOnce(resultList);
+ assertEquals(resultList.size(), appliedObjectKeys.getAllValues().size(),
+ "All objects should be applied and reported for exactly once");
+ assertEquals(allTestObjectKeys(), new HashSet<>(appliedObjectKeys.getAllValues()));
+ }
+
+ @Test
+ void testAwaitForCompletionWithNoErrorsAndSlowImport()
+ {
+ List<BlobStreamResult> resultList = buildBlobStreamResult(0, true, 0);
+ ImportCompletionCoordinator.of(0, mockWriterContext, dataTransferApi,
+ writerValidator, resultList, mockExtension, onCancelJob)
+ .waitForCompletion();
+ validateAllSlicesWereCalledAtMostOnce(resultList);
+ assertEquals(resultList.size(), appliedObjectKeys.getAllValues().size(),
+ "All objects should be applied and reported for exactly once");
+ assertEquals(allTestObjectKeys(), new HashSet<>(appliedObjectKeys.getAllValues()));
+ }
+
+ @Test // the test scenario has error when checking, but CL passes overall and the import is successful
+ void testAwaitForCompletionWithErrorsAndCLPasses()
+ {
+ // There is 1 failure in each replica set. 2 out of 3 replicas succeeds.
+ List<BlobStreamResult> resultList = buildBlobStreamResult(1, false, 0);
+ ImportCompletionCoordinator.of(0, mockWriterContext, dataTransferApi,
+ writerValidator, resultList, mockExtension, onCancelJob)
+ .waitForCompletion();
+ validateAllSlicesWereCalledAtMostOnce(resultList);
+ assertEquals(resultList.size(), appliedObjectKeys.getAllValues().size(),
+ "All objects should be applied and reported for exactly once");
+ assertEquals(allTestObjectKeys(), new HashSet<>(appliedObjectKeys.getAllValues()));
+ }
+
+ @Test // the test scenario has errors that fails CL, the import fails
+ void testAwaitForCompletionWithErrorsAndCLFails()
+ {
+ // There is 2 failure in each replica set. Only 1 out of 3 replicas succeeds.
+ // All replica sets fail, the number of ranges is not deterministic.
+ // Therefore, the assertion omits the number of ranges in the message
+ String errorMessage = "ranges with QUORUM for job " + jobId + " in phase WaitForCommitCompletion";
+ List<BlobStreamResult> resultList = buildBlobStreamResult(2, false, 0);
+ RuntimeException exception = assertThrows(RuntimeException.class, () -> {
+ ImportCompletionCoordinator.of(0, mockWriterContext, dataTransferApi,
+ writerValidator, resultList, mockExtension, onCancelJob)
+ .waitForCompletion();
+ });
+ assertNotNull(exception.getMessage());
+ assertTrue(exception.getMessage().contains("Failed to load"));
+ assertTrue(exception.getMessage().contains(errorMessage));
+ assertNotNull(exception.getCause());
+ validateAllSlicesWereCalledAtMostOnce(resultList);
+ assertEquals(0, appliedObjectKeys.getAllValues().size(),
+ "No object should be applied and reported");
+ }
+
+ @Test
+ void testCLUnsatisfiedRanges()
+ {
+ String errorMessage = "Some of the token ranges cannot satisfy with consistency level. job=" + jobId + " phase=WaitForCommitCompletion";
+ // CL check won't fail as there is no failed instances.
+ // The check won't be satisfied too since there is not enough available instances.
+ List<BlobStreamResult> resultList = buildBlobStreamResult(0, false, 2);
+ RuntimeException exception = assertThrows(RuntimeException.class, () -> {
+ ImportCompletionCoordinator.of(0, mockWriterContext, dataTransferApi,
+ writerValidator, resultList, mockExtension, onCancelJob)
+ .waitForCompletion();
+ });
+ assertNotNull(exception.getMessage());
+ assertTrue(exception.getMessage().contains(errorMessage));
+ assertNull(exception.getCause());
+ validateAllSlicesWereCalledAtMostOnce(resultList);
+ assertEquals(0, appliedObjectKeys.getAllValues().size(),
+ "No object should be applied and reported");
+ }
+
+ @Test
+ void testAwaitShouldPassWithStuckSliceWhenClSatisfied()
+ {
+ /*
+ * When slice import is stuck on server side, i.e. import request never indicate the slice is complete.
+ * If the consistency level has been satisfied for all ranges, it is safe to ignore the abnormal status
+ * of the stuck slices.
+ * The test verifies that in such scenario, ImportCompletionCoordinator does not block forever,
+ * and it can conclude success result
+ */
+ List<BlobStreamResult> resultList = buildBlobStreamResultWithNoProgressImports(1);
+ ImportCompletionCoordinator coordinator = ImportCompletionCoordinator.of(0, mockWriterContext, dataTransferApi,
+ writerValidator, resultList, mockExtension, onCancelJob);
+ coordinator.waitForCompletion();
+ assertEquals(resultList.size(), appliedObjectKeys.getAllValues().size(),
+ "All objects should be applied and reported for exactly once");
+ assertEquals(allTestObjectKeys(), new HashSet<>(appliedObjectKeys.getAllValues()));
+ Map<CompletableFuture<Void>, RequestAndInstance> importFutures = coordinator.importFutures();
+ int cancelledImports = importFutures.keySet().stream().mapToInt(f -> f.isCancelled() ? 1 : 0).sum();
+ assertEquals(TOTAL_INSTANCES, cancelledImports,
+ "Each replica set should have a slice gets cancelled due to making no progress");
+ }
+
+ @Test
+ void testJobCancelOnTopologyChanged()
+ {
+ AtomicBoolean isCancelled = new AtomicBoolean(false);
+ Consumer<CancelJobEvent> onCancel = event -> {
+ isCancelled.set(true);
+ };
+ BulkWriterContext mockWriterContext = mock(BulkWriterContext.class);
+ ClusterInfo mockClusterInfo = mock(ClusterInfo.class);
+ when(mockWriterContext.cluster()).thenReturn(mockClusterInfo);
+ when(mockClusterInfo.getTokenRangeMapping(false))
+ .thenReturn(TokenRangeMappingUtils.buildTokenRangeMapping(0,
+ ImmutableMap.of("DC1", 3),
+ TOTAL_INSTANCES))
+ .thenReturn(TokenRangeMappingUtils.buildTokenRangeMapping(0,
+ ImmutableMap.of("DC1", 3),
+ TOTAL_INSTANCES + 1)); // adding a new instance; expansion
+ List<BlobStreamResult> resultList = buildBlobStreamResult(0, false, 0);
+ AtomicReference<CassandraTopologyMonitor> monitorRef = new AtomicReference<>(null);
+ ImportCompletionCoordinator coordinator = new ImportCompletionCoordinator(0, mockWriterContext, dataTransferApi,
+ writerValidator, resultList, mockExtension, onCancel,
+ (clusterInfo, onCancelJob) -> {
+ monitorRef.set(new CassandraTopologyMonitor(clusterInfo, onCancelJob));
+ return monitorRef.get();
+ });
+ monitorRef.get().checkTopologyOnDemand();
+ CompletionException coordinatorEx = assertThrows(CompletionException.class, coordinator::waitForCompletion);
+ assertEquals("Topology changed during bulk write", coordinatorEx.getCause().getMessage());
+ assertTrue(isCancelled.get());
+ CompletableFuture<Void> firstFailure = coordinator.firstFailure();
+ assertTrue(firstFailure.isCompletedExceptionally());
+ ExecutionException firstFailureEx = assertThrows(ExecutionException.class, firstFailure::get);
+ assertEquals(coordinatorEx.getCause(), firstFailureEx.getCause());
+ }
+
+ private Set<String> allTestObjectKeys()
+ {
+ return IntStream.range(0, 10).boxed().map(i -> "key_for_instance_" + i).collect(Collectors.toSet());
+ }
+
+ private List<BlobStreamResult> buildBlobStreamResultWithNoProgressImports(int noProgressInstanceCount)
+ {
+ return buildBlobStreamResult(0, false, 0, noProgressInstanceCount);
+ }
+
+ private List<BlobStreamResult> buildBlobStreamResult(int failedInstanceCount, boolean simulateSlowImport, int unavailableInstanceCount)
+ {
+ return buildBlobStreamResult(failedInstanceCount, simulateSlowImport, unavailableInstanceCount, 0);
+ }
+
+ /**
+ * @param failedInstanceCount number of instances in each replica set that fail the http request
+ * @param simulateSlowImport slow import with artificial delay
+ * @param unavailableInstanceCount number of instances in each replica set that is not included in the BlobStreamResult
+ * @param noProgressInstanceCount number of instances in each replica set that make no progress, i.e. future never complete
+ * @return a list of blob stream result
+ */
+ private List<BlobStreamResult> buildBlobStreamResult(int failedInstanceCount,
+ boolean simulateSlowImport,
+ int unavailableInstanceCount,
+ int noProgressInstanceCount)
+ {
+ List<BlobStreamResult> resultList = new ArrayList<>();
+ int totalInstances = 10;
+
+ for (int i = 0; i < totalInstances; i++)
+ {
+ List<RingInstance> replicaSet = Arrays.asList(ringInstance(i, totalInstances),
+ ringInstance(i + 1, totalInstances),
+ ringInstance(i + 2, totalInstances));
+ Set<CreatedRestoreSlice> createdRestoreSlices = new HashSet<>();
+ int failedPerReplica = failedInstanceCount;
+ int unavailablePerReplica = unavailableInstanceCount;
+ int noProgressPerReplicaSet = noProgressInstanceCount;
+ // create one distinct slice per instance
+ CreateSliceRequestPayload mockCreateSliceRequestPayload = mock(CreateSliceRequestPayload.class);
+ when(mockCreateSliceRequestPayload.startToken()).thenReturn(BigInteger.valueOf(100 * i));
+ when(mockCreateSliceRequestPayload.endToken()).thenReturn(BigInteger.valueOf(100 * (1 + i)));
+ when(mockCreateSliceRequestPayload.sliceId()).thenReturn(UUID.randomUUID().toString());
+ when(mockCreateSliceRequestPayload.key()).thenReturn("key_for_instance_" + i); // to be captured by extension mock
+ when(mockCreateSliceRequestPayload.bucket()).thenReturn("bucket"); // to be captured by extension mock
+ when(mockCreateSliceRequestPayload.compressedSize()).thenReturn(1L); // to be captured by extension mock
+ when(mockCreateSliceRequestPayload.compressedSizeOrZero()).thenReturn(1L);
+ List<RingInstance> passedReplicaSet = new ArrayList<>();
+ for (RingInstance instance : replicaSet)
+ {
+ if (unavailablePerReplica-- > 0)
+ {
+ continue; // do not include this instance
+ }
+ passedReplicaSet.add(instance);
+ createdRestoreSlices.add(new CreatedRestoreSlice(mockCreateSliceRequestPayload));
+ if (simulateSlowImport && i == totalInstances - 1)
+ {
+ // only add slowness for the last import
+ doAnswer((Answer<CompletableFuture<Void>>) invocation -> {
+ Thread.sleep(ThreadLocalRandom.current().nextInt(2000));
+ return CompletableFuture.completedFuture(null);
+ })
+ .when(dataTransferApi)
+ .createRestoreSliceFromDriver(eq(new SidecarInstanceImpl(instance.nodeName(), 9043)),
+ eq(mockCreateSliceRequestPayload));
+ }
+ else if (noProgressPerReplicaSet-- > 0)
+ {
+ // return a future that does complete
+ doReturn(new CompletableFuture<>())
+ .when(dataTransferApi)
+ .createRestoreSliceFromDriver(eq(new SidecarInstanceImpl(instance.nodeName(), 9043)),
+ eq(mockCreateSliceRequestPayload));
+ }
+ else if (failedPerReplica-- > 0)
+ {
+ CompletableFuture<Void> future = new CompletableFuture<>();
+ future.completeExceptionally(RetriesExhaustedException.of(10, mock(Request.class), null));
+ doReturn(future)
+ .when(dataTransferApi)
+ .createRestoreSliceFromDriver(eq(new SidecarInstanceImpl(instance.nodeName(), 9043)),
+ eq(mockCreateSliceRequestPayload));
+ }
+ else
+ {
+ doReturn(CompletableFuture.completedFuture(null))
+ .when(dataTransferApi)
+ .createRestoreSliceFromDriver(eq(new SidecarInstanceImpl(instance.nodeName(), 9043)),
+ eq(mockCreateSliceRequestPayload));
+ }
+ }
+ BlobStreamResult result = new BlobStreamResult("", mock(Range.class), Collections.emptyList(),
+ passedReplicaSet, createdRestoreSlices, 0, 0);
+ resultList.add(result);
+ }
+ return resultList;
+ }
+
+ // Some slice might not be called due to short circuit, hence at most once
+ private void validateAllSlicesWereCalledAtMostOnce(List<BlobStreamResult> resultList)
+ {
+ for (BlobStreamResult blobStreamResult : resultList)
+ {
+ for (RingInstance instance : blobStreamResult.passed)
+ {
+ for (CreatedRestoreSlice createdRestoreSlice : blobStreamResult.createdRestoreSlices)
+ {
+ verify(dataTransferApi, atMostOnce())
+ .createRestoreSliceFromDriver(eq(new SidecarInstanceImpl(instance.nodeName(), 9043)),
+ eq(createdRestoreSlice.sliceRequestPayload()));
+ }
+ }
+ }
+ }
+
+ private RingInstance ringInstance(int i, int totalInstances)
+ {
+ int instanceInRing = i % totalInstances + 1;
+ return new RingInstance(new RingEntry.Builder()
+ .datacenter("DC1")
+ .address("127.0.0." + instanceInRing)
+ .token(String.valueOf(i * 100))
+ .fqdn("instance-" + instanceInRing)
+ .build());
+ }
+}
diff --git a/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/MockBulkWriterContext.java b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/MockBulkWriterContext.java
index cf6a6f5..205d412 100644
--- a/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/MockBulkWriterContext.java
+++ b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/MockBulkWriterContext.java
@@ -19,6 +19,7 @@
package org.apache.cassandra.spark.bulkwriter;
+import java.math.BigInteger;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Collections;
@@ -37,6 +38,7 @@
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.Range;
import org.apache.commons.lang3.tuple.Pair;
import o.a.c.sidecar.client.shaded.common.data.TimeSkewResponse;
@@ -44,6 +46,7 @@
import org.apache.cassandra.bridge.CassandraBridgeFactory;
import org.apache.cassandra.bridge.CassandraVersion;
import org.apache.cassandra.spark.bulkwriter.token.ConsistencyLevel;
+import org.apache.cassandra.spark.bulkwriter.token.ReplicaAwareFailureHandler;
import org.apache.cassandra.spark.bulkwriter.token.TokenRangeMapping;
import org.apache.cassandra.spark.common.Digest;
import org.apache.cassandra.spark.common.client.ClientException;
@@ -54,6 +57,7 @@
import org.apache.cassandra.spark.common.schema.ColumnTypes;
import org.apache.cassandra.spark.common.stats.JobStatsPublisher;
import org.apache.cassandra.spark.data.CqlField;
+import org.apache.cassandra.spark.data.QualifiedTableName;
import org.apache.cassandra.spark.data.partitioner.Partitioner;
import org.apache.cassandra.spark.validation.StartupValidator;
import org.apache.spark.sql.types.DataTypes;
@@ -65,7 +69,7 @@
import static org.apache.cassandra.spark.bulkwriter.SqlToCqlTypeConverter.VARCHAR;
import static org.apache.cassandra.spark.bulkwriter.TableSchemaTestCommon.mockCqlType;
-public class MockBulkWriterContext implements BulkWriterContext, ClusterInfo, JobInfo, SchemaInfo, DataTransferApi, JobStatsPublisher
+public class MockBulkWriterContext implements BulkWriterContext, ClusterInfo, JobInfo, SchemaInfo, JobStatsPublisher
{
private static final long serialVersionUID = -2912371629236770646L;
public static final String[] DEFAULT_PARTITION_KEY_COLUMNS = {"id", "date"};
@@ -75,10 +79,8 @@
new String[]{"id", "date", "course", "marks"},
new org.apache.spark.sql.types.DataType[]{DataTypes.IntegerType, DataTypes.DateType, DataTypes.StringType, DataTypes.IntegerType},
new CqlField.CqlType[]{mockCqlType(INT), mockCqlType(DATE), mockCqlType(VARCHAR), mockCqlType(INT)});
- private final boolean quoteIdentifiers;
private ConsistencyLevel.CL consistencyLevel;
private int sstableDataSizeInMB = 128;
- private int sstableWriteBatchSize = 2;
private CassandraBridge bridge = CassandraBridgeFactory.get(CassandraVersion.FOURZERO);
@Override
@@ -87,7 +89,7 @@
// DO NOTHING
}
- public interface CommitResultSupplier extends BiFunction<List<String>, String, RemoteCommitResult>
+ public interface CommitResultSupplier extends BiFunction<List<String>, String, DirectDataTransferApi.RemoteCommitResult>
{
}
@@ -105,11 +107,10 @@
private final TableSchema schema;
private final TokenRangeMapping<RingInstance> tokenRangeMapping;
private final Set<CassandraInstance> cleanCalledForInstance = Collections.synchronizedSet(new HashSet<>());
- private boolean instancesAreAvailable = true;
private boolean cleanShouldThrow = false;
private final TokenPartitioner tokenPartitioner;
private final String cassandraVersion;
- private CommitResultSupplier crSupplier = (uuids, dc) -> new RemoteCommitResult(true, Collections.emptyList(), uuids, null);
+ private CommitResultSupplier crSupplier = (uuids, dc) -> new DirectDataTransferApi.RemoteCommitResult(true, Collections.emptyList(), uuids, null);
private Predicate<CassandraInstance> uploadRequestConsumer = instance -> true;
@@ -139,7 +140,6 @@
String[] primaryKeyColumnNames,
boolean quoteIdentifiers)
{
- this.quoteIdentifiers = quoteIdentifiers;
this.tokenRangeMapping = tokenRangeMapping;
this.tokenPartitioner = new TokenPartitioner(tokenRangeMapping, 1, 2, 2, false);
this.cassandraVersion = cassandraVersion;
@@ -158,7 +158,7 @@
.withWriteMode(WriteMode.INSERT)
.withDataFrameSchema(validDataFrameSchema)
.withTTLSetting(ttlOption);
- if (quoteIdentifiers())
+ if (quoteIdentifiers)
{
builder.withQuotedIdentifiers();
}
@@ -205,6 +205,12 @@
}
@Override
+ public CassandraContext getCassandraContext()
+ {
+ return null;
+ }
+
+ @Override
public void refreshClusterInfo()
{
refreshClusterInfoCallCount++;
@@ -263,6 +269,30 @@
return DigestAlgorithms.XXHASH32;
}
+ @Override
+ public DataTransportInfo transportInfo()
+ {
+ return new DataTransportInfo(DataTransport.DIRECT, null, 0);
+ }
+
+ @Override
+ public int jobKeepAliveMinutes()
+ {
+ return 1;
+ }
+
+ @Override
+ public int effectiveSidecarPort()
+ {
+ return 9043;
+ }
+
+ @Override
+ public int importCoordinatorTimeoutMultiplier()
+ {
+ return 2;
+ }
+
public void setSkipCleanOnFailures(boolean skipClean)
{
this.skipClean = skipClean;
@@ -281,12 +311,18 @@
}
@Override
- public UUID getId()
+ public UUID getRestoreJobId()
{
return jobId;
}
@Override
+ public String getConfiguredJobId()
+ {
+ return null;
+ }
+
+ @Override
public TokenPartitioner getTokenPartitioner()
{
return tokenPartitioner;
@@ -323,52 +359,12 @@
return cassandraVersion;
}
- @Override
- public RemoteCommitResult commitSSTables(CassandraInstance instance, String migrationId, List<String> uuids)
- {
- commits.computeIfAbsent(instance, k -> new ArrayList<>()).add(migrationId);
- return crSupplier.apply(buildCompleteBatchIds(uuids), instance.datacenter());
- }
-
private List<String> buildCompleteBatchIds(List<String> uuids)
{
return uuids.stream().map(uuid -> uuid + "-" + jobId).collect(Collectors.toList());
}
@Override
- public void cleanUploadSession(CassandraInstance instance, String sessionID, String jobID) throws ClientException
- {
- cleanCalledForInstance.add(instance);
- if (cleanShouldThrow)
- {
- throw new ClientException("Clean was called but was set to throw");
- }
- }
-
- @Override
- public void uploadSSTableComponent(Path componentFile,
- int ssTableIdx,
- CassandraInstance instance,
- String sessionID,
- Digest digest) throws ClientException
- {
- boolean uploadSucceeded = uploadRequestConsumer.test(instance);
- uploads.compute(instance, (k, pathList) -> {
- if (pathList == null)
- {
- pathList = new ArrayList<>();
- }
- pathList.add(new UploadRequest(componentFile, ssTableIdx, instance, sessionID, digest, uploadSucceeded));
- return pathList;
- });
- if (!uploadSucceeded)
- {
- throw new ClientException("Failed upload");
- }
- uploadsLatch.countDown();
- }
-
- @Override
public Map<RingInstance, InstanceAvailability> getInstanceAvailability()
{
return tokenRangeMapping.getReplicaMetadata().stream()
@@ -377,12 +373,6 @@
}
@Override
- public boolean instanceIsAvailable(RingInstance ringInstance)
- {
- return instancesAreAvailable;
- }
-
- @Override
public InstanceState getInstanceState(RingInstance ringInstance)
{
return InstanceState.NORMAL;
@@ -393,11 +383,6 @@
this.uploadRequestConsumer = uploadRequestConsumer;
}
- public void setInstancesAreAvailable(boolean instancesAreAvailable)
- {
- this.instancesAreAvailable = instancesAreAvailable;
- }
-
public int refreshClusterInfoCallCount()
{
return refreshClusterInfoCallCount;
@@ -450,9 +435,79 @@
}
@Override
- public DataTransferApi transfer()
+ public TransportContext transportContext()
{
- return this;
+ MockBulkWriterContext mockBulkWriterContext = this;
+ return new TransportContext.DirectDataBulkWriterContext()
+ {
+ @Override
+ public DirectDataTransferApi dataTransferApi()
+ {
+ return new DirectDataTransferApi()
+ {
+ @Override
+ public DirectDataTransferApi.RemoteCommitResult commitSSTables(CassandraInstance instance, String migrationId, List<String> uuids)
+ {
+ commits.compute(instance, (ignored, commitList) -> {
+ if (commitList == null)
+ {
+ commitList = new ArrayList<>();
+ }
+ commitList.add(migrationId);
+ return commitList;
+ });
+ return crSupplier.apply(buildCompleteBatchIds(uuids), instance.datacenter());
+ }
+
+ @Override
+ public void cleanUploadSession(CassandraInstance instance, String sessionID, String jobID) throws ClientException
+ {
+ cleanCalledForInstance.add(instance);
+ if (cleanShouldThrow)
+ {
+ throw new ClientException("Clean was called but was set to throw");
+ }
+ }
+
+ @Override
+ public void uploadSSTableComponent(Path componentFile,
+ int ssTableIdx,
+ CassandraInstance instance,
+ String sessionID,
+ Digest digest) throws ClientException
+ {
+ boolean uploadSucceeded = uploadRequestConsumer.test(instance);
+ uploads.compute(instance, (k, pathList) -> {
+ if (pathList == null)
+ {
+ pathList = new ArrayList<>();
+ }
+ pathList.add(new UploadRequest(componentFile, ssTableIdx, instance, sessionID, digest, uploadSucceeded));
+ return pathList;
+ });
+ if (!uploadSucceeded)
+ {
+ throw new ClientException("Failed upload");
+ }
+ }
+ };
+ }
+
+ @Override
+ public StreamSession<?> createStreamSession(BulkWriterContext writerContext,
+ String sessionId,
+ SortedSSTableWriter sstableWriter,
+ Range<BigInteger> range,
+ ReplicaAwareFailureHandler<RingInstance> failureHandler)
+ {
+ return new DirectStreamSession(mockBulkWriterContext,
+ sstableWriter,
+ this,
+ sessionId,
+ range,
+ failureHandler);
+ }
+ };
}
public CassandraBridge bridge()
@@ -461,21 +516,9 @@
}
@Override
- public boolean quoteIdentifiers()
+ public QualifiedTableName qualifiedTableName()
{
- return quoteIdentifiers;
- }
-
- @Override
- public String keyspace()
- {
- return "keyspace";
- }
-
- @Override
- public String tableName()
- {
- return "table";
+ return new QualifiedTableName("keyspace", "table", false);
}
// Startup Validation
diff --git a/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/MockScheduledExecutorService.java b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/MockScheduledExecutorService.java
index 41c6d75..ca99c5d 100644
--- a/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/MockScheduledExecutorService.java
+++ b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/MockScheduledExecutorService.java
@@ -24,6 +24,7 @@
import java.util.concurrent.Callable;
import java.util.concurrent.Delayed;
import java.util.concurrent.ExecutionException;
+import java.util.concurrent.Future;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
@@ -46,10 +47,10 @@
@NotNull
@Override
- public ScheduledFuture<?> scheduleAtFixedRate(@NotNull Runnable command,
- long initialDelay,
- long period,
- @NotNull TimeUnit unit)
+ public ScheduledFuture<?> scheduleWithFixedDelay(@NotNull Runnable command,
+ long initialDelay,
+ long period,
+ @NotNull TimeUnit unit)
{
this.period = period;
this.timeUnit = unit;
@@ -67,6 +68,15 @@
@NotNull
@Override
+ public <T> Future<T> submit(@NotNull Callable<T> task)
+ {
+ MockScheduledFuture<T> future = new MockScheduledFuture<>(task);
+ futures.add(future);
+ return future;
+ }
+
+ @NotNull
+ @Override
public List<Runnable> shutdownNow()
{
stopped = true;
diff --git a/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/MockTableWriter.java b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/MockTableWriter.java
index 716b7ee..9551fe2 100644
--- a/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/MockTableWriter.java
+++ b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/MockTableWriter.java
@@ -30,6 +30,7 @@
import org.apache.commons.io.FileUtils;
import org.apache.cassandra.bridge.SSTableWriter;
+import org.apache.cassandra.spark.utils.DigestAlgorithm;
public class MockTableWriter implements SSTableWriter
{
@@ -100,4 +101,12 @@
{
FileUtils.deleteDirectory(outDir.toFile());
}
+
+ public interface Creator
+ {
+ // to match with SortedSSTableWriter's constructor
+ SortedSSTableWriter create(MockTableWriter tableWriter,
+ Path outDir,
+ DigestAlgorithm digestAlgorithm);
+ }
}
diff --git a/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/NonValidatingTestSSTableWriter.java b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/NonValidatingTestSortedSSTableWriter.java
similarity index 85%
rename from cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/NonValidatingTestSSTableWriter.java
rename to cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/NonValidatingTestSortedSSTableWriter.java
index 08ae58a..eae8795 100644
--- a/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/NonValidatingTestSSTableWriter.java
+++ b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/NonValidatingTestSortedSSTableWriter.java
@@ -24,9 +24,9 @@
import org.apache.cassandra.spark.utils.DigestAlgorithm;
import org.jetbrains.annotations.NotNull;
-class NonValidatingTestSSTableWriter extends SSTableWriter
+public class NonValidatingTestSortedSSTableWriter extends SortedSSTableWriter
{
- NonValidatingTestSSTableWriter(MockTableWriter tableWriter, Path path, DigestAlgorithm digestAlgorithm)
+ public NonValidatingTestSortedSSTableWriter(MockTableWriter tableWriter, Path path, DigestAlgorithm digestAlgorithm)
{
super(tableWriter, path, digestAlgorithm);
}
diff --git a/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/RecordWriterTest.java b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/RecordWriterTest.java
index e8f5138..14abca6 100644
--- a/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/RecordWriterTest.java
+++ b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/RecordWriterTest.java
@@ -74,7 +74,7 @@
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.Mockito.when;
-public class RecordWriterTest
+class RecordWriterTest
{
private static final int REPLICA_COUNT = 3;
private static final int FILES_PER_SSTABLE = 8;
@@ -125,7 +125,7 @@
}
@Test
- public void testWriteFailWhenTopologyChangeWithinTask()
+ void testWriteFailWhenTopologyChangeWithinTask()
{
// Generate token range mapping to simulate node movement of the first node by assigning it a different token
// within the same partition
@@ -138,7 +138,7 @@
moveTargetToken);
MockBulkWriterContext m = Mockito.spy(writerContext);
- rw = new RecordWriter(m, COLUMN_NAMES, () -> tc, SSTableWriter::new);
+ rw = new RecordWriter(m, COLUMN_NAMES, () -> tc, SortedSSTableWriter::new);
when(m.getTokenRangeMapping(false)).thenCallRealMethod().thenReturn(testMapping);
Iterator<Tuple2<DecoratedKey, Object[]>> data = generateData();
@@ -147,7 +147,7 @@
}
@Test
- public void testWriteWithBlockedInstances()
+ void testWriteWithBlockedInstances()
{
String blockedInstanceIp = "127.0.0.2";
@@ -167,7 +167,7 @@
writerContext = new MockBulkWriterContext(tokenRangeMapping, DEFAULT_CASSANDRA_VERSION, ConsistencyLevel.CL.QUORUM);
MockBulkWriterContext m = Mockito.spy(writerContext);
- rw = new RecordWriter(m, COLUMN_NAMES, () -> tc, SSTableWriter::new);
+ rw = new RecordWriter(m, COLUMN_NAMES, () -> tc, SortedSSTableWriter::new);
when(m.getTokenRangeMapping(anyBoolean())).thenReturn(testMapping);
when(m.getInstanceAvailability()).thenReturn(availability);
@@ -176,11 +176,11 @@
Map<CassandraInstance, List<UploadRequest>> uploads = writerContext.getUploads();
// Should not upload to blocked instances
assertThat(uploads.keySet().size(), is(REPLICA_COUNT - 1));
- assertFalse(uploads.keySet().stream().map(i -> i.ipAddress()).collect(Collectors.toSet()).contains(blockedInstanceIp));
+ assertFalse(uploads.keySet().stream().map(CassandraInstance::ipAddress).collect(Collectors.toSet()).contains(blockedInstanceIp));
}
@Test
- public void testWriteWithExclusions()
+ void testWriteWithExclusions()
{
TokenRangeMapping<RingInstance> testMapping =
TokenRangeMappingUtils.buildTokenRangeMappingWithFailures(100000,
@@ -188,7 +188,7 @@
12);
MockBulkWriterContext m = Mockito.spy(writerContext);
- rw = new RecordWriter(m, COLUMN_NAMES, () -> tc, SSTableWriter::new);
+ rw = new RecordWriter(m, COLUMN_NAMES, () -> tc, SortedSSTableWriter::new);
when(m.getTokenRangeMapping(anyBoolean())).thenReturn(testMapping);
when(m.getInstanceAvailability()).thenCallRealMethod();
@@ -199,14 +199,14 @@
}
@Test
- public void testSuccessfulWrite() throws InterruptedException
+ void testSuccessfulWrite() throws InterruptedException
{
Iterator<Tuple2<DecoratedKey, Object[]>> data = generateData();
validateSuccessfulWrite(writerContext, data, COLUMN_NAMES);
}
@Test
- public void testWriteWithMixedCaseColumnNames() throws InterruptedException
+ void testWriteWithMixedCaseColumnNames() throws InterruptedException
{
boolean quoteIdentifiers = true;
String[] pk = {"ID", "date"};
@@ -230,9 +230,9 @@
}
@Test
- public void testSuccessfulWriteCheckUploads()
+ void testSuccessfulWriteCheckUploads()
{
- rw = new RecordWriter(writerContext, COLUMN_NAMES, () -> tc, SSTableWriter::new);
+ rw = new RecordWriter(writerContext, COLUMN_NAMES, () -> tc, SortedSSTableWriter::new);
Iterator<Tuple2<DecoratedKey, Object[]>> data = generateData();
rw.write(data);
Map<CassandraInstance, List<UploadRequest>> uploads = writerContext.getUploads();
@@ -246,14 +246,14 @@
}
@Test
- public void testWriteWithConstantTTL() throws InterruptedException
+ void testWriteWithConstantTTL() throws InterruptedException
{
Iterator<Tuple2<DecoratedKey, Object[]>> data = generateData(false, false);
validateSuccessfulWrite(writerContext, data, COLUMN_NAMES);
}
@Test
- public void testWriteWithTTLColumn() throws InterruptedException
+ void testWriteWithTTLColumn() throws InterruptedException
{
Iterator<Tuple2<DecoratedKey, Object[]>> data = generateData(true, false);
String[] columnNamesWithTtl =
@@ -264,14 +264,14 @@
}
@Test
- public void testWriteWithConstantTimestamp() throws InterruptedException
+ void testWriteWithConstantTimestamp() throws InterruptedException
{
Iterator<Tuple2<DecoratedKey, Object[]>> data = generateData(false, false);
validateSuccessfulWrite(writerContext, data, COLUMN_NAMES);
}
@Test
- public void testWriteWithTimestampColumn() throws InterruptedException
+ void testWriteWithTimestampColumn() throws InterruptedException
{
Iterator<Tuple2<DecoratedKey, Object[]>> data = generateData(false, true);
String[] columnNamesWithTimestamp =
@@ -282,7 +282,7 @@
}
@Test
- public void testWriteWithTimestampAndTTLColumn() throws InterruptedException
+ void testWriteWithTimestampAndTTLColumn() throws InterruptedException
{
Iterator<Tuple2<DecoratedKey, Object[]>> data = generateData(true, true);
String[] columnNames =
@@ -293,7 +293,7 @@
}
@Test
- public void testWriteWithSubRanges()
+ void testWriteWithSubRanges()
{
MockBulkWriterContext m = Mockito.spy(writerContext);
TokenPartitioner mtp = Mockito.mock(TokenPartitioner.class);
@@ -303,7 +303,7 @@
Range<BigInteger> overlapRange = Range.openClosed(BigInteger.valueOf(-9223372036854775808L), BigInteger.valueOf(200000));
when(mtp.getTokenRange(anyInt())).thenReturn(overlapRange);
- rw = new RecordWriter(m, COLUMN_NAMES, () -> tc, SSTableWriter::new);
+ rw = new RecordWriter(m, COLUMN_NAMES, () -> tc, SortedSSTableWriter::new);
Iterator<Tuple2<DecoratedKey, Object[]>> data = generateDataWithFakeToken(ROWS_COUNT, range);
List<StreamResult> res = rw.write(data).streamResults();
assertEquals(1, res.size());
@@ -321,7 +321,7 @@
}
@Test
- public void testWriteWithDataInMultipleSubRanges()
+ void testWriteWithDataInMultipleSubRanges()
{
MockBulkWriterContext m = Mockito.spy(writerContext);
TokenPartitioner mtp = Mockito.mock(TokenPartitioner.class);
@@ -330,7 +330,7 @@
Range<BigInteger> overlapRange = Range.openClosed(BigInteger.valueOf(-9223372036854775808L), BigInteger.valueOf(200000));
Range<BigInteger> firstSubrange = Range.openClosed(BigInteger.valueOf(-9223372036854775808L), BigInteger.valueOf(100000));
when(mtp.getTokenRange(anyInt())).thenReturn(overlapRange);
- rw = new RecordWriter(m, COLUMN_NAMES, () -> tc, SSTableWriter::new);
+ rw = new RecordWriter(m, COLUMN_NAMES, () -> tc, SortedSSTableWriter::new);
// Generate rows that belong to the first sub-range
Iterator<Tuple2<DecoratedKey, Object[]>> data = generateDataWithFakeToken(ROWS_COUNT, firstSubrange);
@@ -351,7 +351,7 @@
}
@Test
- public void testWriteWithTokensAcrossSubRanges()
+ void testWriteWithTokensAcrossSubRanges()
{
MockBulkWriterContext m = Mockito.spy(writerContext);
m.setSstableDataSizeInMB(1);
@@ -362,7 +362,7 @@
Range<BigInteger> firstSubrange = Range.openClosed(BigInteger.valueOf(-9223372036854775808L), BigInteger.valueOf(100000L));
Range<BigInteger> secondSubrange = Range.openClosed(BigInteger.valueOf(100000L), BigInteger.valueOf(200000L));
when(mtp.getTokenRange(anyInt())).thenReturn(overlapRange);
- rw = new RecordWriter(m, COLUMN_NAMES, () -> tc, SSTableWriter::new);
+ rw = new RecordWriter(m, COLUMN_NAMES, () -> tc, SortedSSTableWriter::new);
int numRows = 1; // generate 1 row in each range
Iterator<Tuple2<DecoratedKey, Object[]>> firstRangeData = generateDataWithFakeToken(numRows, firstSubrange);
Iterator<Tuple2<DecoratedKey, Object[]>> secondRangeData = generateDataWithFakeToken(numRows, secondSubrange);
@@ -386,9 +386,10 @@
}
@Test
- public void testCorruptSSTable()
+ void testCorruptSSTable()
{
- rw = new RecordWriter(writerContext, COLUMN_NAMES, () -> tc, (wc, path, dp) -> new SSTableWriter(tw.setOutDir(path), path, digestAlgorithm));
+ rw = new RecordWriter(writerContext, COLUMN_NAMES, () -> tc,
+ (wc, path, dp) -> new SortedSSTableWriter(tw.setOutDir(path), path, digestAlgorithm));
Iterator<Tuple2<DecoratedKey, Object[]>> data = generateData();
// TODO: Add better error handling with human-readable exception messages in SSTableReader::new
// That way we can assert on the exception thrown here
@@ -396,9 +397,10 @@
}
@Test
- public void testWriteWithOutOfRangeTokenFails()
+ void testWriteWithOutOfRangeTokenFails()
{
- rw = new RecordWriter(writerContext, COLUMN_NAMES, () -> tc, (wc, path, dp) -> new SSTableWriter(tw, folder, digestAlgorithm));
+ rw = new RecordWriter(writerContext, COLUMN_NAMES, () -> tc,
+ (wc, path, dp) -> new SortedSSTableWriter(tw, folder, digestAlgorithm));
Iterator<Tuple2<DecoratedKey, Object[]>> data = generateData(5, Range.all(), false, false, false);
RuntimeException ex = assertThrows(RuntimeException.class, () -> rw.write(data));
String expectedErr = "java.lang.IllegalStateException: Received Token " +
@@ -407,9 +409,10 @@
}
@Test
- public void testAddRowThrowingFails()
+ void testAddRowThrowingFails()
{
- rw = new RecordWriter(writerContext, COLUMN_NAMES, () -> tc, (wc, path, dp) -> new SSTableWriter(tw, folder, digestAlgorithm));
+ rw = new RecordWriter(writerContext, COLUMN_NAMES, () -> tc,
+ (wc, path, dp) -> new SortedSSTableWriter(tw, folder, digestAlgorithm));
tw.setAddRowThrows(true);
Iterator<Tuple2<DecoratedKey, Object[]>> data = generateData();
RuntimeException ex = assertThrows(RuntimeException.class, () -> rw.write(data));
@@ -417,11 +420,12 @@
}
@Test
- public void testBadTimeSkewFails()
+ void testBadTimeSkewFails()
{
// Mock context returns a 60-minute allowable time skew, so we use something just outside the limits
long sixtyOneMinutesInMillis = TimeUnit.MINUTES.toMillis(61);
- rw = new RecordWriter(writerContext, COLUMN_NAMES, () -> tc, (wc, path, dp) -> new SSTableWriter(tw, folder, digestAlgorithm));
+ rw = new RecordWriter(writerContext, COLUMN_NAMES, () -> tc,
+ (wc, path, dp) -> new SortedSSTableWriter(tw, folder, digestAlgorithm));
writerContext.setTimeProvider(() -> System.currentTimeMillis() - sixtyOneMinutesInMillis);
Iterator<Tuple2<DecoratedKey, Object[]>> data = generateData();
RuntimeException ex = assertThrows(RuntimeException.class, () -> rw.write(data));
@@ -429,12 +433,12 @@
}
@Test
- public void testTimeSkewWithinLimitsSucceeds()
+ void testTimeSkewWithinLimitsSucceeds()
{
// Mock context returns a 60-minute allowable time skew, so we use something just inside the limits
long fiftyNineMinutesInMillis = TimeUnit.MINUTES.toMillis(59);
long remoteTime = System.currentTimeMillis() - fiftyNineMinutesInMillis;
- rw = new RecordWriter(writerContext, COLUMN_NAMES, () -> tc, SSTableWriter::new);
+ rw = new RecordWriter(writerContext, COLUMN_NAMES, () -> tc, SortedSSTableWriter::new);
writerContext.setTimeProvider(() -> remoteTime); // Return a very low "current time" to make sure we fail if skew is too bad
Iterator<Tuple2<DecoratedKey, Object[]>> data = generateData();
rw.write(data);
@@ -457,7 +461,7 @@
int expectedUploads,
CountDownLatch uploadsLatch) throws InterruptedException
{
- RecordWriter rw = new RecordWriter(writerContext, columnNames, () -> tc, SSTableWriter::new);
+ RecordWriter rw = new RecordWriter(writerContext, columnNames, () -> tc, SortedSSTableWriter::new);
rw.write(data);
uploadsLatch.await(1, TimeUnit.SECONDS);
diff --git a/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/SSTableWriterTest.java b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/SortedSSTableWriterTest.java
similarity index 96%
rename from cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/SSTableWriterTest.java
rename to cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/SortedSSTableWriterTest.java
index 58dc982..e7acbfb 100644
--- a/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/SSTableWriterTest.java
+++ b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/SortedSSTableWriterTest.java
@@ -43,7 +43,7 @@
import static org.junit.jupiter.api.Assertions.assertEquals;
-public class SSTableWriterTest
+public class SortedSSTableWriterTest
{
private static String previousMbeanState;
@@ -86,7 +86,7 @@
public void canCreateWriterForVersion(String version) throws IOException
{
MockBulkWriterContext writerContext = new MockBulkWriterContext(tokenRangeMapping, version, ConsistencyLevel.CL.LOCAL_QUORUM);
- SSTableWriter tw = new SSTableWriter(writerContext, tmpDir, new XXHash32DigestAlgorithm());
+ SortedSSTableWriter tw = new SortedSSTableWriter(writerContext, tmpDir, new XXHash32DigestAlgorithm());
tw.addRow(BigInteger.ONE, ImmutableMap.of("id", 1, "date", 1, "course", "foo", "marks", 1));
tw.close(writerContext, 1);
try (DirectoryStream<Path> dataFileStream = Files.newDirectoryStream(tw.getOutDir(), "*Data.db"))
diff --git a/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/StreamSessionConsistencyTest.java b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/StreamSessionConsistencyTest.java
index b809f88..b9722d1 100644
--- a/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/StreamSessionConsistencyTest.java
+++ b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/StreamSessionConsistencyTest.java
@@ -24,10 +24,10 @@
import java.nio.file.Path;
import java.util.Arrays;
import java.util.Collection;
-import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutionException;
+import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
@@ -68,8 +68,8 @@
@TempDir
private Path folder;
private MockTableWriter tableWriter;
- private StreamSession streamSession;
private MockBulkWriterContext writerContext;
+ private TransportContext.DirectDataBulkWriterContext transportContext;
private final MockScheduledExecutorService executor = new MockScheduledExecutorService();
private DigestAlgorithm digestAlgorithm;
@@ -87,11 +87,7 @@
digestAlgorithm = new XXHash32DigestAlgorithm();
tableWriter = new MockTableWriter(folder);
writerContext = new MockBulkWriterContext(TOKEN_RANGE_MAPPING, "cassandra-4.0.0", consistencyLevel);
- streamSession = new StreamSession(writerContext,
- "sessionId",
- RANGE,
- executor,
- new ReplicaAwareFailureHandler<>(writerContext.cluster().getPartitioner()));
+ transportContext = (TransportContext.DirectDataBulkWriterContext) writerContext.transportContext();
}
@ParameterizedTest(name = "CL: {0}, numFailures: {1}")
@@ -101,12 +97,6 @@
throws IOException, ExecutionException, InterruptedException
{
setup(consistencyLevel, failuresPerDc);
- streamSession = new StreamSession(writerContext,
- "sessionId",
- RANGE,
- executor,
- new ReplicaAwareFailureHandler<>(writerContext.cluster().getPartitioner()));
-
AtomicInteger dc1Failures = new AtomicInteger(failuresPerDc.get(0));
AtomicInteger dc2Failures = new AtomicInteger(failuresPerDc.get(1));
ImmutableMap<String, AtomicInteger> dcFailures = ImmutableMap.of("DC1", dc1Failures, "DC2", dc2Failures);
@@ -115,28 +105,26 @@
writerContext.setCommitResultSupplier((uuids, dc) -> {
if (dcFailures.get(dc).getAndDecrement() > 0)
{
- return new DataTransferApi.RemoteCommitResult(false, uuids, Collections.emptyList(), "");
+ return new DirectDataTransferApi.RemoteCommitResult(false, null, uuids, "");
}
else
{
- return new DataTransferApi.RemoteCommitResult(true, Collections.emptyList(), uuids, "");
+ return new DirectDataTransferApi.RemoteCommitResult(true, uuids, null, "");
}
});
- SSTableWriter tr = new NonValidatingTestSSTableWriter(tableWriter, folder, digestAlgorithm);
- tr.addRow(BigInteger.valueOf(102L), COLUMN_BIND_VALUES);
- tr.close(writerContext, 1);
- streamSession.scheduleStream(tr);
+ StreamSession<?> streamSession = createStreamSession(NonValidatingTestSortedSSTableWriter::new);
+ streamSession.addRow(BigInteger.valueOf(102L), COLUMN_BIND_VALUES);
+ Future<?> fut = streamSession.scheduleStreamAsync(1, executor);
if (shouldFail)
{
- RuntimeException exception = assertThrows(RuntimeException.class,
- () -> streamSession.close()); // Force "execution" of futures
+ ExecutionException exception = assertThrows(ExecutionException.class, fut::get);
assertEquals("Failed to load 1 ranges with " + consistencyLevel
+ " for job " + writerContext.job().getId()
- + " in phase UploadAndCommit.", exception.getMessage());
+ + " in phase UploadAndCommit.", exception.getCause().getMessage());
}
else
{
- streamSession.close(); // Force "execution" of futures
+ fut.get();
}
executor.assertFuturesCalled();
assertThat(writerContext.getUploads().values().stream()
@@ -156,32 +144,25 @@
throws IOException, ExecutionException, InterruptedException
{
setup(consistencyLevel, failuresPerDc);
- streamSession = new StreamSession(writerContext,
- "sessionId",
- RANGE,
- executor,
- new ReplicaAwareFailureHandler<>(writerContext.cluster().getPartitioner()));
AtomicInteger dc1Failures = new AtomicInteger(failuresPerDc.get(0));
AtomicInteger dc2Failures = new AtomicInteger(failuresPerDc.get(1));
int numFailures = dc1Failures.get() + dc2Failures.get();
ImmutableMap<String, AtomicInteger> dcFailures = ImmutableMap.of("DC1", dc1Failures, "DC2", dc2Failures);
boolean shouldFail = calculateFailure(consistencyLevel, dc1Failures.get(), dc2Failures.get());
writerContext.setUploadSupplier(instance -> dcFailures.get(instance.datacenter()).getAndDecrement() <= 0);
- SSTableWriter tr = new NonValidatingTestSSTableWriter(tableWriter, folder, digestAlgorithm);
- tr.addRow(BigInteger.valueOf(102L), COLUMN_BIND_VALUES);
- tr.close(writerContext, 1);
- streamSession.scheduleStream(tr);
+ StreamSession<?> streamSession = createStreamSession(NonValidatingTestSortedSSTableWriter::new);
+ streamSession.addRow(BigInteger.valueOf(102L), COLUMN_BIND_VALUES);
+ Future<?> fut = streamSession.scheduleStreamAsync(1, executor);
if (shouldFail)
{
- RuntimeException exception = assertThrows(RuntimeException.class,
- () -> streamSession.close()); // Force "execution" of futures
+ ExecutionException exception = assertThrows(ExecutionException.class, fut::get);
assertEquals("Failed to load 1 ranges with " + consistencyLevel
+ " for job " + writerContext.job().getId()
- + " in phase UploadAndCommit.", exception.getMessage());
+ + " in phase UploadAndCommit.", exception.getCause().getMessage());
}
else
{
- streamSession.close(); // Force "execution" of futures
+ fut.get(); // Force "execution" of futures
}
executor.assertFuturesCalled();
int totalFilesToUpload = REPLICATION_FACTOR * NUMBER_DCS * FILES_PER_SSTABLE;
@@ -224,4 +205,14 @@
throw new IllegalArgumentException("CL: " + consistencyLevel + " not supported");
}
}
+
+ private StreamSession<?> createStreamSession(MockTableWriter.Creator writerCreator)
+ {
+ return new DirectStreamSession(writerContext,
+ writerCreator.create(tableWriter, folder, digestAlgorithm),
+ transportContext,
+ "sessionId",
+ RANGE,
+ new ReplicaAwareFailureHandler<>(writerContext.cluster().getPartitioner()));
+ }
}
diff --git a/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/blobupload/BlobStreamSessionTest.java b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/blobupload/BlobStreamSessionTest.java
new file mode 100644
index 0000000..0dc1ff1
--- /dev/null
+++ b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/blobupload/BlobStreamSessionTest.java
@@ -0,0 +1,227 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.cassandra.spark.bulkwriter.blobupload;
+
+import java.io.IOException;
+import java.math.BigInteger;
+import java.net.URISyntaxException;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.UUID;
+
+import com.google.common.collect.BoundType;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.Range;
+import org.apache.commons.io.FileUtils;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.io.TempDir;
+
+import o.a.c.sidecar.client.shaded.common.data.CreateSliceRequestPayload;
+import o.a.c.sidecar.client.shaded.common.data.RestoreJobSecrets;
+import o.a.c.sidecar.client.shaded.common.data.RestoreJobSummaryResponsePayload;
+import org.apache.cassandra.bridge.CassandraBridge;
+import org.apache.cassandra.bridge.SSTableSummary;
+import org.apache.cassandra.sidecar.client.SidecarClient;
+import org.apache.cassandra.sidecar.client.SidecarInstance;
+import org.apache.cassandra.spark.bulkwriter.BulkWriterContext;
+import org.apache.cassandra.spark.bulkwriter.ClusterInfo;
+import org.apache.cassandra.spark.bulkwriter.DataTransport;
+import org.apache.cassandra.spark.bulkwriter.DataTransportInfo;
+import org.apache.cassandra.spark.bulkwriter.JobInfo;
+import org.apache.cassandra.spark.bulkwriter.MockBulkWriterContext;
+import org.apache.cassandra.spark.bulkwriter.MockTableWriter;
+import org.apache.cassandra.spark.bulkwriter.NonValidatingTestSortedSSTableWriter;
+import org.apache.cassandra.spark.bulkwriter.RingInstance;
+import org.apache.cassandra.spark.bulkwriter.SortedSSTableWriter;
+import org.apache.cassandra.spark.bulkwriter.TokenRangeMappingUtils;
+import org.apache.cassandra.spark.bulkwriter.TransportContext;
+import org.apache.cassandra.spark.bulkwriter.token.ReplicaAwareFailureHandler;
+import org.apache.cassandra.spark.bulkwriter.token.TokenRangeMapping;
+import org.apache.cassandra.spark.common.client.ClientException;
+import org.apache.cassandra.spark.data.FileSystemSSTable;
+import org.apache.cassandra.spark.data.QualifiedTableName;
+import org.apache.cassandra.spark.transports.storage.StorageCredentials;
+import org.apache.cassandra.spark.transports.storage.extensions.StorageTransportConfiguration;
+import org.apache.cassandra.spark.transports.storage.extensions.StorageTransportExtension;
+import org.apache.cassandra.spark.utils.TemporaryDirectory;
+import org.apache.cassandra.spark.utils.XXHash32DigestAlgorithm;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.mockito.ArgumentMatchers.anyBoolean;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.when;
+
+class BlobStreamSessionTest
+{
+ @TempDir
+ private Path folder;
+
+ @Test
+ void testSendBundles() throws IOException, URISyntaxException
+ {
+ // setup
+ UUID jobId = UUID.randomUUID();
+ String sessionId = "1-" + UUID.randomUUID();
+ BundleNameGenerator nameGenerator = new BundleNameGenerator(jobId.toString(), sessionId);
+ TransportContext.CloudStorageTransportContext transportContext = mock(TransportContext.CloudStorageTransportContext.class);
+ TokenRangeMapping<RingInstance> topology = TokenRangeMappingUtils.buildTokenRangeMapping(0, ImmutableMap.of("DC1", 3), 3);
+ MockBulkWriterContext bulkWriterContext = new MockBulkWriterContext(topology);
+ BulkWriterContext spiedWriterContext = spy(bulkWriterContext);
+ ReplicaAwareFailureHandler<RingInstance> replicaAwareFailureHandler = new ReplicaAwareFailureHandler<>(bulkWriterContext.cluster().getPartitioner());
+ Range<BigInteger> range = Range.range(BigInteger.valueOf(100L), BoundType.OPEN, BigInteger.valueOf(199L), BoundType.CLOSED);
+ JobInfo job = mock(JobInfo.class);
+ when(job.getRestoreJobId()).thenReturn(jobId);
+ when(job.qualifiedTableName()).thenReturn(new QualifiedTableName("ks", "table1"));
+ MockTableWriter tableWriter = new MockTableWriter(folder);
+ SortedSSTableWriter sstableWriter = new NonValidatingTestSortedSSTableWriter(tableWriter, folder, new XXHash32DigestAlgorithm());
+
+ DataTransportInfo transportInfo = mock(DataTransportInfo.class);
+ when(transportInfo.getTransport()).thenReturn(DataTransport.S3_COMPAT);
+ when(transportInfo.getMaxSizePerBundleInBytes()).thenReturn(5 * 1024L);
+ when(job.transportInfo()).thenReturn(transportInfo);
+ when(spiedWriterContext.job()).thenReturn(job);
+ when(job.effectiveSidecarPort()).thenReturn(65055);
+
+ ClusterInfo clusterInfo = mock(ClusterInfo.class);
+ when(clusterInfo.getTokenRangeMapping(anyBoolean())).thenReturn(topology);
+ when(spiedWriterContext.cluster()).thenReturn(clusterInfo);
+ StorageTransportConfiguration storageTransportConfiguration = mock(StorageTransportConfiguration.class);
+ when(transportContext.transportConfiguration()).thenReturn(storageTransportConfiguration);
+
+ StorageTransportExtension transportExtension = mock(StorageTransportExtension.class);
+ when(transportContext.transportExtensionImplementation()).thenReturn(transportExtension);
+
+ try (TemporaryDirectory tempDir = new TemporaryDirectory())
+ {
+ // setup continued
+ Path sourceDir = Paths.get(getClass().getResource("/data/ks/table1-ea3b3e6b-0d78-4913-89f2-15fcf98711d0").toURI());
+ Path outputDir = tempDir.path();
+ FileUtils.copyDirectory(sourceDir.toFile(), outputDir.toFile());
+
+ CassandraBridge bridge = generateBridge(outputDir);
+ SSTableLister ssTableLister = new SSTableLister(new QualifiedTableName("ks", "table1"), bridge);
+ SSTablesBundler ssTablesBundler = new SSTablesBundler(outputDir, ssTableLister, nameGenerator, 5 * 1024);
+ ssTablesBundler.includeDirectory(outputDir);
+ ssTablesBundler.finish();
+ List<Bundle> bundles = ImmutableList.copyOf(ssTablesBundler);
+
+ SidecarClient sidecarClient = mock(SidecarClient.class);
+ StorageClient storageClient = mock(StorageClient.class);
+ MockBlobTransferApi blobDataTransferApi = new MockBlobTransferApi(spiedWriterContext.job(), sidecarClient, storageClient);
+ when(transportContext.dataTransferApi()).thenReturn(blobDataTransferApi);
+ when(transportContext.transportConfiguration().getReadBucket()).thenReturn("readBucket");
+
+ BlobStreamSession ss = new BlobStreamSession(spiedWriterContext, sstableWriter,
+ transportContext, sessionId,
+ range, bridge, replicaAwareFailureHandler);
+
+ // test begins
+ for (Bundle bundle : bundles)
+ {
+ ss.sendBundle(bundle, true);
+ }
+
+ assertEquals(bundles.size(), ss.createdRestoreSlices().size(),
+ "It should create 1 slice per bundle");
+ Bundle actualBundle1 = blobDataTransferApi.uploadedBundleManifest.get(BigInteger.valueOf(1L));
+ BundleManifest.Entry actualBundle1Entry = actualBundle1.manifestEntry("na-1-big-");
+ assertEquals(BigInteger.valueOf(1L), actualBundle1Entry.startToken());
+ assertEquals(BigInteger.valueOf(3L), actualBundle1Entry.endToken());
+ Map<String, String> bundle1ComponentsChecksum = actualBundle1Entry.componentsChecksum();
+ assertEquals("f48b39a3", bundle1ComponentsChecksum.get("na-1-big-Data.db"));
+ assertEquals("ee128018", bundle1ComponentsChecksum.get("na-1-big-Index.db"));
+ assertEquals("e2c32c23", bundle1ComponentsChecksum.get("na-1-big-Summary.db"));
+ assertEquals("f773fcc6", bundle1ComponentsChecksum.get("na-1-big-Statistics.db"));
+ assertEquals("7c8ef1f5", bundle1ComponentsChecksum.get("na-1-big-TOC.txt"));
+ assertEquals("72fc4f9c", bundle1ComponentsChecksum.get("na-1-big-Filter.db"));
+
+ Bundle actualBundle2 = blobDataTransferApi.uploadedBundleManifest.get(BigInteger.valueOf(3L));
+ BundleManifest.Entry actualBundle2Entry = actualBundle2.manifestEntry("na-2-big-");
+ assertEquals(BigInteger.valueOf(3L), actualBundle2Entry.startToken());
+ assertEquals(BigInteger.valueOf(6L), actualBundle2Entry.endToken());
+ Map<String, String> bundle2ComponentsChecksum = actualBundle2Entry.componentsChecksum();
+ assertEquals("f48b39a3", bundle2ComponentsChecksum.get("na-2-big-Data.db"));
+ assertEquals("ee128018", bundle2ComponentsChecksum.get("na-2-big-Index.db"));
+ assertEquals("e2c32c23", bundle2ComponentsChecksum.get("na-2-big-Summary.db"));
+ assertEquals("f773fcc6", bundle2ComponentsChecksum.get("na-2-big-Statistics.db"));
+ assertEquals("7c8ef1f5", bundle2ComponentsChecksum.get("na-2-big-TOC.txt"));
+ assertEquals("72fc4f9c", bundle2ComponentsChecksum.get("na-2-big-Filter.db"));
+ }
+ }
+
+ private CassandraBridge generateBridge(Path outputDir)
+ {
+ CassandraBridge bridge = mock(CassandraBridge.class);
+
+ SSTableSummary summary1 = new SSTableSummary(BigInteger.valueOf(1L), BigInteger.valueOf(3L), "na-1-big-");
+ SSTableSummary summary2 = new SSTableSummary(BigInteger.valueOf(3L), BigInteger.valueOf(6L), "na-2-big-");
+
+ FileSystemSSTable ssTable1 = new FileSystemSSTable(outputDir.resolve("na-1-big-Data.db"), false, null);
+ FileSystemSSTable ssTable2 = new FileSystemSSTable(outputDir.resolve("na-2-big-Data.db"), false, null);
+ when(bridge.getSSTableSummary("ks", "table1", ssTable1)).thenReturn(summary1);
+ when(bridge.getSSTableSummary("ks", "table1", ssTable2)).thenReturn(summary2);
+ return bridge;
+ }
+
+ public static class MockBlobTransferApi extends BlobDataTransferApi
+ {
+ Map<BigInteger, Bundle> uploadedBundleManifest = new HashMap<>();
+
+ MockBlobTransferApi(JobInfo jobInfo, SidecarClient sidecarClient, StorageClient storageClient)
+ {
+ super(jobInfo, sidecarClient, storageClient);
+ }
+
+ @Override
+ public RestoreJobSummaryResponsePayload restoreJobSummary()
+ {
+ RestoreJobSummaryResponsePayload payload = mock(RestoreJobSummaryResponsePayload.class);
+ o.a.c.sidecar.client.shaded.common.data.StorageCredentials credentials = mock(o.a.c.sidecar.client.shaded.common.data.StorageCredentials.class);
+ when(credentials.accessKeyId()).thenReturn("id");
+ when(credentials.secretAccessKey()).thenReturn("key");
+ when(credentials.sessionToken()).thenReturn("token");
+ when(payload.secrets()).thenReturn(new RestoreJobSecrets(credentials, credentials));
+ return payload;
+ }
+
+ @Override
+ public BundleStorageObject uploadBundle(StorageCredentials writeCredentials, Bundle bundle)
+ {
+ uploadedBundleManifest.put(bundle.startToken, bundle);
+ return BundleStorageObject.builder()
+ .bundle(bundle)
+ .storageObjectChecksum("dummy")
+ .storageObjectKey("some_prefix-" + bundle.bundleFile.getFileName().toString())
+ .build();
+ }
+
+ @Override
+ public void createRestoreSliceFromExecutor(SidecarInstance sidecarInstance,
+ CreateSliceRequestPayload createSliceRequestPayload) throws ClientException
+ {
+ // the request is always successful
+ }
+ }
+}
diff --git a/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/blobupload/BundleManifestTest.java b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/blobupload/BundleManifestTest.java
new file mode 100644
index 0000000..7c2beed
--- /dev/null
+++ b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/blobupload/BundleManifestTest.java
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.cassandra.spark.bulkwriter.blobupload;
+
+import java.math.BigInteger;
+import java.nio.file.FileAlreadyExistsException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+
+import org.apache.commons.io.FileUtils;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.io.TempDir;
+
+import com.fasterxml.jackson.core.JsonProcessingException;
+
+import static org.apache.cassandra.spark.bulkwriter.blobupload.BundleManifest.Entry;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+class BundleManifestTest
+{
+ @TempDir
+ private Path tempFolder;
+
+ @Test
+ void testJsonSerialization() throws JsonProcessingException
+ {
+ BundleManifest bundleManifest = testManifest();
+ String value = BundleManifest.OBJECT_WRITER.writeValueAsString(bundleManifest);
+ assertEquals(EXPECTED_JSON, value);
+ }
+
+ @Test
+ void testPersistToFile() throws Exception
+ {
+ Path manifestFile = tempFolder.resolve("manifest.json");
+ assertFalse(Files.exists(manifestFile));
+ BundleManifest bundleManifest = testManifest();
+ bundleManifest.persistTo(manifestFile);
+ String persistedContent = FileUtils.readFileToString(manifestFile.toFile());
+ assertEquals(EXPECTED_JSON, persistedContent);
+ }
+
+ @Test
+ void testPersistToFileFailsWithExistingFile() throws Exception
+ {
+ // the file already exist
+ Path manifestFile = Files.createFile(tempFolder.resolve("manifest.json"));
+ assertTrue(Files.exists(manifestFile));
+ assertThrows(FileAlreadyExistsException.class,
+ () -> testManifest().persistTo(manifestFile));
+ }
+
+ private BundleManifest testManifest()
+ {
+ BundleManifest bundleManifest = new BundleManifest();
+ Entry manifestEntry1 = new Entry("prefix1",
+ BigInteger.valueOf(1L),
+ BigInteger.valueOf(3L));
+ manifestEntry1.addComponentChecksum("prefix1-data.db", "checksumVal");
+ manifestEntry1.addComponentChecksum("prefix1-statistics.db", "checksumVal");
+
+ Entry manifestEntry2 = new Entry("prefix2",
+ BigInteger.valueOf(4L),
+ BigInteger.valueOf(6L));
+ manifestEntry2.addComponentChecksum("prefix2-data.db", "checksumVal");
+ manifestEntry2.addComponentChecksum("prefix2-statistics.db", "checksumVal");
+
+ bundleManifest.addEntry(manifestEntry1);
+ bundleManifest.addEntry(manifestEntry2);
+
+ return bundleManifest;
+ }
+
+ private static final String EXPECTED_JSON = "{\n"
+ + " \"prefix2\" : {\n"
+ + " \"components_checksum\" : {\n"
+ + " \"prefix2-data.db\" : \"checksumVal\",\n"
+ + " \"prefix2-statistics.db\" : \"checksumVal\"\n"
+ + " },\n"
+ + " \"start_token\" : 4,\n"
+ + " \"end_token\" : 6\n"
+ + " },\n"
+ + " \"prefix1\" : {\n"
+ + " \"components_checksum\" : {\n"
+ + " \"prefix1-data.db\" : \"checksumVal\",\n"
+ + " \"prefix1-statistics.db\" : \"checksumVal\"\n"
+ + " },\n"
+ + " \"start_token\" : 1,\n"
+ + " \"end_token\" : 3\n"
+ + " }\n"
+ + "}";
+}
diff --git a/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/blobupload/BundleNameGeneratorTest.java b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/blobupload/BundleNameGeneratorTest.java
new file mode 100644
index 0000000..bb3ef0d
--- /dev/null
+++ b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/blobupload/BundleNameGeneratorTest.java
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.cassandra.spark.bulkwriter.blobupload;
+
+import java.math.BigInteger;
+
+import org.junit.jupiter.api.Test;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+class BundleNameGeneratorTest
+{
+ @Test
+ void testNameGenerated()
+ {
+ String jobId = "ea3b3e6b-0d78-4913-89f2-15fcf98711d0";
+ String sessionId = "1-9062a40b-41ae-40b0-8ba6-47f9bbec6cba";
+ BundleNameGenerator nameGenerator = new BundleNameGenerator(jobId, sessionId);
+
+ String expectedName = "b_" + jobId + '_' + sessionId + "_1_3";
+ assertEquals(expectedName, nameGenerator.generate(BigInteger.valueOf(1L), BigInteger.valueOf(3L)));
+ expectedName = "d_" + jobId + '_' + sessionId + "_3_6";
+ assertEquals(expectedName, nameGenerator.generate(BigInteger.valueOf(3L), BigInteger.valueOf(6L)));
+ }
+
+ @Test
+ void testAllStartCharsGenerated()
+ {
+ String jobId = "ea3b3e6b-0d78-4913-89f2-15fcf98711d0";
+ char[] expectedResults = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789".toCharArray();
+
+ String sessionId = "1-9062a40b-41ae-40b0-8ba6-47f9bbec6cba";
+ BundleNameGenerator nameGenerator = new BundleNameGenerator(jobId, sessionId);
+
+ // till 61 because of mod 62 results possible
+ for (int i = 0; i < 62; i++)
+ {
+ assertEquals(expectedResults[i], nameGenerator.generate(BigInteger.valueOf(i), BigInteger.valueOf(i + 1)).charAt(0));
+ }
+ }
+}
diff --git a/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/blobupload/BundleTest.java b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/blobupload/BundleTest.java
new file mode 100644
index 0000000..f80499d
--- /dev/null
+++ b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/blobupload/BundleTest.java
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.cassandra.spark.bulkwriter.blobupload;
+
+import java.io.FileInputStream;
+import java.math.BigInteger;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.UUID;
+import java.util.zip.ZipInputStream;
+
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.io.TempDir;
+
+import org.apache.cassandra.bridge.SSTableSummary;
+import org.apache.cassandra.spark.bulkwriter.blobupload.SSTableCollector.SSTableFilesAndRange;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+class BundleTest
+{
+ @TempDir
+ private Path tempFolder;
+
+ @Test
+ void testBuildBundleAndDelete() throws Exception
+ {
+ Path stagingDir = Files.createDirectories(tempFolder.resolve("staging"));
+ long totalSize = 0;
+ int sstableCount = 3;
+ int componentCount = 5;
+ List<SSTableFilesAndRange> sourceSSTables = new ArrayList<>(sstableCount);
+ for (int i = 0; i < sstableCount; i++)
+ {
+ sourceSSTables.add(mockSSTableFilesAndRange(componentCount, 100));
+ totalSize += 100;
+ }
+ Bundle bundle = Bundle.builder()
+ .bundleSequence(0)
+ .sourceSSTables(sourceSSTables)
+ .bundleNameGenerator(new BundleNameGenerator("jobId", "sessionId"))
+ .bundleStagingDirectory(stagingDir)
+ .build();
+ assertEquals(totalSize, bundle.bundleUncompressedSize);
+ assertEquals(BigInteger.ONE, bundle.startToken);
+ assertEquals(BigInteger.TEN, bundle.endToken);
+ assertNotNull(bundle.bundleFile);
+ assertTrue(Files.exists(bundle.bundleFile));
+ ZipInputStream zis = new ZipInputStream(new FileInputStream(bundle.bundleFile.toFile()));
+ int acutalFilesCount = 0;
+ while (zis.getNextEntry() != null)
+ {
+ acutalFilesCount++;
+ }
+ // the extra file (+ 1) is the manifest file
+ assertEquals(sstableCount * componentCount + 1, acutalFilesCount);
+
+ bundle.deleteAll();
+ assertFalse(Files.exists(bundle.bundleFile));
+ assertFalse(Files.exists(bundle.bundleDirectory));
+ long filesCount = Files.list(stagingDir).count();
+ assertEquals(0, filesCount);
+ }
+
+ private SSTableFilesAndRange mockSSTableFilesAndRange(int fileCount, long size) throws Exception
+ {
+ SSTableSummary summary = new SSTableSummary(BigInteger.ONE, BigInteger.TEN,
+ UUID.randomUUID().toString());
+ List<Path> paths = new ArrayList<>(fileCount);
+ for (int i = 0; i < fileCount; i++)
+ {
+ paths.add(Files.createFile(tempFolder.resolve(UUID.randomUUID().toString())));
+ }
+ return new SSTableFilesAndRange(summary, paths, size);
+ }
+}
diff --git a/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/blobupload/CreatedRestoreSliceTest.java b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/blobupload/CreatedRestoreSliceTest.java
new file mode 100644
index 0000000..1c02633
--- /dev/null
+++ b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/blobupload/CreatedRestoreSliceTest.java
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.cassandra.spark.bulkwriter.blobupload;
+
+import java.math.BigInteger;
+
+import org.junit.jupiter.api.Test;
+
+import o.a.c.sidecar.client.shaded.common.data.CreateSliceRequestPayload;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+class CreatedRestoreSliceTest
+{
+ @Test
+ void testEqualsAndHashcode()
+ {
+ CreateSliceRequestPayload req = new CreateSliceRequestPayload("id", 0, "bucket", "key", "checksum",
+ BigInteger.ONE, BigInteger.TEN, 234L, 123L);
+ CreatedRestoreSlice slice = new CreatedRestoreSlice(req);
+
+ assertThat(slice.sliceRequestPayloadJson)
+ .isEqualTo("{\"sliceId\":\"id\"," +
+ "\"bucketId\":0," +
+ "\"storageBucket\":\"bucket\"," +
+ "\"storageKey\":\"key\"," +
+ "\"sliceChecksum\":\"checksum\"," +
+ "\"startToken\":1," +
+ "\"endToken\":10," +
+ "\"sliceUncompressedSize\":234," +
+ "\"sliceCompressedSize\":123}");
+
+ assertThat(slice).isEqualTo(new CreatedRestoreSlice(req))
+ .hasSameHashCodeAs(new CreatedRestoreSlice(req));
+
+ CreateSliceRequestPayload differentReq = new CreateSliceRequestPayload("newId", 0, "bucket", "key", "checksum",
+ BigInteger.ZERO, BigInteger.valueOf(2L), 234L, 123L);
+ assertThat(slice).isNotEqualTo(new CreatedRestoreSlice(differentReq))
+ .doesNotHaveSameHashCodeAs(new CreatedRestoreSlice(differentReq));
+ }
+}
diff --git a/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/blobupload/DataChunkerTest.java b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/blobupload/DataChunkerTest.java
new file mode 100644
index 0000000..ae54d04
--- /dev/null
+++ b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/blobupload/DataChunkerTest.java
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.cassandra.spark.bulkwriter.blobupload;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.channels.Channels;
+import java.nio.channels.ReadableByteChannel;
+import java.util.Iterator;
+import java.util.Random;
+
+import org.junit.jupiter.api.Test;
+
+import org.apache.cassandra.spark.utils.ByteBufferUtils;
+
+import static org.junit.jupiter.api.Assertions.assertArrayEquals;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+public class DataChunkerTest
+{
+ @Test
+ public void testChunksGeneratedWithWholeChunks() throws IOException
+ {
+ testChunking(512, 4, 512 / 4);
+ }
+
+ @Test
+ public void testChunksGeneratedWithSmallerLastChunk() throws IOException
+ {
+ testChunking(513, 4, 513 / 4 + 1);
+ }
+
+ private void testChunking(int totalSize, int chunkSize, int expectedChunks) throws IOException
+ {
+ DataChunker chunker = new DataChunker(chunkSize, false);
+ Random rd = new Random();
+ byte[] expected = new byte[totalSize];
+ rd.nextBytes(expected);
+
+ try (ReadableByteChannel channel = Channels.newChannel(new ByteArrayInputStream(expected)))
+ {
+ int size = 0;
+ Iterator<ByteBuffer> chunks = chunker.chunks(channel);
+ ByteArrayOutputStream bos = new ByteArrayOutputStream();
+ while (chunks.hasNext())
+ {
+ ByteBuffer buffer = chunks.next();
+ bos.write(ByteBufferUtils.getArray(buffer));
+ size += 1;
+ }
+ assertEquals(expectedChunks, size);
+ assertArrayEquals(expected, bos.toByteArray());
+ }
+ }
+}
diff --git a/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/blobupload/SSTableListerTest.java b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/blobupload/SSTableListerTest.java
new file mode 100644
index 0000000..5abfa19
--- /dev/null
+++ b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/blobupload/SSTableListerTest.java
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.cassandra.spark.bulkwriter.blobupload;
+
+import java.io.IOException;
+import java.math.BigInteger;
+import java.net.URISyntaxException;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+
+import org.junit.jupiter.api.Test;
+
+import org.apache.cassandra.bridge.CassandraBridge;
+import org.apache.cassandra.bridge.SSTableSummary;
+import org.apache.cassandra.spark.data.FileSystemSSTable;
+import org.apache.cassandra.spark.data.QualifiedTableName;
+import org.apache.cassandra.spark.utils.TemporaryDirectory;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNull;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+class SSTableListerTest
+{
+ @Test
+ void testOutput() throws URISyntaxException
+ {
+ Path outputDir = Paths.get(getClass().getResource("/data/ks/table1-ea3b3e6b-0d78-4913-89f2-15fcf98711d0").toURI());
+ CassandraBridge bridge = mock(CassandraBridge.class);
+
+ SSTableSummary summary1 = new SSTableSummary(BigInteger.valueOf(1L), BigInteger.valueOf(3L), "na-1-big-");
+ SSTableSummary summary2 = new SSTableSummary(BigInteger.valueOf(3L), BigInteger.valueOf(6L), "na-2-big-");
+
+ FileSystemSSTable ssTable1 = new FileSystemSSTable(outputDir.resolve("na-1-big-Data.db"), false, null);
+ FileSystemSSTable ssTable2 = new FileSystemSSTable(outputDir.resolve("na-2-big-Data.db"), false, null);
+ when(bridge.getSSTableSummary("ks", "table1", ssTable1)).thenReturn(summary1);
+ when(bridge.getSSTableSummary("ks", "table1", ssTable2)).thenReturn(summary2);
+ SSTableLister ssTableLister = new SSTableLister(new QualifiedTableName("ks", "table1"), bridge);
+ ssTableLister.includeDirectory(outputDir);
+ List<SSTableLister.SSTableFilesAndRange> sstables = new ArrayList<>();
+ // 10196 is the total size of files in /data/ks/table1-ea3b3e6b-0d78-4913-89f2-15fcf98711d0
+ // If this line fails, maybe something has been changed in the folder.
+ assertEquals(10196, ssTableLister.totalSize());
+ while (!ssTableLister.isEmpty())
+ {
+ sstables.add(ssTableLister.consumeOne());
+ }
+ assertEquals(2, sstables.size());
+ Set<String> ssTablePrefixes = sstables.stream()
+ .map(sstable -> sstable.summary.sstableId)
+ .collect(Collectors.toSet());
+
+ assertTrue(ssTablePrefixes.contains("na-1-big-"));
+ assertTrue(ssTablePrefixes.contains("na-2-big-"));
+
+ Set<Path> range1Files = sstables.get(0).files;
+ Set<Path> range2Files = sstables.get(1).files;
+
+ assertTrue(range1Files.contains(outputDir.resolve("na-1-big-Data.db")));
+ assertTrue(range1Files.contains(outputDir.resolve("na-1-big-Index.db")));
+ assertTrue(range1Files.contains(outputDir.resolve("na-1-big-Summary.db")));
+ assertTrue(range1Files.contains(outputDir.resolve("na-1-big-Statistics.db")));
+ assertTrue(range1Files.contains(outputDir.resolve("na-1-big-TOC.txt")));
+
+ assertTrue(range2Files.contains(outputDir.resolve("na-2-big-Data.db")));
+ assertTrue(range2Files.contains(outputDir.resolve("na-2-big-Index.db")));
+ assertTrue(range2Files.contains(outputDir.resolve("na-2-big-Summary.db")));
+ assertTrue(range2Files.contains(outputDir.resolve("na-2-big-Statistics.db")));
+ assertTrue(range2Files.contains(outputDir.resolve("na-2-big-TOC.txt")));
+ }
+
+ @Test
+ void testEmptyDir() throws IOException
+ {
+ try (TemporaryDirectory tempDir = new TemporaryDirectory())
+ {
+ CassandraBridge bridge = mock(CassandraBridge.class);
+ SSTableLister ssTableLister = new SSTableLister(new QualifiedTableName("ks", "table1"), bridge);
+ ssTableLister.includeDirectory(tempDir.path());
+ assertNull(ssTableLister.peek());
+ assertNull(ssTableLister.consumeOne());
+ assertTrue(ssTableLister.isEmpty());
+ }
+ }
+}
diff --git a/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/blobupload/SSTablesBundlerTest.java b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/blobupload/SSTablesBundlerTest.java
new file mode 100644
index 0000000..f673f11
--- /dev/null
+++ b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/blobupload/SSTablesBundlerTest.java
@@ -0,0 +1,188 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.cassandra.spark.bulkwriter.blobupload;
+
+import java.io.IOException;
+import java.math.BigInteger;
+import java.net.URISyntaxException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.List;
+import java.util.Map;
+import java.util.NoSuchElementException;
+import java.util.UUID;
+
+import com.google.common.collect.ImmutableList;
+import org.apache.commons.io.FileUtils;
+import org.junit.jupiter.api.Test;
+
+import com.fasterxml.jackson.databind.ObjectMapper;
+import org.apache.cassandra.bridge.CassandraBridge;
+import org.apache.cassandra.bridge.SSTableSummary;
+import org.apache.cassandra.spark.bulkwriter.util.IOUtils;
+import org.apache.cassandra.spark.data.FileSystemSSTable;
+import org.apache.cassandra.spark.data.QualifiedTableName;
+import org.apache.cassandra.spark.utils.TemporaryDirectory;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+class SSTablesBundlerTest
+{
+ private final String jobId = UUID.randomUUID().toString();
+ private final String sessionId = "1-" + UUID.randomUUID();
+
+ @Test
+ void testNumberOfBundlesGenerated() throws IOException, URISyntaxException
+ {
+ try (TemporaryDirectory tempDir = new TemporaryDirectory())
+ {
+ BundleNameGenerator nameGenerator = new BundleNameGenerator(jobId, sessionId);
+
+ Path sourceDir = Paths.get(getClass().getResource("/data/ks/table1-ea3b3e6b-0d78-4913-89f2-15fcf98711d0").toURI());
+ Path outputDir = tempDir.path();
+ FileUtils.copyDirectory(sourceDir.toFile(), outputDir.toFile());
+
+ CassandraBridge bridge = mockCassandraBridge(outputDir);
+ SSTableLister ssTableLister = new SSTableLister(new QualifiedTableName("ks", "table1"), bridge);
+ SSTablesBundler ssTablesBundler = new SSTablesBundler(outputDir, ssTableLister, nameGenerator, 5 * 1024);
+ ssTablesBundler.includeDirectory(outputDir);
+ ssTablesBundler.finish();
+ List<Bundle> bundles = ImmutableList.copyOf(ssTablesBundler);
+ assertEquals(2, bundles.size());
+
+ Path bundle0 = outputDir.resolve("0");
+ Path bundle1 = outputDir.resolve("1");
+ assertTrue(Files.exists(bundle0) && Files.isDirectory(bundle0));
+ assertTrue(Files.exists(bundle1) && Files.isDirectory(bundle1));
+ String expectedZippedBundlePath1 = "b_" + jobId + "_" + sessionId + "_1_3";
+ String expectedZippedBundlePath2 = "e_" + jobId + "_" + sessionId + "_4_6";
+ assertTrue(Files.exists(outputDir.resolve(expectedZippedBundlePath1)));
+ assertTrue(Files.exists(outputDir.resolve(expectedZippedBundlePath2)));
+ }
+ }
+
+ @Test
+ void testManifestWritten() throws IOException, URISyntaxException
+ {
+ try (TemporaryDirectory tempDir = new TemporaryDirectory())
+ {
+ BundleNameGenerator nameGenerator = new BundleNameGenerator(jobId, sessionId);
+
+ Path sourceDir = Paths.get(getClass().getResource("/data/ks/table1-ea3b3e6b-0d78-4913-89f2-15fcf98711d0").toURI());
+ Path outputDir = tempDir.path();
+ FileUtils.copyDirectory(sourceDir.toFile(), outputDir.toFile());
+
+ CassandraBridge bridge = mockCassandraBridge(outputDir);
+ SSTableLister writerOutputAnalyzer = new SSTableLister(new QualifiedTableName("ks", "table1"), bridge);
+ SSTablesBundler ssTablesBundler = new SSTablesBundler(outputDir, writerOutputAnalyzer, nameGenerator, 5 * 1024);
+ ssTablesBundler.includeDirectory(outputDir);
+ ssTablesBundler.finish();
+ // evaluate and compute all bundles
+ while (ssTablesBundler.hasNext())
+ {
+ ssTablesBundler.next();
+ }
+
+ String expectedBundle0Manifest = "{\n" +
+ " \"na-1-big-\" : {\n" +
+ " \"components_checksum\" : {\n" +
+ " \"na-1-big-Summary.db\" : \"e2c32c23\",\n" +
+ " \"na-1-big-TOC.txt\" : \"7c8ef1f5\",\n" +
+ " \"na-1-big-Filter.db\" : \"72fc4f9c\",\n" +
+ " \"na-1-big-Index.db\" : \"ee128018\",\n" +
+ " \"na-1-big-Data.db\" : \"f48b39a3\",\n" +
+ " \"na-1-big-Statistics.db\" : \"f773fcc6\"\n" +
+ " },\n" +
+ " \"start_token\" : 1,\n" +
+ " \"end_token\" : 3\n" +
+ " }\n" +
+ "}";
+ String expectedBundle1Manifest = "{\n"
+ + " \"na-2-big-\" : {\n"
+ + " \"components_checksum\" : {\n"
+ + " \"na-2-big-Filter.db\" : \"72fc4f9c\",\n"
+ + " \"na-2-big-TOC.txt\" : \"7c8ef1f5\",\n"
+ + " \"na-2-big-Index.db\" : \"ee128018\",\n"
+ + " \"na-2-big-Data.db\" : \"f48b39a3\",\n"
+ + " \"na-2-big-Summary.db\" : \"e2c32c23\",\n"
+ + " \"na-2-big-Statistics.db\" : \"f773fcc6\"\n"
+ + " },\n"
+ + " \"start_token\" : 4,\n"
+ + " \"end_token\" : 6\n"
+ + " }\n"
+ + "}";
+ Path bundle0Manifest = outputDir.resolve("0").resolve("manifest.json");
+ Path bundle1Manifest = outputDir.resolve("1").resolve("manifest.json");
+ assertTrue(Files.exists(bundle0Manifest));
+ assertTrue(Files.exists(bundle1Manifest));
+ ObjectMapper mapper = new ObjectMapper();
+ Map actualBundle0 = mapper.readValue(bundle0Manifest.toFile(), Map.class);
+ Map expectedBundle0 = mapper.readValue(expectedBundle0Manifest, Map.class);
+ assertEquals(expectedBundle0, actualBundle0);
+ Map actualBundle1 = mapper.readValue(bundle1Manifest.toFile(), Map.class);
+ Map expectedBundle1 = mapper.readValue(expectedBundle1Manifest, Map.class);
+ assertEquals(expectedBundle1, actualBundle1);
+ }
+ }
+
+ @Test
+ void testChecksumComputedForEmptyFile() throws IOException
+ {
+ try (TemporaryDirectory tempDir = new TemporaryDirectory())
+ {
+ Path empty = Files.createFile(tempDir.path().resolve("empty"));
+ assertEquals("2cc5d05", IOUtils.xxhash32(empty));
+ }
+ }
+
+ @Test
+ void testEmptyOutputDir() throws IOException
+ {
+ try (TemporaryDirectory tempDir = new TemporaryDirectory())
+ {
+ BundleNameGenerator nameGenerator = new BundleNameGenerator(jobId, sessionId);
+
+ Path outputDir = tempDir.path();
+ CassandraBridge bridge = mockCassandraBridge(outputDir);
+ SSTableLister ssTableLister = new SSTableLister(new QualifiedTableName("ks", "table1"), bridge);
+ SSTablesBundler ssTablesBundler = new SSTablesBundler(outputDir, ssTableLister, nameGenerator, 200);
+ assertThrows(NoSuchElementException.class, ssTablesBundler::next);
+ }
+ }
+
+ private CassandraBridge mockCassandraBridge(Path outputDir)
+ {
+ CassandraBridge bridge = mock(CassandraBridge.class);
+
+ SSTableSummary summary1 = new SSTableSummary(BigInteger.valueOf(1L), BigInteger.valueOf(3L), "na-1-big-");
+ SSTableSummary summary2 = new SSTableSummary(BigInteger.valueOf(4L), BigInteger.valueOf(6L), "na-2-big-");
+
+ FileSystemSSTable ssTable1 = new FileSystemSSTable(outputDir.resolve("na-1-big-Data.db"), false, null);
+ FileSystemSSTable ssTable2 = new FileSystemSSTable(outputDir.resolve("na-2-big-Data.db"), false, null);
+ when(bridge.getSSTableSummary("ks", "table1", ssTable1)).thenReturn(summary1);
+ when(bridge.getSSTableSummary("ks", "table1", ssTable2)).thenReturn(summary2);
+ return bridge;
+ }
+}
diff --git a/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/token/BulkWriterConsistencyLevelTest.java b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/token/BulkWriterConsistencyLevelTest.java
new file mode 100644
index 0000000..858934a
--- /dev/null
+++ b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/bulkwriter/token/BulkWriterConsistencyLevelTest.java
@@ -0,0 +1,195 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.cassandra.spark.bulkwriter.token;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.Set;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
+import com.google.common.collect.ImmutableMap;
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.Test;
+
+import org.apache.cassandra.spark.common.model.CassandraInstance;
+import org.apache.cassandra.spark.data.ReplicationFactor;
+
+import static org.apache.cassandra.spark.bulkwriter.token.ConsistencyLevel.CL;
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+class BulkWriterConsistencyLevelTest
+{
+ private static final ReplicationFactor replicationFactor = new ReplicationFactor(ImmutableMap.of(
+ "class", "NetworkTopologyStrategy",
+ "dc1", "3"));
+
+ private static List<CassandraInstance> succeededNone = Collections.emptyList();
+ private static List<CassandraInstance> succeededOne;
+ private static List<CassandraInstance> succeededTwo;
+ private static List<CassandraInstance> succeededThree;
+
+ private static Set<String> zero = Collections.emptySet();
+ private static Set<String> one = intToSet(1);
+ private static Set<String> two = intToSet(2);
+ private static Set<String> three = intToSet(3);
+
+ @BeforeAll
+ static void setup()
+ {
+ CassandraInstance i1 = mockInstance("dc1");
+ CassandraInstance i2 = mockInstance("dc1");
+ CassandraInstance i3 = mockInstance("dc1");
+ succeededOne = Arrays.asList(i1);
+ succeededTwo = Arrays.asList(i1, i2);
+ succeededThree = Arrays.asList(i1, i2, i3);
+ }
+
+ @Test
+ void testCanBeSatisfiedReturnsTrue()
+ {
+ testCanBeSatisfied(CL.ONE, succeededOne, true);
+ testCanBeSatisfied(CL.ONE, succeededTwo, true);
+ testCanBeSatisfied(CL.ONE, succeededThree, true);
+
+ testCanBeSatisfied(CL.TWO, succeededTwo, true);
+ testCanBeSatisfied(CL.TWO, succeededThree, true);
+
+ testCanBeSatisfied(CL.LOCAL_ONE, succeededOne, true);
+ testCanBeSatisfied(CL.LOCAL_ONE, succeededTwo, true);
+ testCanBeSatisfied(CL.LOCAL_ONE, succeededThree, true);
+
+ testCanBeSatisfied(CL.LOCAL_QUORUM, succeededTwo, true);
+ testCanBeSatisfied(CL.LOCAL_QUORUM, succeededThree, true);
+
+ testCanBeSatisfied(CL.EACH_QUORUM, succeededTwo, true);
+ testCanBeSatisfied(CL.EACH_QUORUM, succeededThree, true);
+
+ testCanBeSatisfied(CL.QUORUM, succeededTwo, true);
+ testCanBeSatisfied(CL.QUORUM, succeededThree, true);
+
+ testCanBeSatisfied(CL.ALL, succeededThree, true);
+ }
+
+ @Test
+ void testCanBeSatisfiedReturnsFalse()
+ {
+ testCanBeSatisfied(CL.ONE, succeededNone, false);
+
+ testCanBeSatisfied(CL.TWO, succeededNone, false);
+ testCanBeSatisfied(CL.TWO, succeededOne, false);
+
+ testCanBeSatisfied(CL.LOCAL_ONE, succeededNone, false);
+
+ testCanBeSatisfied(CL.LOCAL_QUORUM, succeededNone, false);
+ testCanBeSatisfied(CL.LOCAL_QUORUM, succeededOne, false);
+
+ testCanBeSatisfied(CL.EACH_QUORUM, succeededNone, false);
+ testCanBeSatisfied(CL.EACH_QUORUM, succeededOne, false);
+
+ testCanBeSatisfied(CL.QUORUM, succeededNone, false);
+ testCanBeSatisfied(CL.QUORUM, succeededOne, false);
+
+ testCanBeSatisfied(CL.ALL, succeededNone, false);
+ testCanBeSatisfied(CL.ALL, succeededOne, false);
+ testCanBeSatisfied(CL.ALL, succeededTwo, false);
+ }
+
+ @Test
+ void testCheckConsistencyReturnsTrue()
+ {
+ testCheckConsistency(CL.ONE, /* total */ three, /* failed */ zero, zero, true);
+ testCheckConsistency(CL.ONE, /* total */ three, /* failed */ one, zero, true);
+ testCheckConsistency(CL.ONE, /* total */ three, /* failed */ two, zero, true);
+
+ testCheckConsistency(CL.TWO, /* total */ three, /* failed */ zero, zero, true);
+ testCheckConsistency(CL.TWO, /* total */ three, /* failed */ one, zero, true);
+
+ testCheckConsistency(CL.LOCAL_ONE, /* total */ three, /* failed */ zero, /* pending */ zero, true);
+ testCheckConsistency(CL.LOCAL_ONE, /* total */ three, /* failed */ zero, /* pending */ one, true);
+ testCheckConsistency(CL.LOCAL_ONE, /* total */ three, /* failed */ zero, /* pending */ two, true);
+ testCheckConsistency(CL.LOCAL_ONE, /* total */ three, /* failed */ one, /* pending */ one, true);
+ testCheckConsistency(CL.LOCAL_ONE, /* total */ three, /* failed */ two, /* pending */ zero, true);
+
+ testCheckConsistency(CL.LOCAL_QUORUM, /* total */ three, /* failed */ zero, zero, true);
+ testCheckConsistency(CL.LOCAL_QUORUM, /* total */ three, /* failed */ one, zero, true);
+
+ testCheckConsistency(CL.EACH_QUORUM, /* total */ three, /* failed */ zero, zero, true);
+ testCheckConsistency(CL.EACH_QUORUM, /* total */ three, /* failed */ one, zero, true);
+
+ testCheckConsistency(CL.QUORUM, /* total */ three, /* failed */ zero, zero, true);
+ testCheckConsistency(CL.QUORUM, /* total */ three, /* failed */ one, zero, true);
+
+ testCheckConsistency(CL.ALL, /* total */ three, /* failed */ zero, zero, true);
+ }
+
+ @Test
+ void testCheckConsistencyReturnsFalse()
+ {
+ testCheckConsistency(CL.ONE, /* total */ three, /* failed */ three, zero, false);
+
+ testCheckConsistency(CL.TWO, /* total */ three, /* failed */ three, zero, false);
+ testCheckConsistency(CL.TWO, /* total */ three, /* failed */ two, zero, false);
+
+ testCheckConsistency(CL.LOCAL_ONE, /* total */ three, /* failed */ three, /* pending */ zero, false);
+ testCheckConsistency(CL.LOCAL_ONE, /* total */ three, /* failed */ two, /* pending */ one, false);
+ testCheckConsistency(CL.LOCAL_ONE, /* total */ three, /* failed */ one, /* pending */ two, false);
+
+ testCheckConsistency(CL.LOCAL_QUORUM, /* total */ three, /* failed */ three, zero, false);
+ testCheckConsistency(CL.LOCAL_QUORUM, /* total */ three, /* failed */ two, zero, false);
+
+ testCheckConsistency(CL.EACH_QUORUM, /* total */ three, /* failed */ three, zero, false);
+ testCheckConsistency(CL.EACH_QUORUM, /* total */ three, /* failed */ two, zero, false);
+
+ testCheckConsistency(CL.QUORUM, /* total */ three, /* failed */ three, zero, false);
+ testCheckConsistency(CL.QUORUM, /* total */ three, /* failed */ two, zero, false);
+
+ testCheckConsistency(CL.ALL, /* total */ three, /* failed */ one, zero, false);
+ testCheckConsistency(CL.ALL, /* total */ three, /* failed */ two, zero, false);
+ testCheckConsistency(CL.ALL, /* total */ three, /* failed */ three, zero, false);
+ }
+
+ private void testCanBeSatisfied(ConsistencyLevel cl, List<CassandraInstance> succeeded, boolean expectedResult)
+ {
+ assertThat(cl.canBeSatisfied(succeeded, replicationFactor, "dc1")).isEqualTo(expectedResult);
+ }
+
+ private void testCheckConsistency(ConsistencyLevel cl, Set<String> total, Set<String> failed, Set<String> pending, boolean expectedResult)
+ {
+ assertThat(cl.checkConsistency(total, pending, zero, // replacement is not used
+ zero, // include blocking instance set in failed set
+ failed, "dc1")).isEqualTo(expectedResult);
+ }
+
+ private static CassandraInstance mockInstance(String dc)
+ {
+ CassandraInstance i = mock(CassandraInstance.class);
+ when(i.datacenter()).thenReturn(dc);
+ return i;
+ }
+
+ private static Set<String> intToSet(int i)
+ {
+ return IntStream.range(0, i).mapToObj(String::valueOf).collect(Collectors.toSet());
+ }
+}
diff --git a/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/utils/BuildInfoTest.java b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/utils/BuildInfoTest.java
index ccac01e..35f820a 100644
--- a/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/utils/BuildInfoTest.java
+++ b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/utils/BuildInfoTest.java
@@ -40,6 +40,9 @@
{
assertTrue(BuildInfo.WRITER_USER_AGENT.endsWith(" writer"));
assertNotEquals("unknown", BuildInfo.getBuildVersion());
+
+ assertTrue(BuildInfo.WRITER_S3_USER_AGENT.endsWith(" writer-s3"));
+ assertNotEquals("unknown", BuildInfo.getBuildVersion());
}
@Test
diff --git a/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/utils/IOUtilsTest.java b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/utils/IOUtilsTest.java
new file mode 100644
index 0000000..63f6d35
--- /dev/null
+++ b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/utils/IOUtilsTest.java
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.cassandra.spark.utils;
+
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.nio.charset.StandardCharsets;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.util.zip.ZipInputStream;
+
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.io.TempDir;
+
+import org.apache.cassandra.spark.bulkwriter.util.IOUtils;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotEquals;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+class IOUtilsTest
+{
+ @TempDir
+ private Path tempFolder;
+
+ @Test
+ void testZip() throws Exception
+ {
+ File zipSourceDir = Files.createDirectories(tempFolder.resolve("zipSource")).toFile();
+ int expectedFileCount = 10;
+ for (int i = 0; i < expectedFileCount; i++)
+ {
+ new File(zipSourceDir, Integer.toString(i)).createNewFile();
+ }
+ File targetZip = tempFolder.resolve("zip").toFile();
+ long zipFileSize = IOUtils.zip(zipSourceDir.toPath(), targetZip.toPath());
+ assertTrue(targetZip.exists());
+ assertTrue(zipFileSize > 0);
+
+ ZipInputStream zis = new ZipInputStream(new FileInputStream(targetZip));
+ int acutalFilesCount = 0;
+ while (zis.getNextEntry() != null)
+ {
+ acutalFilesCount++;
+ }
+ assertEquals(expectedFileCount, acutalFilesCount);
+ }
+
+ @Test
+ void testZipFailsOnInvalidInput()
+ {
+ Path file = tempFolder.resolve("file");
+ IOException thrown = assertThrows(IOException.class,
+ () -> IOUtils.zip(file, file));
+ assertTrue(thrown.getMessage().contains("Not a directory"));
+ }
+
+ @Test
+ void testChecksumCalculationShouldBeDeterministic() throws Exception
+ {
+ Path file = tempFolder.resolve("file");
+ Files.write(file, "Hello World!".getBytes(StandardCharsets.UTF_8));
+ String checksum1 = IOUtils.xxhash32(file);
+ String checksum2 = IOUtils.xxhash32(file);
+ assertEquals(checksum1, checksum2,
+ "Deterministic checksum calculation should yield same result for same input");
+ assertEquals("bd69788", checksum1);
+
+ Path anotherFile = tempFolder.resolve("anotherFile");
+ Files.write(anotherFile, "Hello World!".getBytes(StandardCharsets.UTF_8));
+ String checksum3 = IOUtils.xxhash32(anotherFile);
+ assertEquals(checksum1, checksum3, "Checksum should be same for the same content");
+ }
+
+ @Test
+ void testChecksumShouldBeDifferentForDifferentContent() throws Exception
+ {
+ Path file1 = tempFolder.resolve("file1");
+ Path file2 = tempFolder.resolve("file2");
+ Files.write(file1, "I am in file 1".getBytes(StandardCharsets.UTF_8));
+ Files.write(file2, "File 2 is where you find me".getBytes(StandardCharsets.UTF_8));
+ String checksum1 = IOUtils.xxhash32(file1);
+ String checksum2 = IOUtils.xxhash32(file2);
+ assertNotEquals(checksum1, checksum2);
+ assertEquals("a6a6a5ba", checksum1);
+ assertEquals("9e9b9db5", checksum2);
+ }
+}
diff --git a/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/utils/XXHash32DigestAlgorithmTest.java b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/utils/XXHash32DigestAlgorithmTest.java
index 8903cc3..0c136df 100644
--- a/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/utils/XXHash32DigestAlgorithmTest.java
+++ b/cassandra-analytics-core/src/test/java/org/apache/cassandra/spark/utils/XXHash32DigestAlgorithmTest.java
@@ -49,7 +49,7 @@
// $ xxh32sum file2.txt # -> ef976cbe
// $ xxh32sum file3.txt # -> 08321e1e
- @ParameterizedTest(name = "{index} fileName={0} expectedMd5={1}")
+ @ParameterizedTest(name = "{index} fileName={0} expectedDigest={1}")
@CsvSource({
"file1.txt,d76a44a5",
"file2.txt,ef976cbe",
diff --git a/cassandra-analytics-core/src/test/resources/data/ks/table1-ea3b3e6b-0d78-4913-89f2-15fcf98711d0/na-1-big-Data.db b/cassandra-analytics-core/src/test/resources/data/ks/table1-ea3b3e6b-0d78-4913-89f2-15fcf98711d0/na-1-big-Data.db
new file mode 100644
index 0000000..fe53589
--- /dev/null
+++ b/cassandra-analytics-core/src/test/resources/data/ks/table1-ea3b3e6b-0d78-4913-89f2-15fcf98711d0/na-1-big-Data.db
Binary files differ
diff --git a/cassandra-analytics-core/src/test/resources/data/ks/table1-ea3b3e6b-0d78-4913-89f2-15fcf98711d0/na-1-big-Filter.db b/cassandra-analytics-core/src/test/resources/data/ks/table1-ea3b3e6b-0d78-4913-89f2-15fcf98711d0/na-1-big-Filter.db
new file mode 100644
index 0000000..8868e5c
--- /dev/null
+++ b/cassandra-analytics-core/src/test/resources/data/ks/table1-ea3b3e6b-0d78-4913-89f2-15fcf98711d0/na-1-big-Filter.db
Binary files differ
diff --git a/cassandra-analytics-core/src/test/resources/data/ks/table1-ea3b3e6b-0d78-4913-89f2-15fcf98711d0/na-1-big-Index.db b/cassandra-analytics-core/src/test/resources/data/ks/table1-ea3b3e6b-0d78-4913-89f2-15fcf98711d0/na-1-big-Index.db
new file mode 100644
index 0000000..b3094bf
--- /dev/null
+++ b/cassandra-analytics-core/src/test/resources/data/ks/table1-ea3b3e6b-0d78-4913-89f2-15fcf98711d0/na-1-big-Index.db
Binary files differ
diff --git a/cassandra-analytics-core/src/test/resources/data/ks/table1-ea3b3e6b-0d78-4913-89f2-15fcf98711d0/na-1-big-Statistics.db b/cassandra-analytics-core/src/test/resources/data/ks/table1-ea3b3e6b-0d78-4913-89f2-15fcf98711d0/na-1-big-Statistics.db
new file mode 100644
index 0000000..f9940d0
--- /dev/null
+++ b/cassandra-analytics-core/src/test/resources/data/ks/table1-ea3b3e6b-0d78-4913-89f2-15fcf98711d0/na-1-big-Statistics.db
Binary files differ
diff --git a/cassandra-analytics-core/src/test/resources/data/ks/table1-ea3b3e6b-0d78-4913-89f2-15fcf98711d0/na-1-big-Summary.db b/cassandra-analytics-core/src/test/resources/data/ks/table1-ea3b3e6b-0d78-4913-89f2-15fcf98711d0/na-1-big-Summary.db
new file mode 100644
index 0000000..9b24e04
--- /dev/null
+++ b/cassandra-analytics-core/src/test/resources/data/ks/table1-ea3b3e6b-0d78-4913-89f2-15fcf98711d0/na-1-big-Summary.db
Binary files differ
diff --git a/cassandra-analytics-core/src/test/resources/data/ks/table1-ea3b3e6b-0d78-4913-89f2-15fcf98711d0/na-1-big-TOC.txt b/cassandra-analytics-core/src/test/resources/data/ks/table1-ea3b3e6b-0d78-4913-89f2-15fcf98711d0/na-1-big-TOC.txt
new file mode 100644
index 0000000..6ea912e
--- /dev/null
+++ b/cassandra-analytics-core/src/test/resources/data/ks/table1-ea3b3e6b-0d78-4913-89f2-15fcf98711d0/na-1-big-TOC.txt
@@ -0,0 +1,8 @@
+TOC.txt
+Data.db
+Statistics.db
+Summary.db
+Filter.db
+Digest.crc32
+Index.db
+CompressionInfo.db
diff --git a/cassandra-analytics-core/src/test/resources/data/ks/table1-ea3b3e6b-0d78-4913-89f2-15fcf98711d0/na-2-big-Data.db b/cassandra-analytics-core/src/test/resources/data/ks/table1-ea3b3e6b-0d78-4913-89f2-15fcf98711d0/na-2-big-Data.db
new file mode 100644
index 0000000..fe53589
--- /dev/null
+++ b/cassandra-analytics-core/src/test/resources/data/ks/table1-ea3b3e6b-0d78-4913-89f2-15fcf98711d0/na-2-big-Data.db
Binary files differ
diff --git a/cassandra-analytics-core/src/test/resources/data/ks/table1-ea3b3e6b-0d78-4913-89f2-15fcf98711d0/na-2-big-Filter.db b/cassandra-analytics-core/src/test/resources/data/ks/table1-ea3b3e6b-0d78-4913-89f2-15fcf98711d0/na-2-big-Filter.db
new file mode 100644
index 0000000..8868e5c
--- /dev/null
+++ b/cassandra-analytics-core/src/test/resources/data/ks/table1-ea3b3e6b-0d78-4913-89f2-15fcf98711d0/na-2-big-Filter.db
Binary files differ
diff --git a/cassandra-analytics-core/src/test/resources/data/ks/table1-ea3b3e6b-0d78-4913-89f2-15fcf98711d0/na-2-big-Index.db b/cassandra-analytics-core/src/test/resources/data/ks/table1-ea3b3e6b-0d78-4913-89f2-15fcf98711d0/na-2-big-Index.db
new file mode 100644
index 0000000..b3094bf
--- /dev/null
+++ b/cassandra-analytics-core/src/test/resources/data/ks/table1-ea3b3e6b-0d78-4913-89f2-15fcf98711d0/na-2-big-Index.db
Binary files differ
diff --git a/cassandra-analytics-core/src/test/resources/data/ks/table1-ea3b3e6b-0d78-4913-89f2-15fcf98711d0/na-2-big-Statistics.db b/cassandra-analytics-core/src/test/resources/data/ks/table1-ea3b3e6b-0d78-4913-89f2-15fcf98711d0/na-2-big-Statistics.db
new file mode 100644
index 0000000..f9940d0
--- /dev/null
+++ b/cassandra-analytics-core/src/test/resources/data/ks/table1-ea3b3e6b-0d78-4913-89f2-15fcf98711d0/na-2-big-Statistics.db
Binary files differ
diff --git a/cassandra-analytics-core/src/test/resources/data/ks/table1-ea3b3e6b-0d78-4913-89f2-15fcf98711d0/na-2-big-Summary.db b/cassandra-analytics-core/src/test/resources/data/ks/table1-ea3b3e6b-0d78-4913-89f2-15fcf98711d0/na-2-big-Summary.db
new file mode 100644
index 0000000..9b24e04
--- /dev/null
+++ b/cassandra-analytics-core/src/test/resources/data/ks/table1-ea3b3e6b-0d78-4913-89f2-15fcf98711d0/na-2-big-Summary.db
Binary files differ
diff --git a/cassandra-analytics-core/src/test/resources/data/ks/table1-ea3b3e6b-0d78-4913-89f2-15fcf98711d0/na-2-big-TOC.txt b/cassandra-analytics-core/src/test/resources/data/ks/table1-ea3b3e6b-0d78-4913-89f2-15fcf98711d0/na-2-big-TOC.txt
new file mode 100644
index 0000000..6ea912e
--- /dev/null
+++ b/cassandra-analytics-core/src/test/resources/data/ks/table1-ea3b3e6b-0d78-4913-89f2-15fcf98711d0/na-2-big-TOC.txt
@@ -0,0 +1,8 @@
+TOC.txt
+Data.db
+Statistics.db
+Summary.db
+Filter.db
+Digest.crc32
+Index.db
+CompressionInfo.db
diff --git a/cassandra-analytics-integration-framework/src/main/java/org/apache/cassandra/sidecar/testing/SharedClusterIntegrationTestBase.java b/cassandra-analytics-integration-framework/src/main/java/org/apache/cassandra/sidecar/testing/SharedClusterIntegrationTestBase.java
index e216018..3257338 100644
--- a/cassandra-analytics-integration-framework/src/main/java/org/apache/cassandra/sidecar/testing/SharedClusterIntegrationTestBase.java
+++ b/cassandra-analytics-integration-framework/src/main/java/org/apache/cassandra/sidecar/testing/SharedClusterIntegrationTestBase.java
@@ -22,10 +22,12 @@
import java.io.IOException;
import java.net.BindException;
import java.net.InetSocketAddress;
+import java.net.URI;
import java.net.UnknownHostException;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;
+import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
@@ -43,6 +45,7 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import com.codahale.metrics.MetricRegistry;
import com.datastax.driver.core.Cluster;
import com.datastax.driver.core.ResultSet;
import com.datastax.driver.core.Session;
@@ -74,14 +77,19 @@
import org.apache.cassandra.sidecar.common.utils.SidecarVersionProvider;
import org.apache.cassandra.sidecar.config.JmxConfiguration;
import org.apache.cassandra.sidecar.config.KeyStoreConfiguration;
+import org.apache.cassandra.sidecar.config.S3ClientConfiguration;
+import org.apache.cassandra.sidecar.config.S3ProxyConfiguration;
import org.apache.cassandra.sidecar.config.ServiceConfiguration;
import org.apache.cassandra.sidecar.config.SidecarConfiguration;
import org.apache.cassandra.sidecar.config.SslConfiguration;
import org.apache.cassandra.sidecar.config.yaml.KeyStoreConfigurationImpl;
+import org.apache.cassandra.sidecar.config.yaml.S3ClientConfigurationImpl;
+import org.apache.cassandra.sidecar.config.yaml.SchemaKeyspaceConfigurationImpl;
import org.apache.cassandra.sidecar.config.yaml.ServiceConfigurationImpl;
import org.apache.cassandra.sidecar.config.yaml.SidecarConfigurationImpl;
import org.apache.cassandra.sidecar.config.yaml.SslConfigurationImpl;
import org.apache.cassandra.sidecar.exceptions.ThrowableUtils;
+import org.apache.cassandra.sidecar.metrics.instance.InstanceHealthMetrics;
import org.apache.cassandra.sidecar.server.MainModule;
import org.apache.cassandra.sidecar.server.Server;
import org.apache.cassandra.sidecar.utils.CassandraVersionProvider;
@@ -134,6 +142,8 @@
@ExtendWith(VertxExtension.class)
public abstract class SharedClusterIntegrationTestBase
{
+ public static final String SIDECAR_S3_ENDPOINT_OVERRIDE_OPT = "S3_ENDPOINT_OVERRIDE";
+
protected final Logger logger = LoggerFactory.getLogger(SharedClusterIntegrationTestBase.class);
private static final int MAX_CLUSTER_PROVISION_RETRIES = 5;
@@ -309,6 +319,14 @@
}
/**
+ * Override to provide additional options to configure sidecar
+ */
+ protected Map<String, String> sidecarAdditionalOptions()
+ {
+ return Collections.emptyMap();
+ }
+
+ /**
* Starts Sidecar configured to run against the provided Cassandra {@code cluster}.
*
* @param cluster the cluster to use
@@ -317,7 +335,7 @@
protected void startSidecar(ICluster<? extends IInstance> cluster) throws InterruptedException
{
VertxTestContext context = new VertxTestContext();
- AbstractModule testModule = new IntegrationTestModule(cluster, classLoaderWrapper, mtlsTestHelper);
+ AbstractModule testModule = new IntegrationTestModule(cluster, classLoaderWrapper, mtlsTestHelper, sidecarAdditionalOptions());
injector = Guice.createInjector(Modules.override(new MainModule()).with(testModule));
dnsResolver = injector.getInstance(DnsResolver.class);
vertx = injector.getInstance(Vertx.class);
@@ -471,14 +489,26 @@
private final ICluster<? extends IInstance> cluster;
private final IsolatedDTestClassLoaderWrapper wrapper;
private final MtlsTestHelper mtlsTestHelper;
+ private final Map<String, String> additioanlOptions;
IntegrationTestModule(ICluster<? extends IInstance> cluster,
IsolatedDTestClassLoaderWrapper wrapper,
- MtlsTestHelper mtlsTestHelper)
+ MtlsTestHelper mtlsTestHelper,
+ Map<String, String> additionalOptions)
{
this.cluster = cluster;
this.wrapper = wrapper;
this.mtlsTestHelper = mtlsTestHelper;
+ this.additioanlOptions = additionalOptions;
+ }
+
+ @Provides
+ @Singleton
+ public CQLSessionProvider cqlSessionProvider()
+ {
+ List<InetSocketAddress> contactPoints = buildContactPoints();
+ return new TemporaryCqlSessionProvider(contactPoints,
+ SharedExecutorNettyOptions.INSTANCE);
}
@Provides
@@ -487,14 +517,10 @@
SidecarConfiguration configuration,
CassandraVersionProvider cassandraVersionProvider,
SidecarVersionProvider sidecarVersionProvider,
+ CQLSessionProvider cqlSessionProvider,
DnsResolver dnsResolver)
{
JmxConfiguration jmxConfiguration = configuration.serviceConfiguration().jmxConfiguration();
-
- List<InetSocketAddress> contactPoints = buildContactPoints();
- CQLSessionProvider cqlSessionProvider = new TemporaryCqlSessionProvider(contactPoints,
- SharedExecutorNettyOptions.INSTANCE);
-
List<InstanceMetadata> instanceMetadataList =
IntStream.rangeClosed(1, cluster.size())
.mapToObj(i -> buildInstanceMetadata(vertx,
@@ -516,6 +542,9 @@
ServiceConfiguration conf = ServiceConfigurationImpl.builder()
.host("0.0.0.0") // binds to all interfaces, potential security issue if left running for long
.port(0) // let the test find an available port
+ .schemaKeyspaceConfiguration(SchemaKeyspaceConfigurationImpl.builder()
+ .isEnabled(true)
+ .build())
.build();
@@ -547,9 +576,11 @@
LOGGER.info("Not enabling mTLS for testing purposes. Set '{}' to 'true' if you would " +
"like mTLS enabled.", CASSANDRA_INTEGRATION_TEST_ENABLE_MTLS);
}
+ S3ClientConfiguration s3ClientConfig = new S3ClientConfigurationImpl("s3-client", 4, 60L, buildTestS3ProxyConfig());
return SidecarConfigurationImpl.builder()
.serviceConfiguration(conf)
.sslConfiguration(sslConfiguration)
+ .s3ClientConfiguration(s3ClientConfig)
.build();
}
@@ -568,6 +599,36 @@
.collect(Collectors.toList());
}
+ private S3ProxyConfiguration buildTestS3ProxyConfig()
+ {
+ return new S3ProxyConfiguration()
+ {
+ @Override
+ public URI proxy()
+ {
+ return null;
+ }
+
+ @Override
+ public String username()
+ {
+ return null;
+ }
+
+ @Override
+ public String password()
+ {
+ return null;
+ }
+
+ @Override
+ public URI endpointOverride()
+ {
+ return URI.create(additioanlOptions.getOrDefault(SIDECAR_S3_ENDPOINT_OVERRIDE_OPT, "http://localhost:9090"));
+ }
+ };
+ }
+
static int tryGetIntConfig(IInstanceConfig config, String configName, int defaultValue)
{
try
@@ -610,6 +671,7 @@
.port(config.jmxPort())
.connectionMaxRetries(jmxConfiguration.maxRetries())
.connectionRetryDelayMillis(jmxConfiguration.retryDelayMillis()));
+ MetricRegistry metricRegistry = new MetricRegistry();
CassandraAdapterDelegate delegate = new CassandraAdapterDelegate(vertx,
config.num(),
versionProvider,
@@ -618,7 +680,8 @@
new DriverUtils(),
sidecarVersion,
ipAddress,
- port);
+ port,
+ new InstanceHealthMetrics(metricRegistry));
return InstanceMetadataImpl.builder()
.id(config.num())
.host(hostName)
@@ -626,7 +689,7 @@
.dataDirs(Arrays.asList(dataDirectories))
.stagingDir(stagingDir)
.delegate(delegate)
- .globalMetricRegistryName("test")
+ .metricRegistry(metricRegistry)
.build();
}
diff --git a/cassandra-analytics-integration-tests/build.gradle b/cassandra-analytics-integration-tests/build.gradle
index a6af9f9..d1728c3 100644
--- a/cassandra-analytics-integration-tests/build.gradle
+++ b/cassandra-analytics-integration-tests/build.gradle
@@ -70,6 +70,13 @@
testImplementation(group: "${sparkGroupId}", name: "spark-sql_${scalaMajorVersion}", version: "${sparkVersion}")
testImplementation(project(path: ":cassandra-analytics-integration-framework"))
+
+ testImplementation 'software.amazon.awssdk:s3'
+ testImplementation 'software.amazon.awssdk:netty-nio-client'
+ testImplementation platform(group: 'software.amazon.awssdk', name:'bom', version:"${project.aswSdkVersion}")
+ testImplementation 'com.adobe.testing:s3mock-testcontainers:2.17.0' // 3.x version do not support java 11
+
+ testRuntimeOnly 'com.fasterxml.jackson.core:jackson-annotations:2.14.2'
}
test {
@@ -117,6 +124,15 @@
println("JVM arguments for $project.name are $allJvmArgs")
}
+ // container test does not run on java 8
+ if ("true" == System.getenv("skipContainerTest") || JavaVersion.current().isJava8()) {
+ exclude("**/testcontainer/**")
+
+ filter {
+ setFailOnNoMatchingTests(false) // do not fail the build if no matching test found, since in CI we run individual test class
+ }
+ }
+
useJUnitPlatform()
def destDir = Paths.get(rootProject.rootDir.absolutePath, "build", "test-reports", "integration").toFile()
reports {
diff --git a/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/testcontainer/BulkWriteS3CompatModeSimpleTest.java b/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/testcontainer/BulkWriteS3CompatModeSimpleTest.java
new file mode 100644
index 0000000..da15e0e
--- /dev/null
+++ b/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/testcontainer/BulkWriteS3CompatModeSimpleTest.java
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.cassandra.analytics.testcontainer;
+
+import java.util.Map;
+
+import com.google.common.collect.ImmutableMap;
+import org.junit.jupiter.api.Test;
+
+import com.adobe.testing.s3mock.testcontainers.S3MockContainer;
+import org.apache.cassandra.analytics.DataGenerationUtils;
+import org.apache.cassandra.analytics.SharedClusterSparkIntegrationTestBase;
+import org.apache.cassandra.sidecar.testing.QualifiedName;
+import org.apache.cassandra.testing.ClusterBuilderConfiguration;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SparkSession;
+
+import static org.apache.cassandra.testing.TestUtils.CREATE_TEST_TABLE_STATEMENT;
+import static org.apache.cassandra.testing.TestUtils.DC1_RF3;
+import static org.apache.cassandra.testing.TestUtils.ROW_COUNT;
+import static org.apache.cassandra.testing.TestUtils.TEST_KEYSPACE;
+import static org.assertj.core.api.Assertions.assertThat;
+
+class BulkWriteS3CompatModeSimpleTest extends SharedClusterSparkIntegrationTestBase
+{
+ public static final String BUCKET_NAME = "sbw-bucket";
+ private static final QualifiedName TABLE_NAME = new QualifiedName(TEST_KEYSPACE, BulkWriteS3CompatModeSimpleTest.class.getSimpleName().toLowerCase());
+ private S3MockContainer s3Mock;
+
+ @Override
+ protected void afterClusterProvisioned()
+ {
+ // must start s3Mock before starting sidecar, in order to provide the actual s3 server port to start sidecar
+ super.afterClusterProvisioned();
+ s3Mock = new S3MockContainer("2.17.0")
+ .withInitialBuckets(BUCKET_NAME);
+ s3Mock.start();
+ assertThat(s3Mock.isRunning()).isTrue();
+ }
+
+ @Override
+ protected void afterClusterShutdown()
+ {
+ s3Mock.stop();
+ }
+
+ @Override
+ protected ClusterBuilderConfiguration testClusterConfiguration()
+ {
+ return super.testClusterConfiguration()
+ .nodesPerDc(3);
+ }
+
+ @Override
+ protected void initializeSchemaForTest()
+ {
+ createTestKeyspace(TEST_KEYSPACE, DC1_RF3);
+ createTestTable(TABLE_NAME, CREATE_TEST_TABLE_STATEMENT);
+ }
+
+ @Override
+ protected Map<String, String> sidecarAdditionalOptions()
+ {
+ return ImmutableMap.of(SIDECAR_S3_ENDPOINT_OVERRIDE_OPT, s3Mock.getHttpEndpoint());
+ }
+
+ /**
+ * Write data using S3_COMPAT mode and read back using CQL. Assert that all written data are read back
+ */
+ @Test
+ void testS3CompatBulkWrite()
+ {
+ SparkSession spark = getOrCreateSparkSession();
+ Dataset<Row> df = DataGenerationUtils.generateCourseData(spark, ROW_COUNT);
+ Map<String, String> s3CompatOptions = ImmutableMap.of(
+ "data_transport", "S3_COMPAT",
+ "data_transport_extension_class", LocalStorageTransportExtension.class.getCanonicalName(),
+ "storage_client_endpoint_override", s3Mock.getHttpEndpoint() // point to s3Mock server
+ );
+ bulkWriterDataFrameWriter(df, TABLE_NAME, s3CompatOptions).save();
+ sparkTestUtils.validateWrites(df.collectAsList(), queryAllData(TABLE_NAME));
+ }
+}
diff --git a/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/testcontainer/LocalStorageTransportExtension.java b/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/testcontainer/LocalStorageTransportExtension.java
new file mode 100644
index 0000000..9ab92d9
--- /dev/null
+++ b/cassandra-analytics-integration-tests/src/test/java/org/apache/cassandra/analytics/testcontainer/LocalStorageTransportExtension.java
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.cassandra.analytics.testcontainer;
+
+import com.google.common.collect.ImmutableMap;
+
+import org.apache.cassandra.spark.transports.storage.StorageCredentialPair;
+import org.apache.cassandra.spark.transports.storage.StorageCredentials;
+import org.apache.cassandra.spark.transports.storage.extensions.CredentialChangeListener;
+import org.apache.cassandra.spark.transports.storage.extensions.ObjectFailureListener;
+import org.apache.cassandra.spark.transports.storage.extensions.StorageTransportConfiguration;
+import org.apache.cassandra.spark.transports.storage.extensions.StorageTransportExtension;
+import org.apache.spark.SparkConf;
+
+import static org.apache.cassandra.analytics.testcontainer.BulkWriteS3CompatModeSimpleTest.BUCKET_NAME;
+
+public class LocalStorageTransportExtension implements StorageTransportExtension
+{
+ @Override
+ public void initialize(String jobId, SparkConf conf, boolean isOnDriver)
+ {
+ }
+
+ @Override
+ public StorageTransportConfiguration getStorageConfiguration()
+ {
+ return new StorageTransportConfiguration(BUCKET_NAME,
+ "us-west-1",
+ BUCKET_NAME,
+ "eu-west-1",
+ "key-prefix",
+ generateTokens(),
+ ImmutableMap.of());
+ }
+
+ @Override
+ public void onTransportStart(long elapsedMillis)
+ {
+ }
+
+ @Override
+ public void setCredentialChangeListener(CredentialChangeListener credentialChangeListener)
+ {
+ }
+
+ @Override
+ public void setObjectFailureListener(ObjectFailureListener objectFailureListener)
+ {
+ }
+
+ @Override
+ public void onObjectPersisted(String bucket, String key, long sizeInBytes)
+ {
+ }
+
+ @Override
+ public void onAllObjectsPersisted(long objectsCount, long rowCount, long elapsedMillis)
+ {
+ }
+
+ @Override
+ public void onObjectApplied(String bucket, String key, long sizeInBytes, long elapsedMillis)
+ {
+ }
+
+ @Override
+ public void onJobSucceeded(long elapsedMillis)
+ {
+ }
+
+ @Override
+ public void onJobFailed(long elapsedMillis, Throwable throwable)
+ {
+ }
+
+ private StorageCredentialPair generateTokens()
+ {
+ return new StorageCredentialPair(new StorageCredentials("writeKey",
+ "writeSecret",
+ "writeSessionToken"),
+ new StorageCredentials("readKey",
+ "readSecret",
+ "readSessionToken"));
+ }
+}
diff --git a/cassandra-analytics-integration-tests/src/test/resources/logback-test.xml b/cassandra-analytics-integration-tests/src/test/resources/logback-test.xml
index f1265f8..14a4cda 100644
--- a/cassandra-analytics-integration-tests/src/test/resources/logback-test.xml
+++ b/cassandra-analytics-integration-tests/src/test/resources/logback-test.xml
@@ -38,7 +38,7 @@
<root level="INFO">
<appender-ref ref="STDOUT" />
</root>
- <root level="DEBUG">
+ <root level="INFO">
<appender-ref ref="FILE" />
</root>
diff --git a/cassandra-bridge/src/main/java/org/apache/cassandra/bridge/CassandraBridge.java b/cassandra-bridge/src/main/java/org/apache/cassandra/bridge/CassandraBridge.java
index b234065..79e8480 100644
--- a/cassandra-bridge/src/main/java/org/apache/cassandra/bridge/CassandraBridge.java
+++ b/cassandra-bridge/src/main/java/org/apache/cassandra/bridge/CassandraBridge.java
@@ -45,6 +45,8 @@
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.Stream;
+import javax.annotation.Nullable;
+import javax.validation.constraints.NotNull;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableMap;
@@ -54,6 +56,7 @@
import org.apache.cassandra.spark.data.CqlField;
import org.apache.cassandra.spark.data.CqlTable;
import org.apache.cassandra.spark.data.ReplicationFactor;
+import org.apache.cassandra.spark.data.SSTable;
import org.apache.cassandra.spark.data.SSTablesSupplier;
import org.apache.cassandra.spark.data.partitioner.Partitioner;
import org.apache.cassandra.spark.reader.IndexEntry;
@@ -64,8 +67,6 @@
import org.apache.cassandra.spark.sparksql.filters.SparkRangeFilter;
import org.apache.cassandra.spark.stats.Stats;
import org.apache.cassandra.spark.utils.TimeProvider;
-import org.jetbrains.annotations.NotNull;
-import org.jetbrains.annotations.Nullable;
/**
* Provides an abstract interface for all calls to the Cassandra code of a specific version
@@ -391,6 +392,10 @@
Set<String> userDefinedTypeStatements,
int bufferSizeMB);
+ public abstract SSTableSummary getSSTableSummary(@NotNull String keyspace,
+ @NotNull String table,
+ @NotNull SSTable ssTable);
+
public interface IRow
{
Object get(int position);
diff --git a/cassandra-bridge/src/main/java/org/apache/cassandra/bridge/SSTableSummary.java b/cassandra-bridge/src/main/java/org/apache/cassandra/bridge/SSTableSummary.java
new file mode 100644
index 0000000..6a2dd25
--- /dev/null
+++ b/cassandra-bridge/src/main/java/org/apache/cassandra/bridge/SSTableSummary.java
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.cassandra.bridge;
+
+import java.math.BigInteger;
+
+public class SSTableSummary
+{
+ public final BigInteger firstToken;
+ public final BigInteger lastToken;
+ public final String sstableId;
+
+ public SSTableSummary(BigInteger firstToken, BigInteger lastToken, String sstableId)
+ {
+ this.firstToken = firstToken;
+ this.lastToken = lastToken;
+ this.sstableId = sstableId;
+ }
+}
diff --git a/cassandra-bridge/src/main/java/org/apache/cassandra/spark/utils/FutureUtils.java b/cassandra-bridge/src/main/java/org/apache/cassandra/spark/utils/FutureUtils.java
index d9a8d6e..2e7d3bd 100644
--- a/cassandra-bridge/src/main/java/org/apache/cassandra/spark/utils/FutureUtils.java
+++ b/cassandra-bridge/src/main/java/org/apache/cassandra/spark/utils/FutureUtils.java
@@ -20,6 +20,7 @@
package org.apache.cassandra.spark.utils;
import java.util.ArrayList;
+import java.util.Collection;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
@@ -78,19 +79,19 @@
* Await all futures and combine into single result
*
* @param <T> result type returned by this method
- * @param futures list of futures
+ * @param futures collection of futures
* @param acceptPartialResult if false, fail the entire request if a single failure occurs, if true just log partial failures
- * @param logger consumer to log errors
+ * @param onFailure consumer of errors
* @return result of all combined futures
*/
- public static <T> List<T> awaitAll(List<CompletableFuture<T>> futures,
+ public static <T> List<T> awaitAll(Collection<CompletableFuture<T>> futures,
boolean acceptPartialResult,
- Consumer<Throwable> logger)
+ Consumer<Throwable> onFailure)
{
- List<T> result = new ArrayList<>(futures.size() * 10); // TODO: Comment on why allocate tenfold
+ List<T> result = new ArrayList<>(futures.size());
for (CompletableFuture<T> future : futures)
{
- FutureResult<T> futureResult = await(future, logger);
+ FutureResult<T> futureResult = await(future, onFailure);
if (futureResult.throwable != null)
{
// Failed
diff --git a/cassandra-bridge/src/main/java/org/apache/cassandra/spark/utils/MapUtils.java b/cassandra-bridge/src/main/java/org/apache/cassandra/spark/utils/MapUtils.java
index b2fa27c..e899b10 100644
--- a/cassandra-bridge/src/main/java/org/apache/cassandra/spark/utils/MapUtils.java
+++ b/cassandra-bridge/src/main/java/org/apache/cassandra/spark/utils/MapUtils.java
@@ -189,20 +189,23 @@
return value != null ? Long.parseLong(value) : defaultValue;
}
- public static <T extends Enum<T>> T getEnumOption(Map<String, String> options, String key, T defaultValue)
- {
- return getEnumOption(options, key, defaultValue, null);
- }
-
- @SuppressWarnings("unchecked")
- public static <T extends Enum<T>> T getEnumOption(Map<String, String> options,
- String key, T defaultValue,
- String displayName)
+ /**
+ * Returns the enum variant for the given {@code key} and the {@code enumClass}. The {@code defaultValue} is returned
+ * when the lookup misses.
+ *
+ * @param options the map
+ * @param key the key to lookup
+ * @param defaultValue the default value
+ * @param displayName an optional name to display in the error message
+ * @return the enum variant or the default value if the lookup misses
+ * @param <T> enum type
+ */
+ public static <T extends Enum<T>> T getEnumOption(Map<String, String> options, String key, T defaultValue, String displayName)
{
String value = options.get(lowerCaseKey(key));
try
{
- return value != null ? (T) Enum.valueOf(defaultValue.getDeclaringClass(), value) : defaultValue;
+ return value != null ? Enum.valueOf(defaultValue.getDeclaringClass(), value) : defaultValue;
}
catch (IllegalArgumentException exception)
{
@@ -219,7 +222,7 @@
* @param options the map
* @param key the key to the map
* @param defaultValue the default value
- * @return the long value
+ * @return String value
*/
public static String getOrDefault(Map<String, String> options, String key, String defaultValue)
{
diff --git a/cassandra-bridge/src/main/scala-2.12-spark-3/org/apache/cassandra/spark/utils/ScalaConversionUtils.java b/cassandra-bridge/src/main/scala-2.12-spark-3/org/apache/cassandra/spark/utils/ScalaConversionUtils.java
index 23eabb3..1987332 100644
--- a/cassandra-bridge/src/main/scala-2.12-spark-3/org/apache/cassandra/spark/utils/ScalaConversionUtils.java
+++ b/cassandra-bridge/src/main/scala-2.12-spark-3/org/apache/cassandra/spark/utils/ScalaConversionUtils.java
@@ -36,21 +36,21 @@
public static <A> java.lang.Iterable<A> asJavaIterable(scala.collection.Iterable<A> iterable)
{
- return JavaConverters.asJavaIterable(iterable);
+ return JavaConverters.<A>asJavaIterable(iterable);
}
public static <A> scala.collection.Iterator<A> asScalaIterator(java.util.Iterator<A> iterator)
{
- return JavaConverters.asScalaIterator(iterator);
+ return JavaConverters.<A>asScalaIterator(iterator);
}
public static <A, B> java.util.Map<A, B> mapAsJavaMap(scala.collection.Map<A, B> map)
{
- return JavaConverters.mapAsJavaMap(map);
+ return JavaConverters.<A, B>mapAsJavaMap(map);
}
public static <A> List<A> mutableSeqAsJavaList(Seq<A> seq)
{
- return JavaConverters.mutableSeqAsJavaList(seq);
+ return JavaConverters.<A>mutableSeqAsJavaList(seq);
}
}
diff --git a/cassandra-four-zero-bridge/src/main/java/org/apache/cassandra/bridge/CassandraBridgeImplementation.java b/cassandra-four-zero-bridge/src/main/java/org/apache/cassandra/bridge/CassandraBridgeImplementation.java
index 0aa576e..364b91b 100644
--- a/cassandra-four-zero-bridge/src/main/java/org/apache/cassandra/bridge/CassandraBridgeImplementation.java
+++ b/cassandra-four-zero-bridge/src/main/java/org/apache/cassandra/bridge/CassandraBridgeImplementation.java
@@ -53,6 +53,7 @@
import org.apache.cassandra.config.DatabaseDescriptor;
import org.apache.cassandra.config.ParameterizedClass;
import org.apache.cassandra.cql3.ColumnIdentifier;
+import org.apache.cassandra.db.DecoratedKey;
import org.apache.cassandra.db.Keyspace;
import org.apache.cassandra.db.commitlog.CommitLogSegmentManagerStandard;
import org.apache.cassandra.db.marshal.AbstractType;
@@ -70,12 +71,14 @@
import org.apache.cassandra.io.sstable.SSTableTombstoneWriter;
import org.apache.cassandra.io.sstable.format.SSTableReader;
import org.apache.cassandra.locator.SimpleSnitch;
+import org.apache.cassandra.schema.Schema;
import org.apache.cassandra.schema.TableMetadata;
import org.apache.cassandra.schema.TableMetadataRef;
import org.apache.cassandra.spark.data.CqlField;
import org.apache.cassandra.spark.data.CqlTable;
import org.apache.cassandra.spark.data.CqlType;
import org.apache.cassandra.spark.data.ReplicationFactor;
+import org.apache.cassandra.spark.data.SSTable;
import org.apache.cassandra.spark.data.SSTablesSupplier;
import org.apache.cassandra.spark.data.complex.CqlCollection;
import org.apache.cassandra.spark.data.complex.CqlFrozen;
@@ -109,9 +112,11 @@
import org.apache.cassandra.spark.reader.CompactionStreamScanner;
import org.apache.cassandra.spark.reader.IndexEntry;
import org.apache.cassandra.spark.reader.IndexReader;
+import org.apache.cassandra.spark.reader.ReaderUtils;
import org.apache.cassandra.spark.reader.Rid;
import org.apache.cassandra.spark.reader.SchemaBuilder;
import org.apache.cassandra.spark.reader.StreamScanner;
+import org.apache.cassandra.spark.reader.SummaryDbUtils;
import org.apache.cassandra.spark.reader.common.IndexIterator;
import org.apache.cassandra.spark.sparksql.filters.PartitionKeyFilter;
import org.apache.cassandra.spark.sparksql.filters.PruneColumnFilter;
@@ -121,6 +126,7 @@
import org.apache.cassandra.spark.utils.TimeProvider;
import org.apache.cassandra.tools.JsonTransformer;
import org.apache.cassandra.tools.Util;
+import org.apache.cassandra.utils.Pair;
import org.apache.cassandra.utils.UUIDGen;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
@@ -599,6 +605,45 @@
userDefinedTypeStatements, bufferSizeMB);
}
+ @Override
+ public SSTableSummary getSSTableSummary(@NotNull String keyspace,
+ @NotNull String table,
+ @NotNull SSTable ssTable)
+ {
+ TableMetadata metadata = Schema.instance.getTableMetadata(keyspace, table);
+ if (metadata == null)
+ {
+ throw new RuntimeException("Could not create table metadata needed for reading SSTable summaries for keyspace: " + keyspace);
+ }
+ try
+ {
+ SummaryDbUtils.Summary summary = SummaryDbUtils.readSummary(metadata, ssTable);
+ Pair<DecoratedKey, DecoratedKey> keys = Pair.create(summary.first(), summary.last());
+ if (keys.left == null || keys.right == null)
+ {
+ keys = ReaderUtils.keysFromIndex(metadata, ssTable);
+ }
+ if (keys.left == null || keys.right == null)
+ {
+ throw new RuntimeException("Could not load SSTable first or last tokens for SSTable: " + ssTable.getDataFileName());
+ }
+ DecoratedKey first = keys.left;
+ DecoratedKey last = keys.right;
+ BigInteger firstToken = ReaderUtils.tokenToBigInteger(first.getToken());
+ BigInteger lastToken = ReaderUtils.tokenToBigInteger(last.getToken());
+ return new SSTableSummary(firstToken, lastToken, getSSTablePrefix(ssTable.getDataFileName()));
+ }
+ catch (final IOException exception)
+ {
+ throw new RuntimeException(exception);
+ }
+ }
+
+ private String getSSTablePrefix(String dataFileName)
+ {
+ return dataFileName.substring(0, dataFileName.lastIndexOf('-') + 1);
+ }
+
// Version-Specific Test Utility Methods
@Override
diff --git a/cassandra-four-zero-bridge/src/main/java/org/apache/cassandra/spark/reader/ReaderUtils.java b/cassandra-four-zero-bridge/src/main/java/org/apache/cassandra/spark/reader/ReaderUtils.java
index e7bbe81..2f887c6 100644
--- a/cassandra-four-zero-bridge/src/main/java/org/apache/cassandra/spark/reader/ReaderUtils.java
+++ b/cassandra-four-zero-bridge/src/main/java/org/apache/cassandra/spark/reader/ReaderUtils.java
@@ -80,10 +80,10 @@
{
private static final int CHECKSUM_LENGTH = 4; // CRC32
private static final Constructor<?> SERIALIZATION_HEADER =
- Arrays.stream(SerializationHeader.Component.class.getDeclaredConstructors())
- .filter(constructor -> constructor.getParameterCount() == 5)
- .findFirst()
- .orElseThrow(() -> new RuntimeException("Could not find SerializationHeader.Component constructor"));
+ Arrays.stream(SerializationHeader.Component.class.getDeclaredConstructors())
+ .filter(constructor -> constructor.getParameterCount() == 5)
+ .findFirst()
+ .orElseThrow(() -> new RuntimeException("Could not find SerializationHeader.Component constructor"));
public static final ByteBuffer SUPER_COLUMN_MAP_COLUMN = ByteBufferUtil.EMPTY_BYTE_BUFFER;
static
@@ -145,7 +145,7 @@
// We use comparator.size() rather than clustering.size() because of static clusterings
int clusteringSize = metadata.comparator.size();
int size = clusteringSize + (TableMetadata.Flag.isDense(metadata.flags) ? 0 : 1)
- + (collectionElement == null ? 0 : 1);
+ + (collectionElement == null ? 0 : 1);
if (TableMetadata.Flag.isSuper(metadata.flags))
{
size = clusteringSize + 1;
@@ -196,8 +196,8 @@
return CompositeType.build(ByteBufferAccessor.instance, isStatic, values);
}
- static Pair<DecoratedKey, DecoratedKey> keysFromIndex(@NotNull TableMetadata metadata,
- @NotNull SSTable ssTable) throws IOException
+ public static Pair<DecoratedKey, DecoratedKey> keysFromIndex(@NotNull TableMetadata metadata,
+ @NotNull SSTable ssTable) throws IOException
{
try (InputStream primaryIndex = ssTable.openPrimaryIndexStream())
{
@@ -513,10 +513,10 @@
}
static List<PartitionKeyFilter> filterKeyInBloomFilter(
- @NotNull SSTable ssTable,
- @NotNull IPartitioner partitioner,
- Descriptor descriptor,
- @NotNull List<PartitionKeyFilter> partitionKeyFilters) throws IOException
+ @NotNull SSTable ssTable,
+ @NotNull IPartitioner partitioner,
+ Descriptor descriptor,
+ @NotNull List<PartitionKeyFilter> partitionKeyFilters) throws IOException
{
try
{
diff --git a/cassandra-four-zero-bridge/src/main/java/org/apache/cassandra/spark/reader/SummaryDbUtils.java b/cassandra-four-zero-bridge/src/main/java/org/apache/cassandra/spark/reader/SummaryDbUtils.java
index a667dcc..0a2df42 100644
--- a/cassandra-four-zero-bridge/src/main/java/org/apache/cassandra/spark/reader/SummaryDbUtils.java
+++ b/cassandra-four-zero-bridge/src/main/java/org/apache/cassandra/spark/reader/SummaryDbUtils.java
@@ -36,9 +36,9 @@
/**
* Helper methods for reading the Summary.db SSTable file component
*/
-final class SummaryDbUtils
+public final class SummaryDbUtils
{
- static class Summary
+ public static class Summary
{
private final IndexSummary indexSummary;
private final DecoratedKey firstKey;
@@ -74,7 +74,7 @@
throw new IllegalStateException(getClass() + " is static utility class and shall not be instantiated");
}
- static Summary readSummary(@NotNull TableMetadata metadata, @NotNull SSTable ssTable) throws IOException
+ public static Summary readSummary(@NotNull TableMetadata metadata, @NotNull SSTable ssTable) throws IOException
{
try (InputStream in = ssTable.openSummaryStream())
{
diff --git a/gradle.properties b/gradle.properties
index 8d64afc..0989924 100644
--- a/gradle.properties
+++ b/gradle.properties
@@ -31,5 +31,6 @@
# force version 4.5.1 of vertx to prevent issues initializing io.vertx.core.json.jackson.JacksonCodec,
# which requires a newer version of jackson, which is not available in spark 2
vertxVersion=4.5.1
+aswSdkVersion=2.25.31
org.gradle.jvmargs=-Xmx4096m
diff --git a/scripts/build-sidecar.sh b/scripts/build-sidecar.sh
index 56dc5df..09a59ec 100755
--- a/scripts/build-sidecar.sh
+++ b/scripts/build-sidecar.sh
@@ -24,7 +24,7 @@
SCRIPT_DIR=$( dirname -- "$( readlink -f -- "$0"; )"; )
SIDECAR_REPO="${SIDECAR_REPO:-https://github.com/apache/cassandra-sidecar.git}"
SIDECAR_BRANCH="${SIDECAR_BRANCH:-trunk}"
- SIDECAR_COMMIT="${SIDECAR_COMMIT:-20795db4d708b9287e0a2281695923bfb6fa9138}"
+ SIDECAR_COMMIT="${SIDECAR_COMMIT:-fd6f7ac5f9f19dbbeeb9e7f80ca1fcbf60d5a4c6}"
SIDECAR_JAR_DIR="$(dirname "${SCRIPT_DIR}/")/dependencies"
SIDECAR_JAR_DIR=${CASSANDRA_DEP_DIR:-$SIDECAR_JAR_DIR}
SIDECAR_BUILD_DIR="${SIDECAR_JAR_DIR}/sidecar-build"
@@ -47,7 +47,9 @@
cd "${SIDECAR_BUILD_DIR}"
echo "branch ${SIDECAR_BRANCH} sha ${SIDECAR_COMMIT}"
# check out the correct cassandra version:
+ # if the SIDECAR_BRANCH directory does not exist; we initialize from the scratch
if [ ! -d "${SIDECAR_BRANCH}" ] ; then
+ # if SIDECAR_COMMIT is defined; we pull the specified commit
if [ -n "${SIDECAR_COMMIT}" ] ; then
mkdir -p "${SIDECAR_BRANCH}"
cd "${SIDECAR_BRANCH}"
@@ -55,14 +57,17 @@
git remote add upstream "${SIDECAR_REPO}"
git fetch --depth=1 upstream "${SIDECAR_COMMIT}"
git reset --hard FETCH_HEAD
- else
+ else # we pull/clone the branch instead
git clone --depth 1 --single-branch --branch "${SIDECAR_BRANCH}" "${SIDECAR_REPO}" "${SIDECAR_BRANCH}"
cd "${SIDECAR_BRANCH}"
fi
- else
+ else # SIDECAR_BRANCH already exists; we are doing delta builds
cd "${SIDECAR_BRANCH}"
+ # no SIDECAR_COMMIT defined; we pull any new commits in the branch
if [ -z "${SIDECAR_COMMIT}" ] ; then
- git pull
+ git pull upstream "${SIDECAR_BRANCH}"
+ else # we pull the specified commit from upstream
+ git pull upstream "${SIDECAR_COMMIT}"
fi
fi
if [ -z "${SIDECAR_COMMIT}" ] ; then