Handle task location fetch from overlord during rolling upgrades (#16227)

Bug:
#15724 introduced a bug where a rolling upgrade would cause all task locations
returned by the Overlord on an older version to be unknown.

Fix:
If the new API fails, fall back to single task status API which always returns a valid task location.
diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/client/IndexerWorkerManagerClient.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/client/IndexerWorkerManagerClient.java
index 1894336..927130e 100644
--- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/client/IndexerWorkerManagerClient.java
+++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/client/IndexerWorkerManagerClient.java
@@ -20,6 +20,7 @@
 package org.apache.druid.msq.indexing.client;
 
 import com.google.common.collect.ImmutableSet;
+import org.apache.druid.client.indexing.TaskStatusResponse;
 import org.apache.druid.common.guava.FutureUtils;
 import org.apache.druid.indexer.TaskLocation;
 import org.apache.druid.indexer.TaskStatus;
@@ -37,6 +38,7 @@
 public class IndexerWorkerManagerClient implements WorkerManagerClient
 {
   private final OverlordClient overlordClient;
+  private final TaskLocationFetcher locationFetcher = new TaskLocationFetcher();
 
   public IndexerWorkerManagerClient(final OverlordClient overlordClient)
   {
@@ -65,16 +67,7 @@
   @Override
   public TaskLocation location(String workerId)
   {
-    final TaskStatus response = FutureUtils.getUnchecked(
-        overlordClient.taskStatuses(ImmutableSet.of(workerId)),
-        true
-    ).get(workerId);
-
-    if (response != null) {
-      return response.getLocation();
-    } else {
-      return TaskLocation.unknown();
-    }
+    return locationFetcher.getLocation(workerId);
   }
 
   @Override
@@ -82,4 +75,31 @@
   {
     // Nothing to do. The OverlordServiceClient is closed by the JVM lifecycle.
   }
+
+  private class TaskLocationFetcher
+  {
+    TaskLocation getLocation(String workerId)
+    {
+      final TaskStatus taskStatus = FutureUtils.getUnchecked(
+          overlordClient.taskStatuses(ImmutableSet.of(workerId)),
+          true
+      ).get(workerId);
+
+      if (taskStatus != null
+          && !TaskLocation.unknown().equals(taskStatus.getLocation())) {
+        return taskStatus.getLocation();
+      }
+
+      // Retry with the single status API
+      final TaskStatusResponse statusResponse = FutureUtils.getUnchecked(
+          overlordClient.taskStatus(workerId),
+          true
+      );
+      if (statusResponse == null || statusResponse.getStatus() == null) {
+        return TaskLocation.unknown();
+      } else {
+        return statusResponse.getStatus().getLocation();
+      }
+    }
+  }
 }
diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/client/IndexerWorkerManagerClientTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/client/IndexerWorkerManagerClientTest.java
new file mode 100644
index 0000000..4b53420
--- /dev/null
+++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/client/IndexerWorkerManagerClientTest.java
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.msq.indexing.client;
+
+import com.google.common.util.concurrent.Futures;
+import org.apache.druid.client.indexing.TaskStatusResponse;
+import org.apache.druid.indexer.TaskLocation;
+import org.apache.druid.indexer.TaskState;
+import org.apache.druid.indexer.TaskStatus;
+import org.apache.druid.indexer.TaskStatusPlus;
+import org.apache.druid.java.util.common.DateTimes;
+import org.apache.druid.rpc.indexing.OverlordClient;
+import org.junit.Assert;
+import org.junit.Test;
+import org.mockito.ArgumentMatchers;
+import org.mockito.Mockito;
+
+import java.util.Collections;
+
+public class IndexerWorkerManagerClientTest
+{
+
+  @Test
+  public void testGetLocationCallsMultiStatusApiByDefault()
+  {
+    final OverlordClient overlordClient = Mockito.mock(OverlordClient.class);
+
+    final String taskId = "worker1";
+    final TaskLocation expectedLocation = new TaskLocation("localhost", 1000, 1100, null);
+    Mockito.when(overlordClient.taskStatuses(Collections.singleton(taskId))).thenReturn(
+        Futures.immediateFuture(
+            Collections.singletonMap(
+                taskId,
+                new TaskStatus(taskId, TaskState.RUNNING, 100L, null, expectedLocation)
+            )
+        )
+    );
+
+    final IndexerWorkerManagerClient managerClient = new IndexerWorkerManagerClient(overlordClient);
+    Assert.assertEquals(managerClient.location(taskId), expectedLocation);
+
+    Mockito.verify(overlordClient, Mockito.times(1)).taskStatuses(ArgumentMatchers.anySet());
+    Mockito.verify(overlordClient, Mockito.never()).taskStatus(ArgumentMatchers.anyString());
+  }
+
+  @Test
+  public void testGetLocationFallsBackToSingleTaskApiIfLocationIsUnknown()
+  {
+    final OverlordClient overlordClient = Mockito.mock(OverlordClient.class);
+
+    final String taskId = "worker1";
+    Mockito.when(overlordClient.taskStatuses(Collections.singleton(taskId))).thenReturn(
+        Futures.immediateFuture(
+            Collections.singletonMap(
+                taskId,
+                new TaskStatus(taskId, TaskState.RUNNING, 100L, null, TaskLocation.unknown())
+            )
+        )
+    );
+
+    final TaskLocation expectedLocation = new TaskLocation("localhost", 1000, 1100, null);
+    final TaskStatusPlus taskStatus = new TaskStatusPlus(
+        taskId,
+        null,
+        null,
+        DateTimes.nowUtc(),
+        DateTimes.nowUtc(),
+        TaskState.RUNNING,
+        null,
+        100L,
+        expectedLocation,
+        "wiki",
+        null
+    );
+
+    Mockito.when(overlordClient.taskStatus(taskId)).thenReturn(
+        Futures.immediateFuture(new TaskStatusResponse(taskId, taskStatus))
+    );
+
+    final IndexerWorkerManagerClient managerClient = new IndexerWorkerManagerClient(overlordClient);
+    Assert.assertEquals(managerClient.location(taskId), expectedLocation);
+
+    Mockito.verify(overlordClient, Mockito.times(1)).taskStatuses(ArgumentMatchers.anySet());
+    Mockito.verify(overlordClient, Mockito.times(1)).taskStatus(ArgumentMatchers.anyString());
+  }
+
+}
diff --git a/server/src/main/java/org/apache/druid/rpc/indexing/SpecificTaskServiceLocator.java b/server/src/main/java/org/apache/druid/rpc/indexing/SpecificTaskServiceLocator.java
index 163c7e1..3f54413 100644
--- a/server/src/main/java/org/apache/druid/rpc/indexing/SpecificTaskServiceLocator.java
+++ b/server/src/main/java/org/apache/druid/rpc/indexing/SpecificTaskServiceLocator.java
@@ -26,6 +26,8 @@
 import com.google.common.util.concurrent.MoreExecutors;
 import com.google.common.util.concurrent.SettableFuture;
 import com.google.errorprone.annotations.concurrent.GuardedBy;
+import org.apache.druid.client.indexing.TaskStatusResponse;
+import org.apache.druid.common.guava.FutureUtils;
 import org.apache.druid.indexer.TaskLocation;
 import org.apache.druid.indexer.TaskState;
 import org.apache.druid.indexer.TaskStatus;
@@ -55,6 +57,7 @@
 
   private final String taskId;
   private final OverlordClient overlordClient;
+  private final TaskLocationFetcher locationFetcher = new TaskLocationFetcher();
   private final Object lock = new Object();
 
   @GuardedBy("lock")
@@ -129,14 +132,20 @@
                       lastKnownLocation = null;
                     } else {
                       lastKnownState = status.getStatusCode();
-
+                      final TaskLocation location;
                       if (TaskLocation.unknown().equals(status.getLocation())) {
+                        location = locationFetcher.getLocation();
+                      } else {
+                        location = status.getLocation();
+                      }
+
+                      if (TaskLocation.unknown().equals(location)) {
                         lastKnownLocation = null;
                       } else {
                         lastKnownLocation = new ServiceLocation(
-                            status.getLocation().getHost(),
-                            status.getLocation().getPort(),
-                            status.getLocation().getTlsPort(),
+                            location.getHost(),
+                            location.getPort(),
+                            location.getTlsPort(),
                             StringUtils.format("%s/%s", BASE_PATH, StringUtils.urlEncode(taskId))
                         );
                       }
@@ -199,4 +208,20 @@
       }
     }
   }
+
+  private class TaskLocationFetcher
+  {
+    TaskLocation getLocation()
+    {
+      final TaskStatusResponse statusResponse = FutureUtils.getUnchecked(
+          overlordClient.taskStatus(taskId),
+          true
+      );
+      if (statusResponse == null || statusResponse.getStatus() == null) {
+        return TaskLocation.unknown();
+      } else {
+        return statusResponse.getStatus().getLocation();
+      }
+    }
+  }
 }
diff --git a/server/src/test/java/org/apache/druid/rpc/indexing/SpecificTaskServiceLocatorTest.java b/server/src/test/java/org/apache/druid/rpc/indexing/SpecificTaskServiceLocatorTest.java
index 4888078..f754567 100644
--- a/server/src/test/java/org/apache/druid/rpc/indexing/SpecificTaskServiceLocatorTest.java
+++ b/server/src/test/java/org/apache/druid/rpc/indexing/SpecificTaskServiceLocatorTest.java
@@ -22,9 +22,12 @@
 import com.google.common.util.concurrent.Futures;
 import com.google.common.util.concurrent.ListenableFuture;
 import com.google.common.util.concurrent.SettableFuture;
+import org.apache.druid.client.indexing.TaskStatusResponse;
 import org.apache.druid.indexer.TaskLocation;
 import org.apache.druid.indexer.TaskState;
 import org.apache.druid.indexer.TaskStatus;
+import org.apache.druid.indexer.TaskStatusPlus;
+import org.apache.druid.java.util.common.DateTimes;
 import org.apache.druid.java.util.common.ISE;
 import org.apache.druid.rpc.ServiceLocation;
 import org.apache.druid.rpc.ServiceLocations;
@@ -62,6 +65,25 @@
   {
     Mockito.when(overlordClient.taskStatuses(Collections.singleton(TASK_ID)))
            .thenReturn(status(TaskState.RUNNING, TaskLocation.unknown()));
+    final TaskStatusResponse response = new TaskStatusResponse(
+        TASK_ID,
+        new TaskStatusPlus(
+            TASK_ID,
+            null,
+            null,
+            DateTimes.nowUtc(),
+            DateTimes.EPOCH,
+            TaskState.RUNNING,
+            null,
+            null,
+            null,
+            TaskLocation.unknown(),
+            null,
+            null
+        )
+    );
+    Mockito.when(overlordClient.taskStatus(TASK_ID))
+           .thenReturn(Futures.immediateFuture(response));
 
     final SpecificTaskServiceLocator locator = new SpecificTaskServiceLocator(TASK_ID, overlordClient);
     final ListenableFuture<ServiceLocations> future = locator.locate();
@@ -94,6 +116,25 @@
   {
     Mockito.when(overlordClient.taskStatuses(Collections.singleton(TASK_ID)))
            .thenReturn(status(TaskState.SUCCESS, TaskLocation.unknown()));
+    final TaskStatusResponse response = new TaskStatusResponse(
+        TASK_ID,
+        new TaskStatusPlus(
+            TASK_ID,
+            null,
+            null,
+            DateTimes.nowUtc(),
+            DateTimes.EPOCH,
+            TaskState.FAILED,
+            null,
+            null,
+            100L,
+            TaskLocation.unknown(),
+            null,
+            null
+        )
+    );
+    Mockito.when(overlordClient.taskStatus(TASK_ID))
+           .thenReturn(Futures.immediateFuture(response));
 
     final SpecificTaskServiceLocator locator = new SpecificTaskServiceLocator(TASK_ID, overlordClient);
     final ListenableFuture<ServiceLocations> future = locator.locate();
@@ -105,6 +146,25 @@
   {
     Mockito.when(overlordClient.taskStatuses(Collections.singleton(TASK_ID)))
            .thenReturn(status(TaskState.FAILED, TaskLocation.unknown()));
+    final TaskStatusResponse response = new TaskStatusResponse(
+        TASK_ID,
+        new TaskStatusPlus(
+            TASK_ID,
+            null,
+            null,
+            DateTimes.nowUtc(),
+            DateTimes.EPOCH,
+            TaskState.FAILED,
+            null,
+            null,
+            100L,
+            TaskLocation.unknown(),
+            null,
+            null
+        )
+    );
+    Mockito.when(overlordClient.taskStatus(TASK_ID))
+           .thenReturn(Futures.immediateFuture(response));
 
     final SpecificTaskServiceLocator locator = new SpecificTaskServiceLocator(TASK_ID, overlordClient);
     final ListenableFuture<ServiceLocations> future = locator.locate();