| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| test_that("list_compute_functions() works", { |
| expect_type(list_compute_functions(), "character") |
| expect_true(all(!grepl("^hash_", list_compute_functions()))) |
| }) |
| |
| test_that("arrow_scalar_function() works", { |
| # check in/out type as schema/data type |
| fun <- arrow_scalar_function( |
| function(context, x) x$cast(int64()), |
| schema(x = int32()), |
| int64() |
| ) |
| expect_equal(fun$in_type[[1]], schema(x = int32())) |
| expect_equal(fun$out_type[[1]](), int64()) |
| |
| # check in/out type as data type/data type |
| fun <- arrow_scalar_function( |
| function(context, x) x$cast(int64()), |
| int32(), |
| int64() |
| ) |
| expect_equal(fun$in_type[[1]][[1]], field("", int32())) |
| expect_equal(fun$out_type[[1]](), int64()) |
| |
| # check in/out type as field/data type |
| fun <- arrow_scalar_function( |
| function(context, a_name) x$cast(int64()), |
| field("a_name", int32()), |
| int64() |
| ) |
| expect_equal(fun$in_type[[1]], schema(a_name = int32())) |
| expect_equal(fun$out_type[[1]](), int64()) |
| |
| # check in/out type as lists |
| fun <- arrow_scalar_function( |
| function(context, x) x, |
| list(int32(), int64()), |
| list(int64(), int32()), |
| auto_convert = TRUE |
| ) |
| |
| expect_equal(fun$in_type[[1]][[1]], field("", int32())) |
| expect_equal(fun$in_type[[2]][[1]], field("", int64())) |
| expect_equal(fun$out_type[[1]](), int64()) |
| expect_equal(fun$out_type[[2]](), int32()) |
| |
| expect_snapshot_error(arrow_scalar_function(NULL, int32(), int32())) |
| }) |
| |
| test_that("arrow_scalar_function() works with auto_convert = TRUE", { |
| times_32_wrapper <- arrow_scalar_function( |
| function(context, x) x * 32, |
| float64(), |
| float64(), |
| auto_convert = TRUE |
| ) |
| |
| dummy_kernel_context <- list() |
| |
| expect_equal( |
| times_32_wrapper$wrapper_fun(dummy_kernel_context, list(Scalar$create(2))), |
| Array$create(2 * 32) |
| ) |
| }) |
| |
| test_that("register_scalar_function() adds a compute function to the registry", { |
| skip_if_not(CanRunWithCapturedR()) |
| # TODO(ARROW-17178): User-defined function-friendly ExecPlan execution has |
| # occasional valgrind errors |
| skip_on_linux_devel() |
| |
| register_scalar_function( |
| "times_32", |
| function(context, x) x * 32.0, |
| int32(), |
| float64(), |
| auto_convert = TRUE |
| ) |
| on.exit(unregister_binding("times_32")) |
| |
| expect_true("times_32" %in% names(asNamespace("arrow")$.cache$functions)) |
| expect_true("times_32" %in% list_compute_functions()) |
| |
| expect_equal( |
| call_function("times_32", Array$create(1L, int32())), |
| Array$create(32L, float64()) |
| ) |
| |
| expect_equal( |
| call_function("times_32", Scalar$create(1L, int32())), |
| Scalar$create(32L, float64()) |
| ) |
| |
| skip_if_not_available("acero") |
| |
| expect_identical( |
| record_batch(a = 1L) |> |
| dplyr::mutate(b = times_32(a)) |> |
| dplyr::collect(), |
| tibble::tibble(a = 1L, b = 32.0) |
| ) |
| }) |
| |
| test_that("arrow_scalar_function() with bad return type errors", { |
| skip_if_not(CanRunWithCapturedR()) |
| |
| register_scalar_function( |
| "times_32_bad_return_type_array", |
| function(context, x) Array$create(x, int32()), |
| int32(), |
| float64() |
| ) |
| on.exit(unregister_binding("times_32_bad_return_type_array")) |
| |
| expect_error( |
| call_function("times_32_bad_return_type_array", Array$create(1L)), |
| "Expected return Array or Scalar with type 'double'" |
| ) |
| |
| register_scalar_function( |
| "times_32_bad_return_type_scalar", |
| function(context, x) Scalar$create(x, int32()), |
| int32(), |
| float64() |
| ) |
| on.exit(unregister_binding("times_32_bad_return_type_scalar")) |
| |
| expect_error( |
| call_function("times_32_bad_return_type_scalar", Array$create(1L)), |
| "Expected return Array or Scalar with type 'double'" |
| ) |
| }) |
| |
| test_that("register_scalar_function() can register multiple kernels", { |
| skip_if_not(CanRunWithCapturedR()) |
| |
| register_scalar_function( |
| "times_32", |
| function(context, x) x * 32L, |
| in_type = list(int32(), int64(), float64()), |
| out_type = function(in_types) in_types[[1]], |
| auto_convert = TRUE |
| ) |
| on.exit(unregister_binding("times_32")) |
| |
| expect_equal( |
| call_function("times_32", Scalar$create(1L, int32())), |
| Scalar$create(32L, int32()) |
| ) |
| |
| expect_equal( |
| call_function("times_32", Scalar$create(1L, int64())), |
| Scalar$create(32L, int64()) |
| ) |
| |
| expect_equal( |
| call_function("times_32", Scalar$create(1L, float64())), |
| Scalar$create(32L, float64()) |
| ) |
| }) |
| |
| test_that("register_scalar_function() errors for unsupported specifications", { |
| expect_error( |
| register_scalar_function( |
| "no_kernels", |
| function(...) NULL, |
| list(), |
| list() |
| ), |
| "Can't register user-defined scalar function with 0 kernels" |
| ) |
| |
| expect_error( |
| register_scalar_function( |
| "wrong_n_args", |
| function(x) NULL, |
| int32(), |
| int32() |
| ), |
| "Expected `fun` to accept 2 argument\\(s\\)" |
| ) |
| |
| expect_error( |
| register_scalar_function( |
| "var_kernels", |
| function(...) NULL, |
| list(float64(), schema(x = float64(), y = float64())), |
| float64() |
| ), |
| "Kernels for user-defined function must accept the same number of arguments" |
| ) |
| }) |
| |
| test_that("user-defined functions work during multi-threaded execution", { |
| skip_if_not(CanRunWithCapturedR()) |
| skip_if_not_available("dataset") |
| # Skip on linux devel because: |
| # TODO(ARROW-17283): Snappy has a UBSan issue that is fixed in the dev version |
| # TODO(ARROW-17178): User-defined function-friendly ExecPlan execution has |
| # occasional valgrind errors |
| skip_on_linux_devel() |
| |
| n_rows <- 10000 |
| n_partitions <- 10 |
| example_df <- expand.grid( |
| part = letters[seq_len(n_partitions)], |
| value = seq_len(n_rows), |
| stringsAsFactors = FALSE |
| ) |
| |
| # make sure values are different for each partition and |
| example_df$row_num <- seq_len(nrow(example_df)) |
| example_df$value <- example_df$value + match(example_df$part, letters) |
| |
| tf_dataset <- tempfile() |
| tf_dest <- tempfile() |
| on.exit(unlink(c(tf_dataset, tf_dest))) |
| write_dataset(example_df, tf_dataset, partitioning = "part") |
| |
| register_scalar_function( |
| "times_32", |
| function(context, x) x * 32.0, |
| int32(), |
| float64(), |
| auto_convert = TRUE |
| ) |
| on.exit(unregister_binding("times_32")) |
| |
| # check a regular collect() |
| result <- open_dataset(tf_dataset) |> |
| dplyr::mutate(fun_result = times_32(value)) |> |
| dplyr::collect() |> |
| dplyr::arrange(row_num) |
| |
| expect_identical(result$fun_result, example_df$value * 32) |
| |
| # check a write_dataset() |
| open_dataset(tf_dataset) |> |
| dplyr::mutate(fun_result = times_32(value)) |> |
| write_dataset(tf_dest) |
| |
| result2 <- dplyr::collect(open_dataset(tf_dest)) |> |
| dplyr::arrange(row_num) |> |
| dplyr::collect() |
| |
| expect_identical(result2$fun_result, example_df$value * 32) |
| }) |
| |
| test_that("nested exec plans can contain user-defined functions", { |
| skip_if_not_available("dataset") |
| skip_if_not(CanRunWithCapturedR()) |
| |
| register_scalar_function( |
| "times_32", |
| function(context, x) x * 32.0, |
| int32(), |
| float64(), |
| auto_convert = TRUE |
| ) |
| on.exit(unregister_binding("times_32")) |
| |
| stream_plan_with_udf <- function() { |
| record_batch(a = 1:1000) |> |
| dplyr::mutate(b = times_32(a)) |> |
| as_record_batch_reader() |> |
| as_arrow_table() |
| } |
| |
| collect_plan_with_head <- function() { |
| record_batch(a = 1:1000) |> |
| dplyr::mutate(fun_result = times_32(a)) |> |
| head(11) |> |
| dplyr::collect() |
| } |
| |
| expect_equal( |
| stream_plan_with_udf(), |
| record_batch(a = 1:1000) |> |
| dplyr::mutate(b = times_32(a)) |> |
| dplyr::collect(as_data_frame = FALSE) |
| ) |
| |
| result <- collect_plan_with_head() |
| expect_equal(nrow(result), 11) |
| }) |
| |
| test_that("head() on exec plan containing user-defined functions", { |
| skip("ARROW-18101") |
| skip_if_not_available("dataset") |
| skip_if_not(CanRunWithCapturedR()) |
| |
| register_scalar_function( |
| "times_32", |
| function(context, x) x * 32.0, |
| int32(), |
| float64(), |
| auto_convert = TRUE |
| ) |
| on.exit(unregister_binding("times_32")) |
| |
| result <- record_batch(a = 1:1000) |> |
| dplyr::mutate(b = times_32(a)) |> |
| as_record_batch_reader() |> |
| head(11) |> |
| dplyr::collect() |
| |
| expect_equal(nrow(result), 11) |
| }) |