[CALCITE-4772] PushProjector should retain alias when handling RexCall (YuKong)
Close apache/calcite#2516
diff --git a/core/src/main/java/org/apache/calcite/rel/rules/PushProjector.java b/core/src/main/java/org/apache/calcite/rel/rules/PushProjector.java
index 00f4ce3..6ac717d 100644
--- a/core/src/main/java/org/apache/calcite/rel/rules/PushProjector.java
+++ b/core/src/main/java/org/apache/calcite/rel/rules/PushProjector.java
@@ -562,11 +562,15 @@
new RexUtil.FixNullabilityShuttle(
projChild.getCluster().getRexBuilder(), typeList);
newExpr = newExpr.accept(fixer);
-
- newProjects.add(
- Pair.of(
- newExpr,
- SqlUtil.deriveAliasFromOrdinal(preserveExpOrdinal++)));
+ final String originalFieldName = findOriginalFieldName(projExpr);
+ final String newAlias;
+ if (originalFieldName != null) {
+ newAlias = originalFieldName;
+ } else {
+ newAlias = SqlUtil.deriveAliasFromOrdinal(preserveExpOrdinal);
+ }
+ newProjects.add(Pair.of(newExpr, newAlias));
+ preserveExpOrdinal++;
}
return (Project) relBuilder.push(projChild)
@@ -574,6 +578,16 @@
.build();
}
+ private @Nullable String findOriginalFieldName(RexNode originRexNode) {
+ if (origProj == null) {
+ return null;
+ }
+ int idx = origProj.getProjects().indexOf(originRexNode);
+ if (idx < 0) {
+ return null;
+ }
+ return origProj.getRowType().getFieldList().get(idx).getName();
+ }
/**
* Determines how much each input reference needs to be adjusted as a result
* of projection.
diff --git a/core/src/test/java/org/apache/calcite/rel/rel2sql/RelToSqlConverterTest.java b/core/src/test/java/org/apache/calcite/rel/rel2sql/RelToSqlConverterTest.java
index b87a0a5..6959648 100644
--- a/core/src/test/java/org/apache/calcite/rel/rel2sql/RelToSqlConverterTest.java
+++ b/core/src/test/java/org/apache/calcite/rel/rel2sql/RelToSqlConverterTest.java
@@ -3466,8 +3466,8 @@
+ "join \"customer\" as c\n"
+ " on s.\"customer_id\" = c.\"customer_id\"\n"
+ "group by s.\"customer_id\"";
- final String expected = "SELECT \"t\".\"customer_id\", SUM(\"t\".\"EXPR$0\")\n"
- + "FROM (SELECT \"customer_id\", \"store_sales\" * \"store_cost\" AS \"EXPR$0\"\n"
+ final String expected = "SELECT \"t\".\"customer_id\", SUM(\"t\".\"$f1\")\n"
+ + "FROM (SELECT \"customer_id\", \"store_sales\" * \"store_cost\" AS \"$f1\"\n"
+ "FROM \"foodmart\".\"sales_fact_1997\") AS \"t\"\n"
+ "INNER JOIN (SELECT \"customer_id\"\n"
+ "FROM \"foodmart\".\"customer\") AS \"t0\" ON \"t\".\"customer_id\" = \"t0\".\"customer_id\"\n"
@@ -3476,6 +3476,21 @@
sql(sql).optimize(rules, null).ok(expected);
}
+ @Test void testMultiplicationRetainsExplicitAlias() {
+ final String sql = "select s.\"customer_id\", s.\"store_sales\" * s.\"store_cost\" as \"total\""
+ + "from \"sales_fact_1997\" as s\n"
+ + "join \"customer\" as c\n"
+ + " on s.\"customer_id\" = c.\"customer_id\"\n";
+ final String expected = "SELECT \"t\".\"customer_id\", \"t\".\"total\"\n"
+ + "FROM (SELECT \"customer_id\", \"store_sales\" * \"store_cost\" AS \"total\"\n"
+ + "FROM \"foodmart\".\"sales_fact_1997\") AS \"t\"\n"
+ + "INNER JOIN (SELECT \"customer_id\"\n"
+ + "FROM \"foodmart\".\"customer\") AS \"t0\" ON \"t\".\"customer_id\" = \"t0\""
+ + ".\"customer_id\"";
+ RuleSet rules = RuleSets.ofList(CoreRules.PROJECT_JOIN_TRANSPOSE);
+ sql(sql).optimize(rules, null).ok(expected);
+ }
+
@Test void testRankFunctionForPrintingOfFrameBoundary() {
String query = "SELECT rank() over (order by \"hire_date\") FROM \"employee\"";
String expected = "SELECT RANK() OVER (ORDER BY \"hire_date\")\n"
diff --git a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
index 987802e..01a9142 100644
--- a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
+++ b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
@@ -4202,7 +4202,7 @@
LogicalJoin(condition=[$4], joinType=[left])
LogicalProject(DEPTNO=[$0], NAME=[$1])
LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
- LogicalProject(DEPTNO=[$0], NAME=[$1], $f2=[>($0, 10)], EXPR$0=[>($0, 10)])
+ LogicalProject(DEPTNO=[$0], NAME=[$1], $f2=[>($0, 10)], $f4=[>($0, 10)])
LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
LogicalProject(DEPTNO=[$0])
LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
@@ -8763,7 +8763,7 @@
LogicalAggregate(group=[{0}], EXPR$0=[COUNT()])
LogicalProject(EXPR$1=[$1])
LogicalJoin(condition=[=($0, $2)], joinType=[full])
- LogicalProject(ENAME=[$1], EXPR$0=[CASE(<($5, 11), *(-1, $5), $5)])
+ LogicalProject(ENAME=[$1], EXPR$1=[CASE(<($5, 11), *(-1, $5), $5)])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
LogicalProject(ENAME=[$0])
LogicalTableScan(table=[[CATALOG, SALES, BONUS]])
@@ -8792,7 +8792,7 @@
LogicalAggregate(group=[{0}], EXPR$0=[COUNT()])
LogicalProject(EXPR$1=[$1])
LogicalJoin(condition=[=($0, $2)], joinType=[inner])
- LogicalProject(ENAME=[$1], EXPR$0=[CASE(<($5, 11), 11, *(-1, $5))])
+ LogicalProject(ENAME=[$1], EXPR$1=[CASE(<($5, 11), 11, *(-1, $5))])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
LogicalProject(ENAME=[$0])
LogicalTableScan(table=[[CATALOG, SALES, BONUS]])
@@ -8821,7 +8821,7 @@
LogicalAggregate(group=[{0}], EXPR$0=[COUNT()])
LogicalProject(EXPR$1=[$1])
LogicalJoin(condition=[=($0, $2)], joinType=[inner])
- LogicalProject(ENAME=[$1], EXPR$0=[CASE(<($5, 11), *(-1, $5), $5)])
+ LogicalProject(ENAME=[$1], EXPR$1=[CASE(<($5, 11), *(-1, $5), $5)])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
LogicalProject(ENAME=[$0])
LogicalTableScan(table=[[CATALOG, SALES, BONUS]])
@@ -8874,7 +8874,7 @@
LogicalAggregate(group=[{0}], EXPR$0=[COUNT()])
LogicalProject(EXPR$1=[$1])
LogicalJoin(condition=[=($0, $2)], joinType=[left])
- LogicalProject(ENAME=[$1], EXPR$0=[CASE(<($5, 11), 11, *(-1, $5))])
+ LogicalProject(ENAME=[$1], EXPR$1=[CASE(<($5, 11), 11, *(-1, $5))])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
LogicalProject(ENAME=[$0])
LogicalTableScan(table=[[CATALOG, SALES, BONUS]])
@@ -8963,7 +8963,7 @@
LogicalJoin(condition=[=($1, $0)], joinType=[left])
LogicalProject(ENAME=[$0])
LogicalTableScan(table=[[CATALOG, SALES, BONUS]])
- LogicalProject(ENAME=[$1], EXPR$0=[CASE(<($5, 11), *(-1, $5), $5)])
+ LogicalProject(ENAME=[$1], EXPR$1=[CASE(<($5, 11), *(-1, $5), $5)])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
]]>
</Resource>
@@ -9020,7 +9020,7 @@
LogicalAggregate(group=[{0}], EXPR$0=[COUNT()])
LogicalProject(EXPR$1=[$1])
LogicalJoin(condition=[=($0, $2)], joinType=[right])
- LogicalProject(ENAME=[$1], EXPR$0=[CASE(<($5, 11), *(-1, $5), $5)])
+ LogicalProject(ENAME=[$1], EXPR$1=[CASE(<($5, 11), *(-1, $5), $5)])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
LogicalProject(ENAME=[$0])
LogicalTableScan(table=[[CATALOG, SALES, BONUS]])
@@ -9051,7 +9051,7 @@
LogicalJoin(condition=[=($1, $0)], joinType=[right])
LogicalProject(ENAME=[$0])
LogicalTableScan(table=[[CATALOG, SALES, BONUS]])
- LogicalProject(ENAME=[$1], EXPR$0=[CASE(<($5, 11), 11, *(-1, $5))])
+ LogicalProject(ENAME=[$1], EXPR$1=[CASE(<($5, 11), 11, *(-1, $5))])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
]]>
</Resource>
@@ -9080,7 +9080,7 @@
LogicalJoin(condition=[=($1, $0)], joinType=[right])
LogicalProject(ENAME=[$0])
LogicalTableScan(table=[[CATALOG, SALES, BONUS]])
- LogicalProject(ENAME=[$1], EXPR$0=[CASE(<($5, 11), *(-1, $5), $5)])
+ LogicalProject(ENAME=[$1], EXPR$1=[CASE(<($5, 11), *(-1, $5), $5)])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
]]>
</Resource>