blob: 6d4bf1bc4594c9400fa8b60246d87dc9e7990377 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.sql.calcite;
import com.google.common.collect.ImmutableList;
import junitparams.JUnitParamsRunner;
import junitparams.Parameters;
import org.apache.druid.java.util.common.Intervals;
import org.apache.druid.java.util.common.granularity.AllGranularity;
import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.java.util.common.granularity.Granularity;
import org.apache.druid.query.QueryDataSource;
import org.apache.druid.query.TableDataSource;
import org.apache.druid.query.aggregation.CountAggregatorFactory;
import org.apache.druid.query.aggregation.FilteredAggregatorFactory;
import org.apache.druid.query.aggregation.LongMaxAggregatorFactory;
import org.apache.druid.query.aggregation.LongSumAggregatorFactory;
import org.apache.druid.query.aggregation.any.LongAnyAggregatorFactory;
import org.apache.druid.query.aggregation.cardinality.CardinalityAggregatorFactory;
import org.apache.druid.query.aggregation.hyperloglog.HyperUniqueFinalizingPostAggregator;
import org.apache.druid.query.aggregation.post.ArithmeticPostAggregator;
import org.apache.druid.query.aggregation.post.FieldAccessPostAggregator;
import org.apache.druid.query.dimension.DefaultDimensionSpec;
import org.apache.druid.query.expression.TestExprMacroTable;
import org.apache.druid.query.groupby.GroupByQuery;
import org.apache.druid.segment.column.ValueType;
import org.apache.druid.segment.join.JoinType;
import org.apache.druid.segment.virtual.ExpressionVirtualColumn;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.util.CalciteTests;
import org.junit.Test;
import org.junit.runner.RunWith;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
@RunWith(JUnitParamsRunner.class)
public class CalciteCorrelatedQueryTest extends BaseCalciteQueryTest
{
@Test
@Parameters(source = QueryContextForJoinProvider.class)
public void testCorrelatedSubquery(Map<String, Object> queryContext) throws Exception
{
cannotVectorize();
queryContext = withLeftDirectAccessEnabled(queryContext);
testQuery(
"select country, ANY_VALUE(\n"
+ " select avg(\"users\") from (\n"
+ " select floor(__time to day), count(distinct user) \"users\" from visits f where f.country = visits.country group by 1\n"
+ " )\n"
+ " ) as \"DAU\"\n"
+ "from visits \n"
+ "group by 1",
queryContext,
ImmutableList.of(
GroupByQuery.builder()
.setDataSource(
join(
new TableDataSource(CalciteTests.USERVISITDATASOURCE),
new QueryDataSource(
GroupByQuery.builder()
.setDataSource(
GroupByQuery.builder()
.setDataSource(CalciteTests.USERVISITDATASOURCE)
.setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY))
.setVirtualColumns(new ExpressionVirtualColumn(
"v0",
"timestamp_floor(\"__time\",'P1D',null,'UTC')",
ValueType.LONG,
TestExprMacroTable.INSTANCE
))
.setDimFilter(not(selector("country", null, null)))
.setDimensions(
new DefaultDimensionSpec(
"v0",
"d0",
ValueType.LONG
),
new DefaultDimensionSpec(
"country",
"d1"
)
)
.setAggregatorSpecs(new CardinalityAggregatorFactory(
"a0:a",
null,
Collections.singletonList(
new DefaultDimensionSpec(
"user",
"user"
)),
false,
true
))
.setPostAggregatorSpecs(Collections.singletonList(new HyperUniqueFinalizingPostAggregator(
"a0",
"a0:a"
)))
.setContext(
withTimestampResultContext(
queryContext,
"d0",
Granularities.DAY
)
)
.setGranularity(new AllGranularity())
.build()
)
.setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY))
.setDimensions(new DefaultDimensionSpec("d1", "_d0"))
.setAggregatorSpecs(
new LongSumAggregatorFactory("_a0:sum", "a0"),
new CountAggregatorFactory("_a0:count")
)
.setPostAggregatorSpecs(Collections.singletonList(new ArithmeticPostAggregator(
"_a0",
"quotient",
Arrays
.asList(
new FieldAccessPostAggregator(null, "_a0:sum"),
new FieldAccessPostAggregator(null, "_a0:count")
)
)))
.setGranularity(new AllGranularity())
.setContext(queryContext)
.build()
),
"j0.",
equalsCondition(
DruidExpression.fromColumn("country"),
DruidExpression.fromColumn("j0._d0")
),
JoinType.LEFT
)
)
.setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY))
.setDimensions(new DefaultDimensionSpec("country", "d0"))
.setAggregatorSpecs(new LongAnyAggregatorFactory("a0", "j0._a0"))
.setGranularity(new AllGranularity())
.setContext(queryContext)
.build()
),
ImmutableList.of(
new Object[]{"India", 2L},
new Object[]{"USA", 1L},
new Object[]{"canada", 3L}
)
);
}
@Test
@Parameters(source = QueryContextForJoinProvider.class)
public void testCorrelatedSubqueryWithLeftFilter(Map<String, Object> queryContext) throws Exception
{
cannotVectorize();
queryContext = withLeftDirectAccessEnabled(queryContext);
testQuery(
"select country, ANY_VALUE(\n"
+ " select max(\"users\") from (\n"
+ " select floor(__time to day), count(*) \"users\" from visits f where f.country = visits.country group by 1\n"
+ " )\n"
+ " ) as \"dailyVisits\"\n"
+ "from visits \n"
+ " where city = 'B' and __time between '2021-01-01 01:00:00' AND '2021-01-02 23:59:59'"
+ " group by 1",
queryContext,
ImmutableList.of(
GroupByQuery.builder()
.setDataSource(
join(
new TableDataSource(CalciteTests.USERVISITDATASOURCE),
new QueryDataSource(
GroupByQuery.builder()
.setDataSource(
GroupByQuery.builder()
.setDataSource(CalciteTests.USERVISITDATASOURCE)
.setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY))
.setVirtualColumns(new ExpressionVirtualColumn(
"v0",
"timestamp_floor(\"__time\",'P1D',null,'UTC')",
ValueType.LONG,
TestExprMacroTable.INSTANCE
))
.setDimFilter(not(selector("country", null, null)))
.setDimensions(
new DefaultDimensionSpec(
"v0",
"d0",
ValueType.LONG
),
new DefaultDimensionSpec(
"country",
"d1"
)
)
.setAggregatorSpecs(new CountAggregatorFactory("a0"))
.setContext(
withTimestampResultContext(
queryContext,
"d0",
Granularities.DAY
)
)
.setGranularity(new AllGranularity())
.build()
)
.setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY))
.setDimensions(new DefaultDimensionSpec("d1", "_d0"))
.setAggregatorSpecs(
new LongMaxAggregatorFactory("_a0", "a0")
)
.setGranularity(new AllGranularity())
.setContext(queryContext)
.build()
),
"j0.",
equalsCondition(
DruidExpression.fromColumn("country"),
DruidExpression.fromColumn("j0._d0")
),
JoinType.LEFT,
selector("city", "B", null)
)
)
.setQuerySegmentSpec(querySegmentSpec(Intervals.of(
"2021-01-01T01:00:00.000Z/2021-01-02T23:59:59.001Z")))
.setDimensions(new DefaultDimensionSpec("country", "d0"))
.setAggregatorSpecs(new LongAnyAggregatorFactory("a0", "j0._a0"))
.setGranularity(new AllGranularity())
.setContext(queryContext)
.build()
),
ImmutableList.of(
new Object[]{"canada", 4L}
)
);
}
@Test
@Parameters(source = QueryContextForJoinProvider.class)
public void testCorrelatedSubqueryWithLeftFilter_leftDirectAccessDisabled(Map<String, Object> queryContext) throws Exception
{
cannotVectorize();
testQuery(
"select country, ANY_VALUE(\n"
+ " select max(\"users\") from (\n"
+ " select floor(__time to day), count(*) \"users\" from visits f where f.country = visits.country group by 1\n"
+ " )\n"
+ " ) as \"dailyVisits\"\n"
+ "from visits \n"
+ " where city = 'B' and __time between '2021-01-01 01:00:00' AND '2021-01-02 23:59:59'"
+ " group by 1",
queryContext,
ImmutableList.of(
GroupByQuery.builder()
.setDataSource(
join(
new QueryDataSource(newScanQueryBuilder().dataSource(CalciteTests.USERVISITDATASOURCE)
.intervals(querySegmentSpec(Intervals.of(
"2021-01-01T01:00:00.000Z/2021-01-02T23:59:59.001Z")))
.filters(selector("city", "B", null))
.columns("__time", "city", "country")
.build()),
new QueryDataSource(
GroupByQuery.builder()
.setDataSource(
GroupByQuery.builder()
.setDataSource(CalciteTests.USERVISITDATASOURCE)
.setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY))
.setVirtualColumns(new ExpressionVirtualColumn(
"v0",
"timestamp_floor(\"__time\",'P1D',null,'UTC')",
ValueType.LONG,
TestExprMacroTable.INSTANCE
))
.setDimFilter(not(selector("country", null, null)))
.setDimensions(
new DefaultDimensionSpec(
"v0",
"d0",
ValueType.LONG
),
new DefaultDimensionSpec(
"country",
"d1"
)
)
.setAggregatorSpecs(new CountAggregatorFactory("a0"))
.setContext(
withTimestampResultContext(
queryContext,
"d0",
Granularities.DAY
)
)
.setGranularity(new AllGranularity())
.build()
)
.setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY))
.setDimensions(new DefaultDimensionSpec("d1", "_d0"))
.setAggregatorSpecs(
new LongMaxAggregatorFactory("_a0", "a0")
)
.setGranularity(new AllGranularity())
.setContext(queryContext)
.build()
),
"j0.",
equalsCondition(
DruidExpression.fromColumn("country"),
DruidExpression.fromColumn("j0._d0")
),
JoinType.LEFT
)
)
.setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY))
.setDimensions(new DefaultDimensionSpec("country", "d0"))
.setAggregatorSpecs(new LongAnyAggregatorFactory("a0", "j0._a0"))
.setGranularity(new AllGranularity())
.setContext(queryContext)
.build()
),
ImmutableList.of(
new Object[]{"canada", 4L}
)
);
}
@Test
@Parameters(source = QueryContextForJoinProvider.class)
public void testCorrelatedSubqueryWithCorrelatedQueryFilter(Map<String, Object> queryContext) throws Exception
{
cannotVectorize();
queryContext = withLeftDirectAccessEnabled(queryContext);
testQuery(
"select country, ANY_VALUE(\n"
+ " select max(\"users\") from (\n"
+ " select floor(__time to day), count(user) \"users\" from visits f where f.country = visits.country and f.city = 'A' group by 1\n"
+ " )\n"
+ " ) as \"dailyVisits\"\n"
+ "from visits \n"
+ " where city = 'B'"
+ " group by 1",
queryContext,
ImmutableList.of(
GroupByQuery.builder()
.setDataSource(
join(
new TableDataSource(CalciteTests.USERVISITDATASOURCE),
new QueryDataSource(
GroupByQuery.builder()
.setDataSource(
GroupByQuery.builder()
.setDataSource(CalciteTests.USERVISITDATASOURCE)
.setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY))
.setVirtualColumns(new ExpressionVirtualColumn(
"v0",
"timestamp_floor(\"__time\",'P1D',null,'UTC')",
ValueType.LONG,
TestExprMacroTable.INSTANCE
))
.setDimensions(
new DefaultDimensionSpec(
"v0",
"d0",
ValueType.LONG
),
new DefaultDimensionSpec(
"country",
"d1"
)
)
.setAggregatorSpecs(new FilteredAggregatorFactory(
new CountAggregatorFactory("a0"),
not(selector("user", null, null))
))
.setDimFilter(and(
selector("city", "A", null),
not(selector("country", null, null))
))
.setContext(
withTimestampResultContext(
queryContext,
"d0",
Granularities.DAY
)
)
.setGranularity(new AllGranularity())
.build()
)
.setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY))
.setDimensions(new DefaultDimensionSpec("d1", "_d0"))
.setAggregatorSpecs(
new LongMaxAggregatorFactory("_a0", "a0")
)
.setGranularity(new AllGranularity())
.setContext(queryContext)
.build()
),
"j0.",
equalsCondition(
DruidExpression.fromColumn("country"),
DruidExpression.fromColumn("j0._d0")
),
JoinType.LEFT,
selector("city", "B", null)
)
)
.setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY))
.setDimensions(new DefaultDimensionSpec("country", "d0"))
.setAggregatorSpecs(new LongAnyAggregatorFactory("a0", "j0._a0"))
.setGranularity(new AllGranularity())
.setContext(queryContext)
.build()
),
ImmutableList.of(
new Object[]{"canada", 2L}
)
);
}
@Test
@Parameters(source = QueryContextForJoinProvider.class)
public void testCorrelatedSubqueryWithCorrelatedQueryFilter_Scan(Map<String, Object> queryContext) throws Exception
{
cannotVectorize();
queryContext = withLeftDirectAccessEnabled(queryContext);
testQuery(
"select country, ANY_VALUE(\n"
+ " select max(\"users\") from (\n"
+ " select floor(__time to day), count(user) \"users\" from visits f where f.country = visits.country and f.city = 'A' group by 1\n"
+ " )\n"
+ " ) as \"dailyVisits\"\n"
+ "from visits \n"
+ " where city = 'B'"
+ " group by 1",
queryContext,
ImmutableList.of(
GroupByQuery.builder()
.setDataSource(
join(
new TableDataSource(CalciteTests.USERVISITDATASOURCE),
new QueryDataSource(
GroupByQuery.builder()
.setDataSource(
GroupByQuery.builder()
.setDataSource(CalciteTests.USERVISITDATASOURCE)
.setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY))
.setVirtualColumns(new ExpressionVirtualColumn(
"v0",
"timestamp_floor(\"__time\",'P1D',null,'UTC')",
ValueType.LONG,
TestExprMacroTable.INSTANCE
))
.setDimensions(
new DefaultDimensionSpec(
"v0",
"d0",
ValueType.LONG
),
new DefaultDimensionSpec(
"country",
"d1"
)
)
.setAggregatorSpecs(new FilteredAggregatorFactory(
new CountAggregatorFactory("a0"),
not(selector("user", null, null))
))
.setDimFilter(and(
selector("city", "A", null),
not(selector("country", null, null))
))
.setContext(
withTimestampResultContext(
queryContext,
"d0",
Granularities.DAY
)
)
.setGranularity(new AllGranularity())
.build()
)
.setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY))
.setDimensions(new DefaultDimensionSpec("d1", "_d0"))
.setAggregatorSpecs(
new LongMaxAggregatorFactory("_a0", "a0")
)
.setGranularity(new AllGranularity())
.setContext(queryContext)
.build()
),
"j0.",
equalsCondition(
DruidExpression.fromColumn("country"),
DruidExpression.fromColumn("j0._d0")
),
JoinType.LEFT,
selector("city", "B", null)
)
)
.setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY))
.setDimensions(new DefaultDimensionSpec("country", "d0"))
.setAggregatorSpecs(new LongAnyAggregatorFactory("a0", "j0._a0"))
.setGranularity(new AllGranularity())
.setContext(queryContext)
.build()
),
ImmutableList.of(
new Object[]{"canada", 2L}
)
);
}
private Map<String, Object> withTimestampResultContext(
Map<String, Object> input,
String timestampResultField,
Granularity granularity
)
{
Map<String, Object> output = new HashMap<>(input);
output.put(GroupByQuery.CTX_TIMESTAMP_RESULT_FIELD, timestampResultField);
output.put(GroupByQuery.CTX_TIMESTAMP_RESULT_FIELD_GRANULARITY, granularity);
output.put(GroupByQuery.CTX_TIMESTAMP_RESULT_FIELD_INDEX, 0);
return output;
}
}