blob: 69c1bb3311867fd0fc57dec4a57acdd528e41914 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.drill.exec.rpc;
import com.google.protobuf.Internal.EnumLite;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.MessageLite;
import com.google.protobuf.Parser;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufInputStream;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.socket.SocketChannel;
import io.netty.handler.codec.MessageToMessageDecoder;
import io.netty.util.concurrent.GenericFutureListener;
import org.apache.drill.common.exceptions.UserException;
import org.apache.drill.exec.proto.GeneralRPCProtos.RpcMode;
import org.apache.drill.exec.proto.UserBitShared.DrillPBError;
import com.google.common.base.Preconditions;
import com.google.common.base.Stopwatch;
import java.io.Closeable;
import java.net.SocketAddress;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
/**
* The Rpc Bus deals with incoming and outgoing communication and is used on both the server and the client side of a
* system.
*
* @param <T> RPC type
* @param <C> Remote connection type
*/
public abstract class RpcBus<T extends EnumLite, C extends RemoteConnection> implements Closeable {
final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(this.getClass());
private static final OutboundRpcMessage PONG = new OutboundRpcMessage(RpcMode.PONG, 0, 0, Acks.OK);
protected abstract MessageLite getResponseDefaultInstance(int rpcType) throws RpcException;
protected abstract void handle(C connection, int rpcType, ByteBuf pBody, ByteBuf dBody, ResponseSender sender)
throws RpcException;
protected final RpcConfig rpcConfig;
protected volatile SocketAddress local;
protected volatile SocketAddress remote;
public RpcBus(RpcConfig rpcConfig) {
this.rpcConfig = rpcConfig;
}
protected void setAddresses(SocketAddress remote, SocketAddress local){
this.remote = remote;
this.local = local;
}
public <SEND extends MessageLite, RECEIVE extends MessageLite>
DrillRpcFuture<RECEIVE> send(C connection, T rpcType, SEND protobufBody, Class<RECEIVE> clazz,
ByteBuf... dataBodies) {
DrillRpcFutureImpl<RECEIVE> rpcFuture = new DrillRpcFutureImpl<RECEIVE>();
this.send(rpcFuture, connection, rpcType, protobufBody, clazz, dataBodies);
return rpcFuture;
}
public <SEND extends MessageLite, RECEIVE extends MessageLite>
void send(RpcOutcomeListener<RECEIVE> listener, C connection, T rpcType, SEND protobufBody, Class<RECEIVE> clazz,
ByteBuf... dataBodies) {
send(listener, connection, rpcType, protobufBody, clazz, false, dataBodies);
}
public <SEND extends MessageLite, RECEIVE extends MessageLite>
void send(RpcOutcomeListener<RECEIVE> listener, C connection, T rpcType, SEND protobufBody, Class<RECEIVE> clazz,
boolean allowInEventLoop, ByteBuf... dataBodies) {
Preconditions
.checkArgument(
allowInEventLoop || !connection.inEventLoop(),
"You attempted to send while inside the rpc event thread. This isn't allowed because sending will block if the channel is backed up.");
ByteBuf pBuffer = null;
boolean completed = false;
try {
if (!allowInEventLoop && !connection.blockOnNotWritable(listener)) {
// if we're in not in the event loop and we're interrupted while blocking, skip sending this message.
return;
}
assert !Arrays.asList(dataBodies).contains(null);
assert rpcConfig.checkSend(rpcType, protobufBody.getClass(), clazz);
Preconditions.checkNotNull(protobufBody);
ChannelListenerWithCoordinationId futureListener = connection.createNewRpcListener(listener, clazz);
OutboundRpcMessage m = new OutboundRpcMessage(RpcMode.REQUEST, rpcType, futureListener.getCoordinationId(), protobufBody, dataBodies);
ChannelFuture channelFuture = connection.getChannel().writeAndFlush(m);
channelFuture.addListener(futureListener);
channelFuture.addListener(ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE);
completed = true;
} catch (Exception | AssertionError e) {
listener.failed(new RpcException("Failure sending message.", e));
} finally {
if (!completed) {
if (pBuffer != null) {
pBuffer.release();
}
if (dataBodies != null) {
for (ByteBuf b : dataBodies) {
b.release();
}
}
}
}
}
protected abstract C initRemoteConnection(SocketChannel channel);
public class ChannelClosedHandler implements GenericFutureListener<ChannelFuture> {
final C clientConnection;
private final Channel channel;
public ChannelClosedHandler(C clientConnection, Channel channel) {
this.channel = channel;
this.clientConnection = clientConnection;
}
@Override
public void operationComplete(ChannelFuture future) throws Exception {
final String msg;
if(local!=null) {
msg = String.format("Channel closed %s <--> %s.", local, remote);
}else{
msg = String.format("Channel closed %s <--> %s.", future.channel().localAddress(), future.channel().remoteAddress());
}
final ChannelClosedException ex = future.cause() != null ?
new ChannelClosedException(msg, future.cause()) :
new ChannelClosedException(msg);
clientConnection.channelClosed(ex);
}
}
protected GenericFutureListener<ChannelFuture> getCloseHandler(SocketChannel channel, C clientConnection) {
return new ChannelClosedHandler(clientConnection, channel);
}
private class ResponseSenderImpl implements ResponseSender {
private final RemoteConnection connection;
private final int coordinationId;
private final AtomicBoolean sent = new AtomicBoolean(false);
public ResponseSenderImpl(RemoteConnection connection, int coordinationId) {
this.connection = connection;
this.coordinationId = coordinationId;
}
public void send(Response r) {
assert rpcConfig.checkResponseSend(r.rpcType, r.pBody.getClass());
sendOnce();
OutboundRpcMessage outMessage = new OutboundRpcMessage(RpcMode.RESPONSE, r.rpcType, coordinationId,
r.pBody, r.dBodies);
if (RpcConstants.EXTRA_DEBUGGING) {
logger.debug("Adding message to outbound buffer. {}", outMessage);
logger.debug("Sending response with Sender {}", System.identityHashCode(this));
}
connection.getChannel().writeAndFlush(outMessage);
}
/**
* Ensures that each sender is only used once.
*/
private void sendOnce() {
if (!sent.compareAndSet(false, true)) {
throw new IllegalStateException("Attempted to utilize a sender multiple times.");
}
}
void sendFailure(UserRpcException e){
sendOnce();
UserException uex = UserException.systemError(e)
.addIdentity(e.getEndpoint())
.build(logger);
logger.error("Unexpected Error while handling request message", e);
OutboundRpcMessage outMessage = new OutboundRpcMessage(
RpcMode.RESPONSE_FAILURE,
0,
coordinationId,
uex.getOrCreatePBError(false)
);
if (RpcConstants.EXTRA_DEBUGGING) {
logger.debug("Adding message to outbound buffer. {}", outMessage);
}
connection.getChannel().writeAndFlush(outMessage);
}
}
private static void retainByteBuf(ByteBuf buf) {
if (buf != null) {
buf.retain();
}
}
private static void releaseByteBuf(ByteBuf buf) {
if (buf != null) {
buf.release();
}
}
protected class InboundHandler extends MessageToMessageDecoder<InboundRpcMessage> {
private final C connection;
public InboundHandler(C connection) {
super();
Preconditions.checkNotNull(connection);
this.connection = connection;
}
@Override
protected void decode(final ChannelHandlerContext ctx, final InboundRpcMessage msg, final List<Object> output)
throws Exception {
if (!ctx.channel().isOpen()) {
return;
}
if (RpcConstants.EXTRA_DEBUGGING) {
logger.debug("Received message {}", msg);
}
final Channel channel = connection.getChannel();
final Stopwatch watch = Stopwatch.createStarted();
try {
switch (msg.mode) {
case REQUEST: {
final ResponseSenderImpl sender = new ResponseSenderImpl(connection, msg.coordinationId);
retainByteBuf(msg.pBody);
retainByteBuf(msg.dBody);
try {
handle(connection, msg.rpcType, msg.pBody, msg.dBody, sender);
} catch (UserRpcException e) {
sender.sendFailure(e);
} finally {
releaseByteBuf(msg.pBody);
releaseByteBuf(msg.dBody);
}
break;
}
case RESPONSE: {
retainByteBuf(msg.pBody);
retainByteBuf(msg.dBody);
try {
final MessageLite defaultResponse = getResponseDefaultInstance(msg.rpcType);
assert rpcConfig.checkReceive(msg.rpcType, defaultResponse.getClass());
final RpcOutcome<?> rpcFuture = connection.getAndRemoveRpcOutcome(msg.rpcType, msg.coordinationId,
defaultResponse.getClass());
final Parser<?> parser = defaultResponse.getParserForType();
final Object value = parser.parseFrom(new ByteBufInputStream(msg.pBody, msg.pBody.readableBytes()));
rpcFuture.set(value, msg.dBody);
if (RpcConstants.EXTRA_DEBUGGING) {
logger.debug("Updated rpc future {} with value {}", rpcFuture, value);
}
} catch (Exception ex) {
logger.error("Failure while handling response.", ex);
throw ex;
} finally {
releaseByteBuf(msg.pBody);
releaseByteBuf(msg.dBody);
}
break;
}
case RESPONSE_FAILURE: {
DrillPBError failure = DrillPBError.parseFrom(new ByteBufInputStream(msg.pBody, msg.pBody.readableBytes()));
connection.recordRemoteFailure(msg.coordinationId, failure);
if (RpcConstants.EXTRA_DEBUGGING) {
logger.debug("Updated rpc future with coordinationId {} with failure {}", msg.coordinationId, failure);
}
break;
}
case PING:
channel.writeAndFlush(PONG);
break;
case PONG:
ctx.fireUserEventTriggered(RpcMode.PONG);
break;
default:
throw new UnsupportedOperationException();
}
} finally {
long time = watch.elapsed(TimeUnit.MILLISECONDS);
long delayThreshold = Integer.parseInt(System.getProperty("drill.exec.rpcDelayWarning", "500"));
if (time > delayThreshold) {
logger.warn(String.format(
"Message of mode %s of rpc type %d took longer than %dms. Actual duration was %dms.",
msg.mode, msg.rpcType, delayThreshold, time));
}
msg.release();
}
}
}
public static <T> T get(ByteBuf pBody, Parser<T> parser) throws RpcException {
try {
ByteBufInputStream is = new ByteBufInputStream(pBody);
return parser.parseFrom(is);
} catch (InvalidProtocolBufferException e) {
throw new RpcException(
String.format("Failure while decoding message with parser of type. %s",
parser.getClass().getCanonicalName()), e);
}
}
}