Automatically determine numShards for parallel ingestion hash partitioning (#10419)

* Automatically determine numShards for parallel ingestion hash partitioning

* Fix inspection, tests, coverage

* Docs and some PR comments

* Adjust locking

* Use HllSketch instead of HyperLogLogCollector

* Fix tests

* Address some PR comments

* Fix granularity bug

* Small doc fix
diff --git a/core/src/main/java/org/apache/druid/indexer/partitions/HashedPartitionsSpec.java b/core/src/main/java/org/apache/druid/indexer/partitions/HashedPartitionsSpec.java
index 1ce2749..d78636c 100644
--- a/core/src/main/java/org/apache/druid/indexer/partitions/HashedPartitionsSpec.java
+++ b/core/src/main/java/org/apache/druid/indexer/partitions/HashedPartitionsSpec.java
@@ -33,7 +33,7 @@
 
 public class HashedPartitionsSpec implements DimensionBasedPartitionsSpec
 {
-  static final String NAME = "hashed";
+  public static final String NAME = "hashed";
   @VisibleForTesting
   static final String NUM_SHARDS = "numShards";
 
@@ -160,7 +160,7 @@
   @Override
   public String getForceGuaranteedRollupIncompatiblityReason()
   {
-    return getNumShards() == null ? NUM_SHARDS + " must be specified" : FORCE_GUARANTEED_ROLLUP_COMPATIBLE;
+    return FORCE_GUARANTEED_ROLLUP_COMPATIBLE;
   }
 
   @Override
diff --git a/core/src/main/java/org/apache/druid/timeline/partition/HashBasedNumberedShardSpec.java b/core/src/main/java/org/apache/druid/timeline/partition/HashBasedNumberedShardSpec.java
index a1ec5a9..f564c2f 100644
--- a/core/src/main/java/org/apache/druid/timeline/partition/HashBasedNumberedShardSpec.java
+++ b/core/src/main/java/org/apache/druid/timeline/partition/HashBasedNumberedShardSpec.java
@@ -47,7 +47,7 @@
 
 public class HashBasedNumberedShardSpec extends NumberedShardSpec
 {
-  static final List<String> DEFAULT_PARTITION_DIMENSIONS = ImmutableList.of();
+  public static final List<String> DEFAULT_PARTITION_DIMENSIONS = ImmutableList.of();
 
   private static final HashFunction HASH_FUNCTION = Hashing.murmur3_32();
 
@@ -159,8 +159,7 @@
     }
   }
 
-  @VisibleForTesting
-  static List<Object> getGroupKey(final List<String> partitionDimensions, final long timestamp, final InputRow inputRow)
+  public static List<Object> getGroupKey(final List<String> partitionDimensions, final long timestamp, final InputRow inputRow)
   {
     if (partitionDimensions.isEmpty()) {
       return Rows.toGroupKey(timestamp, inputRow);
diff --git a/docs/ingestion/native-batch.md b/docs/ingestion/native-batch.md
index 9b71825..5a58a7f 100644
--- a/docs/ingestion/native-batch.md
+++ b/docs/ingestion/native-batch.md
@@ -294,11 +294,18 @@
 |property|description|default|required?|
 |--------|-----------|-------|---------|
 |type|This should always be `hashed`|none|yes|
-|numShards|Directly specify the number of shards to create. If this is specified and `intervals` is specified in the `granularitySpec`, the index task can skip the determine intervals/partitions pass through the data.|null|yes|
+|numShards|Directly specify the number of shards to create. If this is specified and `intervals` is specified in the `granularitySpec`, the index task can skip the determine intervals/partitions pass through the data. This property and `targetRowsPerSegment` cannot both be set.|null|no|
 |partitionDimensions|The dimensions to partition on. Leave blank to select all dimensions.|null|no|
+|targetRowsPerSegment|A target row count for each partition. If `numShards` is left unspecified, the Parallel task will determine a partition count automatically such that each partition has a row count close to the target, assuming evenly distributed keys in the input data. A target per-segment row count of 5 million is used if both `numShards` and `targetRowsPerSegment` are null. |null (or 5,000,000 if both `numShards` and `targetRowsPerSegment` are null)|no|
 
 The Parallel task with hash-based partitioning is similar to [MapReduce](https://en.wikipedia.org/wiki/MapReduce).
-The task runs in 2 phases, i.e., `partial segment generation` and `partial segment merge`.
+The task runs in up to 3 phases: `partial dimension cardinality`, `partial segment generation` and `partial segment merge`.
+- The `partial dimension cardinality` phase is an optional phase that only runs if `numShards` is not specified.
+The Parallel task splits the input data and assigns them to worker tasks based on the split hint spec.
+Each worker task (type `partial_dimension_cardinality`) gathers estimates of partitioning dimensions cardinality for
+each time chunk. The Parallel task will aggregate these estimates from the worker tasks and determine the highest
+cardinality across all of the time chunks in the input data, dividing this cardinality by `targetRowsPerSegment` to
+automatically determine `numShards`.
 - In the `partial segment generation` phase, just like the Map phase in MapReduce,
 the Parallel task splits the input data based on the split hint spec
 and assigns each split to a worker task. Each worker task (type `partial_index_generate`) reads the assigned split,
diff --git a/indexing-service/src/main/java/org/apache/druid/indexing/common/task/IndexTask.java b/indexing-service/src/main/java/org/apache/druid/indexing/common/task/IndexTask.java
index 5e93675..16e03f1 100644
--- a/indexing-service/src/main/java/org/apache/druid/indexing/common/task/IndexTask.java
+++ b/indexing-service/src/main/java/org/apache/druid/indexing/common/task/IndexTask.java
@@ -131,8 +131,9 @@
 
 public class IndexTask extends AbstractBatchIndexTask implements ChatHandler
 {
+  public static final HashFunction HASH_FUNCTION = Hashing.murmur3_128();
+
   private static final Logger log = new Logger(IndexTask.class);
-  private static final HashFunction HASH_FUNCTION = Hashing.murmur3_128();
   private static final String TYPE = "index";
 
   private static String makeGroupId(IndexIngestionSpec ingestionSchema)
@@ -599,7 +600,8 @@
       if (partitionsSpec.getType() == SecondaryPartitionType.HASH) {
         return PartialHashSegmentGenerateTask.createHashPartitionAnalysisFromPartitionsSpec(
             granularitySpec,
-            (HashedPartitionsSpec) partitionsSpec
+            (HashedPartitionsSpec) partitionsSpec,
+            null // not overriding numShards
         );
       } else if (partitionsSpec.getType() == SecondaryPartitionType.LINEAR) {
         return createLinearPartitionAnalysis(granularitySpec, (DynamicPartitionsSpec) partitionsSpec);
diff --git a/indexing-service/src/main/java/org/apache/druid/indexing/common/task/Task.java b/indexing-service/src/main/java/org/apache/druid/indexing/common/task/Task.java
index 20a7da6..197d901 100644
--- a/indexing-service/src/main/java/org/apache/druid/indexing/common/task/Task.java
+++ b/indexing-service/src/main/java/org/apache/druid/indexing/common/task/Task.java
@@ -28,6 +28,7 @@
 import org.apache.druid.indexing.common.config.TaskConfig;
 import org.apache.druid.indexing.common.task.batch.parallel.LegacySinglePhaseSubTask;
 import org.apache.druid.indexing.common.task.batch.parallel.ParallelIndexSupervisorTask;
+import org.apache.druid.indexing.common.task.batch.parallel.PartialDimensionCardinalityTask;
 import org.apache.druid.indexing.common.task.batch.parallel.PartialDimensionDistributionTask;
 import org.apache.druid.indexing.common.task.batch.parallel.PartialGenericSegmentMergeTask;
 import org.apache.druid.indexing.common.task.batch.parallel.PartialHashSegmentGenerateTask;
@@ -62,6 +63,7 @@
     // for backward compatibility
     @Type(name = SinglePhaseSubTask.OLD_TYPE_NAME, value = LegacySinglePhaseSubTask.class),
     @Type(name = PartialHashSegmentGenerateTask.TYPE, value = PartialHashSegmentGenerateTask.class),
+    @Type(name = PartialDimensionCardinalityTask.TYPE, value = PartialDimensionCardinalityTask.class),
     @Type(name = PartialRangeSegmentGenerateTask.TYPE, value = PartialRangeSegmentGenerateTask.class),
     @Type(name = PartialDimensionDistributionTask.TYPE, value = PartialDimensionDistributionTask.class),
     @Type(name = PartialGenericSegmentMergeTask.TYPE, value = PartialGenericSegmentMergeTask.class),
diff --git a/indexing-service/src/main/java/org/apache/druid/indexing/common/task/batch/parallel/DimensionCardinalityReport.java b/indexing-service/src/main/java/org/apache/druid/indexing/common/task/batch/parallel/DimensionCardinalityReport.java
new file mode 100644
index 0000000..83502f5
--- /dev/null
+++ b/indexing-service/src/main/java/org/apache/druid/indexing/common/task/batch/parallel/DimensionCardinalityReport.java
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.indexing.common.task.batch.parallel;
+
+import com.fasterxml.jackson.annotation.JsonCreator;
+import com.fasterxml.jackson.annotation.JsonProperty;
+import org.apache.datasketches.hll.HllSketch;
+import org.joda.time.Interval;
+
+import java.util.Map;
+import java.util.Objects;
+
+public class DimensionCardinalityReport implements SubTaskReport
+{
+  // We choose logK=11 because the following link shows that HllSketch with K=2048 has roughly the same
+  // serialized size as HyperLogLogCollector.
+  // http://datasketches.apache.org/docs/HLL/HllSketchVsDruidHyperLogLogCollector.html
+  public static final int HLL_SKETCH_LOG_K = 11;
+
+  static final String TYPE = "dimension_cardinality";
+  private static final String PROP_CARDINALITIES = "cardinalities";
+
+
+  private final String taskId;
+
+  /**
+   * A map of intervals to byte arrays, representing {@link HllSketch} objects,
+   * serialized using {@link HllSketch#toCompactByteArray()}.
+   *
+   * The HllSketch objects should be created with the HLL_SKETCH_LOG_K constant defined in this class.
+   *
+   * The collector is used to determine cardinality estimates for each interval.
+   */
+  private final Map<Interval, byte[]> intervalToCardinalities;
+
+  @JsonCreator
+  public DimensionCardinalityReport(
+      @JsonProperty("taskId") String taskId,
+      @JsonProperty(PROP_CARDINALITIES) Map<Interval, byte[]> intervalToCardinalities
+  )
+  {
+    this.taskId = taskId;
+    this.intervalToCardinalities = intervalToCardinalities;
+  }
+
+  @Override
+  @JsonProperty
+  public String getTaskId()
+  {
+    return taskId;
+  }
+
+  @JsonProperty(PROP_CARDINALITIES)
+  public Map<Interval, byte[]> getIntervalToCardinalities()
+  {
+    return intervalToCardinalities;
+  }
+
+  @Override
+  public String toString()
+  {
+    return "DimensionCardinalityReport{" +
+           "taskId='" + taskId + '\'' +
+           ", intervalToCardinalities=" + intervalToCardinalities +
+           '}';
+  }
+
+  @Override
+  public boolean equals(Object o)
+  {
+    if (this == o) {
+      return true;
+    }
+    if (o == null || getClass() != o.getClass()) {
+      return false;
+    }
+    DimensionCardinalityReport that = (DimensionCardinalityReport) o;
+    return Objects.equals(getTaskId(), that.getTaskId()) &&
+           Objects.equals(getIntervalToCardinalities(), that.getIntervalToCardinalities());
+  }
+
+  @Override
+  public int hashCode()
+  {
+    return Objects.hash(getTaskId(), getIntervalToCardinalities());
+  }
+
+  public static HllSketch createHllSketchForReport()
+  {
+    return new HllSketch(HLL_SKETCH_LOG_K);
+  }
+}
diff --git a/indexing-service/src/main/java/org/apache/druid/indexing/common/task/batch/parallel/ParallelIndexSupervisorTask.java b/indexing-service/src/main/java/org/apache/druid/indexing/common/task/batch/parallel/ParallelIndexSupervisorTask.java
index 7b72895..dd0e759 100644
--- a/indexing-service/src/main/java/org/apache/druid/indexing/common/task/batch/parallel/ParallelIndexSupervisorTask.java
+++ b/indexing-service/src/main/java/org/apache/druid/indexing/common/task/batch/parallel/ParallelIndexSupervisorTask.java
@@ -30,11 +30,15 @@
 import com.google.common.collect.Multimap;
 import it.unimi.dsi.fastutil.objects.Object2IntMap;
 import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap;
+import org.apache.datasketches.hll.HllSketch;
+import org.apache.datasketches.hll.Union;
+import org.apache.datasketches.memory.Memory;
 import org.apache.druid.data.input.FiniteFirehoseFactory;
 import org.apache.druid.data.input.InputFormat;
 import org.apache.druid.data.input.InputSource;
 import org.apache.druid.indexer.TaskState;
 import org.apache.druid.indexer.TaskStatus;
+import org.apache.druid.indexer.partitions.HashedPartitionsSpec;
 import org.apache.druid.indexer.partitions.PartitionsSpec;
 import org.apache.druid.indexer.partitions.SingleDimensionPartitionsSpec;
 import org.apache.druid.indexing.common.Counters;
@@ -271,14 +275,30 @@
   }
 
   @VisibleForTesting
-  PartialHashSegmentGenerateParallelIndexTaskRunner createPartialHashSegmentGenerateRunner(TaskToolbox toolbox)
+  PartialDimensionCardinalityParallelIndexTaskRunner createPartialDimensionCardinalityRunner(TaskToolbox toolbox)
+  {
+    return new PartialDimensionCardinalityParallelIndexTaskRunner(
+        toolbox,
+        getId(),
+        getGroupId(),
+        ingestionSchema,
+        getContext()
+    );
+  }
+
+  @VisibleForTesting
+  PartialHashSegmentGenerateParallelIndexTaskRunner createPartialHashSegmentGenerateRunner(
+      TaskToolbox toolbox,
+      Integer numShardsOverride
+  )
   {
     return new PartialHashSegmentGenerateParallelIndexTaskRunner(
         toolbox,
         getId(),
         getGroupId(),
         ingestionSchema,
-        getContext()
+        getContext(),
+        numShardsOverride
     );
   }
 
@@ -499,17 +519,67 @@
 
   private TaskStatus runHashPartitionMultiPhaseParallel(TaskToolbox toolbox) throws Exception
   {
-    // 1. Partial segment generation phase
-    ParallelIndexTaskRunner<PartialHashSegmentGenerateTask, GeneratedPartitionsReport<GenericPartitionStat>> indexingRunner
-        = createRunner(toolbox, this::createPartialHashSegmentGenerateRunner);
+    TaskState state;
 
-    TaskState state = runNextPhase(indexingRunner);
+    if (!(ingestionSchema.getTuningConfig().getPartitionsSpec() instanceof HashedPartitionsSpec)) {
+      // only range and hash partitioning is supported for multiphase parallel ingestion, see runMultiPhaseParallel()
+      throw new ISE(
+          "forceGuaranteedRollup is set but partitionsSpec [%s] is not a single_dim or hash partition spec.",
+          ingestionSchema.getTuningConfig().getPartitionsSpec()
+      );
+    }
+
+    final Integer numShardsOverride;
+    HashedPartitionsSpec partitionsSpec = (HashedPartitionsSpec) ingestionSchema.getTuningConfig().getPartitionsSpec();
+    if (partitionsSpec.getNumShards() == null) {
+      // 0. need to determine numShards by scanning the data
+      LOG.info("numShards is unspecified, beginning %s phase.", PartialDimensionCardinalityTask.TYPE);
+      ParallelIndexTaskRunner<PartialDimensionCardinalityTask, DimensionCardinalityReport> cardinalityRunner =
+          createRunner(
+              toolbox,
+              this::createPartialDimensionCardinalityRunner
+          );
+
+      if (cardinalityRunner == null) {
+        throw new ISE("Could not create cardinality runner for hash partitioning.");
+      }
+
+      state = runNextPhase(cardinalityRunner);
+      if (state.isFailure()) {
+        return TaskStatus.failure(getId());
+      }
+
+      int effectiveMaxRowsPerSegment = partitionsSpec.getMaxRowsPerSegment() == null
+                                       ? PartitionsSpec.DEFAULT_MAX_ROWS_PER_SEGMENT
+                                       : partitionsSpec.getMaxRowsPerSegment();
+      LOG.info("effective maxRowsPerSegment is: " + effectiveMaxRowsPerSegment);
+
+      if (cardinalityRunner.getReports() == null) {
+        throw new ISE("Could not determine cardinalities for hash partitioning.");
+      }
+      numShardsOverride = determineNumShardsFromCardinalityReport(
+          cardinalityRunner.getReports().values(),
+          effectiveMaxRowsPerSegment
+      );
+
+      LOG.info("Automatically determined numShards: " + numShardsOverride);
+    } else {
+      numShardsOverride = null;
+    }
+
+    // 1. Partial segment generation phase
+    ParallelIndexTaskRunner<PartialHashSegmentGenerateTask, GeneratedPartitionsReport<GenericPartitionStat>> indexingRunner =
+        createRunner(
+            toolbox,
+            f -> createPartialHashSegmentGenerateRunner(toolbox, numShardsOverride)
+        );
+
+    state = runNextPhase(indexingRunner);
     if (state.isFailure()) {
       return TaskStatus.failure(getId());
     }
 
     // 2. Partial segment merge phase
-
     // partition (interval, partitionId) -> partition locations
     Map<Pair<Interval, Integer>, List<GenericPartitionLocation>> partitionToLocations =
         groupGenericPartitionLocationsPerPartition(indexingRunner.getReports());
@@ -582,6 +652,50 @@
     return TaskStatus.fromCode(getId(), mergeState);
   }
 
+  @VisibleForTesting
+  public static int determineNumShardsFromCardinalityReport(
+      Collection<DimensionCardinalityReport> reports,
+      int maxRowsPerSegment
+  )
+  {
+    // aggregate all the sub-reports
+    Map<Interval, Union> finalCollectors = new HashMap<>();
+    reports.forEach(report -> {
+      Map<Interval, byte[]> intervalToCardinality = report.getIntervalToCardinalities();
+      for (Map.Entry<Interval, byte[]> entry : intervalToCardinality.entrySet()) {
+        Union union = finalCollectors.computeIfAbsent(
+            entry.getKey(),
+            (key) -> {
+              return new Union(DimensionCardinalityReport.HLL_SKETCH_LOG_K);
+            }
+        );
+        HllSketch entryHll = HllSketch.wrap(Memory.wrap(entry.getValue()));
+        union.update(entryHll);
+      }
+    });
+
+    // determine the highest cardinality in any interval
+    long maxCardinality = 0;
+    for (Union union : finalCollectors.values()) {
+      maxCardinality = Math.max(maxCardinality, (long) union.getEstimate());
+    }
+
+    LOG.info("Estimated max cardinality: " + maxCardinality);
+
+    // determine numShards based on maxRowsPerSegment and the highest per-interval cardinality
+    long numShards = maxCardinality / maxRowsPerSegment;
+    if (maxCardinality % maxRowsPerSegment != 0) {
+      // if there's a remainder add 1 so we stay under maxRowsPerSegment
+      numShards += 1;
+    }
+    try {
+      return Math.toIntExact(numShards);
+    }
+    catch (ArithmeticException ae) {
+      throw new ISE("Estimated numShards [%s] exceeds integer bounds.", numShards);
+    }
+  }
+
   private Map<Interval, PartitionBoundaries> determineAllRangePartitions(Collection<DimensionDistributionReport> reports)
   {
     Multimap<Interval, StringDistribution> intervalToDistributions = ArrayListMultimap.create();
diff --git a/indexing-service/src/main/java/org/apache/druid/indexing/common/task/batch/parallel/PartialDimensionCardinalityParallelIndexTaskRunner.java b/indexing-service/src/main/java/org/apache/druid/indexing/common/task/batch/parallel/PartialDimensionCardinalityParallelIndexTaskRunner.java
new file mode 100644
index 0000000..25b0c0c
--- /dev/null
+++ b/indexing-service/src/main/java/org/apache/druid/indexing/common/task/batch/parallel/PartialDimensionCardinalityParallelIndexTaskRunner.java
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.indexing.common.task.batch.parallel;
+
+import org.apache.druid.data.input.InputSplit;
+import org.apache.druid.indexing.common.TaskToolbox;
+
+import java.util.Map;
+
+/**
+ * {@link ParallelIndexTaskRunner} for the phase to determine cardinalities of dimension values in
+ * multi-phase parallel indexing.
+ */
+class PartialDimensionCardinalityParallelIndexTaskRunner
+    extends InputSourceSplitParallelIndexTaskRunner<PartialDimensionCardinalityTask, DimensionCardinalityReport>
+{
+  private static final String PHASE_NAME = "partial dimension cardinality";
+
+  PartialDimensionCardinalityParallelIndexTaskRunner(
+      TaskToolbox toolbox,
+      String taskId,
+      String groupId,
+      ParallelIndexIngestionSpec ingestionSchema,
+      Map<String, Object> context
+  )
+  {
+    super(
+        toolbox,
+        taskId,
+        groupId,
+        ingestionSchema,
+        context
+    );
+  }
+
+  @Override
+  public String getName()
+  {
+    return PHASE_NAME;
+  }
+
+  @Override
+  SubTaskSpec<PartialDimensionCardinalityTask> createSubTaskSpec(
+      String id,
+      String groupId,
+      String supervisorTaskId,
+      Map<String, Object> context,
+      InputSplit split,
+      ParallelIndexIngestionSpec subTaskIngestionSpec
+  )
+  {
+    return new SubTaskSpec<PartialDimensionCardinalityTask>(
+        id,
+        groupId,
+        supervisorTaskId,
+        context,
+        split
+    )
+    {
+      @Override
+      public PartialDimensionCardinalityTask newSubTask(int numAttempts)
+      {
+        return new PartialDimensionCardinalityTask(
+            null,
+            getGroupId(),
+            null,
+            getSupervisorTaskId(),
+            numAttempts,
+            subTaskIngestionSpec,
+            getContext(),
+            getToolbox().getJsonMapper()
+        );
+      }
+    };
+  }
+}
diff --git a/indexing-service/src/main/java/org/apache/druid/indexing/common/task/batch/parallel/PartialDimensionCardinalityTask.java b/indexing-service/src/main/java/org/apache/druid/indexing/common/task/batch/parallel/PartialDimensionCardinalityTask.java
new file mode 100644
index 0000000..6ba5ad3
--- /dev/null
+++ b/indexing-service/src/main/java/org/apache/druid/indexing/common/task/batch/parallel/PartialDimensionCardinalityTask.java
@@ -0,0 +1,245 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.indexing.common.task.batch.parallel;
+
+import com.fasterxml.jackson.annotation.JacksonInject;
+import com.fasterxml.jackson.annotation.JsonCreator;
+import com.fasterxml.jackson.annotation.JsonProperty;
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.google.common.base.Preconditions;
+import org.apache.datasketches.hll.HllSketch;
+import org.apache.druid.data.input.InputFormat;
+import org.apache.druid.data.input.InputRow;
+import org.apache.druid.data.input.InputSource;
+import org.apache.druid.indexer.TaskStatus;
+import org.apache.druid.indexer.partitions.HashedPartitionsSpec;
+import org.apache.druid.indexing.common.TaskToolbox;
+import org.apache.druid.indexing.common.actions.TaskActionClient;
+import org.apache.druid.indexing.common.task.AbstractBatchIndexTask;
+import org.apache.druid.indexing.common.task.ClientBasedTaskInfoProvider;
+import org.apache.druid.indexing.common.task.TaskResource;
+import org.apache.druid.java.util.common.granularity.Granularity;
+import org.apache.druid.java.util.common.logger.Logger;
+import org.apache.druid.java.util.common.parsers.CloseableIterator;
+import org.apache.druid.segment.incremental.ParseExceptionHandler;
+import org.apache.druid.segment.incremental.RowIngestionMeters;
+import org.apache.druid.segment.indexing.DataSchema;
+import org.apache.druid.segment.indexing.granularity.GranularitySpec;
+import org.apache.druid.timeline.partition.HashBasedNumberedShardSpec;
+import org.joda.time.DateTime;
+import org.joda.time.Interval;
+
+import javax.annotation.Nullable;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+public class PartialDimensionCardinalityTask extends PerfectRollupWorkerTask
+{
+  public static final String TYPE = "partial_dimension_cardinality";
+  private static final Logger LOG = new Logger(PartialDimensionCardinalityTask.class);
+
+  private final int numAttempts;
+  private final ParallelIndexIngestionSpec ingestionSchema;
+  private final String supervisorTaskId;
+
+  private final ObjectMapper jsonMapper;
+
+  @JsonCreator
+  PartialDimensionCardinalityTask(
+      // id shouldn't be null except when this task is created by ParallelIndexSupervisorTask
+      @JsonProperty("id") @Nullable String id,
+      @JsonProperty("groupId") final String groupId,
+      @JsonProperty("resource") final TaskResource taskResource,
+      @JsonProperty("supervisorTaskId") final String supervisorTaskId,
+      @JsonProperty("numAttempts") final int numAttempts, // zero-based counting
+      @JsonProperty("spec") final ParallelIndexIngestionSpec ingestionSchema,
+      @JsonProperty("context") final Map<String, Object> context,
+      @JacksonInject ObjectMapper jsonMapper
+  )
+  {
+    super(
+        getOrMakeId(id, TYPE, ingestionSchema.getDataSchema().getDataSource()),
+        groupId,
+        taskResource,
+        ingestionSchema.getDataSchema(),
+        ingestionSchema.getTuningConfig(),
+        context
+    );
+
+    Preconditions.checkArgument(
+        ingestionSchema.getTuningConfig().getPartitionsSpec() instanceof HashedPartitionsSpec,
+        "%s partitionsSpec required",
+        HashedPartitionsSpec.NAME
+    );
+
+    this.numAttempts = numAttempts;
+    this.ingestionSchema = ingestionSchema;
+    this.supervisorTaskId = supervisorTaskId;
+    this.jsonMapper = jsonMapper;
+  }
+
+  @JsonProperty
+  private int getNumAttempts()
+  {
+    return numAttempts;
+  }
+
+  @JsonProperty("spec")
+  private ParallelIndexIngestionSpec getIngestionSchema()
+  {
+    return ingestionSchema;
+  }
+
+  @JsonProperty
+  private String getSupervisorTaskId()
+  {
+    return supervisorTaskId;
+  }
+
+  @Override
+  public String getType()
+  {
+    return TYPE;
+  }
+
+  @Override
+  public boolean isReady(TaskActionClient taskActionClient) throws Exception
+  {
+    return tryTimeChunkLock(
+        taskActionClient,
+        getIngestionSchema().getDataSchema().getGranularitySpec().inputIntervals()
+    );
+  }
+
+  @Override
+  public TaskStatus runTask(TaskToolbox toolbox) throws Exception
+  {
+    DataSchema dataSchema = ingestionSchema.getDataSchema();
+    GranularitySpec granularitySpec = dataSchema.getGranularitySpec();
+    ParallelIndexTuningConfig tuningConfig = ingestionSchema.getTuningConfig();
+
+    HashedPartitionsSpec partitionsSpec = (HashedPartitionsSpec) tuningConfig.getPartitionsSpec();
+    Preconditions.checkNotNull(partitionsSpec, "partitionsSpec required in tuningConfig");
+
+    List<String> partitionDimensions = partitionsSpec.getPartitionDimensions();
+    if (partitionDimensions == null) {
+      partitionDimensions = HashBasedNumberedShardSpec.DEFAULT_PARTITION_DIMENSIONS;
+    }
+
+    InputSource inputSource = ingestionSchema.getIOConfig().getNonNullInputSource(
+        ingestionSchema.getDataSchema().getParser()
+    );
+    InputFormat inputFormat = inputSource.needsFormat()
+                              ? ParallelIndexSupervisorTask.getInputFormat(ingestionSchema)
+                              : null;
+    final RowIngestionMeters buildSegmentsMeters = toolbox.getRowIngestionMetersFactory().createRowIngestionMeters();
+    final ParseExceptionHandler parseExceptionHandler = new ParseExceptionHandler(
+        buildSegmentsMeters,
+        tuningConfig.isLogParseExceptions(),
+        tuningConfig.getMaxParseExceptions(),
+        tuningConfig.getMaxSavedParseExceptions()
+    );
+
+    try (
+        final CloseableIterator<InputRow> inputRowIterator = AbstractBatchIndexTask.inputSourceReader(
+            toolbox.getIndexingTmpDir(),
+            dataSchema,
+            inputSource,
+            inputFormat,
+            AbstractBatchIndexTask.defaultRowFilter(granularitySpec),
+            buildSegmentsMeters,
+            parseExceptionHandler
+        );
+    ) {
+      Map<Interval, byte[]> cardinalities = determineCardinalities(
+          inputRowIterator,
+          granularitySpec,
+          partitionDimensions
+      );
+
+      sendReport(
+          toolbox,
+          new DimensionCardinalityReport(getId(), cardinalities)
+      );
+    }
+
+    return TaskStatus.success(getId());
+  }
+
+  private Map<Interval, byte[]> determineCardinalities(
+      CloseableIterator<InputRow> inputRowIterator,
+      GranularitySpec granularitySpec,
+      List<String> partitionDimensions
+  )
+  {
+    Map<Interval, HllSketch> intervalToCardinalities = new HashMap<>();
+    while (inputRowIterator.hasNext()) {
+      InputRow inputRow = inputRowIterator.next();
+      //noinspection ConstantConditions (null rows are filtered out by FilteringCloseableInputRowIterator
+      DateTime timestamp = inputRow.getTimestamp();
+      //noinspection OptionalGetWithoutIsPresent (InputRowIterator returns rows with present intervals)
+      Interval interval = granularitySpec.bucketInterval(timestamp).get();
+      Granularity queryGranularity = granularitySpec.getQueryGranularity();
+
+      HllSketch hllSketch = intervalToCardinalities.computeIfAbsent(
+          interval,
+          (intervalKey) -> {
+            return DimensionCardinalityReport.createHllSketchForReport();
+          }
+      );
+      List<Object> groupKey = HashBasedNumberedShardSpec.getGroupKey(
+          partitionDimensions,
+          queryGranularity.bucketStart(timestamp).getMillis(),
+          inputRow
+      );
+
+      try {
+        hllSketch.update(
+            jsonMapper.writeValueAsBytes(groupKey)
+        );
+      }
+      catch (JsonProcessingException jpe) {
+        throw new RuntimeException(jpe);
+      }
+    }
+
+    // Serialize the collectors for sending to the supervisor task
+    Map<Interval, byte[]> newMap = new HashMap<>();
+    for (Map.Entry<Interval, HllSketch> entry : intervalToCardinalities.entrySet()) {
+      newMap.put(entry.getKey(), entry.getValue().toCompactByteArray());
+    }
+    return newMap;
+  }
+
+  private void sendReport(TaskToolbox toolbox, DimensionCardinalityReport report)
+  {
+    final ParallelIndexSupervisorTaskClient taskClient = toolbox.getSupervisorTaskClientFactory().build(
+        new ClientBasedTaskInfoProvider(toolbox.getIndexingServiceClient()),
+        getId(),
+        1, // always use a single http thread
+        ingestionSchema.getTuningConfig().getChatHandlerTimeout(),
+        ingestionSchema.getTuningConfig().getChatHandlerNumRetries()
+    );
+    taskClient.report(supervisorTaskId, report);
+  }
+
+}
diff --git a/indexing-service/src/main/java/org/apache/druid/indexing/common/task/batch/parallel/PartialHashSegmentGenerateParallelIndexTaskRunner.java b/indexing-service/src/main/java/org/apache/druid/indexing/common/task/batch/parallel/PartialHashSegmentGenerateParallelIndexTaskRunner.java
index 17d3b36..39024a1 100644
--- a/indexing-service/src/main/java/org/apache/druid/indexing/common/task/batch/parallel/PartialHashSegmentGenerateParallelIndexTaskRunner.java
+++ b/indexing-service/src/main/java/org/apache/druid/indexing/common/task/batch/parallel/PartialHashSegmentGenerateParallelIndexTaskRunner.java
@@ -32,15 +32,19 @@
 {
   private static final String PHASE_NAME = "partial segment generation";
 
+  private Integer numShardsOverride;
+
   PartialHashSegmentGenerateParallelIndexTaskRunner(
       TaskToolbox toolbox,
       String taskId,
       String groupId,
       ParallelIndexIngestionSpec ingestionSchema,
-      Map<String, Object> context
+      Map<String, Object> context,
+      Integer numShardsOverride
   )
   {
     super(toolbox, taskId, groupId, ingestionSchema, context);
+    this.numShardsOverride = numShardsOverride;
   }
 
   @Override
@@ -77,7 +81,8 @@
             supervisorTaskId,
             numAttempts,
             subTaskIngestionSpec,
-            context
+            context,
+            numShardsOverride
         );
       }
     };
diff --git a/indexing-service/src/main/java/org/apache/druid/indexing/common/task/batch/parallel/PartialHashSegmentGenerateTask.java b/indexing-service/src/main/java/org/apache/druid/indexing/common/task/batch/parallel/PartialHashSegmentGenerateTask.java
index 98dae99..ff0090c 100644
--- a/indexing-service/src/main/java/org/apache/druid/indexing/common/task/batch/parallel/PartialHashSegmentGenerateTask.java
+++ b/indexing-service/src/main/java/org/apache/druid/indexing/common/task/batch/parallel/PartialHashSegmentGenerateTask.java
@@ -56,6 +56,7 @@
   private final int numAttempts;
   private final ParallelIndexIngestionSpec ingestionSchema;
   private final String supervisorTaskId;
+  private final Integer numShardsOverride;
 
   @JsonCreator
   public PartialHashSegmentGenerateTask(
@@ -66,7 +67,8 @@
       @JsonProperty("supervisorTaskId") final String supervisorTaskId,
       @JsonProperty("numAttempts") final int numAttempts, // zero-based counting
       @JsonProperty(PROP_SPEC) final ParallelIndexIngestionSpec ingestionSchema,
-      @JsonProperty("context") final Map<String, Object> context
+      @JsonProperty("context") final Map<String, Object> context,
+      @Nullable @JsonProperty("numShardsOverride") final Integer numShardsOverride
   )
   {
     super(
@@ -82,6 +84,7 @@
     this.numAttempts = numAttempts;
     this.ingestionSchema = ingestionSchema;
     this.supervisorTaskId = supervisorTaskId;
+    this.numShardsOverride = numShardsOverride;
   }
 
   @JsonProperty
@@ -130,7 +133,7 @@
         getId(),
         granularitySpec,
         new SupervisorTaskAccess(supervisorTaskId, taskClient),
-        createHashPartitionAnalysisFromPartitionsSpec(granularitySpec, partitionsSpec)
+        createHashPartitionAnalysisFromPartitionsSpec(granularitySpec, partitionsSpec, numShardsOverride)
     );
   }
 
@@ -165,13 +168,21 @@
    */
   public static HashPartitionAnalysis createHashPartitionAnalysisFromPartitionsSpec(
       GranularitySpec granularitySpec,
-      @Nonnull HashedPartitionsSpec partitionsSpec
+      @Nonnull HashedPartitionsSpec partitionsSpec,
+      @Nullable Integer numShardsOverride
   )
   {
     final SortedSet<Interval> intervals = granularitySpec.bucketIntervals().get();
-    final int numBucketsPerInterval = partitionsSpec.getNumShards() == null
-                                      ? 1
-                                      : partitionsSpec.getNumShards();
+
+    final int numBucketsPerInterval;
+    if (numShardsOverride != null) {
+      numBucketsPerInterval = numShardsOverride;
+    } else {
+      numBucketsPerInterval = partitionsSpec.getNumShards() == null
+                              ? 1
+                              : partitionsSpec.getNumShards();
+    }
+
     final HashPartitionAnalysis partitionAnalysis = new HashPartitionAnalysis(partitionsSpec);
     intervals.forEach(interval -> partitionAnalysis.updateBucket(interval, numBucketsPerInterval));
     return partitionAnalysis;
diff --git a/indexing-service/src/main/java/org/apache/druid/indexing/common/task/batch/parallel/SubTaskReport.java b/indexing-service/src/main/java/org/apache/druid/indexing/common/task/batch/parallel/SubTaskReport.java
index 26f20f6..51bcbc8 100644
--- a/indexing-service/src/main/java/org/apache/druid/indexing/common/task/batch/parallel/SubTaskReport.java
+++ b/indexing-service/src/main/java/org/apache/druid/indexing/common/task/batch/parallel/SubTaskReport.java
@@ -31,6 +31,7 @@
 @JsonSubTypes(value = {
     @Type(name = PushedSegmentsReport.TYPE, value = PushedSegmentsReport.class),
     @Type(name = DimensionDistributionReport.TYPE, value = DimensionDistributionReport.class),
+    @Type(name = DimensionCardinalityReport.TYPE, value = DimensionCardinalityReport.class),
     @Type(name = GeneratedPartitionsMetadataReport.TYPE, value = GeneratedPartitionsMetadataReport.class)
 })
 public interface SubTaskReport
diff --git a/indexing-service/src/test/java/org/apache/druid/indexing/common/task/batch/parallel/AbstractParallelIndexSupervisorTaskTest.java b/indexing-service/src/test/java/org/apache/druid/indexing/common/task/batch/parallel/AbstractParallelIndexSupervisorTaskTest.java
index 5ef3120..af22160 100644
--- a/indexing-service/src/test/java/org/apache/druid/indexing/common/task/batch/parallel/AbstractParallelIndexSupervisorTaskTest.java
+++ b/indexing-service/src/test/java/org/apache/druid/indexing/common/task/batch/parallel/AbstractParallelIndexSupervisorTaskTest.java
@@ -535,7 +535,8 @@
         new NamedType(PartialHashSegmentGenerateTask.class, PartialHashSegmentGenerateTask.TYPE),
         new NamedType(PartialRangeSegmentGenerateTask.class, PartialRangeSegmentGenerateTask.TYPE),
         new NamedType(PartialGenericSegmentMergeTask.class, PartialGenericSegmentMergeTask.TYPE),
-        new NamedType(PartialDimensionDistributionTask.class, PartialDimensionDistributionTask.TYPE)
+        new NamedType(PartialDimensionDistributionTask.class, PartialDimensionDistributionTask.TYPE),
+        new NamedType(PartialDimensionCardinalityTask.class, PartialDimensionCardinalityTask.TYPE)
     );
   }
 
diff --git a/indexing-service/src/test/java/org/apache/druid/indexing/common/task/batch/parallel/DimensionCardinalityReportTest.java b/indexing-service/src/test/java/org/apache/druid/indexing/common/task/batch/parallel/DimensionCardinalityReportTest.java
new file mode 100644
index 0000000..102b5f8
--- /dev/null
+++ b/indexing-service/src/test/java/org/apache/druid/indexing/common/task/batch/parallel/DimensionCardinalityReportTest.java
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.indexing.common.task.batch.parallel;
+
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.google.common.collect.ImmutableMap;
+import nl.jqno.equalsverifier.EqualsVerifier;
+import org.apache.datasketches.hll.HllSketch;
+import org.apache.druid.hll.HyperLogLogCollector;
+import org.apache.druid.indexing.common.task.IndexTask;
+import org.apache.druid.java.util.common.Intervals;
+import org.apache.druid.segment.TestHelper;
+import org.joda.time.Interval;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+
+public class DimensionCardinalityReportTest
+{
+  private static final ObjectMapper OBJECT_MAPPER = ParallelIndexTestingFactory.createObjectMapper();
+
+  private DimensionCardinalityReport target;
+
+  @Before
+  public void setup()
+  {
+    Interval interval = Intervals.ETERNITY;
+    HyperLogLogCollector collector = HyperLogLogCollector.makeLatestCollector();
+    Map<Interval, byte[]> intervalToCardinality = Collections.singletonMap(interval, collector.toByteArray());
+    String taskId = "abc";
+    target = new DimensionCardinalityReport(taskId, intervalToCardinality);
+  }
+
+  @Test
+  public void serializesDeserializes()
+  {
+    TestHelper.testSerializesDeserializes(OBJECT_MAPPER, target);
+  }
+
+  @Test
+  public void abidesEqualsContract()
+  {
+    EqualsVerifier.forClass(DimensionCardinalityReport.class)
+                  .usingGetClass()
+                  .verify();
+  }
+
+  @Test
+  public void testSupervisorDetermineNumShardsFromCardinalityReport()
+  {
+    List<DimensionCardinalityReport> reports = new ArrayList<>();
+
+    HllSketch collector1 = DimensionCardinalityReport.createHllSketchForReport();
+    collector1.update(IndexTask.HASH_FUNCTION.hashLong(1L).asBytes());
+    collector1.update(IndexTask.HASH_FUNCTION.hashLong(200L).asBytes());
+    DimensionCardinalityReport report1 = new DimensionCardinalityReport(
+        "taskA",
+        ImmutableMap.of(
+            Intervals.of("1970-01-01T00:00:00.000Z/1970-01-02T00:00:00.000Z"),
+            collector1.toCompactByteArray()
+        )
+    );
+    reports.add(report1);
+
+    HllSketch collector2 = DimensionCardinalityReport.createHllSketchForReport();
+    collector2.update(IndexTask.HASH_FUNCTION.hashLong(1000L).asBytes());
+    collector2.update(IndexTask.HASH_FUNCTION.hashLong(30000L).asBytes());
+    DimensionCardinalityReport report2 = new DimensionCardinalityReport(
+        "taskB",
+        ImmutableMap.of(
+            Intervals.of("1970-01-01T00:00:00.000Z/1970-01-02T00:00:00.000Z"),
+            collector2.toCompactByteArray()
+        )
+    );
+    reports.add(report2);
+
+    // Separate interval with only 1 value
+    HllSketch collector3 = DimensionCardinalityReport.createHllSketchForReport();
+    collector3.update(IndexTask.HASH_FUNCTION.hashLong(99000L).asBytes());
+    DimensionCardinalityReport report3 = new DimensionCardinalityReport(
+        "taskC",
+        ImmutableMap.of(
+            Intervals.of("1970-01-02T00:00:00.000Z/1970-01-03T00:00:00.000Z"),
+            collector3.toCompactByteArray()
+        )
+    );
+    reports.add(report3);
+
+    // first interval in test has cardinality 4
+    int numShards = ParallelIndexSupervisorTask.determineNumShardsFromCardinalityReport(
+        reports,
+        1
+    );
+    Assert.assertEquals(4L, numShards);
+
+    numShards = ParallelIndexSupervisorTask.determineNumShardsFromCardinalityReport(
+        reports,
+        2
+    );
+    Assert.assertEquals(2L, numShards);
+
+    numShards = ParallelIndexSupervisorTask.determineNumShardsFromCardinalityReport(
+        reports,
+        3
+    );
+    Assert.assertEquals(2L, numShards);
+
+    numShards = ParallelIndexSupervisorTask.determineNumShardsFromCardinalityReport(
+        reports,
+        4
+    );
+    Assert.assertEquals(1L, numShards);
+
+    numShards = ParallelIndexSupervisorTask.determineNumShardsFromCardinalityReport(
+        reports,
+        5
+    );
+    Assert.assertEquals(1L, numShards);
+  }
+}
diff --git a/indexing-service/src/test/java/org/apache/druid/indexing/common/task/batch/parallel/HashPartitionMultiPhaseParallelIndexingTest.java b/indexing-service/src/test/java/org/apache/druid/indexing/common/task/batch/parallel/HashPartitionMultiPhaseParallelIndexingTest.java
index 45f40bd..8bc6c14 100644
--- a/indexing-service/src/test/java/org/apache/druid/indexing/common/task/batch/parallel/HashPartitionMultiPhaseParallelIndexingTest.java
+++ b/indexing-service/src/test/java/org/apache/druid/indexing/common/task/batch/parallel/HashPartitionMultiPhaseParallelIndexingTest.java
@@ -44,6 +44,7 @@
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
 
+import javax.annotation.Nullable;
 import java.io.File;
 import java.io.IOException;
 import java.io.Writer;
@@ -83,29 +84,37 @@
   );
   private static final Interval INTERVAL_TO_INDEX = Intervals.of("2017-12/P1M");
 
-  @Parameterized.Parameters(name = "{0}, useInputFormatApi={1}")
+  @Parameterized.Parameters(
+      name = "lockGranularity={0}, useInputFormatApi={1}, maxNumConcurrentSubTasks={2}, numShards={3}"
+  )
   public static Iterable<Object[]> constructorFeeder()
   {
     return ImmutableList.of(
-        new Object[]{LockGranularity.TIME_CHUNK, false, 2},
-        new Object[]{LockGranularity.TIME_CHUNK, true, 2},
-        new Object[]{LockGranularity.TIME_CHUNK, true, 1},
-        new Object[]{LockGranularity.SEGMENT, true, 2}
+        new Object[]{LockGranularity.TIME_CHUNK, false, 2, 2},
+        new Object[]{LockGranularity.TIME_CHUNK, true, 2, 2},
+        new Object[]{LockGranularity.TIME_CHUNK, true, 1, 2},
+        new Object[]{LockGranularity.SEGMENT, true, 2, 2},
+        new Object[]{LockGranularity.TIME_CHUNK, true, 2, null},
+        new Object[]{LockGranularity.TIME_CHUNK, true, 1, null},
+        new Object[]{LockGranularity.SEGMENT, true, 2, null}
     );
   }
 
   private final int maxNumConcurrentSubTasks;
+  private final Integer numShards;
 
   private File inputDir;
 
   public HashPartitionMultiPhaseParallelIndexingTest(
       LockGranularity lockGranularity,
       boolean useInputFormatApi,
-      int maxNumConcurrentSubTasks
+      int maxNumConcurrentSubTasks,
+      @Nullable Integer numShards
   )
   {
     super(lockGranularity, useInputFormatApi);
     this.maxNumConcurrentSubTasks = maxNumConcurrentSubTasks;
+    this.numShards = numShards;
   }
 
   @Before
@@ -135,11 +144,16 @@
   public void testRun() throws Exception
   {
     final Set<DataSegment> publishedSegments = runTestTask(
-        new HashedPartitionsSpec(null, 2, ImmutableList.of("dim1", "dim2")),
+        new HashedPartitionsSpec(null, numShards, ImmutableList.of("dim1", "dim2")),
         TaskState.SUCCESS,
         false
     );
-    assertHashedPartition(publishedSegments);
+
+    // we don't specify maxRowsPerSegment so it defaults to DEFAULT_MAX_ROWS_PER_SEGMENT,
+    // which is 5 million, so assume that there will only be 1 shard if numShards is not set.
+    int expectedSegmentCount = numShards != null ? numShards : 1;
+
+    assertHashedPartition(publishedSegments, expectedSegmentCount);
   }
 
   @Test
@@ -148,7 +162,7 @@
     final Set<DataSegment> publishedSegments = new HashSet<>();
     publishedSegments.addAll(
         runTestTask(
-            new HashedPartitionsSpec(null, 2, ImmutableList.of("dim1", "dim2")),
+            new HashedPartitionsSpec(null, numShards, ImmutableList.of("dim1", "dim2")),
             TaskState.SUCCESS,
             false
         )
@@ -235,7 +249,7 @@
     }
   }
 
-  private void assertHashedPartition(Set<DataSegment> publishedSegments) throws IOException
+  private void assertHashedPartition(Set<DataSegment> publishedSegments, int expectedNumSegments) throws IOException
   {
     final Map<Interval, List<DataSegment>> intervalToSegments = new HashMap<>();
     publishedSegments.forEach(
@@ -243,7 +257,7 @@
     );
     final File tempSegmentDir = temporaryFolder.newFolder();
     for (List<DataSegment> segmentsInInterval : intervalToSegments.values()) {
-      Assert.assertEquals(2, segmentsInInterval.size());
+      Assert.assertEquals(expectedNumSegments, segmentsInInterval.size());
       for (DataSegment segment : segmentsInInterval) {
         List<ScanResultValue> results = querySegment(segment, ImmutableList.of("dim1", "dim2"), tempSegmentDir);
         final int hash = HashBasedNumberedShardSpec.hash(getObjectMapper(), (List<Object>) results.get(0).getEvents());
diff --git a/indexing-service/src/test/java/org/apache/druid/indexing/common/task/batch/parallel/ParallelIndexSupervisorTaskSerdeTest.java b/indexing-service/src/test/java/org/apache/druid/indexing/common/task/batch/parallel/ParallelIndexSupervisorTaskSerdeTest.java
index faa8002..e6b144f 100644
--- a/indexing-service/src/test/java/org/apache/druid/indexing/common/task/batch/parallel/ParallelIndexSupervisorTaskSerdeTest.java
+++ b/indexing-service/src/test/java/org/apache/druid/indexing/common/task/batch/parallel/ParallelIndexSupervisorTaskSerdeTest.java
@@ -21,6 +21,7 @@
 
 import com.fasterxml.jackson.databind.ObjectMapper;
 import com.fasterxml.jackson.databind.jsontype.NamedType;
+import org.apache.druid.common.config.NullHandling;
 import org.apache.druid.data.input.impl.CsvInputFormat;
 import org.apache.druid.data.input.impl.DimensionsSpec;
 import org.apache.druid.data.input.impl.LocalInputSource;
@@ -55,6 +56,10 @@
 
 public class ParallelIndexSupervisorTaskSerdeTest
 {
+  static {
+    NullHandling.initializeForTests();
+  }
+
   private static final ObjectMapper OBJECT_MAPPER = createObjectMapper();
   private static final List<Interval> INTERVALS = Collections.singletonList(Intervals.of("2018/2019"));
 
@@ -108,13 +113,8 @@
   @Test
   public void forceGuaranteedRollupWithHashPartitionsMissingNumShards()
   {
-    expectedException.expect(IllegalStateException.class);
-    expectedException.expectMessage(
-        "forceGuaranteedRollup is incompatible with partitionsSpec: numShards must be specified"
-    );
-
     Integer numShards = null;
-    new ParallelIndexSupervisorTaskBuilder()
+    ParallelIndexSupervisorTask task = new ParallelIndexSupervisorTaskBuilder()
         .ingestionSpec(
             new ParallelIndexIngestionSpecBuilder()
                 .forceGuaranteedRollup(true)
@@ -123,6 +123,9 @@
                 .build()
         )
         .build();
+
+    PartitionsSpec partitionsSpec = task.getIngestionSchema().getTuningConfig().getPartitionsSpec();
+    Assert.assertThat(partitionsSpec, CoreMatchers.instanceOf(HashedPartitionsSpec.class));
   }
 
   @Test
diff --git a/indexing-service/src/test/java/org/apache/druid/indexing/common/task/batch/parallel/PartialDimensionCardinalityTaskTest.java b/indexing-service/src/test/java/org/apache/druid/indexing/common/task/batch/parallel/PartialDimensionCardinalityTaskTest.java
new file mode 100644
index 0000000..0ad4dde
--- /dev/null
+++ b/indexing-service/src/test/java/org/apache/druid/indexing/common/task/batch/parallel/PartialDimensionCardinalityTaskTest.java
@@ -0,0 +1,402 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.indexing.common.task.batch.parallel;
+
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Iterables;
+import org.apache.datasketches.hll.HllSketch;
+import org.apache.datasketches.memory.Memory;
+import org.apache.druid.client.indexing.NoopIndexingServiceClient;
+import org.apache.druid.data.input.InputFormat;
+import org.apache.druid.data.input.InputSource;
+import org.apache.druid.data.input.impl.InlineInputSource;
+import org.apache.druid.indexer.TaskState;
+import org.apache.druid.indexer.TaskStatus;
+import org.apache.druid.indexer.partitions.DynamicPartitionsSpec;
+import org.apache.druid.indexer.partitions.HashedPartitionsSpec;
+import org.apache.druid.indexer.partitions.PartitionsSpec;
+import org.apache.druid.indexer.partitions.SingleDimensionPartitionsSpec;
+import org.apache.druid.indexing.common.TaskInfoProvider;
+import org.apache.druid.indexing.common.TaskToolbox;
+import org.apache.druid.indexing.common.stats.DropwizardRowIngestionMetersFactory;
+import org.apache.druid.indexing.common.task.IndexTaskClientFactory;
+import org.apache.druid.java.util.common.DateTimes;
+import org.apache.druid.java.util.common.Intervals;
+import org.apache.druid.java.util.common.granularity.Granularities;
+import org.apache.druid.segment.TestHelper;
+import org.apache.druid.segment.incremental.ParseExceptionHandler;
+import org.apache.druid.segment.indexing.DataSchema;
+import org.apache.druid.segment.indexing.granularity.UniformGranularitySpec;
+import org.apache.druid.testing.junit.LoggerCaptureRule;
+import org.apache.logging.log4j.core.LogEvent;
+import org.easymock.Capture;
+import org.easymock.EasyMock;
+import org.hamcrest.Matchers;
+import org.joda.time.Duration;
+import org.joda.time.Interval;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.experimental.runners.Enclosed;
+import org.junit.rules.ExpectedException;
+import org.junit.rules.TemporaryFolder;
+import org.junit.runner.RunWith;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+
+@RunWith(Enclosed.class)
+public class PartialDimensionCardinalityTaskTest
+{
+  private static final ObjectMapper OBJECT_MAPPER = ParallelIndexTestingFactory.createObjectMapper();
+  private static final HashedPartitionsSpec HASHED_PARTITIONS_SPEC = HashedPartitionsSpec.defaultSpec();
+
+  public static class ConstructorTest
+  {
+    @Rule
+    public ExpectedException exception = ExpectedException.none();
+
+    @Test
+    public void requiresForceGuaranteedRollup()
+    {
+      exception.expect(IllegalArgumentException.class);
+      exception.expectMessage("forceGuaranteedRollup must be set");
+
+      ParallelIndexTuningConfig tuningConfig = new ParallelIndexTestingFactory.TuningConfigBuilder()
+          .forceGuaranteedRollup(false)
+          .partitionsSpec(new DynamicPartitionsSpec(null, null))
+          .build();
+
+      new PartialDimensionCardinalityTaskBuilder()
+          .tuningConfig(tuningConfig)
+          .build();
+    }
+
+    @Test
+    public void requiresHashedPartitions()
+    {
+      exception.expect(IllegalArgumentException.class);
+      exception.expectMessage("hashed partitionsSpec required");
+
+      PartitionsSpec partitionsSpec = new SingleDimensionPartitionsSpec(null, 1, "a", false);
+      ParallelIndexTuningConfig tuningConfig =
+          new ParallelIndexTestingFactory.TuningConfigBuilder().partitionsSpec(partitionsSpec).build();
+
+      new PartialDimensionCardinalityTaskBuilder()
+          .tuningConfig(tuningConfig)
+          .build();
+    }
+
+    @Test
+    public void requiresGranularitySpecInputIntervals()
+    {
+      exception.expect(IllegalArgumentException.class);
+      exception.expectMessage("Missing intervals in granularitySpec");
+
+      DataSchema dataSchema = ParallelIndexTestingFactory.createDataSchema(Collections.emptyList());
+
+      new PartialDimensionCardinalityTaskBuilder()
+          .dataSchema(dataSchema)
+          .build();
+    }
+
+    @Test
+    public void serializesDeserializes()
+    {
+      PartialDimensionCardinalityTask task = new PartialDimensionCardinalityTaskBuilder()
+          .build();
+      TestHelper.testSerializesDeserializes(OBJECT_MAPPER, task);
+    }
+
+    @Test
+    public void hasCorrectPrefixForAutomaticId()
+    {
+      PartialDimensionCardinalityTask task = new PartialDimensionCardinalityTaskBuilder()
+          .id(ParallelIndexTestingFactory.AUTOMATIC_ID)
+          .build();
+      Assert.assertThat(task.getId(), Matchers.startsWith(PartialDimensionCardinalityTask.TYPE));
+    }
+  }
+
+  public static class RunTaskTest
+  {
+    @Rule
+    public ExpectedException exception = ExpectedException.none();
+
+    @Rule
+    public TemporaryFolder temporaryFolder = new TemporaryFolder();
+
+    @Rule
+    public LoggerCaptureRule logger = new LoggerCaptureRule(ParseExceptionHandler.class);
+
+    private Capture<SubTaskReport> reportCapture;
+    private TaskToolbox taskToolbox;
+
+    @Before
+    public void setup()
+    {
+      reportCapture = Capture.newInstance();
+      ParallelIndexSupervisorTaskClient taskClient = EasyMock.mock(ParallelIndexSupervisorTaskClient.class);
+      taskClient.report(EasyMock.eq(ParallelIndexTestingFactory.SUPERVISOR_TASK_ID), EasyMock.capture(reportCapture));
+      EasyMock.replay(taskClient);
+      taskToolbox = EasyMock.mock(TaskToolbox.class);
+      EasyMock.expect(taskToolbox.getIndexingTmpDir()).andStubReturn(temporaryFolder.getRoot());
+      EasyMock.expect(taskToolbox.getSupervisorTaskClientFactory()).andReturn(
+          new IndexTaskClientFactory<ParallelIndexSupervisorTaskClient>()
+          {
+            @Override
+            public ParallelIndexSupervisorTaskClient build(
+                TaskInfoProvider taskInfoProvider,
+                String callerId,
+                int numThreads,
+                Duration httpTimeout,
+                long numRetries
+            )
+            {
+              return taskClient;
+            }
+          }
+      );
+      EasyMock.expect(taskToolbox.getIndexingServiceClient()).andReturn(new NoopIndexingServiceClient());
+      EasyMock.expect(taskToolbox.getRowIngestionMetersFactory()).andReturn(new DropwizardRowIngestionMetersFactory());
+      EasyMock.replay(taskToolbox);
+    }
+
+    @Test
+    public void requiresPartitionDimension() throws Exception
+    {
+      exception.expect(IllegalArgumentException.class);
+      exception.expectMessage("partitionDimension must be specified");
+
+      ParallelIndexTuningConfig tuningConfig = new ParallelIndexTestingFactory.TuningConfigBuilder()
+          .partitionsSpec(
+              new ParallelIndexTestingFactory.SingleDimensionPartitionsSpecBuilder().partitionDimension(null).build()
+          )
+          .build();
+      PartialDimensionCardinalityTask task = new PartialDimensionCardinalityTaskBuilder()
+          .tuningConfig(tuningConfig)
+          .build();
+
+      task.runTask(taskToolbox);
+    }
+
+    @Test
+    public void logsParseExceptionsIfEnabled() throws Exception
+    {
+      long invalidTimestamp = Long.MAX_VALUE;
+      InputSource inlineInputSource = new InlineInputSource(
+          ParallelIndexTestingFactory.createRow(invalidTimestamp, "a")
+      );
+      ParallelIndexTuningConfig tuningConfig = new ParallelIndexTestingFactory.TuningConfigBuilder()
+          .partitionsSpec(HASHED_PARTITIONS_SPEC)
+          .logParseExceptions(true)
+          .build();
+      PartialDimensionCardinalityTask task = new PartialDimensionCardinalityTaskBuilder()
+          .inputSource(inlineInputSource)
+          .tuningConfig(tuningConfig)
+          .build();
+
+      task.runTask(taskToolbox);
+
+      List<LogEvent> logEvents = logger.getLogEvents();
+      Assert.assertEquals(1, logEvents.size());
+      String logMessage = logEvents.get(0).getMessage().getFormattedMessage();
+      Assert.assertThat(logMessage, Matchers.containsString("Encountered parse exception"));
+    }
+
+    @Test
+    public void doesNotLogParseExceptionsIfDisabled() throws Exception
+    {
+      ParallelIndexTuningConfig tuningConfig = new ParallelIndexTestingFactory.TuningConfigBuilder()
+          .partitionsSpec(HASHED_PARTITIONS_SPEC)
+          .logParseExceptions(false)
+          .build();
+      PartialDimensionCardinalityTask task = new PartialDimensionCardinalityTaskBuilder()
+          .tuningConfig(tuningConfig)
+          .build();
+
+      task.runTask(taskToolbox);
+
+      Assert.assertEquals(Collections.emptyList(), logger.getLogEvents());
+    }
+
+    @Test
+    public void failsWhenTooManyParseExceptions() throws Exception
+    {
+      ParallelIndexTuningConfig tuningConfig = new ParallelIndexTestingFactory.TuningConfigBuilder()
+          .partitionsSpec(HASHED_PARTITIONS_SPEC)
+          .maxParseExceptions(0)
+          .build();
+      PartialDimensionCardinalityTask task = new PartialDimensionCardinalityTaskBuilder()
+          .tuningConfig(tuningConfig)
+          .build();
+
+      exception.expect(RuntimeException.class);
+      exception.expectMessage("Max parse exceptions[0] exceeded");
+
+      task.runTask(taskToolbox);
+    }
+
+    @Test
+    public void sendsCorrectReportWhenRowHasMultipleDimensionValues()
+    {
+      InputSource inlineInputSource = new InlineInputSource(
+          ParallelIndexTestingFactory.createRow(0, Arrays.asList("a", "b"))
+      );
+      PartialDimensionCardinalityTaskBuilder taskBuilder = new PartialDimensionCardinalityTaskBuilder()
+          .inputSource(inlineInputSource);
+
+      DimensionCardinalityReport report = runTask(taskBuilder);
+
+      Assert.assertEquals(ParallelIndexTestingFactory.ID, report.getTaskId());
+      Map<Interval, byte[]> intervalToCardinalities = report.getIntervalToCardinalities();
+      byte[] hllSketchBytes = Iterables.getOnlyElement(intervalToCardinalities.values());
+      HllSketch hllSketch = HllSketch.wrap(Memory.wrap(hllSketchBytes));
+      Assert.assertNotNull(hllSketch);
+      Assert.assertEquals(1L, (long) hllSketch.getEstimate());
+    }
+
+    @Test
+    public void sendsCorrectReportWithMultipleIntervalsInData()
+    {
+      // Segment granularity is DAY, query granularity is HOUR
+      InputSource inlineInputSource = new InlineInputSource(
+          ParallelIndexTestingFactory.createRow(DateTimes.of("1970-01-01T00:00:00.001Z").getMillis(), "a") + "\n" +
+          ParallelIndexTestingFactory.createRow(DateTimes.of("1970-01-02T03:46:40.000Z").getMillis(), "b") + "\n" +
+          ParallelIndexTestingFactory.createRow(DateTimes.of("1970-01-02T03:46:40.000Z").getMillis(), "c") + "\n" +
+          ParallelIndexTestingFactory.createRow(DateTimes.of("1970-01-02T04:02:40.000Z").getMillis(), "b") + "\n" +
+          ParallelIndexTestingFactory.createRow(DateTimes.of("1970-01-02T05:19:10.000Z").getMillis(), "b")
+      );
+      PartialDimensionCardinalityTaskBuilder taskBuilder = new PartialDimensionCardinalityTaskBuilder()
+          .inputSource(inlineInputSource);
+
+      DimensionCardinalityReport report = runTask(taskBuilder);
+
+      Assert.assertEquals(ParallelIndexTestingFactory.ID, report.getTaskId());
+      Map<Interval, byte[]> intervalToCardinalities = report.getIntervalToCardinalities();
+      Assert.assertEquals(2, intervalToCardinalities.size());
+
+      byte[] hllSketchBytes;
+      HllSketch hllSketch;
+      hllSketchBytes = intervalToCardinalities.get(Intervals.of("1970-01-01T00:00:00.000Z/1970-01-02T00:00:00.000Z"));
+      hllSketch = HllSketch.wrap(Memory.wrap(hllSketchBytes));
+      Assert.assertNotNull(hllSketch);
+      Assert.assertEquals(1L, (long) hllSketch.getEstimate());
+
+      hllSketchBytes = intervalToCardinalities.get(Intervals.of("1970-01-02T00:00:00.000Z/1970-01-03T00:00:00.000Z"));
+      hllSketch = HllSketch.wrap(Memory.wrap(hllSketchBytes));
+      Assert.assertNotNull(hllSketch);
+      Assert.assertEquals(4L, (long) hllSketch.getEstimate());
+    }
+
+    @Test
+    public void returnsSuccessIfNoExceptions() throws Exception
+    {
+      PartialDimensionCardinalityTask task = new PartialDimensionCardinalityTaskBuilder()
+          .build();
+
+      TaskStatus taskStatus = task.runTask(taskToolbox);
+
+      Assert.assertEquals(ParallelIndexTestingFactory.ID, taskStatus.getId());
+      Assert.assertEquals(TaskState.SUCCESS, taskStatus.getStatusCode());
+    }
+
+    private DimensionCardinalityReport runTask(PartialDimensionCardinalityTaskBuilder taskBuilder)
+    {
+      try {
+        taskBuilder.build()
+                   .runTask(taskToolbox);
+      }
+      catch (Exception e) {
+        throw new RuntimeException(e);
+      }
+
+      return (DimensionCardinalityReport) reportCapture.getValue();
+    }
+  }
+
+  private static class PartialDimensionCardinalityTaskBuilder
+  {
+    private static final InputFormat INPUT_FORMAT = ParallelIndexTestingFactory.getInputFormat();
+
+    private String id = ParallelIndexTestingFactory.ID;
+    private InputSource inputSource = new InlineInputSource("row-with-invalid-timestamp");
+    private ParallelIndexTuningConfig tuningConfig = new ParallelIndexTestingFactory.TuningConfigBuilder()
+        .partitionsSpec(HASHED_PARTITIONS_SPEC)
+        .build();
+    private DataSchema dataSchema =
+        ParallelIndexTestingFactory
+            .createDataSchema(ParallelIndexTestingFactory.INPUT_INTERVALS)
+            .withGranularitySpec(
+                new UniformGranularitySpec(
+                    Granularities.DAY,
+                    Granularities.HOUR,
+                    ImmutableList.of(Intervals.of("1970-01-01T00:00:00Z/P10D"))
+                )
+            );
+
+    @SuppressWarnings("SameParameterValue")
+    PartialDimensionCardinalityTaskBuilder id(String id)
+    {
+      this.id = id;
+      return this;
+    }
+
+    PartialDimensionCardinalityTaskBuilder inputSource(InputSource inputSource)
+    {
+      this.inputSource = inputSource;
+      return this;
+    }
+
+    PartialDimensionCardinalityTaskBuilder tuningConfig(ParallelIndexTuningConfig tuningConfig)
+    {
+      this.tuningConfig = tuningConfig;
+      return this;
+    }
+
+    PartialDimensionCardinalityTaskBuilder dataSchema(DataSchema dataSchema)
+    {
+      this.dataSchema = dataSchema;
+      return this;
+    }
+
+
+    PartialDimensionCardinalityTask build()
+    {
+      ParallelIndexIngestionSpec ingestionSpec =
+          ParallelIndexTestingFactory.createIngestionSpec(inputSource, INPUT_FORMAT, tuningConfig, dataSchema);
+
+      return new PartialDimensionCardinalityTask(
+          id,
+          ParallelIndexTestingFactory.GROUP_ID,
+          ParallelIndexTestingFactory.TASK_RESOURCE,
+          ParallelIndexTestingFactory.SUPERVISOR_TASK_ID,
+          ParallelIndexTestingFactory.NUM_ATTEMPTS,
+          ingestionSpec,
+          ParallelIndexTestingFactory.CONTEXT,
+          OBJECT_MAPPER
+      );
+    }
+  }
+}
diff --git a/indexing-service/src/test/java/org/apache/druid/indexing/common/task/batch/parallel/PartialHashSegmentGenerateTaskTest.java b/indexing-service/src/test/java/org/apache/druid/indexing/common/task/batch/parallel/PartialHashSegmentGenerateTaskTest.java
index 3643c74..ac32c81 100644
--- a/indexing-service/src/test/java/org/apache/druid/indexing/common/task/batch/parallel/PartialHashSegmentGenerateTaskTest.java
+++ b/indexing-service/src/test/java/org/apache/druid/indexing/common/task/batch/parallel/PartialHashSegmentGenerateTaskTest.java
@@ -60,7 +60,8 @@
         ParallelIndexTestingFactory.SUPERVISOR_TASK_ID,
         ParallelIndexTestingFactory.NUM_ATTEMPTS,
         INGESTION_SPEC,
-        ParallelIndexTestingFactory.CONTEXT
+        ParallelIndexTestingFactory.CONTEXT,
+        null
     );
   }
 
@@ -93,7 +94,8 @@
                 Granularities.NONE,
                 intervals
             ),
-            new HashedPartitionsSpec(null, expectedNumBuckets, null)
+            new HashedPartitionsSpec(null, expectedNumBuckets, null),
+            null
         );
     Assert.assertEquals(intervals.size(), partitionAnalysis.getNumTimePartitions());
     for (Interval interval : intervals) {
diff --git a/indexing-service/src/test/java/org/apache/druid/indexing/common/task/batch/parallel/PerfectRollupWorkerTaskTest.java b/indexing-service/src/test/java/org/apache/druid/indexing/common/task/batch/parallel/PerfectRollupWorkerTaskTest.java
index 4e041d6..98912b4 100644
--- a/indexing-service/src/test/java/org/apache/druid/indexing/common/task/batch/parallel/PerfectRollupWorkerTaskTest.java
+++ b/indexing-service/src/test/java/org/apache/druid/indexing/common/task/batch/parallel/PerfectRollupWorkerTaskTest.java
@@ -56,11 +56,8 @@
   }
 
   @Test
-  public void failsWithInvalidPartitionsSpec()
+  public void succeedsWithUnspecifiedNumShards()
   {
-    exception.expect(IllegalArgumentException.class);
-    exception.expectMessage("forceGuaranteedRollup is incompatible with partitionsSpec");
-
     new PerfectRollupWorkerTaskBuilder()
         .partitionsSpec(HashedPartitionsSpec.defaultSpec())
         .build();
diff --git a/integration-tests/src/test/java/org/apache/druid/tests/indexer/AbstractITBatchIndexTest.java b/integration-tests/src/test/java/org/apache/druid/tests/indexer/AbstractITBatchIndexTest.java
index 9a36015..d25e35f 100644
--- a/integration-tests/src/test/java/org/apache/druid/tests/indexer/AbstractITBatchIndexTest.java
+++ b/integration-tests/src/test/java/org/apache/druid/tests/indexer/AbstractITBatchIndexTest.java
@@ -23,6 +23,7 @@
 import com.google.inject.Inject;
 import org.apache.commons.io.IOUtils;
 import org.apache.druid.indexer.partitions.SecondaryPartitionType;
+import org.apache.druid.indexing.common.task.batch.parallel.PartialDimensionCardinalityTask;
 import org.apache.druid.indexing.common.task.batch.parallel.PartialDimensionDistributionTask;
 import org.apache.druid.indexing.common.task.batch.parallel.PartialGenericSegmentMergeTask;
 import org.apache.druid.indexing.common.task.batch.parallel.PartialHashSegmentGenerateTask;
@@ -321,6 +322,7 @@
                     } else {
                       return t.getType().equalsIgnoreCase(PartialHashSegmentGenerateTask.TYPE)
                              || t.getType().equalsIgnoreCase(PartialDimensionDistributionTask.TYPE)
+                             || t.getType().equalsIgnoreCase(PartialDimensionCardinalityTask.TYPE)
                              || t.getType().equalsIgnoreCase(PartialRangeSegmentGenerateTask.TYPE)
                              || t.getType().equalsIgnoreCase(PartialGenericSegmentMergeTask.TYPE);
                     }