HBASE-25307 ThreadLocal pooling leads to NullPointerException (#2685)

* PoolMap does not discard any elements anymore. If an element is put,
it always stores it. The reason: it stores expensive resources (rpc
connections) which would lead to resource leak if we simple discard it.
RpcClients can reference netty ByteBufs which are reference counted.
Resource cleanup is done by AbstractRpcClient.cleanupIdleConnections().
* PoolMap does not implement Map interface anymore, so ensuring
thread-safety has become easier. Put method is replaced with getOrCreate().
* ThreadLocalPool doesn't use ThreadLocal class anymore. It stores
resources on thread basis, but it doesn't remove values when a thread
exits. Again, proper cleanup is done by cleanupIdleConnections().

Signed-off-by: Sean Busbey <busbey@apache.org>
Signed-off-by: Wellington Chevreuil <wellington.chevreuil@gmail.com>
diff --git a/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/AbstractRpcClient.java b/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/AbstractRpcClient.java
index d94f2d3..acc82de 100644
--- a/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/AbstractRpcClient.java
+++ b/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/AbstractRpcClient.java
@@ -101,7 +101,7 @@
       new ThreadFactoryBuilder().setNameFormat("Idle-Rpc-Conn-Sweeper-pool-%d").setDaemon(true)
         .setUncaughtExceptionHandler(Threads.LOGGING_EXCEPTION_HANDLER).build());
 
-  protected boolean running = true; // if client runs
+  private boolean running = true; // if client runs
 
   protected final Configuration conf;
   protected final String clusterId;
@@ -127,7 +127,7 @@
   protected final int readTO;
   protected final int writeTO;
 
-  protected final PoolMap<ConnectionId, T> connections;
+  private final PoolMap<ConnectionId, T> connections;
 
   private final AtomicInteger callIdCnt = new AtomicInteger(0);
 
@@ -209,7 +209,7 @@
           if (LOG.isTraceEnabled()) {
             LOG.trace("Cleanup idle connection to {}", conn.remoteId().address);
           }
-          connections.removeValue(conn.remoteId(), conn);
+          connections.remove(conn.remoteId(), conn);
           conn.cleanupConnection();
         }
       }
@@ -294,7 +294,14 @@
    * @return the maximum pool size
    */
   private static int getPoolSize(Configuration config) {
-    return config.getInt(HConstants.HBASE_CLIENT_IPC_POOL_SIZE, 1);
+    int poolSize = config.getInt(HConstants.HBASE_CLIENT_IPC_POOL_SIZE, 1);
+
+    if (poolSize <= 0) {
+      LOG.warn("{} must be positive. Using default value: 1", HConstants.HBASE_CLIENT_IPC_POOL_SIZE);
+      return 1;
+    } else {
+      return poolSize;
+    }
   }
 
   private int nextCallId() {
@@ -350,11 +357,7 @@
       if (!running) {
         throw new StoppedRpcClientException();
       }
-      conn = connections.get(remoteId);
-      if (conn == null) {
-        conn = createConnection(remoteId);
-        connections.put(remoteId, conn);
-      }
+      conn = connections.getOrCreate(remoteId, () -> createConnection(remoteId));
       conn.setLastTouched(EnvironmentEdgeManager.currentTime());
     }
     return conn;
@@ -453,7 +456,7 @@
             && remoteId.address.getHostName().equals(sn.getHostname())) {
           LOG.info("The server on " + sn.toString() + " is dead - stopping the connection "
               + connection.remoteId);
-          connections.removeValue(remoteId, connection);
+          connections.remove(remoteId, connection);
           connection.shutdown();
           connection.cleanupConnection();
         }
diff --git a/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/BlockingRpcConnection.java b/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/BlockingRpcConnection.java
index ba3d4cd..6d4babe 100644
--- a/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/BlockingRpcConnection.java
+++ b/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/BlockingRpcConnection.java
@@ -343,13 +343,13 @@
   @Override
   public void run() {
     if (LOG.isTraceEnabled()) {
-      LOG.trace(threadName + ": starting, connections " + this.rpcClient.connections.size());
+      LOG.trace(threadName + ": starting");
     }
     while (waitForWork()) {
       readResponse();
     }
     if (LOG.isTraceEnabled()) {
-      LOG.trace(threadName + ": stopped, connections " + this.rpcClient.connections.size());
+      LOG.trace(threadName + ": stopped");
     }
   }
 
diff --git a/hbase-client/src/main/java/org/apache/hadoop/hbase/util/PoolMap.java b/hbase-client/src/main/java/org/apache/hadoop/hbase/util/PoolMap.java
index cb3bf4a..897c0d8 100644
--- a/hbase-client/src/main/java/org/apache/hadoop/hbase/util/PoolMap.java
+++ b/hbase-client/src/main/java/org/apache/hadoop/hbase/util/PoolMap.java
@@ -18,17 +18,8 @@
  */
 package org.apache.hadoop.hbase.util;
 
-import java.util.ArrayList;
-import java.util.Collection;
-import java.util.HashMap;
-import java.util.HashSet;
-import java.util.List;
-import java.util.Locale;
-import java.util.Map;
-import java.util.Set;
-import java.util.concurrent.ConcurrentHashMap;
-import java.util.concurrent.CopyOnWriteArrayList;
-import java.util.concurrent.atomic.AtomicInteger;
+import java.io.IOException;
+import java.util.*;
 
 import org.apache.yetus.audience.InterfaceAudience;
 
@@ -45,177 +36,102 @@
  * key. A size of {@link Integer#MAX_VALUE} is interpreted as an unbounded pool.
  * </p>
  *
+ * <p>
+ * PoolMap is thread-safe. It does not remove elements automatically. Unused resources
+ * must be closed and removed explicitly.
+ * </p>
+ *
  * @param <K>
  *          the type of the key to the resource
  * @param <V>
  *          the type of the resource being pooled
  */
 @InterfaceAudience.Private
-public class PoolMap<K, V> implements Map<K, V> {
-  private PoolType poolType;
+public class PoolMap<K, V> {
+  private final Map<K, Pool<V>> pools;
+  private final PoolType poolType;
+  private final int poolMaxSize;
 
-  private int poolMaxSize;
-
-  private Map<K, Pool<V>> pools = new ConcurrentHashMap<>();
-
-  public PoolMap(PoolType poolType) {
-    this.poolType = poolType;
+   public PoolMap(PoolType poolType, int poolMaxSize) {
+     pools = new HashMap<>();
+     this.poolType = poolType;
+     this.poolMaxSize = poolMaxSize;
   }
 
-  public PoolMap(PoolType poolType, int poolMaxSize) {
-    this.poolType = poolType;
-    this.poolMaxSize = poolMaxSize;
-  }
+  public V getOrCreate(K key, PoolResourceSupplier<V> supplier) throws IOException {
+     synchronized (pools) {
+       Pool<V> pool = pools.get(key);
 
-  @Override
-  public V get(Object key) {
-    Pool<V> pool = pools.get(key);
-    return pool != null ? pool.get() : null;
-  }
+       if (pool == null) {
+         pool = createPool();
+         pools.put(key, pool);
+       }
 
-  @Override
-  public V put(K key, V value) {
-    Pool<V> pool = pools.get(key);
-    if (pool == null) {
-      pools.put(key, pool = createPool());
-    }
-    return pool != null ? pool.put(value) : null;
-  }
+       try {
+         return pool.getOrCreate(supplier);
+       } catch (IOException | RuntimeException | Error e) {
+         if (pool.size() == 0) {
+           pools.remove(key);
+         }
 
-  @SuppressWarnings("unchecked")
-  @Override
-  public V remove(Object key) {
-    Pool<V> pool = pools.remove(key);
-    if (pool != null) {
-      removeValue((K) key, pool.get());
-    }
-    return null;
+         throw e;
+       }
+     }
   }
+  public boolean remove(K key, V value) {
+    synchronized (pools) {
+      Pool<V> pool = pools.get(key);
 
-  public boolean removeValue(K key, V value) {
-    Pool<V> pool = pools.get(key);
-    boolean res = false;
-    if (pool != null) {
-      res = pool.remove(value);
-      if (res && pool.size() == 0) {
+      if (pool == null) {
+        return false;
+      }
+
+      boolean removed = pool.remove(value);
+
+      if (removed && pool.size() == 0) {
         pools.remove(key);
       }
-    }
-    return res;
-  }
 
-  @Override
-  public Collection<V> values() {
-    Collection<V> values = new ArrayList<>();
-    for (Pool<V> pool : pools.values()) {
-      Collection<V> poolValues = pool.values();
-      if (poolValues != null) {
-        values.addAll(poolValues);
-      }
-    }
-    return values;
-  }
-
-  public Collection<V> values(K key) {
-    Collection<V> values = new ArrayList<>();
-    Pool<V> pool = pools.get(key);
-    if (pool != null) {
-      Collection<V> poolValues = pool.values();
-      if (poolValues != null) {
-        values.addAll(poolValues);
-      }
-    }
-    return values;
-  }
-
-
-  @Override
-  public boolean isEmpty() {
-    return pools.isEmpty();
-  }
-
-  @Override
-  public int size() {
-    return pools.size();
-  }
-
-  public int size(K key) {
-    Pool<V> pool = pools.get(key);
-    return pool != null ? pool.size() : 0;
-  }
-
-  @Override
-  public boolean containsKey(Object key) {
-    return pools.containsKey(key);
-  }
-
-  @Override
-  public boolean containsValue(Object value) {
-    if (value == null) {
-      return false;
-    }
-    for (Pool<V> pool : pools.values()) {
-      if (value.equals(pool.get())) {
-        return true;
-      }
-    }
-    return false;
-  }
-
-  @Override
-  public void putAll(Map<? extends K, ? extends V> map) {
-    for (Map.Entry<? extends K, ? extends V> entry : map.entrySet()) {
-      put(entry.getKey(), entry.getValue());
+      return removed;
     }
   }
 
-  @Override
-  public void clear() {
-    for (Pool<V> pool : pools.values()) {
-      pool.clear();
-    }
-    pools.clear();
-  }
+  public List<V> values() {
+    List<V> values = new ArrayList<>();
 
-  @Override
-  public Set<K> keySet() {
-    return pools.keySet();
-  }
-
-  @Override
-  public Set<Map.Entry<K, V>> entrySet() {
-    Set<Map.Entry<K, V>> entries = new HashSet<>();
-    for (Map.Entry<K, Pool<V>> poolEntry : pools.entrySet()) {
-      final K poolKey = poolEntry.getKey();
-      final Pool<V> pool = poolEntry.getValue();
-      if (pool != null) {
-        for (final V poolValue : pool.values()) {
-          entries.add(new Map.Entry<K, V>() {
-            @Override
-            public K getKey() {
-              return poolKey;
-            }
-
-            @Override
-            public V getValue() {
-              return poolValue;
-            }
-
-            @Override
-            public V setValue(V value) {
-              return pool.put(value);
-            }
-          });
+    synchronized (pools) {
+      for (Pool<V> pool : pools.values()) {
+        Collection<V> poolValues = pool.values();
+        if (poolValues != null) {
+          values.addAll(poolValues);
         }
       }
     }
-    return entries;
+
+    return values;
+  }
+
+  public void clear() {
+    synchronized (pools) {
+      for (Pool<V> pool : pools.values()) {
+        pool.clear();
+      }
+
+      pools.clear();
+    }
+  }
+
+  public interface PoolResourceSupplier<R> {
+     R get() throws IOException;
+  }
+
+  protected static <V> V createResource(PoolResourceSupplier<V> supplier) throws IOException {
+    V resource = supplier.get();
+    return Objects.requireNonNull(resource, "resource cannot be null.");
   }
 
   protected interface Pool<R> {
-    R get();
-
-    R put(R resource);
+    R getOrCreate(PoolResourceSupplier<R> supplier) throws IOException;
 
     boolean remove(R resource);
 
@@ -254,8 +170,9 @@
       return new RoundRobinPool<>(poolMaxSize);
     case ThreadLocal:
       return new ThreadLocalPool<>();
+    default:
+      return new RoundRobinPool<>(poolMaxSize);
     }
-    return null;
   }
 
   /**
@@ -275,43 +192,66 @@
    *
    */
   @SuppressWarnings("serial")
-  static class RoundRobinPool<R> extends CopyOnWriteArrayList<R> implements Pool<R> {
-    private int maxSize;
-    private int nextResource = 0;
+  static class RoundRobinPool<R> implements Pool<R> {
+    private final List<R> resources;
+    private final int maxSize;
+
+    private int nextIndex;
 
     public RoundRobinPool(int maxSize) {
+      if (maxSize <= 0) {
+        throw new IllegalArgumentException("maxSize must be positive");
+      }
+
+      resources = new ArrayList<>(maxSize);
       this.maxSize = maxSize;
     }
 
     @Override
-    public R put(R resource) {
-      if (super.size() < maxSize) {
-        add(resource);
-      }
-      return null;
-    }
+    public R getOrCreate(PoolResourceSupplier<R> supplier) throws IOException {
+      int size = resources.size();
+      R resource;
 
-    @Override
-    public R get() {
-      if (super.size() < maxSize) {
-        return null;
+      /* letting pool to grow */
+      if (size < maxSize) {
+        resource = createResource(supplier);
+        resources.add(resource);
+      } else {
+        resource = resources.get(nextIndex);
+
+        /* at this point size cannot be 0 */
+        nextIndex = (nextIndex + 1) % size;
       }
-      nextResource %= super.size();
-      R resource = get(nextResource++);
+
       return resource;
     }
 
     @Override
-    public Collection<R> values() {
-      return this;
+    public boolean remove(R resource) {
+      return resources.remove(resource);
     }
 
+    @Override
+    public void clear() {
+      resources.clear();
+    }
+
+    @Override
+    public Collection<R> values() {
+      return resources;
+    }
+
+    @Override
+    public int size() {
+      return resources.size();
+    }
   }
 
   /**
    * The <code>ThreadLocalPool</code> represents a {@link PoolMap.Pool} that
-   * builds on the {@link ThreadLocal} class. It essentially binds the resource
-   * to the thread from which it is accessed.
+   * works similarly to {@link ThreadLocal} class. It essentially binds the resource
+   * to the thread from which it is accessed. It doesn't remove resources when a thread exits,
+   * those resources must be closed manually.
    *
    * <p>
    * Note that the size of the pool is essentially bounded by the number of threads
@@ -321,62 +261,45 @@
    * @param <R>
    *          the type of the resource
    */
-  static class ThreadLocalPool<R> extends ThreadLocal<R> implements Pool<R> {
-    private static final Map<ThreadLocalPool<?>, AtomicInteger> poolSizes = new HashMap<>();
+  static class ThreadLocalPool<R> implements Pool<R> {
+    private final Map<Thread, R> resources;
 
     public ThreadLocalPool() {
+      resources = new HashMap<>();
     }
 
     @Override
-    public R put(R resource) {
-      R previousResource = get();
-      if (previousResource == null) {
-        AtomicInteger poolSize = poolSizes.get(this);
-        if (poolSize == null) {
-          poolSizes.put(this, poolSize = new AtomicInteger(0));
-        }
-        poolSize.incrementAndGet();
+    public R getOrCreate(PoolResourceSupplier<R> supplier) throws IOException {
+      Thread myself = Thread.currentThread();
+      R resource = resources.get(myself);
+
+      if (resource == null) {
+        resource = createResource(supplier);
+        resources.put(myself, resource);
       }
-      this.set(resource);
-      return previousResource;
-    }
 
-    @Override
-    public void remove() {
-      super.remove();
-      AtomicInteger poolSize = poolSizes.get(this);
-      if (poolSize != null) {
-        poolSize.decrementAndGet();
-      }
-    }
-
-    @Override
-    public int size() {
-      AtomicInteger poolSize = poolSizes.get(this);
-      return poolSize != null ? poolSize.get() : 0;
+      return resource;
     }
 
     @Override
     public boolean remove(R resource) {
-      R previousResource = super.get();
-      if (resource != null && resource.equals(previousResource)) {
-        remove();
-        return true;
-      } else {
-        return false;
-      }
+      /* remove can be called from any thread */
+      return resources.values().remove(resource);
+    }
+
+    @Override
+    public int size() {
+      return resources.size();
     }
 
     @Override
     public void clear() {
-      super.remove();
+      resources.clear();
     }
 
     @Override
     public Collection<R> values() {
-      List<R> values = new ArrayList<>();
-      values.add(get());
-      return values;
+      return resources.values();
     }
   }
 }
diff --git a/hbase-client/src/test/java/org/apache/hadoop/hbase/util/PoolMapTestBase.java b/hbase-client/src/test/java/org/apache/hadoop/hbase/util/PoolMapTestBase.java
index 1b24252..314cae9 100644
--- a/hbase-client/src/test/java/org/apache/hadoop/hbase/util/PoolMapTestBase.java
+++ b/hbase-client/src/test/java/org/apache/hadoop/hbase/util/PoolMapTestBase.java
@@ -17,9 +17,13 @@
  */
 package org.apache.hadoop.hbase.util;
 
+import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertTrue;
 
+import java.io.IOException;
+import java.util.Objects;
 import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicReference;
 import org.apache.hadoop.hbase.util.PoolMap.PoolType;
 import org.junit.After;
 import org.junit.Before;
@@ -28,6 +32,7 @@
 
   protected PoolMap<String, String> poolMap;
 
+  protected static final int KEY_COUNT = 5;
   protected static final int POOL_SIZE = 3;
 
   @Before
@@ -35,27 +40,5 @@
     this.poolMap = new PoolMap<>(getPoolType(), POOL_SIZE);
   }
 
-  @After
-  public void tearDown() throws Exception {
-    this.poolMap.clear();
-  }
-
   protected abstract PoolType getPoolType();
-
-  protected void runThread(final String randomKey, final String randomValue,
-      final String expectedValue) throws InterruptedException {
-    final AtomicBoolean matchFound = new AtomicBoolean(false);
-    Thread thread = new Thread(new Runnable() {
-      @Override
-      public void run() {
-        poolMap.put(randomKey, randomValue);
-        String actualValue = poolMap.get(randomKey);
-        matchFound
-            .set(expectedValue == null ? actualValue == null : expectedValue.equals(actualValue));
-      }
-    });
-    thread.start();
-    thread.join();
-    assertTrue(matchFound.get());
-  }
 }
diff --git a/hbase-client/src/test/java/org/apache/hadoop/hbase/util/TestRoundRobinPoolMap.java b/hbase-client/src/test/java/org/apache/hadoop/hbase/util/TestRoundRobinPoolMap.java
index a71cf29..ef7cb4e 100644
--- a/hbase-client/src/test/java/org/apache/hadoop/hbase/util/TestRoundRobinPoolMap.java
+++ b/hbase-client/src/test/java/org/apache/hadoop/hbase/util/TestRoundRobinPoolMap.java
@@ -18,12 +18,19 @@
 package org.apache.hadoop.hbase.util;
 
 import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
 
+import java.io.IOException;
 import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Iterator;
 import java.util.List;
 import java.util.Random;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.CompletionException;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.ThreadLocalRandom;
+import java.util.concurrent.atomic.AtomicInteger;
 import org.apache.hadoop.hbase.HBaseClassTestRule;
 import org.apache.hadoop.hbase.testclassification.MiscTests;
 import org.apache.hadoop.hbase.testclassification.SmallTests;
@@ -45,58 +52,103 @@
   }
 
   @Test
-  public void testSingleThreadedClient() throws InterruptedException, ExecutionException {
-    Random rand = ThreadLocalRandom.current();
-    String randomKey = String.valueOf(rand.nextInt());
-    String randomValue = String.valueOf(rand.nextInt());
-    // As long as the pool is not full, we'll get null back.
-    // This forces the user to create new values that can be used to populate
-    // the pool.
-    runThread(randomKey, randomValue, null);
-    assertEquals(1, poolMap.size(randomKey));
+  public void testGetOrCreate() throws IOException {
+    String key = "key";
+    String value = "value";
+    String result = poolMap.getOrCreate(key, () -> value);
+
+    assertEquals(value, result);
+    assertEquals(1, poolMap.values().size());
   }
 
   @Test
-  public void testMultiThreadedClients() throws InterruptedException, ExecutionException {
-    Random rand = ThreadLocalRandom.current();
+  public void testMultipleKeys() throws IOException {
+    for (int i = 0; i < KEY_COUNT; i++) {
+      String key = Integer.toString(i);
+      String value = Integer.toString(2 * i);
+      String result = poolMap.getOrCreate(key, () -> value);
+
+      assertEquals(value, result);
+    }
+
+    assertEquals(KEY_COUNT, poolMap.values().size());
+  }
+
+  @Test
+  public void testMultipleValues() throws IOException {
+    String key = "key";
+
     for (int i = 0; i < POOL_SIZE; i++) {
-      String randomKey = String.valueOf(rand.nextInt());
-      String randomValue = String.valueOf(rand.nextInt());
-      // As long as the pool is not full, we'll get null back
-      runThread(randomKey, randomValue, null);
-      // As long as we use distinct keys, each pool will have one value
-      assertEquals(1, poolMap.size(randomKey));
+      String value = Integer.toString(i);
+      String result = poolMap.getOrCreate(key, () -> value);
+
+      assertEquals(value, result);
     }
-    poolMap.clear();
-    String randomKey = String.valueOf(rand.nextInt());
-    for (int i = 0; i < POOL_SIZE - 1; i++) {
-      String randomValue = String.valueOf(rand.nextInt());
-      // As long as the pool is not full, we'll get null back
-      runThread(randomKey, randomValue, null);
-      // since we use the same key, the pool size should grow
-      assertEquals(i + 1, poolMap.size(randomKey));
-    }
-    // at the end of the day, there should be as many values as we put
-    assertEquals(POOL_SIZE - 1, poolMap.size(randomKey));
+
+    assertEquals(POOL_SIZE, poolMap.values().size());
   }
 
   @Test
-  public void testPoolCap() throws InterruptedException, ExecutionException {
-    Random rand = ThreadLocalRandom.current();
-    String randomKey = String.valueOf(rand.nextInt());
-    List<String> randomValues = new ArrayList<>();
-    for (int i = 0; i < POOL_SIZE * 2; i++) {
-      String randomValue = String.valueOf(rand.nextInt());
-      randomValues.add(randomValue);
-      if (i < POOL_SIZE - 1) {
-        // As long as the pool is not full, we'll get null back
-        runThread(randomKey, randomValue, null);
-      } else {
-        // when the pool becomes full, we expect the value we get back to be
-        // what we put earlier, in round-robin order
-        runThread(randomKey, randomValue, randomValues.get((i - POOL_SIZE + 1) % POOL_SIZE));
-      }
+  public void testRoundRobin() throws IOException {
+    String key = "key";
+
+    for (int i = 0; i < POOL_SIZE; i++) {
+      String value = Integer.toString(i);
+      poolMap.getOrCreate(key, () -> value);
     }
-    assertEquals(POOL_SIZE, poolMap.size(randomKey));
+
+    assertEquals(POOL_SIZE, poolMap.values().size());
+
+    /* pool is filled, get() should return elements round robin order */
+    for (int i = 0; i < 2 * POOL_SIZE; i++) {
+      String expected = Integer.toString(i % POOL_SIZE);
+      assertEquals(expected, poolMap.getOrCreate(key, () -> {
+        throw new IOException("must not call me");
+      }));
+    }
+
+    assertEquals(POOL_SIZE, poolMap.values().size());
+  }
+
+  @Test
+  public void testMultiThreadedRoundRobin() throws ExecutionException, InterruptedException {
+    String key = "key";
+    AtomicInteger id = new AtomicInteger();
+    List<String> results = Collections.synchronizedList(new ArrayList<>());
+
+    Runnable runnable = () -> {
+      try {
+        for (int i = 0; i < POOL_SIZE; i++) {
+          String value = Integer.toString(id.getAndIncrement());
+          String result = poolMap.getOrCreate(key, () -> value);
+          results.add(result);
+
+          Thread.yield();
+        }
+      } catch (IOException e) {
+        throw new CompletionException(e);
+      }
+    };
+
+    CompletableFuture<Void> future1 = CompletableFuture.runAsync(runnable);
+    CompletableFuture<Void> future2 = CompletableFuture.runAsync(runnable);
+
+    /* test for successful completion */
+    future1.get();
+    future2.get();
+
+    assertEquals(POOL_SIZE, poolMap.values().size());
+
+    /* check every elements occur twice */
+    Collections.sort(results);
+    Iterator<String> iterator = results.iterator();
+
+    for (int i = 0; i < POOL_SIZE; i++) {
+      String next1 = iterator.next();
+      String next2 = iterator.next();
+      assertEquals(next1, next2);
+    }
+
+    assertFalse(iterator.hasNext());
   }
 }
diff --git a/hbase-client/src/test/java/org/apache/hadoop/hbase/util/TestThreadLocalPoolMap.java b/hbase-client/src/test/java/org/apache/hadoop/hbase/util/TestThreadLocalPoolMap.java
index 5f047c4..a1cb610 100644
--- a/hbase-client/src/test/java/org/apache/hadoop/hbase/util/TestThreadLocalPoolMap.java
+++ b/hbase-client/src/test/java/org/apache/hadoop/hbase/util/TestThreadLocalPoolMap.java
@@ -19,9 +19,13 @@
 
 import static org.junit.Assert.assertEquals;
 
+import java.io.IOException;
 import java.util.Random;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.CompletionException;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.ThreadLocalRandom;
+import java.util.concurrent.atomic.AtomicInteger;
 import org.apache.hadoop.hbase.HBaseClassTestRule;
 import org.apache.hadoop.hbase.testclassification.MiscTests;
 import org.apache.hadoop.hbase.testclassification.SmallTests;
@@ -43,42 +47,71 @@
   }
 
   @Test
-  public void testSingleThreadedClient() throws InterruptedException, ExecutionException {
-    Random rand = ThreadLocalRandom.current();
-    String randomKey = String.valueOf(rand.nextInt());
-    String randomValue = String.valueOf(rand.nextInt());
-    // As long as the pool is not full, we should get back what we put
-    runThread(randomKey, randomValue, randomValue);
-    assertEquals(1, poolMap.size(randomKey));
+  public void testGetOrCreate() throws IOException {
+    String key = "key";
+    String value = "value";
+    String result = poolMap.getOrCreate(key, () -> value);
+
+    assertEquals(value, result);
+    assertEquals(1, poolMap.values().size());
   }
 
   @Test
-  public void testMultiThreadedClients() throws InterruptedException, ExecutionException {
-    Random rand = ThreadLocalRandom.current();
-    // As long as the pool is not full, we should get back what we put
-    for (int i = 0; i < POOL_SIZE; i++) {
-      String randomKey = String.valueOf(rand.nextInt());
-      String randomValue = String.valueOf(rand.nextInt());
-      runThread(randomKey, randomValue, randomValue);
-      assertEquals(1, poolMap.size(randomKey));
+  public void testMultipleKeys() throws IOException {
+    for (int i = 0; i < KEY_COUNT; i++) {
+      String key = Integer.toString(i);
+      String value = Integer.toString(2 * i);
+      String result = poolMap.getOrCreate(key, () -> value);
+
+      assertEquals(value, result);
     }
-    String randomKey = String.valueOf(rand.nextInt());
-    for (int i = 0; i < POOL_SIZE; i++) {
-      String randomValue = String.valueOf(rand.nextInt());
-      runThread(randomKey, randomValue, randomValue);
-      assertEquals(i + 1, poolMap.size(randomKey));
-    }
+
+    assertEquals(KEY_COUNT, poolMap.values().size());
   }
 
   @Test
-  public void testPoolCap() throws InterruptedException, ExecutionException {
-    Random rand = ThreadLocalRandom.current();
-    String randomKey = String.valueOf(rand.nextInt());
-    for (int i = 0; i < POOL_SIZE * 2; i++) {
-      String randomValue = String.valueOf(rand.nextInt());
-      // as of HBASE-4150, pool limit is no longer used with ThreadLocalPool
-      runThread(randomKey, randomValue, randomValue);
-    }
-    assertEquals(POOL_SIZE * 2, poolMap.size(randomKey));
+  public void testFull() throws IOException {
+    String key = "key";
+    String value = "value";
+
+    String result = poolMap.getOrCreate(key, () -> value);
+    assertEquals(value, result);
+
+    String result2 = poolMap.getOrCreate(key, () -> {
+      throw new IOException("must not call me");
+    });
+
+    assertEquals(value, result2);
+    assertEquals(1, poolMap.values().size());
+  }
+
+  @Test
+  public void testLocality() throws ExecutionException, InterruptedException {
+    String key = "key";
+    AtomicInteger id = new AtomicInteger();
+
+    Runnable runnable = () -> {
+      try {
+        String myId = Integer.toString(id.getAndIncrement());
+
+        for (int i = 0; i < 3; i++) {
+          String result = poolMap.getOrCreate(key, () -> myId);
+          assertEquals(myId, result);
+
+          Thread.yield();
+        }
+      } catch (IOException e) {
+        throw new CompletionException(e);
+      }
+    };
+
+    CompletableFuture<Void> future1 = CompletableFuture.runAsync(runnable);
+    CompletableFuture<Void> future2 = CompletableFuture.runAsync(runnable);
+
+    /* test for successful completion */
+    future1.get();
+    future2.get();
+
+    assertEquals(2, poolMap.values().size());
   }
 }