blob: 5a271d65d167448c595d24ddde1dbbf838ece9fe [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.zeppelin.spark;
import com.google.common.io.Files;
import junit.framework.TestCase;
import org.apache.commons.lang3.exception.ExceptionUtils;
import org.apache.zeppelin.interpreter.Interpreter;
import org.apache.zeppelin.interpreter.InterpreterContext;
import org.apache.zeppelin.interpreter.InterpreterException;
import org.apache.zeppelin.interpreter.InterpreterGroup;
import org.apache.zeppelin.interpreter.InterpreterOutput;
import org.apache.zeppelin.interpreter.InterpreterResult;
import org.apache.zeppelin.interpreter.InterpreterResultMessage;
import org.apache.zeppelin.interpreter.LazyOpenInterpreter;
import org.apache.zeppelin.interpreter.remote.RemoteInterpreterEventClient;
import org.apache.zeppelin.interpreter.thrift.InterpreterCompletion;
import org.apache.zeppelin.python.IPythonInterpreterTest;
import org.junit.Test;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.verify;
public class IPySparkInterpreterTest extends IPythonInterpreterTest {
private RemoteInterpreterEventClient mockIntpEventClient = mock(RemoteInterpreterEventClient.class);
@Override
protected Properties initIntpProperties() {
Properties p = new Properties();
p.setProperty("spark.master", "local[4]");
p.setProperty("master", "local[4]");
p.setProperty("spark.submit.deployMode", "client");
p.setProperty("spark.app.name", "Zeppelin Test");
p.setProperty("zeppelin.spark.useHiveContext", "false");
p.setProperty("zeppelin.spark.maxResult", "3");
p.setProperty("zeppelin.spark.importImplicit", "true");
p.setProperty("zeppelin.pyspark.python", "python");
p.setProperty("zeppelin.dep.localrepo", Files.createTempDir().getAbsolutePath());
p.setProperty("zeppelin.python.gatewayserver_address", "127.0.0.1");
p.setProperty("zeppelin.spark.deprecatedMsg.show", "false");
return p;
}
@Override
protected void startInterpreter(Properties properties) throws InterpreterException {
InterpreterContext context = getInterpreterContext();
context.setIntpEventClient(mockIntpEventClient);
InterpreterContext.set(context);
LazyOpenInterpreter sparkInterpreter = new LazyOpenInterpreter(
new SparkInterpreter(properties));
intpGroup = new InterpreterGroup();
intpGroup.put("session_1", new ArrayList<Interpreter>());
intpGroup.get("session_1").add(sparkInterpreter);
sparkInterpreter.setInterpreterGroup(intpGroup);
LazyOpenInterpreter pySparkInterpreter =
new LazyOpenInterpreter(new PySparkInterpreter(properties));
intpGroup.get("session_1").add(pySparkInterpreter);
pySparkInterpreter.setInterpreterGroup(intpGroup);
interpreter = new LazyOpenInterpreter(new IPySparkInterpreter(properties));
intpGroup.get("session_1").add(interpreter);
interpreter.setInterpreterGroup(intpGroup);
interpreter.open();
}
@Override
public void tearDown() throws InterpreterException {
intpGroup.close();
interpreter = null;
intpGroup = null;
}
//@Test
public void testIPySpark() throws InterruptedException, InterpreterException, IOException {
testPySpark(interpreter, mockIntpEventClient);
}
public static void testPySpark(final Interpreter interpreter, RemoteInterpreterEventClient mockIntpEventClient)
throws InterpreterException, IOException, InterruptedException {
reset(mockIntpEventClient);
// rdd
InterpreterContext context = createInterpreterContext(mockIntpEventClient);
InterpreterResult result = interpreter.interpret("sc.version", context);
Thread.sleep(100);
assertEquals(InterpreterResult.Code.SUCCESS, result.code());
String sparkVersion = context.out.toInterpreterResultMessage().get(0).getData();
context = createInterpreterContext(mockIntpEventClient);
result = interpreter.interpret("sc.range(1,10).sum()", context);
Thread.sleep(100);
assertEquals(InterpreterResult.Code.SUCCESS, result.code());
List<InterpreterResultMessage> interpreterResultMessages = context.out.toInterpreterResultMessage();
assertEquals("45", interpreterResultMessages.get(0).getData().trim());
// spark job url is sent
verify(mockIntpEventClient).onParaInfosReceived(any(Map.class));
// spark sql
context = createInterpreterContext(mockIntpEventClient);
if (!isSpark2(sparkVersion)) {
result = interpreter.interpret("df = sqlContext.createDataFrame([(1,'a'),(2,'b')])\ndf.show()", context);
assertEquals(InterpreterResult.Code.SUCCESS, result.code());
interpreterResultMessages = context.out.toInterpreterResultMessage();
assertEquals(
"+---+---+\n" +
"| _1| _2|\n" +
"+---+---+\n" +
"| 1| a|\n" +
"| 2| b|\n" +
"+---+---+", interpreterResultMessages.get(0).getData().trim());
context = createInterpreterContext(mockIntpEventClient);
result = interpreter.interpret("z.show(df)", context);
assertEquals(InterpreterResult.Code.SUCCESS, result.code());
interpreterResultMessages = context.out.toInterpreterResultMessage();
assertEquals(
"_1 _2\n" +
"1 a\n" +
"2 b", interpreterResultMessages.get(0).getData().trim());
} else {
result = interpreter.interpret("df = spark.createDataFrame([(1,'a'),(2,'b')])\ndf.show()", context);
assertEquals(InterpreterResult.Code.SUCCESS, result.code());
interpreterResultMessages = context.out.toInterpreterResultMessage();
assertEquals(
"+---+---+\n" +
"| _1| _2|\n" +
"+---+---+\n" +
"| 1| a|\n" +
"| 2| b|\n" +
"+---+---+", interpreterResultMessages.get(0).getData().trim());
context = createInterpreterContext(mockIntpEventClient);
result = interpreter.interpret("z.show(df)", context);
assertEquals(InterpreterResult.Code.SUCCESS, result.code());
interpreterResultMessages = context.out.toInterpreterResultMessage();
assertEquals(
"_1 _2\n" +
"1 a\n" +
"2 b", interpreterResultMessages.get(0).getData().trim());
}
// cancel
if (interpreter instanceof IPySparkInterpreter) {
final InterpreterContext context2 = createInterpreterContext(mockIntpEventClient);
Thread thread = new Thread() {
@Override
public void run() {
InterpreterResult result = null;
try {
result = interpreter.interpret("import time\nsc.range(1,10).foreach(lambda x: time.sleep(1))", context2);
} catch (InterpreterException e) {
e.printStackTrace();
}
assertEquals(InterpreterResult.Code.ERROR, result.code());
List<InterpreterResultMessage> interpreterResultMessages = null;
try {
interpreterResultMessages = context2.out.toInterpreterResultMessage();
assertTrue(interpreterResultMessages.get(0).getData().contains("KeyboardInterrupt"));
} catch (IOException e) {
e.printStackTrace();
}
}
};
thread.start();
// sleep 1 second to wait for the spark job starts
Thread.sleep(1000);
interpreter.cancel(context);
thread.join();
}
// completions
List<InterpreterCompletion> completions = interpreter.completion("sc.ran", 6, createInterpreterContext(mockIntpEventClient));
assertEquals(1, completions.size());
assertEquals("range", completions.get(0).getValue());
completions = interpreter.completion("sc.", 3, createInterpreterContext(mockIntpEventClient));
assertTrue(completions.size() > 0);
completions.contains(new InterpreterCompletion("range", "range", ""));
completions = interpreter.completion("1+1\nsc.", 7, createInterpreterContext(mockIntpEventClient));
assertTrue(completions.size() > 0);
completions.contains(new InterpreterCompletion("range", "range", ""));
completions = interpreter.completion("s", 1, createInterpreterContext(mockIntpEventClient));
assertTrue(completions.size() > 0);
completions.contains(new InterpreterCompletion("sc", "sc", ""));
// pyspark streaming TODO(zjffdu) disable pyspark streaming test temporary
context = createInterpreterContext(mockIntpEventClient);
// result = interpreter.interpret(
// "from pyspark.streaming import StreamingContext\n" +
// "import time\n" +
// "ssc = StreamingContext(sc, 1)\n" +
// "rddQueue = []\n" +
// "for i in range(5):\n" +
// " rddQueue += [ssc.sparkContext.parallelize([j for j in range(1, 1001)], 10)]\n" +
// "inputStream = ssc.queueStream(rddQueue)\n" +
// "mappedStream = inputStream.map(lambda x: (x % 10, 1))\n" +
// "reducedStream = mappedStream.reduceByKey(lambda a, b: a + b)\n" +
// "reducedStream.pprint()\n" +
// "ssc.start()\n" +
// "time.sleep(6)\n" +
// "ssc.stop(stopSparkContext=False, stopGraceFully=True)", context);
// Thread.sleep(1000);
// assertEquals(InterpreterResult.Code.SUCCESS, result.code());
// interpreterResultMessages = context.out.toInterpreterResultMessage();
// assertEquals(1, interpreterResultMessages.size());
// assertTrue(interpreterResultMessages.get(0).getData().contains("(0, 100)"));
}
@Test
@Override
public void testIPythonFailToLaunch() throws InterpreterException {
tearDown();
Properties properties = initIntpProperties();
properties.setProperty("spark.pyspark.python", "invalid_python");
try {
startInterpreter(properties);
fail("Should not be able to start IPythonInterpreter");
} catch (InterpreterException e) {
String exceptionMsg = ExceptionUtils.getStackTrace(e);
TestCase.assertTrue(exceptionMsg, exceptionMsg.contains("No such file or directory"));
}
}
private static boolean isSpark2(String sparkVersion) {
return sparkVersion.startsWith("'2.") || sparkVersion.startsWith("u'2.");
}
private static InterpreterContext createInterpreterContext(RemoteInterpreterEventClient mockRemoteEventClient) {
return InterpreterContext.builder()
.setNoteId("noteId")
.setParagraphId("paragraphId")
.setIntpEventClient(mockRemoteEventClient)
.setInterpreterOut(new InterpreterOutput(null))
.build();
}
}