blob: 77b0ce2bae194eb5bec45a3c9eeb36a512731e12 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
use std::any::Any;
use datafusion::arrow::datatypes::SchemaRef;
use datafusion_expr::{Expr, TableProviderFilterPushDown, TableSource};
use pyo3::prelude::*;
use datafusion_optimizer::utils::split_conjunction;
use super::{data_type::DataTypeMap, function::SqlFunction};
#[pyclass(name = "SqlSchema", module = "datafusion.common", subclass)]
#[derive(Debug, Clone)]
pub struct SqlSchema {
#[pyo3(get, set)]
pub name: String,
#[pyo3(get, set)]
pub tables: Vec<SqlTable>,
#[pyo3(get, set)]
pub views: Vec<SqlView>,
#[pyo3(get, set)]
pub functions: Vec<SqlFunction>,
}
#[pyclass(name = "SqlTable", module = "datafusion.common", subclass)]
#[derive(Debug, Clone)]
pub struct SqlTable {
#[pyo3(get, set)]
pub name: String,
#[pyo3(get, set)]
pub columns: Vec<(String, DataTypeMap)>,
#[pyo3(get, set)]
pub primary_key: Option<String>,
#[pyo3(get, set)]
pub foreign_keys: Vec<String>,
#[pyo3(get, set)]
pub indexes: Vec<String>,
#[pyo3(get, set)]
pub constraints: Vec<String>,
#[pyo3(get, set)]
pub statistics: SqlStatistics,
#[pyo3(get, set)]
pub filepaths: Option<Vec<String>>,
}
#[pymethods]
impl SqlTable {
#[new]
pub fn new(
table_name: String,
columns: Vec<(String, DataTypeMap)>,
row_count: f64,
filepaths: Option<Vec<String>>,
) -> Self {
Self {
name: table_name,
columns,
primary_key: None,
foreign_keys: Vec::new(),
indexes: Vec::new(),
constraints: Vec::new(),
statistics: SqlStatistics::new(row_count),
filepaths,
}
}
}
#[pyclass(name = "SqlView", module = "datafusion.common", subclass)]
#[derive(Debug, Clone)]
pub struct SqlView {
#[pyo3(get, set)]
pub name: String,
#[pyo3(get, set)]
pub definition: String, // SQL code that defines the view
}
#[pymethods]
impl SqlSchema {
#[new]
pub fn new(schema_name: &str) -> Self {
Self {
name: schema_name.to_owned(),
tables: Vec::new(),
views: Vec::new(),
functions: Vec::new(),
}
}
pub fn table_by_name(&self, table_name: &str) -> Option<SqlTable> {
for tbl in &self.tables {
if tbl.name.eq(table_name) {
return Some(tbl.clone());
}
}
None
}
pub fn add_table(&mut self, table: SqlTable) {
self.tables.push(table);
}
pub fn drop_table(&mut self, table_name: String) {
self.tables.retain(|x| !x.name.eq(&table_name));
}
}
/// SqlTable wrapper that is compatible with DataFusion logical query plans
pub struct SqlTableSource {
schema: SchemaRef,
statistics: Option<SqlStatistics>,
filepaths: Option<Vec<String>>,
}
impl SqlTableSource {
/// Initialize a new `EmptyTable` from a schema
pub fn new(
schema: SchemaRef,
statistics: Option<SqlStatistics>,
filepaths: Option<Vec<String>>,
) -> Self {
Self {
schema,
statistics,
filepaths,
}
}
/// Access optional statistics associated with this table source
pub fn statistics(&self) -> Option<&SqlStatistics> {
self.statistics.as_ref()
}
/// Access optional filepath associated with this table source
#[allow(dead_code)]
pub fn filepaths(&self) -> Option<&Vec<String>> {
self.filepaths.as_ref()
}
}
/// Implement TableSource, used in the logical query plan and in logical query optimizations
impl TableSource for SqlTableSource {
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
fn supports_filter_pushdown(
&self,
filter: &Expr,
) -> datafusion_common::Result<TableProviderFilterPushDown> {
let filters = split_conjunction(filter);
if filters.iter().all(|f| is_supported_push_down_expr(f)) {
// Push down filters to the tablescan operation if all are supported
Ok(TableProviderFilterPushDown::Exact)
} else if filters.iter().any(|f| is_supported_push_down_expr(f)) {
// Partially apply the filter in the TableScan but retain
// the Filter operator in the plan as well
Ok(TableProviderFilterPushDown::Inexact)
} else {
Ok(TableProviderFilterPushDown::Unsupported)
}
}
fn table_type(&self) -> datafusion_expr::TableType {
datafusion_expr::TableType::Base
}
#[allow(deprecated)]
fn supports_filters_pushdown(
&self,
filters: &[&Expr],
) -> datafusion_common::Result<Vec<TableProviderFilterPushDown>> {
filters
.iter()
.map(|f| self.supports_filter_pushdown(f))
.collect()
}
fn get_logical_plan(&self) -> Option<&datafusion_expr::LogicalPlan> {
None
}
}
fn is_supported_push_down_expr(_expr: &Expr) -> bool {
// For now we support all kinds of expr's at this level
true
}
#[pyclass(name = "SqlStatistics", module = "datafusion.common", subclass)]
#[derive(Debug, Clone)]
pub struct SqlStatistics {
row_count: f64,
}
#[pymethods]
impl SqlStatistics {
#[new]
pub fn new(row_count: f64) -> Self {
Self { row_count }
}
#[pyo3(name = "getRowCount")]
pub fn get_row_count(&self) -> f64 {
self.row_count
}
}