blob: a8f31f79dfdc489222fd66337fb2f55697e4d3dd [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.livy.rsc.driver;
import java.io.File;
import java.io.IOException;
import java.net.MalformedURLException;
import java.net.URI;
import java.nio.file.Files;
import java.nio.file.attribute.PosixFilePermission;
import java.nio.file.attribute.PosixFilePermissions;
import java.util.Collection;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedDeque;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import io.netty.channel.ChannelHandler.Sharable;
import io.netty.channel.ChannelHandlerContext;
import io.netty.util.concurrent.ScheduledFuture;
import org.apache.commons.io.FileUtils;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.livy.client.common.Serializer;
import org.apache.livy.rsc.BaseProtocol;
import org.apache.livy.rsc.BypassJobStatus;
import org.apache.livy.rsc.FutureListener;
import org.apache.livy.rsc.RSCConf;
import org.apache.livy.rsc.Utils;
import org.apache.livy.rsc.rpc.Rpc;
import org.apache.livy.rsc.rpc.RpcDispatcher;
import org.apache.livy.rsc.rpc.RpcServer;
import static org.apache.livy.rsc.RSCConf.Entry.*;
/**
* Driver code for the Spark client library.
*/
@Sharable
public class RSCDriver extends BaseProtocol {
private static final Logger LOG = LoggerFactory.getLogger(RSCDriver.class);
private final Serializer serializer;
private final Object jcLock;
private final Object shutdownLock;
private final ExecutorService executor;
private final File localTmpDir;
// Used to queue up requests while the SparkContext is being created.
private final List<JobWrapper<?>> jobQueue;
// Keeps track of connected clients.
protected final Collection<Rpc> clients;
final Map<String, JobWrapper<?>> activeJobs;
private final Collection<BypassJobWrapper> bypassJobs;
private RpcServer server;
private volatile JobContextImpl jc;
private volatile boolean running;
protected final SparkConf conf;
protected final RSCConf livyConf;
private final AtomicReference<ScheduledFuture<?>> idleTimeout;
private final AtomicBoolean inShutdown;
public RSCDriver(SparkConf conf, RSCConf livyConf) throws Exception {
Set<PosixFilePermission> perms = PosixFilePermissions.fromString("rwx------");
this.localTmpDir = Files.createTempDirectory("rsc-tmp",
PosixFilePermissions.asFileAttribute(perms)).toFile();
this.executor = Executors.newCachedThreadPool();
this.jobQueue = new LinkedList<>();
this.clients = new ConcurrentLinkedDeque<>();
this.serializer = new Serializer();
this.conf = conf;
this.livyConf = livyConf;
this.jcLock = new Object();
this.shutdownLock = new Object();
this.activeJobs = new ConcurrentHashMap<>();
this.bypassJobs = new ConcurrentLinkedDeque<>();
this.idleTimeout = new AtomicReference<>();
this.inShutdown = new AtomicBoolean(false);
}
private synchronized void shutdown() {
if (!running) {
return;
}
running = false;
// Cancel any pending jobs.
for (JobWrapper<?> job : activeJobs.values()) {
try {
job.cancel();
} catch (Exception e) {
LOG.warn("Error during cancel job.", e);
}
}
try {
shutdownContext();
} catch (Exception e) {
LOG.warn("Error during shutdown.", e);
}
try {
shutdownServer();
} catch (Exception e) {
LOG.warn("Error during shutdown.", e);
}
synchronized (shutdownLock) {
shutdownLock.notifyAll();
}
synchronized (jcLock) {
jcLock.notifyAll();
}
}
private void initializeServer() throws Exception {
String clientId = livyConf.get(CLIENT_ID);
Utils.checkArgument(clientId != null, "No client ID provided.");
String secret = livyConf.get(CLIENT_SECRET);
Utils.checkArgument(secret != null, "No secret provided.");
String launcherAddress = livyConf.get(LAUNCHER_ADDRESS);
Utils.checkArgument(launcherAddress != null, "Missing launcher address.");
int launcherPort = livyConf.getInt(LAUNCHER_PORT);
Utils.checkArgument(launcherPort > 0, "Missing launcher port.");
LOG.info("Connecting to: {}:{}", launcherAddress, launcherPort);
// We need to unset this configuration since it doesn't really apply for the driver side.
// If the driver runs on a multi-homed machine, this can lead to issues where the Livy
// server cannot connect to the auto-detected address, but since the driver can run anywhere
// on the cluster, it would be tricky to solve that problem in a generic way.
livyConf.set(RPC_SERVER_ADDRESS, null);
if (livyConf.getBoolean(TEST_STUCK_START_DRIVER)) {
// Test flag is turned on so we will just infinite loop here. It should cause
// timeout and we should still see yarn application being cleaned up.
LOG.info("Infinite looping as test flag TEST_STUCK_START_SESSION is turned on.");
while(true) {
try {
TimeUnit.MINUTES.sleep(10);
} catch (InterruptedException e) {
LOG.warn("Interrupted during test sleep.", e);
}
}
}
// Bring up the RpcServer an register the secret provided by the Livy server as a client.
LOG.info("Starting RPC server...");
this.server = new RpcServer(livyConf);
server.registerClient(clientId, secret, new RpcServer.ClientCallback() {
@Override
public RpcDispatcher onNewClient(Rpc client) {
registerClient(client);
return RSCDriver.this;
}
@Override
public void onSaslComplete(Rpc client) {
onClientAuthenticated(client);
}
});
// The RPC library takes care of timing out this.
Rpc callbackRpc = Rpc.createClient(livyConf, server.getEventLoopGroup(),
launcherAddress, launcherPort, clientId, secret, this).get();
try {
callbackRpc.call(new RemoteDriverAddress(server.getAddress(), server.getPort())).get(
livyConf.getTimeAsMs(RPC_CLIENT_HANDSHAKE_TIMEOUT), TimeUnit.MILLISECONDS);
} catch (TimeoutException te) {
LOG.warn("Timed out sending address to Livy server, shutting down.");
throw te;
} finally {
callbackRpc.close();
}
// At this point we install the idle timeout handler, in case the Livy server fails to connect
// back.
setupIdleTimeout();
}
private void registerClient(final Rpc client) {
clients.add(client);
stopIdleTimeout();
Utils.addListener(client.getChannel().closeFuture(), new FutureListener<Void>() {
@Override
public void onSuccess(Void unused) {
clients.remove(client);
client.unRegisterRpc();
if (!inShutdown.get()) {
setupIdleTimeout();
}
}
});
LOG.debug("Registered new connection from {}.", client.getChannel());
}
private void setupIdleTimeout() {
if (clients.size() > 0) {
return;
}
Runnable timeoutTask = new Runnable() {
@Override
public void run() {
LOG.warn("Shutting down RSC due to idle timeout ({}).", livyConf.get(SERVER_IDLE_TIMEOUT));
shutdown();
}
};
ScheduledFuture<?> timeout = server.getEventLoopGroup().schedule(timeoutTask,
livyConf.getTimeAsMs(SERVER_IDLE_TIMEOUT), TimeUnit.MILLISECONDS);
// If there's already an idle task registered, then cancel the new one.
if (!this.idleTimeout.compareAndSet(null, timeout)) {
LOG.debug("Timeout task already registered.");
timeout.cancel(false);
}
// If a new client connected while the idle task was being set up, then stop the task.
if (clients.size() > 0) {
stopIdleTimeout();
}
}
private void stopIdleTimeout() {
ScheduledFuture<?> idleTimeout = this.idleTimeout.getAndSet(null);
if (idleTimeout != null) {
LOG.debug("Cancelling idle timeout since new client connected.");
idleTimeout.cancel(false);
}
}
protected void broadcast(Object msg) {
for (Rpc client : clients) {
try {
client.call(msg);
} catch (Exception e) {
LOG.warn("Failed to send message to client " + client, e);
}
}
}
/**
* Initializes the SparkContext used by this driver. This implementation creates a
* context with the provided configuration. Subclasses can override this behavior,
* and returning a null context is allowed. In that case, the context exposed by
* JobContext will be null.
*
* @return The initalized SparkContext
*/
protected SparkEntries initializeSparkEntries() throws Exception {
SparkEntries entries = new SparkEntries(conf);
// Explicitly call sc() to initialize SparkContext.
entries.sc();
return entries;
}
protected void onClientAuthenticated(final Rpc client) {
}
/**
* Called to shut down the driver; any initialization done by initializeContext() should
* be undone here. This is guaranteed to be called only once.
*/
protected void shutdownContext() {
if (jc != null) {
jc.stop();
}
executor.shutdownNow();
try {
FileUtils.deleteDirectory(localTmpDir);
} catch (IOException e) {
LOG.warn("Failed to delete local tmp dir: " + localTmpDir, e);
}
}
private void shutdownServer() {
inShutdown.compareAndSet(false, true);
if (server != null) {
server.close();
}
for (Rpc client: clients) {
client.close();
}
}
void run() throws Exception {
this.running = true;
// Set up a class loader that can be modified, so that we can add jars uploaded
// by the client to the driver's class path.
ClassLoader driverClassLoader = new MutableClassLoader(
Thread.currentThread().getContextClassLoader());
Thread.currentThread().setContextClassLoader(driverClassLoader);
try {
initializeServer();
SparkEntries entries = initializeSparkEntries();
synchronized (jcLock) {
jc = new JobContextImpl(entries, localTmpDir, this);
jcLock.notifyAll();
}
synchronized (jcLock) {
for (JobWrapper<?> job : jobQueue) {
submit(job);
}
jobQueue.clear();
}
synchronized (shutdownLock) {
try {
while (running) {
shutdownLock.wait();
}
} catch (InterruptedException ie) {
// Nothing to do.
}
}
} finally {
shutdown();
}
}
public void submit(JobWrapper<?> job) {
if (jc != null) {
job.submit(executor);
return;
}
synchronized (jcLock) {
if (jc != null) {
job.submit(executor);
} else {
LOG.info("SparkContext not yet up, queueing job request.");
jobQueue.add(job);
}
}
}
JobContextImpl jobContext() {
return jc;
}
Serializer serializer() {
return serializer;
}
<T> void jobFinished(String jobId, T result, Throwable error) {
LOG.debug("Send job({}) result to Client.", jobId);
broadcast(new JobResult<T>(jobId, result, error));
}
void jobStarted(String jobId) {
broadcast(new JobStarted(jobId));
}
public void handle(ChannelHandlerContext ctx, CancelJob msg) {
JobWrapper<?> job = activeJobs.get(msg.id);
if (job == null || !job.cancel()) {
LOG.info("Requested to cancel an already finished job.");
}
}
public void handle(ChannelHandlerContext ctx, EndSession msg) {
if (livyConf.getBoolean(TEST_STUCK_END_SESSION)) {
LOG.warn("Ignoring EndSession request because TEST_STUCK_END_SESSION is set.");
} else {
LOG.debug("Shutting down due to EndSession request.");
shutdown();
}
}
public void handle(ChannelHandlerContext ctx, JobRequest<?> msg) {
LOG.info("Received job request {}", msg.id);
JobWrapper<?> wrapper = new JobWrapper<>(this, msg.id, msg.job);
activeJobs.put(msg.id, wrapper);
submit(wrapper);
}
public void handle(ChannelHandlerContext ctx, BypassJobRequest msg) throws Exception {
LOG.info("Received bypass job request {}", msg.id);
BypassJobWrapper wrapper = createWrapper(msg);
bypassJobs.add(wrapper);
activeJobs.put(msg.id, wrapper);
if (msg.synchronous) {
waitForJobContext();
try {
wrapper.call();
} catch (Throwable t) {
// Wrapper already logged and saved the exception, just avoid it bubbling up
// to the RPC layer.
}
} else {
submit(wrapper);
}
}
protected BypassJobWrapper createWrapper(BypassJobRequest msg) throws Exception {
return new BypassJobWrapper(this, msg.id, new BypassJob(this.serializer(), msg.serializedJob));
}
@SuppressWarnings("unchecked")
public Object handle(ChannelHandlerContext ctx, SyncJobRequest msg) throws Exception {
waitForJobContext();
return msg.job.call(jc);
}
public BypassJobStatus handle(ChannelHandlerContext ctx, GetBypassJobStatus msg) {
for (Iterator<BypassJobWrapper> it = bypassJobs.iterator(); it.hasNext();) {
BypassJobWrapper job = it.next();
if (job.jobId.equals(msg.id)) {
BypassJobStatus status = job.getStatus();
switch (status.state) {
case CANCELLED:
case FAILED:
case SUCCEEDED:
it.remove();
break;
default:
// No-op.
}
return status;
}
}
throw new NoSuchElementException(msg.id);
}
private void waitForJobContext() throws InterruptedException {
synchronized (jcLock) {
while (jc == null) {
jcLock.wait();
if (!running) {
throw new IllegalStateException("Remote context is shutting down.");
}
}
}
}
protected void addFile(String path) {
jc.sc().addFile(path);
}
protected void addJarOrPyFile(String path) throws Exception {
File localCopyDir = new File(jc.getLocalTmpDir(), "__livy__");
File localCopy = copyFileToLocal(localCopyDir, path, jc.sc().sc());
addLocalFileToClassLoader(localCopy);
jc.sc().addJar(path);
}
public void addLocalFileToClassLoader(File localCopy) throws MalformedURLException {
MutableClassLoader cl = (MutableClassLoader) Thread.currentThread().getContextClassLoader();
cl.addURL(localCopy.toURI().toURL());
}
public File copyFileToLocal(
File localCopyDir,
String filePath,
SparkContext sc) throws Exception {
synchronized (jc) {
if (!localCopyDir.isDirectory() && !localCopyDir.mkdir()) {
throw new IOException("Failed to create directory to add pyFile");
}
}
URI uri = new URI(filePath);
String name = uri.getFragment() != null ? uri.getFragment() : uri.getPath();
name = new File(name).getName();
File localCopy = new File(localCopyDir, name);
if (localCopy.exists()) {
throw new IOException(String.format("A file with name %s has " +
"already been uploaded.", name));
}
Configuration conf = sc.hadoopConfiguration();
FileSystem fs = FileSystem.get(uri, conf);
fs.copyToLocalFile(new Path(uri), new Path(localCopy.toURI()));
return localCopy;
}
}