blob: 264cbfc9ba9295410907e38f17f5e45ab0d93091 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Utilities and Helpers
# Given a JList<T>, returns an R list containing the same elements, the number
# of which is optionally upper bounded by `logicalUpperBound` (by default,
# return all elements). Takes care of deserializations and type conversions.
convertJListToRList <- function(jList, flatten, logicalUpperBound = NULL,
serializedMode = "byte") {
arrSize <- callJMethod(jList, "size")
# Datasets with serializedMode == "string" (such as an RDD directly generated by textFile()):
# each partition is not dense-packed into one Array[Byte], and `arrSize`
# here corresponds to number of logical elements. Thus we can prune here.
if (serializedMode == "string" && !is.null(logicalUpperBound)) {
arrSize <- min(arrSize, logicalUpperBound)
}
results <- if (arrSize > 0) {
lapply(0 : (arrSize - 1),
function(index) {
obj <- callJMethod(jList, "get", as.integer(index))
# Assume it is either an R object or a Java obj ref.
if (inherits(obj, "jobj")) {
if (isInstanceOf(obj, "scala.Tuple2")) {
# JavaPairRDD[Array[Byte], Array[Byte]].
keyBytes <- callJMethod(obj, "_1")
valBytes <- callJMethod(obj, "_2")
res <- list(unserialize(keyBytes),
unserialize(valBytes))
} else {
stop("utils.R: convertJListToRList only supports ",
"RDD[Array[Byte]] and ",
"JavaPairRDD[Array[Byte], Array[Byte]] for now")
}
} else {
if (inherits(obj, "raw")) {
if (serializedMode == "byte") {
# RDD[Array[Byte]]. `obj` is a whole partition.
res <- unserialize(obj)
# For serialized datasets, `obj` (and `rRaw`) here corresponds to
# one whole partition dense-packed together. We deserialize the
# whole partition first, then cap the number of elements to be returned.
} else if (serializedMode == "row") {
res <- readRowList(obj)
# For DataFrames that have been converted to RRDDs, we call readRowList
# which will read in each row of the RRDD as a list and deserialize
# each element.
flatten <<- FALSE
# Use global assignment to change the flatten flag. This means
# we don't have to worry about the default argument in other functions
# e.g. collect
}
# TODO: is it possible to distinguish element boundary so that we can
# unserialize only what we need?
if (!is.null(logicalUpperBound)) {
res <- head(res, n = logicalUpperBound)
}
} else {
# obj is of a primitive Java type, is simplified to R's
# corresponding type.
res <- list(obj)
}
}
res
})
} else {
list()
}
if (flatten) {
as.list(unlist(results, recursive = FALSE))
} else {
as.list(results)
}
}
# Returns TRUE if `name` refers to an RDD in the given environment `env`
isRDD <- function(name, env) {
obj <- get(name, envir = env)
inherits(obj, "RDD")
}
#' Compute the hashCode of an object
#'
#' Java-style function to compute the hashCode for the given object. Returns
#' an integer value.
#'
#' @details
#' This only works for integer, numeric and character types right now.
#'
#' @param key the object to be hashed
#' @return the hash code as an integer
#' @examples
#'\dontrun{
#' hashCode(1L) # 1
#' hashCode(1.0) # 1072693248
#' hashCode("1") # 49
#'}
#' @note hashCode since 1.4.0
hashCode <- function(key) {
if (class(key) == "integer") {
as.integer(key[[1]])
} else if (class(key) == "numeric") {
# Convert the double to long and then calculate the hash code
rawVec <- writeBin(key[[1]], con = raw())
intBits <- packBits(rawToBits(rawVec), "integer")
as.integer(bitwXor(intBits[2], intBits[1]))
} else if (class(key) == "character") {
# TODO: SPARK-7839 means we might not have the native library available
n <- nchar(key)
if (n == 0) {
0L
} else {
asciiVals <- sapply(charToRaw(key), function(x) { strtoi(x, 16L) })
hashC <- 0
for (k in seq_len(length(asciiVals))) {
hashC <- mult31AndAdd(hashC, asciiVals[k])
}
as.integer(hashC)
}
} else {
warning("Could not hash object, returning 0")
as.integer(0)
}
}
# Helper function used to wrap a 'numeric' value to integer bounds.
# Useful for implementing C-like integer arithmetic
wrapInt <- function(value) {
if (value > .Machine$integer.max) {
value <- value - 2 * .Machine$integer.max - 2
} else if (value < -1 * .Machine$integer.max) {
value <- 2 * .Machine$integer.max + value + 2
}
value
}
# Multiply `val` by 31 and add `addVal` to the result. Ensures that
# integer-overflows are handled at every step.
#
# TODO: this function does not handle integer overflow well
mult31AndAdd <- function(val, addVal) {
vec <- c(bitwShiftL(val, c(4, 3, 2, 1, 0)), addVal)
vec[is.na(vec)] <- 0
Reduce(function(a, b) {
wrapInt(as.numeric(a) + as.numeric(b))
},
vec)
}
# Create a new RDD with serializedMode == "byte".
# Return itself if already in "byte" format.
serializeToBytes <- function(rdd) {
if (!inherits(rdd, "RDD")) {
stop("Argument 'rdd' is not an RDD type.")
}
if (getSerializedMode(rdd) != "byte") {
ser.rdd <- lapply(rdd, function(x) { x })
return(ser.rdd)
} else {
return(rdd)
}
}
# Create a new RDD with serializedMode == "string".
# Return itself if already in "string" format.
serializeToString <- function(rdd) {
if (!inherits(rdd, "RDD")) {
stop("Argument 'rdd' is not an RDD type.")
}
if (getSerializedMode(rdd) != "string") {
ser.rdd <- lapply(rdd, function(x) { toString(x) })
# force it to create jrdd using "string"
getJRDD(ser.rdd, serializedMode = "string")
return(ser.rdd)
} else {
return(rdd)
}
}
# Fast append to list by using an accumulator.
# http://stackoverflow.com/questions/17046336/here-we-go-again-append-an-element-to-a-list-in-r
#
# The accumulator should has three fields size, counter and data.
# This function amortizes the allocation cost by doubling
# the size of the list every time it fills up.
addItemToAccumulator <- function(acc, item) {
if (acc$counter == acc$size) {
acc$size <- acc$size * 2
length(acc$data) <- acc$size
}
acc$counter <- acc$counter + 1
acc$data[[acc$counter]] <- item
}
initAccumulator <- function() {
acc <- new.env()
acc$counter <- 0
acc$data <- list(NULL)
acc$size <- 1
acc
}
# Utility function to sort a list of key value pairs
# Used in unit tests
sortKeyValueList <- function(kv_list, decreasing = FALSE) {
keys <- sapply(kv_list, function(x) x[[1]])
kv_list[order(keys, decreasing = decreasing)]
}
# Utility function to generate compact R lists from grouped rdd
# Used in Join-family functions
# param:
# tagged_list R list generated via groupByKey with tags(1L, 2L, ...)
# cnull Boolean list where each element determines whether the corresponding list should
# be converted to list(NULL)
genCompactLists <- function(tagged_list, cnull) {
len <- length(tagged_list)
lists <- list(vector("list", len), vector("list", len))
index <- list(1, 1)
for (x in tagged_list) {
tag <- x[[1]]
idx <- index[[tag]]
lists[[tag]][[idx]] <- x[[2]]
index[[tag]] <- idx + 1
}
len <- lapply(index, function(x) x - 1)
for (i in (1:2)) {
if (cnull[[i]] && len[[i]] == 0) {
lists[[i]] <- list(NULL)
} else {
length(lists[[i]]) <- len[[i]]
}
}
lists
}
# Utility function to merge compact R lists
# Used in Join-family functions
# param:
# left/right Two compact lists ready for Cartesian product
mergeCompactLists <- function(left, right) {
result <- list()
length(result) <- length(left) * length(right)
index <- 1
for (i in left) {
for (j in right) {
result[[index]] <- list(i, j)
index <- index + 1
}
}
result
}
# Utility function to wrapper above two operations
# Used in Join-family functions
# param (same as genCompactLists):
# tagged_list R list generated via groupByKey with tags(1L, 2L, ...)
# cnull Boolean list where each element determines whether the corresponding list should
# be converted to list(NULL)
joinTaggedList <- function(tagged_list, cnull) {
lists <- genCompactLists(tagged_list, cnull)
mergeCompactLists(lists[[1]], lists[[2]])
}
# Utility function to reduce a key-value list with predicate
# Used in *ByKey functions
# param
# pair key-value pair
# keys/vals env of key/value with hashes
# updateOrCreatePred predicate function
# updateFn update or merge function for existing pair, similar with `mergeVal` @combineByKey
# createFn create function for new pair, similar with `createCombiner` @combinebykey
updateOrCreatePair <- function(pair, keys, vals, updateOrCreatePred, updateFn, createFn) {
# assume hashVal bind to `$hash`, key/val with index 1/2
hashVal <- pair$hash
key <- pair[[1]]
val <- pair[[2]]
if (updateOrCreatePred(pair)) {
assign(hashVal, do.call(updateFn, list(get(hashVal, envir = vals), val)), envir = vals)
} else {
assign(hashVal, do.call(createFn, list(val)), envir = vals)
assign(hashVal, key, envir = keys)
}
}
# Utility function to convert key&values envs into key-val list
convertEnvsToList <- function(keys, vals) {
lapply(ls(keys),
function(name) {
list(keys[[name]], vals[[name]])
})
}
# Utility function to merge 2 environments with the second overriding values in the first
# env1 is changed in place
overrideEnvs <- function(env1, env2) {
lapply(ls(env2),
function(name) {
env1[[name]] <- env2[[name]]
})
}
# Utility function to capture the varargs into environment object
varargsToEnv <- function(...) {
# Based on http://stackoverflow.com/a/3057419/4577954
pairs <- list(...)
env <- new.env()
for (name in names(pairs)) {
env[[name]] <- pairs[[name]]
}
env
}
# Utility function to capture the varargs into environment object but all values are converted
# into string.
varargsToStrEnv <- function(...) {
pairs <- list(...)
nameList <- names(pairs)
env <- new.env()
ignoredNames <- list()
if (is.null(nameList)) {
# When all arguments are not named, names(..) returns NULL.
ignoredNames <- pairs
} else {
for (i in seq_along(pairs)) {
name <- nameList[i]
value <- pairs[i]
if (identical(name, "")) {
# When some of arguments are not named, name is "".
ignoredNames <- append(ignoredNames, value)
} else {
value <- pairs[[name]]
if (!(is.logical(value) || is.numeric(value) || is.character(value) || is.null(value))) {
stop("Unsupported type for ", name, " : ", toString(class(value)), ". ",
"Supported types are logical, numeric, character and NULL.", call. = FALSE)
}
if (is.logical(value)) {
env[[name]] <- tolower(as.character(value))
} else if (is.null(value)) {
env[[name]] <- value
} else {
env[[name]] <- as.character(value)
}
}
}
}
if (length(ignoredNames) != 0) {
warning("Unnamed arguments ignored: ", toString(ignoredNames), ".", call. = FALSE)
}
env
}
getStorageLevel <- function(newLevel = c("DISK_ONLY",
"DISK_ONLY_2",
"DISK_ONLY_3",
"MEMORY_AND_DISK",
"MEMORY_AND_DISK_2",
"MEMORY_AND_DISK_SER",
"MEMORY_AND_DISK_SER_2",
"MEMORY_ONLY",
"MEMORY_ONLY_2",
"MEMORY_ONLY_SER",
"MEMORY_ONLY_SER_2",
"OFF_HEAP")) {
match.arg(newLevel)
storageLevelClass <- "org.apache.spark.storage.StorageLevel"
storageLevel <- switch(newLevel,
"DISK_ONLY" = callJStatic(storageLevelClass, "DISK_ONLY"),
"DISK_ONLY_2" = callJStatic(storageLevelClass, "DISK_ONLY_2"),
"DISK_ONLY_3" = callJStatic(storageLevelClass, "DISK_ONLY_3"),
"MEMORY_AND_DISK" = callJStatic(storageLevelClass, "MEMORY_AND_DISK"),
"MEMORY_AND_DISK_2" = callJStatic(storageLevelClass, "MEMORY_AND_DISK_2"),
"MEMORY_AND_DISK_SER" = callJStatic(storageLevelClass,
"MEMORY_AND_DISK_SER"),
"MEMORY_AND_DISK_SER_2" = callJStatic(storageLevelClass,
"MEMORY_AND_DISK_SER_2"),
"MEMORY_ONLY" = callJStatic(storageLevelClass, "MEMORY_ONLY"),
"MEMORY_ONLY_2" = callJStatic(storageLevelClass, "MEMORY_ONLY_2"),
"MEMORY_ONLY_SER" = callJStatic(storageLevelClass, "MEMORY_ONLY_SER"),
"MEMORY_ONLY_SER_2" = callJStatic(storageLevelClass, "MEMORY_ONLY_SER_2"),
"OFF_HEAP" = callJStatic(storageLevelClass, "OFF_HEAP"))
}
storageLevelToString <- function(levelObj) {
useDisk <- callJMethod(levelObj, "useDisk")
useMemory <- callJMethod(levelObj, "useMemory")
useOffHeap <- callJMethod(levelObj, "useOffHeap")
deserialized <- callJMethod(levelObj, "deserialized")
replication <- callJMethod(levelObj, "replication")
shortName <- if (!useDisk && !useMemory && !useOffHeap && !deserialized && replication == 1) {
"NONE"
} else if (useDisk && !useMemory && !useOffHeap && !deserialized && replication == 1) {
"DISK_ONLY"
} else if (useDisk && !useMemory && !useOffHeap && !deserialized && replication == 2) {
"DISK_ONLY_2"
} else if (useDisk && !useMemory && !useOffHeap && !deserialized && replication == 3) {
"DISK_ONLY_3"
} else if (!useDisk && useMemory && !useOffHeap && deserialized && replication == 1) {
"MEMORY_ONLY"
} else if (!useDisk && useMemory && !useOffHeap && deserialized && replication == 2) {
"MEMORY_ONLY_2"
} else if (!useDisk && useMemory && !useOffHeap && !deserialized && replication == 1) {
"MEMORY_ONLY_SER"
} else if (!useDisk && useMemory && !useOffHeap && !deserialized && replication == 2) {
"MEMORY_ONLY_SER_2"
} else if (useDisk && useMemory && !useOffHeap && deserialized && replication == 1) {
"MEMORY_AND_DISK"
} else if (useDisk && useMemory && !useOffHeap && deserialized && replication == 2) {
"MEMORY_AND_DISK_2"
} else if (useDisk && useMemory && !useOffHeap && !deserialized && replication == 1) {
"MEMORY_AND_DISK_SER"
} else if (useDisk && useMemory && !useOffHeap && !deserialized && replication == 2) {
"MEMORY_AND_DISK_SER_2"
} else if (useDisk && useMemory && useOffHeap && !deserialized && replication == 1) {
"OFF_HEAP"
} else {
NULL
}
fullInfo <- callJMethod(levelObj, "toString")
if (is.null(shortName)) {
fullInfo
} else {
paste(shortName, "-", fullInfo)
}
}
# Utility function for functions where an argument needs to be integer but we want to allow
# the user to type (for example) `5` instead of `5L` to avoid a confusing error message.
numToInt <- function(num) {
if (as.integer(num) != num) {
warning("Coercing ", as.list(sys.call())[[2L]], " to integer.")
}
as.integer(num)
}
# Utility function to recursively traverse the Abstract Syntax Tree (AST) of a
# user defined function (UDF), and to examine variables in the UDF to decide
# if their values should be included in the new function environment.
# param
# node The current AST node in the traversal.
# oldEnv The original function environment.
# defVars An Accumulator of variables names defined in the function's calling environment,
# including function argument and local variable names.
# checkedFunc An environment of function objects examined during cleanClosure. It can
# be considered as a "name"-to-"list of functions" mapping.
# newEnv A new function environment to store necessary function dependencies, an output argument.
processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) {
nodeLen <- length(node)
if (nodeLen > 1 && typeof(node) == "language") {
# Recursive case: current AST node is an internal node, check for its children.
if (length(node[[1]]) > 1) {
for (i in 1:nodeLen) {
processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv)
}
} else {
# if node[[1]] is length of 1, check for some R special functions.
nodeChar <- as.character(node[[1]])
if (nodeChar == "{" || nodeChar == "(") {
# Skip start symbol.
for (i in 2:nodeLen) {
processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv)
}
} else if (nodeChar == "<-" || nodeChar == "=" ||
nodeChar == "<<-") {
# Assignment Ops.
defVar <- node[[2]]
if (length(defVar) == 1 && typeof(defVar) == "symbol") {
# Add the defined variable name into defVars.
addItemToAccumulator(defVars, as.character(defVar))
} else {
processClosure(node[[2]], oldEnv, defVars, checkedFuncs, newEnv)
}
for (i in 3:nodeLen) {
processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv)
}
} else if (nodeChar == "function") {
# Function definition.
# Add parameter names.
newArgs <- names(node[[2]])
lapply(newArgs, function(arg) { addItemToAccumulator(defVars, arg) })
for (i in 3:nodeLen) {
processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv)
}
} else if (nodeChar == "$") {
# Skip the field.
processClosure(node[[2]], oldEnv, defVars, checkedFuncs, newEnv)
} else if (nodeChar == "::" || nodeChar == ":::") {
processClosure(node[[3]], oldEnv, defVars, checkedFuncs, newEnv)
} else {
for (i in 1:nodeLen) {
processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv)
}
}
}
} else if (nodeLen == 1 &&
(typeof(node) == "symbol" || typeof(node) == "language")) {
# Base case: current AST node is a leaf node and a symbol or a function call.
nodeChar <- as.character(node)
if (!nodeChar %in% defVars$data) {
# Not a function parameter or local variable.
func.env <- oldEnv
topEnv <- parent.env(.GlobalEnv)
# Search in function environment, and function's enclosing environments
# up to global environment. There is no need to look into package environments
# above the global or namespace environment that is not SparkR below the global,
# as they are assumed to be loaded on workers.
while (!identical(func.env, topEnv)) {
# Namespaces other than "SparkR" will not be searched.
if (!isNamespace(func.env) ||
(getNamespaceName(func.env) == "SparkR" &&
!(nodeChar %in% getNamespaceExports("SparkR")) &&
# Note that generic S4 methods should not be set to the environment of
# cleaned closure. It does not work with R 4.0.0+. See also SPARK-31918.
nodeChar != "" && !methods::isGeneric(nodeChar, func.env))) {
# Only include SparkR internals.
# Set parameter 'inherits' to FALSE since we do not need to search in
# attached package environments.
if (tryCatch(exists(nodeChar, envir = func.env, inherits = FALSE),
error = function(e) { FALSE })) {
obj <- get(nodeChar, envir = func.env, inherits = FALSE)
if (is.function(obj)) {
# If the node is a function call.
funcList <- mget(nodeChar, envir = checkedFuncs, inherits = F,
ifnotfound = list(list(NULL)))[[1]]
found <- sapply(funcList, function(func) {
ifelse(
identical(func, obj) &&
# Also check if the parent environment is identical to current parent
identical(parent.env(environment(func)), func.env),
TRUE, FALSE)
})
if (sum(found) > 0) {
# If function has been examined ignore
break
}
# Function has not been examined, record it and recursively clean its closure.
assign(nodeChar,
if (is.null(funcList[[1]])) {
list(obj)
} else {
append(funcList, obj)
},
envir = checkedFuncs)
obj <- cleanClosure(obj, checkedFuncs)
}
assign(nodeChar, obj, envir = newEnv)
break
}
}
# Continue to search in enclosure.
func.env <- parent.env(func.env)
}
}
}
}
# Utility function to get user defined function (UDF) dependencies (closure).
# More specifically, this function captures the values of free variables defined
# outside a UDF, and stores them in the function's environment.
# param
# func A function whose closure needs to be captured.
# checkedFunc An environment of function objects examined during cleanClosure. It can be
# considered as a "name"-to-"list of functions" mapping.
# return value
# a new version of func that has a correct environment (closure).
cleanClosure <- function(func, checkedFuncs = new.env()) {
if (is.function(func)) {
newEnv <- new.env(parent = .GlobalEnv)
func.body <- body(func)
oldEnv <- environment(func)
# defVars is an Accumulator of variables names defined in the function's calling
# environment. First, function's arguments are added to defVars.
defVars <- initAccumulator()
argNames <- names(as.list(args(func)))
for (i in 1:(length(argNames) - 1)) {
# Remove the ending NULL in pairlist.
addItemToAccumulator(defVars, argNames[i])
}
# Recursively examine variables in the function body.
processClosure(func.body, oldEnv, defVars, checkedFuncs, newEnv)
environment(func) <- newEnv
}
func
}
# Append partition lengths to each partition in two input RDDs if needed.
# param
# x An RDD.
# Other An RDD.
# return value
# A list of two result RDDs.
appendPartitionLengths <- function(x, other) {
if (getSerializedMode(x) != getSerializedMode(other) ||
getSerializedMode(x) == "byte") {
# Append the number of elements in each partition to that partition so that we can later
# know the boundary of elements from x and other.
#
# Note that this appending also serves the purpose of reserialization, because even if
# any RDD is serialized, we need to reserialize it to make sure its partitions are encoded
# as a single byte array. For example, partitions of an RDD generated from partitionBy()
# may be encoded as multiple byte arrays.
appendLength <- function(part) {
len <- length(part)
part[[len + 1]] <- len + 1
part
}
x <- lapplyPartition(x, appendLength)
other <- lapplyPartition(other, appendLength)
}
list(x, other)
}
# Perform zip or cartesian between elements from two RDDs in each partition
# param
# rdd An RDD.
# zip A boolean flag indicating this call is for zip operation or not.
# return value
# A result RDD.
mergePartitions <- function(rdd, zip) {
serializerMode <- getSerializedMode(rdd)
partitionFunc <- function(partIndex, part) {
len <- length(part)
if (len > 0) {
if (serializerMode == "byte") {
lengthOfValues <- part[[len]]
lengthOfKeys <- part[[len - lengthOfValues]]
stopifnot(len == lengthOfKeys + lengthOfValues)
# For zip operation, check if corresponding partitions
# of both RDDs have the same number of elements.
if (zip && lengthOfKeys != lengthOfValues) {
stop("Can only zip RDDs with same number of elements ",
"in each pair of corresponding partitions.")
}
if (lengthOfKeys > 1) {
keys <- part[1 : (lengthOfKeys - 1)]
} else {
keys <- list()
}
if (lengthOfValues > 1) {
values <- part[(lengthOfKeys + 1) : (len - 1)]
} else {
values <- list()
}
if (!zip) {
return(mergeCompactLists(keys, values))
}
} else {
keys <- part[c(TRUE, FALSE)]
values <- part[c(FALSE, TRUE)]
}
mapply(
function(k, v) { list(k, v) },
keys,
values,
SIMPLIFY = FALSE,
USE.NAMES = FALSE)
} else {
part
}
}
PipelinedRDD(rdd, partitionFunc)
}
# Convert a named list to struct so that
# SerDe won't confuse between a normal named list and struct
listToStruct <- function(list) {
stopifnot(class(list) == "list")
stopifnot(!is.null(names(list)))
class(list) <- "struct"
list
}
# Convert a struct to a named list
structToList <- function(struct) {
stopifnot(class(list) == "struct")
class(struct) <- "list"
struct
}
# Convert a named list to an environment to be passed to JVM
convertNamedListToEnv <- function(namedList) {
# Make sure each item in the list has a name
names <- names(namedList)
stopifnot(
if (is.null(names)) {
length(namedList) == 0
} else {
!any(is.na(names))
})
env <- new.env()
for (name in names) {
env[[name]] <- namedList[[name]]
}
env
}
# Assign a new environment for attach() and with() methods
assignNewEnv <- function(data) {
stopifnot(class(data) == "SparkDataFrame")
cols <- columns(data)
stopifnot(length(cols) > 0)
env <- new.env()
for (i in seq_len(length(cols))) {
assign(x = cols[i], value = data[, cols[i], drop = F], envir = env)
}
env
}
# Utility function to split by ',' and whitespace, remove empty tokens
splitString <- function(input) {
Filter(nzchar, unlist(strsplit(input, ",|\\s")))
}
varargsToJProperties <- function(...) {
pairs <- list(...)
props <- newJObject("java.util.Properties")
if (length(pairs) > 0) {
lapply(ls(pairs), function(k) {
callJMethod(props, "setProperty", as.character(k), as.character(pairs[[k]]))
})
}
props
}
launchScript <- function(script, combinedArgs, wait = FALSE, stdout = "", stderr = "") {
if (.Platform$OS.type == "windows") {
scriptWithArgs <- paste(script, combinedArgs, sep = " ")
# on Windows, intern = F seems to mean output to the console. (documentation on this is missing)
shell(scriptWithArgs, translate = TRUE, wait = wait, intern = wait)
} else {
# http://stat.ethz.ch/R-manual/R-devel/library/base/html/system2.html
# stdout = F means discard output
# stdout = "" means to its console (default)
# Note that the console of this child process might not be the same as the running R process.
system2(script, combinedArgs, stdout = stdout, wait = wait, stderr = stderr)
}
}
getSparkContext <- function() {
if (!exists(".sparkRjsc", envir = .sparkREnv)) {
stop("SparkR has not been initialized. Please call sparkR.session()")
}
sc <- get(".sparkRjsc", envir = .sparkREnv)
sc
}
isMasterLocal <- function(master) {
grepl("^local(\\[([0-9]+|\\*)\\])?$", master, perl = TRUE)
}
isClientMode <- function(master) {
grepl("([a-z]+)-client$", master, perl = TRUE)
}
isSparkRShell <- function() {
grepl(".*shell\\.R$", Sys.getenv("R_PROFILE_USER"), perl = TRUE)
}
# Works identically with `callJStatic(...)` but throws a pretty formatted exception.
handledCallJStatic <- function(cls, method, ...) {
result <- tryCatch(callJStatic(cls, method, ...),
error = function(e) {
captureJVMException(e, method)
})
result
}
# Works identically with `callJMethod(...)` but throws a pretty formatted exception.
handledCallJMethod <- function(obj, method, ...) {
result <- tryCatch(callJMethod(obj, method, ...),
error = function(e) {
captureJVMException(e, method)
})
result
}
captureJVMException <- function(e, method) {
rawmsg <- as.character(e)
if (any(grepl("^Error in .*?: ", rawmsg))) {
# If the exception message starts with "Error in ...", this is possibly
# "Error in invokeJava(...)". Here, it replaces the characters to
# `paste("Error in", method, ":")` in order to identify which function
# was called in JVM side.
stacktrace <- strsplit(rawmsg, "Error in .*?: ")[[1]]
rmsg <- paste("Error in", method, ":")
stacktrace <- paste(rmsg[1], stacktrace[2])
} else {
# Otherwise, do not convert the error message just in case.
stacktrace <- rawmsg
}
# StreamingQueryException could wrap an IllegalArgumentException, so look for that first
if (any(grepl("org.apache.spark.sql.streaming.StreamingQueryException: ",
stacktrace, fixed = TRUE))) {
msg <- strsplit(stacktrace, "org.apache.spark.sql.streaming.StreamingQueryException: ",
fixed = TRUE)[[1]]
# Extract "Error in ..." message.
rmsg <- msg[1]
# Extract the first message of JVM exception.
first <- strsplit(msg[2], "\r?\n\tat")[[1]][1]
stop(rmsg, "streaming query error - ", first, call. = FALSE)
} else if (any(grepl("java.lang.IllegalArgumentException: ", stacktrace, fixed = TRUE))) {
msg <- strsplit(stacktrace, "java.lang.IllegalArgumentException: ", fixed = TRUE)[[1]]
# Extract "Error in ..." message.
rmsg <- msg[1]
# Extract the first message of JVM exception.
first <- strsplit(msg[2], "\r?\n\tat")[[1]][1]
stop(rmsg, "illegal argument - ", first, call. = FALSE)
} else if (any(grepl("org.apache.spark.sql.AnalysisException: ", stacktrace, fixed = TRUE))) {
msg <- strsplit(stacktrace, "org.apache.spark.sql.AnalysisException: ", fixed = TRUE)[[1]]
# Extract "Error in ..." message.
rmsg <- msg[1]
# Extract the first message of JVM exception.
first <- strsplit(msg[2], "\r?\n\tat")[[1]][1]
stop(rmsg, "analysis error - ", first, call. = FALSE)
} else
if (any(grepl("org.apache.spark.sql.catalyst.analysis.NoSuchDatabaseException: ",
stacktrace, fixed = TRUE))) {
msg <- strsplit(stacktrace, "org.apache.spark.sql.catalyst.analysis.NoSuchDatabaseException: ",
fixed = TRUE)[[1]]
# Extract "Error in ..." message.
rmsg <- msg[1]
# Extract the first message of JVM exception.
first <- strsplit(msg[2], "\r?\n\tat")[[1]][1]
stop(rmsg, "no such database - ", first, call. = FALSE)
} else
if (any(grepl("org.apache.spark.sql.catalyst.analysis.NoSuchTableException: ",
stacktrace, fixed = TRUE))) {
msg <- strsplit(stacktrace, "org.apache.spark.sql.catalyst.analysis.NoSuchTableException: ",
fixed = TRUE)[[1]]
# Extract "Error in ..." message.
rmsg <- msg[1]
# Extract the first message of JVM exception.
first <- strsplit(msg[2], "\r?\n\tat")[[1]][1]
stop(rmsg, "no such table - ", first, call. = FALSE)
} else if (any(grepl("org.apache.spark.sql.catalyst.parser.ParseException: ",
stacktrace, fixed = TRUE))) {
msg <- strsplit(stacktrace, "org.apache.spark.sql.catalyst.parser.ParseException: ",
fixed = TRUE)[[1]]
# Extract "Error in ..." message.
rmsg <- msg[1]
# Extract the first message of JVM exception.
first <- strsplit(msg[2], "\r?\n\tat")[[1]][1]
stop(rmsg, "parse error - ", first, call. = FALSE)
} else {
stop(stacktrace, call. = FALSE)
}
}
# rbind a list of rows with raw (binary) columns
#
# @param inputData a list of rows, with each row a list
# @return data.frame with raw columns as lists
rbindRaws <- function(inputData) {
row1 <- inputData[[1]]
rawcolumns <- ("raw" == sapply(row1, class))
listmatrix <- do.call(rbind, inputData)
# A dataframe with all list columns
out <- as.data.frame(listmatrix)
out[!rawcolumns] <- lapply(out[!rawcolumns], unlist)
out
}
# Get basename without extension from URL
basenameSansExtFromUrl <- function(url) {
# split by '/'
splits <- unlist(strsplit(url, "^.+/"))
last <- tail(splits, 1)
# this is from file_path_sans_ext
# first, remove any compression extension
filename <- sub("[.](gz|bz2|xz)$", "", last)
# then, strip extension by the last '.'
sub("([^.]+)\\.[[:alnum:]]+$", "\\1", filename)
}
isAtomicLengthOne <- function(x) {
is.atomic(x) && length(x) == 1
}
is_windows <- function() {
.Platform$OS.type == "windows"
}
hadoop_home_set <- function() {
!identical(Sys.getenv("HADOOP_HOME"), "")
}
windows_with_hadoop <- function() {
!is_windows() || hadoop_home_set()
}
# get0 not supported before R 3.2.0
getOne <- function(x, envir, inherits = TRUE, ifnotfound = NULL) {
mget(x[1L], envir = envir, inherits = inherits, ifnotfound = list(ifnotfound))[[1L]]
}
# Returns a vector of parent directories, traversing up count times, starting with a full path
# e.g. traverseParentDirs("/Users/user/Library/Caches/spark/spark2.2", 1) should return
# this "/Users/user/Library/Caches/spark/spark2.2"
# and "/Users/user/Library/Caches/spark"
traverseParentDirs <- function(x, count) {
if (dirname(x) == x || count <= 0) x else c(x, Recall(dirname(x), count - 1))
}