blob: 10298fbaf0563a62ccb79cd006a9a56a030d2cf4 [file] [log] [blame]
is.MXDataIter <- function(x) {
inherits(x, "Rcpp_MXNativeDataIter") ||
inherits(x, "Rcpp_MXArrayDataIter")
}
#' Judge if an object is mx.dataiter
#'
#' @return Logical indicator
#'
#' @export
is.mx.dataiter <- is.MXDataIter
#' Extract a certain field from DataIter.
#'
#' @export
mx.io.extract <- function(iter, field) {
packer <- mx.nd.arraypacker()
iter$reset()
while (iter$iter.next()) {
dlist <- iter$value()
padded <- iter$num.pad()
data <- dlist[[field]]
oshape <- dim(data)
ndim <- length(oshape)
packer$push(mx.nd.slice(data, 0, oshape[[ndim]] - padded))
}
iter$reset()
return(packer$get())
}
#
#' Create MXDataIter compatible iterator from R's array
#'
#' @param data The data array.
#' @param label The label array.
#' @param batch.size The batch size used to pack the array.
#' @param shuffle Whether shuffle the data
#'
#' @export
mx.io.arrayiter <- function(data, label,
batch.size=128,
shuffle=FALSE) {
if (shuffle) {
shape <- dim(data)
if (is.null(shape)) {
num.data <- length(data)
} else {
ndim <- length(shape)
num.data <- shape[[ndim]]
}
unif.rnds <- as.array(mx.runif(c(num.data), ctx=mx.cpu()));
} else {
unif.rnds <- as.array(0)
}
mx.io.internal.arrayiter(as.array(data),
as.array(label),
unif.rnds,
batch.size,
shuffle)
}