blob: 126f1939aa823b6f242c73e805225effd7433546 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.s2graph.rest.netty
import java.util.concurrent.Executors
import com.typesafe.config.ConfigFactory
import io.netty.bootstrap.ServerBootstrap
import io.netty.buffer.{ByteBuf, Unpooled}
import io.netty.channel._
import io.netty.channel.epoll.{EpollEventLoopGroup, EpollServerSocketChannel}
import io.netty.channel.nio.NioEventLoopGroup
import io.netty.channel.socket.SocketChannel
import io.netty.channel.socket.nio.NioServerSocketChannel
import io.netty.handler.codec.http.HttpHeaders._
import io.netty.handler.codec.http._
import io.netty.handler.logging.{LogLevel, LoggingHandler}
import io.netty.util.CharsetUtil
import org.apache.s2graph.core.GraphExceptions.{BadQueryException}
import org.apache.s2graph.core.mysqls.Experiment
import org.apache.s2graph.core.rest.RestHandler
import org.apache.s2graph.core.rest.RestHandler.{CanLookup, HandlerResult}
import org.apache.s2graph.core.utils.Extensions._
import org.apache.s2graph.core.utils.logger
import org.apache.s2graph.core.{S2Graph, PostProcess}
import play.api.libs.json._
import scala.collection.mutable
import scala.concurrent.{ExecutionContext, Future}
import scala.io.Source
import scala.util.{Failure, Success, Try}
import scala.language.existentials
class S2RestHandler(s2rest: RestHandler)(implicit ec: ExecutionContext) extends SimpleChannelInboundHandler[FullHttpRequest] {
val ApplicationJson = "application/json"
val Ok = HttpResponseStatus.OK
val CloseOpt = Option(ChannelFutureListener.CLOSE)
val BadRequest = HttpResponseStatus.BAD_REQUEST
val BadGateway = HttpResponseStatus.BAD_GATEWAY
val NotFound = HttpResponseStatus.NOT_FOUND
val InternalServerError = HttpResponseStatus.INTERNAL_SERVER_ERROR
implicit val nettyHeadersLookup = new CanLookup[HttpHeaders] {
override def lookup(m: HttpHeaders, key: String) = Option(m.get(key))
}
def badRoute(ctx: ChannelHandlerContext) =
simpleResponse(ctx, BadGateway, byteBufOpt = None, channelFutureListenerOpt = CloseOpt)
def simpleResponse(ctx: ChannelHandlerContext,
httpResponseStatus: HttpResponseStatus,
byteBufOpt: Option[ByteBuf] = None,
headers: Seq[(String, String)] = Nil,
channelFutureListenerOpt: Option[ChannelFutureListener] = None): Unit = {
val res: FullHttpResponse = byteBufOpt match {
case None => new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, httpResponseStatus)
case Some(byteBuf) =>
new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, httpResponseStatus, byteBuf)
}
headers.foreach { case (k, v) => res.headers().set(k, v) }
val channelFuture = ctx.writeAndFlush(res)
channelFutureListenerOpt match {
case None =>
case Some(listener) => channelFuture.addListener(listener)
}
}
def toResponse(ctx: ChannelHandlerContext, req: FullHttpRequest, requestBody: String, result: HandlerResult, startedAt: Long) = {
var closeOpt = CloseOpt
var headers = mutable.ArrayBuilder.make[(String, String)]
headers += (Names.CONTENT_TYPE -> ApplicationJson)
result.headers.foreach(headers += _)
if (HttpHeaders.isKeepAlive(req)) {
headers += (Names.CONNECTION -> HttpHeaders.Values.KEEP_ALIVE)
closeOpt = None
}
result.body onComplete {
case Success(json) =>
val duration = System.currentTimeMillis() - startedAt
val bucketName = result.headers.toMap.get(Experiment.ImpressionKey).getOrElse("")
val log = s"${req.getMethod} ${req.getUri} took ${duration} ms 200 ${s2rest.calcSize(json)} ${requestBody} ${bucketName}"
logger.info(log)
val buf: ByteBuf = Unpooled.copiedBuffer(json.toString, CharsetUtil.UTF_8)
headers += (Names.CONTENT_LENGTH -> buf.readableBytes().toString)
simpleResponse(ctx, Ok, byteBufOpt = Option(buf), channelFutureListenerOpt = closeOpt, headers = headers.result())
case Failure(ex) => ex match {
case e: BadQueryException =>
logger.error(s"{$requestBody}, ${e.getMessage}", e)
val buf: ByteBuf = Unpooled.copiedBuffer(PostProcess.badRequestResults(e).toString, CharsetUtil.UTF_8)
simpleResponse(ctx, BadRequest, byteBufOpt = Option(buf), channelFutureListenerOpt = CloseOpt, headers = headers.result())
case e: Exception =>
logger.error(s"${requestBody}, ${e.getMessage}", e)
val buf: ByteBuf = Unpooled.copiedBuffer(PostProcess.emptyResults.toString, CharsetUtil.UTF_8)
simpleResponse(ctx, InternalServerError, byteBufOpt = Option(buf), channelFutureListenerOpt = CloseOpt, headers = headers.result())
}
}
}
private def healthCheck(ctx: ChannelHandlerContext)(predicate: Boolean): Unit = {
if (predicate) {
val healthCheckMsg = Unpooled.copiedBuffer(NettyServer.deployInfo, CharsetUtil.UTF_8)
simpleResponse(ctx, Ok, byteBufOpt = Option(healthCheckMsg), channelFutureListenerOpt = CloseOpt)
} else {
simpleResponse(ctx, NotFound, channelFutureListenerOpt = CloseOpt)
}
}
private def updateHealthCheck(ctx: ChannelHandlerContext)(newValue: Boolean)(updateOp: Boolean => Unit): Unit = {
updateOp(newValue)
val newHealthCheckMsg = Unpooled.copiedBuffer(newValue.toString, CharsetUtil.UTF_8)
simpleResponse(ctx, Ok, byteBufOpt = Option(newHealthCheckMsg), channelFutureListenerOpt = CloseOpt)
}
override def channelRead0(ctx: ChannelHandlerContext, req: FullHttpRequest): Unit = {
val uri = req.getUri
val startedAt = System.currentTimeMillis()
val checkFunc = healthCheck(ctx) _
val updateFunc = updateHealthCheck(ctx) _
req.getMethod match {
case HttpMethod.GET =>
uri match {
case "/health_check.html" => checkFunc(NettyServer.isHealthy)
case "/fallback_check.html" => checkFunc(NettyServer.isFallbackHealthy)
case "/query_fallback_check.html" => checkFunc(NettyServer.isQueryFallbackHealthy)
case s if s.startsWith("/graphs/getEdge/") =>
if (!NettyServer.isQueryFallbackHealthy) {
val result = HandlerResult(body = Future.successful(PostProcess.emptyResults))
toResponse(ctx, req, s, result, startedAt)
} else {
val Array(srcId, tgtId, labelName, direction) = s.split("/").takeRight(4)
val params = Json.arr(Json.obj("label" -> labelName, "direction" -> direction, "from" -> srcId, "to" -> tgtId))
val result = s2rest.checkEdges(params)
toResponse(ctx, req, s, result, startedAt)
}
case _ => badRoute(ctx)
}
case HttpMethod.PUT =>
if (uri.startsWith("/health_check/")) {
val newValue = uri.split("/").last.toBoolean
updateFunc(newValue) { v => NettyServer.isHealthy = v }
} else if (uri.startsWith("/query_fallback_check/")) {
val newValue = uri.split("/").last.toBoolean
updateFunc(newValue) { v => NettyServer.isQueryFallbackHealthy = v }
} else if (uri.startsWith("/fallback_check/")) {
val newValue = uri.split("/").last.toBoolean
updateFunc(newValue) { v => NettyServer.isFallbackHealthy = v }
} else {
badRoute(ctx)
}
case HttpMethod.POST =>
val body = req.content.toString(CharsetUtil.UTF_8)
if (!NettyServer.isQueryFallbackHealthy) {
val result = HandlerResult(body = Future.successful(PostProcess.emptyResults))
toResponse(ctx, req, body, result, startedAt)
} else {
val result = s2rest.doPost(uri, body, req.headers())
toResponse(ctx, req, body, result, startedAt)
}
case _ =>
simpleResponse(ctx, BadRequest, byteBufOpt = None, channelFutureListenerOpt = CloseOpt)
}
}
override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable) {
cause match {
case e: java.io.IOException =>
ctx.channel().close().addListener(CloseOpt.get)
case _ =>
cause.printStackTrace()
logger.error(s"exception on query.", cause)
simpleResponse(ctx, BadRequest, byteBufOpt = None, channelFutureListenerOpt = CloseOpt)
}
}
}
// Simple http server
object NettyServer extends App {
/** should be same with Boostrap.onStart on play */
val numOfThread = Runtime.getRuntime.availableProcessors()
val threadPool = Executors.newFixedThreadPool(numOfThread)
val ec = ExecutionContext.fromExecutor(threadPool)
val config = ConfigFactory.load()
val port = Try(config.getInt("http.port")).recover { case _ => 9000 }.get
val transport = Try(config.getString("netty.transport")).recover { case _ => "jdk" }.get
val maxBodySize = Try(config.getInt("max.body.size")).recover { case _ => 65536 * 2 }.get
// init s2graph with config
val s2graph = new S2Graph(config)(ec)
val rest = new RestHandler(s2graph)(ec)
val deployInfo = Try(Source.fromFile("./release_info").mkString("")).recover { case _ => "release info not found\n" }.get
var isHealthy = config.getBooleanWithFallback("app.health.on", true)
var isFallbackHealthy = true
var isQueryFallbackHealthy = true
logger.info(s"starts with num of thread: $numOfThread, ${threadPool.getClass.getSimpleName}")
logger.info(s"transport: $transport")
// Configure the server.
val (bossGroup, workerGroup, channelClass) = transport match {
case "native" =>
(new EpollEventLoopGroup(1), new EpollEventLoopGroup(), classOf[EpollServerSocketChannel])
case _ =>
(new NioEventLoopGroup(1), new NioEventLoopGroup(), classOf[NioServerSocketChannel])
}
try {
val b: ServerBootstrap = new ServerBootstrap()
.option(ChannelOption.SO_BACKLOG, Int.box(2048))
b.group(bossGroup, workerGroup).channel(channelClass)
.handler(new LoggingHandler(LogLevel.INFO))
.childHandler(new ChannelInitializer[SocketChannel] {
override def initChannel(ch: SocketChannel) {
val p = ch.pipeline()
p.addLast(new HttpServerCodec())
p.addLast(new HttpObjectAggregator(maxBodySize))
p.addLast(new S2RestHandler(rest)(ec))
}
})
// for check server is started
logger.info(s"Listening for HTTP on /0.0.0.0:$port")
val ch: Channel = b.bind(port).sync().channel()
ch.closeFuture().sync()
} finally {
bossGroup.shutdownGracefully()
workerGroup.shutdownGracefully()
s2graph.shutdown()
}
}