blob: 247504f5ebbb95de0630d8d88e0345cf3b198cfe [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.deploy.rest
import java.io.{DataOutputStream, FileNotFoundException}
import java.net.{ConnectException, HttpURLConnection, SocketException, URL}
import java.nio.charset.StandardCharsets
import java.util.concurrent.TimeoutException
import scala.collection.mutable
import scala.concurrent.{Await, Future}
import scala.concurrent.duration._
import scala.io.Source
import scala.util.control.NonFatal
import com.fasterxml.jackson.core.JsonProcessingException
import jakarta.servlet.http.HttpServletResponse
import org.apache.spark.{SPARK_VERSION => sparkVersion, SparkConf, SparkException}
import org.apache.spark.deploy.SparkApplication
import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKeys._
import org.apache.spark.util.Utils
/**
* A client that submits applications to a [[RestSubmissionServer]].
*
* In protocol version v1, the REST URL takes the form http://[host:port]/v1/submissions/[action],
* where [action] can be one of create, kill, or status. Each type of request is represented in
* an HTTP message sent to the following prefixes:
* (1) submit - POST to /submissions/create
* (2) kill - POST /submissions/kill/[submissionId]
* (3) status - GET /submissions/status/[submissionId]
*
* In the case of (1), parameters are posted in the HTTP body in the form of JSON fields.
* Otherwise, the URL fully specifies the intended action of the client.
*
* Since the protocol is expected to be stable across Spark versions, existing fields cannot be
* added or removed, though new optional fields can be added. In the rare event that forward or
* backward compatibility is broken, Spark must introduce a new protocol version (e.g. v2).
*
* The client and the server must communicate using the same version of the protocol. If there
* is a mismatch, the server will respond with the highest protocol version it supports. A future
* implementation of this client can use that information to retry using the version specified
* by the server.
*/
private[spark] class RestSubmissionClient(master: String) extends Logging {
import RestSubmissionClient._
private val masters: Array[String] = if (master.startsWith("spark://")) {
Utils.parseStandaloneMasterUrls(master)
} else {
Array(master)
}
// Set of masters that lost contact with us, used to keep track of
// whether there are masters still alive for us to communicate with
private val lostMasters = new mutable.HashSet[String]
/**
* Submit an application specified by the parameters in the provided request.
*
* If the submission was successful, poll the status of the submission and report
* it to the user. Otherwise, report the error message provided by the server.
*/
def createSubmission(request: CreateSubmissionRequest): SubmitRestProtocolResponse = {
logInfo(s"Submitting a request to launch an application in $master.")
var handled: Boolean = false
var response: SubmitRestProtocolResponse = null
for (m <- masters if !handled) {
validateMaster(m)
val url = getSubmitUrl(m)
try {
response = postJson(url, request.toJson)
response match {
case s: CreateSubmissionResponse =>
if (s.success) {
reportSubmissionStatus(s)
handleRestResponse(s)
handled = true
}
case unexpected =>
handleUnexpectedRestResponse(unexpected)
}
} catch {
case e: SubmitRestConnectionException =>
if (handleConnectionException(m)) {
throw new SubmitRestConnectionException("Unable to connect to server", e)
}
}
}
response
}
/** Request that the server kill the specified submission. */
def killSubmission(submissionId: String): SubmitRestProtocolResponse = {
logInfo(s"Submitting a request to kill submission $submissionId in $master.")
var handled: Boolean = false
var response: SubmitRestProtocolResponse = null
for (m <- masters if !handled) {
validateMaster(m)
val url = getKillUrl(m, submissionId)
try {
response = post(url)
response match {
case k: KillSubmissionResponse =>
if (!Utils.responseFromBackup(k.message)) {
handleRestResponse(k)
handled = true
}
case unexpected =>
handleUnexpectedRestResponse(unexpected)
}
} catch {
case e: SubmitRestConnectionException =>
if (handleConnectionException(m)) {
throw new SubmitRestConnectionException("Unable to connect to server", e)
}
}
}
response
}
/** Request that the server kill all submissions. */
def killAllSubmissions(): SubmitRestProtocolResponse = {
logInfo(s"Submitting a request to kill all submissions in $master.")
var handled: Boolean = false
var response: SubmitRestProtocolResponse = null
for (m <- masters if !handled) {
validateMaster(m)
val url = getKillAllUrl(m)
try {
response = post(url)
response match {
case k: KillAllSubmissionResponse =>
if (!Utils.responseFromBackup(k.message)) {
handleRestResponse(k)
handled = true
}
case unexpected =>
handleUnexpectedRestResponse(unexpected)
}
} catch {
case e: SubmitRestConnectionException =>
if (handleConnectionException(m)) {
throw new SubmitRestConnectionException("Unable to connect to server", e)
}
}
}
response
}
/** Request that the server clears all submissions and applications. */
def clear(): SubmitRestProtocolResponse = {
logInfo(s"Submitting a request to clear $master.")
var handled: Boolean = false
var response: SubmitRestProtocolResponse = null
for (m <- masters if !handled) {
validateMaster(m)
val url = getClearUrl(m)
try {
response = post(url)
response match {
case k: ClearResponse =>
if (!Utils.responseFromBackup(k.message)) {
handleRestResponse(k)
handled = true
}
case unexpected =>
handleUnexpectedRestResponse(unexpected)
}
} catch {
case e: SubmitRestConnectionException =>
if (handleConnectionException(m)) {
throw new SubmitRestConnectionException("Unable to connect to server", e)
}
}
}
response
}
/** Check the readiness of Master. */
def readyz(): SubmitRestProtocolResponse = {
logInfo(s"Submitting a request to check the status of $master.")
var handled: Boolean = false
var response: SubmitRestProtocolResponse = new ErrorResponse
for (m <- masters if !handled) {
validateMaster(m)
val url = getReadyzUrl(m)
try {
response = get(url)
response match {
case k: ReadyzResponse =>
if (!Utils.responseFromBackup(k.message)) {
handleRestResponse(k)
handled = true
}
case unexpected =>
handleUnexpectedRestResponse(unexpected)
}
} catch {
case e: SubmitRestConnectionException =>
if (handleConnectionException(m)) {
throw new SubmitRestConnectionException("Unable to connect to server", e)
}
}
}
response
}
/** Request the status of a submission from the server. */
def requestSubmissionStatus(
submissionId: String,
quiet: Boolean = false): SubmitRestProtocolResponse = {
logInfo(s"Submitting a request for the status of submission $submissionId in $master.")
var handled: Boolean = false
var response: SubmitRestProtocolResponse = null
for (m <- masters if !handled) {
validateMaster(m)
val url = getStatusUrl(m, submissionId)
try {
response = get(url)
response match {
case s: SubmissionStatusResponse if s.success =>
if (!quiet) {
handleRestResponse(s)
}
handled = true
case unexpected =>
handleUnexpectedRestResponse(unexpected)
}
} catch {
case e: SubmitRestConnectionException =>
if (handleConnectionException(m)) {
throw new SubmitRestConnectionException("Unable to connect to server", e)
}
}
}
response
}
/** Construct a message that captures the specified parameters for submitting an application. */
def constructSubmitRequest(
appResource: String,
mainClass: String,
appArgs: Array[String],
sparkProperties: Map[String, String],
environmentVariables: Map[String, String]): CreateSubmissionRequest = {
val message = new CreateSubmissionRequest
message.clientSparkVersion = sparkVersion
message.appResource = appResource
message.mainClass = mainClass
message.appArgs = appArgs
message.sparkProperties = sparkProperties
message.environmentVariables = environmentVariables
message.validate()
message
}
/** Send a GET request to the specified URL. */
private def get(url: URL): SubmitRestProtocolResponse = {
logDebug(s"Sending GET request to server at $url.")
val conn = url.openConnection().asInstanceOf[HttpURLConnection]
conn.setRequestMethod("GET")
readResponse(conn)
}
/** Send a POST request to the specified URL. */
private def post(url: URL): SubmitRestProtocolResponse = {
logDebug(s"Sending POST request to server at $url.")
val conn = url.openConnection().asInstanceOf[HttpURLConnection]
conn.setRequestMethod("POST")
readResponse(conn)
}
/** Send a POST request with the given JSON as the body to the specified URL. */
private def postJson(url: URL, json: String): SubmitRestProtocolResponse = {
logDebug(s"Sending POST request to server at $url:\n$json")
val conn = url.openConnection().asInstanceOf[HttpURLConnection]
conn.setRequestMethod("POST")
conn.setRequestProperty("Content-Type", "application/json")
conn.setRequestProperty("charset", "utf-8")
conn.setDoOutput(true)
try {
val out = new DataOutputStream(conn.getOutputStream)
Utils.tryWithSafeFinally {
out.write(json.getBytes(StandardCharsets.UTF_8))
} {
out.close()
}
} catch {
case e: ConnectException =>
throw new SubmitRestConnectionException("Connect Exception when connect to server", e)
}
readResponse(conn)
}
/**
* Read the response from the server and return it as a validated [[SubmitRestProtocolResponse]].
* If the response represents an error, report the embedded message to the user.
* Exposed for testing.
*/
private[rest] def readResponse(connection: HttpURLConnection): SubmitRestProtocolResponse = {
// scalastyle:off executioncontextglobal
import scala.concurrent.ExecutionContext.Implicits.global
// scalastyle:on executioncontextglobal
val responseFuture = Future {
val responseCode = connection.getResponseCode
if (responseCode != HttpServletResponse.SC_OK) {
val errString = Some(Source.fromInputStream(connection.getErrorStream())
.getLines().mkString("\n"))
if (responseCode == HttpServletResponse.SC_INTERNAL_SERVER_ERROR &&
!connection.getContentType().contains("application/json")) {
throw new SubmitRestProtocolException(s"Server responded with exception:\n${errString}")
}
logError(log"Server responded with error:\n${MDC(ERROR, errString)}")
val error = new ErrorResponse
if (responseCode == RestSubmissionServer.SC_UNKNOWN_PROTOCOL_VERSION) {
error.highestProtocolVersion = RestSubmissionServer.PROTOCOL_VERSION
}
error.message = errString.get
error
} else {
val dataStream = connection.getInputStream
// If the server threw an exception while writing a response, it will not have a body
if (dataStream == null) {
throw new SubmitRestProtocolException("Server returned empty body")
}
val responseJson = Source.fromInputStream(dataStream).mkString
logDebug(s"Response from the server:\n$responseJson")
val response = SubmitRestProtocolMessage.fromJson(responseJson)
response.validate()
response match {
// If the response is an error, log the message
case error: ErrorResponse =>
logError(log"Server responded with error:\n${MDC(ERROR, error.message)}")
error
// Otherwise, simply return the response
case response: SubmitRestProtocolResponse => response
case unexpected =>
throw new SubmitRestProtocolException(
s"Message received from server was not a response:\n${unexpected.toJson}")
}
}
}
// scalastyle:off awaitresult
try { Await.result(responseFuture, 10.seconds) } catch {
// scalastyle:on awaitresult
case unreachable @ (_: FileNotFoundException | _: SocketException) =>
throw new SubmitRestConnectionException("Unable to connect to server", unreachable)
case malformed @ (_: JsonProcessingException | _: SubmitRestProtocolException) =>
throw new SubmitRestProtocolException("Malformed response received from server", malformed)
case timeout: TimeoutException =>
throw new SubmitRestConnectionException("No response from server", timeout)
case NonFatal(t) =>
throw new SparkException("Exception while waiting for response", t)
}
}
/** Return the REST URL for creating a new submission. */
private def getSubmitUrl(master: String): URL = {
val baseUrl = getBaseUrl(master)
new URL(s"$baseUrl/create")
}
/** Return the REST URL for killing an existing submission. */
private def getKillUrl(master: String, submissionId: String): URL = {
val baseUrl = getBaseUrl(master)
new URL(s"$baseUrl/kill/$submissionId")
}
/** Return the REST URL for killing all submissions. */
private def getKillAllUrl(master: String): URL = {
val baseUrl = getBaseUrl(master)
new URL(s"$baseUrl/killall")
}
/** Return the REST URL for clear all existing submissions and applications. */
private def getClearUrl(master: String): URL = {
val baseUrl = getBaseUrl(master)
new URL(s"$baseUrl/clear")
}
/** Return the REST URL for requesting the readyz API. */
private def getReadyzUrl(master: String): URL = {
val baseUrl = getBaseUrl(master)
new URL(s"$baseUrl/readyz")
}
/** Return the REST URL for requesting the status of an existing submission. */
private def getStatusUrl(master: String, submissionId: String): URL = {
val baseUrl = getBaseUrl(master)
new URL(s"$baseUrl/status/$submissionId")
}
/** Return the base URL for communicating with the server, including the protocol version. */
private def getBaseUrl(master: String): String = {
var masterUrl = master
supportedMasterPrefixes.foreach { prefix =>
if (master.startsWith(prefix)) {
masterUrl = master.stripPrefix(prefix)
}
}
masterUrl = masterUrl.stripSuffix("/")
s"http://$masterUrl/$PROTOCOL_VERSION/submissions"
}
/** Throw an exception if this is not standalone mode. */
private def validateMaster(master: String): Unit = {
val valid = supportedMasterPrefixes.exists { prefix => master.startsWith(prefix) }
if (!valid) {
throw new IllegalArgumentException(
"This REST client only supports master URLs that start with " +
"one of the following: " + supportedMasterPrefixes.mkString(","))
}
}
/** Report the status of a newly created submission. */
private def reportSubmissionStatus(
submitResponse: CreateSubmissionResponse): Unit = {
if (submitResponse.success) {
val submissionId = submitResponse.submissionId
if (submissionId != null) {
logInfo(s"Submission successfully created as $submissionId. Polling submission state...")
pollSubmissionStatus(submissionId)
} else {
// should never happen
logError("Application successfully submitted, but submission ID was not provided!")
}
} else {
val failMessage = Option(submitResponse.message).map { ": " + _ }.getOrElse("")
logError(log"Application submission failed${MDC(ERROR, failMessage)}")
}
}
/**
* Poll the status of the specified submission and log it.
* This retries up to a fixed number of times before giving up.
*/
private def pollSubmissionStatus(submissionId: String): Unit = {
(1 to REPORT_DRIVER_STATUS_MAX_TRIES).foreach { _ =>
val response = requestSubmissionStatus(submissionId, quiet = true)
val statusResponse = response match {
case s: SubmissionStatusResponse => s
case _ => return // unexpected type, let upstream caller handle it
}
if (statusResponse.success) {
val driverState = Option(statusResponse.driverState)
val workerId = Option(statusResponse.workerId)
val workerHostPort = Option(statusResponse.workerHostPort)
val exception = Option(statusResponse.message)
// Log driver state, if present
driverState match {
case Some(state) => logInfo(s"State of driver $submissionId is now $state.")
case _ =>
logError(log"State of driver ${MDC(SUBMISSION_ID, submissionId)} was not found!")
}
// Log worker node, if present
(workerId, workerHostPort) match {
case (Some(id), Some(hp)) => logInfo(s"Driver is running on worker $id at $hp.")
case _ =>
}
// Log exception stack trace, if present
exception.foreach { e => logError(log"${MDC(ERROR, e)}") }
return
}
Thread.sleep(REPORT_DRIVER_STATUS_INTERVAL)
}
logError(log"Error: Master did not recognize driver ${MDC(SUBMISSION_ID, submissionId)}.")
}
/** Log the response sent by the server in the REST application submission protocol. */
private def handleRestResponse(response: SubmitRestProtocolResponse): Unit = {
logInfo(s"Server responded with ${response.messageType}:\n${response.toJson}")
}
/** Log an appropriate error if the response sent by the server is not of the expected type. */
private def handleUnexpectedRestResponse(unexpected: SubmitRestProtocolResponse): Unit = {
// scalastyle:off line.size.limit
logError(log"Error: Server responded with message of unexpected type ${MDC(CLASS_NAME, unexpected.messageType)}.")
// scalastyle:on
}
/**
* When a connection exception is caught, return true if all masters are lost.
* Note that the heuristic used here does not take into account that masters
* can recover during the lifetime of this client. This assumption should be
* harmless because this client currently does not support retrying submission
* on failure yet (SPARK-6443).
*/
private def handleConnectionException(masterUrl: String): Boolean = {
if (!lostMasters.contains(masterUrl)) {
logWarning(log"Unable to connect to server ${MDC(MASTER_URL, masterUrl)}.")
lostMasters += masterUrl
}
lostMasters.size >= masters.length
}
}
private[spark] object RestSubmissionClient {
val supportedMasterPrefixes = Seq("spark://")
// SPARK_HOME and SPARK_CONF_DIR are filtered out because they are usually wrong
// on the remote machine (SPARK-12345) (SPARK-25934)
private val EXCLUDED_SPARK_ENV_VARS = Set("SPARK_ENV_LOADED", "SPARK_HOME", "SPARK_CONF_DIR")
private val REPORT_DRIVER_STATUS_INTERVAL = 1000
private val REPORT_DRIVER_STATUS_MAX_TRIES = 10
val PROTOCOL_VERSION = "v1"
/**
* Filter non-spark environment variables from any environment.
*/
private[rest] def filterSystemEnvironment(env: Map[String, String]): Map[String, String] = {
env.filter { case (k, _) =>
k.startsWith("SPARK_") && !EXCLUDED_SPARK_ENV_VARS.contains(k)
}
}
private[spark] def supportsRestClient(master: String): Boolean = {
supportedMasterPrefixes.exists(master.startsWith)
}
}
private[spark] class RestSubmissionClientApp extends SparkApplication {
/** Submits a request to run the application and return the response. Visible for testing. */
def run(
appResource: String,
mainClass: String,
appArgs: Array[String],
conf: SparkConf,
env: Map[String, String] = Map()): SubmitRestProtocolResponse = {
val master = conf.getOption("spark.master").getOrElse {
throw new IllegalArgumentException("'spark.master' must be set.")
}
val sparkProperties = conf.getAll.toMap
val client = new RestSubmissionClient(master)
val submitRequest = client.constructSubmitRequest(
appResource, mainClass, appArgs, sparkProperties, env)
client.createSubmission(submitRequest)
}
override def start(args: Array[String], conf: SparkConf): Unit = {
if (args.length < 2) {
sys.error("Usage: RestSubmissionClient [app resource] [main class] [app args*]")
sys.exit(1)
}
val appResource = args(0)
val mainClass = args(1)
val appArgs = args.slice(2, args.length)
val env = RestSubmissionClient.filterSystemEnvironment(sys.env)
run(appResource, mainClass, appArgs, conf, env)
}
}