[LIVY-998][THRIFT] Support connecting to existing sessions using session name via Thrift Server

Co-authored-by: Asif Khatri <asif.khatri@cloudera.com>
diff --git a/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/LivyThriftSessionManager.scala b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/LivyThriftSessionManager.scala
index cd8d2f5..03e10c3 100644
--- a/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/LivyThriftSessionManager.scala
+++ b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/LivyThriftSessionManager.scala
@@ -155,12 +155,13 @@
   }
 
   /**
-   * If the user specified an existing sessionId to use, the corresponding session is returned,
-   * otherwise a new session is created and returned.
+   * If the user specified an existing sessionId or session name to use, the corresponding session
+   * is returned, otherwise a new session is created and returned.
    */
-  private def getOrCreateLivySession(
+  def getOrCreateLivySession(
       sessionHandle: SessionHandle,
       sessionId: Option[Int],
+      sessionName: Option[String],
       username: String,
       createLivySession: () => InteractiveSession): InteractiveSession = {
     sessionId match {
@@ -183,7 +184,27 @@
             }
         }
       case None =>
-        createLivySession()
+        sessionName match {
+          case Some(name) =>
+            server.livySessionManager.get(name) match {
+              case None =>
+                createLivySession()
+              case Some(session) if !server.isAllowedToUse(username, session) =>
+                warn(s"$username has no modify permissions to InteractiveSession $name.")
+                throw new IllegalAccessException(
+                  s"$username is not allowed to use InteractiveSession $name.")
+              case Some(session) =>
+                if (session.state.isActive) {
+                  info(s"Reusing Session $name for $sessionHandle.")
+                  session
+                } else {
+                  warn(s"InteractiveSession $name is not active anymore.")
+                  throw new IllegalArgumentException(s"Session $name is not active anymore.")
+                }
+            }
+          case None =>
+            createLivySession()
+        }
     }
   }
 
@@ -248,7 +269,8 @@
         livyServiceUGI.doAs(new PrivilegedExceptionAction[InteractiveSession] {
           override def run(): InteractiveSession = {
             livySession =
-              getOrCreateLivySession(sessionHandle, sessionId, username, createLivySession)
+              getOrCreateLivySession(sessionHandle, sessionId, createInteractiveRequest.name,
+                username, createLivySession)
             synchronized {
               managedLivySessionActiveUsers.get(livySession.id).foreach { numUsers =>
                 managedLivySessionActiveUsers(livySession.id) = numUsers + 1
diff --git a/thriftserver/server/src/test/scala/org/apache/livy/thriftserver/TestLivyThriftSessionManager.scala b/thriftserver/server/src/test/scala/org/apache/livy/thriftserver/TestLivyThriftSessionManager.scala
index 11eea31..cbfc006 100644
--- a/thriftserver/server/src/test/scala/org/apache/livy/thriftserver/TestLivyThriftSessionManager.scala
+++ b/thriftserver/server/src/test/scala/org/apache/livy/thriftserver/TestLivyThriftSessionManager.scala
@@ -27,13 +27,13 @@
 import org.apache.hive.service.cli.{HiveSQLException, SessionHandle}
 import org.junit.Assert._
 import org.junit.Test
-import org.mockito.Mockito.mock
+import org.mockito.Mockito.{mock, when}
 
 import org.apache.livy.LivyConf
-import org.apache.livy.server.AccessManager
 import org.apache.livy.server.interactive.InteractiveSession
-import org.apache.livy.server.recovery.{SessionStore, StateStore}
-import org.apache.livy.sessions.InteractiveSessionManager
+import org.apache.livy.server.recovery.SessionStore
+import org.apache.livy.server.AccessManager
+import org.apache.livy.sessions.{InteractiveSessionManager, SessionState}
 import org.apache.livy.utils.Clock.sleep
 
 object ConnectionLimitType extends Enumeration {
@@ -46,7 +46,7 @@
   import ConnectionLimitType._
 
   private def createThriftSessionManager(
-      limitTypes: ConnectionLimitType*): LivyThriftSessionManager = {
+      limitTypes: ConnectionLimitType*): (LivyThriftSessionManager, LivyThriftServer) = {
     val conf = new LivyConf()
     conf.set(LivyConf.LIVY_SPARK_VERSION, sys.env("LIVY_SPARK_VERSION"))
     val limit = 3
@@ -62,21 +62,23 @@
   }
 
   private def createThriftSessionManager(
-      maxSessionWait: Option[String]): LivyThriftSessionManager = {
+      maxSessionWait: Option[String]): (LivyThriftSessionManager, LivyThriftServer) = {
     val conf = new LivyConf()
     conf.set(LivyConf.LIVY_SPARK_VERSION, sys.env("LIVY_SPARK_VERSION"))
     maxSessionWait.foreach(conf.set(LivyConf.THRIFT_SESSION_CREATION_TIMEOUT, _))
     this.createThriftSessionManager(conf)
   }
 
-  private def createThriftSessionManager(conf: LivyConf): LivyThriftSessionManager = {
+  private def createThriftSessionManager(conf: LivyConf): (LivyThriftSessionManager,
+    LivyThriftServer) = {
     val server = new LivyThriftServer(
       conf,
       mock(classOf[InteractiveSessionManager]),
       mock(classOf[SessionStore]),
       mock(classOf[AccessManager])
     )
-    new LivyThriftSessionManager(server, conf)
+    val sessionManager = new LivyThriftSessionManager(server, conf)
+    (sessionManager, server)
   }
 
   private def testLimit(
@@ -99,7 +101,7 @@
 
   @Test
   def testLimitConnectionsByUser(): Unit = {
-    val thriftSessionMgr = createThriftSessionManager(User)
+    val (thriftSessionMgr, _) = createThriftSessionManager(User)
     val user = "alice"
     val forwardedAddresses = new java.util.ArrayList[String]()
     thriftSessionMgr.incrementConnections(user, "10.20.30.40", forwardedAddresses)
@@ -111,7 +113,7 @@
 
   @Test
   def testLimitConnectionsByIpAddress(): Unit = {
-    val thriftSessionMgr = createThriftSessionManager(IpAddress)
+    val (thriftSessionMgr, _) = createThriftSessionManager(IpAddress)
     val ipAddress = "10.20.30.40"
     val forwardedAddresses = new java.util.ArrayList[String]()
     thriftSessionMgr.incrementConnections("alice", ipAddress, forwardedAddresses)
@@ -123,7 +125,7 @@
 
   @Test
   def testLimitConnectionsByUserAndIpAddress(): Unit = {
-    val thriftSessionMgr = createThriftSessionManager(UserIpAddress)
+    val (thriftSessionMgr, _) = createThriftSessionManager(UserIpAddress)
     val user = "alice"
     val ipAddress = "10.20.30.40"
     val userAndAddress = user + ":" + ipAddress
@@ -149,7 +151,7 @@
 
   @Test
   def testMultipleConnectionLimits(): Unit = {
-    val thriftSessionMgr = createThriftSessionManager(User, IpAddress)
+    val (thriftSessionMgr, _) = createThriftSessionManager(User, IpAddress)
     val user = "alice"
     val ipAddress = "10.20.30.40"
     val forwardedAddresses = new java.util.ArrayList[String]()
@@ -166,7 +168,7 @@
 
   @Test(expected = classOf[TimeoutException])
   def testGetLivySessionWaitForTimeout(): Unit = {
-    val thriftSessionMgr = createThriftSessionManager(Some("10ms"))
+    val (thriftSessionMgr, _) = createThriftSessionManager(Some("10ms"))
     val sessionHandle = mock(classOf[SessionHandle])
     val future = Future[InteractiveSession] {
       sleep(100)
@@ -178,7 +180,7 @@
 
   @Test(expected = classOf[TimeoutException])
   def testGetLivySessionWithTimeoutException(): Unit = {
-    val thriftSessionMgr = createThriftSessionManager(None)
+    val (thriftSessionMgr, _) = createThriftSessionManager(None)
     val sessionHandle = mock(classOf[SessionHandle])
     val future = Future[InteractiveSession] {
       throw new TimeoutException("Actively throw TimeoutException in Future.")
@@ -187,4 +189,72 @@
     Await.ready(future, Duration(30, TimeUnit.SECONDS))
     thriftSessionMgr.getLivySession(sessionHandle)
   }
+
+
+  @Test
+  def testGetOrCreateLivySessionDifferentSessions(): Unit = {
+    val (thriftSessionMgr, server) = createThriftSessionManager(User, IpAddress)
+    val sessionHandle = mock(classOf[SessionHandle])
+    val sessionUser = "testUser"
+    val sessionId1 = Some(1)
+    val session1 = mock(classOf[InteractiveSession])
+    when(session1.state).thenReturn(SessionState.Running)
+    when(session1.owner).thenReturn(sessionUser)
+    when(server.livySessionManager.get(1)).thenReturn(Some(session1))
+    val sessionId2 = Some(2)
+    val session2 = mock(classOf[InteractiveSession])
+    when(session2.state).thenReturn(SessionState.Running)
+    when(session2.owner).thenReturn(sessionUser)
+    when(server.livySessionManager.get(2)).thenReturn(Some(session2))
+    val result1 = thriftSessionMgr.getOrCreateLivySession(sessionHandle, sessionId1, None,
+      sessionUser, () => null)
+    val result2 = thriftSessionMgr.getOrCreateLivySession(sessionHandle, sessionId2, None,
+      sessionUser, () => null)
+
+    assertNotNull(result1)
+    assertNotNull(result2)
+    assertNotEquals(result1, result2)
+  }
+
+  @Test
+  def testGetOrCreateLivySessionExistingSessionByID(): Unit = {
+    val (thriftSessionMgr, server) = createThriftSessionManager(User, IpAddress)
+    val sessionHandle = mock(classOf[SessionHandle])
+    val sessionUser = "testUser"
+    val sessionId = Some(1)
+    val session1 = mock(classOf[InteractiveSession])
+    when(session1.state).thenReturn(SessionState.Running)
+    when(session1.owner).thenReturn(sessionUser)
+    when(server.livySessionManager.get(1)).thenReturn(Some(session1))
+    val result1 = thriftSessionMgr.getOrCreateLivySession(sessionHandle, sessionId, None,
+      sessionUser, () => null)
+    val result2 = thriftSessionMgr.getOrCreateLivySession(sessionHandle, sessionId, None,
+      sessionUser, () => null)
+
+    assertNotNull(result1)
+    assertNotNull(result2)
+    assertEquals(result1, result2)
+  }
+
+
+  @Test
+  def testGetOrCreateLivySessionExistingSessionByName(): Unit = {
+    val (thriftSessionMgr, server) = createThriftSessionManager(User, IpAddress)
+    val sessionHandle = mock(classOf[SessionHandle])
+    val sessionUser = "testUser"
+    val sessionName = Some("sessionName")
+    val session1 = mock(classOf[InteractiveSession])
+    when(session1.state).thenReturn(SessionState.Running)
+    when(session1.owner).thenReturn(sessionUser)
+    when(server.livySessionManager.get("sessionName")).thenReturn(Some(session1))
+    val result1 = thriftSessionMgr.getOrCreateLivySession(sessionHandle, None, sessionName,
+      sessionUser, () => null)
+    val result2 = thriftSessionMgr.getOrCreateLivySession(sessionHandle, None, sessionName,
+      sessionUser, () => null)
+
+    assertNotNull(result1)
+    assertNotNull(result2)
+    assertEquals(result1, result2)
+  }
+
 }