blob: 75bab0b63d3eff96ff4478ed9937c2441161ec2c [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.livy.thriftserver.rpc
import java.lang.reflect.InvocationTargetException
import scala.collection.immutable.HashMap
import scala.collection.mutable.ArrayBuffer
import scala.util.Try
import org.apache.hive.service.cli.SessionHandle
import org.apache.livy._
import org.apache.livy.server.interactive.InteractiveSession
import org.apache.livy.thriftserver.serde.ColumnOrientedResultSet
import org.apache.livy.thriftserver.types.DataType
import org.apache.livy.utils.LivySparkUtils
class RpcClient(livySession: InteractiveSession) extends Logging {
import RpcClient._
private val isSpark1 = {
val (sparkMajorVersion, _) =
LivySparkUtils.formatSparkVersion(livySession.livyConf.get(LivyConf.LIVY_SPARK_VERSION))
sparkMajorVersion == 1
}
private val defaultIncrementalCollect =
livySession.livyConf.getBoolean(LivyConf.THRIFT_INCR_COLLECT_ENABLED).toString
private val rscClient = livySession.client.get
def isValid: Boolean = rscClient.isAlive
private def sessionId(sessionHandle: SessionHandle): String = {
sessionHandle.getSessionId.toString
}
@throws[Exception]
def executeSql(
sessionHandle: SessionHandle,
statementId: String,
statement: String): JobHandle[_] = {
info(s"RSC client is executing SQL query: $statement, statementId = $statementId, session = " +
sessionHandle)
require(null != statementId, s"Invalid statementId specified. StatementId = $statementId")
require(null != statement, s"Invalid statement specified. StatementId = $statement")
livySession.recordActivity()
rscClient.submit(executeSqlJob(sessionId(sessionHandle),
statementId,
statement,
isSpark1,
defaultIncrementalCollect,
s"spark.${LivyConf.THRIFT_INCR_COLLECT_ENABLED}"))
}
@throws[Exception]
def fetchResult(statementId: String,
types: Array[DataType],
maxRows: Int): JobHandle[ColumnOrientedResultSet] = {
info(s"RSC client is fetching result for statementId $statementId with $maxRows maxRows.")
require(null != statementId, s"Invalid statementId specified. StatementId = $statementId")
livySession.recordActivity()
rscClient.submit(fetchResultJob(statementId, types, maxRows))
}
@throws[Exception]
def fetchResultSchema(statementId: String): JobHandle[String] = {
info(s"RSC client is fetching result schema for statementId = $statementId")
require(null != statementId, s"Invalid statementId specified. statementId = $statementId")
livySession.recordActivity()
rscClient.submit(fetchResultSchemaJob(statementId))
}
@throws[Exception]
def cleanupStatement(statementId: String, cancelJob: Boolean = false): JobHandle[_] = {
info(s"Cleaning up remote session for statementId = $statementId")
require(null != statementId, s"Invalid statementId specified. statementId = $statementId")
livySession.recordActivity()
rscClient.submit(cleanupStatementJob(statementId))
}
/**
* Creates a new Spark context for the specified session and stores it in a shared variable so
* that any incoming session uses a different one: it is needed in order to avoid interactions
* between different users working on the same remote Livy session (eg. setting a property,
* changing database, etc.).
*/
@throws[Exception]
def executeRegisterSession(sessionHandle: SessionHandle): JobHandle[_] = {
info(s"RSC client is executing register session $sessionHandle")
livySession.recordActivity()
rscClient.submit(registerSessionJob(sessionId(sessionHandle), isSpark1))
}
/**
* Removes the Spark session created for the specified session from the shared variable.
*/
@throws[Exception]
def executeUnregisterSession(sessionHandle: SessionHandle): JobHandle[_] = {
info(s"RSC client is executing unregister session $sessionHandle")
livySession.recordActivity()
rscClient.submit(unregisterSessionJob(sessionId(sessionHandle)))
}
}
/**
* As remotely we don't have any class instance, all the job definitions are placed here in
* order to enforce that we are not accessing any class attribute
*/
object RpcClient {
// Maps a session ID to its SparkSession (or HiveContext/SQLContext according to the Spark
// version used)
val SESSION_SPARK_ENTRY_MAP = "livy.thriftserver.rpc_sessionIdToSparkSQLSession"
val STATEMENT_RESULT_ITER_MAP = "livy.thriftserver.rpc_statementIdToResultIter"
val STATEMENT_SCHEMA_MAP = "livy.thriftserver.rpc_statementIdToSchema"
private def registerSessionJob(sessionId: String, isSpark1: Boolean): Job[_] = new Job[Boolean] {
override def call(jc: JobContext): Boolean = {
val spark: Any = if (isSpark1) {
Option(jc.hivectx()).getOrElse(jc.sqlctx())
} else {
jc.sparkSession()
}
val sessionSpecificSpark = spark.getClass.getMethod("newSession").invoke(spark)
jc.sc().synchronized {
val existingMap =
Try(jc.getSharedObject[HashMap[String, AnyRef]](SESSION_SPARK_ENTRY_MAP))
.getOrElse(new HashMap[String, AnyRef]())
jc.setSharedObject(SESSION_SPARK_ENTRY_MAP,
existingMap + ((sessionId, sessionSpecificSpark)))
Try(jc.getSharedObject[HashMap[String, String]](STATEMENT_SCHEMA_MAP))
.failed.foreach { _ =>
jc.setSharedObject(STATEMENT_SCHEMA_MAP, new HashMap[String, String]())
}
Try(jc.getSharedObject[HashMap[String, Iterator[_]]](STATEMENT_RESULT_ITER_MAP))
.failed.foreach { _ =>
jc.setSharedObject(STATEMENT_RESULT_ITER_MAP, new HashMap[String, Iterator[_]]())
}
}
true
}
}
private def unregisterSessionJob(sessionId: String): Job[_] = new Job[Boolean] {
override def call(jobContext: JobContext): Boolean = {
jobContext.sc().synchronized {
val existingMap =
jobContext.getSharedObject[HashMap[String, AnyRef]](SESSION_SPARK_ENTRY_MAP)
jobContext.setSharedObject(SESSION_SPARK_ENTRY_MAP, existingMap - sessionId)
}
true
}
}
private def cleanupStatementJob(statementId: String): Job[_] = new Job[Boolean] {
override def call(jc: JobContext): Boolean = {
val sparkContext = jc.sc()
sparkContext.cancelJobGroup(statementId)
sparkContext.synchronized {
// Clear job group only if current job group is same as expected job group.
if (sparkContext.getLocalProperty("spark.jobGroup.id") == statementId) {
sparkContext.clearJobGroup()
}
val iterMap = jc.getSharedObject[HashMap[String, Iterator[_]]](STATEMENT_RESULT_ITER_MAP)
jc.setSharedObject(STATEMENT_RESULT_ITER_MAP, iterMap - statementId)
val schemaMap = jc.getSharedObject[HashMap[String, String]](STATEMENT_SCHEMA_MAP)
jc.setSharedObject(STATEMENT_SCHEMA_MAP, schemaMap - statementId)
}
true
}
}
private def fetchResultSchemaJob(statementId: String): Job[String] = new Job[String] {
override def call(jobContext: JobContext): String = {
jobContext.getSharedObject[HashMap[String, String]](STATEMENT_SCHEMA_MAP)(statementId)
}
}
private def fetchResultJob(statementId: String,
types: Array[DataType],
maxRows: Int): Job[ColumnOrientedResultSet] = new Job[ColumnOrientedResultSet] {
override def call(jobContext: JobContext): ColumnOrientedResultSet = {
val statementIterMap =
jobContext.getSharedObject[HashMap[String, Iterator[_]]](STATEMENT_RESULT_ITER_MAP)
val iter = statementIterMap(statementId)
if (null == iter) {
// Previous query execution failed.
throw new NoSuchElementException("No successful query executed for output")
}
val resultSet = new ColumnOrientedResultSet(types)
val numOfColumns = types.length
if (!iter.hasNext) {
resultSet
} else {
var curRow = 0
while (curRow < maxRows && iter.hasNext) {
val sparkRow = iter.next()
val row = ArrayBuffer[Any]()
var curCol: Integer = 0
while (curCol < numOfColumns) {
row += sparkRow.getClass.getMethod("get", classOf[Int]).invoke(sparkRow, curCol)
curCol += 1
}
resultSet.addRow(row.toArray.asInstanceOf[Array[Object]])
curRow += 1
}
resultSet
}
}
}
private def executeSqlJob(sessionId: String,
statementId: String,
statement: String,
isSpark1: Boolean,
defaultIncrementalCollect: String,
incrementalCollectEnabledProp: String): Job[_] = new Job[Boolean] {
override def call(jc: JobContext): Boolean = {
val sparkContext = jc.sc()
sparkContext.synchronized {
sparkContext.setJobGroup(statementId, statement)
}
val spark = jc.getSharedObject[HashMap[String, AnyRef]](SESSION_SPARK_ENTRY_MAP)(sessionId)
try {
val result = spark.getClass.getMethod("sql", classOf[String]).invoke(spark, statement)
val schema = result.getClass.getMethod("schema").invoke(result)
val jsonString = schema.getClass.getMethod("json").invoke(schema).asInstanceOf[String]
// Set the schema in the shared map
sparkContext.synchronized {
val existingMap = jc.getSharedObject[HashMap[String, String]](STATEMENT_SCHEMA_MAP)
jc.setSharedObject(STATEMENT_SCHEMA_MAP, existingMap + ((statementId, jsonString)))
}
val incrementalCollect = {
if (isSpark1) {
spark.getClass.getMethod("getConf", classOf[String], classOf[String])
.invoke(spark,
incrementalCollectEnabledProp,
defaultIncrementalCollect)
.asInstanceOf[String].toBoolean
} else {
val conf = spark.getClass.getMethod("conf").invoke(spark)
conf.getClass.getMethod("get", classOf[String], classOf[String])
.invoke(conf,
incrementalCollectEnabledProp,
defaultIncrementalCollect)
.asInstanceOf[String].toBoolean
}
}
val iter = if (incrementalCollect) {
val rdd = result.getClass.getMethod("rdd").invoke(result)
rdd.getClass.getMethod("toLocalIterator").invoke(rdd).asInstanceOf[Iterator[_]]
} else {
result.getClass.getMethod("collect").invoke(result).asInstanceOf[Array[_]].iterator
}
// Set the iterator in the shared map
sparkContext.synchronized {
val existingMap =
jc.getSharedObject[HashMap[String, Iterator[_]]](STATEMENT_RESULT_ITER_MAP)
jc.setSharedObject(STATEMENT_RESULT_ITER_MAP, existingMap + ((statementId, iter)))
}
} catch {
case e: InvocationTargetException => throw e.getCause
}
true
}
}
}