Java2.0 proto (#20461)
* ADD: first commit
* ADD: load local libraries
* UPDATE: use header files of MXNet 2.0
* ADD: load binaries from environment variable, java properties or jar files.
* ADD: add symbol loading and closing
add module integration
* ADD: [WIP] Component MxNDArray
* ADD: [WIP] Component MxNDArray
* ADD: Component MxNDArray. Pass static compilation check
* ADD: Component CachedOp
* REMOVE: module api which is no use
* FIX: dependency missing
* ADD: [WIP] add test cases for NdArray and CachedOp
* ADD: [WIP] add test cases for NdArray and CachedOp
* ADD: implement of the forward function for MxSymbolblock
* ADD: implement of the forward function for MxSymbolblock
* ADD: Sample model downloader for MLP
* ADD: doc
* ADD: Front-end module for inference, class MxModel, Predictor and so on.
* FIX: Mxnet crash when process exits.
* FIX: remove and initialize 3rdparty directory
* FIX: revert version of submodules: dlpack, dmlc-core, googletest, ps-lite
* Revert "FIX: remove and initialize 3rdparty directory"
This reverts commit d097675e
* FIX: redownload files in 3rdparty
* FIX: reset --hard the version of a few submodules
* FIX: reset --hard the version of a few submodules
* FIX: reset --hard the version of a few submodules
* PERF: [WIP] optimize code structure and memory management and
* ADD: add copyright; remove Mx prefix for some classes
* ADD: add copyright
* FIX: group name, path to find header files
* UPDATE: README.md
* ADD: copyright
* ADD: copyright
* ADD: package-info
ADD: ci
ADD: ci
ADD: make modification to trigger ci
ADD: ci
ADD: ci
ADD: ci
ADD: ci
ADD: gradlew
ADD: java_package_ci
ADD: java_package_ci
ADD: java_package_ci
ADD: java_package_ci
ADD: java_package_ci
ADD: java_package_ci
ADD: java_package_ci
ADD: java_package_ci
ADD: java_package_ci
FIX: build failure
* FIX: ci config file
* UPDATE: remove ParameterStore and some scripts
UPDATE: remove Initializer.java
* UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: jenkins ci scripts for java-package
UPDATE: jenkins ci scripts for java-package
UPDATE: jenkins ci scripts for java-package
UPDATE: jenkins ci scripts for java-package
UPDATE: jenkins ci scripts for java-package
UPDATE: jenkins ci scripts for java-package
UPDATE: jenkins ci scripts for java-package
UPDATE: jenkins ci scripts for java-package
UPDATE: jenkins ci scripts for java-package
UPDATE: jenkins ci scripts for java-package
FIX: issues in Resource close methods
FIX: issues in Resource close methods
FIX: issues in Resource close methods
UPDATE: remove scripts for dev
* FIX: loading on Linux platform
* # This is a combination of 18 commits.
# This is the 1st commit message:
FIX: loading on Linux platform
# This is the commit message #2:
UPDATE: ci for java-package
# This is the commit message #3:
UPDATE: ci for java-package
# This is the commit message #4:
UPDATE: ci for java-package
# This is the commit message #5:
UPDATE: ci for java-package
# This is the commit message #6:
UPDATE: ci for java-package
# This is the commit message #7:
UPDATE: ci for java-package
# This is the commit message #8:
UPDATE: ci for java-package
# This is the commit message #9:
UPDATE: ci for java-package
# This is the commit message #10:
UPDATE: ci for java-package
# This is the commit message #11:
UPDATE: ci for java-package
# This is the commit message #12:
UPDATE: ci for java-package
# This is the commit message #13:
UPDATE: ci for java-package
# This is the commit message #14:
UPDATE: ci for java-package
# This is the commit message #15:
UPDATE: ci for java-package
# This is the commit message #16:
UPDATE: ci for java-package
# This is the commit message #17:
UPDATE: ci for java-package
# This is the commit message #18:
UPDATE: ci for java-package
* # This is a combination of 27 commits.
parent 1ea18edce197b402cbfbaaaa54a94347501d92ab
author cspchen <cspchen@amazon.com> 1629186478 +0800
committer cspchen <cspchen@amazon.com> 1629186485 +0800
# This is a combination of 21 commits.
# This is the 1st commit message:
FIX: loading on Linux platform
# This is the commit message #2:
UPDATE: ci for java-package
# This is the commit message #3:
UPDATE: ci for java-package
# This is the commit message #4:
UPDATE: ci for java-package
# This is the commit message #5:
UPDATE: ci for java-package
# This is the commit message #6:
UPDATE: ci for java-package
# This is the commit message #7:
UPDATE: ci for java-package
# This is the commit message #8:
UPDATE: ci for java-package
# This is the commit message #9:
UPDATE: ci for java-package
# This is the commit message #10:
UPDATE: ci for java-package
# This is the commit message #11:
UPDATE: ci for java-package
# This is the commit message #12:
UPDATE: ci for java-package
# This is the commit message #13:
UPDATE: ci for java-package
# This is the commit message #14:
UPDATE: ci for java-package
# This is the commit message #15:
UPDATE: ci for java-package
# This is the commit message #16:
UPDATE: ci for java-package
# This is the commit message #17:
UPDATE: ci for java-package
# This is the commit message #18:
UPDATE: ci for java-package
# This is the commit message #19:
UPDATE: ci for java-package
# This is the commit message #20:
UPDATE: ci for java-package
# This is the commit message #21:
UPDATE: ci for java-package
# This is the commit message #22:
UPDATE: ci for java-package
# This is the commit message #23:
UPDATE: jenkins ci scripts for java-package
# This is the commit message #24:
UPDATE: jenkins ci scripts for java-package
# This is the commit message #25:
UPDATE: jenkins ci scripts for java-package
# This is the commit message #26:
UPDATE: jenkins ci scripts for java-package
# This is the commit message #27:
UPDATE: jenkins ci scripts for java-package
* MERGE: resolve conflicts
* MERGE: resolve conflicts
* # This is a combination of 35 commits.
parent 1ea18edce197b402cbfbaaaa54a94347501d92ab
author cspchen <cspchen@amazon.com> 1629186478 +0800
committer cspchen <cspchen@amazon.com> 1629186485 +0800
# This is a combination of 21 commits.
# This is the 1st commit message:
FIX: loading on Linux platform
# This is the commit message #2:
UPDATE: ci for java-package
# This is the commit message #3:
UPDATE: ci for java-package
# This is the commit message #4:
UPDATE: ci for java-package
# This is the commit message #5:
UPDATE: ci for java-package
# This is the commit message #6:
UPDATE: ci for java-package
# This is the commit message #7:
UPDATE: ci for java-package
# This is the commit message #8:
UPDATE: ci for java-package
# This is the commit message #9:
UPDATE: ci for java-package
# This is the commit message #10:
UPDATE: ci for java-package
# This is the commit message #11:
UPDATE: ci for java-package
# This is the commit message #12:
UPDATE: ci for java-package
# This is the commit message #13:
UPDATE: ci for java-package
# This is the commit message #14:
UPDATE: ci for java-package
# This is the commit message #15:
UPDATE: ci for java-package
# This is the commit message #16:
UPDATE: ci for java-package
# This is the commit message #17:
UPDATE: ci for java-package
# This is the commit message #18:
UPDATE: ci for java-package
# This is the commit message #19:
UPDATE: ci for java-package
# This is the commit message #20:
UPDATE: ci for java-package
# This is the commit message #21:
UPDATE: ci for java-package
# This is the commit message #22:
UPDATE: ci for java-package
# This is the commit message #23:
UPDATE: jenkins ci scripts for java-package
# This is the commit message #24:
UPDATE: jenkins ci scripts for java-package
# This is the commit message #25:
UPDATE: jenkins ci scripts for java-package
# This is the commit message #26:
UPDATE: jenkins ci scripts for java-package
# This is the commit message #27:
UPDATE: jenkins ci scripts for java-package
# This is the commit message #28:
UPDATE: jenkins ci scripts for java-package
# This is the commit message #30:
UPDATE: jenkins ci scripts for java-package
# This is the commit message #31:
UPDATE: jenkins ci scripts for java-package
# This is the commit message #32:
UPDATE: jenkins ci scripts for java-package
# This is the commit message #33:
UPDATE: jenkins ci scripts for java-package
# This is the commit message #34:
FIX: issues in Resource close methods
# This is the commit message #35:
FIX: issues in Resource close methods
* parent 1ea18edce197b402cbfbaaaa54a94347501d92ab
author cspchen <cspchen@amazon.com> 1629186478 +0800
committer cspchen <cspchen@amazon.com> 1629186485 +0800
FIX: loading on Linux platform
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: jenkins ci scripts for java-package
UPDATE: jenkins ci scripts for java-package
UPDATE: jenkins ci scripts for java-package
UPDATE: jenkins ci scripts for java-package
UPDATE: jenkins ci scripts for java-package
UPDATE: jenkins ci scripts for java-package
UPDATE: jenkins ci scripts for java-package
UPDATE: jenkins ci scripts for java-package
UPDATE: jenkins ci scripts for java-package
UPDATE: jenkins ci scripts for java-package
FIX: issues in Resource close methods
FIX: issues in Resource close methods
FIX: issues in Resource close methods
DOC: add doc
STYLE: change code style for pmd check
FIX: avoid the register for a signal handler twice
STYLE: pass pmd check
UPDATE: remove unused scripts
* FIX: solve problems before merge
* UPDATE: remove useless files
* FIX: licence to apache
* FIX: sanity check
* FIX: sanity check
* FIX: sanity check
* FIX: remove unused files
* FIX: remove unused files
* DOC: add document
* FIX: doesn't work on osx
* FIX: clang static check
* FIX: sanity
* FIX: skip signal handler registration when building java package
* FIX: remove DataType String
* FIX: add license
Co-authored-by: cspchen <cspchen@amazon.com>
diff --git a/CMakeLists.txt b/CMakeLists.txt
index b901f41..c6e0a4e 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -91,6 +91,7 @@
option(BUILD_EXTENSION_PATH "Path to extension to build" "")
option(BUILD_CYTHON_MODULES "Build cython modules." OFF)
option(LOG_FATAL_THROW "Log exceptions but do not abort" ON)
+option(BUILD_JAVA_NATIVE "Skip signal handler registration for Java Binding" OFF)
cmake_dependent_option(USE_SPLIT_ARCH_DLL "Build a separate DLL for each Cuda arch (Windows only)." ON "MSVC" OFF)
cmake_dependent_option(USE_CCACHE "Attempt using CCache to wrap the compilation" ON "UNIX" OFF)
cmake_dependent_option(MXNET_FORCE_SHARED_CRT "Build with dynamic CRT on Windows (/MD)" ON "MXNET_BUILD_SHARED_LIBS" OFF)
@@ -970,6 +971,10 @@
target_compile_definitions(mxnet PUBLIC MXNET_USE_CPP_PACKAGE=1)
endif()
+if(BUILD_JAVA_NATIVE)
+ add_definitions(-DSKIP_SIGNAL_HANDLER_REGISTRATION=1)
+endif()
+
if(NOT CMAKE_BUILD_TYPE STREQUAL "Distribution")
# Staticbuild applies linker version script to hide private symbols, breaking unit tests
add_subdirectory(tests)
diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh
index 70934f6..1e75ab2 100755
--- a/ci/docker/runtime_functions.sh
+++ b/ci/docker/runtime_functions.sh
@@ -310,6 +310,23 @@
build_ubuntu_cpu_openblas
}
+build_ubuntu_cpu_and_test_java() {
+ build_ubuntu_cpu_openblas_java
+ java_package_integration_test
+}
+
+java_package_integration_test() {
+ # make sure you are using java 11
+ # build java project
+ cd /work/mxnet/java-package
+ ./gradlew build -x javadoc
+ # generate native library
+ ./gradlew :native:buildLocalLibraryJarDefault
+ ./gradlew :native:mkl-linuxJar
+ # run integration
+ ./gradlew :integration:run
+}
+
build_ubuntu_cpu_openblas() {
set -ex
cd /work/build
@@ -327,6 +344,24 @@
ninja
}
+build_ubuntu_cpu_openblas_java() {
+ set -ex
+ cd /work/build
+ CXXFLAGS="-Wno-error=strict-overflow" CC=gcc-7 CXX=g++-7 cmake \
+ -DCMAKE_BUILD_TYPE="RelWithDebInfo" \
+ -DENABLE_TESTCOVERAGE=ON \
+ -DUSE_TVM_OP=ON \
+ -DUSE_BLAS=Open \
+ -DUSE_ONEDNN=OFF \
+ -DUSE_CUDA=OFF \
+ -DUSE_DIST_KVSTORE=ON \
+ -DBUILD_CYTHON_MODULES=ON \
+ -DBUILD_EXTENSION_PATH=/work/mxnet/example/extensions/lib_external_ops \
+ -DBUILD_JAVA_NATIVE=ON \
+ -G Ninja /work/mxnet
+ ninja
+}
+
build_ubuntu_cpu_mkl() {
set -ex
cd /work/build
diff --git a/java-package/Develop.md b/java-package/Develop.md
new file mode 100644
index 0000000..8bc0ffe
--- /dev/null
+++ b/java-package/Develop.md
@@ -0,0 +1,109 @@
+<!--- Licensed to the Apache Software Foundation (ASF) under one -->
+<!--- or more contributor license agreements. See the NOTICE file -->
+<!--- distributed with this work for additional information -->
+<!--- regarding copyright ownership. The ASF licenses this file -->
+<!--- to you under the Apache License, Version 2.0 (the -->
+<!--- "License"); you may not use this file except in compliance -->
+<!--- with the License. You may obtain a copy of the License at -->
+
+<!--- http://www.apache.org/licenses/LICENSE-2.0 -->
+
+<!--- Unless required by applicable law or agreed to in writing, -->
+<!--- software distributed under the License is distributed on an -->
+<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
+<!--- KIND, either express or implied. See the License for the -->
+<!--- specific language governing permissions and limitations -->
+<!--- under the License. -->
+
+# Development Tips
+
+## Set up the Project
+### Step 1. Obtain MXNet Library
+The first step is to obtain the mxnet library. We recommend you build it from source. Also, you can download the library
+from
+#### Build from source
+Refer to [Build From Source](https://mxnet.apache.org/get_started/build_from_source#building-mxnet)
+For MacOS users:
+- Prepare
+```shell
+# Install OS X Developer Tools
+$ xcode-select --install
+
+# Install Homebrew
+$ /usr/bin/ruby -e "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/master/install)"
+
+# Install dependencies
+$ brew install cmake ninja ccache opencv
+```
+- Clone 3rd party projects
+```shell
+# Clone 3rd dependency for mxnet. It's necessary
+$ git submodule update --init --recursive
+```
+- Build MXNet
+```shell
+# select and copy cmake configure files for macos
+$ cp config/darwin.cmake config.cmake
+
+# create build directory for prpject
+$ mkdir build; cd build
+
+# cmake
+$ cmake ..
+$ cmake --build .
+```
+Libraries will be generated under the directory _build/_.
+
+For Linux users:
+Docker might help you build libraries on different platforms. You can get help from [README for CI](../ci/README.md).
+For example, you can build mxnet on Ubuntu with by the following command.
+```shell
+$ python3 ci/build.py -p ubuntu_cpu
+```
+##### Download Pre-built library
+You can find the mxnet library from installed packages for mxnet, like python module. However, mxnet 2.0 is not released
+yet, that's why we recommend you build it from source.
+```shell
+# download python module for mxnet (have to mention that mxnet 2.0 hasn't been released by now)
+$ pip3 install mxnet==1.7.0.post2
+# find the location of the installed module
+$ python
+Python 3.6.8 |Anaconda, Inc.| (default, Dec 29 2018, 19:04:46)
+>>> import mxnet
+>>> mxnet
+<module 'mxnet' from '/Users/xxx/anaconda3/lib/python3.6/site-packages/mxnet/__init__.py'>
+>>> quit()
+# you can locate the module under /Users/xxx/anaconda3/lib/python3.6/site-packages/mxnet/
+$ ls /Users/xxx/anaconda3/lib/python3.6/site-packages/mxnet/ | grep libmxnet
+libmxnet.dylib
+```
+The compiled library is the file with the name of _libmxnet.*_. For MacOS, you will receive the file with suffix
+_.dylib_; For Linux, the lib file have the suffix ".so"; For Windows, the suffix is "."
+
+### Step 2. Build MXNet Native Lib for Java
+The project uses gradle to manage dependencies. You can build the project using gradle. We have to encapsulate the mxnet
+library into a jar file so that we can load it into JVM.
+```shell
+$ cd java-package
+# Build the project
+$ ./gradlew build
+# Create gradle tasks to package mxnet library into jar
+# The task name is in this form {$favor}-{$platform}Jar
+# MacOS -> mkl-osxJar
+# Linux -> mkl-linuxJar
+# Windows -> mkl-winJar
+$ ./gradlew :native:buildLocalLibraryJarDefault
+# Build native lib for macos
+$ ./gradlew mkl-osxJar
+# Check the lib for osx
+$ ls native/build/libs | grep osx
+native-2.0.0-SNAPSHOT-osx-x86_64.jar
+```
+The jar file _native-2.0.0-SNAPSHOT-osx-x86_64.jar_ is the output lib file.
+
+### Step 3. Run Integration Test
+When we execute the task for integration test, the built mxnet native lib will be added into classpath automatically.
+```shell
+$ ./gradlew :integration:run
+
+```
\ No newline at end of file
diff --git a/java-package/README.md b/java-package/README.md
new file mode 100644
index 0000000..eb5b901
--- /dev/null
+++ b/java-package/README.md
@@ -0,0 +1,58 @@
+<!--- Licensed to the Apache Software Foundation (ASF) under one -->
+<!--- or more contributor license agreements. See the NOTICE file -->
+<!--- distributed with this work for additional information -->
+<!--- regarding copyright ownership. The ASF licenses this file -->
+<!--- to you under the Apache License, Version 2.0 (the -->
+<!--- "License"); you may not use this file except in compliance -->
+<!--- with the License. You may obtain a copy of the License at -->
+
+<!--- http://www.apache.org/licenses/LICENSE-2.0 -->
+
+<!--- Unless required by applicable law or agreed to in writing, -->
+<!--- software distributed under the License is distributed on an -->
+<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
+<!--- KIND, either express or implied. See the License for the -->
+<!--- specific language governing permissions and limitations -->
+<!--- under the License. -->
+
+# Java Package for MXNet 2.0
+
+## Requirements
+
+## Install
+
+## Scripts
+- customize mxnet library path
+```bash
+export MXNET_LIBRARY_PATH=//anaconda3/lib/python3.8/site-packages/mxnet/
+```
+
+
+## Tests
+Test case for a rough inference run with MXNet model
+```bash
+./gradlew :integration:run
+```
+
+## Example
+
+```java
+try (MxResource base = BaseMxResource.getSystemMxResource())
+ {
+ Model model = Model.loadModel(Item.MLP);
+// Model model = Model.loadModel("test", Paths.get("/Users/cspchen/mxnet.java_package/cache/repo/test-models/mlp.tar.gz/mlp/"));
+ Predictor<NDList, NDList> predictor = model.newPredictor();
+ NDArray input = NDArray.create(base, new Shape(1, 28, 28)).ones();
+ NDList inputs = new NDList();
+ inputs.add(input);
+ NDList result = predictor.predict(inputs);
+ NDArray expected = NDArray.create(
+ base,
+ new float[]{4.93476f, -0.76084447f, 0.37713608f, 0.6605506f, -1.3485785f, -0.8736369f
+ , 0.018061712f, -1.3274033f, 1.0609543f, 0.24042489f}, new Shape(1, 10));
+ Assertions.assertAlmostEquals(result.get(0), expected);
+
+ } catch (IOException e) {
+ logger.error(e.getMessage(), e);
+ }
+```
\ No newline at end of file
diff --git a/java-package/build.gradle b/java-package/build.gradle
new file mode 100644
index 0000000..909dbbb
--- /dev/null
+++ b/java-package/build.gradle
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+plugins {
+ id "com.github.spotbugs" version "4.2.0" apply true
+}
+
+defaultTasks 'build'
+
+allprojects {
+ group 'org.apache.mxnet'
+ boolean isRelease = project.hasProperty("release") || project.hasProperty("staging")
+ version = "${java_package_version}" + (isRelease ? "" : "-SNAPSHOT")
+
+ repositories {
+// maven {
+// url "https://mlrepo.djl.ai/maven/"
+// }
+ mavenCentral()
+ maven {
+ url 'https://oss.sonatype.org/content/repositories/snapshots/'
+ }
+ }
+
+ apply plugin: 'idea'
+ idea {
+ module {
+ outputDir = file('build/classes/java/main')
+ testOutputDir = file('build/classes/java/test')
+ // inheritOutputDirs = true
+ }
+ }
+}
+
+def javaProjects() {
+ return subprojects.findAll { new File(it.projectDir, "src/main").exists() }
+}
+
+configure(javaProjects()) {
+ apply plugin: 'java-library'
+ sourceCompatibility = 1.8
+ targetCompatibility = 1.8
+ compileJava.options.encoding = "UTF-8"
+ compileTestJava.options.encoding = "UTF-8"
+ if (JavaVersion.current() != JavaVersion.VERSION_1_8) {
+ compileJava.options.compilerArgs.addAll(["--release", "8"])
+ }
+
+ apply plugin: 'eclipse'
+
+ eclipse {
+ jdt.file.withProperties { props ->
+ props.setProperty "org.eclipse.jdt.core.circularClasspath", "warning"
+ }
+ classpath {
+ sourceSets.test.java {
+ srcDirs = ["src/test/java"]
+ exclude "**/package-info.java"
+ }
+ }
+ }
+
+ apply from: file("${rootProject.projectDir}/tools/gradle/java-formatter.gradle")
+ apply from: file("${rootProject.projectDir}/tools/gradle/check.gradle")
+
+ test {
+ // tensorflow mobilenet and resnet require more cpu memory
+ maxHeapSize = "4096m"
+ doFirst {
+ if (JavaVersion.current() != JavaVersion.VERSION_1_8) {
+ jvmArgs = [
+ '--add-opens', "java.base/jdk.internal.loader=ALL-UNNAMED"
+ ]
+ }
+ }
+
+ useTestNG() {
+// suiteXmlFiles << new File(rootDir, "testng.xml") //This is how to add custom testng.xml
+ }
+
+ testLogging {
+ showStandardStreams = true
+ events "passed", "skipped", "failed", "standardOut", "standardError"
+ }
+
+ doFirst {
+ systemProperties System.getProperties()
+ systemProperties.remove("user.dir")
+ systemProperty "org.apache.mxnet.logging.level", "debug"
+ systemProperty "org.slf4j.simpleLogger.defaultLogLevel", "debug"
+ systemProperty "org.slf4j.simpleLogger.log.org.mortbay.log", "warn"
+ systemProperty "disableProgressBar", "true"
+ systemProperty "nightly", System.getProperty("nightly", "false")
+// systemProperty "java.library.path", "/Users/cspchen/Work/incubator-mxnet/build"
+ if (gradle.startParameter.offline) {
+ systemProperty "offline", "true"
+ }
+ }
+ }
+
+ compileJava {
+ options.compilerArgs << "-proc:none" << "-Xlint:all,-options,-static" << "-Werror"
+ }
+
+ compileTestJava {
+ options.compilerArgs << "-proc:none" << "-Xlint:all,-options,-static" << "-Werror"
+ }
+
+ jar {
+ manifest {
+ attributes("Automatic-Module-Name": "org.apach.mxnet.${project.name.replace('-', '_')}")
+ }
+ }
+}
+
+apply from: file("${rootProject.projectDir}/tools/gradle/jacoco.gradle")
+apply from: file("${rootProject.projectDir}/tools/gradle/stats.gradle")
diff --git a/java-package/example/build.gradle b/java-package/example/build.gradle
new file mode 100644
index 0000000..a6831ee
--- /dev/null
+++ b/java-package/example/build.gradle
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+plugins {
+ id 'java'
+}
+
+group 'incubator-mxnet.java-package'
+version '0.0.1-SNAPSHOT'
+
+repositories {
+ mavenCentral()
+}
+
+dependencies {
+ testImplementation 'org.junit.jupiter:junit-jupiter-api:5.7.0'
+ testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine:5.7.0'
+}
+
+test {
+ useJUnitPlatform()
+}
\ No newline at end of file
diff --git a/java-package/gradle.properties b/java-package/gradle.properties
new file mode 100644
index 0000000..9d32099
--- /dev/null
+++ b/java-package/gradle.properties
@@ -0,0 +1,43 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+org.gradle.daemon=true
+org.gradle.jvmargs=-Xmx2048M
+
+systemProp.org.gradle.internal.http.socketTimeout=120000
+systemProp.org.gradle.internal.http.connectionTimeout=60000
+
+# FIXME: Workaround gradle publish issue: https://github.com/gradle/gradle/issues/11308
+systemProp.org.gradle.internal.publish.checksums.insecure=true
+
+java_package_version=0.0.1
+mxnet_version=2.0.0
+api_version=0.0.1
+jnarator_version=0.0.1
+
+antlr_version=4.7.2
+commons_cli_version=1.4
+commons_compress_version=1.20
+commons_csv_version=1.8
+gson_version=2.8.6
+jna_version=5.3.0
+netty_version=4.1.51.Final
+slf4j_version=1.7.30
+log4j_slf4j_version=2.13.3
+testng_version=7.1.0
+powermock_version=2.0.7
+
diff --git a/java-package/gradle/wrapper/gradle-wrapper.jar b/java-package/gradle/wrapper/gradle-wrapper.jar
new file mode 100644
index 0000000..e708b1c
--- /dev/null
+++ b/java-package/gradle/wrapper/gradle-wrapper.jar
Binary files differ
diff --git a/java-package/gradle/wrapper/gradle-wrapper.properties b/java-package/gradle/wrapper/gradle-wrapper.properties
new file mode 100644
index 0000000..689c068
--- /dev/null
+++ b/java-package/gradle/wrapper/gradle-wrapper.properties
@@ -0,0 +1,22 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+distributionBase=GRADLE_USER_HOME
+distributionPath=wrapper/dists
+distributionUrl=https\://services.gradle.org/distributions/gradle-7.0-bin.zip
+zipStoreBase=GRADLE_USER_HOME
+zipStorePath=wrapper/dists
diff --git a/java-package/gradlew b/java-package/gradlew
new file mode 100755
index 0000000..ebb6c09
--- /dev/null
+++ b/java-package/gradlew
@@ -0,0 +1,186 @@
+#!/usr/bin/env sh
+
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+##############################################################################
+##
+## Gradle start up script for UN*X
+##
+##############################################################################
+
+# Attempt to set APP_HOME
+# Resolve links: $0 may be a link
+PRG="$0"
+# Need this for relative symlinks.
+while [ -h "$PRG" ] ; do
+ ls=`ls -ld "$PRG"`
+ link=`expr "$ls" : '.*-> \(.*\)$'`
+ if expr "$link" : '/.*' > /dev/null; then
+ PRG="$link"
+ else
+ PRG=`dirname "$PRG"`"/$link"
+ fi
+done
+SAVED="`pwd`"
+cd "`dirname \"$PRG\"`/" >/dev/null
+APP_HOME="`pwd -P`"
+cd "$SAVED" >/dev/null
+
+APP_NAME="Gradle"
+APP_BASE_NAME=`basename "$0"`
+
+# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
+DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"'
+
+# Use the maximum available, or set MAX_FD != -1 to use that value.
+MAX_FD="maximum"
+
+warn () {
+ echo "$*"
+}
+
+die () {
+ echo
+ echo "$*"
+ echo
+ exit 1
+}
+
+# OS specific support (must be 'true' or 'false').
+cygwin=false
+msys=false
+darwin=false
+nonstop=false
+case "`uname`" in
+ CYGWIN* )
+ cygwin=true
+ ;;
+ Darwin* )
+ darwin=true
+ ;;
+ MINGW* )
+ msys=true
+ ;;
+ NONSTOP* )
+ nonstop=true
+ ;;
+esac
+
+CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
+
+
+# Determine the Java command to use to start the JVM.
+if [ -n "$JAVA_HOME" ] ; then
+ if [ -x "$JAVA_HOME/jre/sh/java" ] ; then
+ # IBM's JDK on AIX uses strange locations for the executables
+ JAVACMD="$JAVA_HOME/jre/sh/java"
+ else
+ JAVACMD="$JAVA_HOME/bin/java"
+ fi
+ if [ ! -x "$JAVACMD" ] ; then
+ die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME
+
+Please set the JAVA_HOME variable in your environment to match the
+location of your Java installation."
+ fi
+else
+ JAVACMD="java"
+ which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
+
+Please set the JAVA_HOME variable in your environment to match the
+location of your Java installation."
+fi
+
+# Increase the maximum file descriptors if we can.
+if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then
+ MAX_FD_LIMIT=`ulimit -H -n`
+ if [ $? -eq 0 ] ; then
+ if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then
+ MAX_FD="$MAX_FD_LIMIT"
+ fi
+ ulimit -n $MAX_FD
+ if [ $? -ne 0 ] ; then
+ warn "Could not set maximum file descriptor limit: $MAX_FD"
+ fi
+ else
+ warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT"
+ fi
+fi
+
+# For Darwin, add options to specify how the application appears in the dock
+if $darwin; then
+ GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\""
+fi
+
+# For Cygwin or MSYS, switch paths to Windows format before running java
+if [ "$cygwin" = "true" -o "$msys" = "true" ] ; then
+ APP_HOME=`cygpath --path --mixed "$APP_HOME"`
+ CLASSPATH=`cygpath --path --mixed "$CLASSPATH"`
+
+ JAVACMD=`cygpath --unix "$JAVACMD"`
+
+ # We build the pattern for arguments to be converted via cygpath
+ ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null`
+ SEP=""
+ for dir in $ROOTDIRSRAW ; do
+ ROOTDIRS="$ROOTDIRS$SEP$dir"
+ SEP="|"
+ done
+ OURCYGPATTERN="(^($ROOTDIRS))"
+ # Add a user-defined pattern to the cygpath arguments
+ if [ "$GRADLE_CYGPATTERN" != "" ] ; then
+ OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)"
+ fi
+ # Now convert the arguments - kludge to limit ourselves to /bin/sh
+ i=0
+ for arg in "$@" ; do
+ CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -`
+ CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option
+
+ if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition
+ eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"`
+ else
+ eval `echo args$i`="\"$arg\""
+ fi
+ i=`expr $i + 1`
+ done
+ case $i in
+ 0) set -- ;;
+ 1) set -- "$args0" ;;
+ 2) set -- "$args0" "$args1" ;;
+ 3) set -- "$args0" "$args1" "$args2" ;;
+ 4) set -- "$args0" "$args1" "$args2" "$args3" ;;
+ 5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;;
+ 6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;;
+ 7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;;
+ 8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;;
+ 9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;;
+ esac
+fi
+
+# Escape application args
+save () {
+ for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done
+ echo " "
+}
+APP_ARGS=`save "$@"`
+
+# Collect all arguments for the java command, following the shell quoting and substitution rules
+eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS"
+
+exec "$JAVACMD" "$@"
diff --git a/java-package/gradlew.bat b/java-package/gradlew.bat
new file mode 100644
index 0000000..bcf7fb7
--- /dev/null
+++ b/java-package/gradlew.bat
@@ -0,0 +1,92 @@
+@REM
+@REM Licensed to the Apache Software Foundation (ASF) under one
+@REM or more contributor license agreements. See the NOTICE file
+@REM distributed with this work for additional information
+@REM regarding copyright ownership. The ASF licenses this file
+@REM to you under the Apache License, Version 2.0 (the
+@REM "License"); you may not use this file except in compliance
+@REM with the License. You may obtain a copy of the License at
+@REM
+@REM http://www.apache.org/licenses/LICENSE-2.0
+@REM
+@REM Unless required by applicable law or agreed to in writing,
+@REM software distributed under the License is distributed on an
+@REM "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+@REM KIND, either express or implied. See the License for the
+@REM specific language governing permissions and limitations
+@REM under the License.
+@REM
+
+@if "%DEBUG%" == "" @echo off
+@rem ##########################################################################
+@rem
+@rem Gradle startup script for Windows
+@rem
+@rem ##########################################################################
+
+@rem Set local scope for the variables with windows NT shell
+if "%OS%"=="Windows_NT" setlocal
+
+set DIRNAME=%~dp0
+if "%DIRNAME%" == "" set DIRNAME=.
+set APP_BASE_NAME=%~n0
+set APP_HOME=%DIRNAME%
+
+@rem Resolve any "." and ".." in APP_HOME to make it shorter.
+for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi
+
+@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
+set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m"
+
+@rem Find java.exe
+if defined JAVA_HOME goto findJavaFromJavaHome
+
+set JAVA_EXE=java.exe
+%JAVA_EXE% -version >NUL 2>&1
+if "%ERRORLEVEL%" == "0" goto execute
+
+echo.
+echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
+echo.
+echo Please set the JAVA_HOME variable in your environment to match the
+echo location of your Java installation.
+
+goto fail
+
+:findJavaFromJavaHome
+set JAVA_HOME=%JAVA_HOME:"=%
+set JAVA_EXE=%JAVA_HOME%/bin/java.exe
+
+if exist "%JAVA_EXE%" goto execute
+
+echo.
+echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME%
+echo.
+echo Please set the JAVA_HOME variable in your environment to match the
+echo location of your Java installation.
+
+goto fail
+
+:execute
+@rem Setup the command line
+
+set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar
+
+
+@rem Execute Gradle
+"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %*
+
+:end
+@rem End local scope for the variables with windows NT shell
+if "%ERRORLEVEL%"=="0" goto mainEnd
+
+:fail
+rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of
+rem the _cmd.exe /c_ return code!
+if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1
+exit /b 1
+
+:mainEnd
+if "%OS%"=="Windows_NT" endlocal
+
+:omega
diff --git a/java-package/integration/build.gradle b/java-package/integration/build.gradle
new file mode 100644
index 0000000..fed8860
--- /dev/null
+++ b/java-package/integration/build.gradle
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+plugins {
+ id 'application'
+ id 'jacoco'
+}
+
+group 'org.apache.mxnet'
+version '0.0.1-SNAPSHOT'
+
+repositories {
+ mavenCentral()
+}
+
+application {
+ mainClassName = System.getProperty("main", "org.apache.mxnet.integration.IntegrationTest")
+}
+
+dependencies {
+ api "commons-cli:commons-cli:${commons_cli_version}"
+ api "org.apache.logging.log4j:log4j-slf4j-impl:${log4j_slf4j_version}"
+
+ api project(":mxnet-engine")
+ implementation "org.testng:testng:${testng_version}"
+// testImplementation(":mxnet-engine")
+ testImplementation("org.testng:testng:${testng_version}") {
+ exclude group: "junit", module: "junit"
+ }
+ testImplementation "org.slf4j:slf4j-simple:${slf4j_version}"
+}
+
+run {
+ systemProperties System.getProperties()
+ systemProperties.remove("user.dir")
+ systemProperty("file.encoding", "UTF-8")
+ jvmArgs "-Xverify:none"
+}
+
+checkstyleMain {
+ // skip check style for this package
+ exclude 'org/apache/mxnet/integration/**'
+}
+
+//test {
+//
+// useTestNG()
+// filter {
+// includeTestsMatching "org.apache.mxnet.integration.tests.engine.*"
+// }
+//
+//}
+
diff --git a/java-package/integration/src/main/java/org/apache/mxnet/integration/IntegrationTest.java b/java-package/integration/src/main/java/org/apache/mxnet/integration/IntegrationTest.java
new file mode 100644
index 0000000..96d0ad3
--- /dev/null
+++ b/java-package/integration/src/main/java/org/apache/mxnet/integration/IntegrationTest.java
@@ -0,0 +1,413 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.integration;
+
+import java.io.File;
+import java.io.IOException;
+import java.lang.reflect.Constructor;
+import java.lang.reflect.InvocationTargetException;
+import java.lang.reflect.Method;
+import java.net.URISyntaxException;
+import java.net.URL;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.time.Duration;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.Enumeration;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.jar.JarEntry;
+import java.util.jar.JarFile;
+import java.util.stream.Collectors;
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.DefaultParser;
+import org.apache.commons.cli.HelpFormatter;
+import org.apache.commons.cli.Options;
+import org.apache.commons.cli.ParseException;
+import org.apache.mxnet.integration.util.Arguments;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.testng.SkipException;
+import org.testng.annotations.AfterClass;
+import org.testng.annotations.AfterTest;
+import org.testng.annotations.BeforeClass;
+import org.testng.annotations.BeforeTest;
+import org.testng.annotations.Test;
+
+public class IntegrationTest {
+
+ private static final Logger logger = LoggerFactory.getLogger(IntegrationTest.class);
+
+ private Class<?> source;
+
+ public IntegrationTest(Class<?> source) {
+ this.source = source;
+ }
+
+ public static void main(String[] args) {
+ new IntegrationTest(IntegrationTest.class).runTests(args);
+ // TODO: not elegant solution to native library crash
+ // System.exit(0);
+ }
+
+ public boolean runTests(String[] args) {
+ Options options = Arguments.getOptions();
+ try {
+ DefaultParser parser = new DefaultParser();
+ CommandLine cmd = parser.parse(options, args, null, false);
+ Arguments arguments = new Arguments(cmd);
+
+ Duration duration = Duration.ofMinutes(arguments.getDuration());
+ List<TestClass> tests = listTests(arguments, source);
+
+ boolean testsPassed = true;
+ while (!duration.isNegative()) {
+ long begin = System.currentTimeMillis();
+
+ testsPassed = testsPassed && runTests(tests);
+
+ long delta = System.currentTimeMillis() - begin;
+ duration = duration.minus(Duration.ofMillis(delta));
+ }
+ return testsPassed;
+ } catch (ParseException e) {
+ HelpFormatter formatter = new HelpFormatter();
+ formatter.setLeftPadding(1);
+ formatter.setWidth(120);
+ formatter.printHelp(e.getMessage(), options);
+ return false;
+ } catch (Throwable t) {
+ logger.error("Unexpected error", t);
+ return false;
+ }
+ }
+
+ private boolean runTests(List<TestClass> tests) {
+ Map<TestResult, Integer> totals = new ConcurrentHashMap<>();
+ for (TestClass testClass : tests) {
+ logger.info("Running test {} ...", testClass.getName());
+ int testCount = testClass.getTestCount();
+
+ try {
+ if (!testClass.beforeClass()) {
+ totals.merge(TestResult.FAILED, testCount, Integer::sum);
+ continue;
+ }
+
+ for (int i = 0; i < testCount; ++i) {
+ TestResult result = testClass.runTest(i);
+ totals.merge(result, 1, Integer::sum);
+ }
+ } finally {
+ testClass.afterClass();
+ }
+ }
+
+ int totalFailed = totals.getOrDefault(TestResult.FAILED, 0);
+ int totalPassed = totals.getOrDefault(TestResult.SUCCESS, 0);
+ int totalSkipped = totals.getOrDefault(TestResult.SKIPPED, 0);
+ int totalUnsupported = totals.getOrDefault(TestResult.UNSUPPORTED, 0);
+ if (totalSkipped > 0) {
+ logger.info("Skipped: {} tests", totalSkipped);
+ }
+ if (totalUnsupported > 0) {
+ logger.info("Unsupported: {} tests", totalUnsupported);
+ }
+ if (totalFailed > 0) {
+ logger.error("Failed {} out of {} tests", totalFailed, totalFailed + totalPassed);
+ } else {
+ logger.info("Passed all {} tests", totalPassed);
+ }
+ return totalFailed == 0;
+ }
+
+ private static List<TestClass> listTests(Arguments arguments, Class<?> source)
+ throws IOException, ReflectiveOperationException, URISyntaxException {
+ String className = arguments.getClassName();
+ String methodName = arguments.getMethodName();
+ List<TestClass> tests = new ArrayList<>();
+ try {
+ if (className != null) {
+ Class<?> clazz;
+ if (className.startsWith(arguments.getPackageName())) {
+ clazz = Class.forName(className);
+ } else {
+ clazz = Class.forName(arguments.getPackageName() + className);
+ }
+ getTestsInClass(clazz, methodName).map(tests::add);
+ } else {
+ List<Class<?>> classes = listTestClasses(arguments, source);
+ for (Class<?> clazz : classes) {
+ getTestsInClass(clazz, methodName).map(tests::add);
+ }
+ }
+ } catch (ReflectiveOperationException | IOException | URISyntaxException e) {
+ logger.error("Failed to resolve test class.", e);
+ throw e;
+ }
+ return tests;
+ }
+
+ private static Optional<TestClass> getTestsInClass(Class<?> clazz, String methodName)
+ throws ReflectiveOperationException {
+ if (clazz.getConstructors().length == 0) {
+ return Optional.empty();
+ }
+ Constructor<?> ctor = clazz.getConstructor();
+ Object obj = ctor.newInstance();
+ TestClass testClass = new TestClass(obj);
+
+ for (Method method : clazz.getDeclaredMethods()) {
+ Test testMethod = method.getAnnotation(Test.class);
+ if (testMethod != null) {
+ if (testMethod.enabled()
+ && (methodName == null || methodName.equals(method.getName()))) {
+ testClass.addTestMethod(method);
+ }
+ continue;
+ }
+ BeforeClass beforeClass = method.getAnnotation(BeforeClass.class);
+ if (beforeClass != null) {
+ testClass.addBeforeClass(method);
+ continue;
+ }
+ AfterClass afterClass = method.getAnnotation(AfterClass.class);
+ if (afterClass != null) {
+ testClass.addAfterClass(method);
+ continue;
+ }
+ BeforeTest beforeTest = method.getAnnotation(BeforeTest.class);
+ if (beforeTest != null) {
+ testClass.addBeforeTest(method);
+ continue;
+ }
+ AfterTest afterTest = method.getAnnotation(AfterTest.class);
+ if (afterTest != null) {
+ testClass.addAfterTest(method);
+ }
+ }
+
+ return Optional.of(testClass);
+ }
+
+ private static List<Class<?>> listTestClasses(Arguments arguments, Class<?> clazz)
+ throws IOException, ClassNotFoundException, URISyntaxException {
+ URL url = clazz.getProtectionDomain().getCodeSource().getLocation();
+ String path = url.getPath();
+
+ if (!"file".equalsIgnoreCase(url.getProtocol())) {
+ return Collections.emptyList();
+ }
+
+ List<Class<?>> classList = new ArrayList<>();
+
+ Path classPath = Paths.get(url.toURI());
+ if (Files.isDirectory(classPath)) {
+ Collection<Path> files =
+ Files.walk(classPath)
+ .filter(p -> Files.isRegularFile(p) && p.toString().endsWith(".class"))
+ .collect(Collectors.toList());
+ for (Path file : files) {
+ Path p = classPath.relativize(file);
+ String className = p.toString();
+ className = className.substring(0, className.lastIndexOf('.'));
+ className = className.replace(File.separatorChar, '.');
+ if (className.startsWith(arguments.getPackageName()) && !className.contains("$")) {
+ try {
+ classList.add(Class.forName(className));
+ } catch (ExceptionInInitializerError ignore) {
+ // ignore
+ }
+ }
+ }
+ } else if (path.toLowerCase().endsWith(".jar")) {
+ try (JarFile jarFile = new JarFile(classPath.toFile())) {
+ Enumeration<JarEntry> en = jarFile.entries();
+ while (en.hasMoreElements()) {
+ JarEntry entry = en.nextElement();
+ String fileName = entry.getName();
+ if (fileName.endsWith(".class")) {
+ fileName = fileName.substring(0, fileName.lastIndexOf('.'));
+ fileName = fileName.replace('/', '.');
+ if (fileName.startsWith(arguments.getPackageName())) {
+ try {
+ classList.add(Class.forName(fileName));
+ } catch (ExceptionInInitializerError ignore) {
+ // ignore
+ }
+ }
+ }
+ }
+ }
+ }
+
+ return classList;
+ }
+
+ private static final class TestClass {
+
+ private Object object;
+ private List<Method> testMethods;
+ private List<Method> beforeClass;
+ private List<Method> afterClass;
+ private List<Method> beforeTest;
+ private List<Method> afterTest;
+
+ public TestClass(Object object) {
+ this.object = object;
+ testMethods = new ArrayList<>();
+ beforeClass = new ArrayList<>();
+ afterClass = new ArrayList<>();
+ beforeTest = new ArrayList<>();
+ afterTest = new ArrayList<>();
+ }
+
+ public void addTestMethod(Method method) {
+ testMethods.add(method);
+ }
+
+ public void addBeforeClass(Method method) {
+ beforeClass.add(method);
+ }
+
+ public void addAfterClass(Method method) {
+ afterClass.add(method);
+ }
+
+ public void addBeforeTest(Method method) {
+ beforeTest.add(method);
+ }
+
+ public void addAfterTest(Method method) {
+ afterTest.add(method);
+ }
+
+ public boolean beforeClass() {
+ try {
+ for (Method method : beforeClass) {
+ method.invoke(object);
+ }
+ return true;
+ } catch (InvocationTargetException | IllegalAccessException e) {
+ logger.error("", e.getCause());
+ }
+ return false;
+ }
+
+ public void afterClass() {
+ try {
+ for (Method method : afterClass) {
+ method.invoke(object);
+ }
+ } catch (InvocationTargetException | IllegalAccessException e) {
+ logger.error("", e.getCause());
+ }
+ }
+
+ public boolean beforeTest() {
+ try {
+ for (Method method : beforeTest) {
+ method.invoke(object);
+ }
+ return true;
+ } catch (InvocationTargetException | IllegalAccessException e) {
+ logger.error("", e.getCause());
+ }
+ return false;
+ }
+
+ public void afterTest() {
+ try {
+ for (Method method : afterTest) {
+ method.invoke(object);
+ }
+ } catch (InvocationTargetException | IllegalAccessException e) {
+ logger.error("", e.getCause());
+ }
+ }
+
+ public TestResult runTest(int index) {
+ if (!beforeTest()) {
+ return TestResult.FAILED;
+ }
+
+ TestResult result;
+ Method method = testMethods.get(index);
+ try {
+ long begin = System.nanoTime();
+ method.invoke(object);
+ String time = String.format("%.3f", (System.nanoTime() - begin) / 1000_0000f);
+ logger.info("Test {}.{} PASSED, duration: {}", getName(), method.getName(), time);
+ result = TestResult.SUCCESS;
+ } catch (IllegalAccessException | InvocationTargetException e) {
+ if (expectedException(method, e)) {
+ logger.info("Test {}.{} PASSED", getName(), method.getName());
+ result = TestResult.SUCCESS;
+ } else if (e.getCause() instanceof SkipException) {
+ logger.info("Test {}.{} SKIPPED", getName(), method.getName());
+ result = TestResult.SKIPPED;
+ } else if (e.getCause() instanceof UnsupportedOperationException) {
+ logger.info("Test {}.{} UNSUPPORTED", getName(), method.getName());
+ logger.trace("", e.getCause());
+ result = TestResult.UNSUPPORTED;
+ } else {
+ logger.error("Test {}.{} FAILED", getName(), method.getName());
+ logger.error("", e.getCause());
+ result = TestResult.FAILED;
+ }
+ } finally {
+ afterTest();
+ }
+ return result;
+ }
+
+ public int getTestCount() {
+ return testMethods.size();
+ }
+
+ public String getName() {
+ return object.getClass().getName();
+ }
+
+ private static boolean expectedException(Method method, Exception e) {
+ Test test = method.getAnnotation(Test.class);
+ Class<?>[] exceptions = test.expectedExceptions();
+ if (exceptions.length > 0) {
+ Throwable exception = e.getCause();
+ for (Class<?> c : exceptions) {
+ if (c.isInstance(exception)) {
+ return true;
+ }
+ }
+ }
+ return false;
+ }
+ }
+
+ public enum TestResult {
+ SUCCESS,
+ FAILED,
+ SKIPPED,
+ UNSUPPORTED;
+ }
+}
diff --git a/java-package/integration/src/main/java/org/apache/mxnet/integration/package-info.java b/java-package/integration/src/main/java/org/apache/mxnet/integration/package-info.java
new file mode 100644
index 0000000..f97e77a
--- /dev/null
+++ b/java-package/integration/src/main/java/org/apache/mxnet/integration/package-info.java
@@ -0,0 +1,19 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/** Contains integration tests that use the engine to test the actual behavior of the API. */
+package org.apache.mxnet.integration;
diff --git a/java-package/integration/src/main/java/org/apache/mxnet/integration/tests/engine/ModelTest.java b/java-package/integration/src/main/java/org/apache/mxnet/integration/tests/engine/ModelTest.java
new file mode 100644
index 0000000..8beffb2
--- /dev/null
+++ b/java-package/integration/src/main/java/org/apache/mxnet/integration/tests/engine/ModelTest.java
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.integration.tests.engine;
+
+import java.io.IOException;
+import org.apache.mxnet.engine.BaseMxResource;
+import org.apache.mxnet.engine.Model;
+import org.apache.mxnet.engine.MxResource;
+import org.apache.mxnet.engine.Predictor;
+import org.apache.mxnet.integration.tests.jna.JnaUtilTest;
+import org.apache.mxnet.integration.util.Assertions;
+import org.apache.mxnet.ndarray.NDArray;
+import org.apache.mxnet.ndarray.NDList;
+import org.apache.mxnet.ndarray.types.Shape;
+import org.apache.mxnet.repository.Item;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.testng.annotations.Test;
+
+public class ModelTest {
+ private static final Logger logger = LoggerFactory.getLogger(JnaUtilTest.class);
+
+ @Test
+ public void modelLoadAndPredictTest() {
+ try (MxResource base = BaseMxResource.getSystemMxResource()) {
+ Model model = Model.loadModel(Item.MLP);
+ Predictor<NDList, NDList> predictor = model.newPredictor();
+ NDArray input = NDArray.create(base, new Shape(1, 28, 28)).ones();
+ NDList inputs = new NDList();
+ inputs.add(input);
+ NDList result = predictor.predict(inputs);
+ NDArray expected =
+ NDArray.create(
+ base,
+ new float[] {
+ 4.93476f,
+ -0.76084447f,
+ 0.37713608f,
+ 0.6605506f,
+ -1.3485785f,
+ -0.8736369f,
+ 0.018061712f,
+ -1.3274033f,
+ 1.0609543f,
+ 0.24042489f
+ },
+ new Shape(1, 10));
+ Assertions.assertAlmostEquals(result.get(0), expected);
+ } catch (IOException e) {
+ logger.error(e.getMessage(), e);
+ }
+ }
+}
diff --git a/java-package/integration/src/main/java/org/apache/mxnet/integration/tests/engine/package-info.java b/java-package/integration/src/main/java/org/apache/mxnet/integration/tests/engine/package-info.java
new file mode 100644
index 0000000..cc6a916
--- /dev/null
+++ b/java-package/integration/src/main/java/org/apache/mxnet/integration/tests/engine/package-info.java
@@ -0,0 +1,19 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/** Contains integration tests that use the engine to test the actual behavior of the API. */
+package org.apache.mxnet.integration.tests.engine;
diff --git a/java-package/integration/src/main/java/org/apache/mxnet/integration/tests/jna/JnaUtilTest.java b/java-package/integration/src/main/java/org/apache/mxnet/integration/tests/jna/JnaUtilTest.java
new file mode 100644
index 0000000..e751140
--- /dev/null
+++ b/java-package/integration/src/main/java/org/apache/mxnet/integration/tests/jna/JnaUtilTest.java
@@ -0,0 +1,183 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.integration.tests.jna;
+
+import java.io.IOException;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+import org.apache.mxnet.engine.BaseMxResource;
+import org.apache.mxnet.engine.Device;
+import org.apache.mxnet.engine.MxResource;
+import org.apache.mxnet.engine.Symbol;
+import org.apache.mxnet.jna.JnaUtils;
+import org.apache.mxnet.ndarray.NDArray;
+import org.apache.mxnet.ndarray.NDList;
+import org.apache.mxnet.ndarray.types.Shape;
+import org.apache.mxnet.nn.Parameter;
+import org.apache.mxnet.nn.SymbolBlock;
+import org.apache.mxnet.repository.Item;
+import org.apache.mxnet.repository.Repository;
+import org.apache.mxnet.util.PairList;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.testng.Assert;
+import org.testng.annotations.Test;
+
+public class JnaUtilTest {
+
+ private static final Logger logger = LoggerFactory.getLogger(JnaUtilTest.class);
+
+ @Test
+ public void doForwardTest() throws IOException {
+ try (MxResource base = BaseMxResource.getSystemMxResource()) {
+ Path modelPath = Repository.initRepository(Item.MLP);
+ Path symbolPath = modelPath.resolve("mlp-symbol.json");
+ Path paramsPath = modelPath.resolve("mlp-0000.params");
+ Symbol symbol = Symbol.loadSymbol(base, symbolPath);
+ SymbolBlock block = new SymbolBlock(base, symbol);
+ Device device = Device.defaultIfNull();
+ NDList mxNDArray = JnaUtils.loadNdArray(base, paramsPath, Device.defaultIfNull(null));
+
+ // load parameters
+ List<Parameter> parameters = block.getAllParameters();
+ Map<String, Parameter> map = new ConcurrentHashMap<>();
+ parameters.forEach(p -> map.put(p.getName(), p));
+
+ for (NDArray nd : mxNDArray) {
+ String key = nd.getName();
+ if (key == null) {
+ throw new IllegalArgumentException(
+ "Array names must be present in parameter file");
+ }
+
+ String paramName = key.split(":", 2)[1];
+ Parameter parameter = map.remove(paramName);
+ parameter.setArray(nd);
+ }
+ block.setInputNames(new ArrayList<>(map.keySet()));
+
+ NDArray arr = NDArray.create(base, new Shape(1, 28, 28), device).ones();
+ block.forward(new NDList(arr), new PairList<>(), device);
+ logger.info(
+ "Number of MxResource managed by baseMxResource: {}",
+ BaseMxResource.getSystemMxResource().getSubResource().size());
+ } catch (IOException e) {
+ logger.error(e.getMessage(), e);
+ throw e;
+ }
+ Assert.assertEquals(BaseMxResource.getSystemMxResource().getSubResource().size(), 0);
+ }
+
+ @Test
+ public void createNdArray() {
+ try {
+ try (BaseMxResource base = BaseMxResource.getSystemMxResource()) {
+ int[] originIntegerArray = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
+ float[] originFloatArray = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
+ double[] originDoubleArray = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
+ long[] originLongArray = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
+ boolean[] originBooleanArray = {
+ true, false, false, true, true, true, true, false, false, true, true, true
+ };
+ byte[] originByteArray = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
+ NDArray intArray = NDArray.create(base, originIntegerArray, new Shape(3, 4));
+ NDArray floatArray = NDArray.create(base, originFloatArray, new Shape(3, 4));
+ NDArray doubleArray = NDArray.create(base, originDoubleArray, new Shape(3, 4));
+ NDArray longArray = NDArray.create(base, originLongArray, new Shape(3, 4));
+ NDArray booleanArray = NDArray.create(base, originBooleanArray, new Shape(3, 4));
+ NDArray byteArray = NDArray.create(base, originByteArray, new Shape(3, 4));
+ NDArray intArray2 = NDArray.create(base, originIntegerArray);
+ NDArray floatArray2 = NDArray.create(base, originFloatArray);
+ NDArray doubleArray2 = NDArray.create(base, originDoubleArray);
+ NDArray longArray2 = NDArray.create(base, originLongArray);
+ NDArray booleanArray2 = NDArray.create(base, originBooleanArray);
+ NDArray byteArray2 = NDArray.create(base, originByteArray);
+
+ int[] ndArrayInt = intArray.toIntArray();
+ Assert.assertEquals(originIntegerArray, ndArrayInt);
+ // Float -> Double
+ float[] floats = floatArray.toFloatArray();
+ Assert.assertEquals(originFloatArray, floats);
+ double[] ndArrayDouble = doubleArray.toDoubleArray();
+ Assert.assertEquals(originDoubleArray, ndArrayDouble);
+ long[] ndArrayLong = longArray.toLongArray();
+ Assert.assertEquals(originLongArray, ndArrayLong);
+ boolean[] ndArrayBoolean = booleanArray.toBooleanArray();
+ Assert.assertEquals(originBooleanArray, ndArrayBoolean);
+ byte[] ndArrayByte = byteArray.toByteArray();
+ Assert.assertEquals(originByteArray, ndArrayByte);
+
+ int[] ndArrayInt2 = intArray2.toIntArray();
+ Assert.assertEquals(originIntegerArray, ndArrayInt2);
+
+ // Float -> Double
+ float[] floats2 = floatArray2.toFloatArray();
+ Assert.assertEquals(originFloatArray, floats2);
+ double[] ndArrayDouble2 = doubleArray2.toDoubleArray();
+ Assert.assertEquals(originDoubleArray, ndArrayDouble2);
+ long[] ndArrayLong2 = longArray2.toLongArray();
+ Assert.assertEquals(originLongArray, ndArrayLong2);
+ boolean[] ndArrayBoolean2 = booleanArray2.toBooleanArray();
+ Assert.assertEquals(originBooleanArray, ndArrayBoolean2);
+ byte[] ndArrayByte2 = byteArray2.toByteArray();
+ Assert.assertEquals(originByteArray, ndArrayByte2);
+ } catch (ClassCastException e) {
+ logger.error(e.getMessage());
+ throw e;
+ }
+ BaseMxResource base = BaseMxResource.getSystemMxResource();
+ int countNotReleased = 0;
+ for (MxResource mxResource : base.getSubResource().values()) {
+ if (!mxResource.getClosed()) {
+ ++countNotReleased;
+ }
+ }
+ Assert.assertEquals(countNotReleased, 0);
+ } catch (ClassCastException e) {
+ logger.error(e.getMessage());
+ throw e;
+ }
+ }
+
+ @Test
+ public void loadNdArray() throws IOException {
+ try (BaseMxResource base = BaseMxResource.getSystemMxResource()) {
+ Path modelPath = Repository.initRepository(Item.MLP);
+ Path paramsPath = modelPath.resolve("mlp-0000.params");
+ NDList mxNDArray =
+ JnaUtils.loadNdArray(
+ base, Paths.get(paramsPath.toUri()), Device.defaultIfNull(null));
+ logger.info(mxNDArray.toString());
+ logger.info(
+ String.format(
+ "The amount of sub resources managed by BaseMxResource: %s",
+ base.getSubResource().size()));
+ } catch (IOException e) {
+ logger.error(e.getMessage());
+ throw e;
+ }
+ logger.info(
+ String.format(
+ "The amount of sub resources managed by BaseMxResource: %s",
+ BaseMxResource.getSystemMxResource().getSubResource().size()));
+ }
+}
diff --git a/java-package/integration/src/main/java/org/apache/mxnet/integration/util/Arguments.java b/java-package/integration/src/main/java/org/apache/mxnet/integration/util/Arguments.java
new file mode 100644
index 0000000..15771f3
--- /dev/null
+++ b/java-package/integration/src/main/java/org/apache/mxnet/integration/util/Arguments.java
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.integration.util;
+
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.Option;
+import org.apache.commons.cli.Options;
+
+public class Arguments {
+
+ private String methodName;
+ private String className;
+ private String packageName;
+ private int duration;
+ private int iteration = 1;
+
+ public Arguments(CommandLine cmd) {
+ methodName = cmd.getOptionValue("method-name");
+ className = cmd.getOptionValue("class-name");
+ if (cmd.hasOption("package-name")) {
+ packageName = cmd.getOptionValue("package-name");
+ } else {
+ packageName = "org.apache.mxnet.integration.tests.";
+ }
+
+ if (cmd.hasOption("duration")) {
+ duration = Integer.parseInt(cmd.getOptionValue("duration"));
+ }
+ if (cmd.hasOption("iteration")) {
+ iteration = Integer.parseInt(cmd.getOptionValue("iteration"));
+ }
+ }
+
+ public static Options getOptions() {
+ Options options = new Options();
+ options.addOption(
+ Option.builder("d")
+ .longOpt("duration")
+ .hasArg()
+ .argName("DURATION")
+ .desc("Duration of the test.")
+ .build());
+ options.addOption(
+ Option.builder("n")
+ .longOpt("iteration")
+ .hasArg()
+ .argName("ITERATION")
+ .desc("Number of iterations in each test.")
+ .build());
+ options.addOption(
+ Option.builder("p")
+ .longOpt("package-name")
+ .hasArg()
+ .argName("PACKAGE-NAME")
+ .desc("Name of the package to run")
+ .build());
+ options.addOption(
+ Option.builder("c")
+ .longOpt("class-name")
+ .hasArg()
+ .argName("CLASS-NAME")
+ .desc("Name of the class to run")
+ .build());
+ options.addOption(
+ Option.builder("m")
+ .longOpt("method-name")
+ .hasArg()
+ .argName("METHOD-NAME")
+ .desc("Name of the method to run")
+ .build());
+ return options;
+ }
+
+ public int getDuration() {
+ return duration;
+ }
+
+ public int getIteration() {
+ return iteration;
+ }
+
+ public String getPackageName() {
+ return packageName;
+ }
+
+ public String getClassName() {
+ return className;
+ }
+
+ public String getMethodName() {
+ return methodName;
+ }
+}
diff --git a/java-package/integration/src/main/java/org/apache/mxnet/integration/util/Assertions.java b/java-package/integration/src/main/java/org/apache/mxnet/integration/util/Assertions.java
new file mode 100644
index 0000000..b342b45
--- /dev/null
+++ b/java-package/integration/src/main/java/org/apache/mxnet/integration/util/Assertions.java
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mxnet.integration.util;
+
+import org.apache.mxnet.ndarray.NDArray;
+import org.apache.mxnet.ndarray.NDList;
+import org.testng.Assert;
+
+public final class Assertions {
+ private static final double RTOL = 1e-5;
+ private static final double ATOL = 1e-3;
+
+ private Assertions() {}
+
+ private static <T> String getDefaultErrorMessage(T actual, T expected) {
+ return getDefaultErrorMessage(actual, expected, null);
+ }
+
+ private static <T> String getDefaultErrorMessage(T actual, T expected, String errorMessage) {
+ StringBuilder sb = new StringBuilder(100);
+ if (errorMessage != null) {
+ sb.append(errorMessage);
+ }
+ sb.append(System.lineSeparator())
+ .append("Expected: ")
+ .append(expected)
+ .append(System.lineSeparator())
+ .append("Actual: ")
+ .append(actual);
+ return sb.toString();
+ }
+
+ public static void assertAlmostEquals(NDArray actual, NDArray expected) {
+ assertAlmostEquals(actual, expected, RTOL, ATOL);
+ }
+
+ public static void assertAlmostEquals(NDList actual, NDList expected) {
+ assertAlmostEquals(actual, expected, RTOL, ATOL);
+ }
+
+ public static void assertAlmostEquals(double actual, double expected) {
+ assertAlmostEquals(actual, expected, RTOL, ATOL);
+ }
+
+ public static void assertAlmostEquals(
+ double actual, double expected, double rtol, double atol) {
+ if (Math.abs(actual - expected) > (atol + rtol * Math.abs(expected))) {
+ throw new AssertionError(getDefaultErrorMessage(actual, expected));
+ }
+ }
+
+ public static void assertAlmostEquals(
+ NDList actual, NDList expected, double rtol, double atol) {
+ Assert.assertEquals(
+ actual.size(),
+ expected.size(),
+ getDefaultErrorMessage(
+ actual.size(), expected.size(), "The NDLists have different sizes"));
+ int size = actual.size();
+ for (int i = 0; i < size; i++) {
+ assertAlmostEquals(actual.get(i), expected.get(i), rtol, atol);
+ }
+ }
+
+ public static void assertAlmostEquals(
+ NDArray actual, NDArray expected, double rtol, double atol) {
+ if (!actual.getShape().equals(expected.getShape())) {
+ throw new AssertionError(
+ getDefaultErrorMessage(
+ actual.getShape(),
+ expected.getShape(),
+ "The shape of two NDArray are different!"));
+ }
+ Number[] actualDoubleArray = actual.toArray();
+ Number[] expectedDoubleArray = expected.toArray();
+ for (int i = 0; i < actualDoubleArray.length; i++) {
+ double a = actualDoubleArray[i].doubleValue();
+ double b = expectedDoubleArray[i].doubleValue();
+ if (Math.abs(a - b) > (atol + rtol * Math.abs(b))) {
+ throw new AssertionError("Expected:" + b + " but got " + a);
+ }
+ }
+ }
+
+ public static void assertInPlaceEquals(NDArray actual, NDArray expected, NDArray original) {
+ Assert.assertEquals(
+ actual, expected, getDefaultErrorMessage(actual, expected, "Assert Equal failed!"));
+ Assert.assertSame(
+ original,
+ actual,
+ getDefaultErrorMessage(original, expected, "Assert Inplace failed!"));
+ }
+
+ public static void assertInPlaceAlmostEquals(
+ NDArray actual, NDArray expected, NDArray original) {
+ assertInPlaceAlmostEquals(actual, expected, original, RTOL, ATOL);
+ }
+
+ public static void assertInPlaceAlmostEquals(
+ NDArray actual, NDArray expected, NDArray original, double rtol, double atol) {
+ assertAlmostEquals(actual, expected, rtol, atol);
+ Assert.assertSame(
+ original,
+ actual,
+ getDefaultErrorMessage(original, expected, "Assert Inplace failed!"));
+ }
+}
diff --git a/java-package/integration/src/main/java/org/apache/mxnet/integration/util/CoverageUtils.java b/java-package/integration/src/main/java/org/apache/mxnet/integration/util/CoverageUtils.java
new file mode 100644
index 0000000..e897c27
--- /dev/null
+++ b/java-package/integration/src/main/java/org/apache/mxnet/integration/util/CoverageUtils.java
@@ -0,0 +1,260 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mxnet.integration.util;
+
+import java.io.File;
+import java.io.IOException;
+import java.lang.reflect.Constructor;
+import java.lang.reflect.Field;
+import java.lang.reflect.Method;
+import java.lang.reflect.Proxy;
+import java.net.URISyntaxException;
+import java.net.URL;
+import java.net.URLClassLoader;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.Enumeration;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.jar.JarEntry;
+import java.util.jar.JarFile;
+import java.util.stream.Collectors;
+
+public final class CoverageUtils {
+
+ private CoverageUtils() {}
+
+ public static void testGetterSetters(Class<?> baseClass)
+ throws IOException, ReflectiveOperationException, URISyntaxException {
+ List<Class<?>> list = getClasses(baseClass);
+ for (Class<?> clazz : list) {
+ Object obj = null;
+ if (clazz.isEnum()) {
+ obj = clazz.getEnumConstants()[0];
+ } else {
+ Constructor<?>[] constructors = clazz.getConstructors();
+ for (Constructor<?> con : constructors) {
+ try {
+ Class<?>[] types = con.getParameterTypes();
+ Object[] args = new Object[types.length];
+ for (int i = 0; i < args.length; ++i) {
+ args[i] = getMockInstance(types[i], true);
+ }
+ con.setAccessible(true);
+ obj = con.newInstance(args);
+ } catch (ReflectiveOperationException ignore) {
+ // ignore
+ }
+ }
+ }
+ if (obj == null) {
+ continue;
+ }
+
+ Method[] methods = clazz.getDeclaredMethods();
+ for (Method method : methods) {
+ String methodName = method.getName();
+ int parameterCount = method.getParameterCount();
+ try {
+ if (parameterCount == 0
+ && (methodName.startsWith("get")
+ || methodName.startsWith("is")
+ || "toString".equals(methodName)
+ || "hashCode".equals(methodName))) {
+ method.invoke(obj);
+ } else if (parameterCount == 1
+ && (methodName.startsWith("set") || "fromValue".equals(methodName))) {
+ Class<?> type = method.getParameterTypes()[0];
+ method.invoke(obj, getMockInstance(type, true));
+ } else if ("equals".equals(methodName)) {
+ method.invoke(obj, obj);
+ method.invoke(obj, (Object) null);
+ Class<?> type = method.getParameterTypes()[0];
+ method.invoke(obj, getMockInstance(type, true));
+ }
+ } catch (ReflectiveOperationException ignore) {
+ // ignore
+ }
+ }
+ }
+ }
+
+ private static List<Class<?>> getClasses(Class<?> clazz)
+ throws IOException, ReflectiveOperationException, URISyntaxException {
+ ClassLoader appClassLoader = Thread.currentThread().getContextClassLoader();
+ Field field = appClassLoader.getClass().getDeclaredField("ucp");
+ field.setAccessible(true);
+ Object ucp = field.get(appClassLoader);
+ Method method = ucp.getClass().getDeclaredMethod("getURLs");
+ URL[] urls = (URL[]) method.invoke(ucp);
+ ClassLoader cl = new TestClassLoader(urls, Thread.currentThread().getContextClassLoader());
+
+ URL url = clazz.getProtectionDomain().getCodeSource().getLocation();
+ String path = url.getPath();
+
+ if (!"file".equalsIgnoreCase(url.getProtocol())) {
+ return Collections.emptyList();
+ }
+
+ List<Class<?>> classList = new ArrayList<>();
+
+ Path classPath = Paths.get(url.toURI());
+ if (Files.isDirectory(classPath)) {
+ Collection<Path> files =
+ Files.walk(classPath)
+ .filter(p -> Files.isRegularFile(p) && p.toString().endsWith(".class"))
+ .collect(Collectors.toList());
+ for (Path file : files) {
+ Path p = classPath.relativize(file);
+ String className = p.toString();
+ className = className.substring(0, className.lastIndexOf('.'));
+ className = className.replace(File.separatorChar, '.');
+
+ try {
+ classList.add(Class.forName(className, true, cl));
+ } catch (Error ignore) {
+ // ignore
+ }
+ }
+ } else if (path.toLowerCase().endsWith(".jar")) {
+ try (JarFile jarFile = new JarFile(classPath.toFile())) {
+ Enumeration<JarEntry> en = jarFile.entries();
+ while (en.hasMoreElements()) {
+ JarEntry entry = en.nextElement();
+ String fileName = entry.getName();
+ if (fileName.endsWith(".class")) {
+ fileName = fileName.substring(0, fileName.lastIndexOf('.'));
+ fileName = fileName.replace('/', '.');
+ try {
+ classList.add(Class.forName(fileName, true, cl));
+ } catch (Error ignore) {
+ // ignore
+ }
+ }
+ }
+ }
+ }
+
+ return classList;
+ }
+
+ private static Object getMockInstance(Class<?> clazz, boolean useConstructor) {
+ if (clazz.isPrimitive()) {
+ if (clazz == Boolean.TYPE) {
+ return Boolean.TRUE;
+ }
+ if (clazz == Character.TYPE) {
+ return '0';
+ }
+ if (clazz == Byte.TYPE) {
+ return (byte) 0;
+ }
+ if (clazz == Short.TYPE) {
+ return (short) 0;
+ }
+ if (clazz == Integer.TYPE) {
+ return 0;
+ }
+ if (clazz == Long.TYPE) {
+ return 0L;
+ }
+ if (clazz == Float.TYPE) {
+ return 0f;
+ }
+ if (clazz == Double.TYPE) {
+ return 0d;
+ }
+ }
+
+ if (clazz.isAssignableFrom(String.class)) {
+ return "";
+ }
+
+ if (clazz.isAssignableFrom(List.class)) {
+ return new ArrayList<>();
+ }
+
+ if (clazz.isAssignableFrom(Set.class)) {
+ return new HashSet<>();
+ }
+
+ if (clazz.isAssignableFrom(Map.class)) {
+ return new HashMap<>();
+ }
+
+ if (clazz.isEnum()) {
+ return clazz.getEnumConstants()[0];
+ }
+
+ if (clazz.isInterface()) {
+ return newProxyInstance(clazz);
+ }
+
+ if (useConstructor) {
+ Constructor<?>[] constructors = clazz.getConstructors();
+ for (Constructor<?> con : constructors) {
+ try {
+ Class<?>[] types = con.getParameterTypes();
+ Object[] args = new Object[types.length];
+ for (int i = 0; i < args.length; ++i) {
+ args[i] = getMockInstance(types[i], false);
+ }
+ con.setAccessible(true);
+ return con.newInstance(args);
+ } catch (ReflectiveOperationException ignore) {
+ // ignore
+ }
+ }
+ }
+
+ return null;
+ }
+
+ @SuppressWarnings({"rawtypes", "PMD.UseProperClassLoader"})
+ private static Object newProxyInstance(Class<?> clazz) {
+ ClassLoader cl = clazz.getClassLoader();
+ return Proxy.newProxyInstance(cl, new Class[] {clazz}, (proxy, method, args) -> null);
+ }
+
+ private static final class TestClassLoader extends URLClassLoader {
+
+ public TestClassLoader(URL[] urls, ClassLoader parent) {
+ super(urls, parent);
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public Class<?> loadClass(String name) throws ClassNotFoundException {
+ try {
+ return findClass(name);
+ } catch (ClassNotFoundException e) {
+ ClassLoader classLoader = getParent();
+ if (classLoader == null) {
+ classLoader = getSystemClassLoader();
+ }
+ return classLoader.loadClass(name);
+ }
+ }
+ }
+}
diff --git a/java-package/integration/src/main/resources/log4j2.xml b/java-package/integration/src/main/resources/log4j2.xml
new file mode 100644
index 0000000..ff05a01
--- /dev/null
+++ b/java-package/integration/src/main/resources/log4j2.xml
@@ -0,0 +1,34 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!--
+ ~ Licensed to the Apache Software Foundation (ASF) under one or more
+ ~ contributor license agreements. See the NOTICE file distributed with
+ ~ this work for additional information regarding copyright ownership.
+ ~ The ASF licenses this file to You under the Apache License, Version 2.0
+ ~ (the "License"); you may not use this file except in compliance with
+ ~ the License. You may obtain a copy of the License at
+ ~
+ ~ http://www.apache.org/licenses/LICENSE-2.0
+ ~
+ ~ Unless required by applicable law or agreed to in writing, software
+ ~ distributed under the License is distributed on an "AS IS" BASIS,
+ ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ ~ See the License for the specific language governing permissions and
+ ~ limitations under the License.
+ -->
+
+<Configuration status="INFO">
+ <Appenders>
+ <Console name="console" target="SYSTEM_OUT">
+ <PatternLayout
+ pattern="[%-5level] - %msg%n"/>
+ </Console>
+ </Appenders>
+ <Loggers>
+ <Root level="info" additivity="false">
+ <AppenderRef ref="console"/>
+ </Root>
+ <Logger name="com.apache.mxnet" level="${sys:com.apache.mxnet.level:-debug}" additivity="false">
+ <AppenderRef ref="console"/>
+ </Logger>
+ </Loggers>
+</Configuration>
diff --git a/java-package/jnarator/build.gradle b/java-package/jnarator/build.gradle
new file mode 100644
index 0000000..abc7cc0
--- /dev/null
+++ b/java-package/jnarator/build.gradle
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+plugins {
+ id 'antlr'
+}
+
+dependencies {
+ antlr "org.antlr:antlr4:${antlr_version}"
+
+ api "commons-cli:commons-cli:${commons_cli_version}"
+ api "org.antlr:antlr4-runtime:${antlr_version}"
+ api "org.apache.logging.log4j:log4j-slf4j-impl:${log4j_slf4j_version}"
+
+ testImplementation("org.testng:testng:${testng_version}") {
+ exclude group: "junit", module: "junit"
+ }
+
+ testImplementation "org.slf4j:slf4j-simple:${slf4j_version}"
+ testRuntimeOnly project(":mxnet-engine")
+// testRuntimeOnly ":mxnet-native-auto:${mxnet_version}"
+}
+
+checkstyleMain.source = 'src/main/java'
+
+checkstyleMain {
+ // skip check style for this package
+ exclude 'org/apache/mxnet/jnarator/**'
+}
+
+pmdMain.source = 'src/main/java'
+//pmdMain.ignoreFailures(true)
+spotbugs.ignoreFailures = true
+
+jar {
+ manifest {
+ attributes (
+ "Main-Class" : "org.apache.mxnet.jnarator.Main",
+ "Multi-Release" : true
+ )
+ }
+ includeEmptyDirs = false
+ duplicatesStrategy = DuplicatesStrategy.INCLUDE
+ from configurations.runtimeClasspath.collect { it.isDirectory() ? it : zipTree(it) }
+}
diff --git a/java-package/jnarator/src/main/antlr/org/apache/mxnet/jnarator/parser/C.g4 b/java-package/jnarator/src/main/antlr/org/apache/mxnet/jnarator/parser/C.g4
new file mode 100644
index 0000000..6206c13
--- /dev/null
+++ b/java-package/jnarator/src/main/antlr/org/apache/mxnet/jnarator/parser/C.g4
@@ -0,0 +1,923 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+grammar C;
+
+@parser::header {
+package org.apache.mxnet.jnarator.parser;
+
+}
+
+@lexer::header {
+package org.apache.mxnet.jnarator.parser;
+
+}
+
+
+primaryExpression
+ : Identifier
+ | Constant
+ | StringLiteral+
+ | '(' expression ')'
+ | genericSelection
+ | '__extension__'? '(' compoundStatement ')' // Blocks (GCC extension)
+ | '__builtin_va_arg' '(' unaryExpression ',' typeName ')'
+ | '__builtin_offsetof' '(' typeName ',' unaryExpression ')'
+ ;
+
+genericSelection
+ : '_Generic' '(' assignmentExpression ',' genericAssocList ')'
+ ;
+
+genericAssocList
+ : genericAssociation
+ | genericAssocList ',' genericAssociation
+ ;
+
+genericAssociation
+ : typeName ':' assignmentExpression
+ | 'default' ':' assignmentExpression
+ ;
+
+postfixExpression
+ : primaryExpression
+ | postfixExpression '[' expression ']'
+ | postfixExpression '(' argumentExpressionList? ')'
+ | postfixExpression '.' Identifier
+ | postfixExpression '->' Identifier
+ | postfixExpression '++'
+ | postfixExpression '--'
+ | '(' typeName ')' '{' initializerList '}'
+ | '(' typeName ')' '{' initializerList ',' '}'
+ | '__extension__' '(' typeName ')' '{' initializerList '}'
+ | '__extension__' '(' typeName ')' '{' initializerList ',' '}'
+ ;
+
+argumentExpressionList
+ : assignmentExpression
+ | argumentExpressionList ',' assignmentExpression
+ ;
+
+unaryExpression
+ : postfixExpression
+ | '++' unaryExpression
+ | '--' unaryExpression
+ | unaryOperator castExpression
+ | 'sizeof' unaryExpression
+ | 'sizeof' '(' typeName ')'
+ | '_Alignof' '(' typeName ')'
+ | '&&' Identifier // GCC extension address of label
+ ;
+
+unaryOperator
+ : '&' | '*' | '+' | '-' | '~' | '!'
+ ;
+
+castExpression
+ : '(' typeName ')' castExpression
+ | '__extension__' '(' typeName ')' castExpression
+ | unaryExpression
+ | DigitSequence // for
+ ;
+
+multiplicativeExpression
+ : castExpression
+ | multiplicativeExpression '*' castExpression
+ | multiplicativeExpression '/' castExpression
+ | multiplicativeExpression '%' castExpression
+ ;
+
+additiveExpression
+ : multiplicativeExpression
+ | additiveExpression '+' multiplicativeExpression
+ | additiveExpression '-' multiplicativeExpression
+ ;
+
+shiftExpression
+ : additiveExpression
+ | shiftExpression '<<' additiveExpression
+ | shiftExpression '>>' additiveExpression
+ ;
+
+relationalExpression
+ : shiftExpression
+ | relationalExpression '<' shiftExpression
+ | relationalExpression '>' shiftExpression
+ | relationalExpression '<=' shiftExpression
+ | relationalExpression '>=' shiftExpression
+ ;
+
+equalityExpression
+ : relationalExpression
+ | equalityExpression '==' relationalExpression
+ | equalityExpression '!=' relationalExpression
+ ;
+
+andExpression
+ : equalityExpression
+ | andExpression '&' equalityExpression
+ ;
+
+exclusiveOrExpression
+ : andExpression
+ | exclusiveOrExpression '^' andExpression
+ ;
+
+inclusiveOrExpression
+ : exclusiveOrExpression
+ | inclusiveOrExpression '|' exclusiveOrExpression
+ ;
+
+logicalAndExpression
+ : inclusiveOrExpression
+ | logicalAndExpression '&&' inclusiveOrExpression
+ ;
+
+logicalOrExpression
+ : logicalAndExpression
+ | logicalOrExpression '||' logicalAndExpression
+ ;
+
+conditionalExpression
+ : logicalOrExpression ('?' expression ':' conditionalExpression)?
+ ;
+
+assignmentExpression
+ : conditionalExpression
+ | unaryExpression assignmentOperator assignmentExpression
+ | DigitSequence // for
+ ;
+
+assignmentOperator
+ : '=' | '*=' | '/=' | '%=' | '+=' | '-=' | '<<=' | '>>=' | '&=' | '^=' | '|='
+ ;
+
+expression
+ : assignmentExpression
+ | expression ',' assignmentExpression
+ ;
+
+constantExpression
+ : conditionalExpression
+ ;
+
+declaration
+ : declarationSpecifiers initDeclaratorList ';'
+ | declarationSpecifiers ';'
+ | staticAssertDeclaration
+ ;
+
+declarationSpecifiers
+ : declarationSpecifier+
+ ;
+
+declarationSpecifiers2
+ : declarationSpecifier+
+ ;
+
+declarationSpecifier
+ : storageClassSpecifier
+ | typeSpecifier
+ | typeQualifier
+ | functionSpecifier
+ | alignmentSpecifier
+ ;
+
+initDeclaratorList
+ : initDeclarator
+ | initDeclaratorList ',' initDeclarator
+ ;
+
+initDeclarator
+ : declarator
+ | declarator '=' initializer
+ ;
+
+storageClassSpecifier
+ : 'typedef'
+ | 'extern'
+ | 'static'
+ | '_Thread_local'
+ | 'auto'
+ | 'register'
+ ;
+
+typeSpecifier
+ : ('void'
+ | 'char'
+ | 'short'
+ | 'int'
+ | 'long'
+ | 'float'
+ | 'double'
+ | 'signed'
+ | 'unsigned'
+ | '_Bool'
+ | '_Complex'
+ | '__m128'
+ | '__m128d'
+ | '__m128i')
+ | '__extension__' '(' ('__m128' | '__m128d' | '__m128i') ')'
+ | atomicTypeSpecifier
+ | structOrUnionSpecifier
+ | enumSpecifier
+ | typedefName
+ | '__typeof__' '(' constantExpression ')' // GCC extension
+ | typeSpecifier pointer
+ ;
+
+structOrUnionSpecifier
+ : structOrUnion Identifier? '{' structDeclarationList '}'
+ | structOrUnion Identifier
+ ;
+
+structOrUnion
+ : 'struct'
+ | 'union'
+ ;
+
+structDeclarationList
+ : structDeclaration
+ | structDeclarationList structDeclaration
+ ;
+
+structDeclaration
+ : specifierQualifierList structDeclaratorList? ';'
+ | staticAssertDeclaration
+ ;
+
+specifierQualifierList
+ : typeSpecifier specifierQualifierList?
+ | typeQualifier specifierQualifierList?
+ ;
+
+structDeclaratorList
+ : structDeclarator
+ | structDeclaratorList ',' structDeclarator
+ ;
+
+structDeclarator
+ : declarator
+ | declarator? ':' constantExpression
+ ;
+
+enumSpecifier
+ : 'enum' Identifier? '{' enumeratorList '}'
+ | 'enum' Identifier? '{' enumeratorList ',' '}'
+ | 'enum' Identifier
+ ;
+
+enumeratorList
+ : enumerator
+ | enumeratorList ',' enumerator
+ ;
+
+enumerator
+ : enumerationConstant
+ | enumerationConstant '=' constantExpression
+ ;
+
+enumerationConstant
+ : Identifier
+ ;
+
+atomicTypeSpecifier
+ : '_Atomic' '(' typeName ')'
+ ;
+
+typeQualifier
+ : 'const'
+ | 'restrict'
+ | 'volatile'
+ | '_Atomic'
+ ;
+
+functionSpecifier
+ : ('inline'
+ | '_Noreturn'
+ | '__inline__' // GCC extension
+ | '__stdcall')
+ | gccAttributeSpecifier
+ | '__declspec' '(' Identifier ')'
+ ;
+
+alignmentSpecifier
+ : '_Alignas' '(' typeName ')'
+ | '_Alignas' '(' constantExpression ')'
+ ;
+
+declarator
+ : pointer? directDeclarator gccDeclaratorExtension*
+ ;
+
+directDeclarator
+ : Identifier
+ | '(' declarator ')'
+ | directDeclarator '[' typeQualifierList? assignmentExpression? ']'
+ | directDeclarator '[' 'static' typeQualifierList? assignmentExpression ']'
+ | directDeclarator '[' typeQualifierList 'static' assignmentExpression ']'
+ | directDeclarator '[' typeQualifierList? '*' ']'
+ | directDeclarator '(' parameterTypeList ')'
+ | directDeclarator '(' identifierList? ')'
+ | Identifier ':' DigitSequence // bit field
+ | '(' typeSpecifier? pointer directDeclarator ')' // function pointer like: (__cdecl *f)
+ ;
+
+gccDeclaratorExtension
+ : '__asm' '(' StringLiteral+ ')'
+ | gccAttributeSpecifier
+ ;
+
+gccAttributeSpecifier
+ : '__attribute__' '(' '(' gccAttributeList ')' ')'
+ ;
+
+gccAttributeList
+ : gccAttribute (',' gccAttribute)*
+ | // empty
+ ;
+
+gccAttribute
+ : ~(',' | '(' | ')') // relaxed def for "identifier or reserved word"
+ ('(' argumentExpressionList? ')')?
+ | // empty
+ ;
+
+nestedParenthesesBlock
+ : ( ~('(' | ')')
+ | '(' nestedParenthesesBlock ')'
+ )*
+ ;
+
+pointer
+ : '*' typeQualifierList?
+ | '*' typeQualifierList? pointer
+ | '^' typeQualifierList? // Blocks language extension
+ | '^' typeQualifierList? pointer // Blocks language extension
+ ;
+
+typeQualifierList
+ : typeQualifier
+ | typeQualifierList typeQualifier
+ ;
+
+parameterTypeList
+ : parameterList
+ | parameterList ',' '...'
+ ;
+
+parameterList
+ : parameterDeclaration
+ | parameterList ',' parameterDeclaration
+ ;
+
+parameterDeclaration
+ : declarationSpecifiers declarator
+ | declarationSpecifiers2 abstractDeclarator?
+ ;
+
+identifierList
+ : Identifier
+ | identifierList ',' Identifier
+ ;
+
+typeName
+ : specifierQualifierList abstractDeclarator?
+ ;
+
+abstractDeclarator
+ : pointer
+ | pointer? directAbstractDeclarator gccDeclaratorExtension*
+ ;
+
+directAbstractDeclarator
+ : '(' abstractDeclarator ')' gccDeclaratorExtension*
+ | '[' typeQualifierList? assignmentExpression? ']'
+ | '[' 'static' typeQualifierList? assignmentExpression ']'
+ | '[' typeQualifierList 'static' assignmentExpression ']'
+ | '[' '*' ']'
+ | '(' parameterTypeList? ')' gccDeclaratorExtension*
+ | directAbstractDeclarator '[' typeQualifierList? assignmentExpression? ']'
+ | directAbstractDeclarator '[' 'static' typeQualifierList? assignmentExpression ']'
+ | directAbstractDeclarator '[' typeQualifierList 'static' assignmentExpression ']'
+ | directAbstractDeclarator '[' '*' ']'
+ | directAbstractDeclarator '(' parameterTypeList? ')' gccDeclaratorExtension*
+ ;
+
+typedefName
+ : Identifier
+ ;
+
+initializer
+ : assignmentExpression
+ | '{' initializerList '}'
+ | '{' initializerList ',' '}'
+ ;
+
+initializerList
+ : designation? initializer
+ | initializerList ',' designation? initializer
+ ;
+
+designation
+ : designatorList '='
+ ;
+
+designatorList
+ : designator
+ | designatorList designator
+ ;
+
+designator
+ : '[' constantExpression ']'
+ | '.' Identifier
+ ;
+
+staticAssertDeclaration
+ : '_Static_assert' '(' constantExpression ',' StringLiteral+ ')' ';'
+ ;
+
+statement
+ : labeledStatement
+ | compoundStatement
+ | expressionStatement
+ | selectionStatement
+ | iterationStatement
+ | jumpStatement
+ | ('__asm' | '__asm__') ('volatile' | '__volatile__') '(' (logicalOrExpression (',' logicalOrExpression)*)? (':' (logicalOrExpression (',' logicalOrExpression)*)?)* ')' ';'
+ ;
+
+labeledStatement
+ : Identifier ':' statement
+ | 'case' constantExpression ':' statement
+ | 'default' ':' statement
+ ;
+
+compoundStatement
+ : '{' blockItemList? '}'
+ ;
+
+blockItemList
+ : blockItem
+ | blockItemList blockItem
+ ;
+
+blockItem
+ : statement
+ | declaration
+ ;
+
+expressionStatement
+ : expression? ';'
+ ;
+
+selectionStatement
+ : 'if' '(' expression ')' statement ('else' statement)?
+ | 'switch' '(' expression ')' statement
+ ;
+
+iterationStatement
+ : While '(' expression ')' statement
+ | Do statement While '(' expression ')' ';'
+ | For '(' forCondition ')' statement
+ ;
+
+// | 'for' '(' expression? ';' expression? ';' forUpdate? ')' statement
+// | For '(' declaration expression? ';' expression? ')' statement
+
+forCondition
+ : forDeclaration ';' forExpression? ';' forExpression?
+ | expression? ';' forExpression? ';' forExpression?
+ ;
+
+forDeclaration
+ : declarationSpecifiers initDeclaratorList
+ | declarationSpecifiers
+ ;
+
+forExpression
+ : assignmentExpression
+ | forExpression ',' assignmentExpression
+ ;
+
+jumpStatement
+ : 'goto' Identifier ';'
+ | 'continue' ';'
+ | 'break' ';'
+ | 'return' expression? ';'
+ | 'goto' unaryExpression ';' // GCC extension
+ ;
+
+compilationUnit
+ : translationUnit? ( EOF | '}' )
+ ;
+
+translationUnit
+ : externalDeclaration
+ | translationUnit externalDeclaration
+ ;
+
+externalDeclaration
+ : functionDefinition
+ | declaration
+ | ';' // stray ;
+ ;
+
+functionDefinition
+ : declarationSpecifiers? declarator declarationList? compoundStatement
+ ;
+
+declarationList
+ : declaration
+ | declarationList declaration
+ ;
+
+Auto : 'auto';
+Break : 'break';
+Case : 'case';
+Char : 'char';
+Const : 'const';
+Continue : 'continue';
+Default : 'default';
+Do : 'do';
+Double : 'double';
+Else : 'else';
+Enum : 'enum';
+Extern : 'extern';
+Float : 'float';
+For : 'for';
+Goto : 'goto';
+If : 'if';
+Inline : 'inline';
+Int : 'int';
+Long : 'long';
+Register : 'register';
+Restrict : 'restrict';
+Return : 'return';
+Short : 'short';
+Signed : 'signed';
+Sizeof : 'sizeof';
+Static : 'static';
+Struct : 'struct';
+Switch : 'switch';
+Typedef : 'typedef';
+Union : 'union';
+Unsigned : 'unsigned';
+Void : 'void';
+Volatile : 'volatile';
+While : 'while';
+
+Alignas : '_Alignas';
+Alignof : '_Alignof';
+Atomic : '_Atomic';
+Bool : '_Bool';
+Complex : '_Complex';
+Generic : '_Generic';
+Imaginary : '_Imaginary';
+Noreturn : '_Noreturn';
+StaticAssert : '_Static_assert';
+ThreadLocal : '_Thread_local';
+
+LeftParen : '(';
+RightParen : ')';
+LeftBracket : '[';
+RightBracket : ']';
+LeftBrace : '{';
+RightBrace : '}';
+
+Less : '<';
+LessEqual : '<=';
+Greater : '>';
+GreaterEqual : '>=';
+LeftShift : '<<';
+RightShift : '>>';
+
+Plus : '+';
+PlusPlus : '++';
+Minus : '-';
+MinusMinus : '--';
+Star : '*';
+Div : '/';
+Mod : '%';
+
+And : '&';
+Or : '|';
+AndAnd : '&&';
+OrOr : '||';
+Caret : '^';
+Not : '!';
+Tilde : '~';
+
+Question : '?';
+Colon : ':';
+Semi : ';';
+Comma : ',';
+
+Assign : '=';
+// '*=' | '/=' | '%=' | '+=' | '-=' | '<<=' | '>>=' | '&=' | '^=' | '|='
+StarAssign : '*=';
+DivAssign : '/=';
+ModAssign : '%=';
+PlusAssign : '+=';
+MinusAssign : '-=';
+LeftShiftAssign : '<<=';
+RightShiftAssign : '>>=';
+AndAssign : '&=';
+XorAssign : '^=';
+OrAssign : '|=';
+
+Equal : '==';
+NotEqual : '!=';
+
+Arrow : '->';
+Dot : '.';
+Ellipsis : '...';
+
+Identifier
+ : IdentifierNondigit
+ ( IdentifierNondigit
+ | Digit
+ )*
+ ;
+
+fragment
+IdentifierNondigit
+ : Nondigit
+ | UniversalCharacterName
+ //| // other implementation-defined characters...
+ ;
+
+fragment
+Nondigit
+ : [a-zA-Z_]
+ ;
+
+fragment
+Digit
+ : [0-9]
+ ;
+
+fragment
+UniversalCharacterName
+ : '\\u' HexQuad
+ | '\\U' HexQuad HexQuad
+ ;
+
+fragment
+HexQuad
+ : HexadecimalDigit HexadecimalDigit HexadecimalDigit HexadecimalDigit
+ ;
+
+Constant
+ : IntegerConstant
+ | FloatingConstant
+ //| EnumerationConstant
+ | CharacterConstant
+ ;
+
+fragment
+IntegerConstant
+ : DecimalConstant IntegerSuffix?
+ | OctalConstant IntegerSuffix?
+ | HexadecimalConstant IntegerSuffix?
+ | BinaryConstant
+ ;
+
+fragment
+BinaryConstant
+ : '0' [bB] [0-1]+
+ ;
+
+fragment
+DecimalConstant
+ : NonzeroDigit Digit*
+ ;
+
+fragment
+OctalConstant
+ : '0' OctalDigit*
+ ;
+
+fragment
+HexadecimalConstant
+ : HexadecimalPrefix HexadecimalDigit+
+ ;
+
+fragment
+HexadecimalPrefix
+ : '0' [xX]
+ ;
+
+fragment
+NonzeroDigit
+ : [1-9]
+ ;
+
+fragment
+OctalDigit
+ : [0-7]
+ ;
+
+fragment
+HexadecimalDigit
+ : [0-9a-fA-F]
+ ;
+
+fragment
+IntegerSuffix
+ : UnsignedSuffix LongSuffix?
+ | UnsignedSuffix LongLongSuffix
+ | LongSuffix UnsignedSuffix?
+ | LongLongSuffix UnsignedSuffix?
+ ;
+
+fragment
+UnsignedSuffix
+ : [uU]
+ ;
+
+fragment
+LongSuffix
+ : [lL]
+ ;
+
+fragment
+LongLongSuffix
+ : 'll' | 'LL'
+ ;
+
+fragment
+FloatingConstant
+ : DecimalFloatingConstant
+ | HexadecimalFloatingConstant
+ ;
+
+fragment
+DecimalFloatingConstant
+ : FractionalConstant ExponentPart? FloatingSuffix?
+ | DigitSequence ExponentPart FloatingSuffix?
+ ;
+
+fragment
+HexadecimalFloatingConstant
+ : HexadecimalPrefix HexadecimalFractionalConstant BinaryExponentPart FloatingSuffix?
+ | HexadecimalPrefix HexadecimalDigitSequence BinaryExponentPart FloatingSuffix?
+ ;
+
+fragment
+FractionalConstant
+ : DigitSequence? '.' DigitSequence
+ | DigitSequence '.'
+ ;
+
+fragment
+ExponentPart
+ : 'e' Sign? DigitSequence
+ | 'E' Sign? DigitSequence
+ ;
+
+fragment
+Sign
+ : '+' | '-'
+ ;
+
+DigitSequence
+ : Digit+
+ ;
+
+fragment
+HexadecimalFractionalConstant
+ : HexadecimalDigitSequence? '.' HexadecimalDigitSequence
+ | HexadecimalDigitSequence '.'
+ ;
+
+fragment
+BinaryExponentPart
+ : 'p' Sign? DigitSequence
+ | 'P' Sign? DigitSequence
+ ;
+
+fragment
+HexadecimalDigitSequence
+ : HexadecimalDigit+
+ ;
+
+fragment
+FloatingSuffix
+ : 'f' | 'l' | 'F' | 'L'
+ ;
+
+fragment
+CharacterConstant
+ : '\'' CCharSequence '\''
+ | 'L\'' CCharSequence '\''
+ | 'u\'' CCharSequence '\''
+ | 'U\'' CCharSequence '\''
+ ;
+
+fragment
+CCharSequence
+ : CChar+
+ ;
+
+fragment
+CChar
+ : ~['\\\r\n]
+ | EscapeSequence
+ ;
+
+fragment
+EscapeSequence
+ : SimpleEscapeSequence
+ | OctalEscapeSequence
+ | HexadecimalEscapeSequence
+ | UniversalCharacterName
+ ;
+
+fragment
+SimpleEscapeSequence
+ : '\\' ['"?abfnrtv\\]
+ ;
+
+fragment
+OctalEscapeSequence
+ : '\\' OctalDigit
+ | '\\' OctalDigit OctalDigit
+ | '\\' OctalDigit OctalDigit OctalDigit
+ ;
+
+fragment
+HexadecimalEscapeSequence
+ : '\\x' HexadecimalDigit+
+ ;
+
+StringLiteral
+ : EncodingPrefix? '"' SCharSequence? '"'
+ ;
+
+fragment
+EncodingPrefix
+ : 'u8'
+ | 'u'
+ | 'U'
+ | 'L'
+ ;
+
+fragment
+SCharSequence
+ : SChar+
+ ;
+
+fragment
+SChar
+ : ~["\\\r\n]
+ | EscapeSequence
+ | '\\\n' // Added line
+ | '\\\r\n' // Added line
+ ;
+
+LineAfterPreprocessing
+ : '#' ~[\r\n]*
+ -> skip
+ ;
+
+Whitespace
+ : ( [ \t]
+ | 'MXNET_DLL'
+ | 'NNVM_DLL'
+ | 'extern "C" {'
+ | 'DEFAULT' Whitespace? '(' .*? ')'
+ )+
+ -> skip
+ ;
+
+Newline
+ : ( '\r' '\n'?
+ | '\n'
+ )
+ -> skip
+ ;
+
+BlockComment
+ : '/*' .*? '*/'
+ -> skip
+ ;
+
+LineComment
+ : '//' ~[\r\n]*
+ -> skip
+ ;
diff --git a/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/AntlrUtils.java b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/AntlrUtils.java
new file mode 100644
index 0000000..3a2b5e0
--- /dev/null
+++ b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/AntlrUtils.java
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mxnet.jnarator;
+
+import java.util.ArrayList;
+import java.util.List;
+import org.antlr.v4.runtime.tree.ParseTree;
+import org.antlr.v4.runtime.tree.TerminalNode;
+import org.apache.mxnet.jnarator.parser.CParser;
+
+public final class AntlrUtils {
+
+ private AntlrUtils() {}
+
+ public static boolean isTypeDef(CParser.DeclarationSpecifiersContext specs) {
+ if (specs.isEmpty()) {
+ return false;
+ }
+
+ CParser.DeclarationSpecifierContext spec =
+ (CParser.DeclarationSpecifierContext) specs.getChild(0);
+ CParser.StorageClassSpecifierContext storage = spec.storageClassSpecifier();
+ if (storage != null) {
+ return storage.Typedef() != null;
+ }
+ return false;
+ }
+
+ public static String getTypeDefValue(CParser.DeclarationSpecifiersContext specs) {
+ List<String> list = new ArrayList<>();
+ for (int i = 1; i < specs.getChildCount(); ++i) {
+ list.add(specs.getChild(i).getText());
+ }
+ return String.join(" ", list);
+ }
+
+ public static boolean isEnum(CParser.DeclarationSpecifiersContext specs) {
+ if (specs.isEmpty()) {
+ return false;
+ }
+
+ CParser.DeclarationSpecifierContext spec =
+ (CParser.DeclarationSpecifierContext) specs.getChild(0);
+ CParser.TypeSpecifierContext type = spec.typeSpecifier();
+ if (type == null) {
+ return false;
+ }
+ return type.enumSpecifier() != null;
+ }
+
+ public static boolean isStructOrUnion(CParser.DeclarationSpecifiersContext specs) {
+ if (specs.isEmpty()) {
+ return false;
+ }
+
+ CParser.DeclarationSpecifierContext spec =
+ (CParser.DeclarationSpecifierContext) specs.getChild(0);
+ CParser.TypeSpecifierContext type = spec.typeSpecifier();
+ if (type == null) {
+ return false;
+ }
+ return type.structOrUnionSpecifier() != null;
+ }
+
+ public static String getText(ParseTree tree) {
+ StringBuilder sb = new StringBuilder();
+ getText(sb, tree);
+ return sb.toString();
+ }
+
+ private static void getText(StringBuilder sb, ParseTree tree) {
+ if (tree instanceof TerminalNode) {
+ sb.append("\"v\" : \"").append(tree.getText()).append('"');
+ return;
+ }
+ sb.append('"');
+ sb.append(tree.getClass().getSimpleName()).append("\" : {");
+ for (int i = 0; i < tree.getChildCount(); i++) {
+ getText(sb, tree.getChild(i));
+ if (i < tree.getChildCount() - 1) {
+ sb.append(',');
+ }
+ }
+ sb.append('}');
+ }
+
+ public static String toCamelCase(String name) {
+ String[] tokens = name.split("_");
+ for (int i = 0; i < tokens.length; ++i) {
+ char upper = Character.toUpperCase(tokens[i].charAt(0));
+ tokens[i] = upper + tokens[i].substring(1); // NOPMD
+ }
+ return String.join("", tokens);
+ }
+}
diff --git a/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/DataType.java b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/DataType.java
new file mode 100644
index 0000000..3b1c1fe
--- /dev/null
+++ b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/DataType.java
@@ -0,0 +1,266 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mxnet.jnarator;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Set;
+import org.antlr.v4.runtime.tree.ParseTree;
+import org.antlr.v4.runtime.tree.TerminalNode;
+import org.apache.mxnet.jnarator.parser.CParser;
+
+public class DataType {
+
+ private boolean isConst;
+ private boolean functionPointer;
+ private int pointerCount;
+ private StringBuilder type = new StringBuilder(); // NOPMD
+
+ public boolean isConst() {
+ return isConst;
+ }
+
+ public void setConst() {
+ isConst = true;
+ }
+
+ public boolean isFunctionPointer() {
+ return functionPointer;
+ }
+
+ public void setFunctionPointer(boolean functionPointer) {
+ this.functionPointer = functionPointer;
+ }
+
+ public int getPointerCount() {
+ return pointerCount;
+ }
+
+ public void setPointerCount(int pointerCount) {
+ this.pointerCount = pointerCount;
+ }
+
+ public void increasePointerCount() {
+ ++pointerCount;
+ }
+
+ public String getType() {
+ return type.toString();
+ }
+
+ public void setType(String typeName) {
+ type.setLength(0);
+ type.append(typeName);
+ }
+
+ public void appendTypeName(String name) {
+ if (type.length() > 0) {
+ type.append(' ');
+ }
+ type.append(name);
+ }
+
+ public String map(Map<String, TypeDefine> map, Set<String> structs) {
+ String typeName = type.toString().trim();
+ TypeDefine typeDefine = map.get(typeName);
+ boolean isStruct = structs.contains(typeName);
+ if (typeDefine != null && !typeDefine.isCallBack()) {
+ typeName = typeDefine.getValue();
+
+ String mapped = typeName.replaceAll("const ", "").replaceAll(" const", "");
+ if (typeName.length() - mapped.length() > 0) {
+ isConst = true;
+ }
+ typeName = mapped;
+ mapped = typeName.replaceAll("\\*", "");
+ int count = typeName.length() - mapped.length();
+ pointerCount += count;
+ typeName = mapped;
+ setType(typeName);
+ }
+
+ if (pointerCount > 2) {
+ return "PointerByReference";
+ }
+
+ typeName = baseTypeMapping(typeName);
+
+ if (pointerCount == 2) {
+ if (isConst && "char".equals(typeName)) {
+ return "String[]";
+ }
+ return "PointerByReference";
+ }
+
+ if (pointerCount == 1) {
+ switch (typeName) {
+ case "byte":
+ return "ByteBuffer";
+ case "NativeSize":
+ return "NativeSizeByReference";
+ case "int":
+ if (isConst) {
+ return "int[]";
+ }
+ return "IntBuffer";
+ case "long":
+ if (isConst) {
+ return "long[]";
+ }
+ return "LongBuffer";
+ case "char":
+ if (isConst) {
+ return "String";
+ }
+ return "ByteBuffer";
+ case "float":
+ return "FloatBuffer";
+ case "void":
+ return "Pointer";
+ default:
+ if (isStruct) {
+ return typeName + ".ByReference";
+ }
+ return "Pointer";
+ }
+ }
+ return typeName;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public String toString() {
+ StringBuilder sb = new StringBuilder();
+ if (isConst) {
+ sb.append("const ");
+ }
+ sb.append(type);
+ if (pointerCount > 0) {
+ sb.append(' ');
+ for (int i = 0; i < pointerCount; ++i) {
+ sb.append('*');
+ }
+ }
+ return sb.toString();
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ DataType dataType = (DataType) o;
+ return type.toString().equals(dataType.type.toString());
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public int hashCode() {
+ return Objects.hash(type);
+ }
+
+ static DataType parse(ParseTree tree) {
+ DataType dataType = new DataType();
+ parseTypeSpec(dataType, tree);
+ return dataType;
+ }
+
+ static List<DataType> parseDataTypes(List<CParser.DeclarationSpecifierContext> list) {
+ List<DataType> ret = new ArrayList<>();
+ DataType dataType = new DataType();
+ for (CParser.DeclarationSpecifierContext spec : list) {
+ CParser.TypeQualifierContext qualifier = spec.typeQualifier();
+ if (qualifier != null) {
+ String qualifierName = qualifier.getText();
+ if ("const".equals(qualifierName)) {
+ dataType.setConst();
+ } else {
+ dataType.appendTypeName(qualifierName);
+ }
+ continue;
+ }
+
+ CParser.TypeSpecifierContext type = spec.typeSpecifier();
+ parseTypeSpec(dataType, type);
+ ret.add(dataType);
+ dataType = new DataType();
+ }
+
+ return ret;
+ }
+
+ private static void parseTypeSpec(DataType dataType, ParseTree tree) {
+ if (tree == null) {
+ return;
+ }
+
+ if (tree instanceof CParser.StructOrUnionContext) {
+ return;
+ }
+ if (tree instanceof CParser.TypedefNameContext) {
+ if (dataType.getType().isEmpty()) {
+ dataType.appendTypeName(tree.getText());
+ }
+ return;
+ }
+
+ if (tree instanceof TerminalNode) {
+ String value = tree.getText();
+ if ("const".equals(value)) {
+ dataType.setConst();
+ } else if ("*".equals(value)) {
+ dataType.increasePointerCount();
+ } else {
+ dataType.appendTypeName(value);
+ }
+ return;
+ }
+
+ for (int i = 0; i < tree.getChildCount(); i++) {
+ parseTypeSpec(dataType, tree.getChild(i));
+ }
+ }
+
+ private static String baseTypeMapping(String type) {
+ switch (type) {
+ case "uint64_t":
+ case "int64_t":
+ case "long":
+ return "long";
+ case "uint32_t":
+ case "unsigned int":
+ case "unsigned":
+ case "int":
+ return "int";
+ case "bool":
+ return "byte";
+ case "size_t":
+ return "NativeSize";
+ case "char":
+ case "void":
+ case "float":
+ default:
+ return type;
+ }
+ }
+}
diff --git a/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/FuncInfo.java b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/FuncInfo.java
new file mode 100644
index 0000000..3cdc9c9
--- /dev/null
+++ b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/FuncInfo.java
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mxnet.jnarator;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Objects;
+import org.apache.mxnet.jnarator.parser.CParser;
+
+public class FuncInfo {
+
+ private String name;
+ private DataType returnType;
+ private List<Parameter> parameters = new ArrayList<>();
+
+ public String getName() {
+ return name;
+ }
+
+ public void setName(String name) {
+ this.name = name;
+ }
+
+ public DataType getReturnType() {
+ return returnType;
+ }
+
+ public void setReturnType(DataType returnType) {
+ this.returnType = returnType;
+ }
+
+ public List<Parameter> getParameters() {
+ return parameters;
+ }
+
+ public void setParameters(List<Parameter> parameters) {
+ this.parameters = parameters;
+ }
+
+ public void addParameter(Parameter parameter) {
+ parameters.add(parameter);
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public String toString() {
+ StringBuilder sb = new StringBuilder();
+ sb.append(returnType).append(' ').append(name).append('(');
+ if (parameters != null) {
+ boolean first = true;
+ for (Parameter param : parameters) {
+ if (first) {
+ first = false;
+ } else {
+ sb.append(", ");
+ }
+ sb.append(param);
+ }
+ }
+
+ sb.append(");");
+ return sb.toString();
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ FuncInfo funcInfo = (FuncInfo) o;
+ return name.equals(funcInfo.name) && parameters.equals(funcInfo.parameters);
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public int hashCode() {
+ return Objects.hash(name);
+ }
+
+ static FuncInfo parse(CParser.DeclarationContext ctx) {
+ FuncInfo info = new FuncInfo();
+
+ List<CParser.DeclarationSpecifierContext> specs =
+ ctx.declarationSpecifiers().declarationSpecifier();
+ List<DataType> dataTypes = DataType.parseDataTypes(specs);
+ info.setReturnType(dataTypes.get(0));
+ if (dataTypes.size() > 1) {
+ info.setName(dataTypes.get(1).getType());
+ }
+
+ CParser.InitDeclaratorContext init = ctx.initDeclaratorList().initDeclarator();
+ CParser.DirectDeclaratorContext declarator = init.declarator().directDeclarator();
+
+ CParser.DirectDeclaratorContext name = declarator.directDeclarator();
+ if (info.getName() == null) {
+ info.setName(name.getText());
+ CParser.ParameterTypeListContext paramListCtx = declarator.parameterTypeList();
+ if (paramListCtx != null) {
+ Parameter.parseParams(info.getParameters(), paramListCtx);
+ }
+ } else {
+ DataType dataType = new DataType();
+ CParser.TypeSpecifierContext type = declarator.typeSpecifier();
+ dataType.appendTypeName(type.getText());
+ if (declarator.pointer() != null) {
+ dataType.increasePointerCount();
+ }
+ Parameter param = new Parameter(dataType, name.getText());
+ info.addParameter(param);
+ }
+
+ return info;
+ }
+}
diff --git a/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/JnaGenerator.java b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/JnaGenerator.java
new file mode 100644
index 0000000..7be963d
--- /dev/null
+++ b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/JnaGenerator.java
@@ -0,0 +1,371 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mxnet.jnarator;
+
+import java.io.BufferedWriter;
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Properties;
+import java.util.Set;
+
+public class JnaGenerator {
+
+ private Path dir;
+ private String packageName;
+ private String libName;
+ private String className;
+ private Map<String, TypeDefine> typedefMap;
+ private Set<String> structs;
+ private Properties mapping;
+
+ public JnaGenerator(
+ String libName,
+ String packageName,
+ Map<String, TypeDefine> typedefMap,
+ Set<String> structs,
+ Properties mapping) {
+ this.libName = libName;
+ this.packageName = packageName;
+ this.typedefMap = typedefMap;
+ this.structs = structs;
+ this.mapping = mapping;
+ }
+
+ public void init(String output) throws IOException {
+ String[] tokens = packageName.split("\\.");
+ dir = Paths.get(output, tokens);
+ Files.createDirectories(dir);
+ className = AntlrUtils.toCamelCase(libName) + "Library";
+ }
+
+ @SuppressWarnings("PMD.UseConcurrentHashMap")
+ public void writeStructure(Map<String, List<TypeDefine>> structMap) throws IOException {
+ for (Map.Entry<String, List<TypeDefine>> entry : structMap.entrySet()) {
+ String name = entry.getKey();
+ Path path = dir.resolve(name + ".java");
+ try (BufferedWriter writer = Files.newBufferedWriter(path)) {
+ writer.append("package ").append(packageName).append(";\n\n");
+
+ Set<String> importSet = new HashSet<>();
+ importSet.add("com.sun.jna.Pointer");
+ importSet.add("com.sun.jna.Structure");
+ importSet.add("java.util.List");
+
+ Map<String, String> fieldNames = new LinkedHashMap<>();
+ for (TypeDefine typeDefine : entry.getValue()) {
+ String fieldName = typeDefine.getValue();
+ String typeName;
+ if (typeDefine.isCallBack()) {
+ typeName = AntlrUtils.toCamelCase(fieldName) + "Callback";
+ importSet.add("com.sun.jna.Callback");
+ for (Parameter param : typeDefine.getParameters()) {
+ String type = param.getType().map(typedefMap, structs);
+ addImports(importSet, type);
+ }
+ } else {
+ typeName = typeDefine.getDataType().map(typedefMap, structs);
+ addImports(importSet, typeName);
+ }
+ fieldNames.put(fieldName, typeName);
+ }
+
+ int fieldCount = fieldNames.size();
+ if (fieldCount < 2) {
+ importSet.add("java.util.Collections");
+ } else {
+ importSet.add("java.util.Arrays");
+ }
+
+ List<String> imports = new ArrayList<>(importSet.size());
+ imports.addAll(importSet);
+ Collections.sort(imports);
+ for (String imp : imports) {
+ writer.append("import ").append(imp).append(";\n");
+ }
+
+ writer.append("\npublic class ").append(name).append(" extends Structure {\n");
+ if (fieldCount > 0) {
+ writer.write("\n");
+ }
+ for (Map.Entry<String, String> field : fieldNames.entrySet()) {
+ writer.append(" public ").append(field.getValue()).append(' ');
+ writer.append(field.getKey()).append(";\n");
+ }
+
+ writer.append("\n public ").append(name).append("() {\n");
+ writer.append(" }\n");
+ writer.append("\n public ").append(name).append("(Pointer peer) {\n");
+ writer.append(" super(peer);\n");
+ writer.append(" }\n");
+
+ writer.append("\n @Override\n");
+ writer.append(" protected List<String> getFieldOrder() {\n");
+ switch (fieldNames.size()) {
+ case 0:
+ writer.append(" return Collections.emptyList();\n");
+ break;
+ case 1:
+ writer.append(" return Collections.singletonList(");
+ String firstField = fieldNames.keySet().iterator().next();
+ writer.append('"').append(firstField).append("\");\n");
+ break;
+ default:
+ writer.append(" return Arrays.asList(");
+ boolean first = true;
+ for (String fieldName : fieldNames.keySet()) {
+ if (first) {
+ first = false;
+ } else {
+ writer.write(", ");
+ }
+ writer.append('"').append(fieldName).append('"');
+ }
+ writer.append(");\n");
+ break;
+ }
+ writer.append(" }\n");
+
+ for (TypeDefine typeDefine : entry.getValue()) {
+ String fieldName = typeDefine.getValue();
+ String typeName = fieldNames.get(fieldName);
+ String getterName;
+ if (!typeDefine.isCallBack()) {
+ getterName = AntlrUtils.toCamelCase(fieldName);
+ } else {
+ getterName = typeName;
+ }
+
+ writer.append("\n public void set").append(getterName).append('(');
+ writer.append(typeName).append(' ').append(fieldName).append(") {\n");
+ writer.append(" this.").append(fieldName).append(" = ");
+ writer.append(fieldName).append(";\n");
+ writer.append(" }\n");
+ writer.append("\n public ").append(typeName).append(" get");
+ writer.append(getterName).append("() {\n");
+ writer.append(" return ").append(fieldName).append(";\n");
+ writer.append(" }\n");
+ }
+
+ writer.append("\n public static final class ByReference extends ");
+ writer.append(name).append(" implements Structure.ByReference {}\n");
+
+ writer.append("\n public static final class ByValue extends ");
+ writer.append(name).append(" implements Structure.ByValue {}\n");
+
+ for (TypeDefine typeDefine : entry.getValue()) {
+ if (typeDefine.isCallBack()) {
+ DataType dataType = typeDefine.getDataType();
+ String fieldName = typeDefine.getValue();
+
+ String callbackName = fieldNames.get(fieldName);
+ String returnType = mapping.getProperty(callbackName);
+ if (returnType == null) {
+ returnType = dataType.map(typedefMap, structs);
+ }
+
+ writer.append("\n public interface ").append(callbackName);
+ writer.append(" extends Callback {\n");
+ writer.append(" ").append(returnType).append(" apply(");
+ writeParameters(writer, fieldName, typeDefine.getParameters());
+ writer.append(");\n");
+ writer.append(" }\n");
+ }
+ }
+
+ writer.append("}\n");
+ }
+ }
+ }
+
+ public void writeLibrary(Collection<FuncInfo> functions, Map<String, List<String>> enumMap)
+ throws IOException {
+ try (BufferedWriter writer = Files.newBufferedWriter(dir.resolve(className + ".java"))) {
+ writer.append("package ").append(packageName).append(";\n\n");
+
+ writer.append("import com.sun.jna.Callback;\n");
+ writer.append("import com.sun.jna.Library;\n");
+ writer.append("import com.sun.jna.Pointer;\n");
+ writer.append("import com.sun.jna.ptr.PointerByReference;\n");
+ writer.append("import java.nio.ByteBuffer;\n");
+ writer.append("import java.nio.FloatBuffer;\n");
+ writer.append("import java.nio.IntBuffer;\n");
+ writer.append("import java.nio.LongBuffer;\n");
+
+ writer.append("\npublic interface ").append(className).append(" extends Library {\n\n");
+
+ for (Map.Entry<String, List<String>> entry : enumMap.entrySet()) {
+ String name = entry.getKey();
+ writer.append("\n enum ").append(name).append(" {\n");
+ List<String> fields = entry.getValue();
+ int count = 0;
+ for (String field : fields) {
+ writer.append(" ").append(field);
+ if (++count < fields.size()) {
+ writer.append(',');
+ }
+ writer.append('\n');
+ }
+ writer.append(" }\n");
+ }
+
+ for (TypeDefine typeDefine : typedefMap.values()) {
+ if (typeDefine.isCallBack()) {
+ String callbackName = typeDefine.getDataType().getType();
+ String returnType = mapping.getProperty(callbackName);
+ if (returnType == null) {
+ returnType = typeDefine.getValue();
+ }
+ writer.append("\n interface ").append(callbackName);
+ writer.append(" extends Callback {\n");
+ writer.append(" ").append(returnType).append(" apply(");
+ writeParameters(writer, callbackName, typeDefine.getParameters());
+ writer.append(");\n");
+ writer.append(" }\n");
+ }
+ }
+
+ for (FuncInfo info : functions) {
+ writeFunction(writer, info);
+ }
+ writer.append("}\n");
+ }
+ }
+
+ public void writeNativeSize() throws IOException {
+ try (BufferedWriter writer = Files.newBufferedWriter(dir.resolve("NativeSize.java"))) {
+ writer.append("package ").append(packageName).append(";\n\n");
+ writer.append("import com.sun.jna.IntegerType;\n");
+ writer.append("import com.sun.jna.Native;\n\n");
+
+ writer.append("public class NativeSize extends IntegerType {\n\n");
+ writer.append(" private static final long serialVersionUID = 1L;\n\n");
+ writer.append(" public static final int SIZE = Native.SIZE_T_SIZE;\n\n");
+ writer.append(" public NativeSize() {\n");
+ writer.append(" this(0);\n");
+ writer.append(" }\n\n");
+ writer.append(" public NativeSize(long value) {\n");
+ writer.append(" super(SIZE, value);\n");
+ writer.append(" }\n");
+ writer.append("}\n");
+ }
+
+ Path path = dir.resolve("NativeSizeByReference.java");
+ try (BufferedWriter writer = Files.newBufferedWriter(path)) {
+ writer.append("package ").append(packageName).append(";\n\n");
+ writer.append("import com.sun.jna.ptr.ByReference;\n\n");
+ writer.append("public class NativeSizeByReference extends ByReference {\n\n");
+ writer.append(" public NativeSizeByReference() {\n");
+ writer.append(" this(new NativeSize(0));\n");
+ writer.append(" }\n\n");
+ writer.append(" public NativeSizeByReference(NativeSize value) {\n");
+ writer.append(" super(NativeSize.SIZE);\n");
+ writer.append(" setValue(value);\n");
+ writer.append(" }\n\n");
+ writer.append(" public void setValue(NativeSize value) {\n");
+ writer.append(" if (NativeSize.SIZE == 4) {\n");
+ writer.append(" getPointer().setInt(0, value.intValue());\n");
+ writer.append(" } else if (NativeSize.SIZE == 8) {\n");
+ writer.append(" getPointer().setLong(0, value.longValue());\n");
+ writer.append(" } else {\n");
+ writer.append(
+ " throw new IllegalArgumentException(\"size_t has to be either 4 or 8 bytes.\");\n");
+ writer.append(" }\n");
+ writer.append(" }\n\n");
+ writer.append(" public NativeSize getValue() {\n");
+ writer.append(" if (NativeSize.SIZE == 4) {\n");
+ writer.append(" return new NativeSize(getPointer().getInt(0));\n");
+ writer.append(" } else if (NativeSize.SIZE == 8) {\n");
+ writer.append(" return new NativeSize(getPointer().getLong(0));\n");
+ writer.append(" } else {\n");
+ writer.append(
+ " throw new IllegalArgumentException(\"size_t has to be either 4 or 8 bytes.\");\n");
+ writer.append(" }\n");
+ writer.append(" }\n");
+ writer.append("}\n");
+ }
+ }
+
+ private void writeFunction(BufferedWriter writer, FuncInfo info) throws IOException {
+ String funcName = info.getName();
+ String returnType = mapping.getProperty(funcName);
+ if (returnType == null) {
+ returnType = info.getReturnType().map(typedefMap, structs);
+ }
+ writer.append("\n ").append(returnType).append(' ');
+ writer.append(funcName).append('(');
+ writeParameters(writer, funcName, info.getParameters());
+ writer.append(");\n");
+ }
+
+ private void writeParameters(BufferedWriter writer, String funcName, List<Parameter> parameters)
+ throws IOException {
+ if (parameters != null) {
+ boolean first = true;
+ for (Parameter param : parameters) {
+ if (first) {
+ first = false;
+ } else {
+ writer.append(", ");
+ }
+ String paramName = param.getName();
+ String type = mapping.getProperty(funcName + '.' + paramName);
+ if (type == null) {
+ type = param.getType().map(typedefMap, structs);
+ }
+ if (!"void".equals(type)) {
+ writer.append(type).append(' ');
+ writer.append(paramName);
+ }
+ }
+ }
+ }
+
+ private static void addImports(Set<String> importSet, String typeName) {
+ switch (typeName) {
+ case "ByReference":
+ case "ByteByReference":
+ case "DoubleByReference":
+ case "FloatByReference":
+ case "IntByReference":
+ case "LongByReference":
+ case "NativeLongByReference":
+ case "PointerByReference":
+ case "ShortByReference":
+ importSet.add("com.sun.jna.ptr." + typeName);
+ break;
+ case "ByteBuffer":
+ case "DoubleBuffer":
+ case "FloatBuffer":
+ case "IntBuffer":
+ case "LongBuffer":
+ case "ShortBuffer":
+ importSet.add("java.nio." + typeName);
+ break;
+ default:
+ break;
+ }
+ }
+}
diff --git a/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/JnaParser.java b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/JnaParser.java
new file mode 100644
index 0000000..16a350e
--- /dev/null
+++ b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/JnaParser.java
@@ -0,0 +1,195 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mxnet.jnarator;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import org.antlr.v4.runtime.CharStreams;
+import org.antlr.v4.runtime.CommonTokenStream;
+import org.antlr.v4.runtime.tree.ParseTree;
+import org.antlr.v4.runtime.tree.ParseTreeWalker;
+import org.apache.mxnet.jnarator.parser.CBaseListener;
+import org.apache.mxnet.jnarator.parser.CLexer;
+import org.apache.mxnet.jnarator.parser.CParser;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class JnaParser {
+
+ static final Logger logger = LoggerFactory.getLogger(Main.class);
+
+ Map<String, List<TypeDefine>> structMap;
+ Map<String, List<String>> enumMap;
+ List<FuncInfo> functions;
+ Map<String, TypeDefine> typedefMap;
+ private Set<String> functionNames;
+
+ public JnaParser() {
+ structMap = new LinkedHashMap<>();
+ enumMap = new LinkedHashMap<>();
+ functions = new ArrayList<>();
+ typedefMap = new LinkedHashMap<>();
+ functionNames = new HashSet<>();
+ }
+
+ public void parse(String headerFile) {
+ try {
+ CLexer lexer = new CLexer(CharStreams.fromFileName(headerFile));
+ CommonTokenStream tokens = new CommonTokenStream(lexer);
+ CParser parser = new CParser(tokens);
+ ParseTree tree = parser.compilationUnit();
+
+ ParseTreeWalker walker = new ParseTreeWalker();
+ CBaseListener listener =
+ new CBaseListener() {
+
+ /** {@inheritDoc} */
+ @Override
+ public void enterDeclaration(CParser.DeclarationContext ctx) {
+ CParser.DeclarationSpecifiersContext specs =
+ ctx.declarationSpecifiers();
+ CParser.InitDeclaratorListContext init = ctx.initDeclaratorList();
+
+ if (AntlrUtils.isTypeDef(specs)) {
+ TypeDefine value = TypeDefine.parse(init, specs);
+ typedefMap.put(value.getDataType().getType(), value);
+ } else if (AntlrUtils.isStructOrUnion(specs)) {
+ CParser.DeclarationSpecifierContext spec =
+ (CParser.DeclarationSpecifierContext) specs.getChild(0);
+ CParser.TypeSpecifierContext type = spec.typeSpecifier();
+ CParser.StructOrUnionSpecifierContext struct =
+ type.structOrUnionSpecifier();
+ String name = struct.Identifier().getText();
+ List<TypeDefine> fields = new ArrayList<>();
+
+ CParser.StructDeclarationListContext list =
+ struct.structDeclarationList();
+ parseStructFields(fields, list);
+
+ structMap.put(name, fields);
+ } else if (AntlrUtils.isEnum(specs)) {
+ CParser.DeclarationSpecifierContext spec =
+ (CParser.DeclarationSpecifierContext) specs.getChild(0);
+ CParser.TypeSpecifierContext type = spec.typeSpecifier();
+ CParser.EnumSpecifierContext enumSpecifierContext =
+ type.enumSpecifier();
+ String name = enumSpecifierContext.Identifier().getText();
+ List<String> fields = new ArrayList<>();
+ parseEnum(fields, ctx);
+ enumMap.put(name, fields);
+ } else {
+ FuncInfo info = FuncInfo.parse(ctx);
+ if (checkDuplicate(info)) {
+ logger.warn("Duplicate function: {}.", info.getName());
+ } else {
+ functions.add(info);
+ }
+ }
+ }
+ };
+ walker.walk(listener, tree);
+ } catch (IOException e) {
+ logger.error("", e);
+ }
+ }
+
+ void parseStructFields(List<TypeDefine> fields, ParseTree tree) {
+ if (tree instanceof CParser.StructDeclarationContext) {
+ CParser.StructDeclarationContext ctx = (CParser.StructDeclarationContext) tree;
+ CParser.SpecifierQualifierListContext qualifierList = ctx.specifierQualifierList();
+ DataType dataType = DataType.parse(qualifierList);
+
+ TypeDefine typeDefine = new TypeDefine();
+ fields.add(typeDefine);
+
+ typeDefine.setDataType(dataType);
+
+ CParser.StructDeclaratorListContext name = ctx.structDeclaratorList();
+ if (name != null) {
+ typeDefine.setCallBack(true);
+
+ CParser.DirectDeclaratorContext dd =
+ name.structDeclarator().declarator().directDeclarator();
+ CParser.DirectDeclaratorContext nameCtx =
+ dd.directDeclarator().declarator().directDeclarator();
+ String fieldName = nameCtx.getText();
+ typeDefine.setValue(fieldName);
+
+ CParser.ParameterTypeListContext paramListCtx = dd.parameterTypeList();
+ if (paramListCtx != null) {
+ Parameter.parseParams(typeDefine.getParameters(), paramListCtx);
+ }
+ } else {
+ CParser.SpecifierQualifierListContext nameList =
+ qualifierList.specifierQualifierList();
+ if (nameList.specifierQualifierList() != null) {
+ typeDefine.setValue(nameList.specifierQualifierList().getText());
+ } else {
+ typeDefine.setValue(nameList.getText());
+ }
+ }
+ return;
+ }
+
+ for (int i = 0; i < tree.getChildCount(); i++) {
+ parseStructFields(fields, tree.getChild(i));
+ }
+ }
+
+ void parseEnum(List<String> fields, ParseTree ctx) {
+ if (ctx instanceof CParser.EnumerationConstantContext) {
+ fields.add(ctx.getText());
+ return;
+ }
+
+ for (int i = 0; i < ctx.getChildCount(); i++) {
+ parseEnum(fields, ctx.getChild(i));
+ }
+ }
+
+ public Map<String, List<TypeDefine>> getStructMap() {
+ return structMap;
+ }
+
+ public Map<String, List<String>> getEnumMap() {
+ return enumMap;
+ }
+
+ public List<FuncInfo> getFunctions() {
+ return functions;
+ }
+
+ public Map<String, TypeDefine> getTypedefMap() {
+ return typedefMap;
+ }
+
+ boolean checkDuplicate(FuncInfo function) {
+ if (!functionNames.add(function.getName())) {
+ for (FuncInfo info : functions) {
+ if (function.equals(info)) {
+ return true;
+ }
+ }
+ }
+ return false;
+ }
+}
diff --git a/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/Main.java b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/Main.java
new file mode 100644
index 0000000..39aa8fc
--- /dev/null
+++ b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/Main.java
@@ -0,0 +1,175 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mxnet.jnarator;
+
+import java.io.InputStream;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.List;
+import java.util.Map;
+import java.util.Properties;
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.DefaultParser;
+import org.apache.commons.cli.HelpFormatter;
+import org.apache.commons.cli.Option;
+import org.apache.commons.cli.Options;
+import org.apache.commons.cli.ParseException;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public final class Main {
+
+ private static final Logger logger = LoggerFactory.getLogger(Main.class);
+
+ private Main() {}
+
+ public static void main(String[] args) {
+ Options options = Config.getOptions();
+ try {
+ DefaultParser cmdParser = new DefaultParser();
+ CommandLine cmd = cmdParser.parse(options, args, null, false);
+ Config config = new Config(cmd);
+
+ String output = config.getOutput();
+ String packageName = config.getPackageName();
+ String library = config.getLibrary();
+ String[] headerFiles = config.getHeaderFiles();
+ String mappingFile = config.getMappingFile();
+
+ Path dir = Paths.get(output);
+ Files.createDirectories(dir);
+
+ Properties mapping = new Properties();
+ if (mappingFile != null) {
+ Path file = Paths.get(mappingFile);
+ if (Files.notExists(file)) {
+ logger.error("mapping file does not exists: {}", mappingFile);
+ System.exit(-1); // NOPMD
+ }
+ try (InputStream in = Files.newInputStream(file)) {
+ mapping.load(in);
+ }
+ }
+
+ JnaParser jnaParser = new JnaParser();
+ Map<String, TypeDefine> typedefMap = jnaParser.getTypedefMap();
+ Map<String, List<TypeDefine>> structMap = jnaParser.getStructMap();
+ JnaGenerator generator =
+ new JnaGenerator(library, packageName, typedefMap, structMap.keySet(), mapping);
+ generator.init(output);
+
+ for (String headerFile : headerFiles) {
+ jnaParser.parse(headerFile);
+ }
+
+ generator.writeNativeSize();
+ generator.writeStructure(structMap);
+ generator.writeLibrary(jnaParser.getFunctions(), jnaParser.getEnumMap());
+ } catch (ParseException e) {
+ HelpFormatter formatter = new HelpFormatter();
+ formatter.setLeftPadding(1);
+ formatter.setWidth(120);
+ formatter.printHelp(e.getMessage(), options);
+ System.exit(-1); // NOPMD
+ } catch (Throwable t) {
+ logger.error("", t);
+ System.exit(-1); // NOPMD
+ }
+ }
+
+ public static final class Config {
+
+ private String library;
+ private String packageName;
+ private String output;
+ private String[] headerFiles;
+ private String mappingFile;
+
+ public Config(CommandLine cmd) {
+ library = cmd.getOptionValue("library");
+ packageName = cmd.getOptionValue("package");
+ output = cmd.getOptionValue("output");
+ headerFiles = cmd.getOptionValues("header");
+ mappingFile = cmd.getOptionValue("mappingFile");
+ }
+
+ public static Options getOptions() {
+ Options options = new Options();
+ options.addOption(
+ Option.builder("l")
+ .longOpt("library")
+ .hasArg()
+ .required()
+ .argName("LIBRARY")
+ .desc("library name")
+ .build());
+ options.addOption(
+ Option.builder("p")
+ .longOpt("package")
+ .required()
+ .hasArg()
+ .argName("PACKAGE")
+ .desc("Java package name")
+ .build());
+ options.addOption(
+ Option.builder("o")
+ .longOpt("output")
+ .required()
+ .hasArg()
+ .argName("OUTPUT")
+ .desc("output directory")
+ .build());
+ options.addOption(
+ Option.builder("f")
+ .longOpt("header")
+ .required()
+ .hasArgs()
+ .argName("HEADER")
+ .desc("Header files")
+ .build());
+ options.addOption(
+ Option.builder("m")
+ .longOpt("mappingFile")
+ .hasArg()
+ .argName("MAPPING_FILE")
+ .desc("Type mappingFile config file")
+ .build());
+ return options;
+ }
+
+ public String getLibrary() {
+ return library;
+ }
+
+ public String getPackageName() {
+ return packageName;
+ }
+
+ public String getOutput() {
+ return output;
+ }
+
+ public String[] getHeaderFiles() {
+ return headerFiles;
+ }
+
+ public String getMappingFile() {
+ return mappingFile;
+ }
+ }
+}
diff --git a/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/Parameter.java b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/Parameter.java
new file mode 100644
index 0000000..f46e5e7
--- /dev/null
+++ b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/Parameter.java
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mxnet.jnarator;
+
+import java.util.List;
+import java.util.Objects;
+import org.antlr.v4.runtime.tree.ParseTree;
+import org.apache.mxnet.jnarator.parser.CParser;
+
+public class Parameter {
+
+ private DataType type;
+ private String name;
+
+ public Parameter(DataType type, String name) {
+ this.type = type;
+ this.name = name;
+ }
+
+ public DataType getType() {
+ return type;
+ }
+
+ public String getName() {
+ return name;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public String toString() {
+ return type.toString() + ' ' + name;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ Parameter parameter = (Parameter) o;
+ return type.equals(parameter.type);
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public int hashCode() {
+ return Objects.hash(type);
+ }
+
+ static void parseParams(List<Parameter> params, ParseTree ctx) {
+ if (ctx instanceof CParser.ParameterDeclarationContext) {
+ CParser.ParameterDeclarationContext declarationContext =
+ (CParser.ParameterDeclarationContext) ctx;
+ CParser.DeclarationSpecifiersContext spec = declarationContext.declarationSpecifiers();
+ DataType dataType;
+ if (spec == null) {
+ dataType = DataType.parse(declarationContext.declarationSpecifiers2());
+ } else {
+ dataType = DataType.parse(spec);
+ }
+
+ CParser.DeclaratorContext declarator = declarationContext.declarator();
+
+ String name;
+ if (declarator != null) {
+ CParser.PointerContext pointer = declarator.pointer();
+ if (pointer != null) {
+ dataType.increasePointerCount();
+ }
+ name = declarator.directDeclarator().getText();
+ } else {
+ name = "arg" + (params.size() + 1);
+ }
+
+ Parameter param = new Parameter(dataType, name);
+ params.add(param);
+ return;
+ }
+ for (int i = 0; i < ctx.getChildCount(); i++) {
+ parseParams(params, ctx.getChild(i));
+ }
+ }
+}
diff --git a/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/TypeDefine.java b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/TypeDefine.java
new file mode 100644
index 0000000..1f30d21
--- /dev/null
+++ b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/TypeDefine.java
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mxnet.jnarator;
+
+import java.util.ArrayList;
+import java.util.List;
+import org.apache.mxnet.jnarator.parser.CParser;
+
+public class TypeDefine {
+
+ private DataType dataType;
+ private boolean callBack;
+ private String value;
+ private List<Parameter> parameters = new ArrayList<>();
+
+ public DataType getDataType() {
+ return dataType;
+ }
+
+ public void setDataType(DataType dataType) {
+ this.dataType = dataType;
+ }
+
+ public boolean isCallBack() {
+ return callBack;
+ }
+
+ public void setCallBack(boolean callBack) {
+ this.callBack = callBack;
+ }
+
+ public String getValue() {
+ return value;
+ }
+
+ public void setValue(String value) {
+ this.value = value;
+ }
+
+ public List<Parameter> getParameters() {
+ return parameters;
+ }
+
+ static TypeDefine parse(
+ CParser.InitDeclaratorListContext init, CParser.DeclarationSpecifiersContext specs) {
+ TypeDefine typeDefine = new TypeDefine();
+ DataType dataType = new DataType();
+ typeDefine.setDataType(dataType);
+
+ CParser.DirectDeclaratorContext ctx = init.initDeclarator().declarator().directDeclarator();
+ CParser.DirectDeclaratorContext callback = ctx.directDeclarator();
+ if (callback == null) {
+ dataType.setType(ctx.getText());
+ } else {
+ typeDefine.setCallBack(true);
+ dataType.setType(callback.declarator().directDeclarator().getText());
+ CParser.ParameterTypeListContext paramListCtx = ctx.parameterTypeList();
+ List<Parameter> parameters = typeDefine.getParameters();
+ Parameter.parseParams(parameters, paramListCtx);
+ }
+
+ List<String> list = new ArrayList<>();
+ for (int i = 1; i < specs.getChildCount(); ++i) {
+ list.add(specs.getChild(i).getText());
+ }
+
+ typeDefine.setValue(String.join(" ", list));
+ return typeDefine;
+ }
+}
diff --git a/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/package-info.java b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/package-info.java
new file mode 100644
index 0000000..a6cff70
--- /dev/null
+++ b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/package-info.java
@@ -0,0 +1,19 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/** Contains classes to generate the Apache MXNet (incubating) native interface. */
+package org.apache.mxnet.jnarator;
diff --git a/java-package/jnarator/src/main/resources/log4j2.xml b/java-package/jnarator/src/main/resources/log4j2.xml
new file mode 100644
index 0000000..4818a95
--- /dev/null
+++ b/java-package/jnarator/src/main/resources/log4j2.xml
@@ -0,0 +1,30 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!--
+ ~ Licensed to the Apache Software Foundation (ASF) under one or more
+ ~ contributor license agreements. See the NOTICE file distributed with
+ ~ this work for additional information regarding copyright ownership.
+ ~ The ASF licenses this file to You under the Apache License, Version 2.0
+ ~ (the "License"); you may not use this file except in compliance with
+ ~ the License. You may obtain a copy of the License at
+ ~
+ ~ http://www.apache.org/licenses/LICENSE-2.0
+ ~
+ ~ Unless required by applicable law or agreed to in writing, software
+ ~ distributed under the License is distributed on an "AS IS" BASIS,
+ ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ ~ See the License for the specific language governing permissions and
+ ~ limitations under the License.
+ -->
+
+<Configuration status="WARN">
+ <Appenders>
+ <Console name="Console" target="SYSTEM_OUT">
+ <PatternLayout pattern="[%highlight{%-5level}] %msg%n%throwable"/>
+ </Console>
+ </Appenders>
+ <Loggers>
+ <Root level="DEBUG" additivity="false">
+ <AppenderRef ref="Console"/>
+ </Root>
+ </Loggers>
+</Configuration>
diff --git a/java-package/mxnet-engine/build.gradle b/java-package/mxnet-engine/build.gradle
new file mode 100644
index 0000000..c874f72
--- /dev/null
+++ b/java-package/mxnet-engine/build.gradle
@@ -0,0 +1,188 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+plugins {
+ id 'java'
+}
+
+group 'org.apache.mxnet'
+version '0.0.1-SNAPSHOT'
+
+repositories {
+ mavenCentral()
+}
+
+def getOsName() {
+ def os_name = System.properties['os.name']
+ if (os_name.contains('windows')) {
+ return "win"
+ } else if (os_name.contains('Mac OS X')) {
+ return "osx"
+ } else if (os_name.contains('Linux')) {
+ return "linux"
+ } else {
+ return System.properties['os.name']
+ }
+}
+
+dependencies {
+ api "com.google.code.gson:gson:${gson_version}"
+ api "net.java.dev.jna:jna:${jna_version}"
+ api "org.apache.commons:commons-compress:${commons_compress_version}"
+ api "org.slf4j:slf4j-api:${slf4j_version}"
+
+ testImplementation("org.testng:testng:${testng_version}") {
+ exclude group: "junit", module: "junit"
+ }
+ testImplementation "org.slf4j:slf4j-simple:${slf4j_version}"
+ // Solve the problem: Failed to load class "org.slf4j.impl.StaticLoggerBinder".
+ implementation "org.slf4j:slf4j-simple:${slf4j_version}"
+ def osName = getOsName()
+ implementation files("${project(':native').buildDir}/libs/native-${mxnet_version}-SNAPSHOT-${osName}-x86_64.jar")
+// implementation fileTree(dir: "${project(':native').buildDir}/lib", includes: ["native-${mxnet_version}-SNAPSHOT-${osName}-x86_64.jar"])
+}
+
+sourceSets {
+ main {
+ java {
+ srcDirs = ['src/main/java', 'build/generated-src']
+ }
+ }
+}
+
+checkstyleMain.source = 'src/main/java'
+pmdMain.source = 'src/main/java'
+
+task jnarator(dependsOn: ":jnarator:jar") {
+ outputs.dir "${project.buildDir}/generated-src"
+ doLast {
+ File jnaGenerator = project(":jnarator").jar.outputs.files.singleFile
+ javaexec {
+ main = "-jar"
+ args = [
+ jnaGenerator.absolutePath,
+ "-l",
+ "mxnet",
+ "-p",
+ "org.apache.mxnet.jna",
+ "-o",
+ "${project.buildDir}/generated-src",
+ "-m",
+ "${project.projectDir}/src/main/jna/mapping.properties",
+ "-f",
+ "../../include/mxnet/c_api.h",
+ "../../include/nnvm/c_api.h"
+ ]
+ }
+ }
+}
+
+test {
+ useTestNG() {
+ useDefaultListeners = true
+ }
+ environment "PATH", "src/test/bin:${environment.PATH}"
+// environment "MXNET_LIBRARY_PATH", "${MXNET_LIBRARY_PATH}"
+ maxHeapSize = '6G'
+ testLogging.showStandardStreams = true
+ beforeTest { descriptor ->
+ logger.lifecycle("Running test: " + descriptor)
+ }
+ failFast = false
+ onOutput { descriptor, event ->
+ logger.lifecycle("Test: " + descriptor + " produced standard out/err: " + event.message )
+ }
+// debugOptions {
+// enabled = true
+// port = 4455
+// server = true
+// suspend = true
+// }
+// filter {
+// includeTestsMatching("*Test")
+// }
+}
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+//import java.util.regex.Matcher
+//import java.util.regex.Pattern
+
+//def checkForUpdate(String path, String url) {
+// def expected = new URL(url).text
+// def actual = new File("${project.projectDir}/src/main/include/${path}").text
+// if (!actual.equals(expected)) {
+// def fileName = path.replaceAll("[/\\\\]", '_')
+// file("${project.projectDir}/build").mkdirs()
+// (file("${project.projectDir}/build/${fileName}")).text = expected
+// logger.warn("[\033[31mWARN\033[0m ] Header file has been changed in open source project: ${path}.")
+// }
+//}
+
+//task checkHeaderFile() {
+// outputs.files "${project.buildDir}/mxnet_c_api.h", "${project.buildDir}/nnvm_c_api.h"
+// doFirst {
+// if (gradle.startParameter.offline) {
+// logger.warn("[\033[31mWARN\033[0m ] Ignore header validation in offline mode.")
+// return
+// }
+//
+// def mxnetUrl = "https://raw.githubusercontent.com/apache/incubator-mxnet/v1.7.x/"
+// checkForUpdate("mxnet/c_api.h", "${mxnetUrl}/include/mxnet/c_api.h")
+// def content = new URL("https://github.com/apache/incubator-mxnet/tree/v1.7.x/3rdparty").text
+//
+// Pattern pattern = Pattern.compile("href=\"/apache/incubator-tvm/tree/([a-z0-9]+)\"")
+// Matcher m = pattern.matcher(content);
+// if (!m.find()) {
+// throw new GradleException("Failed to retrieve submodule hash for tvm from github")
+// }
+// String hash = m.group(1);
+//
+// def nnvmUrl = "https://raw.githubusercontent.com/apache/incubator-tvm/${hash}"
+// checkForUpdate("nnvm/c_api.h", "${nnvmUrl}/nnvm/include/nnvm/c_api.h")
+// }
+//}
+
+compileJava.dependsOn(jnarator)
+
+// TODO
+//publishing {
+// publications {
+// maven(MavenPublication) {
+// pom {
+// name = "DJL Engine Adapter for Apache MXNet"
+// description = "Deep Java Library (DJL) Engine Adapter for Apache MXNet"
+// url = "http://www.djl.ai/mxnet/${project.name}"
+// }
+// }
+// }
+//}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/BaseMxResource.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/BaseMxResource.java
new file mode 100644
index 0000000..13ae88d
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/BaseMxResource.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.engine;
+
+import org.apache.mxnet.jna.JnaUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * The top-level {@link MxResource} instance, with no parent Resource to manage. The {@link
+ * BaseMxResource} instance will be lazy loaded when the first time called, like when {@link Model}
+ * instance is loaded for the first time.
+ */
+public final class BaseMxResource extends MxResource {
+
+ private static final Logger logger = LoggerFactory.getLogger(BaseMxResource.class);
+
+ private static BaseMxResource systemMxResource;
+
+ protected BaseMxResource() {
+ super();
+ // Workaround MXNet engine lazy initialization issue
+ JnaUtils.getAllOpNames();
+
+ JnaUtils.setNumpyMode(JnaUtils.NumpyMode.GLOBAL_ON);
+
+ // Workaround MXNet shutdown crash issue
+ Runtime.getRuntime().addShutdownHook(new Thread(JnaUtils::waitAll)); // NOPMD
+ }
+
+ /**
+ * Getter method for the singleton {@code systemMxResource} instance.
+ *
+ * @return The top-leve {@link BaseMxResource} instance.
+ */
+ public static synchronized BaseMxResource getSystemMxResource() {
+ if (systemMxResource == null) {
+ systemMxResource = new BaseMxResource();
+ }
+ return systemMxResource;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public void close() {
+ if (!getClosed()) {
+ logger.debug(String.format("Start to free BaseMxResource instance: %S", this.getUid()));
+ // only clean sub resources
+ JnaUtils.waitAll();
+ super.freeSubResources();
+ setClosed(true);
+ logger.debug(
+ String.format("Finish to free BaseMxResource instance: %S", this.getUid()));
+ }
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/CachedOp.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/CachedOp.java
new file mode 100644
index 0000000..eef058e
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/CachedOp.java
@@ -0,0 +1,162 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mxnet.engine;
+
+import com.sun.jna.Pointer;
+import java.util.List;
+import java.util.Map;
+import org.apache.mxnet.jna.JnaUtils;
+import org.apache.mxnet.ndarray.NDArray;
+import org.apache.mxnet.ndarray.NDList;
+import org.apache.mxnet.ndarray.types.Shape;
+import org.apache.mxnet.nn.Parameter;
+import org.apache.mxnet.nn.SymbolBlock;
+import org.apache.mxnet.util.Pair;
+import org.apache.mxnet.util.PairList;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * The {@code CachedOp} is an internal helper that provides the core functionality to execute a
+ * {@link SymbolBlock}.
+ *
+ * <p>We don't recommend users interact with this class directly. Users should use {@link Predictor}
+ * instead. CachedOp is an operator that simplifies calling and analyzing the input shape. It
+ * requires minimum input to do inference because most of the information can be obtained from the
+ * model itself.
+ */
+public class CachedOp extends MxResource {
+
+ private static final Logger logger = LoggerFactory.getLogger(CachedOp.class);
+
+ private List<Parameter> parameters;
+ private PairList<String, Integer> dataIndices;
+ private Map<String, Integer> dataIndicesMap;
+ private List<Integer> paramIndices;
+
+ /**
+ * Creates an instance of {@link CachedOp}.
+ *
+ * <p>It can be created by using {@link JnaUtils#createCachedOp(SymbolBlock, MxResource)}
+ *
+ * @param parent the MxResource object to manage this instance of CachedOp
+ * @param handle the C handle of the CachedOp
+ * @param parameters the parameter values
+ * @param paramIndices the parameters required by the model and their corresponding location
+ * @param dataIndices the input data names required by the model and their corresponding
+ * location
+ */
+ public CachedOp(
+ MxResource parent,
+ Pointer handle,
+ List<Parameter> parameters,
+ List<Integer> paramIndices,
+ PairList<String, Integer> dataIndices) {
+ super(parent, handle);
+ this.parameters = parameters;
+ this.dataIndices = dataIndices;
+ this.paramIndices = paramIndices;
+ this.dataIndicesMap = dataIndices.toMap();
+ }
+
+ /**
+ * Assigns inputs to the empty locations of the input NDArray.
+ *
+ * @param data the input in {@link NDList} format
+ * @return an {@link NDList}
+ */
+ public NDList forward(NDList data) {
+ // reset the input data index at the beginning
+ NDArray[] allInputsNDArray = new NDArray[parameters.size()];
+ // check device of input
+ Device device = data.isEmpty() ? Device.defaultIfNull() : data.head().getDevice();
+ // fill allInputsNDArray with parameter values on correct device
+ for (int index : paramIndices) {
+ Parameter parameter = parameters.get(index);
+ NDArray value = parameter.getArray();
+ if (value == null) {
+ throw new NullPointerException("Failed to find parameter from parameterStore");
+ }
+ value.setDevice(device);
+ allInputsNDArray[index] = value;
+ }
+
+ // fill allInputsNDArray with data values
+ int index = 0;
+ for (NDArray array : data) {
+ // TODO: NDArray name doesn't match. To confirm the format of input name
+ // String inputName = array.getName().split(":")[1];
+ String inputName = array.getName();
+ // if inputName not provided, value will follow the default order
+ int idx = indexOf(inputName, index++);
+ allInputsNDArray[idx] = array;
+ }
+
+ // check the input, set as Shape(batchSize) by default
+ for (Pair<String, Integer> pair : dataIndices) {
+ if (allInputsNDArray[pair.getValue()] == null) {
+ // TODO: Do we need to set default to the input?
+ long batchSize = data.head().getShape().get(0);
+ String key = pair.getKey();
+ if (!"prob_label".equals(key) && !"softmax_label".equals(key)) {
+ logger.warn(
+ "Input "
+ + key
+ + " not found, set NDArray to Shape("
+ + batchSize
+ + ") by default");
+ }
+ // TODO: consider how to manage MxNDArray generated during inference
+ allInputsNDArray[pair.getValue()] =
+ NDArray.create(this, new Shape(batchSize), device);
+ }
+ }
+ NDArray[] result = JnaUtils.cachedOpInvoke(getParent(), getHandle(), allInputsNDArray);
+ return new NDList(result);
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public void close() {
+ if (!getClosed()) {
+ logger.debug(String.format("Start to free CachedOp instance: %S", this.getUid()));
+ super.freeSubResources();
+ Pointer pointer = handle.getAndSet(null);
+ if (pointer != null) {
+ JnaUtils.freeCachedOp(pointer);
+ }
+ setClosed(true);
+ logger.debug(String.format("Finish to free CachedOp instance: %S", this.getUid()));
+ }
+ }
+
+ private int indexOf(String inputName, int position) {
+ if (inputName == null) {
+ return dataIndices.valueAt(position);
+ }
+
+ Integer index = dataIndicesMap.get(inputName);
+ if (index == null) {
+ throw new IllegalArgumentException(
+ "Unknown input name: "
+ + inputName
+ + ", expected inputs: "
+ + dataIndicesMap.keySet().toString());
+ }
+ return index;
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/Device.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/Device.java
new file mode 100644
index 0000000..7fb74c6
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/Device.java
@@ -0,0 +1,236 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.engine;
+
+import java.util.Map;
+import java.util.Objects;
+import java.util.concurrent.ConcurrentHashMap;
+import org.apache.mxnet.util.cuda.CudaUtils;
+
+/**
+ * The {@code Device} class provides the specified assignment for CPU/GPU processing on the {@link
+ * org.apache.mxnet.ndarray.NDArray}.
+ *
+ * <p>Users can use this to specify whether to load/compute the {@code NDArray} on CPU/GPU with
+ * deviceType and deviceId provided.
+ */
+public final class Device {
+
+ private static final Map<String, Device> CACHE = new ConcurrentHashMap<>();
+
+ private static final Device CPU = new Device(Type.CPU, -1);
+
+ private static final Device GPU = Device.of(Type.GPU, 0);
+
+ private String deviceType;
+
+ private int deviceId;
+
+ private static final Device DEFAULT_DEVICE = CPU;
+
+ /**
+ * Creates a {@code Device} with basic information.
+ *
+ * @param deviceType the device type, typically CPU or GPU
+ * @param deviceId the deviceId on the hardware. For example, if you have multiple GPUs, you can
+ * choose which GPU to process the NDArray
+ */
+ private Device(String deviceType, int deviceId) {
+ this.deviceType = deviceType;
+ this.deviceId = deviceId;
+ }
+
+ /**
+ * Returns a {@code Device} with device type and device id.
+ *
+ * @param deviceType the device type, typically CPU or GPU
+ * @param deviceId the deviceId on the hardware.
+ * @return a {@code Device} instance
+ */
+ public static Device of(String deviceType, int deviceId) {
+ if (Type.CPU.equals(deviceType)) {
+ return CPU;
+ }
+ String key = deviceType + '-' + deviceId;
+ return CACHE.computeIfAbsent(key, k -> new Device(deviceType, deviceId));
+ }
+
+ /**
+ * Returns the device type of the Device.
+ *
+ * @return the device type of the Device
+ */
+ public String getDeviceType() {
+ return deviceType;
+ }
+
+ /**
+ * Returns the {@code deviceId} of the Device.
+ *
+ * @return the {@code deviceId} of the Device
+ */
+ public int getDeviceId() {
+ if (Type.CPU.equals(deviceType)) {
+ throw new IllegalStateException("CPU doesn't have device id");
+ }
+ return deviceId;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public String toString() {
+ if (Type.CPU.equals(deviceType)) {
+ return deviceType + "()";
+ }
+ return deviceType + '(' + deviceId + ')';
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ Device device = (Device) o;
+ if (Type.CPU.equals(deviceType)) {
+ return Objects.equals(deviceType, device.getDeviceType());
+ }
+ return deviceId == device.deviceId && Objects.equals(deviceType, device.deviceType);
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public int hashCode() {
+ return Objects.hash(deviceType, deviceId);
+ }
+
+ /**
+ * Returns the default CPU Device.
+ *
+ * @return the default CPU Device
+ */
+ public static Device cpu() {
+ return CPU;
+ }
+
+ /**
+ * Returns the default GPU Device.
+ *
+ * @return the default GPU Device
+ */
+ public static Device gpu() {
+ return GPU;
+ }
+
+ /**
+ * Returns a new instance of GPU {@code Device} with the specified {@code deviceId}.
+ *
+ * @param deviceId the GPU device ID
+ * @return a new instance of GPU {@code Device} with specified {@code deviceId}
+ */
+ public static Device gpu(int deviceId) {
+ return of(Type.GPU, deviceId);
+ }
+
+ /**
+ * Returns an array of devices.
+ *
+ * <p>If GPUs are available, it will return an array of {@code Device} of size
+ * \(min(numAvailable, maxGpus)\). Else, it will return an array with a single CPU device.
+ *
+ * @return an array of devices
+ */
+ public static Device[] getDevices() {
+ return getDevices(Integer.MAX_VALUE);
+ }
+
+ /**
+ * Returns an array of devices given the maximum number of GPUs to use.
+ *
+ * <p>If GPUs are available, it will return an array of {@code Device} of size
+ * \(min(numAvailable, maxGpus)\). Else, it will return an array with a single CPU device.
+ *
+ * @param maxGpus the max number of GPUs to use. Use 0 for no GPUs.
+ * @return an array of devices
+ */
+ public static Device[] getDevices(int maxGpus) {
+ int count = getGpuCount();
+ if (maxGpus <= 0 || count <= 0) {
+ return new Device[] {CPU};
+ }
+
+ count = Math.min(maxGpus, count);
+ Device[] devices = new Device[count];
+ for (int i = 0; i < devices.length; ++i) {
+ devices[i] = gpu(i);
+ }
+ return devices;
+ }
+
+ /**
+ * Returns the number of GPUs available in the system.
+ *
+ * @return the number of GPUs available in the system
+ */
+ public static int getGpuCount() {
+ return CudaUtils.getGpuCount();
+ }
+
+ /**
+ * Returns the default context used in Engine.
+ *
+ * <p>The default type is defined by whether the deep learning engine is recognizing GPUs
+ * available on your machine. If there is no GPU available, CPU will be used.
+ *
+ * @return a {@link Device}
+ */
+ private static Device defaultDevice() {
+ return DEFAULT_DEVICE;
+ }
+
+ /**
+ * Returns the given device or the default if it is null.
+ *
+ * @param device the device to try to return
+ * @return the given device or the default if it is null
+ */
+ public static Device defaultIfNull(Device device) {
+ if (device != null) {
+ return device;
+ }
+ return defaultDevice();
+ }
+
+ /**
+ * Returns the default device.
+ *
+ * @return the default device
+ */
+ public static Device defaultIfNull() {
+ return defaultIfNull(null);
+ }
+
+ /** Contains device type string constants. */
+ public interface Type {
+ String CPU = "cpu";
+ String GPU = "gpu";
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/DeviceType.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/DeviceType.java
new file mode 100644
index 0000000..02ad120
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/DeviceType.java
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.engine;
+
+/** {@code DeviceType} is a class used to map the Device name to their corresponding type number. */
+public final class DeviceType {
+
+ private static final String CPU_PINNED = "cpu_pinned";
+
+ private DeviceType() {}
+
+ /**
+ * Converts a {@link Device} to the corresponding MXNet device number.
+ *
+ * @param device the java {@link Device}
+ * @return the MXNet device number
+ * @exception IllegalArgumentException the device is null or is not supported
+ */
+ public static int toDeviceType(Device device) {
+ if (device == null) {
+ throw new IllegalArgumentException("Unsupported device: null");
+ }
+
+ String deviceType = device.getDeviceType();
+
+ if (Device.Type.CPU.equals(deviceType)) {
+ return 1;
+ } else if (Device.Type.GPU.equals(deviceType)) {
+ return 2;
+ } else if (CPU_PINNED.equals(deviceType)) {
+ return 3;
+ } else {
+ throw new IllegalArgumentException("Unsupported device: " + device.toString());
+ }
+ }
+
+ /**
+ * Converts from an MXNet device number to {@link Device}.
+ *
+ * @param deviceType the MXNet device number
+ * @return the corresponding {@link Device}
+ */
+ public static String fromDeviceType(int deviceType) {
+ switch (deviceType) {
+ case 1:
+ case 3:
+ // hide the CPU_PINNED to frontend user
+ // but the advance user can still create CPU_PINNED
+ // to pass through engine
+ return Device.Type.CPU;
+ case 2:
+ return Device.Type.GPU;
+ default:
+ throw new IllegalArgumentException("Unsupported deviceType: " + deviceType);
+ }
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/GradReq.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/GradReq.java
new file mode 100644
index 0000000..c388de3
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/GradReq.java
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.engine;
+
+/** An enum that indicates whether gradient is required. */
+public enum GradReq {
+ NULL("null", 0),
+ WRITE("write", 1),
+ ADD("add", 3);
+
+ private String type;
+ private int value;
+
+ GradReq(String type, int value) {
+ this.type = type;
+ this.value = value;
+ }
+
+ /**
+ * Gets the type of this {@code GradReq}.
+ *
+ * @return the type
+ */
+ public String getType() {
+ return type;
+ }
+
+ /**
+ * Gets the value of this {@code GradType}.
+ *
+ * @return the value
+ */
+ public int getValue() {
+ return value;
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/Model.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/Model.java
new file mode 100644
index 0000000..90cceb4
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/Model.java
@@ -0,0 +1,442 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.engine;
+
+import java.io.FileNotFoundException;
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.util.ArrayList;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Locale;
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+import org.apache.mxnet.jna.JnaUtils;
+import org.apache.mxnet.ndarray.NDArray;
+import org.apache.mxnet.ndarray.NDList;
+import org.apache.mxnet.ndarray.types.DataType;
+import org.apache.mxnet.ndarray.types.Shape;
+import org.apache.mxnet.nn.Parameter;
+import org.apache.mxnet.nn.SymbolBlock;
+import org.apache.mxnet.repository.Item;
+import org.apache.mxnet.repository.Repository;
+import org.apache.mxnet.translate.NoOpTranslator;
+import org.apache.mxnet.translate.Translator;
+import org.apache.mxnet.util.PairList;
+import org.apache.mxnet.util.Utils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * A model is a collection of artifacts that is created by the training process.
+ *
+ * <p>Model contains methods to load and process a model object. In addition, it provides MXNet
+ * Specific functionality, such as getSymbol to obtain the Symbolic graph and getParameters to
+ * obtain the parameter NDArrays
+ */
+public class Model extends MxResource {
+
+ private static final Logger logger = LoggerFactory.getLogger(Model.class);
+ protected Path modelDir;
+ protected SymbolBlock symbolBlock;
+ protected String modelName;
+ protected DataType dataType;
+ protected PairList<String, Shape> inputData;
+ protected Map<String, Object> artifacts = new ConcurrentHashMap<>();
+ protected Map<String, String> properties = new ConcurrentHashMap<>();
+
+ Model(String name, Device device) {
+ this(BaseMxResource.getSystemMxResource(), name, device);
+ }
+
+ private Model(MxResource parent, String name, Device device) {
+ super(parent);
+ setDevice(Device.defaultIfNull(device));
+ setDataType(DataType.FLOAT32);
+ setModelName(name);
+ }
+
+ /**
+ * Create a default {@link Predictor} instance, with {@link NoOpTranslator} as default
+ * translator , and do not copy parameters to parameter store.
+ *
+ * @return {@link Predictor}
+ */
+ public Predictor<NDList, NDList> newPredictor() {
+ Translator<NDList, NDList> noOpTranslator = new NoOpTranslator();
+ return newPredictor(noOpTranslator);
+ }
+
+ /**
+ * Create {@link Predictor} instance, with specific {@link Translator} and {@code copy}.
+ *
+ * @param translator {@link Translator} used to convert inputs and outputs into {@link NDList}
+ * to get inferred
+ * @param <I> the input type
+ * @param <O> the output type
+ * @return {@link Predictor}
+ */
+ public <I, O> Predictor<I, O> newPredictor(Translator<I, O> translator) {
+ return new Predictor<>(this, translator);
+ }
+
+ /**
+ * Create and initialize a MxModel from the model directory.
+ *
+ * @param modelPath {@code Path} model directory
+ * @return loaded {@code Model} instance
+ * @throws IOException when IO operation fails in loading a resource
+ */
+ public static Model loadModel(Path modelPath) throws IOException {
+ return loadModel("model", modelPath);
+ }
+
+ /**
+ * Create and initialize a MxModel from repository Item.
+ *
+ * @param modelItem {@link Item} model directory
+ * @return {@link Model}
+ * @throws IOException when IO operation fails in loading a resource
+ */
+ public static Model loadModel(Item modelItem) throws IOException {
+ Model model = createModel(modelItem);
+ model.initial();
+ return model;
+ }
+
+ /**
+ * Create and initialize a MxModel with a model name from the model directory.
+ *
+ * @param modelName {@link String} model name
+ * @param modelPath {@link Path} model directory
+ * @return {@link Model}
+ * @throws IOException when IO operation fails in loading a resource
+ */
+ public static Model loadModel(String modelName, Path modelPath) throws IOException {
+ Model model = createModel(modelName, modelPath);
+ model.initial();
+ return model;
+ }
+
+ /**
+ * Create a MxModel with specific model name and model directory. By default, the {@link Model}
+ * instance is managed by the top level {@link BaseMxResource}.
+ *
+ * @param modelName {@String} model name
+ * @param modelDir {@Path} local model path
+ * @return {@link Model}
+ * @throws IOException when IO operation fails in loading a resource
+ */
+ static Model createModel(String modelName, Path modelDir) {
+ Model model = new Model(modelName, Device.defaultIfNull());
+ model.setModelDir(modelDir);
+ return model;
+ }
+
+ /**
+ * Create a sample MxModel Download or find the local path for the sample model.
+ *
+ * @param item {@link Item} sample model to be created
+ * @return created {@link Model} instance
+ * @throws IOException when IO operation fails in loading a resource
+ */
+ static Model createModel(Item item) throws IOException {
+ Path modelDir = Repository.initRepository(item);
+ return createModel(item.getName(), modelDir);
+ }
+
+ /**
+ * Initialize the model object Download or find the path for target model Load parameters and
+ * symbol from the path.
+ *
+ * @throws IOException when IO operation fails in loading a resource
+ * @throws FileNotFoundException if Model Directory is not assigned
+ */
+ public void initial() throws IOException {
+ if (getModelDir() == null) {
+ throw new FileNotFoundException("Model path is not defined!");
+ }
+ load(getModelDir());
+ }
+
+ /**
+ * Loads the model from the {@code modelPath}.
+ *
+ * @param modelPath the directory or file path of the model location
+ * @throws IOException when IO operation fails in loading a resource
+ */
+ public void load(Path modelPath) throws IOException {
+ load(modelPath, null, null);
+ }
+
+ /**
+ * Loads the MXNet model from a specified location.
+ *
+ * <p>MXNet Model looks for {MODEL_NAME}-symbol.json and {MODEL_NAME}-{EPOCH}.params files in
+ * the specified directory. By default, It will pick up the latest epoch of the parameter file.
+ * However, users can explicitly specify an epoch to be loaded:
+ *
+ * <pre>
+ * Map<String, String> options = new HashMap<>()
+ * <b>options.put("epoch", "3");</b>
+ * model.load(modelPath, "squeezenet", options);
+ * </pre>
+ *
+ * @param modelPath the directory of the model
+ * @param prefix the model file name or path prefix
+ * @param options load model options, see documentation for the specific engine
+ * @throws IOException Exception for file loading
+ */
+ public void load(Path modelPath, String prefix, Map<String, ?> options) throws IOException {
+ modelDir = modelPath.toAbsolutePath();
+ if (prefix == null) {
+ prefix = modelName;
+ }
+ Path paramFile = paramPathResolver(prefix, options);
+ if (paramFile == null) {
+ prefix = modelDir.toFile().getName();
+ paramFile = paramPathResolver(prefix, options);
+ if (paramFile == null) {
+ throw new FileNotFoundException(
+ "Parameter file with prefix: " + prefix + " not found in: " + modelDir);
+ }
+ }
+
+ if (getSymbolBlock() == null) {
+ // load MxSymbolBlock
+ Path symbolFile = modelDir.resolve(prefix + "-symbol.json");
+ if (Files.notExists(symbolFile)) {
+ throw new FileNotFoundException(
+ "Symbol file not found: "
+ + symbolFile
+ + ", please set block manually for imperative model.");
+ }
+
+ // TODO: change default name "data" to model-specific one
+ setMxSymbolBlock(SymbolBlock.createMxSymbolBlock(this, symbolFile));
+ }
+ loadParameters(paramFile);
+ // TODO: Check if Symbol has all names that params file have
+ if (options != null && options.containsKey("MxOptimizeFor")) {
+ String optimization = (String) options.get("MxOptimizeFor");
+ getSymbolBlock().optimizeFor(optimization, getDevice());
+ }
+ }
+
+ protected Path paramPathResolver(String prefix, Map<String, ?> options) throws IOException {
+ try {
+ int epoch = getEpoch(prefix, options);
+ return getModelDir()
+ .resolve(String.format(Locale.ROOT, "%s-%04d.params", prefix, epoch));
+ } catch (FileNotFoundException e) {
+ return null;
+ }
+ }
+
+ private int getEpoch(String prefix, Map<String, ?> options) throws IOException {
+ if (options != null) {
+ Object epochOption = options.getOrDefault("epoch", null);
+ if (epochOption != null) {
+ return Integer.parseInt(epochOption.toString());
+ }
+ }
+ return Utils.getCurrentEpoch(getModelDir(), prefix);
+ }
+
+ @SuppressWarnings("PMD.UseConcurrentHashMap")
+ private void loadParameters(Path paramFile) {
+
+ NDList paramNDlist = JnaUtils.loadNdArray(this, paramFile, getDevice());
+
+ List<Parameter> parameters = getSymbolBlock().getAllParameters();
+ Map<String, Parameter> map = new LinkedHashMap<>();
+ parameters.forEach(p -> map.put(p.getName(), p));
+
+ for (NDArray nd : paramNDlist) {
+ String key = nd.getName();
+ if (key == null) {
+ throw new IllegalArgumentException("Array names must be present in parameter file");
+ }
+
+ String paramName = key.split(":", 2)[1];
+ Parameter parameter = map.remove(paramName);
+ parameter.setArray(nd);
+ }
+ getSymbolBlock().setInputNames(new ArrayList<>(map.keySet()));
+
+ // TODO: Find a better to infer model DataType from SymbolBlock.
+ dataType = paramNDlist.head().getDataType();
+ logger.debug("MXNet Model {} ({}) loaded successfully.", paramFile, dataType);
+ }
+
+ /**
+ * Get the modelDir from the Model.
+ *
+ * @return {@link Path} modelDir for the Model
+ */
+ public Path getModelDir() {
+ return modelDir;
+ }
+
+ /**
+ * Set the modelDir for the Model.
+ *
+ * @param modelDir {@link Path}
+ */
+ public void setModelDir(Path modelDir) {
+ this.modelDir = modelDir;
+ }
+
+ /**
+ * Get the symbolBlock of the Model.
+ *
+ * @return {@link SymbolBlock}
+ */
+ public SymbolBlock getSymbolBlock() {
+ return symbolBlock;
+ }
+
+ /**
+ * Set the symbolBlock for the Model.
+ *
+ * @param symbolBlock {@link SymbolBlock}
+ */
+ public void setMxSymbolBlock(SymbolBlock symbolBlock) {
+ this.symbolBlock = symbolBlock;
+ }
+
+ /**
+ * Get the name of the Model.
+ *
+ * @return modelName
+ */
+ public String getModelName() {
+ return modelName;
+ }
+
+ /**
+ * Set the model name for the Model.
+ *
+ * @param modelName for the Model
+ */
+ public final void setModelName(String modelName) {
+ this.modelName = modelName;
+ }
+
+ /**
+ * Get data type for the Model.
+ *
+ * @return {@link DataType}
+ */
+ public DataType getDataType() {
+ return dataType;
+ }
+
+ /**
+ * Set data type for the Model.
+ *
+ * @param dataType {@link DataType}
+ */
+ public final void setDataType(DataType dataType) {
+ this.dataType = dataType;
+ }
+
+ /**
+ * Get input data of the Model.
+ *
+ * @return {@link PairList} inputData
+ */
+ public PairList<String, Shape> getInputData() {
+ return inputData;
+ }
+
+ /**
+ * Set input data for the Model.
+ *
+ * @param inputData {@link PairList}
+ */
+ public void setInputData(PairList<String, Shape> inputData) {
+ this.inputData = inputData;
+ }
+
+ /**
+ * Get the Artifact Object from artifacts by key.
+ *
+ * @param key for the Artifact Object
+ * @return Artifact {@link Object} instance
+ */
+ public Object getArtifact(String key) {
+ return artifacts.get(key);
+ }
+
+ /**
+ * Set the Artifact Object for artifacts.
+ *
+ * @param key for the Artifact
+ * @param artifact {@link Object}
+ */
+ public void setArtifact(String key, Object artifact) {
+ artifacts.put(key, artifact);
+ }
+
+ /**
+ * Get the property from properties by key.
+ *
+ * @param key {@link String}
+ * @return {@link String} property
+ */
+ public String getProperty(String key) {
+ return properties.get(key);
+ }
+
+ /**
+ * Set the property for the Model.
+ *
+ * @param key for the property
+ * @param property value of the property
+ */
+ public void setProperties(String key, String property) {
+ this.properties.put(key, property);
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public Device getDevice() {
+ if (device == null) {
+ return super.getDevice();
+ }
+ return device;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public void close() {
+ if (!getClosed()) {
+ logger.debug(String.format("Start to free Model instance: %S", this.getModelName()));
+ // release sub resources
+ super.freeSubResources();
+ // release itself
+ this.symbolBlock = null;
+ this.artifacts = null;
+ this.properties = null;
+ setClosed(true);
+ logger.debug(String.format("Finish to free Model instance: %S", this.getModelName()));
+ }
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/MxDataType.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/MxDataType.java
new file mode 100644
index 0000000..b96941f
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/MxDataType.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.engine;
+
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+import org.apache.mxnet.ndarray.types.DataType;
+
+/** Helper to convert between {@link DataType} and the MXNet internal DataTypes. */
+public final class MxDataType {
+
+ private static Map<DataType, String> toMx = createMapToMx();
+ private static Map<String, DataType> fromMx = createMapFromMx();
+
+ private MxDataType() {}
+
+ private static Map<DataType, String> createMapToMx() {
+ Map<DataType, String> map = new ConcurrentHashMap<>();
+ map.put(DataType.FLOAT32, "float32");
+ map.put(DataType.FLOAT64, "float64");
+ map.put(DataType.INT32, "int32");
+ map.put(DataType.INT64, "int64");
+ map.put(DataType.UINT8, "uint8");
+ return map;
+ }
+
+ private static Map<String, DataType> createMapFromMx() {
+ Map<String, DataType> map = new ConcurrentHashMap<>();
+ map.put("float32", DataType.FLOAT32);
+ map.put("float64", DataType.FLOAT64);
+ map.put("int32", DataType.INT32);
+ map.put("int64", DataType.INT64);
+ map.put("uint8", DataType.UINT8);
+ return map;
+ }
+
+ /**
+ * Converts a MXNet type String into a {@link DataType}.
+ *
+ * @param mxType the type String to convert
+ * @return the {@link DataType}
+ */
+ public static DataType fromMx(String mxType) {
+ return fromMx.get(mxType);
+ }
+
+ /**
+ * Converts a {@link DataType} into the corresponding MXNet type String.
+ *
+ * @param jType the java {@link DataType} to convert
+ * @return the converted MXNet type string
+ */
+ public static String toMx(DataType jType) {
+ return toMx.get(jType);
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/MxResource.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/MxResource.java
new file mode 100644
index 0000000..5740d9b
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/MxResource.java
@@ -0,0 +1,171 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.engine;
+
+import com.sun.jna.Pointer;
+import java.util.UUID;
+import java.util.concurrent.ConcurrentHashMap;
+import org.apache.mxnet.util.NativeResource;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * An auto closable Resource object whose life circle can be managed by its parent {@link
+ * MxResource} instance. Meanwhile, it manages life circle of child {@link MxResource} instances.
+ */
+public class MxResource extends NativeResource<Pointer> {
+
+ private static final Logger logger = LoggerFactory.getLogger(MxResource.class);
+
+ private static boolean closed;
+
+ protected Device device;
+
+ private MxResource parent;
+
+ private ConcurrentHashMap<String, MxResource> subResources;
+
+ protected MxResource() {
+ super();
+ setParent(null);
+ }
+
+ protected MxResource(MxResource parent, String uid) {
+ super(uid);
+ setClosed(false);
+ setParent(parent);
+ getParent().addSubResource(this);
+ }
+
+ protected MxResource(MxResource parent) {
+ this(parent, UUID.randomUUID().toString());
+ }
+
+ protected MxResource(MxResource parent, Pointer handle) {
+ super(handle);
+ setParent(parent);
+ if (parent != null) {
+ parent.addSubResource(this);
+ } else {
+ BaseMxResource.getSystemMxResource().addSubResource(this);
+ }
+ }
+ /**
+ * Add the sub {@link MxResource} under the current instance.
+ *
+ * @param subMxResource the instance to be added
+ */
+ public void addSubResource(MxResource subMxResource) {
+ getSubResource().put(subMxResource.getUid(), subMxResource);
+ }
+
+ /** Free all sub {@link MxResource} instances of the current instance. */
+ public void freeSubResources() {
+ if (subResourceInitialized()) {
+ for (MxResource subResource : subResources.values()) {
+ try {
+ subResource.close();
+ } catch (Exception e) {
+ logger.error("MxResource close failed.", e);
+ }
+ }
+ subResources = null;
+ }
+ }
+
+ /**
+ * Check whether {@code subResource} has been initialized.
+ *
+ * @return boolean
+ */
+ public boolean subResourceInitialized() {
+ return subResources != null;
+ }
+
+ /**
+ * Get the {@code subResources} of the {@link MxResource}.
+ *
+ * @return subResources
+ */
+ public ConcurrentHashMap<String, MxResource> getSubResource() {
+ if (!subResourceInitialized()) {
+ subResources = new ConcurrentHashMap<>();
+ }
+ return subResources;
+ }
+
+ protected final void setParent(MxResource parent) {
+ this.parent = parent;
+ }
+
+ /**
+ * Get parent {@link MxResource} of the current instance.
+ *
+ * @return {@link MxResource}
+ */
+ public MxResource getParent() {
+ return this.parent;
+ }
+
+ /**
+ * Set the {@link Device} for the {@link MxResource}.
+ *
+ * @param device {@link Device}
+ */
+ public void setDevice(Device device) {
+ this.device = device;
+ }
+
+ /**
+ * Returns the {@link Device} of this {@code MxResource}.
+ *
+ * <p>{@link Device} class contains the information where this {@code NDArray} stored in memory,
+ * like CPU/GPU.
+ *
+ * @return the {@link Device} of this {@code MxResource}
+ */
+ public Device getDevice() {
+ Device curDevice = getParent() == null ? null : getParent().getDevice();
+ return Device.defaultIfNull(curDevice);
+ }
+
+ /**
+ * Sets closed for MxResource instance.
+ *
+ * @param isClosed whether this {@link MxResource} get closed
+ */
+ public final void setClosed(boolean isClosed) {
+ this.closed = isClosed;
+ }
+
+ /**
+ * Get the attribute closed for the MxResource to check out whether it is closed.
+ *
+ * @return closed
+ */
+ public boolean getClosed() {
+ return closed;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public void close() {
+ freeSubResources();
+ setClosed(true);
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/MxResourceList.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/MxResourceList.java
new file mode 100644
index 0000000..6869cf6
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/MxResourceList.java
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.engine;
+
+import java.util.List;
+import java.util.Map;
+import org.apache.mxnet.util.Pair;
+import org.apache.mxnet.util.PairList;
+
+/**
+ * An {@code MxResourceList} represents a sequence of {@link MxResource}s with names.
+ *
+ * <p>Each {@link MxResource} in this list can optionally have a name. You can use the name to look
+ * up an MxResource in the MxResourceList.
+ *
+ * @see MxResource
+ */
+public class MxResourceList extends PairList<String, MxResource> {
+
+ /** Creates an empty {@code MxResourceList}. */
+ public MxResourceList() {}
+
+ /**
+ * Constructs an empty {@code MxResourceList} with the specified initial capacity.
+ *
+ * @param initialCapacity the initial capacity of the list
+ * @throws IllegalArgumentException if the specified initial capacity is negative
+ */
+ public MxResourceList(int initialCapacity) {
+ super(initialCapacity);
+ }
+
+ /**
+ * Constructs a {@code BlockList} containing the elements of the specified keys and values.
+ *
+ * @param keys the key list containing the elements to be placed into this {@code
+ * MxResourceList}
+ * @param values the value list containing the elements to be placed into this {@code
+ * MxResource}
+ * @throws IllegalArgumentException if the keys and values size are different
+ */
+ public MxResourceList(List<String> keys, List<MxResource> values) {
+ super(keys, values);
+ }
+
+ /**
+ * Constructs a {@code BlockList} containing the elements of the specified list of Pairs.
+ *
+ * @param list the list containing the elements to be placed into this {@code MxResourceList}
+ */
+ public MxResourceList(List<Pair<String, MxResource>> list) {
+ super(list);
+ }
+
+ /**
+ * Constructs a {@code BlockList} containing the elements of the specified map.
+ *
+ * @param map the map containing keys and values
+ */
+ public MxResourceList(Map<String, MxResource> map) {
+ super(map);
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/OpParams.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/OpParams.java
new file mode 100644
index 0000000..3362e03
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/OpParams.java
@@ -0,0 +1,217 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.engine;
+
+import org.apache.mxnet.ndarray.types.Shape;
+import org.apache.mxnet.ndarray.types.SparseFormat;
+import org.apache.mxnet.util.PairList;
+
+/** An internal helper for creating the MXNet operator parameters. */
+public class OpParams extends PairList<String, Object> {
+ // mxnet cpu take index
+ private static final String MXNET_CPU = "cpu(0)";
+ /**
+ * Sets the Shape parameter.
+ *
+ * @param shape the shape to set
+ */
+ public void setShape(Shape shape) {
+ addParam("shape", shape);
+ }
+
+ /**
+ * Sets the device to use for the operation.
+ *
+ * @param device the device to use for the operation
+ */
+ public void setDevice(Device device) {
+ setParam("ctx", ("cpu".equals(device.getDeviceType()) ? MXNET_CPU : device.toString()));
+ }
+
+ /**
+ * Sets the dataType to use for the operation.
+ *
+ * @param dataType the dataType to use for the operation
+ */
+ public void setDataType(org.apache.mxnet.ndarray.types.DataType dataType) {
+ if (dataType != null) {
+ setParam("dtype", MxDataType.toMx(dataType));
+ }
+ }
+
+ /**
+ * Sets the sparseFormat to use for the operation.
+ *
+ * @param sparseFormat the sparseFormat to use for the operation
+ */
+ public void setSparseFormat(SparseFormat sparseFormat) {
+ if (sparseFormat != null) {
+ setParam("stype", String.valueOf(sparseFormat.getValue()));
+ }
+ }
+
+ /**
+ * Sets a (potentially existing) parameter to a new value.
+ *
+ * @param paramName the parameter name to update
+ * @param value the value to set the parameter to
+ */
+ public void setParam(String paramName, String value) {
+ remove(paramName);
+ add(paramName, value);
+ }
+
+ /**
+ * Adds a parameter.
+ *
+ * @param paramName the name of the new parameter
+ * @param shape the value of the new parameter
+ */
+ public void addParam(String paramName, Shape shape) {
+ if (shape != null) {
+ add(paramName, shape.toString());
+ }
+ }
+
+ /**
+ * Adds a parameter.
+ *
+ * @param paramName the name of the new parameter
+ * @param value the value of the new parameter
+ */
+ public void addParam(String paramName, String value) {
+ add(paramName, value);
+ }
+
+ /**
+ * Adds a parameter.
+ *
+ * @param paramName the name of the new parameter
+ * @param value the value of the new parameter
+ */
+ public void addParam(String paramName, int value) {
+ add(paramName, String.valueOf(value));
+ }
+
+ /**
+ * Adds a parameter.
+ *
+ * @param paramName the name of the new parameter
+ * @param value the value of the new parameter
+ */
+ public void addParam(String paramName, long value) {
+ add(paramName, String.valueOf(value));
+ }
+
+ /**
+ * Adds a parameter.
+ *
+ * @param paramName the name of the new parameter
+ * @param value the value of the new parameter
+ */
+ public void addParam(String paramName, double value) {
+ add(paramName, String.valueOf(value));
+ }
+
+ /**
+ * Adds a parameter.
+ *
+ * @param paramName the name of the new parameter
+ * @param value the value of the new parameter
+ */
+ public void addParam(String paramName, float value) {
+ add(paramName, String.valueOf(value));
+ }
+
+ /**
+ * Adds a parameter.
+ *
+ * @param paramName the name of the new parameter
+ * @param value the value of the new parameter
+ */
+ public void addParam(String paramName, boolean value) {
+ add(paramName, value ? "True" : "False");
+ }
+
+ /**
+ * Adds a parameter.
+ *
+ * @param paramName the name of the new parameter
+ * @param value the value of the new parameter
+ */
+ public void addParam(String paramName, Number value) {
+ add(paramName, String.valueOf(value));
+ }
+
+ /**
+ * Adds a parameter with tuple value.
+ *
+ * @param paramName the name of the new parameter
+ * @param tuple the values of the new parameter
+ */
+ public void addTupleParam(String paramName, int... tuple) {
+ StringBuilder sb = new StringBuilder();
+ sb.append('(');
+ for (int i = 0; i < tuple.length; ++i) {
+ if (i > 0) {
+ sb.append(", ");
+ }
+ sb.append(tuple[i]);
+ }
+ sb.append(')');
+ add(paramName, sb.toString());
+ }
+
+ /**
+ * Adds a parameter with tuple value.
+ *
+ * @param paramName the name of the new parameter
+ * @param tuple the values of the new parameter
+ */
+ public void addTupleParam(String paramName, long... tuple) {
+ StringBuilder sb = new StringBuilder();
+ sb.append('(');
+ for (int i = 0; i < tuple.length; ++i) {
+ if (i > 0) {
+ sb.append(", ");
+ }
+ sb.append(tuple[i]);
+ }
+ sb.append(')');
+ add(paramName, sb.toString());
+ }
+
+ /**
+ * Adds a parameter with tuple value.
+ *
+ * @param paramName the name of the new parameter
+ * @param tuple the values of the new parameter
+ */
+ public void addTupleParam(String paramName, float... tuple) {
+ StringBuilder sb = new StringBuilder();
+ sb.append('(');
+ for (int i = 0; i < tuple.length; ++i) {
+ if (i > 0) {
+ sb.append(", ");
+ }
+ sb.append(tuple[i]);
+ }
+ sb.append(')');
+ add(paramName, sb.toString());
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/Predictor.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/Predictor.java
new file mode 100644
index 0000000..739f6e0
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/Predictor.java
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.engine;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import org.apache.mxnet.exception.TranslateException;
+import org.apache.mxnet.ndarray.NDList;
+import org.apache.mxnet.translate.Translator;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * The {@code Predictor} class provides a session for model inference.
+ *
+ * <p>You can use a {@code Predictor}, with a specified {@link Translator}, to perform inference on
+ * a {@link Model}
+ *
+ * @param <I> the input type
+ * @param <O> the output type
+ * @see Model
+ * @see Translator
+ * @see <a href="http://docs.djl.ai/docs/development/memory_management.html">The guide on memory
+ * management</a>
+ * @see <a
+ * href="https://github.com/deepjavalibrary/djl/blob/master/examples/docs/multithread_inference.md">The
+ * guide on running multi-threaded inference</a>
+ * @see <a href="http://docs.djl.ai/docs/development/inference_performance_optimization.html">The
+ * guide on inference performance optimization</a>
+ */
+public class Predictor<I, O> extends MxResource {
+
+ private static final Logger logger = LoggerFactory.getLogger(Predictor.class);
+ private Translator<I, O> translator;
+ private Model model;
+
+ /**
+ * Creates a new instance of {@code Predictor} with the given {@link Model} and {@link
+ * Translator}.
+ *
+ * @param model the model on which the predictions are based
+ * @param translator the translator to be used
+ */
+ public Predictor(Model model, Translator<I, O> translator) {
+ super(model);
+ this.model = model;
+ this.translator = translator;
+ }
+
+ /**
+ * Predicts an item for inference.
+ *
+ * @param input the input
+ * @return the output object defined by the user
+ * @throws TranslateException if an error occurs during prediction
+ */
+ @SuppressWarnings("PMD.AvoidRethrowingException")
+ public List<O> predict(List<I> input) {
+ NDList[] ndLists = processInputs(input);
+ for (int i = 0; i < ndLists.length; ++i) {
+ ndLists[i] = forward(ndLists[i]);
+ }
+ return processOutPut(ndLists);
+ }
+
+ /**
+ * Predicts an Item for inference.
+ *
+ * @param input input data
+ * @return O the output object defined by the user
+ * @throws TranslateException if an error occurs during prediction
+ */
+ public O predict(I input) {
+ return predict(Collections.singletonList(input)).get(0);
+ }
+
+ private NDList forward(NDList ndList) {
+ logger.trace("Predictor input data: {}", ndList);
+ return model.getSymbolBlock().forward(ndList);
+ }
+
+ // TODO: add batch predict
+
+ private NDList[] processInputs(List<I> inputs) {
+ int batchSize = inputs.size();
+ NDList[] preprocessed = new NDList[batchSize];
+ try {
+ for (int i = 0; i < batchSize; ++i) {
+ preprocessed[i] = translator.processInput(inputs.get(i));
+ }
+ } catch (Exception e) {
+ logger.error("Error occurs when process input items.", e);
+ throw new TranslateException(e);
+ }
+ return preprocessed;
+ }
+
+ private List<O> processOutPut(NDList[] ndLists) {
+ List<O> outputs = new ArrayList<>();
+ try {
+ for (NDList mxNDList : ndLists) {
+ outputs.add(translator.processOutput(mxNDList));
+ }
+ } catch (Exception e) {
+ logger.error("Error occurs when process output items.", e);
+ throw new TranslateException(e);
+ }
+ return outputs;
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/Symbol.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/Symbol.java
new file mode 100644
index 0000000..7777add
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/Symbol.java
@@ -0,0 +1,242 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.engine;
+
+import com.sun.jna.Pointer;
+import java.nio.file.Path;
+import java.util.Arrays;
+import java.util.LinkedHashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.stream.Collectors;
+import org.apache.mxnet.jna.JnaUtils;
+import org.apache.mxnet.ndarray.types.Shape;
+import org.apache.mxnet.util.PairList;
+import org.apache.mxnet.util.Utils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * {@code Symbol} is an internal helper for symbolic model graphs used by the {@link
+ * org.apache.mxnet.nn.SymbolBlock}.
+ *
+ * @see org.apache.mxnet.nn.SymbolBlock
+ */
+public class Symbol extends MxResource {
+
+ private static final Logger logger = LoggerFactory.getLogger(Symbol.class);
+
+ private String[] outputs;
+
+ protected Symbol(MxResource parent, Pointer handle) {
+ super(parent, handle);
+ }
+
+ static Symbol loadFromFile(MxResource parent, String path) {
+ Pointer p = JnaUtils.createSymbolFromFile(path);
+ return new Symbol(parent, p);
+ }
+
+ /**
+ * Load {@link Symbol} from the given {@link Path}.
+ *
+ * @param parent the parent {@link MxResource}
+ * @param path the {@link Path} to load the {@link Symbol}
+ * @return {@link Symbol}
+ */
+ public static Symbol loadSymbol(MxResource parent, Path path) {
+ return loadFromFile(parent, path.toAbsolutePath().toString());
+ }
+
+ /**
+ * Loads a symbol from a json string.
+ *
+ * @param parent the parent {@link MxResource}
+ * @param json the json string of the symbol.
+ * @return the new symbol
+ */
+ public static Symbol loadJson(MxResource parent, String json) {
+ Pointer pointer = JnaUtils.createSymbolFromString(json);
+ return new Symbol(parent, pointer);
+ }
+
+ /**
+ * Returns the symbol outputs.
+ *
+ * @return the symbol outputs
+ */
+ public String[] getOutputNames() {
+ if (this.outputs == null) {
+ this.outputs = JnaUtils.listSymbolOutputs(getHandle());
+ }
+ return this.outputs;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public void close() {
+ if (!getClosed()) {
+ logger.debug(String.format("Start to free Symbol instance: %S", this.toJsonString()));
+ super.freeSubResources();
+ Pointer pointer = handle.getAndSet(null);
+ if (pointer != null) {
+ JnaUtils.freeSymbol(pointer);
+ }
+ setClosed(true);
+ logger.debug(String.format("Finish to free Symbol instance: %S", this.toJsonString()));
+ }
+ }
+
+ /**
+ * Returns the output symbol by index.
+ *
+ * @param index the index of the output
+ * @return the symbol output as a new symbol
+ */
+ public Symbol get(int index) {
+ Pointer pointer = JnaUtils.getSymbolOutput(getInternals().getHandle(), index);
+ return new Symbol(getParent(), pointer);
+ }
+
+ /**
+ * Returns the output symbol with the given name.
+ *
+ * @param name the name of the symbol to return
+ * @return the output symbol
+ * @throws IllegalArgumentException Thrown if no output matches the name
+ */
+ public Symbol get(String name) {
+ String[] out = getInternalOutputNames();
+ int index = Utils.indexOf(out, name);
+ if (index < 0) {
+ throw new IllegalArgumentException("Cannot find output that matches name: " + name);
+ }
+ return get(index);
+ }
+
+ /**
+ * Returns the symbol argument names.
+ *
+ * @return the symbol argument names
+ */
+ public String[] getArgNames() {
+ return JnaUtils.listSymbolArguments(getHandle());
+ }
+
+ /**
+ * Returns the MXNet auxiliary states for the symbol.
+ *
+ * @return the MXNet auxiliary states for the symbol
+ */
+ public String[] getAuxNames() {
+ return JnaUtils.listSymbolAuxiliaryStates(getHandle());
+ }
+
+ /**
+ * Returns the symbol names.
+ *
+ * @return the symbol names
+ */
+ public String[] getAllNames() {
+ return JnaUtils.listSymbolNames(getHandle());
+ }
+
+ /**
+ * Returns the list of names for all internal outputs.
+ *
+ * @return a list of names
+ */
+ public List<String> getLayerNames() {
+ String[] outputNames = getInternalOutputNames();
+ String[] allNames = getAllNames();
+ Set<String> allNamesSet = new LinkedHashSet<>(Arrays.asList(allNames));
+ // Kill all params field and keep the output layer
+ return Arrays.stream(outputNames)
+ .filter(n -> !allNamesSet.contains(n))
+ .collect(Collectors.toList());
+ }
+
+ private String[] getInternalOutputNames() {
+ return JnaUtils.listSymbolOutputs(getInternals().getHandle());
+ }
+
+ /**
+ * Returns the symbol internals.
+ *
+ * @return the symbol internals symbol
+ */
+ public Symbol getInternals() {
+ Pointer pointer = JnaUtils.getSymbolInternals(getHandle());
+ return new Symbol(getParent(), pointer);
+ }
+
+ /**
+ * Infers the shapes for all parameters inside a symbol from the given input shapes.
+ *
+ * @param pairs the given input name and shape
+ * @return a map of arguments with names and shapes
+ */
+ public Map<String, Shape> inferShape(PairList<String, Shape> pairs) {
+ List<List<Shape>> shapes = JnaUtils.inferShape(this, pairs);
+ if (shapes == null) {
+ throw new IllegalArgumentException("Cannot infer shape based on the data provided!");
+ }
+ List<Shape> argShapes = shapes.get(0);
+ List<Shape> outputShapes = shapes.get(1);
+ List<Shape> auxShapes = shapes.get(2);
+ // TODO: add output to the map
+ String[] argNames = getArgNames();
+ String[] auxNames = getAuxNames();
+ String[] outputNames = getOutputNames();
+ Map<String, Shape> shapesMap = new ConcurrentHashMap<>();
+ for (int i = 0; i < argNames.length; i++) {
+ shapesMap.put(argNames[i], argShapes.get(i));
+ }
+ for (int i = 0; i < auxNames.length; i++) {
+ shapesMap.put(auxNames[i], auxShapes.get(i));
+ }
+ for (int i = 0; i < outputNames.length; i++) {
+ shapesMap.put(outputNames[i], outputShapes.get(i));
+ }
+ return shapesMap;
+ }
+
+ /**
+ * [Experimental] Add customized optimization on the Symbol.
+ *
+ * <p>This method can be used with EIA or TensorRT for model acceleration
+ *
+ * @param backend backend name
+ * @param device the device assigned
+ * @return optimized Symbol
+ */
+ public Symbol optimizeFor(String backend, Device device) {
+ return new Symbol(getParent(), JnaUtils.optimizeFor(this, backend, device));
+ }
+
+ /**
+ * Converts Symbol to json string for saving purpose.
+ *
+ * @return the json string
+ */
+ public String toJsonString() {
+ return JnaUtils.getSymbolString(getHandle());
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/package-info.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/package-info.java
new file mode 100644
index 0000000..e88f05d
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/package-info.java
@@ -0,0 +1,19 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/** Contains the Java front-end implementation for Apache MXNet. */
+package org.apache.mxnet.engine;
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/BaseException.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/BaseException.java
new file mode 100644
index 0000000..0123a50
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/BaseException.java
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.exception;
+
+/** Thrown to indicate that a native error is raised from the underlying. */
+public class BaseException extends RuntimeException {
+
+ private static final long serialVersionUID = 1L;
+
+ /**
+ * Constructs a new exception with the specified detail message. The cause is not initialized,
+ * and may subsequently be initialized by a call to {@link #initCause}.
+ *
+ * @param message the detail message. The detail message is saved for later retrieval by the
+ * {@link #getMessage()} method.
+ */
+ public BaseException(String message) {
+ super(message);
+ }
+
+ /**
+ * Constructs a new exception with the specified detail message and cause.
+ *
+ * <p>Note that the detail message associated with {@code cause} is <i>not</i> automatically
+ * incorporated in this exception's detail message.
+ *
+ * @param message the detail message (which is saved for later retrieval by the {@link
+ * #getMessage()} method)
+ * @param cause the cause (which is saved for later retrieval by the {@link #getCause()}
+ * method). (A {@code null} value is permitted, and indicates that the cause is nonexistent
+ * or unknown.)
+ */
+ public BaseException(String message, Throwable cause) {
+ super(message, cause);
+ }
+
+ /**
+ * Constructs a new exception with the specified cause and a detail message of {@code
+ * (cause==null ? null : cause.toString())} (which typically contains the class and detail
+ * message of {@code cause}). This constructor is useful for exceptions that are little more
+ * than wrappers for other throwables (for example, {@link
+ * java.security.PrivilegedActionException}).
+ *
+ * @param cause the cause (which is saved for later retrieval by the {@link #getCause()}
+ * method). (A {@code null} value is permitted, and indicates that the cause is nonexistent
+ * or unknown.)
+ */
+ public BaseException(Throwable cause) {
+ super(cause);
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/JnaCallException.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/JnaCallException.java
new file mode 100644
index 0000000..f6f26fc
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/JnaCallException.java
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.exception;
+
+/** Thrown to indicate JNA functions are not called as expected. */
+public class JnaCallException extends BaseException {
+
+ private static final long serialVersionUID = 1L;
+
+ /**
+ * Constructs a new exception with the specified detail message. The cause is not initialized,
+ * and may subsequently be initialized by a call to {@link #initCause}.
+ *
+ * @param message the detail message. The detail message is saved for later retrieval by the
+ * {@link #getMessage()} method.
+ */
+ public JnaCallException(String message) {
+ super(message);
+ }
+
+ /**
+ * \ Constructs a new exception with the specified detail message and cause.
+ *
+ * <p>Note that the detail message associated with {@code cause} is <i>not</i> automatically
+ * incorporated in this exception's detail message.
+ *
+ * @param message the detail message that is saved for later retrieval by the {@link
+ * #getMessage()} method
+ * @param cause the cause that is saved for later retrieval by the {@link #getCause()} method. A
+ * {@code null} value is permitted, and indicates that the cause is nonexistent or unknown
+ */
+ public JnaCallException(String message, Throwable cause) {
+ super(message, cause);
+ }
+
+ /**
+ * Constructs a new exception with the specified cause and a detail message of {@code
+ * (cause==null ? null : cause.toString())} which typically contains the class and detail
+ * message of {@code cause}. This constructor is useful for exceptions that are little more than
+ * wrappers for other throwables. For example, {@link java.security.PrivilegedActionException}.
+ *
+ * @param cause the cause that is saved for later retrieval by the {@link #getCause()} method. A
+ * {@code null} value is permitted, and indicates that the cause is nonexistent or unknown
+ */
+ public JnaCallException(Throwable cause) {
+ super(cause);
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/MalformedModelException.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/MalformedModelException.java
new file mode 100644
index 0000000..5455633
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/MalformedModelException.java
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.exception;
+
+/** Thrown to indicate Model parameters are not in expected format or are malformed. */
+public class MalformedModelException extends ModelException {
+
+ private static final long serialVersionUID = 1L;
+
+ /**
+ * Constructs a new exception with the specified detail message. The cause is not initialized,
+ * and may subsequently be initialized by a call to {@link #initCause}.
+ *
+ * @param message the detail message. The detail message is saved for later retrieval by the
+ * {@link #getMessage()} method.
+ */
+ public MalformedModelException(String message) {
+ super(message);
+ }
+
+ /**
+ * \ Constructs a new exception with the specified detail message and cause.
+ *
+ * <p>Note that the detail message associated with {@code cause} is <i>not</i> automatically
+ * incorporated in this exception's detail message.
+ *
+ * @param message the detail message that is saved for later retrieval by the {@link
+ * #getMessage()} method
+ * @param cause the cause that is saved for later retrieval by the {@link #getCause()} method. A
+ * {@code null} value is permitted, and indicates that the cause is nonexistent or unknown
+ */
+ public MalformedModelException(String message, Throwable cause) {
+ super(message, cause);
+ }
+
+ /**
+ * Constructs a new exception with the specified cause and a detail message of {@code
+ * (cause==null ? null : cause.toString())} which typically contains the class and detail
+ * message of {@code cause}. This constructor is useful for exceptions that are little more than
+ * wrappers for other throwables. For example, {@link java.security.PrivilegedActionException}.
+ *
+ * @param cause the cause that is saved for later retrieval by the {@link #getCause()} method. A
+ * {@code null} value is permitted, and indicates that the cause is nonexistent or unknown
+ */
+ public MalformedModelException(Throwable cause) {
+ super(cause);
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/ModelException.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/ModelException.java
new file mode 100644
index 0000000..94a12e7
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/ModelException.java
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.exception;
+
+/** Thrown to indicate . */
+public class ModelException extends BaseException {
+
+ private static final long serialVersionUID = 1L;
+
+ /**
+ * Constructs a new exception with the specified detail message. The cause is not initialized,
+ * and may subsequently be initialized by a call to {@link #initCause}.
+ *
+ * @param message the detail message that is saved for later retrieval by the {@link
+ * #getMessage()} method
+ */
+ public ModelException(String message) {
+ super(message);
+ }
+
+ /**
+ * Constructs a new exception with the specified detail message and cause.
+ *
+ * <p>Note that the detail message associated with {@code cause} is <i>not</i> automatically
+ * incorporated in this exception's detail message.
+ *
+ * @param message the detail message that is saved for later retrieval by the {@link
+ * #getMessage()} method
+ * @param cause the cause that is saved for later retrieval by the {@link #getCause()} method. A
+ * {@code null} value is permitted, and indicates that the cause is nonexistent or unknown
+ */
+ public ModelException(String message, Throwable cause) {
+ super(message, cause);
+ }
+
+ /**
+ * Constructs a new exception with the specified cause and a detail message of {@code
+ * (cause==null ? null : cause.toString())} which typically contains the class and detail
+ * message of {@code cause}. This constructor is useful for exceptions that are little more than
+ * wrappers for other throwables. For example, {@link java.security.PrivilegedActionException}.
+ *
+ * @param cause the cause that is saved for later retrieval by the {@link #getCause()} method. A
+ * {@code null} value is permitted, and indicates that the cause is nonexistent or unknown
+ */
+ public ModelException(Throwable cause) {
+ super(cause);
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/TranslateException.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/TranslateException.java
new file mode 100644
index 0000000..b17758a
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/TranslateException.java
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.exception;
+
+/** Thrown to indicate Translate pipeline doesn't work as expected. */
+public class TranslateException extends BaseException {
+
+ private static final long serialVersionUID = 1L;
+
+ /**
+ * Constructs a new exception with the specified detail message. The cause is not initialized,
+ * and may subsequently be initialized by a call to {@link #initCause}.
+ *
+ * @param message the detail message. The detail message is saved for later retrieval by the
+ * {@link #getMessage()} method.
+ */
+ public TranslateException(String message) {
+ super(message);
+ }
+
+ /**
+ * \ Constructs a new exception with the specified detail message and cause.
+ *
+ * <p>Note that the detail message associated with {@code cause} is <i>not</i> automatically
+ * incorporated in this exception's detail message.
+ *
+ * @param message the detail message that is saved for later retrieval by the {@link
+ * #getMessage()} method
+ * @param cause the cause that is saved for later retrieval by the {@link #getCause()} method. A
+ * {@code null} value is permitted, and indicates that the cause is nonexistent or unknown
+ */
+ public TranslateException(String message, Throwable cause) {
+ super(message, cause);
+ }
+
+ /**
+ * Constructs a new exception with the specified cause and a detail message of {@code
+ * (cause==null ? null : cause.toString())} which typically contains the class and detail
+ * message of {@code cause}. This constructor is useful for exceptions that are little more than
+ * wrappers for other throwables. For example, {@link java.security.PrivilegedActionException}.
+ *
+ * @param cause the cause that is saved for later retrieval by the {@link #getCause()} method. A
+ * {@code null} value is permitted, and indicates that the cause is nonexistent or unknown
+ */
+ public TranslateException(Throwable cause) {
+ super(cause);
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/package-info.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/package-info.java
new file mode 100644
index 0000000..d464bfd
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/package-info.java
@@ -0,0 +1,19 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/** Contains the Java front-end implementation for Apache MXNet. */
+package org.apache.mxnet.exception;
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/FunctionInfo.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/FunctionInfo.java
new file mode 100644
index 0000000..a2da116
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/FunctionInfo.java
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.jna;
+
+import com.sun.jna.Pointer;
+import java.util.List;
+import org.apache.mxnet.engine.Device;
+import org.apache.mxnet.engine.MxResource;
+import org.apache.mxnet.ndarray.NDArray;
+import org.apache.mxnet.ndarray.types.SparseFormat;
+import org.apache.mxnet.util.PairList;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/** A FunctionInfo represents an operator (ie function) within the MXNet Engine. */
+public class FunctionInfo {
+
+ private Pointer handle;
+ private String name;
+ private PairList<String, String> arguments;
+
+ private static final Logger logger = LoggerFactory.getLogger(FunctionInfo.class);
+
+ FunctionInfo(Pointer pointer, String functionName, PairList<String, String> arguments) {
+ this.handle = pointer;
+ this.name = functionName;
+ this.arguments = arguments;
+ }
+
+ /**
+ * Returns the name of the operator.
+ *
+ * @return the name of the operator
+ */
+ public String getFunctionName() {
+ return name;
+ }
+
+ /**
+ * Returns the names of the params to the operator.
+ *
+ * @return the names of the params to the operator
+ */
+ public List<String> getArgumentNames() {
+ return arguments.keys();
+ }
+
+ /**
+ * Returns the types of the operator arguments.
+ *
+ * @return the types of the operator arguments
+ */
+ public List<String> getArgumentTypes() {
+ return arguments.values();
+ }
+ /**
+ * Calls an operator with the given arguments.
+ *
+ * @param src the input NDArray(s) to the operator
+ * @param dest the destination NDArray(s) to be overwritten with the result of the operator
+ * @param params the non-NDArray arguments to the operator. Should be a {@code PairList<String,
+ * String>}
+ * @return the error code or zero for no errors
+ */
+ public int invoke(NDArray[] src, NDArray[] dest, PairList<String, ?> params) {
+ checkDevices(src);
+ checkDevices(dest);
+ return JnaUtils.imperativeInvoke(handle, src, dest, params).size();
+ }
+
+ /**
+ * Calls an operator with the given arguments.
+ *
+ * @param parent {@link MxResource} for the current instance
+ * @param src the input NDArray(s) to the operator
+ * @param params the non-NDArray arguments to the operator. Should be a {@code PairList<String,
+ * String>}
+ * @return the error code or zero for no errors
+ */
+ public NDArray[] invoke(MxResource parent, NDArray[] src, PairList<String, ?> params) {
+ checkDevices(src);
+ PairList<Pointer, SparseFormat> pairList =
+ JnaUtils.imperativeInvoke(handle, src, null, params);
+ return pairList.stream()
+ .map(
+ pair -> {
+ if (pair.getValue() != SparseFormat.DENSE) {
+ return NDArray.create(parent, pair.getKey(), pair.getValue());
+ }
+ return NDArray.create(parent, pair.getKey());
+ })
+ .toArray(NDArray[]::new);
+ }
+
+ private void checkDevices(NDArray[] src) {
+ // check if all the NDArrays are in the same device
+ if (logger.isDebugEnabled() && src.length > 1) {
+ Device device = src[0].getDevice();
+ for (int i = 1; i < src.length; ++i) {
+ if (!device.equals(src[i].getDevice())) {
+ logger.warn(
+ "Please make sure all the NDArrays are in the same device. You can call toDevice() to move the NDArray to the desired Device.");
+ }
+ }
+ }
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/JnaUtils.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/JnaUtils.java
new file mode 100644
index 0000000..46cff30
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/JnaUtils.java
@@ -0,0 +1,893 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.jna;
+
+import com.sun.jna.Memory;
+import com.sun.jna.Native;
+import com.sun.jna.Pointer;
+import com.sun.jna.ptr.PointerByReference;
+import java.nio.Buffer;
+import java.nio.ByteBuffer;
+import java.nio.IntBuffer;
+import java.nio.charset.StandardCharsets;
+import java.nio.file.Path;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
+import org.apache.mxnet.engine.CachedOp;
+import org.apache.mxnet.engine.Device;
+import org.apache.mxnet.engine.DeviceType;
+import org.apache.mxnet.engine.MxResource;
+import org.apache.mxnet.engine.Symbol;
+import org.apache.mxnet.exception.JnaCallException;
+import org.apache.mxnet.ndarray.NDArray;
+import org.apache.mxnet.ndarray.NDList;
+import org.apache.mxnet.ndarray.types.DataType;
+import org.apache.mxnet.ndarray.types.Shape;
+import org.apache.mxnet.ndarray.types.SparseFormat;
+import org.apache.mxnet.nn.Parameter;
+import org.apache.mxnet.nn.SymbolBlock;
+import org.apache.mxnet.util.PairList;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * A class containing utilities to interact with the MXNet Engine's Java Native Access (JNA) layer.
+ */
+@SuppressWarnings("MissingJavadocMethod")
+public final class JnaUtils {
+
+ private static final Logger logger = LoggerFactory.getLogger(JnaUtils.class);
+
+ public static final MxnetLibrary LIB = LibUtils.loadLibrary();
+
+ public static final ObjectPool<PointerByReference> REFS =
+ new ObjectPool<>(PointerByReference::new, r -> r.setValue(null));
+
+ private static final String[] OP_NAME_PREFIX = {
+ "_contrib_", "_linalg_", "_sparse_", "_image_", "_random_"
+ };
+
+ private static final Map<String, FunctionInfo> OPS = getNdArrayFunctions();
+ // private static final Map<String, FunctionInfo> OPS = null;
+
+ private static final Set<String> FEATURES = getFeaturesInternal();
+
+ public static final String[] EMPTY_ARRAY = new String[0];
+
+ private JnaUtils() {
+ // not called
+ }
+
+ /** An enum that enumerates the statuses of numpy mode. */
+ public enum NumpyMode {
+ OFF,
+ THREAD_LOCAL_ON,
+ GLOBAL_ON
+ }
+
+ public static void waitAll() {
+ checkCall(LIB.MXNDArrayWaitAll());
+ }
+
+ public static void setNumpyMode(NumpyMode mode) {
+ IntBuffer ret = IntBuffer.allocate(1);
+ checkCall(LIB.MXSetIsNumpyShape(mode.ordinal(), ret));
+ }
+
+ /////////////////////////////////
+ // Related to CacheOp
+ /////////////////////////////////
+ public static CachedOp createCachedOp(SymbolBlock block, MxResource parent) {
+ Symbol symbol = block.getSymbol();
+
+ List<Parameter> parameters = block.getAllParameters();
+
+ // record data index in all inputs
+ PairList<String, Integer> dataIndices = new PairList<>();
+ // record parameter index in all inputs
+ List<Integer> paramIndices = new ArrayList<>();
+ int index = 0;
+ for (Parameter parameter : parameters) {
+ // We assume uninitialized parameters are data inputs
+ if (parameter.isInitialized()) {
+ paramIndices.add(index);
+ } else {
+ dataIndices.add(parameter.getName(), index);
+ }
+ ++index;
+ }
+
+ // Creating CachedOp
+ Pointer symbolHandle = symbol.getHandle();
+ PointerByReference ref = REFS.acquire();
+
+ // static_alloc and static_shape are enabled by default
+ String[] keys = {"data_indices", "param_indices", "static_alloc", "static_shape"};
+ String[] values = {dataIndices.values().toString(), paramIndices.toString(), "1", "1"};
+
+ checkCall(LIB.MXCreateCachedOp(symbolHandle, keys.length, keys, values, ref, (byte) 0));
+ Pointer pointer = ref.getValue();
+ REFS.recycle(ref);
+
+ return new CachedOp(parent, pointer, parameters, paramIndices, dataIndices);
+ }
+
+ public static void freeCachedOp(Pointer handle) {
+ checkCall(LIB.MXFreeCachedOp(handle));
+ }
+
+ /////////////////////////////////
+ // About Symbol
+ /////////////////////////////////
+ public static Pointer createSymbolFromFile(String path) {
+ PointerByReference ref = REFS.acquire();
+ checkCall(LIB.MXSymbolCreateFromFile(path, ref));
+ Pointer pointer = ref.getValue();
+ REFS.recycle(ref);
+ return pointer;
+ }
+
+ public static Pointer createSymbolFromString(String json) {
+ PointerByReference ref = REFS.acquire();
+ checkCall(LIB.MXSymbolCreateFromJSON(json, ref));
+ Pointer pointer = ref.getValue();
+ REFS.recycle(ref);
+ return pointer;
+ }
+
+ public static String[] listSymbolOutputs(Pointer symbol) {
+ IntBuffer size = IntBuffer.allocate(1);
+ PointerByReference ref = REFS.acquire();
+
+ checkCall(LIB.MXSymbolListOutputs(symbol, size, ref));
+ String[] ret = toStringArray(ref, size.get());
+ REFS.recycle(ref);
+ return ret;
+ }
+
+ public static String printSymbol(Pointer symbol) {
+ String[] outStr = new String[1];
+ checkCall(LIB.NNSymbolPrint(symbol, outStr));
+ return outStr[0];
+ }
+
+ public static void freeSymbol(Pointer symbol) {
+ checkCall(LIB.MXSymbolFree(symbol));
+ }
+
+ public static String[] listSymbolArguments(Pointer symbol) {
+ IntBuffer size = IntBuffer.allocate(1);
+ PointerByReference ref = REFS.acquire();
+
+ checkCall(LIB.MXSymbolListArguments(symbol, size, ref));
+
+ String[] ret = toStringArray(ref, size.get());
+ REFS.recycle(ref);
+ return ret;
+ }
+
+ public static String[] listSymbolAuxiliaryStates(Pointer symbol) {
+ IntBuffer size = IntBuffer.allocate(1);
+ PointerByReference ref = REFS.acquire();
+
+ checkCall(LIB.MXSymbolListAuxiliaryStates(symbol, size, ref));
+
+ String[] ret = toStringArray(ref, size.get());
+ REFS.recycle(ref);
+ return ret;
+ }
+
+ public static String[] listSymbolNames(Pointer symbol) {
+ IntBuffer size = IntBuffer.allocate(1);
+ PointerByReference ref = REFS.acquire();
+
+ checkCall(LIB.NNSymbolListInputNames(symbol, 0, size, ref));
+
+ String[] ret = toStringArray(ref, size.get());
+ REFS.recycle(ref);
+ return ret;
+ }
+
+ public static Pointer getSymbolInternals(Pointer symbol) {
+ PointerByReference ref = REFS.acquire();
+ checkCall(LIB.MXSymbolGetInternals(symbol, ref));
+ Pointer pointer = ref.getValue();
+ REFS.recycle(ref);
+ return pointer;
+ }
+
+ private static List<Shape> recoverShape(
+ NativeSizeByReference size, PointerByReference nDim, PointerByReference data) {
+ int shapeLength = (int) size.getValue().longValue();
+ if (shapeLength == 0) {
+ return new ArrayList<>();
+ }
+ int[] dims = nDim.getValue().getIntArray(0, shapeLength);
+ int flattenedLength = 0;
+ for (int dim : dims) {
+ flattenedLength += dim;
+ }
+ long[] flattenedShapes = data.getValue().getPointer(0).getLongArray(0, flattenedLength);
+ int idx = 0;
+ List<Shape> result = new ArrayList<>();
+ for (int dim : dims) {
+ long[] shape = new long[dim];
+ System.arraycopy(flattenedShapes, idx, shape, 0, dim);
+ idx += dim;
+ result.add(new Shape(shape));
+ }
+ return result;
+ }
+
+ public static List<List<Shape>> inferShape(Symbol symbol, PairList<String, Shape> args) {
+ Pointer handler = symbol.getHandle();
+ int numArgs = args.size();
+ String[] keys = args.keys().toArray(new String[0]);
+ // the following two is also the representation of
+ // CSR NDArray
+ long[] indPtr = new long[numArgs + 1];
+ Shape flattened = new Shape();
+ indPtr[0] = 0;
+ for (int i = 0; i < args.size(); i++) {
+ Shape shape = args.valueAt(i);
+ indPtr[i + 1] = shape.dimension();
+ flattened = flattened.addAll(shape);
+ }
+ long[] flattenedShapeArray = flattened.getShape();
+
+ NativeSizeByReference inShapeSize = new NativeSizeByReference();
+ PointerByReference inShapeNDim = REFS.acquire();
+ PointerByReference inShapeData = REFS.acquire();
+ NativeSizeByReference outShapeSize = new NativeSizeByReference();
+ PointerByReference outShapeNDim = REFS.acquire();
+ PointerByReference outShapeData = REFS.acquire();
+ NativeSizeByReference auxShapeSize = new NativeSizeByReference();
+ PointerByReference auxShapeNDim = REFS.acquire();
+ PointerByReference auxShapeData = REFS.acquire();
+ IntBuffer complete = IntBuffer.allocate(1);
+ checkCall(
+ LIB.MXSymbolInferShape64(
+ handler,
+ numArgs,
+ keys,
+ indPtr,
+ flattenedShapeArray,
+ inShapeSize,
+ inShapeNDim,
+ inShapeData,
+ outShapeSize,
+ outShapeNDim,
+ outShapeData,
+ auxShapeSize,
+ auxShapeNDim,
+ auxShapeData,
+ complete));
+ if (complete.get() != 0) {
+ return Arrays.asList(
+ recoverShape(inShapeSize, inShapeNDim, inShapeData),
+ recoverShape(outShapeSize, outShapeNDim, outShapeData),
+ recoverShape(auxShapeSize, auxShapeNDim, auxShapeData));
+ }
+ return null;
+ }
+
+ public static Pointer optimizeFor(Symbol current, String backend, Device device) {
+ // TODO: Support partition on parameters
+ PointerByReference returnedSymbolHandle = REFS.acquire();
+ // placeHolders
+ PointerByReference[] placeHolders = {
+ REFS.acquire(),
+ REFS.acquire(),
+ REFS.acquire(),
+ REFS.acquire(),
+ REFS.acquire(),
+ REFS.acquire()
+ };
+ // there is no need to update parameters
+ // TODO : check 22th parameter type
+ checkCall(
+ LIB.MXOptimizeForBackend(
+ current.getHandle(),
+ backend,
+ DeviceType.toDeviceType(device),
+ returnedSymbolHandle,
+ 0,
+ placeHolders[0],
+ 0,
+ placeHolders[1],
+ 0,
+ new String[0],
+ new String[0],
+ 0,
+ new String[0],
+ new long[0],
+ new int[0],
+ 0,
+ new String[0],
+ new int[0],
+ 0,
+ new String[0],
+ new int[0],
+ (byte) 0,
+ IntBuffer.allocate(0),
+ placeHolders[2],
+ placeHolders[3],
+ IntBuffer.allocate(0),
+ placeHolders[4],
+ placeHolders[5]));
+ Pointer ptr = returnedSymbolHandle.getValue();
+ REFS.recycle(returnedSymbolHandle);
+ Arrays.stream(placeHolders).forEach(REFS::recycle);
+ return ptr;
+ }
+
+ public static String getSymbolString(Pointer symbol) {
+ String[] holder = new String[1];
+ checkCall(LIB.MXSymbolSaveToJSON(symbol, holder));
+ return holder[0];
+ }
+
+ public static Pointer getSymbolOutput(Pointer symbol, int index) {
+ PointerByReference ref = REFS.acquire();
+ checkCall(LIB.MXSymbolGetOutput(symbol, index, ref));
+ Pointer pointer = ref.getValue();
+ REFS.recycle(ref);
+ return pointer;
+ }
+
+ public static NDList loadNdArray(MxResource parent, Path path, Device device) {
+ IntBuffer handlesSize = IntBuffer.allocate(1);
+ PointerByReference handlesRef = REFS.acquire();
+ PointerByReference namesRef = REFS.acquire();
+ IntBuffer namesSize = IntBuffer.allocate(1);
+ checkCall(LIB.MXNDArrayLoad(path.toString(), handlesSize, handlesRef, namesSize, namesRef));
+ int ndArrayCount = handlesSize.get();
+ int nameCount = namesSize.get();
+ if (nameCount > 0 && ndArrayCount != nameCount) {
+ throw new IllegalStateException(
+ "Mismatch between names and arrays in checkpoint file: " + path.toString());
+ }
+ Pointer[] handles = handlesRef.getValue().getPointerArray(0, ndArrayCount);
+ NDList ndList = new NDList();
+ if (nameCount == 0) {
+ for (Pointer handle : handles) {
+ ndList.add(NDArray.create(parent, handle));
+ }
+ } else {
+ String[] names = namesRef.getValue().getStringArray(0, nameCount);
+ for (int i = 0; i < ndArrayCount; i++) {
+ NDArray array = NDArray.create(parent, handles[i]);
+ array.setName(names[i]);
+ ndList.add(array);
+ }
+ }
+
+ REFS.recycle(namesRef);
+ REFS.recycle(handlesRef);
+
+ // MXNet always load NDArray on CPU
+ if (Device.cpu().equals(device)) {
+ return ndList;
+ }
+
+ NDList ret = ndList.toDevice(device, true);
+ ndList.close();
+ return ret;
+ }
+
+ public static PairList<String, Pointer> loadNdArrayFromFile(String path) {
+ IntBuffer handleSize = IntBuffer.allocate(1);
+ IntBuffer namesSize = IntBuffer.allocate(1);
+ PointerByReference handlesRef = REFS.acquire();
+ PointerByReference namesRef = REFS.acquire();
+ checkCall(LIB.MXNDArrayLoad(path, handleSize, handlesRef, namesSize, namesRef));
+ // TODO : construct NDArray Objects
+ int handleCount = handleSize.get();
+ int nameCount = namesSize.get();
+ if (nameCount > 0 && nameCount != handleCount) {
+ throw new IllegalStateException(
+ "Mismatch between names and arrays in checkpoint file: " + path);
+ }
+ Pointer[] handles = handlesRef.getValue().getPointerArray(0, handleCount);
+
+ PairList<String, Pointer> pairList = new PairList<>();
+
+ if (nameCount == 0) {
+ for (Pointer handle : handles) {
+ pairList.add(null, handle);
+ }
+ } else {
+ String[] names = namesRef.getValue().getStringArray(0, nameCount);
+ for (int i = 0; i < handleCount; i++) {
+ pairList.add(names[i], handles[i]);
+ }
+ }
+ REFS.recycle(namesRef);
+ REFS.recycle(handlesRef);
+
+ return pairList;
+ }
+
+ public static void freeNdArray(Pointer handle) {
+ checkCall(LIB.MXNDArrayFree(handle));
+ }
+
+ public static Pointer loadNdArrayFromByteArray(byte[] buf, int offset, int size) {
+ Memory memory = new Memory(size);
+ memory.write(0, buf, offset, size);
+ PointerByReference outRef = REFS.acquire();
+ checkCall(LIB.MXNDArrayLoadFromRawBytes(memory, new NativeSize(size), outRef));
+ Pointer p = outRef.getValue();
+ // outRef.getValue().getPointerArray(0, size);
+
+ REFS.recycle(outRef);
+ return p;
+ }
+
+ public static Pointer loadNdArrayFromByteBuffer(ByteBuffer byteBuffer) {
+ // Pointer handle = new Pointer(byteBuffer.address);
+ // ((DirectByteBuffer) byteBuffer).address()
+ // TODO
+ byte[] bytes = new byte[byteBuffer.limit()];
+ byteBuffer.get(bytes);
+ return loadNdArrayFromByteArray(bytes, 0, byteBuffer.limit());
+ }
+
+ public static ByteBuffer saveNdArrayAsByteBuffer(Pointer ndArray) {
+ NativeSizeByReference size = new NativeSizeByReference();
+ PointerByReference ref = new PointerByReference();
+ checkCall(LIB.MXNDArraySaveRawBytes(ndArray, size, ref));
+ return ref.getValue().getByteBuffer(0, size.getValue().longValue());
+ }
+
+ public static byte[] saveNdArrayAsByteArray(Pointer ndArray) {
+ ByteBuffer buffer = saveNdArrayAsByteBuffer(ndArray);
+ byte[] bytes = new byte[buffer.limit()];
+ buffer.get(bytes);
+ return bytes;
+ }
+
+ public static void syncCopyToCPU(Pointer ndArray, Pointer data, int len) {
+ NativeSize size = new NativeSize(len);
+ checkNDArray(ndArray, "copy from");
+ checkNDArray(data, "copy to");
+ checkCall(LIB.MXNDArraySyncCopyToCPU(ndArray, data, size));
+ }
+
+ public static void syncCopyFromCPU(Pointer ndArray, Buffer data, int len) {
+ NativeSize size = new NativeSize(len);
+ Pointer pointer = Native.getDirectBufferPointer(data);
+ checkCall(LIB.MXNDArraySyncCopyFromCPU(ndArray, pointer, size));
+ }
+
+ public static void waitToRead(Pointer ndArray) {
+ checkNDArray(ndArray, "wait to read");
+ checkCall(LIB.MXNDArrayWaitToRead(ndArray));
+ }
+
+ public static void waitToWrite(Pointer ndArray) {
+ checkNDArray(ndArray, "wait to write");
+ checkCall(LIB.MXNDArrayWaitToWrite(ndArray));
+ }
+
+ public static Pointer detachGradient(Pointer handle) {
+ PointerByReference ref = REFS.acquire();
+ checkCall(LIB.MXNDArrayDetach(handle, ref));
+ Pointer pointer = ref.getValue();
+ REFS.recycle(ref);
+ return pointer;
+ }
+
+ public static Pointer getGradient(Pointer handle) {
+ PointerByReference ref = REFS.acquire();
+ checkNDArray(handle, "get the gradient for");
+ checkCall(LIB.MXNDArrayGetGrad(handle, ref));
+ Pointer pointer = ref.getValue();
+ REFS.recycle(ref);
+ return pointer;
+ }
+
+ public static void autogradMarkVariables(
+ int numVar, Pointer varHandles, IntBuffer reqsArray, Pointer gradHandles) {
+ PointerByReference varRef = REFS.acquire();
+ PointerByReference gradRef = REFS.acquire();
+ varRef.setValue(varHandles);
+ gradRef.setValue(gradHandles);
+ checkCall(LIB.MXAutogradMarkVariables(numVar, varRef, reqsArray, gradRef));
+ REFS.recycle(varRef);
+ REFS.recycle(gradRef);
+ }
+
+ public static Map<String, FunctionInfo> getNdArrayFunctions() {
+ Set<String> opNames = JnaUtils.getAllOpNames();
+ Map<String, FunctionInfo> map = new ConcurrentHashMap<>();
+
+ PointerByReference ref = REFS.acquire();
+ for (String opName : opNames) {
+ checkCall(LIB.NNGetOpHandle(opName, ref));
+ String functionName = getOpNamePrefix(opName);
+ map.put(functionName, getFunctionByName(opName, functionName, ref.getValue()));
+ ref.setValue(null);
+ }
+ REFS.recycle(ref);
+ return map;
+ }
+
+ public static PairList<Pointer, SparseFormat> imperativeInvoke(
+ Pointer function, NDArray[] src, NDArray[] dest, PairList<String, ?> params) {
+ String[] keys;
+ String[] values;
+ if (params == null) {
+ keys = EMPTY_ARRAY;
+ values = EMPTY_ARRAY;
+ } else {
+ keys = params.keyArray(EMPTY_ARRAY);
+ values = params.values().stream().map(Object::toString).toArray(String[]::new);
+ }
+ // StringArray keyArray = StringArray.of(keys);
+ // StringArray valueArray = StringArray.of(values);
+ PointerArray srcArray = toPointerArray(src);
+ PointerArray destArray = toPointerArray(dest);
+ PointerByReference destRef = REFS.acquire();
+ destRef.setValue(destArray);
+ PointerByReference destSType = REFS.acquire();
+ IntBuffer numOutputs = IntBuffer.allocate(1);
+ numOutputs.put(0, 1);
+
+ checkCall(
+ LIB.MXImperativeInvoke(
+ function,
+ src.length,
+ srcArray,
+ numOutputs,
+ destRef,
+ keys.length,
+ keys,
+ values,
+ destSType));
+ int numOfOutputs = numOutputs.get(0);
+ Pointer[] ptrArray = destRef.getValue().getPointerArray(0, numOfOutputs);
+ int[] sTypes = destSType.getValue().getIntArray(0, numOfOutputs);
+ PairList<Pointer, SparseFormat> pairList = new PairList<>();
+ for (int i = 0; i < numOfOutputs; i++) {
+ pairList.add(ptrArray[i], SparseFormat.fromValue(sTypes[i]));
+ }
+ REFS.recycle(destRef);
+ REFS.recycle(destSType);
+ srcArray.recycle();
+ // keyArray.recycle();
+ // valueArray.recycle();
+
+ if (destArray != null) {
+ destArray.recycle();
+ }
+ return pairList;
+ }
+
+ private static PointerArray toPointerArray(NDArray[] vals) {
+ if (vals == null) {
+ return null;
+ }
+ Pointer[] valPointers = new Pointer[vals.length];
+ for (int i = 0; i < vals.length; i++) {
+ valPointers[i] = vals[i].getHandle();
+ }
+ return PointerArray.of(valPointers);
+ }
+
+ public static FunctionInfo op(String opName) {
+ if (!OPS.containsKey(opName)) {
+ throw new IllegalArgumentException("Unknown operator: " + opName);
+ }
+ return OPS.get(opName);
+ }
+
+ public static FunctionInfo getFunctionByName(String name, String functionName, Pointer handle) {
+ String[] nameRef = {name};
+ String[] description = new String[1];
+ IntBuffer numArgs = IntBuffer.allocate(1);
+ PointerByReference argNameRef = REFS.acquire();
+ PointerByReference argTypeRef = REFS.acquire();
+ PointerByReference argDescRef = REFS.acquire();
+ String[] keyVarArgs = new String[1];
+ String[] returnType = new String[1];
+
+ checkCall(
+ LIB.MXSymbolGetAtomicSymbolInfo(
+ handle,
+ nameRef,
+ description,
+ numArgs,
+ argNameRef,
+ argTypeRef,
+ argDescRef,
+ keyVarArgs,
+ returnType));
+
+ int count = numArgs.get();
+ PairList<String, String> arguments = new PairList<>();
+ if (count != 0) {
+ String[] argNames =
+ argNameRef.getValue().getStringArray(0, count, StandardCharsets.UTF_8.name());
+ String[] argTypes =
+ argTypeRef.getValue().getStringArray(0, count, StandardCharsets.UTF_8.name());
+ for (int i = 0; i < argNames.length; i++) {
+ arguments.add(argNames[i], argTypes[i]);
+ }
+ }
+
+ REFS.recycle(argNameRef);
+ REFS.recycle(argTypeRef);
+ REFS.recycle(argDescRef);
+
+ return new FunctionInfo(handle, functionName, arguments);
+ }
+
+ public static Set<String> getAllOpNames() {
+ IntBuffer outSize = IntBuffer.allocate(1);
+ PointerByReference outArray = REFS.acquire();
+
+ checkCall(LIB.MXListAllOpNames(outSize, outArray));
+
+ int size = outSize.get();
+ Pointer[] pointers = outArray.getValue().getPointerArray(0, size);
+
+ Set<String> set = new HashSet<>();
+ for (Pointer p : pointers) {
+ set.add(p.getString(0, StandardCharsets.UTF_8.name()));
+ }
+ REFS.recycle(outArray);
+ return set;
+ }
+
+ public static String getOpNamePrefix(String name) {
+ for (String prefix : OP_NAME_PREFIX) {
+ if (name.startsWith(prefix)) {
+ return name.substring(prefix.length());
+ }
+ }
+ return name;
+ }
+
+ public static DataType getDataTypeOfNdArray(Pointer handle) {
+ IntBuffer dataType = IntBuffer.allocate(1);
+ checkNDArray(handle, "get the data type of");
+ checkCall(LIB.MXNDArrayGetDType(handle, dataType));
+ return DataType.values()[dataType.get()];
+ }
+
+ public static Device getDeviceOfNdArray(Pointer handle) {
+ IntBuffer deviceType = IntBuffer.allocate(1);
+ IntBuffer deviceId = IntBuffer.allocate(1);
+ checkNDArray(handle, "get the device of");
+ checkCall(LIB.MXNDArrayGetContext(handle, deviceType, deviceId));
+ String deviceTypeStr = DeviceType.fromDeviceType(deviceType.get(0));
+ // CPU is special case which don't have device id
+ return Device.of(deviceTypeStr, deviceId.get(0));
+ }
+
+ public static Shape getShapeOfNdArray(Pointer handle) {
+ IntBuffer dim = IntBuffer.allocate(1);
+ PointerByReference ref = REFS.acquire();
+ checkNDArray(handle, "get the shape of");
+ checkCall(LIB.MXNDArrayGetShape(handle, dim, ref));
+ int nDim = dim.get();
+ if (nDim == 0) {
+ REFS.recycle(ref);
+ return new Shape();
+ }
+ int[] shape = ref.getValue().getIntArray(0, nDim);
+ REFS.recycle(ref);
+ return new Shape(Arrays.stream(shape).asLongStream().toArray());
+ }
+
+ public static Shape getShape64OfNdArray(Pointer handle) {
+ IntBuffer dim = IntBuffer.allocate(1);
+ PointerByReference ref = REFS.acquire();
+ checkNDArray(handle, "get the shape64 of");
+ checkCall(LIB.MXNDArrayGetShape64(handle, dim, ref));
+ int nDim = dim.get();
+ if (nDim == 0) {
+ REFS.recycle(ref);
+ return new Shape();
+ }
+ int[] shape = ref.getValue().getIntArray(0, nDim);
+ REFS.recycle(ref);
+ return new Shape(Arrays.stream(shape).asLongStream().toArray());
+ }
+
+ public static SparseFormat getStorageType(Pointer handle) {
+ IntBuffer type = IntBuffer.allocate(1);
+ checkNDArray(handle, "get the storage type of");
+ checkCall(LIB.MXNDArrayGetStorageType(handle, type));
+ return SparseFormat.fromValue(type.get());
+ }
+
+ public static Pointer createNdArray(
+ Device device, Shape shape, DataType dataType, int size, boolean delayedAlloc) {
+ int deviceType = DeviceType.toDeviceType(device);
+ int deviceId = (deviceType != 1) ? device.getDeviceId() : -1;
+ int delay = delayedAlloc ? 1 : 0;
+
+ PointerByReference ref = REFS.acquire();
+ int[] shapeArray = Arrays.stream(shape.getShape()).mapToInt(Math::toIntExact).toArray();
+ checkCall(
+ LIB.MXNDArrayCreate(
+ shapeArray, size, deviceType, deviceId, delay, dataType.ordinal(), ref));
+
+ Pointer pointer = ref.getValue();
+ REFS.recycle(ref);
+ return pointer;
+ }
+
+ public static Pointer createSparseNdArray(
+ SparseFormat fmt,
+ Device device,
+ Shape shape,
+ DataType dtype,
+ DataType[] auxDTypes,
+ Shape[] auxShapes,
+ boolean delayedAlloc) {
+ int[] shapeArray = Arrays.stream(shape.getShape()).mapToInt(Math::toIntExact).toArray();
+ int deviceType = DeviceType.toDeviceType(device);
+ int deviceId = (deviceType != 1) ? device.getDeviceId() : -1;
+ int delay = delayedAlloc ? 1 : 0;
+ PointerByReference ref = REFS.acquire();
+ IntBuffer auxDTypesInt =
+ IntBuffer.wrap(Arrays.stream(auxDTypes).mapToInt(DataType::ordinal).toArray());
+ IntBuffer auxNDims =
+ IntBuffer.wrap(Arrays.stream(auxShapes).mapToInt(Shape::dimension).toArray());
+ int[] auxShapesInt = Arrays.stream(auxShapes).mapToInt(ele -> (int) ele.head()).toArray();
+ checkCall(
+ LIB.MXNDArrayCreateSparseEx(
+ fmt.getValue(),
+ shapeArray,
+ shapeArray.length,
+ deviceType,
+ deviceId,
+ delay,
+ dtype.ordinal(),
+ auxDTypes.length,
+ auxDTypesInt,
+ auxNDims,
+ auxShapesInt,
+ ref));
+ Pointer pointer = ref.getValue();
+ REFS.recycle(ref);
+ return pointer;
+ }
+
+ public static void ndArraySyncCopyFromNdArray(NDArray dest, NDArray src, int location) {
+ checkCall(LIB.MXNDArraySyncCopyFromNDArray(dest.getHandle(), src.getHandle(), location));
+ }
+
+ public static int getVersion() {
+ IntBuffer version = IntBuffer.allocate(1);
+ checkCall(LIB.MXGetVersion(version));
+ return version.get();
+ }
+
+ public static NDArray[] cachedOpInvoke(
+ MxResource parent, Pointer cachedOpHandle, NDArray[] inputs) {
+ IntBuffer buf = IntBuffer.allocate(1);
+ PointerArray array = toPointerArray(inputs);
+ PointerByReference ref = REFS.acquire();
+ PointerByReference outSTypeRef = REFS.acquire();
+ Device device = inputs[0].getDevice();
+ // TODO: check the init value of default_dev_type and default_dev_id
+ checkCall(
+ LIB.MXInvokeCachedOp(
+ cachedOpHandle,
+ inputs.length,
+ array,
+ DeviceType.toDeviceType(device),
+ 0,
+ buf,
+ ref,
+ outSTypeRef));
+ int numOutputs = buf.get();
+ Pointer[] ptrArray = ref.getValue().getPointerArray(0, numOutputs);
+ int[] sTypes = outSTypeRef.getValue().getIntArray(0, numOutputs);
+ NDArray[] output = new NDArray[numOutputs];
+ for (int i = 0; i < numOutputs; i++) {
+ if (sTypes[i] != 0) {
+ output[i] = NDArray.create(parent, ptrArray[i], SparseFormat.fromValue(sTypes[i]));
+ } else {
+ output[i] = NDArray.create(parent, ptrArray[i]);
+ }
+ }
+ REFS.recycle(ref);
+ REFS.recycle(outSTypeRef);
+ array.recycle();
+ return output;
+ }
+
+ private static void checkNDArray(Pointer pointer, String msg) {
+ if (pointer == null) {
+ throw new IllegalArgumentException(
+ "Tried to " + msg + " an MXNet NDArray that was already closed");
+ }
+ }
+
+ public static void checkCall(int ret) {
+ if (ret != 0) {
+ logger.error("MXNet engine call failed: " + getLastError());
+ throw new JnaCallException("MXNet engine call failed: " + getLastError());
+ }
+ }
+
+ private static String getLastError() {
+ return LIB.MXGetLastError();
+ }
+
+ private static String[] toStringArray(PointerByReference ref, int size) {
+ if (size == 0) {
+ return new String[0];
+ }
+
+ Pointer[] pointers = ref.getValue().getPointerArray(0, size);
+
+ String[] arr = new String[size];
+ for (int i = 0; i < size; ++i) {
+ arr[i] = pointers[i].getString(0, StandardCharsets.UTF_8.name());
+ }
+
+ return arr;
+ }
+
+ private static Set<String> getFeaturesInternal() {
+ PointerByReference ref = REFS.acquire();
+ NativeSizeByReference outSize = new NativeSizeByReference();
+ checkCall(LIB.MXLibInfoFeatures(ref, outSize));
+ int size = outSize.getValue().intValue();
+ if (size == 0) {
+ REFS.recycle(ref);
+ return Collections.emptySet();
+ }
+
+ LibFeature pointer = new LibFeature(ref.getValue());
+ pointer.read();
+
+ LibFeature[] features = (LibFeature[]) pointer.toArray(size);
+
+ Set<String> set = new HashSet<>();
+ for (LibFeature feature : features) {
+ if (feature.getEnabled() == 1) {
+ set.add(feature.getName());
+ }
+ }
+ REFS.recycle(ref);
+ return set;
+ }
+
+ public static Set<String> getFeatures() {
+ return FEATURES;
+ }
+
+ public static boolean autogradIsTraining() {
+ ByteBuffer isTraining = ByteBuffer.allocate(1);
+ checkCall(LIB.MXAutogradIsTraining(isTraining));
+ return isTraining.get(0) == 1;
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/LibUtils.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/LibUtils.java
new file mode 100644
index 0000000..d110398
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/LibUtils.java
@@ -0,0 +1,211 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.jna;
+
+import com.sun.jna.Library;
+import com.sun.jna.Native;
+import java.io.File;
+import java.io.IOException;
+import java.io.InputStream;
+import java.net.URL;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.StandardCopyOption;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Enumeration;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+import org.apache.mxnet.util.Platform;
+import org.apache.mxnet.util.Utils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Utilities for finding the MXNet Engine binary on the System.
+ *
+ * <p>The Engine will be searched for in a variety of locations in the following order:
+ *
+ * <ol>
+ * <li>In the path specified by the MXNET_LIBRARY_PATH environment variable
+ * <li>In a jar file location in the classpath. These jars can be created with the mxnet-native
+ * module.
+ * <li>In the python3 path. These can be installed using pip.
+ * <li>In the python path. These can be installed using pip.
+ * </ol>
+ */
+@SuppressWarnings("MissingJavadocMethod")
+public final class LibUtils {
+
+ private static final Logger logger = LoggerFactory.getLogger(LibUtils.class);
+
+ private static final String LIB_NAME = "mxnet";
+
+ private static final String MXNET_LIBRARY_PATH = "MXNET_LIBRARY_PATH";
+
+ private static final String MXNET_PROPERTIES_FILE_PATH = "native/lib/mxnet.properties";
+
+ private LibUtils() {}
+
+ public static MxnetLibrary loadLibrary() {
+
+ String libName = getLibName();
+ logger.debug("Loading mxnet library from: {}", libName);
+ if (System.getProperty("os.name").startsWith("Linux")) {
+ logger.info("Loading on Linux platform");
+ Map<String, Integer> options = new ConcurrentHashMap<>();
+ int rtld = 1; // Linux RTLD lazy + local
+ options.put(Library.OPTION_OPEN_FLAGS, rtld);
+ return Native.load(libName, MxnetLibrary.class, options);
+ }
+ return Native.load(libName, MxnetLibrary.class);
+ }
+
+ public static String getLibName() {
+ String libName = findOverrideLibrary();
+ if (libName == null) {
+ libName = LibUtils.findLibraryInClasspath();
+ if (libName == null) {
+ libName = LIB_NAME;
+ }
+ }
+
+ return libName;
+ }
+
+ private static String findOverrideLibrary() {
+ // TODO: load from jar files
+ String libPath = System.getenv(MXNET_LIBRARY_PATH);
+ if (libPath != null) {
+ String libName = findLibraryInPath(libPath);
+ if (libName != null) {
+ return libName;
+ }
+ }
+
+ libPath = System.getProperty("java.library.path");
+ if (libPath != null) {
+ return findLibraryInPath(libPath);
+ }
+ return null;
+ }
+
+ private static synchronized String findLibraryInClasspath() {
+ Enumeration<URL> urls = getUrls();
+ // No native jars
+ if (!urls.hasMoreElements()) {
+ logger.debug("mxnet.properties not found in class path.");
+ return null;
+ }
+
+ // Find the mxnet library version that matches local system platform
+ // throw exception if no one matches
+ Platform systemPlatform = Platform.fromSystem();
+ try {
+ while (urls.hasMoreElements()) {
+ URL url = urls.nextElement();
+ Platform platform = Platform.fromUrl(url);
+ if (!platform.isPlaceholder() && platform.matches(systemPlatform)) {
+ return loadLibraryFromClasspath(platform);
+ }
+ }
+ } catch (IOException e) {
+ throw new IllegalStateException(
+ "Failed to read MXNet native library jar properties", e);
+ }
+
+ throw new IllegalStateException(
+ "Your MXNet native library jar does not match your operating system. Make sure that the Maven Dependency Classifier matches your system type.");
+ }
+
+ private static Enumeration<URL> getUrls() {
+ try {
+ return Thread.currentThread()
+ .getContextClassLoader()
+ .getResources(MXNET_PROPERTIES_FILE_PATH);
+ } catch (IOException e) {
+ logger.warn(
+ String.format(
+ "IO Exception occurs when try to find the file %s", MXNET_LIBRARY_PATH),
+ e);
+ return null;
+ }
+ }
+
+ private static String loadLibraryFromClasspath(Platform platform) {
+ Path tmp = null;
+ try {
+ String libName = System.mapLibraryName(LIB_NAME);
+ Path cacheFolder = Utils.getEngineCacheDir(LIB_NAME);
+ logger.debug("Using cache dir: {}", cacheFolder);
+
+ Path dir = cacheFolder.resolve(platform.getVersion() + platform.getClassifier());
+ Path path = dir.resolve(libName);
+ if (Files.exists(path)) {
+ return path.toAbsolutePath().toString();
+ }
+ Files.createDirectories(cacheFolder);
+ tmp = Files.createTempDirectory(cacheFolder, "tmp");
+ for (String file : platform.getLibraries()) {
+ String libPath = "/native/lib/" + file;
+ try (InputStream is = LibUtils.class.getResourceAsStream(libPath)) {
+ logger.info("Extracting {} to cache ...", file);
+ Files.copy(is, tmp.resolve(file), StandardCopyOption.REPLACE_EXISTING);
+ }
+ }
+
+ Utils.moveQuietly(tmp, dir);
+ return path.toAbsolutePath().toString();
+ } catch (IOException e) {
+ throw new IllegalStateException("Failed to extract MXNet native library", e);
+ } finally {
+ if (tmp != null) {
+ Utils.deleteQuietly(tmp);
+ }
+ }
+ }
+
+ private static String findLibraryInPath(String libPath) {
+ String[] paths = libPath.split(File.pathSeparator);
+ List<String> mappedLibNames;
+ if (com.sun.jna.Platform.isMac()) {
+ mappedLibNames = Arrays.asList("libmxnet.dylib", "libmxnet.jnilib", "libmxnet.so");
+ } else {
+ mappedLibNames = Collections.singletonList(System.mapLibraryName(LIB_NAME));
+ }
+
+ for (String path : paths) {
+ File p = new File(path);
+ if (!p.exists()) {
+ continue;
+ }
+ for (String name : mappedLibNames) {
+ if (p.isFile() && p.getName().endsWith(name)) {
+ return p.getAbsolutePath();
+ }
+
+ File file = new File(path, name);
+ if (file.exists() && file.isFile()) {
+ return file.getAbsolutePath();
+ }
+ }
+ }
+ return null;
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/NativeString.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/NativeString.java
new file mode 100644
index 0000000..d43e475
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/NativeString.java
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.jna;
+
+import com.sun.jna.Memory;
+import com.sun.jna.Pointer;
+import java.nio.charset.Charset;
+
+/**
+ * Provides a temporary allocation of an immutable C string (<code>const char*</code> or <code>
+ * const wchar_t*</code>) for use when converting a Java String into a native memory function
+ * argument.
+ */
+final class NativeString {
+
+ private static final ObjectPool<NativeString> POOL = new ObjectPool<>(null, null);
+
+ private Memory pointer;
+
+ /**
+ * Create a native string (NUL-terminated array of <code>char</code>), using the requested
+ * encoding.
+ *
+ * @param data the bytes of the string
+ */
+ private NativeString(byte[] data) {
+ pointer = new Memory(data.length + 1);
+ setData(data);
+ }
+
+ private void setData(byte[] data) {
+ pointer.write(0, data, 0, data.length);
+ pointer.setByte(data.length, (byte) 0);
+ }
+
+ /**
+ * Acquires a pooled {@code NativeString} object if available, otherwise a new instance is
+ * created.
+ *
+ * @param string the string value
+ * @param encoding the charset encoding
+ * @return a {@code NativeString} object
+ */
+ public static NativeString of(String string, Charset encoding) {
+ byte[] data = string.getBytes(encoding);
+ NativeString array = POOL.acquire();
+ if (array != null && array.pointer.size() > data.length) {
+ array.setData(data);
+ return array;
+ }
+ return new NativeString(data);
+ }
+
+ /** Recycles this instance and return it back to the pool. */
+ public void recycle() {
+ POOL.recycle(this);
+ }
+
+ /**
+ * Returns the peer pointer.
+ *
+ * @return the peer pointer
+ */
+ public Pointer getPointer() {
+ return pointer;
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/ObjectPool.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/ObjectPool.java
new file mode 100644
index 0000000..573acd4
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/ObjectPool.java
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.jna;
+
+import java.util.Queue;
+import java.util.concurrent.ConcurrentLinkedQueue;
+import java.util.function.Consumer;
+import java.util.function.Supplier;
+
+/**
+ * A generic object pool implementation.
+ *
+ * @param <T> the type of object to put in the pool
+ */
+@SuppressWarnings("MissingJavadocMethod")
+public class ObjectPool<T> {
+
+ private Queue<T> queue;
+ private Supplier<T> supplier;
+ private Consumer<T> consumer;
+
+ public ObjectPool(Supplier<T> supplier, Consumer<T> consumer) {
+ queue = new ConcurrentLinkedQueue<>();
+ this.supplier = supplier;
+ this.consumer = consumer;
+ }
+
+ public T acquire() {
+ T item = queue.poll();
+ if (item == null) {
+ if (supplier != null) {
+ return supplier.get();
+ }
+ }
+ return item;
+ }
+
+ public void recycle(T item) {
+ if (consumer != null) {
+ consumer.accept(item);
+ }
+ queue.add(item);
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/PointerArray.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/PointerArray.java
new file mode 100644
index 0000000..a864b64
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/PointerArray.java
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.jna;
+
+import com.sun.jna.Function;
+import com.sun.jna.Memory;
+import com.sun.jna.Native;
+import com.sun.jna.Pointer;
+
+/**
+ * An abstraction for a native pointer array data type ({@code void**}).
+ *
+ * @see Pointer
+ * @see com.sun.jna.ptr.PointerByReference
+ * @see Function
+ */
+@SuppressWarnings("checkstyle:EqualsHashCode")
+final class PointerArray extends Memory {
+
+ private static final ObjectPool<PointerArray> POOL = new ObjectPool<>(null, null);
+
+ private int length;
+
+ /**
+ * Constructs a {@link Memory} buffer PointerArray given the Pointers to include in it.
+ *
+ * @param arg the pointers to include in the array
+ */
+ private PointerArray(Pointer... arg) {
+ super(Native.POINTER_SIZE * (arg.length + 1));
+ length = arg.length;
+ setPointers(arg);
+ }
+
+ /**
+ * Acquires a pooled {@code PointerArray} object if available, otherwise a new instance is
+ * created.
+ *
+ * @param arg the pointers to include in the array
+ * @return a {@code PointerArray} object
+ */
+ public static PointerArray of(Pointer... arg) {
+ PointerArray array = POOL.acquire();
+ if (array != null && array.length >= arg.length) {
+ array.setPointers(arg);
+ return array;
+ }
+ return new PointerArray(arg);
+ }
+
+ /** Recycles this instance and return it back to the pool. */
+ public void recycle() {
+ POOL.recycle(this);
+ }
+
+ private void setPointers(Pointer[] pointers) {
+ for (int i = 0; i < pointers.length; i++) {
+ setPointer(i * Native.POINTER_SIZE, pointers[i]);
+ }
+ setPointer(Native.POINTER_SIZE * length, null);
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public boolean equals(Object o) {
+ return o == this;
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/StringArray.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/StringArray.java
new file mode 100644
index 0000000..00ffa8f
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/StringArray.java
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.jna;
+
+import com.sun.jna.Memory;
+import com.sun.jna.Native;
+import com.sun.jna.Pointer;
+import java.nio.charset.Charset;
+import java.util.ArrayList;
+import java.util.List;
+
+/** An abstraction for a native string array data type ({@code char**}). */
+@SuppressWarnings("checkstyle:EqualsHashCode")
+final class StringArray extends Memory {
+
+ private static final Charset ENCODING = Native.DEFAULT_CHARSET;
+ private static final ObjectPool<StringArray> POOL = new ObjectPool<>(null, null);
+ /** Hold all {@code NativeString}, avoid be GCed. */
+ private List<NativeString> natives; // NOPMD
+
+ private int length;
+
+ /**
+ * Create a native array of strings.
+ *
+ * @param strings the strings
+ */
+ private StringArray(String[] strings) {
+ super((strings.length + 1) * Native.POINTER_SIZE);
+ natives = new ArrayList<>();
+ length = strings.length;
+ setPointers(strings);
+ }
+
+ private void setPointers(String[] strings) {
+ for (NativeString ns : natives) {
+ ns.recycle();
+ }
+ natives.clear();
+ for (int i = 0; i < strings.length; ++i) {
+ Pointer p = null;
+ if (strings[i] != null) {
+ NativeString ns = NativeString.of(strings[i], ENCODING);
+ natives.add(ns);
+ p = ns.getPointer();
+ }
+ setPointer(Native.POINTER_SIZE * i, p);
+ }
+ setPointer(Native.POINTER_SIZE * strings.length, null);
+ }
+
+ /**
+ * Acquires a pooled {@code StringArray} object if available, otherwise a new instance is
+ * created.
+ *
+ * @param strings the pointers to include in the array
+ * @return a {@code StringArray} object
+ */
+ public static StringArray of(String[] strings) {
+ StringArray array = POOL.acquire();
+ if (array != null && array.length >= strings.length) {
+ array.setPointers(strings);
+ return array;
+ }
+ return new StringArray(strings);
+ }
+
+ /** Recycles this instance and return it back to the poll. */
+ public void recycle() {
+ POOL.recycle(this);
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public boolean equals(Object o) {
+ return this == o;
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/package-info.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/package-info.java
new file mode 100644
index 0000000..d524d50
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/package-info.java
@@ -0,0 +1,19 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/** Contains the Java front-end implementation for Apache MXNet. */
+package org.apache.mxnet.jna;
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/NDArray.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/NDArray.java
new file mode 100644
index 0000000..f775f44
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/NDArray.java
@@ -0,0 +1,3455 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.ndarray;
+
+import com.sun.jna.Native;
+import com.sun.jna.Pointer;
+import java.io.ByteArrayInputStream;
+import java.io.DataInputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.Buffer;
+import java.nio.ByteBuffer;
+import java.nio.DoubleBuffer;
+import java.nio.FloatBuffer;
+import java.nio.IntBuffer;
+import java.nio.LongBuffer;
+import java.util.Arrays;
+import java.util.stream.IntStream;
+import org.apache.mxnet.engine.BaseMxResource;
+import org.apache.mxnet.engine.Device;
+import org.apache.mxnet.engine.MxResource;
+import org.apache.mxnet.engine.OpParams;
+import org.apache.mxnet.jna.JnaUtils;
+import org.apache.mxnet.ndarray.index.NDIndex;
+import org.apache.mxnet.ndarray.types.DataType;
+import org.apache.mxnet.ndarray.types.Shape;
+import org.apache.mxnet.ndarray.types.SparseFormat;
+import org.apache.mxnet.util.Float16Utils;
+import org.apache.mxnet.util.PairList;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * A Class representing an n-dimensional array.
+ *
+ * <p>NDArray is the core data structure for all mathematical computations. An NDArray represents a
+ * multidimensional, fixed-size homogeneous array. It has very similar behaviour to the Numpy python
+ * package with the addition of efficient computing.
+ */
+public class NDArray extends MxResource {
+
+ private static final Logger logger = LoggerFactory.getLogger(NDArray.class);
+
+ private static final int MAX_SIZE = 100;
+ private static final int MAX_DEPTH = 10;
+ private static final int MAX_ROWS = 10;
+ private static final int MAX_COLUMNS = 20;
+ private static final NDArray[] EMPTY = new NDArray[0];
+
+ private String name;
+ private Device device;
+ private SparseFormat sparseFormat;
+ private DataType dataType;
+ private Shape shape;
+ // use Boolean object to maintain three status: false, true
+ // and null which means the flag is not set by the native engine yet
+ private Boolean hasGradient;
+ private Integer version;
+ private NDArrayEx mxNDArrayEx;
+
+ protected NDArray(Pointer handle) {
+ super(BaseMxResource.getSystemMxResource(), handle);
+ }
+
+ /**
+ * Constructs an {@link NDArray} from a native handle and metadata (internal. Use {@method
+ * create} methods).
+ *
+ * @param parent the parent {@link MxResource} to manage the life circle of the {@link NDArray}
+ * @param handle the pointer to the native NDArray memory
+ * @param device the device the new array will be located on
+ * @param shape the shape of the new array
+ * @param dataType the dataType of the new array
+ * @param hasGradient the gradient status of the new array
+ */
+ NDArray(
+ MxResource parent,
+ Pointer handle,
+ Device device,
+ Shape shape,
+ DataType dataType,
+ boolean hasGradient) {
+ this(parent, handle);
+ setParent(parent);
+ this.device = device;
+ // shape check
+ if (Arrays.stream(shape.getShape()).anyMatch(s -> s < 0)) {
+ throw new IllegalArgumentException("The shape must be >= 0");
+ }
+ this.shape = shape;
+ this.dataType = dataType;
+ this.hasGradient = hasGradient;
+ if (parent != null) {
+ parent.addSubResource(this);
+ }
+ }
+
+ /**
+ * Constructs an {@link NDArray} from a native handle and metadata (internal).
+ *
+ * @param parent the parent {@link MxResource} to manage the life circle of the {@link NDArray}
+ * @param handle the pointer to the native NDArray memory
+ */
+ NDArray(MxResource parent, Pointer handle) {
+ super(parent, handle);
+ this.mxNDArrayEx = new NDArrayEx(this);
+ }
+
+ /**
+ * Constructs an {@link NDArray} from a native handle and metadata (internal).
+ *
+ * @param parent the parent {@link MxResource} to manage the life circle of the {@link NDArray}
+ * @param handle the pointer to the native NDArray memory
+ * @param fmt the sparse format
+ */
+ NDArray(MxResource parent, Pointer handle, SparseFormat fmt) {
+ this(parent, handle);
+ this.sparseFormat = fmt;
+ }
+
+ /**
+ * Creates an NDArray with the given Native Memory Pointer and parent MxResource.
+ *
+ * @param parent the parent {@link MxResource} instance
+ * @param handle the array's native memory pointer
+ * @return the created array
+ */
+ public static NDArray create(MxResource parent, Pointer handle) {
+ return new NDArray(parent, handle);
+ }
+
+ /**
+ * Creates an NDArray with the given Native Memory Pointer and parent MxResource.
+ *
+ * @param parent the parent {@link MxResource} instance
+ * @param handle the array's native memory pointer
+ * @param fmt the sparse format
+ * @return the created array
+ */
+ public static NDArray create(MxResource parent, Pointer handle, SparseFormat fmt) {
+ return new NDArray(parent, handle, fmt);
+ }
+
+ /**
+ * Creates an uninitialized instance of {@link DataType#FLOAT32} {@link NDArray} with specified
+ * parent {@link MxResource}, {@link Shape}, {@link Device} and {@code hasGradient}.
+ *
+ * @param parent the parent {@link MxResource}
+ * @param shape the {@link Shape} of the {@link NDArray}
+ * @param device the {@link Device} of the {@link NDArray}
+ * @return a new instance of {@link NDArray}
+ */
+ public static NDArray create(MxResource parent, Shape shape, Device device) {
+ return create(parent, shape, DataType.FLOAT32, device);
+ }
+
+ /**
+ * Creates an uninitialized instance of {@link DataType#FLOAT32} {@link NDArray} with specified
+ * parent {@link MxResource} and {@link Shape}.
+ *
+ * @param parent the parent {@link MxResource}
+ * @param shape the {@link Shape} of the {@link NDArray}
+ * @return a new instance of {@link NDArray}
+ */
+ public static NDArray create(MxResource parent, Shape shape) {
+ return create(parent, shape, DataType.FLOAT32, Device.defaultIfNull());
+ }
+
+ /**
+ * Creates an uninitialized instance of {@link NDArray} with specified parent {@link
+ * MxResource}, {@link Shape}, {@link DataType}, {@link Device} and {@code hasGradient}.
+ *
+ * @param parent the parent {@link MxResource}
+ * @param shape the {@link Shape} of the {@link NDArray}
+ * @param dataType the {@link DataType} of the {@link NDArray}
+ * @param device the {@link Device} of the {@link NDArray}
+ * @param hasGradient true if the gradient calculation is required for this {@link NDArray}
+ * @return a new instance of {@link NDArray}
+ */
+ public static NDArray create(
+ MxResource parent, Shape shape, DataType dataType, Device device, boolean hasGradient) {
+ Pointer handle =
+ JnaUtils.createNdArray(device, shape, dataType, shape.dimension(), hasGradient);
+ return new NDArray(parent, handle, device, shape, dataType, hasGradient);
+ }
+
+ /**
+ * Creates an uninitialized instance of {@link NDArray} with specified parent {@link
+ * MxResource}, {@link Shape}, {@link DataType}, {@link Device}.
+ *
+ * @param parent the parent {@link MxResource}
+ * @param shape the {@link Shape} of the {@link NDArray}
+ * @param dataType the {@link DataType} of the {@link NDArray}
+ * @param device the {@link Device} of the {@link NDArray}
+ * @return a new instance of {@link NDArray}
+ */
+ public static NDArray create(MxResource parent, Shape shape, DataType dataType, Device device) {
+ Pointer handle = JnaUtils.createNdArray(device, shape, dataType, shape.dimension(), false);
+ return new NDArray(parent, handle, Device.defaultIfNull(device), shape, dataType, false);
+ }
+
+ /**
+ * Creates and initializes a scalar {@link NDArray}.
+ *
+ * @param parent the parent {@link MxResource} instance
+ * @param data the {@link Number} that needs to be set
+ * @return a new instance of {@link NDArray}
+ * @throws IllegalArgumentException when the Type of data is not expected
+ */
+ public static NDArray create(MxResource parent, Number data) {
+ if (data instanceof Integer) {
+ return create(parent, data.intValue());
+ } else if (data instanceof Float) {
+ return create(parent, data.floatValue());
+ } else if (data instanceof Double) {
+ return create(parent, data.doubleValue());
+ } else if (data instanceof Long) {
+ return create(parent, data.longValue());
+ } else if (data instanceof Byte) {
+ return create(parent, data.byteValue());
+ } else {
+ throw new IllegalArgumentException("Short conversion not supported!");
+ }
+ }
+
+ /**
+ * Creates and initializes an instance of {@link NDArray} with specified {@link Shape} and float
+ * array.
+ *
+ * @param parent the parent {@link MxResource} to manage this instance
+ * @param data the float array that needs to be set
+ * @param shape the {@link Shape} of the {@link NDArray}
+ * @return a new instance of {@link NDArray}
+ */
+ public static NDArray create(MxResource parent, float[] data, Shape shape) {
+ return create(parent, FloatBuffer.wrap(data), shape);
+ }
+
+ /**
+ * Creates and initializes an instance of {@link NDArray} with specified {@link Shape} and int
+ * array.
+ *
+ * @param parent the parent {@link MxResource} to manage this instance
+ * @param data the float array that needs to be set
+ * @param shape the {@link Shape} of the {@link NDArray}
+ * @return a new instance of {@link NDArray}
+ */
+ public static NDArray create(MxResource parent, int[] data, Shape shape) {
+ return create(parent, IntBuffer.wrap(data), shape);
+ }
+
+ /**
+ * Creates and initializes an instance of {@link NDArray} with specified {@link Shape} and
+ * double array.
+ *
+ * @param parent the parent {@link MxResource} to manage this instance
+ * @param data the float array that needs to be set
+ * @param shape the {@link Shape} of the {@link NDArray}
+ * @return a new instance of {@link NDArray}
+ */
+ public static NDArray create(MxResource parent, double[] data, Shape shape) {
+ return create(parent, DoubleBuffer.wrap(data), shape);
+ }
+
+ /**
+ * Creates and initializes an instance of {@link NDArray} with specified {@link Shape} and long
+ * array.
+ *
+ * @param parent the parent {@link MxResource} to manage this instance
+ * @param data the float array that needs to be set
+ * @param shape the {@link Shape} of the {@link NDArray}
+ * @return a new instance of {@link NDArray}
+ */
+ public static NDArray create(MxResource parent, long[] data, Shape shape) {
+ return create(parent, LongBuffer.wrap(data), shape);
+ }
+
+ /**
+ * Creates and initializes an instance of {@link NDArray} with specified {@link Shape} and byte
+ * array.
+ *
+ * @param parent the parent {@link MxResource} to manage this instance
+ * @param data the float array that needs to be set
+ * @param shape the {@link Shape} of the {@link NDArray}
+ * @return a new instance of {@link NDArray}
+ */
+ public static NDArray create(MxResource parent, byte[] data, Shape shape) {
+ return create(parent, ByteBuffer.wrap(data), shape);
+ }
+
+ /**
+ * Creates and initializes an instance of {@link NDArray} with specified {@link Shape} and
+ * boolean array.
+ *
+ * @param parent the parent {@link MxResource} to manage this instance
+ * @param data the boolean array that needs to be set
+ * @param shape the {@link Shape} of the {@link NDArray}
+ * @return a new instance of {@link NDArray}
+ */
+ public static NDArray create(MxResource parent, boolean[] data, Shape shape) {
+ byte[] byteData = new byte[data.length];
+ for (int i = 0; i < data.length; i++) {
+ byteData[i] = (byte) (data[i] ? 1 : 0);
+ }
+ return create(parent, ByteBuffer.wrap(byteData), shape, DataType.BOOLEAN);
+ }
+
+ /**
+ * Creates and initializes a scalar {@link NDArray}.
+ *
+ * @param parent the parent {@link MxResource} to manage this instance
+ * @param data the float that needs to be set
+ * @return a new instance of {@link NDArray}
+ */
+ public static NDArray create(MxResource parent, float data) {
+ return create(parent, new float[] {data}, new Shape());
+ }
+
+ /**
+ * Creates and initializes a scalar {@link NDArray}.
+ *
+ * @param parent the parent {@link MxResource} to manage this instance
+ * @param data the float data that needs to be set
+ * @return a new instance of {@link NDArray}
+ */
+ public static NDArray create(MxResource parent, int data) {
+ return create(parent, new int[] {data}, new Shape());
+ }
+
+ /**
+ * Creates and initializes a scalar {@link NDArray}.
+ *
+ * @param parent the parent {@link MxResource} to manage this instance
+ * @param data the double data that needs to be set
+ * @return a new instance of {@link NDArray}
+ */
+ public static NDArray create(MxResource parent, double data) {
+ return create(parent, new double[] {data}, new Shape());
+ }
+
+ /**
+ * Creates and initializes a scalar {@link NDArray}.
+ *
+ * @param parent the parent {@link MxResource} to manage this instance
+ * @param data the long data that needs to be set
+ * @return a new instance of {@link NDArray}
+ */
+ public static NDArray create(MxResource parent, long data) {
+ return create(parent, new long[] {data}, new Shape());
+ }
+
+ /**
+ * Creates and initializes a scalar {@link NDArray}.
+ *
+ * @param parent the parent {@link MxResource} to manage this instance
+ * @param data the byte data that needs to be set
+ * @return a new instance of {@link NDArray}
+ */
+ public static NDArray create(MxResource parent, byte data) {
+ return create(parent, new byte[] {data}, new Shape());
+ }
+
+ /**
+ * Creates and initializes a scalar {@link NDArray}.
+ *
+ * @param parent the parent {@link MxResource} to manage this instance
+ * @param data the boolean data that needs to be set
+ * @return a new instance of {@link NDArray}
+ */
+ public static NDArray create(MxResource parent, boolean data) {
+
+ return create(parent, new boolean[] {data}, new Shape());
+ }
+
+ /**
+ * Creates and initializes a 1D {@link NDArray}.
+ *
+ * @param parent the parent {@link MxResource} to manage this instance
+ * @param data the float array that needs to be set
+ * @return a new instance of {@link NDArray}
+ */
+ public static NDArray create(MxResource parent, float[] data) {
+ return create(parent, data, new Shape(data.length));
+ }
+
+ /**
+ * Creates and initializes a 1D {@link NDArray}.
+ *
+ * @param parent the parent {@link MxResource} to manage this instance
+ * @param data the float array that needs to be set
+ * @return a new instance of {@link NDArray}
+ */
+ public static NDArray create(MxResource parent, int[] data) {
+ return create(parent, data, new Shape(data.length));
+ }
+
+ /**
+ * Creates and initializes a 1D {@link NDArray}.
+ *
+ * @param parent the parent {@link MxResource} to manage this instance
+ * @param data the float array that needs to be set
+ * @return a new instance of {@link NDArray}
+ */
+ public static NDArray create(MxResource parent, double[] data) {
+ return create(parent, data, new Shape(data.length));
+ }
+
+ /**
+ * Creates and initializes a 1D {@link NDArray}.
+ *
+ * @param parent the parent {@link MxResource} to manage this instance
+ * @param data the float array that needs to be set
+ * @return a new instance of {@link NDArray}
+ */
+ public static NDArray create(MxResource parent, long[] data) {
+ return create(parent, data, new Shape(data.length));
+ }
+
+ /**
+ * Creates and initializes a 1D {@link NDArray}.
+ *
+ * @param parent the parent {@link MxResource} to manage this instance
+ * @param data the float array that needs to be set
+ * @return a new instance of {@link NDArray}
+ */
+ public static NDArray create(MxResource parent, byte[] data) {
+ return create(parent, data, new Shape(data.length));
+ }
+
+ /**
+ * Creates and initializes a 1D {@link NDArray}.
+ *
+ * @param parent the parent {@link MxResource} to manage this instance
+ * @param data the bool array that needs to be set
+ * @return a new instance of {@link NDArray}
+ */
+ public static NDArray create(MxResource parent, boolean[] data) {
+ return create(parent, data, new Shape(data.length));
+ }
+
+ /**
+ * Creates and initializes a 2D {@link NDArray}.
+ *
+ * @param parent the parent {@link MxResource} to manage this instance
+ * @param data the float array that needs to be set
+ * @return a new instance of {@link NDArray}
+ */
+ public static NDArray create(MxResource parent, float[][] data) {
+ FloatBuffer buffer = FloatBuffer.allocate(data.length * data[0].length);
+ for (float[] d : data) {
+ buffer.put(d);
+ }
+ buffer.rewind();
+ return create(parent, buffer, new Shape(data.length, data[0].length));
+ }
+
+ /**
+ * Creates and initializes a 2D {@link NDArray}.
+ *
+ * @param parent the parent {@link MxResource} to manage this instance
+ * @param data the float array that needs to be set
+ * @return a new instance of {@link NDArray}
+ */
+ public static NDArray create(MxResource parent, int[][] data) {
+ IntBuffer buffer = IntBuffer.allocate(data.length * data[0].length);
+ for (int[] d : data) {
+ buffer.put(d);
+ }
+ buffer.rewind();
+ return create(parent, buffer, new Shape(data.length, data[0].length));
+ }
+
+ /**
+ * Creates and initializes a 2D {@link NDArray}.
+ *
+ * @param parent the parent {@link MxResource} to manage this instance
+ * @param data the float array that needs to be set
+ * @return a new instance of {@link NDArray}
+ */
+ public static NDArray create(MxResource parent, double[][] data) {
+ DoubleBuffer buffer = DoubleBuffer.allocate(data.length * data[0].length);
+ for (double[] d : data) {
+ buffer.put(d);
+ }
+ buffer.rewind();
+ return create(parent, buffer, new Shape(data.length, data[0].length));
+ }
+
+ /**
+ * Creates and initializes a 2-D {@link NDArray}.
+ *
+ * @param parent the parent {@link MxResource} to manage this instance
+ * @param data the float array that needs to be set
+ * @return a new instance of {@link NDArray}
+ */
+ public static NDArray create(MxResource parent, long[][] data) {
+ LongBuffer buffer = LongBuffer.allocate(data.length * data[0].length);
+ for (long[] d : data) {
+ buffer.put(d);
+ }
+ buffer.rewind();
+ return create(parent, buffer, new Shape(data.length, data[0].length));
+ }
+
+ /**
+ * Creates and initializes a 2-D {@link NDArray}.
+ *
+ * @param parent the parent {@link MxResource} instance
+ * @param data the float array that needs to be set
+ * @return a new instance of {@link NDArray}
+ */
+ public static NDArray create(MxResource parent, byte[][] data) {
+ ByteBuffer buffer = ByteBuffer.allocate(data.length * data[0].length);
+ for (byte[] d : data) {
+ buffer.put(d);
+ }
+ buffer.rewind();
+ return create(parent, buffer, new Shape(data.length, data[0].length));
+ }
+
+ /**
+ * Creates and initializes a 2-D {@link NDArray}.
+ *
+ * @param parent the parent {@link MxResource} instance
+ * @param data the boolean array that needs to be set
+ * @return a new instance of {@link NDArray}
+ */
+ public static NDArray create(MxResource parent, boolean[][] data) {
+ ByteBuffer buffer = ByteBuffer.allocate(data.length * data[0].length);
+ for (boolean[] d : data) {
+ for (boolean b : d) {
+ buffer.put((byte) (b ? 1 : 0));
+ }
+ }
+ buffer.rewind();
+ return create(parent, buffer, new Shape(data.length, data[0].length), DataType.BOOLEAN);
+ }
+
+ /**
+ * Creates and initializes a {@link NDArray} with specified {@link Shape}.
+ *
+ * <p>{@link DataType} of the MxNDArray will determined by type of Buffer.
+ *
+ * @param parent the parent {@link MxResource} instance
+ * @param data the data to initialize the {@code MxNDArray}
+ * @param shape the {@link Shape} of the {@link NDArray}
+ * @return a new instance of {@link NDArray}
+ */
+ static NDArray create(MxResource parent, Buffer data, Shape shape) {
+ DataType dataType = DataType.fromBuffer(data);
+ return create(parent, data, shape, dataType);
+ }
+
+ static NDArray create(MxResource parent, Buffer data, Shape shape, DataType dataType) {
+ NDArray array = create(parent, shape, dataType, Device.defaultIfNull());
+ array.set(data);
+ return array;
+ }
+
+ /**
+ * Returns the name of this {@code NDArray}.
+ *
+ * @return the name of this {@code NDArray}
+ */
+ public String getName() {
+ return name;
+ }
+
+ /**
+ * Sets the name of this {@code NDArray}.
+ *
+ * @param name of the {@code NDArray}
+ */
+ public void setName(String name) {
+ this.name = name;
+ }
+
+ /**
+ * Returns the {@link DataType} of this {@code NDArray}.
+ *
+ * @return the {@link DataType} of this {@code NDArray}
+ */
+ public DataType getDataType() {
+ if (this.dataType == null) {
+ this.dataType = JnaUtils.getDataTypeOfNdArray(getHandle());
+ }
+ return this.dataType;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public Device getDevice() {
+ if (this.device == null) {
+ this.device = JnaUtils.getDeviceOfNdArray(getHandle());
+ }
+ return this.device;
+ }
+
+ /**
+ * Returns the {@link Shape} of this {@code NDArray}.
+ *
+ * @return the {@link Shape} of this {@code NDArray}
+ */
+ public Shape getShape() {
+ if (this.shape == null) {
+ this.shape = JnaUtils.getShapeOfNdArray(getHandle());
+ }
+ return this.shape;
+ }
+
+ /**
+ * Returns the {@link SparseFormat} of this {@code NDArray}.
+ *
+ * @return the {@link SparseFormat} of this {@code NDArray}
+ */
+ public SparseFormat getSparseFormat() {
+ if (this.sparseFormat == null) {
+ this.sparseFormat = JnaUtils.getStorageType(getHandle());
+ }
+ return this.sparseFormat;
+ }
+
+ /**
+ * Returns the version of this {@code NDArray}.
+ *
+ * @return the version of this {@code NDArray}
+ */
+ public Integer getVersion() {
+ if (this.version == null) {
+ this.version = JnaUtils.getVersion();
+ }
+ return this.version;
+ }
+
+ private NDArray duplicate(Shape shape, DataType dataType, Device device, String name) {
+ // TODO get copy parameter
+ NDArray array = create(getParent(), shape, dataType, device);
+ array.setName(name);
+ copyTo(array);
+ return array;
+ }
+
+ /**
+ * Returns a copy of this {@code NDArray}.
+ *
+ * @return a copy of this {@code NDArray}
+ */
+ NDArray duplicate() {
+ NDArray array = create(getParent(), getShape(), getDataType(), getDevice());
+ array.setName(getName());
+ copyTo(array);
+ return array;
+ }
+
+ /**
+ * Moves this {@code NDArray} to a different {@link Device}.
+ *
+ * @param device the {@link Device} to be set
+ * @param copy set {@code true} if you want to return a copy of the Existing {@code NDArray}
+ * @return the result {@code NDArray} with the new {@link Device}
+ */
+ public NDArray toDevice(Device device, boolean copy) {
+ if (device.equals(getDevice()) && !copy) {
+ return this;
+ }
+ return duplicate(getShape(), getDataType(), device, getName());
+ }
+
+ /**
+ * Converts this {@code NDArray} to a different {@link DataType}.
+ *
+ * @param dataType the {@link DataType} to be set
+ * @param copy set {@code true} if you want to return a copy of the Existing {@code NDArray}
+ * @return the result {@code NDArray} with the new {@link DataType}
+ */
+ public NDArray toType(DataType dataType, boolean copy) {
+ if (dataType.equals(getDataType()) && !copy) {
+ return this;
+ }
+ return duplicate(getShape(), dataType, getDevice(), getName());
+ }
+
+ /**
+ * Creates an instance of {@link NDArray} with specified {@link Shape} filled with zeros.
+ *
+ * @param shape the {@link Shape} of the {@link NDArray}
+ * @param dataType the {@link DataType} of the {@link NDArray}
+ * @return a new instance of {@link NDArray}
+ * @see #zeros(Shape, DataType, Device)
+ */
+ public NDArray zeros(Shape shape, DataType dataType) {
+ return fill("_npi_zeros", shape, dataType);
+ }
+
+ /**
+ * Creates an instance of {@link NDArray} with the same {@link Shape} and {@link DataType}
+ * filled with zeros.
+ *
+ * @return a new instance of {@link NDArray}
+ * @see #zeros(Shape, DataType, Device)
+ */
+ public NDArray zeros() {
+ return zeros(getShape(), getDataType());
+ }
+
+ /**
+ * Creates an instance of {@link NDArray} with specified {@link Device}, {@link Shape}, and
+ * {@link DataType} filled with zeros.
+ *
+ * @param shape the {@link Shape} of the {@link NDArray}
+ * @param dataType the {@link DataType} of the {@link NDArray}
+ * @param device the {@link Device} of the {@link NDArray}
+ * @return a new instance of {@link NDArray}
+ */
+ NDArray zeros(Shape shape, DataType dataType, Device device) {
+ if (device == null || device.equals(getDevice())) {
+ return zeros(shape, dataType);
+ }
+ return zeros(shape, dataType);
+ }
+
+ private NDArray createGradient(SparseFormat format) {
+ try (NDArray zeros = this.zeros(getShape(), getDataType(), getDevice())) {
+ return zeros.toSparse(format);
+ }
+ }
+
+ private NDArray fill(String opName, Shape shape, DataType dataType) {
+ OpParams params = new OpParams();
+ if (shape == null) {
+ throw new IllegalArgumentException("Shape is required for " + opName.substring(1));
+ }
+ params.addParam("shape", shape);
+ params.setDevice(device);
+ params.setDataType(dataType);
+ return invoke(getParent(), opName, params);
+ }
+
+ /**
+ * Creates an instance of {@link NDArray} with specified {@link Shape} filled with ones.
+ *
+ * @param parent the parent {@link MxResource}
+ * @param shape the {@link Shape} of the {@link NDArray}
+ * @param dataType the {@link DataType} of the {@link NDArray}
+ * @param device the {@link Device} of the {@link NDArray}
+ * @return a new instance of {@link NDArray}
+ */
+ public static NDArray ones(MxResource parent, Shape shape, DataType dataType, Device device) {
+ return create(parent, shape, dataType, device).ones();
+ }
+
+ /**
+ * Creates an instance of {@link NDArray} with specified {@link Shape} filled with ones.
+ *
+ * @param shape the {@link Shape} of the {@link NDArray}
+ * @param dataType the {@link DataType} of the {@link NDArray}
+ * @return a new instance of {@link NDArray}
+ */
+ NDArray ones(Shape shape, DataType dataType) {
+ return fill("_npi_ones", shape, dataType);
+ }
+
+ /**
+ * Creates an instance of {@link NDArray} with same {@link Shape} and {@link DataType} filled
+ * with ones.
+ *
+ * @return a new instance of {@link NDArray}
+ */
+ public NDArray ones() {
+ return ones(getShape(), getDataType());
+ }
+ /**
+ * Creates an instance of {@link NDArray} with specified {@link Shape} filled with ones.
+ *
+ * @param shape the {@link Shape} of the {@link NDArray}
+ * @return a new instance of {@link NDArray}
+ */
+ NDArray ones(Shape shape) {
+ return ones(shape, DataType.FLOAT32);
+ }
+
+ /**
+ * Creates an instance of {@link NDArray} with specified {@link Device}, {@link Shape}, and
+ * {@link DataType} filled with ones.
+ *
+ * @param shape the {@link Shape} of the {@link NDArray}
+ * @param dataType the {@link DataType} of the {@link NDArray}
+ * @param device the {@link Device} of the {@link NDArray}
+ * @return a new instance of {@link NDArray}
+ */
+ NDArray ones(Shape shape, DataType dataType, Device device) {
+ if (device == null || device.equals(getDevice())) {
+ return ones(shape, dataType);
+ }
+ return create(getParent(), shape, dataType, device).ones();
+ }
+
+ /**
+ * Returns the gradient {@code NDArray} attached to this {@code NDArray}.
+ *
+ * @return the gradient {@code NDArray}
+ * @throws IllegalStateException when hasGradient is false
+ */
+ public NDArray getGradient() {
+ if (!hasGradient()) {
+ throw new IllegalStateException(
+ "No gradient attached to this MxNDArray, please call array.requiredGradient()"
+ + "on your MxNDArray or block.setInitializer() on your Block");
+ }
+ Pointer pointer = JnaUtils.getGradient(getHandle());
+ return create(getParent(), pointer);
+ }
+
+ /**
+ * Returns true if the gradient calculation is required for this {@code NDArray}.
+ *
+ * @return true if the gradient calculation is required for this {@code NDArray} else false
+ */
+ public boolean hasGradient() {
+ if (hasGradient == null) {
+ Pointer pointer = JnaUtils.getGradient(getHandle());
+ hasGradient = pointer != null;
+ }
+ return hasGradient;
+ }
+
+ /**
+ * Returns an NDArray equal to this that stop gradient propagation through it.
+ *
+ * @return an NDArray equal to this that stops gradient propagation through it
+ */
+ public NDArray stopGradient() {
+ Pointer pointer = JnaUtils.detachGradient(getHandle());
+ return create(getParent(), pointer);
+ }
+
+ /**
+ * Converts this {@code NDArray} to a String array.
+ *
+ * <p>This method is only applicable to the String typed NDArray and not for printing purpose
+ *
+ * @return Array of Strings
+ */
+ public String[] toStringArray() {
+ throw new UnsupportedOperationException("String MxNDArray is not supported!");
+ }
+
+ /**
+ * Converts this {@code NDArray} to a ByteBuffer.
+ *
+ * @return a ByteBuffer
+ */
+ public ByteBuffer toByteBuffer() {
+ if (getSparseFormat() != SparseFormat.DENSE) {
+ throw new IllegalStateException("Require Dense MxNDArray, actual " + getSparseFormat());
+ }
+ Shape sh = getShape();
+ DataType dType = getDataType();
+ long product = sh.size();
+ long len = dType.getNumOfBytes() * product;
+ ByteBuffer bb = NDSerializer.allocateDirect(Math.toIntExact(len));
+ Pointer pointer = Native.getDirectBufferPointer(bb);
+ JnaUtils.syncCopyToCPU(getHandle(), pointer, Math.toIntExact(product));
+ return bb;
+ }
+
+ /**
+ * Returns the total number of elements in this {@code MxNDArray}.
+ *
+ * @return the number of elements in this {@code MxNDArray}
+ */
+ long size() {
+ return getShape().size();
+ }
+
+ long size(int axis) {
+ return getShape().size(axis);
+ }
+
+ /**
+ * Sets this {@code NDArray} value from {@link Buffer}.
+ *
+ * @param data the input buffered data
+ */
+ public void set(Buffer data) {
+ int size = Math.toIntExact(size());
+ if (data.remaining() < size) {
+ throw new IllegalArgumentException(
+ "The MxNDArray size is: " + size + ", but buffer size is: " + data.remaining());
+ }
+ if (data.isDirect()) {
+ JnaUtils.syncCopyFromCPU(getHandle(), data, size);
+ return;
+ }
+
+ data.limit(size);
+ // int8, uint8, boolean use ByteBuffer, so need to explicitly input DataType
+ DataType inputType = DataType.fromBuffer(data);
+ validate(inputType);
+
+ int numOfBytes = inputType.getNumOfBytes();
+ ByteBuffer buf = NDSerializer.allocateDirect(size * numOfBytes);
+
+ switch (inputType) {
+ case FLOAT32:
+ buf.asFloatBuffer().put((FloatBuffer) data);
+ break;
+ case FLOAT64:
+ buf.asDoubleBuffer().put((DoubleBuffer) data);
+ break;
+ case UINT8:
+ case INT8:
+ case BOOLEAN:
+ buf.put((ByteBuffer) data);
+ break;
+ case INT32:
+ buf.asIntBuffer().put((IntBuffer) data);
+ break;
+ case INT64:
+ buf.asLongBuffer().put((LongBuffer) data);
+ break;
+ case FLOAT16:
+ default:
+ throw new UnsupportedOperationException("data type is not supported!");
+ }
+ buf.rewind();
+ JnaUtils.syncCopyFromCPU(getHandle(), buf, size);
+ }
+
+ private void validate(DataType inputType) {
+ if (getDataType() != inputType
+ && ((dataType != DataType.UINT8 && dataType != DataType.BOOLEAN)
+ || inputType != DataType.INT8)) {
+ // Infer DataType from Buffer always return INT8, make this two special case that
+ // allows set UINT8 and BOOL array with regular ByteBuffer.
+ throw new IllegalStateException(
+ "DataType mismatch, required: " + dataType + ", actual: " + inputType);
+ }
+ }
+
+ /**
+ * Returns {@code true} if this {@code MxNDArray} is a scalar {@code MxNDArray} with empty
+ * {@link Shape}.
+ *
+ * @return {@code true} if this {@code MxNDArray} is a scalar {@code MxNDArray} with empty
+ * {@link Shape}
+ */
+ boolean isScalar() {
+ return getShape().isScalar();
+ }
+
+ /**
+ * Returns {@code true} if all elements within this {@code NDArray} are non-zero or {@code
+ * true}.
+ *
+ * @return {@code true} if all elements within this {@code NDArray} are non-zero or {@code true}
+ */
+ NDArray all() {
+ // result of sum operation is int64 now
+ return toType(DataType.BOOLEAN, false).sum().eq(size());
+ }
+
+ /**
+ * Deep-copies the current {@code NDArray} to the one passed in.
+ *
+ * @param ndArray this {@code NDArray} prepared to be copied to
+ */
+ public void copyTo(NDArray ndArray) {
+
+ Shape inShape = getShape();
+ Shape destShape = ndArray.getShape();
+ if (!Arrays.equals(inShape.getShape(), destShape.getShape())) {
+ throw new IllegalArgumentException(
+ "shape are diff. Required: " + destShape + ", Actual " + inShape);
+ }
+ JnaUtils.op("_npi_copyto").invoke(new NDArray[] {this}, new NDArray[] {ndArray}, null);
+ }
+
+ NDArray booleanMask(NDArray index) {
+ return booleanMask(index, 0);
+ }
+
+ /**
+ * Returns portion of this {@code NDArray} given the index boolean {@code NDArray} along given
+ * axis.
+ *
+ * @param index boolean {@code NDArray} mask
+ * @param axis an integer that represents the axis of {@code NDArray} to mask from
+ * @return the result {@code NDArray}
+ */
+ public NDArray booleanMask(NDArray index, int axis) {
+ if (isScalar() || index.isScalar()) {
+ throw new IllegalArgumentException("booleanMask didn't support scalar!");
+ }
+ // TODO remove reshape when MXNet numpy support multi-dim index
+ // and boolean MxNDArray reshape
+ Shape remainingDims = getShape().slice(index.getShape().dimension());
+ // create a reshape array {-1, remainingDims}
+ long[] reshape = new long[remainingDims.dimension() + 1];
+ reshape[0] = -1;
+ System.arraycopy(remainingDims.getShape(), 0, reshape, 1, remainingDims.dimension());
+ OpParams params = new OpParams();
+ params.addParam("axis", axis);
+ try (NDArray reshaped = this.reshape(new Shape(reshape));
+ NDArray reshapedIndex = index.toType(DataType.INT32, false).reshape(-1);
+ NDArray result =
+ invoke(
+ getParent(),
+ "_npi_boolean_mask",
+ new NDArray[] {reshaped, reshapedIndex},
+ params)) {
+ return result.reshape(reshape);
+ }
+ }
+
+ /**
+ * Sets all elements outside the sequence to a constant value.
+ *
+ * <p>This function takes an n-dimensional input array of the form [batch_size,
+ * max_sequence_length, ....] and returns an array of the same shape. Parameter {@code
+ * sequenceLength} is used to handle variable-length sequences. sequence_length should be an
+ * input array of positive ints of dimension [batch_size].
+ *
+ * @param sequenceLength used to handle variable-length sequences
+ * @param value the constant value to be set
+ * @return the result {@code NDArray}
+ */
+ public NDArray sequenceMask(NDArray sequenceLength, float value) {
+ if (getShape().dimension() < 2 || getShape().isScalar() || getShape().hasZeroDimension()) {
+ throw new IllegalArgumentException(
+ "sequenceMask is not supported for MxNDArray with less than 2 dimensions");
+ }
+ Shape expectedSequenceLengthShape = new Shape(getShape().get(0));
+ if (!sequenceLength.getShape().equals(expectedSequenceLengthShape)) {
+ throw new IllegalArgumentException("SequenceLength must be of shape [batchSize]");
+ }
+ OpParams params = new OpParams();
+ params.add("value", value);
+ params.add("use_sequence_length", true);
+ params.add("axis", 1);
+ return invoke(getParent(), "_npx_sequence_mask", new NDList(this, sequenceLength), params)
+ .head();
+ }
+
+ /**
+ * Sets all elements outside the sequence to 0.
+ *
+ * <p>This function takes an n-dimensional input array of the form [batch_size,
+ * max_sequence_length, ....] and returns an array of the same shape. Parameter {@code
+ * sequenceLength} is used to handle variable-length sequences. sequence_length should be an
+ * input array of positive ints of dimension [batch_size].
+ *
+ * @param sequenceLength used to handle variable-length sequences
+ * @return the result {@code NDArray}
+ */
+ public NDArray sequenceMask(NDArray sequenceLength) {
+ return sequenceMask(sequenceLength, 0);
+ }
+
+ /**
+ * Returns an {@code NDArray} of zeros with the same {@link Shape}, {@link DataType} and {@link
+ * SparseFormat} as the input {@code NDArray}.
+ *
+ * @return a {@code NDArray} filled with zeros
+ */
+ public NDArray zerosLike() {
+ OpParams params = new OpParams();
+ params.addParam("fill_value", 0);
+ return invoke(getParent(), "_npi_full_like", this, params);
+ }
+
+ /**
+ * Returns an {@code NDArray} of ones with the same {@link Shape}, {@link DataType} and {@link
+ * SparseFormat} as the input {@code NDArray}.
+ *
+ * @return a {@code NDArray} filled with ones
+ */
+ public NDArray onesLike() {
+ OpParams params = new OpParams();
+ params.addParam("fill_value", 1);
+ return invoke(getParent(), "_npi_full_like", this, params);
+ }
+
+ NDArray get(NDIndex index) {
+ return getNDArrayInternal().getIndexer().get(this, index);
+ }
+
+ NDArray get(long... indices) {
+ return get(new NDIndex(indices));
+ }
+
+ NDArray getScalar(long... indices) {
+ NDArray value = get(new NDIndex(indices));
+ if (value.size() != 1) {
+ throw new IllegalArgumentException("The supplied Index does not produce a scalar");
+ }
+ return value;
+ }
+
+ boolean getBoolean(long... indices) {
+ return getScalar(indices).toBooleanArray()[0];
+ }
+
+ /**
+ * Returns {@code true} if all elements in this {@code NDArray} are equal to the {@link Number}.
+ *
+ * @param number the number to compare
+ * @return the boolean result
+ */
+ public boolean contentEquals(Number number) {
+ if (number == null) {
+ return false;
+ }
+ try (NDArray result = eq(number)) {
+ return result.all().getBoolean();
+ }
+ }
+
+ /**
+ * Returns {@code true} if all elements in this {@code NDArray} are equal to the other {@link
+ * NDArray}.
+ *
+ * @param other the other {@code NDArray} to compare
+ * @return the boolean result
+ */
+ public boolean contentEquals(NDArray other) {
+ if (other == null || (!shapeEquals(other))) {
+ return false;
+ }
+ if (getDataType() != other.getDataType()) {
+ return false;
+ }
+ try (NDArray result = eq(other).toType(DataType.INT32, false)) {
+ return result.all().getBoolean();
+ }
+ }
+
+ /**
+ * Returns the boolean {@code NDArray} for element-wise "Equals" comparison.
+ *
+ * @param n the number to compare
+ * @return the boolean {@code NDArray} for element-wise "Equals" comparison
+ */
+ public NDArray eq(Number n) {
+ OpParams params = new OpParams();
+ params.add("scalar", n.toString());
+ return invoke(getParent(), "_npi_equal_scalar", this, params);
+ }
+
+ /**
+ * Returns the boolean {@code NDArray} for element-wise "Equals" comparison.
+ *
+ * @param other the {@code NDArray} to compare
+ * @return the boolean {@code NDArray} for element-wise "Equals" comparison
+ */
+ public NDArray eq(NDArray other) {
+ return invoke(getParent(), "_npi_equal", new NDArray[] {this, other}, null);
+ }
+
+ /**
+ * Returns the boolean {@code NDArray} for element-wise "Not equals" comparison.
+ *
+ * @param n the number to compare
+ * @return the boolean {@code NDArray} for element-wise "Not equals" comparison
+ */
+ public NDArray neq(Number n) {
+ OpParams params = new OpParams();
+ params.add("scalar", n.toString());
+ return invoke(getParent(), "_npi_not_equal_scalar", this, params);
+ }
+
+ /**
+ * Returns the boolean {@code NDArray} for element-wise "Not equals" comparison.
+ *
+ * @param other the {@code NDArray} to compare
+ * @return the boolean {@code NDArray} for element-wise "Not equals" comparison
+ */
+ public NDArray neq(NDArray other) {
+ return invoke(getParent(), "_npi_not_equal", new NDArray[] {this, other}, null);
+ }
+
+ /**
+ * Returns the boolean {@code NDArray} for element-wise "Greater" comparison.
+ *
+ * @param other the number to compare
+ * @return the boolean {@code NDArray} for element-wise "Greater" comparison
+ */
+ public NDArray gt(Number other) {
+ OpParams params = new OpParams();
+ params.add("scalar", other.toString());
+ return invoke(getParent(), "_npi_greater_scalar", this, params);
+ }
+
+ /**
+ * Returns the boolean {@code NDArray} for element-wise "Greater Than" comparison.
+ *
+ * @param other the {@code NDArray} to compare
+ * @return the boolean {@code NDArray} for element-wis "Greater Than" comparison
+ */
+ public NDArray gt(NDArray other) {
+ return invoke(getParent(), "_npi_greater", new NDArray[] {this, other}, null);
+ }
+
+ /**
+ * Returns the boolean {@code NDArray} for element-wise "Greater or equals" comparison.
+ *
+ * @param other the number to compare
+ * @return the boolean {@code NDArray} for element-wise "Greater or equals" comparison
+ */
+ public NDArray gte(Number other) {
+ OpParams params = new OpParams();
+ params.add("scalar", other.toString());
+ return invoke(getParent(), "_npi_greater_equal_scalar", this, params);
+ }
+
+ /**
+ * Returns the boolean {@code NDArray} for element-wise "Greater or equals" comparison.
+ *
+ * @param other the number to compare
+ * @return the boolean {@code NDArray} for "Greater or equals" comparison
+ */
+ public NDArray gte(NDArray other) {
+ return invoke(getParent(), "_npi_greater_equal", new NDArray[] {this, other}, null);
+ }
+
+ /**
+ * Returns the boolean {@code NDArray} for element-wise "Less" comparison.
+ *
+ * @param other the number to compare
+ * @return the boolean {@code NDArray} for element-wise "Less" comparison
+ */
+ public NDArray lt(Number other) {
+ OpParams params = new OpParams();
+ params.add("scalar", other.toString());
+ return invoke(getParent(), "_npi_less_scalar", this, params);
+ }
+
+ /**
+ * Returns the boolean {@code NDArray} for element-wise "Less" comparison.
+ *
+ * @param other the {@code NDArray} to compare
+ * @return the boolean {@code NDArray} for element-wise "Less" comparison
+ */
+ public NDArray lt(NDArray other) {
+ return invoke(getParent(), "_npi_less", new NDArray[] {this, other}, null);
+ }
+
+ /**
+ * Returns the boolean {@code NDArray} for element-wise "Less or equals" comparison.
+ *
+ * @param other the number to compare
+ * @return the boolean {@code NDArray} for element-wise "Less or equals" comparison
+ */
+ public NDArray lte(Number other) {
+ OpParams params = new OpParams();
+ params.add("scalar", other.toString());
+ return invoke(getParent(), "_npi_less_equal_scalar", this, params);
+ }
+
+ /**
+ * Returns the boolean {@code NDArray} for element-wise "Less or equals" comparison.
+ *
+ * @param other the {@code NDArray} to compare
+ * @return the boolean {@code NDArray} for element-wise "Less or equals" comparison
+ */
+ public NDArray lte(NDArray other) {
+ return invoke(getParent(), "_npi_less_equal", new NDArray[] {this, other}, null);
+ }
+
+ /**
+ * Adds a number to this {@code NDArray} element-wise.
+ *
+ * @param n the number to add
+ * @return the result {@code NDArray}
+ */
+ public NDArray add(Number n) {
+ OpParams params = new OpParams();
+ params.add("scalar", n.toString());
+ return invoke(getParent(), "_npi_add_scalar", this, params);
+ }
+
+ /**
+ * Adds other {@code NDArray}s to this {@code NDArray} element-wise.
+ *
+ * @param other the other {@code NDArray}s to add
+ * @return the result {@code NDArray}
+ * @throws IllegalArgumentException others arrays must have at least one element
+ */
+ public NDArray add(NDArray other) {
+ return invoke(getParent(), "_npi_add", new NDArray[] {this, other}, null);
+ }
+
+ /**
+ * Subtracts a number from this {@code NDArray} element-wise.
+ *
+ * @param n the number to subtract from
+ * @return the result {@code NDArray}
+ */
+ public NDArray sub(Number n) {
+ OpParams params = new OpParams();
+ params.add("scalar", n.toString());
+ return invoke(getParent(), "_npi_subtract_scalar", this, params);
+ }
+
+ /**
+ * Subtracts the other {@code NDArray} from this {@code NDArray} element-wise.
+ *
+ * @param other the other {@code NDArray} to subtract from
+ * @return the result {@code NDArray}
+ */
+ public NDArray sub(NDArray other) {
+ return invoke(getParent(), "_npi_subtract", new NDArray[] {this, other}, null);
+ }
+
+ /**
+ * Multiplies this {@code NDArray} by a number element-wise.
+ *
+ * @param n the number to multiply by
+ * @return the result {@code NDArray}
+ */
+ public NDArray mul(Number n) {
+ OpParams params = new OpParams();
+ params.add("scalar", n.toString());
+ return invoke(getParent(), "_npi_multiply_scalar", this, params);
+ }
+
+ /**
+ * Multiplies this {@code NDArray} by other {@code NDArray}s element-wise.
+ *
+ * @param other the other {@code NDArray}s to multiply by
+ * @return the result {@code NDArray}
+ * @throws IllegalArgumentException others arrays must have at least one element
+ */
+ public NDArray mul(NDArray other) {
+ return invoke(getParent(), "_npi_multiply", new NDArray[] {this, other}, null);
+ }
+
+ /**
+ * Divides this {@code NDArray} by a number element-wise.
+ *
+ * @param n the number to divide by
+ * @return the result {@code NDArray}
+ */
+ public NDArray div(Number n) {
+ OpParams params = new OpParams();
+ params.add("scalar", n.toString());
+ return invoke(getParent(), "_npi_true_divide_scalar", this, params);
+ }
+
+ /**
+ * Divides this {@code NDArray} by the other {@code NDArray} element-wise.
+ *
+ * @param other the other {@code NDArray} to divide by
+ * @return the result {@code NDArray}
+ */
+ public NDArray div(NDArray other) {
+ return invoke(getParent(), "_npi_true_divide", new NDArray[] {this, other}, null);
+ }
+
+ /**
+ * Returns element-wise remainder of division.
+ *
+ * @param n the divisor number
+ * @return the result {@code NDArray}
+ */
+ public NDArray mod(Number n) {
+ OpParams params = new OpParams();
+ params.add("scalar", n.toString());
+ return invoke(getParent(), "_npi_mod_scalar", this, params);
+ }
+
+ /**
+ * Returns element-wise remainder of division.
+ *
+ * @param other the divisor {@code NDArray}
+ * @return the result {@code NDArray}
+ */
+ public NDArray mod(NDArray other) {
+ return invoke(getParent(), "_npi_mod", new NDArray[] {this, other}, null);
+ }
+
+ /**
+ * Takes the power of this {@code NDArray} with a number element-wise.
+ *
+ * @param n the number to take the power with
+ * @return the result {@code NDArray}
+ */
+ public NDArray pow(Number n) {
+ OpParams params = new OpParams();
+ params.add("scalar", n.toString());
+ return invoke(getParent(), "_npi_power_scalar", this, params);
+ }
+
+ /**
+ * Takes the power of this {@code NDArray} with the other {@code NDArray} element-wise.
+ *
+ * @param other the other {@code NDArray} to take the power with
+ * @return the result {@code NDArray}
+ */
+ public NDArray pow(NDArray other) {
+ return invoke(getParent(), "_npi_power", new NDArray[] {this, other}, null);
+ }
+
+ /**
+ * Adds a number to this {@code NDArray} element-wise in place.
+ *
+ * @param n the number to add
+ * @return the result {@code NDArray}
+ */
+ public NDArray addi(Number n) {
+ OpParams params = new OpParams();
+ params.add("scalar", n.toString());
+ invoke("_npi_add_scalar", new NDArray[] {this}, new NDArray[] {this}, params);
+ return this;
+ }
+
+ /**
+ * Adds other {@code NDArray}s to this {@code NDArray} element-wise in place.
+ *
+ * @param other the other {@code NDArray}s to add
+ * @return the result {@code NDArray}
+ */
+ public NDArray addi(NDArray other) {
+ invoke("_npi_add", new NDArray[] {this, other}, new NDArray[] {this}, null);
+ return this;
+ }
+
+ /**
+ * Subtracts a number from this {@code NDArray} element-wise in place.
+ *
+ * @param n the number to subtract
+ * @return the result {@code NDArray}
+ */
+ public NDArray subi(Number n) {
+ OpParams params = new OpParams();
+ params.add("scalar", n.toString());
+ invoke("_npi_subtract_scalar", new NDArray[] {this}, new NDArray[] {this}, params);
+ return this;
+ }
+
+ /**
+ * Subtracts the other {@code NDArray} from this {@code NDArray} element-wise in place.
+ *
+ * @param other the other {@code NDArray} to subtract from
+ * @return the result {@code NDArray}
+ */
+ public NDArray subi(NDArray other) {
+ invoke("_npi_subtract", new NDArray[] {this, other}, new NDArray[] {this}, null);
+ return this;
+ }
+
+ /**
+ * Multiplies this {@code NDArray} by a number element-wise in place.
+ *
+ * @param n the number to multiply by
+ * @return the result {@code NDArray}
+ */
+ public NDArray muli(Number n) {
+ OpParams params = new OpParams();
+ params.add("scalar", n.toString());
+ invoke("_npi_multiply_scalar", new NDArray[] {this}, new NDArray[] {this}, params);
+ return this;
+ }
+
+ /**
+ * Multiplies this {@code NDArray} by other {@code NDArray} element-wise in place.
+ *
+ * @param other the other NDArrays to multiply with
+ * @return the result {@code NDArray}
+ */
+ public NDArray muli(NDArray other) {
+ invoke("_npi_multiply", new NDArray[] {this, other}, new NDArray[] {this}, null);
+ return this;
+ }
+
+ /**
+ * Divides this {@code NDArray} by a number element-wise in place.
+ *
+ * @param n the number to divide values by
+ * @return the array after applying division operation
+ */
+ public NDArray divi(Number n) {
+ OpParams params = new OpParams();
+ params.add("scalar", n.toString());
+ invoke("_npi_true_divide_scalar", new NDArray[] {this}, new NDArray[] {this}, params);
+ return this;
+ }
+
+ /**
+ * Divides this {@code NDArray} by the other {@code NDArray} element-wise in place.
+ *
+ * @param other the other {@code NDArray} to divide by
+ * @return the result of the divide
+ */
+ public NDArray divi(NDArray other) {
+ invoke("_npi_true_divide", new NDArray[] {this, other}, new NDArray[] {this}, null);
+ return this;
+ }
+
+ /**
+ * Returns element-wise remainder of division in place.
+ *
+ * @param n the divisor number
+ * @return the result {@code NDArray}
+ */
+ public NDArray modi(Number n) {
+ OpParams params = new OpParams();
+ params.add("scalar", n.toString());
+ invoke("_npi_mod_scalar", new NDArray[] {this}, new NDArray[] {this}, params);
+ return this;
+ }
+
+ /**
+ * Returns in place element-wise remainder of division in place.
+ *
+ * @param other the divisor {@code NDArray}
+ * @return the result of the divide
+ */
+ public NDArray modi(NDArray other) {
+ invoke("_npi_mod", new NDArray[] {this, other}, new NDArray[] {this}, null);
+ return this;
+ }
+
+ /**
+ * Takes the power of this {@code NDArray} with a number element-wise in place.
+ *
+ * @param n the number to raise the power to
+ * @return the result {@code NDArray}
+ */
+ public NDArray powi(Number n) {
+ OpParams params = new OpParams();
+ params.add("scalar", n.toString());
+ invoke("_npi_power_scalar", new NDArray[] {this}, new NDArray[] {this}, params);
+ return this;
+ }
+
+ /**
+ * Takes the power of this {@code NDArray} with the other {@code NDArray} element-wise in place.
+ *
+ * @param other the other {@code NDArray} to take the power with
+ * @return the result {@code NDArray}
+ */
+ public NDArray powi(NDArray other) {
+ invoke("_npi_power", new NDArray[] {this, other}, new NDArray[] {this}, null);
+ return this;
+ }
+
+ /**
+ * Returns the element-wise sign.
+ *
+ * @return the result {@code NDArray}
+ */
+ public NDArray sign() {
+ return invoke(getParent(), "_npi_sign", this, null);
+ }
+
+ /**
+ * Returns the element-wise sign in-place.
+ *
+ * @return the result {@code NDArray}
+ */
+ public NDArray signi() {
+ invoke("_npi_sign", new NDArray[] {this}, new NDArray[] {this}, null);
+ return this;
+ }
+
+ /**
+ * Returns the maximum of this {@code NDArray} and a number element-wise.
+ *
+ * @param n the number to be compared
+ * @return the maximum of this {@code NDArray} and a number element-wise
+ */
+ public NDArray maximum(Number n) {
+ OpParams params = new OpParams();
+ params.add("scalar", n.toString());
+ return invoke(getParent(), "_npi_maximum_scalar", this, params);
+ }
+
+ /**
+ * Returns the maximum of this {@code NDArray} and the other {@code NDArray} element-wise.
+ *
+ * @param other the {@code NDArray} to be compared
+ * @return the maximum of this {@code NDArray} and the other {@code NDArray} element-wise
+ */
+ public NDArray maximum(NDArray other) {
+ return invoke(getParent(), "_npi_maximum", new NDArray[] {this, other}, null);
+ }
+
+ /**
+ * Returns the minimum of this {@code NDArray} and a number element-wise.
+ *
+ * @param n the number to be compared
+ * @return the minimum of this {@code NDArray} and a number element-wise
+ */
+ public NDArray minimum(Number n) {
+ OpParams params = new OpParams();
+ params.add("scalar", n.toString());
+ return invoke(getParent(), "_npi_minimum_scalar", this, params);
+ }
+
+ /**
+ * Returns the maximum of this {@code NDArray} and the other {@code NDArray} element-wise.
+ *
+ * @param other the {@code NDArray} to be compared
+ * @return the maximum of this {@code NDArray} and the other {@code NDArray} element-wise
+ */
+ public NDArray minimum(NDArray other) {
+ return invoke(getParent(), "_npi_minimum", new NDArray[] {this, other}, null);
+ }
+
+ /**
+ * Returns the numerical negative {@code NDArray} element-wise.
+ *
+ * @return the result {@code NDArray}
+ */
+ public NDArray neg() {
+ return invoke(getParent(), "_npi_negative", this, null);
+ }
+
+ /**
+ * Returns the numerical negative {@code NDArray} element-wise in place.
+ *
+ * @return the result {@code NDArray}
+ */
+ public NDArray negi() {
+ invoke("_npi_negative", new NDArray[] {this}, new NDArray[] {this}, null);
+ return this;
+ }
+
+ /**
+ * Returns the absolute value of this {@code NDArray} element-wise.
+ *
+ * @return the result {@code NDArray}
+ */
+ public NDArray abs() {
+ return invoke(getParent(), "_npi_absolute", this, null);
+ }
+
+ /**
+ * Returns the square of this {@code NDArray} element-wise.
+ *
+ * @return the result {@code NDArray}
+ */
+ public NDArray square() {
+ return invoke(getParent(), "_npi_square", this, null);
+ }
+
+ /**
+ * Returns the square root of this {@code NDArray} element-wise.
+ *
+ * @return the result {@code NDArray}
+ */
+ public NDArray sqrt() {
+ return invoke(getParent(), "_npi_sqrt", this, null);
+ }
+
+ /**
+ * Returns the cube-root of this {@code NDArray} element-wise.
+ *
+ * @return the result {@code NDArray}
+ */
+ public NDArray cbrt() {
+ return invoke(getParent(), "_npi_cbrt", this, null);
+ }
+
+ /**
+ * Returns the floor of this {@code NDArray} element-wise.
+ *
+ * @return the result {@code NDArray}
+ */
+ public NDArray floor() {
+ return invoke(getParent(), "_npi_floor", this, null);
+ }
+
+ /**
+ * Returns the ceiling of this {@code NDArray} element-wise.
+ *
+ * @return the result {@code NDArray}
+ */
+ public NDArray ceil() {
+ return invoke(getParent(), "_npi_ceil", this, null);
+ }
+
+ /**
+ * Returns the round of this {@code NDArray} element-wise.
+ *
+ * @return the result {@code NDArray}
+ */
+ public NDArray round() {
+ return invoke(getParent(), "round", this, null);
+ }
+
+ /**
+ * Returns the truncated value of this {@code NDArray} element-wise.
+ *
+ * @return the result {@code NDArray}
+ */
+ public NDArray trunc() {
+ return invoke(getParent(), "_npi_trunc", this, null);
+ }
+
+ /**
+ * Returns the exponential value of this {@code NDArray} element-wise.
+ *
+ * @return the result {@code NDArray}
+ */
+ public NDArray exp() {
+ return invoke(getParent(), "_npi_exp", this, null);
+ }
+
+ /**
+ * Returns the natural logarithmic value of this {@code NDArray} element-wise.
+ *
+ * @return the result {@code NDArray}
+ */
+ public NDArray log() {
+ return invoke(getParent(), "_npi_log", this, null);
+ }
+
+ /**
+ * Returns the base 10 logarithm of this {@code NDArray} element-wise.
+ *
+ * @return the result {@code NDArray}
+ */
+ public NDArray log10() {
+ return invoke(getParent(), "_npi_log10", this, null);
+ }
+
+ /**
+ * Returns the base 2 logarithm of this {@code NDArray} element-wise.
+ *
+ * @return the result {@code NDArray}
+ */
+ public NDArray log2() {
+ return invoke(getParent(), "_npi_log2", this, null);
+ }
+
+ /**
+ * Returns the trigonometric sine of this {@code NDArray} element-wise.
+ *
+ * @return the result {@code NDArray}
+ */
+ public NDArray sin() {
+ return invoke(getParent(), "_npi_sin", this, null);
+ }
+
+ /**
+ * Returns the trigonometric cosine of this {@code NDArray} element-wise.
+ *
+ * @return the result {@code NDArray}
+ */
+ public NDArray cos() {
+ return invoke(getParent(), "_npi_cos", this, null);
+ }
+
+ /**
+ * Returns the trigonometric tangent of this {@code NDArray} element-wise.
+ *
+ * @return the result {@code NDArray}
+ */
+ public NDArray tan() {
+ return invoke(getParent(), "_npi_tan", this, null);
+ }
+
+ /**
+ * Returns the inverse trigonometric sine of this {@code NDArray} element-wise.
+ *
+ * @return the result {@code NDArray}
+ */
+ public NDArray asin() {
+ return invoke(getParent(), "_npi_arcsin", this, null);
+ }
+
+ /**
+ * Returns the inverse trigonometric cosine of this {@code NDArray} element-wise.
+ *
+ * @return the result {@code NDArray}
+ */
+ public NDArray acos() {
+ return invoke(getParent(), "_npi_arccos", this, null);
+ }
+
+ /**
+ * Returns the inverse trigonometric tangent of this {@code NDArray} element-wise.
+ *
+ * @return the result {@code NDArray}
+ */
+ public NDArray atan() {
+ return invoke(getParent(), "_npi_arctan", this, null);
+ }
+
+ /**
+ * Returns the hyperbolic sine of this {@code NDArray} element-wise.
+ *
+ * @return the result {@code NDArray}
+ */
+ public NDArray sinh() {
+ return invoke(getParent(), "_npi_sinh", this, null);
+ }
+
+ /**
+ * Returns the hyperbolic cosine of this {@code NDArray} element-wise.
+ *
+ * @return the result {@code NDArray}
+ */
+ public NDArray cosh() {
+ return invoke(getParent(), "_npi_cosh", this, null);
+ }
+
+ /**
+ * Returns the hyperbolic tangent of this {@code NDArray} element-wise.
+ *
+ * @return the result {@code NDArray}
+ */
+ public NDArray tanh() {
+ return invoke(getParent(), "_npi_tanh", this, null);
+ }
+
+ /**
+ * Returns the inverse hyperbolic sine of this {@code NDArray} element-wise.
+ *
+ * @return the result {@code NDArray}
+ */
+ public NDArray asinh() {
+ return invoke(getParent(), "_npi_arcsinh", this, null);
+ }
+
+ /**
+ * Returns the inverse hyperbolic cosine of this {@code NDArray} element-wise.
+ *
+ * @return the result {@code NDArray}
+ */
+ public NDArray acosh() {
+ return invoke(getParent(), "_npi_arccosh", this, null);
+ }
+
+ /**
+ * Returns the inverse hyperbolic tangent of this {@code NDArray} element-wise.
+ *
+ * @return the result {@code NDArray}
+ */
+ public NDArray atanh() {
+ return invoke(getParent(), "_npi_arctanh", this, null);
+ }
+
+ /**
+ * Converts this {@code NDArray} from radians to degrees element-wise.
+ *
+ * @return the result {@code NDArray}
+ */
+ public NDArray toDegrees() {
+ return invoke(getParent(), "_npi_degrees", this, null);
+ }
+
+ /**
+ * Converts this {@code NDArray} from degrees to radians element-wise.
+ *
+ * @return the result {@code NDArray}
+ */
+ public NDArray toRadians() {
+ return invoke(getParent(), "_npi_radians", this, null);
+ }
+
+ /**
+ * Returns the maximum of this {@code NDArray}.
+ *
+ * @return the maximum of this {@code NDArray}
+ */
+ public NDArray max() {
+ return invoke(getParent(), "_np_max", this, null);
+ }
+
+ /**
+ * Returns the maximum of this {@code NDArray} along given axes.
+ *
+ * @param axes the axes along which to operate
+ * @return the maximum of this {@code NDArray} with the specified axes removed from the Shape
+ * containing the max
+ * @see NDArray#max(int[], boolean)
+ */
+ public NDArray max(int[] axes) {
+ OpParams params = new OpParams();
+ params.addTupleParam("axis", axes);
+ return invoke(getParent(), "_np_max", this, params);
+ }
+
+ /**
+ * Returns the maximum of this {@code NDArray} along given axes.
+ *
+ * @param axes the axes along which to operate
+ * @param keepDims {@code true} to keep the specified axes as size 1 in the output array, {@code
+ * false} to squeeze the values out of the output array.
+ * @return the maximum of this {@code NDArray}
+ */
+ public NDArray max(int[] axes, boolean keepDims) {
+ OpParams params = new OpParams();
+ params.addTupleParam("axis", axes);
+ params.addParam("keepdims", keepDims);
+ return invoke(getParent(), "_np_max", this, params);
+ }
+
+ /**
+ * Returns the minimum of this {@code NDArray}.
+ *
+ * @return the minimum of this {@code NDArray}
+ */
+ public NDArray min() {
+ return invoke(getParent(), "_np_min", this, null);
+ }
+
+ /**
+ * Returns the minimum of this {@code NDArray} along given axes.
+ *
+ * @param axes the axes along which to operate
+ * @return the minimum of this {@code NDArray} with the specified axes removed from the Shape
+ * containing the min
+ * @see NDArray#min(int[], boolean)
+ */
+ public NDArray min(int[] axes) {
+ return min(axes, false);
+ }
+
+ /**
+ * Returns the minimum of this {@code NDArray} along given axes.
+ *
+ * @param axes the axes along which to operate
+ * @param keepDims {@code true} to keep the specified axes as size 1 in the output array, {@code
+ * false} to squeeze the values out of the output array
+ * @return the minimum of this {@code NDArray}
+ */
+ public NDArray min(int[] axes, boolean keepDims) {
+ OpParams params = new OpParams();
+ params.addTupleParam("axis", axes);
+ params.addParam("keepdims", keepDims);
+ return invoke(getParent(), "_np_min", this, params);
+ }
+
+ /**
+ * Returns the sum of this {@code NDArray}.
+ *
+ * @return the sum of this {@code NDArray}
+ */
+ public NDArray sum() {
+ // TODO current windows doesn't support boolean MxNDArray
+ if (System.getProperty("os.name").toLowerCase().contains("win")) {
+ DataType target = getDataType();
+ if (!target.isFloating()) {
+ try (NDArray thisArr = toType(DataType.FLOAT32, false)) {
+ if (target == DataType.BOOLEAN) {
+ target = DataType.INT64;
+ }
+ try (NDArray array = invoke(getParent(), "_np_sum", thisArr, null)) {
+ return array.toType(target, false);
+ }
+ }
+ }
+ }
+ return invoke(getParent(), "_np_sum", this, null);
+ }
+
+ /**
+ * Returns the sum of this {@code NDArray} along given axes.
+ *
+ * @param axes the axes along which to operate
+ * @return the sum of this {@code NDArray} with the specified axes removed from the Shape
+ * containing the sum
+ * @see NDArray#sum(int[], boolean)
+ */
+ public NDArray sum(int[] axes) {
+ return sum(axes, false);
+ }
+
+ /**
+ * Returns the sum of this {@code NDArray} along given axes.
+ *
+ * @param axes the axes along which to operate
+ * @param keepDims {@code true} to keep the specified axes as size 1 in the output array, {@code
+ * false} to squeeze the values out of the output array
+ * @return the sum of this {@code NDArray}
+ */
+ public NDArray sum(int[] axes, boolean keepDims) {
+ OpParams params = new OpParams();
+ params.addTupleParam("axis", axes);
+ params.addParam("keepdims", keepDims);
+ return invoke(getParent(), "_np_sum", this, params);
+ }
+
+ /**
+ * Returns the product of this {@code NDArray}.
+ *
+ * @return the product of this {@code NDArray}
+ */
+ public NDArray prod() {
+ return invoke(getParent(), "_np_prod", this, null);
+ }
+
+ /**
+ * Returns the product of this {@code NDArray} elements over the given axes.
+ *
+ * @param axes the axes along which to operate
+ * @return the product of this {@code NDArray} with the specified axes removed from the Shape
+ * containing the prod
+ * @see NDArray#prod(int[], boolean)
+ */
+ NDArray prod(int[] axes) {
+ return prod(axes, false);
+ }
+
+ /**
+ * Returns the product of this {@code NDArray} elements over the given axes.
+ *
+ * @param axes the axes along which to operate
+ * @param keepDims {@code true} to keep the specified axes as size 1 in the output array, {@code
+ * false} to squeeze the values out of the output array
+ * @return the product of this {@code NDArray}
+ */
+ public NDArray prod(int[] axes, boolean keepDims) {
+ OpParams params = new OpParams();
+ params.addTupleParam("axis", axes);
+ params.addParam("keepdims", keepDims);
+ return invoke(getParent(), "_np_prod", this, params);
+ }
+
+ /**
+ * Returns the average of this {@code NDArray}.
+ *
+ * @return the average of this {@code NDArray}
+ */
+ public NDArray mean() {
+ return invoke(getParent(), "_npi_mean", this, null);
+ }
+
+ /**
+ * Returns the average of this {@code NDArray} along given axes.
+ *
+ * @param axes the axes along which to operate
+ * @return the average of this {@code NDArray} with the specified axes removed from the Shape
+ * containing the mean
+ * @see NDArray#mean(int[], boolean)
+ */
+ public NDArray mean(int[] axes) {
+ return mean(axes, false);
+ }
+
+ /**
+ * Returns the average of this {@code NDArray} along given axes.
+ *
+ * @param axes the axes along which to operate
+ * @param keepDims {@code true} to keep the specified axes as size 1 in the output array, {@code
+ * false} to squeeze the values out of the output array
+ * @return the average of this {@code NDArray}
+ */
+ public NDArray mean(int[] axes, boolean keepDims) {
+ OpParams params = new OpParams();
+ params.addTupleParam("axis", axes);
+ params.addParam("keepdims", keepDims);
+ return invoke(getParent(), "_npi_mean", this, params);
+ }
+
+ /**
+ * Rotates an array by 90 degrees in the plane specified by axes.
+ *
+ * @param times Number of times the array is rotated by 90 degrees.
+ * @param axes The array is rotated in the plane defined by the axes. Axes must be different.
+ * @return the rotated NDArray
+ */
+ public NDArray rotate90(int times, int[] axes) {
+ if (axes.length != 2) {
+ throw new IllegalArgumentException("Axes must be 2");
+ }
+ OpParams params = new OpParams();
+ params.addTupleParam("axes", axes);
+ params.addParam("k", times);
+ return invoke(getParent(), "_npi_rot90", this, params);
+ }
+
+ /**
+ * Returns the sum along diagonals of this {@code NDArray}.
+ *
+ * <p>If this {@code NDArray} is 2-D, the sum along its diagonal with the given offset is
+ * returned, i.e., the sum of elements a[i,i+offset] for all i. If this {@code NDArray} has more
+ * than two dimensions, then the axes specified by axis1 and axis2 are used to determine the 2-D
+ * sub-arrays whose traces are returned. The {@link Shape} of the resulting array is the same as
+ * this {@code NDArray} with axis1 and axis2 removed.
+ *
+ * @param offset offset of the diagonal from the main diagonal. Can be both positive and
+ * negative.
+ * @param axis1 axes to be used as the first axis of the 2-D sub-arrays from which the diagonals
+ * should be taken
+ * @param axis2 axes to be used as the second axis of the 2-D sub-arrays from which the
+ * diagonals should be taken
+ * @return the sum along diagonals of this {@code NDArray}
+ */
+ public NDArray trace(int offset, int axis1, int axis2) {
+ OpParams params = new OpParams();
+ params.addParam("offset", offset);
+ params.addParam("axis1", axis1);
+ params.addParam("axis2", axis2);
+ return invoke(getParent(), "_np_trace", this, params);
+ }
+
+ /**
+ * Returns the sum along diagonals of this {@code NDArray}.
+ *
+ * <p>If this {@code NDArray} is 2-D, the sum along its diagonal with the given offset is
+ * returned, i.e., the sum of elements a[i,i+offset] for all i. If this {@code NDArray} has more
+ * than two dimensions, then the axes specified by axis1 and axis2 are used to determine the 2-D
+ * sub-arrays whose traces are returned. The {@link Shape} of the resulting array is the same as
+ * this {@code NDArray} with axis1 and axis2 removed.
+ *
+ * @param offset offset of the diagonal from the main diagonal. Can be both positive and
+ * negative.
+ * @return the sum along diagonals of this {@code NDArray}
+ */
+ public NDArray trace(int offset) {
+ return trace(offset, 0, 1);
+ }
+
+ /**
+ * Splits this {@code NDArray} into multiple sub{@code NDArray}s given sections along the given
+ * axis.
+ *
+ * @param indices this {@code NDArray} will be divided into N (sections) equal arrays along axis
+ * @param axis the axis to split along
+ * @return an {@link NDList} with numOutputs {@code NDArray}s with {@link Shape} {@code
+ * (this.shape.axis /= axis) }
+ * @throws IllegalArgumentException thrown if the numOutputs does not equally divide the given
+ * axis
+ */
+ public NDList split(long[] indices, int axis) {
+ if (indices.length == 0) {
+ return new NDList(this);
+ }
+ OpParams params = new OpParams();
+ // follow the numpy behavior
+ if (indices[0] != 0) {
+ long[] tempIndices = new long[indices.length + 1];
+ tempIndices[0] = 0;
+ System.arraycopy(indices, 0, tempIndices, 1, indices.length);
+ indices = tempIndices;
+ }
+ params.addTupleParam("indices", indices);
+ params.addParam("axis", axis);
+ params.addParam("squeeze_axis", false);
+ return invoke(getParent(), "_npi_split", new NDList(this), params);
+ }
+
+ /**
+ * Flattens this {@code NDArray} into a 1-D {@code NDArray} in row-major order.
+ *
+ * <p>To flatten in column-major order, first transpose this {@code NDArray}
+ *
+ * @return a 1-D {@code NDArray} of equal size
+ */
+ public NDArray flatten() {
+ return reshape(new Shape(Math.toIntExact(size())));
+ }
+
+ /**
+ * Reshapes this {@code NDArray} to the given {@link Shape}.
+ *
+ * <p>You can reshape it to match another NDArray by calling {@code a.reshape(b.getShape()) }
+ *
+ * @param shape the {@link Shape} to reshape into. Must have equal size to the current shape
+ * @return a reshaped {@code NDArray}
+ * @throws IllegalArgumentException thrown if the given {@link Shape} does not match the size of
+ * the current shape
+ */
+ public NDArray reshape(Shape shape) {
+ OpParams params = new OpParams();
+ params.addParam("newshape", shape);
+ return invoke(getParent(), "_np_reshape", this, params);
+ }
+
+ /**
+ * Reshapes this {@code NDArray} to the given {@link Shape}.
+ *
+ * @param newShape the long array to reshape into. Must have equal size to the current shape
+ * @return a reshaped {@code NDArray}
+ * @throws IllegalArgumentException thrown if the given {@link Shape} does not match the size of
+ * the current shape
+ */
+ public NDArray reshape(long... newShape) {
+ return reshape(new Shape(newShape));
+ }
+
+ /**
+ * Expands the {@link Shape} of a {@code NDArray}.
+ *
+ * <p>Inserts a new axis that will appear at the axis position in the expanded {@code NDArray}
+ * shape.
+ *
+ * @param axis the position in the expanded axes where the new axis is placed
+ * @return the result {@code NDArray}. The number of dimensions is one greater than that of the
+ * {@code NDArray}
+ */
+ public NDArray expandDims(int axis) {
+ OpParams params = new OpParams();
+ params.addParam("axis", axis);
+ return invoke(getParent(), "_npi_expand_dims", this, params);
+ }
+
+ /**
+ * Removes all singleton dimensions from this {@code NDArray} {@link Shape}.
+ *
+ * @return a result {@code NDArray} of same size and data without singleton dimensions
+ */
+ public NDArray squeeze() {
+ return invoke(getParent(), "_np_squeeze", this, null);
+ }
+
+ /**
+ * Removes singleton dimensions at the given axes.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> NDArray array = manager.create(new float[] {0f, 1f, 2f}, new Shape(1, 3, 1));
+ * jshell> array;
+ * ND: (1, 3, 1) cpu() float32
+ * [[[0.],
+ * [1.],
+ * [2.],
+ * ],
+ * ]
+ * jshell> array.squeeze(new int[] {0, 2});
+ * ND: (3) cpu() float32
+ * [0., 1., 2.]
+ * </pre>
+ *
+ * @param axes the axes at which to remove the singleton dimensions
+ * @return a result {@code NDArray} of same size and data without the axes at part of the shape
+ * @throws IllegalArgumentException thrown if any of the given axes are not a singleton
+ * dimension
+ */
+ public NDArray squeeze(int[] axes) {
+ OpParams params = new OpParams();
+ params.addTupleParam("axis", axes);
+ return invoke(getParent(), "_np_squeeze", this, params);
+ }
+
+ /**
+ * Returns the truth value of this {@code NDArray} AND the other {@code NDArray} element-wise.
+ *
+ * <p>The shapes of this {@code NDArray} and the other {@code NDArray} must be broadcastable.
+ *
+ * @param other the other {@code NDArray} to operate on
+ * @return the boolean {@code NDArray} of the logical AND operation applied to the elements of
+ * this {@code NDArray} and the other {@code NDArray}
+ */
+ public NDArray logicalAnd(NDArray other) {
+ // TODO switch to numpy op, although current op support zero-dim, scalar
+ NDArray thisArr =
+ (getDataType() == DataType.BOOLEAN) ? toType(DataType.INT32, false) : this;
+ other =
+ (other.getDataType() == DataType.BOOLEAN)
+ ? other.toType(DataType.INT32, false)
+ : other;
+ return invoke(getParent(), "broadcast_logical_and", new NDArray[] {thisArr, other}, null)
+ .toType(DataType.BOOLEAN, false);
+ }
+
+ /**
+ * Computes the truth value of this {@code NDArray} OR the other {@code NDArray} element-wise.
+ *
+ * <p>The shapes of this {@code NDArray} and the other {@code NDArray} must be broadcastable.
+ *
+ * @param other the other {@code NDArray} to operate on
+ * @return the boolean {@code NDArray} of the logical OR operation applied to the elements of
+ * this {@code NDArray} and the other {@code NDArray}
+ */
+ public NDArray logicalOr(NDArray other) {
+ // TODO switch to numpy op, although current op support zero-dim, scalar
+ NDArray thisArr =
+ (getDataType() == DataType.BOOLEAN) ? toType(DataType.INT32, false) : this;
+ other =
+ (other.getDataType() == DataType.BOOLEAN)
+ ? other.toType(DataType.INT32, false)
+ : other;
+ return invoke(getParent(), "broadcast_logical_or", new NDArray[] {thisArr, other}, null)
+ .toType(DataType.BOOLEAN, false);
+ }
+
+ /**
+ * Computes the truth value of this {@code NDArray} XOR the other {@code NDArray} element-wise.
+ *
+ * <p>The shapes of this {@code NDArray} and the other {@code NDArray} must be broadcastable.
+ *
+ * @param other the other {@code NDArray} to operate on
+ * @return the boolean {@code NDArray} of the logical XOR operation applied to the elements of
+ * this {@code NDArray} and the other {@code NDArray}
+ */
+ public NDArray logicalXor(NDArray other) {
+ // TODO switch to numpy op, although current op support zero-dim, scalar
+ NDArray thisArr =
+ (getDataType() == DataType.BOOLEAN) ? toType(DataType.INT32, false) : this;
+ other =
+ (other.getDataType() == DataType.BOOLEAN)
+ ? other.toType(DataType.INT32, false)
+ : other;
+ return invoke(getParent(), "broadcast_logical_xor", new NDArray[] {thisArr, other}, null)
+ .toType(DataType.BOOLEAN, false);
+ }
+
+ /**
+ * Computes the truth value of NOT this {@code NDArray} element-wise.
+ *
+ * @return the boolean {@code NDArray}
+ */
+ public NDArray logicalNot() {
+ return invoke(getParent(), "_npi_logical_not", this, null);
+ }
+
+ /**
+ * Returns the indices that would sort this {@code NDArray} given the axis.
+ *
+ * <p>Perform an indirect sort along the given axis. It returns a {@code NDArray} of indices of
+ * the same {@link Shape} as this {@code NDArray}.
+ *
+ * @param axis the axis to sort along
+ * @param ascending whether to sort ascending
+ * @return a {@code NDArray} of indices corresponding to elements in this {@code NDArray} on the
+ * axis, the output DataType is always {@link DataType#INT64}
+ */
+ public NDArray argSort(int axis, boolean ascending) {
+ OpParams params = new OpParams();
+ params.addParam("axis", axis);
+ // be careful that MXNet numpy argsort op didn't officially support this param
+ params.addParam("is_ascend", ascending);
+ params.setDataType(DataType.INT64);
+ return invoke(getParent(), "_npi_argsort", this, params);
+ }
+
+ /**
+ * Sorts the flattened {@code NDArray}.
+ *
+ * @param axis the axis to sort along
+ * @return the sorted {@code NDArray}
+ */
+ public NDArray sort(int axis) {
+ OpParams params = new OpParams();
+ params.addParam("axis", axis);
+ return invoke(getParent(), "_npi_sort", this, params);
+ }
+
+ /**
+ * Sorts the flattened {@code NDArray}.
+ *
+ * @return the sorted {@code NDArray}
+ */
+ public NDArray sort() {
+ return invoke(getParent(), "_npi_sort", this, null);
+ }
+
+ /**
+ * Applies the softmax function along the given axis.
+ *
+ * @param axis the axis along which to apply
+ * @return the result {@code NDArray}
+ * @see <a href="https://en.wikipedia.org/wiki/Softmax_function">softmax</a>
+ * @see NDArray#softmax(int)
+ */
+ public NDArray softmax(int axis) {
+ // MXNet softmax op bug on GPU
+ if (isEmpty()) {
+ return create(getParent(), getShape(), DataType.FLOAT32, getDevice());
+ }
+ OpParams params = new OpParams();
+ params.addParam("axis", axis);
+ return invoke(getParent(), "_npx_softmax", this, params);
+ }
+
+ /**
+ * Applies the softmax function followed by a logarithm.
+ *
+ * <p>Mathematically equivalent to calling softmax and then log. This single operator is faster
+ * than calling two operators and numerically more stable when computing gradients.
+ *
+ * @param axis the axis along which to apply
+ * @return the result {@code NDArray}
+ */
+ public NDArray logSoftmax(int axis) {
+ // MXNet logsoftmax op bug on GPU
+ if (isEmpty()) {
+ return create(getParent(), getShape(), DataType.FLOAT32, getDevice());
+ }
+ OpParams params = new OpParams();
+ params.addParam("axis", axis);
+ return invoke(getParent(), "_npx_log_softmax", this, params);
+ }
+
+ /**
+ * Returns the cumulative sum of the elements in the flattened {@code NDArray}.
+ *
+ * @return the cumulative sum of the elements in the flattened {@code NDArray}
+ */
+ public NDArray cumSum() {
+ return invoke(getParent(), "_np_cumsum", this, null);
+ }
+
+ /**
+ * Return the cumulative sum of the elements along a given axis.
+ *
+ * @param axis the axis along which the cumulative sum is computed
+ * @return the cumulative sum along the specified axis
+ */
+ public NDArray cumSum(int axis) {
+ OpParams params = new OpParams();
+ params.addParam("axis", axis);
+ return invoke(getParent(), "_np_cumsum", this, params);
+ }
+
+ /**
+ * Replace the handle of the NDArray with the other. The NDArray used for replacement will be
+ * killed.
+ *
+ * <p>Please use with caution, this method will make the input argument unusable.
+ *
+ * @param replaced the handle provider that will be killed
+ */
+ public void intern(NDArray replaced) {
+ NDArray arr = replaced;
+ Pointer oldHandle = handle.getAndSet(arr.handle.getAndSet(null));
+ JnaUtils.waitToRead(oldHandle);
+ JnaUtils.freeNdArray(oldHandle);
+ // dereference old ndarray
+ arr.close();
+ }
+
+ /**
+ * Returns the boolean {@code NDArray} with value {@code true} where this {@code NDArray}'s
+ * entries are infinite, or {@code false} where they are not infinite.
+ *
+ * @return the boolean {@code NDArray} with value {@code true} if this {@code NDArray}'s entries
+ * are infinite
+ */
+ public NDArray isInfinite() {
+ throw new UnsupportedOperationException("Not implemented yet.");
+ }
+
+ /**
+ * Returns the boolean {@code NDArray} with value {@code true} where this {@code NDArray}'s
+ * entries are NaN, or {@code false} where they are not NaN.
+ *
+ * @return the boolean {@code NDArray} with value {@code true} if this {@code NDArray}'s {@link
+ * NDArray} are NaN
+ */
+ public NDArray isNaN() {
+ return invoke(getParent(), "_npi_isnan", this, null);
+ }
+
+ /**
+ * Returns a dense representation of the sparse {@code NDArray}.
+ *
+ * @return the result {@code NDArray}
+ */
+ public NDArray toDense() {
+ if (!isSparse()) {
+ return duplicate();
+ }
+ return castStorage(SparseFormat.DENSE);
+ }
+
+ /**
+ * Returns a sparse representation of {@code NDArray}.
+ *
+ * @param fmt the {@link SparseFormat} of this {@code NDArray}
+ * @return the result {@code NDArray}
+ */
+ public NDArray toSparse(SparseFormat fmt) {
+ if (fmt != SparseFormat.DENSE
+ && fmt != SparseFormat.CSR
+ && fmt != SparseFormat.ROW_SPARSE) {
+ throw new UnsupportedOperationException(fmt + " is not supported");
+ }
+ if (fmt == getSparseFormat()) {
+ return duplicate();
+ }
+ return castStorage(fmt);
+ }
+
+ private NDArray castStorage(SparseFormat fmt) {
+ OpParams params = new OpParams();
+ params.setParam("stype", fmt.getType());
+ return invoke(getParent(), "cast_storage", this, params);
+ }
+
+ /**
+ * Constructs a {@code NDArray} by repeating this {@code NDArray} the number of times given
+ * repeats.
+ *
+ * @param repeats the number of times to repeat for each dimension
+ * @return a NDArray that has been tiled
+ */
+ public NDArray tile(long repeats) {
+ // zero-dim
+ if (isEmpty()) {
+ return duplicate();
+ }
+ // scalar
+ int dim = (isScalar()) ? 1 : getShape().dimension();
+ long[] repeatsArray = new long[dim];
+ Arrays.fill(repeatsArray, repeats);
+ return tile(repeatsArray);
+ }
+
+ /**
+ * Constructs a {@code NDArray} by repeating this {@code NDArray} the number of times given by
+ * repeats.
+ *
+ * @param repeats the number of times to repeat along each axis
+ * @return a {@code NDArray} that has been tiled
+ */
+ public NDArray tile(long[] repeats) {
+ OpParams params = new OpParams();
+ params.addTupleParam("reps", repeats);
+ return invoke(getParent(), "_npi_tile", this, params);
+ }
+
+ /**
+ * Constructs a {@code NDArray} by repeating this {@code NDArray} the number of times given by
+ * repeats along given axis.
+ *
+ * @param axis the axis to repeat
+ * @param repeats the number of times to repeat for each axis
+ * @return a {@code NDArray} that has been tiled
+ * @throws IllegalArgumentException thrown for invalid axis
+ */
+ public NDArray tile(int axis, long repeats) {
+ // scalar
+ if (isScalar()) {
+ throw new IllegalArgumentException("scalar didn't support specifying axis");
+ }
+ long[] repeatsArray = new long[getShape().dimension()];
+ Arrays.fill(repeatsArray, 1);
+ repeatsArray[withAxis(axis)] = repeats;
+ return tile(repeatsArray);
+ }
+
+ /**
+ * Constructs a {@code NDArray} by repeating this {@code NDArray} the number of times to match
+ * the desired shape.
+ *
+ * <p>If the desired {@link Shape}has fewer dimensions than this {@code NDArray}, it will tile
+ * against the last axis.
+ *
+ * @param desiredShape the {@link Shape}that should be converted to
+ * @return a {@code NDArray} that has been tiled
+ */
+ public NDArray tile(Shape desiredShape) {
+ return tile(repeatsToMatchShape(desiredShape));
+ }
+
+ private int withAxis(int axis) {
+ return Math.floorMod(axis, getShape().dimension());
+ }
+
+ private long[] repeatsToMatchShape(Shape desiredShape) {
+ Shape curShape = getShape();
+ int dimension = curShape.dimension();
+ if (desiredShape.dimension() > dimension) {
+ throw new IllegalArgumentException("The desired shape has too many dimensions");
+ }
+ if (desiredShape.dimension() < dimension) {
+ int additionalDimensions = dimension - desiredShape.dimension();
+ desiredShape = curShape.slice(0, additionalDimensions).addAll(desiredShape);
+ }
+ long[] repeats = new long[dimension];
+ for (int i = 0; i < dimension; i++) {
+ if (curShape.get(i) == 0 || desiredShape.get(i) % curShape.get(i) != 0) {
+ throw new IllegalArgumentException(
+ "The desired shape is not a multiple of the original shape");
+ }
+ repeats[i] = Math.round(Math.ceil((double) desiredShape.get(i) / curShape.get(i)));
+ }
+ return repeats;
+ }
+
+ /**
+ * Repeats element of this {@code NDArray} the number of times given repeats.
+ *
+ * @param repeats the number of times to repeat for each axis
+ * @return an {@code NDArray} that has been repeated
+ */
+ public NDArray repeat(long repeats) {
+ // zero-dim
+ if (isEmpty()) {
+ return duplicate();
+ }
+ // scalar
+ int dim = (isScalar()) ? 1 : getShape().dimension();
+ long[] repeatsArray = new long[dim];
+ Arrays.fill(repeatsArray, repeats);
+ return repeat(repeatsArray);
+ }
+
+ /**
+ * Repeats element of this {@code NDArray} the number of times given repeats along given axis.
+ *
+ * @param axis the axis to repeat
+ * @param repeats the number of times to repeat for each axis
+ * @return an {@code NDArray} that has been repeated
+ * @throws IllegalArgumentException thrown for invalid axis
+ */
+ public NDArray repeat(int axis, long repeats) {
+ long[] repeatsArray = new long[getShape().dimension()];
+ Arrays.fill(repeatsArray, 1);
+ repeatsArray[withAxis(axis)] = repeats;
+ return repeat(repeatsArray);
+ }
+
+ /**
+ * Repeats element of this {@code NDArray} the number of times given repeats along each axis.
+ *
+ * @param repeats the number of times to repeat along each axis
+ * @return a {@code NDArray} that has been repeated
+ */
+ public NDArray repeat(long[] repeats) {
+ // TODO get rid of for loop once bug in MXNet np.repeat is fixed
+ NDArray array = this;
+ int baseAxis = getShape().dimension() - repeats.length;
+ for (int i = 0; i < repeats.length; i++) {
+ if (repeats[i] > 1) {
+ NDArray previousArray = array;
+ OpParams params = new OpParams();
+ params.addParam("repeats", repeats[i]);
+ params.addParam("axis", baseAxis + i);
+ array = invoke(getParent(), "_np_repeat", array, params);
+ if (previousArray != this) {
+ previousArray.close();
+ }
+ }
+ }
+ return array;
+ }
+
+ /**
+ * Repeats element of this {@code NDArray} to match the desired shape.
+ *
+ * <p>If the desired {@link Shape} has fewer dimensions that the array, it will repeat against
+ * the last axis.
+ *
+ * @param desiredShape the {@link Shape} that should be converted to
+ * @return an {@code NDArray} that has been repeated
+ */
+ public NDArray repeat(Shape desiredShape) {
+ return repeat(repeatsToMatchShape(desiredShape));
+ }
+
+ /**
+ * Dot product of this {@code NDArray} and the other {@code NDArray}.
+ *
+ * <ul>
+ * <li>If both this {@code NDArray} and the other {@code NDArray} are 1-D {@code NDArray}s, it
+ * is inner product of vectors (without complex conjugation).
+ * <li>If both this {@code NDArray} and the other {@code NDArray} are 2-D {@code NDArray}s, it
+ * is matrix multiplication.
+ * <li>If either this {@code NDArray} or the other {@code NDArray} is 0-D {@code NDArray}
+ * (scalar), it is equivalent to mul.
+ * <li>If this {@code NDArray} is N-D {@code NDArray} and the other {@code NDArray} is 1-D
+ * {@code NDArray}, it is a sum product over the last axis of those.
+ * <li>If this {@code NDArray} is N-D {@code NDArray} and the other {@code NDArray} is M-D
+ * {@code NDArray}(where M>=2), it is a sum product over the last axis of this
+ * {@code NDArray} and the second-to-last axis of the other {@code NDArray}
+ * </ul>
+ *
+ * @param other the other {@code NDArray} to perform dot product with
+ * @return the result {@code NDArray}
+ */
+ public NDArray dot(NDArray other) {
+ return invoke(getParent(), "_np_dot", new NDArray[] {this, other}, null);
+ }
+
+ /**
+ * Product matrix of this {@code NDArray} and the other {@code NDArray}.
+ *
+ * @param other the other {@code NDArray} to perform matrix product with
+ * @return the result {@code NDArray}
+ */
+ public NDArray matMul(NDArray other) {
+ if (isScalar() || other.isScalar()) {
+ throw new IllegalArgumentException("scalar is not allowed for matMul()");
+ }
+ return invoke(getParent(), "_npi_matmul", new NDArray[] {this, other}, null);
+ }
+
+ /**
+ * Clips (limit) the values in this {@code NDArray}.
+ *
+ * <p>Given an interval, values outside the interval are clipped to the interval edges. For
+ * example, if an interval of [0, 1] is specified, values smaller than 0 become 0, and values
+ * larger than 1 become 1.
+ *
+ * @param min the minimum value
+ * @param max the maximum value
+ * @return an {@code NDArray} with the elements of this {@code NDArray}, but where values <
+ * min are replaced with min, and those > max with max
+ */
+ public NDArray clip(Number min, Number max) {
+ OpParams params = new OpParams();
+ params.addParam("a_min", min);
+ params.addParam("a_max", max);
+ return invoke(getParent(), "_npi_clip", this, params);
+ }
+
+ /**
+ * Interchanges two axes of this {@code NDArray}.
+ *
+ * @param axis1 the first axis
+ * @param axis2 the second axis
+ * @return the swapped axes {@code NDArray}
+ */
+ public NDArray swapAxes(int axis1, int axis2) {
+ OpParams params = new OpParams();
+ params.addParam("dim1", axis1);
+ params.addParam("dim2", axis2);
+ return invoke(getParent(), "_npi_swapaxes", this, params);
+ }
+
+ /**
+ * Returns the reverse order of elements in an array along the given axis.
+ *
+ * <p>The shape of the array is preserved, but the elements are reordered.
+ *
+ * @param axes the axes to flip on
+ * @return the newly flipped array
+ */
+ public NDArray flip(int... axes) {
+ OpParams params = new OpParams();
+ params.addTupleParam("axis", axes);
+ return invoke(getParent(), "_npi_flip", this, params);
+ }
+
+ /**
+ * Returns this {@code NDArray} with axes transposed.
+ *
+ * @return the newly permuted array
+ */
+ public NDArray transpose() {
+ return invoke(getParent(), "_np_transpose", this, null);
+ }
+
+ /**
+ * Returns this {@code NDArray} with given axes transposed.
+ *
+ * @param dimensions the axes to swap to
+ * @return the transposed {@code NDArray}
+ * @throws IllegalArgumentException thrown when passing a axis that is greater than the actual
+ * number of dimensions
+ */
+ public NDArray transpose(int... dimensions) {
+ if (Arrays.stream(dimensions).anyMatch(d -> d < 0)) {
+ throw new UnsupportedOperationException(
+ "Passing -1 for broadcasting the dimension is not currently supported");
+ }
+ if (!Arrays.equals(
+ Arrays.stream(dimensions).sorted().toArray(),
+ IntStream.range(0, getShape().dimension()).toArray())) {
+ throw new IllegalArgumentException(
+ "You must include each of the dimensions from 0 until "
+ + getShape().dimension());
+ }
+ OpParams params = new OpParams();
+ params.addTupleParam("axes", dimensions);
+ return invoke(getParent(), "_np_transpose", this, params);
+ }
+
+ /**
+ * Broadcasts this {@code NDArray} to be the given shape.
+ *
+ * @param shape the new {@link Shape} of this {@code NDArray}
+ * @return the broadcasted {@code NDArray}
+ */
+ public NDArray broadcast(Shape shape) {
+ OpParams params = new OpParams();
+ params.setShape(shape);
+ return invoke(getParent(), "_npi_broadcast_to", this, params);
+ }
+
+ /**
+ * Returns the indices of the maximum values into the flattened {@code NDArray}.
+ *
+ * @return a {@code NDArray} containing indices
+ */
+ public NDArray argMax() {
+ if (isEmpty()) {
+ throw new IllegalArgumentException("attempt to get argMax of an empty MxNDArray");
+ }
+ return invoke(getParent(), "_npi_argmax", this, null);
+ }
+
+ /**
+ * Returns the indices of the maximum values along given axis.
+ *
+ * @param axis the axis along which to find maximum values
+ * @return a {@code NDArray} containing indices
+ */
+ public NDArray argMax(int axis) {
+ OpParams params = new OpParams();
+ params.addParam("axis", axis);
+ return invoke(getParent(), "_npi_argmax", this, params);
+ }
+
+ /**
+ * Returns the indices of the minimum values into the flattened {@code NDArray}.
+ *
+ * @return a {@code NDArray} containing indices
+ */
+ public NDArray argMin() {
+ if (isEmpty()) {
+ throw new IllegalArgumentException("attempt to get argMin of an empty MxNDArray");
+ }
+ return invoke(getParent(), "_npi_argmin", this, null);
+ }
+
+ /**
+ * Returns the indices of the minimum values along given axis.
+ *
+ * @param axis the axis along which to find minimum values
+ * @return a {@code NDArray} containing indices
+ */
+ public NDArray argMin(int axis) {
+ OpParams params = new OpParams();
+ params.addParam("axis", axis);
+ return invoke(getParent(), "_npi_argmin", this, params);
+ }
+
+ /**
+ * Returns percentile for this {@code NDArray}.
+ *
+ * @param percentile the target percentile in range of 0..100
+ * @return the result {@code NDArray}
+ */
+ public NDArray percentile(Number percentile) {
+ throw new UnsupportedOperationException("Not implemented yet.");
+ }
+
+ /**
+ * Returns median along given dimension(s).
+ *
+ * @param percentile the target percentile in range of 0..100
+ * @param dimension the dimension to calculate percentile for
+ * @return the result {@code NDArray} NDArray
+ */
+ public NDArray percentile(Number percentile, int[] dimension) {
+ throw new UnsupportedOperationException("Not implemented yet.");
+ }
+
+ /**
+ * Returns median value for this {@code NDArray}.
+ *
+ * @return the median {@code NDArray}
+ */
+ public NDArray median() {
+ throw new UnsupportedOperationException("Not implemented yet.");
+ }
+
+ /**
+ * Returns median value along given axes.
+ *
+ * @param axes the axes along which to perform the median operation
+ * @return the median {@code NDArray} along the specified axes
+ */
+ public NDArray median(int[] axes) {
+ throw new UnsupportedOperationException("Not implemented yet.");
+ }
+
+ /**
+ * Returns the indices of elements that are non-zero.
+ *
+ * <p>Note that the behavior is slightly different from numpy.nonzero. Numpy returns a tuple of
+ * NDArray, one for each dimension of NDArray. DJL nonzero returns only one {@code NDArray} with
+ * last dimension containing all dimension of indices.
+ *
+ * @return the indices of the elements that are non-zero
+ */
+ public NDArray nonzero() {
+ NDArray thisArr =
+ (getDataType() == DataType.BOOLEAN) ? toType(DataType.INT32, false) : this;
+ return invoke(getParent(), "_npx_nonzero", thisArr, null);
+ }
+
+ /**
+ * Returns element-wise inverse gauss error function of the {@code NDArray}.
+ *
+ * @return The inverse of gauss error of the {@code NDArray}, element-wise
+ */
+ public NDArray erfinv() {
+ return invoke(getParent(), "erfinv", this, null);
+ }
+
+ /**
+ * Returns the norm of this {@code NDArray}.
+ *
+ * @param keepDims If this is set to True, the axes which are normed over are left in the result
+ * as dimensions with size one. With this option the result will broadcast correctly against
+ * the original x.
+ * @return the norm of this {@code NDArray}
+ */
+ public NDArray norm(boolean keepDims) {
+ OpParams params = new OpParams();
+ params.add("flag", -2);
+ params.addParam("keepdims", keepDims);
+ return invoke(getParent(), "_npi_norm", this, params);
+ }
+
+ /**
+ * Returns the norm of this {@code NDArray}.
+ *
+ * @param ord Order of the norm.
+ * @param axes If axes contains an integer, it specifies the axis of x along which to compute
+ * the vector norms. If axis contains 2 integers, it specifies the axes that hold 2-D
+ * matrices, and the matrix norms of these matrices are computed.
+ * @param keepDims keepDims If this is set to True, the axes which are normed over are left in
+ * the result as dimensions with size one. With this option the result will broadcast
+ * correctly against the original x.
+ * @return the norm of this {@code NDArray}
+ */
+ public NDArray norm(int ord, int[] axes, boolean keepDims) {
+ OpParams params = new OpParams();
+ params.addParam("ord", (double) ord);
+ params.addTupleParam("axis", axes);
+ params.addParam("keepdims", keepDims);
+ return invoke(getParent(), "_npi_norm", this, params);
+ }
+
+ // public MxNDArray oneHot(int depth) {
+ // return LazyNDArray.super.oneHot(depth);
+ // }
+
+ /**
+ * Returns a one-hot {@code NDArray}.
+ *
+ * <ul>
+ * <li>The locations represented by indices take value onValue, while all other locations take
+ * value offValue.
+ * <li>If the input {@code NDArray} is rank N, the output will have rank N+1. The new axis is
+ * appended at the end.
+ * <li>If {@code NDArray} is a scalar the output shape will be a vector of length depth.
+ * <li>If {@code NDArray} is a vector of length features, the output shape will be features x
+ * depth.
+ * <li>If {@code NDArray} is a matrix with shape [batch, features], the output shape will be
+ * batch x features x depth.
+ * </ul>
+ *
+ * @param depth Depth of the one hot dimension.
+ * @param onValue The value assigned to the locations represented by indices.
+ * @param offValue The value assigned to the locations not represented by indices.
+ * @param dataType dataType of the output.
+ * @return one-hot encoding of this {@code NDArray}
+ */
+ public NDArray oneHot(int depth, float onValue, float offValue, DataType dataType) {
+ OpParams params = new OpParams();
+ params.add("depth", depth);
+ params.add("on_value", onValue);
+ params.add("off_value", offValue);
+ params.add("dtype", dataType);
+ return invoke(getParent(), "_npx_one_hot", this, params).toType(dataType, false);
+ }
+
+ /**
+ * Batchwise product of this {@code NDArray} and the other {@code NDArray}.
+ *
+ * <ul>
+ * <li>batchDot is used to compute dot product of x and y when x and y are data in batch,
+ * namely N-D (N greater or equal to 3) arrays in shape of (B0, …, B_i, :, :). For
+ * example, given x with shape (B_0, …, B_i, N, M) and y with shape (B_0, …, B_i, M, K),
+ * the result array will have shape (B_0, …, B_i, N, K), which is computed by:
+ * batch_dot(x,y)[b_0, ..., b_i, :, :] = dot(x[b_0, ..., b_i, :, :], y[b_0, ..., b_i, :,
+ * :])
+ * </ul>
+ *
+ * @param other the other {@code NDArray} to perform batch dot product with
+ * @return the result {@code NDArray}
+ */
+ public NDArray batchDot(NDArray other) {
+ return invoke(getParent(), "_npx_batch_dot", new NDArray[] {this, other}, null);
+ }
+
+ /**
+ * Returns an internal representative of Native {@code NDArray}.
+ *
+ * <p>This method should only be used by Engine provider
+ *
+ * @return an internal representative of Native {@code NDArray}
+ */
+ public NDArrayEx getNDArrayInternal() {
+ return mxNDArrayEx;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public void close() {
+ if (!getClosed()) {
+ logger.debug(String.format("Start to free NDArray instance: %S", this.getUid()));
+ super.freeSubResources();
+
+ if (this.getHandle() != null) {
+ JnaUtils.freeNdArray(this.getHandle());
+ }
+ setClosed(true);
+ logger.debug(String.format("Finish to free NDArray instance: %S", this.getUid()));
+ }
+ }
+
+ /**
+ * Returns {@code true} if this {@code NDArray} is special case: no-value {@code NDArray}.
+ *
+ * @return {@code true} if this NDArray is empty
+ */
+ public boolean isEmpty() {
+ return getShape().size() == 0;
+ }
+
+ boolean isSparse() {
+ return getSparseFormat() != SparseFormat.DENSE;
+ }
+
+ boolean shapeEquals(NDArray other) {
+ return getShape().equals(other.getShape());
+ }
+
+ /**
+ * An engine specific generic invocation to native operation.
+ *
+ * <p>You should avoid using this function if possible. Since this function is engine specific,
+ * using this API may cause a portability issue. Native operation may not compatible between
+ * each version.
+ *
+ * @param parent the parent {@link MxResource} of the created {@link NDList}
+ * @param operation the native operation to perform
+ * @param src the {@link NDList} of source {@link NDArray}
+ * @param params the parameters to be passed to the native operation
+ * @return the output array of {@link NDArray}
+ * @throws IllegalArgumentException if operation is not supported by Engine
+ */
+ public static NDList invoke(
+ MxResource parent, String operation, NDList src, PairList<String, ?> params) {
+ return new NDList(JnaUtils.op(operation).invoke(parent, src.toArray(EMPTY), params));
+ }
+
+ /**
+ * An engine specific generic invocation to native operator.
+ *
+ * <p>You should avoid using this function if possible. Since this function is engine specific,
+ * using this API may cause portability issues. A native operation may not be compatible between
+ * each version.
+ *
+ * @param operation the native operation to perform
+ * @param src the {@link NDList} of source {@link NDArray}
+ * @param dest the {@link NDList} to save output to
+ * @param params the parameters to be passed to the native operator
+ * @throws IllegalArgumentException if operation is not supported by Engine
+ */
+ public static void invoke(
+ String operation, NDList src, NDList dest, PairList<String, ?> params) {
+ invoke(operation, src.toArray(EMPTY), dest.toArray(EMPTY), params);
+ }
+
+ /**
+ * An engine specific generic invocation to native operator.
+ *
+ * <p>You should avoid using this function if possible. Since this function is engine specific,
+ * using this API may cause portability issues. A native operation may not be compatible between
+ * each version.
+ *
+ * @param parent the parent {@link MxResource} to manage this instance
+ * @param operation the native operation to perform
+ * @param src the array of source {@link NDArray}
+ * @param params the parameters to be passed to the native operator
+ * @return the output array of {@link NDArray}
+ */
+ public static NDArray invoke(
+ MxResource parent, String operation, NDArray[] src, PairList<String, ?> params) {
+ return JnaUtils.op(operation).invoke(parent, src, params)[0];
+ }
+
+ /**
+ * An engine specific generic invocation to native operation.
+ *
+ * <p>You should avoid using this function if possible. Since this function is engine specific,
+ * using this API may cause a portability issue. Native operation may not be compatible between
+ * each version.
+ *
+ * @param operation the native operation to perform
+ * @param src the {@link NDList} of source {@link NDArray}
+ * @param dest the {@link NDList} to save output to
+ * @param params the parameters to be passed to the native operation
+ * @throws IllegalArgumentException if operation is not supported by Engine
+ */
+ public static void invoke(
+ String operation, NDArray[] src, NDArray[] dest, PairList<String, ?> params) {
+ JnaUtils.op(operation).invoke(src, dest, params);
+ }
+
+ /**
+ * An engine specific generic invocation to native operator.
+ *
+ * <p>You should avoid using this function if possible. Since this function is engine specific,
+ * using this API may cause portability issues. A native operation may not be compatible between
+ * each version.
+ *
+ * @param parent the parent {@link MxResource} to manage this instance
+ * @param operation the native operation to perform
+ * @param src the source {@link NDArray}
+ * @param params the parameters to be passed to the native operator
+ * @return the output array of {@link NDArray}
+ */
+ public static NDArray invoke(
+ MxResource parent, String operation, NDArray src, PairList<String, ?> params) {
+ return invoke(parent, operation, new NDArray[] {src}, params);
+ }
+
+ /**
+ * An engine specific generic invocation to native operator.
+ *
+ * <p>You should avoid using this function if possible. Since this function is engine specific,
+ * using this API may cause portability issues. A native operation may not be compatible between
+ * each version.
+ *
+ * @param parent the parent {@link MxResource} to manage this instance
+ * @param operation the native operation to perform
+ * @param params the parameters to be passed to the native operator
+ * @return the output array of {@link NDArray}
+ */
+ public static NDArray invoke(MxResource parent, String operation, PairList<String, ?> params) {
+ return invoke(parent, operation, EMPTY, params);
+ }
+
+ /**
+ * Encodes {@code MxNDArray} to byte array.
+ *
+ * @return byte array
+ */
+ public byte[] encode() {
+ return NDSerializer.encode(this);
+ }
+
+ /**
+ * Draws samples from a uniform distribution.
+ *
+ * <p>Samples are uniformly distributed over the half-open interval [low, high) (includes low,
+ * but excludes high). In other words, any value within the given interval is equally likely to
+ * be drawn by uniform.
+ *
+ * @param parent {@link MxResource} of this instance
+ * @param low the lower boundary of the output interval. All values generated will be greater
+ * than or equal to low.
+ * @param high the upper boundary of the output interval. All values generated will be less than
+ * high.
+ * @param shape the {@link Shape} of the {@link NDArray}
+ * @param dataType the {@link DataType} of the {@link NDArray}
+ * @param device the {@link Device} of the {@link NDArray}
+ * @return the drawn samples {@link NDArray}
+ */
+ public static NDArray randomUniform(
+ MxResource parent,
+ float low,
+ float high,
+ Shape shape,
+ DataType dataType,
+ Device device) {
+ OpParams params = new OpParams();
+ params.addParam("low", low);
+ params.addParam("high", high);
+ params.addParam("size", shape);
+ params.setDevice(device);
+ params.setDataType(dataType);
+ return invoke(parent, "_npi_uniform", params);
+ }
+
+ /**
+ * Draws samples from a uniform distribution.
+ *
+ * <p>Samples are uniformly distributed over the half-open interval [low, high) (includes low,
+ * but excludes high). In other words, any value within the given interval is equally likely to
+ * be drawn by uniform.
+ *
+ * @param parent {@link MxResource} of this instance
+ * @param low the lower boundary of the output interval. All values generated will be greater
+ * than or equal to low.
+ * @param high the upper boundary of the output interval. All values generated will be less than
+ * high.
+ * @param shape the {@link Shape} of the {@link NDArray}
+ * @param dataType the {@link DataType} of the {@link NDArray}
+ * @return the drawn samples {@link NDArray}
+ */
+ private static NDArray randomUniform(
+ MxResource parent, float low, float high, Shape shape, DataType dataType) {
+ return randomUniform(parent, low, high, shape, dataType, Device.defaultIfNull(null));
+ }
+
+ /**
+ * Draws random samples from a normal (Gaussian) distribution.
+ *
+ * @param parent {@link MxResource} of this instance
+ * @param loc the mean (centre) of the distribution
+ * @param scale the standard deviation (spread or "width") of the distribution
+ * @param shape the output {@link Shape}
+ * @param dataType the {@link DataType} of the {@link NDArray}
+ * @param device the {@link Device} of the {@link NDArray}
+ * @return the drawn samples {@link NDArray}
+ */
+ public static NDArray randomNormal(
+ MxResource parent,
+ float loc,
+ float scale,
+ Shape shape,
+ DataType dataType,
+ Device device) {
+ if (device == null) {
+ return randomNormal(parent, loc, scale, shape, dataType);
+ }
+ return randomNormal(parent, loc, scale, shape, dataType);
+ }
+
+ /**
+ * Draws random samples from a normal (Gaussian) distribution.
+ *
+ * @param parent {@link MxResource} of this instance
+ * @param loc the mean (centre) of the distribution
+ * @param scale the standard deviation (spread or "width") of the distribution
+ * @param shape the output {@link Shape}
+ * @param dataType the {@link DataType} of the {@link NDArray}
+ * @return the drawn samples {@link NDArray}
+ */
+ public static NDArray randomNormal(
+ MxResource parent, float loc, float scale, Shape shape, DataType dataType) {
+ OpParams params = new OpParams();
+ params.addParam("loc", loc);
+ params.addParam("scale", scale);
+ params.addParam("size", shape);
+ params.setDevice(Device.defaultIfNull(null));
+ params.setDataType(dataType);
+ return invoke(parent, "_npi_normal", params);
+ }
+
+ /**
+ * Decodes {@link NDArray} through byte array.
+ *
+ * @param parent the parent {@link MxResource} to create the {@link NDArray}
+ * @param bytes byte array to load from
+ * @return {@link NDArray}
+ */
+ static NDArray decode(MxResource parent, byte[] bytes) {
+ try (DataInputStream dis = new DataInputStream(new ByteArrayInputStream(bytes))) {
+ return NDSerializer.decode(parent, dis);
+ } catch (IOException e) {
+ throw new IllegalArgumentException("NDArray decoding failed", e);
+ }
+ }
+
+ /**
+ * Decodes {@link NDArray} through {@link DataInputStream}.
+ *
+ * @param parent the parent {@link MxResource} to create the {@link NDArray}
+ * @param is input stream data to load from
+ * @return {@link NDArray}
+ * @throws IOException data is not readable
+ */
+ public static NDArray decode(MxResource parent, InputStream is) throws IOException {
+ return NDSerializer.decode(parent, is);
+ }
+
+ /**
+ * Converts this {@code NDArray} to a Number array based on its {@link DataType}.
+ *
+ * @return a Number array
+ */
+ public Number[] toArray() {
+ switch (getDataType()) {
+ case FLOAT16:
+ case FLOAT32:
+ float[] floatArray = toFloatArray();
+ return IntStream.range(0, floatArray.length)
+ .mapToObj(i -> floatArray[i])
+ .toArray(Number[]::new);
+ case FLOAT64:
+ return Arrays.stream(toDoubleArray()).boxed().toArray(Double[]::new);
+ case INT32:
+ return Arrays.stream(toIntArray()).boxed().toArray(Integer[]::new);
+ case INT64:
+ return Arrays.stream(toLongArray()).boxed().toArray(Long[]::new);
+ case BOOLEAN:
+ case INT8:
+ ByteBuffer bb = toByteBuffer();
+ Byte[] ret = new Byte[bb.remaining()];
+ for (int i = 0; i < ret.length; ++i) {
+ ret[i] = bb.get();
+ }
+ return ret;
+ case UINT8:
+ return Arrays.stream(toUint8Array()).boxed().toArray(Integer[]::new);
+ default:
+ throw new IllegalStateException("Unsupported DataType: " + getDataType());
+ }
+ }
+
+ /**
+ * Converts this {@code NDArray} to a boolean array.
+ *
+ * @return a boolean array
+ * @throws IllegalStateException when {@link DataType} of this {@code NDArray} mismatches
+ */
+ public boolean[] toBooleanArray() {
+ if (getDataType() != DataType.BOOLEAN) {
+ throw new IllegalStateException(
+ "DataType mismatch, Required boolean" + " Actual " + getDataType());
+ }
+ ByteBuffer bb = toByteBuffer();
+ boolean[] ret = new boolean[bb.remaining()];
+ for (int i = 0; i < ret.length; ++i) {
+ ret[i] = bb.get() != 0;
+ }
+ return ret;
+ }
+
+ /**
+ * Converts this {@code NDArray} to a double array.
+ *
+ * @return a double array
+ * @throws IllegalStateException when {@link DataType} of this {@code NDArray} mismatches
+ */
+ public double[] toDoubleArray() {
+ if (getDataType() != DataType.FLOAT64) {
+ throw new IllegalStateException(
+ "DataType mismatch, Required double" + " Actual " + getDataType());
+ }
+ DoubleBuffer db = toByteBuffer().asDoubleBuffer();
+ double[] ret = new double[db.remaining()];
+ db.get(ret);
+ return ret;
+ }
+
+ /**
+ * Converts this {@code NDArray} to a float array.
+ *
+ * @return a float array
+ * @throws IllegalStateException when {@link DataType} of this {@code NDArray} mismatches
+ */
+ public float[] toFloatArray() {
+ if (getDataType() == DataType.FLOAT16) {
+ return Float16Utils.fromByteBuffer(toByteBuffer());
+ } else if (getDataType() != DataType.FLOAT32) {
+ throw new IllegalStateException(
+ "DataType mismatch, Required float, Actual " + getDataType());
+ }
+ FloatBuffer fb = toByteBuffer().asFloatBuffer();
+ float[] ret = new float[fb.remaining()];
+ fb.get(ret);
+ return ret;
+ }
+
+ /**
+ * Converts this {@code NDArray} to an int array.
+ *
+ * @return an int array
+ * @throws IllegalStateException when {@link DataType} of this {@code NDArray} mismatches
+ */
+ public int[] toIntArray() {
+ if (getDataType() != DataType.INT32) {
+ throw new IllegalStateException(
+ "DataType mismatch, Required int" + " Actual " + getDataType());
+ }
+ IntBuffer ib = toByteBuffer().asIntBuffer();
+ int[] ret = new int[ib.remaining()];
+ ib.get(ret);
+ return ret;
+ }
+
+ /**
+ * Converts this {@code NDArray} to a long array.
+ *
+ * @return a long array
+ * @throws IllegalStateException when {@link DataType} of this {@code NDArray} mismatches
+ */
+ public long[] toLongArray() {
+ if (getDataType() != DataType.INT64) {
+ throw new IllegalStateException(
+ "DataType mismatch, Required long" + " Actual " + getDataType());
+ }
+ LongBuffer lb = toByteBuffer().asLongBuffer();
+ long[] ret = new long[lb.remaining()];
+ lb.get(ret);
+ return ret;
+ }
+
+ /**
+ * Converts this {@code NDArray} to a byte array.
+ *
+ * @return a byte array
+ * @throws IllegalStateException when {@link DataType} of this {@code NDArray} mismatches
+ */
+ public byte[] toByteArray() {
+ ByteBuffer bb = toByteBuffer();
+ if (bb.hasArray()) {
+ return bb.array();
+ }
+ byte[] buf = new byte[bb.remaining()];
+ bb.get(buf);
+ return buf;
+ }
+
+ /**
+ * Converts this {@code NDArray} to a uint8 array.
+ *
+ * @return a uint8 array
+ * @throws IllegalStateException when {@link DataType} of this {@code NDArray} mismatches
+ */
+ public int[] toUint8Array() {
+ ByteBuffer bb = toByteBuffer();
+ int[] buf = new int[bb.remaining()];
+ for (int i = 0; i < buf.length; ++i) {
+ buf[i] = bb.get() & 0xff;
+ }
+ return buf;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public String toString() {
+ if (getClosed()) {
+ return "This array is already closed";
+ }
+ return toDebugString(MAX_SIZE, MAX_DEPTH, MAX_ROWS, MAX_COLUMNS);
+ }
+
+ /**
+ * Runs the debug string representation of this {@code NDArray}.
+ *
+ * @param maxSize the maximum elements to print out
+ * @param maxDepth the maximum depth to print out
+ * @param maxRows the maximum rows to print out
+ * @param maxColumns the maximum columns to print out
+ * @return the debug string representation of this {@code NDArray}
+ */
+ String toDebugString(int maxSize, int maxDepth, int maxRows, int maxColumns) {
+ return NDFormat.format(this, maxSize, maxDepth, maxRows, maxColumns);
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/NDArrayEx.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/NDArrayEx.java
new file mode 100644
index 0000000..047ff04
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/NDArrayEx.java
@@ -0,0 +1,1107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.ndarray;
+
+import java.util.Arrays;
+import java.util.List;
+import java.util.stream.Stream;
+import org.apache.mxnet.engine.Device;
+import org.apache.mxnet.engine.OpParams;
+import org.apache.mxnet.jna.JnaUtils;
+import org.apache.mxnet.ndarray.types.DataType;
+import org.apache.mxnet.ndarray.types.Shape;
+import org.apache.mxnet.ndarray.types.SparseFormat;
+
+/** An internal interface that encapsulates engine specific operations. */
+@SuppressWarnings("MissingJavadocMethod")
+public class NDArrayEx {
+
+ private static final NDArrayIndexer INDEXER = new NDArrayIndexer();
+
+ private NDArray array;
+
+ /**
+ * Constructs an {@code MxNDArrayEx} given a {@link NDArray}.
+ *
+ * @param parent the {@link NDArray} to extend
+ */
+ NDArrayEx(NDArray parent) {
+ this.array = parent;
+ }
+
+ // TODO only used to calculate zero-dim numpy shape
+ // remove it once MXNet have all the np op that we support
+ private Shape deriveBroadcastedShape(Shape lhs, Shape rhs) {
+ long[] result = new long[Math.max(lhs.dimension(), rhs.dimension())];
+ long lDiff = result.length - lhs.dimension();
+ long rDiff = result.length - rhs.dimension();
+ for (int i = 0; i < result.length; i++) {
+ long l = 1;
+ long r = 1;
+ if (i >= lDiff) {
+ l = lhs.get(Math.toIntExact(i - lDiff));
+ }
+ if (i >= rDiff) {
+ r = rhs.get(Math.toIntExact(i - rDiff));
+ }
+ if (l != r) {
+ if (l != 1 && r != 1) {
+ throw new IllegalArgumentException(
+ "operands could not be broadcast together with shapes "
+ + lhs
+ + " "
+ + rhs);
+ }
+ result[i] = (l == 1) ? r : l;
+ } else {
+ result[i] = l;
+ }
+ }
+ return new Shape(result);
+ }
+
+ ////////////////////////////////////////
+ // MxNDArrays
+ ////////////////////////////////////////
+ /**
+ * Applies reverse division with a scalar - i.e., (n / thisArrayValues).
+ *
+ * @param n the Value to use for reverse division
+ * @return a copy of the array after applying reverse division
+ */
+ public NDArray rdiv(Number n) {
+ OpParams params = new OpParams();
+ params.add("scalar", n.toString());
+ return NDArray.invoke(getArray().getParent(), "_rdiv_scalar", array, params);
+ }
+
+ /**
+ * Applies reverse division with a scalar - i.e., (n / thisArrayValues).
+ *
+ * @param b the ndarray to use for reverse division
+ * @return a copy of the array after applying reverse division
+ */
+ public NDArray rdiv(NDArray b) {
+ return b.div(array);
+ }
+
+ /**
+ * Applies in place reverse division - i.e., (n / thisArrayValues).
+ *
+ * @param n the value to use for reverse division
+ * @return this array after applying reverse division
+ */
+ public NDArray rdivi(Number n) {
+ OpParams params = new OpParams();
+ params.add("scalar", n.toString());
+ NDArray.invoke("_rdiv_scalar", new NDArray[] {array}, new NDArray[] {array}, params);
+ return array;
+ }
+
+ /**
+ * Applies in place reverse division - i.e., (n / thisArrayValues).
+ *
+ * @param b the ndarray to use for reverse division
+ * @return this array after applying reverse division
+ */
+ public NDArray rdivi(NDArray b) {
+ NDArray.invoke("elemwise_div", new NDArray[] {b, array}, new NDArray[] {array}, null);
+ return array;
+ }
+
+ /**
+ * Applies reverse subtraction with duplicates - i.e., (n - thisArrayValues).
+ *
+ * @param n the value to use for reverse subtraction
+ * @return a copy of array after reverse subtraction
+ */
+ public NDArray rsub(Number n) {
+ return array.sub(n).neg();
+ }
+
+ /**
+ * Applies reverse subtraction with duplicates - i.e., (n - thisArrayValues).
+ *
+ * @param b the ndarray to use for reverse subtraction
+ * @return a copy of the array after reverse subtraction
+ */
+ public NDArray rsub(NDArray b) {
+ return array.sub(b).neg();
+ }
+
+ /**
+ * Applies reverse subtraction in place - i.e., (n - thisArrayValues).
+ *
+ * @param n the value to use for reverse subtraction
+ * @return this array after reverse subtraction
+ */
+ public NDArray rsubi(Number n) {
+ return array.subi(n).negi();
+ }
+
+ /**
+ * Applies reverse subtraction in place - i.e., (n - thisArrayValues).
+ *
+ * @param b the ndarray to use for reverse subtraction
+ * @return this array after reverse subtraction
+ */
+ public NDArray rsubi(NDArray b) {
+ return array.subi(b).negi();
+ }
+
+ public NDArray rmod(Number n) {
+ OpParams params = new OpParams();
+ params.add("scalar", n.toString());
+ return NDArray.invoke(getArray().getParent(), "_npi_rmod_scalar", array, params);
+ }
+
+ /**
+ * Applies reverse remainder of division with a scalar.
+ *
+ * @param b the value to use for reverse division
+ * @return a copy of array after applying reverse division
+ */
+ public NDArray rmod(NDArray b) {
+ return b.mod(array);
+ }
+
+ /**
+ * Applies in place reverse remainder of division with a scalar.
+ *
+ * @param n the value to use for reverse division
+ * @return this array after applying reverse division
+ */
+ public NDArray rmodi(Number n) {
+ OpParams params = new OpParams();
+ params.add("scalar", n.toString());
+ NDArray.invoke("_npi_rmod_scalar", new NDArray[] {array}, new NDArray[] {array}, params);
+ return array;
+ }
+
+ /**
+ * Applies in place reverse remainder of division.
+ *
+ * @param b the ndarray to use for reverse division
+ * @return this array after applying reverse division
+ */
+ public NDArray rmodi(NDArray b) {
+ NDArray.invoke("_npi_mod", new NDArray[] {b, array}, new NDArray[] {array}, null);
+ return array;
+ }
+
+ /**
+ * Reverses the power of each element being raised in the {@code NDArray}.
+ *
+ * @param n the value to use for reverse power
+ * @return a copy of array after applying reverse power
+ */
+ public NDArray rpow(Number n) {
+ OpParams params = new OpParams();
+ params.add("scalar", n.toString());
+ return NDArray.invoke(getArray().getParent(), "_npi_rpower_scalar", array, params);
+ }
+
+ /**
+ * Reverses the power of each element being raised in the {@code NDArray} in place.
+ *
+ * @param n the value to use for reverse power
+ * @return a copy of array after applying reverse power
+ */
+ public NDArray rpowi(Number n) {
+ OpParams params = new OpParams();
+ params.add("scalar", n.toString());
+ NDArray.invoke("_npi_rpower_scalar", new NDArray[] {array}, new NDArray[] {array}, params);
+ return array;
+ }
+
+ ////////////////////////////////////////
+ // Activations
+ ////////////////////////////////////////
+ /**
+ * Computes rectified linear activation.
+ *
+ * @return a copy of array after applying relu
+ */
+ public NDArray relu() {
+ OpParams params = new OpParams();
+ params.addParam("act_type", "relu");
+ return NDArray.invoke(getArray().getParent(), "_npx_activation", array, params);
+ }
+
+ public NDArray sigmoid() {
+ OpParams params = new OpParams();
+ params.addParam("act_type", "sigmoid");
+ return NDArray.invoke(getArray().getParent(), "_npx_activation", array, params);
+ }
+
+ public NDArray tanh() {
+ OpParams params = new OpParams();
+ params.addParam("act_type", "tanh");
+ return NDArray.invoke(getArray().getParent(), "_npx_activation", array, params);
+ }
+
+ public NDArray softPlus() {
+ OpParams params = new OpParams();
+ params.addParam("act_type", "softrelu");
+ return NDArray.invoke(getArray().getParent(), "_npx_activation", array, params);
+ }
+
+ public NDArray softSign() {
+ OpParams params = new OpParams();
+ params.addParam("act_type", "softsign");
+ return NDArray.invoke(getArray().getParent(), "_npx_activation", array, params);
+ }
+
+ public NDArray leakyRelu(float alpha) {
+ OpParams params = new OpParams();
+ params.addParam("act_type", "leaky");
+ params.addParam("slope", alpha);
+ return NDArray.invoke(getArray().getParent(), "_npx_leaky_relu", array, params);
+ }
+
+ public NDArray elu(float alpha) {
+ OpParams params = new OpParams();
+ params.addParam("act_type", "elu");
+ params.addParam("slope", alpha);
+ return NDArray.invoke(getArray().getParent(), "_npx_leaky_relu", array, params);
+ }
+
+ public NDArray selu() {
+ OpParams params = new OpParams();
+ params.addParam("act_type", "selu");
+ return NDArray.invoke(getArray().getParent(), "_npx_leaky_relu", array, params);
+ }
+
+ public NDArray gelu() {
+ OpParams params = new OpParams();
+ params.addParam("act_type", "gelu");
+ return NDArray.invoke(getArray().getParent(), "_npx_leaky_relu", array, params);
+ }
+
+ ////////////////////////////////////////
+ // Pooling Operations
+ ////////////////////////////////////////
+
+ public NDArray maxPool(Shape kernelShape, Shape stride, Shape padding, boolean ceilMode) {
+ OpParams params = new OpParams();
+ params.addParam("kernel", kernelShape);
+ params.add("pool_type", "max");
+ params.addParam("stride", stride);
+ params.addParam("pad", padding);
+ params.add("pooling_convention", ceilMode ? "full" : "valid");
+ return NDArray.invoke(getArray().getParent(), "_npx_pooling", getArray(), params);
+ }
+
+ public NDArray globalMaxPool() {
+ OpParams params = new OpParams();
+ params.add("kernel", getGlobalPoolingShapes(1));
+ params.add("pad", getGlobalPoolingShapes(0));
+ params.add("pool_type", "max");
+ params.addParam("global_pool", true);
+ try (NDArray temp =
+ NDArray.invoke(getArray().getParent(), "_npx_pooling", getArray(), params)) {
+ return temp.reshape(temp.getShape().size(0), temp.getShape().size(1));
+ }
+ }
+
+ public NDArray avgPool(
+ Shape kernelShape,
+ Shape stride,
+ Shape padding,
+ boolean ceilMode,
+ boolean countIncludePad) {
+ OpParams params = new OpParams();
+ params.addParam("kernel", kernelShape);
+ params.add("pool_type", "avg");
+ params.addParam("stride", stride);
+ params.addParam("pad", padding);
+ params.add("pooling_convention", ceilMode ? "full" : "valid");
+ params.addParam("count_include_pad", countIncludePad);
+ return NDArray.invoke(getArray().getParent(), "_npx_pooling", getArray(), params);
+ }
+
+ public NDArray globalAvgPool() {
+ OpParams params = new OpParams();
+ params.add("kernel", getGlobalPoolingShapes(1));
+ params.add("pad", getGlobalPoolingShapes(0));
+ params.add("pool_type", "avg");
+ params.addParam("global_pool", true);
+ try (NDArray temp =
+ NDArray.invoke(getArray().getParent(), "_npx_pooling", getArray(), params)) {
+ return temp.reshape(temp.getShape().size(0), temp.getShape().size(1));
+ }
+ }
+
+ public NDArray lpPool(
+ float normType, Shape kernelShape, Shape stride, Shape padding, boolean ceilMode) {
+ if (((int) normType) != normType) {
+ throw new IllegalArgumentException(
+ "float type of normType is not supported in MXNet engine, please use integer instead");
+ }
+ OpParams params = new OpParams();
+ params.addParam("p_value", (int) normType);
+ params.addParam("kernel", kernelShape);
+ params.add("pool_type", "lp");
+ params.addParam("stride", stride);
+ params.addParam("pad", padding);
+ params.add("pooling_convention", ceilMode ? "full" : "valid");
+
+ return NDArray.invoke(getArray().getParent(), "_npx_pooling", getArray(), params);
+ }
+
+ public NDArray globalLpPool(float normType) {
+ if (((int) normType) != normType) {
+ throw new IllegalArgumentException(
+ "float type of normType is not supported in MXNet engine, please use integer instead");
+ }
+ OpParams params = new OpParams();
+ params.add("pool_type", "lp");
+ params.addParam("p_value", (int) normType);
+ params.addParam("global_pool", true);
+ try (NDArray temp =
+ NDArray.invoke(getArray().getParent(), "_npx_pooling", getArray(), params)) {
+ return temp.reshape(temp.getShape().size(0), temp.getShape().size(1));
+ }
+ }
+
+ ////////////////////////////////////////
+ // Optimizer
+ ////////////////////////////////////////
+
+ // public void adadeltaUpdate(
+ // MxNDList inputs,
+ // MxNDList weights,
+ // float weightDecay,
+ // float rescaleGrad,
+ // float clipGrad,
+ // float rho,
+ // float epsilon) {
+ // MxNDArray weight = inputs.get(0);
+ // MxNDArray grad = inputs.get(1);
+ // MxNDArray s = inputs.get(2);
+ // MxNDArray delta = inputs.get(3);
+ //
+ // // create a baseManager to close all intermediate MxNDArrays
+ // try (NDManager subManager = NDManager.newBaseManager()) {
+ // subManager.tempAttachAll(inputs, weights);
+ //
+ // // Preprocess Gradient
+ // grad.muli(rescaleGrad);
+ // if (clipGrad > 0) {
+ // grad = grad.clip(-clipGrad, clipGrad);
+ // }
+ // grad.addi(weight.mul(weightDecay));
+ //
+ // // Update s, g, and delta
+ // s.muli(rho).addi(grad.square().mul(1 - rho));
+ // MxNDArray g = delta.add(epsilon).sqrt().div(s.add(epsilon).sqrt()).mul(grad);
+ // delta.muli(rho).addi(g.square().mul(1 - rho));
+ //
+ // // Update weight
+ // weight.subi(g);
+ // }
+ // }
+
+ public void adagradUpdate(
+ NDList inputs,
+ NDList weights,
+ float learningRate,
+ float weightDecay,
+ float rescaleGrad,
+ float clipGrad,
+ float epsilon) {
+ OpParams params = new OpParams();
+ params.addParam("lr", learningRate);
+ params.addParam("wd", weightDecay);
+ params.addParam("rescale_grad", rescaleGrad);
+ params.addParam("clip_gradient", clipGrad);
+
+ params.addParam("epsilon", epsilon);
+
+ NDArray.invoke("adagrad_update", inputs, weights, params);
+ }
+
+ public void adamUpdate(
+ NDList inputs,
+ NDList weights,
+ float learningRate,
+ float weightDecay,
+ float rescaleGrad,
+ float clipGrad,
+ float beta1,
+ float beta2,
+ float epsilon,
+ boolean lazyUpdate) {
+ OpParams params = new OpParams();
+ params.addParam("lr", learningRate);
+ params.addParam("wd", weightDecay);
+ params.addParam("rescale_grad", rescaleGrad);
+ params.addParam("clip_gradient", clipGrad);
+
+ params.addParam("beta1", beta1);
+ params.addParam("beta2", beta2);
+ params.addParam("epsilon", epsilon);
+ params.addParam("lazy_update", lazyUpdate);
+
+ NDArray.invoke("adam_update", inputs, weights, params);
+ }
+
+ public void rmspropUpdate(
+ NDList inputs,
+ NDList weights,
+ float learningRate,
+ float weightDecay,
+ float rescaleGrad,
+ float clipGrad,
+ float gamma1,
+ float gamma2,
+ float epsilon,
+ boolean centered) {
+ OpParams params = new OpParams();
+ params.addParam("lr", learningRate);
+ params.addParam("wd", weightDecay);
+ params.addParam("rescale_grad", rescaleGrad);
+ params.addParam("clip_gradient", clipGrad);
+
+ params.addParam("gamma1", gamma1);
+ params.addParam("epsilon", epsilon);
+
+ if (!centered) {
+ NDArray.invoke("rmsprop_update", inputs, weights, params);
+ } else {
+ params.addParam("gamma2", gamma2);
+
+ NDArray.invoke("rmspropalex_update", inputs, weights, params);
+ }
+ }
+
+ public void nagUpdate(
+ NDList inputs,
+ NDList weights,
+ float learningRate,
+ float weightDecay,
+ float rescaleGrad,
+ float clipGrad,
+ float momentum) {
+ OpParams params = new OpParams();
+ params.addParam("lr", learningRate);
+ params.addParam("wd", weightDecay);
+ params.addParam("rescale_grad", rescaleGrad);
+ params.addParam("clip_gradient", clipGrad);
+ params.addParam("momentum", momentum);
+ NDArray.invoke("nag_mom_update", inputs, weights, params);
+ }
+
+ public void sgdUpdate(
+ NDList inputs,
+ NDList weights,
+ float learningRate,
+ float weightDecay,
+ float rescaleGrad,
+ float clipGrad,
+ float momentum,
+ boolean lazyUpdate) {
+ OpParams params = new OpParams();
+ params.addParam("lr", learningRate);
+ params.addParam("wd", weightDecay);
+ params.addParam("rescale_grad", rescaleGrad);
+ params.addParam("clip_gradient", clipGrad);
+ params.addParam("lazy_update", lazyUpdate);
+
+ if (momentum != 0) {
+ params.addParam("momentum", momentum);
+ NDArray.invoke("sgd_mom_update", inputs, weights, params);
+ } else {
+ NDArray.invoke("sgd_update", inputs, weights, params);
+ }
+ }
+
+ ////////////////////////////////////////
+ // Neural network
+ ////////////////////////////////////////
+
+ public NDList convolution(
+ NDArray input,
+ NDArray weight,
+ NDArray bias,
+ Shape stride,
+ Shape padding,
+ Shape dilation,
+ int groups) {
+ OpParams params = new OpParams();
+ params.addParam("kernel", weight.getShape().slice(2));
+ params.addParam("stride", stride);
+ params.addParam("pad", padding);
+ params.addParam("dilate", dilation);
+ params.addParam("num_group", groups);
+ params.addParam("num_filter", weight.getShape().get(0));
+
+ NDList inputs = new NDList(input, weight);
+ if (bias != null) {
+ params.add("no_bias", false);
+ inputs.add(bias);
+ } else {
+ params.add("no_bias", true);
+ }
+
+ return NDArray.invoke(getArray().getParent(), "_npx_convolution", inputs, params);
+ }
+
+ public NDList deconvolution(
+ NDArray input,
+ NDArray weight,
+ NDArray bias,
+ Shape stride,
+ Shape padding,
+ Shape outPadding,
+ Shape dilation,
+ int groups) {
+ OpParams params = new OpParams();
+ params.addParam("kernel", weight.getShape().slice(2));
+ params.addParam("stride", stride);
+ params.addParam("pad", padding);
+ params.addParam("adj", outPadding);
+ params.addParam("dilate", dilation);
+ params.addParam("num_group", groups);
+ params.addParam("num_filter", weight.getShape().get(0));
+
+ NDList inputs = new NDList(input, weight);
+ if (bias != null) {
+ params.add("no_bias", false);
+ inputs.add(bias);
+ } else {
+ params.add("no_bias", true);
+ }
+
+ return NDArray.invoke(getArray().getParent(), "_npx_deconvolution", inputs, params);
+ }
+
+ public NDList linear(NDArray input, NDArray weight, NDArray bias) {
+ OpParams params = new OpParams();
+ params.addParam("num_hidden", weight.size(0));
+ params.addParam("flatten", false);
+ params.addParam("no_bias", bias == null);
+ NDList inputs = new NDList(input, weight);
+ if (bias != null) {
+ inputs.add(bias);
+ }
+
+ return NDArray.invoke(getArray().getParent(), "_npx_fully_connected", inputs, params);
+ }
+
+ public NDList embedding(NDArray input, NDArray weight, SparseFormat sparse) {
+ if (!sparse.equals(SparseFormat.DENSE) && !sparse.equals(SparseFormat.ROW_SPARSE)) {
+ throw new IllegalArgumentException("MXNet only supports row sparse");
+ }
+ OpParams params = new OpParams();
+ long inputDim = weight.getShape().get(0);
+ long outputDim = weight.getShape().get(1);
+ params.addParam("input_dim", inputDim);
+ params.addParam("output_dim", outputDim);
+ params.addParam("sparse_grad", sparse.getValue());
+ return NDArray.invoke(
+ getArray().getParent(), "_npx_embedding", new NDList(input, weight), params);
+ }
+
+ public NDList prelu(NDArray input, NDArray alpha) {
+ OpParams params = new OpParams();
+ params.addParam("act_type", "prelu");
+ return NDArray.invoke(
+ getArray().getParent(), "_npx_leaky_relu", new NDList(input, alpha), params);
+ }
+
+ public NDList dropout(NDArray input, float rate, boolean training) {
+ if (training != JnaUtils.autogradIsTraining()) {
+ throw new IllegalArgumentException(
+ "the mode of dropout in MXNet should align with the mode of GradientCollector");
+ }
+
+ OpParams params = new OpParams();
+ params.addParam("p", rate);
+
+ return NDArray.invoke(getArray().getParent(), "_npx_dropout", new NDList(input), params);
+ }
+
+ public NDList batchNorm(
+ NDArray input,
+ NDArray runningMean,
+ NDArray runningVar,
+ NDArray gamma,
+ NDArray beta,
+ int axis,
+ float momentum,
+ float eps,
+ boolean training) {
+ OpParams params = new OpParams();
+ params.addParam("axis", axis);
+ params.addParam("fix_gamma", gamma == null);
+ params.addParam("eps", eps);
+ params.addParam("momentum", momentum);
+
+ if (training != JnaUtils.autogradIsTraining()) {
+ throw new IllegalArgumentException(
+ "the mode of batchNorm in MXNet should align with the mode of GradientCollector");
+ }
+
+ return NDArray.invoke(
+ getArray().getParent(),
+ "_npx_batch_norm",
+ new NDList(input, gamma, beta, runningMean, runningVar),
+ params);
+ }
+
+ // public MxNDList rnn(
+ // MxNDArray input,
+ // MxNDArray state,
+ // MxNDList params,
+ // boolean hasBiases,
+ // int numLayers,
+ // RNN.Activation activation,
+ // double dropRate,
+ // boolean training,
+ // boolean bidirectional,
+ // boolean batchFirst) {
+ // int numParams = numLayers * ((hasBiases) ? 4 : 2) * ((bidirectional) ? 2 : 1);
+ // Preconditions.checkArgument(
+ // params.size() == numParams,
+ // "The size of Params is incorrect expect "
+ // + numParams
+ // + " parameters but got "
+ // + params.size());
+ //
+ // if (training != JnaUtils.autogradIsTraining()) {
+ // throw new IllegalArgumentException(
+ // "the mode of rnn in MXNet should align with the mode of
+ // GradientCollector");
+ // }
+ //
+ // if (batchFirst) {
+ // input = input.swapAxes(0, 1);
+ // }
+ //
+ // MxOpParams opParams = new MxOpParams();
+ // opParams.addParam("p", dropRate);
+ // opParams.addParam("state_size", state.getShape().tail());
+ // opParams.addParam("num_layers", numLayers);
+ // opParams.addParam("bidirectional", bidirectional);
+ // opParams.addParam("state_outputs", true);
+ // opParams.addParam("mode", activation == RNN.Activation.TANH ? "rnn_tanh" :
+ // "rnn_relu");
+ //
+ // MxNDList inputs = new MxNDList();
+ // inputs.add(input);
+ //
+ // try (MxNDList temp = new MxNDList()) {
+ // for (MxNDArray param : params) {
+ // temp.add(param.flatten());
+ // }
+ // MxNDArray tempParam = MxNDArrays.concat(temp);
+ // tempParam.attach(input.getManager());
+ // inputs.add(tempParam);
+ // }
+ //
+ // inputs.add(state);
+ //
+ // if (!batchFirst) {
+ // return getManager().invoke("_npx_rnn", inputs, opParams);
+ // }
+ //
+ // MxNDList result = getManager().invoke("_npx_rnn", inputs, opParams);
+ // try (MxNDArray temp = result.head()) {
+ // return new MxNDList(temp.swapAxes(0, 1), result.get(1));
+ // }
+ // }
+
+ // public MxNDList gru(
+ // MxNDArray input,
+ // MxNDArray state,
+ // MxNDList params,
+ // boolean hasBiases,
+ // int numLayers,
+ // double dropRate,
+ // boolean training,
+ // boolean bidirectional,
+ // boolean batchFirst) {
+ // int numParams = numLayers * ((hasBiases) ? 4 : 2) * ((bidirectional) ? 2 : 1);
+ // Preconditions.checkArgument(
+ // params.size() == numParams,
+ // "The size of Params is incorrect expect "
+ // + numParams
+ // + " parameters but got "
+ // + params.size());
+ //
+ // if (training != JnaUtils.autogradIsTraining()) {
+ // throw new IllegalArgumentException(
+ // "the mode of gru in MXNet should align with the mode of
+ // GradientCollector");
+ // }
+ //
+ // if (batchFirst) {
+ // input = input.swapAxes(0, 1);
+ // }
+ //
+ // MxOpParams opParams = new MxOpParams();
+ // opParams.addParam("p", dropRate);
+ // opParams.addParam("state_size", state.getShape().tail());
+ // opParams.addParam("num_layers", numLayers);
+ // opParams.addParam("bidirectional", bidirectional);
+ // opParams.addParam("state_outputs", true);
+ // opParams.addParam("mode", "gru");
+ //
+ // MxNDList inputs = new MxNDList();
+ // inputs.add(input);
+ //
+ // try (MxNDList temp = new MxNDList()) {
+ // for (MxNDArray param : params) {
+ // temp.add(param.flatten());
+ // }
+ // MxNDArray tempParam = MxNDArrays.concat(temp);
+ // tempParam.attach(input.getManager());
+ // inputs.add(tempParam);
+ // }
+ //
+ // inputs.add(state);
+ //
+ // if (!batchFirst) {
+ // return getManager().invoke("_npx_rnn", inputs, opParams);
+ // }
+ //
+ // MxNDList result = getManager().invoke("_npx_rnn", inputs, opParams);
+ // try (MxNDArray temp = result.head()) {
+ // return new MxNDList(temp.swapAxes(0, 1), result.get(1));
+ // }
+ // }
+ //
+ // public MxNDList lstm(
+ // MxNDArray input,
+ // MxNDList states,
+ // MxNDList params,
+ // boolean hasBiases,
+ // int numLayers,
+ // double dropRate,
+ // boolean training,
+ // boolean bidirectional,
+ // boolean batchFirst) {
+ // int numParams = numLayers * ((hasBiases) ? 4 : 2) * ((bidirectional) ? 2 : 1);
+ // Preconditions.checkArgument(
+ // params.size() == numParams,
+ // "The size of Params is incorrect expect "
+ // + numParams
+ // + " parameters but got "
+ // + params.size());
+ //
+ // if (training != JnaUtils.autogradIsTraining()) {
+ // throw new IllegalArgumentException(
+ // "the mode of lstm in MXNet should align with the mode of
+ // GradientCollector");
+ // }
+ //
+ // if (batchFirst) {
+ // input = input.swapAxes(0, 1);
+ // }
+ //
+ // MxOpParams opParams = new MxOpParams();
+ // opParams.addParam("mode", "lstm");
+ // opParams.addParam("p", dropRate);
+ // opParams.addParam("state_size", states.head().getShape().tail());
+ // opParams.addParam("state_outputs", true);
+ // opParams.addParam("num_layers", numLayers);
+ // opParams.addParam("bidirectional", bidirectional);
+ // opParams.addParam("lstm_state_clip_nan", true);
+ //
+ // MxNDList inputs = new MxNDList();
+ // inputs.add(input);
+ // try (MxNDList temp = new MxNDList()) {
+ // for (MxNDArray param : params) {
+ // temp.add(param.flatten());
+ // }
+ // MxNDArray tempParam = MxNDArrays.concat(temp);
+ // tempParam.attach(input.getManager());
+ // inputs.add(tempParam);
+ // }
+ // inputs.addAll(states);
+ //
+ // if (!batchFirst) {
+ // return getManager().invoke("_npx_rnn", inputs, opParams);
+ // }
+ //
+ // MxNDList result = getManager().invoke("_npx_rnn", inputs, opParams);
+ // try (MxNDArray temp = result.head()) {
+ // return new MxNDList(temp.swapAxes(0, 1), result.get(1), result.get(2));
+ // }
+ // }
+
+ ////////////////////////////////////////
+ // Image and CV
+ ////////////////////////////////////////
+
+ public NDArray normalize(float[] mean, float[] std) {
+ OpParams params = new OpParams();
+ params.addTupleParam("mean", mean);
+ params.addTupleParam("std", std);
+ return NDArray.invoke(getArray().getParent(), "_npx__image_normalize", array, params);
+ }
+
+ public NDArray toTensor() {
+ return NDArray.invoke(getArray().getParent(), "_npx__image_to_tensor", array, null);
+ }
+
+ public NDArray resize(int width, int height, int interpolation) {
+ if (array.isEmpty()) {
+ throw new IllegalArgumentException("attempt to resize of an empty MxNDArray");
+ }
+ OpParams params = new OpParams();
+ params.addTupleParam("size", width, height);
+ params.addParam("interp", interpolation);
+ return NDArray.invoke(getArray().getParent(), "_npx__image_resize", array, params);
+ }
+
+ public NDArray crop(int x, int y, int width, int height) {
+ OpParams params = new OpParams();
+ params.add("x", x);
+ params.add("y", y);
+ params.add("width", width);
+ params.add("height", height);
+ return NDArray.invoke(getArray().getParent(), "_npx__image_crop", array, params);
+ }
+
+ public NDArray randomFlipLeftRight() {
+ if (array.getDevice().getDeviceType().equals(Device.Type.GPU)) {
+ throw new UnsupportedOperationException("randomFlipLeftRight is not supported on GPU");
+ }
+ return NDArray.invoke(
+ getArray().getParent(), "_npx__image_random_flip_left_right", array, null);
+ }
+
+ public NDArray randomFlipTopBottom() {
+ if (array.getDevice().getDeviceType().equals(Device.Type.GPU)) {
+ throw new UnsupportedOperationException("randomFlipTopBottom is not supported on GPU");
+ }
+ return NDArray.invoke(
+ getArray().getParent(), "_npx__image_random_flip_top_bottom", array, null);
+ }
+
+ public NDArray randomBrightness(float brightness) {
+ if (array.getDevice().getDeviceType().equals(Device.Type.GPU)) {
+ throw new UnsupportedOperationException("randomBrightness is not supported on GPU");
+ }
+ OpParams params = new OpParams();
+ float min = Math.max(0, 1 - brightness);
+ float max = 1 + brightness;
+ params.addParam("min_factor", min);
+ params.addParam("max_factor", max);
+ return NDArray.invoke(
+ getArray().getParent(), "_npx__image_random_brightness", array, params);
+ }
+
+ public NDArray randomHue(float hue) {
+ if (array.getDevice().getDeviceType().equals(Device.Type.GPU)) {
+ throw new UnsupportedOperationException("randomHue is not supported on GPU");
+ }
+ OpParams params = new OpParams();
+ float min = Math.max(0, 1 - hue);
+ float max = 1 + hue;
+ params.addParam("min_factor", min);
+ params.addParam("max_factor", max);
+ return NDArray.invoke(getArray().getParent(), "_npx__image_random_hue", array, params);
+ }
+
+ public NDArray randomColorJitter(
+ float brightness, float contrast, float saturation, float hue) {
+ if (array.getDevice().getDeviceType().equals(Device.Type.GPU)) {
+ throw new UnsupportedOperationException("randomColorJitter is not supported on GPU");
+ }
+ OpParams params = new OpParams();
+ params.addParam("brightness", brightness);
+ params.addParam("contrast", contrast);
+ params.addParam("saturation", saturation);
+ params.addParam("hue", hue);
+ return NDArray.invoke(
+ getArray().getParent(), "_npx__image_random_color_jitter", array, params);
+ }
+
+ public NDArrayIndexer getIndexer() {
+ return INDEXER;
+ }
+
+ ////////////////////////////////////////
+ // Miscellaneous
+ ////////////////////////////////////////
+
+ @SuppressWarnings("PMD.UseTryWithResources")
+ public NDArray where(NDArray condition, NDArray other) {
+ NDArray array1;
+ NDArray array2;
+ condition =
+ (condition.getDataType() == DataType.BOOLEAN)
+ ? condition.toType(DataType.INT32, false)
+ : condition;
+ if (array.getDataType() != other.getDataType()) {
+ throw new IllegalArgumentException(
+ "DataType mismatch, required "
+ + array.getDataType()
+ + " actual "
+ + other.getDataType());
+ }
+ if (!array.shapeEquals(other)) {
+ Shape res = deriveBroadcastedShape(array.getShape(), other.getShape());
+ array1 = (!res.equals(array.getShape())) ? array.broadcast(res) : array;
+ array2 = (!res.equals(other.getShape())) ? other.broadcast(res) : other;
+ } else {
+ array1 = array;
+ array2 = other;
+ }
+ try {
+ return NDArray.invoke(
+ getArray().getParent(),
+ "where",
+ new NDArray[] {condition, array1, array2},
+ null);
+ } finally {
+ if (array1 != array) {
+ array1.close();
+ }
+ if (array2 != other) {
+ array2.close();
+ }
+ }
+ }
+
+ public NDArray stack(NDList arrays, int axis) {
+ OpParams params = new OpParams();
+ params.addParam("axis", axis);
+ NDArray[] srcArray = new NDArray[arrays.size() + 1];
+ srcArray[0] = array;
+ System.arraycopy(arrays.toArray(new NDArray[0]), 0, srcArray, 1, arrays.size());
+ return NDArray.invoke(getArray().getParent(), "_npi_stack", srcArray, params);
+ }
+
+ /**
+ * Check two criteria of concat input: 1. no scalar 2. dimensions of all the array must be the
+ * same.
+ *
+ * @param list input {@link NDList}
+ */
+ public static void checkConcatInput(NDList list) {
+ NDArray[] arrays = list.toArray(new NDArray[0]);
+ if (Stream.of(arrays).allMatch(array -> array.getShape().dimension() == 0)) {
+ throw new IllegalArgumentException(
+ "scalar(zero-dimensional) arrays cannot be concatenated");
+ }
+ int dimension = arrays[0].getShape().dimension();
+ for (int i = 1; i < arrays.length; i++) {
+ if (arrays[i].getShape().dimension() != dimension) {
+ throw new IllegalArgumentException(
+ "all the input arrays must have same number of dimensions, but the array at index 0 has "
+ + dimension
+ + " dimension(s) and the array at index "
+ + i
+ + " has "
+ + arrays[i].getShape().dimension()
+ + " dimension(s)");
+ }
+ }
+ }
+
+ public NDArray concat(NDList list, int axis) {
+ checkConcatInput(list);
+
+ OpParams params = new OpParams();
+ // MXNet backend use dim as argument name
+ params.addParam("axis", axis);
+ NDArray[] srcArray = new NDArray[list.size() + 1];
+ srcArray[0] = array;
+ System.arraycopy(list.toArray(new NDArray[0]), 0, srcArray, 1, list.size());
+ return NDArray.invoke(getArray().getParent(), "_npi_concatenate", srcArray, params);
+ }
+
+ public NDList multiBoxTarget(
+ NDList inputs,
+ float iouThreshold,
+ float ignoreLabel,
+ float negativeMiningRatio,
+ float negativeMiningThreshold,
+ int minNegativeSamples) {
+ OpParams parameters = new OpParams();
+ parameters.add("minimum_negative_samples", minNegativeSamples);
+ parameters.add("overlap_threshold", iouThreshold);
+ parameters.add("ignore_label", ignoreLabel);
+ parameters.add("negative_mining_ratio", negativeMiningRatio);
+ parameters.add("negative_mining_thresh", negativeMiningThreshold);
+ return NDArray.invoke(getArray().getParent(), "MultiBoxTarget", inputs, parameters);
+ }
+
+ public NDList multiBoxPrior(
+ List<Float> sizes,
+ List<Float> ratios,
+ List<Float> steps,
+ List<Float> offsets,
+ boolean clip) {
+ OpParams parameters = new OpParams();
+ parameters.add("sizes", sizes);
+ parameters.add("ratios", ratios);
+ parameters.add("steps", steps);
+ parameters.add("offsets", offsets);
+ parameters.add("clip", clip);
+ return NDArray.invoke(
+ getArray().getParent(), "MultiBoxPrior", new NDList(array), parameters);
+ }
+
+ public NDList multiBoxDetection(
+ NDList inputs,
+ boolean clip,
+ float threshold,
+ int backgroundId,
+ float nmsThreashold,
+ boolean forceSuppress,
+ int nmsTopK) {
+ OpParams parameters = new OpParams();
+ parameters.add("clip", clip);
+ parameters.add("threshold", threshold);
+ parameters.add("background_id", backgroundId);
+ parameters.add("nms_threshold", nmsThreashold);
+ parameters.add("force_suppress", forceSuppress);
+ parameters.add("nms_topk", nmsTopK);
+ return NDArray.invoke(getArray().getParent(), "MultiBoxDetection", inputs, parameters);
+ }
+
+ public NDArray getArray() {
+ return array;
+ }
+
+ private int getGlobalPoolingDim() {
+ int poolDim = getArray().getShape().dimension() - 2;
+ if (poolDim < 1 || poolDim > 3) {
+ throw new IllegalStateException(
+ "GlobalPooling only support"
+ + "1 to 3 Dimensions, "
+ + poolDim
+ + "D is not supported.");
+ }
+ return poolDim;
+ }
+
+ private Shape getGlobalPoolingShapes(long fillValue) {
+ // determine pooling dimension according to input
+ // input dimension minus 2 (batch and channel dim)
+ int poolDim = getGlobalPoolingDim();
+ long[] shape = new long[poolDim];
+ Arrays.fill(shape, fillValue);
+ return new Shape(shape);
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/NDArrayIndexer.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/NDArrayIndexer.java
new file mode 100644
index 0000000..ee93fd7
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/NDArrayIndexer.java
@@ -0,0 +1,163 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.ndarray;
+
+import java.util.List;
+import java.util.Optional;
+import java.util.Stack;
+import org.apache.mxnet.engine.OpParams;
+import org.apache.mxnet.ndarray.dim.NDIndexBooleans;
+import org.apache.mxnet.ndarray.dim.NDIndexElement;
+import org.apache.mxnet.ndarray.dim.full.NDIndexFullPick;
+import org.apache.mxnet.ndarray.dim.full.NDIndexFullSlice;
+import org.apache.mxnet.ndarray.index.NDIndex;
+import org.apache.mxnet.ndarray.types.Shape;
+
+/** A helper class for {@link NDArray} implementations for operations with an {@link NDIndex}. */
+public class NDArrayIndexer {
+
+ /**
+ * Returns a subarray by picking the elements.
+ *
+ * @param array the array to get from
+ * @param index the index to get
+ * @return the subArray
+ */
+ public NDArray get(NDArray array, NDIndex index) {
+ if (index.getRank() == 0 && array.getShape().isScalar()) {
+ return array.duplicate();
+ }
+
+ // use booleanMask for NDIndexBooleans case
+ List<NDIndexElement> indices = index.getIndices();
+ if (!indices.isEmpty() && indices.get(0) instanceof NDIndexBooleans) {
+ if (indices.size() != 1) {
+ throw new IllegalArgumentException(
+ "get() currently didn't support more that one boolean NDArray");
+ }
+ return array.booleanMask(((NDIndexBooleans) indices.get(0)).getIndex());
+ }
+
+ Optional<NDIndexFullPick> fullPick = NDIndexFullPick.fromIndex(index, array.getShape());
+ if (fullPick.isPresent()) {
+ return get(array, fullPick.get());
+ }
+
+ Optional<NDIndexFullSlice> fullSlice = NDIndexFullSlice.fromIndex(index, array.getShape());
+ if (fullSlice.isPresent()) {
+ return get(array, fullSlice.get());
+ }
+ throw new UnsupportedOperationException(
+ "get() currently supports all, fixed, and slices indices");
+ }
+
+ /**
+ * Returns a subarray by picking the elements.
+ *
+ * @param array the array to get from
+ * @param fullPick the elements to pick
+ * @return the subArray
+ */
+ public NDArray get(NDArray array, NDIndexFullPick fullPick) {
+ OpParams params = new OpParams();
+ params.addParam("axis", fullPick.getAxis());
+ params.addParam("keepdims", true);
+ params.add("mode", "wrap");
+ return NDArray.invoke(
+ array.getParent(), "pick", new NDList(array, fullPick.getIndices()), params)
+ .singletonOrThrow();
+ }
+
+ /**
+ * Returns a subarray at the slice.
+ *
+ * @param array the array to get from
+ * @param fullSlice the fullSlice index of the array
+ * @return the subArray
+ */
+ public NDArray get(NDArray array, NDIndexFullSlice fullSlice) {
+ OpParams params = new OpParams();
+ params.addTupleParam("begin", fullSlice.getMin());
+ params.addTupleParam("end", fullSlice.getMax());
+ params.addTupleParam("step", fullSlice.getStep());
+
+ NDArray result = NDArray.invoke(array.getParent(), "_npi_slice", array, params);
+ int[] toSqueeze = fullSlice.getToSqueeze();
+ if (toSqueeze.length > 0) {
+ NDArray oldResult = result;
+ result = result.squeeze(toSqueeze);
+ oldResult.close();
+ }
+ return result;
+ }
+
+ /**
+ * Sets the values of the array at the fullSlice with an array.
+ *
+ * @param array the array to set
+ * @param fullSlice the fullSlice of the index to set in the array
+ * @param value the value to set with
+ */
+ public void set(NDArray array, NDIndexFullSlice fullSlice, NDArray value) {
+ OpParams params = new OpParams();
+ params.addTupleParam("begin", fullSlice.getMin());
+ params.addTupleParam("end", fullSlice.getMax());
+ params.addTupleParam("step", fullSlice.getStep());
+
+ Stack<NDArray> prepareValue = new Stack<>();
+ prepareValue.add(value);
+ prepareValue.add(prepareValue.peek().toDevice(array.getDevice(), false));
+ // prepareValue.add(prepareValue.peek().asType(getDataType(), false));
+ // Deal with the case target: (1, 10, 1), original (10)
+ // try to find (10, 1) and reshape (10) to that
+ Shape targetShape = fullSlice.getShape();
+ while (targetShape.size() > value.size()) {
+ targetShape = targetShape.slice(1);
+ }
+ prepareValue.add(prepareValue.peek().reshape(targetShape));
+ prepareValue.add(prepareValue.peek().broadcast(fullSlice.getShape()));
+
+ NDArray.invoke(
+ "_npi_slice_assign",
+ new NDArray[] {array, prepareValue.peek()},
+ new NDArray[] {array},
+ params);
+ for (NDArray toClean : prepareValue) {
+ if (toClean != value) {
+ toClean.close();
+ }
+ }
+ }
+
+ /**
+ * Sets the values of the array at the fullSlice with a number.
+ *
+ * @param array the array to set
+ * @param fullSlice the fullSlice of the index to set in the array
+ * @param value the value to set with
+ */
+ public void set(NDArray array, NDIndexFullSlice fullSlice, Number value) {
+ OpParams params = new OpParams();
+ params.addTupleParam("begin", fullSlice.getMin());
+ params.addTupleParam("end", fullSlice.getMax());
+ params.addTupleParam("step", fullSlice.getStep());
+ params.addParam("scalar", value);
+ NDArray.invoke(
+ "_npi_slice_assign_scalar", new NDArray[] {array}, new NDArray[] {array}, params);
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/NDArrays.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/NDArrays.java
new file mode 100644
index 0000000..df0c6b8
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/NDArrays.java
@@ -0,0 +1,2008 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.ndarray;
+
+import java.util.Arrays;
+import org.apache.mxnet.ndarray.types.Shape;
+
+/** This class contains various methods for manipulating MxNDArrays. */
+public final class NDArrays {
+
+ private NDArrays() {}
+
+ private static void checkInputs(NDArray[] arrays) {
+ if (arrays == null || arrays.length < 2) {
+ throw new IllegalArgumentException("Passed in arrays must have at least one element");
+ }
+ if (arrays.length > 2
+ && Arrays.stream(arrays).skip(1).anyMatch(array -> !arrays[0].shapeEquals(array))) {
+ throw new IllegalArgumentException("The shape of all inputs must be the same");
+ }
+ }
+
+ ////////////////////////////////////////
+ // Operations: Element Comparison
+ ////////////////////////////////////////
+
+ /**
+ * Returns {@code true} if all elements in {@link NDArray} a are equal to {@link NDArray} b.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array = manager.ones(new Shape(3));
+ * jshell> MxNDArrays.contentEquals(array, 1); // return true instead of boolean MxNDArray
+ * true
+ * </pre>
+ *
+ * @param a the {@link NDArray} to compare
+ * @param n the number to compare
+ * @return the boolean result
+ */
+ public static boolean contentEquals(NDArray a, Number n) {
+ if (a == null) {
+ return false;
+ }
+ return a.contentEquals(n);
+ }
+
+ /**
+ * Returns {@code true} if all elements in {@link NDArray} a are equal to {@link NDArray} b.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array1 = manager.arange(6f).reshape(2, 3);
+ * jshell> MxNDArray array2 = manager.create(new float[] {0f, 1f, 2f, 3f, 4f, 5f}, new Shape(2, 3));
+ * jshell> MxNDArrays.contentEquals(array1, array2); // return true instead of boolean MxNDArray
+ * true
+ * </pre>
+ *
+ * @param a the {@link NDArray} to compare
+ * @param b the {@link NDArray} to compare
+ * @return the boolean result
+ */
+ public static boolean contentEquals(NDArray a, NDArray b) {
+ return a.contentEquals(b);
+ }
+
+ /**
+ * Checks 2 {@link NDArray}s for equal shapes.
+ *
+ * <p>Shapes are considered equal if:
+ *
+ * <ul>
+ * <li>Both {@link NDArray}s have equal rank, and
+ * <li>size(0)...size(rank()-1) are equal for both {@link NDArray}s
+ * </ul>
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array1 = manager.ones(new Shape(1, 2, 3));
+ * jshell> MxNDArray array2 = manager.create(new Shape(1, 2, 3));
+ * jshell> MxNDArrays.shapeEquals(array1, array2); // return true instead of boolean MxNDArray
+ * true
+ * </pre>
+ *
+ * @param a the {@link NDArray} to compare
+ * @param b the {@link NDArray} to compare
+ * @return {@code true} if the {@link Shape}s are the same
+ */
+ public static boolean shapeEquals(NDArray a, NDArray b) {
+ return a.shapeEquals(b);
+ }
+
+ /**
+ * Returns {@code true} if two {@link NDArray} are element-wise equal within a tolerance.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array1 = manager.create(new double[] {1e10,1e-7});
+ * jshell> MxNDArray array2 = manager.create(new double[] {1.00001e10,1e-8});
+ * jshell> MxNDArrays.allClose(array1, array2); // return false instead of boolean MxNDArray
+ * false
+ * jshell> MxNDArray array1 = manager.create(new double[] {1e10,1e-8});
+ * jshell> MxNDArray array2 = manager.create(new double[] {1.00001e10,1e-9});
+ * jshell> MxNDArrays.allClose(array1, array2); // return true instead of boolean MxNDArray
+ * true
+ * </pre>
+ *
+ * @param a the {@link NDArray} to compare with
+ * @param b the {@link NDArray} to compare with
+ * @return the boolean result
+ */
+ // public static boolean allClose(MxNDArray a, MxNDArray b) {
+ // return a.allClose(b);
+ // }
+
+ /**
+ * Returns {@code true} if two {@link NDArray} are element-wise equal within a tolerance.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array1 = manager.create(new double[] {1e10, 1e-7});
+ * jshell> MxNDArray array2 = manager.create(new double[] {1.00001e10, 1e-8});
+ * jshell> MxNDArrays.allClose(array1, array2, 1e-05, 1e-08, false); // return false instead of boolean MxNDArray
+ * false
+ * jshell> MxNDArray array1 = manager.create(new double[] {1e10, 1e-8});
+ * jshell> MxNDArray array2 = manager.create(new double[] {1.00001e10, 1e-9});
+ * jshell> MxNDArrays.allClose(array1, array2, 1e-05, 1e-08, false); // return true instead of boolean MxNDArray
+ * true
+ * jshell> MxNDArray array1 = manager.create(new float[] {1f, Float.NaN});
+ * jshell> MxNDArray array2 = manager.create(new float[] {1f, Float.NaN});
+ * jshell> MxNDArrays.allClose(array1, array2, 1e-05, 1e-08, true); // return true instead of boolean MxNDArray
+ * true
+ * </pre>
+ *
+ * @param a the {@link NDArray} to compare with
+ * @param b the {@link NDArray} to compare with
+ * @param rtol the relative tolerance parameter
+ * @param atol the absolute tolerance parameter
+ * @param equalNan whether to compare NaN’s as equal. If {@code true}, NaN’s in the {@link
+ * NDArray} will be considered equal to NaN’s in the other {@link NDArray}
+ * @return the boolean result
+ */
+ // public static boolean allClose(
+ // MxNDArray a, MxNDArray b, double rtol, double atol, boolean equalNan) {
+ // return a.allClose(b, rtol, atol, equalNan);
+ // }
+
+ /**
+ * Returns the boolean {@link NDArray} for element-wise "Equals" comparison.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array = manager.ones(new Shape(1));
+ * jshell> MxNDArrays.eq(array, 1);
+ * ND: (1) cpu() boolean
+ * [ true]
+ * </pre>
+ *
+ * @param a the {@link NDArray} to compare
+ * @param n the number to compare
+ * @return the boolean {@link NDArray} for element-wise "Equals" comparison
+ */
+ public static NDArray eq(NDArray a, Number n) {
+ return a.eq(n);
+ }
+
+ /**
+ * Returns the boolean {@link NDArray} for element-wise "Equals" comparison.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array = manager.ones(new Shape(1));
+ * jshell> MxNDArrays.eq(1, array);
+ * ND: (1) cpu() boolean
+ * [ true]
+ * </pre>
+ *
+ * @param n the number to compare
+ * @param a the {@link NDArray} to compare
+ * @return the boolean {@link NDArray} for element-wise "Equals" comparison
+ */
+ public static NDArray eq(Number n, NDArray a) {
+ return a.eq(n);
+ }
+
+ /**
+ * Returns the boolean {@link NDArray} for element-wise "Equals" comparison.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array1 = manager.create(new float[] {0f, 1f, 3f});
+ * jshell> MxNDArray array2 = manager.arange(3f);
+ * jshell> MxNDArrays.eq(array1, array2);
+ * ND: (3) cpu() boolean
+ * [ true, true, false]
+ * </pre>
+ *
+ * @param a the {@link NDArray} to compare
+ * @param b the {@link NDArray} to compare
+ * @return the boolean {@link NDArray} for element-wise "Equals" comparison
+ */
+ public static NDArray eq(NDArray a, NDArray b) {
+ return a.eq(b);
+ }
+
+ /**
+ * Returns the boolean {@link NDArray} for element-wise "Not equals" comparison.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array = manager.arange(4f).reshape(2, 2);
+ * jshell> MxNDArrays.neq(array, 1);
+ * ND: (2, 2) cpu() boolean
+ * [[ true, false],
+ * [ true, true],
+ * ]
+ * </pre>
+ *
+ * @param a the {@link NDArray} to compare
+ * @param n the number to compare
+ * @return the boolean {@link NDArray} for element-wise "Not equals" comparison
+ */
+ public static NDArray neq(NDArray a, Number n) {
+ return a.neq(n);
+ }
+
+ /**
+ * Returns the boolean {@link NDArray} for element-wise "Not equals" comparison.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array = manager.arange(f4).reshape(2, 2);
+ * jshell> MxNDArrays.neq(1, array);
+ * ND: (2, 2) cpu() boolean
+ * [[ true, false],
+ * [ true, true],
+ * ]
+ * </pre>
+ *
+ * @param n the number to compare
+ * @param a the {@link NDArray} to compare
+ * @return the boolean {@link NDArray} for element-wise "Not equals" comparison
+ */
+ public static NDArray neq(Number n, NDArray a) {
+ return a.neq(n);
+ }
+
+ /**
+ * Returns the boolean {@link NDArray} for element-wise "Not equals" comparison.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array1 = manager.create(new float[] {1f, 2f});
+ * jshell> MxNDArray array2 = manager.create(new float[] {1f, 3f});
+ * jshell> MxNDArrays.neq(array1, array2);
+ * ND: (2) cpu() boolean
+ * [false, true]
+ * jshell> MxNDArray array1 = manager.create(new float[] {1f, 2f});
+ * jshell> MxNDArray array2 = manager.create(new float[] {1f, 3f, 1f, 4f}, new Shape(2, 2));
+ * jshell> MxNDArrays.neq(array1, array2); // broadcasting
+ * ND: (2, 2) cpu() boolean
+ * [[false, true],
+ * [false, true],
+ * ]
+ * </pre>
+ *
+ * @param a the {@link NDArray} to compare
+ * @param b the {@link NDArray} to compare
+ * @return the boolean {@link NDArray} for element-wise "Not equals" comparison
+ */
+ public static NDArray neq(NDArray a, NDArray b) {
+ return a.neq(b);
+ }
+
+ /**
+ * Returns the boolean {@link NDArray} for element-wise "Greater Than" comparison.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array = manager.create(new float[] {4f, 2f});
+ * jshell> MxNDArrays.gt(array, 2f);
+ * ND: (2) cpu() boolean
+ * [ true, false]
+ * </pre>
+ *
+ * @param a the {@link NDArray} to compare
+ * @param n the number to be compared against
+ * @return the boolean {@link NDArray} for element-wise "Greater Than" comparison
+ */
+ public static NDArray gt(NDArray a, Number n) {
+ return a.gt(n);
+ }
+
+ /**
+ * Returns the boolean {@link NDArray} for element-wise "Greater Than" comparison.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array = manager.create(new float[] {4f, 2f});
+ * jshell> MxNDArrays.gt(2f, array);
+ * ND: (2) cpu() boolean
+ * [false, false]
+ * </pre>
+ *
+ * @param n the number to be compared
+ * @param a the MxNDArray to be compared against
+ * @return the boolean {@link NDArray} for element-wise "Greater Than" comparison
+ */
+ public static NDArray gt(Number n, NDArray a) {
+ return a.lt(n);
+ }
+
+ /**
+ * Returns the boolean {@link NDArray} for element-wise "Greater Than" comparison.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array1 = manager.create(new float[] {4f, 2f});
+ * jshell> MxNDArray array2 = manager.create(new float[] {2f, 2f});
+ * jshell> MxNDArrays.gt(array1, array2);
+ * ND: (2) cpu() boolean
+ * [ true, false]
+ * </pre>
+ *
+ * @param a the {@link NDArray} to be compared
+ * @param b the {@link NDArray} to be compared against
+ * @return the boolean {@link NDArray} for element-wise "Greater Than" comparison
+ */
+ public static NDArray gt(NDArray a, NDArray b) {
+ return a.gt(b);
+ }
+
+ /**
+ * Returns the boolean {@link NDArray} for element-wise "Greater or equals" comparison.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array = manager.create(new float[] {4f, 2f});
+ * jshell> MxNDArrays.gte(array, 2);
+ * ND: (2) cpu() boolean
+ * [ true, true]
+ * </pre>
+ *
+ * @param a the {@link NDArray} to be compared
+ * @param n the number to be compared against
+ * @return the boolean {@link NDArray} for element-wise "Greater or equals" comparison
+ */
+ public static NDArray gte(NDArray a, Number n) {
+ return a.gte(n);
+ }
+
+ /**
+ * Returns the boolean {@link NDArray} for element-wise "Greater or equals" comparison.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array = manager.create(new float[] {4f, 2f});
+ * jshell> MxNDArrays.gte(2, array);
+ * ND: (2) cpu() boolean
+ * [false, true]
+ * </pre>
+ *
+ * @param n the number to be compared
+ * @param a the {@link NDArray} to be compared against
+ * @return the boolean {@link NDArray} for element-wise "Greater or equals" comparison
+ */
+ public static NDArray gte(Number n, NDArray a) {
+ return a.lte(n);
+ }
+
+ /**
+ * Returns the boolean {@link NDArray} for element-wise "Greater or equals" comparison.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array1 = manager.create(new float[] {4f, 2f});
+ * jshell> MxNDArray array2 = manager.create(new float[] {2f, 2f});
+ * jshell> MxNDArrays.gte(array1, array2);
+ * ND: (2) cpu() boolean
+ * [ true, true]
+ * </pre>
+ *
+ * @param a the {@link NDArray} to be compared
+ * @param b the {@link NDArray} to be compared against
+ * @return the boolean {@link NDArray} for element-wise "Greater or equals" comparison
+ */
+ public static NDArray gte(NDArray a, NDArray b) {
+ return a.gte(b);
+ }
+
+ /**
+ * Returns the boolean {@link NDArray} for element-wise "Less" comparison.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array = manager.create(new float[] {1f, 2f});
+ * jshell> MxNDArrays.lt(array, 2f);
+ * ND: (2) cpu() boolean
+ * [ true, false]
+ * </pre>
+ *
+ * @param a the {@link NDArray} to be compared
+ * @param n the number to be compared against
+ * @return the boolean {@link NDArray} for element-wise "Less" comparison
+ */
+ public static NDArray lt(NDArray a, Number n) {
+ return a.lt(n);
+ }
+
+ /**
+ * Returns the boolean {@link NDArray} for element-wise "Less" comparison.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array = manager.create(new float[] {1f, 2f});
+ * jshell> MxNDArrays.lt(2f, array);
+ * ND: (2) cpu() boolean
+ * [false, false]
+ * </pre>
+ *
+ * @param n the number to be compared
+ * @param a the {@link NDArray} to be compared against
+ * @return the boolean {@link NDArray} for element-wise "Less" comparison
+ */
+ public static NDArray lt(Number n, NDArray a) {
+ return a.gt(n);
+ }
+
+ /**
+ * Returns the boolean {@link NDArray} for element-wise "Less" comparison.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array1 = manager.create(new float[] {1f, 2f});
+ * jshell> MxNDArray array2 = manager.create(new float[] {2f, 2f});
+ * jshell> MxNDArrays.lt(array1, array2);
+ * ND: (2) cpu() boolean
+ * [ true, false]
+ * </pre>
+ *
+ * @param a the {@link NDArray} to be compared
+ * @param b the {@link NDArray} to be compared against
+ * @return the boolean {@link NDArray} for element-wise "Less" comparison
+ */
+ public static NDArray lt(NDArray a, NDArray b) {
+ return a.lt(b);
+ }
+
+ /**
+ * Returns the boolean {@link NDArray} for element-wise "Less or equals" comparison.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array = manager.create(new float[] {1f, 2f});
+ * jshell> MxNDArrays.lte(array, 2f);
+ * ND: (2) cpu() boolean
+ * [ true, true]
+ * </pre>
+ *
+ * @param a the {@link NDArray} to be compared
+ * @param n the number to be compared against
+ * @return the boolean {@link NDArray} for element-wise "Less or equals" comparison
+ */
+ public static NDArray lte(NDArray a, Number n) {
+ return a.lte(n);
+ }
+
+ /**
+ * Returns the boolean {@link NDArray} for element-wise "Less or equals" comparison.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array = manager.create(new float[] {1f, 2f});
+ * jshell> MxNDArrays.lte(2f, array);
+ * ND: (2) cpu() boolean
+ * [false, true]
+ * </pre>
+ *
+ * @param n the number to be compared
+ * @param a the {@link NDArray} to be compared against
+ * @return the boolean {@link NDArray} for element-wise "Less or equals" comparison
+ */
+ public static NDArray lte(Number n, NDArray a) {
+ return a.gte(n);
+ }
+
+ /**
+ * Returns the boolean {@link NDArray} for element-wise "Less or equals" comparison.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array1 = manager.create(new float[] {1f, 2f});
+ * jshell> MxNDArray array2 = manager.create(new float[] {2f, 2f});
+ * jshell> MxNDArrays.lte(array1, array2)
+ * ND: (2) cpu() boolean
+ * [ true, true]
+ * </pre>
+ *
+ * @param a the {@link NDArray} to be compared
+ * @param b the {@link NDArray} to be compared against
+ * @return the boolean {@link NDArray} for element-wise "Less or equals" comparison
+ */
+ public static NDArray lte(NDArray a, NDArray b) {
+ return a.lte(b);
+ }
+
+ /**
+ * Returns elements chosen from the {@link NDArray} or the other {@link NDArray} depending on
+ * condition.
+ *
+ * <p>Given three {@link NDArray}s, condition, a, and b, returns an {@link NDArray} with the
+ * elements from a or b, depending on whether the elements from condition {@link NDArray} are
+ * {@code true} or {@code false}. If condition has the same shape as a, each element in the
+ * output {@link NDArray} is from this if the corresponding element in the condition is {@code
+ * true}, and from other if {@code false}.
+ *
+ * <p>Note that all non-zero values are interpreted as {@code true} in condition {@link
+ * NDArray}.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array = manager.arange(10f);
+ * jshell> MxNDArrays.where(array.lt(5), array, array.mul(10));
+ * ND: (10) cpu() float32
+ * [ 0., 1., 2., 3., 4., 50., 60., 70., 80., 90.]
+ * jshell> MxNDArray array = manager.create(new float[]{0f, 1f, 2f, 0f, 2f, 4f, 0f, 3f, 6f}, new Shape(3, 3));
+ * jshell> MxNDArrays.where(array.lt(4), array, manager.create(-1f));
+ * ND: (3, 3) cpu() float32
+ * [[ 0., 1., 2.],
+ * [ 0., 2., -1.],
+ * [ 0., 3., -1.],
+ * ]
+ * </pre>
+ *
+ * @param condition the condition {@code MxNDArray}
+ * @param a the first {@link NDArray}
+ * @param b the other {@link NDArray}
+ * @return the result {@link NDArray}
+ */
+ public static NDArray where(NDArray condition, NDArray a, NDArray b) {
+ return a.getNDArrayInternal().where(condition, b);
+ }
+
+ /**
+ * Returns the maximum of a {@link NDArray} and a number element-wise.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array = manager.create(new float[] {2f, 3f, 4f});
+ * jshell> MxNDArrays.maximum(array, 3f);
+ * ND: (3) cpu() float32
+ * [3., 3., 4.]
+ * </pre>
+ *
+ * @param a the {@link NDArray} to be compared
+ * @param n the number to be compared
+ * @return the maximum of a {@link NDArray} and a number element-wise
+ */
+ public static NDArray maximum(NDArray a, Number n) {
+ return a.maximum(n);
+ }
+
+ /**
+ * Returns the maximum of a number and a {@link NDArray} element-wise.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array = manager.create(new float[] {2f, 3f, 4f});
+ * jshell> MxNDArrays.maximum(3f, array);
+ * ND: (3) cpu() float32
+ * [3., 3., 4.]
+ * </pre>
+ *
+ * @param n the number to be compared
+ * @param a the {@link NDArray} to be compared
+ * @return the maximum of a number and a {@link NDArray} element-wise
+ */
+ public static NDArray maximum(Number n, NDArray a) {
+ return maximum(a, n);
+ }
+
+ /**
+ * Returns the maximum of {@link NDArray} a and {@link NDArray} b element-wise.
+ *
+ * <p>The shapes of {@link NDArray} a and {@link NDArray} b must be broadcastable.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array1 = manager.create(new float[] {2f, 3f, 4f});
+ * jshell> MxNDArray array2 = manager.create(new float[] {1f, 5f, 2f});
+ * jshell> MxNDArrays.maximum(array1, array2);
+ * ND: (3) cpu() float32
+ * [2., 5., 4.]
+ * jshell> MxNDArray array1 = manager.eye(2);
+ * jshell> MxNDArray array2 = manager.create(new float[] {0.5f, 2f});
+ * jshell> MxNDArrays.maximum(array1, array2); // broadcasting
+ * ND: (2, 2) cpu() float32
+ * [[1. , 2. ],
+ * [0.5, 2. ],
+ * ]
+ * </pre>
+ *
+ * @param a the {@link NDArray} to be compared
+ * @param b the {@link NDArray} to be compared
+ * @return the maximum of {@link NDArray} a and {@link NDArray} b element-wise
+ */
+ public static NDArray maximum(NDArray a, NDArray b) {
+ return a.maximum(b);
+ }
+
+ /**
+ * Returns the minimum of a {@link NDArray} and a number element-wise.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array = manager.create(new float[] {2f, 3f, 4f});
+ * jshell> MxNDArrays.minimum(array, 3f);
+ * ND: (3) cpu() float32
+ * [2., 3., 3.]
+ * </pre>
+ *
+ * @param a the {@link NDArray} to be compared
+ * @param n the number to be compared
+ * @return the minimum of a {@link NDArray} and a number element-wise
+ */
+ public static NDArray minimum(NDArray a, Number n) {
+ return a.minimum(n);
+ }
+
+ /**
+ * Returns the minimum of a number and a {@link NDArray} element-wise.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array = manager.create(new float[] {2f, 3f, 4f});
+ * jshell> MxNDArrays.minimum(3f, array);
+ * ND: (3) cpu() float32
+ * [2., 3., 3.]
+ * </pre>
+ *
+ * @param n the number to be compared
+ * @param a the {@link NDArray} to be compared
+ * @return the minimum of a number and a {@link NDArray} element-wise
+ */
+ public static NDArray minimum(Number n, NDArray a) {
+ return minimum(a, n);
+ }
+
+ /**
+ * Returns the minimum of {@link NDArray} a and {@link NDArray} b element-wise.
+ *
+ * <p>The shapes of {@link NDArray} a and {@link NDArray} b must be broadcastable.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array1 = manager.create(new float[] {2f, 3f, 4f});
+ * jshell> MxNDArray array2 = manager.create(new float[] {1f, 5f, 2f});
+ * jshell> MxNDArrays.minimum(array1, array2);
+ * ND: (3) cpu() float32
+ * [1., 3., 2.]
+ * jshell> MxNDArray array1 = manager.eye(2);
+ * jshell> MxNDArray array2 = manager.create(new float[] {0.5f, 2f});
+ * jshell> MxNDArrays.minimum(array1, array2); // broadcasting
+ * ND: (2, 2) cpu() float32
+ * [[0.5, 0. ],
+ * [0. , 1. ],
+ * ]
+ * </pre>
+ *
+ * @param a the {@link NDArray} to be compared
+ * @param b the {@link NDArray} to be compared
+ * @return the minimum of {@link NDArray} a and {@link NDArray} b element-wise
+ */
+ public static NDArray minimum(NDArray a, NDArray b) {
+ return a.minimum(b);
+ }
+
+ /**
+ * Returns portion of the {@link NDArray} given the index boolean {@link NDArray} along first
+ * axis.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array = manager.create(new float[] {1f, 2f, 3f, 4f, 5f, 6f}, new Shape(3, 2));
+ * jshell> MxNDArray mask = manager.create(new boolean[] {true, false, true});
+ * jshell> MxNDArrays.booleanMask(array, mask);
+ * ND: (2, 2) cpu() float32
+ * [[1., 2.],
+ * [5., 6.],
+ * ]
+ * </pre>
+ *
+ * @param data the {@link NDArray} to operate on
+ * @param index the boolean {@link NDArray} mask
+ * @return the result {@link NDArray}
+ */
+ public static NDArray booleanMask(NDArray data, NDArray index) {
+ return booleanMask(data, index, 0);
+ }
+
+ /**
+ * Returns portion of the {@link NDArray} given the index boolean {@link NDArray} along given
+ * axis.
+ *
+ * @param data the {@link NDArray} to operate on
+ * @param index the boolean {@link NDArray} mask
+ * @param axis an integer that represents the axis of {@link NDArray} to mask from
+ * @return the result {@link NDArray}
+ */
+ public static NDArray booleanMask(NDArray data, NDArray index, int axis) {
+ return data.booleanMask(index, axis);
+ }
+
+ /**
+ * Sets all elements of the given {@link NDArray} outside the sequence {@link NDArray} to a
+ * constant value.
+ *
+ * <p>This function takes an n-dimensional input array of the form [batch_size,
+ * max_sequence_length, ....] and returns an array of the same shape. Parameter {@code
+ * sequenceLength} is used to handle variable-length sequences. {@code sequenceLength} should be
+ * an input array of positive ints of dimension [batch_size].
+ *
+ * @param data the {@link NDArray} to operate on
+ * @param sequenceLength used to handle variable-length sequences
+ * @param value the constant value to be set
+ * @return the result {@link NDArray}
+ */
+ public static NDArray sequenceMask(NDArray data, NDArray sequenceLength, float value) {
+ return data.sequenceMask(sequenceLength, value);
+ }
+
+ /**
+ * Sets all elements of the given {@link NDArray} outside the sequence {@link NDArray} to 0.
+ *
+ * <p>This function takes an n-dimensional input array of the form [batch_size,
+ * max_sequence_length, ....] and returns an array of the same shape. Parameter {@code
+ * sequenceLength} is used to handle variable-length sequences. {@code sequenceLength} should be
+ * an input array of positive ints of dimension [batch_size].
+ *
+ * @param data the {@link NDArray} to operate on
+ * @param sequenceLength used to handle variable-length sequences
+ * @return the result {@link NDArray}
+ */
+ public static NDArray sequenceMask(NDArray data, NDArray sequenceLength) {
+ return data.sequenceMask(sequenceLength);
+ }
+
+ ////////////////////////////////////////
+ // Operations: Element Arithmetic
+ ////////////////////////////////////////
+
+ /**
+ * Adds a number to the {@link NDArray} element-wise.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array = manager.create(new float[] {1f, 2f});
+ * jshell> MxNDArrays.add(array, 2f);
+ * ND: (2) cpu() float32
+ * [3., 4.]
+ * </pre>
+ *
+ * @param a the {@link NDArray} to be added to
+ * @param n the number to add
+ * @return the result {@link NDArray}
+ */
+ public static NDArray add(NDArray a, Number n) {
+ return a.add(n);
+ }
+
+ /**
+ * Adds a {@link NDArray} to a number element-wise.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array = manager.create(new float[] {1f, 2f});
+ * jshell> MxNDArrays.add(2f, array);
+ * ND: (2) cpu() float32
+ * [3., 4.]
+ * </pre>
+ *
+ * @param n the number to be added to
+ * @param a the {@link NDArray} to add
+ * @return the result {@link NDArray}
+ */
+ public static NDArray add(Number n, NDArray a) {
+ return a.add(n);
+ }
+
+ /**