Revert "Revert "Add async call retry to resolve the transient ZK connection issue. (#970)""

This reverts commit 370e277966f75a7fba45f5b96f7608c127b2905c.
diff --git a/zookeeper-api/src/main/java/org/apache/helix/zookeeper/zkclient/ZkClient.java b/zookeeper-api/src/main/java/org/apache/helix/zookeeper/zkclient/ZkClient.java
index 562143f..89f9e32 100644
--- a/zookeeper-api/src/main/java/org/apache/helix/zookeeper/zkclient/ZkClient.java
+++ b/zookeeper-api/src/main/java/org/apache/helix/zookeeper/zkclient/ZkClient.java
@@ -27,7 +27,10 @@
 import org.apache.helix.zookeeper.datamodel.ZNRecord;
 import org.apache.helix.zookeeper.exception.ZkClientException;
 import org.apache.helix.zookeeper.zkclient.annotation.PreFetch;
+import org.apache.helix.zookeeper.zkclient.callback.ZkAsyncCallMonitorContext;
 import org.apache.helix.zookeeper.zkclient.callback.ZkAsyncCallbacks;
+import org.apache.helix.zookeeper.zkclient.callback.ZkAsyncRetryCallContext;
+import org.apache.helix.zookeeper.zkclient.callback.ZkAsyncRetryThread;
 import org.apache.helix.zookeeper.zkclient.exception.ZkBadVersionException;
 import org.apache.helix.zookeeper.zkclient.exception.ZkException;
 import org.apache.helix.zookeeper.zkclient.exception.ZkInterruptedException;
@@ -96,6 +99,10 @@
   private PathBasedZkSerializer _pathBasedZkSerializer;
   private ZkClientMonitor _monitor;
 
+  // To automatically retry the async operation, we need a separate thread other than the
+  // ZkEventThread. Otherwise the retry request might block the normal event processing.
+  protected final ZkAsyncRetryThread _asyncCallRetryThread;
+
   private class IZkDataListenerEntry {
     final IZkDataListener _dataListener;
     final boolean _prefetchData;
@@ -183,6 +190,9 @@
     _operationRetryTimeoutInMillis = operationRetryTimeout;
     _isNewSessionEventFired = false;
 
+    _asyncCallRetryThread = new ZkAsyncRetryThread(zkConnection.getServers());
+    _asyncCallRetryThread.start();
+
     connect(connectionTimeout, this);
 
     // initiate monitor
@@ -1736,15 +1746,23 @@
       data = (datat == null ? null : serialize(datat, path));
     } catch (ZkMarshallingError e) {
       cb.processResult(KeeperException.Code.MARSHALLINGERROR.intValue(), path,
-          new ZkAsyncCallbacks.ZkAsyncCallContext(_monitor, startT, 0, false), null);
+          new ZkAsyncCallMonitorContext(_monitor, startT, 0, false), null);
       return;
     }
+    doAsyncCreate(path, data, mode, startT, cb);
+  }
+
+  private void doAsyncCreate(final String path, final byte[] data, final CreateMode mode,
+      final long startT, final ZkAsyncCallbacks.CreateCallbackHandler cb) {
     retryUntilConnected(() -> {
       ((ZkConnection) getConnection()).getZookeeper()
-          .create(path, data, ZooDefs.Ids.OPEN_ACL_UNSAFE,
-              // Arrays.asList(DEFAULT_ACL),
-              mode, cb, new ZkAsyncCallbacks.ZkAsyncCallContext(_monitor, startT,
-                  data == null ? 0 : data.length, false));
+          .create(path, data, ZooDefs.Ids.OPEN_ACL_UNSAFE, mode, cb,
+              new ZkAsyncRetryCallContext(_asyncCallRetryThread, cb, _monitor, startT, 0, false) {
+                @Override
+                protected void doRetry() {
+                  doAsyncCreate(path, data, mode, System.currentTimeMillis(), cb);
+                }
+              });
       return null;
     });
   }
@@ -1758,50 +1776,66 @@
       data = serialize(datat, path);
     } catch (ZkMarshallingError e) {
       cb.processResult(KeeperException.Code.MARSHALLINGERROR.intValue(), path,
-          new ZkAsyncCallbacks.ZkAsyncCallContext(_monitor, startT, 0, false), null);
+          new ZkAsyncCallMonitorContext(_monitor, startT, 0, false), null);
       return;
     }
+    doAsyncSetData(path, data, version, startT, cb);
+  }
+
+  private void doAsyncSetData(final String path, byte[] data, final int version, final long startT,
+      final ZkAsyncCallbacks.SetDataCallbackHandler cb) {
     retryUntilConnected(() -> {
       ((ZkConnection) getConnection()).getZookeeper().setData(path, data, version, cb,
-          new ZkAsyncCallbacks.ZkAsyncCallContext(_monitor, startT,
-              data == null ? 0 : data.length, false));
+          new ZkAsyncRetryCallContext(_asyncCallRetryThread, cb, _monitor, startT,
+              data == null ? 0 : data.length, false) {
+            @Override
+            protected void doRetry() {
+              doAsyncSetData(path, data, version, System.currentTimeMillis(), cb);
+            }
+          });
       return null;
     });
   }
 
   public void asyncGetData(final String path, final ZkAsyncCallbacks.GetDataCallbackHandler cb) {
     final long startT = System.currentTimeMillis();
-    retryUntilConnected(new Callable<Object>() {
-      @Override
-      public Object call() throws Exception {
-        ((ZkConnection) getConnection()).getZookeeper().getData(path, null, cb,
-            new ZkAsyncCallbacks.ZkAsyncCallContext(_monitor, startT, 0, true));
-        return null;
-      }
+    retryUntilConnected(() -> {
+      ((ZkConnection) getConnection()).getZookeeper().getData(path, null, cb,
+          new ZkAsyncRetryCallContext(_asyncCallRetryThread, cb, _monitor, startT, 0, true) {
+            @Override
+            protected void doRetry() {
+              asyncGetData(path, cb);
+            }
+          });
+      return null;
     });
   }
 
   public void asyncExists(final String path, final ZkAsyncCallbacks.ExistsCallbackHandler cb) {
     final long startT = System.currentTimeMillis();
-    retryUntilConnected(new Callable<Object>() {
-      @Override
-      public Object call() throws Exception {
-        ((ZkConnection) getConnection()).getZookeeper().exists(path, null, cb,
-            new ZkAsyncCallbacks.ZkAsyncCallContext(_monitor, startT, 0, true));
-        return null;
-      }
+    retryUntilConnected(() -> {
+      ((ZkConnection) getConnection()).getZookeeper().exists(path, null, cb,
+          new ZkAsyncRetryCallContext(_asyncCallRetryThread, cb, _monitor, startT, 0, true) {
+            @Override
+            protected void doRetry() {
+              asyncExists(path, cb);
+            }
+          });
+      return null;
     });
   }
 
   public void asyncDelete(final String path, final ZkAsyncCallbacks.DeleteCallbackHandler cb) {
     final long startT = System.currentTimeMillis();
-    retryUntilConnected(new Callable<Object>() {
-      @Override
-      public Object call() throws Exception {
-        ((ZkConnection) getConnection()).getZookeeper().delete(path, -1, cb,
-            new ZkAsyncCallbacks.ZkAsyncCallContext(_monitor, startT, 0, false));
-        return null;
-      }
+    retryUntilConnected(() -> {
+      ((ZkConnection) getConnection()).getZookeeper().delete(path, -1, cb,
+          new ZkAsyncRetryCallContext(_asyncCallRetryThread, cb, _monitor, startT, 0, false) {
+            @Override
+            protected void doRetry() {
+              asyncDelete(path, cb);
+            }
+          });
+      return null;
     });
   }
 
@@ -1955,6 +1989,10 @@
         return;
       }
       setShutdownTrigger(true);
+      if (_asyncCallRetryThread != null) {
+        _asyncCallRetryThread.interrupt();
+        _asyncCallRetryThread.join(2000);
+      }
       _eventThread.interrupt();
       _eventThread.join(2000);
       if (isManagingZkConnection()) {
diff --git a/zookeeper-api/src/main/java/org/apache/helix/zookeeper/zkclient/callback/CancellableZkAsyncCallback.java b/zookeeper-api/src/main/java/org/apache/helix/zookeeper/zkclient/callback/CancellableZkAsyncCallback.java
new file mode 100644
index 0000000..27d92e8
--- /dev/null
+++ b/zookeeper-api/src/main/java/org/apache/helix/zookeeper/zkclient/callback/CancellableZkAsyncCallback.java
@@ -0,0 +1,8 @@
+package org.apache.helix.zookeeper.zkclient.callback;
+
+public interface CancellableZkAsyncCallback {
+  /**
+   * Notify all the callers that are waiting for the callback to cancel the wait.
+   */
+  void notifyCallers();
+}
diff --git a/zookeeper-api/src/main/java/org/apache/helix/zookeeper/zkclient/callback/ZkAsyncCallMonitorContext.java b/zookeeper-api/src/main/java/org/apache/helix/zookeeper/zkclient/callback/ZkAsyncCallMonitorContext.java
new file mode 100644
index 0000000..bf2fd44
--- /dev/null
+++ b/zookeeper-api/src/main/java/org/apache/helix/zookeeper/zkclient/callback/ZkAsyncCallMonitorContext.java
@@ -0,0 +1,46 @@
+package org.apache.helix.zookeeper.zkclient.callback;
+
+import org.apache.helix.zookeeper.zkclient.metric.ZkClientMonitor;
+
+public class ZkAsyncCallMonitorContext {
+  private final long _startTimeMilliSec;
+  private final ZkClientMonitor _monitor;
+  private final boolean _isRead;
+  private int _bytes;
+
+  /**
+   * @param monitor           ZkClient monitor for update the operation result.
+   * @param startTimeMilliSec Operation initialization time.
+   * @param bytes             The data size in bytes that is involved in the operation.
+   * @param isRead            True if the operation is readonly.
+   */
+  public ZkAsyncCallMonitorContext(final ZkClientMonitor monitor, long startTimeMilliSec, int bytes,
+      boolean isRead) {
+    _monitor = monitor;
+    _startTimeMilliSec = startTimeMilliSec;
+    _bytes = bytes;
+    _isRead = isRead;
+  }
+
+  /**
+   * Update the operated data size in bytes.
+   * @param bytes
+   */
+  void setBytes(int bytes) {
+    _bytes = bytes;
+  }
+
+  /**
+   * Record the operation result into the specified ZkClient monitor.
+   * @param path
+   */
+  void recordAccess(String path) {
+    if (_monitor != null) {
+      if (_isRead) {
+        _monitor.record(path, _bytes, _startTimeMilliSec, ZkClientMonitor.AccessType.READ);
+      } else {
+        _monitor.record(path, _bytes, _startTimeMilliSec, ZkClientMonitor.AccessType.WRITE);
+      }
+    }
+  }
+}
diff --git a/zookeeper-api/src/main/java/org/apache/helix/zookeeper/zkclient/callback/ZkAsyncCallbacks.java b/zookeeper-api/src/main/java/org/apache/helix/zookeeper/zkclient/callback/ZkAsyncCallbacks.java
index 04c4058..70dbab4 100644
--- a/zookeeper-api/src/main/java/org/apache/helix/zookeeper/zkclient/callback/ZkAsyncCallbacks.java
+++ b/zookeeper-api/src/main/java/org/apache/helix/zookeeper/zkclient/callback/ZkAsyncCallbacks.java
@@ -31,41 +31,35 @@
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-
 public class ZkAsyncCallbacks {
   private static Logger LOG = LoggerFactory.getLogger(ZkAsyncCallbacks.class);
+  public static final int UNKNOWN_RET_CODE = 255;
 
   public static class GetDataCallbackHandler extends DefaultCallback implements DataCallback {
     public byte[] _data;
     public Stat _stat;
 
     @Override
-    public void handle() {
-      // TODO Auto-generated method stub
-    }
-
-    @Override
     public void processResult(int rc, String path, Object ctx, byte[] data, Stat stat) {
       if (rc == 0) {
         _data = data;
         _stat = stat;
         // update ctx with data size
-        if (_data != null && ctx != null && ctx instanceof ZkAsyncCallContext) {
-          ZkAsyncCallContext zkCtx = (ZkAsyncCallContext) ctx;
-          zkCtx._bytes = _data.length;
+        if (_data != null && ctx != null && ctx instanceof ZkAsyncCallMonitorContext) {
+          ((ZkAsyncCallMonitorContext) ctx).setBytes(_data.length);
         }
       }
       callback(rc, path, ctx);
     }
-  }
-
-  public static class SetDataCallbackHandler extends DefaultCallback implements StatCallback {
-    Stat _stat;
 
     @Override
     public void handle() {
       // TODO Auto-generated method stub
     }
+  }
+
+  public static class SetDataCallbackHandler extends DefaultCallback implements StatCallback {
+    Stat _stat;
 
     @Override
     public void processResult(int rc, String path, Object ctx, Stat stat) {
@@ -78,15 +72,15 @@
     public Stat getStat() {
       return _stat;
     }
-  }
-
-  public static class ExistsCallbackHandler extends DefaultCallback implements StatCallback {
-    public Stat _stat;
 
     @Override
     public void handle() {
       // TODO Auto-generated method stub
     }
+  }
+
+  public static class ExistsCallbackHandler extends DefaultCallback implements StatCallback {
+    public Stat _stat;
 
     @Override
     public void processResult(int rc, String path, Object ctx, Stat stat) {
@@ -95,6 +89,11 @@
       }
       callback(rc, path, ctx);
     }
+
+    @Override
+    public void handle() {
+      // TODO Auto-generated method stub
+    }
   }
 
   public static class CreateCallbackHandler extends DefaultCallback implements StringCallback {
@@ -122,44 +121,66 @@
   }
 
   /**
-   * Default callback for zookeeper async api
+   * Default callback for zookeeper async api.
    */
-  public static abstract class DefaultCallback {
-    AtomicBoolean _lock = new AtomicBoolean(false);
-    int _rc = -1;
+  public static abstract class DefaultCallback implements CancellableZkAsyncCallback {
+    AtomicBoolean _isOperationDone = new AtomicBoolean(false);
+    int _rc = UNKNOWN_RET_CODE;
 
     public void callback(int rc, String path, Object ctx) {
       if (rc != 0 && LOG.isDebugEnabled()) {
         LOG.debug(this + ", rc:" + Code.get(rc) + ", path: " + path);
       }
 
-      if (ctx != null && ctx instanceof ZkAsyncCallContext) {
-        ZkAsyncCallContext zkCtx = (ZkAsyncCallContext) ctx;
-        if (zkCtx._monitor != null) {
-          if (zkCtx._isRead) {
-            zkCtx._monitor.record(path, zkCtx._bytes, zkCtx._startTimeMilliSec,
-                ZkClientMonitor.AccessType.READ);
-          } else {
-            zkCtx._monitor.record(path, zkCtx._bytes, zkCtx._startTimeMilliSec,
-                ZkClientMonitor.AccessType.WRITE);
-          }
-        }
+      if (ctx != null && ctx instanceof ZkAsyncCallMonitorContext) {
+        ((ZkAsyncCallMonitorContext) ctx).recordAccess(path);
       }
 
       _rc = rc;
-      handle();
 
-      synchronized (_lock) {
-        _lock.set(true);
-        _lock.notify();
+      // If retry is requested by passing the retry callback context, do retry if necessary.
+      if (needRetry(rc)) {
+        if (ctx != null && ctx instanceof ZkAsyncRetryCallContext) {
+          try {
+            if (((ZkAsyncRetryCallContext) ctx).requestRetry()) {
+              // The retry operation will be done asynchronously. Once it is done, the same callback
+              // handler object shall be triggered to ensure the result is notified to the right
+              // caller(s).
+              return;
+            } else {
+              LOG.warn(
+                  "Cannot request to retry the operation. The retry request thread may have been stopped.");
+            }
+          } catch (Throwable t) {
+            LOG.error("Failed to request to retry the operation.", t);
+          }
+        } else {
+          LOG.warn(
+              "The provided callback context {} is not ZkAsyncRetryCallContext. Skip retrying.",
+              ctx.getClass().getName());
+        }
+      }
+
+      // If operation is done successfully or no retry needed, notify the caller(s).
+      try {
+        handle();
+      } finally {
+        markOperationDone();
       }
     }
 
+    public boolean isOperationDone() {
+      return _isOperationDone.get();
+    }
+
+    /**
+     * The blocking call that return true once the operation has been completed without retrying.
+     */
     public boolean waitForSuccess() {
       try {
-        synchronized (_lock) {
-          while (!_lock.get()) {
-            _lock.wait();
+        synchronized (_isOperationDone) {
+          while (!_isOperationDone.get()) {
+            _isOperationDone.wait();
           }
         }
       } catch (InterruptedException e) {
@@ -172,22 +193,52 @@
       return _rc;
     }
 
+    @Override
+    public void notifyCallers() {
+      LOG.warn("The callback {} has been cancelled.", this);
+      markOperationDone();
+    }
+
+    /**
+     * Additional callback handling.
+     */
     abstract public void handle();
-  }
 
-  public static class ZkAsyncCallContext {
-    private long _startTimeMilliSec;
-    private int _bytes;
-    private ZkClientMonitor _monitor;
-    private boolean _isRead;
+    private void markOperationDone() {
+      synchronized (_isOperationDone) {
+        _isOperationDone.set(true);
+        _isOperationDone.notifyAll();
+      }
+    }
 
-    public ZkAsyncCallContext(final ZkClientMonitor monitor, long startTimeMilliSec, int bytes,
-        boolean isRead) {
-      _monitor = monitor;
-      _startTimeMilliSec = startTimeMilliSec;
-      _bytes = bytes;
-      _isRead = isRead;
+    /**
+     * @param rc the return code
+     * @return true if the error is transient and the operation may succeed when being retried.
+     */
+    private boolean needRetry(int rc) {
+      try {
+        switch (Code.get(rc)) {
+        /** Connection to the server has been lost */
+        case CONNECTIONLOSS:
+          /** The session has been expired by the server */
+        case SESSIONEXPIRED:
+          /** Session moved to another server, so operation is ignored */
+        case SESSIONMOVED:
+          return true;
+        default:
+          return false;
+        }
+      } catch (ClassCastException | NullPointerException ex) {
+        LOG.error("Failed to handle unknown return code {}. Skip retrying.", rc, ex);
+        return false;
+      }
     }
   }
 
+  @Deprecated
+  public static class ZkAsyncCallContext extends ZkAsyncCallMonitorContext {
+    ZkAsyncCallContext(ZkClientMonitor monitor, long startTimeMilliSec, int bytes, boolean isRead) {
+      super(monitor, startTimeMilliSec, bytes, isRead);
+    }
+  }
 }
diff --git a/zookeeper-api/src/main/java/org/apache/helix/zookeeper/zkclient/callback/ZkAsyncRetryCallContext.java b/zookeeper-api/src/main/java/org/apache/helix/zookeeper/zkclient/callback/ZkAsyncRetryCallContext.java
new file mode 100644
index 0000000..4a9402f
--- /dev/null
+++ b/zookeeper-api/src/main/java/org/apache/helix/zookeeper/zkclient/callback/ZkAsyncRetryCallContext.java
@@ -0,0 +1,49 @@
+package org.apache.helix.zookeeper.zkclient.callback;
+
+import org.apache.helix.zookeeper.zkclient.metric.ZkClientMonitor;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public abstract class ZkAsyncRetryCallContext extends ZkAsyncCallMonitorContext {
+  private static Logger LOG = LoggerFactory.getLogger(ZkAsyncRetryCallContext.class);
+  private final ZkAsyncRetryThread _retryThread;
+  private final CancellableZkAsyncCallback _cancellableCallback;
+
+  /**
+   * @param retryThread       The thread that executes the retry operation.
+   *                          Note that retry in the ZkEventThread is not allowed to avoid dead lock.
+   * @param callback          Cancellable asynchronous callback to notify when the retry is cancelled.
+   * @param monitor           ZkClient monitor for update the operation result.
+   * @param startTimeMilliSec Operation initialization time.
+   * @param bytes             The data size in bytes that is involved in the operation.
+   * @param isRead            True if the operation is readonly.
+   */
+  public ZkAsyncRetryCallContext(final ZkAsyncRetryThread retryThread,
+      final CancellableZkAsyncCallback callback, final ZkClientMonitor monitor,
+      long startTimeMilliSec, int bytes, boolean isRead) {
+    super(monitor, startTimeMilliSec, bytes, isRead);
+    _retryThread = retryThread;
+    _cancellableCallback = callback;
+  }
+
+  /**
+   * Request a retry.
+   *
+   * @return True if the request was sent successfully.
+   */
+  boolean requestRetry() {
+    return _retryThread.sendRetryRequest(this);
+  }
+
+  /**
+   * Notify the pending callback that retry has been cancelled.
+   */
+  void cancel() {
+    _cancellableCallback.notifyCallers();
+  }
+
+  /**
+   * The actual retry operation logic.
+   */
+  protected abstract void doRetry() throws Exception;
+}
diff --git a/zookeeper-api/src/main/java/org/apache/helix/zookeeper/zkclient/callback/ZkAsyncRetryThread.java b/zookeeper-api/src/main/java/org/apache/helix/zookeeper/zkclient/callback/ZkAsyncRetryThread.java
new file mode 100644
index 0000000..c59d423
--- /dev/null
+++ b/zookeeper-api/src/main/java/org/apache/helix/zookeeper/zkclient/callback/ZkAsyncRetryThread.java
@@ -0,0 +1,57 @@
+package org.apache.helix.zookeeper.zkclient.callback;
+
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.LinkedBlockingQueue;
+
+import org.apache.helix.zookeeper.zkclient.exception.ZkInterruptedException;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class ZkAsyncRetryThread extends Thread {
+  private static Logger LOG = LoggerFactory.getLogger(ZkAsyncRetryThread.class);
+  private BlockingQueue<ZkAsyncRetryCallContext> _retryContexts = new LinkedBlockingQueue<>();
+  private volatile boolean _isReady = true;
+
+  public ZkAsyncRetryThread(String name) {
+    setDaemon(true);
+    setName("ZkClient-AsyncCallback-Retry-" + getId() + "-" + name);
+  }
+
+  @Override
+  public void run() {
+    LOG.info("Starting ZkClient AsyncCallback retry thread.");
+    try {
+      while (!isInterrupted()) {
+        ZkAsyncRetryCallContext context = _retryContexts.take();
+        try {
+          context.doRetry();
+        } catch (InterruptedException | ZkInterruptedException e) {
+          // if interrupted, stop retrying and interrupt the thread.
+          context.cancel();
+          interrupt();
+        } catch (Throwable e) {
+          LOG.error("Error retrying callback " + context, e);
+        }
+      }
+    } catch (InterruptedException e) {
+      LOG.info("ZkClient AsyncCallback retry thread is interrupted.");
+    }
+    synchronized (this) {
+      // Mark ready to be false, so no new requests will be sent.
+      _isReady = false;
+      // Notify to all the callers waiting for the result.
+      for (ZkAsyncRetryCallContext context : _retryContexts) {
+        context.cancel();
+      }
+    }
+    LOG.info("Terminate ZkClient AsyncCallback retry thread.");
+  }
+
+  synchronized boolean sendRetryRequest(ZkAsyncRetryCallContext context) {
+    if (_isReady) {
+      _retryContexts.add(context);
+      return true;
+    }
+    return false;
+  }
+}
diff --git a/zookeeper-api/src/test/java/org/apache/helix/zookeeper/impl/ZkTestBase.java b/zookeeper-api/src/test/java/org/apache/helix/zookeeper/impl/ZkTestBase.java
index 51eda80..2b8b1b3 100644
--- a/zookeeper-api/src/test/java/org/apache/helix/zookeeper/impl/ZkTestBase.java
+++ b/zookeeper-api/src/test/java/org/apache/helix/zookeeper/impl/ZkTestBase.java
@@ -58,7 +58,7 @@
    * Multiple ZK references
    */
   // The following maps hold ZK connect string as keys
-  protected final Map<String, ZkServer> _zkServerMap = new HashMap<>();
+  protected static final Map<String, ZkServer> _zkServerMap = new HashMap<>();
   protected static int _numZk = 1; // Initial value
 
   @BeforeSuite
diff --git a/zookeeper-api/src/test/java/org/apache/helix/zookeeper/impl/client/TestZkClientAsyncRetry.java b/zookeeper-api/src/test/java/org/apache/helix/zookeeper/impl/client/TestZkClientAsyncRetry.java
new file mode 100644
index 0000000..4e5b06f
--- /dev/null
+++ b/zookeeper-api/src/test/java/org/apache/helix/zookeeper/impl/client/TestZkClientAsyncRetry.java
@@ -0,0 +1,405 @@
+package org.apache.helix.zookeeper.impl.client;
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+import org.apache.helix.zookeeper.datamodel.ZNRecord;
+import org.apache.helix.zookeeper.datamodel.serializer.ZNRecordSerializer;
+import org.apache.helix.zookeeper.impl.ZkTestBase;
+import org.apache.helix.zookeeper.zkclient.callback.ZkAsyncCallbacks;
+import org.apache.helix.zookeeper.zkclient.callback.ZkAsyncRetryCallContext;
+import org.apache.helix.zookeeper.zkclient.exception.ZkInterruptedException;
+import org.apache.zookeeper.CreateMode;
+import org.apache.zookeeper.KeeperException;
+import org.testng.Assert;
+import org.testng.annotations.AfterClass;
+import org.testng.annotations.BeforeClass;
+import org.testng.annotations.Test;
+
+import static org.apache.helix.zookeeper.zkclient.callback.ZkAsyncCallbacks.UNKNOWN_RET_CODE;
+import static org.apache.zookeeper.KeeperException.Code.CONNECTIONLOSS;
+
+/**
+ * Note this is a whitebox test to test the async operation callback/context.
+ * We don't have a good way to simulate an async ZK operation failure in the server side yet.
+ */
+public class TestZkClientAsyncRetry extends ZkTestBase {
+  private final String TEST_ROOT = String.format("/%s", getClass().getSimpleName());
+  private final String NODE_PATH = TEST_ROOT + "/async";
+
+  private org.apache.helix.zookeeper.zkclient.ZkClient _zkClient;
+  private String _zkServerAddress;
+
+  @BeforeClass
+  public void beforeClass() {
+    _zkClient = _zkServerMap.values().iterator().next().getZkClient();
+    _zkServerAddress = _zkClient.getServers();
+    _zkClient.createPersistent(TEST_ROOT);
+  }
+
+  @AfterClass
+  public void afterClass() {
+    _zkClient.deleteRecursively(TEST_ROOT);
+    _zkClient.close();
+  }
+
+  private boolean waitAsyncOperation(ZkAsyncCallbacks.DefaultCallback callback, long timeout) {
+    final boolean[] ret = { false };
+    Thread waitThread = new Thread(() -> ret[0] = callback.waitForSuccess());
+    waitThread.start();
+    try {
+      waitThread.join(timeout);
+      waitThread.interrupt();
+      return ret[0];
+    } catch (InterruptedException e) {
+      return false;
+    }
+  }
+
+  @Test
+  public void testAsyncRetryCategories() {
+    MockAsyncZkClient testZkClient = new MockAsyncZkClient(_zkServerAddress);
+    try {
+      ZNRecord tmpRecord = new ZNRecord("tmpRecord");
+      tmpRecord.setSimpleField("foo", "bar");
+      // Loop all possible error codes to test async create.
+      // Only connectivity issues will be retried, the other issues will be return error immediately.
+      for (KeeperException.Code code : KeeperException.Code.values()) {
+        if (code == KeeperException.Code.OK) {
+          continue;
+        }
+        ZkAsyncCallbacks.CreateCallbackHandler createCallback =
+            new ZkAsyncCallbacks.CreateCallbackHandler();
+        Assert.assertEquals(createCallback.getRc(), UNKNOWN_RET_CODE);
+        testZkClient.setAsyncCallRC(code.intValue());
+        if (code == CONNECTIONLOSS || code == KeeperException.Code.SESSIONEXPIRED
+            || code == KeeperException.Code.SESSIONMOVED) {
+          // Async create will be pending due to the mock error rc is retryable.
+          testZkClient.asyncCreate(NODE_PATH, null, CreateMode.PERSISTENT, createCallback);
+          Assert.assertFalse(createCallback.isOperationDone());
+          Assert.assertEquals(createCallback.getRc(), code.intValue());
+          // Change the mock response
+          testZkClient.setAsyncCallRC(KeeperException.Code.OK.intValue());
+          // Async retry will succeed now. Wait until the operation is successfully done and verify.
+          Assert.assertTrue(waitAsyncOperation(createCallback, 1000));
+          Assert.assertEquals(createCallback.getRc(), KeeperException.Code.OK.intValue());
+          Assert.assertTrue(testZkClient.exists(NODE_PATH));
+          Assert.assertTrue(testZkClient.getAndResetRetryCount() >= 1);
+        } else {
+          // Async create will fail due to the mock error rc is not recoverable.
+          testZkClient.asyncCreate(NODE_PATH, null, CreateMode.PERSISTENT, createCallback);
+          Assert.assertTrue(waitAsyncOperation(createCallback, 1000));
+          Assert.assertEquals(createCallback.getRc(), code.intValue());
+          Assert.assertEquals(testZkClient.getAndResetRetryCount(), 0);
+        }
+        testZkClient.delete(NODE_PATH);
+        Assert.assertFalse(testZkClient.exists(NODE_PATH));
+      }
+    } finally {
+      testZkClient.setAsyncCallRC(KeeperException.Code.OK.intValue());
+      testZkClient.close();
+      _zkClient.delete(NODE_PATH);
+    }
+  }
+
+  @Test(dependsOnMethods = "testAsyncRetryCategories")
+  public void testAsyncWriteRetry() {
+    MockAsyncZkClient testZkClient = new MockAsyncZkClient(_zkServerAddress);
+    try {
+      ZNRecord tmpRecord = new ZNRecord("tmpRecord");
+      tmpRecord.setSimpleField("foo", "bar");
+      testZkClient.createPersistent(NODE_PATH, tmpRecord);
+
+      // 1. Test async set retry
+      ZkAsyncCallbacks.SetDataCallbackHandler setCallback =
+          new ZkAsyncCallbacks.SetDataCallbackHandler();
+      Assert.assertEquals(setCallback.getRc(), UNKNOWN_RET_CODE);
+
+      tmpRecord.setSimpleField("test", "data");
+      testZkClient.setAsyncCallRC(CONNECTIONLOSS.intValue());
+      // Async set will be pending due to the mock error rc is retryable.
+      testZkClient.asyncSetData(NODE_PATH, tmpRecord, -1, setCallback);
+      Assert.assertFalse(setCallback.isOperationDone());
+      Assert.assertEquals(setCallback.getRc(), CONNECTIONLOSS.intValue());
+      // Change the mock return code.
+      testZkClient.setAsyncCallRC(KeeperException.Code.OK.intValue());
+      // Async retry will succeed now. Wait until the operation is successfully done and verify.
+      Assert.assertTrue(waitAsyncOperation(setCallback, 1000));
+      Assert.assertEquals(setCallback.getRc(), KeeperException.Code.OK.intValue());
+      Assert.assertEquals(((ZNRecord) testZkClient.readData(NODE_PATH)).getSimpleField("test"),
+          "data");
+      Assert.assertTrue(testZkClient.getAndResetRetryCount() >= 1);
+
+      // 2. Test async delete
+      ZkAsyncCallbacks.DeleteCallbackHandler deleteCallback =
+          new ZkAsyncCallbacks.DeleteCallbackHandler();
+      Assert.assertEquals(deleteCallback.getRc(), UNKNOWN_RET_CODE);
+
+      testZkClient.setAsyncCallRC(CONNECTIONLOSS.intValue());
+      // Async delete will be pending due to the mock error rc is retryable.
+      testZkClient.asyncDelete(NODE_PATH, deleteCallback);
+      Assert.assertFalse(deleteCallback.isOperationDone());
+      Assert.assertEquals(deleteCallback.getRc(), CONNECTIONLOSS.intValue());
+      // Change the mock return code.
+      testZkClient.setAsyncCallRC(KeeperException.Code.OK.intValue());
+      // Async retry will succeed now. Wait until the operation is successfully done and verify.
+      Assert.assertTrue(waitAsyncOperation(deleteCallback, 1000));
+      Assert.assertEquals(deleteCallback.getRc(), KeeperException.Code.OK.intValue());
+      Assert.assertFalse(testZkClient.exists(NODE_PATH));
+      Assert.assertTrue(testZkClient.getAndResetRetryCount() >= 1);
+    } finally {
+      testZkClient.setAsyncCallRC(KeeperException.Code.OK.intValue());
+      testZkClient.close();
+      _zkClient.delete(NODE_PATH);
+    }
+  }
+
+  @Test(dependsOnMethods = "testAsyncWriteRetry")
+  public void testAsyncReadRetry() {
+    MockAsyncZkClient testZkClient = new MockAsyncZkClient(_zkServerAddress);
+    try {
+      ZNRecord tmpRecord = new ZNRecord("tmpRecord");
+      tmpRecord.setSimpleField("foo", "bar");
+      testZkClient.createPersistent(NODE_PATH, tmpRecord);
+
+      // 1. Test async exist check
+      ZkAsyncCallbacks.ExistsCallbackHandler existsCallback =
+          new ZkAsyncCallbacks.ExistsCallbackHandler();
+      Assert.assertEquals(existsCallback.getRc(), UNKNOWN_RET_CODE);
+
+      testZkClient.setAsyncCallRC(CONNECTIONLOSS.intValue());
+      // Async exist check will be pending due to the mock error rc is retryable.
+      testZkClient.asyncExists(NODE_PATH, existsCallback);
+      Assert.assertFalse(existsCallback.isOperationDone());
+      Assert.assertEquals(existsCallback.getRc(), CONNECTIONLOSS.intValue());
+      // Change the mock return code.
+      testZkClient.setAsyncCallRC(KeeperException.Code.OK.intValue());
+      // Async retry will succeed now. Wait until the operation is successfully done and verify.
+      Assert.assertTrue(waitAsyncOperation(existsCallback, 1000));
+      Assert.assertEquals(existsCallback.getRc(), KeeperException.Code.OK.intValue());
+      Assert.assertTrue(existsCallback._stat != null);
+      Assert.assertTrue(testZkClient.getAndResetRetryCount() >= 1);
+
+      // 2. Test async get
+      ZkAsyncCallbacks.GetDataCallbackHandler getCallback =
+          new ZkAsyncCallbacks.GetDataCallbackHandler();
+      Assert.assertEquals(getCallback.getRc(), UNKNOWN_RET_CODE);
+
+      testZkClient.setAsyncCallRC(CONNECTIONLOSS.intValue());
+      // Async get will be pending due to the mock error rc is retryable.
+      testZkClient.asyncGetData(NODE_PATH, getCallback);
+      Assert.assertFalse(getCallback.isOperationDone());
+      Assert.assertEquals(getCallback.getRc(), CONNECTIONLOSS.intValue());
+      // Change the mock return code.
+      testZkClient.setAsyncCallRC(KeeperException.Code.OK.intValue());
+      // Async retry will succeed now. Wait until the operation is successfully done and verify.
+      Assert.assertTrue(waitAsyncOperation(getCallback, 1000));
+      Assert.assertEquals(getCallback.getRc(), KeeperException.Code.OK.intValue());
+      ZNRecord record = testZkClient.deserialize(getCallback._data, NODE_PATH);
+      Assert.assertEquals(record.getSimpleField("foo"), "bar");
+      Assert.assertTrue(testZkClient.getAndResetRetryCount() >= 1);
+    } finally {
+      testZkClient.setAsyncCallRC(KeeperException.Code.OK.intValue());
+      testZkClient.close();
+      _zkClient.delete(NODE_PATH);
+    }
+  }
+
+  @Test(dependsOnMethods = "testAsyncReadRetry")
+  public void testAsyncRequestCleanup() {
+    int cbCount = 10;
+    MockAsyncZkClient testZkClient = new MockAsyncZkClient(_zkServerAddress);
+    try {
+      ZNRecord tmpRecord = new ZNRecord("tmpRecord");
+      tmpRecord.setSimpleField("foo", "bar");
+      testZkClient.createPersistent(NODE_PATH, tmpRecord);
+
+      // Create 10 async exists check requests
+      ZkAsyncCallbacks.ExistsCallbackHandler[] existsCallbacks =
+          new ZkAsyncCallbacks.ExistsCallbackHandler[cbCount];
+      for (int i = 0; i < cbCount; i++) {
+        existsCallbacks[i] = new ZkAsyncCallbacks.ExistsCallbackHandler();
+      }
+      testZkClient.setAsyncCallRC(CONNECTIONLOSS.intValue());
+      // All async exist check calls will be pending due to the mock error rc is retryable.
+      for (ZkAsyncCallbacks.ExistsCallbackHandler cb : existsCallbacks) {
+        testZkClient.asyncExists(NODE_PATH, cb);
+        Assert.assertEquals(cb.getRc(), CONNECTIONLOSS.intValue());
+      }
+      // Wait for a while, no callback finishes
+      Assert.assertFalse(waitAsyncOperation(existsCallbacks[0], 1000));
+      for (ZkAsyncCallbacks.ExistsCallbackHandler cb : existsCallbacks) {
+        Assert.assertEquals(cb.getRc(), CONNECTIONLOSS.intValue());
+        Assert.assertFalse(cb.isOperationDone());
+      }
+      testZkClient.close();
+      // All callback retry will be cancelled because the zkclient is closed.
+      for (ZkAsyncCallbacks.ExistsCallbackHandler cb : existsCallbacks) {
+        Assert.assertTrue(waitAsyncOperation(cb, 1000));
+        Assert.assertEquals(cb.getRc(), CONNECTIONLOSS.intValue());
+      }
+      Assert.assertTrue(testZkClient.getAndResetRetryCount() >= 1);
+    } finally {
+      testZkClient.setAsyncCallRC(KeeperException.Code.OK.intValue());
+      testZkClient.close();
+      _zkClient.delete(NODE_PATH);
+    }
+  }
+
+  /**
+   * Mock client to whitebox test async functionality.
+   */
+  class MockAsyncZkClient extends ZkClient {
+    private static final long RETRY_INTERVAL_MS = 500;
+    private long _retryCount = 0;
+
+    /**
+     * If the specified return code is OK, call the real function.
+     * Otherwise, trigger the callback with the specified RC without triggering the real ZK call.
+     */
+    private int _asyncCallRetCode = KeeperException.Code.OK.intValue();
+
+    public MockAsyncZkClient(String zkAddress) {
+      super(zkAddress);
+      setZkSerializer(new ZNRecordSerializer());
+    }
+
+    public void setAsyncCallRC(int rc) {
+      _asyncCallRetCode = rc;
+    }
+
+    public long getAndResetRetryCount() {
+      long tmpCount = _retryCount;
+      _retryCount = 0;
+      return tmpCount;
+    }
+
+    @Override
+    public void asyncCreate(String path, Object datat, CreateMode mode,
+        ZkAsyncCallbacks.CreateCallbackHandler cb) {
+      if (_asyncCallRetCode == KeeperException.Code.OK.intValue()) {
+        super.asyncCreate(path, datat, mode, cb);
+        return;
+      } else {
+        cb.processResult(_asyncCallRetCode, path,
+            new ZkAsyncRetryCallContext(_asyncCallRetryThread, cb, null, 0, 0, false) {
+              @Override
+              protected void doRetry() {
+                _retryCount++;
+                try {
+                  Thread.sleep(RETRY_INTERVAL_MS);
+                } catch (InterruptedException e) {
+                  throw new ZkInterruptedException(e);
+                }
+                asyncCreate(path, datat, mode, cb);
+              }
+            }, null);
+      }
+    }
+
+    @Override
+    public void asyncSetData(String path, Object datat, int version,
+        ZkAsyncCallbacks.SetDataCallbackHandler cb) {
+      if (_asyncCallRetCode == KeeperException.Code.OK.intValue()) {
+        super.asyncSetData(path, datat, version, cb);
+        return;
+      } else {
+        cb.processResult(_asyncCallRetCode, path,
+            new ZkAsyncRetryCallContext(_asyncCallRetryThread, cb, null, 0, 0, false) {
+              @Override
+              protected void doRetry() {
+                _retryCount++;
+                try {
+                  Thread.sleep(RETRY_INTERVAL_MS);
+                } catch (InterruptedException e) {
+                  throw new ZkInterruptedException(e);
+                }
+                asyncSetData(path, datat, version, cb);
+              }
+            }, null);
+      }
+    }
+
+    @Override
+    public void asyncGetData(String path, ZkAsyncCallbacks.GetDataCallbackHandler cb) {
+      if (_asyncCallRetCode == KeeperException.Code.OK.intValue()) {
+        super.asyncGetData(path, cb);
+        return;
+      } else {
+        cb.processResult(_asyncCallRetCode, path,
+            new ZkAsyncRetryCallContext(_asyncCallRetryThread, cb, null, 0, 0, true) {
+              @Override
+              protected void doRetry() {
+                _retryCount++;
+                try {
+                  Thread.sleep(RETRY_INTERVAL_MS);
+                } catch (InterruptedException e) {
+                  throw new ZkInterruptedException(e);
+                }
+                asyncGetData(path, cb);
+              }
+            }, null, null);
+      }
+    }
+
+    @Override
+    public void asyncExists(String path, ZkAsyncCallbacks.ExistsCallbackHandler cb) {
+      if (_asyncCallRetCode == KeeperException.Code.OK.intValue()) {
+        super.asyncExists(path, cb);
+        return;
+      } else {
+        cb.processResult(_asyncCallRetCode, path,
+            new ZkAsyncRetryCallContext(_asyncCallRetryThread, cb, null, 0, 0, true) {
+              @Override
+              protected void doRetry() {
+                _retryCount++;
+                try {
+                  Thread.sleep(RETRY_INTERVAL_MS);
+                } catch (InterruptedException e) {
+                  throw new ZkInterruptedException(e);
+                }
+                asyncExists(path, cb);
+              }
+            }, null);
+      }
+    }
+
+    @Override
+    public void asyncDelete(String path, ZkAsyncCallbacks.DeleteCallbackHandler cb) {
+      if (_asyncCallRetCode == KeeperException.Code.OK.intValue()) {
+        super.asyncDelete(path, cb);
+        return;
+      } else {
+        cb.processResult(_asyncCallRetCode, path,
+            new ZkAsyncRetryCallContext(_asyncCallRetryThread, cb, null, 0, 0, false) {
+              @Override
+              protected void doRetry() {
+                _retryCount++;
+                try {
+                  Thread.sleep(RETRY_INTERVAL_MS);
+                } catch (InterruptedException e) {
+                  throw new ZkInterruptedException(e);
+                }
+                asyncDelete(path, cb);
+              }
+            });
+      }
+    }
+  }
+}