blob: 8d6ed7f1a743bb50a3e43e271c26879ff07a5d3c [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
//! SQL Query Planner (produces logical plan from SQL AST)
use std::str::FromStr;
use std::sync::Arc;
use crate::datasource::TableProvider;
use crate::logical_plan::Expr::Alias;
use crate::logical_plan::{
and, lit, DFSchema, Expr, LogicalPlan, LogicalPlanBuilder, Operator, PlanType,
StringifiedPlan, ToDFSchema,
};
use crate::scalar::ScalarValue;
use crate::{
error::{DataFusionError, Result},
physical_plan::udaf::AggregateUDF,
};
use crate::{
physical_plan::udf::ScalarUDF,
physical_plan::{aggregates, functions},
sql::parser::{CreateExternalTable, FileType, Statement as DFStatement},
};
use arrow::datatypes::*;
use crate::prelude::JoinType;
use sqlparser::ast::{
BinaryOperator, DataType as SQLDataType, Expr as SQLExpr, FunctionArg, Join,
JoinConstraint, JoinOperator, Query, Select, SelectItem, SetExpr, TableFactor,
TableWithJoins, UnaryOperator, Value,
};
use sqlparser::ast::{ColumnDef as SQLColumnDef, ColumnOption};
use sqlparser::ast::{OrderByExpr, Statement};
use sqlparser::parser::ParserError::ParserError;
use super::utils::{
can_columns_satisfy_exprs, expand_wildcard, expr_as_column_expr,
find_aggregate_exprs, find_column_exprs, rebase_expr,
};
/// The ContextProvider trait allows the query planner to obtain meta-data about tables and
/// functions referenced in SQL statements
pub trait ContextProvider {
/// Getter for a datasource
fn get_table_provider(
&self,
name: &str,
) -> Option<Arc<dyn TableProvider + Send + Sync>>;
/// Getter for a UDF description
fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarUDF>>;
/// Getter for a UDAF description
fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>>;
}
/// SQL query planner
pub struct SqlToRel<'a, S: ContextProvider> {
schema_provider: &'a S,
}
impl<'a, S: ContextProvider> SqlToRel<'a, S> {
/// Create a new query planner
pub fn new(schema_provider: &'a S) -> Self {
SqlToRel { schema_provider }
}
/// Generate a logical plan from an DataFusion SQL statement
pub fn statement_to_plan(&self, statement: &DFStatement) -> Result<LogicalPlan> {
match statement {
DFStatement::CreateExternalTable(s) => self.external_table_to_plan(&s),
DFStatement::Statement(s) => self.sql_statement_to_plan(&s),
}
}
/// Generate a logical plan from an SQL statement
pub fn sql_statement_to_plan(&self, sql: &Statement) -> Result<LogicalPlan> {
match sql {
Statement::Explain {
verbose,
statement,
analyze: _,
} => self.explain_statement_to_plan(*verbose, &statement),
Statement::Query(query) => self.query_to_plan(&query),
_ => Err(DataFusionError::NotImplemented(
"Only SELECT statements are implemented".to_string(),
)),
}
}
/// Generate a logic plan from an SQL query
pub fn query_to_plan(&self, query: &Query) -> Result<LogicalPlan> {
let plan = match &query.body {
SetExpr::Select(s) => self.select_to_plan(s.as_ref()),
_ => Err(DataFusionError::NotImplemented(format!(
"Query {} not implemented yet",
query.body
))),
}?;
let plan = self.order_by(&plan, &query.order_by)?;
self.limit(&plan, &query.limit)
}
/// Generate a logical plan from a CREATE EXTERNAL TABLE statement
pub fn external_table_to_plan(
&self,
statement: &CreateExternalTable,
) -> Result<LogicalPlan> {
let CreateExternalTable {
name,
columns,
file_type,
has_header,
location,
} = statement;
// semantic checks
match *file_type {
FileType::CSV => {
if columns.is_empty() {
return Err(DataFusionError::Plan(
"Column definitions required for CSV files. None found".into(),
));
}
}
FileType::Parquet => {
if !columns.is_empty() {
return Err(DataFusionError::Plan(
"Column definitions can not be specified for PARQUET files."
.into(),
));
}
}
FileType::NdJson => {}
};
let schema = self.build_schema(&columns)?;
Ok(LogicalPlan::CreateExternalTable {
schema: schema.to_dfschema_ref()?,
name: name.clone(),
location: location.clone(),
file_type: *file_type,
has_header: *has_header,
})
}
/// Generate a plan for EXPLAIN ... that will print out a plan
///
pub fn explain_statement_to_plan(
&self,
verbose: bool,
statement: &Statement,
) -> Result<LogicalPlan> {
let plan = self.sql_statement_to_plan(&statement)?;
let stringified_plans = vec![StringifiedPlan::new(
PlanType::LogicalPlan,
format!("{:#?}", plan),
)];
let schema = LogicalPlan::explain_schema();
let plan = Arc::new(plan);
Ok(LogicalPlan::Explain {
verbose,
plan,
stringified_plans,
schema: schema.to_dfschema_ref()?,
})
}
fn build_schema(&self, columns: &Vec<SQLColumnDef>) -> Result<Schema> {
let mut fields = Vec::new();
for column in columns {
let data_type = self.make_data_type(&column.data_type)?;
let allow_null = column
.options
.iter()
.any(|x| x.option == ColumnOption::Null);
fields.push(Field::new(&column.name.value, data_type, allow_null));
}
Ok(Schema::new(fields))
}
/// Maps the SQL type to the corresponding Arrow `DataType`
fn make_data_type(&self, sql_type: &SQLDataType) -> Result<DataType> {
match sql_type {
SQLDataType::BigInt => Ok(DataType::Int64),
SQLDataType::Int => Ok(DataType::Int32),
SQLDataType::SmallInt => Ok(DataType::Int16),
SQLDataType::Char(_) | SQLDataType::Varchar(_) | SQLDataType::Text => {
Ok(DataType::Utf8)
}
SQLDataType::Decimal(_, _) => Ok(DataType::Float64),
SQLDataType::Float(_) => Ok(DataType::Float32),
SQLDataType::Real | SQLDataType::Double => Ok(DataType::Float64),
SQLDataType::Boolean => Ok(DataType::Boolean),
SQLDataType::Date => Ok(DataType::Date32(DateUnit::Day)),
SQLDataType::Time => Ok(DataType::Time64(TimeUnit::Millisecond)),
SQLDataType::Timestamp => Ok(DataType::Date64(DateUnit::Millisecond)),
_ => Err(DataFusionError::NotImplemented(format!(
"The SQL data type {:?} is not implemented",
sql_type
))),
}
}
fn plan_from_tables(&self, from: &Vec<TableWithJoins>) -> Result<Vec<LogicalPlan>> {
match from.len() {
0 => Ok(vec![LogicalPlanBuilder::empty(true).build()?]),
_ => from
.iter()
.map(|t| self.plan_table_with_joins(t))
.collect::<Result<Vec<_>>>(),
}
}
fn plan_table_with_joins(&self, t: &TableWithJoins) -> Result<LogicalPlan> {
let left = self.create_relation(&t.relation)?;
match t.joins.len() {
0 => Ok(left),
n => {
let mut left = self.parse_relation_join(&left, &t.joins[0])?;
for i in 1..n {
left = self.parse_relation_join(&left, &t.joins[i])?;
}
Ok(left)
}
}
}
fn parse_relation_join(
&self,
left: &LogicalPlan,
join: &Join,
) -> Result<LogicalPlan> {
let right = self.create_relation(&join.relation)?;
match &join.join_operator {
JoinOperator::LeftOuter(constraint) => {
self.parse_join(left, &right, constraint, JoinType::Left)
}
JoinOperator::RightOuter(constraint) => {
self.parse_join(left, &right, constraint, JoinType::Right)
}
JoinOperator::Inner(constraint) => {
self.parse_join(left, &right, constraint, JoinType::Inner)
}
other => Err(DataFusionError::NotImplemented(format!(
"Unsupported JOIN operator {:?}",
other
))),
}
}
fn parse_join(
&self,
left: &LogicalPlan,
right: &LogicalPlan,
constraint: &JoinConstraint,
join_type: JoinType,
) -> Result<LogicalPlan> {
match constraint {
JoinConstraint::On(sql_expr) => {
let mut keys: Vec<(String, String)> = vec![];
let join_schema = left.schema().join(&right.schema())?;
// parse ON expression
let expr = self.sql_to_rex(sql_expr, &join_schema)?;
// extract join keys
extract_join_keys(&expr, &mut keys)?;
let left_keys: Vec<&str> =
keys.iter().map(|pair| pair.0.as_str()).collect();
let right_keys: Vec<&str> =
keys.iter().map(|pair| pair.1.as_str()).collect();
// return the logical plan representing the join
LogicalPlanBuilder::from(&left)
.join(&right, join_type, &left_keys, &right_keys)?
.build()
}
JoinConstraint::Using(idents) => {
let keys: Vec<&str> = idents.iter().map(|x| x.value.as_str()).collect();
LogicalPlanBuilder::from(&left)
.join(&right, join_type, &keys, &keys)?
.build()
}
JoinConstraint::Natural => {
// https://issues.apache.org/jira/browse/ARROW-10727
Err(DataFusionError::NotImplemented(
"NATURAL JOIN is not supported (https://issues.apache.org/jira/browse/ARROW-10727)".to_string(),
))
}
}
}
fn create_relation(&self, relation: &TableFactor) -> Result<LogicalPlan> {
match relation {
TableFactor::Table { name, .. } => {
let table_name = name.to_string();
match self.schema_provider.get_table_provider(&table_name) {
Some(provider) => {
LogicalPlanBuilder::scan(&table_name, provider, None)?.build()
}
None => Err(DataFusionError::Plan(format!(
"no provider found for table {}",
name
))),
}
}
TableFactor::Derived { subquery, .. } => self.query_to_plan(subquery),
TableFactor::NestedJoin(table_with_joins) => {
self.plan_table_with_joins(table_with_joins)
}
// @todo Support TableFactory::TableFunction?
_ => Err(DataFusionError::NotImplemented(format!(
"Unsupported ast node {:?} in create_relation",
relation
))),
}
}
/// Generate a logic plan from an SQL select
fn select_to_plan(&self, select: &Select) -> Result<LogicalPlan> {
if select.having.is_some() {
return Err(DataFusionError::NotImplemented(
"HAVING is not implemented yet".to_string(),
));
}
let plans = self.plan_from_tables(&select.from)?;
let plan = match &select.selection {
Some(predicate_expr) => {
// build join schema
let mut fields = vec![];
for plan in &plans {
fields.extend_from_slice(&plan.schema().fields());
}
let join_schema = DFSchema::new(fields)?;
let filter_expr = self.sql_to_rex(predicate_expr, &join_schema)?;
// look for expressions of the form `<column> = <column>`
let mut possible_join_keys = vec![];
extract_possible_join_keys(&filter_expr, &mut possible_join_keys)?;
let mut all_join_keys = vec![];
let mut left = plans[0].clone();
for right in plans.iter().skip(1) {
let left_schema = left.schema();
let right_schema = right.schema();
let mut join_keys = vec![];
for (l, r) in &possible_join_keys {
if left_schema.field_with_unqualified_name(l).is_ok()
&& right_schema.field_with_unqualified_name(r).is_ok()
{
join_keys.push((l.as_str(), r.as_str()));
} else if left_schema.field_with_unqualified_name(r).is_ok()
&& right_schema.field_with_unqualified_name(l).is_ok()
{
join_keys.push((r.as_str(), l.as_str()));
}
}
if join_keys.is_empty() {
return Err(DataFusionError::NotImplemented(
"Cartesian joins are not supported".to_string(),
));
} else {
let left_keys: Vec<_> =
join_keys.iter().map(|(l, _)| *l).collect();
let right_keys: Vec<_> =
join_keys.iter().map(|(_, r)| *r).collect();
let builder = LogicalPlanBuilder::from(&left);
left = builder
.join(right, JoinType::Inner, &left_keys, &right_keys)?
.build()?;
}
all_join_keys.extend_from_slice(&join_keys);
}
// remove join expressions from filter
match remove_join_expressions(&filter_expr, &all_join_keys)? {
Some(filter_expr) => {
LogicalPlanBuilder::from(&left).filter(filter_expr)?.build()
}
_ => Ok(left),
}
}
None => {
if plans.len() == 1 {
Ok(plans[0].clone())
} else {
Err(DataFusionError::NotImplemented(
"Cartesian joins are not supported".to_string(),
))
}
}
};
let plan = plan?;
// The SELECT expressions, with wildcards expanded.
let select_exprs = self.prepare_select_exprs(&plan, &select.projection)?;
// All of the aggregate expressions (deduplicated).
let aggr_exprs = find_aggregate_exprs(&select_exprs);
let (plan, select_exprs_post_aggr) =
if !select.group_by.is_empty() || !aggr_exprs.is_empty() {
self.aggregate(&plan, &select_exprs, &select.group_by, &aggr_exprs)?
} else {
(plan, select_exprs)
};
self.project(&plan, select_exprs_post_aggr, false)
}
/// Returns the `Expr`'s corresponding to a SQL query's SELECT expressions.
///
/// Wildcards are expanded into the concrete list of columns.
fn prepare_select_exprs(
&self,
plan: &LogicalPlan,
projection: &Vec<SelectItem>,
) -> Result<Vec<Expr>> {
let input_schema = plan.schema();
Ok(projection
.iter()
.map(|expr| self.sql_select_to_rex(&expr, &input_schema))
.collect::<Result<Vec<Expr>>>()?
.iter()
.flat_map(|expr| expand_wildcard(&expr, &input_schema))
.collect::<Vec<Expr>>())
}
/// Wrap a plan in a projection
///
/// If the `force` argument is `false`, the projection is applied only when
/// necessary, i.e., when the input fields are different than the
/// projection. Note that if the input fields are the same, but out of
/// order, the projection will be applied.
fn project(
&self,
input: &LogicalPlan,
expr: Vec<Expr>,
force: bool,
) -> Result<LogicalPlan> {
self.validate_schema_satisfies_exprs(&input.schema(), &expr)?;
let plan = LogicalPlanBuilder::from(input).project(expr)?.build()?;
let project = force
|| match input {
LogicalPlan::TableScan { .. } => true,
_ => plan.schema().fields() != input.schema().fields(),
};
if project {
Ok(plan)
} else {
Ok(input.clone())
}
}
fn aggregate(
&self,
input: &LogicalPlan,
select_exprs: &Vec<Expr>,
group_by: &Vec<SQLExpr>,
aggr_exprs: &Vec<Expr>,
) -> Result<(LogicalPlan, Vec<Expr>)> {
let group_by_exprs = group_by
.iter()
.map(|e| self.sql_to_rex(e, &input.schema()))
.collect::<Result<Vec<Expr>>>()?;
let aggr_projection_exprs = group_by_exprs
.iter()
.chain(aggr_exprs.iter())
.cloned()
.collect::<Vec<Expr>>();
let plan = LogicalPlanBuilder::from(&input)
.aggregate(group_by_exprs, aggr_exprs.clone())?
.build()?;
// After aggregation, these are all of the columns that will be
// available to next phases of planning.
let column_exprs_post_aggr = aggr_projection_exprs
.iter()
.map(|expr| expr_as_column_expr(expr, input))
.collect::<Result<Vec<Expr>>>()?;
// Rewrite the SELECT expression to use the columns produced by the
// aggregation.
let select_exprs_post_aggr = select_exprs
.iter()
.map(|expr| rebase_expr(expr, &aggr_projection_exprs, input))
.collect::<Result<Vec<Expr>>>()?;
if !can_columns_satisfy_exprs(&column_exprs_post_aggr, &select_exprs_post_aggr)? {
return Err(DataFusionError::Plan(
"Projection references non-aggregate values".to_owned(),
));
}
Ok((plan, select_exprs_post_aggr))
}
/// Wrap a plan in a limit
fn limit(&self, input: &LogicalPlan, limit: &Option<SQLExpr>) -> Result<LogicalPlan> {
match *limit {
Some(ref limit_expr) => {
let n = match self.sql_to_rex(&limit_expr, &input.schema())? {
Expr::Literal(ScalarValue::Int64(Some(n))) => Ok(n as usize),
_ => Err(DataFusionError::Plan(
"Unexpected expression for LIMIT clause".to_string(),
)),
}?;
LogicalPlanBuilder::from(&input).limit(n)?.build()
}
_ => Ok(input.clone()),
}
}
/// Wrap the logical in a sort
fn order_by(
&self,
plan: &LogicalPlan,
order_by: &Vec<OrderByExpr>,
) -> Result<LogicalPlan> {
if order_by.is_empty() {
return Ok(plan.clone());
}
let input_schema = plan.schema();
let order_by_rex: Result<Vec<Expr>> = order_by
.iter()
.map(|e| {
Ok(Expr::Sort {
expr: Box::new(self.sql_to_rex(&e.expr, &input_schema)?),
// by default asc
asc: e.asc.unwrap_or(true),
// by default nulls first to be consistent with spark
nulls_first: e.nulls_first.unwrap_or(true),
})
})
.collect();
LogicalPlanBuilder::from(&plan).sort(order_by_rex?)?.build()
}
/// Validate the schema provides all of the columns referenced in the expressions.
fn validate_schema_satisfies_exprs(
&self,
schema: &DFSchema,
exprs: &Vec<Expr>,
) -> Result<()> {
find_column_exprs(exprs)
.iter()
.try_for_each(|col| match col {
Expr::Column(name) => {
schema.field_with_unqualified_name(&name).map_err(|_| {
DataFusionError::Plan(format!(
"Invalid identifier '{}' for schema {}",
name,
schema.to_string()
))
})?;
Ok(())
}
_ => Err(DataFusionError::Internal("Not a column".to_string())),
})
}
/// Generate a relational expression from a select SQL expression
fn sql_select_to_rex(&self, sql: &SelectItem, schema: &DFSchema) -> Result<Expr> {
match sql {
SelectItem::UnnamedExpr(expr) => self.sql_to_rex(expr, schema),
SelectItem::ExprWithAlias { expr, alias } => Ok(Alias(
Box::new(self.sql_to_rex(&expr, schema)?),
alias.value.clone(),
)),
SelectItem::Wildcard => Ok(Expr::Wildcard),
SelectItem::QualifiedWildcard(_) => Err(DataFusionError::NotImplemented(
"Qualified wildcards are not supported".to_string(),
)),
}
}
/// Generate a relational expression from a SQL expression
pub fn sql_to_rex(&self, sql: &SQLExpr, schema: &DFSchema) -> Result<Expr> {
let expr = self.sql_expr_to_logical_expr(sql)?;
self.validate_schema_satisfies_exprs(schema, &vec![expr.clone()])?;
Ok(expr)
}
fn sql_fn_arg_to_logical_expr(&self, sql: &FunctionArg) -> Result<Expr> {
match sql {
FunctionArg::Named { name: _, arg } => self.sql_expr_to_logical_expr(arg),
FunctionArg::Unnamed(value) => self.sql_expr_to_logical_expr(value),
}
}
fn sql_expr_to_logical_expr(&self, sql: &SQLExpr) -> Result<Expr> {
match sql {
SQLExpr::Value(Value::Number(n)) => match n.parse::<i64>() {
Ok(n) => Ok(lit(n)),
Err(_) => Ok(lit(n.parse::<f64>().unwrap())),
},
SQLExpr::Value(Value::SingleQuotedString(ref s)) => Ok(lit(s.clone())),
SQLExpr::Value(Value::Null) => Ok(Expr::Literal(ScalarValue::Utf8(None))),
SQLExpr::Identifier(ref id) => {
if &id.value[0..1] == "@" {
let var_names = vec![id.value.clone()];
Ok(Expr::ScalarVariable(var_names))
} else {
Ok(Expr::Column(id.value.to_string()))
}
}
SQLExpr::CompoundIdentifier(ids) => {
let mut var_names = vec![];
for id in ids {
var_names.push(id.value.clone());
}
if &var_names[0][0..1] == "@" {
Ok(Expr::ScalarVariable(var_names))
} else {
Err(DataFusionError::NotImplemented(format!(
"Unsupported compound identifier '{:?}'",
var_names,
)))
}
}
SQLExpr::Wildcard => Ok(Expr::Wildcard),
SQLExpr::Case {
operand,
conditions,
results,
else_result,
} => {
let expr = if let Some(e) = operand {
Some(Box::new(self.sql_expr_to_logical_expr(e)?))
} else {
None
};
let when_expr = conditions
.iter()
.map(|e| self.sql_expr_to_logical_expr(e))
.collect::<Result<Vec<_>>>()?;
let then_expr = results
.iter()
.map(|e| self.sql_expr_to_logical_expr(e))
.collect::<Result<Vec<_>>>()?;
let else_expr = if let Some(e) = else_result {
Some(Box::new(self.sql_expr_to_logical_expr(e)?))
} else {
None
};
Ok(Expr::Case {
expr,
when_then_expr: when_expr
.iter()
.zip(then_expr.iter())
.map(|(w, t)| (Box::new(w.to_owned()), Box::new(t.to_owned())))
.collect(),
else_expr,
})
}
SQLExpr::Cast {
ref expr,
ref data_type,
} => Ok(Expr::Cast {
expr: Box::new(self.sql_expr_to_logical_expr(&expr)?),
data_type: convert_data_type(data_type)?,
}),
SQLExpr::TypedString {
ref data_type,
ref value,
} => Ok(Expr::Cast {
expr: Box::new(lit(&**value)),
data_type: convert_data_type(data_type)?,
}),
SQLExpr::IsNull(ref expr) => {
Ok(Expr::IsNull(Box::new(self.sql_expr_to_logical_expr(expr)?)))
}
SQLExpr::IsNotNull(ref expr) => Ok(Expr::IsNotNull(Box::new(
self.sql_expr_to_logical_expr(expr)?,
))),
SQLExpr::UnaryOp { ref op, ref expr } => match op {
UnaryOperator::Not => {
Ok(Expr::Not(Box::new(self.sql_expr_to_logical_expr(expr)?)))
}
UnaryOperator::Plus => Ok(self.sql_expr_to_logical_expr(expr)?),
UnaryOperator::Minus => {
match expr.as_ref() {
// optimization: if it's a number literal, we applly the negative operator
// here directly to calculate the new literal.
SQLExpr::Value(Value::Number(n)) => match n.parse::<i64>() {
Ok(n) => Ok(lit(-n)),
Err(_) => Ok(lit(-n
.parse::<f64>()
.map_err(|_e| {
DataFusionError::Internal(format!(
"negative operator can be only applied to integer and float operands, got: {}",
n))
})?)),
},
// not a literal, apply negative operator on expression
_ => Ok(Expr::Negative(Box::new(self.sql_expr_to_logical_expr(expr)?))),
}
}
_ => Err(DataFusionError::NotImplemented(format!(
"Unsupported SQL unary operator {:?}",
op
))),
},
SQLExpr::Between {
ref expr,
ref negated,
ref low,
ref high,
} => Ok(Expr::Between {
expr: Box::new(self.sql_expr_to_logical_expr(&expr)?),
negated: *negated,
low: Box::new(self.sql_expr_to_logical_expr(&low)?),
high: Box::new(self.sql_expr_to_logical_expr(&high)?),
}),
SQLExpr::InList {
ref expr,
ref list,
ref negated,
} => {
let list_expr = list
.iter()
.map(|e| self.sql_expr_to_logical_expr(e))
.collect::<Result<Vec<_>>>()?;
Ok(Expr::InList {
expr: Box::new(self.sql_expr_to_logical_expr(&expr)?),
list: list_expr,
negated: *negated,
})
}
SQLExpr::BinaryOp {
ref left,
ref op,
ref right,
} => {
let operator = match *op {
BinaryOperator::Gt => Ok(Operator::Gt),
BinaryOperator::GtEq => Ok(Operator::GtEq),
BinaryOperator::Lt => Ok(Operator::Lt),
BinaryOperator::LtEq => Ok(Operator::LtEq),
BinaryOperator::Eq => Ok(Operator::Eq),
BinaryOperator::NotEq => Ok(Operator::NotEq),
BinaryOperator::Plus => Ok(Operator::Plus),
BinaryOperator::Minus => Ok(Operator::Minus),
BinaryOperator::Multiply => Ok(Operator::Multiply),
BinaryOperator::Divide => Ok(Operator::Divide),
BinaryOperator::Modulus => Ok(Operator::Modulus),
BinaryOperator::And => Ok(Operator::And),
BinaryOperator::Or => Ok(Operator::Or),
BinaryOperator::Like => Ok(Operator::Like),
BinaryOperator::NotLike => Ok(Operator::NotLike),
_ => Err(DataFusionError::NotImplemented(format!(
"Unsupported SQL binary operator {:?}",
op
))),
}?;
Ok(Expr::BinaryExpr {
left: Box::new(self.sql_expr_to_logical_expr(&left)?),
op: operator,
right: Box::new(self.sql_expr_to_logical_expr(&right)?),
})
}
SQLExpr::Function(function) => {
let name: String = function.name.to_string();
// first, scalar built-in
if let Ok(fun) = functions::BuiltinScalarFunction::from_str(&name) {
let args = function
.args
.iter()
.map(|a| self.sql_fn_arg_to_logical_expr(a))
.collect::<Result<Vec<Expr>>>()?;
return Ok(Expr::ScalarFunction { fun, args });
};
// next, aggregate built-ins
if let Ok(fun) = aggregates::AggregateFunction::from_str(&name) {
let args = if fun == aggregates::AggregateFunction::Count {
function
.args
.iter()
.map(|a| match a {
FunctionArg::Unnamed(SQLExpr::Value(Value::Number(
_,
))) => Ok(lit(1_u8)),
FunctionArg::Unnamed(SQLExpr::Wildcard) => Ok(lit(1_u8)),
_ => self.sql_fn_arg_to_logical_expr(a),
})
.collect::<Result<Vec<Expr>>>()?
} else {
function
.args
.iter()
.map(|a| self.sql_fn_arg_to_logical_expr(a))
.collect::<Result<Vec<Expr>>>()?
};
return Ok(Expr::AggregateFunction {
fun,
distinct: function.distinct,
args,
});
};
// finally, user-defined functions (UDF) and UDAF
match self.schema_provider.get_function_meta(&name) {
Some(fm) => {
let args = function
.args
.iter()
.map(|a| self.sql_fn_arg_to_logical_expr(a))
.collect::<Result<Vec<Expr>>>()?;
Ok(Expr::ScalarUDF { fun: fm, args })
}
None => match self.schema_provider.get_aggregate_meta(&name) {
Some(fm) => {
let args = function
.args
.iter()
.map(|a| self.sql_fn_arg_to_logical_expr(a))
.collect::<Result<Vec<Expr>>>()?;
Ok(Expr::AggregateUDF { fun: fm, args })
}
_ => Err(DataFusionError::Plan(format!(
"Invalid function '{}'",
name
))),
},
}
}
SQLExpr::Nested(e) => self.sql_expr_to_logical_expr(&e),
_ => Err(DataFusionError::NotImplemented(format!(
"Unsupported ast node {:?} in sqltorel",
sql
))),
}
}
}
/// Remove join expressions from a filter expression
fn remove_join_expressions(
expr: &Expr,
join_columns: &[(&str, &str)],
) -> Result<Option<Expr>> {
match expr {
Expr::BinaryExpr { left, op, right } => match op {
Operator::Eq => match (left.as_ref(), right.as_ref()) {
(Expr::Column(l), Expr::Column(r)) => {
if join_columns.contains(&(l, r)) || join_columns.contains(&(r, l)) {
Ok(None)
} else {
Ok(Some(expr.clone()))
}
}
_ => Ok(Some(expr.clone())),
},
Operator::And => {
let l = remove_join_expressions(left, join_columns)?;
let r = remove_join_expressions(right, join_columns)?;
match (l, r) {
(Some(ll), Some(rr)) => Ok(Some(and(ll, rr))),
(Some(ll), _) => Ok(Some(ll)),
(_, Some(rr)) => Ok(Some(rr)),
_ => Ok(None),
}
}
_ => Ok(Some(expr.clone())),
},
_ => Ok(Some(expr.clone())),
}
}
/// Parse equijoin ON condition which could be a single Eq or multiple conjunctive Eqs
///
/// Examples
///
/// foo = bar
/// foo = bar AND bar = baz AND ...
///
fn extract_join_keys(expr: &Expr, accum: &mut Vec<(String, String)>) -> Result<()> {
match expr {
Expr::BinaryExpr { left, op, right } => match op {
Operator::Eq => match (left.as_ref(), right.as_ref()) {
(Expr::Column(l), Expr::Column(r)) => {
accum.push((l.to_owned(), r.to_owned()));
Ok(())
}
other => Err(DataFusionError::SQL(ParserError(format!(
"Unsupported expression '{:?}' in JOIN condition",
other
)))),
},
Operator::And => {
extract_join_keys(left, accum)?;
extract_join_keys(right, accum)
}
other => Err(DataFusionError::SQL(ParserError(format!(
"Unsupported expression '{:?}' in JOIN condition",
other
)))),
},
other => Err(DataFusionError::SQL(ParserError(format!(
"Unsupported expression '{:?}' in JOIN condition",
other
)))),
}
}
/// Extract join keys from a WHERE clause
fn extract_possible_join_keys(
expr: &Expr,
accum: &mut Vec<(String, String)>,
) -> Result<()> {
match expr {
Expr::BinaryExpr { left, op, right } => match op {
Operator::Eq => match (left.as_ref(), right.as_ref()) {
(Expr::Column(l), Expr::Column(r)) => {
accum.push((l.to_owned(), r.to_owned()));
Ok(())
}
_ => Ok(()),
},
Operator::And => {
extract_possible_join_keys(left, accum)?;
extract_possible_join_keys(right, accum)
}
_ => Ok(()),
},
_ => Ok(()),
}
}
/// Convert SQL data type to relational representation of data type
pub fn convert_data_type(sql: &SQLDataType) -> Result<DataType> {
match sql {
SQLDataType::Boolean => Ok(DataType::Boolean),
SQLDataType::SmallInt => Ok(DataType::Int16),
SQLDataType::Int => Ok(DataType::Int32),
SQLDataType::BigInt => Ok(DataType::Int64),
SQLDataType::Float(_) | SQLDataType::Real => Ok(DataType::Float64),
SQLDataType::Double => Ok(DataType::Float64),
SQLDataType::Char(_) | SQLDataType::Varchar(_) => Ok(DataType::Utf8),
SQLDataType::Timestamp => Ok(DataType::Timestamp(TimeUnit::Nanosecond, None)),
SQLDataType::Date => Ok(DataType::Date32(DateUnit::Day)),
other => Err(DataFusionError::NotImplemented(format!(
"Unsupported SQL type {:?}",
other
))),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::datasource::empty::EmptyTable;
use crate::{logical_plan::create_udf, sql::parser::DFParser};
use functions::ScalarFunctionImplementation;
const PERSON_COLUMN_NAMES: &str =
"id, first_name, last_name, age, state, salary, birth_date";
#[test]
fn select_no_relation() {
quick_test(
"SELECT 1",
"Projection: Int64(1)\
\n EmptyRelation",
);
}
#[test]
fn select_column_does_not_exist() {
let sql = "SELECT doesnotexist FROM person";
let err = logical_plan(sql).expect_err("query should have failed");
assert_eq!(
format!(
"Plan(\"Invalid identifier \\\'doesnotexist\\\' for schema {}\")",
PERSON_COLUMN_NAMES
),
format!("{:?}", err)
);
}
#[test]
fn select_repeated_column() {
let sql = "SELECT age, age FROM person";
let err = logical_plan(sql).expect_err("query should have failed");
assert_eq!(
"Plan(\"Projections require unique expression names but the expression \\\"#age\\\" at position 0 and \\\"#age\\\" at position 1 have the same name. Consider aliasing (\\\"AS\\\") one of them.\")",
format!("{:?}", err)
);
}
#[test]
fn select_wildcard_with_repeated_column() {
let sql = "SELECT *, age FROM person";
let err = logical_plan(sql).expect_err("query should have failed");
assert_eq!(
"Plan(\"Projections require unique expression names but the expression \\\"#age\\\" at position 3 and \\\"#age\\\" at position 7 have the same name. Consider aliasing (\\\"AS\\\") one of them.\")",
format!("{:?}", err)
);
}
#[test]
fn select_wildcard_with_repeated_column_but_is_aliased() {
quick_test(
"SELECT *, first_name AS fn from person",
"Projection: #id, #first_name, #last_name, #age, #state, #salary, #birth_date, #first_name AS fn\
\n TableScan: person projection=None",
);
}
#[test]
fn select_scalar_func_with_literal_no_relation() {
quick_test(
"SELECT sqrt(9)",
"Projection: sqrt(Int64(9))\
\n EmptyRelation",
);
}
#[test]
fn select_simple_filter() {
let sql = "SELECT id, first_name, last_name \
FROM person WHERE state = 'CO'";
let expected = "Projection: #id, #first_name, #last_name\
\n Filter: #state Eq Utf8(\"CO\")\
\n TableScan: person projection=None";
quick_test(sql, expected);
}
#[test]
fn select_filter_column_does_not_exist() {
let sql = "SELECT first_name FROM person WHERE doesnotexist = 'A'";
let err = logical_plan(sql).expect_err("query should have failed");
assert_eq!(
format!(
"Plan(\"Invalid identifier \\\'doesnotexist\\\' for schema {}\")",
PERSON_COLUMN_NAMES
),
format!("{:?}", err)
);
}
#[test]
fn select_filter_cannot_use_alias() {
let sql = "SELECT first_name AS x FROM person WHERE x = 'A'";
let err = logical_plan(sql).expect_err("query should have failed");
assert_eq!(
format!(
"Plan(\"Invalid identifier \\\'x\\\' for schema {}\")",
PERSON_COLUMN_NAMES
),
format!("{:?}", err)
);
}
#[test]
fn select_neg_filter() {
let sql = "SELECT id, first_name, last_name \
FROM person WHERE NOT state";
let expected = "Projection: #id, #first_name, #last_name\
\n Filter: NOT #state\
\n TableScan: person projection=None";
quick_test(sql, expected);
}
#[test]
fn select_compound_filter() {
let sql = "SELECT id, first_name, last_name \
FROM person WHERE state = 'CO' AND age >= 21 AND age <= 65";
let expected = "Projection: #id, #first_name, #last_name\
\n Filter: #state Eq Utf8(\"CO\") And #age GtEq Int64(21) And #age LtEq Int64(65)\
\n TableScan: person projection=None";
quick_test(sql, expected);
}
#[test]
fn test_timestamp_filter() {
let sql = "SELECT state FROM person WHERE birth_date < CAST (158412331400600000 as timestamp)";
let expected = "Projection: #state\
\n Filter: #birth_date Lt CAST(Int64(158412331400600000) AS Timestamp(Nanosecond, None))\
\n TableScan: person projection=None";
quick_test(sql, expected);
}
#[test]
fn test_date_filter() {
let sql =
"SELECT state FROM person WHERE birth_date < CAST ('2020-01-01' as date)";
let expected = "Projection: #state\
\n Filter: #birth_date Lt CAST(Utf8(\"2020-01-01\") AS Date32(Day))\
\n TableScan: person projection=None";
quick_test(sql, expected);
}
#[test]
fn select_all_boolean_operators() {
let sql = "SELECT age, first_name, last_name \
FROM person \
WHERE age = 21 \
AND age != 21 \
AND age > 21 \
AND age >= 21 \
AND age < 65 \
AND age <= 65";
let expected = "Projection: #age, #first_name, #last_name\
\n Filter: #age Eq Int64(21) \
And #age NotEq Int64(21) \
And #age Gt Int64(21) \
And #age GtEq Int64(21) \
And #age Lt Int64(65) \
And #age LtEq Int64(65)\
\n TableScan: person projection=None";
quick_test(sql, expected);
}
#[test]
fn select_between() {
let sql = "SELECT state FROM person WHERE age BETWEEN 21 AND 65";
let expected = "Projection: #state\
\n Filter: #age BETWEEN Int64(21) AND Int64(65)\
\n TableScan: person projection=None";
quick_test(sql, expected);
}
#[test]
fn select_between_negated() {
let sql = "SELECT state FROM person WHERE age NOT BETWEEN 21 AND 65";
let expected = "Projection: #state\
\n Filter: #age NOT BETWEEN Int64(21) AND Int64(65)\
\n TableScan: person projection=None";
quick_test(sql, expected);
}
#[test]
fn select_nested() {
let sql = "SELECT fn2, last_name
FROM (
SELECT fn1 as fn2, last_name, birth_date
FROM (
SELECT first_name AS fn1, last_name, birth_date, age
FROM person
)
)";
let expected = "Projection: #fn2, #last_name\
\n Projection: #fn1 AS fn2, #last_name, #birth_date\
\n Projection: #first_name AS fn1, #last_name, #birth_date, #age\
\n TableScan: person projection=None";
quick_test(sql, expected);
}
#[test]
fn select_nested_with_filters() {
let sql = "SELECT fn1, age
FROM (
SELECT first_name AS fn1, age
FROM person
WHERE age > 20
)
WHERE fn1 = 'X' AND age < 30";
let expected = "Filter: #fn1 Eq Utf8(\"X\") And #age Lt Int64(30)\
\n Projection: #first_name AS fn1, #age\
\n Filter: #age Gt Int64(20)\
\n TableScan: person projection=None";
quick_test(sql, expected);
}
#[test]
fn select_binary_expr() {
let sql = "SELECT age + salary from person";
let expected = "Projection: #age Plus #salary\
\n TableScan: person projection=None";
quick_test(sql, expected);
}
#[test]
fn select_binary_expr_nested() {
let sql = "SELECT (age + salary)/2 from person";
let expected = "Projection: #age Plus #salary Divide Int64(2)\
\n TableScan: person projection=None";
quick_test(sql, expected);
}
#[test]
fn select_wildcard_with_groupby() {
quick_test(
"SELECT * FROM person GROUP BY id, first_name, last_name, age, state, salary, birth_date",
"Aggregate: groupBy=[[#id, #first_name, #last_name, #age, #state, #salary, #birth_date]], aggr=[[]]\
\n TableScan: person projection=None",
);
quick_test(
"SELECT * FROM (SELECT first_name, last_name FROM person) GROUP BY first_name, last_name",
"Aggregate: groupBy=[[#first_name, #last_name]], aggr=[[]]\
\n Projection: #first_name, #last_name\
\n TableScan: person projection=None",
);
}
#[test]
fn select_simple_aggregate() {
quick_test(
"SELECT MIN(age) FROM person",
"Aggregate: groupBy=[[]], aggr=[[MIN(#age)]]\
\n TableScan: person projection=None",
);
}
#[test]
fn test_sum_aggregate() {
quick_test(
"SELECT SUM(age) from person",
"Aggregate: groupBy=[[]], aggr=[[SUM(#age)]]\
\n TableScan: person projection=None",
);
}
#[test]
fn select_simple_aggregate_column_does_not_exist() {
let sql = "SELECT MIN(doesnotexist) FROM person";
let err = logical_plan(sql).expect_err("query should have failed");
assert_eq!(
format!(
"Plan(\"Invalid identifier \\\'doesnotexist\\\' for schema {}\")",
PERSON_COLUMN_NAMES
),
format!("{:?}", err)
);
}
#[test]
fn select_simple_aggregate_repeated_aggregate() {
let sql = "SELECT MIN(age), MIN(age) FROM person";
let err = logical_plan(sql).expect_err("query should have failed");
assert_eq!(
"Plan(\"Projections require unique expression names but the expression \\\"#MIN(age)\\\" at position 0 and \\\"#MIN(age)\\\" at position 1 have the same name. Consider aliasing (\\\"AS\\\") one of them.\")",
format!("{:?}", err)
);
}
#[test]
fn select_simple_aggregate_repeated_aggregate_with_single_alias() {
quick_test(
"SELECT MIN(age), MIN(age) AS a FROM person",
"Projection: #MIN(age), #MIN(age) AS a\
\n Aggregate: groupBy=[[]], aggr=[[MIN(#age)]]\
\n TableScan: person projection=None",
);
}
#[test]
fn select_simple_aggregate_repeated_aggregate_with_unique_aliases() {
quick_test(
"SELECT MIN(age) AS a, MIN(age) AS b FROM person",
"Projection: #MIN(age) AS a, #MIN(age) AS b\
\n Aggregate: groupBy=[[]], aggr=[[MIN(#age)]]\
\n TableScan: person projection=None",
);
}
#[test]
fn select_simple_aggregate_repeated_aggregate_with_repeated_aliases() {
let sql = "SELECT MIN(age) AS a, MIN(age) AS a FROM person";
let err = logical_plan(sql).expect_err("query should have failed");
assert_eq!(
"Plan(\"Projections require unique expression names but the expression \\\"#MIN(age) AS a\\\" at position 0 and \\\"#MIN(age) AS a\\\" at position 1 have the same name. Consider aliasing (\\\"AS\\\") one of them.\")",
format!("{:?}", err)
);
}
#[test]
fn select_simple_aggregate_with_groupby() {
quick_test(
"SELECT state, MIN(age), MAX(age) FROM person GROUP BY state",
"Aggregate: groupBy=[[#state]], aggr=[[MIN(#age), MAX(#age)]]\
\n TableScan: person projection=None",
);
}
#[test]
fn select_simple_aggregate_with_groupby_with_aliases() {
quick_test(
"SELECT state AS a, MIN(age) AS b FROM person GROUP BY state",
"Projection: #state AS a, #MIN(age) AS b\
\n Aggregate: groupBy=[[#state]], aggr=[[MIN(#age)]]\
\n TableScan: person projection=None",
);
}
#[test]
fn select_simple_aggregate_with_groupby_with_aliases_repeated() {
let sql = "SELECT state AS a, MIN(age) AS a FROM person GROUP BY state";
let err = logical_plan(sql).expect_err("query should have failed");
assert_eq!(
"Plan(\"Projections require unique expression names but the expression \\\"#state AS a\\\" at position 0 and \\\"#MIN(age) AS a\\\" at position 1 have the same name. Consider aliasing (\\\"AS\\\") one of them.\")",
format!("{:?}", err)
);
}
#[test]
fn select_simple_aggregate_with_groupby_column_unselected() {
quick_test(
"SELECT MIN(age), MAX(age) FROM person GROUP BY state",
"Projection: #MIN(age), #MAX(age)\
\n Aggregate: groupBy=[[#state]], aggr=[[MIN(#age), MAX(#age)]]\
\n TableScan: person projection=None",
);
}
#[test]
fn select_simple_aggregate_with_groupby_and_column_in_group_by_does_not_exist() {
let sql = "SELECT SUM(age) FROM person GROUP BY doesnotexist";
let err = logical_plan(sql).expect_err("query should have failed");
assert_eq!(
format!(
"Plan(\"Invalid identifier \\\'doesnotexist\\\' for schema {}\")",
PERSON_COLUMN_NAMES
),
format!("{:?}", err)
);
}
#[test]
fn select_simple_aggregate_with_groupby_and_column_in_aggregate_does_not_exist() {
let sql = "SELECT SUM(doesnotexist) FROM person GROUP BY first_name";
let err = logical_plan(sql).expect_err("query should have failed");
assert_eq!(
format!(
"Plan(\"Invalid identifier \\\'doesnotexist\\\' for schema {}\")",
PERSON_COLUMN_NAMES
),
format!("{:?}", err)
);
}
#[test]
fn select_simple_aggregate_with_groupby_and_column_is_in_aggregate_and_groupby() {
quick_test(
"SELECT MAX(first_name) FROM person GROUP BY first_name",
"Projection: #MAX(first_name)\
\n Aggregate: groupBy=[[#first_name]], aggr=[[MAX(#first_name)]]\
\n TableScan: person projection=None",
);
}
#[test]
fn select_simple_aggregate_with_groupby_cannot_use_alias() {
let sql = "SELECT state AS x, MAX(age) FROM person GROUP BY x";
let err = logical_plan(sql).expect_err("query should have failed");
assert_eq!(
format!(
"Plan(\"Invalid identifier \\\'x\\\' for schema {}\")",
PERSON_COLUMN_NAMES
),
format!("{:?}", err)
);
}
#[test]
fn select_simple_aggregate_with_groupby_aggregate_repeated() {
let sql = "SELECT state, MIN(age), MIN(age) FROM person GROUP BY state";
let err = logical_plan(sql).expect_err("query should have failed");
assert_eq!(
"Plan(\"Projections require unique expression names but the expression \\\"#MIN(age)\\\" at position 1 and \\\"#MIN(age)\\\" at position 2 have the same name. Consider aliasing (\\\"AS\\\") one of them.\")",
format!("{:?}", err)
);
}
#[test]
fn select_simple_aggregate_with_groupby_aggregate_repeated_and_one_has_alias() {
quick_test(
"SELECT state, MIN(age), MIN(age) AS ma FROM person GROUP BY state",
"Projection: #state, #MIN(age), #MIN(age) AS ma\
\n Aggregate: groupBy=[[#state]], aggr=[[MIN(#age)]]\
\n TableScan: person projection=None",
)
}
#[test]
fn select_simple_aggregate_with_groupby_non_column_expression_unselected() {
quick_test(
"SELECT MIN(first_name) FROM person GROUP BY age + 1",
"Projection: #MIN(first_name)\
\n Aggregate: groupBy=[[#age Plus Int64(1)]], aggr=[[MIN(#first_name)]]\
\n TableScan: person projection=None",
);
}
#[test]
fn select_simple_aggregate_with_groupby_non_column_expression_selected_and_resolvable(
) {
quick_test(
"SELECT age + 1, MIN(first_name) FROM person GROUP BY age + 1",
"Aggregate: groupBy=[[#age Plus Int64(1)]], aggr=[[MIN(#first_name)]]\
\n TableScan: person projection=None",
);
quick_test(
"SELECT MIN(first_name), age + 1 FROM person GROUP BY age + 1",
"Projection: #MIN(first_name), #age Plus Int64(1)\
\n Aggregate: groupBy=[[#age Plus Int64(1)]], aggr=[[MIN(#first_name)]]\
\n TableScan: person projection=None",
);
}
#[test]
fn select_simple_aggregate_with_groupby_non_column_expression_nested_and_resolvable()
{
quick_test(
"SELECT ((age + 1) / 2) * (age + 1), MIN(first_name) FROM person GROUP BY age + 1",
"Projection: #age Plus Int64(1) Divide Int64(2) Multiply #age Plus Int64(1), #MIN(first_name)\
\n Aggregate: groupBy=[[#age Plus Int64(1)]], aggr=[[MIN(#first_name)]]\
\n TableScan: person projection=None",
);
}
#[test]
fn select_simple_aggregate_with_groupby_non_column_expression_nested_and_not_resolvable(
) {
// The query should fail, because age + 9 is not in the group by.
let sql = "SELECT ((age + 1) / 2) * (age + 9), MIN(first_name) FROM person GROUP BY age + 1";
let err = logical_plan(sql).expect_err("query should have failed");
assert_eq!(
"Plan(\"Projection references non-aggregate values\")",
format!("{:?}", err)
);
}
#[test]
fn select_simple_aggregate_with_groupby_non_column_expression_and_its_column_selected(
) {
let sql = "SELECT age, MIN(first_name) FROM person GROUP BY age + 1";
let err = logical_plan(sql).expect_err("query should have failed");
assert_eq!(
"Plan(\"Projection references non-aggregate values\")",
format!("{:?}", err)
);
}
#[test]
fn select_simple_aggregate_nested_in_binary_expr_with_groupby() {
quick_test(
"SELECT state, MIN(age) < 10 FROM person GROUP BY state",
"Projection: #state, #MIN(age) Lt Int64(10)\
\n Aggregate: groupBy=[[#state]], aggr=[[MIN(#age)]]\
\n TableScan: person projection=None",
);
}
#[test]
fn select_simple_aggregate_and_nested_groupby_column() {
quick_test(
"SELECT age + 1, MAX(first_name) FROM person GROUP BY age",
"Projection: #age Plus Int64(1), #MAX(first_name)\
\n Aggregate: groupBy=[[#age]], aggr=[[MAX(#first_name)]]\
\n TableScan: person projection=None",
);
}
#[test]
fn select_aggregate_compounded_with_groupby_column() {
quick_test(
"SELECT age + MIN(salary) FROM person GROUP BY age",
"Projection: #age Plus #MIN(salary)\
\n Aggregate: groupBy=[[#age]], aggr=[[MIN(#salary)]]\
\n TableScan: person projection=None",
);
}
#[test]
fn select_aggregate_with_non_column_inner_expression_with_groupby() {
quick_test(
"SELECT state, MIN(age + 1) FROM person GROUP BY state",
"Aggregate: groupBy=[[#state]], aggr=[[MIN(#age Plus Int64(1))]]\
\n TableScan: person projection=None",
);
}
#[test]
fn test_wildcard() {
quick_test(
"SELECT * from person",
"Projection: #id, #first_name, #last_name, #age, #state, #salary, #birth_date\
\n TableScan: person projection=None",
);
}
#[test]
fn select_count_one() {
let sql = "SELECT COUNT(1) FROM person";
let expected = "Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]]\
\n TableScan: person projection=None";
quick_test(sql, expected);
}
#[test]
fn select_count_column() {
let sql = "SELECT COUNT(id) FROM person";
let expected = "Aggregate: groupBy=[[]], aggr=[[COUNT(#id)]]\
\n TableScan: person projection=None";
quick_test(sql, expected);
}
#[test]
fn select_scalar_func() {
let sql = "SELECT sqrt(age) FROM person";
let expected = "Projection: sqrt(#age)\
\n TableScan: person projection=None";
quick_test(sql, expected);
}
#[test]
fn select_aliased_scalar_func() {
let sql = "SELECT sqrt(age) AS square_people FROM person";
let expected = "Projection: sqrt(#age) AS square_people\
\n TableScan: person projection=None";
quick_test(sql, expected);
}
#[test]
fn select_where_nullif_division() {
let sql = "SELECT c3/(c4+c5) \
FROM aggregate_test_100 WHERE c3/nullif(c4+c5, 0) > 0.1";
let expected = "Projection: #c3 Divide #c4 Plus #c5\
\n Filter: #c3 Divide nullif(#c4 Plus #c5, Int64(0)) Gt Float64(0.1)\
\n TableScan: aggregate_test_100 projection=None";
quick_test(sql, expected);
}
#[test]
fn select_where_with_negative_operator() {
let sql = "SELECT c3 FROM aggregate_test_100 WHERE c3 > -0.1 AND -c4 > 0";
let expected = "Projection: #c3\
\n Filter: #c3 Gt Float64(-0.1) And (- #c4) Gt Int64(0)\
\n TableScan: aggregate_test_100 projection=None";
quick_test(sql, expected);
}
#[test]
fn select_where_with_positive_operator() {
let sql = "SELECT c3 FROM aggregate_test_100 WHERE c3 > +0.1 AND +c4 > 0";
let expected = "Projection: #c3\
\n Filter: #c3 Gt Float64(0.1) And #c4 Gt Int64(0)\
\n TableScan: aggregate_test_100 projection=None";
quick_test(sql, expected);
}
#[test]
fn select_order_by() {
let sql = "SELECT id FROM person ORDER BY id";
let expected = "Sort: #id ASC NULLS FIRST\
\n Projection: #id\
\n TableScan: person projection=None";
quick_test(sql, expected);
}
#[test]
fn select_order_by_desc() {
let sql = "SELECT id FROM person ORDER BY id DESC";
let expected = "Sort: #id DESC NULLS FIRST\
\n Projection: #id\
\n TableScan: person projection=None";
quick_test(sql, expected);
}
#[test]
fn select_order_by_nulls_last() {
quick_test(
"SELECT id FROM person ORDER BY id DESC NULLS LAST",
"Sort: #id DESC NULLS LAST\
\n Projection: #id\
\n TableScan: person projection=None",
);
quick_test(
"SELECT id FROM person ORDER BY id NULLS LAST",
"Sort: #id ASC NULLS LAST\
\n Projection: #id\
\n TableScan: person projection=None",
);
}
#[test]
fn select_group_by() {
let sql = "SELECT state FROM person GROUP BY state";
let expected = "Aggregate: groupBy=[[#state]], aggr=[[]]\
\n TableScan: person projection=None";
quick_test(sql, expected);
}
#[test]
fn select_group_by_columns_not_in_select() {
let sql = "SELECT MAX(age) FROM person GROUP BY state";
let expected = "Projection: #MAX(age)\
\n Aggregate: groupBy=[[#state]], aggr=[[MAX(#age)]]\
\n TableScan: person projection=None";
quick_test(sql, expected);
}
#[test]
fn select_group_by_count_star() {
let sql = "SELECT state, COUNT(*) FROM person GROUP BY state";
let expected = "Aggregate: groupBy=[[#state]], aggr=[[COUNT(UInt8(1))]]\
\n TableScan: person projection=None";
quick_test(sql, expected);
}
#[test]
fn select_group_by_needs_projection() {
let sql = "SELECT COUNT(state), state FROM person GROUP BY state";
let expected = "\
Projection: #COUNT(state), #state\
\n Aggregate: groupBy=[[#state]], aggr=[[COUNT(#state)]]\
\n TableScan: person projection=None";
quick_test(sql, expected);
}
#[test]
fn select_7480_1() {
let sql = "SELECT c1, MIN(c12) FROM aggregate_test_100 GROUP BY c1, c13";
let expected = "Projection: #c1, #MIN(c12)\
\n Aggregate: groupBy=[[#c1, #c13]], aggr=[[MIN(#c12)]]\
\n TableScan: aggregate_test_100 projection=None";
quick_test(sql, expected);
}
#[test]
fn select_7480_2() {
let sql = "SELECT c1, c13, MIN(c12) FROM aggregate_test_100 GROUP BY c1";
let err = logical_plan(sql).expect_err("query should have failed");
assert_eq!(
"Plan(\"Projection references non-aggregate values\")",
format!("{:?}", err)
);
}
#[test]
fn create_external_table_csv() {
let sql = "CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV LOCATION 'foo.csv'";
let expected = "CreateExternalTable: \"t\"";
quick_test(sql, expected);
}
#[test]
fn create_external_table_csv_no_schema() {
let sql = "CREATE EXTERNAL TABLE t STORED AS CSV LOCATION 'foo.csv'";
let err = logical_plan(sql).expect_err("query should have failed");
assert_eq!(
"Plan(\"Column definitions required for CSV files. None found\")",
format!("{:?}", err)
);
}
#[test]
fn create_external_table_parquet() {
let sql =
"CREATE EXTERNAL TABLE t(c1 int) STORED AS PARQUET LOCATION 'foo.parquet'";
let err = logical_plan(sql).expect_err("query should have failed");
assert_eq!(
"Plan(\"Column definitions can not be specified for PARQUET files.\")",
format!("{:?}", err)
);
}
#[test]
fn create_external_table_parquet_no_schema() {
let sql = "CREATE EXTERNAL TABLE t STORED AS PARQUET LOCATION 'foo.parquet'";
let expected = "CreateExternalTable: \"t\"";
quick_test(sql, expected);
}
#[test]
fn equijoin_explicit_syntax() {
let sql = "SELECT id, order_id \
FROM person \
JOIN orders \
ON id = customer_id";
let expected = "Projection: #id, #order_id\
\n Join: id = customer_id\
\n TableScan: person projection=None\
\n TableScan: orders projection=None";
quick_test(sql, expected);
}
#[test]
fn equijoin_explicit_syntax_3_tables() {
let sql = "SELECT id, order_id, l_description \
FROM person \
JOIN orders ON id = customer_id \
JOIN lineitem ON o_item_id = l_item_id";
let expected = "Projection: #id, #order_id, #l_description\
\n Join: o_item_id = l_item_id\
\n Join: id = customer_id\
\n TableScan: person projection=None\
\n TableScan: orders projection=None\
\n TableScan: lineitem projection=None";
quick_test(sql, expected);
}
#[test]
fn select_typedstring() {
let sql = "SELECT date '2020-12-10' AS date FROM person";
let expected = "Projection: CAST(Utf8(\"2020-12-10\") AS Date32(Day)) AS date\
\n TableScan: person projection=None";
quick_test(sql, expected);
}
fn logical_plan(sql: &str) -> Result<LogicalPlan> {
let planner = SqlToRel::new(&MockContextProvider {});
let result = DFParser::parse_sql(&sql);
let ast = result.unwrap();
planner.statement_to_plan(&ast[0])
}
/// Create logical plan, write with formatter, compare to expected output
fn quick_test(sql: &str, expected: &str) {
let plan = logical_plan(sql).unwrap();
assert_eq!(expected, format!("{:?}", plan));
}
struct MockContextProvider {}
impl ContextProvider for MockContextProvider {
fn get_table_provider(
&self,
name: &str,
) -> Option<Arc<dyn TableProvider + Send + Sync>> {
let schema = match name {
"person" => Some(Schema::new(vec![
Field::new("id", DataType::UInt32, false),
Field::new("first_name", DataType::Utf8, false),
Field::new("last_name", DataType::Utf8, false),
Field::new("age", DataType::Int32, false),
Field::new("state", DataType::Utf8, false),
Field::new("salary", DataType::Float64, false),
Field::new(
"birth_date",
DataType::Timestamp(TimeUnit::Nanosecond, None),
false,
),
])),
"orders" => Some(Schema::new(vec![
Field::new("order_id", DataType::UInt32, false),
Field::new("customer_id", DataType::UInt32, false),
Field::new("o_item_id", DataType::Utf8, false),
Field::new("qty", DataType::Int32, false),
Field::new("price", DataType::Float64, false),
])),
"lineitem" => Some(Schema::new(vec![
Field::new("l_item_id", DataType::UInt32, false),
Field::new("l_description", DataType::Utf8, false),
])),
"aggregate_test_100" => Some(Schema::new(vec![
Field::new("c1", DataType::Utf8, false),
Field::new("c2", DataType::UInt32, false),
Field::new("c3", DataType::Int8, false),
Field::new("c4", DataType::Int16, false),
Field::new("c5", DataType::Int32, false),
Field::new("c6", DataType::Int64, false),
Field::new("c7", DataType::UInt8, false),
Field::new("c8", DataType::UInt16, false),
Field::new("c9", DataType::UInt32, false),
Field::new("c10", DataType::UInt64, false),
Field::new("c11", DataType::Float32, false),
Field::new("c12", DataType::Float64, false),
Field::new("c13", DataType::Utf8, false),
])),
_ => None,
};
schema.map(|s| -> Arc<dyn TableProvider + Send + Sync> {
Arc::new(EmptyTable::new(Arc::new(s)))
})
}
fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarUDF>> {
let f: ScalarFunctionImplementation =
Arc::new(|_| Err(DataFusionError::NotImplemented("".to_string())));
match name {
"my_sqrt" => Some(Arc::new(create_udf(
"my_sqrt",
vec![DataType::Float64],
Arc::new(DataType::Float64),
f,
))),
_ => None,
}
}
fn get_aggregate_meta(&self, _name: &str) -> Option<Arc<AggregateUDF>> {
unimplemented!()
}
}
}