blob: fe648bd3a100d8a0caf99c1b1efc498aa0c60684 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
use arrow::array::{as_primitive_array, Int32Builder, UInt64Array};
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use arrow::record_batch::RecordBatch;
use async_trait::async_trait;
use datafusion::datasource::datasource::{
Statistics, TableProvider, TableProviderFilterPushDown,
};
use datafusion::error::Result;
use datafusion::execution::context::ExecutionContext;
use datafusion::logical_plan::Expr;
use datafusion::physical_plan::common::SizedRecordBatchStream;
use datafusion::physical_plan::{ExecutionPlan, Partitioning, SendableRecordBatchStream};
use datafusion::prelude::*;
use datafusion::scalar::ScalarValue;
use std::sync::Arc;
fn create_batch(value: i32, num_rows: usize) -> Result<RecordBatch> {
let mut builder = Int32Builder::new(num_rows);
for _ in 0..num_rows {
builder.append_value(value)?;
}
Ok(RecordBatch::try_new(
Arc::new(Schema::new(vec![Field::new(
"flag",
DataType::Int32,
false,
)])),
vec![Arc::new(builder.finish())],
)?)
}
#[derive(Debug)]
struct CustomPlan {
schema: SchemaRef,
batches: Vec<Arc<RecordBatch>>,
}
#[async_trait]
impl ExecutionPlan for CustomPlan {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
fn output_partitioning(&self) -> Partitioning {
Partitioning::UnknownPartitioning(1)
}
fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
vec![]
}
fn with_new_children(
&self,
_: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
unreachable!()
}
async fn execute(&self, _: usize) -> Result<SendableRecordBatchStream> {
Ok(Box::pin(SizedRecordBatchStream::new(
self.schema(),
self.batches.clone(),
)))
}
}
#[derive(Clone)]
struct CustomProvider {
zero_batch: RecordBatch,
one_batch: RecordBatch,
}
impl TableProvider for CustomProvider {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn schema(&self) -> SchemaRef {
self.zero_batch.schema()
}
fn scan(
&self,
_: &Option<Vec<usize>>,
_: usize,
filters: &[Expr],
) -> Result<Arc<dyn ExecutionPlan>> {
match &filters[0] {
Expr::BinaryExpr { right, .. } => {
let int_value = match &**right {
Expr::Literal(ScalarValue::Int64(i)) => i.unwrap(),
_ => unimplemented!(),
};
Ok(Arc::new(CustomPlan {
schema: self.zero_batch.schema(),
batches: match int_value {
0 => vec![Arc::new(self.zero_batch.clone())],
1 => vec![Arc::new(self.one_batch.clone())],
_ => vec![],
},
}))
}
_ => Ok(Arc::new(CustomPlan {
schema: self.zero_batch.schema(),
batches: vec![],
})),
}
}
fn statistics(&self) -> Statistics {
Statistics::default()
}
fn supports_filter_pushdown(&self, _: &Expr) -> Result<TableProviderFilterPushDown> {
Ok(TableProviderFilterPushDown::Exact)
}
}
async fn assert_provider_row_count(value: i64, expected_count: u64) -> Result<()> {
let provider = CustomProvider {
zero_batch: create_batch(0, 10)?,
one_batch: create_batch(1, 5)?,
};
let mut ctx = ExecutionContext::new();
let df = ctx
.read_table(Arc::new(provider.clone()))?
.filter(col("flag").eq(lit(value)))?
.aggregate(&[], &[count(col("flag"))])?;
let results = df.collect().await?;
let result_col: &UInt64Array = as_primitive_array(results[0].column(0));
assert_eq!(result_col.value(0), expected_count);
ctx.register_table("data", Arc::new(provider));
let sql_results = ctx
.sql(&format!("select count(*) from data where flag = {}", value))?
.collect()
.await?;
let sql_result_col: &UInt64Array = as_primitive_array(sql_results[0].column(0));
assert_eq!(sql_result_col.value(0), expected_count);
Ok(())
}
#[tokio::test]
async fn test_filter_pushdown_results() -> Result<()> {
assert_provider_row_count(0, 10).await?;
assert_provider_row_count(1, 5).await?;
assert_provider_row_count(2, 0).await?;
Ok(())
}