blob: a943e314055f96ad3a4297e058a5ecc7fdd52af4 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.mxnet
import ml.dmlc.mxnet.Base._
import scala.collection.mutable.{ArrayBuffer, ListBuffer}
/**
* JNI functions
*/
private[mxnet] class LibInfo {
@native def nativeLibInit(): Int
@native def mxGetLastError(): String
// Operators
@native def mxListAllOpNames(names: ListBuffer[String]): Int
@native def nnGetOpHandle(opName: String, opHandle: RefLong): Int
// NDArray
@native def mxImperativeInvoke(creator: FunctionHandle,
inputs: Array[NDArrayHandle],
outputsGiven: Array[NDArrayHandle],
outputs: ArrayBuffer[NDArrayHandle],
numParams: Int,
paramKeys: Array[String],
paramVals: Array[String]): Int
@native def mxNDArrayFree(handle: NDArrayHandle): Int
@native def mxNDArrayCreateNone(out: NDArrayHandleRef): Int
@native def mxNDArrayCreateEx(shape: Array[Int],
ndim: Int,
devType: Int,
devId: Int,
delayAlloc: Int,
dtype: Int,
out: NDArrayHandleRef): Int
@native def mxNDArrayWaitAll(): Int
@native def mxNDArrayWaitToRead(handle: NDArrayHandle): Int
@native def mxListFunctions(functions: ListBuffer[FunctionHandle]): Int
@native def mxFuncDescribe(handle: FunctionHandle,
nUsedVars: MXUintRef,
nScalars: MXUintRef,
nMutateVars: MXUintRef,
typeMask: RefInt): Int
@native def mxFuncGetInfo(handle: FunctionHandle,
name: RefString,
desc: RefString,
numArgs: MXUintRef,
argNames: ListBuffer[String],
argTypes: ListBuffer[String],
argDescs: ListBuffer[String]): Int
@native def mxFuncInvoke(function: FunctionHandle,
useVars: Array[NDArrayHandle],
scalarArgs: Array[MXFloat],
mutateVars: Array[NDArrayHandle]): Int
@native def mxFuncInvokeEx(function: FunctionHandle,
useVars: Array[NDArrayHandle],
scalarArgs: Array[MXFloat],
mutateVars: Array[NDArrayHandle],
numParams: Int,
paramKeys: Array[Array[Byte]],
paramVals: Array[Array[Byte]]): Int
@native def mxNDArrayGetShape(handle: NDArrayHandle,
ndim: MXUintRef,
data: ArrayBuffer[Int]): Int
@native def mxNDArraySyncCopyToCPU(handle: NDArrayHandle,
data: Array[Byte],
size: Int): Int
@native def mxNDArraySlice(handle: NDArrayHandle,
start: MXUint,
end: MXUint,
sliceHandle: NDArrayHandleRef): Int
@native def mxNDArrayAt(handle: NDArrayHandle,
idx: MXUint,
out: NDArrayHandleRef): Int
@native def mxNDArrayReshape(handle: NDArrayHandle,
nDim: Int,
dims: Array[Int],
reshapeHandle: NDArrayHandleRef): Int
@native def mxNDArraySyncCopyFromCPU(handle: NDArrayHandle,
source: Array[MXFloat],
size: Int): Int
@native def mxNDArrayLoad(fname: String,
outSize: MXUintRef,
handles: ArrayBuffer[NDArrayHandle],
outNameSize: MXUintRef,
names: ArrayBuffer[String]): Int
@native def mxNDArraySave(fname: String,
handles: Array[NDArrayHandle],
keys: Array[String]): Int
@native def mxNDArrayGetContext(handle: NDArrayHandle, devTypeId: RefInt, devId: RefInt): Int
@native def mxNDArraySaveRawBytes(handle: NDArrayHandle, buf: ArrayBuffer[Byte]): Int
@native def mxNDArrayLoadFromRawBytes(bytes: Array[Byte], handle: NDArrayHandleRef): Int
@native def mxNDArrayGetDType(handle: NDArrayHandle, dtype: RefInt): Int
// KVStore Server
@native def mxInitPSEnv(keys: Array[String], values: Array[String]): Int
@native def mxKVStoreRunServer(handle: KVStoreHandle, controller: KVServerControllerCallback): Int
@native def mxKVStoreGetNumDeadNode(handle: KVStoreHandle, nodeId: Int, number: RefInt): Int
// KVStore
@native def mxKVStoreCreate(name: String, handle: KVStoreHandleRef): Int
@native def mxKVStoreInit(handle: KVStoreHandle,
len: MXUint,
keys: Array[Int],
values: Array[NDArrayHandle]): Int
@native def mxKVStoreInitEx(handle: KVStoreHandle,
len: MXUint,
keys: Array[String],
values: Array[NDArrayHandle]): Int
@native def mxKVStorePush(handle: KVStoreHandle,
len: MXUint,
keys: Array[Int],
values: Array[NDArrayHandle],
priority: Int): Int
@native def mxKVStorePushEx(handle: KVStoreHandle,
len: MXUint,
keys: Array[String],
values: Array[NDArrayHandle],
priority: Int): Int
@native def mxKVStorePull(handle: KVStoreHandle,
len: MXUint,
keys: Array[Int],
outs: Array[NDArrayHandle],
priority: Int): Int
@native def mxKVStorePullEx(handle: KVStoreHandle,
len: MXUint,
keys: Array[String],
outs: Array[NDArrayHandle],
priority: Int): Int
@native def mxKVStoreSetUpdater(handle: KVStoreHandle, updaterFunc: MXKVStoreUpdater): Int
@native def mxKVStoreIsWorkerNode(isWorker: RefInt): Int
@native def mxKVStoreGetType(handle: KVStoreHandle, kvType: RefString): Int
@native def mxKVStoreSendCommmandToServers(handle: KVStoreHandle,
head: Int, body: String): Int
@native def mxKVStoreBarrier(handle: KVStoreHandle): Int
@native def mxKVStoreGetGroupSize(handle: KVStoreHandle, size: RefInt): Int
@native def mxKVStoreGetRank(handle: KVStoreHandle, size: RefInt): Int
@native def mxKVStoreSetBarrierBeforeExit(handle: KVStoreHandle, doBarrier: Int): Int
@native def mxKVStoreFree(handle: KVStoreHandle): Int
// DataIter Funcs
@native def mxListDataIters(handles: ListBuffer[DataIterCreator]): Int
@native def mxDataIterCreateIter(handle: DataIterCreator,
keys: Array[String],
vals: Array[String],
out: DataIterHandleRef): Int
@native def mxDataIterGetIterInfo(creator: DataIterCreator,
name: RefString,
description: RefString,
argNames: ListBuffer[String],
argTypeInfos: ListBuffer[String],
argDescriptions: ListBuffer[String]): Int
@native def mxDataIterFree(handle: DataIterHandle): Int
@native def mxDataIterBeforeFirst(handle: DataIterHandle): Int
@native def mxDataIterNext(handle: DataIterHandle, out: RefInt): Int
@native def mxDataIterGetLabel(handle: DataIterHandle,
out: NDArrayHandleRef): Int
@native def mxDataIterGetData(handle: DataIterHandle,
out: NDArrayHandleRef): Int
@native def mxDataIterGetIndex(handle: DataIterHandle,
outIndex: ListBuffer[Long],
outSize: RefLong): Int
@native def mxDataIterGetPadNum(handle: DataIterHandle,
out: MXUintRef): Int
// Executors
@native def mxExecutorOutputs(handle: ExecutorHandle, outputs: ArrayBuffer[NDArrayHandle]): Int
@native def mxExecutorFree(handle: ExecutorHandle): Int
@native def mxExecutorForward(handle: ExecutorHandle, isTrain: Int): Int
@native def mxExecutorBackward(handle: ExecutorHandle,
grads: Array[NDArrayHandle]): Int
@native def mxExecutorPrint(handle: ExecutorHandle, debugStr: RefString): Int
@native def mxExecutorSetMonitorCallback(handle: ExecutorHandle, callback: MXMonitorCallback): Int
// Symbols
@native def mxSymbolListAtomicSymbolCreators(symbolList: ListBuffer[SymbolHandle]): Int
@native def mxSymbolGetAtomicSymbolInfo(handle: SymbolHandle,
name: RefString,
desc: RefString,
numArgs: MXUintRef,
argNames: ListBuffer[String],
argTypes: ListBuffer[String],
argDescs: ListBuffer[String],
keyVarNumArgs: RefString): Int
@native def mxSymbolCreateAtomicSymbol(handle: SymbolHandle,
paramKeys: Array[String],
paramVals: Array[String],
symHandleRef: SymbolHandleRef): Int
@native def mxSymbolSetAttr(handle: SymbolHandle, key: String, value: String): Int
@native def mxSymbolListAttrShallow(handle: SymbolHandle,
outSize: MXUintRef,
out: ArrayBuffer[String]): Int
@native def mxSymbolListAttr(handle: SymbolHandle,
outSize: MXUintRef,
out: ArrayBuffer[String]): Int
@native def mxSymbolCompose(handle: SymbolHandle,
name: String,
keys: Array[String],
args: Array[SymbolHandle]): Int
@native def mxSymbolCreateVariable(name: String, out: SymbolHandleRef): Int
@native def mxSymbolGetAttr(handle: SymbolHandle,
key: String,
ret: RefString,
success: RefInt): Int
@native def mxSymbolListArguments(handle: SymbolHandle,
arguments: ArrayBuffer[String]): Int
@native def mxSymbolCopy(handle: SymbolHandle, clonedHandle: SymbolHandleRef): Int
@native def mxSymbolListAuxiliaryStates(handle: SymbolHandle,
arguments: ArrayBuffer[String]): Int
@native def mxSymbolListOutputs(handle: SymbolHandle,
outputs: ArrayBuffer[String]): Int
@native def mxSymbolCreateGroup(handles: Array[SymbolHandle], out: SymbolHandleRef): Int
@native def mxSymbolPrint(handle: SymbolHandle, str: RefString): Int
@native def mxSymbolGetInternals(handle: SymbolHandle, out: SymbolHandleRef): Int
@native def mxSymbolInferType(handle: SymbolHandle,
keys: Array[String],
sdata: Array[Int],
argTypeData: ListBuffer[Int],
outTypeData: ListBuffer[Int],
auxTypeData: ListBuffer[Int],
complete: RefInt): Int
@native def mxSymbolInferShape(handle: SymbolHandle,
numArgs: MXUint,
keys: Array[String],
argIndPtr: Array[MXUint],
argShapeData: Array[MXUint],
inShapeData: ListBuffer[Array[Int]],
outShapeData: ListBuffer[Array[Int]],
auxShapeData: ListBuffer[Array[Int]],
complete: RefInt): Int
@native def mxSymbolGetOutput(handle: SymbolHandle, index: Int, out: SymbolHandleRef): Int
@native def mxSymbolSaveToJSON(handle: SymbolHandle, out: RefString): Int
@native def mxSymbolCreateFromJSON(json: String, handle: SymbolHandleRef): Int
// scalastyle:off parameterNum
@native def mxExecutorBindX(handle: SymbolHandle,
deviceTypeId: Int,
deviceID: Int,
numCtx: Int,
ctxMapKeys: Array[String],
ctxMapDevTypes: Array[Int],
ctxMapDevIDs: Array[Int],
numArgs: Int,
argsHandle: Array[NDArrayHandle],
argsGradHandle: Array[NDArrayHandle],
reqsArray: Array[Int],
auxArgsHandle: Array[NDArrayHandle],
out: ExecutorHandleRef): Int
@native def mxExecutorBindEX(handle: SymbolHandle,
deviceTypeId: Int,
deviceID: Int,
numCtx: Int,
ctxMapKeys: Array[String],
ctxMapDevTypes: Array[Int],
ctxMapDevIDs: Array[Int],
numArgs: Int,
argsHandle: Array[NDArrayHandle],
argsGradHandle: Array[NDArrayHandle],
reqsArray: Array[Int],
auxArgsHandle: Array[NDArrayHandle],
sharedExec: ExecutorHandle,
out: ExecutorHandleRef): Int
// scalastyle:on parameterNum
@native def mxSymbolSaveToFile(handle: SymbolHandle, fname: String): Int
@native def mxSymbolCreateFromFile(fname: String, handle: SymbolHandleRef): Int
@native def mxSymbolFree(handle: SymbolHandle): Int
// Random
@native def mxRandomSeed(seed: Int): Int
@native def mxNotifyShutdown(): Int
// RecordIO
@native def mxRecordIOWriterCreate(uri: String, out: RecordIOHandleRef): Int
@native def mxRecordIOReaderCreate(uri: String, out: RecordIOHandleRef): Int
@native def mxRecordIOWriterFree(handle: RecordIOHandle): Int
@native def mxRecordIOReaderFree(handle: RecordIOHandle): Int
@native def mxRecordIOWriterWriteRecord(handle: RecordIOHandle, buf: String, size: Int): Int
@native def mxRecordIOReaderReadRecord(handle: RecordIOHandle, buf: RefString): Int
@native def mxRecordIOWriterTell(handle: RecordIOHandle, pos: RefInt): Int
@native def mxRecordIOReaderSeek(handle: RecordIOHandle, pos: Int): Int
// Rtc
@native def mxRtcCreate(name: String,
inputNames: Array[String],
outputNames: Array[String],
inputs: Array[NDArrayHandle],
outputs: Array[NDArrayHandle],
kernel: String,
out: RtcHandleRef): Int
@native def mxRtcPush(handle: RtcHandle,
inputs: Array[NDArrayHandle],
outputs: Array[NDArrayHandle],
gridDimX: Int,
gridDimY: Int,
gridDimZ: Int,
blockDimX: Int,
blockDimY: Int,
blockDimZ: Int): Int
@native def mxRtcFree(handle: RtcHandle): Int
// CustomOp
@native def mxCustomOpRegister(regName: String, opProp: CustomOpProp): Int
// Profiler
@native def mxSetProfilerConfig(mode: Int, fileName: String): Int
@native def mxSetProfilerState(state: Int): Int
@native def mxDumpProfile(): Int
}