Close zkClients created by TaskStateModelFactory (#1678)

This PR closes the previously unclosed zkClients in TaskStateModelFactory, also added retry timeout and log to the logic.
diff --git a/helix-core/src/main/java/org/apache/helix/task/TaskStateModelFactory.java b/helix-core/src/main/java/org/apache/helix/task/TaskStateModelFactory.java
index aded724..7b9dd74 100644
--- a/helix-core/src/main/java/org/apache/helix/task/TaskStateModelFactory.java
+++ b/helix-core/src/main/java/org/apache/helix/task/TaskStateModelFactory.java
@@ -19,6 +19,7 @@
  * under the License.
  */
 
+import java.time.Duration;
 import java.util.Map;
 import java.util.concurrent.Executors;
 import java.util.concurrent.ScheduledExecutorService;
@@ -51,6 +52,9 @@
 public class TaskStateModelFactory extends StateModelFactory<TaskStateModel> {
   private static Logger LOG = LoggerFactory.getLogger(TaskStateModelFactory.class);
 
+  // Unit in minutes. Need a retry timeout to prevent zkClient from hanging infinitely.
+  private static final int ZKCLIENT_OPERATION_RETRY_TIMEOUT = 5;
+
   private final HelixManager _manager;
   private final Map<String, TaskFactory> _taskFactoryRegistry;
   private final ScheduledExecutorService _taskExecutor;
@@ -58,16 +62,7 @@
   private ThreadPoolExecutorMonitor _monitor;
 
   public TaskStateModelFactory(HelixManager manager, Map<String, TaskFactory> taskFactoryRegistry) {
-    this(manager, taskFactoryRegistry, Executors.newScheduledThreadPool(TaskUtil
-        .getTargetThreadPoolSize(createZkClient(manager), manager.getClusterName(),
-            manager.getInstanceName()), new ThreadFactory() {
-      private AtomicInteger threadId = new AtomicInteger(0);
-
-      @Override
-      public Thread newThread(Runnable r) {
-        return new Thread(r, "TaskStateModelFactory-task_thread-" + threadId.getAndIncrement());
-      }
-    }));
+    this(manager, taskFactoryRegistry, createThreadPoolExecutor(manager));
   }
 
   // DO NOT USE! This size of provided thread pool will not be reflected to controller
@@ -134,12 +129,6 @@
    * Create a RealmAwareZkClient to get thread pool sizes
    */
   protected static RealmAwareZkClient createZkClient(HelixManager manager) {
-    // TODO: revisit the logic here - we are creating a connection although we already have a
-    // manager. We cannot use the connection within manager because some users connect the manager
-    // after registering the state model factory (in which case we cannot use manager's connection),
-    // and some connect the manager before registering the state model factory (in which case we
-    // can use manager's connection). We need to think about the right order and determine if we
-    // want to enforce it, which may cause backward incompatibility.
     if (!(manager instanceof ZKHelixManager)) {
       // TODO: None-ZKHelixManager cannot initialize this class. After interface rework of
       // HelixManager, the initialization should be allowed.
@@ -148,6 +137,9 @@
     }
     RealmAwareZkClient.RealmAwareZkClientConfig clientConfig =
         new RealmAwareZkClient.RealmAwareZkClientConfig().setZkSerializer(new ZNRecordSerializer());
+    // Set operation retry timeout to prevent hanging infinitely
+    clientConfig
+        .setOperationRetryTimeout(Duration.ofMinutes(ZKCLIENT_OPERATION_RETRY_TIMEOUT).toMillis());
     String zkAddress = manager.getMetadataStoreConnectionString();
 
     if (Boolean.getBoolean(SystemPropertyKeys.MULTI_ZK_ENABLED) || zkAddress == null) {
@@ -172,8 +164,40 @@
       }
     }
 
-    return SharedZkClientFactory.getInstance().buildZkClient(
-        new HelixZkClient.ZkConnectionConfig(zkAddress),
-        clientConfig.createHelixZkClientConfig().setZkSerializer(new ZNRecordSerializer()));
+    // Note: operation retry timeout doesn't take effect due to github.com/apache/helix/issues/1682
+    return SharedZkClientFactory.getInstance()
+        .buildZkClient(new HelixZkClient.ZkConnectionConfig(zkAddress),
+            clientConfig.createHelixZkClientConfig());
+  }
+
+  private static ScheduledExecutorService createThreadPoolExecutor(HelixManager manager) {
+    // TODO: revisit the logic here - we are creating a connection although we already have a
+    // manager. We cannot use the connection within manager because some users connect the manager
+    // after registering the state model factory (in which case we cannot use manager's connection),
+    // and some connect the manager before registering the state model factory (in which case we
+    // can use manager's connection). We need to think about the right order and determine if we
+    // want to enforce it, which may cause backward incompatibility.
+    RealmAwareZkClient zkClient = createZkClient(manager);
+    int targetThreadPoolSize;
+
+    // Ensure the zkClient is closed after reading the pool size;
+    try {
+      targetThreadPoolSize = TaskUtil
+          .getTargetThreadPoolSize(zkClient, manager.getClusterName(), manager.getInstanceName());
+    } finally {
+      zkClient.close();
+    }
+
+    LOG.info(
+        "Obtained target thread pool size: {} from cluster {} for instance {}. Creating thread pool.",
+        targetThreadPoolSize, manager.getClusterName(), manager.getInstanceName());
+    return Executors.newScheduledThreadPool(targetThreadPoolSize, new ThreadFactory() {
+      private AtomicInteger threadId = new AtomicInteger(0);
+
+      @Override
+      public Thread newThread(Runnable r) {
+        return new Thread(r, "TaskStateModelFactory-task_thread-" + threadId.getAndIncrement());
+      }
+    });
   }
 }
diff --git a/helix-core/src/test/java/org/apache/helix/task/TestTaskStateModelFactory.java b/helix-core/src/test/java/org/apache/helix/task/TestTaskStateModelFactory.java
index bb64989..38286b4 100644
--- a/helix-core/src/test/java/org/apache/helix/task/TestTaskStateModelFactory.java
+++ b/helix-core/src/test/java/org/apache/helix/task/TestTaskStateModelFactory.java
@@ -24,6 +24,7 @@
 import java.util.HashMap;
 import java.util.Map;
 
+import org.apache.helix.HelixManager;
 import org.apache.helix.SystemPropertyKeys;
 import org.apache.helix.integration.manager.MockParticipantManager;
 import org.apache.helix.integration.task.TaskTestBase;
@@ -35,6 +36,7 @@
 import org.apache.helix.zookeeper.api.client.RealmAwareZkClient;
 import org.apache.helix.zookeeper.constant.RoutingDataReaderType;
 import org.apache.helix.zookeeper.impl.client.FederatedZkClient;
+import org.apache.helix.zookeeper.impl.factory.SharedZkClientFactory;
 import org.apache.helix.zookeeper.routing.RoutingDataManager;
 import org.mockito.Mockito;
 import org.testng.Assert;
@@ -90,29 +92,20 @@
         testMSDSServerEndpointKey);
 
     RoutingDataManager.getInstance().reset();
-    RealmAwareZkClient zkClient = TaskStateModelFactory.createZkClient(anyParticipantManager);
-    Assert.assertEquals(TaskUtil
-        .getTargetThreadPoolSize(zkClient, anyParticipantManager.getClusterName(),
-            anyParticipantManager.getInstanceName()), TEST_TARGET_TASK_THREAD_POOL_SIZE);
-    Assert.assertTrue(zkClient instanceof FederatedZkClient);
+    verifyThreadPoolSizeAndZkClientClass(anyParticipantManager, TEST_TARGET_TASK_THREAD_POOL_SIZE,
+        FederatedZkClient.class);
 
     // Turn off multiZk mode in System config, and remove zkAddress
     System.setProperty(SystemPropertyKeys.MULTI_ZK_ENABLED, "false");
     ZKHelixManager participantManager = Mockito.spy(anyParticipantManager);
     when(participantManager.getMetadataStoreConnectionString()).thenReturn(null);
-    zkClient = TaskStateModelFactory.createZkClient(participantManager);
-    Assert.assertEquals(TaskUtil
-        .getTargetThreadPoolSize(zkClient, anyParticipantManager.getClusterName(),
-            anyParticipantManager.getInstanceName()), TEST_TARGET_TASK_THREAD_POOL_SIZE);
-    Assert.assertTrue(zkClient instanceof FederatedZkClient);
+    verifyThreadPoolSizeAndZkClientClass(participantManager, TEST_TARGET_TASK_THREAD_POOL_SIZE,
+        FederatedZkClient.class);
 
     // Test no connection config case
     when(participantManager.getRealmAwareZkConnectionConfig()).thenReturn(null);
-    zkClient = TaskStateModelFactory.createZkClient(participantManager);
-    Assert.assertEquals(TaskUtil
-        .getTargetThreadPoolSize(zkClient, anyParticipantManager.getClusterName(),
-            anyParticipantManager.getInstanceName()), TEST_TARGET_TASK_THREAD_POOL_SIZE);
-    Assert.assertTrue(zkClient instanceof FederatedZkClient);
+    verifyThreadPoolSizeAndZkClientClass(participantManager, TEST_TARGET_TASK_THREAD_POOL_SIZE,
+        FederatedZkClient.class);
 
     // Remove server endpoint key and use connection config to specify endpoint
     System.clearProperty(SystemPropertyKeys.MSDS_SERVER_ENDPOINT_KEY);
@@ -122,11 +115,8 @@
             .setRoutingDataSourceEndpoint(testMSDSServerEndpointKey)
             .setRoutingDataSourceType(RoutingDataReaderType.HTTP.name()).build();
     when(participantManager.getRealmAwareZkConnectionConfig()).thenReturn(connectionConfig);
-    zkClient = TaskStateModelFactory.createZkClient(participantManager);
-    Assert.assertEquals(TaskUtil
-        .getTargetThreadPoolSize(zkClient, anyParticipantManager.getClusterName(),
-            anyParticipantManager.getInstanceName()), TEST_TARGET_TASK_THREAD_POOL_SIZE);
-    Assert.assertTrue(zkClient instanceof FederatedZkClient);
+    verifyThreadPoolSizeAndZkClientClass(participantManager, TEST_TARGET_TASK_THREAD_POOL_SIZE,
+        FederatedZkClient.class);
 
     // Restore system properties
     if (prevMultiZkEnabled == null) {
@@ -151,10 +141,8 @@
     // Turn off multiZk mode in System config
     System.setProperty(SystemPropertyKeys.MULTI_ZK_ENABLED, "false");
 
-    RealmAwareZkClient zkClient = TaskStateModelFactory.createZkClient(anyParticipantManager);
-    Assert.assertEquals(TaskUtil
-        .getTargetThreadPoolSize(zkClient, anyParticipantManager.getClusterName(),
-            anyParticipantManager.getInstanceName()), TEST_TARGET_TASK_THREAD_POOL_SIZE);
+    verifyThreadPoolSizeAndZkClientClass(anyParticipantManager, TEST_TARGET_TASK_THREAD_POOL_SIZE,
+        SharedZkClientFactory.InnerSharedZkClient.class);
 
     // Restore system properties
     if (prevMultiZkEnabled == null) {
@@ -169,4 +157,16 @@
   public void testZkClientCreationNonZKManager() {
     TaskStateModelFactory.createZkClient(new MockManager());
   }
+
+  private void verifyThreadPoolSizeAndZkClientClass(HelixManager helixManager, int threadPoolSize,
+      Class<?> zkClientClass) {
+    RealmAwareZkClient zkClient = TaskStateModelFactory.createZkClient(helixManager);
+    try {
+      Assert.assertEquals(TaskUtil.getTargetThreadPoolSize(zkClient, helixManager.getClusterName(),
+          helixManager.getInstanceName()), threadPoolSize);
+      Assert.assertEquals(zkClient.getClass(), zkClientClass);
+    } finally {
+      zkClient.close();
+    }
+  }
 }