blob: aedc511c62fefe4d98a98739389c925ef6ea4e79 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
use std::{any::Any, sync::Arc};
use arrow_schema::{Field, Schema};
use datafusion::arrow::{array::Float32Array, record_batch::RecordBatch};
use datafusion::error::Result;
use datafusion::functions_aggregate::average::avg_udaf;
use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility};
use datafusion::{assert_batches_eq, prelude::*};
use datafusion_common::cast::as_float64_array;
use datafusion_expr::function::{AggregateFunctionSimplification, StateFieldsArgs};
use datafusion_expr::simplify::SimplifyInfo;
use datafusion_expr::{
expr::AggregateFunction, function::AccumulatorArgs, Accumulator, AggregateUDF,
AggregateUDFImpl, GroupsAccumulator, Signature,
};
/// This example shows how to use the AggregateUDFImpl::simplify API to simplify/replace user
/// defined aggregate function with a different expression which is defined in the `simplify` method.
#[derive(Debug, Clone)]
struct BetterAvgUdaf {
signature: Signature,
}
impl BetterAvgUdaf {
/// Create a new instance of the GeoMeanUdaf struct
fn new() -> Self {
Self {
signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable),
}
}
}
impl AggregateUDFImpl for BetterAvgUdaf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"better_avg"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Float64)
}
fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
unimplemented!("should not be invoked")
}
fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<arrow_schema::Field>> {
unimplemented!("should not be invoked")
}
fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
true
}
fn create_groups_accumulator(
&self,
_args: AccumulatorArgs,
) -> Result<Box<dyn GroupsAccumulator>> {
unimplemented!("should not get here");
}
// we override method, to return new expression which would substitute
// user defined function call
fn simplify(&self) -> Option<AggregateFunctionSimplification> {
// as an example for this functionality we replace UDF function
// with build-in aggregate function to illustrate the use
let simplify = |aggregate_function: datafusion_expr::expr::AggregateFunction,
_: &dyn SimplifyInfo| {
Ok(Expr::AggregateFunction(AggregateFunction::new_udf(
avg_udaf(),
// yes it is the same Avg, `BetterAvgUdaf` was just a
// marketing pitch :)
aggregate_function.args,
aggregate_function.distinct,
aggregate_function.filter,
aggregate_function.order_by,
aggregate_function.null_treatment,
)))
};
Some(Box::new(simplify))
}
}
// create local session context with an in-memory table
fn create_context() -> Result<SessionContext> {
use datafusion::datasource::MemTable;
// define a schema.
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Float32, false),
Field::new("b", DataType::Float32, false),
]));
// define data in two partitions
let batch1 = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Float32Array::from(vec![2.0, 4.0, 8.0])),
Arc::new(Float32Array::from(vec![2.0, 2.0, 2.0])),
],
)?;
let batch2 = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Float32Array::from(vec![16.0])),
Arc::new(Float32Array::from(vec![2.0])),
],
)?;
let ctx = SessionContext::new();
// declare a table in memory. In spark API, this corresponds to createDataFrame(...).
let provider = MemTable::try_new(schema, vec![vec![batch1], vec![batch2]])?;
ctx.register_table("t", Arc::new(provider))?;
Ok(ctx)
}
#[tokio::main]
async fn main() -> Result<()> {
let ctx = create_context()?;
let better_avg = AggregateUDF::from(BetterAvgUdaf::new());
ctx.register_udaf(better_avg.clone());
let result = ctx
.sql("SELECT better_avg(a) FROM t group by b")
.await?
.collect()
.await?;
let expected = [
"+-----------------+",
"| better_avg(t.a) |",
"+-----------------+",
"| 7.5 |",
"+-----------------+",
];
assert_batches_eq!(expected, &result);
let df = ctx.table("t").await?;
let df = df.aggregate(vec![], vec![better_avg.call(vec![col("a")])])?;
let results = df.collect().await?;
let result = as_float64_array(results[0].column(0))?;
assert!((result.value(0) - 7.5).abs() < f64::EPSILON);
println!("The average of [2,4,8,16] is {}", result.value(0));
Ok(())
}