add NDArrayCollector
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala
index 2f79b58..181b232 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala
@@ -167,7 +167,7 @@
private def getOutputs: Array[NDArray] = {
val ndHandles = ArrayBuffer[NDArrayHandle]()
checkCall(_LIB.mxExecutorOutputs(handle, ndHandles))
- ndHandles.toArray.map(new NDArray(_))
+ ndHandles.toArray.map(new NDArray(_, addToCollector = false))
}
/**
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Monitor.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Monitor.scala
index 8e53d65..c8a251d 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/Monitor.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Monitor.scala
@@ -51,7 +51,7 @@
override def invoke(name: String, arr: NDArrayHandle): Unit = {
// wrapper for executor callback
if (activated) {
- val array = new NDArray(arr, writable = false)
+ val array = new NDArray(arr, writable = false, addToCollector = false)
val elem = (step, name, statFunc(array))
queue += elem
}
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
index e8c687e..844621d 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
@@ -547,11 +547,15 @@
* </b>
*/
class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
- val writable: Boolean = true) extends WarnIfNotDisposed {
+ val writable: Boolean = true,
+ addToCollector: Boolean = true) extends WarnIfNotDisposed {
+ if (addToCollector) {
+ NDArrayCollector.collect(this)
+ }
// record arrays who construct this array instance
// we use weak reference to prevent gc blocking
private[mxnet] val dependencies = mutable.HashMap.empty[Long, WeakReference[NDArray]]
- private var disposed = false
+ @volatile private var disposed = false
def isDisposed: Boolean = disposed
def serialize(): Array[Byte] = {
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArrayCollector.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArrayCollector.scala
new file mode 100644
index 0000000..b5ae44b
--- /dev/null
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArrayCollector.scala
@@ -0,0 +1,157 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet
+
+import org.slf4j.LoggerFactory
+
+import scala.annotation.varargs
+import scala.collection.mutable
+
+/**
+ * A collector to store NDArrays.
+ * It provides a scope, NDArrays allocated in the scope can either <br />
+ * - be disposed automatically when the code block finishes, or <br />
+ * - simply be collected for future usage.
+ * <br />
+ * If the return type of scope is <em>NDArray</em> or <em>NDArrayFuncReturn</em>,
+ * the collector is smart enough NOT to collect or dispose the returned NDArray. <br />
+ * However in other cases, it is users' responsibility NOT to leak allocated NDArrays outside,
+ * (e.g., store to a global variable and use later, pass to another thread, etc.) <br />
+ * Usage Example:
+ * <pre>
+ * val a = NDArray.array(Array(-1f, 0f, 1f, 2f, 3f, 4f), shape = Shape(2, 3))
+ * val res = NDArrayCollector.auto().withScope {
+ * (NDArray.relu(a) + a).toArray
+ * }
+ * </pre>
+ * In the case above, the intermediate NDArrays
+ * (created by <em>NDArray.relu</em> and <em>+</em>) will be disposed automatically. <br />
+ * User can also decide to use dispose the collected NDArrays later: <br />
+ * <pre>
+ * val collector = NDArrayCollector.manual()
+ * val res = collector.withScope {
+ * (NDArray.relu(a) + a).toArray
+ * }
+ * collector.foreach(_.dispose())
+ * </pre>
+ * For Java users: <br />
+ * <pre>
+ * NDArray a = NDArray.array(new float[]{-1f, 0f, 1f, 2f, 3f, 4f},
+ * Shape.create(2, 3), Context.cpu(0));
+ * float[] sliced = NDArrayCollector.auto().withScope(
+ * new scala.runtime.AbstractFunction0<float[]>() {
+ * @Override
+ * public float[] apply() {
+ * a.slice(0, 1).toArray();
+ * }
+ * });
+ * </pre>
+ */
+object NDArrayCollector {
+ private val logger = LoggerFactory.getLogger(classOf[NDArrayCollector])
+
+ private val currCollector = new ThreadLocal[NDArrayCollector] {
+ override def initialValue = new NDArrayCollector(false, false)
+ }
+
+ /**
+ * Create a collector which will dispose the collected NDArrays automatically.
+ * @return an auto-disposable collector.
+ */
+ def auto(): NDArrayCollector = new NDArrayCollector(true)
+
+ /**
+ * Create a collector allows users to later dispose the collected NDArray manually.
+ * @return a manually-disposable collector.
+ */
+ def manual(): NDArrayCollector = new NDArrayCollector(false)
+
+ /**
+ * Collect the NDArrays into the collector of the current thread.
+ * @param ndArray NDArrays need to be collected.
+ */
+ @varargs def collect(ndArray: NDArray*): Unit = {
+ currCollector.get().add(ndArray: _*)
+ }
+}
+
+class NDArrayCollector private(private val autoDispose: Boolean = true,
+ private val doCollect: Boolean = true) {
+ private val arrays: mutable.Map[Long, NDArray] = mutable.HashMap.empty[Long, NDArray]
+
+ private def add(nd: NDArray*): Unit = {
+ if (doCollect) nd.foreach(arr => arrays.put(arr.handle, arr))
+ }
+
+ /**
+ * Clear the collector.
+ */
+ def clear(): Unit = {
+ arrays.clear()
+ }
+
+ /**
+ * Iterate over the collected NDArrays and apply the user-defined function to each NDArray.
+ * @param f the function that is applied for its side-effect to every NDArray.
+ * The result of function <em>f</em> is discarded.
+ */
+ def foreach(f: NDArray => Unit): Unit = {
+ arrays.values.foreach(f(_))
+ }
+
+ /**
+ * @return how many unique NDArrays are collected.
+ */
+ def size: Int = arrays.size
+
+ /**
+ * Create a code scope, NDArrays allocated within this scope will be collected.
+ * The collected NDArrays will be either <br />
+ * - disposed automatically when the code blcok finishes (when using <em>auto</em>) or <br />
+ * - stored for later access (when using <em>manual</em>) <br />
+ * If the return type of scope is <em>NDArray</em> or <em>NDArrayFuncReturn</em>,
+ * it is smart enough NOT to collect or dispose the returned NDArray. <br />
+ * However in other cases, it is users' responsibility NOT to leak allocated NDArrays outside.
+ * @param body code block to be executed within the scope.
+ * @tparam T return type of the function <em>body</em>.
+ * @return The result of function <em>body</em>.
+ */
+ def withScope[T](body: => T): T = {
+ val old = NDArrayCollector.currCollector.get()
+ NDArrayCollector.currCollector.set(this)
+ try {
+ val ret = body
+
+ ret match {
+ case ndRet: NDArray =>
+ arrays.remove(ndRet.handle)
+ case ndarrays: NDArrayFuncReturn =>
+ ndarrays.arr.foreach(nd => arrays.remove(nd.handle))
+ case _ => // do nothing
+ }
+
+ if (autoDispose) {
+ foreach(_.dispose())
+ clear()
+ }
+ ret
+ } finally {
+ NDArrayCollector.currCollector.set(old)
+ }
+ }
+}
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Operator.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Operator.scala
index 6630d5f..f2abe5e 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/Operator.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Operator.scala
@@ -72,9 +72,9 @@
val tensors = (0 until 5).toArray.map( x => ArrayBuffer[NDArray]() )
for (i <- 0 until numNdarray) {
if (tags(i) == 1 || tags(i) == 4) {
- tensors(tags(i)) += new NDArray(ndarraies(i), writable = true)
+ tensors(tags(i)) += new NDArray(ndarraies(i), writable = true, addToCollector = false)
} else {
- tensors(tags(i)) += new NDArray(ndarraies(i), writable = false)
+ tensors(tags(i)) += new NDArray(ndarraies(i), writable = false, addToCollector = false)
}
}
val reqEnum = Array("null", "write", "inplace", "add")
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala
index ed3c5ad..0f1fca6 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala
@@ -149,14 +149,15 @@
*/
private def _padData(ndArray: NDArray): NDArray = {
val padNum = cursor + dataBatchSize - numData
- val newArray = NDArray.zeros(ndArray.slice(0, dataBatchSize).shape)
- val batch = ndArray.slice(cursor, numData)
- val padding = ndArray.slice(0, padNum)
- newArray.slice(0, dataBatchSize - padNum).set(batch).dispose()
- newArray.slice(dataBatchSize - padNum, dataBatchSize).set(padding).dispose()
- batch.dispose()
- padding.dispose()
- newArray
+ val shape = Shape(dataBatchSize) ++ ndArray.shape.slice(1, ndArray.shape.size)
+ val newArray = NDArray.zeros(shape)
+ NDArrayCollector.auto().withScope {
+ val batch = ndArray.slice(cursor, numData)
+ val padding = ndArray.slice(0, padNum)
+ newArray.slice(0, dataBatchSize - padNum).set(batch)
+ newArray.slice(dataBatchSize - padNum, dataBatchSize).set(padding)
+ newArray
+ }
}
private def _getData(data: IndexedSeq[(String, NDArray)]): IndexedSeq[NDArray] = {
diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/NDArrayCollectorSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/NDArrayCollectorSuite.scala
new file mode 100644
index 0000000..55fa26b
--- /dev/null
+++ b/scala-package/core/src/test/scala/org/apache/mxnet/NDArrayCollectorSuite.scala
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet
+
+import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers}
+
+class NDArrayCollectorSuite extends FunSuite with BeforeAndAfterAll with Matchers {
+
+ test("auto dispose") {
+ val a = NDArray.array(Array(-1f, 0f, 1f, 2f, 3f, 4f), shape = Shape(2, 3))
+ var b, c: NDArray = null
+
+ val res = NDArrayCollector.auto().withScope {
+ b = NDArray.relu(a) // [0, 0, 1, 2, 3, 4]
+ c = a + b // [-1, 0, 2, 4, 6, 8]
+ c.slice(0, 1)
+ }
+
+ assert(b.isDisposed)
+ assert(c.isDisposed)
+ assert(!res.isDisposed) // smart enough not to dispose the returned NDArray
+
+ assert(res.toArray === Array(-1f, 0f, 2f))
+ }
+
+ test("manually dispose") {
+ val a = NDArray.array(Array(-1f, 0f, 1f, 2f, 3f, 4f), shape = Shape(2, 3))
+ var b, c: NDArray = null
+
+ val collector = NDArrayCollector.manual()
+ val res = collector.withScope {
+ b = NDArray.relu(a) // [0, 0, 1, 2, 3, 4]
+ c = a + b // [-1, 0, 2, 4, 6, 8]
+ c.slice(0, 1)
+ }
+
+ assert(res.toArray === Array(-1f, 0f, 2f))
+
+ assert(collector.size === 2) // smart enough not to collect the returned NDArray
+ assert(!b.isDisposed)
+ assert(!c.isDisposed)
+ assert(!res.isDisposed)
+
+ collector.foreach(_.dispose())
+ assert(b.isDisposed)
+ assert(c.isDisposed)
+ assert(!res.isDisposed)
+
+ collector.clear()
+ assert(collector.size === 0)
+ }
+}