| // Licensed to the Apache Software Foundation (ASF) under one |
| // or more contributor license agreements. See the NOTICE file |
| // distributed with this work for additional information |
| // regarding copyright ownership. The ASF licenses this file |
| // to you under the Apache License, Version 2.0 (the |
| // "License"); you may not use this file except in compliance |
| // with the License. You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, |
| // software distributed under the License is distributed on an |
| // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| // KIND, either express or implied. See the License for the |
| // specific language governing permissions and limitations |
| // under the License. |
| |
| use std::{collections::HashMap, sync::Arc}; |
| |
| use arrow::datatypes::{DataType, Field, Schema}; |
| |
| use datafusion_common::config::ConfigOptions; |
| use datafusion_common::{Result, TableReference, plan_err}; |
| use datafusion_expr::WindowUDF; |
| use datafusion_expr::planner::ExprPlanner; |
| use datafusion_expr::{ |
| AggregateUDF, ScalarUDF, TableSource, logical_plan::builder::LogicalTableSource, |
| }; |
| use datafusion_functions::core::planner::CoreFunctionPlanner; |
| use datafusion_functions_aggregate::count::count_udaf; |
| use datafusion_functions_aggregate::sum::sum_udaf; |
| use datafusion_sql::{ |
| planner::{ContextProvider, SqlToRel}, |
| sqlparser::{dialect::GenericDialect, parser::Parser}, |
| }; |
| |
| fn main() { |
| let sql = "SELECT \ |
| c.id, c.first_name, c.last_name, \ |
| COUNT(*) as num_orders, \ |
| sum(o.price) AS total_price, \ |
| sum(o.price * s.sales_tax) AS state_tax \ |
| FROM customer c \ |
| JOIN state s ON c.state = s.id \ |
| JOIN orders o ON c.id = o.customer_id \ |
| WHERE o.price > 0 \ |
| AND c.last_name LIKE 'G%' \ |
| GROUP BY 1, 2, 3 \ |
| ORDER BY state_tax DESC"; |
| |
| // parse the SQL |
| let dialect = GenericDialect {}; // or AnsiDialect, or your own dialect ... |
| let ast = Parser::parse_sql(&dialect, sql).unwrap(); |
| let statement = &ast[0]; |
| |
| // create a logical query plan |
| let context_provider = MyContextProvider::new() |
| .with_udaf(sum_udaf()) |
| .with_udaf(count_udaf()) |
| .with_expr_planner(Arc::new(CoreFunctionPlanner::default())); |
| let sql_to_rel = SqlToRel::new(&context_provider); |
| let plan = sql_to_rel.sql_statement_to_plan(statement.clone()).unwrap(); |
| |
| // show the plan |
| println!("{plan}"); |
| } |
| |
| struct MyContextProvider { |
| options: ConfigOptions, |
| tables: HashMap<String, Arc<dyn TableSource>>, |
| udafs: HashMap<String, Arc<AggregateUDF>>, |
| expr_planners: Vec<Arc<dyn ExprPlanner>>, |
| } |
| |
| impl MyContextProvider { |
| fn with_udaf(mut self, udaf: Arc<AggregateUDF>) -> Self { |
| self.udafs.insert(udaf.name().to_string(), udaf); |
| self |
| } |
| |
| fn with_expr_planner(mut self, planner: Arc<dyn ExprPlanner>) -> Self { |
| self.expr_planners.push(planner); |
| self |
| } |
| |
| fn new() -> Self { |
| let mut tables = HashMap::new(); |
| tables.insert( |
| "customer".to_string(), |
| create_table_source(vec![ |
| Field::new("id", DataType::Int32, false), |
| Field::new("first_name", DataType::Utf8, false), |
| Field::new("last_name", DataType::Utf8, false), |
| Field::new("state", DataType::Utf8, false), |
| ]), |
| ); |
| tables.insert( |
| "state".to_string(), |
| create_table_source(vec![ |
| Field::new("id", DataType::Int32, false), |
| Field::new("sales_tax", DataType::Decimal128(10, 2), false), |
| ]), |
| ); |
| tables.insert( |
| "orders".to_string(), |
| create_table_source(vec![ |
| Field::new("id", DataType::Int32, false), |
| Field::new("customer_id", DataType::Int32, false), |
| Field::new("item_id", DataType::Int32, false), |
| Field::new("quantity", DataType::Int32, false), |
| Field::new("price", DataType::Decimal128(10, 2), false), |
| ]), |
| ); |
| Self { |
| tables, |
| options: Default::default(), |
| udafs: Default::default(), |
| expr_planners: vec![], |
| } |
| } |
| } |
| |
| fn create_table_source(fields: Vec<Field>) -> Arc<dyn TableSource> { |
| Arc::new(LogicalTableSource::new(Arc::new( |
| Schema::new_with_metadata(fields, HashMap::new()), |
| ))) |
| } |
| |
| impl ContextProvider for MyContextProvider { |
| fn get_table_source(&self, name: TableReference) -> Result<Arc<dyn TableSource>> { |
| match self.tables.get(name.table()) { |
| Some(table) => Ok(Arc::clone(table)), |
| _ => plan_err!("Table not found: {}", name.table()), |
| } |
| } |
| |
| fn get_function_meta(&self, _name: &str) -> Option<Arc<ScalarUDF>> { |
| None |
| } |
| |
| fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> { |
| self.udafs.get(name).cloned() |
| } |
| |
| fn get_variable_type(&self, _variable_names: &[String]) -> Option<DataType> { |
| None |
| } |
| |
| fn get_window_meta(&self, _name: &str) -> Option<Arc<WindowUDF>> { |
| None |
| } |
| |
| fn options(&self) -> &ConfigOptions { |
| &self.options |
| } |
| |
| fn udf_names(&self) -> Vec<String> { |
| Vec::new() |
| } |
| |
| fn udaf_names(&self) -> Vec<String> { |
| Vec::new() |
| } |
| |
| fn udwf_names(&self) -> Vec<String> { |
| Vec::new() |
| } |
| |
| fn get_expr_planners(&self) -> &[Arc<dyn ExprPlanner>] { |
| &self.expr_planners |
| } |
| } |