blob: f63ff4edafc0133d35e56b717353b84ee48fb974 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.cassandra.spark.example;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.LongStream;
import com.google.common.collect.ImmutableMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.cassandra.spark.KryoRegister;
import org.apache.cassandra.spark.bulkwriter.BulkSparkConf;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.sql.DataFrameReader;
import org.apache.spark.sql.DataFrameWriter;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.StructType;
import static org.apache.spark.sql.types.DataTypes.BinaryType;
import static org.apache.spark.sql.types.DataTypes.LongType;
/**
* Example showcasing the Cassandra Spark Analytics write and read capabilities
*
* <p>Prepare your environment by creating the following keyspace and table
*
* <p>Schema for the {@code keyspace}:
* <pre>
* CREATE KEYSPACE spark_test WITH replication = {'class': 'NetworkTopologyStrategy', 'datacenter1': '3'}
* AND durable_writes = true;
* </pre>
*
* <p>Schema for the {@code table}:
* <pre>
* CREATE TABLE spark_test.test (
* id BIGINT PRIMARY KEY,
* course BLOB,
* marks BIGINT
* );
* </pre>
*/
public abstract class AbstractCassandraJob
{
private final Logger logger = LoggerFactory.getLogger(this.getClass());
private JobConfiguration configuration;
protected abstract JobConfiguration configureJob(SparkContext sc, SparkConf sparkConf);
public void start(String[] args)
{
logger.info("Starting Spark job with args={}", Arrays.toString(args));
SparkConf sparkConf = new SparkConf().setAppName("Sample Spark Cassandra Bulk Reader Job")
.set("spark.master", "local[8]");
// Add SBW-specific settings
// TODO: simplify setting up spark conf
BulkSparkConf.setupSparkConf(sparkConf, true);
KryoRegister.setup(sparkConf);
SparkSession spark = SparkSession
.builder()
.config(sparkConf)
.getOrCreate();
SparkContext sc = spark.sparkContext();
SQLContext sql = spark.sqlContext();
logger.info("Spark Conf: " + sparkConf.toDebugString());
configuration = configureJob(sc, sparkConf);
long rowCount = configuration.rowCount;
try
{
Dataset<Row> written = null;
if (configuration.shouldWrite())
{
written = write(rowCount, sql, sc);
}
Dataset<Row> read = null;
if (configuration.shouldRead())
{
read = read(rowCount, sql);
}
if (configuration.shouldWrite() && configuration.shouldRead())
{
checkSmallDataFrameEquality(written, read);
}
logger.info("Finished Spark job, shutting down...");
sc.stop();
}
catch (Throwable throwable)
{
logger.error("Unexpected exception executing Spark job", throwable);
try
{
sc.stop();
}
catch (Throwable ignored)
{
}
}
}
protected Dataset<Row> write(long rowCount, SQLContext sql, SparkContext sc)
{
StructType schema = new StructType()
.add("id", LongType, false)
.add("course", BinaryType, false)
.add("marks", LongType, false);
JavaSparkContext javaSparkContext = JavaSparkContext.fromSparkContext(sc);
int parallelism = sc.defaultParallelism();
JavaRDD<Row> rows = genDataset(javaSparkContext, rowCount, parallelism);
Dataset<Row> df = sql.createDataFrame(rows, schema);
DataFrameWriter<Row> writer = df.write().format("org.apache.cassandra.spark.sparksql.CassandraDataSink");
writer.options(configuration.writeOptions);
writer.mode("append").save();
return df;
}
protected Dataset<Row> read(long expectedRowCount, SQLContext sql)
{
DataFrameReader reader = sql.read().format("org.apache.cassandra.spark.sparksql.CassandraDataSource");
reader.options(configuration.readOptions);
Dataset<Row> df = reader.load();
long count = df.count();
logger.info("Found {} records", count);
if (count != expectedRowCount)
{
logger.error("Expected {} records but found {} records", expectedRowCount, count);
return null;
}
return df;
}
private static void checkSmallDataFrameEquality(Dataset<Row> expected, Dataset<Row> actual)
{
if (actual == null)
{
throw new NullPointerException("actual dataframe is null");
}
if (!actual.exceptAll(expected).isEmpty())
{
throw new IllegalStateException("The content of the dataframes differs");
}
}
private static JavaRDD<Row> genDataset(JavaSparkContext sc, long records, Integer parallelism)
{
long recordsPerPartition = records / parallelism;
long remainder = records - (recordsPerPartition * parallelism);
List<Integer> seq = IntStream.range(0, parallelism).boxed().collect(Collectors.toList());
return sc.parallelize(seq, parallelism).mapPartitionsWithIndex(
(Function2<Integer, Iterator<Integer>, Iterator<Row>>) (index, integerIterator) -> {
long firstRecordNumber = index * recordsPerPartition;
long recordsToGenerate = (index.equals(parallelism)) ? remainder : recordsPerPartition;
return LongStream.range(0, recordsToGenerate).mapToObj((offset) -> {
long i = firstRecordNumber + offset;
String courseNameString = String.valueOf(i);
int courseNameStringLen = courseNameString.length();
int courseNameMultiplier = 1000 / courseNameStringLen;
byte[] courseName = dupStringAsBytes(courseNameString, courseNameMultiplier);
return RowFactory.create(i, courseName, i);
}).iterator();
}, false);
}
private static byte[] dupStringAsBytes(String string, Integer times)
{
byte[] stringBytes = string.getBytes();
ByteBuffer buf = ByteBuffer.allocate(stringBytes.length * times);
for (int i = 0; i < times; i++)
{
buf.put(stringBytes);
}
return buf.array();
}
static class JobConfiguration
{
long rowCount = 100_000; // being mutable deliberately for testing convenience.
ImmutableMap<String, String> writeOptions;
ImmutableMap<String, String> readOptions;
JobConfiguration(Map<String, String> writeOptions, Map<String, String> readOptions)
{
this.writeOptions = ImmutableMap.copyOf(writeOptions);
this.readOptions = ImmutableMap.copyOf(readOptions);
}
boolean shouldWrite()
{
return !writeOptions.isEmpty();
}
boolean shouldRead()
{
return !readOptions.isEmpty();
}
}
}