HADOOP-18365. Update the remote address when a change is detected (#4692)

Avoid reconnecting to the old address after detecting that the address has been updated.

* Fix Checkstyle line length violation
* Keep ConnectionId as Immutable for map key

The ConnectionId is used as a key in the connections map, and updating the remoteId caused problems with the cleanup of connections when the removeMethod was used.

Instead of updating the address within the remoteId, use the removeMethod to cleanup references to the current identifier and then replace it with a new identifier using the updated address.

* Use final to protect immutable ConnectionId

Mark non-test fields as private and final, and add a missing accessor.

* Use a stable hashCode to allow safe IP addr changes
* Add test that updated address is used

Once the address has been updated, it should be used in future calls.  Check to ensure that a second request succeeds and that it uses the existing updated address instead of having to re-resolve.

Signed-off-by: Nick Dimiduk <ndimiduk@apache.org>
Signed-off-by: sokui
Signed-off-by: XanderZu
Signed-off-by: stack <stack@apache.org>
diff --git a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Client.java b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Client.java
index 534824e..20fc9ef 100644
--- a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Client.java
+++ b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Client.java
@@ -419,7 +419,7 @@
    * socket: responses may be delivered out of order. */
   private class Connection extends Thread {
     private InetSocketAddress server;             // server ip:port
-    private final ConnectionId remoteId;                // connection id
+    private final ConnectionId remoteId;          // connection id
     private AuthMethod authMethod; // authentication method
     private AuthProtocol authProtocol;
     private int serviceClass;
@@ -645,6 +645,9 @@
         LOG.warn("Address change detected. Old: " + server.toString() +
                                  " New: " + currentAddr.toString());
         server = currentAddr;
+        // Update the remote address so that reconnections are with the updated address.
+        // This avoids thrashing.
+        remoteId.setAddress(currentAddr);
         UserGroupInformation ticket = remoteId.getTicket();
         this.setName("IPC Client (" + socketFactory.hashCode()
             + ") connection to " + server.toString() + " from "
@@ -1700,9 +1703,9 @@
   @InterfaceAudience.LimitedPrivate({"HDFS", "MapReduce"})
   @InterfaceStability.Evolving
   public static class ConnectionId {
-    InetSocketAddress address;
-    UserGroupInformation ticket;
-    final Class<?> protocol;
+    private InetSocketAddress address;
+    private final UserGroupInformation ticket;
+    private final Class<?> protocol;
     private static final int PRIME = 16777619;
     private final int rpcTimeout;
     private final int maxIdleTime; //connections will be culled if it was idle for 
@@ -1717,7 +1720,7 @@
     private final int pingInterval; // how often sends ping to the server in msecs
     private String saslQop; // here for testing
     private final Configuration conf; // used to get the expected kerberos principal name
-    
+
     public ConnectionId(InetSocketAddress address, Class<?> protocol,
                  UserGroupInformation ticket, int rpcTimeout,
                  RetryPolicy connectionRetryPolicy, Configuration conf) {
@@ -1753,7 +1756,28 @@
     InetSocketAddress getAddress() {
       return address;
     }
-    
+
+    /**
+     * This is used to update the remote address when an address change is detected.  This method
+     * ensures that the {@link #hashCode()} won't change.
+     *
+     * @param address the updated address
+     * @throws IllegalArgumentException if the hostname or port doesn't match
+     * @see Connection#updateAddress()
+     */
+    void setAddress(InetSocketAddress address) {
+      if (!Objects.equals(this.address.getHostName(), address.getHostName())) {
+        throw new IllegalArgumentException("Hostname must match: " + this.address + " vs "
+            + address);
+      }
+      if (this.address.getPort() != address.getPort()) {
+        throw new IllegalArgumentException("Port must match: " + this.address + " vs " + address);
+      }
+
+      this.address = address;
+    }
+
+
     Class<?> getProtocol() {
       return protocol;
     }
@@ -1864,7 +1888,11 @@
     @Override
     public int hashCode() {
       int result = connectionRetryPolicy.hashCode();
-      result = PRIME * result + ((address == null) ? 0 : address.hashCode());
+      // We calculate based on the host name and port without the IP address, since the hashCode
+      // must be stable even if the IP address is updated.
+      result = PRIME * result + ((address == null || address.getHostName() == null) ? 0 :
+          address.getHostName().hashCode());
+      result = PRIME * result + ((address == null) ? 0 : address.getPort());
       result = PRIME * result + (doPing ? 1231 : 1237);
       result = PRIME * result + maxIdleTime;
       result = PRIME * result + pingInterval;
diff --git a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/WritableRpcEngine.java b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/WritableRpcEngine.java
index 2a19ad2..3e4ee70 100644
--- a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/WritableRpcEngine.java
+++ b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/WritableRpcEngine.java
@@ -323,7 +323,7 @@
       Client.ConnectionId connId, Configuration conf, SocketFactory factory)
       throws IOException {
     return getProxy(protocol, clientVersion, connId.getAddress(),
-        connId.ticket, conf, factory, connId.getRpcTimeout(),
+        connId.getTicket(), conf, factory, connId.getRpcTimeout(),
         connId.getRetryPolicy(), null, null);
   }
 
diff --git a/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/ipc/TestIPC.java b/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/ipc/TestIPC.java
index 95ff302..1e78079 100644
--- a/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/ipc/TestIPC.java
+++ b/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/ipc/TestIPC.java
@@ -18,6 +18,7 @@
 
 package org.apache.hadoop.ipc;
 
+import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertNotNull;
@@ -93,6 +94,7 @@
 import org.apache.hadoop.test.LambdaTestUtils;
 import org.apache.hadoop.test.Whitebox;
 import org.apache.hadoop.util.StringUtils;
+import org.assertj.core.api.Condition;
 import org.junit.Assert;
 import org.junit.Assume;
 import org.junit.Before;
@@ -815,6 +817,81 @@
     }
   }
 
+  /**
+   * The {@link ConnectionId#hashCode} has to be stable despite updates that occur as the the
+   * address evolves over time.  The {@link ConnectionId} is used as a primary key in maps, so
+   * its hashCode can't change.
+   *
+   * @throws IOException if there is a client or server failure
+   */
+  @Test
+  public void testStableHashCode() throws IOException {
+    Server server = new TestServer(5, false);
+    try {
+      server.start();
+
+      // Leave host unresolved to start. Use "localhost" as opposed
+      // to local IP from NetUtils.getConnectAddress(server) to force
+      // resolution later
+      InetSocketAddress unresolvedAddr = InetSocketAddress.createUnresolved(
+          "localhost", NetUtils.getConnectAddress(server).getPort());
+
+      // Setup: Create a ConnectionID using an unresolved address, and get it's hashCode to serve
+      // as a point of comparison.
+      int rpcTimeout = MIN_SLEEP_TIME * 2;
+      final ConnectionId remoteId = getConnectionId(unresolvedAddr, rpcTimeout, conf);
+      int expected = remoteId.hashCode();
+
+      // Start client
+      Client.setConnectTimeout(conf, 100);
+      Client client = new Client(LongWritable.class, conf);
+      try {
+        // Test: Call should re-resolve host and succeed
+        LongWritable param = new LongWritable(RANDOM.nextLong());
+        client.call(RPC.RpcKind.RPC_BUILTIN, param, remoteId,
+            RPC.RPC_SERVICE_CLASS_DEFAULT, null);
+        int actual = remoteId.hashCode();
+
+        // Verify: The hashCode should match, although the InetAddress is different since it has
+        // now been resolved
+        assertThat(remoteId.getAddress()).isNotEqualTo(unresolvedAddr);
+        assertThat(remoteId.getAddress().getHostName()).isEqualTo(unresolvedAddr.getHostName());
+        assertThat(remoteId.hashCode()).isEqualTo(expected);
+
+        // Test: Call should succeed without having to re-resolve
+        InetSocketAddress expectedSocketAddress = remoteId.getAddress();
+        param = new LongWritable(RANDOM.nextLong());
+        client.call(RPC.RpcKind.RPC_BUILTIN, param, remoteId,
+            RPC.RPC_SERVICE_CLASS_DEFAULT, null);
+
+        // Verify: The same instance of the InetSocketAddress has been used to make the second
+        // call
+        assertThat(remoteId.getAddress()).isSameAs(expectedSocketAddress);
+
+        // Verify: The hashCode is protected against updates to the host name
+        String hostName = InetAddress.getLocalHost().getHostName();
+        InetSocketAddress mismatchedHostName = NetUtils.createSocketAddr(
+            InetAddress.getLocalHost().getHostName(),
+            remoteId.getAddress().getPort());
+        assertThatExceptionOfType(IllegalArgumentException.class)
+            .isThrownBy(() -> remoteId.setAddress(mismatchedHostName))
+            .withMessageStartingWith("Hostname must match");
+
+        // Verify: The hashCode is protected against updates to the port
+        InetSocketAddress mismatchedPort = NetUtils.createSocketAddr(
+            remoteId.getAddress().getHostName(),
+            remoteId.getAddress().getPort() + 1);
+        assertThatExceptionOfType(IllegalArgumentException.class)
+            .isThrownBy(() -> remoteId.setAddress(mismatchedPort))
+            .withMessageStartingWith("Port must match");
+      } finally {
+        client.stop();
+      }
+    } finally {
+      server.stop();
+    }
+  }
+
   @Test(timeout=60000)
   public void testIpcFlakyHostResolution() throws IOException {
     // start server