blob: eaa023a37403e599a6d2561334c6a60e8b2bdbc5 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.livy.test.framework
import java.util.regex.Pattern
import javax.servlet.http.HttpServletResponse
import scala.annotation.tailrec
import scala.concurrent.duration._
import scala.language.postfixOps
import scala.util.{Either, Left, Right}
import com.fasterxml.jackson.annotation.JsonIgnoreProperties
import com.fasterxml.jackson.databind.ObjectMapper
import com.fasterxml.jackson.module.scala.DefaultScalaModule
import com.ning.http.client.AsyncHttpClient
import com.ning.http.client.Response
import org.apache.hadoop.yarn.api.records.ApplicationId
import org.apache.hadoop.yarn.util.ConverterUtils
import org.scalatest.concurrent.Eventually._
import org.apache.livy.server.batch.CreateBatchRequest
import org.apache.livy.server.interactive.CreateInteractiveRequest
import org.apache.livy.sessions.{Kind, SessionKindModule, SessionState}
import org.apache.livy.utils.AppInfo
object LivyRestClient {
private val BATCH_TYPE = "batches"
private val INTERACTIVE_TYPE = "sessions"
// TODO Define these in production code and share them with test code.
@JsonIgnoreProperties(ignoreUnknown = true)
private case class StatementResult(id: Int, state: String, output: Map[String, Any])
private case class CompletionResult(candidates: Seq[String])
@JsonIgnoreProperties(ignoreUnknown = true)
case class StatementError(ename: String, evalue: String, stackTrace: Seq[String])
@JsonIgnoreProperties(ignoreUnknown = true)
case class SessionSnapshot(
id: Int,
appId: Option[String],
state: String,
appInfo: AppInfo,
log: IndexedSeq[String])
}
class LivyRestClient(val httpClient: AsyncHttpClient, val livyEndpoint: String) {
import LivyRestClient._
val mapper = new ObjectMapper()
.registerModule(DefaultScalaModule)
.registerModule(new SessionKindModule())
class Session(val id: Int, sessionType: String) {
val url: String = s"$livyEndpoint/$sessionType/$id"
def appId(): ApplicationId = {
ConverterUtils.toApplicationId(snapshot().appId.get)
}
def snapshot(): SessionSnapshot = {
val r = httpClient.prepareGet(url).execute().get()
assertStatusCode(r, HttpServletResponse.SC_OK)
mapper.readValue(r.getResponseBodyAsStream, classOf[SessionSnapshot])
}
def stop(): Unit = {
httpClient.prepareDelete(url).execute().get()
eventually(timeout(30 seconds), interval(1 second)) {
verifySessionDoesNotExist()
}
}
def verifySessionState(state: SessionState): Unit = {
verifySessionState(Set(state))
}
def verifySessionState(states: Set[SessionState]): Unit = {
val t = if (Cluster.isRunningOnTravis) 5.minutes else 2.minutes
val strStates = states.map(_.toString)
// Travis uses very slow VM. It needs a longer timeout.
// Keeping the original timeout to avoid slowing down local development.
eventually(timeout(t), interval(1 second)) {
val s = snapshot().state
assert(strStates.contains(s), s"Session $id state $s doesn't equal one of $strStates")
}
}
def verifySessionDoesNotExist(): Unit = {
val r = httpClient.prepareGet(url).execute().get()
assertStatusCode(r, HttpServletResponse.SC_NOT_FOUND)
}
}
class BatchSession(id: Int) extends Session(id, BATCH_TYPE) {
def verifySessionDead(): Unit = verifySessionState(SessionState.Dead())
def verifySessionRunning(): Unit = verifySessionState(SessionState.Running())
def verifySessionSuccess(): Unit = verifySessionState(SessionState.Success())
}
class InteractiveSession(id: Int) extends Session(id, INTERACTIVE_TYPE) {
class Statement(code: String) {
val stmtId = {
val requestBody = Map("code" -> code)
val r = httpClient.preparePost(s"$url/statements")
.setBody(mapper.writeValueAsString(requestBody))
.execute()
.get()
assertStatusCode(r, HttpServletResponse.SC_CREATED)
val newStmt = mapper.readValue(r.getResponseBodyAsStream, classOf[StatementResult])
newStmt.id
}
final def result(): Either[String, StatementError] = {
eventually(timeout(1 minute), interval(1 second)) {
val r = httpClient.prepareGet(s"$url/statements/$stmtId")
.execute()
.get()
assertStatusCode(r, HttpServletResponse.SC_OK)
val newStmt = mapper.readValue(r.getResponseBodyAsStream, classOf[StatementResult])
assert(newStmt.state == "available", s"Statement isn't available: ${newStmt.state}")
val output = newStmt.output
output.get("status") match {
case Some("ok") =>
val data = output("data").asInstanceOf[Map[String, Any]]
var rst = data.getOrElse("text/plain", "")
val magicRst = data.getOrElse("application/vnd.livy.table.v1+json", null)
if (magicRst != null) {
rst = mapper.writeValueAsString(magicRst)
}
Left(rst.asInstanceOf[String])
case Some("error") => Right(mapper.convertValue(output, classOf[StatementError]))
case Some(status) =>
throw new IllegalStateException(s"Unknown statement $stmtId status: $status")
case None =>
throw new IllegalStateException(s"Unknown statement $stmtId output: $newStmt")
}
}
}
def verifyResult(expectedRegex: String): Unit = {
result() match {
case Left(result) =>
if (expectedRegex != null) {
matchStrings(result, expectedRegex)
}
case Right(error) =>
assert(false, s"Got error from statement $stmtId $code: ${error.evalue}")
}
}
def verifyError(
ename: String = null, evalue: String = null, stackTrace: String = null): Unit = {
result() match {
case Left(result) =>
assert(false, s"Statement $stmtId `$code` expected to fail, but succeeded.")
case Right(error) =>
val remoteStack = Option(error.stackTrace).getOrElse(Nil).mkString("\n")
Seq(error.ename -> ename, error.evalue -> evalue, remoteStack -> stackTrace).foreach {
case (actual, expected) if expected != null => matchStrings(actual, expected)
case _ =>
}
}
}
private def matchStrings(actual: String, expected: String): Unit = {
val regex = Pattern.compile(expected, Pattern.DOTALL)
assert(regex.matcher(actual).matches(), s"$actual did not match regex $expected")
}
}
class Completion(code: String, kind: String, cursor: Int) {
val completions = {
val requestBody = Map("code" -> code, "cursor" -> cursor, "kind" -> kind)
val r = httpClient.preparePost(s"$url/completion")
.setBody(mapper.writeValueAsString(requestBody))
.execute()
.get()
assertStatusCode(r, HttpServletResponse.SC_OK)
val res = mapper.readValue(r.getResponseBodyAsStream, classOf[CompletionResult])
res.candidates
}
final def result(): Seq[String] = completions
def verifyContaining(expected: List[String]): Unit = {
assert(result().toSet.forall(x => expected.contains(x)))
}
def verifyNone(): Unit = {
assert(result() == List(), s"Expected no completion proposals but found $completions")
}
}
def run(code: String): Statement = { new Statement(code) }
def complete(code: String, kind: String, cursor: Int): Completion = {
new Completion(code, kind, cursor)
}
def runFatalStatement(code: String): Unit = {
val requestBody = Map("code" -> code)
val r = httpClient.preparePost(s"$url/statements")
.setBody(mapper.writeValueAsString(requestBody))
.execute()
verifySessionState(SessionState.Dead())
}
def verifySessionIdle(): Unit = {
verifySessionState(SessionState.Idle())
}
}
def startBatch(
file: String,
className: Option[String],
args: List[String],
sparkConf: Map[String, String]): BatchSession = {
val r = new CreateBatchRequest()
r.file = file
r.className = className
r.args = args
r.conf = Map("spark.yarn.maxAppAttempts" -> "1") ++ sparkConf
val id = start(BATCH_TYPE, mapper.writeValueAsString(r))
new BatchSession(id)
}
def startSession(
kind: Kind,
sparkConf: Map[String, String],
heartbeatTimeoutInSecond: Int): InteractiveSession = {
val r = new CreateInteractiveRequest()
r.kind = kind
r.conf = sparkConf
r.heartbeatTimeoutInSecond = heartbeatTimeoutInSecond
val id = start(INTERACTIVE_TYPE, mapper.writeValueAsString(r))
new InteractiveSession(id)
}
def connectSession(id: Int): InteractiveSession = { new InteractiveSession(id) }
private def start(sessionType: String, body: String): Int = {
val r = httpClient.preparePost(s"$livyEndpoint/$sessionType")
.setBody(body)
.execute()
.get()
assertStatusCode(r, HttpServletResponse.SC_CREATED)
val newSession = mapper.readValue(r.getResponseBodyAsStream, classOf[SessionSnapshot])
newSession.id
}
private def assertStatusCode(r: Response, expected: Int): Unit = {
def pretty(r: Response): String = {
s"${r.getStatusCode} ${r.getResponseBody}"
}
assert(r.getStatusCode() == expected, s"HTTP status code != $expected: ${pretty(r)}")
}
}