| /** |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package org.apache.hadoop.ipc; |
| |
| import java.net.InetAddress; |
| import java.net.Socket; |
| import java.net.InetSocketAddress; |
| import java.net.SocketTimeoutException; |
| import java.net.UnknownHostException; |
| import java.io.IOException; |
| import java.io.DataInputStream; |
| import java.io.DataOutputStream; |
| import java.io.BufferedInputStream; |
| import java.io.BufferedOutputStream; |
| import java.io.FilterInputStream; |
| import java.io.InputStream; |
| import java.io.OutputStream; |
| |
| import java.security.PrivilegedExceptionAction; |
| import java.util.Hashtable; |
| import java.util.Iterator; |
| import java.util.Random; |
| import java.util.Set; |
| import java.util.Map.Entry; |
| import java.util.concurrent.atomic.AtomicBoolean; |
| import java.util.concurrent.atomic.AtomicLong; |
| |
| import javax.net.SocketFactory; |
| |
| import org.apache.commons.logging.*; |
| |
| import org.apache.hadoop.classification.InterfaceAudience; |
| import org.apache.hadoop.classification.InterfaceStability; |
| import org.apache.hadoop.conf.Configuration; |
| import org.apache.hadoop.fs.CommonConfigurationKeys; |
| import org.apache.hadoop.fs.CommonConfigurationKeysPublic; |
| import org.apache.hadoop.io.IOUtils; |
| import org.apache.hadoop.io.Text; |
| import org.apache.hadoop.io.Writable; |
| import org.apache.hadoop.io.WritableUtils; |
| import org.apache.hadoop.io.DataOutputBuffer; |
| import org.apache.hadoop.net.NetUtils; |
| import org.apache.hadoop.security.KerberosInfo; |
| import org.apache.hadoop.security.SaslRpcClient; |
| import org.apache.hadoop.security.SaslRpcServer.AuthMethod; |
| import org.apache.hadoop.security.SecurityUtil; |
| import org.apache.hadoop.security.UserGroupInformation; |
| import org.apache.hadoop.security.token.Token; |
| import org.apache.hadoop.security.token.TokenIdentifier; |
| import org.apache.hadoop.security.token.TokenSelector; |
| import org.apache.hadoop.security.token.TokenInfo; |
| import org.apache.hadoop.util.ReflectionUtils; |
| |
| /** A client for an IPC service. IPC calls take a single {@link Writable} as a |
| * parameter, and return a {@link Writable} as their value. A service runs on |
| * a port and is defined by a parameter class and a value class. |
| * |
| * @see Server |
| */ |
| public class Client { |
| |
| public static final Log LOG = |
| LogFactory.getLog(Client.class); |
| private Hashtable<ConnectionId, Connection> connections = |
| new Hashtable<ConnectionId, Connection>(); |
| |
| private Class<? extends Writable> valueClass; // class of call values |
| private int counter; // counter for call ids |
| private AtomicBoolean running = new AtomicBoolean(true); // if client runs |
| final private Configuration conf; |
| |
| private SocketFactory socketFactory; // how to create sockets |
| private int refCount = 1; |
| |
| final static int PING_CALL_ID = -1; |
| |
| /** |
| * set the ping interval value in configuration |
| * |
| * @param conf Configuration |
| * @param pingInterval the ping interval |
| */ |
| final public static void setPingInterval(Configuration conf, int pingInterval) { |
| conf.setInt(CommonConfigurationKeys.IPC_PING_INTERVAL_KEY, pingInterval); |
| } |
| |
| /** |
| * Get the ping interval from configuration; |
| * If not set in the configuration, return the default value. |
| * |
| * @param conf Configuration |
| * @return the ping interval |
| */ |
| final static int getPingInterval(Configuration conf) { |
| return conf.getInt(CommonConfigurationKeys.IPC_PING_INTERVAL_KEY, |
| CommonConfigurationKeys.IPC_PING_INTERVAL_DEFAULT); |
| } |
| |
| /** |
| * The time after which a RPC will timeout. |
| * If ping is not enabled (via ipc.client.ping), then the timeout value is the |
| * same as the pingInterval. |
| * If ping is enabled, then there is no timeout value. |
| * |
| * @param conf Configuration |
| * @return the timeout period in milliseconds. -1 if no timeout value is set |
| */ |
| final public static int getTimeout(Configuration conf) { |
| if (!conf.getBoolean(CommonConfigurationKeys.IPC_CLIENT_PING_KEY, true)) { |
| return getPingInterval(conf); |
| } |
| return -1; |
| } |
| |
| /** |
| * Increment this client's reference count |
| * |
| */ |
| synchronized void incCount() { |
| refCount++; |
| } |
| |
| /** |
| * Decrement this client's reference count |
| * |
| */ |
| synchronized void decCount() { |
| refCount--; |
| } |
| |
| /** |
| * Return if this client has no reference |
| * |
| * @return true if this client has no reference; false otherwise |
| */ |
| synchronized boolean isZeroReference() { |
| return refCount==0; |
| } |
| |
| /** A call waiting for a value. */ |
| private class Call { |
| int id; // call id |
| Writable param; // parameter |
| Writable value; // value, null if error |
| IOException error; // exception, null if value |
| boolean done; // true when call is done |
| |
| protected Call(Writable param) { |
| this.param = param; |
| synchronized (Client.this) { |
| this.id = counter++; |
| } |
| } |
| |
| /** Indicate when the call is complete and the |
| * value or error are available. Notifies by default. */ |
| protected synchronized void callComplete() { |
| this.done = true; |
| notify(); // notify caller |
| } |
| |
| /** Set the exception when there is an error. |
| * Notify the caller the call is done. |
| * |
| * @param error exception thrown by the call; either local or remote |
| */ |
| public synchronized void setException(IOException error) { |
| this.error = error; |
| callComplete(); |
| } |
| |
| /** Set the return value when there is no error. |
| * Notify the caller the call is done. |
| * |
| * @param value return value of the call. |
| */ |
| public synchronized void setValue(Writable value) { |
| this.value = value; |
| callComplete(); |
| } |
| |
| public synchronized Writable getValue() { |
| return value; |
| } |
| } |
| |
| /** Thread that reads responses and notifies callers. Each connection owns a |
| * socket connected to a remote address. Calls are multiplexed through this |
| * socket: responses may be delivered out of order. */ |
| private class Connection extends Thread { |
| private InetSocketAddress server; // server ip:port |
| private String serverPrincipal; // server's krb5 principal name |
| private ConnectionHeader header; // connection header |
| private final ConnectionId remoteId; // connection id |
| private AuthMethod authMethod; // authentication method |
| private boolean useSasl; |
| private Token<? extends TokenIdentifier> token; |
| private SaslRpcClient saslRpcClient; |
| |
| private Socket socket = null; // connected socket |
| private DataInputStream in; |
| private DataOutputStream out; |
| private int rpcTimeout; |
| private int maxIdleTime; //connections will be culled if it was idle for |
| //maxIdleTime msecs |
| private int maxRetries; //the max. no. of retries for socket connections |
| private boolean tcpNoDelay; // if T then disable Nagle's Algorithm |
| private boolean doPing; //do we need to send ping message |
| private int pingInterval; // how often sends ping to the server in msecs |
| |
| // currently active calls |
| private Hashtable<Integer, Call> calls = new Hashtable<Integer, Call>(); |
| private AtomicLong lastActivity = new AtomicLong();// last I/O activity time |
| private AtomicBoolean shouldCloseConnection = new AtomicBoolean(); // indicate if the connection is closed |
| private IOException closeException; // close reason |
| |
| public Connection(ConnectionId remoteId) throws IOException { |
| this.remoteId = remoteId; |
| this.server = remoteId.getAddress(); |
| if (server.isUnresolved()) { |
| throw NetUtils.wrapException(remoteId.getAddress().getHostName(), |
| remoteId.getAddress().getPort(), |
| null, |
| 0, |
| new UnknownHostException()); |
| } |
| this.rpcTimeout = remoteId.getRpcTimeout(); |
| this.maxIdleTime = remoteId.getMaxIdleTime(); |
| this.maxRetries = remoteId.getMaxRetries(); |
| this.tcpNoDelay = remoteId.getTcpNoDelay(); |
| this.doPing = remoteId.getDoPing(); |
| this.pingInterval = remoteId.getPingInterval(); |
| if (LOG.isDebugEnabled()) { |
| LOG.debug("The ping interval is " + this.pingInterval + " ms."); |
| } |
| |
| UserGroupInformation ticket = remoteId.getTicket(); |
| Class<?> protocol = remoteId.getProtocol(); |
| this.useSasl = UserGroupInformation.isSecurityEnabled(); |
| if (useSasl && protocol != null) { |
| TokenInfo tokenInfo = SecurityUtil.getTokenInfo(protocol, conf); |
| if (tokenInfo != null) { |
| TokenSelector<? extends TokenIdentifier> tokenSelector = null; |
| try { |
| tokenSelector = tokenInfo.value().newInstance(); |
| } catch (InstantiationException e) { |
| throw new IOException(e.toString()); |
| } catch (IllegalAccessException e) { |
| throw new IOException(e.toString()); |
| } |
| InetSocketAddress addr = remoteId.getAddress(); |
| token = tokenSelector.selectToken(new Text(addr.getAddress() |
| .getHostAddress() + ":" + addr.getPort()), |
| ticket.getTokens()); |
| } |
| KerberosInfo krbInfo = SecurityUtil.getKerberosInfo(protocol, conf); |
| if (krbInfo != null) { |
| serverPrincipal = remoteId.getServerPrincipal(); |
| if (LOG.isDebugEnabled()) { |
| LOG.debug("RPC Server's Kerberos principal name for protocol=" |
| + protocol.getCanonicalName() + " is " + serverPrincipal); |
| } |
| } |
| } |
| |
| if (!useSasl) { |
| authMethod = AuthMethod.SIMPLE; |
| } else if (token != null) { |
| authMethod = AuthMethod.DIGEST; |
| } else { |
| authMethod = AuthMethod.KERBEROS; |
| } |
| |
| header = new ConnectionHeader(protocol == null ? null : protocol |
| .getName(), ticket, authMethod); |
| |
| if (LOG.isDebugEnabled()) |
| LOG.debug("Use " + authMethod + " authentication for protocol " |
| + protocol.getSimpleName()); |
| |
| this.setName("IPC Client (" + socketFactory.hashCode() +") connection to " + |
| remoteId.getAddress().toString() + |
| " from " + ((ticket==null)?"an unknown user":ticket.getUserName())); |
| this.setDaemon(true); |
| } |
| |
| /** Update lastActivity with the current time. */ |
| private void touch() { |
| lastActivity.set(System.currentTimeMillis()); |
| } |
| |
| /** |
| * Add a call to this connection's call queue and notify |
| * a listener; synchronized. |
| * Returns false if called during shutdown. |
| * @param call to add |
| * @return true if the call was added. |
| */ |
| private synchronized boolean addCall(Call call) { |
| if (shouldCloseConnection.get()) |
| return false; |
| calls.put(call.id, call); |
| notify(); |
| return true; |
| } |
| |
| /** This class sends a ping to the remote side when timeout on |
| * reading. If no failure is detected, it retries until at least |
| * a byte is read. |
| */ |
| private class PingInputStream extends FilterInputStream { |
| /* constructor */ |
| protected PingInputStream(InputStream in) { |
| super(in); |
| } |
| |
| /* Process timeout exception |
| * if the connection is not going to be closed or |
| * is not configured to have a RPC timeout, send a ping. |
| * (if rpcTimeout is not set to be 0, then RPC should timeout. |
| * otherwise, throw the timeout exception. |
| */ |
| private void handleTimeout(SocketTimeoutException e) throws IOException { |
| if (shouldCloseConnection.get() || !running.get() || rpcTimeout > 0) { |
| throw e; |
| } else { |
| sendPing(); |
| } |
| } |
| |
| /** Read a byte from the stream. |
| * Send a ping if timeout on read. Retries if no failure is detected |
| * until a byte is read. |
| * @throws IOException for any IO problem other than socket timeout |
| */ |
| public int read() throws IOException { |
| do { |
| try { |
| return super.read(); |
| } catch (SocketTimeoutException e) { |
| handleTimeout(e); |
| } |
| } while (true); |
| } |
| |
| /** Read bytes into a buffer starting from offset <code>off</code> |
| * Send a ping if timeout on read. Retries if no failure is detected |
| * until a byte is read. |
| * |
| * @return the total number of bytes read; -1 if the connection is closed. |
| */ |
| public int read(byte[] buf, int off, int len) throws IOException { |
| do { |
| try { |
| return super.read(buf, off, len); |
| } catch (SocketTimeoutException e) { |
| handleTimeout(e); |
| } |
| } while (true); |
| } |
| } |
| |
| private synchronized void disposeSasl() { |
| if (saslRpcClient != null) { |
| try { |
| saslRpcClient.dispose(); |
| saslRpcClient = null; |
| } catch (IOException ignored) { |
| } |
| } |
| } |
| |
| private synchronized boolean shouldAuthenticateOverKrb() throws IOException { |
| UserGroupInformation loginUser = UserGroupInformation.getLoginUser(); |
| UserGroupInformation currentUser = UserGroupInformation.getCurrentUser(); |
| UserGroupInformation realUser = currentUser.getRealUser(); |
| if (authMethod == AuthMethod.KERBEROS && loginUser != null && |
| // Make sure user logged in using Kerberos either keytab or TGT |
| loginUser.hasKerberosCredentials() && |
| // relogin only in case it is the login user (e.g. JT) |
| // or superuser (like oozie). |
| (loginUser.equals(currentUser) || loginUser.equals(realUser))) { |
| return true; |
| } |
| return false; |
| } |
| |
| private synchronized boolean setupSaslConnection(final InputStream in2, |
| final OutputStream out2) |
| throws IOException { |
| saslRpcClient = new SaslRpcClient(authMethod, token, serverPrincipal); |
| return saslRpcClient.saslConnect(in2, out2); |
| } |
| |
| /** |
| * Update the server address if the address corresponding to the host |
| * name has changed. |
| * |
| * @return true if an addr change was detected. |
| * @throws IOException when the hostname cannot be resolved. |
| */ |
| private synchronized boolean updateAddress() throws IOException { |
| // Do a fresh lookup with the old host name. |
| InetSocketAddress currentAddr = NetUtils.createSocketAddrForHost( |
| server.getHostName(), server.getPort()); |
| |
| if (!server.equals(currentAddr)) { |
| LOG.warn("Address change detected. Old: " + server.toString() + |
| " New: " + currentAddr.toString()); |
| server = currentAddr; |
| return true; |
| } |
| return false; |
| } |
| |
| private synchronized void setupConnection() throws IOException { |
| short ioFailures = 0; |
| short timeoutFailures = 0; |
| while (true) { |
| try { |
| this.socket = socketFactory.createSocket(); |
| this.socket.setTcpNoDelay(tcpNoDelay); |
| |
| /* |
| * Bind the socket to the host specified in the principal name of the |
| * client, to ensure Server matching address of the client connection |
| * to host name in principal passed. |
| */ |
| if (UserGroupInformation.isSecurityEnabled()) { |
| KerberosInfo krbInfo = |
| remoteId.getProtocol().getAnnotation(KerberosInfo.class); |
| if (krbInfo != null && krbInfo.clientPrincipal() != null) { |
| String host = |
| SecurityUtil.getHostFromPrincipal(remoteId.getTicket().getUserName()); |
| |
| // If host name is a valid local address then bind socket to it |
| InetAddress localAddr = NetUtils.getLocalInetAddress(host); |
| if (localAddr != null) { |
| this.socket.bind(new InetSocketAddress(localAddr, 0)); |
| } |
| } |
| } |
| |
| // connection time out is 20s |
| NetUtils.connect(this.socket, server, 20000); |
| if (rpcTimeout > 0) { |
| pingInterval = rpcTimeout; // rpcTimeout overwrites pingInterval |
| } |
| this.socket.setSoTimeout(pingInterval); |
| return; |
| } catch (SocketTimeoutException toe) { |
| /* Check for an address change and update the local reference. |
| * Reset the failure counter if the address was changed |
| */ |
| if (updateAddress()) { |
| timeoutFailures = ioFailures = 0; |
| } |
| /* |
| * The max number of retries is 45, which amounts to 20s*45 = 15 |
| * minutes retries. |
| */ |
| handleConnectionFailure(timeoutFailures++, 45, toe); |
| } catch (IOException ie) { |
| if (updateAddress()) { |
| timeoutFailures = ioFailures = 0; |
| } |
| handleConnectionFailure(ioFailures++, maxRetries, ie); |
| } |
| } |
| } |
| |
| /** |
| * If multiple clients with the same principal try to connect to the same |
| * server at the same time, the server assumes a replay attack is in |
| * progress. This is a feature of kerberos. In order to work around this, |
| * what is done is that the client backs off randomly and tries to initiate |
| * the connection again. The other problem is to do with ticket expiry. To |
| * handle that, a relogin is attempted. |
| */ |
| private synchronized void handleSaslConnectionFailure( |
| final int currRetries, final int maxRetries, final Exception ex, |
| final Random rand, final UserGroupInformation ugi) throws IOException, |
| InterruptedException { |
| ugi.doAs(new PrivilegedExceptionAction<Object>() { |
| public Object run() throws IOException, InterruptedException { |
| final short MAX_BACKOFF = 5000; |
| closeConnection(); |
| disposeSasl(); |
| if (shouldAuthenticateOverKrb()) { |
| if (currRetries < maxRetries) { |
| if(LOG.isDebugEnabled()) { |
| LOG.debug("Exception encountered while connecting to " |
| + "the server : " + ex); |
| } |
| // try re-login |
| if (UserGroupInformation.isLoginKeytabBased()) { |
| UserGroupInformation.getLoginUser().reloginFromKeytab(); |
| } else { |
| UserGroupInformation.getLoginUser().reloginFromTicketCache(); |
| } |
| // have granularity of milliseconds |
| //we are sleeping with the Connection lock held but since this |
| //connection instance is being used for connecting to the server |
| //in question, it is okay |
| Thread.sleep((rand.nextInt(MAX_BACKOFF) + 1)); |
| return null; |
| } else { |
| String msg = "Couldn't setup connection for " |
| + UserGroupInformation.getLoginUser().getUserName() + " to " |
| + serverPrincipal; |
| LOG.warn(msg); |
| throw (IOException) new IOException(msg).initCause(ex); |
| } |
| } else { |
| LOG.warn("Exception encountered while connecting to " |
| + "the server : " + ex); |
| } |
| if (ex instanceof RemoteException) |
| throw (RemoteException) ex; |
| throw new IOException(ex); |
| } |
| }); |
| } |
| |
| |
| /** Connect to the server and set up the I/O streams. It then sends |
| * a header to the server and starts |
| * the connection thread that waits for responses. |
| */ |
| private synchronized void setupIOstreams() throws InterruptedException { |
| if (socket != null || shouldCloseConnection.get()) { |
| return; |
| } |
| try { |
| if (LOG.isDebugEnabled()) { |
| LOG.debug("Connecting to "+server); |
| } |
| short numRetries = 0; |
| final short MAX_RETRIES = 5; |
| Random rand = null; |
| while (true) { |
| setupConnection(); |
| InputStream inStream = NetUtils.getInputStream(socket); |
| OutputStream outStream = NetUtils.getOutputStream(socket); |
| writeRpcHeader(outStream); |
| if (useSasl) { |
| final InputStream in2 = inStream; |
| final OutputStream out2 = outStream; |
| UserGroupInformation ticket = remoteId.getTicket(); |
| if (authMethod == AuthMethod.KERBEROS) { |
| if (ticket.getRealUser() != null) { |
| ticket = ticket.getRealUser(); |
| } |
| } |
| boolean continueSasl = false; |
| try { |
| continueSasl = ticket |
| .doAs(new PrivilegedExceptionAction<Boolean>() { |
| @Override |
| public Boolean run() throws IOException { |
| return setupSaslConnection(in2, out2); |
| } |
| }); |
| } catch (Exception ex) { |
| if (rand == null) { |
| rand = new Random(); |
| } |
| handleSaslConnectionFailure(numRetries++, MAX_RETRIES, ex, rand, |
| ticket); |
| continue; |
| } |
| if (continueSasl) { |
| // Sasl connect is successful. Let's set up Sasl i/o streams. |
| inStream = saslRpcClient.getInputStream(inStream); |
| outStream = saslRpcClient.getOutputStream(outStream); |
| } else { |
| // fall back to simple auth because server told us so. |
| authMethod = AuthMethod.SIMPLE; |
| header = new ConnectionHeader(header.getProtocol(), header |
| .getUgi(), authMethod); |
| useSasl = false; |
| } |
| } |
| |
| if (doPing) { |
| this.in = new DataInputStream(new BufferedInputStream( |
| new PingInputStream(inStream))); |
| } else { |
| this.in = new DataInputStream(new BufferedInputStream(inStream)); |
| } |
| this.out = new DataOutputStream(new BufferedOutputStream(outStream)); |
| writeHeader(); |
| |
| // update last activity time |
| touch(); |
| |
| // start the receiver thread after the socket connection has been set |
| // up |
| start(); |
| return; |
| } |
| } catch (Throwable t) { |
| if (t instanceof IOException) { |
| markClosed((IOException)t); |
| } else { |
| markClosed(new IOException("Couldn't set up IO streams", t)); |
| } |
| close(); |
| } |
| } |
| |
| private void closeConnection() { |
| if (socket == null) { |
| return; |
| } |
| // close the current connection |
| try { |
| socket.close(); |
| } catch (IOException e) { |
| LOG.warn("Not able to close a socket", e); |
| } |
| // set socket to null so that the next call to setupIOstreams |
| // can start the process of connect all over again. |
| socket = null; |
| } |
| |
| /* Handle connection failures |
| * |
| * If the current number of retries is equal to the max number of retries, |
| * stop retrying and throw the exception; Otherwise backoff 1 second and |
| * try connecting again. |
| * |
| * This Method is only called from inside setupIOstreams(), which is |
| * synchronized. Hence the sleep is synchronized; the locks will be retained. |
| * |
| * @param curRetries current number of retries |
| * @param maxRetries max number of retries allowed |
| * @param ioe failure reason |
| * @throws IOException if max number of retries is reached |
| */ |
| private void handleConnectionFailure( |
| int curRetries, int maxRetries, IOException ioe) throws IOException { |
| |
| closeConnection(); |
| |
| // throw the exception if the maximum number of retries is reached |
| if (curRetries >= maxRetries) { |
| throw ioe; |
| } |
| |
| // otherwise back off and retry |
| try { |
| Thread.sleep(1000); |
| } catch (InterruptedException ignored) {} |
| |
| LOG.info("Retrying connect to server: " + server + |
| ". Already tried " + curRetries + " time(s)."); |
| } |
| |
| /* Write the RPC header */ |
| private void writeRpcHeader(OutputStream outStream) throws IOException { |
| DataOutputStream out = new DataOutputStream(new BufferedOutputStream(outStream)); |
| // Write out the header, version and authentication method |
| out.write(Server.HEADER.array()); |
| out.write(Server.CURRENT_VERSION); |
| authMethod.write(out); |
| out.flush(); |
| } |
| |
| /* Write the protocol header for each connection |
| * Out is not synchronized because only the first thread does this. |
| */ |
| private void writeHeader() throws IOException { |
| // Write out the ConnectionHeader |
| DataOutputBuffer buf = new DataOutputBuffer(); |
| header.write(buf); |
| |
| // Write out the payload length |
| int bufLen = buf.getLength(); |
| out.writeInt(bufLen); |
| out.write(buf.getData(), 0, bufLen); |
| } |
| |
| /* wait till someone signals us to start reading RPC response or |
| * it is idle too long, it is marked as to be closed, |
| * or the client is marked as not running. |
| * |
| * Return true if it is time to read a response; false otherwise. |
| */ |
| private synchronized boolean waitForWork() { |
| if (calls.isEmpty() && !shouldCloseConnection.get() && running.get()) { |
| long timeout = maxIdleTime- |
| (System.currentTimeMillis()-lastActivity.get()); |
| if (timeout>0) { |
| try { |
| wait(timeout); |
| } catch (InterruptedException e) {} |
| } |
| } |
| |
| if (!calls.isEmpty() && !shouldCloseConnection.get() && running.get()) { |
| return true; |
| } else if (shouldCloseConnection.get()) { |
| return false; |
| } else if (calls.isEmpty()) { // idle connection closed or stopped |
| markClosed(null); |
| return false; |
| } else { // get stopped but there are still pending requests |
| markClosed((IOException)new IOException().initCause( |
| new InterruptedException())); |
| return false; |
| } |
| } |
| |
| public InetSocketAddress getRemoteAddress() { |
| return server; |
| } |
| |
| /* Send a ping to the server if the time elapsed |
| * since last I/O activity is equal to or greater than the ping interval |
| */ |
| private synchronized void sendPing() throws IOException { |
| long curTime = System.currentTimeMillis(); |
| if ( curTime - lastActivity.get() >= pingInterval) { |
| lastActivity.set(curTime); |
| synchronized (out) { |
| out.writeInt(PING_CALL_ID); |
| out.flush(); |
| } |
| } |
| } |
| |
| public void run() { |
| if (LOG.isDebugEnabled()) |
| LOG.debug(getName() + ": starting, having connections " |
| + connections.size()); |
| |
| try { |
| while (waitForWork()) {//wait here for work - read or close connection |
| receiveResponse(); |
| } |
| } catch (Throwable t) { |
| // This truly is unexpected, since we catch IOException in receiveResponse |
| // -- this is only to be really sure that we don't leave a client hanging |
| // forever. |
| LOG.warn("Unexpected error reading responses on connection " + this, t); |
| markClosed(new IOException("Error reading responses", t)); |
| } |
| |
| close(); |
| |
| if (LOG.isDebugEnabled()) |
| LOG.debug(getName() + ": stopped, remaining connections " |
| + connections.size()); |
| } |
| |
| /** Initiates a call by sending the parameter to the remote server. |
| * Note: this is not called from the Connection thread, but by other |
| * threads. |
| */ |
| public void sendParam(Call call) { |
| if (shouldCloseConnection.get()) { |
| return; |
| } |
| |
| DataOutputBuffer d=null; |
| try { |
| synchronized (this.out) { |
| if (LOG.isDebugEnabled()) |
| LOG.debug(getName() + " sending #" + call.id); |
| |
| //for serializing the |
| //data to be written |
| d = new DataOutputBuffer(); |
| d.writeInt(call.id); |
| call.param.write(d); |
| byte[] data = d.getData(); |
| int dataLength = d.getLength(); |
| out.writeInt(dataLength); //first put the data length |
| out.write(data, 0, dataLength);//write the data |
| out.flush(); |
| } |
| } catch(IOException e) { |
| markClosed(e); |
| } finally { |
| //the buffer is just an in-memory buffer, but it is still polite to |
| // close early |
| IOUtils.closeStream(d); |
| } |
| } |
| |
| /* Receive a response. |
| * Because only one receiver, so no synchronization on in. |
| */ |
| private void receiveResponse() { |
| if (shouldCloseConnection.get()) { |
| return; |
| } |
| touch(); |
| |
| try { |
| int id = in.readInt(); // try to read an id |
| |
| if (LOG.isDebugEnabled()) |
| LOG.debug(getName() + " got value #" + id); |
| |
| Call call = calls.get(id); |
| |
| int state = in.readInt(); // read call status |
| if (state == Status.SUCCESS.state) { |
| Writable value = ReflectionUtils.newInstance(valueClass, conf); |
| value.readFields(in); // read value |
| call.setValue(value); |
| calls.remove(id); |
| } else if (state == Status.ERROR.state) { |
| call.setException(new RemoteException(WritableUtils.readString(in), |
| WritableUtils.readString(in))); |
| calls.remove(id); |
| } else if (state == Status.FATAL.state) { |
| // Close the connection |
| markClosed(new RemoteException(WritableUtils.readString(in), |
| WritableUtils.readString(in))); |
| } |
| } catch (IOException e) { |
| markClosed(e); |
| } |
| } |
| |
| private synchronized void markClosed(IOException e) { |
| if (shouldCloseConnection.compareAndSet(false, true)) { |
| closeException = e; |
| notifyAll(); |
| } |
| } |
| |
| /** Close the connection. */ |
| private synchronized void close() { |
| if (!shouldCloseConnection.get()) { |
| LOG.error("The connection is not in the closed state"); |
| return; |
| } |
| |
| // release the resources |
| // first thing to do;take the connection out of the connection list |
| synchronized (connections) { |
| if (connections.get(remoteId) == this) { |
| connections.remove(remoteId); |
| } |
| } |
| |
| // close the streams and therefore the socket |
| IOUtils.closeStream(out); |
| IOUtils.closeStream(in); |
| disposeSasl(); |
| |
| // clean up all calls |
| if (closeException == null) { |
| if (!calls.isEmpty()) { |
| LOG.warn( |
| "A connection is closed for no cause and calls are not empty"); |
| |
| // clean up calls anyway |
| closeException = new IOException("Unexpected closed connection"); |
| cleanupCalls(); |
| } |
| } else { |
| // log the info |
| if (LOG.isDebugEnabled()) { |
| LOG.debug("closing ipc connection to " + server + ": " + |
| closeException.getMessage(),closeException); |
| } |
| |
| // cleanup calls |
| cleanupCalls(); |
| } |
| if (LOG.isDebugEnabled()) |
| LOG.debug(getName() + ": closed"); |
| } |
| |
| /* Cleanup all calls and mark them as done */ |
| private void cleanupCalls() { |
| Iterator<Entry<Integer, Call>> itor = calls.entrySet().iterator() ; |
| while (itor.hasNext()) { |
| Call c = itor.next().getValue(); |
| c.setException(closeException); // local exception |
| itor.remove(); |
| } |
| } |
| } |
| |
| /** Call implementation used for parallel calls. */ |
| private class ParallelCall extends Call { |
| private ParallelResults results; |
| private int index; |
| |
| public ParallelCall(Writable param, ParallelResults results, int index) { |
| super(param); |
| this.results = results; |
| this.index = index; |
| } |
| |
| /** Deliver result to result collector. */ |
| protected void callComplete() { |
| results.callComplete(this); |
| } |
| } |
| |
| /** Result collector for parallel calls. */ |
| private static class ParallelResults { |
| private Writable[] values; |
| private int size; |
| private int count; |
| |
| public ParallelResults(int size) { |
| this.values = new Writable[size]; |
| this.size = size; |
| } |
| |
| /** Collect a result. */ |
| public synchronized void callComplete(ParallelCall call) { |
| values[call.index] = call.getValue(); // store the value |
| count++; // count it |
| if (count == size) // if all values are in |
| notify(); // then notify waiting caller |
| } |
| } |
| |
| /** Construct an IPC client whose values are of the given {@link Writable} |
| * class. */ |
| public Client(Class<? extends Writable> valueClass, Configuration conf, |
| SocketFactory factory) { |
| this.valueClass = valueClass; |
| this.conf = conf; |
| this.socketFactory = factory; |
| } |
| |
| /** |
| * Construct an IPC client with the default SocketFactory |
| * @param valueClass |
| * @param conf |
| */ |
| public Client(Class<? extends Writable> valueClass, Configuration conf) { |
| this(valueClass, conf, NetUtils.getDefaultSocketFactory(conf)); |
| } |
| |
| /** Return the socket factory of this client |
| * |
| * @return this client's socket factory |
| */ |
| SocketFactory getSocketFactory() { |
| return socketFactory; |
| } |
| |
| /** Stop all threads related to this client. No further calls may be made |
| * using this client. */ |
| public void stop() { |
| if (LOG.isDebugEnabled()) { |
| LOG.debug("Stopping client"); |
| } |
| |
| if (!running.compareAndSet(true, false)) { |
| return; |
| } |
| |
| // wake up all connections |
| synchronized (connections) { |
| for (Connection conn : connections.values()) { |
| conn.interrupt(); |
| } |
| } |
| |
| // wait until all connections are closed |
| while (!connections.isEmpty()) { |
| try { |
| Thread.sleep(100); |
| } catch (InterruptedException e) { |
| } |
| } |
| } |
| |
| /** Make a call, passing <code>param</code>, to the IPC server running at |
| * <code>address</code>, returning the value. Throws exceptions if there are |
| * network problems or if the remote code threw an exception. |
| * @deprecated Use {@link #call(Writable, ConnectionId)} instead |
| */ |
| @Deprecated |
| public Writable call(Writable param, InetSocketAddress address) |
| throws InterruptedException, IOException { |
| return call(param, address, null); |
| } |
| |
| /** Make a call, passing <code>param</code>, to the IPC server running at |
| * <code>address</code> with the <code>ticket</code> credentials, returning |
| * the value. |
| * Throws exceptions if there are network problems or if the remote code |
| * threw an exception. |
| * @deprecated Use {@link #call(Writable, ConnectionId)} instead |
| */ |
| @Deprecated |
| public Writable call(Writable param, InetSocketAddress addr, |
| UserGroupInformation ticket) |
| throws InterruptedException, IOException { |
| ConnectionId remoteId = ConnectionId.getConnectionId(addr, null, ticket, 0, |
| conf); |
| return call(param, remoteId); |
| } |
| |
| /** Make a call, passing <code>param</code>, to the IPC server running at |
| * <code>address</code> which is servicing the <code>protocol</code> protocol, |
| * with the <code>ticket</code> credentials and <code>rpcTimeout</code> as |
| * timeout, returning the value. |
| * Throws exceptions if there are network problems or if the remote code |
| * threw an exception. |
| * @deprecated Use {@link #call(Writable, ConnectionId)} instead |
| */ |
| @Deprecated |
| public Writable call(Writable param, InetSocketAddress addr, |
| Class<?> protocol, UserGroupInformation ticket, |
| int rpcTimeout) |
| throws InterruptedException, IOException { |
| ConnectionId remoteId = ConnectionId.getConnectionId(addr, protocol, |
| ticket, rpcTimeout, conf); |
| return call(param, remoteId); |
| } |
| |
| /** |
| * Make a call, passing <code>param</code>, to the IPC server running at |
| * <code>address</code> which is servicing the <code>protocol</code> protocol, |
| * with the <code>ticket</code> credentials, <code>rpcTimeout</code> as |
| * timeout and <code>conf</code> as conf for this connection, returning the |
| * value. Throws exceptions if there are network problems or if the remote |
| * code threw an exception. |
| */ |
| public Writable call(Writable param, InetSocketAddress addr, |
| Class<?> protocol, UserGroupInformation ticket, |
| int rpcTimeout, Configuration conf) |
| throws InterruptedException, IOException { |
| ConnectionId remoteId = ConnectionId.getConnectionId(addr, protocol, |
| ticket, rpcTimeout, conf); |
| return call(param, remoteId); |
| } |
| |
| /** Make a call, passing <code>param</code>, to the IPC server defined by |
| * <code>remoteId</code>, returning the value. |
| * Throws exceptions if there are network problems or if the remote code |
| * threw an exception. */ |
| public Writable call(Writable param, ConnectionId remoteId) |
| throws InterruptedException, IOException { |
| Call call = new Call(param); |
| Connection connection = getConnection(remoteId, call); |
| connection.sendParam(call); // send the parameter |
| boolean interrupted = false; |
| synchronized (call) { |
| while (!call.done) { |
| try { |
| call.wait(); // wait for the result |
| } catch (InterruptedException ie) { |
| // save the fact that we were interrupted |
| interrupted = true; |
| } |
| } |
| |
| if (interrupted) { |
| // set the interrupt flag now that we are done waiting |
| Thread.currentThread().interrupt(); |
| } |
| |
| if (call.error != null) { |
| if (call.error instanceof RemoteException) { |
| call.error.fillInStackTrace(); |
| throw call.error; |
| } else { // local exception |
| InetSocketAddress address = remoteId.getAddress(); |
| throw NetUtils.wrapException(address.getHostName(), |
| address.getPort(), |
| NetUtils.getHostname(), |
| 0, |
| call.error); |
| } |
| } else { |
| return call.value; |
| } |
| } |
| } |
| |
| /** |
| * @deprecated Use {@link #call(Writable[], InetSocketAddress[], |
| * Class, UserGroupInformation, Configuration)} instead |
| */ |
| @Deprecated |
| public Writable[] call(Writable[] params, InetSocketAddress[] addresses) |
| throws IOException, InterruptedException { |
| return call(params, addresses, null, null, conf); |
| } |
| |
| /** |
| * @deprecated Use {@link #call(Writable[], InetSocketAddress[], |
| * Class, UserGroupInformation, Configuration)} instead |
| */ |
| @Deprecated |
| public Writable[] call(Writable[] params, InetSocketAddress[] addresses, |
| Class<?> protocol, UserGroupInformation ticket) |
| throws IOException, InterruptedException { |
| return call(params, addresses, protocol, ticket, conf); |
| } |
| |
| |
| /** Makes a set of calls in parallel. Each parameter is sent to the |
| * corresponding address. When all values are available, or have timed out |
| * or errored, the collected results are returned in an array. The array |
| * contains nulls for calls that timed out or errored. */ |
| public Writable[] call(Writable[] params, InetSocketAddress[] addresses, |
| Class<?> protocol, UserGroupInformation ticket, Configuration conf) |
| throws IOException, InterruptedException { |
| if (addresses.length == 0) return new Writable[0]; |
| |
| ParallelResults results = new ParallelResults(params.length); |
| synchronized (results) { |
| for (int i = 0; i < params.length; i++) { |
| ParallelCall call = new ParallelCall(params[i], results, i); |
| try { |
| ConnectionId remoteId = ConnectionId.getConnectionId(addresses[i], |
| protocol, ticket, 0, conf); |
| Connection connection = getConnection(remoteId, call); |
| connection.sendParam(call); // send each parameter |
| } catch (IOException e) { |
| // log errors |
| LOG.info("Calling "+addresses[i]+" caught: " + |
| e.getMessage(),e); |
| results.size--; // wait for one fewer result |
| } |
| } |
| while (results.count != results.size) { |
| try { |
| results.wait(); // wait for all results |
| } catch (InterruptedException e) {} |
| } |
| |
| return results.values; |
| } |
| } |
| |
| // for unit testing only |
| @InterfaceAudience.Private |
| @InterfaceStability.Unstable |
| Set<ConnectionId> getConnectionIds() { |
| synchronized (connections) { |
| return connections.keySet(); |
| } |
| } |
| |
| /** Get a connection from the pool, or create a new one and add it to the |
| * pool. Connections to a given ConnectionId are reused. */ |
| private Connection getConnection(ConnectionId remoteId, |
| Call call) |
| throws IOException, InterruptedException { |
| if (!running.get()) { |
| // the client is stopped |
| throw new IOException("The client is stopped"); |
| } |
| Connection connection; |
| /* we could avoid this allocation for each RPC by having a |
| * connectionsId object and with set() method. We need to manage the |
| * refs for keys in HashMap properly. For now its ok. |
| */ |
| do { |
| synchronized (connections) { |
| connection = connections.get(remoteId); |
| if (connection == null) { |
| connection = new Connection(remoteId); |
| connections.put(remoteId, connection); |
| } |
| } |
| } while (!connection.addCall(call)); |
| |
| //we don't invoke the method below inside "synchronized (connections)" |
| //block above. The reason for that is if the server happens to be slow, |
| //it will take longer to establish a connection and that will slow the |
| //entire system down. |
| connection.setupIOstreams(); |
| return connection; |
| } |
| |
| /** |
| * This class holds the address and the user ticket. The client connections |
| * to servers are uniquely identified by <remoteAddress, protocol, ticket> |
| */ |
| @InterfaceAudience.LimitedPrivate({"HDFS", "MapReduce"}) |
| @InterfaceStability.Evolving |
| public static class ConnectionId { |
| InetSocketAddress address; |
| UserGroupInformation ticket; |
| Class<?> protocol; |
| private static final int PRIME = 16777619; |
| private int rpcTimeout; |
| private String serverPrincipal; |
| private int maxIdleTime; //connections will be culled if it was idle for |
| //maxIdleTime msecs |
| private int maxRetries; //the max. no. of retries for socket connections |
| private boolean tcpNoDelay; // if T then disable Nagle's Algorithm |
| private boolean doPing; //do we need to send ping message |
| private int pingInterval; // how often sends ping to the server in msecs |
| |
| ConnectionId(InetSocketAddress address, Class<?> protocol, |
| UserGroupInformation ticket, int rpcTimeout, |
| String serverPrincipal, int maxIdleTime, |
| int maxRetries, boolean tcpNoDelay, |
| boolean doPing, int pingInterval) { |
| this.protocol = protocol; |
| this.address = address; |
| this.ticket = ticket; |
| this.rpcTimeout = rpcTimeout; |
| this.serverPrincipal = serverPrincipal; |
| this.maxIdleTime = maxIdleTime; |
| this.maxRetries = maxRetries; |
| this.tcpNoDelay = tcpNoDelay; |
| this.doPing = doPing; |
| this.pingInterval = pingInterval; |
| } |
| |
| InetSocketAddress getAddress() { |
| return address; |
| } |
| |
| Class<?> getProtocol() { |
| return protocol; |
| } |
| |
| UserGroupInformation getTicket() { |
| return ticket; |
| } |
| |
| private int getRpcTimeout() { |
| return rpcTimeout; |
| } |
| |
| String getServerPrincipal() { |
| return serverPrincipal; |
| } |
| |
| int getMaxIdleTime() { |
| return maxIdleTime; |
| } |
| |
| int getMaxRetries() { |
| return maxRetries; |
| } |
| |
| boolean getTcpNoDelay() { |
| return tcpNoDelay; |
| } |
| |
| boolean getDoPing() { |
| return doPing; |
| } |
| |
| int getPingInterval() { |
| return pingInterval; |
| } |
| |
| /** |
| * Returns a ConnectionId object. |
| * @param addr Remote address for the connection. |
| * @param protocol Protocol for RPC. |
| * @param ticket UGI |
| * @param rpcTimeout timeout |
| * @param conf Configuration object |
| * @return A ConnectionId instance |
| * @throws IOException |
| */ |
| public static ConnectionId getConnectionId(InetSocketAddress addr, |
| Class<?> protocol, UserGroupInformation ticket, int rpcTimeout, |
| Configuration conf) throws IOException { |
| String remotePrincipal = getRemotePrincipal(conf, addr, protocol); |
| boolean doPing = |
| conf.getBoolean(CommonConfigurationKeys.IPC_CLIENT_PING_KEY, true); |
| return new ConnectionId(addr, protocol, ticket, |
| rpcTimeout, remotePrincipal, |
| conf.getInt(CommonConfigurationKeysPublic.IPC_CLIENT_CONNECTION_MAXIDLETIME_KEY, |
| CommonConfigurationKeysPublic.IPC_CLIENT_CONNECTION_MAXIDLETIME_DEFAULT), |
| conf.getInt(CommonConfigurationKeysPublic.IPC_CLIENT_CONNECT_MAX_RETRIES_KEY, |
| CommonConfigurationKeysPublic.IPC_CLIENT_CONNECT_MAX_RETRIES_DEFAULT), |
| conf.getBoolean(CommonConfigurationKeysPublic.IPC_CLIENT_TCPNODELAY_KEY, |
| CommonConfigurationKeysPublic.IPC_CLIENT_TCPNODELAY_DEFAULT), |
| doPing, |
| (doPing ? Client.getPingInterval(conf) : 0)); |
| } |
| |
| private static String getRemotePrincipal(Configuration conf, |
| InetSocketAddress address, Class<?> protocol) throws IOException { |
| if (!UserGroupInformation.isSecurityEnabled() || protocol == null) { |
| return null; |
| } |
| KerberosInfo krbInfo = SecurityUtil.getKerberosInfo(protocol, conf); |
| if (krbInfo != null) { |
| String serverKey = krbInfo.serverPrincipal(); |
| if (serverKey == null) { |
| throw new IOException( |
| "Can't obtain server Kerberos config key from protocol=" |
| + protocol.getCanonicalName()); |
| } |
| return SecurityUtil.getServerPrincipal(conf.get(serverKey), address |
| .getAddress()); |
| } |
| return null; |
| } |
| |
| static boolean isEqual(Object a, Object b) { |
| return a == null ? b == null : a.equals(b); |
| } |
| |
| @Override |
| public boolean equals(Object obj) { |
| if (obj == this) { |
| return true; |
| } |
| if (obj instanceof ConnectionId) { |
| ConnectionId that = (ConnectionId) obj; |
| return isEqual(this.address, that.address) |
| && this.doPing == that.doPing |
| && this.maxIdleTime == that.maxIdleTime |
| && this.maxRetries == that.maxRetries |
| && this.pingInterval == that.pingInterval |
| && isEqual(this.protocol, that.protocol) |
| && this.rpcTimeout == that.rpcTimeout |
| && isEqual(this.serverPrincipal, that.serverPrincipal) |
| && this.tcpNoDelay == that.tcpNoDelay |
| && isEqual(this.ticket, that.ticket); |
| } |
| return false; |
| } |
| |
| @Override |
| public int hashCode() { |
| int result = 1; |
| result = PRIME * result + ((address == null) ? 0 : address.hashCode()); |
| result = PRIME * result + (doPing ? 1231 : 1237); |
| result = PRIME * result + maxIdleTime; |
| result = PRIME * result + maxRetries; |
| result = PRIME * result + pingInterval; |
| result = PRIME * result + ((protocol == null) ? 0 : protocol.hashCode()); |
| result = PRIME * result + rpcTimeout; |
| result = PRIME * result |
| + ((serverPrincipal == null) ? 0 : serverPrincipal.hashCode()); |
| result = PRIME * result + (tcpNoDelay ? 1231 : 1237); |
| result = PRIME * result + ((ticket == null) ? 0 : ticket.hashCode()); |
| return result; |
| } |
| } |
| } |