blob: f98194c6320cf48d279021e875abe7a05d1f1890 [file] [log] [blame]
/*
* Copyright 2017 HugeGraph Authors
*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with this
* work for additional information regarding copyright ownership. The ASF
* licenses this file to You under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package com.baidu.hugegraph.computer.core.network.netty;
import java.util.concurrent.TimeUnit;
import com.baidu.hugegraph.computer.core.common.exception.TransportException;
import com.baidu.hugegraph.computer.core.network.ConnectionId;
import com.baidu.hugegraph.computer.core.network.MessageHandler;
import com.baidu.hugegraph.computer.core.network.TransportUtil;
import com.baidu.hugegraph.computer.core.network.buffer.FileRegionBuffer;
import com.baidu.hugegraph.computer.core.network.buffer.NetworkBuffer;
import com.baidu.hugegraph.computer.core.network.message.AbstractMessage;
import com.baidu.hugegraph.computer.core.network.message.AckMessage;
import com.baidu.hugegraph.computer.core.network.message.DataMessage;
import com.baidu.hugegraph.computer.core.network.message.FinishMessage;
import com.baidu.hugegraph.computer.core.network.message.StartMessage;
import com.baidu.hugegraph.computer.core.network.session.ServerSession;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.EventLoop;
import io.netty.channel.socket.SocketChannel;
import io.netty.util.concurrent.ScheduledFuture;
public class NettyServerHandler extends AbstractNettyHandler {
private static final long INITIAL_DELAY = 0L;
private final MessageHandler handler;
private final ServerSession serverSession;
private final ChannelFutureListenerOnWrite listenerOnWrite;
private ScheduledFuture<?> respondAckTask;
public NettyServerHandler(ServerSession serverSession,
MessageHandler handler) {
this.serverSession = serverSession;
this.handler = handler;
this.listenerOnWrite = new ChannelFutureListenerOnWrite(this.handler);
}
@Override
protected void processStartMessage(ChannelHandlerContext ctx,
Channel channel,
StartMessage startMessage) {
this.serverSession.onRecvStateStart();
this.ackStartMessage(ctx);
}
@Override
protected void processFinishMessage(ChannelHandlerContext ctx,
Channel channel,
FinishMessage finishMessage) {
int finishId = finishMessage.requestId();
boolean needAckFinish = this.serverSession.onRecvStateFinish(finishId);
if (needAckFinish) {
this.ackFinishMessage(ctx, this.serverSession.finishId());
}
}
@Override
protected void processDataMessage(ChannelHandlerContext ctx,
Channel channel,
DataMessage dataMessage) {
NetworkBuffer body = dataMessage.body();
try {
int requestId = dataMessage.requestId();
this.serverSession.onRecvData(requestId);
if (body instanceof FileRegionBuffer) {
this.processFileRegionBuffer(ctx, channel, dataMessage,
(FileRegionBuffer) body);
} else {
this.handler.handle(dataMessage.type(), dataMessage.partition(),
dataMessage.body());
this.serverSession.onHandledData(requestId);
}
} finally {
body.release();
}
}
private void processFileRegionBuffer(ChannelHandlerContext ctx,
Channel channel,
DataMessage dataMessage,
FileRegionBuffer fileRegionBuffer) {
// Optimize Value of max bytes of next read
TransportUtil.setMaxBytesPerRead(channel, fileRegionBuffer.length());
String outputPath = this.handler.genOutputPath(dataMessage.type(),
dataMessage.partition());
/*
* Submit zero-copy task to EventLoop, it will be executed next time
* network data is received.
*/
ChannelFuture channelFuture = fileRegionBuffer.transformFromChannel(
(SocketChannel) channel, outputPath);
channelFuture.addListener((ChannelFutureListener) future -> {
try {
if (future.isSuccess()) {
this.handler.handle(dataMessage.type(),
dataMessage.partition(),
dataMessage.body());
this.serverSession.onHandledData(
dataMessage.requestId());
} else {
this.exceptionCaught(ctx, future.cause());
}
// Reset max bytes next read to length of frame
TransportUtil.setMaxBytesPerRead(future.channel(),
AbstractMessage.HEADER_LENGTH);
future.channel().unsafe().recvBufAllocHandle().reset(
future.channel().config());
dataMessage.release();
} catch (Throwable throwable) {
this.exceptionCaught(ctx, throwable);
}
});
}
@Override
protected void processAckMessage(ChannelHandlerContext ctx, Channel channel,
AckMessage ackMessage) {
throw new UnsupportedOperationException(
"Server does not support processAckMessage()");
}
private void ackStartMessage(ChannelHandlerContext ctx) {
AckMessage startAck = new AckMessage(AbstractMessage.START_SEQ);
ctx.writeAndFlush(startAck).addListener(this.listenerOnWrite);
this.serverSession.completeStateStart();
Channel channel = ctx.channel();
this.handler.onStarted(TransportUtil.remoteConnectionId(channel));
// Add an schedule task to check and respond ack
if (this.respondAckTask == null) {
EventLoop eventExecutors = ctx.channel().eventLoop();
this.respondAckTask = eventExecutors.scheduleWithFixedDelay(
() -> this.checkAndRespondAck(ctx),
INITIAL_DELAY,
this.serverSession.minAckInterval(),
TimeUnit.MILLISECONDS);
}
}
private void ackFinishMessage(ChannelHandlerContext ctx,
int finishId) {
AckMessage finishAck = new AckMessage(finishId);
ctx.writeAndFlush(finishAck).addListener(this.listenerOnWrite);
this.serverSession.completeStateFinish();
Channel channel = ctx.channel();
this.handler.onFinished(TransportUtil.remoteConnectionId(channel));
// Cancel and remove the task to check respond ack
if (this.respondAckTask != null) {
this.respondAckTask.cancel(false);
this.respondAckTask = null;
}
}
private void ackDataMessage(ChannelHandlerContext ctx, int ackId) {
AckMessage ackMessage = new AckMessage(ackId);
ctx.writeAndFlush(ackMessage).addListener(this.listenerOnWrite);
this.serverSession.onDataAckSent(ackId);
}
private void checkAndRespondAck(ChannelHandlerContext ctx) {
if (this.serverSession.needAckFinish()) {
this.ackFinishMessage(ctx, this.serverSession.finishId());
} else if (this.serverSession.needAckData()) {
this.ackDataMessage(ctx, this.serverSession.maxHandledId());
}
}
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
Channel channel = ctx.channel();
ConnectionId connectionId = TransportUtil.remoteConnectionId(channel);
this.handler.onChannelActive(connectionId);
super.channelActive(ctx);
}
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
Channel channel = ctx.channel();
ConnectionId connectionId = TransportUtil.remoteConnectionId(channel);
this.handler.onChannelInactive(connectionId);
super.channelInactive(ctx);
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
TransportException exception;
Channel channel = ctx.channel();
if (cause instanceof TransportException) {
exception = (TransportException) cause;
} else {
exception = new TransportException(
"%s when the server receive data from '%s'",
cause, cause.getMessage(),
TransportUtil.remoteAddress(channel));
}
ConnectionId connectionId = TransportUtil.remoteConnectionId(channel);
this.handler.exceptionCaught(exception, connectionId);
}
@Override
protected ServerSession session() {
return this.serverSession;
}
@Override
protected MessageHandler transportHandler() {
return this.handler;
}
}