VirtualColumnRegistry reuse virtual column should take account of value type (#11546)
Co-authored-by: huangqixiang.871 <huangqixiang.871@bytedance.com>
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/AvgSqlAggregator.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/AvgSqlAggregator.java
index 28364bd..7d80c18 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/AvgSqlAggregator.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/AvgSqlAggregator.java
@@ -24,6 +24,7 @@
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rex.RexBuilder;
+import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.type.SqlTypeName;
@@ -38,6 +39,7 @@
import org.apache.druid.sql.calcite.aggregation.Aggregations;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DruidExpression;
+import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.planner.Calcites;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
@@ -109,7 +111,12 @@
expression = null;
} else {
// if the filter or anywhere else defined a virtual column for us, re-use it
- VirtualColumn vc = virtualColumnRegistry.getVirtualColumnByExpression(arg.getExpression());
+ final RexNode resolutionArg = Expressions.fromFieldAccess(
+ rowSignature,
+ project,
+ Iterables.getOnlyElement(aggregateCall.getArgList())
+ );
+ VirtualColumn vc = virtualColumnRegistry.getVirtualColumnByExpression(arg.getExpression(), resolutionArg.getType());
fieldName = vc != null ? vc.getOutputName() : null;
expression = vc != null ? null : arg.getExpression();
}
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/rel/VirtualColumnRegistry.java b/sql/src/main/java/org/apache/druid/sql/calcite/rel/VirtualColumnRegistry.java
index 2ce4a84..bdfa396 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/rel/VirtualColumnRegistry.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/rel/VirtualColumnRegistry.java
@@ -31,6 +31,7 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
+import java.util.Objects;
import java.util.stream.Collectors;
/**
@@ -40,7 +41,7 @@
public class VirtualColumnRegistry
{
private final RowSignature baseRowSignature;
- private final Map<String, VirtualColumn> virtualColumnsByExpression;
+ private final Map<ExpressionWrapper, VirtualColumn> virtualColumnsByExpression;
private final Map<String, VirtualColumn> virtualColumnsByName;
private final String virtualColumnPrefix;
private int virtualColumnCounter;
@@ -48,7 +49,7 @@
private VirtualColumnRegistry(
RowSignature baseRowSignature,
String virtualColumnPrefix,
- Map<String, VirtualColumn> virtualColumnsByExpression,
+ Map<ExpressionWrapper, VirtualColumn> virtualColumnsByExpression,
Map<String, VirtualColumn> virtualColumnsByName
)
{
@@ -85,7 +86,8 @@
ValueType valueType
)
{
- if (!virtualColumnsByExpression.containsKey(expression.getExpression())) {
+ ExpressionWrapper expressionWrapper = new ExpressionWrapper(expression.getExpression(), valueType);
+ if (!virtualColumnsByExpression.containsKey(expressionWrapper)) {
final String virtualColumnName = virtualColumnPrefix + virtualColumnCounter++;
final VirtualColumn virtualColumn = expression.toVirtualColumn(
virtualColumnName,
@@ -93,7 +95,7 @@
plannerContext.getExprMacroTable()
);
virtualColumnsByExpression.put(
- expression.getExpression(),
+ expressionWrapper,
virtualColumn
);
virtualColumnsByName.put(
@@ -102,7 +104,7 @@
);
}
- return virtualColumnsByExpression.get(expression.getExpression());
+ return virtualColumnsByExpression.get(expressionWrapper);
}
/**
@@ -131,9 +133,10 @@
}
@Nullable
- public VirtualColumn getVirtualColumnByExpression(String expression)
+ public VirtualColumn getVirtualColumnByExpression(String expression, RelDataType type)
{
- return virtualColumnsByExpression.get(expression);
+ ExpressionWrapper expressionWrapper = new ExpressionWrapper(expression, Calcites.getValueTypeForRelDataType(type));
+ return virtualColumnsByExpression.get(expressionWrapper);
}
/**
@@ -164,4 +167,35 @@
.map(this::getVirtualColumn)
.collect(Collectors.toList());
}
+
+ private static class ExpressionWrapper
+ {
+ private final String expression;
+ private final ValueType valueType;
+
+ public ExpressionWrapper(String expression, ValueType valueType)
+ {
+ this.expression = expression;
+ this.valueType = valueType;
+ }
+
+ @Override
+ public boolean equals(Object o)
+ {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ ExpressionWrapper expressionWrapper = (ExpressionWrapper) o;
+ return Objects.equals(expression, expressionWrapper.expression) && valueType == expressionWrapper.valueType;
+ }
+
+ @Override
+ public int hashCode()
+ {
+ return Objects.hash(expression, valueType);
+ }
+ }
}
diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
index bb1dcb8..3afbbb7 100644
--- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
+++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
@@ -18818,4 +18818,44 @@
Collections.emptyList()
);
}
+
+ @Test
+ public void testCommonVirtualExpressionWithDifferentValueType() throws Exception
+ {
+ testQuery(
+ "select\n"
+ + " dim1,\n"
+ + " sum(cast(0 as bigint)) as s1,\n"
+ + " sum(cast(0 as double)) as s2\n"
+ + "from druid.foo\n"
+ + "where dim1 = 'none'\n"
+ + "group by dim1\n"
+ + "limit 1",
+ ImmutableList.of(new TopNQueryBuilder()
+ .dataSource(CalciteTests.DATASOURCE1)
+ .intervals(querySegmentSpec(Filtration.eternity()))
+ .filters(selector("dim1", "none", null))
+ .granularity(Granularities.ALL)
+ .virtualColumns(
+ expressionVirtualColumn(
+ "v0",
+ "'none'",
+ ValueType.STRING
+ )
+ )
+ .dimension(
+ new DefaultDimensionSpec("v0", "d0")
+ )
+ .aggregators(
+ aggregators(
+ new LongSumAggregatorFactory("a0", null, "0", ExprMacroTable.nil()),
+ new DoubleSumAggregatorFactory("a1", null, "0", ExprMacroTable.nil())
+ ))
+ .context(QUERY_CONTEXT_DEFAULT)
+ .metric(new DimensionTopNMetricSpec(null, StringComparators.LEXICOGRAPHIC))
+ .threshold(1)
+ .build()),
+ ImmutableList.of()
+ );
+ }
}