blob: 3270ba7399e50b6103df8790e88402b9117ec963 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.mxnet.io
import ml.dmlc.mxnet.{DataBatch, DataIter, NDArray, Shape}
import org.slf4j.LoggerFactory
import java.util.concurrent.Semaphore
import scala.collection.immutable.ListMap
/**
* Base class for prefetching iterators. Takes one or more DataIters
* and combine them with prefetching.
*
* @param iters list of DataIters
* @param dataNames
* @param labelNames
*/
class PrefetchingIter(
iters: IndexedSeq[DataIter],
dataNames: IndexedSeq[Map[String, String]] = null,
labelNames: IndexedSeq[Map[String, String]] = null) extends DataIter {
private val logger = LoggerFactory.getLogger(classOf[PrefetchingIter])
require(iters.nonEmpty, "Iters length must be greater than 0")
private val _provideData: ListMap[String, Shape] = {
if (dataNames == null) {
iters.map(_.provideData).foldLeft(ListMap[String, Shape]()) { (acc, elem) =>
acc ++ elem
}
} else {
iters.zipWithIndex.map(tu => (tu._1.provideData, tu._2))
.map(m => m._1.map(t => (dataNames(m._2)(t._1), t._2)))
.foldLeft(ListMap[String, Shape]()) { (acc, elem) =>
acc ++ elem
}
}
}
private val _provideLabel: ListMap[String, Shape] = {
if (labelNames == null) {
iters.map(_.provideLabel).foldLeft(ListMap[String, Shape]()) { (acc, elem) =>
acc ++ elem
}
} else {
iters.zipWithIndex.map(tu => (tu._1.provideLabel, tu._2))
.map(m => m._1.map(t => (labelNames(m._2)(t._1), t._2)))
.foldLeft(ListMap[String, Shape]()) { (acc, elem) =>
acc ++ elem
}
}
}
private val _batchSize: Int = this._provideData.toList(0)._2(0)
private val dataReady: IndexedSeq[Semaphore] =
(0 until iters.length).map(i => new Semaphore(0))
private val dataTaken: IndexedSeq[Semaphore] =
(0 until iters.length).map(i => new Semaphore(1))
@volatile private var started: Boolean = true
private var currentBatch: DataBatch = null
private val nextBatch: Array[DataBatch] = (0 until iters.length).map { i =>
new DataBatch(null, null, null, 0)
}.toArray
// thread entry
def prefetchFunc(i: Int): Runnable = new Runnable {
override def run(): Unit = {
while (started) {
dataTaken(i).acquire()
if (started) {
try {
nextBatch(i) = iters(i).next()
} catch {
case ex: NoSuchElementException => nextBatch(i) = null
}
}
dataReady(i).release()
}
}
}
private val prefetchThreads =
for (i <- 0 until iters.length) yield new Thread(prefetchFunc(i))
prefetchThreads.foreach(_.start())
override def next(): DataBatch = currentBatch
/**
* reset the iterator
*/
override def reset(): Unit = {
for (e <- dataReady) e.acquire()
for (i <- iters) i.reset()
for (e <- dataTaken) e.release()
}
override def batchSize: Int = this._batchSize
/**
* get data of current batch
* @return the data of current batch
*/
override def getData(): IndexedSeq[NDArray] = currentBatch.data
/**
* Get label of current batch
* @return the label of current batch
*/
override def getLabel(): IndexedSeq[NDArray] = currentBatch.label
/**
* the index of current batch
* @return
*/
override def getIndex(): IndexedSeq[Long] = currentBatch.index
// The name and shape of label provided by this iterator
override def provideLabel: ListMap[String, Shape] = this._provideLabel
/**
* get the number of padding examples
* in current batch
* @return number of padding examples in current batch
*/
override def getPad(): Int = this.currentBatch.pad
// The name and shape of data provided by this iterator
override def provideData: ListMap[String, Shape] = this._provideData
override def hasNext: Boolean = {
for (e <- dataReady) e.acquire()
if (nextBatch(0) == null) {
for (i <- nextBatch) {
assert(i == null, "Number of entry mismatches between iterators")
}
for (e <- dataReady) e.release()
false
} else {
for (batch <- nextBatch) {
assert(batch.pad == nextBatch(0).pad,
"Number of entry mismatches between iterators")
}
val datas = for (batch <- nextBatch) yield batch.data
val labels = for (batch <- nextBatch) yield batch.label
currentBatch = new DataBatch(datas.toIndexedSeq.flatten,
labels.toIndexedSeq.flatten,
nextBatch(0).index,
nextBatch(0).pad)
for (e <- dataTaken) e.release()
true
}
}
/**
* Stop all its internal prefetching threads.
* The object shall never be used after it is disposed.
*/
def dispose(): Unit = {
started = false
for (e <- dataTaken) e.release()
for (t <- prefetchThreads) t.join()
}
}