blob: 7745f30c7c9715c6fc73058df63b82fc4b840df6 [file] [log] [blame]
/** Copyright 2015 TappingStone, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.prediction.workflow
import akka.actor._
import akka.event.Logging
import akka.io.IO
import akka.pattern.ask
import akka.util.Timeout
import com.github.nscala_time.time.Imports.DateTime
import com.google.gson.Gson
import com.twitter.chill.KryoBase
import com.twitter.chill.KryoInjection
import com.twitter.chill.ScalaKryoInstantiator
import grizzled.slf4j.Logging
import io.prediction.controller.Engine
import io.prediction.controller.Params
import io.prediction.controller.Utils
import io.prediction.controller.WithPrId
import io.prediction.controller.WorkflowParams
import io.prediction.controller.java.LJavaAlgorithm
import io.prediction.controller.java.LJavaServing
import io.prediction.controller.java.PJavaAlgorithm
import io.prediction.core.BaseAlgorithm
import io.prediction.core.BaseServing
import io.prediction.core.Doer
import io.prediction.data.storage.EngineInstance
import io.prediction.data.storage.EngineManifest
import io.prediction.data.storage.Storage
import org.json4s._
import org.json4s.native.JsonMethods._
import org.json4s.native.Serialization.write
import spray.can.Http
import spray.http.MediaTypes._
import spray.http._
import spray.routing._
import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.Future
import scala.concurrent.duration._
import scala.concurrent.future
import scala.language.existentials
import scala.util.Failure
import scala.util.Random
import scala.util.Success
import java.io.PrintWriter
import java.io.StringWriter
class KryoInstantiator(classLoader: ClassLoader) extends ScalaKryoInstantiator {
override def newKryo(): KryoBase = {
val kryo = super.newKryo()
kryo.setClassLoader(classLoader)
kryo
}
}
case class ServerConfig(
batch: String = "",
engineInstanceId: String = "",
engineId: Option[String] = None,
engineVersion: Option[String] = None,
ip: String = "localhost",
port: Int = 8000,
feedback: Boolean = false,
eventServerIp: String = "localhost",
eventServerPort: Int = 7070,
accessKey: Option[String] = None,
logUrl: Option[String] = None,
logPrefix: Option[String] = None,
logFile: Option[String] = None,
verbose: Boolean = false,
debug: Boolean = false)
case class StartServer()
case class BindServer()
case class StopServer()
case class ReloadServer()
case class UpgradeCheck()
object CreateServer extends Logging {
val actorSystem = ActorSystem("pio-server")
val engineInstances = Storage.getMetaDataEngineInstances
val engineManifests = Storage.getMetaDataEngineManifests
val modeldata = Storage.getModelDataModels
def main(args: Array[String]): Unit = {
val parser = new scopt.OptionParser[ServerConfig]("CreateServer") {
opt[String]("batch") action { (x, c) =>
c.copy(batch = x)
} text("Batch label of the deployment.")
opt[String]("engineId") action { (x, c) =>
c.copy(engineId = Some(x))
} text("Engine ID.")
opt[String]("engineVersion") action { (x, c) =>
c.copy(engineVersion = Some(x))
} text("Engine version.")
opt[String]("ip") action { (x, c) =>
c.copy(ip = x)
} text("IP to bind to (default: localhost).")
opt[Int]("port") action { (x, c) =>
c.copy(port = x)
} text("Port to bind to (default: 8000).")
opt[String]("engineInstanceId") required() action { (x, c) =>
c.copy(engineInstanceId = x)
} text("Engine instance ID.")
opt[Unit]("feedback") action { (_, c) =>
c.copy(feedback = true)
} text("Enable feedback loop to event server.")
opt[String]("event-server-ip") action { (x, c) =>
c.copy(eventServerIp = x)
} text("Event server IP. Default: localhost")
opt[Int]("event-server-port") action { (x, c) =>
c.copy(eventServerPort = x)
} text("Event server port. Default: 7070")
opt[String]("accesskey") action { (x, c) =>
c.copy(accessKey = Some(x))
} text("Event server access key.")
opt[String]("log-url") action { (x, c) =>
c.copy(logUrl = Some(x))
}
opt[String]("log-prefix") action { (x, c) =>
c.copy(logPrefix = Some(x))
}
opt[String]("log-file") action { (x, c) =>
c.copy(logFile = Some(x))
}
opt[Unit]("verbose") action { (x, c) =>
c.copy(verbose = true)
} text("Enable verbose output.")
opt[Unit]("debug") action { (x, c) =>
c.copy(debug = true)
} text("Enable debug output.")
}
parser.parse(args, ServerConfig()) map { sc =>
WorkflowUtils.modifyLogging(sc.verbose)
engineInstances.get(sc.engineInstanceId) map { engineInstance =>
val engineId = sc.engineId.getOrElse(engineInstance.engineId)
val engineVersion = sc.engineVersion.getOrElse(
engineInstance.engineVersion)
engineManifests.get(engineId, engineVersion) map { manifest =>
val engineFactoryName = engineInstance.engineFactory
val upgrade = actorSystem.actorOf(Props(
classOf[UpgradeActor],
engineFactoryName))
actorSystem.scheduler.schedule(
0.seconds,
1.days,
upgrade,
UpgradeCheck())
val master = actorSystem.actorOf(Props(
classOf[MasterActor],
sc,
engineInstance,
engineFactoryName,
manifest),
"master")
implicit val timeout = Timeout(5.seconds)
master ? StartServer()
actorSystem.awaitTermination
} getOrElse {
error(s"Invalid engine ID or version. Aborting server.")
}
} getOrElse {
error(s"Invalid engine instance ID. Aborting server.")
}
}
}
def createServerActorWithEngine[TD, EIN, PD, Q, P, A](
sc: ServerConfig,
engineInstance: EngineInstance,
engine: Engine[TD, EIN, PD, Q, P, A],
engineLanguage: EngineLanguage.Value,
manifest: EngineManifest): ActorRef = {
val engineParams = engine.engineInstanceToEngineParams(engineInstance)
val kryoInstantiator = new KryoInstantiator(getClass.getClassLoader)
val kryo = KryoInjection.instance(kryoInstantiator)
val modelsFromEngineInstance =
kryo.invert(modeldata.get(engineInstance.id).get.models).get.
asInstanceOf[Seq[Any]]
val sparkContext = WorkflowContext(
batch = if (sc.batch == "") engineInstance.batch else sc.batch,
executorEnv = engineInstance.env,
mode = "Serving",
sparkEnv = engineInstance.sparkConf)
val models = engine.prepareDeploy(
sparkContext,
engineParams,
engineInstance.id,
modelsFromEngineInstance,
params = WorkflowParams()
)
val algorithms = engineParams.algorithmParamsList.map { case (n, p) =>
Doer(engine.algorithmClassMap(n), p)
}
val servingParamsWithName = engineParams.servingParams
val serving = Doer(engine.servingClassMap(servingParamsWithName._1),
servingParamsWithName._2)
actorSystem.actorOf(
Props(
classOf[ServerActor[Q, P]],
sc,
engineInstance,
engine,
engineLanguage,
manifest,
engineParams.dataSourceParams._2,
engineParams.preparatorParams._2,
algorithms,
engineParams.algorithmParamsList.map(_._2),
models,
serving,
engineParams.servingParams._2))
}
}
class UpgradeActor(engineClass: String) extends Actor {
val log = Logging(context.system, this)
implicit val system = context.system
def receive: Actor.Receive = {
case x: UpgradeCheck =>
WorkflowUtils.checkUpgrade("deployment", engineClass)
}
}
class MasterActor(
sc: ServerConfig,
engineInstance: EngineInstance,
engineFactoryName: String,
manifest: EngineManifest) extends Actor {
val log = Logging(context.system, this)
implicit val system = context.system
var sprayHttpListener: Option[ActorRef] = None
var currentServerActor: Option[ActorRef] = None
var retry = 3
def undeploy(ip: String, port: Int): Unit = {
val serverUrl = s"http://${ip}:${port}"
log.info(
s"Undeploying any existing engine instance at $serverUrl")
try {
val code = scalaj.http.Http(s"$serverUrl/stop").asString.code
code match {
case 200 => Unit
case 404 => log.error(
s"Another process is using $serverUrl. Unable to undeploy.")
case _ => log.error(
s"Another process is using $serverUrl, or an existing " +
s"engine server is not responding properly (HTTP $code). " +
"Unable to undeploy.")
}
} catch {
case e: java.net.ConnectException =>
log.warning(s"Nothing at $serverUrl")
case _: Throwable =>
log.error("Another process might be occupying " +
s"$ip:$port. Unable to undeploy.")
}
}
def receive: Actor.Receive = {
case x: StartServer =>
val actor = createServerActor(
sc,
engineInstance,
engineFactoryName,
manifest)
currentServerActor = Some(actor)
undeploy(sc.ip, sc.port)
self ! BindServer()
case x: BindServer =>
currentServerActor map { actor =>
IO(Http) ! Http.Bind(actor, interface = sc.ip, port = sc.port)
} getOrElse {
log.error("Cannot bind a non-existing server backend.")
}
case x: StopServer =>
log.info(s"Stop server command received.")
sprayHttpListener.map { l =>
log.info("Server is shutting down.")
l ! Http.Unbind(5.seconds)
system.shutdown
} getOrElse {
log.warning("No active server is running.")
}
case x: ReloadServer =>
log.info("Reload server command received.")
val latestEngineInstance =
CreateServer.engineInstances.getLatestCompleted(
manifest.id,
manifest.version,
engineInstance.engineVariant)
latestEngineInstance map { lr =>
val actor = createServerActor(sc, lr, engineFactoryName, manifest)
sprayHttpListener.map { l =>
l ! Http.Unbind(5.seconds)
IO(Http) ! Http.Bind(actor, interface = sc.ip, port = sc.port)
currentServerActor.get ! Kill
currentServerActor = Some(actor)
} getOrElse {
log.warning("No active server is running. Abort reloading.")
}
} getOrElse {
log.warning(
s"No latest completed engine instance for ${manifest.id} " +
s"${manifest.version}. Abort reloading.")
}
case x: Http.Bound =>
log.info("Bind successful. Ready to serve.")
sprayHttpListener = Some(sender)
case x: Http.CommandFailed =>
if (retry > 0) {
retry -= 1
log.error(s"Bind failed. Retrying... ($retry more trial(s))")
context.system.scheduler.scheduleOnce(1.seconds) {
self ! BindServer()
}
} else {
log.error("Bind failed. Shutting down.")
system.shutdown
}
}
def createServerActor(
sc: ServerConfig,
engineInstance: EngineInstance,
engineFactoryName: String,
manifest: EngineManifest): ActorRef = {
val (engineLanguage, engineFactory) =
WorkflowUtils.getEngine(engineFactoryName, getClass.getClassLoader)
val engine = engineFactory()
// EngineFactory return a base engine, which may not be deployable.
if (!engine.isInstanceOf[Engine[_,_,_,_,_,_]]) {
throw new NoSuchMethodException(s"Engine $engine is not deployable")
}
val deployableEngine = engine.asInstanceOf[Engine[_,_,_,_,_,_]]
CreateServer.createServerActorWithEngine(
sc,
engineInstance,
// engine,
deployableEngine,
engineLanguage,
manifest)
}
}
class ServerActor[Q, P](
val args: ServerConfig,
val engineInstance: EngineInstance,
val engine: Engine[_, _, _, Q, P, _],
val engineLanguage: EngineLanguage.Value,
val manifest: EngineManifest,
val dataSourceParams: Params,
val preparatorParams: Params,
val algorithms: Seq[BaseAlgorithm[_, _, Q, P]],
val algorithmsParams: Seq[Params],
val models: Seq[Any],
val serving: BaseServing[Q, P],
val servingParams: Params) extends Actor with HttpService {
val serverStartTime = DateTime.now
lazy val gson = new Gson
val log = Logging(context.system, this)
val (javaAlgorithms, scalaAlgorithms) = algorithms.partition(_.isJava)
var requestCount: Int = 0
var avgServingSec: Double = 0.0
var lastServingSec: Double = 0.0
def actorRefFactory: ActorContext = context
def receive: Actor.Receive = runRoute(myRoute)
val feedbackEnabled = if (args.feedback) {
if (args.accessKey.isEmpty) {
log.error("Feedback loop cannot be enabled because accessKey is empty.")
false
} else {
true
}
} else false
def remoteLog(logUrl: String, logPrefix: String, message: String): Unit = {
implicit val formats = Utils.json4sDefaultFormats
try {
scalaj.http.Http(logUrl).postData(
logPrefix + write(Map(
"engineInstance" -> engineInstance,
"message" -> message))).asString
} catch {
case e: Throwable =>
log.error(s"Unable to send remote log: ${e.getMessage}")
}
}
def getStackTraceString(e: Throwable): String = {
val writer = new StringWriter()
val printWriter = new PrintWriter(writer)
e.printStackTrace(printWriter)
writer.toString
}
val myRoute =
path("") {
get {
respondWithMediaType(`text/html`) {
detach() {
complete {
html.index(
args,
manifest,
engineInstance,
algorithms.map(_.toString),
algorithmsParams.map(_.toString),
models.map(_.toString),
dataSourceParams.toString,
preparatorParams.toString,
servingParams.toString,
serverStartTime,
feedbackEnabled,
args.eventServerIp,
args.eventServerPort,
requestCount,
avgServingSec,
lastServingSec
).toString
}
}
}
}
} ~
path("queries.json") {
post {
detach() {
entity(as[String]) { queryString =>
try {
val servingStartTime = DateTime.now
val queryTime = DateTime.now
val javaQuery = javaAlgorithms.headOption map { alg =>
val queryClass = if (
alg.isInstanceOf[LJavaAlgorithm[_, _, Q, P]]) {
alg.asInstanceOf[LJavaAlgorithm[_, _, Q, P]].queryClass
} else {
alg.asInstanceOf[PJavaAlgorithm[_, _, Q, P]].queryClass
}
gson.fromJson(queryString, queryClass)
}
val scalaQuery = scalaAlgorithms.headOption map { alg =>
Extraction.extract(parse(queryString))(
alg.querySerializer, alg.queryManifest)
}
val predictions = algorithms.zipWithIndex.map { case (a, ai) =>
if (a.isJava) {
a.predictBase(models(ai), javaQuery.get)
} else {
a.predictBase(models(ai), scalaQuery.get)
}
}
val r = if (serving.isInstanceOf[LJavaServing[Q, P]]) {
val prediction = serving.serveBase(javaQuery.get, predictions)
// parse to Json4s JObject for later merging with prId
(parse(gson.toJson(prediction)), prediction, javaQuery.get)
} else {
val prediction = serving.serveBase(scalaQuery.get, predictions)
(Extraction.decompose(prediction)(
scalaAlgorithms.head.querySerializer),
prediction,
scalaQuery.get)
}
/** Handle feedback to Event Server
* Send the following back to the Event Server
* - appId
* - engineInstanceId
* - query
* - prediction
* - prId
*/
val result = if (feedbackEnabled) {
implicit val formats =
if (!scalaAlgorithms.isEmpty) {
scalaAlgorithms.head.querySerializer
} else {
Utils.json4sDefaultFormats
}
// val genPrId = Random.alphanumeric.take(64).mkString
def genPrId: String = Random.alphanumeric.take(64).mkString
val newPrId = if (r._2.isInstanceOf[WithPrId]) {
val org = r._2.asInstanceOf[WithPrId].prId
if (org.isEmpty) genPrId else org
} else genPrId
// also save Query's prId as prId of this pio_pr predict events
val queryPrId =
if (r._3.isInstanceOf[WithPrId]) {
Map("prId" ->
r._3.asInstanceOf[WithPrId].prId)
} else {
Map()
}
val data = Map(
// "appId" -> dataSourceParams.asInstanceOf[ParamsWithAppId].appId,
"event" -> "predict",
"eventTime" -> queryTime.toString(),
"entityType" -> "pio_pr", // prediction result
"entityId" -> newPrId,
"properties" -> Map(
"engineInstanceId" -> engineInstance.id,
"query" -> r._3,
"prediction" -> r._2)) ++ queryPrId
// At this point args.accessKey should be Some(String).
val accessKey = args.accessKey.getOrElse("")
val f: Future[Int] = future {
scalaj.http.Http(
s"http://${args.eventServerIp}:${args.eventServerPort}/" +
s"events.json?accessKey=$accessKey").postData(
write(data)).header(
"content-type", "application/json").asString.code
}
f onComplete {
case Success(code) => {
if (code != 201) {
log.error(s"Feedback event failed. Status code: $code."
+ s"Data: ${write(data)}.")
}
}
case Failure(t) => {
log.error(s"Feedback event failed: ${t.getMessage}") }
}
// overwrite prId in predictedResult
// - if it is WithPrId,
// then overwrite with new prId
// - if it is not WithPrId, no prId injection
if (r._2.isInstanceOf[WithPrId]) {
r._1 merge parse( s"""{"prId" : "$newPrId"}""")
} else {
r._1
}
} else r._1
// Bookkeeping
val servingEndTime = DateTime.now
lastServingSec =
(servingEndTime.getMillis - servingStartTime.getMillis) / 1000.0
avgServingSec =
((avgServingSec * requestCount) + lastServingSec) /
(requestCount + 1)
requestCount += 1
respondWithMediaType(`application/json`) {
complete(compact(render(result)))
}
} catch {
case e: MappingException =>
log.error(
s"Query '$queryString' is invalid. Reason: ${e.getMessage}")
args.logUrl map { url =>
remoteLog(
url,
args.logPrefix.getOrElse(""),
s"Query:\n$queryString\n\nStack Trace:\n" +
s"${getStackTraceString(e)}\n\n")
}
complete(StatusCodes.BadRequest, e.getMessage)
case e: Throwable =>
val msg = s"Query:\n$queryString\n\nStack Trace:\n" +
s"${getStackTraceString(e)}\n\n"
log.error(msg)
args.logUrl map { url =>
remoteLog(
url,
args.logPrefix.getOrElse(""),
msg)
}
complete(StatusCodes.InternalServerError, msg)
}
}
}
}
} ~
path("reload") {
get {
complete {
context.actorSelection("/user/master") ! ReloadServer()
"Reloading..."
}
}
} ~
path("stop") {
get {
complete {
context.system.scheduler.scheduleOnce(1.seconds) {
context.actorSelection("/user/master") ! StopServer()
}
"Shutting down..."
}
}
} ~
pathPrefix("assets") {
getFromResourceDirectory("assets")
}
}