SQL support for union datasources. (#10324)

* SQL support for union datasources.

Exposed via the "UNION ALL" operator. This means that there are now two
different implementations of UNION ALL: one at the top level of a query
that works by concatenating subquery results, and one at the table level
that works by creating a UnionDataSource.

The SQL documentation is updated to discuss these two use cases and how
they behave.

Future work could unify these by building support for a native datasource
that represents the union of multiple subqueries. (Today, UnionDataSource
can only represent the union of tables, not subqueries.)

* Fixes.

* Error message for sanity check.

* Additional test fixes.

* Add some error messages.
diff --git a/docs/querying/sql.md b/docs/querying/sql.md
index 8eb4c79..1595c27 100644
--- a/docs/querying/sql.md
+++ b/docs/querying/sql.md
@@ -145,12 +145,52 @@
 
 ### UNION ALL
 
-The "UNION ALL" operator can be used to fuse multiple queries together. Their results will be concatenated, and each
-query will run separately, back to back (not in parallel). Druid does not currently support "UNION" without "ALL".
-UNION ALL must appear at the very outer layer of a SQL query (it cannot appear in a subquery or in the FROM clause).
+The "UNION ALL" operator fuses multiple queries together. Druid SQL supports the UNION ALL operator in two situations:
+top-level and table-level. Queries that use UNION ALL in any other way will not be able to execute.
 
-Note that despite the similar name, UNION ALL is not the same thing as as [union datasource](datasource.md#union).
-UNION ALL allows unioning the results of queries, whereas union datasources allow unioning tables.
+#### Top-level
+
+UNION ALL can be used at the very top outer layer of a SQL query (not in a subquery, and not in the FROM clause). In
+this case, the underlying queries will be run separately, back to back, and their results will all be returned in
+one result set.
+
+For example:
+
+```
+SELECT COUNT(*) FROM tbl WHERE my_column = 'value1'
+UNION ALL
+SELECT COUNT(*) FROM tbl WHERE my_column = 'value2'
+```
+
+When UNION ALL occurs at the top level of a query like this, the results from the unioned queries are concatenated
+together and appear one after the other.
+
+#### Table-level
+
+UNION ALL can be used to query multiple tables at the same time. In this case, it must appear in the FROM clause,
+and the subqueries that are inputs to the UNION ALL operator must be simple table SELECTs (no expressions, column
+aliasing, etc). The query will run natively using a [union datasource](datasource.md#union).
+
+The same columns must be selected from each table in the same order, and those columns must either have the same types,
+or types that can be implicitly cast to each other (such as different numeric types). For this reason, it is generally
+more robust to write your queries to select specific columns. If you use `SELECT *`, you will need to modify your
+queries if a new column is added to one of the tables but not to the others.
+
+For example:
+
+```
+SELECT col1, COUNT(*)
+FROM (
+  SELECT col1, col2, col3 FROM tbl1
+  UNION ALL
+  SELECT col1, col2, col3 FROM tbl2
+)
+GROUP BY col1
+```
+
+When UNION ALL occurs at the table level, the rows from the unioned tables are not guaranteed to be processed in
+any particular order. They may be processed in an interleaved fashion. If you need a particular result ordering,
+use [ORDER BY](#order-by).
 
 ### EXPLAIN PLAN
 
@@ -754,7 +794,6 @@
 Additionally, some Druid native query features are not supported by the SQL language. Some unsupported Druid features
 include:
 
-- [Union datasources](datasource.html#union).
 - [Inline datasources](datasource.html#inline).
 - [Spatial filters](../development/geo.html).
 - [Query cancellation](querying.html#query-cancellation).
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidJoinQueryRel.java b/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidJoinQueryRel.java
index 6faf219..5500e50 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidJoinQueryRel.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidJoinQueryRel.java
@@ -340,7 +340,7 @@
     // ideally this would involve JoinableFactory.isDirectlyJoinable to check that the global datasources
     // are in fact possibly joinable, but for now isGlobal is coupled to joinability
     return !(DruidRels.isScanOrMapping(right, false)
-             && DruidRels.dataSourceIfLeafRel(right).filter(DataSource::isGlobal).isPresent());
+             && DruidRels.druidTableIfLeafRel(right).filter(table -> table.getDataSource().isGlobal()).isPresent());
   }
 
   /**
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidRels.java b/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidRels.java
index 2ca30d8..cef9cb4 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidRels.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidRels.java
@@ -19,7 +19,6 @@
 
 package org.apache.druid.sql.calcite.rel;
 
-import org.apache.druid.query.DataSource;
 import org.apache.druid.sql.calcite.table.DruidTable;
 
 import java.util.Optional;
@@ -29,10 +28,10 @@
   /**
    * Returns the DataSource involved in a leaf query of class {@link DruidQueryRel}.
    */
-  public static Optional<DataSource> dataSourceIfLeafRel(final DruidRel<?> druidRel)
+  public static Optional<DruidTable> druidTableIfLeafRel(final DruidRel<?> druidRel)
   {
     if (druidRel instanceof DruidQueryRel) {
-      return Optional.of(druidRel.getTable().unwrap(DruidTable.class).getDataSource());
+      return Optional.of(druidRel.getTable().unwrap(DruidTable.class));
     } else {
       return Optional.empty();
     }
@@ -42,12 +41,13 @@
    * Check if a druidRel is a simple table scan, or a projection that merely remaps columns without transforming them.
    * Like {@link #isScanOrProject} but more restrictive: only remappings are allowed.
    *
-   * @param druidRel  the rel to check
-   * @param canBeJoin consider a 'join' that doesn't do anything fancy to be a scan-or-mapping too.
+   * @param druidRel         the rel to check
+   * @param canBeJoinOrUnion consider a {@link DruidJoinQueryRel} or {@link DruidUnionDataSourceRel} as possible
+   *                         scans-and-mappings too.
    */
-  public static boolean isScanOrMapping(final DruidRel<?> druidRel, final boolean canBeJoin)
+  public static boolean isScanOrMapping(final DruidRel<?> druidRel, final boolean canBeJoinOrUnion)
   {
-    if (isScanOrProject(druidRel, canBeJoin)) {
+    if (isScanOrProject(druidRel, canBeJoinOrUnion)) {
       // Like isScanOrProject, but don't allow transforming projections.
       final PartialDruidQuery partialQuery = druidRel.getPartialDruidQuery();
       return partialQuery.getSelectProject() == null || partialQuery.getSelectProject().isMapping();
@@ -59,12 +59,14 @@
   /**
    * Check if a druidRel is a simple table scan or a scan + projection.
    *
-   * @param druidRel  the rel to check
-   * @param canBeJoin consider a 'join' that doesn't do anything fancy to be a scan-or-mapping too.
+   * @param druidRel         the rel to check
+   * @param canBeJoinOrUnion consider a {@link DruidJoinQueryRel} or {@link DruidUnionDataSourceRel} as possible
+   *                         scans-and-mappings too.
    */
-  private static boolean isScanOrProject(final DruidRel<?> druidRel, final boolean canBeJoin)
+  private static boolean isScanOrProject(final DruidRel<?> druidRel, final boolean canBeJoinOrUnion)
   {
-    if (druidRel instanceof DruidQueryRel || (canBeJoin && druidRel instanceof DruidJoinQueryRel)) {
+    if (druidRel instanceof DruidQueryRel || (canBeJoinOrUnion && (druidRel instanceof DruidJoinQueryRel
+                                                                   || druidRel instanceof DruidUnionDataSourceRel))) {
       final PartialDruidQuery partialQuery = druidRel.getPartialDruidQuery();
       final PartialDruidQuery.Stage stage = partialQuery.stage();
       return (stage == PartialDruidQuery.Stage.SCAN || stage == PartialDruidQuery.Stage.SELECT_PROJECT)
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidUnionDataSourceRel.java b/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidUnionDataSourceRel.java
new file mode 100644
index 0000000..9823946
--- /dev/null
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidUnionDataSourceRel.java
@@ -0,0 +1,280 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.sql.calcite.rel;
+
+import com.fasterxml.jackson.core.JsonProcessingException;
+import org.apache.calcite.plan.RelOptCluster;
+import org.apache.calcite.plan.RelOptCost;
+import org.apache.calcite.plan.RelOptPlanner;
+import org.apache.calcite.plan.RelOptRule;
+import org.apache.calcite.plan.RelTraitSet;
+import org.apache.calcite.rel.RelNode;
+import org.apache.calcite.rel.RelWriter;
+import org.apache.calcite.rel.core.Union;
+import org.apache.calcite.rel.metadata.RelMetadataQuery;
+import org.apache.calcite.rel.type.RelDataType;
+import org.apache.druid.java.util.common.StringUtils;
+import org.apache.druid.java.util.common.guava.Sequence;
+import org.apache.druid.query.DataSource;
+import org.apache.druid.query.TableDataSource;
+import org.apache.druid.query.UnionDataSource;
+import org.apache.druid.segment.column.RowSignature;
+import org.apache.druid.sql.calcite.table.RowSignatures;
+
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+/**
+ * Represents a query on top of a {@link UnionDataSource}. This is used to represent a "UNION ALL" of regular table
+ * datasources.
+ *
+ * See {@link DruidUnionRel} for a version that can union any set of queries together (not just regular tables),
+ * but also must be the outermost rel of a query plan. In the future we expect that {@link UnionDataSource} will gain
+ * the ability to union query datasources together, and then this class could replace {@link DruidUnionRel}.
+ */
+public class DruidUnionDataSourceRel extends DruidRel<DruidUnionDataSourceRel>
+{
+  private static final TableDataSource DUMMY_DATA_SOURCE = new TableDataSource("__union__");
+
+  private final Union unionRel;
+  private final List<String> unionColumnNames;
+  private final PartialDruidQuery partialQuery;
+
+  private DruidUnionDataSourceRel(
+      final RelOptCluster cluster,
+      final RelTraitSet traitSet,
+      final Union unionRel,
+      final List<String> unionColumnNames,
+      final PartialDruidQuery partialQuery,
+      final QueryMaker queryMaker
+  )
+  {
+    super(cluster, traitSet, queryMaker);
+    this.unionRel = unionRel;
+    this.unionColumnNames = unionColumnNames;
+    this.partialQuery = partialQuery;
+  }
+
+  public static DruidUnionDataSourceRel create(
+      final Union unionRel,
+      final List<String> unionColumnNames,
+      final QueryMaker queryMaker
+  )
+  {
+    return new DruidUnionDataSourceRel(
+        unionRel.getCluster(),
+        unionRel.getTraitSet(),
+        unionRel,
+        unionColumnNames,
+        PartialDruidQuery.create(unionRel),
+        queryMaker
+    );
+  }
+
+  public List<String> getUnionColumnNames()
+  {
+    return unionColumnNames;
+  }
+
+  @Override
+  public PartialDruidQuery getPartialDruidQuery()
+  {
+    return partialQuery;
+  }
+
+  @Override
+  public DruidUnionDataSourceRel withPartialQuery(final PartialDruidQuery newQueryBuilder)
+  {
+    return new DruidUnionDataSourceRel(
+        getCluster(),
+        getTraitSet().plusAll(newQueryBuilder.getRelTraits()),
+        unionRel,
+        unionColumnNames,
+        newQueryBuilder,
+        getQueryMaker()
+    );
+  }
+
+  @Override
+  public Sequence<Object[]> runQuery()
+  {
+    // runQuery doesn't need to finalize aggregations, because the fact that runQuery is happening suggests this
+    // is the outermost query and it will actually get run as a native query. Druid's native query layer will
+    // finalize aggregations for the outermost query even if we don't explicitly ask it to.
+
+    return getQueryMaker().runQuery(toDruidQuery(false));
+  }
+
+  @Override
+  public DruidQuery toDruidQuery(final boolean finalizeAggregations)
+  {
+    final List<TableDataSource> dataSources = new ArrayList<>();
+    RowSignature signature = null;
+
+    for (final RelNode relNode : unionRel.getInputs()) {
+      final DruidRel<?> druidRel = (DruidRel<?>) relNode;
+      if (!DruidRels.isScanOrMapping(druidRel, false)) {
+        throw new CannotBuildQueryException(druidRel);
+      }
+
+      final DruidQuery query = druidRel.toDruidQuery(false);
+      final DataSource dataSource = query.getDataSource();
+      if (!(dataSource instanceof TableDataSource)) {
+        throw new CannotBuildQueryException(druidRel);
+      }
+
+      if (signature == null) {
+        signature = query.getOutputRowSignature();
+      }
+
+      if (signature.getColumnNames().equals(query.getOutputRowSignature().getColumnNames())) {
+        dataSources.add((TableDataSource) dataSource);
+      } else {
+        throw new CannotBuildQueryException(druidRel);
+      }
+    }
+
+    if (signature == null) {
+      // No inputs.
+      throw new CannotBuildQueryException(unionRel);
+    }
+
+    // Sanity check: the columns we think we're building off must equal the "unionColumnNames" registered at
+    // creation time.
+    if (!signature.getColumnNames().equals(unionColumnNames)) {
+      throw new CannotBuildQueryException(unionRel);
+    }
+
+    return partialQuery.build(
+        new UnionDataSource(dataSources),
+        signature,
+        getPlannerContext(),
+        getCluster().getRexBuilder(),
+        finalizeAggregations
+    );
+  }
+
+  @Override
+  public DruidQuery toDruidQueryForExplaining()
+  {
+    return partialQuery.build(
+        DUMMY_DATA_SOURCE,
+        RowSignatures.fromRelDataType(
+            unionRel.getRowType().getFieldNames(),
+            unionRel.getRowType()
+        ),
+        getPlannerContext(),
+        getCluster().getRexBuilder(),
+        false
+    );
+  }
+
+  @Override
+  public DruidUnionDataSourceRel asDruidConvention()
+  {
+    return new DruidUnionDataSourceRel(
+        getCluster(),
+        getTraitSet().replace(DruidConvention.instance()),
+        (Union) unionRel.copy(
+            unionRel.getTraitSet(),
+            unionRel.getInputs()
+                    .stream()
+                    .map(input -> RelOptRule.convert(input, DruidConvention.instance()))
+                    .collect(Collectors.toList())
+        ),
+        unionColumnNames,
+        partialQuery,
+        getQueryMaker()
+    );
+  }
+
+  @Override
+  public List<RelNode> getInputs()
+  {
+    return unionRel.getInputs();
+  }
+
+  @Override
+  public void replaceInput(int ordinalInParent, RelNode p)
+  {
+    unionRel.replaceInput(ordinalInParent, p);
+  }
+
+  @Override
+  public RelNode copy(final RelTraitSet traitSet, final List<RelNode> inputs)
+  {
+    return new DruidUnionDataSourceRel(
+        getCluster(),
+        traitSet,
+        (Union) unionRel.copy(unionRel.getTraitSet(), inputs),
+        unionColumnNames,
+        partialQuery,
+        getQueryMaker()
+    );
+  }
+
+  @Override
+  public Set<String> getDataSourceNames()
+  {
+    final Set<String> retVal = new HashSet<>();
+
+    for (final RelNode input : unionRel.getInputs()) {
+      retVal.addAll(((DruidRel<?>) input).getDataSourceNames());
+    }
+
+    return retVal;
+  }
+
+  @Override
+  public RelWriter explainTerms(RelWriter pw)
+  {
+    final String queryString;
+    final DruidQuery druidQuery = toDruidQueryForExplaining();
+
+    try {
+      queryString = getQueryMaker().getJsonMapper().writeValueAsString(druidQuery.getQuery());
+    }
+    catch (JsonProcessingException e) {
+      throw new RuntimeException(e);
+    }
+
+    for (int i = 0; i < unionRel.getInputs().size(); i++) {
+      pw.input(StringUtils.format("input#%d", i), unionRel.getInputs().get(i));
+    }
+
+    return pw.item("query", queryString)
+             .item("signature", druidQuery.getOutputRowSignature());
+  }
+
+  @Override
+  protected RelDataType deriveRowType()
+  {
+    return partialQuery.getRowType();
+  }
+
+  @Override
+  public RelOptCost computeSelfCost(final RelOptPlanner planner, final RelMetadataQuery mq)
+  {
+    return planner.getCostFactory().makeZeroCost();
+  }
+}
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidUnionRel.java b/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidUnionRel.java
index fb71f83..a83869e 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidUnionRel.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidUnionRel.java
@@ -33,6 +33,7 @@
 import org.apache.druid.java.util.common.StringUtils;
 import org.apache.druid.java.util.common.guava.Sequence;
 import org.apache.druid.java.util.common.guava.Sequences;
+import org.apache.druid.query.UnionDataSource;
 
 import javax.annotation.Nullable;
 import java.util.ArrayList;
@@ -40,6 +41,16 @@
 import java.util.Set;
 import java.util.stream.Collectors;
 
+/**
+ * Represents a "UNION ALL" of various input {@link DruidRel}. Note that this rel doesn't represent a real native query,
+ * but rather, it represents the concatenation of a series of native queries in the SQL layer. Therefore,
+ * {@link #getPartialDruidQuery()} returns null, and this rel cannot be built on top of. It must be the outer rel in a
+ * query plan.
+ *
+ * See {@link DruidUnionDataSourceRel} for a version that does a regular Druid query using a {@link UnionDataSource}.
+ * In the future we expect that {@link UnionDataSource} will gain the ability to union query datasources together, and
+ * then this rel could be replaced by {@link DruidUnionDataSourceRel}.
+ */
 public class DruidUnionRel extends DruidRel<DruidUnionRel>
 {
   private final RelDataType rowType;
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidJoinRule.java b/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidJoinRule.java
index cb7fc05..b9f8f34 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidJoinRule.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidJoinRule.java
@@ -75,7 +75,15 @@
   public boolean matches(RelOptRuleCall call)
   {
     final Join join = call.rel(0);
-    return canHandleCondition(join.getCondition(), join.getLeft().getRowType());
+    final DruidRel<?> left = call.rel(1);
+    final DruidRel<?> right = call.rel(2);
+
+    // 1) Can handle the join condition as a native join.
+    // 2) Left has a PartialDruidQuery (i.e., is a real query, not top-level UNION ALL).
+    // 3) Right has a PartialDruidQuery (i.e., is a real query, not top-level UNION ALL).
+    return canHandleCondition(join.getCondition(), join.getLeft().getRowType())
+           && left.getPartialDruidQuery() != null
+           && right.getPartialDruidQuery() != null;
   }
 
   @Override
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidRules.java b/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidRules.java
index 75f5038..6b449ed 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidRules.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidRules.java
@@ -90,6 +90,7 @@
         DruidOuterQueryRule.PROJECT_AGGREGATE,
         DruidOuterQueryRule.AGGREGATE_SORT_PROJECT,
         DruidUnionRule.instance(),
+        DruidUnionDataSourceRule.instance(),
         DruidSortUnionRule.instance(),
         DruidJoinRule.instance()
     );
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidSortUnionRule.java b/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidSortUnionRule.java
index c350fef..0ef41fb 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidSortUnionRule.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidSortUnionRule.java
@@ -27,6 +27,9 @@
 
 import java.util.Collections;
 
+/**
+ * Rule that pushes LIMIT and OFFSET into a {@link DruidUnionRel}.
+ */
 public class DruidSortUnionRule extends RelOptRule
 {
   private static final DruidSortUnionRule INSTANCE = new DruidSortUnionRule();
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidUnionDataSourceRule.java b/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidUnionDataSourceRule.java
new file mode 100644
index 0000000..49e77ce
--- /dev/null
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidUnionDataSourceRule.java
@@ -0,0 +1,168 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.sql.calcite.rule;
+
+import org.apache.calcite.plan.RelOptRule;
+import org.apache.calcite.plan.RelOptRuleCall;
+import org.apache.calcite.rel.RelNode;
+import org.apache.calcite.rel.core.Union;
+import org.apache.calcite.util.mapping.Mappings;
+import org.apache.druid.java.util.common.ISE;
+import org.apache.druid.query.TableDataSource;
+import org.apache.druid.segment.column.RowSignature;
+import org.apache.druid.sql.calcite.rel.DruidQueryRel;
+import org.apache.druid.sql.calcite.rel.DruidRel;
+import org.apache.druid.sql.calcite.rel.DruidRels;
+import org.apache.druid.sql.calcite.rel.DruidUnionDataSourceRel;
+import org.apache.druid.sql.calcite.rel.PartialDruidQuery;
+import org.apache.druid.sql.calcite.table.DruidTable;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Optional;
+
+/**
+ * Creates a {@link DruidUnionDataSourceRel} from various {@link DruidQueryRel} inputs that represent simple
+ * table scans.
+ */
+public class DruidUnionDataSourceRule extends RelOptRule
+{
+  private static final DruidUnionDataSourceRule INSTANCE = new DruidUnionDataSourceRule();
+
+  private DruidUnionDataSourceRule()
+  {
+    super(
+        operand(
+            Union.class,
+            operand(DruidRel.class, none()),
+            operand(DruidQueryRel.class, none())
+        )
+    );
+  }
+
+  public static DruidUnionDataSourceRule instance()
+  {
+    return INSTANCE;
+  }
+
+  @Override
+  public boolean matches(RelOptRuleCall call)
+  {
+    final Union unionRel = call.rel(0);
+    final DruidRel<?> firstDruidRel = call.rel(1);
+    final DruidQueryRel secondDruidRel = call.rel(2);
+
+    // Can only do UNION ALL of inputs that have compatible schemas (or schema mappings).
+    return unionRel.all && isUnionCompatible(firstDruidRel, secondDruidRel);
+  }
+
+  @Override
+  public void onMatch(final RelOptRuleCall call)
+  {
+    final Union unionRel = call.rel(0);
+    final DruidRel<?> firstDruidRel = call.rel(1);
+    final DruidQueryRel secondDruidRel = call.rel(2);
+
+    if (firstDruidRel instanceof DruidUnionDataSourceRel) {
+      // Unwrap and flatten the inputs to the Union.
+      final RelNode newUnionRel = call.builder()
+                                      .pushAll(firstDruidRel.getInputs())
+                                      .push(secondDruidRel)
+                                      .union(true, firstDruidRel.getInputs().size() + 1)
+                                      .build();
+
+      call.transformTo(
+          DruidUnionDataSourceRel.create(
+              (Union) newUnionRel,
+              getColumnNamesIfTableOrUnion(firstDruidRel).get(),
+              firstDruidRel.getQueryMaker()
+          )
+      );
+    } else {
+      // Sanity check.
+      if (!(firstDruidRel instanceof DruidQueryRel)) {
+        throw new ISE("Expected first rel to be a DruidQueryRel, but it was %s", firstDruidRel.getClass().getName());
+      }
+
+      call.transformTo(
+          DruidUnionDataSourceRel.create(
+              unionRel,
+              getColumnNamesIfTableOrUnion(firstDruidRel).get(),
+              firstDruidRel.getQueryMaker()
+          )
+      );
+    }
+  }
+
+  private static boolean isUnionCompatible(final DruidRel<?> first, final DruidRel<?> second)
+  {
+    final Optional<List<String>> columnNames = getColumnNamesIfTableOrUnion(first);
+    return columnNames.isPresent() && columnNames.equals(getColumnNamesIfTableOrUnion(second));
+  }
+
+  static Optional<List<String>> getColumnNamesIfTableOrUnion(final DruidRel<?> druidRel)
+  {
+    final PartialDruidQuery partialQuery = druidRel.getPartialDruidQuery();
+
+    final Optional<DruidTable> druidTable =
+        DruidRels.druidTableIfLeafRel(druidRel)
+                 .filter(table -> table.getDataSource() instanceof TableDataSource);
+
+    if (druidTable.isPresent() && DruidRels.isScanOrMapping(druidRel, false)) {
+      // This rel is a table scan or mapping.
+
+      if (partialQuery.stage() == PartialDruidQuery.Stage.SCAN) {
+        return Optional.of(druidTable.get().getRowSignature().getColumnNames());
+      } else {
+        // Sanity check. Expected to be true due to the "scan or mapping" check.
+        if (partialQuery.stage() != PartialDruidQuery.Stage.SELECT_PROJECT) {
+          throw new ISE("Expected stage %s but got %s", PartialDruidQuery.Stage.SELECT_PROJECT, partialQuery.stage());
+        }
+
+        // Apply the mapping (with additional sanity checks).
+        final RowSignature tableSignature = druidTable.get().getRowSignature();
+        final Mappings.TargetMapping mapping = partialQuery.getSelectProject().getMapping();
+
+        if (mapping.getSourceCount() != tableSignature.size()) {
+          throw new ISE(
+              "Expected mapping with %d columns but got %d columns",
+              tableSignature.size(),
+              mapping.getSourceCount()
+          );
+        }
+
+        final List<String> retVal = new ArrayList<>();
+
+        for (int i = 0; i < mapping.getTargetCount(); i++) {
+          final int sourceField = mapping.getSourceOpt(i);
+          retVal.add(tableSignature.getColumnName(sourceField));
+        }
+
+        return Optional.of(retVal);
+      }
+    } else if (!druidTable.isPresent() && druidRel instanceof DruidUnionDataSourceRel) {
+      // This rel is a union itself.
+
+      return Optional.of(((DruidUnionDataSourceRel) druidRel).getUnionColumnNames());
+    } else {
+      return Optional.empty();
+    }
+  }
+}
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidUnionRule.java b/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidUnionRule.java
index 1aefb98..e97ed2b 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidUnionRule.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidUnionRule.java
@@ -28,13 +28,22 @@
 
 import java.util.List;
 
+/**
+ * Rule that creates a {@link DruidUnionRel} from some {@link DruidRel} inputs.
+ */
 public class DruidUnionRule extends RelOptRule
 {
   private static final DruidUnionRule INSTANCE = new DruidUnionRule();
 
   private DruidUnionRule()
   {
-    super(operand(Union.class, unordered(operand(DruidRel.class, any()))));
+    super(
+        operand(
+            Union.class,
+            operand(DruidRel.class, none()),
+            operand(DruidRel.class, none())
+        )
+    );
   }
 
   public static DruidUnionRule instance()
@@ -43,20 +52,29 @@
   }
 
   @Override
+  public boolean matches(RelOptRuleCall call)
+  {
+    // Make DruidUnionRule and DruidUnionDataSourceRule mutually exclusive.
+    return !DruidUnionDataSourceRule.instance().matches(call);
+  }
+
+  @Override
   public void onMatch(final RelOptRuleCall call)
   {
     final Union unionRel = call.rel(0);
-    final DruidRel someDruidRel = call.rel(1);
+    final DruidRel<?> someDruidRel = call.rel(1);
     final List<RelNode> inputs = unionRel.getInputs();
 
+    // Can only do UNION ALL.
     if (unionRel.all) {
-      // Can only do UNION ALL.
-      call.transformTo(DruidUnionRel.create(
-          someDruidRel.getQueryMaker(),
-          unionRel.getRowType(),
-          inputs,
-          -1
-      ));
+      call.transformTo(
+          DruidUnionRel.create(
+              someDruidRel.getQueryMaker(),
+              unionRel.getRowType(),
+              inputs,
+              -1
+          )
+      );
     }
   }
 }
diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
index 2658661..d1e9c70 100644
--- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
+++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
@@ -47,6 +47,7 @@
 import org.apache.druid.query.QueryException;
 import org.apache.druid.query.ResourceLimitExceededException;
 import org.apache.druid.query.TableDataSource;
+import org.apache.druid.query.UnionDataSource;
 import org.apache.druid.query.aggregation.CountAggregatorFactory;
 import org.apache.druid.query.aggregation.DoubleMaxAggregatorFactory;
 import org.apache.druid.query.aggregation.DoubleMinAggregatorFactory;
@@ -3588,7 +3589,7 @@
   }
 
   @Test
-  public void testUnionAll() throws Exception
+  public void testUnionAllQueries() throws Exception
   {
     testQuery(
         "SELECT COUNT(*) FROM foo UNION ALL SELECT SUM(cnt) FROM foo UNION ALL SELECT COUNT(*) FROM foo",
@@ -3620,7 +3621,7 @@
   }
 
   @Test
-  public void testUnionAllWithLimit() throws Exception
+  public void testUnionAllQueriesWithLimit() throws Exception
   {
     testQuery(
         "SELECT * FROM ("
@@ -3647,6 +3648,431 @@
   }
 
   @Test
+  public void testUnionAllDifferentTablesWithMapping() throws Exception
+  {
+    testQuery(
+        "SELECT\n"
+        + "dim1, dim2, SUM(m1), COUNT(*)\n"
+        + "FROM (SELECT dim1, dim2, m1 FROM foo UNION ALL SELECT dim1, dim2, m1 FROM numfoo)\n"
+        + "WHERE dim2 = 'a' OR dim2 = 'def'\n"
+        + "GROUP BY 1, 2",
+        ImmutableList.of(
+            GroupByQuery.builder()
+                        .setDataSource(
+                            new UnionDataSource(
+                                ImmutableList.of(
+                                    new TableDataSource(CalciteTests.DATASOURCE1),
+                                    new TableDataSource(CalciteTests.DATASOURCE3)
+                                )
+                            )
+                        )
+                        .setInterval(querySegmentSpec(Filtration.eternity()))
+                        .setGranularity(Granularities.ALL)
+                        .setDimFilter(in("dim2", ImmutableList.of("def", "a"), null))
+                        .setDimensions(
+                            new DefaultDimensionSpec("dim1", "d0"),
+                            new DefaultDimensionSpec("dim2", "d1")
+                        )
+                        .setAggregatorSpecs(
+                            aggregators(
+                                new DoubleSumAggregatorFactory("a0", "m1"),
+                                new CountAggregatorFactory("a1")
+                            )
+                        )
+                        .setContext(QUERY_CONTEXT_DEFAULT)
+                        .build()
+        ),
+        ImmutableList.of(
+            new Object[]{"", "a", 2.0, 2L},
+            new Object[]{"1", "a", 8.0, 2L}
+        )
+    );
+  }
+
+  @Test
+  public void testJoinUnionAllDifferentTablesWithMapping() throws Exception
+  {
+    testQuery(
+        "SELECT\n"
+        + "dim1, dim2, SUM(m1), COUNT(*)\n"
+        + "FROM (SELECT dim1, dim2, m1 FROM foo UNION ALL SELECT dim1, dim2, m1 FROM numfoo)\n"
+        + "WHERE dim2 = 'a' OR dim2 = 'def'\n"
+        + "GROUP BY 1, 2",
+        ImmutableList.of(
+            GroupByQuery.builder()
+                        .setDataSource(
+                            new UnionDataSource(
+                                ImmutableList.of(
+                                    new TableDataSource(CalciteTests.DATASOURCE1),
+                                    new TableDataSource(CalciteTests.DATASOURCE3)
+                                )
+                            )
+                        )
+                        .setInterval(querySegmentSpec(Filtration.eternity()))
+                        .setGranularity(Granularities.ALL)
+                        .setDimFilter(in("dim2", ImmutableList.of("def", "a"), null))
+                        .setDimensions(
+                            new DefaultDimensionSpec("dim1", "d0"),
+                            new DefaultDimensionSpec("dim2", "d1")
+                        )
+                        .setAggregatorSpecs(
+                            aggregators(
+                                new DoubleSumAggregatorFactory("a0", "m1"),
+                                new CountAggregatorFactory("a1")
+                            )
+                        )
+                        .setContext(QUERY_CONTEXT_DEFAULT)
+                        .build()
+        ),
+        ImmutableList.of(
+            new Object[]{"", "a", 2.0, 2L},
+            new Object[]{"1", "a", 8.0, 2L}
+        )
+    );
+  }
+
+  @Test
+  public void testUnionAllTablesColumnCountMismatch() throws Exception
+  {
+    expectedException.expect(ValidationException.class);
+    expectedException.expectMessage("Column count mismatch in UNION ALL");
+
+    testQuery(
+        "SELECT\n"
+        + "dim1, dim2, SUM(m1), COUNT(*)\n"
+        + "FROM (SELECT * FROM foo UNION ALL SELECT * FROM numfoo)\n"
+        + "WHERE dim2 = 'a' OR dim2 = 'def'\n"
+        + "GROUP BY 1, 2",
+        ImmutableList.of(),
+        ImmutableList.of()
+    );
+  }
+
+  @Test
+  public void testUnionAllTablesColumnTypeMismatchFloatLong() throws Exception
+  {
+    // "m1" has a different type in foo and foo2 (float vs long), but this query is OK anyway because they can both
+    // be implicitly cast to double.
+
+    testQuery(
+        "SELECT\n"
+        + "dim1, dim2, SUM(m1), COUNT(*)\n"
+        + "FROM (SELECT dim1, dim2, m1 FROM foo2 UNION ALL SELECT dim1, dim2, m1 FROM foo)\n"
+        + "WHERE dim2 = 'a' OR dim2 = 'en'\n"
+        + "GROUP BY 1, 2",
+        ImmutableList.of(
+            GroupByQuery.builder()
+                        .setDataSource(
+                            new UnionDataSource(
+                                ImmutableList.of(
+                                    new TableDataSource(CalciteTests.DATASOURCE2),
+                                    new TableDataSource(CalciteTests.DATASOURCE1)
+                                )
+                            )
+                        )
+                        .setInterval(querySegmentSpec(Filtration.eternity()))
+                        .setGranularity(Granularities.ALL)
+                        .setDimFilter(in("dim2", ImmutableList.of("en", "a"), null))
+                        .setDimensions(
+                            new DefaultDimensionSpec("dim1", "d0"),
+                            new DefaultDimensionSpec("dim2", "d1")
+                        )
+                        .setAggregatorSpecs(
+                            aggregators(
+                                new DoubleSumAggregatorFactory("a0", "m1"),
+                                new CountAggregatorFactory("a1")
+                            )
+                        )
+                        .setContext(QUERY_CONTEXT_DEFAULT)
+                        .build()
+        ),
+        ImmutableList.of(
+            new Object[]{"", "a", 1.0, 1L},
+            new Object[]{"1", "a", 4.0, 1L},
+            new Object[]{"druid", "en", 1.0, 1L}
+        )
+    );
+  }
+
+  @Test
+  public void testUnionAllTablesColumnTypeMismatchStringLong()
+  {
+    // "dim3" has a different type in foo and foo2 (string vs long), which requires a casting subquery, so this
+    // query cannot be planned.
+
+    assertQueryIsUnplannable(
+        "SELECT\n"
+        + "dim3, dim2, SUM(m1), COUNT(*)\n"
+        + "FROM (SELECT dim3, dim2, m1 FROM foo2 UNION ALL SELECT dim3, dim2, m1 FROM foo)\n"
+        + "WHERE dim2 = 'a' OR dim2 = 'en'\n"
+        + "GROUP BY 1, 2"
+    );
+  }
+
+  @Test
+  public void testUnionAllTablesWhenMappingIsRequired()
+  {
+    // Cannot plan this UNION ALL operation, because the column swap would require generating a subquery.
+
+    assertQueryIsUnplannable(
+        "SELECT\n"
+        + "c, COUNT(*)\n"
+        + "FROM (SELECT dim1 AS c, m1 FROM foo UNION ALL SELECT dim2 AS c, m1 FROM numfoo)\n"
+        + "WHERE c = 'a' OR c = 'def'\n"
+        + "GROUP BY 1"
+    );
+  }
+
+  @Test
+  public void testUnionAllTablesWhenCastAndMappingIsRequired()
+  {
+    // Cannot plan this UNION ALL operation, because the column swap would require generating a subquery.
+
+    assertQueryIsUnplannable(
+        "SELECT\n"
+        + "c, COUNT(*)\n"
+        + "FROM (SELECT dim1 AS c, m1 FROM foo UNION ALL SELECT cnt AS c, m1 FROM numfoo)\n"
+        + "WHERE c = 'a' OR c = 'def'\n"
+        + "GROUP BY 1"
+    );
+  }
+
+  @Test
+  public void testUnionAllSameTableTwice() throws Exception
+  {
+    testQuery(
+        "SELECT\n"
+        + "dim1, dim2, SUM(m1), COUNT(*)\n"
+        + "FROM (SELECT * FROM foo UNION ALL SELECT * FROM foo)\n"
+        + "WHERE dim2 = 'a' OR dim2 = 'def'\n"
+        + "GROUP BY 1, 2",
+        ImmutableList.of(
+            GroupByQuery.builder()
+                        .setDataSource(
+                            new UnionDataSource(
+                                ImmutableList.of(
+                                    new TableDataSource(CalciteTests.DATASOURCE1),
+                                    new TableDataSource(CalciteTests.DATASOURCE1)
+                                )
+                            )
+                        )
+                        .setInterval(querySegmentSpec(Filtration.eternity()))
+                        .setGranularity(Granularities.ALL)
+                        .setDimFilter(in("dim2", ImmutableList.of("def", "a"), null))
+                        .setDimensions(
+                            new DefaultDimensionSpec("dim1", "d0"),
+                            new DefaultDimensionSpec("dim2", "d1")
+                        )
+                        .setAggregatorSpecs(
+                            aggregators(
+                                new DoubleSumAggregatorFactory("a0", "m1"),
+                                new CountAggregatorFactory("a1")
+                            )
+                        )
+                        .setContext(QUERY_CONTEXT_DEFAULT)
+                        .build()
+        ),
+        ImmutableList.of(
+            new Object[]{"", "a", 2.0, 2L},
+            new Object[]{"1", "a", 8.0, 2L}
+        )
+    );
+  }
+
+  @Test
+  public void testUnionAllSameTableTwiceWithSameMapping() throws Exception
+  {
+    testQuery(
+        "SELECT\n"
+        + "dim1, dim2, SUM(m1), COUNT(*)\n"
+        + "FROM (SELECT dim1, dim2, m1 FROM foo UNION ALL SELECT dim1, dim2, m1 FROM foo)\n"
+        + "WHERE dim2 = 'a' OR dim2 = 'def'\n"
+        + "GROUP BY 1, 2",
+        ImmutableList.of(
+            GroupByQuery.builder()
+                        .setDataSource(
+                            new UnionDataSource(
+                                ImmutableList.of(
+                                    new TableDataSource(CalciteTests.DATASOURCE1),
+                                    new TableDataSource(CalciteTests.DATASOURCE1)
+                                )
+                            )
+                        )
+                        .setInterval(querySegmentSpec(Filtration.eternity()))
+                        .setGranularity(Granularities.ALL)
+                        .setDimFilter(in("dim2", ImmutableList.of("def", "a"), null))
+                        .setDimensions(
+                            new DefaultDimensionSpec("dim1", "d0"),
+                            new DefaultDimensionSpec("dim2", "d1")
+                        )
+                        .setAggregatorSpecs(
+                            aggregators(
+                                new DoubleSumAggregatorFactory("a0", "m1"),
+                                new CountAggregatorFactory("a1")
+                            )
+                        )
+                        .setContext(QUERY_CONTEXT_DEFAULT)
+                        .build()
+        ),
+        ImmutableList.of(
+            new Object[]{"", "a", 2.0, 2L},
+            new Object[]{"1", "a", 8.0, 2L}
+        )
+    );
+  }
+
+  @Test
+  public void testUnionAllSameTableTwiceWithDifferentMapping()
+  {
+    // Cannot plan this UNION ALL operation, because the column swap would require generating a subquery.
+
+    assertQueryIsUnplannable(
+        "SELECT\n"
+        + "dim1, dim2, SUM(m1), COUNT(*)\n"
+        + "FROM (SELECT dim1, dim2, m1 FROM foo UNION ALL SELECT dim2, dim1, m1 FROM foo)\n"
+        + "WHERE dim2 = 'a' OR dim2 = 'def'\n"
+        + "GROUP BY 1, 2"
+    );
+  }
+
+  @Test
+  public void testUnionAllSameTableThreeTimes() throws Exception
+  {
+    testQuery(
+        "SELECT\n"
+        + "dim1, dim2, SUM(m1), COUNT(*)\n"
+        + "FROM (SELECT * FROM foo UNION ALL SELECT * FROM foo UNION ALL SELECT * FROM foo)\n"
+        + "WHERE dim2 = 'a' OR dim2 = 'def'\n"
+        + "GROUP BY 1, 2",
+        ImmutableList.of(
+            GroupByQuery.builder()
+                        .setDataSource(
+                            new UnionDataSource(
+                                ImmutableList.of(
+                                    new TableDataSource(CalciteTests.DATASOURCE1),
+                                    new TableDataSource(CalciteTests.DATASOURCE1),
+                                    new TableDataSource(CalciteTests.DATASOURCE1)
+                                )
+                            )
+                        )
+                        .setInterval(querySegmentSpec(Filtration.eternity()))
+                        .setGranularity(Granularities.ALL)
+                        .setDimFilter(in("dim2", ImmutableList.of("def", "a"), null))
+                        .setDimensions(
+                            new DefaultDimensionSpec("dim1", "d0"),
+                            new DefaultDimensionSpec("dim2", "d1")
+                        )
+                        .setAggregatorSpecs(
+                            aggregators(
+                                new DoubleSumAggregatorFactory("a0", "m1"),
+                                new CountAggregatorFactory("a1")
+                            )
+                        )
+                        .setContext(QUERY_CONTEXT_DEFAULT)
+                        .build()
+        ),
+        ImmutableList.of(
+            new Object[]{"", "a", 3.0, 3L},
+            new Object[]{"1", "a", 12.0, 3L}
+        )
+    );
+  }
+
+  @Test
+  public void testUnionAllThreeTablesColumnCountMismatch1() throws Exception
+  {
+    expectedException.expect(ValidationException.class);
+    expectedException.expectMessage("Column count mismatch in UNION ALL");
+
+    testQuery(
+        "SELECT\n"
+        + "dim1, dim2, SUM(m1), COUNT(*)\n"
+        + "FROM (SELECT * FROM numfoo UNION ALL SELECT * FROM foo UNION ALL SELECT * from foo)\n"
+        + "WHERE dim2 = 'a' OR dim2 = 'def'\n"
+        + "GROUP BY 1, 2",
+        ImmutableList.of(),
+        ImmutableList.of()
+    );
+  }
+
+  @Test
+  public void testUnionAllThreeTablesColumnCountMismatch2() throws Exception
+  {
+    expectedException.expect(ValidationException.class);
+    expectedException.expectMessage("Column count mismatch in UNION ALL");
+
+    testQuery(
+        "SELECT\n"
+        + "dim1, dim2, SUM(m1), COUNT(*)\n"
+        + "FROM (SELECT * FROM numfoo UNION ALL SELECT * FROM foo UNION ALL SELECT * from foo)\n"
+        + "WHERE dim2 = 'a' OR dim2 = 'def'\n"
+        + "GROUP BY 1, 2",
+        ImmutableList.of(),
+        ImmutableList.of()
+    );
+  }
+
+  @Test
+  public void testUnionAllThreeTablesColumnCountMismatch3() throws Exception
+  {
+    expectedException.expect(ValidationException.class);
+    expectedException.expectMessage("Column count mismatch in UNION ALL");
+
+    testQuery(
+        "SELECT\n"
+        + "dim1, dim2, SUM(m1), COUNT(*)\n"
+        + "FROM (SELECT * FROM foo UNION ALL SELECT * FROM foo UNION ALL SELECT * from numfoo)\n"
+        + "WHERE dim2 = 'a' OR dim2 = 'def'\n"
+        + "GROUP BY 1, 2",
+        ImmutableList.of(),
+        ImmutableList.of()
+    );
+  }
+
+  @Test
+  public void testUnionAllSameTableThreeTimesWithSameMapping() throws Exception
+  {
+    testQuery(
+        "SELECT\n"
+        + "dim1, dim2, SUM(m1), COUNT(*)\n"
+        + "FROM (SELECT dim1, dim2, m1 FROM foo UNION ALL SELECT dim1, dim2, m1 FROM foo UNION ALL SELECT dim1, dim2, m1 FROM foo)\n"
+        + "WHERE dim2 = 'a' OR dim2 = 'def'\n"
+        + "GROUP BY 1, 2",
+        ImmutableList.of(
+            GroupByQuery.builder()
+                        .setDataSource(
+                            new UnionDataSource(
+                                ImmutableList.of(
+                                    new TableDataSource(CalciteTests.DATASOURCE1),
+                                    new TableDataSource(CalciteTests.DATASOURCE1),
+                                    new TableDataSource(CalciteTests.DATASOURCE1)
+                                )
+                            )
+                        )
+                        .setInterval(querySegmentSpec(Filtration.eternity()))
+                        .setGranularity(Granularities.ALL)
+                        .setDimFilter(in("dim2", ImmutableList.of("def", "a"), null))
+                        .setDimensions(
+                            new DefaultDimensionSpec("dim1", "d0"),
+                            new DefaultDimensionSpec("dim2", "d1")
+                        )
+                        .setAggregatorSpecs(
+                            aggregators(
+                                new DoubleSumAggregatorFactory("a0", "m1"),
+                                new CountAggregatorFactory("a1")
+                            )
+                        )
+                        .setContext(QUERY_CONTEXT_DEFAULT)
+                        .build()
+        ),
+        ImmutableList.of(
+            new Object[]{"", "a", 3.0, 3L},
+            new Object[]{"1", "a", 12.0, 3L}
+        )
+    );
+  }
+
+  @Test
   public void testPruneDeadAggregators() throws Exception
   {
     // Test for ProjectAggregatePruneUnusedCallRule.
@@ -7083,6 +7509,58 @@
   }
 
   @Test
+  public void testExactCountDistinctUsingSubqueryOnUnionAllTables() throws Exception
+  {
+    testQuery(
+        "SELECT\n"
+        + "  SUM(cnt),\n"
+        + "  COUNT(*)\n"
+        + "FROM (\n"
+        + "  SELECT dim2, SUM(cnt) AS cnt\n"
+        + "  FROM (SELECT * FROM druid.foo UNION ALL SELECT * FROM druid.foo)\n"
+        + "  GROUP BY dim2\n"
+        + ")",
+        ImmutableList.of(
+            GroupByQuery.builder()
+                        .setDataSource(
+                            new QueryDataSource(
+                                GroupByQuery.builder()
+                                            .setDataSource(
+                                                new UnionDataSource(
+                                                    ImmutableList.of(
+                                                        new TableDataSource(CalciteTests.DATASOURCE1),
+                                                        new TableDataSource(CalciteTests.DATASOURCE1)
+                                                    )
+                                                )
+                                            )
+                                            .setInterval(querySegmentSpec(Filtration.eternity()))
+                                            .setGranularity(Granularities.ALL)
+                                            .setDimensions(dimensions(new DefaultDimensionSpec("dim2", "d0")))
+                                            .setAggregatorSpecs(aggregators(new LongSumAggregatorFactory("a0", "cnt")))
+                                            .setContext(QUERY_CONTEXT_DEFAULT)
+                                            .build()
+                            )
+                        )
+                        .setInterval(querySegmentSpec(Filtration.eternity()))
+                        .setGranularity(Granularities.ALL)
+                        .setAggregatorSpecs(aggregators(
+                            new LongSumAggregatorFactory("_a0", "a0"),
+                            new CountAggregatorFactory("_a1")
+                        ))
+                        .setContext(QUERY_CONTEXT_DEFAULT)
+                        .build()
+        ),
+        NullHandling.replaceWithDefault() ?
+        ImmutableList.of(
+            new Object[]{12L, 3L}
+        ) :
+        ImmutableList.of(
+            new Object[]{12L, 4L}
+        )
+    );
+  }
+
+  @Test
   public void testMinMaxAvgDailyCountWithLimit() throws Exception
   {
     // Cannot vectorize due to virtual columns.
@@ -9036,6 +9514,52 @@
 
   @Test
   @Parameters(source = QueryContextForJoinProvider.class)
+  public void testJoinUnionTablesOnLookup(Map<String, Object> queryContext) throws Exception
+  {
+    // Cannot vectorize JOIN operator.
+    cannotVectorize();
+
+    testQuery(
+        "SELECT lookyloo.v, COUNT(*)\n"
+        + "FROM\n"
+        + "  (SELECT dim2 FROM foo UNION ALL SELECT dim2 FROM numfoo) u\n"
+        + "  LEFT JOIN lookup.lookyloo ON u.dim2 = lookyloo.k\n"
+        + "WHERE lookyloo.v <> 'xa'\n"
+        + "GROUP BY lookyloo.v",
+        queryContext,
+        ImmutableList.of(
+            GroupByQuery.builder()
+                        .setDataSource(
+                            join(
+                                new UnionDataSource(
+                                    ImmutableList.of(
+                                        new TableDataSource(CalciteTests.DATASOURCE1),
+                                        new TableDataSource(CalciteTests.DATASOURCE3)
+                                    )
+                                ),
+                                new LookupDataSource("lookyloo"),
+                                "j0.",
+                                equalsCondition(DruidExpression.fromColumn("dim2"), DruidExpression.fromColumn("j0.k")),
+                                JoinType.LEFT
+                            )
+                        )
+                        .setInterval(querySegmentSpec(Filtration.eternity()))
+                        .setDimFilter(not(selector("j0.v", "xa", null)))
+                        .setGranularity(Granularities.ALL)
+                        .setDimensions(dimensions(new DefaultDimensionSpec("j0.v", "d0")))
+                        .setAggregatorSpecs(aggregators(new CountAggregatorFactory("a0")))
+                        .setContext(queryContext)
+                        .build()
+        ),
+        ImmutableList.of(
+            new Object[]{NULL_STRING, 6L},
+            new Object[]{"xabc", 2L}
+        )
+    );
+  }
+
+  @Test
+  @Parameters(source = QueryContextForJoinProvider.class)
   public void testFilterAndGroupByLookupUsingJoinOperator(Map<String, Object> queryContext) throws Exception
   {
     // Cannot vectorize JOIN operator.
diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/rel/DruidRelsTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/rel/DruidRelsTest.java
index 7e0ed30..17b91a5 100644
--- a/sql/src/test/java/org/apache/druid/sql/calcite/rel/DruidRelsTest.java
+++ b/sql/src/test/java/org/apache/druid/sql/calcite/rel/DruidRelsTest.java
@@ -19,21 +19,28 @@
 
 package org.apache.druid.sql.calcite.rel;
 
+import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableSet;
+import org.apache.calcite.plan.RelOptTable;
 import org.apache.calcite.rel.core.Filter;
 import org.apache.calcite.rel.core.Project;
+import org.apache.calcite.util.mapping.MappingType;
+import org.apache.calcite.util.mapping.Mappings;
+import org.apache.druid.sql.calcite.table.DruidTable;
 import org.easymock.EasyMock;
 import org.junit.Assert;
 import org.junit.Test;
 
 import javax.annotation.Nullable;
+import java.util.List;
+import java.util.function.Consumer;
 
 public class DruidRelsTest
 {
   @Test
   public void test_isScanOrMapping_scan()
   {
-    final DruidRel<?> rel = mockDruidRel(DruidQueryRel.class, PartialDruidQuery.Stage.SCAN, null, null);
+    final DruidRel<?> rel = mockDruidRel(DruidQueryRel.class, PartialDruidQuery.Stage.SCAN, null, null, null);
     Assert.assertTrue(DruidRels.isScanOrMapping(rel, true));
     Assert.assertTrue(DruidRels.isScanOrMapping(rel, false));
     EasyMock.verify(rel, rel.getPartialDruidQuery());
@@ -42,7 +49,16 @@
   @Test
   public void test_isScanOrMapping_scanJoin()
   {
-    final DruidRel<?> rel = mockDruidRel(DruidJoinQueryRel.class, PartialDruidQuery.Stage.SCAN, null, null);
+    final DruidRel<?> rel = mockDruidRel(DruidJoinQueryRel.class, PartialDruidQuery.Stage.SCAN, null, null, null);
+    Assert.assertTrue(DruidRels.isScanOrMapping(rel, true));
+    Assert.assertFalse(DruidRels.isScanOrMapping(rel, false));
+    EasyMock.verify(rel, rel.getPartialDruidQuery());
+  }
+
+  @Test
+  public void test_isScanOrMapping_scanUnion()
+  {
+    final DruidRel<?> rel = mockDruidRel(DruidUnionDataSourceRel.class, PartialDruidQuery.Stage.SCAN, null, null, null);
     Assert.assertTrue(DruidRels.isScanOrMapping(rel, true));
     Assert.assertFalse(DruidRels.isScanOrMapping(rel, false));
     EasyMock.verify(rel, rel.getPartialDruidQuery());
@@ -51,7 +67,7 @@
   @Test
   public void test_isScanOrMapping_scanQuery()
   {
-    final DruidRel<?> rel = mockDruidRel(DruidOuterQueryRel.class, PartialDruidQuery.Stage.SCAN, null, null);
+    final DruidRel<?> rel = mockDruidRel(DruidOuterQueryRel.class, PartialDruidQuery.Stage.SCAN, null, null, null);
     Assert.assertFalse(DruidRels.isScanOrMapping(rel, true));
     Assert.assertFalse(DruidRels.isScanOrMapping(rel, false));
     EasyMock.verify(rel, rel.getPartialDruidQuery());
@@ -60,10 +76,11 @@
   @Test
   public void test_isScanOrMapping_mapping()
   {
-    final Project project = mockProject(true);
+    final Project project = mockMappingProject(ImmutableList.of(1, 0), 2);
     final DruidRel<?> rel = mockDruidRel(
         DruidQueryRel.class,
         PartialDruidQuery.Stage.SELECT_PROJECT,
+        null,
         project,
         null
     );
@@ -76,10 +93,11 @@
   @Test
   public void test_isScanOrMapping_mappingJoin()
   {
-    final Project project = mockProject(true);
+    final Project project = mockMappingProject(ImmutableList.of(1, 0), 2);
     final DruidRel<?> rel = mockDruidRel(
         DruidJoinQueryRel.class,
         PartialDruidQuery.Stage.SELECT_PROJECT,
+        null,
         project,
         null
     );
@@ -90,12 +108,47 @@
   }
 
   @Test
+  public void test_isScanOrMapping_mappingUnion()
+  {
+    final Project project = mockMappingProject(ImmutableList.of(1, 0), 2);
+    final DruidRel<?> rel = mockDruidRel(
+        DruidUnionDataSourceRel.class,
+        PartialDruidQuery.Stage.SELECT_PROJECT,
+        null,
+        project,
+        null
+    );
+    Assert.assertTrue(DruidRels.isScanOrMapping(rel, true));
+    Assert.assertFalse(DruidRels.isScanOrMapping(rel, false));
+
+    EasyMock.verify(rel, rel.getPartialDruidQuery(), project);
+  }
+
+  @Test
+  public void test_isScanOrMapping_mappingQuery()
+  {
+    final Project project = mockMappingProject(ImmutableList.of(1, 0), 2);
+    final DruidRel<?> rel = mockDruidRel(
+        DruidOuterQueryRel.class,
+        PartialDruidQuery.Stage.SELECT_PROJECT,
+        null,
+        project,
+        null
+    );
+    Assert.assertFalse(DruidRels.isScanOrMapping(rel, true));
+    Assert.assertFalse(DruidRels.isScanOrMapping(rel, false));
+
+    EasyMock.verify(rel, rel.getPartialDruidQuery(), project);
+  }
+
+  @Test
   public void test_isScanOrMapping_nonMapping()
   {
-    final Project project = mockProject(false);
+    final Project project = mockNonMappingProject();
     final DruidRel<?> rel = mockDruidRel(
         DruidQueryRel.class,
         PartialDruidQuery.Stage.SELECT_PROJECT,
+        null,
         project,
         null
     );
@@ -108,10 +161,28 @@
   @Test
   public void test_isScanOrMapping_nonMappingJoin()
   {
-    final Project project = mockProject(false);
+    final Project project = mockNonMappingProject();
     final DruidRel<?> rel = mockDruidRel(
         DruidJoinQueryRel.class,
         PartialDruidQuery.Stage.SELECT_PROJECT,
+        null,
+        project,
+        null
+    );
+    Assert.assertFalse(DruidRels.isScanOrMapping(rel, true));
+    Assert.assertFalse(DruidRels.isScanOrMapping(rel, false));
+
+    EasyMock.verify(rel, rel.getPartialDruidQuery(), project);
+  }
+
+  @Test
+  public void test_isScanOrMapping_nonMappingUnion()
+  {
+    final Project project = mockNonMappingProject();
+    final DruidRel<?> rel = mockDruidRel(
+        DruidUnionDataSourceRel.class,
+        PartialDruidQuery.Stage.SELECT_PROJECT,
+        null,
         project,
         null
     );
@@ -124,10 +195,11 @@
   @Test
   public void test_isScanOrMapping_filterThenProject()
   {
-    final Project project = mockProject(true);
+    final Project project = mockMappingProject(ImmutableList.of(1, 0), 2);
     final DruidRel<?> rel = mockDruidRel(
         DruidQueryRel.class,
         PartialDruidQuery.Stage.SELECT_PROJECT,
+        null,
         project,
         mockFilter()
     );
@@ -140,10 +212,28 @@
   @Test
   public void test_isScanOrMapping_filterThenProjectJoin()
   {
-    final Project project = mockProject(true);
+    final Project project = mockMappingProject(ImmutableList.of(1, 0), 2);
     final DruidRel<?> rel = mockDruidRel(
         DruidJoinQueryRel.class,
         PartialDruidQuery.Stage.SELECT_PROJECT,
+        null,
+        project,
+        mockFilter()
+    );
+    Assert.assertFalse(DruidRels.isScanOrMapping(rel, true));
+    Assert.assertFalse(DruidRels.isScanOrMapping(rel, false));
+
+    EasyMock.verify(rel, rel.getPartialDruidQuery(), project);
+  }
+
+  @Test
+  public void test_isScanOrMapping_filterThenProjectUnion()
+  {
+    final Project project = mockMappingProject(ImmutableList.of(1, 0), 2);
+    final DruidRel<?> rel = mockDruidRel(
+        DruidUnionDataSourceRel.class,
+        PartialDruidQuery.Stage.SELECT_PROJECT,
+        null,
         project,
         mockFilter()
     );
@@ -160,6 +250,7 @@
         DruidQueryRel.class,
         PartialDruidQuery.Stage.WHERE_FILTER,
         null,
+        null,
         mockFilter()
     );
     Assert.assertFalse(DruidRels.isScanOrMapping(rel, true));
@@ -175,6 +266,7 @@
         DruidJoinQueryRel.class,
         PartialDruidQuery.Stage.WHERE_FILTER,
         null,
+        null,
         mockFilter()
     );
     Assert.assertFalse(DruidRels.isScanOrMapping(rel, true));
@@ -192,10 +284,11 @@
     );
 
     for (PartialDruidQuery.Stage stage : PartialDruidQuery.Stage.values()) {
-      final Project project = mockProject(true);
+      final Project project = mockMappingProject(ImmutableList.of(1, 0), 2);
       final DruidRel<?> rel = mockDruidRel(
           DruidQueryRel.class,
           stage,
+          null,
           project,
           null
       );
@@ -207,34 +300,66 @@
     }
   }
 
-  private static DruidRel<?> mockDruidRel(
+  public static DruidRel<?> mockDruidRel(
       final Class<? extends DruidRel<?>> clazz,
       final PartialDruidQuery.Stage stage,
+      @Nullable DruidTable druidTable,
+      @Nullable Project selectProject,
+      @Nullable Filter whereFilter
+  )
+  {
+    return mockDruidRel(clazz, rel -> {}, stage, druidTable, selectProject, whereFilter);
+  }
+
+  public static <T extends DruidRel<?>> T mockDruidRel(
+      final Class<T> clazz,
+      final Consumer<T> additionalExpectationsFunction,
+      final PartialDruidQuery.Stage stage,
+      @Nullable DruidTable druidTable,
       @Nullable Project selectProject,
       @Nullable Filter whereFilter
   )
   {
     // DruidQueryRels rely on a ton of Calcite stuff like RelOptCluster, RelOptTable, etc, which is quite verbose to
     // create real instances of. So, tragically, we'll use EasyMock.
-    final DruidRel<?> mockRel = EasyMock.mock(clazz);
     final PartialDruidQuery mockPartialQuery = EasyMock.mock(PartialDruidQuery.class);
     EasyMock.expect(mockPartialQuery.stage()).andReturn(stage).anyTimes();
     EasyMock.expect(mockPartialQuery.getSelectProject()).andReturn(selectProject).anyTimes();
     EasyMock.expect(mockPartialQuery.getWhereFilter()).andReturn(whereFilter).anyTimes();
+
+    final RelOptTable mockRelOptTable = EasyMock.mock(RelOptTable.class);
+    EasyMock.expect(mockRelOptTable.unwrap(DruidTable.class)).andReturn(druidTable).anyTimes();
+
+    final T mockRel = EasyMock.mock(clazz);
     EasyMock.expect(mockRel.getPartialDruidQuery()).andReturn(mockPartialQuery).anyTimes();
-    EasyMock.replay(mockRel, mockPartialQuery);
+    EasyMock.expect(mockRel.getTable()).andReturn(mockRelOptTable).anyTimes();
+    additionalExpectationsFunction.accept(mockRel);
+
+    EasyMock.replay(mockRel, mockPartialQuery, mockRelOptTable);
     return mockRel;
   }
 
-  private static Project mockProject(final boolean mapping)
+  public static Project mockMappingProject(final List<Integer> sources, final int sourceCount)
   {
     final Project mockProject = EasyMock.mock(Project.class);
-    EasyMock.expect(mockProject.isMapping()).andReturn(mapping).anyTimes();
+    EasyMock.expect(mockProject.isMapping()).andReturn(true).anyTimes();
+
+    final Mappings.PartialMapping mapping = new Mappings.PartialMapping(sources, sourceCount, MappingType.SURJECTION);
+
+    EasyMock.expect(mockProject.getMapping()).andReturn(mapping).anyTimes();
     EasyMock.replay(mockProject);
     return mockProject;
   }
 
-  private static Filter mockFilter()
+  public static Project mockNonMappingProject()
+  {
+    final Project mockProject = EasyMock.mock(Project.class);
+    EasyMock.expect(mockProject.isMapping()).andReturn(false).anyTimes();
+    EasyMock.replay(mockProject);
+    return mockProject;
+  }
+
+  public static Filter mockFilter()
   {
     return EasyMock.mock(Filter.class);
   }
diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/rule/DruidUnionDataSourceRuleTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/rule/DruidUnionDataSourceRuleTest.java
new file mode 100644
index 0000000..3cf88b9
--- /dev/null
+++ b/sql/src/test/java/org/apache/druid/sql/calcite/rule/DruidUnionDataSourceRuleTest.java
@@ -0,0 +1,221 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.sql.calcite.rule;
+
+import com.google.common.collect.ImmutableList;
+import org.apache.calcite.rel.core.Project;
+import org.apache.calcite.util.mapping.Mappings;
+import org.apache.druid.query.TableDataSource;
+import org.apache.druid.segment.column.RowSignature;
+import org.apache.druid.segment.column.ValueType;
+import org.apache.druid.sql.calcite.rel.DruidJoinQueryRel;
+import org.apache.druid.sql.calcite.rel.DruidOuterQueryRel;
+import org.apache.druid.sql.calcite.rel.DruidQueryRel;
+import org.apache.druid.sql.calcite.rel.DruidRel;
+import org.apache.druid.sql.calcite.rel.DruidRelsTest;
+import org.apache.druid.sql.calcite.rel.DruidUnionDataSourceRel;
+import org.apache.druid.sql.calcite.rel.PartialDruidQuery;
+import org.apache.druid.sql.calcite.table.DruidTable;
+import org.easymock.EasyMock;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.util.Arrays;
+import java.util.List;
+import java.util.Optional;
+
+public class DruidUnionDataSourceRuleTest
+{
+  private final DruidTable fooDruidTable = new DruidTable(
+      new TableDataSource("foo"),
+      RowSignature.builder()
+                  .addTimeColumn()
+                  .add("col1", ValueType.STRING)
+                  .add("col2", ValueType.LONG)
+                  .build(),
+      false,
+      false
+  );
+
+  @Test
+  public void test_getColumnNamesIfTableOrUnion_tableScan()
+  {
+    final DruidRel<?> druidRel = DruidRelsTest.mockDruidRel(
+        DruidQueryRel.class,
+        PartialDruidQuery.Stage.SCAN,
+        fooDruidTable,
+        null,
+        null
+    );
+
+    Assert.assertEquals(
+        Optional.of(ImmutableList.of("__time", "col1", "col2")),
+        DruidUnionDataSourceRule.getColumnNamesIfTableOrUnion(druidRel)
+    );
+  }
+
+  @Test
+  public void test_getColumnNamesIfTableOrUnion_tableMapping()
+  {
+    final DruidRel<?> druidRel = DruidRelsTest.mockDruidRel(
+        DruidQueryRel.class,
+        PartialDruidQuery.Stage.SELECT_PROJECT,
+        fooDruidTable,
+        DruidRelsTest.mockMappingProject(ImmutableList.of(1), 3),
+        null
+    );
+
+    Assert.assertEquals(
+        Optional.of(ImmutableList.of("col1")),
+        DruidUnionDataSourceRule.getColumnNamesIfTableOrUnion(druidRel)
+    );
+  }
+
+  @Test
+  public void test_getColumnNamesIfTableOrUnion_tableProject()
+  {
+    final DruidRel<?> druidRel = DruidRelsTest.mockDruidRel(
+        DruidQueryRel.class,
+        PartialDruidQuery.Stage.SELECT_PROJECT,
+        fooDruidTable,
+        DruidRelsTest.mockNonMappingProject(),
+        null
+    );
+
+    Assert.assertEquals(
+        Optional.empty(),
+        DruidUnionDataSourceRule.getColumnNamesIfTableOrUnion(druidRel)
+    );
+  }
+
+  @Test
+  public void test_getColumnNamesIfTableOrUnion_tableFilterPlusMapping()
+  {
+    final DruidRel<?> druidRel = DruidRelsTest.mockDruidRel(
+        DruidQueryRel.class,
+        PartialDruidQuery.Stage.SELECT_PROJECT,
+        fooDruidTable,
+        DruidRelsTest.mockMappingProject(ImmutableList.of(1), 3),
+        DruidRelsTest.mockFilter()
+    );
+
+    Assert.assertEquals(
+        Optional.empty(),
+        DruidUnionDataSourceRule.getColumnNamesIfTableOrUnion(druidRel)
+    );
+  }
+
+  @Test
+  public void test_getColumnNamesIfTableOrUnion_unionScan()
+  {
+    final DruidUnionDataSourceRel druidRel = DruidRelsTest.mockDruidRel(
+        DruidUnionDataSourceRel.class,
+        rel -> EasyMock.expect(rel.getUnionColumnNames()).andReturn(fooDruidTable.getRowSignature().getColumnNames()),
+        PartialDruidQuery.Stage.SCAN,
+        null,
+        null,
+        null
+    );
+
+    Assert.assertEquals(
+        Optional.of(ImmutableList.of("__time", "col1", "col2")),
+        DruidUnionDataSourceRule.getColumnNamesIfTableOrUnion(druidRel)
+    );
+  }
+
+  @Test
+  public void test_getColumnNamesIfTableOrUnion_unionMapping()
+  {
+    final Project project = DruidRelsTest.mockMappingProject(ImmutableList.of(2, 1), 3);
+    final Mappings.TargetMapping mapping = project.getMapping();
+    final String[] mappedColumnNames = new String[mapping.getTargetCount()];
+
+    final List<String> columnNames = fooDruidTable.getRowSignature().getColumnNames();
+    for (int i = 0; i < columnNames.size(); i++) {
+      mappedColumnNames[mapping.getTargetOpt(i)] = columnNames.get(i);
+    }
+
+    final DruidUnionDataSourceRel druidRel = DruidRelsTest.mockDruidRel(
+        DruidUnionDataSourceRel.class,
+        rel -> EasyMock.expect(rel.getUnionColumnNames()).andReturn(Arrays.asList(mappedColumnNames)),
+        PartialDruidQuery.Stage.SELECT_PROJECT,
+        null,
+        project,
+        null
+    );
+
+    Assert.assertEquals(
+        Optional.of(ImmutableList.of("col2", "col1")),
+        DruidUnionDataSourceRule.getColumnNamesIfTableOrUnion(druidRel)
+    );
+  }
+
+  @Test
+  public void test_getColumnNamesIfTableOrUnion_unionProject()
+  {
+    final DruidUnionDataSourceRel druidRel = DruidRelsTest.mockDruidRel(
+        DruidUnionDataSourceRel.class,
+        rel -> EasyMock.expect(rel.getUnionColumnNames()).andReturn(fooDruidTable.getRowSignature().getColumnNames()),
+        PartialDruidQuery.Stage.SELECT_PROJECT,
+        null,
+        DruidRelsTest.mockNonMappingProject(),
+        null
+    );
+
+    Assert.assertEquals(
+        Optional.of(ImmutableList.of("__time", "col1", "col2")),
+        DruidUnionDataSourceRule.getColumnNamesIfTableOrUnion(druidRel)
+    );
+  }
+
+  @Test
+  public void test_getColumnNamesIfTableOrUnion_outerQuery()
+  {
+    final DruidRel<?> druidRel = DruidRelsTest.mockDruidRel(
+        DruidOuterQueryRel.class,
+        PartialDruidQuery.Stage.SELECT_PROJECT,
+        null,
+        null,
+        null
+    );
+
+    Assert.assertEquals(
+        Optional.empty(),
+        DruidUnionDataSourceRule.getColumnNamesIfTableOrUnion(druidRel)
+    );
+  }
+
+  @Test
+  public void test_getColumnNamesIfTableOrUnion_join()
+  {
+    final DruidRel<?> druidRel = DruidRelsTest.mockDruidRel(
+        DruidJoinQueryRel.class,
+        PartialDruidQuery.Stage.SELECT_PROJECT,
+        null,
+        null,
+        null
+    );
+
+    Assert.assertEquals(
+        Optional.empty(),
+        DruidUnionDataSourceRule.getColumnNamesIfTableOrUnion(druidRel)
+    );
+  }
+}
diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/util/CalciteTests.java b/sql/src/test/java/org/apache/druid/sql/calcite/util/CalciteTests.java
index f19ac7f..d8088ae 100644
--- a/sql/src/test/java/org/apache/druid/sql/calcite/util/CalciteTests.java
+++ b/sql/src/test/java/org/apache/druid/sql/calcite/util/CalciteTests.java
@@ -42,6 +42,7 @@
 import org.apache.druid.data.input.impl.InputRowParser;
 import org.apache.druid.data.input.impl.LongDimensionSchema;
 import org.apache.druid.data.input.impl.MapInputRowParser;
+import org.apache.druid.data.input.impl.StringDimensionSchema;
 import org.apache.druid.data.input.impl.TimeAndDimsParseSpec;
 import org.apache.druid.data.input.impl.TimestampSpec;
 import org.apache.druid.discovery.DiscoveryDruidNode;
@@ -71,6 +72,7 @@
 import org.apache.druid.query.aggregation.CountAggregatorFactory;
 import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory;
 import org.apache.druid.query.aggregation.FloatSumAggregatorFactory;
+import org.apache.druid.query.aggregation.LongSumAggregatorFactory;
 import org.apache.druid.query.aggregation.hyperloglog.HyperUniquesAggregatorFactory;
 import org.apache.druid.query.expression.LookupEnabledTestExprMacroTable;
 import org.apache.druid.query.expression.LookupExprMacro;
@@ -311,6 +313,25 @@
       .withRollup(false)
       .build();
 
+  private static final IncrementalIndexSchema INDEX_SCHEMA_DIFFERENT_DIM3_M1_TYPES = new IncrementalIndexSchema.Builder()
+      .withDimensionsSpec(
+          new DimensionsSpec(
+              ImmutableList.of(
+                  new StringDimensionSchema("dim1"),
+                  new StringDimensionSchema("dim2"),
+                  new LongDimensionSchema("dim3")
+              )
+          )
+      )
+      .withMetrics(
+          new CountAggregatorFactory("cnt"),
+          new LongSumAggregatorFactory("m1", "m1"),
+          new DoubleSumAggregatorFactory("m2", "m2"),
+          new HyperUniquesAggregatorFactory("unique_dim1", "dim1")
+      )
+      .withRollup(false)
+      .build();
+
   private static final IncrementalIndexSchema INDEX_SCHEMA_WITH_X_COLUMNS = new IncrementalIndexSchema.Builder()
       .withMetrics(
           new CountAggregatorFactory("cnt_x"),
@@ -536,18 +557,21 @@
           .put("t", "2000-01-01")
           .put("dim1", "דרואיד")
           .put("dim2", "he")
+          .put("dim3", 10L)
           .put("m1", 1.0)
           .build(),
       ImmutableMap.<String, Object>builder()
           .put("t", "2000-01-01")
           .put("dim1", "druid")
           .put("dim2", "en")
+          .put("dim3", 11L)
           .put("m1", 1.0)
           .build(),
       ImmutableMap.<String, Object>builder()
           .put("t", "2000-01-01")
           .put("dim1", "друид")
           .put("dim2", "ru")
+          .put("dim3", 12L)
           .put("m1", 1.0)
           .build()
   );
@@ -775,7 +799,7 @@
         .create()
         .tmpDir(new File(tmpDir, "2"))
         .segmentWriteOutMediumFactory(OffHeapMemorySegmentWriteOutMediumFactory.instance())
-        .schema(INDEX_SCHEMA)
+        .schema(INDEX_SCHEMA_DIFFERENT_DIM3_M1_TYPES)
         .rows(ROWS2)
         .buildMMappedIndex();