blob: 5fce164107c2dde1ebd9be7e8c1ece535e2aefeb [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.livy.rsc.rpc;
import java.io.Closeable;
import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import javax.security.auth.callback.Callback;
import javax.security.auth.callback.CallbackHandler;
import javax.security.auth.callback.NameCallback;
import javax.security.auth.callback.PasswordCallback;
import javax.security.sasl.RealmCallback;
import javax.security.sasl.Sasl;
import javax.security.sasl.SaslClient;
import javax.security.sasl.SaslException;
import io.netty.bootstrap.Bootstrap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.logging.LogLevel;
import io.netty.handler.logging.LoggingHandler;
import io.netty.util.concurrent.EventExecutorGroup;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.GenericFutureListener;
import io.netty.util.concurrent.ImmediateEventExecutor;
import io.netty.util.concurrent.Promise;
import io.netty.util.concurrent.ScheduledFuture;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.livy.rsc.RSCConf;
import org.apache.livy.rsc.Utils;
import static org.apache.livy.rsc.RSCConf.Entry.*;
/**
* Encapsulates the RPC functionality. Provides higher-level methods to talk to the remote
* endpoint.
*/
public class Rpc implements Closeable {
private static final Logger LOG = LoggerFactory.getLogger(Rpc.class);
static final String SASL_REALM = "rsc";
static final String SASL_USER = "rsc";
static final String SASL_PROTOCOL = "rsc";
static final String SASL_AUTH_CONF = "auth-conf";
/**
* Creates an RPC client for a server running on the given remote host and port.
*
* @param config RPC configuration data.
* @param eloop Event loop for managing the connection.
* @param host Host name or IP address to connect to.
* @param port Port where server is listening.
* @param clientId The client ID that identifies the connection.
* @param secret Secret for authenticating the client with the server.
* @param dispatcher Dispatcher used to handle RPC calls.
* @return A future that can be used to monitor the creation of the RPC object.
*/
public static Promise<Rpc> createClient(
final RSCConf config,
final EventLoopGroup eloop,
String host,
int port,
final String clientId,
final String secret,
final RpcDispatcher dispatcher) throws Exception {
int connectTimeoutMs = (int) config.getTimeAsMs(RPC_CLIENT_CONNECT_TIMEOUT);
final ChannelFuture cf = new Bootstrap()
.group(eloop)
.handler(new ChannelInboundHandlerAdapter() { })
.channel(NioSocketChannel.class)
.option(ChannelOption.SO_KEEPALIVE, true)
.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, connectTimeoutMs)
.connect(host, port);
final Promise<Rpc> promise = eloop.next().newPromise();
final AtomicReference<Rpc> rpc = new AtomicReference<Rpc>();
// Set up a timeout to undo everything.
final Runnable timeoutTask = new Runnable() {
@Override
public void run() {
promise.setFailure(new TimeoutException("Timed out waiting for RPC server connection."));
}
};
final ScheduledFuture<?> timeoutFuture = eloop.schedule(timeoutTask,
config.getTimeAsMs(RPC_CLIENT_HANDSHAKE_TIMEOUT), TimeUnit.MILLISECONDS);
// The channel listener instantiates the Rpc instance when the connection is established,
// and initiates the SASL handshake.
cf.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture cf) throws Exception {
if (cf.isSuccess()) {
SaslClientHandler saslHandler = new SaslClientHandler(config, clientId, promise,
timeoutFuture, secret, dispatcher);
Rpc rpc = createRpc(config, saslHandler, (SocketChannel) cf.channel(), eloop);
saslHandler.rpc = rpc;
saslHandler.sendHello(cf.channel());
} else {
promise.setFailure(cf.cause());
}
}
});
// Handle cancellation of the promise.
promise.addListener(new GenericFutureListener<Promise<Rpc>>() {
@Override
public void operationComplete(Promise<Rpc> p) {
if (p.isCancelled()) {
cf.cancel(true);
}
}
});
return promise;
}
static Rpc createServer(SaslHandler saslHandler, RSCConf config, SocketChannel channel,
EventExecutorGroup egroup) throws IOException {
return createRpc(config, saslHandler, channel, egroup);
}
private static Rpc createRpc(RSCConf config,
SaslHandler saslHandler,
SocketChannel client,
EventExecutorGroup egroup)
throws IOException {
LogLevel logLevel = LogLevel.TRACE;
String logLevelStr = config.get(RPC_CHANNEL_LOG_LEVEL);
if (logLevelStr != null) {
try {
logLevel = LogLevel.valueOf(logLevelStr);
} catch (Exception e) {
LOG.warn("Invalid log level {}, reverting to default.", logLevelStr);
}
}
boolean logEnabled = false;
switch (logLevel) {
case DEBUG:
logEnabled = LOG.isDebugEnabled();
break;
case ERROR:
logEnabled = LOG.isErrorEnabled();
break;
case INFO:
logEnabled = LOG.isInfoEnabled();
break;
case TRACE:
logEnabled = LOG.isTraceEnabled();
break;
case WARN:
logEnabled = LOG.isWarnEnabled();
break;
}
if (logEnabled) {
client.pipeline().addLast("logger", new LoggingHandler(Rpc.class, logLevel));
}
KryoMessageCodec kryo = new KryoMessageCodec(config.getInt(RPC_MAX_MESSAGE_SIZE),
MessageHeader.class, NullMessage.class, SaslMessage.class);
saslHandler.setKryoMessageCodec(kryo);
client.pipeline()
.addLast("codec", kryo)
.addLast("sasl", saslHandler);
return new Rpc(config, client, egroup);
}
static Rpc createEmbedded(RpcDispatcher dispatcher) {
EmbeddedChannel c = new EmbeddedChannel(
new LoggingHandler(Rpc.class),
new KryoMessageCodec(0, MessageHeader.class, NullMessage.class),
dispatcher);
Rpc rpc = new Rpc(new RSCConf(null), c, ImmediateEventExecutor.INSTANCE);
rpc.dispatcher = dispatcher;
dispatcher.registerRpc(c, rpc);
return rpc;
}
private final RSCConf config;
private final AtomicBoolean rpcClosed;
private final AtomicLong rpcId;
private final Channel channel;
private final EventExecutorGroup egroup;
private volatile RpcDispatcher dispatcher;
private final Map<Class<?>, Method> handlers = new ConcurrentHashMap<>();
private final Collection<OutstandingRpc> rpcCalls = new ConcurrentLinkedQueue<OutstandingRpc>();
private volatile Rpc.MessageHeader lastHeader;
private Rpc(RSCConf config, Channel channel, EventExecutorGroup egroup) {
Utils.checkArgument(channel != null);
Utils.checkArgument(egroup != null);
this.config = config;
this.channel = channel;
this.dispatcher = null;
this.egroup = egroup;
this.rpcClosed = new AtomicBoolean();
this.rpcId = new AtomicLong();
// Note: this does not work for embedded channels.
channel.pipeline().addLast("monitor", new ChannelInboundHandlerAdapter() {
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
close();
super.channelInactive(ctx);
}
});
}
/**
* For debugging purposes.
* @return The name of this Class.
*/
protected String name() {
return getClass().getSimpleName();
}
public void handleMsg(ChannelHandlerContext ctx, Object msg, Class<?> handleClass, Object obj)
throws Exception {
if (lastHeader == null) {
if (!(msg instanceof MessageHeader)) {
LOG.warn("[{}] Expected RPC header, got {} instead.", name(),
msg != null ? msg.getClass().getName() : null);
throw new IllegalArgumentException();
}
lastHeader = (MessageHeader) msg;
} else {
LOG.debug("[{}] Received RPC message: type={} id={} payload={}", name(),
lastHeader.type, lastHeader.id, msg != null ? msg.getClass().getName() : null);
try {
switch (lastHeader.type) {
case CALL:
handleCall(ctx, msg, handleClass, obj);
break;
case REPLY:
handleReply(ctx, msg, findRpcCall(lastHeader.id));
break;
case ERROR:
handleError(ctx, msg, findRpcCall(lastHeader.id));
break;
default:
throw new IllegalArgumentException("Unknown RPC message type: " + lastHeader.type);
}
} finally {
lastHeader = null;
}
}
}
private void handleCall(ChannelHandlerContext ctx, Object msg, Class<?> handleClass, Object obj)
throws Exception {
Method handler = handlers.get(msg.getClass());
if (handler == null) {
// Try both getDeclaredMethod() and getMethod() so that we try both private methods
// of the class, and public methods of parent classes.
try {
handler = handleClass.getDeclaredMethod("handle", ChannelHandlerContext.class,
msg.getClass());
} catch (NoSuchMethodException e) {
try {
handler = handleClass.getMethod("handle", ChannelHandlerContext.class,
msg.getClass());
} catch (NoSuchMethodException e2) {
LOG.warn(String.format("[%s] Failed to find handler for msg '%s'.", name(),
msg.getClass().getName()));
writeMessage(MessageType.ERROR, Utils.stackTraceAsString(e.getCause()));
return;
}
}
handler.setAccessible(true);
handlers.put(msg.getClass(), handler);
}
try {
Object payload = handler.invoke(obj, ctx, msg);
if (payload == null) {
payload = new NullMessage();
}
writeMessage(MessageType.REPLY, payload);
} catch (InvocationTargetException ite) {
LOG.debug(String.format("[%s] Error in RPC handler.", name()), ite.getCause());
writeMessage(MessageType.ERROR, Utils.stackTraceAsString(ite.getCause()));
}
}
private void handleReply(ChannelHandlerContext ctx, Object msg, OutstandingRpc rpc) {
rpc.future.setSuccess(msg instanceof NullMessage ? null : msg);
}
private void handleError(ChannelHandlerContext ctx, Object msg, OutstandingRpc rpc) {
if (msg instanceof String) {
LOG.warn("Received error message:{}.", msg);
rpc.future.setFailure(new RpcException((String) msg));
} else {
String error = String.format("Received error with unexpected payload (%s).",
msg != null ? msg.getClass().getName() : null);
LOG.warn(String.format("[%s] %s", name(), error));
rpc.future.setFailure(new IllegalArgumentException(error));
ctx.close();
}
}
private void writeMessage(MessageType replyType, Object payload) {
channel.write(new MessageHeader(lastHeader.id, replyType));
channel.writeAndFlush(payload);
}
private OutstandingRpc findRpcCall(long id) {
for (Iterator<OutstandingRpc> it = rpcCalls.iterator(); it.hasNext();) {
OutstandingRpc rpc = it.next();
if (rpc.id == id) {
it.remove();
return rpc;
}
}
throw new IllegalArgumentException(String.format(
"Received RPC reply for unknown RPC (%d).", id));
}
private void registerRpcCall(long id, Promise<?> promise, String type) {
LOG.debug("[{}] Registered outstanding rpc {} ({}).", name(), id, type);
rpcCalls.add(new OutstandingRpc(id, promise));
}
private void discardRpcCall(long id) {
LOG.debug("[{}] Discarding failed RPC {}.", name(), id);
findRpcCall(id);
}
private static class OutstandingRpc {
final long id;
final Promise<Object> future;
@SuppressWarnings("unchecked")
OutstandingRpc(long id, Promise<?> future) {
this.id = id;
this.future = (Promise<Object>) future;
}
}
public void handleChannelException(ChannelHandlerContext ctx, Throwable cause) {
if (LOG.isDebugEnabled()) {
LOG.debug(String.format("[%s] Caught exception in channel pipeline.", name()), cause);
} else {
LOG.info(String.format("[%s] Caught exception in channel pipeline.", name()), cause);
}
if (lastHeader != null) {
// There's an RPC waiting for a reply. Exception was most probably caught while processing
// the RPC, so send an error.
channel.write(new MessageHeader(lastHeader.id, MessageType.ERROR));
channel.writeAndFlush(Utils.stackTraceAsString(cause));
lastHeader = null;
}
ctx.close();
}
public void handleChannelInactive() {
if (rpcCalls.size() > 0) {
LOG.warn("[{}] Closing RPC channel with {} outstanding RPCs.", name(), rpcCalls.size());
for (OutstandingRpc rpc : rpcCalls) {
rpc.future.cancel(true);
}
} else {
LOG.debug("Channel {} became inactive.", channel);
}
}
/**
* Send an RPC call to the remote endpoint and returns a future that can be used to monitor the
* operation.
*
* @param msg RPC call to send.
* @return A future used to monitor the operation.
*/
public Future<Void> call(Object msg) {
return call(msg, Void.class);
}
/**
* Send an RPC call to the remote endpoint and returns a future that can be used to monitor the
* operation.
*
* @param msg RPC call to send.
* @param retType Type of expected reply.
* @return A future used to monitor the operation.
*/
public <T> Future<T> call(final Object msg, Class<T> retType) {
Utils.checkArgument(msg != null);
Utils.checkState(channel.isOpen(), "RPC channel is closed.");
try {
final long id = rpcId.getAndIncrement();
final Promise<T> promise = egroup.next().newPromise();
final ChannelFutureListener listener = new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture cf) {
if (!cf.isSuccess() && !promise.isDone()) {
LOG.warn("Failed to send RPC, closing connection.", cf.cause());
promise.setFailure(cf.cause());
discardRpcCall(id);
close();
}
}
};
registerRpcCall(id, promise, msg.getClass().getName());
channel.eventLoop().submit(new Runnable() {
@Override
public void run() {
channel.write(new MessageHeader(id, Rpc.MessageType.CALL)).addListener(listener);
channel.writeAndFlush(msg).addListener(listener);
}
});
return promise;
} catch (Exception e) {
throw Utils.propagate(e);
}
}
public Channel getChannel() {
return channel;
}
public void unRegisterRpc() {
if (dispatcher != null) {
dispatcher.unregisterRpc(channel);
}
}
void setDispatcher(RpcDispatcher dispatcher) {
Utils.checkNotNull(dispatcher);
Utils.checkState(this.dispatcher == null, "Dispatcher already set.");
this.dispatcher = dispatcher;
channel.pipeline().addLast("dispatcher", dispatcher);
dispatcher.registerRpc(channel, this);
}
@Override
public void close() {
if (!rpcClosed.compareAndSet(false, true)) {
return;
}
try {
channel.close().sync();
} catch (InterruptedException ie) {
Thread.interrupted();
}
}
static enum MessageType {
CALL,
REPLY,
ERROR;
}
static class MessageHeader {
final long id;
final MessageType type;
MessageHeader() {
this(-1, null);
}
MessageHeader(long id, MessageType type) {
this.id = id;
this.type = type;
}
}
static class NullMessage {
}
static class SaslMessage {
final String clientId;
final byte[] payload;
SaslMessage() {
this(null, null);
}
SaslMessage(byte[] payload) {
this(null, payload);
}
SaslMessage(String clientId, byte[] payload) {
this.clientId = clientId;
this.payload = payload;
}
}
private static class SaslClientHandler extends SaslHandler implements CallbackHandler {
private final SaslClient client;
private final String clientId;
private final String secret;
private final RpcDispatcher dispatcher;
private Promise<Rpc> promise;
private ScheduledFuture<?> timeout;
// Can't be set in constructor due to circular dependency.
private Rpc rpc;
SaslClientHandler(
RSCConf config,
String clientId,
Promise<Rpc> promise,
ScheduledFuture<?> timeout,
String secret,
RpcDispatcher dispatcher)
throws IOException {
super(config);
this.clientId = clientId;
this.promise = promise;
this.timeout = timeout;
this.secret = secret;
this.dispatcher = dispatcher;
this.client = Sasl.createSaslClient(new String[] { config.get(SASL_MECHANISMS) },
null, SASL_PROTOCOL, SASL_REALM, config.getSaslOptions(), this);
}
@Override
protected boolean isComplete() {
return client.isComplete();
}
@Override
protected String getNegotiatedProperty(String name) {
return (String) client.getNegotiatedProperty(name);
}
@Override
protected SaslMessage update(SaslMessage challenge) throws IOException {
byte[] response = client.evaluateChallenge(challenge.payload);
return response != null ? new SaslMessage(response) : null;
}
@Override
public byte[] wrap(byte[] data, int offset, int len) throws IOException {
return client.wrap(data, offset, len);
}
@Override
public byte[] unwrap(byte[] data, int offset, int len) throws IOException {
return client.unwrap(data, offset, len);
}
@Override
public void dispose() throws IOException {
if (!client.isComplete()) {
onError(new SaslException("Client closed before SASL negotiation finished."));
}
client.dispose();
}
@Override
protected void onComplete() throws Exception {
timeout.cancel(true);
rpc.setDispatcher(dispatcher);
promise.setSuccess(rpc);
timeout = null;
promise = null;
}
@Override
protected void onError(Throwable error) {
timeout.cancel(true);
if (!promise.isDone()) {
promise.setFailure(error);
}
}
@Override
public void handle(Callback[] callbacks) {
for (Callback cb : callbacks) {
if (cb instanceof NameCallback) {
((NameCallback)cb).setName(clientId);
} else if (cb instanceof PasswordCallback) {
((PasswordCallback)cb).setPassword(secret.toCharArray());
} else if (cb instanceof RealmCallback) {
RealmCallback rb = (RealmCallback) cb;
rb.setText(rb.getDefaultText());
}
}
}
void sendHello(Channel c) throws Exception {
byte[] hello = client.hasInitialResponse() ?
client.evaluateChallenge(new byte[0]) : new byte[0];
c.writeAndFlush(new SaslMessage(clientId, hello)).sync();
}
}
}