blob: b749a413c5b6c792ac05bd1f665b06b30137c38a [file]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#' Judge if an object is mx.dataiter
#'
#' @return Logical indicator
#'
#' @export
is.mx.dataiter <- function(x) {
inherits(x, "Rcpp_MXNativeDataIter") ||
inherits(x, "Rcpp_MXArrayDataIter")
}
#' Extract a certain field from DataIter.
#'
#' @export
mx.io.extract <- function(iter, field) {
packer <- mx.nd.arraypacker()
iter$reset()
while (iter$iter.next()) {
dlist <- iter$value()
padded <- iter$num.pad()
data <- dlist[[field]]
oshape <- dim(data)
ndim <- length(oshape)
packer$push(mx.nd.slice(data, 0, oshape[[ndim]] - padded))
}
iter$reset()
return(packer$get())
}
#
#' Create MXDataIter compatible iterator from R's array
#'
#' @param data The data array.
#' @param label The label array.
#' @param batch.size The batch size used to pack the array.
#' @param shuffle Whether shuffle the data
#'
#' @export
mx.io.arrayiter <- function(data, label,
batch.size=128,
shuffle=FALSE) {
if (shuffle) {
shape <- dim(data)
if (is.null(shape)) {
num.data <- length(data)
} else {
ndim <- length(shape)
num.data <- shape[[ndim]]
}
unif.rnds <- as.array(mx.runif(c(num.data), ctx=mx.cpu()));
} else {
unif.rnds <- as.array(0)
}
mx.io.internal.arrayiter(as.array(data),
as.array(label),
unif.rnds,
batch.size,
shuffle)
}