blob: 657432ef313622cdef6fd67f33e4c8b321e013df [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
//! # TABLESAMPLE Example
//!
//! This example demonstrates implementing SQL `TABLESAMPLE` support using
//! DataFusion's extensibility APIs.
//!
//! This is a working `TABLESAMPLE` implementation that can serve as a starting
//! point for your own projects. It also works as a template for adding other
//! custom SQL operators, covering the full pipeline from parsing to execution.
//!
//! It shows how to:
//!
//! 1. **Parse** TABLESAMPLE syntax via a custom [`RelationPlanner`]
//! 2. **Plan** sampling as a custom logical node ([`TableSamplePlanNode`])
//! 3. **Execute** sampling via a custom physical operator ([`SampleExec`])
//!
//! ## Supported Syntax
//!
//! ```sql
//! -- Bernoulli sampling (each row has N% chance of selection)
//! SELECT * FROM table TABLESAMPLE BERNOULLI(10 PERCENT)
//!
//! -- Fractional sampling (0.0 to 1.0)
//! SELECT * FROM table TABLESAMPLE (0.1)
//!
//! -- Row count limit
//! SELECT * FROM table TABLESAMPLE (100 ROWS)
//!
//! -- Reproducible sampling with a seed
//! SELECT * FROM table TABLESAMPLE (10 PERCENT) REPEATABLE(42)
//! ```
//!
//! ## Architecture
//!
//! ```text
//! ┌─────────────────────────────────────────────────────────────────┐
//! │ SQL Query │
//! │ SELECT * FROM t TABLESAMPLE BERNOULLI(10 PERCENT) REPEATABLE(1)│
//! └─────────────────────────────────────────────────────────────────┘
//! │
//! ▼
//! ┌─────────────────────────────────────────────────────────────────┐
//! │ TableSamplePlanner │
//! │ (RelationPlanner: parses TABLESAMPLE, creates logical node) │
//! └─────────────────────────────────────────────────────────────────┘
//! │
//! ▼
//! ┌─────────────────────────────────────────────────────────────────┐
//! │ TableSamplePlanNode │
//! │ (UserDefinedLogicalNode: stores sampling params) │
//! └─────────────────────────────────────────────────────────────────┘
//! │
//! ▼
//! ┌─────────────────────────────────────────────────────────────────┐
//! │ TableSampleExtensionPlanner │
//! │ (ExtensionPlanner: creates physical execution plan) │
//! └─────────────────────────────────────────────────────────────────┘
//! │
//! ▼
//! ┌─────────────────────────────────────────────────────────────────┐
//! │ SampleExec │
//! │ (ExecutionPlan: performs actual row sampling at runtime) │
//! └─────────────────────────────────────────────────────────────────┘
//! ```
use std::{
any::Any,
fmt::{self, Debug, Formatter},
hash::{Hash, Hasher},
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use arrow::datatypes::{Float64Type, Int64Type};
use arrow::{
array::{ArrayRef, Int32Array, RecordBatch, StringArray, UInt32Array},
compute,
};
use arrow_schema::SchemaRef;
use futures::{
ready,
stream::{Stream, StreamExt},
};
use rand::{Rng, SeedableRng, rngs::StdRng};
use tonic::async_trait;
use datafusion::optimizer::simplify_expressions::simplify_literal::parse_literal;
use datafusion::{
execution::{
RecordBatchStream, SendableRecordBatchStream, SessionState, SessionStateBuilder,
TaskContext, context::QueryPlanner,
},
physical_expr::EquivalenceProperties,
physical_plan::{
DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties,
metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, RecordOutput},
},
physical_planner::{DefaultPhysicalPlanner, ExtensionPlanner, PhysicalPlanner},
prelude::*,
};
use datafusion_common::{
DFSchemaRef, DataFusionError, Result, Statistics, internal_err, not_impl_err,
plan_datafusion_err, plan_err,
};
use datafusion_expr::{
UserDefinedLogicalNode, UserDefinedLogicalNodeCore,
logical_plan::{Extension, LogicalPlan, LogicalPlanBuilder},
planner::{
PlannedRelation, RelationPlanner, RelationPlannerContext, RelationPlanning,
},
};
use datafusion_sql::sqlparser::ast::{
self, TableFactor, TableSampleMethod, TableSampleUnit,
};
use insta::assert_snapshot;
// ============================================================================
// Example Entry Point
// ============================================================================
/// Runs the TABLESAMPLE examples demonstrating various sampling techniques.
pub async fn table_sample() -> Result<()> {
// Build session with custom query planner for physical planning
let state = SessionStateBuilder::new()
.with_default_features()
.with_query_planner(Arc::new(TableSampleQueryPlanner))
.build();
let ctx = SessionContext::new_with_state(state);
// Register custom relation planner for logical planning
ctx.register_relation_planner(Arc::new(TableSamplePlanner))?;
register_sample_data(&ctx)?;
println!("TABLESAMPLE Example");
println!("===================\n");
run_examples(&ctx).await
}
async fn run_examples(ctx: &SessionContext) -> Result<()> {
// Example 1: Baseline - full table scan
let results = run_example(
ctx,
"Example 1: Full table (baseline)",
"SELECT * FROM sample_data",
)
.await?;
assert_snapshot!(results, @r"
+---------+---------+
| column1 | column2 |
+---------+---------+
| 1 | row_1 |
| 2 | row_2 |
| 3 | row_3 |
| 4 | row_4 |
| 5 | row_5 |
| 6 | row_6 |
| 7 | row_7 |
| 8 | row_8 |
| 9 | row_9 |
| 10 | row_10 |
+---------+---------+
");
// Example 2: Percentage-based Bernoulli sampling
// REPEATABLE(seed) ensures deterministic results for snapshot testing
let results = run_example(
ctx,
"Example 2: BERNOULLI percentage sampling",
"SELECT * FROM sample_data TABLESAMPLE BERNOULLI(30 PERCENT) REPEATABLE(123)",
)
.await?;
assert_snapshot!(results, @r"
+---------+---------+
| column1 | column2 |
+---------+---------+
| 1 | row_1 |
| 2 | row_2 |
| 7 | row_7 |
| 8 | row_8 |
+---------+---------+
");
// Example 3: Fractional sampling (0.0 to 1.0)
// REPEATABLE(seed) ensures deterministic results for snapshot testing
let results = run_example(
ctx,
"Example 3: Fractional sampling",
"SELECT * FROM sample_data TABLESAMPLE (0.5) REPEATABLE(456)",
)
.await?;
assert_snapshot!(results, @r"
+---------+---------+
| column1 | column2 |
+---------+---------+
| 2 | row_2 |
| 4 | row_4 |
| 8 | row_8 |
+---------+---------+
");
// Example 4: Row count limit (deterministic, no seed needed)
let results = run_example(
ctx,
"Example 4: Row count limit",
"SELECT * FROM sample_data TABLESAMPLE (3 ROWS)",
)
.await?;
assert_snapshot!(results, @r"
+---------+---------+
| column1 | column2 |
+---------+---------+
| 1 | row_1 |
| 2 | row_2 |
| 3 | row_3 |
+---------+---------+
");
// Example 5: Sampling combined with filtering
let results = run_example(
ctx,
"Example 5: Sampling with WHERE clause",
"SELECT * FROM sample_data TABLESAMPLE (5 ROWS) WHERE column1 > 2",
)
.await?;
assert_snapshot!(results, @r"
+---------+---------+
| column1 | column2 |
+---------+---------+
| 3 | row_3 |
| 4 | row_4 |
| 5 | row_5 |
+---------+---------+
");
// Example 6: Sampling in JOIN queries
// REPEATABLE(seed) ensures deterministic results for snapshot testing
let results = run_example(
ctx,
"Example 6: Sampling in JOINs",
r#"SELECT t1.column1, t2.column1, t1.column2, t2.column2
FROM sample_data t1 TABLESAMPLE (0.7) REPEATABLE(789)
JOIN sample_data t2 TABLESAMPLE (0.7) REPEATABLE(123)
ON t1.column1 = t2.column1"#,
)
.await?;
assert_snapshot!(results, @r"
+---------+---------+---------+---------+
| column1 | column1 | column2 | column2 |
+---------+---------+---------+---------+
| 2 | 2 | row_2 | row_2 |
| 5 | 5 | row_5 | row_5 |
| 7 | 7 | row_7 | row_7 |
| 8 | 8 | row_8 | row_8 |
| 10 | 10 | row_10 | row_10 |
+---------+---------+---------+---------+
");
Ok(())
}
/// Helper to run a single example query and capture results.
async fn run_example(ctx: &SessionContext, title: &str, sql: &str) -> Result<String> {
println!("{title}:\n{sql}\n");
let df = ctx.sql(sql).await?;
println!("{}\n", df.logical_plan().display_indent());
let batches = df.collect().await?;
let results = arrow::util::pretty::pretty_format_batches(&batches)?.to_string();
println!("{results}\n");
Ok(results)
}
/// Register test data: 10 rows with column1=1..10 and column2="row_1".."row_10"
fn register_sample_data(ctx: &SessionContext) -> Result<()> {
let column1: ArrayRef = Arc::new(Int32Array::from((1..=10).collect::<Vec<i32>>()));
let column2: ArrayRef = Arc::new(StringArray::from(
(1..=10).map(|i| format!("row_{i}")).collect::<Vec<_>>(),
));
let batch =
RecordBatch::try_from_iter(vec![("column1", column1), ("column2", column2)])?;
ctx.register_batch("sample_data", batch)?;
Ok(())
}
// ============================================================================
// Logical Planning: TableSamplePlanner + TableSamplePlanNode
// ============================================================================
/// Relation planner that intercepts `TABLESAMPLE` clauses in SQL and creates
/// [`TableSamplePlanNode`] logical nodes.
#[derive(Debug)]
struct TableSamplePlanner;
impl RelationPlanner for TableSamplePlanner {
fn plan_relation(
&self,
relation: TableFactor,
context: &mut dyn RelationPlannerContext,
) -> Result<RelationPlanning> {
// Only handle Table relations with TABLESAMPLE clause
let TableFactor::Table {
sample: Some(sample),
alias,
name,
args,
with_hints,
version,
with_ordinality,
partitions,
json_path,
index_hints,
} = relation
else {
return Ok(RelationPlanning::Original(Box::new(relation)));
};
// Extract sample spec (handles both before/after alias positions)
let sample = match sample {
ast::TableSampleKind::BeforeTableAlias(s)
| ast::TableSampleKind::AfterTableAlias(s) => s,
};
// Validate sampling method
if let Some(method) = &sample.name
&& *method != TableSampleMethod::Bernoulli
&& *method != TableSampleMethod::Row
{
return not_impl_err!(
"Sampling method {} is not supported (only BERNOULLI and ROW)",
method
);
}
// Offset sampling (ClickHouse-style) not supported
if sample.offset.is_some() {
return not_impl_err!(
"TABLESAMPLE with OFFSET is not supported (requires total row count)"
);
}
// Parse optional REPEATABLE seed
let seed = sample
.seed
.map(|s| {
s.value.to_string().parse::<u64>().map_err(|_| {
plan_datafusion_err!("REPEATABLE seed must be an integer")
})
})
.transpose()?;
// Plan the underlying table without the sample clause
let base_relation = TableFactor::Table {
sample: None,
alias: alias.clone(),
name,
args,
with_hints,
version,
with_ordinality,
partitions,
json_path,
index_hints,
};
let input = context.plan(base_relation)?;
// Handle bucket sampling (Hive-style: TABLESAMPLE(BUCKET x OUT OF y))
if let Some(bucket) = sample.bucket {
if bucket.on.is_some() {
return not_impl_err!(
"TABLESAMPLE BUCKET with ON clause requires CLUSTERED BY table"
);
}
let bucket_num: u64 =
bucket.bucket.to_string().parse().map_err(|_| {
plan_datafusion_err!("bucket number must be an integer")
})?;
let total: u64 =
bucket.total.to_string().parse().map_err(|_| {
plan_datafusion_err!("bucket total must be an integer")
})?;
let fraction = bucket_num as f64 / total as f64;
let plan = TableSamplePlanNode::new(input, fraction, seed).into_plan();
return Ok(RelationPlanning::Planned(Box::new(PlannedRelation::new(
plan, alias,
))));
}
// Handle quantity-based sampling
let Some(quantity) = sample.quantity else {
return plan_err!(
"TABLESAMPLE requires a quantity (percentage, fraction, or row count)"
);
};
let quantity_value_expr = context.sql_to_expr(quantity.value, input.schema())?;
match quantity.unit {
// TABLESAMPLE (N ROWS) - exact row limit
Some(TableSampleUnit::Rows) => {
let rows: i64 = parse_literal::<Int64Type>(&quantity_value_expr)?;
if rows < 0 {
return plan_err!("row count must be non-negative, got {}", rows);
}
let plan = LogicalPlanBuilder::from(input)
.limit(0, Some(rows as usize))?
.build()?;
Ok(RelationPlanning::Planned(Box::new(PlannedRelation::new(
plan, alias,
))))
}
// TABLESAMPLE (N PERCENT) - percentage sampling
Some(TableSampleUnit::Percent) => {
let percent: f64 = parse_literal::<Float64Type>(&quantity_value_expr)?;
let fraction = percent / 100.0;
let plan = TableSamplePlanNode::new(input, fraction, seed).into_plan();
Ok(RelationPlanning::Planned(Box::new(PlannedRelation::new(
plan, alias,
))))
}
// TABLESAMPLE (N) - fraction if <1.0, row limit if >=1.0
None => {
let value = parse_literal::<Float64Type>(&quantity_value_expr)?;
if value < 0.0 {
return plan_err!("sample value must be non-negative, got {}", value);
}
let plan = if value >= 1.0 {
// Interpret as row limit
LogicalPlanBuilder::from(input)
.limit(0, Some(value as usize))?
.build()?
} else {
// Interpret as fraction
TableSamplePlanNode::new(input, value, seed).into_plan()
};
Ok(RelationPlanning::Planned(Box::new(PlannedRelation::new(
plan, alias,
))))
}
}
}
}
/// Custom logical plan node representing a TABLESAMPLE operation.
///
/// Stores sampling parameters (bounds, seed) and wraps the input plan.
/// Gets converted to [`SampleExec`] during physical planning.
#[derive(Debug, Clone, Hash, Eq, PartialEq, PartialOrd)]
struct TableSamplePlanNode {
input: LogicalPlan,
lower_bound: HashableF64,
upper_bound: HashableF64,
seed: u64,
}
impl TableSamplePlanNode {
/// Create a new sampling node with the given fraction (0.0 to 1.0).
fn new(input: LogicalPlan, fraction: f64, seed: Option<u64>) -> Self {
Self {
input,
lower_bound: HashableF64(0.0),
upper_bound: HashableF64(fraction),
seed: seed.unwrap_or_else(rand::random),
}
}
/// Wrap this node in a LogicalPlan::Extension.
fn into_plan(self) -> LogicalPlan {
LogicalPlan::Extension(Extension {
node: Arc::new(self),
})
}
}
impl UserDefinedLogicalNodeCore for TableSamplePlanNode {
fn name(&self) -> &str {
"TableSample"
}
fn inputs(&self) -> Vec<&LogicalPlan> {
vec![&self.input]
}
fn schema(&self) -> &DFSchemaRef {
self.input.schema()
}
fn expressions(&self) -> Vec<Expr> {
vec![]
}
fn fmt_for_explain(&self, f: &mut Formatter) -> fmt::Result {
write!(
f,
"Sample: bounds=[{}, {}], seed={}",
self.lower_bound.0, self.upper_bound.0, self.seed
)
}
fn with_exprs_and_inputs(
&self,
_exprs: Vec<Expr>,
mut inputs: Vec<LogicalPlan>,
) -> Result<Self> {
Ok(Self {
input: inputs.swap_remove(0),
lower_bound: self.lower_bound,
upper_bound: self.upper_bound,
seed: self.seed,
})
}
}
/// Wrapper for f64 that implements Hash and Eq (required for LogicalPlan).
#[derive(Debug, Clone, Copy, PartialOrd)]
struct HashableF64(f64);
impl PartialEq for HashableF64 {
fn eq(&self, other: &Self) -> bool {
self.0.to_bits() == other.0.to_bits()
}
}
impl Eq for HashableF64 {}
impl Hash for HashableF64 {
fn hash<H: Hasher>(&self, state: &mut H) {
self.0.to_bits().hash(state);
}
}
// ============================================================================
// Physical Planning: TableSampleQueryPlanner + TableSampleExtensionPlanner
// ============================================================================
/// Custom query planner that registers [`TableSampleExtensionPlanner`] to
/// convert [`TableSamplePlanNode`] into [`SampleExec`].
#[derive(Debug)]
struct TableSampleQueryPlanner;
#[async_trait]
impl QueryPlanner for TableSampleQueryPlanner {
async fn create_physical_plan(
&self,
logical_plan: &LogicalPlan,
session_state: &SessionState,
) -> Result<Arc<dyn ExecutionPlan>> {
let planner = DefaultPhysicalPlanner::with_extension_planners(vec![Arc::new(
TableSampleExtensionPlanner,
)]);
planner
.create_physical_plan(logical_plan, session_state)
.await
}
}
/// Extension planner that converts [`TableSamplePlanNode`] to [`SampleExec`].
struct TableSampleExtensionPlanner;
#[async_trait]
impl ExtensionPlanner for TableSampleExtensionPlanner {
async fn plan_extension(
&self,
_planner: &dyn PhysicalPlanner,
node: &dyn UserDefinedLogicalNode,
_logical_inputs: &[&LogicalPlan],
physical_inputs: &[Arc<dyn ExecutionPlan>],
_session_state: &SessionState,
) -> Result<Option<Arc<dyn ExecutionPlan>>> {
let Some(sample_node) = node.as_any().downcast_ref::<TableSamplePlanNode>()
else {
return Ok(None);
};
let exec = SampleExec::try_new(
Arc::clone(&physical_inputs[0]),
sample_node.lower_bound.0,
sample_node.upper_bound.0,
sample_node.seed,
)?;
Ok(Some(Arc::new(exec)))
}
}
// ============================================================================
// Physical Execution: SampleExec + BernoulliSampler
// ============================================================================
/// Physical execution plan that samples rows from its input using Bernoulli sampling.
///
/// Each row is independently selected with probability `(upper_bound - lower_bound)`
/// and appears at most once.
#[derive(Debug, Clone)]
pub struct SampleExec {
input: Arc<dyn ExecutionPlan>,
lower_bound: f64,
upper_bound: f64,
seed: u64,
metrics: ExecutionPlanMetricsSet,
cache: PlanProperties,
}
impl SampleExec {
/// Create a new SampleExec with Bernoulli sampling (without replacement).
///
/// # Arguments
/// * `input` - The input execution plan
/// * `lower_bound` - Lower bound of sampling range (typically 0.0)
/// * `upper_bound` - Upper bound of sampling range (0.0 to 1.0)
/// * `seed` - Random seed for reproducible sampling
pub fn try_new(
input: Arc<dyn ExecutionPlan>,
lower_bound: f64,
upper_bound: f64,
seed: u64,
) -> Result<Self> {
if lower_bound < 0.0 || upper_bound > 1.0 || lower_bound > upper_bound {
return internal_err!(
"Sampling bounds must satisfy 0.0 <= lower <= upper <= 1.0, got [{}, {}]",
lower_bound,
upper_bound
);
}
let cache = PlanProperties::new(
EquivalenceProperties::new(input.schema()),
input.properties().partitioning.clone(),
input.properties().emission_type,
input.properties().boundedness,
);
Ok(Self {
input,
lower_bound,
upper_bound,
seed,
metrics: ExecutionPlanMetricsSet::new(),
cache,
})
}
/// Create a sampler for the given partition.
fn create_sampler(&self, partition: usize) -> BernoulliSampler {
let seed = self.seed.wrapping_add(partition as u64);
BernoulliSampler::new(self.lower_bound, self.upper_bound, seed)
}
}
impl DisplayAs for SampleExec {
fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> fmt::Result {
write!(
f,
"SampleExec: bounds=[{}, {}], seed={}",
self.lower_bound, self.upper_bound, self.seed
)
}
}
impl ExecutionPlan for SampleExec {
fn name(&self) -> &'static str {
"SampleExec"
}
fn as_any(&self) -> &dyn Any {
self
}
fn properties(&self) -> &PlanProperties {
&self.cache
}
fn maintains_input_order(&self) -> Vec<bool> {
// Sampling preserves row order (rows are filtered, not reordered)
vec![true]
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![&self.input]
}
fn with_new_children(
self: Arc<Self>,
mut children: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
Ok(Arc::new(Self::try_new(
children.swap_remove(0),
self.lower_bound,
self.upper_bound,
self.seed,
)?))
}
fn execute(
&self,
partition: usize,
context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> {
Ok(Box::pin(SampleStream {
input: self.input.execute(partition, context)?,
sampler: self.create_sampler(partition),
metrics: BaselineMetrics::new(&self.metrics, partition),
}))
}
fn metrics(&self) -> Option<MetricsSet> {
Some(self.metrics.clone_inner())
}
fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
let mut stats = self.input.partition_statistics(partition)?;
let ratio = self.upper_bound - self.lower_bound;
// Scale statistics by sampling ratio (inexact due to randomness)
stats.num_rows = stats
.num_rows
.map(|n| (n as f64 * ratio) as usize)
.to_inexact();
stats.total_byte_size = stats
.total_byte_size
.map(|n| (n as f64 * ratio) as usize)
.to_inexact();
Ok(stats)
}
}
/// Bernoulli sampler: includes each row with probability `(upper - lower)`.
/// This is sampling **without replacement** - each row appears at most once.
struct BernoulliSampler {
lower_bound: f64,
upper_bound: f64,
rng: StdRng,
}
impl BernoulliSampler {
fn new(lower_bound: f64, upper_bound: f64, seed: u64) -> Self {
Self {
lower_bound,
upper_bound,
rng: StdRng::seed_from_u64(seed),
}
}
fn sample(&mut self, batch: &RecordBatch) -> Result<RecordBatch> {
let range = self.upper_bound - self.lower_bound;
if range <= 0.0 {
return Ok(RecordBatch::new_empty(batch.schema()));
}
// Select rows where random value falls in [lower, upper)
let indices: Vec<u32> = (0..batch.num_rows())
.filter(|_| {
let r: f64 = self.rng.random();
r >= self.lower_bound && r < self.upper_bound
})
.map(|i| i as u32)
.collect();
if indices.is_empty() {
return Ok(RecordBatch::new_empty(batch.schema()));
}
compute::take_record_batch(batch, &UInt32Array::from(indices))
.map_err(DataFusionError::from)
}
}
/// Stream adapter that applies sampling to each batch.
struct SampleStream {
input: SendableRecordBatchStream,
sampler: BernoulliSampler,
metrics: BaselineMetrics,
}
impl Stream for SampleStream {
type Item = Result<RecordBatch>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
match ready!(self.input.poll_next_unpin(cx)) {
Some(Ok(batch)) => {
let elapsed = self.metrics.elapsed_compute().clone();
let _timer = elapsed.timer();
let result = self.sampler.sample(&batch);
Poll::Ready(Some(result.record_output(&self.metrics)))
}
Some(Err(e)) => Poll::Ready(Some(Err(e))),
None => Poll::Ready(None),
}
}
}
impl RecordBatchStream for SampleStream {
fn schema(&self) -> SchemaRef {
self.input.schema()
}
}