blob: 475775a599f62038828a1e84308779e32d635ce6 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
//! This example shows how to create and use "Async UDFs" in DataFusion.
//!
//! Async UDFs allow you to perform asynchronous operations, such as
//! making network requests. This can be used for tasks like fetching
//! data from an external API such as a LLM service or an external database.
use arrow::array::{ArrayRef, BooleanArray, Int64Array, RecordBatch, StringArray};
use arrow_schema::{DataType, Field, Schema};
use async_trait::async_trait;
use datafusion::assert_batches_eq;
use datafusion::common::cast::as_string_view_array;
use datafusion::common::error::Result;
use datafusion::common::not_impl_err;
use datafusion::common::utils::take_function_args;
use datafusion::execution::SessionStateBuilder;
use datafusion::logical_expr::async_udf::{AsyncScalarUDF, AsyncScalarUDFImpl};
use datafusion::logical_expr::{
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
};
use datafusion::prelude::{SessionConfig, SessionContext};
use std::any::Any;
use std::sync::Arc;
/// In this example we register `AskLLM` as an asynchronous user defined function
/// and invoke it via the DataFrame API and SQL
pub async fn async_udf() -> Result<()> {
// Use a hard coded parallelism level of 4 so the explain plan
// is consistent across machines.
let config = SessionConfig::new().with_target_partitions(4);
let ctx =
SessionContext::from(SessionStateBuilder::new().with_config(config).build());
// Similarly to regular UDFs, you create an AsyncScalarUDF by implementing
// `AsyncScalarUDFImpl` and creating an instance of `AsyncScalarUDF`.
let async_equal = AskLLM::new();
let udf = AsyncScalarUDF::new(Arc::new(async_equal));
// Async UDFs are registered with the SessionContext, using the same
// `register_udf` method as regular UDFs.
ctx.register_udf(udf.into_scalar_udf());
// Create a table named 'animal' with some sample data
ctx.register_batch("animal", animal()?)?;
// You can use the async UDF as normal in SQL queries
//
// Note: Async UDFs can currently be used in the select list and filter conditions.
let results = ctx
.sql("select * from animal a where ask_llm(a.name, 'Is this animal furry?')")
.await?
.collect()
.await?;
assert_batches_eq!(
[
"+----+------+",
"| id | name |",
"+----+------+",
"| 1 | cat |",
"| 2 | dog |",
"+----+------+",
],
&results
);
// While the interface is the same for both normal and async UDFs, you can
// use `EXPLAIN` output to see that the async UDF uses a special
// `AsyncFuncExec` node in the physical plan:
let results = ctx
.sql("explain select * from animal a where ask_llm(a.name, 'Is this animal furry?')")
.await?
.collect()
.await?;
assert_batches_eq!(
[
"+---------------+--------------------------------------------------------------------------------------------------------------------------------+",
"| plan_type | plan |",
"+---------------+--------------------------------------------------------------------------------------------------------------------------------+",
"| logical_plan | SubqueryAlias: a |",
"| | Filter: ask_llm(CAST(animal.name AS Utf8View), Utf8View(\"Is this animal furry?\")) |",
"| | TableScan: animal projection=[id, name] |",
"| physical_plan | CoalesceBatchesExec: target_batch_size=8192 |",
"| | FilterExec: __async_fn_0@2, projection=[id@0, name@1] |",
"| | RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 |",
"| | AsyncFuncExec: async_expr=[async_expr(name=__async_fn_0, expr=ask_llm(CAST(name@1 AS Utf8View), Is this animal furry?))] |",
"| | CoalesceBatchesExec: target_batch_size=8192 |",
"| | DataSourceExec: partitions=1, partition_sizes=[1] |",
"| | |",
"+---------------+--------------------------------------------------------------------------------------------------------------------------------+",
],
&results
);
Ok(())
}
/// Returns a sample `RecordBatch` representing an "animal" table with two columns:
fn animal() -> Result<RecordBatch> {
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, false),
Field::new("name", DataType::Utf8, false),
]));
let id_array = Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5]));
let name_array = Arc::new(StringArray::from(vec![
"cat", "dog", "fish", "bird", "snake",
]));
Ok(RecordBatch::try_new(schema, vec![id_array, name_array])?)
}
/// An async UDF that simulates asking a large language model (LLM) service a
/// question based on the content of two columns. The UDF will return a boolean
/// indicating whether the LLM thinks the first argument matches the question in
/// the second argument.
///
/// Since this is a simplified example, it does not call an LLM service, but
/// could be extended to do so in a real-world scenario.
#[derive(Debug, PartialEq, Eq, Hash)]
struct AskLLM {
signature: Signature,
}
impl Default for AskLLM {
fn default() -> Self {
Self::new()
}
}
impl AskLLM {
pub fn new() -> Self {
Self {
signature: Signature::exact(
vec![DataType::Utf8View, DataType::Utf8View],
Volatility::Volatile,
),
}
}
}
/// All async UDFs implement the `ScalarUDFImpl` trait, which provides the basic
/// information for the function, such as its name, signature, and return type.
/// [async_trait]
impl ScalarUDFImpl for AskLLM {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"ask_llm"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Boolean)
}
/// Since this is an async UDF, the `invoke_with_args` method will not be
/// called directly.
fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
not_impl_err!("AskLLM can only be called from async contexts")
}
}
/// In addition to [`ScalarUDFImpl`], we also need to implement the
/// [`AsyncScalarUDFImpl`] trait.
#[async_trait]
impl AsyncScalarUDFImpl for AskLLM {
/// The `invoke_async_with_args` method is similar to `invoke_with_args`,
/// but it returns a `Future` that resolves to the result.
///
/// Since this signature is `async`, it can do any `async` operations, such
/// as network requests. This method is run on the same tokio `Runtime` that
/// is processing the query, so you may wish to make actual network requests
/// on a different `Runtime`, as explained in the `thread_pools.rs` example
/// in this directory.
async fn invoke_async_with_args(
&self,
args: ScalarFunctionArgs,
) -> Result<ColumnarValue> {
// in a real UDF you would likely want to special case constant
// arguments to improve performance, but this example converts the
// arguments to arrays for simplicity.
let args = ColumnarValue::values_to_arrays(&args.args)?;
let [content_column, question_column] = take_function_args(self.name(), args)?;
// In a real function, you would use a library such as `reqwest` here to
// make an async HTTP request. Credentials and other configurations can
// be supplied via the `ConfigOptions` parameter.
// In this example, we will simulate the LLM response by comparing the two
// input arguments using some static strings
let content_column = as_string_view_array(&content_column)?;
let question_column = as_string_view_array(&question_column)?;
let result_array: BooleanArray = content_column
.iter()
.zip(question_column.iter())
.map(|(a, b)| {
// If either value is null, return None
let a = a?;
let b = b?;
// Simulate an LLM response by checking the arguments to some
// hardcoded conditions.
if a.contains("cat") && b.contains("furry")
|| a.contains("dog") && b.contains("furry")
{
Some(true)
} else {
Some(false)
}
})
.collect();
Ok(ColumnarValue::from(Arc::new(result_array) as ArrayRef))
}
}