blob: 71bf36c313e3b5bbbfce452a8206795b7eb4679d [file] [log] [blame]
/** Copyright 2014 TappingStone, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.prediction.workflow
import io.prediction.controller.IEngineFactory
import io.prediction.controller.EmptyParams
import io.prediction.controller.Engine
import io.prediction.controller.PAlgorithm
import io.prediction.controller.Params
import io.prediction.controller.java.LJavaAlgorithm
import io.prediction.controller.java.LJavaServing
import io.prediction.core.BaseAlgorithm
import io.prediction.core.BaseServing
import io.prediction.core.Doer
import io.prediction.data.storage.EngineInstance
import io.prediction.data.storage.EngineManifest
import io.prediction.data.storage.Storage
import akka.actor.{ Actor, ActorRef, ActorSystem, Kill, Props }
import akka.event.Logging
import akka.io.IO
import akka.pattern.ask
import akka.util.Timeout
import com.github.nscala_time.time.Imports.DateTime
import com.google.gson.Gson
import com.twitter.chill.KryoInjection
import com.twitter.chill.ScalaKryoInstantiator
import grizzled.slf4j.Logging
import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.SparkContext._
import org.apache.spark.rdd.RDD
import org.json4s._
import org.json4s.native.JsonMethods._
import org.json4s.native.Serialization.{ read, write }
import spray.can.Http
import spray.routing._
import spray.http._
import spray.http.MediaTypes._
import scala.concurrent.duration._
import scala.concurrent.ExecutionContext.Implicits.global
import scala.language.existentials
import scala.reflect.runtime.universe
import java.io.File
import java.io.ByteArrayInputStream
import java.io.ObjectInputStream
import java.net.URLClassLoader
class KryoInstantiator(classLoader: ClassLoader) extends ScalaKryoInstantiator {
override def newKryo = {
val kryo = super.newKryo
kryo.setClassLoader(classLoader)
kryo
}
}
case class ServerConfig(
engineInstanceId: String = "",
engineId: Option[String] = None,
engineVersion: Option[String] = None,
ip: String = "localhost",
port: Int = 8000)
case class StartServer()
case class StopServer()
case class ReloadServer()
object CreateServer extends Logging {
val actorSystem = ActorSystem("pio-server")
val engineInstances = Storage.getMetaDataEngineInstances
val engineManifests = Storage.getMetaDataEngineManifests
val modeldata = Storage.getModelDataModels
def main(args: Array[String]): Unit = {
val parser = new scopt.OptionParser[ServerConfig]("CreateServer") {
opt[String]("engineId") action { (x, c) =>
c.copy(engineId = Some(x))
} text("Engine ID.")
opt[String]("engineVersion") action { (x, c) =>
c.copy(engineVersion = Some(x))
} text("Engine version.")
opt[String]("ip") action { (x, c) =>
c.copy(ip = x)
} text("IP to bind to (default: localhost).")
opt[Int]("port") action { (x, c) =>
c.copy(port = x)
} text("Port to bind to (default: 8000).")
opt[String]("engineInstanceId") required() action { (x, c) =>
c.copy(engineInstanceId = x)
} text("Engine instance ID.")
}
parser.parse(args, ServerConfig()) map { sc =>
engineInstances.get(sc.engineInstanceId) map { engineInstance =>
val engineId = sc.engineId.getOrElse(engineInstance.engineId)
val engineVersion = sc.engineVersion.getOrElse(
engineInstance.engineVersion)
engineManifests.get(engineId, engineVersion) map { manifest =>
WorkflowUtils.checkUpgrade("deployment")
val engineFactoryName = manifest.engineFactory
val master = actorSystem.actorOf(Props(
classOf[MasterActor],
sc,
engineInstance,
engineFactoryName,
manifest),
"master")
implicit val timeout = Timeout(5.seconds)
master ? StartServer()
actorSystem.awaitTermination
} getOrElse {
error(s"Invalid engine ID or version. Aborting server.")
}
} getOrElse {
error(s"Invalid engine instance ID. Aborting server.")
}
}
}
def createServerActorWithEngine[TD, DP, PD, Q, P, A](
sc: ServerConfig,
engineInstance: EngineInstance,
engine: Engine[TD, DP, PD, Q, P, A],
engineLanguage: EngineLanguage.Value,
manifest: EngineManifest): ActorRef = {
implicit val formats = DefaultFormats
val algorithmsParamsWithNames =
read[Seq[(String, JValue)]](engineInstance.algorithmsParams).map {
case (algoName, params) =>
val extractedParams = WorkflowUtils.extractParams(
engineLanguage,
compact(render(params)),
engine.algorithmClassMap(algoName))
(algoName, extractedParams)
}
val algorithmsParams = algorithmsParamsWithNames.map { _._2 }
val algorithms = algorithmsParamsWithNames.map { case (n, p) =>
Doer(engine.algorithmClassMap(n), p)
}
val servingParams =
if (engineInstance.servingParams == "")
EmptyParams()
else
WorkflowUtils.extractParams(
engineLanguage,
engineInstance.servingParams,
engine.servingClass)
val serving = Doer(engine.servingClass, servingParams)
val pAlgorithmExists =
algorithms.exists(_.isInstanceOf[PAlgorithm[_, PD, _, Q, P]])
val sparkContext =
if (pAlgorithmExists)
Some(WorkflowContext(engineInstance.batch, engineInstance.env))
else
None
val evalPreparedMap = sparkContext map { sc =>
logger.info("Data Source")
val dataSourceParams = WorkflowUtils.extractParams(
engineLanguage,
engineInstance.dataSourceParams,
engine.dataSourceClass)
val dataSource = Doer(engine.dataSourceClass, dataSourceParams)
val evalParamsDataMap
: Map[EI, (DP, TD, RDD[(Q, A)])] = dataSource
.readBase(sc)
.zipWithIndex
.map(_.swap)
.toMap
val evalDataMap: Map[EI, (TD, RDD[(Q, A)])] = evalParamsDataMap.map {
case(ei, e) => (ei -> (e._2, e._3))
}
logger.info("Preparator")
val preparatorParams = WorkflowUtils.extractParams(
engineLanguage,
engineInstance.preparatorParams,
engine.preparatorClass)
val preparator = Doer(engine.preparatorClass, preparatorParams)
val evalPreparedMap: Map[EI, PD] = evalDataMap
.map{ case (ei, data) => (ei, preparator.prepareBase(sc, data._1)) }
logger.info("Preparator complete")
evalPreparedMap
}
val kryoInstantiator = new KryoInstantiator(getClass.getClassLoader)
val kryo = KryoInjection.instance(kryoInstantiator)
val modelsFromEngineInstance =
kryo.invert(modeldata.get(engineInstance.id).get.models).get.
asInstanceOf[Seq[Seq[Any]]]
val models = modelsFromEngineInstance.head.zip(algorithms).
zip(algorithmsParams).map {
case ((m, a), p) =>
if (m.isInstanceOf[PersistentModelManifest]) {
info("Custom-persisted model detected for algorithm " +
a.getClass.getName)
WorkflowUtils.getPersistentModel(
m.asInstanceOf[PersistentModelManifest],
engineInstance.id,
p,
sparkContext,
getClass.getClassLoader)
} else if (a.isInstanceOf[PAlgorithm[_, _, _, Q, P]]) {
info(s"Parallel model detected for algorithm ${a.getClass.getName}")
a.trainBase(sparkContext.get, evalPreparedMap.get(0))
} else {
try {
info(s"Loaded model ${m.getClass.getName} for algorithm " +
s"${a.getClass.getName}")
m
} catch {
case e: NullPointerException =>
warn(s"Null model detected for algorithm ${a.getClass.getName}")
m
}
}
}
actorSystem.actorOf(
Props(
classOf[ServerActor[Q, P]],
sc,
engineInstance,
engine,
engineLanguage,
manifest,
algorithms,
algorithmsParams,
models,
serving,
servingParams))
}
}
class MasterActor(
sc: ServerConfig,
engineInstance: EngineInstance,
engineFactoryName: String,
manifest: EngineManifest) extends Actor {
val log = Logging(context.system, this)
implicit val system = context.system
var sprayHttpListener: Option[ActorRef] = None
var currentServerActor: Option[ActorRef] = None
def receive = {
case x: StartServer =>
val actor = createServerActor(
sc,
engineInstance,
engineFactoryName,
manifest)
IO(Http) ! Http.Bind(actor, interface = sc.ip, port = sc.port)
currentServerActor = Some(actor)
case x: StopServer =>
log.info(s"Stop server command received.")
sprayHttpListener.map { l =>
log.info("Server is shutting down.")
l ! Http.Unbind(5.seconds)
system.shutdown
} getOrElse {
log.warning("No active server is running.")
}
case x: ReloadServer =>
log.info("Reload server command received.")
val latestEngineInstance =
CreateServer.engineInstances.getLatestCompleted(
manifest.id,
manifest.version)
latestEngineInstance map { lr =>
val actor = createServerActor(sc, lr, engineFactoryName, manifest)
sprayHttpListener.map { l =>
l ! Http.Unbind(5.seconds)
IO(Http) ! Http.Bind(actor, interface = sc.ip, port = sc.port)
currentServerActor.get ! Kill
currentServerActor = Some(actor)
} getOrElse {
log.warning("No active server is running. Abort reloading.")
}
} getOrElse {
log.warning(
s"No latest completed engine instance for ${manifest.id} " +
s"${manifest.version}. Abort reloading.")
}
case x: Http.Bound =>
log.info("Bind successful. Ready to serve.")
sprayHttpListener = Some(sender)
case x: Http.CommandFailed =>
log.error("Bind failed. Shutting down.")
system.shutdown
}
def createServerActor(
sc: ServerConfig,
engineInstance: EngineInstance,
engineFactoryName: String,
manifest: EngineManifest): ActorRef = {
val (engineLanguage, engine) =
WorkflowUtils.getEngine(engineFactoryName, getClass.getClassLoader)
CreateServer.createServerActorWithEngine(
sc,
engineInstance,
engine,
engineLanguage,
manifest)
}
}
class ServerActor[Q, P](
val args: ServerConfig,
val engineInstance: EngineInstance,
val engine: Engine[_, _, _, Q, P, _],
val engineLanguage: EngineLanguage.Value,
val manifest: EngineManifest,
val algorithms: Seq[BaseAlgorithm[_ <: Params, _, _, Q, P]],
val algorithmsParams: Seq[Params],
val models: Seq[Any],
val serving: BaseServing[_ <: Params, Q, P],
val servingParams: Params) extends Actor with HttpService {
val serverStartTime = DateTime.now
lazy val gson = new Gson
val log = Logging(context.system, this)
val (javaAlgorithms, scalaAlgorithms) =
algorithms.partition(_.isInstanceOf[LJavaAlgorithm[_, _, _, Q, P]])
def actorRefFactory = context
def receive = runRoute(myRoute)
val myRoute =
path("") {
get {
respondWithMediaType(`text/html`) {
detach() {
complete {
html.index(
args,
manifest,
engineInstance,
algorithms.map(_.toString),
algorithmsParams.map(_.toString),
models.map(_.toString),
servingParams.toString,
serverStartTime).toString
}
}
}
}
} ~
path("queries.json") {
post {
detach() {
entity(as[String]) { queryString =>
try {
val javaQuery = if (!javaAlgorithms.isEmpty) {
Some(gson.fromJson(
queryString,
javaAlgorithms.head.asInstanceOf[LJavaAlgorithm[_, _, _, Q, P]].
queryClass))
} else None
val scalaQuery = if (!scalaAlgorithms.isEmpty) {
Some(Extraction.extract(parse(queryString))(
scalaAlgorithms.head.querySerializer,
scalaAlgorithms.head.queryManifest))
} else None
val predictions = algorithms.zipWithIndex.map { case (a, ai) =>
if (a.isInstanceOf[LJavaAlgorithm[_, _, _, Q, P]])
a.predictBase(models(ai), javaQuery.get)
else
a.predictBase(models(ai), scalaQuery.get)
}
val json = if (serving.isInstanceOf[LJavaServing[_, Q, P]]) {
val prediction = serving.serveBase(javaQuery.get, predictions)
gson.toJson(prediction)
} else {
val prediction = serving.serveBase(scalaQuery.get, predictions)
compact(render(Extraction.decompose(prediction)(
scalaAlgorithms.head.querySerializer)))
}
respondWithMediaType(`application/json`) {
complete(json)
}
} catch {
case e: MappingException =>
log.error(
s"Query '${queryString}' is invalid. Reason: ${e.getMessage}")
complete(StatusCodes.BadRequest, e.getMessage)
}
}
}
}
} ~
path("reload") {
get {
complete {
context.actorSelection("/user/master") ! ReloadServer()
"Reloading..."
}
}
} ~
path("stop") {
get {
complete {
context.system.scheduler.scheduleOnce(1.seconds) {
context.actorSelection("/user/master") ! StopServer()
}
"Shutting down..."
}
}
} ~
pathPrefix("assets") {
getFromResourceDirectory("assets")
}
}