Enable rewriting certain inner joins as filters. (#11068)

* Enable rewriting certain inner joins as filters.

The main logic for doing the rewrite is in JoinableFactoryWrapper's
segmentMapFn method. The requirements are:

- It must be an inner equi-join.
- The right-hand columns referenced by the condition must not contain any
  duplicate values. (If they did, the inner join would not be guaranteed
  to return at most one row for each left-hand-side row.)
- No columns from the right-hand side can be used by anything other than
  the join condition itself.

HashJoinSegmentStorageAdapter is also modified to pass through to
the base adapter (even allowing vectorization!) in the case where 100%
of join clauses could be rewritten as filters.

In support of this goal:

- Add Query getRequiredColumns() method to help us figure out whether
  the right-hand side of a join datasource is being used or not.
- Add JoinConditionAnalysis getRequiredColumns() method to help us
  figure out if the right-hand side of a join is being used by later
  join clauses acting on the same base.
- Add Joinable getNonNullColumnValuesIfAllUnique method to enable
  retrieving the set of values that will form the "in" filter.
- Add LookupExtractor canGetKeySet() and keySet() methods to support
  LookupJoinable in its efforts to implement the new Joinable method.
- Add "enableRewriteJoinToFilter" feature flag to
  JoinFilterRewriteConfig. The default is disabled.

* Test improvements.

* Test fixes.

* Avoid slow size() call.

* Remove invalid test.

* Fix style.

* Fix mistaken default.

* Small fixes.

* Fix logic error.
diff --git a/benchmarks/src/test/java/org/apache/druid/benchmark/IndexedTableJoinCursorBenchmark.java b/benchmarks/src/test/java/org/apache/druid/benchmark/IndexedTableJoinCursorBenchmark.java
index 1ed1c37..c2dd8ae 100644
--- a/benchmarks/src/test/java/org/apache/druid/benchmark/IndexedTableJoinCursorBenchmark.java
+++ b/benchmarks/src/test/java/org/apache/druid/benchmark/IndexedTableJoinCursorBenchmark.java
@@ -186,6 +186,7 @@
                     enableFilterPushdown,
                     enableFilterRewrite,
                     enableFilterRewriteValueFilters,
+                    QueryContexts.DEFAULT_ENABLE_REWRITE_JOIN_TO_FILTER,
                     QueryContexts.DEFAULT_ENABLE_JOIN_FILTER_REWRITE_MAX_SIZE
                 ),
                 clauses,
diff --git a/benchmarks/src/test/java/org/apache/druid/benchmark/JoinAndLookupBenchmark.java b/benchmarks/src/test/java/org/apache/druid/benchmark/JoinAndLookupBenchmark.java
index bf257fd..ee6b7e3 100644
--- a/benchmarks/src/test/java/org/apache/druid/benchmark/JoinAndLookupBenchmark.java
+++ b/benchmarks/src/test/java/org/apache/druid/benchmark/JoinAndLookupBenchmark.java
@@ -150,6 +150,7 @@
                     false,
                     false,
                     false,
+                    false,
                     0
                 ),
                 joinableClausesLookupStringKey,
@@ -185,6 +186,7 @@
                     false,
                     false,
                     false,
+                    false,
                     0
                 ),
                 joinableClausesLookupLongKey,
@@ -220,6 +222,7 @@
                     false,
                     false,
                     false,
+                    false,
                     0
                 ),
                 joinableClausesLookupLongKey,
@@ -255,6 +258,7 @@
                     false,
                     false,
                     false,
+                    false,
                     0
                 ),
                 joinableClausesIndexedTableLongKey,
diff --git a/extensions-core/lookups-cached-single/src/main/java/org/apache/druid/server/lookup/LoadingLookup.java b/extensions-core/lookups-cached-single/src/main/java/org/apache/druid/server/lookup/LoadingLookup.java
index af346e2..2bffc36 100644
--- a/extensions-core/lookups-cached-single/src/main/java/org/apache/druid/server/lookup/LoadingLookup.java
+++ b/extensions-core/lookups-cached-single/src/main/java/org/apache/druid/server/lookup/LoadingLookup.java
@@ -30,6 +30,7 @@
 import java.util.Collections;
 import java.util.List;
 import java.util.Map;
+import java.util.Set;
 import java.util.concurrent.Callable;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.atomic.AtomicBoolean;
@@ -112,12 +113,24 @@
   }
 
   @Override
+  public boolean canGetKeySet()
+  {
+    return false;
+  }
+
+  @Override
   public Iterable<Map.Entry<String, String>> iterable()
   {
     throw new UnsupportedOperationException("Cannot iterate");
   }
 
   @Override
+  public Set<String> keySet()
+  {
+    throw new UnsupportedOperationException("Cannot get key set");
+  }
+
+  @Override
   public byte[] getCacheKey()
   {
     return LookupExtractionModule.getRandomCacheKey();
diff --git a/extensions-core/lookups-cached-single/src/main/java/org/apache/druid/server/lookup/PollingLookup.java b/extensions-core/lookups-cached-single/src/main/java/org/apache/druid/server/lookup/PollingLookup.java
index 375f3d0..84c20d5 100644
--- a/extensions-core/lookups-cached-single/src/main/java/org/apache/druid/server/lookup/PollingLookup.java
+++ b/extensions-core/lookups-cached-single/src/main/java/org/apache/druid/server/lookup/PollingLookup.java
@@ -37,6 +37,7 @@
 import java.util.Collections;
 import java.util.List;
 import java.util.Map;
+import java.util.Set;
 import java.util.concurrent.Executors;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicBoolean;
@@ -174,12 +175,24 @@
   }
 
   @Override
+  public boolean canGetKeySet()
+  {
+    return false;
+  }
+
+  @Override
   public Iterable<Map.Entry<String, String>> iterable()
   {
     throw new UnsupportedOperationException("Cannot iterate");
   }
 
   @Override
+  public Set<String> keySet()
+  {
+    throw new UnsupportedOperationException("Cannot get key set");
+  }
+
+  @Override
   public byte[] getCacheKey()
   {
     return LookupExtractionModule.getRandomCacheKey();
diff --git a/extensions-core/lookups-cached-single/src/test/java/org/apache/druid/server/lookup/LoadingLookupTest.java b/extensions-core/lookups-cached-single/src/test/java/org/apache/druid/server/lookup/LoadingLookupTest.java
index 93e147d..0a28454 100644
--- a/extensions-core/lookups-cached-single/src/test/java/org/apache/druid/server/lookup/LoadingLookupTest.java
+++ b/extensions-core/lookups-cached-single/src/test/java/org/apache/druid/server/lookup/LoadingLookupTest.java
@@ -26,7 +26,9 @@
 import org.apache.druid.testing.InitializedNullHandlingTest;
 import org.easymock.EasyMock;
 import org.junit.Assert;
+import org.junit.Rule;
 import org.junit.Test;
+import org.junit.rules.ExpectedException;
 
 import java.util.Arrays;
 import java.util.Collections;
@@ -40,6 +42,9 @@
   LoadingCache reverseLookupCache = EasyMock.createStrictMock(LoadingCache.class);
   LoadingLookup loadingLookup = new LoadingLookup(dataFetcher, lookupCache, reverseLookupCache);
 
+  @Rule
+  public ExpectedException expectedException = ExpectedException.none();
+
   @Test
   public void testApplyEmptyOrNull() throws ExecutionException
   {
@@ -123,4 +128,17 @@
   {
     Assert.assertFalse(Arrays.equals(loadingLookup.getCacheKey(), loadingLookup.getCacheKey()));
   }
+
+  @Test
+  public void testCanGetKeySet()
+  {
+    Assert.assertFalse(loadingLookup.canGetKeySet());
+  }
+
+  @Test
+  public void testKeySet()
+  {
+    expectedException.expect(UnsupportedOperationException.class);
+    loadingLookup.keySet();
+  }
 }
diff --git a/extensions-core/lookups-cached-single/src/test/java/org/apache/druid/server/lookup/PollingLookupTest.java b/extensions-core/lookups-cached-single/src/test/java/org/apache/druid/server/lookup/PollingLookupTest.java
index c276b74..715100d 100644
--- a/extensions-core/lookups-cached-single/src/test/java/org/apache/druid/server/lookup/PollingLookupTest.java
+++ b/extensions-core/lookups-cached-single/src/test/java/org/apache/druid/server/lookup/PollingLookupTest.java
@@ -34,7 +34,9 @@
 import org.junit.After;
 import org.junit.Assert;
 import org.junit.Before;
+import org.junit.Rule;
 import org.junit.Test;
+import org.junit.rules.ExpectedException;
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
 
@@ -62,6 +64,9 @@
 
   private static final long POLL_PERIOD = 1000L;
 
+  @Rule
+  public ExpectedException expectedException = ExpectedException.none();
+
   @JsonTypeName("mock")
   private static class MockDataFetcher implements DataFetcher
   {
@@ -204,6 +209,19 @@
     Assert.assertFalse(Arrays.equals(pollingLookup2.getCacheKey(), pollingLookup.getCacheKey()));
   }
 
+  @Test
+  public void testCanGetKeySet()
+  {
+    Assert.assertFalse(pollingLookup.canGetKeySet());
+  }
+
+  @Test
+  public void testKeySet()
+  {
+    expectedException.expect(UnsupportedOperationException.class);
+    pollingLookup.keySet();
+  }
+
   private void assertMapLookup(Map<String, String> map, LookupExtractor lookup)
   {
     for (Map.Entry<String, String> entry : map.entrySet()) {
diff --git a/processing/src/main/java/org/apache/druid/query/Queries.java b/processing/src/main/java/org/apache/druid/query/Queries.java
index e25a88e..58de469 100644
--- a/processing/src/main/java/org/apache/druid/query/Queries.java
+++ b/processing/src/main/java/org/apache/druid/query/Queries.java
@@ -26,11 +26,16 @@
 import org.apache.druid.java.util.common.ISE;
 import org.apache.druid.query.aggregation.AggregatorFactory;
 import org.apache.druid.query.aggregation.PostAggregator;
+import org.apache.druid.query.dimension.DimensionSpec;
 import org.apache.druid.query.filter.DimFilter;
 import org.apache.druid.query.planning.DataSourceAnalysis;
 import org.apache.druid.query.planning.PreJoinableClause;
 import org.apache.druid.query.spec.MultipleSpecificSegmentSpec;
+import org.apache.druid.segment.VirtualColumn;
+import org.apache.druid.segment.VirtualColumns;
+import org.apache.druid.segment.column.ColumnHolder;
 
+import javax.annotation.Nullable;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashSet;
@@ -219,4 +224,73 @@
 
     return retVal;
   }
+
+  /**
+   * Helper for implementations of {@link Query#getRequiredColumns()}. Returns the list of columns that will be read
+   * out of a datasource by a query that uses the provided objects in the usual way.
+   *
+   * The returned set always contains {@code __time}, no matter what.
+   *
+   * If the virtual columns, filter, dimensions, aggregators, or additional columns refer to a virtual column, then the
+   * inputs of the virtual column will be returned instead of the name of the virtual column itself. Therefore, the
+   * returned list will never contain the names of any virtual columns.
+   *
+   * @param virtualColumns    virtual columns whose inputs should be included.
+   * @param filter            optional filter whose inputs should be included.
+   * @param dimensions        dimension specs whose inputs should be included.
+   * @param aggregators       aggregators whose inputs should be included.
+   * @param additionalColumns additional columns to include. Each of these will be added to the returned set, unless it
+   *                          refers to a virtual column, in which case the virtual column inputs will be added instead.
+   */
+  public static Set<String> computeRequiredColumns(
+      final VirtualColumns virtualColumns,
+      @Nullable final DimFilter filter,
+      final List<DimensionSpec> dimensions,
+      final List<AggregatorFactory> aggregators,
+      final List<String> additionalColumns
+  )
+  {
+    final Set<String> requiredColumns = new HashSet<>();
+
+    // Everyone needs __time (it's used by intervals filters).
+    requiredColumns.add(ColumnHolder.TIME_COLUMN_NAME);
+
+    for (VirtualColumn virtualColumn : virtualColumns.getVirtualColumns()) {
+      for (String column : virtualColumn.requiredColumns()) {
+        if (!virtualColumns.exists(column)) {
+          requiredColumns.addAll(virtualColumn.requiredColumns());
+        }
+      }
+    }
+
+    if (filter != null) {
+      for (String column : filter.getRequiredColumns()) {
+        if (!virtualColumns.exists(column)) {
+          requiredColumns.add(column);
+        }
+      }
+    }
+
+    for (DimensionSpec dimensionSpec : dimensions) {
+      if (!virtualColumns.exists(dimensionSpec.getDimension())) {
+        requiredColumns.add(dimensionSpec.getDimension());
+      }
+    }
+
+    for (AggregatorFactory aggregator : aggregators) {
+      for (String column : aggregator.requiredFields()) {
+        if (!virtualColumns.exists(column)) {
+          requiredColumns.add(column);
+        }
+      }
+    }
+
+    for (String column : additionalColumns) {
+      if (!virtualColumns.exists(column)) {
+        requiredColumns.add(column);
+      }
+    }
+
+    return requiredColumns;
+  }
 }
diff --git a/processing/src/main/java/org/apache/druid/query/Query.java b/processing/src/main/java/org/apache/druid/query/Query.java
index 93b24ce..fc12d5e 100644
--- a/processing/src/main/java/org/apache/druid/query/Query.java
+++ b/processing/src/main/java/org/apache/druid/query/Query.java
@@ -46,6 +46,7 @@
 import javax.annotation.Nullable;
 import java.util.List;
 import java.util.Map;
+import java.util.Set;
 import java.util.UUID;
 import java.util.concurrent.ExecutorService;
 
@@ -193,4 +194,20 @@
   {
     return VirtualColumns.EMPTY;
   }
+
+  /**
+   * Returns the set of columns that this query will need to access out of its datasource.
+   *
+   * This method does not "look into" what the datasource itself is doing. For example, if a query is built on a
+   * {@link QueryDataSource}, this method will not return the columns used by that subquery. As another example, if a
+   * query is built on a {@link JoinDataSource}, this method will not return the columns from the underlying datasources
+   * that are used by the join condition, unless those columns are also used by this query in other ways.
+   *
+   * Returns null if the set of required columns cannot be known ahead of time.
+   */
+  @Nullable
+  default Set<String> getRequiredColumns()
+  {
+    return null;
+  }
 }
diff --git a/processing/src/main/java/org/apache/druid/query/QueryContexts.java b/processing/src/main/java/org/apache/druid/query/QueryContexts.java
index 6edba68..f6528a6d 100644
--- a/processing/src/main/java/org/apache/druid/query/QueryContexts.java
+++ b/processing/src/main/java/org/apache/druid/query/QueryContexts.java
@@ -54,6 +54,7 @@
   public static final String JOIN_FILTER_PUSH_DOWN_KEY = "enableJoinFilterPushDown";
   public static final String JOIN_FILTER_REWRITE_ENABLE_KEY = "enableJoinFilterRewrite";
   public static final String JOIN_FILTER_REWRITE_VALUE_COLUMN_FILTERS_ENABLE_KEY = "enableJoinFilterRewriteValueColumnFilters";
+  public static final String REWRITE_JOIN_TO_FILTER_ENABLE_KEY = "enableRewriteJoinToFilter";
   public static final String JOIN_FILTER_REWRITE_MAX_SIZE_KEY = "joinFilterRewriteMaxSize";
   // This flag control whether a sql join query with left scan should be attempted to be run as direct table access
   // instead of being wrapped inside a query. With direct table access enabled, druid can push down the join operation to
@@ -80,6 +81,7 @@
   public static final boolean DEFAULT_ENABLE_JOIN_FILTER_PUSH_DOWN = true;
   public static final boolean DEFAULT_ENABLE_JOIN_FILTER_REWRITE = true;
   public static final boolean DEFAULT_ENABLE_JOIN_FILTER_REWRITE_VALUE_COLUMN_FILTERS = false;
+  public static final boolean DEFAULT_ENABLE_REWRITE_JOIN_TO_FILTER = false;
   public static final long DEFAULT_ENABLE_JOIN_FILTER_REWRITE_MAX_SIZE = 10000;
   public static final boolean DEFAULT_ENABLE_SQL_JOIN_LEFT_SCAN_DIRECT = false;
   public static final boolean DEFAULT_USE_FILTER_CNF = false;
@@ -274,6 +276,7 @@
   {
     return parseInt(query, BROKER_PARALLELISM, defaultValue);
   }
+
   public static <T> boolean getEnableJoinFilterRewriteValueColumnFilters(Query<T> query)
   {
     return parseBoolean(
@@ -283,6 +286,15 @@
     );
   }
 
+  public static <T> boolean getEnableRewriteJoinToFilter(Query<T> query)
+  {
+    return parseBoolean(
+        query,
+        REWRITE_JOIN_TO_FILTER_ENABLE_KEY,
+        DEFAULT_ENABLE_REWRITE_JOIN_TO_FILTER
+    );
+  }
+
   public static <T> long getJoinFilterRewriteMaxSize(Query<T> query)
   {
     return parseLong(query, JOIN_FILTER_REWRITE_MAX_SIZE_KEY, DEFAULT_ENABLE_JOIN_FILTER_REWRITE_MAX_SIZE);
diff --git a/processing/src/main/java/org/apache/druid/query/extraction/MapLookupExtractor.java b/processing/src/main/java/org/apache/druid/query/extraction/MapLookupExtractor.java
index 2309602..b001615 100644
--- a/processing/src/main/java/org/apache/druid/query/extraction/MapLookupExtractor.java
+++ b/processing/src/main/java/org/apache/druid/query/extraction/MapLookupExtractor.java
@@ -35,6 +35,7 @@
 import java.util.Collections;
 import java.util.List;
 import java.util.Map;
+import java.util.Set;
 import java.util.stream.Collectors;
 
 @JsonTypeName("map")
@@ -129,12 +130,24 @@
   }
 
   @Override
+  public boolean canGetKeySet()
+  {
+    return true;
+  }
+
+  @Override
   public Iterable<Map.Entry<String, String>> iterable()
   {
     return map.entrySet();
   }
 
   @Override
+  public Set<String> keySet()
+  {
+    return Collections.unmodifiableSet(map.keySet());
+  }
+
+  @Override
   public boolean equals(Object o)
   {
     if (this == o) {
diff --git a/processing/src/main/java/org/apache/druid/query/groupby/GroupByQuery.java b/processing/src/main/java/org/apache/druid/query/groupby/GroupByQuery.java
index b4fb075..7923e8e 100644
--- a/processing/src/main/java/org/apache/druid/query/groupby/GroupByQuery.java
+++ b/processing/src/main/java/org/apache/druid/query/groupby/GroupByQuery.java
@@ -73,6 +73,7 @@
 import javax.annotation.Nullable;
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Collections;
 import java.util.Comparator;
 import java.util.HashSet;
 import java.util.List;
@@ -778,6 +779,19 @@
     return postProcessingFn.apply(results);
   }
 
+  @Nullable
+  @Override
+  public Set<String> getRequiredColumns()
+  {
+    return Queries.computeRequiredColumns(
+        virtualColumns,
+        dimFilter,
+        dimensions,
+        aggregatorSpecs,
+        Collections.emptyList()
+    );
+  }
+
   @Override
   public GroupByQuery withOverriddenContext(Map<String, Object> contextOverride)
   {
diff --git a/processing/src/main/java/org/apache/druid/query/lookup/LookupExtractor.java b/processing/src/main/java/org/apache/druid/query/lookup/LookupExtractor.java
index f806a55..f24a965 100644
--- a/processing/src/main/java/org/apache/druid/query/lookup/LookupExtractor.java
+++ b/processing/src/main/java/org/apache/druid/query/lookup/LookupExtractor.java
@@ -29,6 +29,7 @@
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.Set;
 
 @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, property = "type")
 @JsonSubTypes(value = {
@@ -106,6 +107,11 @@
   public abstract boolean canIterate();
 
   /**
+   * Returns true if this lookup extractor's {@link #keySet()} method will return a valid set.
+   */
+  public abstract boolean canGetKeySet();
+
+  /**
    * Returns an Iterable that iterates over the keys and values in this lookup extractor.
    *
    * @throws UnsupportedOperationException if {@link #canIterate()} returns false.
@@ -113,6 +119,13 @@
   public abstract Iterable<Map.Entry<String, String>> iterable();
 
   /**
+   * Returns a Set of all keys in this lookup extractor. The returned Set will not change.
+   *
+   * @throws UnsupportedOperationException if {@link #canGetKeySet()} returns false.
+   */
+  public abstract Set<String> keySet();
+
+  /**
    * Create a cache key for use in results caching
    *
    * @return A byte array that can be used to uniquely identify if results of a prior lookup can use the cached values
diff --git a/processing/src/main/java/org/apache/druid/query/scan/ScanQuery.java b/processing/src/main/java/org/apache/druid/query/scan/ScanQuery.java
index 347e675..067bdff 100644
--- a/processing/src/main/java/org/apache/druid/query/scan/ScanQuery.java
+++ b/processing/src/main/java/org/apache/druid/query/scan/ScanQuery.java
@@ -31,6 +31,7 @@
 import org.apache.druid.query.BaseQuery;
 import org.apache.druid.query.DataSource;
 import org.apache.druid.query.Druids;
+import org.apache.druid.query.Queries;
 import org.apache.druid.query.Query;
 import org.apache.druid.query.filter.DimFilter;
 import org.apache.druid.query.spec.QuerySegmentSpec;
@@ -38,10 +39,12 @@
 import org.apache.druid.segment.column.ColumnHolder;
 
 import javax.annotation.Nullable;
+import java.util.Collections;
 import java.util.Comparator;
 import java.util.List;
 import java.util.Map;
 import java.util.Objects;
+import java.util.Set;
 
 public class ScanQuery extends BaseQuery<ScanResultValue>
 {
@@ -311,6 +314,24 @@
     );
   }
 
+  @Nullable
+  @Override
+  public Set<String> getRequiredColumns()
+  {
+    if (columns == null || columns.isEmpty()) {
+      // We don't know what columns we require. We'll find out when the segment shows up.
+      return null;
+    } else {
+      return Queries.computeRequiredColumns(
+          virtualColumns,
+          dimFilter,
+          Collections.emptyList(),
+          Collections.emptyList(),
+          columns
+      );
+    }
+  }
+
   public ScanQuery withOffset(final long newOffset)
   {
     return Druids.ScanQueryBuilder.copy(this).offset(newOffset).build();
diff --git a/processing/src/main/java/org/apache/druid/query/timeseries/TimeseriesQuery.java b/processing/src/main/java/org/apache/druid/query/timeseries/TimeseriesQuery.java
index 4756707..63c12de 100644
--- a/processing/src/main/java/org/apache/druid/query/timeseries/TimeseriesQuery.java
+++ b/processing/src/main/java/org/apache/druid/query/timeseries/TimeseriesQuery.java
@@ -38,10 +38,13 @@
 import org.apache.druid.query.spec.QuerySegmentSpec;
 import org.apache.druid.segment.VirtualColumns;
 
+import javax.annotation.Nullable;
 import java.util.ArrayList;
+import java.util.Collections;
 import java.util.List;
 import java.util.Map;
 import java.util.Objects;
+import java.util.Set;
 
 /**
  */
@@ -157,6 +160,19 @@
     return getContextBoolean(SKIP_EMPTY_BUCKETS, false);
   }
 
+  @Nullable
+  @Override
+  public Set<String> getRequiredColumns()
+  {
+    return Queries.computeRequiredColumns(
+        virtualColumns,
+        dimFilter,
+        Collections.emptyList(),
+        aggregatorSpecs,
+        Collections.emptyList()
+    );
+  }
+
   @Override
   public TimeseriesQuery withQuerySegmentSpec(QuerySegmentSpec querySegmentSpec)
   {
diff --git a/processing/src/main/java/org/apache/druid/query/topn/TopNQuery.java b/processing/src/main/java/org/apache/druid/query/topn/TopNQuery.java
index 3218139..7724a6d 100644
--- a/processing/src/main/java/org/apache/druid/query/topn/TopNQuery.java
+++ b/processing/src/main/java/org/apache/druid/query/topn/TopNQuery.java
@@ -37,10 +37,13 @@
 import org.apache.druid.query.spec.QuerySegmentSpec;
 import org.apache.druid.segment.VirtualColumns;
 
+import javax.annotation.Nullable;
 import java.util.ArrayList;
+import java.util.Collections;
 import java.util.List;
 import java.util.Map;
 import java.util.Objects;
+import java.util.Set;
 
 /**
  */
@@ -156,6 +159,19 @@
     return postAggregatorSpecs;
   }
 
+  @Nullable
+  @Override
+  public Set<String> getRequiredColumns()
+  {
+    return Queries.computeRequiredColumns(
+        virtualColumns,
+        dimFilter,
+        Collections.singletonList(dimensionSpec),
+        aggregatorSpecs,
+        Collections.emptyList()
+    );
+  }
+
   public void initTopNAlgorithmSelector(TopNAlgorithmSelector selector)
   {
     if (dimensionSpec.getExtractionFn() != null) {
diff --git a/processing/src/main/java/org/apache/druid/segment/filter/Filters.java b/processing/src/main/java/org/apache/druid/segment/filter/Filters.java
index 03209d0..abecb0d 100644
--- a/processing/src/main/java/org/apache/druid/segment/filter/Filters.java
+++ b/processing/src/main/java/org/apache/druid/segment/filter/Filters.java
@@ -59,6 +59,7 @@
 import java.util.LinkedHashSet;
 import java.util.List;
 import java.util.NoSuchElementException;
+import java.util.Objects;
 import java.util.Optional;
 import java.util.stream.Collectors;
 
@@ -486,14 +487,14 @@
 
   /**
    * Create a filter representing an AND relationship across a list of filters. Deduplicates filters, flattens stacks,
-   * and removes literal "false" filters.
+   * and removes null filters and literal "false" filters.
    *
    * @param filters List of filters
    *
    * @return If "filters" has more than one filter remaining after processing, returns {@link AndFilter}.
    * If "filters" has a single element remaining after processing, return that filter alone.
    *
-   * @throws IllegalArgumentException if "filters" is empty
+   * @throws IllegalArgumentException if "filters" is empty or only contains nulls
    */
   public static Filter and(final List<Filter> filters)
   {
@@ -501,15 +502,18 @@
   }
 
   /**
-   * Like {@link #and}, but returns an empty Optional instead of throwing an exception if "filters" is empty.
+   * Like {@link #and}, but returns an empty Optional instead of throwing an exception if "filters" is empty
+   * or only contains nulls.
    */
   public static Optional<Filter> maybeAnd(List<Filter> filters)
   {
-    if (filters.isEmpty()) {
+    final List<Filter> nonNullFilters = nonNull(filters);
+
+    if (nonNullFilters.isEmpty()) {
       return Optional.empty();
     }
 
-    final LinkedHashSet<Filter> filtersToUse = flattenAndChildren(filters);
+    final LinkedHashSet<Filter> filtersToUse = flattenAndChildren(nonNullFilters);
 
     if (filtersToUse.isEmpty()) {
       assert !filters.isEmpty();
@@ -527,7 +531,7 @@
 
   /**
    * Create a filter representing an OR relationship across a list of filters. Deduplicates filters, flattens stacks,
-   * and removes literal "false" filters.
+   * and removes null filters and literal "false" filters.
    *
    * @param filters List of filters
    *
@@ -542,18 +546,21 @@
   }
 
   /**
-   * Like {@link #or}, but returns an empty Optional instead of throwing an exception if "filters" is empty.
+   * Like {@link #or}, but returns an empty Optional instead of throwing an exception if "filters" is empty
+   * or only contains nulls.
    */
   public static Optional<Filter> maybeOr(final List<Filter> filters)
   {
-    if (filters.isEmpty()) {
+    final List<Filter> nonNullFilters = nonNull(filters);
+
+    if (nonNullFilters.isEmpty()) {
       return Optional.empty();
     }
 
-    final LinkedHashSet<Filter> filtersToUse = flattenOrChildren(filters);
+    final LinkedHashSet<Filter> filtersToUse = flattenOrChildren(nonNullFilters);
 
     if (filtersToUse.isEmpty()) {
-      assert !filters.isEmpty();
+      assert !nonNullFilters.isEmpty();
       // Original "filters" list must have been 100% literally-false filters.
       return Optional.of(FalseFilter.instance());
     } else if (filtersToUse.stream().anyMatch(filter -> filter instanceof TrueFilter)) {
@@ -595,6 +602,20 @@
     return valueMatcher.matches();
   }
 
+
+  /**
+   * Returns a list equivalent to the input list, but with nulls removed. If the original list has no nulls,
+   * it is returned directly.
+   */
+  private static List<Filter> nonNull(final List<Filter> filters)
+  {
+    if (filters.stream().anyMatch(Objects::isNull)) {
+      return filters.stream().filter(Objects::nonNull).collect(Collectors.toList());
+    } else {
+      return filters;
+    }
+  }
+
   /**
    * Flattens children of an AND, removes duplicates, and removes literally-true filters.
    */
diff --git a/processing/src/main/java/org/apache/druid/segment/join/HashJoinSegment.java b/processing/src/main/java/org/apache/druid/segment/join/HashJoinSegment.java
index 2002ee10..34ac51c 100644
--- a/processing/src/main/java/org/apache/druid/segment/join/HashJoinSegment.java
+++ b/processing/src/main/java/org/apache/druid/segment/join/HashJoinSegment.java
@@ -66,9 +66,9 @@
     this.clauses = clauses;
     this.joinFilterPreAnalysis = joinFilterPreAnalysis;
 
-    // Verify 'clauses' is nonempty (otherwise it's a waste to create this object, and the caller should know)
-    if (clauses.isEmpty()) {
-      throw new IAE("'clauses' is empty, no need to create HashJoinSegment");
+    // Verify this virtual segment is doing something useful (otherwise it's a waste to create this object)
+    if (clauses.isEmpty() && baseFilter == null) {
+      throw new IAE("'clauses' and 'baseFilter' are both empty, no need to create HashJoinSegment");
     }
   }
 
diff --git a/processing/src/main/java/org/apache/druid/segment/join/HashJoinSegmentStorageAdapter.java b/processing/src/main/java/org/apache/druid/segment/join/HashJoinSegmentStorageAdapter.java
index 4df490f..86b7ef4 100644
--- a/processing/src/main/java/org/apache/druid/segment/join/HashJoinSegmentStorageAdapter.java
+++ b/processing/src/main/java/org/apache/druid/segment/join/HashJoinSegmentStorageAdapter.java
@@ -37,16 +37,19 @@
 import org.apache.druid.segment.column.ColumnHolder;
 import org.apache.druid.segment.data.Indexed;
 import org.apache.druid.segment.data.ListIndexed;
+import org.apache.druid.segment.filter.Filters;
 import org.apache.druid.segment.join.filter.JoinFilterAnalyzer;
 import org.apache.druid.segment.join.filter.JoinFilterPreAnalysis;
 import org.apache.druid.segment.join.filter.JoinFilterPreAnalysisKey;
 import org.apache.druid.segment.join.filter.JoinFilterSplit;
+import org.apache.druid.segment.vector.VectorCursor;
 import org.joda.time.DateTime;
 import org.joda.time.Interval;
 
 import javax.annotation.Nonnull;
 import javax.annotation.Nullable;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.HashSet;
 import java.util.LinkedHashSet;
 import java.util.List;
@@ -56,6 +59,8 @@
 public class HashJoinSegmentStorageAdapter implements StorageAdapter
 {
   private final StorageAdapter baseAdapter;
+
+  @Nullable
   private final Filter baseFilter;
   private final List<JoinableClause> clauses;
   private final JoinFilterPreAnalysis joinFilterPreAnalysis;
@@ -84,7 +89,7 @@
    */
   HashJoinSegmentStorageAdapter(
       final StorageAdapter baseAdapter,
-      final Filter baseFilter,
+      @Nullable final Filter baseFilter,
       final List<JoinableClause> clauses,
       final JoinFilterPreAnalysis joinFilterPreAnalysis
   )
@@ -222,6 +227,43 @@
   }
 
   @Override
+  public boolean canVectorize(@Nullable Filter filter, VirtualColumns virtualColumns, boolean descending)
+  {
+    // HashJoinEngine isn't vectorized yet.
+    // However, we can still vectorize if there are no clauses, since that means all we need to do is apply
+    // a base filter. That's easy enough!
+    return clauses.isEmpty() && baseAdapter.canVectorize(baseFilterAnd(filter), virtualColumns, descending);
+  }
+
+  @Nullable
+  @Override
+  public VectorCursor makeVectorCursor(
+      @Nullable Filter filter,
+      Interval interval,
+      VirtualColumns virtualColumns,
+      boolean descending,
+      int vectorSize,
+      @Nullable QueryMetrics<?> queryMetrics
+  )
+  {
+    if (!canVectorize(filter, virtualColumns, descending)) {
+      throw new ISE("Cannot vectorize. Check 'canVectorize' before calling 'makeVectorCursor'.");
+    }
+
+    // Should have been checked by canVectorize.
+    assert clauses.isEmpty();
+
+    return baseAdapter.makeVectorCursor(
+        baseFilterAnd(filter),
+        interval,
+        virtualColumns,
+        descending,
+        vectorSize,
+        queryMetrics
+    );
+  }
+
+  @Override
   public Sequence<Cursor> makeCursors(
       @Nullable final Filter filter,
       @Nonnull final Interval interval,
@@ -231,6 +273,19 @@
       @Nullable final QueryMetrics<?> queryMetrics
   )
   {
+    final Filter combinedFilter = baseFilterAnd(filter);
+
+    if (clauses.isEmpty()) {
+      return baseAdapter.makeCursors(
+          combinedFilter,
+          interval,
+          virtualColumns,
+          gran,
+          descending,
+          queryMetrics
+      );
+    }
+
     // Filter pre-analysis key implied by the call to "makeCursors". We need to sanity-check that it matches
     // the actual pre-analysis that was done. Note: we can't infer a rewrite config from the "makeCursors" call (it
     // requires access to the query context) so we'll need to skip sanity-checking it, by re-using the one present
@@ -240,7 +295,7 @@
             joinFilterPreAnalysis.getKey().getRewriteConfig(),
             clauses,
             virtualColumns,
-            filter
+            combinedFilter
         );
 
     final JoinFilterPreAnalysisKey keyCached = joinFilterPreAnalysis.getKey();
@@ -363,4 +418,10 @@
                 .filter(clause -> clause.includesColumn(column))
                 .findFirst();
   }
+
+  @Nullable
+  private Filter baseFilterAnd(@Nullable final Filter other)
+  {
+    return Filters.maybeAnd(Arrays.asList(baseFilter, other)).orElse(null);
+  }
 }
diff --git a/processing/src/main/java/org/apache/druid/segment/join/JoinConditionAnalysis.java b/processing/src/main/java/org/apache/druid/segment/join/JoinConditionAnalysis.java
index 23875a0..53460b7 100644
--- a/processing/src/main/java/org/apache/druid/segment/join/JoinConditionAnalysis.java
+++ b/processing/src/main/java/org/apache/druid/segment/join/JoinConditionAnalysis.java
@@ -29,6 +29,7 @@
 
 import java.util.ArrayList;
 import java.util.Collections;
+import java.util.HashSet;
 import java.util.List;
 import java.util.Objects;
 import java.util.Optional;
@@ -58,6 +59,7 @@
   private final boolean isAlwaysTrue;
   private final boolean canHashJoin;
   private final Set<String> rightKeyColumns;
+  private final Set<String> requiredColumns;
 
   private JoinConditionAnalysis(
       final String originalExpression,
@@ -80,6 +82,7 @@
                                                                     ExprUtils.nilBindings()).asBoolean());
     canHashJoin = nonEquiConditions.stream().allMatch(Expr::isLiteral);
     rightKeyColumns = getEquiConditions().stream().map(Equality::getRightColumn).collect(Collectors.toSet());
+    requiredColumns = computeRequiredColumns(rightPrefix, equiConditions, nonEquiConditions);
   }
 
   /**
@@ -192,6 +195,15 @@
     return rightKeyColumns;
   }
 
+  /**
+   * Returns the set of column names required by this join condition. Columns from the right-hand side are returned
+   * with their prefixes included.
+   */
+  public Set<String> getRequiredColumns()
+  {
+    return requiredColumns;
+  }
+
   @Override
   public boolean equals(Object o)
   {
@@ -217,4 +229,24 @@
   {
     return originalExpression;
   }
+
+  private static Set<String> computeRequiredColumns(
+      final String rightPrefix,
+      final List<Equality> equiConditions,
+      final List<Expr> nonEquiConditions
+  )
+  {
+    final Set<String> requiredColumns = new HashSet<>();
+
+    for (Equality equality : equiConditions) {
+      requiredColumns.add(rightPrefix + equality.getRightColumn());
+      requiredColumns.addAll(equality.getLeftExpr().analyzeInputs().getRequiredBindings());
+    }
+
+    for (Expr expr : nonEquiConditions) {
+      requiredColumns.addAll(expr.analyzeInputs().getRequiredBindings());
+    }
+
+    return requiredColumns;
+  }
 }
diff --git a/processing/src/main/java/org/apache/druid/segment/join/Joinable.java b/processing/src/main/java/org/apache/druid/segment/join/Joinable.java
index f22134b..25957f7 100644
--- a/processing/src/main/java/org/apache/druid/segment/join/Joinable.java
+++ b/processing/src/main/java/org/apache/druid/segment/join/Joinable.java
@@ -86,6 +86,15 @@
   );
 
   /**
+   * Returns all nonnull values from a particular column if they are all unique, if there are "maxNumValues" or fewer,
+   * and if the column exists and supports this operation. Otherwise, returns an empty Optional.
+   *
+   * @param columnName   name of the column
+   * @param maxNumValues maximum number of values to return
+   */
+  Optional<Set<String>> getNonNullColumnValuesIfAllUnique(String columnName, int maxNumValues);
+
+  /**
    * Searches a column from this Joinable for a particular value, finds rows that match,
    * and returns values of a second column for those rows.
    *
@@ -93,9 +102,9 @@
    * @param searchColumnValue Target value of the search column. This is the value that is being filtered on.
    * @param retrievalColumnName The column to retrieve values from. This is the column that is being joined against.
    * @param maxCorrelationSetSize Maximum number of values to retrieve. If we detect that more values would be
-   *                              returned than this limit, return an empty set.
+   *                              returned than this limit, return absent.
    * @param allowNonKeyColumnSearch If true, allow searchs on non-key columns. If this is false,
-   *                                a search on a non-key column should return an empty set.
+   *                                a search on a non-key column returns absent.
    * @return The set of correlated column values. If we cannot determine correlated values, return absent.
    *
    * In case either the search or retrieval column names are not found, this will return absent.
diff --git a/processing/src/main/java/org/apache/druid/segment/join/JoinableFactoryWrapper.java b/processing/src/main/java/org/apache/druid/segment/join/JoinableFactoryWrapper.java
index b076b1a..f462b93 100644
--- a/processing/src/main/java/org/apache/druid/segment/join/JoinableFactoryWrapper.java
+++ b/processing/src/main/java/org/apache/druid/segment/join/JoinableFactoryWrapper.java
@@ -19,12 +19,22 @@
 
 package org.apache.druid.segment.join;
 
+import com.google.common.annotations.VisibleForTesting;
 import com.google.common.base.Preconditions;
+import com.google.common.collect.HashMultiset;
+import com.google.common.collect.Iterables;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Multiset;
+import com.google.common.collect.Sets;
+import com.google.common.primitives.Ints;
 import org.apache.druid.java.util.common.IAE;
+import org.apache.druid.java.util.common.ISE;
+import org.apache.druid.java.util.common.Pair;
 import org.apache.druid.java.util.common.logger.Logger;
 import org.apache.druid.query.Query;
 import org.apache.druid.query.cache.CacheKeyBuilder;
 import org.apache.druid.query.filter.Filter;
+import org.apache.druid.query.filter.InDimFilter;
 import org.apache.druid.query.planning.DataSourceAnalysis;
 import org.apache.druid.query.planning.PreJoinableClause;
 import org.apache.druid.segment.SegmentReference;
@@ -36,8 +46,13 @@
 import org.apache.druid.segment.join.filter.rewrite.JoinFilterRewriteConfig;
 import org.apache.druid.utils.JvmUtils;
 
+import javax.annotation.Nullable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
 import java.util.List;
 import java.util.Optional;
+import java.util.Set;
 import java.util.concurrent.atomic.AtomicLong;
 import java.util.function.Function;
 
@@ -61,6 +76,7 @@
    * Creates a Function that maps base segments to {@link HashJoinSegment} if needed (i.e. if the number of join
    * clauses is > 0). If mapping is not needed, this method will return {@link Function#identity()}.
    *
+   * @param baseFilter         Filter to apply before the join takes place
    * @param clauses            Pre-joinable clauses
    * @param cpuTimeAccumulator An accumulator that we will add CPU nanos to; this is part of the function to encourage
    *                           callers to remember to track metrics on CPU time required for creation of Joinables
@@ -70,7 +86,7 @@
    *                           query from the end user.
    */
   public Function<SegmentReference, SegmentReference> createSegmentMapFn(
-      final Filter baseFilter,
+      @Nullable final Filter baseFilter,
       final List<PreJoinableClause> clauses,
       final AtomicLong cpuTimeAccumulator,
       final Query<?> query
@@ -84,22 +100,48 @@
             return Function.identity();
           } else {
             final JoinableClauses joinableClauses = JoinableClauses.createClauses(clauses, joinableFactory);
+            final JoinFilterRewriteConfig filterRewriteConfig = JoinFilterRewriteConfig.forQuery(query);
+
+            // Pick off any join clauses that can be converted into filters.
+            final Set<String> requiredColumns = query.getRequiredColumns();
+            final Filter baseFilterToUse;
+            final List<JoinableClause> clausesToUse;
+
+            if (requiredColumns != null && filterRewriteConfig.isEnableRewriteJoinToFilter()) {
+              final Pair<List<Filter>, List<JoinableClause>> conversionResult = convertJoinsToFilters(
+                  joinableClauses.getJoinableClauses(),
+                  requiredColumns,
+                  Ints.checkedCast(Math.min(filterRewriteConfig.getFilterRewriteMaxSize(), Integer.MAX_VALUE))
+              );
+
+              baseFilterToUse =
+                  Filters.maybeAnd(
+                      Lists.newArrayList(
+                          Iterables.concat(
+                              Collections.singleton(baseFilter),
+                              conversionResult.lhs
+                          )
+                      )
+                  ).orElse(null);
+              clausesToUse = conversionResult.rhs;
+            } else {
+              baseFilterToUse = baseFilter;
+              clausesToUse = joinableClauses.getJoinableClauses();
+            }
+
+            // Analyze remaining join clauses to see if filters on them can be pushed down.
             final JoinFilterPreAnalysis joinFilterPreAnalysis = JoinFilterAnalyzer.computeJoinFilterPreAnalysis(
                 new JoinFilterPreAnalysisKey(
-                    JoinFilterRewriteConfig.forQuery(query),
-                    joinableClauses.getJoinableClauses(),
+                    filterRewriteConfig,
+                    clausesToUse,
                     query.getVirtualColumns(),
-                    Filters.toFilter(query.getFilter())
+                    Filters.maybeAnd(Arrays.asList(baseFilterToUse, Filters.toFilter(query.getFilter())))
+                           .orElse(null)
                 )
             );
 
             return baseSegment ->
-                new HashJoinSegment(
-                    baseSegment,
-                    baseFilter,
-                    joinableClauses.getJoinableClauses(),
-                    joinFilterPreAnalysis
-                );
+                new HashJoinSegment(baseSegment, baseFilterToUse, clausesToUse, joinFilterPreAnalysis);
           }
         }
     );
@@ -116,7 +158,9 @@
    * in the JOIN is not cacheable.
    *
    * @param dataSourceAnalysis for the join datasource
+   *
    * @return the optional cache key to be used as part of query cache key
+   *
    * @throws {@link IAE} if this operation is called on a non-join data source
    */
   public Optional<byte[]> computeJoinDataSourceCacheKey(
@@ -148,4 +192,112 @@
     return Optional.of(keyBuilder.build());
   }
 
+
+  /**
+   * Converts any join clauses to filters that can be converted, and returns the rest as-is.
+   *
+   * See {@link #convertJoinToFilter} for details on the logic.
+   */
+  @VisibleForTesting
+  static Pair<List<Filter>, List<JoinableClause>> convertJoinsToFilters(
+      final List<JoinableClause> clauses,
+      final Set<String> requiredColumns,
+      final int maxNumFilterValues
+  )
+  {
+    final List<Filter> filterList = new ArrayList<>();
+    final List<JoinableClause> clausesToUse = new ArrayList<>();
+
+    // Join clauses may depend on other, earlier join clauses.
+    // We track that using a Multiset, because we'll need to remove required columns one by one as we convert clauses,
+    // and multiple clauses may refer to the same column.
+    final Multiset<String> columnsRequiredByJoinClauses = HashMultiset.create();
+
+    for (JoinableClause clause : clauses) {
+      for (String column : clause.getCondition().getRequiredColumns()) {
+        columnsRequiredByJoinClauses.add(column, 1);
+      }
+    }
+
+    // Walk through the list of clauses, picking off any from the start of the list that can be converted to filters.
+    boolean atStart = true;
+    for (JoinableClause clause : clauses) {
+      if (atStart) {
+        // Remove this clause from columnsRequiredByJoinClauses. It's ok if it relies on itself.
+        for (String column : clause.getCondition().getRequiredColumns()) {
+          columnsRequiredByJoinClauses.remove(column, 1);
+        }
+
+        final Optional<Filter> filter =
+            convertJoinToFilter(
+                clause,
+                Sets.union(requiredColumns, columnsRequiredByJoinClauses.elementSet()),
+                maxNumFilterValues
+            );
+
+        if (filter.isPresent()) {
+          filterList.add(filter.get());
+        } else {
+          clausesToUse.add(clause);
+          atStart = false;
+        }
+      } else {
+        clausesToUse.add(clause);
+      }
+    }
+
+    // Sanity check. If this exception is ever thrown, it's a bug.
+    if (filterList.size() + clausesToUse.size() != clauses.size()) {
+      throw new ISE("Lost a join clause during planning");
+    }
+
+    return Pair.of(filterList, clausesToUse);
+  }
+
+  /**
+   * Converts a join clause into an "in" filter if possible.
+   *
+   * The requirements are:
+   *
+   * - it must be an INNER equi-join
+   * - the right-hand columns referenced by the condition must not have any duplicate values
+   * - no columns from the right-hand side can appear in "requiredColumns"
+   */
+  @VisibleForTesting
+  static Optional<Filter> convertJoinToFilter(
+      final JoinableClause clause,
+      final Set<String> requiredColumns,
+      final int maxNumFilterValues
+  )
+  {
+    if (clause.getJoinType() == JoinType.INNER
+        && requiredColumns.stream().noneMatch(clause::includesColumn)
+        && clause.getCondition().getNonEquiConditions().isEmpty()
+        && clause.getCondition().getEquiConditions().size() > 0) {
+      final List<Filter> filters = new ArrayList<>();
+      int numValues = maxNumFilterValues;
+
+      for (final Equality condition : clause.getCondition().getEquiConditions()) {
+        final String leftColumn = condition.getLeftExpr().getBindingIfIdentifier();
+
+        if (leftColumn == null) {
+          return Optional.empty();
+        }
+
+        final Optional<Set<String>> columnValuesForFilter =
+            clause.getJoinable().getNonNullColumnValuesIfAllUnique(condition.getRightColumn(), numValues);
+
+        if (columnValuesForFilter.isPresent()) {
+          numValues -= columnValuesForFilter.get().size();
+          filters.add(Filters.toFilter(new InDimFilter(leftColumn, columnValuesForFilter.get())));
+        } else {
+          return Optional.empty();
+        }
+      }
+
+      return Optional.of(Filters.and(filters));
+    }
+
+    return Optional.empty();
+  }
 }
diff --git a/processing/src/main/java/org/apache/druid/segment/join/filter/rewrite/JoinFilterRewriteConfig.java b/processing/src/main/java/org/apache/druid/segment/join/filter/rewrite/JoinFilterRewriteConfig.java
index ec18f03..88bf00b 100644
--- a/processing/src/main/java/org/apache/druid/segment/join/filter/rewrite/JoinFilterRewriteConfig.java
+++ b/processing/src/main/java/org/apache/druid/segment/join/filter/rewrite/JoinFilterRewriteConfig.java
@@ -48,6 +48,12 @@
   private final boolean enableRewriteValueColumnFilters;
 
   /**
+   * Whether to enable eliminating entire inner join clauses by rewriting them into filters on the base segment.
+   * In production this should generally be {@code QueryContexts.getEnableRewriteJoinToFilter(query)}.
+   */
+  private final boolean enableRewriteJoinToFilter;
+
+  /**
    * The max allowed size of correlated value sets for RHS rewrites. In production
    * This should generally be {@code QueryContexts.getJoinFilterRewriteMaxSize(query)}.
    */
@@ -57,12 +63,14 @@
       boolean enableFilterPushDown,
       boolean enableFilterRewrite,
       boolean enableRewriteValueColumnFilters,
+      boolean enableRewriteJoinToFilter,
       long filterRewriteMaxSize
   )
   {
     this.enableFilterPushDown = enableFilterPushDown;
     this.enableFilterRewrite = enableFilterRewrite;
     this.enableRewriteValueColumnFilters = enableRewriteValueColumnFilters;
+    this.enableRewriteJoinToFilter = enableRewriteJoinToFilter;
     this.filterRewriteMaxSize = filterRewriteMaxSize;
   }
 
@@ -72,6 +80,7 @@
         QueryContexts.getEnableJoinFilterPushDown(query),
         QueryContexts.getEnableJoinFilterRewrite(query),
         QueryContexts.getEnableJoinFilterRewriteValueColumnFilters(query),
+        QueryContexts.getEnableRewriteJoinToFilter(query),
         QueryContexts.getJoinFilterRewriteMaxSize(query)
     );
   }
@@ -91,6 +100,11 @@
     return enableRewriteValueColumnFilters;
   }
 
+  public boolean isEnableRewriteJoinToFilter()
+  {
+    return enableRewriteJoinToFilter;
+  }
+
   public long getFilterRewriteMaxSize()
   {
     return filterRewriteMaxSize;
@@ -106,10 +120,11 @@
       return false;
     }
     JoinFilterRewriteConfig that = (JoinFilterRewriteConfig) o;
-    return enableFilterPushDown == that.enableFilterPushDown &&
-           enableFilterRewrite == that.enableFilterRewrite &&
-           enableRewriteValueColumnFilters == that.enableRewriteValueColumnFilters &&
-           filterRewriteMaxSize == that.filterRewriteMaxSize;
+    return enableFilterPushDown == that.enableFilterPushDown
+           && enableFilterRewrite == that.enableFilterRewrite
+           && enableRewriteValueColumnFilters == that.enableRewriteValueColumnFilters
+           && enableRewriteJoinToFilter == that.enableRewriteJoinToFilter
+           && filterRewriteMaxSize == that.filterRewriteMaxSize;
   }
 
   @Override
@@ -119,7 +134,20 @@
         enableFilterPushDown,
         enableFilterRewrite,
         enableRewriteValueColumnFilters,
+        enableRewriteJoinToFilter,
         filterRewriteMaxSize
     );
   }
+
+  @Override
+  public String toString()
+  {
+    return "JoinFilterRewriteConfig{" +
+           "enableFilterPushDown=" + enableFilterPushDown +
+           ", enableFilterRewrite=" + enableFilterRewrite +
+           ", enableRewriteValueColumnFilters=" + enableRewriteValueColumnFilters +
+           ", enableRewriteJoinToFilter=" + enableRewriteJoinToFilter +
+           ", filterRewriteMaxSize=" + filterRewriteMaxSize +
+           '}';
+  }
 }
diff --git a/processing/src/main/java/org/apache/druid/segment/join/lookup/LookupJoinable.java b/processing/src/main/java/org/apache/druid/segment/join/lookup/LookupJoinable.java
index 109da85..2d3c43d 100644
--- a/processing/src/main/java/org/apache/druid/segment/join/lookup/LookupJoinable.java
+++ b/processing/src/main/java/org/apache/druid/segment/join/lookup/LookupJoinable.java
@@ -21,6 +21,8 @@
 
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableSet;
+import com.google.common.collect.Sets;
+import org.apache.druid.common.config.NullHandling;
 import org.apache.druid.java.util.common.io.Closer;
 import org.apache.druid.query.lookup.LookupExtractor;
 import org.apache.druid.segment.ColumnSelectorFactory;
@@ -34,6 +36,7 @@
 import javax.annotation.Nullable;
 import java.io.Closeable;
 import java.util.Collections;
+import java.util.HashSet;
 import java.util.List;
 import java.util.Optional;
 import java.util.Set;
@@ -93,6 +96,39 @@
   }
 
   @Override
+  public Optional<Set<String>> getNonNullColumnValuesIfAllUnique(String columnName, int maxNumValues)
+  {
+    if (LookupColumnSelectorFactory.KEY_COLUMN.equals(columnName) && extractor.canGetKeySet()) {
+      final Set<String> keys = extractor.keySet();
+
+      final Set<String> nullEquivalentValues = new HashSet<>();
+      nullEquivalentValues.add(null);
+      if (NullHandling.replaceWithDefault()) {
+        nullEquivalentValues.add(NullHandling.defaultStringValue());
+      }
+
+      // size() of Sets.difference is slow; avoid it.
+      int nonNullKeys = keys.size();
+
+      for (String value : nullEquivalentValues) {
+        if (keys.contains(value)) {
+          nonNullKeys--;
+        }
+      }
+
+      if (nonNullKeys > maxNumValues) {
+        return Optional.empty();
+      } else if (nonNullKeys == keys.size()) {
+        return Optional.of(keys);
+      } else {
+        return Optional.of(Sets.difference(keys, nullEquivalentValues));
+      }
+    } else {
+      return Optional.empty();
+    }
+  }
+
+  @Override
   public Optional<Set<String>> getCorrelatedColumnValues(
       String searchColumnName,
       String searchColumnValue,
diff --git a/processing/src/main/java/org/apache/druid/segment/join/table/IndexedTableJoinable.java b/processing/src/main/java/org/apache/druid/segment/join/table/IndexedTableJoinable.java
index 4faaf54..e59b4fe 100644
--- a/processing/src/main/java/org/apache/druid/segment/join/table/IndexedTableJoinable.java
+++ b/processing/src/main/java/org/apache/druid/segment/join/table/IndexedTableJoinable.java
@@ -20,8 +20,10 @@
 package org.apache.druid.segment.join.table;
 
 import it.unimi.dsi.fastutil.ints.IntList;
+import org.apache.druid.common.config.NullHandling;
 import org.apache.druid.java.util.common.io.Closer;
 import org.apache.druid.segment.ColumnSelectorFactory;
+import org.apache.druid.segment.DimensionHandlerUtils;
 import org.apache.druid.segment.column.ColumnCapabilities;
 import org.apache.druid.segment.join.JoinConditionAnalysis;
 import org.apache.druid.segment.join.JoinMatcher;
@@ -35,6 +37,7 @@
 import java.util.Objects;
 import java.util.Optional;
 import java.util.Set;
+import java.util.TreeSet;
 
 public class IndexedTableJoinable implements Joinable
 {
@@ -89,6 +92,42 @@
   }
 
   @Override
+  public Optional<Set<String>> getNonNullColumnValuesIfAllUnique(final String columnName, final int maxNumValues)
+  {
+    final int columnPosition = table.rowSignature().indexOf(columnName);
+
+    if (columnPosition < 0) {
+      return Optional.empty();
+    }
+
+    try (final IndexedTable.Reader reader = table.columnReader(columnPosition)) {
+      // Sorted set to encourage "in" filters that result from this method to do dictionary lookups in order.
+      // The hopes are that this will improve locality and therefore improve performance.
+      final Set<String> allValues = new TreeSet<>();
+
+      for (int i = 0; i < table.numRows(); i++) {
+        final String s = DimensionHandlerUtils.convertObjectToString(reader.read(i));
+
+        if (!NullHandling.isNullOrEquivalent(s)) {
+          if (!allValues.add(s)) {
+            // Duplicate found. Since the values are not all unique, we must return an empty Optional.
+            return Optional.empty();
+          }
+
+          if (allValues.size() > maxNumValues) {
+            return Optional.empty();
+          }
+        }
+      }
+
+      return Optional.of(allValues);
+    }
+    catch (IOException e) {
+      throw new RuntimeException(e);
+    }
+  }
+
+  @Override
   public Optional<Set<String>> getCorrelatedColumnValues(
       String searchColumnName,
       String searchColumnValue,
@@ -112,7 +151,7 @@
         IntList rowIndex = index.find(searchColumnValue);
         for (int i = 0; i < rowIndex.size(); i++) {
           int rowNum = rowIndex.getInt(i);
-          String correlatedDimVal = Objects.toString(reader.read(rowNum), null);
+          String correlatedDimVal = DimensionHandlerUtils.convertObjectToString(reader.read(rowNum));
           correlatedValues.add(correlatedDimVal);
 
           if (correlatedValues.size() > maxCorrelationSetSize) {
@@ -132,7 +171,7 @@
         for (int i = 0; i < table.numRows(); i++) {
           String dimVal = Objects.toString(dimNameReader.read(i), null);
           if (searchColumnValue.equals(dimVal)) {
-            String correlatedDimVal = Objects.toString(correlatedColumnReader.read(i), null);
+            String correlatedDimVal = DimensionHandlerUtils.convertObjectToString(correlatedColumnReader.read(i));
             correlatedValues.add(correlatedDimVal);
             if (correlatedValues.size() > maxCorrelationSetSize) {
               return Optional.empty();
diff --git a/processing/src/test/java/org/apache/druid/query/groupby/GroupByQueryTest.java b/processing/src/test/java/org/apache/druid/query/groupby/GroupByQueryTest.java
index b6f76b4..cec90ed 100644
--- a/processing/src/test/java/org/apache/druid/query/groupby/GroupByQueryTest.java
+++ b/processing/src/test/java/org/apache/druid/query/groupby/GroupByQueryTest.java
@@ -21,11 +21,13 @@
 
 import com.fasterxml.jackson.databind.ObjectMapper;
 import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
 import com.google.common.collect.Ordering;
 import nl.jqno.equalsverifier.EqualsVerifier;
 import nl.jqno.equalsverifier.Warning;
 import org.apache.druid.java.util.common.Intervals;
 import org.apache.druid.java.util.common.granularity.Granularities;
+import org.apache.druid.math.expr.ExprMacroTable;
 import org.apache.druid.query.BaseQuery;
 import org.apache.druid.query.Query;
 import org.apache.druid.query.QueryRunnerTestHelper;
@@ -40,6 +42,7 @@
 import org.apache.druid.query.spec.QuerySegmentSpec;
 import org.apache.druid.segment.TestHelper;
 import org.apache.druid.segment.column.ValueType;
+import org.apache.druid.segment.virtual.ExpressionVirtualColumn;
 import org.junit.Assert;
 import org.junit.Test;
 
@@ -81,6 +84,33 @@
   }
 
   @Test
+  public void testGetRequiredColumns()
+  {
+    final GroupByQuery query = GroupByQuery
+        .builder()
+        .setDataSource(QueryRunnerTestHelper.DATA_SOURCE)
+        .setQuerySegmentSpec(QueryRunnerTestHelper.FIRST_TO_THIRD)
+        .setVirtualColumns(new ExpressionVirtualColumn("v", "\"other\"", ValueType.STRING, ExprMacroTable.nil()))
+        .setDimensions(new DefaultDimensionSpec("quality", "alias"), DefaultDimensionSpec.of("v"))
+        .setAggregatorSpecs(QueryRunnerTestHelper.ROWS_COUNT, new LongSumAggregatorFactory("idx", "index"))
+        .setGranularity(QueryRunnerTestHelper.DAY_GRAN)
+        .setPostAggregatorSpecs(ImmutableList.of(new FieldAccessPostAggregator("x", "idx")))
+        .setLimitSpec(
+            new DefaultLimitSpec(
+                ImmutableList.of(new OrderByColumnSpec(
+                    "alias",
+                    OrderByColumnSpec.Direction.ASCENDING,
+                    StringComparators.LEXICOGRAPHIC
+                )),
+                100
+            )
+        )
+        .build();
+
+    Assert.assertEquals(ImmutableSet.of("__time", "quality", "other", "index"), query.getRequiredColumns());
+  }
+
+  @Test
   public void testRowOrderingMixTypes()
   {
     final GroupByQuery query = GroupByQuery.builder()
diff --git a/processing/src/test/java/org/apache/druid/query/scan/ScanQueryTest.java b/processing/src/test/java/org/apache/druid/query/scan/ScanQueryTest.java
index 1854883..7972725 100644
--- a/processing/src/test/java/org/apache/druid/query/scan/ScanQueryTest.java
+++ b/processing/src/test/java/org/apache/druid/query/scan/ScanQueryTest.java
@@ -269,4 +269,47 @@
     // This should throw an ISE
     List<ScanResultValue> res = borkedSequence.toList();
   }
+
+  @Test
+  public void testGetRequiredColumnsWithNoColumns()
+  {
+    final ScanQuery query =
+        Druids.newScanQueryBuilder()
+              .order(ScanQuery.Order.DESCENDING)
+              .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_LIST)
+              .dataSource("some src")
+              .intervals(intervalSpec)
+              .build();
+
+    Assert.assertNull(query.getRequiredColumns());
+  }
+
+  @Test
+  public void testGetRequiredColumnsWithEmptyColumns()
+  {
+    final ScanQuery query =
+        Druids.newScanQueryBuilder()
+              .order(ScanQuery.Order.DESCENDING)
+              .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_LIST)
+              .dataSource("some src")
+              .intervals(intervalSpec)
+              .columns(Collections.emptyList())
+              .build();
+
+    Assert.assertNull(query.getRequiredColumns());
+  }
+
+  @Test
+  public void testGetRequiredColumnsWithColumns()
+  {
+    final ScanQuery query =
+        Druids.newScanQueryBuilder()
+              .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_LIST)
+              .dataSource("some src")
+              .intervals(intervalSpec)
+              .columns("foo", "bar")
+              .build();
+
+    Assert.assertEquals(ImmutableSet.of("__time", "foo", "bar"), query.getRequiredColumns());
+  }
 }
diff --git a/processing/src/test/java/org/apache/druid/query/timeseries/TimeseriesQueryTest.java b/processing/src/test/java/org/apache/druid/query/timeseries/TimeseriesQueryTest.java
index 3108802..54bebf8 100644
--- a/processing/src/test/java/org/apache/druid/query/timeseries/TimeseriesQueryTest.java
+++ b/processing/src/test/java/org/apache/druid/query/timeseries/TimeseriesQueryTest.java
@@ -20,10 +20,15 @@
 package org.apache.druid.query.timeseries;
 
 import com.fasterxml.jackson.databind.ObjectMapper;
+import com.google.common.collect.ImmutableSet;
+import org.apache.druid.math.expr.ExprMacroTable;
 import org.apache.druid.query.Druids;
 import org.apache.druid.query.Query;
 import org.apache.druid.query.QueryRunnerTestHelper;
+import org.apache.druid.query.aggregation.LongSumAggregatorFactory;
 import org.apache.druid.segment.TestHelper;
+import org.apache.druid.segment.column.ValueType;
+import org.apache.druid.segment.virtual.ExpressionVirtualColumn;
 import org.junit.Assert;
 import org.junit.Test;
 import org.junit.runner.RunWith;
@@ -54,13 +59,13 @@
   public void testQuerySerialization() throws IOException
   {
     Query query = Druids.newTimeseriesQueryBuilder()
-        .dataSource(QueryRunnerTestHelper.DATA_SOURCE)
-        .granularity(QueryRunnerTestHelper.DAY_GRAN)
-        .intervals(QueryRunnerTestHelper.FULL_ON_INTERVAL_SPEC)
-        .aggregators(QueryRunnerTestHelper.ROWS_COUNT, QueryRunnerTestHelper.INDEX_DOUBLE_SUM)
-        .postAggregators(QueryRunnerTestHelper.ADD_ROWS_INDEX_CONSTANT)
-        .descending(descending)
-        .build();
+                        .dataSource(QueryRunnerTestHelper.DATA_SOURCE)
+                        .granularity(QueryRunnerTestHelper.DAY_GRAN)
+                        .intervals(QueryRunnerTestHelper.FULL_ON_INTERVAL_SPEC)
+                        .aggregators(QueryRunnerTestHelper.ROWS_COUNT, QueryRunnerTestHelper.INDEX_DOUBLE_SUM)
+                        .postAggregators(QueryRunnerTestHelper.ADD_ROWS_INDEX_CONSTANT)
+                        .descending(descending)
+                        .build();
 
     String json = JSON_MAPPER.writeValueAsString(query);
     Query serdeQuery = JSON_MAPPER.readValue(json, Query.class);
@@ -68,4 +73,32 @@
     Assert.assertEquals(query, serdeQuery);
   }
 
+  @Test
+  public void testGetRequiredColumns()
+  {
+    final TimeseriesQuery query =
+        Druids.newTimeseriesQueryBuilder()
+              .dataSource(QueryRunnerTestHelper.DATA_SOURCE)
+              .granularity(QueryRunnerTestHelper.DAY_GRAN)
+              .virtualColumns(
+                  new ExpressionVirtualColumn(
+                      "index",
+                      "\"fieldFromVirtualColumn\"",
+                      ValueType.LONG,
+                      ExprMacroTable.nil()
+                  )
+              )
+              .intervals(QueryRunnerTestHelper.FULL_ON_INTERVAL_SPEC)
+              .aggregators(
+                  QueryRunnerTestHelper.ROWS_COUNT,
+                  QueryRunnerTestHelper.INDEX_DOUBLE_SUM,
+                  QueryRunnerTestHelper.INDEX_LONG_MAX,
+                  new LongSumAggregatorFactory("beep", "aField")
+              )
+              .postAggregators(QueryRunnerTestHelper.ADD_ROWS_INDEX_CONSTANT)
+              .descending(descending)
+              .build();
+
+    Assert.assertEquals(ImmutableSet.of("__time", "fieldFromVirtualColumn", "aField"), query.getRequiredColumns());
+  }
 }
diff --git a/processing/src/test/java/org/apache/druid/query/topn/TopNQueryTest.java b/processing/src/test/java/org/apache/druid/query/topn/TopNQueryTest.java
index 82c77b1..5aede90 100644
--- a/processing/src/test/java/org/apache/druid/query/topn/TopNQueryTest.java
+++ b/processing/src/test/java/org/apache/druid/query/topn/TopNQueryTest.java
@@ -20,19 +20,27 @@
 package org.apache.druid.query.topn;
 
 import com.fasterxml.jackson.databind.ObjectMapper;
+import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSet;
 import com.google.common.collect.Iterables;
 import com.google.common.collect.Lists;
+import org.apache.druid.math.expr.ExprMacroTable;
 import org.apache.druid.query.Query;
 import org.apache.druid.query.QueryRunnerTestHelper;
 import org.apache.druid.query.aggregation.DoubleMaxAggregatorFactory;
 import org.apache.druid.query.aggregation.DoubleMinAggregatorFactory;
+import org.apache.druid.query.aggregation.LongSumAggregatorFactory;
+import org.apache.druid.query.aggregation.post.FieldAccessPostAggregator;
+import org.apache.druid.query.dimension.DefaultDimensionSpec;
 import org.apache.druid.query.dimension.ExtractionDimensionSpec;
 import org.apache.druid.query.dimension.LegacyDimensionSpec;
 import org.apache.druid.query.extraction.MapLookupExtractor;
 import org.apache.druid.query.lookup.LookupExtractionFn;
 import org.apache.druid.query.ordering.StringComparators;
 import org.apache.druid.segment.TestHelper;
+import org.apache.druid.segment.column.ValueType;
+import org.apache.druid.segment.virtual.ExpressionVirtualColumn;
 import org.junit.Assert;
 import org.junit.Rule;
 import org.junit.Test;
@@ -240,4 +248,22 @@
     String json = JSON_MAPPER.writeValueAsString(query);
     JSON_MAPPER.readValue(json, Query.class);
   }
+
+  @Test
+  public void testGetRequiredColumns()
+  {
+    final TopNQuery query = new TopNQueryBuilder()
+        .dataSource(QueryRunnerTestHelper.DATA_SOURCE)
+        .intervals(QueryRunnerTestHelper.FIRST_TO_THIRD)
+        .virtualColumns(new ExpressionVirtualColumn("v", "\"other\"", ValueType.STRING, ExprMacroTable.nil()))
+        .dimension(DefaultDimensionSpec.of("v"))
+        .aggregators(QueryRunnerTestHelper.ROWS_COUNT, new LongSumAggregatorFactory("idx", "index"))
+        .granularity(QueryRunnerTestHelper.DAY_GRAN)
+        .postAggregators(ImmutableList.of(new FieldAccessPostAggregator("x", "idx")))
+        .metric(new NumericTopNMetricSpec("idx"))
+        .threshold(100)
+        .build();
+
+    Assert.assertEquals(ImmutableSet.of("__time", "other", "index"), query.getRequiredColumns());
+  }
 }
diff --git a/processing/src/test/java/org/apache/druid/segment/join/BaseHashJoinSegmentStorageAdapterTest.java b/processing/src/test/java/org/apache/druid/segment/join/BaseHashJoinSegmentStorageAdapterTest.java
index d5dc9a2..26ba161 100644
--- a/processing/src/test/java/org/apache/druid/segment/join/BaseHashJoinSegmentStorageAdapterTest.java
+++ b/processing/src/test/java/org/apache/druid/segment/join/BaseHashJoinSegmentStorageAdapterTest.java
@@ -30,6 +30,7 @@
 import org.apache.druid.segment.VirtualColumn;
 import org.apache.druid.segment.VirtualColumns;
 import org.apache.druid.segment.column.ValueType;
+import org.apache.druid.segment.filter.Filters;
 import org.apache.druid.segment.join.filter.JoinFilterAnalyzer;
 import org.apache.druid.segment.join.filter.JoinFilterPreAnalysis;
 import org.apache.druid.segment.join.filter.JoinFilterPreAnalysisKey;
@@ -48,6 +49,7 @@
 import org.junit.rules.TemporaryFolder;
 
 import java.io.IOException;
+import java.util.Collections;
 import java.util.List;
 
 public class BaseHashJoinSegmentStorageAdapterTest
@@ -56,6 +58,7 @@
       true,
       true,
       true,
+      QueryContexts.DEFAULT_ENABLE_REWRITE_JOIN_TO_FILTER,
       QueryContexts.DEFAULT_ENABLE_JOIN_FILTER_REWRITE_MAX_SIZE
   );
 
@@ -235,12 +238,16 @@
       VirtualColumns virtualColumns
   )
   {
+    // Seemingly-useless "Filter.maybeAnd" is here to dedupe filters, flatten stacks, etc, in the same way that
+    // JoinableFactoryWrapper's segmentMapFn would do.
+    final Filter filterToUse = Filters.maybeAnd(Collections.singletonList(originalFilter)).orElse(null);
+
     return JoinFilterAnalyzer.computeJoinFilterPreAnalysis(
         new JoinFilterPreAnalysisKey(
             DEFAULT_JOIN_FILTER_REWRITE_CONFIG,
             joinableClauses,
             virtualColumns,
-            originalFilter
+            filterToUse
         )
     );
   }
diff --git a/processing/src/test/java/org/apache/druid/segment/join/HashJoinSegmentStorageAdapterTest.java b/processing/src/test/java/org/apache/druid/segment/join/HashJoinSegmentStorageAdapterTest.java
index 68e6426..10d0483 100644
--- a/processing/src/test/java/org/apache/druid/segment/join/HashJoinSegmentStorageAdapterTest.java
+++ b/processing/src/test/java/org/apache/druid/segment/join/HashJoinSegmentStorageAdapterTest.java
@@ -2024,20 +2024,22 @@
   @Test
   public void test_makeCursors_factToCountryLeftWithBaseFilter()
   {
+    final Filter baseFilter = Filters.or(Arrays.asList(
+        new SelectorDimFilter("countryIsoCode", "CA", null).toFilter(),
+        new SelectorDimFilter("countryIsoCode", "MatchNothing", null).toFilter()
+    ));
+
     List<JoinableClause> joinableClauses = ImmutableList.of(factToCountryOnIsoCode(JoinType.LEFT));
 
     JoinFilterPreAnalysis joinFilterPreAnalysis = makeDefaultConfigPreAnalysis(
-        null,
+        baseFilter,
         joinableClauses,
         VirtualColumns.EMPTY
     );
     JoinTestHelper.verifyCursors(
         new HashJoinSegmentStorageAdapter(
             factSegment.asStorageAdapter(),
-            Filters.or(Arrays.asList(
-                new SelectorDimFilter("countryIsoCode", "CA", null).toFilter(),
-                new SelectorDimFilter("countryIsoCode", "MatchNothing", null).toFilter()
-            )),
+            baseFilter,
             joinableClauses,
             joinFilterPreAnalysis
         ).makeCursors(
@@ -2067,19 +2069,21 @@
   @Test
   public void test_makeCursors_factToCountryInnerWithBaseFilter()
   {
+    final Filter baseFilter = Filters.or(Arrays.asList(
+        new SelectorDimFilter("countryIsoCode", "CA", null).toFilter(),
+        new SelectorDimFilter("countryIsoCode", "MatchNothing", null).toFilter()
+    ));
+
     List<JoinableClause> joinableClauses = ImmutableList.of(factToCountryOnIsoCode(JoinType.INNER));
     JoinFilterPreAnalysis joinFilterPreAnalysis = makeDefaultConfigPreAnalysis(
-        null,
+        baseFilter,
         joinableClauses,
         VirtualColumns.EMPTY
     );
     JoinTestHelper.verifyCursors(
         new HashJoinSegmentStorageAdapter(
             factSegment.asStorageAdapter(),
-            Filters.or(Arrays.asList(
-                new SelectorDimFilter("countryIsoCode", "CA", null).toFilter(),
-                new SelectorDimFilter("countryIsoCode", "MatchNothing", null).toFilter()
-            )),
+            baseFilter,
             joinableClauses,
             joinFilterPreAnalysis
         ).makeCursors(
@@ -2108,19 +2112,21 @@
   @Test
   public void test_makeCursors_factToCountryRightWithBaseFilter()
   {
+    final Filter baseFilter = Filters.or(Arrays.asList(
+        new SelectorDimFilter("countryIsoCode", "CA", null).toFilter(),
+        new SelectorDimFilter("countryIsoCode", "MatchNothing", null).toFilter()
+    ));
+
     List<JoinableClause> joinableClauses = ImmutableList.of(factToCountryOnIsoCode(JoinType.RIGHT));
     JoinFilterPreAnalysis joinFilterPreAnalysis = makeDefaultConfigPreAnalysis(
-        null,
+        baseFilter,
         joinableClauses,
         VirtualColumns.EMPTY
     );
     JoinTestHelper.verifyCursors(
         new HashJoinSegmentStorageAdapter(
             factSegment.asStorageAdapter(),
-            Filters.or(Arrays.asList(
-                new SelectorDimFilter("countryIsoCode", "CA", null).toFilter(),
-                new SelectorDimFilter("countryIsoCode", "MatchNothing", null).toFilter()
-            )),
+            baseFilter,
             joinableClauses,
             joinFilterPreAnalysis
         ).makeCursors(
@@ -2166,19 +2172,21 @@
   @Test
   public void test_makeCursors_factToCountryFullWithBaseFilter()
   {
+    final Filter baseFilter = Filters.or(Arrays.asList(
+        new SelectorDimFilter("countryIsoCode", "CA", null).toFilter(),
+        new SelectorDimFilter("countryIsoCode", "MatchNothing", null).toFilter()
+    ));
+
     List<JoinableClause> joinableClauses = ImmutableList.of(factToCountryOnIsoCode(JoinType.FULL));
     JoinFilterPreAnalysis joinFilterPreAnalysis = makeDefaultConfigPreAnalysis(
-        null,
+        baseFilter,
         joinableClauses,
         VirtualColumns.EMPTY
     );
     JoinTestHelper.verifyCursors(
         new HashJoinSegmentStorageAdapter(
             factSegment.asStorageAdapter(),
-            Filters.or(Arrays.asList(
-                new SelectorDimFilter("countryIsoCode", "CA", null).toFilter(),
-                new SelectorDimFilter("countryIsoCode", "MatchNothing", null).toFilter()
-            )),
+            baseFilter,
             joinableClauses,
             joinFilterPreAnalysis
         ).makeCursors(
diff --git a/processing/src/test/java/org/apache/druid/segment/join/HashJoinSegmentTest.java b/processing/src/test/java/org/apache/druid/segment/join/HashJoinSegmentTest.java
index 9a56b3b..581c9a1 100644
--- a/processing/src/test/java/org/apache/druid/segment/join/HashJoinSegmentTest.java
+++ b/processing/src/test/java/org/apache/druid/segment/join/HashJoinSegmentTest.java
@@ -22,13 +22,11 @@
 import com.google.common.collect.ImmutableList;
 import org.apache.druid.java.util.common.io.Closer;
 import org.apache.druid.math.expr.ExprMacroTable;
-import org.apache.druid.query.QueryContexts;
 import org.apache.druid.segment.QueryableIndex;
 import org.apache.druid.segment.QueryableIndexSegment;
 import org.apache.druid.segment.ReferenceCountingSegment;
 import org.apache.druid.segment.SegmentReference;
 import org.apache.druid.segment.StorageAdapter;
-import org.apache.druid.segment.join.filter.rewrite.JoinFilterRewriteConfig;
 import org.apache.druid.segment.join.table.IndexedTableJoinable;
 import org.apache.druid.testing.InitializedNullHandlingTest;
 import org.apache.druid.timeline.SegmentId;
@@ -49,14 +47,6 @@
 
 public class HashJoinSegmentTest extends InitializedNullHandlingTest
 {
-  private static final JoinFilterRewriteConfig DEFAULT_JOIN_FILTER_REWRITE_CONFIG =
-      new JoinFilterRewriteConfig(
-          true,
-          true,
-          true,
-          QueryContexts.DEFAULT_ENABLE_JOIN_FILTER_REWRITE_MAX_SIZE
-      );
-
   @Rule
   public TemporaryFolder temporaryFolder = new TemporaryFolder();
 
@@ -205,7 +195,7 @@
   public void test_constructor_noClauses()
   {
     expectedException.expect(IllegalArgumentException.class);
-    expectedException.expectMessage("'clauses' is empty, no need to create HashJoinSegment");
+    expectedException.expectMessage("'clauses' and 'baseFilter' are both empty, no need to create HashJoinSegment");
 
     List<JoinableClause> joinableClauses = ImmutableList.of();
 
diff --git a/processing/src/test/java/org/apache/druid/segment/join/JoinConditionAnalysisTest.java b/processing/src/test/java/org/apache/druid/segment/join/JoinConditionAnalysisTest.java
index 875f686..1ab6c09 100644
--- a/processing/src/test/java/org/apache/druid/segment/join/JoinConditionAnalysisTest.java
+++ b/processing/src/test/java/org/apache/druid/segment/join/JoinConditionAnalysisTest.java
@@ -275,13 +275,22 @@
   }
 
   @Test
+  public void test_getRequiredColumns()
+  {
+    final String expression = "(x == \"j.y\") && ((x + y == \"j.z\") || (z == \"j.zz\"))";
+    final JoinConditionAnalysis analysis = analyze(expression);
+
+    Assert.assertEquals(ImmutableSet.of("x", "j.y", "y", "j.z", "z", "j.zz"), analysis.getRequiredColumns());
+  }
+
+  @Test
   public void test_equals()
   {
     EqualsVerifier.forClass(JoinConditionAnalysis.class)
                   .usingGetClass()
                   .withIgnoredFields(
                           // These fields are tightly coupled with originalExpression
-                          "equiConditions", "nonEquiConditions",
+                          "equiConditions", "nonEquiConditions", "requiredColumns",
                           // These fields are calculated from other other fields in the class
                           "isAlwaysTrue", "isAlwaysFalse", "canHashJoin", "rightKeyColumns")
                   .verify();
diff --git a/processing/src/test/java/org/apache/druid/segment/join/JoinFilterAnalyzerTest.java b/processing/src/test/java/org/apache/druid/segment/join/JoinFilterAnalyzerTest.java
index e422423..694edde 100644
--- a/processing/src/test/java/org/apache/druid/segment/join/JoinFilterAnalyzerTest.java
+++ b/processing/src/test/java/org/apache/druid/segment/join/JoinFilterAnalyzerTest.java
@@ -2092,6 +2092,7 @@
                 false,
                 true,
                 true,
+                QueryContexts.DEFAULT_ENABLE_REWRITE_JOIN_TO_FILTER,
                 QueryContexts.DEFAULT_ENABLE_JOIN_FILTER_REWRITE_MAX_SIZE
             ),
             joinableClauses.getJoinableClauses(),
@@ -2171,6 +2172,7 @@
                 true,
                 false,
                 true,
+                QueryContexts.DEFAULT_ENABLE_REWRITE_JOIN_TO_FILTER,
                 QueryContexts.DEFAULT_ENABLE_JOIN_FILTER_REWRITE_MAX_SIZE
             ),
             joinableClauses.getJoinableClauses(),
@@ -2591,6 +2593,7 @@
                 true,
                 true,
                 true,
+                QueryContexts.DEFAULT_ENABLE_REWRITE_JOIN_TO_FILTER,
                 QueryContexts.DEFAULT_ENABLE_JOIN_FILTER_REWRITE_MAX_SIZE
             ),
             joinableClauses,
diff --git a/processing/src/test/java/org/apache/druid/segment/join/JoinableFactoryWrapperTest.java b/processing/src/test/java/org/apache/druid/segment/join/JoinableFactoryWrapperTest.java
index 94067c3..70491ad 100644
--- a/processing/src/test/java/org/apache/druid/segment/join/JoinableFactoryWrapperTest.java
+++ b/processing/src/test/java/org/apache/druid/segment/join/JoinableFactoryWrapperTest.java
@@ -21,25 +21,30 @@
 
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSet;
 import com.google.common.collect.Iterators;
+import com.google.common.collect.Sets;
+import org.apache.druid.common.config.NullHandling;
+import org.apache.druid.common.config.NullHandlingTest;
 import org.apache.druid.java.util.common.IAE;
 import org.apache.druid.java.util.common.Intervals;
+import org.apache.druid.java.util.common.Pair;
 import org.apache.druid.java.util.common.StringUtils;
 import org.apache.druid.math.expr.ExprMacroTable;
 import org.apache.druid.query.DataSource;
 import org.apache.druid.query.GlobalTableDataSource;
 import org.apache.druid.query.LookupDataSource;
-import org.apache.druid.query.QueryContexts;
 import org.apache.druid.query.TableDataSource;
 import org.apache.druid.query.TestQuery;
 import org.apache.druid.query.extraction.MapLookupExtractor;
 import org.apache.druid.query.filter.FalseDimFilter;
+import org.apache.druid.query.filter.Filter;
+import org.apache.druid.query.filter.InDimFilter;
 import org.apache.druid.query.filter.TrueDimFilter;
 import org.apache.druid.query.planning.DataSourceAnalysis;
 import org.apache.druid.query.planning.PreJoinableClause;
 import org.apache.druid.query.spec.MultipleIntervalSegmentSpec;
 import org.apache.druid.segment.SegmentReference;
-import org.apache.druid.segment.join.filter.rewrite.JoinFilterRewriteConfig;
 import org.apache.druid.segment.join.lookup.LookupJoinable;
 import org.easymock.EasyMock;
 import org.junit.Assert;
@@ -51,22 +56,31 @@
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
+import java.util.Map;
 import java.util.Optional;
+import java.util.Set;
 import java.util.concurrent.atomic.AtomicLong;
 import java.util.function.Function;
 
-public class JoinableFactoryWrapperTest
+public class JoinableFactoryWrapperTest extends NullHandlingTest
 {
-  private static final JoinFilterRewriteConfig DEFAULT_JOIN_FILTER_REWRITE_CONFIG = new JoinFilterRewriteConfig(
-      QueryContexts.DEFAULT_ENABLE_JOIN_FILTER_PUSH_DOWN,
-      QueryContexts.DEFAULT_ENABLE_JOIN_FILTER_REWRITE,
-      QueryContexts.DEFAULT_ENABLE_JOIN_FILTER_REWRITE_VALUE_COLUMN_FILTERS,
-      QueryContexts.DEFAULT_ENABLE_JOIN_FILTER_REWRITE_MAX_SIZE
-  );
-
   private static final JoinableFactoryWrapper NOOP_JOINABLE_FACTORY_WRAPPER = new JoinableFactoryWrapper(
       NoopJoinableFactory.INSTANCE);
 
+  private static final Map<String, String> TEST_LOOKUP =
+      ImmutableMap.<String, String>builder()
+          .put("MX", "Mexico")
+          .put("NO", "Norway")
+          .put("SV", "El Salvador")
+          .put("US", "United States")
+          .put("", "Empty key")
+          .build();
+
+  private static final Set<String> TEST_LOOKUP_KEYS =
+      NullHandling.sqlCompatible()
+      ? TEST_LOOKUP.keySet()
+      : Sets.difference(TEST_LOOKUP.keySet(), Collections.singleton(""));
+
   @Rule
   public ExpectedException expectedException = ExpectedException.none();
 
@@ -428,6 +442,300 @@
     JoinPrefixUtils.checkPrefixesForDuplicatesAndShadowing(prefixes);
   }
 
+  @Test
+  public void test_convertJoinsToFilters_convertInnerJoin()
+  {
+    final Pair<List<Filter>, List<JoinableClause>> conversion = JoinableFactoryWrapper.convertJoinsToFilters(
+        ImmutableList.of(
+            new JoinableClause(
+                "j.",
+                LookupJoinable.wrap(new MapLookupExtractor(TEST_LOOKUP, false)),
+                JoinType.INNER,
+                JoinConditionAnalysis.forExpression("x == \"j.k\"", "j.", ExprMacroTable.nil())
+            )
+        ),
+        ImmutableSet.of("x"),
+        Integer.MAX_VALUE
+    );
+
+    Assert.assertEquals(
+        Pair.of(
+            ImmutableList.of(new InDimFilter("x", TEST_LOOKUP_KEYS)),
+            ImmutableList.of()
+        ),
+        conversion
+    );
+  }
+
+  @Test
+  public void test_convertJoinsToFilters_convertTwoInnerJoins()
+  {
+    final ImmutableList<JoinableClause> clauses = ImmutableList.of(
+        new JoinableClause(
+            "j.",
+            LookupJoinable.wrap(new MapLookupExtractor(TEST_LOOKUP, false)),
+            JoinType.INNER,
+            JoinConditionAnalysis.forExpression("x == \"j.k\"", "j.", ExprMacroTable.nil())
+        ),
+        new JoinableClause(
+            "_j.",
+            LookupJoinable.wrap(new MapLookupExtractor(TEST_LOOKUP, false)),
+            JoinType.INNER,
+            JoinConditionAnalysis.forExpression("x == \"_j.k\"", "_j.", ExprMacroTable.nil())
+        ),
+        new JoinableClause(
+            "__j.",
+            LookupJoinable.wrap(new MapLookupExtractor(TEST_LOOKUP, false)),
+            JoinType.LEFT,
+            JoinConditionAnalysis.forExpression("x == \"__j.k\"", "__j.", ExprMacroTable.nil())
+        )
+    );
+
+    final Pair<List<Filter>, List<JoinableClause>> conversion = JoinableFactoryWrapper.convertJoinsToFilters(
+        clauses,
+        ImmutableSet.of("x"),
+        Integer.MAX_VALUE
+    );
+
+    Assert.assertEquals(
+        Pair.of(
+            ImmutableList.of(new InDimFilter("x", TEST_LOOKUP_KEYS), new InDimFilter("x", TEST_LOOKUP_KEYS)),
+            ImmutableList.of(clauses.get(2))
+        ),
+        conversion
+    );
+  }
+
+  @Test
+  public void test_convertJoinsToFilters_dontConvertTooManyValues()
+  {
+    final JoinableClause clause = new JoinableClause(
+        "j.",
+        LookupJoinable.wrap(new MapLookupExtractor(TEST_LOOKUP, false)),
+        JoinType.INNER,
+        JoinConditionAnalysis.forExpression("x == \"j.k\"", "j.", ExprMacroTable.nil())
+    );
+
+    final Pair<List<Filter>, List<JoinableClause>> conversion = JoinableFactoryWrapper.convertJoinsToFilters(
+        ImmutableList.of(
+            clause
+        ),
+        ImmutableSet.of("x"),
+        2
+    );
+
+    Assert.assertEquals(
+        Pair.of(
+            ImmutableList.of(),
+            ImmutableList.of(clause)
+        ),
+        conversion
+    );
+  }
+
+  @Test
+  public void test_convertJoinsToFilters_dontConvertLeftJoin()
+  {
+    final JoinableClause clause = new JoinableClause(
+        "j.",
+        LookupJoinable.wrap(new MapLookupExtractor(TEST_LOOKUP, false)),
+        JoinType.LEFT,
+        JoinConditionAnalysis.forExpression("x == \"j.k\"", "j.", ExprMacroTable.nil())
+    );
+
+    final Pair<List<Filter>, List<JoinableClause>> conversion = JoinableFactoryWrapper.convertJoinsToFilters(
+        ImmutableList.of(clause),
+        ImmutableSet.of("x"),
+        Integer.MAX_VALUE
+    );
+
+    Assert.assertEquals(
+        Pair.of(
+            ImmutableList.of(),
+            ImmutableList.of(clause)
+        ),
+        conversion
+    );
+  }
+
+  @Test
+  public void test_convertJoinsToFilters_dontConvertWhenColumnIsUsed()
+  {
+    final JoinableClause clause = new JoinableClause(
+        "j.",
+        LookupJoinable.wrap(new MapLookupExtractor(TEST_LOOKUP, false)),
+        JoinType.INNER,
+        JoinConditionAnalysis.forExpression("x == \"j.k\"", "j.", ExprMacroTable.nil())
+    );
+
+    final Pair<List<Filter>, List<JoinableClause>> conversion = JoinableFactoryWrapper.convertJoinsToFilters(
+        ImmutableList.of(clause),
+        ImmutableSet.of("x", "j.k"),
+        Integer.MAX_VALUE
+    );
+
+    Assert.assertEquals(
+        Pair.of(
+            ImmutableList.of(),
+            ImmutableList.of(clause)
+        ),
+        conversion
+    );
+  }
+
+  @Test
+  public void test_convertJoinsToFilters_dontConvertLhsFunctions()
+  {
+    final JoinableClause clause = new JoinableClause(
+        "j.",
+        LookupJoinable.wrap(new MapLookupExtractor(TEST_LOOKUP, false)),
+        JoinType.INNER,
+        JoinConditionAnalysis.forExpression("concat(x,'') == \"j.k\"", "j.", ExprMacroTable.nil())
+    );
+
+    final Pair<List<Filter>, List<JoinableClause>> conversion = JoinableFactoryWrapper.convertJoinsToFilters(
+        ImmutableList.of(clause),
+        ImmutableSet.of("x"),
+        Integer.MAX_VALUE
+    );
+
+    Assert.assertEquals(
+        Pair.of(
+            ImmutableList.of(),
+            ImmutableList.of(clause)
+        ),
+        conversion
+    );
+  }
+
+  @Test
+  public void test_convertJoinsToFilters_dontConvertRhsFunctions()
+  {
+    final JoinableClause clause = new JoinableClause(
+        "j.",
+        LookupJoinable.wrap(new MapLookupExtractor(TEST_LOOKUP, false)),
+        JoinType.INNER,
+        JoinConditionAnalysis.forExpression("x == concat(\"j.k\",'')", "j.", ExprMacroTable.nil())
+    );
+
+    final Pair<List<Filter>, List<JoinableClause>> conversion = JoinableFactoryWrapper.convertJoinsToFilters(
+        ImmutableList.of(clause),
+        ImmutableSet.of("x"),
+        Integer.MAX_VALUE
+    );
+
+    Assert.assertEquals(
+        Pair.of(
+            ImmutableList.of(),
+            ImmutableList.of(clause)
+        ),
+        conversion
+    );
+  }
+
+  @Test
+  public void test_convertJoinsToFilters_dontConvertNonEquiJoin()
+  {
+    final JoinableClause clause = new JoinableClause(
+        "j.",
+        LookupJoinable.wrap(new MapLookupExtractor(TEST_LOOKUP, false)),
+        JoinType.INNER,
+        JoinConditionAnalysis.forExpression("x != \"j.k\"", "j.", ExprMacroTable.nil())
+    );
+
+    final Pair<List<Filter>, List<JoinableClause>> conversion = JoinableFactoryWrapper.convertJoinsToFilters(
+        ImmutableList.of(clause),
+        ImmutableSet.of("x"),
+        Integer.MAX_VALUE
+    );
+
+    Assert.assertEquals(
+        Pair.of(
+            ImmutableList.of(),
+            ImmutableList.of(clause)
+        ),
+        conversion
+    );
+  }
+
+  @Test
+  public void test_convertJoinsToFilters_dontConvertJoinsDependedOnByLaterJoins()
+  {
+    final ImmutableList<JoinableClause> clauses = ImmutableList.of(
+        new JoinableClause(
+            "j.",
+            LookupJoinable.wrap(new MapLookupExtractor(TEST_LOOKUP, false)),
+            JoinType.INNER,
+            JoinConditionAnalysis.forExpression("x == \"j.k\"", "j.", ExprMacroTable.nil())
+        ),
+        new JoinableClause(
+            "_j.",
+            LookupJoinable.wrap(new MapLookupExtractor(TEST_LOOKUP, false)),
+            JoinType.INNER,
+            JoinConditionAnalysis.forExpression("\"j.k\" == \"_j.k\"", "_j.", ExprMacroTable.nil())
+        ),
+        new JoinableClause(
+            "__j.",
+            LookupJoinable.wrap(new MapLookupExtractor(TEST_LOOKUP, false)),
+            JoinType.LEFT,
+            JoinConditionAnalysis.forExpression("x == \"__j.k\"", "__j.", ExprMacroTable.nil())
+        )
+    );
+
+    final Pair<List<Filter>, List<JoinableClause>> conversion = JoinableFactoryWrapper.convertJoinsToFilters(
+        clauses,
+        ImmutableSet.of("x"),
+        Integer.MAX_VALUE
+    );
+
+    Assert.assertEquals(
+        Pair.of(
+            ImmutableList.of(),
+            clauses
+        ),
+        conversion
+    );
+  }
+
+  @Test
+  public void test_convertJoinsToFilters_dontConvertJoinsDependedOnByLaterJoins2()
+  {
+    final ImmutableList<JoinableClause> clauses = ImmutableList.of(
+        new JoinableClause(
+            "j.",
+            LookupJoinable.wrap(new MapLookupExtractor(TEST_LOOKUP, false)),
+            JoinType.INNER,
+            JoinConditionAnalysis.forExpression("x == \"j.k\"", "j.", ExprMacroTable.nil())
+        ),
+        new JoinableClause(
+            "_j.",
+            LookupJoinable.wrap(new MapLookupExtractor(TEST_LOOKUP, false)),
+            JoinType.INNER,
+            JoinConditionAnalysis.forExpression("x == \"_j.k\"", "_j.", ExprMacroTable.nil())
+        ),
+        new JoinableClause(
+            "__j.",
+            LookupJoinable.wrap(new MapLookupExtractor(TEST_LOOKUP, false)),
+            JoinType.LEFT,
+            JoinConditionAnalysis.forExpression("\"_j.v\" == \"__j.k\"", "__j.", ExprMacroTable.nil())
+        )
+    );
+
+    final Pair<List<Filter>, List<JoinableClause>> conversion = JoinableFactoryWrapper.convertJoinsToFilters(
+        clauses,
+        ImmutableSet.of("x"),
+        Integer.MAX_VALUE
+    );
+
+    Assert.assertEquals(
+        Pair.of(
+            ImmutableList.of(new InDimFilter("x", TEST_LOOKUP_KEYS)),
+            clauses.subList(1, clauses.size())
+        ),
+        conversion
+    );
+  }
+
   private PreJoinableClause makeGlobalPreJoinableClause(String tableName, String expression, String prefix)
   {
     return makeGlobalPreJoinableClause(tableName, expression, prefix, JoinType.LEFT);
diff --git a/processing/src/test/java/org/apache/druid/segment/join/filter/rewrite/JoinFilterRewriteConfigTest.java b/processing/src/test/java/org/apache/druid/segment/join/filter/rewrite/JoinFilterRewriteConfigTest.java
new file mode 100644
index 0000000..5d0b2f8
--- /dev/null
+++ b/processing/src/test/java/org/apache/druid/segment/join/filter/rewrite/JoinFilterRewriteConfigTest.java
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.segment.join.filter.rewrite;
+
+import nl.jqno.equalsverifier.EqualsVerifier;
+import org.junit.Test;
+
+public class JoinFilterRewriteConfigTest
+{
+  @Test
+  public void testEquals()
+  {
+    EqualsVerifier.forClass(JoinFilterRewriteConfig.class).usingGetClass().verify();
+  }
+}
diff --git a/processing/src/test/java/org/apache/druid/segment/join/lookup/LookupJoinableTest.java b/processing/src/test/java/org/apache/druid/segment/join/lookup/LookupJoinableTest.java
index 2037f77..4b1dcda 100644
--- a/processing/src/test/java/org/apache/druid/segment/join/lookup/LookupJoinableTest.java
+++ b/processing/src/test/java/org/apache/druid/segment/join/lookup/LookupJoinableTest.java
@@ -21,6 +21,8 @@
 
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableSet;
+import org.apache.druid.common.config.NullHandling;
+import org.apache.druid.common.config.NullHandlingTest;
 import org.apache.druid.query.lookup.LookupExtractor;
 import org.apache.druid.segment.column.ColumnCapabilities;
 import org.apache.druid.segment.column.ValueType;
@@ -35,12 +37,13 @@
 import org.mockito.junit.MockitoJUnitRunner;
 
 import java.util.Collections;
+import java.util.HashSet;
 import java.util.List;
 import java.util.Optional;
 import java.util.Set;
 
 @RunWith(MockitoJUnitRunner.class)
-public class LookupJoinableTest
+public class LookupJoinableTest extends NullHandlingTest
 {
   private static final String UNKNOWN_COLUMN = "UNKNOWN_COLUMN";
   private static final String SEARCH_KEY_VALUE = "SEARCH_KEY_VALUE";
@@ -56,9 +59,17 @@
   @Before
   public void setUp()
   {
+    final Set<String> keyValues = new HashSet<>();
+    keyValues.add("foo");
+    keyValues.add("bar");
+    keyValues.add("");
+    keyValues.add(null);
+
     Mockito.doReturn(SEARCH_VALUE_VALUE).when(extractor).apply(SEARCH_KEY_VALUE);
     Mockito.doReturn(ImmutableList.of(SEARCH_KEY_VALUE)).when(extractor).unapply(SEARCH_VALUE_VALUE);
     Mockito.doReturn(ImmutableList.of()).when(extractor).unapply(SEARCH_VALUE_UNKNOWN);
+    Mockito.doReturn(true).when(extractor).canGetKeySet();
+    Mockito.doReturn(keyValues).when(extractor).keySet();
     target = LookupJoinable.wrap(extractor);
   }
 
@@ -124,7 +135,8 @@
             SEARCH_KEY_VALUE,
             LookupColumnSelectorFactory.VALUE_COLUMN,
             0,
-            false);
+            false
+        );
 
     Assert.assertFalse(correlatedValues.isPresent());
   }
@@ -138,10 +150,12 @@
             SEARCH_KEY_VALUE,
             UNKNOWN_COLUMN,
             0,
-            false);
+            false
+        );
 
     Assert.assertFalse(correlatedValues.isPresent());
   }
+
   @Test
   public void getCorrelatedColumnValuesForSearchKeyAndRetrieveKeyColumnShouldReturnSearchValue()
   {
@@ -150,7 +164,8 @@
         SEARCH_KEY_VALUE,
         LookupColumnSelectorFactory.KEY_COLUMN,
         0,
-        false);
+        false
+    );
     Assert.assertEquals(Optional.of(ImmutableSet.of(SEARCH_KEY_VALUE)), correlatedValues);
   }
 
@@ -162,7 +177,8 @@
         SEARCH_KEY_VALUE,
         LookupColumnSelectorFactory.VALUE_COLUMN,
         0,
-        false);
+        false
+    );
     Assert.assertEquals(Optional.of(ImmutableSet.of(SEARCH_VALUE_VALUE)), correlatedValues);
   }
 
@@ -174,7 +190,8 @@
         SEARCH_KEY_NULL_VALUE,
         LookupColumnSelectorFactory.VALUE_COLUMN,
         0,
-        false);
+        false
+    );
     Assert.assertEquals(Optional.of(Collections.singleton(null)), correlatedValues);
   }
 
@@ -186,14 +203,16 @@
         SEARCH_VALUE_VALUE,
         LookupColumnSelectorFactory.VALUE_COLUMN,
         10,
-        false);
+        false
+    );
     Assert.assertEquals(Optional.empty(), correlatedValues);
     correlatedValues = target.getCorrelatedColumnValues(
         LookupColumnSelectorFactory.VALUE_COLUMN,
         SEARCH_VALUE_VALUE,
         LookupColumnSelectorFactory.KEY_COLUMN,
         10,
-        false);
+        false
+    );
     Assert.assertEquals(Optional.empty(), correlatedValues);
   }
 
@@ -205,7 +224,8 @@
         SEARCH_VALUE_VALUE,
         LookupColumnSelectorFactory.VALUE_COLUMN,
         0,
-        true);
+        true
+    );
     Assert.assertEquals(Optional.of(ImmutableSet.of(SEARCH_VALUE_VALUE)), correlatedValues);
   }
 
@@ -217,7 +237,8 @@
         SEARCH_VALUE_VALUE,
         LookupColumnSelectorFactory.KEY_COLUMN,
         10,
-        true);
+        true
+    );
     Assert.assertEquals(Optional.of(ImmutableSet.of(SEARCH_KEY_VALUE)), correlatedValues);
   }
 
@@ -234,7 +255,8 @@
         SEARCH_VALUE_VALUE,
         LookupColumnSelectorFactory.KEY_COLUMN,
         0,
-        true);
+        true
+    );
     Assert.assertEquals(Optional.empty(), correlatedValues);
   }
 
@@ -246,7 +268,46 @@
         SEARCH_VALUE_UNKNOWN,
         LookupColumnSelectorFactory.KEY_COLUMN,
         10,
-        true);
+        true
+    );
     Assert.assertEquals(Optional.of(ImmutableSet.of()), correlatedValues);
   }
+
+  @Test
+  public void getNonNullColumnValuesIfAllUniqueForValueColumnShouldReturnEmpty()
+  {
+    final Optional<Set<String>> values = target.getNonNullColumnValuesIfAllUnique(
+        LookupColumnSelectorFactory.VALUE_COLUMN,
+        Integer.MAX_VALUE
+    );
+
+    Assert.assertEquals(Optional.empty(), values);
+  }
+
+  @Test
+  public void getNonNullColumnValuesIfAllUniqueForKeyColumnShouldReturnValues()
+  {
+    final Optional<Set<String>> values = target.getNonNullColumnValuesIfAllUnique(
+        LookupColumnSelectorFactory.KEY_COLUMN,
+        Integer.MAX_VALUE
+    );
+
+    Assert.assertEquals(
+        Optional.of(
+            NullHandling.replaceWithDefault() ? ImmutableSet.of("foo", "bar") : ImmutableSet.of("foo", "bar", "")
+        ),
+        values
+    );
+  }
+
+  @Test
+  public void getNonNullColumnValuesIfAllUniqueForKeyColumnWithLowMaxValuesShouldReturnEmpty()
+  {
+    final Optional<Set<String>> values = target.getNonNullColumnValuesIfAllUnique(
+        LookupColumnSelectorFactory.KEY_COLUMN,
+        1
+    );
+
+    Assert.assertEquals(Optional.empty(), values);
+  }
 }
diff --git a/processing/src/test/java/org/apache/druid/segment/join/table/IndexedTableJoinableTest.java b/processing/src/test/java/org/apache/druid/segment/join/table/IndexedTableJoinableTest.java
index 5f54aa2..a9b1ae5 100644
--- a/processing/src/test/java/org/apache/druid/segment/join/table/IndexedTableJoinableTest.java
+++ b/processing/src/test/java/org/apache/druid/segment/join/table/IndexedTableJoinableTest.java
@@ -50,6 +50,7 @@
   private static final String PREFIX = "j.";
   private static final String KEY_COLUMN = "str";
   private static final String VALUE_COLUMN = "long";
+  private static final String ALL_SAME_COLUMN = "allsame";
   private static final String UNKNOWN_COLUMN = "unknown";
   private static final String SEARCH_KEY_NULL_VALUE = "baz";
   private static final String SEARCH_KEY_VALUE = "foo";
@@ -84,13 +85,14 @@
 
   private final InlineDataSource inlineDataSource = InlineDataSource.fromIterable(
       ImmutableList.of(
-          new Object[]{"foo", 1L},
-          new Object[]{"bar", 2L},
-          new Object[]{"baz", null}
+          new Object[]{"foo", 1L, 1L},
+          new Object[]{"bar", 2L, 1L},
+          new Object[]{"baz", null, 1L}
       ),
       RowSignature.builder()
-                  .add("str", ValueType.STRING)
-                  .add("long", ValueType.LONG)
+                  .add(KEY_COLUMN, ValueType.STRING)
+                  .add(VALUE_COLUMN, ValueType.LONG)
+                  .add(ALL_SAME_COLUMN, ValueType.LONG)
                   .build()
   );
 
@@ -113,7 +115,7 @@
   @Test
   public void getAvailableColumns()
   {
-    Assert.assertEquals(ImmutableList.of("str", "long"), target.getAvailableColumns());
+    Assert.assertEquals(ImmutableList.of(KEY_COLUMN, VALUE_COLUMN, ALL_SAME_COLUMN), target.getAvailableColumns());
   }
 
   @Test
@@ -340,4 +342,50 @@
         true);
     Assert.assertEquals(Optional.of(ImmutableSet.of()), correlatedValues);
   }
+
+  @Test
+  public void getNonNullColumnValuesIfAllUniqueForValueColumnShouldReturnValues()
+  {
+    final Optional<Set<String>> values = target.getNonNullColumnValuesIfAllUnique(VALUE_COLUMN, Integer.MAX_VALUE);
+
+    Assert.assertEquals(Optional.of(ImmutableSet.of("1", "2")), values);
+  }
+
+  @Test
+  public void getNonNullColumnValuesIfAllUniqueForNonexistentColumnShouldReturnEmpty()
+  {
+    final Optional<Set<String>> values = target.getNonNullColumnValuesIfAllUnique("nonexistent", Integer.MAX_VALUE);
+
+    Assert.assertEquals(Optional.empty(), values);
+  }
+
+  @Test
+  public void getNonNullColumnValuesIfAllUniqueForKeyColumnShouldReturnValues()
+  {
+    final Optional<Set<String>> values = target.getNonNullColumnValuesIfAllUnique(KEY_COLUMN, Integer.MAX_VALUE);
+
+    Assert.assertEquals(
+        Optional.of(ImmutableSet.of("foo", "bar", "baz")),
+        values
+    );
+  }
+
+  @Test
+  public void getNonNullColumnValuesIfAllUniqueForAllSameColumnShouldReturnEmpty()
+  {
+    final Optional<Set<String>> values = target.getNonNullColumnValuesIfAllUnique(ALL_SAME_COLUMN, Integer.MAX_VALUE);
+
+    Assert.assertEquals(
+        Optional.empty(),
+        values
+    );
+  }
+
+  @Test
+  public void getNonNullColumnValuesIfAllUniqueForKeyColumnWithLowMaxValuesShouldReturnEmpty()
+  {
+    final Optional<Set<String>> values = target.getNonNullColumnValuesIfAllUnique(KEY_COLUMN, 1);
+
+    Assert.assertEquals(Optional.empty(), values);
+  }
 }
diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/BaseCalciteQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/BaseCalciteQueryTest.java
index 9c816dc..19e5860 100644
--- a/sql/src/test/java/org/apache/druid/sql/calcite/BaseCalciteQueryTest.java
+++ b/sql/src/test/java/org/apache/druid/sql/calcite/BaseCalciteQueryTest.java
@@ -859,6 +859,14 @@
     skipVectorize = true;
   }
 
+  protected static boolean isRewriteJoinToFilter(final Map<String, Object> queryContext)
+  {
+    return (boolean) queryContext.getOrDefault(
+        QueryContexts.REWRITE_JOIN_TO_FILTER_ENABLE_KEY,
+        QueryContexts.DEFAULT_ENABLE_REWRITE_JOIN_TO_FILTER
+    );
+  }
+
   /**
    * This is a provider of query contexts that should be used by join tests.
    * It tests various configs that can be passed to join queries. All the configs provided by this provider should
@@ -872,23 +880,48 @@
       return new Object[]{
           // default behavior
           QUERY_CONTEXT_DEFAULT,
-          // filter value re-writes enabled
+          // all rewrites enabled
           new ImmutableMap.Builder<String, Object>()
               .putAll(QUERY_CONTEXT_DEFAULT)
               .put(QueryContexts.JOIN_FILTER_REWRITE_VALUE_COLUMN_FILTERS_ENABLE_KEY, true)
               .put(QueryContexts.JOIN_FILTER_REWRITE_ENABLE_KEY, true)
+              .put(QueryContexts.REWRITE_JOIN_TO_FILTER_ENABLE_KEY, true)
               .build(),
-          // rewrite values enabled but filter re-writes disabled.
-          // This should be drive the same behavior as the previous config
+          // filter-on-value-column rewrites disabled, everything else enabled
+          new ImmutableMap.Builder<String, Object>()
+              .putAll(QUERY_CONTEXT_DEFAULT)
+              .put(QueryContexts.JOIN_FILTER_REWRITE_VALUE_COLUMN_FILTERS_ENABLE_KEY, false)
+              .put(QueryContexts.JOIN_FILTER_REWRITE_ENABLE_KEY, true)
+              .put(QueryContexts.REWRITE_JOIN_TO_FILTER_ENABLE_KEY, true)
+              .build(),
+          // filter rewrites fully disabled, join-to-filter enabled
+          new ImmutableMap.Builder<String, Object>()
+              .putAll(QUERY_CONTEXT_DEFAULT)
+              .put(QueryContexts.JOIN_FILTER_REWRITE_VALUE_COLUMN_FILTERS_ENABLE_KEY, false)
+              .put(QueryContexts.JOIN_FILTER_REWRITE_ENABLE_KEY, false)
+              .put(QueryContexts.REWRITE_JOIN_TO_FILTER_ENABLE_KEY, true)
+              .build(),
+          // filter rewrites disabled, but value column filters still set to true (it should be ignored and this should
+          // behave the same as the previous context)
           new ImmutableMap.Builder<String, Object>()
               .putAll(QUERY_CONTEXT_DEFAULT)
               .put(QueryContexts.JOIN_FILTER_REWRITE_VALUE_COLUMN_FILTERS_ENABLE_KEY, true)
               .put(QueryContexts.JOIN_FILTER_REWRITE_ENABLE_KEY, false)
+              .put(QueryContexts.REWRITE_JOIN_TO_FILTER_ENABLE_KEY, true)
               .build(),
-          // filter re-writes disabled
+          // filter rewrites fully enabled, join-to-filter disabled
           new ImmutableMap.Builder<String, Object>()
               .putAll(QUERY_CONTEXT_DEFAULT)
+              .put(QueryContexts.JOIN_FILTER_REWRITE_VALUE_COLUMN_FILTERS_ENABLE_KEY, true)
+              .put(QueryContexts.JOIN_FILTER_REWRITE_ENABLE_KEY, true)
+              .put(QueryContexts.REWRITE_JOIN_TO_FILTER_ENABLE_KEY, false)
+              .build(),
+          // all rewrites disabled
+          new ImmutableMap.Builder<String, Object>()
+              .putAll(QUERY_CONTEXT_DEFAULT)
+              .put(QueryContexts.JOIN_FILTER_REWRITE_VALUE_COLUMN_FILTERS_ENABLE_KEY, false)
               .put(QueryContexts.JOIN_FILTER_REWRITE_ENABLE_KEY, false)
+              .put(QueryContexts.REWRITE_JOIN_TO_FILTER_ENABLE_KEY, false)
               .build(),
           };
     }
diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
index cbd875a..e47a89a 100644
--- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
+++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
@@ -370,13 +370,17 @@
   }
 
   @Test
-  public void testJoinOuterGroupByAndSubqueryNoLimit() throws Exception
+  @Parameters(source = QueryContextForJoinProvider.class)
+  public void testJoinOuterGroupByAndSubqueryNoLimit(Map<String, Object> queryContext) throws Exception
   {
-    // Cannot vectorize JOIN operator.
-    cannotVectorize();
+    // Fully removing the join allows this query to vectorize.
+    if (!isRewriteJoinToFilter(queryContext)) {
+      cannotVectorize();
+    }
 
     testQuery(
         "SELECT dim2, AVG(m2) FROM (SELECT * FROM foo AS t1 INNER JOIN foo AS t2 ON t1.m1 = t2.m1) AS t3 GROUP BY dim2",
+        queryContext,
         ImmutableList.of(
             GroupByQuery.builder()
                         .setDataSource(
@@ -390,6 +394,7 @@
                                         .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST)
                                         .context(QUERY_CONTEXT_DEFAULT)
                                         .build()
+                                        .withOverriddenContext(queryContext)
                                 ),
                                 "j0.",
                                 equalsCondition(
@@ -431,6 +436,7 @@
                         )
                         .setContext(QUERY_CONTEXT_DEFAULT)
                         .build()
+                        .withOverriddenContext(queryContext)
         ),
         NullHandling.sqlCompatible()
         ? ImmutableList.of(
@@ -4306,12 +4312,17 @@
   }
 
   @Test
-  public void testUnionAllTwoQueriesLeftQueryIsJoin() throws Exception
+  @Parameters(source = QueryContextForJoinProvider.class)
+  public void testUnionAllTwoQueriesLeftQueryIsJoin(Map<String, Object> queryContext) throws Exception
   {
-    cannotVectorize();
+    // Fully removing the join allows this query to vectorize.
+    if (!isRewriteJoinToFilter(queryContext)) {
+      cannotVectorize();
+    }
 
     testQuery(
         "(SELECT COUNT(*) FROM foo INNER JOIN lookup.lookyloo ON foo.dim1 = lookyloo.k)  UNION ALL SELECT SUM(cnt) FROM foo",
+        queryContext,
         ImmutableList.of(
             Druids.newTimeseriesQueryBuilder()
                   .dataSource(
@@ -4326,7 +4337,8 @@
                   .granularity(Granularities.ALL)
                   .aggregators(aggregators(new CountAggregatorFactory("a0")))
                   .context(TIMESERIES_CONTEXT_DEFAULT)
-                  .build(),
+                  .build()
+                  .withOverriddenContext(queryContext),
             Druids.newTimeseriesQueryBuilder()
                   .dataSource(CalciteTests.DATASOURCE1)
                   .intervals(querySegmentSpec(Filtration.eternity()))
@@ -4334,18 +4346,24 @@
                   .aggregators(aggregators(new LongSumAggregatorFactory("a0", "cnt")))
                   .context(TIMESERIES_CONTEXT_DEFAULT)
                   .build()
+                  .withOverriddenContext(queryContext)
         ),
         ImmutableList.of(new Object[]{1L}, new Object[]{6L})
     );
   }
 
   @Test
-  public void testUnionAllTwoQueriesRightQueryIsJoin() throws Exception
+  @Parameters(source = QueryContextForJoinProvider.class)
+  public void testUnionAllTwoQueriesRightQueryIsJoin(Map<String, Object> queryContext) throws Exception
   {
-    cannotVectorize();
+    // Fully removing the join allows this query to vectorize.
+    if (!isRewriteJoinToFilter(queryContext)) {
+      cannotVectorize();
+    }
 
     testQuery(
         "(SELECT SUM(cnt) FROM foo UNION ALL SELECT COUNT(*) FROM foo INNER JOIN lookup.lookyloo ON foo.dim1 = lookyloo.k) ",
+        queryContext,
         ImmutableList.of(
             Druids.newTimeseriesQueryBuilder()
                   .dataSource(CalciteTests.DATASOURCE1)
@@ -4353,7 +4371,8 @@
                   .granularity(Granularities.ALL)
                   .aggregators(aggregators(new LongSumAggregatorFactory("a0", "cnt")))
                   .context(TIMESERIES_CONTEXT_DEFAULT)
-                  .build(),
+                  .build()
+                  .withOverriddenContext(queryContext),
             Druids.newTimeseriesQueryBuilder()
                   .dataSource(
                       join(
@@ -4368,6 +4387,7 @@
                   .aggregators(aggregators(new CountAggregatorFactory("a0")))
                   .context(TIMESERIES_CONTEXT_DEFAULT)
                   .build()
+                  .withOverriddenContext(queryContext)
         ),
         ImmutableList.of(new Object[]{6L}, new Object[]{1L})
     );
@@ -8362,8 +8382,10 @@
   @Parameters(source = QueryContextForJoinProvider.class)
   public void testTopNFilterJoin(Map<String, Object> queryContext) throws Exception
   {
-    // Cannot vectorize JOIN operator.
-    cannotVectorize();
+    // Fully removing the join allows this query to vectorize.
+    if (!isRewriteJoinToFilter(queryContext)) {
+      cannotVectorize();
+    }
 
     // Filters on top N values of some dimension by using an inner join.
     testQuery(
@@ -13456,8 +13478,10 @@
   @Parameters(source = QueryContextForJoinProvider.class)
   public void testUsingSubqueryAsPartOfAndFilter(Map<String, Object> queryContext) throws Exception
   {
-    // Cannot vectorize JOIN operator.
-    cannotVectorize();
+    // Fully removing the join allows this query to vectorize.
+    if (!isRewriteJoinToFilter(queryContext)) {
+      cannotVectorize();
+    }
 
     testQuery(
         "SELECT dim1, dim2, COUNT(*) FROM druid.foo\n"
@@ -13925,6 +13949,234 @@
   }
 
   @Test
+  @Parameters(source = QueryContextForJoinProvider.class)
+  public void testTwoSemiJoinsSimultaneously(Map<String, Object> queryContext) throws Exception
+  {
+    // Fully removing the join allows this query to vectorize.
+    if (!isRewriteJoinToFilter(queryContext)) {
+      cannotVectorize();
+    }
+
+    testQuery(
+        "SELECT dim1, COUNT(*) FROM foo\n"
+        + "WHERE dim1 IN ('abc', 'def')"
+        + "AND __time IN (SELECT MAX(__time) FROM foo WHERE cnt = 1)\n"
+        + "AND __time IN (SELECT MAX(__time) FROM foo WHERE cnt <> 2)\n"
+        + "GROUP BY 1",
+        queryContext,
+        ImmutableList.of(
+            GroupByQuery.builder()
+                        .setDataSource(
+                            join(
+                                join(
+                                    new TableDataSource(CalciteTests.DATASOURCE1),
+                                    new QueryDataSource(
+                                        Druids.newTimeseriesQueryBuilder()
+                                              .dataSource(CalciteTests.DATASOURCE1)
+                                              .intervals(querySegmentSpec(Filtration.eternity()))
+                                              .granularity(Granularities.ALL)
+                                              .filters(selector("cnt", "1", null))
+                                              .aggregators(new LongMaxAggregatorFactory("a0", "__time"))
+                                              .context(TIMESERIES_CONTEXT_DEFAULT)
+                                              .build()
+                                    ),
+                                    "j0.",
+                                    "(\"__time\" == \"j0.a0\")",
+                                    JoinType.INNER
+                                ),
+                                new QueryDataSource(
+                                    Druids.newTimeseriesQueryBuilder()
+                                          .dataSource(CalciteTests.DATASOURCE1)
+                                          .intervals(querySegmentSpec(Filtration.eternity()))
+                                          .granularity(Granularities.ALL)
+                                          .filters(not(selector("cnt", "2", null)))
+                                          .aggregators(new LongMaxAggregatorFactory("a0", "__time"))
+                                          .context(TIMESERIES_CONTEXT_DEFAULT)
+                                          .build()
+                                ),
+                                "_j0.",
+                                "(\"__time\" == \"_j0.a0\")",
+                                JoinType.INNER
+                            )
+                        )
+                        .setInterval(querySegmentSpec(Filtration.eternity()))
+                        .setGranularity(Granularities.ALL)
+                        .setDimFilter(in("dim1", ImmutableList.of("abc", "def"), null))
+                        .setDimensions(dimensions(new DefaultDimensionSpec("dim1", "d0", ValueType.STRING)))
+                        .setAggregatorSpecs(aggregators(new CountAggregatorFactory("a0")))
+                        .setContext(queryContext)
+                        .build()
+        ),
+        ImmutableList.of(new Object[]{"abc", 1L})
+    );
+  }
+
+  @Test
+  @Parameters(source = QueryContextForJoinProvider.class)
+  public void testSemiAndAntiJoinSimultaneouslyUsingWhereInSubquery(Map<String, Object> queryContext) throws Exception
+  {
+    cannotVectorize();
+
+    testQuery(
+        "SELECT dim1, COUNT(*) FROM foo\n"
+        + "WHERE dim1 IN ('abc', 'def')\n"
+        + "AND __time IN (SELECT MAX(__time) FROM foo)\n"
+        + "AND __time NOT IN (SELECT MIN(__time) FROM foo)\n"
+        + "GROUP BY 1",
+        queryContext,
+        ImmutableList.of(
+            GroupByQuery.builder()
+                        .setDataSource(
+                            join(
+                                join(
+                                    join(
+                                        new TableDataSource(CalciteTests.DATASOURCE1),
+                                        new QueryDataSource(
+                                            Druids.newTimeseriesQueryBuilder()
+                                                  .dataSource(CalciteTests.DATASOURCE1)
+                                                  .intervals(querySegmentSpec(Filtration.eternity()))
+                                                  .granularity(Granularities.ALL)
+                                                  .aggregators(new LongMaxAggregatorFactory("a0", "__time"))
+                                                  .context(TIMESERIES_CONTEXT_DEFAULT)
+                                                  .build()
+                                        ),
+                                        "j0.",
+                                        "(\"__time\" == \"j0.a0\")",
+                                        JoinType.INNER
+                                    ),
+                                    new QueryDataSource(
+                                        GroupByQuery.builder()
+                                                    .setDataSource(
+                                                        new QueryDataSource(
+                                                            Druids.newTimeseriesQueryBuilder()
+                                                                  .dataSource(CalciteTests.DATASOURCE1)
+                                                                  .intervals(querySegmentSpec(Filtration.eternity()))
+                                                                  .granularity(Granularities.ALL)
+                                                                  .aggregators(
+                                                                      new LongMinAggregatorFactory("a0", "__time")
+                                                                  )
+                                                                  .context(TIMESERIES_CONTEXT_DEFAULT)
+                                                                  .build()
+                                                        )
+                                                    )
+                                                    .setInterval(querySegmentSpec(Filtration.eternity()))
+                                                    .setGranularity(Granularities.ALL)
+                                                    .setAggregatorSpecs(
+                                                        new CountAggregatorFactory("_a0"),
+                                                        NullHandling.sqlCompatible()
+                                                        ? new FilteredAggregatorFactory(
+                                                            new CountAggregatorFactory("_a1"),
+                                                            not(selector("a0", null, null))
+                                                        )
+                                                        : new CountAggregatorFactory("_a1")
+                                                    )
+                                                    .setContext(QUERY_CONTEXT_DEFAULT)
+                                                    .build()
+                                    ),
+                                    "_j0.",
+                                    "1",
+                                    JoinType.INNER
+                                ),
+                                new QueryDataSource(
+                                    Druids.newTimeseriesQueryBuilder()
+                                          .dataSource(CalciteTests.DATASOURCE1)
+                                          .intervals(querySegmentSpec(Filtration.eternity()))
+                                          .granularity(Granularities.ALL)
+                                          .aggregators(new LongMinAggregatorFactory("a0", "__time"))
+                                          .postAggregators(expressionPostAgg("p0", "1"))
+                                          .context(TIMESERIES_CONTEXT_DEFAULT)
+                                          .build()
+                                ),
+                                "__j0.",
+                                "(\"__time\" == \"__j0.a0\")",
+                                JoinType.LEFT
+                            )
+                        )
+                        .setInterval(querySegmentSpec(Filtration.eternity()))
+                        .setGranularity(Granularities.ALL)
+                        .setDimFilter(
+                            and(
+                                in("dim1", ImmutableList.of("abc", "def"), null),
+                                or(
+                                    selector("_j0._a0", "0", null),
+                                    and(selector("__j0.p0", null, null), expressionFilter("(\"_j0._a1\" >= \"_j0._a0\")"))
+                                )
+                            )
+                        )
+                        .setDimensions(dimensions(new DefaultDimensionSpec("dim1", "d0", ValueType.STRING)))
+                        .setAggregatorSpecs(aggregators(new CountAggregatorFactory("a0")))
+                        .setContext(queryContext)
+                        .build()
+        ),
+        ImmutableList.of(new Object[]{"abc", 1L})
+    );
+  }
+
+  @Test
+  @Parameters(source = QueryContextForJoinProvider.class)
+  public void testSemiAndAntiJoinSimultaneouslyUsingExplicitJoins(Map<String, Object> queryContext) throws Exception
+  {
+    cannotVectorize();
+
+    testQuery(
+        "SELECT dim1, COUNT(*) FROM\n"
+        + "foo\n"
+        + "INNER JOIN (SELECT MAX(__time) t FROM foo) t0 on t0.t = foo.__time\n"
+        + "LEFT JOIN (SELECT MIN(__time) t FROM foo) t1 on t1.t = foo.__time\n"
+        + "WHERE dim1 IN ('abc', 'def') AND t1.t is null\n"
+        + "GROUP BY 1",
+        queryContext,
+        ImmutableList.of(
+            GroupByQuery.builder()
+                        .setDataSource(
+                            join(
+                                join(
+                                    new TableDataSource(CalciteTests.DATASOURCE1),
+                                    new QueryDataSource(
+                                        Druids.newTimeseriesQueryBuilder()
+                                              .dataSource(CalciteTests.DATASOURCE1)
+                                              .intervals(querySegmentSpec(Filtration.eternity()))
+                                              .granularity(Granularities.ALL)
+                                              .aggregators(new LongMaxAggregatorFactory("a0", "__time"))
+                                              .context(TIMESERIES_CONTEXT_DEFAULT)
+                                              .build()
+                                    ),
+                                    "j0.",
+                                    "(\"__time\" == \"j0.a0\")",
+                                    JoinType.INNER
+                                ),
+                                new QueryDataSource(
+                                    Druids.newTimeseriesQueryBuilder()
+                                          .dataSource(CalciteTests.DATASOURCE1)
+                                          .intervals(querySegmentSpec(Filtration.eternity()))
+                                          .granularity(Granularities.ALL)
+                                          .aggregators(new LongMinAggregatorFactory("a0", "__time"))
+                                          .context(TIMESERIES_CONTEXT_DEFAULT)
+                                          .build()
+                                ),
+                                "_j0.",
+                                "(\"__time\" == \"_j0.a0\")",
+                                JoinType.LEFT
+                            )
+                        )
+                        .setInterval(querySegmentSpec(Filtration.eternity()))
+                        .setGranularity(Granularities.ALL)
+                        .setDimFilter(
+                            and(
+                                in("dim1", ImmutableList.of("abc", "def"), null),
+                                selector("_j0.a0", null, null)
+                            )
+                        )
+                        .setDimensions(dimensions(new DefaultDimensionSpec("dim1", "d0", ValueType.STRING)))
+                        .setAggregatorSpecs(aggregators(new CountAggregatorFactory("a0")))
+                        .setContext(queryContext)
+                        .build()
+        ),
+        ImmutableList.of(new Object[]{"abc", 1L})
+    );
+  }
+
+  @Test
   public void testSemiJoinWithOuterTimeExtractAggregateWithOrderBy() throws Exception
   {
     // Cannot vectorize due to virtual columns.
@@ -14006,8 +14258,10 @@
   @Parameters(source = QueryContextForJoinProvider.class)
   public void testInAggregationSubquery(Map<String, Object> queryContext) throws Exception
   {
-    // Cannot vectorize JOIN operator.
-    cannotVectorize();
+    // Fully removing the join allows this query to vectorize.
+    if (!isRewriteJoinToFilter(queryContext)) {
+      cannotVectorize();
+    }
 
     testQuery(
         "SELECT DISTINCT __time FROM druid.foo WHERE __time IN (SELECT MAX(__time) FROM druid.foo)",
@@ -14025,6 +14279,7 @@
                                           .aggregators(new LongMaxAggregatorFactory("a0", "__time"))
                                           .context(TIMESERIES_CONTEXT_DEFAULT)
                                           .build()
+                                          .withOverriddenContext(queryContext)
                                 ),
                                 "j0.",
                                 equalsCondition(
@@ -14037,8 +14292,9 @@
                         .setInterval(querySegmentSpec(Filtration.eternity()))
                         .setGranularity(Granularities.ALL)
                         .setDimensions(dimensions(new DefaultDimensionSpec("__time", "d0", ValueType.LONG)))
-                        .setContext(queryContext)
+                        .setContext(QUERY_CONTEXT_DEFAULT)
                         .build()
+                        .withOverriddenContext(queryContext)
         ),
         ImmutableList.of(
             new Object[]{timestamp("2001-01-03")}