[STORM-3679] Fix the misuse of nodeId as hostname in LoadAwareShuffleGrouping (#3313)

diff --git a/storm-client/src/jvm/org/apache/storm/daemon/worker/WorkerState.java b/storm-client/src/jvm/org/apache/storm/daemon/worker/WorkerState.java
index 1bc5970..3110696 100644
--- a/storm-client/src/jvm/org/apache/storm/daemon/worker/WorkerState.java
+++ b/storm-client/src/jvm/org/apache/storm/daemon/worker/WorkerState.java
@@ -120,6 +120,8 @@
     final ConcurrentMap<String, Long> blobToLastKnownVersion;
     final ReentrantReadWriteLock endpointSocketLock;
     final AtomicReference<Map<Integer, NodeInfo>> cachedTaskToNodePort;
+    // cachedNodeToHost can be temporarily out of sync with cachedTaskToNodePort
+    final AtomicReference<Map<String, String>> cachedNodeToHost;
     final AtomicReference<Map<NodeInfo, IConnection>> cachedNodeToPortSocket;
     // executor id is in form [start_task_id end_task_id]
     final Map<List<Long>, JCQueue> executorReceiveQueueMap;
@@ -214,6 +216,7 @@
         this.endpointSocketLock = new ReentrantReadWriteLock();
         this.cachedNodeToPortSocket = new AtomicReference<>(new HashMap<>());
         this.cachedTaskToNodePort = new AtomicReference<>(new HashMap<>());
+        this.cachedNodeToHost = new AtomicReference<>(new HashMap<>());
         this.suicideCallback = Utils.mkSuicideFn();
         this.uptime = Utils.makeUptimeComputer();
         this.defaultSharedResources = makeDefaultResources();
@@ -426,9 +429,9 @@
             }
         }
 
-        Set<NodeInfo> currentConnections = cachedNodeToPortSocket.get().keySet();
-        Set<NodeInfo> newConnections = Sets.difference(neededConnections, currentConnections);
-        Set<NodeInfo> removeConnections = Sets.difference(currentConnections, neededConnections);
+        final Set<NodeInfo> currentConnections = cachedNodeToPortSocket.get().keySet();
+        final Set<NodeInfo> newConnections = Sets.difference(neededConnections, currentConnections);
+        final Set<NodeInfo> removeConnections = Sets.difference(currentConnections, neededConnections);
 
         Map<String, String> nodeHost = assignment != null ? assignment.get_node_host() : null;
         // Add new connections atomically
@@ -446,7 +449,6 @@
             return next;
         });
 
-
         try {
             endpointSocketLock.writeLock().lock();
             cachedTaskToNodePort.set(newTaskToNodePort);
@@ -454,6 +456,13 @@
             endpointSocketLock.writeLock().unlock();
         }
 
+        // It is okay that cachedNodeToHost can be temporarily out of sync with cachedTaskToNodePort
+        if (nodeHost != null) {
+            cachedNodeToHost.set(nodeHost);
+        } else {
+            cachedNodeToHost.set(new HashMap<>());
+        }
+
         for (NodeInfo nodeInfo : removeConnections) {
             cachedNodeToPortSocket.get().get(nodeInfo).close();
         }
@@ -610,7 +619,7 @@
             return new WorkerTopologyContext(systemTopology, topologyConf, taskToComponent, componentToSortedTasks,
                                              componentToStreamToFields, topologyId, codeDir, pidDir, port, localTaskIds,
                                              defaultSharedResources,
-                                             userSharedResources, cachedTaskToNodePort, assignmentId);
+                                             userSharedResources, cachedTaskToNodePort, assignmentId, cachedNodeToHost);
         } catch (IOException e) {
             throw Utils.wrapInRuntime(e);
         }
diff --git a/storm-client/src/jvm/org/apache/storm/grouping/LoadAwareShuffleGrouping.java b/storm-client/src/jvm/org/apache/storm/grouping/LoadAwareShuffleGrouping.java
index f969ade..15b690d 100644
--- a/storm-client/src/jvm/org/apache/storm/grouping/LoadAwareShuffleGrouping.java
+++ b/storm-client/src/jvm/org/apache/storm/grouping/LoadAwareShuffleGrouping.java
@@ -51,6 +51,7 @@
     private NodeInfo sourceNodeInfo;
     private List<Integer> targetTasks;
     private AtomicReference<Map<Integer, NodeInfo>> taskToNodePort;
+    private AtomicReference<Map<String, String>> nodeToHost;
     private Map<String, Object> conf;
     private DNSToSwitchMapping dnsToSwitchMapping;
     private Map<LocalityScope, List<Integer>> localityGroup;
@@ -60,8 +61,9 @@
     @Override
     public void prepare(WorkerTopologyContext context, GlobalStreamId stream, List<Integer> targetTasks) {
         random = new Random();
-        sourceNodeInfo = new NodeInfo(context.getThisWorkerHost(), Sets.newHashSet((long) context.getThisWorkerPort()));
+        sourceNodeInfo = new NodeInfo(context.getAssignmentId(), Sets.newHashSet((long) context.getThisWorkerPort()));
         taskToNodePort = context.getTaskToNodePort();
+        nodeToHost = context.getNodeToHost();
         this.targetTasks = targetTasks;
         capacity = targetTasks.size() == 1 ? 1 : Math.max(1000, targetTasks.size() * 5);
         conf = context.getConf();
@@ -110,13 +112,18 @@
     }
 
     private void refreshLocalityGroup() {
+        // taskToNodePort and nodeToHost might be out of sync when they are refreshed by WorkerState
+        // but this is okay since it will only cause a temporary misjudgement on LocalityScope
         Map<Integer, NodeInfo> cachedTaskToNodePort = taskToNodePort.get();
-        Map<String, String> hostToRack = getHostToRackMapping(cachedTaskToNodePort);
+        Map<String, String> cachedNodeToHost = nodeToHost.get();
+
+        Map<String, String> hostToRack = getHostToRackMapping(cachedTaskToNodePort, cachedNodeToHost);
 
         localityGroup.values().stream().forEach(v -> v.clear());
 
         for (int target : targetTasks) {
-            LocalityScope scope = calculateScope(cachedTaskToNodePort, hostToRack, target);
+            LocalityScope scope = calculateScope(cachedTaskToNodePort, cachedNodeToHost, hostToRack, target);
+            LOG.debug("targetTask {} is in LocalityScope {}", target, scope);
             if (!localityGroup.containsKey(scope)) {
                 localityGroup.put(scope, new ArrayList<>());
             }
@@ -249,41 +256,55 @@
         arr[j] = tmp;
     }
 
-    private LocalityScope calculateScope(Map<Integer, NodeInfo> taskToNodePort, Map<String, String> hostToRack, int target) {
+    private LocalityScope calculateScope(Map<Integer, NodeInfo> taskToNodePort, Map<String, String> nodeToHost,
+                                         Map<String, String> hostToRack, int target) {
         NodeInfo targetNodeInfo = taskToNodePort.get(target);
 
         if (targetNodeInfo == null) {
             return LocalityScope.EVERYTHING;
         }
 
-        String sourceRack = hostToRack.get(sourceNodeInfo.get_node());
-        String targetRack = hostToRack.get(targetNodeInfo.get_node());
-
-        if (sourceRack != null && targetRack != null && sourceRack.equals(targetRack)) {
-            if (sourceNodeInfo.get_node().equals(targetNodeInfo.get_node())) {
-                if (sourceNodeInfo.get_port().equals(targetNodeInfo.get_port())) {
-                    return LocalityScope.WORKER_LOCAL;
-                }
-                return LocalityScope.HOST_LOCAL;
+        if (sourceNodeInfo.get_node().equals(targetNodeInfo.get_node())) {
+            if (sourceNodeInfo.get_port().equals(targetNodeInfo.get_port())) {
+                return LocalityScope.WORKER_LOCAL;
             }
-            return LocalityScope.RACK_LOCAL;
+            return LocalityScope.HOST_LOCAL;
         } else {
-            return LocalityScope.EVERYTHING;
+            String sourceHostname = nodeToHost.get(sourceNodeInfo.get_node());
+            String targetHostname = nodeToHost.get(targetNodeInfo.get_node());
+
+            String sourceRack = (sourceHostname == null) ? null : hostToRack.get(sourceHostname);
+            String targetRack = (targetHostname == null) ? null : hostToRack.get(targetHostname);
+
+            if (sourceRack != null && sourceRack.equals(targetRack)) {
+                return LocalityScope.RACK_LOCAL;
+            } else {
+                return LocalityScope.EVERYTHING;
+            }
         }
     }
 
-    private Map<String, String> getHostToRackMapping(Map<Integer, NodeInfo> taskToNodePort) {
-        Set<String> hosts = new HashSet();
+    private Map<String, String> getHostToRackMapping(Map<Integer, NodeInfo> taskToNodePort, Map<String, String> nodeToHost) {
+        Set<String> hosts = new HashSet<>();
         for (int task : targetTasks) {
             //if this task containing worker will be killed by a assignments sync,
             //taskToNodePort will be an empty map which is refreshed by WorkerState
             if (taskToNodePort.containsKey(task)) {
-                hosts.add(taskToNodePort.get(task).get_node());
+                String node = taskToNodePort.get(task).get_node();
+                String hostname = nodeToHost.get(node);
+                if (hostname != null) {
+                    hosts.add(hostname);
+                }
             } else {
                 LOG.error("Could not find task NodeInfo from local cache.");
             }
         }
-        hosts.add(sourceNodeInfo.get_node());
+
+        String node = sourceNodeInfo.get_node();
+        String hostname = nodeToHost.get(node);
+        if (hostname != null) {
+            hosts.add(hostname);
+        }
         return dnsToSwitchMapping.resolve(new ArrayList<>(hosts));
     }
 
diff --git a/storm-client/src/jvm/org/apache/storm/task/WorkerTopologyContext.java b/storm-client/src/jvm/org/apache/storm/task/WorkerTopologyContext.java
index 36ef800..f5dec11 100644
--- a/storm-client/src/jvm/org/apache/storm/task/WorkerTopologyContext.java
+++ b/storm-client/src/jvm/org/apache/storm/task/WorkerTopologyContext.java
@@ -32,6 +32,7 @@
     private String pidDir;
     private AtomicReference<Map<Integer, NodeInfo>> taskToNodePort;
     private String assignmentId;
+    private final AtomicReference<Map<String, String>> nodeToHost;
 
     public WorkerTopologyContext(
         StormTopology topology,
@@ -47,7 +48,8 @@
         Map<String, Object> defaultResources,
         Map<String, Object> userResources,
         AtomicReference<Map<Integer, NodeInfo>> taskToNodePort,
-        String assignmentId
+        String assignmentId,
+        AtomicReference<Map<String, String>> nodeToHost
     ) {
         super(topology, topoConf, taskToComponent, componentToSortedTasks, componentToStreamToFields, stormId);
         this.codeDir = codeDir;
@@ -66,6 +68,7 @@
         this.workerTasks = workerTasks;
         this.taskToNodePort = taskToNodePort;
         this.assignmentId = assignmentId;
+        this.nodeToHost = nodeToHost;
 
     }
 
@@ -83,7 +86,7 @@
         Map<String, Object> defaultResources,
         Map<String, Object> userResources) {
         this(topology, topoConf, taskToComponent, componentToSortedTasks, componentToStreamToFields, stormId,
-             codeDir, pidDir, workerPort, workerTasks, defaultResources, userResources, null, null);
+             codeDir, pidDir, workerPort, workerTasks, defaultResources, userResources, null, null, null);
     }
 
     /**
@@ -97,7 +100,7 @@
         return workerPort;
     }
 
-    public String getThisWorkerHost() {
+    public String getAssignmentId() {
         return assignmentId;
     }
 
@@ -111,6 +114,14 @@
     }
 
     /**
+     * Get a map from nodeId to hostname.
+     * @return a map from nodeId to hostname
+     */
+    public AtomicReference<Map<String, String>> getNodeToHost() {
+        return nodeToHost;
+    }
+
+    /**
      * Gets the location of the external resources for this worker on the local filesystem. These external resources typically include bolts
      * implemented in other languages, such as Ruby or Python.
      */
diff --git a/storm-client/test/jvm/org/apache/storm/grouping/LoadAwareShuffleGroupingTest.java b/storm-client/test/jvm/org/apache/storm/grouping/LoadAwareShuffleGroupingTest.java
index 678d803..448d8fc 100644
--- a/storm-client/test/jvm/org/apache/storm/grouping/LoadAwareShuffleGroupingTest.java
+++ b/storm-client/test/jvm/org/apache/storm/grouping/LoadAwareShuffleGroupingTest.java
@@ -66,8 +66,10 @@
         NodeInfo nodeInfo = new NodeInfo("node-id", Sets.newHashSet(6700L));
         availableTaskIds.forEach(e -> taskNodeToPort.put(e, nodeInfo));
         when(context.getTaskToNodePort()).thenReturn(new AtomicReference<>(taskNodeToPort));
-        when(context.getThisWorkerHost()).thenReturn("node-id");
+        when(context.getAssignmentId()).thenReturn("node-id");
         when(context.getThisWorkerPort()).thenReturn(6700);
+        AtomicReference<Map<String, String>> nodeToHost = new AtomicReference<>(Collections.singletonMap("node-id", "hostname1"));
+        when(context.getNodeToHost()).thenReturn(nodeToHost);
         return context;
     }
 
@@ -566,8 +568,14 @@
         taskNodeToPort.put(3, new NodeInfo("node-id2", Sets.newHashSet(6703L)));
 
         when(context.getTaskToNodePort()).thenReturn(new AtomicReference<>(taskNodeToPort));
-        when(context.getThisWorkerHost()).thenReturn("node-id");
+        when(context.getAssignmentId()).thenReturn("node-id");
         when(context.getThisWorkerPort()).thenReturn(6701);
+
+        Map<String, String> nodeToHost = new HashMap<>();
+        nodeToHost.put("node-id", "hostname1");
+        nodeToHost.put("node-id2", "hostname2");
+        when(context.getNodeToHost()).thenReturn(new AtomicReference<>(nodeToHost));
+
         return context;
     }
 }
\ No newline at end of file