blob: 22dfbbbf0c3a5011c1f8079525179d911b3864f1 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
use std::sync::Arc;
use arrow::{
array::{ArrayRef, AsArray, Float64Array},
datatypes::Float64Type,
};
use arrow_schema::DataType;
use datafusion::error::Result;
use datafusion::prelude::*;
use datafusion_common::ScalarValue;
use datafusion_expr::{PartitionEvaluator, Volatility, WindowFrame};
// create local execution context with `cars.csv` registered as a table named `cars`
async fn create_context() -> Result<SessionContext> {
// declare a new context. In spark API, this corresponds to a new spark SQL session
let ctx = SessionContext::new();
// declare a table in memory. In spark API, this corresponds to createDataFrame(...).
println!("pwd: {}", std::env::current_dir().unwrap().display());
let csv_path = "../../datafusion/core/tests/data/cars.csv".to_string();
let read_options = CsvReadOptions::default().has_header(true);
ctx.register_csv("cars", &csv_path, read_options).await?;
Ok(ctx)
}
/// In this example we will declare a user defined window function that computes a moving average and then run it using SQL
#[tokio::main]
async fn main() -> Result<()> {
let ctx = create_context().await?;
// here is where we define the UDWF. We also declare its signature:
let smooth_it = create_udwf(
"smooth_it",
DataType::Float64,
Arc::new(DataType::Float64),
Volatility::Immutable,
Arc::new(make_partition_evaluator),
);
// register the window function with DataFusion so we can call it
ctx.register_udwf(smooth_it.clone());
// Use SQL to run the new window function
let df = ctx.sql("SELECT * from cars").await?;
// print the results
df.show().await?;
// Use SQL to run the new window function:
//
// `PARTITION BY car`:each distinct value of car (red, and green)
// should be treated as a separate partition (and will result in
// creating a new `PartitionEvaluator`)
//
// `ORDER BY time`: within each partition ('green' or 'red') the
// rows will be ordered by the value in the `time` column
//
// `evaluate_inside_range` is invoked with a window defined by the
// SQL. In this case:
//
// The first invocation will be passed row 0, the first row in the
// partition.
//
// The second invocation will be passed rows 0 and 1, the first
// two rows in the partition.
//
// etc.
let df = ctx
.sql(
"SELECT \
car, \
speed, \
smooth_it(speed) OVER (PARTITION BY car ORDER BY time) AS smooth_speed,\
time \
from cars \
ORDER BY \
car",
)
.await?;
// print the results
df.show().await?;
// this time, call the new widow function with an explicit
// window so evaluate will be invoked with each window.
//
// `ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING`: each invocation
// sees at most 3 rows: the row before, the current row, and the 1
// row afterward.
let df = ctx.sql(
"SELECT \
car, \
speed, \
smooth_it(speed) OVER (PARTITION BY car ORDER BY time ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) AS smooth_speed,\
time \
from cars \
ORDER BY \
car",
).await?;
// print the results
df.show().await?;
// Now, run the function using the DataFrame API:
let window_expr = smooth_it
.call(vec![col("speed")]) // smooth_it(speed)
.partition_by(vec![col("car")]) // PARTITION BY car
.order_by(vec![col("time").sort(true, true)]) // ORDER BY time ASC
.window_frame(WindowFrame::new(None))
.build()?;
let df = ctx.table("cars").await?.window(vec![window_expr])?;
// print the results
df.show().await?;
Ok(())
}
/// Create a `PartitionEvaluator` to evaluate this function on a new
/// partition.
fn make_partition_evaluator() -> Result<Box<dyn PartitionEvaluator>> {
Ok(Box::new(MyPartitionEvaluator::new()))
}
/// This implements the lowest level evaluation for a window function
///
/// It handles calculating the value of the window function for each
/// distinct values of `PARTITION BY` (each car type in our example)
#[derive(Clone, Debug)]
struct MyPartitionEvaluator {}
impl MyPartitionEvaluator {
fn new() -> Self {
Self {}
}
}
/// Different evaluation methods are called depending on the various
/// settings of WindowUDF. This example uses the simplest and most
/// general, `evaluate`. See `PartitionEvaluator` for the other more
/// advanced uses.
impl PartitionEvaluator for MyPartitionEvaluator {
/// Tell DataFusion the window function varies based on the value
/// of the window frame.
fn uses_window_frame(&self) -> bool {
true
}
/// This function is called once per input row.
///
/// `range`specifies which indexes of `values` should be
/// considered for the calculation.
///
/// Note this is the SLOWEST, but simplest, way to evaluate a
/// window function. It is much faster to implement
/// evaluate_all or evaluate_all_with_rank, if possible
fn evaluate(
&mut self,
values: &[ArrayRef],
range: &std::ops::Range<usize>,
) -> Result<ScalarValue> {
// Again, the input argument is an array of floating
// point numbers to calculate a moving average
let arr: &Float64Array = values[0].as_ref().as_primitive::<Float64Type>();
let range_len = range.end - range.start;
// our smoothing function will average all the values in the
let output = if range_len > 0 {
let sum: f64 = arr.values().iter().skip(range.start).take(range_len).sum();
Some(sum / range_len as f64)
} else {
None
};
Ok(ScalarValue::Float64(output))
}
}