GEODE-8634: Fix AsyncInvocationTimeoutDistributedTest flakiness (#5799)

Reset latch and threadId before each test.
diff --git a/geode-dunit/src/distributedTest/java/org/apache/geode/test/dunit/tests/AsyncInvocationTimeoutDistributedTest.java b/geode-dunit/src/distributedTest/java/org/apache/geode/test/dunit/tests/AsyncInvocationTimeoutDistributedTest.java
index af740f6..fbbf836 100644
--- a/geode-dunit/src/distributedTest/java/org/apache/geode/test/dunit/tests/AsyncInvocationTimeoutDistributedTest.java
+++ b/geode-dunit/src/distributedTest/java/org/apache/geode/test/dunit/tests/AsyncInvocationTimeoutDistributedTest.java
@@ -28,6 +28,7 @@
 import java.util.concurrent.atomic.AtomicReference;
 
 import org.junit.After;
+import org.junit.Before;
 import org.junit.Rule;
 import org.junit.Test;
 
@@ -38,35 +39,37 @@
 @SuppressWarnings("serial")
 public class AsyncInvocationTimeoutDistributedTest implements Serializable {
 
-  private static final long TIMEOUT_MILLIS = getTimeout().toMillis();
-
-  private static final AtomicReference<Long> threadId = new AtomicReference<>();
-  private static final AtomicReference<CountDownLatch> latch = new AtomicReference<>();
+  private static final AtomicReference<Long> THREAD_ID =
+      new AtomicReference<>(0L);
+  private static final AtomicReference<CountDownLatch> LATCH =
+      new AtomicReference<>(new CountDownLatch(0));
 
   @Rule
-  public DistributedRule distributedRule = new DistributedRule();
+  public DistributedRule distributedRule = new DistributedRule(1);
+
+  @Before
+  public void setUp() {
+    getVM(0).invoke(() -> {
+      LATCH.set(new CountDownLatch(1));
+      THREAD_ID.set(0L);
+    });
+  }
 
   @After
   public void tearDown() {
-    getVM(0).invoke(() -> {
-      CountDownLatch latchInVM0 = latch.get();
-      while (latchInVM0 != null && latchInVM0.getCount() > 0) {
-        latchInVM0.countDown();
-      }
-    });
+    getVM(0).invoke(() -> LATCH.get().countDown());
   }
 
   @Test
-  public void await_runnable_timeout_includesStackTraceAsCause() {
+  public void awaitWithRunnableTimeoutExceptionIncludesRemoteStackTraceAsCause() {
     AsyncInvocation<Void> hangInVM0 = getVM(0).invokeAsync(() -> {
-      latch.set(new CountDownLatch(1));
-      threadId.set(Thread.currentThread().getId());
-      latch.get().await(TIMEOUT_MILLIS, MILLISECONDS);
+      THREAD_ID.set(Thread.currentThread().getId());
+      LATCH.get().await(getTimeout().toMillis(), MILLISECONDS);
     });
 
     long remoteThreadId = getVM(0).invoke(() -> {
-      await().until(() -> threadId.get() > 0);
-      return threadId.get();
+      await().until(() -> THREAD_ID.get() > 0);
+      return THREAD_ID.get();
     });
 
     Throwable thrown = catchThrowable(() -> hangInVM0.await(1, SECONDS));
@@ -79,17 +82,16 @@
   }
 
   @Test
-  public void await_callable_timeout_includesStackTraceAsCause() {
+  public void awaitWithCallableTimeoutExceptionIncludesRemoteStackTraceAsCause() {
     AsyncInvocation<Integer> hangInVM0 = getVM(0).invokeAsync(() -> {
-      latch.set(new CountDownLatch(1));
-      threadId.set(Thread.currentThread().getId());
-      latch.get().await(TIMEOUT_MILLIS, MILLISECONDS);
+      THREAD_ID.set(Thread.currentThread().getId());
+      LATCH.get().await(getTimeout().toMillis(), MILLISECONDS);
       return 42;
     });
 
     long remoteThreadId = getVM(0).invoke(() -> {
-      await().until(() -> threadId.get() > 0);
-      return threadId.get();
+      await().until(() -> THREAD_ID.get() > 0);
+      return THREAD_ID.get();
     });
 
     Throwable thrown = catchThrowable(() -> hangInVM0.await(1, SECONDS));
@@ -102,17 +104,16 @@
   }
 
   @Test
-  public void get_callable_timeout_includesStackTraceAsCause() {
+  public void getWithCallableTimeoutExceptionIncludesRemoteStackTraceAsCause() {
     AsyncInvocation<Integer> hangInVM0 = getVM(0).invokeAsync(() -> {
-      latch.set(new CountDownLatch(1));
-      threadId.set(Thread.currentThread().getId());
-      latch.get().await(TIMEOUT_MILLIS, MILLISECONDS);
+      THREAD_ID.set(Thread.currentThread().getId());
+      LATCH.get().await(getTimeout().toMillis(), MILLISECONDS);
       return 42;
     });
 
     long remoteThreadId = getVM(0).invoke(() -> {
-      await().until(() -> threadId.get() > 0);
-      return threadId.get();
+      await().until(() -> THREAD_ID.get() > 0);
+      return THREAD_ID.get();
     });
 
     Throwable thrown = catchThrowable(() -> hangInVM0.get(1, SECONDS));