[LIVY-735][RSC] Fix rpc channel closed when multi clients connect to one driver

## What changes were proposed in this pull request?

Currently, the driver tries to support communicating with multi-clients, by registering each client at https://github.com/apache/incubator-livy/blob/master/rsc/src/main/java/org/apache/livy/rsc/driver/RSCDriver.java#L220.

But actually, if multi-clients connect to one driver, the rpc channel will close, the reason are as follows.

1.  In every communication, client sends two packages to driver: header{type, id}, and payload at https://github.com/apache/incubator-livy/blob/master/rsc/src/main/java/org/apache/livy/rsc/rpc/RpcDispatcher.java#L144.

2. If client1 sends header1, payload1, and client2 sends header2, payload2 at the same time.
  The driver receives the package in the order: header1, header2, payload1, payload2.

3. When driver receives header1, driver assigns lastHeader at https://github.com/apache/incubator-livy/blob/master/rsc/src/main/java/org/apache/livy/rsc/rpc/RpcDispatcher.java#L73.

4. Then driver receives header2, driver process it as a payload at https://github.com/apache/incubator-livy/blob/master/rsc/src/main/java/org/apache/livy/rsc/rpc/RpcDispatcher.java#L78 which cause exception and rpc channel closed.

In the muti-active HA mode, the design doc is at: https://docs.google.com/document/d/1bD3qYZpw14_NuCcSGUOfqQ0pqvSbCQsOLFuZp26Ohjc/edit?usp=sharing, the session is allocated among servers by consistent hashing. If a new livy joins, some session will be migrated from old livy to new livy. If the session client in new livy connect to driver before stoping session client in old livy, then two session clients will both connect to driver, and rpc channel close.  In this case, it's hard to ensure only one client connect to one driver at any time. So it's better to support multi-clients connect to one driver, which has no side effects.

How to fix:
1. Move the code of processing client message from `RpcDispatcher` to each `Rpc`.
2. Each `Rpc` registers itself to `channelRpc` in RpcDispatcher.
3. `RpcDispatcher` dispatches each message to `Rpc` according to  `ctx.channel()`.

## How was this patch tested?

Existed UT and IT

Author: runzhiwang <runzhiwang@tencent.com>

Closes #268 from runzhiwang/multi-client-one-driver.
diff --git a/rsc/src/main/java/org/apache/livy/rsc/driver/RSCDriver.java b/rsc/src/main/java/org/apache/livy/rsc/driver/RSCDriver.java
index 0d8eec5..a8f31f7 100644
--- a/rsc/src/main/java/org/apache/livy/rsc/driver/RSCDriver.java
+++ b/rsc/src/main/java/org/apache/livy/rsc/driver/RSCDriver.java
@@ -224,6 +224,7 @@
       @Override
       public void onSuccess(Void unused) {
         clients.remove(client);
+        client.unRegisterRpc();
         if (!inShutdown.get()) {
           setupIdleTimeout();
         }
diff --git a/rsc/src/main/java/org/apache/livy/rsc/rpc/Rpc.java b/rsc/src/main/java/org/apache/livy/rsc/rpc/Rpc.java
index 868dc6d..5fce164 100644
--- a/rsc/src/main/java/org/apache/livy/rsc/rpc/Rpc.java
+++ b/rsc/src/main/java/org/apache/livy/rsc/rpc/Rpc.java
@@ -19,10 +19,11 @@
 
 import java.io.Closeable;
 import java.io.IOException;
-import java.util.Collection;
-import java.util.Collections;
-import java.util.LinkedList;
-import java.util.Map;
+import java.lang.reflect.InvocationTargetException;
+import java.lang.reflect.Method;
+import java.util.*;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentLinkedQueue;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.TimeoutException;
 import java.util.concurrent.atomic.AtomicBoolean;
@@ -208,6 +209,7 @@
         dispatcher);
     Rpc rpc = new Rpc(new RSCConf(null), c, ImmediateEventExecutor.INSTANCE);
     rpc.dispatcher = dispatcher;
+    dispatcher.registerRpc(c, rpc);
     return rpc;
   }
 
@@ -218,6 +220,10 @@
   private final EventExecutorGroup egroup;
   private volatile RpcDispatcher dispatcher;
 
+  private final Map<Class<?>, Method> handlers = new ConcurrentHashMap<>();
+  private final Collection<OutstandingRpc> rpcCalls = new ConcurrentLinkedQueue<OutstandingRpc>();
+  private volatile Rpc.MessageHeader lastHeader;
+
   private Rpc(RSCConf config, Channel channel, EventExecutorGroup egroup) {
     Utils.checkArgument(channel != null);
     Utils.checkArgument(egroup != null);
@@ -239,6 +245,166 @@
   }
 
   /**
+   * For debugging purposes.
+   * @return The name of this Class.
+   */
+  protected String name() {
+    return getClass().getSimpleName();
+  }
+
+  public void handleMsg(ChannelHandlerContext ctx, Object msg, Class<?> handleClass, Object obj)
+      throws Exception {
+    if (lastHeader == null) {
+      if (!(msg instanceof MessageHeader)) {
+        LOG.warn("[{}] Expected RPC header, got {} instead.", name(),
+          msg != null ? msg.getClass().getName() : null);
+        throw new IllegalArgumentException();
+      }
+      lastHeader = (MessageHeader) msg;
+    } else {
+      LOG.debug("[{}] Received RPC message: type={} id={} payload={}", name(),
+        lastHeader.type, lastHeader.id, msg != null ? msg.getClass().getName() : null);
+      try {
+        switch (lastHeader.type) {
+          case CALL:
+            handleCall(ctx, msg, handleClass, obj);
+            break;
+          case REPLY:
+            handleReply(ctx, msg, findRpcCall(lastHeader.id));
+            break;
+          case ERROR:
+            handleError(ctx, msg, findRpcCall(lastHeader.id));
+            break;
+          default:
+            throw new IllegalArgumentException("Unknown RPC message type: " + lastHeader.type);
+        }
+      } finally {
+        lastHeader = null;
+      }
+    }
+  }
+
+  private void handleCall(ChannelHandlerContext ctx, Object msg, Class<?> handleClass, Object obj)
+      throws Exception {
+    Method handler = handlers.get(msg.getClass());
+    if (handler == null) {
+      // Try both getDeclaredMethod() and getMethod() so that we try both private methods
+      // of the class, and public methods of parent classes.
+      try {
+        handler = handleClass.getDeclaredMethod("handle", ChannelHandlerContext.class,
+            msg.getClass());
+      } catch (NoSuchMethodException e) {
+        try {
+          handler = handleClass.getMethod("handle", ChannelHandlerContext.class,
+              msg.getClass());
+        } catch (NoSuchMethodException e2) {
+          LOG.warn(String.format("[%s] Failed to find handler for msg '%s'.", name(),
+            msg.getClass().getName()));
+          writeMessage(MessageType.ERROR, Utils.stackTraceAsString(e.getCause()));
+          return;
+        }
+      }
+      handler.setAccessible(true);
+      handlers.put(msg.getClass(), handler);
+    }
+
+    try {
+      Object payload = handler.invoke(obj, ctx, msg);
+      if (payload == null) {
+        payload = new NullMessage();
+      }
+      writeMessage(MessageType.REPLY, payload);
+    } catch (InvocationTargetException ite) {
+      LOG.debug(String.format("[%s] Error in RPC handler.", name()), ite.getCause());
+      writeMessage(MessageType.ERROR, Utils.stackTraceAsString(ite.getCause()));
+    }
+  }
+
+  private void handleReply(ChannelHandlerContext ctx, Object msg, OutstandingRpc rpc) {
+    rpc.future.setSuccess(msg instanceof NullMessage ? null : msg);
+  }
+
+  private void handleError(ChannelHandlerContext ctx, Object msg, OutstandingRpc rpc) {
+    if (msg instanceof String) {
+      LOG.warn("Received error message:{}.", msg);
+      rpc.future.setFailure(new RpcException((String) msg));
+    } else {
+      String error = String.format("Received error with unexpected payload (%s).",
+          msg != null ? msg.getClass().getName() : null);
+      LOG.warn(String.format("[%s] %s", name(), error));
+      rpc.future.setFailure(new IllegalArgumentException(error));
+      ctx.close();
+    }
+  }
+
+  private void writeMessage(MessageType replyType, Object payload) {
+    channel.write(new MessageHeader(lastHeader.id, replyType));
+    channel.writeAndFlush(payload);
+  }
+
+  private OutstandingRpc findRpcCall(long id) {
+    for (Iterator<OutstandingRpc> it = rpcCalls.iterator(); it.hasNext();) {
+      OutstandingRpc rpc = it.next();
+      if (rpc.id == id) {
+        it.remove();
+        return rpc;
+      }
+    }
+    throw new IllegalArgumentException(String.format(
+        "Received RPC reply for unknown RPC (%d).", id));
+  }
+
+  private void registerRpcCall(long id, Promise<?> promise, String type) {
+    LOG.debug("[{}] Registered outstanding rpc {} ({}).", name(), id, type);
+    rpcCalls.add(new OutstandingRpc(id, promise));
+  }
+
+  private void discardRpcCall(long id) {
+    LOG.debug("[{}] Discarding failed RPC {}.", name(), id);
+    findRpcCall(id);
+  }
+
+  private static class OutstandingRpc {
+    final long id;
+    final Promise<Object> future;
+
+    @SuppressWarnings("unchecked")
+    OutstandingRpc(long id, Promise<?> future) {
+      this.id = id;
+      this.future = (Promise<Object>) future;
+    }
+  }
+
+  public void handleChannelException(ChannelHandlerContext ctx, Throwable cause) {
+    if (LOG.isDebugEnabled()) {
+      LOG.debug(String.format("[%s] Caught exception in channel pipeline.", name()), cause);
+    } else {
+      LOG.info(String.format("[%s] Caught exception in channel pipeline.", name()), cause);
+    }
+
+    if (lastHeader != null) {
+      // There's an RPC waiting for a reply. Exception was most probably caught while processing
+      // the RPC, so send an error.
+      channel.write(new MessageHeader(lastHeader.id, MessageType.ERROR));
+      channel.writeAndFlush(Utils.stackTraceAsString(cause));
+      lastHeader = null;
+    }
+
+    ctx.close();
+  }
+
+  public void handleChannelInactive() {
+    if (rpcCalls.size() > 0) {
+      LOG.warn("[{}] Closing RPC channel with {} outstanding RPCs.", name(), rpcCalls.size());
+      for (OutstandingRpc rpc : rpcCalls) {
+        rpc.future.cancel(true);
+      }
+    } else {
+      LOG.debug("Channel {} became inactive.", channel);
+    }
+  }
+
+  /**
    * Send an RPC call to the remote endpoint and returns a future that can be used to monitor the
    * operation.
    *
@@ -269,13 +435,13 @@
             if (!cf.isSuccess() && !promise.isDone()) {
               LOG.warn("Failed to send RPC, closing connection.", cf.cause());
               promise.setFailure(cf.cause());
-              dispatcher.discardRpc(id);
+              discardRpcCall(id);
               close();
             }
           }
       };
 
-      dispatcher.registerRpc(id, promise, msg.getClass().getName());
+      registerRpcCall(id, promise, msg.getClass().getName());
       channel.eventLoop().submit(new Runnable() {
         @Override
         public void run() {
@@ -294,11 +460,18 @@
     return channel;
   }
 
+  public void unRegisterRpc() {
+    if (dispatcher != null) {
+      dispatcher.unregisterRpc(channel);
+    }
+  }
+
   void setDispatcher(RpcDispatcher dispatcher) {
     Utils.checkNotNull(dispatcher);
     Utils.checkState(this.dispatcher == null, "Dispatcher already set.");
     this.dispatcher = dispatcher;
     channel.pipeline().addLast("dispatcher", dispatcher);
+    dispatcher.registerRpc(channel, this);
   }
 
   @Override
diff --git a/rsc/src/main/java/org/apache/livy/rsc/rpc/RpcDispatcher.java b/rsc/src/main/java/org/apache/livy/rsc/rpc/RpcDispatcher.java
index 0c149b0..88744c2 100644
--- a/rsc/src/main/java/org/apache/livy/rsc/rpc/RpcDispatcher.java
+++ b/rsc/src/main/java/org/apache/livy/rsc/rpc/RpcDispatcher.java
@@ -17,22 +17,15 @@
 
 package org.apache.livy.rsc.rpc;
 
-import java.lang.reflect.InvocationTargetException;
-import java.lang.reflect.Method;
-import java.util.Collection;
-import java.util.Iterator;
 import java.util.Map;
 import java.util.concurrent.ConcurrentHashMap;
-import java.util.concurrent.ConcurrentLinkedQueue;
 
+import io.netty.channel.Channel;
 import io.netty.channel.ChannelHandlerContext;
 import io.netty.channel.SimpleChannelInboundHandler;
-import io.netty.util.concurrent.Promise;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import org.apache.livy.rsc.Utils;
-
 /**
  * An implementation of ChannelInboundHandler that dispatches incoming messages to an instance
  * method based on the method signature.
@@ -49,10 +42,7 @@
 
   private static final Logger LOG = LoggerFactory.getLogger(RpcDispatcher.class);
 
-  private final Map<Class<?>, Method> handlers = new ConcurrentHashMap<>();
-  private final Collection<OutstandingRpc> rpcs = new ConcurrentLinkedQueue<OutstandingRpc>();
-
-  private volatile Rpc.MessageHeader lastHeader;
+  private final Map<Channel, Rpc> channelRpc = new ConcurrentHashMap<>();
 
   /**
    * Override this to add a name to the dispatcher, for debugging purposes.
@@ -62,161 +52,36 @@
     return getClass().getSimpleName();
   }
 
+  public void registerRpc(Channel channel, Rpc rpc) {
+    channelRpc.put(channel, rpc);
+  }
+
+  public void unregisterRpc(Channel channel) {
+    channelRpc.remove(channel);
+  }
+
+  private Rpc getRpc(ChannelHandlerContext ctx) {
+    Channel channel = ctx.channel();
+    if (!channelRpc.containsKey(channel)) {
+      throw new IllegalArgumentException("not existed channel:" + channel);
+    }
+
+    return channelRpc.get(channel);
+  }
+
   @Override
   protected final void channelRead0(ChannelHandlerContext ctx, Object msg) throws Exception {
-    if (lastHeader == null) {
-      if (!(msg instanceof Rpc.MessageHeader)) {
-        LOG.warn("[{}] Expected RPC header, got {} instead.", name(),
-            msg != null ? msg.getClass().getName() : null);
-        throw new IllegalArgumentException();
-      }
-      lastHeader = (Rpc.MessageHeader) msg;
-    } else {
-      LOG.debug("[{}] Received RPC message: type={} id={} payload={}", name(),
-        lastHeader.type, lastHeader.id, msg != null ? msg.getClass().getName() : null);
-      try {
-        switch (lastHeader.type) {
-        case CALL:
-          handleCall(ctx, msg);
-          break;
-        case REPLY:
-          handleReply(ctx, msg, findRpc(lastHeader.id));
-          break;
-        case ERROR:
-          handleError(ctx, msg, findRpc(lastHeader.id));
-          break;
-        default:
-          throw new IllegalArgumentException("Unknown RPC message type: " + lastHeader.type);
-        }
-      } finally {
-        lastHeader = null;
-      }
-    }
-  }
-
-  private OutstandingRpc findRpc(long id) {
-    for (Iterator<OutstandingRpc> it = rpcs.iterator(); it.hasNext();) {
-      OutstandingRpc rpc = it.next();
-      if (rpc.id == id) {
-        it.remove();
-        return rpc;
-      }
-    }
-    throw new IllegalArgumentException(String.format(
-        "Received RPC reply for unknown RPC (%d).", id));
-  }
-
-  private void handleCall(ChannelHandlerContext ctx, Object msg) throws Exception {
-    Method handler = handlers.get(msg.getClass());
-    if (handler == null) {
-      // Try both getDeclaredMethod() and getMethod() so that we try both private methods
-      // of the class, and public methods of parent classes.
-      try {
-        handler = getClass().getDeclaredMethod("handle", ChannelHandlerContext.class,
-            msg.getClass());
-      } catch (NoSuchMethodException e) {
-        try {
-          handler = getClass().getMethod("handle", ChannelHandlerContext.class,
-              msg.getClass());
-        } catch (NoSuchMethodException e2) {
-          LOG.warn(String.format("[%s] Failed to find handler for msg '%s'.", name(),
-            msg.getClass().getName()));
-          writeMessage(ctx, Rpc.MessageType.ERROR, Utils.stackTraceAsString(e.getCause()));
-          return;
-        }
-      }
-      handler.setAccessible(true);
-      handlers.put(msg.getClass(), handler);
-    }
-
-    try {
-      Object payload = handler.invoke(this, ctx, msg);
-      if (payload == null) {
-        payload = new Rpc.NullMessage();
-      }
-      writeMessage(ctx, Rpc.MessageType.REPLY, payload);
-    } catch (InvocationTargetException ite) {
-      LOG.debug(String.format("[%s] Error in RPC handler.", name()), ite.getCause());
-      writeMessage(ctx, Rpc.MessageType.ERROR, Utils.stackTraceAsString(ite.getCause()));
-    }
-  }
-
-  private void writeMessage(ChannelHandlerContext ctx, Rpc.MessageType replyType, Object payload) {
-    ctx.channel().write(new Rpc.MessageHeader(lastHeader.id, replyType));
-    ctx.channel().writeAndFlush(payload);
-  }
-
-  private void handleReply(ChannelHandlerContext ctx, Object msg, OutstandingRpc rpc)
-      throws Exception {
-    rpc.future.setSuccess(msg instanceof Rpc.NullMessage ? null : msg);
-  }
-
-  private void handleError(ChannelHandlerContext ctx, Object msg, OutstandingRpc rpc)
-      throws Exception {
-    if (msg instanceof String) {
-      LOG.warn("Received error message:{}.", msg);
-      rpc.future.setFailure(new RpcException((String) msg));
-    } else {
-      String error = String.format("Received error with unexpected payload (%s).",
-          msg != null ? msg.getClass().getName() : null);
-      LOG.warn(String.format("[%s] %s", name(), error));
-      rpc.future.setFailure(new IllegalArgumentException(error));
-      ctx.close();
-    }
+    getRpc(ctx).handleMsg(ctx, msg, getClass(), this);
   }
 
   @Override
   public final void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
-    if (LOG.isDebugEnabled()) {
-      LOG.debug(String.format("[%s] Caught exception in channel pipeline.", name()), cause);
-    } else {
-      LOG.info("[{}] Closing channel due to exception in pipeline ({}).", name(),
-          cause.getMessage());
-    }
-
-    if (lastHeader != null) {
-      // There's an RPC waiting for a reply. Exception was most probably caught while processing
-      // the RPC, so send an error.
-      ctx.channel().write(new Rpc.MessageHeader(lastHeader.id, Rpc.MessageType.ERROR));
-      ctx.channel().writeAndFlush(Utils.stackTraceAsString(cause));
-      lastHeader = null;
-    }
-
-    ctx.close();
+    getRpc(ctx).handleChannelException(ctx, cause);
   }
 
   @Override
   public final void channelInactive(ChannelHandlerContext ctx) throws Exception {
-    if (rpcs.size() > 0) {
-      LOG.warn("[{}] Closing RPC channel with {} outstanding RPCs.", name(), rpcs.size());
-      for (OutstandingRpc rpc : rpcs) {
-        rpc.future.cancel(true);
-      }
-    } else {
-      LOG.debug("Channel {} became inactive.", ctx.channel());
-    }
+    getRpc(ctx).handleChannelInactive();
     super.channelInactive(ctx);
   }
-
-  void registerRpc(long id, Promise<?> promise, String type) {
-    LOG.debug("[{}] Registered outstanding rpc {} ({}).", name(), id, type);
-    rpcs.add(new OutstandingRpc(id, promise));
-  }
-
-  void discardRpc(long id) {
-    LOG.debug("[{}] Discarding failed RPC {}.", name(), id);
-    findRpc(id);
-  }
-
-  private static class OutstandingRpc {
-    final long id;
-    final Promise<Object> future;
-
-    @SuppressWarnings("unchecked")
-    OutstandingRpc(long id, Promise<?> future) {
-      this.id = id;
-      this.future = (Promise<Object>) future;
-    }
-  }
-
 }