blob: 06b8e90c139289b6da0dfe30862d196ea4ae0d80 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
require(mxnet)
context("io")
source("get_data.R")
test_that("MNISTIter", {
GetMNIST_ubyte()
batch.size <- 100
train_dataiter <- mx.io.MNISTIter(image = "data/train-images-idx3-ubyte", label = "data/train-labels-idx1-ubyte",
data.shape = c(784), batch.size = batch.size, shuffle = TRUE, flat = TRUE,
silent = 0, seed = 10)
train_dataiter$reset()
batch_count <- 0
while (train_dataiter$iter.next()) {
batch_count <- batch_count + 1
}
nbatch <- 60000/batch.size
expect_equal(batch_count, nbatch)
train_dataiter$reset()
train_dataiter$iter.next()
label_0 <- as.array(train_dataiter$value()$label)
train_dataiter$iter.next()
train_dataiter$iter.next()
train_dataiter$iter.next()
train_dataiter$iter.next()
train_dataiter$reset()
train_dataiter$iter.next()
label_1 <- as.array(train_dataiter$value()$label)
expect_equal(label_0, label_1)
})
test_that("Cifar10Rec", {
GetCifar10()
dataiter <- mx.io.ImageRecordIter(path.imgrec = "./data/cifar/train.rec", path.imglist = "./data/cifar/train.lst",
mean.img = "./data/cifar/cifar10_mean.bin", batch.size = 100, data.shape = c(28,
28, 3), rand.crop = TRUE, rand.mirror = TRUE)
labelcount <- rep(0, 10)
dataiter$reset()
while (dataiter$iter.next()) {
label <- as.array(dataiter$value()$label)
for (i in label) {
labelcount[i + 1] <- labelcount[i + 1] + 1
}
}
expect_equal(labelcount, rep(5000, 10))
})
test_that("mx.io.arrayiter", {
X <- matrix(c(1:10000), 100, 100)
y <- c(1:100)
dataiter <- mx.io.arrayiter(X, y, batch.size = 20, shuffle = FALSE)
dataiter$reset()
batch_count <- 0
while (dataiter$iter.next()) {
batch_count <- batch_count + 1
}
expect_equal(batch_count, 100/20)
y <- round(y/10)
dataiter <- mx.io.arrayiter(X, y, batch.size = 30, shuffle = FALSE)
labelcount <- rep(0, 11)
dataiter$reset()
while (dataiter$iter.next()) {
label <- as.array(dataiter$value()$label)
for (i in label) {
labelcount[i + 1] <- labelcount[i + 1] + 1
}
}
expect_equal(labelcount, c(5, 9, 11, 9, 11, 9, 11, 13, 22, 14, 6))
})