blob: c1c953400648f1a234427b4bd40fd0aec193c6dd [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.livy.rsc;
import java.io.File;
import java.io.IOException;
import java.net.URI;
import java.nio.ByteBuffer;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.util.concurrent.GenericFutureListener;
import io.netty.util.concurrent.ImmediateEventExecutor;
import io.netty.util.concurrent.Promise;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.livy.Job;
import org.apache.livy.JobHandle;
import org.apache.livy.LivyClient;
import org.apache.livy.client.common.BufferUtils;
import org.apache.livy.rsc.driver.AddFileJob;
import org.apache.livy.rsc.driver.AddJarJob;
import org.apache.livy.rsc.rpc.Rpc;
import org.apache.livy.sessions.SessionState;
import static org.apache.livy.rsc.RSCConf.Entry.*;
public class RSCClient implements LivyClient {
private static final Logger LOG = LoggerFactory.getLogger(RSCClient.class);
private static final AtomicInteger EXECUTOR_GROUP_ID = new AtomicInteger();
private final RSCConf conf;
private final Promise<ContextInfo> contextInfoPromise;
private final Map<String, JobHandleImpl<?>> jobs;
private final ClientProtocol protocol;
private final Promise<Rpc> driverRpc;
private final int executorGroupId;
private final EventLoopGroup eventLoopGroup;
private final Promise<URI> serverUriPromise;
private ContextInfo contextInfo;
private Process driverProcess;
private volatile boolean isAlive;
private volatile String replState;
// Record the last activity timestamp of the repl
private volatile long replLastActivity = System.nanoTime();
RSCClient(RSCConf conf, Promise<ContextInfo> ctx, Process driverProcess) throws IOException {
this.conf = conf;
this.contextInfoPromise = ctx;
this.driverProcess = driverProcess;
this.jobs = new ConcurrentHashMap<>();
this.protocol = new ClientProtocol();
this.driverRpc = ImmediateEventExecutor.INSTANCE.newPromise();
this.executorGroupId = EXECUTOR_GROUP_ID.incrementAndGet();
this.eventLoopGroup = new NioEventLoopGroup(
conf.getInt(RPC_MAX_THREADS),
Utils.newDaemonThreadFactory("RSCClient-" + executorGroupId + "-%d"));
this.serverUriPromise = ImmediateEventExecutor.INSTANCE.newPromise();
Utils.addListener(this.contextInfoPromise, new FutureListener<ContextInfo>() {
@Override
public void onSuccess(ContextInfo info) throws Exception {
connectToContext(info);
String url = String.format("rsc://%s:%s@%s:%d",
info.clientId, info.secret, info.remoteAddress, info.remotePort);
serverUriPromise.setSuccess(URI.create(url));
}
@Override
public void onFailure(Throwable error) {
connectionError(error);
serverUriPromise.setFailure(error);
}
});
isAlive = true;
}
public boolean isAlive() {
return isAlive;
}
public Process getDriverProcess() {
return driverProcess;
}
private synchronized void connectToContext(final ContextInfo info) throws Exception {
this.contextInfo = info;
try {
Promise<Rpc> promise = Rpc.createClient(conf,
eventLoopGroup,
info.remoteAddress,
info.remotePort,
info.clientId,
info.secret,
protocol);
Utils.addListener(promise, new FutureListener<Rpc>() {
@Override
public void onSuccess(Rpc rpc) throws Exception {
driverRpc.setSuccess(rpc);
Utils.addListener(rpc.getChannel().closeFuture(), new FutureListener<Void>() {
@Override
public void onSuccess(Void unused) {
if (isAlive) {
LOG.warn("Client RPC channel closed unexpectedly.");
try {
stop(false);
} catch (Exception e) { /* stop() itself prints warning. */ }
}
}
});
LOG.debug("Connected to context {} ({}, {}).", info.clientId,
rpc.getChannel(), executorGroupId);
}
@Override
public void onFailure(Throwable error) throws Exception {
driverRpc.setFailure(error);
connectionError(error);
}
});
} catch (Exception e) {
connectionError(e);
}
}
private void connectionError(Throwable error) {
LOG.error("Failed to connect to context.", error);
try {
stop(false);
} catch (Exception e) { /* stop() itself prints warning. */ }
}
private <T> io.netty.util.concurrent.Future<T> deferredCall(final Object msg,
final Class<T> retType) {
if (driverRpc.isSuccess()) {
try {
return driverRpc.get().call(msg, retType);
} catch (Exception ie) {
throw Utils.propagate(ie);
}
}
// No driver RPC yet, so install a listener and return a promise that will be ready when
// the driver is up and the message is actually delivered.
final Promise<T> promise = eventLoopGroup.next().newPromise();
final FutureListener<T> callListener = new FutureListener<T>() {
@Override
public void onSuccess(T value) throws Exception {
promise.setSuccess(value);
}
@Override
public void onFailure(Throwable error) throws Exception {
promise.setFailure(error);
}
};
Utils.addListener(driverRpc, new FutureListener<Rpc>() {
@Override
public void onSuccess(Rpc rpc) throws Exception {
Utils.addListener(rpc.call(msg, retType), callListener);
}
@Override
public void onFailure(Throwable error) throws Exception {
promise.setFailure(error);
}
});
return promise;
}
public Future<URI> getServerUri() {
return serverUriPromise;
}
@Override
public <T> JobHandle<T> submit(Job<T> job) {
return protocol.submit(job);
}
@Override
public <T> Future<T> run(Job<T> job) {
return protocol.run(job);
}
@Override
public synchronized void stop(boolean shutdownContext) {
if (isAlive) {
isAlive = false;
try {
this.contextInfoPromise.cancel(true);
if (shutdownContext && driverRpc.isSuccess()) {
protocol.endSession();
// Because the remote context won't really reply to the end session message -
// since it closes the channel while handling it, we wait for the RPC's channel
// to close instead.
long stopTimeout = conf.getTimeAsMs(CLIENT_SHUTDOWN_TIMEOUT);
driverRpc.get().getChannel().closeFuture().get(stopTimeout,
TimeUnit.MILLISECONDS);
}
} catch (Exception e) {
LOG.warn("Exception while waiting for end session reply.", e);
Utils.propagate(e);
} finally {
if (driverRpc.isSuccess()) {
try {
driverRpc.get().close();
} catch (Exception e) {
LOG.warn("Error stopping RPC.", e);
}
}
// Report failure for all pending jobs, so that clients can react.
for (Map.Entry<String, JobHandleImpl<?>> e : jobs.entrySet()) {
LOG.info("Failing pending job {} due to shutdown.", e.getKey());
e.getValue().setFailure(new IOException("RSCClient instance stopped."));
}
eventLoopGroup.shutdownGracefully();
}
if (contextInfo != null) {
LOG.debug("Disconnected from context {}, shutdown = {}.", contextInfo.clientId,
shutdownContext);
}
}
}
@Override
public Future<?> uploadJar(File jar) {
throw new UnsupportedOperationException("Use addJar to add the jar to the remote context!");
}
@Override
public Future<?> addJar(URI uri) {
return submit(new AddJarJob(uri.toString()));
}
@Override
public Future<?> uploadFile(File file) {
throw new UnsupportedOperationException("Use addFile to add the file to the remote context!");
}
@Override
public Future<?> addFile(URI uri) {
return submit(new AddFileJob(uri.toString()));
}
public String bypass(ByteBuffer serializedJob, String jobType, boolean sync) {
return protocol.bypass(serializedJob, jobType, sync);
}
public Future<BypassJobStatus> getBypassJobStatus(String id) {
return protocol.getBypassJobStatus(id);
}
public void cancel(String jobId) {
protocol.cancel(jobId);
}
ContextInfo getContextInfo() {
return contextInfo;
}
public Future<Integer> submitReplCode(String code, String codeType) throws Exception {
return deferredCall(new BaseProtocol.ReplJobRequest(code, codeType), Integer.class);
}
public void cancelReplCode(int statementId) throws Exception {
deferredCall(new BaseProtocol.CancelReplJobRequest(statementId), Void.class);
}
public Future<ReplJobResults> getReplJobResults(Integer from, Integer size) throws Exception {
return deferredCall(new BaseProtocol.GetReplJobResults(from, size), ReplJobResults.class);
}
public Future<ReplJobResults> getReplJobResults() throws Exception {
return deferredCall(new BaseProtocol.GetReplJobResults(), ReplJobResults.class);
}
public Future<String[]> completeReplCode(String code, String codeType, int cursor)
throws Exception {
return deferredCall(new BaseProtocol.ReplCompleteRequest(code, codeType, cursor),
String[].class);
}
/**
* @return Return the repl state. If this's not connected to a repl session, it will return null.
*/
public String getReplState() {
return replState;
}
/**
* Get the timestamp of the last activity of the repl. It will be updated when the repl state
* changed from busy to idle
*
* @return last activity timestamp
*/
public long getReplLastActivity() {
return replLastActivity;
}
private class ClientProtocol extends BaseProtocol {
<T> JobHandleImpl<T> submit(Job<T> job) {
final String jobId = UUID.randomUUID().toString();
Object msg = new JobRequest<T>(jobId, job);
final Promise<T> promise = eventLoopGroup.next().newPromise();
final JobHandleImpl<T> handle = new JobHandleImpl<T>(RSCClient.this,
promise, jobId);
jobs.put(jobId, handle);
final io.netty.util.concurrent.Future<Void> rpc = deferredCall(msg, Void.class);
LOG.debug("Sending JobRequest[{}].", jobId);
Utils.addListener(rpc, new FutureListener<Void>() {
@Override
public void onSuccess(Void unused) throws Exception {
handle.changeState(JobHandle.State.QUEUED);
}
@Override
public void onFailure(Throwable error) throws Exception {
LOG.error("RPC error.", error);
promise.tryFailure(error);
}
});
promise.addListener(new GenericFutureListener<Promise<T>>() {
@Override
public void operationComplete(Promise<T> p) {
if (jobId != null) {
jobs.remove(jobId);
}
if (p.isCancelled() && !rpc.isDone()) {
rpc.cancel(true);
}
}
});
return handle;
}
@SuppressWarnings("unchecked")
<T> Future<T> run(Job<T> job) {
return (Future<T>) deferredCall(new SyncJobRequest(job), Object.class);
}
String bypass(ByteBuffer serializedJob, String jobType, boolean sync) {
String jobId = UUID.randomUUID().toString();
Object msg =
new BypassJobRequest(jobId, jobType, BufferUtils.toByteArray(serializedJob), sync);
deferredCall(msg, Void.class);
return jobId;
}
Future<BypassJobStatus> getBypassJobStatus(String id) {
return deferredCall(new GetBypassJobStatus(id), BypassJobStatus.class);
}
void cancel(String jobId) {
deferredCall(new CancelJob(jobId), Void.class);
}
Future<?> endSession() {
return deferredCall(new EndSession(), Void.class);
}
private void handle(ChannelHandlerContext ctx, InitializationError msg) {
LOG.warn("Error reported from remote driver: {}", msg.stackTrace);
}
private void handle(ChannelHandlerContext ctx, JobResult msg) {
JobHandleImpl<?> handle = jobs.remove(msg.id);
if (handle != null) {
LOG.info("Received result for {}", msg.id);
// TODO: need a better exception for this.
Throwable error = msg.error != null ? new RuntimeException(msg.error) : null;
if (error == null) {
handle.setSuccess(msg.result);
} else {
handle.setFailure(error);
}
} else {
LOG.warn("Received result for unknown job {}", msg.id);
}
}
private void handle(ChannelHandlerContext ctx, JobStarted msg) {
JobHandleImpl<?> handle = jobs.get(msg.id);
if (handle != null) {
handle.changeState(JobHandle.State.STARTED);
} else {
LOG.warn("Received event for unknown job {}", msg.id);
}
}
private void handle(ChannelHandlerContext ctx, ReplState msg) {
LOG.trace("Received repl state for {}", msg.state);
// Update last activity timestamp when state change is from busy to idle.
if (SessionState.Busy$.MODULE$.state().equals(replState) && msg != null &&
SessionState.Idle$.MODULE$.state().equals(msg.state)) {
replLastActivity = System.nanoTime();
}
replState = msg.state;
}
}
}