blob: f75dc79eba00649299380a97127616ebb9a5b6b0 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use datafusion_common::config::ConfigOptions;
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::expr_rewriter::rewrite_expr;
use datafusion_expr::{
AggregateUDF, Between, Expr, Filter, LogicalPlan, ScalarUDF, TableSource,
};
use datafusion_optimizer::optimizer::Optimizer;
use datafusion_optimizer::{utils, OptimizerConfig, OptimizerContext, OptimizerRule};
use datafusion_sql::planner::{ContextProvider, SqlToRel};
use datafusion_sql::sqlparser::dialect::PostgreSqlDialect;
use datafusion_sql::sqlparser::parser::Parser;
use datafusion_sql::TableReference;
use std::any::Any;
use std::sync::Arc;
pub fn main() -> Result<()> {
// produce a logical plan using the datafusion-sql crate
let dialect = PostgreSqlDialect {};
let sql = "SELECT * FROM person WHERE age BETWEEN 21 AND 32";
let statements = Parser::parse_sql(&dialect, sql)?;
// produce a logical plan using the datafusion-sql crate
let context_provider = MyContextProvider::default();
let sql_to_rel = SqlToRel::new(&context_provider);
let logical_plan = sql_to_rel.sql_statement_to_plan(statements[0].clone())?;
println!(
"Unoptimized Logical Plan:\n\n{}\n",
logical_plan.display_indent()
);
// now run the optimizer with our custom rule
let optimizer = Optimizer::with_rules(vec![Arc::new(MyRule {})]);
let config = OptimizerContext::default().with_skip_failing_rules(false);
let optimized_plan = optimizer.optimize(&logical_plan, &config, observe)?;
println!(
"Optimized Logical Plan:\n\n{}\n",
optimized_plan.display_indent()
);
Ok(())
}
fn observe(plan: &LogicalPlan, rule: &dyn OptimizerRule) {
println!(
"After applying rule '{}':\n{}\n",
rule.name(),
plan.display_indent()
)
}
struct MyRule {}
impl OptimizerRule for MyRule {
fn name(&self) -> &str {
"my_rule"
}
fn try_optimize(
&self,
plan: &LogicalPlan,
config: &dyn OptimizerConfig,
) -> Result<Option<LogicalPlan>> {
// recurse down and optimize children first
let optimized_plan = utils::optimize_children(self, plan, config)?;
match optimized_plan {
Some(LogicalPlan::Filter(filter)) => {
let predicate = my_rewrite(filter.predicate.clone())?;
Ok(Some(LogicalPlan::Filter(Filter::try_new(
predicate,
filter.input,
)?)))
}
Some(optimized_plan) => Ok(Some(optimized_plan)),
None => match plan {
LogicalPlan::Filter(filter) => {
let predicate = my_rewrite(filter.predicate.clone())?;
Ok(Some(LogicalPlan::Filter(Filter::try_new(
predicate,
filter.input.clone(),
)?)))
}
_ => Ok(None),
},
}
}
}
/// use rewrite_expr to modify the expression tree.
fn my_rewrite(expr: Expr) -> Result<Expr> {
rewrite_expr(expr, |e| {
// closure is invoked for all sub expressions
match e {
Expr::Between(Between {
expr,
negated,
low,
high,
}) => {
// unbox
let expr: Expr = *expr;
let low: Expr = *low;
let high: Expr = *high;
if negated {
Ok(expr.clone().lt(low).or(expr.gt(high)))
} else {
Ok(expr.clone().gt_eq(low).and(expr.lt_eq(high)))
}
}
_ => Ok(e),
}
})
}
#[derive(Default)]
struct MyContextProvider {
options: ConfigOptions,
}
impl ContextProvider for MyContextProvider {
fn get_table_provider(&self, name: TableReference) -> Result<Arc<dyn TableSource>> {
if name.table() == "person" {
Ok(Arc::new(MyTableSource {
schema: Arc::new(Schema::new(vec![
Field::new("name", DataType::Utf8, false),
Field::new("age", DataType::UInt8, false),
])),
}))
} else {
Err(DataFusionError::Plan("table not found".to_string()))
}
}
fn get_function_meta(&self, _name: &str) -> Option<Arc<ScalarUDF>> {
None
}
fn get_aggregate_meta(&self, _name: &str) -> Option<Arc<AggregateUDF>> {
None
}
fn get_variable_type(&self, _variable_names: &[String]) -> Option<DataType> {
None
}
fn options(&self) -> &ConfigOptions {
&self.options
}
}
struct MyTableSource {
schema: SchemaRef,
}
impl TableSource for MyTableSource {
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
}