Fix race condition in scheduler message processing logic. (#1930)

This PR aims to fix the race condition that happens during processing scheduler messages. The previous logic which dynamically delete task partitions in the scheduler message IdealState may cause conflicts and results in inconsistent message status update. Since updating the task partitions is not a necessary step, this PR removes the corresponding logic and simplify the message handling procedure.

This PR will help to stablize TestSchedulerMessage.java.
diff --git a/helix-core/src/main/java/org/apache/helix/controller/stages/ExternalViewComputeStage.java b/helix-core/src/main/java/org/apache/helix/controller/stages/ExternalViewComputeStage.java
index 50cf0db..79b4d41 100644
--- a/helix-core/src/main/java/org/apache/helix/controller/stages/ExternalViewComputeStage.java
+++ b/helix-core/src/main/java/org/apache/helix/controller/stages/ExternalViewComputeStage.java
@@ -23,12 +23,10 @@
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.Iterator;
-import java.util.LinkedList;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
 import java.util.TreeMap;
-
 import org.apache.helix.HelixDataAccessor;
 import org.apache.helix.HelixDefinedState;
 import org.apache.helix.HelixException;
@@ -48,11 +46,8 @@
 import org.apache.helix.model.Partition;
 import org.apache.helix.model.Resource;
 import org.apache.helix.model.ResourceConfig;
-import org.apache.helix.model.StateModelDefinition;
 import org.apache.helix.model.StatusUpdate;
 import org.apache.helix.monitoring.mbeans.ClusterStatusMonitor;
-import org.apache.helix.zookeeper.datamodel.ZNRecord;
-import org.apache.helix.zookeeper.datamodel.ZNRecordDelta;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -203,7 +198,7 @@
 
       // For SCHEDULER_TASK_RESOURCE resource group (helix task queue), we need to find out which
       // task partitions are finished (COMPLETED or ERROR), update the status update of the original
-      // scheduler message, and then remove the partitions from the ideal state
+      // scheduler message.
       if (idealState != null
           && idealState.getStateModelDefRef().equalsIgnoreCase(
           DefaultSchedulerMessageHandlerFactory.SCHEDULER_TASK_QUEUE)) {
@@ -215,15 +210,9 @@
   private void updateScheduledTaskStatus(ExternalView ev, HelixManager manager,
       IdealState taskQueueIdealState) {
     HelixDataAccessor accessor = manager.getHelixDataAccessor();
-    ZNRecord finishedTasks = new ZNRecord(ev.getResourceName());
 
-    // Place holder for finished partitions
-    Map<String, String> emptyMap = new HashMap<String, String>();
-    List<String> emptyList = new LinkedList<String>();
-
-    Map<String, Integer> controllerMsgIdCountMap = new HashMap<String, Integer>();
-    Map<String, Map<String, String>> controllerMsgUpdates =
-        new HashMap<String, Map<String, String>>();
+    Map<String, Integer> controllerMsgIdCountMap = new HashMap<>();
+    Map<String, Map<String, String>> controllerFinishedMsgs = new HashMap<>();
 
     Builder keyBuilder = accessor.keyBuilder();
 
@@ -232,61 +221,52 @@
         if (taskState.equalsIgnoreCase(HelixDefinedState.ERROR.toString()) || taskState
             .equalsIgnoreCase("COMPLETED")) {
           LogUtil.logInfo(LOG, _eventId, taskPartitionName + " finished as " + taskState);
-          finishedTasks.getListFields().put(taskPartitionName, emptyList);
-          finishedTasks.getMapFields().put(taskPartitionName, emptyMap);
 
           // Update original scheduler message status update
-          if (taskQueueIdealState.getRecord().getMapField(taskPartitionName) != null) {
-            String controllerMsgId = taskQueueIdealState.getRecord().getMapField(taskPartitionName)
-                .get(DefaultSchedulerMessageHandlerFactory.CONTROLLER_MSG_ID);
+          Map<String, String> taskPartitionStatus = taskQueueIdealState.getRecord().getMapField(taskPartitionName);
+          if (taskPartitionStatus != null) {
+            String controllerMsgId = taskPartitionStatus.get(DefaultSchedulerMessageHandlerFactory.CONTROLLER_MSG_ID);
             if (controllerMsgId != null) {
-              LogUtil.logInfo(LOG, _eventId,
-                  taskPartitionName + " finished with controllerMsg " + controllerMsgId);
-              if (!controllerMsgUpdates.containsKey(controllerMsgId)) {
-                controllerMsgUpdates.put(controllerMsgId, new HashMap<String, String>());
-              }
-              controllerMsgUpdates.get(controllerMsgId).put(taskPartitionName, taskState);
+              LogUtil.logInfo(LOG, _eventId, taskPartitionName + " finished with controllerMsg " + controllerMsgId);
+              controllerFinishedMsgs.computeIfAbsent(controllerMsgId, id -> new HashMap<>())
+                  .put(taskPartitionName, taskState);
             }
           }
         }
       }
     }
+
     // fill the controllerMsgIdCountMap
-    for (String taskId : taskQueueIdealState.getPartitionSet()) {
-      String controllerMsgId =
-          taskQueueIdealState.getRecord().getMapField(taskId)
-              .get(DefaultSchedulerMessageHandlerFactory.CONTROLLER_MSG_ID);
+    for (Map<String, String> taskInfo : taskQueueIdealState.getRecord().getMapFields().values()) {
+      String controllerMsgId = taskInfo.get(DefaultSchedulerMessageHandlerFactory.CONTROLLER_MSG_ID);
       if (controllerMsgId != null) {
-        if (!controllerMsgIdCountMap.containsKey(controllerMsgId)) {
-          controllerMsgIdCountMap.put(controllerMsgId, 0);
-        }
-        controllerMsgIdCountMap.put(controllerMsgId,
-            (controllerMsgIdCountMap.get(controllerMsgId) + 1));
+        controllerMsgIdCountMap.put(controllerMsgId, controllerMsgIdCountMap.getOrDefault(controllerMsgId, 0) + 1);
       }
     }
 
-    if (controllerMsgUpdates.size() > 0) {
-      for (String controllerMsgId : controllerMsgUpdates.keySet()) {
+    if (controllerFinishedMsgs.size() > 0) {
+      for (String controllerMsgId : controllerFinishedMsgs.keySet()) {
         PropertyKey controllerStatusUpdateKey =
             keyBuilder.controllerTaskStatus(MessageType.SCHEDULER_MSG.name(), controllerMsgId);
         StatusUpdate controllerStatusUpdate = accessor.getProperty(controllerStatusUpdateKey);
-        for (String taskPartitionName : controllerMsgUpdates.get(controllerMsgId).keySet()) {
-          Map<String, String> result = new HashMap<String, String>();
-          result.put("Result", controllerMsgUpdates.get(controllerMsgId).get(taskPartitionName));
-          controllerStatusUpdate.getRecord().setMapField(
-              "MessageResult "
-                  + taskQueueIdealState.getRecord().getMapField(taskPartitionName)
-                      .get(Message.Attributes.TGT_NAME.toString())
-                  + " "
-                  + taskPartitionName
-                  + " "
-                  + taskQueueIdealState.getRecord().getMapField(taskPartitionName)
-                      .get(Message.Attributes.MSG_ID.toString()), result);
-        }
-        // All done for the scheduled tasks that came from controllerMsgId, add summary for it
+
         Integer controllerMsgIdCount = controllerMsgIdCountMap.get(controllerMsgId);
         if (controllerMsgIdCount != null
-            && controllerMsgUpdates.get(controllerMsgId).size() == controllerMsgIdCount.intValue()) {
+            && controllerFinishedMsgs.get(controllerMsgId).size() == controllerMsgIdCount.intValue()) {
+          // All done for the scheduled tasks that came from controllerMsgId, add summary for it
+          for (String taskPartitionName : controllerFinishedMsgs.get(controllerMsgId).keySet()) {
+            Map<String, String> result = new HashMap<>();
+            result.put("Result", controllerFinishedMsgs.get(controllerMsgId).get(taskPartitionName));
+            controllerStatusUpdate.getRecord().setMapField(
+                "MessageResult "
+                    + taskQueueIdealState.getRecord().getMapField(taskPartitionName)
+                    .get(Message.Attributes.TGT_NAME.toString())
+                    + " "
+                    + taskPartitionName
+                    + " "
+                    + taskQueueIdealState.getRecord().getMapField(taskPartitionName)
+                    .get(Message.Attributes.MSG_ID.toString()), result);
+          }
           int finishedTasksNum = 0;
           int completedTasksNum = 0;
           for (String key : controllerStatusUpdate.getRecord().getMapFields().keySet()) {
@@ -300,7 +280,7 @@
               }
             }
           }
-          Map<String, String> summary = new TreeMap<String, String>();
+          Map<String, String> summary = new TreeMap<>();
           summary.put("TotalMessages:", "" + finishedTasksNum);
           summary.put("CompletedMessages", "" + completedTasksNum);
 
@@ -310,18 +290,5 @@
         accessor.updateProperty(controllerStatusUpdateKey, controllerStatusUpdate);
       }
     }
-
-    if (finishedTasks.getListFields().size() > 0) {
-      ZNRecordDelta znDelta = new ZNRecordDelta(finishedTasks, ZNRecordDelta.MergeOperation.SUBTRACT);
-      List<ZNRecordDelta> deltaList = new LinkedList<ZNRecordDelta>();
-      deltaList.add(znDelta);
-      IdealState delta = new IdealState(taskQueueIdealState.getResourceName());
-      delta.setDeltaList(deltaList);
-
-      // Remove the finished (COMPLETED or ERROR) tasks from the SCHEDULER_TASK_RESOURCE idealstate
-      keyBuilder = accessor.keyBuilder();
-      accessor.updateProperty(keyBuilder.idealStates(taskQueueIdealState.getResourceName()), delta);
-    }
   }
-
 }
diff --git a/helix-core/src/main/java/org/apache/helix/manager/zk/DefaultSchedulerMessageHandlerFactory.java b/helix-core/src/main/java/org/apache/helix/manager/zk/DefaultSchedulerMessageHandlerFactory.java
index 7df5615..19308bb 100644
--- a/helix-core/src/main/java/org/apache/helix/manager/zk/DefaultSchedulerMessageHandlerFactory.java
+++ b/helix-core/src/main/java/org/apache/helix/manager/zk/DefaultSchedulerMessageHandlerFactory.java
@@ -90,7 +90,7 @@
       String key = "MessageResult " + message.getMsgSrc() + " " + UUID.randomUUID();
       _resultSummaryMap.put(key, message.getResultMap());
 
-      if (this.isDone()) {
+      if (isDone()) {
         _logger.info("Scheduler msg " + _originalMessage.getMsgId() + " completed");
         _statusUpdateUtil.logInfo(_originalMessage, SchedulerAsyncCallback.class,
             "Scheduler task completed", _manager);
@@ -100,23 +100,18 @@
 
     private void addSummary(Map<String, Map<String, String>> _resultSummaryMap,
         Message originalMessage, HelixManager manager, boolean timeOut) {
-      Map<String, String> summary = new TreeMap<String, String>();
+      Map<String, String> summary = new TreeMap<>();
       summary.put("TotalMessages:", "" + _resultSummaryMap.size());
       summary.put("Timeout", "" + timeOut);
       _resultSummaryMap.put("Summary", summary);
 
       HelixDataAccessor accessor = manager.getHelixDataAccessor();
       Builder keyBuilder = accessor.keyBuilder();
-      ZNRecord statusUpdate =
-          accessor.getProperty(
-              keyBuilder.controllerTaskStatus(MessageType.SCHEDULER_MSG.name(),
-                  originalMessage.getMsgId())).getRecord();
-
-      statusUpdate.getMapFields().putAll(_resultSummaryMap);
-      accessor.setProperty(
-          keyBuilder.controllerTaskStatus(MessageType.SCHEDULER_MSG.name(),
-              originalMessage.getMsgId()), new StatusUpdate(statusUpdate));
-
+      accessor.updateProperty(
+          keyBuilder.controllerTaskStatus(MessageType.SCHEDULER_MSG.name(), originalMessage.getMsgId()), status -> {
+            status.getMapFields().putAll(_resultSummaryMap);
+            return status;
+          }, null);
     }
   }
 
@@ -169,7 +164,7 @@
             _manager.getClusterName(), clusterName));
       }
 
-      Map<String, String> sendSummary = new HashMap<String, String>();
+      Map<String, String> sendSummary = new HashMap<>();
       sendSummary.put("MessageCount", "0");
       Map<InstanceType, List<Message>> messages =
           _manager.getMessagingService().generateMessage(recipientCriteria, messageTemplate);
@@ -225,7 +220,6 @@
         }
       }
       // Record the number of messages sent into scheduler message status updates
-
       ZNRecord statusUpdate =
           accessor.getProperty(
               keyBuilder.controllerTaskStatus(MessageType.SCHEDULER_MSG.name(),
diff --git a/helix-core/src/test/java/org/apache/helix/integration/messaging/TestSchedulerMessage.java b/helix-core/src/test/java/org/apache/helix/integration/messaging/TestSchedulerMessage.java
index a4878e9..c4c6c51 100644
--- a/helix-core/src/test/java/org/apache/helix/integration/messaging/TestSchedulerMessage.java
+++ b/helix-core/src/test/java/org/apache/helix/integration/messaging/TestSchedulerMessage.java
@@ -41,6 +41,7 @@
 import org.apache.helix.PropertyKey;
 import org.apache.helix.PropertyKey.Builder;
 import org.apache.helix.PropertyPathBuilder;
+import org.apache.helix.TestHelper;
 import org.apache.helix.integration.common.ZkStandAloneCMTestBase;
 import org.apache.helix.manager.zk.DefaultSchedulerMessageHandlerFactory;
 import org.apache.helix.messaging.AsyncCallback;
@@ -186,14 +187,11 @@
     _factory._results.clear();
     HelixManager manager = null;
     for (int i = 0; i < NODE_NR; i++) {
-      _participants[i].getMessagingService()
-          .registerMessageHandlerFactory(_factory.getMessageTypes(), _factory);
-
+      _participants[i].getMessagingService().registerMessageHandlerFactory(_factory.getMessageTypes(), _factory);
       manager = _participants[i];
     }
 
-    Message schedulerMessage =
-        new Message(MessageType.SCHEDULER_MSG + "", UUID.randomUUID().toString());
+    Message schedulerMessage = new Message(MessageType.SCHEDULER_MSG + "", UUID.randomUUID().toString());
     schedulerMessage.setTgtSessionId("*");
     schedulerMessage.setTgtName("CONTROLLER");
     // TODO: change it to "ADMIN" ?
@@ -227,45 +225,42 @@
     Builder keyBuilder = helixDataAccessor.keyBuilder();
     helixDataAccessor.createControllerMessage(schedulerMessage);
 
-    for (int i = 0; i < 30; i++) {
-      Thread.sleep(2000);
-      if (_PARTITIONS == _factory._results.size()) {
-        break;
-      }
-    }
+    Assert.assertTrue(TestHelper.verify(() -> _PARTITIONS == _factory._results.size(), TestHelper.WAIT_DURATION));
 
     Assert.assertEquals(_PARTITIONS, _factory._results.size());
-    PropertyKey controllerTaskStatus = keyBuilder
-        .controllerTaskStatus(MessageType.SCHEDULER_MSG.name(), schedulerMessage.getMsgId());
+    PropertyKey controllerTaskStatus =
+        keyBuilder.controllerTaskStatus(MessageType.SCHEDULER_MSG.name(), schedulerMessage.getMsgId());
 
-    int messageResultCount = 0;
-    for (int i = 0; i < 10; i++) {
-      Thread.sleep(1000);
+    Assert.assertTrue(TestHelper.verify(() -> {
       ZNRecord statusUpdate = helixDataAccessor.getProperty(controllerTaskStatus).getRecord();
-      Assert.assertEquals("" + (_PARTITIONS * 3),
-          statusUpdate.getMapField("SentMessageCount").get("MessageCount"));
+      try {
+        if (_PARTITIONS * 3 != Integer.parseInt(statusUpdate.getMapField("SentMessageCount").get("MessageCount"))) {
+          return false;
+        }
+      } catch (Exception ex) {
+        return false;
+      }
+      int messageResultCount = 0;
       for (String key : statusUpdate.getMapFields().keySet()) {
         if (key.startsWith("MessageResult ")) {
           messageResultCount++;
-          Assert.assertTrue(statusUpdate.getMapField(key).size() > 1);
+          if (statusUpdate.getMapField(key).size() <= 1) {
+            return false;
+          }
         }
       }
-      if (messageResultCount == _PARTITIONS * 3) {
-        break;
-      } else {
-        Thread.sleep(2000);
+      if (messageResultCount != _PARTITIONS * 3) {
+        return false;
       }
-    }
-    Assert.assertEquals(messageResultCount, _PARTITIONS * 3);
-    int count = 0;
-    for (Set<String> val : _factory._results.values()) {
-      count += val.size();
-    }
-    Assert.assertEquals(count, _PARTITIONS * 3);
+      int count = 0;
+      for (Set<String> val : _factory._results.values()) {
+        count += val.size();
+      }
+      return count == _PARTITIONS * 3;
+    }, TestHelper.WAIT_DURATION));
 
     // test the ZkPathDataDumpTask
-    String controllerStatusPath =
-        PropertyPathBuilder.controllerStatusUpdate(manager.getClusterName());
+    String controllerStatusPath = PropertyPathBuilder.controllerStatusUpdate(manager.getClusterName());
     List<String> subPaths = _gZkClient.getChildren(controllerStatusPath);
     Assert.assertTrue(subPaths.size() > 0);
     for (String subPath : subPaths) {
@@ -274,8 +269,8 @@
       Assert.assertTrue(subsubPaths.size() > 0);
     }
 
-    String instanceStatusPath = PropertyPathBuilder.instanceStatusUpdate(manager.getClusterName(),
-        "localhost_" + (START_PORT));
+    String instanceStatusPath =
+        PropertyPathBuilder.instanceStatusUpdate(manager.getClusterName(), "localhost_" + (START_PORT));
     subPaths = _gZkClient.getChildren(instanceStatusPath);
     Assert.assertEquals(subPaths.size(), 0);
     for (String subPath : subPaths) {
@@ -317,14 +312,12 @@
     _factory._results.clear();
     HelixManager manager = null;
     for (int i = 0; i < NODE_NR; i++) {
-      _participants[i].getMessagingService()
-          .registerMessageHandlerFactory(_factory.getMessageTypes(), _factory);
+      _participants[i].getMessagingService().registerMessageHandlerFactory(_factory.getMessageTypes(), _factory);
 
       manager = _participants[i]; // _startCMResultMap.get(hostDest)._manager;
     }
 
-    Message schedulerMessage =
-        new Message(MessageType.SCHEDULER_MSG + "", UUID.randomUUID().toString());
+    Message schedulerMessage = new Message(MessageType.SCHEDULER_MSG + "", UUID.randomUUID().toString());
     schedulerMessage.setTgtSessionId("*");
     schedulerMessage.setTgtName("CONTROLLER");
     // TODO: change it to "ADMIN" ?
@@ -363,14 +356,16 @@
     Thread.sleep(3000);
 
     Assert.assertEquals(0, _factory._results.size());
-    PropertyKey controllerTaskStatus = keyBuilder
-        .controllerTaskStatus(MessageType.SCHEDULER_MSG.name(), schedulerMessage.getMsgId());
-    for (int i = 0; i < 10; i++) {
+    PropertyKey controllerTaskStatus =
+        keyBuilder.controllerTaskStatus(MessageType.SCHEDULER_MSG.name(), schedulerMessage.getMsgId());
+
+    // Need to wait until record is ready
+    Assert.assertTrue(TestHelper.verify(() -> {
       StatusUpdate update = helixDataAccessor.getProperty(controllerTaskStatus);
-      if (update == null || update.getRecord().getMapField("SentMessageCount") == null) {
-        Thread.sleep(1000);
-      }
-    }
+      return update != null && update.getRecord().getMapField("SentMessageCount") != null;
+    }, 10 * 1000));
+
+    // Ensure the records remains to be zero
     ZNRecord statusUpdate = helixDataAccessor.getProperty(controllerTaskStatus).getRecord();
     Assert.assertEquals(statusUpdate.getMapField("SentMessageCount").get("MessageCount"), "0");
     int count = 0;
@@ -386,17 +381,11 @@
     Thread.sleep(2000);
     HelixManager manager = null;
     for (int i = 0; i < NODE_NR; i++) {
-      _participants[i].getMessagingService()
-          .registerMessageHandlerFactory(_factory.getMessageTypes(), _factory);
-
-      _participants[i].getMessagingService()
-          .registerMessageHandlerFactory(_factory.getMessageTypes(), _factory);
-
+      _participants[i].getMessagingService().registerMessageHandlerFactory(_factory.getMessageTypes(), _factory);
       manager = _participants[i];
     }
 
-    Message schedulerMessage =
-        new Message(MessageType.SCHEDULER_MSG + "", UUID.randomUUID().toString());
+    Message schedulerMessage = new Message(MessageType.SCHEDULER_MSG + "", UUID.randomUUID().toString());
     schedulerMessage.setTgtSessionId("*");
     schedulerMessage.setTgtName("CONTROLLER");
     // TODO: change it to "ADMIN" ?
@@ -428,8 +417,8 @@
     schedulerMessage.getRecord().setSimpleField("TIMEOUT", "-1");
     schedulerMessage.getRecord().setSimpleField("WAIT_ALL", "true");
 
-    schedulerMessage.getRecord().setSimpleField(
-        DefaultSchedulerMessageHandlerFactory.SCHEDULER_TASK_QUEUE, "TestSchedulerMsg3");
+    schedulerMessage.getRecord()
+        .setSimpleField(DefaultSchedulerMessageHandlerFactory.SCHEDULER_TASK_QUEUE, "TestSchedulerMsg3");
     Criteria cr2 = new Criteria();
     cr2.setRecipientInstanceType(InstanceType.CONTROLLER);
     cr2.setInstanceName("*");
@@ -458,41 +447,42 @@
       crString = sw.toString();
       schedulerMessage.getRecord().setSimpleField("Criteria", crString);
       manager.getMessagingService().sendAndWait(cr2, schedulerMessage, callback, -1);
-      String msgId = callback._message.getResultMap()
-          .get(DefaultSchedulerMessageHandlerFactory.SCHEDULER_MSG_ID);
+      String msgId = callback._message.getResultMap().get(DefaultSchedulerMessageHandlerFactory.SCHEDULER_MSG_ID);
 
       HelixDataAccessor helixDataAccessor = manager.getHelixDataAccessor();
       Builder keyBuilder = helixDataAccessor.keyBuilder();
 
-      for (int j = 0; j < 100; j++) {
-        Thread.sleep(200);
-        PropertyKey controllerTaskStatus =
-            keyBuilder.controllerTaskStatus(MessageType.SCHEDULER_MSG.name(), msgId);
+      // Wait until all sub messages to be processed
+      PropertyKey controllerTaskStatus = keyBuilder.controllerTaskStatus(MessageType.SCHEDULER_MSG.name(), msgId);
+      int instanceOrder = i;
+      Assert.assertTrue(TestHelper.verify(() -> {
         ZNRecord statusUpdate = helixDataAccessor.getProperty(controllerTaskStatus).getRecord();
-        if (statusUpdate.getMapFields().containsKey("Summary")) {
-          break;
+        if (!statusUpdate.getMapFields().containsKey("Summary")) {
+          return false;
         }
-      }
-
-      Thread.sleep(3000);
-      PropertyKey controllerTaskStatus =
-          keyBuilder.controllerTaskStatus(MessageType.SCHEDULER_MSG.name(), msgId);
-      ZNRecord statusUpdate = helixDataAccessor.getProperty(controllerTaskStatus).getRecord();
-      Assert.assertEquals("" + (_PARTITIONS * 3 / 5),
-          statusUpdate.getMapField("SentMessageCount").get("MessageCount"));
-      int messageResultCount = 0;
-      for (String key : statusUpdate.getMapFields().keySet()) {
-        if (key.startsWith("MessageResult")) {
-          messageResultCount++;
+        try {
+          if (_PARTITIONS * 3 / 5 != Integer.parseInt(
+              statusUpdate.getMapField("SentMessageCount").get("MessageCount"))) {
+            return false;
+          }
+        } catch (Exception ex) {
+          return false;
         }
-      }
-      Assert.assertEquals(messageResultCount, _PARTITIONS * 3 / 5);
-
-      int count = 0;
-      for (Set<String> val : _factory._results.values()) {
-        count += val.size();
-      }
-      Assert.assertEquals(count, _PARTITIONS * 3 / 5 * (i + 1));
+        int messageResultCount = 0;
+        for (String key : statusUpdate.getMapFields().keySet()) {
+          if (key.startsWith("MessageResult")) {
+            messageResultCount++;
+          }
+        }
+        if (messageResultCount != _PARTITIONS * 3 / 5) {
+          return false;
+        }
+        int count = 0;
+        for (Set<String> val : _factory._results.values()) {
+          count += val.size();
+        }
+        return count == _PARTITIONS * 3 / 5 * (instanceOrder + 1);
+      }, TestHelper.WAIT_DURATION));
     }
   }
 
@@ -501,13 +491,11 @@
     _factory._results.clear();
     HelixManager manager = null;
     for (int i = 0; i < NODE_NR; i++) {
-      _participants[i].getMessagingService()
-          .registerMessageHandlerFactory(_factory.getMessageTypes(), _factory);
+      _participants[i].getMessagingService().registerMessageHandlerFactory(_factory.getMessageTypes(), _factory);
       manager = _participants[i];
     }
 
-    Message schedulerMessage =
-        new Message(MessageType.SCHEDULER_MSG + "", UUID.randomUUID().toString());
+    Message schedulerMessage = new Message(MessageType.SCHEDULER_MSG + "", UUID.randomUUID().toString());
     schedulerMessage.setTgtSessionId("*");
     schedulerMessage.setTgtName("CONTROLLER");
     // TODO: change it to "ADMIN" ?
@@ -539,8 +527,8 @@
     schedulerMessage.getRecord().setSimpleField("TIMEOUT", "-1");
     schedulerMessage.getRecord().setSimpleField("WAIT_ALL", "true");
 
-    schedulerMessage.getRecord().setSimpleField(
-        DefaultSchedulerMessageHandlerFactory.SCHEDULER_TASK_QUEUE, "TestSchedulerMsg4");
+    schedulerMessage.getRecord()
+        .setSimpleField(DefaultSchedulerMessageHandlerFactory.SCHEDULER_TASK_QUEUE, "TestSchedulerMsg4");
     Criteria cr2 = new Criteria();
     cr2.setRecipientInstanceType(InstanceType.CONTROLLER);
     cr2.setInstanceName("*");
@@ -551,8 +539,9 @@
     constraints.put("TRANSITION", "OFFLINE-COMPLETED");
     constraints.put("CONSTRAINT_VALUE", "1");
     constraints.put("INSTANCE", ".*");
-    manager.getClusterManagmentTool().setConstraint(manager.getClusterName(),
-        ConstraintType.MESSAGE_CONSTRAINT, "constraint1", new ConstraintItem(constraints));
+    manager.getClusterManagmentTool()
+        .setConstraint(manager.getClusterName(), ConstraintType.MESSAGE_CONSTRAINT, "constraint1",
+            new ConstraintItem(constraints));
 
     MockAsyncCallback callback = new MockAsyncCallback();
     cr.setInstanceName("localhost_%");
@@ -565,8 +554,7 @@
     crString = sw.toString();
     schedulerMessage.getRecord().setSimpleField("Criteria", crString);
     manager.getMessagingService().sendAndWait(cr2, schedulerMessage, callback, -1);
-    String msgIdPrime = callback._message.getResultMap()
-        .get(DefaultSchedulerMessageHandlerFactory.SCHEDULER_MSG_ID);
+    String msgIdPrime = callback._message.getResultMap().get(DefaultSchedulerMessageHandlerFactory.SCHEDULER_MSG_ID);
 
     HelixDataAccessor helixDataAccessor = manager.getHelixDataAccessor();
     Builder keyBuilder = helixDataAccessor.keyBuilder();
@@ -583,51 +571,48 @@
       crString = sw.toString();
       schedulerMessage.getRecord().setSimpleField("Criteria", crString);
       manager.getMessagingService().sendAndWait(cr2, schedulerMessage, callback, -1);
-      String msgId = callback._message.getResultMap()
-          .get(DefaultSchedulerMessageHandlerFactory.SCHEDULER_MSG_ID);
+      String msgId = callback._message.getResultMap().get(DefaultSchedulerMessageHandlerFactory.SCHEDULER_MSG_ID);
       msgIds.add(msgId);
     }
     for (int i = 0; i < NODE_NR; i++) {
       String msgId = msgIds.get(i);
-      for (int j = 0; j < 100; j++) {
-        Thread.sleep(200);
-        PropertyKey controllerTaskStatus =
-            keyBuilder.controllerTaskStatus(MessageType.SCHEDULER_MSG.name(), msgId);
+      PropertyKey controllerTaskStatus = keyBuilder.controllerTaskStatus(MessageType.SCHEDULER_MSG.name(), msgId);
+      // Wait until all sub messages to be processed
+      Assert.assertTrue(TestHelper.verify(() -> {
         ZNRecord statusUpdate = helixDataAccessor.getProperty(controllerTaskStatus).getRecord();
-        if (statusUpdate.getMapFields().containsKey("Summary")) {
-          break;
+        if (!statusUpdate.getMapFields().containsKey("Summary")) {
+          return false;
         }
-      }
-
-      // Add a half-second delay because it takes time for messages to be processed
-      Thread.sleep(500L);
-      PropertyKey controllerTaskStatus =
-          keyBuilder.controllerTaskStatus(MessageType.SCHEDULER_MSG.name(), msgId);
-      ZNRecord statusUpdate = helixDataAccessor.getProperty(controllerTaskStatus).getRecord();
-      Assert.assertEquals("" + (_PARTITIONS * 3 / 5),
-          statusUpdate.getMapField("SentMessageCount").get("MessageCount"));
-      int messageResultCount = 0;
-      for (String key : statusUpdate.getMapFields().keySet()) {
-        if (key.startsWith("MessageResult")) {
-          messageResultCount++;
+        try {
+          if (_PARTITIONS * 3 / 5 != Integer.parseInt(
+              statusUpdate.getMapField("SentMessageCount").get("MessageCount"))) {
+            return false;
+          }
+        } catch (Exception ex) {
+          return false;
         }
-      }
-      Assert.assertEquals(messageResultCount, _PARTITIONS * 3 / 5);
+        int messageResultCount = 0;
+        for (String key : statusUpdate.getMapFields().keySet()) {
+          if (key.startsWith("MessageResult")) {
+            messageResultCount++;
+          }
+        }
+        return messageResultCount == _PARTITIONS * 3 / 5;
+      }, TestHelper.WAIT_DURATION));
     }
 
-    for (int j = 0; j < 100; j++) {
-      Thread.sleep(200);
-      PropertyKey controllerTaskStatus =
-          keyBuilder.controllerTaskStatus(MessageType.SCHEDULER_MSG.name(), msgIdPrime);
+    // Wait until the main message to be processed
+    PropertyKey controllerTaskStatus = keyBuilder.controllerTaskStatus(MessageType.SCHEDULER_MSG.name(), msgIdPrime);
+    Assert.assertTrue(TestHelper.verify(() -> {
       ZNRecord statusUpdate = helixDataAccessor.getProperty(controllerTaskStatus).getRecord();
-      if (statusUpdate.getMapFields().containsKey("Summary")) {
-        break;
+      if (!statusUpdate.getMapFields().containsKey("Summary")) {
+        return false;
       }
-    }
-    int count = 0;
-    for (Set<String> val : _factory._results.values()) {
-      count += val.size();
-    }
-    Assert.assertEquals(count, _PARTITIONS * 3 * 2);
+      int count = 0;
+      for (Set<String> val : _factory._results.values()) {
+        count += val.size();
+      }
+      return count == _PARTITIONS * 3 * 2;
+    }, TestHelper.WAIT_DURATION));
   }
 }