[CALCITE-4616] AggregateUnionTransposeRule causes row type mismatch when some inputs have unique grouping key

Close apache/calcite#2437
diff --git a/core/src/main/java/org/apache/calcite/rel/rules/AggregateUnionTransposeRule.java b/core/src/main/java/org/apache/calcite/rel/rules/AggregateUnionTransposeRule.java
index 54d89c4..7d61d00 100644
--- a/core/src/main/java/org/apache/calcite/rel/rules/AggregateUnionTransposeRule.java
+++ b/core/src/main/java/org/apache/calcite/rel/rules/AggregateUnionTransposeRule.java
@@ -127,30 +127,34 @@
       return;
     }
 
-    // create corresponding aggregates on top of each union child
-    final RelBuilder relBuilder = call.builder();
-    int transformCount = 0;
+    boolean hasUniqueKeyInAllInputs = true;
     final RelMetadataQuery mq = call.getMetadataQuery();
     for (RelNode input : union.getInputs()) {
       boolean alreadyUnique =
           RelMdUtil.areColumnsDefinitelyUnique(mq, input,
               aggRel.getGroupSet());
 
-      relBuilder.push(input);
       if (!alreadyUnique) {
-        ++transformCount;
-        relBuilder.aggregate(relBuilder.groupKey(aggRel.getGroupSet()),
-            aggRel.getAggCallList());
+        hasUniqueKeyInAllInputs = false;
+        break;
       }
     }
 
-    if (transformCount == 0) {
+    if (hasUniqueKeyInAllInputs) {
       // none of the children could benefit from the push-down,
       // so bail out (preventing the infinite loop to which most
       // planners would succumb)
       return;
     }
 
+    // create corresponding aggregates on top of each union child
+    final RelBuilder relBuilder = call.builder();
+    for (RelNode input : union.getInputs()) {
+      relBuilder.push(input);
+      relBuilder.aggregate(relBuilder.groupKey(aggRel.getGroupSet()),
+          aggRel.getAggCallList());
+    }
+
     // create a new union whose children are the aggregates created above
     relBuilder.union(true, union.getInputs().size());
     relBuilder.aggregate(
diff --git a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
index c5bfadb..df4cd85 100644
--- a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
+++ b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
@@ -5269,6 +5269,35 @@
         .check();
   }
 
+  /** Test case for
+   * <a href="https://issues.apache.org/jira/browse/CALCITE-4616">[CALCITE-4616]
+   * AggregateUnionTransposeRule causes row type mismatch when some inputs have
+   * unique grouping key</a>. */
+  @Test void testAggregateUnionTransposeWithOneInputUnique() {
+    final String sql = "select deptno, SUM(t) from (\n"
+        + "select deptno, 1 as t from sales.emp e1\n"
+        + "union all\n"
+        + "select distinct deptno, 2 as t from sales.emp e2)\n"
+        + "group by deptno";
+    sql(sql)
+        .withRule(CoreRules.AGGREGATE_UNION_TRANSPOSE)
+        .check();
+  }
+
+  /** If all inputs to UNION are already unique, AggregateUnionTransposeRule is
+   * a no-op. */
+  @Test void testAggregateUnionTransposeWithAllInputsUnique() {
+    final String sql = "select deptno, SUM(t) from (\n"
+        + "select distinct deptno, 1 as t from sales.emp e1\n"
+        + "union all\n"
+        + "select distinct deptno, 2 as t from sales.emp e2)\n"
+        + "group by deptno";
+    sql(sql)
+        .withRule(CoreRules.AGGREGATE_UNION_TRANSPOSE)
+        .checkUnchanged();
+  }
+
+
   @Test void testSortJoinTranspose1() {
     final String sql = "select * from sales.emp e left join (\n"
         + "  select * from sales.dept d) d on e.deptno = d.deptno\n"
diff --git a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
index 01a9142..ec91207 100644
--- a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
+++ b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
@@ -1031,6 +1031,72 @@
 ]]>
     </Resource>
   </TestCase>
+  <TestCase name="testAggregateUnionTransposeWithAllInputsUnique">
+    <Resource name="sql">
+      <![CDATA[select deptno, SUM(t) from (
+select deptno, 1 as t from sales.emp e1
+union all
+select distinct deptno, 2 as t from sales.emp e2)
+group by deptno]]>
+    </Resource>
+    <Resource name="planBefore">
+      <![CDATA[
+LogicalAggregate(group=[{0}], EXPR$1=[SUM($1)])
+  LogicalUnion(all=[true])
+    LogicalAggregate(group=[{0, 1}])
+      LogicalProject(DEPTNO=[$7], T=[1])
+        LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+    LogicalAggregate(group=[{0, 1}])
+      LogicalProject(DEPTNO=[$7], T=[2])
+        LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+    </Resource>
+    <Resource name="planAfter">
+      <![CDATA[
+LogicalAggregate(group=[{0}], EXPR$1=[SUM($1)])
+  LogicalUnion(all=[true])
+    LogicalAggregate(group=[{0, 1}])
+      LogicalProject(DEPTNO=[$7], T=[1])
+        LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+    LogicalAggregate(group=[{0, 1}])
+      LogicalProject(DEPTNO=[$7], T=[2])
+        LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+    </Resource>
+  </TestCase>
+  <TestCase name="testAggregateUnionTransposeWithOneInputUnique">
+    <Resource name="sql">
+      <![CDATA[select deptno, SUM(t) from (
+select deptno, 1 as t from sales.emp e1
+union all
+select distinct deptno, 2 as t from sales.emp e2)
+group by deptno]]>
+    </Resource>
+    <Resource name="planBefore">
+      <![CDATA[
+LogicalAggregate(group=[{0}], EXPR$1=[SUM($1)])
+  LogicalUnion(all=[true])
+    LogicalProject(DEPTNO=[$7], T=[1])
+      LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+    LogicalAggregate(group=[{0, 1}])
+      LogicalProject(DEPTNO=[$7], T=[2])
+        LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+    </Resource>
+    <Resource name="planAfter">
+      <![CDATA[
+LogicalAggregate(group=[{0}], EXPR$1=[SUM($1)])
+  LogicalUnion(all=[true])
+    LogicalAggregate(group=[{0}], EXPR$1=[SUM($1)])
+      LogicalProject(DEPTNO=[$7], T=[1])
+        LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+    LogicalAggregate(group=[{0}], EXPR$1=[SUM($1)])
+      LogicalAggregate(group=[{0, 1}])
+        LogicalProject(DEPTNO=[$7], T=[2])
+          LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+    </Resource>
+  </TestCase>
   <TestCase name="testAll">
     <Resource name="sql">
       <![CDATA[select * from emp e1