Fix incorrect result of exact topN on an inner join with limit (#11517)
diff --git a/processing/src/main/java/org/apache/druid/query/topn/BaseTopNAlgorithm.java b/processing/src/main/java/org/apache/druid/query/topn/BaseTopNAlgorithm.java
index b8b04ad..843d248 100644
--- a/processing/src/main/java/org/apache/druid/query/topn/BaseTopNAlgorithm.java
+++ b/processing/src/main/java/org/apache/druid/query/topn/BaseTopNAlgorithm.java
@@ -316,6 +316,7 @@
if (ignoreAfterThreshold &&
query.getDimensionsFilter() == null &&
+ !storageAdapter.hasBuiltInFilters() &&
query.getIntervals().stream().anyMatch(interval -> interval.contains(storageAdapter.getInterval()))) {
endIndex = Math.min(endIndex, startIndex + query.getThreshold());
}
diff --git a/processing/src/main/java/org/apache/druid/query/topn/TopNQueryConfig.java b/processing/src/main/java/org/apache/druid/query/topn/TopNQueryConfig.java
index 9e5dcd3..2793b27 100644
--- a/processing/src/main/java/org/apache/druid/query/topn/TopNQueryConfig.java
+++ b/processing/src/main/java/org/apache/druid/query/topn/TopNQueryConfig.java
@@ -27,9 +27,11 @@
*/
public class TopNQueryConfig
{
+ public static final int DEFAULT_MIN_TOPN_THRESHOLD = 1000;
+
@JsonProperty
@Min(1)
- private int minTopNThreshold = 1000;
+ private int minTopNThreshold = DEFAULT_MIN_TOPN_THRESHOLD;
public int getMinTopNThreshold()
{
diff --git a/processing/src/main/java/org/apache/druid/segment/StorageAdapter.java b/processing/src/main/java/org/apache/druid/segment/StorageAdapter.java
index e7905b2..2aa4b77 100644
--- a/processing/src/main/java/org/apache/druid/segment/StorageAdapter.java
+++ b/processing/src/main/java/org/apache/druid/segment/StorageAdapter.java
@@ -76,4 +76,17 @@
int getNumRows();
DateTime getMaxIngestedEventTime();
Metadata getMetadata();
+
+ /**
+ * Returns true if this storage adapter can filter some rows out. The actual column cardinality can be lower than
+ * what {@link #getDimensionCardinality} returns if this returns true. Dimension selectors for such storage adapter
+ * can return non-contiguous dictionary IDs because the dictionary IDs in filtered rows will not be returned.
+ * Note that the number of rows accessible via this storage adapter will not necessarily decrease because of
+ * the built-in filters. For inner joins, for example, the number of joined rows can be larger than
+ * the number of rows in the base adapter even though this method returns true.
+ */
+ default boolean hasBuiltInFilters()
+ {
+ return false;
+ }
}
diff --git a/processing/src/main/java/org/apache/druid/segment/join/HashJoinSegmentStorageAdapter.java b/processing/src/main/java/org/apache/druid/segment/join/HashJoinSegmentStorageAdapter.java
index 86b7ef4..c056725 100644
--- a/processing/src/main/java/org/apache/druid/segment/join/HashJoinSegmentStorageAdapter.java
+++ b/processing/src/main/java/org/apache/druid/segment/join/HashJoinSegmentStorageAdapter.java
@@ -227,6 +227,13 @@
}
@Override
+ public boolean hasBuiltInFilters()
+ {
+ return clauses.stream()
+ .anyMatch(clause -> clause.getJoinType() == JoinType.INNER && !clause.getCondition().isAlwaysTrue());
+ }
+
+ @Override
public boolean canVectorize(@Nullable Filter filter, VirtualColumns virtualColumns, boolean descending)
{
// HashJoinEngine isn't vectorized yet.
@@ -343,7 +350,7 @@
return PostJoinCursor.wrap(
retVal,
VirtualColumns.create(postJoinVirtualColumns),
- joinFilterSplit.getJoinTableFilter().isPresent() ? joinFilterSplit.getJoinTableFilter().get() : null
+ joinFilterSplit.getJoinTableFilter().orElse(null)
);
}
).withBaggage(joinablesCloser);
diff --git a/processing/src/test/java/org/apache/druid/segment/join/BaseHashJoinSegmentStorageAdapterTest.java b/processing/src/test/java/org/apache/druid/segment/join/BaseHashJoinSegmentStorageAdapterTest.java
index 26ba161..33199b3 100644
--- a/processing/src/test/java/org/apache/druid/segment/join/BaseHashJoinSegmentStorageAdapterTest.java
+++ b/processing/src/test/java/org/apache/druid/segment/join/BaseHashJoinSegmentStorageAdapterTest.java
@@ -206,9 +206,14 @@
*/
protected HashJoinSegmentStorageAdapter makeFactToCountrySegment()
{
+ return makeFactToCountrySegment(JoinType.LEFT);
+ }
+
+ protected HashJoinSegmentStorageAdapter makeFactToCountrySegment(JoinType joinType)
+ {
return new HashJoinSegmentStorageAdapter(
factSegment.asStorageAdapter(),
- ImmutableList.of(factToCountryOnIsoCode(JoinType.LEFT)),
+ ImmutableList.of(factToCountryOnIsoCode(joinType)),
null
);
}
diff --git a/processing/src/test/java/org/apache/druid/segment/join/HashJoinSegmentStorageAdapterTest.java b/processing/src/test/java/org/apache/druid/segment/join/HashJoinSegmentStorageAdapterTest.java
index 10d0483..4ee41a2 100644
--- a/processing/src/test/java/org/apache/druid/segment/join/HashJoinSegmentStorageAdapterTest.java
+++ b/processing/src/test/java/org/apache/druid/segment/join/HashJoinSegmentStorageAdapterTest.java
@@ -2266,4 +2266,101 @@
Assert.assertEquals(expectedPostJoin, actualPostJoin);
}
+ @Test
+ public void test_hasBuiltInFiltersForSingleJoinableClauseWithVariousJoinTypes()
+ {
+ Assert.assertTrue(makeFactToCountrySegment(JoinType.INNER).hasBuiltInFilters());
+ Assert.assertFalse(makeFactToCountrySegment(JoinType.LEFT).hasBuiltInFilters());
+ Assert.assertFalse(makeFactToCountrySegment(JoinType.RIGHT).hasBuiltInFilters());
+ Assert.assertFalse(makeFactToCountrySegment(JoinType.FULL).hasBuiltInFilters());
+ // cross join
+ Assert.assertFalse(
+ new HashJoinSegmentStorageAdapter(
+ factSegment.asStorageAdapter(),
+ ImmutableList.of(
+ new JoinableClause(
+ FACT_TO_COUNTRY_ON_ISO_CODE_PREFIX,
+ new IndexedTableJoinable(countriesTable),
+ JoinType.INNER,
+ JoinConditionAnalysis.forExpression(
+ "'true'",
+ FACT_TO_COUNTRY_ON_ISO_CODE_PREFIX,
+ ExprMacroTable.nil()
+ )
+ )
+ ),
+ null
+ ).hasBuiltInFilters()
+ );
+ }
+
+ @Test
+ public void test_hasBuiltInFiltersForEmptyJoinableClause()
+ {
+ Assert.assertFalse(
+ new HashJoinSegmentStorageAdapter(
+ factSegment.asStorageAdapter(),
+ ImmutableList.of(),
+ null
+ ).hasBuiltInFilters()
+ );
+ }
+
+ @Test
+ public void test_hasBuiltInFiltersForMultipleJoinableClausesWithVariousJoinTypes()
+ {
+ Assert.assertTrue(
+ new HashJoinSegmentStorageAdapter(
+ factSegment.asStorageAdapter(),
+ ImmutableList.of(
+ factToRegion(JoinType.INNER),
+ regionToCountry(JoinType.LEFT)
+ ),
+ null
+ ).hasBuiltInFilters()
+ );
+
+ Assert.assertTrue(
+ new HashJoinSegmentStorageAdapter(
+ factSegment.asStorageAdapter(),
+ ImmutableList.of(
+ factToRegion(JoinType.RIGHT),
+ regionToCountry(JoinType.INNER),
+ factToCountryOnNumber(JoinType.FULL)
+ ),
+ null
+ ).hasBuiltInFilters()
+ );
+
+ Assert.assertFalse(
+ new HashJoinSegmentStorageAdapter(
+ factSegment.asStorageAdapter(),
+ ImmutableList.of(
+ factToRegion(JoinType.LEFT),
+ regionToCountry(JoinType.LEFT)
+ ),
+ null
+ ).hasBuiltInFilters()
+ );
+
+ Assert.assertFalse(
+ new HashJoinSegmentStorageAdapter(
+ factSegment.asStorageAdapter(),
+ ImmutableList.of(
+ factToRegion(JoinType.LEFT),
+ new JoinableClause(
+ FACT_TO_COUNTRY_ON_ISO_CODE_PREFIX,
+ new IndexedTableJoinable(countriesTable),
+ JoinType.INNER,
+ JoinConditionAnalysis.forExpression(
+ "'true'",
+ FACT_TO_COUNTRY_ON_ISO_CODE_PREFIX,
+ ExprMacroTable.nil()
+ )
+ )
+ ),
+ null
+ ).hasBuiltInFilters()
+ );
+ }
}
diff --git a/server/src/test/java/org/apache/druid/query/QueryRunnerBasedOnClusteredClientTestBase.java b/server/src/test/java/org/apache/druid/query/QueryRunnerBasedOnClusteredClientTestBase.java
index 94c6c59..97457b8 100644
--- a/server/src/test/java/org/apache/druid/query/QueryRunnerBasedOnClusteredClientTestBase.java
+++ b/server/src/test/java/org/apache/druid/query/QueryRunnerBasedOnClusteredClientTestBase.java
@@ -46,6 +46,7 @@
import org.apache.druid.query.context.ResponseContext;
import org.apache.druid.query.context.ResponseContext.Key;
import org.apache.druid.query.timeseries.TimeseriesResultValue;
+import org.apache.druid.query.topn.TopNQueryConfig;
import org.apache.druid.segment.QueryableIndex;
import org.apache.druid.segment.generator.GeneratorBasicSchemas;
import org.apache.druid.segment.generator.GeneratorSchemaInfo;
@@ -108,7 +109,11 @@
protected QueryRunnerBasedOnClusteredClientTestBase()
{
- conglomerate = QueryStackTests.createQueryRunnerFactoryConglomerate(CLOSER, USE_PARALLEL_MERGE_POOL_CONFIGURED);
+ conglomerate = QueryStackTests.createQueryRunnerFactoryConglomerate(
+ CLOSER,
+ USE_PARALLEL_MERGE_POOL_CONFIGURED,
+ () -> TopNQueryConfig.DEFAULT_MIN_TOPN_THRESHOLD
+ );
toolChestWarehouse = new QueryToolChestWarehouse()
{
diff --git a/server/src/test/java/org/apache/druid/server/QueryStackTests.java b/server/src/test/java/org/apache/druid/server/QueryStackTests.java
index 074649f..9e7c234 100644
--- a/server/src/test/java/org/apache/druid/server/QueryStackTests.java
+++ b/server/src/test/java/org/apache/druid/server/QueryStackTests.java
@@ -80,6 +80,7 @@
import java.nio.ByteBuffer;
import java.util.Map;
import java.util.Set;
+import java.util.function.Supplier;
/**
* Utilities for creating query-stack objects for tests.
@@ -228,20 +229,30 @@
*/
public static QueryRunnerFactoryConglomerate createQueryRunnerFactoryConglomerate(final Closer closer)
{
- return createQueryRunnerFactoryConglomerate(closer, true);
+ return createQueryRunnerFactoryConglomerate(closer, true, () -> TopNQueryConfig.DEFAULT_MIN_TOPN_THRESHOLD);
}
public static QueryRunnerFactoryConglomerate createQueryRunnerFactoryConglomerate(
final Closer closer,
- final boolean useParallelMergePoolConfigured
-
+ final Supplier<Integer> minTopNThresholdSupplier
)
{
- return createQueryRunnerFactoryConglomerate(closer,
- getProcessingConfig(
- useParallelMergePoolConfigured,
- DruidProcessingConfig.DEFAULT_NUM_MERGE_BUFFERS
- )
+ return createQueryRunnerFactoryConglomerate(closer, true, minTopNThresholdSupplier);
+ }
+
+ public static QueryRunnerFactoryConglomerate createQueryRunnerFactoryConglomerate(
+ final Closer closer,
+ final boolean useParallelMergePoolConfigured,
+ final Supplier<Integer> minTopNThresholdSupplier
+ )
+ {
+ return createQueryRunnerFactoryConglomerate(
+ closer,
+ getProcessingConfig(
+ useParallelMergePoolConfigured,
+ DruidProcessingConfig.DEFAULT_NUM_MERGE_BUFFERS
+ ),
+ minTopNThresholdSupplier
);
}
@@ -250,6 +261,19 @@
final DruidProcessingConfig processingConfig
)
{
+ return createQueryRunnerFactoryConglomerate(
+ closer,
+ processingConfig,
+ () -> TopNQueryConfig.DEFAULT_MIN_TOPN_THRESHOLD
+ );
+ }
+
+ public static QueryRunnerFactoryConglomerate createQueryRunnerFactoryConglomerate(
+ final Closer closer,
+ final DruidProcessingConfig processingConfig,
+ final Supplier<Integer> minTopNThresholdSupplier
+ )
+ {
final CloseableStupidPool<ByteBuffer> stupidPool = new CloseableStupidPool<>(
"TopNQueryRunnerFactory-bufferPool",
() -> ByteBuffer.allocate(COMPUTE_BUFFER_SIZE)
@@ -308,7 +332,14 @@
TopNQuery.class,
new TopNQueryRunnerFactory(
stupidPool,
- new TopNQueryQueryToolChest(new TopNQueryConfig()),
+ new TopNQueryQueryToolChest(new TopNQueryConfig()
+ {
+ @Override
+ public int getMinTopNThreshold()
+ {
+ return minTopNThresholdSupplier.get();
+ }
+ }),
QueryRunnerTestHelper.NOOP_QUERYWATCHER
)
)
diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/BaseCalciteQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/BaseCalciteQueryTest.java
index c8988cf..b4c3682 100644
--- a/sql/src/test/java/org/apache/druid/sql/calcite/BaseCalciteQueryTest.java
+++ b/sql/src/test/java/org/apache/druid/sql/calcite/BaseCalciteQueryTest.java
@@ -62,6 +62,7 @@
import org.apache.druid.query.spec.MultipleIntervalSegmentSpec;
import org.apache.druid.query.spec.QuerySegmentSpec;
import org.apache.druid.query.timeseries.TimeseriesQuery;
+import org.apache.druid.query.topn.TopNQueryConfig;
import org.apache.druid.segment.column.ColumnHolder;
import org.apache.druid.segment.column.ValueType;
import org.apache.druid.segment.join.JoinType;
@@ -248,6 +249,7 @@
public static QueryRunnerFactoryConglomerate conglomerate;
public static Closer resourceCloser;
+ public static int minTopNThreshold = TopNQueryConfig.DEFAULT_MIN_TOPN_THRESHOLD;
@Rule
public ExpectedException expectedException = ExpectedException.none();
@@ -444,7 +446,7 @@
public static void setUpClass()
{
resourceCloser = Closer.create();
- conglomerate = QueryStackTests.createQueryRunnerFactoryConglomerate(resourceCloser);
+ conglomerate = QueryStackTests.createQueryRunnerFactoryConglomerate(resourceCloser, () -> minTopNThreshold);
}
@AfterClass
diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
index 9404b03..d05b150 100644
--- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
+++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
@@ -354,6 +354,56 @@
}
@Test
+ public void testExactTopNOnInnerJoinWithLimit() throws Exception
+ {
+ // Adjust topN threshold, so that the topN engine keeps only 1 slot for aggregates, which should be enough
+ // to compute the query with limit 1.
+ minTopNThreshold = 1;
+ Map<String, Object> context = new HashMap<>(QUERY_CONTEXT_DEFAULT);
+ context.put(PlannerConfig.CTX_KEY_USE_APPROXIMATE_TOPN, false);
+ testQuery(
+ "select f1.\"dim4\", sum(\"m1\") from numfoo f1 inner join (\n"
+ + " select \"dim4\" from numfoo where dim4 <> 'a' group by 1\n"
+ + ") f2 on f1.\"dim4\" = f2.\"dim4\" group by 1 limit 1",
+ context, // turn on exact topN
+ ImmutableList.of(
+ new TopNQueryBuilder()
+ .intervals(querySegmentSpec(Filtration.eternity()))
+ .granularity(Granularities.ALL)
+ .dimension(new DefaultDimensionSpec("dim4", "_d0"))
+ .aggregators(new DoubleSumAggregatorFactory("a0", "m1"))
+ .metric(new DimensionTopNMetricSpec(null, StringComparators.LEXICOGRAPHIC))
+ .threshold(1)
+ .dataSource(
+ JoinDataSource.create(
+ new TableDataSource("numfoo"),
+ new QueryDataSource(
+ GroupByQuery.builder()
+ .setInterval(querySegmentSpec(Filtration.eternity()))
+ .setGranularity(Granularities.ALL)
+ .setDimFilter(new NotDimFilter(new SelectorDimFilter("dim4", "a", null)))
+ .setDataSource(new TableDataSource("numfoo"))
+ .setDimensions(new DefaultDimensionSpec("dim4", "_d0"))
+ .setContext(context)
+ .build()
+ ),
+ "j0.",
+ "(\"dim4\" == \"j0._d0\")",
+ JoinType.INNER,
+ null,
+ ExprMacroTable.nil()
+ )
+ )
+ .context(context)
+ .build()
+ ),
+ ImmutableList.of(
+ new Object[]{"b", 15.0}
+ )
+ );
+ }
+
+ @Test
public void testJoinOuterGroupByAndSubqueryHasLimit() throws Exception
{
// Cannot vectorize JOIN operator.