blob: d49aac485279e6f6407b559bcd318321eb6b902d [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
use arrow::{
array::{ArrayRef, Float32Array, Float64Array},
datatypes::DataType,
record_batch::RecordBatch,
util::pretty,
};
use datafusion::prelude::*;
use datafusion::{error::Result, physical_plan::functions::make_scalar_function};
use std::sync::Arc;
// create local execution context with an in-memory table
fn create_context() -> Result<ExecutionContext> {
use arrow::datatypes::{Field, Schema};
use datafusion::datasource::MemTable;
// define a schema.
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Float32, false),
Field::new("b", DataType::Float64, false),
]));
// define data.
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Float32Array::from(vec![2.1, 3.1, 4.1, 5.1])),
Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])),
],
)?;
// declare a new context. In spark API, this corresponds to a new spark SQLsession
let mut ctx = ExecutionContext::new();
// declare a table in memory. In spark API, this corresponds to createDataFrame(...).
let provider = MemTable::try_new(schema, vec![vec![batch]])?;
ctx.register_table("t", Arc::new(provider));
Ok(ctx)
}
/// In this example we will declare a single-type, single return type UDF that exponentiates f64, a^b
#[tokio::main]
async fn main() -> Result<()> {
let mut ctx = create_context()?;
// First, declare the actual implementation of the calculation
let pow = |args: &[ArrayRef]| {
// in DataFusion, all `args` and output are dynamically-typed arrays, which means that we need to:
// 1. cast the values to the type we want
// 2. perform the computation for every element in the array (using a loop or SIMD) and construct the result
// this is guaranteed by DataFusion based on the function's signature.
assert_eq!(args.len(), 2);
// 1. cast both arguments to f64. These casts MUST be aligned with the signature or this function panics!
let base = &args[0]
.as_any()
.downcast_ref::<Float64Array>()
.expect("cast failed");
let exponent = &args[1]
.as_any()
.downcast_ref::<Float64Array>()
.expect("cast failed");
// this is guaranteed by DataFusion. We place it just to make it obvious.
assert_eq!(exponent.len(), base.len());
// 2. perform the computation
let array = base
.iter()
.zip(exponent.iter())
.map(|(base, exponent)| {
match (base, exponent) {
// in arrow, any value can be null.
// Here we decide to make our UDF to return null when either base or exponent is null.
(Some(base), Some(exponent)) => Some(base.powf(exponent)),
_ => None,
}
})
.collect::<Float64Array>();
// `Ok` because no error occurred during the calculation (we should add one if exponent was [0, 1[ and the base < 0 because that panics!)
// `Arc` because arrays are immutable, thread-safe, trait objects.
Ok(Arc::new(array) as ArrayRef)
};
// the function above expects an `ArrayRef`, but DataFusion may pass a scalar to a UDF.
// thus, we use `make_scalar_function` to decorare the closure so that it can handle both Arrays and Scalar values.
let pow = make_scalar_function(pow);
// Next:
// * give it a name so that it shows nicely when the plan is printed
// * declare what input it expects
// * declare its return type
let pow = create_udf(
"pow",
// expects two f64
vec![DataType::Float64, DataType::Float64],
// returns f64
Arc::new(DataType::Float64),
pow,
);
// at this point, we can use it or register it, depending on the use-case:
// * if the UDF is expected to be used throughout the program in different contexts,
// we can register it, and call it later:
ctx.register_udf(pow.clone()); // clone is only required in this example because we show both usages
// * if the UDF is expected to be used directly in the scope, `.call` it directly:
let expr = pow.call(vec![col("a"), col("b")]);
// get a DataFrame from the context
let df = ctx.table("t")?;
// if we do not have `pow` in the scope and we registered it, we can get it from the registry
let pow = df.registry().udf("pow")?;
// equivalent to expr
let expr1 = pow.call(vec![col("a"), col("b")]);
// equivalent to `'SELECT pow(a, b), pow(a, b) AS pow1 FROM t'`
let df = df.select(&[
expr,
// alias so that they have different column names
expr1.alias("pow1"),
])?;
// note that "b" is f32, not f64. DataFusion coerces the types to match the UDF's signature.
// execute the query
let results = df.collect().await?;
// print the results
pretty::print_batches(&results)?;
Ok(())
}