blob: 4734d44c7ea337b4085e0c8eaeb58d3aaf01502e [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#' @include arrow-package.R
# Base class for Array, ChunkedArray, and Scalar, for S3 method dispatch only.
# Does not exist in C++ class hierarchy
ArrowDatum <- R6Class("ArrowDatum", inherit = ArrowObject,
public = list(
cast = function(target_type, safe = TRUE, ...) {
opts <- cast_options(safe, ...)
opts$to_type <- as_type(target_type)
call_function("cast", self, options = opts)
}
)
)
#' @export
length.ArrowDatum <- function(x) x$length()
#' @export
is.finite.ArrowDatum <- function(x) {
is_fin <- call_function("is_finite", x)
# for compatibility with base::is.finite(), return FALSE for NA_real_
is_fin & !is.na(is_fin)
}
#' @export
is.infinite.ArrowDatum <- function(x) {
is_inf <- call_function("is_inf", x)
# for compatibility with base::is.infinite(), return FALSE for NA_real_
is_inf & !is.na(is_inf)
}
#' @export
is.na.ArrowDatum <- function(x) {
# TODO: if an option is added to the is_null kernel to treat NaN as NA,
# use that to simplify the code here (ARROW-13367)
if (x$type_id() %in% TYPES_WITH_NAN) {
call_function("is_nan", x) | call_function("is_null", x)
} else {
call_function("is_null", x)
}
}
#' @export
is.nan.ArrowDatum <- function(x) {
if (x$type_id() %in% TYPES_WITH_NAN) {
# TODO: if an option is added to the is_nan kernel to treat NA as NaN,
# use that to simplify the code here (ARROW-13366)
call_function("is_nan", x) & call_function("is_valid", x)
} else {
Scalar$create(FALSE)$as_array(length(x))
}
}
#' @export
as.vector.ArrowDatum <- function(x, mode) {
tryCatch(
x$as_vector(),
error = handle_embedded_nul_error
)
}
#' @export
Ops.ArrowDatum <- function(e1, e2) {
if (.Generic == "!") {
eval_array_expression(.Generic, e1)
} else if (.Generic %in% names(.array_function_map)) {
eval_array_expression(.Generic, e1, e2)
} else {
stop(paste0("Unsupported operation on `", class(e1)[1L], "` : "), .Generic, call. = FALSE)
}
}
# Wrapper around call_function that:
# (1) maps R function names to Arrow C++ compute ("/" --> "divide_checked")
# (2) wraps R input args as Array or Scalar
eval_array_expression <- function(FUN,
...,
args = list(...),
options = empty_named_list()) {
if (FUN == "-" && length(args) == 1L) {
if (inherits(args[[1]], "ArrowObject")) {
return(eval_array_expression("negate_checked", args[[1]]))
} else {
return(-args[[1]])
}
}
args <- lapply(args, .wrap_arrow, FUN)
# In Arrow, "divide" is one function, which does integer division on
# integer inputs and floating-point division on floats
if (FUN == "/") {
# TODO: omg so many ways it's wrong to assume these types
args <- map(args, ~.$cast(float64()))
} else if (FUN == "%/%") {
# In R, integer division works like floor(float division)
out <- eval_array_expression("/", args = args, options = options)
return(out$cast(int32(), allow_float_truncate = TRUE))
} else if (FUN == "%%") {
# {e1 - e2 * ( e1 %/% e2 )}
# ^^^ form doesn't work because Ops.Array evaluates eagerly,
# but we can build that up
quotient <- eval_array_expression("%/%", args = args)
base <- eval_array_expression("*", quotient, args[[2]])
# this cast is to ensure that the result of this and e1 are the same
# (autocasting only applies to scalars)
base <- base$cast(args[[1]]$type)
return(eval_array_expression("-", args[[1]], base))
}
call_function(
.array_function_map[[FUN]] %||% FUN,
args = args,
options = options
)
}
.wrap_arrow <- function(arg, fun) {
if (!inherits(arg, "ArrowObject")) {
# TODO: Array$create if lengths are equal?
if (fun == "%in%") {
arg <- Array$create(arg)
} else {
arg <- Scalar$create(arg)
}
}
arg
}
#' @export
na.omit.ArrowDatum <- function(object, ...) {
object$Filter(!is.na(object))
}
#' @export
na.exclude.ArrowDatum <- na.omit.ArrowDatum
#' @export
na.fail.ArrowDatum <- function(object, ...) {
if (object$null_count > 0) {
stop("missing values in object", call. = FALSE)
}
object
}
filter_rows <- function(x, i, keep_na = TRUE, ...) {
# General purpose function for [ row subsetting with R semantics
# Based on the input for `i`, calls x$Filter, x$Slice, or x$Take
nrows <- x$num_rows %||% x$length() # Depends on whether Array or Table-like
if (is.logical(i)) {
if (isTRUE(i)) {
# Shortcut without doing any work
x
} else {
i <- rep_len(i, nrows) # For R recycling behavior; consider vctrs::vec_recycle()
x$Filter(i, keep_na)
}
} else if (is.numeric(i)) {
if (all(i < 0)) {
# in R, negative i means "everything but i"
i <- setdiff(seq_len(nrows), -1 * i)
}
if (is.sliceable(i)) {
x$Slice(i[1] - 1, length(i))
} else if (all(i > 0)) {
x$Take(i - 1)
} else {
stop("Cannot mix positive and negative indices", call. = FALSE)
}
} else if (is.Array(i, INTEGER_TYPES)) {
# NOTE: this doesn't do the - 1 offset
x$Take(i)
} else if (is.Array(i, "bool")) {
x$Filter(i, keep_na)
} else {
# Unsupported cases
if (is.Array(i)) {
stop("Cannot extract rows with an Array of type ", i$type$ToString(), call. = FALSE)
}
stop("Cannot extract rows with an object of class ", class(i), call.=FALSE)
}
}
#' @export
`[.ArrowDatum` <- filter_rows
#' @importFrom utils head
#' @export
head.ArrowDatum <- function(x, n = 6L, ...) {
assert_is(n, c("numeric", "integer"))
assert_that(length(n) == 1)
len <- NROW(x)
if (n < 0) {
# head(x, negative) means all but the last n rows
n <- max(len + n, 0)
} else {
n <- min(len, n)
}
if (n == len) {
return(x)
}
x$Slice(0, n)
}
#' @importFrom utils tail
#' @export
tail.ArrowDatum <- function(x, n = 6L, ...) {
assert_is(n, c("numeric", "integer"))
assert_that(length(n) == 1)
len <- NROW(x)
if (n < 0) {
# tail(x, negative) means all but the first n rows
n <- min(-n, len)
} else {
n <- max(len - n, 0)
}
if (n == 0) {
return(x)
}
x$Slice(n)
}
is.sliceable <- function(i) {
# Determine whether `i` can be expressed as a $Slice() command
is.numeric(i) &&
length(i) > 0 &&
all(i > 0) &&
identical(as.integer(i), i[1]:i[length(i)])
}
#' @export
as.double.ArrowDatum <- function(x, ...) as.double(as.vector(x), ...)
#' @export
as.integer.ArrowDatum <- function(x, ...) as.integer(as.vector(x), ...)
#' @export
as.character.ArrowDatum <- function(x, ...) as.character(as.vector(x), ...)
#' @export
sort.ArrowDatum <- function(x, decreasing = FALSE, na.last = NA, ...) {
# Arrow always sorts nulls at the end of the array. This corresponds to
# sort(na.last = TRUE). For the other two cases (na.last = NA and
# na.last = FALSE) we need to use workarounds.
# TODO: Implement this more cleanly after ARROW-12063
if (is.na(na.last)) {
# Filter out NAs before sorting
x <- x$Filter(!is.na(x))
x$Take(x$SortIndices(descending = decreasing))
} else if (na.last) {
x$Take(x$SortIndices(descending = decreasing))
} else {
# Create a new array that encodes missing values as 1 and non-missing values
# as 0. Sort descending by that array first to get the NAs at the beginning
tbl <- Table$create(x = x, `is_na` = as.integer(is.na(x)))
tbl$x$Take(tbl$SortIndices(names = c("is_na", "x"), descending = c(TRUE, decreasing)))
}
}