fix count and average SQL aggregators on constant virtual columns (#11208)
* fix count and average SQL aggregators on constant virtual columns
* style
* even better, why are we tracking virtual columns in aggregations at all if we have a virtual column registry
* oops missed a few
* remove unused
* this will fix it
diff --git a/extensions-contrib/tdigestsketch/src/main/java/org/apache/druid/query/aggregation/tdigestsketch/TDigestSketchUtils.java b/extensions-contrib/tdigestsketch/src/main/java/org/apache/druid/query/aggregation/tdigestsketch/TDigestSketchUtils.java
index 3a5be13..4be77f6 100644
--- a/extensions-contrib/tdigestsketch/src/main/java/org/apache/druid/query/aggregation/tdigestsketch/TDigestSketchUtils.java
+++ b/extensions-contrib/tdigestsketch/src/main/java/org/apache/druid/query/aggregation/tdigestsketch/TDigestSketchUtils.java
@@ -24,8 +24,8 @@
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.segment.VirtualColumn;
import org.apache.druid.segment.virtual.ExpressionVirtualColumn;
-import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.expression.DruidExpression;
+import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import java.nio.ByteBuffer;
@@ -80,30 +80,24 @@
}
public static boolean matchingAggregatorFactoryExists(
+ final VirtualColumnRegistry virtualColumnRegistry,
final DruidExpression input,
final Integer compression,
- final Aggregation existing,
final TDigestSketchAggregatorFactory factory
)
{
// Check input for equivalence.
final boolean inputMatches;
- final VirtualColumn virtualInput = existing.getVirtualColumns()
- .stream()
- .filter(
- virtualColumn ->
- virtualColumn.getOutputName()
- .equals(factory.getFieldName())
- )
- .findFirst()
- .orElse(null);
+ final VirtualColumn virtualInput =
+ virtualColumnRegistry.findVirtualColumns(factory.requiredFields())
+ .stream()
+ .findFirst()
+ .orElse(null);
if (virtualInput == null) {
- inputMatches = input.isDirectColumnAccess()
- && input.getDirectColumn().equals(factory.getFieldName());
+ inputMatches = input.isDirectColumnAccess() && input.getDirectColumn().equals(factory.getFieldName());
} else {
- inputMatches = ((ExpressionVirtualColumn) virtualInput).getExpression()
- .equals(input.getExpression());
+ inputMatches = ((ExpressionVirtualColumn) virtualInput).getExpression().equals(input.getExpression());
}
return inputMatches && compression == factory.getCompression();
}
diff --git a/extensions-contrib/tdigestsketch/src/main/java/org/apache/druid/query/aggregation/tdigestsketch/sql/TDigestGenerateSketchSqlAggregator.java b/extensions-contrib/tdigestsketch/src/main/java/org/apache/druid/query/aggregation/tdigestsketch/sql/TDigestGenerateSketchSqlAggregator.java
index 7104148..ee89702 100644
--- a/extensions-contrib/tdigestsketch/src/main/java/org/apache/druid/query/aggregation/tdigestsketch/sql/TDigestGenerateSketchSqlAggregator.java
+++ b/extensions-contrib/tdigestsketch/src/main/java/org/apache/druid/query/aggregation/tdigestsketch/sql/TDigestGenerateSketchSqlAggregator.java
@@ -47,7 +47,6 @@
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable;
-import java.util.ArrayList;
import java.util.List;
public class TDigestGenerateSketchSqlAggregator implements SqlAggregator
@@ -112,9 +111,9 @@
if (factory instanceof TDigestSketchAggregatorFactory) {
final TDigestSketchAggregatorFactory theFactory = (TDigestSketchAggregatorFactory) factory;
final boolean matches = TDigestSketchUtils.matchingAggregatorFactoryExists(
+ virtualColumnRegistry,
input,
compression,
- existing,
(TDigestSketchAggregatorFactory) factory
);
@@ -129,8 +128,6 @@
}
// No existing match found. Create a new one.
- final List<VirtualColumn> virtualColumns = new ArrayList<>();
-
if (input.isDirectColumnAccess()) {
aggregatorFactory = new TDigestSketchAggregatorFactory(
aggName,
@@ -143,7 +140,6 @@
input,
ValueType.FLOAT
);
- virtualColumns.add(virtualColumn);
aggregatorFactory = new TDigestSketchAggregatorFactory(
aggName,
virtualColumn.getOutputName(),
@@ -151,10 +147,7 @@
);
}
- return Aggregation.create(
- virtualColumns,
- aggregatorFactory
- );
+ return Aggregation.create(aggregatorFactory);
}
private static class TDigestGenerateSketchSqlAggFunction extends SqlAggFunction
diff --git a/extensions-contrib/tdigestsketch/src/main/java/org/apache/druid/query/aggregation/tdigestsketch/sql/TDigestSketchQuantileSqlAggregator.java b/extensions-contrib/tdigestsketch/src/main/java/org/apache/druid/query/aggregation/tdigestsketch/sql/TDigestSketchQuantileSqlAggregator.java
index 66e2c35..314edfa 100644
--- a/extensions-contrib/tdigestsketch/src/main/java/org/apache/druid/query/aggregation/tdigestsketch/sql/TDigestSketchQuantileSqlAggregator.java
+++ b/extensions-contrib/tdigestsketch/src/main/java/org/apache/druid/query/aggregation/tdigestsketch/sql/TDigestSketchQuantileSqlAggregator.java
@@ -50,7 +50,6 @@
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable;
-import java.util.ArrayList;
import java.util.List;
public class TDigestSketchQuantileSqlAggregator implements SqlAggregator
@@ -123,9 +122,9 @@
for (AggregatorFactory factory : existing.getAggregatorFactories()) {
if (factory instanceof TDigestSketchAggregatorFactory) {
final boolean matches = TDigestSketchUtils.matchingAggregatorFactoryExists(
+ virtualColumnRegistry,
input,
compression,
- existing,
(TDigestSketchAggregatorFactory) factory
);
@@ -148,8 +147,6 @@
}
// No existing match found. Create a new one.
- final List<VirtualColumn> virtualColumns = new ArrayList<>();
-
if (input.isDirectColumnAccess()) {
aggregatorFactory = new TDigestSketchAggregatorFactory(
sketchName,
@@ -162,7 +159,6 @@
input,
ValueType.FLOAT
);
- virtualColumns.add(virtualColumn);
aggregatorFactory = new TDigestSketchAggregatorFactory(
sketchName,
virtualColumn.getOutputName(),
@@ -171,7 +167,6 @@
}
return Aggregation.create(
- virtualColumns,
ImmutableList.of(aggregatorFactory),
new TDigestSketchToQuantilePostAggregator(
name,
diff --git a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchApproxCountDistinctSqlAggregator.java b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchApproxCountDistinctSqlAggregator.java
index 5054495..1628dd1 100644
--- a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchApproxCountDistinctSqlAggregator.java
+++ b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchApproxCountDistinctSqlAggregator.java
@@ -29,12 +29,10 @@
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.post.FinalizingFieldAccessPostAggregator;
-import org.apache.druid.segment.VirtualColumn;
import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import java.util.Collections;
-import java.util.List;
public class HllSketchApproxCountDistinctSqlAggregator extends HllSketchBaseSqlAggregator implements SqlAggregator
{
@@ -51,12 +49,10 @@
protected Aggregation toAggregation(
String name,
boolean finalizeAggregations,
- List<VirtualColumn> virtualColumns,
AggregatorFactory aggregatorFactory
)
{
return Aggregation.create(
- virtualColumns,
Collections.singletonList(aggregatorFactory),
finalizeAggregations ? new FinalizingFieldAccessPostAggregator(
name,
diff --git a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchBaseSqlAggregator.java b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchBaseSqlAggregator.java
index 2f08cf0..3669bbd 100644
--- a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchBaseSqlAggregator.java
+++ b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchBaseSqlAggregator.java
@@ -45,7 +45,6 @@
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable;
-import java.util.ArrayList;
import java.util.List;
public abstract class HllSketchBaseSqlAggregator implements SqlAggregator
@@ -115,7 +114,6 @@
tgtHllType = HllSketchAggregatorFactory.DEFAULT_TGT_HLL_TYPE.name();
}
- final List<VirtualColumn> virtualColumns = new ArrayList<>();
final AggregatorFactory aggregatorFactory;
final String aggregatorName = finalizeAggregations ? Calcites.makePrefixedName(name, "a") : name;
@@ -150,7 +148,6 @@
dataType
);
dimensionSpec = new DefaultDimensionSpec(virtualColumn.getOutputName(), null, inputType);
- virtualColumns.add(virtualColumn);
}
aggregatorFactory = new HllSketchBuildAggregatorFactory(
@@ -165,7 +162,6 @@
return toAggregation(
name,
finalizeAggregations,
- virtualColumns,
aggregatorFactory
);
}
@@ -173,7 +169,6 @@
protected abstract Aggregation toAggregation(
String name,
boolean finalizeAggregations,
- List<VirtualColumn> virtualColumns,
AggregatorFactory aggregatorFactory
);
}
diff --git a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchObjectSqlAggregator.java b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchObjectSqlAggregator.java
index 80d0c57..f5da4d7 100644
--- a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchObjectSqlAggregator.java
+++ b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchObjectSqlAggregator.java
@@ -28,12 +28,10 @@
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.druid.query.aggregation.AggregatorFactory;
-import org.apache.druid.segment.VirtualColumn;
import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import java.util.Collections;
-import java.util.List;
public class HllSketchObjectSqlAggregator extends HllSketchBaseSqlAggregator implements SqlAggregator
{
@@ -50,12 +48,10 @@
protected Aggregation toAggregation(
String name,
boolean finalizeAggregations,
- List<VirtualColumn> virtualColumns,
AggregatorFactory aggregatorFactory
)
{
return Aggregation.create(
- virtualColumns,
Collections.singletonList(aggregatorFactory),
null
);
diff --git a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/quantiles/sql/DoublesSketchApproxQuantileSqlAggregator.java b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/quantiles/sql/DoublesSketchApproxQuantileSqlAggregator.java
index c2e27da..69a07dd 100644
--- a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/quantiles/sql/DoublesSketchApproxQuantileSqlAggregator.java
+++ b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/quantiles/sql/DoublesSketchApproxQuantileSqlAggregator.java
@@ -50,7 +50,6 @@
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable;
-import java.util.ArrayList;
import java.util.List;
public class DoublesSketchApproxQuantileSqlAggregator implements SqlAggregator
@@ -132,22 +131,16 @@
// Check input for equivalence.
final boolean inputMatches;
- final VirtualColumn virtualInput = existing.getVirtualColumns()
- .stream()
- .filter(
- virtualColumn ->
- virtualColumn.getOutputName()
- .equals(theFactory.getFieldName())
- )
- .findFirst()
- .orElse(null);
+ final VirtualColumn virtualInput =
+ virtualColumnRegistry.findVirtualColumns(theFactory.requiredFields())
+ .stream()
+ .findFirst()
+ .orElse(null);
if (virtualInput == null) {
- inputMatches = input.isDirectColumnAccess()
- && input.getDirectColumn().equals(theFactory.getFieldName());
+ inputMatches = input.isDirectColumnAccess() && input.getDirectColumn().equals(theFactory.getFieldName());
} else {
- inputMatches = ((ExpressionVirtualColumn) virtualInput).getExpression()
- .equals(input.getExpression());
+ inputMatches = ((ExpressionVirtualColumn) virtualInput).getExpression().equals(input.getExpression());
}
final boolean matches = inputMatches
@@ -172,8 +165,6 @@
}
// No existing match found. Create a new one.
- final List<VirtualColumn> virtualColumns = new ArrayList<>();
-
if (input.isDirectColumnAccess()) {
aggregatorFactory = new DoublesSketchAggregatorFactory(
histogramName,
@@ -186,7 +177,6 @@
input,
ValueType.FLOAT
);
- virtualColumns.add(virtualColumn);
aggregatorFactory = new DoublesSketchAggregatorFactory(
histogramName,
virtualColumn.getOutputName(),
@@ -195,7 +185,6 @@
}
return Aggregation.create(
- virtualColumns,
ImmutableList.of(aggregatorFactory),
new DoublesSketchToQuantilePostAggregator(
name,
diff --git a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/quantiles/sql/DoublesSketchObjectSqlAggregator.java b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/quantiles/sql/DoublesSketchObjectSqlAggregator.java
index 17b94dc..76d3e9c 100644
--- a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/quantiles/sql/DoublesSketchObjectSqlAggregator.java
+++ b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/quantiles/sql/DoublesSketchObjectSqlAggregator.java
@@ -47,7 +47,6 @@
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable;
-import java.util.ArrayList;
import java.util.List;
public class DoublesSketchObjectSqlAggregator implements SqlAggregator
@@ -110,8 +109,6 @@
}
// No existing match found. Create a new one.
- final List<VirtualColumn> virtualColumns = new ArrayList<>();
-
if (input.isDirectColumnAccess()) {
aggregatorFactory = new DoublesSketchAggregatorFactory(
histogramName,
@@ -124,7 +121,6 @@
input,
ValueType.FLOAT
);
- virtualColumns.add(virtualColumn);
aggregatorFactory = new DoublesSketchAggregatorFactory(
histogramName,
virtualColumn.getOutputName(),
@@ -133,7 +129,6 @@
}
return Aggregation.create(
- virtualColumns,
ImmutableList.of(aggregatorFactory),
null
);
diff --git a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchApproxCountDistinctSqlAggregator.java b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchApproxCountDistinctSqlAggregator.java
index 8c2cdc7..e44dc49 100644
--- a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchApproxCountDistinctSqlAggregator.java
+++ b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchApproxCountDistinctSqlAggregator.java
@@ -29,12 +29,10 @@
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.post.FinalizingFieldAccessPostAggregator;
-import org.apache.druid.segment.VirtualColumn;
import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import java.util.Collections;
-import java.util.List;
public class ThetaSketchApproxCountDistinctSqlAggregator extends ThetaSketchBaseSqlAggregator implements SqlAggregator
{
@@ -51,12 +49,10 @@
protected Aggregation toAggregation(
String name,
boolean finalizeAggregations,
- List<VirtualColumn> virtualColumns,
AggregatorFactory aggregatorFactory
)
{
return Aggregation.create(
- virtualColumns,
Collections.singletonList(aggregatorFactory),
finalizeAggregations ? new FinalizingFieldAccessPostAggregator(
name,
diff --git a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchBaseSqlAggregator.java b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchBaseSqlAggregator.java
index ed2aafc..b71a8cc 100644
--- a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchBaseSqlAggregator.java
+++ b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchBaseSqlAggregator.java
@@ -44,7 +44,6 @@
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable;
-import java.util.ArrayList;
import java.util.List;
public abstract class ThetaSketchBaseSqlAggregator implements SqlAggregator
@@ -94,7 +93,6 @@
sketchSize = SketchAggregatorFactory.DEFAULT_MAX_SKETCH_SIZE;
}
- final List<VirtualColumn> virtualColumns = new ArrayList<>();
final AggregatorFactory aggregatorFactory;
final String aggregatorName = finalizeAggregations ? Calcites.makePrefixedName(name, "a") : name;
@@ -130,7 +128,6 @@
dataType
);
dimensionSpec = new DefaultDimensionSpec(virtualColumn.getOutputName(), null, inputType);
- virtualColumns.add(virtualColumn);
}
aggregatorFactory = new SketchMergeAggregatorFactory(
@@ -146,7 +143,6 @@
return toAggregation(
name,
finalizeAggregations,
- virtualColumns,
aggregatorFactory
);
}
@@ -154,7 +150,6 @@
protected abstract Aggregation toAggregation(
String name,
boolean finalizeAggregations,
- List<VirtualColumn> virtualColumns,
AggregatorFactory aggregatorFactory
);
}
diff --git a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchObjectSqlAggregator.java b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchObjectSqlAggregator.java
index 184adf5..ac2a4ba 100644
--- a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchObjectSqlAggregator.java
+++ b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchObjectSqlAggregator.java
@@ -28,12 +28,10 @@
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.druid.query.aggregation.AggregatorFactory;
-import org.apache.druid.segment.VirtualColumn;
import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import java.util.Collections;
-import java.util.List;
public class ThetaSketchObjectSqlAggregator extends ThetaSketchBaseSqlAggregator implements SqlAggregator
{
@@ -50,12 +48,10 @@
protected Aggregation toAggregation(
String name,
boolean finalizeAggregations,
- List<VirtualColumn> virtualColumns,
AggregatorFactory aggregatorFactory
)
{
return Aggregation.create(
- virtualColumns,
Collections.singletonList(aggregatorFactory),
null
);
diff --git a/extensions-core/druid-bloom-filter/src/main/java/org/apache/druid/query/aggregation/bloom/sql/BloomFilterSqlAggregator.java b/extensions-core/druid-bloom-filter/src/main/java/org/apache/druid/query/aggregation/bloom/sql/BloomFilterSqlAggregator.java
index ff7dbd4..f934520 100644
--- a/extensions-core/druid-bloom-filter/src/main/java/org/apache/druid/query/aggregation/bloom/sql/BloomFilterSqlAggregator.java
+++ b/extensions-core/druid-bloom-filter/src/main/java/org/apache/druid/query/aggregation/bloom/sql/BloomFilterSqlAggregator.java
@@ -50,7 +50,6 @@
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable;
-import java.util.ArrayList;
import java.util.List;
public class BloomFilterSqlAggregator implements SqlAggregator
@@ -115,15 +114,10 @@
// Check input for equivalence.
final boolean inputMatches;
- final VirtualColumn virtualInput =
- existing.getVirtualColumns()
- .stream()
- .filter(virtualColumn ->
- virtualColumn.getOutputName().equals(theFactory.getField().getOutputName())
- )
- .findFirst()
- .orElse(null);
-
+ final VirtualColumn virtualInput = virtualColumnRegistry.findVirtualColumns(theFactory.requiredFields())
+ .stream()
+ .findFirst()
+ .orElse(null);
if (virtualInput == null) {
if (input.isDirectColumnAccess()) {
inputMatches =
@@ -150,7 +144,6 @@
}
// No existing match found. Create a new one.
- final List<VirtualColumn> virtualColumns = new ArrayList<>();
ValueType valueType = Calcites.getValueTypeForRelDataType(inputOperand.getType());
final DimensionSpec spec;
@@ -173,7 +166,6 @@
input,
inputOperand.getType()
);
- virtualColumns.add(virtualColumn);
spec = new DefaultDimensionSpec(
virtualColumn.getOutputName(),
StringUtils.format("%s:%s", name, virtualColumn.getOutputName())
@@ -186,10 +178,7 @@
maxNumEntries
);
- return Aggregation.create(
- virtualColumns,
- aggregatorFactory
- );
+ return Aggregation.create(aggregatorFactory);
}
private static class BloomFilterSqlAggFunction extends SqlAggFunction
diff --git a/extensions-core/histogram/src/main/java/org/apache/druid/query/aggregation/histogram/sql/FixedBucketsHistogramQuantileSqlAggregator.java b/extensions-core/histogram/src/main/java/org/apache/druid/query/aggregation/histogram/sql/FixedBucketsHistogramQuantileSqlAggregator.java
index b29482d..eb3ae43 100644
--- a/extensions-core/histogram/src/main/java/org/apache/druid/query/aggregation/histogram/sql/FixedBucketsHistogramQuantileSqlAggregator.java
+++ b/extensions-core/histogram/src/main/java/org/apache/druid/query/aggregation/histogram/sql/FixedBucketsHistogramQuantileSqlAggregator.java
@@ -50,7 +50,6 @@
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable;
-import java.util.ArrayList;
import java.util.List;
public class FixedBucketsHistogramQuantileSqlAggregator implements SqlAggregator
@@ -188,15 +187,11 @@
// Check input for equivalence.
final boolean inputMatches;
- final VirtualColumn virtualInput = existing.getVirtualColumns()
- .stream()
- .filter(
- virtualColumn ->
- virtualColumn.getOutputName()
- .equals(theFactory.getFieldName())
- )
- .findFirst()
- .orElse(null);
+ final VirtualColumn virtualInput =
+ virtualColumnRegistry.findVirtualColumns(theFactory.requiredFields())
+ .stream()
+ .findFirst()
+ .orElse(null);
if (virtualInput == null) {
inputMatches = input.isDirectColumnAccess()
@@ -224,8 +219,6 @@
}
// No existing match found. Create a new one.
- final List<VirtualColumn> virtualColumns = new ArrayList<>();
-
if (input.isDirectColumnAccess()) {
aggregatorFactory = new FixedBucketsHistogramAggregatorFactory(
histogramName,
@@ -242,7 +235,6 @@
input,
ValueType.FLOAT
);
- virtualColumns.add(virtualColumn);
aggregatorFactory = new FixedBucketsHistogramAggregatorFactory(
histogramName,
virtualColumn.getOutputName(),
@@ -255,7 +247,6 @@
}
return Aggregation.create(
- virtualColumns,
ImmutableList.of(aggregatorFactory),
new QuantilePostAggregator(name, histogramName, probability)
);
diff --git a/extensions-core/histogram/src/main/java/org/apache/druid/query/aggregation/histogram/sql/QuantileSqlAggregator.java b/extensions-core/histogram/src/main/java/org/apache/druid/query/aggregation/histogram/sql/QuantileSqlAggregator.java
index e00e83b..529fa10 100644
--- a/extensions-core/histogram/src/main/java/org/apache/druid/query/aggregation/histogram/sql/QuantileSqlAggregator.java
+++ b/extensions-core/histogram/src/main/java/org/apache/druid/query/aggregation/histogram/sql/QuantileSqlAggregator.java
@@ -51,7 +51,6 @@
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable;
-import java.util.ArrayList;
import java.util.List;
public class QuantileSqlAggregator implements SqlAggregator
@@ -137,15 +136,11 @@
// Check input for equivalence.
final boolean inputMatches;
- final VirtualColumn virtualInput = existing.getVirtualColumns()
- .stream()
- .filter(
- virtualColumn ->
- virtualColumn.getOutputName()
- .equals(theFactory.getFieldName())
- )
- .findFirst()
- .orElse(null);
+ final VirtualColumn virtualInput =
+ virtualColumnRegistry.findVirtualColumns(theFactory.requiredFields())
+ .stream()
+ .findFirst()
+ .orElse(null);
if (virtualInput == null) {
inputMatches = input.isDirectColumnAccess()
@@ -173,8 +168,6 @@
}
// No existing match found. Create a new one.
- final List<VirtualColumn> virtualColumns = new ArrayList<>();
-
if (input.isDirectColumnAccess()) {
if (rowSignature.getColumnType(input.getDirectColumn()).orElse(null) == ValueType.COMPLEX) {
aggregatorFactory = new ApproximateHistogramFoldingAggregatorFactory(
@@ -200,7 +193,6 @@
} else {
final VirtualColumn virtualColumn =
virtualColumnRegistry.getOrCreateVirtualColumnForExpression(plannerContext, input, ValueType.FLOAT);
- virtualColumns.add(virtualColumn);
aggregatorFactory = new ApproximateHistogramAggregatorFactory(
histogramName,
virtualColumn.getOutputName(),
@@ -213,7 +205,6 @@
}
return Aggregation.create(
- virtualColumns,
ImmutableList.of(aggregatorFactory),
new QuantilePostAggregator(name, histogramName, probability)
);
diff --git a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/sql/BaseVarianceSqlAggregator.java b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/sql/BaseVarianceSqlAggregator.java
index 07c8849..617067a 100644
--- a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/sql/BaseVarianceSqlAggregator.java
+++ b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/sql/BaseVarianceSqlAggregator.java
@@ -48,7 +48,6 @@
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable;
-import java.util.ArrayList;
import java.util.List;
public abstract class BaseVarianceSqlAggregator implements SqlAggregator
@@ -84,7 +83,6 @@
final AggregatorFactory aggregatorFactory;
final RelDataType dataType = inputOperand.getType();
final ValueType inputType = Calcites.getValueTypeForRelDataType(dataType);
- final List<VirtualColumn> virtualColumns = new ArrayList<>();
final DimensionSpec dimensionSpec;
final String aggName = StringUtils.format("%s:agg", name);
final SqlAggFunction func = calciteFunction();
@@ -98,7 +96,6 @@
VirtualColumn virtualColumn =
virtualColumnRegistry.getOrCreateVirtualColumnForExpression(plannerContext, input, dataType);
dimensionSpec = new DefaultDimensionSpec(virtualColumn.getOutputName(), null, inputType);
- virtualColumns.add(virtualColumn);
}
switch (inputType) {
@@ -135,7 +132,6 @@
}
return Aggregation.create(
- virtualColumns,
ImmutableList.of(aggregatorFactory),
postAggregator
);
diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/FilteredAggregatorFactory.java b/processing/src/main/java/org/apache/druid/query/aggregation/FilteredAggregatorFactory.java
index d8de2d7..852af4f 100644
--- a/processing/src/main/java/org/apache/druid/query/aggregation/FilteredAggregatorFactory.java
+++ b/processing/src/main/java/org/apache/druid/query/aggregation/FilteredAggregatorFactory.java
@@ -23,6 +23,8 @@
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.base.Preconditions;
import com.google.common.base.Strings;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
import org.apache.druid.query.PerSegmentQueryOptimizationContext;
import org.apache.druid.query.filter.DimFilter;
import org.apache.druid.query.filter.Filter;
@@ -166,7 +168,10 @@
@Override
public List<String> requiredFields()
{
- return delegate.requiredFields();
+ return ImmutableList.copyOf(
+ // use a set to get rid of dupes
+ ImmutableSet.<String>builder().addAll(delegate.requiredFields()).addAll(filter.getRequiredColumns()).build()
+ );
}
@Override
diff --git a/processing/src/test/java/org/apache/druid/query/aggregation/FilteredAggregatorFactoryTest.java b/processing/src/test/java/org/apache/druid/query/aggregation/FilteredAggregatorFactoryTest.java
index 879a252..1f26547 100644
--- a/processing/src/test/java/org/apache/druid/query/aggregation/FilteredAggregatorFactoryTest.java
+++ b/processing/src/test/java/org/apache/druid/query/aggregation/FilteredAggregatorFactoryTest.java
@@ -19,11 +19,14 @@
package org.apache.druid.query.aggregation;
+import com.google.common.collect.ImmutableList;
+import org.apache.druid.query.filter.SelectorDimFilter;
import org.apache.druid.query.filter.TrueDimFilter;
+import org.apache.druid.testing.InitializedNullHandlingTest;
import org.junit.Assert;
import org.junit.Test;
-public class FilteredAggregatorFactoryTest
+public class FilteredAggregatorFactoryTest extends InitializedNullHandlingTest
{
@Test
public void testSimpleNaming()
@@ -44,4 +47,16 @@
null
).getName());
}
+
+ @Test
+ public void testRequiredFields()
+ {
+ Assert.assertEquals(
+ ImmutableList.of("x", "y"),
+ new FilteredAggregatorFactory(
+ new LongSumAggregatorFactory("x", "x"),
+ new SelectorDimFilter("y", "wat", null)
+ ).requiredFields()
+ );
+ }
}
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/Aggregation.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/Aggregation.java
index 41ce161..f28ae18 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/Aggregation.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/Aggregation.java
@@ -29,7 +29,6 @@
import org.apache.druid.query.aggregation.PostAggregator;
import org.apache.druid.query.filter.AndDimFilter;
import org.apache.druid.query.filter.DimFilter;
-import org.apache.druid.segment.VirtualColumn;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.filtration.Filtration;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
@@ -44,17 +43,14 @@
public class Aggregation
{
- private final List<VirtualColumn> virtualColumns;
private final List<AggregatorFactory> aggregatorFactories;
private final PostAggregator postAggregator;
private Aggregation(
- final List<VirtualColumn> virtualColumns,
final List<AggregatorFactory> aggregatorFactories,
final PostAggregator postAggregator
)
{
- this.virtualColumns = Preconditions.checkNotNull(virtualColumns, "virtualColumns");
this.aggregatorFactories = Preconditions.checkNotNull(aggregatorFactories, "aggregatorFactories");
this.postAggregator = postAggregator;
@@ -88,19 +84,10 @@
}
}
- public static Aggregation create(final List<VirtualColumn> virtualColumns, final AggregatorFactory aggregatorFactory)
- {
- return new Aggregation(
- virtualColumns,
- ImmutableList.of(aggregatorFactory),
- null
- );
- }
public static Aggregation create(final AggregatorFactory aggregatorFactory)
{
return new Aggregation(
- ImmutableList.of(),
ImmutableList.of(aggregatorFactory),
null
);
@@ -108,7 +95,7 @@
public static Aggregation create(final PostAggregator postAggregator)
{
- return new Aggregation(Collections.emptyList(), Collections.emptyList(), postAggregator);
+ return new Aggregation(Collections.emptyList(), postAggregator);
}
public static Aggregation create(
@@ -116,21 +103,19 @@
final PostAggregator postAggregator
)
{
- return new Aggregation(ImmutableList.of(), aggregatorFactories, postAggregator);
+ return new Aggregation(aggregatorFactories, postAggregator);
}
- public static Aggregation create(
- final List<VirtualColumn> virtualColumns,
- final List<AggregatorFactory> aggregatorFactories,
- final PostAggregator postAggregator
- )
+ public List<String> getRequiredColumns()
{
- return new Aggregation(virtualColumns, aggregatorFactories, postAggregator);
- }
-
- public List<VirtualColumn> getVirtualColumns()
- {
- return virtualColumns;
+ Set<String> columns = new HashSet<>();
+ for (AggregatorFactory agg : aggregatorFactories) {
+ columns.addAll(agg.requiredFields());
+ }
+ if (postAggregator != null) {
+ columns.addAll(postAggregator.getDependentFields());
+ }
+ return ImmutableList.copyOf(columns);
}
public List<AggregatorFactory> getAggregatorFactories()
@@ -181,21 +166,10 @@
.optimizeFilterOnly(virtualColumnRegistry.getFullRowSignature())
.getDimFilter();
- Set<VirtualColumn> aggVirtualColumnsPlusFilterColumns = new HashSet<>(virtualColumns);
- for (String column : baseOptimizedFilter.getRequiredColumns()) {
- if (virtualColumnRegistry.isVirtualColumnDefined(column)) {
- aggVirtualColumnsPlusFilterColumns.add(virtualColumnRegistry.getVirtualColumn(column));
- }
- }
final List<AggregatorFactory> newAggregators = new ArrayList<>();
for (AggregatorFactory agg : aggregatorFactories) {
if (agg instanceof FilteredAggregatorFactory) {
final FilteredAggregatorFactory filteredAgg = (FilteredAggregatorFactory) agg;
- for (String column : filteredAgg.getFilter().getRequiredColumns()) {
- if (virtualColumnRegistry.isVirtualColumnDefined(column)) {
- aggVirtualColumnsPlusFilterColumns.add(virtualColumnRegistry.getVirtualColumn(column));
- }
- }
newAggregators.add(
new FilteredAggregatorFactory(
filteredAgg.getAggregator(),
@@ -209,7 +183,7 @@
}
}
- return new Aggregation(new ArrayList<>(aggVirtualColumnsPlusFilterColumns), newAggregators, postAggregator);
+ return new Aggregation(newAggregators, postAggregator);
}
@Override
@@ -222,23 +196,21 @@
return false;
}
final Aggregation that = (Aggregation) o;
- return Objects.equals(virtualColumns, that.virtualColumns) &&
- Objects.equals(aggregatorFactories, that.aggregatorFactories) &&
+ return Objects.equals(aggregatorFactories, that.aggregatorFactories) &&
Objects.equals(postAggregator, that.postAggregator);
}
@Override
public int hashCode()
{
- return Objects.hash(virtualColumns, aggregatorFactories, postAggregator);
+ return Objects.hash(aggregatorFactories, postAggregator);
}
@Override
public String toString()
{
return "Aggregation{" +
- "virtualColumns=" + virtualColumns +
- ", aggregatorFactories=" + aggregatorFactories +
+ "aggregatorFactories=" + aggregatorFactories +
", postAggregator=" + postAggregator +
'}';
}
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/ApproxCountDistinctSqlAggregator.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/ApproxCountDistinctSqlAggregator.java
index 73bceca..bbf0189 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/ApproxCountDistinctSqlAggregator.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/ApproxCountDistinctSqlAggregator.java
@@ -52,7 +52,6 @@
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable;
-import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
@@ -94,7 +93,6 @@
return null;
}
- final List<VirtualColumn> myvirtualColumns = new ArrayList<>();
final AggregatorFactory aggregatorFactory;
final String aggregatorName = finalizeAggregations ? Calcites.makePrefixedName(name, "a") : name;
@@ -120,7 +118,6 @@
VirtualColumn virtualColumn =
virtualColumnRegistry.getOrCreateVirtualColumnForExpression(plannerContext, arg, dataType);
dimensionSpec = new DefaultDimensionSpec(virtualColumn.getOutputName(), null, inputType);
- myvirtualColumns.add(virtualColumn);
}
aggregatorFactory = new CardinalityAggregatorFactory(
@@ -133,7 +130,6 @@
}
return Aggregation.create(
- myvirtualColumns,
Collections.singletonList(aggregatorFactory),
finalizeAggregations ? new HyperUniqueFinalizingPostAggregator(name, aggregatorFactory.getName()) : null
);
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/ArraySqlAggregator.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/ArraySqlAggregator.java
index 96e51bb..0f80daa 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/ArraySqlAggregator.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/ArraySqlAggregator.java
@@ -53,7 +53,6 @@
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable;
-import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
@@ -134,19 +133,15 @@
break;
}
}
- List<VirtualColumn> virtualColumns = new ArrayList<>();
-
if (arg.isDirectColumnAccess()) {
fieldName = arg.getDirectColumn();
} else {
VirtualColumn vc = virtualColumnRegistry.getOrCreateVirtualColumnForExpression(plannerContext, arg, elementType);
- virtualColumns.add(vc);
fieldName = vc.getOutputName();
}
if (aggregateCall.isDistinct()) {
return Aggregation.create(
- virtualColumns,
new ExpressionLambdaAggregatorFactory(
name,
ImmutableSet.of(fieldName),
@@ -163,7 +158,6 @@
);
} else {
return Aggregation.create(
- virtualColumns,
new ExpressionLambdaAggregatorFactory(
name,
ImmutableSet.of(fieldName),
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/AvgSqlAggregator.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/AvgSqlAggregator.java
index 3b97344..28364bd 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/AvgSqlAggregator.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/AvgSqlAggregator.java
@@ -31,6 +31,7 @@
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.post.ArithmeticPostAggregator;
import org.apache.druid.query.aggregation.post.FieldAccessPostAggregator;
+import org.apache.druid.segment.VirtualColumn;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.segment.column.ValueType;
import org.apache.druid.sql.calcite.aggregation.Aggregation;
@@ -78,37 +79,7 @@
return null;
}
- final String fieldName;
- final String expression;
- final DruidExpression arg = Iterables.getOnlyElement(arguments);
-
- if (arg.isDirectColumnAccess()) {
- fieldName = arg.getDirectColumn();
- expression = null;
- } else {
- fieldName = null;
- expression = arg.getExpression();
- }
-
- final ExprMacroTable macroTable = plannerContext.getExprMacroTable();
-
- final ValueType sumType;
- // Use 64-bit sum regardless of the type of the AVG aggregator.
- if (SqlTypeName.INT_TYPES.contains(aggregateCall.getType().getSqlTypeName())) {
- sumType = ValueType.LONG;
- } else {
- sumType = ValueType.DOUBLE;
- }
-
- final String sumName = Calcites.makePrefixedName(name, "sum");
final String countName = Calcites.makePrefixedName(name, "count");
- final AggregatorFactory sum = SumSqlAggregator.createSumAggregatorFactory(
- sumType,
- sumName,
- fieldName,
- expression,
- macroTable
- );
final AggregatorFactory count = CountSqlAggregator.createCountAggregatorFactory(
countName,
plannerContext,
@@ -119,6 +90,38 @@
project
);
+ final String fieldName;
+ final String expression;
+ final DruidExpression arg = Iterables.getOnlyElement(arguments);
+
+
+ final ExprMacroTable macroTable = plannerContext.getExprMacroTable();
+ final ValueType sumType;
+ // Use 64-bit sum regardless of the type of the AVG aggregator.
+ if (SqlTypeName.INT_TYPES.contains(aggregateCall.getType().getSqlTypeName())) {
+ sumType = ValueType.LONG;
+ } else {
+ sumType = ValueType.DOUBLE;
+ }
+
+ if (arg.isDirectColumnAccess()) {
+ fieldName = arg.getDirectColumn();
+ expression = null;
+ } else {
+ // if the filter or anywhere else defined a virtual column for us, re-use it
+ VirtualColumn vc = virtualColumnRegistry.getVirtualColumnByExpression(arg.getExpression());
+ fieldName = vc != null ? vc.getOutputName() : null;
+ expression = vc != null ? null : arg.getExpression();
+ }
+ final String sumName = Calcites.makePrefixedName(name, "sum");
+ final AggregatorFactory sum = SumSqlAggregator.createSumAggregatorFactory(
+ sumType,
+ sumName,
+ fieldName,
+ expression,
+ macroTable
+ );
+
return Aggregation.create(
ImmutableList.of(sum, count),
new ArithmeticPostAggregator(
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/CountSqlAggregator.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/CountSqlAggregator.java
index f674798..442b9ea 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/CountSqlAggregator.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/CountSqlAggregator.java
@@ -134,7 +134,7 @@
} else {
// Not COUNT(*), not distinct
// COUNT(x) should count all non-null values of x.
- return Aggregation.create(createCountAggregatorFactory(
+ AggregatorFactory theCount = createCountAggregatorFactory(
name,
plannerContext,
rowSignature,
@@ -142,7 +142,9 @@
rexBuilder,
aggregateCall,
project
- ));
+ );
+
+ return Aggregation.create(theCount);
}
}
}
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/EarliestLatestAnySqlAggregator.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/EarliestLatestAnySqlAggregator.java
index 8ec917e..be9a69f 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/EarliestLatestAnySqlAggregator.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/EarliestLatestAnySqlAggregator.java
@@ -64,9 +64,7 @@
import javax.annotation.Nullable;
import java.util.Collections;
import java.util.List;
-import java.util.Objects;
import java.util.stream.Collectors;
-import java.util.stream.Stream;
public class EarliestLatestAnySqlAggregator implements SqlAggregator
{
@@ -209,9 +207,6 @@
}
return Aggregation.create(
- Stream.of(virtualColumnRegistry.getVirtualColumn(fieldName))
- .filter(Objects::nonNull)
- .collect(Collectors.toList()),
Collections.singletonList(
aggregatorType.createAggregatorFactory(
aggregatorName,
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidQuery.java b/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidQuery.java
index c503f96..251340f 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidQuery.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidQuery.java
@@ -636,7 +636,7 @@
}
for (Aggregation aggregation : grouping.getAggregations()) {
- virtualColumns.addAll(aggregation.getVirtualColumns());
+ virtualColumns.addAll(virtualColumnRegistry.findVirtualColumns(aggregation.getRequiredColumns()));
}
}
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/rel/VirtualColumnRegistry.java b/sql/src/main/java/org/apache/druid/sql/calcite/rel/VirtualColumnRegistry.java
index 7244c75..2ce4a84 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/rel/VirtualColumnRegistry.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/rel/VirtualColumnRegistry.java
@@ -29,7 +29,9 @@
import javax.annotation.Nullable;
import java.util.HashMap;
+import java.util.List;
import java.util.Map;
+import java.util.stream.Collectors;
/**
* Provides facilities to create and re-use {@link VirtualColumn} definitions for dimensions, filters, and filtered
@@ -128,6 +130,12 @@
return virtualColumnsByName.get(virtualColumnName);
}
+ @Nullable
+ public VirtualColumn getVirtualColumnByExpression(String expression)
+ {
+ return virtualColumnsByExpression.get(expression);
+ }
+
/**
* Get a signature representing the base signature plus all registered virtual columns.
*/
@@ -145,4 +153,15 @@
return builder.build();
}
+
+ /**
+ * Given a list of column names, find any corresponding {@link VirtualColumn} with the same name
+ */
+ public List<VirtualColumn> findVirtualColumns(List<String> allColumns)
+ {
+ return allColumns.stream()
+ .filter(this::isVirtualColumnDefined)
+ .map(this::getVirtualColumn)
+ .collect(Collectors.toList());
+ }
}
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/rule/GroupByRules.java b/sql/src/main/java/org/apache/druid/sql/calcite/rule/GroupByRules.java
index 2489ed7..cf596ce 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/rule/GroupByRules.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/rule/GroupByRules.java
@@ -117,7 +117,6 @@
if (doesMatch) {
existingAggregationsWithSameFilter.add(
Aggregation.create(
- existingAggregation.getVirtualColumns(),
existingAggregation.getAggregatorFactories().stream()
.map(factory -> ((FilteredAggregatorFactory) factory).getAggregator())
.collect(Collectors.toList()),
diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
index 1efd99b..976d61a 100644
--- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
+++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
@@ -50,6 +50,7 @@
import org.apache.druid.query.ResourceLimitExceededException;
import org.apache.druid.query.TableDataSource;
import org.apache.druid.query.UnionDataSource;
+import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.CountAggregatorFactory;
import org.apache.druid.query.aggregation.DoubleMaxAggregatorFactory;
import org.apache.druid.query.aggregation.DoubleMinAggregatorFactory;
@@ -112,6 +113,8 @@
import org.apache.druid.query.topn.InvertedTopNMetricSpec;
import org.apache.druid.query.topn.NumericTopNMetricSpec;
import org.apache.druid.query.topn.TopNQueryBuilder;
+import org.apache.druid.segment.VirtualColumn;
+import org.apache.druid.segment.VirtualColumns;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.segment.column.ValueType;
import org.apache.druid.segment.join.JoinType;
@@ -18926,4 +18929,76 @@
expectedResults
);
}
+
+ @Test
+ public void testCountAndAverageByConstantVirtualColumn() throws Exception
+ {
+ List<VirtualColumn> virtualColumns;
+ List<AggregatorFactory> aggs;
+ if (useDefault) {
+ aggs = ImmutableList.of(
+ new FilteredAggregatorFactory(
+ new CountAggregatorFactory("a0"),
+ not(selector("v0", null, null))
+ ),
+ new LongSumAggregatorFactory("a1:sum", null, "325323", TestExprMacroTable.INSTANCE),
+ new CountAggregatorFactory("a1:count")
+ );
+ virtualColumns = ImmutableList.of(
+ expressionVirtualColumn("v0", "'10.1'", ValueType.STRING)
+ );
+ } else {
+ aggs = ImmutableList.of(
+ new FilteredAggregatorFactory(
+ new CountAggregatorFactory("a0"),
+ not(selector("v0", null, null))
+ ),
+ new LongSumAggregatorFactory("a1:sum", "v1"),
+ new FilteredAggregatorFactory(
+ new CountAggregatorFactory("a1:count"),
+ not(selector("v1", null, null))
+ )
+ );
+ virtualColumns = ImmutableList.of(
+ expressionVirtualColumn("v0", "'10.1'", ValueType.STRING),
+ expressionVirtualColumn("v1", "325323", ValueType.LONG)
+ );
+
+ }
+ testQuery(
+ "SELECT dim5, COUNT(dim1), AVG(l1) FROM druid.numfoo WHERE dim1 = '10.1' AND l1 = 325323 GROUP BY dim5",
+ ImmutableList.of(
+ GroupByQuery.builder()
+ .setDataSource(CalciteTests.DATASOURCE3)
+ .setInterval(querySegmentSpec(Filtration.eternity()))
+ .setDimFilter(
+ and(
+ selector("dim1", "10.1", null),
+ selector("l1", "325323", null)
+ )
+ )
+ .setGranularity(Granularities.ALL)
+ .setVirtualColumns(VirtualColumns.create(virtualColumns))
+ .setDimensions(new DefaultDimensionSpec("dim5", "_d0", ValueType.STRING))
+ .setAggregatorSpecs(aggs)
+ .setPostAggregatorSpecs(
+ ImmutableList.of(
+ new ArithmeticPostAggregator(
+ "a1",
+ "quotient",
+ ImmutableList.of(
+ new FieldAccessPostAggregator(null, "a1:sum"),
+ new FieldAccessPostAggregator(null, "a1:count")
+ )
+ )
+ )
+ )
+ .setContext(QUERY_CONTEXT_DEFAULT)
+ .build()
+ ),
+ ImmutableList.of(
+ new Object[]{"ab", 1L, 325323L}
+ )
+ );
+ }
}