blob: 1259f90d64496b6d09244dc2b5e9df4a347294bf [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
use arrow_schema::{Field, Schema};
use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility};
use datafusion_physical_expr::NullState;
use std::{any::Any, sync::Arc};
use arrow::{
array::{
ArrayRef, AsArray, Float32Array, PrimitiveArray, PrimitiveBuilder, UInt32Array,
},
datatypes::{ArrowNativeTypeOp, ArrowPrimitiveType, Float64Type, UInt32Type},
record_batch::RecordBatch,
};
use datafusion::error::Result;
use datafusion::prelude::*;
use datafusion_common::{cast::as_float64_array, ScalarValue};
use datafusion_expr::{
function::{AccumulatorArgs, StateFieldsArgs},
Accumulator, AggregateUDF, AggregateUDFImpl, GroupsAccumulator, Signature,
};
/// This example shows how to use the full AggregateUDFImpl API to implement a user
/// defined aggregate function. As in the `simple_udaf.rs` example, this struct implements
/// a function `accumulator` that returns the `Accumulator` instance.
///
/// To do so, we must implement the `AggregateUDFImpl` trait.
#[derive(Debug, Clone)]
struct GeoMeanUdaf {
signature: Signature,
}
impl GeoMeanUdaf {
/// Create a new instance of the GeoMeanUdaf struct
fn new() -> Self {
Self {
signature: Signature::exact(
// this function will always take one arguments of type f64
vec![DataType::Float64],
// this function is deterministic and will always return the same
// result for the same input
Volatility::Immutable,
),
}
}
}
impl AggregateUDFImpl for GeoMeanUdaf {
/// We implement as_any so that we can downcast the AggregateUDFImpl trait object
fn as_any(&self) -> &dyn Any {
self
}
/// Return the name of this function
fn name(&self) -> &str {
"geo_mean"
}
/// Return the "signature" of this function -- namely that types of arguments it will take
fn signature(&self) -> &Signature {
&self.signature
}
/// What is the type of value that will be returned by this function.
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Float64)
}
/// This is the accumulator factory; DataFusion uses it to create new accumulators.
///
/// This is the accumulator factory for row wise accumulation; Even when `GroupsAccumulator`
/// is supported, DataFusion will use this row oriented
/// accumulator when the aggregate function is used as a window function
/// or when there are only aggregates (no GROUP BY columns) in the plan.
fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(GeometricMean::new()))
}
/// This is the description of the state. accumulator's state() must match the types here.
fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
Ok(vec![
Field::new("prod", args.return_type.clone(), true),
Field::new("n", DataType::UInt32, true),
])
}
/// Tell DataFusion that this aggregate supports the more performant `GroupsAccumulator`
/// which is used for cases when there are grouping columns in the query
fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
true
}
fn create_groups_accumulator(
&self,
_args: AccumulatorArgs,
) -> Result<Box<dyn GroupsAccumulator>> {
Ok(Box::new(GeometricMeanGroupsAccumulator::new()))
}
}
/// A UDAF has state across multiple rows, and thus we require a `struct` with that state.
#[derive(Debug)]
struct GeometricMean {
n: u32,
prod: f64,
}
impl GeometricMean {
// how the struct is initialized
pub fn new() -> Self {
GeometricMean { n: 0, prod: 1.0 }
}
}
// UDAFs are built using the trait `Accumulator`, that offers DataFusion the necessary functions
// to use them.
impl Accumulator for GeometricMean {
// This function serializes our state to `ScalarValue`, which DataFusion uses
// to pass this state between execution stages.
// Note that this can be arbitrary data.
fn state(&mut self) -> Result<Vec<ScalarValue>> {
Ok(vec![
ScalarValue::from(self.prod),
ScalarValue::from(self.n),
])
}
// DataFusion expects this function to return the final value of this aggregator.
// in this case, this is the formula of the geometric mean
fn evaluate(&mut self) -> Result<ScalarValue> {
let value = self.prod.powf(1.0 / self.n as f64);
Ok(ScalarValue::from(value))
}
// DataFusion calls this function to update the accumulator's state for a batch
// of inputs rows. In this case the product is updated with values from the first column
// and the count is updated based on the row count
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
if values.is_empty() {
return Ok(());
}
let arr = &values[0];
(0..arr.len()).try_for_each(|index| {
let v = ScalarValue::try_from_array(arr, index)?;
if let ScalarValue::Float64(Some(value)) = v {
self.prod *= value;
self.n += 1;
} else {
unreachable!("")
}
Ok(())
})
}
// Merge the output of `Self::state()` from other instances of this accumulator
// into this accumulator's state
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
if states.is_empty() {
return Ok(());
}
let arr = &states[0];
(0..arr.len()).try_for_each(|index| {
let v = states
.iter()
.map(|array| ScalarValue::try_from_array(array, index))
.collect::<Result<Vec<_>>>()?;
if let (ScalarValue::Float64(Some(prod)), ScalarValue::UInt32(Some(n))) =
(&v[0], &v[1])
{
self.prod *= prod;
self.n += n;
} else {
unreachable!("")
}
Ok(())
})
}
fn size(&self) -> usize {
std::mem::size_of_val(self)
}
}
// create local session context with an in-memory table
fn create_context() -> Result<SessionContext> {
use datafusion::datasource::MemTable;
// define a schema.
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Float32, false),
Field::new("b", DataType::Float32, false),
]));
// define data in two partitions
let batch1 = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Float32Array::from(vec![2.0, 4.0, 8.0])),
Arc::new(Float32Array::from(vec![2.0, 2.0, 2.0])),
],
)?;
let batch2 = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Float32Array::from(vec![64.0])),
Arc::new(Float32Array::from(vec![2.0])),
],
)?;
// declare a new context. In spark API, this corresponds to a new spark SQLsession
let ctx = SessionContext::new();
// declare a table in memory. In spark API, this corresponds to createDataFrame(...).
let provider = MemTable::try_new(schema, vec![vec![batch1], vec![batch2]])?;
ctx.register_table("t", Arc::new(provider))?;
Ok(ctx)
}
// Define a `GroupsAccumulator` for GeometricMean
/// which handles accumulator state for multiple groups at once.
/// This API is significantly more complicated than `Accumulator`, which manages
/// the state for a single group, but for queries with a large number of groups
/// can be significantly faster. See the `GroupsAccumulator` documentation for
/// more information.
struct GeometricMeanGroupsAccumulator {
/// The type of the internal sum
prod_data_type: DataType,
/// The type of the returned sum
return_data_type: DataType,
/// Count per group (use u32 to make UInt32Array)
counts: Vec<u32>,
/// product per group, stored as the native type (not `ScalarValue`)
prods: Vec<f64>,
/// Track nulls in the input / filters
null_state: NullState,
}
impl GeometricMeanGroupsAccumulator {
fn new() -> Self {
Self {
prod_data_type: DataType::Float64,
return_data_type: DataType::Float64,
counts: vec![],
prods: vec![],
null_state: NullState::new(),
}
}
}
impl GroupsAccumulator for GeometricMeanGroupsAccumulator {
/// Updates the accumulator state given input. DataFusion provides `group_indices`,
/// the groups that each row in `values` belongs to as well as an optional filter of which rows passed.
fn update_batch(
&mut self,
values: &[ArrayRef],
group_indices: &[usize],
opt_filter: Option<&arrow::array::BooleanArray>,
total_num_groups: usize,
) -> Result<()> {
assert_eq!(values.len(), 1, "single argument to update_batch");
let values = values[0].as_primitive::<Float64Type>();
// increment counts, update sums
self.counts.resize(total_num_groups, 0);
self.prods.resize(total_num_groups, 1.0);
// Use the `NullState` structure to generate specialized code for null / non null input elements
self.null_state.accumulate(
group_indices,
values,
opt_filter,
total_num_groups,
|group_index, new_value| {
let prod = &mut self.prods[group_index];
*prod = prod.mul_wrapping(new_value);
self.counts[group_index] += 1;
},
);
Ok(())
}
/// Merge the results from previous invocations of `evaluate` into this accumulator's state
fn merge_batch(
&mut self,
values: &[ArrayRef],
group_indices: &[usize],
opt_filter: Option<&arrow::array::BooleanArray>,
total_num_groups: usize,
) -> Result<()> {
assert_eq!(values.len(), 2, "two arguments to merge_batch");
// first batch is counts, second is partial sums
let partial_prods = values[0].as_primitive::<Float64Type>();
let partial_counts = values[1].as_primitive::<UInt32Type>();
// update counts with partial counts
self.counts.resize(total_num_groups, 0);
self.null_state.accumulate(
group_indices,
partial_counts,
opt_filter,
total_num_groups,
|group_index, partial_count| {
self.counts[group_index] += partial_count;
},
);
// update prods
self.prods.resize(total_num_groups, 1.0);
self.null_state.accumulate(
group_indices,
partial_prods,
opt_filter,
total_num_groups,
|group_index, new_value: <Float64Type as ArrowPrimitiveType>::Native| {
let prod = &mut self.prods[group_index];
*prod = prod.mul_wrapping(new_value);
},
);
Ok(())
}
/// Generate output, as specified by `emit_to` and update the intermediate state
fn evaluate(&mut self, emit_to: datafusion_expr::EmitTo) -> Result<ArrayRef> {
let counts = emit_to.take_needed(&mut self.counts);
let prods = emit_to.take_needed(&mut self.prods);
let nulls = self.null_state.build(emit_to);
assert_eq!(nulls.len(), prods.len());
assert_eq!(counts.len(), prods.len());
// don't evaluate geometric mean with null inputs to avoid errors on null values
let array: PrimitiveArray<Float64Type> = if nulls.null_count() > 0 {
let mut builder = PrimitiveBuilder::<Float64Type>::with_capacity(nulls.len());
let iter = prods.into_iter().zip(counts).zip(nulls.iter());
for ((prod, count), is_valid) in iter {
if is_valid {
builder.append_value(prod.powf(1.0 / count as f64))
} else {
builder.append_null();
}
}
builder.finish()
} else {
let geo_mean: Vec<<Float64Type as ArrowPrimitiveType>::Native> = prods
.into_iter()
.zip(counts)
.map(|(prod, count)| prod.powf(1.0 / count as f64))
.collect::<Vec<_>>();
PrimitiveArray::new(geo_mean.into(), Some(nulls)) // no copy
.with_data_type(self.return_data_type.clone())
};
Ok(Arc::new(array))
}
// return arrays for counts and prods
fn state(&mut self, emit_to: datafusion_expr::EmitTo) -> Result<Vec<ArrayRef>> {
let nulls = self.null_state.build(emit_to);
let nulls = Some(nulls);
let counts = emit_to.take_needed(&mut self.counts);
let counts = UInt32Array::new(counts.into(), nulls.clone()); // zero copy
let prods = emit_to.take_needed(&mut self.prods);
let prods = PrimitiveArray::<Float64Type>::new(prods.into(), nulls) // zero copy
.with_data_type(self.prod_data_type.clone());
Ok(vec![
Arc::new(prods) as ArrayRef,
Arc::new(counts) as ArrayRef,
])
}
fn size(&self) -> usize {
self.counts.capacity() * std::mem::size_of::<u32>()
+ self.prods.capacity() * std::mem::size_of::<Float64Type>()
}
}
#[tokio::main]
async fn main() -> Result<()> {
let ctx = create_context()?;
// create the AggregateUDF
let geometric_mean = AggregateUDF::from(GeoMeanUdaf::new());
ctx.register_udaf(geometric_mean.clone());
let sql_df = ctx.sql("SELECT geo_mean(a) FROM t group by b").await?;
sql_df.show().await?;
// get a DataFrame from the context
// this table has 1 column `a` f32 with values {2,4,8,64}, whose geometric mean is 8.0.
let df = ctx.table("t").await?;
// perform the aggregation
let df = df.aggregate(vec![], vec![geometric_mean.call(vec![col("a")])])?;
// note that "a" is f32, not f64. DataFusion coerces it to match the UDAF's signature.
// execute the query
let results = df.collect().await?;
// downcast the array to the expected type
let result = as_float64_array(results[0].column(0))?;
// verify that the calculation is correct
assert!((result.value(0) - 8.0).abs() < f64::EPSILON);
println!("The geometric mean of [2,4,8,64] is {}", result.value(0));
Ok(())
}