Java2.0 proto (#20461)
* ADD: first commit
* ADD: load local libraries
* UPDATE: use header files of MXNet 2.0
* ADD: load binaries from environment variable, java properties or jar files.
* ADD: add symbol loading and closing
add module integration
* ADD: [WIP] Component MxNDArray
* ADD: [WIP] Component MxNDArray
* ADD: Component MxNDArray. Pass static compilation check
* ADD: Component CachedOp
* REMOVE: module api which is no use
* FIX: dependency missing
* ADD: [WIP] add test cases for NdArray and CachedOp
* ADD: [WIP] add test cases for NdArray and CachedOp
* ADD: implement of the forward function for MxSymbolblock
* ADD: implement of the forward function for MxSymbolblock
* ADD: Sample model downloader for MLP
* ADD: doc
* ADD: Front-end module for inference, class MxModel, Predictor and so on.
* FIX: Mxnet crash when process exits.
* FIX: remove and initialize 3rdparty directory
* FIX: revert version of submodules: dlpack, dmlc-core, googletest, ps-lite
* Revert "FIX: remove and initialize 3rdparty directory"
This reverts commit d097675e
* FIX: redownload files in 3rdparty
* FIX: reset --hard the version of a few submodules
* FIX: reset --hard the version of a few submodules
* FIX: reset --hard the version of a few submodules
* PERF: [WIP] optimize code structure and memory management and
* ADD: add copyright; remove Mx prefix for some classes
* ADD: add copyright
* FIX: group name, path to find header files
* UPDATE: README.md
* ADD: copyright
* ADD: copyright
* ADD: package-info
ADD: ci
ADD: ci
ADD: make modification to trigger ci
ADD: ci
ADD: ci
ADD: ci
ADD: ci
ADD: gradlew
ADD: java_package_ci
ADD: java_package_ci
ADD: java_package_ci
ADD: java_package_ci
ADD: java_package_ci
ADD: java_package_ci
ADD: java_package_ci
ADD: java_package_ci
ADD: java_package_ci
FIX: build failure
* FIX: ci config file
* UPDATE: remove ParameterStore and some scripts
UPDATE: remove Initializer.java
* UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: jenkins ci scripts for java-package
UPDATE: jenkins ci scripts for java-package
UPDATE: jenkins ci scripts for java-package
UPDATE: jenkins ci scripts for java-package
UPDATE: jenkins ci scripts for java-package
UPDATE: jenkins ci scripts for java-package
UPDATE: jenkins ci scripts for java-package
UPDATE: jenkins ci scripts for java-package
UPDATE: jenkins ci scripts for java-package
UPDATE: jenkins ci scripts for java-package
FIX: issues in Resource close methods
FIX: issues in Resource close methods
FIX: issues in Resource close methods
UPDATE: remove scripts for dev
* FIX: loading on Linux platform
* # This is a combination of 18 commits.
# This is the 1st commit message:
FIX: loading on Linux platform
# This is the commit message #2:
UPDATE: ci for java-package
# This is the commit message #3:
UPDATE: ci for java-package
# This is the commit message #4:
UPDATE: ci for java-package
# This is the commit message #5:
UPDATE: ci for java-package
# This is the commit message #6:
UPDATE: ci for java-package
# This is the commit message #7:
UPDATE: ci for java-package
# This is the commit message #8:
UPDATE: ci for java-package
# This is the commit message #9:
UPDATE: ci for java-package
# This is the commit message #10:
UPDATE: ci for java-package
# This is the commit message #11:
UPDATE: ci for java-package
# This is the commit message #12:
UPDATE: ci for java-package
# This is the commit message #13:
UPDATE: ci for java-package
# This is the commit message #14:
UPDATE: ci for java-package
# This is the commit message #15:
UPDATE: ci for java-package
# This is the commit message #16:
UPDATE: ci for java-package
# This is the commit message #17:
UPDATE: ci for java-package
# This is the commit message #18:
UPDATE: ci for java-package
* # This is a combination of 27 commits.
parent 1ea18edce197b402cbfbaaaa54a94347501d92ab
author cspchen <cspchen@amazon.com> 1629186478 +0800
committer cspchen <cspchen@amazon.com> 1629186485 +0800
# This is a combination of 21 commits.
# This is the 1st commit message:
FIX: loading on Linux platform
# This is the commit message #2:
UPDATE: ci for java-package
# This is the commit message #3:
UPDATE: ci for java-package
# This is the commit message #4:
UPDATE: ci for java-package
# This is the commit message #5:
UPDATE: ci for java-package
# This is the commit message #6:
UPDATE: ci for java-package
# This is the commit message #7:
UPDATE: ci for java-package
# This is the commit message #8:
UPDATE: ci for java-package
# This is the commit message #9:
UPDATE: ci for java-package
# This is the commit message #10:
UPDATE: ci for java-package
# This is the commit message #11:
UPDATE: ci for java-package
# This is the commit message #12:
UPDATE: ci for java-package
# This is the commit message #13:
UPDATE: ci for java-package
# This is the commit message #14:
UPDATE: ci for java-package
# This is the commit message #15:
UPDATE: ci for java-package
# This is the commit message #16:
UPDATE: ci for java-package
# This is the commit message #17:
UPDATE: ci for java-package
# This is the commit message #18:
UPDATE: ci for java-package
# This is the commit message #19:
UPDATE: ci for java-package
# This is the commit message #20:
UPDATE: ci for java-package
# This is the commit message #21:
UPDATE: ci for java-package
# This is the commit message #22:
UPDATE: ci for java-package
# This is the commit message #23:
UPDATE: jenkins ci scripts for java-package
# This is the commit message #24:
UPDATE: jenkins ci scripts for java-package
# This is the commit message #25:
UPDATE: jenkins ci scripts for java-package
# This is the commit message #26:
UPDATE: jenkins ci scripts for java-package
# This is the commit message #27:
UPDATE: jenkins ci scripts for java-package
* MERGE: resolve conflicts
* MERGE: resolve conflicts
* # This is a combination of 35 commits.
parent 1ea18edce197b402cbfbaaaa54a94347501d92ab
author cspchen <cspchen@amazon.com> 1629186478 +0800
committer cspchen <cspchen@amazon.com> 1629186485 +0800
# This is a combination of 21 commits.
# This is the 1st commit message:
FIX: loading on Linux platform
# This is the commit message #2:
UPDATE: ci for java-package
# This is the commit message #3:
UPDATE: ci for java-package
# This is the commit message #4:
UPDATE: ci for java-package
# This is the commit message #5:
UPDATE: ci for java-package
# This is the commit message #6:
UPDATE: ci for java-package
# This is the commit message #7:
UPDATE: ci for java-package
# This is the commit message #8:
UPDATE: ci for java-package
# This is the commit message #9:
UPDATE: ci for java-package
# This is the commit message #10:
UPDATE: ci for java-package
# This is the commit message #11:
UPDATE: ci for java-package
# This is the commit message #12:
UPDATE: ci for java-package
# This is the commit message #13:
UPDATE: ci for java-package
# This is the commit message #14:
UPDATE: ci for java-package
# This is the commit message #15:
UPDATE: ci for java-package
# This is the commit message #16:
UPDATE: ci for java-package
# This is the commit message #17:
UPDATE: ci for java-package
# This is the commit message #18:
UPDATE: ci for java-package
# This is the commit message #19:
UPDATE: ci for java-package
# This is the commit message #20:
UPDATE: ci for java-package
# This is the commit message #21:
UPDATE: ci for java-package
# This is the commit message #22:
UPDATE: ci for java-package
# This is the commit message #23:
UPDATE: jenkins ci scripts for java-package
# This is the commit message #24:
UPDATE: jenkins ci scripts for java-package
# This is the commit message #25:
UPDATE: jenkins ci scripts for java-package
# This is the commit message #26:
UPDATE: jenkins ci scripts for java-package
# This is the commit message #27:
UPDATE: jenkins ci scripts for java-package
# This is the commit message #28:
UPDATE: jenkins ci scripts for java-package
# This is the commit message #30:
UPDATE: jenkins ci scripts for java-package
# This is the commit message #31:
UPDATE: jenkins ci scripts for java-package
# This is the commit message #32:
UPDATE: jenkins ci scripts for java-package
# This is the commit message #33:
UPDATE: jenkins ci scripts for java-package
# This is the commit message #34:
FIX: issues in Resource close methods
# This is the commit message #35:
FIX: issues in Resource close methods
* parent 1ea18edce197b402cbfbaaaa54a94347501d92ab
author cspchen <cspchen@amazon.com> 1629186478 +0800
committer cspchen <cspchen@amazon.com> 1629186485 +0800
FIX: loading on Linux platform
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: ci for java-package
UPDATE: jenkins ci scripts for java-package
UPDATE: jenkins ci scripts for java-package
UPDATE: jenkins ci scripts for java-package
UPDATE: jenkins ci scripts for java-package
UPDATE: jenkins ci scripts for java-package
UPDATE: jenkins ci scripts for java-package
UPDATE: jenkins ci scripts for java-package
UPDATE: jenkins ci scripts for java-package
UPDATE: jenkins ci scripts for java-package
UPDATE: jenkins ci scripts for java-package
FIX: issues in Resource close methods
FIX: issues in Resource close methods
FIX: issues in Resource close methods
DOC: add doc
STYLE: change code style for pmd check
FIX: avoid the register for a signal handler twice
STYLE: pass pmd check
UPDATE: remove unused scripts
* FIX: solve problems before merge
* UPDATE: remove useless files
* FIX: licence to apache
* FIX: sanity check
* FIX: sanity check
* FIX: sanity check
* FIX: remove unused files
* FIX: remove unused files
* DOC: add document
* FIX: doesn't work on osx
* FIX: clang static check
* FIX: sanity
* FIX: skip signal handler registration when building java package
* FIX: remove DataType String
* FIX: add license
Co-authored-by: cspchen <cspchen@amazon.com>
diff --git a/CMakeLists.txt b/CMakeLists.txt
index b901f41..c6e0a4e 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -91,6 +91,7 @@
option(BUILD_EXTENSION_PATH "Path to extension to build" "")
option(BUILD_CYTHON_MODULES "Build cython modules." OFF)
option(LOG_FATAL_THROW "Log exceptions but do not abort" ON)
+option(BUILD_JAVA_NATIVE "Skip signal handler registration for Java Binding" OFF)
cmake_dependent_option(USE_SPLIT_ARCH_DLL "Build a separate DLL for each Cuda arch (Windows only)." ON "MSVC" OFF)
cmake_dependent_option(USE_CCACHE "Attempt using CCache to wrap the compilation" ON "UNIX" OFF)
cmake_dependent_option(MXNET_FORCE_SHARED_CRT "Build with dynamic CRT on Windows (/MD)" ON "MXNET_BUILD_SHARED_LIBS" OFF)
@@ -970,6 +971,10 @@
target_compile_definitions(mxnet PUBLIC MXNET_USE_CPP_PACKAGE=1)
endif()
+if(BUILD_JAVA_NATIVE)
+ add_definitions(-DSKIP_SIGNAL_HANDLER_REGISTRATION=1)
+endif()
+
if(NOT CMAKE_BUILD_TYPE STREQUAL "Distribution")
# Staticbuild applies linker version script to hide private symbols, breaking unit tests
add_subdirectory(tests)
diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh
index 70934f6..1e75ab2 100755
--- a/ci/docker/runtime_functions.sh
+++ b/ci/docker/runtime_functions.sh
@@ -310,6 +310,23 @@
build_ubuntu_cpu_openblas
}
+build_ubuntu_cpu_and_test_java() {
+ build_ubuntu_cpu_openblas_java
+ java_package_integration_test
+}
+
+java_package_integration_test() {
+ # make sure you are using java 11
+ # build java project
+ cd /work/mxnet/java-package
+ ./gradlew build -x javadoc
+ # generate native library
+ ./gradlew :native:buildLocalLibraryJarDefault
+ ./gradlew :native:mkl-linuxJar
+ # run integration
+ ./gradlew :integration:run
+}
+
build_ubuntu_cpu_openblas() {
set -ex
cd /work/build
@@ -327,6 +344,24 @@
ninja
}
+build_ubuntu_cpu_openblas_java() {
+ set -ex
+ cd /work/build
+ CXXFLAGS="-Wno-error=strict-overflow" CC=gcc-7 CXX=g++-7 cmake \
+ -DCMAKE_BUILD_TYPE="RelWithDebInfo" \
+ -DENABLE_TESTCOVERAGE=ON \
+ -DUSE_TVM_OP=ON \
+ -DUSE_BLAS=Open \
+ -DUSE_ONEDNN=OFF \
+ -DUSE_CUDA=OFF \
+ -DUSE_DIST_KVSTORE=ON \
+ -DBUILD_CYTHON_MODULES=ON \
+ -DBUILD_EXTENSION_PATH=/work/mxnet/example/extensions/lib_external_ops \
+ -DBUILD_JAVA_NATIVE=ON \
+ -G Ninja /work/mxnet
+ ninja
+}
+
build_ubuntu_cpu_mkl() {
set -ex
cd /work/build
diff --git a/java-package/Develop.md b/java-package/Develop.md
new file mode 100644
index 0000000..8bc0ffe
--- /dev/null
+++ b/java-package/Develop.md
@@ -0,0 +1,109 @@
+<!--- Licensed to the Apache Software Foundation (ASF) under one -->
+<!--- or more contributor license agreements. See the NOTICE file -->
+<!--- distributed with this work for additional information -->
+<!--- regarding copyright ownership. The ASF licenses this file -->
+<!--- to you under the Apache License, Version 2.0 (the -->
+<!--- "License"); you may not use this file except in compliance -->
+<!--- with the License. You may obtain a copy of the License at -->
+
+<!--- http://www.apache.org/licenses/LICENSE-2.0 -->
+
+<!--- Unless required by applicable law or agreed to in writing, -->
+<!--- software distributed under the License is distributed on an -->
+<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
+<!--- KIND, either express or implied. See the License for the -->
+<!--- specific language governing permissions and limitations -->
+<!--- under the License. -->
+
+# Development Tips
+
+## Set up the Project
+### Step 1. Obtain MXNet Library
+The first step is to obtain the mxnet library. We recommend you build it from source. Also, you can download the library
+from
+#### Build from source
+Refer to [Build From Source](https://mxnet.apache.org/get_started/build_from_source#building-mxnet)
+For MacOS users:
+- Prepare
+```shell
+# Install OS X Developer Tools
+$ xcode-select --install
+
+# Install Homebrew
+$ /usr/bin/ruby -e "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/master/install)"
+
+# Install dependencies
+$ brew install cmake ninja ccache opencv
+```
+- Clone 3rd party projects
+```shell
+# Clone 3rd dependency for mxnet. It's necessary
+$ git submodule update --init --recursive
+```
+- Build MXNet
+```shell
+# select and copy cmake configure files for macos
+$ cp config/darwin.cmake config.cmake
+
+# create build directory for prpject
+$ mkdir build; cd build
+
+# cmake
+$ cmake ..
+$ cmake --build .
+```
+Libraries will be generated under the directory _build/_.
+
+For Linux users:
+Docker might help you build libraries on different platforms. You can get help from [README for CI](../ci/README.md).
+For example, you can build mxnet on Ubuntu with by the following command.
+```shell
+$ python3 ci/build.py -p ubuntu_cpu
+```
+##### Download Pre-built library
+You can find the mxnet library from installed packages for mxnet, like python module. However, mxnet 2.0 is not released
+yet, that's why we recommend you build it from source.
+```shell
+# download python module for mxnet (have to mention that mxnet 2.0 hasn't been released by now)
+$ pip3 install mxnet==1.7.0.post2
+# find the location of the installed module
+$ python
+Python 3.6.8 |Anaconda, Inc.| (default, Dec 29 2018, 19:04:46)
+>>> import mxnet
+>>> mxnet
+<module 'mxnet' from '/Users/xxx/anaconda3/lib/python3.6/site-packages/mxnet/__init__.py'>
+>>> quit()
+# you can locate the module under /Users/xxx/anaconda3/lib/python3.6/site-packages/mxnet/
+$ ls /Users/xxx/anaconda3/lib/python3.6/site-packages/mxnet/ | grep libmxnet
+libmxnet.dylib
+```
+The compiled library is the file with the name of _libmxnet.*_. For MacOS, you will receive the file with suffix
+_.dylib_; For Linux, the lib file have the suffix ".so"; For Windows, the suffix is "."
+
+### Step 2. Build MXNet Native Lib for Java
+The project uses gradle to manage dependencies. You can build the project using gradle. We have to encapsulate the mxnet
+library into a jar file so that we can load it into JVM.
+```shell
+$ cd java-package
+# Build the project
+$ ./gradlew build
+# Create gradle tasks to package mxnet library into jar
+# The task name is in this form {$favor}-{$platform}Jar
+# MacOS -> mkl-osxJar
+# Linux -> mkl-linuxJar
+# Windows -> mkl-winJar
+$ ./gradlew :native:buildLocalLibraryJarDefault
+# Build native lib for macos
+$ ./gradlew mkl-osxJar
+# Check the lib for osx
+$ ls native/build/libs | grep osx
+native-2.0.0-SNAPSHOT-osx-x86_64.jar
+```
+The jar file _native-2.0.0-SNAPSHOT-osx-x86_64.jar_ is the output lib file.
+
+### Step 3. Run Integration Test
+When we execute the task for integration test, the built mxnet native lib will be added into classpath automatically.
+```shell
+$ ./gradlew :integration:run
+
+```
\ No newline at end of file
diff --git a/java-package/README.md b/java-package/README.md
new file mode 100644
index 0000000..eb5b901
--- /dev/null
+++ b/java-package/README.md
@@ -0,0 +1,58 @@
+<!--- Licensed to the Apache Software Foundation (ASF) under one -->
+<!--- or more contributor license agreements. See the NOTICE file -->
+<!--- distributed with this work for additional information -->
+<!--- regarding copyright ownership. The ASF licenses this file -->
+<!--- to you under the Apache License, Version 2.0 (the -->
+<!--- "License"); you may not use this file except in compliance -->
+<!--- with the License. You may obtain a copy of the License at -->
+
+<!--- http://www.apache.org/licenses/LICENSE-2.0 -->
+
+<!--- Unless required by applicable law or agreed to in writing, -->
+<!--- software distributed under the License is distributed on an -->
+<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
+<!--- KIND, either express or implied. See the License for the -->
+<!--- specific language governing permissions and limitations -->
+<!--- under the License. -->
+
+# Java Package for MXNet 2.0
+
+## Requirements
+
+## Install
+
+## Scripts
+- customize mxnet library path
+```bash
+export MXNET_LIBRARY_PATH=//anaconda3/lib/python3.8/site-packages/mxnet/
+```
+
+
+## Tests
+Test case for a rough inference run with MXNet model
+```bash
+./gradlew :integration:run
+```
+
+## Example
+
+```java
+try (MxResource base = BaseMxResource.getSystemMxResource())
+ {
+ Model model = Model.loadModel(Item.MLP);
+// Model model = Model.loadModel("test", Paths.get("/Users/cspchen/mxnet.java_package/cache/repo/test-models/mlp.tar.gz/mlp/"));
+ Predictor<NDList, NDList> predictor = model.newPredictor();
+ NDArray input = NDArray.create(base, new Shape(1, 28, 28)).ones();
+ NDList inputs = new NDList();
+ inputs.add(input);
+ NDList result = predictor.predict(inputs);
+ NDArray expected = NDArray.create(
+ base,
+ new float[]{4.93476f, -0.76084447f, 0.37713608f, 0.6605506f, -1.3485785f, -0.8736369f
+ , 0.018061712f, -1.3274033f, 1.0609543f, 0.24042489f}, new Shape(1, 10));
+ Assertions.assertAlmostEquals(result.get(0), expected);
+
+ } catch (IOException e) {
+ logger.error(e.getMessage(), e);
+ }
+```
\ No newline at end of file
diff --git a/java-package/build.gradle b/java-package/build.gradle
new file mode 100644
index 0000000..909dbbb
--- /dev/null
+++ b/java-package/build.gradle
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+plugins {
+ id "com.github.spotbugs" version "4.2.0" apply true
+}
+
+defaultTasks 'build'
+
+allprojects {
+ group 'org.apache.mxnet'
+ boolean isRelease = project.hasProperty("release") || project.hasProperty("staging")
+ version = "${java_package_version}" + (isRelease ? "" : "-SNAPSHOT")
+
+ repositories {
+// maven {
+// url "https://mlrepo.djl.ai/maven/"
+// }
+ mavenCentral()
+ maven {
+ url 'https://oss.sonatype.org/content/repositories/snapshots/'
+ }
+ }
+
+ apply plugin: 'idea'
+ idea {
+ module {
+ outputDir = file('build/classes/java/main')
+ testOutputDir = file('build/classes/java/test')
+ // inheritOutputDirs = true
+ }
+ }
+}
+
+def javaProjects() {
+ return subprojects.findAll { new File(it.projectDir, "src/main").exists() }
+}
+
+configure(javaProjects()) {
+ apply plugin: 'java-library'
+ sourceCompatibility = 1.8
+ targetCompatibility = 1.8
+ compileJava.options.encoding = "UTF-8"
+ compileTestJava.options.encoding = "UTF-8"
+ if (JavaVersion.current() != JavaVersion.VERSION_1_8) {
+ compileJava.options.compilerArgs.addAll(["--release", "8"])
+ }
+
+ apply plugin: 'eclipse'
+
+ eclipse {
+ jdt.file.withProperties { props ->
+ props.setProperty "org.eclipse.jdt.core.circularClasspath", "warning"
+ }
+ classpath {
+ sourceSets.test.java {
+ srcDirs = ["src/test/java"]
+ exclude "**/package-info.java"
+ }
+ }
+ }
+
+ apply from: file("${rootProject.projectDir}/tools/gradle/java-formatter.gradle")
+ apply from: file("${rootProject.projectDir}/tools/gradle/check.gradle")
+
+ test {
+ // tensorflow mobilenet and resnet require more cpu memory
+ maxHeapSize = "4096m"
+ doFirst {
+ if (JavaVersion.current() != JavaVersion.VERSION_1_8) {
+ jvmArgs = [
+ '--add-opens', "java.base/jdk.internal.loader=ALL-UNNAMED"
+ ]
+ }
+ }
+
+ useTestNG() {
+// suiteXmlFiles << new File(rootDir, "testng.xml") //This is how to add custom testng.xml
+ }
+
+ testLogging {
+ showStandardStreams = true
+ events "passed", "skipped", "failed", "standardOut", "standardError"
+ }
+
+ doFirst {
+ systemProperties System.getProperties()
+ systemProperties.remove("user.dir")
+ systemProperty "org.apache.mxnet.logging.level", "debug"
+ systemProperty "org.slf4j.simpleLogger.defaultLogLevel", "debug"
+ systemProperty "org.slf4j.simpleLogger.log.org.mortbay.log", "warn"
+ systemProperty "disableProgressBar", "true"
+ systemProperty "nightly", System.getProperty("nightly", "false")
+// systemProperty "java.library.path", "/Users/cspchen/Work/incubator-mxnet/build"
+ if (gradle.startParameter.offline) {
+ systemProperty "offline", "true"
+ }
+ }
+ }
+
+ compileJava {
+ options.compilerArgs << "-proc:none" << "-Xlint:all,-options,-static" << "-Werror"
+ }
+
+ compileTestJava {
+ options.compilerArgs << "-proc:none" << "-Xlint:all,-options,-static" << "-Werror"
+ }
+
+ jar {
+ manifest {
+ attributes("Automatic-Module-Name": "org.apach.mxnet.${project.name.replace('-', '_')}")
+ }
+ }
+}
+
+apply from: file("${rootProject.projectDir}/tools/gradle/jacoco.gradle")
+apply from: file("${rootProject.projectDir}/tools/gradle/stats.gradle")
diff --git a/java-package/example/build.gradle b/java-package/example/build.gradle
new file mode 100644
index 0000000..a6831ee
--- /dev/null
+++ b/java-package/example/build.gradle
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+plugins {
+ id 'java'
+}
+
+group 'incubator-mxnet.java-package'
+version '0.0.1-SNAPSHOT'
+
+repositories {
+ mavenCentral()
+}
+
+dependencies {
+ testImplementation 'org.junit.jupiter:junit-jupiter-api:5.7.0'
+ testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine:5.7.0'
+}
+
+test {
+ useJUnitPlatform()
+}
\ No newline at end of file
diff --git a/java-package/gradle.properties b/java-package/gradle.properties
new file mode 100644
index 0000000..9d32099
--- /dev/null
+++ b/java-package/gradle.properties
@@ -0,0 +1,43 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+org.gradle.daemon=true
+org.gradle.jvmargs=-Xmx2048M
+
+systemProp.org.gradle.internal.http.socketTimeout=120000
+systemProp.org.gradle.internal.http.connectionTimeout=60000
+
+# FIXME: Workaround gradle publish issue: https://github.com/gradle/gradle/issues/11308
+systemProp.org.gradle.internal.publish.checksums.insecure=true
+
+java_package_version=0.0.1
+mxnet_version=2.0.0
+api_version=0.0.1
+jnarator_version=0.0.1
+
+antlr_version=4.7.2
+commons_cli_version=1.4
+commons_compress_version=1.20
+commons_csv_version=1.8
+gson_version=2.8.6
+jna_version=5.3.0
+netty_version=4.1.51.Final
+slf4j_version=1.7.30
+log4j_slf4j_version=2.13.3
+testng_version=7.1.0
+powermock_version=2.0.7
+
diff --git a/java-package/gradle/wrapper/gradle-wrapper.jar b/java-package/gradle/wrapper/gradle-wrapper.jar
new file mode 100644
index 0000000..e708b1c
--- /dev/null
+++ b/java-package/gradle/wrapper/gradle-wrapper.jar
Binary files differ
diff --git a/java-package/gradle/wrapper/gradle-wrapper.properties b/java-package/gradle/wrapper/gradle-wrapper.properties
new file mode 100644
index 0000000..689c068
--- /dev/null
+++ b/java-package/gradle/wrapper/gradle-wrapper.properties
@@ -0,0 +1,22 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+distributionBase=GRADLE_USER_HOME
+distributionPath=wrapper/dists
+distributionUrl=https\://services.gradle.org/distributions/gradle-7.0-bin.zip
+zipStoreBase=GRADLE_USER_HOME
+zipStorePath=wrapper/dists
diff --git a/java-package/gradlew b/java-package/gradlew
new file mode 100755
index 0000000..ebb6c09
--- /dev/null
+++ b/java-package/gradlew
@@ -0,0 +1,186 @@
+#!/usr/bin/env sh
+
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+##############################################################################
+##
+## Gradle start up script for UN*X
+##
+##############################################################################
+
+# Attempt to set APP_HOME
+# Resolve links: $0 may be a link
+PRG="$0"
+# Need this for relative symlinks.
+while [ -h "$PRG" ] ; do
+ ls=`ls -ld "$PRG"`
+ link=`expr "$ls" : '.*-> \(.*\)$'`
+ if expr "$link" : '/.*' > /dev/null; then
+ PRG="$link"
+ else
+ PRG=`dirname "$PRG"`"/$link"
+ fi
+done
+SAVED="`pwd`"
+cd "`dirname \"$PRG\"`/" >/dev/null
+APP_HOME="`pwd -P`"
+cd "$SAVED" >/dev/null
+
+APP_NAME="Gradle"
+APP_BASE_NAME=`basename "$0"`
+
+# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
+DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"'
+
+# Use the maximum available, or set MAX_FD != -1 to use that value.
+MAX_FD="maximum"
+
+warn () {
+ echo "$*"
+}
+
+die () {
+ echo
+ echo "$*"
+ echo
+ exit 1
+}
+
+# OS specific support (must be 'true' or 'false').
+cygwin=false
+msys=false
+darwin=false
+nonstop=false
+case "`uname`" in
+ CYGWIN* )
+ cygwin=true
+ ;;
+ Darwin* )
+ darwin=true
+ ;;
+ MINGW* )
+ msys=true
+ ;;
+ NONSTOP* )
+ nonstop=true
+ ;;
+esac
+
+CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
+
+
+# Determine the Java command to use to start the JVM.
+if [ -n "$JAVA_HOME" ] ; then
+ if [ -x "$JAVA_HOME/jre/sh/java" ] ; then
+ # IBM's JDK on AIX uses strange locations for the executables
+ JAVACMD="$JAVA_HOME/jre/sh/java"
+ else
+ JAVACMD="$JAVA_HOME/bin/java"
+ fi
+ if [ ! -x "$JAVACMD" ] ; then
+ die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME
+
+Please set the JAVA_HOME variable in your environment to match the
+location of your Java installation."
+ fi
+else
+ JAVACMD="java"
+ which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
+
+Please set the JAVA_HOME variable in your environment to match the
+location of your Java installation."
+fi
+
+# Increase the maximum file descriptors if we can.
+if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then
+ MAX_FD_LIMIT=`ulimit -H -n`
+ if [ $? -eq 0 ] ; then
+ if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then
+ MAX_FD="$MAX_FD_LIMIT"
+ fi
+ ulimit -n $MAX_FD
+ if [ $? -ne 0 ] ; then
+ warn "Could not set maximum file descriptor limit: $MAX_FD"
+ fi
+ else
+ warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT"
+ fi
+fi
+
+# For Darwin, add options to specify how the application appears in the dock
+if $darwin; then
+ GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\""
+fi
+
+# For Cygwin or MSYS, switch paths to Windows format before running java
+if [ "$cygwin" = "true" -o "$msys" = "true" ] ; then
+ APP_HOME=`cygpath --path --mixed "$APP_HOME"`
+ CLASSPATH=`cygpath --path --mixed "$CLASSPATH"`
+
+ JAVACMD=`cygpath --unix "$JAVACMD"`
+
+ # We build the pattern for arguments to be converted via cygpath
+ ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null`
+ SEP=""
+ for dir in $ROOTDIRSRAW ; do
+ ROOTDIRS="$ROOTDIRS$SEP$dir"
+ SEP="|"
+ done
+ OURCYGPATTERN="(^($ROOTDIRS))"
+ # Add a user-defined pattern to the cygpath arguments
+ if [ "$GRADLE_CYGPATTERN" != "" ] ; then
+ OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)"
+ fi
+ # Now convert the arguments - kludge to limit ourselves to /bin/sh
+ i=0
+ for arg in "$@" ; do
+ CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -`
+ CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option
+
+ if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition
+ eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"`
+ else
+ eval `echo args$i`="\"$arg\""
+ fi
+ i=`expr $i + 1`
+ done
+ case $i in
+ 0) set -- ;;
+ 1) set -- "$args0" ;;
+ 2) set -- "$args0" "$args1" ;;
+ 3) set -- "$args0" "$args1" "$args2" ;;
+ 4) set -- "$args0" "$args1" "$args2" "$args3" ;;
+ 5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;;
+ 6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;;
+ 7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;;
+ 8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;;
+ 9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;;
+ esac
+fi
+
+# Escape application args
+save () {
+ for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done
+ echo " "
+}
+APP_ARGS=`save "$@"`
+
+# Collect all arguments for the java command, following the shell quoting and substitution rules
+eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS"
+
+exec "$JAVACMD" "$@"
diff --git a/java-package/gradlew.bat b/java-package/gradlew.bat
new file mode 100644
index 0000000..bcf7fb7
--- /dev/null
+++ b/java-package/gradlew.bat
@@ -0,0 +1,92 @@
+@REM
+@REM Licensed to the Apache Software Foundation (ASF) under one
+@REM or more contributor license agreements. See the NOTICE file
+@REM distributed with this work for additional information
+@REM regarding copyright ownership. The ASF licenses this file
+@REM to you under the Apache License, Version 2.0 (the
+@REM "License"); you may not use this file except in compliance
+@REM with the License. You may obtain a copy of the License at
+@REM
+@REM http://www.apache.org/licenses/LICENSE-2.0
+@REM
+@REM Unless required by applicable law or agreed to in writing,
+@REM software distributed under the License is distributed on an
+@REM "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+@REM KIND, either express or implied. See the License for the
+@REM specific language governing permissions and limitations
+@REM under the License.
+@REM
+
+@if "%DEBUG%" == "" @echo off
+@rem ##########################################################################
+@rem
+@rem Gradle startup script for Windows
+@rem
+@rem ##########################################################################
+
+@rem Set local scope for the variables with windows NT shell
+if "%OS%"=="Windows_NT" setlocal
+
+set DIRNAME=%~dp0
+if "%DIRNAME%" == "" set DIRNAME=.
+set APP_BASE_NAME=%~n0
+set APP_HOME=%DIRNAME%
+
+@rem Resolve any "." and ".." in APP_HOME to make it shorter.
+for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi
+
+@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
+set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m"
+
+@rem Find java.exe
+if defined JAVA_HOME goto findJavaFromJavaHome
+
+set JAVA_EXE=java.exe
+%JAVA_EXE% -version >NUL 2>&1
+if "%ERRORLEVEL%" == "0" goto execute
+
+echo.
+echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
+echo.
+echo Please set the JAVA_HOME variable in your environment to match the
+echo location of your Java installation.
+
+goto fail
+
+:findJavaFromJavaHome
+set JAVA_HOME=%JAVA_HOME:"=%
+set JAVA_EXE=%JAVA_HOME%/bin/java.exe
+
+if exist "%JAVA_EXE%" goto execute
+
+echo.
+echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME%
+echo.
+echo Please set the JAVA_HOME variable in your environment to match the
+echo location of your Java installation.
+
+goto fail
+
+:execute
+@rem Setup the command line
+
+set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar
+
+
+@rem Execute Gradle
+"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %*
+
+:end
+@rem End local scope for the variables with windows NT shell
+if "%ERRORLEVEL%"=="0" goto mainEnd
+
+:fail
+rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of
+rem the _cmd.exe /c_ return code!
+if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1
+exit /b 1
+
+:mainEnd
+if "%OS%"=="Windows_NT" endlocal
+
+:omega
diff --git a/java-package/integration/build.gradle b/java-package/integration/build.gradle
new file mode 100644
index 0000000..fed8860
--- /dev/null
+++ b/java-package/integration/build.gradle
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+plugins {
+ id 'application'
+ id 'jacoco'
+}
+
+group 'org.apache.mxnet'
+version '0.0.1-SNAPSHOT'
+
+repositories {
+ mavenCentral()
+}
+
+application {
+ mainClassName = System.getProperty("main", "org.apache.mxnet.integration.IntegrationTest")
+}
+
+dependencies {
+ api "commons-cli:commons-cli:${commons_cli_version}"
+ api "org.apache.logging.log4j:log4j-slf4j-impl:${log4j_slf4j_version}"
+
+ api project(":mxnet-engine")
+ implementation "org.testng:testng:${testng_version}"
+// testImplementation(":mxnet-engine")
+ testImplementation("org.testng:testng:${testng_version}") {
+ exclude group: "junit", module: "junit"
+ }
+ testImplementation "org.slf4j:slf4j-simple:${slf4j_version}"
+}
+
+run {
+ systemProperties System.getProperties()
+ systemProperties.remove("user.dir")
+ systemProperty("file.encoding", "UTF-8")
+ jvmArgs "-Xverify:none"
+}
+
+checkstyleMain {
+ // skip check style for this package
+ exclude 'org/apache/mxnet/integration/**'
+}
+
+//test {
+//
+// useTestNG()
+// filter {
+// includeTestsMatching "org.apache.mxnet.integration.tests.engine.*"
+// }
+//
+//}
+
diff --git a/java-package/integration/src/main/java/org/apache/mxnet/integration/IntegrationTest.java b/java-package/integration/src/main/java/org/apache/mxnet/integration/IntegrationTest.java
new file mode 100644
index 0000000..96d0ad3
--- /dev/null
+++ b/java-package/integration/src/main/java/org/apache/mxnet/integration/IntegrationTest.java
@@ -0,0 +1,413 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.integration;
+
+import java.io.File;
+import java.io.IOException;
+import java.lang.reflect.Constructor;
+import java.lang.reflect.InvocationTargetException;
+import java.lang.reflect.Method;
+import java.net.URISyntaxException;
+import java.net.URL;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.time.Duration;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.Enumeration;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.jar.JarEntry;
+import java.util.jar.JarFile;
+import java.util.stream.Collectors;
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.DefaultParser;
+import org.apache.commons.cli.HelpFormatter;
+import org.apache.commons.cli.Options;
+import org.apache.commons.cli.ParseException;
+import org.apache.mxnet.integration.util.Arguments;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.testng.SkipException;
+import org.testng.annotations.AfterClass;
+import org.testng.annotations.AfterTest;
+import org.testng.annotations.BeforeClass;
+import org.testng.annotations.BeforeTest;
+import org.testng.annotations.Test;
+
+public class IntegrationTest {
+
+ private static final Logger logger = LoggerFactory.getLogger(IntegrationTest.class);
+
+ private Class<?> source;
+
+ public IntegrationTest(Class<?> source) {
+ this.source = source;
+ }
+
+ public static void main(String[] args) {
+ new IntegrationTest(IntegrationTest.class).runTests(args);
+ // TODO: not elegant solution to native library crash
+ // System.exit(0);
+ }
+
+ public boolean runTests(String[] args) {
+ Options options = Arguments.getOptions();
+ try {
+ DefaultParser parser = new DefaultParser();
+ CommandLine cmd = parser.parse(options, args, null, false);
+ Arguments arguments = new Arguments(cmd);
+
+ Duration duration = Duration.ofMinutes(arguments.getDuration());
+ List<TestClass> tests = listTests(arguments, source);
+
+ boolean testsPassed = true;
+ while (!duration.isNegative()) {
+ long begin = System.currentTimeMillis();
+
+ testsPassed = testsPassed && runTests(tests);
+
+ long delta = System.currentTimeMillis() - begin;
+ duration = duration.minus(Duration.ofMillis(delta));
+ }
+ return testsPassed;
+ } catch (ParseException e) {
+ HelpFormatter formatter = new HelpFormatter();
+ formatter.setLeftPadding(1);
+ formatter.setWidth(120);
+ formatter.printHelp(e.getMessage(), options);
+ return false;
+ } catch (Throwable t) {
+ logger.error("Unexpected error", t);
+ return false;
+ }
+ }
+
+ private boolean runTests(List<TestClass> tests) {
+ Map<TestResult, Integer> totals = new ConcurrentHashMap<>();
+ for (TestClass testClass : tests) {
+ logger.info("Running test {} ...", testClass.getName());
+ int testCount = testClass.getTestCount();
+
+ try {
+ if (!testClass.beforeClass()) {
+ totals.merge(TestResult.FAILED, testCount, Integer::sum);
+ continue;
+ }
+
+ for (int i = 0; i < testCount; ++i) {
+ TestResult result = testClass.runTest(i);
+ totals.merge(result, 1, Integer::sum);
+ }
+ } finally {
+ testClass.afterClass();
+ }
+ }
+
+ int totalFailed = totals.getOrDefault(TestResult.FAILED, 0);
+ int totalPassed = totals.getOrDefault(TestResult.SUCCESS, 0);
+ int totalSkipped = totals.getOrDefault(TestResult.SKIPPED, 0);
+ int totalUnsupported = totals.getOrDefault(TestResult.UNSUPPORTED, 0);
+ if (totalSkipped > 0) {
+ logger.info("Skipped: {} tests", totalSkipped);
+ }
+ if (totalUnsupported > 0) {
+ logger.info("Unsupported: {} tests", totalUnsupported);
+ }
+ if (totalFailed > 0) {
+ logger.error("Failed {} out of {} tests", totalFailed, totalFailed + totalPassed);
+ } else {
+ logger.info("Passed all {} tests", totalPassed);
+ }
+ return totalFailed == 0;
+ }
+
+ private static List<TestClass> listTests(Arguments arguments, Class<?> source)
+ throws IOException, ReflectiveOperationException, URISyntaxException {
+ String className = arguments.getClassName();
+ String methodName = arguments.getMethodName();
+ List<TestClass> tests = new ArrayList<>();
+ try {
+ if (className != null) {
+ Class<?> clazz;
+ if (className.startsWith(arguments.getPackageName())) {
+ clazz = Class.forName(className);
+ } else {
+ clazz = Class.forName(arguments.getPackageName() + className);
+ }
+ getTestsInClass(clazz, methodName).map(tests::add);
+ } else {
+ List<Class<?>> classes = listTestClasses(arguments, source);
+ for (Class<?> clazz : classes) {
+ getTestsInClass(clazz, methodName).map(tests::add);
+ }
+ }
+ } catch (ReflectiveOperationException | IOException | URISyntaxException e) {
+ logger.error("Failed to resolve test class.", e);
+ throw e;
+ }
+ return tests;
+ }
+
+ private static Optional<TestClass> getTestsInClass(Class<?> clazz, String methodName)
+ throws ReflectiveOperationException {
+ if (clazz.getConstructors().length == 0) {
+ return Optional.empty();
+ }
+ Constructor<?> ctor = clazz.getConstructor();
+ Object obj = ctor.newInstance();
+ TestClass testClass = new TestClass(obj);
+
+ for (Method method : clazz.getDeclaredMethods()) {
+ Test testMethod = method.getAnnotation(Test.class);
+ if (testMethod != null) {
+ if (testMethod.enabled()
+ && (methodName == null || methodName.equals(method.getName()))) {
+ testClass.addTestMethod(method);
+ }
+ continue;
+ }
+ BeforeClass beforeClass = method.getAnnotation(BeforeClass.class);
+ if (beforeClass != null) {
+ testClass.addBeforeClass(method);
+ continue;
+ }
+ AfterClass afterClass = method.getAnnotation(AfterClass.class);
+ if (afterClass != null) {
+ testClass.addAfterClass(method);
+ continue;
+ }
+ BeforeTest beforeTest = method.getAnnotation(BeforeTest.class);
+ if (beforeTest != null) {
+ testClass.addBeforeTest(method);
+ continue;
+ }
+ AfterTest afterTest = method.getAnnotation(AfterTest.class);
+ if (afterTest != null) {
+ testClass.addAfterTest(method);
+ }
+ }
+
+ return Optional.of(testClass);
+ }
+
+ private static List<Class<?>> listTestClasses(Arguments arguments, Class<?> clazz)
+ throws IOException, ClassNotFoundException, URISyntaxException {
+ URL url = clazz.getProtectionDomain().getCodeSource().getLocation();
+ String path = url.getPath();
+
+ if (!"file".equalsIgnoreCase(url.getProtocol())) {
+ return Collections.emptyList();
+ }
+
+ List<Class<?>> classList = new ArrayList<>();
+
+ Path classPath = Paths.get(url.toURI());
+ if (Files.isDirectory(classPath)) {
+ Collection<Path> files =
+ Files.walk(classPath)
+ .filter(p -> Files.isRegularFile(p) && p.toString().endsWith(".class"))
+ .collect(Collectors.toList());
+ for (Path file : files) {
+ Path p = classPath.relativize(file);
+ String className = p.toString();
+ className = className.substring(0, className.lastIndexOf('.'));
+ className = className.replace(File.separatorChar, '.');
+ if (className.startsWith(arguments.getPackageName()) && !className.contains("$")) {
+ try {
+ classList.add(Class.forName(className));
+ } catch (ExceptionInInitializerError ignore) {
+ // ignore
+ }
+ }
+ }
+ } else if (path.toLowerCase().endsWith(".jar")) {
+ try (JarFile jarFile = new JarFile(classPath.toFile())) {
+ Enumeration<JarEntry> en = jarFile.entries();
+ while (en.hasMoreElements()) {
+ JarEntry entry = en.nextElement();
+ String fileName = entry.getName();
+ if (fileName.endsWith(".class")) {
+ fileName = fileName.substring(0, fileName.lastIndexOf('.'));
+ fileName = fileName.replace('/', '.');
+ if (fileName.startsWith(arguments.getPackageName())) {
+ try {
+ classList.add(Class.forName(fileName));
+ } catch (ExceptionInInitializerError ignore) {
+ // ignore
+ }
+ }
+ }
+ }
+ }
+ }
+
+ return classList;
+ }
+
+ private static final class TestClass {
+
+ private Object object;
+ private List<Method> testMethods;
+ private List<Method> beforeClass;
+ private List<Method> afterClass;
+ private List<Method> beforeTest;
+ private List<Method> afterTest;
+
+ public TestClass(Object object) {
+ this.object = object;
+ testMethods = new ArrayList<>();
+ beforeClass = new ArrayList<>();
+ afterClass = new ArrayList<>();
+ beforeTest = new ArrayList<>();
+ afterTest = new ArrayList<>();
+ }
+
+ public void addTestMethod(Method method) {
+ testMethods.add(method);
+ }
+
+ public void addBeforeClass(Method method) {
+ beforeClass.add(method);
+ }
+
+ public void addAfterClass(Method method) {
+ afterClass.add(method);
+ }
+
+ public void addBeforeTest(Method method) {
+ beforeTest.add(method);
+ }
+
+ public void addAfterTest(Method method) {
+ afterTest.add(method);
+ }
+
+ public boolean beforeClass() {
+ try {
+ for (Method method : beforeClass) {
+ method.invoke(object);
+ }
+ return true;
+ } catch (InvocationTargetException | IllegalAccessException e) {
+ logger.error("", e.getCause());
+ }
+ return false;
+ }
+
+ public void afterClass() {
+ try {
+ for (Method method : afterClass) {
+ method.invoke(object);
+ }
+ } catch (InvocationTargetException | IllegalAccessException e) {
+ logger.error("", e.getCause());
+ }
+ }
+
+ public boolean beforeTest() {
+ try {
+ for (Method method : beforeTest) {
+ method.invoke(object);
+ }
+ return true;
+ } catch (InvocationTargetException | IllegalAccessException e) {
+ logger.error("", e.getCause());
+ }
+ return false;
+ }
+
+ public void afterTest() {
+ try {
+ for (Method method : afterTest) {
+ method.invoke(object);
+ }
+ } catch (InvocationTargetException | IllegalAccessException e) {
+ logger.error("", e.getCause());
+ }
+ }
+
+ public TestResult runTest(int index) {
+ if (!beforeTest()) {
+ return TestResult.FAILED;
+ }
+
+ TestResult result;
+ Method method = testMethods.get(index);
+ try {
+ long begin = System.nanoTime();
+ method.invoke(object);
+ String time = String.format("%.3f", (System.nanoTime() - begin) / 1000_0000f);
+ logger.info("Test {}.{} PASSED, duration: {}", getName(), method.getName(), time);
+ result = TestResult.SUCCESS;
+ } catch (IllegalAccessException | InvocationTargetException e) {
+ if (expectedException(method, e)) {
+ logger.info("Test {}.{} PASSED", getName(), method.getName());
+ result = TestResult.SUCCESS;
+ } else if (e.getCause() instanceof SkipException) {
+ logger.info("Test {}.{} SKIPPED", getName(), method.getName());
+ result = TestResult.SKIPPED;
+ } else if (e.getCause() instanceof UnsupportedOperationException) {
+ logger.info("Test {}.{} UNSUPPORTED", getName(), method.getName());
+ logger.trace("", e.getCause());
+ result = TestResult.UNSUPPORTED;
+ } else {
+ logger.error("Test {}.{} FAILED", getName(), method.getName());
+ logger.error("", e.getCause());
+ result = TestResult.FAILED;
+ }
+ } finally {
+ afterTest();
+ }
+ return result;
+ }
+
+ public int getTestCount() {
+ return testMethods.size();
+ }
+
+ public String getName() {
+ return object.getClass().getName();
+ }
+
+ private static boolean expectedException(Method method, Exception e) {
+ Test test = method.getAnnotation(Test.class);
+ Class<?>[] exceptions = test.expectedExceptions();
+ if (exceptions.length > 0) {
+ Throwable exception = e.getCause();
+ for (Class<?> c : exceptions) {
+ if (c.isInstance(exception)) {
+ return true;
+ }
+ }
+ }
+ return false;
+ }
+ }
+
+ public enum TestResult {
+ SUCCESS,
+ FAILED,
+ SKIPPED,
+ UNSUPPORTED;
+ }
+}
diff --git a/java-package/integration/src/main/java/org/apache/mxnet/integration/package-info.java b/java-package/integration/src/main/java/org/apache/mxnet/integration/package-info.java
new file mode 100644
index 0000000..f97e77a
--- /dev/null
+++ b/java-package/integration/src/main/java/org/apache/mxnet/integration/package-info.java
@@ -0,0 +1,19 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/** Contains integration tests that use the engine to test the actual behavior of the API. */
+package org.apache.mxnet.integration;
diff --git a/java-package/integration/src/main/java/org/apache/mxnet/integration/tests/engine/ModelTest.java b/java-package/integration/src/main/java/org/apache/mxnet/integration/tests/engine/ModelTest.java
new file mode 100644
index 0000000..8beffb2
--- /dev/null
+++ b/java-package/integration/src/main/java/org/apache/mxnet/integration/tests/engine/ModelTest.java
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.integration.tests.engine;
+
+import java.io.IOException;
+import org.apache.mxnet.engine.BaseMxResource;
+import org.apache.mxnet.engine.Model;
+import org.apache.mxnet.engine.MxResource;
+import org.apache.mxnet.engine.Predictor;
+import org.apache.mxnet.integration.tests.jna.JnaUtilTest;
+import org.apache.mxnet.integration.util.Assertions;
+import org.apache.mxnet.ndarray.NDArray;
+import org.apache.mxnet.ndarray.NDList;
+import org.apache.mxnet.ndarray.types.Shape;
+import org.apache.mxnet.repository.Item;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.testng.annotations.Test;
+
+public class ModelTest {
+ private static final Logger logger = LoggerFactory.getLogger(JnaUtilTest.class);
+
+ @Test
+ public void modelLoadAndPredictTest() {
+ try (MxResource base = BaseMxResource.getSystemMxResource()) {
+ Model model = Model.loadModel(Item.MLP);
+ Predictor<NDList, NDList> predictor = model.newPredictor();
+ NDArray input = NDArray.create(base, new Shape(1, 28, 28)).ones();
+ NDList inputs = new NDList();
+ inputs.add(input);
+ NDList result = predictor.predict(inputs);
+ NDArray expected =
+ NDArray.create(
+ base,
+ new float[] {
+ 4.93476f,
+ -0.76084447f,
+ 0.37713608f,
+ 0.6605506f,
+ -1.3485785f,
+ -0.8736369f,
+ 0.018061712f,
+ -1.3274033f,
+ 1.0609543f,
+ 0.24042489f
+ },
+ new Shape(1, 10));
+ Assertions.assertAlmostEquals(result.get(0), expected);
+ } catch (IOException e) {
+ logger.error(e.getMessage(), e);
+ }
+ }
+}
diff --git a/java-package/integration/src/main/java/org/apache/mxnet/integration/tests/engine/package-info.java b/java-package/integration/src/main/java/org/apache/mxnet/integration/tests/engine/package-info.java
new file mode 100644
index 0000000..cc6a916
--- /dev/null
+++ b/java-package/integration/src/main/java/org/apache/mxnet/integration/tests/engine/package-info.java
@@ -0,0 +1,19 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/** Contains integration tests that use the engine to test the actual behavior of the API. */
+package org.apache.mxnet.integration.tests.engine;
diff --git a/java-package/integration/src/main/java/org/apache/mxnet/integration/tests/jna/JnaUtilTest.java b/java-package/integration/src/main/java/org/apache/mxnet/integration/tests/jna/JnaUtilTest.java
new file mode 100644
index 0000000..e751140
--- /dev/null
+++ b/java-package/integration/src/main/java/org/apache/mxnet/integration/tests/jna/JnaUtilTest.java
@@ -0,0 +1,183 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.integration.tests.jna;
+
+import java.io.IOException;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+import org.apache.mxnet.engine.BaseMxResource;
+import org.apache.mxnet.engine.Device;
+import org.apache.mxnet.engine.MxResource;
+import org.apache.mxnet.engine.Symbol;
+import org.apache.mxnet.jna.JnaUtils;
+import org.apache.mxnet.ndarray.NDArray;
+import org.apache.mxnet.ndarray.NDList;
+import org.apache.mxnet.ndarray.types.Shape;
+import org.apache.mxnet.nn.Parameter;
+import org.apache.mxnet.nn.SymbolBlock;
+import org.apache.mxnet.repository.Item;
+import org.apache.mxnet.repository.Repository;
+import org.apache.mxnet.util.PairList;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.testng.Assert;
+import org.testng.annotations.Test;
+
+public class JnaUtilTest {
+
+ private static final Logger logger = LoggerFactory.getLogger(JnaUtilTest.class);
+
+ @Test
+ public void doForwardTest() throws IOException {
+ try (MxResource base = BaseMxResource.getSystemMxResource()) {
+ Path modelPath = Repository.initRepository(Item.MLP);
+ Path symbolPath = modelPath.resolve("mlp-symbol.json");
+ Path paramsPath = modelPath.resolve("mlp-0000.params");
+ Symbol symbol = Symbol.loadSymbol(base, symbolPath);
+ SymbolBlock block = new SymbolBlock(base, symbol);
+ Device device = Device.defaultIfNull();
+ NDList mxNDArray = JnaUtils.loadNdArray(base, paramsPath, Device.defaultIfNull(null));
+
+ // load parameters
+ List<Parameter> parameters = block.getAllParameters();
+ Map<String, Parameter> map = new ConcurrentHashMap<>();
+ parameters.forEach(p -> map.put(p.getName(), p));
+
+ for (NDArray nd : mxNDArray) {
+ String key = nd.getName();
+ if (key == null) {
+ throw new IllegalArgumentException(
+ "Array names must be present in parameter file");
+ }
+
+ String paramName = key.split(":", 2)[1];
+ Parameter parameter = map.remove(paramName);
+ parameter.setArray(nd);
+ }
+ block.setInputNames(new ArrayList<>(map.keySet()));
+
+ NDArray arr = NDArray.create(base, new Shape(1, 28, 28), device).ones();
+ block.forward(new NDList(arr), new PairList<>(), device);
+ logger.info(
+ "Number of MxResource managed by baseMxResource: {}",
+ BaseMxResource.getSystemMxResource().getSubResource().size());
+ } catch (IOException e) {
+ logger.error(e.getMessage(), e);
+ throw e;
+ }
+ Assert.assertEquals(BaseMxResource.getSystemMxResource().getSubResource().size(), 0);
+ }
+
+ @Test
+ public void createNdArray() {
+ try {
+ try (BaseMxResource base = BaseMxResource.getSystemMxResource()) {
+ int[] originIntegerArray = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
+ float[] originFloatArray = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
+ double[] originDoubleArray = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
+ long[] originLongArray = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
+ boolean[] originBooleanArray = {
+ true, false, false, true, true, true, true, false, false, true, true, true
+ };
+ byte[] originByteArray = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
+ NDArray intArray = NDArray.create(base, originIntegerArray, new Shape(3, 4));
+ NDArray floatArray = NDArray.create(base, originFloatArray, new Shape(3, 4));
+ NDArray doubleArray = NDArray.create(base, originDoubleArray, new Shape(3, 4));
+ NDArray longArray = NDArray.create(base, originLongArray, new Shape(3, 4));
+ NDArray booleanArray = NDArray.create(base, originBooleanArray, new Shape(3, 4));
+ NDArray byteArray = NDArray.create(base, originByteArray, new Shape(3, 4));
+ NDArray intArray2 = NDArray.create(base, originIntegerArray);
+ NDArray floatArray2 = NDArray.create(base, originFloatArray);
+ NDArray doubleArray2 = NDArray.create(base, originDoubleArray);
+ NDArray longArray2 = NDArray.create(base, originLongArray);
+ NDArray booleanArray2 = NDArray.create(base, originBooleanArray);
+ NDArray byteArray2 = NDArray.create(base, originByteArray);
+
+ int[] ndArrayInt = intArray.toIntArray();
+ Assert.assertEquals(originIntegerArray, ndArrayInt);
+ // Float -> Double
+ float[] floats = floatArray.toFloatArray();
+ Assert.assertEquals(originFloatArray, floats);
+ double[] ndArrayDouble = doubleArray.toDoubleArray();
+ Assert.assertEquals(originDoubleArray, ndArrayDouble);
+ long[] ndArrayLong = longArray.toLongArray();
+ Assert.assertEquals(originLongArray, ndArrayLong);
+ boolean[] ndArrayBoolean = booleanArray.toBooleanArray();
+ Assert.assertEquals(originBooleanArray, ndArrayBoolean);
+ byte[] ndArrayByte = byteArray.toByteArray();
+ Assert.assertEquals(originByteArray, ndArrayByte);
+
+ int[] ndArrayInt2 = intArray2.toIntArray();
+ Assert.assertEquals(originIntegerArray, ndArrayInt2);
+
+ // Float -> Double
+ float[] floats2 = floatArray2.toFloatArray();
+ Assert.assertEquals(originFloatArray, floats2);
+ double[] ndArrayDouble2 = doubleArray2.toDoubleArray();
+ Assert.assertEquals(originDoubleArray, ndArrayDouble2);
+ long[] ndArrayLong2 = longArray2.toLongArray();
+ Assert.assertEquals(originLongArray, ndArrayLong2);
+ boolean[] ndArrayBoolean2 = booleanArray2.toBooleanArray();
+ Assert.assertEquals(originBooleanArray, ndArrayBoolean2);
+ byte[] ndArrayByte2 = byteArray2.toByteArray();
+ Assert.assertEquals(originByteArray, ndArrayByte2);
+ } catch (ClassCastException e) {
+ logger.error(e.getMessage());
+ throw e;
+ }
+ BaseMxResource base = BaseMxResource.getSystemMxResource();
+ int countNotReleased = 0;
+ for (MxResource mxResource : base.getSubResource().values()) {
+ if (!mxResource.getClosed()) {
+ ++countNotReleased;
+ }
+ }
+ Assert.assertEquals(countNotReleased, 0);
+ } catch (ClassCastException e) {
+ logger.error(e.getMessage());
+ throw e;
+ }
+ }
+
+ @Test
+ public void loadNdArray() throws IOException {
+ try (BaseMxResource base = BaseMxResource.getSystemMxResource()) {
+ Path modelPath = Repository.initRepository(Item.MLP);
+ Path paramsPath = modelPath.resolve("mlp-0000.params");
+ NDList mxNDArray =
+ JnaUtils.loadNdArray(
+ base, Paths.get(paramsPath.toUri()), Device.defaultIfNull(null));
+ logger.info(mxNDArray.toString());
+ logger.info(
+ String.format(
+ "The amount of sub resources managed by BaseMxResource: %s",
+ base.getSubResource().size()));
+ } catch (IOException e) {
+ logger.error(e.getMessage());
+ throw e;
+ }
+ logger.info(
+ String.format(
+ "The amount of sub resources managed by BaseMxResource: %s",
+ BaseMxResource.getSystemMxResource().getSubResource().size()));
+ }
+}
diff --git a/java-package/integration/src/main/java/org/apache/mxnet/integration/util/Arguments.java b/java-package/integration/src/main/java/org/apache/mxnet/integration/util/Arguments.java
new file mode 100644
index 0000000..15771f3
--- /dev/null
+++ b/java-package/integration/src/main/java/org/apache/mxnet/integration/util/Arguments.java
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.integration.util;
+
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.Option;
+import org.apache.commons.cli.Options;
+
+public class Arguments {
+
+ private String methodName;
+ private String className;
+ private String packageName;
+ private int duration;
+ private int iteration = 1;
+
+ public Arguments(CommandLine cmd) {
+ methodName = cmd.getOptionValue("method-name");
+ className = cmd.getOptionValue("class-name");
+ if (cmd.hasOption("package-name")) {
+ packageName = cmd.getOptionValue("package-name");
+ } else {
+ packageName = "org.apache.mxnet.integration.tests.";
+ }
+
+ if (cmd.hasOption("duration")) {
+ duration = Integer.parseInt(cmd.getOptionValue("duration"));
+ }
+ if (cmd.hasOption("iteration")) {
+ iteration = Integer.parseInt(cmd.getOptionValue("iteration"));
+ }
+ }
+
+ public static Options getOptions() {
+ Options options = new Options();
+ options.addOption(
+ Option.builder("d")
+ .longOpt("duration")
+ .hasArg()
+ .argName("DURATION")
+ .desc("Duration of the test.")
+ .build());
+ options.addOption(
+ Option.builder("n")
+ .longOpt("iteration")
+ .hasArg()
+ .argName("ITERATION")
+ .desc("Number of iterations in each test.")
+ .build());
+ options.addOption(
+ Option.builder("p")
+ .longOpt("package-name")
+ .hasArg()
+ .argName("PACKAGE-NAME")
+ .desc("Name of the package to run")
+ .build());
+ options.addOption(
+ Option.builder("c")
+ .longOpt("class-name")
+ .hasArg()
+ .argName("CLASS-NAME")
+ .desc("Name of the class to run")
+ .build());
+ options.addOption(
+ Option.builder("m")
+ .longOpt("method-name")
+ .hasArg()
+ .argName("METHOD-NAME")
+ .desc("Name of the method to run")
+ .build());
+ return options;
+ }
+
+ public int getDuration() {
+ return duration;
+ }
+
+ public int getIteration() {
+ return iteration;
+ }
+
+ public String getPackageName() {
+ return packageName;
+ }
+
+ public String getClassName() {
+ return className;
+ }
+
+ public String getMethodName() {
+ return methodName;
+ }
+}
diff --git a/java-package/integration/src/main/java/org/apache/mxnet/integration/util/Assertions.java b/java-package/integration/src/main/java/org/apache/mxnet/integration/util/Assertions.java
new file mode 100644
index 0000000..b342b45
--- /dev/null
+++ b/java-package/integration/src/main/java/org/apache/mxnet/integration/util/Assertions.java
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mxnet.integration.util;
+
+import org.apache.mxnet.ndarray.NDArray;
+import org.apache.mxnet.ndarray.NDList;
+import org.testng.Assert;
+
+public final class Assertions {
+ private static final double RTOL = 1e-5;
+ private static final double ATOL = 1e-3;
+
+ private Assertions() {}
+
+ private static <T> String getDefaultErrorMessage(T actual, T expected) {
+ return getDefaultErrorMessage(actual, expected, null);
+ }
+
+ private static <T> String getDefaultErrorMessage(T actual, T expected, String errorMessage) {
+ StringBuilder sb = new StringBuilder(100);
+ if (errorMessage != null) {
+ sb.append(errorMessage);
+ }
+ sb.append(System.lineSeparator())
+ .append("Expected: ")
+ .append(expected)
+ .append(System.lineSeparator())
+ .append("Actual: ")
+ .append(actual);
+ return sb.toString();
+ }
+
+ public static void assertAlmostEquals(NDArray actual, NDArray expected) {
+ assertAlmostEquals(actual, expected, RTOL, ATOL);
+ }
+
+ public static void assertAlmostEquals(NDList actual, NDList expected) {
+ assertAlmostEquals(actual, expected, RTOL, ATOL);
+ }
+
+ public static void assertAlmostEquals(double actual, double expected) {
+ assertAlmostEquals(actual, expected, RTOL, ATOL);
+ }
+
+ public static void assertAlmostEquals(
+ double actual, double expected, double rtol, double atol) {
+ if (Math.abs(actual - expected) > (atol + rtol * Math.abs(expected))) {
+ throw new AssertionError(getDefaultErrorMessage(actual, expected));
+ }
+ }
+
+ public static void assertAlmostEquals(
+ NDList actual, NDList expected, double rtol, double atol) {
+ Assert.assertEquals(
+ actual.size(),
+ expected.size(),
+ getDefaultErrorMessage(
+ actual.size(), expected.size(), "The NDLists have different sizes"));
+ int size = actual.size();
+ for (int i = 0; i < size; i++) {
+ assertAlmostEquals(actual.get(i), expected.get(i), rtol, atol);
+ }
+ }
+
+ public static void assertAlmostEquals(
+ NDArray actual, NDArray expected, double rtol, double atol) {
+ if (!actual.getShape().equals(expected.getShape())) {
+ throw new AssertionError(
+ getDefaultErrorMessage(
+ actual.getShape(),
+ expected.getShape(),
+ "The shape of two NDArray are different!"));
+ }
+ Number[] actualDoubleArray = actual.toArray();
+ Number[] expectedDoubleArray = expected.toArray();
+ for (int i = 0; i < actualDoubleArray.length; i++) {
+ double a = actualDoubleArray[i].doubleValue();
+ double b = expectedDoubleArray[i].doubleValue();
+ if (Math.abs(a - b) > (atol + rtol * Math.abs(b))) {
+ throw new AssertionError("Expected:" + b + " but got " + a);
+ }
+ }
+ }
+
+ public static void assertInPlaceEquals(NDArray actual, NDArray expected, NDArray original) {
+ Assert.assertEquals(
+ actual, expected, getDefaultErrorMessage(actual, expected, "Assert Equal failed!"));
+ Assert.assertSame(
+ original,
+ actual,
+ getDefaultErrorMessage(original, expected, "Assert Inplace failed!"));
+ }
+
+ public static void assertInPlaceAlmostEquals(
+ NDArray actual, NDArray expected, NDArray original) {
+ assertInPlaceAlmostEquals(actual, expected, original, RTOL, ATOL);
+ }
+
+ public static void assertInPlaceAlmostEquals(
+ NDArray actual, NDArray expected, NDArray original, double rtol, double atol) {
+ assertAlmostEquals(actual, expected, rtol, atol);
+ Assert.assertSame(
+ original,
+ actual,
+ getDefaultErrorMessage(original, expected, "Assert Inplace failed!"));
+ }
+}
diff --git a/java-package/integration/src/main/java/org/apache/mxnet/integration/util/CoverageUtils.java b/java-package/integration/src/main/java/org/apache/mxnet/integration/util/CoverageUtils.java
new file mode 100644
index 0000000..e897c27
--- /dev/null
+++ b/java-package/integration/src/main/java/org/apache/mxnet/integration/util/CoverageUtils.java
@@ -0,0 +1,260 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mxnet.integration.util;
+
+import java.io.File;
+import java.io.IOException;
+import java.lang.reflect.Constructor;
+import java.lang.reflect.Field;
+import java.lang.reflect.Method;
+import java.lang.reflect.Proxy;
+import java.net.URISyntaxException;
+import java.net.URL;
+import java.net.URLClassLoader;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.Enumeration;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.jar.JarEntry;
+import java.util.jar.JarFile;
+import java.util.stream.Collectors;
+
+public final class CoverageUtils {
+
+ private CoverageUtils() {}
+
+ public static void testGetterSetters(Class<?> baseClass)
+ throws IOException, ReflectiveOperationException, URISyntaxException {
+ List<Class<?>> list = getClasses(baseClass);
+ for (Class<?> clazz : list) {
+ Object obj = null;
+ if (clazz.isEnum()) {
+ obj = clazz.getEnumConstants()[0];
+ } else {
+ Constructor<?>[] constructors = clazz.getConstructors();
+ for (Constructor<?> con : constructors) {
+ try {
+ Class<?>[] types = con.getParameterTypes();
+ Object[] args = new Object[types.length];
+ for (int i = 0; i < args.length; ++i) {
+ args[i] = getMockInstance(types[i], true);
+ }
+ con.setAccessible(true);
+ obj = con.newInstance(args);
+ } catch (ReflectiveOperationException ignore) {
+ // ignore
+ }
+ }
+ }
+ if (obj == null) {
+ continue;
+ }
+
+ Method[] methods = clazz.getDeclaredMethods();
+ for (Method method : methods) {
+ String methodName = method.getName();
+ int parameterCount = method.getParameterCount();
+ try {
+ if (parameterCount == 0
+ && (methodName.startsWith("get")
+ || methodName.startsWith("is")
+ || "toString".equals(methodName)
+ || "hashCode".equals(methodName))) {
+ method.invoke(obj);
+ } else if (parameterCount == 1
+ && (methodName.startsWith("set") || "fromValue".equals(methodName))) {
+ Class<?> type = method.getParameterTypes()[0];
+ method.invoke(obj, getMockInstance(type, true));
+ } else if ("equals".equals(methodName)) {
+ method.invoke(obj, obj);
+ method.invoke(obj, (Object) null);
+ Class<?> type = method.getParameterTypes()[0];
+ method.invoke(obj, getMockInstance(type, true));
+ }
+ } catch (ReflectiveOperationException ignore) {
+ // ignore
+ }
+ }
+ }
+ }
+
+ private static List<Class<?>> getClasses(Class<?> clazz)
+ throws IOException, ReflectiveOperationException, URISyntaxException {
+ ClassLoader appClassLoader = Thread.currentThread().getContextClassLoader();
+ Field field = appClassLoader.getClass().getDeclaredField("ucp");
+ field.setAccessible(true);
+ Object ucp = field.get(appClassLoader);
+ Method method = ucp.getClass().getDeclaredMethod("getURLs");
+ URL[] urls = (URL[]) method.invoke(ucp);
+ ClassLoader cl = new TestClassLoader(urls, Thread.currentThread().getContextClassLoader());
+
+ URL url = clazz.getProtectionDomain().getCodeSource().getLocation();
+ String path = url.getPath();
+
+ if (!"file".equalsIgnoreCase(url.getProtocol())) {
+ return Collections.emptyList();
+ }
+
+ List<Class<?>> classList = new ArrayList<>();
+
+ Path classPath = Paths.get(url.toURI());
+ if (Files.isDirectory(classPath)) {
+ Collection<Path> files =
+ Files.walk(classPath)
+ .filter(p -> Files.isRegularFile(p) && p.toString().endsWith(".class"))
+ .collect(Collectors.toList());
+ for (Path file : files) {
+ Path p = classPath.relativize(file);
+ String className = p.toString();
+ className = className.substring(0, className.lastIndexOf('.'));
+ className = className.replace(File.separatorChar, '.');
+
+ try {
+ classList.add(Class.forName(className, true, cl));
+ } catch (Error ignore) {
+ // ignore
+ }
+ }
+ } else if (path.toLowerCase().endsWith(".jar")) {
+ try (JarFile jarFile = new JarFile(classPath.toFile())) {
+ Enumeration<JarEntry> en = jarFile.entries();
+ while (en.hasMoreElements()) {
+ JarEntry entry = en.nextElement();
+ String fileName = entry.getName();
+ if (fileName.endsWith(".class")) {
+ fileName = fileName.substring(0, fileName.lastIndexOf('.'));
+ fileName = fileName.replace('/', '.');
+ try {
+ classList.add(Class.forName(fileName, true, cl));
+ } catch (Error ignore) {
+ // ignore
+ }
+ }
+ }
+ }
+ }
+
+ return classList;
+ }
+
+ private static Object getMockInstance(Class<?> clazz, boolean useConstructor) {
+ if (clazz.isPrimitive()) {
+ if (clazz == Boolean.TYPE) {
+ return Boolean.TRUE;
+ }
+ if (clazz == Character.TYPE) {
+ return '0';
+ }
+ if (clazz == Byte.TYPE) {
+ return (byte) 0;
+ }
+ if (clazz == Short.TYPE) {
+ return (short) 0;
+ }
+ if (clazz == Integer.TYPE) {
+ return 0;
+ }
+ if (clazz == Long.TYPE) {
+ return 0L;
+ }
+ if (clazz == Float.TYPE) {
+ return 0f;
+ }
+ if (clazz == Double.TYPE) {
+ return 0d;
+ }
+ }
+
+ if (clazz.isAssignableFrom(String.class)) {
+ return "";
+ }
+
+ if (clazz.isAssignableFrom(List.class)) {
+ return new ArrayList<>();
+ }
+
+ if (clazz.isAssignableFrom(Set.class)) {
+ return new HashSet<>();
+ }
+
+ if (clazz.isAssignableFrom(Map.class)) {
+ return new HashMap<>();
+ }
+
+ if (clazz.isEnum()) {
+ return clazz.getEnumConstants()[0];
+ }
+
+ if (clazz.isInterface()) {
+ return newProxyInstance(clazz);
+ }
+
+ if (useConstructor) {
+ Constructor<?>[] constructors = clazz.getConstructors();
+ for (Constructor<?> con : constructors) {
+ try {
+ Class<?>[] types = con.getParameterTypes();
+ Object[] args = new Object[types.length];
+ for (int i = 0; i < args.length; ++i) {
+ args[i] = getMockInstance(types[i], false);
+ }
+ con.setAccessible(true);
+ return con.newInstance(args);
+ } catch (ReflectiveOperationException ignore) {
+ // ignore
+ }
+ }
+ }
+
+ return null;
+ }
+
+ @SuppressWarnings({"rawtypes", "PMD.UseProperClassLoader"})
+ private static Object newProxyInstance(Class<?> clazz) {
+ ClassLoader cl = clazz.getClassLoader();
+ return Proxy.newProxyInstance(cl, new Class[] {clazz}, (proxy, method, args) -> null);
+ }
+
+ private static final class TestClassLoader extends URLClassLoader {
+
+ public TestClassLoader(URL[] urls, ClassLoader parent) {
+ super(urls, parent);
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public Class<?> loadClass(String name) throws ClassNotFoundException {
+ try {
+ return findClass(name);
+ } catch (ClassNotFoundException e) {
+ ClassLoader classLoader = getParent();
+ if (classLoader == null) {
+ classLoader = getSystemClassLoader();
+ }
+ return classLoader.loadClass(name);
+ }
+ }
+ }
+}
diff --git a/java-package/integration/src/main/resources/log4j2.xml b/java-package/integration/src/main/resources/log4j2.xml
new file mode 100644
index 0000000..ff05a01
--- /dev/null
+++ b/java-package/integration/src/main/resources/log4j2.xml
@@ -0,0 +1,34 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!--
+ ~ Licensed to the Apache Software Foundation (ASF) under one or more
+ ~ contributor license agreements. See the NOTICE file distributed with
+ ~ this work for additional information regarding copyright ownership.
+ ~ The ASF licenses this file to You under the Apache License, Version 2.0
+ ~ (the "License"); you may not use this file except in compliance with
+ ~ the License. You may obtain a copy of the License at
+ ~
+ ~ http://www.apache.org/licenses/LICENSE-2.0
+ ~
+ ~ Unless required by applicable law or agreed to in writing, software
+ ~ distributed under the License is distributed on an "AS IS" BASIS,
+ ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ ~ See the License for the specific language governing permissions and
+ ~ limitations under the License.
+ -->
+
+<Configuration status="INFO">
+ <Appenders>
+ <Console name="console" target="SYSTEM_OUT">
+ <PatternLayout
+ pattern="[%-5level] - %msg%n"/>
+ </Console>
+ </Appenders>
+ <Loggers>
+ <Root level="info" additivity="false">
+ <AppenderRef ref="console"/>
+ </Root>
+ <Logger name="com.apache.mxnet" level="${sys:com.apache.mxnet.level:-debug}" additivity="false">
+ <AppenderRef ref="console"/>
+ </Logger>
+ </Loggers>
+</Configuration>
diff --git a/java-package/jnarator/build.gradle b/java-package/jnarator/build.gradle
new file mode 100644
index 0000000..abc7cc0
--- /dev/null
+++ b/java-package/jnarator/build.gradle
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+plugins {
+ id 'antlr'
+}
+
+dependencies {
+ antlr "org.antlr:antlr4:${antlr_version}"
+
+ api "commons-cli:commons-cli:${commons_cli_version}"
+ api "org.antlr:antlr4-runtime:${antlr_version}"
+ api "org.apache.logging.log4j:log4j-slf4j-impl:${log4j_slf4j_version}"
+
+ testImplementation("org.testng:testng:${testng_version}") {
+ exclude group: "junit", module: "junit"
+ }
+
+ testImplementation "org.slf4j:slf4j-simple:${slf4j_version}"
+ testRuntimeOnly project(":mxnet-engine")
+// testRuntimeOnly ":mxnet-native-auto:${mxnet_version}"
+}
+
+checkstyleMain.source = 'src/main/java'
+
+checkstyleMain {
+ // skip check style for this package
+ exclude 'org/apache/mxnet/jnarator/**'
+}
+
+pmdMain.source = 'src/main/java'
+//pmdMain.ignoreFailures(true)
+spotbugs.ignoreFailures = true
+
+jar {
+ manifest {
+ attributes (
+ "Main-Class" : "org.apache.mxnet.jnarator.Main",
+ "Multi-Release" : true
+ )
+ }
+ includeEmptyDirs = false
+ duplicatesStrategy = DuplicatesStrategy.INCLUDE
+ from configurations.runtimeClasspath.collect { it.isDirectory() ? it : zipTree(it) }
+}
diff --git a/java-package/jnarator/src/main/antlr/org/apache/mxnet/jnarator/parser/C.g4 b/java-package/jnarator/src/main/antlr/org/apache/mxnet/jnarator/parser/C.g4
new file mode 100644
index 0000000..6206c13
--- /dev/null
+++ b/java-package/jnarator/src/main/antlr/org/apache/mxnet/jnarator/parser/C.g4
@@ -0,0 +1,923 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+grammar C;
+
+@parser::header {
+package org.apache.mxnet.jnarator.parser;
+
+}
+
+@lexer::header {
+package org.apache.mxnet.jnarator.parser;
+
+}
+
+
+primaryExpression
+ : Identifier
+ | Constant
+ | StringLiteral+
+ | '(' expression ')'
+ | genericSelection
+ | '__extension__'? '(' compoundStatement ')' // Blocks (GCC extension)
+ | '__builtin_va_arg' '(' unaryExpression ',' typeName ')'
+ | '__builtin_offsetof' '(' typeName ',' unaryExpression ')'
+ ;
+
+genericSelection
+ : '_Generic' '(' assignmentExpression ',' genericAssocList ')'
+ ;
+
+genericAssocList
+ : genericAssociation
+ | genericAssocList ',' genericAssociation
+ ;
+
+genericAssociation
+ : typeName ':' assignmentExpression
+ | 'default' ':' assignmentExpression
+ ;
+
+postfixExpression
+ : primaryExpression
+ | postfixExpression '[' expression ']'
+ | postfixExpression '(' argumentExpressionList? ')'
+ | postfixExpression '.' Identifier
+ | postfixExpression '->' Identifier
+ | postfixExpression '++'
+ | postfixExpression '--'
+ | '(' typeName ')' '{' initializerList '}'
+ | '(' typeName ')' '{' initializerList ',' '}'
+ | '__extension__' '(' typeName ')' '{' initializerList '}'
+ | '__extension__' '(' typeName ')' '{' initializerList ',' '}'
+ ;
+
+argumentExpressionList
+ : assignmentExpression
+ | argumentExpressionList ',' assignmentExpression
+ ;
+
+unaryExpression
+ : postfixExpression
+ | '++' unaryExpression
+ | '--' unaryExpression
+ | unaryOperator castExpression
+ | 'sizeof' unaryExpression
+ | 'sizeof' '(' typeName ')'
+ | '_Alignof' '(' typeName ')'
+ | '&&' Identifier // GCC extension address of label
+ ;
+
+unaryOperator
+ : '&' | '*' | '+' | '-' | '~' | '!'
+ ;
+
+castExpression
+ : '(' typeName ')' castExpression
+ | '__extension__' '(' typeName ')' castExpression
+ | unaryExpression
+ | DigitSequence // for
+ ;
+
+multiplicativeExpression
+ : castExpression
+ | multiplicativeExpression '*' castExpression
+ | multiplicativeExpression '/' castExpression
+ | multiplicativeExpression '%' castExpression
+ ;
+
+additiveExpression
+ : multiplicativeExpression
+ | additiveExpression '+' multiplicativeExpression
+ | additiveExpression '-' multiplicativeExpression
+ ;
+
+shiftExpression
+ : additiveExpression
+ | shiftExpression '<<' additiveExpression
+ | shiftExpression '>>' additiveExpression
+ ;
+
+relationalExpression
+ : shiftExpression
+ | relationalExpression '<' shiftExpression
+ | relationalExpression '>' shiftExpression
+ | relationalExpression '<=' shiftExpression
+ | relationalExpression '>=' shiftExpression
+ ;
+
+equalityExpression
+ : relationalExpression
+ | equalityExpression '==' relationalExpression
+ | equalityExpression '!=' relationalExpression
+ ;
+
+andExpression
+ : equalityExpression
+ | andExpression '&' equalityExpression
+ ;
+
+exclusiveOrExpression
+ : andExpression
+ | exclusiveOrExpression '^' andExpression
+ ;
+
+inclusiveOrExpression
+ : exclusiveOrExpression
+ | inclusiveOrExpression '|' exclusiveOrExpression
+ ;
+
+logicalAndExpression
+ : inclusiveOrExpression
+ | logicalAndExpression '&&' inclusiveOrExpression
+ ;
+
+logicalOrExpression
+ : logicalAndExpression
+ | logicalOrExpression '||' logicalAndExpression
+ ;
+
+conditionalExpression
+ : logicalOrExpression ('?' expression ':' conditionalExpression)?
+ ;
+
+assignmentExpression
+ : conditionalExpression
+ | unaryExpression assignmentOperator assignmentExpression
+ | DigitSequence // for
+ ;
+
+assignmentOperator
+ : '=' | '*=' | '/=' | '%=' | '+=' | '-=' | '<<=' | '>>=' | '&=' | '^=' | '|='
+ ;
+
+expression
+ : assignmentExpression
+ | expression ',' assignmentExpression
+ ;
+
+constantExpression
+ : conditionalExpression
+ ;
+
+declaration
+ : declarationSpecifiers initDeclaratorList ';'
+ | declarationSpecifiers ';'
+ | staticAssertDeclaration
+ ;
+
+declarationSpecifiers
+ : declarationSpecifier+
+ ;
+
+declarationSpecifiers2
+ : declarationSpecifier+
+ ;
+
+declarationSpecifier
+ : storageClassSpecifier
+ | typeSpecifier
+ | typeQualifier
+ | functionSpecifier
+ | alignmentSpecifier
+ ;
+
+initDeclaratorList
+ : initDeclarator
+ | initDeclaratorList ',' initDeclarator
+ ;
+
+initDeclarator
+ : declarator
+ | declarator '=' initializer
+ ;
+
+storageClassSpecifier
+ : 'typedef'
+ | 'extern'
+ | 'static'
+ | '_Thread_local'
+ | 'auto'
+ | 'register'
+ ;
+
+typeSpecifier
+ : ('void'
+ | 'char'
+ | 'short'
+ | 'int'
+ | 'long'
+ | 'float'
+ | 'double'
+ | 'signed'
+ | 'unsigned'
+ | '_Bool'
+ | '_Complex'
+ | '__m128'
+ | '__m128d'
+ | '__m128i')
+ | '__extension__' '(' ('__m128' | '__m128d' | '__m128i') ')'
+ | atomicTypeSpecifier
+ | structOrUnionSpecifier
+ | enumSpecifier
+ | typedefName
+ | '__typeof__' '(' constantExpression ')' // GCC extension
+ | typeSpecifier pointer
+ ;
+
+structOrUnionSpecifier
+ : structOrUnion Identifier? '{' structDeclarationList '}'
+ | structOrUnion Identifier
+ ;
+
+structOrUnion
+ : 'struct'
+ | 'union'
+ ;
+
+structDeclarationList
+ : structDeclaration
+ | structDeclarationList structDeclaration
+ ;
+
+structDeclaration
+ : specifierQualifierList structDeclaratorList? ';'
+ | staticAssertDeclaration
+ ;
+
+specifierQualifierList
+ : typeSpecifier specifierQualifierList?
+ | typeQualifier specifierQualifierList?
+ ;
+
+structDeclaratorList
+ : structDeclarator
+ | structDeclaratorList ',' structDeclarator
+ ;
+
+structDeclarator
+ : declarator
+ | declarator? ':' constantExpression
+ ;
+
+enumSpecifier
+ : 'enum' Identifier? '{' enumeratorList '}'
+ | 'enum' Identifier? '{' enumeratorList ',' '}'
+ | 'enum' Identifier
+ ;
+
+enumeratorList
+ : enumerator
+ | enumeratorList ',' enumerator
+ ;
+
+enumerator
+ : enumerationConstant
+ | enumerationConstant '=' constantExpression
+ ;
+
+enumerationConstant
+ : Identifier
+ ;
+
+atomicTypeSpecifier
+ : '_Atomic' '(' typeName ')'
+ ;
+
+typeQualifier
+ : 'const'
+ | 'restrict'
+ | 'volatile'
+ | '_Atomic'
+ ;
+
+functionSpecifier
+ : ('inline'
+ | '_Noreturn'
+ | '__inline__' // GCC extension
+ | '__stdcall')
+ | gccAttributeSpecifier
+ | '__declspec' '(' Identifier ')'
+ ;
+
+alignmentSpecifier
+ : '_Alignas' '(' typeName ')'
+ | '_Alignas' '(' constantExpression ')'
+ ;
+
+declarator
+ : pointer? directDeclarator gccDeclaratorExtension*
+ ;
+
+directDeclarator
+ : Identifier
+ | '(' declarator ')'
+ | directDeclarator '[' typeQualifierList? assignmentExpression? ']'
+ | directDeclarator '[' 'static' typeQualifierList? assignmentExpression ']'
+ | directDeclarator '[' typeQualifierList 'static' assignmentExpression ']'
+ | directDeclarator '[' typeQualifierList? '*' ']'
+ | directDeclarator '(' parameterTypeList ')'
+ | directDeclarator '(' identifierList? ')'
+ | Identifier ':' DigitSequence // bit field
+ | '(' typeSpecifier? pointer directDeclarator ')' // function pointer like: (__cdecl *f)
+ ;
+
+gccDeclaratorExtension
+ : '__asm' '(' StringLiteral+ ')'
+ | gccAttributeSpecifier
+ ;
+
+gccAttributeSpecifier
+ : '__attribute__' '(' '(' gccAttributeList ')' ')'
+ ;
+
+gccAttributeList
+ : gccAttribute (',' gccAttribute)*
+ | // empty
+ ;
+
+gccAttribute
+ : ~(',' | '(' | ')') // relaxed def for "identifier or reserved word"
+ ('(' argumentExpressionList? ')')?
+ | // empty
+ ;
+
+nestedParenthesesBlock
+ : ( ~('(' | ')')
+ | '(' nestedParenthesesBlock ')'
+ )*
+ ;
+
+pointer
+ : '*' typeQualifierList?
+ | '*' typeQualifierList? pointer
+ | '^' typeQualifierList? // Blocks language extension
+ | '^' typeQualifierList? pointer // Blocks language extension
+ ;
+
+typeQualifierList
+ : typeQualifier
+ | typeQualifierList typeQualifier
+ ;
+
+parameterTypeList
+ : parameterList
+ | parameterList ',' '...'
+ ;
+
+parameterList
+ : parameterDeclaration
+ | parameterList ',' parameterDeclaration
+ ;
+
+parameterDeclaration
+ : declarationSpecifiers declarator
+ | declarationSpecifiers2 abstractDeclarator?
+ ;
+
+identifierList
+ : Identifier
+ | identifierList ',' Identifier
+ ;
+
+typeName
+ : specifierQualifierList abstractDeclarator?
+ ;
+
+abstractDeclarator
+ : pointer
+ | pointer? directAbstractDeclarator gccDeclaratorExtension*
+ ;
+
+directAbstractDeclarator
+ : '(' abstractDeclarator ')' gccDeclaratorExtension*
+ | '[' typeQualifierList? assignmentExpression? ']'
+ | '[' 'static' typeQualifierList? assignmentExpression ']'
+ | '[' typeQualifierList 'static' assignmentExpression ']'
+ | '[' '*' ']'
+ | '(' parameterTypeList? ')' gccDeclaratorExtension*
+ | directAbstractDeclarator '[' typeQualifierList? assignmentExpression? ']'
+ | directAbstractDeclarator '[' 'static' typeQualifierList? assignmentExpression ']'
+ | directAbstractDeclarator '[' typeQualifierList 'static' assignmentExpression ']'
+ | directAbstractDeclarator '[' '*' ']'
+ | directAbstractDeclarator '(' parameterTypeList? ')' gccDeclaratorExtension*
+ ;
+
+typedefName
+ : Identifier
+ ;
+
+initializer
+ : assignmentExpression
+ | '{' initializerList '}'
+ | '{' initializerList ',' '}'
+ ;
+
+initializerList
+ : designation? initializer
+ | initializerList ',' designation? initializer
+ ;
+
+designation
+ : designatorList '='
+ ;
+
+designatorList
+ : designator
+ | designatorList designator
+ ;
+
+designator
+ : '[' constantExpression ']'
+ | '.' Identifier
+ ;
+
+staticAssertDeclaration
+ : '_Static_assert' '(' constantExpression ',' StringLiteral+ ')' ';'
+ ;
+
+statement
+ : labeledStatement
+ | compoundStatement
+ | expressionStatement
+ | selectionStatement
+ | iterationStatement
+ | jumpStatement
+ | ('__asm' | '__asm__') ('volatile' | '__volatile__') '(' (logicalOrExpression (',' logicalOrExpression)*)? (':' (logicalOrExpression (',' logicalOrExpression)*)?)* ')' ';'
+ ;
+
+labeledStatement
+ : Identifier ':' statement
+ | 'case' constantExpression ':' statement
+ | 'default' ':' statement
+ ;
+
+compoundStatement
+ : '{' blockItemList? '}'
+ ;
+
+blockItemList
+ : blockItem
+ | blockItemList blockItem
+ ;
+
+blockItem
+ : statement
+ | declaration
+ ;
+
+expressionStatement
+ : expression? ';'
+ ;
+
+selectionStatement
+ : 'if' '(' expression ')' statement ('else' statement)?
+ | 'switch' '(' expression ')' statement
+ ;
+
+iterationStatement
+ : While '(' expression ')' statement
+ | Do statement While '(' expression ')' ';'
+ | For '(' forCondition ')' statement
+ ;
+
+// | 'for' '(' expression? ';' expression? ';' forUpdate? ')' statement
+// | For '(' declaration expression? ';' expression? ')' statement
+
+forCondition
+ : forDeclaration ';' forExpression? ';' forExpression?
+ | expression? ';' forExpression? ';' forExpression?
+ ;
+
+forDeclaration
+ : declarationSpecifiers initDeclaratorList
+ | declarationSpecifiers
+ ;
+
+forExpression
+ : assignmentExpression
+ | forExpression ',' assignmentExpression
+ ;
+
+jumpStatement
+ : 'goto' Identifier ';'
+ | 'continue' ';'
+ | 'break' ';'
+ | 'return' expression? ';'
+ | 'goto' unaryExpression ';' // GCC extension
+ ;
+
+compilationUnit
+ : translationUnit? ( EOF | '}' )
+ ;
+
+translationUnit
+ : externalDeclaration
+ | translationUnit externalDeclaration
+ ;
+
+externalDeclaration
+ : functionDefinition
+ | declaration
+ | ';' // stray ;
+ ;
+
+functionDefinition
+ : declarationSpecifiers? declarator declarationList? compoundStatement
+ ;
+
+declarationList
+ : declaration
+ | declarationList declaration
+ ;
+
+Auto : 'auto';
+Break : 'break';
+Case : 'case';
+Char : 'char';
+Const : 'const';
+Continue : 'continue';
+Default : 'default';
+Do : 'do';
+Double : 'double';
+Else : 'else';
+Enum : 'enum';
+Extern : 'extern';
+Float : 'float';
+For : 'for';
+Goto : 'goto';
+If : 'if';
+Inline : 'inline';
+Int : 'int';
+Long : 'long';
+Register : 'register';
+Restrict : 'restrict';
+Return : 'return';
+Short : 'short';
+Signed : 'signed';
+Sizeof : 'sizeof';
+Static : 'static';
+Struct : 'struct';
+Switch : 'switch';
+Typedef : 'typedef';
+Union : 'union';
+Unsigned : 'unsigned';
+Void : 'void';
+Volatile : 'volatile';
+While : 'while';
+
+Alignas : '_Alignas';
+Alignof : '_Alignof';
+Atomic : '_Atomic';
+Bool : '_Bool';
+Complex : '_Complex';
+Generic : '_Generic';
+Imaginary : '_Imaginary';
+Noreturn : '_Noreturn';
+StaticAssert : '_Static_assert';
+ThreadLocal : '_Thread_local';
+
+LeftParen : '(';
+RightParen : ')';
+LeftBracket : '[';
+RightBracket : ']';
+LeftBrace : '{';
+RightBrace : '}';
+
+Less : '<';
+LessEqual : '<=';
+Greater : '>';
+GreaterEqual : '>=';
+LeftShift : '<<';
+RightShift : '>>';
+
+Plus : '+';
+PlusPlus : '++';
+Minus : '-';
+MinusMinus : '--';
+Star : '*';
+Div : '/';
+Mod : '%';
+
+And : '&';
+Or : '|';
+AndAnd : '&&';
+OrOr : '||';
+Caret : '^';
+Not : '!';
+Tilde : '~';
+
+Question : '?';
+Colon : ':';
+Semi : ';';
+Comma : ',';
+
+Assign : '=';
+// '*=' | '/=' | '%=' | '+=' | '-=' | '<<=' | '>>=' | '&=' | '^=' | '|='
+StarAssign : '*=';
+DivAssign : '/=';
+ModAssign : '%=';
+PlusAssign : '+=';
+MinusAssign : '-=';
+LeftShiftAssign : '<<=';
+RightShiftAssign : '>>=';
+AndAssign : '&=';
+XorAssign : '^=';
+OrAssign : '|=';
+
+Equal : '==';
+NotEqual : '!=';
+
+Arrow : '->';
+Dot : '.';
+Ellipsis : '...';
+
+Identifier
+ : IdentifierNondigit
+ ( IdentifierNondigit
+ | Digit
+ )*
+ ;
+
+fragment
+IdentifierNondigit
+ : Nondigit
+ | UniversalCharacterName
+ //| // other implementation-defined characters...
+ ;
+
+fragment
+Nondigit
+ : [a-zA-Z_]
+ ;
+
+fragment
+Digit
+ : [0-9]
+ ;
+
+fragment
+UniversalCharacterName
+ : '\\u' HexQuad
+ | '\\U' HexQuad HexQuad
+ ;
+
+fragment
+HexQuad
+ : HexadecimalDigit HexadecimalDigit HexadecimalDigit HexadecimalDigit
+ ;
+
+Constant
+ : IntegerConstant
+ | FloatingConstant
+ //| EnumerationConstant
+ | CharacterConstant
+ ;
+
+fragment
+IntegerConstant
+ : DecimalConstant IntegerSuffix?
+ | OctalConstant IntegerSuffix?
+ | HexadecimalConstant IntegerSuffix?
+ | BinaryConstant
+ ;
+
+fragment
+BinaryConstant
+ : '0' [bB] [0-1]+
+ ;
+
+fragment
+DecimalConstant
+ : NonzeroDigit Digit*
+ ;
+
+fragment
+OctalConstant
+ : '0' OctalDigit*
+ ;
+
+fragment
+HexadecimalConstant
+ : HexadecimalPrefix HexadecimalDigit+
+ ;
+
+fragment
+HexadecimalPrefix
+ : '0' [xX]
+ ;
+
+fragment
+NonzeroDigit
+ : [1-9]
+ ;
+
+fragment
+OctalDigit
+ : [0-7]
+ ;
+
+fragment
+HexadecimalDigit
+ : [0-9a-fA-F]
+ ;
+
+fragment
+IntegerSuffix
+ : UnsignedSuffix LongSuffix?
+ | UnsignedSuffix LongLongSuffix
+ | LongSuffix UnsignedSuffix?
+ | LongLongSuffix UnsignedSuffix?
+ ;
+
+fragment
+UnsignedSuffix
+ : [uU]
+ ;
+
+fragment
+LongSuffix
+ : [lL]
+ ;
+
+fragment
+LongLongSuffix
+ : 'll' | 'LL'
+ ;
+
+fragment
+FloatingConstant
+ : DecimalFloatingConstant
+ | HexadecimalFloatingConstant
+ ;
+
+fragment
+DecimalFloatingConstant
+ : FractionalConstant ExponentPart? FloatingSuffix?
+ | DigitSequence ExponentPart FloatingSuffix?
+ ;
+
+fragment
+HexadecimalFloatingConstant
+ : HexadecimalPrefix HexadecimalFractionalConstant BinaryExponentPart FloatingSuffix?
+ | HexadecimalPrefix HexadecimalDigitSequence BinaryExponentPart FloatingSuffix?
+ ;
+
+fragment
+FractionalConstant
+ : DigitSequence? '.' DigitSequence
+ | DigitSequence '.'
+ ;
+
+fragment
+ExponentPart
+ : 'e' Sign? DigitSequence
+ | 'E' Sign? DigitSequence
+ ;
+
+fragment
+Sign
+ : '+' | '-'
+ ;
+
+DigitSequence
+ : Digit+
+ ;
+
+fragment
+HexadecimalFractionalConstant
+ : HexadecimalDigitSequence? '.' HexadecimalDigitSequence
+ | HexadecimalDigitSequence '.'
+ ;
+
+fragment
+BinaryExponentPart
+ : 'p' Sign? DigitSequence
+ | 'P' Sign? DigitSequence
+ ;
+
+fragment
+HexadecimalDigitSequence
+ : HexadecimalDigit+
+ ;
+
+fragment
+FloatingSuffix
+ : 'f' | 'l' | 'F' | 'L'
+ ;
+
+fragment
+CharacterConstant
+ : '\'' CCharSequence '\''
+ | 'L\'' CCharSequence '\''
+ | 'u\'' CCharSequence '\''
+ | 'U\'' CCharSequence '\''
+ ;
+
+fragment
+CCharSequence
+ : CChar+
+ ;
+
+fragment
+CChar
+ : ~['\\\r\n]
+ | EscapeSequence
+ ;
+
+fragment
+EscapeSequence
+ : SimpleEscapeSequence
+ | OctalEscapeSequence
+ | HexadecimalEscapeSequence
+ | UniversalCharacterName
+ ;
+
+fragment
+SimpleEscapeSequence
+ : '\\' ['"?abfnrtv\\]
+ ;
+
+fragment
+OctalEscapeSequence
+ : '\\' OctalDigit
+ | '\\' OctalDigit OctalDigit
+ | '\\' OctalDigit OctalDigit OctalDigit
+ ;
+
+fragment
+HexadecimalEscapeSequence
+ : '\\x' HexadecimalDigit+
+ ;
+
+StringLiteral
+ : EncodingPrefix? '"' SCharSequence? '"'
+ ;
+
+fragment
+EncodingPrefix
+ : 'u8'
+ | 'u'
+ | 'U'
+ | 'L'
+ ;
+
+fragment
+SCharSequence
+ : SChar+
+ ;
+
+fragment
+SChar
+ : ~["\\\r\n]
+ | EscapeSequence
+ | '\\\n' // Added line
+ | '\\\r\n' // Added line
+ ;
+
+LineAfterPreprocessing
+ : '#' ~[\r\n]*
+ -> skip
+ ;
+
+Whitespace
+ : ( [ \t]
+ | 'MXNET_DLL'
+ | 'NNVM_DLL'
+ | 'extern "C" {'
+ | 'DEFAULT' Whitespace? '(' .*? ')'
+ )+
+ -> skip
+ ;
+
+Newline
+ : ( '\r' '\n'?
+ | '\n'
+ )
+ -> skip
+ ;
+
+BlockComment
+ : '/*' .*? '*/'
+ -> skip
+ ;
+
+LineComment
+ : '//' ~[\r\n]*
+ -> skip
+ ;
diff --git a/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/AntlrUtils.java b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/AntlrUtils.java
new file mode 100644
index 0000000..3a2b5e0
--- /dev/null
+++ b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/AntlrUtils.java
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mxnet.jnarator;
+
+import java.util.ArrayList;
+import java.util.List;
+import org.antlr.v4.runtime.tree.ParseTree;
+import org.antlr.v4.runtime.tree.TerminalNode;
+import org.apache.mxnet.jnarator.parser.CParser;
+
+public final class AntlrUtils {
+
+ private AntlrUtils() {}
+
+ public static boolean isTypeDef(CParser.DeclarationSpecifiersContext specs) {
+ if (specs.isEmpty()) {
+ return false;
+ }
+
+ CParser.DeclarationSpecifierContext spec =
+ (CParser.DeclarationSpecifierContext) specs.getChild(0);
+ CParser.StorageClassSpecifierContext storage = spec.storageClassSpecifier();
+ if (storage != null) {
+ return storage.Typedef() != null;
+ }
+ return false;
+ }
+
+ public static String getTypeDefValue(CParser.DeclarationSpecifiersContext specs) {
+ List<String> list = new ArrayList<>();
+ for (int i = 1; i < specs.getChildCount(); ++i) {
+ list.add(specs.getChild(i).getText());
+ }
+ return String.join(" ", list);
+ }
+
+ public static boolean isEnum(CParser.DeclarationSpecifiersContext specs) {
+ if (specs.isEmpty()) {
+ return false;
+ }
+
+ CParser.DeclarationSpecifierContext spec =
+ (CParser.DeclarationSpecifierContext) specs.getChild(0);
+ CParser.TypeSpecifierContext type = spec.typeSpecifier();
+ if (type == null) {
+ return false;
+ }
+ return type.enumSpecifier() != null;
+ }
+
+ public static boolean isStructOrUnion(CParser.DeclarationSpecifiersContext specs) {
+ if (specs.isEmpty()) {
+ return false;
+ }
+
+ CParser.DeclarationSpecifierContext spec =
+ (CParser.DeclarationSpecifierContext) specs.getChild(0);
+ CParser.TypeSpecifierContext type = spec.typeSpecifier();
+ if (type == null) {
+ return false;
+ }
+ return type.structOrUnionSpecifier() != null;
+ }
+
+ public static String getText(ParseTree tree) {
+ StringBuilder sb = new StringBuilder();
+ getText(sb, tree);
+ return sb.toString();
+ }
+
+ private static void getText(StringBuilder sb, ParseTree tree) {
+ if (tree instanceof TerminalNode) {
+ sb.append("\"v\" : \"").append(tree.getText()).append('"');
+ return;
+ }
+ sb.append('"');
+ sb.append(tree.getClass().getSimpleName()).append("\" : {");
+ for (int i = 0; i < tree.getChildCount(); i++) {
+ getText(sb, tree.getChild(i));
+ if (i < tree.getChildCount() - 1) {
+ sb.append(',');
+ }
+ }
+ sb.append('}');
+ }
+
+ public static String toCamelCase(String name) {
+ String[] tokens = name.split("_");
+ for (int i = 0; i < tokens.length; ++i) {
+ char upper = Character.toUpperCase(tokens[i].charAt(0));
+ tokens[i] = upper + tokens[i].substring(1); // NOPMD
+ }
+ return String.join("", tokens);
+ }
+}
diff --git a/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/DataType.java b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/DataType.java
new file mode 100644
index 0000000..3b1c1fe
--- /dev/null
+++ b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/DataType.java
@@ -0,0 +1,266 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mxnet.jnarator;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Set;
+import org.antlr.v4.runtime.tree.ParseTree;
+import org.antlr.v4.runtime.tree.TerminalNode;
+import org.apache.mxnet.jnarator.parser.CParser;
+
+public class DataType {
+
+ private boolean isConst;
+ private boolean functionPointer;
+ private int pointerCount;
+ private StringBuilder type = new StringBuilder(); // NOPMD
+
+ public boolean isConst() {
+ return isConst;
+ }
+
+ public void setConst() {
+ isConst = true;
+ }
+
+ public boolean isFunctionPointer() {
+ return functionPointer;
+ }
+
+ public void setFunctionPointer(boolean functionPointer) {
+ this.functionPointer = functionPointer;
+ }
+
+ public int getPointerCount() {
+ return pointerCount;
+ }
+
+ public void setPointerCount(int pointerCount) {
+ this.pointerCount = pointerCount;
+ }
+
+ public void increasePointerCount() {
+ ++pointerCount;
+ }
+
+ public String getType() {
+ return type.toString();
+ }
+
+ public void setType(String typeName) {
+ type.setLength(0);
+ type.append(typeName);
+ }
+
+ public void appendTypeName(String name) {
+ if (type.length() > 0) {
+ type.append(' ');
+ }
+ type.append(name);
+ }
+
+ public String map(Map<String, TypeDefine> map, Set<String> structs) {
+ String typeName = type.toString().trim();
+ TypeDefine typeDefine = map.get(typeName);
+ boolean isStruct = structs.contains(typeName);
+ if (typeDefine != null && !typeDefine.isCallBack()) {
+ typeName = typeDefine.getValue();
+
+ String mapped = typeName.replaceAll("const ", "").replaceAll(" const", "");
+ if (typeName.length() - mapped.length() > 0) {
+ isConst = true;
+ }
+ typeName = mapped;
+ mapped = typeName.replaceAll("\\*", "");
+ int count = typeName.length() - mapped.length();
+ pointerCount += count;
+ typeName = mapped;
+ setType(typeName);
+ }
+
+ if (pointerCount > 2) {
+ return "PointerByReference";
+ }
+
+ typeName = baseTypeMapping(typeName);
+
+ if (pointerCount == 2) {
+ if (isConst && "char".equals(typeName)) {
+ return "String[]";
+ }
+ return "PointerByReference";
+ }
+
+ if (pointerCount == 1) {
+ switch (typeName) {
+ case "byte":
+ return "ByteBuffer";
+ case "NativeSize":
+ return "NativeSizeByReference";
+ case "int":
+ if (isConst) {
+ return "int[]";
+ }
+ return "IntBuffer";
+ case "long":
+ if (isConst) {
+ return "long[]";
+ }
+ return "LongBuffer";
+ case "char":
+ if (isConst) {
+ return "String";
+ }
+ return "ByteBuffer";
+ case "float":
+ return "FloatBuffer";
+ case "void":
+ return "Pointer";
+ default:
+ if (isStruct) {
+ return typeName + ".ByReference";
+ }
+ return "Pointer";
+ }
+ }
+ return typeName;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public String toString() {
+ StringBuilder sb = new StringBuilder();
+ if (isConst) {
+ sb.append("const ");
+ }
+ sb.append(type);
+ if (pointerCount > 0) {
+ sb.append(' ');
+ for (int i = 0; i < pointerCount; ++i) {
+ sb.append('*');
+ }
+ }
+ return sb.toString();
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ DataType dataType = (DataType) o;
+ return type.toString().equals(dataType.type.toString());
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public int hashCode() {
+ return Objects.hash(type);
+ }
+
+ static DataType parse(ParseTree tree) {
+ DataType dataType = new DataType();
+ parseTypeSpec(dataType, tree);
+ return dataType;
+ }
+
+ static List<DataType> parseDataTypes(List<CParser.DeclarationSpecifierContext> list) {
+ List<DataType> ret = new ArrayList<>();
+ DataType dataType = new DataType();
+ for (CParser.DeclarationSpecifierContext spec : list) {
+ CParser.TypeQualifierContext qualifier = spec.typeQualifier();
+ if (qualifier != null) {
+ String qualifierName = qualifier.getText();
+ if ("const".equals(qualifierName)) {
+ dataType.setConst();
+ } else {
+ dataType.appendTypeName(qualifierName);
+ }
+ continue;
+ }
+
+ CParser.TypeSpecifierContext type = spec.typeSpecifier();
+ parseTypeSpec(dataType, type);
+ ret.add(dataType);
+ dataType = new DataType();
+ }
+
+ return ret;
+ }
+
+ private static void parseTypeSpec(DataType dataType, ParseTree tree) {
+ if (tree == null) {
+ return;
+ }
+
+ if (tree instanceof CParser.StructOrUnionContext) {
+ return;
+ }
+ if (tree instanceof CParser.TypedefNameContext) {
+ if (dataType.getType().isEmpty()) {
+ dataType.appendTypeName(tree.getText());
+ }
+ return;
+ }
+
+ if (tree instanceof TerminalNode) {
+ String value = tree.getText();
+ if ("const".equals(value)) {
+ dataType.setConst();
+ } else if ("*".equals(value)) {
+ dataType.increasePointerCount();
+ } else {
+ dataType.appendTypeName(value);
+ }
+ return;
+ }
+
+ for (int i = 0; i < tree.getChildCount(); i++) {
+ parseTypeSpec(dataType, tree.getChild(i));
+ }
+ }
+
+ private static String baseTypeMapping(String type) {
+ switch (type) {
+ case "uint64_t":
+ case "int64_t":
+ case "long":
+ return "long";
+ case "uint32_t":
+ case "unsigned int":
+ case "unsigned":
+ case "int":
+ return "int";
+ case "bool":
+ return "byte";
+ case "size_t":
+ return "NativeSize";
+ case "char":
+ case "void":
+ case "float":
+ default:
+ return type;
+ }
+ }
+}
diff --git a/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/FuncInfo.java b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/FuncInfo.java
new file mode 100644
index 0000000..3cdc9c9
--- /dev/null
+++ b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/FuncInfo.java
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mxnet.jnarator;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Objects;
+import org.apache.mxnet.jnarator.parser.CParser;
+
+public class FuncInfo {
+
+ private String name;
+ private DataType returnType;
+ private List<Parameter> parameters = new ArrayList<>();
+
+ public String getName() {
+ return name;
+ }
+
+ public void setName(String name) {
+ this.name = name;
+ }
+
+ public DataType getReturnType() {
+ return returnType;
+ }
+
+ public void setReturnType(DataType returnType) {
+ this.returnType = returnType;
+ }
+
+ public List<Parameter> getParameters() {
+ return parameters;
+ }
+
+ public void setParameters(List<Parameter> parameters) {
+ this.parameters = parameters;
+ }
+
+ public void addParameter(Parameter parameter) {
+ parameters.add(parameter);
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public String toString() {
+ StringBuilder sb = new StringBuilder();
+ sb.append(returnType).append(' ').append(name).append('(');
+ if (parameters != null) {
+ boolean first = true;
+ for (Parameter param : parameters) {
+ if (first) {
+ first = false;
+ } else {
+ sb.append(", ");
+ }
+ sb.append(param);
+ }
+ }
+
+ sb.append(");");
+ return sb.toString();
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ FuncInfo funcInfo = (FuncInfo) o;
+ return name.equals(funcInfo.name) && parameters.equals(funcInfo.parameters);
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public int hashCode() {
+ return Objects.hash(name);
+ }
+
+ static FuncInfo parse(CParser.DeclarationContext ctx) {
+ FuncInfo info = new FuncInfo();
+
+ List<CParser.DeclarationSpecifierContext> specs =
+ ctx.declarationSpecifiers().declarationSpecifier();
+ List<DataType> dataTypes = DataType.parseDataTypes(specs);
+ info.setReturnType(dataTypes.get(0));
+ if (dataTypes.size() > 1) {
+ info.setName(dataTypes.get(1).getType());
+ }
+
+ CParser.InitDeclaratorContext init = ctx.initDeclaratorList().initDeclarator();
+ CParser.DirectDeclaratorContext declarator = init.declarator().directDeclarator();
+
+ CParser.DirectDeclaratorContext name = declarator.directDeclarator();
+ if (info.getName() == null) {
+ info.setName(name.getText());
+ CParser.ParameterTypeListContext paramListCtx = declarator.parameterTypeList();
+ if (paramListCtx != null) {
+ Parameter.parseParams(info.getParameters(), paramListCtx);
+ }
+ } else {
+ DataType dataType = new DataType();
+ CParser.TypeSpecifierContext type = declarator.typeSpecifier();
+ dataType.appendTypeName(type.getText());
+ if (declarator.pointer() != null) {
+ dataType.increasePointerCount();
+ }
+ Parameter param = new Parameter(dataType, name.getText());
+ info.addParameter(param);
+ }
+
+ return info;
+ }
+}
diff --git a/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/JnaGenerator.java b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/JnaGenerator.java
new file mode 100644
index 0000000..7be963d
--- /dev/null
+++ b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/JnaGenerator.java
@@ -0,0 +1,371 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mxnet.jnarator;
+
+import java.io.BufferedWriter;
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Properties;
+import java.util.Set;
+
+public class JnaGenerator {
+
+ private Path dir;
+ private String packageName;
+ private String libName;
+ private String className;
+ private Map<String, TypeDefine> typedefMap;
+ private Set<String> structs;
+ private Properties mapping;
+
+ public JnaGenerator(
+ String libName,
+ String packageName,
+ Map<String, TypeDefine> typedefMap,
+ Set<String> structs,
+ Properties mapping) {
+ this.libName = libName;
+ this.packageName = packageName;
+ this.typedefMap = typedefMap;
+ this.structs = structs;
+ this.mapping = mapping;
+ }
+
+ public void init(String output) throws IOException {
+ String[] tokens = packageName.split("\\.");
+ dir = Paths.get(output, tokens);
+ Files.createDirectories(dir);
+ className = AntlrUtils.toCamelCase(libName) + "Library";
+ }
+
+ @SuppressWarnings("PMD.UseConcurrentHashMap")
+ public void writeStructure(Map<String, List<TypeDefine>> structMap) throws IOException {
+ for (Map.Entry<String, List<TypeDefine>> entry : structMap.entrySet()) {
+ String name = entry.getKey();
+ Path path = dir.resolve(name + ".java");
+ try (BufferedWriter writer = Files.newBufferedWriter(path)) {
+ writer.append("package ").append(packageName).append(";\n\n");
+
+ Set<String> importSet = new HashSet<>();
+ importSet.add("com.sun.jna.Pointer");
+ importSet.add("com.sun.jna.Structure");
+ importSet.add("java.util.List");
+
+ Map<String, String> fieldNames = new LinkedHashMap<>();
+ for (TypeDefine typeDefine : entry.getValue()) {
+ String fieldName = typeDefine.getValue();
+ String typeName;
+ if (typeDefine.isCallBack()) {
+ typeName = AntlrUtils.toCamelCase(fieldName) + "Callback";
+ importSet.add("com.sun.jna.Callback");
+ for (Parameter param : typeDefine.getParameters()) {
+ String type = param.getType().map(typedefMap, structs);
+ addImports(importSet, type);
+ }
+ } else {
+ typeName = typeDefine.getDataType().map(typedefMap, structs);
+ addImports(importSet, typeName);
+ }
+ fieldNames.put(fieldName, typeName);
+ }
+
+ int fieldCount = fieldNames.size();
+ if (fieldCount < 2) {
+ importSet.add("java.util.Collections");
+ } else {
+ importSet.add("java.util.Arrays");
+ }
+
+ List<String> imports = new ArrayList<>(importSet.size());
+ imports.addAll(importSet);
+ Collections.sort(imports);
+ for (String imp : imports) {
+ writer.append("import ").append(imp).append(";\n");
+ }
+
+ writer.append("\npublic class ").append(name).append(" extends Structure {\n");
+ if (fieldCount > 0) {
+ writer.write("\n");
+ }
+ for (Map.Entry<String, String> field : fieldNames.entrySet()) {
+ writer.append(" public ").append(field.getValue()).append(' ');
+ writer.append(field.getKey()).append(";\n");
+ }
+
+ writer.append("\n public ").append(name).append("() {\n");
+ writer.append(" }\n");
+ writer.append("\n public ").append(name).append("(Pointer peer) {\n");
+ writer.append(" super(peer);\n");
+ writer.append(" }\n");
+
+ writer.append("\n @Override\n");
+ writer.append(" protected List<String> getFieldOrder() {\n");
+ switch (fieldNames.size()) {
+ case 0:
+ writer.append(" return Collections.emptyList();\n");
+ break;
+ case 1:
+ writer.append(" return Collections.singletonList(");
+ String firstField = fieldNames.keySet().iterator().next();
+ writer.append('"').append(firstField).append("\");\n");
+ break;
+ default:
+ writer.append(" return Arrays.asList(");
+ boolean first = true;
+ for (String fieldName : fieldNames.keySet()) {
+ if (first) {
+ first = false;
+ } else {
+ writer.write(", ");
+ }
+ writer.append('"').append(fieldName).append('"');
+ }
+ writer.append(");\n");
+ break;
+ }
+ writer.append(" }\n");
+
+ for (TypeDefine typeDefine : entry.getValue()) {
+ String fieldName = typeDefine.getValue();
+ String typeName = fieldNames.get(fieldName);
+ String getterName;
+ if (!typeDefine.isCallBack()) {
+ getterName = AntlrUtils.toCamelCase(fieldName);
+ } else {
+ getterName = typeName;
+ }
+
+ writer.append("\n public void set").append(getterName).append('(');
+ writer.append(typeName).append(' ').append(fieldName).append(") {\n");
+ writer.append(" this.").append(fieldName).append(" = ");
+ writer.append(fieldName).append(";\n");
+ writer.append(" }\n");
+ writer.append("\n public ").append(typeName).append(" get");
+ writer.append(getterName).append("() {\n");
+ writer.append(" return ").append(fieldName).append(";\n");
+ writer.append(" }\n");
+ }
+
+ writer.append("\n public static final class ByReference extends ");
+ writer.append(name).append(" implements Structure.ByReference {}\n");
+
+ writer.append("\n public static final class ByValue extends ");
+ writer.append(name).append(" implements Structure.ByValue {}\n");
+
+ for (TypeDefine typeDefine : entry.getValue()) {
+ if (typeDefine.isCallBack()) {
+ DataType dataType = typeDefine.getDataType();
+ String fieldName = typeDefine.getValue();
+
+ String callbackName = fieldNames.get(fieldName);
+ String returnType = mapping.getProperty(callbackName);
+ if (returnType == null) {
+ returnType = dataType.map(typedefMap, structs);
+ }
+
+ writer.append("\n public interface ").append(callbackName);
+ writer.append(" extends Callback {\n");
+ writer.append(" ").append(returnType).append(" apply(");
+ writeParameters(writer, fieldName, typeDefine.getParameters());
+ writer.append(");\n");
+ writer.append(" }\n");
+ }
+ }
+
+ writer.append("}\n");
+ }
+ }
+ }
+
+ public void writeLibrary(Collection<FuncInfo> functions, Map<String, List<String>> enumMap)
+ throws IOException {
+ try (BufferedWriter writer = Files.newBufferedWriter(dir.resolve(className + ".java"))) {
+ writer.append("package ").append(packageName).append(";\n\n");
+
+ writer.append("import com.sun.jna.Callback;\n");
+ writer.append("import com.sun.jna.Library;\n");
+ writer.append("import com.sun.jna.Pointer;\n");
+ writer.append("import com.sun.jna.ptr.PointerByReference;\n");
+ writer.append("import java.nio.ByteBuffer;\n");
+ writer.append("import java.nio.FloatBuffer;\n");
+ writer.append("import java.nio.IntBuffer;\n");
+ writer.append("import java.nio.LongBuffer;\n");
+
+ writer.append("\npublic interface ").append(className).append(" extends Library {\n\n");
+
+ for (Map.Entry<String, List<String>> entry : enumMap.entrySet()) {
+ String name = entry.getKey();
+ writer.append("\n enum ").append(name).append(" {\n");
+ List<String> fields = entry.getValue();
+ int count = 0;
+ for (String field : fields) {
+ writer.append(" ").append(field);
+ if (++count < fields.size()) {
+ writer.append(',');
+ }
+ writer.append('\n');
+ }
+ writer.append(" }\n");
+ }
+
+ for (TypeDefine typeDefine : typedefMap.values()) {
+ if (typeDefine.isCallBack()) {
+ String callbackName = typeDefine.getDataType().getType();
+ String returnType = mapping.getProperty(callbackName);
+ if (returnType == null) {
+ returnType = typeDefine.getValue();
+ }
+ writer.append("\n interface ").append(callbackName);
+ writer.append(" extends Callback {\n");
+ writer.append(" ").append(returnType).append(" apply(");
+ writeParameters(writer, callbackName, typeDefine.getParameters());
+ writer.append(");\n");
+ writer.append(" }\n");
+ }
+ }
+
+ for (FuncInfo info : functions) {
+ writeFunction(writer, info);
+ }
+ writer.append("}\n");
+ }
+ }
+
+ public void writeNativeSize() throws IOException {
+ try (BufferedWriter writer = Files.newBufferedWriter(dir.resolve("NativeSize.java"))) {
+ writer.append("package ").append(packageName).append(";\n\n");
+ writer.append("import com.sun.jna.IntegerType;\n");
+ writer.append("import com.sun.jna.Native;\n\n");
+
+ writer.append("public class NativeSize extends IntegerType {\n\n");
+ writer.append(" private static final long serialVersionUID = 1L;\n\n");
+ writer.append(" public static final int SIZE = Native.SIZE_T_SIZE;\n\n");
+ writer.append(" public NativeSize() {\n");
+ writer.append(" this(0);\n");
+ writer.append(" }\n\n");
+ writer.append(" public NativeSize(long value) {\n");
+ writer.append(" super(SIZE, value);\n");
+ writer.append(" }\n");
+ writer.append("}\n");
+ }
+
+ Path path = dir.resolve("NativeSizeByReference.java");
+ try (BufferedWriter writer = Files.newBufferedWriter(path)) {
+ writer.append("package ").append(packageName).append(";\n\n");
+ writer.append("import com.sun.jna.ptr.ByReference;\n\n");
+ writer.append("public class NativeSizeByReference extends ByReference {\n\n");
+ writer.append(" public NativeSizeByReference() {\n");
+ writer.append(" this(new NativeSize(0));\n");
+ writer.append(" }\n\n");
+ writer.append(" public NativeSizeByReference(NativeSize value) {\n");
+ writer.append(" super(NativeSize.SIZE);\n");
+ writer.append(" setValue(value);\n");
+ writer.append(" }\n\n");
+ writer.append(" public void setValue(NativeSize value) {\n");
+ writer.append(" if (NativeSize.SIZE == 4) {\n");
+ writer.append(" getPointer().setInt(0, value.intValue());\n");
+ writer.append(" } else if (NativeSize.SIZE == 8) {\n");
+ writer.append(" getPointer().setLong(0, value.longValue());\n");
+ writer.append(" } else {\n");
+ writer.append(
+ " throw new IllegalArgumentException(\"size_t has to be either 4 or 8 bytes.\");\n");
+ writer.append(" }\n");
+ writer.append(" }\n\n");
+ writer.append(" public NativeSize getValue() {\n");
+ writer.append(" if (NativeSize.SIZE == 4) {\n");
+ writer.append(" return new NativeSize(getPointer().getInt(0));\n");
+ writer.append(" } else if (NativeSize.SIZE == 8) {\n");
+ writer.append(" return new NativeSize(getPointer().getLong(0));\n");
+ writer.append(" } else {\n");
+ writer.append(
+ " throw new IllegalArgumentException(\"size_t has to be either 4 or 8 bytes.\");\n");
+ writer.append(" }\n");
+ writer.append(" }\n");
+ writer.append("}\n");
+ }
+ }
+
+ private void writeFunction(BufferedWriter writer, FuncInfo info) throws IOException {
+ String funcName = info.getName();
+ String returnType = mapping.getProperty(funcName);
+ if (returnType == null) {
+ returnType = info.getReturnType().map(typedefMap, structs);
+ }
+ writer.append("\n ").append(returnType).append(' ');
+ writer.append(funcName).append('(');
+ writeParameters(writer, funcName, info.getParameters());
+ writer.append(");\n");
+ }
+
+ private void writeParameters(BufferedWriter writer, String funcName, List<Parameter> parameters)
+ throws IOException {
+ if (parameters != null) {
+ boolean first = true;
+ for (Parameter param : parameters) {
+ if (first) {
+ first = false;
+ } else {
+ writer.append(", ");
+ }
+ String paramName = param.getName();
+ String type = mapping.getProperty(funcName + '.' + paramName);
+ if (type == null) {
+ type = param.getType().map(typedefMap, structs);
+ }
+ if (!"void".equals(type)) {
+ writer.append(type).append(' ');
+ writer.append(paramName);
+ }
+ }
+ }
+ }
+
+ private static void addImports(Set<String> importSet, String typeName) {
+ switch (typeName) {
+ case "ByReference":
+ case "ByteByReference":
+ case "DoubleByReference":
+ case "FloatByReference":
+ case "IntByReference":
+ case "LongByReference":
+ case "NativeLongByReference":
+ case "PointerByReference":
+ case "ShortByReference":
+ importSet.add("com.sun.jna.ptr." + typeName);
+ break;
+ case "ByteBuffer":
+ case "DoubleBuffer":
+ case "FloatBuffer":
+ case "IntBuffer":
+ case "LongBuffer":
+ case "ShortBuffer":
+ importSet.add("java.nio." + typeName);
+ break;
+ default:
+ break;
+ }
+ }
+}
diff --git a/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/JnaParser.java b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/JnaParser.java
new file mode 100644
index 0000000..16a350e
--- /dev/null
+++ b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/JnaParser.java
@@ -0,0 +1,195 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mxnet.jnarator;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import org.antlr.v4.runtime.CharStreams;
+import org.antlr.v4.runtime.CommonTokenStream;
+import org.antlr.v4.runtime.tree.ParseTree;
+import org.antlr.v4.runtime.tree.ParseTreeWalker;
+import org.apache.mxnet.jnarator.parser.CBaseListener;
+import org.apache.mxnet.jnarator.parser.CLexer;
+import org.apache.mxnet.jnarator.parser.CParser;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class JnaParser {
+
+ static final Logger logger = LoggerFactory.getLogger(Main.class);
+
+ Map<String, List<TypeDefine>> structMap;
+ Map<String, List<String>> enumMap;
+ List<FuncInfo> functions;
+ Map<String, TypeDefine> typedefMap;
+ private Set<String> functionNames;
+
+ public JnaParser() {
+ structMap = new LinkedHashMap<>();
+ enumMap = new LinkedHashMap<>();
+ functions = new ArrayList<>();
+ typedefMap = new LinkedHashMap<>();
+ functionNames = new HashSet<>();
+ }
+
+ public void parse(String headerFile) {
+ try {
+ CLexer lexer = new CLexer(CharStreams.fromFileName(headerFile));
+ CommonTokenStream tokens = new CommonTokenStream(lexer);
+ CParser parser = new CParser(tokens);
+ ParseTree tree = parser.compilationUnit();
+
+ ParseTreeWalker walker = new ParseTreeWalker();
+ CBaseListener listener =
+ new CBaseListener() {
+
+ /** {@inheritDoc} */
+ @Override
+ public void enterDeclaration(CParser.DeclarationContext ctx) {
+ CParser.DeclarationSpecifiersContext specs =
+ ctx.declarationSpecifiers();
+ CParser.InitDeclaratorListContext init = ctx.initDeclaratorList();
+
+ if (AntlrUtils.isTypeDef(specs)) {
+ TypeDefine value = TypeDefine.parse(init, specs);
+ typedefMap.put(value.getDataType().getType(), value);
+ } else if (AntlrUtils.isStructOrUnion(specs)) {
+ CParser.DeclarationSpecifierContext spec =
+ (CParser.DeclarationSpecifierContext) specs.getChild(0);
+ CParser.TypeSpecifierContext type = spec.typeSpecifier();
+ CParser.StructOrUnionSpecifierContext struct =
+ type.structOrUnionSpecifier();
+ String name = struct.Identifier().getText();
+ List<TypeDefine> fields = new ArrayList<>();
+
+ CParser.StructDeclarationListContext list =
+ struct.structDeclarationList();
+ parseStructFields(fields, list);
+
+ structMap.put(name, fields);
+ } else if (AntlrUtils.isEnum(specs)) {
+ CParser.DeclarationSpecifierContext spec =
+ (CParser.DeclarationSpecifierContext) specs.getChild(0);
+ CParser.TypeSpecifierContext type = spec.typeSpecifier();
+ CParser.EnumSpecifierContext enumSpecifierContext =
+ type.enumSpecifier();
+ String name = enumSpecifierContext.Identifier().getText();
+ List<String> fields = new ArrayList<>();
+ parseEnum(fields, ctx);
+ enumMap.put(name, fields);
+ } else {
+ FuncInfo info = FuncInfo.parse(ctx);
+ if (checkDuplicate(info)) {
+ logger.warn("Duplicate function: {}.", info.getName());
+ } else {
+ functions.add(info);
+ }
+ }
+ }
+ };
+ walker.walk(listener, tree);
+ } catch (IOException e) {
+ logger.error("", e);
+ }
+ }
+
+ void parseStructFields(List<TypeDefine> fields, ParseTree tree) {
+ if (tree instanceof CParser.StructDeclarationContext) {
+ CParser.StructDeclarationContext ctx = (CParser.StructDeclarationContext) tree;
+ CParser.SpecifierQualifierListContext qualifierList = ctx.specifierQualifierList();
+ DataType dataType = DataType.parse(qualifierList);
+
+ TypeDefine typeDefine = new TypeDefine();
+ fields.add(typeDefine);
+
+ typeDefine.setDataType(dataType);
+
+ CParser.StructDeclaratorListContext name = ctx.structDeclaratorList();
+ if (name != null) {
+ typeDefine.setCallBack(true);
+
+ CParser.DirectDeclaratorContext dd =
+ name.structDeclarator().declarator().directDeclarator();
+ CParser.DirectDeclaratorContext nameCtx =
+ dd.directDeclarator().declarator().directDeclarator();
+ String fieldName = nameCtx.getText();
+ typeDefine.setValue(fieldName);
+
+ CParser.ParameterTypeListContext paramListCtx = dd.parameterTypeList();
+ if (paramListCtx != null) {
+ Parameter.parseParams(typeDefine.getParameters(), paramListCtx);
+ }
+ } else {
+ CParser.SpecifierQualifierListContext nameList =
+ qualifierList.specifierQualifierList();
+ if (nameList.specifierQualifierList() != null) {
+ typeDefine.setValue(nameList.specifierQualifierList().getText());
+ } else {
+ typeDefine.setValue(nameList.getText());
+ }
+ }
+ return;
+ }
+
+ for (int i = 0; i < tree.getChildCount(); i++) {
+ parseStructFields(fields, tree.getChild(i));
+ }
+ }
+
+ void parseEnum(List<String> fields, ParseTree ctx) {
+ if (ctx instanceof CParser.EnumerationConstantContext) {
+ fields.add(ctx.getText());
+ return;
+ }
+
+ for (int i = 0; i < ctx.getChildCount(); i++) {
+ parseEnum(fields, ctx.getChild(i));
+ }
+ }
+
+ public Map<String, List<TypeDefine>> getStructMap() {
+ return structMap;
+ }
+
+ public Map<String, List<String>> getEnumMap() {
+ return enumMap;
+ }
+
+ public List<FuncInfo> getFunctions() {
+ return functions;
+ }
+
+ public Map<String, TypeDefine> getTypedefMap() {
+ return typedefMap;
+ }
+
+ boolean checkDuplicate(FuncInfo function) {
+ if (!functionNames.add(function.getName())) {
+ for (FuncInfo info : functions) {
+ if (function.equals(info)) {
+ return true;
+ }
+ }
+ }
+ return false;
+ }
+}
diff --git a/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/Main.java b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/Main.java
new file mode 100644
index 0000000..39aa8fc
--- /dev/null
+++ b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/Main.java
@@ -0,0 +1,175 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mxnet.jnarator;
+
+import java.io.InputStream;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.List;
+import java.util.Map;
+import java.util.Properties;
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.DefaultParser;
+import org.apache.commons.cli.HelpFormatter;
+import org.apache.commons.cli.Option;
+import org.apache.commons.cli.Options;
+import org.apache.commons.cli.ParseException;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public final class Main {
+
+ private static final Logger logger = LoggerFactory.getLogger(Main.class);
+
+ private Main() {}
+
+ public static void main(String[] args) {
+ Options options = Config.getOptions();
+ try {
+ DefaultParser cmdParser = new DefaultParser();
+ CommandLine cmd = cmdParser.parse(options, args, null, false);
+ Config config = new Config(cmd);
+
+ String output = config.getOutput();
+ String packageName = config.getPackageName();
+ String library = config.getLibrary();
+ String[] headerFiles = config.getHeaderFiles();
+ String mappingFile = config.getMappingFile();
+
+ Path dir = Paths.get(output);
+ Files.createDirectories(dir);
+
+ Properties mapping = new Properties();
+ if (mappingFile != null) {
+ Path file = Paths.get(mappingFile);
+ if (Files.notExists(file)) {
+ logger.error("mapping file does not exists: {}", mappingFile);
+ System.exit(-1); // NOPMD
+ }
+ try (InputStream in = Files.newInputStream(file)) {
+ mapping.load(in);
+ }
+ }
+
+ JnaParser jnaParser = new JnaParser();
+ Map<String, TypeDefine> typedefMap = jnaParser.getTypedefMap();
+ Map<String, List<TypeDefine>> structMap = jnaParser.getStructMap();
+ JnaGenerator generator =
+ new JnaGenerator(library, packageName, typedefMap, structMap.keySet(), mapping);
+ generator.init(output);
+
+ for (String headerFile : headerFiles) {
+ jnaParser.parse(headerFile);
+ }
+
+ generator.writeNativeSize();
+ generator.writeStructure(structMap);
+ generator.writeLibrary(jnaParser.getFunctions(), jnaParser.getEnumMap());
+ } catch (ParseException e) {
+ HelpFormatter formatter = new HelpFormatter();
+ formatter.setLeftPadding(1);
+ formatter.setWidth(120);
+ formatter.printHelp(e.getMessage(), options);
+ System.exit(-1); // NOPMD
+ } catch (Throwable t) {
+ logger.error("", t);
+ System.exit(-1); // NOPMD
+ }
+ }
+
+ public static final class Config {
+
+ private String library;
+ private String packageName;
+ private String output;
+ private String[] headerFiles;
+ private String mappingFile;
+
+ public Config(CommandLine cmd) {
+ library = cmd.getOptionValue("library");
+ packageName = cmd.getOptionValue("package");
+ output = cmd.getOptionValue("output");
+ headerFiles = cmd.getOptionValues("header");
+ mappingFile = cmd.getOptionValue("mappingFile");
+ }
+
+ public static Options getOptions() {
+ Options options = new Options();
+ options.addOption(
+ Option.builder("l")
+ .longOpt("library")
+ .hasArg()
+ .required()
+ .argName("LIBRARY")
+ .desc("library name")
+ .build());
+ options.addOption(
+ Option.builder("p")
+ .longOpt("package")
+ .required()
+ .hasArg()
+ .argName("PACKAGE")
+ .desc("Java package name")
+ .build());
+ options.addOption(
+ Option.builder("o")
+ .longOpt("output")
+ .required()
+ .hasArg()
+ .argName("OUTPUT")
+ .desc("output directory")
+ .build());
+ options.addOption(
+ Option.builder("f")
+ .longOpt("header")
+ .required()
+ .hasArgs()
+ .argName("HEADER")
+ .desc("Header files")
+ .build());
+ options.addOption(
+ Option.builder("m")
+ .longOpt("mappingFile")
+ .hasArg()
+ .argName("MAPPING_FILE")
+ .desc("Type mappingFile config file")
+ .build());
+ return options;
+ }
+
+ public String getLibrary() {
+ return library;
+ }
+
+ public String getPackageName() {
+ return packageName;
+ }
+
+ public String getOutput() {
+ return output;
+ }
+
+ public String[] getHeaderFiles() {
+ return headerFiles;
+ }
+
+ public String getMappingFile() {
+ return mappingFile;
+ }
+ }
+}
diff --git a/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/Parameter.java b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/Parameter.java
new file mode 100644
index 0000000..f46e5e7
--- /dev/null
+++ b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/Parameter.java
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mxnet.jnarator;
+
+import java.util.List;
+import java.util.Objects;
+import org.antlr.v4.runtime.tree.ParseTree;
+import org.apache.mxnet.jnarator.parser.CParser;
+
+public class Parameter {
+
+ private DataType type;
+ private String name;
+
+ public Parameter(DataType type, String name) {
+ this.type = type;
+ this.name = name;
+ }
+
+ public DataType getType() {
+ return type;
+ }
+
+ public String getName() {
+ return name;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public String toString() {
+ return type.toString() + ' ' + name;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ Parameter parameter = (Parameter) o;
+ return type.equals(parameter.type);
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public int hashCode() {
+ return Objects.hash(type);
+ }
+
+ static void parseParams(List<Parameter> params, ParseTree ctx) {
+ if (ctx instanceof CParser.ParameterDeclarationContext) {
+ CParser.ParameterDeclarationContext declarationContext =
+ (CParser.ParameterDeclarationContext) ctx;
+ CParser.DeclarationSpecifiersContext spec = declarationContext.declarationSpecifiers();
+ DataType dataType;
+ if (spec == null) {
+ dataType = DataType.parse(declarationContext.declarationSpecifiers2());
+ } else {
+ dataType = DataType.parse(spec);
+ }
+
+ CParser.DeclaratorContext declarator = declarationContext.declarator();
+
+ String name;
+ if (declarator != null) {
+ CParser.PointerContext pointer = declarator.pointer();
+ if (pointer != null) {
+ dataType.increasePointerCount();
+ }
+ name = declarator.directDeclarator().getText();
+ } else {
+ name = "arg" + (params.size() + 1);
+ }
+
+ Parameter param = new Parameter(dataType, name);
+ params.add(param);
+ return;
+ }
+ for (int i = 0; i < ctx.getChildCount(); i++) {
+ parseParams(params, ctx.getChild(i));
+ }
+ }
+}
diff --git a/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/TypeDefine.java b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/TypeDefine.java
new file mode 100644
index 0000000..1f30d21
--- /dev/null
+++ b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/TypeDefine.java
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mxnet.jnarator;
+
+import java.util.ArrayList;
+import java.util.List;
+import org.apache.mxnet.jnarator.parser.CParser;
+
+public class TypeDefine {
+
+ private DataType dataType;
+ private boolean callBack;
+ private String value;
+ private List<Parameter> parameters = new ArrayList<>();
+
+ public DataType getDataType() {
+ return dataType;
+ }
+
+ public void setDataType(DataType dataType) {
+ this.dataType = dataType;
+ }
+
+ public boolean isCallBack() {
+ return callBack;
+ }
+
+ public void setCallBack(boolean callBack) {
+ this.callBack = callBack;
+ }
+
+ public String getValue() {
+ return value;
+ }
+
+ public void setValue(String value) {
+ this.value = value;
+ }
+
+ public List<Parameter> getParameters() {
+ return parameters;
+ }
+
+ static TypeDefine parse(
+ CParser.InitDeclaratorListContext init, CParser.DeclarationSpecifiersContext specs) {
+ TypeDefine typeDefine = new TypeDefine();
+ DataType dataType = new DataType();
+ typeDefine.setDataType(dataType);
+
+ CParser.DirectDeclaratorContext ctx = init.initDeclarator().declarator().directDeclarator();
+ CParser.DirectDeclaratorContext callback = ctx.directDeclarator();
+ if (callback == null) {
+ dataType.setType(ctx.getText());
+ } else {
+ typeDefine.setCallBack(true);
+ dataType.setType(callback.declarator().directDeclarator().getText());
+ CParser.ParameterTypeListContext paramListCtx = ctx.parameterTypeList();
+ List<Parameter> parameters = typeDefine.getParameters();
+ Parameter.parseParams(parameters, paramListCtx);
+ }
+
+ List<String> list = new ArrayList<>();
+ for (int i = 1; i < specs.getChildCount(); ++i) {
+ list.add(specs.getChild(i).getText());
+ }
+
+ typeDefine.setValue(String.join(" ", list));
+ return typeDefine;
+ }
+}
diff --git a/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/package-info.java b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/package-info.java
new file mode 100644
index 0000000..a6cff70
--- /dev/null
+++ b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/package-info.java
@@ -0,0 +1,19 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/** Contains classes to generate the Apache MXNet (incubating) native interface. */
+package org.apache.mxnet.jnarator;
diff --git a/java-package/jnarator/src/main/resources/log4j2.xml b/java-package/jnarator/src/main/resources/log4j2.xml
new file mode 100644
index 0000000..4818a95
--- /dev/null
+++ b/java-package/jnarator/src/main/resources/log4j2.xml
@@ -0,0 +1,30 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!--
+ ~ Licensed to the Apache Software Foundation (ASF) under one or more
+ ~ contributor license agreements. See the NOTICE file distributed with
+ ~ this work for additional information regarding copyright ownership.
+ ~ The ASF licenses this file to You under the Apache License, Version 2.0
+ ~ (the "License"); you may not use this file except in compliance with
+ ~ the License. You may obtain a copy of the License at
+ ~
+ ~ http://www.apache.org/licenses/LICENSE-2.0
+ ~
+ ~ Unless required by applicable law or agreed to in writing, software
+ ~ distributed under the License is distributed on an "AS IS" BASIS,
+ ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ ~ See the License for the specific language governing permissions and
+ ~ limitations under the License.
+ -->
+
+<Configuration status="WARN">
+ <Appenders>
+ <Console name="Console" target="SYSTEM_OUT">
+ <PatternLayout pattern="[%highlight{%-5level}] %msg%n%throwable"/>
+ </Console>
+ </Appenders>
+ <Loggers>
+ <Root level="DEBUG" additivity="false">
+ <AppenderRef ref="Console"/>
+ </Root>
+ </Loggers>
+</Configuration>
diff --git a/java-package/mxnet-engine/build.gradle b/java-package/mxnet-engine/build.gradle
new file mode 100644
index 0000000..c874f72
--- /dev/null
+++ b/java-package/mxnet-engine/build.gradle
@@ -0,0 +1,188 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+plugins {
+ id 'java'
+}
+
+group 'org.apache.mxnet'
+version '0.0.1-SNAPSHOT'
+
+repositories {
+ mavenCentral()
+}
+
+def getOsName() {
+ def os_name = System.properties['os.name']
+ if (os_name.contains('windows')) {
+ return "win"
+ } else if (os_name.contains('Mac OS X')) {
+ return "osx"
+ } else if (os_name.contains('Linux')) {
+ return "linux"
+ } else {
+ return System.properties['os.name']
+ }
+}
+
+dependencies {
+ api "com.google.code.gson:gson:${gson_version}"
+ api "net.java.dev.jna:jna:${jna_version}"
+ api "org.apache.commons:commons-compress:${commons_compress_version}"
+ api "org.slf4j:slf4j-api:${slf4j_version}"
+
+ testImplementation("org.testng:testng:${testng_version}") {
+ exclude group: "junit", module: "junit"
+ }
+ testImplementation "org.slf4j:slf4j-simple:${slf4j_version}"
+ // Solve the problem: Failed to load class "org.slf4j.impl.StaticLoggerBinder".
+ implementation "org.slf4j:slf4j-simple:${slf4j_version}"
+ def osName = getOsName()
+ implementation files("${project(':native').buildDir}/libs/native-${mxnet_version}-SNAPSHOT-${osName}-x86_64.jar")
+// implementation fileTree(dir: "${project(':native').buildDir}/lib", includes: ["native-${mxnet_version}-SNAPSHOT-${osName}-x86_64.jar"])
+}
+
+sourceSets {
+ main {
+ java {
+ srcDirs = ['src/main/java', 'build/generated-src']
+ }
+ }
+}
+
+checkstyleMain.source = 'src/main/java'
+pmdMain.source = 'src/main/java'
+
+task jnarator(dependsOn: ":jnarator:jar") {
+ outputs.dir "${project.buildDir}/generated-src"
+ doLast {
+ File jnaGenerator = project(":jnarator").jar.outputs.files.singleFile
+ javaexec {
+ main = "-jar"
+ args = [
+ jnaGenerator.absolutePath,
+ "-l",
+ "mxnet",
+ "-p",
+ "org.apache.mxnet.jna",
+ "-o",
+ "${project.buildDir}/generated-src",
+ "-m",
+ "${project.projectDir}/src/main/jna/mapping.properties",
+ "-f",
+ "../../include/mxnet/c_api.h",
+ "../../include/nnvm/c_api.h"
+ ]
+ }
+ }
+}
+
+test {
+ useTestNG() {
+ useDefaultListeners = true
+ }
+ environment "PATH", "src/test/bin:${environment.PATH}"
+// environment "MXNET_LIBRARY_PATH", "${MXNET_LIBRARY_PATH}"
+ maxHeapSize = '6G'
+ testLogging.showStandardStreams = true
+ beforeTest { descriptor ->
+ logger.lifecycle("Running test: " + descriptor)
+ }
+ failFast = false
+ onOutput { descriptor, event ->
+ logger.lifecycle("Test: " + descriptor + " produced standard out/err: " + event.message )
+ }
+// debugOptions {
+// enabled = true
+// port = 4455
+// server = true
+// suspend = true
+// }
+// filter {
+// includeTestsMatching("*Test")
+// }
+}
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+//import java.util.regex.Matcher
+//import java.util.regex.Pattern
+
+//def checkForUpdate(String path, String url) {
+// def expected = new URL(url).text
+// def actual = new File("${project.projectDir}/src/main/include/${path}").text
+// if (!actual.equals(expected)) {
+// def fileName = path.replaceAll("[/\\\\]", '_')
+// file("${project.projectDir}/build").mkdirs()
+// (file("${project.projectDir}/build/${fileName}")).text = expected
+// logger.warn("[\033[31mWARN\033[0m ] Header file has been changed in open source project: ${path}.")
+// }
+//}
+
+//task checkHeaderFile() {
+// outputs.files "${project.buildDir}/mxnet_c_api.h", "${project.buildDir}/nnvm_c_api.h"
+// doFirst {
+// if (gradle.startParameter.offline) {
+// logger.warn("[\033[31mWARN\033[0m ] Ignore header validation in offline mode.")
+// return
+// }
+//
+// def mxnetUrl = "https://raw.githubusercontent.com/apache/incubator-mxnet/v1.7.x/"
+// checkForUpdate("mxnet/c_api.h", "${mxnetUrl}/include/mxnet/c_api.h")
+// def content = new URL("https://github.com/apache/incubator-mxnet/tree/v1.7.x/3rdparty").text
+//
+// Pattern pattern = Pattern.compile("href=\"/apache/incubator-tvm/tree/([a-z0-9]+)\"")
+// Matcher m = pattern.matcher(content);
+// if (!m.find()) {
+// throw new GradleException("Failed to retrieve submodule hash for tvm from github")
+// }
+// String hash = m.group(1);
+//
+// def nnvmUrl = "https://raw.githubusercontent.com/apache/incubator-tvm/${hash}"
+// checkForUpdate("nnvm/c_api.h", "${nnvmUrl}/nnvm/include/nnvm/c_api.h")
+// }
+//}
+
+compileJava.dependsOn(jnarator)
+
+// TODO
+//publishing {
+// publications {
+// maven(MavenPublication) {
+// pom {
+// name = "DJL Engine Adapter for Apache MXNet"
+// description = "Deep Java Library (DJL) Engine Adapter for Apache MXNet"
+// url = "http://www.djl.ai/mxnet/${project.name}"
+// }
+// }
+// }
+//}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/BaseMxResource.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/BaseMxResource.java
new file mode 100644
index 0000000..13ae88d
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/BaseMxResource.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.engine;
+
+import org.apache.mxnet.jna.JnaUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * The top-level {@link MxResource} instance, with no parent Resource to manage. The {@link
+ * BaseMxResource} instance will be lazy loaded when the first time called, like when {@link Model}
+ * instance is loaded for the first time.
+ */
+public final class BaseMxResource extends MxResource {
+
+ private static final Logger logger = LoggerFactory.getLogger(BaseMxResource.class);
+
+ private static BaseMxResource systemMxResource;
+
+ protected BaseMxResource() {
+ super();
+ // Workaround MXNet engine lazy initialization issue
+ JnaUtils.getAllOpNames();
+
+ JnaUtils.setNumpyMode(JnaUtils.NumpyMode.GLOBAL_ON);
+
+ // Workaround MXNet shutdown crash issue
+ Runtime.getRuntime().addShutdownHook(new Thread(JnaUtils::waitAll)); // NOPMD
+ }
+
+ /**
+ * Getter method for the singleton {@code systemMxResource} instance.
+ *
+ * @return The top-leve {@link BaseMxResource} instance.
+ */
+ public static synchronized BaseMxResource getSystemMxResource() {
+ if (systemMxResource == null) {
+ systemMxResource = new BaseMxResource();
+ }
+ return systemMxResource;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public void close() {
+ if (!getClosed()) {
+ logger.debug(String.format("Start to free BaseMxResource instance: %S", this.getUid()));
+ // only clean sub resources
+ JnaUtils.waitAll();
+ super.freeSubResources();
+ setClosed(true);
+ logger.debug(
+ String.format("Finish to free BaseMxResource instance: %S", this.getUid()));
+ }
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/CachedOp.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/CachedOp.java
new file mode 100644
index 0000000..eef058e
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/CachedOp.java
@@ -0,0 +1,162 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mxnet.engine;
+
+import com.sun.jna.Pointer;
+import java.util.List;
+import java.util.Map;
+import org.apache.mxnet.jna.JnaUtils;
+import org.apache.mxnet.ndarray.NDArray;
+import org.apache.mxnet.ndarray.NDList;
+import org.apache.mxnet.ndarray.types.Shape;
+import org.apache.mxnet.nn.Parameter;
+import org.apache.mxnet.nn.SymbolBlock;
+import org.apache.mxnet.util.Pair;
+import org.apache.mxnet.util.PairList;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * The {@code CachedOp} is an internal helper that provides the core functionality to execute a
+ * {@link SymbolBlock}.
+ *
+ * <p>We don't recommend users interact with this class directly. Users should use {@link Predictor}
+ * instead. CachedOp is an operator that simplifies calling and analyzing the input shape. It
+ * requires minimum input to do inference because most of the information can be obtained from the
+ * model itself.
+ */
+public class CachedOp extends MxResource {
+
+ private static final Logger logger = LoggerFactory.getLogger(CachedOp.class);
+
+ private List<Parameter> parameters;
+ private PairList<String, Integer> dataIndices;
+ private Map<String, Integer> dataIndicesMap;
+ private List<Integer> paramIndices;
+
+ /**
+ * Creates an instance of {@link CachedOp}.
+ *
+ * <p>It can be created by using {@link JnaUtils#createCachedOp(SymbolBlock, MxResource)}
+ *
+ * @param parent the MxResource object to manage this instance of CachedOp
+ * @param handle the C handle of the CachedOp
+ * @param parameters the parameter values
+ * @param paramIndices the parameters required by the model and their corresponding location
+ * @param dataIndices the input data names required by the model and their corresponding
+ * location
+ */
+ public CachedOp(
+ MxResource parent,
+ Pointer handle,
+ List<Parameter> parameters,
+ List<Integer> paramIndices,
+ PairList<String, Integer> dataIndices) {
+ super(parent, handle);
+ this.parameters = parameters;
+ this.dataIndices = dataIndices;
+ this.paramIndices = paramIndices;
+ this.dataIndicesMap = dataIndices.toMap();
+ }
+
+ /**
+ * Assigns inputs to the empty locations of the input NDArray.
+ *
+ * @param data the input in {@link NDList} format
+ * @return an {@link NDList}
+ */
+ public NDList forward(NDList data) {
+ // reset the input data index at the beginning
+ NDArray[] allInputsNDArray = new NDArray[parameters.size()];
+ // check device of input
+ Device device = data.isEmpty() ? Device.defaultIfNull() : data.head().getDevice();
+ // fill allInputsNDArray with parameter values on correct device
+ for (int index : paramIndices) {
+ Parameter parameter = parameters.get(index);
+ NDArray value = parameter.getArray();
+ if (value == null) {
+ throw new NullPointerException("Failed to find parameter from parameterStore");
+ }
+ value.setDevice(device);
+ allInputsNDArray[index] = value;
+ }
+
+ // fill allInputsNDArray with data values
+ int index = 0;
+ for (NDArray array : data) {
+ // TODO: NDArray name doesn't match. To confirm the format of input name
+ // String inputName = array.getName().split(":")[1];
+ String inputName = array.getName();
+ // if inputName not provided, value will follow the default order
+ int idx = indexOf(inputName, index++);
+ allInputsNDArray[idx] = array;
+ }
+
+ // check the input, set as Shape(batchSize) by default
+ for (Pair<String, Integer> pair : dataIndices) {
+ if (allInputsNDArray[pair.getValue()] == null) {
+ // TODO: Do we need to set default to the input?
+ long batchSize = data.head().getShape().get(0);
+ String key = pair.getKey();
+ if (!"prob_label".equals(key) && !"softmax_label".equals(key)) {
+ logger.warn(
+ "Input "
+ + key
+ + " not found, set NDArray to Shape("
+ + batchSize
+ + ") by default");
+ }
+ // TODO: consider how to manage MxNDArray generated during inference
+ allInputsNDArray[pair.getValue()] =
+ NDArray.create(this, new Shape(batchSize), device);
+ }
+ }
+ NDArray[] result = JnaUtils.cachedOpInvoke(getParent(), getHandle(), allInputsNDArray);
+ return new NDList(result);
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public void close() {
+ if (!getClosed()) {
+ logger.debug(String.format("Start to free CachedOp instance: %S", this.getUid()));
+ super.freeSubResources();
+ Pointer pointer = handle.getAndSet(null);
+ if (pointer != null) {
+ JnaUtils.freeCachedOp(pointer);
+ }
+ setClosed(true);
+ logger.debug(String.format("Finish to free CachedOp instance: %S", this.getUid()));
+ }
+ }
+
+ private int indexOf(String inputName, int position) {
+ if (inputName == null) {
+ return dataIndices.valueAt(position);
+ }
+
+ Integer index = dataIndicesMap.get(inputName);
+ if (index == null) {
+ throw new IllegalArgumentException(
+ "Unknown input name: "
+ + inputName
+ + ", expected inputs: "
+ + dataIndicesMap.keySet().toString());
+ }
+ return index;
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/Device.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/Device.java
new file mode 100644
index 0000000..7fb74c6
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/Device.java
@@ -0,0 +1,236 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.engine;
+
+import java.util.Map;
+import java.util.Objects;
+import java.util.concurrent.ConcurrentHashMap;
+import org.apache.mxnet.util.cuda.CudaUtils;
+
+/**
+ * The {@code Device} class provides the specified assignment for CPU/GPU processing on the {@link
+ * org.apache.mxnet.ndarray.NDArray}.
+ *
+ * <p>Users can use this to specify whether to load/compute the {@code NDArray} on CPU/GPU with
+ * deviceType and deviceId provided.
+ */
+public final class Device {
+
+ private static final Map<String, Device> CACHE = new ConcurrentHashMap<>();
+
+ private static final Device CPU = new Device(Type.CPU, -1);
+
+ private static final Device GPU = Device.of(Type.GPU, 0);
+
+ private String deviceType;
+
+ private int deviceId;
+
+ private static final Device DEFAULT_DEVICE = CPU;
+
+ /**
+ * Creates a {@code Device} with basic information.
+ *
+ * @param deviceType the device type, typically CPU or GPU
+ * @param deviceId the deviceId on the hardware. For example, if you have multiple GPUs, you can
+ * choose which GPU to process the NDArray
+ */
+ private Device(String deviceType, int deviceId) {
+ this.deviceType = deviceType;
+ this.deviceId = deviceId;
+ }
+
+ /**
+ * Returns a {@code Device} with device type and device id.
+ *
+ * @param deviceType the device type, typically CPU or GPU
+ * @param deviceId the deviceId on the hardware.
+ * @return a {@code Device} instance
+ */
+ public static Device of(String deviceType, int deviceId) {
+ if (Type.CPU.equals(deviceType)) {
+ return CPU;
+ }
+ String key = deviceType + '-' + deviceId;
+ return CACHE.computeIfAbsent(key, k -> new Device(deviceType, deviceId));
+ }
+
+ /**
+ * Returns the device type of the Device.
+ *
+ * @return the device type of the Device
+ */
+ public String getDeviceType() {
+ return deviceType;
+ }
+
+ /**
+ * Returns the {@code deviceId} of the Device.
+ *
+ * @return the {@code deviceId} of the Device
+ */
+ public int getDeviceId() {
+ if (Type.CPU.equals(deviceType)) {
+ throw new IllegalStateException("CPU doesn't have device id");
+ }
+ return deviceId;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public String toString() {
+ if (Type.CPU.equals(deviceType)) {
+ return deviceType + "()";
+ }
+ return deviceType + '(' + deviceId + ')';
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ Device device = (Device) o;
+ if (Type.CPU.equals(deviceType)) {
+ return Objects.equals(deviceType, device.getDeviceType());
+ }
+ return deviceId == device.deviceId && Objects.equals(deviceType, device.deviceType);
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public int hashCode() {
+ return Objects.hash(deviceType, deviceId);
+ }
+
+ /**
+ * Returns the default CPU Device.
+ *
+ * @return the default CPU Device
+ */
+ public static Device cpu() {
+ return CPU;
+ }
+
+ /**
+ * Returns the default GPU Device.
+ *
+ * @return the default GPU Device
+ */
+ public static Device gpu() {
+ return GPU;
+ }
+
+ /**
+ * Returns a new instance of GPU {@code Device} with the specified {@code deviceId}.
+ *
+ * @param deviceId the GPU device ID
+ * @return a new instance of GPU {@code Device} with specified {@code deviceId}
+ */
+ public static Device gpu(int deviceId) {
+ return of(Type.GPU, deviceId);
+ }
+
+ /**
+ * Returns an array of devices.
+ *
+ * <p>If GPUs are available, it will return an array of {@code Device} of size
+ * \(min(numAvailable, maxGpus)\). Else, it will return an array with a single CPU device.
+ *
+ * @return an array of devices
+ */
+ public static Device[] getDevices() {
+ return getDevices(Integer.MAX_VALUE);
+ }
+
+ /**
+ * Returns an array of devices given the maximum number of GPUs to use.
+ *
+ * <p>If GPUs are available, it will return an array of {@code Device} of size
+ * \(min(numAvailable, maxGpus)\). Else, it will return an array with a single CPU device.
+ *
+ * @param maxGpus the max number of GPUs to use. Use 0 for no GPUs.
+ * @return an array of devices
+ */
+ public static Device[] getDevices(int maxGpus) {
+ int count = getGpuCount();
+ if (maxGpus <= 0 || count <= 0) {
+ return new Device[] {CPU};
+ }
+
+ count = Math.min(maxGpus, count);
+ Device[] devices = new Device[count];
+ for (int i = 0; i < devices.length; ++i) {
+ devices[i] = gpu(i);
+ }
+ return devices;
+ }
+
+ /**
+ * Returns the number of GPUs available in the system.
+ *
+ * @return the number of GPUs available in the system
+ */
+ public static int getGpuCount() {
+ return CudaUtils.getGpuCount();
+ }
+
+ /**
+ * Returns the default context used in Engine.
+ *
+ * <p>The default type is defined by whether the deep learning engine is recognizing GPUs
+ * available on your machine. If there is no GPU available, CPU will be used.
+ *
+ * @return a {@link Device}
+ */
+ private static Device defaultDevice() {
+ return DEFAULT_DEVICE;
+ }
+
+ /**
+ * Returns the given device or the default if it is null.
+ *
+ * @param device the device to try to return
+ * @return the given device or the default if it is null
+ */
+ public static Device defaultIfNull(Device device) {
+ if (device != null) {
+ return device;
+ }
+ return defaultDevice();
+ }
+
+ /**
+ * Returns the default device.
+ *
+ * @return the default device
+ */
+ public static Device defaultIfNull() {
+ return defaultIfNull(null);
+ }
+
+ /** Contains device type string constants. */
+ public interface Type {
+ String CPU = "cpu";
+ String GPU = "gpu";
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/DeviceType.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/DeviceType.java
new file mode 100644
index 0000000..02ad120
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/DeviceType.java
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.engine;
+
+/** {@code DeviceType} is a class used to map the Device name to their corresponding type number. */
+public final class DeviceType {
+
+ private static final String CPU_PINNED = "cpu_pinned";
+
+ private DeviceType() {}
+
+ /**
+ * Converts a {@link Device} to the corresponding MXNet device number.
+ *
+ * @param device the java {@link Device}
+ * @return the MXNet device number
+ * @exception IllegalArgumentException the device is null or is not supported
+ */
+ public static int toDeviceType(Device device) {
+ if (device == null) {
+ throw new IllegalArgumentException("Unsupported device: null");
+ }
+
+ String deviceType = device.getDeviceType();
+
+ if (Device.Type.CPU.equals(deviceType)) {
+ return 1;
+ } else if (Device.Type.GPU.equals(deviceType)) {
+ return 2;
+ } else if (CPU_PINNED.equals(deviceType)) {
+ return 3;
+ } else {
+ throw new IllegalArgumentException("Unsupported device: " + device.toString());
+ }
+ }
+
+ /**
+ * Converts from an MXNet device number to {@link Device}.
+ *
+ * @param deviceType the MXNet device number
+ * @return the corresponding {@link Device}
+ */
+ public static String fromDeviceType(int deviceType) {
+ switch (deviceType) {
+ case 1:
+ case 3:
+ // hide the CPU_PINNED to frontend user
+ // but the advance user can still create CPU_PINNED
+ // to pass through engine
+ return Device.Type.CPU;
+ case 2:
+ return Device.Type.GPU;
+ default:
+ throw new IllegalArgumentException("Unsupported deviceType: " + deviceType);
+ }
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/GradReq.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/GradReq.java
new file mode 100644
index 0000000..c388de3
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/GradReq.java
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.engine;
+
+/** An enum that indicates whether gradient is required. */
+public enum GradReq {
+ NULL("null", 0),
+ WRITE("write", 1),
+ ADD("add", 3);
+
+ private String type;
+ private int value;
+
+ GradReq(String type, int value) {
+ this.type = type;
+ this.value = value;
+ }
+
+ /**
+ * Gets the type of this {@code GradReq}.
+ *
+ * @return the type
+ */
+ public String getType() {
+ return type;
+ }
+
+ /**
+ * Gets the value of this {@code GradType}.
+ *
+ * @return the value
+ */
+ public int getValue() {
+ return value;
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/Model.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/Model.java
new file mode 100644
index 0000000..90cceb4
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/Model.java
@@ -0,0 +1,442 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.engine;
+
+import java.io.FileNotFoundException;
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.util.ArrayList;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Locale;
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+import org.apache.mxnet.jna.JnaUtils;
+import org.apache.mxnet.ndarray.NDArray;
+import org.apache.mxnet.ndarray.NDList;
+import org.apache.mxnet.ndarray.types.DataType;
+import org.apache.mxnet.ndarray.types.Shape;
+import org.apache.mxnet.nn.Parameter;
+import org.apache.mxnet.nn.SymbolBlock;
+import org.apache.mxnet.repository.Item;
+import org.apache.mxnet.repository.Repository;
+import org.apache.mxnet.translate.NoOpTranslator;
+import org.apache.mxnet.translate.Translator;
+import org.apache.mxnet.util.PairList;
+import org.apache.mxnet.util.Utils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * A model is a collection of artifacts that is created by the training process.
+ *
+ * <p>Model contains methods to load and process a model object. In addition, it provides MXNet
+ * Specific functionality, such as getSymbol to obtain the Symbolic graph and getParameters to
+ * obtain the parameter NDArrays
+ */
+public class Model extends MxResource {
+
+ private static final Logger logger = LoggerFactory.getLogger(Model.class);
+ protected Path modelDir;
+ protected SymbolBlock symbolBlock;
+ protected String modelName;
+ protected DataType dataType;
+ protected PairList<String, Shape> inputData;
+ protected Map<String, Object> artifacts = new ConcurrentHashMap<>();
+ protected Map<String, String> properties = new ConcurrentHashMap<>();
+
+ Model(String name, Device device) {
+ this(BaseMxResource.getSystemMxResource(), name, device);
+ }
+
+ private Model(MxResource parent, String name, Device device) {
+ super(parent);
+ setDevice(Device.defaultIfNull(device));
+ setDataType(DataType.FLOAT32);
+ setModelName(name);
+ }
+
+ /**
+ * Create a default {@link Predictor} instance, with {@link NoOpTranslator} as default
+ * translator , and do not copy parameters to parameter store.
+ *
+ * @return {@link Predictor}
+ */
+ public Predictor<NDList, NDList> newPredictor() {
+ Translator<NDList, NDList> noOpTranslator = new NoOpTranslator();
+ return newPredictor(noOpTranslator);
+ }
+
+ /**
+ * Create {@link Predictor} instance, with specific {@link Translator} and {@code copy}.
+ *
+ * @param translator {@link Translator} used to convert inputs and outputs into {@link NDList}
+ * to get inferred
+ * @param <I> the input type
+ * @param <O> the output type
+ * @return {@link Predictor}
+ */
+ public <I, O> Predictor<I, O> newPredictor(Translator<I, O> translator) {
+ return new Predictor<>(this, translator);
+ }
+
+ /**
+ * Create and initialize a MxModel from the model directory.
+ *
+ * @param modelPath {@code Path} model directory
+ * @return loaded {@code Model} instance
+ * @throws IOException when IO operation fails in loading a resource
+ */
+ public static Model loadModel(Path modelPath) throws IOException {
+ return loadModel("model", modelPath);
+ }
+
+ /**
+ * Create and initialize a MxModel from repository Item.
+ *
+ * @param modelItem {@link Item} model directory
+ * @return {@link Model}
+ * @throws IOException when IO operation fails in loading a resource
+ */
+ public static Model loadModel(Item modelItem) throws IOException {
+ Model model = createModel(modelItem);
+ model.initial();
+ return model;
+ }
+
+ /**
+ * Create and initialize a MxModel with a model name from the model directory.
+ *
+ * @param modelName {@link String} model name
+ * @param modelPath {@link Path} model directory
+ * @return {@link Model}
+ * @throws IOException when IO operation fails in loading a resource
+ */
+ public static Model loadModel(String modelName, Path modelPath) throws IOException {
+ Model model = createModel(modelName, modelPath);
+ model.initial();
+ return model;
+ }
+
+ /**
+ * Create a MxModel with specific model name and model directory. By default, the {@link Model}
+ * instance is managed by the top level {@link BaseMxResource}.
+ *
+ * @param modelName {@String} model name
+ * @param modelDir {@Path} local model path
+ * @return {@link Model}
+ * @throws IOException when IO operation fails in loading a resource
+ */
+ static Model createModel(String modelName, Path modelDir) {
+ Model model = new Model(modelName, Device.defaultIfNull());
+ model.setModelDir(modelDir);
+ return model;
+ }
+
+ /**
+ * Create a sample MxModel Download or find the local path for the sample model.
+ *
+ * @param item {@link Item} sample model to be created
+ * @return created {@link Model} instance
+ * @throws IOException when IO operation fails in loading a resource
+ */
+ static Model createModel(Item item) throws IOException {
+ Path modelDir = Repository.initRepository(item);
+ return createModel(item.getName(), modelDir);
+ }
+
+ /**
+ * Initialize the model object Download or find the path for target model Load parameters and
+ * symbol from the path.
+ *
+ * @throws IOException when IO operation fails in loading a resource
+ * @throws FileNotFoundException if Model Directory is not assigned
+ */
+ public void initial() throws IOException {
+ if (getModelDir() == null) {
+ throw new FileNotFoundException("Model path is not defined!");
+ }
+ load(getModelDir());
+ }
+
+ /**
+ * Loads the model from the {@code modelPath}.
+ *
+ * @param modelPath the directory or file path of the model location
+ * @throws IOException when IO operation fails in loading a resource
+ */
+ public void load(Path modelPath) throws IOException {
+ load(modelPath, null, null);
+ }
+
+ /**
+ * Loads the MXNet model from a specified location.
+ *
+ * <p>MXNet Model looks for {MODEL_NAME}-symbol.json and {MODEL_NAME}-{EPOCH}.params files in
+ * the specified directory. By default, It will pick up the latest epoch of the parameter file.
+ * However, users can explicitly specify an epoch to be loaded:
+ *
+ * <pre>
+ * Map<String, String> options = new HashMap<>()
+ * <b>options.put("epoch", "3");</b>
+ * model.load(modelPath, "squeezenet", options);
+ * </pre>
+ *
+ * @param modelPath the directory of the model
+ * @param prefix the model file name or path prefix
+ * @param options load model options, see documentation for the specific engine
+ * @throws IOException Exception for file loading
+ */
+ public void load(Path modelPath, String prefix, Map<String, ?> options) throws IOException {
+ modelDir = modelPath.toAbsolutePath();
+ if (prefix == null) {
+ prefix = modelName;
+ }
+ Path paramFile = paramPathResolver(prefix, options);
+ if (paramFile == null) {
+ prefix = modelDir.toFile().getName();
+ paramFile = paramPathResolver(prefix, options);
+ if (paramFile == null) {
+ throw new FileNotFoundException(
+ "Parameter file with prefix: " + prefix + " not found in: " + modelDir);
+ }
+ }
+
+ if (getSymbolBlock() == null) {
+ // load MxSymbolBlock
+ Path symbolFile = modelDir.resolve(prefix + "-symbol.json");
+ if (Files.notExists(symbolFile)) {
+ throw new FileNotFoundException(
+ "Symbol file not found: "
+ + symbolFile
+ + ", please set block manually for imperative model.");
+ }
+
+ // TODO: change default name "data" to model-specific one
+ setMxSymbolBlock(SymbolBlock.createMxSymbolBlock(this, symbolFile));
+ }
+ loadParameters(paramFile);
+ // TODO: Check if Symbol has all names that params file have
+ if (options != null && options.containsKey("MxOptimizeFor")) {
+ String optimization = (String) options.get("MxOptimizeFor");
+ getSymbolBlock().optimizeFor(optimization, getDevice());
+ }
+ }
+
+ protected Path paramPathResolver(String prefix, Map<String, ?> options) throws IOException {
+ try {
+ int epoch = getEpoch(prefix, options);
+ return getModelDir()
+ .resolve(String.format(Locale.ROOT, "%s-%04d.params", prefix, epoch));
+ } catch (FileNotFoundException e) {
+ return null;
+ }
+ }
+
+ private int getEpoch(String prefix, Map<String, ?> options) throws IOException {
+ if (options != null) {
+ Object epochOption = options.getOrDefault("epoch", null);
+ if (epochOption != null) {
+ return Integer.parseInt(epochOption.toString());
+ }
+ }
+ return Utils.getCurrentEpoch(getModelDir(), prefix);
+ }
+
+ @SuppressWarnings("PMD.UseConcurrentHashMap")
+ private void loadParameters(Path paramFile) {
+
+ NDList paramNDlist = JnaUtils.loadNdArray(this, paramFile, getDevice());
+
+ List<Parameter> parameters = getSymbolBlock().getAllParameters();
+ Map<String, Parameter> map = new LinkedHashMap<>();
+ parameters.forEach(p -> map.put(p.getName(), p));
+
+ for (NDArray nd : paramNDlist) {
+ String key = nd.getName();
+ if (key == null) {
+ throw new IllegalArgumentException("Array names must be present in parameter file");
+ }
+
+ String paramName = key.split(":", 2)[1];
+ Parameter parameter = map.remove(paramName);
+ parameter.setArray(nd);
+ }
+ getSymbolBlock().setInputNames(new ArrayList<>(map.keySet()));
+
+ // TODO: Find a better to infer model DataType from SymbolBlock.
+ dataType = paramNDlist.head().getDataType();
+ logger.debug("MXNet Model {} ({}) loaded successfully.", paramFile, dataType);
+ }
+
+ /**
+ * Get the modelDir from the Model.
+ *
+ * @return {@link Path} modelDir for the Model
+ */
+ public Path getModelDir() {
+ return modelDir;
+ }
+
+ /**
+ * Set the modelDir for the Model.
+ *
+ * @param modelDir {@link Path}
+ */
+ public void setModelDir(Path modelDir) {
+ this.modelDir = modelDir;
+ }
+
+ /**
+ * Get the symbolBlock of the Model.
+ *
+ * @return {@link SymbolBlock}
+ */
+ public SymbolBlock getSymbolBlock() {
+ return symbolBlock;
+ }
+
+ /**
+ * Set the symbolBlock for the Model.
+ *
+ * @param symbolBlock {@link SymbolBlock}
+ */
+ public void setMxSymbolBlock(SymbolBlock symbolBlock) {
+ this.symbolBlock = symbolBlock;
+ }
+
+ /**
+ * Get the name of the Model.
+ *
+ * @return modelName
+ */
+ public String getModelName() {
+ return modelName;
+ }
+
+ /**
+ * Set the model name for the Model.
+ *
+ * @param modelName for the Model
+ */
+ public final void setModelName(String modelName) {
+ this.modelName = modelName;
+ }
+
+ /**
+ * Get data type for the Model.
+ *
+ * @return {@link DataType}
+ */
+ public DataType getDataType() {
+ return dataType;
+ }
+
+ /**
+ * Set data type for the Model.
+ *
+ * @param dataType {@link DataType}
+ */
+ public final void setDataType(DataType dataType) {
+ this.dataType = dataType;
+ }
+
+ /**
+ * Get input data of the Model.
+ *
+ * @return {@link PairList} inputData
+ */
+ public PairList<String, Shape> getInputData() {
+ return inputData;
+ }
+
+ /**
+ * Set input data for the Model.
+ *
+ * @param inputData {@link PairList}
+ */
+ public void setInputData(PairList<String, Shape> inputData) {
+ this.inputData = inputData;
+ }
+
+ /**
+ * Get the Artifact Object from artifacts by key.
+ *
+ * @param key for the Artifact Object
+ * @return Artifact {@link Object} instance
+ */
+ public Object getArtifact(String key) {
+ return artifacts.get(key);
+ }
+
+ /**
+ * Set the Artifact Object for artifacts.
+ *
+ * @param key for the Artifact
+ * @param artifact {@link Object}
+ */
+ public void setArtifact(String key, Object artifact) {
+ artifacts.put(key, artifact);
+ }
+
+ /**
+ * Get the property from properties by key.
+ *
+ * @param key {@link String}
+ * @return {@link String} property
+ */
+ public String getProperty(String key) {
+ return properties.get(key);
+ }
+
+ /**
+ * Set the property for the Model.
+ *
+ * @param key for the property
+ * @param property value of the property
+ */
+ public void setProperties(String key, String property) {
+ this.properties.put(key, property);
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public Device getDevice() {
+ if (device == null) {
+ return super.getDevice();
+ }
+ return device;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public void close() {
+ if (!getClosed()) {
+ logger.debug(String.format("Start to free Model instance: %S", this.getModelName()));
+ // release sub resources
+ super.freeSubResources();
+ // release itself
+ this.symbolBlock = null;
+ this.artifacts = null;
+ this.properties = null;
+ setClosed(true);
+ logger.debug(String.format("Finish to free Model instance: %S", this.getModelName()));
+ }
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/MxDataType.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/MxDataType.java
new file mode 100644
index 0000000..b96941f
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/MxDataType.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.engine;
+
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+import org.apache.mxnet.ndarray.types.DataType;
+
+/** Helper to convert between {@link DataType} and the MXNet internal DataTypes. */
+public final class MxDataType {
+
+ private static Map<DataType, String> toMx = createMapToMx();
+ private static Map<String, DataType> fromMx = createMapFromMx();
+
+ private MxDataType() {}
+
+ private static Map<DataType, String> createMapToMx() {
+ Map<DataType, String> map = new ConcurrentHashMap<>();
+ map.put(DataType.FLOAT32, "float32");
+ map.put(DataType.FLOAT64, "float64");
+ map.put(DataType.INT32, "int32");
+ map.put(DataType.INT64, "int64");
+ map.put(DataType.UINT8, "uint8");
+ return map;
+ }
+
+ private static Map<String, DataType> createMapFromMx() {
+ Map<String, DataType> map = new ConcurrentHashMap<>();
+ map.put("float32", DataType.FLOAT32);
+ map.put("float64", DataType.FLOAT64);
+ map.put("int32", DataType.INT32);
+ map.put("int64", DataType.INT64);
+ map.put("uint8", DataType.UINT8);
+ return map;
+ }
+
+ /**
+ * Converts a MXNet type String into a {@link DataType}.
+ *
+ * @param mxType the type String to convert
+ * @return the {@link DataType}
+ */
+ public static DataType fromMx(String mxType) {
+ return fromMx.get(mxType);
+ }
+
+ /**
+ * Converts a {@link DataType} into the corresponding MXNet type String.
+ *
+ * @param jType the java {@link DataType} to convert
+ * @return the converted MXNet type string
+ */
+ public static String toMx(DataType jType) {
+ return toMx.get(jType);
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/MxResource.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/MxResource.java
new file mode 100644
index 0000000..5740d9b
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/MxResource.java
@@ -0,0 +1,171 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.engine;
+
+import com.sun.jna.Pointer;
+import java.util.UUID;
+import java.util.concurrent.ConcurrentHashMap;
+import org.apache.mxnet.util.NativeResource;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * An auto closable Resource object whose life circle can be managed by its parent {@link
+ * MxResource} instance. Meanwhile, it manages life circle of child {@link MxResource} instances.
+ */
+public class MxResource extends NativeResource<Pointer> {
+
+ private static final Logger logger = LoggerFactory.getLogger(MxResource.class);
+
+ private static boolean closed;
+
+ protected Device device;
+
+ private MxResource parent;
+
+ private ConcurrentHashMap<String, MxResource> subResources;
+
+ protected MxResource() {
+ super();
+ setParent(null);
+ }
+
+ protected MxResource(MxResource parent, String uid) {
+ super(uid);
+ setClosed(false);
+ setParent(parent);
+ getParent().addSubResource(this);
+ }
+
+ protected MxResource(MxResource parent) {
+ this(parent, UUID.randomUUID().toString());
+ }
+
+ protected MxResource(MxResource parent, Pointer handle) {
+ super(handle);
+ setParent(parent);
+ if (parent != null) {
+ parent.addSubResource(this);
+ } else {
+ BaseMxResource.getSystemMxResource().addSubResource(this);
+ }
+ }
+ /**
+ * Add the sub {@link MxResource} under the current instance.
+ *
+ * @param subMxResource the instance to be added
+ */
+ public void addSubResource(MxResource subMxResource) {
+ getSubResource().put(subMxResource.getUid(), subMxResource);
+ }
+
+ /** Free all sub {@link MxResource} instances of the current instance. */
+ public void freeSubResources() {
+ if (subResourceInitialized()) {
+ for (MxResource subResource : subResources.values()) {
+ try {
+ subResource.close();
+ } catch (Exception e) {
+ logger.error("MxResource close failed.", e);
+ }
+ }
+ subResources = null;
+ }
+ }
+
+ /**
+ * Check whether {@code subResource} has been initialized.
+ *
+ * @return boolean
+ */
+ public boolean subResourceInitialized() {
+ return subResources != null;
+ }
+
+ /**
+ * Get the {@code subResources} of the {@link MxResource}.
+ *
+ * @return subResources
+ */
+ public ConcurrentHashMap<String, MxResource> getSubResource() {
+ if (!subResourceInitialized()) {
+ subResources = new ConcurrentHashMap<>();
+ }
+ return subResources;
+ }
+
+ protected final void setParent(MxResource parent) {
+ this.parent = parent;
+ }
+
+ /**
+ * Get parent {@link MxResource} of the current instance.
+ *
+ * @return {@link MxResource}
+ */
+ public MxResource getParent() {
+ return this.parent;
+ }
+
+ /**
+ * Set the {@link Device} for the {@link MxResource}.
+ *
+ * @param device {@link Device}
+ */
+ public void setDevice(Device device) {
+ this.device = device;
+ }
+
+ /**
+ * Returns the {@link Device} of this {@code MxResource}.
+ *
+ * <p>{@link Device} class contains the information where this {@code NDArray} stored in memory,
+ * like CPU/GPU.
+ *
+ * @return the {@link Device} of this {@code MxResource}
+ */
+ public Device getDevice() {
+ Device curDevice = getParent() == null ? null : getParent().getDevice();
+ return Device.defaultIfNull(curDevice);
+ }
+
+ /**
+ * Sets closed for MxResource instance.
+ *
+ * @param isClosed whether this {@link MxResource} get closed
+ */
+ public final void setClosed(boolean isClosed) {
+ this.closed = isClosed;
+ }
+
+ /**
+ * Get the attribute closed for the MxResource to check out whether it is closed.
+ *
+ * @return closed
+ */
+ public boolean getClosed() {
+ return closed;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public void close() {
+ freeSubResources();
+ setClosed(true);
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/MxResourceList.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/MxResourceList.java
new file mode 100644
index 0000000..6869cf6
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/MxResourceList.java
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.engine;
+
+import java.util.List;
+import java.util.Map;
+import org.apache.mxnet.util.Pair;
+import org.apache.mxnet.util.PairList;
+
+/**
+ * An {@code MxResourceList} represents a sequence of {@link MxResource}s with names.
+ *
+ * <p>Each {@link MxResource} in this list can optionally have a name. You can use the name to look
+ * up an MxResource in the MxResourceList.
+ *
+ * @see MxResource
+ */
+public class MxResourceList extends PairList<String, MxResource> {
+
+ /** Creates an empty {@code MxResourceList}. */
+ public MxResourceList() {}
+
+ /**
+ * Constructs an empty {@code MxResourceList} with the specified initial capacity.
+ *
+ * @param initialCapacity the initial capacity of the list
+ * @throws IllegalArgumentException if the specified initial capacity is negative
+ */
+ public MxResourceList(int initialCapacity) {
+ super(initialCapacity);
+ }
+
+ /**
+ * Constructs a {@code BlockList} containing the elements of the specified keys and values.
+ *
+ * @param keys the key list containing the elements to be placed into this {@code
+ * MxResourceList}
+ * @param values the value list containing the elements to be placed into this {@code
+ * MxResource}
+ * @throws IllegalArgumentException if the keys and values size are different
+ */
+ public MxResourceList(List<String> keys, List<MxResource> values) {
+ super(keys, values);
+ }
+
+ /**
+ * Constructs a {@code BlockList} containing the elements of the specified list of Pairs.
+ *
+ * @param list the list containing the elements to be placed into this {@code MxResourceList}
+ */
+ public MxResourceList(List<Pair<String, MxResource>> list) {
+ super(list);
+ }
+
+ /**
+ * Constructs a {@code BlockList} containing the elements of the specified map.
+ *
+ * @param map the map containing keys and values
+ */
+ public MxResourceList(Map<String, MxResource> map) {
+ super(map);
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/OpParams.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/OpParams.java
new file mode 100644
index 0000000..3362e03
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/OpParams.java
@@ -0,0 +1,217 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.engine;
+
+import org.apache.mxnet.ndarray.types.Shape;
+import org.apache.mxnet.ndarray.types.SparseFormat;
+import org.apache.mxnet.util.PairList;
+
+/** An internal helper for creating the MXNet operator parameters. */
+public class OpParams extends PairList<String, Object> {
+ // mxnet cpu take index
+ private static final String MXNET_CPU = "cpu(0)";
+ /**
+ * Sets the Shape parameter.
+ *
+ * @param shape the shape to set
+ */
+ public void setShape(Shape shape) {
+ addParam("shape", shape);
+ }
+
+ /**
+ * Sets the device to use for the operation.
+ *
+ * @param device the device to use for the operation
+ */
+ public void setDevice(Device device) {
+ setParam("ctx", ("cpu".equals(device.getDeviceType()) ? MXNET_CPU : device.toString()));
+ }
+
+ /**
+ * Sets the dataType to use for the operation.
+ *
+ * @param dataType the dataType to use for the operation
+ */
+ public void setDataType(org.apache.mxnet.ndarray.types.DataType dataType) {
+ if (dataType != null) {
+ setParam("dtype", MxDataType.toMx(dataType));
+ }
+ }
+
+ /**
+ * Sets the sparseFormat to use for the operation.
+ *
+ * @param sparseFormat the sparseFormat to use for the operation
+ */
+ public void setSparseFormat(SparseFormat sparseFormat) {
+ if (sparseFormat != null) {
+ setParam("stype", String.valueOf(sparseFormat.getValue()));
+ }
+ }
+
+ /**
+ * Sets a (potentially existing) parameter to a new value.
+ *
+ * @param paramName the parameter name to update
+ * @param value the value to set the parameter to
+ */
+ public void setParam(String paramName, String value) {
+ remove(paramName);
+ add(paramName, value);
+ }
+
+ /**
+ * Adds a parameter.
+ *
+ * @param paramName the name of the new parameter
+ * @param shape the value of the new parameter
+ */
+ public void addParam(String paramName, Shape shape) {
+ if (shape != null) {
+ add(paramName, shape.toString());
+ }
+ }
+
+ /**
+ * Adds a parameter.
+ *
+ * @param paramName the name of the new parameter
+ * @param value the value of the new parameter
+ */
+ public void addParam(String paramName, String value) {
+ add(paramName, value);
+ }
+
+ /**
+ * Adds a parameter.
+ *
+ * @param paramName the name of the new parameter
+ * @param value the value of the new parameter
+ */
+ public void addParam(String paramName, int value) {
+ add(paramName, String.valueOf(value));
+ }
+
+ /**
+ * Adds a parameter.
+ *
+ * @param paramName the name of the new parameter
+ * @param value the value of the new parameter
+ */
+ public void addParam(String paramName, long value) {
+ add(paramName, String.valueOf(value));
+ }
+
+ /**
+ * Adds a parameter.
+ *
+ * @param paramName the name of the new parameter
+ * @param value the value of the new parameter
+ */
+ public void addParam(String paramName, double value) {
+ add(paramName, String.valueOf(value));
+ }
+
+ /**
+ * Adds a parameter.
+ *
+ * @param paramName the name of the new parameter
+ * @param value the value of the new parameter
+ */
+ public void addParam(String paramName, float value) {
+ add(paramName, String.valueOf(value));
+ }
+
+ /**
+ * Adds a parameter.
+ *
+ * @param paramName the name of the new parameter
+ * @param value the value of the new parameter
+ */
+ public void addParam(String paramName, boolean value) {
+ add(paramName, value ? "True" : "False");
+ }
+
+ /**
+ * Adds a parameter.
+ *
+ * @param paramName the name of the new parameter
+ * @param value the value of the new parameter
+ */
+ public void addParam(String paramName, Number value) {
+ add(paramName, String.valueOf(value));
+ }
+
+ /**
+ * Adds a parameter with tuple value.
+ *
+ * @param paramName the name of the new parameter
+ * @param tuple the values of the new parameter
+ */
+ public void addTupleParam(String paramName, int... tuple) {
+ StringBuilder sb = new StringBuilder();
+ sb.append('(');
+ for (int i = 0; i < tuple.length; ++i) {
+ if (i > 0) {
+ sb.append(", ");
+ }
+ sb.append(tuple[i]);
+ }
+ sb.append(')');
+ add(paramName, sb.toString());
+ }
+
+ /**
+ * Adds a parameter with tuple value.
+ *
+ * @param paramName the name of the new parameter
+ * @param tuple the values of the new parameter
+ */
+ public void addTupleParam(String paramName, long... tuple) {
+ StringBuilder sb = new StringBuilder();
+ sb.append('(');
+ for (int i = 0; i < tuple.length; ++i) {
+ if (i > 0) {
+ sb.append(", ");
+ }
+ sb.append(tuple[i]);
+ }
+ sb.append(')');
+ add(paramName, sb.toString());
+ }
+
+ /**
+ * Adds a parameter with tuple value.
+ *
+ * @param paramName the name of the new parameter
+ * @param tuple the values of the new parameter
+ */
+ public void addTupleParam(String paramName, float... tuple) {
+ StringBuilder sb = new StringBuilder();
+ sb.append('(');
+ for (int i = 0; i < tuple.length; ++i) {
+ if (i > 0) {
+ sb.append(", ");
+ }
+ sb.append(tuple[i]);
+ }
+ sb.append(')');
+ add(paramName, sb.toString());
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/Predictor.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/Predictor.java
new file mode 100644
index 0000000..739f6e0
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/Predictor.java
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.engine;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import org.apache.mxnet.exception.TranslateException;
+import org.apache.mxnet.ndarray.NDList;
+import org.apache.mxnet.translate.Translator;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * The {@code Predictor} class provides a session for model inference.
+ *
+ * <p>You can use a {@code Predictor}, with a specified {@link Translator}, to perform inference on
+ * a {@link Model}
+ *
+ * @param <I> the input type
+ * @param <O> the output type
+ * @see Model
+ * @see Translator
+ * @see <a href="http://docs.djl.ai/docs/development/memory_management.html">The guide on memory
+ * management</a>
+ * @see <a
+ * href="https://github.com/deepjavalibrary/djl/blob/master/examples/docs/multithread_inference.md">The
+ * guide on running multi-threaded inference</a>
+ * @see <a href="http://docs.djl.ai/docs/development/inference_performance_optimization.html">The
+ * guide on inference performance optimization</a>
+ */
+public class Predictor<I, O> extends MxResource {
+
+ private static final Logger logger = LoggerFactory.getLogger(Predictor.class);
+ private Translator<I, O> translator;
+ private Model model;
+
+ /**
+ * Creates a new instance of {@code Predictor} with the given {@link Model} and {@link
+ * Translator}.
+ *
+ * @param model the model on which the predictions are based
+ * @param translator the translator to be used
+ */
+ public Predictor(Model model, Translator<I, O> translator) {
+ super(model);
+ this.model = model;
+ this.translator = translator;
+ }
+
+ /**
+ * Predicts an item for inference.
+ *
+ * @param input the input
+ * @return the output object defined by the user
+ * @throws TranslateException if an error occurs during prediction
+ */
+ @SuppressWarnings("PMD.AvoidRethrowingException")
+ public List<O> predict(List<I> input) {
+ NDList[] ndLists = processInputs(input);
+ for (int i = 0; i < ndLists.length; ++i) {
+ ndLists[i] = forward(ndLists[i]);
+ }
+ return processOutPut(ndLists);
+ }
+
+ /**
+ * Predicts an Item for inference.
+ *
+ * @param input input data
+ * @return O the output object defined by the user
+ * @throws TranslateException if an error occurs during prediction
+ */
+ public O predict(I input) {
+ return predict(Collections.singletonList(input)).get(0);
+ }
+
+ private NDList forward(NDList ndList) {
+ logger.trace("Predictor input data: {}", ndList);
+ return model.getSymbolBlock().forward(ndList);
+ }
+
+ // TODO: add batch predict
+
+ private NDList[] processInputs(List<I> inputs) {
+ int batchSize = inputs.size();
+ NDList[] preprocessed = new NDList[batchSize];
+ try {
+ for (int i = 0; i < batchSize; ++i) {
+ preprocessed[i] = translator.processInput(inputs.get(i));
+ }
+ } catch (Exception e) {
+ logger.error("Error occurs when process input items.", e);
+ throw new TranslateException(e);
+ }
+ return preprocessed;
+ }
+
+ private List<O> processOutPut(NDList[] ndLists) {
+ List<O> outputs = new ArrayList<>();
+ try {
+ for (NDList mxNDList : ndLists) {
+ outputs.add(translator.processOutput(mxNDList));
+ }
+ } catch (Exception e) {
+ logger.error("Error occurs when process output items.", e);
+ throw new TranslateException(e);
+ }
+ return outputs;
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/Symbol.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/Symbol.java
new file mode 100644
index 0000000..7777add
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/Symbol.java
@@ -0,0 +1,242 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.engine;
+
+import com.sun.jna.Pointer;
+import java.nio.file.Path;
+import java.util.Arrays;
+import java.util.LinkedHashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.stream.Collectors;
+import org.apache.mxnet.jna.JnaUtils;
+import org.apache.mxnet.ndarray.types.Shape;
+import org.apache.mxnet.util.PairList;
+import org.apache.mxnet.util.Utils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * {@code Symbol} is an internal helper for symbolic model graphs used by the {@link
+ * org.apache.mxnet.nn.SymbolBlock}.
+ *
+ * @see org.apache.mxnet.nn.SymbolBlock
+ */
+public class Symbol extends MxResource {
+
+ private static final Logger logger = LoggerFactory.getLogger(Symbol.class);
+
+ private String[] outputs;
+
+ protected Symbol(MxResource parent, Pointer handle) {
+ super(parent, handle);
+ }
+
+ static Symbol loadFromFile(MxResource parent, String path) {
+ Pointer p = JnaUtils.createSymbolFromFile(path);
+ return new Symbol(parent, p);
+ }
+
+ /**
+ * Load {@link Symbol} from the given {@link Path}.
+ *
+ * @param parent the parent {@link MxResource}
+ * @param path the {@link Path} to load the {@link Symbol}
+ * @return {@link Symbol}
+ */
+ public static Symbol loadSymbol(MxResource parent, Path path) {
+ return loadFromFile(parent, path.toAbsolutePath().toString());
+ }
+
+ /**
+ * Loads a symbol from a json string.
+ *
+ * @param parent the parent {@link MxResource}
+ * @param json the json string of the symbol.
+ * @return the new symbol
+ */
+ public static Symbol loadJson(MxResource parent, String json) {
+ Pointer pointer = JnaUtils.createSymbolFromString(json);
+ return new Symbol(parent, pointer);
+ }
+
+ /**
+ * Returns the symbol outputs.
+ *
+ * @return the symbol outputs
+ */
+ public String[] getOutputNames() {
+ if (this.outputs == null) {
+ this.outputs = JnaUtils.listSymbolOutputs(getHandle());
+ }
+ return this.outputs;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public void close() {
+ if (!getClosed()) {
+ logger.debug(String.format("Start to free Symbol instance: %S", this.toJsonString()));
+ super.freeSubResources();
+ Pointer pointer = handle.getAndSet(null);
+ if (pointer != null) {
+ JnaUtils.freeSymbol(pointer);
+ }
+ setClosed(true);
+ logger.debug(String.format("Finish to free Symbol instance: %S", this.toJsonString()));
+ }
+ }
+
+ /**
+ * Returns the output symbol by index.
+ *
+ * @param index the index of the output
+ * @return the symbol output as a new symbol
+ */
+ public Symbol get(int index) {
+ Pointer pointer = JnaUtils.getSymbolOutput(getInternals().getHandle(), index);
+ return new Symbol(getParent(), pointer);
+ }
+
+ /**
+ * Returns the output symbol with the given name.
+ *
+ * @param name the name of the symbol to return
+ * @return the output symbol
+ * @throws IllegalArgumentException Thrown if no output matches the name
+ */
+ public Symbol get(String name) {
+ String[] out = getInternalOutputNames();
+ int index = Utils.indexOf(out, name);
+ if (index < 0) {
+ throw new IllegalArgumentException("Cannot find output that matches name: " + name);
+ }
+ return get(index);
+ }
+
+ /**
+ * Returns the symbol argument names.
+ *
+ * @return the symbol argument names
+ */
+ public String[] getArgNames() {
+ return JnaUtils.listSymbolArguments(getHandle());
+ }
+
+ /**
+ * Returns the MXNet auxiliary states for the symbol.
+ *
+ * @return the MXNet auxiliary states for the symbol
+ */
+ public String[] getAuxNames() {
+ return JnaUtils.listSymbolAuxiliaryStates(getHandle());
+ }
+
+ /**
+ * Returns the symbol names.
+ *
+ * @return the symbol names
+ */
+ public String[] getAllNames() {
+ return JnaUtils.listSymbolNames(getHandle());
+ }
+
+ /**
+ * Returns the list of names for all internal outputs.
+ *
+ * @return a list of names
+ */
+ public List<String> getLayerNames() {
+ String[] outputNames = getInternalOutputNames();
+ String[] allNames = getAllNames();
+ Set<String> allNamesSet = new LinkedHashSet<>(Arrays.asList(allNames));
+ // Kill all params field and keep the output layer
+ return Arrays.stream(outputNames)
+ .filter(n -> !allNamesSet.contains(n))
+ .collect(Collectors.toList());
+ }
+
+ private String[] getInternalOutputNames() {
+ return JnaUtils.listSymbolOutputs(getInternals().getHandle());
+ }
+
+ /**
+ * Returns the symbol internals.
+ *
+ * @return the symbol internals symbol
+ */
+ public Symbol getInternals() {
+ Pointer pointer = JnaUtils.getSymbolInternals(getHandle());
+ return new Symbol(getParent(), pointer);
+ }
+
+ /**
+ * Infers the shapes for all parameters inside a symbol from the given input shapes.
+ *
+ * @param pairs the given input name and shape
+ * @return a map of arguments with names and shapes
+ */
+ public Map<String, Shape> inferShape(PairList<String, Shape> pairs) {
+ List<List<Shape>> shapes = JnaUtils.inferShape(this, pairs);
+ if (shapes == null) {
+ throw new IllegalArgumentException("Cannot infer shape based on the data provided!");
+ }
+ List<Shape> argShapes = shapes.get(0);
+ List<Shape> outputShapes = shapes.get(1);
+ List<Shape> auxShapes = shapes.get(2);
+ // TODO: add output to the map
+ String[] argNames = getArgNames();
+ String[] auxNames = getAuxNames();
+ String[] outputNames = getOutputNames();
+ Map<String, Shape> shapesMap = new ConcurrentHashMap<>();
+ for (int i = 0; i < argNames.length; i++) {
+ shapesMap.put(argNames[i], argShapes.get(i));
+ }
+ for (int i = 0; i < auxNames.length; i++) {
+ shapesMap.put(auxNames[i], auxShapes.get(i));
+ }
+ for (int i = 0; i < outputNames.length; i++) {
+ shapesMap.put(outputNames[i], outputShapes.get(i));
+ }
+ return shapesMap;
+ }
+
+ /**
+ * [Experimental] Add customized optimization on the Symbol.
+ *
+ * <p>This method can be used with EIA or TensorRT for model acceleration
+ *
+ * @param backend backend name
+ * @param device the device assigned
+ * @return optimized Symbol
+ */
+ public Symbol optimizeFor(String backend, Device device) {
+ return new Symbol(getParent(), JnaUtils.optimizeFor(this, backend, device));
+ }
+
+ /**
+ * Converts Symbol to json string for saving purpose.
+ *
+ * @return the json string
+ */
+ public String toJsonString() {
+ return JnaUtils.getSymbolString(getHandle());
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/package-info.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/package-info.java
new file mode 100644
index 0000000..e88f05d
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/package-info.java
@@ -0,0 +1,19 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/** Contains the Java front-end implementation for Apache MXNet. */
+package org.apache.mxnet.engine;
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/BaseException.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/BaseException.java
new file mode 100644
index 0000000..0123a50
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/BaseException.java
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.exception;
+
+/** Thrown to indicate that a native error is raised from the underlying. */
+public class BaseException extends RuntimeException {
+
+ private static final long serialVersionUID = 1L;
+
+ /**
+ * Constructs a new exception with the specified detail message. The cause is not initialized,
+ * and may subsequently be initialized by a call to {@link #initCause}.
+ *
+ * @param message the detail message. The detail message is saved for later retrieval by the
+ * {@link #getMessage()} method.
+ */
+ public BaseException(String message) {
+ super(message);
+ }
+
+ /**
+ * Constructs a new exception with the specified detail message and cause.
+ *
+ * <p>Note that the detail message associated with {@code cause} is <i>not</i> automatically
+ * incorporated in this exception's detail message.
+ *
+ * @param message the detail message (which is saved for later retrieval by the {@link
+ * #getMessage()} method)
+ * @param cause the cause (which is saved for later retrieval by the {@link #getCause()}
+ * method). (A {@code null} value is permitted, and indicates that the cause is nonexistent
+ * or unknown.)
+ */
+ public BaseException(String message, Throwable cause) {
+ super(message, cause);
+ }
+
+ /**
+ * Constructs a new exception with the specified cause and a detail message of {@code
+ * (cause==null ? null : cause.toString())} (which typically contains the class and detail
+ * message of {@code cause}). This constructor is useful for exceptions that are little more
+ * than wrappers for other throwables (for example, {@link
+ * java.security.PrivilegedActionException}).
+ *
+ * @param cause the cause (which is saved for later retrieval by the {@link #getCause()}
+ * method). (A {@code null} value is permitted, and indicates that the cause is nonexistent
+ * or unknown.)
+ */
+ public BaseException(Throwable cause) {
+ super(cause);
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/JnaCallException.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/JnaCallException.java
new file mode 100644
index 0000000..f6f26fc
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/JnaCallException.java
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.exception;
+
+/** Thrown to indicate JNA functions are not called as expected. */
+public class JnaCallException extends BaseException {
+
+ private static final long serialVersionUID = 1L;
+
+ /**
+ * Constructs a new exception with the specified detail message. The cause is not initialized,
+ * and may subsequently be initialized by a call to {@link #initCause}.
+ *
+ * @param message the detail message. The detail message is saved for later retrieval by the
+ * {@link #getMessage()} method.
+ */
+ public JnaCallException(String message) {
+ super(message);
+ }
+
+ /**
+ * \ Constructs a new exception with the specified detail message and cause.
+ *
+ * <p>Note that the detail message associated with {@code cause} is <i>not</i> automatically
+ * incorporated in this exception's detail message.
+ *
+ * @param message the detail message that is saved for later retrieval by the {@link
+ * #getMessage()} method
+ * @param cause the cause that is saved for later retrieval by the {@link #getCause()} method. A
+ * {@code null} value is permitted, and indicates that the cause is nonexistent or unknown
+ */
+ public JnaCallException(String message, Throwable cause) {
+ super(message, cause);
+ }
+
+ /**
+ * Constructs a new exception with the specified cause and a detail message of {@code
+ * (cause==null ? null : cause.toString())} which typically contains the class and detail
+ * message of {@code cause}. This constructor is useful for exceptions that are little more than
+ * wrappers for other throwables. For example, {@link java.security.PrivilegedActionException}.
+ *
+ * @param cause the cause that is saved for later retrieval by the {@link #getCause()} method. A
+ * {@code null} value is permitted, and indicates that the cause is nonexistent or unknown
+ */
+ public JnaCallException(Throwable cause) {
+ super(cause);
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/MalformedModelException.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/MalformedModelException.java
new file mode 100644
index 0000000..5455633
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/MalformedModelException.java
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.exception;
+
+/** Thrown to indicate Model parameters are not in expected format or are malformed. */
+public class MalformedModelException extends ModelException {
+
+ private static final long serialVersionUID = 1L;
+
+ /**
+ * Constructs a new exception with the specified detail message. The cause is not initialized,
+ * and may subsequently be initialized by a call to {@link #initCause}.
+ *
+ * @param message the detail message. The detail message is saved for later retrieval by the
+ * {@link #getMessage()} method.
+ */
+ public MalformedModelException(String message) {
+ super(message);
+ }
+
+ /**
+ * \ Constructs a new exception with the specified detail message and cause.
+ *
+ * <p>Note that the detail message associated with {@code cause} is <i>not</i> automatically
+ * incorporated in this exception's detail message.
+ *
+ * @param message the detail message that is saved for later retrieval by the {@link
+ * #getMessage()} method
+ * @param cause the cause that is saved for later retrieval by the {@link #getCause()} method. A
+ * {@code null} value is permitted, and indicates that the cause is nonexistent or unknown
+ */
+ public MalformedModelException(String message, Throwable cause) {
+ super(message, cause);
+ }
+
+ /**
+ * Constructs a new exception with the specified cause and a detail message of {@code
+ * (cause==null ? null : cause.toString())} which typically contains the class and detail
+ * message of {@code cause}. This constructor is useful for exceptions that are little more than
+ * wrappers for other throwables. For example, {@link java.security.PrivilegedActionException}.
+ *
+ * @param cause the cause that is saved for later retrieval by the {@link #getCause()} method. A
+ * {@code null} value is permitted, and indicates that the cause is nonexistent or unknown
+ */
+ public MalformedModelException(Throwable cause) {
+ super(cause);
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/ModelException.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/ModelException.java
new file mode 100644
index 0000000..94a12e7
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/ModelException.java
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.exception;
+
+/** Thrown to indicate . */
+public class ModelException extends BaseException {
+
+ private static final long serialVersionUID = 1L;
+
+ /**
+ * Constructs a new exception with the specified detail message. The cause is not initialized,
+ * and may subsequently be initialized by a call to {@link #initCause}.
+ *
+ * @param message the detail message that is saved for later retrieval by the {@link
+ * #getMessage()} method
+ */
+ public ModelException(String message) {
+ super(message);
+ }
+
+ /**
+ * Constructs a new exception with the specified detail message and cause.
+ *
+ * <p>Note that the detail message associated with {@code cause} is <i>not</i> automatically
+ * incorporated in this exception's detail message.
+ *
+ * @param message the detail message that is saved for later retrieval by the {@link
+ * #getMessage()} method
+ * @param cause the cause that is saved for later retrieval by the {@link #getCause()} method. A
+ * {@code null} value is permitted, and indicates that the cause is nonexistent or unknown
+ */
+ public ModelException(String message, Throwable cause) {
+ super(message, cause);
+ }
+
+ /**
+ * Constructs a new exception with the specified cause and a detail message of {@code
+ * (cause==null ? null : cause.toString())} which typically contains the class and detail
+ * message of {@code cause}. This constructor is useful for exceptions that are little more than
+ * wrappers for other throwables. For example, {@link java.security.PrivilegedActionException}.
+ *
+ * @param cause the cause that is saved for later retrieval by the {@link #getCause()} method. A
+ * {@code null} value is permitted, and indicates that the cause is nonexistent or unknown
+ */
+ public ModelException(Throwable cause) {
+ super(cause);
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/TranslateException.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/TranslateException.java
new file mode 100644
index 0000000..b17758a
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/TranslateException.java
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.exception;
+
+/** Thrown to indicate Translate pipeline doesn't work as expected. */
+public class TranslateException extends BaseException {
+
+ private static final long serialVersionUID = 1L;
+
+ /**
+ * Constructs a new exception with the specified detail message. The cause is not initialized,
+ * and may subsequently be initialized by a call to {@link #initCause}.
+ *
+ * @param message the detail message. The detail message is saved for later retrieval by the
+ * {@link #getMessage()} method.
+ */
+ public TranslateException(String message) {
+ super(message);
+ }
+
+ /**
+ * \ Constructs a new exception with the specified detail message and cause.
+ *
+ * <p>Note that the detail message associated with {@code cause} is <i>not</i> automatically
+ * incorporated in this exception's detail message.
+ *
+ * @param message the detail message that is saved for later retrieval by the {@link
+ * #getMessage()} method
+ * @param cause the cause that is saved for later retrieval by the {@link #getCause()} method. A
+ * {@code null} value is permitted, and indicates that the cause is nonexistent or unknown
+ */
+ public TranslateException(String message, Throwable cause) {
+ super(message, cause);
+ }
+
+ /**
+ * Constructs a new exception with the specified cause and a detail message of {@code
+ * (cause==null ? null : cause.toString())} which typically contains the class and detail
+ * message of {@code cause}. This constructor is useful for exceptions that are little more than
+ * wrappers for other throwables. For example, {@link java.security.PrivilegedActionException}.
+ *
+ * @param cause the cause that is saved for later retrieval by the {@link #getCause()} method. A
+ * {@code null} value is permitted, and indicates that the cause is nonexistent or unknown
+ */
+ public TranslateException(Throwable cause) {
+ super(cause);
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/package-info.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/package-info.java
new file mode 100644
index 0000000..d464bfd
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/package-info.java
@@ -0,0 +1,19 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/** Contains the Java front-end implementation for Apache MXNet. */
+package org.apache.mxnet.exception;
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/FunctionInfo.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/FunctionInfo.java
new file mode 100644
index 0000000..a2da116
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/FunctionInfo.java
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.jna;
+
+import com.sun.jna.Pointer;
+import java.util.List;
+import org.apache.mxnet.engine.Device;
+import org.apache.mxnet.engine.MxResource;
+import org.apache.mxnet.ndarray.NDArray;
+import org.apache.mxnet.ndarray.types.SparseFormat;
+import org.apache.mxnet.util.PairList;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/** A FunctionInfo represents an operator (ie function) within the MXNet Engine. */
+public class FunctionInfo {
+
+ private Pointer handle;
+ private String name;
+ private PairList<String, String> arguments;
+
+ private static final Logger logger = LoggerFactory.getLogger(FunctionInfo.class);
+
+ FunctionInfo(Pointer pointer, String functionName, PairList<String, String> arguments) {
+ this.handle = pointer;
+ this.name = functionName;
+ this.arguments = arguments;
+ }
+
+ /**
+ * Returns the name of the operator.
+ *
+ * @return the name of the operator
+ */
+ public String getFunctionName() {
+ return name;
+ }
+
+ /**
+ * Returns the names of the params to the operator.
+ *
+ * @return the names of the params to the operator
+ */
+ public List<String> getArgumentNames() {
+ return arguments.keys();
+ }
+
+ /**
+ * Returns the types of the operator arguments.
+ *
+ * @return the types of the operator arguments
+ */
+ public List<String> getArgumentTypes() {
+ return arguments.values();
+ }
+ /**
+ * Calls an operator with the given arguments.
+ *
+ * @param src the input NDArray(s) to the operator
+ * @param dest the destination NDArray(s) to be overwritten with the result of the operator
+ * @param params the non-NDArray arguments to the operator. Should be a {@code PairList<String,
+ * String>}
+ * @return the error code or zero for no errors
+ */
+ public int invoke(NDArray[] src, NDArray[] dest, PairList<String, ?> params) {
+ checkDevices(src);
+ checkDevices(dest);
+ return JnaUtils.imperativeInvoke(handle, src, dest, params).size();
+ }
+
+ /**
+ * Calls an operator with the given arguments.
+ *
+ * @param parent {@link MxResource} for the current instance
+ * @param src the input NDArray(s) to the operator
+ * @param params the non-NDArray arguments to the operator. Should be a {@code PairList<String,
+ * String>}
+ * @return the error code or zero for no errors
+ */
+ public NDArray[] invoke(MxResource parent, NDArray[] src, PairList<String, ?> params) {
+ checkDevices(src);
+ PairList<Pointer, SparseFormat> pairList =
+ JnaUtils.imperativeInvoke(handle, src, null, params);
+ return pairList.stream()
+ .map(
+ pair -> {
+ if (pair.getValue() != SparseFormat.DENSE) {
+ return NDArray.create(parent, pair.getKey(), pair.getValue());
+ }
+ return NDArray.create(parent, pair.getKey());
+ })
+ .toArray(NDArray[]::new);
+ }
+
+ private void checkDevices(NDArray[] src) {
+ // check if all the NDArrays are in the same device
+ if (logger.isDebugEnabled() && src.length > 1) {
+ Device device = src[0].getDevice();
+ for (int i = 1; i < src.length; ++i) {
+ if (!device.equals(src[i].getDevice())) {
+ logger.warn(
+ "Please make sure all the NDArrays are in the same device. You can call toDevice() to move the NDArray to the desired Device.");
+ }
+ }
+ }
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/JnaUtils.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/JnaUtils.java
new file mode 100644
index 0000000..46cff30
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/JnaUtils.java
@@ -0,0 +1,893 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.jna;
+
+import com.sun.jna.Memory;
+import com.sun.jna.Native;
+import com.sun.jna.Pointer;
+import com.sun.jna.ptr.PointerByReference;
+import java.nio.Buffer;
+import java.nio.ByteBuffer;
+import java.nio.IntBuffer;
+import java.nio.charset.StandardCharsets;
+import java.nio.file.Path;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
+import org.apache.mxnet.engine.CachedOp;
+import org.apache.mxnet.engine.Device;
+import org.apache.mxnet.engine.DeviceType;
+import org.apache.mxnet.engine.MxResource;
+import org.apache.mxnet.engine.Symbol;
+import org.apache.mxnet.exception.JnaCallException;
+import org.apache.mxnet.ndarray.NDArray;
+import org.apache.mxnet.ndarray.NDList;
+import org.apache.mxnet.ndarray.types.DataType;
+import org.apache.mxnet.ndarray.types.Shape;
+import org.apache.mxnet.ndarray.types.SparseFormat;
+import org.apache.mxnet.nn.Parameter;
+import org.apache.mxnet.nn.SymbolBlock;
+import org.apache.mxnet.util.PairList;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * A class containing utilities to interact with the MXNet Engine's Java Native Access (JNA) layer.
+ */
+@SuppressWarnings("MissingJavadocMethod")
+public final class JnaUtils {
+
+ private static final Logger logger = LoggerFactory.getLogger(JnaUtils.class);
+
+ public static final MxnetLibrary LIB = LibUtils.loadLibrary();
+
+ public static final ObjectPool<PointerByReference> REFS =
+ new ObjectPool<>(PointerByReference::new, r -> r.setValue(null));
+
+ private static final String[] OP_NAME_PREFIX = {
+ "_contrib_", "_linalg_", "_sparse_", "_image_", "_random_"
+ };
+
+ private static final Map<String, FunctionInfo> OPS = getNdArrayFunctions();
+ // private static final Map<String, FunctionInfo> OPS = null;
+
+ private static final Set<String> FEATURES = getFeaturesInternal();
+
+ public static final String[] EMPTY_ARRAY = new String[0];
+
+ private JnaUtils() {
+ // not called
+ }
+
+ /** An enum that enumerates the statuses of numpy mode. */
+ public enum NumpyMode {
+ OFF,
+ THREAD_LOCAL_ON,
+ GLOBAL_ON
+ }
+
+ public static void waitAll() {
+ checkCall(LIB.MXNDArrayWaitAll());
+ }
+
+ public static void setNumpyMode(NumpyMode mode) {
+ IntBuffer ret = IntBuffer.allocate(1);
+ checkCall(LIB.MXSetIsNumpyShape(mode.ordinal(), ret));
+ }
+
+ /////////////////////////////////
+ // Related to CacheOp
+ /////////////////////////////////
+ public static CachedOp createCachedOp(SymbolBlock block, MxResource parent) {
+ Symbol symbol = block.getSymbol();
+
+ List<Parameter> parameters = block.getAllParameters();
+
+ // record data index in all inputs
+ PairList<String, Integer> dataIndices = new PairList<>();
+ // record parameter index in all inputs
+ List<Integer> paramIndices = new ArrayList<>();
+ int index = 0;
+ for (Parameter parameter : parameters) {
+ // We assume uninitialized parameters are data inputs
+ if (parameter.isInitialized()) {
+ paramIndices.add(index);
+ } else {
+ dataIndices.add(parameter.getName(), index);
+ }
+ ++index;
+ }
+
+ // Creating CachedOp
+ Pointer symbolHandle = symbol.getHandle();
+ PointerByReference ref = REFS.acquire();
+
+ // static_alloc and static_shape are enabled by default
+ String[] keys = {"data_indices", "param_indices", "static_alloc", "static_shape"};
+ String[] values = {dataIndices.values().toString(), paramIndices.toString(), "1", "1"};
+
+ checkCall(LIB.MXCreateCachedOp(symbolHandle, keys.length, keys, values, ref, (byte) 0));
+ Pointer pointer = ref.getValue();
+ REFS.recycle(ref);
+
+ return new CachedOp(parent, pointer, parameters, paramIndices, dataIndices);
+ }
+
+ public static void freeCachedOp(Pointer handle) {
+ checkCall(LIB.MXFreeCachedOp(handle));
+ }
+
+ /////////////////////////////////
+ // About Symbol
+ /////////////////////////////////
+ public static Pointer createSymbolFromFile(String path) {
+ PointerByReference ref = REFS.acquire();
+ checkCall(LIB.MXSymbolCreateFromFile(path, ref));
+ Pointer pointer = ref.getValue();
+ REFS.recycle(ref);
+ return pointer;
+ }
+
+ public static Pointer createSymbolFromString(String json) {
+ PointerByReference ref = REFS.acquire();
+ checkCall(LIB.MXSymbolCreateFromJSON(json, ref));
+ Pointer pointer = ref.getValue();
+ REFS.recycle(ref);
+ return pointer;
+ }
+
+ public static String[] listSymbolOutputs(Pointer symbol) {
+ IntBuffer size = IntBuffer.allocate(1);
+ PointerByReference ref = REFS.acquire();
+
+ checkCall(LIB.MXSymbolListOutputs(symbol, size, ref));
+ String[] ret = toStringArray(ref, size.get());
+ REFS.recycle(ref);
+ return ret;
+ }
+
+ public static String printSymbol(Pointer symbol) {
+ String[] outStr = new String[1];
+ checkCall(LIB.NNSymbolPrint(symbol, outStr));
+ return outStr[0];
+ }
+
+ public static void freeSymbol(Pointer symbol) {
+ checkCall(LIB.MXSymbolFree(symbol));
+ }
+
+ public static String[] listSymbolArguments(Pointer symbol) {
+ IntBuffer size = IntBuffer.allocate(1);
+ PointerByReference ref = REFS.acquire();
+
+ checkCall(LIB.MXSymbolListArguments(symbol, size, ref));
+
+ String[] ret = toStringArray(ref, size.get());
+ REFS.recycle(ref);
+ return ret;
+ }
+
+ public static String[] listSymbolAuxiliaryStates(Pointer symbol) {
+ IntBuffer size = IntBuffer.allocate(1);
+ PointerByReference ref = REFS.acquire();
+
+ checkCall(LIB.MXSymbolListAuxiliaryStates(symbol, size, ref));
+
+ String[] ret = toStringArray(ref, size.get());
+ REFS.recycle(ref);
+ return ret;
+ }
+
+ public static String[] listSymbolNames(Pointer symbol) {
+ IntBuffer size = IntBuffer.allocate(1);
+ PointerByReference ref = REFS.acquire();
+
+ checkCall(LIB.NNSymbolListInputNames(symbol, 0, size, ref));
+
+ String[] ret = toStringArray(ref, size.get());
+ REFS.recycle(ref);
+ return ret;
+ }
+
+ public static Pointer getSymbolInternals(Pointer symbol) {
+ PointerByReference ref = REFS.acquire();
+ checkCall(LIB.MXSymbolGetInternals(symbol, ref));
+ Pointer pointer = ref.getValue();
+ REFS.recycle(ref);
+ return pointer;
+ }
+
+ private static List<Shape> recoverShape(
+ NativeSizeByReference size, PointerByReference nDim, PointerByReference data) {
+ int shapeLength = (int) size.getValue().longValue();
+ if (shapeLength == 0) {
+ return new ArrayList<>();
+ }
+ int[] dims = nDim.getValue().getIntArray(0, shapeLength);
+ int flattenedLength = 0;
+ for (int dim : dims) {
+ flattenedLength += dim;
+ }
+ long[] flattenedShapes = data.getValue().getPointer(0).getLongArray(0, flattenedLength);
+ int idx = 0;
+ List<Shape> result = new ArrayList<>();
+ for (int dim : dims) {
+ long[] shape = new long[dim];
+ System.arraycopy(flattenedShapes, idx, shape, 0, dim);
+ idx += dim;
+ result.add(new Shape(shape));
+ }
+ return result;
+ }
+
+ public static List<List<Shape>> inferShape(Symbol symbol, PairList<String, Shape> args) {
+ Pointer handler = symbol.getHandle();
+ int numArgs = args.size();
+ String[] keys = args.keys().toArray(new String[0]);
+ // the following two is also the representation of
+ // CSR NDArray
+ long[] indPtr = new long[numArgs + 1];
+ Shape flattened = new Shape();
+ indPtr[0] = 0;
+ for (int i = 0; i < args.size(); i++) {
+ Shape shape = args.valueAt(i);
+ indPtr[i + 1] = shape.dimension();
+ flattened = flattened.addAll(shape);
+ }
+ long[] flattenedShapeArray = flattened.getShape();
+
+ NativeSizeByReference inShapeSize = new NativeSizeByReference();
+ PointerByReference inShapeNDim = REFS.acquire();
+ PointerByReference inShapeData = REFS.acquire();
+ NativeSizeByReference outShapeSize = new NativeSizeByReference();
+ PointerByReference outShapeNDim = REFS.acquire();
+ PointerByReference outShapeData = REFS.acquire();
+ NativeSizeByReference auxShapeSize = new NativeSizeByReference();
+ PointerByReference auxShapeNDim = REFS.acquire();
+ PointerByReference auxShapeData = REFS.acquire();
+ IntBuffer complete = IntBuffer.allocate(1);
+ checkCall(
+ LIB.MXSymbolInferShape64(
+ handler,
+ numArgs,
+ keys,
+ indPtr,
+ flattenedShapeArray,
+ inShapeSize,
+ inShapeNDim,
+ inShapeData,
+ outShapeSize,
+ outShapeNDim,
+ outShapeData,
+ auxShapeSize,
+ auxShapeNDim,
+ auxShapeData,
+ complete));
+ if (complete.get() != 0) {
+ return Arrays.asList(
+ recoverShape(inShapeSize, inShapeNDim, inShapeData),
+ recoverShape(outShapeSize, outShapeNDim, outShapeData),
+ recoverShape(auxShapeSize, auxShapeNDim, auxShapeData));
+ }
+ return null;
+ }
+
+ public static Pointer optimizeFor(Symbol current, String backend, Device device) {
+ // TODO: Support partition on parameters
+ PointerByReference returnedSymbolHandle = REFS.acquire();
+ // placeHolders
+ PointerByReference[] placeHolders = {
+ REFS.acquire(),
+ REFS.acquire(),
+ REFS.acquire(),
+ REFS.acquire(),
+ REFS.acquire(),
+ REFS.acquire()
+ };
+ // there is no need to update parameters
+ // TODO : check 22th parameter type
+ checkCall(
+ LIB.MXOptimizeForBackend(
+ current.getHandle(),
+ backend,
+ DeviceType.toDeviceType(device),
+ returnedSymbolHandle,
+ 0,
+ placeHolders[0],
+ 0,
+ placeHolders[1],
+ 0,
+ new String[0],
+ new String[0],
+ 0,
+ new String[0],
+ new long[0],
+ new int[0],
+ 0,
+ new String[0],
+ new int[0],
+ 0,
+ new String[0],
+ new int[0],
+ (byte) 0,
+ IntBuffer.allocate(0),
+ placeHolders[2],
+ placeHolders[3],
+ IntBuffer.allocate(0),
+ placeHolders[4],
+ placeHolders[5]));
+ Pointer ptr = returnedSymbolHandle.getValue();
+ REFS.recycle(returnedSymbolHandle);
+ Arrays.stream(placeHolders).forEach(REFS::recycle);
+ return ptr;
+ }
+
+ public static String getSymbolString(Pointer symbol) {
+ String[] holder = new String[1];
+ checkCall(LIB.MXSymbolSaveToJSON(symbol, holder));
+ return holder[0];
+ }
+
+ public static Pointer getSymbolOutput(Pointer symbol, int index) {
+ PointerByReference ref = REFS.acquire();
+ checkCall(LIB.MXSymbolGetOutput(symbol, index, ref));
+ Pointer pointer = ref.getValue();
+ REFS.recycle(ref);
+ return pointer;
+ }
+
+ public static NDList loadNdArray(MxResource parent, Path path, Device device) {
+ IntBuffer handlesSize = IntBuffer.allocate(1);
+ PointerByReference handlesRef = REFS.acquire();
+ PointerByReference namesRef = REFS.acquire();
+ IntBuffer namesSize = IntBuffer.allocate(1);
+ checkCall(LIB.MXNDArrayLoad(path.toString(), handlesSize, handlesRef, namesSize, namesRef));
+ int ndArrayCount = handlesSize.get();
+ int nameCount = namesSize.get();
+ if (nameCount > 0 && ndArrayCount != nameCount) {
+ throw new IllegalStateException(
+ "Mismatch between names and arrays in checkpoint file: " + path.toString());
+ }
+ Pointer[] handles = handlesRef.getValue().getPointerArray(0, ndArrayCount);
+ NDList ndList = new NDList();
+ if (nameCount == 0) {
+ for (Pointer handle : handles) {
+ ndList.add(NDArray.create(parent, handle));
+ }
+ } else {
+ String[] names = namesRef.getValue().getStringArray(0, nameCount);
+ for (int i = 0; i < ndArrayCount; i++) {
+ NDArray array = NDArray.create(parent, handles[i]);
+ array.setName(names[i]);
+ ndList.add(array);
+ }
+ }
+
+ REFS.recycle(namesRef);
+ REFS.recycle(handlesRef);
+
+ // MXNet always load NDArray on CPU
+ if (Device.cpu().equals(device)) {
+ return ndList;
+ }
+
+ NDList ret = ndList.toDevice(device, true);
+ ndList.close();
+ return ret;
+ }
+
+ public static PairList<String, Pointer> loadNdArrayFromFile(String path) {
+ IntBuffer handleSize = IntBuffer.allocate(1);
+ IntBuffer namesSize = IntBuffer.allocate(1);
+ PointerByReference handlesRef = REFS.acquire();
+ PointerByReference namesRef = REFS.acquire();
+ checkCall(LIB.MXNDArrayLoad(path, handleSize, handlesRef, namesSize, namesRef));
+ // TODO : construct NDArray Objects
+ int handleCount = handleSize.get();
+ int nameCount = namesSize.get();
+ if (nameCount > 0 && nameCount != handleCount) {
+ throw new IllegalStateException(
+ "Mismatch between names and arrays in checkpoint file: " + path);
+ }
+ Pointer[] handles = handlesRef.getValue().getPointerArray(0, handleCount);
+
+ PairList<String, Pointer> pairList = new PairList<>();
+
+ if (nameCount == 0) {
+ for (Pointer handle : handles) {
+ pairList.add(null, handle);
+ }
+ } else {
+ String[] names = namesRef.getValue().getStringArray(0, nameCount);
+ for (int i = 0; i < handleCount; i++) {
+ pairList.add(names[i], handles[i]);
+ }
+ }
+ REFS.recycle(namesRef);
+ REFS.recycle(handlesRef);
+
+ return pairList;
+ }
+
+ public static void freeNdArray(Pointer handle) {
+ checkCall(LIB.MXNDArrayFree(handle));
+ }
+
+ public static Pointer loadNdArrayFromByteArray(byte[] buf, int offset, int size) {
+ Memory memory = new Memory(size);
+ memory.write(0, buf, offset, size);
+ PointerByReference outRef = REFS.acquire();
+ checkCall(LIB.MXNDArrayLoadFromRawBytes(memory, new NativeSize(size), outRef));
+ Pointer p = outRef.getValue();
+ // outRef.getValue().getPointerArray(0, size);
+
+ REFS.recycle(outRef);
+ return p;
+ }
+
+ public static Pointer loadNdArrayFromByteBuffer(ByteBuffer byteBuffer) {
+ // Pointer handle = new Pointer(byteBuffer.address);
+ // ((DirectByteBuffer) byteBuffer).address()
+ // TODO
+ byte[] bytes = new byte[byteBuffer.limit()];
+ byteBuffer.get(bytes);
+ return loadNdArrayFromByteArray(bytes, 0, byteBuffer.limit());
+ }
+
+ public static ByteBuffer saveNdArrayAsByteBuffer(Pointer ndArray) {
+ NativeSizeByReference size = new NativeSizeByReference();
+ PointerByReference ref = new PointerByReference();
+ checkCall(LIB.MXNDArraySaveRawBytes(ndArray, size, ref));
+ return ref.getValue().getByteBuffer(0, size.getValue().longValue());
+ }
+
+ public static byte[] saveNdArrayAsByteArray(Pointer ndArray) {
+ ByteBuffer buffer = saveNdArrayAsByteBuffer(ndArray);
+ byte[] bytes = new byte[buffer.limit()];
+ buffer.get(bytes);
+ return bytes;
+ }
+
+ public static void syncCopyToCPU(Pointer ndArray, Pointer data, int len) {
+ NativeSize size = new NativeSize(len);
+ checkNDArray(ndArray, "copy from");
+ checkNDArray(data, "copy to");
+ checkCall(LIB.MXNDArraySyncCopyToCPU(ndArray, data, size));
+ }
+
+ public static void syncCopyFromCPU(Pointer ndArray, Buffer data, int len) {
+ NativeSize size = new NativeSize(len);
+ Pointer pointer = Native.getDirectBufferPointer(data);
+ checkCall(LIB.MXNDArraySyncCopyFromCPU(ndArray, pointer, size));
+ }
+
+ public static void waitToRead(Pointer ndArray) {
+ checkNDArray(ndArray, "wait to read");
+ checkCall(LIB.MXNDArrayWaitToRead(ndArray));
+ }
+
+ public static void waitToWrite(Pointer ndArray) {
+ checkNDArray(ndArray, "wait to write");
+ checkCall(LIB.MXNDArrayWaitToWrite(ndArray));
+ }
+
+ public static Pointer detachGradient(Pointer handle) {
+ PointerByReference ref = REFS.acquire();
+ checkCall(LIB.MXNDArrayDetach(handle, ref));
+ Pointer pointer = ref.getValue();
+ REFS.recycle(ref);
+ return pointer;
+ }
+
+ public static Pointer getGradient(Pointer handle) {
+ PointerByReference ref = REFS.acquire();
+ checkNDArray(handle, "get the gradient for");
+ checkCall(LIB.MXNDArrayGetGrad(handle, ref));
+ Pointer pointer = ref.getValue();
+ REFS.recycle(ref);
+ return pointer;
+ }
+
+ public static void autogradMarkVariables(
+ int numVar, Pointer varHandles, IntBuffer reqsArray, Pointer gradHandles) {
+ PointerByReference varRef = REFS.acquire();
+ PointerByReference gradRef = REFS.acquire();
+ varRef.setValue(varHandles);
+ gradRef.setValue(gradHandles);
+ checkCall(LIB.MXAutogradMarkVariables(numVar, varRef, reqsArray, gradRef));
+ REFS.recycle(varRef);
+ REFS.recycle(gradRef);
+ }
+
+ public static Map<String, FunctionInfo> getNdArrayFunctions() {
+ Set<String> opNames = JnaUtils.getAllOpNames();
+ Map<String, FunctionInfo> map = new ConcurrentHashMap<>();
+
+ PointerByReference ref = REFS.acquire();
+ for (String opName : opNames) {
+ checkCall(LIB.NNGetOpHandle(opName, ref));
+ String functionName = getOpNamePrefix(opName);
+ map.put(functionName, getFunctionByName(opName, functionName, ref.getValue()));
+ ref.setValue(null);
+ }
+ REFS.recycle(ref);
+ return map;
+ }
+
+ public static PairList<Pointer, SparseFormat> imperativeInvoke(
+ Pointer function, NDArray[] src, NDArray[] dest, PairList<String, ?> params) {
+ String[] keys;
+ String[] values;
+ if (params == null) {
+ keys = EMPTY_ARRAY;
+ values = EMPTY_ARRAY;
+ } else {
+ keys = params.keyArray(EMPTY_ARRAY);
+ values = params.values().stream().map(Object::toString).toArray(String[]::new);
+ }
+ // StringArray keyArray = StringArray.of(keys);
+ // StringArray valueArray = StringArray.of(values);
+ PointerArray srcArray = toPointerArray(src);
+ PointerArray destArray = toPointerArray(dest);
+ PointerByReference destRef = REFS.acquire();
+ destRef.setValue(destArray);
+ PointerByReference destSType = REFS.acquire();
+ IntBuffer numOutputs = IntBuffer.allocate(1);
+ numOutputs.put(0, 1);
+
+ checkCall(
+ LIB.MXImperativeInvoke(
+ function,
+ src.length,
+ srcArray,
+ numOutputs,
+ destRef,
+ keys.length,
+ keys,
+ values,
+ destSType));
+ int numOfOutputs = numOutputs.get(0);
+ Pointer[] ptrArray = destRef.getValue().getPointerArray(0, numOfOutputs);
+ int[] sTypes = destSType.getValue().getIntArray(0, numOfOutputs);
+ PairList<Pointer, SparseFormat> pairList = new PairList<>();
+ for (int i = 0; i < numOfOutputs; i++) {
+ pairList.add(ptrArray[i], SparseFormat.fromValue(sTypes[i]));
+ }
+ REFS.recycle(destRef);
+ REFS.recycle(destSType);
+ srcArray.recycle();
+ // keyArray.recycle();
+ // valueArray.recycle();
+
+ if (destArray != null) {
+ destArray.recycle();
+ }
+ return pairList;
+ }
+
+ private static PointerArray toPointerArray(NDArray[] vals) {
+ if (vals == null) {
+ return null;
+ }
+ Pointer[] valPointers = new Pointer[vals.length];
+ for (int i = 0; i < vals.length; i++) {
+ valPointers[i] = vals[i].getHandle();
+ }
+ return PointerArray.of(valPointers);
+ }
+
+ public static FunctionInfo op(String opName) {
+ if (!OPS.containsKey(opName)) {
+ throw new IllegalArgumentException("Unknown operator: " + opName);
+ }
+ return OPS.get(opName);
+ }
+
+ public static FunctionInfo getFunctionByName(String name, String functionName, Pointer handle) {
+ String[] nameRef = {name};
+ String[] description = new String[1];
+ IntBuffer numArgs = IntBuffer.allocate(1);
+ PointerByReference argNameRef = REFS.acquire();
+ PointerByReference argTypeRef = REFS.acquire();
+ PointerByReference argDescRef = REFS.acquire();
+ String[] keyVarArgs = new String[1];
+ String[] returnType = new String[1];
+
+ checkCall(
+ LIB.MXSymbolGetAtomicSymbolInfo(
+ handle,
+ nameRef,
+ description,
+ numArgs,
+ argNameRef,
+ argTypeRef,
+ argDescRef,
+ keyVarArgs,
+ returnType));
+
+ int count = numArgs.get();
+ PairList<String, String> arguments = new PairList<>();
+ if (count != 0) {
+ String[] argNames =
+ argNameRef.getValue().getStringArray(0, count, StandardCharsets.UTF_8.name());
+ String[] argTypes =
+ argTypeRef.getValue().getStringArray(0, count, StandardCharsets.UTF_8.name());
+ for (int i = 0; i < argNames.length; i++) {
+ arguments.add(argNames[i], argTypes[i]);
+ }
+ }
+
+ REFS.recycle(argNameRef);
+ REFS.recycle(argTypeRef);
+ REFS.recycle(argDescRef);
+
+ return new FunctionInfo(handle, functionName, arguments);
+ }
+
+ public static Set<String> getAllOpNames() {
+ IntBuffer outSize = IntBuffer.allocate(1);
+ PointerByReference outArray = REFS.acquire();
+
+ checkCall(LIB.MXListAllOpNames(outSize, outArray));
+
+ int size = outSize.get();
+ Pointer[] pointers = outArray.getValue().getPointerArray(0, size);
+
+ Set<String> set = new HashSet<>();
+ for (Pointer p : pointers) {
+ set.add(p.getString(0, StandardCharsets.UTF_8.name()));
+ }
+ REFS.recycle(outArray);
+ return set;
+ }
+
+ public static String getOpNamePrefix(String name) {
+ for (String prefix : OP_NAME_PREFIX) {
+ if (name.startsWith(prefix)) {
+ return name.substring(prefix.length());
+ }
+ }
+ return name;
+ }
+
+ public static DataType getDataTypeOfNdArray(Pointer handle) {
+ IntBuffer dataType = IntBuffer.allocate(1);
+ checkNDArray(handle, "get the data type of");
+ checkCall(LIB.MXNDArrayGetDType(handle, dataType));
+ return DataType.values()[dataType.get()];
+ }
+
+ public static Device getDeviceOfNdArray(Pointer handle) {
+ IntBuffer deviceType = IntBuffer.allocate(1);
+ IntBuffer deviceId = IntBuffer.allocate(1);
+ checkNDArray(handle, "get the device of");
+ checkCall(LIB.MXNDArrayGetContext(handle, deviceType, deviceId));
+ String deviceTypeStr = DeviceType.fromDeviceType(deviceType.get(0));
+ // CPU is special case which don't have device id
+ return Device.of(deviceTypeStr, deviceId.get(0));
+ }
+
+ public static Shape getShapeOfNdArray(Pointer handle) {
+ IntBuffer dim = IntBuffer.allocate(1);
+ PointerByReference ref = REFS.acquire();
+ checkNDArray(handle, "get the shape of");
+ checkCall(LIB.MXNDArrayGetShape(handle, dim, ref));
+ int nDim = dim.get();
+ if (nDim == 0) {
+ REFS.recycle(ref);
+ return new Shape();
+ }
+ int[] shape = ref.getValue().getIntArray(0, nDim);
+ REFS.recycle(ref);
+ return new Shape(Arrays.stream(shape).asLongStream().toArray());
+ }
+
+ public static Shape getShape64OfNdArray(Pointer handle) {
+ IntBuffer dim = IntBuffer.allocate(1);
+ PointerByReference ref = REFS.acquire();
+ checkNDArray(handle, "get the shape64 of");
+ checkCall(LIB.MXNDArrayGetShape64(handle, dim, ref));
+ int nDim = dim.get();
+ if (nDim == 0) {
+ REFS.recycle(ref);
+ return new Shape();
+ }
+ int[] shape = ref.getValue().getIntArray(0, nDim);
+ REFS.recycle(ref);
+ return new Shape(Arrays.stream(shape).asLongStream().toArray());
+ }
+
+ public static SparseFormat getStorageType(Pointer handle) {
+ IntBuffer type = IntBuffer.allocate(1);
+ checkNDArray(handle, "get the storage type of");
+ checkCall(LIB.MXNDArrayGetStorageType(handle, type));
+ return SparseFormat.fromValue(type.get());
+ }
+
+ public static Pointer createNdArray(
+ Device device, Shape shape, DataType dataType, int size, boolean delayedAlloc) {
+ int deviceType = DeviceType.toDeviceType(device);
+ int deviceId = (deviceType != 1) ? device.getDeviceId() : -1;
+ int delay = delayedAlloc ? 1 : 0;
+
+ PointerByReference ref = REFS.acquire();
+ int[] shapeArray = Arrays.stream(shape.getShape()).mapToInt(Math::toIntExact).toArray();
+ checkCall(
+ LIB.MXNDArrayCreate(
+ shapeArray, size, deviceType, deviceId, delay, dataType.ordinal(), ref));
+
+ Pointer pointer = ref.getValue();
+ REFS.recycle(ref);
+ return pointer;
+ }
+
+ public static Pointer createSparseNdArray(
+ SparseFormat fmt,
+ Device device,
+ Shape shape,
+ DataType dtype,
+ DataType[] auxDTypes,
+ Shape[] auxShapes,
+ boolean delayedAlloc) {
+ int[] shapeArray = Arrays.stream(shape.getShape()).mapToInt(Math::toIntExact).toArray();
+ int deviceType = DeviceType.toDeviceType(device);
+ int deviceId = (deviceType != 1) ? device.getDeviceId() : -1;
+ int delay = delayedAlloc ? 1 : 0;
+ PointerByReference ref = REFS.acquire();
+ IntBuffer auxDTypesInt =
+ IntBuffer.wrap(Arrays.stream(auxDTypes).mapToInt(DataType::ordinal).toArray());
+ IntBuffer auxNDims =
+ IntBuffer.wrap(Arrays.stream(auxShapes).mapToInt(Shape::dimension).toArray());
+ int[] auxShapesInt = Arrays.stream(auxShapes).mapToInt(ele -> (int) ele.head()).toArray();
+ checkCall(
+ LIB.MXNDArrayCreateSparseEx(
+ fmt.getValue(),
+ shapeArray,
+ shapeArray.length,
+ deviceType,
+ deviceId,
+ delay,
+ dtype.ordinal(),
+ auxDTypes.length,
+ auxDTypesInt,
+ auxNDims,
+ auxShapesInt,
+ ref));
+ Pointer pointer = ref.getValue();
+ REFS.recycle(ref);
+ return pointer;
+ }
+
+ public static void ndArraySyncCopyFromNdArray(NDArray dest, NDArray src, int location) {
+ checkCall(LIB.MXNDArraySyncCopyFromNDArray(dest.getHandle(), src.getHandle(), location));
+ }
+
+ public static int getVersion() {
+ IntBuffer version = IntBuffer.allocate(1);
+ checkCall(LIB.MXGetVersion(version));
+ return version.get();
+ }
+
+ public static NDArray[] cachedOpInvoke(
+ MxResource parent, Pointer cachedOpHandle, NDArray[] inputs) {
+ IntBuffer buf = IntBuffer.allocate(1);
+ PointerArray array = toPointerArray(inputs);
+ PointerByReference ref = REFS.acquire();
+ PointerByReference outSTypeRef = REFS.acquire();
+ Device device = inputs[0].getDevice();
+ // TODO: check the init value of default_dev_type and default_dev_id
+ checkCall(
+ LIB.MXInvokeCachedOp(
+ cachedOpHandle,
+ inputs.length,
+ array,
+ DeviceType.toDeviceType(device),
+ 0,
+ buf,
+ ref,
+ outSTypeRef));
+ int numOutputs = buf.get();
+ Pointer[] ptrArray = ref.getValue().getPointerArray(0, numOutputs);
+ int[] sTypes = outSTypeRef.getValue().getIntArray(0, numOutputs);
+ NDArray[] output = new NDArray[numOutputs];
+ for (int i = 0; i < numOutputs; i++) {
+ if (sTypes[i] != 0) {
+ output[i] = NDArray.create(parent, ptrArray[i], SparseFormat.fromValue(sTypes[i]));
+ } else {
+ output[i] = NDArray.create(parent, ptrArray[i]);
+ }
+ }
+ REFS.recycle(ref);
+ REFS.recycle(outSTypeRef);
+ array.recycle();
+ return output;
+ }
+
+ private static void checkNDArray(Pointer pointer, String msg) {
+ if (pointer == null) {
+ throw new IllegalArgumentException(
+ "Tried to " + msg + " an MXNet NDArray that was already closed");
+ }
+ }
+
+ public static void checkCall(int ret) {
+ if (ret != 0) {
+ logger.error("MXNet engine call failed: " + getLastError());
+ throw new JnaCallException("MXNet engine call failed: " + getLastError());
+ }
+ }
+
+ private static String getLastError() {
+ return LIB.MXGetLastError();
+ }
+
+ private static String[] toStringArray(PointerByReference ref, int size) {
+ if (size == 0) {
+ return new String[0];
+ }
+
+ Pointer[] pointers = ref.getValue().getPointerArray(0, size);
+
+ String[] arr = new String[size];
+ for (int i = 0; i < size; ++i) {
+ arr[i] = pointers[i].getString(0, StandardCharsets.UTF_8.name());
+ }
+
+ return arr;
+ }
+
+ private static Set<String> getFeaturesInternal() {
+ PointerByReference ref = REFS.acquire();
+ NativeSizeByReference outSize = new NativeSizeByReference();
+ checkCall(LIB.MXLibInfoFeatures(ref, outSize));
+ int size = outSize.getValue().intValue();
+ if (size == 0) {
+ REFS.recycle(ref);
+ return Collections.emptySet();
+ }
+
+ LibFeature pointer = new LibFeature(ref.getValue());
+ pointer.read();
+
+ LibFeature[] features = (LibFeature[]) pointer.toArray(size);
+
+ Set<String> set = new HashSet<>();
+ for (LibFeature feature : features) {
+ if (feature.getEnabled() == 1) {
+ set.add(feature.getName());
+ }
+ }
+ REFS.recycle(ref);
+ return set;
+ }
+
+ public static Set<String> getFeatures() {
+ return FEATURES;
+ }
+
+ public static boolean autogradIsTraining() {
+ ByteBuffer isTraining = ByteBuffer.allocate(1);
+ checkCall(LIB.MXAutogradIsTraining(isTraining));
+ return isTraining.get(0) == 1;
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/LibUtils.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/LibUtils.java
new file mode 100644
index 0000000..d110398
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/LibUtils.java
@@ -0,0 +1,211 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.jna;
+
+import com.sun.jna.Library;
+import com.sun.jna.Native;
+import java.io.File;
+import java.io.IOException;
+import java.io.InputStream;
+import java.net.URL;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.StandardCopyOption;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Enumeration;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+import org.apache.mxnet.util.Platform;
+import org.apache.mxnet.util.Utils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Utilities for finding the MXNet Engine binary on the System.
+ *
+ * <p>The Engine will be searched for in a variety of locations in the following order:
+ *
+ * <ol>
+ * <li>In the path specified by the MXNET_LIBRARY_PATH environment variable
+ * <li>In a jar file location in the classpath. These jars can be created with the mxnet-native
+ * module.
+ * <li>In the python3 path. These can be installed using pip.
+ * <li>In the python path. These can be installed using pip.
+ * </ol>
+ */
+@SuppressWarnings("MissingJavadocMethod")
+public final class LibUtils {
+
+ private static final Logger logger = LoggerFactory.getLogger(LibUtils.class);
+
+ private static final String LIB_NAME = "mxnet";
+
+ private static final String MXNET_LIBRARY_PATH = "MXNET_LIBRARY_PATH";
+
+ private static final String MXNET_PROPERTIES_FILE_PATH = "native/lib/mxnet.properties";
+
+ private LibUtils() {}
+
+ public static MxnetLibrary loadLibrary() {
+
+ String libName = getLibName();
+ logger.debug("Loading mxnet library from: {}", libName);
+ if (System.getProperty("os.name").startsWith("Linux")) {
+ logger.info("Loading on Linux platform");
+ Map<String, Integer> options = new ConcurrentHashMap<>();
+ int rtld = 1; // Linux RTLD lazy + local
+ options.put(Library.OPTION_OPEN_FLAGS, rtld);
+ return Native.load(libName, MxnetLibrary.class, options);
+ }
+ return Native.load(libName, MxnetLibrary.class);
+ }
+
+ public static String getLibName() {
+ String libName = findOverrideLibrary();
+ if (libName == null) {
+ libName = LibUtils.findLibraryInClasspath();
+ if (libName == null) {
+ libName = LIB_NAME;
+ }
+ }
+
+ return libName;
+ }
+
+ private static String findOverrideLibrary() {
+ // TODO: load from jar files
+ String libPath = System.getenv(MXNET_LIBRARY_PATH);
+ if (libPath != null) {
+ String libName = findLibraryInPath(libPath);
+ if (libName != null) {
+ return libName;
+ }
+ }
+
+ libPath = System.getProperty("java.library.path");
+ if (libPath != null) {
+ return findLibraryInPath(libPath);
+ }
+ return null;
+ }
+
+ private static synchronized String findLibraryInClasspath() {
+ Enumeration<URL> urls = getUrls();
+ // No native jars
+ if (!urls.hasMoreElements()) {
+ logger.debug("mxnet.properties not found in class path.");
+ return null;
+ }
+
+ // Find the mxnet library version that matches local system platform
+ // throw exception if no one matches
+ Platform systemPlatform = Platform.fromSystem();
+ try {
+ while (urls.hasMoreElements()) {
+ URL url = urls.nextElement();
+ Platform platform = Platform.fromUrl(url);
+ if (!platform.isPlaceholder() && platform.matches(systemPlatform)) {
+ return loadLibraryFromClasspath(platform);
+ }
+ }
+ } catch (IOException e) {
+ throw new IllegalStateException(
+ "Failed to read MXNet native library jar properties", e);
+ }
+
+ throw new IllegalStateException(
+ "Your MXNet native library jar does not match your operating system. Make sure that the Maven Dependency Classifier matches your system type.");
+ }
+
+ private static Enumeration<URL> getUrls() {
+ try {
+ return Thread.currentThread()
+ .getContextClassLoader()
+ .getResources(MXNET_PROPERTIES_FILE_PATH);
+ } catch (IOException e) {
+ logger.warn(
+ String.format(
+ "IO Exception occurs when try to find the file %s", MXNET_LIBRARY_PATH),
+ e);
+ return null;
+ }
+ }
+
+ private static String loadLibraryFromClasspath(Platform platform) {
+ Path tmp = null;
+ try {
+ String libName = System.mapLibraryName(LIB_NAME);
+ Path cacheFolder = Utils.getEngineCacheDir(LIB_NAME);
+ logger.debug("Using cache dir: {}", cacheFolder);
+
+ Path dir = cacheFolder.resolve(platform.getVersion() + platform.getClassifier());
+ Path path = dir.resolve(libName);
+ if (Files.exists(path)) {
+ return path.toAbsolutePath().toString();
+ }
+ Files.createDirectories(cacheFolder);
+ tmp = Files.createTempDirectory(cacheFolder, "tmp");
+ for (String file : platform.getLibraries()) {
+ String libPath = "/native/lib/" + file;
+ try (InputStream is = LibUtils.class.getResourceAsStream(libPath)) {
+ logger.info("Extracting {} to cache ...", file);
+ Files.copy(is, tmp.resolve(file), StandardCopyOption.REPLACE_EXISTING);
+ }
+ }
+
+ Utils.moveQuietly(tmp, dir);
+ return path.toAbsolutePath().toString();
+ } catch (IOException e) {
+ throw new IllegalStateException("Failed to extract MXNet native library", e);
+ } finally {
+ if (tmp != null) {
+ Utils.deleteQuietly(tmp);
+ }
+ }
+ }
+
+ private static String findLibraryInPath(String libPath) {
+ String[] paths = libPath.split(File.pathSeparator);
+ List<String> mappedLibNames;
+ if (com.sun.jna.Platform.isMac()) {
+ mappedLibNames = Arrays.asList("libmxnet.dylib", "libmxnet.jnilib", "libmxnet.so");
+ } else {
+ mappedLibNames = Collections.singletonList(System.mapLibraryName(LIB_NAME));
+ }
+
+ for (String path : paths) {
+ File p = new File(path);
+ if (!p.exists()) {
+ continue;
+ }
+ for (String name : mappedLibNames) {
+ if (p.isFile() && p.getName().endsWith(name)) {
+ return p.getAbsolutePath();
+ }
+
+ File file = new File(path, name);
+ if (file.exists() && file.isFile()) {
+ return file.getAbsolutePath();
+ }
+ }
+ }
+ return null;
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/NativeString.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/NativeString.java
new file mode 100644
index 0000000..d43e475
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/NativeString.java
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.jna;
+
+import com.sun.jna.Memory;
+import com.sun.jna.Pointer;
+import java.nio.charset.Charset;
+
+/**
+ * Provides a temporary allocation of an immutable C string (<code>const char*</code> or <code>
+ * const wchar_t*</code>) for use when converting a Java String into a native memory function
+ * argument.
+ */
+final class NativeString {
+
+ private static final ObjectPool<NativeString> POOL = new ObjectPool<>(null, null);
+
+ private Memory pointer;
+
+ /**
+ * Create a native string (NUL-terminated array of <code>char</code>), using the requested
+ * encoding.
+ *
+ * @param data the bytes of the string
+ */
+ private NativeString(byte[] data) {
+ pointer = new Memory(data.length + 1);
+ setData(data);
+ }
+
+ private void setData(byte[] data) {
+ pointer.write(0, data, 0, data.length);
+ pointer.setByte(data.length, (byte) 0);
+ }
+
+ /**
+ * Acquires a pooled {@code NativeString} object if available, otherwise a new instance is
+ * created.
+ *
+ * @param string the string value
+ * @param encoding the charset encoding
+ * @return a {@code NativeString} object
+ */
+ public static NativeString of(String string, Charset encoding) {
+ byte[] data = string.getBytes(encoding);
+ NativeString array = POOL.acquire();
+ if (array != null && array.pointer.size() > data.length) {
+ array.setData(data);
+ return array;
+ }
+ return new NativeString(data);
+ }
+
+ /** Recycles this instance and return it back to the pool. */
+ public void recycle() {
+ POOL.recycle(this);
+ }
+
+ /**
+ * Returns the peer pointer.
+ *
+ * @return the peer pointer
+ */
+ public Pointer getPointer() {
+ return pointer;
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/ObjectPool.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/ObjectPool.java
new file mode 100644
index 0000000..573acd4
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/ObjectPool.java
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.jna;
+
+import java.util.Queue;
+import java.util.concurrent.ConcurrentLinkedQueue;
+import java.util.function.Consumer;
+import java.util.function.Supplier;
+
+/**
+ * A generic object pool implementation.
+ *
+ * @param <T> the type of object to put in the pool
+ */
+@SuppressWarnings("MissingJavadocMethod")
+public class ObjectPool<T> {
+
+ private Queue<T> queue;
+ private Supplier<T> supplier;
+ private Consumer<T> consumer;
+
+ public ObjectPool(Supplier<T> supplier, Consumer<T> consumer) {
+ queue = new ConcurrentLinkedQueue<>();
+ this.supplier = supplier;
+ this.consumer = consumer;
+ }
+
+ public T acquire() {
+ T item = queue.poll();
+ if (item == null) {
+ if (supplier != null) {
+ return supplier.get();
+ }
+ }
+ return item;
+ }
+
+ public void recycle(T item) {
+ if (consumer != null) {
+ consumer.accept(item);
+ }
+ queue.add(item);
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/PointerArray.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/PointerArray.java
new file mode 100644
index 0000000..a864b64
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/PointerArray.java
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.jna;
+
+import com.sun.jna.Function;
+import com.sun.jna.Memory;
+import com.sun.jna.Native;
+import com.sun.jna.Pointer;
+
+/**
+ * An abstraction for a native pointer array data type ({@code void**}).
+ *
+ * @see Pointer
+ * @see com.sun.jna.ptr.PointerByReference
+ * @see Function
+ */
+@SuppressWarnings("checkstyle:EqualsHashCode")
+final class PointerArray extends Memory {
+
+ private static final ObjectPool<PointerArray> POOL = new ObjectPool<>(null, null);
+
+ private int length;
+
+ /**
+ * Constructs a {@link Memory} buffer PointerArray given the Pointers to include in it.
+ *
+ * @param arg the pointers to include in the array
+ */
+ private PointerArray(Pointer... arg) {
+ super(Native.POINTER_SIZE * (arg.length + 1));
+ length = arg.length;
+ setPointers(arg);
+ }
+
+ /**
+ * Acquires a pooled {@code PointerArray} object if available, otherwise a new instance is
+ * created.
+ *
+ * @param arg the pointers to include in the array
+ * @return a {@code PointerArray} object
+ */
+ public static PointerArray of(Pointer... arg) {
+ PointerArray array = POOL.acquire();
+ if (array != null && array.length >= arg.length) {
+ array.setPointers(arg);
+ return array;
+ }
+ return new PointerArray(arg);
+ }
+
+ /** Recycles this instance and return it back to the pool. */
+ public void recycle() {
+ POOL.recycle(this);
+ }
+
+ private void setPointers(Pointer[] pointers) {
+ for (int i = 0; i < pointers.length; i++) {
+ setPointer(i * Native.POINTER_SIZE, pointers[i]);
+ }
+ setPointer(Native.POINTER_SIZE * length, null);
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public boolean equals(Object o) {
+ return o == this;
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/StringArray.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/StringArray.java
new file mode 100644
index 0000000..00ffa8f
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/StringArray.java
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.jna;
+
+import com.sun.jna.Memory;
+import com.sun.jna.Native;
+import com.sun.jna.Pointer;
+import java.nio.charset.Charset;
+import java.util.ArrayList;
+import java.util.List;
+
+/** An abstraction for a native string array data type ({@code char**}). */
+@SuppressWarnings("checkstyle:EqualsHashCode")
+final class StringArray extends Memory {
+
+ private static final Charset ENCODING = Native.DEFAULT_CHARSET;
+ private static final ObjectPool<StringArray> POOL = new ObjectPool<>(null, null);
+ /** Hold all {@code NativeString}, avoid be GCed. */
+ private List<NativeString> natives; // NOPMD
+
+ private int length;
+
+ /**
+ * Create a native array of strings.
+ *
+ * @param strings the strings
+ */
+ private StringArray(String[] strings) {
+ super((strings.length + 1) * Native.POINTER_SIZE);
+ natives = new ArrayList<>();
+ length = strings.length;
+ setPointers(strings);
+ }
+
+ private void setPointers(String[] strings) {
+ for (NativeString ns : natives) {
+ ns.recycle();
+ }
+ natives.clear();
+ for (int i = 0; i < strings.length; ++i) {
+ Pointer p = null;
+ if (strings[i] != null) {
+ NativeString ns = NativeString.of(strings[i], ENCODING);
+ natives.add(ns);
+ p = ns.getPointer();
+ }
+ setPointer(Native.POINTER_SIZE * i, p);
+ }
+ setPointer(Native.POINTER_SIZE * strings.length, null);
+ }
+
+ /**
+ * Acquires a pooled {@code StringArray} object if available, otherwise a new instance is
+ * created.
+ *
+ * @param strings the pointers to include in the array
+ * @return a {@code StringArray} object
+ */
+ public static StringArray of(String[] strings) {
+ StringArray array = POOL.acquire();
+ if (array != null && array.length >= strings.length) {
+ array.setPointers(strings);
+ return array;
+ }
+ return new StringArray(strings);
+ }
+
+ /** Recycles this instance and return it back to the poll. */
+ public void recycle() {
+ POOL.recycle(this);
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public boolean equals(Object o) {
+ return this == o;
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/package-info.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/package-info.java
new file mode 100644
index 0000000..d524d50
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/package-info.java
@@ -0,0 +1,19 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/** Contains the Java front-end implementation for Apache MXNet. */
+package org.apache.mxnet.jna;
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/NDArray.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/NDArray.java
new file mode 100644
index 0000000..f775f44
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/NDArray.java
@@ -0,0 +1,3455 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.ndarray;
+
+import com.sun.jna.Native;
+import com.sun.jna.Pointer;
+import java.io.ByteArrayInputStream;
+import java.io.DataInputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.Buffer;
+import java.nio.ByteBuffer;
+import java.nio.DoubleBuffer;
+import java.nio.FloatBuffer;
+import java.nio.IntBuffer;
+import java.nio.LongBuffer;
+import java.util.Arrays;
+import java.util.stream.IntStream;
+import org.apache.mxnet.engine.BaseMxResource;
+import org.apache.mxnet.engine.Device;
+import org.apache.mxnet.engine.MxResource;
+import org.apache.mxnet.engine.OpParams;
+import org.apache.mxnet.jna.JnaUtils;
+import org.apache.mxnet.ndarray.index.NDIndex;
+import org.apache.mxnet.ndarray.types.DataType;
+import org.apache.mxnet.ndarray.types.Shape;
+import org.apache.mxnet.ndarray.types.SparseFormat;
+import org.apache.mxnet.util.Float16Utils;
+import org.apache.mxnet.util.PairList;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * A Class representing an n-dimensional array.
+ *
+ * <p>NDArray is the core data structure for all mathematical computations. An NDArray represents a
+ * multidimensional, fixed-size homogeneous array. It has very similar behaviour to the Numpy python
+ * package with the addition of efficient computing.
+ */
+public class NDArray extends MxResource {
+
+ private static final Logger logger = LoggerFactory.getLogger(NDArray.class);
+
+ private static final int MAX_SIZE = 100;
+ private static final int MAX_DEPTH = 10;
+ private static final int MAX_ROWS = 10;
+ private static final int MAX_COLUMNS = 20;
+ private static final NDArray[] EMPTY = new NDArray[0];
+
+ private String name;
+ private Device device;
+ private SparseFormat sparseFormat;
+ private DataType dataType;
+ private Shape shape;
+ // use Boolean object to maintain three status: false, true
+ // and null which means the flag is not set by the native engine yet
+ private Boolean hasGradient;
+ private Integer version;
+ private NDArrayEx mxNDArrayEx;
+
+ protected NDArray(Pointer handle) {
+ super(BaseMxResource.getSystemMxResource(), handle);
+ }
+
+ /**
+ * Constructs an {@link NDArray} from a native handle and metadata (internal. Use {@method
+ * create} methods).
+ *
+ * @param parent the parent {@link MxResource} to manage the life circle of the {@link NDArray}
+ * @param handle the pointer to the native NDArray memory
+ * @param device the device the new array will be located on
+ * @param shape the shape of the new array
+ * @param dataType the dataType of the new array
+ * @param hasGradient the gradient status of the new array
+ */
+ NDArray(
+ MxResource parent,
+ Pointer handle,
+ Device device,
+ Shape shape,
+ DataType dataType,
+ boolean hasGradient) {
+ this(parent, handle);
+ setParent(parent);
+ this.device = device;
+ // shape check
+ if (Arrays.stream(shape.getShape()).anyMatch(s -> s < 0)) {
+ throw new IllegalArgumentException("The shape must be >= 0");
+ }
+ this.shape = shape;
+ this.dataType = dataType;
+ this.hasGradient = hasGradient;
+ if (parent != null) {
+ parent.addSubResource(this);
+ }
+ }
+
+ /**
+ * Constructs an {@link NDArray} from a native handle and metadata (internal).
+ *
+ * @param parent the parent {@link MxResource} to manage the life circle of the {@link NDArray}
+ * @param handle the pointer to the native NDArray memory
+ */
+ NDArray(MxResource parent, Pointer handle) {
+ super(parent, handle);
+ this.mxNDArrayEx = new NDArrayEx(this);
+ }
+
+ /**
+ * Constructs an {@link NDArray} from a native handle and metadata (internal).
+ *
+ * @param parent the parent {@link MxResource} to manage the life circle of the {@link NDArray}
+ * @param handle the pointer to the native NDArray memory
+ * @param fmt the sparse format
+ */
+ NDArray(MxResource parent, Pointer handle, SparseFormat fmt) {
+ this(parent, handle);
+ this.sparseFormat = fmt;
+ }
+
+ /**
+ * Creates an NDArray with the given Native Memory Pointer and parent MxResource.
+ *
+ * @param parent the parent {@link MxResource} instance
+ * @param handle the array's native memory pointer
+ * @return the created array
+ */
+ public static NDArray create(MxResource parent, Pointer handle) {
+ return new NDArray(parent, handle);
+ }
+
+ /**
+ * Creates an NDArray with the given Native Memory Pointer and parent MxResource.
+ *
+ * @param parent the parent {@link MxResource} instance
+ * @param handle the array's native memory pointer
+ * @param fmt the sparse format
+ * @return the created array
+ */
+ public static NDArray create(MxResource parent, Pointer handle, SparseFormat fmt) {
+ return new NDArray(parent, handle, fmt);
+ }
+
+ /**
+ * Creates an uninitialized instance of {@link DataType#FLOAT32} {@link NDArray} with specified
+ * parent {@link MxResource}, {@link Shape}, {@link Device} and {@code hasGradient}.
+ *
+ * @param parent the parent {@link MxResource}
+ * @param shape the {@link Shape} of the {@link NDArray}
+ * @param device the {@link Device} of the {@link NDArray}
+ * @return a new instance of {@link NDArray}
+ */
+ public static NDArray create(MxResource parent, Shape shape, Device device) {
+ return create(parent, shape, DataType.FLOAT32, device);
+ }
+
+ /**
+ * Creates an uninitialized instance of {@link DataType#FLOAT32} {@link NDArray} with specified
+ * parent {@link MxResource} and {@link Shape}.
+ *
+ * @param parent the parent {@link MxResource}
+ * @param shape the {@link Shape} of the {@link NDArray}
+ * @return a new instance of {@link NDArray}
+ */
+ public static NDArray create(MxResource parent, Shape shape) {
+ return create(parent, shape, DataType.FLOAT32, Device.defaultIfNull());
+ }
+
+ /**
+ * Creates an uninitialized instance of {@link NDArray} with specified parent {@link
+ * MxResource}, {@link Shape}, {@link DataType}, {@link Device} and {@code hasGradient}.
+ *
+ * @param parent the parent {@link MxResource}
+ * @param shape the {@link Shape} of the {@link NDArray}
+ * @param dataType the {@link DataType} of the {@link NDArray}
+ * @param device the {@link Device} of the {@link NDArray}
+ * @param hasGradient true if the gradient calculation is required for this {@link NDArray}
+ * @return a new instance of {@link NDArray}
+ */
+ public static NDArray create(
+ MxResource parent, Shape shape, DataType dataType, Device device, boolean hasGradient) {
+ Pointer handle =
+ JnaUtils.createNdArray(device, shape, dataType, shape.dimension(), hasGradient);
+ return new NDArray(parent, handle, device, shape, dataType, hasGradient);
+ }
+
+ /**
+ * Creates an uninitialized instance of {@link NDArray} with specified parent {@link
+ * MxResource}, {@link Shape}, {@link DataType}, {@link Device}.
+ *
+ * @param parent the parent {@link MxResource}
+ * @param shape the {@link Shape} of the {@link NDArray}
+ * @param dataType the {@link DataType} of the {@link NDArray}
+ * @param device the {@link Device} of the {@link NDArray}
+ * @return a new instance of {@link NDArray}
+ */
+ public static NDArray create(MxResource parent, Shape shape, DataType dataType, Device device) {
+ Pointer handle = JnaUtils.createNdArray(device, shape, dataType, shape.dimension(), false);
+ return new NDArray(parent, handle, Device.defaultIfNull(device), shape, dataType, false);
+ }
+
+ /**
+ * Creates and initializes a scalar {@link NDArray}.
+ *
+ * @param parent the parent {@link MxResource} instance
+ * @param data the {@link Number} that needs to be set
+ * @return a new instance of {@link NDArray}
+ * @throws IllegalArgumentException when the Type of data is not expected
+ */
+ public static NDArray create(MxResource parent, Number data) {
+ if (data instanceof Integer) {
+ return create(parent, data.intValue());
+ } else if (data instanceof Float) {
+ return create(parent, data.floatValue());
+ } else if (data instanceof Double) {
+ return create(parent, data.doubleValue());
+ } else if (data instanceof Long) {
+ return create(parent, data.longValue());
+ } else if (data instanceof Byte) {
+ return create(parent, data.byteValue());
+ } else {
+ throw new IllegalArgumentException("Short conversion not supported!");
+ }
+ }
+
+ /**
+ * Creates and initializes an instance of {@link NDArray} with specified {@link Shape} and float
+ * array.
+ *
+ * @param parent the parent {@link MxResource} to manage this instance
+ * @param data the float array that needs to be set
+ * @param shape the {@link Shape} of the {@link NDArray}
+ * @return a new instance of {@link NDArray}
+ */
+ public static NDArray create(MxResource parent, float[] data, Shape shape) {
+ return create(parent, FloatBuffer.wrap(data), shape);
+ }
+
+ /**
+ * Creates and initializes an instance of {@link NDArray} with specified {@link Shape} and int
+ * array.
+ *
+ * @param parent the parent {@link MxResource} to manage this instance
+ * @param data the float array that needs to be set
+ * @param shape the {@link Shape} of the {@link NDArray}
+ * @return a new instance of {@link NDArray}
+ */
+ public static NDArray create(MxResource parent, int[] data, Shape shape) {
+ return create(parent, IntBuffer.wrap(data), shape);
+ }
+
+ /**
+ * Creates and initializes an instance of {@link NDArray} with specified {@link Shape} and
+ * double array.
+ *
+ * @param parent the parent {@link MxResource} to manage this instance
+ * @param data the float array that needs to be set
+ * @param shape the {@link Shape} of the {@link NDArray}
+ * @return a new instance of {@link NDArray}
+ */
+ public static NDArray create(MxResource parent, double[] data, Shape shape) {
+ return create(parent, DoubleBuffer.wrap(data), shape);
+ }
+
+ /**
+ * Creates and initializes an instance of {@link NDArray} with specified {@link Shape} and long
+ * array.
+ *
+ * @param parent the parent {@link MxResource} to manage this instance
+ * @param data the float array that needs to be set
+ * @param shape the {@link Shape} of the {@link NDArray}
+ * @return a new instance of {@link NDArray}
+ */
+ public static NDArray create(MxResource parent, long[] data, Shape shape) {
+ return create(parent, LongBuffer.wrap(data), shape);
+ }
+
+ /**
+ * Creates and initializes an instance of {@link NDArray} with specified {@link Shape} and byte
+ * array.
+ *
+ * @param parent the parent {@link MxResource} to manage this instance
+ * @param data the float array that needs to be set
+ * @param shape the {@link Shape} of the {@link NDArray}
+ * @return a new instance of {@link NDArray}
+ */
+ public static NDArray create(MxResource parent, byte[] data, Shape shape) {
+ return create(parent, ByteBuffer.wrap(data), shape);
+ }
+
+ /**
+ * Creates and initializes an instance of {@link NDArray} with specified {@link Shape} and
+ * boolean array.
+ *
+ * @param parent the parent {@link MxResource} to manage this instance
+ * @param data the boolean array that needs to be set
+ * @param shape the {@link Shape} of the {@link NDArray}
+ * @return a new instance of {@link NDArray}
+ */
+ public static NDArray create(MxResource parent, boolean[] data, Shape shape) {
+ byte[] byteData = new byte[data.length];
+ for (int i = 0; i < data.length; i++) {
+ byteData[i] = (byte) (data[i] ? 1 : 0);
+ }
+ return create(parent, ByteBuffer.wrap(byteData), shape, DataType.BOOLEAN);
+ }
+
+ /**
+ * Creates and initializes a scalar {@link NDArray}.
+ *
+ * @param parent the parent {@link MxResource} to manage this instance
+ * @param data the float that needs to be set
+ * @return a new instance of {@link NDArray}
+ */
+ public static NDArray create(MxResource parent, float data) {
+ return create(parent, new float[] {data}, new Shape());
+ }
+
+ /**
+ * Creates and initializes a scalar {@link NDArray}.
+ *
+ * @param parent the parent {@link MxResource} to manage this instance
+ * @param data the float data that needs to be set
+ * @return a new instance of {@link NDArray}
+ */
+ public static NDArray create(MxResource parent, int data) {
+ return create(parent, new int[] {data}, new Shape());
+ }
+
+ /**
+ * Creates and initializes a scalar {@link NDArray}.
+ *
+ * @param parent the parent {@link MxResource} to manage this instance
+ * @param data the double data that needs to be set
+ * @return a new instance of {@link NDArray}
+ */
+ public static NDArray create(MxResource parent, double data) {
+ return create(parent, new double[] {data}, new Shape());
+ }
+
+ /**
+ * Creates and initializes a scalar {@link NDArray}.
+ *
+ * @param parent the parent {@link MxResource} to manage this instance
+ * @param data the long data that needs to be set
+ * @return a new instance of {@link NDArray}
+ */
+ public static NDArray create(MxResource parent, long data) {
+ return create(parent, new long[] {data}, new Shape());
+ }
+
+ /**
+ * Creates and initializes a scalar {@link NDArray}.
+ *
+ * @param parent the parent {@link MxResource} to manage this instance
+ * @param data the byte data that needs to be set
+ * @return a new instance of {@link NDArray}
+ */
+ public static NDArray create(MxResource parent, byte data) {
+ return create(parent, new byte[] {data}, new Shape());
+ }
+
+ /**
+ * Creates and initializes a scalar {@link NDArray}.
+ *
+ * @param parent the parent {@link MxResource} to manage this instance
+ * @param data the boolean data that needs to be set
+ * @return a new instance of {@link NDArray}
+ */
+ public static NDArray create(MxResource parent, boolean data) {
+
+ return create(parent, new boolean[] {data}, new Shape());
+ }
+
+ /**
+ * Creates and initializes a 1D {@link NDArray}.
+ *
+ * @param parent the parent {@link MxResource} to manage this instance
+ * @param data the float array that needs to be set
+ * @return a new instance of {@link NDArray}
+ */
+ public static NDArray create(MxResource parent, float[] data) {
+ return create(parent, data, new Shape(data.length));
+ }
+
+ /**
+ * Creates and initializes a 1D {@link NDArray}.
+ *
+ * @param parent the parent {@link MxResource} to manage this instance
+ * @param data the float array that needs to be set
+ * @return a new instance of {@link NDArray}
+ */
+ public static NDArray create(MxResource parent, int[] data) {
+ return create(parent, data, new Shape(data.length));
+ }
+
+ /**
+ * Creates and initializes a 1D {@link NDArray}.
+ *
+ * @param parent the parent {@link MxResource} to manage this instance
+ * @param data the float array that needs to be set
+ * @return a new instance of {@link NDArray}
+ */
+ public static NDArray create(MxResource parent, double[] data) {
+ return create(parent, data, new Shape(data.length));
+ }
+
+ /**
+ * Creates and initializes a 1D {@link NDArray}.
+ *
+ * @param parent the parent {@link MxResource} to manage this instance
+ * @param data the float array that needs to be set
+ * @return a new instance of {@link NDArray}
+ */
+ public static NDArray create(MxResource parent, long[] data) {
+ return create(parent, data, new Shape(data.length));
+ }
+
+ /**
+ * Creates and initializes a 1D {@link NDArray}.
+ *
+ * @param parent the parent {@link MxResource} to manage this instance
+ * @param data the float array that needs to be set
+ * @return a new instance of {@link NDArray}
+ */
+ public static NDArray create(MxResource parent, byte[] data) {
+ return create(parent, data, new Shape(data.length));
+ }
+
+ /**
+ * Creates and initializes a 1D {@link NDArray}.
+ *
+ * @param parent the parent {@link MxResource} to manage this instance
+ * @param data the bool array that needs to be set
+ * @return a new instance of {@link NDArray}
+ */
+ public static NDArray create(MxResource parent, boolean[] data) {
+ return create(parent, data, new Shape(data.length));
+ }
+
+ /**
+ * Creates and initializes a 2D {@link NDArray}.
+ *
+ * @param parent the parent {@link MxResource} to manage this instance
+ * @param data the float array that needs to be set
+ * @return a new instance of {@link NDArray}
+ */
+ public static NDArray create(MxResource parent, float[][] data) {
+ FloatBuffer buffer = FloatBuffer.allocate(data.length * data[0].length);
+ for (float[] d : data) {
+ buffer.put(d);
+ }
+ buffer.rewind();
+ return create(parent, buffer, new Shape(data.length, data[0].length));
+ }
+
+ /**
+ * Creates and initializes a 2D {@link NDArray}.
+ *
+ * @param parent the parent {@link MxResource} to manage this instance
+ * @param data the float array that needs to be set
+ * @return a new instance of {@link NDArray}
+ */
+ public static NDArray create(MxResource parent, int[][] data) {
+ IntBuffer buffer = IntBuffer.allocate(data.length * data[0].length);
+ for (int[] d : data) {
+ buffer.put(d);
+ }
+ buffer.rewind();
+ return create(parent, buffer, new Shape(data.length, data[0].length));
+ }
+
+ /**
+ * Creates and initializes a 2D {@link NDArray}.
+ *
+ * @param parent the parent {@link MxResource} to manage this instance
+ * @param data the float array that needs to be set
+ * @return a new instance of {@link NDArray}
+ */
+ public static NDArray create(MxResource parent, double[][] data) {
+ DoubleBuffer buffer = DoubleBuffer.allocate(data.length * data[0].length);
+ for (double[] d : data) {
+ buffer.put(d);
+ }
+ buffer.rewind();
+ return create(parent, buffer, new Shape(data.length, data[0].length));
+ }
+
+ /**
+ * Creates and initializes a 2-D {@link NDArray}.
+ *
+ * @param parent the parent {@link MxResource} to manage this instance
+ * @param data the float array that needs to be set
+ * @return a new instance of {@link NDArray}
+ */
+ public static NDArray create(MxResource parent, long[][] data) {
+ LongBuffer buffer = LongBuffer.allocate(data.length * data[0].length);
+ for (long[] d : data) {
+ buffer.put(d);
+ }
+ buffer.rewind();
+ return create(parent, buffer, new Shape(data.length, data[0].length));
+ }
+
+ /**
+ * Creates and initializes a 2-D {@link NDArray}.
+ *
+ * @param parent the parent {@link MxResource} instance
+ * @param data the float array that needs to be set
+ * @return a new instance of {@link NDArray}
+ */
+ public static NDArray create(MxResource parent, byte[][] data) {
+ ByteBuffer buffer = ByteBuffer.allocate(data.length * data[0].length);
+ for (byte[] d : data) {
+ buffer.put(d);
+ }
+ buffer.rewind();
+ return create(parent, buffer, new Shape(data.length, data[0].length));
+ }
+
+ /**
+ * Creates and initializes a 2-D {@link NDArray}.
+ *
+ * @param parent the parent {@link MxResource} instance
+ * @param data the boolean array that needs to be set
+ * @return a new instance of {@link NDArray}
+ */
+ public static NDArray create(MxResource parent, boolean[][] data) {
+ ByteBuffer buffer = ByteBuffer.allocate(data.length * data[0].length);
+ for (boolean[] d : data) {
+ for (boolean b : d) {
+ buffer.put((byte) (b ? 1 : 0));
+ }
+ }
+ buffer.rewind();
+ return create(parent, buffer, new Shape(data.length, data[0].length), DataType.BOOLEAN);
+ }
+
+ /**
+ * Creates and initializes a {@link NDArray} with specified {@link Shape}.
+ *
+ * <p>{@link DataType} of the MxNDArray will determined by type of Buffer.
+ *
+ * @param parent the parent {@link MxResource} instance
+ * @param data the data to initialize the {@code MxNDArray}
+ * @param shape the {@link Shape} of the {@link NDArray}
+ * @return a new instance of {@link NDArray}
+ */
+ static NDArray create(MxResource parent, Buffer data, Shape shape) {
+ DataType dataType = DataType.fromBuffer(data);
+ return create(parent, data, shape, dataType);
+ }
+
+ static NDArray create(MxResource parent, Buffer data, Shape shape, DataType dataType) {
+ NDArray array = create(parent, shape, dataType, Device.defaultIfNull());
+ array.set(data);
+ return array;
+ }
+
+ /**
+ * Returns the name of this {@code NDArray}.
+ *
+ * @return the name of this {@code NDArray}
+ */
+ public String getName() {
+ return name;
+ }
+
+ /**
+ * Sets the name of this {@code NDArray}.
+ *
+ * @param name of the {@code NDArray}
+ */
+ public void setName(String name) {
+ this.name = name;
+ }
+
+ /**
+ * Returns the {@link DataType} of this {@code NDArray}.
+ *
+ * @return the {@link DataType} of this {@code NDArray}
+ */
+ public DataType getDataType() {
+ if (this.dataType == null) {
+ this.dataType = JnaUtils.getDataTypeOfNdArray(getHandle());
+ }
+ return this.dataType;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public Device getDevice() {
+ if (this.device == null) {
+ this.device = JnaUtils.getDeviceOfNdArray(getHandle());
+ }
+ return this.device;
+ }
+
+ /**
+ * Returns the {@link Shape} of this {@code NDArray}.
+ *
+ * @return the {@link Shape} of this {@code NDArray}
+ */
+ public Shape getShape() {
+ if (this.shape == null) {
+ this.shape = JnaUtils.getShapeOfNdArray(getHandle());
+ }
+ return this.shape;
+ }
+
+ /**
+ * Returns the {@link SparseFormat} of this {@code NDArray}.
+ *
+ * @return the {@link SparseFormat} of this {@code NDArray}
+ */
+ public SparseFormat getSparseFormat() {
+ if (this.sparseFormat == null) {
+ this.sparseFormat = JnaUtils.getStorageType(getHandle());
+ }
+ return this.sparseFormat;
+ }
+
+ /**
+ * Returns the version of this {@code NDArray}.
+ *
+ * @return the version of this {@code NDArray}
+ */
+ public Integer getVersion() {
+ if (this.version == null) {
+ this.version = JnaUtils.getVersion();
+ }
+ return this.version;
+ }
+
+ private NDArray duplicate(Shape shape, DataType dataType, Device device, String name) {
+ // TODO get copy parameter
+ NDArray array = create(getParent(), shape, dataType, device);
+ array.setName(name);
+ copyTo(array);
+ return array;
+ }
+
+ /**
+ * Returns a copy of this {@code NDArray}.
+ *
+ * @return a copy of this {@code NDArray}
+ */
+ NDArray duplicate() {
+ NDArray array = create(getParent(), getShape(), getDataType(), getDevice());
+ array.setName(getName());
+ copyTo(array);
+ return array;
+ }
+
+ /**
+ * Moves this {@code NDArray} to a different {@link Device}.
+ *
+ * @param device the {@link Device} to be set
+ * @param copy set {@code true} if you want to return a copy of the Existing {@code NDArray}
+ * @return the result {@code NDArray} with the new {@link Device}
+ */
+ public NDArray toDevice(Device device, boolean copy) {
+ if (device.equals(getDevice()) && !copy) {
+ return this;
+ }
+ return duplicate(getShape(), getDataType(), device, getName());
+ }
+
+ /**
+ * Converts this {@code NDArray} to a different {@link DataType}.
+ *
+ * @param dataType the {@link DataType} to be set
+ * @param copy set {@code true} if you want to return a copy of the Existing {@code NDArray}
+ * @return the result {@code NDArray} with the new {@link DataType}
+ */
+ public NDArray toType(DataType dataType, boolean copy) {
+ if (dataType.equals(getDataType()) && !copy) {
+ return this;
+ }
+ return duplicate(getShape(), dataType, getDevice(), getName());
+ }
+
+ /**
+ * Creates an instance of {@link NDArray} with specified {@link Shape} filled with zeros.
+ *
+ * @param shape the {@link Shape} of the {@link NDArray}
+ * @param dataType the {@link DataType} of the {@link NDArray}
+ * @return a new instance of {@link NDArray}
+ * @see #zeros(Shape, DataType, Device)
+ */
+ public NDArray zeros(Shape shape, DataType dataType) {
+ return fill("_npi_zeros", shape, dataType);
+ }
+
+ /**
+ * Creates an instance of {@link NDArray} with the same {@link Shape} and {@link DataType}
+ * filled with zeros.
+ *
+ * @return a new instance of {@link NDArray}
+ * @see #zeros(Shape, DataType, Device)
+ */
+ public NDArray zeros() {
+ return zeros(getShape(), getDataType());
+ }
+
+ /**
+ * Creates an instance of {@link NDArray} with specified {@link Device}, {@link Shape}, and
+ * {@link DataType} filled with zeros.
+ *
+ * @param shape the {@link Shape} of the {@link NDArray}
+ * @param dataType the {@link DataType} of the {@link NDArray}
+ * @param device the {@link Device} of the {@link NDArray}
+ * @return a new instance of {@link NDArray}
+ */
+ NDArray zeros(Shape shape, DataType dataType, Device device) {
+ if (device == null || device.equals(getDevice())) {
+ return zeros(shape, dataType);
+ }
+ return zeros(shape, dataType);
+ }
+
+ private NDArray createGradient(SparseFormat format) {
+ try (NDArray zeros = this.zeros(getShape(), getDataType(), getDevice())) {
+ return zeros.toSparse(format);
+ }
+ }
+
+ private NDArray fill(String opName, Shape shape, DataType dataType) {
+ OpParams params = new OpParams();
+ if (shape == null) {
+ throw new IllegalArgumentException("Shape is required for " + opName.substring(1));
+ }
+ params.addParam("shape", shape);
+ params.setDevice(device);
+ params.setDataType(dataType);
+ return invoke(getParent(), opName, params);
+ }
+
+ /**
+ * Creates an instance of {@link NDArray} with specified {@link Shape} filled with ones.
+ *
+ * @param parent the parent {@link MxResource}
+ * @param shape the {@link Shape} of the {@link NDArray}
+ * @param dataType the {@link DataType} of the {@link NDArray}
+ * @param device the {@link Device} of the {@link NDArray}
+ * @return a new instance of {@link NDArray}
+ */
+ public static NDArray ones(MxResource parent, Shape shape, DataType dataType, Device device) {
+ return create(parent, shape, dataType, device).ones();
+ }
+
+ /**
+ * Creates an instance of {@link NDArray} with specified {@link Shape} filled with ones.
+ *
+ * @param shape the {@link Shape} of the {@link NDArray}
+ * @param dataType the {@link DataType} of the {@link NDArray}
+ * @return a new instance of {@link NDArray}
+ */
+ NDArray ones(Shape shape, DataType dataType) {
+ return fill("_npi_ones", shape, dataType);
+ }
+
+ /**
+ * Creates an instance of {@link NDArray} with same {@link Shape} and {@link DataType} filled
+ * with ones.
+ *
+ * @return a new instance of {@link NDArray}
+ */
+ public NDArray ones() {
+ return ones(getShape(), getDataType());
+ }
+ /**
+ * Creates an instance of {@link NDArray} with specified {@link Shape} filled with ones.
+ *
+ * @param shape the {@link Shape} of the {@link NDArray}
+ * @return a new instance of {@link NDArray}
+ */
+ NDArray ones(Shape shape) {
+ return ones(shape, DataType.FLOAT32);
+ }
+
+ /**
+ * Creates an instance of {@link NDArray} with specified {@link Device}, {@link Shape}, and
+ * {@link DataType} filled with ones.
+ *
+ * @param shape the {@link Shape} of the {@link NDArray}
+ * @param dataType the {@link DataType} of the {@link NDArray}
+ * @param device the {@link Device} of the {@link NDArray}
+ * @return a new instance of {@link NDArray}
+ */
+ NDArray ones(Shape shape, DataType dataType, Device device) {
+ if (device == null || device.equals(getDevice())) {
+ return ones(shape, dataType);
+ }
+ return create(getParent(), shape, dataType, device).ones();
+ }
+
+ /**
+ * Returns the gradient {@code NDArray} attached to this {@code NDArray}.
+ *
+ * @return the gradient {@code NDArray}
+ * @throws IllegalStateException when hasGradient is false
+ */
+ public NDArray getGradient() {
+ if (!hasGradient()) {
+ throw new IllegalStateException(
+ "No gradient attached to this MxNDArray, please call array.requiredGradient()"
+ + "on your MxNDArray or block.setInitializer() on your Block");
+ }
+ Pointer pointer = JnaUtils.getGradient(getHandle());
+ return create(getParent(), pointer);
+ }
+
+ /**
+ * Returns true if the gradient calculation is required for this {@code NDArray}.
+ *
+ * @return true if the gradient calculation is required for this {@code NDArray} else false
+ */
+ public boolean hasGradient() {
+ if (hasGradient == null) {
+ Pointer pointer = JnaUtils.getGradient(getHandle());
+ hasGradient = pointer != null;
+ }
+ return hasGradient;
+ }
+
+ /**
+ * Returns an NDArray equal to this that stop gradient propagation through it.
+ *
+ * @return an NDArray equal to this that stops gradient propagation through it
+ */
+ public NDArray stopGradient() {
+ Pointer pointer = JnaUtils.detachGradient(getHandle());
+ return create(getParent(), pointer);
+ }
+
+ /**
+ * Converts this {@code NDArray} to a String array.
+ *
+ * <p>This method is only applicable to the String typed NDArray and not for printing purpose
+ *
+ * @return Array of Strings
+ */
+ public String[] toStringArray() {
+ throw new UnsupportedOperationException("String MxNDArray is not supported!");
+ }
+
+ /**
+ * Converts this {@code NDArray} to a ByteBuffer.
+ *
+ * @return a ByteBuffer
+ */
+ public ByteBuffer toByteBuffer() {
+ if (getSparseFormat() != SparseFormat.DENSE) {
+ throw new IllegalStateException("Require Dense MxNDArray, actual " + getSparseFormat());
+ }
+ Shape sh = getShape();
+ DataType dType = getDataType();
+ long product = sh.size();
+ long len = dType.getNumOfBytes() * product;
+ ByteBuffer bb = NDSerializer.allocateDirect(Math.toIntExact(len));
+ Pointer pointer = Native.getDirectBufferPointer(bb);
+ JnaUtils.syncCopyToCPU(getHandle(), pointer, Math.toIntExact(product));
+ return bb;
+ }
+
+ /**
+ * Returns the total number of elements in this {@code MxNDArray}.
+ *
+ * @return the number of elements in this {@code MxNDArray}
+ */
+ long size() {
+ return getShape().size();
+ }
+
+ long size(int axis) {
+ return getShape().size(axis);
+ }
+
+ /**
+ * Sets this {@code NDArray} value from {@link Buffer}.
+ *
+ * @param data the input buffered data
+ */
+ public void set(Buffer data) {
+ int size = Math.toIntExact(size());
+ if (data.remaining() < size) {
+ throw new IllegalArgumentException(
+ "The MxNDArray size is: " + size + ", but buffer size is: " + data.remaining());
+ }
+ if (data.isDirect()) {
+ JnaUtils.syncCopyFromCPU(getHandle(), data, size);
+ return;
+ }
+
+ data.limit(size);
+ // int8, uint8, boolean use ByteBuffer, so need to explicitly input DataType
+ DataType inputType = DataType.fromBuffer(data);
+ validate(inputType);
+
+ int numOfBytes = inputType.getNumOfBytes();
+ ByteBuffer buf = NDSerializer.allocateDirect(size * numOfBytes);
+
+ switch (inputType) {
+ case FLOAT32:
+ buf.asFloatBuffer().put((FloatBuffer) data);
+ break;
+ case FLOAT64:
+ buf.asDoubleBuffer().put((DoubleBuffer) data);
+ break;
+ case UINT8:
+ case INT8:
+ case BOOLEAN:
+ buf.put((ByteBuffer) data);
+ break;
+ case INT32:
+ buf.asIntBuffer().put((IntBuffer) data);
+ break;
+ case INT64:
+ buf.asLongBuffer().put((LongBuffer) data);
+ break;
+ case FLOAT16:
+ default:
+ throw new UnsupportedOperationException("data type is not supported!");
+ }
+ buf.rewind();
+ JnaUtils.syncCopyFromCPU(getHandle(), buf, size);
+ }
+
+ private void validate(DataType inputType) {
+ if (getDataType() != inputType
+ && ((dataType != DataType.UINT8 && dataType != DataType.BOOLEAN)
+ || inputType != DataType.INT8)) {
+ // Infer DataType from Buffer always return INT8, make this two special case that
+ // allows set UINT8 and BOOL array with regular ByteBuffer.
+ throw new IllegalStateException(
+ "DataType mismatch, required: " + dataType + ", actual: " + inputType);
+ }
+ }
+
+ /**
+ * Returns {@code true} if this {@code MxNDArray} is a scalar {@code MxNDArray} with empty
+ * {@link Shape}.
+ *
+ * @return {@code true} if this {@code MxNDArray} is a scalar {@code MxNDArray} with empty
+ * {@link Shape}
+ */
+ boolean isScalar() {
+ return getShape().isScalar();
+ }
+
+ /**
+ * Returns {@code true} if all elements within this {@code NDArray} are non-zero or {@code
+ * true}.
+ *
+ * @return {@code true} if all elements within this {@code NDArray} are non-zero or {@code true}
+ */
+ NDArray all() {
+ // result of sum operation is int64 now
+ return toType(DataType.BOOLEAN, false).sum().eq(size());
+ }
+
+ /**
+ * Deep-copies the current {@code NDArray} to the one passed in.
+ *
+ * @param ndArray this {@code NDArray} prepared to be copied to
+ */
+ public void copyTo(NDArray ndArray) {
+
+ Shape inShape = getShape();
+ Shape destShape = ndArray.getShape();
+ if (!Arrays.equals(inShape.getShape(), destShape.getShape())) {
+ throw new IllegalArgumentException(
+ "shape are diff. Required: " + destShape + ", Actual " + inShape);
+ }
+ JnaUtils.op("_npi_copyto").invoke(new NDArray[] {this}, new NDArray[] {ndArray}, null);
+ }
+
+ NDArray booleanMask(NDArray index) {
+ return booleanMask(index, 0);
+ }
+
+ /**
+ * Returns portion of this {@code NDArray} given the index boolean {@code NDArray} along given
+ * axis.
+ *
+ * @param index boolean {@code NDArray} mask
+ * @param axis an integer that represents the axis of {@code NDArray} to mask from
+ * @return the result {@code NDArray}
+ */
+ public NDArray booleanMask(NDArray index, int axis) {
+ if (isScalar() || index.isScalar()) {
+ throw new IllegalArgumentException("booleanMask didn't support scalar!");
+ }
+ // TODO remove reshape when MXNet numpy support multi-dim index
+ // and boolean MxNDArray reshape
+ Shape remainingDims = getShape().slice(index.getShape().dimension());
+ // create a reshape array {-1, remainingDims}
+ long[] reshape = new long[remainingDims.dimension() + 1];
+ reshape[0] = -1;
+ System.arraycopy(remainingDims.getShape(), 0, reshape, 1, remainingDims.dimension());
+ OpParams params = new OpParams();
+ params.addParam("axis", axis);
+ try (NDArray reshaped = this.reshape(new Shape(reshape));
+ NDArray reshapedIndex = index.toType(DataType.INT32, false).reshape(-1);
+ NDArray result =
+ invoke(
+ getParent(),
+ "_npi_boolean_mask",
+ new NDArray[] {reshaped, reshapedIndex},
+ params)) {
+ return result.reshape(reshape);
+ }
+ }
+
+ /**
+ * Sets all elements outside the sequence to a constant value.
+ *
+ * <p>This function takes an n-dimensional input array of the form [batch_size,
+ * max_sequence_length, ....] and returns an array of the same shape. Parameter {@code
+ * sequenceLength} is used to handle variable-length sequences. sequence_length should be an
+ * input array of positive ints of dimension [batch_size].
+ *
+ * @param sequenceLength used to handle variable-length sequences
+ * @param value the constant value to be set
+ * @return the result {@code NDArray}
+ */
+ public NDArray sequenceMask(NDArray sequenceLength, float value) {
+ if (getShape().dimension() < 2 || getShape().isScalar() || getShape().hasZeroDimension()) {
+ throw new IllegalArgumentException(
+ "sequenceMask is not supported for MxNDArray with less than 2 dimensions");
+ }
+ Shape expectedSequenceLengthShape = new Shape(getShape().get(0));
+ if (!sequenceLength.getShape().equals(expectedSequenceLengthShape)) {
+ throw new IllegalArgumentException("SequenceLength must be of shape [batchSize]");
+ }
+ OpParams params = new OpParams();
+ params.add("value", value);
+ params.add("use_sequence_length", true);
+ params.add("axis", 1);
+ return invoke(getParent(), "_npx_sequence_mask", new NDList(this, sequenceLength), params)
+ .head();
+ }
+
+ /**
+ * Sets all elements outside the sequence to 0.
+ *
+ * <p>This function takes an n-dimensional input array of the form [batch_size,
+ * max_sequence_length, ....] and returns an array of the same shape. Parameter {@code
+ * sequenceLength} is used to handle variable-length sequences. sequence_length should be an
+ * input array of positive ints of dimension [batch_size].
+ *
+ * @param sequenceLength used to handle variable-length sequences
+ * @return the result {@code NDArray}
+ */
+ public NDArray sequenceMask(NDArray sequenceLength) {
+ return sequenceMask(sequenceLength, 0);
+ }
+
+ /**
+ * Returns an {@code NDArray} of zeros with the same {@link Shape}, {@link DataType} and {@link
+ * SparseFormat} as the input {@code NDArray}.
+ *
+ * @return a {@code NDArray} filled with zeros
+ */
+ public NDArray zerosLike() {
+ OpParams params = new OpParams();
+ params.addParam("fill_value", 0);
+ return invoke(getParent(), "_npi_full_like", this, params);
+ }
+
+ /**
+ * Returns an {@code NDArray} of ones with the same {@link Shape}, {@link DataType} and {@link
+ * SparseFormat} as the input {@code NDArray}.
+ *
+ * @return a {@code NDArray} filled with ones
+ */
+ public NDArray onesLike() {
+ OpParams params = new OpParams();
+ params.addParam("fill_value", 1);
+ return invoke(getParent(), "_npi_full_like", this, params);
+ }
+
+ NDArray get(NDIndex index) {
+ return getNDArrayInternal().getIndexer().get(this, index);
+ }
+
+ NDArray get(long... indices) {
+ return get(new NDIndex(indices));
+ }
+
+ NDArray getScalar(long... indices) {
+ NDArray value = get(new NDIndex(indices));
+ if (value.size() != 1) {
+ throw new IllegalArgumentException("The supplied Index does not produce a scalar");
+ }
+ return value;
+ }
+
+ boolean getBoolean(long... indices) {
+ return getScalar(indices).toBooleanArray()[0];
+ }
+
+ /**
+ * Returns {@code true} if all elements in this {@code NDArray} are equal to the {@link Number}.
+ *
+ * @param number the number to compare
+ * @return the boolean result
+ */
+ public boolean contentEquals(Number number) {
+ if (number == null) {
+ return false;
+ }
+ try (NDArray result = eq(number)) {
+ return result.all().getBoolean();
+ }
+ }
+
+ /**
+ * Returns {@code true} if all elements in this {@code NDArray} are equal to the other {@link
+ * NDArray}.
+ *
+ * @param other the other {@code NDArray} to compare
+ * @return the boolean result
+ */
+ public boolean contentEquals(NDArray other) {
+ if (other == null || (!shapeEquals(other))) {
+ return false;
+ }
+ if (getDataType() != other.getDataType()) {
+ return false;
+ }
+ try (NDArray result = eq(other).toType(DataType.INT32, false)) {
+ return result.all().getBoolean();
+ }
+ }
+
+ /**
+ * Returns the boolean {@code NDArray} for element-wise "Equals" comparison.
+ *
+ * @param n the number to compare
+ * @return the boolean {@code NDArray} for element-wise "Equals" comparison
+ */
+ public NDArray eq(Number n) {
+ OpParams params = new OpParams();
+ params.add("scalar", n.toString());
+ return invoke(getParent(), "_npi_equal_scalar", this, params);
+ }
+
+ /**
+ * Returns the boolean {@code NDArray} for element-wise "Equals" comparison.
+ *
+ * @param other the {@code NDArray} to compare
+ * @return the boolean {@code NDArray} for element-wise "Equals" comparison
+ */
+ public NDArray eq(NDArray other) {
+ return invoke(getParent(), "_npi_equal", new NDArray[] {this, other}, null);
+ }
+
+ /**
+ * Returns the boolean {@code NDArray} for element-wise "Not equals" comparison.
+ *
+ * @param n the number to compare
+ * @return the boolean {@code NDArray} for element-wise "Not equals" comparison
+ */
+ public NDArray neq(Number n) {
+ OpParams params = new OpParams();
+ params.add("scalar", n.toString());
+ return invoke(getParent(), "_npi_not_equal_scalar", this, params);
+ }
+
+ /**
+ * Returns the boolean {@code NDArray} for element-wise "Not equals" comparison.
+ *
+ * @param other the {@code NDArray} to compare
+ * @return the boolean {@code NDArray} for element-wise "Not equals" comparison
+ */
+ public NDArray neq(NDArray other) {
+ return invoke(getParent(), "_npi_not_equal", new NDArray[] {this, other}, null);
+ }
+
+ /**
+ * Returns the boolean {@code NDArray} for element-wise "Greater" comparison.
+ *
+ * @param other the number to compare
+ * @return the boolean {@code NDArray} for element-wise "Greater" comparison
+ */
+ public NDArray gt(Number other) {
+ OpParams params = new OpParams();
+ params.add("scalar", other.toString());
+ return invoke(getParent(), "_npi_greater_scalar", this, params);
+ }
+
+ /**
+ * Returns the boolean {@code NDArray} for element-wise "Greater Than" comparison.
+ *
+ * @param other the {@code NDArray} to compare
+ * @return the boolean {@code NDArray} for element-wis "Greater Than" comparison
+ */
+ public NDArray gt(NDArray other) {
+ return invoke(getParent(), "_npi_greater", new NDArray[] {this, other}, null);
+ }
+
+ /**
+ * Returns the boolean {@code NDArray} for element-wise "Greater or equals" comparison.
+ *
+ * @param other the number to compare
+ * @return the boolean {@code NDArray} for element-wise "Greater or equals" comparison
+ */
+ public NDArray gte(Number other) {
+ OpParams params = new OpParams();
+ params.add("scalar", other.toString());
+ return invoke(getParent(), "_npi_greater_equal_scalar", this, params);
+ }
+
+ /**
+ * Returns the boolean {@code NDArray} for element-wise "Greater or equals" comparison.
+ *
+ * @param other the number to compare
+ * @return the boolean {@code NDArray} for "Greater or equals" comparison
+ */
+ public NDArray gte(NDArray other) {
+ return invoke(getParent(), "_npi_greater_equal", new NDArray[] {this, other}, null);
+ }
+
+ /**
+ * Returns the boolean {@code NDArray} for element-wise "Less" comparison.
+ *
+ * @param other the number to compare
+ * @return the boolean {@code NDArray} for element-wise "Less" comparison
+ */
+ public NDArray lt(Number other) {
+ OpParams params = new OpParams();
+ params.add("scalar", other.toString());
+ return invoke(getParent(), "_npi_less_scalar", this, params);
+ }
+
+ /**
+ * Returns the boolean {@code NDArray} for element-wise "Less" comparison.
+ *
+ * @param other the {@code NDArray} to compare
+ * @return the boolean {@code NDArray} for element-wise "Less" comparison
+ */
+ public NDArray lt(NDArray other) {
+ return invoke(getParent(), "_npi_less", new NDArray[] {this, other}, null);
+ }
+
+ /**
+ * Returns the boolean {@code NDArray} for element-wise "Less or equals" comparison.
+ *
+ * @param other the number to compare
+ * @return the boolean {@code NDArray} for element-wise "Less or equals" comparison
+ */
+ public NDArray lte(Number other) {
+ OpParams params = new OpParams();
+ params.add("scalar", other.toString());
+ return invoke(getParent(), "_npi_less_equal_scalar", this, params);
+ }
+
+ /**
+ * Returns the boolean {@code NDArray} for element-wise "Less or equals" comparison.
+ *
+ * @param other the {@code NDArray} to compare
+ * @return the boolean {@code NDArray} for element-wise "Less or equals" comparison
+ */
+ public NDArray lte(NDArray other) {
+ return invoke(getParent(), "_npi_less_equal", new NDArray[] {this, other}, null);
+ }
+
+ /**
+ * Adds a number to this {@code NDArray} element-wise.
+ *
+ * @param n the number to add
+ * @return the result {@code NDArray}
+ */
+ public NDArray add(Number n) {
+ OpParams params = new OpParams();
+ params.add("scalar", n.toString());
+ return invoke(getParent(), "_npi_add_scalar", this, params);
+ }
+
+ /**
+ * Adds other {@code NDArray}s to this {@code NDArray} element-wise.
+ *
+ * @param other the other {@code NDArray}s to add
+ * @return the result {@code NDArray}
+ * @throws IllegalArgumentException others arrays must have at least one element
+ */
+ public NDArray add(NDArray other) {
+ return invoke(getParent(), "_npi_add", new NDArray[] {this, other}, null);
+ }
+
+ /**
+ * Subtracts a number from this {@code NDArray} element-wise.
+ *
+ * @param n the number to subtract from
+ * @return the result {@code NDArray}
+ */
+ public NDArray sub(Number n) {
+ OpParams params = new OpParams();
+ params.add("scalar", n.toString());
+ return invoke(getParent(), "_npi_subtract_scalar", this, params);
+ }
+
+ /**
+ * Subtracts the other {@code NDArray} from this {@code NDArray} element-wise.
+ *
+ * @param other the other {@code NDArray} to subtract from
+ * @return the result {@code NDArray}
+ */
+ public NDArray sub(NDArray other) {
+ return invoke(getParent(), "_npi_subtract", new NDArray[] {this, other}, null);
+ }
+
+ /**
+ * Multiplies this {@code NDArray} by a number element-wise.
+ *
+ * @param n the number to multiply by
+ * @return the result {@code NDArray}
+ */
+ public NDArray mul(Number n) {
+ OpParams params = new OpParams();
+ params.add("scalar", n.toString());
+ return invoke(getParent(), "_npi_multiply_scalar", this, params);
+ }
+
+ /**
+ * Multiplies this {@code NDArray} by other {@code NDArray}s element-wise.
+ *
+ * @param other the other {@code NDArray}s to multiply by
+ * @return the result {@code NDArray}
+ * @throws IllegalArgumentException others arrays must have at least one element
+ */
+ public NDArray mul(NDArray other) {
+ return invoke(getParent(), "_npi_multiply", new NDArray[] {this, other}, null);
+ }
+
+ /**
+ * Divides this {@code NDArray} by a number element-wise.
+ *
+ * @param n the number to divide by
+ * @return the result {@code NDArray}
+ */
+ public NDArray div(Number n) {
+ OpParams params = new OpParams();
+ params.add("scalar", n.toString());
+ return invoke(getParent(), "_npi_true_divide_scalar", this, params);
+ }
+
+ /**
+ * Divides this {@code NDArray} by the other {@code NDArray} element-wise.
+ *
+ * @param other the other {@code NDArray} to divide by
+ * @return the result {@code NDArray}
+ */
+ public NDArray div(NDArray other) {
+ return invoke(getParent(), "_npi_true_divide", new NDArray[] {this, other}, null);
+ }
+
+ /**
+ * Returns element-wise remainder of division.
+ *
+ * @param n the divisor number
+ * @return the result {@code NDArray}
+ */
+ public NDArray mod(Number n) {
+ OpParams params = new OpParams();
+ params.add("scalar", n.toString());
+ return invoke(getParent(), "_npi_mod_scalar", this, params);
+ }
+
+ /**
+ * Returns element-wise remainder of division.
+ *
+ * @param other the divisor {@code NDArray}
+ * @return the result {@code NDArray}
+ */
+ public NDArray mod(NDArray other) {
+ return invoke(getParent(), "_npi_mod", new NDArray[] {this, other}, null);
+ }
+
+ /**
+ * Takes the power of this {@code NDArray} with a number element-wise.
+ *
+ * @param n the number to take the power with
+ * @return the result {@code NDArray}
+ */
+ public NDArray pow(Number n) {
+ OpParams params = new OpParams();
+ params.add("scalar", n.toString());
+ return invoke(getParent(), "_npi_power_scalar", this, params);
+ }
+
+ /**
+ * Takes the power of this {@code NDArray} with the other {@code NDArray} element-wise.
+ *
+ * @param other the other {@code NDArray} to take the power with
+ * @return the result {@code NDArray}
+ */
+ public NDArray pow(NDArray other) {
+ return invoke(getParent(), "_npi_power", new NDArray[] {this, other}, null);
+ }
+
+ /**
+ * Adds a number to this {@code NDArray} element-wise in place.
+ *
+ * @param n the number to add
+ * @return the result {@code NDArray}
+ */
+ public NDArray addi(Number n) {
+ OpParams params = new OpParams();
+ params.add("scalar", n.toString());
+ invoke("_npi_add_scalar", new NDArray[] {this}, new NDArray[] {this}, params);
+ return this;
+ }
+
+ /**
+ * Adds other {@code NDArray}s to this {@code NDArray} element-wise in place.
+ *
+ * @param other the other {@code NDArray}s to add
+ * @return the result {@code NDArray}
+ */
+ public NDArray addi(NDArray other) {
+ invoke("_npi_add", new NDArray[] {this, other}, new NDArray[] {this}, null);
+ return this;
+ }
+
+ /**
+ * Subtracts a number from this {@code NDArray} element-wise in place.
+ *
+ * @param n the number to subtract
+ * @return the result {@code NDArray}
+ */
+ public NDArray subi(Number n) {
+ OpParams params = new OpParams();
+ params.add("scalar", n.toString());
+ invoke("_npi_subtract_scalar", new NDArray[] {this}, new NDArray[] {this}, params);
+ return this;
+ }
+
+ /**
+ * Subtracts the other {@code NDArray} from this {@code NDArray} element-wise in place.
+ *
+ * @param other the other {@code NDArray} to subtract from
+ * @return the result {@code NDArray}
+ */
+ public NDArray subi(NDArray other) {
+ invoke("_npi_subtract", new NDArray[] {this, other}, new NDArray[] {this}, null);
+ return this;
+ }
+
+ /**
+ * Multiplies this {@code NDArray} by a number element-wise in place.
+ *
+ * @param n the number to multiply by
+ * @return the result {@code NDArray}
+ */
+ public NDArray muli(Number n) {
+ OpParams params = new OpParams();
+ params.add("scalar", n.toString());
+ invoke("_npi_multiply_scalar", new NDArray[] {this}, new NDArray[] {this}, params);
+ return this;
+ }
+
+ /**
+ * Multiplies this {@code NDArray} by other {@code NDArray} element-wise in place.
+ *
+ * @param other the other NDArrays to multiply with
+ * @return the result {@code NDArray}
+ */
+ public NDArray muli(NDArray other) {
+ invoke("_npi_multiply", new NDArray[] {this, other}, new NDArray[] {this}, null);
+ return this;
+ }
+
+ /**
+ * Divides this {@code NDArray} by a number element-wise in place.
+ *
+ * @param n the number to divide values by
+ * @return the array after applying division operation
+ */
+ public NDArray divi(Number n) {
+ OpParams params = new OpParams();
+ params.add("scalar", n.toString());
+ invoke("_npi_true_divide_scalar", new NDArray[] {this}, new NDArray[] {this}, params);
+ return this;
+ }
+
+ /**
+ * Divides this {@code NDArray} by the other {@code NDArray} element-wise in place.
+ *
+ * @param other the other {@code NDArray} to divide by
+ * @return the result of the divide
+ */
+ public NDArray divi(NDArray other) {
+ invoke("_npi_true_divide", new NDArray[] {this, other}, new NDArray[] {this}, null);
+ return this;
+ }
+
+ /**
+ * Returns element-wise remainder of division in place.
+ *
+ * @param n the divisor number
+ * @return the result {@code NDArray}
+ */
+ public NDArray modi(Number n) {
+ OpParams params = new OpParams();
+ params.add("scalar", n.toString());
+ invoke("_npi_mod_scalar", new NDArray[] {this}, new NDArray[] {this}, params);
+ return this;
+ }
+
+ /**
+ * Returns in place element-wise remainder of division in place.
+ *
+ * @param other the divisor {@code NDArray}
+ * @return the result of the divide
+ */
+ public NDArray modi(NDArray other) {
+ invoke("_npi_mod", new NDArray[] {this, other}, new NDArray[] {this}, null);
+ return this;
+ }
+
+ /**
+ * Takes the power of this {@code NDArray} with a number element-wise in place.
+ *
+ * @param n the number to raise the power to
+ * @return the result {@code NDArray}
+ */
+ public NDArray powi(Number n) {
+ OpParams params = new OpParams();
+ params.add("scalar", n.toString());
+ invoke("_npi_power_scalar", new NDArray[] {this}, new NDArray[] {this}, params);
+ return this;
+ }
+
+ /**
+ * Takes the power of this {@code NDArray} with the other {@code NDArray} element-wise in place.
+ *
+ * @param other the other {@code NDArray} to take the power with
+ * @return the result {@code NDArray}
+ */
+ public NDArray powi(NDArray other) {
+ invoke("_npi_power", new NDArray[] {this, other}, new NDArray[] {this}, null);
+ return this;
+ }
+
+ /**
+ * Returns the element-wise sign.
+ *
+ * @return the result {@code NDArray}
+ */
+ public NDArray sign() {
+ return invoke(getParent(), "_npi_sign", this, null);
+ }
+
+ /**
+ * Returns the element-wise sign in-place.
+ *
+ * @return the result {@code NDArray}
+ */
+ public NDArray signi() {
+ invoke("_npi_sign", new NDArray[] {this}, new NDArray[] {this}, null);
+ return this;
+ }
+
+ /**
+ * Returns the maximum of this {@code NDArray} and a number element-wise.
+ *
+ * @param n the number to be compared
+ * @return the maximum of this {@code NDArray} and a number element-wise
+ */
+ public NDArray maximum(Number n) {
+ OpParams params = new OpParams();
+ params.add("scalar", n.toString());
+ return invoke(getParent(), "_npi_maximum_scalar", this, params);
+ }
+
+ /**
+ * Returns the maximum of this {@code NDArray} and the other {@code NDArray} element-wise.
+ *
+ * @param other the {@code NDArray} to be compared
+ * @return the maximum of this {@code NDArray} and the other {@code NDArray} element-wise
+ */
+ public NDArray maximum(NDArray other) {
+ return invoke(getParent(), "_npi_maximum", new NDArray[] {this, other}, null);
+ }
+
+ /**
+ * Returns the minimum of this {@code NDArray} and a number element-wise.
+ *
+ * @param n the number to be compared
+ * @return the minimum of this {@code NDArray} and a number element-wise
+ */
+ public NDArray minimum(Number n) {
+ OpParams params = new OpParams();
+ params.add("scalar", n.toString());
+ return invoke(getParent(), "_npi_minimum_scalar", this, params);
+ }
+
+ /**
+ * Returns the maximum of this {@code NDArray} and the other {@code NDArray} element-wise.
+ *
+ * @param other the {@code NDArray} to be compared
+ * @return the maximum of this {@code NDArray} and the other {@code NDArray} element-wise
+ */
+ public NDArray minimum(NDArray other) {
+ return invoke(getParent(), "_npi_minimum", new NDArray[] {this, other}, null);
+ }
+
+ /**
+ * Returns the numerical negative {@code NDArray} element-wise.
+ *
+ * @return the result {@code NDArray}
+ */
+ public NDArray neg() {
+ return invoke(getParent(), "_npi_negative", this, null);
+ }
+
+ /**
+ * Returns the numerical negative {@code NDArray} element-wise in place.
+ *
+ * @return the result {@code NDArray}
+ */
+ public NDArray negi() {
+ invoke("_npi_negative", new NDArray[] {this}, new NDArray[] {this}, null);
+ return this;
+ }
+
+ /**
+ * Returns the absolute value of this {@code NDArray} element-wise.
+ *
+ * @return the result {@code NDArray}
+ */
+ public NDArray abs() {
+ return invoke(getParent(), "_npi_absolute", this, null);
+ }
+
+ /**
+ * Returns the square of this {@code NDArray} element-wise.
+ *
+ * @return the result {@code NDArray}
+ */
+ public NDArray square() {
+ return invoke(getParent(), "_npi_square", this, null);
+ }
+
+ /**
+ * Returns the square root of this {@code NDArray} element-wise.
+ *
+ * @return the result {@code NDArray}
+ */
+ public NDArray sqrt() {
+ return invoke(getParent(), "_npi_sqrt", this, null);
+ }
+
+ /**
+ * Returns the cube-root of this {@code NDArray} element-wise.
+ *
+ * @return the result {@code NDArray}
+ */
+ public NDArray cbrt() {
+ return invoke(getParent(), "_npi_cbrt", this, null);
+ }
+
+ /**
+ * Returns the floor of this {@code NDArray} element-wise.
+ *
+ * @return the result {@code NDArray}
+ */
+ public NDArray floor() {
+ return invoke(getParent(), "_npi_floor", this, null);
+ }
+
+ /**
+ * Returns the ceiling of this {@code NDArray} element-wise.
+ *
+ * @return the result {@code NDArray}
+ */
+ public NDArray ceil() {
+ return invoke(getParent(), "_npi_ceil", this, null);
+ }
+
+ /**
+ * Returns the round of this {@code NDArray} element-wise.
+ *
+ * @return the result {@code NDArray}
+ */
+ public NDArray round() {
+ return invoke(getParent(), "round", this, null);
+ }
+
+ /**
+ * Returns the truncated value of this {@code NDArray} element-wise.
+ *
+ * @return the result {@code NDArray}
+ */
+ public NDArray trunc() {
+ return invoke(getParent(), "_npi_trunc", this, null);
+ }
+
+ /**
+ * Returns the exponential value of this {@code NDArray} element-wise.
+ *
+ * @return the result {@code NDArray}
+ */
+ public NDArray exp() {
+ return invoke(getParent(), "_npi_exp", this, null);
+ }
+
+ /**
+ * Returns the natural logarithmic value of this {@code NDArray} element-wise.
+ *
+ * @return the result {@code NDArray}
+ */
+ public NDArray log() {
+ return invoke(getParent(), "_npi_log", this, null);
+ }
+
+ /**
+ * Returns the base 10 logarithm of this {@code NDArray} element-wise.
+ *
+ * @return the result {@code NDArray}
+ */
+ public NDArray log10() {
+ return invoke(getParent(), "_npi_log10", this, null);
+ }
+
+ /**
+ * Returns the base 2 logarithm of this {@code NDArray} element-wise.
+ *
+ * @return the result {@code NDArray}
+ */
+ public NDArray log2() {
+ return invoke(getParent(), "_npi_log2", this, null);
+ }
+
+ /**
+ * Returns the trigonometric sine of this {@code NDArray} element-wise.
+ *
+ * @return the result {@code NDArray}
+ */
+ public NDArray sin() {
+ return invoke(getParent(), "_npi_sin", this, null);
+ }
+
+ /**
+ * Returns the trigonometric cosine of this {@code NDArray} element-wise.
+ *
+ * @return the result {@code NDArray}
+ */
+ public NDArray cos() {
+ return invoke(getParent(), "_npi_cos", this, null);
+ }
+
+ /**
+ * Returns the trigonometric tangent of this {@code NDArray} element-wise.
+ *
+ * @return the result {@code NDArray}
+ */
+ public NDArray tan() {
+ return invoke(getParent(), "_npi_tan", this, null);
+ }
+
+ /**
+ * Returns the inverse trigonometric sine of this {@code NDArray} element-wise.
+ *
+ * @return the result {@code NDArray}
+ */
+ public NDArray asin() {
+ return invoke(getParent(), "_npi_arcsin", this, null);
+ }
+
+ /**
+ * Returns the inverse trigonometric cosine of this {@code NDArray} element-wise.
+ *
+ * @return the result {@code NDArray}
+ */
+ public NDArray acos() {
+ return invoke(getParent(), "_npi_arccos", this, null);
+ }
+
+ /**
+ * Returns the inverse trigonometric tangent of this {@code NDArray} element-wise.
+ *
+ * @return the result {@code NDArray}
+ */
+ public NDArray atan() {
+ return invoke(getParent(), "_npi_arctan", this, null);
+ }
+
+ /**
+ * Returns the hyperbolic sine of this {@code NDArray} element-wise.
+ *
+ * @return the result {@code NDArray}
+ */
+ public NDArray sinh() {
+ return invoke(getParent(), "_npi_sinh", this, null);
+ }
+
+ /**
+ * Returns the hyperbolic cosine of this {@code NDArray} element-wise.
+ *
+ * @return the result {@code NDArray}
+ */
+ public NDArray cosh() {
+ return invoke(getParent(), "_npi_cosh", this, null);
+ }
+
+ /**
+ * Returns the hyperbolic tangent of this {@code NDArray} element-wise.
+ *
+ * @return the result {@code NDArray}
+ */
+ public NDArray tanh() {
+ return invoke(getParent(), "_npi_tanh", this, null);
+ }
+
+ /**
+ * Returns the inverse hyperbolic sine of this {@code NDArray} element-wise.
+ *
+ * @return the result {@code NDArray}
+ */
+ public NDArray asinh() {
+ return invoke(getParent(), "_npi_arcsinh", this, null);
+ }
+
+ /**
+ * Returns the inverse hyperbolic cosine of this {@code NDArray} element-wise.
+ *
+ * @return the result {@code NDArray}
+ */
+ public NDArray acosh() {
+ return invoke(getParent(), "_npi_arccosh", this, null);
+ }
+
+ /**
+ * Returns the inverse hyperbolic tangent of this {@code NDArray} element-wise.
+ *
+ * @return the result {@code NDArray}
+ */
+ public NDArray atanh() {
+ return invoke(getParent(), "_npi_arctanh", this, null);
+ }
+
+ /**
+ * Converts this {@code NDArray} from radians to degrees element-wise.
+ *
+ * @return the result {@code NDArray}
+ */
+ public NDArray toDegrees() {
+ return invoke(getParent(), "_npi_degrees", this, null);
+ }
+
+ /**
+ * Converts this {@code NDArray} from degrees to radians element-wise.
+ *
+ * @return the result {@code NDArray}
+ */
+ public NDArray toRadians() {
+ return invoke(getParent(), "_npi_radians", this, null);
+ }
+
+ /**
+ * Returns the maximum of this {@code NDArray}.
+ *
+ * @return the maximum of this {@code NDArray}
+ */
+ public NDArray max() {
+ return invoke(getParent(), "_np_max", this, null);
+ }
+
+ /**
+ * Returns the maximum of this {@code NDArray} along given axes.
+ *
+ * @param axes the axes along which to operate
+ * @return the maximum of this {@code NDArray} with the specified axes removed from the Shape
+ * containing the max
+ * @see NDArray#max(int[], boolean)
+ */
+ public NDArray max(int[] axes) {
+ OpParams params = new OpParams();
+ params.addTupleParam("axis", axes);
+ return invoke(getParent(), "_np_max", this, params);
+ }
+
+ /**
+ * Returns the maximum of this {@code NDArray} along given axes.
+ *
+ * @param axes the axes along which to operate
+ * @param keepDims {@code true} to keep the specified axes as size 1 in the output array, {@code
+ * false} to squeeze the values out of the output array.
+ * @return the maximum of this {@code NDArray}
+ */
+ public NDArray max(int[] axes, boolean keepDims) {
+ OpParams params = new OpParams();
+ params.addTupleParam("axis", axes);
+ params.addParam("keepdims", keepDims);
+ return invoke(getParent(), "_np_max", this, params);
+ }
+
+ /**
+ * Returns the minimum of this {@code NDArray}.
+ *
+ * @return the minimum of this {@code NDArray}
+ */
+ public NDArray min() {
+ return invoke(getParent(), "_np_min", this, null);
+ }
+
+ /**
+ * Returns the minimum of this {@code NDArray} along given axes.
+ *
+ * @param axes the axes along which to operate
+ * @return the minimum of this {@code NDArray} with the specified axes removed from the Shape
+ * containing the min
+ * @see NDArray#min(int[], boolean)
+ */
+ public NDArray min(int[] axes) {
+ return min(axes, false);
+ }
+
+ /**
+ * Returns the minimum of this {@code NDArray} along given axes.
+ *
+ * @param axes the axes along which to operate
+ * @param keepDims {@code true} to keep the specified axes as size 1 in the output array, {@code
+ * false} to squeeze the values out of the output array
+ * @return the minimum of this {@code NDArray}
+ */
+ public NDArray min(int[] axes, boolean keepDims) {
+ OpParams params = new OpParams();
+ params.addTupleParam("axis", axes);
+ params.addParam("keepdims", keepDims);
+ return invoke(getParent(), "_np_min", this, params);
+ }
+
+ /**
+ * Returns the sum of this {@code NDArray}.
+ *
+ * @return the sum of this {@code NDArray}
+ */
+ public NDArray sum() {
+ // TODO current windows doesn't support boolean MxNDArray
+ if (System.getProperty("os.name").toLowerCase().contains("win")) {
+ DataType target = getDataType();
+ if (!target.isFloating()) {
+ try (NDArray thisArr = toType(DataType.FLOAT32, false)) {
+ if (target == DataType.BOOLEAN) {
+ target = DataType.INT64;
+ }
+ try (NDArray array = invoke(getParent(), "_np_sum", thisArr, null)) {
+ return array.toType(target, false);
+ }
+ }
+ }
+ }
+ return invoke(getParent(), "_np_sum", this, null);
+ }
+
+ /**
+ * Returns the sum of this {@code NDArray} along given axes.
+ *
+ * @param axes the axes along which to operate
+ * @return the sum of this {@code NDArray} with the specified axes removed from the Shape
+ * containing the sum
+ * @see NDArray#sum(int[], boolean)
+ */
+ public NDArray sum(int[] axes) {
+ return sum(axes, false);
+ }
+
+ /**
+ * Returns the sum of this {@code NDArray} along given axes.
+ *
+ * @param axes the axes along which to operate
+ * @param keepDims {@code true} to keep the specified axes as size 1 in the output array, {@code
+ * false} to squeeze the values out of the output array
+ * @return the sum of this {@code NDArray}
+ */
+ public NDArray sum(int[] axes, boolean keepDims) {
+ OpParams params = new OpParams();
+ params.addTupleParam("axis", axes);
+ params.addParam("keepdims", keepDims);
+ return invoke(getParent(), "_np_sum", this, params);
+ }
+
+ /**
+ * Returns the product of this {@code NDArray}.
+ *
+ * @return the product of this {@code NDArray}
+ */
+ public NDArray prod() {
+ return invoke(getParent(), "_np_prod", this, null);
+ }
+
+ /**
+ * Returns the product of this {@code NDArray} elements over the given axes.
+ *
+ * @param axes the axes along which to operate
+ * @return the product of this {@code NDArray} with the specified axes removed from the Shape
+ * containing the prod
+ * @see NDArray#prod(int[], boolean)
+ */
+ NDArray prod(int[] axes) {
+ return prod(axes, false);
+ }
+
+ /**
+ * Returns the product of this {@code NDArray} elements over the given axes.
+ *
+ * @param axes the axes along which to operate
+ * @param keepDims {@code true} to keep the specified axes as size 1 in the output array, {@code
+ * false} to squeeze the values out of the output array
+ * @return the product of this {@code NDArray}
+ */
+ public NDArray prod(int[] axes, boolean keepDims) {
+ OpParams params = new OpParams();
+ params.addTupleParam("axis", axes);
+ params.addParam("keepdims", keepDims);
+ return invoke(getParent(), "_np_prod", this, params);
+ }
+
+ /**
+ * Returns the average of this {@code NDArray}.
+ *
+ * @return the average of this {@code NDArray}
+ */
+ public NDArray mean() {
+ return invoke(getParent(), "_npi_mean", this, null);
+ }
+
+ /**
+ * Returns the average of this {@code NDArray} along given axes.
+ *
+ * @param axes the axes along which to operate
+ * @return the average of this {@code NDArray} with the specified axes removed from the Shape
+ * containing the mean
+ * @see NDArray#mean(int[], boolean)
+ */
+ public NDArray mean(int[] axes) {
+ return mean(axes, false);
+ }
+
+ /**
+ * Returns the average of this {@code NDArray} along given axes.
+ *
+ * @param axes the axes along which to operate
+ * @param keepDims {@code true} to keep the specified axes as size 1 in the output array, {@code
+ * false} to squeeze the values out of the output array
+ * @return the average of this {@code NDArray}
+ */
+ public NDArray mean(int[] axes, boolean keepDims) {
+ OpParams params = new OpParams();
+ params.addTupleParam("axis", axes);
+ params.addParam("keepdims", keepDims);
+ return invoke(getParent(), "_npi_mean", this, params);
+ }
+
+ /**
+ * Rotates an array by 90 degrees in the plane specified by axes.
+ *
+ * @param times Number of times the array is rotated by 90 degrees.
+ * @param axes The array is rotated in the plane defined by the axes. Axes must be different.
+ * @return the rotated NDArray
+ */
+ public NDArray rotate90(int times, int[] axes) {
+ if (axes.length != 2) {
+ throw new IllegalArgumentException("Axes must be 2");
+ }
+ OpParams params = new OpParams();
+ params.addTupleParam("axes", axes);
+ params.addParam("k", times);
+ return invoke(getParent(), "_npi_rot90", this, params);
+ }
+
+ /**
+ * Returns the sum along diagonals of this {@code NDArray}.
+ *
+ * <p>If this {@code NDArray} is 2-D, the sum along its diagonal with the given offset is
+ * returned, i.e., the sum of elements a[i,i+offset] for all i. If this {@code NDArray} has more
+ * than two dimensions, then the axes specified by axis1 and axis2 are used to determine the 2-D
+ * sub-arrays whose traces are returned. The {@link Shape} of the resulting array is the same as
+ * this {@code NDArray} with axis1 and axis2 removed.
+ *
+ * @param offset offset of the diagonal from the main diagonal. Can be both positive and
+ * negative.
+ * @param axis1 axes to be used as the first axis of the 2-D sub-arrays from which the diagonals
+ * should be taken
+ * @param axis2 axes to be used as the second axis of the 2-D sub-arrays from which the
+ * diagonals should be taken
+ * @return the sum along diagonals of this {@code NDArray}
+ */
+ public NDArray trace(int offset, int axis1, int axis2) {
+ OpParams params = new OpParams();
+ params.addParam("offset", offset);
+ params.addParam("axis1", axis1);
+ params.addParam("axis2", axis2);
+ return invoke(getParent(), "_np_trace", this, params);
+ }
+
+ /**
+ * Returns the sum along diagonals of this {@code NDArray}.
+ *
+ * <p>If this {@code NDArray} is 2-D, the sum along its diagonal with the given offset is
+ * returned, i.e., the sum of elements a[i,i+offset] for all i. If this {@code NDArray} has more
+ * than two dimensions, then the axes specified by axis1 and axis2 are used to determine the 2-D
+ * sub-arrays whose traces are returned. The {@link Shape} of the resulting array is the same as
+ * this {@code NDArray} with axis1 and axis2 removed.
+ *
+ * @param offset offset of the diagonal from the main diagonal. Can be both positive and
+ * negative.
+ * @return the sum along diagonals of this {@code NDArray}
+ */
+ public NDArray trace(int offset) {
+ return trace(offset, 0, 1);
+ }
+
+ /**
+ * Splits this {@code NDArray} into multiple sub{@code NDArray}s given sections along the given
+ * axis.
+ *
+ * @param indices this {@code NDArray} will be divided into N (sections) equal arrays along axis
+ * @param axis the axis to split along
+ * @return an {@link NDList} with numOutputs {@code NDArray}s with {@link Shape} {@code
+ * (this.shape.axis /= axis) }
+ * @throws IllegalArgumentException thrown if the numOutputs does not equally divide the given
+ * axis
+ */
+ public NDList split(long[] indices, int axis) {
+ if (indices.length == 0) {
+ return new NDList(this);
+ }
+ OpParams params = new OpParams();
+ // follow the numpy behavior
+ if (indices[0] != 0) {
+ long[] tempIndices = new long[indices.length + 1];
+ tempIndices[0] = 0;
+ System.arraycopy(indices, 0, tempIndices, 1, indices.length);
+ indices = tempIndices;
+ }
+ params.addTupleParam("indices", indices);
+ params.addParam("axis", axis);
+ params.addParam("squeeze_axis", false);
+ return invoke(getParent(), "_npi_split", new NDList(this), params);
+ }
+
+ /**
+ * Flattens this {@code NDArray} into a 1-D {@code NDArray} in row-major order.
+ *
+ * <p>To flatten in column-major order, first transpose this {@code NDArray}
+ *
+ * @return a 1-D {@code NDArray} of equal size
+ */
+ public NDArray flatten() {
+ return reshape(new Shape(Math.toIntExact(size())));
+ }
+
+ /**
+ * Reshapes this {@code NDArray} to the given {@link Shape}.
+ *
+ * <p>You can reshape it to match another NDArray by calling {@code a.reshape(b.getShape()) }
+ *
+ * @param shape the {@link Shape} to reshape into. Must have equal size to the current shape
+ * @return a reshaped {@code NDArray}
+ * @throws IllegalArgumentException thrown if the given {@link Shape} does not match the size of
+ * the current shape
+ */
+ public NDArray reshape(Shape shape) {
+ OpParams params = new OpParams();
+ params.addParam("newshape", shape);
+ return invoke(getParent(), "_np_reshape", this, params);
+ }
+
+ /**
+ * Reshapes this {@code NDArray} to the given {@link Shape}.
+ *
+ * @param newShape the long array to reshape into. Must have equal size to the current shape
+ * @return a reshaped {@code NDArray}
+ * @throws IllegalArgumentException thrown if the given {@link Shape} does not match the size of
+ * the current shape
+ */
+ public NDArray reshape(long... newShape) {
+ return reshape(new Shape(newShape));
+ }
+
+ /**
+ * Expands the {@link Shape} of a {@code NDArray}.
+ *
+ * <p>Inserts a new axis that will appear at the axis position in the expanded {@code NDArray}
+ * shape.
+ *
+ * @param axis the position in the expanded axes where the new axis is placed
+ * @return the result {@code NDArray}. The number of dimensions is one greater than that of the
+ * {@code NDArray}
+ */
+ public NDArray expandDims(int axis) {
+ OpParams params = new OpParams();
+ params.addParam("axis", axis);
+ return invoke(getParent(), "_npi_expand_dims", this, params);
+ }
+
+ /**
+ * Removes all singleton dimensions from this {@code NDArray} {@link Shape}.
+ *
+ * @return a result {@code NDArray} of same size and data without singleton dimensions
+ */
+ public NDArray squeeze() {
+ return invoke(getParent(), "_np_squeeze", this, null);
+ }
+
+ /**
+ * Removes singleton dimensions at the given axes.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> NDArray array = manager.create(new float[] {0f, 1f, 2f}, new Shape(1, 3, 1));
+ * jshell> array;
+ * ND: (1, 3, 1) cpu() float32
+ * [[[0.],
+ * [1.],
+ * [2.],
+ * ],
+ * ]
+ * jshell> array.squeeze(new int[] {0, 2});
+ * ND: (3) cpu() float32
+ * [0., 1., 2.]
+ * </pre>
+ *
+ * @param axes the axes at which to remove the singleton dimensions
+ * @return a result {@code NDArray} of same size and data without the axes at part of the shape
+ * @throws IllegalArgumentException thrown if any of the given axes are not a singleton
+ * dimension
+ */
+ public NDArray squeeze(int[] axes) {
+ OpParams params = new OpParams();
+ params.addTupleParam("axis", axes);
+ return invoke(getParent(), "_np_squeeze", this, params);
+ }
+
+ /**
+ * Returns the truth value of this {@code NDArray} AND the other {@code NDArray} element-wise.
+ *
+ * <p>The shapes of this {@code NDArray} and the other {@code NDArray} must be broadcastable.
+ *
+ * @param other the other {@code NDArray} to operate on
+ * @return the boolean {@code NDArray} of the logical AND operation applied to the elements of
+ * this {@code NDArray} and the other {@code NDArray}
+ */
+ public NDArray logicalAnd(NDArray other) {
+ // TODO switch to numpy op, although current op support zero-dim, scalar
+ NDArray thisArr =
+ (getDataType() == DataType.BOOLEAN) ? toType(DataType.INT32, false) : this;
+ other =
+ (other.getDataType() == DataType.BOOLEAN)
+ ? other.toType(DataType.INT32, false)
+ : other;
+ return invoke(getParent(), "broadcast_logical_and", new NDArray[] {thisArr, other}, null)
+ .toType(DataType.BOOLEAN, false);
+ }
+
+ /**
+ * Computes the truth value of this {@code NDArray} OR the other {@code NDArray} element-wise.
+ *
+ * <p>The shapes of this {@code NDArray} and the other {@code NDArray} must be broadcastable.
+ *
+ * @param other the other {@code NDArray} to operate on
+ * @return the boolean {@code NDArray} of the logical OR operation applied to the elements of
+ * this {@code NDArray} and the other {@code NDArray}
+ */
+ public NDArray logicalOr(NDArray other) {
+ // TODO switch to numpy op, although current op support zero-dim, scalar
+ NDArray thisArr =
+ (getDataType() == DataType.BOOLEAN) ? toType(DataType.INT32, false) : this;
+ other =
+ (other.getDataType() == DataType.BOOLEAN)
+ ? other.toType(DataType.INT32, false)
+ : other;
+ return invoke(getParent(), "broadcast_logical_or", new NDArray[] {thisArr, other}, null)
+ .toType(DataType.BOOLEAN, false);
+ }
+
+ /**
+ * Computes the truth value of this {@code NDArray} XOR the other {@code NDArray} element-wise.
+ *
+ * <p>The shapes of this {@code NDArray} and the other {@code NDArray} must be broadcastable.
+ *
+ * @param other the other {@code NDArray} to operate on
+ * @return the boolean {@code NDArray} of the logical XOR operation applied to the elements of
+ * this {@code NDArray} and the other {@code NDArray}
+ */
+ public NDArray logicalXor(NDArray other) {
+ // TODO switch to numpy op, although current op support zero-dim, scalar
+ NDArray thisArr =
+ (getDataType() == DataType.BOOLEAN) ? toType(DataType.INT32, false) : this;
+ other =
+ (other.getDataType() == DataType.BOOLEAN)
+ ? other.toType(DataType.INT32, false)
+ : other;
+ return invoke(getParent(), "broadcast_logical_xor", new NDArray[] {thisArr, other}, null)
+ .toType(DataType.BOOLEAN, false);
+ }
+
+ /**
+ * Computes the truth value of NOT this {@code NDArray} element-wise.
+ *
+ * @return the boolean {@code NDArray}
+ */
+ public NDArray logicalNot() {
+ return invoke(getParent(), "_npi_logical_not", this, null);
+ }
+
+ /**
+ * Returns the indices that would sort this {@code NDArray} given the axis.
+ *
+ * <p>Perform an indirect sort along the given axis. It returns a {@code NDArray} of indices of
+ * the same {@link Shape} as this {@code NDArray}.
+ *
+ * @param axis the axis to sort along
+ * @param ascending whether to sort ascending
+ * @return a {@code NDArray} of indices corresponding to elements in this {@code NDArray} on the
+ * axis, the output DataType is always {@link DataType#INT64}
+ */
+ public NDArray argSort(int axis, boolean ascending) {
+ OpParams params = new OpParams();
+ params.addParam("axis", axis);
+ // be careful that MXNet numpy argsort op didn't officially support this param
+ params.addParam("is_ascend", ascending);
+ params.setDataType(DataType.INT64);
+ return invoke(getParent(), "_npi_argsort", this, params);
+ }
+
+ /**
+ * Sorts the flattened {@code NDArray}.
+ *
+ * @param axis the axis to sort along
+ * @return the sorted {@code NDArray}
+ */
+ public NDArray sort(int axis) {
+ OpParams params = new OpParams();
+ params.addParam("axis", axis);
+ return invoke(getParent(), "_npi_sort", this, params);
+ }
+
+ /**
+ * Sorts the flattened {@code NDArray}.
+ *
+ * @return the sorted {@code NDArray}
+ */
+ public NDArray sort() {
+ return invoke(getParent(), "_npi_sort", this, null);
+ }
+
+ /**
+ * Applies the softmax function along the given axis.
+ *
+ * @param axis the axis along which to apply
+ * @return the result {@code NDArray}
+ * @see <a href="https://en.wikipedia.org/wiki/Softmax_function">softmax</a>
+ * @see NDArray#softmax(int)
+ */
+ public NDArray softmax(int axis) {
+ // MXNet softmax op bug on GPU
+ if (isEmpty()) {
+ return create(getParent(), getShape(), DataType.FLOAT32, getDevice());
+ }
+ OpParams params = new OpParams();
+ params.addParam("axis", axis);
+ return invoke(getParent(), "_npx_softmax", this, params);
+ }
+
+ /**
+ * Applies the softmax function followed by a logarithm.
+ *
+ * <p>Mathematically equivalent to calling softmax and then log. This single operator is faster
+ * than calling two operators and numerically more stable when computing gradients.
+ *
+ * @param axis the axis along which to apply
+ * @return the result {@code NDArray}
+ */
+ public NDArray logSoftmax(int axis) {
+ // MXNet logsoftmax op bug on GPU
+ if (isEmpty()) {
+ return create(getParent(), getShape(), DataType.FLOAT32, getDevice());
+ }
+ OpParams params = new OpParams();
+ params.addParam("axis", axis);
+ return invoke(getParent(), "_npx_log_softmax", this, params);
+ }
+
+ /**
+ * Returns the cumulative sum of the elements in the flattened {@code NDArray}.
+ *
+ * @return the cumulative sum of the elements in the flattened {@code NDArray}
+ */
+ public NDArray cumSum() {
+ return invoke(getParent(), "_np_cumsum", this, null);
+ }
+
+ /**
+ * Return the cumulative sum of the elements along a given axis.
+ *
+ * @param axis the axis along which the cumulative sum is computed
+ * @return the cumulative sum along the specified axis
+ */
+ public NDArray cumSum(int axis) {
+ OpParams params = new OpParams();
+ params.addParam("axis", axis);
+ return invoke(getParent(), "_np_cumsum", this, params);
+ }
+
+ /**
+ * Replace the handle of the NDArray with the other. The NDArray used for replacement will be
+ * killed.
+ *
+ * <p>Please use with caution, this method will make the input argument unusable.
+ *
+ * @param replaced the handle provider that will be killed
+ */
+ public void intern(NDArray replaced) {
+ NDArray arr = replaced;
+ Pointer oldHandle = handle.getAndSet(arr.handle.getAndSet(null));
+ JnaUtils.waitToRead(oldHandle);
+ JnaUtils.freeNdArray(oldHandle);
+ // dereference old ndarray
+ arr.close();
+ }
+
+ /**
+ * Returns the boolean {@code NDArray} with value {@code true} where this {@code NDArray}'s
+ * entries are infinite, or {@code false} where they are not infinite.
+ *
+ * @return the boolean {@code NDArray} with value {@code true} if this {@code NDArray}'s entries
+ * are infinite
+ */
+ public NDArray isInfinite() {
+ throw new UnsupportedOperationException("Not implemented yet.");
+ }
+
+ /**
+ * Returns the boolean {@code NDArray} with value {@code true} where this {@code NDArray}'s
+ * entries are NaN, or {@code false} where they are not NaN.
+ *
+ * @return the boolean {@code NDArray} with value {@code true} if this {@code NDArray}'s {@link
+ * NDArray} are NaN
+ */
+ public NDArray isNaN() {
+ return invoke(getParent(), "_npi_isnan", this, null);
+ }
+
+ /**
+ * Returns a dense representation of the sparse {@code NDArray}.
+ *
+ * @return the result {@code NDArray}
+ */
+ public NDArray toDense() {
+ if (!isSparse()) {
+ return duplicate();
+ }
+ return castStorage(SparseFormat.DENSE);
+ }
+
+ /**
+ * Returns a sparse representation of {@code NDArray}.
+ *
+ * @param fmt the {@link SparseFormat} of this {@code NDArray}
+ * @return the result {@code NDArray}
+ */
+ public NDArray toSparse(SparseFormat fmt) {
+ if (fmt != SparseFormat.DENSE
+ && fmt != SparseFormat.CSR
+ && fmt != SparseFormat.ROW_SPARSE) {
+ throw new UnsupportedOperationException(fmt + " is not supported");
+ }
+ if (fmt == getSparseFormat()) {
+ return duplicate();
+ }
+ return castStorage(fmt);
+ }
+
+ private NDArray castStorage(SparseFormat fmt) {
+ OpParams params = new OpParams();
+ params.setParam("stype", fmt.getType());
+ return invoke(getParent(), "cast_storage", this, params);
+ }
+
+ /**
+ * Constructs a {@code NDArray} by repeating this {@code NDArray} the number of times given
+ * repeats.
+ *
+ * @param repeats the number of times to repeat for each dimension
+ * @return a NDArray that has been tiled
+ */
+ public NDArray tile(long repeats) {
+ // zero-dim
+ if (isEmpty()) {
+ return duplicate();
+ }
+ // scalar
+ int dim = (isScalar()) ? 1 : getShape().dimension();
+ long[] repeatsArray = new long[dim];
+ Arrays.fill(repeatsArray, repeats);
+ return tile(repeatsArray);
+ }
+
+ /**
+ * Constructs a {@code NDArray} by repeating this {@code NDArray} the number of times given by
+ * repeats.
+ *
+ * @param repeats the number of times to repeat along each axis
+ * @return a {@code NDArray} that has been tiled
+ */
+ public NDArray tile(long[] repeats) {
+ OpParams params = new OpParams();
+ params.addTupleParam("reps", repeats);
+ return invoke(getParent(), "_npi_tile", this, params);
+ }
+
+ /**
+ * Constructs a {@code NDArray} by repeating this {@code NDArray} the number of times given by
+ * repeats along given axis.
+ *
+ * @param axis the axis to repeat
+ * @param repeats the number of times to repeat for each axis
+ * @return a {@code NDArray} that has been tiled
+ * @throws IllegalArgumentException thrown for invalid axis
+ */
+ public NDArray tile(int axis, long repeats) {
+ // scalar
+ if (isScalar()) {
+ throw new IllegalArgumentException("scalar didn't support specifying axis");
+ }
+ long[] repeatsArray = new long[getShape().dimension()];
+ Arrays.fill(repeatsArray, 1);
+ repeatsArray[withAxis(axis)] = repeats;
+ return tile(repeatsArray);
+ }
+
+ /**
+ * Constructs a {@code NDArray} by repeating this {@code NDArray} the number of times to match
+ * the desired shape.
+ *
+ * <p>If the desired {@link Shape}has fewer dimensions than this {@code NDArray}, it will tile
+ * against the last axis.
+ *
+ * @param desiredShape the {@link Shape}that should be converted to
+ * @return a {@code NDArray} that has been tiled
+ */
+ public NDArray tile(Shape desiredShape) {
+ return tile(repeatsToMatchShape(desiredShape));
+ }
+
+ private int withAxis(int axis) {
+ return Math.floorMod(axis, getShape().dimension());
+ }
+
+ private long[] repeatsToMatchShape(Shape desiredShape) {
+ Shape curShape = getShape();
+ int dimension = curShape.dimension();
+ if (desiredShape.dimension() > dimension) {
+ throw new IllegalArgumentException("The desired shape has too many dimensions");
+ }
+ if (desiredShape.dimension() < dimension) {
+ int additionalDimensions = dimension - desiredShape.dimension();
+ desiredShape = curShape.slice(0, additionalDimensions).addAll(desiredShape);
+ }
+ long[] repeats = new long[dimension];
+ for (int i = 0; i < dimension; i++) {
+ if (curShape.get(i) == 0 || desiredShape.get(i) % curShape.get(i) != 0) {
+ throw new IllegalArgumentException(
+ "The desired shape is not a multiple of the original shape");
+ }
+ repeats[i] = Math.round(Math.ceil((double) desiredShape.get(i) / curShape.get(i)));
+ }
+ return repeats;
+ }
+
+ /**
+ * Repeats element of this {@code NDArray} the number of times given repeats.
+ *
+ * @param repeats the number of times to repeat for each axis
+ * @return an {@code NDArray} that has been repeated
+ */
+ public NDArray repeat(long repeats) {
+ // zero-dim
+ if (isEmpty()) {
+ return duplicate();
+ }
+ // scalar
+ int dim = (isScalar()) ? 1 : getShape().dimension();
+ long[] repeatsArray = new long[dim];
+ Arrays.fill(repeatsArray, repeats);
+ return repeat(repeatsArray);
+ }
+
+ /**
+ * Repeats element of this {@code NDArray} the number of times given repeats along given axis.
+ *
+ * @param axis the axis to repeat
+ * @param repeats the number of times to repeat for each axis
+ * @return an {@code NDArray} that has been repeated
+ * @throws IllegalArgumentException thrown for invalid axis
+ */
+ public NDArray repeat(int axis, long repeats) {
+ long[] repeatsArray = new long[getShape().dimension()];
+ Arrays.fill(repeatsArray, 1);
+ repeatsArray[withAxis(axis)] = repeats;
+ return repeat(repeatsArray);
+ }
+
+ /**
+ * Repeats element of this {@code NDArray} the number of times given repeats along each axis.
+ *
+ * @param repeats the number of times to repeat along each axis
+ * @return a {@code NDArray} that has been repeated
+ */
+ public NDArray repeat(long[] repeats) {
+ // TODO get rid of for loop once bug in MXNet np.repeat is fixed
+ NDArray array = this;
+ int baseAxis = getShape().dimension() - repeats.length;
+ for (int i = 0; i < repeats.length; i++) {
+ if (repeats[i] > 1) {
+ NDArray previousArray = array;
+ OpParams params = new OpParams();
+ params.addParam("repeats", repeats[i]);
+ params.addParam("axis", baseAxis + i);
+ array = invoke(getParent(), "_np_repeat", array, params);
+ if (previousArray != this) {
+ previousArray.close();
+ }
+ }
+ }
+ return array;
+ }
+
+ /**
+ * Repeats element of this {@code NDArray} to match the desired shape.
+ *
+ * <p>If the desired {@link Shape} has fewer dimensions that the array, it will repeat against
+ * the last axis.
+ *
+ * @param desiredShape the {@link Shape} that should be converted to
+ * @return an {@code NDArray} that has been repeated
+ */
+ public NDArray repeat(Shape desiredShape) {
+ return repeat(repeatsToMatchShape(desiredShape));
+ }
+
+ /**
+ * Dot product of this {@code NDArray} and the other {@code NDArray}.
+ *
+ * <ul>
+ * <li>If both this {@code NDArray} and the other {@code NDArray} are 1-D {@code NDArray}s, it
+ * is inner product of vectors (without complex conjugation).
+ * <li>If both this {@code NDArray} and the other {@code NDArray} are 2-D {@code NDArray}s, it
+ * is matrix multiplication.
+ * <li>If either this {@code NDArray} or the other {@code NDArray} is 0-D {@code NDArray}
+ * (scalar), it is equivalent to mul.
+ * <li>If this {@code NDArray} is N-D {@code NDArray} and the other {@code NDArray} is 1-D
+ * {@code NDArray}, it is a sum product over the last axis of those.
+ * <li>If this {@code NDArray} is N-D {@code NDArray} and the other {@code NDArray} is M-D
+ * {@code NDArray}(where M>=2), it is a sum product over the last axis of this
+ * {@code NDArray} and the second-to-last axis of the other {@code NDArray}
+ * </ul>
+ *
+ * @param other the other {@code NDArray} to perform dot product with
+ * @return the result {@code NDArray}
+ */
+ public NDArray dot(NDArray other) {
+ return invoke(getParent(), "_np_dot", new NDArray[] {this, other}, null);
+ }
+
+ /**
+ * Product matrix of this {@code NDArray} and the other {@code NDArray}.
+ *
+ * @param other the other {@code NDArray} to perform matrix product with
+ * @return the result {@code NDArray}
+ */
+ public NDArray matMul(NDArray other) {
+ if (isScalar() || other.isScalar()) {
+ throw new IllegalArgumentException("scalar is not allowed for matMul()");
+ }
+ return invoke(getParent(), "_npi_matmul", new NDArray[] {this, other}, null);
+ }
+
+ /**
+ * Clips (limit) the values in this {@code NDArray}.
+ *
+ * <p>Given an interval, values outside the interval are clipped to the interval edges. For
+ * example, if an interval of [0, 1] is specified, values smaller than 0 become 0, and values
+ * larger than 1 become 1.
+ *
+ * @param min the minimum value
+ * @param max the maximum value
+ * @return an {@code NDArray} with the elements of this {@code NDArray}, but where values <
+ * min are replaced with min, and those > max with max
+ */
+ public NDArray clip(Number min, Number max) {
+ OpParams params = new OpParams();
+ params.addParam("a_min", min);
+ params.addParam("a_max", max);
+ return invoke(getParent(), "_npi_clip", this, params);
+ }
+
+ /**
+ * Interchanges two axes of this {@code NDArray}.
+ *
+ * @param axis1 the first axis
+ * @param axis2 the second axis
+ * @return the swapped axes {@code NDArray}
+ */
+ public NDArray swapAxes(int axis1, int axis2) {
+ OpParams params = new OpParams();
+ params.addParam("dim1", axis1);
+ params.addParam("dim2", axis2);
+ return invoke(getParent(), "_npi_swapaxes", this, params);
+ }
+
+ /**
+ * Returns the reverse order of elements in an array along the given axis.
+ *
+ * <p>The shape of the array is preserved, but the elements are reordered.
+ *
+ * @param axes the axes to flip on
+ * @return the newly flipped array
+ */
+ public NDArray flip(int... axes) {
+ OpParams params = new OpParams();
+ params.addTupleParam("axis", axes);
+ return invoke(getParent(), "_npi_flip", this, params);
+ }
+
+ /**
+ * Returns this {@code NDArray} with axes transposed.
+ *
+ * @return the newly permuted array
+ */
+ public NDArray transpose() {
+ return invoke(getParent(), "_np_transpose", this, null);
+ }
+
+ /**
+ * Returns this {@code NDArray} with given axes transposed.
+ *
+ * @param dimensions the axes to swap to
+ * @return the transposed {@code NDArray}
+ * @throws IllegalArgumentException thrown when passing a axis that is greater than the actual
+ * number of dimensions
+ */
+ public NDArray transpose(int... dimensions) {
+ if (Arrays.stream(dimensions).anyMatch(d -> d < 0)) {
+ throw new UnsupportedOperationException(
+ "Passing -1 for broadcasting the dimension is not currently supported");
+ }
+ if (!Arrays.equals(
+ Arrays.stream(dimensions).sorted().toArray(),
+ IntStream.range(0, getShape().dimension()).toArray())) {
+ throw new IllegalArgumentException(
+ "You must include each of the dimensions from 0 until "
+ + getShape().dimension());
+ }
+ OpParams params = new OpParams();
+ params.addTupleParam("axes", dimensions);
+ return invoke(getParent(), "_np_transpose", this, params);
+ }
+
+ /**
+ * Broadcasts this {@code NDArray} to be the given shape.
+ *
+ * @param shape the new {@link Shape} of this {@code NDArray}
+ * @return the broadcasted {@code NDArray}
+ */
+ public NDArray broadcast(Shape shape) {
+ OpParams params = new OpParams();
+ params.setShape(shape);
+ return invoke(getParent(), "_npi_broadcast_to", this, params);
+ }
+
+ /**
+ * Returns the indices of the maximum values into the flattened {@code NDArray}.
+ *
+ * @return a {@code NDArray} containing indices
+ */
+ public NDArray argMax() {
+ if (isEmpty()) {
+ throw new IllegalArgumentException("attempt to get argMax of an empty MxNDArray");
+ }
+ return invoke(getParent(), "_npi_argmax", this, null);
+ }
+
+ /**
+ * Returns the indices of the maximum values along given axis.
+ *
+ * @param axis the axis along which to find maximum values
+ * @return a {@code NDArray} containing indices
+ */
+ public NDArray argMax(int axis) {
+ OpParams params = new OpParams();
+ params.addParam("axis", axis);
+ return invoke(getParent(), "_npi_argmax", this, params);
+ }
+
+ /**
+ * Returns the indices of the minimum values into the flattened {@code NDArray}.
+ *
+ * @return a {@code NDArray} containing indices
+ */
+ public NDArray argMin() {
+ if (isEmpty()) {
+ throw new IllegalArgumentException("attempt to get argMin of an empty MxNDArray");
+ }
+ return invoke(getParent(), "_npi_argmin", this, null);
+ }
+
+ /**
+ * Returns the indices of the minimum values along given axis.
+ *
+ * @param axis the axis along which to find minimum values
+ * @return a {@code NDArray} containing indices
+ */
+ public NDArray argMin(int axis) {
+ OpParams params = new OpParams();
+ params.addParam("axis", axis);
+ return invoke(getParent(), "_npi_argmin", this, params);
+ }
+
+ /**
+ * Returns percentile for this {@code NDArray}.
+ *
+ * @param percentile the target percentile in range of 0..100
+ * @return the result {@code NDArray}
+ */
+ public NDArray percentile(Number percentile) {
+ throw new UnsupportedOperationException("Not implemented yet.");
+ }
+
+ /**
+ * Returns median along given dimension(s).
+ *
+ * @param percentile the target percentile in range of 0..100
+ * @param dimension the dimension to calculate percentile for
+ * @return the result {@code NDArray} NDArray
+ */
+ public NDArray percentile(Number percentile, int[] dimension) {
+ throw new UnsupportedOperationException("Not implemented yet.");
+ }
+
+ /**
+ * Returns median value for this {@code NDArray}.
+ *
+ * @return the median {@code NDArray}
+ */
+ public NDArray median() {
+ throw new UnsupportedOperationException("Not implemented yet.");
+ }
+
+ /**
+ * Returns median value along given axes.
+ *
+ * @param axes the axes along which to perform the median operation
+ * @return the median {@code NDArray} along the specified axes
+ */
+ public NDArray median(int[] axes) {
+ throw new UnsupportedOperationException("Not implemented yet.");
+ }
+
+ /**
+ * Returns the indices of elements that are non-zero.
+ *
+ * <p>Note that the behavior is slightly different from numpy.nonzero. Numpy returns a tuple of
+ * NDArray, one for each dimension of NDArray. DJL nonzero returns only one {@code NDArray} with
+ * last dimension containing all dimension of indices.
+ *
+ * @return the indices of the elements that are non-zero
+ */
+ public NDArray nonzero() {
+ NDArray thisArr =
+ (getDataType() == DataType.BOOLEAN) ? toType(DataType.INT32, false) : this;
+ return invoke(getParent(), "_npx_nonzero", thisArr, null);
+ }
+
+ /**
+ * Returns element-wise inverse gauss error function of the {@code NDArray}.
+ *
+ * @return The inverse of gauss error of the {@code NDArray}, element-wise
+ */
+ public NDArray erfinv() {
+ return invoke(getParent(), "erfinv", this, null);
+ }
+
+ /**
+ * Returns the norm of this {@code NDArray}.
+ *
+ * @param keepDims If this is set to True, the axes which are normed over are left in the result
+ * as dimensions with size one. With this option the result will broadcast correctly against
+ * the original x.
+ * @return the norm of this {@code NDArray}
+ */
+ public NDArray norm(boolean keepDims) {
+ OpParams params = new OpParams();
+ params.add("flag", -2);
+ params.addParam("keepdims", keepDims);
+ return invoke(getParent(), "_npi_norm", this, params);
+ }
+
+ /**
+ * Returns the norm of this {@code NDArray}.
+ *
+ * @param ord Order of the norm.
+ * @param axes If axes contains an integer, it specifies the axis of x along which to compute
+ * the vector norms. If axis contains 2 integers, it specifies the axes that hold 2-D
+ * matrices, and the matrix norms of these matrices are computed.
+ * @param keepDims keepDims If this is set to True, the axes which are normed over are left in
+ * the result as dimensions with size one. With this option the result will broadcast
+ * correctly against the original x.
+ * @return the norm of this {@code NDArray}
+ */
+ public NDArray norm(int ord, int[] axes, boolean keepDims) {
+ OpParams params = new OpParams();
+ params.addParam("ord", (double) ord);
+ params.addTupleParam("axis", axes);
+ params.addParam("keepdims", keepDims);
+ return invoke(getParent(), "_npi_norm", this, params);
+ }
+
+ // public MxNDArray oneHot(int depth) {
+ // return LazyNDArray.super.oneHot(depth);
+ // }
+
+ /**
+ * Returns a one-hot {@code NDArray}.
+ *
+ * <ul>
+ * <li>The locations represented by indices take value onValue, while all other locations take
+ * value offValue.
+ * <li>If the input {@code NDArray} is rank N, the output will have rank N+1. The new axis is
+ * appended at the end.
+ * <li>If {@code NDArray} is a scalar the output shape will be a vector of length depth.
+ * <li>If {@code NDArray} is a vector of length features, the output shape will be features x
+ * depth.
+ * <li>If {@code NDArray} is a matrix with shape [batch, features], the output shape will be
+ * batch x features x depth.
+ * </ul>
+ *
+ * @param depth Depth of the one hot dimension.
+ * @param onValue The value assigned to the locations represented by indices.
+ * @param offValue The value assigned to the locations not represented by indices.
+ * @param dataType dataType of the output.
+ * @return one-hot encoding of this {@code NDArray}
+ */
+ public NDArray oneHot(int depth, float onValue, float offValue, DataType dataType) {
+ OpParams params = new OpParams();
+ params.add("depth", depth);
+ params.add("on_value", onValue);
+ params.add("off_value", offValue);
+ params.add("dtype", dataType);
+ return invoke(getParent(), "_npx_one_hot", this, params).toType(dataType, false);
+ }
+
+ /**
+ * Batchwise product of this {@code NDArray} and the other {@code NDArray}.
+ *
+ * <ul>
+ * <li>batchDot is used to compute dot product of x and y when x and y are data in batch,
+ * namely N-D (N greater or equal to 3) arrays in shape of (B0, …, B_i, :, :). For
+ * example, given x with shape (B_0, …, B_i, N, M) and y with shape (B_0, …, B_i, M, K),
+ * the result array will have shape (B_0, …, B_i, N, K), which is computed by:
+ * batch_dot(x,y)[b_0, ..., b_i, :, :] = dot(x[b_0, ..., b_i, :, :], y[b_0, ..., b_i, :,
+ * :])
+ * </ul>
+ *
+ * @param other the other {@code NDArray} to perform batch dot product with
+ * @return the result {@code NDArray}
+ */
+ public NDArray batchDot(NDArray other) {
+ return invoke(getParent(), "_npx_batch_dot", new NDArray[] {this, other}, null);
+ }
+
+ /**
+ * Returns an internal representative of Native {@code NDArray}.
+ *
+ * <p>This method should only be used by Engine provider
+ *
+ * @return an internal representative of Native {@code NDArray}
+ */
+ public NDArrayEx getNDArrayInternal() {
+ return mxNDArrayEx;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public void close() {
+ if (!getClosed()) {
+ logger.debug(String.format("Start to free NDArray instance: %S", this.getUid()));
+ super.freeSubResources();
+
+ if (this.getHandle() != null) {
+ JnaUtils.freeNdArray(this.getHandle());
+ }
+ setClosed(true);
+ logger.debug(String.format("Finish to free NDArray instance: %S", this.getUid()));
+ }
+ }
+
+ /**
+ * Returns {@code true} if this {@code NDArray} is special case: no-value {@code NDArray}.
+ *
+ * @return {@code true} if this NDArray is empty
+ */
+ public boolean isEmpty() {
+ return getShape().size() == 0;
+ }
+
+ boolean isSparse() {
+ return getSparseFormat() != SparseFormat.DENSE;
+ }
+
+ boolean shapeEquals(NDArray other) {
+ return getShape().equals(other.getShape());
+ }
+
+ /**
+ * An engine specific generic invocation to native operation.
+ *
+ * <p>You should avoid using this function if possible. Since this function is engine specific,
+ * using this API may cause a portability issue. Native operation may not compatible between
+ * each version.
+ *
+ * @param parent the parent {@link MxResource} of the created {@link NDList}
+ * @param operation the native operation to perform
+ * @param src the {@link NDList} of source {@link NDArray}
+ * @param params the parameters to be passed to the native operation
+ * @return the output array of {@link NDArray}
+ * @throws IllegalArgumentException if operation is not supported by Engine
+ */
+ public static NDList invoke(
+ MxResource parent, String operation, NDList src, PairList<String, ?> params) {
+ return new NDList(JnaUtils.op(operation).invoke(parent, src.toArray(EMPTY), params));
+ }
+
+ /**
+ * An engine specific generic invocation to native operator.
+ *
+ * <p>You should avoid using this function if possible. Since this function is engine specific,
+ * using this API may cause portability issues. A native operation may not be compatible between
+ * each version.
+ *
+ * @param operation the native operation to perform
+ * @param src the {@link NDList} of source {@link NDArray}
+ * @param dest the {@link NDList} to save output to
+ * @param params the parameters to be passed to the native operator
+ * @throws IllegalArgumentException if operation is not supported by Engine
+ */
+ public static void invoke(
+ String operation, NDList src, NDList dest, PairList<String, ?> params) {
+ invoke(operation, src.toArray(EMPTY), dest.toArray(EMPTY), params);
+ }
+
+ /**
+ * An engine specific generic invocation to native operator.
+ *
+ * <p>You should avoid using this function if possible. Since this function is engine specific,
+ * using this API may cause portability issues. A native operation may not be compatible between
+ * each version.
+ *
+ * @param parent the parent {@link MxResource} to manage this instance
+ * @param operation the native operation to perform
+ * @param src the array of source {@link NDArray}
+ * @param params the parameters to be passed to the native operator
+ * @return the output array of {@link NDArray}
+ */
+ public static NDArray invoke(
+ MxResource parent, String operation, NDArray[] src, PairList<String, ?> params) {
+ return JnaUtils.op(operation).invoke(parent, src, params)[0];
+ }
+
+ /**
+ * An engine specific generic invocation to native operation.
+ *
+ * <p>You should avoid using this function if possible. Since this function is engine specific,
+ * using this API may cause a portability issue. Native operation may not be compatible between
+ * each version.
+ *
+ * @param operation the native operation to perform
+ * @param src the {@link NDList} of source {@link NDArray}
+ * @param dest the {@link NDList} to save output to
+ * @param params the parameters to be passed to the native operation
+ * @throws IllegalArgumentException if operation is not supported by Engine
+ */
+ public static void invoke(
+ String operation, NDArray[] src, NDArray[] dest, PairList<String, ?> params) {
+ JnaUtils.op(operation).invoke(src, dest, params);
+ }
+
+ /**
+ * An engine specific generic invocation to native operator.
+ *
+ * <p>You should avoid using this function if possible. Since this function is engine specific,
+ * using this API may cause portability issues. A native operation may not be compatible between
+ * each version.
+ *
+ * @param parent the parent {@link MxResource} to manage this instance
+ * @param operation the native operation to perform
+ * @param src the source {@link NDArray}
+ * @param params the parameters to be passed to the native operator
+ * @return the output array of {@link NDArray}
+ */
+ public static NDArray invoke(
+ MxResource parent, String operation, NDArray src, PairList<String, ?> params) {
+ return invoke(parent, operation, new NDArray[] {src}, params);
+ }
+
+ /**
+ * An engine specific generic invocation to native operator.
+ *
+ * <p>You should avoid using this function if possible. Since this function is engine specific,
+ * using this API may cause portability issues. A native operation may not be compatible between
+ * each version.
+ *
+ * @param parent the parent {@link MxResource} to manage this instance
+ * @param operation the native operation to perform
+ * @param params the parameters to be passed to the native operator
+ * @return the output array of {@link NDArray}
+ */
+ public static NDArray invoke(MxResource parent, String operation, PairList<String, ?> params) {
+ return invoke(parent, operation, EMPTY, params);
+ }
+
+ /**
+ * Encodes {@code MxNDArray} to byte array.
+ *
+ * @return byte array
+ */
+ public byte[] encode() {
+ return NDSerializer.encode(this);
+ }
+
+ /**
+ * Draws samples from a uniform distribution.
+ *
+ * <p>Samples are uniformly distributed over the half-open interval [low, high) (includes low,
+ * but excludes high). In other words, any value within the given interval is equally likely to
+ * be drawn by uniform.
+ *
+ * @param parent {@link MxResource} of this instance
+ * @param low the lower boundary of the output interval. All values generated will be greater
+ * than or equal to low.
+ * @param high the upper boundary of the output interval. All values generated will be less than
+ * high.
+ * @param shape the {@link Shape} of the {@link NDArray}
+ * @param dataType the {@link DataType} of the {@link NDArray}
+ * @param device the {@link Device} of the {@link NDArray}
+ * @return the drawn samples {@link NDArray}
+ */
+ public static NDArray randomUniform(
+ MxResource parent,
+ float low,
+ float high,
+ Shape shape,
+ DataType dataType,
+ Device device) {
+ OpParams params = new OpParams();
+ params.addParam("low", low);
+ params.addParam("high", high);
+ params.addParam("size", shape);
+ params.setDevice(device);
+ params.setDataType(dataType);
+ return invoke(parent, "_npi_uniform", params);
+ }
+
+ /**
+ * Draws samples from a uniform distribution.
+ *
+ * <p>Samples are uniformly distributed over the half-open interval [low, high) (includes low,
+ * but excludes high). In other words, any value within the given interval is equally likely to
+ * be drawn by uniform.
+ *
+ * @param parent {@link MxResource} of this instance
+ * @param low the lower boundary of the output interval. All values generated will be greater
+ * than or equal to low.
+ * @param high the upper boundary of the output interval. All values generated will be less than
+ * high.
+ * @param shape the {@link Shape} of the {@link NDArray}
+ * @param dataType the {@link DataType} of the {@link NDArray}
+ * @return the drawn samples {@link NDArray}
+ */
+ private static NDArray randomUniform(
+ MxResource parent, float low, float high, Shape shape, DataType dataType) {
+ return randomUniform(parent, low, high, shape, dataType, Device.defaultIfNull(null));
+ }
+
+ /**
+ * Draws random samples from a normal (Gaussian) distribution.
+ *
+ * @param parent {@link MxResource} of this instance
+ * @param loc the mean (centre) of the distribution
+ * @param scale the standard deviation (spread or "width") of the distribution
+ * @param shape the output {@link Shape}
+ * @param dataType the {@link DataType} of the {@link NDArray}
+ * @param device the {@link Device} of the {@link NDArray}
+ * @return the drawn samples {@link NDArray}
+ */
+ public static NDArray randomNormal(
+ MxResource parent,
+ float loc,
+ float scale,
+ Shape shape,
+ DataType dataType,
+ Device device) {
+ if (device == null) {
+ return randomNormal(parent, loc, scale, shape, dataType);
+ }
+ return randomNormal(parent, loc, scale, shape, dataType);
+ }
+
+ /**
+ * Draws random samples from a normal (Gaussian) distribution.
+ *
+ * @param parent {@link MxResource} of this instance
+ * @param loc the mean (centre) of the distribution
+ * @param scale the standard deviation (spread or "width") of the distribution
+ * @param shape the output {@link Shape}
+ * @param dataType the {@link DataType} of the {@link NDArray}
+ * @return the drawn samples {@link NDArray}
+ */
+ public static NDArray randomNormal(
+ MxResource parent, float loc, float scale, Shape shape, DataType dataType) {
+ OpParams params = new OpParams();
+ params.addParam("loc", loc);
+ params.addParam("scale", scale);
+ params.addParam("size", shape);
+ params.setDevice(Device.defaultIfNull(null));
+ params.setDataType(dataType);
+ return invoke(parent, "_npi_normal", params);
+ }
+
+ /**
+ * Decodes {@link NDArray} through byte array.
+ *
+ * @param parent the parent {@link MxResource} to create the {@link NDArray}
+ * @param bytes byte array to load from
+ * @return {@link NDArray}
+ */
+ static NDArray decode(MxResource parent, byte[] bytes) {
+ try (DataInputStream dis = new DataInputStream(new ByteArrayInputStream(bytes))) {
+ return NDSerializer.decode(parent, dis);
+ } catch (IOException e) {
+ throw new IllegalArgumentException("NDArray decoding failed", e);
+ }
+ }
+
+ /**
+ * Decodes {@link NDArray} through {@link DataInputStream}.
+ *
+ * @param parent the parent {@link MxResource} to create the {@link NDArray}
+ * @param is input stream data to load from
+ * @return {@link NDArray}
+ * @throws IOException data is not readable
+ */
+ public static NDArray decode(MxResource parent, InputStream is) throws IOException {
+ return NDSerializer.decode(parent, is);
+ }
+
+ /**
+ * Converts this {@code NDArray} to a Number array based on its {@link DataType}.
+ *
+ * @return a Number array
+ */
+ public Number[] toArray() {
+ switch (getDataType()) {
+ case FLOAT16:
+ case FLOAT32:
+ float[] floatArray = toFloatArray();
+ return IntStream.range(0, floatArray.length)
+ .mapToObj(i -> floatArray[i])
+ .toArray(Number[]::new);
+ case FLOAT64:
+ return Arrays.stream(toDoubleArray()).boxed().toArray(Double[]::new);
+ case INT32:
+ return Arrays.stream(toIntArray()).boxed().toArray(Integer[]::new);
+ case INT64:
+ return Arrays.stream(toLongArray()).boxed().toArray(Long[]::new);
+ case BOOLEAN:
+ case INT8:
+ ByteBuffer bb = toByteBuffer();
+ Byte[] ret = new Byte[bb.remaining()];
+ for (int i = 0; i < ret.length; ++i) {
+ ret[i] = bb.get();
+ }
+ return ret;
+ case UINT8:
+ return Arrays.stream(toUint8Array()).boxed().toArray(Integer[]::new);
+ default:
+ throw new IllegalStateException("Unsupported DataType: " + getDataType());
+ }
+ }
+
+ /**
+ * Converts this {@code NDArray} to a boolean array.
+ *
+ * @return a boolean array
+ * @throws IllegalStateException when {@link DataType} of this {@code NDArray} mismatches
+ */
+ public boolean[] toBooleanArray() {
+ if (getDataType() != DataType.BOOLEAN) {
+ throw new IllegalStateException(
+ "DataType mismatch, Required boolean" + " Actual " + getDataType());
+ }
+ ByteBuffer bb = toByteBuffer();
+ boolean[] ret = new boolean[bb.remaining()];
+ for (int i = 0; i < ret.length; ++i) {
+ ret[i] = bb.get() != 0;
+ }
+ return ret;
+ }
+
+ /**
+ * Converts this {@code NDArray} to a double array.
+ *
+ * @return a double array
+ * @throws IllegalStateException when {@link DataType} of this {@code NDArray} mismatches
+ */
+ public double[] toDoubleArray() {
+ if (getDataType() != DataType.FLOAT64) {
+ throw new IllegalStateException(
+ "DataType mismatch, Required double" + " Actual " + getDataType());
+ }
+ DoubleBuffer db = toByteBuffer().asDoubleBuffer();
+ double[] ret = new double[db.remaining()];
+ db.get(ret);
+ return ret;
+ }
+
+ /**
+ * Converts this {@code NDArray} to a float array.
+ *
+ * @return a float array
+ * @throws IllegalStateException when {@link DataType} of this {@code NDArray} mismatches
+ */
+ public float[] toFloatArray() {
+ if (getDataType() == DataType.FLOAT16) {
+ return Float16Utils.fromByteBuffer(toByteBuffer());
+ } else if (getDataType() != DataType.FLOAT32) {
+ throw new IllegalStateException(
+ "DataType mismatch, Required float, Actual " + getDataType());
+ }
+ FloatBuffer fb = toByteBuffer().asFloatBuffer();
+ float[] ret = new float[fb.remaining()];
+ fb.get(ret);
+ return ret;
+ }
+
+ /**
+ * Converts this {@code NDArray} to an int array.
+ *
+ * @return an int array
+ * @throws IllegalStateException when {@link DataType} of this {@code NDArray} mismatches
+ */
+ public int[] toIntArray() {
+ if (getDataType() != DataType.INT32) {
+ throw new IllegalStateException(
+ "DataType mismatch, Required int" + " Actual " + getDataType());
+ }
+ IntBuffer ib = toByteBuffer().asIntBuffer();
+ int[] ret = new int[ib.remaining()];
+ ib.get(ret);
+ return ret;
+ }
+
+ /**
+ * Converts this {@code NDArray} to a long array.
+ *
+ * @return a long array
+ * @throws IllegalStateException when {@link DataType} of this {@code NDArray} mismatches
+ */
+ public long[] toLongArray() {
+ if (getDataType() != DataType.INT64) {
+ throw new IllegalStateException(
+ "DataType mismatch, Required long" + " Actual " + getDataType());
+ }
+ LongBuffer lb = toByteBuffer().asLongBuffer();
+ long[] ret = new long[lb.remaining()];
+ lb.get(ret);
+ return ret;
+ }
+
+ /**
+ * Converts this {@code NDArray} to a byte array.
+ *
+ * @return a byte array
+ * @throws IllegalStateException when {@link DataType} of this {@code NDArray} mismatches
+ */
+ public byte[] toByteArray() {
+ ByteBuffer bb = toByteBuffer();
+ if (bb.hasArray()) {
+ return bb.array();
+ }
+ byte[] buf = new byte[bb.remaining()];
+ bb.get(buf);
+ return buf;
+ }
+
+ /**
+ * Converts this {@code NDArray} to a uint8 array.
+ *
+ * @return a uint8 array
+ * @throws IllegalStateException when {@link DataType} of this {@code NDArray} mismatches
+ */
+ public int[] toUint8Array() {
+ ByteBuffer bb = toByteBuffer();
+ int[] buf = new int[bb.remaining()];
+ for (int i = 0; i < buf.length; ++i) {
+ buf[i] = bb.get() & 0xff;
+ }
+ return buf;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public String toString() {
+ if (getClosed()) {
+ return "This array is already closed";
+ }
+ return toDebugString(MAX_SIZE, MAX_DEPTH, MAX_ROWS, MAX_COLUMNS);
+ }
+
+ /**
+ * Runs the debug string representation of this {@code NDArray}.
+ *
+ * @param maxSize the maximum elements to print out
+ * @param maxDepth the maximum depth to print out
+ * @param maxRows the maximum rows to print out
+ * @param maxColumns the maximum columns to print out
+ * @return the debug string representation of this {@code NDArray}
+ */
+ String toDebugString(int maxSize, int maxDepth, int maxRows, int maxColumns) {
+ return NDFormat.format(this, maxSize, maxDepth, maxRows, maxColumns);
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/NDArrayEx.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/NDArrayEx.java
new file mode 100644
index 0000000..047ff04
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/NDArrayEx.java
@@ -0,0 +1,1107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.ndarray;
+
+import java.util.Arrays;
+import java.util.List;
+import java.util.stream.Stream;
+import org.apache.mxnet.engine.Device;
+import org.apache.mxnet.engine.OpParams;
+import org.apache.mxnet.jna.JnaUtils;
+import org.apache.mxnet.ndarray.types.DataType;
+import org.apache.mxnet.ndarray.types.Shape;
+import org.apache.mxnet.ndarray.types.SparseFormat;
+
+/** An internal interface that encapsulates engine specific operations. */
+@SuppressWarnings("MissingJavadocMethod")
+public class NDArrayEx {
+
+ private static final NDArrayIndexer INDEXER = new NDArrayIndexer();
+
+ private NDArray array;
+
+ /**
+ * Constructs an {@code MxNDArrayEx} given a {@link NDArray}.
+ *
+ * @param parent the {@link NDArray} to extend
+ */
+ NDArrayEx(NDArray parent) {
+ this.array = parent;
+ }
+
+ // TODO only used to calculate zero-dim numpy shape
+ // remove it once MXNet have all the np op that we support
+ private Shape deriveBroadcastedShape(Shape lhs, Shape rhs) {
+ long[] result = new long[Math.max(lhs.dimension(), rhs.dimension())];
+ long lDiff = result.length - lhs.dimension();
+ long rDiff = result.length - rhs.dimension();
+ for (int i = 0; i < result.length; i++) {
+ long l = 1;
+ long r = 1;
+ if (i >= lDiff) {
+ l = lhs.get(Math.toIntExact(i - lDiff));
+ }
+ if (i >= rDiff) {
+ r = rhs.get(Math.toIntExact(i - rDiff));
+ }
+ if (l != r) {
+ if (l != 1 && r != 1) {
+ throw new IllegalArgumentException(
+ "operands could not be broadcast together with shapes "
+ + lhs
+ + " "
+ + rhs);
+ }
+ result[i] = (l == 1) ? r : l;
+ } else {
+ result[i] = l;
+ }
+ }
+ return new Shape(result);
+ }
+
+ ////////////////////////////////////////
+ // MxNDArrays
+ ////////////////////////////////////////
+ /**
+ * Applies reverse division with a scalar - i.e., (n / thisArrayValues).
+ *
+ * @param n the Value to use for reverse division
+ * @return a copy of the array after applying reverse division
+ */
+ public NDArray rdiv(Number n) {
+ OpParams params = new OpParams();
+ params.add("scalar", n.toString());
+ return NDArray.invoke(getArray().getParent(), "_rdiv_scalar", array, params);
+ }
+
+ /**
+ * Applies reverse division with a scalar - i.e., (n / thisArrayValues).
+ *
+ * @param b the ndarray to use for reverse division
+ * @return a copy of the array after applying reverse division
+ */
+ public NDArray rdiv(NDArray b) {
+ return b.div(array);
+ }
+
+ /**
+ * Applies in place reverse division - i.e., (n / thisArrayValues).
+ *
+ * @param n the value to use for reverse division
+ * @return this array after applying reverse division
+ */
+ public NDArray rdivi(Number n) {
+ OpParams params = new OpParams();
+ params.add("scalar", n.toString());
+ NDArray.invoke("_rdiv_scalar", new NDArray[] {array}, new NDArray[] {array}, params);
+ return array;
+ }
+
+ /**
+ * Applies in place reverse division - i.e., (n / thisArrayValues).
+ *
+ * @param b the ndarray to use for reverse division
+ * @return this array after applying reverse division
+ */
+ public NDArray rdivi(NDArray b) {
+ NDArray.invoke("elemwise_div", new NDArray[] {b, array}, new NDArray[] {array}, null);
+ return array;
+ }
+
+ /**
+ * Applies reverse subtraction with duplicates - i.e., (n - thisArrayValues).
+ *
+ * @param n the value to use for reverse subtraction
+ * @return a copy of array after reverse subtraction
+ */
+ public NDArray rsub(Number n) {
+ return array.sub(n).neg();
+ }
+
+ /**
+ * Applies reverse subtraction with duplicates - i.e., (n - thisArrayValues).
+ *
+ * @param b the ndarray to use for reverse subtraction
+ * @return a copy of the array after reverse subtraction
+ */
+ public NDArray rsub(NDArray b) {
+ return array.sub(b).neg();
+ }
+
+ /**
+ * Applies reverse subtraction in place - i.e., (n - thisArrayValues).
+ *
+ * @param n the value to use for reverse subtraction
+ * @return this array after reverse subtraction
+ */
+ public NDArray rsubi(Number n) {
+ return array.subi(n).negi();
+ }
+
+ /**
+ * Applies reverse subtraction in place - i.e., (n - thisArrayValues).
+ *
+ * @param b the ndarray to use for reverse subtraction
+ * @return this array after reverse subtraction
+ */
+ public NDArray rsubi(NDArray b) {
+ return array.subi(b).negi();
+ }
+
+ public NDArray rmod(Number n) {
+ OpParams params = new OpParams();
+ params.add("scalar", n.toString());
+ return NDArray.invoke(getArray().getParent(), "_npi_rmod_scalar", array, params);
+ }
+
+ /**
+ * Applies reverse remainder of division with a scalar.
+ *
+ * @param b the value to use for reverse division
+ * @return a copy of array after applying reverse division
+ */
+ public NDArray rmod(NDArray b) {
+ return b.mod(array);
+ }
+
+ /**
+ * Applies in place reverse remainder of division with a scalar.
+ *
+ * @param n the value to use for reverse division
+ * @return this array after applying reverse division
+ */
+ public NDArray rmodi(Number n) {
+ OpParams params = new OpParams();
+ params.add("scalar", n.toString());
+ NDArray.invoke("_npi_rmod_scalar", new NDArray[] {array}, new NDArray[] {array}, params);
+ return array;
+ }
+
+ /**
+ * Applies in place reverse remainder of division.
+ *
+ * @param b the ndarray to use for reverse division
+ * @return this array after applying reverse division
+ */
+ public NDArray rmodi(NDArray b) {
+ NDArray.invoke("_npi_mod", new NDArray[] {b, array}, new NDArray[] {array}, null);
+ return array;
+ }
+
+ /**
+ * Reverses the power of each element being raised in the {@code NDArray}.
+ *
+ * @param n the value to use for reverse power
+ * @return a copy of array after applying reverse power
+ */
+ public NDArray rpow(Number n) {
+ OpParams params = new OpParams();
+ params.add("scalar", n.toString());
+ return NDArray.invoke(getArray().getParent(), "_npi_rpower_scalar", array, params);
+ }
+
+ /**
+ * Reverses the power of each element being raised in the {@code NDArray} in place.
+ *
+ * @param n the value to use for reverse power
+ * @return a copy of array after applying reverse power
+ */
+ public NDArray rpowi(Number n) {
+ OpParams params = new OpParams();
+ params.add("scalar", n.toString());
+ NDArray.invoke("_npi_rpower_scalar", new NDArray[] {array}, new NDArray[] {array}, params);
+ return array;
+ }
+
+ ////////////////////////////////////////
+ // Activations
+ ////////////////////////////////////////
+ /**
+ * Computes rectified linear activation.
+ *
+ * @return a copy of array after applying relu
+ */
+ public NDArray relu() {
+ OpParams params = new OpParams();
+ params.addParam("act_type", "relu");
+ return NDArray.invoke(getArray().getParent(), "_npx_activation", array, params);
+ }
+
+ public NDArray sigmoid() {
+ OpParams params = new OpParams();
+ params.addParam("act_type", "sigmoid");
+ return NDArray.invoke(getArray().getParent(), "_npx_activation", array, params);
+ }
+
+ public NDArray tanh() {
+ OpParams params = new OpParams();
+ params.addParam("act_type", "tanh");
+ return NDArray.invoke(getArray().getParent(), "_npx_activation", array, params);
+ }
+
+ public NDArray softPlus() {
+ OpParams params = new OpParams();
+ params.addParam("act_type", "softrelu");
+ return NDArray.invoke(getArray().getParent(), "_npx_activation", array, params);
+ }
+
+ public NDArray softSign() {
+ OpParams params = new OpParams();
+ params.addParam("act_type", "softsign");
+ return NDArray.invoke(getArray().getParent(), "_npx_activation", array, params);
+ }
+
+ public NDArray leakyRelu(float alpha) {
+ OpParams params = new OpParams();
+ params.addParam("act_type", "leaky");
+ params.addParam("slope", alpha);
+ return NDArray.invoke(getArray().getParent(), "_npx_leaky_relu", array, params);
+ }
+
+ public NDArray elu(float alpha) {
+ OpParams params = new OpParams();
+ params.addParam("act_type", "elu");
+ params.addParam("slope", alpha);
+ return NDArray.invoke(getArray().getParent(), "_npx_leaky_relu", array, params);
+ }
+
+ public NDArray selu() {
+ OpParams params = new OpParams();
+ params.addParam("act_type", "selu");
+ return NDArray.invoke(getArray().getParent(), "_npx_leaky_relu", array, params);
+ }
+
+ public NDArray gelu() {
+ OpParams params = new OpParams();
+ params.addParam("act_type", "gelu");
+ return NDArray.invoke(getArray().getParent(), "_npx_leaky_relu", array, params);
+ }
+
+ ////////////////////////////////////////
+ // Pooling Operations
+ ////////////////////////////////////////
+
+ public NDArray maxPool(Shape kernelShape, Shape stride, Shape padding, boolean ceilMode) {
+ OpParams params = new OpParams();
+ params.addParam("kernel", kernelShape);
+ params.add("pool_type", "max");
+ params.addParam("stride", stride);
+ params.addParam("pad", padding);
+ params.add("pooling_convention", ceilMode ? "full" : "valid");
+ return NDArray.invoke(getArray().getParent(), "_npx_pooling", getArray(), params);
+ }
+
+ public NDArray globalMaxPool() {
+ OpParams params = new OpParams();
+ params.add("kernel", getGlobalPoolingShapes(1));
+ params.add("pad", getGlobalPoolingShapes(0));
+ params.add("pool_type", "max");
+ params.addParam("global_pool", true);
+ try (NDArray temp =
+ NDArray.invoke(getArray().getParent(), "_npx_pooling", getArray(), params)) {
+ return temp.reshape(temp.getShape().size(0), temp.getShape().size(1));
+ }
+ }
+
+ public NDArray avgPool(
+ Shape kernelShape,
+ Shape stride,
+ Shape padding,
+ boolean ceilMode,
+ boolean countIncludePad) {
+ OpParams params = new OpParams();
+ params.addParam("kernel", kernelShape);
+ params.add("pool_type", "avg");
+ params.addParam("stride", stride);
+ params.addParam("pad", padding);
+ params.add("pooling_convention", ceilMode ? "full" : "valid");
+ params.addParam("count_include_pad", countIncludePad);
+ return NDArray.invoke(getArray().getParent(), "_npx_pooling", getArray(), params);
+ }
+
+ public NDArray globalAvgPool() {
+ OpParams params = new OpParams();
+ params.add("kernel", getGlobalPoolingShapes(1));
+ params.add("pad", getGlobalPoolingShapes(0));
+ params.add("pool_type", "avg");
+ params.addParam("global_pool", true);
+ try (NDArray temp =
+ NDArray.invoke(getArray().getParent(), "_npx_pooling", getArray(), params)) {
+ return temp.reshape(temp.getShape().size(0), temp.getShape().size(1));
+ }
+ }
+
+ public NDArray lpPool(
+ float normType, Shape kernelShape, Shape stride, Shape padding, boolean ceilMode) {
+ if (((int) normType) != normType) {
+ throw new IllegalArgumentException(
+ "float type of normType is not supported in MXNet engine, please use integer instead");
+ }
+ OpParams params = new OpParams();
+ params.addParam("p_value", (int) normType);
+ params.addParam("kernel", kernelShape);
+ params.add("pool_type", "lp");
+ params.addParam("stride", stride);
+ params.addParam("pad", padding);
+ params.add("pooling_convention", ceilMode ? "full" : "valid");
+
+ return NDArray.invoke(getArray().getParent(), "_npx_pooling", getArray(), params);
+ }
+
+ public NDArray globalLpPool(float normType) {
+ if (((int) normType) != normType) {
+ throw new IllegalArgumentException(
+ "float type of normType is not supported in MXNet engine, please use integer instead");
+ }
+ OpParams params = new OpParams();
+ params.add("pool_type", "lp");
+ params.addParam("p_value", (int) normType);
+ params.addParam("global_pool", true);
+ try (NDArray temp =
+ NDArray.invoke(getArray().getParent(), "_npx_pooling", getArray(), params)) {
+ return temp.reshape(temp.getShape().size(0), temp.getShape().size(1));
+ }
+ }
+
+ ////////////////////////////////////////
+ // Optimizer
+ ////////////////////////////////////////
+
+ // public void adadeltaUpdate(
+ // MxNDList inputs,
+ // MxNDList weights,
+ // float weightDecay,
+ // float rescaleGrad,
+ // float clipGrad,
+ // float rho,
+ // float epsilon) {
+ // MxNDArray weight = inputs.get(0);
+ // MxNDArray grad = inputs.get(1);
+ // MxNDArray s = inputs.get(2);
+ // MxNDArray delta = inputs.get(3);
+ //
+ // // create a baseManager to close all intermediate MxNDArrays
+ // try (NDManager subManager = NDManager.newBaseManager()) {
+ // subManager.tempAttachAll(inputs, weights);
+ //
+ // // Preprocess Gradient
+ // grad.muli(rescaleGrad);
+ // if (clipGrad > 0) {
+ // grad = grad.clip(-clipGrad, clipGrad);
+ // }
+ // grad.addi(weight.mul(weightDecay));
+ //
+ // // Update s, g, and delta
+ // s.muli(rho).addi(grad.square().mul(1 - rho));
+ // MxNDArray g = delta.add(epsilon).sqrt().div(s.add(epsilon).sqrt()).mul(grad);
+ // delta.muli(rho).addi(g.square().mul(1 - rho));
+ //
+ // // Update weight
+ // weight.subi(g);
+ // }
+ // }
+
+ public void adagradUpdate(
+ NDList inputs,
+ NDList weights,
+ float learningRate,
+ float weightDecay,
+ float rescaleGrad,
+ float clipGrad,
+ float epsilon) {
+ OpParams params = new OpParams();
+ params.addParam("lr", learningRate);
+ params.addParam("wd", weightDecay);
+ params.addParam("rescale_grad", rescaleGrad);
+ params.addParam("clip_gradient", clipGrad);
+
+ params.addParam("epsilon", epsilon);
+
+ NDArray.invoke("adagrad_update", inputs, weights, params);
+ }
+
+ public void adamUpdate(
+ NDList inputs,
+ NDList weights,
+ float learningRate,
+ float weightDecay,
+ float rescaleGrad,
+ float clipGrad,
+ float beta1,
+ float beta2,
+ float epsilon,
+ boolean lazyUpdate) {
+ OpParams params = new OpParams();
+ params.addParam("lr", learningRate);
+ params.addParam("wd", weightDecay);
+ params.addParam("rescale_grad", rescaleGrad);
+ params.addParam("clip_gradient", clipGrad);
+
+ params.addParam("beta1", beta1);
+ params.addParam("beta2", beta2);
+ params.addParam("epsilon", epsilon);
+ params.addParam("lazy_update", lazyUpdate);
+
+ NDArray.invoke("adam_update", inputs, weights, params);
+ }
+
+ public void rmspropUpdate(
+ NDList inputs,
+ NDList weights,
+ float learningRate,
+ float weightDecay,
+ float rescaleGrad,
+ float clipGrad,
+ float gamma1,
+ float gamma2,
+ float epsilon,
+ boolean centered) {
+ OpParams params = new OpParams();
+ params.addParam("lr", learningRate);
+ params.addParam("wd", weightDecay);
+ params.addParam("rescale_grad", rescaleGrad);
+ params.addParam("clip_gradient", clipGrad);
+
+ params.addParam("gamma1", gamma1);
+ params.addParam("epsilon", epsilon);
+
+ if (!centered) {
+ NDArray.invoke("rmsprop_update", inputs, weights, params);
+ } else {
+ params.addParam("gamma2", gamma2);
+
+ NDArray.invoke("rmspropalex_update", inputs, weights, params);
+ }
+ }
+
+ public void nagUpdate(
+ NDList inputs,
+ NDList weights,
+ float learningRate,
+ float weightDecay,
+ float rescaleGrad,
+ float clipGrad,
+ float momentum) {
+ OpParams params = new OpParams();
+ params.addParam("lr", learningRate);
+ params.addParam("wd", weightDecay);
+ params.addParam("rescale_grad", rescaleGrad);
+ params.addParam("clip_gradient", clipGrad);
+ params.addParam("momentum", momentum);
+ NDArray.invoke("nag_mom_update", inputs, weights, params);
+ }
+
+ public void sgdUpdate(
+ NDList inputs,
+ NDList weights,
+ float learningRate,
+ float weightDecay,
+ float rescaleGrad,
+ float clipGrad,
+ float momentum,
+ boolean lazyUpdate) {
+ OpParams params = new OpParams();
+ params.addParam("lr", learningRate);
+ params.addParam("wd", weightDecay);
+ params.addParam("rescale_grad", rescaleGrad);
+ params.addParam("clip_gradient", clipGrad);
+ params.addParam("lazy_update", lazyUpdate);
+
+ if (momentum != 0) {
+ params.addParam("momentum", momentum);
+ NDArray.invoke("sgd_mom_update", inputs, weights, params);
+ } else {
+ NDArray.invoke("sgd_update", inputs, weights, params);
+ }
+ }
+
+ ////////////////////////////////////////
+ // Neural network
+ ////////////////////////////////////////
+
+ public NDList convolution(
+ NDArray input,
+ NDArray weight,
+ NDArray bias,
+ Shape stride,
+ Shape padding,
+ Shape dilation,
+ int groups) {
+ OpParams params = new OpParams();
+ params.addParam("kernel", weight.getShape().slice(2));
+ params.addParam("stride", stride);
+ params.addParam("pad", padding);
+ params.addParam("dilate", dilation);
+ params.addParam("num_group", groups);
+ params.addParam("num_filter", weight.getShape().get(0));
+
+ NDList inputs = new NDList(input, weight);
+ if (bias != null) {
+ params.add("no_bias", false);
+ inputs.add(bias);
+ } else {
+ params.add("no_bias", true);
+ }
+
+ return NDArray.invoke(getArray().getParent(), "_npx_convolution", inputs, params);
+ }
+
+ public NDList deconvolution(
+ NDArray input,
+ NDArray weight,
+ NDArray bias,
+ Shape stride,
+ Shape padding,
+ Shape outPadding,
+ Shape dilation,
+ int groups) {
+ OpParams params = new OpParams();
+ params.addParam("kernel", weight.getShape().slice(2));
+ params.addParam("stride", stride);
+ params.addParam("pad", padding);
+ params.addParam("adj", outPadding);
+ params.addParam("dilate", dilation);
+ params.addParam("num_group", groups);
+ params.addParam("num_filter", weight.getShape().get(0));
+
+ NDList inputs = new NDList(input, weight);
+ if (bias != null) {
+ params.add("no_bias", false);
+ inputs.add(bias);
+ } else {
+ params.add("no_bias", true);
+ }
+
+ return NDArray.invoke(getArray().getParent(), "_npx_deconvolution", inputs, params);
+ }
+
+ public NDList linear(NDArray input, NDArray weight, NDArray bias) {
+ OpParams params = new OpParams();
+ params.addParam("num_hidden", weight.size(0));
+ params.addParam("flatten", false);
+ params.addParam("no_bias", bias == null);
+ NDList inputs = new NDList(input, weight);
+ if (bias != null) {
+ inputs.add(bias);
+ }
+
+ return NDArray.invoke(getArray().getParent(), "_npx_fully_connected", inputs, params);
+ }
+
+ public NDList embedding(NDArray input, NDArray weight, SparseFormat sparse) {
+ if (!sparse.equals(SparseFormat.DENSE) && !sparse.equals(SparseFormat.ROW_SPARSE)) {
+ throw new IllegalArgumentException("MXNet only supports row sparse");
+ }
+ OpParams params = new OpParams();
+ long inputDim = weight.getShape().get(0);
+ long outputDim = weight.getShape().get(1);
+ params.addParam("input_dim", inputDim);
+ params.addParam("output_dim", outputDim);
+ params.addParam("sparse_grad", sparse.getValue());
+ return NDArray.invoke(
+ getArray().getParent(), "_npx_embedding", new NDList(input, weight), params);
+ }
+
+ public NDList prelu(NDArray input, NDArray alpha) {
+ OpParams params = new OpParams();
+ params.addParam("act_type", "prelu");
+ return NDArray.invoke(
+ getArray().getParent(), "_npx_leaky_relu", new NDList(input, alpha), params);
+ }
+
+ public NDList dropout(NDArray input, float rate, boolean training) {
+ if (training != JnaUtils.autogradIsTraining()) {
+ throw new IllegalArgumentException(
+ "the mode of dropout in MXNet should align with the mode of GradientCollector");
+ }
+
+ OpParams params = new OpParams();
+ params.addParam("p", rate);
+
+ return NDArray.invoke(getArray().getParent(), "_npx_dropout", new NDList(input), params);
+ }
+
+ public NDList batchNorm(
+ NDArray input,
+ NDArray runningMean,
+ NDArray runningVar,
+ NDArray gamma,
+ NDArray beta,
+ int axis,
+ float momentum,
+ float eps,
+ boolean training) {
+ OpParams params = new OpParams();
+ params.addParam("axis", axis);
+ params.addParam("fix_gamma", gamma == null);
+ params.addParam("eps", eps);
+ params.addParam("momentum", momentum);
+
+ if (training != JnaUtils.autogradIsTraining()) {
+ throw new IllegalArgumentException(
+ "the mode of batchNorm in MXNet should align with the mode of GradientCollector");
+ }
+
+ return NDArray.invoke(
+ getArray().getParent(),
+ "_npx_batch_norm",
+ new NDList(input, gamma, beta, runningMean, runningVar),
+ params);
+ }
+
+ // public MxNDList rnn(
+ // MxNDArray input,
+ // MxNDArray state,
+ // MxNDList params,
+ // boolean hasBiases,
+ // int numLayers,
+ // RNN.Activation activation,
+ // double dropRate,
+ // boolean training,
+ // boolean bidirectional,
+ // boolean batchFirst) {
+ // int numParams = numLayers * ((hasBiases) ? 4 : 2) * ((bidirectional) ? 2 : 1);
+ // Preconditions.checkArgument(
+ // params.size() == numParams,
+ // "The size of Params is incorrect expect "
+ // + numParams
+ // + " parameters but got "
+ // + params.size());
+ //
+ // if (training != JnaUtils.autogradIsTraining()) {
+ // throw new IllegalArgumentException(
+ // "the mode of rnn in MXNet should align with the mode of
+ // GradientCollector");
+ // }
+ //
+ // if (batchFirst) {
+ // input = input.swapAxes(0, 1);
+ // }
+ //
+ // MxOpParams opParams = new MxOpParams();
+ // opParams.addParam("p", dropRate);
+ // opParams.addParam("state_size", state.getShape().tail());
+ // opParams.addParam("num_layers", numLayers);
+ // opParams.addParam("bidirectional", bidirectional);
+ // opParams.addParam("state_outputs", true);
+ // opParams.addParam("mode", activation == RNN.Activation.TANH ? "rnn_tanh" :
+ // "rnn_relu");
+ //
+ // MxNDList inputs = new MxNDList();
+ // inputs.add(input);
+ //
+ // try (MxNDList temp = new MxNDList()) {
+ // for (MxNDArray param : params) {
+ // temp.add(param.flatten());
+ // }
+ // MxNDArray tempParam = MxNDArrays.concat(temp);
+ // tempParam.attach(input.getManager());
+ // inputs.add(tempParam);
+ // }
+ //
+ // inputs.add(state);
+ //
+ // if (!batchFirst) {
+ // return getManager().invoke("_npx_rnn", inputs, opParams);
+ // }
+ //
+ // MxNDList result = getManager().invoke("_npx_rnn", inputs, opParams);
+ // try (MxNDArray temp = result.head()) {
+ // return new MxNDList(temp.swapAxes(0, 1), result.get(1));
+ // }
+ // }
+
+ // public MxNDList gru(
+ // MxNDArray input,
+ // MxNDArray state,
+ // MxNDList params,
+ // boolean hasBiases,
+ // int numLayers,
+ // double dropRate,
+ // boolean training,
+ // boolean bidirectional,
+ // boolean batchFirst) {
+ // int numParams = numLayers * ((hasBiases) ? 4 : 2) * ((bidirectional) ? 2 : 1);
+ // Preconditions.checkArgument(
+ // params.size() == numParams,
+ // "The size of Params is incorrect expect "
+ // + numParams
+ // + " parameters but got "
+ // + params.size());
+ //
+ // if (training != JnaUtils.autogradIsTraining()) {
+ // throw new IllegalArgumentException(
+ // "the mode of gru in MXNet should align with the mode of
+ // GradientCollector");
+ // }
+ //
+ // if (batchFirst) {
+ // input = input.swapAxes(0, 1);
+ // }
+ //
+ // MxOpParams opParams = new MxOpParams();
+ // opParams.addParam("p", dropRate);
+ // opParams.addParam("state_size", state.getShape().tail());
+ // opParams.addParam("num_layers", numLayers);
+ // opParams.addParam("bidirectional", bidirectional);
+ // opParams.addParam("state_outputs", true);
+ // opParams.addParam("mode", "gru");
+ //
+ // MxNDList inputs = new MxNDList();
+ // inputs.add(input);
+ //
+ // try (MxNDList temp = new MxNDList()) {
+ // for (MxNDArray param : params) {
+ // temp.add(param.flatten());
+ // }
+ // MxNDArray tempParam = MxNDArrays.concat(temp);
+ // tempParam.attach(input.getManager());
+ // inputs.add(tempParam);
+ // }
+ //
+ // inputs.add(state);
+ //
+ // if (!batchFirst) {
+ // return getManager().invoke("_npx_rnn", inputs, opParams);
+ // }
+ //
+ // MxNDList result = getManager().invoke("_npx_rnn", inputs, opParams);
+ // try (MxNDArray temp = result.head()) {
+ // return new MxNDList(temp.swapAxes(0, 1), result.get(1));
+ // }
+ // }
+ //
+ // public MxNDList lstm(
+ // MxNDArray input,
+ // MxNDList states,
+ // MxNDList params,
+ // boolean hasBiases,
+ // int numLayers,
+ // double dropRate,
+ // boolean training,
+ // boolean bidirectional,
+ // boolean batchFirst) {
+ // int numParams = numLayers * ((hasBiases) ? 4 : 2) * ((bidirectional) ? 2 : 1);
+ // Preconditions.checkArgument(
+ // params.size() == numParams,
+ // "The size of Params is incorrect expect "
+ // + numParams
+ // + " parameters but got "
+ // + params.size());
+ //
+ // if (training != JnaUtils.autogradIsTraining()) {
+ // throw new IllegalArgumentException(
+ // "the mode of lstm in MXNet should align with the mode of
+ // GradientCollector");
+ // }
+ //
+ // if (batchFirst) {
+ // input = input.swapAxes(0, 1);
+ // }
+ //
+ // MxOpParams opParams = new MxOpParams();
+ // opParams.addParam("mode", "lstm");
+ // opParams.addParam("p", dropRate);
+ // opParams.addParam("state_size", states.head().getShape().tail());
+ // opParams.addParam("state_outputs", true);
+ // opParams.addParam("num_layers", numLayers);
+ // opParams.addParam("bidirectional", bidirectional);
+ // opParams.addParam("lstm_state_clip_nan", true);
+ //
+ // MxNDList inputs = new MxNDList();
+ // inputs.add(input);
+ // try (MxNDList temp = new MxNDList()) {
+ // for (MxNDArray param : params) {
+ // temp.add(param.flatten());
+ // }
+ // MxNDArray tempParam = MxNDArrays.concat(temp);
+ // tempParam.attach(input.getManager());
+ // inputs.add(tempParam);
+ // }
+ // inputs.addAll(states);
+ //
+ // if (!batchFirst) {
+ // return getManager().invoke("_npx_rnn", inputs, opParams);
+ // }
+ //
+ // MxNDList result = getManager().invoke("_npx_rnn", inputs, opParams);
+ // try (MxNDArray temp = result.head()) {
+ // return new MxNDList(temp.swapAxes(0, 1), result.get(1), result.get(2));
+ // }
+ // }
+
+ ////////////////////////////////////////
+ // Image and CV
+ ////////////////////////////////////////
+
+ public NDArray normalize(float[] mean, float[] std) {
+ OpParams params = new OpParams();
+ params.addTupleParam("mean", mean);
+ params.addTupleParam("std", std);
+ return NDArray.invoke(getArray().getParent(), "_npx__image_normalize", array, params);
+ }
+
+ public NDArray toTensor() {
+ return NDArray.invoke(getArray().getParent(), "_npx__image_to_tensor", array, null);
+ }
+
+ public NDArray resize(int width, int height, int interpolation) {
+ if (array.isEmpty()) {
+ throw new IllegalArgumentException("attempt to resize of an empty MxNDArray");
+ }
+ OpParams params = new OpParams();
+ params.addTupleParam("size", width, height);
+ params.addParam("interp", interpolation);
+ return NDArray.invoke(getArray().getParent(), "_npx__image_resize", array, params);
+ }
+
+ public NDArray crop(int x, int y, int width, int height) {
+ OpParams params = new OpParams();
+ params.add("x", x);
+ params.add("y", y);
+ params.add("width", width);
+ params.add("height", height);
+ return NDArray.invoke(getArray().getParent(), "_npx__image_crop", array, params);
+ }
+
+ public NDArray randomFlipLeftRight() {
+ if (array.getDevice().getDeviceType().equals(Device.Type.GPU)) {
+ throw new UnsupportedOperationException("randomFlipLeftRight is not supported on GPU");
+ }
+ return NDArray.invoke(
+ getArray().getParent(), "_npx__image_random_flip_left_right", array, null);
+ }
+
+ public NDArray randomFlipTopBottom() {
+ if (array.getDevice().getDeviceType().equals(Device.Type.GPU)) {
+ throw new UnsupportedOperationException("randomFlipTopBottom is not supported on GPU");
+ }
+ return NDArray.invoke(
+ getArray().getParent(), "_npx__image_random_flip_top_bottom", array, null);
+ }
+
+ public NDArray randomBrightness(float brightness) {
+ if (array.getDevice().getDeviceType().equals(Device.Type.GPU)) {
+ throw new UnsupportedOperationException("randomBrightness is not supported on GPU");
+ }
+ OpParams params = new OpParams();
+ float min = Math.max(0, 1 - brightness);
+ float max = 1 + brightness;
+ params.addParam("min_factor", min);
+ params.addParam("max_factor", max);
+ return NDArray.invoke(
+ getArray().getParent(), "_npx__image_random_brightness", array, params);
+ }
+
+ public NDArray randomHue(float hue) {
+ if (array.getDevice().getDeviceType().equals(Device.Type.GPU)) {
+ throw new UnsupportedOperationException("randomHue is not supported on GPU");
+ }
+ OpParams params = new OpParams();
+ float min = Math.max(0, 1 - hue);
+ float max = 1 + hue;
+ params.addParam("min_factor", min);
+ params.addParam("max_factor", max);
+ return NDArray.invoke(getArray().getParent(), "_npx__image_random_hue", array, params);
+ }
+
+ public NDArray randomColorJitter(
+ float brightness, float contrast, float saturation, float hue) {
+ if (array.getDevice().getDeviceType().equals(Device.Type.GPU)) {
+ throw new UnsupportedOperationException("randomColorJitter is not supported on GPU");
+ }
+ OpParams params = new OpParams();
+ params.addParam("brightness", brightness);
+ params.addParam("contrast", contrast);
+ params.addParam("saturation", saturation);
+ params.addParam("hue", hue);
+ return NDArray.invoke(
+ getArray().getParent(), "_npx__image_random_color_jitter", array, params);
+ }
+
+ public NDArrayIndexer getIndexer() {
+ return INDEXER;
+ }
+
+ ////////////////////////////////////////
+ // Miscellaneous
+ ////////////////////////////////////////
+
+ @SuppressWarnings("PMD.UseTryWithResources")
+ public NDArray where(NDArray condition, NDArray other) {
+ NDArray array1;
+ NDArray array2;
+ condition =
+ (condition.getDataType() == DataType.BOOLEAN)
+ ? condition.toType(DataType.INT32, false)
+ : condition;
+ if (array.getDataType() != other.getDataType()) {
+ throw new IllegalArgumentException(
+ "DataType mismatch, required "
+ + array.getDataType()
+ + " actual "
+ + other.getDataType());
+ }
+ if (!array.shapeEquals(other)) {
+ Shape res = deriveBroadcastedShape(array.getShape(), other.getShape());
+ array1 = (!res.equals(array.getShape())) ? array.broadcast(res) : array;
+ array2 = (!res.equals(other.getShape())) ? other.broadcast(res) : other;
+ } else {
+ array1 = array;
+ array2 = other;
+ }
+ try {
+ return NDArray.invoke(
+ getArray().getParent(),
+ "where",
+ new NDArray[] {condition, array1, array2},
+ null);
+ } finally {
+ if (array1 != array) {
+ array1.close();
+ }
+ if (array2 != other) {
+ array2.close();
+ }
+ }
+ }
+
+ public NDArray stack(NDList arrays, int axis) {
+ OpParams params = new OpParams();
+ params.addParam("axis", axis);
+ NDArray[] srcArray = new NDArray[arrays.size() + 1];
+ srcArray[0] = array;
+ System.arraycopy(arrays.toArray(new NDArray[0]), 0, srcArray, 1, arrays.size());
+ return NDArray.invoke(getArray().getParent(), "_npi_stack", srcArray, params);
+ }
+
+ /**
+ * Check two criteria of concat input: 1. no scalar 2. dimensions of all the array must be the
+ * same.
+ *
+ * @param list input {@link NDList}
+ */
+ public static void checkConcatInput(NDList list) {
+ NDArray[] arrays = list.toArray(new NDArray[0]);
+ if (Stream.of(arrays).allMatch(array -> array.getShape().dimension() == 0)) {
+ throw new IllegalArgumentException(
+ "scalar(zero-dimensional) arrays cannot be concatenated");
+ }
+ int dimension = arrays[0].getShape().dimension();
+ for (int i = 1; i < arrays.length; i++) {
+ if (arrays[i].getShape().dimension() != dimension) {
+ throw new IllegalArgumentException(
+ "all the input arrays must have same number of dimensions, but the array at index 0 has "
+ + dimension
+ + " dimension(s) and the array at index "
+ + i
+ + " has "
+ + arrays[i].getShape().dimension()
+ + " dimension(s)");
+ }
+ }
+ }
+
+ public NDArray concat(NDList list, int axis) {
+ checkConcatInput(list);
+
+ OpParams params = new OpParams();
+ // MXNet backend use dim as argument name
+ params.addParam("axis", axis);
+ NDArray[] srcArray = new NDArray[list.size() + 1];
+ srcArray[0] = array;
+ System.arraycopy(list.toArray(new NDArray[0]), 0, srcArray, 1, list.size());
+ return NDArray.invoke(getArray().getParent(), "_npi_concatenate", srcArray, params);
+ }
+
+ public NDList multiBoxTarget(
+ NDList inputs,
+ float iouThreshold,
+ float ignoreLabel,
+ float negativeMiningRatio,
+ float negativeMiningThreshold,
+ int minNegativeSamples) {
+ OpParams parameters = new OpParams();
+ parameters.add("minimum_negative_samples", minNegativeSamples);
+ parameters.add("overlap_threshold", iouThreshold);
+ parameters.add("ignore_label", ignoreLabel);
+ parameters.add("negative_mining_ratio", negativeMiningRatio);
+ parameters.add("negative_mining_thresh", negativeMiningThreshold);
+ return NDArray.invoke(getArray().getParent(), "MultiBoxTarget", inputs, parameters);
+ }
+
+ public NDList multiBoxPrior(
+ List<Float> sizes,
+ List<Float> ratios,
+ List<Float> steps,
+ List<Float> offsets,
+ boolean clip) {
+ OpParams parameters = new OpParams();
+ parameters.add("sizes", sizes);
+ parameters.add("ratios", ratios);
+ parameters.add("steps", steps);
+ parameters.add("offsets", offsets);
+ parameters.add("clip", clip);
+ return NDArray.invoke(
+ getArray().getParent(), "MultiBoxPrior", new NDList(array), parameters);
+ }
+
+ public NDList multiBoxDetection(
+ NDList inputs,
+ boolean clip,
+ float threshold,
+ int backgroundId,
+ float nmsThreashold,
+ boolean forceSuppress,
+ int nmsTopK) {
+ OpParams parameters = new OpParams();
+ parameters.add("clip", clip);
+ parameters.add("threshold", threshold);
+ parameters.add("background_id", backgroundId);
+ parameters.add("nms_threshold", nmsThreashold);
+ parameters.add("force_suppress", forceSuppress);
+ parameters.add("nms_topk", nmsTopK);
+ return NDArray.invoke(getArray().getParent(), "MultiBoxDetection", inputs, parameters);
+ }
+
+ public NDArray getArray() {
+ return array;
+ }
+
+ private int getGlobalPoolingDim() {
+ int poolDim = getArray().getShape().dimension() - 2;
+ if (poolDim < 1 || poolDim > 3) {
+ throw new IllegalStateException(
+ "GlobalPooling only support"
+ + "1 to 3 Dimensions, "
+ + poolDim
+ + "D is not supported.");
+ }
+ return poolDim;
+ }
+
+ private Shape getGlobalPoolingShapes(long fillValue) {
+ // determine pooling dimension according to input
+ // input dimension minus 2 (batch and channel dim)
+ int poolDim = getGlobalPoolingDim();
+ long[] shape = new long[poolDim];
+ Arrays.fill(shape, fillValue);
+ return new Shape(shape);
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/NDArrayIndexer.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/NDArrayIndexer.java
new file mode 100644
index 0000000..ee93fd7
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/NDArrayIndexer.java
@@ -0,0 +1,163 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.ndarray;
+
+import java.util.List;
+import java.util.Optional;
+import java.util.Stack;
+import org.apache.mxnet.engine.OpParams;
+import org.apache.mxnet.ndarray.dim.NDIndexBooleans;
+import org.apache.mxnet.ndarray.dim.NDIndexElement;
+import org.apache.mxnet.ndarray.dim.full.NDIndexFullPick;
+import org.apache.mxnet.ndarray.dim.full.NDIndexFullSlice;
+import org.apache.mxnet.ndarray.index.NDIndex;
+import org.apache.mxnet.ndarray.types.Shape;
+
+/** A helper class for {@link NDArray} implementations for operations with an {@link NDIndex}. */
+public class NDArrayIndexer {
+
+ /**
+ * Returns a subarray by picking the elements.
+ *
+ * @param array the array to get from
+ * @param index the index to get
+ * @return the subArray
+ */
+ public NDArray get(NDArray array, NDIndex index) {
+ if (index.getRank() == 0 && array.getShape().isScalar()) {
+ return array.duplicate();
+ }
+
+ // use booleanMask for NDIndexBooleans case
+ List<NDIndexElement> indices = index.getIndices();
+ if (!indices.isEmpty() && indices.get(0) instanceof NDIndexBooleans) {
+ if (indices.size() != 1) {
+ throw new IllegalArgumentException(
+ "get() currently didn't support more that one boolean NDArray");
+ }
+ return array.booleanMask(((NDIndexBooleans) indices.get(0)).getIndex());
+ }
+
+ Optional<NDIndexFullPick> fullPick = NDIndexFullPick.fromIndex(index, array.getShape());
+ if (fullPick.isPresent()) {
+ return get(array, fullPick.get());
+ }
+
+ Optional<NDIndexFullSlice> fullSlice = NDIndexFullSlice.fromIndex(index, array.getShape());
+ if (fullSlice.isPresent()) {
+ return get(array, fullSlice.get());
+ }
+ throw new UnsupportedOperationException(
+ "get() currently supports all, fixed, and slices indices");
+ }
+
+ /**
+ * Returns a subarray by picking the elements.
+ *
+ * @param array the array to get from
+ * @param fullPick the elements to pick
+ * @return the subArray
+ */
+ public NDArray get(NDArray array, NDIndexFullPick fullPick) {
+ OpParams params = new OpParams();
+ params.addParam("axis", fullPick.getAxis());
+ params.addParam("keepdims", true);
+ params.add("mode", "wrap");
+ return NDArray.invoke(
+ array.getParent(), "pick", new NDList(array, fullPick.getIndices()), params)
+ .singletonOrThrow();
+ }
+
+ /**
+ * Returns a subarray at the slice.
+ *
+ * @param array the array to get from
+ * @param fullSlice the fullSlice index of the array
+ * @return the subArray
+ */
+ public NDArray get(NDArray array, NDIndexFullSlice fullSlice) {
+ OpParams params = new OpParams();
+ params.addTupleParam("begin", fullSlice.getMin());
+ params.addTupleParam("end", fullSlice.getMax());
+ params.addTupleParam("step", fullSlice.getStep());
+
+ NDArray result = NDArray.invoke(array.getParent(), "_npi_slice", array, params);
+ int[] toSqueeze = fullSlice.getToSqueeze();
+ if (toSqueeze.length > 0) {
+ NDArray oldResult = result;
+ result = result.squeeze(toSqueeze);
+ oldResult.close();
+ }
+ return result;
+ }
+
+ /**
+ * Sets the values of the array at the fullSlice with an array.
+ *
+ * @param array the array to set
+ * @param fullSlice the fullSlice of the index to set in the array
+ * @param value the value to set with
+ */
+ public void set(NDArray array, NDIndexFullSlice fullSlice, NDArray value) {
+ OpParams params = new OpParams();
+ params.addTupleParam("begin", fullSlice.getMin());
+ params.addTupleParam("end", fullSlice.getMax());
+ params.addTupleParam("step", fullSlice.getStep());
+
+ Stack<NDArray> prepareValue = new Stack<>();
+ prepareValue.add(value);
+ prepareValue.add(prepareValue.peek().toDevice(array.getDevice(), false));
+ // prepareValue.add(prepareValue.peek().asType(getDataType(), false));
+ // Deal with the case target: (1, 10, 1), original (10)
+ // try to find (10, 1) and reshape (10) to that
+ Shape targetShape = fullSlice.getShape();
+ while (targetShape.size() > value.size()) {
+ targetShape = targetShape.slice(1);
+ }
+ prepareValue.add(prepareValue.peek().reshape(targetShape));
+ prepareValue.add(prepareValue.peek().broadcast(fullSlice.getShape()));
+
+ NDArray.invoke(
+ "_npi_slice_assign",
+ new NDArray[] {array, prepareValue.peek()},
+ new NDArray[] {array},
+ params);
+ for (NDArray toClean : prepareValue) {
+ if (toClean != value) {
+ toClean.close();
+ }
+ }
+ }
+
+ /**
+ * Sets the values of the array at the fullSlice with a number.
+ *
+ * @param array the array to set
+ * @param fullSlice the fullSlice of the index to set in the array
+ * @param value the value to set with
+ */
+ public void set(NDArray array, NDIndexFullSlice fullSlice, Number value) {
+ OpParams params = new OpParams();
+ params.addTupleParam("begin", fullSlice.getMin());
+ params.addTupleParam("end", fullSlice.getMax());
+ params.addTupleParam("step", fullSlice.getStep());
+ params.addParam("scalar", value);
+ NDArray.invoke(
+ "_npi_slice_assign_scalar", new NDArray[] {array}, new NDArray[] {array}, params);
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/NDArrays.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/NDArrays.java
new file mode 100644
index 0000000..df0c6b8
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/NDArrays.java
@@ -0,0 +1,2008 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.ndarray;
+
+import java.util.Arrays;
+import org.apache.mxnet.ndarray.types.Shape;
+
+/** This class contains various methods for manipulating MxNDArrays. */
+public final class NDArrays {
+
+ private NDArrays() {}
+
+ private static void checkInputs(NDArray[] arrays) {
+ if (arrays == null || arrays.length < 2) {
+ throw new IllegalArgumentException("Passed in arrays must have at least one element");
+ }
+ if (arrays.length > 2
+ && Arrays.stream(arrays).skip(1).anyMatch(array -> !arrays[0].shapeEquals(array))) {
+ throw new IllegalArgumentException("The shape of all inputs must be the same");
+ }
+ }
+
+ ////////////////////////////////////////
+ // Operations: Element Comparison
+ ////////////////////////////////////////
+
+ /**
+ * Returns {@code true} if all elements in {@link NDArray} a are equal to {@link NDArray} b.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array = manager.ones(new Shape(3));
+ * jshell> MxNDArrays.contentEquals(array, 1); // return true instead of boolean MxNDArray
+ * true
+ * </pre>
+ *
+ * @param a the {@link NDArray} to compare
+ * @param n the number to compare
+ * @return the boolean result
+ */
+ public static boolean contentEquals(NDArray a, Number n) {
+ if (a == null) {
+ return false;
+ }
+ return a.contentEquals(n);
+ }
+
+ /**
+ * Returns {@code true} if all elements in {@link NDArray} a are equal to {@link NDArray} b.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array1 = manager.arange(6f).reshape(2, 3);
+ * jshell> MxNDArray array2 = manager.create(new float[] {0f, 1f, 2f, 3f, 4f, 5f}, new Shape(2, 3));
+ * jshell> MxNDArrays.contentEquals(array1, array2); // return true instead of boolean MxNDArray
+ * true
+ * </pre>
+ *
+ * @param a the {@link NDArray} to compare
+ * @param b the {@link NDArray} to compare
+ * @return the boolean result
+ */
+ public static boolean contentEquals(NDArray a, NDArray b) {
+ return a.contentEquals(b);
+ }
+
+ /**
+ * Checks 2 {@link NDArray}s for equal shapes.
+ *
+ * <p>Shapes are considered equal if:
+ *
+ * <ul>
+ * <li>Both {@link NDArray}s have equal rank, and
+ * <li>size(0)...size(rank()-1) are equal for both {@link NDArray}s
+ * </ul>
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array1 = manager.ones(new Shape(1, 2, 3));
+ * jshell> MxNDArray array2 = manager.create(new Shape(1, 2, 3));
+ * jshell> MxNDArrays.shapeEquals(array1, array2); // return true instead of boolean MxNDArray
+ * true
+ * </pre>
+ *
+ * @param a the {@link NDArray} to compare
+ * @param b the {@link NDArray} to compare
+ * @return {@code true} if the {@link Shape}s are the same
+ */
+ public static boolean shapeEquals(NDArray a, NDArray b) {
+ return a.shapeEquals(b);
+ }
+
+ /**
+ * Returns {@code true} if two {@link NDArray} are element-wise equal within a tolerance.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array1 = manager.create(new double[] {1e10,1e-7});
+ * jshell> MxNDArray array2 = manager.create(new double[] {1.00001e10,1e-8});
+ * jshell> MxNDArrays.allClose(array1, array2); // return false instead of boolean MxNDArray
+ * false
+ * jshell> MxNDArray array1 = manager.create(new double[] {1e10,1e-8});
+ * jshell> MxNDArray array2 = manager.create(new double[] {1.00001e10,1e-9});
+ * jshell> MxNDArrays.allClose(array1, array2); // return true instead of boolean MxNDArray
+ * true
+ * </pre>
+ *
+ * @param a the {@link NDArray} to compare with
+ * @param b the {@link NDArray} to compare with
+ * @return the boolean result
+ */
+ // public static boolean allClose(MxNDArray a, MxNDArray b) {
+ // return a.allClose(b);
+ // }
+
+ /**
+ * Returns {@code true} if two {@link NDArray} are element-wise equal within a tolerance.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array1 = manager.create(new double[] {1e10, 1e-7});
+ * jshell> MxNDArray array2 = manager.create(new double[] {1.00001e10, 1e-8});
+ * jshell> MxNDArrays.allClose(array1, array2, 1e-05, 1e-08, false); // return false instead of boolean MxNDArray
+ * false
+ * jshell> MxNDArray array1 = manager.create(new double[] {1e10, 1e-8});
+ * jshell> MxNDArray array2 = manager.create(new double[] {1.00001e10, 1e-9});
+ * jshell> MxNDArrays.allClose(array1, array2, 1e-05, 1e-08, false); // return true instead of boolean MxNDArray
+ * true
+ * jshell> MxNDArray array1 = manager.create(new float[] {1f, Float.NaN});
+ * jshell> MxNDArray array2 = manager.create(new float[] {1f, Float.NaN});
+ * jshell> MxNDArrays.allClose(array1, array2, 1e-05, 1e-08, true); // return true instead of boolean MxNDArray
+ * true
+ * </pre>
+ *
+ * @param a the {@link NDArray} to compare with
+ * @param b the {@link NDArray} to compare with
+ * @param rtol the relative tolerance parameter
+ * @param atol the absolute tolerance parameter
+ * @param equalNan whether to compare NaN’s as equal. If {@code true}, NaN’s in the {@link
+ * NDArray} will be considered equal to NaN’s in the other {@link NDArray}
+ * @return the boolean result
+ */
+ // public static boolean allClose(
+ // MxNDArray a, MxNDArray b, double rtol, double atol, boolean equalNan) {
+ // return a.allClose(b, rtol, atol, equalNan);
+ // }
+
+ /**
+ * Returns the boolean {@link NDArray} for element-wise "Equals" comparison.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array = manager.ones(new Shape(1));
+ * jshell> MxNDArrays.eq(array, 1);
+ * ND: (1) cpu() boolean
+ * [ true]
+ * </pre>
+ *
+ * @param a the {@link NDArray} to compare
+ * @param n the number to compare
+ * @return the boolean {@link NDArray} for element-wise "Equals" comparison
+ */
+ public static NDArray eq(NDArray a, Number n) {
+ return a.eq(n);
+ }
+
+ /**
+ * Returns the boolean {@link NDArray} for element-wise "Equals" comparison.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array = manager.ones(new Shape(1));
+ * jshell> MxNDArrays.eq(1, array);
+ * ND: (1) cpu() boolean
+ * [ true]
+ * </pre>
+ *
+ * @param n the number to compare
+ * @param a the {@link NDArray} to compare
+ * @return the boolean {@link NDArray} for element-wise "Equals" comparison
+ */
+ public static NDArray eq(Number n, NDArray a) {
+ return a.eq(n);
+ }
+
+ /**
+ * Returns the boolean {@link NDArray} for element-wise "Equals" comparison.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array1 = manager.create(new float[] {0f, 1f, 3f});
+ * jshell> MxNDArray array2 = manager.arange(3f);
+ * jshell> MxNDArrays.eq(array1, array2);
+ * ND: (3) cpu() boolean
+ * [ true, true, false]
+ * </pre>
+ *
+ * @param a the {@link NDArray} to compare
+ * @param b the {@link NDArray} to compare
+ * @return the boolean {@link NDArray} for element-wise "Equals" comparison
+ */
+ public static NDArray eq(NDArray a, NDArray b) {
+ return a.eq(b);
+ }
+
+ /**
+ * Returns the boolean {@link NDArray} for element-wise "Not equals" comparison.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array = manager.arange(4f).reshape(2, 2);
+ * jshell> MxNDArrays.neq(array, 1);
+ * ND: (2, 2) cpu() boolean
+ * [[ true, false],
+ * [ true, true],
+ * ]
+ * </pre>
+ *
+ * @param a the {@link NDArray} to compare
+ * @param n the number to compare
+ * @return the boolean {@link NDArray} for element-wise "Not equals" comparison
+ */
+ public static NDArray neq(NDArray a, Number n) {
+ return a.neq(n);
+ }
+
+ /**
+ * Returns the boolean {@link NDArray} for element-wise "Not equals" comparison.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array = manager.arange(f4).reshape(2, 2);
+ * jshell> MxNDArrays.neq(1, array);
+ * ND: (2, 2) cpu() boolean
+ * [[ true, false],
+ * [ true, true],
+ * ]
+ * </pre>
+ *
+ * @param n the number to compare
+ * @param a the {@link NDArray} to compare
+ * @return the boolean {@link NDArray} for element-wise "Not equals" comparison
+ */
+ public static NDArray neq(Number n, NDArray a) {
+ return a.neq(n);
+ }
+
+ /**
+ * Returns the boolean {@link NDArray} for element-wise "Not equals" comparison.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array1 = manager.create(new float[] {1f, 2f});
+ * jshell> MxNDArray array2 = manager.create(new float[] {1f, 3f});
+ * jshell> MxNDArrays.neq(array1, array2);
+ * ND: (2) cpu() boolean
+ * [false, true]
+ * jshell> MxNDArray array1 = manager.create(new float[] {1f, 2f});
+ * jshell> MxNDArray array2 = manager.create(new float[] {1f, 3f, 1f, 4f}, new Shape(2, 2));
+ * jshell> MxNDArrays.neq(array1, array2); // broadcasting
+ * ND: (2, 2) cpu() boolean
+ * [[false, true],
+ * [false, true],
+ * ]
+ * </pre>
+ *
+ * @param a the {@link NDArray} to compare
+ * @param b the {@link NDArray} to compare
+ * @return the boolean {@link NDArray} for element-wise "Not equals" comparison
+ */
+ public static NDArray neq(NDArray a, NDArray b) {
+ return a.neq(b);
+ }
+
+ /**
+ * Returns the boolean {@link NDArray} for element-wise "Greater Than" comparison.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array = manager.create(new float[] {4f, 2f});
+ * jshell> MxNDArrays.gt(array, 2f);
+ * ND: (2) cpu() boolean
+ * [ true, false]
+ * </pre>
+ *
+ * @param a the {@link NDArray} to compare
+ * @param n the number to be compared against
+ * @return the boolean {@link NDArray} for element-wise "Greater Than" comparison
+ */
+ public static NDArray gt(NDArray a, Number n) {
+ return a.gt(n);
+ }
+
+ /**
+ * Returns the boolean {@link NDArray} for element-wise "Greater Than" comparison.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array = manager.create(new float[] {4f, 2f});
+ * jshell> MxNDArrays.gt(2f, array);
+ * ND: (2) cpu() boolean
+ * [false, false]
+ * </pre>
+ *
+ * @param n the number to be compared
+ * @param a the MxNDArray to be compared against
+ * @return the boolean {@link NDArray} for element-wise "Greater Than" comparison
+ */
+ public static NDArray gt(Number n, NDArray a) {
+ return a.lt(n);
+ }
+
+ /**
+ * Returns the boolean {@link NDArray} for element-wise "Greater Than" comparison.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array1 = manager.create(new float[] {4f, 2f});
+ * jshell> MxNDArray array2 = manager.create(new float[] {2f, 2f});
+ * jshell> MxNDArrays.gt(array1, array2);
+ * ND: (2) cpu() boolean
+ * [ true, false]
+ * </pre>
+ *
+ * @param a the {@link NDArray} to be compared
+ * @param b the {@link NDArray} to be compared against
+ * @return the boolean {@link NDArray} for element-wise "Greater Than" comparison
+ */
+ public static NDArray gt(NDArray a, NDArray b) {
+ return a.gt(b);
+ }
+
+ /**
+ * Returns the boolean {@link NDArray} for element-wise "Greater or equals" comparison.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array = manager.create(new float[] {4f, 2f});
+ * jshell> MxNDArrays.gte(array, 2);
+ * ND: (2) cpu() boolean
+ * [ true, true]
+ * </pre>
+ *
+ * @param a the {@link NDArray} to be compared
+ * @param n the number to be compared against
+ * @return the boolean {@link NDArray} for element-wise "Greater or equals" comparison
+ */
+ public static NDArray gte(NDArray a, Number n) {
+ return a.gte(n);
+ }
+
+ /**
+ * Returns the boolean {@link NDArray} for element-wise "Greater or equals" comparison.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array = manager.create(new float[] {4f, 2f});
+ * jshell> MxNDArrays.gte(2, array);
+ * ND: (2) cpu() boolean
+ * [false, true]
+ * </pre>
+ *
+ * @param n the number to be compared
+ * @param a the {@link NDArray} to be compared against
+ * @return the boolean {@link NDArray} for element-wise "Greater or equals" comparison
+ */
+ public static NDArray gte(Number n, NDArray a) {
+ return a.lte(n);
+ }
+
+ /**
+ * Returns the boolean {@link NDArray} for element-wise "Greater or equals" comparison.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array1 = manager.create(new float[] {4f, 2f});
+ * jshell> MxNDArray array2 = manager.create(new float[] {2f, 2f});
+ * jshell> MxNDArrays.gte(array1, array2);
+ * ND: (2) cpu() boolean
+ * [ true, true]
+ * </pre>
+ *
+ * @param a the {@link NDArray} to be compared
+ * @param b the {@link NDArray} to be compared against
+ * @return the boolean {@link NDArray} for element-wise "Greater or equals" comparison
+ */
+ public static NDArray gte(NDArray a, NDArray b) {
+ return a.gte(b);
+ }
+
+ /**
+ * Returns the boolean {@link NDArray} for element-wise "Less" comparison.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array = manager.create(new float[] {1f, 2f});
+ * jshell> MxNDArrays.lt(array, 2f);
+ * ND: (2) cpu() boolean
+ * [ true, false]
+ * </pre>
+ *
+ * @param a the {@link NDArray} to be compared
+ * @param n the number to be compared against
+ * @return the boolean {@link NDArray} for element-wise "Less" comparison
+ */
+ public static NDArray lt(NDArray a, Number n) {
+ return a.lt(n);
+ }
+
+ /**
+ * Returns the boolean {@link NDArray} for element-wise "Less" comparison.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array = manager.create(new float[] {1f, 2f});
+ * jshell> MxNDArrays.lt(2f, array);
+ * ND: (2) cpu() boolean
+ * [false, false]
+ * </pre>
+ *
+ * @param n the number to be compared
+ * @param a the {@link NDArray} to be compared against
+ * @return the boolean {@link NDArray} for element-wise "Less" comparison
+ */
+ public static NDArray lt(Number n, NDArray a) {
+ return a.gt(n);
+ }
+
+ /**
+ * Returns the boolean {@link NDArray} for element-wise "Less" comparison.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array1 = manager.create(new float[] {1f, 2f});
+ * jshell> MxNDArray array2 = manager.create(new float[] {2f, 2f});
+ * jshell> MxNDArrays.lt(array1, array2);
+ * ND: (2) cpu() boolean
+ * [ true, false]
+ * </pre>
+ *
+ * @param a the {@link NDArray} to be compared
+ * @param b the {@link NDArray} to be compared against
+ * @return the boolean {@link NDArray} for element-wise "Less" comparison
+ */
+ public static NDArray lt(NDArray a, NDArray b) {
+ return a.lt(b);
+ }
+
+ /**
+ * Returns the boolean {@link NDArray} for element-wise "Less or equals" comparison.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array = manager.create(new float[] {1f, 2f});
+ * jshell> MxNDArrays.lte(array, 2f);
+ * ND: (2) cpu() boolean
+ * [ true, true]
+ * </pre>
+ *
+ * @param a the {@link NDArray} to be compared
+ * @param n the number to be compared against
+ * @return the boolean {@link NDArray} for element-wise "Less or equals" comparison
+ */
+ public static NDArray lte(NDArray a, Number n) {
+ return a.lte(n);
+ }
+
+ /**
+ * Returns the boolean {@link NDArray} for element-wise "Less or equals" comparison.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array = manager.create(new float[] {1f, 2f});
+ * jshell> MxNDArrays.lte(2f, array);
+ * ND: (2) cpu() boolean
+ * [false, true]
+ * </pre>
+ *
+ * @param n the number to be compared
+ * @param a the {@link NDArray} to be compared against
+ * @return the boolean {@link NDArray} for element-wise "Less or equals" comparison
+ */
+ public static NDArray lte(Number n, NDArray a) {
+ return a.gte(n);
+ }
+
+ /**
+ * Returns the boolean {@link NDArray} for element-wise "Less or equals" comparison.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array1 = manager.create(new float[] {1f, 2f});
+ * jshell> MxNDArray array2 = manager.create(new float[] {2f, 2f});
+ * jshell> MxNDArrays.lte(array1, array2)
+ * ND: (2) cpu() boolean
+ * [ true, true]
+ * </pre>
+ *
+ * @param a the {@link NDArray} to be compared
+ * @param b the {@link NDArray} to be compared against
+ * @return the boolean {@link NDArray} for element-wise "Less or equals" comparison
+ */
+ public static NDArray lte(NDArray a, NDArray b) {
+ return a.lte(b);
+ }
+
+ /**
+ * Returns elements chosen from the {@link NDArray} or the other {@link NDArray} depending on
+ * condition.
+ *
+ * <p>Given three {@link NDArray}s, condition, a, and b, returns an {@link NDArray} with the
+ * elements from a or b, depending on whether the elements from condition {@link NDArray} are
+ * {@code true} or {@code false}. If condition has the same shape as a, each element in the
+ * output {@link NDArray} is from this if the corresponding element in the condition is {@code
+ * true}, and from other if {@code false}.
+ *
+ * <p>Note that all non-zero values are interpreted as {@code true} in condition {@link
+ * NDArray}.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array = manager.arange(10f);
+ * jshell> MxNDArrays.where(array.lt(5), array, array.mul(10));
+ * ND: (10) cpu() float32
+ * [ 0., 1., 2., 3., 4., 50., 60., 70., 80., 90.]
+ * jshell> MxNDArray array = manager.create(new float[]{0f, 1f, 2f, 0f, 2f, 4f, 0f, 3f, 6f}, new Shape(3, 3));
+ * jshell> MxNDArrays.where(array.lt(4), array, manager.create(-1f));
+ * ND: (3, 3) cpu() float32
+ * [[ 0., 1., 2.],
+ * [ 0., 2., -1.],
+ * [ 0., 3., -1.],
+ * ]
+ * </pre>
+ *
+ * @param condition the condition {@code MxNDArray}
+ * @param a the first {@link NDArray}
+ * @param b the other {@link NDArray}
+ * @return the result {@link NDArray}
+ */
+ public static NDArray where(NDArray condition, NDArray a, NDArray b) {
+ return a.getNDArrayInternal().where(condition, b);
+ }
+
+ /**
+ * Returns the maximum of a {@link NDArray} and a number element-wise.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array = manager.create(new float[] {2f, 3f, 4f});
+ * jshell> MxNDArrays.maximum(array, 3f);
+ * ND: (3) cpu() float32
+ * [3., 3., 4.]
+ * </pre>
+ *
+ * @param a the {@link NDArray} to be compared
+ * @param n the number to be compared
+ * @return the maximum of a {@link NDArray} and a number element-wise
+ */
+ public static NDArray maximum(NDArray a, Number n) {
+ return a.maximum(n);
+ }
+
+ /**
+ * Returns the maximum of a number and a {@link NDArray} element-wise.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array = manager.create(new float[] {2f, 3f, 4f});
+ * jshell> MxNDArrays.maximum(3f, array);
+ * ND: (3) cpu() float32
+ * [3., 3., 4.]
+ * </pre>
+ *
+ * @param n the number to be compared
+ * @param a the {@link NDArray} to be compared
+ * @return the maximum of a number and a {@link NDArray} element-wise
+ */
+ public static NDArray maximum(Number n, NDArray a) {
+ return maximum(a, n);
+ }
+
+ /**
+ * Returns the maximum of {@link NDArray} a and {@link NDArray} b element-wise.
+ *
+ * <p>The shapes of {@link NDArray} a and {@link NDArray} b must be broadcastable.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array1 = manager.create(new float[] {2f, 3f, 4f});
+ * jshell> MxNDArray array2 = manager.create(new float[] {1f, 5f, 2f});
+ * jshell> MxNDArrays.maximum(array1, array2);
+ * ND: (3) cpu() float32
+ * [2., 5., 4.]
+ * jshell> MxNDArray array1 = manager.eye(2);
+ * jshell> MxNDArray array2 = manager.create(new float[] {0.5f, 2f});
+ * jshell> MxNDArrays.maximum(array1, array2); // broadcasting
+ * ND: (2, 2) cpu() float32
+ * [[1. , 2. ],
+ * [0.5, 2. ],
+ * ]
+ * </pre>
+ *
+ * @param a the {@link NDArray} to be compared
+ * @param b the {@link NDArray} to be compared
+ * @return the maximum of {@link NDArray} a and {@link NDArray} b element-wise
+ */
+ public static NDArray maximum(NDArray a, NDArray b) {
+ return a.maximum(b);
+ }
+
+ /**
+ * Returns the minimum of a {@link NDArray} and a number element-wise.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array = manager.create(new float[] {2f, 3f, 4f});
+ * jshell> MxNDArrays.minimum(array, 3f);
+ * ND: (3) cpu() float32
+ * [2., 3., 3.]
+ * </pre>
+ *
+ * @param a the {@link NDArray} to be compared
+ * @param n the number to be compared
+ * @return the minimum of a {@link NDArray} and a number element-wise
+ */
+ public static NDArray minimum(NDArray a, Number n) {
+ return a.minimum(n);
+ }
+
+ /**
+ * Returns the minimum of a number and a {@link NDArray} element-wise.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array = manager.create(new float[] {2f, 3f, 4f});
+ * jshell> MxNDArrays.minimum(3f, array);
+ * ND: (3) cpu() float32
+ * [2., 3., 3.]
+ * </pre>
+ *
+ * @param n the number to be compared
+ * @param a the {@link NDArray} to be compared
+ * @return the minimum of a number and a {@link NDArray} element-wise
+ */
+ public static NDArray minimum(Number n, NDArray a) {
+ return minimum(a, n);
+ }
+
+ /**
+ * Returns the minimum of {@link NDArray} a and {@link NDArray} b element-wise.
+ *
+ * <p>The shapes of {@link NDArray} a and {@link NDArray} b must be broadcastable.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array1 = manager.create(new float[] {2f, 3f, 4f});
+ * jshell> MxNDArray array2 = manager.create(new float[] {1f, 5f, 2f});
+ * jshell> MxNDArrays.minimum(array1, array2);
+ * ND: (3) cpu() float32
+ * [1., 3., 2.]
+ * jshell> MxNDArray array1 = manager.eye(2);
+ * jshell> MxNDArray array2 = manager.create(new float[] {0.5f, 2f});
+ * jshell> MxNDArrays.minimum(array1, array2); // broadcasting
+ * ND: (2, 2) cpu() float32
+ * [[0.5, 0. ],
+ * [0. , 1. ],
+ * ]
+ * </pre>
+ *
+ * @param a the {@link NDArray} to be compared
+ * @param b the {@link NDArray} to be compared
+ * @return the minimum of {@link NDArray} a and {@link NDArray} b element-wise
+ */
+ public static NDArray minimum(NDArray a, NDArray b) {
+ return a.minimum(b);
+ }
+
+ /**
+ * Returns portion of the {@link NDArray} given the index boolean {@link NDArray} along first
+ * axis.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array = manager.create(new float[] {1f, 2f, 3f, 4f, 5f, 6f}, new Shape(3, 2));
+ * jshell> MxNDArray mask = manager.create(new boolean[] {true, false, true});
+ * jshell> MxNDArrays.booleanMask(array, mask);
+ * ND: (2, 2) cpu() float32
+ * [[1., 2.],
+ * [5., 6.],
+ * ]
+ * </pre>
+ *
+ * @param data the {@link NDArray} to operate on
+ * @param index the boolean {@link NDArray} mask
+ * @return the result {@link NDArray}
+ */
+ public static NDArray booleanMask(NDArray data, NDArray index) {
+ return booleanMask(data, index, 0);
+ }
+
+ /**
+ * Returns portion of the {@link NDArray} given the index boolean {@link NDArray} along given
+ * axis.
+ *
+ * @param data the {@link NDArray} to operate on
+ * @param index the boolean {@link NDArray} mask
+ * @param axis an integer that represents the axis of {@link NDArray} to mask from
+ * @return the result {@link NDArray}
+ */
+ public static NDArray booleanMask(NDArray data, NDArray index, int axis) {
+ return data.booleanMask(index, axis);
+ }
+
+ /**
+ * Sets all elements of the given {@link NDArray} outside the sequence {@link NDArray} to a
+ * constant value.
+ *
+ * <p>This function takes an n-dimensional input array of the form [batch_size,
+ * max_sequence_length, ....] and returns an array of the same shape. Parameter {@code
+ * sequenceLength} is used to handle variable-length sequences. {@code sequenceLength} should be
+ * an input array of positive ints of dimension [batch_size].
+ *
+ * @param data the {@link NDArray} to operate on
+ * @param sequenceLength used to handle variable-length sequences
+ * @param value the constant value to be set
+ * @return the result {@link NDArray}
+ */
+ public static NDArray sequenceMask(NDArray data, NDArray sequenceLength, float value) {
+ return data.sequenceMask(sequenceLength, value);
+ }
+
+ /**
+ * Sets all elements of the given {@link NDArray} outside the sequence {@link NDArray} to 0.
+ *
+ * <p>This function takes an n-dimensional input array of the form [batch_size,
+ * max_sequence_length, ....] and returns an array of the same shape. Parameter {@code
+ * sequenceLength} is used to handle variable-length sequences. {@code sequenceLength} should be
+ * an input array of positive ints of dimension [batch_size].
+ *
+ * @param data the {@link NDArray} to operate on
+ * @param sequenceLength used to handle variable-length sequences
+ * @return the result {@link NDArray}
+ */
+ public static NDArray sequenceMask(NDArray data, NDArray sequenceLength) {
+ return data.sequenceMask(sequenceLength);
+ }
+
+ ////////////////////////////////////////
+ // Operations: Element Arithmetic
+ ////////////////////////////////////////
+
+ /**
+ * Adds a number to the {@link NDArray} element-wise.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array = manager.create(new float[] {1f, 2f});
+ * jshell> MxNDArrays.add(array, 2f);
+ * ND: (2) cpu() float32
+ * [3., 4.]
+ * </pre>
+ *
+ * @param a the {@link NDArray} to be added to
+ * @param n the number to add
+ * @return the result {@link NDArray}
+ */
+ public static NDArray add(NDArray a, Number n) {
+ return a.add(n);
+ }
+
+ /**
+ * Adds a {@link NDArray} to a number element-wise.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array = manager.create(new float[] {1f, 2f});
+ * jshell> MxNDArrays.add(2f, array);
+ * ND: (2) cpu() float32
+ * [3., 4.]
+ * </pre>
+ *
+ * @param n the number to be added to
+ * @param a the {@link NDArray} to add
+ * @return the result {@link NDArray}
+ */
+ public static NDArray add(Number n, NDArray a) {
+ return a.add(n);
+ }
+
+ /**
+ * Adds a {@link NDArray} to a {@link NDArray} element-wise.
+ *
+ * <p>The shapes of all of the {@link NDArray}s must be the same.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array = manager.create(new float[] {1f, 2f});
+ * jshell> MxNDArrays.add(array, array, array);
+ * ND: (2) cpu() float32
+ * [3., 6.]
+ * </pre>
+ *
+ * @param arrays the {@link NDArray}s to add together
+ * @return the result {@link NDArray}
+ * @throws IllegalArgumentException arrays must have at least two elements
+ * @throws IllegalArgumentException the shape of all inputs must be the same
+ */
+ public static NDArray add(NDArray... arrays) {
+ checkInputs(arrays);
+ if (arrays.length == 2) {
+ return arrays[0].add(arrays[1]);
+ }
+ try (NDArray array = NDArrays.stack(new NDList(arrays))) {
+ return array.sum(new int[] {0});
+ }
+ }
+
+ /**
+ * Subtracts a number from the {@link NDArray} element-wise.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array = manager.create(new float[] {1f, 2f});
+ * jshell> array.sub(2f);
+ * ND: (2) cpu() float32
+ * [-1., 0.]
+ * </pre>
+ *
+ * @param a the {@link NDArray} to be subtracted
+ * @param n the number to subtract from
+ * @return the result {@link NDArray}
+ */
+ public static NDArray sub(NDArray a, Number n) {
+ return a.sub(n);
+ }
+
+ /**
+ * Subtracts a {@link NDArray} from a number element-wise.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array = manager.create(new float[] {1f, 2f});
+ * jshell> MxNDArrays.sub(3f, array);
+ * ND: (2) cpu() float32
+ * [2., 1.]
+ * </pre>
+ *
+ * @param n the number to be subtracted
+ * @param a the {@link NDArray} to subtract from
+ * @return the result {@link NDArray}
+ */
+ public static NDArray sub(Number n, NDArray a) {
+ return a.getNDArrayInternal().rsub(n);
+ }
+
+ /**
+ * Subtracts a {@link NDArray} from a {@link NDArray} element-wise.
+ *
+ * <p>The shapes of {@link NDArray} a and {@link NDArray} b must be broadcastable.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array1 = manager.arange(9f).reshape(3, 3);
+ * jshell> MxNDArray array2 = manager.arange(3f);
+ * jshell> MxNDArrays.sub(array1, array2); // broadcasting
+ * ND: (3, 3) cpu() float32
+ * [[0., 0., 0.],
+ * [3., 3., 3.],
+ * [6., 6., 6.],
+ * ]
+ * </pre>
+ *
+ * @param a the {@link NDArray} to be subtracted
+ * @param b the {@link NDArray} to subtract from
+ * @return the result {@link NDArray}
+ */
+ public static NDArray sub(NDArray a, NDArray b) {
+ return a.sub(b);
+ }
+
+ /**
+ * Multiplies the {@link NDArray} by a number element-wise.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array = manager.create(new float[] {1f, 2f});
+ * jshell> MxNDArrays.mul(array, 3f);
+ * ND: (2) cpu() float32
+ * [3., 6.]
+ * </pre>
+ *
+ * @param a the MxNDArray to be multiplied
+ * @param n the number to multiply by
+ * @return the result {@link NDArray}
+ */
+ public static NDArray mul(NDArray a, Number n) {
+ return a.mul(n);
+ }
+
+ /**
+ * Multiplies a number by a {@link NDArray} element-wise.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array = manager.create(new float[] {1f, 2f});
+ * jshell> MxNDArrays.mul(3f, array);
+ * ND: (2) cpu() float32
+ * [3., 6.]
+ * </pre>
+ *
+ * @param n the number to be multiplied
+ * @param a the {@link NDArray} to multiply by
+ * @return the result {@link NDArray}
+ */
+ public static NDArray mul(Number n, NDArray a) {
+ return a.mul(n);
+ }
+
+ /**
+ * Multiplies all of the {@link NDArray}s together element-wise.
+ *
+ * <p>The shapes of all of the {@link NDArray}s must be the same.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array = manager.create(new float[] {1f, 2f});
+ * jshell> MxNDArrays.mul(array, array, array);
+ * ND: (2) cpu() float32
+ * [1., 8.]
+ * </pre>
+ *
+ * @param arrays the {@link NDArray}s to multiply together
+ * @return the result {@link NDArray}
+ * @throws IllegalArgumentException arrays must have at least two elements
+ * @throws IllegalArgumentException the shape of all inputs must be the same
+ */
+ public static NDArray mul(NDArray... arrays) {
+ checkInputs(arrays);
+ if (arrays.length == 2) {
+ return arrays[0].mul(arrays[1]);
+ }
+ try (NDArray array = NDArrays.stack(new NDList(arrays))) {
+ return array.prod(new int[] {0});
+ }
+ }
+
+ /**
+ * Divides the {@link NDArray} by a number element-wise.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array = manager.arange(5f);
+ * jshell> MxNDArrays.div(array, 4f);
+ * ND: (5) cpu() float32
+ * [0. , 0.25, 0.5 , 0.75, 1. ]
+ * </pre>
+ *
+ * @param a the {@link NDArray} to be be divided
+ * @param n the number to divide by
+ * @return the result {@link NDArray}
+ */
+ public static NDArray div(NDArray a, Number n) {
+ return a.div(n);
+ }
+
+ /**
+ * Divides a number by a {@link NDArray} element-wise.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array = manager.arange(5f).add(1);
+ * jshell> MxNDArrays.div(4f, array);
+ * ND: (5) cpu() float32
+ * [4. , 2. , 1.3333, 1. , 0.8 ]
+ * </pre>
+ *
+ * @param n the number to be be divided
+ * @param a the {@link NDArray} to divide by
+ * @return the result {@link NDArray}
+ */
+ public static NDArray div(Number n, NDArray a) {
+ return a.getNDArrayInternal().rdiv(n);
+ }
+
+ /**
+ * Divides a {@link NDArray} by a {@link NDArray} element-wise.
+ *
+ * <p>The shapes of {@link NDArray} a and {@link NDArray} b must be broadcastable.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array1 = manager.arange(9f).reshape(3, 3);
+ * jshell> MxNDArray array2 = manager.ones(new Shape(3)).mul(10);
+ * jshell> MxNDArrays.div(array1, array2); // broadcasting
+ * ND: (3, 3) cpu() float32
+ * [[0. , 0.1, 0.2],
+ * [0.3, 0.4, 0.5],
+ * [0.6, 0.7, 0.8],
+ * ]
+ * </pre>
+ *
+ * @param a the {@link NDArray} to be be divided
+ * @param b the {@link NDArray} to divide by
+ * @return the result {@link NDArray}
+ */
+ public static NDArray div(NDArray a, NDArray b) {
+ return a.div(b);
+ }
+
+ /**
+ * Returns element-wise remainder of division.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array = manager.arange(7f);
+ * jshell> MxNDArrays.mod(array, 5f);
+ * ND: (7) cpu() float32
+ * [0., 1., 2., 3., 4., 0., 1.]
+ * </pre>
+ *
+ * @param a the dividend {@link NDArray}
+ * @param n the divisor number
+ * @return the result {@link NDArray}
+ */
+ public static NDArray mod(NDArray a, Number n) {
+ return a.mod(n);
+ }
+
+ /**
+ * Returns element-wise remainder of division.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array = manager.arange(7f).add(1);
+ * jshell> MxNDArrays.mod(5f, array);
+ * ND: (7) cpu() float32
+ * [0., 1., 2., 1., 0., 5., 5.]
+ * </pre>
+ *
+ * @param n the dividend number
+ * @param a the divisor {@link NDArray}
+ * @return the result {@link NDArray}
+ */
+ public static NDArray mod(Number n, NDArray a) {
+ return a.getNDArrayInternal().rmod(n);
+ }
+
+ /**
+ * Returns element-wise remainder of division.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array1 = manager.create(new float[] {4f, 7f});
+ * jshell> MxNDArray array2 = manager.create(new float[] {2f, 3f});
+ * jshell> MxNDArrays.mod(array1, array2);
+ * ND: (2) cpu() float32
+ * [0., 1.]
+ * </pre>
+ *
+ * @param a the dividend MxNDArray
+ * @param b the dividend MxNDArray
+ * @return the result {@link NDArray}
+ */
+ public static NDArray mod(NDArray a, NDArray b) {
+ return a.mod(b);
+ }
+
+ /**
+ * Takes the power of the {@link NDArray} with a number element-wise.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array = manager.arange(5f);
+ * jshell> MxNDArrays.pow(array, 4f);
+ * ND: (6) cpu() float32
+ * [ 0., 1., 8., 27., 64., 125.]
+ * </pre>
+ *
+ * @param a the {@link NDArray} to be taken the power with
+ * @param n the number to take the power with
+ * @return the result {@link NDArray}
+ */
+ public static NDArray pow(NDArray a, Number n) {
+ return a.pow(n);
+ }
+
+ /**
+ * Takes the power of a number with a {@link NDArray} element-wise.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array = manager.arange(5f);
+ * jshell> MxNDArrays.pow(4f, array);
+ * ND: (5) cpu() float32
+ * [ 1., 4., 16., 64., 256.]
+ * </pre>
+ *
+ * @param n the number to be taken the power with
+ * @param a the {@link NDArray} to take the power with
+ * @return the result {@link NDArray}
+ */
+ public static NDArray pow(Number n, NDArray a) {
+ return a.getNDArrayInternal().rpow(n);
+ }
+
+ /**
+ * Takes the power of a {@link NDArray} with a {@link NDArray} element-wise.
+ *
+ * <p>The shapes of {@link NDArray} a and {@link NDArray} b must be broadcastable.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array1 = manager.arange(6f).reshape(3, 2);
+ * jshell> MxNDArray array2 = manager.create(new float[] {2f, 3f});
+ * jshell> MxNDArrays.pow(array1, array2); // broadcasting
+ * ND: (3, 2) cpu() float32
+ * [[ 0., 1.],
+ * [ 4., 27.],
+ * [ 16., 125.],
+ * ]
+ * </pre>
+ *
+ * @param a the {@link NDArray} to be taken the power with
+ * @param b the {@link NDArray} to take the power with
+ * @return the result {@link NDArray}
+ */
+ public static NDArray pow(NDArray a, NDArray b) {
+ return a.pow(b);
+ }
+
+ /**
+ * Adds a number to the {@link NDArray} element-wise in place.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array = manager.create(new float[] {1f, 2f});
+ * jshell> MxNDArrays.addi(array, 2f);
+ * ND: (2) cpu() float32
+ * [3., 4.]
+ * jshell> array;
+ * ND: (2) cpu() float32
+ * [3., 4.]
+ * </pre>
+ *
+ * @param a the {@link NDArray} to be added to
+ * @param n the number to add
+ * @return the result {@link NDArray}
+ */
+ public static NDArray addi(NDArray a, Number n) {
+ return a.addi(n);
+ }
+
+ /**
+ * Adds a {@link NDArray} to a number element-wise in place.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array = manager.create(new float[] {1f, 2f});
+ * jshell> MxNDArrays.addi(2f, array);
+ * ND: (2) cpu() float32
+ * [3., 4.]
+ * jshell> array;
+ * ND: (2) cpu() float32
+ * [3., 4.]
+ * </pre>
+ *
+ * @param a the number to be added to
+ * @param n the {@link NDArray} to add
+ * @return the result {@link NDArray}
+ */
+ public static NDArray addi(Number n, NDArray a) {
+ return a.addi(n);
+ }
+
+ /**
+ * Adds all of the {@link NDArray}s together element-wise in place.
+ *
+ * <p>The shapes of all of the {@link NDArray}s must be the same.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array1 = manager.create(new float[] {1f, 2f});
+ * jshell> MxNDArray array2 = manager.create(new float[] {3f, 4f});
+ * jshell> MxNDArray array3 = manager.create(new float[] {5f, 6f});
+ * jshell> MxNDArrays.addi(array1, array2, array3);
+ * ND: (2) cpu() float32
+ * [9., 12.]
+ * jshell> array;
+ * ND: (2) cpu() float32
+ * [9., 12.]
+ * </pre>
+ *
+ * @param arrays the {@link NDArray}s to add together
+ * @return the result {@link NDArray}
+ * @throws IllegalArgumentException arrays must have at least two elements
+ */
+ public static NDArray addi(NDArray... arrays) {
+ checkInputs(arrays);
+ Arrays.stream(arrays).skip(1).forEachOrdered(array -> arrays[0].addi(array));
+ return arrays[0];
+ }
+
+ /**
+ * Subtracts a number from the {@link NDArray} element-wise in place.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array = manager.create(new float[] {1f, 2f});
+ * jshell> MxNDArrays.subi(array, 2f);
+ * ND: (2) cpu() float32
+ * [-1., 0.]
+ * jshell> array;
+ * ND: (2) cpu() float32
+ * [-1., 0.]
+ * </pre>
+ *
+ * @param a the {@link NDArray} to be subtracted
+ * @param n the number to subtract from
+ * @return the result {@link NDArray}
+ */
+ public static NDArray subi(NDArray a, Number n) {
+ return a.subi(n);
+ }
+
+ /**
+ * Subtracts a {@link NDArray} from a number element-wise in place.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array = manager.create(new float[] {1f, 2f});
+ * jshell> MxNDArrays.subi(3f, array);
+ * ND: (2) cpu() float32
+ * [2., 1.]
+ * jshell> array;
+ * ND: (2) cpu() float32
+ * [2., 1.]
+ * </pre>
+ *
+ * @param n the number to be subtracted
+ * @param a the {@link NDArray} to subtract from
+ * @return the result {@link NDArray}
+ */
+ public static NDArray subi(Number n, NDArray a) {
+ return a.getNDArrayInternal().rsubi(n);
+ }
+
+ /**
+ * Subtracts a {@link NDArray} from a {@link NDArray} element-wise in place.
+ *
+ * <p>The shapes of {@link NDArray} a and {@link NDArray} b must be broadcastable.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array1 = manager.arange(9f).reshape(3, 3);
+ * jshell> MxNDArray array2 = manager.arange(3f);
+ * jshell> MxNDArrays.subi(array1, array2); // broadcasting
+ * ND: (3, 3) cpu() float32
+ * [[0., 0., 0.],
+ * [3., 3., 3.],
+ * [6., 6., 6.],
+ * ]
+ * jshell> array1;
+ * [[0., 0., 0.],
+ * [3., 3., 3.],
+ * [6., 6., 6.],
+ * ]
+ * </pre>
+ *
+ * @param a the {@link NDArray} to be subtracted
+ * @param b the {@link NDArray} to subtract from
+ * @return the result {@link NDArray}
+ */
+ public static NDArray subi(NDArray a, NDArray b) {
+ return a.subi(b);
+ }
+
+ /**
+ * Multiplies the {@link NDArray} by a number element-wise in place.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array = manager.create(new float[] {1f, 2f});
+ * jshell> MxNDArrays.muli(array, 3f);
+ * ND: (2) cpu() float32
+ * [3., 6.]
+ * jshell> array;
+ * ND: (2) cpu() float32
+ * [3., 6.]
+ * </pre>
+ *
+ * @param a the MxNDArray to be multiplied
+ * @param n the number to multiply by
+ * @return the result {@link NDArray}
+ */
+ public static NDArray muli(NDArray a, Number n) {
+ return a.muli(n);
+ }
+
+ /**
+ * Multiplies a number by a {@link NDArray} element-wise.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array = manager.create(new float[] {1f, 2f});
+ * jshell> MxNDArrays.muli(3f, array);
+ * ND: (2) cpu() float32
+ * [3., 6.]
+ * jshell> array;
+ * ND: (2) cpu() float32
+ * [3., 6.]
+ * </pre>
+ *
+ * @param n the number to multiply by
+ * @param a the {@link NDArray} to multiply by
+ * @return the result {@link NDArray}
+ */
+ public static NDArray muli(Number n, NDArray a) {
+ return a.muli(n);
+ }
+
+ /**
+ * Multiplies all of the {@link NDArray}s together element-wise in place.
+ *
+ * <p>The shapes of all of the {@link NDArray}s must be the same.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array1 = manager.create(new float[] {1f, 2f});
+ * jshell> MxNDArray array2 = manager.create(new float[] {3f, 4f});
+ * jshell> MxNDArray array3 = manager.create(new float[] {5f, 6f});
+ * jshell> MxNDArrays.muli(array1, array2, array3);
+ * ND: (2) cpu() float32
+ * [15., 48.]
+ * jshell> array;
+ * ND: (2) cpu() float32
+ * [15., 48.]
+ * </pre>
+ *
+ * @param arrays the {@link NDArray}s to multiply together
+ * @return the result {@link NDArray}
+ * @throws IllegalArgumentException arrays must have at least two elements
+ */
+ public static NDArray muli(NDArray... arrays) {
+ checkInputs(arrays);
+ Arrays.stream(arrays).skip(1).forEachOrdered(array -> arrays[0].muli(array));
+ return arrays[0];
+ }
+
+ /**
+ * Divides a number by a {@link NDArray} element-wise in place.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array = manager.arange(5f);
+ * jshell> MxNDArrays.divi(array, 4f);
+ * ND: (5) cpu() float32
+ * [0. , 0.25, 0.5 , 0.75, 1. ]
+ * jshell> array;
+ * ND: (5) cpu() float32
+ * [0. , 0.25, 0.5 , 0.75, 1. ]
+ * </pre>
+ *
+ * @param a the {@link NDArray} to be be divided
+ * @param n the number to divide by
+ * @return the result {@link NDArray}
+ */
+ public static NDArray divi(NDArray a, Number n) {
+ return a.divi(n);
+ }
+
+ /**
+ * Divides a number by a {@link NDArray} element-wise.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array = manager.arange(5f).add(1);
+ * jshell> MxNDArrays.divi(4f, array);
+ * ND: (5) cpu() float32
+ * [4. , 2. , 1.3333, 1. , 0.8 ]
+ * jshell> array;
+ * ND: (5) cpu() float32
+ * [4. , 2. , 1.3333, 1. , 0.8 ]
+ * </pre>
+ *
+ * @param n the number to be be divided
+ * @param a the {@link NDArray} to divide by
+ * @return the result {@link NDArray}
+ */
+ public static NDArray divi(Number n, NDArray a) {
+ return a.getNDArrayInternal().rdivi(n);
+ }
+
+ /**
+ * Divides a {@link NDArray} by a {@link NDArray} element-wise.
+ *
+ * <p>The shapes of {@link NDArray} a and {@link NDArray} b must be broadcastable.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array1 = manager.arange(9f).reshape(3, 3);
+ * jshell> MxNDArray array2 = manager.ones(new Shape(3)).mul(10);
+ * jshell> MxNDArrays.divi(array1, array2); // broadcasting
+ * ND: (3, 3) cpu() float32
+ * [[0. , 0.1, 0.2],
+ * [0.3, 0.4, 0.5],
+ * [0.6, 0.7, 0.8],
+ * ]
+ * jshell> array1;
+ * [[0. , 0.1, 0.2],
+ * [0.3, 0.4, 0.5],
+ * [0.6, 0.7, 0.8],
+ * ]
+ * </pre>
+ *
+ * @param a the {@link NDArray} to be be divided
+ * @param b the {@link NDArray} to divide by
+ * @return the result {@link NDArray}
+ */
+ public static NDArray divi(NDArray a, NDArray b) {
+ return a.divi(b);
+ }
+
+ /**
+ * Returns element-wise remainder of division in place.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array = manager.arange(7f);
+ * jshell> MxNDArrays.modi(array, 5f);
+ * ND: (7) cpu() float32
+ * [0., 1., 2., 3., 4., 0., 1.]
+ * jshell> array;
+ * ND: (7) cpu() float32
+ * [0., 1., 2., 3., 4., 0., 1.]
+ * </pre>
+ *
+ * @param a the dividend {@link NDArray}
+ * @param n the divisor number
+ * @return the result {@link NDArray}
+ */
+ public static NDArray modi(NDArray a, Number n) {
+ return a.modi(n);
+ }
+
+ /**
+ * Returns element-wise remainder of division in place.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array = manager.arange(7f);
+ * jshell> MxNDArrays.modi(5f, array);
+ * ND: (7) cpu() float32
+ * [0., 0., 1., 2., 1., 0., 5.]
+ * jshell> array;
+ * ND: (7) cpu() float32
+ * [0., 0., 1., 2., 1., 0., 5.]
+ * </pre>
+ *
+ * @param n the dividend number
+ * @param a the divisor {@link NDArray}
+ * @return the result {@link NDArray}
+ */
+ public static NDArray modi(Number n, NDArray a) {
+ return a.getNDArrayInternal().rmodi(n);
+ }
+
+ /**
+ * Returns element-wise remainder of division.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array1 = manager.create(new float[] {4f, 7f});
+ * jshell> MxNDArray array2 = manager.create(new float[] {2f, 3f});
+ * jshell> MxNDArrays.modi(array1, array2);
+ * ND: (2) cpu() float32
+ * [0., 1.]
+ * jshell> array1;
+ * ND: (2) cpu() float32
+ * [0., 1.]
+ * </pre>
+ *
+ * @param a the dividend MxNDArray
+ * @param b the dividend MxNDArray
+ * @return the result {@link NDArray}
+ */
+ public static NDArray modi(NDArray a, NDArray b) {
+ return a.modi(b);
+ }
+
+ /**
+ * Takes the power of the {@link NDArray} with a number element-wise in place.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array = manager.arange(5f);
+ * jshell> MxNDArrays.powi(array, 4f);
+ * ND: (6) cpu() float32
+ * [ 0., 1., 8., 27., 64., 125.]
+ * jshell> array;
+ * ND: (6) cpu() float32
+ * [ 0., 1., 8., 27., 64., 125.]
+ * </pre>
+ *
+ * @param a the {@link NDArray} to be taken the power with
+ * @param n the number to take the power with
+ * @return the result {@link NDArray}
+ */
+ public static NDArray powi(NDArray a, Number n) {
+ return a.powi(n);
+ }
+
+ /**
+ * Takes the power of a number with a {@link NDArray} element-wise in place.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array = manager.arange(5f);
+ * jshell> MxNDArrays.powi(4f, array);
+ * ND: (5) cpu() float32
+ * [ 1., 4., 16., 64., 256.]
+ * jshell> array;
+ * ND: (5) cpu() float32
+ * [ 1., 4., 16., 64., 256.]
+ * </pre>
+ *
+ * @param n the number to be taken the power with
+ * @param a the {@link NDArray} to take the power with
+ * @return the result {@link NDArray}
+ */
+ public static NDArray powi(Number n, NDArray a) {
+ return a.getNDArrayInternal().rpowi(n);
+ }
+
+ /**
+ * Takes the power of a {@link NDArray} with a {@link NDArray} element-wise.
+ *
+ * <p>The shapes of {@link NDArray} a and {@link NDArray} b must be broadcastable.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array1 = manager.arange(6f).reshape(3, 2);
+ * jshell> MxNDArray array2 = manager.create(new float[] {2f, 3f});
+ * jshell> MxNDArrays.powi(array1, array2); // broadcasting
+ * ND: (3, 2) cpu() float32
+ * [[ 0., 1.],
+ * [ 4., 27.],
+ * [ 16., 125.],
+ * ]
+ * jshell> array1;
+ * ND: (3, 2) cpu() float32
+ * [[ 0., 1.],
+ * [ 4., 27.],
+ * [ 16., 125.],
+ * ]
+ * </pre>
+ *
+ * @param a the {@link NDArray} to be taken the power with
+ * @param b the {@link NDArray} to take the power with
+ * @return the result {@link NDArray}
+ */
+ public static NDArray powi(NDArray a, NDArray b) {
+ return a.powi(b);
+ }
+
+ /**
+ * Dot product of {@link NDArray} a and {@link NDArray} b.
+ *
+ * <ul>
+ * <li>If both the {@link NDArray} and the other {@link NDArray} are 1-D {@link NDArray}s, it
+ * is inner product of vectors (without complex conjugation).
+ * <li>If both the {@link NDArray} and the other {@link NDArray} are 2-D {@link NDArray}s, it
+ * is matrix multiplication.
+ * <li>If either the {@link NDArray} or the other {@link NDArray} is 0-D {@link NDArray}
+ * (scalar), it is equivalent to mul.
+ * <li>If the {@link NDArray} is N-D {@link NDArray} and the other {@link NDArray} is 1-D
+ * {@link NDArray}, it is a sum product over the last axis of those.
+ * <li>If the {@link NDArray} is N-D {@link NDArray} and the other {@link NDArray} is M-D
+ * {@link NDArray}(where M>=2), it is a sum product over the last axis of this
+ * {@link NDArray} and the second-to-last axis of the other {@link NDArray}
+ * </ul>
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array1 = manager.create(new float[] {1f, 2f, 3f});
+ * jshell> MxNDArray array2 = manager.create(new float[] {4f, 5f, 6f});
+ * jshell> MxNDArrays.dot(array1, array2); // inner product
+ * ND: () cpu() float32
+ * 32.
+ * jshell> array1 = manager.create(new float[] {1f, 2f, 3f, 4f}, new Shape(2, 2));
+ * jshell> array2 = manager.create(new float[] {5f, 6f, 7f, 8f}, new Shape(2, 2));
+ * jshell> MxNDArrays.dot(array1, array2); // matrix multiplication
+ * ND: (2, 2) cpu() float32
+ * [[19., 22.],
+ * [43., 50.],
+ * ]
+ * jshell> array1 = manager.create(new float[] {1f, 2f, 3f, 4f}, new Shape(2, 2));
+ * jshell> array2 = manager.create(5f);
+ * jshell> MxNDArrays.dot(array1, array2);
+ * ND: (2, 2) cpu() float32
+ * [[ 5., 10.],
+ * [15., 20.],
+ * ]
+ * jshell> array1 = manager.create(new float[] {1f, 2f, 3f, 4f}, new Shape(2, 2));
+ * jshell> array2 = manager.create(new float[] {1f, 2f});
+ * jshell> MxNDArrays.dot(array1, array2);
+ * ND: (2) cpu() float32
+ * [ 5., 11.]
+ * jshell> array1 = manager.create(new float[] {1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f}, new Shape(2, 2, 2));
+ * jshell> array2 = manager.create(new float[] {1f, 2f, 3f ,4f}, new Shape(2, 2));
+ * jshell> MxNDArrays.dot(array1, array2);
+ * ND: (2, 2, 2) cpu() float32
+ * [[[ 7., 10.],
+ * [15., 22.],
+ * ],
+ * [[23., 34.],
+ * [31., 46.],
+ * ],
+ * ]
+ * </pre>
+ *
+ * @param a the {@link NDArray} to perform dot product with
+ * @param b the {@link NDArray} to perform dot product with
+ * @return the result {@link NDArray}
+ */
+ public static NDArray dot(NDArray a, NDArray b) {
+ return a.dot(b);
+ }
+
+ /**
+ * Product matrix of this {@code MxNDArray} and the other {@code MxNDArray}.
+ *
+ * <p>The behavior depends on the arguments in the following way.
+ *
+ * <ul>
+ * <li>If both this {@code MxNDArray} and the other {@code MxNDArray} are 2-D {@code
+ * MxNDArray}s, they are multiplied like conventional matrices
+ * <li>If either this {@code MxNDArray} or the other {@code MxNDArray} is N-D {@code
+ * MxNDArray}, N > 2 , it is treated as a stack of matrices residing in the last two
+ * indexes and broadcast accordingly.
+ * <li>If this {@code MxNDArray} is 1-D {@code MxNDArray}, it is promoted to a matrix by
+ * prepending a 1 to its dimensions. After matrix multiplication the prepended 1 is
+ * removed.
+ * <li>If other {@code MxNDArray} is 1-D {@code MxNDArray}, it is promoted to a matrix by
+ * appending a 1 to its dimensions. After matrix multiplication the appended 1 is removed.
+ * </ul>
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array1 = manager.create(new float[] {1f, 0f, 0f, 1f}, new Shape(2, 2));
+ * jshell> MxNDArray array2 = manager.create(new float[] {4f, 1f, 2f, 2f}, new Shape(2, 2));
+ * jshell> MxNDArrays.matMul(array1, array2); // for 2-D arrays, it is the matrix product
+ * ND: (2, 2) cpu() float32
+ * [[4., 1.],
+ * [2., 2.],
+ * ]
+ * jshell> array1 = manager.create(new float[] {1f, 0f, 0f, 1f}, new Shape(2, 2));
+ * jshell> array2 = manager.create(new float[] {1f, 2f});
+ * jshell> MxNDArrays.matMul(array1, array2);
+ * ND: (2) cpu() float32
+ * [1., 2.]
+ * jshell> array1 = manager.create(new float[] {1f, 0f, 0f, 1f}, new Shape(2, 2));
+ * jshell> array2 = manager.create(new float[] {1f, 2f});
+ * jshell> MxNDArrays.matMul(array1, array2);
+ * ND: (2) cpu() float32
+ * [1., 2.]
+ * jshell> array1 = manager.arange(2f * 2f * 4f).reshape(2, 2, 4);
+ * jshell> array2 = manager.arange(2f * 2f * 4f).reshape(2, 4, 2);
+ * jshell> MxNDArrays.matMul(array1, array2);
+ * ND: () cpu() float32
+ * 98.
+ * </pre>
+ *
+ * @param a the {@link NDArray} to perform matrix product with
+ * @param b the {@link NDArray} to perform matrix product with
+ * @return the result {@code MxNDArray}
+ */
+ public static NDArray matMul(NDArray a, NDArray b) {
+ return a.matMul(b);
+ }
+
+ /**
+ * Joins a sequence of {@link NDArray}s in {@link NDList} along the first axis.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array1 = manager.create(new float[] {0f, 1f, 2f});
+ * jshell> MxNDArray array2 = manager.create(new float[] {3f, 4f, 5f});
+ * jshell> MxNDArray array3 = manager.create(new float[] {6f, 7f, 8f});
+ * jshell> MxNDArrays.stack(new MxNDList(array1, array2, array3));
+ * ND: (3, 3) cpu() float32
+ * [[0., 1., 2.],
+ * [3., 4., 5.],
+ * [6., 7., 8.],
+ * ]
+ * </pre>
+ *
+ * @param arrays the input {@link NDList}. Each {@link NDArray} in the {@link NDList} must have
+ * the same shape as the {@link NDArray}
+ * @return the result {@link NDArray}. The stacked {@link NDArray} has one more dimension than
+ * the {@link NDArray}s in {@link NDList}
+ */
+ public static NDArray stack(NDList arrays) {
+ return stack(arrays, 0);
+ }
+
+ /**
+ * Joins a sequence of {@link NDArray}s in {@link NDList} along a new axis.
+ *
+ * <p>The axis parameter specifies the index of the new axis in the dimensions of the result.
+ * For example, if axis=0 it will be the first dimension and if axis=-1 it will be the last
+ * dimension.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array1 = manager.create(new float[] {0f, 1f, 2f});
+ * jshell> MxNDArray array2 = manager.create(new float[] {3f, 4f, 5f});
+ * jshell> MxNDArrays.stack(new MxNDList(array1, array2), 0);
+ * ND: (2, 3) cpu() float32
+ * [[0., 1., 2.],
+ * [3., 4., 5.],
+ * ]
+ * jshell> MxNDArray array1 = manager.create(new float[] {0f, 1f, 2f});
+ * jshell> MxNDArray array2 = manager.create(new float[] {3f, 4f, 5f});
+ * jshell> MxNDArrays.stack(new MxNDList(array1, array2), 1);
+ * ND: (3, 2) cpu() float32
+ * [[0., 3.],
+ * [1., 4.],
+ * [2., 5.],
+ * ]
+ * </pre>
+ *
+ * @param arrays the input {@link NDList}. Each {@link NDArray} in the {@link NDList} must have
+ * the same shape as the {@link NDArray}
+ * @param axis the axis in the result {@link NDArray} along which the input {@link NDList} are
+ * stacked
+ * @return the result {@link NDArray}. The stacked {@link NDArray} has one more dimension than
+ * the the {@link NDArray}
+ */
+ public static NDArray stack(NDList arrays, int axis) {
+ if (arrays.size() <= 0) {
+ throw new IllegalArgumentException("need at least one array to stack");
+ }
+ NDArray array = arrays.head();
+ return array.getNDArrayInternal().stack(arrays.subNDList(1), axis);
+ }
+
+ /**
+ * Joins a {@link NDList} along the first axis.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array1 = manager.create(new float[] {0f, 1f, 2f});
+ * jshell> MxNDArray array2 = manager.create(new float[] {3f, 4f, 5f});
+ * jshell> MxNDArray array3 = manager.create(new float[] {6f, 7f, 8f});
+ * jshell> MxNDArrays.concat(new MxNDList(array1, array2, array3));
+ * ND: (9) cpu() float32
+ * [0., 1., 2., 3., 4., 5., 6., 7., 8.]
+ * </pre>
+ *
+ * @param arrays a {@link NDList} which have the same shape as the {@link NDArray}, except in
+ * the dimension corresponding to axis
+ * @return the concatenated {@link NDArray}
+ */
+ public static NDArray concat(NDList arrays) {
+ return concat(arrays, 0);
+ }
+
+ /**
+ * Joins a {@link NDList} along an existing axis.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array1 = manager.create(new float[] {1f, 2f, 3f, 4f}, new Shape(2, 2));
+ * jshell> MxNDArray array2 = manager.create(new float[] {5f, 6f}, new Shape(1, 2));
+ * jshell> MxNDArrays.concat(new MxNDList(array1, array2), 0);
+ * ND: (3, 2) cpu() float32
+ * [[1., 2.],
+ * [3., 4.],
+ * [5., 6.],
+ * ]
+ * jshell> MxNDArrays.concat(new MxNDList(array1, array2.transpose()), 1);
+ * ND: (2, 3) cpu() float32
+ * [[1., 2., 5.],
+ * [3., 4., 6.],
+ * ]
+ * </pre>
+ *
+ * @param arrays a {@link NDList} which have the same shape as the {@link NDArray}, except in
+ * the dimension corresponding to axis
+ * @param axis the axis along which the {@link NDList} will be joined
+ * @return the concatenated {@link NDArray}
+ */
+ public static NDArray concat(NDList arrays, int axis) {
+
+ if (arrays.size() <= 0) {
+ throw new IllegalArgumentException("need at least one array to concatenate");
+ }
+
+ if (arrays.size() == 1) {
+ return arrays.singletonOrThrow().duplicate();
+ }
+ NDArray array = arrays.head();
+ return array.getNDArrayInternal().concat(arrays.subNDList(1), axis);
+ }
+
+ /**
+ * Returns the truth value of {@link NDArray} a AND {@link NDArray} b element-wise.
+ *
+ * <p>The shapes of {@link NDArray} a and {@link NDArray} b must be broadcastable.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array1 = manager.create(new boolean[] {true});
+ * jshell> MxNDArray array2 = manager.create(new boolean[] {false});
+ * jshell> MxNDArrays.logicalAnd(array1, array2);
+ * ND: (1) cpu() boolean
+ * [false]
+ * jshell> array1 = manager.create(new boolean[] {true, false});
+ * jshell> array2 = manager.create(new boolean[] {false, false});
+ * jshell> MxNDArrays.logicalAnd(array.gt(1), array.lt(4));
+ * ND: (2) cpu() boolean
+ * [false, false]
+ * </pre>
+ *
+ * @param a the {@link NDArray} to operate on
+ * @param b the {@link NDArray} to operate on
+ * @return the boolean {@link NDArray} of the logical AND operation applied to the elements of
+ * the {@link NDArray} a and {@link NDArray} b
+ */
+ public static NDArray logicalAnd(NDArray a, NDArray b) {
+ return a.logicalAnd(b);
+ }
+
+ /**
+ * Computes the truth value of {@link NDArray} a AND {@link NDArray} b element-wise.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array1 = manager.create(new boolean[] {true});
+ * jshell> MxNDArray array2 = manager.create(new boolean[] {false});
+ * jshell> MxNDArrays.logicalOr(array1, array2);
+ * ND: (1) cpu() boolean
+ * [ true]
+ * jshell> array1 = manager.create(new boolean[] {true, false});
+ * jshell> array2 = manager.create(new boolean[] {false, false});
+ * jshell> MxNDArrays.logicalOr(array1, array2);
+ * ND: (2) cpu() boolean
+ * [ true, false]
+ * </pre>
+ *
+ * <pre>
+ * jshell> MxNDArray array = manager.arange(5f);
+ * jshell> MxNDArrays.logicalOr(array.lt(1), array.gt(3));
+ * ND: (5) cpu() boolean
+ * [ true, false, false, false, true]
+ * </pre>
+ *
+ * @param a the {@link NDArray} to operate on
+ * @param b the {@link NDArray} to operate on
+ * @return the boolean {@link NDArray} of the logical AND operation applied to the elements of
+ * the {@link NDArray} a and {@link NDArray} b
+ */
+ public static NDArray logicalOr(NDArray a, NDArray b) {
+ return a.logicalOr(b);
+ }
+
+ /**
+ * Computes the truth value of {@link NDArray} a AND {@link NDArray} b element-wise.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array = manager.create(new boolean[] {true});
+ * jshell> MxNDArrays.logicalXor(array1, array2);
+ * ND: (1) cpu() boolean
+ * [ true]
+ * jshell> array1 = manager.create(new boolean[] {true, false});
+ * jshell> array2 = manager.create(new boolean[] {false, false});
+ * jshell> MxNDArrays.logicalXor(array1, array2);
+ * ND: (2) cpu() boolean
+ * [ true, false]
+ * </pre>
+ *
+ * <pre>
+ * jshell> MxNDArray array = manager.arange(5f);
+ * jshell> MxNDArrays.logicalXor(array.lt(1), array.gt(3));
+ * ND: (5) cpu() boolean
+ * [ true, false, false, false, true]
+ * </pre>
+ *
+ * @param a the {@link NDArray} to operate on
+ * @param b the {@link NDArray} to operate on
+ * @return the boolean {@link NDArray} of the logical XOR operation applied to the elements of
+ * the {@link NDArray} a and {@link NDArray} b
+ */
+ public static NDArray logicalXor(NDArray a, NDArray b) {
+ return a.logicalXor(b);
+ }
+
+ /**
+ * Returns element-wise inverse gauss error function of the input {@code MxNDArray}.
+ *
+ * <p>Examples
+ *
+ * <pre>
+ * jshell> MxNDArray array = manager.create(new float[] {0f, 0.5f, -1f});
+ * jshell> MxNDArrays.erfinv(array);
+ * ND: (3) cpu() float32
+ * [0., 0.4769, -inf]
+ * </pre>
+ *
+ * @param input The input {@code MxNDArray}
+ * @return The inverse of gauss error of the input, element-wise
+ */
+ public static NDArray erfinv(NDArray input) {
+ return input.erfinv();
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/NDFormat.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/NDFormat.java
new file mode 100644
index 0000000..686719f
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/NDFormat.java
@@ -0,0 +1,345 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.ndarray;
+
+import java.util.Locale;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
+import org.apache.mxnet.ndarray.types.DataType;
+import org.apache.mxnet.ndarray.types.Shape;
+import org.apache.mxnet.util.Utils;
+
+/** A helper for printing an {@link NDArray}. */
+public abstract class NDFormat {
+
+ private static final int PRECISION = 8;
+ private static final String LF = System.getProperty("line.separator");
+ private static final Pattern PATTERN = Pattern.compile("\\s*\\d\\.(\\d*?)0*e[+-](\\d+)");
+
+ /**
+ * Formats the contents of an array as a pretty printable string.
+ *
+ * @param array the array to print
+ * @param maxSize the maximum elements to print out
+ * @param maxDepth the maximum depth to print out
+ * @param maxRows the maximum rows to print out
+ * @param maxColumns the maximum columns to print out
+ * @return the string representation of the array
+ */
+ public static String format(
+ NDArray array, int maxSize, int maxDepth, int maxRows, int maxColumns) {
+ NDFormat format;
+ DataType dataType = array.getDataType();
+
+ if (dataType == DataType.UINT8) {
+ format = new HexFormat();
+ } else if (dataType == DataType.BOOLEAN) {
+ format = new BooleanFormat();
+ } else if (dataType.isInteger()) {
+ format = new IntFormat(array);
+ } else {
+ format = new FloatFormat(array);
+ }
+ return format.dump(array, maxSize, maxDepth, maxRows, maxColumns);
+ }
+
+ protected abstract CharSequence format(Number value);
+
+ private String dump(NDArray array, int maxSize, int maxDepth, int maxRows, int maxColumns) {
+ StringBuilder sb = new StringBuilder(1000);
+ String name = array.getName();
+ if (name != null) {
+ sb.append(name).append(": ");
+ } else {
+ sb.append("ND: ");
+ }
+ sb.append(array.getShape())
+ .append(' ')
+ .append(array.getDevice())
+ .append(' ')
+ .append(array.getDataType());
+ if (array.hasGradient()) {
+ sb.append(" hasGradient");
+ }
+ sb.append(LF);
+
+ long size = array.size();
+ long dimension = array.getShape().dimension();
+ if (size == 0) {
+ // corner case: 0 dimension
+ sb.append("[]").append(LF);
+ } else if (dimension == 0) {
+ // scalar case
+ sb.append(format(array.toArray()[0])).append(LF);
+ } else if (size > maxSize) {
+ sb.append("[ Exceed max print size ]");
+ } else if (dimension > maxDepth) {
+ sb.append("[ Exceed max print dimension ]");
+ } else {
+ dump(sb, array, 0, true, maxRows, maxColumns);
+ }
+ return sb.toString();
+ }
+
+ private void dump(
+ StringBuilder sb,
+ NDArray array,
+ int depth,
+ boolean first,
+ int maxRows,
+ int maxColumns) {
+ if (!first) {
+ Utils.pad(sb, ' ', depth);
+ }
+ sb.append('[');
+ Shape shape = array.getShape();
+ if (shape.dimension() == 1) {
+ append(sb, array.toArray(), maxColumns);
+ } else {
+ long len = shape.head();
+ long limit = Math.min(len, maxRows);
+ for (int i = 0; i < limit; ++i) {
+ try (NDArray nd = array.get(i)) {
+ dump(sb, nd, depth + 1, i == 0, maxRows, maxColumns);
+ }
+ }
+ long remaining = len - limit;
+ if (remaining > 0) {
+ Utils.pad(sb, ' ', depth + 1);
+ sb.append("... ").append(remaining).append(" more");
+ }
+ Utils.pad(sb, ' ', depth);
+ }
+ // last "]"
+ if (depth == 0) {
+ sb.append(']').append(LF);
+ } else {
+ sb.append("],").append(LF);
+ }
+ }
+
+ private void append(StringBuilder sb, Number[] values, int maxColumns) {
+ if (values.length == 0) {
+ return;
+ }
+ long limit = Math.min(values.length, maxColumns);
+ sb.append(format(values[0]));
+ for (int i = 1; i < limit; ++i) {
+ sb.append(", ");
+ sb.append(format(values[i]));
+ }
+
+ long remaining = values.length - limit;
+ if (remaining > 0) {
+ sb.append(", ... ").append(remaining).append(" more");
+ }
+ }
+
+ private static final class FloatFormat extends NDFormat {
+
+ private boolean exponential;
+ private int precision;
+ private int totalLength;
+
+ public FloatFormat(NDArray array) {
+ Number[] values = array.toArray();
+ int maxIntPartLen = 0;
+ int maxFractionLen = 0;
+ int expFractionLen = 0;
+ int maxExpSize = 2;
+ boolean sign = false;
+
+ double max = 0;
+ double min = Double.MAX_VALUE;
+ for (Number n : values) {
+ double v = n.doubleValue();
+ if (v < 0) {
+ sign = true;
+ }
+
+ if (!Double.isFinite(v)) {
+ int intPartLen = v < 0 ? 4 : 3;
+ if (totalLength < intPartLen) {
+ totalLength = intPartLen;
+ }
+ continue;
+ }
+ double abs = Math.abs(v);
+ String str = String.format(Locale.ENGLISH, "%16e", abs);
+ Matcher m = PATTERN.matcher(str);
+ if (!m.matches()) {
+ throw new AssertionError("Invalid decimal value: " + str);
+ }
+ int fractionLen = m.group(1).length();
+ if (expFractionLen < fractionLen) {
+ expFractionLen = fractionLen;
+ }
+ int expSize = m.group(2).length();
+ if (expSize > maxExpSize) {
+ maxExpSize = expSize;
+ }
+
+ if (abs >= 1) {
+ int intPartLen = (int) Math.log10(abs) + 1;
+ if (v < 0) {
+ ++intPartLen;
+ }
+ if (intPartLen > maxIntPartLen) {
+ maxIntPartLen = intPartLen;
+ }
+ int fullFractionLen = fractionLen + 1 - intPartLen;
+ if (maxFractionLen < fullFractionLen) {
+ maxFractionLen = fullFractionLen;
+ }
+ } else {
+ int intPartLen = v < 0 ? 2 : 1;
+ if (intPartLen > maxIntPartLen) {
+ maxIntPartLen = intPartLen;
+ }
+
+ int fullFractionLen = fractionLen + Integer.parseInt(m.group(2));
+ if (maxFractionLen < fullFractionLen) {
+ maxFractionLen = fullFractionLen;
+ }
+ }
+
+ if (abs > max) {
+ max = abs;
+ }
+ if (abs < min && abs > 0) {
+ min = abs;
+ }
+ }
+ double ratio = max / min;
+ if (max > 1.e8 || min < 0.0001 || ratio > 1000.) {
+ exponential = true;
+ precision = Math.min(PRECISION, expFractionLen);
+ totalLength = precision + 4;
+ if (sign) {
+ ++totalLength;
+ }
+ } else {
+ precision = Math.min(4, maxFractionLen);
+ int len = maxIntPartLen + precision + 1;
+ if (totalLength < len) {
+ totalLength = len;
+ }
+ }
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public CharSequence format(Number value) {
+ double d = value.doubleValue();
+ if (Double.isNaN(d)) {
+ return String.format(Locale.ENGLISH, "%" + totalLength + "s", "nan");
+ } else if (Double.isInfinite(d)) {
+ if (d > 0) {
+ return String.format(Locale.ENGLISH, "%" + totalLength + "s", "inf");
+ } else {
+ return String.format(Locale.ENGLISH, "%" + totalLength + "s", "-inf");
+ }
+ }
+ if (exponential) {
+ precision = Math.max(PRECISION, precision);
+ return String.format(Locale.ENGLISH, "% ." + precision + "e", value.doubleValue());
+ }
+ if (precision == 0) {
+ String fmt = "%" + (totalLength - 1) + '.' + precision + "f.";
+ return String.format(Locale.ENGLISH, fmt, value.doubleValue());
+ }
+
+ String fmt = "%" + totalLength + '.' + precision + 'f';
+ String ret = String.format(Locale.ENGLISH, fmt, value.doubleValue());
+ // Replace trailing zeros with space
+ char[] chars = ret.toCharArray();
+ for (int i = chars.length - 1; i >= 0; --i) {
+ if (chars[i] == '0') {
+ chars[i] = ' ';
+ } else {
+ break;
+ }
+ }
+ return new String(chars);
+ }
+ }
+
+ private static final class HexFormat extends NDFormat {
+
+ /** {@inheritDoc} */
+ @Override
+ public CharSequence format(Number value) {
+ return String.format(Locale.ENGLISH, "0x%02X", value.byteValue());
+ }
+ }
+
+ private static final class IntFormat extends NDFormat {
+
+ private boolean exponential;
+ private int precision;
+ private int totalLength;
+
+ public IntFormat(NDArray array) {
+ Number[] values = array.toArray();
+ // scalar case
+ if (values.length == 1) {
+ totalLength = 1;
+ return;
+ }
+ long max = 0;
+ long negativeMax = 0;
+ for (Number n : values) {
+ long v = n.longValue();
+ long abs = Math.abs(v);
+ if (v < 0 && abs > negativeMax) {
+ negativeMax = abs;
+ }
+ if (abs > max) {
+ max = abs;
+ }
+ }
+
+ if (max >= 1.e8) {
+ exponential = true;
+ precision = Math.min(PRECISION, (int) Math.log10(max) + 1);
+ } else {
+ int size = (max != 0) ? (int) Math.log10(max) + 1 : 1;
+ int negativeSize = (negativeMax != 0) ? (int) Math.log10(negativeMax) + 2 : 2;
+ totalLength = Math.max(size, negativeSize);
+ }
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public CharSequence format(Number value) {
+ if (exponential) {
+ return String.format(Locale.ENGLISH, "% ." + precision + "e", value.floatValue());
+ }
+ return String.format(Locale.ENGLISH, "%" + totalLength + "d", value.longValue());
+ }
+ }
+
+ private static final class BooleanFormat extends NDFormat {
+
+ /** {@inheritDoc} */
+ @Override
+ public CharSequence format(Number value) {
+ return value.byteValue() != 0 ? " true" : "false";
+ }
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/NDList.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/NDList.java
new file mode 100644
index 0000000..b467a3a
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/NDList.java
@@ -0,0 +1,272 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.ndarray;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import org.apache.mxnet.engine.Device;
+import org.apache.mxnet.engine.MxResource;
+import org.apache.mxnet.ndarray.types.Shape;
+
+/**
+ * An {@code NDList} represents a sequence of {@link NDArray}s with names.
+ *
+ * <p>Each {@link NDArray} in this list can optionally have a name. You can use the name to look up
+ * an NDArray in the NDList.
+ *
+ * @see NDArray
+ */
+public class NDList extends ArrayList<NDArray> implements AutoCloseable {
+ private static final long serialVersionUID = 1L;
+
+ /** Constructs an empty NDList. */
+ public NDList() {}
+
+ /**
+ * Constructs an empty NDList with the specified initial capacity.
+ *
+ * @param initialCapacity the initial capacity of the list
+ * @throws IllegalArgumentException if the specified initial capacity is negative
+ */
+ public NDList(int initialCapacity) {
+ super(initialCapacity);
+ }
+
+ /**
+ * Constructs and initiates an NDList with the specified {@link NDList}s.
+ *
+ * @param arrays the {@link NDList}s
+ */
+ public NDList(NDArray... arrays) {
+ super(Arrays.asList(arrays));
+ }
+
+ /**
+ * Constructs and initiates an NDList with the specified {@link NDArray}s.
+ *
+ * @param other the {@link NDArray}s
+ */
+ public NDList(Collection<NDArray> other) {
+ super(other);
+ }
+
+ /**
+ * Decodes NDList from byte array.
+ *
+ * @param parent the parent {@link MxResource} to manage this instance
+ * @param byteArray byte array to load from
+ * @return {@code NDList}
+ */
+ public static NDList decode(MxResource parent, byte[] byteArray) {
+ return decode(parent, new ByteArrayInputStream(byteArray));
+ }
+
+ /**
+ * Decodes NDList from {@link InputStream}.
+ *
+ * @param parent {@link MxResource} assigned to {@link NDArray}
+ * @param is input stream contains the ndlist information
+ * @return {@code NDList}
+ */
+ public static NDList decode(MxResource parent, InputStream is) {
+ try (DataInputStream dis = new DataInputStream(is)) {
+ int size = dis.readInt();
+ if (size < 0) {
+ throw new IllegalArgumentException("Invalid NDList size: " + size);
+ }
+ NDList list = new NDList();
+ for (int i = 0; i < size; i++) {
+ list.add(i, NDArray.decode(parent, dis));
+ }
+ return list;
+ } catch (IOException e) {
+ throw new IllegalArgumentException("Malformed data", e);
+ }
+ }
+
+ /**
+ * Removes the first occurrence of the specified element from this NDList if it is present.
+ *
+ * <p>If this list does not contain the element, it is unchanged. More formally, removes the
+ * element with the lowest index {@code i} such that {@code
+ * (o==null ? get(i)==null : o.equals(get(i)))} (if such an element exists).
+ *
+ * @param name the name of the NDArray to be removed from this NDList, if present
+ * @return the element that was removed
+ */
+ public NDArray remove(String name) {
+ int index = 0;
+ for (NDArray array : this) {
+ if (name.equals(array.getName())) {
+ remove(index);
+ return array;
+ }
+ ++index;
+ }
+ return null;
+ }
+
+ /**
+ * Returns {@code true} if this NDList contains an NDArray with the specified name.
+ *
+ * @param name the name of the NDArray to be removed from this NDList, if present
+ * @return {@code true} if this list contains the specified element
+ */
+ public boolean contains(String name) {
+ for (NDArray array : this) {
+ if (name.equals(array.getName())) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ /**
+ * Returns the head index of the NDList.
+ *
+ * @return the head NDArray
+ * @throws IndexOutOfBoundsException if the index is out of range ({@code index < 0 || index
+ * >= size()})
+ */
+ public NDArray head() {
+ return get(0);
+ }
+
+ /**
+ * Returns the only element if this is a singleton NDList or throws an exception if multiple
+ * elements.
+ *
+ * @return the head NDArray
+ * @throws IndexOutOfBoundsException if the list does not contain exactly one element
+ */
+ public NDArray singletonOrThrow() {
+ if (size() != 1) {
+ throw new IndexOutOfBoundsException(
+ "Incorrect number of elements in NDList.singletonOrThrow: Expected 1 and was "
+ + size());
+ }
+ return get(0);
+ }
+
+ /**
+ * Appends all of the NDArrays in the specified NDList to the end of this NDList, in the order
+ * that they are returned by the specified NDList's iterator.
+ *
+ * @param other the NDList containing NDArray to be added to this list
+ * @return this NDList after the addition
+ */
+ public NDList addAll(NDList other) {
+ for (NDArray array : other) {
+ add(array);
+ }
+ return this;
+ }
+
+ /**
+ * Returns a view of the portion of this NDList between the specified fromIndex, inclusive, and
+ * to the end.
+ *
+ * @param fromIndex the start index (inclusive)
+ * @return a view of the portion of this NDList
+ */
+ public NDList subNDList(int fromIndex) {
+ return new NDList(subList(fromIndex, size()));
+ }
+
+ /**
+ * Converts all the {@code NDArray} in {@code NDList} to a different {@link Device}.
+ *
+ * @param device the {@link Device} to be set
+ * @param copy set {@code true} if you want to return a copy of the underlying NDArray
+ * @return a new {@code NDList} with the NDArrays on specified {@link Device}
+ */
+ public NDList toDevice(Device device, boolean copy) {
+ if (!copy) {
+ // if all arrays in NDList are already on device, return itself
+ if (this.stream().allMatch(array -> array.getDevice() == device)) {
+ return this;
+ }
+ }
+ NDList newNDList = new NDList(size());
+ forEach(a -> newNDList.add(a.toDevice(device, copy)));
+ return newNDList;
+ }
+
+ /**
+ * Encodes the NDList to byte array.
+ *
+ * @return the byte array
+ */
+ public byte[] encode() {
+ try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) {
+ DataOutputStream dos = new DataOutputStream(baos);
+ dos.writeInt(size());
+ for (NDArray nd : this) {
+ dos.write(nd.encode());
+ }
+ dos.flush();
+ return baos.toByteArray();
+ } catch (IOException e) {
+ throw new AssertionError("NDList is not writable", e);
+ }
+ }
+
+ /**
+ * Gets all of shapes in the {@code NDList}.
+ *
+ * @return shapes in {@code NDList}
+ */
+ public Shape[] getShapes() {
+ return stream().map(NDArray::getShape).toArray(Shape[]::new);
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public void close() {
+ forEach(NDArray::close);
+ clear();
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public String toString() {
+ StringBuilder builder = new StringBuilder(200);
+ builder.append("NDList size: ").append(size()).append('\n');
+ int index = 0;
+ for (NDArray array : this) {
+ String name = array.getName();
+ builder.append(index++).append(' ');
+ if (name != null) {
+ builder.append(name);
+ }
+ builder.append(": ")
+ .append(array.getShape())
+ .append(' ')
+ .append(array.getDataType())
+ .append('\n');
+ }
+ return builder.toString();
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/NDSerializer.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/NDSerializer.java
new file mode 100644
index 0000000..34d37a5
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/NDSerializer.java
@@ -0,0 +1,162 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.ndarray;
+
+import java.io.ByteArrayOutputStream;
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+import org.apache.mxnet.engine.MxResource;
+import org.apache.mxnet.ndarray.types.DataType;
+import org.apache.mxnet.ndarray.types.Shape;
+
+/** A interface contains encoding and decoding logic for NDArray. */
+public final class NDSerializer {
+
+ static final int BUFFER_SIZE = 81920;
+ static final String MAGIC_NUMBER = "NDAR";
+ static final int VERSION = 2;
+
+ private NDSerializer() {}
+
+ /**
+ * Allocates a new engine specific direct byte buffer.
+ *
+ * @param capacity the new buffer's capacity, in bytes
+ * @return the new byte buffer
+ */
+ public static ByteBuffer allocateDirect(int capacity) {
+ return ByteBuffer.allocateDirect(capacity).order(ByteOrder.nativeOrder());
+ }
+
+ /**
+ * Encodes {@link NDArray} to byte array.
+ *
+ * @param array the input {@link NDArray}
+ * @return byte array
+ */
+ static byte[] encode(NDArray array) {
+ try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) {
+ DataOutputStream dos = new DataOutputStream(baos);
+ // magic string for version identification
+ dos.writeUTF(MAGIC_NUMBER);
+ dos.writeInt(VERSION);
+ String name = array.getName();
+ if (name == null) {
+ dos.write(0);
+ } else {
+ dos.write(1);
+ dos.writeUTF(name);
+ }
+ dos.writeUTF(array.getSparseFormat().name());
+ dos.writeUTF(array.getDataType().name());
+
+ Shape shape = array.getShape();
+ dos.write(shape.getEncoded());
+
+ ByteBuffer bb = array.toByteBuffer();
+ int length = bb.remaining();
+ dos.writeInt(length);
+
+ if (length > 0) {
+ if (length > BUFFER_SIZE) {
+ byte[] buf = new byte[BUFFER_SIZE];
+ while (length > BUFFER_SIZE) {
+ bb.get(buf);
+ dos.write(buf);
+ length = bb.remaining();
+ }
+ }
+
+ byte[] buf = new byte[length];
+ bb.get(buf);
+ dos.write(buf);
+ }
+ dos.flush();
+ return baos.toByteArray();
+ } catch (IOException e) {
+ throw new AssertionError("This should never happen", e);
+ }
+ }
+
+ /**
+ * Decodes {@link NDArray} through {@link DataInputStream}.
+ *
+ * @param parent the parent MxResource object which create the returned object
+ * @param is input stream data to load from
+ * @return {@link NDArray}
+ * @throws IOException data is not readable
+ */
+ public static NDArray decode(MxResource parent, InputStream is) throws IOException {
+ DataInputStream dis;
+ if (is instanceof DataInputStream) {
+ dis = (DataInputStream) is;
+ } else {
+ dis = new DataInputStream(is);
+ }
+
+ if (!"NDAR".equals(dis.readUTF())) {
+ throw new IllegalArgumentException("Malformed NDArray data");
+ }
+
+ // NDArray encode version
+ int version = dis.readInt();
+ if (version < 1 || version > VERSION) {
+ throw new IllegalArgumentException("Unexpected NDArray encode version " + version);
+ }
+
+ String name = null;
+ if (version > 1) {
+ byte flag = dis.readByte();
+ if (flag == 1) {
+ name = dis.readUTF();
+ }
+ }
+
+ dis.readUTF(); // ignore SparseFormat
+
+ // DataType - 1 byte
+ DataType dataType = DataType.valueOf(dis.readUTF());
+
+ // Shape
+ Shape shape = Shape.decode(dis);
+
+ // Data
+ int length = dis.readInt();
+ ByteBuffer data = allocateDirect(length);
+
+ if (length > 0) {
+ byte[] buf = new byte[BUFFER_SIZE];
+ while (length > BUFFER_SIZE) {
+ dis.readFully(buf);
+ data.put(buf);
+ length -= BUFFER_SIZE;
+ }
+
+ dis.readFully(buf, 0, length);
+ data.put(buf, 0, length);
+ data.rewind();
+ }
+ NDArray array = NDArray.create(parent, dataType.asDataType(data), shape, dataType);
+ array.setName(name);
+ return array;
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/dim/NDIndexAll.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/dim/NDIndexAll.java
new file mode 100644
index 0000000..8d348e4
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/dim/NDIndexAll.java
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.ndarray.dim;
+
+/** An {@code NDIndexElement} to return all values in a particular dimension. */
+public class NDIndexAll implements NDIndexElement {
+
+ /** {@inheritDoc} */
+ @Override
+ public int getRank() {
+ return 1;
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/dim/NDIndexBooleans.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/dim/NDIndexBooleans.java
new file mode 100644
index 0000000..1468414
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/dim/NDIndexBooleans.java
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.ndarray.dim;
+
+import org.apache.mxnet.ndarray.NDArray;
+
+/** An {@code NDIndexElement} to return values based on a mask binary NDArray. */
+public class NDIndexBooleans implements NDIndexElement {
+
+ private NDArray index;
+
+ /**
+ * Constructs a {@code NDIndexBooleans} instance with specified mask binary NDArray.
+ *
+ * @param index the mask binary {@code NDArray}
+ */
+ public NDIndexBooleans(NDArray index) {
+ this.index = index;
+ }
+
+ /**
+ * Returns the mask binary {@code NDArray}.
+ *
+ * @return the mask binary {@code NDArray}
+ */
+ public NDArray getIndex() {
+ return index;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public int getRank() {
+ return index.getShape().dimension();
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/dim/NDIndexElement.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/dim/NDIndexElement.java
new file mode 100644
index 0000000..4f89fbb
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/dim/NDIndexElement.java
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.ndarray.dim;
+
+/** An index for particular dimensions created by NDIndex. */
+public interface NDIndexElement {
+
+ /**
+ * Returns the number of dimensions occupied by this index element.
+ *
+ * @return the number of dimensions occupied by this index element
+ */
+ int getRank();
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/dim/NDIndexFixed.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/dim/NDIndexFixed.java
new file mode 100644
index 0000000..e0713ac
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/dim/NDIndexFixed.java
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mxnet.ndarray.dim;
+
+/** An NDIndexElement that returns only a specific value in the corresponding dimension. */
+public class NDIndexFixed implements NDIndexElement {
+
+ private long index;
+
+ /**
+ * Constructs a {@code NDIndexFixed} instance with specified dimension.
+ *
+ * @param index the dimension of the NDArray
+ */
+ public NDIndexFixed(long index) {
+ this.index = index;
+ }
+
+ /**
+ * Returns the dimension of the index.
+ *
+ * @return the dimension of the index
+ */
+ public long getIndex() {
+ return index;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public int getRank() {
+ return 1;
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/dim/NDIndexPick.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/dim/NDIndexPick.java
new file mode 100644
index 0000000..f651f01
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/dim/NDIndexPick.java
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mxnet.ndarray.dim;
+
+import org.apache.mxnet.ndarray.NDArray;
+
+/** An {@link NDIndexElement} that gets elements by index in the specified axis. */
+public class NDIndexPick implements NDIndexElement {
+
+ private NDArray indices;
+
+ /**
+ * Constructs a pick.
+ *
+ * @param indices the indices to pick
+ */
+ public NDIndexPick(NDArray indices) {
+ this.indices = indices;
+ }
+
+ @Override
+ /** {@inheritDoc} */
+ public int getRank() {
+ return 1;
+ }
+
+ /**
+ * Returns the indices to pick.
+ *
+ * @return the indices to pick
+ */
+ public NDArray getIndices() {
+ return indices;
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/dim/NDIndexSlice.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/dim/NDIndexSlice.java
new file mode 100644
index 0000000..e87784c
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/dim/NDIndexSlice.java
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mxnet.ndarray.dim;
+
+/** An NDIndexElement that returns a range of values in the specified dimension. */
+public class NDIndexSlice implements NDIndexElement {
+
+ private Long min;
+ private Long max;
+ private Long step;
+
+ /**
+ * Constructs a {@code NDIndexSlice} instance with specified range and step.
+ *
+ * @param min the start of the range
+ * @param max the end of the range
+ * @param step the step between each slice
+ * @throws IllegalArgumentException Thrown if the step is zero
+ */
+ public NDIndexSlice(Long min, Long max, Long step) {
+ this.min = min;
+ this.max = max;
+ this.step = step;
+ if (step != null && step == 0) {
+ throw new IllegalArgumentException("The step can not be zero");
+ }
+ }
+
+ /**
+ * Returns the start of the range.
+ *
+ * @return the start of the range
+ */
+ public Long getMin() {
+ return min;
+ }
+
+ /**
+ * Returns the end of the range.
+ *
+ * @return the end of the range
+ */
+ public Long getMax() {
+ return max;
+ }
+
+ /**
+ * Returns the step between each slice.
+ *
+ * @return the step between each slice
+ */
+ public Long getStep() {
+ return step;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public int getRank() {
+ return 1;
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/dim/full/NDIndexFullPick.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/dim/full/NDIndexFullPick.java
new file mode 100644
index 0000000..632badf
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/dim/full/NDIndexFullPick.java
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mxnet.ndarray.dim.full;
+
+import java.util.Optional;
+import org.apache.mxnet.ndarray.NDArray;
+import org.apache.mxnet.ndarray.dim.NDIndexAll;
+import org.apache.mxnet.ndarray.dim.NDIndexElement;
+import org.apache.mxnet.ndarray.dim.NDIndexPick;
+import org.apache.mxnet.ndarray.index.NDIndex;
+import org.apache.mxnet.ndarray.types.Shape;
+
+/** A simplified representation of a pick-based {@link NDArray}. */
+public final class NDIndexFullPick {
+
+ private NDArray indices;
+ private int axis;
+
+ /**
+ * Constructs a new {@link NDIndexFullPick}.
+ *
+ * @param indices the indices to pick
+ * @param axis the axis to pick at
+ */
+ private NDIndexFullPick(NDArray indices, int axis) {
+ this.indices = indices;
+ this.axis = axis;
+ }
+
+ /**
+ * Returns (if possible) the {@link NDIndexFullPick} representation of an {@link NDIndex}.
+ *
+ * @param index the index to represent
+ * @param target the shape of the array to index
+ * @return the full pick representation or nothing if it can't represent the index
+ */
+ public static Optional<NDIndexFullPick> fromIndex(NDIndex index, Shape target) {
+ int axis = 0;
+ NDIndexFullPick fullPick = null;
+ for (NDIndexElement el : index.getIndices()) {
+ if (el instanceof NDIndexAll) {
+ axis++;
+ } else if (el instanceof NDIndexPick) {
+ if (fullPick == null) {
+ fullPick = new NDIndexFullPick(((NDIndexPick) el).getIndices(), axis);
+ } else {
+ // Don't support multiple picks
+ throw new UnsupportedOperationException(
+ "Only one pick per get is currently supported");
+ }
+ } else {
+ // Invalid dim for fullPick
+ return Optional.empty();
+ }
+ }
+ return Optional.ofNullable(fullPick);
+ }
+
+ /**
+ * Returns the indices to pick.
+ *
+ * @return the indices to pick
+ */
+ public NDArray getIndices() {
+ return indices;
+ }
+
+ /**
+ * Returns the axis to pick.
+ *
+ * @return the axis to pick
+ */
+ public int getAxis() {
+ return axis;
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/dim/full/NDIndexFullSlice.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/dim/full/NDIndexFullSlice.java
new file mode 100644
index 0000000..9b8d9f0
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/dim/full/NDIndexFullSlice.java
@@ -0,0 +1,239 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mxnet.ndarray.dim.full;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Optional;
+import org.apache.mxnet.ndarray.dim.NDIndexAll;
+import org.apache.mxnet.ndarray.dim.NDIndexElement;
+import org.apache.mxnet.ndarray.dim.NDIndexFixed;
+import org.apache.mxnet.ndarray.dim.NDIndexSlice;
+import org.apache.mxnet.ndarray.index.NDIndex;
+import org.apache.mxnet.ndarray.types.Shape;
+
+/** An index as a slice on all dimensions where some dimensions can be squeezed. */
+public final class NDIndexFullSlice {
+ private long[] min;
+ private long[] max;
+ private long[] step;
+ private int[] toSqueeze;
+ private Shape shape;
+ private Shape squeezedShape;
+
+ /**
+ * Constructs a {@link NDIndexFullSlice}.
+ *
+ * @param min the min for each axis
+ * @param max the max for each axis
+ * @param step the step for each axis
+ * @param toSqueeze the axes to squeeze after slicing
+ * @param shape the result shape (without squeezing)
+ * @param squeezedShape the result shape (with squeezing)
+ */
+ private NDIndexFullSlice(
+ long[] min,
+ long[] max,
+ long[] step,
+ int[] toSqueeze,
+ Shape shape,
+ Shape squeezedShape) {
+ this.min = min;
+ this.max = max;
+ this.step = step;
+ this.toSqueeze = toSqueeze;
+ this.shape = shape;
+ this.squeezedShape = squeezedShape;
+ }
+
+ /**
+ * Returns (if possible) the {@link NDIndexFullSlice} representation of an {@link NDIndex}.
+ *
+ * @param index the index to represent
+ * @param target the shape of the array to index
+ * @return the full slice representation or nothing if it can't represent the index
+ */
+ public static Optional<NDIndexFullSlice> fromIndex(NDIndex index, Shape target) {
+ if (!index.stream()
+ .allMatch(
+ ie ->
+ ie instanceof NDIndexAll
+ || ie instanceof NDIndexFixed
+ || ie instanceof NDIndexSlice)) {
+ return Optional.empty();
+ }
+ int ellipsisIndex = index.getEllipsisIndex();
+ int indDimensions = index.getRank();
+ int targetDimensions = target.dimension();
+ if (indDimensions > target.dimension()) {
+ throw new IllegalArgumentException(
+ "The index has too many dimensions - "
+ + indDimensions
+ + " dimensions for array with "
+ + targetDimensions
+ + " dimensions");
+ }
+ long[] min = new long[targetDimensions];
+ long[] max = new long[targetDimensions];
+ long[] step = new long[targetDimensions];
+ List<Integer> toSqueeze = new ArrayList<>(targetDimensions);
+ long[] shape = new long[targetDimensions];
+ List<Long> squeezedShape = new ArrayList<>(targetDimensions);
+ if (ellipsisIndex == -1 || ellipsisIndex == indDimensions) {
+ // ellipsis in the end and non ellipsis case
+ for (int i = 0; i < indDimensions; i++) {
+ NDIndexElement ie = index.get(i);
+ addSliceInfo(ie, i, target, min, max, step, toSqueeze, shape, squeezedShape);
+ }
+ for (int i = indDimensions; i < target.dimension(); i++) {
+ padIndexAll(i, target, min, max, step, shape, squeezedShape);
+ }
+ } else if (ellipsisIndex == 0) {
+ // ellipsis in the beginning
+ int paddingDim = targetDimensions - indDimensions;
+ int i;
+ for (i = 0; i < paddingDim; ++i) {
+ padIndexAll(i, target, min, max, step, shape, squeezedShape);
+ }
+ for (; i < targetDimensions; ++i) {
+ NDIndexElement ie = index.get(i - paddingDim);
+ addSliceInfo(ie, i, target, min, max, step, toSqueeze, shape, squeezedShape);
+ }
+ } else {
+ // ellipsis in the middle
+ int paddingDim = targetDimensions - indDimensions;
+ int i;
+ for (i = 0; i < ellipsisIndex; ++i) {
+ NDIndexElement ie = index.get(i);
+ addSliceInfo(ie, i, target, min, max, step, toSqueeze, shape, squeezedShape);
+ }
+ for (; i < paddingDim + ellipsisIndex; ++i) {
+ padIndexAll(i, target, min, max, step, shape, squeezedShape);
+ }
+ for (; i < targetDimensions; ++i) {
+ NDIndexElement ie = index.get(i - paddingDim);
+ addSliceInfo(ie, i, target, min, max, step, toSqueeze, shape, squeezedShape);
+ }
+ }
+ int[] squeeze = toSqueeze.stream().mapToInt(i -> i).toArray();
+ NDIndexFullSlice fullSlice =
+ new NDIndexFullSlice(
+ min, max, step, squeeze, new Shape(shape), new Shape(squeezedShape));
+ return Optional.of(fullSlice);
+ }
+
+ private static void addSliceInfo(
+ NDIndexElement ie,
+ int i,
+ Shape target,
+ long[] min,
+ long[] max,
+ long[] step,
+ List<Integer> toSqueeze,
+ long[] shape,
+ List<Long> squeezedShape) {
+ if (ie instanceof NDIndexFixed) {
+ NDIndexFixed fixed = ((NDIndexFixed) ie);
+ long rawIndex = fixed.getIndex();
+ min[i] = rawIndex < 0 ? Math.floorMod(rawIndex, target.get(i)) : rawIndex;
+ max[i] = min[i] + 1;
+ step[i] = 1;
+ toSqueeze.add(i);
+ shape[i] = 1;
+ } else if (ie instanceof NDIndexSlice) {
+ NDIndexSlice slice = (NDIndexSlice) ie;
+ long rawMin = Optional.ofNullable(slice.getMin()).orElse(0L);
+ min[i] = rawMin < 0 ? Math.floorMod(rawMin, target.get(i)) : rawMin;
+ long rawMax = Optional.ofNullable(slice.getMax()).orElse(target.size(i));
+ max[i] = rawMax < 0 ? Math.floorMod(rawMax, target.get(i)) : rawMax;
+ step[i] = Optional.ofNullable(slice.getStep()).orElse(1L);
+ shape[i] = (long) Math.ceil(((double) (max[i] - min[i])) / step[i]);
+ squeezedShape.add(shape[i]);
+ } else if (ie instanceof NDIndexAll) {
+ padIndexAll(i, target, min, max, step, shape, squeezedShape);
+ }
+ }
+
+ private static void padIndexAll(
+ int i,
+ Shape target,
+ long[] min,
+ long[] max,
+ long[] step,
+ long[] shape,
+ List<Long> squeezedShape) {
+ min[i] = 0;
+ max[i] = target.size(i);
+ step[i] = 1;
+ shape[i] = target.size(i);
+ squeezedShape.add(target.size(i));
+ }
+
+ /**
+ * Returns the slice min for each axis.
+ *
+ * @return the slice min for each axis
+ */
+ public long[] getMin() {
+ return min;
+ }
+
+ /**
+ * Returns the slice max for each axis.
+ *
+ * @return the slice max for each axis
+ */
+ public long[] getMax() {
+ return max;
+ }
+
+ /**
+ * Returns the slice step for each axis.
+ *
+ * @return the slice step for each axis
+ */
+ public long[] getStep() {
+ return step;
+ }
+
+ /**
+ * Returns the squeeze array of axis.
+ *
+ * @return the squeeze array of axis
+ */
+ public int[] getToSqueeze() {
+ return toSqueeze;
+ }
+
+ /**
+ * Returns the slice shape without squeezing.
+ *
+ * @return the slice shape without squeezing
+ */
+ public Shape getShape() {
+ return shape;
+ }
+
+ /**
+ * Returns the slice shape with squeezing.
+ *
+ * @return the slice shape with squeezing
+ */
+ public Shape getSqueezedShape() {
+ return squeezedShape;
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/dim/full/package-info.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/dim/full/package-info.java
new file mode 100644
index 0000000..e796f52
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/dim/full/package-info.java
@@ -0,0 +1,19 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/** Contains the Java front-end implementation for Apache MXNet. */
+package org.apache.mxnet.ndarray.dim.full;
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/dim/package-info.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/dim/package-info.java
new file mode 100644
index 0000000..9a5c2f8
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/dim/package-info.java
@@ -0,0 +1,19 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/** Contains the Java front-end implementation for Apache MXNet. */
+package org.apache.mxnet.ndarray.dim;
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/index/NDIndex.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/index/NDIndex.java
new file mode 100644
index 0000000..f08ee67
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/index/NDIndex.java
@@ -0,0 +1,418 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mxnet.ndarray.index;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
+import java.util.stream.Stream;
+import org.apache.mxnet.ndarray.NDArray;
+import org.apache.mxnet.ndarray.dim.NDIndexAll;
+import org.apache.mxnet.ndarray.dim.NDIndexBooleans;
+import org.apache.mxnet.ndarray.dim.NDIndexElement;
+import org.apache.mxnet.ndarray.dim.NDIndexFixed;
+import org.apache.mxnet.ndarray.dim.NDIndexPick;
+import org.apache.mxnet.ndarray.dim.NDIndexSlice;
+import org.apache.mxnet.ndarray.types.DataType;
+
+/**
+ * The {@code NDIndex} allows you to specify a subset of an NDArray that can be used for fetching or
+ * updating.
+ *
+ * <p>It accepts a different index option for each dimension, given in the order of the dimensions.
+ * Each dimension has options corresponding to:
+ *
+ * <ul>
+ * <li>Return all dimensions - Pass null to addIndices
+ * <li>A single value in the dimension - Pass the value to addIndices with a negative index -i
+ * corresponding to [dimensionLength - i]
+ * <li>A range of values - Use addSliceDim
+ * </ul>
+ *
+ * <p>We recommend creating the NDIndex using {@link #NDIndex(String, Object...)}.
+ *
+ * @see #NDIndex(String, Object...)
+ */
+public class NDIndex {
+
+ /* Android regex requires escape } char as well */
+ private static final Pattern ITEM_PATTERN =
+ Pattern.compile(
+ "(\\*)|((-?\\d+|\\{\\})?:(-?\\d+|\\{\\})?(:(-?\\d+|\\{\\}))?)|(-?\\d+|\\{\\})");
+
+ private int rank;
+ private List<NDIndexElement> indices;
+ private int ellipsisIndex;
+
+ /** Creates an empty {@link NDIndex} to append values to. */
+ public NDIndex() {
+ rank = 0;
+ indices = new ArrayList<>();
+ ellipsisIndex = -1;
+ }
+
+ /**
+ * Creates a {@link NDIndex} given the index values.
+ *
+ * <p>Here are some examples of the indices format.
+ *
+ * <pre>
+ * NDArray a = manager.ones(new Shape(5, 4, 3));
+ *
+ * // Gets a subsection of the NDArray in the first axis.
+ * assertEquals(a.get(new NDIndex("2")).getShape(), new Shape(4, 3));
+ *
+ * // Gets a subsection of the NDArray indexing from the end (-i == length - i).
+ * assertEquals(a.get(new NDIndex("-1")).getShape(), new Shape(4, 3));
+ *
+ * // Gets everything in the first axis and a subsection in the second axis.
+ * // You can use either : or * to represent everything
+ * assertEquals(a.get(new NDIndex(":, 2")).getShape(), new Shape(5, 3));
+ * assertEquals(a.get(new NDIndex("*, 2")).getShape(), new Shape(5, 3));
+ *
+ * // Gets a range of values along the second axis that is inclusive on the bottom and exclusive on the top.
+ * assertEquals(a.get(new NDIndex(":, 1:3")).getShape(), new Shape(5, 2, 3));
+ *
+ * // Excludes either the min or the max of the range to go all the way to the beginning or end.
+ * assertEquals(a.get(new NDIndex(":, :3")).getShape(), new Shape(5, 3, 3));
+ * assertEquals(a.get(new NDIndex(":, 1:")).getShape(), new Shape(5, 4, 3));
+ *
+ * // Uses the value after the second colon in a slicing range, the step, to get every other result.
+ * assertEquals(a.get(new NDIndex(":, 1::2")).getShape(), new Shape(5, 2, 3));
+ *
+ * // Uses a negative step to reverse along the dimension.
+ * assertEquals(a.get(new NDIndex("-1")).getShape(), new Shape(5, 4, 3));
+ *
+ * // Uses a variable argument to the index
+ * // It can replace any number in any of these formats with {} and then the value of {}
+ * // is specified in an argument following the indices string.
+ * assertEquals(a.get(new NDIndex("{}, {}:{}", 0, 1, 3)).getShape(), new Shape(2, 3));
+ *
+ * // Uses ellipsis to insert many full slices
+ * assertEquals(a.get(new NDIndex("...")).getShape(), new Shape(5, 4, 3));
+ *
+ * // Uses ellipsis to select all the dimensions except for last axis where we only get a subsection.
+ * assertEquals(a.get(new NDIndex("..., 2")).getShape(), new Shape(5, 4));
+ * </pre>
+ *
+ * @param indices a comma separated list of indices corresponding to either subsections,
+ * everything, or slices on a particular dimension
+ * @param args arguments to replace the variable "{}" in the indices string. Can be an integer,
+ * long, boolean {@link NDArray}, or integer {@link NDArray}.
+ * @see <a href="https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html">Numpy
+ * Indexing</a>
+ */
+ public NDIndex(String indices, Object... args) {
+ this();
+ addIndices(indices, args);
+ }
+
+ /**
+ * Creates an NDIndex with the given indices as specified values on the NDArray.
+ *
+ * @param indices the indices with each index corresponding to the dimensions and negative
+ * indices starting from the end
+ */
+ public NDIndex(long... indices) {
+ this();
+ addIndices(indices);
+ }
+
+ /**
+ * Creates an {@link NDIndex} that just has one slice in the given axis.
+ *
+ * @param axis the axis to slice
+ * @param min the min of the slice
+ * @param max the max of the slice
+ * @return a new {@link NDIndex} with the given slice.
+ */
+ public static NDIndex sliceAxis(int axis, long min, long max) {
+ NDIndex ind = new NDIndex();
+ for (int i = 0; i < axis; i++) {
+ ind.addAllDim();
+ }
+ ind.addSliceDim(min, max);
+ return ind;
+ }
+
+ /**
+ * Returns the number of dimensions specified in the Index.
+ *
+ * @return the number of dimensions specified in the Index
+ */
+ public int getRank() {
+ return rank;
+ }
+
+ /**
+ * Returns the index of the ellipsis.
+ *
+ * @return the index of the ellipsis within this index or -1 for none.
+ */
+ public int getEllipsisIndex() {
+ return ellipsisIndex;
+ }
+
+ /**
+ * Returns the index affecting the given dimension.
+ *
+ * @param dimension the affected dimension
+ * @return the index affecting the given dimension
+ */
+ public NDIndexElement get(int dimension) {
+ return indices.get(dimension);
+ }
+
+ /**
+ * Returns the indices.
+ *
+ * @return the indices
+ */
+ public List<NDIndexElement> getIndices() {
+ return indices;
+ }
+
+ /**
+ * Updates the NDIndex by appending indices to the array.
+ *
+ * @param indices the indices to add similar to {@link #NDIndex(String, Object...)}
+ * @param args arguments to replace the variable "{}" in the indices string. Can be an integer,
+ * long, boolean {@link NDArray}, or integer {@link NDArray}.
+ * @return the updated {@link NDIndex}
+ * @see #NDIndex(String, Object...)
+ */
+ public final NDIndex addIndices(String indices, Object... args) {
+ String[] indexItems = indices.split(",");
+ rank += indexItems.length;
+ int argIndex = 0;
+ for (int i = 0; i < indexItems.length; ++i) {
+ if ("...".equals(indexItems[i].trim())) {
+ // make sure ellipsis appear only once
+ if (ellipsisIndex != -1) {
+ throw new IllegalArgumentException(
+ "an index can only have a single ellipsis (\"...\")");
+ }
+ ellipsisIndex = i;
+ } else {
+ argIndex = addIndexItem(indexItems[i], argIndex, args);
+ }
+ }
+ if (ellipsisIndex != -1) {
+ rank--;
+ }
+ if (argIndex != args.length) {
+ throw new IllegalArgumentException("Incorrect number of index arguments");
+ }
+ return this;
+ }
+
+ /**
+ * Updates the NDIndex by appending indices as specified values on the NDArray.
+ *
+ * @param indices with each index corresponding to the dimensions and negative indices starting
+ * from the end
+ * @return the updated {@link NDIndex}
+ */
+ public final NDIndex addIndices(long... indices) {
+ rank += indices.length;
+ for (long i : indices) {
+ this.indices.add(new NDIndexFixed(i));
+ }
+ return this;
+ }
+
+ /**
+ * Updates the NDIndex by appending a boolean NDArray.
+ *
+ * <p>The NDArray should have a matching shape to the dimensions being fetched and will return
+ * where the values in NDIndex do not equal zero.
+ *
+ * @param index a boolean NDArray where all nonzero elements correspond to elements to return
+ * @return the updated {@link NDIndex}
+ */
+ public NDIndex addBooleanIndex(NDArray index) {
+ rank += index.getShape().dimension();
+ indices.add(new NDIndexBooleans(index));
+ return this;
+ }
+
+ /**
+ * Appends a new index to get all values in the dimension.
+ *
+ * @return the updated {@link NDIndex}
+ */
+ public NDIndex addAllDim() {
+ rank++;
+ indices.add(new NDIndexAll());
+ return this;
+ }
+
+ /**
+ * Appends multiple new index to get all values in the dimension.
+ *
+ * @param count how many axes of {@link NDIndexAll} to add.
+ * @return the updated {@link NDIndex}
+ * @throws IllegalArgumentException if count is negative
+ */
+ public NDIndex addAllDim(int count) {
+ if (count < 0) {
+ throw new IllegalArgumentException(
+ "The number of index dimensions to add can't be negative");
+ }
+ rank += count;
+ for (int i = 0; i < count; i++) {
+ indices.add(new NDIndexAll());
+ }
+ return this;
+ }
+
+ /**
+ * Appends a new index to slice the dimension and returns a range of values.
+ *
+ * @param min the minimum of the range
+ * @param max the maximum of the range
+ * @return the updated {@link NDIndex}
+ */
+ public NDIndex addSliceDim(long min, long max) {
+ rank++;
+ indices.add(new NDIndexSlice(min, max, null));
+ return this;
+ }
+
+ /**
+ * Appends a new index to slice the dimension and returns a range of values.
+ *
+ * @param min the minimum of the range
+ * @param max the maximum of the range
+ * @param step the step of the slice
+ * @return the updated {@link NDIndex}
+ */
+ public NDIndex addSliceDim(long min, long max, long step) {
+ rank++;
+ indices.add(new NDIndexSlice(min, max, step));
+ return this;
+ }
+
+ /**
+ * Appends a picking index that gets values by index in the axis.
+ *
+ * @param index the indices should be NDArray. For each element in the indices array, it acts
+ * like a fixed index returning an element of that shape. So, the final shape would be
+ * indices.getShape().addAll(target.getShape().slice(1)) (assuming it is the first index
+ * element).
+ * @return the updated {@link NDIndex}
+ */
+ public NDIndex addPickDim(NDArray index) {
+ rank++;
+ indices.add(new NDIndexPick(index));
+ return this;
+ }
+
+ /**
+ * Returns a stream of the NDIndexElements.
+ *
+ * @return a stream of the NDIndexElements
+ */
+ public Stream<NDIndexElement> stream() {
+ return indices.stream();
+ }
+
+ private int addIndexItem(String indexItem, int argIndex, Object[] args) {
+ indexItem = indexItem.trim();
+ Matcher m = ITEM_PATTERN.matcher(indexItem);
+ if (!m.matches()) {
+ throw new IllegalArgumentException("Invalid argument index: " + indexItem);
+ }
+ // "*" case
+ String star = m.group(1);
+ if (star != null) {
+ indices.add(new NDIndexAll());
+ return argIndex;
+ }
+ // "number" number only case
+ String digit = m.group(7);
+ if (digit != null) {
+ if ("{}".equals(digit)) {
+ Object arg = args[argIndex];
+ if (arg instanceof Integer) {
+ indices.add(new NDIndexFixed((Integer) arg));
+ return argIndex + 1;
+ } else if (arg instanceof Long) {
+ indices.add(new NDIndexFixed((Long) arg));
+ return argIndex + 1;
+ } else if (arg instanceof NDArray) {
+ NDArray array = (NDArray) arg;
+ if (array.getDataType() == DataType.BOOLEAN) {
+ indices.add(new NDIndexBooleans(array));
+ return argIndex + 1;
+ } else if (array.getDataType().isInteger()) {
+ indices.add(new NDIndexPick(array));
+ return argIndex + 1;
+ }
+ }
+ throw new IllegalArgumentException("Unknown argument: " + arg);
+ } else {
+ indices.add(new NDIndexFixed(Long.parseLong(digit)));
+ return argIndex;
+ }
+ }
+
+ // Slice
+ Long min = null;
+ Long max = null;
+ Long step = null;
+ if (m.group(3) != null) {
+ min = parseSliceItem(m.group(3), argIndex, args);
+ if ("{}".equals(m.group(3))) {
+ argIndex++;
+ }
+ }
+ if (m.group(4) != null) {
+ max = parseSliceItem(m.group(4), argIndex, args);
+ if ("{}".equals(m.group(4))) {
+ argIndex++;
+ }
+ }
+ if (m.group(6) != null) {
+ step = parseSliceItem(m.group(6), argIndex, args);
+ if ("{}".equals(m.group(6))) {
+ argIndex++;
+ }
+ }
+ if (min == null && max == null && step == null) {
+ indices.add(new NDIndexAll());
+ } else {
+ indices.add(new NDIndexSlice(min, max, step));
+ }
+ return argIndex;
+ }
+
+ private Long parseSliceItem(String sliceItem, int argIndex, Object... args) {
+ if ("{}".equals(sliceItem)) {
+ Object arg = args[argIndex];
+ if (arg instanceof Integer) {
+ return ((Integer) arg).longValue();
+ } else if (arg instanceof Long) {
+ return (Long) arg;
+ }
+ throw new IllegalArgumentException("Unknown slice argument: " + arg);
+ } else {
+ return Long.parseLong(sliceItem);
+ }
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/index/package-info.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/index/package-info.java
new file mode 100644
index 0000000..7cc862c
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/index/package-info.java
@@ -0,0 +1,19 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/** Contains the Java front-end implementation for Apache MXNet. */
+package org.apache.mxnet.ndarray.index;
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/package-info.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/package-info.java
new file mode 100644
index 0000000..b161896
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/package-info.java
@@ -0,0 +1,19 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/** Contains the Java front-end implementation for Apache MXNet. */
+package org.apache.mxnet.ndarray;
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/types/DataType.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/types/DataType.java
new file mode 100644
index 0000000..bda4181
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/types/DataType.java
@@ -0,0 +1,145 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.ndarray.types;
+
+import java.nio.Buffer;
+import java.nio.ByteBuffer;
+import java.nio.DoubleBuffer;
+import java.nio.FloatBuffer;
+import java.nio.IntBuffer;
+import java.nio.LongBuffer;
+import org.apache.mxnet.ndarray.NDArray;
+
+/** An enum representing the underlying {@link NDArray}'s data type. */
+public enum DataType {
+ FLOAT32(Format.FLOATING, 4),
+ FLOAT64(Format.FLOATING, 8),
+ FLOAT16(Format.FLOATING, 2),
+ UINT8(Format.UINT, 1),
+ INT32(Format.INT, 4),
+ INT8(Format.INT, 1),
+ INT64(Format.INT, 8),
+ BOOLEAN(Format.BOOLEAN, 1),
+ UNKNOWN(Format.UNKNOWN, 0);
+ /** The general data type format categories. */
+ public enum Format {
+ FLOATING,
+ UINT,
+ INT,
+ BOOLEAN,
+ UNKNOWN
+ }
+
+ private Format format;
+ private int numOfBytes;
+
+ DataType(Format format, int numOfBytes) {
+ this.format = format;
+ this.numOfBytes = numOfBytes;
+ }
+
+ /**
+ * Returns the number of bytes for each element.
+ *
+ * @return the number of bytes for each element
+ */
+ public int getNumOfBytes() {
+ return numOfBytes;
+ }
+
+ /**
+ * Returns the format of the data type.
+ *
+ * @return the format of the data type
+ */
+ public Format getFormat() {
+ return format;
+ }
+
+ /**
+ * Checks whether it is a floating data type.
+ *
+ * @return whether it is a floating data type
+ */
+ public boolean isFloating() {
+ return format == Format.FLOATING;
+ }
+
+ /**
+ * Checks whether it is an integer data type.
+ *
+ * @return whether it is an integer type
+ */
+ public boolean isInteger() {
+ return format == Format.UINT || format == Format.INT;
+ }
+
+ /**
+ * Returns the data type to use for a data buffer.
+ *
+ * @param data the buffer to analyze
+ * @return the data type for the buffer
+ */
+ public static DataType fromBuffer(Buffer data) {
+ if (data instanceof FloatBuffer) {
+ return DataType.FLOAT32;
+ } else if (data instanceof DoubleBuffer) {
+ return DataType.FLOAT64;
+ } else if (data instanceof IntBuffer) {
+ return DataType.INT32;
+ } else if (data instanceof LongBuffer) {
+ return DataType.INT64;
+ } else if (data instanceof ByteBuffer) {
+ return DataType.INT8;
+ } else {
+ throw new IllegalArgumentException(
+ "Unsupported buffer type: " + data.getClass().getSimpleName());
+ }
+ }
+
+ /**
+ * Converts a {@link ByteBuffer} to a buffer for this data type.
+ *
+ * @param data the buffer to convert
+ * @return the converted buffer
+ */
+ public Buffer asDataType(ByteBuffer data) {
+ switch (this) {
+ case FLOAT32:
+ return data.asFloatBuffer();
+ case FLOAT64:
+ return data.asDoubleBuffer();
+ case INT32:
+ return data.asIntBuffer();
+ case INT64:
+ return data.asLongBuffer();
+ case UINT8:
+ case INT8:
+ case FLOAT16:
+ case UNKNOWN:
+ default:
+ return data;
+ }
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public String toString() {
+ return name().toLowerCase();
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/types/LayoutType.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/types/LayoutType.java
new file mode 100644
index 0000000..9602146
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/types/LayoutType.java
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.ndarray.types;
+
+import java.util.stream.IntStream;
+import org.apache.mxnet.ndarray.NDArray;
+
+/**
+ * An enum to represent the meaning of a particular axis in an {@link NDArray}.
+ *
+ * <p>The options are:
+ *
+ * <ul>
+ * <li>{@link LayoutType#BATCH} - Different elements in a batch
+ * <li>{@link LayoutType#CHANNEL} - Each channel represents a different aspect of the data such as
+ * RGB showing different color channels.
+ * <li>{@link LayoutType#DEPTH} - The depth of a 3-D input
+ * <li>{@link LayoutType#HEIGHT} - The width of a multi-dimensional input, usually an image.
+ * <li>{@link LayoutType#WIDTH} - The height of a multi-dimensional input, usually an image.
+ * <li>{@link LayoutType#TIME} - The time within a sequence such as text or video.
+ * <li>{@link LayoutType#UNKNOWN} - A unknown or otherwise unrepresentable layout type.
+ * </ul>
+ */
+public enum LayoutType {
+ BATCH('N'),
+ CHANNEL('C'),
+ DEPTH('D'),
+ HEIGHT('H'),
+ WIDTH('W'),
+ TIME('T'),
+ UNKNOWN('?');
+
+ private char value;
+
+ LayoutType(char value) {
+ this.value = value;
+ }
+
+ /**
+ * Returns the character representation of the layout type.
+ *
+ * @return the character representation of the layout type
+ */
+ public char getValue() {
+ return value;
+ }
+
+ /**
+ * Converts the character to the matching layout type.
+ *
+ * @param value the character to convert
+ * @return the matching layout type
+ * @throws IllegalArgumentException thrown if the character does not match any layout type
+ */
+ public static LayoutType fromValue(char value) {
+ for (LayoutType type : LayoutType.values()) {
+ if (value == type.value) {
+ return type;
+ }
+ }
+ throw new IllegalArgumentException(
+ "The value does not match any layoutTypes. Use '?' for Unknown");
+ }
+
+ /**
+ * Converts each character to the matching layout type.
+ *
+ * @param layout the character string to convert
+ * @return the list of layout types for each character in the string
+ * @throws IllegalArgumentException thrown if the character does not match any layout type
+ */
+ public static LayoutType[] fromValue(String layout) {
+ return IntStream.range(0, layout.length())
+ .mapToObj(i -> fromValue(layout.charAt(i)))
+ .toArray(LayoutType[]::new);
+ }
+
+ /**
+ * Converts a layout type array to a string of the character representations.
+ *
+ * @param layouts the layout type to convert
+ * @return the string of the character representations
+ */
+ public static String toString(LayoutType[] layouts) {
+ StringBuilder sb = new StringBuilder(layouts.length);
+ for (LayoutType layout : layouts) {
+ sb.append(layout.getValue());
+ }
+ return sb.toString();
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/types/Shape.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/types/Shape.java
new file mode 100644
index 0000000..4348e45
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/types/Shape.java
@@ -0,0 +1,484 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.ndarray.types;
+
+import java.io.DataInputStream;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.Arrays;
+import java.util.List;
+import java.util.function.Function;
+import java.util.function.Predicate;
+import java.util.stream.Collectors;
+import java.util.stream.LongStream;
+import java.util.stream.Stream;
+import org.apache.mxnet.ndarray.NDArray;
+import org.apache.mxnet.util.Pair;
+import org.apache.mxnet.util.PairList;
+
+/** A class that presents the {@link NDArray}'s shape information. */
+public class Shape {
+
+ private long[] shape;
+ private LayoutType[] layout;
+
+ /**
+ * Constructs and initializes a {@code Shape} with specified dimension as {@code (long...
+ * shape)}.
+ *
+ * @param shape the dimensions of the shape
+ * @throws IllegalArgumentException Thrown if any element in Shape is invalid. It should not be
+ * less than -1. Also thrown if the shape and layout do not have equal sizes.
+ */
+ public Shape(long... shape) {
+ this(
+ shape,
+ Arrays.stream(shape).mapToObj(x -> LayoutType.UNKNOWN).toArray(LayoutType[]::new));
+ }
+
+ /**
+ * Constructs and initializes a {@code Shape} with specified dimension.
+ *
+ * @param shape the dimensions of the shape
+ * @throws IllegalArgumentException Thrown if any element in Shape is invalid. It should not be
+ * less than -1. Also thrown if the shape and layout do not have equal sizes.
+ */
+ public Shape(List<Long> shape) {
+ this(
+ shape.stream().mapToLong(l -> l).toArray(),
+ shape.stream().map(x -> LayoutType.UNKNOWN).toArray(LayoutType[]::new));
+ }
+
+ /**
+ * Constructs and initializes a {@code Shape} with specified shape and layout pairList.
+ *
+ * @param shape the dimensions and layout of the shape
+ * @throws IllegalArgumentException Thrown if any element in Shape is invalid. It should not be
+ * less than -1 .Also thrown if the shape and layout do not have equal sizes.
+ */
+ public Shape(PairList<Long, LayoutType> shape) {
+ this(
+ shape.keys().stream().mapToLong(l -> l).toArray(),
+ shape.values().toArray(new LayoutType[shape.size()]));
+ }
+
+ /**
+ * Constructs and initializes a {@code Shape} with specified dimension and layout.
+ *
+ * @param shape the size of each axis of the shape
+ * @param layout the {@link LayoutType} of each axis in the shape
+ * @throws IllegalArgumentException Thrown if any element in Shape is invalid. It should not be
+ * less than -1. Also thrown for an invalid layout. Also thrown if the shape and layout do
+ * not have equal sizes.
+ */
+ public Shape(long[] shape, String layout) {
+ this(shape, LayoutType.fromValue(layout));
+ }
+
+ /**
+ * Constructs and initializes a {@code Shape} with specified dimension and layout.
+ *
+ * @param shape the size of each axis of the shape
+ * @param layout the {@link LayoutType} of each axis in the shape
+ * @throws IllegalArgumentException Thrown if any element in Shape is invalid. It should not be
+ * less than -1. Also thrown if the shape and layout do not have equal sizes.
+ */
+ public Shape(long[] shape, LayoutType[] layout) {
+ if (Arrays.stream(shape).anyMatch(s -> s < -1)) {
+ throw new IllegalArgumentException("The shape must be >= -1");
+ }
+ if (shape.length != layout.length) {
+ throw new IllegalArgumentException("The shape and layout must have the same length");
+ }
+ this.shape = shape;
+ this.layout = layout;
+ }
+
+ /**
+ * Returns a new shape altering the given dimension.
+ *
+ * @param shape the shape to update
+ * @param dimension the dimension to get the shape in
+ * @param value the value to set the dimension to
+ * @return a new shape with the update applied
+ */
+ public static Shape update(Shape shape, int dimension, long value) {
+ long[] newShape = shape.shape.clone();
+ newShape[dimension] = value;
+ return new Shape(newShape, shape.layout);
+ }
+
+ /**
+ * Returns the dimensions of the {@code Shape}.
+ *
+ * @return the dimensions of the {@code Shape}
+ */
+ public long[] getShape() {
+ return shape;
+ }
+
+ /**
+ * Returns the shape in the given dimension.
+ *
+ * @param dimension the dimension to get the shape in
+ * @return the shape in the given dimension
+ */
+ public long get(int dimension) {
+ return shape[dimension];
+ }
+
+ /**
+ * Returns the layout type in the given dimension.
+ *
+ * @param dimension the dimension to get the layout type in
+ * @return the layout type in the given dimension
+ */
+ public LayoutType getLayoutType(int dimension) {
+ return layout[dimension];
+ }
+
+ /**
+ * Returns the size of a specific dimension or several specific dimensions.
+ *
+ * @param dimensions the dimension or dimensions to find the size of
+ * @return the size of specific dimension(s) or -1 for indeterminate size
+ * @throws IllegalArgumentException thrown if passed an invalid dimension
+ */
+ public long size(int... dimensions) {
+ long total = 1;
+ for (long d : dimensions) {
+ if (d < 0 || d >= shape.length) {
+ throw new IllegalArgumentException("Invalid dimension " + d);
+ }
+ if (shape[Math.toIntExact(d)] == -1) {
+ return -1;
+ }
+ total *= shape[Math.toIntExact(d)];
+ }
+ return total;
+ }
+
+ /**
+ * Returns the total size.
+ *
+ * @return the total size or -1 for indeterminate size
+ */
+ public long size() {
+ long total = 1;
+ for (long v : shape) {
+ if (v == -1) {
+ return -1;
+ }
+ total *= v;
+ }
+ return total;
+ }
+
+ /**
+ * Returns the number of dimensions of this {@code Shape}.
+ *
+ * @return the number of dimensions of this {@code Shape}
+ */
+ public int dimension() {
+ return shape.length;
+ }
+
+ /**
+ * Return the count of unknown value in this {@code Shape}.
+ *
+ * @return the number of unknown value in this {@code Shape}
+ */
+ public long getUnknownValueCount() {
+ return Arrays.stream(shape).filter(s -> s == -1).count();
+ }
+
+ /**
+ * Creates a new {@code Shape} whose content is a slice of this shape.
+ *
+ * <p>The sub shape begins at the specified {@code beginIndex} and extends to {@code endIndex -
+ * 1}.
+ *
+ * @param beginIndex the beginning index, inclusive
+ * @return a new {@code Shape} whose content is a slice of this shape
+ */
+ public Shape slice(int beginIndex) {
+ return slice(beginIndex, shape.length);
+ }
+
+ /**
+ * Creates a new {@code Shape} whose content is a slice of this shape.
+ *
+ * <p>The sub shape begins at the specified {@code beginIndex} and extends to {@code endIndex -
+ * 1}.
+ *
+ * @param beginIndex the beginning index, inclusive
+ * @param endIndex the ending index, exclusive
+ * @return a new {@code Shape} whose content is a slice of this shape
+ */
+ public Shape slice(int beginIndex, int endIndex) {
+ int size = endIndex - beginIndex;
+ long[] out = new long[size];
+ System.arraycopy(shape, beginIndex, out, 0, size);
+ return new Shape(out);
+ }
+
+ /**
+ * Returns only the axes of the Shape whose layout types match the predicate.
+ *
+ * @param predicate the predicate to compare the axes of the Shape with
+ * @return a new filtered Shape
+ */
+ public Shape filterByLayoutType(Predicate<LayoutType> predicate) {
+ return new Shape(
+ new PairList<>(
+ this.stream()
+ .filter(pair -> predicate.test(pair.getValue()))
+ .collect(Collectors.toList())));
+ }
+
+ /**
+ * Returns a mapped shape.
+ *
+ * @param mapper the function to map each element of the Shape by
+ * @return a new mapped Shape
+ */
+ public Shape map(Function<Pair<Long, LayoutType>, Pair<Long, LayoutType>> mapper) {
+ return new Shape(new PairList<>(stream().map(mapper).collect(Collectors.toList())));
+ }
+
+ /**
+ * Returns a stream of the Shape.
+ *
+ * @return the stream of the Shape
+ */
+ public Stream<Pair<Long, LayoutType>> stream() {
+ return new PairList<>(
+ Arrays.stream(shape).boxed().collect(Collectors.toList()),
+ Arrays.asList(layout))
+ .stream();
+ }
+
+ /**
+ * Joins this shape with axes.
+ *
+ * @param axes the axes to join
+ * @return the joined {@code Shape}
+ */
+ public Shape add(long... axes) {
+ return this.addAll(new Shape(axes));
+ }
+
+ /**
+ * Joins this shape with specified {@code other} shape.
+ *
+ * @param other the shape to join
+ * @return the joined {@code Shape}
+ */
+ public Shape addAll(Shape other) {
+ return new Shape(
+ LongStream.concat(Arrays.stream(shape), Arrays.stream(other.shape)).toArray());
+ }
+
+ /**
+ * Returns the head index of the shape.
+ *
+ * @return the head index of the shape
+ * @throws IndexOutOfBoundsException Thrown if the shape is empty
+ */
+ public long head() {
+ // scalar case
+ if (shape.length == 0) {
+ throw new IndexOutOfBoundsException("can't get value from scalar shape.");
+ }
+ return shape[0];
+ }
+
+ /**
+ * Returns the tail index of the shape.
+ *
+ * @return the tail index of the shape
+ * @throws IndexOutOfBoundsException Thrown if the shape is empty
+ */
+ public long tail() {
+ // scalar case
+ if (shape.length == 0) {
+ throw new IndexOutOfBoundsException("can't get value from scalar shape.");
+ }
+ return shape[shape.length - 1];
+ }
+
+ /**
+ * Returns the number of trailing ones in the array shape.
+ *
+ * <p>For example, a rank 3 array with shape [10, 1, 1] would return 2 for this method
+ *
+ * @return the number of trailing ones in the shape
+ */
+ public int getTrailingOnes() {
+ for (int i = 0; i < shape.length; i++) {
+ if (shape[shape.length - i - 1] != 1) {
+ return i;
+ }
+ }
+ return 0;
+ }
+
+ /**
+ * Returns the number of leading ones in the array shape.
+ *
+ * <p>For example, a rank 3 array with shape [1, 10, 1] would return value 1 for this method
+ *
+ * @return the number of leading ones in the shape
+ */
+ public int getLeadingOnes() {
+ for (int i = 0; i < shape.length; i++) {
+ if (shape[i] != 1) {
+ return i;
+ }
+ }
+ return 0;
+ }
+
+ /**
+ * Returns {@code true} if the NDArray is a scalar.
+ *
+ * @return whether the NDArray is a scalar
+ */
+ public boolean isScalar() {
+ return dimension() == 0;
+ }
+
+ /**
+ * Returns {@code true} if the NDArray contains zero dimensions.
+ *
+ * @return whether the NDArray contain zero dimensions
+ */
+ public boolean hasZeroDimension() {
+ for (int i = 0; i < dimension(); i++) {
+ if (shape[i] == 0) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ /**
+ * Returns {@code true} if a layout is set.
+ *
+ * @return whether a layout has been set
+ */
+ public boolean isLayoutKnown() {
+ return !Arrays.stream(layout).allMatch(l -> l == LayoutType.UNKNOWN);
+ }
+
+ /**
+ * Returns the layout type for each axis in this shape.
+ *
+ * @return the layout type for each axis in this shape
+ */
+ public LayoutType[] getLayout() {
+ return layout;
+ }
+
+ /**
+ * Returns the string layout type for each axis in this shape.
+ *
+ * @return the string layout type for each axis in this shape
+ */
+ public String toLayoutString() {
+ return LayoutType.toString(layout);
+ }
+
+ /**
+ * Gets the byte array representation of this {@code Shape} for serialization.
+ *
+ * @return a byte array representation of this {@code Shape}
+ */
+ public byte[] getEncoded() {
+ int length = 8 + shape.length * 8 + layout.length * 2;
+ ByteBuffer bb = ByteBuffer.allocate(length);
+ bb.putInt(shape.length);
+ for (long l : shape) {
+ bb.putLong(l);
+ }
+ bb.putInt(layout.length);
+ for (LayoutType layoutType : layout) {
+ bb.putChar(layoutType.getValue());
+ }
+ return bb.array();
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ Shape shape1 = (Shape) o;
+ return Arrays.equals(shape, shape1.shape);
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public int hashCode() {
+ return Arrays.hashCode(shape);
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public String toString() {
+ StringBuilder sb = new StringBuilder();
+ sb.append('(');
+ for (int i = 0; i < shape.length; ++i) {
+ if (i > 0) {
+ sb.append(", ");
+ }
+ sb.append(shape[i]);
+ }
+ sb.append(')');
+ return sb.toString();
+ }
+
+ /**
+ * Decodes the data in the given {@link DataInputStream} and converts it into the corresponding
+ * {@link Shape} object.
+ *
+ * @param dis the inputstream to read from
+ * @return the corresponding {@link Shape} object
+ * @throws IOException when an I/O error occurs
+ */
+ public static Shape decode(DataInputStream dis) throws IOException {
+ // Shape
+ int length = dis.readInt();
+ long[] shapeValue = new long[length];
+ for (int i = 0; i < length; ++i) {
+ shapeValue[i] = dis.readLong();
+ }
+
+ // Layout
+ length = dis.readInt();
+ char[] layout = new char[length];
+ for (int i = 0; i < length; ++i) {
+ layout[i] = dis.readChar();
+ }
+ return new Shape(shapeValue, new String(layout));
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/types/SparseFormat.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/types/SparseFormat.java
new file mode 100644
index 0000000..7c3d389
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/types/SparseFormat.java
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.ndarray.types;
+
+/**
+ * An enum representing Sparse matrix storage formats.
+ *
+ * <ul>
+ * <li>DENSE: Stride format
+ * <li>ROW_SPARSE: Row Sparse
+ * <li>CSR: Compressed Sparse Row
+ * </ul>
+ *
+ * @see <a href="https://software.intel.com/en-us/node/471374">Sparse Matrix Storage Formats</a>
+ */
+public enum SparseFormat {
+ // the dense format is accelerated by MKLDNN by default
+ DENSE("default", 0),
+ ROW_SPARSE("row_sparse", 1),
+ CSR("csr", 2);
+
+ private String type;
+ private int value;
+
+ SparseFormat(String type, int value) {
+ this.type = type;
+ this.value = value;
+ }
+
+ /**
+ * Gets the {@code SparseFormat} from it's integer value.
+ *
+ * @param value the integer value of the {@code SparseFormat}
+ * @return a {@code SparseFormat}
+ */
+ public static SparseFormat fromValue(int value) {
+ for (SparseFormat t : values()) {
+ if (value == t.getValue()) {
+ return t;
+ }
+ }
+ throw new IllegalArgumentException("Unknown Sparse type: " + value);
+ }
+
+ /**
+ * Returns the {@code SparseFormat} name.
+ *
+ * @return the {@code SparseFormat} name
+ */
+ public String getType() {
+ return type;
+ }
+
+ /**
+ * Returns the integer value of this {@code SparseFormat}.
+ *
+ * @return the integer value of this {@code SparseFormat}
+ */
+ public int getValue() {
+ return value;
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/types/package-info.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/types/package-info.java
new file mode 100644
index 0000000..c58f71d
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/types/package-info.java
@@ -0,0 +1,19 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/** Contains the Java front-end implementation for Apache MXNet. */
+package org.apache.mxnet.ndarray.types;
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/nn/Parameter.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/nn/Parameter.java
new file mode 100644
index 0000000..9a268a3
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/nn/Parameter.java
@@ -0,0 +1,301 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.nn;
+
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.util.Objects;
+import java.util.UUID;
+import org.apache.mxnet.engine.Device;
+import org.apache.mxnet.engine.MxResource;
+import org.apache.mxnet.exception.MalformedModelException;
+import org.apache.mxnet.ndarray.NDArray;
+import org.apache.mxnet.ndarray.NDSerializer;
+import org.apache.mxnet.ndarray.types.DataType;
+import org.apache.mxnet.ndarray.types.Shape;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * {@code Parameter} is a container class that holds a learnable parameter of a model.
+ *
+ * <p>Every {@code Parameter} is associated with a {@link SymbolBlock}. The output of the block's
+ * forward function depends on the values in the {@code Parameter}. During training, the values in
+ * the {@code Parameter} are updated to reflect the training data. This process forms the crux of
+ * learning.
+ *
+ * @see <a href="https://d2l.djl.ai/chapter_deep-learning-computation/parameters.html">The D2L
+ * chapter on parameter management</a>
+ */
+public class Parameter extends MxResource {
+ private static final Logger logger = LoggerFactory.getLogger(Parameter.class);
+
+ private static final byte VERSION = 1;
+
+ private String id;
+ private String name;
+ private Shape shape;
+ private Type type;
+ private NDArray array;
+
+ Parameter(Builder builder) {
+ this.id = UUID.randomUUID().toString();
+ this.name = builder.name;
+ this.shape = builder.shape;
+ this.type = builder.type;
+ this.array = builder.array;
+ }
+
+ /**
+ * Gets the ID of this {@code Parameter}.
+ *
+ * @return the ID of this {@code Parameter}
+ */
+ public String getId() {
+ return id;
+ }
+
+ /**
+ * Gets the name of this {@code Parameter}.
+ *
+ * @return the name of this {@code Parameter}
+ */
+ public String getName() {
+ return name == null ? "" : name;
+ }
+
+ /**
+ * Gets the type of this {@code Parameter}.
+ *
+ * @return the type of this {@code Parameter}
+ */
+ public Type getType() {
+ return type;
+ }
+
+ /**
+ * Sets the values of this {@code Parameter}.
+ *
+ * @param array the {@link NDArray} that contains values of this {@code Parameter}
+ */
+ public void setArray(NDArray array) {
+ if (shape != null) {
+ throw new IllegalStateException("array has been set! Use either setArray or setShape");
+ }
+ this.array = array;
+ shape = array.getShape();
+ array.setName(name);
+ }
+
+ /**
+ * Sets the shape of this {@code Parameter}.
+ *
+ * @param shape the shape of this {@code Parameter}
+ */
+ public void setShape(Shape shape) {
+ if (array != null) {
+ throw new IllegalStateException("array has been set! Use either setArray or setShape");
+ }
+ this.shape = shape;
+ }
+
+ /**
+ * Gets the values of this {@code Parameter} as an {@link NDArray}.
+ *
+ * @return an {@link NDArray} that contains values of this {@code Parameter}
+ */
+ public NDArray getArray() {
+ if (!isInitialized()) {
+ throw new IllegalStateException("The array has not been initialized");
+ }
+ return array;
+ }
+
+ /**
+ * Checks if this {@code Parameter} is initialized.
+ *
+ * @return {@code true} if this {@code Parameter} is initialized
+ */
+ public boolean isInitialized() {
+ return array != null;
+ }
+
+ /**
+ * Initializes the parameter, with given {@link DataType} for the given expected input shapes.
+ *
+ * @param parent the parent {@link MxResource} to manage this instance
+ * @param dataType the datatype of the {@code Parameter}
+ * @param device the device of {@link NDArray} in the {@code Parameter}
+ */
+ public void initialize(MxResource parent, DataType dataType, Device device) {
+ Objects.requireNonNull(shape, "No parameter shape has been set");
+ }
+
+ /**
+ * Writes the parameter NDArrays to the given output stream.
+ *
+ * @param dos the output stream to write to
+ * @throws IOException if the write operation fails
+ */
+ public void save(DataOutputStream dos) throws IOException {
+ if (!isInitialized()) {
+ dos.writeChar('N');
+ return;
+ }
+
+ dos.writeChar('P');
+ dos.writeByte(VERSION);
+ dos.writeUTF(getName());
+ dos.write(array.encode());
+ }
+
+ /**
+ * Loads parameter NDArrays from InputStream.
+ *
+ * <p>Currently, we cannot deserialize into the exact subclass of NDArray. The SparseNDArray
+ * will be loaded as NDArray only.
+ *
+ * @param parent the parent {@link MxResource} to manage this instance
+ * @param dis the InputStream
+ * @throws IOException if failed to read (parameters).
+ */
+ public void load(MxResource parent, DataInputStream dis) throws IOException {
+ char magic = dis.readChar();
+ if (magic == 'N') {
+ return;
+ } else if (magic != 'P') {
+ throw new MalformedModelException("Invalid input data.");
+ }
+
+ // Version
+ byte version = dis.readByte();
+ if (version != VERSION) {
+ throw new MalformedModelException("Unsupported encoding version: " + version);
+ }
+
+ String parameterName = dis.readUTF();
+ if (!parameterName.equals(getName())) {
+ throw new MalformedModelException(
+ "Unexpected parameter name: " + parameterName + ", expected: " + name);
+ }
+
+ array = NDSerializer.decode(parent, dis);
+ // set the shape of the parameter and prepare() can be skipped
+ shape = array.getShape();
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public void close() {
+ if (!getClosed()) {
+ logger.debug(String.format("Start to free Symbol instance: %S", this.getUid()));
+ super.freeSubResources();
+ if (array != null) {
+ array.close();
+ array = null;
+ }
+ setClosed(true);
+ logger.debug(String.format("Start to free Symbol instance: %S", this.getUid()));
+ }
+ }
+
+ /**
+ * Creates a builder to build a {@code Parameter}.
+ *
+ * <p>The methods start with {@code set} are required fields, and {@code opt} for optional
+ * fields.
+ *
+ * @return a new builder
+ */
+ public static Builder builder() {
+ return new Builder();
+ }
+
+ /** Enumerates the types of {@link Parameter}. */
+ public enum Type {
+ WEIGHT,
+ BIAS,
+ GAMMA,
+ BETA,
+ RUNNING_MEAN,
+ RUNNING_VAR,
+ OTHER;
+ }
+
+ /** A Builder to construct a {@code Parameter}. */
+ public static final class Builder {
+ String name;
+ Shape shape;
+ Type type;
+ NDArray array;
+
+ /**
+ * Sets the name of the {@code Parameter}.
+ *
+ * @param name the name of the {@code Parameter}
+ * @return this {@code Parameter}
+ */
+ public Builder setName(String name) {
+ this.name = name;
+ return this;
+ }
+
+ /**
+ * Sets the {@code Type} of the {@code Parameter}.
+ *
+ * @param type the {@code Type} of the {@code Parameter}
+ * @return this {@code Parameter}
+ */
+ public Builder setType(Type type) {
+ this.type = type;
+ return this;
+ }
+
+ /**
+ * Sets the shape of the {@code Parameter}.
+ *
+ * @param shape the shape of the {@code Parameter}
+ * @return this {@code Parameter}
+ */
+ public Builder optShape(Shape shape) {
+ this.shape = shape;
+ return this;
+ }
+
+ /**
+ * Sets the array of the {@code Parameter}.
+ *
+ * @param array the array of the {@code Parameter}
+ * @return this {@code Parameter}
+ */
+ public Builder optArray(NDArray array) {
+ this.array = array;
+ return this;
+ }
+
+ /**
+ * Builds a {@code Parameter} instance.
+ *
+ * @return the {@code Parameter} instance
+ */
+ public Parameter build() {
+ return new Parameter(this);
+ }
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/nn/ParameterList.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/nn/ParameterList.java
new file mode 100644
index 0000000..575972a
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/nn/ParameterList.java
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.nn;
+
+import java.util.List;
+import java.util.Map;
+import org.apache.mxnet.util.Pair;
+import org.apache.mxnet.util.PairList;
+
+/** Represents a set of names and Parameters. */
+public class ParameterList extends PairList<String, Parameter> {
+
+ /** Create an empty {@code ParameterList}. */
+ public ParameterList() {}
+
+ /**
+ * Constructs an empty {@code ParameterList} with the specified initial capacity.
+ *
+ * @param initialCapacity the initial capacity of the list
+ * @throws IllegalArgumentException if the specified initial capacity is negative
+ */
+ public ParameterList(int initialCapacity) {
+ super(initialCapacity);
+ }
+
+ /**
+ * Constructs a {@code ParameterList} containing the elements of the specified keys and values.
+ *
+ * @param keys the key list containing the elements to be placed into this {@code ParameterList}
+ * @param values the value list containing the elements to be placed into this {@code
+ * ParameterList}
+ * @throws IllegalArgumentException if the keys and values size are different
+ */
+ public ParameterList(List<String> keys, List<Parameter> values) {
+ super(keys, values);
+ }
+
+ /**
+ * Constructs a {@code ParameterList} containing the elements of the specified list of Pairs.
+ *
+ * @param list the list containing the elements to be placed into this {@code ParameterList}
+ */
+ public ParameterList(List<Pair<String, Parameter>> list) {
+ super(list);
+ }
+
+ /**
+ * Constructs a {@code ParameterList} containing the elements of the specified map.
+ *
+ * @param map the map containing keys and values
+ */
+ public ParameterList(Map<String, Parameter> map) {
+ super(map);
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/nn/SymbolBlock.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/nn/SymbolBlock.java
new file mode 100644
index 0000000..cb52f38
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/nn/SymbolBlock.java
@@ -0,0 +1,556 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.nn;
+
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.nio.charset.StandardCharsets;
+import java.nio.file.Path;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import org.apache.mxnet.engine.CachedOp;
+import org.apache.mxnet.engine.Device;
+import org.apache.mxnet.engine.MxResource;
+import org.apache.mxnet.engine.MxResourceList;
+import org.apache.mxnet.engine.Symbol;
+import org.apache.mxnet.exception.MalformedModelException;
+import org.apache.mxnet.jna.JnaUtils;
+import org.apache.mxnet.ndarray.NDArray;
+import org.apache.mxnet.ndarray.NDList;
+import org.apache.mxnet.ndarray.types.DataType;
+import org.apache.mxnet.ndarray.types.Shape;
+import org.apache.mxnet.util.Pair;
+import org.apache.mxnet.util.PairList;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * {@code SymbolBlock} is a {@link MxResource}. It is used to load models that were exported
+ * directly from the engine in its native format.
+ */
+public class SymbolBlock extends MxResource {
+
+ private static final Logger logger = LoggerFactory.getLogger(SymbolBlock.class);
+
+ /** The shape of the input for this block, set by the initialization process. */
+ protected Shape[] inputShapes;
+
+ /** List of names for the input, named inputs should be manually set in sub class. */
+ protected List<String> inputNames = Collections.emptyList();
+
+ /**
+ * The model version of this block, used for checking if parameters are still valid during
+ * parameter loading.
+ */
+ protected byte version;
+
+ /**
+ * All direct parameters of this Block. Keys are name of the parameters.
+ *
+ * <p>Use the {@link SymbolBlock#addParameter(Parameter)} method to add children. All parameters
+ * in this map are automatically loaded / saved.
+ */
+ @SuppressWarnings("PMD.UseConcurrentHashMap")
+ protected Map<String, Parameter> parameters = new LinkedHashMap<>();
+
+ private static final byte VERSION = 3;
+
+ private CachedOp op;
+ private Symbol symbol;
+ private List<Parameter> mxNetParams; // includes input data
+ private Map<String, Shape> paramShapes;
+ private Shape[] outputShapes;
+ private PairList<String, Shape> inputDescriptions;
+ private PairList<String, Shape> outputDescriptions;
+ private boolean first;
+
+ /**
+ * Constructs a {@code MxSymbolBlock} for a {@link Symbol}.
+ *
+ * @param parent the parent MxResource to use for the block
+ * @param symbol the symbol containing the block's symbolic graph
+ */
+ public SymbolBlock(MxResource parent, Symbol symbol) {
+ super();
+ setParent(parent);
+ this.symbol = symbol;
+ initBlock();
+ }
+
+ /**
+ * Constructs an empty {@code MxSymbolBlock}.
+ *
+ * @param parent the parent {@code MxSymbolBlock} instance to manage this MxSymbolBlock
+ */
+ private SymbolBlock(MxResource parent) {
+ super();
+ setParent(parent);
+ }
+
+ /**
+ * Constructs an {@code MxSymbolBlock} and load the symbol according to {@code Path} The life
+ * circle of the {@code Symbol} instance is managed by parent {@code MxResource}.
+ *
+ * @param parent the parent MxResource Object to manage this MxSymbolBlock
+ * @param symbolPath the Path to load symbol
+ * @return created {@code SymbolBlock} instance
+ */
+ public static SymbolBlock createMxSymbolBlock(MxResource parent, Path symbolPath) {
+ SymbolBlock symbolBlock = new SymbolBlock(parent);
+ symbolBlock.loadSymbol(symbolPath);
+ symbolBlock.initBlock();
+ return symbolBlock;
+ }
+
+ private void loadSymbol(Path symbolPath) {
+ this.symbol = Symbol.loadSymbol(this, symbolPath);
+ }
+
+ /**
+ * Sets the names of the input data.
+ *
+ * @param inputNames the names of the input data
+ */
+ public void setInputNames(List<String> inputNames) {
+ this.inputNames = inputNames;
+ // now that we know which of the parameters are just input placeholders and which
+ // are trainable, add them properly so they are correctly handled
+ Set<String> nameLookup = new HashSet<>(inputNames);
+ for (Parameter mxNetParameter : mxNetParams) {
+ if (!nameLookup.contains(mxNetParameter.getName())) {
+ addParameter(mxNetParameter);
+ }
+ }
+ }
+
+ protected final Parameter addParameter(Parameter parameter) {
+ parameters.put(parameter.getName(), parameter);
+ return parameter;
+ }
+
+ /**
+ * Returns the list of inputs and parameter NDArrays.
+ *
+ * @return the list of inputs and parameter NDArrays
+ */
+ public List<Parameter> getAllParameters() {
+ return mxNetParams;
+ }
+
+ /**
+ * Returns the layers' name.
+ *
+ * @return a List of String containing the layers' name
+ */
+ public List<String> getLayerNames() {
+ return symbol.getLayerNames();
+ }
+
+ /**
+ * Returns the Symbolic graph from the model.
+ *
+ * @return a {@link Symbol} object
+ */
+ public Symbol getSymbol() {
+ return symbol;
+ }
+
+ /**
+ * Applies Optimization algorithm for the model.
+ *
+ * @param optimization the name of the optimization
+ * @param device the device assigned
+ */
+ public void optimizeFor(String optimization, Device device) {
+ Symbol newSymbol = symbol.optimizeFor(optimization, device);
+ symbol.close();
+ symbol = newSymbol;
+ }
+
+ /**
+ * Returns a {@link PairList} of input names, and shapes.
+ *
+ * @return the {@link PairList} of input names, and shapes
+ */
+ public PairList<String, Shape> describeInput() {
+ if (inputDescriptions == null) {
+ inputDescriptions = new PairList<>();
+ for (String name : inputNames) {
+ // Add empty shapes as input shapes are not saved
+ // in MXNet models
+ logger.warn(
+ "Input shapes are unknown, please run predict or forward once"
+ + "and call describeInput again.");
+ inputDescriptions.add(name, new Shape());
+ }
+ }
+ return inputDescriptions;
+ }
+
+ /**
+ * Returns a {@link PairList} of output names and shapes stored in model file.
+ *
+ * @return the {@link PairList} of output names, and shapes
+ */
+ public PairList<String, Shape> describeOutput() {
+ if (outputDescriptions == null) {
+ logger.warn(
+ "Output shapes are unknown, please run predict or forward once"
+ + "and call describeOutput again.");
+ }
+ return outputDescriptions;
+ }
+
+ /**
+ * Applies the operating function of the mxSymbolBlock once. This method should be called only
+ * on blocks that are initialized.
+ *
+ * @param inputs the input NDList
+ * @param params optional parameters
+ * @param device device to use
+ * @return the output of the forward pass
+ */
+ public final NDList forward(NDList inputs, PairList<String, Object> params, Device device) {
+
+ if (!isInitialized()) {
+ initialize(getParent(), DataType.FLOAT32, device, inputs.getShapes());
+ }
+ return forwardInternal(inputs, params);
+ }
+
+ /**
+ * Applies the operating function of the block once. This method should be called only on blocks
+ * that are initialized.
+ *
+ * @param inputs the input NDList
+ * @return the output of the forward pass
+ */
+ public NDList forward(NDList inputs) {
+ return forward(inputs, null, getDevice());
+ }
+
+ /**
+ * A forward call using both training data and labels.
+ *
+ * <p>Within this forward call, it can be assumed that training is true.
+ *
+ * @param data the input data NDList
+ * @param labels the input labels NDList
+ * @param params optional parameters
+ * @param device the device assigned
+ * @return the output of the forward pass
+ * @see #forward(NDList, PairList, Device)
+ */
+ public NDList forward(
+ NDList data, NDList labels, PairList<String, Object> params, Device device) {
+ if (!isInitialized()) {
+ initialize(getParent(), DataType.FLOAT32, device, data.getShapes());
+ }
+ return forwardInternal(data, labels, params);
+ }
+
+ /**
+ * A helper for {@link SymbolBlock#forward(NDList, NDList, PairList, Device)} after
+ * initialization.
+ *
+ * @param data the input data NDList
+ * @param labels the input labels NDList
+ * @param params optional parameters
+ * @return the output of the forward pass
+ * @see #forward(NDList, PairList, Device)
+ */
+ protected NDList forwardInternal(NDList data, NDList labels, PairList<String, Object> params) {
+ return forwardInternal(data, params);
+ }
+
+ protected NDList forwardInternal(NDList inputs, PairList<String, Object> params) {
+ if (first) {
+ synchronized (SymbolBlock.class) {
+ if (first) {
+ // create CachedOp is not thread-safe
+ // add synchronized block to avoid creating multiple CachedOps
+ op = JnaUtils.createCachedOp(this, getParent());
+ inputDescriptions = new PairList<>();
+ outputDescriptions = new PairList<>();
+ for (NDArray array : inputs) {
+ inputDescriptions.add(array.getName(), array.getShape());
+ }
+ NDList outputs = op.forward(inputs);
+ for (NDArray array : outputs) {
+ outputDescriptions.add(array.getName(), array.getShape());
+ }
+ first = false;
+ return outputs;
+ }
+ }
+ }
+ return op.forward(inputs);
+ }
+
+ /**
+ * Returns a boolean whether the {@link SymbolBlock} is initialized.
+ *
+ * @return whether the block is initialized
+ */
+ public boolean isInitialized() {
+ for (Parameter param : getParameters().values()) {
+ if (!param.isInitialized()) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ /**
+ * Initializes the parameters of the block. This method must be called before calling `forward`.
+ *
+ * @param parent the parent {@link MxResource} to manage this initialized
+ * @param dataType the datatype of the parameters
+ * @param device the device of the parameters
+ * @param inputShapes the shapes of the inputs to the block
+ */
+ public void initialize(
+ MxResource parent, DataType dataType, Device device, Shape... inputShapes) {
+ beforeInitialize(inputShapes);
+
+ // no need to initialize() for inference
+
+ for (Parameter parameter : parameters.values()) {
+ parameter.initialize(parent, dataType, device);
+ }
+ initializeChildBlocks();
+ }
+
+ /**
+ * Initializes the Child blocks of this block. You need to override this method if your subclass
+ * has child blocks. Used to determine the correct input shapes for child blocks based on the
+ * requested input shape for this block.
+ */
+ protected void initializeChildBlocks() {
+ if (!getSubResource().isEmpty()) {
+ throw new IllegalStateException(
+ getClass().getSimpleName()
+ + " has child blocks but initializeChildBlocks is not overwritten.");
+ }
+ }
+
+ protected void beforeInitialize(Shape... inputShapes) {
+ if (inputNames.isEmpty()) {
+ // automatically assign input names
+ inputNames = new ArrayList<>();
+ for (int i = 0; i < inputShapes.length; ++i) {
+ inputNames.add("data" + i);
+ }
+ }
+ this.inputShapes = inputShapes;
+ }
+
+ /**
+ * Returns a list of all the parameters of the block, including the parameters of its children
+ * fetched recursively.
+ *
+ * @return the list of all parameters of the SymbolBlock
+ */
+ public ParameterList getParameters() {
+ // we accumulate a list of all parameters by starting with a list of the direct parameters
+ ParameterList allParams = getDirectParameters();
+ // then we add the parameters of child blocks
+ for (Pair<String, MxResource> childPair : getChildren()) {
+ if (SymbolBlock.class.equals(childPair.getValue().getClass())) {
+ SymbolBlock symbolBlock = (SymbolBlock) childPair.getValue();
+ for (Pair<String, Parameter> paramPair : symbolBlock.getParameters()) {
+ // we prepend the name of the child block to the parameter name
+ allParams.add(
+ childPair.getKey() + "_" + paramPair.getKey(), paramPair.getValue());
+ }
+ }
+ }
+ return allParams;
+ }
+
+ /**
+ * Returns a list of all the children of the SymbolBlock.
+ *
+ * @return the list of child blocks
+ */
+ public MxResourceList getChildren() {
+ MxResourceList defensiveCopy = new MxResourceList(getSubResource().size());
+ for (Map.Entry<String, MxResource> entry : getSubResource().entrySet()) {
+ defensiveCopy.add(entry.getKey(), entry.getValue());
+ }
+ return defensiveCopy;
+ }
+
+ /**
+ * Returns a list of all the direct parameters of the SymbolBlock.
+ *
+ * @return the list of {@link Parameter}
+ */
+ public ParameterList getDirectParameters() {
+ return new ParameterList(parameters);
+ }
+
+ /**
+ * Returns the expected output shapes of the SymbolBlock for the specified input shapes.
+ *
+ * @param inputShapes the shapes of the inputs
+ * @return the expected output shapes of the block
+ */
+ public Shape[] getOutputShapes(Shape[] inputShapes) {
+ if (outputShapes == null) {
+ String[] outputNames = symbol.getOutputNames();
+ outputShapes = new Shape[outputNames.length];
+ for (int i = 0; i < outputShapes.length; ++i) {
+ outputShapes[i] = getParameterShape(outputNames[i], inputShapes);
+ }
+ }
+ return outputShapes;
+ }
+
+ /** Removes the last block in the symbolic graph. */
+ public void removeLastBlock() {
+ List<String> layerNames = getLayerNames();
+ String layerName = layerNames.get(layerNames.size() - 2);
+
+ Symbol sliced = symbol.get(layerName);
+ symbol.close();
+ symbol = sliced;
+
+ HashSet<String> set = new HashSet<>(Arrays.asList(symbol.getAllNames()));
+ for (int i = mxNetParams.size() - 1; i >= 0; --i) {
+ Parameter parameter = mxNetParams.get(i);
+ if (!set.contains(parameter.getName())) {
+ mxNetParams.remove(i).close();
+ parameters.remove(parameter.getName(), parameter);
+ }
+ }
+ }
+
+ private Shape getParameterShape(String name, Shape[] inputShapes) {
+ if (paramShapes == null) {
+ PairList<String, Shape> pairs = new PairList<>();
+ for (int i = 0; i < inputNames.size(); i++) {
+ pairs.add(inputNames.get(i), inputShapes[i]);
+ }
+ paramShapes = symbol.inferShape(pairs);
+ }
+ if (paramShapes.containsKey(name)) {
+ return paramShapes.get(name);
+ } else {
+ throw new IllegalArgumentException("Name " + name + " not found");
+ }
+ }
+
+ /**
+ * Writes the parameters of the SymbolBlock to the given outputStream.
+ *
+ * @param os the outputstream to save the parameters to
+ * @throws IOException if an I/O error occurs
+ */
+ public void saveParameters(DataOutputStream os) throws IOException {
+ os.writeByte(VERSION);
+ String json = symbol.toJsonString();
+ // symbol size may go beyond os.writeUTF() size (65535)
+ byte[] bytes = json.getBytes(StandardCharsets.UTF_8);
+ os.writeInt(bytes.length);
+ os.write(bytes);
+ int size = inputNames.size();
+ os.writeInt(size);
+ for (String name : inputNames) {
+ os.writeUTF(name);
+ }
+ for (Parameter parameter : mxNetParams) {
+ parameter.save(os);
+ }
+ }
+
+ /**
+ * Loads the parameters from the given input stream.
+ *
+ * @param parent the parent {@link MxResource} to create the parameter arrays
+ * @param is the inputstream that stream the parameter values
+ * @throws IOException if an I/O error occurs
+ * @throws MalformedModelException if the model file is corrupted or unsupported
+ */
+ public void loadParameters(MxResource parent, DataInputStream is) throws IOException {
+ Byte currentVersion = is.readByte();
+ if (currentVersion > VERSION) {
+ throw new MalformedModelException("Unsupported encoding version: " + version);
+ }
+ if (currentVersion < VERSION && symbol == null) {
+ throw new IllegalStateException(
+ "Symbol is required for version 2, please use Model to load");
+ }
+ if (currentVersion == VERSION) {
+ int len = is.readInt();
+ byte[] bytes = new byte[len];
+ if (is.read(bytes) == -1) {
+ throw new MalformedModelException("InputStream ends at symbol loading!");
+ }
+ // init block only if it is not set
+ symbol = Symbol.loadJson(this, new String(bytes, StandardCharsets.UTF_8));
+ initBlock();
+ }
+ int size = is.readInt();
+ for (int i = 0; i < size; ++i) {
+ inputNames.add(is.readUTF());
+ }
+
+ for (Parameter parameter : mxNetParams) {
+ parameter.load(parent, is);
+ }
+ setInputNames(inputNames);
+ }
+
+ private void initBlock() {
+ inputNames = new ArrayList<>();
+
+ String[] allNames = symbol.getAllNames();
+ mxNetParams = new ArrayList<>(allNames.length);
+
+ for (String name : allNames) {
+ Parameter.Type type = inferType(name);
+ mxNetParams.add(Parameter.builder().setName(name).setType(type).build());
+ }
+ first = true;
+ }
+
+ private static Parameter.Type inferType(String name) {
+ if (name.endsWith("bias")) {
+ return Parameter.Type.BIAS;
+ } else if (name.endsWith("gamma")) {
+ return Parameter.Type.GAMMA;
+ } else if (name.endsWith("beta")) {
+ return Parameter.Type.BETA;
+ } else if (name.endsWith("moving_mean") || name.endsWith("running_mean")) {
+ return Parameter.Type.RUNNING_MEAN;
+ } else if (name.endsWith("moving_var") || name.endsWith("running_var")) {
+ return Parameter.Type.RUNNING_VAR;
+ } else if (name.endsWith("weight")) {
+ return Parameter.Type.WEIGHT;
+ }
+ return Parameter.Type.OTHER;
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/nn/package-info.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/nn/package-info.java
new file mode 100644
index 0000000..283e625
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/nn/package-info.java
@@ -0,0 +1,19 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/** Contains the Java front-end implementation for Apache MXNet. */
+package org.apache.mxnet.nn;
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/repository/Item.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/repository/Item.java
new file mode 100644
index 0000000..0f085fc
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/repository/Item.java
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.repository;
+
+/**
+ * {@link Item} is some listed repositories where we can download pre-trained models data. It is
+ * used by developers to download specific models data by initialize a {@link Repository}.
+ */
+public enum Item {
+ MLP("mlp", "https://resources.djl.ai/test-models/mlp.tar.gz");
+
+ private String name;
+ private String url;
+
+ Item(String name, String url) {
+ this.name = name;
+ this.url = url;
+ }
+
+ /**
+ * Gets the name of this {@code Item}.
+ *
+ * @return the name of this {@code Item}
+ */
+ public String getName() {
+ return name;
+ }
+
+ /**
+ * Gets the URL of this {@code Item} to download.
+ *
+ * @return the URL of this {@code Item}
+ */
+ public String getUrl() {
+ return url;
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/repository/Repository.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/repository/Repository.java
new file mode 100644
index 0000000..2eacee5
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/repository/Repository.java
@@ -0,0 +1,257 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.repository;
+
+import java.io.BufferedInputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.net.URI;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.StandardCopyOption;
+import java.util.zip.GZIPInputStream;
+import java.util.zip.ZipInputStream;
+import org.apache.commons.compress.archivers.tar.TarArchiveEntry;
+import org.apache.commons.compress.archivers.tar.TarArchiveInputStream;
+import org.apache.commons.compress.compressors.gzip.GzipCompressorInputStream;
+import org.apache.mxnet.util.FilenameUtils;
+import org.apache.mxnet.util.Utils;
+import org.apache.mxnet.util.ZipUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * {@code Repository} is a format for storing data {@link Item}s for various uses including deep
+ * learning models and datasets.
+ */
+public class Repository {
+
+ private static final Logger logger = LoggerFactory.getLogger(Repository.class);
+
+ private String name;
+ private URI uri;
+ private Path resourceDir;
+
+ Repository(String name, String uri) {
+ setName(name);
+ setUri(URI.create(uri));
+ }
+
+ Repository(Item item) {
+ this(item.getName(), item.getUrl());
+ }
+
+ /**
+ * Initialize a {@link Repository} by a specific {@link Item}, which provides the name for the
+ * repository and the URL to achieve it.
+ *
+ * @param item {@link Item} to initialize the {@link Repository}
+ * @return {@link Path} of the initialized {@link Repository}
+ * @throws IOException when fail to prepare the {@link Repository}
+ */
+ public static Path initRepository(Item item) throws IOException {
+ Repository repository = new Repository(item);
+ repository.prepare();
+ return repository.getLocalDir();
+ }
+
+ private void setResourceDir(Path mResourceDir) {
+ this.resourceDir = mResourceDir;
+ }
+
+ private Path getResourceDir() {
+ return resourceDir;
+ }
+
+ /**
+ * Returns the local directory to store resources.
+ *
+ * @return {@link Path} of the local resource directory
+ */
+ public Path getLocalDir() {
+ return getResourceDir().resolve(getName());
+ }
+
+ /**
+ * Sets the {@link URI} for the {@link Repository}.
+ *
+ * @param uri of the repository
+ */
+ public final void setUri(URI uri) {
+ this.uri = uri;
+ }
+
+ /**
+ * Returns {@link URI} for the {@link Repository}.
+ *
+ * @return {@link URI} of the {@link Repository}
+ */
+ public URI getUri() {
+ return uri;
+ }
+
+ /**
+ * Sets the name for the {@link Repository}.
+ *
+ * @param name for the {@link Repository}
+ */
+ public final void setName(String name) {
+ this.name = name;
+ }
+
+ /**
+ * Returns the name for the {@link Repository}.
+ *
+ * @return name for the {@link Repository}
+ */
+ public String getName() {
+ return name;
+ }
+
+ /**
+ * Prepares the repository for use.
+ *
+ * @throws IOException if it failed to prepare
+ */
+ public void prepare() throws IOException {
+ String uriPath = getUri().getPath();
+ if (uriPath != null && !"".equals(uriPath) && uriPath.charAt(0) == '/') {
+ uriPath = uriPath.substring(1);
+ }
+ setResourceDir(getCacheDirectory().resolve(uriPath));
+ if (Files.exists(getResourceDir())) {
+ logger.debug("Files have been downloaded already: {}", getResourceDir());
+ return;
+ }
+ Path parentDir = getResourceDir().toAbsolutePath().getParent();
+ if (parentDir == null) {
+ throw new AssertionError(
+ String.format(
+ "Parent path should never be null: {}", getResourceDir().toString()));
+ }
+
+ Files.createDirectories(parentDir);
+ Path tmp = Files.createTempDirectory(parentDir, getResourceDir().toFile().getName());
+
+ // dismiss Progress related
+
+ try {
+ logger.debug("Repository to download: {}", getUri().toString());
+ download(tmp);
+ Utils.moveQuietly(tmp, getResourceDir());
+ } finally {
+ Utils.deleteQuietly(tmp);
+ }
+ }
+
+ private void download(Path tmp) throws IOException {
+ logger.debug("Downloading artifact: {} at {}...", getName(), getUri());
+ try (InputStream is = getUri().toURL().openStream()) {
+ String extension = FilenameUtils.getFileType(getUri().getPath());
+ save(is, tmp, name, extension, isArchiveFile(extension));
+ }
+ }
+
+ private boolean isArchiveFile(String fileType) {
+ return "tgz".equals(fileType) || "zip".equals(fileType) || "tar".equals(fileType);
+ }
+
+ protected void save(
+ InputStream is, Path tmp, String repoName, String extension, boolean archive)
+ throws IOException {
+ // ProgressInputStream pis = new ProgressInputStream(is);
+
+ if (archive) {
+ Path diretory;
+ if (!repoName.isEmpty()) {
+ // honer the name set in metadata.json
+ diretory = tmp.resolve(repoName);
+ Files.createDirectories(diretory);
+ } else {
+ diretory = tmp;
+ }
+ if ("zip".equals(extension)) {
+ ZipUtils.unzip(is, diretory);
+ } else if ("tgz".equals(extension)) {
+ untar(is, diretory, true);
+ } else if ("tar".equals(extension)) {
+ untar(is, diretory, false);
+ } else {
+ throw new IOException("File type is not supported: " + extension);
+ }
+ } else {
+ Path file = tmp.resolve(repoName);
+ if ("zip".equals(extension)) {
+ ZipInputStream zis = new ZipInputStream(is);
+ zis.getNextEntry();
+ Files.copy(zis, file, StandardCopyOption.REPLACE_EXISTING);
+ } else if ("gzip".equals(extension)) {
+ Files.copy(new GZIPInputStream(is), file, StandardCopyOption.REPLACE_EXISTING);
+ } else {
+ Files.copy(is, file, StandardCopyOption.REPLACE_EXISTING);
+ }
+ }
+ // pis.validateChecksum(item);
+ }
+
+ private void untar(InputStream is, Path dir, boolean gzip) throws IOException {
+ InputStream bis;
+ if (gzip) {
+ bis = new GzipCompressorInputStream(new BufferedInputStream(is));
+ } else {
+ bis = new BufferedInputStream(is);
+ }
+ try (TarArchiveInputStream tis = new TarArchiveInputStream(bis)) {
+ TarArchiveEntry entry;
+ while ((entry = tis.getNextTarEntry()) != null) {
+ String entryName = entry.getName();
+ if (entryName.contains("..")) {
+ throw new IOException("Malicious zip entry: " + entryName);
+ }
+ Path file = dir.resolve(entryName).toAbsolutePath();
+ if (entry.isDirectory()) {
+ Files.createDirectories(file);
+ } else {
+ Path parentFile = file.getParent();
+ if (parentFile == null) {
+ throw new AssertionError(
+ "Parent path should never be null: " + file.toString());
+ }
+ Files.createDirectories(parentFile);
+ Files.copy(tis, file, StandardCopyOption.REPLACE_EXISTING);
+ }
+ }
+ }
+ }
+
+ /**
+ * Returns the cache directory for the repository.
+ *
+ * @return the cache directory path
+ * @throws IOException if it failed to ensure the creation of the cache directory
+ */
+ public Path getCacheDirectory() throws IOException {
+ Path dir = Utils.getCacheDir().resolve("cache/repo");
+ if (Files.notExists(dir)) {
+ Files.createDirectories(dir);
+ } else if (!Files.isDirectory(dir)) {
+ throw new IOException("Failed initialize cache directory: " + dir.toString());
+ }
+ return dir;
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/repository/package-info.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/repository/package-info.java
new file mode 100644
index 0000000..6248f96
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/repository/package-info.java
@@ -0,0 +1,19 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/** Contains the Java front-end implementation for Apache MXNet. */
+package org.apache.mxnet.repository;
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/translate/NoOpTranslator.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/translate/NoOpTranslator.java
new file mode 100644
index 0000000..0c5d69f
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/translate/NoOpTranslator.java
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.translate;
+
+import org.apache.mxnet.ndarray.NDList;
+
+/**
+ * Default no operational implement for {@link Translator} to process input and output {@link
+ * org.apache.mxnet.ndarray.NDArray}.
+ */
+public class NoOpTranslator implements Translator<NDList, NDList> {
+
+ /** {@inheritDoc} */
+ @Override
+ public Pipeline getPipeline() {
+ return Translator.super.getPipeline();
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public NDList processInput(NDList input) {
+ return input;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public NDList processOutput(NDList output) {
+ return output;
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/translate/Pipeline.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/translate/Pipeline.java
new file mode 100644
index 0000000..15facc4
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/translate/Pipeline.java
@@ -0,0 +1,220 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.translate;
+
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+import org.apache.mxnet.ndarray.NDArray;
+import org.apache.mxnet.ndarray.NDList;
+import org.apache.mxnet.util.Pair;
+import org.apache.mxnet.util.PairList;
+
+/** {@code Pipeline} allows applying multiple transforms on an input {@link NDList}. */
+public class Pipeline {
+
+ private PairList<IndexKey, Transform> transforms;
+
+ /** Creates a new instance of {@code Pipeline} that has no {@link Transform} defined yet. */
+ public Pipeline() {
+ transforms = new PairList<>();
+ }
+
+ /**
+ * Creates a new instance of {@code Pipeline} that can apply the given transforms on its input.
+ *
+ * <p>Since no keys are provided for these transforms, they will be applied to the first element
+ * in the input {@link NDList} when the {@link #transform(NDList) transform} method is called on
+ * this object.
+ *
+ * @param transforms the transforms to be applied when the {@link #transform(NDList) transform}
+ * method is called on this object
+ */
+ public Pipeline(Transform... transforms) {
+ this.transforms = new PairList<>();
+ for (Transform transform : transforms) {
+ this.transforms.add(new IndexKey(0), transform);
+ }
+ }
+
+ /**
+ * Adds the given {@link Transform} to the list of transforms to be applied on the input when
+ * the {@link #transform(NDList) transform} method is called on this object.
+ *
+ * <p>Since no keys are provided for this {@link Transform}, it will be applied to the first
+ * element in the input {@link NDList}.
+ *
+ * @param transform the {@link Transform} to be added
+ * @return this {@code Pipeline}
+ */
+ public Pipeline add(Transform transform) {
+ transforms.add(new IndexKey(0), transform);
+ return this;
+ }
+
+ /**
+ * Adds the given {@link Transform} to the list of transforms to be applied on the {@link
+ * NDArray} at the given index in the input {@link NDList}.
+ *
+ * @param index the index corresponding to the {@link NDArray} in the input {@link NDList} on
+ * which the given transform must be applied to
+ * @param transform the {@link Transform} to be added
+ * @return this {@code Pipeline}
+ */
+ public Pipeline add(int index, Transform transform) {
+ transforms.add(new IndexKey(index), transform);
+ return this;
+ }
+
+ /**
+ * Adds the given {@link Transform} to the list of transforms to be applied on the {@link
+ * NDArray} with the given key as name in the input {@link NDList}.
+ *
+ * @param name the key corresponding to the {@link NDArray} in the input {@link NDList} on which
+ * the given transform must be applied to
+ * @param transform the {@code Transform} to be applied when the {@link #transform(NDList)
+ * transform} method is called on this object
+ * @return this {@code Pipeline}
+ */
+ public Pipeline add(String name, Transform transform) {
+ transforms.add(new IndexKey(name), transform);
+ return this;
+ }
+
+ /**
+ * Inserts the given {@link Transform} to the list of transforms at the given position.
+ *
+ * <p>Since no keys or indices are provided for this {@link Transform}, it will be applied to
+ * the first element in the input {@link NDList} when the {@link #transform(NDList) transform}
+ * method is called on this object.
+ *
+ * @param position the position at which the {@link Transform} must be inserted
+ * @param transform the {@code Transform} to be inserted
+ * @return this {@code Pipeline}
+ */
+ public Pipeline insert(int position, Transform transform) {
+ transforms.add(position, new IndexKey(0), transform);
+ return this;
+ }
+
+ /**
+ * Inserts the given {@link Transform} to the list of transforms at the given position to be
+ * applied on the {@link NDArray} at the given index in the input {@link NDList}.
+ *
+ * @param position the position at which the {@link Transform} must be inserted
+ * @param index the index corresponding to the {@link NDArray} in the input {@link NDList} on
+ * which the given transform must be applied to
+ * @param transform the {@code Transform} to be inserted
+ * @return this {@code Pipeline}
+ */
+ public Pipeline insert(int position, int index, Transform transform) {
+ transforms.add(position, new IndexKey(index), transform);
+ return this;
+ }
+
+ /**
+ * Inserts the given {@link Transform} to the list of transforms at the given position to be
+ * applied on the {@link NDArray} with the given name in the input {@link NDList}.
+ *
+ * @param position the position at which the {@link Transform} must be inserted
+ * @param name the key corresponding to the {@link NDArray} in the input {@link NDList} on which
+ * the given transform must be applied to
+ * @param transform the {@code Transform} to be inserted
+ * @return this {@code Pipeline}
+ */
+ public Pipeline insert(int position, String name, Transform transform) {
+ transforms.add(position, new IndexKey(name), transform);
+ return this;
+ }
+
+ /**
+ * Applies the transforms configured in this object on the input {@link NDList}.
+ *
+ * <p>If a key is specified with the transform, those transforms will only be applied to the
+ * {@link NDArray} in the input {@link NDList}. If a key is not specified, it will be applied to
+ * the first element in the input {@link NDList}.
+ *
+ * @param input the input {@link NDList} on which the tranforms are to be applied
+ * @return the output {@link NDList} after applying the tranforms
+ */
+ public NDList transform(NDList input) {
+ if (transforms.isEmpty() || input.isEmpty()) {
+ return input;
+ }
+
+ NDArray[] arrays = input.toArray(new NDArray[0]);
+
+ Map<IndexKey, Integer> map = new ConcurrentHashMap<>();
+ // create mapping
+ for (int i = 0; i < input.size(); i++) {
+ String key = input.get(i).getName();
+ if (key != null) {
+ map.put(new IndexKey(key), i);
+ }
+ map.put(new IndexKey(i), i);
+ }
+ // apply transform
+ for (Pair<IndexKey, Transform> transform : transforms) {
+ IndexKey key = transform.getKey();
+ int index = map.get(key);
+ NDArray array = arrays[index];
+
+ arrays[index] = transform.getValue().transform(array);
+ arrays[index].setName(array.getName());
+ }
+
+ return new NDList(arrays);
+ }
+
+ private static final class IndexKey {
+ private String key;
+ private int index;
+
+ private IndexKey(String key) {
+ this.key = key;
+ }
+
+ private IndexKey(int index) {
+ this.index = index;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public int hashCode() {
+ if (key == null) {
+ return index;
+ }
+ return key.hashCode();
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public boolean equals(Object obj) {
+ if (this == obj) {
+ return true;
+ }
+ if (!(obj instanceof IndexKey)) {
+ return false;
+ }
+ IndexKey other = (IndexKey) obj;
+ if (key == null) {
+ return index == other.index;
+ }
+ return key.equals(other.key);
+ }
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/translate/Processor.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/translate/Processor.java
new file mode 100644
index 0000000..cee04d0
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/translate/Processor.java
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.translate;
+
+import org.apache.mxnet.ndarray.NDList;
+
+/**
+ * An interface that provides pre-processing and post-processing functionality.
+ *
+ * @param <I> the type of the input object
+ */
+public interface Processor<I, O> {
+
+ /**
+ * Gets the {@link Pipeline} applied to the input.
+ *
+ * @return the {@link Pipeline}
+ */
+ default Pipeline getPipeline() {
+ throw new UnsupportedOperationException("Not implemented.");
+ }
+
+ /**
+ * Processes the input and converts it to NDList.
+ *
+ * @param input the input object
+ * @return the {@link NDList} after pre-processing
+ * @throws Exception if an error occurs during processing input
+ */
+ @SuppressWarnings("PMD.SignatureDeclareThrowsException")
+ NDList processInput(I input) throws Exception;
+
+ /**
+ * Processes the input and converts it to NDList.
+ *
+ * @param output the input object
+ * @return the {@link NDList} after pre-processing
+ * @throws Exception if an error occurs during processing input
+ */
+ @SuppressWarnings("PMD.SignatureDeclareThrowsException")
+ O processOutput(NDList output) throws Exception;
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/translate/Transform.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/translate/Transform.java
new file mode 100644
index 0000000..8e24304
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/translate/Transform.java
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.translate;
+
+import org.apache.mxnet.ndarray.NDArray;
+
+/**
+ * An interface to apply various transforms to the input.
+ *
+ * <p>A transform can be any function that modifies the input. Some examples of transform are crop
+ * and resize.
+ */
+// TODO : not used by now
+public interface Transform {
+ /**
+ * Applies the {@code Transform} to the given {@link NDArray}.
+ *
+ * @param array the {@link NDArray} on which the {@link Transform} is applied
+ * @return the output of the {@code Transform}
+ */
+ NDArray transform(NDArray array);
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/translate/Translator.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/translate/Translator.java
new file mode 100644
index 0000000..c476e24
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/translate/Translator.java
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.translate;
+
+import java.io.IOException;
+import org.apache.mxnet.engine.Model;
+import org.apache.mxnet.engine.Predictor;
+
+/**
+ * The {@code Translator} interface provides model pre-processing and postprocessing functionality.
+ *
+ * <p>Users can use this in {@link Predictor} with input and output objects specified. The following
+ * is an example of processing an image and creating classification output:
+ *
+ * @param <I> the input type
+ * @param <O> the output type
+ */
+public interface Translator<I, O> extends Processor<I, O> {
+ // TODO: implement getPipeline() and related methods
+ /**
+ * Prepares the translator with the manager and model to use.
+ *
+ * @param model the model to translate for
+ * @throws IOException if there is an error reading inputs for preparing the translator
+ */
+ default void prepare(Model model) throws IOException {}
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/translate/package-info.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/translate/package-info.java
new file mode 100644
index 0000000..5aaeeb5
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/translate/package-info.java
@@ -0,0 +1,19 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/** Contains the Java front-end implementation for Apache MXNet. */
+package org.apache.mxnet.translate;
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/FilenameUtils.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/FilenameUtils.java
new file mode 100644
index 0000000..5493893
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/FilenameUtils.java
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mxnet.util;
+
+import java.util.Locale;
+
+/** A class containing utility methods. */
+public final class FilenameUtils {
+
+ private FilenameUtils() {}
+
+ /**
+ * Returns the type of the file.
+ *
+ * @param fileName the file name
+ * @return the type of the file
+ */
+ public static String getFileType(String fileName) {
+ fileName = fileName.toLowerCase(Locale.ROOT);
+ if (fileName.endsWith(".zip")) {
+ return "zip";
+ } else if (fileName.endsWith(".tgz")
+ || fileName.endsWith(".tar.gz")
+ || fileName.endsWith(".tar.z")) {
+ return "tgz";
+ } else if (fileName.endsWith(".tar")) {
+ return "tar";
+ } else if (fileName.endsWith(".gz") || fileName.endsWith(".z")) {
+ return "gzip";
+ } else {
+ return "";
+ }
+ }
+
+ /**
+ * Returns if the the file is an archive file.
+ *
+ * @param fileName the file name
+ * @return the type of the file
+ */
+ public static boolean isArchiveFile(String fileName) {
+ String fileType = getFileType(fileName);
+ return "tgz".equals(fileType) || "zip".equals(fileType) || "tar".equals(fileType);
+ }
+
+ /**
+ * Returns the name of the file without file extension.
+ *
+ * @param name the file name
+ * @return the name of the file without file extension
+ */
+ public static String getNamePart(String name) {
+ String lowerCase = name.toLowerCase(Locale.ROOT);
+ if (lowerCase.endsWith(".tar.gz")) {
+ return name.substring(0, name.length() - 7);
+ } else if (name.endsWith(".tar.z")) {
+ return name.substring(0, name.length() - 6);
+ } else if (name.endsWith(".tgz") || name.endsWith(".zip") || name.endsWith(".tar")) {
+ return name.substring(0, name.length() - 4);
+ } else if (name.endsWith(".gz")) {
+ return name.substring(0, name.length() - 3);
+ } else if (name.endsWith(".z")) {
+ return name.substring(0, name.length() - 2);
+ }
+ return name;
+ }
+
+ /**
+ * Returns the file name extension of the file.
+ *
+ * @param fileName the file name
+ * @return the file name extension
+ */
+ public static String getFileExtension(String fileName) {
+ int pos = fileName.lastIndexOf('.');
+ if (pos > 0) {
+ return fileName.substring(pos + 1);
+ }
+ return "";
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/Float16Utils.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/Float16Utils.java
new file mode 100644
index 0000000..5961272
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/Float16Utils.java
@@ -0,0 +1,128 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.util;
+
+import java.nio.ByteBuffer;
+import java.nio.ShortBuffer;
+import org.apache.mxnet.ndarray.NDSerializer;
+
+/** {@code Float16Utils} is a set of utilities for working with float16. */
+@SuppressWarnings("PMD.AvoidUsingShortType")
+public final class Float16Utils {
+
+ private Float16Utils() {}
+
+ /**
+ * Converts a byte buffer of float16 values into a float32 array.
+ *
+ * @param buffer the buffer of float16 values as bytes.
+ * @return an array of float32 values.
+ */
+ public static float[] fromByteBuffer(ByteBuffer buffer) {
+ return fromShortBuffer(buffer.asShortBuffer());
+ }
+
+ /**
+ * Converts a short buffer of float16 values into a float32 array.
+ *
+ * @param buffer the buffer of float16 values as shorts.
+ * @return an array of float32 values.
+ */
+ public static float[] fromShortBuffer(ShortBuffer buffer) {
+ int index = 0;
+ float[] ret = new float[buffer.remaining()];
+ while (buffer.hasRemaining()) {
+ short value = buffer.get();
+ ret[index++] = halfToFloat(value);
+ }
+ return ret;
+ }
+
+ /**
+ * Converts an array of float32 values into a byte buffer of float16 values.
+ *
+ * @param floats an array of float32 values.
+ * @return a byte buffer with float16 values represented as shorts (2 bytes each).
+ */
+ public static ByteBuffer toByteBuffer(float[] floats) {
+ ByteBuffer buffer = NDSerializer.allocateDirect(floats.length * 2);
+ for (float f : floats) {
+ short value = floatToHalf(f);
+ buffer.putShort(value);
+ }
+ buffer.rewind();
+ return buffer;
+ }
+
+ /**
+ * Converts a float32 value into a float16 value.
+ *
+ * @param fVal a float32 value.
+ * @return a float16 value represented as a short.
+ */
+ public static short floatToHalf(float fVal) {
+ int bits = Float.floatToIntBits(fVal);
+ int sign = bits >>> 16 & 0x8000;
+ int val = (bits & 0x7fffffff) + 0x1000;
+ if (val >= 0x47800000) {
+ if ((bits & 0x7fffffff) >= 0x47800000) {
+ if (val < 0x7f800000) {
+ return (short) (sign | 0x7c00);
+ }
+ return (short) (sign | 0x7c00 | (bits & 0x007fffff) >>> 13);
+ }
+ return (short) (sign | 0x7bff);
+ }
+ if (val >= 0x38800000) {
+ return (short) (sign | val - 0x38000000 >>> 13);
+ }
+ if (val < 0x33000000) {
+ return (short) sign;
+ }
+ val = (bits & 0x7fffffff) >>> 23;
+ return (short)
+ (sign | ((bits & 0x7fffff | 0x800000) + (0x800000 >>> val - 102) >>> 126 - val));
+ }
+
+ /**
+ * Converts a float16 value into a float32 value.
+ *
+ * @param half a float16 value represented as a short.
+ * @return a float32 value.
+ */
+ public static float halfToFloat(short half) {
+ int mant = half & 0x03ff;
+ int exp = half & 0x7c00;
+ if (exp == 0x7c00) {
+ exp = 0x3fc00;
+ } else if (exp != 0) {
+ exp += 0x1c000;
+ if (mant == 0 && exp > 0x1c400) {
+ return Float.intBitsToFloat((half & 0x8000) << 16 | exp << 13);
+ }
+ } else if (mant != 0) {
+ exp = 0x1c400;
+ do {
+ mant <<= 1;
+ exp -= 0x400;
+ } while ((mant & 0x400) == 0);
+ mant &= 0x3ff;
+ }
+ return Float.intBitsToFloat((half & 0x8000) << 16 | (exp | mant) << 13);
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/NativeResource.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/NativeResource.java
new file mode 100644
index 0000000..e966ef4
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/NativeResource.java
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.util;
+
+import com.sun.jna.Pointer;
+import java.util.concurrent.atomic.AtomicReference;
+
+/**
+ * {@code NativeResource} is an internal class for {@link AutoCloseable} blocks of memory.
+ *
+ * @param <T> the resource that could map to a native pointer or java object
+ */
+public abstract class NativeResource<T> implements AutoCloseable {
+
+ protected final AtomicReference<T> handle;
+ private String uid;
+
+ protected NativeResource(T handle) {
+ this.handle = new AtomicReference<>(handle);
+ this.uid = handle.toString();
+ }
+
+ protected NativeResource() {
+ this.handle = null;
+ this.uid = null;
+ }
+
+ /**
+ * To initialize a NativeResource with handle = null.
+ *
+ * @param uid for the {@link NativeResource}
+ */
+ protected NativeResource(String uid) {
+ this.handle = null;
+ this.uid = uid;
+ }
+
+ /**
+ * Gets the boolean that indicates whether this resource has been released.
+ *
+ * @return whether this resource has been released
+ */
+ public boolean isReleased() {
+ return handle.get() == null;
+ }
+
+ /**
+ * Gets the {@link Pointer} to this resource.
+ *
+ * @return the {@link Pointer} to this resource
+ */
+ public T getHandle() {
+ T reference = handle.get();
+ if (reference == null) {
+ throw new IllegalStateException("Native resource has been release already.");
+ }
+ return reference;
+ }
+
+ /**
+ * Gets the unique ID of this resource.
+ *
+ * @return the unique ID of this resource
+ */
+ public final String getUid() {
+ return uid;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public void close() {
+ throw new UnsupportedOperationException("Not implemented.");
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/Pair.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/Pair.java
new file mode 100644
index 0000000..8f1da32
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/Pair.java
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.util;
+
+import java.util.Objects;
+
+/**
+ * A class containing the key-value pair.
+ *
+ * @param <K> the key type
+ * @param <V> the value type
+ */
+public class Pair<K, V> {
+
+ private K key;
+ private V value;
+
+ /**
+ * Constructs a {@code Pair} instance with key and value.
+ *
+ * @param key the key
+ * @param value the value
+ */
+ public Pair(K key, V value) {
+ this.key = key;
+ this.value = value;
+ }
+
+ /**
+ * Returns the key of this pair.
+ *
+ * @return the key
+ */
+ public K getKey() {
+ return key;
+ }
+
+ /**
+ * Returns the value of this pair.
+ *
+ * @return the value
+ */
+ public V getValue() {
+ return value;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ Pair<?, ?> pair = (Pair<?, ?>) o;
+ return Objects.equals(key, pair.key) && value.equals(pair.value);
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public int hashCode() {
+ return Objects.hash(key, value);
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/PairList.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/PairList.java
new file mode 100644
index 0000000..c803ceb
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/PairList.java
@@ -0,0 +1,380 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.util;
+
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.NoSuchElementException;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.stream.Stream;
+import java.util.stream.StreamSupport;
+
+/**
+ * The {@code PairList} class provides an efficient way to access a list of key-value pairs.
+ *
+ * @param <K> the key type
+ * @param <V> the value type
+ */
+public class PairList<K, V> implements Iterable<Pair<K, V>> {
+
+ private List<K> keys;
+ private List<V> values;
+
+ /** Constructs an empty {@code PairList}. */
+ public PairList() {
+ keys = new ArrayList<>();
+ values = new ArrayList<>();
+ }
+
+ /**
+ * Constructs an empty {@code PairList} with the specified initial capacity.
+ *
+ * @param initialCapacity the initial capacity of the list
+ * @throws IllegalArgumentException if the specified initial capacity is negative
+ */
+ public PairList(int initialCapacity) {
+ keys = new ArrayList<>(initialCapacity);
+ values = new ArrayList<>(initialCapacity);
+ }
+
+ /**
+ * Constructs a {@code PairList} containing the elements of the specified keys and values.
+ *
+ * @param keys the key list containing elements to be placed into this PairList
+ * @param values the value list containing elements to be placed into this PairList
+ * @throws IllegalArgumentException if the keys and values size are different
+ */
+ public PairList(List<K> keys, List<V> values) {
+ if (keys.size() != values.size()) {
+ throw new IllegalArgumentException("key value size mismatch.");
+ }
+ this.keys = keys;
+ this.values = values;
+ }
+
+ /**
+ * Constructs a {@code PairList} containing the elements of the specified list of Pairs.
+ *
+ * @param list the list containing elements to be placed into this PairList
+ */
+ public PairList(List<Pair<K, V>> list) {
+ this(list.size());
+ for (Pair<K, V> pair : list) {
+ keys.add(pair.getKey());
+ values.add(pair.getValue());
+ }
+ }
+
+ /**
+ * Constructs a {@code PairList} containing the elements of the specified map.
+ *
+ * @param map the map contains keys and values
+ */
+ public PairList(Map<K, V> map) {
+ keys = new ArrayList<>(map.size());
+ values = new ArrayList<>(map.size());
+ for (Map.Entry<K, V> entry : map.entrySet()) {
+ keys.add(entry.getKey());
+ values.add(entry.getValue());
+ }
+ }
+
+ /**
+ * Inserts the specified element at the specified position in this list (optional operation),
+ * and shifts the element currently at that position (if any) and any subsequent elements to the
+ * right (adds one to their indices).
+ *
+ * @param index the index at which the specified element is to be inserted
+ * @param key the key
+ * @param value the value
+ */
+ public void add(int index, K key, V value) {
+ keys.add(index, key);
+ values.add(index, value);
+ }
+
+ /**
+ * Adds a key and value to the list.
+ *
+ * @param key the key
+ * @param value the value
+ */
+ public void add(K key, V value) {
+ keys.add(key);
+ values.add(value);
+ }
+
+ /**
+ * Appends all of the elements in the specified pair list to the end of this list.
+ *
+ * @param other the {@code PairList} containing elements to be added to this list
+ */
+ public void addAll(PairList<K, V> other) {
+ if (other != null) {
+ keys.addAll(other.keys);
+ values.addAll(other.values);
+ }
+ }
+
+ /**
+ * Returns the size of the list.
+ *
+ * @return the size of the list
+ */
+ public int size() {
+ return keys.size();
+ }
+
+ /**
+ * Checks whether the list is empty.
+ *
+ * @return whether the list is empty
+ */
+ public boolean isEmpty() {
+ return size() == 0;
+ }
+
+ /**
+ * Returns the key-value pair at the specified position in this list.
+ *
+ * @param index the index of the element to return
+ * @return the key-value pair at the specified position in this list
+ */
+ public Pair<K, V> get(int index) {
+ return new Pair<>(keys.get(index), values.get(index));
+ }
+
+ /**
+ * Returns the value for the first key found in the list.
+ *
+ * @param key the key of the element to get
+ * @return the value for the first key found in the list
+ */
+ public V get(K key) {
+ int index = keys.indexOf(key);
+ if (index == -1) {
+ return null;
+ }
+ return values.get(index);
+ }
+
+ /**
+ * Returns the key at the specified position in this list.
+ *
+ * @param index the index of the element to return
+ * @return the key at the specified position in this list
+ */
+ public K keyAt(int index) {
+ return keys.get(index);
+ }
+
+ /**
+ * Returns the value at the specified position in this list.
+ *
+ * @param index the index of the element to return
+ * @return the value at the specified position in this list
+ */
+ public V valueAt(int index) {
+ return values.get(index);
+ }
+
+ /**
+ * Returns all keys of the list.
+ *
+ * @return all keys of the list
+ */
+ public List<K> keys() {
+ return keys;
+ }
+
+ /**
+ * Returns all values of the list.
+ *
+ * @return all values of the list
+ */
+ public List<V> values() {
+ return values;
+ }
+
+ /**
+ * Returns an array containing all of the keys in this list in proper sequence (from first to
+ * last element); the runtime type of the returned array is that of the specified array.
+ *
+ * <p>If the list fits in the specified array, it is returned therein. Otherwise, a new array is
+ * allocated with the runtime type of the specified array and the size of this list.
+ *
+ * @param target the array into which the keys of this list are to be stored, if it is big
+ * enough; otherwise, a new array of the same runtime type is allocated for this purpose.
+ * @return an array containing the keys of this list
+ */
+ public K[] keyArray(K[] target) {
+ return keys.toArray(target);
+ }
+
+ /**
+ * Returns an array containing all of the values in this list in proper sequence (from first to
+ * last element); the runtime type of the returned array is that of the specified array.
+ *
+ * <p>If the list fits in the specified array, it is returned therein. Otherwise, a new array is
+ * allocated with the runtime type of the specified array and the size of this list.
+ *
+ * @param target the array into which the values of this list are to be stored, if it is big
+ * enough; otherwise, a new array of the same runtime type is allocated for this purpose.
+ * @return an array containing the values of this list
+ */
+ public V[] valueArray(V[] target) {
+ return values.toArray(target);
+ }
+
+ /**
+ * Removes the key-value pair for the first key found in the list.
+ *
+ * @param key the key of the element to be removed
+ * @return the value of the removed element, {@code null} if not found
+ */
+ public V remove(K key) {
+ int index = keys.indexOf(key);
+ if (index == -1) {
+ return null;
+ }
+ return remove(index);
+ }
+
+ /**
+ * Removes the key-value pair at an index.
+ *
+ * @param index the index of the element to remove
+ * @return the value of the removed element, {@code null} if not found
+ */
+ public V remove(int index) {
+ keys.remove(index);
+ return values.remove(index);
+ }
+
+ /**
+ * Returns a view of the portion of this PairList between the specified {@code fromIndex}
+ * inclusive, and to the end.
+ *
+ * @param fromIndex the start index (inclusive)
+ * @return a view of the portion of this PairList
+ */
+ public PairList<K, V> subList(int fromIndex) {
+ return subList(fromIndex, size());
+ }
+
+ /**
+ * Returns a view of the portion of this PairList between the specified {@code fromIndex}
+ * inclusive, and {@code toIndex}, exclusive.
+ *
+ * @param fromIndex the start index (inclusive)
+ * @param toIndex the end index (exclusive)
+ * @return a view of the portion of this PairList
+ */
+ public PairList<K, V> subList(int fromIndex, int toIndex) {
+ List<K> subKeys = keys.subList(fromIndex, toIndex);
+ List<V> subValues = values.subList(fromIndex, toIndex);
+ return new PairList<>(subKeys, subValues);
+ }
+
+ /**
+ * Returns the {@link Stream} type of the PairList.
+ *
+ * @return a {@link Stream} of PairList
+ */
+ public Stream<Pair<K, V>> stream() {
+ return StreamSupport.stream(spliterator(), false);
+ }
+
+ /**
+ * Returns {@code true} if this list contains the specified key.
+ *
+ * @param key the key whose presence will be tested
+ * @return {@code true} if this list contains the specified key
+ */
+ public boolean contains(K key) {
+ return keys.contains(key);
+ }
+
+ /**
+ * Removes all duplicate values from the list.
+ *
+ * @return a new {@code PairList} with the duplicate values removed, taking the latest value for
+ * each key
+ */
+ public PairList<K, V> unique() {
+ return new PairList<>(toMap(false));
+ }
+
+ /**
+ * Returns a {@code Map} that contains the key-value mappings of this list.
+ *
+ * @return a {@code Map} that contains the key-value mappings of this list
+ */
+ public Map<K, V> toMap() {
+ return toMap(true);
+ }
+
+ /**
+ * Returns a {@code Map} that contains the key-value mappings of this list.
+ *
+ * @param checkDuplicate whether to check for duplicated keys in the list
+ * @return a {@code Map} that contains the key-value mappings of this list
+ */
+ public Map<K, V> toMap(boolean checkDuplicate) {
+ int size = keys.size();
+ Map<K, V> map = new ConcurrentHashMap<>(size * 3 / 2);
+ for (int i = 0; i < size; ++i) {
+ if (map.put(keys.get(i), values.get(i)) != null && checkDuplicate) {
+ throw new IllegalStateException("Duplicate keys: " + keys.get(i));
+ }
+ }
+ return map;
+ }
+
+ @Override
+ public Iterator<Pair<K, V>> iterator() {
+ return new Itr();
+ }
+
+ /** Internal Iterator implementation. */
+ private class Itr implements Iterator<Pair<K, V>> {
+
+ private int cursor;
+ private int size = size();
+
+ Itr() {}
+
+ /** {@inheritDoc} */
+ @Override
+ public boolean hasNext() {
+ return cursor < size;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public Pair<K, V> next() {
+ if (cursor >= size) {
+ throw new NoSuchElementException();
+ }
+
+ return get(cursor++);
+ }
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/Platform.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/Platform.java
new file mode 100644
index 0000000..e7e8c72
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/Platform.java
@@ -0,0 +1,185 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.util;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.net.URL;
+import java.util.Properties;
+import org.apache.mxnet.util.cuda.CudaUtils;
+
+/**
+ * The platform contains information regarding the version, os, and build flavor of the MXNet native
+ * code.
+ */
+public final class Platform {
+
+ private String version;
+ private String osPrefix;
+ private String flavor;
+ private String cudaArch;
+ private String[] libraries;
+ private boolean placeholder;
+
+ /** Constructor used only for {@link Platform#fromSystem()}. */
+ private Platform() {}
+
+ /**
+ * Returns the platform that parsed from "engine".properties file.
+ *
+ * @param url the url to the "engine".properties file
+ * @return the platform that parsed from mxnet.properties file
+ * @throws IOException if the file could not be read
+ */
+ public static Platform fromUrl(URL url) throws IOException {
+ Platform platform = Platform.fromSystem();
+ try (InputStream conf = url.openStream()) {
+ Properties prop = new Properties();
+ prop.load(conf);
+ // 1.6.0 later should always has version property
+ platform.version = prop.getProperty("version");
+ if (platform.version == null) {
+ throw new IllegalArgumentException(
+ "version key is required in <engine>.properties file.");
+ }
+ platform.placeholder = prop.getProperty("placeholder") != null;
+ String flavorPrefixedClassifier = prop.getProperty("classifier", "");
+ String libraryList = prop.getProperty("libraries", "");
+ if (libraryList.isEmpty()) {
+ platform.libraries = new String[0];
+ } else {
+ platform.libraries = libraryList.split(",");
+ }
+ if (!flavorPrefixedClassifier.isEmpty()) {
+ platform.flavor = flavorPrefixedClassifier.split("-")[0];
+ platform.osPrefix = flavorPrefixedClassifier.split("-")[1];
+ }
+ }
+ return platform;
+ }
+
+ /**
+ * Returns the platform for the current system.
+ *
+ * @return the platform for the current system
+ */
+ public static Platform fromSystem() {
+ Platform platform = new Platform();
+ String osName = System.getProperty("os.name");
+ if (osName.startsWith("Win")) {
+ platform.osPrefix = "win";
+ } else if (osName.startsWith("Mac")) {
+ platform.osPrefix = "osx";
+ } else if (osName.startsWith("Linux")) {
+ platform.osPrefix = "linux";
+ } else {
+ throw new AssertionError(String.format("Unsupported platform: %s", osName));
+ }
+ if (CudaUtils.getGpuCount() > 0) {
+ platform.flavor = "cu" + CudaUtils.getCudaVersionString();
+ platform.cudaArch = CudaUtils.getComputeCapability(0);
+ } else {
+ platform.flavor = "";
+ }
+ return platform;
+ }
+
+ /**
+ * Returns the Engine Version.
+ *
+ * @return the Engine version
+ */
+ public String getVersion() {
+ return version;
+ }
+
+ /**
+ * Returns the os (win, osx, or linux).
+ *
+ * @return the os (win, osx, or linux)
+ */
+ public String getOsPrefix() {
+ return osPrefix;
+ }
+
+ /**
+ * Returns the MXNet build flavor.
+ *
+ * @return the MXNet build flavor
+ */
+ public String getFlavor() {
+ return flavor;
+ }
+
+ /**
+ * Returns the classifier for the platform.
+ *
+ * @return the classifier for the platform
+ */
+ public String getClassifier() {
+ return getOsPrefix() + "-x86_64";
+ }
+
+ /**
+ * Returns the cuda arch.
+ *
+ * @return the cuda arch
+ */
+ public String getCudaArch() {
+ return cudaArch;
+ }
+
+ /**
+ * Returns the libraries used in the platform.
+ *
+ * @return the libraries used in the platform
+ */
+ public String[] getLibraries() {
+ return libraries;
+ }
+
+ /**
+ * Returns true if the platform is a placeholder.
+ *
+ * @return true if the platform is a placeholder
+ */
+ public boolean isPlaceholder() {
+ return placeholder;
+ }
+
+ /**
+ * Returns true the platforms match (os and flavor).
+ *
+ * @param system the platform to compare it to
+ * @return true if the platforms match
+ */
+ public boolean matches(Platform system) {
+ if (!osPrefix.equals(system.osPrefix)) {
+ return false;
+ }
+ // if system Machine is GPU
+ if (system.flavor.startsWith("cu")) {
+ // system flavor doesn't contain mkl, but MXNet has: cu110mkl
+ return "".equals(flavor)
+ || "cpu".equals(flavor)
+ || "mkl".equals(flavor)
+ || flavor.startsWith(system.flavor);
+ }
+ return "".equals(flavor) || "cpu".equals(flavor) || "mkl".equals(flavor);
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/Progress.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/Progress.java
new file mode 100644
index 0000000..e6d46df
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/Progress.java
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.util;
+
+/** An interface that allows tracking the progress of a task. */
+public interface Progress {
+
+ /**
+ * Resets the progress tracking indicators, and sets the message and max to the given values.
+ *
+ * @param message the message to be shown
+ * @param max the max value that the progress tracking indicator can take
+ */
+ default void reset(String message, long max) {
+ reset(message, max, null);
+ }
+
+ /**
+ * Resets the progress tracking indicators, and sets the message and max to the given values.
+ *
+ * @param message the message to be shown
+ * @param max the max value that the progress tracking indicator can take
+ * @param trailingMessage the trailing message to be shown
+ */
+ void reset(String message, long max, String trailingMessage);
+
+ /**
+ * Starts tracking the progress of the progress tracking indicators at the given initial value.
+ *
+ * @param initialProgress the initial value of the progress
+ */
+ void start(long initialProgress);
+
+ /** Updates the tracking indicators to indicate that the task is complete. */
+ void end();
+
+ /**
+ * Increments the progress tracking indicator by the given value.
+ *
+ * @param increment the value to increment the progress by
+ */
+ void increment(long increment);
+
+ /**
+ * Updates the progress tracking indicator to the given value.
+ *
+ * @param progress the value of the progress tracking indicator
+ */
+ default void update(long progress) {
+ update(progress, null);
+ }
+
+ /**
+ * Updates the progress tracking indicator to the given value, and displays the optional
+ * message.
+ *
+ * @param progress the value of the progress tracking indicator
+ * @param message the optional message
+ */
+ void update(long progress, String message);
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/StandardCapabilities.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/StandardCapabilities.java
new file mode 100644
index 0000000..ce5cc80
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/StandardCapabilities.java
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.util;
+
+/** Constant definitions for the standard capability. */
+public final class StandardCapabilities {
+
+ public static final String CUDA = "CUDA";
+ public static final String CUDNN = "CUDNN";
+ public static final String MKL = "MKL";
+ public static final String MKLDNN = "MKLDNN";
+ public static final String OPENMP = "OPENMP";
+
+ private StandardCapabilities() {}
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/Utils.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/Utils.java
new file mode 100644
index 0000000..35a3cc5
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/Utils.java
@@ -0,0 +1,353 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.util;
+
+import java.io.ByteArrayOutputStream;
+import java.io.FileNotFoundException;
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.charset.StandardCharsets;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.nio.file.StandardCopyOption;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.List;
+import java.util.Objects;
+import java.util.Scanner;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
+import java.util.stream.Collectors;
+
+/** A class containing utility methods. */
+public final class Utils {
+
+ private static final int BUFF_SIZE = 81920;
+ private static final String ENGINE_CACHE_DIR = "ENGINE_CACHE_DIR";
+ private static final String MXNET_CACHE_DIR = "MXNET_CACHE_DIR";
+
+ private Utils() {}
+
+ /**
+ * Returns the index of the first occurrence of the specified element in {@code array}, or -1 if
+ * this list does not contain the element.
+ *
+ * @param array the input array
+ * @param value the element to search for
+ * @param <T> the array type
+ * @return the index of the first occurrence of the specified element in {@code array}, or -1 if
+ * this list does not contain the element
+ */
+ public static <T> int indexOf(T[] array, T value) {
+ if (array != null) {
+ if (value == null) {
+ for (int i = 0; i < array.length; ++i) {
+ if (array[i] == null) {
+ return i;
+ }
+ }
+ } else {
+ for (int i = 0; i < array.length; ++i) {
+ if (value.equals(array[i])) {
+ return i;
+ }
+ }
+ }
+ }
+ return -1;
+ }
+
+ /**
+ * Returns {@code true} if the {@code array} contains the specified element.
+ *
+ * @param array the input array
+ * @param value the element whose presence in {@code array} is to be tested
+ * @param <T> the array type
+ * @return {@code true} if this list contains the specified element
+ */
+ public static <T> boolean contains(T[] array, T value) {
+ return indexOf(array, value) >= 0;
+ }
+
+ /**
+ * Adds padding chars to specified StringBuilder.
+ *
+ * @param sb the StringBuilder to append
+ * @param c the padding char
+ * @param count the number characters to be added
+ */
+ public static void pad(StringBuilder sb, char c, int count) {
+ for (int i = 0; i < count; ++i) {
+ sb.append(c);
+ }
+ }
+
+ /**
+ * Deletes an entire directory and ignore all errors.
+ *
+ * @param dir the directory to be removed
+ */
+ public static void deleteQuietly(Path dir) {
+ try {
+ Files.walk(dir)
+ .sorted(Comparator.reverseOrder())
+ .forEach(
+ path -> {
+ try {
+ Files.deleteIfExists(path);
+ } catch (IOException ignore) {
+ // ignore
+ }
+ });
+ } catch (IOException ignore) {
+ // ignore
+ }
+ }
+
+ /**
+ * Renames a file to a target file and ignore error if target already exists.
+ *
+ * @param source the path to the file to move
+ * @param target the path to the target file
+ * @throws IOException if move file failed
+ */
+ public static void moveQuietly(Path source, Path target) throws IOException {
+ try {
+ Files.move(source, target, StandardCopyOption.ATOMIC_MOVE);
+ } catch (IOException e) {
+ if (!Files.exists(target)) {
+ throw e;
+ }
+ }
+ }
+
+ /**
+ * Reads {@code is} as UTF-8 string.
+ *
+ * @param is the InputStream to be read
+ * @return a UTF-8 encoded string
+ * @throws IOException if IO error occurs
+ */
+ public static String toString(InputStream is) throws IOException {
+ return null;
+ }
+
+ /**
+ * Reads {@code is} as byte array.
+ *
+ * @param is the InputStream to be read
+ * @return a byte array
+ * @throws IOException if IO error occurs
+ */
+ public static byte[] toByteArray(InputStream is) throws IOException {
+
+ try (ByteArrayOutputStream bos = new ByteArrayOutputStream(BUFF_SIZE)) {
+ byte[] buf = new byte[BUFF_SIZE];
+ int read;
+ while ((read = is.read(buf)) != -1) {
+ bos.write(buf, 0, read);
+ }
+ return bos.toByteArray();
+ }
+ }
+
+ /**
+ * Reads all lines from a file.
+ *
+ * @param file the file to be read
+ * @return all lines in the file
+ * @throws IOException if read file failed
+ */
+ public static List<String> readLines(Path file) throws IOException {
+ return readLines(file, false);
+ }
+
+ /**
+ * Reads all lines from a file.
+ *
+ * @param file the file to be read
+ * @param trim true if you want to trim the line and exclude empty lines
+ * @return all lines in the file
+ * @throws IOException if read file failed
+ */
+ public static List<String> readLines(Path file, boolean trim) throws IOException {
+ if (Files.notExists(file)) {
+ return Collections.emptyList();
+ }
+ try (InputStream is = Files.newInputStream(file)) {
+ return readLines(is, trim);
+ }
+ }
+
+ /**
+ * Reads all lines from the specified InputStream.
+ *
+ * @param is the InputStream to read
+ * @return all lines from the input
+ */
+ public static List<String> readLines(InputStream is) {
+ return readLines(is, false);
+ }
+
+ /**
+ * Reads all lines from the specified InputStream.
+ *
+ * @param is the InputStream to read
+ * @param trim true if you want to trim the line and exclude empty lines
+ * @return all lines from the input
+ */
+ public static List<String> readLines(InputStream is, boolean trim) {
+ List<String> list = new ArrayList<>();
+ try (Scanner scanner =
+ new Scanner(is, StandardCharsets.UTF_8.name()).useDelimiter("\\n|\\r\\n")) {
+ while (scanner.hasNext()) {
+ String line = scanner.next();
+ if (trim) {
+ line = line.trim();
+ if (line.isEmpty()) {
+ continue;
+ }
+ }
+ list.add(line);
+ }
+ }
+ return list;
+ }
+
+ /**
+ * Converts a List of Number to float array.
+ *
+ * @param list the list to be converted
+ * @return a float array
+ */
+ public static float[] toFloatArray(List<? extends Number> list) {
+ float[] ret = new float[list.size()];
+ int idx = 0;
+ for (Number n : list) {
+ ret[idx++] = n.floatValue();
+ }
+ return ret;
+ }
+
+ /**
+ * Gets the current epoch number.
+ *
+ * @param modelDir the path to the directory where the model files are stored
+ * @param modelName the name of the model
+ * @return the current epoch number, if no epoch number found, return null
+ * @throws IOException if an I/O error occurs
+ * @throws FileNotFoundException if no matched parameter file with epoch number is found
+ */
+ public static int getCurrentEpoch(Path modelDir, String modelName) throws IOException {
+ final Pattern pattern = Pattern.compile(Pattern.quote(modelName) + "-(\\d{4}).params");
+ List<Integer> checkpoints =
+ Files.walk(modelDir, 1)
+ .map(
+ p -> {
+ Matcher m = pattern.matcher(p.toFile().getName());
+ if (m.matches()) {
+ return Integer.parseInt(m.group(1));
+ }
+ return null;
+ })
+ .filter(Objects::nonNull)
+ .sorted()
+ .collect(Collectors.toList());
+ if (checkpoints.isEmpty()) {
+ throw new FileNotFoundException(
+ String.format(
+ "No matched params file is found in directory: {} for model {}",
+ modelDir.toAbsolutePath(),
+ modelName));
+ }
+ return checkpoints.get(checkpoints.size() - 1);
+ }
+
+ /**
+ * Utility function to help debug nan values in parameters and their gradients.
+ *
+ * @param parameters the list of parameters to check
+ * @param checkGradient whether to check parameter value or its gradient value
+ * @param logger the logger to log the result
+ */
+ // TODO
+ // public static void checkParameterValues(
+ // Pairlist<String, Parameter> parameters, boolean checkGradient, Logger logger) {
+ //
+ // }
+
+ /**
+ * Utility function to help summarize the values in an {@link NDArray}.
+ *
+ * @param array the {@link NDArray} to be summarized
+ * @param logger the logger to log the result
+ * @param prefix the prefix or name to be displayed
+ */
+ // TODO
+ // public static void checkNDArrayValues(NDArray array, Logger logger, String prefix) {
+ //
+ // }
+
+ /**
+ * Utility function to get Engine specific cache directory.
+ *
+ * @param engine the engine name
+ * @return DJL engine cache directory
+ */
+ public static Path getEngineCacheDir(String engine) {
+ return getEngineCacheDir().resolve(engine);
+ }
+
+ /**
+ * Utility function to get Engine cache directory.
+ *
+ * @return DJL engine cache directory
+ */
+ public static Path getEngineCacheDir() {
+ String cacheDir = System.getProperty(ENGINE_CACHE_DIR);
+ if (cacheDir == null || cacheDir.isEmpty()) {
+ cacheDir = System.getenv(ENGINE_CACHE_DIR);
+ if (cacheDir == null || cacheDir.isEmpty()) {
+ return getCacheDir();
+ }
+ }
+ return Paths.get(cacheDir);
+ }
+
+ /**
+ * Utility function to get DJL cache directory.
+ *
+ * @return DJL cache directory
+ */
+ public static Path getCacheDir() {
+ String cacheDir = System.getProperty(MXNET_CACHE_DIR);
+ if (cacheDir == null || cacheDir.isEmpty()) {
+ cacheDir = System.getenv(MXNET_CACHE_DIR);
+ if (cacheDir == null || cacheDir.isEmpty()) {
+ Path dir = Paths.get(System.getProperty("user.home"));
+ if (!Files.isWritable(dir)) {
+ dir = Paths.get(System.getProperty("java.io.tmpdir"));
+ }
+ return dir.resolve("mxnet.java_package");
+ }
+ }
+ return Paths.get(cacheDir);
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/ZipUtils.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/ZipUtils.java
new file mode 100644
index 0000000..44214a9
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/ZipUtils.java
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.util;
+
+import java.io.File;
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.StandardCopyOption;
+import java.util.zip.ZipEntry;
+import java.util.zip.ZipInputStream;
+import java.util.zip.ZipOutputStream;
+
+/** Utilities for working with zip files. */
+public final class ZipUtils {
+
+ private ZipUtils() {}
+
+ /**
+ * Unzips an input stream to a given path.
+ *
+ * @param is the input stream to unzip
+ * @param dest the path to store the unzipped files
+ * @throws IOException for failures to unzip the input stream and create files in the dest path
+ */
+ public static void unzip(InputStream is, Path dest) throws IOException {
+ ZipInputStream zis = new ZipInputStream(is);
+ ZipEntry entry;
+ while ((entry = zis.getNextEntry()) != null) {
+ String name = entry.getName();
+ if (name.contains("..")) {
+ throw new IOException("Malicious zip entry: " + name);
+ }
+ Path file = dest.resolve(name).toAbsolutePath();
+ if (entry.isDirectory()) {
+ Files.createDirectories(file);
+ } else {
+ Path parentFile = file.getParent();
+ if (parentFile == null) {
+ throw new AssertionError(
+ "Parent path should never be null: " + file.toString());
+ }
+ Files.createDirectories(parentFile);
+ Files.copy(zis, file, StandardCopyOption.REPLACE_EXISTING);
+ }
+ }
+ }
+
+ /**
+ * Zips an input directory to a given file.
+ *
+ * @param src the input directory to zip
+ * @param dest the path to store the zipped files
+ * @param includeFolderName if include the source directory name in the zip entry
+ * @throws IOException for failures to zip the input directory
+ */
+ public static void zip(Path src, Path dest, boolean includeFolderName) throws IOException {
+ try (ZipOutputStream zos = new ZipOutputStream(Files.newOutputStream(dest))) {
+ Path root = includeFolderName ? src.getParent() : src;
+ if (root == null) {
+ throw new AssertionError("Parent folder should not be null.");
+ }
+ addToZip(root, src, zos);
+ }
+ }
+
+ private static void addToZip(Path root, Path file, ZipOutputStream zos) throws IOException {
+ Path relative = root.relativize(file);
+ String name = relative.toString();
+ if (Files.isDirectory(file)) {
+ if (!name.isEmpty()) {
+ ZipEntry entry = new ZipEntry(name + '/');
+ zos.putNextEntry(entry);
+ }
+ File[] files = file.toFile().listFiles();
+ if (files != null) {
+ for (File f : files) {
+ addToZip(root, f.toPath(), zos);
+ }
+ }
+ } else if (Files.isRegularFile(file)) {
+ if (name.isEmpty()) {
+ name = file.toFile().getName();
+ }
+ ZipEntry entry = new ZipEntry(name);
+ zos.putNextEntry(entry);
+ Files.copy(file, zos);
+ }
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/cuda/CudaLibrary.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/cuda/CudaLibrary.java
new file mode 100644
index 0000000..abd5ba2
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/cuda/CudaLibrary.java
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.util.cuda;
+
+import com.sun.jna.Library;
+
+/**
+ * {@code CudaLibrary} contains methods mapping to CUDA runtime API.
+ *
+ * <p>see: https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__DEVICE.html
+ */
+public interface CudaLibrary extends Library {
+
+ int INITIALIZATION_ERROR = 3;
+ int INSUFFICIENT_DRIVER = 35;
+ int ERROR_NO_DEVICE = 100;
+ int ERROR_NOT_PERMITTED = 800;
+
+ /**
+ * Gets the number of devices with compute capability greater or equal to 1.0 that are available
+ * for execution.
+ *
+ * @param deviceCount the returned device count
+ * @return CUDA runtime API error code
+ */
+ int cudaGetDeviceCount(int[] deviceCount);
+
+ /**
+ * Returns the version number of the installed CUDA Runtime.
+ *
+ * @param runtimeVersion output buffer of runtime version number
+ * @return CUDA runtime API error code
+ */
+ int cudaRuntimeGetVersion(int[] runtimeVersion);
+
+ /**
+ * Gets the integer value of the attribute {@code attr} on device.
+ *
+ * @param pi the returned device attribute value
+ * @param attr the device attribute to query
+ * @param device the GPU device to retrieve
+ * @return CUDA runtime API error code
+ */
+ int cudaDeviceGetAttribute(int[] pi, int attr, int device);
+
+ /**
+ * Gets free and total device memory.
+ *
+ * @param free the returned free memory in bytes
+ * @param total the returned total memory in bytes
+ * @return CUDA runtime API error code
+ */
+ int cudaMemGetInfo(long[] free, long[] total);
+
+ /**
+ * Set device to be used for GPU executions.
+ *
+ * @param device the GPU device to retrieve
+ * @return CUDA runtime API error code
+ */
+ int cudaSetDevice(int device);
+
+ /**
+ * Gets which device is currently being used.
+ *
+ * @param device the returned current device
+ * @return CUDA runtime API error code
+ */
+ int cudaGetDevice(int[] device);
+
+ /**
+ * Returns the description string for an error code.
+ *
+ * @param code the CUDA error code to convert to string
+ * @return the description string for an error code
+ */
+ String cudaGetErrorString(int code);
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/cuda/CudaUtils.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/cuda/CudaUtils.java
new file mode 100644
index 0000000..d79967d
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/cuda/CudaUtils.java
@@ -0,0 +1,225 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.util.cuda;
+
+import com.sun.jna.Native;
+import java.io.File;
+import java.lang.management.MemoryUsage;
+import java.util.regex.Pattern;
+import org.apache.mxnet.engine.Device;
+import org.apache.mxnet.exception.JnaCallException;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/** A class containing CUDA utility methods. */
+public final class CudaUtils {
+
+ private static final Logger logger = LoggerFactory.getLogger(CudaUtils.class);
+
+ private static final CudaLibrary LIB = loadLibrary();
+
+ private static int gpuCount = -1;
+
+ private CudaUtils() {}
+
+ /**
+ * Gets whether CUDA runtime library is in the system.
+ *
+ * @return {@code true} if CUDA runtime library is in the system
+ */
+ public static boolean hasCuda() {
+ return getGpuCount() > 0;
+ }
+
+ /**
+ * Returns the number of GPUs available in the system.
+ *
+ * @return the number of GPUs available in the system
+ */
+ public static int getGpuCount() {
+
+ if (gpuCount != -1) {
+ return gpuCount;
+ }
+
+ try {
+ validateLibrary();
+ } catch (IllegalStateException e) {
+ return 0;
+ }
+ int[] count = new int[1];
+ int result = LIB.cudaGetDeviceCount(count);
+ switch (result) {
+ case 0:
+ gpuCount = count[0];
+ return gpuCount;
+ case CudaLibrary.ERROR_NO_DEVICE:
+ logger.debug(
+ "No GPU device found: {} ({})", LIB.cudaGetErrorString(result), result);
+ gpuCount = 0;
+ return gpuCount;
+ case CudaLibrary.INITIALIZATION_ERROR:
+ case CudaLibrary.INSUFFICIENT_DRIVER:
+ case CudaLibrary.ERROR_NOT_PERMITTED:
+ default:
+ logger.warn(
+ "Failed to detect GPU count: {} ({})",
+ LIB.cudaGetErrorString(result),
+ result);
+ gpuCount = 0;
+ return gpuCount;
+ }
+ }
+
+ /**
+ * Returns the version of CUDA runtime.
+ *
+ * @return the version if CUDA runtime
+ */
+ public static int getCudaVersion() {
+ validateLibrary();
+ int[] version = new int[1];
+ int result = LIB.cudaRuntimeGetVersion(version);
+ checkCall(result);
+ return version[0];
+ }
+
+ /**
+ * Returns the version string of CUDA runtime.
+ *
+ * @return the version string of CUDA runtime
+ */
+ public static String getCudaVersionString() {
+ validateLibrary();
+ int version = getCudaVersion();
+ int major = version / 1000;
+ int minor = (version / 10) % 10;
+ return String.valueOf(major) + minor;
+ }
+
+ /**
+ * Returns the CUDA compute capability.
+ *
+ * @param device the GPU {@link Device} to retrieve
+ * @return the CUDA compute capability
+ */
+ public static String getComputeCapability(int device) {
+ validateLibrary();
+ int attrComputeCapabilityMajor = 75;
+ int attrComputeCapabilityMinor = 76;
+
+ int[] major = new int[1];
+ int[] minor = new int[1];
+ checkCall(LIB.cudaDeviceGetAttribute(major, attrComputeCapabilityMajor, device));
+ checkCall(LIB.cudaDeviceGetAttribute(minor, attrComputeCapabilityMinor, device));
+
+ return String.valueOf(major[0] + minor[0]);
+ }
+
+ /**
+ * Returns the {@link MemoryUsage} of the specified GPU device.
+ *
+ * @param device the GPU {@link Device} to retrieve
+ * @return the {@link MemoryUsage} of the specified GPU device
+ * @throws IllegalArgumentException if {@link Device} is not GPU device or does not exist
+ */
+ public static MemoryUsage getGpuMemory(Device device) {
+ if (!Device.Type.GPU.equals(device.getDeviceType())) {
+ throw new IllegalArgumentException("Only GPU device is allowed.");
+ }
+
+ validateLibrary("No GPU device detected.");
+
+ int[] currentDevice = new int[1];
+ checkCall(LIB.cudaGetDevice(currentDevice));
+ checkCall(LIB.cudaSetDevice(device.getDeviceId()));
+
+ long[] free = new long[1];
+ long[] total = new long[1];
+
+ checkCall(LIB.cudaMemGetInfo(free, total));
+ checkCall(LIB.cudaSetDevice(currentDevice[0]));
+
+ long committed = total[0] - free[0];
+ return new MemoryUsage(-1, committed, committed, total[0]);
+ }
+
+ private static CudaLibrary loadLibrary() {
+ try {
+ if (System.getProperty("os.name").startsWith("Win")) {
+ String path = System.getenv("PATH");
+ if (path == null) {
+ return null;
+ }
+ Pattern p = Pattern.compile("cudart64_\\d+\\.ddl");
+ String cudaPath = System.getenv("CUDA_PATH");
+
+ String[] searchPath = getPathArray(path, cudaPath);
+
+ for (String item : searchPath) {
+ File dir = new File(item);
+ File[] files = dir.listFiles(n -> p.matcher(n.getName()).matches());
+ if (files != null && files.length > 0) {
+ String fileName = files[0].getName();
+ String cudaRT = fileName.substring(0, fileName.length() - 4);
+ logger.debug("Found cudart: {}", files[0].getAbsolutePath());
+ return Native.load(cudaRT, CudaLibrary.class);
+ }
+ }
+ logger.debug("No cudart library found in path.");
+ return null;
+ }
+ return Native.load("cudart", CudaLibrary.class);
+ } catch (UnsatisfiedLinkError e) {
+ logger.debug("cudart library not found.");
+ logger.trace("", e);
+ return null;
+ }
+ }
+
+ private static String[] getPathArray(String path, String cudaPath) {
+ if (cudaPath == null) {
+ return path.split(";");
+ } else {
+ return ";".split(String.format("%s\\bin\\;%s", cudaPath, path));
+ }
+ }
+
+ private static void checkCall(int ret) {
+ validateLibrary();
+ if (ret != 0) {
+ throw new JnaCallException(
+ String.format(
+ "CUDA API call failed: %s (%d)", LIB.cudaGetErrorString(ret), ret));
+ }
+ }
+
+ private static void validateLibrary() {
+ if (LIB == null) {
+ throw new IllegalStateException("No cuda library is loaded.");
+ }
+ }
+
+ private static void validateLibrary(String msg) {
+ if (msg == null) {
+ validateLibrary();
+ } else if (LIB == null) {
+ throw new IllegalStateException(msg);
+ }
+ }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/cuda/package-info.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/cuda/package-info.java
new file mode 100644
index 0000000..e34e298
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/cuda/package-info.java
@@ -0,0 +1,19 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/** Contains the Java front-end implementation for Apache MXNet. */
+package org.apache.mxnet.util.cuda;
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/package-info.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/package-info.java
new file mode 100644
index 0000000..59260e2
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/util/package-info.java
@@ -0,0 +1,19 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/** Contains the Java front-end implementation for Apache MXNet. */
+package org.apache.mxnet.util;
diff --git a/java-package/mxnet-engine/src/main/jna/mapping.properties b/java-package/mxnet-engine/src/main/jna/mapping.properties
new file mode 100644
index 0000000..8a770cc
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/jna/mapping.properties
@@ -0,0 +1,37 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+MXNDArraySaveRawBytes.out_buf = PointerByReference
+MXNDArraySave.args = PointerArray
+MXInvokeCachedOp.inputs = Pointer
+MXInvokeCachedOpEx.inputs = Pointer
+MXInvokeCachedOpEX.inputs = Pointer
+MXImperativeInvoke.inputs = PointerArray
+MXImperativeInvokeEx.inputs = PointerArray
+MXImperativeInvokeEx.param_keys = StringArray
+MXImperativeInvokeEx.param_vals = StringArray
+MXKVStoreInit.vals = PointerArray
+MXKVStoreInitEx.vals = PointerArray
+MXKVStorePush.vals = PointerArray
+MXKVStorePushEx.vals = PointerArray
+MXKVStorePull.vals = PointerArray
+MXKVStorePullEx.vals = PointerArray
+MXKVStorePushPullEx.vals = PointerArray
+MXKVStorePushPullEx.outs = PointerArray
+MXAutogradBackwardEx.output_handles = PointerArray
+MXAutogradBackwardEx.ograd_handles = PointerArray
+MXAutogradBackward.output_handles = PointerArray
diff --git a/java-package/native/build.gradle b/java-package/native/build.gradle
new file mode 100644
index 0000000..c3c91e0
--- /dev/null
+++ b/java-package/native/build.gradle
@@ -0,0 +1,336 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+plugins {
+ id 'maven-publish'
+ id 'signing'
+}
+
+group = "org.apache.mxnet"
+
+def VERSION = "2.0.0"
+boolean isRelease = project.hasProperty("release") || project.hasProperty("staging")
+version = VERSION + (isRelease ? "" : "-SNAPSHOT")
+
+task syncBuiltMxnetLib(type: Sync) {
+ from "${rootProject.projectDir.parent}/build"
+ into "${project.buildDir}/mxnet/native/lib"
+ include "libmxnet.*"
+}
+
+// Create mxnet native library jar without classifier
+jar {
+ def placeholder = "${project.buildDir}/placeholder"
+ // this line is to enforce gradle to build the jar
+ // otherwise it don't generate the placeholder jar at times
+ // when there is no java code inside src/main
+ outputs.dir file("build/libs")
+ doFirst {
+ def versionName = project.version
+ if (!isRelease) {
+ versionName += String.format("-%s", new Date().format('yyyyMMdd'))
+ }
+ def dir = file("${placeholder}/native/lib")
+ dir.mkdirs()
+ def propFile = file("${placeholder}/native/lib/mxnet.properties")
+ propFile.text = "placeholder=true\nversion=${versionName}\n"
+ }
+
+ from placeholder
+}
+
+java {
+ withJavadocJar()
+ withSourcesJar()
+}
+
+project.tasks.withType(GenerateModuleMetadata) {
+ enabled = false
+}
+
+signing {
+ required(project.hasProperty("staging") || project.hasProperty("snapshot"))
+ def signingKey = findProperty("signingKey")
+ def signingPassword = findProperty("signingPassword")
+ sign publishing.publications
+}
+
+task buildLocalLibraryJarDefault() {
+ def flavor = "mkl"
+ def osName = getOsName()
+ buildLocalLibraryJar(flavor, osName)
+}
+
+def buildLocalLibraryJar(flavorName, osName) {
+ def BINARY_ROOT = "${project.buildDir}"
+ tasks.create(name: "${flavorName}-${osName}Jar", type: Jar) {
+ doFirst {
+ copyMxnetNativeLib(flavorName, osName)
+ def propFile = file("${BINARY_ROOT}/${flavorName}/${osName}/native/lib/mxnet.properties")
+ propFile.delete()
+ def dsStore = file("${BINARY_ROOT}/${flavorName}/${osName}/native/lib/.DS_Store")
+ dsStore.delete()
+
+ def versionName = String.format("${version}-%s", new Date().format('yyyyMMdd'))
+ def dir = file("${BINARY_ROOT}/${flavorName}/${osName}/native/lib")
+ def sb = new StringBuilder()
+ sb.append("version=${versionName}\nclassifier=${flavorName}-${osName}-x86_64\nlibraries=")
+ def first = true
+ for (String name : dir.list().sort()) {
+ if (first) {
+ first = false;
+ } else {
+ sb.append(',')
+ }
+ sb.append(name)
+ }
+ propFile.text = sb.toString()
+ def metaInf = new File("${BINARY_ROOT}/${flavorName}/${osName}/META-INF")
+ metaInf.mkdirs()
+ def licenseFile = new File(metaInf, "LICENSE")
+ licenseFile.text = new URL("https://raw.githubusercontent.com/apache/incubator-mxnet/master/LICENSE").text
+
+ def binaryLicenseFile = new File(metaInf, "LICENSE.binary.dependencies")
+ binaryLicenseFile.text = new URL("https://raw.githubusercontent.com/apache/incubator-mxnet/master/tools/dependencies/LICENSE.binary.dependencies").text
+
+ from file("src/main/resources")
+ }
+ from file("${BINARY_ROOT}/${flavorName}/${osName}")
+ archiveClassifier = "${osName}-x86_64"
+
+ manifest {
+ attributes("Automatic-Module-Name": "org.apache.mxnet.mxnet_native_${flavorName}_${osName}")
+ }
+ }
+ return tasks["${flavorName}-${osName}Jar"]
+}
+
+def getOsName() {
+ def os_name = System.properties['os.name']
+ if (os_name.contains('windows')) {
+ return "win"
+ } else if (os_name.contains('Mac OS X')) {
+ return "osx"
+ } else if (os_name.contains('Linux')) {
+ return "linux"
+ } else {
+ return System.properties['os.name']
+ }
+
+}
+
+//def BINARY_ROOT = "${project.buildDir}/download"
+//def flavorNames = file(BINARY_ROOT).list() ?: []
+//flavorNames.each { flavor ->
+//
+// def platformNames = file("${BINARY_ROOT}/${flavor}").list() ?: []
+//
+// def artifactsNames = []
+//
+// platformNames.each { osName ->
+// tasks.create(name: "${flavor}-${osName}Jar", type: Jar) {
+// doFirst {
+// def propFile = file("${BINARY_ROOT}/${flavor}/${osName}/native/lib/mxnet.properties")
+// propFile.delete()
+// def dsStore = file("${BINARY_ROOT}/${flavor}/${osName}/native/lib/.DS_Store")
+// dsStore.delete()
+//
+// def versionName = String.format("${version}-%s", new Date().format('yyyyMMdd'))
+// def dir = file("${BINARY_ROOT}/${flavor}/${osName}/native/lib")
+// def sb = new StringBuilder()
+// sb.append("version=${versionName}\nclassifier=${flavor}-${osName}-x86_64\nlibraries=")
+// def first = true
+// for (String name : dir.list().sort()) {
+// if (first) {
+// first = false;
+// } else {
+// sb.append(',')
+// }
+// sb.append(name)
+// }
+// propFile.text = sb.toString()
+// def metaInf = new File("${BINARY_ROOT}/${flavor}/${osName}/META-INF")
+// metaInf.mkdirs()
+// def licenseFile = new File(metaInf, "LICENSE")
+// licenseFile.text = new URL("https://raw.githubusercontent.com/apache/incubator-mxnet/master/LICENSE").text
+//
+// def binaryLicenseFile = new File(metaInf, "LICENSE.binary.dependencies")
+// binaryLicenseFile.text = new URL("https://raw.githubusercontent.com/apache/incubator-mxnet/master/tools/dependencies/LICENSE.binary.dependencies").text
+//
+// from file("src/main/resources")
+// }
+// from file("${BINARY_ROOT}/${flavor}/${osName}")
+// archiveClassifier = "${osName}-x86_64"
+//
+// manifest {
+// attributes("Automatic-Module-Name": "org.apache.mxnet.mxnet_native_${flavor}_${osName}")
+// }
+// }
+// artifactsNames.add(tasks["${flavor}-${osName}Jar"])
+// }
+
+ // Only publish if the project directory equals the current directory
+ // This means that publishing from the main project does not publish the native jars
+ // and the native jars have to be published separately
+ // TODO publish info
+// if (project.getProjectDir().toString() == System.getProperty("user.dir")) {
+// publishing.publications.create("${flavor}", MavenPublication) {
+// artifactId "mxnet-native-${flavor}"
+// from components.java
+// artifacts = artifactsNames
+// artifact jar
+// artifact javadocJar
+// artifact sourcesJar
+// pom {
+// name = "DJL release for Apache MXNet native binaries"
+// description = "Deep Java Library (DJL) provided Apache MXNet native library binary distribution"
+// url = "http://www.djl.ai/mxnet/native"
+// packaging = "jar"
+//
+// licenses {
+// license {
+// name = 'The Apache License, Version 2.0'
+// url = 'https://www.apache.org/licenses/LICENSE-2.0'
+// }
+// }
+//
+// scm {
+// connection = "scm:git:git@github.com:deepjavalibrary/djl.git"
+// developerConnection = "scm:git:git@github.com:deepjavalibrary/djl.git"
+// url = "https://github.com/deepjavalibrary/djl"
+// tag = "HEAD"
+// }
+//
+// developers {
+// developer {
+// name = "DJL.AI Team"
+// email = "djl-dev@amazon.com"
+// organization = "Amazon AI"
+// organizationUrl = "https://amazon.com"
+// }
+// }
+// }
+// }
+// }
+//}
+
+//publishing.repositories {
+// maven {
+// if (project.hasProperty("snapshot")) {
+// name = "snapshot"
+// url = "https://oss.sonatype.org/content/repositories/snapshots/"
+// credentials {
+// username = findProperty("ossrhUsername")
+// password = findProperty("ossrhPassword")
+// }
+// } else if (project.hasProperty("staging")) {
+// name = "staging"
+// url = "https://oss.sonatype.org/service/local/staging/deploy/maven2/"
+// credentials {
+// username = findProperty("ossrhUsername")
+// password = findProperty("ossrhPassword")
+// }
+// } else {
+// name = "local"
+// url = "build/repo"
+// }
+// }
+//}
+
+import java.util.zip.GZIPInputStream
+task copyMxnetNativeLibDefault() {
+ copyMxnetNativeLib("mkl", getOsName())
+}
+
+def copyMxnetNativeLib(flavorName, osName) {
+ // TODO: only mkl considered here
+ copy {
+ from("${rootProject.projectDir.parent}/build")
+ into("${project.buildDir}/${flavorName}/${osName}/native/lib")
+ // TODO: load map (flavor-os -> lib name) from configure file
+ switch (osName + "-" + flavorName) {
+ case "osx-mkl":
+ include "libmxnet.dylib"
+ break
+ case "win-commen":
+ include "libgcc_s_seh-1.dll"
+ include "libgfortran-3.dll"
+ include "libopenblas.dll"
+ include "libquadmath-0.dll"
+ break
+ case "win-mkl":
+ include "mxnet.dll"
+ break
+ case "linux-common":
+ include "libgfortran.so.4"
+ include "libgomp.so.1"
+ include "libopenblas.so.0"
+ include "libquadmath.so.0"
+ break
+ case "linux-mkl":
+ case "linux-cu102mkl":
+ case "linux-cu110mkl":
+ include "libmxnet.so"
+ break
+ default:
+ include ""
+ }
+ }
+}
+
+//task downloadMxnetNativeLib() {
+// doLast {
+// def url = "https://publish.djl.ai/mxnet-${VERSION}"
+// def files = [
+//// "linux/common/libgfortran.so.4.gz": "mkl/linux/native/lib/libgfortran.so.4",
+// "linux/common/libgomp.so.1.gz" : "mkl/linux/native/lib/libgomp.so.1",
+// "linux/common/libopenblas.so.0.gz": "mkl/linux/native/lib/libopenblas.so.0",
+// "linux/common/libquadmath.so.0.gz": "mkl/linux/native/lib/libquadmath.so.0",
+// "linux/mkl/libmxnet.so.gz" : "mkl/linux/native/lib/libmxnet.so",
+// "linux/cu102mkl/libmxnet.so.gz" : "cu102mkl/linux/native/lib/libmxnet.so",
+// "linux/cu110mkl/libmxnet.so.gz" : "cu110mkl/linux/native/lib/libmxnet.so",
+// "osx/mkl/libmxnet.dylib.gz" : "mkl/osx/native/lib/libmxnet.dylib",
+// "win/common/libgcc_s_seh-1.dll.gz": "mkl/win/native/lib/libgcc_s_seh-1.dll",
+// "win/common/libgfortran-3.dll.gz" : "mkl/win/native/lib/libgfortran-3.dll",
+// "win/common/libopenblas.dll.gz" : "mkl/win/native/lib/libopenblas.dll",
+// "win/common/libquadmath-0.dll.gz" : "mkl/win/native/lib/libquadmath-0.dll",
+// "win/mkl/libmxnet.dll.gz" : "mkl/win/native/lib/mxnet.dll"
+// ]
+//
+// files.each { entry ->
+// project.logger.lifecycle("Downloading ${url}/${entry.key}")
+// def file = new File("${BINARY_ROOT}/${entry.value}")
+// file.getParentFile().mkdirs()
+// new URL("${url}/${entry.key}").withInputStream { i -> file.withOutputStream { it << new GZIPInputStream(i) } }
+// }
+//
+// copy {
+// from("${BINARY_ROOT}/mkl/linux/native/lib") {
+// exclude '**/libmxnet.so'
+// }
+// into("${BINARY_ROOT}/cu102mkl/linux/native/lib")
+// }
+// copy {
+// from("${BINARY_ROOT}/mkl/linux/native/lib") {
+// exclude '**/libmxnet.so'
+// }
+// into("${BINARY_ROOT}/cu110mkl/linux/native/lib")
+// }
+//
+// new File("${BINARY_ROOT}/auto").mkdirs()
+// }
+//}
diff --git a/java-package/native/src/main/resources/META-INF/.gitkeep b/java-package/native/src/main/resources/META-INF/.gitkeep
new file mode 100644
index 0000000..d216be4
--- /dev/null
+++ b/java-package/native/src/main/resources/META-INF/.gitkeep
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
\ No newline at end of file
diff --git a/java-package/scripts/ci/start_integration_test.sh b/java-package/scripts/ci/start_integration_test.sh
new file mode 100644
index 0000000..d4c55f0
--- /dev/null
+++ b/java-package/scripts/ci/start_integration_test.sh
@@ -0,0 +1,18 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+cd /work/mxnet
+python3 build.py -p ubuntu_cpu /work/mxnet/ci/docker/runtime_functions.sh java_package_integration_test
\ No newline at end of file
diff --git a/java-package/settings.gradle b/java-package/settings.gradle
new file mode 100644
index 0000000..54c778f
--- /dev/null
+++ b/java-package/settings.gradle
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+rootProject.name = 'org.apache.mxnet'
+
+include 'mxnet-engine'
+include 'native'
+include 'jnarator'
+include 'integration'
+include 'example'
+
diff --git a/java-package/tools/conf/checkstyle.xml b/java-package/tools/conf/checkstyle.xml
new file mode 100644
index 0000000..e4e4fe4
--- /dev/null
+++ b/java-package/tools/conf/checkstyle.xml
@@ -0,0 +1,521 @@
+<?xml version="1.0" encoding="utf-8"?>
+<!--
+ ~ Licensed to the Apache Software Foundation (ASF) under one or more
+ ~ contributor license agreements. See the NOTICE file distributed with
+ ~ this work for additional information regarding copyright ownership.
+ ~ The ASF licenses this file to You under the Apache License, Version 2.0
+ ~ (the "License"); you may not use this file except in compliance with
+ ~ the License. You may obtain a copy of the License at
+ ~
+ ~ http://www.apache.org/licenses/LICENSE-2.0
+ ~
+ ~ Unless required by applicable law or agreed to in writing, software
+ ~ distributed under the License is distributed on an "AS IS" BASIS,
+ ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ ~ See the License for the specific language governing permissions and
+ ~ limitations under the License.
+ -->
+
+<!DOCTYPE module PUBLIC "-//Puppy Crawl//DTD Check Configuration 1.3//EN"
+ "http://www.puppycrawl.com/dtds/configuration_1_3.dtd">
+<module name="Checker">
+ <property name="charset" value="UTF-8"/>
+ <property name="fileExtensions" value="java, properties, xml"/>
+
+ <!-- Filters -->
+ <!--
+ <property name="severity" value="warning"/>
+ <module name="SeverityMatchFilter">
+ <property name="severity" value="ignore"/>
+ <property name="acceptOnMatch" value="false"/>
+ </module>
+ -->
+
+ <module name="SuppressionFilter">
+ <property name="file" value="${checkstyle.suppressions.file}"/>
+ </module>
+ <module name="SuppressWarningsFilter"/>
+
+ <!-- Headers -->
+ <module name="Header">
+ <property name="headerFile" value="${checkstyle.licenseHeader.file}"/>
+ <property name="fileExtensions" value="java"/>
+ <property name="ignoreLines" value="2"/>
+ <property name="id" value="header"/>
+ </module>
+ <!--
+ <module name="RegexpHeader">
+ <property name="headerFile" value="${checkstyle.regexp.header.file}"/>
+ <property name="fileExtensions" value="java"/>
+ </module>
+ -->
+
+ <!-- Miscellaneous -->
+ <module name="NewlineAtEndOfFile">
+ <property name="fileExtensions" value="*.java"/>
+ </module>
+ <module name="Translation">
+ <property name="fileExtensions" value="properties"/>
+ </module>
+ <module name="UniqueProperties"/>
+
+ <!-- Regexp -->
+ <!--
+ <module name="RegexpMultiline"/>
+ <module name="RegexpMultiline">
+ <property name="format" value="\r?\n[\t ]*\r?\n[\t ]*\r?\n"/>
+ <property name="fileExtensions" value="java,xml,properties"/>
+ <property name="message" value="Unnecessary consecutive lines"/>
+ </module>
+ <module name="RegexpMultiline">
+ <property name="format" value="/\*\*\W+\* +\p{javaLowerCase}"/>
+ <property name="fileExtensions" value="java"/>
+ <property name="message"
+ value="First sentence in a comment should start with a capital letter"/>
+ </module>
+ <module name="RegexpSingleline">
+ <property name="format" value="\s+$"/>
+ <property name="minimum" value="0"/>
+ <property name="maximum" value="0"/>
+ </module>
+ <module name="RegexpSingleline">
+ <property name="format" value="/\*\* +\p{javaLowerCase}"/>
+ <property name="fileExtensions" value="java"/>
+ <property name="message"
+ value="First sentence in a comment should start with a capital letter"/>
+ </module>
+ <module name="RegexpSingleline">
+ <property name="format" value="^(?!(.*http|import)).{101,}$"/>
+ <property name="fileExtensions" value="g, g4"/>
+ <property name="message" value="Line should not be longer then 100 symbols"/>
+ </module>
+ <module name="RegexpOnFilename"/>
+ <module name="RegexpOnFilename">
+ <property name="folderPattern" value="[\\/]src[\\/]\w+[\\/]java[\\/]"/>
+ <property name="fileNamePattern" value="\.java$"/>
+ <property name="match" value="false"/>
+ <message key="regexp.filepath.mismatch"
+ value="Only java files should be located in the ''src/*/java'' folders."/>
+ </module>
+ <module name="RegexpOnFilename">
+ <property name="folderPattern" value="[\\/]src[\\/]xdocs[\\/]"/>
+ <property name="fileNamePattern" value="\.(xml)|(vm)$"/>
+ <property name="match" value="false"/>
+ <message key="regexp.filepath.mismatch"
+ value="All files in the ''src/xdocs'' folder should have the ''xml'' or ''vm'' extension."/>
+ </module>
+ <module name="RegexpOnFilename">
+ <property name="folderPattern" value="[\\/]src[\\/]it[\\/]java[\\/]"/>
+ <property name="fileNamePattern" value="^((\w+Test)|(Base\w+))\.java$"/>
+ <property name="match" value="false"/>
+ <message key="regexp.filepath.mismatch"
+ value="All files in the ''src/it/java'' folder should be named ''*Test.java'' or ''Base*.java''."/>
+ </module>
+ -->
+
+ <!-- Size Violations -->
+ <module name="FileLength">
+ <property name="max" value="3000"/>
+ <property name="fileExtensions" value="java"/>
+ </module>
+
+ <!-- Whitespace -->
+ <!--
+ <module name="FileTabCharacter">
+ <property name="eachLine" value="false"/>
+ </module>
+ -->
+ <module name="JavadocPackage"/>
+
+ <module name="TreeWalker">
+ <!--
+ <property name="tabWidth" value="4"/>
+ -->
+
+ <!-- Annotations -->
+ <module name="AnnotationLocation">
+ <property name="tokens" value="CLASS_DEF, INTERFACE_DEF, ENUM_DEF, METHOD_DEF, CTOR_DEF"/>
+ </module>
+ <module name="AnnotationLocation">
+ <property name="tokens" value="VARIABLE_DEF"/>
+ <property name="allowSamelineMultipleAnnotations" value="true"/>
+ </module>
+ <module name="AnnotationUseStyle"/>
+ <module name="MissingDeprecated"/>
+ <module name="MissingOverride">
+ <property name="javaFiveCompatibility" value="true"/>
+ </module>
+ <module name="PackageAnnotation"/>
+ <module name="SuppressWarnings"/>
+ <module name="SuppressWarningsHolder"/>
+
+ <!-- Block Checks -->
+ <module name="AvoidNestedBlocks">
+ <property name="allowInSwitchCase" value="true"/>
+ </module>
+ <module name="EmptyBlock">
+ <property name="option" value="text"/>
+ </module>
+ <module name="EmptyCatchBlock">
+ <property name="exceptionVariableName" value="expected|ignore"/>
+ <property name="commentFormat" value="ignore.*"/>
+ </module>
+ <module name="NeedBraces"/>
+ <module name="RightCurly">
+ <property name="tokens"
+ value="CLASS_DEF, METHOD_DEF, CTOR_DEF, LITERAL_FOR, STATIC_INIT, INSTANCE_INIT"/>
+ </module>
+
+ <!-- Class Design -->
+ <!--
+ <module name="DesignForExtension">
+ <property name="ignoredAnnotations"
+ value="Override, Test, Before, After, BeforeClass, AfterClass"/>
+ </module>
+ -->
+ <module name="FinalClass"/>
+ <module name="HideUtilityClassConstructor"/>
+ <module name="InnerTypeLast"/>
+ <!--
+ <module name="InterfaceIsType"/>
+ -->
+ <module name="MutableException"/>
+ <module name="OneTopLevelClass"/>
+ <!--
+ <module name="ThrowsCount">
+ <property name="max" value="2"/>
+ </module>
+ -->
+ <module name="VisibilityModifier">
+ <property name="packageAllowed" value="true"/>
+ <property name="protectedAllowed" value="true"/>
+ </module>
+
+ <!-- Coding -->
+ <!--
+ <module name="ArrayTrailingComma"/>
+ <module name="AvoidInlineConditionals"/>
+ -->
+ <module name="CovariantEquals"/>
+ <module name="DeclarationOrder">
+ <property name="ignoreModifiers" value="true"/>
+ </module>
+ <module name="DefaultComesLast"/>
+ <module name="EmptyStatement"/>
+ <module name="EqualsAvoidNull"/>
+ <module name="EqualsHashCode"/>
+ <module name="ExplicitInitialization"/>
+ <module name="FallThrough"/>
+ <!--
+ <module name="FinalLocalVariable"/>
+ -->
+ <module name="HiddenField">
+ <property name="tokens" value="VARIABLE_DEF"/>
+ <property name="ignoreConstructorParameter" value="true"/>
+ <property name="ignoreSetter" value="true"/>
+ <property name="setterCanReturnItsClass" value="true"/>
+ </module>
+ <!--
+ <module name="IllegalCatch">
+ <property name="illegalClassNames"
+ value="java.lang.Exception, java.lang.Throwable, java.lang.RuntimeException, java.lang.NullPointerException"/>
+ </module>
+ <module name="IllegalInstantiation"/>
+ <module name="IllegalThrows"/>
+ <module name="IllegalToken"/>
+ <module name="IllegalTokenText"/>
+ <module name="IllegalType"/>
+ <module name="InnerAssignment"/>
+ <module name="MagicNumber"/>
+ <module name="MissingCtor">
+ <property name="severity" value="ignore"/>
+ </module>
+ -->
+ <module name="MissingSwitchDefault"/>
+ <!--
+ <module name="ModifiedControlVariable"/>
+ <module name="MultipleStringLiterals"/>
+ -->
+ <module name="MultipleVariableDeclarations"/>
+ <!--
+ <module name="NestedForDepth">
+ <property name="max" value="2"/>
+ </module>
+ <module name="NestedIfDepth">
+ <property name="max" value="3"/>
+ </module>
+ <module name="NestedTryDepth"/>
+ -->
+ <module name="NoClone"/>
+ <!--
+ <module name="NoFinalizer"/>
+ -->
+ <module name="OneStatementPerLine"/>
+ <module name="OverloadMethodsDeclarationOrder"/>
+ <module name="PackageDeclaration"/>
+ <!--
+ <module name="ParameterAssignment"/>
+ <module name="RequireThis"/>
+ <module name="ReturnCount">
+ <property name="maxForVoid" value="0"/>
+ </module>
+ -->
+ <module name="SimplifyBooleanExpression"/>
+ <module name="SimplifyBooleanReturn"/>
+ <module name="StringLiteralEquality"/>
+ <module name="SuperClone"/>
+ <module name="SuperFinalize"/>
+ <!--
+ <module name="UnnecessaryParentheses"/>
+ <module name="VariableDeclarationUsageDistance"/>
+ -->
+
+ <!-- Imports -->
+ <module name="AvoidStarImport"/>
+ <module name="AvoidStaticImport"/>
+ <!--
+ <module name="CustomImportOrder">
+ <property name="customImportOrderRules"
+ value="STATIC###STANDARD_JAVA_PACKAGE###SPECIAL_IMPORTS"/>
+ <property name="specialImportsRegExp" value="org"/>
+ <property name="sortImportsInGroupAlphabetically" value="true"/>
+ <property name="separateLineBetweenGroups" value="true"/>
+ </module>
+ -->
+ <module name="IllegalImport"/>
+ <!--
+ <module name="ImportControl">
+ <property name="file" value="${checkstyle.importcontrol.file}"/>
+ </module>
+ -->
+ <module name="ImportOrder">
+ <property name="option" value="under"/>
+ <property name="groups" value=""/>
+ <property name="ordered" value="true"/>
+ <property name="separated" value="true"/>
+ <property name="sortStaticImportsAlphabetically" value="true"/>
+ </module>
+ <module name="RedundantImport"/>
+ <module name="UnusedImports"/>
+
+ <!-- Javadoc Comments -->
+ <module name="AtclauseOrder"/>
+ <module name="JavadocMethod">
+ <property name="allowUndeclaredRTE" value="true"/>
+ <property name="allowThrowsTagsForSubclasses" value="true"/>
+ </module>
+ <module name="JavadocParagraph"/>
+ <module name="JavadocStyle">
+ <property name="scope" value="public"/>
+ </module>
+ <module name="JavadocTagContinuationIndentation"/>
+ <module name="NonEmptyAtclauseDescription"/>
+ <module name="SummaryJavadoc"/>
+ <module name="MissingJavadocMethod"/>
+ <module name="MissingJavadocPackage"/>
+ <module name="MissingJavadocType"/>
+ <!--
+ <module name="JavadocType">
+ <property name="allowUnknownTags" value="true"/>
+ </module>
+ <module name="JavadocVariable"/>
+ <module name="SingleLineJavadoc"/>
+ <module name="WriteTag"/>
+ -->
+
+ <!-- Metrics -->
+ <!--
+ <module name="BooleanExpressionComplexity">
+ <property name="max" value="7"/>
+ </module>
+ <module name="ClassDataAbstractionCoupling">
+ <property name="excludedClasses"
+ value="boolean, byte, char, double, float, int, long, short, void, Boolean, Byte, Character, Double, Float, Integer, Long, Short, Void, Object, Class, String,
+ StringBuffer, StringBuilder, ArrayIndexOutOfBoundsException, Exception, RuntimeException, IllegalArgumentException, IllegalStateException,
+ IndexOutOfBoundsException, NullPointerException, Throwable, SecurityException, UnsupportedOperationException, List, ArrayList, Deque, Queue, LinkedList, Set,
+ HashSet, SortedSet, TreeSet, Map, HashMap, SortedMap, TreeMap,
+ DetailsAST, CheckstyleException, UnsupportedEncodingException, BuildException, ConversionException, FileNotFoundException, TestException"/>
+ </module>
+ <module name="ClassFanOutComplexity">
+ <property name="max" value="25"/>
+ <property name="excludedClasses"
+ value="boolean, byte, char, double, float, int, long, short, void, Boolean, Byte, Character, Double, Float, Integer, Long, Short, Void, Object, Class, String,
+ StringBuffer, StringBuilder, ArrayIndexOutOfBoundsException, Exception, RuntimeException, IllegalArgumentException, IllegalStateException,
+ IndexOutOfBoundsException, NullPointerException, Throwable, SecurityException, UnsupportedOperationException, List, ArrayList, Deque, Queue, LinkedList, Set,
+ HashSet, SortedSet, TreeSet, Map, HashMap, SortedMap, TreeMap, DetailsAST, CheckstyleException, UnsupportedEncodingException, BuildException,
+ ConversionException, FileNotFoundException, TestException, Log, Sets, Multimap, TokenStreamRecognitionException, RecognitionException, TokenStreamException,
+ IOException"/>
+ </module>
+ <module name="CyclomaticComplexity">
+ <property name="switchBlockAsSingleDecisionPoint" value="true"/>
+ </module>
+ <module name="JavaNCSS"/>
+ <module name="NPathComplexity"/>
+ -->
+
+ <!-- Misc -->
+ <module name="ArrayTypeStyle"/>
+ <module name="AvoidEscapedUnicodeCharacters">
+ <property name="allowEscapesForControlCharacters" value="true"/>
+ <property name="allowByTailComment" value="true"/>
+ <property name="allowNonPrintableEscapes" value="true"/>
+ </module>
+ <module name="CommentsIndentation"/>
+ <!--
+ <module name="DescendantToken"/>
+ -->
+ <!--
+ <module name="FinalParameters">
+ <property name="severity" value="ignore"/>
+ </module>
+ <module name="Indentation">
+ <property name="basicOffset" value="4"/>
+ <property name="braceAdjustment" value="0"/>
+ <property name="caseIndent" value="4"/>
+ <property name="throwsIndent" value="8"/>
+ </module>
+ -->
+ <module name="OuterTypeFilename"/>
+ <!--
+ <module name="TodoComment">
+ <property name="format" value="(TODO)|(FIXME)"/>
+ </module>
+ <module name="TrailingComment"/>
+ <module name="UncommentedMain">
+ <property name="excludedClasses" value="\.Main$"/>
+ </module>
+ -->
+ <module name="UpperEll"/>
+
+ <!-- Modifiers -->
+ <module name="ModifierOrder"/>
+ <!--
+ <module name="RedundantModifier"/>
+ -->
+
+ <!-- Naming Conventions -->
+ <!--
+ <module name="AbbreviationAsWordInName">
+ <property name="ignoreFinal" value="false"/>
+ <property name="allowedAbbreviationLength" value="1"/>
+ <property name="allowedAbbreviations" value="AST"/>
+ </module>
+ <module name="AbstractClassName"/>
+ -->
+ <module name="ClassTypeParameterName">
+ <property name="format" value="(^[A-Z][0-9]?)$|([A-Z][a-zA-Z0-9]*[T]$)"/>
+ <message key="name.invalidPattern"
+ value="Class type name ''{0}'' must match pattern ''{1}''."/>
+ </module>
+ <module name="ConstantName">
+ <property name="format" value="^log(ger)?|[A-Z][A-Z0-9]*(_[A-Z0-9]+)*$"/>
+ </module>
+ <module name="InterfaceTypeParameterName"/>
+ <module name="LocalFinalVariableName"/>
+ <module name="LocalVariableName">
+ <property name="format" value="^[a-z][a-zA-Z0-9]*$"/>
+ </module>
+ <module name="MemberName"/>
+ <module name="MethodName"/>
+ <module name="MethodTypeParameterName"/>
+ <module name="PackageName"/>
+ <module name="ParameterName"/>
+ <!--
+ <module name="CatchParameterName"/>
+ -->
+ <module name="StaticVariableName"/>
+ <module name="TypeName"/>
+
+ <!-- Regexp -->
+ <!--
+ <module name="Regexp"/>
+ <module name="RegexpSinglelineJava"/>
+ <module name="RegexpSinglelineJava">
+ <property name="format" value="[^\p{ASCII}]"/>
+ <property name="ignoreComments" value="true"/>
+ </module>
+ -->
+
+ <!-- Size Violations -->
+ <!--
+ <module name="AnonInnerLength"/>
+ <module name="ExecutableStatementCount">
+ <property name="max" value="30"/>
+ </module>
+ <module name="LineLength">
+ <property name="max" value="100"/>
+ <property name="ignorePattern" value="^ *\* *[^ ]+$"/>
+ </module>
+ <module name="MethodCount">
+ <property name="maxTotal" value="35"/>
+ </module>
+ <module name="MethodLength"/>
+ -->
+ <module name="OuterTypeNumber"/>
+ <!--
+ <module name="ParameterNumber"/>
+ -->
+
+ <!-- Whitespace -->
+ <!--
+ <module name="EmptyForInitializerPad"/>
+ <module name="EmptyForIteratorPad"/>
+ <module name="EmptyLineSeparator">
+ <property name="allowNoEmptyLineBetweenFields" value="true"/>
+ <property name="allowMultipleEmptyLinesInsideClassMembers" value="false"/>
+ </module>
+ <module name="GenericWhitespace"/>
+ <module name="MethodParamPad"/>
+ <module name="NoLineWrap"/>
+ <module name="NoWhitespaceAfter">
+ <property name="tokens" value="ARRAY_INIT"/>
+ <property name="tokens" value="BNOT"/>
+ <property name="tokens" value="DEC"/>
+ <property name="tokens" value="DOT"/>
+ <property name="tokens" value="INC"/>
+ <property name="tokens" value="LNOT"/>
+ <property name="tokens" value="UNARY_MINUS"/>
+ <property name="tokens" value="UNARY_PLUS"/>
+ <property name="tokens" value="ARRAY_DECLARATOR"/>
+ </module>
+ <module name="NoWhitespaceBefore"/>
+ <module name="NoWhitespaceBefore">
+ <property name="tokens" value="DOT"/>
+ <property name="allowLineBreaks" value="true"/>
+ </module>
+ <module name="OperatorWrap"/>
+ <module name="OperatorWrap">
+ <property name="tokens" value="ASSIGN"/>
+ <property name="tokens" value="DIV_ASSIGN"/>
+ <property name="tokens" value="PLUS_ASSIGN"/>
+ <property name="tokens" value="MINUS_ASSIGN"/>
+ <property name="tokens" value="STAR_ASSIGN"/>
+ <property name="tokens" value="MOD_ASSIGN"/>
+ <property name="tokens" value="SR_ASSIGN"/>
+ <property name="tokens" value="BSR_ASSIGN"/>
+ <property name="tokens" value="SL_ASSIGN"/>
+ <property name="tokens" value="BXOR_ASSIGN"/>
+ <property name="tokens" value="BOR_ASSIGN"/>
+ <property name="tokens" value="BAND_ASSIGN"/>
+ <property name="option" value="eol"/>
+ </module>
+ <module name="ParenPad"/>
+ <module name="SeparatorWrap">
+ <property name="tokens" value="DOT"/>
+ <property name="option" value="nl"/>
+ </module>
+ <module name="SeparatorWrap">
+ <property name="tokens" value="COMMA"/>
+ <property name="option" value="EOL"/>
+ </module>
+ <module name="SingleSpaceSeparator">
+ <property name="validateComments" value="false"/>
+ </module>
+ <module name="TypecastParenPad"/>
+ <module name="WhitespaceAfter"/>
+ <module name="WhitespaceAround"/>
+ -->
+
+ </module>
+
+</module>
diff --git a/java-package/tools/conf/findbugs-exclude.xml b/java-package/tools/conf/findbugs-exclude.xml
new file mode 100644
index 0000000..6497265
--- /dev/null
+++ b/java-package/tools/conf/findbugs-exclude.xml
@@ -0,0 +1,50 @@
+<?xml version="1.0" encoding="utf-8"?>
+<!--
+ ~ Licensed to the Apache Software Foundation (ASF) under one or more
+ ~ contributor license agreements. See the NOTICE file distributed with
+ ~ this work for additional information regarding copyright ownership.
+ ~ The ASF licenses this file to You under the Apache License, Version 2.0
+ ~ (the "License"); you may not use this file except in compliance with
+ ~ the License. You may obtain a copy of the License at
+ ~
+ ~ http://www.apache.org/licenses/LICENSE-2.0
+ ~
+ ~ Unless required by applicable law or agreed to in writing, software
+ ~ distributed under the License is distributed on an "AS IS" BASIS,
+ ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ ~ See the License for the specific language governing permissions and
+ ~ limitations under the License.
+ -->
+
+<FindBugsFilter>
+ <Match>
+ <Bug pattern="DM_EXIT,DMI_EMPTY_DB_PASSWORD,DMI_HARDCODED_ABSOLUTE_FILENAME,EI_EXPOSE_REP,EI_EXPOSE_REP2,SF_SWITCH_FALLTHROUGH,NM_CONFUSING"/>
+ </Match>
+ <!-- low priority issues-->
+ <Match>
+ <Bug pattern="DM_CONVERT_CASE,SE_TRANSIENT_FIELD_NOT_RESTORED,UWF_FIELD_NOT_INITIALIZED_IN_CONSTRUCTOR,BC_UNCONFIRMED_CAST_OF_RETURN_VALUE"/>
+ </Match>
+ <Match>
+ <Bug pattern="PZLA_PREFER_ZERO_LENGTH_ARRAYS,DB_DUPLICATE_SWITCH_CLAUSES,BC_UNCONFIRMED_CAST"/>
+ </Match>
+
+ <!-- wildcard suppression -->
+ <Match>
+ <Class name="~ai\.djl\.mxnet\.jnarator\.parser\..*"/>
+ </Match>
+ <!-- function suppression -->
+ <Match>
+ <Bug pattern="DC_DOUBLECHECK"/>
+ <Class name="~ai\.djl\.mxnet\.engine\.MxSymbolBlock"/>
+ <Method name="forwardInternal"/>
+ </Match>
+ <Match>
+ <Bug pattern="DC_DOUBLECHECK"/>
+ <Class name="~ai\.djl\.pytorch\.engine\.PtSymbolBlock"/>
+ <Method name="forwardInternal"/>
+ </Match>
+ <Match>
+ <Bug pattern="URF_UNREAD_FIELD"/>
+ <Class name="~ai\.djl\.pytorch\.engine\.PtNDArray"/>
+ </Match>
+</FindBugsFilter>
diff --git a/java-package/tools/conf/licenseHeader.java b/java-package/tools/conf/licenseHeader.java
new file mode 100644
index 0000000..3e7c6c2
--- /dev/null
+++ b/java-package/tools/conf/licenseHeader.java
@@ -0,0 +1,16 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
diff --git a/java-package/tools/conf/pmd.xml b/java-package/tools/conf/pmd.xml
new file mode 100644
index 0000000..d8c0e86
--- /dev/null
+++ b/java-package/tools/conf/pmd.xml
@@ -0,0 +1,466 @@
+<?xml version="1.0" encoding="utf-8"?>
+<!--
+ ~ Licensed to the Apache Software Foundation (ASF) under one or more
+ ~ contributor license agreements. See the NOTICE file distributed with
+ ~ this work for additional information regarding copyright ownership.
+ ~ The ASF licenses this file to You under the Apache License, Version 2.0
+ ~ (the "License"); you may not use this file except in compliance with
+ ~ the License. You may obtain a copy of the License at
+ ~
+ ~ http://www.apache.org/licenses/LICENSE-2.0
+ ~
+ ~ Unless required by applicable law or agreed to in writing, software
+ ~ distributed under the License is distributed on an "AS IS" BASIS,
+ ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ ~ See the License for the specific language governing permissions and
+ ~ limitations under the License.
+ -->
+
+<ruleset name="pmd">
+ <description>Java Rule in PMD</description>
+
+ <!--
+ <rule ref="category/java/bestpractices.xml">
+ <exclude name="AbstractClassWithoutAbstractMethod"/>
+ <exclude name="AccessorClassGeneration"/>
+ <exclude name="AccessorMethodGeneration"/>
+ <exclude name="ArrayIsStoredDirectly"/>
+ <exclude name="AvoidPrintStackTrace"/>
+ <exclude name="AvoidReassigningParameters"/>
+ <exclude name="AvoidStringBufferField"/>
+ <exclude name="AvoidUsingHardCodedIP"/>
+ <exclude name="CheckResultSet"/>
+ <exclude name="ConstantsInInterface"/>
+ <exclude name="DefaultLabelNotLastInSwitchStmt"/>
+ <exclude name="ForLoopCanBeForeach"/>
+ <exclude name="GuardLogStatement"/>
+ <exclude name="JUnit4SuitesShouldUseSuiteAnnotation"/>
+ <exclude name="JUnit4TestShouldUseAfterAnnotation"/>
+ <exclude name="JUnit4TestShouldUseBeforeAnnotation"/>
+ <exclude name="JUnit4TestShouldUseTestAnnotation"/>
+ <exclude name="JUnitAssertionsShouldIncludeMessage"/>
+ <exclude name="JUnitTestContainsTooManyAsserts"/>
+ <exclude name="JUnitTestsShouldIncludeAssert"/>
+ <exclude name="JUnitUseExpected"/>
+ <exclude name="LooseCoupling"/>
+ <exclude name="MethodReturnsInternalArray"/>
+ <exclude name="MissingOverride"/>
+ <exclude name="OneDeclarationPerLine"/>
+ <exclude name="PositionLiteralsFirstInCaseInsensitiveComparisons"/>
+ <exclude name="PositionLiteralsFirstInComparisons"/>
+ <exclude name="PreserveStackTrace"/>
+ <exclude name="ReplaceEnumerationWithIterator"/>
+ <exclude name="ReplaceHashtableWithMap"/>
+ <exclude name="ReplaceVectorWithList"/>
+ <exclude name="SwitchStmtsShouldHaveDefault"/>
+ <exclude name="SystemPrintln"/>
+ <exclude name="UnusedFormalParameter"/>
+ <exclude name="UnusedImports"/>
+ <exclude name="UnusedLocalVariable"/>
+ <exclude name="UnusedPrivateField"/>
+ <exclude name="UnusedPrivateMethod"/>
+ <exclude name="UseAssertEqualsInsteadOfAssertTrue"/>
+ <exclude name="UseAssertNullInsteadOfAssertTrue"/>
+ <exclude name="UseAssertSameInsteadOfAssertTrue"/>
+ <exclude name="UseAssertTrueInsteadOfAssertEquals"/>
+ <exclude name="UseCollectionIsEmpty"/>
+ <exclude name="UseVarargs"/>
+ </rule>
+
+ <rule ref="category/java/codestyle.xml">
+ <exclude name="AbstractNaming"/>
+ <exclude name="AtLeastOneConstructor"/>
+ <exclude name="AvoidDollarSigns"/>
+ <exclude name="AvoidFinalLocalVariable"/>
+ <exclude name="AvoidPrefixingMethodParameters"/>
+ <exclude name="AvoidProtectedFieldInFinalClass"/>
+ <exclude name="AvoidProtectedMethodInFinalClassNotExtending"/>
+ <exclude name="AvoidUsingNativeCode"/>
+ <exclude name="BooleanGetMethodName"/>
+ <exclude name="CallSuperInConstructor"/>
+ <exclude name="ClassNamingConventions"/>
+ <exclude name="CommentDefaultAccessModifier"/>
+ <exclude name="ConfusingTernary"/>
+ <exclude name="ControlStatementBraces"/>
+ <exclude name="DefaultPackage"/>
+ <exclude name="DontImportJavaLang"/>
+ <exclude name="DuplicateImports"/>
+ <exclude name="EmptyMethodInAbstractClassShouldBeAbstract"/>
+ <exclude name="ExtendsObject"/>
+ <exclude name="FieldDeclarationsShouldBeAtStartOfClass"/>
+ <exclude name="FieldNamingConventions"/>
+ <exclude name="ForLoopShouldBeWhileLoop"/>
+ <exclude name="ForLoopsMustUseBraces"/>
+ <exclude name="FormalParameterNamingConventions"/>
+ <exclude name="GenericsNaming"/>
+ <exclude name="IdenticalCatchBranches"/>
+ <exclude name="IfElseStmtsMustUseBraces"/>
+ <exclude name="IfStmtsMustUseBraces"/>
+ <exclude name="LinguisticNaming"/>
+ <exclude name="LocalHomeNamingConvention"/>
+ <exclude name="LocalInterfaceSessionNamingConvention"/>
+ <exclude name="LocalVariableCouldBeFinal"/>
+ <exclude name="LocalVariableNamingConventions"/>
+ <exclude name="LongVariable"/>
+ <exclude name="MDBAndSessionBeanNamingConvention"/>
+ <exclude name="MethodArgumentCouldBeFinal"/>
+ <exclude name="MethodNamingConventions"/>
+ <exclude name="MIsLeadingVariableName"/>
+ <exclude name="NoPackage"/>
+ <exclude name="OnlyOneReturn"/>
+ <exclude name="PackageCase"/>
+ <exclude name="PrematureDeclaration"/>
+ <exclude name="RemoteInterfaceNamingConvention"/>
+ <exclude name="RemoteSessionInterfaceNamingConvention"/>
+ <exclude name="ShortClassName"/>
+ <exclude name="ShortMethodName"/>
+ <exclude name="ShortVariable"/>
+ <exclude name="SuspiciousConstantFieldName"/>
+ <exclude name="TooManyStaticImports"/>
+ <exclude name="UnnecessaryAnnotationValueElement"/>
+ <exclude name="UnnecessaryConstructor"/>
+ <exclude name="UnnecessaryFullyQualifiedName"/>
+ <exclude name="UnnecessaryLocalBeforeReturn"/>
+ <exclude name="UnnecessaryModifier"/>
+ <exclude name="UnnecessaryReturn"/>
+ <exclude name="UselessParentheses"/>
+ <exclude name="UselessQualifiedThis"/>
+ <exclude name="VariableNamingConventions"/>
+ <exclude name="WhileLoopsMustUseBraces"/>
+ </rule>
+
+ <rule ref="category/java/design.xml">
+ <exclude name="AbstractClassWithoutAnyMethod"/>
+ <exclude name="AvoidCatchingGenericException"/>
+ <exclude name="AvoidDeeplyNestedIfStmts"/>
+ <exclude name="AvoidRethrowingException"/>
+ <exclude name="AvoidThrowingNewInstanceOfSameException"/>
+ <exclude name="AvoidThrowingNullPointerException"/>
+ <exclude name="AvoidThrowingRawExceptionTypes"/>
+ <exclude name="ClassWithOnlyPrivateConstructorsShouldBeFinal"/>
+ <exclude name="CollapsibleIfStatements"/>
+ <exclude name="CouplingBetweenObjects"/>
+ <exclude name="CyclomaticComplexity"/>
+ <exclude name="DataClass"/>
+ <exclude name="DoNotExtendJavaLangError"/>
+ <exclude name="ExceptionAsFlowControl"/>
+ <exclude name="ExcessiveClassLength"/>
+ <exclude name="ExcessiveImports"/>
+ <exclude name="ExcessiveMethodLength"/>
+ <exclude name="ExcessiveParameterList"/>
+ <exclude name="ExcessivePublicCount"/>
+ <exclude name="FinalFieldCouldBeStatic"/>
+ <exclude name="GodClass"/>
+ <exclude name="ImmutableField"/>
+ <exclude name="LawOfDemeter"/>
+ <exclude name="LogicInversion"/>
+ <exclude name="LoosePackageCoupling"/>
+ <exclude name="ModifiedCyclomaticComplexity"/>
+ <exclude name="NcssConstructorCount"/>
+ <exclude name="NcssCount"/>
+ <exclude name="NcssMethodCount"/>
+ <exclude name="NcssTypeCount"/>
+ <exclude name="NPathComplexity"/>
+ <exclude name="SignatureDeclareThrowsException"/>
+ <exclude name="SimplifiedTernary"/>
+ <exclude name="SimplifyBooleanAssertion"/>
+ <exclude name="SimplifyBooleanExpressions"/>
+ <exclude name="SimplifyBooleanReturns"/>
+ <exclude name="SimplifyConditional"/>
+ <exclude name="SingularField"/>
+ <exclude name="StdCyclomaticComplexity"/>
+ <exclude name="SwitchDensity"/>
+ <exclude name="TooManyFields"/>
+ <exclude name="TooManyMethods"/>
+ <exclude name="UselessOverridingMethod"/>
+ <exclude name="UseObjectForClearerAPI"/>
+ <exclude name="UseUtilityClass"/>
+ </rule>
+
+ <rule ref="category/java/documentation.xml">
+ <exclude name="CommentContent"/>
+ <exclude name="CommentRequired"/>
+ <exclude name="CommentSize"/>
+ <exclude name="UncommentedEmptyConstructor"/>
+ <exclude name="UncommentedEmptyMethodBody"/>
+ </rule>
+
+ <rule ref="category/java/errorprone.xml">
+ <exclude name="AssignmentInOperand"/>
+ <exclude name="AssignmentToNonFinalStatic"/>
+ <exclude name="AvoidAccessibilityAlteration"/>
+ <exclude name="AvoidAssertAsIdentifier"/>
+ <exclude name="AvoidBranchingStatementAsLastInLoop"/>
+ <exclude name="AvoidCallingFinalize"/>
+ <exclude name="AvoidCatchingNPE"/>
+ <exclude name="AvoidCatchingThrowable"/>
+ <exclude name="AvoidDecimalLiteralsInBigDecimalConstructor"/>
+ <exclude name="AvoidDuplicateLiterals"/>
+ <exclude name="AvoidEnumAsIdentifier"/>
+ <exclude name="AvoidFieldNameMatchingMethodName"/>
+ <exclude name="AvoidFieldNameMatchingTypeName"/>
+ <exclude name="AvoidInstanceofChecksInCatchClause"/>
+ <exclude name="AvoidLiteralsInIfCondition"/>
+ <exclude name="AvoidLosingExceptionInformation"/>
+ <exclude name="AvoidMultipleUnaryOperators"/>
+ <exclude name="AvoidUsingOctalValues"/>
+ <exclude name="BadComparison"/>
+ <exclude name="BeanMembersShouldSerialize"/>
+ <exclude name="BrokenNullCheck"/>
+ <exclude name="CallSuperFirst"/>
+ <exclude name="CallSuperLast"/>
+ <exclude name="CheckSkipResult"/>
+ <exclude name="ClassCastExceptionWithToArray"/>
+ <exclude name="CloneMethodMustBePublic"/>
+ <exclude name="CloneMethodMustImplementCloneable"/>
+ <exclude name="CloneMethodReturnTypeMustMatchClassName"/>
+ <exclude name="CloneThrowsCloneNotSupportedException"/>
+ <exclude name="CloseResource"/>
+ <exclude name="CompareObjectsWithEquals"/>
+ <exclude name="ConstructorCallsOverridableMethod"/>
+ <exclude name="DataflowAnomalyAnalysis"/>
+ <exclude name="DoNotCallGarbageCollectionExplicitly"/>
+ <exclude name="DoNotCallSystemExit"/>
+ <exclude name="DoNotExtendJavaLangThrowable"/>
+ <exclude name="DoNotHardCodeSDCard"/>
+ <exclude name="DoNotThrowExceptionInFinally"/>
+ <exclude name="DontImportSun"/>
+ <exclude name="DontUseFloatTypeForLoopIndices"/>
+ <exclude name="EmptyCatchBlock"/>
+ <exclude name="EmptyFinalizer"/>
+ <exclude name="EmptyFinallyBlock"/>
+ <exclude name="EmptyIfStmt"/>
+ <exclude name="EmptyInitializer"/>
+ <exclude name="EmptyStatementBlock"/>
+ <exclude name="EmptyStatementNotInLoop"/>
+ <exclude name="EmptySwitchStatements"/>
+ <exclude name="EmptySynchronizedBlock"/>
+ <exclude name="EmptyTryBlock"/>
+ <exclude name="EmptyWhileStmt"/>
+ <exclude name="EqualsNull"/>
+ <exclude name="FinalizeDoesNotCallSuperFinalize"/>
+ <exclude name="FinalizeOnlyCallsSuperFinalize"/>
+ <exclude name="FinalizeOverloaded"/>
+ <exclude name="FinalizeShouldBeProtected"/>
+ <exclude name="IdempotentOperations"/>
+ <exclude name="ImportFromSamePackage"/>
+ <exclude name="InstantiationToGetClass"/>
+ <exclude name="InvalidSlf4jMessageFormat"/>
+ <exclude name="JumbledIncrementer"/>
+ <exclude name="JUnitSpelling"/>
+ <exclude name="JUnitStaticSuite"/>
+ <exclude name="LoggerIsNotStaticFinal"/>
+ <exclude name="MethodWithSameNameAsEnclosingClass"/>
+ <exclude name="MisplacedNullCheck"/>
+ <exclude name="MissingBreakInSwitch"/>
+ <exclude name="MissingSerialVersionUID"/>
+ <exclude name="MissingStaticMethodInNonInstantiatableClass"/>
+ <exclude name="MoreThanOneLogger"/>
+ <exclude name="NonCaseLabelInSwitchStatement"/>
+ <exclude name="NonStaticInitializer"/>
+ <exclude name="NullAssignment"/>
+ <exclude name="OverrideBothEqualsAndHashcode"/>
+ <exclude name="ProperCloneImplementation"/>
+ <exclude name="ProperLogger"/>
+ <exclude name="ReturnEmptyArrayRatherThanNull"/>
+ <exclude name="ReturnFromFinallyBlock"/>
+ <exclude name="SimpleDateFormatNeedsLocale"/>
+ <exclude name="SingleMethodSingleton"/>
+ <exclude name="SingletonClassReturningNewInstance"/>
+ <exclude name="StaticEJBFieldShouldBeFinal"/>
+ <exclude name="StringBufferInstantiationWithChar"/>
+ <exclude name="SuspiciousEqualsMethodName"/>
+ <exclude name="SuspiciousHashcodeMethodName"/>
+ <exclude name="SuspiciousOctalEscape"/>
+ <exclude name="TestClassWithoutTestCases"/>
+ <exclude name="UnconditionalIfStatement"/>
+ <exclude name="UnnecessaryBooleanAssertion"/>
+ <exclude name="UnnecessaryCaseChange"/>
+ <exclude name="UnnecessaryConversionTemporary"/>
+ <exclude name="UnusedNullCheckInEquals"/>
+ <exclude name="UseCorrectExceptionLogging"/>
+ <exclude name="UseEqualsToCompareStrings"/>
+ <exclude name="UselessOperationOnImmutable"/>
+ <exclude name="UseLocaleWithCaseConversions"/>
+ <exclude name="UseProperClassLoader"/>
+ </rule>
+
+ <rule ref="category/java/multithreading.xml">
+ <exclude name="AvoidSynchronizedAtMethodLevel"/>
+ <exclude name="AvoidThreadGroup"/>
+ <exclude name="AvoidUsingVolatile"/>
+ <exclude name="DoNotUseThreads"/>
+ <exclude name="DontCallThreadRun"/>
+ <exclude name="DoubleCheckedLocking"/>
+ <exclude name="NonThreadSafeSingleton"/>
+ <exclude name="UnsynchronizedStaticDateFormatter"/>
+ <exclude name="UseConcurrentHashMap"/>
+ <exclude name="UseNotifyAllInsteadOfNotify"/>
+ </rule>
+
+ <rule ref="category/java/performance.xml">
+ <exclude name="AddEmptyString"/>
+ <exclude name="AppendCharacterWithChar"/>
+ <exclude name="AvoidArrayLoops"/>
+ <exclude name="AvoidFileStream"/>
+ <exclude name="AvoidInstantiatingObjectsInLoops"/>
+ <exclude name="AvoidUsingShortType"/>
+ <exclude name="BigIntegerInstantiation"/>
+ <exclude name="BooleanInstantiation"/>
+ <exclude name="ByteInstantiation"/>
+ <exclude name="ConsecutiveAppendsShouldReuse"/>
+ <exclude name="ConsecutiveLiteralAppends"/>
+ <exclude name="InefficientEmptyStringCheck"/>
+ <exclude name="InefficientStringBuffering"/>
+ <exclude name="InsufficientStringBufferDeclaration"/>
+ <exclude name="IntegerInstantiation"/>
+ <exclude name="LongInstantiation"/>
+ <exclude name="OptimizableToArrayCall"/>
+ <exclude name="RedundantFieldInitializer"/>
+ <exclude name="SimplifyStartsWith"/>
+ <exclude name="ShortInstantiation"/>
+ <exclude name="StringInstantiation"/>
+ <exclude name="StringToString"/>
+ <exclude name="TooFewBranchesForASwitchStatement"/>
+ <exclude name="UnnecessaryWrapperObjectCreation"/>
+ <exclude name="UseArrayListInsteadOfVector"/>
+ <exclude name="UseArraysAsList"/>
+ <exclude name="UseIndexOfChar"/>
+ <exclude name="UselessStringValueOf"/>
+ <exclude name="UseStringBufferForStringAppends"/>
+ <exclude name="UseStringBufferLength"/>
+ </rule>
+
+ <rule ref="category/java/security.xml">
+ <exclude name="HardCodedCryptoKey"/>
+ <exclude name="InsecureCryptoIv"/>
+ </rule>
+ -->
+
+ <rule ref="category/java/bestpractices.xml">
+ <exclude name="ArrayIsStoredDirectly"/>
+ <exclude name="AvoidReassigningParameters"/>
+ <exclude name="AvoidUsingHardCodedIP"/>
+ <exclude name="ConstantsInInterface"/>
+ <exclude name="DefaultLabelNotLastInSwitchStmt"/>
+ <exclude name="GuardLogStatement"/>
+ <exclude name="JUnit4SuitesShouldUseSuiteAnnotation"/>
+ <exclude name="JUnit4TestShouldUseAfterAnnotation"/>
+ <exclude name="JUnit4TestShouldUseBeforeAnnotation"/>
+ <exclude name="JUnit4TestShouldUseTestAnnotation"/>
+ <exclude name="JUnitAssertionsShouldIncludeMessage"/>
+ <exclude name="JUnitTestContainsTooManyAsserts"/>
+ <exclude name="JUnitTestsShouldIncludeAssert"/>
+ <exclude name="JUnitUseExpected"/>
+ <exclude name="MethodReturnsInternalArray"/>
+ <exclude name="UnusedPrivateMethod"/>
+ <exclude name="UseVarargs"/>
+ </rule>
+
+ <rule ref="category/java/codestyle.xml">
+ <exclude name="AbstractNaming"/>
+ <exclude name="AtLeastOneConstructor"/>
+ <exclude name="BooleanGetMethodName"/>
+ <exclude name="CallSuperInConstructor"/>
+ <exclude name="ClassNamingConventions"/>
+ <exclude name="CommentDefaultAccessModifier"/>
+ <exclude name="ConfusingTernary"/>
+ <exclude name="DefaultPackage"/>
+ <exclude name="EmptyMethodInAbstractClassShouldBeAbstract"/>
+ <exclude name="FieldNamingConventions"/>
+ <exclude name="LinguisticNaming"/>
+ <exclude name="LocalVariableCouldBeFinal"/>
+ <exclude name="LongVariable"/>
+ <exclude name="MethodArgumentCouldBeFinal"/>
+ <exclude name="MethodNamingConventions"/>
+ <exclude name="OnlyOneReturn"/>
+ <exclude name="PrematureDeclaration"/>
+ <exclude name="ShortClassName"/>
+ <exclude name="ShortMethodName"/>
+ <exclude name="ShortVariable"/>
+ <exclude name="UnnecessaryModifier"/>
+ <exclude name="UselessParentheses"/>
+ <exclude name="UseUnderscoresInNumericLiterals"/>
+ <exclude name="VariableNamingConventions"/>
+ </rule>
+
+ <rule ref="category/java/design.xml">
+ <exclude name="AvoidCatchingGenericException"/>
+ <exclude name="AvoidDeeplyNestedIfStmts"/>
+ <exclude name="AvoidThrowingNullPointerException"/>
+ <exclude name="CollapsibleIfStatements"/>
+ <exclude name="CyclomaticComplexity"/>
+ <exclude name="DataClass"/>
+ <exclude name="ExcessiveClassLength"/>
+ <exclude name="ExcessiveImports"/>
+ <exclude name="ExcessiveMethodLength"/>
+ <exclude name="ExcessiveParameterList"/>
+ <exclude name="ExcessivePublicCount"/>
+ <exclude name="GodClass"/>
+ <exclude name="ImmutableField"/>
+ <exclude name="LawOfDemeter"/>
+ <exclude name="LoosePackageCoupling"/>
+ <exclude name="NcssConstructorCount"/>
+ <exclude name="NcssCount"/>
+ <exclude name="NcssMethodCount"/>
+ <exclude name="NcssTypeCount"/>
+ <exclude name="NPathComplexity"/>
+ <exclude name="SwitchDensity"/>
+ <exclude name="TooManyFields"/>
+ <exclude name="TooManyMethods"/>
+ <exclude name="UseObjectForClearerAPI"/>
+ </rule>
+
+ <!--
+ <rule ref="category/java/documentation.xml"/>
+ -->
+
+ <rule ref="category/java/errorprone.xml">
+ <exclude name="AssignmentInOperand"/>
+ <exclude name="AvoidCatchingThrowable"/>
+ <exclude name="AvoidDuplicateLiterals"/>
+ <exclude name="AvoidFieldNameMatchingMethodName"/>
+ <exclude name="AvoidFieldNameMatchingTypeName"/>
+ <exclude name="AvoidInstanceofChecksInCatchClause"/>
+ <exclude name="AvoidLiteralsInIfCondition"/>
+ <exclude name="BeanMembersShouldSerialize"/>
+ <exclude name="CloneMethodMustImplementCloneable"/>
+ <exclude name="CloseResource"/>
+ <exclude name="CompareObjectsWithEquals"/>
+ <exclude name="DataflowAnomalyAnalysis"/>
+ <exclude name="EmptyCatchBlock"/>
+ <exclude name="EmptyIfStmt"/>
+ <exclude name="JumbledIncrementer"/>
+ <exclude name="LoggerIsNotStaticFinal"/>
+ <exclude name="MissingBreakInSwitch"/>
+ <exclude name="MoreThanOneLogger"/>
+ <exclude name="NonCaseLabelInSwitchStatement"/>
+ <exclude name="NonStaticInitializer"/>
+ <exclude name="NullAssignment"/>
+ <exclude name="OverrideBothEqualsAndHashcode"/>
+ <exclude name="ReturnEmptyArrayRatherThanNull"/>
+ <exclude name="SimpleDateFormatNeedsLocale"/>
+ <exclude name="UnnecessaryConversionTemporary"/>
+ <exclude name="UseEqualsToCompareStrings"/>
+ <exclude name="UseLocaleWithCaseConversions"/>
+ </rule>
+
+ <rule ref="category/java/multithreading.xml">
+ <exclude name="AvoidSynchronizedAtMethodLevel"/>
+ <exclude name="DoNotUseThreads"/>
+ </rule>
+
+ <rule ref="category/java/performance.xml">
+ <exclude name="AddEmptyString"/>
+ <exclude name="AvoidInstantiatingObjectsInLoops"/>
+ <exclude name="ConsecutiveAppendsShouldReuse"/>
+ <exclude name="InefficientStringBuffering"/>
+ <exclude name="OptimizableToArrayCall"/>
+ <exclude name="SimplifyStartsWith"/>
+ <exclude name="TooFewBranchesForASwitchStatement"/>
+ <exclude name="UseArrayListInsteadOfVector"/>
+ </rule>
+
+ <rule ref="category/java/security.xml"/>
+
+</ruleset>
diff --git a/java-package/tools/conf/suppressions.xml b/java-package/tools/conf/suppressions.xml
new file mode 100644
index 0000000..263bc2e
--- /dev/null
+++ b/java-package/tools/conf/suppressions.xml
@@ -0,0 +1,34 @@
+<?xml version="1.0" encoding="utf-8"?>
+<!--
+ ~ Licensed to the Apache Software Foundation (ASF) under one or more
+ ~ contributor license agreements. See the NOTICE file distributed with
+ ~ this work for additional information regarding copyright ownership.
+ ~ The ASF licenses this file to You under the Apache License, Version 2.0
+ ~ (the "License"); you may not use this file except in compliance with
+ ~ the License. You may obtain a copy of the License at
+ ~
+ ~ http://www.apache.org/licenses/LICENSE-2.0
+ ~
+ ~ Unless required by applicable law or agreed to in writing, software
+ ~ distributed under the License is distributed on an "AS IS" BASIS,
+ ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ ~ See the License for the specific language governing permissions and
+ ~ limitations under the License.
+ -->
+
+<!DOCTYPE suppressions PUBLIC "-//Puppy Crawl//DTD Suppressions 1.1//EN"
+ "http://www.puppycrawl.com/dtds/suppressions_1_1.dtd">
+<suppressions>
+ <suppress checks="FileLength" files="NDArray.java" />
+
+ <suppress checks="(AvoidStaticImport|ImportOrder)" files="src[\\/]test[\\/]java[\\/]"/>
+
+ <suppress checks="(MissingJavadocMethod|MissingJavadocType)" files="src[\\/](test|it)[\\/].*"/>
+
+ <!-- Suppress javadoc in modules-->
+ <suppress checks="(MissingJavadocMethod|MissingJavadocType)" files="ai[\\/]djl[\\/]testing[\\/]"/>
+ <suppress checks="(MissingJavadocMethod|MissingJavadocType)" files="ai[\\/]djl[\\/]integration[\\/]"/>
+ <suppress checks="(MissingJavadocMethod|MissingJavadocType)" files="ai[\\/]djl[\\/]examples[\\/]"/>
+ <suppress checks="(MissingJavadocMethod|MissingJavadocType)" files="ai[\\/]djl[\\/]mxnet[\\/]jnarator[\\/]"/>
+ <suppress checks="(MissingJavadocMethod|MissingJavadocType)" files="ai[\\/]djl[\\/]tensorflow[\\/]"/>
+</suppressions>
diff --git a/java-package/tools/gradle/check.gradle b/java-package/tools/gradle/check.gradle
new file mode 100644
index 0000000..e38eb34
--- /dev/null
+++ b/java-package/tools/gradle/check.gradle
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+if (JavaVersion.current() < JavaVersion.VERSION_11) {
+ apply plugin: "com.github.spotbugs"
+ spotbugs {
+ excludeFilter = file("${rootProject.projectDir}/tools/conf/findbugs-exclude.xml")
+ ignoreFailures = false
+ spotbugsTest.enabled = true
+ }
+ spotbugsMain {
+ reports {
+ xml.enabled false
+ html.enabled true
+ }
+ }
+ spotbugsTest {
+ reports {
+ xml.enabled false
+ html.enabled true
+ }
+ }
+}
+
+apply plugin: "pmd"
+pmd {
+ ignoreFailures = false
+ pmdTest.enabled = false
+ ruleSets = [] // workaround pmd gradle plugin bug
+ ruleSetFiles = files("${rootProject.projectDir}/tools/conf/pmd.xml")
+}
+tasks.withType(Pmd){
+ reports{
+ xml.enabled=true
+ html.enabled=true
+ }
+}
+
+apply plugin: "checkstyle"
+checkstyle {
+ toolVersion = "8.26"
+ ignoreFailures = false
+ checkstyleTest.enabled = true
+ configProperties = [
+ "checkstyle.suppressions.file" : file("${rootProject.projectDir}/tools/conf/suppressions.xml"),
+ "checkstyle.licenseHeader.file" : file("${rootProject.projectDir}/tools/conf/licenseHeader.java")
+ ]
+ configFile = file("${rootProject.projectDir}/tools/conf/checkstyle.xml")
+}
+checkstyleMain {
+ classpath += configurations.compileClasspath
+}
+tasks.withType(Checkstyle) {
+ reports {
+ xml.enabled false
+ html.enabled true
+ }
+}
+
+apply plugin: 'jacoco'
+jacoco {
+ toolVersion = "0.8.5"
+}
+jacocoTestReport {
+ reports {
+ xml.enabled true
+ csv.enabled false
+ }
+}
+
+test.finalizedBy jacocoTestReport
+build.dependsOn javadoc
diff --git a/java-package/tools/gradle/jacoco.gradle b/java-package/tools/gradle/jacoco.gradle
new file mode 100644
index 0000000..010ae47
--- /dev/null
+++ b/java-package/tools/gradle/jacoco.gradle
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+apply plugin: 'jacoco'
+
+def jacocoProjects = subprojects.findAll {
+ if ([":jnarator"].contains(it.getPath())) {
+ return false
+ }
+ return new File(it.projectDir, "src/test/java").exists()
+}
+
+task jacocoMergeTestData(type: JacocoMerge) {
+ jacocoProjects.each { p ->
+ dependsOn(p.test, p.jacocoTestReport)
+ executionData p.tasks.withType(Test)
+ }
+}
+
+def exclusions = [":examples", ":integration"]
+
+task jacocoRootReport(type: JacocoReport) {
+ dependsOn jacocoMergeTestData
+ description = 'Generates an aggregate report from all subprojects'
+
+ jacocoProjects.each { p ->
+ if (!exclusions.contains(p.getPath())) {
+ additionalSourceDirs.from files((Set<File>) p.sourceSets.main.allJava.srcDirs)
+ sourceDirectories.from files((Set<File>) p.sourceSets.main.allSource.srcDirs)
+ classDirectories.from files((FileCollection) p.sourceSets.main.output)
+ additionalClassDirs((FileCollection) p.sourceSets.main.output)
+ }
+ }
+ executionData.from = files(jacocoProjects.jacocoTestReport.executionData).filter { f -> f.exists() }
+
+ reports {
+ xml.enabled = true
+ html.enabled = true
+ }
+}
+
+task jacocoRootVerification(type: JacocoCoverageVerification) {
+ dependsOn jacocoMergeTestData
+
+ jacocoProjects.each { p ->
+ if (!exclusions.contains(p.getPath())) {
+ additionalSourceDirs.from files((Set<File>) p.sourceSets.main.allJava.srcDirs)
+ sourceDirectories.from files((Set<File>) p.sourceSets.main.allSource.srcDirs)
+ classDirectories.from files((FileCollection) p.sourceSets.main.output)
+ additionalClassDirs((FileCollection) p.sourceSets.main.output)
+ }
+ }
+ executionData.from = files(jacocoProjects.jacocoTestReport.executionData).filter { f -> f.exists() }
+
+ violationRules {
+ rule {
+ limit {
+ if (Boolean.getBoolean("nightly")) {
+ minimum = 0.70
+ } else {
+ minimum = 0.65
+ }
+ }
+ }
+ }
+}
diff --git a/java-package/tools/gradle/java-formatter.gradle b/java-package/tools/gradle/java-formatter.gradle
new file mode 100644
index 0000000..2cbe96d
--- /dev/null
+++ b/java-package/tools/gradle/java-formatter.gradle
@@ -0,0 +1,85 @@
+buildscript {
+ repositories {
+ maven {
+ url "https://plugins.gradle.org/m2/"
+ }
+ }
+ dependencies {
+ classpath 'com.google.googlejavaformat:google-java-format:1.6'
+ }
+}
+
+apply plugin: JavaFormatterPlugin
+
+check.dependsOn verifyJava
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import com.google.googlejavaformat.java.Formatter
+import com.google.googlejavaformat.java.ImportOrderer
+import com.google.googlejavaformat.java.JavaFormatterOptions
+import com.google.googlejavaformat.java.Main
+import com.google.googlejavaformat.java.RemoveUnusedImports
+
+class JavaFormatterPlugin implements Plugin<Project> {
+ void apply(Project project) {
+ project.task('formatJava') {
+ doLast {
+ Main formatter = new Main(new PrintWriter(System.out, true), new PrintWriter(System.err, true), System.in)
+ Project rootProject = project.getRootProject()
+ for (item in project.sourceSets) {
+ for (File file : item.getAllSource()) {
+ if (!file.getName().endsWith(".java") || file.getAbsolutePath().contains("generated-src")) {
+ continue
+ }
+ if (formatter.format("-a", "-i", file.getAbsolutePath()) != 0) {
+ throw new GradleException("Format java failed: " + file.getAbsolutePath())
+ }
+ }
+ }
+ }
+ }
+
+ project.task('verifyJava') {
+ doLast {
+ def options = JavaFormatterOptions.builder().style(JavaFormatterOptions.Style.AOSP).build()
+ Formatter formatter = new Formatter(options)
+ Project rootProject = project.getRootProject()
+ for (item in project.sourceSets) {
+ for (File file : item.getAllSource()) {
+ if (!file.getName().endsWith(".java") || file.getAbsolutePath().contains("generated-src")) {
+ continue
+ }
+
+ String src = new String(file.bytes, "UTF-8")
+ String formatted = formatter.formatSource(src)
+ formatted = RemoveUnusedImports.removeUnusedImports(formatted, RemoveUnusedImports.JavadocOnlyImports.KEEP)
+ formatted = ImportOrderer.reorderImports(formatted);
+ if (!src.equals(formatted)) {
+ throw new GradleException("File not formatted: " + file.getAbsolutePath()
+ + System.lineSeparator()
+ + "In order to reformat your code, run './gradlew formatJava' (or './gradlew fJ' for short)"
+ + System.lineSeparator()
+ + "See https://github.com/deepjavalibrary/djl/blob/master/docs/development/development_guideline.md#coding-conventions for more details")
+ }
+ }
+ }
+ }
+ }
+ }
+}
diff --git a/java-package/tools/gradle/stats.gradle b/java-package/tools/gradle/stats.gradle
new file mode 100644
index 0000000..ae827f5
--- /dev/null
+++ b/java-package/tools/gradle/stats.gradle
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+def testsResults = new TreeMap<>(Comparator.reverseOrder())
+gradle.taskGraph.beforeTask { Task task ->
+ task.ext.setProperty("startTime", Instant.now())
+}
+
+gradle.taskGraph.afterTask { Task task, TaskState state ->
+ if (task.name.equals("test") && state.didWork) {
+ long duration = Duration.between(task.ext.startTime, Instant.now()).toSeconds()
+ testsResults.put(duration, task.project.name);
+ }
+}
+
+gradle.buildFinished {
+ if (gradle.startParameter.taskNames.contains("build") && !testsResults.isEmpty()) {
+ int count = 0;
+ println "========== Test duration =========="
+ for (Map.Entry<Long, String> entry : testsResults.entrySet()) {
+ if (count++ > 5) {
+ break;
+ }
+ println "\t${entry.value}:\t${entry.key}s"
+ }
+ }
+}
diff --git a/src/initialize.cc b/src/initialize.cc
index 9ef5121..1319cfc 100644
--- a/src/initialize.cc
+++ b/src/initialize.cc
@@ -375,9 +375,12 @@
}), \
[](auto f) { signal(SIGNAL, f); });
+// TODO(cspchen): avoid jvm exit with code 139. https://github.com/apache/incubator-mxnet/pull/20461
+#if !SKIP_SIGNAL_HANDLER_REGISTRATION
SIGNAL_HANDLER(SIGSEGV, SIGSEGVHandler, true);
SIGNAL_HANDLER(SIGFPE, SIGFPEHandler, false);
SIGNAL_HANDLER(SIGBUS, SIGBUSHandler, false);
+#endif
#endif