blob: 43415fa0e1bd1caf1c323a2a89ce886240ca7d2e [file] [log] [blame]
package io.pivotal.gemfire.spark.connector;
import io.pivotal.gemfire.spark.connector.javaapi.*;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.SQLContext;
//import org.apache.spark.sql.api.java.JavaSQLContext;
import org.apache.spark.streaming.api.java.JavaDStream;
import org.apache.spark.streaming.api.java.JavaPairDStream;
import org.apache.spark.streaming.dstream.DStream;
import org.junit.Test;
import org.scalatest.junit.JUnitSuite;
import scala.Function1;
import scala.Function2;
import scala.Tuple2;
import scala.Tuple3;
import scala.collection.mutable.LinkedList;
import scala.reflect.ClassTag;
import static org.junit.Assert.*;
import static io.pivotal.gemfire.spark.connector.javaapi.GemFireJavaUtil.*;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.*;
public class JavaAPITest extends JUnitSuite {
@SuppressWarnings( "unchecked" )
public Tuple3<SparkContext, GemFireConnectionConf, GemFireConnection> createCommonMocks() {
SparkContext mockSparkContext = mock(SparkContext.class);
GemFireConnectionConf mockConnConf = mock(GemFireConnectionConf.class);
GemFireConnection mockConnection = mock(GemFireConnection.class);
when(mockConnConf.getConnection()).thenReturn(mockConnection);
when(mockConnConf.locators()).thenReturn(new LinkedList());
return new Tuple3<>(mockSparkContext, mockConnConf, mockConnection);
}
@Test
public void testSparkContextFunction() throws Exception {
Tuple3<SparkContext, GemFireConnectionConf, GemFireConnection> tuple3 = createCommonMocks();
GemFireJavaSparkContextFunctions wrapper = javaFunctions(tuple3._1());
assertTrue(tuple3._1() == wrapper.sc);
String regionPath = "testregion";
JavaPairRDD<String, String> rdd = wrapper.gemfireRegion(regionPath, tuple3._2());
verify(tuple3._3()).validateRegion(regionPath);
}
@Test
public void testJavaSparkContextFunctions() throws Exception {
SparkContext mockSparkContext = mock(SparkContext.class);
JavaSparkContext mockJavaSparkContext = mock(JavaSparkContext.class);
when(mockJavaSparkContext.sc()).thenReturn(mockSparkContext);
GemFireJavaSparkContextFunctions wrapper = javaFunctions(mockJavaSparkContext);
assertTrue(mockSparkContext == wrapper.sc);
}
@Test
@SuppressWarnings( "unchecked" )
public void testJavaPairRDDFunctions() throws Exception {
JavaPairRDD<String, Integer> mockPairRDD = mock(JavaPairRDD.class);
RDD<Tuple2<String, Integer>> mockTuple2RDD = mock(RDD.class);
when(mockPairRDD.rdd()).thenReturn(mockTuple2RDD);
GemFireJavaPairRDDFunctions wrapper = javaFunctions(mockPairRDD);
assertTrue(mockTuple2RDD == wrapper.rddf.rdd());
Tuple3<SparkContext, GemFireConnectionConf, GemFireConnection> tuple3 = createCommonMocks();
when(mockTuple2RDD.sparkContext()).thenReturn(tuple3._1());
String regionPath = "testregion";
wrapper.saveToGemfire(regionPath, tuple3._2());
verify(mockTuple2RDD, times(1)).sparkContext();
verify(tuple3._1(), times(1)).runJob(eq(mockTuple2RDD), any(Function2.class), any(ClassTag.class));
}
@Test
@SuppressWarnings( "unchecked" )
public void testJavaRDDFunctions() throws Exception {
JavaRDD<String> mockJavaRDD = mock(JavaRDD.class);
RDD<String> mockRDD = mock(RDD.class);
when(mockJavaRDD.rdd()).thenReturn(mockRDD);
GemFireJavaRDDFunctions wrapper = javaFunctions(mockJavaRDD);
assertTrue(mockRDD == wrapper.rddf.rdd());
Tuple3<SparkContext, GemFireConnectionConf, GemFireConnection> tuple3 = createCommonMocks();
when(mockRDD.sparkContext()).thenReturn(tuple3._1());
PairFunction<String, String, Integer> mockPairFunc = mock(PairFunction.class);
String regionPath = "testregion";
wrapper.saveToGemfire(regionPath, mockPairFunc, tuple3._2());
verify(mockRDD, times(1)).sparkContext();
verify(tuple3._1(), times(1)).runJob(eq(mockRDD), any(Function2.class), any(ClassTag.class));
}
@Test
@SuppressWarnings( "unchecked" )
public void testJavaPairDStreamFunctions() throws Exception {
JavaPairDStream<String, String> mockJavaDStream = mock(JavaPairDStream.class);
DStream<Tuple2<String, String>> mockDStream = mock(DStream.class);
when(mockJavaDStream.dstream()).thenReturn(mockDStream);
GemFireJavaPairDStreamFunctions wrapper = javaFunctions(mockJavaDStream);
assertTrue(mockDStream == wrapper.dsf.dstream());
Tuple3<SparkContext, GemFireConnectionConf, GemFireConnection> tuple3 = createCommonMocks();
String regionPath = "testregion";
wrapper.saveToGemfire(regionPath, tuple3._2());
verify(tuple3._2()).getConnection();
verify(tuple3._3()).validateRegion(regionPath);
verify(mockDStream).foreachRDD(any(Function1.class));
}
@Test
@SuppressWarnings( "unchecked" )
public void testJavaPairDStreamFunctionsWithTuple2DStream() throws Exception {
JavaDStream<Tuple2<String, String>> mockJavaDStream = mock(JavaDStream.class);
DStream<Tuple2<String, String>> mockDStream = mock(DStream.class);
when(mockJavaDStream.dstream()).thenReturn(mockDStream);
GemFireJavaPairDStreamFunctions wrapper = javaFunctions(toJavaPairDStream(mockJavaDStream));
assertTrue(mockDStream == wrapper.dsf.dstream());
}
@Test
@SuppressWarnings( "unchecked" )
public void testJavaDStreamFunctions() throws Exception {
JavaDStream<String> mockJavaDStream = mock(JavaDStream.class);
DStream<String> mockDStream = mock(DStream.class);
when(mockJavaDStream.dstream()).thenReturn(mockDStream);
GemFireJavaDStreamFunctions wrapper = javaFunctions(mockJavaDStream);
assertTrue(mockDStream == wrapper.dsf.dstream());
Tuple3<SparkContext, GemFireConnectionConf, GemFireConnection> tuple3 = createCommonMocks();
PairFunction<String, String, Integer> mockPairFunc = mock(PairFunction.class);
String regionPath = "testregion";
wrapper.saveToGemfire(regionPath, mockPairFunc, tuple3._2());
verify(tuple3._2()).getConnection();
verify(tuple3._3()).validateRegion(regionPath);
verify(mockDStream).foreachRDD(any(Function1.class));
}
@Test
public void testSQLContextFunction() throws Exception {
SQLContext mockSQLContext = mock(SQLContext.class);
GemFireJavaSQLContextFunctions wrapper = javaFunctions(mockSQLContext);
assertTrue(wrapper.scf.getClass() == GemFireSQLContextFunctions.class);
}
}