Introduce segment assignment strategy interface #9047 (#9309)

{
  "segmentAssignmentConfigMap": {
    "OFFLINE": {
      "segmentAssignmentStrategy": "balanced/replicaGroup/allServers"
    }
  }
  ...
}
diff --git a/pinot-common/src/main/java/org/apache/pinot/common/utils/config/TableConfigUtils.java b/pinot-common/src/main/java/org/apache/pinot/common/utils/config/TableConfigUtils.java
index eeed756..9735a0f 100644
--- a/pinot-common/src/main/java/org/apache/pinot/common/utils/config/TableConfigUtils.java
+++ b/pinot-common/src/main/java/org/apache/pinot/common/utils/config/TableConfigUtils.java
@@ -44,6 +44,7 @@
 import org.apache.pinot.spi.config.table.UpsertConfig;
 import org.apache.pinot.spi.config.table.assignment.InstanceAssignmentConfig;
 import org.apache.pinot.spi.config.table.assignment.InstancePartitionsType;
+import org.apache.pinot.spi.config.table.assignment.SegmentAssignmentConfig;
 import org.apache.pinot.spi.config.table.ingestion.BatchIngestionConfig;
 import org.apache.pinot.spi.config.table.ingestion.IngestionConfig;
 import org.apache.pinot.spi.config.table.ingestion.StreamIngestionConfig;
@@ -165,9 +166,17 @@
           new TypeReference<Map<InstancePartitionsType, String>>() { });
     }
 
+    Map<String, SegmentAssignmentConfig> segmentAssignmentConfigMap = null;
+    String segmentAssignmentConfigMapString = simpleFields.get(TableConfig.SEGMENT_ASSIGNMENT_CONFIG_MAP_KEY);
+    if (segmentAssignmentConfigMapString != null) {
+      segmentAssignmentConfigMap = JsonUtils.stringToObject(segmentAssignmentConfigMapString,
+          new TypeReference<Map<String, SegmentAssignmentConfig>>() { });
+    }
+
     return new TableConfig(tableName, tableType, validationConfig, tenantConfig, indexingConfig, customConfig,
-        quotaConfig, taskConfig, routingConfig, queryConfig, instanceAssignmentConfigMap, fieldConfigList, upsertConfig,
-        dedupConfig, ingestionConfig, tierConfigList, isDimTable, tunerConfigList, instancePartitionsMap);
+        quotaConfig, taskConfig, routingConfig, queryConfig, instanceAssignmentConfigMap,
+        fieldConfigList, upsertConfig, dedupConfig, ingestionConfig, tierConfigList, isDimTable,
+        tunerConfigList, instancePartitionsMap, segmentAssignmentConfigMap);
   }
 
   public static ZNRecord toZNRecord(TableConfig tableConfig)
@@ -234,6 +243,12 @@
       simpleFields.put(TableConfig.INSTANCE_PARTITIONS_MAP_CONFIG_KEY,
           JsonUtils.objectToString(tableConfig.getInstancePartitionsMap()));
     }
+    Map<String, SegmentAssignmentConfig> segmentAssignmentConfigMap =
+        tableConfig.getSegmentAssignmentConfigMap();
+    if (segmentAssignmentConfigMap != null) {
+      simpleFields
+          .put(TableConfig.SEGMENT_ASSIGNMENT_CONFIG_MAP_KEY, JsonUtils.objectToString(segmentAssignmentConfigMap));
+    }
 
     ZNRecord znRecord = new ZNRecord(tableConfig.getTableName());
     znRecord.setSimpleFields(simpleFields);
diff --git a/pinot-common/src/test/java/org/apache/pinot/common/utils/config/TableConfigSerDeTest.java b/pinot-common/src/test/java/org/apache/pinot/common/utils/config/TableConfigSerDeTest.java
index a4e15c4..aad62db 100644
--- a/pinot-common/src/test/java/org/apache/pinot/common/utils/config/TableConfigSerDeTest.java
+++ b/pinot-common/src/test/java/org/apache/pinot/common/utils/config/TableConfigSerDeTest.java
@@ -365,6 +365,7 @@
     assertNull(tableConfig.getRoutingConfig());
     assertNull(tableConfig.getQueryConfig());
     assertNull(tableConfig.getInstanceAssignmentConfigMap());
+    assertNull(tableConfig.getSegmentAssignmentConfigMap());
     assertNull(tableConfig.getFieldConfigList());
 
     // Serialize
@@ -380,6 +381,7 @@
     assertFalse(tableConfigJson.has(TableConfig.ROUTING_CONFIG_KEY));
     assertFalse(tableConfigJson.has(TableConfig.QUERY_CONFIG_KEY));
     assertFalse(tableConfigJson.has(TableConfig.INSTANCE_ASSIGNMENT_CONFIG_MAP_KEY));
+    assertFalse(tableConfigJson.has(TableConfig.SEGMENT_ASSIGNMENT_CONFIG_MAP_KEY));
     assertFalse(tableConfigJson.has(TableConfig.FIELD_CONFIG_LIST_KEY));
   }
 
diff --git a/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/segment/BaseSegmentAssignment.java b/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/segment/BaseSegmentAssignment.java
index e84d107..da8d503 100644
--- a/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/segment/BaseSegmentAssignment.java
+++ b/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/segment/BaseSegmentAssignment.java
@@ -20,19 +20,18 @@
 
 import com.google.common.base.Preconditions;
 import java.util.ArrayList;
-import java.util.Collections;
 import java.util.List;
 import java.util.Map;
-import java.util.Random;
-import java.util.Set;
 import java.util.TreeMap;
 import javax.annotation.Nullable;
 import org.apache.commons.lang3.tuple.Pair;
 import org.apache.helix.HelixManager;
 import org.apache.pinot.common.assignment.InstancePartitions;
 import org.apache.pinot.common.tier.Tier;
+import org.apache.pinot.controller.helix.core.assignment.segment.strategy.SegmentAssignmentStrategy;
 import org.apache.pinot.spi.config.table.ReplicaGroupStrategyConfig;
 import org.apache.pinot.spi.config.table.TableConfig;
+import org.apache.pinot.spi.config.table.assignment.InstancePartitionsType;
 import org.apache.pinot.spi.utils.CommonConstants.Helix.StateModel.SegmentStateModel;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -68,11 +67,13 @@
   protected String _tableNameWithType;
   protected int _replication;
   protected String _partitionColumn;
+  protected TableConfig _tableConfig;
 
   @Override
   public void init(HelixManager helixManager, TableConfig tableConfig) {
     _helixManager = helixManager;
     _tableNameWithType = tableConfig.getTableName();
+    _tableConfig = tableConfig;
     _replication = getReplication(tableConfig);
     ReplicaGroupStrategyConfig replicaGroupStrategyConfig =
         tableConfig.getValidationConfig().getReplicaGroupStrategyConfig();
@@ -93,62 +94,12 @@
   protected abstract int getReplication(TableConfig tableConfig);
 
   /**
-   * Helper method to check whether the number of replica-groups matches the table replication for replica-group based
-   * instance partitions. Log a warning if they do not match and use the one inside the instance partitions. The
-   * mismatch can happen when table is not configured correctly (table replication and numReplicaGroups does not match
-   * or replication changed without reassigning instances).
-   */
-  protected void checkReplication(InstancePartitions instancePartitions) {
-    int numReplicaGroups = instancePartitions.getNumReplicaGroups();
-    if (numReplicaGroups != _replication) {
-      _logger.warn(
-          "Number of replica-groups in instance partitions {}: {} does not match replication in table config: {} for "
-              + "table: {}, using: {}", instancePartitions.getInstancePartitionsName(), numReplicaGroups, _replication,
-          _tableNameWithType, numReplicaGroups);
-    }
-  }
-
-  /**
-   * Helper method to assign instances based on the current assignment and instance partitions.
-   */
-  protected List<String> assignSegment(String segmentName, Map<String, Map<String, String>> currentAssignment,
-      InstancePartitions instancePartitions) {
-    int numReplicaGroups = instancePartitions.getNumReplicaGroups();
-    int numPartitions = instancePartitions.getNumPartitions();
-
-    if (numReplicaGroups == 1 && numPartitions == 1) {
-      // Non-replica-group based assignment
-
-      return SegmentAssignmentUtils.assignSegmentWithoutReplicaGroup(currentAssignment, instancePartitions,
-          _replication);
-    } else {
-      // Replica-group based assignment
-
-      checkReplication(instancePartitions);
-
-      int partitionId;
-      if (_partitionColumn == null || numPartitions == 1) {
-        partitionId = 0;
-      } else {
-        // Uniformly spray the segment partitions over the instance partitions
-        partitionId = getSegmentPartitionId(segmentName) % numPartitions;
-      }
-
-      return SegmentAssignmentUtils.assignSegmentWithReplicaGroup(currentAssignment, instancePartitions, partitionId);
-    }
-  }
-
-  /**
-   * Returns the partition id of the segment.
-   */
-  protected abstract int getSegmentPartitionId(String segmentName);
-
-  /**
    * Rebalances tiers and returns a pair of tier assignments and non-tier assignment.
    */
   protected Pair<List<Map<String, Map<String, String>>>, Map<String, Map<String, String>>> rebalanceTiers(
       Map<String, Map<String, String>> currentAssignment, @Nullable List<Tier> sortedTiers,
-      @Nullable Map<String, InstancePartitions> tierInstancePartitionsMap, boolean bootstrap) {
+      @Nullable Map<String, InstancePartitions> tierInstancePartitionsMap, boolean bootstrap,
+      SegmentAssignmentStrategy segmentAssignmentStrategy, InstancePartitionsType instancePartitionsType) {
     if (sortedTiers == null) {
       return Pair.of(null, currentAssignment);
     }
@@ -175,7 +126,8 @@
 
       _logger.info("Rebalancing tier: {} for table: {} with bootstrap: {}, instance partitions: {}", tierName,
           _tableNameWithType, bootstrap, tierInstancePartitions);
-      newTierAssignments.add(reassignSegments(tierName, tierCurrentAssignment, tierInstancePartitions, bootstrap));
+      newTierAssignments.add(reassignSegments(tierName, tierCurrentAssignment, tierInstancePartitions, bootstrap,
+          segmentAssignmentStrategy, instancePartitionsType));
     }
 
     return Pair.of(newTierAssignments, tierSegmentAssignment.getNonTierSegmentAssignment());
@@ -185,7 +137,8 @@
    * Rebalances segments in the current assignment using the instancePartitions and returns new assignment
    */
   protected Map<String, Map<String, String>> reassignSegments(String instancePartitionType,
-      Map<String, Map<String, String>> currentAssignment, InstancePartitions instancePartitions, boolean bootstrap) {
+      Map<String, Map<String, String>> currentAssignment, InstancePartitions instancePartitions, boolean bootstrap,
+      SegmentAssignmentStrategy segmentAssignmentStrategy, InstancePartitionsType instancePartitionsType) {
     Map<String, Map<String, String>> newAssignment;
     if (bootstrap) {
       _logger.info("Bootstrapping segment assignment for {} segments of table: {}", instancePartitionType,
@@ -194,60 +147,16 @@
       // When bootstrap is enabled, start with an empty assignment and reassign all segments
       newAssignment = new TreeMap<>();
       for (String segment : currentAssignment.keySet()) {
-        List<String> assignedInstances = assignSegment(segment, newAssignment, instancePartitions);
-        newAssignment.put(segment,
-            SegmentAssignmentUtils.getInstanceStateMap(assignedInstances, SegmentStateModel.ONLINE));
+        List<String> assignedInstances =
+            segmentAssignmentStrategy.assignSegment(segment, newAssignment, instancePartitions, instancePartitionsType);
+        newAssignment
+            .put(segment, SegmentAssignmentUtils.getInstanceStateMap(assignedInstances, SegmentStateModel.ONLINE));
       }
     } else {
-      int numReplicaGroups = instancePartitions.getNumReplicaGroups();
-      int numPartitions = instancePartitions.getNumPartitions();
-
-      if (numReplicaGroups == 1 && numPartitions == 1) {
-        // Non-replica-group based assignment
-
-        List<String> instances =
-            SegmentAssignmentUtils.getInstancesForNonReplicaGroupBasedAssignment(instancePartitions, _replication);
-        newAssignment =
-            SegmentAssignmentUtils.rebalanceTableWithHelixAutoRebalanceStrategy(currentAssignment, instances,
-                _replication);
-      } else {
-        // Replica-group based assignment
-
-        checkReplication(instancePartitions);
-
-        if (_partitionColumn == null || numPartitions == 1) {
-          // NOTE: Shuffle the segments within the current assignment to avoid moving only new segments to the new added
-          //       servers, which might cause hotspot servers because queries tend to hit the new segments. Use the
-          //       table name hash as the random seed for the shuffle so that the result is deterministic.
-          List<String> segments = new ArrayList<>(currentAssignment.keySet());
-          Collections.shuffle(segments, new Random(_tableNameWithType.hashCode()));
-
-          newAssignment = new TreeMap<>();
-          SegmentAssignmentUtils.rebalanceReplicaGroupBasedPartition(currentAssignment, instancePartitions, 0, segments,
-              newAssignment);
-        } else {
-          Map<Integer, List<String>> instancePartitionIdToSegmentsMap =
-              getInstancePartitionIdToSegmentsMap(currentAssignment.keySet(), instancePartitions.getNumPartitions());
-
-          // NOTE: Shuffle the segments within the current assignment to avoid moving only new segments to the new added
-          //       servers, which might cause hotspot servers because queries tend to hit the new segments. Use the
-          //       table name hash as the random seed for the shuffle so that the result is deterministic.
-          Random random = new Random(_tableNameWithType.hashCode());
-          for (List<String> segments : instancePartitionIdToSegmentsMap.values()) {
-            Collections.shuffle(segments, random);
-          }
-
-          return SegmentAssignmentUtils.rebalanceReplicaGroupBasedTable(currentAssignment, instancePartitions,
-              instancePartitionIdToSegmentsMap);
-        }
-      }
+      // Use segment assignment strategy
+      newAssignment =
+          segmentAssignmentStrategy.reassignSegments(currentAssignment, instancePartitions, instancePartitionsType);
     }
     return newAssignment;
   }
-
-  /**
-   * Returns the instance partitions for the given segments.
-   */
-  protected abstract Map<Integer, List<String>> getInstancePartitionIdToSegmentsMap(Set<String> segments,
-      int numInstancePartitions);
 }
diff --git a/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/segment/OfflineSegmentAssignment.java b/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/segment/OfflineSegmentAssignment.java
index f944ba7..48cb028 100644
--- a/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/segment/OfflineSegmentAssignment.java
+++ b/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/segment/OfflineSegmentAssignment.java
@@ -19,21 +19,17 @@
 package org.apache.pinot.controller.helix.core.assignment.segment;
 
 import com.google.common.base.Preconditions;
-import java.util.ArrayList;
-import java.util.HashMap;
-import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
-import java.util.Set;
 import javax.annotation.Nullable;
 import org.apache.commons.collections.CollectionUtils;
 import org.apache.commons.configuration.Configuration;
 import org.apache.commons.lang3.tuple.Pair;
 import org.apache.pinot.common.assignment.InstancePartitions;
-import org.apache.pinot.common.metadata.ZKMetadataProvider;
-import org.apache.pinot.common.metadata.segment.SegmentZKMetadata;
 import org.apache.pinot.common.tier.Tier;
-import org.apache.pinot.segment.spi.partition.metadata.ColumnPartitionMetadata;
+import org.apache.pinot.controller.helix.core.assignment.segment.strategy.AllServersSegmentAssignmentStrategy;
+import org.apache.pinot.controller.helix.core.assignment.segment.strategy.SegmentAssignmentStrategy;
+import org.apache.pinot.controller.helix.core.assignment.segment.strategy.SegmentAssignmentStrategyFactory;
 import org.apache.pinot.spi.config.table.TableConfig;
 import org.apache.pinot.spi.config.table.assignment.InstancePartitionsType;
 import org.apache.pinot.spi.utils.RebalanceConfigConstants;
@@ -55,50 +51,46 @@
     InstancePartitions instancePartitions = instancePartitionsMap.get(InstancePartitionsType.OFFLINE);
     Preconditions.checkState(instancePartitions != null, "Failed to find OFFLINE instance partitions for table: %s",
         _tableNameWithType);
+    // Gets Segment assignment strategy for instance partitions
+    SegmentAssignmentStrategy segmentAssignmentStrategy = SegmentAssignmentStrategyFactory
+        .getSegmentAssignmentStrategy(_helixManager, _tableConfig, InstancePartitionsType.OFFLINE.toString(),
+            instancePartitions);
     _logger.info("Assigning segment: {} with instance partitions: {} for table: {}", segmentName, instancePartitions,
         _tableNameWithType);
-    List<String> instancesAssigned = assignSegment(segmentName, currentAssignment, instancePartitions);
+    List<String> instancesAssigned = segmentAssignmentStrategy
+        .assignSegment(segmentName, currentAssignment, instancePartitions, InstancePartitionsType.OFFLINE);
     _logger.info("Assigned segment: {} to instances: {} for table: {}", segmentName, instancesAssigned,
         _tableNameWithType);
     return instancesAssigned;
   }
 
   @Override
-  protected int getSegmentPartitionId(String segmentName) {
-    SegmentZKMetadata segmentZKMetadata =
-        ZKMetadataProvider.getSegmentZKMetadata(_helixManager.getHelixPropertyStore(), _tableNameWithType, segmentName);
-    Preconditions.checkState(segmentZKMetadata != null,
-        "Failed to find segment ZK metadata for segment: %s of table: %s", segmentName, _tableNameWithType);
-    return getPartitionId(segmentZKMetadata);
-  }
-
-  private int getPartitionId(SegmentZKMetadata segmentZKMetadata) {
-    String segmentName = segmentZKMetadata.getSegmentName();
-    ColumnPartitionMetadata partitionMetadata =
-        segmentZKMetadata.getPartitionMetadata().getColumnPartitionMap().get(_partitionColumn);
-    Preconditions.checkState(partitionMetadata != null,
-        "Segment ZK metadata for segment: %s of table: %s does not contain partition metadata for column: %s",
-        segmentName, _tableNameWithType, _partitionColumn);
-    Set<Integer> partitions = partitionMetadata.getPartitions();
-    Preconditions.checkState(partitions.size() == 1,
-        "Segment ZK metadata for segment: %s of table: %s contains multiple partitions for column: %s", segmentName,
-        _tableNameWithType, _partitionColumn);
-    return partitions.iterator().next();
-  }
-
-  @Override
   public Map<String, Map<String, String>> rebalanceTable(Map<String, Map<String, String>> currentAssignment,
       Map<InstancePartitionsType, InstancePartitions> instancePartitionsMap, @Nullable List<Tier> sortedTiers,
       @Nullable Map<String, InstancePartitions> tierInstancePartitionsMap, Configuration config) {
     InstancePartitions offlineInstancePartitions = instancePartitionsMap.get(InstancePartitionsType.OFFLINE);
-    Preconditions.checkState(offlineInstancePartitions != null,
-        "Failed to find OFFLINE instance partitions for table: %s", _tableNameWithType);
+    Preconditions
+        .checkState(offlineInstancePartitions != null, "Failed to find OFFLINE instance partitions for table: %s",
+            _tableNameWithType);
+    // Gets Segment assignment strategy for instance partitions
+    SegmentAssignmentStrategy segmentAssignmentStrategy = SegmentAssignmentStrategyFactory
+        .getSegmentAssignmentStrategy(_helixManager, _tableConfig, InstancePartitionsType.OFFLINE.toString(),
+            offlineInstancePartitions);
+    // TODO: Right now as per tier assignment, different instances will be picked up for different tiers which
+    // would produce incorrect results for Dim tables. In future, we need some preconditions to check if
+    // tierPartitionMap has single tier for Dim tables and remove below check
+    // See https://github.com/apache/pinot/issues/9047
+    if (segmentAssignmentStrategy instanceof AllServersSegmentAssignmentStrategy) {
+      return segmentAssignmentStrategy
+          .reassignSegments(currentAssignment, offlineInstancePartitions, InstancePartitionsType.OFFLINE);
+    }
     boolean bootstrap =
         config.getBoolean(RebalanceConfigConstants.BOOTSTRAP, RebalanceConfigConstants.DEFAULT_BOOTSTRAP);
 
     // Rebalance tiers first
     Pair<List<Map<String, Map<String, String>>>, Map<String, Map<String, String>>> pair =
-        rebalanceTiers(currentAssignment, sortedTiers, tierInstancePartitionsMap, bootstrap);
+        rebalanceTiers(currentAssignment, sortedTiers, tierInstancePartitionsMap, bootstrap, segmentAssignmentStrategy,
+            InstancePartitionsType.OFFLINE);
     List<Map<String, Map<String, String>>> newTierAssignments = pair.getLeft();
     Map<String, Map<String, String>> nonTierAssignment = pair.getRight();
 
@@ -106,9 +98,9 @@
         offlineInstancePartitions, bootstrap);
     Map<String, Map<String, String>> newAssignment =
         reassignSegments(InstancePartitionsType.OFFLINE.toString(), nonTierAssignment, offlineInstancePartitions,
-            bootstrap);
+            bootstrap, segmentAssignmentStrategy, InstancePartitionsType.OFFLINE);
 
-    // add tier assignments, if available
+    // Add tier assignments, if available
     if (CollectionUtils.isNotEmpty(newTierAssignments)) {
       newTierAssignments.forEach(newAssignment::putAll);
     }
@@ -117,27 +109,4 @@
         SegmentAssignmentUtils.getNumSegmentsToBeMovedPerInstance(currentAssignment, newAssignment));
     return newAssignment;
   }
-
-  @Override
-  protected Map<Integer, List<String>> getInstancePartitionIdToSegmentsMap(Set<String> segments,
-      int numInstancePartitions) {
-    // Fetch partition id from segment ZK metadata
-    List<SegmentZKMetadata> segmentsZKMetadata =
-        ZKMetadataProvider.getSegmentsZKMetadata(_helixManager.getHelixPropertyStore(), _tableNameWithType);
-
-    Map<Integer, List<String>> instancePartitionIdToSegmentsMap = new HashMap<>();
-    Set<String> segmentsWithoutZKMetadata = new HashSet<>(segments);
-    for (SegmentZKMetadata segmentZKMetadata : segmentsZKMetadata) {
-      String segmentName = segmentZKMetadata.getSegmentName();
-      if (segmentsWithoutZKMetadata.remove(segmentName)) {
-        int partitionId = getPartitionId(segmentZKMetadata);
-        int instancePartitionId = partitionId % numInstancePartitions;
-        instancePartitionIdToSegmentsMap.computeIfAbsent(instancePartitionId, k -> new ArrayList<>()).add(segmentName);
-      }
-    }
-    Preconditions.checkState(segmentsWithoutZKMetadata.isEmpty(), "Failed to find ZK metadata for segments: %s",
-        segmentsWithoutZKMetadata);
-
-    return instancePartitionIdToSegmentsMap;
-  }
 }
diff --git a/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/segment/RealtimeSegmentAssignment.java b/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/segment/RealtimeSegmentAssignment.java
index ba4d7fb..44bec27 100644
--- a/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/segment/RealtimeSegmentAssignment.java
+++ b/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/segment/RealtimeSegmentAssignment.java
@@ -20,10 +20,8 @@
 
 import com.google.common.base.Preconditions;
 import java.util.ArrayList;
-import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
-import java.util.Set;
 import java.util.TreeMap;
 import javax.annotation.Nullable;
 import org.apache.commons.collections.CollectionUtils;
@@ -31,7 +29,8 @@
 import org.apache.commons.lang3.tuple.Pair;
 import org.apache.pinot.common.assignment.InstancePartitions;
 import org.apache.pinot.common.tier.Tier;
-import org.apache.pinot.common.utils.SegmentUtils;
+import org.apache.pinot.controller.helix.core.assignment.segment.strategy.SegmentAssignmentStrategy;
+import org.apache.pinot.controller.helix.core.assignment.segment.strategy.SegmentAssignmentStrategyFactory;
 import org.apache.pinot.spi.config.table.TableConfig;
 import org.apache.pinot.spi.config.table.assignment.InstancePartitionsType;
 import org.apache.pinot.spi.utils.CommonConstants.Helix.StateModel.SegmentStateModel;
@@ -90,9 +89,20 @@
     InstancePartitions instancePartitions = typeToInstancePartitions.getValue();
     _logger.info("Assigning segment: {} with instance partitions: {} for table: {}", segmentName, instancePartitions,
         _tableNameWithType);
-    List<String> instancesAssigned =
-        instancePartitionsType == InstancePartitionsType.CONSUMING ? assignConsumingSegment(segmentName,
-            instancePartitions) : assignSegment(segmentName, currentAssignment, instancePartitions);
+
+    // TODO: remove this check after we also refactor consuming segments assignment strategy
+    // See https://github.com/apache/pinot/issues/9047
+    List<String> instancesAssigned;
+    if (instancePartitionsType == InstancePartitionsType.COMPLETED) {
+      // Gets Segment assignment strategy for instance partitions
+      SegmentAssignmentStrategy segmentAssignmentStrategy = SegmentAssignmentStrategyFactory
+          .getSegmentAssignmentStrategy(_helixManager, _tableConfig, instancePartitionsType.toString(),
+              instancePartitions);
+      instancesAssigned = segmentAssignmentStrategy
+          .assignSegment(segmentName, currentAssignment, instancePartitions, InstancePartitionsType.COMPLETED);
+    } else {
+      instancesAssigned = assignConsumingSegment(segmentName, instancePartitions);
+    }
     _logger.info("Assigned segment: {} to instances: {} for table: {}", segmentName, instancesAssigned,
         _tableNameWithType);
     return instancesAssigned;
@@ -102,7 +112,8 @@
    * Helper method to assign instances for CONSUMING segment based on the segment partition id and instance partitions.
    */
   private List<String> assignConsumingSegment(String segmentName, InstancePartitions instancePartitions) {
-    int segmentPartitionId = getSegmentPartitionId(segmentName);
+    int segmentPartitionId = SegmentAssignmentUtils
+        .getRealtimeSegmentPartitionId(segmentName, _tableNameWithType, _helixManager, _partitionColumn);
     int numReplicaGroups = instancePartitions.getNumReplicaGroups();
     int numPartitions = instancePartitions.getNumPartitions();
 
@@ -125,8 +136,14 @@
       return instancesAssigned;
     } else {
       // Replica-group based assignment
-
-      checkReplication(instancePartitions);
+      // TODO: Refactor check replication this for segment assignment strategy in follow up PR
+      // See https://github.com/apache/pinot/issues/9047
+      if (numReplicaGroups != _replication) {
+        _logger.warn(
+            "Number of replica-groups in instance partitions {}: {} does not match replication in table config: {} for "
+                + "table: {}, using: {}", instancePartitions.getInstancePartitionsName(), numReplicaGroups,
+            _replication, _tableNameWithType, numReplicaGroups);
+      }
       List<String> instancesAssigned = new ArrayList<>(numReplicaGroups);
 
       if (numPartitions == 1) {
@@ -156,35 +173,34 @@
   }
 
   @Override
-  protected int getSegmentPartitionId(String segmentName) {
-    Integer segmentPartitionId =
-        SegmentUtils.getRealtimeSegmentPartitionId(segmentName, _tableNameWithType, _helixManager, _partitionColumn);
-    if (segmentPartitionId == null) {
-      // This case is for the uploaded segments for which there's no partition information.
-      // A random, but consistent, partition id is calculated based on the hash code of the segment name.
-      // Note that '% 10K' is used to prevent having partition ids with large value which will be problematic later in
-      // instance assignment formula.
-      segmentPartitionId = Math.abs(segmentName.hashCode() % 10_000);
-    }
-    return segmentPartitionId;
-  }
-
-  @Override
   public Map<String, Map<String, String>> rebalanceTable(Map<String, Map<String, String>> currentAssignment,
       Map<InstancePartitionsType, InstancePartitions> instancePartitionsMap, @Nullable List<Tier> sortedTiers,
       @Nullable Map<String, InstancePartitions> tierInstancePartitionsMap, Configuration config) {
     InstancePartitions completedInstancePartitions = instancePartitionsMap.get(InstancePartitionsType.COMPLETED);
     InstancePartitions consumingInstancePartitions = instancePartitionsMap.get(InstancePartitionsType.CONSUMING);
-    Preconditions.checkState(consumingInstancePartitions != null,
-        "Failed to find CONSUMING instance partitions for table: %s", _tableNameWithType);
-    boolean includeConsuming = config.getBoolean(RebalanceConfigConstants.INCLUDE_CONSUMING,
-        RebalanceConfigConstants.DEFAULT_INCLUDE_CONSUMING);
+    Preconditions
+        .checkState(consumingInstancePartitions != null, "Failed to find CONSUMING instance partitions for table: %s",
+            _tableNameWithType);
+    boolean includeConsuming = config
+        .getBoolean(RebalanceConfigConstants.INCLUDE_CONSUMING, RebalanceConfigConstants.DEFAULT_INCLUDE_CONSUMING);
     boolean bootstrap =
         config.getBoolean(RebalanceConfigConstants.BOOTSTRAP, RebalanceConfigConstants.DEFAULT_BOOTSTRAP);
 
+    // TODO: remove this check after we also refactor consuming segments assignment strategy
+    // See https://github.com/apache/pinot/issues/9047
+    SegmentAssignmentStrategy segmentAssignmentStrategy = null;
+    if (completedInstancePartitions != null) {
+      // Gets Segment assignment strategy for instance partitions
+      segmentAssignmentStrategy = SegmentAssignmentStrategyFactory
+          .getSegmentAssignmentStrategy(_helixManager, _tableConfig, InstancePartitionsType.COMPLETED.toString(),
+              completedInstancePartitions);
+    }
+
     // Rebalance tiers first
     Pair<List<Map<String, Map<String, String>>>, Map<String, Map<String, String>>> pair =
-        rebalanceTiers(currentAssignment, sortedTiers, tierInstancePartitionsMap, bootstrap);
+        rebalanceTiers(currentAssignment, sortedTiers, tierInstancePartitionsMap, bootstrap, segmentAssignmentStrategy,
+            InstancePartitionsType.COMPLETED);
+
     List<Map<String, Map<String, String>>> newTierAssignments = pair.getLeft();
     Map<String, Map<String, String>> nonTierAssignment = pair.getRight();
 
@@ -202,10 +218,10 @@
     if (completedInstancePartitions != null) {
       // When COMPLETED instance partitions are provided, reassign COMPLETED segments in a balanced way (relocate
       // COMPLETED segments to offload them from CONSUMING instances to COMPLETED instances)
-      _logger.info("Reassigning COMPLETED segments with COMPLETED instance partitions for table: {}",
-          _tableNameWithType);
+      _logger
+          .info("Reassigning COMPLETED segments with COMPLETED instance partitions for table: {}", _tableNameWithType);
       newAssignment = reassignSegments(InstancePartitionsType.COMPLETED.toString(), completedSegmentAssignment,
-          completedInstancePartitions, bootstrap);
+          completedInstancePartitions, bootstrap, segmentAssignmentStrategy, InstancePartitionsType.COMPLETED);
     } else {
       // When COMPLETED instance partitions are not provided, reassign COMPLETED segments the same way as CONSUMING
       // segments with CONSUMING instance partitions (ensure COMPLETED segments are served by the correct instances when
@@ -227,8 +243,8 @@
     Map<String, Map<String, String>> consumingSegmentAssignment =
         completedConsumingOfflineSegmentAssignment.getConsumingSegmentAssignment();
     if (includeConsuming) {
-      _logger.info("Reassigning CONSUMING segments with CONSUMING instance partitions for table: {}",
-          _tableNameWithType);
+      _logger
+          .info("Reassigning CONSUMING segments with CONSUMING instance partitions for table: {}", _tableNameWithType);
 
       for (String segmentName : consumingSegmentAssignment.keySet()) {
         List<String> instancesAssigned = assignConsumingSegment(segmentName, consumingInstancePartitions);
@@ -253,15 +269,4 @@
         SegmentAssignmentUtils.getNumSegmentsToBeMovedPerInstance(currentAssignment, newAssignment));
     return newAssignment;
   }
-
-  @Override
-  protected Map<Integer, List<String>> getInstancePartitionIdToSegmentsMap(Set<String> segments,
-      int numInstancePartitions) {
-    Map<Integer, List<String>> instancePartitionIdToSegmentsMap = new HashMap<>();
-    for (String segmentName : segments) {
-      int instancePartitionId = getSegmentPartitionId(segmentName) % numInstancePartitions;
-      instancePartitionIdToSegmentsMap.computeIfAbsent(instancePartitionId, k -> new ArrayList<>()).add(segmentName);
-    }
-    return instancePartitionIdToSegmentsMap;
-  }
 }
diff --git a/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/segment/SegmentAssignment.java b/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/segment/SegmentAssignment.java
index 92e4980..9ccf9a8 100644
--- a/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/segment/SegmentAssignment.java
+++ b/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/segment/SegmentAssignment.java
@@ -31,16 +31,11 @@
 
 /**
  * Interface for segment assignment and table rebalance.
- * <p>
- * TODO: Add SegmentAssignmentStrategy interface and support custom segment assignment strategy (e.g. cost based segment
- *       assignment). SegmentAssignmentStrategy should not be coupled with SegmentAssignment, and SegmentAssignment
- *       should be able to choose the segment assignment strategy based on the configuration.
  */
 public interface SegmentAssignment {
 
   /**
    * Initializes the segment assignment.
-   *
    * @param helixManager Helix manager
    * @param tableConfig Table config
    */
diff --git a/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/segment/SegmentAssignmentFactory.java b/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/segment/SegmentAssignmentFactory.java
index 75f337c..a49f7b9 100644
--- a/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/segment/SegmentAssignmentFactory.java
+++ b/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/segment/SegmentAssignmentFactory.java
@@ -33,8 +33,7 @@
   public static SegmentAssignment getSegmentAssignment(HelixManager helixManager, TableConfig tableConfig) {
     SegmentAssignment segmentAssignment;
     if (tableConfig.getTableType() == TableType.OFFLINE) {
-      segmentAssignment =
-          tableConfig.isDimTable() ? new OfflineDimTableSegmentAssignment() : new OfflineSegmentAssignment();
+      segmentAssignment = new OfflineSegmentAssignment();
     } else {
       segmentAssignment = new RealtimeSegmentAssignment();
     }
diff --git a/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/segment/SegmentAssignmentUtils.java b/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/segment/SegmentAssignmentUtils.java
index 3cb5890..bbe3f2f 100644
--- a/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/segment/SegmentAssignmentUtils.java
+++ b/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/segment/SegmentAssignmentUtils.java
@@ -22,15 +22,22 @@
 import java.util.ArrayList;
 import java.util.Collection;
 import java.util.HashMap;
+import java.util.HashSet;
 import java.util.LinkedHashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.PriorityQueue;
 import java.util.Set;
 import java.util.TreeMap;
+import javax.annotation.Nullable;
+import org.apache.helix.HelixManager;
 import org.apache.helix.controller.rebalancer.strategy.AutoRebalanceStrategy;
 import org.apache.pinot.common.assignment.InstancePartitions;
+import org.apache.pinot.common.metadata.ZKMetadataProvider;
+import org.apache.pinot.common.metadata.segment.SegmentZKMetadata;
 import org.apache.pinot.common.tier.Tier;
+import org.apache.pinot.common.utils.SegmentUtils;
+import org.apache.pinot.segment.spi.partition.metadata.ColumnPartitionMetadata;
 import org.apache.pinot.spi.utils.CommonConstants.Helix.StateModel.SegmentStateModel;
 import org.apache.pinot.spi.utils.Pairs;
 
@@ -74,10 +81,10 @@
    */
   public static List<String> getInstancesForNonReplicaGroupBasedAssignment(InstancePartitions instancePartitions,
       int replication) {
-    Preconditions.checkState(
-        instancePartitions.getNumReplicaGroups() == 1 && instancePartitions.getNumPartitions() == 1,
-        "Instance partitions: %s should contain 1 replica and 1 partition for non-replica-group based assignment",
-        instancePartitions.getInstancePartitionsName());
+    Preconditions
+        .checkState(instancePartitions.getNumReplicaGroups() == 1 && instancePartitions.getNumPartitions() == 1,
+            "Instance partitions: %s should contain 1 replica and 1 partition for non-replica-group based assignment",
+            instancePartitions.getInstancePartitionsName());
     List<String> instances = instancePartitions.getInstances(0, 0);
     int numInstances = instances.size();
     Preconditions.checkState(numInstances >= replication,
@@ -127,8 +134,8 @@
     int numReplicaGroups = instancePartitions.getNumReplicaGroups();
     List<String> instancesAssigned = new ArrayList<>(numReplicaGroups);
     for (int replicaGroupId = 0; replicaGroupId < numReplicaGroups; replicaGroupId++) {
-      instancesAssigned.add(
-          instancePartitions.getInstances(partitionId, replicaGroupId).get(instanceIdWithLeastSegmentsAssigned));
+      instancesAssigned
+          .add(instancePartitions.getInstances(partitionId, replicaGroupId).get(instanceIdWithLeastSegmentsAssigned));
     }
     return instancesAssigned;
   }
@@ -212,8 +219,8 @@
       for (String instanceName : currentAssignment.get(segmentName).keySet()) {
         Integer instanceId = instanceNameToIdMap.get(instanceName);
         if (instanceId != null && numSegmentsAssignedPerInstance[instanceId] < targetNumSegmentsPerInstance) {
-          newAssignment.put(segmentName,
-              getReplicaGroupBasedInstanceStateMap(instancePartitions, partitionId, instanceId));
+          newAssignment
+              .put(segmentName, getReplicaGroupBasedInstanceStateMap(instancePartitions, partitionId, instanceId));
           numSegmentsAssignedPerInstance[instanceId]++;
           segmentAssigned = true;
           break;
@@ -248,8 +255,8 @@
     Map<String, String> instanceStateMap = new TreeMap<>();
     int numReplicaGroups = instancePartitions.getNumReplicaGroups();
     for (int replicaGroupId = 0; replicaGroupId < numReplicaGroups; replicaGroupId++) {
-      instanceStateMap.put(instancePartitions.getInstances(partitionId, replicaGroupId).get(instanceId),
-          SegmentStateModel.ONLINE);
+      instanceStateMap
+          .put(instancePartitions.getInstances(partitionId, replicaGroupId).get(instanceId), SegmentStateModel.ONLINE);
     }
     return instanceStateMap;
   }
@@ -388,4 +395,90 @@
       return _nonTierSegmentAssignment;
     }
   }
+
+  /**
+   * Returns a partition id for offline table
+   */
+  public static int getOfflineSegmentPartitionId(String segmentName, String offlineTableName, HelixManager helixManager,
+      @Nullable String partitionColumn) {
+    SegmentZKMetadata segmentZKMetadata =
+        ZKMetadataProvider.getSegmentZKMetadata(helixManager.getHelixPropertyStore(), offlineTableName, segmentName);
+    Preconditions
+        .checkState(segmentZKMetadata != null, "Failed to find segment ZK metadata for segment: %s of table: %s",
+            segmentName, offlineTableName);
+    return getPartitionId(segmentZKMetadata, offlineTableName, partitionColumn);
+  }
+
+  private static int getPartitionId(SegmentZKMetadata segmentZKMetadata, String offlineTableName,
+      @Nullable String partitionColumn) {
+    String segmentName = segmentZKMetadata.getSegmentName();
+    ColumnPartitionMetadata partitionMetadata =
+        segmentZKMetadata.getPartitionMetadata().getColumnPartitionMap().get(partitionColumn);
+    Preconditions.checkState(partitionMetadata != null,
+        "Segment ZK metadata for segment: %s of table: %s does not contain partition metadata for column: %s",
+        segmentName, offlineTableName, partitionColumn);
+    Set<Integer> partitions = partitionMetadata.getPartitions();
+    Preconditions.checkState(partitions.size() == 1,
+        "Segment ZK metadata for segment: %s of table: %s contains multiple partitions for column: %s", segmentName,
+        offlineTableName, partitionColumn);
+    return partitions.iterator().next();
+  }
+
+  /**
+   * Returns map of instance partition id to segments for offline tables
+   */
+  public static Map<Integer, List<String>> getOfflineInstancePartitionIdToSegmentsMap(Set<String> segments,
+      int numInstancePartitions, String offlineTableName, HelixManager helixManager, @Nullable String partitionColumn) {
+    // Fetch partition id from segment ZK metadata
+    List<SegmentZKMetadata> segmentsZKMetadata =
+        ZKMetadataProvider.getSegmentsZKMetadata(helixManager.getHelixPropertyStore(), offlineTableName);
+
+    Map<Integer, List<String>> instancePartitionIdToSegmentsMap = new HashMap<>();
+    Set<String> segmentsWithoutZKMetadata = new HashSet<>(segments);
+    for (SegmentZKMetadata segmentZKMetadata : segmentsZKMetadata) {
+      String segmentName = segmentZKMetadata.getSegmentName();
+      if (segmentsWithoutZKMetadata.remove(segmentName)) {
+        int partitionId = getPartitionId(segmentZKMetadata, offlineTableName, partitionColumn);
+        int instancePartitionId = partitionId % numInstancePartitions;
+        instancePartitionIdToSegmentsMap.computeIfAbsent(instancePartitionId, k -> new ArrayList<>()).add(segmentName);
+      }
+    }
+    Preconditions.checkState(segmentsWithoutZKMetadata.isEmpty(), "Failed to find ZK metadata for segments: %s",
+        segmentsWithoutZKMetadata);
+
+    return instancePartitionIdToSegmentsMap;
+  }
+
+  /**
+   * Returns a partition id for realtime table
+   */
+  public static int getRealtimeSegmentPartitionId(String segmentName, String realtimeTableName,
+      HelixManager helixManager, @Nullable String partitionColumn) {
+    Integer segmentPartitionId =
+        SegmentUtils.getRealtimeSegmentPartitionId(segmentName, realtimeTableName, helixManager, partitionColumn);
+    if (segmentPartitionId == null) {
+      // This case is for the uploaded segments for which there's no partition information.
+      // A random, but consistent, partition id is calculated based on the hash code of the segment name.
+      // Note that '% 10K' is used to prevent having partition ids with large value which will be problematic later in
+      // instance assignment formula.
+      segmentPartitionId = Math.abs(segmentName.hashCode() % 10_000);
+    }
+    return segmentPartitionId;
+  }
+
+  /**
+   * Returns map of instance partition id to segments for realtime tables
+   */
+  public static Map<Integer, List<String>> getRealtimeInstancePartitionIdToSegmentsMap(Set<String> segments,
+      int numInstancePartitions, String realtimeTableName, HelixManager helixManager,
+      @Nullable String partitionColumn) {
+    Map<Integer, List<String>> instancePartitionIdToSegmentsMap = new HashMap<>();
+    for (String segmentName : segments) {
+      int instancePartitionId =
+          getRealtimeSegmentPartitionId(segmentName, realtimeTableName, helixManager, partitionColumn)
+              % numInstancePartitions;
+      instancePartitionIdToSegmentsMap.computeIfAbsent(instancePartitionId, k -> new ArrayList<>()).add(segmentName);
+    }
+    return instancePartitionIdToSegmentsMap;
+  }
 }
diff --git a/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/segment/OfflineDimTableSegmentAssignment.java b/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/segment/strategy/AllServersSegmentAssignmentStrategy.java
similarity index 80%
rename from pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/segment/OfflineDimTableSegmentAssignment.java
rename to pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/segment/strategy/AllServersSegmentAssignmentStrategy.java
index d31b829..fee2b78 100644
--- a/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/segment/OfflineDimTableSegmentAssignment.java
+++ b/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/segment/strategy/AllServersSegmentAssignmentStrategy.java
@@ -16,7 +16,7 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-package org.apache.pinot.controller.helix.core.assignment.segment;
+package org.apache.pinot.controller.helix.core.assignment.segment.strategy;
 
 import com.google.common.base.Preconditions;
 import java.util.ArrayList;
@@ -24,17 +24,17 @@
 import java.util.Map;
 import java.util.Set;
 import java.util.TreeMap;
-import javax.annotation.Nullable;
-import org.apache.commons.configuration.Configuration;
 import org.apache.helix.HelixManager;
 import org.apache.pinot.common.assignment.InstancePartitions;
-import org.apache.pinot.common.tier.Tier;
 import org.apache.pinot.common.utils.config.TagNameUtils;
 import org.apache.pinot.common.utils.helix.HelixHelper;
+import org.apache.pinot.controller.helix.core.assignment.segment.SegmentAssignmentUtils;
 import org.apache.pinot.spi.config.table.TableConfig;
 import org.apache.pinot.spi.config.table.TenantConfig;
 import org.apache.pinot.spi.config.table.assignment.InstancePartitionsType;
 import org.apache.pinot.spi.utils.CommonConstants;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 
 /**
@@ -51,7 +51,8 @@
  *   </li>
  * </ul>
  */
-public class OfflineDimTableSegmentAssignment implements SegmentAssignment {
+public class AllServersSegmentAssignmentStrategy implements SegmentAssignmentStrategy {
+  private static final Logger LOGGER = LoggerFactory.getLogger(AllServersSegmentAssignmentStrategy.class);
 
   private HelixManager _helixManager;
   private String _offlineTableName;
@@ -59,15 +60,15 @@
 
   @Override
   public void init(HelixManager helixManager, TableConfig tableConfig) {
-    Preconditions.checkState(tableConfig.isDimTable(), "Not a dimension table: %s" + _offlineTableName);
     _helixManager = helixManager;
     _offlineTableName = tableConfig.getTableName();
     _tenantConfig = tableConfig.getTenantConfig();
+    LOGGER.info("Initialized AllServersSegmentAssignmentStrategy for table: {}", _offlineTableName);
   }
 
   @Override
   public List<String> assignSegment(String segmentName, Map<String, Map<String, String>> currentAssignment,
-      Map<InstancePartitionsType, InstancePartitions> instancePartitionsMap) {
+      InstancePartitions instancePartitions, InstancePartitionsType instancePartitionsType) {
     String serverTag = _tenantConfig.getServer();
     Set<String> instances = HelixHelper.getServerInstancesForTenant(_helixManager, serverTag);
     int numInstances = instances.size();
@@ -77,9 +78,9 @@
   }
 
   @Override
-  public Map<String, Map<String, String>> rebalanceTable(Map<String, Map<String, String>> currentAssignment,
-      Map<InstancePartitionsType, InstancePartitions> instancePartitionsMap, @Nullable List<Tier> sortedTiers,
-      @Nullable Map<String, InstancePartitions> tierInstancePartitionsMap, Configuration config) {
+  public Map<String, Map<String, String>> reassignSegments(Map<String, Map<String, String>> currentAssignment,
+      InstancePartitions instancePartitions, InstancePartitionsType instancePartitionsType) {
+
     String serverTag = _tenantConfig.getServer();
     Set<String> instances = HelixHelper.getServerInstancesForTenant(_helixManager, serverTag);
     Map<String, Map<String, String>> newAssignment = new TreeMap<>();
diff --git a/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/segment/strategy/BalancedNumSegmentAssignmentStrategy.java b/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/segment/strategy/BalancedNumSegmentAssignmentStrategy.java
new file mode 100644
index 0000000..58f1552
--- /dev/null
+++ b/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/segment/strategy/BalancedNumSegmentAssignmentStrategy.java
@@ -0,0 +1,86 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.controller.helix.core.assignment.segment.strategy;
+
+import com.google.common.base.Preconditions;
+import java.util.List;
+import java.util.Map;
+import org.apache.helix.HelixManager;
+import org.apache.pinot.common.assignment.InstancePartitions;
+import org.apache.pinot.controller.helix.core.assignment.segment.SegmentAssignmentUtils;
+import org.apache.pinot.spi.config.table.SegmentsValidationAndRetentionConfig;
+import org.apache.pinot.spi.config.table.TableConfig;
+import org.apache.pinot.spi.config.table.assignment.InstancePartitionsType;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+
+/**
+ * Balance num Segment assignment strategy class for offline segment assignment
+ * <ul>
+ *   <li>
+ *     <p>This segment assignment strategy is used when table replication/ num_replica_groups = 1.</p>
+ *   </li>
+ * </ul>
+ */
+public class BalancedNumSegmentAssignmentStrategy implements SegmentAssignmentStrategy {
+  private static final Logger LOGGER = LoggerFactory.getLogger(BalancedNumSegmentAssignmentStrategy.class);
+
+  private String _tableNameWithType;
+  private int _replication;
+
+  @Override
+  public void init(HelixManager helixManager, TableConfig tableConfig) {
+    _tableNameWithType = tableConfig.getTableName();
+    SegmentsValidationAndRetentionConfig validationAndRetentionConfig = tableConfig.getValidationConfig();
+    Preconditions.checkState(validationAndRetentionConfig != null, "Validation Config is null");
+    _replication = validationAndRetentionConfig.getReplicationNumber();
+    LOGGER.info("Initialized BalancedNumSegmentAssignmentStrategy for table: " + "{} with replication: {}",
+        _tableNameWithType, _replication);
+  }
+
+  @Override
+  public List<String> assignSegment(String segmentName, Map<String, Map<String, String>> currentAssignment,
+      InstancePartitions instancePartitions, InstancePartitionsType instancePartitionsType) {
+    validateSegmentAssignmentStrategy(instancePartitions);
+    return SegmentAssignmentUtils.assignSegmentWithoutReplicaGroup(currentAssignment, instancePartitions, _replication);
+  }
+
+  @Override
+  public Map<String, Map<String, String>> reassignSegments(Map<String, Map<String, String>> currentAssignment,
+      InstancePartitions instancePartitions, InstancePartitionsType instancePartitionsType) {
+    validateSegmentAssignmentStrategy(instancePartitions);
+    Map<String, Map<String, String>> newAssignment;
+    List<String> instances =
+        SegmentAssignmentUtils.getInstancesForNonReplicaGroupBasedAssignment(instancePartitions, _replication);
+    newAssignment =
+        SegmentAssignmentUtils.rebalanceTableWithHelixAutoRebalanceStrategy(currentAssignment, instances, _replication);
+    return newAssignment;
+  }
+
+  private void validateSegmentAssignmentStrategy(InstancePartitions instancePartitions) {
+    int numReplicaGroups = instancePartitions.getNumReplicaGroups();
+    int numPartitions = instancePartitions.getNumPartitions();
+    // Non-replica-group based assignment should have numReplicaGroups and numPartitions = 1
+    Preconditions.checkState(numReplicaGroups == 1,
+        "Replica groups should be 1 in order to use BalanceNumSegmentAssignmentStrategy");
+    Preconditions.checkState(numPartitions == 1,
+        "Replica groups should be 1 in order to use BalanceNumSegmentAssignmentStrategy");
+  }
+}
diff --git a/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/segment/strategy/ReplicaGroupSegmentAssignmentStrategy.java b/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/segment/strategy/ReplicaGroupSegmentAssignmentStrategy.java
new file mode 100644
index 0000000..1db91cb
--- /dev/null
+++ b/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/segment/strategy/ReplicaGroupSegmentAssignmentStrategy.java
@@ -0,0 +1,152 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.controller.helix.core.assignment.segment.strategy;
+
+import com.google.common.base.Preconditions;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.TreeMap;
+import org.apache.helix.HelixManager;
+import org.apache.pinot.common.assignment.InstancePartitions;
+import org.apache.pinot.controller.helix.core.assignment.segment.SegmentAssignmentUtils;
+import org.apache.pinot.spi.config.table.ReplicaGroupStrategyConfig;
+import org.apache.pinot.spi.config.table.SegmentsValidationAndRetentionConfig;
+import org.apache.pinot.spi.config.table.TableConfig;
+import org.apache.pinot.spi.config.table.TableType;
+import org.apache.pinot.spi.config.table.assignment.InstancePartitionsType;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+
+class ReplicaGroupSegmentAssignmentStrategy implements SegmentAssignmentStrategy {
+  private static final Logger LOGGER = LoggerFactory.getLogger(ReplicaGroupSegmentAssignmentStrategy.class);
+
+  private static HelixManager _helixManager;
+  private static String _tableName;
+  private static String _partitionColumn;
+  private int _replication;
+  private TableConfig _tableConfig;
+
+  @Override
+  public void init(HelixManager helixManager, TableConfig tableConfig) {
+    _helixManager = helixManager;
+    _tableConfig = tableConfig;
+    _tableName = tableConfig.getTableName();
+    SegmentsValidationAndRetentionConfig validationAndRetentionConfig = tableConfig.getValidationConfig();
+    Preconditions.checkState(validationAndRetentionConfig != null, "Validation Config is null");
+    _replication = validationAndRetentionConfig.getReplicationNumber();
+    ReplicaGroupStrategyConfig replicaGroupStrategyConfig =
+        validationAndRetentionConfig.getReplicaGroupStrategyConfig();
+    _partitionColumn = replicaGroupStrategyConfig != null ? replicaGroupStrategyConfig.getPartitionColumn() : null;
+    if (_partitionColumn == null) {
+      LOGGER.info("Initialized ReplicaGroupSegmentAssignmentStrategy "
+          + "with replication: {} without partition column for table: {} ", _replication, _tableName);
+    } else {
+      LOGGER.info("Initialized ReplicaGroupSegmentAssignmentStrategy "
+          + "with replication: {} and partition column: {} for table: {}", _replication, _partitionColumn, _tableName);
+    }
+  }
+
+  /**
+   * Assigns the segment for the replica-group based segment assignment strategy and returns the assigned instances.
+   */
+  @Override
+  public List<String> assignSegment(String segmentName, Map<String, Map<String, String>> currentAssignment,
+      InstancePartitions instancePartitions, InstancePartitionsType instancePartitionsType) {
+    int numPartitions = instancePartitions.getNumPartitions();
+    checkReplication(instancePartitions, _replication, _tableName);
+    int partitionId;
+    if (_partitionColumn == null || numPartitions == 1) {
+      partitionId = 0;
+    } else {
+      // Uniformly spray the segment partitions over the instance partitions
+      if (_tableConfig.getTableType() == TableType.OFFLINE) {
+        partitionId = SegmentAssignmentUtils
+            .getOfflineSegmentPartitionId(segmentName, _tableName, _helixManager, _partitionColumn) % numPartitions;
+      } else {
+        partitionId = SegmentAssignmentUtils
+            .getRealtimeSegmentPartitionId(segmentName, _tableName, _helixManager, _partitionColumn) % numPartitions;
+      }
+    }
+    return SegmentAssignmentUtils.assignSegmentWithReplicaGroup(currentAssignment, instancePartitions, partitionId);
+  }
+
+  @Override
+  public Map<String, Map<String, String>> reassignSegments(Map<String, Map<String, String>> currentAssignment,
+      InstancePartitions instancePartitions, InstancePartitionsType instancePartitionsType) {
+    Map<String, Map<String, String>> newAssignment;
+    int numPartitions = instancePartitions.getNumPartitions();
+
+    checkReplication(instancePartitions, _replication, _tableName);
+
+    if (_partitionColumn == null || numPartitions == 1) {
+      // NOTE: Shuffle the segments within the current assignment to avoid moving only new segments to the new added
+      //       servers, which might cause hotspot servers because queries tend to hit the new segments. Use the
+      //       table name hash as the random seed for the shuffle so that the result is deterministic.
+      List<String> segments = new ArrayList<>(currentAssignment.keySet());
+      Collections.shuffle(segments, new Random(_tableName.hashCode()));
+
+      newAssignment = new TreeMap<>();
+      SegmentAssignmentUtils
+          .rebalanceReplicaGroupBasedPartition(currentAssignment, instancePartitions, 0, segments, newAssignment);
+      return newAssignment;
+    } else {
+      Map<Integer, List<String>> instancePartitionIdToSegmentsMap;
+      if (_tableConfig.getTableType() == TableType.OFFLINE) {
+        instancePartitionIdToSegmentsMap = SegmentAssignmentUtils
+            .getOfflineInstancePartitionIdToSegmentsMap(currentAssignment.keySet(),
+                instancePartitions.getNumPartitions(), _tableName, _helixManager, _partitionColumn);
+      } else {
+        instancePartitionIdToSegmentsMap = SegmentAssignmentUtils
+            .getRealtimeInstancePartitionIdToSegmentsMap(currentAssignment.keySet(),
+                instancePartitions.getNumPartitions(), _tableName, _helixManager, _partitionColumn);
+      }
+
+      // NOTE: Shuffle the segments within the current assignment to avoid moving only new segments to the new added
+      //       servers, which might cause hotspot servers because queries tend to hit the new segments. Use the
+      //       table name hash as the random seed for the shuffle so that the result is deterministic.
+      Random random = new Random(_tableName.hashCode());
+      for (List<String> segments : instancePartitionIdToSegmentsMap.values()) {
+        Collections.shuffle(segments, random);
+      }
+
+      return SegmentAssignmentUtils
+          .rebalanceReplicaGroupBasedTable(currentAssignment, instancePartitions, instancePartitionIdToSegmentsMap);
+    }
+  }
+
+  /**
+   * Helper method to check whether the number of replica-groups matches the table replication for replica-group based
+   * instance partitions. Log a warning if they do not match and use the one inside the instance partitions. The
+   * mismatch can happen when table is not configured correctly (table replication and numReplicaGroups does not match
+   * or replication changed without reassigning instances).
+   */
+  private static void checkReplication(InstancePartitions instancePartitions, int replication, String tableName) {
+    int numReplicaGroups = instancePartitions.getNumReplicaGroups();
+    if (numReplicaGroups != replication) {
+      LOGGER.warn(
+          "Number of replica-groups in instance partitions {}: {} does not match replication in table config: {} for "
+              + "table: {}, using: {}", instancePartitions.getInstancePartitionsName(), numReplicaGroups, replication,
+          tableName, numReplicaGroups);
+    }
+  }
+}
diff --git a/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/segment/strategy/SegmentAssignmentStrategy.java b/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/segment/strategy/SegmentAssignmentStrategy.java
new file mode 100644
index 0000000..36222dd
--- /dev/null
+++ b/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/segment/strategy/SegmentAssignmentStrategy.java
@@ -0,0 +1,67 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.controller.helix.core.assignment.segment.strategy;
+
+import java.util.List;
+import java.util.Map;
+import org.apache.helix.HelixManager;
+import org.apache.pinot.common.assignment.InstancePartitions;
+import org.apache.pinot.spi.config.table.TableConfig;
+import org.apache.pinot.spi.config.table.assignment.InstancePartitionsType;
+
+
+/**
+ * Interface for segment assignment strategies
+ */
+public interface SegmentAssignmentStrategy {
+
+  /**
+   * Initializes the segment assignment strategy.
+   *
+   * @param helixManager Helix manager
+   * @param tableConfig Table config
+   */
+  void init(HelixManager helixManager, TableConfig tableConfig);
+
+  /**
+   * Assigns segment to instances. The assignment strategy will be configured in
+   * OfflineSegmentAssignment and RealtimeSegmentAssignment classes and depending on type of
+   * assignment strategy, this function will be called to assign a new segment
+   *
+   * @param segmentName Name of the segment to be assigned
+   * @param currentAssignment Current segment assignment of the table (map from segment name to instance state map)
+   * @param instancePartitions Instance partitions
+   * @return List of instances to assign the segment to
+   */
+  List<String> assignSegment(String segmentName, Map<String, Map<String, String>> currentAssignment,
+      InstancePartitions instancePartitions, InstancePartitionsType instancePartitionsType);
+
+  /**
+   * Re-assigns segment to instances. The assignment strategy will be configured in
+   * OfflineSegmentAssignment and RealtimeSegmentAssignment classes and depending on type of
+   * assignment strategy, this function will be called to re-assign a segment
+   * when the InstancePartitions has been changed.
+   *
+   * @param currentAssignment Current segment assignment of the table (map from segment name to instance state map)
+   * @param instancePartitions Instance partitions
+   * @return Rebalanced assignment for the segments per assignment strategy
+   */
+  Map<String, Map<String, String>> reassignSegments(Map<String, Map<String, String>> currentAssignment,
+      InstancePartitions instancePartitions, InstancePartitionsType instancePartitionsType);
+}
diff --git a/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/segment/strategy/SegmentAssignmentStrategyFactory.java b/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/segment/strategy/SegmentAssignmentStrategyFactory.java
new file mode 100644
index 0000000..3324327
--- /dev/null
+++ b/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/assignment/segment/strategy/SegmentAssignmentStrategyFactory.java
@@ -0,0 +1,105 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.controller.helix.core.assignment.segment.strategy;
+
+import com.google.common.base.Preconditions;
+import java.util.Map;
+import org.apache.helix.HelixManager;
+import org.apache.pinot.common.assignment.InstancePartitions;
+import org.apache.pinot.spi.config.table.TableConfig;
+import org.apache.pinot.spi.config.table.TableType;
+import org.apache.pinot.spi.config.table.assignment.SegmentAssignmentConfig;
+import org.apache.pinot.spi.utils.CommonConstants.Segment.AssignmentStrategy;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+
+/**
+ * Factory for SegmentAssignmentStrategy
+ */
+public class SegmentAssignmentStrategyFactory {
+
+  private static final Logger LOGGER = LoggerFactory.getLogger(SegmentAssignmentStrategyFactory.class);
+
+  private SegmentAssignmentStrategyFactory() {
+  }
+
+  /**
+   * Determine Segment Assignment strategy
+   */
+  public static SegmentAssignmentStrategy getSegmentAssignmentStrategy(HelixManager helixManager,
+      TableConfig tableConfig, String assignmentType, InstancePartitions instancePartitions) {
+    String assignmentStrategy = null;
+
+    TableType currentTableType = tableConfig.getTableType();
+    // TODO: Handle segment assignment strategy in future for CONSUMING segments in follow up PR
+    // See https://github.com/apache/pinot/issues/9047
+    // Accommodate new changes for assignment strategy
+    Map<String, SegmentAssignmentConfig> segmentAssignmentConfigMap = tableConfig.getSegmentAssignmentConfigMap();
+
+    if (tableConfig.isDimTable()) {
+      // Segment Assignment Strategy for DIM tables
+      Preconditions.checkState(currentTableType == TableType.OFFLINE,
+          "All Servers Segment assignment Strategy is only applicable to Dim OfflineTables");
+      SegmentAssignmentStrategy segmentAssignmentStrategy = new AllServersSegmentAssignmentStrategy();
+      segmentAssignmentStrategy.init(helixManager, tableConfig);
+      return segmentAssignmentStrategy;
+    } else {
+      // Try to determine segment assignment strategy from table config
+      if (segmentAssignmentConfigMap != null) {
+        SegmentAssignmentConfig segmentAssignmentConfig;
+        // Use the pre defined segment assignment strategy
+        segmentAssignmentConfig = segmentAssignmentConfigMap.get(assignmentType.toUpperCase());
+        // Segment assignment config is only applicable to offline tables and completed segments of real time tables
+        if (segmentAssignmentConfig != null) {
+          assignmentStrategy = segmentAssignmentConfig.getAssignmentStrategy().toLowerCase();
+        }
+      }
+    }
+
+    // Use the existing information to determine segment assignment strategy
+    SegmentAssignmentStrategy segmentAssignmentStrategy;
+    if (assignmentStrategy == null) {
+      // Calculate numReplicaGroups and numPartitions to determine segment assignment strategy
+      Preconditions
+          .checkState(instancePartitions != null, "Failed to find instance partitions for segment assignment strategy");
+      int numReplicaGroups = instancePartitions.getNumReplicaGroups();
+      int numPartitions = instancePartitions.getNumPartitions();
+
+      if (numReplicaGroups == 1 && numPartitions == 1) {
+        segmentAssignmentStrategy = new BalancedNumSegmentAssignmentStrategy();
+      } else {
+        segmentAssignmentStrategy = new ReplicaGroupSegmentAssignmentStrategy();
+      }
+    } else {
+      // Set segment assignment strategy depending on strategy set in table config
+      switch (assignmentStrategy) {
+        case AssignmentStrategy.REPLICA_GROUP_SEGMENT_ASSIGNMENT_STRATEGY:
+          segmentAssignmentStrategy = new ReplicaGroupSegmentAssignmentStrategy();
+          break;
+        case AssignmentStrategy.BALANCE_NUM_SEGMENT_ASSIGNMENT_STRATEGY:
+        default:
+          segmentAssignmentStrategy = new BalancedNumSegmentAssignmentStrategy();
+          break;
+      }
+    }
+    segmentAssignmentStrategy.init(helixManager, tableConfig);
+    return segmentAssignmentStrategy;
+  }
+}
diff --git a/pinot-controller/src/test/java/org/apache/pinot/controller/helix/core/assignment/segment/OfflineDimTableSegmentAssignmentTest.java b/pinot-controller/src/test/java/org/apache/pinot/controller/helix/core/assignment/segment/strategy/AllServersSegmentAssignmentStrategyTest.java
similarity index 73%
rename from pinot-controller/src/test/java/org/apache/pinot/controller/helix/core/assignment/segment/OfflineDimTableSegmentAssignmentTest.java
rename to pinot-controller/src/test/java/org/apache/pinot/controller/helix/core/assignment/segment/strategy/AllServersSegmentAssignmentStrategyTest.java
index 79ec716..0f64d60 100644
--- a/pinot-controller/src/test/java/org/apache/pinot/controller/helix/core/assignment/segment/OfflineDimTableSegmentAssignmentTest.java
+++ b/pinot-controller/src/test/java/org/apache/pinot/controller/helix/core/assignment/segment/strategy/AllServersSegmentAssignmentStrategyTest.java
@@ -16,10 +16,11 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-package org.apache.pinot.controller.helix.core.assignment.segment;
+package org.apache.pinot.controller.helix.core.assignment.segment.strategy;
 
 import com.google.common.collect.ImmutableList;
 import java.util.ArrayList;
+import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.TreeMap;
@@ -29,8 +30,14 @@
 import org.apache.helix.PropertyKey;
 import org.apache.helix.model.InstanceConfig;
 import org.apache.helix.zookeeper.datamodel.ZNRecord;
+import org.apache.pinot.common.assignment.InstancePartitions;
+import org.apache.pinot.controller.helix.core.assignment.segment.OfflineSegmentAssignment;
+import org.apache.pinot.controller.helix.core.assignment.segment.SegmentAssignment;
+import org.apache.pinot.controller.helix.core.assignment.segment.SegmentAssignmentFactory;
+import org.apache.pinot.controller.helix.core.assignment.segment.SegmentAssignmentTestUtils;
 import org.apache.pinot.spi.config.table.TableConfig;
 import org.apache.pinot.spi.config.table.TableType;
+import org.apache.pinot.spi.config.table.assignment.InstancePartitionsType;
 import org.apache.pinot.spi.utils.builder.TableConfigBuilder;
 import org.testng.annotations.BeforeClass;
 import org.testng.annotations.Test;
@@ -43,7 +50,7 @@
 import static org.testng.Assert.assertTrue;
 
 
-public class OfflineDimTableSegmentAssignmentTest {
+public class AllServersSegmentAssignmentStrategyTest {
   private static final String INSTANCE_NAME_PREFIX = "instance_";
   private static final int NUM_INSTANCES = 10;
   private static final List<String> INSTANCES =
@@ -56,19 +63,35 @@
 
   private SegmentAssignment _segmentAssignment;
   private HelixManager _helixManager;
+  private static final String INSTANCE_PARTITIONS_NAME =
+      InstancePartitionsType.OFFLINE.getInstancePartitionsName(RAW_TABLE_NAME);
+  private static final String COMPLETED_INSTANCE_NAME_PREFIX = "completedInstance_";
+  private static final String COMPLETED_INSTANCE_PARTITIONS_NAME =
+      InstancePartitionsType.COMPLETED.getInstancePartitionsName(RAW_TABLE_NAME);
+  private static final List<String> COMPLETED_INSTANCES =
+      SegmentAssignmentTestUtils.getNameList(COMPLETED_INSTANCE_NAME_PREFIX, NUM_INSTANCES);
+
+  private Map<InstancePartitionsType, InstancePartitions> _instancePartitionsMap = new HashMap<>();
 
   @BeforeClass
   public void setup() {
     TableConfig tableConfig =
         new TableConfigBuilder(TableType.OFFLINE).setTableName(RAW_TABLE_NAME).setIsDimTable(true).build();
-
     _helixManager = mock(HelixManager.class);
+    InstancePartitions instancePartitions = new InstancePartitions(INSTANCE_PARTITIONS_NAME);
+    instancePartitions.setInstances(0, 0, INSTANCES);
+
+    InstancePartitions completedInstancePartitions = new InstancePartitions(COMPLETED_INSTANCE_PARTITIONS_NAME);
+    completedInstancePartitions.setInstances(0, 0, COMPLETED_INSTANCES);
+
+    _instancePartitionsMap.put(InstancePartitionsType.OFFLINE, instancePartitions);
+    _instancePartitionsMap.put(InstancePartitionsType.COMPLETED, completedInstancePartitions);
     _segmentAssignment = SegmentAssignmentFactory.getSegmentAssignment(_helixManager, tableConfig);
   }
 
   @Test
   public void testFactory() {
-    assertTrue(_segmentAssignment instanceof OfflineDimTableSegmentAssignment);
+    assertTrue(_segmentAssignment instanceof OfflineSegmentAssignment);
   }
 
   @Test
@@ -85,7 +108,7 @@
     when(dataAccessor.getChildValues(builder.instanceConfigs(), true)).thenReturn(instanceConfigList);
     when(_helixManager.getHelixDataAccessor()).thenReturn(dataAccessor);
 
-    List<String> instances = _segmentAssignment.assignSegment(SEGMENT_NAME, new TreeMap(), new TreeMap());
+    List<String> instances = _segmentAssignment.assignSegment(SEGMENT_NAME, new TreeMap(), _instancePartitionsMap);
     assertEquals(instances.size(), NUM_INSTANCES);
     assertEqualsNoOrder(instances.toArray(), INSTANCES.toArray());
 
@@ -101,7 +124,7 @@
     when(dataAccessor.getChildValues(builder.instanceConfigs(), true)).thenReturn(instanceConfigList);
 
     Map<String, Map<String, String>> newAssignment =
-        _segmentAssignment.rebalanceTable(currentAssignment, new TreeMap<>(), null, null, null);
+        _segmentAssignment.rebalanceTable(currentAssignment, _instancePartitionsMap, null, null, null);
     assertEquals(newAssignment.get(SEGMENT_NAME).size(), NUM_INSTANCES - 1);
   }
 
@@ -119,7 +142,7 @@
     when(dataAccessor.getChildValues(builder.instanceConfigs(), true)).thenReturn(instanceConfigList);
     when(_helixManager.getHelixDataAccessor()).thenReturn(dataAccessor);
 
-    List<String> instances = _segmentAssignment.assignSegment(SEGMENT_NAME, new TreeMap(), new TreeMap());
+    List<String> instances = _segmentAssignment.assignSegment(SEGMENT_NAME, new TreeMap(), _instancePartitionsMap);
     assertEquals(instances.size(), NUM_INSTANCES);
     assertEqualsNoOrder(instances.toArray(), INSTANCES.toArray());
   }
diff --git a/pinot-controller/src/test/java/org/apache/pinot/controller/helix/core/assignment/segment/OfflineNonReplicaGroupSegmentAssignmentTest.java b/pinot-controller/src/test/java/org/apache/pinot/controller/helix/core/assignment/segment/strategy/BalancedNumSegmentAssignmentStrategyTest.java
similarity index 84%
rename from pinot-controller/src/test/java/org/apache/pinot/controller/helix/core/assignment/segment/OfflineNonReplicaGroupSegmentAssignmentTest.java
rename to pinot-controller/src/test/java/org/apache/pinot/controller/helix/core/assignment/segment/strategy/BalancedNumSegmentAssignmentStrategyTest.java
index c6cdeea..8cef7ef 100644
--- a/pinot-controller/src/test/java/org/apache/pinot/controller/helix/core/assignment/segment/OfflineNonReplicaGroupSegmentAssignmentTest.java
+++ b/pinot-controller/src/test/java/org/apache/pinot/controller/helix/core/assignment/segment/strategy/BalancedNumSegmentAssignmentStrategyTest.java
@@ -16,7 +16,7 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-package org.apache.pinot.controller.helix.core.assignment.segment;
+package org.apache.pinot.controller.helix.core.assignment.segment.strategy;
 
 import java.util.ArrayList;
 import java.util.Arrays;
@@ -27,6 +27,11 @@
 import org.apache.commons.configuration.BaseConfiguration;
 import org.apache.commons.configuration.Configuration;
 import org.apache.pinot.common.assignment.InstancePartitions;
+import org.apache.pinot.controller.helix.core.assignment.segment.OfflineSegmentAssignment;
+import org.apache.pinot.controller.helix.core.assignment.segment.SegmentAssignment;
+import org.apache.pinot.controller.helix.core.assignment.segment.SegmentAssignmentFactory;
+import org.apache.pinot.controller.helix.core.assignment.segment.SegmentAssignmentTestUtils;
+import org.apache.pinot.controller.helix.core.assignment.segment.SegmentAssignmentUtils;
 import org.apache.pinot.spi.config.table.TableConfig;
 import org.apache.pinot.spi.config.table.TableType;
 import org.apache.pinot.spi.config.table.assignment.InstancePartitionsType;
@@ -40,7 +45,7 @@
 import static org.testng.Assert.assertTrue;
 
 
-public class OfflineNonReplicaGroupSegmentAssignmentTest {
+public class BalancedNumSegmentAssignmentStrategyTest {
   private static final int NUM_REPLICAS = 3;
   private static final String SEGMENT_NAME_PREFIX = "segment_";
   private static final int NUM_SEGMENTS = 100;
@@ -53,7 +58,6 @@
   private static final String RAW_TABLE_NAME = "assignmentTable";
   private static final String INSTANCE_PARTITIONS_NAME =
       InstancePartitionsType.OFFLINE.getInstancePartitionsName(RAW_TABLE_NAME);
-
   private SegmentAssignment _segmentAssignment;
   private Map<InstancePartitionsType, InstancePartitions> _instancePartitionsMap;
 
@@ -96,8 +100,8 @@
         assertEquals(instancesAssigned.get(replicaId), INSTANCES.get(expectedAssignedInstanceId));
         expectedAssignedInstanceId = (expectedAssignedInstanceId + 1) % NUM_INSTANCES;
       }
-      currentAssignment.put(segmentName,
-          SegmentAssignmentUtils.getInstanceStateMap(instancesAssigned, SegmentStateModel.ONLINE));
+      currentAssignment
+          .put(segmentName, SegmentAssignmentUtils.getInstanceStateMap(instancesAssigned, SegmentStateModel.ONLINE));
     }
   }
 
@@ -107,8 +111,8 @@
     for (String segmentName : SEGMENTS) {
       List<String> instancesAssigned =
           _segmentAssignment.assignSegment(segmentName, currentAssignment, _instancePartitionsMap);
-      currentAssignment.put(segmentName,
-          SegmentAssignmentUtils.getInstanceStateMap(instancesAssigned, SegmentStateModel.ONLINE));
+      currentAssignment
+          .put(segmentName, SegmentAssignmentUtils.getInstanceStateMap(instancesAssigned, SegmentStateModel.ONLINE));
     }
 
     // There should be 100 segments assigned
@@ -125,8 +129,9 @@
     Arrays.fill(expectedNumSegmentsAssignedPerInstance, numSegmentsPerInstance);
     assertEquals(numSegmentsAssignedPerInstance, expectedNumSegmentsAssignedPerInstance);
     // Current assignment should already be balanced
-    assertEquals(_segmentAssignment.rebalanceTable(currentAssignment, _instancePartitionsMap, null, null,
-        new BaseConfiguration()), currentAssignment);
+    assertEquals(_segmentAssignment
+            .rebalanceTable(currentAssignment, _instancePartitionsMap, null, null, new BaseConfiguration()),
+        currentAssignment);
   }
 
   @Test
@@ -135,8 +140,8 @@
     for (String segmentName : SEGMENTS) {
       List<String> instancesAssigned =
           _segmentAssignment.assignSegment(segmentName, currentAssignment, _instancePartitionsMap);
-      currentAssignment.put(segmentName,
-          SegmentAssignmentUtils.getInstanceStateMap(instancesAssigned, SegmentStateModel.ONLINE));
+      currentAssignment
+          .put(segmentName, SegmentAssignmentUtils.getInstanceStateMap(instancesAssigned, SegmentStateModel.ONLINE));
     }
 
     // Bootstrap table should reassign all segments based on their alphabetical order
diff --git a/pinot-controller/src/test/java/org/apache/pinot/controller/helix/core/assignment/segment/OfflineReplicaGroupSegmentAssignmentTest.java b/pinot-controller/src/test/java/org/apache/pinot/controller/helix/core/assignment/segment/strategy/ReplicaGroupSegmentAssignmentStrategyTest.java
similarity index 85%
rename from pinot-controller/src/test/java/org/apache/pinot/controller/helix/core/assignment/segment/OfflineReplicaGroupSegmentAssignmentTest.java
rename to pinot-controller/src/test/java/org/apache/pinot/controller/helix/core/assignment/segment/strategy/ReplicaGroupSegmentAssignmentStrategyTest.java
index ab7c2c6..f2bf9fc 100644
--- a/pinot-controller/src/test/java/org/apache/pinot/controller/helix/core/assignment/segment/OfflineReplicaGroupSegmentAssignmentTest.java
+++ b/pinot-controller/src/test/java/org/apache/pinot/controller/helix/core/assignment/segment/strategy/ReplicaGroupSegmentAssignmentStrategyTest.java
@@ -16,7 +16,7 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-package org.apache.pinot.controller.helix.core.assignment.segment;
+package org.apache.pinot.controller.helix.core.assignment.segment.strategy;
 
 import com.google.common.collect.ImmutableMap;
 import java.util.ArrayList;
@@ -34,6 +34,11 @@
 import org.apache.pinot.common.metadata.ZKMetadataProvider;
 import org.apache.pinot.common.metadata.segment.SegmentPartitionMetadata;
 import org.apache.pinot.common.metadata.segment.SegmentZKMetadata;
+import org.apache.pinot.controller.helix.core.assignment.segment.OfflineSegmentAssignment;
+import org.apache.pinot.controller.helix.core.assignment.segment.SegmentAssignment;
+import org.apache.pinot.controller.helix.core.assignment.segment.SegmentAssignmentFactory;
+import org.apache.pinot.controller.helix.core.assignment.segment.SegmentAssignmentTestUtils;
+import org.apache.pinot.controller.helix.core.assignment.segment.SegmentAssignmentUtils;
 import org.apache.pinot.segment.spi.partition.metadata.ColumnPartitionMetadata;
 import org.apache.pinot.spi.config.table.ReplicaGroupStrategyConfig;
 import org.apache.pinot.spi.config.table.TableConfig;
@@ -57,7 +62,7 @@
 
 
 @SuppressWarnings("unchecked")
-public class OfflineReplicaGroupSegmentAssignmentTest {
+public class ReplicaGroupSegmentAssignmentStrategyTest {
   private static final int NUM_REPLICAS = 3;
   private static final String SEGMENT_NAME_PREFIX = "segment_";
   private static final int NUM_SEGMENTS = 12;
@@ -126,9 +131,9 @@
           any(), anyInt())).thenReturn(segmentZKMetadataZNRecord);
       segmentZKMetadataZNRecords.add(segmentZKMetadataZNRecord);
     }
-    when(propertyStoreWithPartitions.getChildren(
-        eq(ZKMetadataProvider.constructPropertyStorePathForResource(OFFLINE_TABLE_NAME_WITH_PARTITION)), any(),
-        anyInt(), anyInt(), anyInt())).thenReturn(segmentZKMetadataZNRecords);
+    when(propertyStoreWithPartitions
+        .getChildren(eq(ZKMetadataProvider.constructPropertyStorePathForResource(OFFLINE_TABLE_NAME_WITH_PARTITION)),
+            any(), anyInt(), anyInt(), anyInt())).thenReturn(segmentZKMetadataZNRecords);
     HelixManager helixManagerWithPartitions = mock(HelixManager.class);
     when(helixManagerWithPartitions.getHelixPropertyStore()).thenReturn(propertyStoreWithPartitions);
 
@@ -176,8 +181,8 @@
     Map<String, Map<String, String>> currentAssignment = new TreeMap<>();
     for (int segmentId = 0; segmentId < NUM_SEGMENTS; segmentId++) {
       String segmentName = SEGMENTS.get(segmentId);
-      List<String> instancesAssigned = _segmentAssignmentWithoutPartition.assignSegment(segmentName, currentAssignment,
-          _instancePartitionsMapWithoutPartition);
+      List<String> instancesAssigned = _segmentAssignmentWithoutPartition
+          .assignSegment(segmentName, currentAssignment, _instancePartitionsMapWithoutPartition);
       assertEquals(instancesAssigned.size(), NUM_REPLICAS);
 
       // Segment 0 should be assigned to instance 0, 6, 12
@@ -195,8 +200,8 @@
         assertEquals(instancesAssigned.get(replicaGroupId), INSTANCES.get(expectedAssignedInstanceId));
       }
 
-      currentAssignment.put(segmentName,
-          SegmentAssignmentUtils.getInstanceStateMap(instancesAssigned, SegmentStateModel.ONLINE));
+      currentAssignment
+          .put(segmentName, SegmentAssignmentUtils.getInstanceStateMap(instancesAssigned, SegmentStateModel.ONLINE));
     }
   }
 
@@ -207,8 +212,8 @@
     int numInstancesPerPartition = numInstancesPerReplicaGroup / NUM_PARTITIONS;
     for (int segmentId = 0; segmentId < NUM_SEGMENTS; segmentId++) {
       String segmentName = SEGMENTS.get(segmentId);
-      List<String> instancesAssigned = _segmentAssignmentWithPartition.assignSegment(segmentName, currentAssignment,
-          _instancePartitionsMapWithPartition);
+      List<String> instancesAssigned = _segmentAssignmentWithPartition
+          .assignSegment(segmentName, currentAssignment, _instancePartitionsMapWithPartition);
       assertEquals(instancesAssigned.size(), NUM_REPLICAS);
 
       // Segment 0 (partition 0) should be assigned to instance 0, 6, 12
@@ -228,8 +233,8 @@
         assertEquals(instancesAssigned.get(replicaGroupId), INSTANCES.get(expectedAssignedInstanceId));
       }
 
-      currentAssignment.put(segmentName,
-          SegmentAssignmentUtils.getInstanceStateMap(instancesAssigned, SegmentStateModel.ONLINE));
+      currentAssignment
+          .put(segmentName, SegmentAssignmentUtils.getInstanceStateMap(instancesAssigned, SegmentStateModel.ONLINE));
     }
   }
 
@@ -237,10 +242,10 @@
   public void testTableBalancedWithoutPartition() {
     Map<String, Map<String, String>> currentAssignment = new TreeMap<>();
     for (String segmentName : SEGMENTS) {
-      List<String> instancesAssigned = _segmentAssignmentWithoutPartition.assignSegment(segmentName, currentAssignment,
-          _instancePartitionsMapWithoutPartition);
-      currentAssignment.put(segmentName,
-          SegmentAssignmentUtils.getInstanceStateMap(instancesAssigned, SegmentStateModel.ONLINE));
+      List<String> instancesAssigned = _segmentAssignmentWithoutPartition
+          .assignSegment(segmentName, currentAssignment, _instancePartitionsMapWithoutPartition);
+      currentAssignment
+          .put(segmentName, SegmentAssignmentUtils.getInstanceStateMap(instancesAssigned, SegmentStateModel.ONLINE));
     }
 
     assertEquals(currentAssignment.size(), NUM_SEGMENTS);
@@ -255,19 +260,20 @@
     Arrays.fill(expectedNumSegmentsAssignedPerInstance, numSegmentsPerInstance);
     assertEquals(numSegmentsAssignedPerInstance, expectedNumSegmentsAssignedPerInstance);
     // Current assignment should already be balanced
-    assertEquals(
-        _segmentAssignmentWithoutPartition.rebalanceTable(currentAssignment, _instancePartitionsMapWithoutPartition,
-            null, null, new BaseConfiguration()), currentAssignment);
+    assertEquals(_segmentAssignmentWithoutPartition
+            .rebalanceTable(currentAssignment, _instancePartitionsMapWithoutPartition, null, null,
+                new BaseConfiguration()),
+        currentAssignment);
   }
 
   @Test
   public void testTableBalancedWithPartition() {
     Map<String, Map<String, String>> currentAssignment = new TreeMap<>();
     for (String segmentName : SEGMENTS) {
-      List<String> instancesAssigned = _segmentAssignmentWithPartition.assignSegment(segmentName, currentAssignment,
-          _instancePartitionsMapWithPartition);
-      currentAssignment.put(segmentName,
-          SegmentAssignmentUtils.getInstanceStateMap(instancesAssigned, SegmentStateModel.ONLINE));
+      List<String> instancesAssigned = _segmentAssignmentWithPartition
+          .assignSegment(segmentName, currentAssignment, _instancePartitionsMapWithPartition);
+      currentAssignment
+          .put(segmentName, SegmentAssignmentUtils.getInstanceStateMap(instancesAssigned, SegmentStateModel.ONLINE));
     }
 
     assertEquals(currentAssignment.size(), NUM_SEGMENTS);
@@ -282,27 +288,27 @@
     Arrays.fill(expectedNumSegmentsAssignedPerInstance, numSegmentsPerInstance);
     assertEquals(numSegmentsAssignedPerInstance, expectedNumSegmentsAssignedPerInstance);
     // Current assignment should already be balanced
-    assertEquals(
-        _segmentAssignmentWithPartition.rebalanceTable(currentAssignment, _instancePartitionsMapWithPartition, null,
-            null, new BaseConfiguration()), currentAssignment);
+    assertEquals(_segmentAssignmentWithPartition
+            .rebalanceTable(currentAssignment, _instancePartitionsMapWithPartition, null, null,
+                new BaseConfiguration()),
+        currentAssignment);
   }
 
   @Test
   public void testBootstrapTableWithoutPartition() {
     Map<String, Map<String, String>> currentAssignment = new TreeMap<>();
     for (String segmentName : SEGMENTS) {
-      List<String> instancesAssigned = _segmentAssignmentWithoutPartition.assignSegment(segmentName, currentAssignment,
-          _instancePartitionsMapWithoutPartition);
-      currentAssignment.put(segmentName,
-          SegmentAssignmentUtils.getInstanceStateMap(instancesAssigned, SegmentStateModel.ONLINE));
+      List<String> instancesAssigned = _segmentAssignmentWithoutPartition
+          .assignSegment(segmentName, currentAssignment, _instancePartitionsMapWithoutPartition);
+      currentAssignment
+          .put(segmentName, SegmentAssignmentUtils.getInstanceStateMap(instancesAssigned, SegmentStateModel.ONLINE));
     }
 
     // Bootstrap table should reassign all segments based on their alphabetical order
     Configuration rebalanceConfig = new BaseConfiguration();
     rebalanceConfig.setProperty(RebalanceConfigConstants.BOOTSTRAP, true);
-    Map<String, Map<String, String>> newAssignment =
-        _segmentAssignmentWithoutPartition.rebalanceTable(currentAssignment, _instancePartitionsMapWithoutPartition,
-            null, null, rebalanceConfig);
+    Map<String, Map<String, String>> newAssignment = _segmentAssignmentWithoutPartition
+        .rebalanceTable(currentAssignment, _instancePartitionsMapWithoutPartition, null, null, rebalanceConfig);
     assertEquals(newAssignment.size(), NUM_SEGMENTS);
     List<String> sortedSegments = new ArrayList<>(SEGMENTS);
     sortedSegments.sort(null);
@@ -315,18 +321,17 @@
   public void testBootstrapTableWithPartition() {
     Map<String, Map<String, String>> currentAssignment = new TreeMap<>();
     for (String segmentName : SEGMENTS) {
-      List<String> instancesAssigned = _segmentAssignmentWithPartition.assignSegment(segmentName, currentAssignment,
-          _instancePartitionsMapWithPartition);
-      currentAssignment.put(segmentName,
-          SegmentAssignmentUtils.getInstanceStateMap(instancesAssigned, SegmentStateModel.ONLINE));
+      List<String> instancesAssigned = _segmentAssignmentWithPartition
+          .assignSegment(segmentName, currentAssignment, _instancePartitionsMapWithPartition);
+      currentAssignment
+          .put(segmentName, SegmentAssignmentUtils.getInstanceStateMap(instancesAssigned, SegmentStateModel.ONLINE));
     }
 
     // Bootstrap table should reassign all segments based on their alphabetical order within the partition
     Configuration rebalanceConfig = new BaseConfiguration();
     rebalanceConfig.setProperty(RebalanceConfigConstants.BOOTSTRAP, true);
-    Map<String, Map<String, String>> newAssignment =
-        _segmentAssignmentWithPartition.rebalanceTable(currentAssignment, _instancePartitionsMapWithPartition, null,
-            null, rebalanceConfig);
+    Map<String, Map<String, String>> newAssignment = _segmentAssignmentWithPartition
+        .rebalanceTable(currentAssignment, _instancePartitionsMapWithPartition, null, null, rebalanceConfig);
     assertEquals(newAssignment.size(), NUM_SEGMENTS);
     int numSegmentsPerPartition = NUM_SEGMENTS / NUM_PARTITIONS;
     String[][] partitionIdToSegmentsMap = new String[NUM_PARTITIONS][numSegmentsPerPartition];
@@ -355,12 +360,12 @@
     String instance1 = INSTANCE_NAME_PREFIX + "1";
     String instance2 = INSTANCE_NAME_PREFIX + "2";
     Map<String, Map<String, String>> unbalancedAssignment = new TreeMap<>();
-    SEGMENTS.forEach(segName -> unbalancedAssignment.put(segName,
-        ImmutableMap.of(instance0, SegmentStateModel.ONLINE, instance1, SegmentStateModel.ONLINE, instance2,
+    SEGMENTS.forEach(segName -> unbalancedAssignment.put(segName, ImmutableMap
+        .of(instance0, SegmentStateModel.ONLINE, instance1, SegmentStateModel.ONLINE, instance2,
             SegmentStateModel.ONLINE)));
-    Map<String, Map<String, String>> balancedAssignment =
-        _segmentAssignmentWithPartition.rebalanceTable(unbalancedAssignment, _instancePartitionsMapWithoutPartition,
-            null, null, new BaseConfiguration());
+    Map<String, Map<String, String>> balancedAssignment = _segmentAssignmentWithPartition
+        .rebalanceTable(unbalancedAssignment, _instancePartitionsMapWithoutPartition, null, null,
+            new BaseConfiguration());
     int[] actualNumSegmentsAssignedPerInstance =
         SegmentAssignmentUtils.getNumSegmentsAssignedPerInstance(balancedAssignment, INSTANCES);
     int[] expectedNumSegmentsAssignedPerInstance = new int[NUM_INSTANCES];
@@ -385,9 +390,9 @@
           any(), anyInt())).thenReturn(segmentZKMetadataZNRecord);
       segmentZKMetadataZNRecords.add(segmentZKMetadataZNRecord);
     }
-    when(propertyStore.getChildren(
-        eq(ZKMetadataProvider.constructPropertyStorePathForResource(OFFLINE_TABLE_NAME_WITH_PARTITION)), any(),
-        anyInt(), anyInt(), anyInt())).thenReturn(segmentZKMetadataZNRecords);
+    when(propertyStore
+        .getChildren(eq(ZKMetadataProvider.constructPropertyStorePathForResource(OFFLINE_TABLE_NAME_WITH_PARTITION)),
+            any(), anyInt(), anyInt(), anyInt())).thenReturn(segmentZKMetadataZNRecords);
     HelixManager helixManager = mock(HelixManager.class);
     when(helixManager.getHelixPropertyStore()).thenReturn(propertyStore);
 
@@ -440,8 +445,8 @@
           (segmentId % NUM_INSTANCES) / NUM_PARTITIONS + partitionId * numInstancesPerPartition;
       assertEquals(instancesAssigned.get(0), INSTANCES.get(expectedAssignedInstanceId));
 
-      currentAssignment.put(segmentName,
-          SegmentAssignmentUtils.getInstanceStateMap(instancesAssigned, SegmentStateModel.ONLINE));
+      currentAssignment
+          .put(segmentName, SegmentAssignmentUtils.getInstanceStateMap(instancesAssigned, SegmentStateModel.ONLINE));
     }
 
     // Current assignment should already be balanced
diff --git a/pinot-controller/src/test/java/org/apache/pinot/controller/helix/core/assignment/segment/strategy/SegmentAssignmentStrategyFactoryTest.java b/pinot-controller/src/test/java/org/apache/pinot/controller/helix/core/assignment/segment/strategy/SegmentAssignmentStrategyFactoryTest.java
new file mode 100644
index 0000000..7d0fd95
--- /dev/null
+++ b/pinot-controller/src/test/java/org/apache/pinot/controller/helix/core/assignment/segment/strategy/SegmentAssignmentStrategyFactoryTest.java
@@ -0,0 +1,146 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.controller.helix.core.assignment.segment.strategy;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import org.apache.pinot.common.assignment.InstancePartitions;
+import org.apache.pinot.controller.helix.core.assignment.segment.SegmentAssignmentTestUtils;
+import org.apache.pinot.spi.config.table.ReplicaGroupStrategyConfig;
+import org.apache.pinot.spi.config.table.TableConfig;
+import org.apache.pinot.spi.config.table.TableType;
+import org.apache.pinot.spi.config.table.assignment.InstancePartitionsType;
+import org.apache.pinot.spi.config.table.assignment.SegmentAssignmentConfig;
+import org.apache.pinot.spi.utils.builder.TableConfigBuilder;
+import org.testng.Assert;
+import org.testng.annotations.Test;
+
+
+/**
+ * Tests the {@link SegmentAssignmentStrategyFactory#getSegmentAssignmentStrategy} method
+ */
+public class SegmentAssignmentStrategyFactoryTest {
+
+  private static final int NUM_REPLICAS = 3;
+  private static final String RAW_TABLE_NAME = "testTable";
+  private static final String INSTANCE_PARTITIONS_NAME =
+      InstancePartitionsType.OFFLINE.getInstancePartitionsName(RAW_TABLE_NAME);
+  private static final String INSTANCE_NAME_PREFIX = "instance_";
+  private static final int NUM_INSTANCES = 10;
+  private static final List<String> INSTANCES =
+      SegmentAssignmentTestUtils.getNameList(INSTANCE_NAME_PREFIX, NUM_INSTANCES);
+  private static final String RAW_TABLE_NAME_WITH_PARTITION = "testTableWithPartition";
+  private static final String INSTANCE_PARTITIONS_NAME_WITH_PARTITION =
+      InstancePartitionsType.OFFLINE.getInstancePartitionsName(RAW_TABLE_NAME_WITH_PARTITION);
+  private static final int NUM_PARTITIONS = 3;
+  private static final String PARTITION_COLUMN = "partitionColumn";
+
+  private SegmentAssignmentStrategyFactoryTest() {
+  }
+
+  @Test
+  public void testSegmentAssignmentStrategyFromTableConfig() {
+    // Set segment assignment config map in table config for balanced num segment assignment strategy
+    Map<String, SegmentAssignmentConfig> segmentAssignmentConfigMap = new HashMap<>();
+    segmentAssignmentConfigMap.put(InstancePartitionsType.OFFLINE.toString(), new SegmentAssignmentConfig("Balanced"));
+    TableConfig tableConfig = new TableConfigBuilder(TableType.OFFLINE).setTableName(RAW_TABLE_NAME)
+        .setSegmentAssignmentConfigMap(segmentAssignmentConfigMap).build();
+
+    InstancePartitions instancePartitions = new InstancePartitions(INSTANCE_PARTITIONS_NAME);
+    instancePartitions.setInstances(0, 0, INSTANCES);
+
+    SegmentAssignmentStrategy segmentAssignmentStrategy = SegmentAssignmentStrategyFactory
+        .getSegmentAssignmentStrategy(null, tableConfig, InstancePartitionsType.OFFLINE.toString(), instancePartitions);
+    Assert.assertNotNull(segmentAssignmentStrategy);
+    Assert.assertTrue(segmentAssignmentStrategy instanceof BalancedNumSegmentAssignmentStrategy);
+  }
+
+  @Test
+  public void testSegmentAssignmentStrategyForDimTable() {
+    TableConfig tableConfig =
+        new TableConfigBuilder(TableType.OFFLINE).setTableName(RAW_TABLE_NAME).setIsDimTable(true).build();
+    SegmentAssignmentStrategy segmentAssignmentStrategy = SegmentAssignmentStrategyFactory
+        .getSegmentAssignmentStrategy(null, tableConfig, InstancePartitionsType.OFFLINE.toString(), null);
+    Assert.assertNotNull(segmentAssignmentStrategy);
+    Assert.assertTrue(segmentAssignmentStrategy instanceof AllServersSegmentAssignmentStrategy);
+  }
+
+  @Test
+  public void testBalancedNumSegmentAssignmentStrategyforOfflineTables() {
+    TableConfig tableConfig = new TableConfigBuilder(TableType.OFFLINE).setTableName(RAW_TABLE_NAME).build();
+
+    InstancePartitions instancePartitions = new InstancePartitions(INSTANCE_PARTITIONS_NAME);
+    instancePartitions.setInstances(0, 0, INSTANCES);
+
+    SegmentAssignmentStrategy segmentAssignmentStrategy = SegmentAssignmentStrategyFactory
+        .getSegmentAssignmentStrategy(null, tableConfig, InstancePartitionsType.OFFLINE.toString(), instancePartitions);
+    Assert.assertNotNull(segmentAssignmentStrategy);
+    Assert.assertTrue(segmentAssignmentStrategy instanceof BalancedNumSegmentAssignmentStrategy);
+  }
+
+  @Test
+  public void testBalancedNumSegmentAssignmentStrategyforRealtimeTables() {
+    TableConfig tableConfig = new TableConfigBuilder(TableType.REALTIME).setTableName(RAW_TABLE_NAME).build();
+
+    InstancePartitions instancePartitions = new InstancePartitions(INSTANCE_PARTITIONS_NAME);
+    instancePartitions.setInstances(0, 0, INSTANCES);
+
+    SegmentAssignmentStrategy segmentAssignmentStrategy = SegmentAssignmentStrategyFactory
+        .getSegmentAssignmentStrategy(null, tableConfig, InstancePartitionsType.COMPLETED.toString(),
+            instancePartitions);
+    Assert.assertNotNull(segmentAssignmentStrategy);
+    Assert.assertTrue(segmentAssignmentStrategy instanceof BalancedNumSegmentAssignmentStrategy);
+  }
+
+  @Test
+  public void testReplicaGroupSegmentAssignmentStrategyForBackwardCompatibility() {
+    int numInstancesPerReplicaGroup = NUM_INSTANCES / NUM_REPLICAS;
+    int numInstancesPerPartition = numInstancesPerReplicaGroup / NUM_REPLICAS;
+    ReplicaGroupStrategyConfig replicaGroupStrategyConfig =
+        new ReplicaGroupStrategyConfig(PARTITION_COLUMN, numInstancesPerPartition);
+    TableConfig tableConfig = new TableConfigBuilder(TableType.OFFLINE).setTableName(RAW_TABLE_NAME_WITH_PARTITION)
+        .setNumReplicas(NUM_REPLICAS).setSegmentAssignmentStrategy("ReplicaGroup")
+        .setReplicaGroupStrategyConfig(replicaGroupStrategyConfig).build();
+
+    // {
+    //   0_0=[instance_0, instance_1], 1_0=[instance_2, instance_3], 2_0=[instance_4, instance_5],
+    //   0_1=[instance_6, instance_7], 1_1=[instance_8, instance_9], 2_1=[instance_10, instance_11],
+    //   0_2=[instance_12, instance_13], 1_2=[instance_14, instance_15], 2_2=[instance_16, instance_17]
+    // }
+    InstancePartitions instancePartitions = new InstancePartitions(INSTANCE_PARTITIONS_NAME_WITH_PARTITION);
+
+    int instanceIdToAdd = 0;
+    for (int replicaGroupId = 0; replicaGroupId < NUM_REPLICAS; replicaGroupId++) {
+      for (int partitionId = 0; partitionId < NUM_PARTITIONS; partitionId++) {
+        List<String> instancesForPartition = new ArrayList<>(numInstancesPerPartition);
+        for (int i = 0; i < numInstancesPerPartition; i++) {
+          instancesForPartition.add(INSTANCES.get(instanceIdToAdd++));
+        }
+        instancePartitions.setInstances(partitionId, replicaGroupId, instancesForPartition);
+      }
+    }
+
+    SegmentAssignmentStrategy segmentAssignmentStrategy = SegmentAssignmentStrategyFactory
+        .getSegmentAssignmentStrategy(null, tableConfig, InstancePartitionsType.OFFLINE.toString(), instancePartitions);
+    Assert.assertNotNull(segmentAssignmentStrategy);
+    Assert.assertTrue(segmentAssignmentStrategy instanceof ReplicaGroupSegmentAssignmentStrategy);
+  }
+}
diff --git a/pinot-spi/src/main/java/org/apache/pinot/spi/config/table/TableConfig.java b/pinot-spi/src/main/java/org/apache/pinot/spi/config/table/TableConfig.java
index f41c19d..e9e0032 100644
--- a/pinot-spi/src/main/java/org/apache/pinot/spi/config/table/TableConfig.java
+++ b/pinot-spi/src/main/java/org/apache/pinot/spi/config/table/TableConfig.java
@@ -29,6 +29,7 @@
 import org.apache.pinot.spi.config.BaseJsonConfig;
 import org.apache.pinot.spi.config.table.assignment.InstanceAssignmentConfig;
 import org.apache.pinot.spi.config.table.assignment.InstancePartitionsType;
+import org.apache.pinot.spi.config.table.assignment.SegmentAssignmentConfig;
 import org.apache.pinot.spi.config.table.ingestion.IngestionConfig;
 import org.apache.pinot.spi.utils.builder.TableNameBuilder;
 
@@ -48,6 +49,7 @@
   public static final String QUERY_CONFIG_KEY = "query";
   public static final String INSTANCE_ASSIGNMENT_CONFIG_MAP_KEY = "instanceAssignmentConfigMap";
   public static final String INSTANCE_PARTITIONS_MAP_CONFIG_KEY = "instancePartitionsMap";
+  public static final String SEGMENT_ASSIGNMENT_CONFIG_MAP_KEY = "segmentAssignmentConfigMap";
   public static final String FIELD_CONFIG_LIST_KEY = "fieldConfigList";
   public static final String UPSERT_CONFIG_KEY = "upsertConfig";
   public static final String DEDUP_CONFIG_KEY = "dedupConfig";
@@ -89,6 +91,7 @@
   @JsonPropertyDescription(value = "Point to an existing instance partitions")
   private Map<InstancePartitionsType, String> _instancePartitionsMap;
 
+  private Map<String, SegmentAssignmentConfig> _segmentAssignmentConfigMap;
   private List<FieldConfig> _fieldConfigList;
 
   @JsonPropertyDescription(value = "upsert related config")
@@ -128,7 +131,9 @@
       @JsonProperty(IS_DIM_TABLE_KEY) boolean dimTable,
       @JsonProperty(TUNER_CONFIG_LIST_KEY) @Nullable List<TunerConfig> tunerConfigList,
       @JsonProperty(INSTANCE_PARTITIONS_MAP_CONFIG_KEY) @Nullable
-          Map<InstancePartitionsType, String> instancePartitionsMap) {
+          Map<InstancePartitionsType, String> instancePartitionsMap,
+      @JsonProperty(SEGMENT_ASSIGNMENT_CONFIG_MAP_KEY) @Nullable
+          Map<String, SegmentAssignmentConfig> segmentAssignmentConfigMap) {
     Preconditions.checkArgument(tableName != null, "'tableName' must be configured");
     Preconditions.checkArgument(!tableName.contains(TABLE_NAME_FORBIDDEN_SUBSTRING),
         "'tableName' cannot contain double underscore ('__')");
@@ -158,6 +163,7 @@
     _dimTable = dimTable;
     _tunerConfigList = tunerConfigList;
     _instancePartitionsMap = instancePartitionsMap;
+    _segmentAssignmentConfigMap = segmentAssignmentConfigMap;
   }
 
   @JsonProperty(TABLE_NAME_KEY)
@@ -338,4 +344,14 @@
   public void setTunerConfigsList(List<TunerConfig> tunerConfigList) {
     _tunerConfigList = tunerConfigList;
   }
+
+  @JsonProperty(SEGMENT_ASSIGNMENT_CONFIG_MAP_KEY)
+  @Nullable
+  public Map<String, SegmentAssignmentConfig> getSegmentAssignmentConfigMap() {
+    return _segmentAssignmentConfigMap;
+  }
+
+  public void setSegmentAssignmentConfigMap(Map<String, SegmentAssignmentConfig> segmentAssignmentConfigMap) {
+    _segmentAssignmentConfigMap = segmentAssignmentConfigMap;
+  }
 }
diff --git a/pinot-spi/src/main/java/org/apache/pinot/spi/config/table/assignment/SegmentAssignmentConfig.java b/pinot-spi/src/main/java/org/apache/pinot/spi/config/table/assignment/SegmentAssignmentConfig.java
new file mode 100644
index 0000000..7c7064c
--- /dev/null
+++ b/pinot-spi/src/main/java/org/apache/pinot/spi/config/table/assignment/SegmentAssignmentConfig.java
@@ -0,0 +1,40 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.spi.config.table.assignment;
+
+import com.fasterxml.jackson.annotation.JsonCreator;
+import com.fasterxml.jackson.annotation.JsonProperty;
+import com.fasterxml.jackson.annotation.JsonPropertyDescription;
+import org.apache.pinot.spi.config.BaseJsonConfig;
+
+
+public class SegmentAssignmentConfig extends BaseJsonConfig {
+
+  @JsonPropertyDescription("Configuration for Segment Assignment Strategy")
+  private final String _assignmentStrategy;
+
+  @JsonCreator
+  public SegmentAssignmentConfig(@JsonProperty(value = "segmentAssignmentStrategy") String assignmentStrategy) {
+    _assignmentStrategy = assignmentStrategy;
+  }
+
+  public String getAssignmentStrategy() {
+    return _assignmentStrategy;
+  }
+}
diff --git a/pinot-spi/src/main/java/org/apache/pinot/spi/utils/CommonConstants.java b/pinot-spi/src/main/java/org/apache/pinot/spi/utils/CommonConstants.java
index 25973c3..5b29a00 100644
--- a/pinot-spi/src/main/java/org/apache/pinot/spi/utils/CommonConstants.java
+++ b/pinot-spi/src/main/java/org/apache/pinot/spi/utils/CommonConstants.java
@@ -768,8 +768,9 @@
     public static final String METADATA_URI_FOR_PEER_DOWNLOAD = "";
 
     public static class AssignmentStrategy {
-      public static final String BALANCE_NUM_SEGMENT_ASSIGNMENT_STRATEGY = "BalanceNumSegmentAssignmentStrategy";
-      public static final String REPLICA_GROUP_SEGMENT_ASSIGNMENT_STRATEGY = "ReplicaGroupSegmentAssignmentStrategy";
+      public static final String BALANCE_NUM_SEGMENT_ASSIGNMENT_STRATEGY = "balanced";
+      public static final String REPLICA_GROUP_SEGMENT_ASSIGNMENT_STRATEGY = "replicagroup";
+      public static final String DIM_TABLE_SEGMENT_ASSIGNMENT_STRATEGY = "allservers";
     }
 
     public static class BuiltInVirtualColumn {
diff --git a/pinot-spi/src/main/java/org/apache/pinot/spi/utils/builder/TableConfigBuilder.java b/pinot-spi/src/main/java/org/apache/pinot/spi/utils/builder/TableConfigBuilder.java
index e884d81..ee9c227 100644
--- a/pinot-spi/src/main/java/org/apache/pinot/spi/utils/builder/TableConfigBuilder.java
+++ b/pinot-spi/src/main/java/org/apache/pinot/spi/utils/builder/TableConfigBuilder.java
@@ -44,6 +44,7 @@
 import org.apache.pinot.spi.config.table.UpsertConfig;
 import org.apache.pinot.spi.config.table.assignment.InstanceAssignmentConfig;
 import org.apache.pinot.spi.config.table.assignment.InstancePartitionsType;
+import org.apache.pinot.spi.config.table.assignment.SegmentAssignmentConfig;
 import org.apache.pinot.spi.config.table.ingestion.IngestionConfig;
 
 
@@ -109,6 +110,7 @@
   private QueryConfig _queryConfig;
   private Map<InstancePartitionsType, InstanceAssignmentConfig> _instanceAssignmentConfigMap;
   private Map<InstancePartitionsType, String> _instancePartitionsMap;
+  private Map<String, SegmentAssignmentConfig> _segmentAssignmentConfigMap;
   private List<FieldConfig> _fieldConfigList;
 
   private UpsertConfig _upsertConfig;
@@ -380,6 +382,12 @@
     return this;
   }
 
+  public TableConfigBuilder setSegmentAssignmentConfigMap(
+      Map<String, SegmentAssignmentConfig> segmentAssignmentConfigMap) {
+    _segmentAssignmentConfigMap = segmentAssignmentConfigMap;
+    return this;
+  }
+
   public TableConfig build() {
     // Validation config
     SegmentsValidationAndRetentionConfig validationConfig = new SegmentsValidationAndRetentionConfig();
@@ -431,7 +439,7 @@
 
     return new TableConfig(_tableName, _tableType.toString(), validationConfig, tenantConfig, indexingConfig,
         _customConfig, _quotaConfig, _taskConfig, _routingConfig, _queryConfig, _instanceAssignmentConfigMap,
-        _fieldConfigList, _upsertConfig, _dedupConfig, _ingestionConfig, _tierConfigList, _isDimTable,
-        _tunerConfigList, _instancePartitionsMap);
+        _fieldConfigList, _upsertConfig, _dedupConfig, _ingestionConfig, _tierConfigList, _isDimTable, _tunerConfigList,
+        _instancePartitionsMap, _segmentAssignmentConfigMap);
   }
 }