blob: 2d0a1be9c66f2f4e2ca2f3c2111591abf04df8e5 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.extensions.sql.impl.rule;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;
import org.apache.beam.sdk.extensions.sql.impl.BeamSqlEnv;
import org.apache.beam.sdk.extensions.sql.impl.planner.BeamRuleSets;
import org.apache.beam.sdk.extensions.sql.impl.rel.BeamRelNode;
import org.apache.beam.sdk.extensions.sql.meta.provider.test.TestTableProvider;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.vendor.calcite.v1_20_0.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.calcite.v1_20_0.com.google.common.collect.ImmutableMap;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.DataContext;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.adapter.enumerable.EnumerableConvention;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.adapter.enumerable.EnumerableRules;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.linq4j.Enumerable;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.linq4j.Linq4j;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.plan.ConventionTraitDef;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.plan.RelOptRule;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.plan.RelTraitSet;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.RelCollationTraitDef;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.RelCollations;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.RelFieldCollation;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.RelNode;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.RelRoot;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.core.Join;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.core.TableScan;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.rules.JoinCommuteRule;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.rules.SortProjectTransposeRule;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.type.RelDataType;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.schema.ScannableTable;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.schema.SchemaPlus;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.schema.Statistic;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.schema.Statistics;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.schema.Table;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.schema.impl.AbstractSchema;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.schema.impl.AbstractTable;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.sql.SqlNode;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.sql.parser.SqlParser;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.tools.FrameworkConfig;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.tools.Frameworks;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.tools.Planner;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.tools.Programs;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.tools.RuleSet;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.tools.RuleSets;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.util.ImmutableBitSet;
import org.junit.Assert;
import org.junit.Test;
/**
* This test ensures that we are reordering joins and get a plan similar to Join(large,Join(small,
* medium)) instead of Join(small, Join(medium,large).
*/
public class JoinReorderingTest {
private final PipelineOptions defaultPipelineOptions = PipelineOptionsFactory.create();
@Test
public void testTableSizes() {
TestTableProvider tableProvider = new TestTableProvider();
createThreeTables(tableProvider);
Assert.assertEquals(
1d,
tableProvider
.buildBeamSqlTable(tableProvider.getTable("small_table"))
.getTableStatistics(null)
.getRowCount(),
0.01);
Assert.assertEquals(
3d,
tableProvider
.buildBeamSqlTable(tableProvider.getTable("medium_table"))
.getTableStatistics(null)
.getRowCount(),
0.01);
Assert.assertEquals(
100d,
tableProvider
.buildBeamSqlTable(tableProvider.getTable("large_table"))
.getTableStatistics(null)
.getRowCount(),
0.01);
}
@Test
public void testBeamJoinAssociationRule() throws Exception {
RuleSet prepareRules =
RuleSets.ofList(
SortProjectTransposeRule.INSTANCE,
EnumerableRules.ENUMERABLE_JOIN_RULE,
EnumerableRules.ENUMERABLE_PROJECT_RULE,
EnumerableRules.ENUMERABLE_SORT_RULE,
EnumerableRules.ENUMERABLE_TABLE_SCAN_RULE);
String sqlQuery =
"select * from \"tt\".\"large_table\" as large_table "
+ " JOIN \"tt\".\"medium_table\" as medium_table on large_table.\"medium_key\" = medium_table.\"large_key\" "
+ " JOIN \"tt\".\"small_table\" as small_table on medium_table.\"small_key\" = small_table.\"medium_key\" ";
RelNode originalPlan = transform(sqlQuery, prepareRules);
RelNode optimizedPlan =
transform(
sqlQuery,
RuleSets.ofList(
ImmutableList.<RelOptRule>builder()
.addAll(prepareRules)
.add(BeamJoinAssociateRule.INSTANCE)
.build()));
assertTopTableInJoins(originalPlan, "small_table");
assertTopTableInJoins(optimizedPlan, "large_table");
}
@Test
public void testBeamJoinPushThroughJoinRuleLeft() throws Exception {
RuleSet prepareRules =
RuleSets.ofList(
SortProjectTransposeRule.INSTANCE,
EnumerableRules.ENUMERABLE_JOIN_RULE,
EnumerableRules.ENUMERABLE_PROJECT_RULE,
EnumerableRules.ENUMERABLE_SORT_RULE,
EnumerableRules.ENUMERABLE_TABLE_SCAN_RULE);
String sqlQuery =
"select * from \"tt\".\"large_table\" as large_table "
+ " JOIN \"tt\".\"medium_table\" as medium_table on large_table.\"medium_key\" = medium_table.\"large_key\" "
+ " JOIN \"tt\".\"small_table\" as small_table on medium_table.\"small_key\" = small_table.\"medium_key\" ";
RelNode originalPlan = transform(sqlQuery, prepareRules);
RelNode optimizedPlan =
transform(
sqlQuery,
RuleSets.ofList(
ImmutableList.<RelOptRule>builder()
.addAll(prepareRules)
.add(BeamJoinPushThroughJoinRule.LEFT)
.build()));
assertTopTableInJoins(originalPlan, "small_table");
assertTopTableInJoins(optimizedPlan, "large_table");
}
@Test
public void testBeamJoinPushThroughJoinRuleRight() throws Exception {
RuleSet prepareRules =
RuleSets.ofList(
SortProjectTransposeRule.INSTANCE,
EnumerableRules.ENUMERABLE_JOIN_RULE,
EnumerableRules.ENUMERABLE_PROJECT_RULE,
EnumerableRules.ENUMERABLE_SORT_RULE,
EnumerableRules.ENUMERABLE_TABLE_SCAN_RULE);
String sqlQuery =
"select * from \"tt\".\"medium_table\" as medium_table "
+ " JOIN \"tt\".\"large_table\" as large_table on large_table.\"medium_key\" = medium_table.\"large_key\" "
+ " JOIN \"tt\".\"small_table\" as small_table on medium_table.\"small_key\" = small_table.\"medium_key\" ";
RelNode originalPlan = transform(sqlQuery, prepareRules);
RelNode optimizedPlan =
transform(
sqlQuery,
RuleSets.ofList(
ImmutableList.<RelOptRule>builder()
.addAll(prepareRules)
.add(BeamJoinPushThroughJoinRule.RIGHT)
.build()));
assertTopTableInJoins(originalPlan, "small_table");
assertTopTableInJoins(optimizedPlan, "large_table");
}
@Test
public void testSystemReorderingLargeMediumSmall() {
TestTableProvider tableProvider = new TestTableProvider();
createThreeTables(tableProvider);
BeamSqlEnv env = BeamSqlEnv.withTableProvider(tableProvider);
// This is Join(Join(large, medium), small) which should be converted to a join that large table
// is on the top.
BeamRelNode parsedQuery =
env.parseQuery(
"select * from large_table "
+ " JOIN medium_table on large_table.medium_key = medium_table.large_key "
+ " JOIN small_table on medium_table.small_key = small_table.medium_key ");
assertTopTableInJoins(parsedQuery, "large_table");
}
@Test
public void testSystemReorderingMediumLargeSmall() {
TestTableProvider tableProvider = new TestTableProvider();
createThreeTables(tableProvider);
BeamSqlEnv env = BeamSqlEnv.withTableProvider(tableProvider);
// This is Join(Join(medium, large), small) which should be converted to a join that large table
// is on the top.
BeamRelNode parsedQuery =
env.parseQuery(
"select * from medium_table "
+ " JOIN large_table on large_table.medium_key = medium_table.large_key "
+ " JOIN small_table on medium_table.small_key = small_table.medium_key ");
assertTopTableInJoins(parsedQuery, "large_table");
}
@Test
public void testSystemNotReorderingWithoutRules() {
TestTableProvider tableProvider = new TestTableProvider();
createThreeTables(tableProvider);
List<RelOptRule> ruleSet =
Arrays.stream(BeamRuleSets.getRuleSets())
.flatMap(rules -> StreamSupport.stream(rules.spliterator(), false))
.filter(rule -> !(rule instanceof BeamJoinPushThroughJoinRule))
.filter(rule -> !(rule instanceof BeamJoinAssociateRule))
.filter(rule -> !(rule instanceof JoinCommuteRule))
.collect(Collectors.toList());
BeamSqlEnv env =
BeamSqlEnv.builder(tableProvider)
.setPipelineOptions(PipelineOptionsFactory.create())
.setRuleSets(new RuleSet[] {RuleSets.ofList(ruleSet)})
.build();
// This is Join(Join(medium, large), small) which should be converted to a join that large table
// is on the top.
BeamRelNode parsedQuery =
env.parseQuery(
"select * from medium_table "
+ " JOIN large_table on large_table.medium_key = medium_table.large_key "
+ " JOIN small_table on medium_table.small_key = small_table.medium_key ");
assertTopTableInJoins(parsedQuery, "small_table");
}
@Test
public void testSystemNotReorderingMediumSmallLarge() {
TestTableProvider tableProvider = new TestTableProvider();
createThreeTables(tableProvider);
BeamSqlEnv env = BeamSqlEnv.withTableProvider(tableProvider);
// This is a correct ordered join because large table is on the top. It should not change that.
BeamRelNode parsedQuery =
env.parseQuery(
"select * from medium_table "
+ " JOIN small_table on medium_table.small_key = small_table.medium_key "
+ " JOIN large_table on large_table.medium_key = medium_table.large_key ");
assertTopTableInJoins(parsedQuery, "large_table");
}
@Test
public void testSystemNotReorderingSmallMediumLarge() {
TestTableProvider tableProvider = new TestTableProvider();
createThreeTables(tableProvider);
BeamSqlEnv env = BeamSqlEnv.withTableProvider(tableProvider);
// This is a correct ordered join because large table is on the top. It should not change that.
BeamRelNode parsedQuery =
env.parseQuery(
"select * from small_table "
+ " JOIN medium_table on medium_table.small_key = small_table.medium_key "
+ " JOIN large_table on large_table.medium_key = medium_table.large_key ");
assertTopTableInJoins(parsedQuery, "large_table");
}
private RelNode transform(String sql, RuleSet prepareRules) throws Exception {
final SchemaPlus rootSchema = Frameworks.createRootSchema(true);
final SchemaPlus defSchema = rootSchema.add("tt", new ThreeTablesSchema());
final FrameworkConfig config =
Frameworks.newConfigBuilder()
.parserConfig(SqlParser.Config.DEFAULT)
.defaultSchema(defSchema)
.traitDefs(ConventionTraitDef.INSTANCE, RelCollationTraitDef.INSTANCE)
.programs(Programs.of(prepareRules))
.build();
Planner planner = Frameworks.getPlanner(config);
SqlNode parse = planner.parse(sql);
SqlNode validate = planner.validate(parse);
RelRoot planRoot = planner.rel(validate);
RelNode planBefore = planRoot.rel;
RelTraitSet desiredTraits = planBefore.getTraitSet().replace(EnumerableConvention.INSTANCE);
return planner.transform(0, desiredTraits, planBefore);
}
private void assertTopTableInJoins(RelNode parsedQuery, String expectedTableName) {
RelNode firstJoin = parsedQuery;
while (!(firstJoin instanceof Join)) {
firstJoin = firstJoin.getInput(0);
}
RelNode topRight = ((Join) firstJoin).getRight();
while (!(topRight instanceof Join) && !(topRight instanceof TableScan)) {
topRight = topRight.getInput(0);
}
if (topRight instanceof TableScan) {
Assert.assertTrue(topRight.getDescription().contains(expectedTableName));
} else {
RelNode topLeft = ((Join) firstJoin).getLeft();
while (!(topLeft instanceof TableScan)) {
topLeft = topLeft.getInput(0);
}
Assert.assertTrue(topLeft.getDescription().contains(expectedTableName));
}
}
private void createThreeTables(TestTableProvider tableProvider) {
BeamSqlEnv env = BeamSqlEnv.withTableProvider(tableProvider);
env.executeDdl("CREATE EXTERNAL TABLE small_table (id INTEGER, medium_key INTEGER) TYPE text");
env.executeDdl(
"CREATE EXTERNAL TABLE medium_table ("
+ "id INTEGER,"
+ "small_key INTEGER,"
+ "large_key INTEGER"
+ ") TYPE text");
env.executeDdl(
"CREATE EXTERNAL TABLE large_table ("
+ "id INTEGER,"
+ "medium_key INTEGER"
+ ") TYPE text");
Row row =
Row.withSchema(tableProvider.getTable("small_table").getSchema()).addValues(1, 1).build();
tableProvider.addRows("small_table", row);
for (int i = 0; i < 3; i++) {
row =
Row.withSchema(tableProvider.getTable("medium_table").getSchema())
.addValues(i, 1, 2)
.build();
tableProvider.addRows("medium_table", row);
}
for (int i = 0; i < 100; i++) {
row =
Row.withSchema(tableProvider.getTable("large_table").getSchema()).addValues(i, 2).build();
tableProvider.addRows("large_table", row);
}
}
}
final class ThreeTablesSchema extends AbstractSchema {
private final ImmutableMap<String, Table> tables;
public ThreeTablesSchema() {
super();
ArrayList<Object[]> mediumData = new ArrayList<>();
for (int i = 0; i < 3; i++) {
mediumData.add(new Object[] {i, 1, 2});
}
ArrayList<Object[]> largeData = new ArrayList<>();
for (int i = 0; i < 100; i++) {
largeData.add(new Object[] {i, 2});
}
tables =
ImmutableMap.<String, Table>builder()
.put(
"small_table",
new PkClusteredTable(
factory ->
new RelDataTypeFactory.Builder(factory)
.add("id", factory.createJavaType(int.class))
.add("medium_key", factory.createJavaType(int.class))
.build(),
ImmutableBitSet.of(0),
Arrays.asList(new Object[] {1, 1}, new Object[] {2, 1})))
.put(
"medium_table",
new PkClusteredTable(
factory ->
new RelDataTypeFactory.Builder(factory)
.add("id", factory.createJavaType(int.class))
.add("small_key", factory.createJavaType(int.class))
.add("large_key", factory.createJavaType(int.class))
.build(),
ImmutableBitSet.of(0),
mediumData))
.put(
"large_table",
new PkClusteredTable(
factory ->
new RelDataTypeFactory.Builder(factory)
.add("id", factory.createJavaType(int.class))
.add("medium_key", factory.createJavaType(int.class))
.build(),
ImmutableBitSet.of(0),
largeData))
.build();
}
@Override
protected Map<String, org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.schema.Table>
getTableMap() {
return tables;
}
/** A table sorted (ascending direction and nulls last) on the primary key. */
private static class PkClusteredTable extends AbstractTable implements ScannableTable {
private final ImmutableBitSet pkColumns;
private final List<Object[]> data;
private final java.util.function.Function<RelDataTypeFactory, RelDataType> typeBuilder;
PkClusteredTable(
Function<RelDataTypeFactory, RelDataType> dataTypeBuilder,
ImmutableBitSet pkColumns,
List<Object[]> data) {
this.data = data;
this.typeBuilder = dataTypeBuilder;
this.pkColumns = pkColumns;
}
@Override
public Statistic getStatistic() {
List<RelFieldCollation> collationFields = new ArrayList<>();
for (Integer key : pkColumns) {
collationFields.add(
new RelFieldCollation(
key, RelFieldCollation.Direction.ASCENDING, RelFieldCollation.NullDirection.LAST));
}
return Statistics.of(
data.size(),
ImmutableList.of(pkColumns),
ImmutableList.of(RelCollations.of(collationFields)));
}
@Override
public RelDataType getRowType(final RelDataTypeFactory typeFactory) {
return typeBuilder.apply(typeFactory);
}
@Override
public Enumerable<Object[]> scan(final DataContext root) {
return Linq4j.asEnumerable(data);
}
}
}