HBASE-27999 Implement cache prefetch aware load balancer (#5376)

Signed-off-by: Wellington Chevreuil <wchevreuil@apache.org>
diff --git a/hbase-balancer/src/main/java/org/apache/hadoop/hbase/master/balancer/BalancerClusterState.java b/hbase-balancer/src/main/java/org/apache/hadoop/hbase/master/balancer/BalancerClusterState.java
index a7ae8b4..4b3809c 100644
--- a/hbase-balancer/src/main/java/org/apache/hadoop/hbase/master/balancer/BalancerClusterState.java
+++ b/hbase-balancer/src/main/java/org/apache/hadoop/hbase/master/balancer/BalancerClusterState.java
@@ -34,6 +34,7 @@
 import org.apache.hadoop.hbase.client.RegionReplicaUtil;
 import org.apache.hadoop.hbase.master.RackManager;
 import org.apache.hadoop.hbase.net.Address;
+import org.apache.hadoop.hbase.util.Pair;
 import org.apache.yetus.audience.InterfaceAudience;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -114,6 +115,12 @@
   private float[][] rackLocalities;
   // Maps localityType -> region -> [server|rack]Index with highest locality
   private int[][] regionsToMostLocalEntities;
+  // Maps region -> serverIndex -> regionCacheRatio of a region on a server
+  private Map<Pair<Integer, Integer>, Float> regionIndexServerIndexRegionCachedRatio;
+  // Maps regionIndex -> serverIndex with best region cache ratio
+  private int[] regionServerIndexWithBestRegionCachedRatio;
+  // Maps regionName -> oldServerName -> cache ratio of the region on the old server
+  Map<String, Pair<ServerName, Float>> regionCacheRatioOnOldServerMap;
 
   static class DefaultRackManager extends RackManager {
     @Override
@@ -125,13 +132,20 @@
   BalancerClusterState(Map<ServerName, List<RegionInfo>> clusterState,
     Map<String, Deque<BalancerRegionLoad>> loads, RegionHDFSBlockLocationFinder regionFinder,
     RackManager rackManager) {
-    this(null, clusterState, loads, regionFinder, rackManager);
+    this(null, clusterState, loads, regionFinder, rackManager, null);
+  }
+
+  protected BalancerClusterState(Map<ServerName, List<RegionInfo>> clusterState,
+    Map<String, Deque<BalancerRegionLoad>> loads, RegionHDFSBlockLocationFinder regionFinder,
+    RackManager rackManager, Map<String, Pair<ServerName, Float>> oldRegionServerRegionCacheRatio) {
+    this(null, clusterState, loads, regionFinder, rackManager, oldRegionServerRegionCacheRatio);
   }
 
   @SuppressWarnings("unchecked")
   BalancerClusterState(Collection<RegionInfo> unassignedRegions,
     Map<ServerName, List<RegionInfo>> clusterState, Map<String, Deque<BalancerRegionLoad>> loads,
-    RegionHDFSBlockLocationFinder regionFinder, RackManager rackManager) {
+    RegionHDFSBlockLocationFinder regionFinder, RackManager rackManager,
+    Map<String, Pair<ServerName, Float>> oldRegionServerRegionCacheRatio) {
     if (unassignedRegions == null) {
       unassignedRegions = Collections.emptyList();
     }
@@ -145,6 +159,8 @@
     tables = new ArrayList<>();
     this.rackManager = rackManager != null ? rackManager : new DefaultRackManager();
 
+    this.regionCacheRatioOnOldServerMap = oldRegionServerRegionCacheRatio;
+
     numRegions = 0;
 
     List<List<Integer>> serversPerHostList = new ArrayList<>();
@@ -542,6 +558,142 @@
   }
 
   /**
+   * Returns the size of hFiles from the most recent RegionLoad for region
+   */
+  public int getTotalRegionHFileSizeMB(int region) {
+    Deque<BalancerRegionLoad> load = regionLoads[region];
+    if (load == null) {
+      // This means, that the region has no actual data on disk
+      return 0;
+    }
+    return regionLoads[region].getLast().getRegionSizeMB();
+  }
+
+  /**
+   * Returns the weighted cache ratio of a region on the given region server
+   */
+  public float getOrComputeWeightedRegionCacheRatio(int region, int server) {
+    return getTotalRegionHFileSizeMB(region) * getOrComputeRegionCacheRatio(region, server);
+  }
+
+  /**
+   * Returns the amount by which a region is cached on a given region server. If the region is not
+   * currently hosted on the given region server, then find out if it was previously hosted there
+   * and return the old cache ratio.
+   */
+  protected float getRegionCacheRatioOnRegionServer(int region, int regionServerIndex) {
+    float regionCacheRatio = 0.0f;
+
+    // Get the current region cache ratio if the region is hosted on the server regionServerIndex
+    for (int regionIndex : regionsPerServer[regionServerIndex]) {
+      if (region != regionIndex) {
+        continue;
+      }
+
+      Deque<BalancerRegionLoad> regionLoadList = regionLoads[regionIndex];
+
+      // The region is currently hosted on this region server. Get the region cache ratio for this
+      // region on this server
+      regionCacheRatio =
+        regionLoadList == null ? 0.0f : regionLoadList.getLast().getCurrentRegionCacheRatio();
+
+      return regionCacheRatio;
+    }
+
+    // Region is not currently hosted on this server. Check if the region was cached on this
+    // server earlier. This can happen when the server was shutdown and the cache was persisted.
+    // Search using the region name and server name and not the index id and server id as these ids
+    // may change when a server is marked as dead or a new server is added.
+    String regionEncodedName = regions[region].getEncodedName();
+    ServerName serverName = servers[regionServerIndex];
+    if (
+      regionCacheRatioOnOldServerMap != null
+        && regionCacheRatioOnOldServerMap.containsKey(regionEncodedName)
+    ) {
+      Pair<ServerName, Float> cacheRatioOfRegionOnServer =
+        regionCacheRatioOnOldServerMap.get(regionEncodedName);
+      if (ServerName.isSameAddress(cacheRatioOfRegionOnServer.getFirst(), serverName)) {
+        regionCacheRatio = cacheRatioOfRegionOnServer.getSecond();
+        if (LOG.isDebugEnabled()) {
+          LOG.debug("Old cache ratio found for region {} on server {}: {}", regionEncodedName,
+            serverName, regionCacheRatio);
+        }
+      }
+    }
+    return regionCacheRatio;
+  }
+
+  /**
+   * Populate the maps containing information about how much a region is cached on a region server.
+   */
+  private void computeRegionServerRegionCacheRatio() {
+    regionIndexServerIndexRegionCachedRatio = new HashMap<>();
+    regionServerIndexWithBestRegionCachedRatio = new int[numRegions];
+
+    for (int region = 0; region < numRegions; region++) {
+      float bestRegionCacheRatio = 0.0f;
+      int serverWithBestRegionCacheRatio = 0;
+      for (int server = 0; server < numServers; server++) {
+        float regionCacheRatio = getRegionCacheRatioOnRegionServer(region, server);
+        if (regionCacheRatio > 0.0f || server == regionIndexToServerIndex[region]) {
+          // A region with cache ratio 0 on a server means nothing. Hence, just make a note of
+          // cache ratio only if the cache ratio is greater than 0.
+          Pair<Integer, Integer> regionServerPair = new Pair<>(region, server);
+          regionIndexServerIndexRegionCachedRatio.put(regionServerPair, regionCacheRatio);
+        }
+        if (regionCacheRatio > bestRegionCacheRatio) {
+          serverWithBestRegionCacheRatio = server;
+          // If the server currently hosting the region has equal cache ratio to a historical
+          // server, consider the current server to keep hosting the region
+          bestRegionCacheRatio = regionCacheRatio;
+        } else if (
+          regionCacheRatio == bestRegionCacheRatio && server == regionIndexToServerIndex[region]
+        ) {
+          // If two servers have same region cache ratio, then the server currently hosting the
+          // region
+          // should retain the region
+          serverWithBestRegionCacheRatio = server;
+        }
+      }
+      regionServerIndexWithBestRegionCachedRatio[region] = serverWithBestRegionCacheRatio;
+      Pair<Integer, Integer> regionServerPair =
+        new Pair<>(region, regionIndexToServerIndex[region]);
+      float tempRegionCacheRatio = regionIndexServerIndexRegionCachedRatio.get(regionServerPair);
+      if (tempRegionCacheRatio > bestRegionCacheRatio) {
+        LOG.warn(
+          "INVALID CONDITION: region {} on server {} cache ratio {} is greater than the "
+            + "best region cache ratio {} on server {}",
+          regions[region].getEncodedName(), servers[regionIndexToServerIndex[region]],
+          tempRegionCacheRatio, bestRegionCacheRatio, servers[serverWithBestRegionCacheRatio]);
+      }
+    }
+  }
+
+  protected float getOrComputeRegionCacheRatio(int region, int server) {
+    if (
+      regionServerIndexWithBestRegionCachedRatio == null
+        || regionIndexServerIndexRegionCachedRatio.isEmpty()
+    ) {
+      computeRegionServerRegionCacheRatio();
+    }
+
+    Pair<Integer, Integer> regionServerPair = new Pair<>(region, server);
+    return regionIndexServerIndexRegionCachedRatio.containsKey(regionServerPair)
+      ? regionIndexServerIndexRegionCachedRatio.get(regionServerPair)
+      : 0.0f;
+  }
+
+  public int[] getOrComputeServerWithBestRegionCachedRatio() {
+    if (
+      regionServerIndexWithBestRegionCachedRatio == null
+        || regionIndexServerIndexRegionCachedRatio.isEmpty()
+    ) {
+      computeRegionServerRegionCacheRatio();
+    }
+    return regionServerIndexWithBestRegionCachedRatio;
+  }
+
+  /**
    * Maps region index to rack index
    */
   public int getRackForRegion(int region) {
diff --git a/hbase-balancer/src/main/java/org/apache/hadoop/hbase/master/balancer/BalancerRegionLoad.java b/hbase-balancer/src/main/java/org/apache/hadoop/hbase/master/balancer/BalancerRegionLoad.java
index ffb36cb..33d00e3 100644
--- a/hbase-balancer/src/main/java/org/apache/hadoop/hbase/master/balancer/BalancerRegionLoad.java
+++ b/hbase-balancer/src/main/java/org/apache/hadoop/hbase/master/balancer/BalancerRegionLoad.java
@@ -34,6 +34,8 @@
   private final long writeRequestsCount;
   private final int memStoreSizeMB;
   private final int storefileSizeMB;
+  private final int regionSizeMB;
+  private final float currentRegionPrefetchRatio;
 
   BalancerRegionLoad(RegionMetrics regionMetrics) {
     readRequestsCount = regionMetrics.getReadRequestCount();
@@ -41,6 +43,8 @@
     writeRequestsCount = regionMetrics.getWriteRequestCount();
     memStoreSizeMB = (int) regionMetrics.getMemStoreSize().get(Size.Unit.MEGABYTE);
     storefileSizeMB = (int) regionMetrics.getStoreFileSize().get(Size.Unit.MEGABYTE);
+    regionSizeMB = (int) regionMetrics.getRegionSizeMB().get(Size.Unit.MEGABYTE);
+    currentRegionPrefetchRatio = regionMetrics.getCurrentRegionCachedRatio();
   }
 
   public long getReadRequestsCount() {
@@ -62,4 +66,12 @@
   public int getStorefileSizeMB() {
     return storefileSizeMB;
   }
+
+  public int getRegionSizeMB() {
+    return regionSizeMB;
+  }
+
+  public float getCurrentRegionCacheRatio() {
+    return currentRegionPrefetchRatio;
+  }
 }
diff --git a/hbase-balancer/src/main/java/org/apache/hadoop/hbase/master/balancer/BaseLoadBalancer.java b/hbase-balancer/src/main/java/org/apache/hadoop/hbase/master/balancer/BaseLoadBalancer.java
index a4560cc..5451686 100644
--- a/hbase-balancer/src/main/java/org/apache/hadoop/hbase/master/balancer/BaseLoadBalancer.java
+++ b/hbase-balancer/src/main/java/org/apache/hadoop/hbase/master/balancer/BaseLoadBalancer.java
@@ -232,7 +232,8 @@
         clusterState.put(server, Collections.emptyList());
       }
     }
-    return new BalancerClusterState(regions, clusterState, null, this.regionFinder, rackManager);
+    return new BalancerClusterState(regions, clusterState, null, this.regionFinder, rackManager,
+      null);
   }
 
   private List<ServerName> findIdleServers(List<ServerName> servers) {
diff --git a/hbase-balancer/src/main/java/org/apache/hadoop/hbase/master/balancer/CacheAwareLoadBalancer.java b/hbase-balancer/src/main/java/org/apache/hadoop/hbase/master/balancer/CacheAwareLoadBalancer.java
new file mode 100644
index 0000000..d73769a
--- /dev/null
+++ b/hbase-balancer/src/main/java/org/apache/hadoop/hbase/master/balancer/CacheAwareLoadBalancer.java
@@ -0,0 +1,479 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.master.balancer;
+
+/** An implementation of the {@link org.apache.hadoop.hbase.master.LoadBalancer} that assigns regions
+ * based on the amount they are cached on a given server. A region can move across the region
+ * servers whenever a region server shuts down or crashes. The region server preserves the cache
+ * periodically and restores the cache when it is restarted. This balancer implements a mechanism
+ * where it maintains the amount by which a region is cached on a region server. During balancer
+ * run, a region plan is generated that takes into account this cache information and tries to
+ * move the regions so that the cache minimally impacted.
+ */
+
+import static org.apache.hadoop.hbase.HConstants.BUCKET_CACHE_PERSISTENT_PATH_KEY;
+
+import java.util.ArrayDeque;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Deque;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.hbase.ClusterMetrics;
+import org.apache.hadoop.hbase.RegionMetrics;
+import org.apache.hadoop.hbase.ServerMetrics;
+import org.apache.hadoop.hbase.ServerName;
+import org.apache.hadoop.hbase.Size;
+import org.apache.hadoop.hbase.client.RegionInfo;
+import org.apache.hadoop.hbase.util.Pair;
+import org.apache.yetus.audience.InterfaceAudience;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+@InterfaceAudience.Private
+public class CacheAwareLoadBalancer extends StochasticLoadBalancer {
+  private static final Logger LOG = LoggerFactory.getLogger(CacheAwareLoadBalancer.class);
+
+  private Configuration configuration;
+
+  public enum GeneratorFunctionType {
+    LOAD,
+    CACHE_RATIO
+  }
+
+  @Override
+  public synchronized void loadConf(Configuration configuration) {
+    this.configuration = configuration;
+    this.costFunctions = new ArrayList<>();
+    super.loadConf(configuration);
+  }
+
+  @Override
+  protected List<CandidateGenerator> createCandidateGenerators() {
+    List<CandidateGenerator> candidateGenerators = new ArrayList<>(2);
+    candidateGenerators.add(GeneratorFunctionType.LOAD.ordinal(),
+      new CacheAwareSkewnessCandidateGenerator());
+    candidateGenerators.add(GeneratorFunctionType.CACHE_RATIO.ordinal(),
+      new CacheAwareCandidateGenerator());
+    return candidateGenerators;
+  }
+
+  @Override
+  protected List<CostFunction> createCostFunctions(Configuration configuration) {
+    List<CostFunction> costFunctions = new ArrayList<>();
+    addCostFunction(costFunctions, new CacheAwareRegionSkewnessCostFunction(configuration));
+    addCostFunction(costFunctions, new CacheAwareCostFunction(configuration));
+    return costFunctions;
+  }
+
+  private void addCostFunction(List<CostFunction> costFunctions, CostFunction costFunction) {
+    if (costFunction.getMultiplier() > 0) {
+      costFunctions.add(costFunction);
+    }
+  }
+
+  @Override
+  public void updateClusterMetrics(ClusterMetrics clusterMetrics) {
+    this.clusterStatus = clusterMetrics;
+    updateRegionLoad();
+  }
+
+  /**
+   * Collect the amount of region cached for all the regions from all the active region servers.
+   */
+  private void updateRegionLoad() {
+    loads = new HashMap<>();
+    regionCacheRatioOnOldServerMap = new HashMap<>();
+    Map<String, Pair<ServerName, Integer>> regionCacheRatioOnCurrentServerMap = new HashMap<>();
+
+    // Build current region cache statistics
+    clusterStatus.getLiveServerMetrics().forEach((ServerName sn, ServerMetrics sm) -> {
+      // Create a map of region and the server where it is currently hosted
+      sm.getRegionMetrics().forEach((byte[] regionName, RegionMetrics rm) -> {
+        String regionEncodedName = RegionInfo.encodeRegionName(regionName);
+
+        Deque<BalancerRegionLoad> rload = new ArrayDeque<>();
+
+        // Get the total size of the hFiles in this region
+        int regionSizeMB = (int) rm.getRegionSizeMB().get(Size.Unit.MEGABYTE);
+
+        rload.add(new BalancerRegionLoad(rm));
+        // Maintain a map of region and it's total size. This is needed to calculate the cache
+        // ratios for the regions cached on old region servers
+        regionCacheRatioOnCurrentServerMap.put(regionEncodedName, new Pair<>(sn, regionSizeMB));
+        loads.put(regionEncodedName, rload);
+      });
+    });
+
+    // Build cache statistics for the regions hosted previously on old region servers
+    clusterStatus.getLiveServerMetrics().forEach((ServerName sn, ServerMetrics sm) -> {
+      // Find if a region was previously hosted on a server other than the one it is currently
+      // hosted on.
+      sm.getRegionCachedInfo().forEach((String regionEncodedName, Integer regionSizeInCache) -> {
+        // If the region is found in regionCacheRatioOnCurrentServerMap, it is currently hosted on
+        // this server
+        if (regionCacheRatioOnCurrentServerMap.containsKey(regionEncodedName)) {
+          ServerName currentServer =
+            regionCacheRatioOnCurrentServerMap.get(regionEncodedName).getFirst();
+          if (!ServerName.isSameAddress(currentServer, sn)) {
+            int regionSizeMB =
+              regionCacheRatioOnCurrentServerMap.get(regionEncodedName).getSecond();
+            float regionCacheRatioOnOldServer =
+              regionSizeMB == 0 ? 0.0f : (float) regionSizeInCache / regionSizeMB;
+            regionCacheRatioOnOldServerMap.put(regionEncodedName,
+              new Pair<>(sn, regionCacheRatioOnOldServer));
+          }
+        }
+      });
+    });
+  }
+
+  private RegionInfo getRegionInfoByEncodedName(BalancerClusterState cluster, String regionName) {
+    Optional<RegionInfo> regionInfoOptional =
+      Arrays.stream(cluster.regions).filter((RegionInfo ri) -> {
+        return regionName.equals(ri.getEncodedName());
+      }).findFirst();
+
+    if (regionInfoOptional.isPresent()) {
+      return regionInfoOptional.get();
+    }
+    return null;
+  }
+
+  private class CacheAwareCandidateGenerator extends CandidateGenerator {
+    @Override
+    protected BalanceAction generate(BalancerClusterState cluster) {
+      // Move the regions to the servers they were previously hosted on based on the cache ratio
+      if (
+        !regionCacheRatioOnOldServerMap.isEmpty()
+          && regionCacheRatioOnOldServerMap.entrySet().iterator().hasNext()
+      ) {
+        Map.Entry<String, Pair<ServerName, Float>> regionCacheRatioServerMap =
+          regionCacheRatioOnOldServerMap.entrySet().iterator().next();
+        // Get the server where this region was previously hosted
+        String regionEncodedName = regionCacheRatioServerMap.getKey();
+        RegionInfo regionInfo = getRegionInfoByEncodedName(cluster, regionEncodedName);
+        if (regionInfo == null) {
+          LOG.warn("Region {} not found", regionEncodedName);
+          regionCacheRatioOnOldServerMap.remove(regionEncodedName);
+          return BalanceAction.NULL_ACTION;
+        }
+        if (regionInfo.isMetaRegion() || regionInfo.getTable().isSystemTable()) {
+          regionCacheRatioOnOldServerMap.remove(regionEncodedName);
+          return BalanceAction.NULL_ACTION;
+        }
+        int regionIndex = cluster.regionsToIndex.get(regionInfo);
+        int oldServerIndex = cluster.serversToIndex
+          .get(regionCacheRatioOnOldServerMap.get(regionEncodedName).getFirst().getAddress());
+        if (oldServerIndex < 0) {
+          LOG.warn("Server previously hosting region {} not found", regionEncodedName);
+          regionCacheRatioOnOldServerMap.remove(regionEncodedName);
+          return BalanceAction.NULL_ACTION;
+        }
+
+        float oldRegionCacheRatio =
+          cluster.getOrComputeRegionCacheRatio(regionIndex, oldServerIndex);
+        int currentServerIndex = cluster.regionIndexToServerIndex[regionIndex];
+        float currentRegionCacheRatio =
+          cluster.getOrComputeRegionCacheRatio(regionIndex, currentServerIndex);
+
+        BalanceAction action = generatePlan(cluster, regionIndex, currentServerIndex,
+          currentRegionCacheRatio, oldServerIndex, oldRegionCacheRatio);
+        regionCacheRatioOnOldServerMap.remove(regionEncodedName);
+        return action;
+      }
+      return BalanceAction.NULL_ACTION;
+    }
+
+    private BalanceAction generatePlan(BalancerClusterState cluster, int regionIndex,
+      int currentServerIndex, float cacheRatioOnCurrentServer, int oldServerIndex,
+      float cacheRatioOnOldServer) {
+      return moveRegionToOldServer(cluster, regionIndex, currentServerIndex,
+        cacheRatioOnCurrentServer, oldServerIndex, cacheRatioOnOldServer)
+          ? getAction(currentServerIndex, regionIndex, oldServerIndex, -1)
+          : BalanceAction.NULL_ACTION;
+    }
+
+    private boolean moveRegionToOldServer(BalancerClusterState cluster, int regionIndex,
+      int currentServerIndex, float cacheRatioOnCurrentServer, int oldServerIndex,
+      float cacheRatioOnOldServer) {
+      // Find if the region has already moved by comparing the current server index with the
+      // current server index. This can happen when other candidate generator has moved the region
+      if (currentServerIndex < 0 || oldServerIndex < 0) {
+        return false;
+      }
+
+      float cacheRatioDiffThreshold = 0.6f;
+
+      // Conditions for moving the region
+
+      // If the region is fully cached on the old server, move the region back
+      if (cacheRatioOnOldServer == 1.0f) {
+        if (LOG.isDebugEnabled()) {
+          LOG.debug("Region {} moved to the old server {} as it is fully cached there",
+            cluster.regions[regionIndex].getEncodedName(), cluster.servers[oldServerIndex]);
+        }
+        return true;
+      }
+
+      // Move the region back to the old server if it is cached equally on both the servers
+      if (cacheRatioOnCurrentServer == cacheRatioOnOldServer) {
+        if (LOG.isDebugEnabled()) {
+          LOG.debug(
+            "Region {} moved from {} to {} as the region is cached {} equally on both servers",
+            cluster.regions[regionIndex].getEncodedName(), cluster.servers[currentServerIndex],
+            cluster.servers[oldServerIndex], cacheRatioOnCurrentServer);
+        }
+        return true;
+      }
+
+      // If the region is not fully cached on either of the servers, move the region back to the
+      // old server if the region cache ratio on the current server is still much less than the old
+      // server
+      if (
+        cacheRatioOnOldServer > 0.0f
+          && cacheRatioOnCurrentServer / cacheRatioOnOldServer < cacheRatioDiffThreshold
+      ) {
+        if (LOG.isDebugEnabled()) {
+          LOG.debug(
+            "Region {} moved from {} to {} as region cache ratio {} is better than the current "
+              + "cache ratio {}",
+            cluster.regions[regionIndex].getEncodedName(), cluster.servers[currentServerIndex],
+            cluster.servers[oldServerIndex], cacheRatioOnCurrentServer, cacheRatioOnOldServer);
+        }
+        return true;
+      }
+
+      if (LOG.isDebugEnabled()) {
+        LOG.debug(
+          "Region {} not moved from {} to {} with current cache ratio {} and old cache ratio {}",
+          cluster.regions[regionIndex], cluster.servers[currentServerIndex],
+          cluster.servers[oldServerIndex], cacheRatioOnCurrentServer, cacheRatioOnOldServer);
+      }
+      return false;
+    }
+  }
+
+  private class CacheAwareSkewnessCandidateGenerator extends LoadCandidateGenerator {
+    @Override
+    BalanceAction pickRandomRegions(BalancerClusterState cluster, int thisServer, int otherServer) {
+      // First move all the regions which were hosted previously on some other server back to their
+      // old servers
+      if (
+        !regionCacheRatioOnOldServerMap.isEmpty()
+          && regionCacheRatioOnOldServerMap.entrySet().iterator().hasNext()
+      ) {
+        // Get the first region index in the historical cache ratio list
+        Map.Entry<String, Pair<ServerName, Float>> regionEntry =
+          regionCacheRatioOnOldServerMap.entrySet().iterator().next();
+        String regionEncodedName = regionEntry.getKey();
+
+        RegionInfo regionInfo = getRegionInfoByEncodedName(cluster, regionEncodedName);
+        if (regionInfo == null) {
+          LOG.warn("Region {} does not exist", regionEncodedName);
+          regionCacheRatioOnOldServerMap.remove(regionEncodedName);
+          return BalanceAction.NULL_ACTION;
+        }
+        if (regionInfo.isMetaRegion() || regionInfo.getTable().isSystemTable()) {
+          regionCacheRatioOnOldServerMap.remove(regionEncodedName);
+          return BalanceAction.NULL_ACTION;
+        }
+
+        int regionIndex = cluster.regionsToIndex.get(regionInfo);
+
+        // Get the current host name for this region
+        thisServer = cluster.regionIndexToServerIndex[regionIndex];
+
+        // Get the old server index
+        otherServer = cluster.serversToIndex.get(regionEntry.getValue().getFirst().getAddress());
+
+        regionCacheRatioOnOldServerMap.remove(regionEncodedName);
+
+        if (otherServer < 0) {
+          // The old server has been moved to other host and hence, the region cannot be moved back
+          // to the old server
+          if (LOG.isDebugEnabled()) {
+            LOG.debug(
+              "CacheAwareSkewnessCandidateGenerator: Region {} not moved to the old "
+                + "server {} as the server does not exist",
+              regionEncodedName, regionEntry.getValue().getFirst().getHostname());
+          }
+          return BalanceAction.NULL_ACTION;
+        }
+
+        if (LOG.isDebugEnabled()) {
+          LOG.debug(
+            "CacheAwareSkewnessCandidateGenerator: Region {} moved from {} to {} as it "
+              + "was hosted their earlier",
+            regionEncodedName, cluster.servers[thisServer].getHostname(),
+            cluster.servers[otherServer].getHostname());
+        }
+
+        return getAction(thisServer, regionIndex, otherServer, -1);
+      }
+
+      if (thisServer < 0 || otherServer < 0) {
+        return BalanceAction.NULL_ACTION;
+      }
+
+      int regionIndexToMove = pickLeastCachedRegion(cluster, thisServer);
+      if (regionIndexToMove < 0) {
+        if (LOG.isDebugEnabled()) {
+          LOG.debug("CacheAwareSkewnessCandidateGenerator: No region found for movement");
+        }
+        return BalanceAction.NULL_ACTION;
+      }
+      if (LOG.isDebugEnabled()) {
+        LOG.debug(
+          "CacheAwareSkewnessCandidateGenerator: Region {} moved from {} to {} as it is "
+            + "least cached on current server",
+          cluster.regions[regionIndexToMove].getEncodedName(),
+          cluster.servers[thisServer].getHostname(), cluster.servers[otherServer].getHostname());
+      }
+      return getAction(thisServer, regionIndexToMove, otherServer, -1);
+    }
+
+    private int pickLeastCachedRegion(BalancerClusterState cluster, int thisServer) {
+      float minCacheRatio = Float.MAX_VALUE;
+      int leastCachedRegion = -1;
+      for (int i = 0; i < cluster.regionsPerServer[thisServer].length; i++) {
+        int regionIndex = cluster.regionsPerServer[thisServer][i];
+
+        float cacheRatioOnCurrentServer =
+          cluster.getOrComputeRegionCacheRatio(regionIndex, thisServer);
+        if (cacheRatioOnCurrentServer < minCacheRatio) {
+          minCacheRatio = cacheRatioOnCurrentServer;
+          leastCachedRegion = regionIndex;
+        }
+      }
+      return leastCachedRegion;
+    }
+  }
+
+  static class CacheAwareRegionSkewnessCostFunction extends CostFunction {
+    static final String REGION_COUNT_SKEW_COST_KEY =
+      "hbase.master.balancer.stochastic.regionCountCost";
+    static final float DEFAULT_REGION_COUNT_SKEW_COST = 20;
+    private final DoubleArrayCost cost = new DoubleArrayCost();
+
+    CacheAwareRegionSkewnessCostFunction(Configuration conf) {
+      // Load multiplier should be the greatest as it is the most general way to balance data.
+      this.setMultiplier(conf.getFloat(REGION_COUNT_SKEW_COST_KEY, DEFAULT_REGION_COUNT_SKEW_COST));
+    }
+
+    @Override
+    void prepare(BalancerClusterState cluster) {
+      super.prepare(cluster);
+      cost.prepare(cluster.numServers);
+      cost.applyCostsChange(costs -> {
+        for (int i = 0; i < cluster.numServers; i++) {
+          costs[i] = cluster.regionsPerServer[i].length;
+        }
+      });
+    }
+
+    @Override
+    protected double cost() {
+      return cost.cost();
+    }
+
+    @Override
+    protected void regionMoved(int region, int oldServer, int newServer) {
+      cost.applyCostsChange(costs -> {
+        costs[oldServer] = cluster.regionsPerServer[oldServer].length;
+        costs[newServer] = cluster.regionsPerServer[newServer].length;
+      });
+    }
+
+    public final void updateWeight(double[] weights) {
+      weights[GeneratorFunctionType.LOAD.ordinal()] += cost();
+    }
+  }
+
+  static class CacheAwareCostFunction extends CostFunction {
+    private static final String CACHE_COST_KEY = "hbase.master.balancer.stochastic.cacheCost";
+    private double cacheRatio;
+    private double bestCacheRatio;
+
+    private static final float DEFAULT_CACHE_COST = 20;
+
+    CacheAwareCostFunction(Configuration conf) {
+      boolean isPersistentCache = conf.get(BUCKET_CACHE_PERSISTENT_PATH_KEY) != null;
+      // Disable the CacheAwareCostFunction if the cached file list persistence is not enabled
+      this.setMultiplier(
+        !isPersistentCache ? 0.0f : conf.getFloat(CACHE_COST_KEY, DEFAULT_CACHE_COST));
+      bestCacheRatio = 0.0;
+      cacheRatio = 0.0;
+    }
+
+    @Override
+    void prepare(BalancerClusterState cluster) {
+      super.prepare(cluster);
+      cacheRatio = 0.0;
+      bestCacheRatio = 0.0;
+
+      for (int region = 0; region < cluster.numRegions; region++) {
+        cacheRatio += cluster.getOrComputeWeightedRegionCacheRatio(region,
+          cluster.regionIndexToServerIndex[region]);
+        bestCacheRatio += cluster.getOrComputeWeightedRegionCacheRatio(region,
+          getServerWithBestCacheRatioForRegion(region));
+      }
+
+      cacheRatio = bestCacheRatio == 0 ? 1.0 : cacheRatio / bestCacheRatio;
+      if (LOG.isDebugEnabled()) {
+        LOG.debug("CacheAwareCostFunction: Cost: {}", 1 - cacheRatio);
+      }
+    }
+
+    @Override
+    protected double cost() {
+      return scale(0, 1, 1 - cacheRatio);
+    }
+
+    @Override
+    protected void regionMoved(int region, int oldServer, int newServer) {
+      double regionCacheRatioOnOldServer =
+        cluster.getOrComputeWeightedRegionCacheRatio(region, oldServer);
+      double regionCacheRatioOnNewServer =
+        cluster.getOrComputeWeightedRegionCacheRatio(region, newServer);
+      double cacheRatioDiff = regionCacheRatioOnNewServer - regionCacheRatioOnOldServer;
+      double normalizedDelta = bestCacheRatio == 0.0 ? 0.0 : cacheRatioDiff / bestCacheRatio;
+      cacheRatio += normalizedDelta;
+      if (LOG.isDebugEnabled() && (cacheRatio < 0.0 || cacheRatio > 1.0)) {
+        LOG.debug(
+          "CacheAwareCostFunction:regionMoved:region:{}:from:{}:to:{}:regionCacheRatioOnOldServer:{}:"
+            + "regionCacheRatioOnNewServer:{}:bestRegionCacheRatio:{}:cacheRatio:{}",
+          cluster.regions[region].getEncodedName(), cluster.servers[oldServer].getHostname(),
+          cluster.servers[newServer].getHostname(), regionCacheRatioOnOldServer,
+          regionCacheRatioOnNewServer, bestCacheRatio, cacheRatio);
+      }
+    }
+
+    private int getServerWithBestCacheRatioForRegion(int region) {
+      return cluster.getOrComputeServerWithBestRegionCachedRatio()[region];
+    }
+
+    @Override
+    public final void updateWeight(double[] weights) {
+      weights[GeneratorFunctionType.CACHE_RATIO.ordinal()] += cost();
+    }
+  }
+}
diff --git a/hbase-balancer/src/main/java/org/apache/hadoop/hbase/master/balancer/StochasticLoadBalancer.java b/hbase-balancer/src/main/java/org/apache/hadoop/hbase/master/balancer/StochasticLoadBalancer.java
index edf049e..e5cd544 100644
--- a/hbase-balancer/src/main/java/org/apache/hadoop/hbase/master/balancer/StochasticLoadBalancer.java
+++ b/hbase-balancer/src/main/java/org/apache/hadoop/hbase/master/balancer/StochasticLoadBalancer.java
@@ -42,6 +42,7 @@
 import org.apache.hadoop.hbase.master.RackManager;
 import org.apache.hadoop.hbase.master.RegionPlan;
 import org.apache.hadoop.hbase.util.EnvironmentEdgeManager;
+import org.apache.hadoop.hbase.util.Pair;
 import org.apache.hadoop.hbase.util.ReflectionUtils;
 import org.apache.yetus.audience.InterfaceAudience;
 import org.slf4j.Logger;
@@ -136,8 +137,10 @@
   private long maxRunningTime = DEFAULT_MAX_RUNNING_TIME;
   private int numRegionLoadsToRemember = DEFAULT_KEEP_REGION_LOADS;
   private float minCostNeedBalance = DEFAULT_MIN_COST_NEED_BALANCE;
+  Map<String, Pair<ServerName, Float>> regionCacheRatioOnOldServerMap = new HashMap<>();
 
-  private List<CostFunction> costFunctions; // FindBugs: Wants this protected; IS2_INCONSISTENT_SYNC
+  protected List<CostFunction> costFunctions; // FindBugs: Wants this protected;
+                                              // IS2_INCONSISTENT_SYNC
   // To save currently configed sum of multiplier. Defaulted at 1 for cases that carry high cost
   private float sumMultiplier;
   // to save and report costs to JMX
@@ -224,6 +227,24 @@
     return candidateGenerators;
   }
 
+  protected List<CostFunction> createCostFunctions(Configuration conf) {
+    List<CostFunction> costFunctions = new ArrayList<>();
+    addCostFunction(costFunctions, new RegionCountSkewCostFunction(conf));
+    addCostFunction(costFunctions, new PrimaryRegionCountSkewCostFunction(conf));
+    addCostFunction(costFunctions, new MoveCostFunction(conf, provider));
+    addCostFunction(costFunctions, localityCost);
+    addCostFunction(costFunctions, rackLocalityCost);
+    addCostFunction(costFunctions, new TableSkewCostFunction(conf));
+    addCostFunction(costFunctions, regionReplicaHostCostFunction);
+    addCostFunction(costFunctions, regionReplicaRackCostFunction);
+    addCostFunction(costFunctions, new ReadRequestCostFunction(conf));
+    addCostFunction(costFunctions, new CPRequestCostFunction(conf));
+    addCostFunction(costFunctions, new WriteRequestCostFunction(conf));
+    addCostFunction(costFunctions, new MemStoreSizeCostFunction(conf));
+    addCostFunction(costFunctions, new StoreFileCostFunction(conf));
+    return costFunctions;
+  }
+
   @Override
   protected void loadConf(Configuration conf) {
     super.loadConf(conf);
@@ -242,20 +263,7 @@
 
     regionReplicaHostCostFunction = new RegionReplicaHostCostFunction(conf);
     regionReplicaRackCostFunction = new RegionReplicaRackCostFunction(conf);
-    costFunctions = new ArrayList<>();
-    addCostFunction(new RegionCountSkewCostFunction(conf));
-    addCostFunction(new PrimaryRegionCountSkewCostFunction(conf));
-    addCostFunction(new MoveCostFunction(conf, provider));
-    addCostFunction(localityCost);
-    addCostFunction(rackLocalityCost);
-    addCostFunction(new TableSkewCostFunction(conf));
-    addCostFunction(regionReplicaHostCostFunction);
-    addCostFunction(regionReplicaRackCostFunction);
-    addCostFunction(new ReadRequestCostFunction(conf));
-    addCostFunction(new CPRequestCostFunction(conf));
-    addCostFunction(new WriteRequestCostFunction(conf));
-    addCostFunction(new MemStoreSizeCostFunction(conf));
-    addCostFunction(new StoreFileCostFunction(conf));
+    this.costFunctions = createCostFunctions(conf);
     loadCustomCostFunctions(conf);
 
     curFunctionCosts = new double[costFunctions.size()];
@@ -459,8 +467,8 @@
     // The clusterState that is given to this method contains the state
     // of all the regions in the table(s) (that's true today)
     // Keep track of servers to iterate through them.
-    BalancerClusterState cluster =
-      new BalancerClusterState(loadOfOneTable, loads, finder, rackManager);
+    BalancerClusterState cluster = new BalancerClusterState(loadOfOneTable, loads, finder,
+      rackManager, regionCacheRatioOnOldServerMap);
 
     long startTime = EnvironmentEdgeManager.currentTime();
 
@@ -568,7 +576,7 @@
     return null;
   }
 
-  private void sendRejectionReasonToRingBuffer(Supplier<String> reason,
+  protected void sendRejectionReasonToRingBuffer(Supplier<String> reason,
     List<CostFunction> costFunctions) {
     provider.recordBalancerRejection(() -> {
       BalancerRejection.Builder builder = new BalancerRejection.Builder().setReason(reason.get());
@@ -627,14 +635,14 @@
     }
   }
 
-  private void addCostFunction(CostFunction costFunction) {
+  private void addCostFunction(List<CostFunction> costFunctions, CostFunction costFunction) {
     float multiplier = costFunction.getMultiplier();
     if (multiplier > 0) {
       costFunctions.add(costFunction);
     }
   }
 
-  private String functionCost() {
+  protected String functionCost() {
     StringBuilder builder = new StringBuilder();
     for (CostFunction c : costFunctions) {
       builder.append(c.getClass().getSimpleName());
@@ -655,6 +663,12 @@
     return builder.toString();
   }
 
+  @RestrictedApi(explanation = "Should only be called in tests", link = "",
+      allowedOnPath = ".*(/src/test/.*|StochasticLoadBalancer).java")
+  List<CostFunction> getCostFunctions() {
+    return costFunctions;
+  }
+
   private String totalCostsPerFunc() {
     StringBuilder builder = new StringBuilder();
     for (CostFunction c : costFunctions) {
diff --git a/hbase-balancer/src/test/java/org/apache/hadoop/hbase/master/balancer/BalancerTestBase.java b/hbase-balancer/src/test/java/org/apache/hadoop/hbase/master/balancer/BalancerTestBase.java
index 9ea1c94..4a996e7 100644
--- a/hbase-balancer/src/test/java/org/apache/hadoop/hbase/master/balancer/BalancerTestBase.java
+++ b/hbase-balancer/src/test/java/org/apache/hadoop/hbase/master/balancer/BalancerTestBase.java
@@ -23,6 +23,7 @@
 import java.util.ArrayList;
 import java.util.HashMap;
 import java.util.HashSet;
+import java.util.LinkedHashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.NavigableSet;
@@ -376,6 +377,19 @@
     return servers;
   }
 
+  protected Map<ServerName, List<RegionInfo>> mockClusterServersUnsorted(int[] mockCluster,
+    int numTables) {
+    int numServers = mockCluster.length;
+    Map<ServerName, List<RegionInfo>> servers = new LinkedHashMap<>();
+    for (int i = 0; i < numServers; i++) {
+      int numRegions = mockCluster[i];
+      ServerAndLoad sal = randomServer(0);
+      List<RegionInfo> regions = randomRegions(numRegions, numTables);
+      servers.put(sal.getServerName(), regions);
+    }
+    return servers;
+  }
+
   protected TreeMap<ServerName, List<RegionInfo>> mockUniformClusterServers(int[] mockCluster) {
     int numServers = mockCluster.length;
     TreeMap<ServerName, List<RegionInfo>> servers = new TreeMap<>();
diff --git a/hbase-balancer/src/test/java/org/apache/hadoop/hbase/master/balancer/TestStochasticLoadBalancer.java b/hbase-balancer/src/test/java/org/apache/hadoop/hbase/master/balancer/TestStochasticLoadBalancer.java
index 21f3a3b..cc16cfe 100644
--- a/hbase-balancer/src/test/java/org/apache/hadoop/hbase/master/balancer/TestStochasticLoadBalancer.java
+++ b/hbase-balancer/src/test/java/org/apache/hadoop/hbase/master/balancer/TestStochasticLoadBalancer.java
@@ -139,6 +139,8 @@
       when(rl.getWriteRequestCount()).thenReturn(0L);
       when(rl.getMemStoreSize()).thenReturn(Size.ZERO);
       when(rl.getStoreFileSize()).thenReturn(Size.ZERO);
+      when(rl.getRegionSizeMB()).thenReturn(Size.ZERO);
+      when(rl.getCurrentRegionCachedRatio()).thenReturn(0.0f);
       regionLoadMap.put(info.getRegionName(), rl);
     }
     when(serverMetrics.getRegionMetrics()).thenReturn(regionLoadMap);
@@ -213,6 +215,8 @@
       when(rl.getWriteRequestCount()).thenReturn(0L);
       when(rl.getMemStoreSize()).thenReturn(Size.ZERO);
       when(rl.getStoreFileSize()).thenReturn(new Size(i, Size.Unit.MEGABYTE));
+      when(rl.getRegionSizeMB()).thenReturn(Size.ZERO);
+      when(rl.getCurrentRegionCachedRatio()).thenReturn(0.0f);
 
       Map<byte[], RegionMetrics> regionLoadMap = new TreeMap<>(Bytes.BYTES_COMPARATOR);
       regionLoadMap.put(Bytes.toBytes(REGION_KEY), rl);
diff --git a/hbase-common/src/main/java/org/apache/hadoop/hbase/HConstants.java b/hbase-common/src/main/java/org/apache/hadoop/hbase/HConstants.java
index 1247997..2aa9ecf 100644
--- a/hbase-common/src/main/java/org/apache/hadoop/hbase/HConstants.java
+++ b/hbase-common/src/main/java/org/apache/hadoop/hbase/HConstants.java
@@ -1336,6 +1336,18 @@
   public static final String BUCKET_CACHE_SIZE_KEY = "hbase.bucketcache.size";
 
   /**
+   * If the chosen ioengine can persist its state across restarts, the path to the file to persist
+   * to. This file is NOT the data file. It is a file into which we will serialize the map of what
+   * is in the data file. For example, if you pass the following argument as
+   * BUCKET_CACHE_IOENGINE_KEY ("hbase.bucketcache.ioengine"),
+   * <code>file:/tmp/bucketcache.data </code>, then we will write the bucketcache data to the file
+   * <code>/tmp/bucketcache.data</code> but the metadata on where the data is in the supplied file
+   * is an in-memory map that needs to be persisted across restarts. Where to store this in-memory
+   * state is what you supply here: e.g. <code>/tmp/bucketcache.map</code>.
+   */
+  public static final String BUCKET_CACHE_PERSISTENT_PATH_KEY = "hbase.bucketcache.persistent.path";
+
+  /**
    * HConstants for fast fail on the client side follow
    */
   /**
diff --git a/hbase-server/src/main/java/org/apache/hadoop/hbase/io/hfile/BlockCacheFactory.java b/hbase-server/src/main/java/org/apache/hadoop/hbase/io/hfile/BlockCacheFactory.java
index 38a296a..6956d58 100644
--- a/hbase-server/src/main/java/org/apache/hadoop/hbase/io/hfile/BlockCacheFactory.java
+++ b/hbase-server/src/main/java/org/apache/hadoop/hbase/io/hfile/BlockCacheFactory.java
@@ -18,6 +18,7 @@
 package org.apache.hadoop.hbase.io.hfile;
 
 import static org.apache.hadoop.hbase.HConstants.BUCKET_CACHE_IOENGINE_KEY;
+import static org.apache.hadoop.hbase.HConstants.BUCKET_CACHE_PERSISTENT_PATH_KEY;
 import static org.apache.hadoop.hbase.HConstants.BUCKET_CACHE_SIZE_KEY;
 
 import java.io.IOException;
@@ -47,18 +48,6 @@
   public static final String BLOCKCACHE_POLICY_KEY = "hfile.block.cache.policy";
   public static final String BLOCKCACHE_POLICY_DEFAULT = "LRU";
 
-  /**
-   * If the chosen ioengine can persist its state across restarts, the path to the file to persist
-   * to. This file is NOT the data file. It is a file into which we will serialize the map of what
-   * is in the data file. For example, if you pass the following argument as
-   * BUCKET_CACHE_IOENGINE_KEY ("hbase.bucketcache.ioengine"),
-   * <code>file:/tmp/bucketcache.data </code>, then we will write the bucketcache data to the file
-   * <code>/tmp/bucketcache.data</code> but the metadata on where the data is in the supplied file
-   * is an in-memory map that needs to be persisted across restarts. Where to store this in-memory
-   * state is what you supply here: e.g. <code>/tmp/bucketcache.map</code>.
-   */
-  public static final String BUCKET_CACHE_PERSISTENT_PATH_KEY = "hbase.bucketcache.persistent.path";
-
   public static final String BUCKET_CACHE_WRITER_THREADS_KEY = "hbase.bucketcache.writer.threads";
 
   public static final String BUCKET_CACHE_WRITER_QUEUE_KEY = "hbase.bucketcache.writer.queuelength";
diff --git a/hbase-server/src/main/java/org/apache/hadoop/hbase/io/hfile/bucket/BucketCache.java b/hbase-server/src/main/java/org/apache/hadoop/hbase/io/hfile/bucket/BucketCache.java
index 64162bb..0faf510 100644
--- a/hbase-server/src/main/java/org/apache/hadoop/hbase/io/hfile/bucket/BucketCache.java
+++ b/hbase-server/src/main/java/org/apache/hadoop/hbase/io/hfile/bucket/BucketCache.java
@@ -342,6 +342,7 @@
       } catch (IOException ioex) {
         backingMap.clear();
         fullyCachedFiles.clear();
+        regionCachedSizeMap.clear();
         LOG.error("Can't restore from file[" + persistencePath + "] because of ", ioex);
       }
     }
@@ -1477,6 +1478,7 @@
       // If persistent ioengine and a path, we will serialize out the backingMap.
       this.backingMap.clear();
       this.fullyCachedFiles.clear();
+      this.regionCachedSizeMap.clear();
     }
   }
 
diff --git a/hbase-server/src/test/java/org/apache/hadoop/hbase/master/balancer/TestCacheAwareLoadBalancer.java b/hbase-server/src/test/java/org/apache/hadoop/hbase/master/balancer/TestCacheAwareLoadBalancer.java
new file mode 100644
index 0000000..3ecd8dc
--- /dev/null
+++ b/hbase-server/src/test/java/org/apache/hadoop/hbase/master/balancer/TestCacheAwareLoadBalancer.java
@@ -0,0 +1,397 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.master.balancer;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.Set;
+import java.util.TreeMap;
+import java.util.concurrent.ThreadLocalRandom;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.hbase.ClusterMetrics;
+import org.apache.hadoop.hbase.HBaseClassTestRule;
+import org.apache.hadoop.hbase.HBaseConfiguration;
+import org.apache.hadoop.hbase.HConstants;
+import org.apache.hadoop.hbase.RegionMetrics;
+import org.apache.hadoop.hbase.ServerMetrics;
+import org.apache.hadoop.hbase.ServerName;
+import org.apache.hadoop.hbase.Size;
+import org.apache.hadoop.hbase.TableName;
+import org.apache.hadoop.hbase.client.RegionInfo;
+import org.apache.hadoop.hbase.client.TableDescriptor;
+import org.apache.hadoop.hbase.client.TableDescriptorBuilder;
+import org.apache.hadoop.hbase.master.RegionPlan;
+import org.apache.hadoop.hbase.testclassification.LargeTests;
+import org.apache.hadoop.hbase.util.Bytes;
+import org.junit.BeforeClass;
+import org.junit.ClassRule;
+import org.junit.Test;
+import org.junit.experimental.categories.Category;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.hbase.thirdparty.com.google.common.collect.Lists;
+
+@Category({ LargeTests.class })
+public class TestCacheAwareLoadBalancer extends BalancerTestBase {
+  @ClassRule
+  public static final HBaseClassTestRule CLASS_RULE =
+    HBaseClassTestRule.forClass(TestCacheAwareLoadBalancer.class);
+
+  private static final Logger LOG = LoggerFactory.getLogger(TestCacheAwareLoadBalancer.class);
+
+  private static CacheAwareLoadBalancer loadBalancer;
+
+  static List<ServerName> servers;
+
+  static List<TableDescriptor> tableDescs;
+
+  static Map<TableName, String> tableMap = new HashMap<>();
+
+  static TableName[] tables = new TableName[] { TableName.valueOf("dt1"), TableName.valueOf("dt2"),
+    TableName.valueOf("dt3"), TableName.valueOf("dt4") };
+
+  private static List<ServerName> generateServers(int numServers) {
+    List<ServerName> servers = new ArrayList<>(numServers);
+    Random rand = ThreadLocalRandom.current();
+    for (int i = 0; i < numServers; i++) {
+      String host = "server" + rand.nextInt(100000);
+      int port = rand.nextInt(60000);
+      servers.add(ServerName.valueOf(host, port, -1));
+    }
+    return servers;
+  }
+
+  private static List<TableDescriptor> constructTableDesc(boolean hasBogusTable) {
+    List<TableDescriptor> tds = Lists.newArrayList();
+    for (int i = 0; i < tables.length; i++) {
+      TableDescriptor htd = TableDescriptorBuilder.newBuilder(tables[i]).build();
+      tds.add(htd);
+    }
+    return tds;
+  }
+
+  private ServerMetrics mockServerMetricsWithRegionCacheInfo(ServerName server,
+    List<RegionInfo> regionsOnServer, float currentCacheRatio, List<RegionInfo> oldRegionCacheInfo,
+    int oldRegionCachedSize, int regionSize) {
+    ServerMetrics serverMetrics = mock(ServerMetrics.class);
+    Map<byte[], RegionMetrics> regionLoadMap = new TreeMap<>(Bytes.BYTES_COMPARATOR);
+    for (RegionInfo info : regionsOnServer) {
+      RegionMetrics rl = mock(RegionMetrics.class);
+      when(rl.getReadRequestCount()).thenReturn(0L);
+      when(rl.getWriteRequestCount()).thenReturn(0L);
+      when(rl.getMemStoreSize()).thenReturn(Size.ZERO);
+      when(rl.getStoreFileSize()).thenReturn(Size.ZERO);
+      when(rl.getCurrentRegionCachedRatio()).thenReturn(currentCacheRatio);
+      when(rl.getRegionSizeMB()).thenReturn(new Size(regionSize, Size.Unit.MEGABYTE));
+      regionLoadMap.put(info.getRegionName(), rl);
+    }
+    when(serverMetrics.getRegionMetrics()).thenReturn(regionLoadMap);
+    Map<String, Integer> oldCacheRatioMap = new HashMap<>();
+    for (RegionInfo info : oldRegionCacheInfo) {
+      oldCacheRatioMap.put(info.getEncodedName(), oldRegionCachedSize);
+    }
+    when(serverMetrics.getRegionCachedInfo()).thenReturn(oldCacheRatioMap);
+    return serverMetrics;
+  }
+
+  @BeforeClass
+  public static void beforeAllTests() throws Exception {
+    servers = generateServers(3);
+    tableDescs = constructTableDesc(false);
+    Configuration conf = HBaseConfiguration.create();
+    conf.set(HConstants.BUCKET_CACHE_PERSISTENT_PATH_KEY, "prefetch_file_list");
+    loadBalancer = new CacheAwareLoadBalancer();
+    loadBalancer.setClusterInfoProvider(new DummyClusterInfoProvider(conf));
+    loadBalancer.loadConf(conf);
+  }
+
+  @Test
+  public void testRegionsNotCachedOnOldServerAndCurrentServer() throws Exception {
+    // The regions are not cached on old server as well as the current server. This causes
+    // skewness in the region allocation which should be fixed by the balancer
+
+    Map<ServerName, List<RegionInfo>> clusterState = new HashMap<>();
+    ServerName server0 = servers.get(0);
+    ServerName server1 = servers.get(1);
+    ServerName server2 = servers.get(2);
+
+    // Simulate that the regions previously hosted by server1 are now hosted on server0
+    List<RegionInfo> regionsOnServer0 = randomRegions(10);
+    List<RegionInfo> regionsOnServer1 = randomRegions(0);
+    List<RegionInfo> regionsOnServer2 = randomRegions(5);
+
+    clusterState.put(server0, regionsOnServer0);
+    clusterState.put(server1, regionsOnServer1);
+    clusterState.put(server2, regionsOnServer2);
+
+    // Mock cluster metrics
+    Map<ServerName, ServerMetrics> serverMetricsMap = new TreeMap<>();
+    serverMetricsMap.put(server0, mockServerMetricsWithRegionCacheInfo(server0, regionsOnServer0,
+      0.0f, new ArrayList<>(), 0, 10));
+    serverMetricsMap.put(server1, mockServerMetricsWithRegionCacheInfo(server1, regionsOnServer1,
+      0.0f, new ArrayList<>(), 0, 10));
+    serverMetricsMap.put(server2, mockServerMetricsWithRegionCacheInfo(server2, regionsOnServer2,
+      0.0f, new ArrayList<>(), 0, 10));
+    ClusterMetrics clusterMetrics = mock(ClusterMetrics.class);
+    when(clusterMetrics.getLiveServerMetrics()).thenReturn(serverMetricsMap);
+    loadBalancer.updateClusterMetrics(clusterMetrics);
+
+    Map<TableName, Map<ServerName, List<RegionInfo>>> LoadOfAllTable =
+      (Map) mockClusterServersWithTables(clusterState);
+    List<RegionPlan> plans = loadBalancer.balanceCluster(LoadOfAllTable);
+    Set<RegionInfo> regionsMovedFromServer0 = new HashSet<>();
+    Map<ServerName, List<RegionInfo>> targetServers = new HashMap<>();
+    for (RegionPlan plan : plans) {
+      if (plan.getSource().equals(server0)) {
+        regionsMovedFromServer0.add(plan.getRegionInfo());
+        if (!targetServers.containsKey(plan.getDestination())) {
+          targetServers.put(plan.getDestination(), new ArrayList<>());
+        }
+        targetServers.get(plan.getDestination()).add(plan.getRegionInfo());
+      }
+    }
+    // should move 5 regions from server0 to server 1
+    assertEquals(5, regionsMovedFromServer0.size());
+    assertEquals(5, targetServers.get(server1).size());
+  }
+
+  @Test
+  public void testRegionsPartiallyCachedOnOldServerAndNotCachedOnCurrentServer() throws Exception {
+    // The regions are partially cached on old server but not cached on the current server
+
+    Map<ServerName, List<RegionInfo>> clusterState = new HashMap<>();
+    ServerName server0 = servers.get(0);
+    ServerName server1 = servers.get(1);
+    ServerName server2 = servers.get(2);
+
+    // Simulate that the regions previously hosted by server1 are now hosted on server0
+    List<RegionInfo> regionsOnServer0 = randomRegions(10);
+    List<RegionInfo> regionsOnServer1 = randomRegions(0);
+    List<RegionInfo> regionsOnServer2 = randomRegions(5);
+
+    clusterState.put(server0, regionsOnServer0);
+    clusterState.put(server1, regionsOnServer1);
+    clusterState.put(server2, regionsOnServer2);
+
+    // Mock cluster metrics
+
+    // Mock 5 regions from server0 were previously hosted on server1
+    List<RegionInfo> oldCachedRegions = regionsOnServer0.subList(5, regionsOnServer0.size() - 1);
+
+    Map<ServerName, ServerMetrics> serverMetricsMap = new TreeMap<>();
+    serverMetricsMap.put(server0, mockServerMetricsWithRegionCacheInfo(server0, regionsOnServer0,
+      0.0f, new ArrayList<>(), 0, 10));
+    serverMetricsMap.put(server1, mockServerMetricsWithRegionCacheInfo(server1, regionsOnServer1,
+      0.0f, oldCachedRegions, 6, 10));
+    serverMetricsMap.put(server2, mockServerMetricsWithRegionCacheInfo(server2, regionsOnServer2,
+      0.0f, new ArrayList<>(), 0, 10));
+    ClusterMetrics clusterMetrics = mock(ClusterMetrics.class);
+    when(clusterMetrics.getLiveServerMetrics()).thenReturn(serverMetricsMap);
+    loadBalancer.updateClusterMetrics(clusterMetrics);
+
+    Map<TableName, Map<ServerName, List<RegionInfo>>> LoadOfAllTable =
+      (Map) mockClusterServersWithTables(clusterState);
+    List<RegionPlan> plans = loadBalancer.balanceCluster(LoadOfAllTable);
+    Set<RegionInfo> regionsMovedFromServer0 = new HashSet<>();
+    Map<ServerName, List<RegionInfo>> targetServers = new HashMap<>();
+    for (RegionPlan plan : plans) {
+      if (plan.getSource().equals(server0)) {
+        regionsMovedFromServer0.add(plan.getRegionInfo());
+        if (!targetServers.containsKey(plan.getDestination())) {
+          targetServers.put(plan.getDestination(), new ArrayList<>());
+        }
+        targetServers.get(plan.getDestination()).add(plan.getRegionInfo());
+      }
+    }
+    // should move 5 regions from server0 to server1
+    assertEquals(5, regionsMovedFromServer0.size());
+    assertEquals(5, targetServers.get(server1).size());
+    assertTrue(targetServers.get(server1).containsAll(oldCachedRegions));
+  }
+
+  @Test
+  public void testRegionsFullyCachedOnOldServerAndNotCachedOnCurrentServers() throws Exception {
+    // The regions are fully cached on old server
+
+    Map<ServerName, List<RegionInfo>> clusterState = new HashMap<>();
+    ServerName server0 = servers.get(0);
+    ServerName server1 = servers.get(1);
+    ServerName server2 = servers.get(2);
+
+    // Simulate that the regions previously hosted by server1 are now hosted on server0
+    List<RegionInfo> regionsOnServer0 = randomRegions(10);
+    List<RegionInfo> regionsOnServer1 = randomRegions(0);
+    List<RegionInfo> regionsOnServer2 = randomRegions(5);
+
+    clusterState.put(server0, regionsOnServer0);
+    clusterState.put(server1, regionsOnServer1);
+    clusterState.put(server2, regionsOnServer2);
+
+    // Mock cluster metrics
+
+    // Mock 5 regions from server0 were previously hosted on server1
+    List<RegionInfo> oldCachedRegions = regionsOnServer0.subList(5, regionsOnServer0.size() - 1);
+
+    Map<ServerName, ServerMetrics> serverMetricsMap = new TreeMap<>();
+    serverMetricsMap.put(server0, mockServerMetricsWithRegionCacheInfo(server0, regionsOnServer0,
+      0.0f, new ArrayList<>(), 0, 10));
+    serverMetricsMap.put(server1, mockServerMetricsWithRegionCacheInfo(server1, regionsOnServer1,
+      0.0f, oldCachedRegions, 10, 10));
+    serverMetricsMap.put(server2, mockServerMetricsWithRegionCacheInfo(server2, regionsOnServer2,
+      0.0f, new ArrayList<>(), 0, 10));
+    ClusterMetrics clusterMetrics = mock(ClusterMetrics.class);
+    when(clusterMetrics.getLiveServerMetrics()).thenReturn(serverMetricsMap);
+    loadBalancer.updateClusterMetrics(clusterMetrics);
+
+    Map<TableName, Map<ServerName, List<RegionInfo>>> LoadOfAllTable =
+      (Map) mockClusterServersWithTables(clusterState);
+    List<RegionPlan> plans = loadBalancer.balanceCluster(LoadOfAllTable);
+    Set<RegionInfo> regionsMovedFromServer0 = new HashSet<>();
+    Map<ServerName, List<RegionInfo>> targetServers = new HashMap<>();
+    for (RegionPlan plan : plans) {
+      if (plan.getSource().equals(server0)) {
+        regionsMovedFromServer0.add(plan.getRegionInfo());
+        if (!targetServers.containsKey(plan.getDestination())) {
+          targetServers.put(plan.getDestination(), new ArrayList<>());
+        }
+        targetServers.get(plan.getDestination()).add(plan.getRegionInfo());
+      }
+    }
+    // should move 5 regions from server0 to server1
+    assertEquals(5, regionsMovedFromServer0.size());
+    assertEquals(5, targetServers.get(server1).size());
+    assertTrue(targetServers.get(server1).containsAll(oldCachedRegions));
+  }
+
+  @Test
+  public void testRegionsFullyCachedOnOldAndCurrentServers() throws Exception {
+    // The regions are fully cached on old server
+
+    Map<ServerName, List<RegionInfo>> clusterState = new HashMap<>();
+    ServerName server0 = servers.get(0);
+    ServerName server1 = servers.get(1);
+    ServerName server2 = servers.get(2);
+
+    // Simulate that the regions previously hosted by server1 are now hosted on server0
+    List<RegionInfo> regionsOnServer0 = randomRegions(10);
+    List<RegionInfo> regionsOnServer1 = randomRegions(0);
+    List<RegionInfo> regionsOnServer2 = randomRegions(5);
+
+    clusterState.put(server0, regionsOnServer0);
+    clusterState.put(server1, regionsOnServer1);
+    clusterState.put(server2, regionsOnServer2);
+
+    // Mock cluster metrics
+
+    // Mock 5 regions from server0 were previously hosted on server1
+    List<RegionInfo> oldCachedRegions = regionsOnServer0.subList(5, regionsOnServer0.size() - 1);
+
+    Map<ServerName, ServerMetrics> serverMetricsMap = new TreeMap<>();
+    serverMetricsMap.put(server0, mockServerMetricsWithRegionCacheInfo(server0, regionsOnServer0,
+      1.0f, new ArrayList<>(), 0, 10));
+    serverMetricsMap.put(server1, mockServerMetricsWithRegionCacheInfo(server1, regionsOnServer1,
+      1.0f, oldCachedRegions, 10, 10));
+    serverMetricsMap.put(server2, mockServerMetricsWithRegionCacheInfo(server2, regionsOnServer2,
+      1.0f, new ArrayList<>(), 0, 10));
+    ClusterMetrics clusterMetrics = mock(ClusterMetrics.class);
+    when(clusterMetrics.getLiveServerMetrics()).thenReturn(serverMetricsMap);
+    loadBalancer.updateClusterMetrics(clusterMetrics);
+
+    Map<TableName, Map<ServerName, List<RegionInfo>>> LoadOfAllTable =
+      (Map) mockClusterServersWithTables(clusterState);
+    List<RegionPlan> plans = loadBalancer.balanceCluster(LoadOfAllTable);
+    Set<RegionInfo> regionsMovedFromServer0 = new HashSet<>();
+    Map<ServerName, List<RegionInfo>> targetServers = new HashMap<>();
+    for (RegionPlan plan : plans) {
+      if (plan.getSource().equals(server0)) {
+        regionsMovedFromServer0.add(plan.getRegionInfo());
+        if (!targetServers.containsKey(plan.getDestination())) {
+          targetServers.put(plan.getDestination(), new ArrayList<>());
+        }
+        targetServers.get(plan.getDestination()).add(plan.getRegionInfo());
+      }
+    }
+    // should move 5 regions from server0 to server1
+    assertEquals(5, regionsMovedFromServer0.size());
+    assertEquals(5, targetServers.get(server1).size());
+    assertTrue(targetServers.get(server1).containsAll(oldCachedRegions));
+  }
+
+  @Test
+  public void testRegionsPartiallyCachedOnOldServerAndCurrentServer() throws Exception {
+    // The regions are partially cached on old server
+
+    Map<ServerName, List<RegionInfo>> clusterState = new HashMap<>();
+    ServerName server0 = servers.get(0);
+    ServerName server1 = servers.get(1);
+    ServerName server2 = servers.get(2);
+
+    // Simulate that the regions previously hosted by server1 are now hosted on server0
+    List<RegionInfo> regionsOnServer0 = randomRegions(10);
+    List<RegionInfo> regionsOnServer1 = randomRegions(0);
+    List<RegionInfo> regionsOnServer2 = randomRegions(5);
+
+    clusterState.put(server0, regionsOnServer0);
+    clusterState.put(server1, regionsOnServer1);
+    clusterState.put(server2, regionsOnServer2);
+
+    // Mock cluster metrics
+
+    // Mock 5 regions from server0 were previously hosted on server1
+    List<RegionInfo> oldCachedRegions = regionsOnServer0.subList(5, regionsOnServer0.size() - 1);
+
+    Map<ServerName, ServerMetrics> serverMetricsMap = new TreeMap<>();
+    serverMetricsMap.put(server0, mockServerMetricsWithRegionCacheInfo(server0, regionsOnServer0,
+      0.2f, new ArrayList<>(), 0, 10));
+    serverMetricsMap.put(server1, mockServerMetricsWithRegionCacheInfo(server1, regionsOnServer1,
+      0.0f, oldCachedRegions, 6, 10));
+    serverMetricsMap.put(server2, mockServerMetricsWithRegionCacheInfo(server2, regionsOnServer2,
+      1.0f, new ArrayList<>(), 0, 10));
+    ClusterMetrics clusterMetrics = mock(ClusterMetrics.class);
+    when(clusterMetrics.getLiveServerMetrics()).thenReturn(serverMetricsMap);
+    loadBalancer.updateClusterMetrics(clusterMetrics);
+
+    Map<TableName, Map<ServerName, List<RegionInfo>>> LoadOfAllTable =
+      (Map) mockClusterServersWithTables(clusterState);
+    List<RegionPlan> plans = loadBalancer.balanceCluster(LoadOfAllTable);
+    Set<RegionInfo> regionsMovedFromServer0 = new HashSet<>();
+    Map<ServerName, List<RegionInfo>> targetServers = new HashMap<>();
+    for (RegionPlan plan : plans) {
+      if (plan.getSource().equals(server0)) {
+        regionsMovedFromServer0.add(plan.getRegionInfo());
+        if (!targetServers.containsKey(plan.getDestination())) {
+          targetServers.put(plan.getDestination(), new ArrayList<>());
+        }
+        targetServers.get(plan.getDestination()).add(plan.getRegionInfo());
+      }
+    }
+    assertEquals(5, regionsMovedFromServer0.size());
+    assertEquals(5, targetServers.get(server1).size());
+    assertTrue(targetServers.get(server1).containsAll(oldCachedRegions));
+  }
+}
diff --git a/hbase-server/src/test/java/org/apache/hadoop/hbase/master/balancer/TestCacheAwareLoadBalancerCostFunctions.java b/hbase-server/src/test/java/org/apache/hadoop/hbase/master/balancer/TestCacheAwareLoadBalancerCostFunctions.java
new file mode 100644
index 0000000..448e576
--- /dev/null
+++ b/hbase-server/src/test/java/org/apache/hadoop/hbase/master/balancer/TestCacheAwareLoadBalancerCostFunctions.java
@@ -0,0 +1,316 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hbase.master.balancer;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Map;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.hbase.HBaseClassTestRule;
+import org.apache.hadoop.hbase.HConstants;
+import org.apache.hadoop.hbase.ServerName;
+import org.apache.hadoop.hbase.testclassification.MasterTests;
+import org.apache.hadoop.hbase.testclassification.MediumTests;
+import org.apache.hadoop.hbase.util.Pair;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.ClassRule;
+import org.junit.Test;
+import org.junit.experimental.categories.Category;
+
+@Category({ MasterTests.class, MediumTests.class })
+public class TestCacheAwareLoadBalancerCostFunctions extends StochasticBalancerTestBase {
+
+  @ClassRule
+  public static final HBaseClassTestRule CLASS_RULE =
+    HBaseClassTestRule.forClass(TestCacheAwareLoadBalancerCostFunctions.class);
+
+  // Mapping of test -> expected cache cost
+  private final float[] expectedCacheCost = { 0.0f, 0.0f, 0.5f, 1.0f, 0.0f, 0.572f, 0.0f, 0.075f };
+
+  /**
+   * Data set to testCacheCost: [test][0][0] = mapping of server to number of regions it hosts
+   * [test][region + 1][0] = server that region is hosted on [test][region + 1][server + 1] = size
+   * of region cached on server
+   */
+  private final int[][][] clusterRegionCacheRatioMocks = new int[][][] {
+    // Test 1: each region is entirely on server that hosts it
+    // Cost of moving the regions in this case should be high as the regions are fully cached
+    // on the server they are currently hosted on
+    new int[][] { new int[] { 2, 1, 1 }, // Server 0 has 2, server 1 has 1 and server 2 has 1
+      // region(s) hosted respectively
+      new int[] { 0, 100, 0, 0 }, // region 0 is hosted and cached only on server 0
+      new int[] { 0, 100, 0, 0 }, // region 1 is hosted and cached only on server 0
+      new int[] { 1, 0, 100, 0 }, // region 2 is hosted and cached only on server 1
+      new int[] { 2, 0, 0, 100 }, // region 3 is hosted and cached only on server 2
+    },
+
+    // Test 2: each region is cached completely on the server it is currently hosted on,
+    // but it was also cached on some other server historically
+    // Cost of moving the regions in this case should be high as the regions are fully cached
+    // on the server they are currently hosted on. Although, the regions were previously hosted and
+    // cached on some other server, since they are completely cached on the new server,
+    // there is no need to move the regions back to the previously hosting cluster
+    new int[][] { new int[] { 1, 2, 1 }, // Server 0 has 1, server 1 has 2 and server 2 has 1
+      // region(s) hosted respectively
+      new int[] { 0, 100, 0, 100 }, // region 0 is hosted and currently cached on server 0,
+      // but previously cached completely on server 2
+      new int[] { 1, 100, 100, 0 }, // region 1 is hosted and currently cached on server 1,
+      // but previously cached completely on server 0
+      new int[] { 1, 0, 100, 100 }, // region 2 is hosted and currently cached on server 1,
+      // but previously cached on server 2
+      new int[] { 2, 0, 100, 100 }, // region 3 is hosted and currently cached on server 2,
+    // but previously cached on server 1
+    },
+
+    // Test 3: The regions were hosted and fully cached on a server but later moved to other
+    // because of server crash procedure. The regions are partially cached on the server they
+    // are currently hosted on
+    new int[][] { new int[] { 1, 2, 1 }, new int[] { 0, 50, 0, 100 }, // Region 0 is currently
+      // hosted and partially
+      // cached on
+      // server 0, but was fully
+      // cached on server 2
+      // previously
+      new int[] { 1, 100, 50, 0 }, // Region 1 is currently hosted and partially cached on
+      // server 1, but was fully cached on server 0 previously
+      new int[] { 1, 0, 50, 100 }, // Region 2 is currently hosted and partially cached on
+      // server 1, but was fully cached on server 2 previously
+      new int[] { 2, 0, 100, 50 }, // Region 3 is currently hosted and partially cached on
+    // server 2, but was fully cached on server 1 previously
+    },
+
+    // Test 4: The regions were hosted and fully cached on a server, but later moved to other
+    // server because of server crash procedure. The regions are not at all cached on the server
+    // they are currently hosted on
+    new int[][] { new int[] { 1, 1, 2 }, new int[] { 0, 0, 0, 100 }, // Region 0 is currently hosted
+      // but not cached on server
+      // 0,
+      // but was fully cached on
+      // server 2 previously
+      new int[] { 1, 100, 0, 0 }, // Region 1 is currently hosted but not cached on server 1,
+      // but was fully cached on server 0 previously
+      new int[] { 2, 0, 100, 0 }, // Region 2 is currently hosted but not cached on server 2,
+      // but was fully cached on server 1 previously
+      new int[] { 2, 100, 0, 0 }, // Region 3 is currently hosted but not cached on server 2,
+    // but was fully cached on server 1 previously
+    },
+
+    // Test 5: The regions were partially cached on old servers, before moving to the new server
+    // where also, they are partially cached
+    new int[][] { new int[] { 2, 1, 1 }, new int[] { 0, 50, 50, 0 }, // Region 0 is hosted and
+      // partially cached on
+      // server 0, but
+      // was previously hosted and
+      // partially cached on
+      // server 1
+      new int[] { 0, 50, 0, 50 }, // Region 1 is hosted and partially cached on server 0, but
+      // was previously hosted and partially cached on server 2
+      new int[] { 1, 0, 50, 50 }, // Region 2 is hosted and partially cached on server 1, but
+      // was previously hosted and partially cached on server 2
+      new int[] { 2, 0, 50, 50 }, // Region 3 is hosted and partially cached on server 2, but
+    // was previously hosted and partially cached on server 1
+    },
+
+    // Test 6: The regions are less cached on the new servers as compared to what they were
+    // cached on the server before they were moved to the new servers
+    new int[][] { new int[] { 1, 2, 1 }, new int[] { 0, 30, 70, 0 }, // Region 0 is hosted and
+      // cached 30% on server 0,
+      // but was
+      // previously hosted and
+      // cached 70% on server 1
+      new int[] { 1, 70, 30, 0 }, // Region 1 is hosted and cached 30% on server 1, but was
+      // previously hosted and cached 70% on server 0
+      new int[] { 1, 0, 30, 70 }, // Region 2 is hosted and cached 30% on server 1, but was
+      // previously hosted and cached 70% on server 2
+      new int[] { 2, 0, 70, 30 }, // Region 3 is hosted and cached 30% on server 2, but was
+    // previously hosted and cached 70% on server 1
+    },
+
+    // Test 7: The regions are more cached on the new servers as compared to what they were
+    // cached on the server before they were moved to the new servers
+    new int[][] { new int[] { 2, 1, 1 }, new int[] { 0, 80, 20, 0 }, // Region 0 is hosted and 80%
+      // cached on server 0, but
+      // was
+      // previously hosted and 20%
+      // cached on server 1
+      new int[] { 0, 80, 0, 20 }, // Region 1 is hosted and 80% cached on server 0, but was
+      // previously hosted and 20% cached on server 2
+      new int[] { 1, 20, 80, 0 }, // Region 2 is hosted and 80% cached on server 1, but was
+      // previously hosted and 20% cached on server 0
+      new int[] { 2, 0, 20, 80 }, // Region 3 is hosted and 80% cached on server 2, but was
+    // previously hosted and 20% cached on server 1
+    },
+
+    // Test 8: The regions are randomly assigned to the server with some regions historically
+    // hosted on other region servers
+    new int[][] { new int[] { 1, 2, 1 }, new int[] { 0, 34, 0, 58 }, // Region 0 is hosted and
+      // partially cached on
+      // server 0,
+      // but was previously hosted
+      // and partially cached on
+      // server 2
+      // current cache ratio <
+      // historical cache ratio
+      new int[] { 1, 78, 100, 0 }, // Region 1 is hosted and fully cached on server 1,
+      // but was previously hosted and partially cached on server 0
+      // current cache ratio > historical cache ratio
+      new int[] { 1, 66, 66, 0 }, // Region 2 is hosted and partially cached on server 1,
+      // but was previously hosted and partially cached on server 0
+      // current cache ratio == historical cache ratio
+      new int[] { 2, 0, 0, 96 }, // Region 3 is hosted and partially cached on server 0
+    // No historical cache ratio
+    }, };
+
+  private static Configuration storedConfiguration;
+
+  private CacheAwareLoadBalancer loadBalancer = new CacheAwareLoadBalancer();
+
+  @BeforeClass
+  public static void saveInitialConfiguration() {
+    storedConfiguration = new Configuration(conf);
+  }
+
+  @Before
+  public void beforeEachTest() {
+    conf = new Configuration(storedConfiguration);
+    loadBalancer.loadConf(conf);
+  }
+
+  @Test
+  public void testVerifyCacheAwareSkewnessCostFunctionEnabled() {
+    CacheAwareLoadBalancer lb = new CacheAwareLoadBalancer();
+    lb.loadConf(conf);
+    assertTrue(Arrays.asList(lb.getCostFunctionNames())
+      .contains(CacheAwareLoadBalancer.CacheAwareRegionSkewnessCostFunction.class.getSimpleName()));
+  }
+
+  @Test
+  public void testVerifyCacheAwareSkewnessCostFunctionDisabled() {
+    conf.setFloat(
+      CacheAwareLoadBalancer.CacheAwareRegionSkewnessCostFunction.REGION_COUNT_SKEW_COST_KEY, 0.0f);
+
+    CacheAwareLoadBalancer lb = new CacheAwareLoadBalancer();
+    lb.loadConf(conf);
+
+    assertFalse(Arrays.asList(lb.getCostFunctionNames())
+      .contains(CacheAwareLoadBalancer.CacheAwareRegionSkewnessCostFunction.class.getSimpleName()));
+  }
+
+  @Test
+  public void testVerifyCacheCostFunctionEnabled() {
+    conf.set(HConstants.BUCKET_CACHE_PERSISTENT_PATH_KEY, "/tmp/prefetch.persistence");
+
+    CacheAwareLoadBalancer lb = new CacheAwareLoadBalancer();
+    lb.loadConf(conf);
+
+    assertTrue(Arrays.asList(lb.getCostFunctionNames())
+      .contains(CacheAwareLoadBalancer.CacheAwareCostFunction.class.getSimpleName()));
+  }
+
+  @Test
+  public void testVerifyCacheCostFunctionDisabledByNoBucketCachePersistence() {
+    assertFalse(Arrays.asList(loadBalancer.getCostFunctionNames())
+      .contains(CacheAwareLoadBalancer.CacheAwareCostFunction.class.getSimpleName()));
+  }
+
+  @Test
+  public void testVerifyCacheCostFunctionDisabledByNoMultiplier() {
+    conf.set(HConstants.BUCKET_CACHE_PERSISTENT_PATH_KEY, "/tmp/prefetch.persistence");
+    conf.setFloat("hbase.master.balancer.stochastic.cacheCost", 0.0f);
+    assertFalse(Arrays.asList(loadBalancer.getCostFunctionNames())
+      .contains(CacheAwareLoadBalancer.CacheAwareCostFunction.class.getSimpleName()));
+  }
+
+  @Test
+  public void testCacheCost() {
+    conf.set(HConstants.BUCKET_CACHE_PERSISTENT_PATH_KEY, "/tmp/prefetch.persistence");
+    CacheAwareLoadBalancer.CacheAwareCostFunction costFunction =
+      new CacheAwareLoadBalancer.CacheAwareCostFunction(conf);
+
+    for (int test = 0; test < clusterRegionCacheRatioMocks.length; test++) {
+      int[][] clusterRegionLocations = clusterRegionCacheRatioMocks[test];
+      MockClusterForCacheCost cluster = new MockClusterForCacheCost(clusterRegionLocations);
+      costFunction.prepare(cluster);
+      double cost = costFunction.cost();
+      assertEquals(expectedCacheCost[test], cost, 0.01);
+    }
+  }
+
+  private class MockClusterForCacheCost extends BalancerClusterState {
+    private final Map<Pair<Integer, Integer>, Float> regionServerCacheRatio = new HashMap<>();
+
+    public MockClusterForCacheCost(int[][] regionsArray) {
+      // regions[0] is an array where index = serverIndex and value = number of regions
+      super(mockClusterServersUnsorted(regionsArray[0], 1), null, null, null, null);
+      Map<String, Pair<ServerName, Float>> oldCacheRatio = new HashMap<>();
+      for (int i = 1; i < regionsArray.length; i++) {
+        int regionIndex = i - 1;
+        for (int j = 1; j < regionsArray[i].length; j++) {
+          int serverIndex = j - 1;
+          float cacheRatio = (float) regionsArray[i][j] / 100;
+          regionServerCacheRatio.put(new Pair<>(regionIndex, serverIndex), cacheRatio);
+          if (cacheRatio > 0.0f && serverIndex != regionsArray[i][0]) {
+            // This is the historical cacheRatio value
+            oldCacheRatio.put(regions[regionIndex].getEncodedName(),
+              new Pair<>(servers[serverIndex], cacheRatio));
+          }
+        }
+      }
+      regionCacheRatioOnOldServerMap = oldCacheRatio;
+    }
+
+    @Override
+    public int getTotalRegionHFileSizeMB(int region) {
+      return 1;
+    }
+
+    @Override
+    protected float getRegionCacheRatioOnRegionServer(int region, int regionServerIndex) {
+      float cacheRatio = 0.0f;
+
+      // Get the cache ratio if the region is currently hosted on this server
+      if (regionServerIndex == regionIndexToServerIndex[region]) {
+        return regionServerCacheRatio.get(new Pair<>(region, regionServerIndex));
+      }
+
+      // Region is not currently hosted on this server. Check if the region was cached on this
+      // server earlier. This can happen when the server was shutdown and the cache was persisted.
+      // Search using the index name and server name and not the index id and server id as these
+      // ids may change when a server is marked as dead or a new server is added.
+      String regionEncodedName = regions[region].getEncodedName();
+      ServerName serverName = servers[regionServerIndex];
+      if (
+        regionCacheRatioOnOldServerMap != null
+          && regionCacheRatioOnOldServerMap.containsKey(regionEncodedName)
+      ) {
+        Pair<ServerName, Float> serverCacheRatio =
+          regionCacheRatioOnOldServerMap.get(regionEncodedName);
+        if (ServerName.isSameAddress(serverName, serverCacheRatio.getFirst())) {
+          cacheRatio = serverCacheRatio.getSecond();
+          regionCacheRatioOnOldServerMap.remove(regionEncodedName);
+        }
+      }
+      return cacheRatio;
+    }
+  }
+}
diff --git a/hbase-server/src/test/java/org/apache/hadoop/hbase/master/balancer/TestRSGroupBasedLoadBalancerWithStochasticLoadBalancerAsInternal.java b/hbase-server/src/test/java/org/apache/hadoop/hbase/master/balancer/TestRSGroupBasedLoadBalancerWithStochasticLoadBalancerAsInternal.java
index 7480452..67ef296 100644
--- a/hbase-server/src/test/java/org/apache/hadoop/hbase/master/balancer/TestRSGroupBasedLoadBalancerWithStochasticLoadBalancerAsInternal.java
+++ b/hbase-server/src/test/java/org/apache/hadoop/hbase/master/balancer/TestRSGroupBasedLoadBalancerWithStochasticLoadBalancerAsInternal.java
@@ -86,6 +86,8 @@
       when(rl.getWriteRequestCount()).thenReturn(0L);
       when(rl.getMemStoreSize()).thenReturn(Size.ZERO);
       when(rl.getStoreFileSize()).thenReturn(Size.ZERO);
+      when(rl.getRegionSizeMB()).thenReturn(Size.ZERO);
+      when(rl.getCurrentRegionCachedRatio()).thenReturn(0.0f);
       regionLoadMap.put(info.getRegionName(), rl);
     }
     when(serverMetrics.getRegionMetrics()).thenReturn(regionLoadMap);