blob: 67220b4313769ebfe47b698e3970b12e0cb95993 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.zeppelin.livy;
import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import com.google.gson.annotations.SerializedName;
import org.apache.commons.lang.StringUtils;
import org.apache.http.auth.AuthSchemeProvider;
import org.apache.http.auth.AuthScope;
import org.apache.http.auth.Credentials;
import org.apache.http.client.CredentialsProvider;
import org.apache.http.client.HttpClient;
import org.apache.http.client.config.AuthSchemes;
import org.apache.http.client.config.RequestConfig;
import org.apache.http.config.Registry;
import org.apache.http.config.RegistryBuilder;
import org.apache.http.conn.ssl.SSLConnectionSocketFactory;
import org.apache.http.conn.ssl.SSLContexts;
import org.apache.http.impl.auth.SPNegoSchemeFactory;
import org.apache.http.impl.client.BasicCredentialsProvider;
import org.apache.http.impl.client.HttpClientBuilder;
import org.apache.http.impl.client.HttpClients;
import org.apache.commons.lang.exception.ExceptionUtils;
import org.apache.zeppelin.interpreter.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.HttpEntity;
import org.springframework.http.MediaType;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.ResponseEntity;
import org.springframework.http.client.HttpComponentsClientHttpRequestFactory;
import org.springframework.security.kerberos.client.KerberosRestTemplate;
import org.springframework.web.client.HttpClientErrorException;
import org.springframework.web.client.RestClientException;
import org.springframework.web.client.RestTemplate;
import javax.net.ssl.SSLContext;
import java.io.FileInputStream;
import java.io.IOException;
import java.security.KeyStore;
import java.security.Principal;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.concurrent.ConcurrentHashMap;
/**
* Base class for livy interpreters.
*/
public abstract class BaseLivyInterpreter extends Interpreter {
protected static final Logger LOGGER = LoggerFactory.getLogger(BaseLivyInterpreter.class);
private static Gson gson = new GsonBuilder().setPrettyPrinting().disableHtmlEscaping().create();
private static String SESSION_NOT_FOUND_PATTERN = "\"Session '\\d+' not found.\"";
protected volatile SessionInfo sessionInfo;
private String livyURL;
private int sessionCreationTimeout;
private int pullStatusInterval;
protected boolean displayAppInfo;
protected LivyVersion livyVersion;
private RestTemplate restTemplate;
// keep tracking the mapping between paragraphId and statementId, so that we can cancel the
// statement after we execute it.
private ConcurrentHashMap<String, Integer> paragraphId2StmtIdMapping = new ConcurrentHashMap<>();
public BaseLivyInterpreter(Properties property) {
super(property);
this.livyURL = property.getProperty("zeppelin.livy.url");
this.displayAppInfo = Boolean.parseBoolean(
property.getProperty("zeppelin.livy.displayAppInfo", "true"));
this.sessionCreationTimeout = Integer.parseInt(
property.getProperty("zeppelin.livy.session.create_timeout", 120 + ""));
this.pullStatusInterval = Integer.parseInt(
property.getProperty("zeppelin.livy.pull_status.interval.millis", 1000 + ""));
this.restTemplate = createRestTemplate();
}
public abstract String getSessionKind();
@Override
public void open() {
try {
initLivySession();
} catch (LivyException e) {
String msg = "Fail to create session, please check livy interpreter log and " +
"livy server log";
throw new RuntimeException(msg, e);
}
}
@Override
public void close() {
if (sessionInfo != null) {
closeSession(sessionInfo.id);
// reset sessionInfo to null so that we won't close it twice.
sessionInfo = null;
}
}
protected void initLivySession() throws LivyException {
this.sessionInfo = createSession(getUserName(), getSessionKind());
if (displayAppInfo) {
if (sessionInfo.appId == null) {
// livy 0.2 don't return appId and sparkUiUrl in response so that we need to get it
// explicitly by ourselves.
sessionInfo.appId = extractAppId();
}
if (sessionInfo.appInfo == null ||
StringUtils.isEmpty(sessionInfo.appInfo.get("sparkUiUrl"))) {
sessionInfo.webUIAddress = extractWebUIAddress();
} else {
sessionInfo.webUIAddress = sessionInfo.appInfo.get("sparkUiUrl");
}
LOGGER.info("Create livy session successfully with sessionId: {}, appId: {}, webUI: {}",
sessionInfo.id, sessionInfo.appId, sessionInfo.webUIAddress);
} else {
LOGGER.info("Create livy session successfully with sessionId: {}", this.sessionInfo.id);
}
// check livy version
try {
this.livyVersion = getLivyVersion();
LOGGER.info("Use livy " + livyVersion);
} catch (APINotFoundException e) {
this.livyVersion = new LivyVersion("0.2.0");
LOGGER.info("Use livy 0.2.0");
}
}
protected abstract String extractAppId() throws LivyException;
protected abstract String extractWebUIAddress() throws LivyException;
public SessionInfo getSessionInfo() {
return sessionInfo;
}
@Override
public InterpreterResult interpret(String st, InterpreterContext context) {
if (StringUtils.isEmpty(st)) {
return new InterpreterResult(InterpreterResult.Code.SUCCESS, "");
}
try {
return interpret(st, context.getParagraphId(), this.displayAppInfo, true);
} catch (LivyException e) {
LOGGER.error("Fail to interpret:" + st, e);
return new InterpreterResult(InterpreterResult.Code.ERROR,
InterpreterUtils.getMostRelevantMessage(e));
}
}
@Override
public void cancel(InterpreterContext context) {
if (livyVersion.isCancelSupported()) {
String paraId = context.getParagraphId();
Integer stmtId = paragraphId2StmtIdMapping.get(paraId);
try {
if (stmtId != null) {
cancelStatement(stmtId);
}
} catch (LivyException e) {
LOGGER.error("Fail to cancel statement " + stmtId + " for paragraph " + paraId, e);
} finally {
paragraphId2StmtIdMapping.remove(paraId);
}
} else {
LOGGER.warn("cancel is not supported for this version of livy: " + livyVersion);
}
}
@Override
public FormType getFormType() {
return FormType.NATIVE;
}
@Override
public int getProgress(InterpreterContext context) {
return 0;
}
private SessionInfo createSession(String user, String kind)
throws LivyException {
try {
Map<String, String> conf = new HashMap<>();
for (Map.Entry<Object, Object> entry : property.entrySet()) {
if (entry.getKey().toString().startsWith("livy.spark.") &&
!entry.getValue().toString().isEmpty())
conf.put(entry.getKey().toString().substring(5), entry.getValue().toString());
}
CreateSessionRequest request = new CreateSessionRequest(kind,
user == null || user.equals("anonymous") ? null : user, conf);
SessionInfo sessionInfo = SessionInfo.fromJson(
callRestAPI("/sessions", "POST", request.toJson()));
long start = System.currentTimeMillis();
// pull the session status until it is idle or timeout
while (!sessionInfo.isReady()) {
if ((System.currentTimeMillis() - start) / 1000 > sessionCreationTimeout) {
String msg = "The creation of session " + sessionInfo.id + " is timeout within "
+ sessionCreationTimeout + " seconds, appId: " + sessionInfo.appId
+ ", log: " + sessionInfo.log;
throw new LivyException(msg);
}
Thread.sleep(pullStatusInterval);
sessionInfo = getSessionInfo(sessionInfo.id);
LOGGER.info("Session {} is in state {}, appId {}", sessionInfo.id, sessionInfo.state,
sessionInfo.appId);
if (sessionInfo.isFinished()) {
String msg = "Session " + sessionInfo.id + " is finished, appId: " + sessionInfo.appId
+ ", log: " + sessionInfo.log;
throw new LivyException(msg);
}
}
return sessionInfo;
} catch (Exception e) {
LOGGER.error("Error when creating livy session for user " + user, e);
throw new LivyException(e);
}
}
private SessionInfo getSessionInfo(int sessionId) throws LivyException {
return SessionInfo.fromJson(callRestAPI("/sessions/" + sessionId, "GET"));
}
public InterpreterResult interpret(String code,
String paragraphId,
boolean displayAppInfo,
boolean appendSessionExpired) throws LivyException {
StatementInfo stmtInfo = null;
boolean sessionExpired = false;
try {
try {
stmtInfo = executeStatement(new ExecuteRequest(code));
} catch (SessionNotFoundException e) {
LOGGER.warn("Livy session {} is expired, new session will be created.", sessionInfo.id);
sessionExpired = true;
// we don't want to create multiple sessions because it is possible to have multiple thread
// to call this method, like LivySparkSQLInterpreter which use ParallelScheduler. So we need
// to check session status again in this sync block
synchronized (this) {
if (isSessionExpired()) {
initLivySession();
}
}
stmtInfo = executeStatement(new ExecuteRequest(code));
}
if (paragraphId != null) {
paragraphId2StmtIdMapping.put(paragraphId, stmtInfo.id);
}
// pull the statement status
while (!stmtInfo.isAvailable()) {
try {
Thread.sleep(pullStatusInterval);
} catch (InterruptedException e) {
LOGGER.error("InterruptedException when pulling statement status.", e);
throw new LivyException(e);
}
stmtInfo = getStatementInfo(stmtInfo.id);
}
if (appendSessionExpired) {
return appendSessionExpire(getResultFromStatementInfo(stmtInfo, displayAppInfo),
sessionExpired);
} else {
return getResultFromStatementInfo(stmtInfo, displayAppInfo);
}
} finally {
if (paragraphId != null) {
paragraphId2StmtIdMapping.remove(paragraphId);
}
}
}
protected LivyVersion getLivyVersion() throws LivyException {
return new LivyVersion((LivyVersionResponse.fromJson(callRestAPI("/version", "GET")).version));
}
private boolean isSessionExpired() throws LivyException {
try {
getSessionInfo(sessionInfo.id);
return false;
} catch (SessionNotFoundException e) {
return true;
} catch (LivyException e) {
throw e;
}
}
private InterpreterResult appendSessionExpire(InterpreterResult result, boolean sessionExpired) {
if (sessionExpired) {
InterpreterResult result2 = new InterpreterResult(result.code());
result2.add(InterpreterResult.Type.HTML,
"<font color=\"red\">Previous livy session is expired, new livy session is created. " +
"Paragraphs that depend on this paragraph need to be re-executed!" + "</font>");
for (InterpreterResultMessage message : result.message()) {
result2.add(message.getType(), message.getData());
}
return result2;
} else {
return result;
}
}
private InterpreterResult getResultFromStatementInfo(StatementInfo stmtInfo,
boolean displayAppInfo) {
if (stmtInfo.output != null && stmtInfo.output.isError()) {
return new InterpreterResult(InterpreterResult.Code.ERROR, stmtInfo.output.evalue);
} else if (stmtInfo.isCancelled()) {
// corner case, output might be null if it is cancelled.
return new InterpreterResult(InterpreterResult.Code.ERROR, "Job is cancelled");
} else if (stmtInfo.output == null) {
// This case should never happen, just in case
return new InterpreterResult(InterpreterResult.Code.ERROR, "Empty output");
} else {
//TODO(zjffdu) support other types of data (like json, image and etc)
String result = stmtInfo.output.data.plain_text;
// check table magic result first
if (stmtInfo.output.data.application_livy_table_json != null) {
StringBuilder outputBuilder = new StringBuilder();
boolean notFirstColumn = false;
for (Map header : stmtInfo.output.data.application_livy_table_json.headers) {
if (notFirstColumn) {
outputBuilder.append("\t");
}
outputBuilder.append(header.get("name"));
notFirstColumn = true;
}
outputBuilder.append("\n");
for (List<Object> row : stmtInfo.output.data.application_livy_table_json.records) {
outputBuilder.append(StringUtils.join(row, "\t"));
outputBuilder.append("\n");
}
return new InterpreterResult(InterpreterResult.Code.SUCCESS,
InterpreterResult.Type.TABLE, outputBuilder.toString());
} else if (stmtInfo.output.data.image_png != null) {
return new InterpreterResult(InterpreterResult.Code.SUCCESS,
InterpreterResult.Type.IMG, (String) stmtInfo.output.data.image_png);
} else if (result != null) {
result = result.trim();
if (result.startsWith("<link")
|| result.startsWith("<script")
|| result.startsWith("<style")
|| result.startsWith("<div")) {
result = "%html " + result;
}
}
if (displayAppInfo) {
InterpreterResult interpreterResult = new InterpreterResult(InterpreterResult.Code.SUCCESS);
interpreterResult.add(result);
String appInfoHtml = "<hr/>Spark Application Id: " + sessionInfo.appId + "<br/>"
+ "Spark WebUI: <a href=\"" + sessionInfo.webUIAddress + "\">"
+ sessionInfo.webUIAddress + "</a>";
interpreterResult.add(InterpreterResult.Type.HTML, appInfoHtml);
return interpreterResult;
} else {
return new InterpreterResult(InterpreterResult.Code.SUCCESS, result);
}
}
}
private StatementInfo executeStatement(ExecuteRequest executeRequest)
throws LivyException {
return StatementInfo.fromJson(callRestAPI("/sessions/" + sessionInfo.id + "/statements", "POST",
executeRequest.toJson()));
}
private StatementInfo getStatementInfo(int statementId)
throws LivyException {
return StatementInfo.fromJson(
callRestAPI("/sessions/" + sessionInfo.id + "/statements/" + statementId, "GET"));
}
private void cancelStatement(int statementId) throws LivyException {
callRestAPI("/sessions/" + sessionInfo.id + "/statements/" + statementId + "/cancel", "POST");
}
private RestTemplate createRestTemplate() {
String keytabLocation = property.getProperty("zeppelin.livy.keytab");
String principal = property.getProperty("zeppelin.livy.principal");
boolean isSpnegoEnabled = StringUtils.isNotEmpty(keytabLocation) &&
StringUtils.isNotEmpty(principal);
HttpClient httpClient = null;
if (livyURL.startsWith("https:")) {
String keystoreFile = property.getProperty("zeppelin.livy.ssl.trustStore");
String password = property.getProperty("zeppelin.livy.ssl.trustStorePassword");
if (StringUtils.isBlank(keystoreFile)) {
throw new RuntimeException("No zeppelin.livy.ssl.trustStore specified for livy ssl");
}
if (StringUtils.isBlank(password)) {
throw new RuntimeException("No zeppelin.livy.ssl.trustStorePassword specified " +
"for livy ssl");
}
FileInputStream inputStream = null;
try {
inputStream = new FileInputStream(keystoreFile);
KeyStore trustStore = KeyStore.getInstance(KeyStore.getDefaultType());
trustStore.load(new FileInputStream(keystoreFile), password.toCharArray());
SSLContext sslContext = SSLContexts.custom()
.loadTrustMaterial(trustStore)
.build();
SSLConnectionSocketFactory csf = new SSLConnectionSocketFactory(sslContext);
HttpClientBuilder httpClientBuilder = HttpClients.custom().setSSLSocketFactory(csf);
RequestConfig reqConfig = new RequestConfig() {
@Override
public boolean isAuthenticationEnabled() {
return true;
}
};
httpClientBuilder.setDefaultRequestConfig(reqConfig);
Credentials credentials = new Credentials() {
@Override
public String getPassword() {
return null;
}
@Override
public Principal getUserPrincipal() {
return null;
}
};
CredentialsProvider credsProvider = new BasicCredentialsProvider();
credsProvider.setCredentials(AuthScope.ANY, credentials);
httpClientBuilder.setDefaultCredentialsProvider(credsProvider);
if (isSpnegoEnabled) {
Registry<AuthSchemeProvider> authSchemeProviderRegistry =
RegistryBuilder.<AuthSchemeProvider>create()
.register(AuthSchemes.SPNEGO, new SPNegoSchemeFactory())
.build();
httpClientBuilder.setDefaultAuthSchemeRegistry(authSchemeProviderRegistry);
}
httpClient = httpClientBuilder.build();
} catch (Exception e) {
throw new RuntimeException("Failed to create SSL HttpClient", e);
} finally {
if (inputStream != null) {
try {
inputStream.close();
} catch (IOException e) {
LOGGER.error("Failed to close keystore file", e);
}
}
}
}
if (isSpnegoEnabled) {
if (httpClient == null) {
return new KerberosRestTemplate(keytabLocation, principal);
} else {
return new KerberosRestTemplate(keytabLocation, principal, httpClient);
}
}
if (httpClient == null) {
return new RestTemplate();
} else {
return new RestTemplate(new HttpComponentsClientHttpRequestFactory(httpClient));
}
}
private String callRestAPI(String targetURL, String method) throws LivyException {
return callRestAPI(targetURL, method, "");
}
private String callRestAPI(String targetURL, String method, String jsonData)
throws LivyException {
targetURL = livyURL + targetURL;
LOGGER.debug("Call rest api in {}, method: {}, jsonData: {}", targetURL, method, jsonData);
HttpHeaders headers = new HttpHeaders();
headers.add("Content-Type", MediaType.APPLICATION_JSON_UTF8_VALUE);
headers.add("X-Requested-By", "zeppelin");
ResponseEntity<String> response = null;
try {
if (method.equals("POST")) {
HttpEntity<String> entity = new HttpEntity<>(jsonData, headers);
response = restTemplate.exchange(targetURL, HttpMethod.POST, entity, String.class);
} else if (method.equals("GET")) {
HttpEntity<String> entity = new HttpEntity<>(headers);
response = restTemplate.exchange(targetURL, HttpMethod.GET, entity, String.class);
} else if (method.equals("DELETE")) {
HttpEntity<String> entity = new HttpEntity<>(headers);
response = restTemplate.exchange(targetURL, HttpMethod.DELETE, entity, String.class);
}
} catch (HttpClientErrorException e) {
response = new ResponseEntity(e.getResponseBodyAsString(), e.getStatusCode());
LOGGER.error(String.format("Error with %s StatusCode: %s",
response.getStatusCode().value(), e.getResponseBodyAsString()));
} catch (RestClientException e) {
// Exception happens when kerberos is enabled.
if (e.getCause() instanceof HttpClientErrorException) {
HttpClientErrorException cause = (HttpClientErrorException) e.getCause();
if (cause.getResponseBodyAsString().matches(SESSION_NOT_FOUND_PATTERN)) {
throw new SessionNotFoundException(cause.getResponseBodyAsString());
}
throw new LivyException(cause.getResponseBodyAsString() + "\n"
+ ExceptionUtils.getFullStackTrace(ExceptionUtils.getRootCause(e)));
}
throw new LivyException(e);
}
if (response == null) {
throw new LivyException("No http response returned");
}
LOGGER.debug("Get response, StatusCode: {}, responseBody: {}", response.getStatusCode(),
response.getBody());
if (response.getStatusCode().value() == 200
|| response.getStatusCode().value() == 201) {
return response.getBody();
} else if (response.getStatusCode().value() == 404) {
if (response.getBody().matches(SESSION_NOT_FOUND_PATTERN)) {
throw new SessionNotFoundException(response.getBody());
} else {
throw new APINotFoundException("No rest api found for " + targetURL +
", " + response.getStatusCode());
}
} else {
String responseString = response.getBody();
if (responseString.contains("CreateInteractiveRequest[\\\"master\\\"]")) {
return responseString;
}
throw new LivyException(String.format("Error with %s StatusCode: %s",
response.getStatusCode().value(), responseString));
}
}
private void closeSession(int sessionId) {
try {
callRestAPI("/sessions/" + sessionId, "DELETE");
} catch (Exception e) {
LOGGER.error(String.format("Error closing session for user with session ID: %s",
sessionId), e);
}
}
/*
* We create these POJO here to accommodate livy 0.3 which is not released yet. livy rest api has
* some changes from version to version. So we create these POJO in zeppelin side to accommodate
* incompatibility between versions. Later, when livy become more stable, we could just depend on
* livy client jar.
*/
private static class CreateSessionRequest {
public final String kind;
@SerializedName("proxyUser")
public final String user;
public final Map<String, String> conf;
public CreateSessionRequest(String kind, String user, Map<String, String> conf) {
this.kind = kind;
this.user = user;
this.conf = conf;
}
public String toJson() {
return gson.toJson(this);
}
}
/**
*
*/
public static class SessionInfo {
public final int id;
public String appId;
public String webUIAddress;
public final String owner;
public final String proxyUser;
public final String state;
public final String kind;
public final Map<String, String> appInfo;
public final List<String> log;
public SessionInfo(int id, String appId, String owner, String proxyUser, String state,
String kind, Map<String, String> appInfo, List<String> log) {
this.id = id;
this.appId = appId;
this.owner = owner;
this.proxyUser = proxyUser;
this.state = state;
this.kind = kind;
this.appInfo = appInfo;
this.log = log;
}
public boolean isReady() {
return state.equals("idle");
}
public boolean isFinished() {
return state.equals("error") || state.equals("dead") || state.equals("success");
}
public static SessionInfo fromJson(String json) {
return gson.fromJson(json, SessionInfo.class);
}
}
private static class ExecuteRequest {
public final String code;
public ExecuteRequest(String code) {
this.code = code;
}
public String toJson() {
return gson.toJson(this);
}
}
private static class StatementInfo {
public Integer id;
public String state;
public StatementOutput output;
public StatementInfo() {
}
public static StatementInfo fromJson(String json) {
return gson.fromJson(json, StatementInfo.class);
}
public boolean isAvailable() {
return state.equals("available") || state.equals("cancelled");
}
public boolean isCancelled() {
return state.equals("cancelled");
}
private static class StatementOutput {
public String status;
public String execution_count;
public Data data;
public String ename;
public String evalue;
public Object traceback;
public TableMagic tableMagic;
public boolean isError() {
return status.equals("error");
}
public String toJson() {
return gson.toJson(this);
}
private static class Data {
@SerializedName("text/plain")
public String plain_text;
@SerializedName("image/png")
public String image_png;
@SerializedName("application/json")
public String application_json;
@SerializedName("application/vnd.livy.table.v1+json")
public TableMagic application_livy_table_json;
}
private static class TableMagic {
@SerializedName("headers")
List<Map> headers;
@SerializedName("data")
List<List> records;
}
}
}
private static class LivyVersionResponse {
public String url;
public String branch;
public String revision;
public String version;
public String date;
public String user;
public static LivyVersionResponse fromJson(String json) {
return gson.fromJson(json, LivyVersionResponse.class);
}
}
}