blob: b4663b345f647b48c2dba397587a0f8d058b1dd3 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
use arrow::array::{ArrayRef, Int32Array, RecordBatch, StringArray};
use arrow_schema::DataType;
use datafusion::prelude::SessionContext;
use datafusion_common::tree_node::{Transformed, TreeNode};
use datafusion_common::{assert_batches_eq, Result, ScalarValue};
use datafusion_expr::{
BinaryExpr, ColumnarValue, Expr, LogicalPlan, Operator, ScalarUDF, ScalarUDFImpl,
Signature, Volatility,
};
use datafusion_optimizer::optimizer::ApplyOrder;
use datafusion_optimizer::{OptimizerConfig, OptimizerRule};
use std::any::Any;
use std::sync::Arc;
/// This example demonstrates how to add your own [`OptimizerRule`]
/// to DataFusion.
///
/// [`OptimizerRule`]s transform [`LogicalPlan`]s into an equivalent (but
/// hopefully faster) form.
///
/// See [analyzer_rule.rs] for an example of AnalyzerRules, which are for
/// changing plan semantics.
#[tokio::main]
pub async fn main() -> Result<()> {
// DataFusion includes many built in OptimizerRules for tasks such as outer
// to inner join conversion and constant folding.
//
// Note you can change the order of optimizer rules using the lower level
// `SessionState` API
let ctx = SessionContext::new();
ctx.add_optimizer_rule(Arc::new(MyOptimizerRule {}));
// Now, let's plan and run queries with the new rule
ctx.register_batch("person", person_batch())?;
let sql = "SELECT * FROM person WHERE age = 22";
let plan = ctx.sql(sql).await?.into_optimized_plan()?;
// We can see the effect of our rewrite on the output plan that the filter
// has been rewritten to `my_eq`
assert_eq!(
plan.display_indent().to_string(),
"Filter: my_eq(person.age, Int32(22))\
\n TableScan: person projection=[name, age]"
);
// The query below doesn't respect a filter `where age = 22` because
// the plan has been rewritten using UDF which returns always true
//
// And the output verifies the predicates have been changed (as the my_eq
// function always returns true)
assert_batches_eq!(
[
"+--------+-----+",
"| name | age |",
"+--------+-----+",
"| Andy | 11 |",
"| Andrew | 22 |",
"| Oleks | 33 |",
"+--------+-----+",
],
&ctx.sql(sql).await?.collect().await?
);
// however we can see the rule doesn't trigger for queries with predicates
// other than `=`
assert_batches_eq!(
[
"+-------+-----+",
"| name | age |",
"+-------+-----+",
"| Andy | 11 |",
"| Oleks | 33 |",
"+-------+-----+",
],
&ctx.sql("SELECT * FROM person WHERE age <> 22")
.await?
.collect()
.await?
);
Ok(())
}
/// An example OptimizerRule that replaces all `col = <const>` predicates with a
/// user defined function
struct MyOptimizerRule {}
impl OptimizerRule for MyOptimizerRule {
fn name(&self) -> &str {
"my_optimizer_rule"
}
// New OptimizerRules should use the "rewrite" api as it is more efficient
fn supports_rewrite(&self) -> bool {
true
}
/// Ask the optimizer to handle the plan recursion. `rewrite` will be called
/// on each plan node.
fn apply_order(&self) -> Option<ApplyOrder> {
Some(ApplyOrder::BottomUp)
}
fn rewrite(
&self,
plan: LogicalPlan,
_config: &dyn OptimizerConfig,
) -> Result<Transformed<LogicalPlan>> {
plan.map_expressions(|expr| {
// This closure is called for all expressions in the current plan
//
// For example, given a plan like `SELECT a + b, 5 + 10`
//
// The closure would be called twice:
// 1. once for `a + b`
// 2. once for `5 + 10`
self.rewrite_expr(expr)
})
}
}
impl MyOptimizerRule {
/// Rewrites an Expr replacing all `<col> = <const>` expressions with
/// a call to my_eq udf
fn rewrite_expr(&self, expr: Expr) -> Result<Transformed<Expr>> {
// do a bottom up rewrite of the expression tree
expr.transform_up(|expr| {
// Closure called for each sub tree
match expr {
Expr::BinaryExpr(binary_expr) if is_binary_eq(&binary_expr) => {
// destruture the expression
let BinaryExpr { left, op: _, right } = binary_expr;
// rewrite to `my_eq(left, right)`
let udf = ScalarUDF::new_from_impl(MyEq::new());
let call = udf.call(vec![*left, *right]);
Ok(Transformed::yes(call))
}
_ => Ok(Transformed::no(expr)),
}
})
// Note that the TreeNode API handles propagating the transformed flag
// and errors up the call chain
}
}
/// return true of the expression is an equality expression for a literal or
/// column reference
fn is_binary_eq(binary_expr: &BinaryExpr) -> bool {
binary_expr.op == Operator::Eq
&& is_lit_or_col(binary_expr.left.as_ref())
&& is_lit_or_col(binary_expr.right.as_ref())
}
/// Return true if the expression is a literal or column reference
fn is_lit_or_col(expr: &Expr) -> bool {
matches!(expr, Expr::Column(_) | Expr::Literal(_))
}
/// A simple user defined filter function
#[derive(Debug, Clone)]
struct MyEq {
signature: Signature,
}
impl MyEq {
fn new() -> Self {
Self {
signature: Signature::any(2, Volatility::Stable),
}
}
}
impl ScalarUDFImpl for MyEq {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"my_eq"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Boolean)
}
fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> {
// this example simply returns "true" which is not what a real
// implementation would do.
Ok(ColumnarValue::Scalar(ScalarValue::from(true)))
}
}
/// Return a RecordBatch with made up data
fn person_batch() -> RecordBatch {
let name: ArrayRef =
Arc::new(StringArray::from_iter_values(["Andy", "Andrew", "Oleks"]));
let age: ArrayRef = Arc::new(Int32Array::from(vec![11, 22, 33]));
RecordBatch::try_from_iter(vec![("name", name), ("age", age)]).unwrap()
}