blob: fd69c852e684ca80624efede532fa6c78e7e92b7 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#' @include expression.R
#' @include record-batch.R
#' @include table.R
arrow_dplyr_query <- function(.data) {
# An arrow_dplyr_query is a container for an Arrow data object (Table,
# RecordBatch, or Dataset) and the state of the user's dplyr query--things
# like selected columns, filters, and group vars.
# For most dplyr methods,
# method.Table == method.RecordBatch == method.Dataset == method.arrow_dplyr_query
# This works because the functions all pass .data through arrow_dplyr_query()
if (inherits(.data, "arrow_dplyr_query")) {
return(.data)
}
structure(
list(
.data = .data$clone(),
# selected_columns is a named character vector:
# * vector contents are the names of the columns in the data
# * vector names are the names they should be in the end (i.e. this
# records any renaming)
selected_columns = set_names(names(.data)),
# filtered_rows will be an Expression
filtered_rows = TRUE,
# group_by_vars is a character vector of columns (as renamed)
# in the data. They will be kept when data is pulled into R.
group_by_vars = character()
),
class = "arrow_dplyr_query"
)
}
#' @export
print.arrow_dplyr_query <- function(x, ...) {
schm <- x$.data$schema
cols <- x$selected_columns
fields <- map_chr(cols, ~schm$GetFieldByName(.)$ToString())
# Strip off the field names as they are in the dataset and add the renamed ones
fields <- paste(names(cols), sub("^.*?: ", "", fields), sep = ": ", collapse = "\n")
cat(class(x$.data)[1], " (query)\n", sep = "")
cat(fields, "\n", sep = "")
cat("\n")
if (!isTRUE(x$filtered_rows)) {
if (query_on_dataset(x)) {
filter_string <- x$filtered_rows$ToString()
} else {
filter_string <- .format_array_expression(x$filtered_rows)
}
cat("* Filter: ", filter_string, "\n", sep = "")
}
if (length(x$group_by_vars)) {
cat("* Grouped by ", paste(x$group_by_vars, collapse = ", "), "\n", sep = "")
}
cat("See $.data for the source Arrow object\n")
invisible(x)
}
# These are the names reflecting all select/rename, not what is in Arrow
#' @export
names.arrow_dplyr_query <- function(x) names(x$selected_columns)
#' @export
dim.arrow_dplyr_query <- function(x) {
cols <- length(names(x))
if (isTRUE(x$filtered)) {
rows <- x$.data$num_rows
} else if (query_on_dataset(x)) {
warning("Number of rows unknown; returning NA", call. = FALSE)
# TODO: https://issues.apache.org/jira/browse/ARROW-9697
rows <- NA_integer_
} else {
# Evaluate the filter expression to a BooleanArray and count
rows <- as.integer(sum(eval_array_expression(x$filtered_rows), na.rm = TRUE))
}
c(rows, cols)
}
#' @export
as.data.frame.arrow_dplyr_query <- function(x, row.names = NULL, optional = FALSE, ...) {
collect.arrow_dplyr_query(x, as_data_frame = TRUE, ...)
}
#' @export
head.arrow_dplyr_query <- function(x, n = 6L, ...) {
if (query_on_dataset(x)) {
head.Dataset(x, n, ...)
} else {
out <- collect.arrow_dplyr_query(x, as_data_frame = FALSE)
if (inherits(out, "arrow_dplyr_query")) {
out$.data <- head(out$.data, n)
} else {
out <- head(out, n)
}
out
}
}
#' @export
tail.arrow_dplyr_query <- function(x, n = 6L, ...) {
if (query_on_dataset(x)) {
tail.Dataset(x, n, ...)
} else {
out <- collect.arrow_dplyr_query(x, as_data_frame = FALSE)
if (inherits(out, "arrow_dplyr_query")) {
out$.data <- tail(out$.data, n)
} else {
out <- tail(out, n)
}
out
}
}
#' @export
`[.arrow_dplyr_query` <- function(x, i, j, ..., drop = FALSE) {
if (query_on_dataset(x)) {
`[.Dataset`(x, i, j, ..., drop = FALSE)
} else {
stop(
"[ method not implemented for queries. Call 'collect(x, as_data_frame = FALSE)' first",
call. = FALSE
)
}
}
# The following S3 methods are registered on load if dplyr is present
tbl_vars.arrow_dplyr_query <- function(x) names(x$selected_columns)
select.arrow_dplyr_query <- function(.data, ...) {
column_select(arrow_dplyr_query(.data), !!!enquos(...))
}
select.Dataset <- select.Table <- select.RecordBatch <- select.arrow_dplyr_query
#' @importFrom tidyselect vars_rename
rename.arrow_dplyr_query <- function(.data, ...) {
column_select(arrow_dplyr_query(.data), !!!enquos(...), .FUN = vars_rename)
}
rename.Dataset <- rename.Table <- rename.RecordBatch <- rename.arrow_dplyr_query
column_select <- function(.data, ..., .FUN = vars_select) {
# .FUN is either tidyselect::vars_select or tidyselect::vars_rename
# It operates on the names() of selected_columns, i.e. the column names
# factoring in any renaming that may already have happened
out <- .FUN(names(.data), !!!enquos(...))
# Make sure that the resulting selected columns map back to the original data,
# as in when there are multiple renaming steps
.data$selected_columns <- set_names(.data$selected_columns[out], names(out))
# If we've renamed columns, we need to project that renaming into other
# query parameters we've collected
renamed <- out[names(out) != out]
if (length(renamed)) {
# Massage group_by
gbv <- .data$group_by_vars
renamed_groups <- gbv %in% renamed
gbv[renamed_groups] <- names(renamed)[match(gbv[renamed_groups], renamed)]
.data$group_by_vars <- gbv
# No need to massage filters because those contain references to Arrow objects
}
.data
}
filter.arrow_dplyr_query <- function(.data, ..., .preserve = FALSE) {
# TODO something with the .preserve argument
filts <- quos(...)
if (length(filts) == 0) {
# Nothing to do
return(.data)
}
.data <- arrow_dplyr_query(.data)
# The filter() method works by evaluating the filters to generate Expressions
# with references to Arrays (if .data is Table/RecordBatch) or Fields (if
# .data is a Dataset).
dm <- filter_mask(.data)
filters <- lapply(filts, function (f) {
# This should yield an Expression as long as the filter function(s) are
# implemented in Arrow.
tryCatch(eval_tidy(f, dm), error = function(e) {
# Look for the cases where bad input was given, i.e. this would fail
# in regular dplyr anyway, and let those raise those as errors;
# else, for things not supported by Arrow return a "try-error",
# which we'll handle differently
msg <- conditionMessage(e)
# TODO: internationalization?
if (grepl("object '.*'.not.found", msg)) {
stop(e)
}
if (grepl('could not find function ".*"', msg)) {
stop(e)
}
invisible(structure(msg, class = "try-error", condition = e))
})
})
bad_filters <- map_lgl(filters, ~inherits(., "try-error"))
if (any(bad_filters)) {
bads <- oxford_paste(map_chr(filts, as_label)[bad_filters], quote = FALSE)
if (query_on_dataset(.data)) {
# Abort. We don't want to auto-collect if this is a Dataset because that
# could blow up, too big.
stop(
"Filter expression not supported for Arrow Datasets: ", bads,
"\nCall collect() first to pull data into R.",
call. = FALSE
)
} else {
# TODO: only show this in some debug mode?
warning(
"Filter expression not implemented in Arrow: ", bads, "; pulling data into R",
immediate. = TRUE,
call. = FALSE
)
# Set any valid filters first, then collect and then apply the invalid ones in R
.data <- set_filters(.data, filters[!bad_filters])
return(dplyr::filter(dplyr::collect(.data), !!!filts[bad_filters]))
}
}
set_filters(.data, filters)
}
filter.Dataset <- filter.Table <- filter.RecordBatch <- filter.arrow_dplyr_query
# Create a data mask for evaluating a filter expression
filter_mask <- function(.data) {
f_env <- env()
# Insert functions/operators and field references
# TODO: define functions in env once, outside of this function
# filter_env <- env(parent = if (data_is_dataset) function_env1 else function_env2)
if (query_on_dataset(.data)) {
comp_func <- function(operator) {
force(operator)
function(e1, e2) make_expression(operator, e1, e2)
}
var_binder <- function(x) Expression$field_ref(x)
} else {
comp_func <- function(operator) {
force(operator)
function(e1, e2) build_array_expression(operator, e1, e2)
}
var_binder <- function(x) .data$.data[[x]]
}
# First add the functions
func_names <- set_names(names(.array_function_map))
env_bind(f_env, !!!lapply(func_names, comp_func))
# Then add the column references
# Renaming is handled automatically by the named list
env_bind(f_env, !!!lapply(.data$selected_columns, var_binder))
new_data_mask(f_env)
}
set_filters <- function(.data, expressions) {
# expressions is a list of Expressions. AND them together and set them on .data
new_filter <- Reduce("&", expressions)
if (isTRUE(.data$filtered_rows)) {
# TRUE is default (i.e. no filter yet), so we don't need to & with it
.data$filtered_rows <- new_filter
} else {
.data$filtered_rows <- .data$filtered_rows & new_filter
}
.data
}
collect.arrow_dplyr_query <- function(x, as_data_frame = TRUE, ...) {
x <- ensure_group_vars(x)
# Pull only the selected rows and cols into R
if (query_on_dataset(x)) {
# See dataset.R for Dataset and Scanner(Builder) classes
df <- Scanner$create(x)$ToTable()
} else {
# This is a Table/RecordBatch. See record-batch.R for the [ method
df <- x$.data[x$filtered_rows, x$selected_columns, keep_na = FALSE]
}
if (as_data_frame) {
df <- as.data.frame(df)
}
restore_dplyr_features(df, x)
}
collect.Table <- as.data.frame.Table
collect.RecordBatch <- as.data.frame.RecordBatch
collect.Dataset <- function(x, ...) dplyr::collect(arrow_dplyr_query(x), ...)
ensure_group_vars <- function(x) {
if (inherits(x, "arrow_dplyr_query")) {
# Before pulling data from Arrow, make sure all group vars are in the projection
gv <- set_names(setdiff(dplyr::group_vars(x), names(x)))
x$selected_columns <- c(x$selected_columns, gv)
}
x
}
restore_dplyr_features <- function(df, query) {
# An arrow_dplyr_query holds some attributes that Arrow doesn't know about
# After calling collect(), make sure these features are carried over
grouped <- length(query$group_by_vars) > 0
renamed <- !identical(names(df), names(query))
if (is.data.frame(df)) {
# In case variables were renamed, apply those names
if (renamed && ncol(df)) {
names(df) <- names(query)
}
# Preserve groupings, if present
if (grouped) {
df <- dplyr::grouped_df(df, dplyr::group_vars(query))
}
} else if (grouped || renamed) {
# This is a Table, via collect(as_data_frame = FALSE)
df <- arrow_dplyr_query(df)
names(df$selected_columns) <- names(query)
df$group_by_vars <- query$group_by_vars
}
df
}
#' @importFrom tidyselect vars_pull
pull.arrow_dplyr_query <- function(.data, var = -1) {
.data <- arrow_dplyr_query(.data)
var <- vars_pull(names(.data), !!enquo(var))
.data$selected_columns <- set_names(.data$selected_columns[var], var)
dplyr::collect(.data)[[1]]
}
pull.Dataset <- pull.Table <- pull.RecordBatch <- pull.arrow_dplyr_query
summarise.arrow_dplyr_query <- function(.data, ...) {
.data <- arrow_dplyr_query(.data)
if (query_on_dataset(.data)) {
not_implemented_for_dataset("summarize()")
}
# Only retain the columns we need to do our aggregations
vars_to_keep <- unique(c(
unlist(lapply(quos(...), all.vars)), # vars referenced in summarise
dplyr::group_vars(.data) # vars needed for grouping
))
.data <- dplyr::select(.data, vars_to_keep)
# TODO: determine whether work can be pushed down to Arrow
dplyr::summarise(dplyr::collect(.data), ...)
}
summarise.Dataset <- summarise.Table <- summarise.RecordBatch <- summarise.arrow_dplyr_query
group_by.arrow_dplyr_query <- function(.data, ..., .add = FALSE, add = .add) {
.data <- arrow_dplyr_query(.data)
if (".add" %in% names(formals(dplyr::group_by))) {
# dplyr >= 1.0
gv <- dplyr::group_by_prepare(.data, ..., .add = .add)$group_names
} else {
gv <- dplyr::group_by_prepare(.data, ..., add = add)$group_names
}
.data$group_by_vars <- gv
.data
}
group_by.Dataset <- group_by.Table <- group_by.RecordBatch <- group_by.arrow_dplyr_query
groups.arrow_dplyr_query <- function(x) syms(dplyr::group_vars(x))
groups.Dataset <- groups.Table <- groups.RecordBatch <- function(x) NULL
group_vars.arrow_dplyr_query <- function(x) x$group_by_vars
group_vars.Dataset <- group_vars.Table <- group_vars.RecordBatch <- function(x) NULL
ungroup.arrow_dplyr_query <- function(x, ...) {
x$group_by_vars <- character()
x
}
ungroup.Dataset <- ungroup.Table <- ungroup.RecordBatch <- force
mutate.arrow_dplyr_query <- function(.data, ...) {
.data <- arrow_dplyr_query(.data)
if (query_on_dataset(.data)) {
not_implemented_for_dataset("mutate()")
}
# TODO: see if we can defer evaluating the expressions and not collect here.
# It's different from filters (as currently implemented) because the basic
# vector transformation functions aren't yet implemented in Arrow C++.
dplyr::mutate(dplyr::collect(.data), ...)
}
mutate.Dataset <- mutate.Table <- mutate.RecordBatch <- mutate.arrow_dplyr_query
# TODO: add transmute() that does what summarise() does (select only the vars we need)
arrange.arrow_dplyr_query <- function(.data, ...) {
.data <- arrow_dplyr_query(.data)
if (query_on_dataset(.data)) {
not_implemented_for_dataset("arrange()")
}
dplyr::arrange(dplyr::collect(.data), ...)
}
arrange.Dataset <- arrange.Table <- arrange.RecordBatch <- arrange.arrow_dplyr_query
query_on_dataset <- function(x) inherits(x$.data, "Dataset")
not_implemented_for_dataset <- function(method) {
stop(
method, " is not currently implemented for Arrow Datasets. ",
"Call collect() first to pull data into R.",
call. = FALSE
)
}