blob: 0c149b01e9c33380e751fe8eb2186a5df5227928 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.livy.rsc.rpc;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.Collection;
import java.util.Iterator;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.util.concurrent.Promise;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.livy.rsc.Utils;
/**
* An implementation of ChannelInboundHandler that dispatches incoming messages to an instance
* method based on the method signature.
* <p>
* A handler's signature must be of the form:
* <p>
* <blockquote><tt>protected void handle(ChannelHandlerContext, MessageType)</tt></blockquote>
* <p>
* Where "MessageType" must match exactly the type of the message to handle. Polymorphism is not
* supported. Handlers can return a value, which becomes the RPC reply; if a null is returned, then
* a reply is still sent, with an empty payload.
*/
public abstract class RpcDispatcher extends SimpleChannelInboundHandler<Object> {
private static final Logger LOG = LoggerFactory.getLogger(RpcDispatcher.class);
private final Map<Class<?>, Method> handlers = new ConcurrentHashMap<>();
private final Collection<OutstandingRpc> rpcs = new ConcurrentLinkedQueue<OutstandingRpc>();
private volatile Rpc.MessageHeader lastHeader;
/**
* Override this to add a name to the dispatcher, for debugging purposes.
* @return The name of this dispatcher.
*/
protected String name() {
return getClass().getSimpleName();
}
@Override
protected final void channelRead0(ChannelHandlerContext ctx, Object msg) throws Exception {
if (lastHeader == null) {
if (!(msg instanceof Rpc.MessageHeader)) {
LOG.warn("[{}] Expected RPC header, got {} instead.", name(),
msg != null ? msg.getClass().getName() : null);
throw new IllegalArgumentException();
}
lastHeader = (Rpc.MessageHeader) msg;
} else {
LOG.debug("[{}] Received RPC message: type={} id={} payload={}", name(),
lastHeader.type, lastHeader.id, msg != null ? msg.getClass().getName() : null);
try {
switch (lastHeader.type) {
case CALL:
handleCall(ctx, msg);
break;
case REPLY:
handleReply(ctx, msg, findRpc(lastHeader.id));
break;
case ERROR:
handleError(ctx, msg, findRpc(lastHeader.id));
break;
default:
throw new IllegalArgumentException("Unknown RPC message type: " + lastHeader.type);
}
} finally {
lastHeader = null;
}
}
}
private OutstandingRpc findRpc(long id) {
for (Iterator<OutstandingRpc> it = rpcs.iterator(); it.hasNext();) {
OutstandingRpc rpc = it.next();
if (rpc.id == id) {
it.remove();
return rpc;
}
}
throw new IllegalArgumentException(String.format(
"Received RPC reply for unknown RPC (%d).", id));
}
private void handleCall(ChannelHandlerContext ctx, Object msg) throws Exception {
Method handler = handlers.get(msg.getClass());
if (handler == null) {
// Try both getDeclaredMethod() and getMethod() so that we try both private methods
// of the class, and public methods of parent classes.
try {
handler = getClass().getDeclaredMethod("handle", ChannelHandlerContext.class,
msg.getClass());
} catch (NoSuchMethodException e) {
try {
handler = getClass().getMethod("handle", ChannelHandlerContext.class,
msg.getClass());
} catch (NoSuchMethodException e2) {
LOG.warn(String.format("[%s] Failed to find handler for msg '%s'.", name(),
msg.getClass().getName()));
writeMessage(ctx, Rpc.MessageType.ERROR, Utils.stackTraceAsString(e.getCause()));
return;
}
}
handler.setAccessible(true);
handlers.put(msg.getClass(), handler);
}
try {
Object payload = handler.invoke(this, ctx, msg);
if (payload == null) {
payload = new Rpc.NullMessage();
}
writeMessage(ctx, Rpc.MessageType.REPLY, payload);
} catch (InvocationTargetException ite) {
LOG.debug(String.format("[%s] Error in RPC handler.", name()), ite.getCause());
writeMessage(ctx, Rpc.MessageType.ERROR, Utils.stackTraceAsString(ite.getCause()));
}
}
private void writeMessage(ChannelHandlerContext ctx, Rpc.MessageType replyType, Object payload) {
ctx.channel().write(new Rpc.MessageHeader(lastHeader.id, replyType));
ctx.channel().writeAndFlush(payload);
}
private void handleReply(ChannelHandlerContext ctx, Object msg, OutstandingRpc rpc)
throws Exception {
rpc.future.setSuccess(msg instanceof Rpc.NullMessage ? null : msg);
}
private void handleError(ChannelHandlerContext ctx, Object msg, OutstandingRpc rpc)
throws Exception {
if (msg instanceof String) {
LOG.warn("Received error message:{}.", msg);
rpc.future.setFailure(new RpcException((String) msg));
} else {
String error = String.format("Received error with unexpected payload (%s).",
msg != null ? msg.getClass().getName() : null);
LOG.warn(String.format("[%s] %s", name(), error));
rpc.future.setFailure(new IllegalArgumentException(error));
ctx.close();
}
}
@Override
public final void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
if (LOG.isDebugEnabled()) {
LOG.debug(String.format("[%s] Caught exception in channel pipeline.", name()), cause);
} else {
LOG.info("[{}] Closing channel due to exception in pipeline ({}).", name(),
cause.getMessage());
}
if (lastHeader != null) {
// There's an RPC waiting for a reply. Exception was most probably caught while processing
// the RPC, so send an error.
ctx.channel().write(new Rpc.MessageHeader(lastHeader.id, Rpc.MessageType.ERROR));
ctx.channel().writeAndFlush(Utils.stackTraceAsString(cause));
lastHeader = null;
}
ctx.close();
}
@Override
public final void channelInactive(ChannelHandlerContext ctx) throws Exception {
if (rpcs.size() > 0) {
LOG.warn("[{}] Closing RPC channel with {} outstanding RPCs.", name(), rpcs.size());
for (OutstandingRpc rpc : rpcs) {
rpc.future.cancel(true);
}
} else {
LOG.debug("Channel {} became inactive.", ctx.channel());
}
super.channelInactive(ctx);
}
void registerRpc(long id, Promise<?> promise, String type) {
LOG.debug("[{}] Registered outstanding rpc {} ({}).", name(), id, type);
rpcs.add(new OutstandingRpc(id, promise));
}
void discardRpc(long id) {
LOG.debug("[{}] Discarding failed RPC {}.", name(), id);
findRpc(id);
}
private static class OutstandingRpc {
final long id;
final Promise<Object> future;
@SuppressWarnings("unchecked")
OutstandingRpc(long id, Promise<?> future) {
this.id = id;
this.future = (Promise<Object>) future;
}
}
}