| // Licensed to the Apache Software Foundation (ASF) under one |
| // or more contributor license agreements. See the NOTICE file |
| // distributed with this work for additional information |
| // regarding copyright ownership. The ASF licenses this file |
| // to you under the Apache License, Version 2.0 (the |
| // "License"); you may not use this file except in compliance |
| // with the License. You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, |
| // software distributed under the License is distributed on an |
| // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| // KIND, either express or implied. See the License for the |
| // specific language governing permissions and limitations |
| // under the License. |
| |
| use datafusion::{ |
| arrow::{ |
| array::{ArrayRef, Float32Array, Float64Array}, |
| datatypes::DataType, |
| record_batch::RecordBatch, |
| }, |
| logical_expr::Volatility, |
| }; |
| |
| use datafusion::error::Result; |
| use datafusion::prelude::*; |
| use datafusion_common::cast::as_float64_array; |
| use datafusion_expr::ColumnarValue; |
| use std::sync::Arc; |
| |
| /// create local execution context with an in-memory table: |
| /// |
| /// ```text |
| /// +-----+-----+ |
| /// | a | b | |
| /// +-----+-----+ |
| /// | 2.1 | 1.0 | |
| /// | 3.1 | 2.0 | |
| /// | 4.1 | 3.0 | |
| /// | 5.1 | 4.0 | |
| /// +-----+-----+ |
| /// ``` |
| fn create_context() -> Result<SessionContext> { |
| // define data. |
| let a: ArrayRef = Arc::new(Float32Array::from(vec![2.1, 3.1, 4.1, 5.1])); |
| let b: ArrayRef = Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])); |
| let batch = RecordBatch::try_from_iter(vec![("a", a), ("b", b)])?; |
| |
| // declare a new context. In spark API, this corresponds to a new spark SQLsession |
| let ctx = SessionContext::new(); |
| |
| // declare a table in memory. In spark API, this corresponds to createDataFrame(...). |
| ctx.register_batch("t", batch)?; |
| Ok(ctx) |
| } |
| |
| /// In this example we will declare a single-type, single return type UDF that exponentiates f64, a^b |
| #[tokio::main] |
| async fn main() -> Result<()> { |
| let ctx = create_context()?; |
| |
| // First, declare the actual implementation of the calculation |
| let pow = Arc::new(|args: &[ColumnarValue]| { |
| // in DataFusion, all `args` and output are dynamically-typed arrays, which means that we need to: |
| // 1. cast the values to the type we want |
| // 2. perform the computation for every element in the array (using a loop or SIMD) and construct the result |
| |
| // this is guaranteed by DataFusion based on the function's signature. |
| assert_eq!(args.len(), 2); |
| |
| // Expand the arguments to arrays (this is simple, but inefficient for |
| // single constant values). |
| let args = ColumnarValue::values_to_arrays(args)?; |
| |
| // 1. cast both arguments to f64. These casts MUST be aligned with the signature or this function panics! |
| let base = as_float64_array(&args[0]).expect("cast failed"); |
| let exponent = as_float64_array(&args[1]).expect("cast failed"); |
| |
| // The array lengths is guaranteed by DataFusion. We assert here to make it obvious. |
| assert_eq!(exponent.len(), base.len()); |
| |
| // 2. perform the computation |
| let array = base |
| .iter() |
| .zip(exponent.iter()) |
| .map(|(base, exponent)| { |
| match (base, exponent) { |
| // in arrow, any value can be null. |
| // Here we decide to make our UDF to return null when either base or exponent is null. |
| (Some(base), Some(exponent)) => Some(base.powf(exponent)), |
| _ => None, |
| } |
| }) |
| .collect::<Float64Array>(); |
| |
| // `Ok` because no error occurred during the calculation (we should add one if exponent was [0, 1[ and the base < 0 because that panics!) |
| // `Arc` because arrays are immutable, thread-safe, trait objects. |
| Ok(ColumnarValue::from(Arc::new(array) as ArrayRef)) |
| }); |
| |
| // Next: |
| // * give it a name so that it shows nicely when the plan is printed |
| // * declare what input it expects |
| // * declare its return type |
| let pow = create_udf( |
| "pow", |
| // expects two f64 |
| vec![DataType::Float64, DataType::Float64], |
| // returns f64 |
| Arc::new(DataType::Float64), |
| Volatility::Immutable, |
| pow, |
| ); |
| |
| // at this point, we can use it or register it, depending on the use-case: |
| // * if the UDF is expected to be used throughout the program in different contexts, |
| // we can register it, and call it later: |
| ctx.register_udf(pow.clone()); // clone is only required in this example because we show both usages |
| |
| // * if the UDF is expected to be used directly in the scope, `.call` it directly: |
| let expr = pow.call(vec![col("a"), col("b")]); |
| |
| // get a DataFrame from the context |
| let df = ctx.table("t").await?; |
| |
| // if we do not have `pow` in the scope and we registered it, we can get it from the registry |
| let pow = df.registry().udf("pow")?; |
| // equivalent to expr |
| let expr1 = pow.call(vec![col("a"), col("b")]); |
| |
| // equivalent to `'SELECT pow(a, b), pow(a, b) AS pow1 FROM t'` |
| let df = df.select(vec![ |
| expr, |
| // alias so that they have different column names |
| expr1.alias("pow1"), |
| ])?; |
| |
| // note that "b" is f32, not f64. DataFusion coerces the types to match the UDF's signature. |
| |
| // print the results |
| df.show().await?; |
| |
| // Given that `pow` is registered in the context, we can also use it in SQL: |
| let sql_df = ctx.sql("SELECT pow(a, b) FROM t").await?; |
| |
| // print the results |
| sql_df.show().await?; |
| |
| Ok(()) |
| } |