[LIVY-998][THRIFT] Support connecting to existing sessions using session name via Thrift Server
Co-authored-by: Asif Khatri <asif.khatri@cloudera.com>
diff --git a/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/LivyThriftSessionManager.scala b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/LivyThriftSessionManager.scala
index cd8d2f5..03e10c3 100644
--- a/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/LivyThriftSessionManager.scala
+++ b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/LivyThriftSessionManager.scala
@@ -155,12 +155,13 @@
}
/**
- * If the user specified an existing sessionId to use, the corresponding session is returned,
- * otherwise a new session is created and returned.
+ * If the user specified an existing sessionId or session name to use, the corresponding session
+ * is returned, otherwise a new session is created and returned.
*/
- private def getOrCreateLivySession(
+ def getOrCreateLivySession(
sessionHandle: SessionHandle,
sessionId: Option[Int],
+ sessionName: Option[String],
username: String,
createLivySession: () => InteractiveSession): InteractiveSession = {
sessionId match {
@@ -183,7 +184,27 @@
}
}
case None =>
- createLivySession()
+ sessionName match {
+ case Some(name) =>
+ server.livySessionManager.get(name) match {
+ case None =>
+ createLivySession()
+ case Some(session) if !server.isAllowedToUse(username, session) =>
+ warn(s"$username has no modify permissions to InteractiveSession $name.")
+ throw new IllegalAccessException(
+ s"$username is not allowed to use InteractiveSession $name.")
+ case Some(session) =>
+ if (session.state.isActive) {
+ info(s"Reusing Session $name for $sessionHandle.")
+ session
+ } else {
+ warn(s"InteractiveSession $name is not active anymore.")
+ throw new IllegalArgumentException(s"Session $name is not active anymore.")
+ }
+ }
+ case None =>
+ createLivySession()
+ }
}
}
@@ -248,7 +269,8 @@
livyServiceUGI.doAs(new PrivilegedExceptionAction[InteractiveSession] {
override def run(): InteractiveSession = {
livySession =
- getOrCreateLivySession(sessionHandle, sessionId, username, createLivySession)
+ getOrCreateLivySession(sessionHandle, sessionId, createInteractiveRequest.name,
+ username, createLivySession)
synchronized {
managedLivySessionActiveUsers.get(livySession.id).foreach { numUsers =>
managedLivySessionActiveUsers(livySession.id) = numUsers + 1
diff --git a/thriftserver/server/src/test/scala/org/apache/livy/thriftserver/TestLivyThriftSessionManager.scala b/thriftserver/server/src/test/scala/org/apache/livy/thriftserver/TestLivyThriftSessionManager.scala
index 11eea31..cbfc006 100644
--- a/thriftserver/server/src/test/scala/org/apache/livy/thriftserver/TestLivyThriftSessionManager.scala
+++ b/thriftserver/server/src/test/scala/org/apache/livy/thriftserver/TestLivyThriftSessionManager.scala
@@ -27,13 +27,13 @@
import org.apache.hive.service.cli.{HiveSQLException, SessionHandle}
import org.junit.Assert._
import org.junit.Test
-import org.mockito.Mockito.mock
+import org.mockito.Mockito.{mock, when}
import org.apache.livy.LivyConf
-import org.apache.livy.server.AccessManager
import org.apache.livy.server.interactive.InteractiveSession
-import org.apache.livy.server.recovery.{SessionStore, StateStore}
-import org.apache.livy.sessions.InteractiveSessionManager
+import org.apache.livy.server.recovery.SessionStore
+import org.apache.livy.server.AccessManager
+import org.apache.livy.sessions.{InteractiveSessionManager, SessionState}
import org.apache.livy.utils.Clock.sleep
object ConnectionLimitType extends Enumeration {
@@ -46,7 +46,7 @@
import ConnectionLimitType._
private def createThriftSessionManager(
- limitTypes: ConnectionLimitType*): LivyThriftSessionManager = {
+ limitTypes: ConnectionLimitType*): (LivyThriftSessionManager, LivyThriftServer) = {
val conf = new LivyConf()
conf.set(LivyConf.LIVY_SPARK_VERSION, sys.env("LIVY_SPARK_VERSION"))
val limit = 3
@@ -62,21 +62,23 @@
}
private def createThriftSessionManager(
- maxSessionWait: Option[String]): LivyThriftSessionManager = {
+ maxSessionWait: Option[String]): (LivyThriftSessionManager, LivyThriftServer) = {
val conf = new LivyConf()
conf.set(LivyConf.LIVY_SPARK_VERSION, sys.env("LIVY_SPARK_VERSION"))
maxSessionWait.foreach(conf.set(LivyConf.THRIFT_SESSION_CREATION_TIMEOUT, _))
this.createThriftSessionManager(conf)
}
- private def createThriftSessionManager(conf: LivyConf): LivyThriftSessionManager = {
+ private def createThriftSessionManager(conf: LivyConf): (LivyThriftSessionManager,
+ LivyThriftServer) = {
val server = new LivyThriftServer(
conf,
mock(classOf[InteractiveSessionManager]),
mock(classOf[SessionStore]),
mock(classOf[AccessManager])
)
- new LivyThriftSessionManager(server, conf)
+ val sessionManager = new LivyThriftSessionManager(server, conf)
+ (sessionManager, server)
}
private def testLimit(
@@ -99,7 +101,7 @@
@Test
def testLimitConnectionsByUser(): Unit = {
- val thriftSessionMgr = createThriftSessionManager(User)
+ val (thriftSessionMgr, _) = createThriftSessionManager(User)
val user = "alice"
val forwardedAddresses = new java.util.ArrayList[String]()
thriftSessionMgr.incrementConnections(user, "10.20.30.40", forwardedAddresses)
@@ -111,7 +113,7 @@
@Test
def testLimitConnectionsByIpAddress(): Unit = {
- val thriftSessionMgr = createThriftSessionManager(IpAddress)
+ val (thriftSessionMgr, _) = createThriftSessionManager(IpAddress)
val ipAddress = "10.20.30.40"
val forwardedAddresses = new java.util.ArrayList[String]()
thriftSessionMgr.incrementConnections("alice", ipAddress, forwardedAddresses)
@@ -123,7 +125,7 @@
@Test
def testLimitConnectionsByUserAndIpAddress(): Unit = {
- val thriftSessionMgr = createThriftSessionManager(UserIpAddress)
+ val (thriftSessionMgr, _) = createThriftSessionManager(UserIpAddress)
val user = "alice"
val ipAddress = "10.20.30.40"
val userAndAddress = user + ":" + ipAddress
@@ -149,7 +151,7 @@
@Test
def testMultipleConnectionLimits(): Unit = {
- val thriftSessionMgr = createThriftSessionManager(User, IpAddress)
+ val (thriftSessionMgr, _) = createThriftSessionManager(User, IpAddress)
val user = "alice"
val ipAddress = "10.20.30.40"
val forwardedAddresses = new java.util.ArrayList[String]()
@@ -166,7 +168,7 @@
@Test(expected = classOf[TimeoutException])
def testGetLivySessionWaitForTimeout(): Unit = {
- val thriftSessionMgr = createThriftSessionManager(Some("10ms"))
+ val (thriftSessionMgr, _) = createThriftSessionManager(Some("10ms"))
val sessionHandle = mock(classOf[SessionHandle])
val future = Future[InteractiveSession] {
sleep(100)
@@ -178,7 +180,7 @@
@Test(expected = classOf[TimeoutException])
def testGetLivySessionWithTimeoutException(): Unit = {
- val thriftSessionMgr = createThriftSessionManager(None)
+ val (thriftSessionMgr, _) = createThriftSessionManager(None)
val sessionHandle = mock(classOf[SessionHandle])
val future = Future[InteractiveSession] {
throw new TimeoutException("Actively throw TimeoutException in Future.")
@@ -187,4 +189,72 @@
Await.ready(future, Duration(30, TimeUnit.SECONDS))
thriftSessionMgr.getLivySession(sessionHandle)
}
+
+
+ @Test
+ def testGetOrCreateLivySessionDifferentSessions(): Unit = {
+ val (thriftSessionMgr, server) = createThriftSessionManager(User, IpAddress)
+ val sessionHandle = mock(classOf[SessionHandle])
+ val sessionUser = "testUser"
+ val sessionId1 = Some(1)
+ val session1 = mock(classOf[InteractiveSession])
+ when(session1.state).thenReturn(SessionState.Running)
+ when(session1.owner).thenReturn(sessionUser)
+ when(server.livySessionManager.get(1)).thenReturn(Some(session1))
+ val sessionId2 = Some(2)
+ val session2 = mock(classOf[InteractiveSession])
+ when(session2.state).thenReturn(SessionState.Running)
+ when(session2.owner).thenReturn(sessionUser)
+ when(server.livySessionManager.get(2)).thenReturn(Some(session2))
+ val result1 = thriftSessionMgr.getOrCreateLivySession(sessionHandle, sessionId1, None,
+ sessionUser, () => null)
+ val result2 = thriftSessionMgr.getOrCreateLivySession(sessionHandle, sessionId2, None,
+ sessionUser, () => null)
+
+ assertNotNull(result1)
+ assertNotNull(result2)
+ assertNotEquals(result1, result2)
+ }
+
+ @Test
+ def testGetOrCreateLivySessionExistingSessionByID(): Unit = {
+ val (thriftSessionMgr, server) = createThriftSessionManager(User, IpAddress)
+ val sessionHandle = mock(classOf[SessionHandle])
+ val sessionUser = "testUser"
+ val sessionId = Some(1)
+ val session1 = mock(classOf[InteractiveSession])
+ when(session1.state).thenReturn(SessionState.Running)
+ when(session1.owner).thenReturn(sessionUser)
+ when(server.livySessionManager.get(1)).thenReturn(Some(session1))
+ val result1 = thriftSessionMgr.getOrCreateLivySession(sessionHandle, sessionId, None,
+ sessionUser, () => null)
+ val result2 = thriftSessionMgr.getOrCreateLivySession(sessionHandle, sessionId, None,
+ sessionUser, () => null)
+
+ assertNotNull(result1)
+ assertNotNull(result2)
+ assertEquals(result1, result2)
+ }
+
+
+ @Test
+ def testGetOrCreateLivySessionExistingSessionByName(): Unit = {
+ val (thriftSessionMgr, server) = createThriftSessionManager(User, IpAddress)
+ val sessionHandle = mock(classOf[SessionHandle])
+ val sessionUser = "testUser"
+ val sessionName = Some("sessionName")
+ val session1 = mock(classOf[InteractiveSession])
+ when(session1.state).thenReturn(SessionState.Running)
+ when(session1.owner).thenReturn(sessionUser)
+ when(server.livySessionManager.get("sessionName")).thenReturn(Some(session1))
+ val result1 = thriftSessionMgr.getOrCreateLivySession(sessionHandle, None, sessionName,
+ sessionUser, () => null)
+ val result2 = thriftSessionMgr.getOrCreateLivySession(sessionHandle, None, sessionName,
+ sessionUser, () => null)
+
+ assertNotNull(result1)
+ assertNotNull(result2)
+ assertEquals(result1, result2)
+ }
+
}