Fix left join SQL queries with IS NOT NULL filter (#11434)

This PR fixes the incorrect results for query : 

SELECT dim1, l1.k FROM foo LEFT JOIN (select k || '' as k from lookup.lookyloo group by 1) l1 ON foo.dim1 = l1.k WHERE l1.k IS NOT NULL (in CalciteQueryTests)
In the current code, the WHERE clause gets removed from the top of the left join and is pushed to the table foo
leading to incorrect results.
The fix for such a situation is done by :

Converting such left joins into inner joins (since logically the mentioned left join query is equivalent to an inner join) using Calcite while maintaining that the druid execution layer can execute such inner joins.
Preferring converted inner joins over original left joins in our cost model
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/planner/PlannerConfig.java b/sql/src/main/java/org/apache/druid/sql/calcite/planner/PlannerConfig.java
index f8da5cb..ee7dee1 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/planner/PlannerConfig.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/planner/PlannerConfig.java
@@ -32,6 +32,7 @@
   public static final String CTX_KEY_USE_APPROXIMATE_COUNT_DISTINCT = "useApproximateCountDistinct";
   public static final String CTX_KEY_USE_GROUPING_SET_FOR_EXACT_DISTINCT = "useGroupingSetForExactDistinct";
   public static final String CTX_KEY_USE_APPROXIMATE_TOPN = "useApproximateTopN";
+  public static final String CTX_COMPUTE_INNER_JOIN_COST_AS_FILTER = "computeInnerJoinCostAsFilter";
 
   @JsonProperty
   private Period metadataRefreshPeriod = new Period("PT1M");
@@ -63,6 +64,9 @@
   @JsonProperty
   private boolean useGroupingSetForExactDistinct = false;
 
+  @JsonProperty
+  private boolean computeInnerJoinCostAsFilter = true;
+
   public long getMetadataSegmentPollPeriod()
   {
     return metadataSegmentPollPeriod;
@@ -120,6 +124,11 @@
     return serializeComplexValues;
   }
 
+  public boolean isComputeInnerJoinCostAsFilter()
+  {
+    return computeInnerJoinCostAsFilter;
+  }
+
   public PlannerConfig withOverrides(final Map<String, Object> context)
   {
     if (context == null) {
@@ -150,6 +159,9 @@
     newConfig.metadataSegmentCacheEnable = isMetadataSegmentCacheEnable();
     newConfig.metadataSegmentPollPeriod = getMetadataSegmentPollPeriod();
     newConfig.serializeComplexValues = shouldSerializeComplexValues();
+    newConfig.computeInnerJoinCostAsFilter = getContextBoolean(context,
+                                                               CTX_COMPUTE_INNER_JOIN_COST_AS_FILTER,
+                                                               computeInnerJoinCostAsFilter);
     return newConfig;
   }
 
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidJoinQueryRel.java b/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidJoinQueryRel.java
index 1b28a98..0c53257 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidJoinQueryRel.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidJoinQueryRel.java
@@ -53,6 +53,7 @@
 import org.apache.druid.sql.calcite.expression.DruidExpression;
 import org.apache.druid.sql.calcite.expression.Expressions;
 import org.apache.druid.sql.calcite.planner.Calcites;
+import org.apache.druid.sql.calcite.planner.PlannerConfig;
 import org.apache.druid.sql.calcite.planner.PlannerContext;
 import org.apache.druid.sql.calcite.table.RowSignatures;
 
@@ -72,6 +73,7 @@
   private final Filter leftFilter;
   private final PartialDruidQuery partialQuery;
   private final Join joinRel;
+  private final PlannerConfig plannerConfig;
   private RelNode left;
   private RelNode right;
 
@@ -90,6 +92,7 @@
     this.right = joinRel.getRight();
     this.leftFilter = leftFilter;
     this.partialQuery = partialQuery;
+    this.plannerConfig = queryMaker.getPlannerContext().getPlannerConfig();
   }
 
   /**
@@ -316,6 +319,9 @@
       cost = CostEstimates.COST_JOIN_SUBQUERY;
     } else {
       cost = partialQuery.estimateCost();
+      if (joinRel.getJoinType() == JoinRelType.INNER && plannerConfig.isComputeInnerJoinCostAsFilter()) {
+        cost *= CostEstimates.MULTIPLIER_FILTER; // treating inner join like a filter on left table
+      }
     }
 
     if (computeRightRequiresSubquery(getSomeDruidChild(right))) {
@@ -351,7 +357,7 @@
     return !DruidRels.isScanOrMapping(left, true);
   }
 
-  private static boolean computeRightRequiresSubquery(final DruidRel<?> right)
+  public static boolean computeRightRequiresSubquery(final DruidRel<?> right)
   {
     // Right requires a subquery unless it's a scan or mapping on top of a global datasource.
     // ideally this would involve JoinableFactory.isDirectlyJoinable to check that the global datasources
@@ -385,7 +391,7 @@
     return Pair.of(rightPrefix, signatureBuilder.build());
   }
 
-  private static DruidRel<?> getSomeDruidChild(final RelNode child)
+  public static DruidRel<?> getSomeDruidChild(final RelNode child)
   {
     if (child instanceof DruidRel) {
       return (DruidRel<?>) child;
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidJoinRule.java b/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidJoinRule.java
index b277a8a..b692a98 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidJoinRule.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidJoinRule.java
@@ -35,21 +35,25 @@
 import org.apache.calcite.rex.RexInputRef;
 import org.apache.calcite.rex.RexLiteral;
 import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.rex.RexSlot;
 import org.apache.calcite.rex.RexUtil;
 import org.apache.calcite.sql.SqlKind;
 import org.apache.calcite.sql.fun.SqlStdOperatorTable;
 import org.apache.calcite.tools.RelBuilder;
 import org.apache.calcite.util.ImmutableBitSet;
 import org.apache.druid.java.util.common.Pair;
+import org.apache.druid.query.LookupDataSource;
 import org.apache.druid.sql.calcite.rel.DruidJoinQueryRel;
 import org.apache.druid.sql.calcite.rel.DruidQueryRel;
 import org.apache.druid.sql.calcite.rel.DruidRel;
 import org.apache.druid.sql.calcite.rel.PartialDruidQuery;
 
 import java.util.ArrayList;
+import java.util.HashSet;
 import java.util.List;
 import java.util.Objects;
 import java.util.Optional;
+import java.util.Set;
 import java.util.Stack;
 import java.util.stream.Collectors;
 
@@ -87,7 +91,7 @@
     // 1) Can handle the join condition as a native join.
     // 2) Left has a PartialDruidQuery (i.e., is a real query, not top-level UNION ALL).
     // 3) Right has a PartialDruidQuery (i.e., is a real query, not top-level UNION ALL).
-    return canHandleCondition(join.getCondition(), join.getLeft().getRowType())
+    return canHandleCondition(join.getCondition(), join.getLeft().getRowType(), right)
            && left.getPartialDruidQuery() != null
            && right.getPartialDruidQuery() != null;
   }
@@ -108,7 +112,7 @@
 
     // Already verified to be present in "matches", so just call "get".
     // Can't be final, because we're going to reassign it up to a couple of times.
-    ConditionAnalysis conditionAnalysis = analyzeCondition(join.getCondition(), join.getLeft().getRowType()).get();
+    ConditionAnalysis conditionAnalysis = analyzeCondition(join.getCondition(), join.getLeft().getRowType(), right).get();
     final boolean isLeftDirectAccessPossible = enableLeftScanDirect && (left instanceof DruidQueryRel);
 
     if (left.getPartialDruidQuery().stage() == PartialDruidQuery.Stage.SELECT_PROJECT
@@ -195,21 +199,22 @@
    * Returns whether {@link #analyzeCondition} would return something.
    */
   @VisibleForTesting
-  static boolean canHandleCondition(final RexNode condition, final RelDataType leftRowType)
+  static boolean canHandleCondition(final RexNode condition, final RelDataType leftRowType, DruidRel<?> right)
   {
-    return analyzeCondition(condition, leftRowType).isPresent();
+    return analyzeCondition(condition, leftRowType, right).isPresent();
   }
 
   /**
    * If this condition is an AND of some combination of (1) literals; (2) equality conditions of the form
    * {@code f(LeftRel) = RightColumn}, then return a {@link ConditionAnalysis}.
    */
-  private static Optional<ConditionAnalysis> analyzeCondition(final RexNode condition, final RelDataType leftRowType)
+  private static Optional<ConditionAnalysis> analyzeCondition(final RexNode condition, final RelDataType leftRowType, DruidRel<?> right)
   {
     final List<RexNode> subConditions = decomposeAnd(condition);
     final List<Pair<RexNode, RexInputRef>> equalitySubConditions = new ArrayList<>();
     final List<RexLiteral> literalSubConditions = new ArrayList<>();
     final int numLeftFields = leftRowType.getFieldCount();
+    final Set<RexInputRef> rightColumns = new HashSet<>();
 
     for (RexNode subCondition : subConditions) {
       if (RexUtil.isLiteral(subCondition, true)) {
@@ -243,15 +248,32 @@
 
       if (isLeftExpression(operands.get(0), numLeftFields) && isRightInputRef(operands.get(1), numLeftFields)) {
         equalitySubConditions.add(Pair.of(operands.get(0), (RexInputRef) operands.get(1)));
+        rightColumns.add((RexInputRef) operands.get(1));
       } else if (isRightInputRef(operands.get(0), numLeftFields)
                  && isLeftExpression(operands.get(1), numLeftFields)) {
         equalitySubConditions.add(Pair.of(operands.get(1), (RexInputRef) operands.get(0)));
+        rightColumns.add((RexInputRef) operands.get(0));
       } else {
         // Cannot handle this condition.
         return Optional.empty();
       }
     }
 
+    // if the the right side requires a subquery, then even lookup will transformed to a QueryDataSource
+    // thereby allowing join conditions on both k and v columns of the lookup
+    if (right != null && !DruidJoinQueryRel.computeRightRequiresSubquery(DruidJoinQueryRel.getSomeDruidChild(right))
+        && right instanceof DruidQueryRel) {
+      DruidQueryRel druidQueryRel = (DruidQueryRel) right;
+      if (druidQueryRel.getDruidTable().getDataSource() instanceof LookupDataSource) {
+        long distinctRightColumns = rightColumns.stream().map(RexSlot::getIndex).distinct().count();
+        if (distinctRightColumns > 1) {
+          // it means that the join's right side is lookup and the join condition contains both key and value columns of lookup.
+          // currently, the lookup datasource in the native engine doesn't support using value column in the join condition.
+          return Optional.empty();
+        }
+      }
+    }
+
     return Optional.of(new ConditionAnalysis(numLeftFields, equalitySubConditions, literalSubConditions));
   }
 
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/rule/FilterJoinExcludePushToChildRule.java b/sql/src/main/java/org/apache/druid/sql/calcite/rule/FilterJoinExcludePushToChildRule.java
index ca9ed46..42759a8 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/rule/FilterJoinExcludePushToChildRule.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/rule/FilterJoinExcludePushToChildRule.java
@@ -37,25 +37,30 @@
 import org.apache.calcite.rex.RexCall;
 import org.apache.calcite.rex.RexNode;
 import org.apache.calcite.rex.RexUtil;
+import org.apache.calcite.sql.SqlKind;
 import org.apache.calcite.tools.RelBuilder;
 import org.apache.calcite.tools.RelBuilderFactory;
+import org.apache.druid.common.config.NullHandling;
+import org.apache.druid.java.util.common.Pair;
 
 import java.util.ArrayList;
 import java.util.List;
+import java.util.Objects;
 
 /**
  * This class is a copy (with modification) of {@link FilterJoinRule}. Specifically, this class contains a
  * subset of code from {@link FilterJoinRule} for the codepath involving {@link FilterJoinRule#FILTER_ON_JOIN}
- * Everything has been keep as-is from {@link FilterJoinRule} except for the modification
- * of {@link #classifyFilters(List, JoinRelType, boolean, List)} method called in the
+ * Everything has been keep as-is from {@link FilterJoinRule} except for :
+ * 1. the modification of {@link #classifyFilters(List, JoinRelType, boolean, List)} method called in the
  * {@link #perform(RelOptRuleCall, Filter, Join)} method of this class.
+ * 2. removing redundant 'IS NOT NULL' filters from inner join filter condition
  * The {@link #classifyFilters(List, JoinRelType, boolean, List)} method is based of {@link RelOptUtil#classifyFilters}.
  * The difference is that the modfied method use in thsi class will not not push filters to the children.
  * Hence, filters will either stay where they are or are pushed to the join (if they originated from above the join).
  *
- * This modification is needed due to the bug described in https://github.com/apache/druid/pull/9773
- * This class and it's modification can be removed, switching back to the default Rule provided in Calcite's
- * {@link FilterJoinRule} when https://github.com/apache/druid/issues/9843 is resolved.
+ * The modification of {@link #classifyFilters(List, JoinRelType, boolean, List)} is needed due to the bug described in
+ * https://github.com/apache/druid/pull/9773. This class and it's modification can be removed, switching back to the
+ * default Rule provided in Calcite's {@link FilterJoinRule} when https://github.com/apache/druid/issues/9843 is resolved.
  */
 
 public abstract class FilterJoinExcludePushToChildRule extends FilterJoinRule
@@ -180,6 +185,9 @@
       filterPushed = true;
     }
 
+    // once the filters are pushed to join from top, try to remove redudant 'IS NOT NULL' filters
+    removeRedundantIsNotNullFilters(joinFilters, joinType, NullHandling.sqlCompatible());
+
     // if nothing actually got pushed and there is nothing leftover,
     // then this rule is a no-op
     if ((!filterPushed && joinType == join.getJoinType()) || joinFilters.isEmpty()) {
@@ -292,4 +300,51 @@
     // Did anything change?
     return !filtersToRemove.isEmpty();
   }
+
+  /**
+   * This tries to find all the 'IS NOT NULL' filters in an inner join whose checking column is also
+   * a part of an equi-condition between the two tables. It removes such 'IS NOT NULL' filters from join since
+   * the equi-condition will never return true for null input, thus making the 'IS NOT NULL' filter a no-op.
+   * @param joinFilters
+   * @param joinType
+   * @param isSqlCompatible
+   */
+  static void removeRedundantIsNotNullFilters(List<RexNode> joinFilters, JoinRelType joinType, boolean isSqlCompatible)
+  {
+    if (joinType != JoinRelType.INNER || !isSqlCompatible) {
+      return; // only works for inner joins in SQL mode
+    }
+
+    ImmutableList.Builder<RexNode> isNotNullFiltersBuilder = ImmutableList.builder();
+    ImmutableList.Builder<Pair<RexNode, RexNode>> equalityFiltersOperandBuilder = ImmutableList.builder();
+
+    joinFilters.stream().filter(joinFilter -> joinFilter instanceof RexCall).forEach(joinFilter -> {
+      if (joinFilter.isA(SqlKind.IS_NOT_NULL)) {
+        isNotNullFiltersBuilder.add(joinFilter);
+      } else if (joinFilter.isA(SqlKind.EQUALS)) {
+        List<RexNode> operands = ((RexCall) joinFilter).getOperands();
+        if (operands.size() == 2 && operands.stream().noneMatch(Objects::isNull)) {
+          equalityFiltersOperandBuilder.add(new Pair<>(operands.get(0), operands.get(1)));
+        }
+      }
+    });
+
+    List<Pair<RexNode, RexNode>> equalityFilters = equalityFiltersOperandBuilder.build();
+    ImmutableList.Builder<RexNode> removableFilters = ImmutableList.builder();
+    for (RexNode isNotNullFilter : isNotNullFiltersBuilder.build()) {
+      List<RexNode> operands = ((RexCall) isNotNullFilter).getOperands();
+      boolean canDrop = false;
+      for (Pair<RexNode, RexNode> equalityFilterOperands : equalityFilters) {
+        if ((equalityFilterOperands.lhs != null && equalityFilterOperands.lhs.equals(operands.get(0))) ||
+            (equalityFilterOperands.rhs != null && equalityFilterOperands.rhs.equals(operands.get(0)))) {
+          canDrop = true;
+          break;
+        }
+      }
+      if (canDrop) {
+        removableFilters.add(isNotNullFilter);
+      }
+    }
+    joinFilters.removeAll(removableFilters.build());
+  }
 }
diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/BaseCalciteQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/BaseCalciteQueryTest.java
index 15937a5..c8988cf 100644
--- a/sql/src/test/java/org/apache/druid/sql/calcite/BaseCalciteQueryTest.java
+++ b/sql/src/test/java/org/apache/druid/sql/calcite/BaseCalciteQueryTest.java
@@ -958,6 +958,13 @@
               .build(),
           };
     }
+
+    public static Map<String, Object> withOverrides(Map<String, Object> originalContext, Map<String, Object> overrides)
+    {
+      Map<String, Object> contextWithOverrides = new HashMap<>(originalContext);
+      contextWithOverrides.putAll(overrides);
+      return contextWithOverrides;
+    }
   }
 
   protected Map<String, Object> withLeftDirectAccessEnabled(Map<String, Object> context)
diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
index e3ac2e2..9404b03 100644
--- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
+++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
@@ -145,6 +145,8 @@
 import java.util.Map;
 import java.util.stream.Collectors;
 
+import static org.apache.druid.query.QueryContexts.JOIN_FILTER_REWRITE_ENABLE_KEY;
+
 @RunWith(JUnitParamsRunner.class)
 public class CalciteQueryTest extends BaseCalciteQueryTest
 {
@@ -385,7 +387,7 @@
                                 )
                                 .intervals(querySegmentSpec(Filtration.eternity()))
                                 .limit(10)
-                                .columns("__time", "cnt", "dim1", "dim2", "dim3", "j0.m1", "m1", "m2", "unique_dim1")
+                                .columns("dim2", "j0.m1", "m1", "m2")
                                 .context(QUERY_CONTEXT_DEFAULT)
                                 .build()
                         )
@@ -16847,6 +16849,40 @@
   @Parameters(source = QueryContextForJoinProvider.class)
   public void testInnerJoinOnTwoInlineDataSourcesWithOuterWhere(Map<String, Object> queryContext) throws Exception
   {
+    Druids.ScanQueryBuilder baseScanBuilder = newScanQueryBuilder()
+        .dataSource(
+            join(
+                new QueryDataSource(
+                    newScanQueryBuilder()
+                        .dataSource(CalciteTests.DATASOURCE1)
+                        .intervals(querySegmentSpec(Filtration.eternity()))
+                        .filters(new SelectorDimFilter("dim1", "10.1", null))
+                        .virtualColumns(expressionVirtualColumn("v0", "\'10.1\'", ValueType.STRING))
+                        .columns(ImmutableList.of("__time", "v0"))
+                        .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST)
+                        .context(queryContext)
+                        .build()
+                ),
+                new QueryDataSource(
+                    newScanQueryBuilder()
+                        .dataSource(CalciteTests.DATASOURCE1)
+                        .intervals(querySegmentSpec(Filtration.eternity()))
+                        .filters(new SelectorDimFilter("dim1", "10.1", null))
+                        .columns(ImmutableList.of("dim1"))
+                        .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST)
+                        .context(queryContext)
+                        .build()
+                ),
+                "j0.",
+                equalsCondition(DruidExpression.fromColumn("v0"), DruidExpression.fromColumn("j0.dim1")),
+                JoinType.INNER
+            )
+        )
+        .intervals(querySegmentSpec(Filtration.eternity()))
+        .virtualColumns(expressionVirtualColumn("_v0", "\'10.1\'", ValueType.STRING))
+        .columns("__time", "_v0")
+        .context(queryContext);
+
     testQuery(
         "with abc as\n"
         + "(\n"
@@ -16855,42 +16891,8 @@
         + "SELECT t1.dim1, t1.\"__time\" from abc as t1 INNER JOIN abc as t2 on t1.dim1 = t2.dim1 WHERE t1.dim1 = '10.1'\n",
         queryContext,
         ImmutableList.of(
-            newScanQueryBuilder()
-                .dataSource(
-                    join(
-                        new QueryDataSource(
-                            newScanQueryBuilder()
-                                .dataSource(CalciteTests.DATASOURCE1)
-                                .intervals(querySegmentSpec(Filtration.eternity()))
-                                .filters(new SelectorDimFilter("dim1", "10.1", null))
-                                .virtualColumns(expressionVirtualColumn("v0", "\'10.1\'", ValueType.STRING))
-                                .columns(ImmutableList.of("__time", "v0"))
-                                .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST)
-                                .context(queryContext)
-                                .build()
-                        ),
-                        new QueryDataSource(
-                            newScanQueryBuilder()
-                                .dataSource(CalciteTests.DATASOURCE1)
-                                .intervals(querySegmentSpec(Filtration.eternity()))
-                                .filters(new SelectorDimFilter("dim1", "10.1", null))
-                                .columns(ImmutableList.of("dim1"))
-                                .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST)
-                                .context(queryContext)
-                                .build()
-                        ),
-                        "j0.",
-                        equalsCondition(DruidExpression.fromColumn("v0"), DruidExpression.fromColumn("j0.dim1")),
-                        JoinType.INNER
-                    )
-                )
-                .intervals(querySegmentSpec(Filtration.eternity()))
-                .virtualColumns(expressionVirtualColumn("_v0", "\'10.1\'", ValueType.STRING))
-                .columns("__time", "_v0")
-                .filters(new NotDimFilter(new SelectorDimFilter("v0", null, null)))
-                .context(queryContext)
-                .build()
-        ),
+            NullHandling.sqlCompatible() ? baseScanBuilder.build() :
+            baseScanBuilder.filters(new NotDimFilter(new SelectorDimFilter("v0", null, null))).build()),
         ImmutableList.of(
             new Object[]{"10.1", 946771200000L}
         )
@@ -18032,4 +18034,181 @@
         )
     );
   }
+
+  @Test
+  @Parameters(source = QueryContextForJoinProvider.class)
+  public void testLeftJoinSubqueryWithNullKeyFilter(Map<String, Object> queryContext) throws Exception
+  {
+    // Cannot vectorize due to 'concat' expression.
+    cannotVectorize();
+
+    ScanQuery nullCompatibleModePlan = newScanQueryBuilder()
+        .dataSource(
+            join(
+                new TableDataSource(CalciteTests.DATASOURCE1),
+                new QueryDataSource(
+                    GroupByQuery
+                        .builder()
+                        .setDataSource(new LookupDataSource("lookyloo"))
+                        .setInterval(querySegmentSpec(Filtration.eternity()))
+                        .setGranularity(Granularities.ALL)
+                        .setVirtualColumns(
+                            expressionVirtualColumn("v0", "concat(\"k\",'')", ValueType.STRING)
+                        )
+                        .setDimensions(new DefaultDimensionSpec("v0", "d0"))
+                        .build()
+                ),
+                "j0.",
+                equalsCondition(DruidExpression.fromColumn("dim1"), DruidExpression.fromColumn("j0.d0")),
+                JoinType.INNER
+            )
+        )
+        .intervals(querySegmentSpec(Filtration.eternity()))
+        .columns("dim1", "j0.d0")
+        .context(queryContext)
+        .build();
+
+    ScanQuery nonNullCompatibleModePlan = newScanQueryBuilder()
+        .dataSource(
+            join(
+                new TableDataSource(CalciteTests.DATASOURCE1),
+                new QueryDataSource(
+                    GroupByQuery
+                        .builder()
+                        .setDataSource(new LookupDataSource("lookyloo"))
+                        .setInterval(querySegmentSpec(Filtration.eternity()))
+                        .setGranularity(Granularities.ALL)
+                        .setVirtualColumns(
+                            expressionVirtualColumn("v0", "concat(\"k\",'')", ValueType.STRING)
+                        )
+                        .setDimensions(new DefaultDimensionSpec("v0", "d0"))
+                        .build()
+                ),
+                "j0.",
+                equalsCondition(DruidExpression.fromColumn("dim1"), DruidExpression.fromColumn("j0.d0")),
+                JoinType.LEFT
+            )
+        )
+        .intervals(querySegmentSpec(Filtration.eternity()))
+        .columns("dim1", "j0.d0")
+        .filters(new NotDimFilter(new SelectorDimFilter("j0.d0", null, null)))
+        .context(queryContext)
+        .build();
+
+    boolean isJoinFilterRewriteEnabled = queryContext.getOrDefault(JOIN_FILTER_REWRITE_ENABLE_KEY, true).toString().equals("true");
+    testQuery(
+        "SELECT dim1, l1.k\n"
+        + "FROM foo\n"
+        + "LEFT JOIN (select k || '' as k from lookup.lookyloo group by 1) l1 ON foo.dim1 = l1.k\n"
+        + "WHERE l1.k IS NOT NULL\n",
+        queryContext,
+        ImmutableList.of(NullHandling.sqlCompatible() ? nullCompatibleModePlan : nonNullCompatibleModePlan),
+        NullHandling.sqlCompatible() || !isJoinFilterRewriteEnabled ? ImmutableList.of(new Object[]{"abc", "abc"}) : ImmutableList.of(
+                         new Object[]{"10.1", ""}, // this result is incorrect. TODO : fix this result when the JoinFilterAnalyzer bug is fixed
+                         new Object[]{"2", ""},
+                         new Object[]{"1", ""},
+                         new Object[]{"def", ""},
+                         new Object[]{"abc", "abc"})
+    );
+  }
+
+  @Test
+  @Parameters(source = QueryContextForJoinProvider.class)
+  public void testLeftJoinSubqueryWithSelectorFilter(Map<String, Object> queryContext) throws Exception
+  {
+    // Cannot vectorize due to 'concat' expression.
+    cannotVectorize();
+
+    // disable the cost model where inner join is treated like a filter
+    // this leads to cost(left join) < cost(converted inner join) for the below query
+    queryContext = QueryContextForJoinProvider.withOverrides(queryContext,
+                                                             ImmutableMap.of("computeInnerJoinCostAsFilter", "false"));
+    testQuery(
+        "SELECT dim1, l1.k\n"
+        + "FROM foo\n"
+        + "LEFT JOIN (select k || '' as k from lookup.lookyloo group by 1) l1 ON foo.dim1 = l1.k\n"
+        + "WHERE l1.k = 'abc'\n",
+        queryContext,
+        ImmutableList.of(
+            newScanQueryBuilder()
+                .dataSource(
+                    join(
+                        new TableDataSource(CalciteTests.DATASOURCE1),
+                        new QueryDataSource(
+                            GroupByQuery
+                                .builder()
+                                .setDataSource(new LookupDataSource("lookyloo"))
+                                .setInterval(querySegmentSpec(Filtration.eternity()))
+                                .setGranularity(Granularities.ALL)
+                                .setVirtualColumns(
+                                    expressionVirtualColumn("v0", "concat(\"k\",'')", ValueType.STRING)
+                                )
+                                .setDimensions(new DefaultDimensionSpec("v0", "d0"))
+                                .build()
+                        ),
+                        "j0.",
+                        equalsCondition(DruidExpression.fromColumn("dim1"), DruidExpression.fromColumn("j0.d0")),
+                        JoinType.LEFT
+                    )
+                )
+                .intervals(querySegmentSpec(Filtration.eternity()))
+                .columns("dim1", "j0.d0")
+                .filters(selector("j0.d0", "abc", null))
+                .context(queryContext)
+                .build()
+        ),
+        ImmutableList.of(
+            new Object[]{"abc", "abc"}
+        )
+    );
+  }
+
+  @Test
+  @Parameters(source = QueryContextForJoinProvider.class)
+  public void testInnerJoinSubqueryWithSelectorFilter(Map<String, Object> queryContext) throws Exception
+  {
+    // Cannot vectorize due to 'concat' expression.
+    cannotVectorize();
+
+    testQuery(
+        "SELECT dim1, l1.k "
+        + "FROM foo INNER JOIN (select k || '' as k from lookup.lookyloo group by 1) l1 "
+        + "ON foo.dim1 = l1.k and l1.k = 'abc'",
+        queryContext,
+        ImmutableList.of(
+            newScanQueryBuilder()
+                .dataSource(
+                    join(
+                        new TableDataSource(CalciteTests.DATASOURCE1),
+                        new QueryDataSource(
+                            GroupByQuery
+                                .builder()
+                                .setDataSource(new LookupDataSource("lookyloo"))
+                                .setInterval(querySegmentSpec(Filtration.eternity()))
+                                .setGranularity(Granularities.ALL)
+                                .setVirtualColumns(
+                                    expressionVirtualColumn("v0", "concat(\"k\",'')", ValueType.STRING)
+                                )
+                                .setDimensions(new DefaultDimensionSpec("v0", "d0"))
+                                .build()
+                        ),
+                        "j0.",
+                        StringUtils.format(
+                            "(%s && %s)",
+                            equalsCondition(DruidExpression.fromColumn("dim1"), DruidExpression.fromColumn("j0.d0")),
+                            equalsCondition(DruidExpression.fromExpression("'abc'"), DruidExpression.fromColumn("j0.d0"))
+                        ),
+                        JoinType.INNER
+                    )
+                )
+                .intervals(querySegmentSpec(Filtration.eternity()))
+                .columns("dim1", "j0.d0")
+                .context(queryContext)
+                .build()
+        ),
+        ImmutableList.of(
+            new Object[]{"abc", "abc"}
+        )
+    );
+  }
 }
diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/rule/DruidJoinRuleTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/rule/DruidJoinRuleTest.java
index dd706ff..42c6ba7 100644
--- a/sql/src/test/java/org/apache/druid/sql/calcite/rule/DruidJoinRuleTest.java
+++ b/sql/src/test/java/org/apache/druid/sql/calcite/rule/DruidJoinRuleTest.java
@@ -67,7 +67,8 @@
                 rexBuilder.makeInputRef(joinType, 0),
                 rexBuilder.makeInputRef(joinType, 1)
             ),
-            leftType
+            leftType,
+            null
         )
     );
   }
@@ -86,7 +87,8 @@
                 ),
                 rexBuilder.makeInputRef(typeFactory.createSqlType(SqlTypeName.VARCHAR), 1)
             ),
-            leftType
+            leftType,
+            null
         )
     );
   }
@@ -105,7 +107,8 @@
                     rexBuilder.makeInputRef(typeFactory.createSqlType(SqlTypeName.VARCHAR), 1)
                 )
             ),
-            leftType
+            leftType,
+            null
         )
     );
   }
@@ -120,7 +123,8 @@
                 rexBuilder.makeInputRef(typeFactory.createSqlType(SqlTypeName.VARCHAR), 0),
                 rexBuilder.makeInputRef(typeFactory.createSqlType(SqlTypeName.VARCHAR), 0)
             ),
-            leftType
+            leftType,
+            null
         )
     );
   }
@@ -135,7 +139,8 @@
                 rexBuilder.makeInputRef(typeFactory.createSqlType(SqlTypeName.VARCHAR), 1),
                 rexBuilder.makeInputRef(typeFactory.createSqlType(SqlTypeName.VARCHAR), 1)
             ),
-            leftType
+            leftType,
+            null
         )
     );
   }
@@ -146,7 +151,8 @@
     Assert.assertTrue(
         DruidJoinRule.canHandleCondition(
             rexBuilder.makeLiteral(true),
-            leftType
+            leftType,
+            null
         )
     );
   }
@@ -157,7 +163,8 @@
     Assert.assertTrue(
         DruidJoinRule.canHandleCondition(
             rexBuilder.makeLiteral(false),
-            leftType
+            leftType,
+            null
         )
     );
   }
diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/rule/FilterJoinExcludePushToChildRuleTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/rule/FilterJoinExcludePushToChildRuleTest.java
new file mode 100644
index 0000000..7daf5c1
--- /dev/null
+++ b/sql/src/test/java/org/apache/druid/sql/calcite/rule/FilterJoinExcludePushToChildRuleTest.java
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.sql.calcite.rule;
+
+import com.google.common.collect.ImmutableList;
+import org.apache.calcite.jdbc.JavaTypeFactoryImpl;
+import org.apache.calcite.rel.core.JoinRelType;
+import org.apache.calcite.rel.type.RelDataTypeFactory;
+import org.apache.calcite.rex.RexBuilder;
+import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.sql.fun.SqlStdOperatorTable;
+import org.apache.calcite.sql.type.SqlTypeFactoryImpl;
+import org.apache.calcite.sql.type.SqlTypeName;
+import org.apache.druid.sql.calcite.planner.DruidTypeSystem;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import static org.apache.calcite.sql.fun.SqlStdOperatorTable.IS_NOT_NULL;
+
+public class FilterJoinExcludePushToChildRuleTest
+{
+  private final RexBuilder rexBuilder = new RexBuilder(new JavaTypeFactoryImpl());
+  private final RelDataTypeFactory typeFactory = new SqlTypeFactoryImpl(DruidTypeSystem.INSTANCE);
+
+  @Test
+  public void testRemoveRedundantIsNotNullFiltersWithSQLCompatibility()
+  {
+    RexNode equalityFilter = rexBuilder.makeCall(
+        SqlStdOperatorTable.EQUALS,
+        rexBuilder.makeInputRef(typeFactory.createSqlType(SqlTypeName.VARCHAR), 0),
+        rexBuilder.makeInputRef(typeFactory.createSqlType(SqlTypeName.VARCHAR), 1));
+    RexNode isNotNullFilterOnJoinColumn = rexBuilder.makeCall(IS_NOT_NULL,
+                                                  ImmutableList.of(rexBuilder.makeInputRef(typeFactory.createSqlType(SqlTypeName.VARCHAR), 1)));
+    RexNode isNotNullFilterOnNonJoinColumn = rexBuilder.makeCall(IS_NOT_NULL,
+                                                  ImmutableList.of(rexBuilder.makeInputRef(typeFactory.createSqlType(SqlTypeName.VARCHAR), 2)));
+    List<RexNode> joinFilters = new ArrayList<>();
+    joinFilters.add(equalityFilter);
+
+    FilterJoinExcludePushToChildRule.removeRedundantIsNotNullFilters(joinFilters, JoinRelType.INNER, true);
+    Assert.assertEquals(joinFilters.size(), 1);
+    Assert.assertEquals("Equality Filter changed", joinFilters.get(0), equalityFilter);
+
+    // add IS NOT NULL filter on a join column
+    joinFilters.add(isNotNullFilterOnNonJoinColumn);
+    joinFilters.add(isNotNullFilterOnJoinColumn);
+    Assert.assertEquals(joinFilters.size(), 3);
+    FilterJoinExcludePushToChildRule.removeRedundantIsNotNullFilters(joinFilters, JoinRelType.INNER, true);
+    Assert.assertEquals(joinFilters.size(), 2);
+    Assert.assertEquals("Equality Filter changed", joinFilters.get(0), equalityFilter);
+    Assert.assertEquals("IS NOT NULL filter on non-join column changed", joinFilters.get(1), isNotNullFilterOnNonJoinColumn);
+  }
+}