blob: 878c85f9d2d9f71887797ce12c1877d96de90c0f [file] [log] [blame]
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hama.ipc;
import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufInputStream;
import io.netty.channel.*;
import io.netty.channel.epoll.EpollEventLoopGroup;
import io.netty.channel.epoll.EpollSocketChannel;
import io.netty.channel.socket.SocketChannel;
import io.netty.handler.logging.LogLevel;
import io.netty.handler.logging.LoggingHandler;
import io.netty.handler.timeout.IdleState;
import io.netty.handler.timeout.IdleStateEvent;
import io.netty.handler.timeout.IdleStateHandler;
import io.netty.util.ReferenceCountUtil;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.DataOutputBuffer;
import org.apache.hadoop.io.IOUtils;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.io.WritableUtils;
import org.apache.hadoop.security.SaslRpcServer.AuthMethod;
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hadoop.util.ReflectionUtils;
import org.apache.hama.util.BSPNetUtils;
import javax.net.SocketFactory;
import java.io.DataInputStream;
import java.io.IOException;
import java.net.ConnectException;
import java.net.InetSocketAddress;
import java.net.SocketTimeoutException;
import java.net.UnknownHostException;
import java.util.Hashtable;
import java.util.Iterator;
import java.util.Map.Entry;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
/**
* A client for an IPC service using netty. IPC calls take a single
* {@link Writable} as a parameter, and return a {@link Writable} as their
* value. A service runs on a port and is defined by a parameter class and a
* value class.
*
* @see AsyncClient
*/
public class AsyncClient {
private static final String IPC_CLIENT_CONNECT_MAX_RETRIES_KEY = "ipc.client.connect.max.retries";
private static final int IPC_CLIENT_CONNECT_MAX_RETRIES_DEFAULT = 10;
private static final Log LOG = LogFactory.getLog(AsyncClient.class);
private Hashtable<ConnectionId, Connection> connections = new Hashtable<ConnectionId, Connection>();
private Class<? extends Writable> valueClass; // class of call values
private int counter = 0; // counter for call ids
private AtomicBoolean running = new AtomicBoolean(true); // if client runs
final private Configuration conf; // configuration obj
private SocketFactory socketFactory; // only use in order to meet the
// consistency with other clients
private int refCount = 1;
final private static String PING_INTERVAL_NAME = "ipc.ping.interval";
final static int DEFAULT_PING_INTERVAL = 60000; // 1 min
/**
* set the ping interval value in configuration
*
* @param conf Configuration
* @param pingInterval the ping interval
*/
final public static void setPingInterval(Configuration conf, int pingInterval) {
conf.setInt(PING_INTERVAL_NAME, pingInterval);
}
/**
* Get the ping interval from configuration; If not set in the configuration,
* return the default value.
*
* @param conf Configuration
* @return the ping interval
*/
final static int getPingInterval(Configuration conf) {
return conf.getInt(PING_INTERVAL_NAME, DEFAULT_PING_INTERVAL);
}
/**
* The time after which a RPC will timeout. If ping is not enabled (via
* ipc.client.ping), then the timeout value is the same as the pingInterval.
* If ping is enabled, then there is no timeout value.
*
* @param conf Configuration
* @return the timeout period in milliseconds. -1 if no timeout value is set
*/
final public static int getTimeout(Configuration conf) {
if (!conf.getBoolean("ipc.client.ping", true)) {
return getPingInterval(conf);
}
return -1;
}
/**
* Increment this client's reference count
*
*/
synchronized void incCount() {
refCount++;
}
/**
* Decrement this client's reference count
*
*/
synchronized void decCount() {
refCount--;
}
/**
* Return if this client has no reference
*
* @return true if this client has no reference; false otherwise
*/
synchronized boolean isZeroReference() {
return refCount == 0;
}
/**
* Thread that reads responses and notifies callers. Each connection owns a
* socket connected to a remote address. Calls are multiplexed through this
* socket: responses may be delivered out of order.
*/
private class Connection {
private InetSocketAddress serverAddress; // server ip:port
private ConnectionHeader header; // connection header
private final ConnectionId remoteId; // connection id
private AuthMethod authMethod; // authentication method
private EventLoopGroup group;
private Bootstrap bootstrap;
private Channel channel;
private int rpcTimeout;
private int maxIdleTime; // connections will be culled if it was idle
private final RetryPolicy connectionRetryPolicy;
private boolean tcpNoDelay; // if T then disable Nagle's Algorithm
private int pingInterval; // how often sends ping to the server in msecs
// currently active calls
private Hashtable<Integer, Call> calls = new Hashtable<Integer, Call>();
private AtomicBoolean shouldCloseConnection = new AtomicBoolean(); // indicate
private IOException closeException; // if the connection is closed, close
// reason
/**
* Setup Connection Configuration
*
* @param remoteId remote connection Id
* @throws IOException
*/
public Connection(ConnectionId remoteId) throws IOException {
group = new EpollEventLoopGroup();
bootstrap = new Bootstrap();
this.remoteId = remoteId;
this.serverAddress = remoteId.getAddress();
if (serverAddress.isUnresolved()) {
throw new UnknownHostException("unknown host: "
+ remoteId.getAddress().getHostName());
}
this.maxIdleTime = remoteId.getMaxIdleTime();
this.connectionRetryPolicy = remoteId.connectionRetryPolicy;
this.tcpNoDelay = remoteId.getTcpNoDelay();
this.pingInterval = remoteId.getPingInterval();
if (LOG.isDebugEnabled()) {
LOG.debug("The ping interval is" + this.pingInterval + "ms.");
}
this.rpcTimeout = remoteId.getRpcTimeout();
Class<?> protocol = remoteId.getProtocol();
authMethod = AuthMethod.SIMPLE;
header = new ConnectionHeader(protocol == null ? null
: protocol.getName(), null, authMethod);
}
/**
* Add a call to this connection's call queue and notify a listener;
* synchronized. Returns false if called during shutdown.
*
* @param call to add
* @return true if the call was added.
*/
private synchronized boolean addCall(Call call) {
if (shouldCloseConnection.get())
return false;
calls.put(call.id, call);
notify();
return true;
}
/**
* Update the server address if the address corresponding to the host name
* has changed.
*/
private synchronized boolean updateAddress() throws IOException {
// Do a fresh lookup with the old host name.
InetSocketAddress currentAddr = BSPNetUtils.makeSocketAddr(
serverAddress.getHostName(), serverAddress.getPort());
if (!serverAddress.equals(currentAddr)) {
LOG.warn("Address change detected. Old: " + serverAddress.toString()
+ " New: " + currentAddr.toString());
serverAddress = currentAddr;
return true;
}
return false;
}
/**
* Connect to the server and set up the I/O streams. It then sends a header
* to the server.
*/
private void setupIOstreams() throws InterruptedException {
if (channel != null && channel.isActive()) {
return;
}
try {
if (LOG.isDebugEnabled()) {
LOG.debug("Connecting to " + serverAddress);
}
setupConnection();
writeHeader();
} catch (Throwable t) {
if (t instanceof IOException) {
markClosed((IOException) t);
} else {
markClosed(new IOException("Couldn't set up IO streams", t));
}
close();
}
}
/**
* Configure the client and connect to server
*/
private void setupConnection() throws Exception {
while (true) {
short ioFailures = 0;
try {
// rpcTimeout overwrites pingInterval
if (rpcTimeout > 0) {
pingInterval = rpcTimeout;
}
// Configure the client.
// EpollEventLoopGroup is a multithreaded event loop that handles I/O
// operation
group = new EpollEventLoopGroup();
// Bootstrap is a helper class that sets up a client
bootstrap = new Bootstrap();
bootstrap.group(group).channel(EpollSocketChannel.class)
.option(ChannelOption.TCP_NODELAY, this.tcpNoDelay)
.option(ChannelOption.SO_KEEPALIVE, true)
.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, pingInterval)
.option(ChannelOption.SO_SNDBUF, 30 * 1024 * 1024)
.handler(new LoggingHandler(LogLevel.INFO))
.handler(new ChannelInitializer<SocketChannel>() {
@Override
public void initChannel(SocketChannel ch) throws Exception {
ChannelPipeline p = ch.pipeline();
p.addLast(new IdleStateHandler(0, 0, maxIdleTime));
// Register message processing handler
p.addLast(new NioClientInboundHandler());
}
});
// Bind and start to accept incoming connections.
ChannelFuture channelFuture = bootstrap.connect(
serverAddress.getAddress(), serverAddress.getPort()).sync();
// Get io channel
channel = channelFuture.channel();
LOG.info("AsyncClient startup");
break;
} catch (Exception ie) {
/*
* Check for an address change and update the local reference. Reset
* the failure counter if the address was changed
*/
if (updateAddress()) {
ioFailures = 0;
}
handleConnectionFailure(ioFailures++, ie);
}
}
}
/**
* Write the header protocol header for each connection Out is not
* synchronized because only the first thread does this.
*
* @param channel
*/
private void writeHeader() {
DataOutputBuffer rpcBuff = null;
DataOutputBuffer headerBuf = null;
try {
ByteBuf buf = channel.alloc().buffer();
rpcBuff = new DataOutputBuffer();
authMethod.write(rpcBuff);
headerBuf = new DataOutputBuffer();
header.write(headerBuf);
byte[] data = headerBuf.getData();
int dataLength = headerBuf.getLength();
// write rpcheader
buf.writeInt(AsyncServer.HEADER_LENGTH + dataLength);
buf.writeBytes(AsyncServer.HEADER.array());
buf.writeByte(AsyncServer.CURRENT_VERSION);
buf.writeByte(rpcBuff.getData()[0]);
// write header
buf.writeInt(dataLength);
buf.writeBytes(data, 0, dataLength);
channel.writeAndFlush(buf);
} catch (Exception e) {
LOG.error("Couldn't send header" + e);
} finally {
IOUtils.closeStream(rpcBuff);
IOUtils.closeStream(headerBuf);
}
}
/**
* close the current connection gracefully.
*/
private void closeConnection() {
try {
if (!this.group.isTerminated()) {
this.group.shutdownGracefully();
LOG.info("client gracefully shutdown");
}
} catch (Exception e) {
LOG.warn("Not able to close a client", e);
}
}
/**
* This class process received response message from server.
*/
private class NioClientInboundHandler extends ChannelInboundHandlerAdapter {
/**
* Receive a response. This method is called with the received response
* message, whenever new data is received from a server.
*
* @param ctx
* @param cause
*/
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) {
ByteBuf byteBuf = (ByteBuf) msg;
ByteBufInputStream byteBufInputStream = new ByteBufInputStream(byteBuf);
DataInputStream in = new DataInputStream(byteBufInputStream);
while (true) {
try {
if (in.available() <= 0)
break;
// try to read an id
int id = in.readInt();
if (LOG.isDebugEnabled())
LOG.debug(serverAddress.getHostName() + " got value #" + id);
Call call = calls.get(id);
// read call status
int state = in.readInt();
if (state == Status.SUCCESS.state) {
Writable value = ReflectionUtils.newInstance(valueClass, conf);
value.readFields(in); // read value
call.setValue(value);
calls.remove(id);
} else if (state == Status.ERROR.state) {
String className = WritableUtils.readString(in);
byte[] errorBytes = new byte[in.available()];
in.readFully(errorBytes);
call.setException(new RemoteException(className, new String(
errorBytes)));
calls.remove(id);
} else if (state == Status.FATAL.state) {
// Close the connection
markClosed(new RemoteException(WritableUtils.readString(in),
WritableUtils.readString(in)));
} else {
byte[] garbageBytes = new byte[in.available()];
in.readFully(garbageBytes);
}
} catch (IOException e) {
markClosed(e);
}
}
IOUtils.closeStream(in);
IOUtils.closeStream(byteBufInputStream);
ReferenceCountUtil.release(msg);
}
/**
* Ths event handler method is called with a Throwable due to an I/O
* error. Then, exception is logged and its associated channel is closed
* here
*
* @param ctx
* @param cause
*/
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
LOG.error("Occured I/O Error : " + cause.getMessage());
ctx.close();
}
/**
* this method is triggered after a long reading/writing/idle time, it is
* marked as to be closed, or the client is marked as not running.
*
* @param ctx
* @param evt
*/
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
if (evt instanceof IdleStateEvent) {
IdleStateEvent e = (IdleStateEvent) evt;
if (e.state() != IdleState.ALL_IDLE) {
if (!calls.isEmpty() && !shouldCloseConnection.get()
&& running.get()) {
return;
} else if (shouldCloseConnection.get()) {
markClosed(null);
} else if (calls.isEmpty()) { // idle connection closed or stopped
markClosed(null);
} else { // get stopped but there are still pending requests
markClosed((IOException) new IOException()
.initCause(new InterruptedException()));
}
closeConnection();
}
}
}
}
/**
* Handle connection failures If the current number of retries is equal to
* the max number of retries, stop retrying and throw the exception;
* Otherwise backoff 1 second and try connecting again. This Method is only
* called from inside setupIOstreams(), which is synchronized. Hence the
* sleep is synchronized; the locks will be retained.
*
* @param curRetries current number of retries
* @param maxRetries max number of retries allowed
* @param ioe failure reason
* @throws IOException if max number of retries is reached
*/
@SuppressWarnings("unused")
private void handleConnectionFailure(int curRetries, int maxRetries,
IOException ioe) throws IOException {
closeConnection();
// throw the exception if the maximum number of retries is reached
if (curRetries >= maxRetries) {
throw ioe;
}
// otherwise back off and retry
try {
Thread.sleep(1000);
} catch (InterruptedException ignored) {
}
LOG.info("Retrying connect to server: " + serverAddress
+ ". Already tried " + curRetries + " time(s); maxRetries="
+ maxRetries);
}
/*
* Handle connection failures If the current number of retries, stop
* retrying and throw the exception; Otherwise backoff 1 second and try
* connecting again. This Method is only called from inside
* setupIOstreams(), which is synchronized. Hence the sleep is synchronized;
* the locks will be retained.
* @param curRetries current number of retries
* @param ioe failure reason
* @throws Exception if max number of retries is reached
*/
private void handleConnectionFailure(int curRetries, Exception ioe)
throws Exception {
closeConnection();
final boolean retry;
try {
retry = connectionRetryPolicy.shouldRetry(ioe, curRetries);
} catch (Exception e) {
throw e instanceof IOException ? (IOException) e : new IOException(e);
}
if (!retry) {
throw ioe;
}
LOG.info("Retrying connect to server: " + serverAddress
+ ". Already tried " + curRetries + " time(s); retry policy is "
+ connectionRetryPolicy);
}
/**
* Return the remote address of server
*
* @return remote server address
*/
public InetSocketAddress getRemoteAddress() {
return serverAddress;
}
/**
* Initiates a call by sending the parameter to the remote server.
*
* @param sendCall
*/
public void sendParam(Call sendCall) {
if (LOG.isDebugEnabled())
LOG.debug(this.getClass().getName() + " sending #" + sendCall.id);
DataOutputBuffer buff = null;
try {
buff = new DataOutputBuffer();
buff.writeInt(sendCall.id);
sendCall.param.write(buff);
byte[] data = buff.getData();
int dataLength = buff.getLength();
ByteBuf buf = channel.alloc().buffer();
buf.writeInt(dataLength);
buf.writeBytes(data, 0, dataLength);
ChannelFuture channelFuture = channel.writeAndFlush(buf);
if (channelFuture.cause() != null) {
throw channelFuture.cause();
}
} catch (IOException ioe) {
markClosed(ioe);
} catch (Throwable t) {
markClosed(new IOException(t));
} finally {
// the buffer is just an in-memory buffer, but it is still
// polite to close early
IOUtils.closeStream(buff);
}
}
/**
* Mark the connection to be closed
*
* @param ioe
**/
private synchronized void markClosed(IOException ioe) {
if (shouldCloseConnection.compareAndSet(false, true)) {
closeException = ioe;
notifyAll();
}
}
/** Close the connection. */
private synchronized void close() {
if (!shouldCloseConnection.get()) {
LOG.error("The connection is not in the closed state");
return;
}
// release the resources
// first thing to do;take the connection out of the connection list
synchronized (connections) {
if (connections.get(remoteId) == this) {
Connection connection = connections.remove(remoteId);
connection.closeConnection();
}
}
// clean up all calls
if (closeException == null) {
if (!calls.isEmpty()) {
LOG.warn("A connection is closed for no cause and calls are not empty");
// clean up calls anyway
closeException = new IOException("Unexpected closed connection");
cleanupCalls();
}
} else {
// log the info
if (LOG.isDebugEnabled()) {
LOG.debug("closing ipc connection to " + serverAddress + ": "
+ closeException.getMessage(), closeException);
}
// cleanup calls
cleanupCalls();
}
if (LOG.isDebugEnabled())
LOG.debug(serverAddress.getHostName() + ": closed");
}
/** Cleanup all calls and mark them as done */
private void cleanupCalls() {
Iterator<Entry<Integer, Call>> itor = calls.entrySet().iterator();
while (itor.hasNext()) {
Call c = itor.next().getValue();
c.setException(closeException); // local exception
itor.remove();
}
}
}
/** A call waiting for a value. */
private class Call {
int id; // call id
Writable param; // parameter
Writable value; // value, null if error
IOException error; // exception, null if value
boolean done; // true when call is done
protected Call(Writable param) {
this.param = param;
synchronized (AsyncClient.this) {
this.id = counter++;
}
}
/**
* Indicate when the call is complete and the value or error are available.
* Notifies by default.
*/
protected synchronized void callComplete() {
this.done = true;
notify(); // notify caller
}
/**
* Set the exception when there is an error. Notify the caller the call is
* done.
*
* @param error exception thrown by the call; either local or remote
*/
public synchronized void setException(IOException error) {
this.error = error;
this.callComplete();
}
/**
* Set the return value when there is no error. Notify the caller the call
* is done.
*
* @param value return value of the call.
*/
public synchronized void setValue(Writable value) {
this.value = value;
callComplete();
}
}
/** Call implementation used for parallel calls. */
private class ParallelCall extends Call {
private ParallelResults results;
private int index;
public ParallelCall(Writable param, ParallelResults results, int index) {
super(param);
this.results = results;
this.index = index;
}
@Override
/** Deliver result to result collector. */
protected void callComplete() {
results.callComplete(this);
}
}
/** Result collector for parallel calls. */
private static class ParallelResults {
private Writable[] values;
private int size;
private int count;
public ParallelResults(int size) {
this.values = new Writable[size];
this.size = size;
}
/**
* Collect a result.
*
* @param call
*/
public synchronized void callComplete(ParallelCall call) {
values[call.index] = call.value; // store the value
count++; // count it
if (count == size) // if all values are in
notify(); // then notify waiting caller
}
}
/**
* Construct an IPC client whose values are of the given {@link Writable}
* class.
*
* @param valueClass
* @param conf
* @param factory
*/
public AsyncClient(Class<? extends Writable> valueClass, Configuration conf,
SocketFactory factory) {
this.valueClass = valueClass;
this.conf = conf;
// SocketFactory only use in order to meet the consistency with other
// clients
this.socketFactory = factory;
}
/**
* Construct an IPC client with the default SocketFactory
*
* @param valueClass
* @param conf
*/
public AsyncClient(Class<? extends Writable> valueClass, Configuration conf) {
// SocketFactory only use in order to meet the consistency with other
// clients
this(valueClass, conf, BSPNetUtils.getDefaultSocketFactory(conf));
}
/**
* Return the socket factory of this client
*
* @return this client's socket factory
*/
SocketFactory getSocketFactory() {
// SocketFactory only use in order to meet the consistency with other
// clients
return socketFactory;
}
/**
* Stop all threads related to this client. No further calls may be made using
* this client.
*/
public void stop() {
if (LOG.isDebugEnabled()) {
LOG.debug("Stopping client");
}
if (!running.compareAndSet(true, false)) {
return;
}
// wake up all connections
synchronized (connections) {
for (Connection conn : connections.values()) {
conn.closeConnection();
}
}
}
/**
* Make a call, passing <code>param</code>, to the IPC server running at
* <code>address</code> which is servicing the <code>protocol</code> protocol,
* with the <code>ticket</code> credentials, <code>rpcTimeout</code> as
* timeout and <code>conf</code> as configuration for this connection,
* returning the value. Throws exceptions if there are network problems or if
* the remote code threw an exception.
*
* @param param
* @param addr
* @param protocol
* @param ticket
* @param rpcTimeout
* @param conf
* @return Response Writable value
* @throws InterruptedException
* @throws IOException
*/
public Writable call(Writable param, InetSocketAddress addr,
Class<?> protocol, UserGroupInformation ticket, int rpcTimeout,
Configuration conf) throws InterruptedException, IOException {
ConnectionId remoteId = ConnectionId.getConnectionId(addr, protocol,
ticket, rpcTimeout, conf);
return call(param, remoteId);
}
/**
* Make a call, passing <code>param</code>, to the IPC server defined by
* <code>remoteId</code>, returning the value. Throws exceptions if there are
* network problems or if the remote code threw an exception.
*
* @param param
* @param remoteId
* @return Response Writable value
* @throws InterruptedException
* @throws IOException
*/
public Writable call(Writable param, ConnectionId remoteId)
throws InterruptedException, IOException {
Call call = new Call(param);
Connection connection = getConnection(remoteId, call);
connection.sendParam(call); // send the parameter
boolean interrupted = false;
synchronized (call) {
int callFailCount = 0;
while (!call.done) {
try {
call.wait(1000); // wait for the result
// prevent client hang from response error
if (callFailCount++ == IPC_CLIENT_CONNECT_MAX_RETRIES_DEFAULT)
break;
} catch (InterruptedException ie) {
interrupted = true;
}
}
if (interrupted) {
// set the interrupt flag now that we are done waiting
Thread.currentThread().interrupt();
}
if (call.error != null) {
if (call.error instanceof RemoteException) {
call.error.fillInStackTrace();
throw call.error;
} else { // local exception
// use the connection because it will reflect an ip change,
// unlike
// the remoteId
throw wrapException(connection.getRemoteAddress(), call.error);
}
} else {
return call.value;
}
}
}
/**
* Take an IOException and the address we were trying to connect to and return
* an IOException with the input exception as the cause. The new exception
* provides the stack trace of the place where the exception is thrown and
* some extra diagnostics information. If the exception is ConnectException or
* SocketTimeoutException, return a new one of the same type; Otherwise return
* an IOException.
*
* @param addr target address
* @param exception the relevant exception
* @return an exception to throw
*/
private IOException wrapException(InetSocketAddress addr,
IOException exception) {
if (exception instanceof ConnectException) {
// connection refused; include the host:port in the error
return (ConnectException) new ConnectException("Call to " + addr
+ " failed on connection exception: " + exception)
.initCause(exception);
} else if (exception instanceof SocketTimeoutException) {
return (SocketTimeoutException) new SocketTimeoutException("Call to "
+ addr + " failed on socket timeout exception: " + exception)
.initCause(exception);
} else {
return (IOException) new IOException("Call to " + addr
+ " failed on local exception: " + exception).initCause(exception);
}
}
/**
* Makes a set of calls in parallel. Each parameter is sent to the
* corresponding address. When all values are available, or have timed out or
* errored, the collected results are returned in an array. The array contains
* nulls for calls that timed out or errored.
*
* @param params
* @param addresses
* @param protocol
* @param ticket
* @param conf
* @return Response Writable value array
* @throws IOException
* @throws InterruptedException
*/
public Writable[] call(Writable[] params, InetSocketAddress[] addresses,
Class<?> protocol, UserGroupInformation ticket, Configuration conf)
throws IOException, InterruptedException {
if (addresses.length == 0)
return new Writable[0];
ParallelResults results = new ParallelResults(params.length);
ConnectionId remoteId[] = new ConnectionId[addresses.length];
synchronized (results) {
for (int i = 0; i < params.length; i++) {
ParallelCall call = new ParallelCall(params[i], results, i);
try {
remoteId[i] = ConnectionId.getConnectionId(addresses[i], protocol,
ticket, 0, conf);
Connection connection = getConnection(remoteId[i], call);
connection.sendParam(call); // send each parameter
} catch (IOException e) {
// log errors
LOG.info("Calling " + addresses[i] + " caught: " + e.getMessage(), e);
results.size--; // wait for one fewer result
}
}
while (results.count != results.size) {
try {
results.wait(); // wait for all results
} catch (InterruptedException e) {
}
}
return results.values;
}
}
// for unit testing only
Set<ConnectionId> getConnectionIds() {
synchronized (connections) {
return connections.keySet();
}
}
/**
* Get a connection from the pool, or create a new one and add it to the pool.
* Connections to a given ConnectionId are reused.
*
* @param remoteId
* @param call
* @return connection
* @throws IOException
* @throws InterruptedException
*/
private synchronized Connection getConnection(ConnectionId remoteId, Call call)
throws IOException, InterruptedException {
if (!running.get()) {
// the client is stopped
throw new IOException("The client is stopped");
}
Connection connection;
/*
* we could avoid this allocation for each RPC by having a connectionsId
* object and with set() method. We need to manage the refs for keys in
* HashMap properly. For now its ok
*/
do {
connection = connections.get(remoteId);
if (connection == null) {
connection = new Connection(remoteId);
connections.put(remoteId, connection);
} else if (!connection.channel.isWritable()
|| !connection.channel.isActive()) {
connection = new Connection(remoteId);
connections.remove(remoteId);
connections.put(remoteId, connection);
}
} while (!connection.addCall(call));
// we don't invoke the method below inside "synchronized (connections)"
// block above. The reason for that is if the server happens to be slow,
// it will take longer to establish a connection and that will slow the
// entire system down.
connection.setupIOstreams();
return connection;
}
/**
* This class holds the address and the user ticket. The client connections to
* servers are uniquely identified by <remoteAddress, protocol, ticket>
*/
static class ConnectionId {
InetSocketAddress address;
UserGroupInformation ticket;
Class<?> protocol;
private static final int PRIME = 16777619;
private int rpcTimeout;
private String serverPrincipal;
private int maxIdleTime; // connections will be culled if it was idle for
// maxIdleTime msecs
private final RetryPolicy connectionRetryPolicy;
private boolean tcpNoDelay; // if T then disable Nagle's Algorithm
private int pingInterval; // how often sends ping to the server in msecs
ConnectionId(InetSocketAddress address, Class<?> protocol,
UserGroupInformation ticket, int rpcTimeout, String serverPrincipal,
int maxIdleTime, RetryPolicy connectionRetryPolicy, boolean tcpNoDelay,
int pingInterval) {
this.protocol = protocol;
this.address = address;
this.ticket = ticket;
this.rpcTimeout = rpcTimeout;
this.serverPrincipal = serverPrincipal;
this.maxIdleTime = maxIdleTime;
this.connectionRetryPolicy = connectionRetryPolicy;
this.tcpNoDelay = tcpNoDelay;
this.pingInterval = pingInterval;
}
InetSocketAddress getAddress() {
return address;
}
Class<?> getProtocol() {
return protocol;
}
private int getRpcTimeout() {
return rpcTimeout;
}
String getServerPrincipal() {
return serverPrincipal;
}
int getMaxIdleTime() {
return maxIdleTime;
}
boolean getTcpNoDelay() {
return tcpNoDelay;
}
int getPingInterval() {
return pingInterval;
}
static ConnectionId getConnectionId(InetSocketAddress addr,
Class<?> protocol, UserGroupInformation ticket, Configuration conf)
throws IOException {
return getConnectionId(addr, protocol, ticket, 0, conf);
}
static ConnectionId getConnectionId(InetSocketAddress addr,
Class<?> protocol, UserGroupInformation ticket, int rpcTimeout,
Configuration conf) throws IOException {
return getConnectionId(addr, protocol, ticket, rpcTimeout, null, conf);
}
static ConnectionId getConnectionId(InetSocketAddress addr,
Class<?> protocol, UserGroupInformation ticket, int rpcTimeout,
RetryPolicy connectionRetryPolicy, Configuration conf)
throws IOException {
if (connectionRetryPolicy == null) {
final int max = conf.getInt(IPC_CLIENT_CONNECT_MAX_RETRIES_KEY,
IPC_CLIENT_CONNECT_MAX_RETRIES_DEFAULT);
connectionRetryPolicy = RetryPolicies
.retryUpToMaximumCountWithFixedSleep(max, 1, TimeUnit.SECONDS);
}
return new ConnectionId(addr, protocol, ticket, rpcTimeout,
null,
conf.getInt("ipc.client.connection.maxidletime", 10000), // 10s
connectionRetryPolicy,
conf.getBoolean("ipc.client.tcpnodelay", true),
AsyncClient.getPingInterval(conf));
}
static boolean isEqual(Object a, Object b) {
return a == null ? b == null : a.equals(b);
}
@Override
public boolean equals(Object obj) {
if (obj == this) {
return true;
}
if (obj instanceof ConnectionId) {
ConnectionId that = (ConnectionId) obj;
return isEqual(this.address, that.address)
&& this.maxIdleTime == that.maxIdleTime
&& isEqual(this.connectionRetryPolicy, that.connectionRetryPolicy)
&& this.pingInterval == that.pingInterval
&& isEqual(this.protocol, that.protocol)
&& this.rpcTimeout == that.rpcTimeout
&& isEqual(this.serverPrincipal, that.serverPrincipal)
&& this.tcpNoDelay == that.tcpNoDelay
&& isEqual(this.ticket, that.ticket);
}
return false;
}
@Override
public int hashCode() {
int result = connectionRetryPolicy.hashCode();
result = PRIME * result + ((address == null) ? 0 : address.hashCode());
result = PRIME * result + maxIdleTime;
result = PRIME * result + pingInterval;
result = PRIME * result + ((protocol == null) ? 0 : protocol.hashCode());
result = PRIME * rpcTimeout;
result = PRIME * result
+ ((serverPrincipal == null) ? 0 : serverPrincipal.hashCode());
result = PRIME * result + (tcpNoDelay ? 1231 : 1237);
result = PRIME * result + ((ticket == null) ? 0 : ticket.hashCode());
return result;
}
}
}