VirtualColumnRegistry reuse virtual column should take account of value type (#11546)

Co-authored-by: huangqixiang.871 <huangqixiang.871@bytedance.com>
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/AvgSqlAggregator.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/AvgSqlAggregator.java
index 28364bd..7d80c18 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/AvgSqlAggregator.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/AvgSqlAggregator.java
@@ -24,6 +24,7 @@
 import org.apache.calcite.rel.core.AggregateCall;
 import org.apache.calcite.rel.core.Project;
 import org.apache.calcite.rex.RexBuilder;
+import org.apache.calcite.rex.RexNode;
 import org.apache.calcite.sql.SqlAggFunction;
 import org.apache.calcite.sql.fun.SqlStdOperatorTable;
 import org.apache.calcite.sql.type.SqlTypeName;
@@ -38,6 +39,7 @@
 import org.apache.druid.sql.calcite.aggregation.Aggregations;
 import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
 import org.apache.druid.sql.calcite.expression.DruidExpression;
+import org.apache.druid.sql.calcite.expression.Expressions;
 import org.apache.druid.sql.calcite.planner.Calcites;
 import org.apache.druid.sql.calcite.planner.PlannerContext;
 import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
@@ -109,7 +111,12 @@
       expression = null;
     } else {
       // if the filter or anywhere else defined a virtual column for us, re-use it
-      VirtualColumn vc = virtualColumnRegistry.getVirtualColumnByExpression(arg.getExpression());
+      final RexNode resolutionArg = Expressions.fromFieldAccess(
+          rowSignature,
+          project,
+          Iterables.getOnlyElement(aggregateCall.getArgList())
+      );
+      VirtualColumn vc = virtualColumnRegistry.getVirtualColumnByExpression(arg.getExpression(), resolutionArg.getType());
       fieldName = vc != null ? vc.getOutputName() : null;
       expression = vc != null ? null : arg.getExpression();
     }
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/rel/VirtualColumnRegistry.java b/sql/src/main/java/org/apache/druid/sql/calcite/rel/VirtualColumnRegistry.java
index 2ce4a84..bdfa396 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/rel/VirtualColumnRegistry.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/rel/VirtualColumnRegistry.java
@@ -31,6 +31,7 @@
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.Objects;
 import java.util.stream.Collectors;
 
 /**
@@ -40,7 +41,7 @@
 public class VirtualColumnRegistry
 {
   private final RowSignature baseRowSignature;
-  private final Map<String, VirtualColumn> virtualColumnsByExpression;
+  private final Map<ExpressionWrapper, VirtualColumn> virtualColumnsByExpression;
   private final Map<String, VirtualColumn> virtualColumnsByName;
   private final String virtualColumnPrefix;
   private int virtualColumnCounter;
@@ -48,7 +49,7 @@
   private VirtualColumnRegistry(
       RowSignature baseRowSignature,
       String virtualColumnPrefix,
-      Map<String, VirtualColumn> virtualColumnsByExpression,
+      Map<ExpressionWrapper, VirtualColumn> virtualColumnsByExpression,
       Map<String, VirtualColumn> virtualColumnsByName
   )
   {
@@ -85,7 +86,8 @@
       ValueType valueType
   )
   {
-    if (!virtualColumnsByExpression.containsKey(expression.getExpression())) {
+    ExpressionWrapper expressionWrapper = new ExpressionWrapper(expression.getExpression(), valueType);
+    if (!virtualColumnsByExpression.containsKey(expressionWrapper)) {
       final String virtualColumnName = virtualColumnPrefix + virtualColumnCounter++;
       final VirtualColumn virtualColumn = expression.toVirtualColumn(
           virtualColumnName,
@@ -93,7 +95,7 @@
           plannerContext.getExprMacroTable()
       );
       virtualColumnsByExpression.put(
-          expression.getExpression(),
+          expressionWrapper,
           virtualColumn
       );
       virtualColumnsByName.put(
@@ -102,7 +104,7 @@
       );
     }
 
-    return virtualColumnsByExpression.get(expression.getExpression());
+    return virtualColumnsByExpression.get(expressionWrapper);
   }
 
   /**
@@ -131,9 +133,10 @@
   }
 
   @Nullable
-  public VirtualColumn getVirtualColumnByExpression(String expression)
+  public VirtualColumn getVirtualColumnByExpression(String expression, RelDataType type)
   {
-    return virtualColumnsByExpression.get(expression);
+    ExpressionWrapper expressionWrapper = new ExpressionWrapper(expression, Calcites.getValueTypeForRelDataType(type));
+    return virtualColumnsByExpression.get(expressionWrapper);
   }
 
   /**
@@ -164,4 +167,35 @@
                      .map(this::getVirtualColumn)
                      .collect(Collectors.toList());
   }
+
+  private static class ExpressionWrapper
+  {
+    private final String expression;
+    private final ValueType valueType;
+
+    public ExpressionWrapper(String expression, ValueType valueType)
+    {
+      this.expression = expression;
+      this.valueType = valueType;
+    }
+
+    @Override
+    public boolean equals(Object o)
+    {
+      if (this == o) {
+        return true;
+      }
+      if (o == null || getClass() != o.getClass()) {
+        return false;
+      }
+      ExpressionWrapper expressionWrapper = (ExpressionWrapper) o;
+      return Objects.equals(expression, expressionWrapper.expression) && valueType == expressionWrapper.valueType;
+    }
+
+    @Override
+    public int hashCode()
+    {
+      return Objects.hash(expression, valueType);
+    }
+  }
 }
diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
index bb1dcb8..3afbbb7 100644
--- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
+++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
@@ -18818,4 +18818,44 @@
         Collections.emptyList()
     );
   }
+
+  @Test
+  public void testCommonVirtualExpressionWithDifferentValueType() throws Exception
+  {
+    testQuery(
+        "select\n"
+        + " dim1,\n"
+        + " sum(cast(0 as bigint)) as s1,\n"
+        + " sum(cast(0 as double)) as s2\n"
+        + "from druid.foo\n"
+        + "where dim1 = 'none'\n"
+        + "group by dim1\n"
+        + "limit 1",
+        ImmutableList.of(new TopNQueryBuilder()
+                             .dataSource(CalciteTests.DATASOURCE1)
+                             .intervals(querySegmentSpec(Filtration.eternity()))
+                             .filters(selector("dim1", "none", null))
+                             .granularity(Granularities.ALL)
+                             .virtualColumns(
+                                 expressionVirtualColumn(
+                                     "v0",
+                                     "'none'",
+                                     ValueType.STRING
+                                 )
+                             )
+                             .dimension(
+                                 new DefaultDimensionSpec("v0", "d0")
+                             )
+                             .aggregators(
+                                 aggregators(
+                                     new LongSumAggregatorFactory("a0", null, "0", ExprMacroTable.nil()),
+                                     new DoubleSumAggregatorFactory("a1", null, "0", ExprMacroTable.nil())
+                                 ))
+                             .context(QUERY_CONTEXT_DEFAULT)
+                             .metric(new DimensionTopNMetricSpec(null, StringComparators.LEXICOGRAPHIC))
+                             .threshold(1)
+                             .build()),
+        ImmutableList.of()
+    );
+  }
 }