| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package org.apache.mxnet.io |
| |
| import java.util.NoSuchElementException |
| |
| import org.apache.mxnet.Base._ |
| import org.apache.mxnet._ |
| import org.slf4j.LoggerFactory |
| |
| import scala.annotation.varargs |
| import scala.collection.immutable.ListMap |
| |
| /** |
| * NDArrayIter object in mxnet. Taking NDArray to get dataiter. |
| * |
| * @param data NDArrayIter supports single or multiple data and label. |
| * @param label Same as data, but is not fed to the model during testing. |
| * @param dataBatchSize Batch Size |
| * @param shuffle Whether to shuffle the data |
| * @param lastBatchHandle "pad", "discard" or "roll_over". How to handle the last batch |
| * |
| * This iterator will pad, discard or roll over the last batch if |
| * the size of data does not match batch_size. Roll over is intended |
| * for training and can cause problems if used for prediction. |
| */ |
| class NDArrayIter(data: IndexedSeq[(String, NDArray)], |
| label: IndexedSeq[(String, NDArray)], |
| private val dataBatchSize: Int, shuffle: Boolean, |
| lastBatchHandle: String) extends DataIter { |
| |
| def this(data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = IndexedSeq.empty, |
| dataBatchSize: Int = 1, shuffle: Boolean = false, |
| lastBatchHandle: String = "pad", |
| dataName: String = "data", labelName: String = "label") { |
| this(IO.initData(data, allowEmpty = false, dataName), |
| IO.initData(label, allowEmpty = true, labelName), |
| dataBatchSize, shuffle, lastBatchHandle) |
| } |
| |
| private val logger = LoggerFactory.getLogger(classOf[NDArrayIter]) |
| |
| val (initData: IndexedSeq[(String, NDArray)], initLabel: IndexedSeq[(String, NDArray)]) = { |
| // data should not be null and size > 0 |
| require(data != null && data.size > 0, |
| "data should not be null and data.size should not be zero") |
| |
| require(label != null, |
| "label should not be null. Use IndexedSeq.empty if there are no labels") |
| |
| // shuffle is not supported currently |
| require(!shuffle, "shuffle is not supported currently") |
| |
| // discard final part if lastBatchHandle equals discard |
| if (lastBatchHandle.equals("discard")) { |
| val dataSize = data(0)._2.shape(0) |
| require(dataBatchSize <= dataSize, |
| "batch_size need to be smaller than data size when not padding.") |
| val keepSize = dataSize - dataSize % dataBatchSize |
| val dataList = data.map { case (name, ndArray) => (name, ndArray.slice(0, keepSize)) } |
| if (!label.isEmpty) { |
| val labelList = label.map { case (name, ndArray) => (name, ndArray.slice(0, keepSize)) } |
| (dataList, labelList) |
| } else { |
| (dataList, label) |
| } |
| } else { |
| (data, label) |
| } |
| } |
| |
| val numData = initData(0)._2.shape(0) |
| val numSource: MXUint = initData.size |
| private var cursor = -dataBatchSize |
| |
| private val (_provideData: ListMap[String, Shape], |
| _provideLabel: ListMap[String, Shape]) = { |
| val pData = ListMap.empty[String, Shape] ++ initData.map(getShape) |
| val pLabel = ListMap.empty[String, Shape] ++ initLabel.map(getShape) |
| (pData, pLabel) |
| } |
| |
| /** |
| * get shape via dataBatchSize |
| * @param dataItem |
| */ |
| private def getShape(dataItem: (String, NDArray)): (String, Shape) = { |
| val len = dataItem._2.shape.size |
| val newShape = dataItem._2.shape.slice(1, len) |
| (dataItem._1, Shape(Array[Int](dataBatchSize)) ++ newShape) |
| } |
| |
| |
| /** |
| * Igore roll over data and set to start |
| */ |
| def hardReset(): Unit = { |
| cursor = -dataBatchSize |
| } |
| |
| /** |
| * reset the iterator |
| */ |
| override def reset(): Unit = { |
| if (lastBatchHandle.equals("roll_over") && cursor > numData) { |
| cursor = -dataBatchSize + (cursor%numData) % dataBatchSize |
| } else { |
| cursor = -dataBatchSize |
| } |
| } |
| |
| override def hasNext: Boolean = { |
| if (cursor + dataBatchSize < numData) { |
| true |
| } else { |
| false |
| } |
| } |
| |
| @throws(classOf[NoSuchElementException]) |
| override def next(): DataBatch = { |
| if (hasNext) { |
| cursor += dataBatchSize |
| new DataBatch(getData(), getLabel(), getIndex(), getPad()) |
| } else { |
| throw new NoSuchElementException |
| } |
| } |
| |
| /** |
| * handle the last batch |
| * @param ndArray |
| * @return |
| */ |
| private def _padData(ndArray: NDArray): NDArray = { |
| val padNum = cursor + dataBatchSize - numData |
| val shape = Shape(dataBatchSize) ++ ndArray.shape.slice(1, ndArray.shape.size) |
| val newArray = NDArray.zeros(shape) |
| NDArrayCollector.auto().withScope { |
| val batch = ndArray.slice(cursor, numData) |
| val padding = ndArray.slice(0, padNum) |
| newArray.slice(0, dataBatchSize - padNum).set(batch) |
| newArray.slice(dataBatchSize - padNum, dataBatchSize).set(padding) |
| newArray |
| } |
| } |
| |
| private def _getData(data: IndexedSeq[(String, NDArray)]): IndexedSeq[NDArray] = { |
| require(cursor < numData, "DataIter needs reset.") |
| if (data == null) { |
| null |
| } else { |
| if (cursor + dataBatchSize <= numData) { |
| data.map { case (_, ndArray) => ndArray.slice(cursor, cursor + dataBatchSize) } |
| } else { |
| // padding |
| data.map { case (_, ndArray) => _padData(ndArray) } |
| } |
| } |
| } |
| |
| /** |
| * get data of current batch |
| * @return the data of current batch |
| */ |
| override def getData(): IndexedSeq[NDArray] = { |
| _getData(initData) |
| } |
| |
| /** |
| * Get label of current batch |
| * @return the label of current batch |
| */ |
| override def getLabel(): IndexedSeq[NDArray] = { |
| _getData(initLabel) |
| } |
| |
| /** |
| * the index of current batch |
| * @return |
| */ |
| override def getIndex(): IndexedSeq[Long] = { |
| cursor.toLong to (cursor + dataBatchSize).toLong |
| } |
| |
| /** |
| * get the number of padding examples |
| * in current batch |
| * @return number of padding examples in current batch |
| */ |
| override def getPad(): MXUint = { |
| if (lastBatchHandle.equals("pad") && cursor + batchSize > numData) { |
| cursor + batchSize - numData |
| } else { |
| 0 |
| } |
| } |
| |
| // The name and shape of data provided by this iterator |
| override def provideData: ListMap[String, Shape] = _provideData |
| |
| // The name and shape of label provided by this iterator |
| override def provideLabel: ListMap[String, Shape] = _provideLabel |
| |
| override def batchSize: Int = dataBatchSize |
| } |
| |
| object NDArrayIter { |
| class Builder() { |
| private var data: IndexedSeq[(String, NDArray)] = IndexedSeq.empty |
| private var label: IndexedSeq[(String, NDArray)] = IndexedSeq.empty |
| private var dataBatchSize: Int = 1 |
| private var lastBatchHandle: String = "pad" |
| |
| def addData(name: String, data: NDArray): Builder = { |
| this.data = this.data ++ IndexedSeq((name, data)) |
| this |
| } |
| |
| def addLabel(name: String, label: NDArray): Builder = { |
| this.label = this.label ++ IndexedSeq((name, label)) |
| this |
| } |
| |
| def setBatchSize(batchSize: Int): Builder = { |
| this.dataBatchSize = batchSize |
| this |
| } |
| |
| def setLastBatchHandle(lastBatchHandle: String): Builder = { |
| this.lastBatchHandle = lastBatchHandle |
| this |
| } |
| |
| def build(): NDArrayIter = { |
| new NDArrayIter(data, label, dataBatchSize, false, lastBatchHandle) |
| } |
| } |
| } |