blob: 4e217496cfd140d697e4f4780e1d88ccce7a7b26 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.cassandra.spark.example;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.UUID;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.LongStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.cassandra.spark.KryoRegister;
import org.apache.cassandra.spark.bulkwriter.BulkSparkConf;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.StructType;
import static org.apache.spark.sql.types.DataTypes.BinaryType;
import static org.apache.spark.sql.types.DataTypes.LongType;
/**
* Example showcasing the Cassandra Spark Analytics write and read capabilities
* <p>
* Prepare your environment by creating the following keyspace and table
* <p>
* Schema for the {@code keyspace}:
* <pre>
* CREATE KEYSPACE spark_test WITH replication = {'class': 'NetworkTopologyStrategy', 'datacenter1': '3'}
* AND durable_writes = true;
* </pre>
* <p>
* Schema for the {@code table}:
* <pre>
* CREATE TABLE spark_test.test (
* id BIGINT PRIMARY KEY,
* course BLOB,
* marks BIGINT
* );
* </pre>
*/
public final class SampleCassandraJob
{
private static final Logger LOGGER = LoggerFactory.getLogger(SampleCassandraJob.class);
private SampleCassandraJob()
{
throw new IllegalStateException(getClass() + " is static utility class and shall not be instantiated");
}
public static void main(String[] args)
{
LOGGER.info("Starting Spark job with args={}", Arrays.toString(args));
SparkConf sparkConf = new SparkConf().setAppName("Sample Spark Cassandra Bulk Reader Job")
.set("spark.master", "local[8]");
// Add SBW-specific settings
// TODO: Simplify setting up spark conf
BulkSparkConf.setupSparkConf(sparkConf, true);
KryoRegister.setup(sparkConf);
SparkSession spark = SparkSession
.builder()
.config(sparkConf)
.getOrCreate();
SparkContext sc = spark.sparkContext();
SQLContext sql = spark.sqlContext();
LOGGER.info("Spark Conf: " + sparkConf.toDebugString());
int rowCount = 10_000;
try
{
Dataset<Row> written = write(rowCount, sparkConf, sql, sc);
Dataset<Row> read = read(rowCount, sparkConf, sql, sc);
checkSmallDataFrameEquality(written, read);
LOGGER.info("Finished Spark job, shutting down...");
sc.stop();
}
catch (Throwable throwable)
{
LOGGER.error("Unexpected exception executing Spark job", throwable);
try
{
sc.stop();
}
catch (Throwable ignored)
{
}
}
}
private static Dataset<Row> write(long rowCount, SparkConf sparkConf, SQLContext sql, SparkContext sc)
{
StructType schema = new StructType()
.add("id", LongType, false)
.add("course", BinaryType, false)
.add("marks", LongType, false);
JavaSparkContext javaSparkContext = JavaSparkContext.fromSparkContext(sc);
int parallelism = sc.defaultParallelism();
JavaRDD<Row> rows = genDataset(javaSparkContext, rowCount, parallelism);
Dataset<Row> df = sql.createDataFrame(rows, schema);
df.write()
.format("org.apache.cassandra.spark.sparksql.CassandraDataSink")
.option("sidecar_instances", "localhost,localhost2,localhost3")
.option("keyspace", "spark_test")
.option("table", "test")
.option("local_dc", "datacenter1")
.option("bulk_writer_cl", "LOCAL_QUORUM")
.option("number_splits", "-1")
.mode("append")
.save();
return df;
}
private static Dataset<Row> read(int expectedRowCount, SparkConf sparkConf, SQLContext sql, SparkContext sc)
{
int coresPerExecutor = sparkConf.getInt("spark.executor.cores", 1);
int numExecutors = sparkConf.getInt("spark.dynamicAllocation.maxExecutors", sparkConf.getInt("spark.executor.instances", 1));
int numCores = coresPerExecutor * numExecutors;
Dataset<Row> df = sql.read().format("org.apache.cassandra.spark.sparksql.CassandraDataSource")
.option("sidecar_instances", "localhost,localhost2,localhost3")
.option("keyspace", "spark_test")
.option("table", "test")
.option("DC", "datacenter1")
.option("snapshotName", UUID.randomUUID().toString())
.option("createSnapshot", "true")
.option("defaultParallelism", sc.defaultParallelism())
.option("numCores", numCores)
.option("sizing", "default")
.load();
long count = df.count();
LOGGER.info("Found {} records", count);
if (count != expectedRowCount)
{
LOGGER.error("Expected {} records but found {} records", expectedRowCount, count);
return null;
}
return df;
}
private static void checkSmallDataFrameEquality(Dataset<Row> expected, Dataset<Row> actual)
{
if (actual == null)
{
throw new NullPointerException("actual dataframe is null");
}
if (!actual.exceptAll(expected).isEmpty())
{
throw new IllegalStateException("The content of the dataframes differs");
}
}
private static JavaRDD<Row> genDataset(JavaSparkContext sc, long records, Integer parallelism)
{
long recordsPerPartition = records / parallelism;
long remainder = records - (recordsPerPartition * parallelism);
List<Integer> seq = IntStream.range(0, parallelism).boxed().collect(Collectors.toList());
JavaRDD<Row> dataset = sc.parallelize(seq, parallelism).mapPartitionsWithIndex(
(Function2<Integer, Iterator<Integer>, Iterator<Row>>) (index, integerIterator) -> {
long firstRecordNumber = index * recordsPerPartition;
long recordsToGenerate = index.equals(parallelism) ? remainder : recordsPerPartition;
java.util.Iterator<Row> rows = LongStream.range(0, recordsToGenerate).mapToObj(offset -> {
long recordNumber = firstRecordNumber + offset;
String courseNameString = String.valueOf(recordNumber);
Integer courseNameStringLen = courseNameString.length();
Integer courseNameMultiplier = 1000 / courseNameStringLen;
byte[] courseName = dupStringAsBytes(courseNameString, courseNameMultiplier);
return RowFactory.create(recordNumber, courseName, recordNumber);
}).iterator();
return rows;
}, false);
return dataset;
}
private static byte[] dupStringAsBytes(String string, Integer times)
{
byte[] stringBytes = string.getBytes();
ByteBuffer buffer = ByteBuffer.allocate(stringBytes.length * times);
for (int time = 0; time < times; time++)
{
buffer.put(stringBytes);
}
return buffer.array();
}
}