blob: 06a29fa6b7445cffd93f7cf8441896c72df3a302 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
expect_as_vector <- function(x, y, ...) {
expect_equal(as.vector(x), y, ...)
}
expect_data_frame <- function(x, y, ...) {
expect_equal(as.data.frame(x), y, ...)
}
expect_r6_class <- function(object, class) {
expect_s3_class(object, class)
expect_s3_class(object, "R6")
}
#' Mask `testthat::expect_equal()` in order to compare ArrowObjects using their
#' `Equals` methods from the C++ library.
expect_equal <- function(object, expected, ignore_attr = FALSE, ..., info = NULL, label = NULL) {
if (inherits(object, "ArrowObject") && inherits(expected, "ArrowObject")) {
mc <- match.call()
expect_true(
all.equal(object, expected, check.attributes = !ignore_attr),
info = info,
label = paste(rlang::as_label(mc[["object"]]), "==", rlang::as_label(mc[["expected"]]))
)
} else {
testthat::expect_equal(object, expected, ignore_attr = ignore_attr, ..., info = info, label = label)
}
}
expect_type_equal <- function(object, expected, ...) {
if (is.Array(object)) {
object <- object$type
}
if (is.Array(expected)) {
expected <- expected$type
}
expect_equal(object, expected, ...)
}
expect_match_arg_error <- function(object, values = c()) {
expect_error(object, paste0("'arg' .*", paste(dQuote(values), collapse = ", ")))
}
expect_deprecated <- expect_warning
verify_output <- function(...) {
if (isTRUE(grepl("conda", R.Version()$platform))) {
skip("On conda")
}
testthat::verify_output(...)
}
#' Ensure that dplyr methods on Arrow objects return the same as for data frames
#'
#' This function compares the output of running a dplyr expression on a tibble
#' or data.frame object against the output of the same expression run on
#' Arrow Table and RecordBatch objects.
#'
#'
#' @param expr A dplyr pipeline which must have `.input` as its start
#' @param tbl A tibble or data.frame which will be substituted for `.input`
#' @param skip_record_batch The skip message to show (if you should skip the
#' RecordBatch test)
#' @param skip_table The skip message to show (if you should skip the Table test)
#' @param warning The expected warning from the RecordBatch and Table comparison
#' paths, passed to `expect_warning()`. Special values:
#' * `NA` (the default) for ensuring no warning message
#' * `TRUE` is a special case to mean to check for the
#' "not supported in Arrow; pulling data into R" message.
#' @param ... additional arguments, passed to `expect_equal()`
compare_dplyr_binding <- function(expr,
tbl,
skip_record_batch = NULL,
skip_table = NULL,
warning = NA,
...) {
# Quote the contents of `expr` so that we can evaluate it a few different ways
expr <- rlang::enquo(expr)
# Get the expected output by evaluating expr on the .input data.frame using regular dplyr
expected <- rlang::eval_tidy(expr, rlang::new_data_mask(rlang::env(.input = tbl)))
if (isTRUE(warning)) {
# Special-case the simple warning:
warning <- "not supported in Arrow; pulling data into R"
}
skip_msg <- NULL
# Evaluate `expr` on a RecordBatch object and compare with `expected`
if (is.null(skip_record_batch)) {
expect_warning(
via_batch <- rlang::eval_tidy(
expr,
rlang::new_data_mask(rlang::env(.input = record_batch(tbl)))
),
warning
)
expect_equal(via_batch, expected, ...)
} else {
skip_msg <- c(skip_msg, skip_record_batch)
}
# Evaluate `expr` on a Table object and compare with `expected`
if (is.null(skip_table)) {
expect_warning(
via_table <- rlang::eval_tidy(
expr,
rlang::new_data_mask(rlang::env(.input = arrow_table(tbl)))
),
warning
)
expect_equal(via_table, expected, ...)
} else {
skip_msg <- c(skip_msg, skip_table)
}
if (!is.null(skip_msg)) {
skip(paste(skip_msg, collapse = "\n"))
}
}
#' Assert that Arrow dplyr methods error in the same way as methods on data.frame
#'
#' Comparing the error message generated when running expressions on R objects
#' against the error message generated by running the same expression on Arrow
#' Tables and RecordBatches.
#'
#' @param expr A dplyr pipeline which must have `.input` as its start
#' @param tbl A tibble or data.frame which will be substituted for `.input`
#' @param ... additional arguments, passed to `expect_error()`
compare_dplyr_error <- function(expr,
tbl,
...) {
# ensure we have supplied tbl
force(tbl)
expr <- rlang::enquo(expr)
msg <- tryCatch(
rlang::eval_tidy(expr, rlang::new_data_mask(rlang::env(.input = tbl))),
error = function(e) {
msg <- conditionMessage(e)
if (grepl("Problem while computing", msg[1])) {
msg <- conditionMessage(e$parent)
}
# The error here is of the form:
#
# Problem with `filter()` .input `..1`.
# x object 'b_var' not found
# ℹ Input `..1` is `chr == b_var`.
#
# but what we really care about is the `x` block
# so (temporarily) let's pull those blocks out when we find them
pattern <- i18ize_error_messages()
if (grepl(pattern, msg)) {
msg <- sub(paste0("^.*(", pattern, ").*$"), "\\1", msg)
}
msg
}
)
# make sure msg is a character object (i.e. there has been an error)
# If it did not error, we would get a data.frame or whatever
# This expectation will tell us "dplyr on data.frame errored is not TRUE"
expect_true(identical(typeof(msg), "character"), label = "dplyr on data.frame errored")
expect_error(
rlang::eval_tidy(
expr,
rlang::new_data_mask(rlang::env(.input = record_batch(tbl)))
),
msg,
...
)
expect_error(
rlang::eval_tidy(
expr,
rlang::new_data_mask(rlang::env(.input = arrow_table(tbl)))
),
msg,
...
)
}
#' Comparing the output of running expressions on R vectors against the same
#' expression run on Arrow Arrays and ChunkedArrays.
#'
#' @param expr A vectorized R expression which must have `.input` as its start
#' @param vec A vector which will be substituted for `.input`
#' @param skip_array The skip message to show (if you should skip the Array test)
#' @param skip_chunked_array The skip message to show (if you should skip the ChunkedArray test)
#' @param ignore_attr Ignore differences in specified attributes?
#' @param ... additional arguments, passed to `expect_as_vector()`
compare_expression <- function(expr,
vec,
skip_array = NULL,
skip_chunked_array = NULL,
ignore_attr = FALSE,
...) {
expr <- rlang::enquo(expr)
expected <- rlang::eval_tidy(expr, rlang::new_data_mask(rlang::env(.input = vec)))
skip_msg <- NULL
if (is.null(skip_array)) {
via_array <- rlang::eval_tidy(
expr,
rlang::new_data_mask(rlang::env(.input = Array$create(vec)))
)
expect_as_vector(via_array, expected, ignore_attr, ...)
} else {
skip_msg <- c(skip_msg, skip_array)
}
if (is.null(skip_chunked_array)) {
# split input vector into two to exercise ChunkedArray with >1 chunk
split_vector <- split_vector_as_list(vec)
via_chunked <- rlang::eval_tidy(
expr,
rlang::new_data_mask(rlang::env(.input = ChunkedArray$create(split_vector[[1]], split_vector[[2]])))
)
expect_as_vector(via_chunked, expected, ignore_attr, ...)
} else {
skip_msg <- c(skip_msg, skip_chunked_array)
}
if (!is.null(skip_msg)) {
skip(paste(skip_msg, collapse = "\n"))
}
}
#' Comparing the error message generated when running expressions on R objects
#' against the error message generated by running the same expression on Arrow
#' Arrays and ChunkedArrays.
#'
#' @param expr An R expression which must have `.input` as its start
#' @param vec A vector which will be substituted for `.input`
#' @param skip_array The skip message to show (if you should skip the Array test)
#' @param skip_chunked_array The skip message to show (if you should skip the ChunkedArray test)
#' @param ... additional arguments, passed to `expect_error()`
compare_expression_error <- function(expr,
vec,
skip_array = NULL,
skip_chunked_array = NULL,
...) {
expr <- rlang::enquo(expr)
msg <- tryCatch(
rlang::eval_tidy(expr, rlang::new_data_mask(rlang::env(.input = vec))),
error = function(e) {
msg <- conditionMessage(e)
pattern <- i18ize_error_messages()
if (grepl(pattern, msg)) {
msg <- sub(paste0("^.*(", pattern, ").*$"), "\\1", msg)
}
msg
}
)
expect_true(identical(typeof(msg), "character"), label = "vector errored")
skip_msg <- NULL
if (is.null(skip_array)) {
expect_error(
rlang::eval_tidy(
expr,
rlang::new_data_mask(rlang::env(.input = Array$create(vec)))
),
msg,
...
)
} else {
skip_msg <- c(skip_msg, skip_array)
}
if (is.null(skip_chunked_array)) {
# split input vector into two to exercise ChunkedArray with >1 chunk
split_vector <- split_vector_as_list(vec)
expect_error(
rlang::eval_tidy(
expr,
rlang::new_data_mask(rlang::env(.input = ChunkedArray$create(split_vector[[1]], split_vector[[2]])))
),
msg,
...
)
} else {
skip_msg <- c(skip_msg, skip_chunked_array)
}
if (!is.null(skip_msg)) {
skip(paste(skip_msg, collapse = "\n"))
}
}
split_vector_as_list <- function(vec) {
vec_split <- length(vec) %/% 2
vec1 <- vec[seq(from = min(1, length(vec) - 1), to = min(length(vec) - 1, vec_split), by = 1)]
vec2 <- vec[seq(from = min(length(vec), vec_split + 1), to = length(vec), by = 1)]
list(vec1, vec2)
}