| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| register_bindings_conditional <- function() { |
| register_binding("%in%", function(x, table) { |
| # We use `is_in` here, unlike with Arrays, which use `is_in_meta_binary` |
| value_set <- Array$create(table) |
| # If possible, `table` should be the same type as `x` |
| # Try downcasting here; otherwise Acero may upcast x to table's type |
| try( |
| value_set <- cast_or_parse(value_set, x$type()), |
| silent = TRUE |
| ) |
| |
| expr <- Expression$create("is_in", x, |
| options = list( |
| value_set = value_set, |
| skip_nulls = TRUE |
| ) |
| ) |
| }) |
| |
| register_binding("dplyr::coalesce", function(...) { |
| args <- list2(...) |
| if (length(args) < 1) { |
| abort("At least one argument must be supplied to coalesce()") |
| } |
| |
| # Treat NaN like NA for consistency with dplyr::coalesce(), but if *all* |
| # the values are NaN, we should return NaN, not NA, so don't replace |
| # NaN with NA in the final (or only) argument |
| # TODO: if an option is added to the coalesce kernel to treat NaN as NA, |
| # use that to simplify the code here (ARROW-13389) |
| attr(args[[length(args)]], "last") <- TRUE |
| args <- lapply(args, function(arg) { |
| last_arg <- is.null(attr(arg, "last")) |
| attr(arg, "last") <- NULL |
| |
| if (!inherits(arg, "Expression")) { |
| arg <- Expression$scalar(arg) |
| } |
| |
| if (last_arg && arg$type_id() %in% TYPES_WITH_NAN) { |
| # store the NA_real_ in the same type as arg to avoid avoid casting |
| # smaller float types to larger float types |
| NA_expr <- Expression$scalar(Scalar$create(NA_real_, type = arg$type())) |
| Expression$create("if_else", Expression$create("is_nan", arg), NA_expr, arg) |
| } else { |
| arg |
| } |
| }) |
| Expression$create("coalesce", args = args) |
| }) |
| |
| # Although base R ifelse allows `yes` and `no` to be different classes |
| register_binding("base::ifelse", function(test, yes, no) { |
| args <- list(test, yes, no) |
| # For if_else, the first arg should be a bool Expression, and we don't |
| # want to consider that when casting the other args to the same type. |
| # But ideally `yes` and `no` args should be the same type. |
| args[-1] <- cast_scalars_to_common_type(args[-1]) |
| |
| Expression$create("if_else", args = args) |
| }) |
| |
| register_binding("dplyr::if_else", function(condition, true, false, missing = NULL) { |
| out <- call_binding("base::ifelse", condition, true, false) |
| if (!is.null(missing)) { |
| out <- call_binding( |
| "base::ifelse", |
| call_binding("is.na", condition), |
| missing, |
| out |
| ) |
| } |
| out |
| }) |
| |
| register_binding("dplyr::case_when", function(..., .default = NULL, .ptype = NULL, .size = NULL) { |
| if (!is.null(.ptype)) { |
| arrow_not_supported("`case_when()` with `.ptype` specified") |
| } |
| |
| if (!is.null(.size)) { |
| arrow_not_supported("`case_when()` with `.size` specified") |
| } |
| |
| formulas <- list2(...) |
| n <- length(formulas) |
| if (n == 0) { |
| abort("No cases provided in case_when()") |
| } |
| query <- vector("list", n) |
| value <- vector("list", n) |
| mask <- caller_env() |
| for (i in seq_len(n)) { |
| f <- formulas[[i]] |
| if (!inherits(f, "formula")) { |
| abort("Each argument to case_when() must be a two-sided formula") |
| } |
| query[[i]] <- arrow_eval(f[[2]], mask) |
| value[[i]] <- arrow_eval(f[[3]], mask) |
| if (!call_binding("is.logical", query[[i]])) { |
| abort("Left side of each formula in case_when() must be a logical expression") |
| } |
| if (inherits(value[[i]], "try-error")) { |
| abort(handle_arrow_not_supported(value[[i]], format_expr(f[[3]]))) |
| } |
| } |
| if (!is.null(.default)) { |
| if (length(.default) != 1) { |
| abort(paste0("`.default` must have size 1, not size ", length(.default), ".")) |
| } |
| |
| query[n + 1] <- TRUE |
| value[n + 1] <- .default |
| } |
| Expression$create( |
| "case_when", |
| args = c( |
| Expression$create( |
| "make_struct", |
| args = query, |
| options = list(field_names = as.character(seq_along(query))) |
| ), |
| value |
| ) |
| ) |
| }, notes = "`.ptype` and `.size` arguments not supported" |
| ) |
| } |