| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package org.apache.mxnet |
| |
| import java.nio.{ByteBuffer, ByteOrder} |
| |
| import org.apache.mxnet.Base._ |
| import org.apache.mxnet.DType.DType |
| import org.slf4j.LoggerFactory |
| |
| import scala.collection.mutable |
| import scala.collection.mutable.{ArrayBuffer, ListBuffer} |
| import scala.ref.WeakReference |
| |
| /** |
| * NDArray API of mxnet |
| */ |
| @AddNDArrayFunctions(false) |
| object NDArray { |
| implicit def getFirstResult(ret: NDArrayFuncReturn): NDArray = ret(0) |
| private val logger = LoggerFactory.getLogger(classOf[NDArray]) |
| |
| private val functions: Map[String, NDArrayFunction] = initNDArrayModule() |
| |
| private def addDependency(froms: Array[NDArray], tos: Array[NDArray]): Unit = { |
| froms.foreach { from => |
| val weakRef = new WeakReference(from) |
| tos.foreach { to => |
| to.dependencies.put(from.handle, weakRef) |
| // we add all dep's dep to prevent (recursively) recomputing at runtime. |
| to.dependencies ++= from.dependencies |
| } |
| } |
| } |
| |
| // private[mxnet] def genericNDArrayFunctionInvoke( |
| /** |
| * Used by NDArrayMacro. |
| * Invoke this function by passing in parameters. |
| * Parameters |
| * ---------- |
| * @param args Positional arguments of input scalars and NDArray |
| * @param kwargs Key-value arguments of input scalars |
| * @return The result NDArrays of result of computation. |
| */ |
| def genericNDArrayFunctionInvoke( |
| funcName: String, args: Seq[Any], kwargs: Map[String, Any] = null): NDArrayFuncReturn = { |
| val function = functions(funcName) |
| val ndArgs = ArrayBuffer.empty[NDArray] |
| val posArgs = ArrayBuffer.empty[String] |
| args.foreach { |
| case arr: NDArray => |
| ndArgs.append(arr) |
| case arrFunRet: NDArrayFuncReturn => |
| arrFunRet.arr.foreach(ndArgs.append(_)) |
| case arg => |
| posArgs.append(arg.toString) |
| } |
| |
| require(posArgs.length <= function.arguments.length, |
| s"len(posArgs) = ${posArgs.length}, should be less or equal to len(arguments) " + |
| s"= ${function.arguments.length}") |
| val updatedKwargs: Map[String, String] = |
| (Option(kwargs).getOrElse(Map.empty[String, String]) |
| ++ function.arguments.slice(0, posArgs.length).zip(posArgs) - "out" |
| ).map { case (k, v) => k -> v.toString } |
| |
| val (oriOutputs, outputVars) = |
| if (kwargs != null && kwargs.contains("out")) { |
| val output = kwargs("out") |
| output match { |
| case nd: NDArray => (Array(nd), Array(nd.handle)) |
| case ndFuncRet: NDArrayFuncReturn => (ndFuncRet.arr, ndFuncRet.arr.map(_.handle)) |
| case ndArr: Seq[NDArray] => (ndArr.toArray, ndArr.toArray.map(_.handle)) |
| case _ => throw new IllegalArgumentException( |
| "Unsupported out var type, should be NDArray or subclass of Seq[NDArray]") |
| } |
| } else { |
| (null, null) |
| } |
| |
| val outputs = ArrayBuffer.empty[NDArrayHandle] |
| checkCall(_LIB.mxImperativeInvoke(function.handle, ndArgs.map(_.handle).toArray, outputVars, |
| outputs, updatedKwargs.size, updatedKwargs.keys.toArray, updatedKwargs.values.toArray)) |
| new NDArrayFuncReturn(Option(oriOutputs).getOrElse { |
| val outputArrs = outputs.map(new NDArray(_)).toArray |
| addDependency(ndArgs.toArray, outputArrs) |
| outputArrs |
| }) |
| } |
| |
| /** |
| * Return a new empty handle. |
| * Empty handle can be used to hold result |
| * |
| * @return a new empty ndarray handle |
| */ |
| private def newEmptyHandle(): NDArrayHandle = { |
| val hdl = new NDArrayHandleRef |
| checkCall(_LIB.mxNDArrayCreateNone(hdl)) |
| hdl.value |
| } |
| |
| /** |
| * Return a new handle with specified shape and context. |
| * Empty handle is only used to hold results |
| * |
| * @return a new empty ndarray handle |
| */ |
| private def newAllocHandle(shape: Shape, |
| ctx: Context, |
| delayAlloc: Boolean, |
| dtype: DType = DType.Float32): NDArrayHandle = { |
| val hdl = new NDArrayHandleRef |
| checkCall(_LIB.mxNDArrayCreateEx( |
| shape.toArray, |
| shape.length, |
| ctx.deviceTypeid, |
| ctx.deviceId, |
| if (delayAlloc) 1 else 0, |
| dtype.id, |
| hdl)) |
| hdl.value |
| } |
| |
| /** |
| * Wait all async operation to finish in MXNet |
| * This function is used for benchmark only |
| */ |
| def waitall(): Unit = { |
| checkCall(_LIB.mxNDArrayWaitAll()) |
| } |
| |
| // List and add all the atomic symbol functions to current module. |
| private def initNDArrayModule(): Map[String, NDArrayFunction] = { |
| val opNames = ListBuffer.empty[String] |
| checkCall(_LIB.mxListAllOpNames(opNames)) |
| opNames.map(opName => { |
| val opHandle = new RefLong |
| checkCall(_LIB.nnGetOpHandle(opName, opHandle)) |
| makeNDArrayFunction(opHandle.value, opName) |
| }).toMap |
| } |
| |
| // Create an atomic symbol function by handle and function name. |
| private def makeNDArrayFunction(handle: NDArrayHandle, aliasName: String) |
| : (String, NDArrayFunction) = { |
| val name = new RefString |
| val desc = new RefString |
| val keyVarNumArgs = new RefString |
| val numArgs = new RefInt |
| val argNames = ListBuffer.empty[String] |
| val argTypes = ListBuffer.empty[String] |
| val argDescs = ListBuffer.empty[String] |
| |
| checkCall(_LIB.mxSymbolGetAtomicSymbolInfo( |
| handle, name, desc, numArgs, argNames, argTypes, argDescs, keyVarNumArgs)) |
| val arguments = (argTypes zip argNames).filter { case (dtype, _) => |
| !(dtype.startsWith("NDArray") || dtype.startsWith("Symbol") |
| || dtype.startsWith("NDArray-or-Symbol")) |
| }.map { case (_, argName) => |
| argName |
| } |
| (aliasName, new NDArrayFunction(handle, arguments.toList)) |
| } |
| |
| /** |
| * One hot encoding indices into matrix out. |
| * @param indices An NDArray containing indices of the categorical features. |
| * @param out The result holder of the encoding. |
| * @return Same as out. |
| */ |
| def onehotEncode(indices: NDArray, out: NDArray): NDArray = { |
| NDArray.genericNDArrayFunctionInvoke( |
| "_onehot_encode", Seq(indices, out), Map("out" -> out))(0) |
| } |
| |
| /** |
| * Create an empty uninitialized new NDArray, with specified shape. |
| * |
| * @param shape shape of the NDArray. |
| * @param ctx The context of the NDArray, default to current default context. |
| * |
| * @return The created NDArray. |
| */ |
| def empty(shape: Shape, ctx: Context = null, dtype: DType = Base.MX_REAL_TYPE): NDArray = { |
| val context = if (ctx == null) Context.defaultCtx else ctx |
| new NDArray(handle = NDArray.newAllocHandle(shape, context, delayAlloc = false, dtype)) |
| } |
| |
| def empty(shape: Int *): NDArray = empty(Shape(shape: _*)) |
| |
| def empty(ctx: Context, shape: Int *): NDArray = empty(Shape(shape: _*), ctx) |
| |
| /** |
| * Create a new NDArray filled with 0, with specified shape. |
| * |
| * @param shape shape of the NDArray. |
| * @param ctx The context of the NDArray, default to current default context. |
| * |
| * @return The created NDArray. |
| */ |
| def zeros(shape: Shape, ctx: Context = null, dtype: DType = Base.MX_REAL_TYPE): NDArray = { |
| val arr = empty(shape, ctx, dtype) |
| arr.set(0f) |
| arr |
| } |
| |
| def zeros(shape: Int *): NDArray = zeros(Shape(shape: _*)) |
| |
| def zeros(ctx: Context, shape: Int *): NDArray = zeros(Shape(shape: _*), ctx) |
| |
| /** |
| * Create a new NDArray filled with 1, with specified shape. |
| * @param shape shape of the NDArray. |
| * @param ctx The context of the NDArray, default to current default context. |
| * @return The created NDArray. |
| */ |
| def ones(shape: Shape, ctx: Context = null, dtype: DType = Base.MX_REAL_TYPE): NDArray = { |
| val arr = empty(shape, ctx, dtype) |
| arr.set(1f) |
| arr |
| } |
| |
| def ones(shape: Int *): NDArray = ones(Shape(shape: _*)) |
| |
| def ones(ctx: Context, shape: Int *): NDArray = ones(Shape(shape: _*), ctx) |
| |
| /** |
| * Create a new NDArray filled with given value, with specified shape. |
| * @param shape shape of the NDArray. |
| * @param value value to be filled with |
| * @param ctx The context of the NDArray, default to current default context |
| */ |
| def full(shape: Shape, value: Float, ctx: Context = null): NDArray = { |
| val arr = empty(shape, ctx) |
| arr.set(value) |
| arr |
| } |
| |
| // Perform power operator |
| def power(lhs: NDArray, rhs: NDArray): NDArray = { |
| NDArray.genericNDArrayFunctionInvoke("_power", Seq(lhs, rhs)) |
| } |
| |
| def power(lhs: NDArray, rhs: Float): NDArray = { |
| NDArray.genericNDArrayFunctionInvoke("_power_scalar", Seq(lhs, rhs)) |
| } |
| |
| def power(lhs: Float, rhs: NDArray): NDArray = { |
| NDArray.genericNDArrayFunctionInvoke("_rpower_scalar", Seq(lhs, rhs)) |
| } |
| |
| // Perform maximum operator |
| def maximum(lhs: NDArray, rhs: NDArray): NDArray = { |
| NDArray.genericNDArrayFunctionInvoke("_maximum", Seq(lhs, rhs)) |
| } |
| |
| def maximum(lhs: NDArray, rhs: Float): NDArray = { |
| NDArray.genericNDArrayFunctionInvoke("_maximum_scalar", Seq(lhs, rhs)) |
| } |
| |
| def maximum(lhs: Float, rhs: NDArray): NDArray = { |
| NDArray.genericNDArrayFunctionInvoke("_maximum_scalar", Seq(lhs, rhs)) |
| } |
| |
| // Perform minimum operator |
| def minimum(lhs: NDArray, rhs: NDArray): NDArray = { |
| NDArray.genericNDArrayFunctionInvoke("_minimum", Seq(lhs, rhs)) |
| } |
| |
| def minimum(lhs: NDArray, rhs: Float): NDArray = { |
| NDArray.genericNDArrayFunctionInvoke("_minimum_scalar", Seq(lhs, rhs)) |
| } |
| |
| def minimum(lhs: Float, rhs: NDArray): NDArray = { |
| NDArray.genericNDArrayFunctionInvoke("_minimum_scalar", Seq(lhs, rhs)) |
| } |
| |
| /** |
| * Returns the result of element-wise **equal to** (==) comparison operation with broadcasting. |
| * For each element in input arrays, return 1(true) if corresponding elements are same, |
| * otherwise return 0(false). |
| */ |
| def equal(lhs: NDArray, rhs: NDArray): NDArray = { |
| NDArray.genericNDArrayFunctionInvoke("broadcast_equal", Seq(lhs, rhs)) |
| } |
| |
| def equal(lhs: NDArray, rhs: Float): NDArray = { |
| NDArray.genericNDArrayFunctionInvoke("_equal_scalar", Seq(lhs, rhs)) |
| } |
| |
| /** |
| * Returns the result of element-wise **not equal to** (!=) comparison operation |
| * with broadcasting. |
| * For each element in input arrays, return 1(true) if corresponding elements are different, |
| * otherwise return 0(false). |
| */ |
| def notEqual(lhs: NDArray, rhs: NDArray): NDArray = { |
| NDArray.genericNDArrayFunctionInvoke("broadcast_not_equal", Seq(lhs, rhs)) |
| } |
| |
| def notEqual(lhs: NDArray, rhs: Float): NDArray = { |
| NDArray.genericNDArrayFunctionInvoke("_not_equal_scalar", Seq(lhs, rhs)) |
| } |
| |
| /** |
| * Returns the result of element-wise **greater than** (>) comparison operation |
| * with broadcasting. |
| * For each element in input arrays, return 1(true) if lhs elements are greater than rhs, |
| * otherwise return 0(false). |
| */ |
| def greater(lhs: NDArray, rhs: NDArray): NDArray = { |
| NDArray.genericNDArrayFunctionInvoke("broadcast_greater", Seq(lhs, rhs)) |
| } |
| |
| def greater(lhs: NDArray, rhs: Float): NDArray = { |
| NDArray.genericNDArrayFunctionInvoke("_greater_scalar", Seq(lhs, rhs)) |
| } |
| |
| /** |
| * Returns the result of element-wise **greater than or equal to** (>=) comparison |
| * operation with broadcasting. |
| * For each element in input arrays, return 1(true) if lhs elements are greater than equal to rhs, |
| * otherwise return 0(false). |
| */ |
| def greaterEqual(lhs: NDArray, rhs: NDArray): NDArray = { |
| NDArray.genericNDArrayFunctionInvoke("broadcast_greater_equal", Seq(lhs, rhs)) |
| } |
| |
| def greaterEqual(lhs: NDArray, rhs: Float): NDArray = { |
| NDArray.genericNDArrayFunctionInvoke("_greater_equal_scalar", Seq(lhs, rhs)) |
| } |
| |
| /** |
| * Returns the result of element-wise **lesser than** (<) comparison operation |
| * with broadcasting. |
| * For each element in input arrays, return 1(true) if lhs elements are less than rhs, |
| * otherwise return 0(false). |
| */ |
| def lesser(lhs: NDArray, rhs: NDArray): NDArray = { |
| NDArray.genericNDArrayFunctionInvoke("broadcast_lesser", Seq(lhs, rhs)) |
| } |
| |
| def lesser(lhs: NDArray, rhs: Float): NDArray = { |
| NDArray.genericNDArrayFunctionInvoke("_lesser_scalar", Seq(lhs, rhs)) |
| } |
| |
| /** |
| * Returns the result of element-wise **lesser than or equal to** (<=) comparison |
| * operation with broadcasting. |
| * For each element in input arrays, return 1(true) if lhs elements are |
| * lesser than equal to rhs, otherwise return 0(false). |
| */ |
| def lesserEqual(lhs: NDArray, rhs: NDArray): NDArray = { |
| NDArray.genericNDArrayFunctionInvoke("broadcast_lesser_equal", Seq(lhs, rhs)) |
| } |
| |
| def lesserEqual(lhs: NDArray, rhs: Float): NDArray = { |
| NDArray.genericNDArrayFunctionInvoke("_lesser_equal_scalar", Seq(lhs, rhs)) |
| } |
| |
| /** |
| * Create a new NDArray that copies content from source_array. |
| * @param sourceArr Source data to create NDArray from. |
| * @param shape shape of the NDArray |
| * @param ctx The context of the NDArray, default to current default context. |
| * @return The created NDArray. |
| */ |
| def array(sourceArr: Array[Float], shape: Shape, ctx: Context = null): NDArray = { |
| val arr = empty(shape, ctx) |
| arr.set(sourceArr) |
| arr |
| } |
| |
| /** |
| * Returns evenly spaced values within a given interval. |
| * Values are generated within the half-open interval [`start`, `stop`). In other |
| * words, the interval includes `start` but excludes `stop`. |
| * @param start Start of interval. The default start value is 0. |
| * @param stop End of interval. |
| * @param step Spacing between values. The default step size is 1. |
| * @param repeat Number of times to repeat each element. The default repeat count is 1. |
| * @param ctx Device context. Default context is the current default context. |
| * @param dType The data type of the `NDArray`. The default datatype is `DType.Float32`. |
| * @return NDArray of evenly spaced values in the specified range. |
| */ |
| def arange(start: Float, stop: Option[Float] = None, step: Float = 1.0f, |
| repeat: Int = 1, ctx: Context = Context.defaultCtx, |
| dType: DType = Base.MX_REAL_TYPE): NDArray = { |
| val params = Map("start" -> start, "step" -> step, |
| "repeat" -> repeat, "ctx" -> ctx.toString, "dtype" -> dType.toString()) |
| val fParams = if (stop == None) params else params ++ Map("stop" -> stop.get) |
| NDArray.genericNDArrayFunctionInvoke("_arange", Seq(), fParams)(0) |
| } |
| |
| /** |
| * Concatenate a list of NDArrays along the specified dimension. |
| * @param arrays Arrays to be concatenate. |
| * They must have identical shape except the first dimension. |
| * They also must have the same data type. |
| * @param axis The axis along which to concatenate. |
| * @param alwaysCopy Default `True`. When not `True`, |
| * if the arrays only contain one `NDArray`, |
| * that element will be returned directly, avoid copying. |
| * @return An `NDArray` that lives on the same context as `arrays[0].context`. |
| */ |
| def concatenate(arrays: Seq[NDArray], axis: Int = 0, alwaysCopy: Boolean = true): NDArray = { |
| require(arrays.size > 0) |
| |
| val array0 = arrays(0) |
| if (!alwaysCopy && arrays.size == 1) { |
| array0 |
| } else { |
| val shapeRest1 = array0.shape.slice(0, axis) |
| val shapeRest2 = array0.shape.slice(axis + 1, array0.shape.length) |
| val dtype = array0.dtype |
| |
| val shapeAxis = |
| arrays.map(arr => { |
| require(shapeRest1 == arr.shape.slice(0, axis)) |
| require(shapeRest2 == arr.shape.slice(axis + 1, arr.shape.length)) |
| require(dtype == arr.dtype) |
| arr.shape(axis) |
| }).sum |
| val retShape = shapeRest1 ++ Shape(shapeAxis) ++ shapeRest2 |
| val ret = NDArray.empty(retShape, ctx = array0.context, dtype = dtype) |
| |
| var idx = 0 |
| val begin = Array.fill(retShape.length)(0) |
| val end = retShape.toArray |
| for (arr <- arrays) { |
| if (axis == 0) { |
| ret.slice(idx, idx + arr.shape(0)).set(arr).dispose() |
| } else { |
| begin(axis) = idx |
| end(axis) = idx + arr.shape(axis) |
| NDArray._crop_assign(Map("out" -> ret, |
| "begin" -> Shape(begin), |
| "end" -> Shape(end)))(ret, arr) |
| } |
| idx += arr.shape(axis) |
| } |
| ret |
| } |
| } |
| |
| def concatenate(arrays: NDArray *): NDArray = { |
| concatenate(arrays.toSeq) |
| } |
| |
| /** |
| * Load ndarray from binary file. |
| * |
| * You can also use pickle to do the job if you only work on python. |
| * The advantage of load/save is the file is language agnostic. |
| * This means the file saved using save can be loaded by other language binding of mxnet. |
| * You also get the benefit being able to directly load/save from cloud storage(S3, HDFS) |
| * |
| * @param fname |
| * The name of the file.Can be S3 or HDFS address (remember built with S3 support). |
| * Example of fname: |
| * - `s3://my-bucket/path/my-s3-ndarray` |
| * - `hdfs://my-bucket/path/my-hdfs-ndarray` |
| * - `/path-to/my-local-ndarray` |
| * @return dict of str->NDArray to be saved |
| */ |
| def load(fname: String): (Array[String], Array[NDArray]) = { |
| val outSize = new MXUintRef |
| val outNameSize = new MXUintRef |
| val handles = ArrayBuffer.empty[NDArrayHandle] |
| val names = ArrayBuffer.empty[String] |
| checkCall(_LIB.mxNDArrayLoad(fname, outSize, handles, outNameSize, names)) |
| require(outNameSize.value == 0 || outNameSize.value == outSize.value) |
| (names.toArray, handles.map(new NDArray(_)).toArray) |
| } |
| |
| def load2Map(fname: String): Map[String, NDArray] = { |
| val (keys, vals) = load(fname) |
| require(keys.length == vals.length, "Loaded NDArrays have no name") |
| (keys zip vals).toMap |
| } |
| |
| def load2Array(fname: String): Array[NDArray] = { |
| load(fname)._2 |
| } |
| |
| /** |
| * Save list of NDArray or dict of str->NDArray to binary file. |
| * |
| * You can also use pickle to do the job if you only work on python. |
| * The advantage of load/save is the file is language agnostic. |
| * This means the file saved using save can be loaded by other language binding of mxnet. |
| * You also get the benefit being able to directly load/save from cloud storage(S3, HDFS) |
| * |
| * @param fname |
| * The name of the file.Can be S3 or HDFS address (remember built with S3 support). |
| * Example of fname: |
| * - `s3://my-bucket/path/my-s3-ndarray` |
| * - `hdfs://my-bucket/path/my-hdfs-ndarray` |
| * - `/path-to/my-local-ndarray` |
| * @param data dict of str->NDArray |
| */ |
| def save(fname: String, data: Map[String, NDArray]): Unit = { |
| val keys = data.keys.toArray |
| val handles = data.values.map(_.handle).toArray |
| save(fname, keys, handles) |
| } |
| |
| def save(fname: String, data: Traversable[NDArray]): Unit = { |
| save(fname, null, data.map(_.handle).toArray) |
| } |
| |
| private def save(fname: String, keys: Array[String], handles: Array[NDArrayHandle]): Unit = { |
| checkCall(_LIB.mxNDArraySave(fname, handles, keys)) |
| } |
| |
| def deserialize(bytes: Array[Byte]): NDArray = { |
| val handleRef = new NDArrayHandleRef |
| checkCall(_LIB.mxNDArrayLoadFromRawBytes(bytes, handleRef)) |
| new NDArray(handleRef.value) |
| } |
| |
| // TODO: imdecode |
| } |
| |
| /** |
| * NDArray object in mxnet. |
| * NDArray is basic ndarray/Tensor like data structure in mxnet. <br /> |
| * <b> |
| * WARNING: it is your responsibility to clear this object through dispose(). |
| * </b> |
| */ |
| class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle, |
| val writable: Boolean = true) extends WarnIfNotDisposed { |
| // record arrays who construct this array instance |
| // we use weak reference to prevent gc blocking |
| private[mxnet] val dependencies = mutable.HashMap.empty[Long, WeakReference[NDArray]] |
| private var disposed = false |
| def isDisposed: Boolean = disposed |
| |
| def serialize(): Array[Byte] = { |
| val buf = ArrayBuffer.empty[Byte] |
| checkCall(_LIB.mxNDArraySaveRawBytes(handle, buf)) |
| buf.toArray |
| } |
| |
| /** |
| * Release the native memory. <br /> |
| * The NDArrays it depends on will NOT be disposed. <br /> |
| * The object shall never be used after it is disposed. |
| */ |
| def dispose(): Unit = { |
| if (!disposed) { |
| _LIB.mxNDArrayFree(handle) |
| dependencies.clear() |
| disposed = true |
| } |
| } |
| |
| /** |
| * Dispose all NDArrays who help to construct this array. <br /> |
| * e.g. (a * b + c).disposeDeps() will dispose a, b, c (including their deps) and a * b |
| * @return this array |
| */ |
| def disposeDeps(): NDArray = { |
| disposeDepsExcept() |
| } |
| |
| /** |
| * Dispose all NDArrays who help to construct this array, excepts those in the arguments. <br /> |
| * e.g. (a * b + c).disposeDepsExcept(a, b) |
| * will dispose c and a * b. |
| * Note that a, b's dependencies will not be disposed either. |
| * @return this array |
| */ |
| def disposeDepsExcept(arrs: NDArray*): NDArray = { |
| if (dependencies != null) { |
| val excepts = mutable.HashSet.empty[Long] |
| arrs.foreach { arr => |
| excepts += arr.handle |
| excepts ++= arr.dependencies.keys |
| } |
| dependencies.retain { case (addr, weak) => |
| if (excepts.contains(addr)) { |
| true |
| } else { |
| weak.get.foreach(_.dispose()) |
| false |
| } |
| } |
| } |
| this |
| } |
| |
| /** |
| * Peform an synchronize copy from the array. |
| * @param source The data source we should like to copy from. |
| */ |
| private def syncCopyfrom(source: Array[Float]): Unit = { |
| require(source.length == size, |
| s"array size (${source.length}) do not match the size of NDArray ($size)") |
| checkCall(_LIB.mxNDArraySyncCopyFromCPU(handle, source, source.length)) |
| } |
| |
| /** |
| * Return a sliced NDArray that shares memory with current one. |
| * NDArray only support continuous slicing on axis 0 |
| * |
| * @param start Starting index of slice. |
| * @param stop Finishing index of slice. |
| * |
| * @return a sliced NDArray that shares memory with current one. |
| */ |
| def slice(start: Int, stop: Int): NDArray = { |
| val sliceHandle = new NDArrayHandleRef |
| checkCall(_LIB.mxNDArraySlice(handle, start, stop, sliceHandle)) |
| new NDArray(handle = sliceHandle.value, writable = this.writable) |
| } |
| |
| def slice(range: (Int, Int)): NDArray = { |
| slice(range._1, range._2) |
| } |
| |
| /** |
| * Return a sliced NDArray at the ith position of axis0 |
| * @param i |
| * @return a sliced NDArray that shares memory with current one. |
| */ |
| def slice(i: Int): NDArray = { |
| slice(i, i + 1) |
| } |
| |
| /** |
| * Return a sub NDArray that shares memory with current one. |
| * the first axis will be rolled up, which causes its shape different from slice(i, i+1) |
| * @param idx index of sub array. |
| */ |
| def at(idx: Int): NDArray = { |
| val handleRef = new NDArrayHandleRef() |
| checkCall(_LIB.mxNDArrayAt(this.handle, idx, handleRef)) |
| new NDArray(handle = handleRef.value, writable = this.writable) |
| } |
| |
| // Get transpose of current NDArray |
| def T: NDArray = { |
| require(this.shape.size == 2, "Only 2D matrix is allowed to be transposed") |
| NDArray.genericNDArrayFunctionInvoke("transpose", Seq(this)) |
| } |
| |
| /** |
| * Get data type of current NDArray. |
| * @return class representing type of current ndarray |
| */ |
| def dtype: DType = { |
| val mxDtype = new RefInt |
| checkCall(_LIB.mxNDArrayGetDType(handle, mxDtype)) |
| DType(mxDtype.value) |
| } |
| |
| /** |
| * Return a copied numpy array of current array with specified type. |
| * @param dtype Desired type of result array. |
| * @return A copy of array content. |
| */ |
| def asType(dtype: DType): NDArray = { |
| val res = NDArray.empty(this.shape, ctx = this.context, dtype = dtype) |
| this.copyTo(res) |
| res |
| } |
| |
| /** |
| * Return a reshaped NDArray that shares memory with current one. |
| * @param dims New shape. |
| * |
| * @return a reshaped NDArray that shares memory with current one. |
| */ |
| def reshape(dims: Array[Int]): NDArray = { |
| val reshapeHandle = new NDArrayHandleRef |
| checkCall(_LIB.mxNDArrayReshape(handle, dims.length, dims, reshapeHandle)) |
| new NDArray(handle = reshapeHandle.value, writable = this.writable) |
| } |
| |
| /** |
| * Return a reshaped NDArray that shares memory with current one. |
| * @param dims New shape. |
| * |
| * @return a reshaped NDArray that shares memory with current one. |
| */ |
| def reshape(dims: Shape): NDArray = { |
| reshape(dims.toArray) |
| } |
| |
| /** |
| * Block until all pending writes operations on current NDArray are finished. |
| * This function will return when all the pending writes to the current |
| * NDArray finishes. There can still be pending read going on when the |
| * function returns. |
| */ |
| def waitToRead(): Unit = { |
| checkCall(_LIB.mxNDArrayWaitToRead(handle)) |
| } |
| |
| /** |
| * Get context of current NDArray. |
| * @return The context of current NDArray. |
| */ |
| def context: Context = { |
| val devTypeId = new RefInt |
| val devId = new RefInt |
| checkCall(_LIB.mxNDArrayGetContext(handle, devTypeId, devId)) |
| new Context(Context.devtype2str(devTypeId.value), devId.value) |
| } |
| |
| /** |
| * Set the values of the NDArray |
| * @param value Value to set |
| * @return Current NDArray |
| */ |
| def set(value: Float): NDArray = { |
| require(writable, "trying to assign to a readonly NDArray") |
| NDArray.genericNDArrayFunctionInvoke("_set_value", Seq(value), Map("out" -> this)) |
| this |
| } |
| |
| def set(other: NDArray): NDArray = { |
| require(writable, "trying to assign to a readonly NDArray") |
| other.copyTo(this) |
| } |
| |
| def set(other: Array[Float]): NDArray = { |
| require(writable, "trying to assign to a readonly NDArray") |
| syncCopyfrom(other) |
| this |
| } |
| |
| def +(other: NDArray): NDArray = { |
| NDArray.genericNDArrayFunctionInvoke("_plus", Seq(this, other)) |
| } |
| |
| def +(other: Float): NDArray = { |
| NDArray.genericNDArrayFunctionInvoke("_plus_scalar", Seq(this, other)) |
| } |
| |
| def +=(other: NDArray): NDArray = { |
| if (!writable) { |
| throw new IllegalArgumentException("trying to add to a readonly NDArray") |
| } |
| NDArray.genericNDArrayFunctionInvoke("_plus", Seq(this, other), Map("out" -> this)) |
| this |
| } |
| |
| def +=(other: Float): NDArray = { |
| if (!writable) { |
| throw new IllegalArgumentException("trying to add to a readonly NDArray") |
| } |
| NDArray.genericNDArrayFunctionInvoke("_plus_scalar", Seq(this, other), Map("out" -> this)) |
| this |
| } |
| |
| def -(other: NDArray): NDArray = { |
| NDArray.genericNDArrayFunctionInvoke("_minus", Seq(this, other)) |
| } |
| |
| def -(other: Float): NDArray = { |
| NDArray.genericNDArrayFunctionInvoke("_minus_scalar", Seq(this, other)) |
| } |
| |
| def -=(other: NDArray): NDArray = { |
| if (!writable) { |
| throw new IllegalArgumentException("trying to subtract from a readonly NDArray") |
| } |
| NDArray.genericNDArrayFunctionInvoke("_minus", Seq(this, other), Map("out" -> this)) |
| this |
| } |
| |
| def -=(other: Float): NDArray = { |
| if (!writable) { |
| throw new IllegalArgumentException("trying to subtract from a readonly NDArray") |
| } |
| NDArray.genericNDArrayFunctionInvoke("_minus_scalar", Seq(this, other), Map("out" -> this)) |
| this |
| } |
| |
| def *(other: NDArray): NDArray = { |
| NDArray.genericNDArrayFunctionInvoke("_mul", Seq(this, other)) |
| } |
| |
| def *(other: Float): NDArray = { |
| NDArray.genericNDArrayFunctionInvoke("_mul_scalar", Seq(this, other)) |
| } |
| |
| def unary_-(): NDArray = { |
| NDArray.genericNDArrayFunctionInvoke("_mul_scalar", Seq(this, -1f)) |
| } |
| |
| def *=(other: NDArray): NDArray = { |
| if (!writable) { |
| throw new IllegalArgumentException("trying to multiply to a readonly NDArray") |
| } |
| NDArray.genericNDArrayFunctionInvoke("_mul", Seq(this, other), Map("out" -> this)) |
| this |
| } |
| |
| def *=(other: Float): NDArray = { |
| if (!writable) { |
| throw new IllegalArgumentException("trying to multiply to a readonly NDArray") |
| } |
| NDArray.genericNDArrayFunctionInvoke("_mul_scalar", Seq(this, other), Map("out" -> this)) |
| this |
| } |
| |
| def /(other: NDArray): NDArray = { |
| NDArray.genericNDArrayFunctionInvoke("_div", Seq(this, other)) |
| } |
| |
| def /(other: Float): NDArray = { |
| NDArray.genericNDArrayFunctionInvoke("_div_scalar", Seq(this, other)) |
| } |
| |
| def /=(other: NDArray): NDArray = { |
| if (!writable) { |
| throw new IllegalArgumentException("trying to divide from a readonly NDArray") |
| } |
| NDArray.genericNDArrayFunctionInvoke("_div", Seq(this, other), Map("out" -> this)) |
| this |
| } |
| |
| def /=(other: Float): NDArray = { |
| if (!writable) { |
| throw new IllegalArgumentException("trying to divide from a readonly NDArray") |
| } |
| NDArray.genericNDArrayFunctionInvoke("_div_scalar", Seq(this, other), Map("out" -> this)) |
| this |
| } |
| |
| def **(other: NDArray): NDArray = { |
| NDArray.power(this, other) |
| } |
| |
| def **(other: Float): NDArray = { |
| NDArray.power(this, other) |
| } |
| |
| def **=(other: NDArray): NDArray = { |
| NDArray.genericNDArrayFunctionInvoke("_power", Seq(this, other), Map("out" -> this)) |
| } |
| |
| def **=(other: Float): NDArray = { |
| NDArray.genericNDArrayFunctionInvoke("_power_scalar", Seq(this, other), Map("out" -> this)) |
| } |
| |
| def >(other: NDArray): NDArray = { |
| NDArray.greater(this, other) |
| } |
| |
| def >(other: Float): NDArray = { |
| NDArray.greater(this, other) |
| } |
| |
| def >=(other: NDArray): NDArray = { |
| NDArray.greaterEqual(this, other) |
| } |
| |
| def >=(other: Float): NDArray = { |
| NDArray.greaterEqual(this, other) |
| } |
| |
| def <(other: NDArray): NDArray = { |
| NDArray.lesser(this, other) |
| } |
| |
| def <(other: Float): NDArray = { |
| NDArray.lesser(this, other) |
| } |
| |
| def <=(other: NDArray): NDArray = { |
| NDArray.lesserEqual(this, other) |
| } |
| |
| def <=(other: Float): NDArray = { |
| NDArray.lesserEqual(this, other) |
| } |
| |
| def %(other: NDArray): NDArray = { |
| NDArray.genericNDArrayFunctionInvoke("_mod", Seq(this, other)) |
| } |
| |
| def %(other: Float): NDArray = { |
| NDArray.genericNDArrayFunctionInvoke("_mod_scalar", Seq(this, other)) |
| } |
| |
| def %=(other: NDArray): NDArray = { |
| if (!writable) { |
| throw new IllegalArgumentException("trying to take modulo from a readonly NDArray") |
| } |
| NDArray.genericNDArrayFunctionInvoke("_mod", Seq(this, other), Map("out" -> this)) |
| this |
| } |
| |
| def %=(other: Float): NDArray = { |
| if (!writable) { |
| throw new IllegalArgumentException("trying to take modulo from a readonly NDArray") |
| } |
| NDArray.genericNDArrayFunctionInvoke("_mod_scalar", Seq(this, other), Map("out" -> this)) |
| this |
| } |
| |
| /** |
| * Return a copied flat java array of current array (row-major). |
| * @return A copy of array content. |
| */ |
| def toArray: Array[Float] = { |
| internal.toFloatArray |
| } |
| |
| def internal: NDArrayInternal = { |
| val myType = dtype |
| val arrLength = DType.numOfBytes(myType) * size |
| val arr = Array.ofDim[Byte](arrLength) |
| checkCall(_LIB.mxNDArraySyncCopyToCPU(handle, arr, size)) |
| new NDArrayInternal(arr, myType) |
| } |
| |
| /** |
| * Return a CPU scalar(float) of current ndarray. |
| * This ndarray must have shape (1,) |
| * |
| * @return The scalar representation of the ndarray. |
| */ |
| def toScalar: Float = { |
| require(shape == Shape(1), "The current array is not a scalar") |
| this.toArray(0) |
| } |
| |
| /** |
| * Copy the content of current array to other. |
| * |
| * @param other Target NDArray or context we want to copy data to. |
| * @return The copy target NDArray |
| */ |
| def copyTo(other: NDArray): NDArray = { |
| if (other.handle == this.handle) { |
| NDArray.logger.warn("copy an array to itself, is it intended ?") |
| } else { |
| NDArray.genericNDArrayFunctionInvoke("_copyto", Seq(this), Map("out" -> other)) |
| } |
| other |
| } |
| |
| /** |
| * Copy the content of current array to a new NDArray in the context. |
| * |
| * @param ctx Target context we want to copy data to. |
| * @return The copy target NDArray |
| */ |
| def copyTo(ctx: Context): NDArray = { |
| val ret = new NDArray(NDArray.newAllocHandle(shape, ctx, delayAlloc = true)) |
| copyTo(ret) |
| } |
| |
| /** |
| * Clone the current array |
| * @return the copied NDArray in the same context |
| */ |
| def copy(): NDArray = copyTo(this.context) |
| |
| /** |
| * Get shape of current NDArray. |
| * @return an array representing shape of current ndarray |
| */ |
| def shape: Shape = { |
| val ndim = new MXUintRef |
| val data = ArrayBuffer[Int]() |
| checkCall(_LIB.mxNDArrayGetShape(handle, ndim, data)) |
| require(ndim.value == data.length, s"ndim=$ndim, while len(pdata)=${data.length}") |
| Shape(data) |
| } |
| |
| // Get size of current NDArray. |
| def size: Int = shape.product |
| |
| /** |
| * Return an `NDArray` that lives in the target context. If the array |
| * is already in that context, `self` is returned. Otherwise, a copy is made. |
| * @param context The target context we want the return value to live in. |
| * @return A copy or `self` as an `NDArray` that lives in the target context. |
| */ |
| def asInContext(context: Context): NDArray = { |
| if (this.context == context) this else this.copyTo(context) |
| } |
| |
| override def equals(o: Any): Boolean = o match { |
| case that: NDArray => |
| that != null && that.shape == this.shape && that.toArray.sameElements(this.toArray) |
| case _ => false |
| } |
| |
| override def hashCode: Int = { |
| // TODO: naive implementation |
| shape.hashCode + toArray.hashCode |
| } |
| } |
| |
| private[mxnet] object NDArrayConversions { |
| implicit def int2Scalar(x: Int): NDArrayConversions = new NDArrayConversions(x.toFloat) |
| implicit def double2Scalar(x: Double): NDArrayConversions = new NDArrayConversions(x.toFloat) |
| implicit def float2Scalar(x: Float): NDArrayConversions = new NDArrayConversions(x) |
| } |
| |
| private[mxnet] class NDArrayConversions(val value: Float) { |
| def +(other: NDArray): NDArray = { |
| other + value |
| } |
| def +(other: NDArrayFuncReturn): NDArray = { |
| other.head + value |
| } |
| |
| def -(other: NDArray): NDArray = { |
| NDArray.genericNDArrayFunctionInvoke("_rminus_scalar", Seq(other, value)) |
| } |
| def -(other: NDArrayFuncReturn): NDArray = { |
| NDArray.genericNDArrayFunctionInvoke("_rminus_scalar", Seq(other.head, value)) |
| } |
| |
| def *(other: NDArray): NDArray = { |
| other * value |
| } |
| def *(other: NDArrayFuncReturn): NDArray = { |
| other.head * value |
| } |
| |
| def /(other: NDArray): NDArray = { |
| NDArray.genericNDArrayFunctionInvoke("_rdiv_scalar", Seq(other, value)) |
| } |
| def /(other: NDArrayFuncReturn): NDArray = { |
| NDArray.genericNDArrayFunctionInvoke("_rdiv_scalar", Seq(other.head, value)) |
| } |
| |
| def **(other: NDArray): NDArray = { |
| NDArray.power(value, other) |
| } |
| def **(other: NDArrayFuncReturn): NDArray = { |
| NDArray.power(value, other.head) |
| } |
| |
| def >(other: NDArray): NDArray = { |
| NDArray.lesser(other, value) |
| } |
| def >(other: NDArrayFuncReturn): NDArray = { |
| NDArray.lesser(other.head, value) |
| } |
| |
| def >=(other: NDArray): NDArray = { |
| NDArray.lesserEqual(other, value) |
| } |
| def >=(other: NDArrayFuncReturn): NDArray = { |
| NDArray.lesserEqual(other.head, value) |
| } |
| |
| def <(other: NDArray): NDArray = { |
| NDArray.greater(other, value) |
| } |
| def <(other: NDArrayFuncReturn): NDArray = { |
| NDArray.greater(other.head, value) |
| } |
| |
| def <=(other: NDArray): NDArray = { |
| NDArray.greaterEqual(other, value) |
| } |
| def <=(other: NDArrayFuncReturn): NDArray = { |
| NDArray.greaterEqual(other.head, value) |
| } |
| } |
| |
| private case class NDArrayFunction(handle: NDArrayHandle, arguments: List[String]) |
| |
| private[mxnet] class NDArrayFuncReturn(private[mxnet] val arr: Array[NDArray]) { |
| def head: NDArray = apply(0) |
| def get: NDArray = { |
| require(arr.length == 1, s"return array length = ${arr.length}") |
| head |
| } |
| def apply(i: Int): NDArray = { |
| if (arr == null || arr.length <= i) { |
| null |
| } else { |
| arr(i) |
| } |
| } |
| |
| // copy methods from NDArray |
| def isDisposed: Boolean = head.isDisposed |
| def serialize(): Array[Byte] = head.serialize() |
| def dispose(): Unit = head.dispose() |
| def disposeDeps(): NDArray = head.disposeDeps() |
| def disposeDepsExcept(arrs: NDArray*): NDArray = head.disposeDepsExcept(arrs: _*) |
| def slice(start: Int, stop: Int): NDArray = head.slice(start, stop) |
| def slice(range: (Int, Int)): NDArray = head.slice(range) |
| def slice(i: Int): NDArray = head.slice(i) |
| def reshape(dims: Array[Int]): NDArray = head.reshape(dims) |
| def waitToRead(): Unit = head.waitToRead() |
| def context: Context = head.context |
| def set(value: Float): NDArray = head.set(value) |
| def set(other: NDArray): NDArray = head.set(other) |
| def set(other: Array[Float]): NDArray = head.set(other) |
| def +(other: NDArray): NDArray = head + other |
| def +(other: Float): NDArray = head + other |
| def +=(other: NDArray): NDArray = head += other |
| def +=(other: Float): NDArray = head += other |
| def -(other: NDArray): NDArray = head - other |
| def -(other: Float): NDArray = head - other |
| def -=(other: NDArray): NDArray = head -= other |
| def -=(other: Float): NDArray = head -= other |
| def *(other: NDArray): NDArray = head * other |
| def *(other: Float): NDArray = head * other |
| def unary_-(): NDArray = -head |
| def *=(other: NDArray): NDArray = head *= other |
| def *=(other: Float): NDArray = head *= other |
| def /(other: NDArray): NDArray = head / other |
| def **(other: NDArray): NDArray = head ** other |
| def **(other: Float): NDArray = head ** other |
| def >(other: NDArray): NDArray = head > other |
| def >(other: Float): NDArray = head > other |
| def >=(other: NDArray): NDArray = head >= other |
| def >=(other: Float): NDArray = head >= other |
| def <(other: NDArray): NDArray = head < other |
| def <(other: Float): NDArray = head < other |
| def <=(other: NDArray): NDArray = head <= other |
| def <=(other: Float): NDArray = head <= other |
| def toArray: Array[Float] = head.toArray |
| def toScalar: Float = head.toScalar |
| def copyTo(other: NDArray): NDArray = head.copyTo(other) |
| def copyTo(ctx: Context): NDArray = head.copyTo(ctx) |
| def copy(): NDArray = head.copy() |
| def shape: Shape = head.shape |
| def size: Int = head.size |
| def asInContext(context: Context): NDArray = head.asInContext(context) |
| } |
| |
| private[mxnet] class NDArrayInternal (private val internal: Array[Byte], private val dtype: DType) { |
| private val unitSize = DType.numOfBytes(dtype) |
| require(internal.length > 0 && internal.length % unitSize == 0, |
| s"$dtype size $unitSize cannot divide byte array size ${internal.length}") |
| private val units: Array[Array[Byte]] = ( |
| for (i <- 0 until internal.length / unitSize) |
| yield internal.slice(i * unitSize, (i + 1) * unitSize) |
| ).toArray |
| |
| def getRaw: Array[Byte] = internal |
| def toDoubleArray: Array[Double] = { |
| require(dtype != DType.Float16, "Currently cannot convert float16 to native numerical types") |
| dtype match { |
| case DType.Float32 => units.map(wrapBytes(_).getFloat.toDouble) |
| case DType.Float64 => units.map(wrapBytes(_).getDouble) |
| case DType.Int32 => units.map(wrapBytes(_).getInt.toDouble) |
| case DType.UInt8 => internal.map(_.toDouble) |
| } |
| } |
| def toFloatArray: Array[Float] = { |
| require(dtype != DType.Float16, "Currently cannot convert float16 to native numerical types") |
| dtype match { |
| case DType.Float32 => units.map(wrapBytes(_).getFloat) |
| case DType.Float64 => units.map(wrapBytes(_).getDouble.toFloat) |
| case DType.Int32 => units.map(wrapBytes(_).getInt.toFloat) |
| case DType.UInt8 => internal.map(_.toFloat) |
| } |
| } |
| def toIntArray: Array[Int] = { |
| require(dtype != DType.Float16, "Currently cannot convert float16 to native numerical types") |
| dtype match { |
| case DType.Float32 => units.map(wrapBytes(_).getFloat.toInt) |
| case DType.Float64 => units.map(wrapBytes(_).getDouble.toInt) |
| case DType.Int32 => units.map(wrapBytes(_).getInt) |
| case DType.UInt8 => internal.map(_.toInt) |
| } |
| } |
| def toByteArray: Array[Byte] = { |
| require(dtype != DType.Float16, "Currently cannot convert float16 to native numerical types") |
| dtype match { |
| case DType.Float16 | DType.Float32 => units.map(wrapBytes(_).getFloat.toByte) |
| case DType.Float64 => units.map(wrapBytes(_).getDouble.toByte) |
| case DType.Int32 => units.map(wrapBytes(_).getInt.toByte) |
| case DType.UInt8 => internal.clone() |
| } |
| } |
| |
| private def wrapBytes(bytes: Array[Byte]): ByteBuffer = { |
| val bb = ByteBuffer.wrap(bytes) |
| bb.order(ByteOrder.LITTLE_ENDIAN) |
| bb |
| } |
| } |