Support zone based virtual topology assignment algorithm (#2986)

Support zone based virtual topology assignment algorithm
diff --git a/helix-core/src/main/java/org/apache/helix/cloud/constants/VirtualTopologyGroupConstants.java b/helix-core/src/main/java/org/apache/helix/cloud/constants/VirtualTopologyGroupConstants.java
index d97173a..f248c9b 100644
--- a/helix-core/src/main/java/org/apache/helix/cloud/constants/VirtualTopologyGroupConstants.java
+++ b/helix-core/src/main/java/org/apache/helix/cloud/constants/VirtualTopologyGroupConstants.java
@@ -24,7 +24,13 @@
   public static final String GROUP_NAME = "virtualTopologyGroupName";
   public static final String GROUP_NUMBER = "virtualTopologyGroupNumber";
   public static final String AUTO_MAINTENANCE_MODE_DISABLED = "autoMaintenanceModeDisabled";
+  public static final String ASSIGNMENT_ALGORITHM_TYPE = "assignmentAlgorithmType";
   public static final String GROUP_NAME_SPLITTER = "_";
   public static final String PATH_NAME_SPLITTER = "/";
   public static final String VIRTUAL_FAULT_ZONE_TYPE = "virtualZone";
+  public static final String FORCE_RECOMPUTE = "forceRecompute";
+
+  public enum VirtualGroupAssignmentAlgorithm {
+    ZONE_BASED, INSTANCE_BASED
+  }
 }
diff --git a/helix-core/src/main/java/org/apache/helix/cloud/topology/FaultZoneBasedVirtualGroupAssignmentAlgorithm.java b/helix-core/src/main/java/org/apache/helix/cloud/topology/FaultZoneBasedVirtualGroupAssignmentAlgorithm.java
new file mode 100644
index 0000000..4b22fc9
--- /dev/null
+++ b/helix-core/src/main/java/org/apache/helix/cloud/topology/FaultZoneBasedVirtualGroupAssignmentAlgorithm.java
@@ -0,0 +1,173 @@
+package org.apache.helix.cloud.topology;
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.PriorityQueue;
+import java.util.Queue;
+import java.util.Set;
+import java.util.TreeSet;
+import java.util.stream.Collectors;
+
+import org.apache.commons.math3.util.Pair;
+
+import static org.apache.helix.util.VirtualTopologyUtil.computeVirtualGroupId;
+
+/**
+ * A virtual group assignment algorithm that assigns zones and their instances to virtual groups
+ * a way that preserves existing zone-to-group assignments whenever possible, and balances any
+ * remaining unassigned zones across the least-loaded groups. If the requested number of groups
+ * differs from the existing assignment, a new distribution is computed. Otherwise, if a zone
+ * already exists in the provided assignment, all its instances (including newly discovered ones)
+ * are placed in the same group, ensuring no zone is split across multiple virtual groups.
+ */
+public class FaultZoneBasedVirtualGroupAssignmentAlgorithm implements VirtualGroupAssignmentAlgorithm {
+
+  private static final FaultZoneBasedVirtualGroupAssignmentAlgorithm _instance =
+      new FaultZoneBasedVirtualGroupAssignmentAlgorithm();
+
+  private FaultZoneBasedVirtualGroupAssignmentAlgorithm() {
+  }
+
+  public static FaultZoneBasedVirtualGroupAssignmentAlgorithm getInstance() {
+    return _instance;
+  }
+
+  @Override
+  public Map<String, Set<String>> computeAssignment(int numGroups, String virtualGroupName,
+      Map<String, Set<String>> zoneMapping, Map<String, Set<String>> virtualGroupToInstancesMap) {
+    // 1. If the number of requested virtual groups differs from the current assignment size,
+    //    we must do a fresh assignment (the existing distribution is invalid).
+    if (numGroups != virtualGroupToInstancesMap.size()) {
+      Map<String, Set<String>> newAssignment = new HashMap<>();
+      for (int i = 0; i < numGroups; i++) {
+        newAssignment.put(computeVirtualGroupId(i, virtualGroupName), new HashSet<>());
+      }
+
+      // Assign all zones from scratch in a balanced manner.
+      distributeUnassignedZones(newAssignment, new ArrayList<>(zoneMapping.keySet()), zoneMapping);
+      return constructResult(newAssignment, zoneMapping);
+    }
+
+    // 2. Find unassigned zones. If there is any, incrementally assign them to the least-loaded
+    //    virtual group.
+    // Build instance-to-zone mapping for quick zone lookups.
+    Map<String, String> instanceToZoneMapping = new HashMap<>();
+    for (Map.Entry<String, Set<String>> entry : zoneMapping.entrySet()) {
+      for (String instance : entry.getValue()) {
+        instanceToZoneMapping.put(instance, entry.getKey());
+      }
+    }
+
+    // Copy zoneMapping for tracking which zones are unassigned.
+    Set<String> unassignedZones = new HashSet<>(zoneMapping.keySet());
+
+    // Build virtual group -> zone mapping and remove assigned zones from the unassigned list
+    Map<String, Set<String>> virtualGroupToZoneMapping = new HashMap<>();
+    for (Map.Entry<String, Set<String>> entry : virtualGroupToInstancesMap.entrySet()) {
+      virtualGroupToZoneMapping.putIfAbsent(entry.getKey(), new HashSet<>());
+      for (String instance : entry.getValue()) {
+        String zone = instanceToZoneMapping.get(instance);
+        virtualGroupToZoneMapping.get(entry.getKey()).add(zone);
+        unassignedZones.remove(zone);
+      }
+    }
+
+    // If there are no unassigned zones, return the result as is.
+    if (unassignedZones.isEmpty()) {
+      return constructResult(virtualGroupToZoneMapping, zoneMapping);
+    }
+
+    // Distribute unassigned zones to keep the overall distribution balanced.
+    distributeUnassignedZones(virtualGroupToZoneMapping, new ArrayList<>(unassignedZones),
+        zoneMapping);
+    return constructResult(virtualGroupToZoneMapping, zoneMapping);
+  }
+
+  @Override
+  public Map<String, Set<String>> computeAssignment(int numGroups, String virtualGroupName,
+      Map<String, Set<String>> zoneMapping) {
+    return computeAssignment(numGroups, virtualGroupName, zoneMapping, new HashMap<>());
+  }
+
+  /**
+   * Distributes unassigned zones across virtual groups in a balanced manner.
+   * Assigns heavier zones first to the current least-loaded group.
+   *
+   * @param virtualGroupToZoneMapping Current assignment of virtual group -> set of zones.
+   * @param unassignedZones            List of zones that have not been assigned to any group.
+   * @param zoneMapping                Mapping of physical zone -> set of instances.
+   */
+  private void distributeUnassignedZones(
+      Map<String, Set<String>> virtualGroupToZoneMapping, List<String> unassignedZones,
+      Map<String, Set<String>> zoneMapping) {
+
+    // Priority queue sorted by current load of the virtual group
+    // We always assign new zones to the group with the smallest load to keep them balanced.
+    Queue<String> minHeap = new PriorityQueue<>(
+        Comparator.comparingInt(vg ->
+            virtualGroupToZoneMapping.get(vg).stream()
+                .map(zoneMapping::get)
+                .mapToInt(Set::size)
+                .sum()
+        )
+    );
+    // Seed the min-heap with existing groups
+    minHeap.addAll(virtualGroupToZoneMapping.keySet());
+
+    // Sort unassigned zones by descending number of unassigned instances, assigning "heavier" zones first.
+    unassignedZones.sort(Comparator.comparingInt(zone -> zoneMapping.get(zone).size())
+        .reversed());
+
+    // Assign each zone to the least-loaded group
+    for (String zone : unassignedZones) {
+      String leastLoadVg = minHeap.poll();
+      virtualGroupToZoneMapping.get(leastLoadVg).add(zone);
+      minHeap.offer(leastLoadVg);
+    }
+  }
+
+  /**
+   * Constructs the final result by mapping virtual groups to their instances.
+   *
+   * @param vgToZonesMapping    Mapping of virtual group -> set of zones.
+   * @param zoneToInstancesMapping Mapping of zone -> set of instances.
+   * @return Mapping of virtual group -> set of instances.
+   */
+  private Map<String, Set<String>> constructResult(Map<String, Set<String>> vgToZonesMapping,
+      Map<String, Set<String>> zoneToInstancesMapping) {
+    Map<String, Set<String>> result = new HashMap<>();
+    for (Map.Entry<String, Set<String>> entry : vgToZonesMapping.entrySet()) {
+      Set<String> instances = new HashSet<>();
+      for (String zone : entry.getValue()) {
+        instances.addAll(zoneToInstancesMapping.get(zone));
+      }
+      result.put(entry.getKey(), instances);
+    }
+    return result;
+  }
+}
diff --git a/helix-core/src/main/java/org/apache/helix/cloud/topology/FifoVirtualGroupAssignmentAlgorithm.java b/helix-core/src/main/java/org/apache/helix/cloud/topology/FifoVirtualGroupAssignmentAlgorithm.java
index 23da847..1b1b3ed 100644
--- a/helix-core/src/main/java/org/apache/helix/cloud/topology/FifoVirtualGroupAssignmentAlgorithm.java
+++ b/helix-core/src/main/java/org/apache/helix/cloud/topology/FifoVirtualGroupAssignmentAlgorithm.java
@@ -30,6 +30,7 @@
 import org.apache.helix.cloud.constants.VirtualTopologyGroupConstants;
 import org.apache.helix.util.HelixUtil;
 
+import static org.apache.helix.util.VirtualTopologyUtil.computeVirtualGroupId;
 
 /**
  * A strategy that densely assign virtual groups with input instance list, it doesn't move to the next one until
@@ -49,7 +50,7 @@
 
   @Override
   public Map<String, Set<String>> computeAssignment(int numGroups, String virtualGroupName,
-      Map<String, Set<String>> zoneMapping) {
+      Map<String, Set<String>> zoneMapping, Map<String, Set<String>> virtualGroupToInstancesMap) {
     List<String> sortedInstances = HelixUtil.sortAndFlattenZoneMapping(zoneMapping);
     Map<String, Set<String>> assignment = new HashMap<>();
     // #instances = instancesPerGroupBase * numGroups + residuals
@@ -73,7 +74,9 @@
     return ImmutableMap.copyOf(assignment);
   }
 
-  private static String computeVirtualGroupId(int groupIndex, String virtualGroupName) {
-    return virtualGroupName + VirtualTopologyGroupConstants.GROUP_NAME_SPLITTER + groupIndex;
+  @Override
+  public Map<String, Set<String>> computeAssignment(int numGroups, String virtualGroupName,
+      Map<String, Set<String>> zoneMapping) {
+    return computeAssignment(numGroups, virtualGroupName, zoneMapping, new HashMap<>());
   }
 }
diff --git a/helix-core/src/main/java/org/apache/helix/cloud/topology/VirtualGroupAssignmentAlgorithm.java b/helix-core/src/main/java/org/apache/helix/cloud/topology/VirtualGroupAssignmentAlgorithm.java
index 8d6c97f..d5ac818 100644
--- a/helix-core/src/main/java/org/apache/helix/cloud/topology/VirtualGroupAssignmentAlgorithm.java
+++ b/helix-core/src/main/java/org/apache/helix/cloud/topology/VirtualGroupAssignmentAlgorithm.java
@@ -24,6 +24,19 @@
 
 
 public interface VirtualGroupAssignmentAlgorithm {
+  /**
+   * Compute the assignment for each virtual topology group.
+   *
+   * @param numGroups number of the virtual groups
+   * @param virtualGroupName virtual group name
+   * @param zoneMapping current zone mapping from zoneId to instanceIds
+   * @param virtualGroupToInstancesMap  current virtual group mapping from virtual group Id to instancesIds
+   * @return the assignment as mapping from virtual group ID to instanceIds
+   */
+  default Map<String, Set<String>> computeAssignment(int numGroups, String virtualGroupName,
+      Map<String, Set<String>> zoneMapping, Map<String, Set<String>> virtualGroupToInstancesMap) {
+    return computeAssignment(numGroups, virtualGroupName, zoneMapping);
+  }
 
   /**
    * Compute the assignment for each virtual topology group.
@@ -33,6 +46,7 @@
    * @param zoneMapping current zone mapping from zoneId to instanceIds
    * @return the assignment as mapping from virtual group ID to instanceIds
    */
+  @Deprecated
   Map<String, Set<String>> computeAssignment(int numGroups, String virtualGroupName,
       Map<String, Set<String>> zoneMapping);
 }
diff --git a/helix-core/src/main/java/org/apache/helix/util/VirtualTopologyUtil.java b/helix-core/src/main/java/org/apache/helix/util/VirtualTopologyUtil.java
new file mode 100644
index 0000000..876163b
--- /dev/null
+++ b/helix-core/src/main/java/org/apache/helix/util/VirtualTopologyUtil.java
@@ -0,0 +1,38 @@
+package org.apache.helix.util;
+
+import io.netty.util.internal.StringUtil;
+import org.apache.helix.cloud.constants.VirtualTopologyGroupConstants;
+
+public class VirtualTopologyUtil {
+  public static String computeVirtualGroupId(int groupIndex, String virtualGroupName) {
+    return virtualGroupName + VirtualTopologyGroupConstants.GROUP_NAME_SPLITTER + groupIndex;
+  }
+
+  /**
+   * Ensures the provided fault zone type string ends with
+   * the virtual fault zone type suffix.
+   *
+   * @param oldFaultZoneType The original fault zone type. Must not be null or empty.
+   * @return The fault zone type string with the virtual fault zone type appended if necessary.
+   * @throws IllegalArgumentException if {@code oldFaultZoneType} is null or empty
+   */
+  public static String computeVirtualFaultZoneTypeKey(String oldFaultZoneType) {
+    if (StringUtil.isNullOrEmpty(oldFaultZoneType)) {
+      throw new IllegalArgumentException("The old fault zone type is null or empty");
+    }
+
+    String suffix = VirtualTopologyGroupConstants.GROUP_NAME_SPLITTER
+        + VirtualTopologyGroupConstants.VIRTUAL_FAULT_ZONE_TYPE;
+
+    // If already ends with splitter + VIRTUAL_FAULT_ZONE_TYPE, return as-is
+    if (oldFaultZoneType.endsWith(suffix)) {
+      return oldFaultZoneType;
+    }
+
+    // Otherwise, remove any existing suffix parts beyond the first splitter, if needed
+    String[] segments = oldFaultZoneType.split(VirtualTopologyGroupConstants.GROUP_NAME_SPLITTER);
+    String baseName = segments[0];
+
+    return baseName + suffix;
+  }
+}
diff --git a/helix-core/src/test/java/org/apache/helix/cloud/virtualTopologyGroup/TestFaultZoneBasedVirtualGroupAssignment.java b/helix-core/src/test/java/org/apache/helix/cloud/virtualTopologyGroup/TestFaultZoneBasedVirtualGroupAssignment.java
new file mode 100644
index 0000000..85463fe
--- /dev/null
+++ b/helix-core/src/test/java/org/apache/helix/cloud/virtualTopologyGroup/TestFaultZoneBasedVirtualGroupAssignment.java
@@ -0,0 +1,169 @@
+package org.apache.helix.cloud.virtualTopologyGroup;
+
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+
+import org.apache.helix.cloud.topology.FaultZoneBasedVirtualGroupAssignmentAlgorithm;
+import org.apache.helix.cloud.topology.VirtualGroupAssignmentAlgorithm;
+import org.testng.Assert;
+import org.testng.annotations.BeforeTest;
+import org.testng.annotations.DataProvider;
+import org.testng.annotations.Test;
+
+import static org.apache.helix.util.VirtualTopologyUtil.computeVirtualGroupId;
+
+public class TestFaultZoneBasedVirtualGroupAssignment {
+
+  private static final String GROUP_NAME = "test_virtual_group";
+  private static final int ZONE_NUMBER = 20;
+  private static final int INSTANCES_PER_ZONE = 5;
+  private Map<String, Set<String>> _zoneMapping = new HashMap<>();
+  private Map<String, Set<String>> _zoneMapping2 = new HashMap<>();
+
+  @BeforeTest
+  public void prepare() {
+    _zoneMapping = new HashMap<>();
+    _zoneMapping2 = new HashMap<>();
+    int instanceIdx = 0;
+    for (int i = 0; i < ZONE_NUMBER; i++) {
+      String zone = "zone_" + i;
+      _zoneMapping.computeIfAbsent(zone, k -> new HashSet<>());
+      _zoneMapping2.computeIfAbsent(zone, k -> new HashSet<>());
+      for (int j = 0; j < INSTANCES_PER_ZONE; j++) {
+        String instance = "instance_" + instanceIdx++;
+        _zoneMapping.get(zone).add(instance);
+        _zoneMapping2.get(zone).add(instance);
+      }
+    }
+    // Add a branch zone zone_20 to zoneMapping2
+    _zoneMapping2.computeIfAbsent("zone_20", k -> new HashSet<>());
+    for (int j = 0; j < INSTANCES_PER_ZONE; j++) {
+      String instance = "instance_" + instanceIdx++;
+      _zoneMapping2.get("zone_" + (ZONE_NUMBER)).add(instance);
+    }
+  }
+
+  @Test(dataProvider = "getMappingTests")
+  public void testAssignmentScheme(int numGroups, Map<String, Set<String>> expected,
+      VirtualGroupAssignmentAlgorithm algorithm, Map<String, Set<String>> zoneMapping,
+      Map<String, Set<String>> virtualMapping) {
+    Assert.assertEquals(
+        algorithm.computeAssignment(numGroups, GROUP_NAME, zoneMapping, virtualMapping), expected);
+  }
+
+  @DataProvider
+  public Object[][] getMappingTests() {
+    VirtualGroupAssignmentAlgorithm algorithm = FaultZoneBasedVirtualGroupAssignmentAlgorithm.getInstance();
+
+    // The virtual groups should be balanced across zones
+    Map<String, Set<String>> virtualMapping = new HashMap<>();
+
+    virtualMapping.put(computeVirtualGroupId(0, GROUP_NAME), new HashSet<>());
+    virtualMapping.get(computeVirtualGroupId(0, GROUP_NAME)).addAll(_zoneMapping.get("zone_5"));
+    virtualMapping.get(computeVirtualGroupId(0, GROUP_NAME)).addAll(_zoneMapping.get("zone_1"));
+    virtualMapping.get(computeVirtualGroupId(0, GROUP_NAME)).addAll(_zoneMapping.get("zone_16"));
+    virtualMapping.get(computeVirtualGroupId(0, GROUP_NAME)).addAll(_zoneMapping.get("zone_7"));
+    virtualMapping.get(computeVirtualGroupId(0, GROUP_NAME)).addAll(_zoneMapping.get("zone_14"));
+
+    virtualMapping.put(computeVirtualGroupId(1, GROUP_NAME), new HashSet<>());
+    virtualMapping.get(computeVirtualGroupId(1, GROUP_NAME)).addAll(_zoneMapping.get("zone_0"));
+    virtualMapping.get(computeVirtualGroupId(1, GROUP_NAME)).addAll(_zoneMapping.get("zone_12"));
+    virtualMapping.get(computeVirtualGroupId(1, GROUP_NAME)).addAll(_zoneMapping.get("zone_3"));
+    virtualMapping.get(computeVirtualGroupId(1, GROUP_NAME)).addAll(_zoneMapping.get("zone_18"));
+    virtualMapping.get(computeVirtualGroupId(1, GROUP_NAME)).addAll(_zoneMapping.get("zone_10"));
+
+    virtualMapping.put(computeVirtualGroupId(2, GROUP_NAME), new HashSet<>());
+    virtualMapping.get(computeVirtualGroupId(2, GROUP_NAME)).addAll(_zoneMapping.get("zone_17"));
+    virtualMapping.get(computeVirtualGroupId(2, GROUP_NAME)).addAll(_zoneMapping.get("zone_9"));
+    virtualMapping.get(computeVirtualGroupId(2, GROUP_NAME)).addAll(_zoneMapping.get("zone_11"));
+    virtualMapping.get(computeVirtualGroupId(2, GROUP_NAME)).addAll(_zoneMapping.get("zone_19"));
+    virtualMapping.get(computeVirtualGroupId(2, GROUP_NAME)).addAll(_zoneMapping.get("zone_4"));
+
+    virtualMapping.put(computeVirtualGroupId(3, GROUP_NAME), new HashSet<>());
+    virtualMapping.get(computeVirtualGroupId(3, GROUP_NAME)).addAll(_zoneMapping.get("zone_13"));
+    virtualMapping.get(computeVirtualGroupId(3, GROUP_NAME)).addAll(_zoneMapping.get("zone_6"));
+    virtualMapping.get(computeVirtualGroupId(3, GROUP_NAME)).addAll(_zoneMapping.get("zone_2"));
+    virtualMapping.get(computeVirtualGroupId(3, GROUP_NAME)).addAll(_zoneMapping.get("zone_15"));
+    virtualMapping.get(computeVirtualGroupId(3, GROUP_NAME)).addAll(_zoneMapping.get("zone_8"));
+
+
+    Map<String, Set<String>> virtualMapping2 = new HashMap<>();
+    virtualMapping2.put(computeVirtualGroupId(0, GROUP_NAME), new HashSet<>());
+    virtualMapping2.get(computeVirtualGroupId(0, GROUP_NAME)).addAll(_zoneMapping.get("zone_1"));
+    virtualMapping2.get(computeVirtualGroupId(0, GROUP_NAME)).addAll(_zoneMapping.get("zone_12"));
+
+    virtualMapping2.put(computeVirtualGroupId(1, GROUP_NAME), new HashSet<>());
+    virtualMapping2.get(computeVirtualGroupId(1, GROUP_NAME)).addAll(_zoneMapping.get("zone_0"));
+    virtualMapping2.get(computeVirtualGroupId(1, GROUP_NAME)).addAll(_zoneMapping.get("zone_5"));
+    virtualMapping2.get(computeVirtualGroupId(1, GROUP_NAME)).addAll(_zoneMapping.get("zone_13"));
+
+    virtualMapping2.put(computeVirtualGroupId(2, GROUP_NAME), new HashSet<>());
+    virtualMapping2.get(computeVirtualGroupId(2, GROUP_NAME)).addAll(_zoneMapping.get("zone_17"));
+    virtualMapping2.get(computeVirtualGroupId(2, GROUP_NAME)).addAll(_zoneMapping.get("zone_6"));
+    virtualMapping2.get(computeVirtualGroupId(2, GROUP_NAME)).addAll(_zoneMapping.get("zone_15"));
+
+    virtualMapping2.put(computeVirtualGroupId(3, GROUP_NAME), new HashSet<>());
+    virtualMapping2.get(computeVirtualGroupId(3, GROUP_NAME)).addAll(_zoneMapping.get("zone_19"));
+    virtualMapping2.get(computeVirtualGroupId(3, GROUP_NAME)).addAll(_zoneMapping.get("zone_8"));
+    virtualMapping2.get(computeVirtualGroupId(3, GROUP_NAME)).addAll(_zoneMapping.get("zone_9"));
+
+    virtualMapping2.put(computeVirtualGroupId(4, GROUP_NAME), new HashSet<>());
+    virtualMapping2.get(computeVirtualGroupId(4, GROUP_NAME)).addAll(_zoneMapping.get("zone_7"));
+    virtualMapping2.get(computeVirtualGroupId(4, GROUP_NAME)).addAll(_zoneMapping.get("zone_10"));
+    virtualMapping2.get(computeVirtualGroupId(4, GROUP_NAME)).addAll(_zoneMapping.get("zone_18"));
+
+    virtualMapping2.put(computeVirtualGroupId(5, GROUP_NAME), new HashSet<>());
+    virtualMapping2.get(computeVirtualGroupId(5, GROUP_NAME)).addAll(_zoneMapping.get("zone_3"));
+    virtualMapping2.get(computeVirtualGroupId(5, GROUP_NAME)).addAll(_zoneMapping.get("zone_16"));
+    virtualMapping2.get(computeVirtualGroupId(5, GROUP_NAME)).addAll(_zoneMapping.get("zone_14"));
+
+    virtualMapping2.put(computeVirtualGroupId(6, GROUP_NAME), new HashSet<>());
+    virtualMapping2.get(computeVirtualGroupId(6, GROUP_NAME)).addAll(_zoneMapping.get("zone_11"));
+    virtualMapping2.get(computeVirtualGroupId(6, GROUP_NAME)).addAll(_zoneMapping.get("zone_2"));
+    virtualMapping2.get(computeVirtualGroupId(6, GROUP_NAME)).addAll(_zoneMapping.get("zone_4"));
+
+
+    Map<String, Set<String>> virtualMapping3 = new HashMap<>();
+    virtualMapping3.put(computeVirtualGroupId(0, GROUP_NAME), new HashSet<>());
+    virtualMapping3.get(computeVirtualGroupId(0, GROUP_NAME)).addAll(_zoneMapping2.get("zone_1"));
+    virtualMapping3.get(computeVirtualGroupId(0, GROUP_NAME)).addAll(_zoneMapping2.get("zone_12"));
+    virtualMapping3.get(computeVirtualGroupId(0, GROUP_NAME)).addAll(_zoneMapping2.get("zone_20"));
+
+    virtualMapping3.put(computeVirtualGroupId(1, GROUP_NAME), new HashSet<>());
+    virtualMapping3.get(computeVirtualGroupId(1, GROUP_NAME)).addAll(_zoneMapping2.get("zone_0"));
+    virtualMapping3.get(computeVirtualGroupId(1, GROUP_NAME)).addAll(_zoneMapping2.get("zone_5"));
+    virtualMapping3.get(computeVirtualGroupId(1, GROUP_NAME)).addAll(_zoneMapping2.get("zone_13"));
+
+    virtualMapping3.put(computeVirtualGroupId(2, GROUP_NAME), new HashSet<>());
+    virtualMapping3.get(computeVirtualGroupId(2, GROUP_NAME)).addAll(_zoneMapping2.get("zone_17"));
+    virtualMapping3.get(computeVirtualGroupId(2, GROUP_NAME)).addAll(_zoneMapping2.get("zone_6"));
+    virtualMapping3.get(computeVirtualGroupId(2, GROUP_NAME)).addAll(_zoneMapping2.get("zone_15"));
+
+    virtualMapping3.put(computeVirtualGroupId(3, GROUP_NAME), new HashSet<>());
+    virtualMapping3.get(computeVirtualGroupId(3, GROUP_NAME)).addAll(_zoneMapping2.get("zone_19"));
+    virtualMapping3.get(computeVirtualGroupId(3, GROUP_NAME)).addAll(_zoneMapping2.get("zone_8"));
+    virtualMapping3.get(computeVirtualGroupId(3, GROUP_NAME)).addAll(_zoneMapping2.get("zone_9"));
+
+    virtualMapping3.put(computeVirtualGroupId(4, GROUP_NAME), new HashSet<>());
+    virtualMapping3.get(computeVirtualGroupId(4, GROUP_NAME)).addAll(_zoneMapping2.get("zone_7"));
+    virtualMapping3.get(computeVirtualGroupId(4, GROUP_NAME)).addAll(_zoneMapping2.get("zone_10"));
+    virtualMapping3.get(computeVirtualGroupId(4, GROUP_NAME)).addAll(_zoneMapping2.get("zone_18"));
+
+    virtualMapping3.put(computeVirtualGroupId(5, GROUP_NAME), new HashSet<>());
+    virtualMapping3.get(computeVirtualGroupId(5, GROUP_NAME)).addAll(_zoneMapping2.get("zone_3"));
+    virtualMapping3.get(computeVirtualGroupId(5, GROUP_NAME)).addAll(_zoneMapping2.get("zone_16"));
+    virtualMapping3.get(computeVirtualGroupId(5, GROUP_NAME)).addAll(_zoneMapping2.get("zone_14"));
+
+    virtualMapping3.put(computeVirtualGroupId(6, GROUP_NAME), new HashSet<>());
+    virtualMapping3.get(computeVirtualGroupId(6, GROUP_NAME)).addAll(_zoneMapping2.get("zone_11"));
+    virtualMapping3.get(computeVirtualGroupId(6, GROUP_NAME)).addAll(_zoneMapping2.get("zone_2"));
+    virtualMapping3.get(computeVirtualGroupId(6, GROUP_NAME)).addAll(_zoneMapping2.get("zone_4"));
+
+    return new Object[][]{{4, virtualMapping, algorithm, _zoneMapping, new HashMap<>()},
+        {7, virtualMapping2, algorithm, _zoneMapping, new HashMap<>()},
+        // Should incrementally add the new zone to the virtual groups
+        {7, virtualMapping3, algorithm, _zoneMapping2, virtualMapping2}};
+  }
+}
diff --git a/helix-core/src/test/java/org/apache/helix/cloud/virtualTopologyGroup/TestVirtualTopologyGroupAssignment.java b/helix-core/src/test/java/org/apache/helix/cloud/virtualTopologyGroup/TestVirtualTopologyGroupAssignment.java
index 54f4365..a77a47c 100644
--- a/helix-core/src/test/java/org/apache/helix/cloud/virtualTopologyGroup/TestVirtualTopologyGroupAssignment.java
+++ b/helix-core/src/test/java/org/apache/helix/cloud/virtualTopologyGroup/TestVirtualTopologyGroupAssignment.java
@@ -21,11 +21,12 @@
 
 import com.google.common.collect.Sets;
 import java.util.Arrays;
+import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
-import org.apache.helix.cloud.constants.VirtualTopologyGroupConstants;
+
 import org.apache.helix.cloud.topology.FifoVirtualGroupAssignmentAlgorithm;
 import org.apache.helix.cloud.topology.VirtualGroupAssignmentAlgorithm;
 import org.apache.helix.util.HelixUtil;
@@ -34,6 +35,8 @@
 import org.testng.annotations.DataProvider;
 import org.testng.annotations.Test;
 
+import static org.apache.helix.util.VirtualTopologyUtil.computeVirtualGroupId;
+
 public class TestVirtualTopologyGroupAssignment {
 
   private static final String GROUP_NAME = "test_virtual_group";
@@ -61,34 +64,32 @@
   @Test(dataProvider = "getMappingTests")
   public void testAssignmentScheme(int numGroups, Map<String, Set<String>> expected,
       VirtualGroupAssignmentAlgorithm algorithm) {
-    Assert.assertEquals(algorithm.computeAssignment(numGroups, GROUP_NAME, _zoneMapping), expected);
+    Assert.assertEquals(algorithm.computeAssignment(numGroups, GROUP_NAME, _zoneMapping,
+        Collections.emptyMap()), expected);
   }
 
   @DataProvider
   public Object[][] getMappingTests() {
     Map<String, Set<String>> virtualMapping = new HashMap<>();
     VirtualGroupAssignmentAlgorithm algorithm = FifoVirtualGroupAssignmentAlgorithm.getInstance();
-    virtualMapping.put(computeVirtualGroupId(0), Sets.newHashSet("1", "2", "3", "4", "5"));
-    virtualMapping.put(computeVirtualGroupId(1), Sets.newHashSet("6", "7", "8", "9"));
-    virtualMapping.put(computeVirtualGroupId(2), Sets.newHashSet("a", "b", "c", "d"));
-    Assert.assertEquals(algorithm.computeAssignment(3, GROUP_NAME, _zoneMapping),
+    virtualMapping.put(computeVirtualGroupId(0, GROUP_NAME), Sets.newHashSet("1", "2", "3", "4", "5"));
+    virtualMapping.put(computeVirtualGroupId(1, GROUP_NAME), Sets.newHashSet("6", "7", "8", "9"));
+    virtualMapping.put(computeVirtualGroupId(2, GROUP_NAME), Sets.newHashSet("a", "b", "c", "d"));
+    Assert.assertEquals(algorithm.computeAssignment(3, GROUP_NAME, _zoneMapping, Collections.emptyMap()),
         virtualMapping);
     Map<String, Set<String>> virtualMapping2 = new HashMap<>();
-    virtualMapping2.put(computeVirtualGroupId(0), Sets.newHashSet("1", "2"));
-    virtualMapping2.put(computeVirtualGroupId(1), Sets.newHashSet("3", "4"));
-    virtualMapping2.put(computeVirtualGroupId(2), Sets.newHashSet("5", "6"));
-    virtualMapping2.put(computeVirtualGroupId(3), Sets.newHashSet("7", "8"));
-    virtualMapping2.put(computeVirtualGroupId(4), Sets.newHashSet("9", "a"));
-    virtualMapping2.put(computeVirtualGroupId(5), Sets.newHashSet("b"));
-    virtualMapping2.put(computeVirtualGroupId(6), Sets.newHashSet("c"));
-    virtualMapping2.put(computeVirtualGroupId(7), Sets.newHashSet("d"));
+    virtualMapping2.put(computeVirtualGroupId(0, GROUP_NAME), Sets.newHashSet("1", "2"));
+    virtualMapping2.put(computeVirtualGroupId(1, GROUP_NAME), Sets.newHashSet("3", "4"));
+    virtualMapping2.put(computeVirtualGroupId(2, GROUP_NAME), Sets.newHashSet("5", "6"));
+    virtualMapping2.put(computeVirtualGroupId(3, GROUP_NAME), Sets.newHashSet("7", "8"));
+    virtualMapping2.put(computeVirtualGroupId(4, GROUP_NAME), Sets.newHashSet("9", "a"));
+    virtualMapping2.put(computeVirtualGroupId(5, GROUP_NAME), Sets.newHashSet("b"));
+    virtualMapping2.put(computeVirtualGroupId(6, GROUP_NAME), Sets.newHashSet("c"));
+    virtualMapping2.put(computeVirtualGroupId(7, GROUP_NAME), Sets.newHashSet("d"));
+
     return new Object[][] {
         {3, virtualMapping, algorithm},
         {8, virtualMapping2, algorithm}
     };
   }
-
-  private static String computeVirtualGroupId(int groupIndex) {
-    return GROUP_NAME + VirtualTopologyGroupConstants.GROUP_NAME_SPLITTER + groupIndex;
-  }
 }
diff --git a/helix-rest/src/main/java/org/apache/helix/rest/server/service/ClusterService.java b/helix-rest/src/main/java/org/apache/helix/rest/server/service/ClusterService.java
index db93571..89a49e9 100644
--- a/helix-rest/src/main/java/org/apache/helix/rest/server/service/ClusterService.java
+++ b/helix-rest/src/main/java/org/apache/helix/rest/server/service/ClusterService.java
@@ -36,6 +36,15 @@
   ClusterTopology getClusterTopology(String cluster);
 
   /**
+   * Get the topology of a virtual cluster. If useRealTopology is true, return the real topology
+   * of the cluster. If useRealTopology is false, return the virtual topology of the cluster.
+   * @param cluster the cluster name
+   * @param useRealTopology whether to use the real topology or the virtual topology
+   * @return the cluster topology
+   */
+  ClusterTopology getTopologyOfVirtualCluster(String cluster, boolean useRealTopology);
+
+  /**
    * Get cluster basic information
    * @param clusterId
    * @return
diff --git a/helix-rest/src/main/java/org/apache/helix/rest/server/service/ClusterServiceImpl.java b/helix-rest/src/main/java/org/apache/helix/rest/server/service/ClusterServiceImpl.java
index a152c3e..527aa92 100644
--- a/helix-rest/src/main/java/org/apache/helix/rest/server/service/ClusterServiceImpl.java
+++ b/helix-rest/src/main/java/org/apache/helix/rest/server/service/ClusterServiceImpl.java
@@ -21,6 +21,7 @@
 
 import java.util.ArrayList;
 import java.util.HashMap;
+import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.stream.Collectors;
@@ -30,6 +31,7 @@
 import org.apache.helix.ConfigAccessor;
 import org.apache.helix.HelixDataAccessor;
 import org.apache.helix.PropertyKey;
+import org.apache.helix.cloud.constants.VirtualTopologyGroupConstants;
 import org.apache.helix.model.ClusterConfig;
 import org.apache.helix.model.InstanceConfig;
 import org.apache.helix.model.LiveInstance;
@@ -48,18 +50,41 @@
   @Override
   public ClusterTopology getClusterTopology(String cluster) {
     String zoneField = _configAccessor.getClusterConfig(cluster).getFaultZoneType();
+    return getTopologyUnderDomainType(zoneField, cluster);
+  }
+
+  @Override
+  public ClusterTopology getTopologyOfVirtualCluster(String cluster, boolean useRealTopology) {
+    String virtualZoneField = _configAccessor.getClusterConfig(cluster).getFaultZoneType();
+    String faultZone = virtualZoneField.split(VirtualTopologyGroupConstants.GROUP_NAME_SPLITTER)[0];
+    if (useRealTopology) {
+      // If the user requested to use real topology, return the real topology
+      return getTopologyUnderDomainType(faultZone, cluster);
+    }
+
+    String virtualZoneSuffix = VirtualTopologyGroupConstants.GROUP_NAME_SPLITTER
+        + VirtualTopologyGroupConstants.VIRTUAL_FAULT_ZONE_TYPE;
+    // If the cluster doesn't have a virtual topology but the user requested, return empty
+    // topology, indicating that virtual topology is not enabled
+    if (!virtualZoneField.endsWith(virtualZoneSuffix)) {
+      return new ClusterTopology(cluster, new ArrayList<>(), new HashSet<>());
+    }
+    return getTopologyUnderDomainType(virtualZoneField, cluster);
+  }
+
+  private ClusterTopology getTopologyUnderDomainType(String faultZone, String clusterId) {
     PropertyKey.Builder keyBuilder = _dataAccessor.keyBuilder();
     List<InstanceConfig> instanceConfigs =
         _dataAccessor.getChildValues(keyBuilder.instanceConfigs(), true);
     Map<String, List<ClusterTopology.Instance>> instanceMapByZone = new HashMap<>();
     if (instanceConfigs != null && !instanceConfigs.isEmpty()) {
       for (InstanceConfig instanceConfig : instanceConfigs) {
-        if (!instanceConfig.getDomainAsMap().containsKey(zoneField)) {
+        if (!instanceConfig.getDomainAsMap().containsKey(faultZone)) {
           continue;
         }
         final String instanceName = instanceConfig.getInstanceName();
         final ClusterTopology.Instance instance = new ClusterTopology.Instance(instanceName);
-        final String zoneId = instanceConfig.getDomainAsMap().get(zoneField);
+        final String zoneId = instanceConfig.getDomainAsMap().get(faultZone);
         if (instanceMapByZone.containsKey(zoneId)) {
           instanceMapByZone.get(zoneId).add(instance);
         } else {
@@ -79,7 +104,7 @@
     }
 
     // Get all the instances names
-    return new ClusterTopology(cluster, zones,
+    return new ClusterTopology(clusterId, zones,
         instanceConfigs.stream().map(InstanceConfig::getInstanceName).collect(Collectors.toSet()));
   }
 
diff --git a/helix-rest/src/main/java/org/apache/helix/rest/server/service/VirtualTopologyGroupService.java b/helix-rest/src/main/java/org/apache/helix/rest/server/service/VirtualTopologyGroupService.java
index 2fd5f28..efc36ea 100644
--- a/helix-rest/src/main/java/org/apache/helix/rest/server/service/VirtualTopologyGroupService.java
+++ b/helix-rest/src/main/java/org/apache/helix/rest/server/service/VirtualTopologyGroupService.java
@@ -22,6 +22,7 @@
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.base.Preconditions;
 import java.util.ArrayList;
+import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
@@ -34,9 +35,9 @@
 import org.apache.helix.HelixException;
 import org.apache.helix.PropertyPathBuilder;
 import org.apache.helix.cloud.constants.VirtualTopologyGroupConstants;
+import org.apache.helix.cloud.topology.FaultZoneBasedVirtualGroupAssignmentAlgorithm;
 import org.apache.helix.cloud.topology.FifoVirtualGroupAssignmentAlgorithm;
 import org.apache.helix.cloud.topology.VirtualGroupAssignmentAlgorithm;
-import org.apache.helix.model.CloudConfig;
 import org.apache.helix.model.ClusterConfig;
 import org.apache.helix.model.ClusterTopologyConfig;
 import org.apache.helix.model.InstanceConfig;
@@ -46,6 +47,7 @@
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import static org.apache.helix.util.VirtualTopologyUtil.computeVirtualFaultZoneTypeKey;
 
 /**
  * Service for virtual topology group.
@@ -60,7 +62,7 @@
   private final ClusterService _clusterService;
   private final ConfigAccessor _configAccessor;
   private final HelixDataAccessor _dataAccessor;
-  private final VirtualGroupAssignmentAlgorithm _assignmentAlgorithm;
+  private VirtualGroupAssignmentAlgorithm _assignmentAlgorithm;
 
   public VirtualTopologyGroupService(HelixAdmin helixAdmin, ClusterService clusterService,
       ConfigAccessor configAccessor, HelixDataAccessor dataAccessor) {
@@ -68,7 +70,7 @@
     _clusterService = clusterService;
     _configAccessor = configAccessor;
     _dataAccessor = dataAccessor;
-    _assignmentAlgorithm = FifoVirtualGroupAssignmentAlgorithm.getInstance();
+    _assignmentAlgorithm = FifoVirtualGroupAssignmentAlgorithm.getInstance(); // default assignment algorithm
   }
 
   /**
@@ -86,10 +88,27 @@
    *                     -- if set false or not set, the cluster will automatically enter maintenance mode and exit after
    *                     the call succeeds. It won't proceed if the cluster is already in maintenance mode.
    *                     Either case, the cluster must be in maintenance mode before config change.
+   *                     {@link VirtualTopologyGroupConstants#ASSIGNMENT_ALGORITHM_TYPE} is optional, default to INSTANCE_BASED.
+   *                     {@link VirtualTopologyGroupConstants#FORCE_RECOMPUTE} is optional, default to false.
+   *                     -- if set true, the virtual topology group will be recomputed from scratch by ignoring the existing
+   *                     virtual topology group information.
+   *                     -- if set false or not set, the virtual topology group will be incrementally computed based on the
+   *                     existing virtual topology group information if possible.
    */
   public void addVirtualTopologyGroup(String clusterName, Map<String, String> customFields) {
     // validation
     ClusterConfig clusterConfig = _configAccessor.getClusterConfig(clusterName);
+    // Collect the real topology of the cluster and the virtual topology of the cluster
+    ClusterTopology clusterTopology = _clusterService.getTopologyOfVirtualCluster(clusterName, true);
+    // If forceRecompute is set to true, we will recompute the virtual topology group from scratch
+    // by ignoring the existing virtual topology group information.
+    String forceRecompute = customFields.getOrDefault(VirtualTopologyGroupConstants.FORCE_RECOMPUTE, "false");
+    boolean forceRecomputeFlag = Boolean.parseBoolean(forceRecompute);
+    ClusterTopology virtualTopology =
+        forceRecomputeFlag ? new ClusterTopology(clusterName, Collections.emptyList(),
+            Collections.emptySet())
+            : _clusterService.getTopologyOfVirtualCluster(clusterName, false);
+
     Preconditions.checkState(clusterConfig.isTopologyAwareEnabled(),
         "Topology-aware rebalance is not enabled in cluster " + clusterName);
     String groupName = customFields.get(VirtualTopologyGroupConstants.GROUP_NAME);
@@ -103,14 +122,39 @@
     } catch (NumberFormatException ex) {
       throw new IllegalArgumentException("virtualTopologyGroupNumber " + groupNumberStr + " is not an integer.", ex);
     }
+
+    String algorithm = customFields.get(VirtualTopologyGroupConstants.ASSIGNMENT_ALGORITHM_TYPE);
+    algorithm = algorithm == null ? VirtualTopologyGroupConstants.VirtualGroupAssignmentAlgorithm.INSTANCE_BASED.toString() : algorithm;
+    if (algorithm != null) {
+      VirtualTopologyGroupConstants.VirtualGroupAssignmentAlgorithm algorithmEnum = null;
+      try {
+        algorithmEnum =
+            VirtualTopologyGroupConstants.VirtualGroupAssignmentAlgorithm.valueOf(algorithm);
+      } catch (Exception e) {
+        throw new IllegalArgumentException(
+            "Failed to instantiate assignment algorithm " + algorithm, e);
+      }
+      switch (algorithmEnum) {
+        case ZONE_BASED:
+          Preconditions.checkArgument(numGroups <= clusterTopology.getZones().size(),
+              "Number of virtual groups cannot be greater than the number of zones.");
+          _assignmentAlgorithm = FaultZoneBasedVirtualGroupAssignmentAlgorithm.getInstance();
+          break;
+        case INSTANCE_BASED:
+          Preconditions.checkArgument(numGroups <= clusterTopology.getAllInstances().size(),
+              "Number of virtual groups cannot be greater than the number of instances.");
+          _assignmentAlgorithm = FifoVirtualGroupAssignmentAlgorithm.getInstance();
+          break;
+        default:
+          throw new IllegalArgumentException("Unsupported assignment algorithm " + algorithm);
+      }
+    }
     LOG.info("Computing virtual topology group for cluster {} with param {}", clusterName, customFields);
 
     // compute group assignment
-    ClusterTopology clusterTopology = _clusterService.getClusterTopology(clusterName);
-    Preconditions.checkArgument(numGroups <= clusterTopology.getAllInstances().size(),
-        "Number of virtual groups cannot be greater than the number of instances.");
     Map<String, Set<String>> assignment =
-        _assignmentAlgorithm.computeAssignment(numGroups, groupName, clusterTopology.toZoneMapping());
+        _assignmentAlgorithm.computeAssignment(numGroups, groupName,
+            clusterTopology.toZoneMapping(), virtualTopology.toZoneMapping());
 
     boolean autoMaintenanceModeDisabled = Boolean.parseBoolean(
         customFields.getOrDefault(VirtualTopologyGroupConstants.AUTO_MAINTENANCE_MODE_DISABLED, "false"));
@@ -137,7 +181,7 @@
   private void updateConfigs(String clusterName, ClusterConfig clusterConfig, Map<String, Set<String>> assignment) {
     List<String> zkPaths = new ArrayList<>();
     List<DataUpdater<ZNRecord>> updaters = new ArrayList<>();
-    createInstanceConfigUpdater(clusterName, assignment).forEach((zkPath, updater) -> {
+    createInstanceConfigUpdater(clusterConfig, assignment).forEach((zkPath, updater) -> {
       zkPaths.add(zkPath);
       updaters.add(updater);
     });
@@ -151,7 +195,7 @@
     // update cluster config
     String virtualTopologyString = computeVirtualTopologyString(clusterConfig);
     clusterConfig.setTopology(virtualTopologyString);
-    clusterConfig.setFaultZoneType(VirtualTopologyGroupConstants.VIRTUAL_FAULT_ZONE_TYPE);
+    clusterConfig.setFaultZoneType(computeVirtualFaultZoneTypeKey(clusterConfig.getFaultZoneType()));
     _configAccessor.updateClusterConfig(clusterName, clusterConfig);
     LOG.info("Successfully update instance and cluster config for {}", clusterName);
   }
@@ -160,28 +204,28 @@
   static String computeVirtualTopologyString(ClusterConfig clusterConfig) {
     ClusterTopologyConfig clusterTopologyConfig = ClusterTopologyConfig.createFromClusterConfig(clusterConfig);
     String endNodeType = clusterTopologyConfig.getEndNodeType();
-    String[] splits = new String[] {"", VirtualTopologyGroupConstants.VIRTUAL_FAULT_ZONE_TYPE, endNodeType};
+    String[] splits = new String[] {"", computeVirtualFaultZoneTypeKey(clusterConfig.getFaultZoneType()), endNodeType};
     return String.join(VirtualTopologyGroupConstants.PATH_NAME_SPLITTER, splits);
   }
 
   /**
    * Create updater for instance config for async update.
-   * @param clusterName cluster name of the instances.
+   * @param clusterConfig cluster config for the cluster which the instance reside.
    * @param assignment virtual group assignment.
    * @return a map from instance zkPath to its {@link DataUpdater} to update.
    */
   @VisibleForTesting
   static Map<String, DataUpdater<ZNRecord>> createInstanceConfigUpdater(
-      String clusterName, Map<String, Set<String>> assignment) {
+      ClusterConfig clusterConfig, Map<String, Set<String>> assignment) {
     Map<String, DataUpdater<ZNRecord>> updaters = new HashMap<>();
     for (Map.Entry<String, Set<String>> entry : assignment.entrySet()) {
       String virtualGroup = entry.getKey();
       for (String instanceName : entry.getValue()) {
-        String path = PropertyPathBuilder.instanceConfig(clusterName, instanceName);
+        String path = PropertyPathBuilder.instanceConfig(clusterConfig.getClusterName(), instanceName);
         updaters.put(path, currentData -> {
           InstanceConfig instanceConfig = new InstanceConfig(currentData);
           Map<String, String> domainMap = instanceConfig.getDomainAsMap();
-          domainMap.put(VirtualTopologyGroupConstants.VIRTUAL_FAULT_ZONE_TYPE, virtualGroup);
+          domainMap.put(computeVirtualFaultZoneTypeKey(clusterConfig.getFaultZoneType()), virtualGroup);
           instanceConfig.setDomain(domainMap);
           return instanceConfig.getRecord();
         });
diff --git a/helix-rest/src/test/java/org/apache/helix/rest/server/TestClusterAccessor.java b/helix-rest/src/test/java/org/apache/helix/rest/server/TestClusterAccessor.java
index e882cf0..4ae7b3d 100644
--- a/helix-rest/src/test/java/org/apache/helix/rest/server/TestClusterAccessor.java
+++ b/helix-rest/src/test/java/org/apache/helix/rest/server/TestClusterAccessor.java
@@ -75,6 +75,8 @@
 import org.testng.annotations.DataProvider;
 import org.testng.annotations.Test;
 
+import static org.apache.helix.cloud.azure.AzureConstants.AZURE_FAULT_ZONE_TYPE;
+
 public class TestClusterAccessor extends AbstractTestClass {
 
   private static final String VG_CLUSTER = "vgCluster";
@@ -221,32 +223,6 @@
             "/instance:TestCluster_1localhost_12927"))));
   }
 
-  @Test(dataProvider = "prepareVirtualTopologyTests", dependsOnMethods = "testGetClusters")
-  public void testAddVirtualTopologyGroup(String requestParam, int numGroups,
-      Map<String, String> instanceToGroup) throws IOException {
-    post("clusters/" + VG_CLUSTER,
-        ImmutableMap.of("command", "addVirtualTopologyGroup"),
-        Entity.entity(requestParam, MediaType.APPLICATION_JSON_TYPE),
-        Response.Status.OK.getStatusCode());
-    Map<String, Object> topology = getMapResponseFromRest(String.format("clusters/%s/topology", VG_CLUSTER));
-    Assert.assertTrue(topology.containsKey("zones"));
-    Assert.assertEquals(((List) topology.get("zones")).size(), numGroups);
-
-    ClusterConfig clusterConfig = getClusterConfigFromRest(VG_CLUSTER);
-    String expectedTopology = "/" + VirtualTopologyGroupConstants.VIRTUAL_FAULT_ZONE_TYPE + "/hostname";
-    Assert.assertEquals(clusterConfig.getTopology(), expectedTopology);
-    Assert.assertEquals(clusterConfig.getFaultZoneType(), VirtualTopologyGroupConstants.VIRTUAL_FAULT_ZONE_TYPE);
-
-    HelixDataAccessor helixDataAccessor = new ZKHelixDataAccessor(VG_CLUSTER, _baseAccessor);
-    for (Map.Entry<String, String> entry : instanceToGroup.entrySet()) {
-      InstanceConfig instanceConfig =
-          helixDataAccessor.getProperty(helixDataAccessor.keyBuilder().instanceConfig(entry.getKey()));
-      String expectedGroup = entry.getValue();
-      Assert.assertEquals(instanceConfig.getDomainAsMap().get(VirtualTopologyGroupConstants.VIRTUAL_FAULT_ZONE_TYPE),
-          expectedGroup);
-    }
-  }
-
   @Test(dependsOnMethods = "testGetClusters")
   public void testVirtualTopologyGroupMaintenanceMode() throws JsonProcessingException {
     setupClusterForVirtualTopology(VG_CLUSTER);
@@ -270,6 +246,36 @@
     Assert.assertTrue(isMaintenanceModeEnabled(VG_CLUSTER));
   }
 
+  @Test(dataProvider = "prepareVirtualTopologyTests", dependsOnMethods = "testVirtualTopologyGroupMaintenanceMode")
+  public void testAddVirtualTopologyGroup(String requestParam, int numGroups,
+      Map<String, String> instanceToGroup) throws IOException {
+    System.out.println("Start test :" + TestHelper.getTestMethodName() + " with requestParam: " + requestParam);
+    post("clusters/" + VG_CLUSTER,
+        ImmutableMap.of("command", "addVirtualTopologyGroup"),
+        Entity.entity(requestParam, MediaType.APPLICATION_JSON_TYPE),
+        Response.Status.OK.getStatusCode());
+    Map<String, Object> topology = getMapResponseFromRest(String.format("clusters/%s/topology", VG_CLUSTER));
+    Assert.assertTrue(topology.containsKey("zones"));
+    Assert.assertEquals(((List) topology.get("zones")).size(), numGroups,
+        "virtual groups not created as expected. Need " + numGroups + " groups but got "
+            + (topology.get("zones")));
+
+    ClusterConfig clusterConfig = getClusterConfigFromRest(VG_CLUSTER);
+    String expectedFaultZoneType = AZURE_FAULT_ZONE_TYPE + "_" + VirtualTopologyGroupConstants.VIRTUAL_FAULT_ZONE_TYPE;
+    String expectedTopology = "/" + expectedFaultZoneType + "/hostname";
+    Assert.assertEquals(clusterConfig.getTopology(), expectedTopology);
+    Assert.assertEquals(clusterConfig.getFaultZoneType(), expectedFaultZoneType);
+
+    HelixDataAccessor helixDataAccessor = new ZKHelixDataAccessor(VG_CLUSTER, _baseAccessor);
+    for (Map.Entry<String, String> entry : instanceToGroup.entrySet()) {
+      InstanceConfig instanceConfig =
+          helixDataAccessor.getProperty(helixDataAccessor.keyBuilder().instanceConfig(entry.getKey()));
+      String expectedGroup = entry.getValue();
+      Assert.assertEquals(instanceConfig.getDomainAsMap().get(expectedFaultZoneType),
+          expectedGroup);
+    }
+  }
+
   private boolean isMaintenanceModeEnabled(String clusterName) throws JsonProcessingException {
     String body =
         get("clusters/" + clusterName + "/maintenance", null, Response.Status.OK.getStatusCode(), true);
@@ -281,6 +287,14 @@
     setupClusterForVirtualTopology(VG_CLUSTER);
     String test1 = "{\"virtualTopologyGroupNumber\":\"7\",\"virtualTopologyGroupName\":\"vgTest\"}";
     String test2 = "{\"virtualTopologyGroupNumber\":\"9\",\"virtualTopologyGroupName\":\"vgTest\"}";
+    // Split 5 zones into 2 virtual groups, expect 0-1-2 in virtual group 0, 3-4 in virtual group 1
+    String test3 = "{\"virtualTopologyGroupNumber\":\"2\",\"virtualTopologyGroupName\":\"vgTest\","
+        + "\"assignmentAlgorithmType\":\"ZONE_BASED\"}";
+    String test4 = "{\"virtualTopologyGroupNumber\":\"5\",\"virtualTopologyGroupName\":\"vgTest\","
+        + "\"assignmentAlgorithmType\":\"ZONE_BASED\"}";
+    String test5 = "{\"virtualTopologyGroupNumber\":\"2\",\"virtualTopologyGroupName\":\"vgTest\","
+        + "\"assignmentAlgorithmType\":\"ZONE_BASED\",\"forceRecompute\""
+        + ":\"true\"}";
     return new Object[][] {
         {test1, 7, ImmutableMap.of(
             "vgCluster_localhost_12918", "vgTest_0",
@@ -297,7 +311,29 @@
             "vgCluster_localhost_12918", "vgTest_0",
             "vgCluster_localhost_12919", "vgTest_0",
             "vgCluster_localhost_12925", "vgTest_4",
-            "vgCluster_localhost_12927", "vgTest_6")}
+            "vgCluster_localhost_12927", "vgTest_6")},
+        {test3, 2, ImmutableMap.of(
+            "vgCluster_localhost_12918", "vgTest_0",
+            "vgCluster_localhost_12919", "vgTest_0",
+            "vgCluster_localhost_12925", "vgTest_1",
+            "vgCluster_localhost_12927", "vgTest_0")},
+        {test4, 5, ImmutableMap.of(
+            "vgCluster_localhost_12918", "vgTest_4",
+            "vgCluster_localhost_12919", "vgTest_4",
+            "vgCluster_localhost_12925", "vgTest_2",
+            "vgCluster_localhost_12927", "vgTest_1")},
+        // repeat test3 for deterministic and test for decreasing numGroups
+        {test3, 2, ImmutableMap.of(
+            "vgCluster_localhost_12918", "vgTest_0",
+            "vgCluster_localhost_12919", "vgTest_0",
+            "vgCluster_localhost_12925", "vgTest_1",
+            "vgCluster_localhost_12927", "vgTest_0")},
+        // Force recompute to reassign instances to virtual groups
+        {test5, 2, ImmutableMap.of(
+            "vgCluster_localhost_12918", "vgTest_0",
+            "vgCluster_localhost_12919", "vgTest_0",
+            "vgCluster_localhost_12925", "vgTest_1",
+            "vgCluster_localhost_12927", "vgTest_0")},
     };
   }
 
@@ -940,7 +976,7 @@
 
     ClusterConfig clusterConfigFromZk = _configAccessor.getClusterConfig(clusterName);
     Assert.assertEquals(clusterConfigFromZk.getTopology(), AzureConstants.AZURE_TOPOLOGY);
-    Assert.assertEquals(clusterConfigFromZk.getFaultZoneType(), AzureConstants.AZURE_FAULT_ZONE_TYPE);
+    Assert.assertEquals(clusterConfigFromZk.getFaultZoneType(), AZURE_FAULT_ZONE_TYPE);
     Assert.assertTrue(clusterConfigFromZk.isTopologyAwareEnabled());
   }
 
diff --git a/helix-rest/src/test/java/org/apache/helix/rest/server/service/TestClusterService.java b/helix-rest/src/test/java/org/apache/helix/rest/server/service/TestClusterService.java
index 23182d2..8bd414d 100644
--- a/helix-rest/src/test/java/org/apache/helix/rest/server/service/TestClusterService.java
+++ b/helix-rest/src/test/java/org/apache/helix/rest/server/service/TestClusterService.java
@@ -120,6 +120,44 @@
     Assert.assertTrue(mock.clusterService.isClusterTopologyAware(TEST_CLUSTER));
   }
 
+  @Test
+  public void testGetVirtualTopology() {
+    InstanceConfig instanceConfig1 = new InstanceConfig("instance0");
+    instanceConfig1.setDomain("helixZoneId=zone0, helixZoneId_virtualZone=virtualZone0");
+    InstanceConfig instanceConfig2 = new InstanceConfig("instance1");
+    instanceConfig2.setDomain("helixZoneId=zone1, helixZoneId_virtualZone=virtualZone1");
+    InstanceConfig instanceConfig3 = new InstanceConfig("instance3");
+    instanceConfig3.setDomain("helixZoneId=zone3");
+    List<HelixProperty> instanceConfigs = ImmutableList.of(instanceConfig1, instanceConfig2, instanceConfig3);
+
+    Mock mock = new Mock();
+    ClusterConfig mockConfig = new ClusterConfig(TEST_CLUSTER);
+    mockConfig.setFaultZoneType("helixZoneId_virtualZone");
+    when(mock.configAccessor.getClusterConfig(TEST_CLUSTER)).thenReturn(mockConfig);
+    when(mock.dataAccessor.keyBuilder()).thenReturn(new PropertyKey.Builder(TEST_CLUSTER));
+    when(mock.dataAccessor.getChildValues(any(PropertyKey.class), anyBoolean()))
+        .thenReturn(instanceConfigs);
+
+    // When use `getClusterTopology` on a virtual topology cluster, it shall return topology
+    // based on the configured fault zone type
+    ClusterTopology clusterTopology = mock.clusterService.getClusterTopology(TEST_CLUSTER);
+    Assert.assertEquals(clusterTopology.getZones().size(), 2);
+    Assert.assertEquals(clusterTopology.getClusterId(), TEST_CLUSTER);
+    Assert.assertEquals(clusterTopology.getZones().get(0).getInstances().size(), 1);
+
+    // When use `getVirtualClusterTopology` on a virtual topology cluster, it shall return the
+    // virtual topology
+    clusterTopology = mock.clusterService.getTopologyOfVirtualCluster(TEST_CLUSTER, true);
+    Assert.assertEquals(clusterTopology.getZones().size(), 3);
+    Assert.assertEquals(clusterTopology.getClusterId(), TEST_CLUSTER);
+
+    // When use `getVirtualClusterTopology` on a virtual topology cluster, it shall return the
+    // virtual topology
+    clusterTopology = mock.clusterService.getTopologyOfVirtualCluster(TEST_CLUSTER, false);
+    Assert.assertEquals(clusterTopology.getZones().size(), 2);
+    Assert.assertEquals(clusterTopology.getClusterId(), TEST_CLUSTER);
+  }
+
   private final class Mock {
     private HelixDataAccessor dataAccessor = mock(HelixDataAccessor.class);
     private ConfigAccessor configAccessor = mock(ConfigAccessor.class);
diff --git a/helix-rest/src/test/java/org/apache/helix/rest/server/service/TestVirtualTopologyGroupService.java b/helix-rest/src/test/java/org/apache/helix/rest/server/service/TestVirtualTopologyGroupService.java
index 3ffc29c..685b534 100644
--- a/helix-rest/src/test/java/org/apache/helix/rest/server/service/TestVirtualTopologyGroupService.java
+++ b/helix-rest/src/test/java/org/apache/helix/rest/server/service/TestVirtualTopologyGroupService.java
@@ -22,6 +22,8 @@
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.ImmutableSet;
+
+import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
@@ -29,10 +31,7 @@
 import org.apache.helix.ConfigAccessor;
 import org.apache.helix.HelixAdmin;
 import org.apache.helix.HelixDataAccessor;
-import org.apache.helix.HelixException;
 import org.apache.helix.cloud.azure.AzureConstants;
-import org.apache.helix.cloud.constants.CloudProvider;
-import org.apache.helix.model.CloudConfig;
 import org.apache.helix.model.ClusterConfig;
 import org.apache.helix.model.HelixConfigScope;
 import org.apache.helix.model.InstanceConfig;
@@ -40,12 +39,15 @@
 import org.apache.helix.rest.server.json.cluster.ClusterTopology;
 import org.apache.helix.zookeeper.datamodel.ZNRecord;
 import org.apache.helix.zookeeper.zkclient.DataUpdater;
+import org.mockito.ArgumentMatchers;
 import org.testng.Assert;
 import org.testng.annotations.BeforeTest;
 import org.testng.annotations.DataProvider;
 import org.testng.annotations.Test;
 
+import static org.apache.helix.cloud.azure.AzureConstants.AZURE_FAULT_ZONE_TYPE;
 import static org.apache.helix.cloud.constants.VirtualTopologyGroupConstants.*;
+import static org.apache.helix.util.VirtualTopologyUtil.computeVirtualFaultZoneTypeKey;
 import static org.mockito.Mockito.*;
 
 
@@ -53,6 +55,7 @@
   private static final String TEST_CLUSTER = "Test_Cluster";
   private static final String TEST_CLUSTER0 = "TestCluster_0";
   private static final String TEST_CLUSTER1 = "TestCluster_1";
+  private static final String FAULT_ZONE_TYPE = "helixZoneId";
 
   private final ConfigAccessor _configAccessor = mock(ConfigAccessor.class);
   private final HelixDataAccessor _dataAccessor = mock(HelixDataAccessor.class);
@@ -75,13 +78,18 @@
 
     assignment.put("virtual_group_0", ImmutableSet.of("instance_0", "instance_1"));
     assignment.put("virtual_group_1", ImmutableSet.of("instance_2"));
-    _updaterMap = VirtualTopologyGroupService.createInstanceConfigUpdater(TEST_CLUSTER, assignment);
+    ClusterConfig testClusterConfig = new ClusterConfig(TEST_CLUSTER);
+    testClusterConfig.setFaultZoneType(FAULT_ZONE_TYPE);
+    testClusterConfig.setTopology("/helixZoneId");
+    testClusterConfig.setTopologyAwareEnabled(true);
+    when(_configAccessor.getClusterConfig(TEST_CLUSTER)).thenReturn(testClusterConfig);
+    _updaterMap = VirtualTopologyGroupService.createInstanceConfigUpdater(testClusterConfig, assignment);
 
-    ClusterConfig clusterConfig = new ClusterConfig(TEST_CLUSTER0);
-    clusterConfig.setFaultZoneType(AzureConstants.AZURE_FAULT_ZONE_TYPE);
-    clusterConfig.setTopology(AzureConstants.AZURE_TOPOLOGY);
-    clusterConfig.setTopologyAwareEnabled(true);
-    when(_configAccessor.getClusterConfig(TEST_CLUSTER0)).thenReturn(clusterConfig);
+    ClusterConfig testClusterConfig0 = new ClusterConfig(TEST_CLUSTER0);
+    testClusterConfig0.setFaultZoneType(AZURE_FAULT_ZONE_TYPE);
+    testClusterConfig0.setTopology(AzureConstants.AZURE_TOPOLOGY);
+    testClusterConfig0.setTopologyAwareEnabled(true);
+    when(_configAccessor.getClusterConfig(TEST_CLUSTER0)).thenReturn(testClusterConfig0);
 
     _helixAdmin = mock(HelixAdmin.class);
     when(_helixAdmin.isInMaintenanceMode(anyString())).thenReturn(true);
@@ -91,6 +99,10 @@
     when(_dataAccessor.updateChildren(anyList(), anyList(), anyInt())).thenReturn(results);
     ClusterService clusterService = mock(ClusterService.class);
     when(clusterService.getClusterTopology(anyString())).thenReturn(prepareClusterTopology());
+    when(clusterService.getTopologyOfVirtualCluster(anyString(), ArgumentMatchers.eq(true))).thenReturn(
+        prepareClusterTopology());
+    when(clusterService.getTopologyOfVirtualCluster(anyString(), ArgumentMatchers.eq(false))).thenReturn(
+        new ClusterTopology(TEST_CLUSTER0, Collections.emptyList(), Collections.emptySet()));
     _service = new VirtualTopologyGroupService(_helixAdmin, clusterService, _configAccessor, _dataAccessor);
   }
 
@@ -134,6 +146,14 @@
         GROUP_NAME, "test-group", GROUP_NUMBER, "10", AUTO_MAINTENANCE_MODE_DISABLED, "true"));
   }
 
+  @Test(expectedExceptions = IllegalArgumentException.class,
+      expectedExceptionsMessageRegExp = "Number of virtual groups cannot be greater than the number of zones.*")
+  public void testFaultZoneBasedVirtualGroupAssignment() {
+    _service.addVirtualTopologyGroup(TEST_CLUSTER0, ImmutableMap.of(
+        GROUP_NAME, "test-group", GROUP_NUMBER, "3", AUTO_MAINTENANCE_MODE_DISABLED, "true",
+        ASSIGNMENT_ALGORITHM_TYPE, "ZONE_BASED"));
+  }
+
   @Test(expectedExceptions = IllegalArgumentException.class)
   public void testParamValidation() {
     _service.addVirtualTopologyGroup(TEST_CLUSTER0, ImmutableMap.of(GROUP_NUMBER, "2"));
@@ -150,11 +170,11 @@
   public Object[][] instanceTestProvider() {
     return new Object[][] {
         {computeZkPath("instance_0"), _instanceConfig0,
-            ImmutableMap.of("helixZoneId", "zone0", VIRTUAL_FAULT_ZONE_TYPE, "virtual_group_0")},
+            ImmutableMap.of(FAULT_ZONE_TYPE, "zone0", computeVirtualFaultZoneTypeKey(FAULT_ZONE_TYPE), "virtual_group_0")},
         {computeZkPath("instance_1"), _instanceConfig1,
-            ImmutableMap.of("helixZoneId", "zone0", VIRTUAL_FAULT_ZONE_TYPE, "virtual_group_0")},
+            ImmutableMap.of(FAULT_ZONE_TYPE, "zone0", computeVirtualFaultZoneTypeKey(FAULT_ZONE_TYPE), "virtual_group_0")},
         {computeZkPath("instance_2"), _instanceConfig2,
-            ImmutableMap.of("helixZoneId", "zone1", VIRTUAL_FAULT_ZONE_TYPE, "virtual_group_1")}
+            ImmutableMap.of(FAULT_ZONE_TYPE, "zone1", computeVirtualFaultZoneTypeKey(FAULT_ZONE_TYPE), "virtual_group_1")}
     };
   }
 
@@ -163,8 +183,9 @@
     ClusterConfig testConfig = new ClusterConfig("testId");
     testConfig.setTopologyAwareEnabled(true);
     testConfig.setTopology("/zone/instance");
+    testConfig.setFaultZoneType("zone");
     Assert.assertEquals(VirtualTopologyGroupService.computeVirtualTopologyString(testConfig),
-        "/virtualZone/instance");
+        "/zone_virtualZone/instance");
   }
 
   private static ClusterTopology prepareClusterTopology() {