blob: 32ba3591c97cba263069f37c9e1f2967fd8b1812 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.mxnet.io
import ml.dmlc.mxnet.Base._
import ml.dmlc.mxnet.{DataBatch, DataIter, DataPack, NDArray, Shape, WarnIfNotDisposed}
import ml.dmlc.mxnet.IO._
import org.slf4j.LoggerFactory
import scala.collection.immutable.ListMap
import scala.collection.mutable.ListBuffer
/**
* DataIter built in MXNet.
* @param handle the handle to the underlying C++ Data Iterator
*/
private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle,
dataName: String = "data",
labelName: String = "label")
extends DataIter with WarnIfNotDisposed {
private val logger = LoggerFactory.getLogger(classOf[MXDataIter])
// use currentBatch to implement hasNext
// (may be this is not the best way to do this work,
// fix me if any better way found)
private var currentBatch: DataBatch = null
private val (_provideData: ListMap[String, Shape],
_provideLabel: ListMap[String, Shape],
_batchSize: Int) =
if (hasNext) {
iterNext()
val data = currentBatch.data(0)
val label = currentBatch.label(0)
// properties
val res = (ListMap(dataName -> data.shape), ListMap(labelName -> label.shape), data.shape(0))
currentBatch.dispose()
reset()
res
} else {
(null, null, 0)
}
private var disposed = false
protected def isDisposed = disposed
/**
* Release the native memory.
* The object shall never be used after it is disposed.
*/
def dispose(): Unit = {
if (!disposed) {
_LIB.mxDataIterFree(handle)
disposed = true
}
}
/**
* reset the iterator
*/
override def reset(): Unit = {
currentBatch = null
checkCall(_LIB.mxDataIterBeforeFirst(handle))
}
@throws(classOf[NoSuchElementException])
override def next(): DataBatch = {
if (currentBatch == null) {
iterNext()
}
if (currentBatch != null) {
val batch = currentBatch
currentBatch = null
batch
} else {
throw new NoSuchElementException
}
}
/**
* Iterate to next batch
* @return whether the move is successful
*/
private def iterNext(): Boolean = {
val next = new RefInt
checkCall(_LIB.mxDataIterNext(handle, next))
currentBatch = null
if (next.value > 0) {
currentBatch = new DataBatch(data = getData(), label = getLabel(),
index = getIndex(), pad = getPad())
}
next.value > 0
}
/**
* get data of current batch
* @return the data of current batch
*/
override def getData(): IndexedSeq[NDArray] = {
val out = new NDArrayHandleRef
checkCall(_LIB.mxDataIterGetData(handle, out))
IndexedSeq(new NDArray(out.value, writable = false))
}
/**
* Get label of current batch
* @return the label of current batch
*/
override def getLabel(): IndexedSeq[NDArray] = {
val out = new NDArrayHandleRef
checkCall(_LIB.mxDataIterGetLabel(handle, out))
IndexedSeq(new NDArray(out.value, writable = false))
}
/**
* Get the index of current batch
* @return the index of current batch
*/
override def getIndex(): IndexedSeq[Long] = {
val outIndex = new ListBuffer[Long]
val outSize = new RefLong
checkCall(_LIB.mxDataIterGetIndex(handle, outIndex, outSize))
outIndex.toIndexedSeq
}
/**
* get the number of padding examples
* in current batch
* @return number of padding examples in current batch
*/
override def getPad(): MXUint = {
val out = new MXUintRef
checkCall(_LIB.mxDataIterGetPadNum(handle, out))
out.value
}
// The name and shape of data provided by this iterator
override def provideData: ListMap[String, Shape] = _provideData
// The name and shape of label provided by this iterator
override def provideLabel: ListMap[String, Shape] = _provideLabel
override def hasNext: Boolean = {
if (currentBatch != null) {
true
} else {
iterNext()
}
}
override def batchSize: Int = _batchSize
}
private[mxnet] class MXDataPack(iterName: String, params: Map[String, String]) extends DataPack {
/**
* get data iterator
* @return DataIter
*/
override def iterator: DataIter = {
createIterator(iterName, params)
}
}