blob: b6adc235724a397b8e8df3ab3ca8564adf044e3a [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.livy.client.http
import java.io.{File, InputStream}
import java.net.{InetAddress, URI}
import java.nio.file.{Files, Paths}
import java.util.concurrent.{Future => JFuture, _}
import java.util.concurrent.atomic.AtomicLong
import javax.servlet.ServletContext
import javax.servlet.http.HttpServletRequest
import scala.concurrent.{ExecutionContext, Future}
import org.mockito.ArgumentCaptor
import org.mockito.Matchers.{eq => meq, _}
import org.mockito.Mockito._
import org.scalatest.{BeforeAndAfterAll, FunSpecLike}
import org.scalatra.LifeCycle
import org.scalatra.servlet.ScalatraListener
import org.apache.livy._
import org.apache.livy.client.common.{BufferUtils, Serializer}
import org.apache.livy.client.common.HttpMessages._
import org.apache.livy.server.{AccessManager, WebServer}
import org.apache.livy.server.interactive.{InteractiveSession, InteractiveSessionServlet}
import org.apache.livy.server.recovery.SessionStore
import org.apache.livy.sessions.{InteractiveSessionManager, SessionState, Spark}
import org.apache.livy.test.jobs.Echo
import org.apache.livy.utils.AppInfo
/**
* The test for the HTTP client is written in Scala so we can reuse the code in the livy-server
* module, which implements the client session backend. The client servlet has some functionality
* overridden to avoid creating sub-processes for each seession.
*/
class HttpClientSpec extends FunSpecLike with BeforeAndAfterAll with LivyBaseUnitTestSuite {
import HttpClientSpec._
private val TIMEOUT_S = 10
private val ID_GENERATOR = new AtomicLong()
private val serializer = new Serializer()
private var server: WebServer = _
private var client: LivyClient = _
override def beforeAll(): Unit = {
super.beforeAll()
server = new WebServer(new LivyConf(), "0.0.0.0", 0)
server.context.setResourceBase("src/main/org/apache/livy/server")
server.context.setInitParameter(ScalatraListener.LifeCycleKey,
classOf[HttpClientTestBootstrap].getCanonicalName)
server.context.addEventListener(new ScalatraListener)
server.start()
}
override def afterAll(): Unit = {
super.afterAll()
if (server != null) {
server.stop()
server = null
}
if (client != null) {
client.stop(true)
client = null
}
session = null
}
describe("HTTP client library") {
it("should create clients") {
// WebServer does this internally instead of respecting "0.0.0.0", so try to use the same
// address.
val uri = s"http://${InetAddress.getLocalHost.getHostAddress}:${server.port}/"
client = new LivyClientBuilder(false).setURI(new URI(uri)).build()
}
withClient("should run and monitor asynchronous jobs") {
testJob(false)
}
withClient("should propagate errors from jobs") {
val errorMessage = "This job throws an error."
val (jobId, handle) = runJob(false, { id => Seq(
new JobStatus(id, JobHandle.State.FAILED, null, errorMessage))
})
val error = intercept[ExecutionException] {
handle.get(TIMEOUT_S, TimeUnit.SECONDS)
}
assert(error.getCause() != null)
assert(error.getCause().getMessage().indexOf(errorMessage) >= 0)
verify(session, times(1)).jobStatus(meq(jobId))
}
withClient("should run and monitor synchronous jobs") {
testJob(false)
}
withClient("should add files and jars") {
val furi = new URI("hdfs:file")
val juri = new URI("hdfs:jar")
client.addFile(furi).get(TIMEOUT_S, TimeUnit.SECONDS)
client.addJar(juri).get(TIMEOUT_S, TimeUnit.SECONDS)
verify(session, times(1)).addFile(meq(furi))
verify(session, times(1)).addJar(meq(juri))
}
withClient("should upload files and jars") {
uploadAndVerify("file")
uploadAndVerify("jar")
}
withClient("should cancel jobs") {
val (jobId, handle) = runJob(false, { id => Seq(
new JobStatus(id, JobHandle.State.STARTED, null, null),
new JobStatus(id, JobHandle.State.CANCELLED, null, null))
})
handle.cancel(true)
intercept[CancellationException] {
handle.get(TIMEOUT_S, TimeUnit.SECONDS)
}
verify(session, times(1)).cancelJob(meq(jobId))
}
withClient("should notify listeners of job completion") {
val (jobId, handle) = runJob(false, { id => Seq(
new JobStatus(id, JobHandle.State.STARTED, null, null),
new JobStatus(id, JobHandle.State.SUCCEEDED, serialize(id), null))
})
val listener = mock(classOf[JobHandle.Listener[Long]])
handle.asInstanceOf[JobHandle[Long]].addListener(listener)
assert(handle.get(TIMEOUT_S, TimeUnit.SECONDS) === jobId)
verify(listener, times(1)).onJobSucceeded(any(), any())
}
withClient("should time out handle get() call") {
// JobHandleImpl does exponential backoff checking the result of a job. Given an initial
// wait of 100ms, 4 iterations should result in a wait of 800ms, so the handle should at that
// point timeout a wait of 100ms.
val (jobId, handle) = runJob(false, { id => Seq(
new JobStatus(id, JobHandle.State.STARTED, null, null),
new JobStatus(id, JobHandle.State.STARTED, null, null),
new JobStatus(id, JobHandle.State.STARTED, null, null),
new JobStatus(id, JobHandle.State.SUCCEEDED, serialize(id), null))
})
intercept[TimeoutException] {
handle.get(100, TimeUnit.MILLISECONDS)
}
assert(handle.get(TIMEOUT_S, TimeUnit.SECONDS) === jobId)
}
withClient("should handle null responses") {
testJob(false, response = Some(null))
}
withClient("should connect to existing sessions") {
var sid = client.asInstanceOf[HttpClient].getSessionId()
val uri = s"http://${InetAddress.getLocalHost.getHostAddress}:${server.port}" +
s"${LivyConnection.SESSIONS_URI}/$sid"
val newClient = new LivyClientBuilder(false).setURI(new URI(uri)).build()
newClient.stop(false)
verify(session, never()).stop()
}
withClient("should tear down clients") {
client.stop(true)
verify(session, times(1)).stop()
}
}
private def uploadAndVerify(cmd: String): Unit = {
val f = File.createTempFile("uploadTestFile", cmd)
val expectedStr = "Test data"
val expectedData = expectedStr.getBytes()
Files.write(Paths.get(f.getAbsolutePath), expectedData)
val b = new Array[Byte](expectedData.length)
val captor = ArgumentCaptor.forClass(classOf[InputStream])
if (cmd == "file") {
client.uploadFile(f).get(TIMEOUT_S, TimeUnit.SECONDS)
verify(session, times(1)).addFile(captor.capture(), meq(f.getName))
} else {
client.uploadJar(f).get(TIMEOUT_S, TimeUnit.SECONDS)
verify(session, times(1)).addJar(captor.capture(), meq(f.getName))
}
captor.getValue.read(b)
assert(expectedStr === new String(b))
}
private def runJob(sync: Boolean, genStatusFn: Long => Seq[JobStatus]): (Long, JFuture[Int]) = {
val jobId = java.lang.Long.valueOf(ID_GENERATOR.incrementAndGet())
when(session.submitJob(any(classOf[Array[Byte]]), anyString())).thenReturn(jobId)
val statuses = genStatusFn(jobId)
val first = statuses.head
val remaining = statuses.drop(1)
when(session.jobStatus(meq(jobId))).thenReturn(first, remaining: _*)
val job = new Echo(42)
val handle = if (sync) client.run(job) else client.submit(job)
(jobId, handle)
}
private def testJob(sync: Boolean, response: Option[Any] = None): Unit = {
val (jobId, handle) = runJob(sync, { id => Seq(
new JobStatus(id, JobHandle.State.STARTED, null, null),
new JobStatus(id, JobHandle.State.SUCCEEDED, serialize(response.getOrElse(id)), null))
})
assert(handle.get(TIMEOUT_S, TimeUnit.SECONDS) === response.getOrElse(jobId))
verify(session, times(2)).jobStatus(meq(jobId))
}
private def withClient(desc: String)(fn: => Unit): Unit = {
it(desc) {
assume(client != null, "No active client.")
fn
}
}
def serialize(value: Any): Array[Byte] = {
BufferUtils.toByteArray(serializer.serialize(value))
}
}
private object HttpClientSpec {
// Hack warning: keep the session object available so that individual tests can mock
// the desired behavior before making requests to the server.
var session: InteractiveSession = _
}
private class HttpClientTestBootstrap extends LifeCycle {
private implicit def executor: ExecutionContext = ExecutionContext.global
override def init(context: ServletContext): Unit = {
val conf = new LivyConf()
val stateStore = mock(classOf[SessionStore])
val sessionManager = new InteractiveSessionManager(conf, stateStore, Some(Seq.empty))
val accessManager = new AccessManager(conf)
val servlet = new InteractiveSessionServlet(sessionManager, stateStore, conf, accessManager) {
override protected def createSession(req: HttpServletRequest): InteractiveSession = {
val session = mock(classOf[InteractiveSession])
val id = sessionManager.nextId()
when(session.id).thenReturn(id)
when(session.name).thenReturn(None)
when(session.appId).thenReturn(None)
when(session.appInfo).thenReturn(AppInfo())
when(session.state).thenReturn(SessionState.Idle)
when(session.proxyUser).thenReturn(None)
when(session.kind).thenReturn(Spark)
when(session.driverMemory).thenReturn(None)
when(session.driverCores).thenReturn(None)
when(session.executorMemory).thenReturn(None)
when(session.executorCores).thenReturn(None)
when(session.numExecutors).thenReturn(None)
when(session.proxyUser).thenReturn(None)
when(session.queue).thenReturn(None)
when(session.conf).thenReturn(Map("" -> ""))
when(session.archives).thenReturn(List())
when(session.files).thenReturn(List())
when(session.jars).thenReturn(List())
when(session.pyFiles).thenReturn(List())
when(session.stop()).thenReturn(Future.successful(()))
when(session.ttl).thenReturn(None)
when(session.idleTimeout).thenReturn(None)
require(HttpClientSpec.session == null, "Session already created?")
HttpClientSpec.session = session
session
}
}
context.mount(servlet, s"${LivyConnection.SESSIONS_URI}/*")
}
}