InDimFilter: Fix NPE involving certain Set types. (#11169)

* InDimFilter: Fix NPE involving certain Set types.

Normally, InDimFilters that come from JSON have HashSets for "values".
However, programmatically-generated filters (like the ones from #11068)
may use other set types. Some set types, like TreeSets with natural
ordering, will throw NPE on "contains(null)", which causes the
InDimFilter's ValueMatcher to throw NPE if it encounters a null value.

This patch adds code to detect if the values set can support
contains(null), and if not, wrap that in a null-checking lambda.

Also included:

- Remove unneeded NullHandling.needsEmptyToNull method.
- Update IndexedTableJoinable to generate a TreeSet that does not
  require lambda-wrapping. (This particular TreeSet is how I noticed
  the bug in the first place.)

* Test fixes.

* Improve test coverage
diff --git a/core/src/main/java/org/apache/druid/common/config/NullHandling.java b/core/src/main/java/org/apache/druid/common/config/NullHandling.java
index 51cb265..bd0f0ee 100644
--- a/core/src/main/java/org/apache/druid/common/config/NullHandling.java
+++ b/core/src/main/java/org/apache/druid/common/config/NullHandling.java
@@ -94,11 +94,6 @@
     //CHECKSTYLE.ON: Regexp
   }
 
-  public static boolean needsEmptyToNull(@Nullable String value)
-  {
-    return replaceWithDefault() && Strings.isNullOrEmpty(value);
-  }
-
   @Nullable
   public static String defaultStringValue()
   {
diff --git a/processing/src/main/java/org/apache/druid/query/filter/InDimFilter.java b/processing/src/main/java/org/apache/druid/query/filter/InDimFilter.java
index ca12c2c..577e255 100644
--- a/processing/src/main/java/org/apache/druid/query/filter/InDimFilter.java
+++ b/processing/src/main/java/org/apache/druid/query/filter/InDimFilter.java
@@ -31,7 +31,6 @@
 import com.google.common.base.Suppliers;
 import com.google.common.collect.ImmutableSet;
 import com.google.common.collect.Iterables;
-import com.google.common.collect.Ordering;
 import com.google.common.collect.Range;
 import com.google.common.collect.RangeSet;
 import com.google.common.collect.Sets;
@@ -48,6 +47,7 @@
 import org.apache.druid.common.config.NullHandling;
 import org.apache.druid.java.util.common.IAE;
 import org.apache.druid.java.util.common.StringUtils;
+import org.apache.druid.java.util.common.guava.Comparators;
 import org.apache.druid.query.BitmapResultFactory;
 import org.apache.druid.query.cache.CacheKeyBuilder;
 import org.apache.druid.query.extraction.ExtractionFn;
@@ -76,6 +76,7 @@
 import java.util.Map;
 import java.util.Objects;
 import java.util.Set;
+import java.util.SortedSet;
 
 public class InDimFilter extends AbstractOptimizableDimFilter implements Filter
 {
@@ -91,6 +92,16 @@
   @JsonIgnore
   private final Supplier<byte[]> cacheKeySupplier;
 
+  /**
+   * Creates a new filter.
+   *
+   * @param dimension    column to search
+   * @param values       set of values to match. This collection may be reused to avoid copying a big collection.
+   *                     Therefore, callers should <b>not</b> modify the collection after it is passed to this
+   *                     constructor.
+   * @param extractionFn extraction function to apply to the column before checking against "values"
+   * @param filterTuning optional tuning
+   */
   @JsonCreator
   public InDimFilter(
       @JsonProperty("dimension") String dimension,
@@ -111,10 +122,12 @@
   }
 
   /**
+   * Creates a new filter without an extraction function or any special filter tuning.
    *
-   * @param dimension
-   * @param values This collection instance can be reused if possible to avoid copying a big collection.
-   *               Callers should <b>not</b> modify the collection after it is passed to this constructor.
+   * @param dimension column to search
+   * @param values    set of values to match. This collection may be reused to avoid copying a big collection.
+   *                  Therefore, callers should <b>not</b> modify the collection after it is passed to this
+   *                  constructor.
    */
   public InDimFilter(
       String dimension,
@@ -408,8 +421,17 @@
 
   private byte[] computeCacheKey()
   {
-    final List<String> sortedValues = new ArrayList<>(values);
-    sortedValues.sort(Comparator.nullsFirst(Ordering.natural()));
+    final Collection<String> sortedValues;
+
+    if (values instanceof SortedSet && isNaturalOrder(((SortedSet<String>) values).comparator())) {
+      // Avoid copying "values" when it is already in the order we need for cache key computation.
+      sortedValues = values;
+    } else {
+      final List<String> sortedValuesList = new ArrayList<>(values);
+      sortedValuesList.sort(Comparators.naturalNullsFirst());
+      sortedValues = sortedValuesList;
+    }
+
     final Hasher hasher = Hashing.sha256().newHasher();
     for (String v : sortedValues) {
       if (v == null) {
@@ -464,6 +486,17 @@
     return this;
   }
 
+  /**
+   * Returns true if the comparator is null or the singleton {@link Comparators#naturalNullsFirst()}. Useful for
+   * detecting if a sorted set is in natural order or not.
+   *
+   * May return false negatives (i.e. there are naturally-ordered comparators that will return false here).
+   */
+  private static <T> boolean isNaturalOrder(@Nullable final Comparator<T> comparator)
+  {
+    return comparator == null || Comparators.naturalNullsFirst().equals(comparator);
+  }
+
   private static Iterable<ImmutableBitmap> getBitmapIterable(final Set<String> values, final BitmapIndex bitmapIndex)
   {
     return Filters.bitmapsFromIndexes(getBitmapIndexIterable(values, bitmapIndex), bitmapIndex);
@@ -489,6 +522,32 @@
     };
   }
 
+  private static Predicate<String> createStringPredicate(final Set<String> values)
+  {
+    Preconditions.checkNotNull(values, "values");
+
+    try {
+      // Check to see if values.contains(null) will throw a NullPointerException. Jackson JSON deserialization won't
+      // lead to this (it will create a HashSet, which can accept nulls). But when InDimFilters are created
+      // programmatically as a result of optimizations like rewriting inner joins as filters, the passed-in Set may
+      // not be able to accept nulls. We don't want to copy the Sets (since they may be large) so instead we'll wrap
+      // it in a null-checking lambda if needed.
+
+      //noinspection ResultOfMethodCallIgnored
+      values.contains(null);
+
+      // Safe to do values.contains(null).
+      return values::contains;
+    }
+    catch (NullPointerException ignored) {
+      // Fall through
+    }
+
+    // Not safe to do values.contains(null); must return a wrapper.
+    // Return false for null, since an exception means the set cannot accept null (and therefore does not include it).
+    return value -> value != null && values.contains(value);
+  }
+
   private static DruidLongPredicate createLongPredicate(final Set<String> values)
   {
     LongArrayList longs = new LongArrayList(values.size());
@@ -499,7 +558,6 @@
       }
     }
 
-
     final LongOpenHashSet longHashSet = new LongOpenHashSet(longs);
     return longHashSet::contains;
   }
@@ -537,6 +595,7 @@
   {
     private final ExtractionFn extractionFn;
     private final Set<String> values;
+    private final Supplier<Predicate<String>> stringPredicateSupplier;
     private final Supplier<DruidLongPredicate> longPredicateSupplier;
     private final Supplier<DruidFloatPredicate> floatPredicateSupplier;
     private final Supplier<DruidDoublePredicate> doublePredicateSupplier;
@@ -553,6 +612,7 @@
       // only once. Pass in a common long predicate supplier to all filters created by .toFilter(), so that we only
       // compute the long hashset/array once per query. This supplier must be thread-safe, since this DimFilter will be
       // accessed in the query runners.
+      this.stringPredicateSupplier = Suppliers.memoize(() -> createStringPredicate(values));
       this.longPredicateSupplier = Suppliers.memoize(() -> createLongPredicate(values));
       this.floatPredicateSupplier = Suppliers.memoize(() -> createFloatPredicate(values));
       this.doublePredicateSupplier = Suppliers.memoize(() -> createDoublePredicate(values));
@@ -562,9 +622,10 @@
     public Predicate<String> makeStringPredicate()
     {
       if (extractionFn != null) {
-        return input -> values.contains(extractionFn.apply(input));
+        final Predicate<String> stringPredicate = stringPredicateSupplier.get();
+        return input -> stringPredicate.apply(extractionFn.apply(input));
       } else {
-        return values::contains;
+        return stringPredicateSupplier.get();
       }
     }
 
@@ -572,7 +633,8 @@
     public DruidLongPredicate makeLongPredicate()
     {
       if (extractionFn != null) {
-        return input -> values.contains(extractionFn.apply(input));
+        final Predicate<String> stringPredicate = stringPredicateSupplier.get();
+        return input -> stringPredicate.apply(extractionFn.apply(input));
       } else {
         return longPredicateSupplier.get();
       }
@@ -582,7 +644,8 @@
     public DruidFloatPredicate makeFloatPredicate()
     {
       if (extractionFn != null) {
-        return input -> values.contains(extractionFn.apply(input));
+        final Predicate<String> stringPredicate = stringPredicateSupplier.get();
+        return input -> stringPredicate.apply(extractionFn.apply(input));
       } else {
         return floatPredicateSupplier.get();
       }
@@ -592,9 +655,11 @@
     public DruidDoublePredicate makeDoublePredicate()
     {
       if (extractionFn != null) {
-        return input -> values.contains(extractionFn.apply(input));
+        final Predicate<String> stringPredicate = stringPredicateSupplier.get();
+        return input -> stringPredicate.apply(extractionFn.apply(input));
+      } else {
+        return doublePredicateSupplier.get();
       }
-      return input -> doublePredicateSupplier.get().applyDouble(input);
     }
 
     @Override
diff --git a/processing/src/main/java/org/apache/druid/segment/join/table/IndexedTableJoinable.java b/processing/src/main/java/org/apache/druid/segment/join/table/IndexedTableJoinable.java
index e59b4fe..c230b25 100644
--- a/processing/src/main/java/org/apache/druid/segment/join/table/IndexedTableJoinable.java
+++ b/processing/src/main/java/org/apache/druid/segment/join/table/IndexedTableJoinable.java
@@ -21,6 +21,7 @@
 
 import it.unimi.dsi.fastutil.ints.IntList;
 import org.apache.druid.common.config.NullHandling;
+import org.apache.druid.java.util.common.guava.Comparators;
 import org.apache.druid.java.util.common.io.Closer;
 import org.apache.druid.segment.ColumnSelectorFactory;
 import org.apache.druid.segment.DimensionHandlerUtils;
@@ -103,7 +104,10 @@
     try (final IndexedTable.Reader reader = table.columnReader(columnPosition)) {
       // Sorted set to encourage "in" filters that result from this method to do dictionary lookups in order.
       // The hopes are that this will improve locality and therefore improve performance.
-      final Set<String> allValues = new TreeSet<>();
+      //
+      // Note: we are using Comparators.naturalNullsFirst() because it prevents the need for lambda-wrapping in
+      // InDimFilter's "createStringPredicate" method.
+      final Set<String> allValues = new TreeSet<>(Comparators.naturalNullsFirst());
 
       for (int i = 0; i < table.numRows(); i++) {
         final String s = DimensionHandlerUtils.convertObjectToString(reader.read(i));
diff --git a/processing/src/test/java/org/apache/druid/query/filter/InDimFilterTest.java b/processing/src/test/java/org/apache/druid/query/filter/InDimFilterTest.java
index 9cf5b8f..3bc53f6 100644
--- a/processing/src/test/java/org/apache/druid/query/filter/InDimFilterTest.java
+++ b/processing/src/test/java/org/apache/druid/query/filter/InDimFilterTest.java
@@ -22,10 +22,17 @@
 import com.fasterxml.jackson.databind.ObjectMapper;
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableSet;
+import com.google.common.collect.ImmutableSortedSet;
+import com.google.common.collect.Ordering;
 import com.google.common.collect.Sets;
 import org.apache.druid.common.config.NullHandling;
+import org.apache.druid.data.input.MapBasedRow;
 import org.apache.druid.jackson.DefaultObjectMapper;
 import org.apache.druid.query.extraction.RegexDimExtractionFn;
+import org.apache.druid.segment.RowAdapters;
+import org.apache.druid.segment.RowBasedColumnSelectorFactory;
+import org.apache.druid.segment.column.RowSignature;
+import org.apache.druid.segment.column.ValueType;
 import org.apache.druid.testing.InitializedNullHandlingTest;
 import org.junit.Assert;
 import org.junit.Test;
@@ -33,7 +40,10 @@
 import java.io.IOException;
 import java.util.Arrays;
 import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
 import java.util.Set;
+import java.util.TreeSet;
 
 public class InDimFilterTest extends InitializedNullHandlingTest
 {
@@ -88,6 +98,20 @@
   }
 
   @Test
+  public void testGetCacheKeyReturningSameKeyForSetsOfDifferentTypesAndComparators()
+  {
+    final Set<String> reverseOrderSet = new TreeSet<>(Ordering.natural().reversed());
+    final InDimFilter dimFilter1 = new InDimFilter("dim", Sets.newTreeSet(Arrays.asList("v1", "v2")));
+    final InDimFilter dimFilter2 = new InDimFilter("dim", Sets.newHashSet("v2", "v1"));
+    final InDimFilter dimFilter3 = new InDimFilter("dim", ImmutableSortedSet.copyOf(Arrays.asList("v2", "v1")));
+    reverseOrderSet.addAll(Arrays.asList("v1", "v2"));
+    final InDimFilter dimFilter4 = new InDimFilter("dim", reverseOrderSet);
+    Assert.assertArrayEquals(dimFilter1.getCacheKey(), dimFilter2.getCacheKey());
+    Assert.assertArrayEquals(dimFilter1.getCacheKey(), dimFilter3.getCacheKey());
+    Assert.assertArrayEquals(dimFilter1.getCacheKey(), dimFilter4.getCacheKey());
+  }
+
+  @Test
   public void testGetCacheKeyDifferentKeysForListOfStringsAndSingleStringOfLists()
   {
     final InDimFilter inDimFilter1 = new InDimFilter("dimTest", Arrays.asList("good", "bad"), null);
@@ -144,4 +168,41 @@
     final InDimFilter filter = new InDimFilter("dim", Collections.singleton("v1"), null);
     Assert.assertEquals(new SelectorDimFilter("dim", "v1", null), filter.optimize());
   }
+
+  @Test
+  public void testContainsNullWhenValuesSetIsTreeSet()
+  {
+    // Regression test for NullPointerException caused by programmatically-generated InDimFilters that use
+    // TreeSets with natural comparators. These Sets throw NullPointerException on contains(null).
+    // InDimFilter wraps these contains methods in null-checking lambdas.
+
+    final TreeSet<String> values = new TreeSet<>();
+    values.add("foo");
+    values.add("bar");
+
+    final InDimFilter filter = new InDimFilter("dim", values, null);
+
+    final Map<String, Object> row = new HashMap<>();
+    row.put("dim", null);
+
+    final RowBasedColumnSelectorFactory<MapBasedRow> columnSelectorFactory = RowBasedColumnSelectorFactory.create(
+        RowAdapters.standardRow(),
+        () -> new MapBasedRow(0, row),
+        RowSignature.builder().add("dim", ValueType.STRING).build(),
+        true
+    );
+
+    final ValueMatcher matcher = filter.toFilter().makeMatcher(columnSelectorFactory);
+
+    // This would throw an exception without InDimFilter's null-checking lambda wrapping.
+    Assert.assertFalse(matcher.matches());
+
+    row.put("dim", "foo");
+    // Now it should match.
+    Assert.assertTrue(matcher.matches());
+
+    row.put("dim", "fox");
+    // Now it *shouldn't* match.
+    Assert.assertFalse(matcher.matches());
+  }
 }
diff --git a/processing/src/test/java/org/apache/druid/segment/filter/InFilterTest.java b/processing/src/test/java/org/apache/druid/segment/filter/InFilterTest.java
index aac0e63..145839f 100644
--- a/processing/src/test/java/org/apache/druid/segment/filter/InFilterTest.java
+++ b/processing/src/test/java/org/apache/druid/segment/filter/InFilterTest.java
@@ -396,7 +396,12 @@
     EqualsVerifier.forClass(InDimFilter.InFilterDruidPredicateFactory.class)
                   .usingGetClass()
                   .withNonnullFields("values")
-                  .withIgnoredFields("longPredicateSupplier", "floatPredicateSupplier", "doublePredicateSupplier")
+                  .withIgnoredFields(
+                      "longPredicateSupplier",
+                      "floatPredicateSupplier",
+                      "doublePredicateSupplier",
+                      "stringPredicateSupplier"
+                  )
                   .verify();
   }