feat: add FFI support for user defined functions (#1145)
* Intermediate work adding ffi scalar udf
* Add scalar UDF and example
* Add aggregate udf via ffi
* Initial commit for window ffi integration
* Remove unused import
diff --git a/docs/source/contributor-guide/ffi.rst b/docs/source/contributor-guide/ffi.rst
index c1f9806..a40af12 100644
--- a/docs/source/contributor-guide/ffi.rst
+++ b/docs/source/contributor-guide/ffi.rst
@@ -176,7 +176,7 @@
``TableProvider`` PyCapsule to have this capsule accessible by calling a function named
``__datafusion_table_provider__``. You can see a complete working example of how to
share a ``TableProvider`` from one python library to DataFusion Python in the
-`repository examples folder <https://github.com/apache/datafusion-python/tree/main/examples/ffi-table-provider>`_.
+`repository examples folder <https://github.com/apache/datafusion-python/tree/main/examples/datafusion-ffi-example>`_.
This section has been written using ``TableProvider`` as an example. It is the first
extension that has been written using this approach and the most thoroughly implemented.
diff --git a/examples/datafusion-ffi-example/Cargo.lock b/examples/datafusion-ffi-example/Cargo.lock
index e5a1ca8..1b4ca6b 100644
--- a/examples/datafusion-ffi-example/Cargo.lock
+++ b/examples/datafusion-ffi-example/Cargo.lock
@@ -323,6 +323,8 @@
checksum = "73a47aa0c771b5381de2b7f16998d351a6f4eb839f1e13d48353e17e873d969b"
dependencies = [
"bitflags",
+ "serde",
+ "serde_json",
]
[[package]]
@@ -748,9 +750,9 @@
[[package]]
name = "datafusion"
-version = "47.0.0"
+version = "48.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "ffe060b978f74ab446be722adb8a274e052e005bf6dfd171caadc3abaad10080"
+checksum = "cc6cb8c2c81eada072059983657d6c9caf3fddefc43b4a65551d243253254a96"
dependencies = [
"arrow",
"arrow-ipc",
@@ -775,7 +777,6 @@
"datafusion-functions-nested",
"datafusion-functions-table",
"datafusion-functions-window",
- "datafusion-macros",
"datafusion-optimizer",
"datafusion-physical-expr",
"datafusion-physical-expr-common",
@@ -790,7 +791,7 @@
"object_store",
"parking_lot",
"parquet",
- "rand",
+ "rand 0.9.1",
"regex",
"sqlparser",
"tempfile",
@@ -803,9 +804,9 @@
[[package]]
name = "datafusion-catalog"
-version = "47.0.0"
+version = "48.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "61fe34f401bd03724a1f96d12108144f8cd495a3cdda2bf5e091822fb80b7e66"
+checksum = "b7be8d1b627843af62e447396db08fe1372d882c0eb8d0ea655fd1fbc33120ee"
dependencies = [
"arrow",
"async-trait",
@@ -829,9 +830,9 @@
[[package]]
name = "datafusion-catalog-listing"
-version = "47.0.0"
+version = "48.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "a4411b8e3bce5e0fc7521e44f201def2e2d5d1b5f176fb56e8cdc9942c890f00"
+checksum = "38ab16c5ae43f65ee525fc493ceffbc41f40dee38b01f643dfcfc12959e92038"
dependencies = [
"arrow",
"async-trait",
@@ -852,9 +853,9 @@
[[package]]
name = "datafusion-common"
-version = "47.0.0"
+version = "48.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "0734015d81c8375eb5d4869b7f7ecccc2ee8d6cb81948ef737cd0e7b743bd69c"
+checksum = "d3d56b2ac9f476b93ca82e4ef5fb00769c8a3f248d12b4965af7e27635fa7e12"
dependencies = [
"ahash",
"arrow",
@@ -876,9 +877,9 @@
[[package]]
name = "datafusion-common-runtime"
-version = "47.0.0"
+version = "48.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "5167bb1d2ccbb87c6bc36c295274d7a0519b14afcfdaf401d53cbcaa4ef4968b"
+checksum = "16015071202d6133bc84d72756176467e3e46029f3ce9ad2cb788f9b1ff139b2"
dependencies = [
"futures",
"log",
@@ -887,9 +888,9 @@
[[package]]
name = "datafusion-datasource"
-version = "47.0.0"
+version = "48.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "04e602dcdf2f50c2abf297cc2203c73531e6f48b29516af7695d338cf2a778b1"
+checksum = "b77523c95c89d2a7eb99df14ed31390e04ab29b43ff793e562bdc1716b07e17b"
dependencies = [
"arrow",
"async-compression",
@@ -912,7 +913,7 @@
"log",
"object_store",
"parquet",
- "rand",
+ "rand 0.9.1",
"tempfile",
"tokio",
"tokio-util",
@@ -923,9 +924,9 @@
[[package]]
name = "datafusion-datasource-csv"
-version = "47.0.0"
+version = "48.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "e3bb2253952dc32296ed5b84077cb2e0257fea4be6373e1c376426e17ead4ef6"
+checksum = "40d25c5e2c0ebe8434beeea997b8e88d55b3ccc0d19344293f2373f65bc524fc"
dependencies = [
"arrow",
"async-trait",
@@ -948,9 +949,9 @@
[[package]]
name = "datafusion-datasource-json"
-version = "47.0.0"
+version = "48.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "5b8c7f47a5d2fe03bfa521ec9bafdb8a5c82de8377f60967c3663f00c8790352"
+checksum = "3dc6959e1155741ab35369e1dc7673ba30fc45ed568fad34c01b7cb1daeb4d4c"
dependencies = [
"arrow",
"async-trait",
@@ -973,9 +974,9 @@
[[package]]
name = "datafusion-datasource-parquet"
-version = "47.0.0"
+version = "48.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "27d15868ea39ed2dc266728b554f6304acd473de2142281ecfa1294bb7415923"
+checksum = "b7a6afdfe358d70f4237f60eaef26ae5a1ce7cb2c469d02d5fc6c7fd5d84e58b"
dependencies = [
"arrow",
"async-trait",
@@ -998,21 +999,21 @@
"object_store",
"parking_lot",
"parquet",
- "rand",
+ "rand 0.9.1",
"tokio",
]
[[package]]
name = "datafusion-doc"
-version = "47.0.0"
+version = "48.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "a91f8c2c5788ef32f48ff56c68e5b545527b744822a284373ac79bba1ba47292"
+checksum = "9bcd8a3e3e3d02ea642541be23d44376b5d5c37c2938cce39b3873cdf7186eea"
[[package]]
name = "datafusion-execution"
-version = "47.0.0"
+version = "48.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "06f004d100f49a3658c9da6fb0c3a9b760062d96cd4ad82ccc3b7b69a9fb2f84"
+checksum = "670da1d45d045eee4c2319b8c7ea57b26cf48ab77b630aaa50b779e406da476a"
dependencies = [
"arrow",
"dashmap",
@@ -1022,16 +1023,16 @@
"log",
"object_store",
"parking_lot",
- "rand",
+ "rand 0.9.1",
"tempfile",
"url",
]
[[package]]
name = "datafusion-expr"
-version = "47.0.0"
+version = "48.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "7a4e4ce3802609be38eeb607ee72f6fe86c3091460de9dbfae9e18db423b3964"
+checksum = "b3a577f64bdb7e2cc4043cd97f8901d8c504711fde2dbcb0887645b00d7c660b"
dependencies = [
"arrow",
"chrono",
@@ -1050,9 +1051,9 @@
[[package]]
name = "datafusion-expr-common"
-version = "47.0.0"
+version = "48.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "422ac9cf3b22bbbae8cdf8ceb33039107fde1b5492693168f13bd566b1bcc839"
+checksum = "51b7916806ace3e9f41884f230f7f38ebf0e955dfbd88266da1826f29a0b9a6a"
dependencies = [
"arrow",
"datafusion-common",
@@ -1063,9 +1064,9 @@
[[package]]
name = "datafusion-ffi"
-version = "47.0.0"
+version = "48.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "5cf3fe9ab492c56daeb7beed526690d33622d388b8870472e0b7b7f55490338c"
+checksum = "980cca31de37f5dadf7ea18e4ffc2b6833611f45bed5ef9de0831d2abb50f1ef"
dependencies = [
"abi_stable",
"arrow",
@@ -1073,7 +1074,9 @@
"async-ffi",
"async-trait",
"datafusion",
+ "datafusion-functions-aggregate-common",
"datafusion-proto",
+ "datafusion-proto-common",
"futures",
"log",
"prost",
@@ -1082,10 +1085,24 @@
]
[[package]]
+name = "datafusion-ffi-example"
+version = "0.2.0"
+dependencies = [
+ "arrow",
+ "arrow-array",
+ "arrow-schema",
+ "async-trait",
+ "datafusion",
+ "datafusion-ffi",
+ "pyo3",
+ "pyo3-build-config",
+]
+
+[[package]]
name = "datafusion-functions"
-version = "47.0.0"
+version = "48.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "2ddf0a0a2db5d2918349c978d42d80926c6aa2459cd8a3c533a84ec4bb63479e"
+checksum = "7fb31c9dc73d3e0c365063f91139dc273308f8a8e124adda9898db8085d68357"
dependencies = [
"arrow",
"arrow-buffer",
@@ -1103,7 +1120,7 @@
"itertools",
"log",
"md-5",
- "rand",
+ "rand 0.9.1",
"regex",
"sha2",
"unicode-segmentation",
@@ -1112,9 +1129,9 @@
[[package]]
name = "datafusion-functions-aggregate"
-version = "47.0.0"
+version = "48.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "408a05dafdc70d05a38a29005b8b15e21b0238734dab1e98483fcb58038c5aba"
+checksum = "ebb72c6940697eaaba9bd1f746a697a07819de952b817e3fb841fb75331ad5d4"
dependencies = [
"ahash",
"arrow",
@@ -1133,9 +1150,9 @@
[[package]]
name = "datafusion-functions-aggregate-common"
-version = "47.0.0"
+version = "48.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "756d21da2dd6c9bef97af1504970ff56cbf35d03fbd4ffd62827f02f4d2279d4"
+checksum = "d7fdc54656659e5ecd49bf341061f4156ab230052611f4f3609612a0da259696"
dependencies = [
"ahash",
"arrow",
@@ -1146,9 +1163,9 @@
[[package]]
name = "datafusion-functions-nested"
-version = "47.0.0"
+version = "48.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "8d8d50f6334b378930d992d801a10ac5b3e93b846b39e4a05085742572844537"
+checksum = "fad94598e3374938ca43bca6b675febe557e7a14eb627d617db427d70d65118b"
dependencies = [
"arrow",
"arrow-ord",
@@ -1167,9 +1184,9 @@
[[package]]
name = "datafusion-functions-table"
-version = "47.0.0"
+version = "48.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "cc9a97220736c8fff1446e936be90d57216c06f28969f9ffd3b72ac93c958c8a"
+checksum = "de2fc6c2946da5cab8364fb28b5cac3115f0f3a87960b235ed031c3f7e2e639b"
dependencies = [
"arrow",
"async-trait",
@@ -1183,10 +1200,11 @@
[[package]]
name = "datafusion-functions-window"
-version = "47.0.0"
+version = "48.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "cefc2d77646e1aadd1d6a9c40088937aedec04e68c5f0465939912e1291f8193"
+checksum = "3e5746548a8544870a119f556543adcd88fe0ba6b93723fe78ad0439e0fbb8b4"
dependencies = [
+ "arrow",
"datafusion-common",
"datafusion-doc",
"datafusion-expr",
@@ -1200,9 +1218,9 @@
[[package]]
name = "datafusion-functions-window-common"
-version = "47.0.0"
+version = "48.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "dd4aff082c42fa6da99ce0698c85addd5252928c908eb087ca3cfa64ff16b313"
+checksum = "dcbe9404382cda257c434f22e13577bee7047031dfdb6216dd5e841b9465e6fe"
dependencies = [
"datafusion-common",
"datafusion-physical-expr-common",
@@ -1210,9 +1228,9 @@
[[package]]
name = "datafusion-macros"
-version = "47.0.0"
+version = "48.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "df6f88d7ee27daf8b108ba910f9015176b36fbc72902b1ca5c2a5f1d1717e1a1"
+checksum = "8dce50e3b637dab0d25d04d2fe79dfdca2b257eabd76790bffd22c7f90d700c8"
dependencies = [
"datafusion-expr",
"quote",
@@ -1221,9 +1239,9 @@
[[package]]
name = "datafusion-optimizer"
-version = "47.0.0"
+version = "48.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "084d9f979c4b155346d3c34b18f4256e6904ded508e9554d90fed416415c3515"
+checksum = "03cfaacf06445dc3bbc1e901242d2a44f2cae99a744f49f3fefddcee46240058"
dependencies = [
"arrow",
"chrono",
@@ -1240,9 +1258,9 @@
[[package]]
name = "datafusion-physical-expr"
-version = "47.0.0"
+version = "48.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "64c536062b0076f4e30084065d805f389f9fe38af0ca75bcbac86bc5e9fbab65"
+checksum = "1908034a89d7b2630898e06863583ae4c00a0dd310c1589ca284195ee3f7f8a6"
dependencies = [
"ahash",
"arrow",
@@ -1262,9 +1280,9 @@
[[package]]
name = "datafusion-physical-expr-common"
-version = "47.0.0"
+version = "48.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "f8a92b53b3193fac1916a1c5b8e3f4347c526f6822e56b71faa5fb372327a863"
+checksum = "47b7a12dd59ea07614b67dbb01d85254fbd93df45bcffa63495e11d3bdf847df"
dependencies = [
"ahash",
"arrow",
@@ -1276,9 +1294,9 @@
[[package]]
name = "datafusion-physical-optimizer"
-version = "47.0.0"
+version = "48.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "6fa0a5ac94c7cf3da97bedabd69d6bbca12aef84b9b37e6e9e8c25286511b5e2"
+checksum = "4371cc4ad33978cc2a8be93bd54a232d3f2857b50401a14631c0705f3f910aae"
dependencies = [
"arrow",
"datafusion-common",
@@ -1295,9 +1313,9 @@
[[package]]
name = "datafusion-physical-plan"
-version = "47.0.0"
+version = "48.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "690c615db468c2e5fe5085b232d8b1c088299a6c63d87fd960a354a71f7acb55"
+checksum = "dc47bc33025757a5c11f2cd094c5b6b5ed87f46fa33c023e6fdfa25fcbfade23"
dependencies = [
"ahash",
"arrow",
@@ -1325,9 +1343,9 @@
[[package]]
name = "datafusion-proto"
-version = "47.0.0"
+version = "48.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "a4a1afb2bdb05de7ff65be6883ebfd4ec027bd9f1f21c46aa3afd01927160a83"
+checksum = "d8f5d9acd7d96e3bf2a7bb04818373cab6e51de0356e3694b94905fee7b4e8b6"
dependencies = [
"arrow",
"chrono",
@@ -1341,9 +1359,9 @@
[[package]]
name = "datafusion-proto-common"
-version = "47.0.0"
+version = "48.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "35b7a5876ebd6b564fb9a1fd2c3a2a9686b787071a256b47e4708f0916f9e46f"
+checksum = "09ecb5ec152c4353b60f7a5635489834391f7a291d2b39a4820cd469e318b78e"
dependencies = [
"arrow",
"datafusion-common",
@@ -1352,9 +1370,9 @@
[[package]]
name = "datafusion-session"
-version = "47.0.0"
+version = "48.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "ad229a134c7406c057ece00c8743c0c34b97f4e72f78b475fe17b66c5e14fa4f"
+checksum = "d7485da32283985d6b45bd7d13a65169dcbe8c869e25d01b2cfbc425254b4b49"
dependencies = [
"arrow",
"async-trait",
@@ -1376,9 +1394,9 @@
[[package]]
name = "datafusion-sql"
-version = "47.0.0"
+version = "48.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "64f6ab28b72b664c21a27b22a2ff815fd390ed224c26e89a93b5a8154a4e8607"
+checksum = "a466b15632befddfeac68c125f0260f569ff315c6831538cbb40db754134e0df"
dependencies = [
"arrow",
"bigdecimal",
@@ -1442,20 +1460,6 @@
checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"
[[package]]
-name = "ffi-table-provider"
-version = "0.1.0"
-dependencies = [
- "arrow",
- "arrow-array",
- "arrow-schema",
- "async-trait",
- "datafusion",
- "datafusion-ffi",
- "pyo3",
- "pyo3-build-config",
-]
-
-[[package]]
name = "fixedbitset"
version = "0.5.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -1489,6 +1493,12 @@
checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1"
[[package]]
+name = "foldhash"
+version = "0.1.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2"
+
+[[package]]
name = "form_urlencoded"
version = "1.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -1666,6 +1676,11 @@
version = "0.15.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "84b26c544d002229e640969970a2e74021aadf6e2f96372b9c58eff97de08eb3"
+dependencies = [
+ "allocator-api2",
+ "equivalent",
+ "foldhash",
+]
[[package]]
name = "heck"
@@ -2271,12 +2286,14 @@
[[package]]
name = "petgraph"
-version = "0.7.1"
+version = "0.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "3672b37090dbd86368a4145bc067582552b29c27377cad4e0a306c97f9bd7772"
+checksum = "54acf3a685220b533e437e264e4d932cfbdc4cc7ec0cd232ed73c08d03b8a7ca"
dependencies = [
"fixedbitset",
+ "hashbrown 0.15.3",
"indexmap",
+ "serde",
]
[[package]]
@@ -2305,7 +2322,7 @@
checksum = "3c80231409c20246a13fddb31776fb942c38553c51e871f8cbd687a4cfb5843d"
dependencies = [
"phf_shared",
- "rand",
+ "rand 0.8.5",
]
[[package]]
@@ -2484,19 +2501,27 @@
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
dependencies = [
- "libc",
+ "rand_core 0.6.4",
+]
+
+[[package]]
+name = "rand"
+version = "0.9.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "9fbfd9d094a40bf3ae768db9361049ace4c0e04a4fd6b359518bd7b73a73dd97"
+dependencies = [
"rand_chacha",
- "rand_core",
+ "rand_core 0.9.3",
]
[[package]]
name = "rand_chacha"
-version = "0.3.1"
+version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
+checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb"
dependencies = [
"ppv-lite86",
- "rand_core",
+ "rand_core 0.9.3",
]
[[package]]
@@ -2504,8 +2529,14 @@
version = "0.6.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
+
+[[package]]
+name = "rand_core"
+version = "0.9.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38"
dependencies = [
- "getrandom 0.2.16",
+ "getrandom 0.3.3",
]
[[package]]
@@ -3032,9 +3063,9 @@
[[package]]
name = "uuid"
-version = "1.16.0"
+version = "1.17.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "458f7a779bf54acc9f347480ac654f68407d3aab21269a6e3c9f922acd9e2da9"
+checksum = "3cf4199d1e5d15ddd86a694e4d0dffa9c323ce759fea589f00fef9d81cc1931d"
dependencies = [
"getrandom 0.3.3",
"js-sys",
diff --git a/examples/datafusion-ffi-example/Cargo.toml b/examples/datafusion-ffi-example/Cargo.toml
index 3191635..b26ab48 100644
--- a/examples/datafusion-ffi-example/Cargo.toml
+++ b/examples/datafusion-ffi-example/Cargo.toml
@@ -16,13 +16,13 @@
# under the License.
[package]
-name = "ffi-table-provider"
-version = "0.1.0"
+name = "datafusion-ffi-example"
+version = "0.2.0"
edition = "2021"
[dependencies]
-datafusion = { version = "47.0.0" }
-datafusion-ffi = { version = "47.0.0" }
+datafusion = { version = "48.0.0" }
+datafusion-ffi = { version = "48.0.0" }
pyo3 = { version = "0.23", features = ["extension-module", "abi3", "abi3-py39"] }
arrow = { version = "55.0.0" }
arrow-array = { version = "55.0.0" }
diff --git a/examples/datafusion-ffi-example/python/tests/_test_aggregate_udf.py b/examples/datafusion-ffi-example/python/tests/_test_aggregate_udf.py
new file mode 100644
index 0000000..7ea6b29
--- /dev/null
+++ b/examples/datafusion-ffi-example/python/tests/_test_aggregate_udf.py
@@ -0,0 +1,77 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import pyarrow as pa
+from datafusion import SessionContext, col, udaf
+from datafusion_ffi_example import MySumUDF
+
+
+def setup_context_with_table():
+ ctx = SessionContext()
+
+ # Pick numbers here so we get the same value in both groups
+ # since we cannot be certain of the output order of batches
+ batch = pa.RecordBatch.from_arrays(
+ [
+ pa.array([1, 2, 3, None], type=pa.int64()),
+ pa.array([1, 1, 2, 2], type=pa.int64()),
+ ],
+ names=["a", "b"],
+ )
+ ctx.register_record_batches("test_table", [[batch]])
+ return ctx
+
+
+def test_ffi_aggregate_register():
+ ctx = setup_context_with_table()
+ my_udaf = udaf(MySumUDF())
+ ctx.register_udaf(my_udaf)
+
+ result = ctx.sql("select my_custom_sum(a) from test_table group by b").collect()
+
+ assert len(result) == 2
+ assert result[0].num_columns == 1
+
+ result = [r.column(0) for r in result]
+ expected = [
+ pa.array([3], type=pa.int64()),
+ pa.array([3], type=pa.int64()),
+ ]
+
+ assert result == expected
+
+
+def test_ffi_aggregate_call_directly():
+ ctx = setup_context_with_table()
+ my_udaf = udaf(MySumUDF())
+
+ result = (
+ ctx.table("test_table").aggregate([col("b")], [my_udaf(col("a"))]).collect()
+ )
+
+ assert len(result) == 2
+ assert result[0].num_columns == 2
+
+ result = [r.column(1) for r in result]
+ expected = [
+ pa.array([3], type=pa.int64()),
+ pa.array([3], type=pa.int64()),
+ ]
+
+ assert result == expected
diff --git a/examples/datafusion-ffi-example/python/tests/_test_scalar_udf.py b/examples/datafusion-ffi-example/python/tests/_test_scalar_udf.py
new file mode 100644
index 0000000..0c949c3
--- /dev/null
+++ b/examples/datafusion-ffi-example/python/tests/_test_scalar_udf.py
@@ -0,0 +1,70 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import pyarrow as pa
+from datafusion import SessionContext, col, udf
+from datafusion_ffi_example import IsNullUDF
+
+
+def setup_context_with_table():
+ ctx = SessionContext()
+
+ batch = pa.RecordBatch.from_arrays(
+ [pa.array([1, 2, 3, None])],
+ names=["a"],
+ )
+ ctx.register_record_batches("test_table", [[batch]])
+ return ctx
+
+
+def test_ffi_scalar_register():
+ ctx = setup_context_with_table()
+ my_udf = udf(IsNullUDF())
+ ctx.register_udf(my_udf)
+
+ result = ctx.sql("select my_custom_is_null(a) from test_table").collect()
+
+ assert len(result) == 1
+ assert result[0].num_columns == 1
+ print(result)
+
+ result = [r.column(0) for r in result]
+ expected = [
+ pa.array([False, False, False, True], type=pa.bool_()),
+ ]
+
+ assert result == expected
+
+
+def test_ffi_scalar_call_directly():
+ ctx = setup_context_with_table()
+ my_udf = udf(IsNullUDF())
+
+ result = ctx.table("test_table").select(my_udf(col("a"))).collect()
+
+ assert len(result) == 1
+ assert result[0].num_columns == 1
+ print(result)
+
+ result = [r.column(0) for r in result]
+ expected = [
+ pa.array([False, False, False, True], type=pa.bool_()),
+ ]
+
+ assert result == expected
diff --git a/examples/datafusion-ffi-example/python/tests/_test_window_udf.py b/examples/datafusion-ffi-example/python/tests/_test_window_udf.py
new file mode 100644
index 0000000..7d96994
--- /dev/null
+++ b/examples/datafusion-ffi-example/python/tests/_test_window_udf.py
@@ -0,0 +1,89 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import pyarrow as pa
+from datafusion import SessionContext, col, udwf
+from datafusion_ffi_example import MyRankUDF
+
+
+def setup_context_with_table():
+ ctx = SessionContext()
+
+ # Pick numbers here so we get the same value in both groups
+ # since we cannot be certain of the output order of batches
+ batch = pa.RecordBatch.from_arrays(
+ [
+ pa.array([40, 10, 30, 20], type=pa.int64()),
+ ],
+ names=["a"],
+ )
+ ctx.register_record_batches("test_table", [[batch]])
+ return ctx
+
+
+def test_ffi_window_register():
+ ctx = setup_context_with_table()
+ my_udwf = udwf(MyRankUDF())
+ ctx.register_udwf(my_udwf)
+
+ result = ctx.sql(
+ "select a, my_custom_rank() over (order by a) from test_table"
+ ).collect()
+ assert len(result) == 1
+ assert result[0].num_columns == 2
+
+ results = [
+ (result[0][0][idx].as_py(), result[0][1][idx].as_py()) for idx in range(4)
+ ]
+ results.sort()
+
+ expected = [
+ (10, 1),
+ (20, 2),
+ (30, 3),
+ (40, 4),
+ ]
+ assert results == expected
+
+
+def test_ffi_window_call_directly():
+ ctx = setup_context_with_table()
+ my_udwf = udwf(MyRankUDF())
+
+ result = (
+ ctx.table("test_table")
+ .select(col("a"), my_udwf().order_by(col("a")).build())
+ .collect()
+ )
+
+ assert len(result) == 1
+ assert result[0].num_columns == 2
+
+ results = [
+ (result[0][0][idx].as_py(), result[0][1][idx].as_py()) for idx in range(4)
+ ]
+ results.sort()
+
+ expected = [
+ (10, 1),
+ (20, 2),
+ (30, 3),
+ (40, 4),
+ ]
+ assert results == expected
diff --git a/examples/datafusion-ffi-example/src/aggregate_udf.rs b/examples/datafusion-ffi-example/src/aggregate_udf.rs
new file mode 100644
index 0000000..9481fe9
--- /dev/null
+++ b/examples/datafusion-ffi-example/src/aggregate_udf.rs
@@ -0,0 +1,81 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use arrow_schema::DataType;
+use datafusion::error::Result as DataFusionResult;
+use datafusion::functions_aggregate::sum::Sum;
+use datafusion::logical_expr::function::AccumulatorArgs;
+use datafusion::logical_expr::{Accumulator, AggregateUDF, AggregateUDFImpl, Signature};
+use datafusion_ffi::udaf::FFI_AggregateUDF;
+use pyo3::types::PyCapsule;
+use pyo3::{pyclass, pymethods, Bound, PyResult, Python};
+use std::any::Any;
+use std::sync::Arc;
+
+#[pyclass(name = "MySumUDF", module = "datafusion_ffi_example", subclass)]
+#[derive(Debug, Clone)]
+pub(crate) struct MySumUDF {
+ inner: Arc<Sum>,
+}
+
+#[pymethods]
+impl MySumUDF {
+ #[new]
+ fn new() -> Self {
+ Self {
+ inner: Arc::new(Sum::new()),
+ }
+ }
+
+ fn __datafusion_aggregate_udf__<'py>(
+ &self,
+ py: Python<'py>,
+ ) -> PyResult<Bound<'py, PyCapsule>> {
+ let name = cr"datafusion_aggregate_udf".into();
+
+ let func = Arc::new(AggregateUDF::from(self.clone()));
+ let provider = FFI_AggregateUDF::from(func);
+
+ PyCapsule::new(py, provider, Some(name))
+ }
+}
+
+impl AggregateUDFImpl for MySumUDF {
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+
+ fn name(&self) -> &str {
+ "my_custom_sum"
+ }
+
+ fn signature(&self) -> &Signature {
+ self.inner.signature()
+ }
+
+ fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult<DataType> {
+ self.inner.return_type(arg_types)
+ }
+
+ fn accumulator(&self, acc_args: AccumulatorArgs) -> DataFusionResult<Box<dyn Accumulator>> {
+ self.inner.accumulator(acc_args)
+ }
+
+ fn coerce_types(&self, arg_types: &[DataType]) -> DataFusionResult<Vec<DataType>> {
+ self.inner.coerce_types(arg_types)
+ }
+}
diff --git a/examples/datafusion-ffi-example/src/catalog_provider.rs b/examples/datafusion-ffi-example/src/catalog_provider.rs
index 54e61cf..cd26169 100644
--- a/examples/datafusion-ffi-example/src/catalog_provider.rs
+++ b/examples/datafusion-ffi-example/src/catalog_provider.rs
@@ -24,7 +24,6 @@
catalog::{
CatalogProvider, MemoryCatalogProvider, MemorySchemaProvider, SchemaProvider, TableProvider,
},
- common::exec_err,
datasource::MemTable,
error::{DataFusionError, Result},
};
diff --git a/examples/datafusion-ffi-example/src/lib.rs b/examples/datafusion-ffi-example/src/lib.rs
index 3a4cf22..79af276 100644
--- a/examples/datafusion-ffi-example/src/lib.rs
+++ b/examples/datafusion-ffi-example/src/lib.rs
@@ -16,18 +16,27 @@
// under the License.
use crate::catalog_provider::MyCatalogProvider;
+use crate::aggregate_udf::MySumUDF;
+use crate::scalar_udf::IsNullUDF;
use crate::table_function::MyTableFunction;
use crate::table_provider::MyTableProvider;
+use crate::window_udf::MyRankUDF;
use pyo3::prelude::*;
pub(crate) mod catalog_provider;
+pub(crate) mod aggregate_udf;
+pub(crate) mod scalar_udf;
pub(crate) mod table_function;
pub(crate) mod table_provider;
+pub(crate) mod window_udf;
#[pymodule]
fn datafusion_ffi_example(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<MyTableProvider>()?;
m.add_class::<MyTableFunction>()?;
m.add_class::<MyCatalogProvider>()?;
+ m.add_class::<IsNullUDF>()?;
+ m.add_class::<MySumUDF>()?;
+ m.add_class::<MyRankUDF>()?;
Ok(())
}
diff --git a/examples/datafusion-ffi-example/src/scalar_udf.rs b/examples/datafusion-ffi-example/src/scalar_udf.rs
new file mode 100644
index 0000000..7276666
--- /dev/null
+++ b/examples/datafusion-ffi-example/src/scalar_udf.rs
@@ -0,0 +1,91 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use arrow_array::{Array, BooleanArray};
+use arrow_schema::DataType;
+use datafusion::common::ScalarValue;
+use datafusion::error::Result as DataFusionResult;
+use datafusion::logical_expr::{
+ ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, TypeSignature,
+ Volatility,
+};
+use datafusion_ffi::udf::FFI_ScalarUDF;
+use pyo3::types::PyCapsule;
+use pyo3::{pyclass, pymethods, Bound, PyResult, Python};
+use std::any::Any;
+use std::sync::Arc;
+
+#[pyclass(name = "IsNullUDF", module = "datafusion_ffi_example", subclass)]
+#[derive(Debug, Clone)]
+pub(crate) struct IsNullUDF {
+ signature: Signature,
+}
+
+#[pymethods]
+impl IsNullUDF {
+ #[new]
+ fn new() -> Self {
+ Self {
+ signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable),
+ }
+ }
+
+ fn __datafusion_scalar_udf__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyCapsule>> {
+ let name = cr"datafusion_scalar_udf".into();
+
+ let func = Arc::new(ScalarUDF::from(self.clone()));
+ let provider = FFI_ScalarUDF::from(func);
+
+ PyCapsule::new(py, provider, Some(name))
+ }
+}
+
+impl ScalarUDFImpl for IsNullUDF {
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+
+ fn name(&self) -> &str {
+ "my_custom_is_null"
+ }
+
+ fn signature(&self) -> &Signature {
+ &self.signature
+ }
+
+ fn return_type(&self, _arg_types: &[DataType]) -> DataFusionResult<DataType> {
+ Ok(DataType::Boolean)
+ }
+
+ fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult<ColumnarValue> {
+ let input = &args.args[0];
+
+ Ok(match input {
+ ColumnarValue::Array(arr) => match arr.is_nullable() {
+ true => {
+ let nulls = arr.nulls().unwrap();
+ let nulls = BooleanArray::from_iter(nulls.iter().map(|x| Some(!x)));
+ ColumnarValue::Array(Arc::new(nulls))
+ }
+ false => ColumnarValue::Scalar(ScalarValue::Boolean(Some(false))),
+ },
+ ColumnarValue::Scalar(sv) => {
+ ColumnarValue::Scalar(ScalarValue::Boolean(Some(sv == &ScalarValue::Null)))
+ }
+ })
+ }
+}
diff --git a/examples/datafusion-ffi-example/src/window_udf.rs b/examples/datafusion-ffi-example/src/window_udf.rs
new file mode 100644
index 0000000..e0d3979
--- /dev/null
+++ b/examples/datafusion-ffi-example/src/window_udf.rs
@@ -0,0 +1,81 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use arrow_schema::{DataType, FieldRef};
+use datafusion::error::Result as DataFusionResult;
+use datafusion::functions_window::rank::rank_udwf;
+use datafusion::logical_expr::function::{PartitionEvaluatorArgs, WindowUDFFieldArgs};
+use datafusion::logical_expr::{PartitionEvaluator, Signature, WindowUDF, WindowUDFImpl};
+use datafusion_ffi::udwf::FFI_WindowUDF;
+use pyo3::types::PyCapsule;
+use pyo3::{pyclass, pymethods, Bound, PyResult, Python};
+use std::any::Any;
+use std::sync::Arc;
+
+#[pyclass(name = "MyRankUDF", module = "datafusion_ffi_example", subclass)]
+#[derive(Debug, Clone)]
+pub(crate) struct MyRankUDF {
+ inner: Arc<WindowUDF>,
+}
+
+#[pymethods]
+impl MyRankUDF {
+ #[new]
+ fn new() -> Self {
+ Self { inner: rank_udwf() }
+ }
+
+ fn __datafusion_window_udf__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyCapsule>> {
+ let name = cr"datafusion_window_udf".into();
+
+ let func = Arc::new(WindowUDF::from(self.clone()));
+ let provider = FFI_WindowUDF::from(func);
+
+ PyCapsule::new(py, provider, Some(name))
+ }
+}
+
+impl WindowUDFImpl for MyRankUDF {
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+
+ fn name(&self) -> &str {
+ "my_custom_rank"
+ }
+
+ fn signature(&self) -> &Signature {
+ self.inner.signature()
+ }
+
+ fn partition_evaluator(
+ &self,
+ partition_evaluator_args: PartitionEvaluatorArgs,
+ ) -> DataFusionResult<Box<dyn PartitionEvaluator>> {
+ self.inner
+ .inner()
+ .partition_evaluator(partition_evaluator_args)
+ }
+
+ fn field(&self, field_args: WindowUDFFieldArgs) -> DataFusionResult<FieldRef> {
+ self.inner.inner().field(field_args)
+ }
+
+ fn coerce_types(&self, arg_types: &[DataType]) -> DataFusionResult<Vec<DataType>> {
+ self.inner.coerce_types(arg_types)
+ }
+}
diff --git a/python/datafusion/user_defined.py b/python/datafusion/user_defined.py
index dd634c7..bd686ac 100644
--- a/python/datafusion/user_defined.py
+++ b/python/datafusion/user_defined.py
@@ -22,7 +22,7 @@
import functools
from abc import ABCMeta, abstractmethod
from enum import Enum
-from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, overload
+from typing import TYPE_CHECKING, Any, Callable, Optional, Protocol, TypeVar, overload
import pyarrow as pa
@@ -77,6 +77,12 @@
return self.name.lower()
+class ScalarUDFExportable(Protocol):
+ """Type hint for object that has __datafusion_scalar_udf__ PyCapsule."""
+
+ def __datafusion_scalar_udf__(self) -> object: ... # noqa: D105
+
+
class ScalarUDF:
"""Class for performing scalar user-defined functions (UDF).
@@ -96,6 +102,9 @@
See helper method :py:func:`udf` for argument details.
"""
+ if hasattr(func, "__datafusion_scalar_udf__"):
+ self._udf = df_internal.ScalarUDF.from_pycapsule(func)
+ return
if isinstance(input_types, pa.DataType):
input_types = [input_types]
self._udf = df_internal.ScalarUDF(
@@ -134,6 +143,10 @@
name: Optional[str] = None,
) -> ScalarUDF: ...
+ @overload
+ @staticmethod
+ def udf(func: ScalarUDFExportable) -> ScalarUDF: ...
+
@staticmethod
def udf(*args: Any, **kwargs: Any): # noqa: D417
"""Create a new User-Defined Function (UDF).
@@ -147,7 +160,10 @@
Args:
func (Callable, optional): Only needed when calling as a function.
- Skip this argument when using ``udf`` as a decorator.
+ Skip this argument when using `udf` as a decorator. If you have a Rust
+ backed ScalarUDF within a PyCapsule, you can pass this parameter
+ and ignore the rest. They will be determined directly from the
+ underlying function. See the online documentation for more information.
input_types (list[pa.DataType]): The data types of the arguments
to ``func``. This list must be of the same length as the number of
arguments.
@@ -215,12 +231,31 @@
return decorator
+ if hasattr(args[0], "__datafusion_scalar_udf__"):
+ return ScalarUDF.from_pycapsule(args[0])
+
if args and callable(args[0]):
# Case 1: Used as a function, require the first parameter to be callable
return _function(*args, **kwargs)
# Case 2: Used as a decorator with parameters
return _decorator(*args, **kwargs)
+ @staticmethod
+ def from_pycapsule(func: ScalarUDFExportable) -> ScalarUDF:
+ """Create a Scalar UDF from ScalarUDF PyCapsule object.
+
+ This function will instantiate a Scalar UDF that uses a DataFusion
+ ScalarUDF that is exported via the FFI bindings.
+ """
+ name = str(func.__class__)
+ return ScalarUDF(
+ name=name,
+ func=func,
+ input_types=None,
+ return_type=None,
+ volatility=None,
+ )
+
class Accumulator(metaclass=ABCMeta):
"""Defines how an :py:class:`AggregateUDF` accumulates values."""
@@ -242,6 +277,12 @@
"""Return the resultant value."""
+class AggregateUDFExportable(Protocol):
+ """Type hint for object that has __datafusion_aggregate_udf__ PyCapsule."""
+
+ def __datafusion_aggregate_udf__(self) -> object: ... # noqa: D105
+
+
class AggregateUDF:
"""Class for performing scalar user-defined functions (UDF).
@@ -263,6 +304,9 @@
See :py:func:`udaf` for a convenience function and argument
descriptions.
"""
+ if hasattr(accumulator, "__datafusion_aggregate_udf__"):
+ self._udaf = df_internal.AggregateUDF.from_pycapsule(accumulator)
+ return
self._udaf = df_internal.AggregateUDF(
name,
accumulator,
@@ -307,7 +351,7 @@
) -> AggregateUDF: ...
@staticmethod
- def udaf(*args: Any, **kwargs: Any): # noqa: D417
+ def udaf(*args: Any, **kwargs: Any): # noqa: D417, C901
"""Create a new User-Defined Aggregate Function (UDAF).
This class allows you to define an aggregate function that can be used in
@@ -364,6 +408,10 @@
Args:
accum: The accumulator python function. Only needed when calling as a
function. Skip this argument when using ``udaf`` as a decorator.
+ If you have a Rust backed AggregateUDF within a PyCapsule, you can
+ pass this parameter and ignore the rest. They will be determined
+ directly from the underlying function. See the online documentation
+ for more information.
input_types: The data types of the arguments to ``accum``.
return_type: The data type of the return value.
state_type: The data types of the intermediate accumulation.
@@ -422,12 +470,32 @@
return decorator
+ if hasattr(args[0], "__datafusion_aggregate_udf__"):
+ return AggregateUDF.from_pycapsule(args[0])
+
if args and callable(args[0]):
# Case 1: Used as a function, require the first parameter to be callable
return _function(*args, **kwargs)
# Case 2: Used as a decorator with parameters
return _decorator(*args, **kwargs)
+ @staticmethod
+ def from_pycapsule(func: AggregateUDFExportable) -> AggregateUDF:
+ """Create an Aggregate UDF from AggregateUDF PyCapsule object.
+
+ This function will instantiate a Aggregate UDF that uses a DataFusion
+ AggregateUDF that is exported via the FFI bindings.
+ """
+ name = str(func.__class__)
+ return AggregateUDF(
+ name=name,
+ accumulator=func,
+ input_types=None,
+ return_type=None,
+ state_type=None,
+ volatility=None,
+ )
+
class WindowEvaluator:
"""Evaluator class for user-defined window functions (UDWF).
@@ -588,6 +656,12 @@
return False
+class WindowUDFExportable(Protocol):
+ """Type hint for object that has __datafusion_window_udf__ PyCapsule."""
+
+ def __datafusion_window_udf__(self) -> object: ... # noqa: D105
+
+
class WindowUDF:
"""Class for performing window user-defined functions (UDF).
@@ -608,6 +682,9 @@
See :py:func:`udwf` for a convenience function and argument
descriptions.
"""
+ if hasattr(func, "__datafusion_window_udf__"):
+ self._udwf = df_internal.WindowUDF.from_pycapsule(func)
+ return
self._udwf = df_internal.WindowUDF(
name, func, input_types, return_type, str(volatility)
)
@@ -683,7 +760,10 @@
Args:
func: Only needed when calling as a function. Skip this argument when
- using ``udwf`` as a decorator.
+ using ``udwf`` as a decorator. If you have a Rust backed WindowUDF
+ within a PyCapsule, you can pass this parameter and ignore the rest.
+ They will be determined directly from the underlying function. See
+ the online documentation for more information.
input_types: The data types of the arguments.
return_type: The data type of the return value.
volatility: See :py:class:`Volatility` for allowed values.
@@ -692,6 +772,9 @@
Returns:
A user-defined window function that can be used in window function calls.
"""
+ if hasattr(args[0], "__datafusion_window_udf__"):
+ return WindowUDF.from_pycapsule(args[0])
+
if args and callable(args[0]):
# Case 1: Used as a function, require the first parameter to be callable
return WindowUDF._create_window_udf(*args, **kwargs)
@@ -759,6 +842,22 @@
return decorator
+ @staticmethod
+ def from_pycapsule(func: WindowUDFExportable) -> WindowUDF:
+ """Create a Window UDF from WindowUDF PyCapsule object.
+
+ This function will instantiate a Window UDF that uses a DataFusion
+ WindowUDF that is exported via the FFI bindings.
+ """
+ name = str(func.__class__)
+ return WindowUDF(
+ name=name,
+ func=func,
+ input_types=None,
+ return_type=None,
+ volatility=None,
+ )
+
class TableFunction:
"""Class for performing user-defined table functions (UDTF).
diff --git a/src/functions.rs b/src/functions.rs
index b40500b..eeef483 100644
--- a/src/functions.rs
+++ b/src/functions.rs
@@ -682,7 +682,7 @@
add_builder_fns_to_aggregate(agg_fn, None, filter, None, None)
}
-// We handle first_value explicitly because the signature expects an order_by
+// We handle last_value explicitly because the signature expects an order_by
// https://github.com/apache/datafusion/issues/12376
#[pyfunction]
#[pyo3(signature = (expr, distinct=None, filter=None, order_by=None, null_treatment=None))]
diff --git a/src/udaf.rs b/src/udaf.rs
index 34a9cd5..78f4e2b 100644
--- a/src/udaf.rs
+++ b/src/udaf.rs
@@ -19,6 +19,10 @@
use pyo3::{prelude::*, types::PyTuple};
+use crate::common::data_type::PyScalarValue;
+use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionResult};
+use crate::expr::PyExpr;
+use crate::utils::{parse_volatility, validate_pycapsule};
use datafusion::arrow::array::{Array, ArrayRef};
use datafusion::arrow::datatypes::DataType;
use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow};
@@ -27,11 +31,8 @@
use datafusion::logical_expr::{
create_udaf, Accumulator, AccumulatorFactoryFunction, AggregateUDF,
};
-
-use crate::common::data_type::PyScalarValue;
-use crate::errors::to_datafusion_err;
-use crate::expr::PyExpr;
-use crate::utils::parse_volatility;
+use datafusion_ffi::udaf::{FFI_AggregateUDF, ForeignAggregateUDF};
+use pyo3::types::PyCapsule;
#[derive(Debug)]
struct RustAccumulator {
@@ -183,6 +184,26 @@
Ok(Self { function })
}
+ #[staticmethod]
+ pub fn from_pycapsule(func: Bound<'_, PyAny>) -> PyDataFusionResult<Self> {
+ if func.hasattr("__datafusion_aggregate_udf__")? {
+ let capsule = func.getattr("__datafusion_aggregate_udf__")?.call0()?;
+ let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
+ validate_pycapsule(capsule, "datafusion_aggregate_udf")?;
+
+ let udaf = unsafe { capsule.reference::<FFI_AggregateUDF>() };
+ let udaf: ForeignAggregateUDF = udaf.try_into()?;
+
+ Ok(Self {
+ function: udaf.into(),
+ })
+ } else {
+ Err(crate::errors::PyDataFusionError::Common(
+ "__datafusion_aggregate_udf__ does not exist on AggregateUDF object.".to_string(),
+ ))
+ }
+ }
+
/// creates a new PyExpr with the call of the udf
#[pyo3(signature = (*args))]
fn __call__(&self, args: Vec<PyExpr>) -> PyResult<PyExpr> {
diff --git a/src/udf.rs b/src/udf.rs
index 574c9d7..de1e3f1 100644
--- a/src/udf.rs
+++ b/src/udf.rs
@@ -17,6 +17,8 @@
use std::sync::Arc;
+use datafusion_ffi::udf::{FFI_ScalarUDF, ForeignScalarUDF};
+use pyo3::types::PyCapsule;
use pyo3::{prelude::*, types::PyTuple};
use datafusion::arrow::array::{make_array, Array, ArrayData, ArrayRef};
@@ -29,8 +31,9 @@
use datafusion::logical_expr::{create_udf, ColumnarValue};
use crate::errors::to_datafusion_err;
+use crate::errors::{py_datafusion_err, PyDataFusionResult};
use crate::expr::PyExpr;
-use crate::utils::parse_volatility;
+use crate::utils::{parse_volatility, validate_pycapsule};
/// Create a Rust callable function from a python function that expects pyarrow arrays
fn pyarrow_function_to_rust(
@@ -105,6 +108,26 @@
Ok(Self { function })
}
+ #[staticmethod]
+ pub fn from_pycapsule(func: Bound<'_, PyAny>) -> PyDataFusionResult<Self> {
+ if func.hasattr("__datafusion_scalar_udf__")? {
+ let capsule = func.getattr("__datafusion_scalar_udf__")?.call0()?;
+ let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
+ validate_pycapsule(capsule, "datafusion_scalar_udf")?;
+
+ let udf = unsafe { capsule.reference::<FFI_ScalarUDF>() };
+ let udf: ForeignScalarUDF = udf.try_into()?;
+
+ Ok(Self {
+ function: udf.into(),
+ })
+ } else {
+ Err(crate::errors::PyDataFusionError::Common(
+ "__datafusion_scalar_udf__ does not exist on ScalarUDF object.".to_string(),
+ ))
+ }
+ }
+
/// creates a new PyExpr with the call of the udf
#[pyo3(signature = (*args))]
fn __call__(&self, args: Vec<PyExpr>) -> PyResult<PyExpr> {
diff --git a/src/udwf.rs b/src/udwf.rs
index a0c8cc5..4fb9891 100644
--- a/src/udwf.rs
+++ b/src/udwf.rs
@@ -27,16 +27,17 @@
use pyo3::prelude::*;
use crate::common::data_type::PyScalarValue;
-use crate::errors::to_datafusion_err;
+use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionResult};
use crate::expr::PyExpr;
-use crate::utils::parse_volatility;
+use crate::utils::{parse_volatility, validate_pycapsule};
use datafusion::arrow::datatypes::DataType;
use datafusion::arrow::pyarrow::{FromPyArrow, PyArrowType, ToPyArrow};
use datafusion::error::{DataFusionError, Result};
use datafusion::logical_expr::{
PartitionEvaluator, PartitionEvaluatorFactory, Signature, Volatility, WindowUDF, WindowUDFImpl,
};
-use pyo3::types::{PyList, PyTuple};
+use datafusion_ffi::udwf::{FFI_WindowUDF, ForeignWindowUDF};
+use pyo3::types::{PyCapsule, PyList, PyTuple};
#[derive(Debug)]
struct RustPartitionEvaluator {
@@ -245,6 +246,26 @@
Ok(self.function.call(args).into())
}
+ #[staticmethod]
+ pub fn from_pycapsule(func: Bound<'_, PyAny>) -> PyDataFusionResult<Self> {
+ if func.hasattr("__datafusion_window_udf__")? {
+ let capsule = func.getattr("__datafusion_window_udf__")?.call0()?;
+ let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
+ validate_pycapsule(capsule, "datafusion_window_udf")?;
+
+ let udwf = unsafe { capsule.reference::<FFI_WindowUDF>() };
+ let udwf: ForeignWindowUDF = udwf.try_into()?;
+
+ Ok(Self {
+ function: udwf.into(),
+ })
+ } else {
+ Err(crate::errors::PyDataFusionError::Common(
+ "__datafusion_window_udf__ does not exist on WindowUDF object.".to_string(),
+ ))
+ }
+ }
+
fn __repr__(&self) -> PyResult<String> {
Ok(format!("WindowUDF({})", self.function.name()))
}