blob: f7547c577af7cb1cb693d646aa0e79af114dfe42 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.livy.server
import java.security.AccessControlException
import javax.servlet.http.HttpServletRequest
import org.scalatra._
import scala.concurrent._
import scala.concurrent.duration._
import org.apache.livy.{LivyConf, Logging}
import org.apache.livy.rsc.RSCClientFactory
import org.apache.livy.server.batch.BatchSession
import org.apache.livy.sessions.{Session, SessionManager}
import org.apache.livy.sessions.Session.RecoveryMetadata
object SessionServlet extends Logging
/**
* Base servlet for session management. All helper methods in this class assume that the session
* id parameter in the handler's URI is "id".
*
* Type parameters:
* S: the session type
*/
abstract class SessionServlet[S <: Session, R <: RecoveryMetadata](
private[livy] val sessionManager: SessionManager[S, R],
val livyConf: LivyConf,
accessManager: AccessManager)
extends JsonServlet
with ApiVersioningSupport
with MethodOverride
with UrlGeneratorSupport
with GZipSupport
{
/**
* Creates a new session based on the current request. The implementation is responsible for
* parsing the body of the request.
*/
protected def createSession(req: HttpServletRequest): S
/**
* Returns a object representing the session data to be sent back to the client.
*/
protected def clientSessionView(session: S, req: HttpServletRequest): Any = session
override def shutdown(): Unit = {
sessionManager.shutdown()
}
before() {
contentType = "application/json"
}
get("/") {
val from = params.get("from").map(_.toInt).getOrElse(0)
val size = params.get("size").map(_.toInt).getOrElse(100)
val sessions = sessionManager.all()
Map(
"from" -> from,
"total" -> sessionManager.size(),
"sessions" -> sessions.view(from, from + size).map(clientSessionView(_, request))
)
}
val getSession = get("/:id") {
withUnprotectedSession { session =>
clientSessionView(session, request)
}
}
get("/:id/state") {
withUnprotectedSession { session =>
Map("id" -> session.id, "state" -> session.state.toString)
}
}
get("/:id/log") {
withViewAccessSession { session =>
val from = params.get("from").map(_.toInt)
val size = params.get("size").map(_.toInt)
val (from_, total, logLines) = serializeLogs(session, from, size)
Map(
"id" -> session.id,
"from" -> from_,
"total" -> total,
"log" -> logLines)
}
}
delete("/:id") {
withModifyAccessSession { session =>
sessionManager.delete(session.id) match {
case Some(future) =>
Await.ready(future, Duration.Inf)
Ok(ResponseMessage("deleted"))
case None =>
NotFound(ResponseMessage(s"Session ${session.id} already stopped."))
}
}
}
def tooManySessions(): Boolean = {
val totalChildProceses = RSCClientFactory.childProcesses().get() +
BatchSession.childProcesses.get()
totalChildProceses >= livyConf.getInt(LivyConf.SESSION_MAX_CREATION)
}
post("/") {
if (tooManySessions) {
BadRequest(ResponseMessage("Rejected, too many sessions are being created!"))
} else {
val session = sessionManager.register(createSession(request))
// Because it may take some time to establish the session, update the last activity
// time before returning the session info to the client.
session.recordActivity()
Created(clientSessionView(session, request),
headers = Map("Location" ->
(getRequestPathInfo(request) + url(getSession, "id" -> session.id.toString))))
}
}
private def getRequestPathInfo(request: HttpServletRequest): String = {
if (request.getPathInfo != null && request.getPathInfo != "/") {
request.getPathInfo
} else {
""
}
}
error {
case e: IllegalArgumentException => BadRequest(ResponseMessage(e.getMessage))
case e: AccessControlException => Forbidden(ResponseMessage(e.getMessage))
}
/**
* Returns the remote user for the given request. Separate method so that tests can override it.
*/
protected def remoteUser(req: HttpServletRequest): String = req.getRemoteUser()
/**
* Performs an operation on the session, without checking for ownership. Operations executed
* via this method must not modify the session in any way, or return potentially sensitive
* information.
*/
protected def withUnprotectedSession(fn: (S => Any)): Any = doWithSession(fn, true, None)
/**
* Performs an operation on the session, verifying whether the caller has view access of the
* session.
*/
protected def withViewAccessSession(fn: (S => Any)): Any =
doWithSession(fn, false, Some(accessManager.hasViewAccess))
/**
* Performs an operation on the session, verifying whether the caller has view access of the
* session.
*/
protected def withModifyAccessSession(fn: (S => Any)): Any =
doWithSession(fn, false, Some(accessManager.hasModifyAccess))
private def doWithSession(fn: (S => Any),
allowAll: Boolean,
checkFn: Option[(String, String) => Boolean]): Any = {
val sessionId = params("id").toInt
sessionManager.get(sessionId) match {
case Some(session) =>
if (allowAll || checkFn.map(_(session.owner, remoteUser(request))).getOrElse(false)) {
fn(session)
} else {
Forbidden()
}
case None =>
NotFound(ResponseMessage(s"Session '$sessionId' not found."))
}
}
private def serializeLogs(session: S, fromOpt: Option[Int], sizeOpt: Option[Int]) = {
val lines = session.logLines()
var size = sizeOpt.getOrElse(100)
var from = fromOpt.getOrElse(-1)
if (size < 0) {
size = lines.length
}
if (from < 0) {
from = math.max(0, lines.length - size)
}
val until = from + size
(from, lines.length, lines.view(from, until))
}
}