Optimize HelixTaskExecutor reset() in event of shutdown (#2183)

Optimize HelixTaskExecutor reset() in event of shutdown

Some instances maybe reset() multiple times during participant shutdown.
This commit refactor logic in HelixTaskExecutor to reduce unnecessary
method call.
diff --git a/helix-core/src/main/java/org/apache/helix/messaging/handling/HelixTaskExecutor.java b/helix-core/src/main/java/org/apache/helix/messaging/handling/HelixTaskExecutor.java
index 80b70bb..b21a065 100644
--- a/helix-core/src/main/java/org/apache/helix/messaging/handling/HelixTaskExecutor.java
+++ b/helix-core/src/main/java/org/apache/helix/messaging/handling/HelixTaskExecutor.java
@@ -28,6 +28,7 @@
 import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
+import java.util.Objects;
 import java.util.Set;
 import java.util.Timer;
 import java.util.TimerTask;
@@ -136,6 +137,8 @@
   private final ParticipantStatusMonitor _monitor;
   public static final String MAX_THREADS = "maxThreads";
 
+  // true if all partition state are "clean" as same after reset()
+  private volatile boolean _isCleanState = true;
   private MessageQueueMonitor _messageQueueMonitor;
   private GenericHelixController _controller;
   private Long _lastSessionSyncTime;
@@ -677,13 +680,10 @@
     }
   }
 
-  void reset() {
-    LOG.info("Reset HelixTaskExecutor");
-
-    if (_messageQueueMonitor != null) {
-      _messageQueueMonitor.reset();
-    }
-
+  /**
+   * Shutdown the registered thread pool executors. This method will be no-op if called repeatedly.
+   */
+  private void shutdownExecutors() {
     synchronized (_hdlrFtyRegistry) {
       for (String msgType : _hdlrFtyRegistry.keySet()) {
         // don't un-register factories, just shutdown all executors
@@ -694,17 +694,37 @@
           LOG.info("Reset executor for msgType: " + msgType + ", pool: " + pool);
           shutdownAndAwaitTermination(pool, item);
         }
-
-        if (item.factory() != null) {
-          try {
-            item.factory().reset();
-          } catch (Exception ex) {
-            LOG.error("Failed to reset the factory {} of message type {}.", item.factory().toString(),
-                msgType, ex);
-          }
-        }
       }
     }
+  }
+
+  synchronized void reset() {
+    if (_isCleanState) {
+      LOG.info("HelixTaskExecutor is in clean state, no need to reset again");
+      return;
+    }
+    LOG.info("Reset HelixTaskExecutor");
+
+    if (_messageQueueMonitor != null) {
+      _messageQueueMonitor.reset();
+    }
+
+    shutdownExecutors();
+
+    synchronized (_hdlrFtyRegistry) {
+      _hdlrFtyRegistry.values()
+          .stream()
+          .map(MsgHandlerFactoryRegistryItem::factory)
+          .distinct()
+          .filter(Objects::nonNull)
+          .forEach(factory -> {
+            try {
+              factory.reset();
+            } catch (Exception ex) {
+              LOG.error("Failed to reset the factory {}.", factory.toString(), ex);
+            }
+          });
+    }
     // threads pool specific to STATE_TRANSITION.Key specific pool are not shut down.
     // this is a potential area to improve. https://github.com/apache/helix/issues/1245
 
@@ -712,8 +732,7 @@
     // Log all tasks that fail to terminate
     for (String taskId : _taskMap.keySet()) {
       MessageTaskInfo info = _taskMap.get(taskId);
-      sb.append(
-          "Task: " + taskId + " fails to terminate. Message: " + info._task.getMessage() + "\n");
+      sb.append("Task: " + taskId + " fails to terminate. Message: " + info._task.getMessage() + "\n");
     }
 
     LOG.info(sb.toString());
@@ -724,6 +743,7 @@
     _knownMessageIds.clear();
 
     _lastSessionSyncTime = null;
+    _isCleanState = true;
   }
 
   void init() {
@@ -744,7 +764,7 @@
         _monitor.createExecutorMonitor(type, newPool);
         return newPool;
       });
-      LOG.info("Setup the thread pool for type: %s, isShutdown: %s", msgType, pool.isShutdown());
+      LOG.info("Setup the thread pool for type: {}, isShutdown: {}", msgType, pool.isShutdown());
     }
   }
 
@@ -835,6 +855,7 @@
       init();
       // continue to process messages
     }
+    _isCleanState = false;
 
     // if prefetch is disabled in MessageListenerCallback, we need to read all new messages from zk.
     if (messages == null || messages.isEmpty()) {
@@ -1442,7 +1463,7 @@
       nopMsg.setTgtName(instanceName);
       accessor
           .setProperty(accessor.keyBuilder().message(nopMsg.getTgtName(), nopMsg.getId()), nopMsg);
-      LOG.info("Send NO_OP message to {}}, msgId: {}.", nopMsg.getTgtName(), nopMsg.getId());
+      LOG.info("Send NO_OP message to {}, msgId: {}.", nopMsg.getTgtName(), nopMsg.getId());
     } catch (Exception e) {
       LOG.error("Failed to send NO_OP message to {}.", instanceName, e);
     }
@@ -1454,6 +1475,7 @@
     _isShuttingDown = true;
     _timer.cancel();
 
+    shutdownExecutors();
     reset();
     _monitor.shutDown();
     LOG.info("Shutdown HelixTaskExecutor finished");
diff --git a/helix-core/src/main/java/org/apache/helix/participant/HelixStateMachineEngine.java b/helix-core/src/main/java/org/apache/helix/participant/HelixStateMachineEngine.java
index 50a0782..1a16de4 100644
--- a/helix-core/src/main/java/org/apache/helix/participant/HelixStateMachineEngine.java
+++ b/helix-core/src/main/java/org/apache/helix/participant/HelixStateMachineEngine.java
@@ -164,6 +164,7 @@
   }
 
   private void loopStateModelFactories(Consumer<StateModel> consumer) {
+    // TODO: evaluate impact and consider parallelization
     for (Map<String, StateModelFactory<? extends StateModel>> ftyMap : _stateModelFactoryMap
         .values()) {
       for (StateModelFactory<? extends StateModel> stateModelFactory : ftyMap.values()) {
diff --git a/helix-core/src/test/java/org/apache/helix/messaging/handling/TestHelixTaskExecutor.java b/helix-core/src/test/java/org/apache/helix/messaging/handling/TestHelixTaskExecutor.java
index 7b23990..1a8cc5a 100644
--- a/helix-core/src/test/java/org/apache/helix/messaging/handling/TestHelixTaskExecutor.java
+++ b/helix-core/src/test/java/org/apache/helix/messaging/handling/TestHelixTaskExecutor.java
@@ -132,6 +132,21 @@
     }
   }
 
+  private class TestMessageHandlerFactory3 extends TestMessageHandlerFactory {
+    private boolean _resetDone = false;
+
+    @Override
+    public List<String> getMessageTypes() {
+      return ImmutableList.of("msgType1", "msgType2", "msgType3");
+    }
+
+    @Override
+    public void reset() {
+      Assert.assertFalse(_resetDone, "reset() should only be triggered once in TestMessageHandlerFactory3");
+      _resetDone = true;
+    }
+  }
+
   class CancellableHandlerFactory implements MultiTypeMessageHandlerFactory {
 
     int _handlersCreated = 0;
@@ -798,6 +813,29 @@
     System.out.println("END TestCMTaskExecutor.testHandlerResetTimeout()");
   }
 
+  @Test
+  public void testMsgHandlerRegistryAndShutdown() {
+    HelixTaskExecutor executor = new HelixTaskExecutor();
+    HelixManager manager = new MockClusterManager();
+    TestMessageHandlerFactory factory = new TestMessageHandlerFactory();
+    TestMessageHandlerFactory3 factoryMulti = new TestMessageHandlerFactory3();
+    executor.registerMessageHandlerFactory(factory, HelixTaskExecutor.DEFAULT_PARALLEL_TASKS, 200);
+    executor.registerMessageHandlerFactory(factoryMulti, HelixTaskExecutor.DEFAULT_PARALLEL_TASKS, 200);
+
+    final Message msg = new Message(factory.getMessageTypes().get(0), UUID.randomUUID().toString());
+    msg.setTgtSessionId("*");
+    msg.setTgtName("Localhost_1123");
+    msg.setSrcName("127.101.1.23_2234");
+
+    NotificationContext changeContext = new NotificationContext(manager);
+    changeContext.setChangeType(HelixConstants.ChangeType.MESSAGE);
+    executor.onMessage("some", Collections.singletonList(msg), changeContext);
+    Assert.assertEquals(executor._hdlrFtyRegistry.size(), 4);
+    // Ensure TestMessageHandlerFactory3 instance is reset and reset exactly once
+    executor.shutdown();
+    Assert.assertTrue(factoryMulti._resetDone, "TestMessageHandlerFactory3 should be reset");
+  }
+
   @Test()
   public void testNoRetry() throws InterruptedException {
     System.out.println("START " + TestHelper.getTestMethodName());