blob: 2447682fcbdd4a520270100a6ff5f1158a7cb40d [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.predictionio.workflow
import java.io.Serializable
import java.util.concurrent.TimeUnit
import akka.actor._
import akka.event.Logging
import akka.io.IO
import akka.pattern.ask
import akka.util.Timeout
import com.github.nscala_time.time.Imports.DateTime
import com.twitter.bijection.Injection
import com.twitter.chill.{KryoBase, KryoInjection, ScalaKryoInstantiator}
import com.typesafe.config.ConfigFactory
import de.javakaffee.kryoserializers.SynchronizedCollectionsSerializer
import grizzled.slf4j.Logging
import org.apache.commons.lang3.exception.ExceptionUtils
import org.apache.predictionio.authentication.KeyAuthentication
import org.apache.predictionio.configuration.SSLConfiguration
import org.apache.predictionio.controller.{Engine, Params, Utils, WithPrId}
import org.apache.predictionio.core.{BaseAlgorithm, BaseServing, Doer}
import org.apache.predictionio.data.storage.{EngineInstance, Storage}
import org.apache.predictionio.workflow.JsonExtractorOption.JsonExtractorOption
import org.json4s._
import org.json4s.native.JsonMethods._
import org.json4s.native.Serialization.write
import spray.can.Http
import spray.can.server.ServerSettings
import spray.http.MediaTypes._
import spray.http._
import spray.httpx.Json4sSupport
import spray.routing._
import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.Future
import scala.concurrent.duration._
import scala.language.existentials
import scala.util.{Failure, Random, Success}
import scalaj.http.HttpOptions
class KryoInstantiator(classLoader: ClassLoader) extends ScalaKryoInstantiator {
override def newKryo(): KryoBase = {
val kryo = super.newKryo()
kryo.setClassLoader(classLoader)
SynchronizedCollectionsSerializer.registerSerializers(kryo)
kryo
}
}
object KryoInstantiator extends Serializable {
def newKryoInjection : Injection[Any, Array[Byte]] = {
val kryoInstantiator = new KryoInstantiator(getClass.getClassLoader)
KryoInjection.instance(kryoInstantiator)
}
}
case class ServerConfig(
batch: String = "",
engineInstanceId: String = "",
engineId: Option[String] = None,
engineVersion: Option[String] = None,
engineVariant: String = "",
env: Option[String] = None,
ip: String = "0.0.0.0",
port: Int = 8000,
feedback: Boolean = false,
eventServerIp: String = "0.0.0.0",
eventServerPort: Int = 7070,
accessKey: Option[String] = None,
logUrl: Option[String] = None,
logPrefix: Option[String] = None,
logFile: Option[String] = None,
verbose: Boolean = false,
debug: Boolean = false,
jsonExtractor: JsonExtractorOption = JsonExtractorOption.Both)
case class StartServer()
case class BindServer()
case class StopServer()
case class ReloadServer()
object CreateServer extends Logging {
val actorSystem = ActorSystem("pio-server")
val engineInstances = Storage.getMetaDataEngineInstances
val modeldata = Storage.getModelDataModels
def main(args: Array[String]): Unit = {
val parser = new scopt.OptionParser[ServerConfig]("CreateServer") {
opt[String]("batch") action { (x, c) =>
c.copy(batch = x)
} text("Batch label of the deployment.")
opt[String]("engineId") action { (x, c) =>
c.copy(engineId = Some(x))
} text("Engine ID.")
opt[String]("engineVersion") action { (x, c) =>
c.copy(engineVersion = Some(x))
} text("Engine version.")
opt[String]("engine-variant") required() action { (x, c) =>
c.copy(engineVariant = x)
} text("Engine variant JSON.")
opt[String]("ip") action { (x, c) =>
c.copy(ip = x)
}
opt[String]("env") action { (x, c) =>
c.copy(env = Some(x))
} text("Comma-separated list of environmental variables (in 'FOO=BAR' " +
"format) to pass to the Spark execution environment.")
opt[Int]("port") action { (x, c) =>
c.copy(port = x)
} text("Port to bind to (default: 8000).")
opt[String]("engineInstanceId") required() action { (x, c) =>
c.copy(engineInstanceId = x)
} text("Engine instance ID.")
opt[Unit]("feedback") action { (_, c) =>
c.copy(feedback = true)
} text("Enable feedback loop to event server.")
opt[String]("event-server-ip") action { (x, c) =>
c.copy(eventServerIp = x)
}
opt[Int]("event-server-port") action { (x, c) =>
c.copy(eventServerPort = x)
} text("Event server port. Default: 7070")
opt[String]("accesskey") action { (x, c) =>
c.copy(accessKey = Some(x))
} text("Event server access key.")
opt[String]("log-url") action { (x, c) =>
c.copy(logUrl = Some(x))
}
opt[String]("log-prefix") action { (x, c) =>
c.copy(logPrefix = Some(x))
}
opt[String]("log-file") action { (x, c) =>
c.copy(logFile = Some(x))
}
opt[Unit]("verbose") action { (x, c) =>
c.copy(verbose = true)
} text("Enable verbose output.")
opt[Unit]("debug") action { (x, c) =>
c.copy(debug = true)
} text("Enable debug output.")
opt[String]("json-extractor") action { (x, c) =>
c.copy(jsonExtractor = JsonExtractorOption.withName(x))
}
}
parser.parse(args, ServerConfig()) map { sc =>
WorkflowUtils.modifyLogging(sc.verbose)
engineInstances.get(sc.engineInstanceId) map { engineInstance =>
val engineId = sc.engineId.getOrElse(engineInstance.engineId)
val engineVersion = sc.engineVersion.getOrElse(
engineInstance.engineVersion)
val engineFactoryName = engineInstance.engineFactory
val master = actorSystem.actorOf(Props(
classOf[MasterActor],
sc,
engineInstance,
engineFactoryName),
"master")
implicit val timeout = Timeout(5.seconds)
master ? StartServer()
actorSystem.awaitTermination
} getOrElse {
error(s"Invalid engine instance ID. Aborting server.")
}
}
}
def createServerActorWithEngine[TD, EIN, PD, Q, P, A](
sc: ServerConfig,
engineInstance: EngineInstance,
engine: Engine[TD, EIN, PD, Q, P, A],
engineLanguage: EngineLanguage.Value): ActorRef = {
val engineParams = engine.engineInstanceToEngineParams(
engineInstance, sc.jsonExtractor)
val kryo = KryoInstantiator.newKryoInjection
val modelsFromEngineInstance =
kryo.invert(modeldata.get(engineInstance.id).get.models).get.
asInstanceOf[Seq[Any]]
val batch = if (engineInstance.batch.nonEmpty) {
s"${engineInstance.engineFactory} (${engineInstance.batch})"
} else {
engineInstance.engineFactory
}
val sparkContext = WorkflowContext(
batch = batch,
executorEnv = engineInstance.env,
mode = "Serving",
sparkEnv = engineInstance.sparkConf)
val models = engine.prepareDeploy(
sparkContext,
engineParams,
engineInstance.id,
modelsFromEngineInstance,
params = WorkflowParams()
)
val algorithms = engineParams.algorithmParamsList.map { case (n, p) =>
Doer(engine.algorithmClassMap(n), p)
}
val servingParamsWithName = engineParams.servingParams
val serving = Doer(engine.servingClassMap(servingParamsWithName._1),
servingParamsWithName._2)
actorSystem.actorOf(
Props(
classOf[ServerActor[Q, P]],
sc,
engineInstance,
engine,
engineLanguage,
engineParams.dataSourceParams._2,
engineParams.preparatorParams._2,
algorithms,
engineParams.algorithmParamsList.map(_._2),
models,
serving,
engineParams.servingParams._2))
}
}
class MasterActor (
sc: ServerConfig,
engineInstance: EngineInstance,
engineFactoryName: String) extends Actor with SSLConfiguration with KeyAuthentication {
val log = Logging(context.system, this)
implicit val system = context.system
var sprayHttpListener: Option[ActorRef] = None
var currentServerActor: Option[ActorRef] = None
var retry = 3
val serverConfig = ConfigFactory.load("server.conf")
val sslEnforced = serverConfig.getBoolean("org.apache.predictionio.server.ssl-enforced")
val protocol = if (sslEnforced) "https://" else "http://"
def undeploy(ip: String, port: Int): Unit = {
val serverUrl = s"${protocol}${ip}:${port}"
log.info(
s"Undeploying any existing engine instance at $serverUrl")
try {
val code = scalaj.http.Http(s"$serverUrl/stop")
.option(HttpOptions.allowUnsafeSSL)
.param(ServerKey.param, ServerKey.get)
.method("POST").asString.code
code match {
case 200 => ()
case 404 => log.error(
s"Another process is using $serverUrl. Unable to undeploy.")
case _ => log.error(
s"Another process is using $serverUrl, or an existing " +
s"engine server is not responding properly (HTTP $code). " +
"Unable to undeploy.")
}
} catch {
case e: java.net.ConnectException =>
log.warning(s"Nothing at $serverUrl")
case _: Throwable =>
log.error("Another process might be occupying " +
s"$ip:$port. Unable to undeploy.")
}
}
def receive: Actor.Receive = {
case x: StartServer =>
val actor = createServerActor(
sc,
engineInstance,
engineFactoryName)
currentServerActor = Some(actor)
undeploy(sc.ip, sc.port)
self ! BindServer()
case x: BindServer =>
currentServerActor map { actor =>
val settings = ServerSettings(system)
IO(Http) ! Http.Bind(
actor,
interface = sc.ip,
port = sc.port,
settings = Some(settings.copy(sslEncryption = sslEnforced)))
} getOrElse {
log.error("Cannot bind a non-existing server backend.")
}
case x: StopServer =>
log.info(s"Stop server command received.")
sprayHttpListener.map { l =>
log.info("Server is shutting down.")
l ! Http.Unbind(5.seconds)
system.shutdown()
} getOrElse {
log.warning("No active server is running.")
}
case x: ReloadServer =>
log.info("Reload server command received.")
val latestEngineInstance =
CreateServer.engineInstances.getLatestCompleted(
engineInstance.engineId,
engineInstance.engineVersion,
engineInstance.engineVariant)
latestEngineInstance map { lr =>
val actor = createServerActor(sc, lr, engineFactoryName)
sprayHttpListener.map { l =>
l ! Http.Unbind(5.seconds)
val settings = ServerSettings(system)
IO(Http) ! Http.Bind(
actor,
interface = sc.ip,
port = sc.port,
settings = Some(settings.copy(sslEncryption = sslEnforced)))
currentServerActor.get ! Kill
currentServerActor = Some(actor)
} getOrElse {
log.warning("No active server is running. Abort reloading.")
}
} getOrElse {
log.warning(
s"No latest completed engine instance for ${engineInstance.engineId} " +
s"${engineInstance.engineVersion}. Abort reloading.")
}
case x: Http.Bound =>
val serverUrl = s"${protocol}${sc.ip}:${sc.port}"
log.info(s"Engine is deployed and running. Engine API is live at ${serverUrl}.")
sprayHttpListener = Some(sender)
case x: Http.CommandFailed =>
if (retry > 0) {
retry -= 1
log.error(s"Bind failed. Retrying... ($retry more trial(s))")
context.system.scheduler.scheduleOnce(1.seconds) {
self ! BindServer()
}
} else {
log.error("Bind failed. Shutting down.")
system.shutdown()
}
}
def createServerActor(
sc: ServerConfig,
engineInstance: EngineInstance,
engineFactoryName: String): ActorRef = {
val (engineLanguage, engineFactory) =
WorkflowUtils.getEngine(engineFactoryName, getClass.getClassLoader)
val engine = engineFactory()
// EngineFactory return a base engine, which may not be deployable.
if (!engine.isInstanceOf[Engine[_,_,_,_,_,_]]) {
throw new NoSuchMethodException(s"Engine $engine is not deployable")
}
val deployableEngine = engine.asInstanceOf[Engine[_,_,_,_,_,_]]
CreateServer.createServerActorWithEngine(
sc,
engineInstance,
// engine,
deployableEngine,
engineLanguage)
}
}
class ServerActor[Q, P](
val args: ServerConfig,
val engineInstance: EngineInstance,
val engine: Engine[_, _, _, Q, P, _],
val engineLanguage: EngineLanguage.Value,
val dataSourceParams: Params,
val preparatorParams: Params,
val algorithms: Seq[BaseAlgorithm[_, _, Q, P]],
val algorithmsParams: Seq[Params],
val models: Seq[Any],
val serving: BaseServing[Q, P],
val servingParams: Params) extends Actor with HttpService with KeyAuthentication {
val serverStartTime = DateTime.now
val log = Logging(context.system, this)
var requestCount: Int = 0
var avgServingSec: Double = 0.0
var lastServingSec: Double = 0.0
/** The following is required by HttpService */
def actorRefFactory: ActorContext = context
implicit val timeout = Timeout(5, TimeUnit.SECONDS)
val pluginsActorRef =
context.actorOf(Props(classOf[PluginsActor], args.engineVariant), "PluginsActor")
val pluginContext = EngineServerPluginContext(log, args.engineVariant)
def receive: Actor.Receive = runRoute(myRoute)
val feedbackEnabled = if (args.feedback) {
if (args.accessKey.isEmpty) {
log.error("Feedback loop cannot be enabled because accessKey is empty.")
false
} else {
true
}
} else false
def remoteLog(logUrl: String, logPrefix: String, message: String): Unit = {
implicit val formats = Utils.json4sDefaultFormats
try {
scalaj.http.Http(logUrl).postData(
logPrefix + write(Map(
"engineInstance" -> engineInstance,
"message" -> message))).asString
} catch {
case e: Throwable =>
log.error(s"Unable to send remote log: ${e.getMessage}")
}
}
val myRoute =
path("") {
get {
respondWithMediaType(`text/html`) {
detach() {
complete {
html.index(
args,
engineInstance,
algorithms.map(_.toString),
algorithmsParams.map(_.toString),
models.map(_.toString),
dataSourceParams.toString,
preparatorParams.toString,
servingParams.toString,
serverStartTime,
feedbackEnabled,
args.eventServerIp,
args.eventServerPort,
requestCount,
avgServingSec,
lastServingSec
).toString
}
}
}
}
} ~
path("queries.json") {
post {
detach() {
entity(as[String]) { queryString =>
try {
val servingStartTime = DateTime.now
val jsonExtractorOption = args.jsonExtractor
val queryTime = DateTime.now
// Extract Query from Json
val query = JsonExtractor.extract(
jsonExtractorOption,
queryString,
algorithms.head.queryClass,
algorithms.head.querySerializer,
algorithms.head.gsonTypeAdapterFactories
)
val queryJValue = JsonExtractor.toJValue(
jsonExtractorOption,
query,
algorithms.head.querySerializer,
algorithms.head.gsonTypeAdapterFactories)
// Deploy logic. First call Serving.supplement, then Algo.predict,
// finally Serving.serve.
val supplementedQuery = serving.supplementBase(query)
// TODO: Parallelize the following.
val predictions = algorithms.zip(models).map { case (a, m) =>
a.predictBase(m, supplementedQuery)
}
// Notice that it is by design to call Serving.serve with the
// *original* query.
val prediction = serving.serveBase(query, predictions)
val predictionJValue = JsonExtractor.toJValue(
jsonExtractorOption,
prediction,
algorithms.head.querySerializer,
algorithms.head.gsonTypeAdapterFactories)
/** Handle feedback to Event Server
* Send the following back to the Event Server
* - appId
* - engineInstanceId
* - query
* - prediction
* - prId
*/
val result = if (feedbackEnabled) {
implicit val formats =
algorithms.headOption map { alg =>
alg.querySerializer
} getOrElse {
Utils.json4sDefaultFormats
}
// val genPrId = Random.alphanumeric.take(64).mkString
def genPrId: String = Random.alphanumeric.take(64).mkString
val newPrId = prediction match {
case id: WithPrId =>
val org = id.prId
if (org.isEmpty) genPrId else org
case _ => genPrId
}
// also save Query's prId as prId of this pio_pr predict events
val queryPrId =
query match {
case id: WithPrId =>
Map("prId" -> id.prId)
case _ =>
Map.empty
}
val data = Map(
// "appId" -> dataSourceParams.asInstanceOf[ParamsWithAppId].appId,
"event" -> "predict",
"eventTime" -> queryTime.toString(),
"entityType" -> "pio_pr", // prediction result
"entityId" -> newPrId,
"properties" -> Map(
"engineInstanceId" -> engineInstance.id,
"query" -> query,
"prediction" -> prediction)) ++ queryPrId
// At this point args.accessKey should be Some(String).
val accessKey = args.accessKey.getOrElse("")
val f: Future[Int] = Future {
scalaj.http.Http(
s"http://${args.eventServerIp}:${args.eventServerPort}/" +
s"events.json?accessKey=$accessKey").postData(
write(data)).header(
"content-type", "application/json").asString.code
}
f onComplete {
case Success(code) => {
if (code != 201) {
log.error(s"Feedback event failed. Status code: $code."
+ s"Data: ${write(data)}.")
}
}
case Failure(t) => {
log.error(s"Feedback event failed: ${t.getMessage}") }
}
// overwrite prId in predictedResult
// - if it is WithPrId,
// then overwrite with new prId
// - if it is not WithPrId, no prId injection
if (prediction.isInstanceOf[WithPrId]) {
predictionJValue merge parse(s"""{"prId" : "$newPrId"}""")
} else {
predictionJValue
}
} else predictionJValue
val pluginResult =
pluginContext.outputBlockers.values.foldLeft(result) { case (r, p) =>
p.process(engineInstance, queryJValue, r, pluginContext)
}
pluginsActorRef ! (engineInstance, queryJValue, result)
// Bookkeeping
val servingEndTime = DateTime.now
lastServingSec =
(servingEndTime.getMillis - servingStartTime.getMillis) / 1000.0
avgServingSec =
((avgServingSec * requestCount) + lastServingSec) /
(requestCount + 1)
requestCount += 1
respondWithMediaType(`application/json`) {
complete(compact(render(pluginResult)))
}
} catch {
case e: MappingException =>
val msg = s"Query:\n$queryString\n\nStack Trace:\n" +
s"${ExceptionUtils.getStackTrace(e)}\n\n"
log.error(msg)
args.logUrl map { url =>
remoteLog(
url,
args.logPrefix.getOrElse(""),
msg)
}
complete(StatusCodes.BadRequest, e.getMessage)
case e: Throwable =>
val msg = s"Query:\n$queryString\n\nStack Trace:\n" +
s"${ExceptionUtils.getStackTrace(e)}\n\n"
log.error(msg)
args.logUrl map { url =>
remoteLog(
url,
args.logPrefix.getOrElse(""),
msg)
}
complete(StatusCodes.InternalServerError, msg)
}
}
}
}
} ~
path("reload") {
authenticate(withAccessKeyFromFile) { request =>
post {
complete {
context.actorSelection("/user/master") ! ReloadServer()
"Reloading..."
}
}
}
} ~
path("stop") {
authenticate(withAccessKeyFromFile) { request =>
post {
complete {
context.system.scheduler.scheduleOnce(1.seconds) {
context.actorSelection("/user/master") ! StopServer()
}
"Shutting down..."
}
}
}
} ~
pathPrefix("assets") {
getFromResourceDirectory("assets")
} ~
path("plugins.json") {
import EngineServerJson4sSupport._
get {
respondWithMediaType(MediaTypes.`application/json`) {
complete {
Map("plugins" -> Map(
"outputblockers" -> pluginContext.outputBlockers.map { case (n, p) =>
n -> Map(
"name" -> p.pluginName,
"description" -> p.pluginDescription,
"class" -> p.getClass.getName,
"params" -> pluginContext.pluginParams(p.pluginName))
},
"outputsniffers" -> pluginContext.outputSniffers.map { case (n, p) =>
n -> Map(
"name" -> p.pluginName,
"description" -> p.pluginDescription,
"class" -> p.getClass.getName,
"params" -> pluginContext.pluginParams(p.pluginName))
}
))
}
}
}
} ~
path("plugins" / Segments) { segments =>
import EngineServerJson4sSupport._
get {
respondWithMediaType(MediaTypes.`application/json`) {
complete {
val pluginArgs = segments.drop(2)
val pluginType = segments(0)
val pluginName = segments(1)
pluginType match {
case EngineServerPlugin.outputBlocker =>
pluginContext.outputBlockers(pluginName).handleREST(
pluginArgs)
case EngineServerPlugin.outputSniffer =>
pluginsActorRef ? PluginsActor.HandleREST(
pluginName = pluginName,
pluginArgs = pluginArgs) map {
_.asInstanceOf[String]
}
}
}
}
}
}
}
object EngineServerJson4sSupport extends Json4sSupport {
implicit def json4sFormats: Formats = DefaultFormats
}