fix count and average SQL aggregators on constant virtual columns (#11208)

* fix count and average SQL aggregators on constant virtual columns

* style

* even better, why are we tracking virtual columns in aggregations at all if we have a virtual column registry

* oops missed a few

* remove unused

* this will fix it
diff --git a/extensions-contrib/tdigestsketch/src/main/java/org/apache/druid/query/aggregation/tdigestsketch/TDigestSketchUtils.java b/extensions-contrib/tdigestsketch/src/main/java/org/apache/druid/query/aggregation/tdigestsketch/TDigestSketchUtils.java
index 3a5be13..4be77f6 100644
--- a/extensions-contrib/tdigestsketch/src/main/java/org/apache/druid/query/aggregation/tdigestsketch/TDigestSketchUtils.java
+++ b/extensions-contrib/tdigestsketch/src/main/java/org/apache/druid/query/aggregation/tdigestsketch/TDigestSketchUtils.java
@@ -24,8 +24,8 @@
 import org.apache.druid.java.util.common.StringUtils;
 import org.apache.druid.segment.VirtualColumn;
 import org.apache.druid.segment.virtual.ExpressionVirtualColumn;
-import org.apache.druid.sql.calcite.aggregation.Aggregation;
 import org.apache.druid.sql.calcite.expression.DruidExpression;
+import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
 
 import java.nio.ByteBuffer;
 
@@ -80,30 +80,24 @@
   }
 
   public static boolean matchingAggregatorFactoryExists(
+      final VirtualColumnRegistry virtualColumnRegistry,
       final DruidExpression input,
       final Integer compression,
-      final Aggregation existing,
       final TDigestSketchAggregatorFactory factory
   )
   {
     // Check input for equivalence.
     final boolean inputMatches;
-    final VirtualColumn virtualInput = existing.getVirtualColumns()
-                                               .stream()
-                                               .filter(
-                                                   virtualColumn ->
-                                                       virtualColumn.getOutputName()
-                                                                    .equals(factory.getFieldName())
-                                               )
-                                               .findFirst()
-                                               .orElse(null);
+    final VirtualColumn virtualInput =
+        virtualColumnRegistry.findVirtualColumns(factory.requiredFields())
+                             .stream()
+                             .findFirst()
+                             .orElse(null);
 
     if (virtualInput == null) {
-      inputMatches = input.isDirectColumnAccess()
-                     && input.getDirectColumn().equals(factory.getFieldName());
+      inputMatches = input.isDirectColumnAccess() && input.getDirectColumn().equals(factory.getFieldName());
     } else {
-      inputMatches = ((ExpressionVirtualColumn) virtualInput).getExpression()
-                                                             .equals(input.getExpression());
+      inputMatches = ((ExpressionVirtualColumn) virtualInput).getExpression().equals(input.getExpression());
     }
     return inputMatches && compression == factory.getCompression();
   }
diff --git a/extensions-contrib/tdigestsketch/src/main/java/org/apache/druid/query/aggregation/tdigestsketch/sql/TDigestGenerateSketchSqlAggregator.java b/extensions-contrib/tdigestsketch/src/main/java/org/apache/druid/query/aggregation/tdigestsketch/sql/TDigestGenerateSketchSqlAggregator.java
index 7104148..ee89702 100644
--- a/extensions-contrib/tdigestsketch/src/main/java/org/apache/druid/query/aggregation/tdigestsketch/sql/TDigestGenerateSketchSqlAggregator.java
+++ b/extensions-contrib/tdigestsketch/src/main/java/org/apache/druid/query/aggregation/tdigestsketch/sql/TDigestGenerateSketchSqlAggregator.java
@@ -47,7 +47,6 @@
 import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
 
 import javax.annotation.Nullable;
-import java.util.ArrayList;
 import java.util.List;
 
 public class TDigestGenerateSketchSqlAggregator implements SqlAggregator
@@ -112,9 +111,9 @@
         if (factory instanceof TDigestSketchAggregatorFactory) {
           final TDigestSketchAggregatorFactory theFactory = (TDigestSketchAggregatorFactory) factory;
           final boolean matches = TDigestSketchUtils.matchingAggregatorFactoryExists(
+              virtualColumnRegistry,
               input,
               compression,
-              existing,
               (TDigestSketchAggregatorFactory) factory
           );
 
@@ -129,8 +128,6 @@
     }
 
     // No existing match found. Create a new one.
-    final List<VirtualColumn> virtualColumns = new ArrayList<>();
-
     if (input.isDirectColumnAccess()) {
       aggregatorFactory = new TDigestSketchAggregatorFactory(
           aggName,
@@ -143,7 +140,6 @@
           input,
           ValueType.FLOAT
       );
-      virtualColumns.add(virtualColumn);
       aggregatorFactory = new TDigestSketchAggregatorFactory(
           aggName,
           virtualColumn.getOutputName(),
@@ -151,10 +147,7 @@
       );
     }
 
-    return Aggregation.create(
-        virtualColumns,
-        aggregatorFactory
-    );
+    return Aggregation.create(aggregatorFactory);
   }
 
   private static class TDigestGenerateSketchSqlAggFunction extends SqlAggFunction
diff --git a/extensions-contrib/tdigestsketch/src/main/java/org/apache/druid/query/aggregation/tdigestsketch/sql/TDigestSketchQuantileSqlAggregator.java b/extensions-contrib/tdigestsketch/src/main/java/org/apache/druid/query/aggregation/tdigestsketch/sql/TDigestSketchQuantileSqlAggregator.java
index 66e2c35..314edfa 100644
--- a/extensions-contrib/tdigestsketch/src/main/java/org/apache/druid/query/aggregation/tdigestsketch/sql/TDigestSketchQuantileSqlAggregator.java
+++ b/extensions-contrib/tdigestsketch/src/main/java/org/apache/druid/query/aggregation/tdigestsketch/sql/TDigestSketchQuantileSqlAggregator.java
@@ -50,7 +50,6 @@
 import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
 
 import javax.annotation.Nullable;
-import java.util.ArrayList;
 import java.util.List;
 
 public class TDigestSketchQuantileSqlAggregator implements SqlAggregator
@@ -123,9 +122,9 @@
       for (AggregatorFactory factory : existing.getAggregatorFactories()) {
         if (factory instanceof TDigestSketchAggregatorFactory) {
           final boolean matches = TDigestSketchUtils.matchingAggregatorFactoryExists(
+              virtualColumnRegistry,
               input,
               compression,
-              existing,
               (TDigestSketchAggregatorFactory) factory
           );
 
@@ -148,8 +147,6 @@
     }
 
     // No existing match found. Create a new one.
-    final List<VirtualColumn> virtualColumns = new ArrayList<>();
-
     if (input.isDirectColumnAccess()) {
       aggregatorFactory = new TDigestSketchAggregatorFactory(
           sketchName,
@@ -162,7 +159,6 @@
           input,
           ValueType.FLOAT
       );
-      virtualColumns.add(virtualColumn);
       aggregatorFactory = new TDigestSketchAggregatorFactory(
           sketchName,
           virtualColumn.getOutputName(),
@@ -171,7 +167,6 @@
     }
 
     return Aggregation.create(
-        virtualColumns,
         ImmutableList.of(aggregatorFactory),
         new TDigestSketchToQuantilePostAggregator(
             name,
diff --git a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchApproxCountDistinctSqlAggregator.java b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchApproxCountDistinctSqlAggregator.java
index 5054495..1628dd1 100644
--- a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchApproxCountDistinctSqlAggregator.java
+++ b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchApproxCountDistinctSqlAggregator.java
@@ -29,12 +29,10 @@
 import org.apache.calcite.sql.type.SqlTypeName;
 import org.apache.druid.query.aggregation.AggregatorFactory;
 import org.apache.druid.query.aggregation.post.FinalizingFieldAccessPostAggregator;
-import org.apache.druid.segment.VirtualColumn;
 import org.apache.druid.sql.calcite.aggregation.Aggregation;
 import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
 
 import java.util.Collections;
-import java.util.List;
 
 public class HllSketchApproxCountDistinctSqlAggregator extends HllSketchBaseSqlAggregator implements SqlAggregator
 {
@@ -51,12 +49,10 @@
   protected Aggregation toAggregation(
       String name,
       boolean finalizeAggregations,
-      List<VirtualColumn> virtualColumns,
       AggregatorFactory aggregatorFactory
   )
   {
     return Aggregation.create(
-        virtualColumns,
         Collections.singletonList(aggregatorFactory),
         finalizeAggregations ? new FinalizingFieldAccessPostAggregator(
             name,
diff --git a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchBaseSqlAggregator.java b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchBaseSqlAggregator.java
index 2f08cf0..3669bbd 100644
--- a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchBaseSqlAggregator.java
+++ b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchBaseSqlAggregator.java
@@ -45,7 +45,6 @@
 import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
 
 import javax.annotation.Nullable;
-import java.util.ArrayList;
 import java.util.List;
 
 public abstract class HllSketchBaseSqlAggregator implements SqlAggregator
@@ -115,7 +114,6 @@
       tgtHllType = HllSketchAggregatorFactory.DEFAULT_TGT_HLL_TYPE.name();
     }
 
-    final List<VirtualColumn> virtualColumns = new ArrayList<>();
     final AggregatorFactory aggregatorFactory;
     final String aggregatorName = finalizeAggregations ? Calcites.makePrefixedName(name, "a") : name;
 
@@ -150,7 +148,6 @@
             dataType
         );
         dimensionSpec = new DefaultDimensionSpec(virtualColumn.getOutputName(), null, inputType);
-        virtualColumns.add(virtualColumn);
       }
 
       aggregatorFactory = new HllSketchBuildAggregatorFactory(
@@ -165,7 +162,6 @@
     return toAggregation(
         name,
         finalizeAggregations,
-        virtualColumns,
         aggregatorFactory
     );
   }
@@ -173,7 +169,6 @@
   protected abstract Aggregation toAggregation(
       String name,
       boolean finalizeAggregations,
-      List<VirtualColumn> virtualColumns,
       AggregatorFactory aggregatorFactory
   );
 }
diff --git a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchObjectSqlAggregator.java b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchObjectSqlAggregator.java
index 80d0c57..f5da4d7 100644
--- a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchObjectSqlAggregator.java
+++ b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchObjectSqlAggregator.java
@@ -28,12 +28,10 @@
 import org.apache.calcite.sql.type.SqlTypeFamily;
 import org.apache.calcite.sql.type.SqlTypeName;
 import org.apache.druid.query.aggregation.AggregatorFactory;
-import org.apache.druid.segment.VirtualColumn;
 import org.apache.druid.sql.calcite.aggregation.Aggregation;
 import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
 
 import java.util.Collections;
-import java.util.List;
 
 public class HllSketchObjectSqlAggregator extends HllSketchBaseSqlAggregator implements SqlAggregator
 {
@@ -50,12 +48,10 @@
   protected Aggregation toAggregation(
       String name,
       boolean finalizeAggregations,
-      List<VirtualColumn> virtualColumns,
       AggregatorFactory aggregatorFactory
   )
   {
     return Aggregation.create(
-        virtualColumns,
         Collections.singletonList(aggregatorFactory),
         null
     );
diff --git a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/quantiles/sql/DoublesSketchApproxQuantileSqlAggregator.java b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/quantiles/sql/DoublesSketchApproxQuantileSqlAggregator.java
index c2e27da..69a07dd 100644
--- a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/quantiles/sql/DoublesSketchApproxQuantileSqlAggregator.java
+++ b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/quantiles/sql/DoublesSketchApproxQuantileSqlAggregator.java
@@ -50,7 +50,6 @@
 import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
 
 import javax.annotation.Nullable;
-import java.util.ArrayList;
 import java.util.List;
 
 public class DoublesSketchApproxQuantileSqlAggregator implements SqlAggregator
@@ -132,22 +131,16 @@
 
           // Check input for equivalence.
           final boolean inputMatches;
-          final VirtualColumn virtualInput = existing.getVirtualColumns()
-                                                     .stream()
-                                                     .filter(
-                                                         virtualColumn ->
-                                                             virtualColumn.getOutputName()
-                                                                          .equals(theFactory.getFieldName())
-                                                     )
-                                                     .findFirst()
-                                                     .orElse(null);
+          final VirtualColumn virtualInput =
+              virtualColumnRegistry.findVirtualColumns(theFactory.requiredFields())
+                                   .stream()
+                                   .findFirst()
+                                   .orElse(null);
 
           if (virtualInput == null) {
-            inputMatches = input.isDirectColumnAccess()
-                           && input.getDirectColumn().equals(theFactory.getFieldName());
+            inputMatches = input.isDirectColumnAccess() && input.getDirectColumn().equals(theFactory.getFieldName());
           } else {
-            inputMatches = ((ExpressionVirtualColumn) virtualInput).getExpression()
-                                                                   .equals(input.getExpression());
+            inputMatches = ((ExpressionVirtualColumn) virtualInput).getExpression().equals(input.getExpression());
           }
 
           final boolean matches = inputMatches
@@ -172,8 +165,6 @@
     }
 
     // No existing match found. Create a new one.
-    final List<VirtualColumn> virtualColumns = new ArrayList<>();
-
     if (input.isDirectColumnAccess()) {
       aggregatorFactory = new DoublesSketchAggregatorFactory(
           histogramName,
@@ -186,7 +177,6 @@
           input,
           ValueType.FLOAT
       );
-      virtualColumns.add(virtualColumn);
       aggregatorFactory = new DoublesSketchAggregatorFactory(
           histogramName,
           virtualColumn.getOutputName(),
@@ -195,7 +185,6 @@
     }
 
     return Aggregation.create(
-        virtualColumns,
         ImmutableList.of(aggregatorFactory),
         new DoublesSketchToQuantilePostAggregator(
             name,
diff --git a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/quantiles/sql/DoublesSketchObjectSqlAggregator.java b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/quantiles/sql/DoublesSketchObjectSqlAggregator.java
index 17b94dc..76d3e9c 100644
--- a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/quantiles/sql/DoublesSketchObjectSqlAggregator.java
+++ b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/quantiles/sql/DoublesSketchObjectSqlAggregator.java
@@ -47,7 +47,6 @@
 import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
 
 import javax.annotation.Nullable;
-import java.util.ArrayList;
 import java.util.List;
 
 public class DoublesSketchObjectSqlAggregator implements SqlAggregator
@@ -110,8 +109,6 @@
     }
 
     // No existing match found. Create a new one.
-    final List<VirtualColumn> virtualColumns = new ArrayList<>();
-
     if (input.isDirectColumnAccess()) {
       aggregatorFactory = new DoublesSketchAggregatorFactory(
           histogramName,
@@ -124,7 +121,6 @@
           input,
           ValueType.FLOAT
       );
-      virtualColumns.add(virtualColumn);
       aggregatorFactory = new DoublesSketchAggregatorFactory(
           histogramName,
           virtualColumn.getOutputName(),
@@ -133,7 +129,6 @@
     }
 
     return Aggregation.create(
-        virtualColumns,
         ImmutableList.of(aggregatorFactory),
         null
     );
diff --git a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchApproxCountDistinctSqlAggregator.java b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchApproxCountDistinctSqlAggregator.java
index 8c2cdc7..e44dc49 100644
--- a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchApproxCountDistinctSqlAggregator.java
+++ b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchApproxCountDistinctSqlAggregator.java
@@ -29,12 +29,10 @@
 import org.apache.calcite.sql.type.SqlTypeName;
 import org.apache.druid.query.aggregation.AggregatorFactory;
 import org.apache.druid.query.aggregation.post.FinalizingFieldAccessPostAggregator;
-import org.apache.druid.segment.VirtualColumn;
 import org.apache.druid.sql.calcite.aggregation.Aggregation;
 import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
 
 import java.util.Collections;
-import java.util.List;
 
 public class ThetaSketchApproxCountDistinctSqlAggregator extends ThetaSketchBaseSqlAggregator implements SqlAggregator
 {
@@ -51,12 +49,10 @@
   protected Aggregation toAggregation(
       String name,
       boolean finalizeAggregations,
-      List<VirtualColumn> virtualColumns,
       AggregatorFactory aggregatorFactory
   )
   {
     return Aggregation.create(
-        virtualColumns,
         Collections.singletonList(aggregatorFactory),
         finalizeAggregations ? new FinalizingFieldAccessPostAggregator(
             name,
diff --git a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchBaseSqlAggregator.java b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchBaseSqlAggregator.java
index ed2aafc..b71a8cc 100644
--- a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchBaseSqlAggregator.java
+++ b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchBaseSqlAggregator.java
@@ -44,7 +44,6 @@
 import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
 
 import javax.annotation.Nullable;
-import java.util.ArrayList;
 import java.util.List;
 
 public abstract class ThetaSketchBaseSqlAggregator implements SqlAggregator
@@ -94,7 +93,6 @@
       sketchSize = SketchAggregatorFactory.DEFAULT_MAX_SKETCH_SIZE;
     }
 
-    final List<VirtualColumn> virtualColumns = new ArrayList<>();
     final AggregatorFactory aggregatorFactory;
     final String aggregatorName = finalizeAggregations ? Calcites.makePrefixedName(name, "a") : name;
 
@@ -130,7 +128,6 @@
             dataType
         );
         dimensionSpec = new DefaultDimensionSpec(virtualColumn.getOutputName(), null, inputType);
-        virtualColumns.add(virtualColumn);
       }
 
       aggregatorFactory = new SketchMergeAggregatorFactory(
@@ -146,7 +143,6 @@
     return toAggregation(
         name,
         finalizeAggregations,
-        virtualColumns,
         aggregatorFactory
     );
   }
@@ -154,7 +150,6 @@
   protected abstract Aggregation toAggregation(
       String name,
       boolean finalizeAggregations,
-      List<VirtualColumn> virtualColumns,
       AggregatorFactory aggregatorFactory
   );
 }
diff --git a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchObjectSqlAggregator.java b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchObjectSqlAggregator.java
index 184adf5..ac2a4ba 100644
--- a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchObjectSqlAggregator.java
+++ b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchObjectSqlAggregator.java
@@ -28,12 +28,10 @@
 import org.apache.calcite.sql.type.SqlTypeFamily;
 import org.apache.calcite.sql.type.SqlTypeName;
 import org.apache.druid.query.aggregation.AggregatorFactory;
-import org.apache.druid.segment.VirtualColumn;
 import org.apache.druid.sql.calcite.aggregation.Aggregation;
 import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
 
 import java.util.Collections;
-import java.util.List;
 
 public class ThetaSketchObjectSqlAggregator extends ThetaSketchBaseSqlAggregator implements SqlAggregator
 {
@@ -50,12 +48,10 @@
   protected Aggregation toAggregation(
       String name,
       boolean finalizeAggregations,
-      List<VirtualColumn> virtualColumns,
       AggregatorFactory aggregatorFactory
   )
   {
     return Aggregation.create(
-        virtualColumns,
         Collections.singletonList(aggregatorFactory),
         null
     );
diff --git a/extensions-core/druid-bloom-filter/src/main/java/org/apache/druid/query/aggregation/bloom/sql/BloomFilterSqlAggregator.java b/extensions-core/druid-bloom-filter/src/main/java/org/apache/druid/query/aggregation/bloom/sql/BloomFilterSqlAggregator.java
index ff7dbd4..f934520 100644
--- a/extensions-core/druid-bloom-filter/src/main/java/org/apache/druid/query/aggregation/bloom/sql/BloomFilterSqlAggregator.java
+++ b/extensions-core/druid-bloom-filter/src/main/java/org/apache/druid/query/aggregation/bloom/sql/BloomFilterSqlAggregator.java
@@ -50,7 +50,6 @@
 import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
 
 import javax.annotation.Nullable;
-import java.util.ArrayList;
 import java.util.List;
 
 public class BloomFilterSqlAggregator implements SqlAggregator
@@ -115,15 +114,10 @@
 
           // Check input for equivalence.
           final boolean inputMatches;
-          final VirtualColumn virtualInput =
-              existing.getVirtualColumns()
-                      .stream()
-                      .filter(virtualColumn ->
-                                  virtualColumn.getOutputName().equals(theFactory.getField().getOutputName())
-                      )
-                      .findFirst()
-                      .orElse(null);
-
+          final VirtualColumn virtualInput = virtualColumnRegistry.findVirtualColumns(theFactory.requiredFields())
+                                                                  .stream()
+                                                                  .findFirst()
+                                                                  .orElse(null);
           if (virtualInput == null) {
             if (input.isDirectColumnAccess()) {
               inputMatches =
@@ -150,7 +144,6 @@
     }
 
     // No existing match found. Create a new one.
-    final List<VirtualColumn> virtualColumns = new ArrayList<>();
 
     ValueType valueType = Calcites.getValueTypeForRelDataType(inputOperand.getType());
     final DimensionSpec spec;
@@ -173,7 +166,6 @@
           input,
           inputOperand.getType()
       );
-      virtualColumns.add(virtualColumn);
       spec = new DefaultDimensionSpec(
           virtualColumn.getOutputName(),
           StringUtils.format("%s:%s", name, virtualColumn.getOutputName())
@@ -186,10 +178,7 @@
         maxNumEntries
     );
 
-    return Aggregation.create(
-        virtualColumns,
-        aggregatorFactory
-    );
+    return Aggregation.create(aggregatorFactory);
   }
 
   private static class BloomFilterSqlAggFunction extends SqlAggFunction
diff --git a/extensions-core/histogram/src/main/java/org/apache/druid/query/aggregation/histogram/sql/FixedBucketsHistogramQuantileSqlAggregator.java b/extensions-core/histogram/src/main/java/org/apache/druid/query/aggregation/histogram/sql/FixedBucketsHistogramQuantileSqlAggregator.java
index b29482d..eb3ae43 100644
--- a/extensions-core/histogram/src/main/java/org/apache/druid/query/aggregation/histogram/sql/FixedBucketsHistogramQuantileSqlAggregator.java
+++ b/extensions-core/histogram/src/main/java/org/apache/druid/query/aggregation/histogram/sql/FixedBucketsHistogramQuantileSqlAggregator.java
@@ -50,7 +50,6 @@
 import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
 
 import javax.annotation.Nullable;
-import java.util.ArrayList;
 import java.util.List;
 
 public class FixedBucketsHistogramQuantileSqlAggregator implements SqlAggregator
@@ -188,15 +187,11 @@
 
           // Check input for equivalence.
           final boolean inputMatches;
-          final VirtualColumn virtualInput = existing.getVirtualColumns()
-                                                     .stream()
-                                                     .filter(
-                                                         virtualColumn ->
-                                                             virtualColumn.getOutputName()
-                                                                          .equals(theFactory.getFieldName())
-                                                     )
-                                                     .findFirst()
-                                                     .orElse(null);
+          final VirtualColumn virtualInput =
+              virtualColumnRegistry.findVirtualColumns(theFactory.requiredFields())
+                                   .stream()
+                                   .findFirst()
+                                   .orElse(null);
 
           if (virtualInput == null) {
             inputMatches = input.isDirectColumnAccess()
@@ -224,8 +219,6 @@
     }
 
     // No existing match found. Create a new one.
-    final List<VirtualColumn> virtualColumns = new ArrayList<>();
-
     if (input.isDirectColumnAccess()) {
       aggregatorFactory = new FixedBucketsHistogramAggregatorFactory(
           histogramName,
@@ -242,7 +235,6 @@
           input,
           ValueType.FLOAT
       );
-      virtualColumns.add(virtualColumn);
       aggregatorFactory = new FixedBucketsHistogramAggregatorFactory(
           histogramName,
           virtualColumn.getOutputName(),
@@ -255,7 +247,6 @@
     }
 
     return Aggregation.create(
-        virtualColumns,
         ImmutableList.of(aggregatorFactory),
         new QuantilePostAggregator(name, histogramName, probability)
     );
diff --git a/extensions-core/histogram/src/main/java/org/apache/druid/query/aggregation/histogram/sql/QuantileSqlAggregator.java b/extensions-core/histogram/src/main/java/org/apache/druid/query/aggregation/histogram/sql/QuantileSqlAggregator.java
index e00e83b..529fa10 100644
--- a/extensions-core/histogram/src/main/java/org/apache/druid/query/aggregation/histogram/sql/QuantileSqlAggregator.java
+++ b/extensions-core/histogram/src/main/java/org/apache/druid/query/aggregation/histogram/sql/QuantileSqlAggregator.java
@@ -51,7 +51,6 @@
 import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
 
 import javax.annotation.Nullable;
-import java.util.ArrayList;
 import java.util.List;
 
 public class QuantileSqlAggregator implements SqlAggregator
@@ -137,15 +136,11 @@
 
           // Check input for equivalence.
           final boolean inputMatches;
-          final VirtualColumn virtualInput = existing.getVirtualColumns()
-                                                     .stream()
-                                                     .filter(
-                                                         virtualColumn ->
-                                                             virtualColumn.getOutputName()
-                                                                          .equals(theFactory.getFieldName())
-                                                     )
-                                                     .findFirst()
-                                                     .orElse(null);
+          final VirtualColumn virtualInput =
+              virtualColumnRegistry.findVirtualColumns(theFactory.requiredFields())
+                                   .stream()
+                                   .findFirst()
+                                   .orElse(null);
 
           if (virtualInput == null) {
             inputMatches = input.isDirectColumnAccess()
@@ -173,8 +168,6 @@
     }
 
     // No existing match found. Create a new one.
-    final List<VirtualColumn> virtualColumns = new ArrayList<>();
-
     if (input.isDirectColumnAccess()) {
       if (rowSignature.getColumnType(input.getDirectColumn()).orElse(null) == ValueType.COMPLEX) {
         aggregatorFactory = new ApproximateHistogramFoldingAggregatorFactory(
@@ -200,7 +193,6 @@
     } else {
       final VirtualColumn virtualColumn =
           virtualColumnRegistry.getOrCreateVirtualColumnForExpression(plannerContext, input, ValueType.FLOAT);
-      virtualColumns.add(virtualColumn);
       aggregatorFactory = new ApproximateHistogramAggregatorFactory(
           histogramName,
           virtualColumn.getOutputName(),
@@ -213,7 +205,6 @@
     }
 
     return Aggregation.create(
-        virtualColumns,
         ImmutableList.of(aggregatorFactory),
         new QuantilePostAggregator(name, histogramName, probability)
     );
diff --git a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/sql/BaseVarianceSqlAggregator.java b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/sql/BaseVarianceSqlAggregator.java
index 07c8849..617067a 100644
--- a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/sql/BaseVarianceSqlAggregator.java
+++ b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/sql/BaseVarianceSqlAggregator.java
@@ -48,7 +48,6 @@
 import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
 
 import javax.annotation.Nullable;
-import java.util.ArrayList;
 import java.util.List;
 
 public abstract class BaseVarianceSqlAggregator implements SqlAggregator
@@ -84,7 +83,6 @@
     final AggregatorFactory aggregatorFactory;
     final RelDataType dataType = inputOperand.getType();
     final ValueType inputType = Calcites.getValueTypeForRelDataType(dataType);
-    final List<VirtualColumn> virtualColumns = new ArrayList<>();
     final DimensionSpec dimensionSpec;
     final String aggName = StringUtils.format("%s:agg", name);
     final SqlAggFunction func = calciteFunction();
@@ -98,7 +96,6 @@
       VirtualColumn virtualColumn =
           virtualColumnRegistry.getOrCreateVirtualColumnForExpression(plannerContext, input, dataType);
       dimensionSpec = new DefaultDimensionSpec(virtualColumn.getOutputName(), null, inputType);
-      virtualColumns.add(virtualColumn);
     }
 
     switch (inputType) {
@@ -135,7 +132,6 @@
     }
 
     return Aggregation.create(
-        virtualColumns,
         ImmutableList.of(aggregatorFactory),
         postAggregator
     );
diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/FilteredAggregatorFactory.java b/processing/src/main/java/org/apache/druid/query/aggregation/FilteredAggregatorFactory.java
index d8de2d7..852af4f 100644
--- a/processing/src/main/java/org/apache/druid/query/aggregation/FilteredAggregatorFactory.java
+++ b/processing/src/main/java/org/apache/druid/query/aggregation/FilteredAggregatorFactory.java
@@ -23,6 +23,8 @@
 import com.fasterxml.jackson.annotation.JsonProperty;
 import com.google.common.base.Preconditions;
 import com.google.common.base.Strings;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
 import org.apache.druid.query.PerSegmentQueryOptimizationContext;
 import org.apache.druid.query.filter.DimFilter;
 import org.apache.druid.query.filter.Filter;
@@ -166,7 +168,10 @@
   @Override
   public List<String> requiredFields()
   {
-    return delegate.requiredFields();
+    return ImmutableList.copyOf(
+        // use a set to get rid of dupes
+        ImmutableSet.<String>builder().addAll(delegate.requiredFields()).addAll(filter.getRequiredColumns()).build()
+    );
   }
 
   @Override
diff --git a/processing/src/test/java/org/apache/druid/query/aggregation/FilteredAggregatorFactoryTest.java b/processing/src/test/java/org/apache/druid/query/aggregation/FilteredAggregatorFactoryTest.java
index 879a252..1f26547 100644
--- a/processing/src/test/java/org/apache/druid/query/aggregation/FilteredAggregatorFactoryTest.java
+++ b/processing/src/test/java/org/apache/druid/query/aggregation/FilteredAggregatorFactoryTest.java
@@ -19,11 +19,14 @@
 
 package org.apache.druid.query.aggregation;
 
+import com.google.common.collect.ImmutableList;
+import org.apache.druid.query.filter.SelectorDimFilter;
 import org.apache.druid.query.filter.TrueDimFilter;
+import org.apache.druid.testing.InitializedNullHandlingTest;
 import org.junit.Assert;
 import org.junit.Test;
 
-public class FilteredAggregatorFactoryTest
+public class FilteredAggregatorFactoryTest extends InitializedNullHandlingTest
 {
   @Test
   public void testSimpleNaming()
@@ -44,4 +47,16 @@
         null
     ).getName());
   }
+
+  @Test
+  public void testRequiredFields()
+  {
+    Assert.assertEquals(
+        ImmutableList.of("x", "y"),
+        new FilteredAggregatorFactory(
+            new LongSumAggregatorFactory("x", "x"),
+            new SelectorDimFilter("y", "wat", null)
+        ).requiredFields()
+    );
+  }
 }
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/Aggregation.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/Aggregation.java
index 41ce161..f28ae18 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/Aggregation.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/Aggregation.java
@@ -29,7 +29,6 @@
 import org.apache.druid.query.aggregation.PostAggregator;
 import org.apache.druid.query.filter.AndDimFilter;
 import org.apache.druid.query.filter.DimFilter;
-import org.apache.druid.segment.VirtualColumn;
 import org.apache.druid.segment.column.RowSignature;
 import org.apache.druid.sql.calcite.filtration.Filtration;
 import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
@@ -44,17 +43,14 @@
 
 public class Aggregation
 {
-  private final List<VirtualColumn> virtualColumns;
   private final List<AggregatorFactory> aggregatorFactories;
   private final PostAggregator postAggregator;
 
   private Aggregation(
-      final List<VirtualColumn> virtualColumns,
       final List<AggregatorFactory> aggregatorFactories,
       final PostAggregator postAggregator
   )
   {
-    this.virtualColumns = Preconditions.checkNotNull(virtualColumns, "virtualColumns");
     this.aggregatorFactories = Preconditions.checkNotNull(aggregatorFactories, "aggregatorFactories");
     this.postAggregator = postAggregator;
 
@@ -88,19 +84,10 @@
     }
   }
 
-  public static Aggregation create(final List<VirtualColumn> virtualColumns, final AggregatorFactory aggregatorFactory)
-  {
-    return new Aggregation(
-        virtualColumns,
-        ImmutableList.of(aggregatorFactory),
-        null
-    );
-  }
 
   public static Aggregation create(final AggregatorFactory aggregatorFactory)
   {
     return new Aggregation(
-        ImmutableList.of(),
         ImmutableList.of(aggregatorFactory),
         null
     );
@@ -108,7 +95,7 @@
 
   public static Aggregation create(final PostAggregator postAggregator)
   {
-    return new Aggregation(Collections.emptyList(), Collections.emptyList(), postAggregator);
+    return new Aggregation(Collections.emptyList(), postAggregator);
   }
 
   public static Aggregation create(
@@ -116,21 +103,19 @@
       final PostAggregator postAggregator
   )
   {
-    return new Aggregation(ImmutableList.of(), aggregatorFactories, postAggregator);
+    return new Aggregation(aggregatorFactories, postAggregator);
   }
 
-  public static Aggregation create(
-      final List<VirtualColumn> virtualColumns,
-      final List<AggregatorFactory> aggregatorFactories,
-      final PostAggregator postAggregator
-  )
+  public List<String> getRequiredColumns()
   {
-    return new Aggregation(virtualColumns, aggregatorFactories, postAggregator);
-  }
-
-  public List<VirtualColumn> getVirtualColumns()
-  {
-    return virtualColumns;
+    Set<String> columns = new HashSet<>();
+    for (AggregatorFactory agg : aggregatorFactories) {
+      columns.addAll(agg.requiredFields());
+    }
+    if (postAggregator != null) {
+      columns.addAll(postAggregator.getDependentFields());
+    }
+    return ImmutableList.copyOf(columns);
   }
 
   public List<AggregatorFactory> getAggregatorFactories()
@@ -181,21 +166,10 @@
                                                     .optimizeFilterOnly(virtualColumnRegistry.getFullRowSignature())
                                                     .getDimFilter();
 
-    Set<VirtualColumn> aggVirtualColumnsPlusFilterColumns = new HashSet<>(virtualColumns);
-    for (String column : baseOptimizedFilter.getRequiredColumns()) {
-      if (virtualColumnRegistry.isVirtualColumnDefined(column)) {
-        aggVirtualColumnsPlusFilterColumns.add(virtualColumnRegistry.getVirtualColumn(column));
-      }
-    }
     final List<AggregatorFactory> newAggregators = new ArrayList<>();
     for (AggregatorFactory agg : aggregatorFactories) {
       if (agg instanceof FilteredAggregatorFactory) {
         final FilteredAggregatorFactory filteredAgg = (FilteredAggregatorFactory) agg;
-        for (String column : filteredAgg.getFilter().getRequiredColumns()) {
-          if (virtualColumnRegistry.isVirtualColumnDefined(column)) {
-            aggVirtualColumnsPlusFilterColumns.add(virtualColumnRegistry.getVirtualColumn(column));
-          }
-        }
         newAggregators.add(
             new FilteredAggregatorFactory(
                 filteredAgg.getAggregator(),
@@ -209,7 +183,7 @@
       }
     }
 
-    return new Aggregation(new ArrayList<>(aggVirtualColumnsPlusFilterColumns), newAggregators, postAggregator);
+    return new Aggregation(newAggregators, postAggregator);
   }
 
   @Override
@@ -222,23 +196,21 @@
       return false;
     }
     final Aggregation that = (Aggregation) o;
-    return Objects.equals(virtualColumns, that.virtualColumns) &&
-           Objects.equals(aggregatorFactories, that.aggregatorFactories) &&
+    return Objects.equals(aggregatorFactories, that.aggregatorFactories) &&
            Objects.equals(postAggregator, that.postAggregator);
   }
 
   @Override
   public int hashCode()
   {
-    return Objects.hash(virtualColumns, aggregatorFactories, postAggregator);
+    return Objects.hash(aggregatorFactories, postAggregator);
   }
 
   @Override
   public String toString()
   {
     return "Aggregation{" +
-           "virtualColumns=" + virtualColumns +
-           ", aggregatorFactories=" + aggregatorFactories +
+           "aggregatorFactories=" + aggregatorFactories +
            ", postAggregator=" + postAggregator +
            '}';
   }
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/ApproxCountDistinctSqlAggregator.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/ApproxCountDistinctSqlAggregator.java
index 73bceca..bbf0189 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/ApproxCountDistinctSqlAggregator.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/ApproxCountDistinctSqlAggregator.java
@@ -52,7 +52,6 @@
 import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
 
 import javax.annotation.Nullable;
-import java.util.ArrayList;
 import java.util.Collections;
 import java.util.List;
 
@@ -94,7 +93,6 @@
       return null;
     }
 
-    final List<VirtualColumn> myvirtualColumns = new ArrayList<>();
     final AggregatorFactory aggregatorFactory;
     final String aggregatorName = finalizeAggregations ? Calcites.makePrefixedName(name, "a") : name;
 
@@ -120,7 +118,6 @@
         VirtualColumn virtualColumn =
             virtualColumnRegistry.getOrCreateVirtualColumnForExpression(plannerContext, arg, dataType);
         dimensionSpec = new DefaultDimensionSpec(virtualColumn.getOutputName(), null, inputType);
-        myvirtualColumns.add(virtualColumn);
       }
 
       aggregatorFactory = new CardinalityAggregatorFactory(
@@ -133,7 +130,6 @@
     }
 
     return Aggregation.create(
-        myvirtualColumns,
         Collections.singletonList(aggregatorFactory),
         finalizeAggregations ? new HyperUniqueFinalizingPostAggregator(name, aggregatorFactory.getName()) : null
     );
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/ArraySqlAggregator.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/ArraySqlAggregator.java
index 96e51bb..0f80daa 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/ArraySqlAggregator.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/ArraySqlAggregator.java
@@ -53,7 +53,6 @@
 import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
 
 import javax.annotation.Nullable;
-import java.util.ArrayList;
 import java.util.List;
 import java.util.Objects;
 import java.util.stream.Collectors;
@@ -134,19 +133,15 @@
           break;
       }
     }
-    List<VirtualColumn> virtualColumns = new ArrayList<>();
-
     if (arg.isDirectColumnAccess()) {
       fieldName = arg.getDirectColumn();
     } else {
       VirtualColumn vc = virtualColumnRegistry.getOrCreateVirtualColumnForExpression(plannerContext, arg, elementType);
-      virtualColumns.add(vc);
       fieldName = vc.getOutputName();
     }
 
     if (aggregateCall.isDistinct()) {
       return Aggregation.create(
-          virtualColumns,
           new ExpressionLambdaAggregatorFactory(
               name,
               ImmutableSet.of(fieldName),
@@ -163,7 +158,6 @@
       );
     } else {
       return Aggregation.create(
-          virtualColumns,
           new ExpressionLambdaAggregatorFactory(
               name,
               ImmutableSet.of(fieldName),
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/AvgSqlAggregator.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/AvgSqlAggregator.java
index 3b97344..28364bd 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/AvgSqlAggregator.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/AvgSqlAggregator.java
@@ -31,6 +31,7 @@
 import org.apache.druid.query.aggregation.AggregatorFactory;
 import org.apache.druid.query.aggregation.post.ArithmeticPostAggregator;
 import org.apache.druid.query.aggregation.post.FieldAccessPostAggregator;
+import org.apache.druid.segment.VirtualColumn;
 import org.apache.druid.segment.column.RowSignature;
 import org.apache.druid.segment.column.ValueType;
 import org.apache.druid.sql.calcite.aggregation.Aggregation;
@@ -78,37 +79,7 @@
       return null;
     }
 
-    final String fieldName;
-    final String expression;
-    final DruidExpression arg = Iterables.getOnlyElement(arguments);
-
-    if (arg.isDirectColumnAccess()) {
-      fieldName = arg.getDirectColumn();
-      expression = null;
-    } else {
-      fieldName = null;
-      expression = arg.getExpression();
-    }
-
-    final ExprMacroTable macroTable = plannerContext.getExprMacroTable();
-
-    final ValueType sumType;
-    // Use 64-bit sum regardless of the type of the AVG aggregator.
-    if (SqlTypeName.INT_TYPES.contains(aggregateCall.getType().getSqlTypeName())) {
-      sumType = ValueType.LONG;
-    } else {
-      sumType = ValueType.DOUBLE;
-    }
-
-    final String sumName = Calcites.makePrefixedName(name, "sum");
     final String countName = Calcites.makePrefixedName(name, "count");
-    final AggregatorFactory sum = SumSqlAggregator.createSumAggregatorFactory(
-        sumType,
-        sumName,
-        fieldName,
-        expression,
-        macroTable
-    );
     final AggregatorFactory count = CountSqlAggregator.createCountAggregatorFactory(
         countName,
         plannerContext,
@@ -119,6 +90,38 @@
         project
     );
 
+    final String fieldName;
+    final String expression;
+    final DruidExpression arg = Iterables.getOnlyElement(arguments);
+
+
+    final ExprMacroTable macroTable = plannerContext.getExprMacroTable();
+    final ValueType sumType;
+    // Use 64-bit sum regardless of the type of the AVG aggregator.
+    if (SqlTypeName.INT_TYPES.contains(aggregateCall.getType().getSqlTypeName())) {
+      sumType = ValueType.LONG;
+    } else {
+      sumType = ValueType.DOUBLE;
+    }
+
+    if (arg.isDirectColumnAccess()) {
+      fieldName = arg.getDirectColumn();
+      expression = null;
+    } else {
+      // if the filter or anywhere else defined a virtual column for us, re-use it
+      VirtualColumn vc = virtualColumnRegistry.getVirtualColumnByExpression(arg.getExpression());
+      fieldName = vc != null ? vc.getOutputName() : null;
+      expression = vc != null ? null : arg.getExpression();
+    }
+    final String sumName = Calcites.makePrefixedName(name, "sum");
+    final AggregatorFactory sum = SumSqlAggregator.createSumAggregatorFactory(
+        sumType,
+        sumName,
+        fieldName,
+        expression,
+        macroTable
+    );
+
     return Aggregation.create(
         ImmutableList.of(sum, count),
         new ArithmeticPostAggregator(
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/CountSqlAggregator.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/CountSqlAggregator.java
index f674798..442b9ea 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/CountSqlAggregator.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/CountSqlAggregator.java
@@ -134,7 +134,7 @@
     } else {
       // Not COUNT(*), not distinct
       // COUNT(x) should count all non-null values of x.
-      return Aggregation.create(createCountAggregatorFactory(
+      AggregatorFactory theCount = createCountAggregatorFactory(
             name,
             plannerContext,
             rowSignature,
@@ -142,7 +142,9 @@
             rexBuilder,
             aggregateCall,
             project
-      ));
+      );
+
+      return Aggregation.create(theCount);
     }
   }
 }
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/EarliestLatestAnySqlAggregator.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/EarliestLatestAnySqlAggregator.java
index 8ec917e..be9a69f 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/EarliestLatestAnySqlAggregator.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/EarliestLatestAnySqlAggregator.java
@@ -64,9 +64,7 @@
 import javax.annotation.Nullable;
 import java.util.Collections;
 import java.util.List;
-import java.util.Objects;
 import java.util.stream.Collectors;
-import java.util.stream.Stream;
 
 public class EarliestLatestAnySqlAggregator implements SqlAggregator
 {
@@ -209,9 +207,6 @@
     }
 
     return Aggregation.create(
-        Stream.of(virtualColumnRegistry.getVirtualColumn(fieldName))
-              .filter(Objects::nonNull)
-              .collect(Collectors.toList()),
         Collections.singletonList(
             aggregatorType.createAggregatorFactory(
                 aggregatorName,
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidQuery.java b/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidQuery.java
index c503f96..251340f 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidQuery.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidQuery.java
@@ -636,7 +636,7 @@
       }
 
       for (Aggregation aggregation : grouping.getAggregations()) {
-        virtualColumns.addAll(aggregation.getVirtualColumns());
+        virtualColumns.addAll(virtualColumnRegistry.findVirtualColumns(aggregation.getRequiredColumns()));
       }
     }
 
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/rel/VirtualColumnRegistry.java b/sql/src/main/java/org/apache/druid/sql/calcite/rel/VirtualColumnRegistry.java
index 7244c75..2ce4a84 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/rel/VirtualColumnRegistry.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/rel/VirtualColumnRegistry.java
@@ -29,7 +29,9 @@
 
 import javax.annotation.Nullable;
 import java.util.HashMap;
+import java.util.List;
 import java.util.Map;
+import java.util.stream.Collectors;
 
 /**
  * Provides facilities to create and re-use {@link VirtualColumn} definitions for dimensions, filters, and filtered
@@ -128,6 +130,12 @@
     return virtualColumnsByName.get(virtualColumnName);
   }
 
+  @Nullable
+  public VirtualColumn getVirtualColumnByExpression(String expression)
+  {
+    return virtualColumnsByExpression.get(expression);
+  }
+
   /**
    * Get a signature representing the base signature plus all registered virtual columns.
    */
@@ -145,4 +153,15 @@
 
     return builder.build();
   }
+
+  /**
+   * Given a list of column names, find any corresponding {@link VirtualColumn} with the same name
+   */
+  public List<VirtualColumn> findVirtualColumns(List<String> allColumns)
+  {
+    return allColumns.stream()
+                     .filter(this::isVirtualColumnDefined)
+                     .map(this::getVirtualColumn)
+                     .collect(Collectors.toList());
+  }
 }
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/rule/GroupByRules.java b/sql/src/main/java/org/apache/druid/sql/calcite/rule/GroupByRules.java
index 2489ed7..cf596ce 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/rule/GroupByRules.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/rule/GroupByRules.java
@@ -117,7 +117,6 @@
         if (doesMatch) {
           existingAggregationsWithSameFilter.add(
               Aggregation.create(
-                  existingAggregation.getVirtualColumns(),
                   existingAggregation.getAggregatorFactories().stream()
                                      .map(factory -> ((FilteredAggregatorFactory) factory).getAggregator())
                                      .collect(Collectors.toList()),
diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
index 1efd99b..976d61a 100644
--- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
+++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
@@ -50,6 +50,7 @@
 import org.apache.druid.query.ResourceLimitExceededException;
 import org.apache.druid.query.TableDataSource;
 import org.apache.druid.query.UnionDataSource;
+import org.apache.druid.query.aggregation.AggregatorFactory;
 import org.apache.druid.query.aggregation.CountAggregatorFactory;
 import org.apache.druid.query.aggregation.DoubleMaxAggregatorFactory;
 import org.apache.druid.query.aggregation.DoubleMinAggregatorFactory;
@@ -112,6 +113,8 @@
 import org.apache.druid.query.topn.InvertedTopNMetricSpec;
 import org.apache.druid.query.topn.NumericTopNMetricSpec;
 import org.apache.druid.query.topn.TopNQueryBuilder;
+import org.apache.druid.segment.VirtualColumn;
+import org.apache.druid.segment.VirtualColumns;
 import org.apache.druid.segment.column.RowSignature;
 import org.apache.druid.segment.column.ValueType;
 import org.apache.druid.segment.join.JoinType;
@@ -18926,4 +18929,76 @@
         expectedResults
     );
   }
+
+  @Test
+  public void testCountAndAverageByConstantVirtualColumn() throws Exception
+  {
+    List<VirtualColumn> virtualColumns;
+    List<AggregatorFactory> aggs;
+    if (useDefault) {
+      aggs = ImmutableList.of(
+          new FilteredAggregatorFactory(
+              new CountAggregatorFactory("a0"),
+              not(selector("v0", null, null))
+          ),
+          new LongSumAggregatorFactory("a1:sum", null, "325323", TestExprMacroTable.INSTANCE),
+          new CountAggregatorFactory("a1:count")
+      );
+      virtualColumns = ImmutableList.of(
+          expressionVirtualColumn("v0", "'10.1'", ValueType.STRING)
+      );
+    } else {
+      aggs = ImmutableList.of(
+          new FilteredAggregatorFactory(
+              new CountAggregatorFactory("a0"),
+              not(selector("v0", null, null))
+          ),
+          new LongSumAggregatorFactory("a1:sum", "v1"),
+          new FilteredAggregatorFactory(
+              new CountAggregatorFactory("a1:count"),
+              not(selector("v1", null, null))
+          )
+      );
+      virtualColumns = ImmutableList.of(
+          expressionVirtualColumn("v0", "'10.1'", ValueType.STRING),
+          expressionVirtualColumn("v1", "325323", ValueType.LONG)
+      );
+
+    }
+    testQuery(
+        "SELECT dim5, COUNT(dim1), AVG(l1) FROM druid.numfoo WHERE dim1 = '10.1' AND l1 = 325323 GROUP BY dim5",
+        ImmutableList.of(
+            GroupByQuery.builder()
+                        .setDataSource(CalciteTests.DATASOURCE3)
+                        .setInterval(querySegmentSpec(Filtration.eternity()))
+                        .setDimFilter(
+                            and(
+                                selector("dim1", "10.1", null),
+                                selector("l1", "325323", null)
+                            )
+                        )
+                        .setGranularity(Granularities.ALL)
+                        .setVirtualColumns(VirtualColumns.create(virtualColumns))
+                        .setDimensions(new DefaultDimensionSpec("dim5", "_d0", ValueType.STRING))
+                        .setAggregatorSpecs(aggs)
+                        .setPostAggregatorSpecs(
+                            ImmutableList.of(
+                                new ArithmeticPostAggregator(
+                                    "a1",
+                                    "quotient",
+                                    ImmutableList.of(
+                                        new FieldAccessPostAggregator(null, "a1:sum"),
+                                        new FieldAccessPostAggregator(null, "a1:count")
+                                    )
+                                )
+                            )
+                        )
+                        .setContext(QUERY_CONTEXT_DEFAULT)
+                        .build()
+        ),
+        ImmutableList.of(
+            new Object[]{"ab", 1L, 325323L}
+        )
+    );
+  }
 }