blob: 9a2aabaa79c2e6648d20ba06c50ceb5c0561cbde [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
//! This example shows how to use the structures that DataFusion provides to perform
//! Analysis on SQL queries and their plans.
//!
//! As a motivating example, we show how to count the number of JOINs in a query
//! as well as how many join tree's there are with their respective join count
use std::sync::Arc;
use datafusion::common::Result;
use datafusion::{
datasource::MemTable,
execution::context::{SessionConfig, SessionContext},
};
use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
use datafusion_expr::LogicalPlan;
use test_utils::tpcds::tpcds_schemas;
/// Counts the total number of joins in a plan
fn total_join_count(plan: &LogicalPlan) -> usize {
let mut total = 0;
// We can use the TreeNode API to walk over a LogicalPlan.
plan.apply(|node| {
// if we encounter a join we update the running count
if matches!(node, LogicalPlan::Join(_) | LogicalPlan::CrossJoin(_)) {
total += 1;
}
Ok(TreeNodeRecursion::Continue)
})
.unwrap();
total
}
/// Counts the total number of joins in a plan and collects every join tree in
/// the plan with their respective join count.
///
/// Join Tree Definition: the largest subtree consisting entirely of joins
///
/// For example, this plan:
///
/// ```text
/// JOIN
/// / \
/// A JOIN
/// / \
/// B C
/// ```
///
/// has a single join tree `(A-B-C)` which will result in `(2, [2])`
///
/// This plan:
///
/// ```text
/// JOIN
/// / \
/// A GROUP
/// |
/// JOIN
/// / \
/// B C
/// ```
///
/// Has two join trees `(A-, B-C)` which will result in `(2, [1, 1])`
fn count_trees(plan: &LogicalPlan) -> (usize, Vec<usize>) {
// this works the same way as `total_count`, but now when we encounter a Join
// we try to collect it's entire tree
let mut to_visit = vec![plan];
let mut total = 0;
let mut groups = vec![];
while let Some(node) = to_visit.pop() {
// if we encounter a join, we know were at the root of the tree
// count this tree and recurse on it's inputs
if matches!(node, LogicalPlan::Join(_) | LogicalPlan::CrossJoin(_)) {
let (group_count, inputs) = count_tree(node);
total += group_count;
groups.push(group_count);
to_visit.extend(inputs);
} else {
to_visit.extend(node.inputs());
}
}
(total, groups)
}
/// Count the entire join tree and return its inputs using TreeNode API
///
/// For example, if this function receives following plan:
///
/// ```text
/// JOIN
/// / \
/// A GROUP
/// |
/// JOIN
/// / \
/// B C
/// ```
///
/// It will return `(1, [A, GROUP])`
fn count_tree(join: &LogicalPlan) -> (usize, Vec<&LogicalPlan>) {
let mut inputs = Vec::new();
let mut total = 0;
join.apply(|node| {
// Some extra knowledge:
//
// optimized plans have their projections pushed down as far as
// possible, which sometimes results in a projection going in between 2
// subsequent joins giving the illusion these joins are not "related",
// when in fact they are.
//
// This plan:
// JOIN
// / \
// A PROJECTION
// |
// JOIN
// / \
// B C
//
// is the same as:
//
// JOIN
// / \
// A JOIN
// / \
// B C
// we can continue the recursion in this case
if let LogicalPlan::Projection(_) = node {
return Ok(TreeNodeRecursion::Continue);
}
// any join we count
if matches!(node, LogicalPlan::Join(_) | LogicalPlan::CrossJoin(_)) {
total += 1;
Ok(TreeNodeRecursion::Continue)
} else {
inputs.push(node);
// skip children of input node
Ok(TreeNodeRecursion::Jump)
}
})
.unwrap();
(total, inputs)
}
#[tokio::main]
async fn main() -> Result<()> {
// To show how we can count the joins in a sql query we'll be using query 88
// from the TPC-DS benchmark.
//
// q8 has many joins, cross-joins and multiple join-trees, perfect for our
// example:
let tpcds_query_88 = "
select *
from
(select count(*) h8_30_to_9
from store_sales, household_demographics , time_dim, store
where ss_sold_time_sk = time_dim.t_time_sk
and ss_hdemo_sk = household_demographics.hd_demo_sk
and ss_store_sk = s_store_sk
and time_dim.t_hour = 8
and time_dim.t_minute >= 30
and ((household_demographics.hd_dep_count = 3 and household_demographics.hd_vehicle_count<=3+2) or
(household_demographics.hd_dep_count = 0 and household_demographics.hd_vehicle_count<=0+2) or
(household_demographics.hd_dep_count = 1 and household_demographics.hd_vehicle_count<=1+2))
and store.s_store_name = 'ese') s1,
(select count(*) h9_to_9_30
from store_sales, household_demographics , time_dim, store
where ss_sold_time_sk = time_dim.t_time_sk
and ss_hdemo_sk = household_demographics.hd_demo_sk
and ss_store_sk = s_store_sk
and time_dim.t_hour = 9
and time_dim.t_minute < 30
and ((household_demographics.hd_dep_count = 3 and household_demographics.hd_vehicle_count<=3+2) or
(household_demographics.hd_dep_count = 0 and household_demographics.hd_vehicle_count<=0+2) or
(household_demographics.hd_dep_count = 1 and household_demographics.hd_vehicle_count<=1+2))
and store.s_store_name = 'ese') s2,
(select count(*) h9_30_to_10
from store_sales, household_demographics , time_dim, store
where ss_sold_time_sk = time_dim.t_time_sk
and ss_hdemo_sk = household_demographics.hd_demo_sk
and ss_store_sk = s_store_sk
and time_dim.t_hour = 9
and time_dim.t_minute >= 30
and ((household_demographics.hd_dep_count = 3 and household_demographics.hd_vehicle_count<=3+2) or
(household_demographics.hd_dep_count = 0 and household_demographics.hd_vehicle_count<=0+2) or
(household_demographics.hd_dep_count = 1 and household_demographics.hd_vehicle_count<=1+2))
and store.s_store_name = 'ese') s3,
(select count(*) h10_to_10_30
from store_sales, household_demographics , time_dim, store
where ss_sold_time_sk = time_dim.t_time_sk
and ss_hdemo_sk = household_demographics.hd_demo_sk
and ss_store_sk = s_store_sk
and time_dim.t_hour = 10
and time_dim.t_minute < 30
and ((household_demographics.hd_dep_count = 3 and household_demographics.hd_vehicle_count<=3+2) or
(household_demographics.hd_dep_count = 0 and household_demographics.hd_vehicle_count<=0+2) or
(household_demographics.hd_dep_count = 1 and household_demographics.hd_vehicle_count<=1+2))
and store.s_store_name = 'ese') s4,
(select count(*) h10_30_to_11
from store_sales, household_demographics , time_dim, store
where ss_sold_time_sk = time_dim.t_time_sk
and ss_hdemo_sk = household_demographics.hd_demo_sk
and ss_store_sk = s_store_sk
and time_dim.t_hour = 10
and time_dim.t_minute >= 30
and ((household_demographics.hd_dep_count = 3 and household_demographics.hd_vehicle_count<=3+2) or
(household_demographics.hd_dep_count = 0 and household_demographics.hd_vehicle_count<=0+2) or
(household_demographics.hd_dep_count = 1 and household_demographics.hd_vehicle_count<=1+2))
and store.s_store_name = 'ese') s5,
(select count(*) h11_to_11_30
from store_sales, household_demographics , time_dim, store
where ss_sold_time_sk = time_dim.t_time_sk
and ss_hdemo_sk = household_demographics.hd_demo_sk
and ss_store_sk = s_store_sk
and time_dim.t_hour = 11
and time_dim.t_minute < 30
and ((household_demographics.hd_dep_count = 3 and household_demographics.hd_vehicle_count<=3+2) or
(household_demographics.hd_dep_count = 0 and household_demographics.hd_vehicle_count<=0+2) or
(household_demographics.hd_dep_count = 1 and household_demographics.hd_vehicle_count<=1+2))
and store.s_store_name = 'ese') s6,
(select count(*) h11_30_to_12
from store_sales, household_demographics , time_dim, store
where ss_sold_time_sk = time_dim.t_time_sk
and ss_hdemo_sk = household_demographics.hd_demo_sk
and ss_store_sk = s_store_sk
and time_dim.t_hour = 11
and time_dim.t_minute >= 30
and ((household_demographics.hd_dep_count = 3 and household_demographics.hd_vehicle_count<=3+2) or
(household_demographics.hd_dep_count = 0 and household_demographics.hd_vehicle_count<=0+2) or
(household_demographics.hd_dep_count = 1 and household_demographics.hd_vehicle_count<=1+2))
and store.s_store_name = 'ese') s7,
(select count(*) h12_to_12_30
from store_sales, household_demographics , time_dim, store
where ss_sold_time_sk = time_dim.t_time_sk
and ss_hdemo_sk = household_demographics.hd_demo_sk
and ss_store_sk = s_store_sk
and time_dim.t_hour = 12
and time_dim.t_minute < 30
and ((household_demographics.hd_dep_count = 3 and household_demographics.hd_vehicle_count<=3+2) or
(household_demographics.hd_dep_count = 0 and household_demographics.hd_vehicle_count<=0+2) or
(household_demographics.hd_dep_count = 1 and household_demographics.hd_vehicle_count<=1+2))
and store.s_store_name = 'ese') s8;";
// first set up the config
let config = SessionConfig::default();
let ctx = SessionContext::new_with_config(config);
// register the tables of the TPC-DS query
let tables = tpcds_schemas();
for table in tables {
ctx.register_table(
table.name,
Arc::new(MemTable::try_new(Arc::new(table.schema.clone()), vec![])?),
)?;
}
// We can create a LogicalPlan from a SQL query like this
let logical_plan = ctx.sql(tpcds_query_88).await?.into_optimized_plan()?;
println!(
"Optimized Logical Plan:\n\n{}\n",
logical_plan.display_indent()
);
// we can get the total count (query 88 has 31 joins: 7 CROSS joins and 24 INNER joins => 40 input relations)
let total_join_count = total_join_count(&logical_plan);
assert_eq!(31, total_join_count);
println!("The plan has {total_join_count} joins.");
// Furthermore the 24 inner joins are 8 groups of 3 joins with the 7
// cross-joins combining them we can get these groups using the
// `count_trees` method
let (total_join_count, trees) = count_trees(&logical_plan);
assert_eq!(
(total_join_count, &trees),
// query 88 is very straightforward, we know the cross-join group is at
// the top of the plan followed by the INNER joins
(31, &vec![7, 3, 3, 3, 3, 3, 3, 3, 3])
);
println!(
"And following join-trees (number represents join amount in tree): {trees:?}"
);
Ok(())
}