Java2.0 proto (#20461)

* ADD: first commit

* ADD: load local libraries

* UPDATE: use header files of MXNet 2.0

* ADD: load binaries from environment variable, java properties or jar files.

* ADD: add symbol loading and closing
add module integration

* ADD: [WIP] Component MxNDArray

* ADD: [WIP] Component MxNDArray

* ADD: Component MxNDArray. Pass static compilation check

* ADD: Component CachedOp

* REMOVE: module api which is no use

* FIX: dependency missing

* ADD: [WIP] add test cases for NdArray and CachedOp

* ADD: [WIP] add test cases for NdArray and CachedOp

* ADD: implement of the forward function for MxSymbolblock

* ADD: implement of the forward function for MxSymbolblock

* ADD: Sample model downloader for MLP

* ADD: doc

* ADD: Front-end module for inference, class MxModel, Predictor and so on.

* FIX: Mxnet crash when process exits.

* FIX: remove and initialize 3rdparty directory

* FIX: revert version of submodules: dlpack, dmlc-core, googletest, ps-lite

* Revert "FIX: remove and initialize 3rdparty directory"

This reverts commit d097675e

* FIX: redownload files in 3rdparty

* FIX: reset --hard the version of a few submodules

* FIX: reset --hard the version of a few submodules

* FIX: reset --hard the version of a few submodules

* PERF: [WIP] optimize code structure and memory management and

* ADD: add copyright; remove Mx prefix for some classes

* ADD: add copyright

* FIX: group name, path to find header files

* UPDATE: README.md

* ADD: copyright

* ADD: copyright

* ADD: package-info

ADD: ci

ADD: ci

ADD: make modification to trigger ci

ADD: ci

ADD: ci

ADD: ci

ADD: ci

ADD: gradlew

ADD: java_package_ci

ADD: java_package_ci

ADD: java_package_ci

ADD: java_package_ci

ADD: java_package_ci

ADD: java_package_ci

ADD: java_package_ci

ADD: java_package_ci

ADD: java_package_ci

FIX: build failure

* FIX: ci config file

* UPDATE: remove ParameterStore and some scripts

UPDATE: remove Initializer.java

* UPDATE: ci for java-package

UPDATE: ci for java-package

UPDATE: ci for java-package

UPDATE: ci for java-package

UPDATE: ci for java-package

UPDATE: ci for java-package

UPDATE: ci for java-package

UPDATE: ci for java-package

UPDATE: ci for java-package

UPDATE: ci for java-package

UPDATE: ci for java-package

UPDATE: ci for java-package

UPDATE: ci for java-package

UPDATE: ci for java-package

UPDATE: ci for java-package

UPDATE: ci for java-package

UPDATE: ci for java-package

UPDATE: ci for java-package

UPDATE: ci for java-package

UPDATE: ci for java-package

UPDATE: ci for java-package

UPDATE: ci for java-package

UPDATE: ci for java-package

UPDATE: jenkins ci scripts for java-package

UPDATE: jenkins ci scripts for java-package

UPDATE: jenkins ci scripts for java-package

UPDATE: jenkins ci scripts for java-package

UPDATE: jenkins ci scripts for java-package

UPDATE: jenkins ci scripts for java-package

UPDATE: jenkins ci scripts for java-package

UPDATE: jenkins ci scripts for java-package

UPDATE: jenkins ci scripts for java-package

UPDATE: jenkins ci scripts for java-package

FIX: issues in Resource close methods

FIX: issues in Resource close methods

FIX: issues in Resource close methods

UPDATE: remove scripts for dev

* FIX: loading on Linux platform

* # This is a combination of 18 commits.
# This is the 1st commit message:

FIX: loading on Linux platform

# This is the commit message #2:

UPDATE: ci for java-package

# This is the commit message #3:

UPDATE: ci for java-package

# This is the commit message #4:

UPDATE: ci for java-package

# This is the commit message #5:

UPDATE: ci for java-package

# This is the commit message #6:

UPDATE: ci for java-package

# This is the commit message #7:

UPDATE: ci for java-package

# This is the commit message #8:

UPDATE: ci for java-package

# This is the commit message #9:

UPDATE: ci for java-package

# This is the commit message #10:

UPDATE: ci for java-package

# This is the commit message #11:

UPDATE: ci for java-package

# This is the commit message #12:

UPDATE: ci for java-package

# This is the commit message #13:

UPDATE: ci for java-package

# This is the commit message #14:

UPDATE: ci for java-package

# This is the commit message #15:

UPDATE: ci for java-package

# This is the commit message #16:

UPDATE: ci for java-package

# This is the commit message #17:

UPDATE: ci for java-package

# This is the commit message #18:

UPDATE: ci for java-package

* # This is a combination of 27 commits.
parent 1ea18edce197b402cbfbaaaa54a94347501d92ab
author cspchen <cspchen@amazon.com> 1629186478 +0800
committer cspchen <cspchen@amazon.com> 1629186485 +0800

# This is a combination of 21 commits.
# This is the 1st commit message:

FIX: loading on Linux platform

# This is the commit message #2:

UPDATE: ci for java-package

# This is the commit message #3:

UPDATE: ci for java-package

# This is the commit message #4:

UPDATE: ci for java-package

# This is the commit message #5:

UPDATE: ci for java-package

# This is the commit message #6:

UPDATE: ci for java-package

# This is the commit message #7:

UPDATE: ci for java-package

# This is the commit message #8:

UPDATE: ci for java-package

# This is the commit message #9:

UPDATE: ci for java-package

# This is the commit message #10:

UPDATE: ci for java-package

# This is the commit message #11:

UPDATE: ci for java-package

# This is the commit message #12:

UPDATE: ci for java-package

# This is the commit message #13:

UPDATE: ci for java-package

# This is the commit message #14:

UPDATE: ci for java-package

# This is the commit message #15:

UPDATE: ci for java-package

# This is the commit message #16:

UPDATE: ci for java-package

# This is the commit message #17:

UPDATE: ci for java-package

# This is the commit message #18:

UPDATE: ci for java-package

# This is the commit message #19:

UPDATE: ci for java-package

# This is the commit message #20:

UPDATE: ci for java-package

# This is the commit message #21:

UPDATE: ci for java-package

# This is the commit message #22:

UPDATE: ci for java-package

# This is the commit message #23:

UPDATE: jenkins ci scripts for java-package

# This is the commit message #24:

UPDATE: jenkins ci scripts for java-package

# This is the commit message #25:

UPDATE: jenkins ci scripts for java-package

# This is the commit message #26:

UPDATE: jenkins ci scripts for java-package

# This is the commit message #27:

UPDATE: jenkins ci scripts for java-package

* MERGE: resolve conflicts

* MERGE: resolve conflicts

* # This is a combination of 35 commits.
parent 1ea18edce197b402cbfbaaaa54a94347501d92ab
author cspchen <cspchen@amazon.com> 1629186478 +0800
committer cspchen <cspchen@amazon.com> 1629186485 +0800

# This is a combination of 21 commits.
# This is the 1st commit message:

FIX: loading on Linux platform

# This is the commit message #2:

UPDATE: ci for java-package

# This is the commit message #3:

UPDATE: ci for java-package

# This is the commit message #4:

UPDATE: ci for java-package

# This is the commit message #5:

UPDATE: ci for java-package

# This is the commit message #6:

UPDATE: ci for java-package

# This is the commit message #7:

UPDATE: ci for java-package

# This is the commit message #8:

UPDATE: ci for java-package

# This is the commit message #9:

UPDATE: ci for java-package

# This is the commit message #10:

UPDATE: ci for java-package

# This is the commit message #11:

UPDATE: ci for java-package

# This is the commit message #12:

UPDATE: ci for java-package

# This is the commit message #13:

UPDATE: ci for java-package

# This is the commit message #14:

UPDATE: ci for java-package

# This is the commit message #15:

UPDATE: ci for java-package

# This is the commit message #16:

UPDATE: ci for java-package

# This is the commit message #17:

UPDATE: ci for java-package

# This is the commit message #18:

UPDATE: ci for java-package

# This is the commit message #19:

UPDATE: ci for java-package

# This is the commit message #20:

UPDATE: ci for java-package

# This is the commit message #21:

UPDATE: ci for java-package

# This is the commit message #22:

UPDATE: ci for java-package

# This is the commit message #23:

UPDATE: jenkins ci scripts for java-package

# This is the commit message #24:

UPDATE: jenkins ci scripts for java-package

# This is the commit message #25:

UPDATE: jenkins ci scripts for java-package

# This is the commit message #26:

UPDATE: jenkins ci scripts for java-package

# This is the commit message #27:

UPDATE: jenkins ci scripts for java-package

# This is the commit message #28:

UPDATE: jenkins ci scripts for java-package

# This is the commit message #30:

UPDATE: jenkins ci scripts for java-package

# This is the commit message #31:

UPDATE: jenkins ci scripts for java-package

# This is the commit message #32:

UPDATE: jenkins ci scripts for java-package

# This is the commit message #33:

UPDATE: jenkins ci scripts for java-package

# This is the commit message #34:

FIX: issues in Resource close methods

# This is the commit message #35:

FIX: issues in Resource close methods

* parent 1ea18edce197b402cbfbaaaa54a94347501d92ab
author cspchen <cspchen@amazon.com> 1629186478 +0800
committer cspchen <cspchen@amazon.com> 1629186485 +0800

FIX: loading on Linux platform

UPDATE: ci for java-package

UPDATE: ci for java-package

UPDATE: ci for java-package

UPDATE: ci for java-package

UPDATE: ci for java-package

UPDATE: ci for java-package

UPDATE: ci for java-package

UPDATE: ci for java-package

UPDATE: ci for java-package

UPDATE: ci for java-package

UPDATE: ci for java-package

UPDATE: ci for java-package

UPDATE: ci for java-package

UPDATE: ci for java-package

UPDATE: ci for java-package

UPDATE: ci for java-package

UPDATE: ci for java-package

UPDATE: ci for java-package

UPDATE: ci for java-package

UPDATE: ci for java-package

UPDATE: ci for java-package

UPDATE: jenkins ci scripts for java-package

UPDATE: jenkins ci scripts for java-package

UPDATE: jenkins ci scripts for java-package

UPDATE: jenkins ci scripts for java-package

UPDATE: jenkins ci scripts for java-package

UPDATE: jenkins ci scripts for java-package

UPDATE: jenkins ci scripts for java-package

UPDATE: jenkins ci scripts for java-package

UPDATE: jenkins ci scripts for java-package

UPDATE: jenkins ci scripts for java-package

FIX: issues in Resource close methods

FIX: issues in Resource close methods

FIX: issues in Resource close methods

DOC: add doc

STYLE: change code style for pmd check

FIX: avoid the register for a signal handler twice

STYLE: pass pmd check

UPDATE: remove unused scripts

* FIX: solve problems before merge

* UPDATE: remove useless files

* FIX: licence to apache

* FIX: sanity check

* FIX: sanity check

* FIX: sanity check

* FIX: remove unused files

* FIX: remove unused files

* DOC: add document

* FIX: doesn't work on osx

* FIX: clang static check

* FIX: sanity

* FIX: skip signal handler registration when building java package

* FIX: remove DataType String

* FIX: add license

Co-authored-by: cspchen <cspchen@amazon.com>
diff --git a/CMakeLists.txt b/CMakeLists.txt
index b901f41..c6e0a4e 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -91,6 +91,7 @@
 option(BUILD_EXTENSION_PATH "Path to extension to build" "")
 option(BUILD_CYTHON_MODULES "Build cython modules." OFF)
 option(LOG_FATAL_THROW "Log exceptions but do not abort" ON)
+option(BUILD_JAVA_NATIVE "Skip signal handler registration for Java Binding" OFF)
 cmake_dependent_option(USE_SPLIT_ARCH_DLL "Build a separate DLL for each Cuda arch (Windows only)." ON "MSVC" OFF)
 cmake_dependent_option(USE_CCACHE "Attempt using CCache to wrap the compilation" ON "UNIX" OFF)
 cmake_dependent_option(MXNET_FORCE_SHARED_CRT "Build with dynamic CRT on Windows (/MD)" ON "MXNET_BUILD_SHARED_LIBS" OFF)
@@ -970,6 +971,10 @@
   target_compile_definitions(mxnet PUBLIC MXNET_USE_CPP_PACKAGE=1)
 endif()
 
+if(BUILD_JAVA_NATIVE)
+  add_definitions(-DSKIP_SIGNAL_HANDLER_REGISTRATION=1)
+endif()
+
 if(NOT CMAKE_BUILD_TYPE STREQUAL "Distribution")
   # Staticbuild applies linker version script to hide private symbols, breaking unit tests
   add_subdirectory(tests)
diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh
index 70934f6..1e75ab2 100755
--- a/ci/docker/runtime_functions.sh
+++ b/ci/docker/runtime_functions.sh
@@ -310,6 +310,23 @@
     build_ubuntu_cpu_openblas
 }
 
+build_ubuntu_cpu_and_test_java() {
+    build_ubuntu_cpu_openblas_java
+    java_package_integration_test
+}
+
+java_package_integration_test() {
+    # make sure you are using java 11
+    # build java project
+    cd /work/mxnet/java-package
+    ./gradlew build -x javadoc
+    # generate native library
+    ./gradlew :native:buildLocalLibraryJarDefault
+    ./gradlew :native:mkl-linuxJar
+    # run integration
+    ./gradlew :integration:run
+}
+
 build_ubuntu_cpu_openblas() {
     set -ex
     cd /work/build
@@ -327,6 +344,24 @@
     ninja
 }
 
+build_ubuntu_cpu_openblas_java() {
+    set -ex
+    cd /work/build
+    CXXFLAGS="-Wno-error=strict-overflow" CC=gcc-7 CXX=g++-7 cmake \
+        -DCMAKE_BUILD_TYPE="RelWithDebInfo" \
+        -DENABLE_TESTCOVERAGE=ON \
+        -DUSE_TVM_OP=ON \
+        -DUSE_BLAS=Open \
+        -DUSE_ONEDNN=OFF \
+        -DUSE_CUDA=OFF \
+        -DUSE_DIST_KVSTORE=ON \
+        -DBUILD_CYTHON_MODULES=ON \
+        -DBUILD_EXTENSION_PATH=/work/mxnet/example/extensions/lib_external_ops \
+        -DBUILD_JAVA_NATIVE=ON \
+        -G Ninja /work/mxnet
+    ninja
+}
+
 build_ubuntu_cpu_mkl() {
     set -ex
     cd /work/build
diff --git a/java-package/Develop.md b/java-package/Develop.md
new file mode 100644
index 0000000..8bc0ffe
--- /dev/null
+++ b/java-package/Develop.md
@@ -0,0 +1,109 @@
+<!--- Licensed to the Apache Software Foundation (ASF) under one -->
+<!--- or more contributor license agreements.  See the NOTICE file -->
+<!--- distributed with this work for additional information -->
+<!--- regarding copyright ownership.  The ASF licenses this file -->
+<!--- to you under the Apache License, Version 2.0 (the -->
+<!--- "License"); you may not use this file except in compliance -->
+<!--- with the License.  You may obtain a copy of the License at -->
+
+<!---   http://www.apache.org/licenses/LICENSE-2.0 -->
+
+<!--- Unless required by applicable law or agreed to in writing, -->
+<!--- software distributed under the License is distributed on an -->
+<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
+<!--- KIND, either express or implied.  See the License for the -->
+<!--- specific language governing permissions and limitations -->
+<!--- under the License. -->
+
+# Development Tips
+
+## Set up the Project
+### Step 1. Obtain MXNet Library
+The first step is to obtain the mxnet library. We recommend you build it from source. Also, you can download the library 
+from 
+#### Build from source
+Refer to [Build From Source](https://mxnet.apache.org/get_started/build_from_source#building-mxnet)   
+For MacOS users:
+- Prepare  
+```shell
+# Install OS X Developer Tools
+$ xcode-select --install
+
+# Install Homebrew
+$ /usr/bin/ruby -e "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/master/install)"
+
+# Install dependencies
+$ brew install cmake ninja ccache opencv
+```
+- Clone 3rd party projects
+```shell
+# Clone 3rd dependency for mxnet. It's necessary
+$ git submodule update --init --recursive
+```
+- Build MXNet  
+```shell
+# select and copy cmake configure files for macos
+$ cp config/darwin.cmake config.cmake
+
+# create build directory for prpject
+$ mkdir build; cd build
+
+# cmake
+$ cmake  ..
+$ cmake --build .
+```
+Libraries will be generated under the directory _build/_.
+
+For Linux users:  
+Docker might help you build libraries on different platforms. You can get help from [README for CI](../ci/README.md).  
+For example, you can build mxnet on Ubuntu with  by the following command.  
+```shell
+$ python3 ci/build.py -p ubuntu_cpu
+```
+##### Download Pre-built library
+You can find the mxnet library from installed packages for mxnet, like python module. However, mxnet 2.0 is not released
+yet, that's why we recommend you build it from source.  
+```shell
+# download python module for mxnet (have to mention that mxnet 2.0 hasn't been released by now)
+$ pip3 install  mxnet==1.7.0.post2
+# find the location of the installed module
+$ python
+Python 3.6.8 |Anaconda, Inc.| (default, Dec 29 2018, 19:04:46)
+>>> import mxnet
+>>> mxnet
+<module 'mxnet' from '/Users/xxx/anaconda3/lib/python3.6/site-packages/mxnet/__init__.py'>
+>>> quit()
+# you can locate the module under /Users/xxx/anaconda3/lib/python3.6/site-packages/mxnet/
+$ ls /Users/xxx/anaconda3/lib/python3.6/site-packages/mxnet/ | grep libmxnet
+libmxnet.dylib
+```
+The compiled library is the file with the name of _libmxnet.*_. For MacOS, you will receive the file with suffix 
+_.dylib_; For Linux, the lib file have the suffix ".so"; For Windows, the suffix is "." 
+
+### Step 2. Build MXNet Native Lib for Java
+The project uses gradle to manage dependencies. You can build the project using gradle. We have to encapsulate the mxnet
+library into a jar file so that we can load it into JVM.
+```shell
+$ cd java-package
+# Build the project
+$ ./gradlew build 
+# Create gradle tasks to package mxnet library into jar
+# The task name is in this form {$favor}-{$platform}Jar
+# MacOS -> mkl-osxJar
+# Linux -> mkl-linuxJar
+# Windows -> mkl-winJar
+$ ./gradlew :native:buildLocalLibraryJarDefault
+# Build native lib for macos
+$ ./gradlew mkl-osxJar
+# Check the lib for osx
+$ ls native/build/libs | grep osx
+native-2.0.0-SNAPSHOT-osx-x86_64.jar
+```
+The jar file _native-2.0.0-SNAPSHOT-osx-x86_64.jar_ is the output lib file. 
+
+### Step 3. Run Integration Test
+When we execute the task for integration test, the built mxnet native lib will be added into classpath automatically. 
+```shell
+$ ./gradlew :integration:run
+
+```
\ No newline at end of file
diff --git a/java-package/README.md b/java-package/README.md
new file mode 100644
index 0000000..eb5b901
--- /dev/null
+++ b/java-package/README.md
@@ -0,0 +1,58 @@
+<!--- Licensed to the Apache Software Foundation (ASF) under one -->
+<!--- or more contributor license agreements.  See the NOTICE file -->
+<!--- distributed with this work for additional information -->
+<!--- regarding copyright ownership.  The ASF licenses this file -->
+<!--- to you under the Apache License, Version 2.0 (the -->
+<!--- "License"); you may not use this file except in compliance -->
+<!--- with the License.  You may obtain a copy of the License at -->
+
+<!---   http://www.apache.org/licenses/LICENSE-2.0 -->
+
+<!--- Unless required by applicable law or agreed to in writing, -->
+<!--- software distributed under the License is distributed on an -->
+<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
+<!--- KIND, either express or implied.  See the License for the -->
+<!--- specific language governing permissions and limitations -->
+<!--- under the License. -->
+
+# Java Package for MXNet 2.0
+
+## Requirements
+
+## Install
+
+## Scripts
+- customize mxnet library path  
+```bash
+export MXNET_LIBRARY_PATH=//anaconda3/lib/python3.8/site-packages/mxnet/
+```
+
+
+## Tests  
+Test case for a rough inference run with MXNet model  
+```bash
+./gradlew :integration:run  
+```
+
+## Example
+
+```java
+try (MxResource base = BaseMxResource.getSystemMxResource())
+        {
+        Model model = Model.loadModel(Item.MLP);
+//            Model model = Model.loadModel("test", Paths.get("/Users/cspchen/mxnet.java_package/cache/repo/test-models/mlp.tar.gz/mlp/"));
+        Predictor<NDList, NDList> predictor = model.newPredictor();
+        NDArray input = NDArray.create(base, new Shape(1, 28, 28)).ones();
+        NDList inputs = new NDList();
+        inputs.add(input);
+        NDList result = predictor.predict(inputs);
+        NDArray expected =  NDArray.create(
+        base,
+        new float[]{4.93476f, -0.76084447f, 0.37713608f, 0.6605506f, -1.3485785f, -0.8736369f
+        , 0.018061712f, -1.3274033f, 1.0609543f, 0.24042489f}, new Shape(1, 10));
+        Assertions.assertAlmostEquals(result.get(0), expected);
+
+        } catch (IOException e) {
+        logger.error(e.getMessage(), e);
+        }
+```
\ No newline at end of file
diff --git a/java-package/build.gradle b/java-package/build.gradle
new file mode 100644
index 0000000..909dbbb
--- /dev/null
+++ b/java-package/build.gradle
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+plugins {
+    id "com.github.spotbugs" version "4.2.0" apply true
+}
+
+defaultTasks 'build'
+
+allprojects {
+    group 'org.apache.mxnet'
+    boolean isRelease = project.hasProperty("release") || project.hasProperty("staging")
+    version = "${java_package_version}" + (isRelease ? "" : "-SNAPSHOT")
+
+    repositories {
+//        maven {
+//            url "https://mlrepo.djl.ai/maven/"
+//        }
+        mavenCentral()
+        maven {
+            url 'https://oss.sonatype.org/content/repositories/snapshots/'
+        }
+    }
+
+    apply plugin: 'idea'
+    idea {
+        module {
+            outputDir = file('build/classes/java/main')
+            testOutputDir = file('build/classes/java/test')
+            // inheritOutputDirs = true
+        }
+    }
+}
+
+def javaProjects() {
+    return subprojects.findAll { new File(it.projectDir, "src/main").exists() }
+}
+
+configure(javaProjects()) {
+    apply plugin: 'java-library'
+    sourceCompatibility = 1.8
+    targetCompatibility = 1.8
+    compileJava.options.encoding = "UTF-8"
+    compileTestJava.options.encoding = "UTF-8"
+    if (JavaVersion.current() != JavaVersion.VERSION_1_8) {
+        compileJava.options.compilerArgs.addAll(["--release", "8"])
+    }
+
+    apply plugin: 'eclipse'
+
+    eclipse {
+        jdt.file.withProperties { props ->
+            props.setProperty "org.eclipse.jdt.core.circularClasspath", "warning"
+        }
+        classpath {
+            sourceSets.test.java {
+                srcDirs = ["src/test/java"]
+                exclude "**/package-info.java"
+            }
+        }
+    }
+
+    apply from: file("${rootProject.projectDir}/tools/gradle/java-formatter.gradle")
+    apply from: file("${rootProject.projectDir}/tools/gradle/check.gradle")
+
+    test {
+        // tensorflow mobilenet and resnet require more cpu memory
+        maxHeapSize = "4096m"
+        doFirst {
+            if (JavaVersion.current() != JavaVersion.VERSION_1_8) {
+                jvmArgs = [
+                        '--add-opens', "java.base/jdk.internal.loader=ALL-UNNAMED"
+                ]
+            }
+        }
+
+        useTestNG() {
+//             suiteXmlFiles << new File(rootDir, "testng.xml") //This is how to add custom testng.xml
+        }
+
+        testLogging {
+            showStandardStreams = true
+            events "passed", "skipped", "failed", "standardOut", "standardError"
+        }
+
+        doFirst {
+            systemProperties System.getProperties()
+            systemProperties.remove("user.dir")
+            systemProperty "org.apache.mxnet.logging.level", "debug"
+            systemProperty "org.slf4j.simpleLogger.defaultLogLevel", "debug"
+            systemProperty "org.slf4j.simpleLogger.log.org.mortbay.log", "warn"
+            systemProperty "disableProgressBar", "true"
+            systemProperty "nightly", System.getProperty("nightly", "false")
+//            systemProperty "java.library.path", "/Users/cspchen/Work/incubator-mxnet/build"
+            if (gradle.startParameter.offline) {
+                systemProperty "offline", "true"
+            }
+        }
+    }
+
+    compileJava {
+        options.compilerArgs << "-proc:none" << "-Xlint:all,-options,-static" << "-Werror"
+    }
+
+    compileTestJava {
+        options.compilerArgs << "-proc:none" << "-Xlint:all,-options,-static" << "-Werror"
+    }
+
+    jar {
+        manifest {
+            attributes("Automatic-Module-Name": "org.apach.mxnet.${project.name.replace('-', '_')}")
+        }
+    }
+}
+
+apply from: file("${rootProject.projectDir}/tools/gradle/jacoco.gradle")
+apply from: file("${rootProject.projectDir}/tools/gradle/stats.gradle")
diff --git a/java-package/example/build.gradle b/java-package/example/build.gradle
new file mode 100644
index 0000000..a6831ee
--- /dev/null
+++ b/java-package/example/build.gradle
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+plugins {
+    id 'java'
+}
+
+group 'incubator-mxnet.java-package'
+version '0.0.1-SNAPSHOT'
+
+repositories {
+    mavenCentral()
+}
+
+dependencies {
+    testImplementation 'org.junit.jupiter:junit-jupiter-api:5.7.0'
+    testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine:5.7.0'
+}
+
+test {
+    useJUnitPlatform()
+}
\ No newline at end of file
diff --git a/java-package/gradle.properties b/java-package/gradle.properties
new file mode 100644
index 0000000..9d32099
--- /dev/null
+++ b/java-package/gradle.properties
@@ -0,0 +1,43 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+org.gradle.daemon=true
+org.gradle.jvmargs=-Xmx2048M
+
+systemProp.org.gradle.internal.http.socketTimeout=120000
+systemProp.org.gradle.internal.http.connectionTimeout=60000
+
+# FIXME: Workaround gradle publish issue: https://github.com/gradle/gradle/issues/11308
+systemProp.org.gradle.internal.publish.checksums.insecure=true
+
+java_package_version=0.0.1
+mxnet_version=2.0.0
+api_version=0.0.1
+jnarator_version=0.0.1
+
+antlr_version=4.7.2
+commons_cli_version=1.4
+commons_compress_version=1.20
+commons_csv_version=1.8
+gson_version=2.8.6
+jna_version=5.3.0
+netty_version=4.1.51.Final
+slf4j_version=1.7.30
+log4j_slf4j_version=2.13.3
+testng_version=7.1.0
+powermock_version=2.0.7
+
diff --git a/java-package/gradle/wrapper/gradle-wrapper.jar b/java-package/gradle/wrapper/gradle-wrapper.jar
new file mode 100644
index 0000000..e708b1c
--- /dev/null
+++ b/java-package/gradle/wrapper/gradle-wrapper.jar
Binary files differ
diff --git a/java-package/gradle/wrapper/gradle-wrapper.properties b/java-package/gradle/wrapper/gradle-wrapper.properties
new file mode 100644
index 0000000..689c068
--- /dev/null
+++ b/java-package/gradle/wrapper/gradle-wrapper.properties
@@ -0,0 +1,22 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+distributionBase=GRADLE_USER_HOME
+distributionPath=wrapper/dists
+distributionUrl=https\://services.gradle.org/distributions/gradle-7.0-bin.zip
+zipStoreBase=GRADLE_USER_HOME
+zipStorePath=wrapper/dists
diff --git a/java-package/gradlew b/java-package/gradlew
new file mode 100755
index 0000000..ebb6c09
--- /dev/null
+++ b/java-package/gradlew
@@ -0,0 +1,186 @@
+#!/usr/bin/env sh
+
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+##############################################################################
+##
+##  Gradle start up script for UN*X
+##
+##############################################################################
+
+# Attempt to set APP_HOME
+# Resolve links: $0 may be a link
+PRG="$0"
+# Need this for relative symlinks.
+while [ -h "$PRG" ] ; do
+    ls=`ls -ld "$PRG"`
+    link=`expr "$ls" : '.*-> \(.*\)$'`
+    if expr "$link" : '/.*' > /dev/null; then
+        PRG="$link"
+    else
+        PRG=`dirname "$PRG"`"/$link"
+    fi
+done
+SAVED="`pwd`"
+cd "`dirname \"$PRG\"`/" >/dev/null
+APP_HOME="`pwd -P`"
+cd "$SAVED" >/dev/null
+
+APP_NAME="Gradle"
+APP_BASE_NAME=`basename "$0"`
+
+# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
+DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"'
+
+# Use the maximum available, or set MAX_FD != -1 to use that value.
+MAX_FD="maximum"
+
+warn () {
+    echo "$*"
+}
+
+die () {
+    echo
+    echo "$*"
+    echo
+    exit 1
+}
+
+# OS specific support (must be 'true' or 'false').
+cygwin=false
+msys=false
+darwin=false
+nonstop=false
+case "`uname`" in
+  CYGWIN* )
+    cygwin=true
+    ;;
+  Darwin* )
+    darwin=true
+    ;;
+  MINGW* )
+    msys=true
+    ;;
+  NONSTOP* )
+    nonstop=true
+    ;;
+esac
+
+CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
+
+
+# Determine the Java command to use to start the JVM.
+if [ -n "$JAVA_HOME" ] ; then
+    if [ -x "$JAVA_HOME/jre/sh/java" ] ; then
+        # IBM's JDK on AIX uses strange locations for the executables
+        JAVACMD="$JAVA_HOME/jre/sh/java"
+    else
+        JAVACMD="$JAVA_HOME/bin/java"
+    fi
+    if [ ! -x "$JAVACMD" ] ; then
+        die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME
+
+Please set the JAVA_HOME variable in your environment to match the
+location of your Java installation."
+    fi
+else
+    JAVACMD="java"
+    which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
+
+Please set the JAVA_HOME variable in your environment to match the
+location of your Java installation."
+fi
+
+# Increase the maximum file descriptors if we can.
+if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then
+    MAX_FD_LIMIT=`ulimit -H -n`
+    if [ $? -eq 0 ] ; then
+        if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then
+            MAX_FD="$MAX_FD_LIMIT"
+        fi
+        ulimit -n $MAX_FD
+        if [ $? -ne 0 ] ; then
+            warn "Could not set maximum file descriptor limit: $MAX_FD"
+        fi
+    else
+        warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT"
+    fi
+fi
+
+# For Darwin, add options to specify how the application appears in the dock
+if $darwin; then
+    GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\""
+fi
+
+# For Cygwin or MSYS, switch paths to Windows format before running java
+if [ "$cygwin" = "true" -o "$msys" = "true" ] ; then
+    APP_HOME=`cygpath --path --mixed "$APP_HOME"`
+    CLASSPATH=`cygpath --path --mixed "$CLASSPATH"`
+
+    JAVACMD=`cygpath --unix "$JAVACMD"`
+
+    # We build the pattern for arguments to be converted via cygpath
+    ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null`
+    SEP=""
+    for dir in $ROOTDIRSRAW ; do
+        ROOTDIRS="$ROOTDIRS$SEP$dir"
+        SEP="|"
+    done
+    OURCYGPATTERN="(^($ROOTDIRS))"
+    # Add a user-defined pattern to the cygpath arguments
+    if [ "$GRADLE_CYGPATTERN" != "" ] ; then
+        OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)"
+    fi
+    # Now convert the arguments - kludge to limit ourselves to /bin/sh
+    i=0
+    for arg in "$@" ; do
+        CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -`
+        CHECK2=`echo "$arg"|egrep -c "^-"`                                 ### Determine if an option
+
+        if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then                    ### Added a condition
+            eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"`
+        else
+            eval `echo args$i`="\"$arg\""
+        fi
+        i=`expr $i + 1`
+    done
+    case $i in
+        0) set -- ;;
+        1) set -- "$args0" ;;
+        2) set -- "$args0" "$args1" ;;
+        3) set -- "$args0" "$args1" "$args2" ;;
+        4) set -- "$args0" "$args1" "$args2" "$args3" ;;
+        5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;;
+        6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;;
+        7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;;
+        8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;;
+        9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;;
+    esac
+fi
+
+# Escape application args
+save () {
+    for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done
+    echo " "
+}
+APP_ARGS=`save "$@"`
+
+# Collect all arguments for the java command, following the shell quoting and substitution rules
+eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS"
+
+exec "$JAVACMD" "$@"
diff --git a/java-package/gradlew.bat b/java-package/gradlew.bat
new file mode 100644
index 0000000..bcf7fb7
--- /dev/null
+++ b/java-package/gradlew.bat
@@ -0,0 +1,92 @@
+@REM

+@REM Licensed to the Apache Software Foundation (ASF) under one

+@REM or more contributor license agreements.  See the NOTICE file

+@REM distributed with this work for additional information

+@REM regarding copyright ownership.  The ASF licenses this file

+@REM to you under the Apache License, Version 2.0 (the

+@REM "License"); you may not use this file except in compliance

+@REM with the License.  You may obtain a copy of the License at

+@REM

+@REM   http://www.apache.org/licenses/LICENSE-2.0

+@REM

+@REM Unless required by applicable law or agreed to in writing,

+@REM software distributed under the License is distributed on an

+@REM "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY

+@REM KIND, either express or implied.  See the License for the

+@REM specific language governing permissions and limitations

+@REM under the License.

+@REM

+

+@if "%DEBUG%" == "" @echo off

+@rem ##########################################################################

+@rem

+@rem  Gradle startup script for Windows

+@rem

+@rem ##########################################################################

+

+@rem Set local scope for the variables with windows NT shell

+if "%OS%"=="Windows_NT" setlocal

+

+set DIRNAME=%~dp0

+if "%DIRNAME%" == "" set DIRNAME=.

+set APP_BASE_NAME=%~n0

+set APP_HOME=%DIRNAME%

+

+@rem Resolve any "." and ".." in APP_HOME to make it shorter.

+for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi

+

+@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.

+set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m"

+

+@rem Find java.exe

+if defined JAVA_HOME goto findJavaFromJavaHome

+

+set JAVA_EXE=java.exe

+%JAVA_EXE% -version >NUL 2>&1

+if "%ERRORLEVEL%" == "0" goto execute

+

+echo.

+echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.

+echo.

+echo Please set the JAVA_HOME variable in your environment to match the

+echo location of your Java installation.

+

+goto fail

+

+:findJavaFromJavaHome

+set JAVA_HOME=%JAVA_HOME:"=%

+set JAVA_EXE=%JAVA_HOME%/bin/java.exe

+

+if exist "%JAVA_EXE%" goto execute

+

+echo.

+echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME%

+echo.

+echo Please set the JAVA_HOME variable in your environment to match the

+echo location of your Java installation.

+

+goto fail

+

+:execute

+@rem Setup the command line

+

+set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar

+

+

+@rem Execute Gradle

+"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %*

+

+:end

+@rem End local scope for the variables with windows NT shell

+if "%ERRORLEVEL%"=="0" goto mainEnd

+

+:fail

+rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of

+rem the _cmd.exe /c_ return code!

+if  not "" == "%GRADLE_EXIT_CONSOLE%" exit 1

+exit /b 1

+

+:mainEnd

+if "%OS%"=="Windows_NT" endlocal

+

+:omega

diff --git a/java-package/integration/build.gradle b/java-package/integration/build.gradle
new file mode 100644
index 0000000..fed8860
--- /dev/null
+++ b/java-package/integration/build.gradle
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+plugins {
+    id 'application'
+    id 'jacoco'
+}
+
+group 'org.apache.mxnet'
+version '0.0.1-SNAPSHOT'
+
+repositories {
+    mavenCentral()
+}
+
+application {
+    mainClassName = System.getProperty("main", "org.apache.mxnet.integration.IntegrationTest")
+}
+
+dependencies {
+    api "commons-cli:commons-cli:${commons_cli_version}"
+    api "org.apache.logging.log4j:log4j-slf4j-impl:${log4j_slf4j_version}"
+
+    api project(":mxnet-engine")
+    implementation "org.testng:testng:${testng_version}"
+//    testImplementation(":mxnet-engine")
+    testImplementation("org.testng:testng:${testng_version}") {
+        exclude group: "junit", module: "junit"
+    }
+    testImplementation "org.slf4j:slf4j-simple:${slf4j_version}"
+}
+
+run {
+    systemProperties System.getProperties()
+    systemProperties.remove("user.dir")
+    systemProperty("file.encoding", "UTF-8")
+    jvmArgs "-Xverify:none"
+}
+
+checkstyleMain {
+    // skip check style for this package
+    exclude 'org/apache/mxnet/integration/**'
+}
+
+//test {
+//
+//    useTestNG()
+//    filter {
+//        includeTestsMatching "org.apache.mxnet.integration.tests.engine.*"
+//    }
+//
+//}
+
diff --git a/java-package/integration/src/main/java/org/apache/mxnet/integration/IntegrationTest.java b/java-package/integration/src/main/java/org/apache/mxnet/integration/IntegrationTest.java
new file mode 100644
index 0000000..96d0ad3
--- /dev/null
+++ b/java-package/integration/src/main/java/org/apache/mxnet/integration/IntegrationTest.java
@@ -0,0 +1,413 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.integration;
+
+import java.io.File;
+import java.io.IOException;
+import java.lang.reflect.Constructor;
+import java.lang.reflect.InvocationTargetException;
+import java.lang.reflect.Method;
+import java.net.URISyntaxException;
+import java.net.URL;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.time.Duration;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.Enumeration;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.jar.JarEntry;
+import java.util.jar.JarFile;
+import java.util.stream.Collectors;
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.DefaultParser;
+import org.apache.commons.cli.HelpFormatter;
+import org.apache.commons.cli.Options;
+import org.apache.commons.cli.ParseException;
+import org.apache.mxnet.integration.util.Arguments;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.testng.SkipException;
+import org.testng.annotations.AfterClass;
+import org.testng.annotations.AfterTest;
+import org.testng.annotations.BeforeClass;
+import org.testng.annotations.BeforeTest;
+import org.testng.annotations.Test;
+
+public class IntegrationTest {
+
+    private static final Logger logger = LoggerFactory.getLogger(IntegrationTest.class);
+
+    private Class<?> source;
+
+    public IntegrationTest(Class<?> source) {
+        this.source = source;
+    }
+
+    public static void main(String[] args) {
+        new IntegrationTest(IntegrationTest.class).runTests(args);
+        // TODO: not elegant solution to native library crash
+        //        System.exit(0);
+    }
+
+    public boolean runTests(String[] args) {
+        Options options = Arguments.getOptions();
+        try {
+            DefaultParser parser = new DefaultParser();
+            CommandLine cmd = parser.parse(options, args, null, false);
+            Arguments arguments = new Arguments(cmd);
+
+            Duration duration = Duration.ofMinutes(arguments.getDuration());
+            List<TestClass> tests = listTests(arguments, source);
+
+            boolean testsPassed = true;
+            while (!duration.isNegative()) {
+                long begin = System.currentTimeMillis();
+
+                testsPassed = testsPassed && runTests(tests);
+
+                long delta = System.currentTimeMillis() - begin;
+                duration = duration.minus(Duration.ofMillis(delta));
+            }
+            return testsPassed;
+        } catch (ParseException e) {
+            HelpFormatter formatter = new HelpFormatter();
+            formatter.setLeftPadding(1);
+            formatter.setWidth(120);
+            formatter.printHelp(e.getMessage(), options);
+            return false;
+        } catch (Throwable t) {
+            logger.error("Unexpected error", t);
+            return false;
+        }
+    }
+
+    private boolean runTests(List<TestClass> tests) {
+        Map<TestResult, Integer> totals = new ConcurrentHashMap<>();
+        for (TestClass testClass : tests) {
+            logger.info("Running test {} ...", testClass.getName());
+            int testCount = testClass.getTestCount();
+
+            try {
+                if (!testClass.beforeClass()) {
+                    totals.merge(TestResult.FAILED, testCount, Integer::sum);
+                    continue;
+                }
+
+                for (int i = 0; i < testCount; ++i) {
+                    TestResult result = testClass.runTest(i);
+                    totals.merge(result, 1, Integer::sum);
+                }
+            } finally {
+                testClass.afterClass();
+            }
+        }
+
+        int totalFailed = totals.getOrDefault(TestResult.FAILED, 0);
+        int totalPassed = totals.getOrDefault(TestResult.SUCCESS, 0);
+        int totalSkipped = totals.getOrDefault(TestResult.SKIPPED, 0);
+        int totalUnsupported = totals.getOrDefault(TestResult.UNSUPPORTED, 0);
+        if (totalSkipped > 0) {
+            logger.info("Skipped: {} tests", totalSkipped);
+        }
+        if (totalUnsupported > 0) {
+            logger.info("Unsupported: {} tests", totalUnsupported);
+        }
+        if (totalFailed > 0) {
+            logger.error("Failed {} out of {} tests", totalFailed, totalFailed + totalPassed);
+        } else {
+            logger.info("Passed all {} tests", totalPassed);
+        }
+        return totalFailed == 0;
+    }
+
+    private static List<TestClass> listTests(Arguments arguments, Class<?> source)
+            throws IOException, ReflectiveOperationException, URISyntaxException {
+        String className = arguments.getClassName();
+        String methodName = arguments.getMethodName();
+        List<TestClass> tests = new ArrayList<>();
+        try {
+            if (className != null) {
+                Class<?> clazz;
+                if (className.startsWith(arguments.getPackageName())) {
+                    clazz = Class.forName(className);
+                } else {
+                    clazz = Class.forName(arguments.getPackageName() + className);
+                }
+                getTestsInClass(clazz, methodName).map(tests::add);
+            } else {
+                List<Class<?>> classes = listTestClasses(arguments, source);
+                for (Class<?> clazz : classes) {
+                    getTestsInClass(clazz, methodName).map(tests::add);
+                }
+            }
+        } catch (ReflectiveOperationException | IOException | URISyntaxException e) {
+            logger.error("Failed to resolve test class.", e);
+            throw e;
+        }
+        return tests;
+    }
+
+    private static Optional<TestClass> getTestsInClass(Class<?> clazz, String methodName)
+            throws ReflectiveOperationException {
+        if (clazz.getConstructors().length == 0) {
+            return Optional.empty();
+        }
+        Constructor<?> ctor = clazz.getConstructor();
+        Object obj = ctor.newInstance();
+        TestClass testClass = new TestClass(obj);
+
+        for (Method method : clazz.getDeclaredMethods()) {
+            Test testMethod = method.getAnnotation(Test.class);
+            if (testMethod != null) {
+                if (testMethod.enabled()
+                        && (methodName == null || methodName.equals(method.getName()))) {
+                    testClass.addTestMethod(method);
+                }
+                continue;
+            }
+            BeforeClass beforeClass = method.getAnnotation(BeforeClass.class);
+            if (beforeClass != null) {
+                testClass.addBeforeClass(method);
+                continue;
+            }
+            AfterClass afterClass = method.getAnnotation(AfterClass.class);
+            if (afterClass != null) {
+                testClass.addAfterClass(method);
+                continue;
+            }
+            BeforeTest beforeTest = method.getAnnotation(BeforeTest.class);
+            if (beforeTest != null) {
+                testClass.addBeforeTest(method);
+                continue;
+            }
+            AfterTest afterTest = method.getAnnotation(AfterTest.class);
+            if (afterTest != null) {
+                testClass.addAfterTest(method);
+            }
+        }
+
+        return Optional.of(testClass);
+    }
+
+    private static List<Class<?>> listTestClasses(Arguments arguments, Class<?> clazz)
+            throws IOException, ClassNotFoundException, URISyntaxException {
+        URL url = clazz.getProtectionDomain().getCodeSource().getLocation();
+        String path = url.getPath();
+
+        if (!"file".equalsIgnoreCase(url.getProtocol())) {
+            return Collections.emptyList();
+        }
+
+        List<Class<?>> classList = new ArrayList<>();
+
+        Path classPath = Paths.get(url.toURI());
+        if (Files.isDirectory(classPath)) {
+            Collection<Path> files =
+                    Files.walk(classPath)
+                            .filter(p -> Files.isRegularFile(p) && p.toString().endsWith(".class"))
+                            .collect(Collectors.toList());
+            for (Path file : files) {
+                Path p = classPath.relativize(file);
+                String className = p.toString();
+                className = className.substring(0, className.lastIndexOf('.'));
+                className = className.replace(File.separatorChar, '.');
+                if (className.startsWith(arguments.getPackageName()) && !className.contains("$")) {
+                    try {
+                        classList.add(Class.forName(className));
+                    } catch (ExceptionInInitializerError ignore) {
+                        // ignore
+                    }
+                }
+            }
+        } else if (path.toLowerCase().endsWith(".jar")) {
+            try (JarFile jarFile = new JarFile(classPath.toFile())) {
+                Enumeration<JarEntry> en = jarFile.entries();
+                while (en.hasMoreElements()) {
+                    JarEntry entry = en.nextElement();
+                    String fileName = entry.getName();
+                    if (fileName.endsWith(".class")) {
+                        fileName = fileName.substring(0, fileName.lastIndexOf('.'));
+                        fileName = fileName.replace('/', '.');
+                        if (fileName.startsWith(arguments.getPackageName())) {
+                            try {
+                                classList.add(Class.forName(fileName));
+                            } catch (ExceptionInInitializerError ignore) {
+                                // ignore
+                            }
+                        }
+                    }
+                }
+            }
+        }
+
+        return classList;
+    }
+
+    private static final class TestClass {
+
+        private Object object;
+        private List<Method> testMethods;
+        private List<Method> beforeClass;
+        private List<Method> afterClass;
+        private List<Method> beforeTest;
+        private List<Method> afterTest;
+
+        public TestClass(Object object) {
+            this.object = object;
+            testMethods = new ArrayList<>();
+            beforeClass = new ArrayList<>();
+            afterClass = new ArrayList<>();
+            beforeTest = new ArrayList<>();
+            afterTest = new ArrayList<>();
+        }
+
+        public void addTestMethod(Method method) {
+            testMethods.add(method);
+        }
+
+        public void addBeforeClass(Method method) {
+            beforeClass.add(method);
+        }
+
+        public void addAfterClass(Method method) {
+            afterClass.add(method);
+        }
+
+        public void addBeforeTest(Method method) {
+            beforeTest.add(method);
+        }
+
+        public void addAfterTest(Method method) {
+            afterTest.add(method);
+        }
+
+        public boolean beforeClass() {
+            try {
+                for (Method method : beforeClass) {
+                    method.invoke(object);
+                }
+                return true;
+            } catch (InvocationTargetException | IllegalAccessException e) {
+                logger.error("", e.getCause());
+            }
+            return false;
+        }
+
+        public void afterClass() {
+            try {
+                for (Method method : afterClass) {
+                    method.invoke(object);
+                }
+            } catch (InvocationTargetException | IllegalAccessException e) {
+                logger.error("", e.getCause());
+            }
+        }
+
+        public boolean beforeTest() {
+            try {
+                for (Method method : beforeTest) {
+                    method.invoke(object);
+                }
+                return true;
+            } catch (InvocationTargetException | IllegalAccessException e) {
+                logger.error("", e.getCause());
+            }
+            return false;
+        }
+
+        public void afterTest() {
+            try {
+                for (Method method : afterTest) {
+                    method.invoke(object);
+                }
+            } catch (InvocationTargetException | IllegalAccessException e) {
+                logger.error("", e.getCause());
+            }
+        }
+
+        public TestResult runTest(int index) {
+            if (!beforeTest()) {
+                return TestResult.FAILED;
+            }
+
+            TestResult result;
+            Method method = testMethods.get(index);
+            try {
+                long begin = System.nanoTime();
+                method.invoke(object);
+                String time = String.format("%.3f", (System.nanoTime() - begin) / 1000_0000f);
+                logger.info("Test {}.{} PASSED, duration: {}", getName(), method.getName(), time);
+                result = TestResult.SUCCESS;
+            } catch (IllegalAccessException | InvocationTargetException e) {
+                if (expectedException(method, e)) {
+                    logger.info("Test {}.{} PASSED", getName(), method.getName());
+                    result = TestResult.SUCCESS;
+                } else if (e.getCause() instanceof SkipException) {
+                    logger.info("Test {}.{} SKIPPED", getName(), method.getName());
+                    result = TestResult.SKIPPED;
+                } else if (e.getCause() instanceof UnsupportedOperationException) {
+                    logger.info("Test {}.{} UNSUPPORTED", getName(), method.getName());
+                    logger.trace("", e.getCause());
+                    result = TestResult.UNSUPPORTED;
+                } else {
+                    logger.error("Test {}.{} FAILED", getName(), method.getName());
+                    logger.error("", e.getCause());
+                    result = TestResult.FAILED;
+                }
+            } finally {
+                afterTest();
+            }
+            return result;
+        }
+
+        public int getTestCount() {
+            return testMethods.size();
+        }
+
+        public String getName() {
+            return object.getClass().getName();
+        }
+
+        private static boolean expectedException(Method method, Exception e) {
+            Test test = method.getAnnotation(Test.class);
+            Class<?>[] exceptions = test.expectedExceptions();
+            if (exceptions.length > 0) {
+                Throwable exception = e.getCause();
+                for (Class<?> c : exceptions) {
+                    if (c.isInstance(exception)) {
+                        return true;
+                    }
+                }
+            }
+            return false;
+        }
+    }
+
+    public enum TestResult {
+        SUCCESS,
+        FAILED,
+        SKIPPED,
+        UNSUPPORTED;
+    }
+}
diff --git a/java-package/integration/src/main/java/org/apache/mxnet/integration/package-info.java b/java-package/integration/src/main/java/org/apache/mxnet/integration/package-info.java
new file mode 100644
index 0000000..f97e77a
--- /dev/null
+++ b/java-package/integration/src/main/java/org/apache/mxnet/integration/package-info.java
@@ -0,0 +1,19 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/** Contains integration tests that use the engine to test the actual behavior of the API. */
+package org.apache.mxnet.integration;
diff --git a/java-package/integration/src/main/java/org/apache/mxnet/integration/tests/engine/ModelTest.java b/java-package/integration/src/main/java/org/apache/mxnet/integration/tests/engine/ModelTest.java
new file mode 100644
index 0000000..8beffb2
--- /dev/null
+++ b/java-package/integration/src/main/java/org/apache/mxnet/integration/tests/engine/ModelTest.java
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.integration.tests.engine;
+
+import java.io.IOException;
+import org.apache.mxnet.engine.BaseMxResource;
+import org.apache.mxnet.engine.Model;
+import org.apache.mxnet.engine.MxResource;
+import org.apache.mxnet.engine.Predictor;
+import org.apache.mxnet.integration.tests.jna.JnaUtilTest;
+import org.apache.mxnet.integration.util.Assertions;
+import org.apache.mxnet.ndarray.NDArray;
+import org.apache.mxnet.ndarray.NDList;
+import org.apache.mxnet.ndarray.types.Shape;
+import org.apache.mxnet.repository.Item;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.testng.annotations.Test;
+
+public class ModelTest {
+    private static final Logger logger = LoggerFactory.getLogger(JnaUtilTest.class);
+
+    @Test
+    public void modelLoadAndPredictTest() {
+        try (MxResource base = BaseMxResource.getSystemMxResource()) {
+            Model model = Model.loadModel(Item.MLP);
+            Predictor<NDList, NDList> predictor = model.newPredictor();
+            NDArray input = NDArray.create(base, new Shape(1, 28, 28)).ones();
+            NDList inputs = new NDList();
+            inputs.add(input);
+            NDList result = predictor.predict(inputs);
+            NDArray expected =
+                    NDArray.create(
+                            base,
+                            new float[] {
+                                4.93476f,
+                                -0.76084447f,
+                                0.37713608f,
+                                0.6605506f,
+                                -1.3485785f,
+                                -0.8736369f,
+                                0.018061712f,
+                                -1.3274033f,
+                                1.0609543f,
+                                0.24042489f
+                            },
+                            new Shape(1, 10));
+            Assertions.assertAlmostEquals(result.get(0), expected);
+        } catch (IOException e) {
+            logger.error(e.getMessage(), e);
+        }
+    }
+}
diff --git a/java-package/integration/src/main/java/org/apache/mxnet/integration/tests/engine/package-info.java b/java-package/integration/src/main/java/org/apache/mxnet/integration/tests/engine/package-info.java
new file mode 100644
index 0000000..cc6a916
--- /dev/null
+++ b/java-package/integration/src/main/java/org/apache/mxnet/integration/tests/engine/package-info.java
@@ -0,0 +1,19 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/** Contains integration tests that use the engine to test the actual behavior of the API. */
+package org.apache.mxnet.integration.tests.engine;
diff --git a/java-package/integration/src/main/java/org/apache/mxnet/integration/tests/jna/JnaUtilTest.java b/java-package/integration/src/main/java/org/apache/mxnet/integration/tests/jna/JnaUtilTest.java
new file mode 100644
index 0000000..e751140
--- /dev/null
+++ b/java-package/integration/src/main/java/org/apache/mxnet/integration/tests/jna/JnaUtilTest.java
@@ -0,0 +1,183 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.integration.tests.jna;
+
+import java.io.IOException;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+import org.apache.mxnet.engine.BaseMxResource;
+import org.apache.mxnet.engine.Device;
+import org.apache.mxnet.engine.MxResource;
+import org.apache.mxnet.engine.Symbol;
+import org.apache.mxnet.jna.JnaUtils;
+import org.apache.mxnet.ndarray.NDArray;
+import org.apache.mxnet.ndarray.NDList;
+import org.apache.mxnet.ndarray.types.Shape;
+import org.apache.mxnet.nn.Parameter;
+import org.apache.mxnet.nn.SymbolBlock;
+import org.apache.mxnet.repository.Item;
+import org.apache.mxnet.repository.Repository;
+import org.apache.mxnet.util.PairList;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.testng.Assert;
+import org.testng.annotations.Test;
+
+public class JnaUtilTest {
+
+    private static final Logger logger = LoggerFactory.getLogger(JnaUtilTest.class);
+
+    @Test
+    public void doForwardTest() throws IOException {
+        try (MxResource base = BaseMxResource.getSystemMxResource()) {
+            Path modelPath = Repository.initRepository(Item.MLP);
+            Path symbolPath = modelPath.resolve("mlp-symbol.json");
+            Path paramsPath = modelPath.resolve("mlp-0000.params");
+            Symbol symbol = Symbol.loadSymbol(base, symbolPath);
+            SymbolBlock block = new SymbolBlock(base, symbol);
+            Device device = Device.defaultIfNull();
+            NDList mxNDArray = JnaUtils.loadNdArray(base, paramsPath, Device.defaultIfNull(null));
+
+            // load parameters
+            List<Parameter> parameters = block.getAllParameters();
+            Map<String, Parameter> map = new ConcurrentHashMap<>();
+            parameters.forEach(p -> map.put(p.getName(), p));
+
+            for (NDArray nd : mxNDArray) {
+                String key = nd.getName();
+                if (key == null) {
+                    throw new IllegalArgumentException(
+                            "Array names must be present in parameter file");
+                }
+
+                String paramName = key.split(":", 2)[1];
+                Parameter parameter = map.remove(paramName);
+                parameter.setArray(nd);
+            }
+            block.setInputNames(new ArrayList<>(map.keySet()));
+
+            NDArray arr = NDArray.create(base, new Shape(1, 28, 28), device).ones();
+            block.forward(new NDList(arr), new PairList<>(), device);
+            logger.info(
+                    "Number of MxResource managed by baseMxResource: {}",
+                    BaseMxResource.getSystemMxResource().getSubResource().size());
+        } catch (IOException e) {
+            logger.error(e.getMessage(), e);
+            throw e;
+        }
+        Assert.assertEquals(BaseMxResource.getSystemMxResource().getSubResource().size(), 0);
+    }
+
+    @Test
+    public void createNdArray() {
+        try {
+            try (BaseMxResource base = BaseMxResource.getSystemMxResource()) {
+                int[] originIntegerArray = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
+                float[] originFloatArray = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
+                double[] originDoubleArray = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
+                long[] originLongArray = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
+                boolean[] originBooleanArray = {
+                    true, false, false, true, true, true, true, false, false, true, true, true
+                };
+                byte[] originByteArray = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
+                NDArray intArray = NDArray.create(base, originIntegerArray, new Shape(3, 4));
+                NDArray floatArray = NDArray.create(base, originFloatArray, new Shape(3, 4));
+                NDArray doubleArray = NDArray.create(base, originDoubleArray, new Shape(3, 4));
+                NDArray longArray = NDArray.create(base, originLongArray, new Shape(3, 4));
+                NDArray booleanArray = NDArray.create(base, originBooleanArray, new Shape(3, 4));
+                NDArray byteArray = NDArray.create(base, originByteArray, new Shape(3, 4));
+                NDArray intArray2 = NDArray.create(base, originIntegerArray);
+                NDArray floatArray2 = NDArray.create(base, originFloatArray);
+                NDArray doubleArray2 = NDArray.create(base, originDoubleArray);
+                NDArray longArray2 = NDArray.create(base, originLongArray);
+                NDArray booleanArray2 = NDArray.create(base, originBooleanArray);
+                NDArray byteArray2 = NDArray.create(base, originByteArray);
+
+                int[] ndArrayInt = intArray.toIntArray();
+                Assert.assertEquals(originIntegerArray, ndArrayInt);
+                // Float -> Double
+                float[] floats = floatArray.toFloatArray();
+                Assert.assertEquals(originFloatArray, floats);
+                double[] ndArrayDouble = doubleArray.toDoubleArray();
+                Assert.assertEquals(originDoubleArray, ndArrayDouble);
+                long[] ndArrayLong = longArray.toLongArray();
+                Assert.assertEquals(originLongArray, ndArrayLong);
+                boolean[] ndArrayBoolean = booleanArray.toBooleanArray();
+                Assert.assertEquals(originBooleanArray, ndArrayBoolean);
+                byte[] ndArrayByte = byteArray.toByteArray();
+                Assert.assertEquals(originByteArray, ndArrayByte);
+
+                int[] ndArrayInt2 = intArray2.toIntArray();
+                Assert.assertEquals(originIntegerArray, ndArrayInt2);
+
+                // Float -> Double
+                float[] floats2 = floatArray2.toFloatArray();
+                Assert.assertEquals(originFloatArray, floats2);
+                double[] ndArrayDouble2 = doubleArray2.toDoubleArray();
+                Assert.assertEquals(originDoubleArray, ndArrayDouble2);
+                long[] ndArrayLong2 = longArray2.toLongArray();
+                Assert.assertEquals(originLongArray, ndArrayLong2);
+                boolean[] ndArrayBoolean2 = booleanArray2.toBooleanArray();
+                Assert.assertEquals(originBooleanArray, ndArrayBoolean2);
+                byte[] ndArrayByte2 = byteArray2.toByteArray();
+                Assert.assertEquals(originByteArray, ndArrayByte2);
+            } catch (ClassCastException e) {
+                logger.error(e.getMessage());
+                throw e;
+            }
+            BaseMxResource base = BaseMxResource.getSystemMxResource();
+            int countNotReleased = 0;
+            for (MxResource mxResource : base.getSubResource().values()) {
+                if (!mxResource.getClosed()) {
+                    ++countNotReleased;
+                }
+            }
+            Assert.assertEquals(countNotReleased, 0);
+        } catch (ClassCastException e) {
+            logger.error(e.getMessage());
+            throw e;
+        }
+    }
+
+    @Test
+    public void loadNdArray() throws IOException {
+        try (BaseMxResource base = BaseMxResource.getSystemMxResource()) {
+            Path modelPath = Repository.initRepository(Item.MLP);
+            Path paramsPath = modelPath.resolve("mlp-0000.params");
+            NDList mxNDArray =
+                    JnaUtils.loadNdArray(
+                            base, Paths.get(paramsPath.toUri()), Device.defaultIfNull(null));
+            logger.info(mxNDArray.toString());
+            logger.info(
+                    String.format(
+                            "The amount of sub resources managed by BaseMxResource: %s",
+                            base.getSubResource().size()));
+        } catch (IOException e) {
+            logger.error(e.getMessage());
+            throw e;
+        }
+        logger.info(
+                String.format(
+                        "The amount of sub resources managed by BaseMxResource: %s",
+                        BaseMxResource.getSystemMxResource().getSubResource().size()));
+    }
+}
diff --git a/java-package/integration/src/main/java/org/apache/mxnet/integration/util/Arguments.java b/java-package/integration/src/main/java/org/apache/mxnet/integration/util/Arguments.java
new file mode 100644
index 0000000..15771f3
--- /dev/null
+++ b/java-package/integration/src/main/java/org/apache/mxnet/integration/util/Arguments.java
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.integration.util;
+
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.Option;
+import org.apache.commons.cli.Options;
+
+public class Arguments {
+
+    private String methodName;
+    private String className;
+    private String packageName;
+    private int duration;
+    private int iteration = 1;
+
+    public Arguments(CommandLine cmd) {
+        methodName = cmd.getOptionValue("method-name");
+        className = cmd.getOptionValue("class-name");
+        if (cmd.hasOption("package-name")) {
+            packageName = cmd.getOptionValue("package-name");
+        } else {
+            packageName = "org.apache.mxnet.integration.tests.";
+        }
+
+        if (cmd.hasOption("duration")) {
+            duration = Integer.parseInt(cmd.getOptionValue("duration"));
+        }
+        if (cmd.hasOption("iteration")) {
+            iteration = Integer.parseInt(cmd.getOptionValue("iteration"));
+        }
+    }
+
+    public static Options getOptions() {
+        Options options = new Options();
+        options.addOption(
+                Option.builder("d")
+                        .longOpt("duration")
+                        .hasArg()
+                        .argName("DURATION")
+                        .desc("Duration of the test.")
+                        .build());
+        options.addOption(
+                Option.builder("n")
+                        .longOpt("iteration")
+                        .hasArg()
+                        .argName("ITERATION")
+                        .desc("Number of iterations in each test.")
+                        .build());
+        options.addOption(
+                Option.builder("p")
+                        .longOpt("package-name")
+                        .hasArg()
+                        .argName("PACKAGE-NAME")
+                        .desc("Name of the package to run")
+                        .build());
+        options.addOption(
+                Option.builder("c")
+                        .longOpt("class-name")
+                        .hasArg()
+                        .argName("CLASS-NAME")
+                        .desc("Name of the class to run")
+                        .build());
+        options.addOption(
+                Option.builder("m")
+                        .longOpt("method-name")
+                        .hasArg()
+                        .argName("METHOD-NAME")
+                        .desc("Name of the method to run")
+                        .build());
+        return options;
+    }
+
+    public int getDuration() {
+        return duration;
+    }
+
+    public int getIteration() {
+        return iteration;
+    }
+
+    public String getPackageName() {
+        return packageName;
+    }
+
+    public String getClassName() {
+        return className;
+    }
+
+    public String getMethodName() {
+        return methodName;
+    }
+}
diff --git a/java-package/integration/src/main/java/org/apache/mxnet/integration/util/Assertions.java b/java-package/integration/src/main/java/org/apache/mxnet/integration/util/Assertions.java
new file mode 100644
index 0000000..b342b45
--- /dev/null
+++ b/java-package/integration/src/main/java/org/apache/mxnet/integration/util/Assertions.java
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mxnet.integration.util;
+
+import org.apache.mxnet.ndarray.NDArray;
+import org.apache.mxnet.ndarray.NDList;
+import org.testng.Assert;
+
+public final class Assertions {
+    private static final double RTOL = 1e-5;
+    private static final double ATOL = 1e-3;
+
+    private Assertions() {}
+
+    private static <T> String getDefaultErrorMessage(T actual, T expected) {
+        return getDefaultErrorMessage(actual, expected, null);
+    }
+
+    private static <T> String getDefaultErrorMessage(T actual, T expected, String errorMessage) {
+        StringBuilder sb = new StringBuilder(100);
+        if (errorMessage != null) {
+            sb.append(errorMessage);
+        }
+        sb.append(System.lineSeparator())
+                .append("Expected: ")
+                .append(expected)
+                .append(System.lineSeparator())
+                .append("Actual: ")
+                .append(actual);
+        return sb.toString();
+    }
+
+    public static void assertAlmostEquals(NDArray actual, NDArray expected) {
+        assertAlmostEquals(actual, expected, RTOL, ATOL);
+    }
+
+    public static void assertAlmostEquals(NDList actual, NDList expected) {
+        assertAlmostEquals(actual, expected, RTOL, ATOL);
+    }
+
+    public static void assertAlmostEquals(double actual, double expected) {
+        assertAlmostEquals(actual, expected, RTOL, ATOL);
+    }
+
+    public static void assertAlmostEquals(
+            double actual, double expected, double rtol, double atol) {
+        if (Math.abs(actual - expected) > (atol + rtol * Math.abs(expected))) {
+            throw new AssertionError(getDefaultErrorMessage(actual, expected));
+        }
+    }
+
+    public static void assertAlmostEquals(
+            NDList actual, NDList expected, double rtol, double atol) {
+        Assert.assertEquals(
+                actual.size(),
+                expected.size(),
+                getDefaultErrorMessage(
+                        actual.size(), expected.size(), "The NDLists have different sizes"));
+        int size = actual.size();
+        for (int i = 0; i < size; i++) {
+            assertAlmostEquals(actual.get(i), expected.get(i), rtol, atol);
+        }
+    }
+
+    public static void assertAlmostEquals(
+            NDArray actual, NDArray expected, double rtol, double atol) {
+        if (!actual.getShape().equals(expected.getShape())) {
+            throw new AssertionError(
+                    getDefaultErrorMessage(
+                            actual.getShape(),
+                            expected.getShape(),
+                            "The shape of two NDArray are different!"));
+        }
+        Number[] actualDoubleArray = actual.toArray();
+        Number[] expectedDoubleArray = expected.toArray();
+        for (int i = 0; i < actualDoubleArray.length; i++) {
+            double a = actualDoubleArray[i].doubleValue();
+            double b = expectedDoubleArray[i].doubleValue();
+            if (Math.abs(a - b) > (atol + rtol * Math.abs(b))) {
+                throw new AssertionError("Expected:" + b + " but got " + a);
+            }
+        }
+    }
+
+    public static void assertInPlaceEquals(NDArray actual, NDArray expected, NDArray original) {
+        Assert.assertEquals(
+                actual, expected, getDefaultErrorMessage(actual, expected, "Assert Equal failed!"));
+        Assert.assertSame(
+                original,
+                actual,
+                getDefaultErrorMessage(original, expected, "Assert Inplace failed!"));
+    }
+
+    public static void assertInPlaceAlmostEquals(
+            NDArray actual, NDArray expected, NDArray original) {
+        assertInPlaceAlmostEquals(actual, expected, original, RTOL, ATOL);
+    }
+
+    public static void assertInPlaceAlmostEquals(
+            NDArray actual, NDArray expected, NDArray original, double rtol, double atol) {
+        assertAlmostEquals(actual, expected, rtol, atol);
+        Assert.assertSame(
+                original,
+                actual,
+                getDefaultErrorMessage(original, expected, "Assert Inplace failed!"));
+    }
+}
diff --git a/java-package/integration/src/main/java/org/apache/mxnet/integration/util/CoverageUtils.java b/java-package/integration/src/main/java/org/apache/mxnet/integration/util/CoverageUtils.java
new file mode 100644
index 0000000..e897c27
--- /dev/null
+++ b/java-package/integration/src/main/java/org/apache/mxnet/integration/util/CoverageUtils.java
@@ -0,0 +1,260 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mxnet.integration.util;
+
+import java.io.File;
+import java.io.IOException;
+import java.lang.reflect.Constructor;
+import java.lang.reflect.Field;
+import java.lang.reflect.Method;
+import java.lang.reflect.Proxy;
+import java.net.URISyntaxException;
+import java.net.URL;
+import java.net.URLClassLoader;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.Enumeration;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.jar.JarEntry;
+import java.util.jar.JarFile;
+import java.util.stream.Collectors;
+
+public final class CoverageUtils {
+
+    private CoverageUtils() {}
+
+    public static void testGetterSetters(Class<?> baseClass)
+            throws IOException, ReflectiveOperationException, URISyntaxException {
+        List<Class<?>> list = getClasses(baseClass);
+        for (Class<?> clazz : list) {
+            Object obj = null;
+            if (clazz.isEnum()) {
+                obj = clazz.getEnumConstants()[0];
+            } else {
+                Constructor<?>[] constructors = clazz.getConstructors();
+                for (Constructor<?> con : constructors) {
+                    try {
+                        Class<?>[] types = con.getParameterTypes();
+                        Object[] args = new Object[types.length];
+                        for (int i = 0; i < args.length; ++i) {
+                            args[i] = getMockInstance(types[i], true);
+                        }
+                        con.setAccessible(true);
+                        obj = con.newInstance(args);
+                    } catch (ReflectiveOperationException ignore) {
+                        // ignore
+                    }
+                }
+            }
+            if (obj == null) {
+                continue;
+            }
+
+            Method[] methods = clazz.getDeclaredMethods();
+            for (Method method : methods) {
+                String methodName = method.getName();
+                int parameterCount = method.getParameterCount();
+                try {
+                    if (parameterCount == 0
+                            && (methodName.startsWith("get")
+                                    || methodName.startsWith("is")
+                                    || "toString".equals(methodName)
+                                    || "hashCode".equals(methodName))) {
+                        method.invoke(obj);
+                    } else if (parameterCount == 1
+                            && (methodName.startsWith("set") || "fromValue".equals(methodName))) {
+                        Class<?> type = method.getParameterTypes()[0];
+                        method.invoke(obj, getMockInstance(type, true));
+                    } else if ("equals".equals(methodName)) {
+                        method.invoke(obj, obj);
+                        method.invoke(obj, (Object) null);
+                        Class<?> type = method.getParameterTypes()[0];
+                        method.invoke(obj, getMockInstance(type, true));
+                    }
+                } catch (ReflectiveOperationException ignore) {
+                    // ignore
+                }
+            }
+        }
+    }
+
+    private static List<Class<?>> getClasses(Class<?> clazz)
+            throws IOException, ReflectiveOperationException, URISyntaxException {
+        ClassLoader appClassLoader = Thread.currentThread().getContextClassLoader();
+        Field field = appClassLoader.getClass().getDeclaredField("ucp");
+        field.setAccessible(true);
+        Object ucp = field.get(appClassLoader);
+        Method method = ucp.getClass().getDeclaredMethod("getURLs");
+        URL[] urls = (URL[]) method.invoke(ucp);
+        ClassLoader cl = new TestClassLoader(urls, Thread.currentThread().getContextClassLoader());
+
+        URL url = clazz.getProtectionDomain().getCodeSource().getLocation();
+        String path = url.getPath();
+
+        if (!"file".equalsIgnoreCase(url.getProtocol())) {
+            return Collections.emptyList();
+        }
+
+        List<Class<?>> classList = new ArrayList<>();
+
+        Path classPath = Paths.get(url.toURI());
+        if (Files.isDirectory(classPath)) {
+            Collection<Path> files =
+                    Files.walk(classPath)
+                            .filter(p -> Files.isRegularFile(p) && p.toString().endsWith(".class"))
+                            .collect(Collectors.toList());
+            for (Path file : files) {
+                Path p = classPath.relativize(file);
+                String className = p.toString();
+                className = className.substring(0, className.lastIndexOf('.'));
+                className = className.replace(File.separatorChar, '.');
+
+                try {
+                    classList.add(Class.forName(className, true, cl));
+                } catch (Error ignore) {
+                    // ignore
+                }
+            }
+        } else if (path.toLowerCase().endsWith(".jar")) {
+            try (JarFile jarFile = new JarFile(classPath.toFile())) {
+                Enumeration<JarEntry> en = jarFile.entries();
+                while (en.hasMoreElements()) {
+                    JarEntry entry = en.nextElement();
+                    String fileName = entry.getName();
+                    if (fileName.endsWith(".class")) {
+                        fileName = fileName.substring(0, fileName.lastIndexOf('.'));
+                        fileName = fileName.replace('/', '.');
+                        try {
+                            classList.add(Class.forName(fileName, true, cl));
+                        } catch (Error ignore) {
+                            // ignore
+                        }
+                    }
+                }
+            }
+        }
+
+        return classList;
+    }
+
+    private static Object getMockInstance(Class<?> clazz, boolean useConstructor) {
+        if (clazz.isPrimitive()) {
+            if (clazz == Boolean.TYPE) {
+                return Boolean.TRUE;
+            }
+            if (clazz == Character.TYPE) {
+                return '0';
+            }
+            if (clazz == Byte.TYPE) {
+                return (byte) 0;
+            }
+            if (clazz == Short.TYPE) {
+                return (short) 0;
+            }
+            if (clazz == Integer.TYPE) {
+                return 0;
+            }
+            if (clazz == Long.TYPE) {
+                return 0L;
+            }
+            if (clazz == Float.TYPE) {
+                return 0f;
+            }
+            if (clazz == Double.TYPE) {
+                return 0d;
+            }
+        }
+
+        if (clazz.isAssignableFrom(String.class)) {
+            return "";
+        }
+
+        if (clazz.isAssignableFrom(List.class)) {
+            return new ArrayList<>();
+        }
+
+        if (clazz.isAssignableFrom(Set.class)) {
+            return new HashSet<>();
+        }
+
+        if (clazz.isAssignableFrom(Map.class)) {
+            return new HashMap<>();
+        }
+
+        if (clazz.isEnum()) {
+            return clazz.getEnumConstants()[0];
+        }
+
+        if (clazz.isInterface()) {
+            return newProxyInstance(clazz);
+        }
+
+        if (useConstructor) {
+            Constructor<?>[] constructors = clazz.getConstructors();
+            for (Constructor<?> con : constructors) {
+                try {
+                    Class<?>[] types = con.getParameterTypes();
+                    Object[] args = new Object[types.length];
+                    for (int i = 0; i < args.length; ++i) {
+                        args[i] = getMockInstance(types[i], false);
+                    }
+                    con.setAccessible(true);
+                    return con.newInstance(args);
+                } catch (ReflectiveOperationException ignore) {
+                    // ignore
+                }
+            }
+        }
+
+        return null;
+    }
+
+    @SuppressWarnings({"rawtypes", "PMD.UseProperClassLoader"})
+    private static Object newProxyInstance(Class<?> clazz) {
+        ClassLoader cl = clazz.getClassLoader();
+        return Proxy.newProxyInstance(cl, new Class[] {clazz}, (proxy, method, args) -> null);
+    }
+
+    private static final class TestClassLoader extends URLClassLoader {
+
+        public TestClassLoader(URL[] urls, ClassLoader parent) {
+            super(urls, parent);
+        }
+
+        /** {@inheritDoc} */
+        @Override
+        public Class<?> loadClass(String name) throws ClassNotFoundException {
+            try {
+                return findClass(name);
+            } catch (ClassNotFoundException e) {
+                ClassLoader classLoader = getParent();
+                if (classLoader == null) {
+                    classLoader = getSystemClassLoader();
+                }
+                return classLoader.loadClass(name);
+            }
+        }
+    }
+}
diff --git a/java-package/integration/src/main/resources/log4j2.xml b/java-package/integration/src/main/resources/log4j2.xml
new file mode 100644
index 0000000..ff05a01
--- /dev/null
+++ b/java-package/integration/src/main/resources/log4j2.xml
@@ -0,0 +1,34 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!--
+  ~ Licensed to the Apache Software Foundation (ASF) under one or more
+  ~ contributor license agreements.  See the NOTICE file distributed with
+  ~ this work for additional information regarding copyright ownership.
+  ~ The ASF licenses this file to You under the Apache License, Version 2.0
+  ~ (the "License"); you may not use this file except in compliance with
+  ~ the License.  You may obtain a copy of the License at
+  ~
+  ~    http://www.apache.org/licenses/LICENSE-2.0
+  ~
+  ~ Unless required by applicable law or agreed to in writing, software
+  ~ distributed under the License is distributed on an "AS IS" BASIS,
+  ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+  ~ See the License for the specific language governing permissions and
+  ~ limitations under the License.
+  -->
+
+<Configuration status="INFO">
+  <Appenders>
+    <Console name="console" target="SYSTEM_OUT">
+      <PatternLayout
+          pattern="[%-5level] - %msg%n"/>
+    </Console>
+  </Appenders>
+  <Loggers>
+    <Root level="info" additivity="false">
+      <AppenderRef ref="console"/>
+    </Root>
+    <Logger name="com.apache.mxnet" level="${sys:com.apache.mxnet.level:-debug}" additivity="false">
+      <AppenderRef ref="console"/>
+    </Logger>
+  </Loggers>
+</Configuration>
diff --git a/java-package/jnarator/build.gradle b/java-package/jnarator/build.gradle
new file mode 100644
index 0000000..abc7cc0
--- /dev/null
+++ b/java-package/jnarator/build.gradle
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+plugins {
+    id 'antlr'
+}
+
+dependencies {
+    antlr "org.antlr:antlr4:${antlr_version}"
+
+    api "commons-cli:commons-cli:${commons_cli_version}"
+    api "org.antlr:antlr4-runtime:${antlr_version}"
+    api "org.apache.logging.log4j:log4j-slf4j-impl:${log4j_slf4j_version}"
+
+    testImplementation("org.testng:testng:${testng_version}") {
+        exclude group: "junit", module: "junit"
+    }
+
+    testImplementation "org.slf4j:slf4j-simple:${slf4j_version}"
+    testRuntimeOnly project(":mxnet-engine")
+//    testRuntimeOnly ":mxnet-native-auto:${mxnet_version}"
+}
+
+checkstyleMain.source = 'src/main/java'
+
+checkstyleMain {
+    // skip check style for this package
+    exclude 'org/apache/mxnet/jnarator/**'
+}
+
+pmdMain.source = 'src/main/java'
+//pmdMain.ignoreFailures(true)
+spotbugs.ignoreFailures = true
+
+jar {
+    manifest {
+        attributes (
+                "Main-Class" : "org.apache.mxnet.jnarator.Main",
+                "Multi-Release" : true
+        )
+    }
+    includeEmptyDirs = false
+    duplicatesStrategy = DuplicatesStrategy.INCLUDE
+    from configurations.runtimeClasspath.collect { it.isDirectory() ? it : zipTree(it) }
+}
diff --git a/java-package/jnarator/src/main/antlr/org/apache/mxnet/jnarator/parser/C.g4 b/java-package/jnarator/src/main/antlr/org/apache/mxnet/jnarator/parser/C.g4
new file mode 100644
index 0000000..6206c13
--- /dev/null
+++ b/java-package/jnarator/src/main/antlr/org/apache/mxnet/jnarator/parser/C.g4
@@ -0,0 +1,923 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+grammar C;
+
+@parser::header {
+package org.apache.mxnet.jnarator.parser;
+
+}
+
+@lexer::header {
+package org.apache.mxnet.jnarator.parser;
+
+}
+
+
+primaryExpression
+    :   Identifier
+    |   Constant
+    |   StringLiteral+
+    |   '(' expression ')'
+    |   genericSelection
+    |   '__extension__'? '(' compoundStatement ')' // Blocks (GCC extension)
+    |   '__builtin_va_arg' '(' unaryExpression ',' typeName ')'
+    |   '__builtin_offsetof' '(' typeName ',' unaryExpression ')'
+    ;
+
+genericSelection
+    :   '_Generic' '(' assignmentExpression ',' genericAssocList ')'
+    ;
+
+genericAssocList
+    :   genericAssociation
+    |   genericAssocList ',' genericAssociation
+    ;
+
+genericAssociation
+    :   typeName ':' assignmentExpression
+    |   'default' ':' assignmentExpression
+    ;
+
+postfixExpression
+    :   primaryExpression
+    |   postfixExpression '[' expression ']'
+    |   postfixExpression '(' argumentExpressionList? ')'
+    |   postfixExpression '.' Identifier
+    |   postfixExpression '->' Identifier
+    |   postfixExpression '++'
+    |   postfixExpression '--'
+    |   '(' typeName ')' '{' initializerList '}'
+    |   '(' typeName ')' '{' initializerList ',' '}'
+    |   '__extension__' '(' typeName ')' '{' initializerList '}'
+    |   '__extension__' '(' typeName ')' '{' initializerList ',' '}'
+    ;
+
+argumentExpressionList
+    :   assignmentExpression
+    |   argumentExpressionList ',' assignmentExpression
+    ;
+
+unaryExpression
+    :   postfixExpression
+    |   '++' unaryExpression
+    |   '--' unaryExpression
+    |   unaryOperator castExpression
+    |   'sizeof' unaryExpression
+    |   'sizeof' '(' typeName ')'
+    |   '_Alignof' '(' typeName ')'
+    |   '&&' Identifier // GCC extension address of label
+    ;
+
+unaryOperator
+    :   '&' | '*' | '+' | '-' | '~' | '!'
+    ;
+
+castExpression
+    :   '(' typeName ')' castExpression
+    |   '__extension__' '(' typeName ')' castExpression
+    |   unaryExpression
+    |   DigitSequence // for
+    ;
+
+multiplicativeExpression
+    :   castExpression
+    |   multiplicativeExpression '*' castExpression
+    |   multiplicativeExpression '/' castExpression
+    |   multiplicativeExpression '%' castExpression
+    ;
+
+additiveExpression
+    :   multiplicativeExpression
+    |   additiveExpression '+' multiplicativeExpression
+    |   additiveExpression '-' multiplicativeExpression
+    ;
+
+shiftExpression
+    :   additiveExpression
+    |   shiftExpression '<<' additiveExpression
+    |   shiftExpression '>>' additiveExpression
+    ;
+
+relationalExpression
+    :   shiftExpression
+    |   relationalExpression '<' shiftExpression
+    |   relationalExpression '>' shiftExpression
+    |   relationalExpression '<=' shiftExpression
+    |   relationalExpression '>=' shiftExpression
+    ;
+
+equalityExpression
+    :   relationalExpression
+    |   equalityExpression '==' relationalExpression
+    |   equalityExpression '!=' relationalExpression
+    ;
+
+andExpression
+    :   equalityExpression
+    |   andExpression '&' equalityExpression
+    ;
+
+exclusiveOrExpression
+    :   andExpression
+    |   exclusiveOrExpression '^' andExpression
+    ;
+
+inclusiveOrExpression
+    :   exclusiveOrExpression
+    |   inclusiveOrExpression '|' exclusiveOrExpression
+    ;
+
+logicalAndExpression
+    :   inclusiveOrExpression
+    |   logicalAndExpression '&&' inclusiveOrExpression
+    ;
+
+logicalOrExpression
+    :   logicalAndExpression
+    |   logicalOrExpression '||' logicalAndExpression
+    ;
+
+conditionalExpression
+    :   logicalOrExpression ('?' expression ':' conditionalExpression)?
+    ;
+
+assignmentExpression
+    :   conditionalExpression
+    |   unaryExpression assignmentOperator assignmentExpression
+    |   DigitSequence // for
+    ;
+
+assignmentOperator
+    :   '=' | '*=' | '/=' | '%=' | '+=' | '-=' | '<<=' | '>>=' | '&=' | '^=' | '|='
+    ;
+
+expression
+    :   assignmentExpression
+    |   expression ',' assignmentExpression
+    ;
+
+constantExpression
+    :   conditionalExpression
+    ;
+
+declaration
+    :   declarationSpecifiers initDeclaratorList ';'
+	| 	declarationSpecifiers ';'
+    |   staticAssertDeclaration
+    ;
+
+declarationSpecifiers
+    :   declarationSpecifier+
+    ;
+
+declarationSpecifiers2
+    :   declarationSpecifier+
+    ;
+
+declarationSpecifier
+    :   storageClassSpecifier
+    |   typeSpecifier
+    |   typeQualifier
+    |   functionSpecifier
+    |   alignmentSpecifier
+    ;
+
+initDeclaratorList
+    :   initDeclarator
+    |   initDeclaratorList ',' initDeclarator
+    ;
+
+initDeclarator
+    :   declarator
+    |   declarator '=' initializer
+    ;
+
+storageClassSpecifier
+    :   'typedef'
+    |   'extern'
+    |   'static'
+    |   '_Thread_local'
+    |   'auto'
+    |   'register'
+    ;
+
+typeSpecifier
+    :   ('void'
+    |   'char'
+    |   'short'
+    |   'int'
+    |   'long'
+    |   'float'
+    |   'double'
+    |   'signed'
+    |   'unsigned'
+    |   '_Bool'
+    |   '_Complex'
+    |   '__m128'
+    |   '__m128d'
+    |   '__m128i')
+    |   '__extension__' '(' ('__m128' | '__m128d' | '__m128i') ')'
+    |   atomicTypeSpecifier
+    |   structOrUnionSpecifier
+    |   enumSpecifier
+    |   typedefName
+    |   '__typeof__' '(' constantExpression ')' // GCC extension
+    |   typeSpecifier pointer
+    ;
+
+structOrUnionSpecifier
+    :   structOrUnion Identifier? '{' structDeclarationList '}'
+    |   structOrUnion Identifier
+    ;
+
+structOrUnion
+    :   'struct'
+    |   'union'
+    ;
+
+structDeclarationList
+    :   structDeclaration
+    |   structDeclarationList structDeclaration
+    ;
+
+structDeclaration
+    :   specifierQualifierList structDeclaratorList? ';'
+    |   staticAssertDeclaration
+    ;
+
+specifierQualifierList
+    :   typeSpecifier specifierQualifierList?
+    |   typeQualifier specifierQualifierList?
+    ;
+
+structDeclaratorList
+    :   structDeclarator
+    |   structDeclaratorList ',' structDeclarator
+    ;
+
+structDeclarator
+    :   declarator
+    |   declarator? ':' constantExpression
+    ;
+
+enumSpecifier
+    :   'enum' Identifier? '{' enumeratorList '}'
+    |   'enum' Identifier? '{' enumeratorList ',' '}'
+    |   'enum' Identifier
+    ;
+
+enumeratorList
+    :   enumerator
+    |   enumeratorList ',' enumerator
+    ;
+
+enumerator
+    :   enumerationConstant
+    |   enumerationConstant '=' constantExpression
+    ;
+
+enumerationConstant
+    :   Identifier
+    ;
+
+atomicTypeSpecifier
+    :   '_Atomic' '(' typeName ')'
+    ;
+
+typeQualifier
+    :   'const'
+    |   'restrict'
+    |   'volatile'
+    |   '_Atomic'
+    ;
+
+functionSpecifier
+    :   ('inline'
+    |   '_Noreturn'
+    |   '__inline__' // GCC extension
+    |   '__stdcall')
+    |   gccAttributeSpecifier
+    |   '__declspec' '(' Identifier ')'
+    ;
+
+alignmentSpecifier
+    :   '_Alignas' '(' typeName ')'
+    |   '_Alignas' '(' constantExpression ')'
+    ;
+
+declarator
+    :   pointer? directDeclarator gccDeclaratorExtension*
+    ;
+
+directDeclarator
+    :   Identifier
+    |   '(' declarator ')'
+    |   directDeclarator '[' typeQualifierList? assignmentExpression? ']'
+    |   directDeclarator '[' 'static' typeQualifierList? assignmentExpression ']'
+    |   directDeclarator '[' typeQualifierList 'static' assignmentExpression ']'
+    |   directDeclarator '[' typeQualifierList? '*' ']'
+    |   directDeclarator '(' parameterTypeList ')'
+    |   directDeclarator '(' identifierList? ')'
+    |   Identifier ':' DigitSequence  // bit field
+    |   '(' typeSpecifier? pointer directDeclarator ')' // function pointer like: (__cdecl *f)
+    ;
+
+gccDeclaratorExtension
+    :   '__asm' '(' StringLiteral+ ')'
+    |   gccAttributeSpecifier
+    ;
+
+gccAttributeSpecifier
+    :   '__attribute__' '(' '(' gccAttributeList ')' ')'
+    ;
+
+gccAttributeList
+    :   gccAttribute (',' gccAttribute)*
+    |   // empty
+    ;
+
+gccAttribute
+    :   ~(',' | '(' | ')') // relaxed def for "identifier or reserved word"
+        ('(' argumentExpressionList? ')')?
+    |   // empty
+    ;
+
+nestedParenthesesBlock
+    :   (   ~('(' | ')')
+        |   '(' nestedParenthesesBlock ')'
+        )*
+    ;
+
+pointer
+    :   '*' typeQualifierList?
+    |   '*' typeQualifierList? pointer
+    |   '^' typeQualifierList? // Blocks language extension
+    |   '^' typeQualifierList? pointer // Blocks language extension
+    ;
+
+typeQualifierList
+    :   typeQualifier
+    |   typeQualifierList typeQualifier
+    ;
+
+parameterTypeList
+    :   parameterList
+    |   parameterList ',' '...'
+    ;
+
+parameterList
+    :   parameterDeclaration
+    |   parameterList ',' parameterDeclaration
+    ;
+
+parameterDeclaration
+    :   declarationSpecifiers declarator
+    |   declarationSpecifiers2 abstractDeclarator?
+    ;
+
+identifierList
+    :   Identifier
+    |   identifierList ',' Identifier
+    ;
+
+typeName
+    :   specifierQualifierList abstractDeclarator?
+    ;
+
+abstractDeclarator
+    :   pointer
+    |   pointer? directAbstractDeclarator gccDeclaratorExtension*
+    ;
+
+directAbstractDeclarator
+    :   '(' abstractDeclarator ')' gccDeclaratorExtension*
+    |   '[' typeQualifierList? assignmentExpression? ']'
+    |   '[' 'static' typeQualifierList? assignmentExpression ']'
+    |   '[' typeQualifierList 'static' assignmentExpression ']'
+    |   '[' '*' ']'
+    |   '(' parameterTypeList? ')' gccDeclaratorExtension*
+    |   directAbstractDeclarator '[' typeQualifierList? assignmentExpression? ']'
+    |   directAbstractDeclarator '[' 'static' typeQualifierList? assignmentExpression ']'
+    |   directAbstractDeclarator '[' typeQualifierList 'static' assignmentExpression ']'
+    |   directAbstractDeclarator '[' '*' ']'
+    |   directAbstractDeclarator '(' parameterTypeList? ')' gccDeclaratorExtension*
+    ;
+
+typedefName
+    :   Identifier
+    ;
+
+initializer
+    :   assignmentExpression
+    |   '{' initializerList '}'
+    |   '{' initializerList ',' '}'
+    ;
+
+initializerList
+    :   designation? initializer
+    |   initializerList ',' designation? initializer
+    ;
+
+designation
+    :   designatorList '='
+    ;
+
+designatorList
+    :   designator
+    |   designatorList designator
+    ;
+
+designator
+    :   '[' constantExpression ']'
+    |   '.' Identifier
+    ;
+
+staticAssertDeclaration
+    :   '_Static_assert' '(' constantExpression ',' StringLiteral+ ')' ';'
+    ;
+
+statement
+    :   labeledStatement
+    |   compoundStatement
+    |   expressionStatement
+    |   selectionStatement
+    |   iterationStatement
+    |   jumpStatement
+    |   ('__asm' | '__asm__') ('volatile' | '__volatile__') '(' (logicalOrExpression (',' logicalOrExpression)*)? (':' (logicalOrExpression (',' logicalOrExpression)*)?)* ')' ';'
+    ;
+
+labeledStatement
+    :   Identifier ':' statement
+    |   'case' constantExpression ':' statement
+    |   'default' ':' statement
+    ;
+
+compoundStatement
+    :   '{' blockItemList? '}'
+    ;
+
+blockItemList
+    :   blockItem
+    |   blockItemList blockItem
+    ;
+
+blockItem
+    :   statement
+    |   declaration
+    ;
+
+expressionStatement
+    :   expression? ';'
+    ;
+
+selectionStatement
+    :   'if' '(' expression ')' statement ('else' statement)?
+    |   'switch' '(' expression ')' statement
+    ;
+
+iterationStatement
+    :   While '(' expression ')' statement
+    |   Do statement While '(' expression ')' ';'
+    |   For '(' forCondition ')' statement
+    ;
+
+//    |   'for' '(' expression? ';' expression?  ';' forUpdate? ')' statement
+//    |   For '(' declaration  expression? ';' expression? ')' statement
+
+forCondition
+	:   forDeclaration ';' forExpression? ';' forExpression?
+	|   expression? ';' forExpression? ';' forExpression?
+	;
+
+forDeclaration
+    :   declarationSpecifiers initDeclaratorList
+	| 	declarationSpecifiers
+    ;
+
+forExpression
+    :   assignmentExpression
+    |   forExpression ',' assignmentExpression
+    ;
+
+jumpStatement
+    :   'goto' Identifier ';'
+    |   'continue' ';'
+    |   'break' ';'
+    |   'return' expression? ';'
+    |   'goto' unaryExpression ';' // GCC extension
+    ;
+
+compilationUnit
+    :   translationUnit? ( EOF | '}' )
+    ;
+
+translationUnit
+    :   externalDeclaration
+    |   translationUnit externalDeclaration
+    ;
+
+externalDeclaration
+    :   functionDefinition
+    |   declaration
+    |   ';' // stray ;
+    ;
+
+functionDefinition
+    :   declarationSpecifiers? declarator declarationList? compoundStatement
+    ;
+
+declarationList
+    :   declaration
+    |   declarationList declaration
+    ;
+
+Auto : 'auto';
+Break : 'break';
+Case : 'case';
+Char : 'char';
+Const : 'const';
+Continue : 'continue';
+Default : 'default';
+Do : 'do';
+Double : 'double';
+Else : 'else';
+Enum : 'enum';
+Extern : 'extern';
+Float : 'float';
+For : 'for';
+Goto : 'goto';
+If : 'if';
+Inline : 'inline';
+Int : 'int';
+Long : 'long';
+Register : 'register';
+Restrict : 'restrict';
+Return : 'return';
+Short : 'short';
+Signed : 'signed';
+Sizeof : 'sizeof';
+Static : 'static';
+Struct : 'struct';
+Switch : 'switch';
+Typedef : 'typedef';
+Union : 'union';
+Unsigned : 'unsigned';
+Void : 'void';
+Volatile : 'volatile';
+While : 'while';
+
+Alignas : '_Alignas';
+Alignof : '_Alignof';
+Atomic : '_Atomic';
+Bool : '_Bool';
+Complex : '_Complex';
+Generic : '_Generic';
+Imaginary : '_Imaginary';
+Noreturn : '_Noreturn';
+StaticAssert : '_Static_assert';
+ThreadLocal : '_Thread_local';
+
+LeftParen : '(';
+RightParen : ')';
+LeftBracket : '[';
+RightBracket : ']';
+LeftBrace : '{';
+RightBrace : '}';
+
+Less : '<';
+LessEqual : '<=';
+Greater : '>';
+GreaterEqual : '>=';
+LeftShift : '<<';
+RightShift : '>>';
+
+Plus : '+';
+PlusPlus : '++';
+Minus : '-';
+MinusMinus : '--';
+Star : '*';
+Div : '/';
+Mod : '%';
+
+And : '&';
+Or : '|';
+AndAnd : '&&';
+OrOr : '||';
+Caret : '^';
+Not : '!';
+Tilde : '~';
+
+Question : '?';
+Colon : ':';
+Semi : ';';
+Comma : ',';
+
+Assign : '=';
+// '*=' | '/=' | '%=' | '+=' | '-=' | '<<=' | '>>=' | '&=' | '^=' | '|='
+StarAssign : '*=';
+DivAssign : '/=';
+ModAssign : '%=';
+PlusAssign : '+=';
+MinusAssign : '-=';
+LeftShiftAssign : '<<=';
+RightShiftAssign : '>>=';
+AndAssign : '&=';
+XorAssign : '^=';
+OrAssign : '|=';
+
+Equal : '==';
+NotEqual : '!=';
+
+Arrow : '->';
+Dot : '.';
+Ellipsis : '...';
+
+Identifier
+    :   IdentifierNondigit
+        (   IdentifierNondigit
+        |   Digit
+        )*
+    ;
+
+fragment
+IdentifierNondigit
+    :   Nondigit
+    |   UniversalCharacterName
+    //|   // other implementation-defined characters...
+    ;
+
+fragment
+Nondigit
+    :   [a-zA-Z_]
+    ;
+
+fragment
+Digit
+    :   [0-9]
+    ;
+
+fragment
+UniversalCharacterName
+    :   '\\u' HexQuad
+    |   '\\U' HexQuad HexQuad
+    ;
+
+fragment
+HexQuad
+    :   HexadecimalDigit HexadecimalDigit HexadecimalDigit HexadecimalDigit
+    ;
+
+Constant
+    :   IntegerConstant
+    |   FloatingConstant
+    //|   EnumerationConstant
+    |   CharacterConstant
+    ;
+
+fragment
+IntegerConstant
+    :   DecimalConstant IntegerSuffix?
+    |   OctalConstant IntegerSuffix?
+    |   HexadecimalConstant IntegerSuffix?
+    |	BinaryConstant
+    ;
+
+fragment
+BinaryConstant
+	:	'0' [bB] [0-1]+
+	;
+
+fragment
+DecimalConstant
+    :   NonzeroDigit Digit*
+    ;
+
+fragment
+OctalConstant
+    :   '0' OctalDigit*
+    ;
+
+fragment
+HexadecimalConstant
+    :   HexadecimalPrefix HexadecimalDigit+
+    ;
+
+fragment
+HexadecimalPrefix
+    :   '0' [xX]
+    ;
+
+fragment
+NonzeroDigit
+    :   [1-9]
+    ;
+
+fragment
+OctalDigit
+    :   [0-7]
+    ;
+
+fragment
+HexadecimalDigit
+    :   [0-9a-fA-F]
+    ;
+
+fragment
+IntegerSuffix
+    :   UnsignedSuffix LongSuffix?
+    |   UnsignedSuffix LongLongSuffix
+    |   LongSuffix UnsignedSuffix?
+    |   LongLongSuffix UnsignedSuffix?
+    ;
+
+fragment
+UnsignedSuffix
+    :   [uU]
+    ;
+
+fragment
+LongSuffix
+    :   [lL]
+    ;
+
+fragment
+LongLongSuffix
+    :   'll' | 'LL'
+    ;
+
+fragment
+FloatingConstant
+    :   DecimalFloatingConstant
+    |   HexadecimalFloatingConstant
+    ;
+
+fragment
+DecimalFloatingConstant
+    :   FractionalConstant ExponentPart? FloatingSuffix?
+    |   DigitSequence ExponentPart FloatingSuffix?
+    ;
+
+fragment
+HexadecimalFloatingConstant
+    :   HexadecimalPrefix HexadecimalFractionalConstant BinaryExponentPart FloatingSuffix?
+    |   HexadecimalPrefix HexadecimalDigitSequence BinaryExponentPart FloatingSuffix?
+    ;
+
+fragment
+FractionalConstant
+    :   DigitSequence? '.' DigitSequence
+    |   DigitSequence '.'
+    ;
+
+fragment
+ExponentPart
+    :   'e' Sign? DigitSequence
+    |   'E' Sign? DigitSequence
+    ;
+
+fragment
+Sign
+    :   '+' | '-'
+    ;
+
+DigitSequence
+    :   Digit+
+    ;
+
+fragment
+HexadecimalFractionalConstant
+    :   HexadecimalDigitSequence? '.' HexadecimalDigitSequence
+    |   HexadecimalDigitSequence '.'
+    ;
+
+fragment
+BinaryExponentPart
+    :   'p' Sign? DigitSequence
+    |   'P' Sign? DigitSequence
+    ;
+
+fragment
+HexadecimalDigitSequence
+    :   HexadecimalDigit+
+    ;
+
+fragment
+FloatingSuffix
+    :   'f' | 'l' | 'F' | 'L'
+    ;
+
+fragment
+CharacterConstant
+    :   '\'' CCharSequence '\''
+    |   'L\'' CCharSequence '\''
+    |   'u\'' CCharSequence '\''
+    |   'U\'' CCharSequence '\''
+    ;
+
+fragment
+CCharSequence
+    :   CChar+
+    ;
+
+fragment
+CChar
+    :   ~['\\\r\n]
+    |   EscapeSequence
+    ;
+
+fragment
+EscapeSequence
+    :   SimpleEscapeSequence
+    |   OctalEscapeSequence
+    |   HexadecimalEscapeSequence
+    |   UniversalCharacterName
+    ;
+
+fragment
+SimpleEscapeSequence
+    :   '\\' ['"?abfnrtv\\]
+    ;
+
+fragment
+OctalEscapeSequence
+    :   '\\' OctalDigit
+    |   '\\' OctalDigit OctalDigit
+    |   '\\' OctalDigit OctalDigit OctalDigit
+    ;
+
+fragment
+HexadecimalEscapeSequence
+    :   '\\x' HexadecimalDigit+
+    ;
+
+StringLiteral
+    :   EncodingPrefix? '"' SCharSequence? '"'
+    ;
+
+fragment
+EncodingPrefix
+    :   'u8'
+    |   'u'
+    |   'U'
+    |   'L'
+    ;
+
+fragment
+SCharSequence
+    :   SChar+
+    ;
+
+fragment
+SChar
+    :   ~["\\\r\n]
+    |   EscapeSequence
+    |   '\\\n'   // Added line
+    |   '\\\r\n' // Added line
+    ;
+
+LineAfterPreprocessing
+    :   '#' ~[\r\n]*
+        -> skip
+    ;
+
+Whitespace
+    :   ( [ \t]
+        |   'MXNET_DLL'
+        |   'NNVM_DLL'
+        |   'extern "C" {'
+        |   'DEFAULT' Whitespace? '(' .*? ')'
+        )+
+        -> skip
+    ;
+
+Newline
+    :   (   '\r' '\n'?
+        |   '\n'
+        )
+        -> skip
+    ;
+
+BlockComment
+    :   '/*' .*? '*/'
+        -> skip
+    ;
+
+LineComment
+    :   '//' ~[\r\n]*
+        -> skip
+    ;
diff --git a/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/AntlrUtils.java b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/AntlrUtils.java
new file mode 100644
index 0000000..3a2b5e0
--- /dev/null
+++ b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/AntlrUtils.java
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mxnet.jnarator;
+
+import java.util.ArrayList;
+import java.util.List;
+import org.antlr.v4.runtime.tree.ParseTree;
+import org.antlr.v4.runtime.tree.TerminalNode;
+import org.apache.mxnet.jnarator.parser.CParser;
+
+public final class AntlrUtils {
+
+    private AntlrUtils() {}
+
+    public static boolean isTypeDef(CParser.DeclarationSpecifiersContext specs) {
+        if (specs.isEmpty()) {
+            return false;
+        }
+
+        CParser.DeclarationSpecifierContext spec =
+                (CParser.DeclarationSpecifierContext) specs.getChild(0);
+        CParser.StorageClassSpecifierContext storage = spec.storageClassSpecifier();
+        if (storage != null) {
+            return storage.Typedef() != null;
+        }
+        return false;
+    }
+
+    public static String getTypeDefValue(CParser.DeclarationSpecifiersContext specs) {
+        List<String> list = new ArrayList<>();
+        for (int i = 1; i < specs.getChildCount(); ++i) {
+            list.add(specs.getChild(i).getText());
+        }
+        return String.join(" ", list);
+    }
+
+    public static boolean isEnum(CParser.DeclarationSpecifiersContext specs) {
+        if (specs.isEmpty()) {
+            return false;
+        }
+
+        CParser.DeclarationSpecifierContext spec =
+                (CParser.DeclarationSpecifierContext) specs.getChild(0);
+        CParser.TypeSpecifierContext type = spec.typeSpecifier();
+        if (type == null) {
+            return false;
+        }
+        return type.enumSpecifier() != null;
+    }
+
+    public static boolean isStructOrUnion(CParser.DeclarationSpecifiersContext specs) {
+        if (specs.isEmpty()) {
+            return false;
+        }
+
+        CParser.DeclarationSpecifierContext spec =
+                (CParser.DeclarationSpecifierContext) specs.getChild(0);
+        CParser.TypeSpecifierContext type = spec.typeSpecifier();
+        if (type == null) {
+            return false;
+        }
+        return type.structOrUnionSpecifier() != null;
+    }
+
+    public static String getText(ParseTree tree) {
+        StringBuilder sb = new StringBuilder();
+        getText(sb, tree);
+        return sb.toString();
+    }
+
+    private static void getText(StringBuilder sb, ParseTree tree) {
+        if (tree instanceof TerminalNode) {
+            sb.append("\"v\" : \"").append(tree.getText()).append('"');
+            return;
+        }
+        sb.append('"');
+        sb.append(tree.getClass().getSimpleName()).append("\" : {");
+        for (int i = 0; i < tree.getChildCount(); i++) {
+            getText(sb, tree.getChild(i));
+            if (i < tree.getChildCount() - 1) {
+                sb.append(',');
+            }
+        }
+        sb.append('}');
+    }
+
+    public static String toCamelCase(String name) {
+        String[] tokens = name.split("_");
+        for (int i = 0; i < tokens.length; ++i) {
+            char upper = Character.toUpperCase(tokens[i].charAt(0));
+            tokens[i] = upper + tokens[i].substring(1); // NOPMD
+        }
+        return String.join("", tokens);
+    }
+}
diff --git a/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/DataType.java b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/DataType.java
new file mode 100644
index 0000000..3b1c1fe
--- /dev/null
+++ b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/DataType.java
@@ -0,0 +1,266 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mxnet.jnarator;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Set;
+import org.antlr.v4.runtime.tree.ParseTree;
+import org.antlr.v4.runtime.tree.TerminalNode;
+import org.apache.mxnet.jnarator.parser.CParser;
+
+public class DataType {
+
+    private boolean isConst;
+    private boolean functionPointer;
+    private int pointerCount;
+    private StringBuilder type = new StringBuilder(); // NOPMD
+
+    public boolean isConst() {
+        return isConst;
+    }
+
+    public void setConst() {
+        isConst = true;
+    }
+
+    public boolean isFunctionPointer() {
+        return functionPointer;
+    }
+
+    public void setFunctionPointer(boolean functionPointer) {
+        this.functionPointer = functionPointer;
+    }
+
+    public int getPointerCount() {
+        return pointerCount;
+    }
+
+    public void setPointerCount(int pointerCount) {
+        this.pointerCount = pointerCount;
+    }
+
+    public void increasePointerCount() {
+        ++pointerCount;
+    }
+
+    public String getType() {
+        return type.toString();
+    }
+
+    public void setType(String typeName) {
+        type.setLength(0);
+        type.append(typeName);
+    }
+
+    public void appendTypeName(String name) {
+        if (type.length() > 0) {
+            type.append(' ');
+        }
+        type.append(name);
+    }
+
+    public String map(Map<String, TypeDefine> map, Set<String> structs) {
+        String typeName = type.toString().trim();
+        TypeDefine typeDefine = map.get(typeName);
+        boolean isStruct = structs.contains(typeName);
+        if (typeDefine != null && !typeDefine.isCallBack()) {
+            typeName = typeDefine.getValue();
+
+            String mapped = typeName.replaceAll("const ", "").replaceAll(" const", "");
+            if (typeName.length() - mapped.length() > 0) {
+                isConst = true;
+            }
+            typeName = mapped;
+            mapped = typeName.replaceAll("\\*", "");
+            int count = typeName.length() - mapped.length();
+            pointerCount += count;
+            typeName = mapped;
+            setType(typeName);
+        }
+
+        if (pointerCount > 2) {
+            return "PointerByReference";
+        }
+
+        typeName = baseTypeMapping(typeName);
+
+        if (pointerCount == 2) {
+            if (isConst && "char".equals(typeName)) {
+                return "String[]";
+            }
+            return "PointerByReference";
+        }
+
+        if (pointerCount == 1) {
+            switch (typeName) {
+                case "byte":
+                    return "ByteBuffer";
+                case "NativeSize":
+                    return "NativeSizeByReference";
+                case "int":
+                    if (isConst) {
+                        return "int[]";
+                    }
+                    return "IntBuffer";
+                case "long":
+                    if (isConst) {
+                        return "long[]";
+                    }
+                    return "LongBuffer";
+                case "char":
+                    if (isConst) {
+                        return "String";
+                    }
+                    return "ByteBuffer";
+                case "float":
+                    return "FloatBuffer";
+                case "void":
+                    return "Pointer";
+                default:
+                    if (isStruct) {
+                        return typeName + ".ByReference";
+                    }
+                    return "Pointer";
+            }
+        }
+        return typeName;
+    }
+
+    /** {@inheritDoc} */
+    @Override
+    public String toString() {
+        StringBuilder sb = new StringBuilder();
+        if (isConst) {
+            sb.append("const ");
+        }
+        sb.append(type);
+        if (pointerCount > 0) {
+            sb.append(' ');
+            for (int i = 0; i < pointerCount; ++i) {
+                sb.append('*');
+            }
+        }
+        return sb.toString();
+    }
+
+    /** {@inheritDoc} */
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) {
+            return true;
+        }
+        if (o == null || getClass() != o.getClass()) {
+            return false;
+        }
+        DataType dataType = (DataType) o;
+        return type.toString().equals(dataType.type.toString());
+    }
+
+    /** {@inheritDoc} */
+    @Override
+    public int hashCode() {
+        return Objects.hash(type);
+    }
+
+    static DataType parse(ParseTree tree) {
+        DataType dataType = new DataType();
+        parseTypeSpec(dataType, tree);
+        return dataType;
+    }
+
+    static List<DataType> parseDataTypes(List<CParser.DeclarationSpecifierContext> list) {
+        List<DataType> ret = new ArrayList<>();
+        DataType dataType = new DataType();
+        for (CParser.DeclarationSpecifierContext spec : list) {
+            CParser.TypeQualifierContext qualifier = spec.typeQualifier();
+            if (qualifier != null) {
+                String qualifierName = qualifier.getText();
+                if ("const".equals(qualifierName)) {
+                    dataType.setConst();
+                } else {
+                    dataType.appendTypeName(qualifierName);
+                }
+                continue;
+            }
+
+            CParser.TypeSpecifierContext type = spec.typeSpecifier();
+            parseTypeSpec(dataType, type);
+            ret.add(dataType);
+            dataType = new DataType();
+        }
+
+        return ret;
+    }
+
+    private static void parseTypeSpec(DataType dataType, ParseTree tree) {
+        if (tree == null) {
+            return;
+        }
+
+        if (tree instanceof CParser.StructOrUnionContext) {
+            return;
+        }
+        if (tree instanceof CParser.TypedefNameContext) {
+            if (dataType.getType().isEmpty()) {
+                dataType.appendTypeName(tree.getText());
+            }
+            return;
+        }
+
+        if (tree instanceof TerminalNode) {
+            String value = tree.getText();
+            if ("const".equals(value)) {
+                dataType.setConst();
+            } else if ("*".equals(value)) {
+                dataType.increasePointerCount();
+            } else {
+                dataType.appendTypeName(value);
+            }
+            return;
+        }
+
+        for (int i = 0; i < tree.getChildCount(); i++) {
+            parseTypeSpec(dataType, tree.getChild(i));
+        }
+    }
+
+    private static String baseTypeMapping(String type) {
+        switch (type) {
+            case "uint64_t":
+            case "int64_t":
+            case "long":
+                return "long";
+            case "uint32_t":
+            case "unsigned int":
+            case "unsigned":
+            case "int":
+                return "int";
+            case "bool":
+                return "byte";
+            case "size_t":
+                return "NativeSize";
+            case "char":
+            case "void":
+            case "float":
+            default:
+                return type;
+        }
+    }
+}
diff --git a/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/FuncInfo.java b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/FuncInfo.java
new file mode 100644
index 0000000..3cdc9c9
--- /dev/null
+++ b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/FuncInfo.java
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mxnet.jnarator;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Objects;
+import org.apache.mxnet.jnarator.parser.CParser;
+
+public class FuncInfo {
+
+    private String name;
+    private DataType returnType;
+    private List<Parameter> parameters = new ArrayList<>();
+
+    public String getName() {
+        return name;
+    }
+
+    public void setName(String name) {
+        this.name = name;
+    }
+
+    public DataType getReturnType() {
+        return returnType;
+    }
+
+    public void setReturnType(DataType returnType) {
+        this.returnType = returnType;
+    }
+
+    public List<Parameter> getParameters() {
+        return parameters;
+    }
+
+    public void setParameters(List<Parameter> parameters) {
+        this.parameters = parameters;
+    }
+
+    public void addParameter(Parameter parameter) {
+        parameters.add(parameter);
+    }
+
+    /** {@inheritDoc} */
+    @Override
+    public String toString() {
+        StringBuilder sb = new StringBuilder();
+        sb.append(returnType).append(' ').append(name).append('(');
+        if (parameters != null) {
+            boolean first = true;
+            for (Parameter param : parameters) {
+                if (first) {
+                    first = false;
+                } else {
+                    sb.append(", ");
+                }
+                sb.append(param);
+            }
+        }
+
+        sb.append(");");
+        return sb.toString();
+    }
+
+    /** {@inheritDoc} */
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) {
+            return true;
+        }
+        if (o == null || getClass() != o.getClass()) {
+            return false;
+        }
+        FuncInfo funcInfo = (FuncInfo) o;
+        return name.equals(funcInfo.name) && parameters.equals(funcInfo.parameters);
+    }
+
+    /** {@inheritDoc} */
+    @Override
+    public int hashCode() {
+        return Objects.hash(name);
+    }
+
+    static FuncInfo parse(CParser.DeclarationContext ctx) {
+        FuncInfo info = new FuncInfo();
+
+        List<CParser.DeclarationSpecifierContext> specs =
+                ctx.declarationSpecifiers().declarationSpecifier();
+        List<DataType> dataTypes = DataType.parseDataTypes(specs);
+        info.setReturnType(dataTypes.get(0));
+        if (dataTypes.size() > 1) {
+            info.setName(dataTypes.get(1).getType());
+        }
+
+        CParser.InitDeclaratorContext init = ctx.initDeclaratorList().initDeclarator();
+        CParser.DirectDeclaratorContext declarator = init.declarator().directDeclarator();
+
+        CParser.DirectDeclaratorContext name = declarator.directDeclarator();
+        if (info.getName() == null) {
+            info.setName(name.getText());
+            CParser.ParameterTypeListContext paramListCtx = declarator.parameterTypeList();
+            if (paramListCtx != null) {
+                Parameter.parseParams(info.getParameters(), paramListCtx);
+            }
+        } else {
+            DataType dataType = new DataType();
+            CParser.TypeSpecifierContext type = declarator.typeSpecifier();
+            dataType.appendTypeName(type.getText());
+            if (declarator.pointer() != null) {
+                dataType.increasePointerCount();
+            }
+            Parameter param = new Parameter(dataType, name.getText());
+            info.addParameter(param);
+        }
+
+        return info;
+    }
+}
diff --git a/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/JnaGenerator.java b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/JnaGenerator.java
new file mode 100644
index 0000000..7be963d
--- /dev/null
+++ b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/JnaGenerator.java
@@ -0,0 +1,371 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mxnet.jnarator;
+
+import java.io.BufferedWriter;
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Properties;
+import java.util.Set;
+
+public class JnaGenerator {
+
+    private Path dir;
+    private String packageName;
+    private String libName;
+    private String className;
+    private Map<String, TypeDefine> typedefMap;
+    private Set<String> structs;
+    private Properties mapping;
+
+    public JnaGenerator(
+            String libName,
+            String packageName,
+            Map<String, TypeDefine> typedefMap,
+            Set<String> structs,
+            Properties mapping) {
+        this.libName = libName;
+        this.packageName = packageName;
+        this.typedefMap = typedefMap;
+        this.structs = structs;
+        this.mapping = mapping;
+    }
+
+    public void init(String output) throws IOException {
+        String[] tokens = packageName.split("\\.");
+        dir = Paths.get(output, tokens);
+        Files.createDirectories(dir);
+        className = AntlrUtils.toCamelCase(libName) + "Library";
+    }
+
+    @SuppressWarnings("PMD.UseConcurrentHashMap")
+    public void writeStructure(Map<String, List<TypeDefine>> structMap) throws IOException {
+        for (Map.Entry<String, List<TypeDefine>> entry : structMap.entrySet()) {
+            String name = entry.getKey();
+            Path path = dir.resolve(name + ".java");
+            try (BufferedWriter writer = Files.newBufferedWriter(path)) {
+                writer.append("package ").append(packageName).append(";\n\n");
+
+                Set<String> importSet = new HashSet<>();
+                importSet.add("com.sun.jna.Pointer");
+                importSet.add("com.sun.jna.Structure");
+                importSet.add("java.util.List");
+
+                Map<String, String> fieldNames = new LinkedHashMap<>();
+                for (TypeDefine typeDefine : entry.getValue()) {
+                    String fieldName = typeDefine.getValue();
+                    String typeName;
+                    if (typeDefine.isCallBack()) {
+                        typeName = AntlrUtils.toCamelCase(fieldName) + "Callback";
+                        importSet.add("com.sun.jna.Callback");
+                        for (Parameter param : typeDefine.getParameters()) {
+                            String type = param.getType().map(typedefMap, structs);
+                            addImports(importSet, type);
+                        }
+                    } else {
+                        typeName = typeDefine.getDataType().map(typedefMap, structs);
+                        addImports(importSet, typeName);
+                    }
+                    fieldNames.put(fieldName, typeName);
+                }
+
+                int fieldCount = fieldNames.size();
+                if (fieldCount < 2) {
+                    importSet.add("java.util.Collections");
+                } else {
+                    importSet.add("java.util.Arrays");
+                }
+
+                List<String> imports = new ArrayList<>(importSet.size());
+                imports.addAll(importSet);
+                Collections.sort(imports);
+                for (String imp : imports) {
+                    writer.append("import ").append(imp).append(";\n");
+                }
+
+                writer.append("\npublic class ").append(name).append(" extends Structure {\n");
+                if (fieldCount > 0) {
+                    writer.write("\n");
+                }
+                for (Map.Entry<String, String> field : fieldNames.entrySet()) {
+                    writer.append("    public ").append(field.getValue()).append(' ');
+                    writer.append(field.getKey()).append(";\n");
+                }
+
+                writer.append("\n    public ").append(name).append("() {\n");
+                writer.append("    }\n");
+                writer.append("\n    public ").append(name).append("(Pointer peer) {\n");
+                writer.append("        super(peer);\n");
+                writer.append("    }\n");
+
+                writer.append("\n    @Override\n");
+                writer.append("    protected List<String> getFieldOrder() {\n");
+                switch (fieldNames.size()) {
+                    case 0:
+                        writer.append("        return Collections.emptyList();\n");
+                        break;
+                    case 1:
+                        writer.append("        return Collections.singletonList(");
+                        String firstField = fieldNames.keySet().iterator().next();
+                        writer.append('"').append(firstField).append("\");\n");
+                        break;
+                    default:
+                        writer.append("        return Arrays.asList(");
+                        boolean first = true;
+                        for (String fieldName : fieldNames.keySet()) {
+                            if (first) {
+                                first = false;
+                            } else {
+                                writer.write(", ");
+                            }
+                            writer.append('"').append(fieldName).append('"');
+                        }
+                        writer.append(");\n");
+                        break;
+                }
+                writer.append("    }\n");
+
+                for (TypeDefine typeDefine : entry.getValue()) {
+                    String fieldName = typeDefine.getValue();
+                    String typeName = fieldNames.get(fieldName);
+                    String getterName;
+                    if (!typeDefine.isCallBack()) {
+                        getterName = AntlrUtils.toCamelCase(fieldName);
+                    } else {
+                        getterName = typeName;
+                    }
+
+                    writer.append("\n    public void set").append(getterName).append('(');
+                    writer.append(typeName).append(' ').append(fieldName).append(") {\n");
+                    writer.append("        this.").append(fieldName).append(" = ");
+                    writer.append(fieldName).append(";\n");
+                    writer.append("    }\n");
+                    writer.append("\n    public ").append(typeName).append(" get");
+                    writer.append(getterName).append("() {\n");
+                    writer.append("        return ").append(fieldName).append(";\n");
+                    writer.append("    }\n");
+                }
+
+                writer.append("\n    public static final class ByReference extends ");
+                writer.append(name).append(" implements Structure.ByReference {}\n");
+
+                writer.append("\n    public static final class ByValue extends ");
+                writer.append(name).append(" implements Structure.ByValue {}\n");
+
+                for (TypeDefine typeDefine : entry.getValue()) {
+                    if (typeDefine.isCallBack()) {
+                        DataType dataType = typeDefine.getDataType();
+                        String fieldName = typeDefine.getValue();
+
+                        String callbackName = fieldNames.get(fieldName);
+                        String returnType = mapping.getProperty(callbackName);
+                        if (returnType == null) {
+                            returnType = dataType.map(typedefMap, structs);
+                        }
+
+                        writer.append("\n    public interface ").append(callbackName);
+                        writer.append(" extends Callback {\n");
+                        writer.append("        ").append(returnType).append(" apply(");
+                        writeParameters(writer, fieldName, typeDefine.getParameters());
+                        writer.append(");\n");
+                        writer.append("    }\n");
+                    }
+                }
+
+                writer.append("}\n");
+            }
+        }
+    }
+
+    public void writeLibrary(Collection<FuncInfo> functions, Map<String, List<String>> enumMap)
+            throws IOException {
+        try (BufferedWriter writer = Files.newBufferedWriter(dir.resolve(className + ".java"))) {
+            writer.append("package ").append(packageName).append(";\n\n");
+
+            writer.append("import com.sun.jna.Callback;\n");
+            writer.append("import com.sun.jna.Library;\n");
+            writer.append("import com.sun.jna.Pointer;\n");
+            writer.append("import com.sun.jna.ptr.PointerByReference;\n");
+            writer.append("import java.nio.ByteBuffer;\n");
+            writer.append("import java.nio.FloatBuffer;\n");
+            writer.append("import java.nio.IntBuffer;\n");
+            writer.append("import java.nio.LongBuffer;\n");
+
+            writer.append("\npublic interface ").append(className).append(" extends Library {\n\n");
+
+            for (Map.Entry<String, List<String>> entry : enumMap.entrySet()) {
+                String name = entry.getKey();
+                writer.append("\n    enum ").append(name).append(" {\n");
+                List<String> fields = entry.getValue();
+                int count = 0;
+                for (String field : fields) {
+                    writer.append("        ").append(field);
+                    if (++count < fields.size()) {
+                        writer.append(',');
+                    }
+                    writer.append('\n');
+                }
+                writer.append("    }\n");
+            }
+
+            for (TypeDefine typeDefine : typedefMap.values()) {
+                if (typeDefine.isCallBack()) {
+                    String callbackName = typeDefine.getDataType().getType();
+                    String returnType = mapping.getProperty(callbackName);
+                    if (returnType == null) {
+                        returnType = typeDefine.getValue();
+                    }
+                    writer.append("\n    interface ").append(callbackName);
+                    writer.append(" extends Callback {\n");
+                    writer.append("        ").append(returnType).append(" apply(");
+                    writeParameters(writer, callbackName, typeDefine.getParameters());
+                    writer.append(");\n");
+                    writer.append("    }\n");
+                }
+            }
+
+            for (FuncInfo info : functions) {
+                writeFunction(writer, info);
+            }
+            writer.append("}\n");
+        }
+    }
+
+    public void writeNativeSize() throws IOException {
+        try (BufferedWriter writer = Files.newBufferedWriter(dir.resolve("NativeSize.java"))) {
+            writer.append("package ").append(packageName).append(";\n\n");
+            writer.append("import com.sun.jna.IntegerType;\n");
+            writer.append("import com.sun.jna.Native;\n\n");
+
+            writer.append("public class NativeSize extends IntegerType {\n\n");
+            writer.append("    private static final long serialVersionUID = 1L;\n\n");
+            writer.append("    public static final int SIZE = Native.SIZE_T_SIZE;\n\n");
+            writer.append("    public NativeSize() {\n");
+            writer.append("        this(0);\n");
+            writer.append("    }\n\n");
+            writer.append("    public NativeSize(long value) {\n");
+            writer.append("        super(SIZE, value);\n");
+            writer.append("    }\n");
+            writer.append("}\n");
+        }
+
+        Path path = dir.resolve("NativeSizeByReference.java");
+        try (BufferedWriter writer = Files.newBufferedWriter(path)) {
+            writer.append("package ").append(packageName).append(";\n\n");
+            writer.append("import com.sun.jna.ptr.ByReference;\n\n");
+            writer.append("public class NativeSizeByReference extends ByReference {\n\n");
+            writer.append("    public NativeSizeByReference() {\n");
+            writer.append("        this(new NativeSize(0));\n");
+            writer.append("    }\n\n");
+            writer.append("    public NativeSizeByReference(NativeSize value) {\n");
+            writer.append("        super(NativeSize.SIZE);\n");
+            writer.append("        setValue(value);\n");
+            writer.append("    }\n\n");
+            writer.append("    public void setValue(NativeSize value) {\n");
+            writer.append("        if (NativeSize.SIZE == 4) {\n");
+            writer.append("            getPointer().setInt(0, value.intValue());\n");
+            writer.append("        } else if (NativeSize.SIZE == 8) {\n");
+            writer.append("            getPointer().setLong(0, value.longValue());\n");
+            writer.append("        } else {\n");
+            writer.append(
+                    "            throw new IllegalArgumentException(\"size_t has to be either 4 or 8 bytes.\");\n");
+            writer.append("        }\n");
+            writer.append("    }\n\n");
+            writer.append("    public NativeSize getValue() {\n");
+            writer.append("        if (NativeSize.SIZE == 4) {\n");
+            writer.append("            return new NativeSize(getPointer().getInt(0));\n");
+            writer.append("        } else if (NativeSize.SIZE == 8) {\n");
+            writer.append("            return new NativeSize(getPointer().getLong(0));\n");
+            writer.append("        } else {\n");
+            writer.append(
+                    "            throw new IllegalArgumentException(\"size_t has to be either 4 or 8 bytes.\");\n");
+            writer.append("        }\n");
+            writer.append("    }\n");
+            writer.append("}\n");
+        }
+    }
+
+    private void writeFunction(BufferedWriter writer, FuncInfo info) throws IOException {
+        String funcName = info.getName();
+        String returnType = mapping.getProperty(funcName);
+        if (returnType == null) {
+            returnType = info.getReturnType().map(typedefMap, structs);
+        }
+        writer.append("\n    ").append(returnType).append(' ');
+        writer.append(funcName).append('(');
+        writeParameters(writer, funcName, info.getParameters());
+        writer.append(");\n");
+    }
+
+    private void writeParameters(BufferedWriter writer, String funcName, List<Parameter> parameters)
+            throws IOException {
+        if (parameters != null) {
+            boolean first = true;
+            for (Parameter param : parameters) {
+                if (first) {
+                    first = false;
+                } else {
+                    writer.append(", ");
+                }
+                String paramName = param.getName();
+                String type = mapping.getProperty(funcName + '.' + paramName);
+                if (type == null) {
+                    type = param.getType().map(typedefMap, structs);
+                }
+                if (!"void".equals(type)) {
+                    writer.append(type).append(' ');
+                    writer.append(paramName);
+                }
+            }
+        }
+    }
+
+    private static void addImports(Set<String> importSet, String typeName) {
+        switch (typeName) {
+            case "ByReference":
+            case "ByteByReference":
+            case "DoubleByReference":
+            case "FloatByReference":
+            case "IntByReference":
+            case "LongByReference":
+            case "NativeLongByReference":
+            case "PointerByReference":
+            case "ShortByReference":
+                importSet.add("com.sun.jna.ptr." + typeName);
+                break;
+            case "ByteBuffer":
+            case "DoubleBuffer":
+            case "FloatBuffer":
+            case "IntBuffer":
+            case "LongBuffer":
+            case "ShortBuffer":
+                importSet.add("java.nio." + typeName);
+                break;
+            default:
+                break;
+        }
+    }
+}
diff --git a/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/JnaParser.java b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/JnaParser.java
new file mode 100644
index 0000000..16a350e
--- /dev/null
+++ b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/JnaParser.java
@@ -0,0 +1,195 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mxnet.jnarator;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import org.antlr.v4.runtime.CharStreams;
+import org.antlr.v4.runtime.CommonTokenStream;
+import org.antlr.v4.runtime.tree.ParseTree;
+import org.antlr.v4.runtime.tree.ParseTreeWalker;
+import org.apache.mxnet.jnarator.parser.CBaseListener;
+import org.apache.mxnet.jnarator.parser.CLexer;
+import org.apache.mxnet.jnarator.parser.CParser;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class JnaParser {
+
+    static final Logger logger = LoggerFactory.getLogger(Main.class);
+
+    Map<String, List<TypeDefine>> structMap;
+    Map<String, List<String>> enumMap;
+    List<FuncInfo> functions;
+    Map<String, TypeDefine> typedefMap;
+    private Set<String> functionNames;
+
+    public JnaParser() {
+        structMap = new LinkedHashMap<>();
+        enumMap = new LinkedHashMap<>();
+        functions = new ArrayList<>();
+        typedefMap = new LinkedHashMap<>();
+        functionNames = new HashSet<>();
+    }
+
+    public void parse(String headerFile) {
+        try {
+            CLexer lexer = new CLexer(CharStreams.fromFileName(headerFile));
+            CommonTokenStream tokens = new CommonTokenStream(lexer);
+            CParser parser = new CParser(tokens);
+            ParseTree tree = parser.compilationUnit();
+
+            ParseTreeWalker walker = new ParseTreeWalker();
+            CBaseListener listener =
+                    new CBaseListener() {
+
+                        /** {@inheritDoc} */
+                        @Override
+                        public void enterDeclaration(CParser.DeclarationContext ctx) {
+                            CParser.DeclarationSpecifiersContext specs =
+                                    ctx.declarationSpecifiers();
+                            CParser.InitDeclaratorListContext init = ctx.initDeclaratorList();
+
+                            if (AntlrUtils.isTypeDef(specs)) {
+                                TypeDefine value = TypeDefine.parse(init, specs);
+                                typedefMap.put(value.getDataType().getType(), value);
+                            } else if (AntlrUtils.isStructOrUnion(specs)) {
+                                CParser.DeclarationSpecifierContext spec =
+                                        (CParser.DeclarationSpecifierContext) specs.getChild(0);
+                                CParser.TypeSpecifierContext type = spec.typeSpecifier();
+                                CParser.StructOrUnionSpecifierContext struct =
+                                        type.structOrUnionSpecifier();
+                                String name = struct.Identifier().getText();
+                                List<TypeDefine> fields = new ArrayList<>();
+
+                                CParser.StructDeclarationListContext list =
+                                        struct.structDeclarationList();
+                                parseStructFields(fields, list);
+
+                                structMap.put(name, fields);
+                            } else if (AntlrUtils.isEnum(specs)) {
+                                CParser.DeclarationSpecifierContext spec =
+                                        (CParser.DeclarationSpecifierContext) specs.getChild(0);
+                                CParser.TypeSpecifierContext type = spec.typeSpecifier();
+                                CParser.EnumSpecifierContext enumSpecifierContext =
+                                        type.enumSpecifier();
+                                String name = enumSpecifierContext.Identifier().getText();
+                                List<String> fields = new ArrayList<>();
+                                parseEnum(fields, ctx);
+                                enumMap.put(name, fields);
+                            } else {
+                                FuncInfo info = FuncInfo.parse(ctx);
+                                if (checkDuplicate(info)) {
+                                    logger.warn("Duplicate function: {}.", info.getName());
+                                } else {
+                                    functions.add(info);
+                                }
+                            }
+                        }
+                    };
+            walker.walk(listener, tree);
+        } catch (IOException e) {
+            logger.error("", e);
+        }
+    }
+
+    void parseStructFields(List<TypeDefine> fields, ParseTree tree) {
+        if (tree instanceof CParser.StructDeclarationContext) {
+            CParser.StructDeclarationContext ctx = (CParser.StructDeclarationContext) tree;
+            CParser.SpecifierQualifierListContext qualifierList = ctx.specifierQualifierList();
+            DataType dataType = DataType.parse(qualifierList);
+
+            TypeDefine typeDefine = new TypeDefine();
+            fields.add(typeDefine);
+
+            typeDefine.setDataType(dataType);
+
+            CParser.StructDeclaratorListContext name = ctx.structDeclaratorList();
+            if (name != null) {
+                typeDefine.setCallBack(true);
+
+                CParser.DirectDeclaratorContext dd =
+                        name.structDeclarator().declarator().directDeclarator();
+                CParser.DirectDeclaratorContext nameCtx =
+                        dd.directDeclarator().declarator().directDeclarator();
+                String fieldName = nameCtx.getText();
+                typeDefine.setValue(fieldName);
+
+                CParser.ParameterTypeListContext paramListCtx = dd.parameterTypeList();
+                if (paramListCtx != null) {
+                    Parameter.parseParams(typeDefine.getParameters(), paramListCtx);
+                }
+            } else {
+                CParser.SpecifierQualifierListContext nameList =
+                        qualifierList.specifierQualifierList();
+                if (nameList.specifierQualifierList() != null) {
+                    typeDefine.setValue(nameList.specifierQualifierList().getText());
+                } else {
+                    typeDefine.setValue(nameList.getText());
+                }
+            }
+            return;
+        }
+
+        for (int i = 0; i < tree.getChildCount(); i++) {
+            parseStructFields(fields, tree.getChild(i));
+        }
+    }
+
+    void parseEnum(List<String> fields, ParseTree ctx) {
+        if (ctx instanceof CParser.EnumerationConstantContext) {
+            fields.add(ctx.getText());
+            return;
+        }
+
+        for (int i = 0; i < ctx.getChildCount(); i++) {
+            parseEnum(fields, ctx.getChild(i));
+        }
+    }
+
+    public Map<String, List<TypeDefine>> getStructMap() {
+        return structMap;
+    }
+
+    public Map<String, List<String>> getEnumMap() {
+        return enumMap;
+    }
+
+    public List<FuncInfo> getFunctions() {
+        return functions;
+    }
+
+    public Map<String, TypeDefine> getTypedefMap() {
+        return typedefMap;
+    }
+
+    boolean checkDuplicate(FuncInfo function) {
+        if (!functionNames.add(function.getName())) {
+            for (FuncInfo info : functions) {
+                if (function.equals(info)) {
+                    return true;
+                }
+            }
+        }
+        return false;
+    }
+}
diff --git a/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/Main.java b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/Main.java
new file mode 100644
index 0000000..39aa8fc
--- /dev/null
+++ b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/Main.java
@@ -0,0 +1,175 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mxnet.jnarator;
+
+import java.io.InputStream;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.List;
+import java.util.Map;
+import java.util.Properties;
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.DefaultParser;
+import org.apache.commons.cli.HelpFormatter;
+import org.apache.commons.cli.Option;
+import org.apache.commons.cli.Options;
+import org.apache.commons.cli.ParseException;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public final class Main {
+
+    private static final Logger logger = LoggerFactory.getLogger(Main.class);
+
+    private Main() {}
+
+    public static void main(String[] args) {
+        Options options = Config.getOptions();
+        try {
+            DefaultParser cmdParser = new DefaultParser();
+            CommandLine cmd = cmdParser.parse(options, args, null, false);
+            Config config = new Config(cmd);
+
+            String output = config.getOutput();
+            String packageName = config.getPackageName();
+            String library = config.getLibrary();
+            String[] headerFiles = config.getHeaderFiles();
+            String mappingFile = config.getMappingFile();
+
+            Path dir = Paths.get(output);
+            Files.createDirectories(dir);
+
+            Properties mapping = new Properties();
+            if (mappingFile != null) {
+                Path file = Paths.get(mappingFile);
+                if (Files.notExists(file)) {
+                    logger.error("mapping file does not exists: {}", mappingFile);
+                    System.exit(-1); // NOPMD
+                }
+                try (InputStream in = Files.newInputStream(file)) {
+                    mapping.load(in);
+                }
+            }
+
+            JnaParser jnaParser = new JnaParser();
+            Map<String, TypeDefine> typedefMap = jnaParser.getTypedefMap();
+            Map<String, List<TypeDefine>> structMap = jnaParser.getStructMap();
+            JnaGenerator generator =
+                    new JnaGenerator(library, packageName, typedefMap, structMap.keySet(), mapping);
+            generator.init(output);
+
+            for (String headerFile : headerFiles) {
+                jnaParser.parse(headerFile);
+            }
+
+            generator.writeNativeSize();
+            generator.writeStructure(structMap);
+            generator.writeLibrary(jnaParser.getFunctions(), jnaParser.getEnumMap());
+        } catch (ParseException e) {
+            HelpFormatter formatter = new HelpFormatter();
+            formatter.setLeftPadding(1);
+            formatter.setWidth(120);
+            formatter.printHelp(e.getMessage(), options);
+            System.exit(-1); // NOPMD
+        } catch (Throwable t) {
+            logger.error("", t);
+            System.exit(-1); // NOPMD
+        }
+    }
+
+    public static final class Config {
+
+        private String library;
+        private String packageName;
+        private String output;
+        private String[] headerFiles;
+        private String mappingFile;
+
+        public Config(CommandLine cmd) {
+            library = cmd.getOptionValue("library");
+            packageName = cmd.getOptionValue("package");
+            output = cmd.getOptionValue("output");
+            headerFiles = cmd.getOptionValues("header");
+            mappingFile = cmd.getOptionValue("mappingFile");
+        }
+
+        public static Options getOptions() {
+            Options options = new Options();
+            options.addOption(
+                    Option.builder("l")
+                            .longOpt("library")
+                            .hasArg()
+                            .required()
+                            .argName("LIBRARY")
+                            .desc("library name")
+                            .build());
+            options.addOption(
+                    Option.builder("p")
+                            .longOpt("package")
+                            .required()
+                            .hasArg()
+                            .argName("PACKAGE")
+                            .desc("Java package name")
+                            .build());
+            options.addOption(
+                    Option.builder("o")
+                            .longOpt("output")
+                            .required()
+                            .hasArg()
+                            .argName("OUTPUT")
+                            .desc("output directory")
+                            .build());
+            options.addOption(
+                    Option.builder("f")
+                            .longOpt("header")
+                            .required()
+                            .hasArgs()
+                            .argName("HEADER")
+                            .desc("Header files")
+                            .build());
+            options.addOption(
+                    Option.builder("m")
+                            .longOpt("mappingFile")
+                            .hasArg()
+                            .argName("MAPPING_FILE")
+                            .desc("Type mappingFile config file")
+                            .build());
+            return options;
+        }
+
+        public String getLibrary() {
+            return library;
+        }
+
+        public String getPackageName() {
+            return packageName;
+        }
+
+        public String getOutput() {
+            return output;
+        }
+
+        public String[] getHeaderFiles() {
+            return headerFiles;
+        }
+
+        public String getMappingFile() {
+            return mappingFile;
+        }
+    }
+}
diff --git a/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/Parameter.java b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/Parameter.java
new file mode 100644
index 0000000..f46e5e7
--- /dev/null
+++ b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/Parameter.java
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mxnet.jnarator;
+
+import java.util.List;
+import java.util.Objects;
+import org.antlr.v4.runtime.tree.ParseTree;
+import org.apache.mxnet.jnarator.parser.CParser;
+
+public class Parameter {
+
+    private DataType type;
+    private String name;
+
+    public Parameter(DataType type, String name) {
+        this.type = type;
+        this.name = name;
+    }
+
+    public DataType getType() {
+        return type;
+    }
+
+    public String getName() {
+        return name;
+    }
+
+    /** {@inheritDoc} */
+    @Override
+    public String toString() {
+        return type.toString() + ' ' + name;
+    }
+
+    /** {@inheritDoc} */
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) {
+            return true;
+        }
+        if (o == null || getClass() != o.getClass()) {
+            return false;
+        }
+        Parameter parameter = (Parameter) o;
+        return type.equals(parameter.type);
+    }
+
+    /** {@inheritDoc} */
+    @Override
+    public int hashCode() {
+        return Objects.hash(type);
+    }
+
+    static void parseParams(List<Parameter> params, ParseTree ctx) {
+        if (ctx instanceof CParser.ParameterDeclarationContext) {
+            CParser.ParameterDeclarationContext declarationContext =
+                    (CParser.ParameterDeclarationContext) ctx;
+            CParser.DeclarationSpecifiersContext spec = declarationContext.declarationSpecifiers();
+            DataType dataType;
+            if (spec == null) {
+                dataType = DataType.parse(declarationContext.declarationSpecifiers2());
+            } else {
+                dataType = DataType.parse(spec);
+            }
+
+            CParser.DeclaratorContext declarator = declarationContext.declarator();
+
+            String name;
+            if (declarator != null) {
+                CParser.PointerContext pointer = declarator.pointer();
+                if (pointer != null) {
+                    dataType.increasePointerCount();
+                }
+                name = declarator.directDeclarator().getText();
+            } else {
+                name = "arg" + (params.size() + 1);
+            }
+
+            Parameter param = new Parameter(dataType, name);
+            params.add(param);
+            return;
+        }
+        for (int i = 0; i < ctx.getChildCount(); i++) {
+            parseParams(params, ctx.getChild(i));
+        }
+    }
+}
diff --git a/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/TypeDefine.java b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/TypeDefine.java
new file mode 100644
index 0000000..1f30d21
--- /dev/null
+++ b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/TypeDefine.java
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mxnet.jnarator;
+
+import java.util.ArrayList;
+import java.util.List;
+import org.apache.mxnet.jnarator.parser.CParser;
+
+public class TypeDefine {
+
+    private DataType dataType;
+    private boolean callBack;
+    private String value;
+    private List<Parameter> parameters = new ArrayList<>();
+
+    public DataType getDataType() {
+        return dataType;
+    }
+
+    public void setDataType(DataType dataType) {
+        this.dataType = dataType;
+    }
+
+    public boolean isCallBack() {
+        return callBack;
+    }
+
+    public void setCallBack(boolean callBack) {
+        this.callBack = callBack;
+    }
+
+    public String getValue() {
+        return value;
+    }
+
+    public void setValue(String value) {
+        this.value = value;
+    }
+
+    public List<Parameter> getParameters() {
+        return parameters;
+    }
+
+    static TypeDefine parse(
+            CParser.InitDeclaratorListContext init, CParser.DeclarationSpecifiersContext specs) {
+        TypeDefine typeDefine = new TypeDefine();
+        DataType dataType = new DataType();
+        typeDefine.setDataType(dataType);
+
+        CParser.DirectDeclaratorContext ctx = init.initDeclarator().declarator().directDeclarator();
+        CParser.DirectDeclaratorContext callback = ctx.directDeclarator();
+        if (callback == null) {
+            dataType.setType(ctx.getText());
+        } else {
+            typeDefine.setCallBack(true);
+            dataType.setType(callback.declarator().directDeclarator().getText());
+            CParser.ParameterTypeListContext paramListCtx = ctx.parameterTypeList();
+            List<Parameter> parameters = typeDefine.getParameters();
+            Parameter.parseParams(parameters, paramListCtx);
+        }
+
+        List<String> list = new ArrayList<>();
+        for (int i = 1; i < specs.getChildCount(); ++i) {
+            list.add(specs.getChild(i).getText());
+        }
+
+        typeDefine.setValue(String.join(" ", list));
+        return typeDefine;
+    }
+}
diff --git a/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/package-info.java b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/package-info.java
new file mode 100644
index 0000000..a6cff70
--- /dev/null
+++ b/java-package/jnarator/src/main/java/org/apache/mxnet/jnarator/package-info.java
@@ -0,0 +1,19 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/** Contains classes to generate the Apache MXNet (incubating) native interface. */
+package org.apache.mxnet.jnarator;
diff --git a/java-package/jnarator/src/main/resources/log4j2.xml b/java-package/jnarator/src/main/resources/log4j2.xml
new file mode 100644
index 0000000..4818a95
--- /dev/null
+++ b/java-package/jnarator/src/main/resources/log4j2.xml
@@ -0,0 +1,30 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!--
+  ~ Licensed to the Apache Software Foundation (ASF) under one or more
+  ~ contributor license agreements.  See the NOTICE file distributed with
+  ~ this work for additional information regarding copyright ownership.
+  ~ The ASF licenses this file to You under the Apache License, Version 2.0
+  ~ (the "License"); you may not use this file except in compliance with
+  ~ the License.  You may obtain a copy of the License at
+  ~
+  ~    http://www.apache.org/licenses/LICENSE-2.0
+  ~
+  ~ Unless required by applicable law or agreed to in writing, software
+  ~ distributed under the License is distributed on an "AS IS" BASIS,
+  ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+  ~ See the License for the specific language governing permissions and
+  ~ limitations under the License.
+  -->
+
+<Configuration status="WARN">
+  <Appenders>
+    <Console name="Console" target="SYSTEM_OUT">
+      <PatternLayout pattern="[%highlight{%-5level}] %msg%n%throwable"/>
+    </Console>
+  </Appenders>
+  <Loggers>
+    <Root level="DEBUG" additivity="false">
+      <AppenderRef ref="Console"/>
+    </Root>
+  </Loggers>
+</Configuration>
diff --git a/java-package/mxnet-engine/build.gradle b/java-package/mxnet-engine/build.gradle
new file mode 100644
index 0000000..c874f72
--- /dev/null
+++ b/java-package/mxnet-engine/build.gradle
@@ -0,0 +1,188 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+plugins {
+    id 'java'
+}
+
+group 'org.apache.mxnet'
+version '0.0.1-SNAPSHOT'
+
+repositories {
+    mavenCentral()
+}
+
+def getOsName() {
+    def os_name = System.properties['os.name']
+    if (os_name.contains('windows')) {
+        return "win"
+    } else if (os_name.contains('Mac OS X')) {
+        return "osx"
+    } else if (os_name.contains('Linux')) {
+        return "linux"
+    } else {
+        return System.properties['os.name']
+    }
+}
+
+dependencies {
+    api "com.google.code.gson:gson:${gson_version}"
+    api "net.java.dev.jna:jna:${jna_version}"
+    api "org.apache.commons:commons-compress:${commons_compress_version}"
+    api "org.slf4j:slf4j-api:${slf4j_version}"
+
+    testImplementation("org.testng:testng:${testng_version}") {
+        exclude group: "junit", module: "junit"
+    }
+    testImplementation "org.slf4j:slf4j-simple:${slf4j_version}"
+    // Solve the problem: Failed to load class "org.slf4j.impl.StaticLoggerBinder".
+    implementation "org.slf4j:slf4j-simple:${slf4j_version}"
+    def osName = getOsName()
+    implementation files("${project(':native').buildDir}/libs/native-${mxnet_version}-SNAPSHOT-${osName}-x86_64.jar")
+//    implementation fileTree(dir: "${project(':native').buildDir}/lib", includes: ["native-${mxnet_version}-SNAPSHOT-${osName}-x86_64.jar"])
+}
+
+sourceSets {
+    main {
+        java {
+            srcDirs = ['src/main/java', 'build/generated-src']
+        }
+    }
+}
+
+checkstyleMain.source = 'src/main/java'
+pmdMain.source = 'src/main/java'
+
+task jnarator(dependsOn: ":jnarator:jar") {
+    outputs.dir "${project.buildDir}/generated-src"
+    doLast {
+        File jnaGenerator = project(":jnarator").jar.outputs.files.singleFile
+        javaexec {
+            main = "-jar"
+            args = [
+                    jnaGenerator.absolutePath,
+                    "-l",
+                    "mxnet",
+                    "-p",
+                    "org.apache.mxnet.jna",
+                    "-o",
+                    "${project.buildDir}/generated-src",
+                    "-m",
+                    "${project.projectDir}/src/main/jna/mapping.properties",
+                    "-f",
+                    "../../include/mxnet/c_api.h",
+                    "../../include/nnvm/c_api.h"
+            ]
+        }
+    }
+}
+
+test {
+    useTestNG() {
+        useDefaultListeners = true
+    }
+    environment "PATH", "src/test/bin:${environment.PATH}"
+//    environment "MXNET_LIBRARY_PATH", "${MXNET_LIBRARY_PATH}"
+    maxHeapSize = '6G'
+    testLogging.showStandardStreams = true
+    beforeTest { descriptor ->
+        logger.lifecycle("Running test: " + descriptor)
+    }
+    failFast = false
+    onOutput { descriptor, event ->
+        logger.lifecycle("Test: " + descriptor + " produced standard out/err: " + event.message )
+    }
+//    debugOptions {
+//        enabled = true
+//        port = 4455
+//        server = true
+//        suspend = true
+//    }
+//    filter {
+//        includeTestsMatching("*Test")
+//    }
+}
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+//import java.util.regex.Matcher
+//import java.util.regex.Pattern
+
+//def checkForUpdate(String path, String url) {
+//    def expected = new URL(url).text
+//    def actual = new File("${project.projectDir}/src/main/include/${path}").text
+//    if (!actual.equals(expected)) {
+//        def fileName = path.replaceAll("[/\\\\]", '_')
+//        file("${project.projectDir}/build").mkdirs()
+//        (file("${project.projectDir}/build/${fileName}")).text = expected
+//        logger.warn("[\033[31mWARN\033[0m ] Header file has been changed in open source project: ${path}.")
+//    }
+//}
+
+//task checkHeaderFile() {
+//    outputs.files "${project.buildDir}/mxnet_c_api.h", "${project.buildDir}/nnvm_c_api.h"
+//    doFirst {
+//        if (gradle.startParameter.offline) {
+//            logger.warn("[\033[31mWARN\033[0m ] Ignore header validation in offline mode.")
+//            return
+//        }
+//
+//        def mxnetUrl = "https://raw.githubusercontent.com/apache/incubator-mxnet/v1.7.x/"
+//        checkForUpdate("mxnet/c_api.h", "${mxnetUrl}/include/mxnet/c_api.h")
+//        def content = new URL("https://github.com/apache/incubator-mxnet/tree/v1.7.x/3rdparty").text
+//
+//        Pattern pattern = Pattern.compile("href=\"/apache/incubator-tvm/tree/([a-z0-9]+)\"")
+//        Matcher m = pattern.matcher(content);
+//        if (!m.find()) {
+//            throw new GradleException("Failed to retrieve submodule hash for tvm from github")
+//        }
+//        String hash = m.group(1);
+//
+//        def nnvmUrl = "https://raw.githubusercontent.com/apache/incubator-tvm/${hash}"
+//        checkForUpdate("nnvm/c_api.h", "${nnvmUrl}/nnvm/include/nnvm/c_api.h")
+//    }
+//}
+
+compileJava.dependsOn(jnarator)
+
+// TODO
+//publishing {
+//    publications {
+//        maven(MavenPublication) {
+//            pom {
+//                name = "DJL Engine Adapter for Apache MXNet"
+//                description = "Deep Java Library (DJL) Engine Adapter for Apache MXNet"
+//                url = "http://www.djl.ai/mxnet/${project.name}"
+//            }
+//        }
+//    }
+//}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/BaseMxResource.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/BaseMxResource.java
new file mode 100644
index 0000000..13ae88d
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/BaseMxResource.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.engine;
+
+import org.apache.mxnet.jna.JnaUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * The top-level {@link MxResource} instance, with no parent Resource to manage. The {@link
+ * BaseMxResource} instance will be lazy loaded when the first time called, like when {@link Model}
+ * instance is loaded for the first time.
+ */
+public final class BaseMxResource extends MxResource {
+
+    private static final Logger logger = LoggerFactory.getLogger(BaseMxResource.class);
+
+    private static BaseMxResource systemMxResource;
+
+    protected BaseMxResource() {
+        super();
+        // Workaround MXNet engine lazy initialization issue
+        JnaUtils.getAllOpNames();
+
+        JnaUtils.setNumpyMode(JnaUtils.NumpyMode.GLOBAL_ON);
+
+        // Workaround MXNet shutdown crash issue
+        Runtime.getRuntime().addShutdownHook(new Thread(JnaUtils::waitAll)); // NOPMD
+    }
+
+    /**
+     * Getter method for the singleton {@code systemMxResource} instance.
+     *
+     * @return The top-leve {@link BaseMxResource} instance.
+     */
+    public static synchronized BaseMxResource getSystemMxResource() {
+        if (systemMxResource == null) {
+            systemMxResource = new BaseMxResource();
+        }
+        return systemMxResource;
+    }
+
+    /** {@inheritDoc} */
+    @Override
+    public void close() {
+        if (!getClosed()) {
+            logger.debug(String.format("Start to free BaseMxResource instance: %S", this.getUid()));
+            // only clean sub resources
+            JnaUtils.waitAll();
+            super.freeSubResources();
+            setClosed(true);
+            logger.debug(
+                    String.format("Finish to free BaseMxResource instance: %S", this.getUid()));
+        }
+    }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/CachedOp.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/CachedOp.java
new file mode 100644
index 0000000..eef058e
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/CachedOp.java
@@ -0,0 +1,162 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mxnet.engine;
+
+import com.sun.jna.Pointer;
+import java.util.List;
+import java.util.Map;
+import org.apache.mxnet.jna.JnaUtils;
+import org.apache.mxnet.ndarray.NDArray;
+import org.apache.mxnet.ndarray.NDList;
+import org.apache.mxnet.ndarray.types.Shape;
+import org.apache.mxnet.nn.Parameter;
+import org.apache.mxnet.nn.SymbolBlock;
+import org.apache.mxnet.util.Pair;
+import org.apache.mxnet.util.PairList;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * The {@code CachedOp} is an internal helper that provides the core functionality to execute a
+ * {@link SymbolBlock}.
+ *
+ * <p>We don't recommend users interact with this class directly. Users should use {@link Predictor}
+ * instead. CachedOp is an operator that simplifies calling and analyzing the input shape. It
+ * requires minimum input to do inference because most of the information can be obtained from the
+ * model itself.
+ */
+public class CachedOp extends MxResource {
+
+    private static final Logger logger = LoggerFactory.getLogger(CachedOp.class);
+
+    private List<Parameter> parameters;
+    private PairList<String, Integer> dataIndices;
+    private Map<String, Integer> dataIndicesMap;
+    private List<Integer> paramIndices;
+
+    /**
+     * Creates an instance of {@link CachedOp}.
+     *
+     * <p>It can be created by using {@link JnaUtils#createCachedOp(SymbolBlock, MxResource)}
+     *
+     * @param parent the MxResource object to manage this instance of CachedOp
+     * @param handle the C handle of the CachedOp
+     * @param parameters the parameter values
+     * @param paramIndices the parameters required by the model and their corresponding location
+     * @param dataIndices the input data names required by the model and their corresponding
+     *     location
+     */
+    public CachedOp(
+            MxResource parent,
+            Pointer handle,
+            List<Parameter> parameters,
+            List<Integer> paramIndices,
+            PairList<String, Integer> dataIndices) {
+        super(parent, handle);
+        this.parameters = parameters;
+        this.dataIndices = dataIndices;
+        this.paramIndices = paramIndices;
+        this.dataIndicesMap = dataIndices.toMap();
+    }
+
+    /**
+     * Assigns inputs to the empty locations of the input NDArray.
+     *
+     * @param data the input in {@link NDList} format
+     * @return an {@link NDList}
+     */
+    public NDList forward(NDList data) {
+        // reset the input data index at the beginning
+        NDArray[] allInputsNDArray = new NDArray[parameters.size()];
+        // check device of input
+        Device device = data.isEmpty() ? Device.defaultIfNull() : data.head().getDevice();
+        // fill allInputsNDArray with parameter values on correct device
+        for (int index : paramIndices) {
+            Parameter parameter = parameters.get(index);
+            NDArray value = parameter.getArray();
+            if (value == null) {
+                throw new NullPointerException("Failed to find parameter from parameterStore");
+            }
+            value.setDevice(device);
+            allInputsNDArray[index] = value;
+        }
+
+        // fill allInputsNDArray with data values
+        int index = 0;
+        for (NDArray array : data) {
+            // TODO: NDArray name doesn't match. To confirm the format of input name
+            //            String inputName = array.getName().split(":")[1];
+            String inputName = array.getName();
+            // if inputName not provided, value will follow the default order
+            int idx = indexOf(inputName, index++);
+            allInputsNDArray[idx] = array;
+        }
+
+        // check the input, set as Shape(batchSize) by default
+        for (Pair<String, Integer> pair : dataIndices) {
+            if (allInputsNDArray[pair.getValue()] == null) {
+                // TODO: Do we need to set default to the input?
+                long batchSize = data.head().getShape().get(0);
+                String key = pair.getKey();
+                if (!"prob_label".equals(key) && !"softmax_label".equals(key)) {
+                    logger.warn(
+                            "Input "
+                                    + key
+                                    + " not found, set NDArray to Shape("
+                                    + batchSize
+                                    + ") by default");
+                }
+                // TODO: consider how to manage MxNDArray generated during inference
+                allInputsNDArray[pair.getValue()] =
+                        NDArray.create(this, new Shape(batchSize), device);
+            }
+        }
+        NDArray[] result = JnaUtils.cachedOpInvoke(getParent(), getHandle(), allInputsNDArray);
+        return new NDList(result);
+    }
+
+    /** {@inheritDoc} */
+    @Override
+    public void close() {
+        if (!getClosed()) {
+            logger.debug(String.format("Start to free CachedOp instance: %S", this.getUid()));
+            super.freeSubResources();
+            Pointer pointer = handle.getAndSet(null);
+            if (pointer != null) {
+                JnaUtils.freeCachedOp(pointer);
+            }
+            setClosed(true);
+            logger.debug(String.format("Finish to free CachedOp instance: %S", this.getUid()));
+        }
+    }
+
+    private int indexOf(String inputName, int position) {
+        if (inputName == null) {
+            return dataIndices.valueAt(position);
+        }
+
+        Integer index = dataIndicesMap.get(inputName);
+        if (index == null) {
+            throw new IllegalArgumentException(
+                    "Unknown input name: "
+                            + inputName
+                            + ", expected inputs: "
+                            + dataIndicesMap.keySet().toString());
+        }
+        return index;
+    }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/Device.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/Device.java
new file mode 100644
index 0000000..7fb74c6
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/Device.java
@@ -0,0 +1,236 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.engine;
+
+import java.util.Map;
+import java.util.Objects;
+import java.util.concurrent.ConcurrentHashMap;
+import org.apache.mxnet.util.cuda.CudaUtils;
+
+/**
+ * The {@code Device} class provides the specified assignment for CPU/GPU processing on the {@link
+ * org.apache.mxnet.ndarray.NDArray}.
+ *
+ * <p>Users can use this to specify whether to load/compute the {@code NDArray} on CPU/GPU with
+ * deviceType and deviceId provided.
+ */
+public final class Device {
+
+    private static final Map<String, Device> CACHE = new ConcurrentHashMap<>();
+
+    private static final Device CPU = new Device(Type.CPU, -1);
+
+    private static final Device GPU = Device.of(Type.GPU, 0);
+
+    private String deviceType;
+
+    private int deviceId;
+
+    private static final Device DEFAULT_DEVICE = CPU;
+
+    /**
+     * Creates a {@code Device} with basic information.
+     *
+     * @param deviceType the device type, typically CPU or GPU
+     * @param deviceId the deviceId on the hardware. For example, if you have multiple GPUs, you can
+     *     choose which GPU to process the NDArray
+     */
+    private Device(String deviceType, int deviceId) {
+        this.deviceType = deviceType;
+        this.deviceId = deviceId;
+    }
+
+    /**
+     * Returns a {@code Device} with device type and device id.
+     *
+     * @param deviceType the device type, typically CPU or GPU
+     * @param deviceId the deviceId on the hardware.
+     * @return a {@code Device} instance
+     */
+    public static Device of(String deviceType, int deviceId) {
+        if (Type.CPU.equals(deviceType)) {
+            return CPU;
+        }
+        String key = deviceType + '-' + deviceId;
+        return CACHE.computeIfAbsent(key, k -> new Device(deviceType, deviceId));
+    }
+
+    /**
+     * Returns the device type of the Device.
+     *
+     * @return the device type of the Device
+     */
+    public String getDeviceType() {
+        return deviceType;
+    }
+
+    /**
+     * Returns the {@code deviceId} of the Device.
+     *
+     * @return the {@code deviceId} of the Device
+     */
+    public int getDeviceId() {
+        if (Type.CPU.equals(deviceType)) {
+            throw new IllegalStateException("CPU doesn't have device id");
+        }
+        return deviceId;
+    }
+
+    /** {@inheritDoc} */
+    @Override
+    public String toString() {
+        if (Type.CPU.equals(deviceType)) {
+            return deviceType + "()";
+        }
+        return deviceType + '(' + deviceId + ')';
+    }
+
+    /** {@inheritDoc} */
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) {
+            return true;
+        }
+        if (o == null || getClass() != o.getClass()) {
+            return false;
+        }
+        Device device = (Device) o;
+        if (Type.CPU.equals(deviceType)) {
+            return Objects.equals(deviceType, device.getDeviceType());
+        }
+        return deviceId == device.deviceId && Objects.equals(deviceType, device.deviceType);
+    }
+
+    /** {@inheritDoc} */
+    @Override
+    public int hashCode() {
+        return Objects.hash(deviceType, deviceId);
+    }
+
+    /**
+     * Returns the default CPU Device.
+     *
+     * @return the default CPU Device
+     */
+    public static Device cpu() {
+        return CPU;
+    }
+
+    /**
+     * Returns the default GPU Device.
+     *
+     * @return the default GPU Device
+     */
+    public static Device gpu() {
+        return GPU;
+    }
+
+    /**
+     * Returns a new instance of GPU {@code Device} with the specified {@code deviceId}.
+     *
+     * @param deviceId the GPU device ID
+     * @return a new instance of GPU {@code Device} with specified {@code deviceId}
+     */
+    public static Device gpu(int deviceId) {
+        return of(Type.GPU, deviceId);
+    }
+
+    /**
+     * Returns an array of devices.
+     *
+     * <p>If GPUs are available, it will return an array of {@code Device} of size
+     * \(min(numAvailable, maxGpus)\). Else, it will return an array with a single CPU device.
+     *
+     * @return an array of devices
+     */
+    public static Device[] getDevices() {
+        return getDevices(Integer.MAX_VALUE);
+    }
+
+    /**
+     * Returns an array of devices given the maximum number of GPUs to use.
+     *
+     * <p>If GPUs are available, it will return an array of {@code Device} of size
+     * \(min(numAvailable, maxGpus)\). Else, it will return an array with a single CPU device.
+     *
+     * @param maxGpus the max number of GPUs to use. Use 0 for no GPUs.
+     * @return an array of devices
+     */
+    public static Device[] getDevices(int maxGpus) {
+        int count = getGpuCount();
+        if (maxGpus <= 0 || count <= 0) {
+            return new Device[] {CPU};
+        }
+
+        count = Math.min(maxGpus, count);
+        Device[] devices = new Device[count];
+        for (int i = 0; i < devices.length; ++i) {
+            devices[i] = gpu(i);
+        }
+        return devices;
+    }
+
+    /**
+     * Returns the number of GPUs available in the system.
+     *
+     * @return the number of GPUs available in the system
+     */
+    public static int getGpuCount() {
+        return CudaUtils.getGpuCount();
+    }
+
+    /**
+     * Returns the default context used in Engine.
+     *
+     * <p>The default type is defined by whether the deep learning engine is recognizing GPUs
+     * available on your machine. If there is no GPU available, CPU will be used.
+     *
+     * @return a {@link Device}
+     */
+    private static Device defaultDevice() {
+        return DEFAULT_DEVICE;
+    }
+
+    /**
+     * Returns the given device or the default if it is null.
+     *
+     * @param device the device to try to return
+     * @return the given device or the default if it is null
+     */
+    public static Device defaultIfNull(Device device) {
+        if (device != null) {
+            return device;
+        }
+        return defaultDevice();
+    }
+
+    /**
+     * Returns the default device.
+     *
+     * @return the default device
+     */
+    public static Device defaultIfNull() {
+        return defaultIfNull(null);
+    }
+
+    /** Contains device type string constants. */
+    public interface Type {
+        String CPU = "cpu";
+        String GPU = "gpu";
+    }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/DeviceType.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/DeviceType.java
new file mode 100644
index 0000000..02ad120
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/DeviceType.java
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.engine;
+
+/** {@code DeviceType} is a class used to map the Device name to their corresponding type number. */
+public final class DeviceType {
+
+    private static final String CPU_PINNED = "cpu_pinned";
+
+    private DeviceType() {}
+
+    /**
+     * Converts a {@link Device} to the corresponding MXNet device number.
+     *
+     * @param device the java {@link Device}
+     * @return the MXNet device number
+     * @exception IllegalArgumentException the device is null or is not supported
+     */
+    public static int toDeviceType(Device device) {
+        if (device == null) {
+            throw new IllegalArgumentException("Unsupported device: null");
+        }
+
+        String deviceType = device.getDeviceType();
+
+        if (Device.Type.CPU.equals(deviceType)) {
+            return 1;
+        } else if (Device.Type.GPU.equals(deviceType)) {
+            return 2;
+        } else if (CPU_PINNED.equals(deviceType)) {
+            return 3;
+        } else {
+            throw new IllegalArgumentException("Unsupported device: " + device.toString());
+        }
+    }
+
+    /**
+     * Converts from an MXNet device number to {@link Device}.
+     *
+     * @param deviceType the MXNet device number
+     * @return the corresponding {@link Device}
+     */
+    public static String fromDeviceType(int deviceType) {
+        switch (deviceType) {
+            case 1:
+            case 3:
+                // hide the CPU_PINNED to frontend user
+                // but the advance user can still create CPU_PINNED
+                // to pass through engine
+                return Device.Type.CPU;
+            case 2:
+                return Device.Type.GPU;
+            default:
+                throw new IllegalArgumentException("Unsupported deviceType: " + deviceType);
+        }
+    }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/GradReq.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/GradReq.java
new file mode 100644
index 0000000..c388de3
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/GradReq.java
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.engine;
+
+/** An enum that indicates whether gradient is required. */
+public enum GradReq {
+    NULL("null", 0),
+    WRITE("write", 1),
+    ADD("add", 3);
+
+    private String type;
+    private int value;
+
+    GradReq(String type, int value) {
+        this.type = type;
+        this.value = value;
+    }
+
+    /**
+     * Gets the type of this {@code GradReq}.
+     *
+     * @return the type
+     */
+    public String getType() {
+        return type;
+    }
+
+    /**
+     * Gets the value of this {@code GradType}.
+     *
+     * @return the value
+     */
+    public int getValue() {
+        return value;
+    }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/Model.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/Model.java
new file mode 100644
index 0000000..90cceb4
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/Model.java
@@ -0,0 +1,442 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.engine;
+
+import java.io.FileNotFoundException;
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.util.ArrayList;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Locale;
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+import org.apache.mxnet.jna.JnaUtils;
+import org.apache.mxnet.ndarray.NDArray;
+import org.apache.mxnet.ndarray.NDList;
+import org.apache.mxnet.ndarray.types.DataType;
+import org.apache.mxnet.ndarray.types.Shape;
+import org.apache.mxnet.nn.Parameter;
+import org.apache.mxnet.nn.SymbolBlock;
+import org.apache.mxnet.repository.Item;
+import org.apache.mxnet.repository.Repository;
+import org.apache.mxnet.translate.NoOpTranslator;
+import org.apache.mxnet.translate.Translator;
+import org.apache.mxnet.util.PairList;
+import org.apache.mxnet.util.Utils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * A model is a collection of artifacts that is created by the training process.
+ *
+ * <p>Model contains methods to load and process a model object. In addition, it provides MXNet
+ * Specific functionality, such as getSymbol to obtain the Symbolic graph and getParameters to
+ * obtain the parameter NDArrays
+ */
+public class Model extends MxResource {
+
+    private static final Logger logger = LoggerFactory.getLogger(Model.class);
+    protected Path modelDir;
+    protected SymbolBlock symbolBlock;
+    protected String modelName;
+    protected DataType dataType;
+    protected PairList<String, Shape> inputData;
+    protected Map<String, Object> artifacts = new ConcurrentHashMap<>();
+    protected Map<String, String> properties = new ConcurrentHashMap<>();
+
+    Model(String name, Device device) {
+        this(BaseMxResource.getSystemMxResource(), name, device);
+    }
+
+    private Model(MxResource parent, String name, Device device) {
+        super(parent);
+        setDevice(Device.defaultIfNull(device));
+        setDataType(DataType.FLOAT32);
+        setModelName(name);
+    }
+
+    /**
+     * Create a default {@link Predictor} instance, with {@link NoOpTranslator} as default
+     * translator , and do not copy parameters to parameter store.
+     *
+     * @return {@link Predictor}
+     */
+    public Predictor<NDList, NDList> newPredictor() {
+        Translator<NDList, NDList> noOpTranslator = new NoOpTranslator();
+        return newPredictor(noOpTranslator);
+    }
+
+    /**
+     * Create {@link Predictor} instance, with specific {@link Translator} and {@code copy}.
+     *
+     * @param translator {@link Translator} used to convert inputs and outputs into {@link NDList}
+     *     to get inferred
+     * @param <I> the input type
+     * @param <O> the output type
+     * @return {@link Predictor}
+     */
+    public <I, O> Predictor<I, O> newPredictor(Translator<I, O> translator) {
+        return new Predictor<>(this, translator);
+    }
+
+    /**
+     * Create and initialize a MxModel from the model directory.
+     *
+     * @param modelPath {@code Path} model directory
+     * @return loaded {@code Model} instance
+     * @throws IOException when IO operation fails in loading a resource
+     */
+    public static Model loadModel(Path modelPath) throws IOException {
+        return loadModel("model", modelPath);
+    }
+
+    /**
+     * Create and initialize a MxModel from repository Item.
+     *
+     * @param modelItem {@link Item} model directory
+     * @return {@link Model}
+     * @throws IOException when IO operation fails in loading a resource
+     */
+    public static Model loadModel(Item modelItem) throws IOException {
+        Model model = createModel(modelItem);
+        model.initial();
+        return model;
+    }
+
+    /**
+     * Create and initialize a MxModel with a model name from the model directory.
+     *
+     * @param modelName {@link String} model name
+     * @param modelPath {@link Path} model directory
+     * @return {@link Model}
+     * @throws IOException when IO operation fails in loading a resource
+     */
+    public static Model loadModel(String modelName, Path modelPath) throws IOException {
+        Model model = createModel(modelName, modelPath);
+        model.initial();
+        return model;
+    }
+
+    /**
+     * Create a MxModel with specific model name and model directory. By default, the {@link Model}
+     * instance is managed by the top level {@link BaseMxResource}.
+     *
+     * @param modelName {@String} model name
+     * @param modelDir {@Path} local model path
+     * @return {@link Model}
+     * @throws IOException when IO operation fails in loading a resource
+     */
+    static Model createModel(String modelName, Path modelDir) {
+        Model model = new Model(modelName, Device.defaultIfNull());
+        model.setModelDir(modelDir);
+        return model;
+    }
+
+    /**
+     * Create a sample MxModel Download or find the local path for the sample model.
+     *
+     * @param item {@link Item} sample model to be created
+     * @return created {@link Model} instance
+     * @throws IOException when IO operation fails in loading a resource
+     */
+    static Model createModel(Item item) throws IOException {
+        Path modelDir = Repository.initRepository(item);
+        return createModel(item.getName(), modelDir);
+    }
+
+    /**
+     * Initialize the model object Download or find the path for target model Load parameters and
+     * symbol from the path.
+     *
+     * @throws IOException when IO operation fails in loading a resource
+     * @throws FileNotFoundException if Model Directory is not assigned
+     */
+    public void initial() throws IOException {
+        if (getModelDir() == null) {
+            throw new FileNotFoundException("Model path is not defined!");
+        }
+        load(getModelDir());
+    }
+
+    /**
+     * Loads the model from the {@code modelPath}.
+     *
+     * @param modelPath the directory or file path of the model location
+     * @throws IOException when IO operation fails in loading a resource
+     */
+    public void load(Path modelPath) throws IOException {
+        load(modelPath, null, null);
+    }
+
+    /**
+     * Loads the MXNet model from a specified location.
+     *
+     * <p>MXNet Model looks for {MODEL_NAME}-symbol.json and {MODEL_NAME}-{EPOCH}.params files in
+     * the specified directory. By default, It will pick up the latest epoch of the parameter file.
+     * However, users can explicitly specify an epoch to be loaded:
+     *
+     * <pre>
+     * Map&lt;String, String&gt; options = new HashMap&lt;&gt;()
+     * <b>options.put("epoch", "3");</b>
+     * model.load(modelPath, "squeezenet", options);
+     * </pre>
+     *
+     * @param modelPath the directory of the model
+     * @param prefix the model file name or path prefix
+     * @param options load model options, see documentation for the specific engine
+     * @throws IOException Exception for file loading
+     */
+    public void load(Path modelPath, String prefix, Map<String, ?> options) throws IOException {
+        modelDir = modelPath.toAbsolutePath();
+        if (prefix == null) {
+            prefix = modelName;
+        }
+        Path paramFile = paramPathResolver(prefix, options);
+        if (paramFile == null) {
+            prefix = modelDir.toFile().getName();
+            paramFile = paramPathResolver(prefix, options);
+            if (paramFile == null) {
+                throw new FileNotFoundException(
+                        "Parameter file with prefix: " + prefix + " not found in: " + modelDir);
+            }
+        }
+
+        if (getSymbolBlock() == null) {
+            // load MxSymbolBlock
+            Path symbolFile = modelDir.resolve(prefix + "-symbol.json");
+            if (Files.notExists(symbolFile)) {
+                throw new FileNotFoundException(
+                        "Symbol file not found: "
+                                + symbolFile
+                                + ", please set block manually for imperative model.");
+            }
+
+            // TODO: change default name "data" to model-specific one
+            setMxSymbolBlock(SymbolBlock.createMxSymbolBlock(this, symbolFile));
+        }
+        loadParameters(paramFile);
+        // TODO: Check if Symbol has all names that params file have
+        if (options != null && options.containsKey("MxOptimizeFor")) {
+            String optimization = (String) options.get("MxOptimizeFor");
+            getSymbolBlock().optimizeFor(optimization, getDevice());
+        }
+    }
+
+    protected Path paramPathResolver(String prefix, Map<String, ?> options) throws IOException {
+        try {
+            int epoch = getEpoch(prefix, options);
+            return getModelDir()
+                    .resolve(String.format(Locale.ROOT, "%s-%04d.params", prefix, epoch));
+        } catch (FileNotFoundException e) {
+            return null;
+        }
+    }
+
+    private int getEpoch(String prefix, Map<String, ?> options) throws IOException {
+        if (options != null) {
+            Object epochOption = options.getOrDefault("epoch", null);
+            if (epochOption != null) {
+                return Integer.parseInt(epochOption.toString());
+            }
+        }
+        return Utils.getCurrentEpoch(getModelDir(), prefix);
+    }
+
+    @SuppressWarnings("PMD.UseConcurrentHashMap")
+    private void loadParameters(Path paramFile) {
+
+        NDList paramNDlist = JnaUtils.loadNdArray(this, paramFile, getDevice());
+
+        List<Parameter> parameters = getSymbolBlock().getAllParameters();
+        Map<String, Parameter> map = new LinkedHashMap<>();
+        parameters.forEach(p -> map.put(p.getName(), p));
+
+        for (NDArray nd : paramNDlist) {
+            String key = nd.getName();
+            if (key == null) {
+                throw new IllegalArgumentException("Array names must be present in parameter file");
+            }
+
+            String paramName = key.split(":", 2)[1];
+            Parameter parameter = map.remove(paramName);
+            parameter.setArray(nd);
+        }
+        getSymbolBlock().setInputNames(new ArrayList<>(map.keySet()));
+
+        // TODO: Find a better to infer model DataType from SymbolBlock.
+        dataType = paramNDlist.head().getDataType();
+        logger.debug("MXNet Model {} ({}) loaded successfully.", paramFile, dataType);
+    }
+
+    /**
+     * Get the modelDir from the Model.
+     *
+     * @return {@link Path} modelDir for the Model
+     */
+    public Path getModelDir() {
+        return modelDir;
+    }
+
+    /**
+     * Set the modelDir for the Model.
+     *
+     * @param modelDir {@link Path}
+     */
+    public void setModelDir(Path modelDir) {
+        this.modelDir = modelDir;
+    }
+
+    /**
+     * Get the symbolBlock of the Model.
+     *
+     * @return {@link SymbolBlock}
+     */
+    public SymbolBlock getSymbolBlock() {
+        return symbolBlock;
+    }
+
+    /**
+     * Set the symbolBlock for the Model.
+     *
+     * @param symbolBlock {@link SymbolBlock}
+     */
+    public void setMxSymbolBlock(SymbolBlock symbolBlock) {
+        this.symbolBlock = symbolBlock;
+    }
+
+    /**
+     * Get the name of the Model.
+     *
+     * @return modelName
+     */
+    public String getModelName() {
+        return modelName;
+    }
+
+    /**
+     * Set the model name for the Model.
+     *
+     * @param modelName for the Model
+     */
+    public final void setModelName(String modelName) {
+        this.modelName = modelName;
+    }
+
+    /**
+     * Get data type for the Model.
+     *
+     * @return {@link DataType}
+     */
+    public DataType getDataType() {
+        return dataType;
+    }
+
+    /**
+     * Set data type for the Model.
+     *
+     * @param dataType {@link DataType}
+     */
+    public final void setDataType(DataType dataType) {
+        this.dataType = dataType;
+    }
+
+    /**
+     * Get input data of the Model.
+     *
+     * @return {@link PairList} inputData
+     */
+    public PairList<String, Shape> getInputData() {
+        return inputData;
+    }
+
+    /**
+     * Set input data for the Model.
+     *
+     * @param inputData {@link PairList}
+     */
+    public void setInputData(PairList<String, Shape> inputData) {
+        this.inputData = inputData;
+    }
+
+    /**
+     * Get the Artifact Object from artifacts by key.
+     *
+     * @param key for the Artifact Object
+     * @return Artifact {@link Object} instance
+     */
+    public Object getArtifact(String key) {
+        return artifacts.get(key);
+    }
+
+    /**
+     * Set the Artifact Object for artifacts.
+     *
+     * @param key for the Artifact
+     * @param artifact {@link Object}
+     */
+    public void setArtifact(String key, Object artifact) {
+        artifacts.put(key, artifact);
+    }
+
+    /**
+     * Get the property from properties by key.
+     *
+     * @param key {@link String}
+     * @return {@link String} property
+     */
+    public String getProperty(String key) {
+        return properties.get(key);
+    }
+
+    /**
+     * Set the property for the Model.
+     *
+     * @param key for the property
+     * @param property value of the property
+     */
+    public void setProperties(String key, String property) {
+        this.properties.put(key, property);
+    }
+
+    /** {@inheritDoc} */
+    @Override
+    public Device getDevice() {
+        if (device == null) {
+            return super.getDevice();
+        }
+        return device;
+    }
+
+    /** {@inheritDoc} */
+    @Override
+    public void close() {
+        if (!getClosed()) {
+            logger.debug(String.format("Start to free Model instance: %S", this.getModelName()));
+            // release sub resources
+            super.freeSubResources();
+            // release itself
+            this.symbolBlock = null;
+            this.artifacts = null;
+            this.properties = null;
+            setClosed(true);
+            logger.debug(String.format("Finish to free Model instance: %S", this.getModelName()));
+        }
+    }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/MxDataType.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/MxDataType.java
new file mode 100644
index 0000000..b96941f
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/MxDataType.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.engine;
+
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+import org.apache.mxnet.ndarray.types.DataType;
+
+/** Helper to convert between {@link DataType} and the MXNet internal DataTypes. */
+public final class MxDataType {
+
+    private static Map<DataType, String> toMx = createMapToMx();
+    private static Map<String, DataType> fromMx = createMapFromMx();
+
+    private MxDataType() {}
+
+    private static Map<DataType, String> createMapToMx() {
+        Map<DataType, String> map = new ConcurrentHashMap<>();
+        map.put(DataType.FLOAT32, "float32");
+        map.put(DataType.FLOAT64, "float64");
+        map.put(DataType.INT32, "int32");
+        map.put(DataType.INT64, "int64");
+        map.put(DataType.UINT8, "uint8");
+        return map;
+    }
+
+    private static Map<String, DataType> createMapFromMx() {
+        Map<String, DataType> map = new ConcurrentHashMap<>();
+        map.put("float32", DataType.FLOAT32);
+        map.put("float64", DataType.FLOAT64);
+        map.put("int32", DataType.INT32);
+        map.put("int64", DataType.INT64);
+        map.put("uint8", DataType.UINT8);
+        return map;
+    }
+
+    /**
+     * Converts a MXNet type String into a {@link DataType}.
+     *
+     * @param mxType the type String to convert
+     * @return the {@link DataType}
+     */
+    public static DataType fromMx(String mxType) {
+        return fromMx.get(mxType);
+    }
+
+    /**
+     * Converts a {@link DataType} into the corresponding MXNet type String.
+     *
+     * @param jType the java {@link DataType} to convert
+     * @return the converted MXNet type string
+     */
+    public static String toMx(DataType jType) {
+        return toMx.get(jType);
+    }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/MxResource.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/MxResource.java
new file mode 100644
index 0000000..5740d9b
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/MxResource.java
@@ -0,0 +1,171 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.engine;
+
+import com.sun.jna.Pointer;
+import java.util.UUID;
+import java.util.concurrent.ConcurrentHashMap;
+import org.apache.mxnet.util.NativeResource;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * An auto closable Resource object whose life circle can be managed by its parent {@link
+ * MxResource} instance. Meanwhile, it manages life circle of child {@link MxResource} instances.
+ */
+public class MxResource extends NativeResource<Pointer> {
+
+    private static final Logger logger = LoggerFactory.getLogger(MxResource.class);
+
+    private static boolean closed;
+
+    protected Device device;
+
+    private MxResource parent;
+
+    private ConcurrentHashMap<String, MxResource> subResources;
+
+    protected MxResource() {
+        super();
+        setParent(null);
+    }
+
+    protected MxResource(MxResource parent, String uid) {
+        super(uid);
+        setClosed(false);
+        setParent(parent);
+        getParent().addSubResource(this);
+    }
+
+    protected MxResource(MxResource parent) {
+        this(parent, UUID.randomUUID().toString());
+    }
+
+    protected MxResource(MxResource parent, Pointer handle) {
+        super(handle);
+        setParent(parent);
+        if (parent != null) {
+            parent.addSubResource(this);
+        } else {
+            BaseMxResource.getSystemMxResource().addSubResource(this);
+        }
+    }
+    /**
+     * Add the sub {@link MxResource} under the current instance.
+     *
+     * @param subMxResource the instance to be added
+     */
+    public void addSubResource(MxResource subMxResource) {
+        getSubResource().put(subMxResource.getUid(), subMxResource);
+    }
+
+    /** Free all sub {@link MxResource} instances of the current instance. */
+    public void freeSubResources() {
+        if (subResourceInitialized()) {
+            for (MxResource subResource : subResources.values()) {
+                try {
+                    subResource.close();
+                } catch (Exception e) {
+                    logger.error("MxResource close failed.", e);
+                }
+            }
+            subResources = null;
+        }
+    }
+
+    /**
+     * Check whether {@code subResource} has been initialized.
+     *
+     * @return boolean
+     */
+    public boolean subResourceInitialized() {
+        return subResources != null;
+    }
+
+    /**
+     * Get the {@code subResources} of the {@link MxResource}.
+     *
+     * @return subResources
+     */
+    public ConcurrentHashMap<String, MxResource> getSubResource() {
+        if (!subResourceInitialized()) {
+            subResources = new ConcurrentHashMap<>();
+        }
+        return subResources;
+    }
+
+    protected final void setParent(MxResource parent) {
+        this.parent = parent;
+    }
+
+    /**
+     * Get parent {@link MxResource} of the current instance.
+     *
+     * @return {@link MxResource}
+     */
+    public MxResource getParent() {
+        return this.parent;
+    }
+
+    /**
+     * Set the {@link Device} for the {@link MxResource}.
+     *
+     * @param device {@link Device}
+     */
+    public void setDevice(Device device) {
+        this.device = device;
+    }
+
+    /**
+     * Returns the {@link Device} of this {@code MxResource}.
+     *
+     * <p>{@link Device} class contains the information where this {@code NDArray} stored in memory,
+     * like CPU/GPU.
+     *
+     * @return the {@link Device} of this {@code MxResource}
+     */
+    public Device getDevice() {
+        Device curDevice = getParent() == null ? null : getParent().getDevice();
+        return Device.defaultIfNull(curDevice);
+    }
+
+    /**
+     * Sets closed for MxResource instance.
+     *
+     * @param isClosed whether this {@link MxResource} get closed
+     */
+    public final void setClosed(boolean isClosed) {
+        this.closed = isClosed;
+    }
+
+    /**
+     * Get the attribute closed for the MxResource to check out whether it is closed.
+     *
+     * @return closed
+     */
+    public boolean getClosed() {
+        return closed;
+    }
+
+    /** {@inheritDoc} */
+    @Override
+    public void close() {
+        freeSubResources();
+        setClosed(true);
+    }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/MxResourceList.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/MxResourceList.java
new file mode 100644
index 0000000..6869cf6
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/MxResourceList.java
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.engine;
+
+import java.util.List;
+import java.util.Map;
+import org.apache.mxnet.util.Pair;
+import org.apache.mxnet.util.PairList;
+
+/**
+ * An {@code MxResourceList} represents a sequence of {@link MxResource}s with names.
+ *
+ * <p>Each {@link MxResource} in this list can optionally have a name. You can use the name to look
+ * up an MxResource in the MxResourceList.
+ *
+ * @see MxResource
+ */
+public class MxResourceList extends PairList<String, MxResource> {
+
+    /** Creates an empty {@code MxResourceList}. */
+    public MxResourceList() {}
+
+    /**
+     * Constructs an empty {@code MxResourceList} with the specified initial capacity.
+     *
+     * @param initialCapacity the initial capacity of the list
+     * @throws IllegalArgumentException if the specified initial capacity is negative
+     */
+    public MxResourceList(int initialCapacity) {
+        super(initialCapacity);
+    }
+
+    /**
+     * Constructs a {@code BlockList} containing the elements of the specified keys and values.
+     *
+     * @param keys the key list containing the elements to be placed into this {@code
+     *     MxResourceList}
+     * @param values the value list containing the elements to be placed into this {@code
+     *     MxResource}
+     * @throws IllegalArgumentException if the keys and values size are different
+     */
+    public MxResourceList(List<String> keys, List<MxResource> values) {
+        super(keys, values);
+    }
+
+    /**
+     * Constructs a {@code BlockList} containing the elements of the specified list of Pairs.
+     *
+     * @param list the list containing the elements to be placed into this {@code MxResourceList}
+     */
+    public MxResourceList(List<Pair<String, MxResource>> list) {
+        super(list);
+    }
+
+    /**
+     * Constructs a {@code BlockList} containing the elements of the specified map.
+     *
+     * @param map the map containing keys and values
+     */
+    public MxResourceList(Map<String, MxResource> map) {
+        super(map);
+    }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/OpParams.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/OpParams.java
new file mode 100644
index 0000000..3362e03
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/OpParams.java
@@ -0,0 +1,217 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.engine;
+
+import org.apache.mxnet.ndarray.types.Shape;
+import org.apache.mxnet.ndarray.types.SparseFormat;
+import org.apache.mxnet.util.PairList;
+
+/** An internal helper for creating the MXNet operator parameters. */
+public class OpParams extends PairList<String, Object> {
+    // mxnet cpu take index
+    private static final String MXNET_CPU = "cpu(0)";
+    /**
+     * Sets the Shape parameter.
+     *
+     * @param shape the shape to set
+     */
+    public void setShape(Shape shape) {
+        addParam("shape", shape);
+    }
+
+    /**
+     * Sets the device to use for the operation.
+     *
+     * @param device the device to use for the operation
+     */
+    public void setDevice(Device device) {
+        setParam("ctx", ("cpu".equals(device.getDeviceType()) ? MXNET_CPU : device.toString()));
+    }
+
+    /**
+     * Sets the dataType to use for the operation.
+     *
+     * @param dataType the dataType to use for the operation
+     */
+    public void setDataType(org.apache.mxnet.ndarray.types.DataType dataType) {
+        if (dataType != null) {
+            setParam("dtype", MxDataType.toMx(dataType));
+        }
+    }
+
+    /**
+     * Sets the sparseFormat to use for the operation.
+     *
+     * @param sparseFormat the sparseFormat to use for the operation
+     */
+    public void setSparseFormat(SparseFormat sparseFormat) {
+        if (sparseFormat != null) {
+            setParam("stype", String.valueOf(sparseFormat.getValue()));
+        }
+    }
+
+    /**
+     * Sets a (potentially existing) parameter to a new value.
+     *
+     * @param paramName the parameter name to update
+     * @param value the value to set the parameter to
+     */
+    public void setParam(String paramName, String value) {
+        remove(paramName);
+        add(paramName, value);
+    }
+
+    /**
+     * Adds a parameter.
+     *
+     * @param paramName the name of the new parameter
+     * @param shape the value of the new parameter
+     */
+    public void addParam(String paramName, Shape shape) {
+        if (shape != null) {
+            add(paramName, shape.toString());
+        }
+    }
+
+    /**
+     * Adds a parameter.
+     *
+     * @param paramName the name of the new parameter
+     * @param value the value of the new parameter
+     */
+    public void addParam(String paramName, String value) {
+        add(paramName, value);
+    }
+
+    /**
+     * Adds a parameter.
+     *
+     * @param paramName the name of the new parameter
+     * @param value the value of the new parameter
+     */
+    public void addParam(String paramName, int value) {
+        add(paramName, String.valueOf(value));
+    }
+
+    /**
+     * Adds a parameter.
+     *
+     * @param paramName the name of the new parameter
+     * @param value the value of the new parameter
+     */
+    public void addParam(String paramName, long value) {
+        add(paramName, String.valueOf(value));
+    }
+
+    /**
+     * Adds a parameter.
+     *
+     * @param paramName the name of the new parameter
+     * @param value the value of the new parameter
+     */
+    public void addParam(String paramName, double value) {
+        add(paramName, String.valueOf(value));
+    }
+
+    /**
+     * Adds a parameter.
+     *
+     * @param paramName the name of the new parameter
+     * @param value the value of the new parameter
+     */
+    public void addParam(String paramName, float value) {
+        add(paramName, String.valueOf(value));
+    }
+
+    /**
+     * Adds a parameter.
+     *
+     * @param paramName the name of the new parameter
+     * @param value the value of the new parameter
+     */
+    public void addParam(String paramName, boolean value) {
+        add(paramName, value ? "True" : "False");
+    }
+
+    /**
+     * Adds a parameter.
+     *
+     * @param paramName the name of the new parameter
+     * @param value the value of the new parameter
+     */
+    public void addParam(String paramName, Number value) {
+        add(paramName, String.valueOf(value));
+    }
+
+    /**
+     * Adds a parameter with tuple value.
+     *
+     * @param paramName the name of the new parameter
+     * @param tuple the values of the new parameter
+     */
+    public void addTupleParam(String paramName, int... tuple) {
+        StringBuilder sb = new StringBuilder();
+        sb.append('(');
+        for (int i = 0; i < tuple.length; ++i) {
+            if (i > 0) {
+                sb.append(", ");
+            }
+            sb.append(tuple[i]);
+        }
+        sb.append(')');
+        add(paramName, sb.toString());
+    }
+
+    /**
+     * Adds a parameter with tuple value.
+     *
+     * @param paramName the name of the new parameter
+     * @param tuple the values of the new parameter
+     */
+    public void addTupleParam(String paramName, long... tuple) {
+        StringBuilder sb = new StringBuilder();
+        sb.append('(');
+        for (int i = 0; i < tuple.length; ++i) {
+            if (i > 0) {
+                sb.append(", ");
+            }
+            sb.append(tuple[i]);
+        }
+        sb.append(')');
+        add(paramName, sb.toString());
+    }
+
+    /**
+     * Adds a parameter with tuple value.
+     *
+     * @param paramName the name of the new parameter
+     * @param tuple the values of the new parameter
+     */
+    public void addTupleParam(String paramName, float... tuple) {
+        StringBuilder sb = new StringBuilder();
+        sb.append('(');
+        for (int i = 0; i < tuple.length; ++i) {
+            if (i > 0) {
+                sb.append(", ");
+            }
+            sb.append(tuple[i]);
+        }
+        sb.append(')');
+        add(paramName, sb.toString());
+    }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/Predictor.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/Predictor.java
new file mode 100644
index 0000000..739f6e0
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/Predictor.java
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.engine;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import org.apache.mxnet.exception.TranslateException;
+import org.apache.mxnet.ndarray.NDList;
+import org.apache.mxnet.translate.Translator;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * The {@code Predictor} class provides a session for model inference.
+ *
+ * <p>You can use a {@code Predictor}, with a specified {@link Translator}, to perform inference on
+ * a {@link Model}
+ *
+ * @param <I> the input type
+ * @param <O> the output type
+ * @see Model
+ * @see Translator
+ * @see <a href="http://docs.djl.ai/docs/development/memory_management.html">The guide on memory
+ *     management</a>
+ * @see <a
+ *     href="https://github.com/deepjavalibrary/djl/blob/master/examples/docs/multithread_inference.md">The
+ *     guide on running multi-threaded inference</a>
+ * @see <a href="http://docs.djl.ai/docs/development/inference_performance_optimization.html">The
+ *     guide on inference performance optimization</a>
+ */
+public class Predictor<I, O> extends MxResource {
+
+    private static final Logger logger = LoggerFactory.getLogger(Predictor.class);
+    private Translator<I, O> translator;
+    private Model model;
+
+    /**
+     * Creates a new instance of {@code Predictor} with the given {@link Model} and {@link
+     * Translator}.
+     *
+     * @param model the model on which the predictions are based
+     * @param translator the translator to be used
+     */
+    public Predictor(Model model, Translator<I, O> translator) {
+        super(model);
+        this.model = model;
+        this.translator = translator;
+    }
+
+    /**
+     * Predicts an item for inference.
+     *
+     * @param input the input
+     * @return the output object defined by the user
+     * @throws TranslateException if an error occurs during prediction
+     */
+    @SuppressWarnings("PMD.AvoidRethrowingException")
+    public List<O> predict(List<I> input) {
+        NDList[] ndLists = processInputs(input);
+        for (int i = 0; i < ndLists.length; ++i) {
+            ndLists[i] = forward(ndLists[i]);
+        }
+        return processOutPut(ndLists);
+    }
+
+    /**
+     * Predicts an Item for inference.
+     *
+     * @param input input data
+     * @return O the output object defined by the user
+     * @throws TranslateException if an error occurs during prediction
+     */
+    public O predict(I input) {
+        return predict(Collections.singletonList(input)).get(0);
+    }
+
+    private NDList forward(NDList ndList) {
+        logger.trace("Predictor input data: {}", ndList);
+        return model.getSymbolBlock().forward(ndList);
+    }
+
+    // TODO: add batch predict
+
+    private NDList[] processInputs(List<I> inputs) {
+        int batchSize = inputs.size();
+        NDList[] preprocessed = new NDList[batchSize];
+        try {
+            for (int i = 0; i < batchSize; ++i) {
+                preprocessed[i] = translator.processInput(inputs.get(i));
+            }
+        } catch (Exception e) {
+            logger.error("Error occurs when process input items.", e);
+            throw new TranslateException(e);
+        }
+        return preprocessed;
+    }
+
+    private List<O> processOutPut(NDList[] ndLists) {
+        List<O> outputs = new ArrayList<>();
+        try {
+            for (NDList mxNDList : ndLists) {
+                outputs.add(translator.processOutput(mxNDList));
+            }
+        } catch (Exception e) {
+            logger.error("Error occurs when process output items.", e);
+            throw new TranslateException(e);
+        }
+        return outputs;
+    }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/Symbol.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/Symbol.java
new file mode 100644
index 0000000..7777add
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/Symbol.java
@@ -0,0 +1,242 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.engine;
+
+import com.sun.jna.Pointer;
+import java.nio.file.Path;
+import java.util.Arrays;
+import java.util.LinkedHashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.stream.Collectors;
+import org.apache.mxnet.jna.JnaUtils;
+import org.apache.mxnet.ndarray.types.Shape;
+import org.apache.mxnet.util.PairList;
+import org.apache.mxnet.util.Utils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * {@code Symbol} is an internal helper for symbolic model graphs used by the {@link
+ * org.apache.mxnet.nn.SymbolBlock}.
+ *
+ * @see org.apache.mxnet.nn.SymbolBlock
+ */
+public class Symbol extends MxResource {
+
+    private static final Logger logger = LoggerFactory.getLogger(Symbol.class);
+
+    private String[] outputs;
+
+    protected Symbol(MxResource parent, Pointer handle) {
+        super(parent, handle);
+    }
+
+    static Symbol loadFromFile(MxResource parent, String path) {
+        Pointer p = JnaUtils.createSymbolFromFile(path);
+        return new Symbol(parent, p);
+    }
+
+    /**
+     * Load {@link Symbol} from the given {@link Path}.
+     *
+     * @param parent the parent {@link MxResource}
+     * @param path the {@link Path} to load the {@link Symbol}
+     * @return {@link Symbol}
+     */
+    public static Symbol loadSymbol(MxResource parent, Path path) {
+        return loadFromFile(parent, path.toAbsolutePath().toString());
+    }
+
+    /**
+     * Loads a symbol from a json string.
+     *
+     * @param parent the parent {@link MxResource}
+     * @param json the json string of the symbol.
+     * @return the new symbol
+     */
+    public static Symbol loadJson(MxResource parent, String json) {
+        Pointer pointer = JnaUtils.createSymbolFromString(json);
+        return new Symbol(parent, pointer);
+    }
+
+    /**
+     * Returns the symbol outputs.
+     *
+     * @return the symbol outputs
+     */
+    public String[] getOutputNames() {
+        if (this.outputs == null) {
+            this.outputs = JnaUtils.listSymbolOutputs(getHandle());
+        }
+        return this.outputs;
+    }
+
+    /** {@inheritDoc} */
+    @Override
+    public void close() {
+        if (!getClosed()) {
+            logger.debug(String.format("Start to free Symbol instance: %S", this.toJsonString()));
+            super.freeSubResources();
+            Pointer pointer = handle.getAndSet(null);
+            if (pointer != null) {
+                JnaUtils.freeSymbol(pointer);
+            }
+            setClosed(true);
+            logger.debug(String.format("Finish to free Symbol instance: %S", this.toJsonString()));
+        }
+    }
+
+    /**
+     * Returns the output symbol by index.
+     *
+     * @param index the index of the output
+     * @return the symbol output as a new symbol
+     */
+    public Symbol get(int index) {
+        Pointer pointer = JnaUtils.getSymbolOutput(getInternals().getHandle(), index);
+        return new Symbol(getParent(), pointer);
+    }
+
+    /**
+     * Returns the output symbol with the given name.
+     *
+     * @param name the name of the symbol to return
+     * @return the output symbol
+     * @throws IllegalArgumentException Thrown if no output matches the name
+     */
+    public Symbol get(String name) {
+        String[] out = getInternalOutputNames();
+        int index = Utils.indexOf(out, name);
+        if (index < 0) {
+            throw new IllegalArgumentException("Cannot find output that matches name: " + name);
+        }
+        return get(index);
+    }
+
+    /**
+     * Returns the symbol argument names.
+     *
+     * @return the symbol argument names
+     */
+    public String[] getArgNames() {
+        return JnaUtils.listSymbolArguments(getHandle());
+    }
+
+    /**
+     * Returns the MXNet auxiliary states for the symbol.
+     *
+     * @return the MXNet auxiliary states for the symbol
+     */
+    public String[] getAuxNames() {
+        return JnaUtils.listSymbolAuxiliaryStates(getHandle());
+    }
+
+    /**
+     * Returns the symbol names.
+     *
+     * @return the symbol names
+     */
+    public String[] getAllNames() {
+        return JnaUtils.listSymbolNames(getHandle());
+    }
+
+    /**
+     * Returns the list of names for all internal outputs.
+     *
+     * @return a list of names
+     */
+    public List<String> getLayerNames() {
+        String[] outputNames = getInternalOutputNames();
+        String[] allNames = getAllNames();
+        Set<String> allNamesSet = new LinkedHashSet<>(Arrays.asList(allNames));
+        // Kill all params field and keep the output layer
+        return Arrays.stream(outputNames)
+                .filter(n -> !allNamesSet.contains(n))
+                .collect(Collectors.toList());
+    }
+
+    private String[] getInternalOutputNames() {
+        return JnaUtils.listSymbolOutputs(getInternals().getHandle());
+    }
+
+    /**
+     * Returns the symbol internals.
+     *
+     * @return the symbol internals symbol
+     */
+    public Symbol getInternals() {
+        Pointer pointer = JnaUtils.getSymbolInternals(getHandle());
+        return new Symbol(getParent(), pointer);
+    }
+
+    /**
+     * Infers the shapes for all parameters inside a symbol from the given input shapes.
+     *
+     * @param pairs the given input name and shape
+     * @return a map of arguments with names and shapes
+     */
+    public Map<String, Shape> inferShape(PairList<String, Shape> pairs) {
+        List<List<Shape>> shapes = JnaUtils.inferShape(this, pairs);
+        if (shapes == null) {
+            throw new IllegalArgumentException("Cannot infer shape based on the data provided!");
+        }
+        List<Shape> argShapes = shapes.get(0);
+        List<Shape> outputShapes = shapes.get(1);
+        List<Shape> auxShapes = shapes.get(2);
+        // TODO: add output to the map
+        String[] argNames = getArgNames();
+        String[] auxNames = getAuxNames();
+        String[] outputNames = getOutputNames();
+        Map<String, Shape> shapesMap = new ConcurrentHashMap<>();
+        for (int i = 0; i < argNames.length; i++) {
+            shapesMap.put(argNames[i], argShapes.get(i));
+        }
+        for (int i = 0; i < auxNames.length; i++) {
+            shapesMap.put(auxNames[i], auxShapes.get(i));
+        }
+        for (int i = 0; i < outputNames.length; i++) {
+            shapesMap.put(outputNames[i], outputShapes.get(i));
+        }
+        return shapesMap;
+    }
+
+    /**
+     * [Experimental] Add customized optimization on the Symbol.
+     *
+     * <p>This method can be used with EIA or TensorRT for model acceleration
+     *
+     * @param backend backend name
+     * @param device the device assigned
+     * @return optimized Symbol
+     */
+    public Symbol optimizeFor(String backend, Device device) {
+        return new Symbol(getParent(), JnaUtils.optimizeFor(this, backend, device));
+    }
+
+    /**
+     * Converts Symbol to json string for saving purpose.
+     *
+     * @return the json string
+     */
+    public String toJsonString() {
+        return JnaUtils.getSymbolString(getHandle());
+    }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/package-info.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/package-info.java
new file mode 100644
index 0000000..e88f05d
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/engine/package-info.java
@@ -0,0 +1,19 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/** Contains the Java front-end implementation for Apache MXNet. */
+package org.apache.mxnet.engine;
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/BaseException.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/BaseException.java
new file mode 100644
index 0000000..0123a50
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/BaseException.java
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.exception;
+
+/** Thrown to indicate that a native error is raised from the underlying. */
+public class BaseException extends RuntimeException {
+
+    private static final long serialVersionUID = 1L;
+
+    /**
+     * Constructs a new exception with the specified detail message. The cause is not initialized,
+     * and may subsequently be initialized by a call to {@link #initCause}.
+     *
+     * @param message the detail message. The detail message is saved for later retrieval by the
+     *     {@link #getMessage()} method.
+     */
+    public BaseException(String message) {
+        super(message);
+    }
+
+    /**
+     * Constructs a new exception with the specified detail message and cause.
+     *
+     * <p>Note that the detail message associated with {@code cause} is <i>not</i> automatically
+     * incorporated in this exception's detail message.
+     *
+     * @param message the detail message (which is saved for later retrieval by the {@link
+     *     #getMessage()} method)
+     * @param cause the cause (which is saved for later retrieval by the {@link #getCause()}
+     *     method). (A {@code null} value is permitted, and indicates that the cause is nonexistent
+     *     or unknown.)
+     */
+    public BaseException(String message, Throwable cause) {
+        super(message, cause);
+    }
+
+    /**
+     * Constructs a new exception with the specified cause and a detail message of {@code
+     * (cause==null ? null : cause.toString())} (which typically contains the class and detail
+     * message of {@code cause}). This constructor is useful for exceptions that are little more
+     * than wrappers for other throwables (for example, {@link
+     * java.security.PrivilegedActionException}).
+     *
+     * @param cause the cause (which is saved for later retrieval by the {@link #getCause()}
+     *     method). (A {@code null} value is permitted, and indicates that the cause is nonexistent
+     *     or unknown.)
+     */
+    public BaseException(Throwable cause) {
+        super(cause);
+    }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/JnaCallException.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/JnaCallException.java
new file mode 100644
index 0000000..f6f26fc
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/JnaCallException.java
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.exception;
+
+/** Thrown to indicate JNA functions are not called as expected. */
+public class JnaCallException extends BaseException {
+
+    private static final long serialVersionUID = 1L;
+
+    /**
+     * Constructs a new exception with the specified detail message. The cause is not initialized,
+     * and may subsequently be initialized by a call to {@link #initCause}.
+     *
+     * @param message the detail message. The detail message is saved for later retrieval by the
+     *     {@link #getMessage()} method.
+     */
+    public JnaCallException(String message) {
+        super(message);
+    }
+
+    /**
+     * \ Constructs a new exception with the specified detail message and cause.
+     *
+     * <p>Note that the detail message associated with {@code cause} is <i>not</i> automatically
+     * incorporated in this exception's detail message.
+     *
+     * @param message the detail message that is saved for later retrieval by the {@link
+     *     #getMessage()} method
+     * @param cause the cause that is saved for later retrieval by the {@link #getCause()} method. A
+     *     {@code null} value is permitted, and indicates that the cause is nonexistent or unknown
+     */
+    public JnaCallException(String message, Throwable cause) {
+        super(message, cause);
+    }
+
+    /**
+     * Constructs a new exception with the specified cause and a detail message of {@code
+     * (cause==null ? null : cause.toString())} which typically contains the class and detail
+     * message of {@code cause}. This constructor is useful for exceptions that are little more than
+     * wrappers for other throwables. For example, {@link java.security.PrivilegedActionException}.
+     *
+     * @param cause the cause that is saved for later retrieval by the {@link #getCause()} method. A
+     *     {@code null} value is permitted, and indicates that the cause is nonexistent or unknown
+     */
+    public JnaCallException(Throwable cause) {
+        super(cause);
+    }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/MalformedModelException.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/MalformedModelException.java
new file mode 100644
index 0000000..5455633
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/MalformedModelException.java
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.exception;
+
+/** Thrown to indicate Model parameters are not in expected format or are malformed. */
+public class MalformedModelException extends ModelException {
+
+    private static final long serialVersionUID = 1L;
+
+    /**
+     * Constructs a new exception with the specified detail message. The cause is not initialized,
+     * and may subsequently be initialized by a call to {@link #initCause}.
+     *
+     * @param message the detail message. The detail message is saved for later retrieval by the
+     *     {@link #getMessage()} method.
+     */
+    public MalformedModelException(String message) {
+        super(message);
+    }
+
+    /**
+     * \ Constructs a new exception with the specified detail message and cause.
+     *
+     * <p>Note that the detail message associated with {@code cause} is <i>not</i> automatically
+     * incorporated in this exception's detail message.
+     *
+     * @param message the detail message that is saved for later retrieval by the {@link
+     *     #getMessage()} method
+     * @param cause the cause that is saved for later retrieval by the {@link #getCause()} method. A
+     *     {@code null} value is permitted, and indicates that the cause is nonexistent or unknown
+     */
+    public MalformedModelException(String message, Throwable cause) {
+        super(message, cause);
+    }
+
+    /**
+     * Constructs a new exception with the specified cause and a detail message of {@code
+     * (cause==null ? null : cause.toString())} which typically contains the class and detail
+     * message of {@code cause}. This constructor is useful for exceptions that are little more than
+     * wrappers for other throwables. For example, {@link java.security.PrivilegedActionException}.
+     *
+     * @param cause the cause that is saved for later retrieval by the {@link #getCause()} method. A
+     *     {@code null} value is permitted, and indicates that the cause is nonexistent or unknown
+     */
+    public MalformedModelException(Throwable cause) {
+        super(cause);
+    }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/ModelException.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/ModelException.java
new file mode 100644
index 0000000..94a12e7
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/ModelException.java
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.exception;
+
+/** Thrown to indicate . */
+public class ModelException extends BaseException {
+
+    private static final long serialVersionUID = 1L;
+
+    /**
+     * Constructs a new exception with the specified detail message. The cause is not initialized,
+     * and may subsequently be initialized by a call to {@link #initCause}.
+     *
+     * @param message the detail message that is saved for later retrieval by the {@link
+     *     #getMessage()} method
+     */
+    public ModelException(String message) {
+        super(message);
+    }
+
+    /**
+     * Constructs a new exception with the specified detail message and cause.
+     *
+     * <p>Note that the detail message associated with {@code cause} is <i>not</i> automatically
+     * incorporated in this exception's detail message.
+     *
+     * @param message the detail message that is saved for later retrieval by the {@link
+     *     #getMessage()} method
+     * @param cause the cause that is saved for later retrieval by the {@link #getCause()} method. A
+     *     {@code null} value is permitted, and indicates that the cause is nonexistent or unknown
+     */
+    public ModelException(String message, Throwable cause) {
+        super(message, cause);
+    }
+
+    /**
+     * Constructs a new exception with the specified cause and a detail message of {@code
+     * (cause==null ? null : cause.toString())} which typically contains the class and detail
+     * message of {@code cause}. This constructor is useful for exceptions that are little more than
+     * wrappers for other throwables. For example, {@link java.security.PrivilegedActionException}.
+     *
+     * @param cause the cause that is saved for later retrieval by the {@link #getCause()} method. A
+     *     {@code null} value is permitted, and indicates that the cause is nonexistent or unknown
+     */
+    public ModelException(Throwable cause) {
+        super(cause);
+    }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/TranslateException.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/TranslateException.java
new file mode 100644
index 0000000..b17758a
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/TranslateException.java
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.exception;
+
+/** Thrown to indicate Translate pipeline doesn't work as expected. */
+public class TranslateException extends BaseException {
+
+    private static final long serialVersionUID = 1L;
+
+    /**
+     * Constructs a new exception with the specified detail message. The cause is not initialized,
+     * and may subsequently be initialized by a call to {@link #initCause}.
+     *
+     * @param message the detail message. The detail message is saved for later retrieval by the
+     *     {@link #getMessage()} method.
+     */
+    public TranslateException(String message) {
+        super(message);
+    }
+
+    /**
+     * \ Constructs a new exception with the specified detail message and cause.
+     *
+     * <p>Note that the detail message associated with {@code cause} is <i>not</i> automatically
+     * incorporated in this exception's detail message.
+     *
+     * @param message the detail message that is saved for later retrieval by the {@link
+     *     #getMessage()} method
+     * @param cause the cause that is saved for later retrieval by the {@link #getCause()} method. A
+     *     {@code null} value is permitted, and indicates that the cause is nonexistent or unknown
+     */
+    public TranslateException(String message, Throwable cause) {
+        super(message, cause);
+    }
+
+    /**
+     * Constructs a new exception with the specified cause and a detail message of {@code
+     * (cause==null ? null : cause.toString())} which typically contains the class and detail
+     * message of {@code cause}. This constructor is useful for exceptions that are little more than
+     * wrappers for other throwables. For example, {@link java.security.PrivilegedActionException}.
+     *
+     * @param cause the cause that is saved for later retrieval by the {@link #getCause()} method. A
+     *     {@code null} value is permitted, and indicates that the cause is nonexistent or unknown
+     */
+    public TranslateException(Throwable cause) {
+        super(cause);
+    }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/package-info.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/package-info.java
new file mode 100644
index 0000000..d464bfd
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/exception/package-info.java
@@ -0,0 +1,19 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/** Contains the Java front-end implementation for Apache MXNet. */
+package org.apache.mxnet.exception;
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/FunctionInfo.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/FunctionInfo.java
new file mode 100644
index 0000000..a2da116
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/FunctionInfo.java
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.jna;
+
+import com.sun.jna.Pointer;
+import java.util.List;
+import org.apache.mxnet.engine.Device;
+import org.apache.mxnet.engine.MxResource;
+import org.apache.mxnet.ndarray.NDArray;
+import org.apache.mxnet.ndarray.types.SparseFormat;
+import org.apache.mxnet.util.PairList;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/** A FunctionInfo represents an operator (ie function) within the MXNet Engine. */
+public class FunctionInfo {
+
+    private Pointer handle;
+    private String name;
+    private PairList<String, String> arguments;
+
+    private static final Logger logger = LoggerFactory.getLogger(FunctionInfo.class);
+
+    FunctionInfo(Pointer pointer, String functionName, PairList<String, String> arguments) {
+        this.handle = pointer;
+        this.name = functionName;
+        this.arguments = arguments;
+    }
+
+    /**
+     * Returns the name of the operator.
+     *
+     * @return the name of the operator
+     */
+    public String getFunctionName() {
+        return name;
+    }
+
+    /**
+     * Returns the names of the params to the operator.
+     *
+     * @return the names of the params to the operator
+     */
+    public List<String> getArgumentNames() {
+        return arguments.keys();
+    }
+
+    /**
+     * Returns the types of the operator arguments.
+     *
+     * @return the types of the operator arguments
+     */
+    public List<String> getArgumentTypes() {
+        return arguments.values();
+    }
+    /**
+     * Calls an operator with the given arguments.
+     *
+     * @param src the input NDArray(s) to the operator
+     * @param dest the destination NDArray(s) to be overwritten with the result of the operator
+     * @param params the non-NDArray arguments to the operator. Should be a {@code PairList<String,
+     *     String>}
+     * @return the error code or zero for no errors
+     */
+    public int invoke(NDArray[] src, NDArray[] dest, PairList<String, ?> params) {
+        checkDevices(src);
+        checkDevices(dest);
+        return JnaUtils.imperativeInvoke(handle, src, dest, params).size();
+    }
+
+    /**
+     * Calls an operator with the given arguments.
+     *
+     * @param parent {@link MxResource} for the current instance
+     * @param src the input NDArray(s) to the operator
+     * @param params the non-NDArray arguments to the operator. Should be a {@code PairList<String,
+     *     String>}
+     * @return the error code or zero for no errors
+     */
+    public NDArray[] invoke(MxResource parent, NDArray[] src, PairList<String, ?> params) {
+        checkDevices(src);
+        PairList<Pointer, SparseFormat> pairList =
+                JnaUtils.imperativeInvoke(handle, src, null, params);
+        return pairList.stream()
+                .map(
+                        pair -> {
+                            if (pair.getValue() != SparseFormat.DENSE) {
+                                return NDArray.create(parent, pair.getKey(), pair.getValue());
+                            }
+                            return NDArray.create(parent, pair.getKey());
+                        })
+                .toArray(NDArray[]::new);
+    }
+
+    private void checkDevices(NDArray[] src) {
+        // check if all the NDArrays are in the same device
+        if (logger.isDebugEnabled() && src.length > 1) {
+            Device device = src[0].getDevice();
+            for (int i = 1; i < src.length; ++i) {
+                if (!device.equals(src[i].getDevice())) {
+                    logger.warn(
+                            "Please make sure all the NDArrays are in the same device. You can call toDevice() to move the NDArray to the desired Device.");
+                }
+            }
+        }
+    }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/JnaUtils.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/JnaUtils.java
new file mode 100644
index 0000000..46cff30
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/JnaUtils.java
@@ -0,0 +1,893 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.jna;
+
+import com.sun.jna.Memory;
+import com.sun.jna.Native;
+import com.sun.jna.Pointer;
+import com.sun.jna.ptr.PointerByReference;
+import java.nio.Buffer;
+import java.nio.ByteBuffer;
+import java.nio.IntBuffer;
+import java.nio.charset.StandardCharsets;
+import java.nio.file.Path;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
+import org.apache.mxnet.engine.CachedOp;
+import org.apache.mxnet.engine.Device;
+import org.apache.mxnet.engine.DeviceType;
+import org.apache.mxnet.engine.MxResource;
+import org.apache.mxnet.engine.Symbol;
+import org.apache.mxnet.exception.JnaCallException;
+import org.apache.mxnet.ndarray.NDArray;
+import org.apache.mxnet.ndarray.NDList;
+import org.apache.mxnet.ndarray.types.DataType;
+import org.apache.mxnet.ndarray.types.Shape;
+import org.apache.mxnet.ndarray.types.SparseFormat;
+import org.apache.mxnet.nn.Parameter;
+import org.apache.mxnet.nn.SymbolBlock;
+import org.apache.mxnet.util.PairList;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * A class containing utilities to interact with the MXNet Engine's Java Native Access (JNA) layer.
+ */
+@SuppressWarnings("MissingJavadocMethod")
+public final class JnaUtils {
+
+    private static final Logger logger = LoggerFactory.getLogger(JnaUtils.class);
+
+    public static final MxnetLibrary LIB = LibUtils.loadLibrary();
+
+    public static final ObjectPool<PointerByReference> REFS =
+            new ObjectPool<>(PointerByReference::new, r -> r.setValue(null));
+
+    private static final String[] OP_NAME_PREFIX = {
+        "_contrib_", "_linalg_", "_sparse_", "_image_", "_random_"
+    };
+
+    private static final Map<String, FunctionInfo> OPS = getNdArrayFunctions();
+    //    private static final Map<String, FunctionInfo> OPS = null;
+
+    private static final Set<String> FEATURES = getFeaturesInternal();
+
+    public static final String[] EMPTY_ARRAY = new String[0];
+
+    private JnaUtils() {
+        // not called
+    }
+
+    /** An enum that enumerates the statuses of numpy mode. */
+    public enum NumpyMode {
+        OFF,
+        THREAD_LOCAL_ON,
+        GLOBAL_ON
+    }
+
+    public static void waitAll() {
+        checkCall(LIB.MXNDArrayWaitAll());
+    }
+
+    public static void setNumpyMode(NumpyMode mode) {
+        IntBuffer ret = IntBuffer.allocate(1);
+        checkCall(LIB.MXSetIsNumpyShape(mode.ordinal(), ret));
+    }
+
+    /////////////////////////////////
+    // Related to CacheOp
+    /////////////////////////////////
+    public static CachedOp createCachedOp(SymbolBlock block, MxResource parent) {
+        Symbol symbol = block.getSymbol();
+
+        List<Parameter> parameters = block.getAllParameters();
+
+        // record data index in all inputs
+        PairList<String, Integer> dataIndices = new PairList<>();
+        // record parameter index in all inputs
+        List<Integer> paramIndices = new ArrayList<>();
+        int index = 0;
+        for (Parameter parameter : parameters) {
+            // We assume uninitialized parameters are data inputs
+            if (parameter.isInitialized()) {
+                paramIndices.add(index);
+            } else {
+                dataIndices.add(parameter.getName(), index);
+            }
+            ++index;
+        }
+
+        // Creating CachedOp
+        Pointer symbolHandle = symbol.getHandle();
+        PointerByReference ref = REFS.acquire();
+
+        // static_alloc and static_shape are enabled by default
+        String[] keys = {"data_indices", "param_indices", "static_alloc", "static_shape"};
+        String[] values = {dataIndices.values().toString(), paramIndices.toString(), "1", "1"};
+
+        checkCall(LIB.MXCreateCachedOp(symbolHandle, keys.length, keys, values, ref, (byte) 0));
+        Pointer pointer = ref.getValue();
+        REFS.recycle(ref);
+
+        return new CachedOp(parent, pointer, parameters, paramIndices, dataIndices);
+    }
+
+    public static void freeCachedOp(Pointer handle) {
+        checkCall(LIB.MXFreeCachedOp(handle));
+    }
+
+    /////////////////////////////////
+    // About Symbol
+    /////////////////////////////////
+    public static Pointer createSymbolFromFile(String path) {
+        PointerByReference ref = REFS.acquire();
+        checkCall(LIB.MXSymbolCreateFromFile(path, ref));
+        Pointer pointer = ref.getValue();
+        REFS.recycle(ref);
+        return pointer;
+    }
+
+    public static Pointer createSymbolFromString(String json) {
+        PointerByReference ref = REFS.acquire();
+        checkCall(LIB.MXSymbolCreateFromJSON(json, ref));
+        Pointer pointer = ref.getValue();
+        REFS.recycle(ref);
+        return pointer;
+    }
+
+    public static String[] listSymbolOutputs(Pointer symbol) {
+        IntBuffer size = IntBuffer.allocate(1);
+        PointerByReference ref = REFS.acquire();
+
+        checkCall(LIB.MXSymbolListOutputs(symbol, size, ref));
+        String[] ret = toStringArray(ref, size.get());
+        REFS.recycle(ref);
+        return ret;
+    }
+
+    public static String printSymbol(Pointer symbol) {
+        String[] outStr = new String[1];
+        checkCall(LIB.NNSymbolPrint(symbol, outStr));
+        return outStr[0];
+    }
+
+    public static void freeSymbol(Pointer symbol) {
+        checkCall(LIB.MXSymbolFree(symbol));
+    }
+
+    public static String[] listSymbolArguments(Pointer symbol) {
+        IntBuffer size = IntBuffer.allocate(1);
+        PointerByReference ref = REFS.acquire();
+
+        checkCall(LIB.MXSymbolListArguments(symbol, size, ref));
+
+        String[] ret = toStringArray(ref, size.get());
+        REFS.recycle(ref);
+        return ret;
+    }
+
+    public static String[] listSymbolAuxiliaryStates(Pointer symbol) {
+        IntBuffer size = IntBuffer.allocate(1);
+        PointerByReference ref = REFS.acquire();
+
+        checkCall(LIB.MXSymbolListAuxiliaryStates(symbol, size, ref));
+
+        String[] ret = toStringArray(ref, size.get());
+        REFS.recycle(ref);
+        return ret;
+    }
+
+    public static String[] listSymbolNames(Pointer symbol) {
+        IntBuffer size = IntBuffer.allocate(1);
+        PointerByReference ref = REFS.acquire();
+
+        checkCall(LIB.NNSymbolListInputNames(symbol, 0, size, ref));
+
+        String[] ret = toStringArray(ref, size.get());
+        REFS.recycle(ref);
+        return ret;
+    }
+
+    public static Pointer getSymbolInternals(Pointer symbol) {
+        PointerByReference ref = REFS.acquire();
+        checkCall(LIB.MXSymbolGetInternals(symbol, ref));
+        Pointer pointer = ref.getValue();
+        REFS.recycle(ref);
+        return pointer;
+    }
+
+    private static List<Shape> recoverShape(
+            NativeSizeByReference size, PointerByReference nDim, PointerByReference data) {
+        int shapeLength = (int) size.getValue().longValue();
+        if (shapeLength == 0) {
+            return new ArrayList<>();
+        }
+        int[] dims = nDim.getValue().getIntArray(0, shapeLength);
+        int flattenedLength = 0;
+        for (int dim : dims) {
+            flattenedLength += dim;
+        }
+        long[] flattenedShapes = data.getValue().getPointer(0).getLongArray(0, flattenedLength);
+        int idx = 0;
+        List<Shape> result = new ArrayList<>();
+        for (int dim : dims) {
+            long[] shape = new long[dim];
+            System.arraycopy(flattenedShapes, idx, shape, 0, dim);
+            idx += dim;
+            result.add(new Shape(shape));
+        }
+        return result;
+    }
+
+    public static List<List<Shape>> inferShape(Symbol symbol, PairList<String, Shape> args) {
+        Pointer handler = symbol.getHandle();
+        int numArgs = args.size();
+        String[] keys = args.keys().toArray(new String[0]);
+        // the following two is also the representation of
+        // CSR NDArray
+        long[] indPtr = new long[numArgs + 1];
+        Shape flattened = new Shape();
+        indPtr[0] = 0;
+        for (int i = 0; i < args.size(); i++) {
+            Shape shape = args.valueAt(i);
+            indPtr[i + 1] = shape.dimension();
+            flattened = flattened.addAll(shape);
+        }
+        long[] flattenedShapeArray = flattened.getShape();
+
+        NativeSizeByReference inShapeSize = new NativeSizeByReference();
+        PointerByReference inShapeNDim = REFS.acquire();
+        PointerByReference inShapeData = REFS.acquire();
+        NativeSizeByReference outShapeSize = new NativeSizeByReference();
+        PointerByReference outShapeNDim = REFS.acquire();
+        PointerByReference outShapeData = REFS.acquire();
+        NativeSizeByReference auxShapeSize = new NativeSizeByReference();
+        PointerByReference auxShapeNDim = REFS.acquire();
+        PointerByReference auxShapeData = REFS.acquire();
+        IntBuffer complete = IntBuffer.allocate(1);
+        checkCall(
+                LIB.MXSymbolInferShape64(
+                        handler,
+                        numArgs,
+                        keys,
+                        indPtr,
+                        flattenedShapeArray,
+                        inShapeSize,
+                        inShapeNDim,
+                        inShapeData,
+                        outShapeSize,
+                        outShapeNDim,
+                        outShapeData,
+                        auxShapeSize,
+                        auxShapeNDim,
+                        auxShapeData,
+                        complete));
+        if (complete.get() != 0) {
+            return Arrays.asList(
+                    recoverShape(inShapeSize, inShapeNDim, inShapeData),
+                    recoverShape(outShapeSize, outShapeNDim, outShapeData),
+                    recoverShape(auxShapeSize, auxShapeNDim, auxShapeData));
+        }
+        return null;
+    }
+
+    public static Pointer optimizeFor(Symbol current, String backend, Device device) {
+        // TODO: Support partition on parameters
+        PointerByReference returnedSymbolHandle = REFS.acquire();
+        // placeHolders
+        PointerByReference[] placeHolders = {
+            REFS.acquire(),
+            REFS.acquire(),
+            REFS.acquire(),
+            REFS.acquire(),
+            REFS.acquire(),
+            REFS.acquire()
+        };
+        // there is no need to update parameters
+        // TODO : check 22th parameter type
+        checkCall(
+                LIB.MXOptimizeForBackend(
+                        current.getHandle(),
+                        backend,
+                        DeviceType.toDeviceType(device),
+                        returnedSymbolHandle,
+                        0,
+                        placeHolders[0],
+                        0,
+                        placeHolders[1],
+                        0,
+                        new String[0],
+                        new String[0],
+                        0,
+                        new String[0],
+                        new long[0],
+                        new int[0],
+                        0,
+                        new String[0],
+                        new int[0],
+                        0,
+                        new String[0],
+                        new int[0],
+                        (byte) 0,
+                        IntBuffer.allocate(0),
+                        placeHolders[2],
+                        placeHolders[3],
+                        IntBuffer.allocate(0),
+                        placeHolders[4],
+                        placeHolders[5]));
+        Pointer ptr = returnedSymbolHandle.getValue();
+        REFS.recycle(returnedSymbolHandle);
+        Arrays.stream(placeHolders).forEach(REFS::recycle);
+        return ptr;
+    }
+
+    public static String getSymbolString(Pointer symbol) {
+        String[] holder = new String[1];
+        checkCall(LIB.MXSymbolSaveToJSON(symbol, holder));
+        return holder[0];
+    }
+
+    public static Pointer getSymbolOutput(Pointer symbol, int index) {
+        PointerByReference ref = REFS.acquire();
+        checkCall(LIB.MXSymbolGetOutput(symbol, index, ref));
+        Pointer pointer = ref.getValue();
+        REFS.recycle(ref);
+        return pointer;
+    }
+
+    public static NDList loadNdArray(MxResource parent, Path path, Device device) {
+        IntBuffer handlesSize = IntBuffer.allocate(1);
+        PointerByReference handlesRef = REFS.acquire();
+        PointerByReference namesRef = REFS.acquire();
+        IntBuffer namesSize = IntBuffer.allocate(1);
+        checkCall(LIB.MXNDArrayLoad(path.toString(), handlesSize, handlesRef, namesSize, namesRef));
+        int ndArrayCount = handlesSize.get();
+        int nameCount = namesSize.get();
+        if (nameCount > 0 && ndArrayCount != nameCount) {
+            throw new IllegalStateException(
+                    "Mismatch between names and arrays in checkpoint file: " + path.toString());
+        }
+        Pointer[] handles = handlesRef.getValue().getPointerArray(0, ndArrayCount);
+        NDList ndList = new NDList();
+        if (nameCount == 0) {
+            for (Pointer handle : handles) {
+                ndList.add(NDArray.create(parent, handle));
+            }
+        } else {
+            String[] names = namesRef.getValue().getStringArray(0, nameCount);
+            for (int i = 0; i < ndArrayCount; i++) {
+                NDArray array = NDArray.create(parent, handles[i]);
+                array.setName(names[i]);
+                ndList.add(array);
+            }
+        }
+
+        REFS.recycle(namesRef);
+        REFS.recycle(handlesRef);
+
+        // MXNet always load NDArray on CPU
+        if (Device.cpu().equals(device)) {
+            return ndList;
+        }
+
+        NDList ret = ndList.toDevice(device, true);
+        ndList.close();
+        return ret;
+    }
+
+    public static PairList<String, Pointer> loadNdArrayFromFile(String path) {
+        IntBuffer handleSize = IntBuffer.allocate(1);
+        IntBuffer namesSize = IntBuffer.allocate(1);
+        PointerByReference handlesRef = REFS.acquire();
+        PointerByReference namesRef = REFS.acquire();
+        checkCall(LIB.MXNDArrayLoad(path, handleSize, handlesRef, namesSize, namesRef));
+        // TODO : construct NDArray Objects
+        int handleCount = handleSize.get();
+        int nameCount = namesSize.get();
+        if (nameCount > 0 && nameCount != handleCount) {
+            throw new IllegalStateException(
+                    "Mismatch between names and arrays in checkpoint file: " + path);
+        }
+        Pointer[] handles = handlesRef.getValue().getPointerArray(0, handleCount);
+
+        PairList<String, Pointer> pairList = new PairList<>();
+
+        if (nameCount == 0) {
+            for (Pointer handle : handles) {
+                pairList.add(null, handle);
+            }
+        } else {
+            String[] names = namesRef.getValue().getStringArray(0, nameCount);
+            for (int i = 0; i < handleCount; i++) {
+                pairList.add(names[i], handles[i]);
+            }
+        }
+        REFS.recycle(namesRef);
+        REFS.recycle(handlesRef);
+
+        return pairList;
+    }
+
+    public static void freeNdArray(Pointer handle) {
+        checkCall(LIB.MXNDArrayFree(handle));
+    }
+
+    public static Pointer loadNdArrayFromByteArray(byte[] buf, int offset, int size) {
+        Memory memory = new Memory(size);
+        memory.write(0, buf, offset, size);
+        PointerByReference outRef = REFS.acquire();
+        checkCall(LIB.MXNDArrayLoadFromRawBytes(memory, new NativeSize(size), outRef));
+        Pointer p = outRef.getValue();
+        //        outRef.getValue().getPointerArray(0, size);
+
+        REFS.recycle(outRef);
+        return p;
+    }
+
+    public static Pointer loadNdArrayFromByteBuffer(ByteBuffer byteBuffer) {
+        //        Pointer handle = new Pointer(byteBuffer.address);
+        //        ((DirectByteBuffer) byteBuffer).address()
+        // TODO
+        byte[] bytes = new byte[byteBuffer.limit()];
+        byteBuffer.get(bytes);
+        return loadNdArrayFromByteArray(bytes, 0, byteBuffer.limit());
+    }
+
+    public static ByteBuffer saveNdArrayAsByteBuffer(Pointer ndArray) {
+        NativeSizeByReference size = new NativeSizeByReference();
+        PointerByReference ref = new PointerByReference();
+        checkCall(LIB.MXNDArraySaveRawBytes(ndArray, size, ref));
+        return ref.getValue().getByteBuffer(0, size.getValue().longValue());
+    }
+
+    public static byte[] saveNdArrayAsByteArray(Pointer ndArray) {
+        ByteBuffer buffer = saveNdArrayAsByteBuffer(ndArray);
+        byte[] bytes = new byte[buffer.limit()];
+        buffer.get(bytes);
+        return bytes;
+    }
+
+    public static void syncCopyToCPU(Pointer ndArray, Pointer data, int len) {
+        NativeSize size = new NativeSize(len);
+        checkNDArray(ndArray, "copy from");
+        checkNDArray(data, "copy to");
+        checkCall(LIB.MXNDArraySyncCopyToCPU(ndArray, data, size));
+    }
+
+    public static void syncCopyFromCPU(Pointer ndArray, Buffer data, int len) {
+        NativeSize size = new NativeSize(len);
+        Pointer pointer = Native.getDirectBufferPointer(data);
+        checkCall(LIB.MXNDArraySyncCopyFromCPU(ndArray, pointer, size));
+    }
+
+    public static void waitToRead(Pointer ndArray) {
+        checkNDArray(ndArray, "wait to read");
+        checkCall(LIB.MXNDArrayWaitToRead(ndArray));
+    }
+
+    public static void waitToWrite(Pointer ndArray) {
+        checkNDArray(ndArray, "wait to write");
+        checkCall(LIB.MXNDArrayWaitToWrite(ndArray));
+    }
+
+    public static Pointer detachGradient(Pointer handle) {
+        PointerByReference ref = REFS.acquire();
+        checkCall(LIB.MXNDArrayDetach(handle, ref));
+        Pointer pointer = ref.getValue();
+        REFS.recycle(ref);
+        return pointer;
+    }
+
+    public static Pointer getGradient(Pointer handle) {
+        PointerByReference ref = REFS.acquire();
+        checkNDArray(handle, "get the gradient for");
+        checkCall(LIB.MXNDArrayGetGrad(handle, ref));
+        Pointer pointer = ref.getValue();
+        REFS.recycle(ref);
+        return pointer;
+    }
+
+    public static void autogradMarkVariables(
+            int numVar, Pointer varHandles, IntBuffer reqsArray, Pointer gradHandles) {
+        PointerByReference varRef = REFS.acquire();
+        PointerByReference gradRef = REFS.acquire();
+        varRef.setValue(varHandles);
+        gradRef.setValue(gradHandles);
+        checkCall(LIB.MXAutogradMarkVariables(numVar, varRef, reqsArray, gradRef));
+        REFS.recycle(varRef);
+        REFS.recycle(gradRef);
+    }
+
+    public static Map<String, FunctionInfo> getNdArrayFunctions() {
+        Set<String> opNames = JnaUtils.getAllOpNames();
+        Map<String, FunctionInfo> map = new ConcurrentHashMap<>();
+
+        PointerByReference ref = REFS.acquire();
+        for (String opName : opNames) {
+            checkCall(LIB.NNGetOpHandle(opName, ref));
+            String functionName = getOpNamePrefix(opName);
+            map.put(functionName, getFunctionByName(opName, functionName, ref.getValue()));
+            ref.setValue(null);
+        }
+        REFS.recycle(ref);
+        return map;
+    }
+
+    public static PairList<Pointer, SparseFormat> imperativeInvoke(
+            Pointer function, NDArray[] src, NDArray[] dest, PairList<String, ?> params) {
+        String[] keys;
+        String[] values;
+        if (params == null) {
+            keys = EMPTY_ARRAY;
+            values = EMPTY_ARRAY;
+        } else {
+            keys = params.keyArray(EMPTY_ARRAY);
+            values = params.values().stream().map(Object::toString).toArray(String[]::new);
+        }
+        //        StringArray keyArray = StringArray.of(keys);
+        //        StringArray valueArray = StringArray.of(values);
+        PointerArray srcArray = toPointerArray(src);
+        PointerArray destArray = toPointerArray(dest);
+        PointerByReference destRef = REFS.acquire();
+        destRef.setValue(destArray);
+        PointerByReference destSType = REFS.acquire();
+        IntBuffer numOutputs = IntBuffer.allocate(1);
+        numOutputs.put(0, 1);
+
+        checkCall(
+                LIB.MXImperativeInvoke(
+                        function,
+                        src.length,
+                        srcArray,
+                        numOutputs,
+                        destRef,
+                        keys.length,
+                        keys,
+                        values,
+                        destSType));
+        int numOfOutputs = numOutputs.get(0);
+        Pointer[] ptrArray = destRef.getValue().getPointerArray(0, numOfOutputs);
+        int[] sTypes = destSType.getValue().getIntArray(0, numOfOutputs);
+        PairList<Pointer, SparseFormat> pairList = new PairList<>();
+        for (int i = 0; i < numOfOutputs; i++) {
+            pairList.add(ptrArray[i], SparseFormat.fromValue(sTypes[i]));
+        }
+        REFS.recycle(destRef);
+        REFS.recycle(destSType);
+        srcArray.recycle();
+        //        keyArray.recycle();
+        //        valueArray.recycle();
+
+        if (destArray != null) {
+            destArray.recycle();
+        }
+        return pairList;
+    }
+
+    private static PointerArray toPointerArray(NDArray[] vals) {
+        if (vals == null) {
+            return null;
+        }
+        Pointer[] valPointers = new Pointer[vals.length];
+        for (int i = 0; i < vals.length; i++) {
+            valPointers[i] = vals[i].getHandle();
+        }
+        return PointerArray.of(valPointers);
+    }
+
+    public static FunctionInfo op(String opName) {
+        if (!OPS.containsKey(opName)) {
+            throw new IllegalArgumentException("Unknown operator: " + opName);
+        }
+        return OPS.get(opName);
+    }
+
+    public static FunctionInfo getFunctionByName(String name, String functionName, Pointer handle) {
+        String[] nameRef = {name};
+        String[] description = new String[1];
+        IntBuffer numArgs = IntBuffer.allocate(1);
+        PointerByReference argNameRef = REFS.acquire();
+        PointerByReference argTypeRef = REFS.acquire();
+        PointerByReference argDescRef = REFS.acquire();
+        String[] keyVarArgs = new String[1];
+        String[] returnType = new String[1];
+
+        checkCall(
+                LIB.MXSymbolGetAtomicSymbolInfo(
+                        handle,
+                        nameRef,
+                        description,
+                        numArgs,
+                        argNameRef,
+                        argTypeRef,
+                        argDescRef,
+                        keyVarArgs,
+                        returnType));
+
+        int count = numArgs.get();
+        PairList<String, String> arguments = new PairList<>();
+        if (count != 0) {
+            String[] argNames =
+                    argNameRef.getValue().getStringArray(0, count, StandardCharsets.UTF_8.name());
+            String[] argTypes =
+                    argTypeRef.getValue().getStringArray(0, count, StandardCharsets.UTF_8.name());
+            for (int i = 0; i < argNames.length; i++) {
+                arguments.add(argNames[i], argTypes[i]);
+            }
+        }
+
+        REFS.recycle(argNameRef);
+        REFS.recycle(argTypeRef);
+        REFS.recycle(argDescRef);
+
+        return new FunctionInfo(handle, functionName, arguments);
+    }
+
+    public static Set<String> getAllOpNames() {
+        IntBuffer outSize = IntBuffer.allocate(1);
+        PointerByReference outArray = REFS.acquire();
+
+        checkCall(LIB.MXListAllOpNames(outSize, outArray));
+
+        int size = outSize.get();
+        Pointer[] pointers = outArray.getValue().getPointerArray(0, size);
+
+        Set<String> set = new HashSet<>();
+        for (Pointer p : pointers) {
+            set.add(p.getString(0, StandardCharsets.UTF_8.name()));
+        }
+        REFS.recycle(outArray);
+        return set;
+    }
+
+    public static String getOpNamePrefix(String name) {
+        for (String prefix : OP_NAME_PREFIX) {
+            if (name.startsWith(prefix)) {
+                return name.substring(prefix.length());
+            }
+        }
+        return name;
+    }
+
+    public static DataType getDataTypeOfNdArray(Pointer handle) {
+        IntBuffer dataType = IntBuffer.allocate(1);
+        checkNDArray(handle, "get the data type of");
+        checkCall(LIB.MXNDArrayGetDType(handle, dataType));
+        return DataType.values()[dataType.get()];
+    }
+
+    public static Device getDeviceOfNdArray(Pointer handle) {
+        IntBuffer deviceType = IntBuffer.allocate(1);
+        IntBuffer deviceId = IntBuffer.allocate(1);
+        checkNDArray(handle, "get the device of");
+        checkCall(LIB.MXNDArrayGetContext(handle, deviceType, deviceId));
+        String deviceTypeStr = DeviceType.fromDeviceType(deviceType.get(0));
+        // CPU is special case which don't have device id
+        return Device.of(deviceTypeStr, deviceId.get(0));
+    }
+
+    public static Shape getShapeOfNdArray(Pointer handle) {
+        IntBuffer dim = IntBuffer.allocate(1);
+        PointerByReference ref = REFS.acquire();
+        checkNDArray(handle, "get the shape of");
+        checkCall(LIB.MXNDArrayGetShape(handle, dim, ref));
+        int nDim = dim.get();
+        if (nDim == 0) {
+            REFS.recycle(ref);
+            return new Shape();
+        }
+        int[] shape = ref.getValue().getIntArray(0, nDim);
+        REFS.recycle(ref);
+        return new Shape(Arrays.stream(shape).asLongStream().toArray());
+    }
+
+    public static Shape getShape64OfNdArray(Pointer handle) {
+        IntBuffer dim = IntBuffer.allocate(1);
+        PointerByReference ref = REFS.acquire();
+        checkNDArray(handle, "get the shape64 of");
+        checkCall(LIB.MXNDArrayGetShape64(handle, dim, ref));
+        int nDim = dim.get();
+        if (nDim == 0) {
+            REFS.recycle(ref);
+            return new Shape();
+        }
+        int[] shape = ref.getValue().getIntArray(0, nDim);
+        REFS.recycle(ref);
+        return new Shape(Arrays.stream(shape).asLongStream().toArray());
+    }
+
+    public static SparseFormat getStorageType(Pointer handle) {
+        IntBuffer type = IntBuffer.allocate(1);
+        checkNDArray(handle, "get the storage type of");
+        checkCall(LIB.MXNDArrayGetStorageType(handle, type));
+        return SparseFormat.fromValue(type.get());
+    }
+
+    public static Pointer createNdArray(
+            Device device, Shape shape, DataType dataType, int size, boolean delayedAlloc) {
+        int deviceType = DeviceType.toDeviceType(device);
+        int deviceId = (deviceType != 1) ? device.getDeviceId() : -1;
+        int delay = delayedAlloc ? 1 : 0;
+
+        PointerByReference ref = REFS.acquire();
+        int[] shapeArray = Arrays.stream(shape.getShape()).mapToInt(Math::toIntExact).toArray();
+        checkCall(
+                LIB.MXNDArrayCreate(
+                        shapeArray, size, deviceType, deviceId, delay, dataType.ordinal(), ref));
+
+        Pointer pointer = ref.getValue();
+        REFS.recycle(ref);
+        return pointer;
+    }
+
+    public static Pointer createSparseNdArray(
+            SparseFormat fmt,
+            Device device,
+            Shape shape,
+            DataType dtype,
+            DataType[] auxDTypes,
+            Shape[] auxShapes,
+            boolean delayedAlloc) {
+        int[] shapeArray = Arrays.stream(shape.getShape()).mapToInt(Math::toIntExact).toArray();
+        int deviceType = DeviceType.toDeviceType(device);
+        int deviceId = (deviceType != 1) ? device.getDeviceId() : -1;
+        int delay = delayedAlloc ? 1 : 0;
+        PointerByReference ref = REFS.acquire();
+        IntBuffer auxDTypesInt =
+                IntBuffer.wrap(Arrays.stream(auxDTypes).mapToInt(DataType::ordinal).toArray());
+        IntBuffer auxNDims =
+                IntBuffer.wrap(Arrays.stream(auxShapes).mapToInt(Shape::dimension).toArray());
+        int[] auxShapesInt = Arrays.stream(auxShapes).mapToInt(ele -> (int) ele.head()).toArray();
+        checkCall(
+                LIB.MXNDArrayCreateSparseEx(
+                        fmt.getValue(),
+                        shapeArray,
+                        shapeArray.length,
+                        deviceType,
+                        deviceId,
+                        delay,
+                        dtype.ordinal(),
+                        auxDTypes.length,
+                        auxDTypesInt,
+                        auxNDims,
+                        auxShapesInt,
+                        ref));
+        Pointer pointer = ref.getValue();
+        REFS.recycle(ref);
+        return pointer;
+    }
+
+    public static void ndArraySyncCopyFromNdArray(NDArray dest, NDArray src, int location) {
+        checkCall(LIB.MXNDArraySyncCopyFromNDArray(dest.getHandle(), src.getHandle(), location));
+    }
+
+    public static int getVersion() {
+        IntBuffer version = IntBuffer.allocate(1);
+        checkCall(LIB.MXGetVersion(version));
+        return version.get();
+    }
+
+    public static NDArray[] cachedOpInvoke(
+            MxResource parent, Pointer cachedOpHandle, NDArray[] inputs) {
+        IntBuffer buf = IntBuffer.allocate(1);
+        PointerArray array = toPointerArray(inputs);
+        PointerByReference ref = REFS.acquire();
+        PointerByReference outSTypeRef = REFS.acquire();
+        Device device = inputs[0].getDevice();
+        // TODO: check the init value of default_dev_type and default_dev_id
+        checkCall(
+                LIB.MXInvokeCachedOp(
+                        cachedOpHandle,
+                        inputs.length,
+                        array,
+                        DeviceType.toDeviceType(device),
+                        0,
+                        buf,
+                        ref,
+                        outSTypeRef));
+        int numOutputs = buf.get();
+        Pointer[] ptrArray = ref.getValue().getPointerArray(0, numOutputs);
+        int[] sTypes = outSTypeRef.getValue().getIntArray(0, numOutputs);
+        NDArray[] output = new NDArray[numOutputs];
+        for (int i = 0; i < numOutputs; i++) {
+            if (sTypes[i] != 0) {
+                output[i] = NDArray.create(parent, ptrArray[i], SparseFormat.fromValue(sTypes[i]));
+            } else {
+                output[i] = NDArray.create(parent, ptrArray[i]);
+            }
+        }
+        REFS.recycle(ref);
+        REFS.recycle(outSTypeRef);
+        array.recycle();
+        return output;
+    }
+
+    private static void checkNDArray(Pointer pointer, String msg) {
+        if (pointer == null) {
+            throw new IllegalArgumentException(
+                    "Tried to " + msg + " an MXNet NDArray that was already closed");
+        }
+    }
+
+    public static void checkCall(int ret) {
+        if (ret != 0) {
+            logger.error("MXNet engine call failed: " + getLastError());
+            throw new JnaCallException("MXNet engine call failed: " + getLastError());
+        }
+    }
+
+    private static String getLastError() {
+        return LIB.MXGetLastError();
+    }
+
+    private static String[] toStringArray(PointerByReference ref, int size) {
+        if (size == 0) {
+            return new String[0];
+        }
+
+        Pointer[] pointers = ref.getValue().getPointerArray(0, size);
+
+        String[] arr = new String[size];
+        for (int i = 0; i < size; ++i) {
+            arr[i] = pointers[i].getString(0, StandardCharsets.UTF_8.name());
+        }
+
+        return arr;
+    }
+
+    private static Set<String> getFeaturesInternal() {
+        PointerByReference ref = REFS.acquire();
+        NativeSizeByReference outSize = new NativeSizeByReference();
+        checkCall(LIB.MXLibInfoFeatures(ref, outSize));
+        int size = outSize.getValue().intValue();
+        if (size == 0) {
+            REFS.recycle(ref);
+            return Collections.emptySet();
+        }
+
+        LibFeature pointer = new LibFeature(ref.getValue());
+        pointer.read();
+
+        LibFeature[] features = (LibFeature[]) pointer.toArray(size);
+
+        Set<String> set = new HashSet<>();
+        for (LibFeature feature : features) {
+            if (feature.getEnabled() == 1) {
+                set.add(feature.getName());
+            }
+        }
+        REFS.recycle(ref);
+        return set;
+    }
+
+    public static Set<String> getFeatures() {
+        return FEATURES;
+    }
+
+    public static boolean autogradIsTraining() {
+        ByteBuffer isTraining = ByteBuffer.allocate(1);
+        checkCall(LIB.MXAutogradIsTraining(isTraining));
+        return isTraining.get(0) == 1;
+    }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/LibUtils.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/LibUtils.java
new file mode 100644
index 0000000..d110398
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/LibUtils.java
@@ -0,0 +1,211 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.jna;
+
+import com.sun.jna.Library;
+import com.sun.jna.Native;
+import java.io.File;
+import java.io.IOException;
+import java.io.InputStream;
+import java.net.URL;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.StandardCopyOption;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Enumeration;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+import org.apache.mxnet.util.Platform;
+import org.apache.mxnet.util.Utils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Utilities for finding the MXNet Engine binary on the System.
+ *
+ * <p>The Engine will be searched for in a variety of locations in the following order:
+ *
+ * <ol>
+ *   <li>In the path specified by the MXNET_LIBRARY_PATH environment variable
+ *   <li>In a jar file location in the classpath. These jars can be created with the mxnet-native
+ *       module.
+ *   <li>In the python3 path. These can be installed using pip.
+ *   <li>In the python path. These can be installed using pip.
+ * </ol>
+ */
+@SuppressWarnings("MissingJavadocMethod")
+public final class LibUtils {
+
+    private static final Logger logger = LoggerFactory.getLogger(LibUtils.class);
+
+    private static final String LIB_NAME = "mxnet";
+
+    private static final String MXNET_LIBRARY_PATH = "MXNET_LIBRARY_PATH";
+
+    private static final String MXNET_PROPERTIES_FILE_PATH = "native/lib/mxnet.properties";
+
+    private LibUtils() {}
+
+    public static MxnetLibrary loadLibrary() {
+
+        String libName = getLibName();
+        logger.debug("Loading mxnet library from: {}", libName);
+        if (System.getProperty("os.name").startsWith("Linux")) {
+            logger.info("Loading on Linux platform");
+            Map<String, Integer> options = new ConcurrentHashMap<>();
+            int rtld = 1; // Linux RTLD lazy + local
+            options.put(Library.OPTION_OPEN_FLAGS, rtld);
+            return Native.load(libName, MxnetLibrary.class, options);
+        }
+        return Native.load(libName, MxnetLibrary.class);
+    }
+
+    public static String getLibName() {
+        String libName = findOverrideLibrary();
+        if (libName == null) {
+            libName = LibUtils.findLibraryInClasspath();
+            if (libName == null) {
+                libName = LIB_NAME;
+            }
+        }
+
+        return libName;
+    }
+
+    private static String findOverrideLibrary() {
+        // TODO: load from jar files
+        String libPath = System.getenv(MXNET_LIBRARY_PATH);
+        if (libPath != null) {
+            String libName = findLibraryInPath(libPath);
+            if (libName != null) {
+                return libName;
+            }
+        }
+
+        libPath = System.getProperty("java.library.path");
+        if (libPath != null) {
+            return findLibraryInPath(libPath);
+        }
+        return null;
+    }
+
+    private static synchronized String findLibraryInClasspath() {
+        Enumeration<URL> urls = getUrls();
+        // No native jars
+        if (!urls.hasMoreElements()) {
+            logger.debug("mxnet.properties not found in class path.");
+            return null;
+        }
+
+        // Find the mxnet library version that matches local system platform
+        // throw exception if no one matches
+        Platform systemPlatform = Platform.fromSystem();
+        try {
+            while (urls.hasMoreElements()) {
+                URL url = urls.nextElement();
+                Platform platform = Platform.fromUrl(url);
+                if (!platform.isPlaceholder() && platform.matches(systemPlatform)) {
+                    return loadLibraryFromClasspath(platform);
+                }
+            }
+        } catch (IOException e) {
+            throw new IllegalStateException(
+                    "Failed to read MXNet native library jar properties", e);
+        }
+
+        throw new IllegalStateException(
+                "Your MXNet native library jar does not match your operating system. Make sure that the Maven Dependency Classifier matches your system type.");
+    }
+
+    private static Enumeration<URL> getUrls() {
+        try {
+            return Thread.currentThread()
+                    .getContextClassLoader()
+                    .getResources(MXNET_PROPERTIES_FILE_PATH);
+        } catch (IOException e) {
+            logger.warn(
+                    String.format(
+                            "IO Exception occurs when try to find the file %s", MXNET_LIBRARY_PATH),
+                    e);
+            return null;
+        }
+    }
+
+    private static String loadLibraryFromClasspath(Platform platform) {
+        Path tmp = null;
+        try {
+            String libName = System.mapLibraryName(LIB_NAME);
+            Path cacheFolder = Utils.getEngineCacheDir(LIB_NAME);
+            logger.debug("Using cache dir: {}", cacheFolder);
+
+            Path dir = cacheFolder.resolve(platform.getVersion() + platform.getClassifier());
+            Path path = dir.resolve(libName);
+            if (Files.exists(path)) {
+                return path.toAbsolutePath().toString();
+            }
+            Files.createDirectories(cacheFolder);
+            tmp = Files.createTempDirectory(cacheFolder, "tmp");
+            for (String file : platform.getLibraries()) {
+                String libPath = "/native/lib/" + file;
+                try (InputStream is = LibUtils.class.getResourceAsStream(libPath)) {
+                    logger.info("Extracting {} to cache ...", file);
+                    Files.copy(is, tmp.resolve(file), StandardCopyOption.REPLACE_EXISTING);
+                }
+            }
+
+            Utils.moveQuietly(tmp, dir);
+            return path.toAbsolutePath().toString();
+        } catch (IOException e) {
+            throw new IllegalStateException("Failed to extract MXNet native library", e);
+        } finally {
+            if (tmp != null) {
+                Utils.deleteQuietly(tmp);
+            }
+        }
+    }
+
+    private static String findLibraryInPath(String libPath) {
+        String[] paths = libPath.split(File.pathSeparator);
+        List<String> mappedLibNames;
+        if (com.sun.jna.Platform.isMac()) {
+            mappedLibNames = Arrays.asList("libmxnet.dylib", "libmxnet.jnilib", "libmxnet.so");
+        } else {
+            mappedLibNames = Collections.singletonList(System.mapLibraryName(LIB_NAME));
+        }
+
+        for (String path : paths) {
+            File p = new File(path);
+            if (!p.exists()) {
+                continue;
+            }
+            for (String name : mappedLibNames) {
+                if (p.isFile() && p.getName().endsWith(name)) {
+                    return p.getAbsolutePath();
+                }
+
+                File file = new File(path, name);
+                if (file.exists() && file.isFile()) {
+                    return file.getAbsolutePath();
+                }
+            }
+        }
+        return null;
+    }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/NativeString.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/NativeString.java
new file mode 100644
index 0000000..d43e475
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/NativeString.java
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.jna;
+
+import com.sun.jna.Memory;
+import com.sun.jna.Pointer;
+import java.nio.charset.Charset;
+
+/**
+ * Provides a temporary allocation of an immutable C string (<code>const char*</code> or <code>
+ * const wchar_t*</code>) for use when converting a Java String into a native memory function
+ * argument.
+ */
+final class NativeString {
+
+    private static final ObjectPool<NativeString> POOL = new ObjectPool<>(null, null);
+
+    private Memory pointer;
+
+    /**
+     * Create a native string (NUL-terminated array of <code>char</code>), using the requested
+     * encoding.
+     *
+     * @param data the bytes of the string
+     */
+    private NativeString(byte[] data) {
+        pointer = new Memory(data.length + 1);
+        setData(data);
+    }
+
+    private void setData(byte[] data) {
+        pointer.write(0, data, 0, data.length);
+        pointer.setByte(data.length, (byte) 0);
+    }
+
+    /**
+     * Acquires a pooled {@code NativeString} object if available, otherwise a new instance is
+     * created.
+     *
+     * @param string the string value
+     * @param encoding the charset encoding
+     * @return a {@code NativeString} object
+     */
+    public static NativeString of(String string, Charset encoding) {
+        byte[] data = string.getBytes(encoding);
+        NativeString array = POOL.acquire();
+        if (array != null && array.pointer.size() > data.length) {
+            array.setData(data);
+            return array;
+        }
+        return new NativeString(data);
+    }
+
+    /** Recycles this instance and return it back to the pool. */
+    public void recycle() {
+        POOL.recycle(this);
+    }
+
+    /**
+     * Returns the peer pointer.
+     *
+     * @return the peer pointer
+     */
+    public Pointer getPointer() {
+        return pointer;
+    }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/ObjectPool.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/ObjectPool.java
new file mode 100644
index 0000000..573acd4
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/ObjectPool.java
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.jna;
+
+import java.util.Queue;
+import java.util.concurrent.ConcurrentLinkedQueue;
+import java.util.function.Consumer;
+import java.util.function.Supplier;
+
+/**
+ * A generic object pool implementation.
+ *
+ * @param <T> the type of object to put in the pool
+ */
+@SuppressWarnings("MissingJavadocMethod")
+public class ObjectPool<T> {
+
+    private Queue<T> queue;
+    private Supplier<T> supplier;
+    private Consumer<T> consumer;
+
+    public ObjectPool(Supplier<T> supplier, Consumer<T> consumer) {
+        queue = new ConcurrentLinkedQueue<>();
+        this.supplier = supplier;
+        this.consumer = consumer;
+    }
+
+    public T acquire() {
+        T item = queue.poll();
+        if (item == null) {
+            if (supplier != null) {
+                return supplier.get();
+            }
+        }
+        return item;
+    }
+
+    public void recycle(T item) {
+        if (consumer != null) {
+            consumer.accept(item);
+        }
+        queue.add(item);
+    }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/PointerArray.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/PointerArray.java
new file mode 100644
index 0000000..a864b64
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/PointerArray.java
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.jna;
+
+import com.sun.jna.Function;
+import com.sun.jna.Memory;
+import com.sun.jna.Native;
+import com.sun.jna.Pointer;
+
+/**
+ * An abstraction for a native pointer array data type ({@code void**}).
+ *
+ * @see Pointer
+ * @see com.sun.jna.ptr.PointerByReference
+ * @see Function
+ */
+@SuppressWarnings("checkstyle:EqualsHashCode")
+final class PointerArray extends Memory {
+
+    private static final ObjectPool<PointerArray> POOL = new ObjectPool<>(null, null);
+
+    private int length;
+
+    /**
+     * Constructs a {@link Memory} buffer PointerArray given the Pointers to include in it.
+     *
+     * @param arg the pointers to include in the array
+     */
+    private PointerArray(Pointer... arg) {
+        super(Native.POINTER_SIZE * (arg.length + 1));
+        length = arg.length;
+        setPointers(arg);
+    }
+
+    /**
+     * Acquires a pooled {@code PointerArray} object if available, otherwise a new instance is
+     * created.
+     *
+     * @param arg the pointers to include in the array
+     * @return a {@code PointerArray} object
+     */
+    public static PointerArray of(Pointer... arg) {
+        PointerArray array = POOL.acquire();
+        if (array != null && array.length >= arg.length) {
+            array.setPointers(arg);
+            return array;
+        }
+        return new PointerArray(arg);
+    }
+
+    /** Recycles this instance and return it back to the pool. */
+    public void recycle() {
+        POOL.recycle(this);
+    }
+
+    private void setPointers(Pointer[] pointers) {
+        for (int i = 0; i < pointers.length; i++) {
+            setPointer(i * Native.POINTER_SIZE, pointers[i]);
+        }
+        setPointer(Native.POINTER_SIZE * length, null);
+    }
+
+    /** {@inheritDoc} */
+    @Override
+    public boolean equals(Object o) {
+        return o == this;
+    }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/StringArray.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/StringArray.java
new file mode 100644
index 0000000..00ffa8f
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/StringArray.java
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.jna;
+
+import com.sun.jna.Memory;
+import com.sun.jna.Native;
+import com.sun.jna.Pointer;
+import java.nio.charset.Charset;
+import java.util.ArrayList;
+import java.util.List;
+
+/** An abstraction for a native string array data type ({@code char**}). */
+@SuppressWarnings("checkstyle:EqualsHashCode")
+final class StringArray extends Memory {
+
+    private static final Charset ENCODING = Native.DEFAULT_CHARSET;
+    private static final ObjectPool<StringArray> POOL = new ObjectPool<>(null, null);
+    /** Hold all {@code NativeString}, avoid be GCed. */
+    private List<NativeString> natives; // NOPMD
+
+    private int length;
+
+    /**
+     * Create a native array of strings.
+     *
+     * @param strings the strings
+     */
+    private StringArray(String[] strings) {
+        super((strings.length + 1) * Native.POINTER_SIZE);
+        natives = new ArrayList<>();
+        length = strings.length;
+        setPointers(strings);
+    }
+
+    private void setPointers(String[] strings) {
+        for (NativeString ns : natives) {
+            ns.recycle();
+        }
+        natives.clear();
+        for (int i = 0; i < strings.length; ++i) {
+            Pointer p = null;
+            if (strings[i] != null) {
+                NativeString ns = NativeString.of(strings[i], ENCODING);
+                natives.add(ns);
+                p = ns.getPointer();
+            }
+            setPointer(Native.POINTER_SIZE * i, p);
+        }
+        setPointer(Native.POINTER_SIZE * strings.length, null);
+    }
+
+    /**
+     * Acquires a pooled {@code StringArray} object if available, otherwise a new instance is
+     * created.
+     *
+     * @param strings the pointers to include in the array
+     * @return a {@code StringArray} object
+     */
+    public static StringArray of(String[] strings) {
+        StringArray array = POOL.acquire();
+        if (array != null && array.length >= strings.length) {
+            array.setPointers(strings);
+            return array;
+        }
+        return new StringArray(strings);
+    }
+
+    /** Recycles this instance and return it back to the poll. */
+    public void recycle() {
+        POOL.recycle(this);
+    }
+
+    /** {@inheritDoc} */
+    @Override
+    public boolean equals(Object o) {
+        return this == o;
+    }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/package-info.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/package-info.java
new file mode 100644
index 0000000..d524d50
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/jna/package-info.java
@@ -0,0 +1,19 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/** Contains the Java front-end implementation for Apache MXNet. */
+package org.apache.mxnet.jna;
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/NDArray.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/NDArray.java
new file mode 100644
index 0000000..f775f44
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/NDArray.java
@@ -0,0 +1,3455 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.ndarray;
+
+import com.sun.jna.Native;
+import com.sun.jna.Pointer;
+import java.io.ByteArrayInputStream;
+import java.io.DataInputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.Buffer;
+import java.nio.ByteBuffer;
+import java.nio.DoubleBuffer;
+import java.nio.FloatBuffer;
+import java.nio.IntBuffer;
+import java.nio.LongBuffer;
+import java.util.Arrays;
+import java.util.stream.IntStream;
+import org.apache.mxnet.engine.BaseMxResource;
+import org.apache.mxnet.engine.Device;
+import org.apache.mxnet.engine.MxResource;
+import org.apache.mxnet.engine.OpParams;
+import org.apache.mxnet.jna.JnaUtils;
+import org.apache.mxnet.ndarray.index.NDIndex;
+import org.apache.mxnet.ndarray.types.DataType;
+import org.apache.mxnet.ndarray.types.Shape;
+import org.apache.mxnet.ndarray.types.SparseFormat;
+import org.apache.mxnet.util.Float16Utils;
+import org.apache.mxnet.util.PairList;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * A Class representing an n-dimensional array.
+ *
+ * <p>NDArray is the core data structure for all mathematical computations. An NDArray represents a
+ * multidimensional, fixed-size homogeneous array. It has very similar behaviour to the Numpy python
+ * package with the addition of efficient computing.
+ */
+public class NDArray extends MxResource {
+
+    private static final Logger logger = LoggerFactory.getLogger(NDArray.class);
+
+    private static final int MAX_SIZE = 100;
+    private static final int MAX_DEPTH = 10;
+    private static final int MAX_ROWS = 10;
+    private static final int MAX_COLUMNS = 20;
+    private static final NDArray[] EMPTY = new NDArray[0];
+
+    private String name;
+    private Device device;
+    private SparseFormat sparseFormat;
+    private DataType dataType;
+    private Shape shape;
+    // use Boolean object to maintain three status: false, true
+    // and null which means the flag is not set by the native engine yet
+    private Boolean hasGradient;
+    private Integer version;
+    private NDArrayEx mxNDArrayEx;
+
+    protected NDArray(Pointer handle) {
+        super(BaseMxResource.getSystemMxResource(), handle);
+    }
+
+    /**
+     * Constructs an {@link NDArray} from a native handle and metadata (internal. Use {@method
+     * create} methods).
+     *
+     * @param parent the parent {@link MxResource} to manage the life circle of the {@link NDArray}
+     * @param handle the pointer to the native NDArray memory
+     * @param device the device the new array will be located on
+     * @param shape the shape of the new array
+     * @param dataType the dataType of the new array
+     * @param hasGradient the gradient status of the new array
+     */
+    NDArray(
+            MxResource parent,
+            Pointer handle,
+            Device device,
+            Shape shape,
+            DataType dataType,
+            boolean hasGradient) {
+        this(parent, handle);
+        setParent(parent);
+        this.device = device;
+        // shape check
+        if (Arrays.stream(shape.getShape()).anyMatch(s -> s < 0)) {
+            throw new IllegalArgumentException("The shape must be >= 0");
+        }
+        this.shape = shape;
+        this.dataType = dataType;
+        this.hasGradient = hasGradient;
+        if (parent != null) {
+            parent.addSubResource(this);
+        }
+    }
+
+    /**
+     * Constructs an {@link NDArray} from a native handle and metadata (internal).
+     *
+     * @param parent the parent {@link MxResource} to manage the life circle of the {@link NDArray}
+     * @param handle the pointer to the native NDArray memory
+     */
+    NDArray(MxResource parent, Pointer handle) {
+        super(parent, handle);
+        this.mxNDArrayEx = new NDArrayEx(this);
+    }
+
+    /**
+     * Constructs an {@link NDArray} from a native handle and metadata (internal).
+     *
+     * @param parent the parent {@link MxResource} to manage the life circle of the {@link NDArray}
+     * @param handle the pointer to the native NDArray memory
+     * @param fmt the sparse format
+     */
+    NDArray(MxResource parent, Pointer handle, SparseFormat fmt) {
+        this(parent, handle);
+        this.sparseFormat = fmt;
+    }
+
+    /**
+     * Creates an NDArray with the given Native Memory Pointer and parent MxResource.
+     *
+     * @param parent the parent {@link MxResource} instance
+     * @param handle the array's native memory pointer
+     * @return the created array
+     */
+    public static NDArray create(MxResource parent, Pointer handle) {
+        return new NDArray(parent, handle);
+    }
+
+    /**
+     * Creates an NDArray with the given Native Memory Pointer and parent MxResource.
+     *
+     * @param parent the parent {@link MxResource} instance
+     * @param handle the array's native memory pointer
+     * @param fmt the sparse format
+     * @return the created array
+     */
+    public static NDArray create(MxResource parent, Pointer handle, SparseFormat fmt) {
+        return new NDArray(parent, handle, fmt);
+    }
+
+    /**
+     * Creates an uninitialized instance of {@link DataType#FLOAT32} {@link NDArray} with specified
+     * parent {@link MxResource}, {@link Shape}, {@link Device} and {@code hasGradient}.
+     *
+     * @param parent the parent {@link MxResource}
+     * @param shape the {@link Shape} of the {@link NDArray}
+     * @param device the {@link Device} of the {@link NDArray}
+     * @return a new instance of {@link NDArray}
+     */
+    public static NDArray create(MxResource parent, Shape shape, Device device) {
+        return create(parent, shape, DataType.FLOAT32, device);
+    }
+
+    /**
+     * Creates an uninitialized instance of {@link DataType#FLOAT32} {@link NDArray} with specified
+     * parent {@link MxResource} and {@link Shape}.
+     *
+     * @param parent the parent {@link MxResource}
+     * @param shape the {@link Shape} of the {@link NDArray}
+     * @return a new instance of {@link NDArray}
+     */
+    public static NDArray create(MxResource parent, Shape shape) {
+        return create(parent, shape, DataType.FLOAT32, Device.defaultIfNull());
+    }
+
+    /**
+     * Creates an uninitialized instance of {@link NDArray} with specified parent {@link
+     * MxResource}, {@link Shape}, {@link DataType}, {@link Device} and {@code hasGradient}.
+     *
+     * @param parent the parent {@link MxResource}
+     * @param shape the {@link Shape} of the {@link NDArray}
+     * @param dataType the {@link DataType} of the {@link NDArray}
+     * @param device the {@link Device} of the {@link NDArray}
+     * @param hasGradient true if the gradient calculation is required for this {@link NDArray}
+     * @return a new instance of {@link NDArray}
+     */
+    public static NDArray create(
+            MxResource parent, Shape shape, DataType dataType, Device device, boolean hasGradient) {
+        Pointer handle =
+                JnaUtils.createNdArray(device, shape, dataType, shape.dimension(), hasGradient);
+        return new NDArray(parent, handle, device, shape, dataType, hasGradient);
+    }
+
+    /**
+     * Creates an uninitialized instance of {@link NDArray} with specified parent {@link
+     * MxResource}, {@link Shape}, {@link DataType}, {@link Device}.
+     *
+     * @param parent the parent {@link MxResource}
+     * @param shape the {@link Shape} of the {@link NDArray}
+     * @param dataType the {@link DataType} of the {@link NDArray}
+     * @param device the {@link Device} of the {@link NDArray}
+     * @return a new instance of {@link NDArray}
+     */
+    public static NDArray create(MxResource parent, Shape shape, DataType dataType, Device device) {
+        Pointer handle = JnaUtils.createNdArray(device, shape, dataType, shape.dimension(), false);
+        return new NDArray(parent, handle, Device.defaultIfNull(device), shape, dataType, false);
+    }
+
+    /**
+     * Creates and initializes a scalar {@link NDArray}.
+     *
+     * @param parent the parent {@link MxResource} instance
+     * @param data the {@link Number} that needs to be set
+     * @return a new instance of {@link NDArray}
+     * @throws IllegalArgumentException when the Type of data is not expected
+     */
+    public static NDArray create(MxResource parent, Number data) {
+        if (data instanceof Integer) {
+            return create(parent, data.intValue());
+        } else if (data instanceof Float) {
+            return create(parent, data.floatValue());
+        } else if (data instanceof Double) {
+            return create(parent, data.doubleValue());
+        } else if (data instanceof Long) {
+            return create(parent, data.longValue());
+        } else if (data instanceof Byte) {
+            return create(parent, data.byteValue());
+        } else {
+            throw new IllegalArgumentException("Short conversion not supported!");
+        }
+    }
+
+    /**
+     * Creates and initializes an instance of {@link NDArray} with specified {@link Shape} and float
+     * array.
+     *
+     * @param parent the parent {@link MxResource} to manage this instance
+     * @param data the float array that needs to be set
+     * @param shape the {@link Shape} of the {@link NDArray}
+     * @return a new instance of {@link NDArray}
+     */
+    public static NDArray create(MxResource parent, float[] data, Shape shape) {
+        return create(parent, FloatBuffer.wrap(data), shape);
+    }
+
+    /**
+     * Creates and initializes an instance of {@link NDArray} with specified {@link Shape} and int
+     * array.
+     *
+     * @param parent the parent {@link MxResource} to manage this instance
+     * @param data the float array that needs to be set
+     * @param shape the {@link Shape} of the {@link NDArray}
+     * @return a new instance of {@link NDArray}
+     */
+    public static NDArray create(MxResource parent, int[] data, Shape shape) {
+        return create(parent, IntBuffer.wrap(data), shape);
+    }
+
+    /**
+     * Creates and initializes an instance of {@link NDArray} with specified {@link Shape} and
+     * double array.
+     *
+     * @param parent the parent {@link MxResource} to manage this instance
+     * @param data the float array that needs to be set
+     * @param shape the {@link Shape} of the {@link NDArray}
+     * @return a new instance of {@link NDArray}
+     */
+    public static NDArray create(MxResource parent, double[] data, Shape shape) {
+        return create(parent, DoubleBuffer.wrap(data), shape);
+    }
+
+    /**
+     * Creates and initializes an instance of {@link NDArray} with specified {@link Shape} and long
+     * array.
+     *
+     * @param parent the parent {@link MxResource} to manage this instance
+     * @param data the float array that needs to be set
+     * @param shape the {@link Shape} of the {@link NDArray}
+     * @return a new instance of {@link NDArray}
+     */
+    public static NDArray create(MxResource parent, long[] data, Shape shape) {
+        return create(parent, LongBuffer.wrap(data), shape);
+    }
+
+    /**
+     * Creates and initializes an instance of {@link NDArray} with specified {@link Shape} and byte
+     * array.
+     *
+     * @param parent the parent {@link MxResource} to manage this instance
+     * @param data the float array that needs to be set
+     * @param shape the {@link Shape} of the {@link NDArray}
+     * @return a new instance of {@link NDArray}
+     */
+    public static NDArray create(MxResource parent, byte[] data, Shape shape) {
+        return create(parent, ByteBuffer.wrap(data), shape);
+    }
+
+    /**
+     * Creates and initializes an instance of {@link NDArray} with specified {@link Shape} and
+     * boolean array.
+     *
+     * @param parent the parent {@link MxResource} to manage this instance
+     * @param data the boolean array that needs to be set
+     * @param shape the {@link Shape} of the {@link NDArray}
+     * @return a new instance of {@link NDArray}
+     */
+    public static NDArray create(MxResource parent, boolean[] data, Shape shape) {
+        byte[] byteData = new byte[data.length];
+        for (int i = 0; i < data.length; i++) {
+            byteData[i] = (byte) (data[i] ? 1 : 0);
+        }
+        return create(parent, ByteBuffer.wrap(byteData), shape, DataType.BOOLEAN);
+    }
+
+    /**
+     * Creates and initializes a scalar {@link NDArray}.
+     *
+     * @param parent the parent {@link MxResource} to manage this instance
+     * @param data the float that needs to be set
+     * @return a new instance of {@link NDArray}
+     */
+    public static NDArray create(MxResource parent, float data) {
+        return create(parent, new float[] {data}, new Shape());
+    }
+
+    /**
+     * Creates and initializes a scalar {@link NDArray}.
+     *
+     * @param parent the parent {@link MxResource} to manage this instance
+     * @param data the float data that needs to be set
+     * @return a new instance of {@link NDArray}
+     */
+    public static NDArray create(MxResource parent, int data) {
+        return create(parent, new int[] {data}, new Shape());
+    }
+
+    /**
+     * Creates and initializes a scalar {@link NDArray}.
+     *
+     * @param parent the parent {@link MxResource} to manage this instance
+     * @param data the double data that needs to be set
+     * @return a new instance of {@link NDArray}
+     */
+    public static NDArray create(MxResource parent, double data) {
+        return create(parent, new double[] {data}, new Shape());
+    }
+
+    /**
+     * Creates and initializes a scalar {@link NDArray}.
+     *
+     * @param parent the parent {@link MxResource} to manage this instance
+     * @param data the long data that needs to be set
+     * @return a new instance of {@link NDArray}
+     */
+    public static NDArray create(MxResource parent, long data) {
+        return create(parent, new long[] {data}, new Shape());
+    }
+
+    /**
+     * Creates and initializes a scalar {@link NDArray}.
+     *
+     * @param parent the parent {@link MxResource} to manage this instance
+     * @param data the byte data that needs to be set
+     * @return a new instance of {@link NDArray}
+     */
+    public static NDArray create(MxResource parent, byte data) {
+        return create(parent, new byte[] {data}, new Shape());
+    }
+
+    /**
+     * Creates and initializes a scalar {@link NDArray}.
+     *
+     * @param parent the parent {@link MxResource} to manage this instance
+     * @param data the boolean data that needs to be set
+     * @return a new instance of {@link NDArray}
+     */
+    public static NDArray create(MxResource parent, boolean data) {
+
+        return create(parent, new boolean[] {data}, new Shape());
+    }
+
+    /**
+     * Creates and initializes a 1D {@link NDArray}.
+     *
+     * @param parent the parent {@link MxResource} to manage this instance
+     * @param data the float array that needs to be set
+     * @return a new instance of {@link NDArray}
+     */
+    public static NDArray create(MxResource parent, float[] data) {
+        return create(parent, data, new Shape(data.length));
+    }
+
+    /**
+     * Creates and initializes a 1D {@link NDArray}.
+     *
+     * @param parent the parent {@link MxResource} to manage this instance
+     * @param data the float array that needs to be set
+     * @return a new instance of {@link NDArray}
+     */
+    public static NDArray create(MxResource parent, int[] data) {
+        return create(parent, data, new Shape(data.length));
+    }
+
+    /**
+     * Creates and initializes a 1D {@link NDArray}.
+     *
+     * @param parent the parent {@link MxResource} to manage this instance
+     * @param data the float array that needs to be set
+     * @return a new instance of {@link NDArray}
+     */
+    public static NDArray create(MxResource parent, double[] data) {
+        return create(parent, data, new Shape(data.length));
+    }
+
+    /**
+     * Creates and initializes a 1D {@link NDArray}.
+     *
+     * @param parent the parent {@link MxResource} to manage this instance
+     * @param data the float array that needs to be set
+     * @return a new instance of {@link NDArray}
+     */
+    public static NDArray create(MxResource parent, long[] data) {
+        return create(parent, data, new Shape(data.length));
+    }
+
+    /**
+     * Creates and initializes a 1D {@link NDArray}.
+     *
+     * @param parent the parent {@link MxResource} to manage this instance
+     * @param data the float array that needs to be set
+     * @return a new instance of {@link NDArray}
+     */
+    public static NDArray create(MxResource parent, byte[] data) {
+        return create(parent, data, new Shape(data.length));
+    }
+
+    /**
+     * Creates and initializes a 1D {@link NDArray}.
+     *
+     * @param parent the parent {@link MxResource} to manage this instance
+     * @param data the bool array that needs to be set
+     * @return a new instance of {@link NDArray}
+     */
+    public static NDArray create(MxResource parent, boolean[] data) {
+        return create(parent, data, new Shape(data.length));
+    }
+
+    /**
+     * Creates and initializes a 2D {@link NDArray}.
+     *
+     * @param parent the parent {@link MxResource} to manage this instance
+     * @param data the float array that needs to be set
+     * @return a new instance of {@link NDArray}
+     */
+    public static NDArray create(MxResource parent, float[][] data) {
+        FloatBuffer buffer = FloatBuffer.allocate(data.length * data[0].length);
+        for (float[] d : data) {
+            buffer.put(d);
+        }
+        buffer.rewind();
+        return create(parent, buffer, new Shape(data.length, data[0].length));
+    }
+
+    /**
+     * Creates and initializes a 2D {@link NDArray}.
+     *
+     * @param parent the parent {@link MxResource} to manage this instance
+     * @param data the float array that needs to be set
+     * @return a new instance of {@link NDArray}
+     */
+    public static NDArray create(MxResource parent, int[][] data) {
+        IntBuffer buffer = IntBuffer.allocate(data.length * data[0].length);
+        for (int[] d : data) {
+            buffer.put(d);
+        }
+        buffer.rewind();
+        return create(parent, buffer, new Shape(data.length, data[0].length));
+    }
+
+    /**
+     * Creates and initializes a 2D {@link NDArray}.
+     *
+     * @param parent the parent {@link MxResource} to manage this instance
+     * @param data the float array that needs to be set
+     * @return a new instance of {@link NDArray}
+     */
+    public static NDArray create(MxResource parent, double[][] data) {
+        DoubleBuffer buffer = DoubleBuffer.allocate(data.length * data[0].length);
+        for (double[] d : data) {
+            buffer.put(d);
+        }
+        buffer.rewind();
+        return create(parent, buffer, new Shape(data.length, data[0].length));
+    }
+
+    /**
+     * Creates and initializes a 2-D {@link NDArray}.
+     *
+     * @param parent the parent {@link MxResource} to manage this instance
+     * @param data the float array that needs to be set
+     * @return a new instance of {@link NDArray}
+     */
+    public static NDArray create(MxResource parent, long[][] data) {
+        LongBuffer buffer = LongBuffer.allocate(data.length * data[0].length);
+        for (long[] d : data) {
+            buffer.put(d);
+        }
+        buffer.rewind();
+        return create(parent, buffer, new Shape(data.length, data[0].length));
+    }
+
+    /**
+     * Creates and initializes a 2-D {@link NDArray}.
+     *
+     * @param parent the parent {@link MxResource} instance
+     * @param data the float array that needs to be set
+     * @return a new instance of {@link NDArray}
+     */
+    public static NDArray create(MxResource parent, byte[][] data) {
+        ByteBuffer buffer = ByteBuffer.allocate(data.length * data[0].length);
+        for (byte[] d : data) {
+            buffer.put(d);
+        }
+        buffer.rewind();
+        return create(parent, buffer, new Shape(data.length, data[0].length));
+    }
+
+    /**
+     * Creates and initializes a 2-D {@link NDArray}.
+     *
+     * @param parent the parent {@link MxResource} instance
+     * @param data the boolean array that needs to be set
+     * @return a new instance of {@link NDArray}
+     */
+    public static NDArray create(MxResource parent, boolean[][] data) {
+        ByteBuffer buffer = ByteBuffer.allocate(data.length * data[0].length);
+        for (boolean[] d : data) {
+            for (boolean b : d) {
+                buffer.put((byte) (b ? 1 : 0));
+            }
+        }
+        buffer.rewind();
+        return create(parent, buffer, new Shape(data.length, data[0].length), DataType.BOOLEAN);
+    }
+
+    /**
+     * Creates and initializes a {@link NDArray} with specified {@link Shape}.
+     *
+     * <p>{@link DataType} of the MxNDArray will determined by type of Buffer.
+     *
+     * @param parent the parent {@link MxResource} instance
+     * @param data the data to initialize the {@code MxNDArray}
+     * @param shape the {@link Shape} of the {@link NDArray}
+     * @return a new instance of {@link NDArray}
+     */
+    static NDArray create(MxResource parent, Buffer data, Shape shape) {
+        DataType dataType = DataType.fromBuffer(data);
+        return create(parent, data, shape, dataType);
+    }
+
+    static NDArray create(MxResource parent, Buffer data, Shape shape, DataType dataType) {
+        NDArray array = create(parent, shape, dataType, Device.defaultIfNull());
+        array.set(data);
+        return array;
+    }
+
+    /**
+     * Returns the name of this {@code NDArray}.
+     *
+     * @return the name of this {@code NDArray}
+     */
+    public String getName() {
+        return name;
+    }
+
+    /**
+     * Sets the name of this {@code NDArray}.
+     *
+     * @param name of the {@code NDArray}
+     */
+    public void setName(String name) {
+        this.name = name;
+    }
+
+    /**
+     * Returns the {@link DataType} of this {@code NDArray}.
+     *
+     * @return the {@link DataType} of this {@code NDArray}
+     */
+    public DataType getDataType() {
+        if (this.dataType == null) {
+            this.dataType = JnaUtils.getDataTypeOfNdArray(getHandle());
+        }
+        return this.dataType;
+    }
+
+    /** {@inheritDoc} */
+    @Override
+    public Device getDevice() {
+        if (this.device == null) {
+            this.device = JnaUtils.getDeviceOfNdArray(getHandle());
+        }
+        return this.device;
+    }
+
+    /**
+     * Returns the {@link Shape} of this {@code NDArray}.
+     *
+     * @return the {@link Shape} of this {@code NDArray}
+     */
+    public Shape getShape() {
+        if (this.shape == null) {
+            this.shape = JnaUtils.getShapeOfNdArray(getHandle());
+        }
+        return this.shape;
+    }
+
+    /**
+     * Returns the {@link SparseFormat} of this {@code NDArray}.
+     *
+     * @return the {@link SparseFormat} of this {@code NDArray}
+     */
+    public SparseFormat getSparseFormat() {
+        if (this.sparseFormat == null) {
+            this.sparseFormat = JnaUtils.getStorageType(getHandle());
+        }
+        return this.sparseFormat;
+    }
+
+    /**
+     * Returns the version of this {@code NDArray}.
+     *
+     * @return the version of this {@code NDArray}
+     */
+    public Integer getVersion() {
+        if (this.version == null) {
+            this.version = JnaUtils.getVersion();
+        }
+        return this.version;
+    }
+
+    private NDArray duplicate(Shape shape, DataType dataType, Device device, String name) {
+        // TODO get copy parameter
+        NDArray array = create(getParent(), shape, dataType, device);
+        array.setName(name);
+        copyTo(array);
+        return array;
+    }
+
+    /**
+     * Returns a copy of this {@code NDArray}.
+     *
+     * @return a copy of this {@code NDArray}
+     */
+    NDArray duplicate() {
+        NDArray array = create(getParent(), getShape(), getDataType(), getDevice());
+        array.setName(getName());
+        copyTo(array);
+        return array;
+    }
+
+    /**
+     * Moves this {@code NDArray} to a different {@link Device}.
+     *
+     * @param device the {@link Device} to be set
+     * @param copy set {@code true} if you want to return a copy of the Existing {@code NDArray}
+     * @return the result {@code NDArray} with the new {@link Device}
+     */
+    public NDArray toDevice(Device device, boolean copy) {
+        if (device.equals(getDevice()) && !copy) {
+            return this;
+        }
+        return duplicate(getShape(), getDataType(), device, getName());
+    }
+
+    /**
+     * Converts this {@code NDArray} to a different {@link DataType}.
+     *
+     * @param dataType the {@link DataType} to be set
+     * @param copy set {@code true} if you want to return a copy of the Existing {@code NDArray}
+     * @return the result {@code NDArray} with the new {@link DataType}
+     */
+    public NDArray toType(DataType dataType, boolean copy) {
+        if (dataType.equals(getDataType()) && !copy) {
+            return this;
+        }
+        return duplicate(getShape(), dataType, getDevice(), getName());
+    }
+
+    /**
+     * Creates an instance of {@link NDArray} with specified {@link Shape} filled with zeros.
+     *
+     * @param shape the {@link Shape} of the {@link NDArray}
+     * @param dataType the {@link DataType} of the {@link NDArray}
+     * @return a new instance of {@link NDArray}
+     * @see #zeros(Shape, DataType, Device)
+     */
+    public NDArray zeros(Shape shape, DataType dataType) {
+        return fill("_npi_zeros", shape, dataType);
+    }
+
+    /**
+     * Creates an instance of {@link NDArray} with the same {@link Shape} and {@link DataType}
+     * filled with zeros.
+     *
+     * @return a new instance of {@link NDArray}
+     * @see #zeros(Shape, DataType, Device)
+     */
+    public NDArray zeros() {
+        return zeros(getShape(), getDataType());
+    }
+
+    /**
+     * Creates an instance of {@link NDArray} with specified {@link Device}, {@link Shape}, and
+     * {@link DataType} filled with zeros.
+     *
+     * @param shape the {@link Shape} of the {@link NDArray}
+     * @param dataType the {@link DataType} of the {@link NDArray}
+     * @param device the {@link Device} of the {@link NDArray}
+     * @return a new instance of {@link NDArray}
+     */
+    NDArray zeros(Shape shape, DataType dataType, Device device) {
+        if (device == null || device.equals(getDevice())) {
+            return zeros(shape, dataType);
+        }
+        return zeros(shape, dataType);
+    }
+
+    private NDArray createGradient(SparseFormat format) {
+        try (NDArray zeros = this.zeros(getShape(), getDataType(), getDevice())) {
+            return zeros.toSparse(format);
+        }
+    }
+
+    private NDArray fill(String opName, Shape shape, DataType dataType) {
+        OpParams params = new OpParams();
+        if (shape == null) {
+            throw new IllegalArgumentException("Shape is required for " + opName.substring(1));
+        }
+        params.addParam("shape", shape);
+        params.setDevice(device);
+        params.setDataType(dataType);
+        return invoke(getParent(), opName, params);
+    }
+
+    /**
+     * Creates an instance of {@link NDArray} with specified {@link Shape} filled with ones.
+     *
+     * @param parent the parent {@link MxResource}
+     * @param shape the {@link Shape} of the {@link NDArray}
+     * @param dataType the {@link DataType} of the {@link NDArray}
+     * @param device the {@link Device} of the {@link NDArray}
+     * @return a new instance of {@link NDArray}
+     */
+    public static NDArray ones(MxResource parent, Shape shape, DataType dataType, Device device) {
+        return create(parent, shape, dataType, device).ones();
+    }
+
+    /**
+     * Creates an instance of {@link NDArray} with specified {@link Shape} filled with ones.
+     *
+     * @param shape the {@link Shape} of the {@link NDArray}
+     * @param dataType the {@link DataType} of the {@link NDArray}
+     * @return a new instance of {@link NDArray}
+     */
+    NDArray ones(Shape shape, DataType dataType) {
+        return fill("_npi_ones", shape, dataType);
+    }
+
+    /**
+     * Creates an instance of {@link NDArray} with same {@link Shape} and {@link DataType} filled
+     * with ones.
+     *
+     * @return a new instance of {@link NDArray}
+     */
+    public NDArray ones() {
+        return ones(getShape(), getDataType());
+    }
+    /**
+     * Creates an instance of {@link NDArray} with specified {@link Shape} filled with ones.
+     *
+     * @param shape the {@link Shape} of the {@link NDArray}
+     * @return a new instance of {@link NDArray}
+     */
+    NDArray ones(Shape shape) {
+        return ones(shape, DataType.FLOAT32);
+    }
+
+    /**
+     * Creates an instance of {@link NDArray} with specified {@link Device}, {@link Shape}, and
+     * {@link DataType} filled with ones.
+     *
+     * @param shape the {@link Shape} of the {@link NDArray}
+     * @param dataType the {@link DataType} of the {@link NDArray}
+     * @param device the {@link Device} of the {@link NDArray}
+     * @return a new instance of {@link NDArray}
+     */
+    NDArray ones(Shape shape, DataType dataType, Device device) {
+        if (device == null || device.equals(getDevice())) {
+            return ones(shape, dataType);
+        }
+        return create(getParent(), shape, dataType, device).ones();
+    }
+
+    /**
+     * Returns the gradient {@code NDArray} attached to this {@code NDArray}.
+     *
+     * @return the gradient {@code NDArray}
+     * @throws IllegalStateException when hasGradient is false
+     */
+    public NDArray getGradient() {
+        if (!hasGradient()) {
+            throw new IllegalStateException(
+                    "No gradient attached to this MxNDArray, please call array.requiredGradient()"
+                            + "on your MxNDArray or block.setInitializer() on your Block");
+        }
+        Pointer pointer = JnaUtils.getGradient(getHandle());
+        return create(getParent(), pointer);
+    }
+
+    /**
+     * Returns true if the gradient calculation is required for this {@code NDArray}.
+     *
+     * @return true if the gradient calculation is required for this {@code NDArray} else false
+     */
+    public boolean hasGradient() {
+        if (hasGradient == null) {
+            Pointer pointer = JnaUtils.getGradient(getHandle());
+            hasGradient = pointer != null;
+        }
+        return hasGradient;
+    }
+
+    /**
+     * Returns an NDArray equal to this that stop gradient propagation through it.
+     *
+     * @return an NDArray equal to this that stops gradient propagation through it
+     */
+    public NDArray stopGradient() {
+        Pointer pointer = JnaUtils.detachGradient(getHandle());
+        return create(getParent(), pointer);
+    }
+
+    /**
+     * Converts this {@code NDArray} to a String array.
+     *
+     * <p>This method is only applicable to the String typed NDArray and not for printing purpose
+     *
+     * @return Array of Strings
+     */
+    public String[] toStringArray() {
+        throw new UnsupportedOperationException("String MxNDArray is not supported!");
+    }
+
+    /**
+     * Converts this {@code NDArray} to a ByteBuffer.
+     *
+     * @return a ByteBuffer
+     */
+    public ByteBuffer toByteBuffer() {
+        if (getSparseFormat() != SparseFormat.DENSE) {
+            throw new IllegalStateException("Require Dense MxNDArray, actual " + getSparseFormat());
+        }
+        Shape sh = getShape();
+        DataType dType = getDataType();
+        long product = sh.size();
+        long len = dType.getNumOfBytes() * product;
+        ByteBuffer bb = NDSerializer.allocateDirect(Math.toIntExact(len));
+        Pointer pointer = Native.getDirectBufferPointer(bb);
+        JnaUtils.syncCopyToCPU(getHandle(), pointer, Math.toIntExact(product));
+        return bb;
+    }
+
+    /**
+     * Returns the total number of elements in this {@code MxNDArray}.
+     *
+     * @return the number of elements in this {@code MxNDArray}
+     */
+    long size() {
+        return getShape().size();
+    }
+
+    long size(int axis) {
+        return getShape().size(axis);
+    }
+
+    /**
+     * Sets this {@code NDArray} value from {@link Buffer}.
+     *
+     * @param data the input buffered data
+     */
+    public void set(Buffer data) {
+        int size = Math.toIntExact(size());
+        if (data.remaining() < size) {
+            throw new IllegalArgumentException(
+                    "The MxNDArray size is: " + size + ", but buffer size is: " + data.remaining());
+        }
+        if (data.isDirect()) {
+            JnaUtils.syncCopyFromCPU(getHandle(), data, size);
+            return;
+        }
+
+        data.limit(size);
+        // int8, uint8, boolean use ByteBuffer, so need to explicitly input DataType
+        DataType inputType = DataType.fromBuffer(data);
+        validate(inputType);
+
+        int numOfBytes = inputType.getNumOfBytes();
+        ByteBuffer buf = NDSerializer.allocateDirect(size * numOfBytes);
+
+        switch (inputType) {
+            case FLOAT32:
+                buf.asFloatBuffer().put((FloatBuffer) data);
+                break;
+            case FLOAT64:
+                buf.asDoubleBuffer().put((DoubleBuffer) data);
+                break;
+            case UINT8:
+            case INT8:
+            case BOOLEAN:
+                buf.put((ByteBuffer) data);
+                break;
+            case INT32:
+                buf.asIntBuffer().put((IntBuffer) data);
+                break;
+            case INT64:
+                buf.asLongBuffer().put((LongBuffer) data);
+                break;
+            case FLOAT16:
+            default:
+                throw new UnsupportedOperationException("data type is not supported!");
+        }
+        buf.rewind();
+        JnaUtils.syncCopyFromCPU(getHandle(), buf, size);
+    }
+
+    private void validate(DataType inputType) {
+        if (getDataType() != inputType
+                && ((dataType != DataType.UINT8 && dataType != DataType.BOOLEAN)
+                        || inputType != DataType.INT8)) {
+            // Infer DataType from Buffer always return INT8, make this two special case that
+            // allows set UINT8 and BOOL array with regular ByteBuffer.
+            throw new IllegalStateException(
+                    "DataType mismatch, required: " + dataType + ", actual: " + inputType);
+        }
+    }
+
+    /**
+     * Returns {@code true} if this {@code MxNDArray} is a scalar {@code MxNDArray} with empty
+     * {@link Shape}.
+     *
+     * @return {@code true} if this {@code MxNDArray} is a scalar {@code MxNDArray} with empty
+     *     {@link Shape}
+     */
+    boolean isScalar() {
+        return getShape().isScalar();
+    }
+
+    /**
+     * Returns {@code true} if all elements within this {@code NDArray} are non-zero or {@code
+     * true}.
+     *
+     * @return {@code true} if all elements within this {@code NDArray} are non-zero or {@code true}
+     */
+    NDArray all() {
+        // result of sum operation is int64 now
+        return toType(DataType.BOOLEAN, false).sum().eq(size());
+    }
+
+    /**
+     * Deep-copies the current {@code NDArray} to the one passed in.
+     *
+     * @param ndArray this {@code NDArray} prepared to be copied to
+     */
+    public void copyTo(NDArray ndArray) {
+
+        Shape inShape = getShape();
+        Shape destShape = ndArray.getShape();
+        if (!Arrays.equals(inShape.getShape(), destShape.getShape())) {
+            throw new IllegalArgumentException(
+                    "shape are diff. Required: " + destShape + ", Actual " + inShape);
+        }
+        JnaUtils.op("_npi_copyto").invoke(new NDArray[] {this}, new NDArray[] {ndArray}, null);
+    }
+
+    NDArray booleanMask(NDArray index) {
+        return booleanMask(index, 0);
+    }
+
+    /**
+     * Returns portion of this {@code NDArray} given the index boolean {@code NDArray} along given
+     * axis.
+     *
+     * @param index boolean {@code NDArray} mask
+     * @param axis an integer that represents the axis of {@code NDArray} to mask from
+     * @return the result {@code NDArray}
+     */
+    public NDArray booleanMask(NDArray index, int axis) {
+        if (isScalar() || index.isScalar()) {
+            throw new IllegalArgumentException("booleanMask didn't support scalar!");
+        }
+        // TODO remove reshape when MXNet numpy support multi-dim index
+        // and boolean MxNDArray reshape
+        Shape remainingDims = getShape().slice(index.getShape().dimension());
+        // create a reshape array {-1, remainingDims}
+        long[] reshape = new long[remainingDims.dimension() + 1];
+        reshape[0] = -1;
+        System.arraycopy(remainingDims.getShape(), 0, reshape, 1, remainingDims.dimension());
+        OpParams params = new OpParams();
+        params.addParam("axis", axis);
+        try (NDArray reshaped = this.reshape(new Shape(reshape));
+                NDArray reshapedIndex = index.toType(DataType.INT32, false).reshape(-1);
+                NDArray result =
+                        invoke(
+                                getParent(),
+                                "_npi_boolean_mask",
+                                new NDArray[] {reshaped, reshapedIndex},
+                                params)) {
+            return result.reshape(reshape);
+        }
+    }
+
+    /**
+     * Sets all elements outside the sequence to a constant value.
+     *
+     * <p>This function takes an n-dimensional input array of the form [batch_size,
+     * max_sequence_length, ....] and returns an array of the same shape. Parameter {@code
+     * sequenceLength} is used to handle variable-length sequences. sequence_length should be an
+     * input array of positive ints of dimension [batch_size].
+     *
+     * @param sequenceLength used to handle variable-length sequences
+     * @param value the constant value to be set
+     * @return the result {@code NDArray}
+     */
+    public NDArray sequenceMask(NDArray sequenceLength, float value) {
+        if (getShape().dimension() < 2 || getShape().isScalar() || getShape().hasZeroDimension()) {
+            throw new IllegalArgumentException(
+                    "sequenceMask is not supported for MxNDArray with less than 2 dimensions");
+        }
+        Shape expectedSequenceLengthShape = new Shape(getShape().get(0));
+        if (!sequenceLength.getShape().equals(expectedSequenceLengthShape)) {
+            throw new IllegalArgumentException("SequenceLength must be of shape [batchSize]");
+        }
+        OpParams params = new OpParams();
+        params.add("value", value);
+        params.add("use_sequence_length", true);
+        params.add("axis", 1);
+        return invoke(getParent(), "_npx_sequence_mask", new NDList(this, sequenceLength), params)
+                .head();
+    }
+
+    /**
+     * Sets all elements outside the sequence to 0.
+     *
+     * <p>This function takes an n-dimensional input array of the form [batch_size,
+     * max_sequence_length, ....] and returns an array of the same shape. Parameter {@code
+     * sequenceLength} is used to handle variable-length sequences. sequence_length should be an
+     * input array of positive ints of dimension [batch_size].
+     *
+     * @param sequenceLength used to handle variable-length sequences
+     * @return the result {@code NDArray}
+     */
+    public NDArray sequenceMask(NDArray sequenceLength) {
+        return sequenceMask(sequenceLength, 0);
+    }
+
+    /**
+     * Returns an {@code NDArray} of zeros with the same {@link Shape}, {@link DataType} and {@link
+     * SparseFormat} as the input {@code NDArray}.
+     *
+     * @return a {@code NDArray} filled with zeros
+     */
+    public NDArray zerosLike() {
+        OpParams params = new OpParams();
+        params.addParam("fill_value", 0);
+        return invoke(getParent(), "_npi_full_like", this, params);
+    }
+
+    /**
+     * Returns an {@code NDArray} of ones with the same {@link Shape}, {@link DataType} and {@link
+     * SparseFormat} as the input {@code NDArray}.
+     *
+     * @return a {@code NDArray} filled with ones
+     */
+    public NDArray onesLike() {
+        OpParams params = new OpParams();
+        params.addParam("fill_value", 1);
+        return invoke(getParent(), "_npi_full_like", this, params);
+    }
+
+    NDArray get(NDIndex index) {
+        return getNDArrayInternal().getIndexer().get(this, index);
+    }
+
+    NDArray get(long... indices) {
+        return get(new NDIndex(indices));
+    }
+
+    NDArray getScalar(long... indices) {
+        NDArray value = get(new NDIndex(indices));
+        if (value.size() != 1) {
+            throw new IllegalArgumentException("The supplied Index does not produce a scalar");
+        }
+        return value;
+    }
+
+    boolean getBoolean(long... indices) {
+        return getScalar(indices).toBooleanArray()[0];
+    }
+
+    /**
+     * Returns {@code true} if all elements in this {@code NDArray} are equal to the {@link Number}.
+     *
+     * @param number the number to compare
+     * @return the boolean result
+     */
+    public boolean contentEquals(Number number) {
+        if (number == null) {
+            return false;
+        }
+        try (NDArray result = eq(number)) {
+            return result.all().getBoolean();
+        }
+    }
+
+    /**
+     * Returns {@code true} if all elements in this {@code NDArray} are equal to the other {@link
+     * NDArray}.
+     *
+     * @param other the other {@code NDArray} to compare
+     * @return the boolean result
+     */
+    public boolean contentEquals(NDArray other) {
+        if (other == null || (!shapeEquals(other))) {
+            return false;
+        }
+        if (getDataType() != other.getDataType()) {
+            return false;
+        }
+        try (NDArray result = eq(other).toType(DataType.INT32, false)) {
+            return result.all().getBoolean();
+        }
+    }
+
+    /**
+     * Returns the boolean {@code NDArray} for element-wise "Equals" comparison.
+     *
+     * @param n the number to compare
+     * @return the boolean {@code NDArray} for element-wise "Equals" comparison
+     */
+    public NDArray eq(Number n) {
+        OpParams params = new OpParams();
+        params.add("scalar", n.toString());
+        return invoke(getParent(), "_npi_equal_scalar", this, params);
+    }
+
+    /**
+     * Returns the boolean {@code NDArray} for element-wise "Equals" comparison.
+     *
+     * @param other the {@code NDArray} to compare
+     * @return the boolean {@code NDArray} for element-wise "Equals" comparison
+     */
+    public NDArray eq(NDArray other) {
+        return invoke(getParent(), "_npi_equal", new NDArray[] {this, other}, null);
+    }
+
+    /**
+     * Returns the boolean {@code NDArray} for element-wise "Not equals" comparison.
+     *
+     * @param n the number to compare
+     * @return the boolean {@code NDArray} for element-wise "Not equals" comparison
+     */
+    public NDArray neq(Number n) {
+        OpParams params = new OpParams();
+        params.add("scalar", n.toString());
+        return invoke(getParent(), "_npi_not_equal_scalar", this, params);
+    }
+
+    /**
+     * Returns the boolean {@code NDArray} for element-wise "Not equals" comparison.
+     *
+     * @param other the {@code NDArray} to compare
+     * @return the boolean {@code NDArray} for element-wise "Not equals" comparison
+     */
+    public NDArray neq(NDArray other) {
+        return invoke(getParent(), "_npi_not_equal", new NDArray[] {this, other}, null);
+    }
+
+    /**
+     * Returns the boolean {@code NDArray} for element-wise "Greater" comparison.
+     *
+     * @param other the number to compare
+     * @return the boolean {@code NDArray} for element-wise "Greater" comparison
+     */
+    public NDArray gt(Number other) {
+        OpParams params = new OpParams();
+        params.add("scalar", other.toString());
+        return invoke(getParent(), "_npi_greater_scalar", this, params);
+    }
+
+    /**
+     * Returns the boolean {@code NDArray} for element-wise "Greater Than" comparison.
+     *
+     * @param other the {@code NDArray} to compare
+     * @return the boolean {@code NDArray} for element-wis "Greater Than" comparison
+     */
+    public NDArray gt(NDArray other) {
+        return invoke(getParent(), "_npi_greater", new NDArray[] {this, other}, null);
+    }
+
+    /**
+     * Returns the boolean {@code NDArray} for element-wise "Greater or equals" comparison.
+     *
+     * @param other the number to compare
+     * @return the boolean {@code NDArray} for element-wise "Greater or equals" comparison
+     */
+    public NDArray gte(Number other) {
+        OpParams params = new OpParams();
+        params.add("scalar", other.toString());
+        return invoke(getParent(), "_npi_greater_equal_scalar", this, params);
+    }
+
+    /**
+     * Returns the boolean {@code NDArray} for element-wise "Greater or equals" comparison.
+     *
+     * @param other the number to compare
+     * @return the boolean {@code NDArray} for "Greater or equals" comparison
+     */
+    public NDArray gte(NDArray other) {
+        return invoke(getParent(), "_npi_greater_equal", new NDArray[] {this, other}, null);
+    }
+
+    /**
+     * Returns the boolean {@code NDArray} for element-wise "Less" comparison.
+     *
+     * @param other the number to compare
+     * @return the boolean {@code NDArray} for element-wise "Less" comparison
+     */
+    public NDArray lt(Number other) {
+        OpParams params = new OpParams();
+        params.add("scalar", other.toString());
+        return invoke(getParent(), "_npi_less_scalar", this, params);
+    }
+
+    /**
+     * Returns the boolean {@code NDArray} for element-wise "Less" comparison.
+     *
+     * @param other the {@code NDArray} to compare
+     * @return the boolean {@code NDArray} for element-wise "Less" comparison
+     */
+    public NDArray lt(NDArray other) {
+        return invoke(getParent(), "_npi_less", new NDArray[] {this, other}, null);
+    }
+
+    /**
+     * Returns the boolean {@code NDArray} for element-wise "Less or equals" comparison.
+     *
+     * @param other the number to compare
+     * @return the boolean {@code NDArray} for element-wise "Less or equals" comparison
+     */
+    public NDArray lte(Number other) {
+        OpParams params = new OpParams();
+        params.add("scalar", other.toString());
+        return invoke(getParent(), "_npi_less_equal_scalar", this, params);
+    }
+
+    /**
+     * Returns the boolean {@code NDArray} for element-wise "Less or equals" comparison.
+     *
+     * @param other the {@code NDArray} to compare
+     * @return the boolean {@code NDArray} for element-wise "Less or equals" comparison
+     */
+    public NDArray lte(NDArray other) {
+        return invoke(getParent(), "_npi_less_equal", new NDArray[] {this, other}, null);
+    }
+
+    /**
+     * Adds a number to this {@code NDArray} element-wise.
+     *
+     * @param n the number to add
+     * @return the result {@code NDArray}
+     */
+    public NDArray add(Number n) {
+        OpParams params = new OpParams();
+        params.add("scalar", n.toString());
+        return invoke(getParent(), "_npi_add_scalar", this, params);
+    }
+
+    /**
+     * Adds other {@code NDArray}s to this {@code NDArray} element-wise.
+     *
+     * @param other the other {@code NDArray}s to add
+     * @return the result {@code NDArray}
+     * @throws IllegalArgumentException others arrays must have at least one element
+     */
+    public NDArray add(NDArray other) {
+        return invoke(getParent(), "_npi_add", new NDArray[] {this, other}, null);
+    }
+
+    /**
+     * Subtracts a number from this {@code NDArray} element-wise.
+     *
+     * @param n the number to subtract from
+     * @return the result {@code NDArray}
+     */
+    public NDArray sub(Number n) {
+        OpParams params = new OpParams();
+        params.add("scalar", n.toString());
+        return invoke(getParent(), "_npi_subtract_scalar", this, params);
+    }
+
+    /**
+     * Subtracts the other {@code NDArray} from this {@code NDArray} element-wise.
+     *
+     * @param other the other {@code NDArray} to subtract from
+     * @return the result {@code NDArray}
+     */
+    public NDArray sub(NDArray other) {
+        return invoke(getParent(), "_npi_subtract", new NDArray[] {this, other}, null);
+    }
+
+    /**
+     * Multiplies this {@code NDArray} by a number element-wise.
+     *
+     * @param n the number to multiply by
+     * @return the result {@code NDArray}
+     */
+    public NDArray mul(Number n) {
+        OpParams params = new OpParams();
+        params.add("scalar", n.toString());
+        return invoke(getParent(), "_npi_multiply_scalar", this, params);
+    }
+
+    /**
+     * Multiplies this {@code NDArray} by other {@code NDArray}s element-wise.
+     *
+     * @param other the other {@code NDArray}s to multiply by
+     * @return the result {@code NDArray}
+     * @throws IllegalArgumentException others arrays must have at least one element
+     */
+    public NDArray mul(NDArray other) {
+        return invoke(getParent(), "_npi_multiply", new NDArray[] {this, other}, null);
+    }
+
+    /**
+     * Divides this {@code NDArray} by a number element-wise.
+     *
+     * @param n the number to divide by
+     * @return the result {@code NDArray}
+     */
+    public NDArray div(Number n) {
+        OpParams params = new OpParams();
+        params.add("scalar", n.toString());
+        return invoke(getParent(), "_npi_true_divide_scalar", this, params);
+    }
+
+    /**
+     * Divides this {@code NDArray} by the other {@code NDArray} element-wise.
+     *
+     * @param other the other {@code NDArray} to divide by
+     * @return the result {@code NDArray}
+     */
+    public NDArray div(NDArray other) {
+        return invoke(getParent(), "_npi_true_divide", new NDArray[] {this, other}, null);
+    }
+
+    /**
+     * Returns element-wise remainder of division.
+     *
+     * @param n the divisor number
+     * @return the result {@code NDArray}
+     */
+    public NDArray mod(Number n) {
+        OpParams params = new OpParams();
+        params.add("scalar", n.toString());
+        return invoke(getParent(), "_npi_mod_scalar", this, params);
+    }
+
+    /**
+     * Returns element-wise remainder of division.
+     *
+     * @param other the divisor {@code NDArray}
+     * @return the result {@code NDArray}
+     */
+    public NDArray mod(NDArray other) {
+        return invoke(getParent(), "_npi_mod", new NDArray[] {this, other}, null);
+    }
+
+    /**
+     * Takes the power of this {@code NDArray} with a number element-wise.
+     *
+     * @param n the number to take the power with
+     * @return the result {@code NDArray}
+     */
+    public NDArray pow(Number n) {
+        OpParams params = new OpParams();
+        params.add("scalar", n.toString());
+        return invoke(getParent(), "_npi_power_scalar", this, params);
+    }
+
+    /**
+     * Takes the power of this {@code NDArray} with the other {@code NDArray} element-wise.
+     *
+     * @param other the other {@code NDArray} to take the power with
+     * @return the result {@code NDArray}
+     */
+    public NDArray pow(NDArray other) {
+        return invoke(getParent(), "_npi_power", new NDArray[] {this, other}, null);
+    }
+
+    /**
+     * Adds a number to this {@code NDArray} element-wise in place.
+     *
+     * @param n the number to add
+     * @return the result {@code NDArray}
+     */
+    public NDArray addi(Number n) {
+        OpParams params = new OpParams();
+        params.add("scalar", n.toString());
+        invoke("_npi_add_scalar", new NDArray[] {this}, new NDArray[] {this}, params);
+        return this;
+    }
+
+    /**
+     * Adds other {@code NDArray}s to this {@code NDArray} element-wise in place.
+     *
+     * @param other the other {@code NDArray}s to add
+     * @return the result {@code NDArray}
+     */
+    public NDArray addi(NDArray other) {
+        invoke("_npi_add", new NDArray[] {this, other}, new NDArray[] {this}, null);
+        return this;
+    }
+
+    /**
+     * Subtracts a number from this {@code NDArray} element-wise in place.
+     *
+     * @param n the number to subtract
+     * @return the result {@code NDArray}
+     */
+    public NDArray subi(Number n) {
+        OpParams params = new OpParams();
+        params.add("scalar", n.toString());
+        invoke("_npi_subtract_scalar", new NDArray[] {this}, new NDArray[] {this}, params);
+        return this;
+    }
+
+    /**
+     * Subtracts the other {@code NDArray} from this {@code NDArray} element-wise in place.
+     *
+     * @param other the other {@code NDArray} to subtract from
+     * @return the result {@code NDArray}
+     */
+    public NDArray subi(NDArray other) {
+        invoke("_npi_subtract", new NDArray[] {this, other}, new NDArray[] {this}, null);
+        return this;
+    }
+
+    /**
+     * Multiplies this {@code NDArray} by a number element-wise in place.
+     *
+     * @param n the number to multiply by
+     * @return the result {@code NDArray}
+     */
+    public NDArray muli(Number n) {
+        OpParams params = new OpParams();
+        params.add("scalar", n.toString());
+        invoke("_npi_multiply_scalar", new NDArray[] {this}, new NDArray[] {this}, params);
+        return this;
+    }
+
+    /**
+     * Multiplies this {@code NDArray} by other {@code NDArray} element-wise in place.
+     *
+     * @param other the other NDArrays to multiply with
+     * @return the result {@code NDArray}
+     */
+    public NDArray muli(NDArray other) {
+        invoke("_npi_multiply", new NDArray[] {this, other}, new NDArray[] {this}, null);
+        return this;
+    }
+
+    /**
+     * Divides this {@code NDArray} by a number element-wise in place.
+     *
+     * @param n the number to divide values by
+     * @return the array after applying division operation
+     */
+    public NDArray divi(Number n) {
+        OpParams params = new OpParams();
+        params.add("scalar", n.toString());
+        invoke("_npi_true_divide_scalar", new NDArray[] {this}, new NDArray[] {this}, params);
+        return this;
+    }
+
+    /**
+     * Divides this {@code NDArray} by the other {@code NDArray} element-wise in place.
+     *
+     * @param other the other {@code NDArray} to divide by
+     * @return the result of the divide
+     */
+    public NDArray divi(NDArray other) {
+        invoke("_npi_true_divide", new NDArray[] {this, other}, new NDArray[] {this}, null);
+        return this;
+    }
+
+    /**
+     * Returns element-wise remainder of division in place.
+     *
+     * @param n the divisor number
+     * @return the result {@code NDArray}
+     */
+    public NDArray modi(Number n) {
+        OpParams params = new OpParams();
+        params.add("scalar", n.toString());
+        invoke("_npi_mod_scalar", new NDArray[] {this}, new NDArray[] {this}, params);
+        return this;
+    }
+
+    /**
+     * Returns in place element-wise remainder of division in place.
+     *
+     * @param other the divisor {@code NDArray}
+     * @return the result of the divide
+     */
+    public NDArray modi(NDArray other) {
+        invoke("_npi_mod", new NDArray[] {this, other}, new NDArray[] {this}, null);
+        return this;
+    }
+
+    /**
+     * Takes the power of this {@code NDArray} with a number element-wise in place.
+     *
+     * @param n the number to raise the power to
+     * @return the result {@code NDArray}
+     */
+    public NDArray powi(Number n) {
+        OpParams params = new OpParams();
+        params.add("scalar", n.toString());
+        invoke("_npi_power_scalar", new NDArray[] {this}, new NDArray[] {this}, params);
+        return this;
+    }
+
+    /**
+     * Takes the power of this {@code NDArray} with the other {@code NDArray} element-wise in place.
+     *
+     * @param other the other {@code NDArray} to take the power with
+     * @return the result {@code NDArray}
+     */
+    public NDArray powi(NDArray other) {
+        invoke("_npi_power", new NDArray[] {this, other}, new NDArray[] {this}, null);
+        return this;
+    }
+
+    /**
+     * Returns the element-wise sign.
+     *
+     * @return the result {@code NDArray}
+     */
+    public NDArray sign() {
+        return invoke(getParent(), "_npi_sign", this, null);
+    }
+
+    /**
+     * Returns the element-wise sign in-place.
+     *
+     * @return the result {@code NDArray}
+     */
+    public NDArray signi() {
+        invoke("_npi_sign", new NDArray[] {this}, new NDArray[] {this}, null);
+        return this;
+    }
+
+    /**
+     * Returns the maximum of this {@code NDArray} and a number element-wise.
+     *
+     * @param n the number to be compared
+     * @return the maximum of this {@code NDArray} and a number element-wise
+     */
+    public NDArray maximum(Number n) {
+        OpParams params = new OpParams();
+        params.add("scalar", n.toString());
+        return invoke(getParent(), "_npi_maximum_scalar", this, params);
+    }
+
+    /**
+     * Returns the maximum of this {@code NDArray} and the other {@code NDArray} element-wise.
+     *
+     * @param other the {@code NDArray} to be compared
+     * @return the maximum of this {@code NDArray} and the other {@code NDArray} element-wise
+     */
+    public NDArray maximum(NDArray other) {
+        return invoke(getParent(), "_npi_maximum", new NDArray[] {this, other}, null);
+    }
+
+    /**
+     * Returns the minimum of this {@code NDArray} and a number element-wise.
+     *
+     * @param n the number to be compared
+     * @return the minimum of this {@code NDArray} and a number element-wise
+     */
+    public NDArray minimum(Number n) {
+        OpParams params = new OpParams();
+        params.add("scalar", n.toString());
+        return invoke(getParent(), "_npi_minimum_scalar", this, params);
+    }
+
+    /**
+     * Returns the maximum of this {@code NDArray} and the other {@code NDArray} element-wise.
+     *
+     * @param other the {@code NDArray} to be compared
+     * @return the maximum of this {@code NDArray} and the other {@code NDArray} element-wise
+     */
+    public NDArray minimum(NDArray other) {
+        return invoke(getParent(), "_npi_minimum", new NDArray[] {this, other}, null);
+    }
+
+    /**
+     * Returns the numerical negative {@code NDArray} element-wise.
+     *
+     * @return the result {@code NDArray}
+     */
+    public NDArray neg() {
+        return invoke(getParent(), "_npi_negative", this, null);
+    }
+
+    /**
+     * Returns the numerical negative {@code NDArray} element-wise in place.
+     *
+     * @return the result {@code NDArray}
+     */
+    public NDArray negi() {
+        invoke("_npi_negative", new NDArray[] {this}, new NDArray[] {this}, null);
+        return this;
+    }
+
+    /**
+     * Returns the absolute value of this {@code NDArray} element-wise.
+     *
+     * @return the result {@code NDArray}
+     */
+    public NDArray abs() {
+        return invoke(getParent(), "_npi_absolute", this, null);
+    }
+
+    /**
+     * Returns the square of this {@code NDArray} element-wise.
+     *
+     * @return the result {@code NDArray}
+     */
+    public NDArray square() {
+        return invoke(getParent(), "_npi_square", this, null);
+    }
+
+    /**
+     * Returns the square root of this {@code NDArray} element-wise.
+     *
+     * @return the result {@code NDArray}
+     */
+    public NDArray sqrt() {
+        return invoke(getParent(), "_npi_sqrt", this, null);
+    }
+
+    /**
+     * Returns the cube-root of this {@code NDArray} element-wise.
+     *
+     * @return the result {@code NDArray}
+     */
+    public NDArray cbrt() {
+        return invoke(getParent(), "_npi_cbrt", this, null);
+    }
+
+    /**
+     * Returns the floor of this {@code NDArray} element-wise.
+     *
+     * @return the result {@code NDArray}
+     */
+    public NDArray floor() {
+        return invoke(getParent(), "_npi_floor", this, null);
+    }
+
+    /**
+     * Returns the ceiling of this {@code NDArray} element-wise.
+     *
+     * @return the result {@code NDArray}
+     */
+    public NDArray ceil() {
+        return invoke(getParent(), "_npi_ceil", this, null);
+    }
+
+    /**
+     * Returns the round of this {@code NDArray} element-wise.
+     *
+     * @return the result {@code NDArray}
+     */
+    public NDArray round() {
+        return invoke(getParent(), "round", this, null);
+    }
+
+    /**
+     * Returns the truncated value of this {@code NDArray} element-wise.
+     *
+     * @return the result {@code NDArray}
+     */
+    public NDArray trunc() {
+        return invoke(getParent(), "_npi_trunc", this, null);
+    }
+
+    /**
+     * Returns the exponential value of this {@code NDArray} element-wise.
+     *
+     * @return the result {@code NDArray}
+     */
+    public NDArray exp() {
+        return invoke(getParent(), "_npi_exp", this, null);
+    }
+
+    /**
+     * Returns the natural logarithmic value of this {@code NDArray} element-wise.
+     *
+     * @return the result {@code NDArray}
+     */
+    public NDArray log() {
+        return invoke(getParent(), "_npi_log", this, null);
+    }
+
+    /**
+     * Returns the base 10 logarithm of this {@code NDArray} element-wise.
+     *
+     * @return the result {@code NDArray}
+     */
+    public NDArray log10() {
+        return invoke(getParent(), "_npi_log10", this, null);
+    }
+
+    /**
+     * Returns the base 2 logarithm of this {@code NDArray} element-wise.
+     *
+     * @return the result {@code NDArray}
+     */
+    public NDArray log2() {
+        return invoke(getParent(), "_npi_log2", this, null);
+    }
+
+    /**
+     * Returns the trigonometric sine of this {@code NDArray} element-wise.
+     *
+     * @return the result {@code NDArray}
+     */
+    public NDArray sin() {
+        return invoke(getParent(), "_npi_sin", this, null);
+    }
+
+    /**
+     * Returns the trigonometric cosine of this {@code NDArray} element-wise.
+     *
+     * @return the result {@code NDArray}
+     */
+    public NDArray cos() {
+        return invoke(getParent(), "_npi_cos", this, null);
+    }
+
+    /**
+     * Returns the trigonometric tangent of this {@code NDArray} element-wise.
+     *
+     * @return the result {@code NDArray}
+     */
+    public NDArray tan() {
+        return invoke(getParent(), "_npi_tan", this, null);
+    }
+
+    /**
+     * Returns the inverse trigonometric sine of this {@code NDArray} element-wise.
+     *
+     * @return the result {@code NDArray}
+     */
+    public NDArray asin() {
+        return invoke(getParent(), "_npi_arcsin", this, null);
+    }
+
+    /**
+     * Returns the inverse trigonometric cosine of this {@code NDArray} element-wise.
+     *
+     * @return the result {@code NDArray}
+     */
+    public NDArray acos() {
+        return invoke(getParent(), "_npi_arccos", this, null);
+    }
+
+    /**
+     * Returns the inverse trigonometric tangent of this {@code NDArray} element-wise.
+     *
+     * @return the result {@code NDArray}
+     */
+    public NDArray atan() {
+        return invoke(getParent(), "_npi_arctan", this, null);
+    }
+
+    /**
+     * Returns the hyperbolic sine of this {@code NDArray} element-wise.
+     *
+     * @return the result {@code NDArray}
+     */
+    public NDArray sinh() {
+        return invoke(getParent(), "_npi_sinh", this, null);
+    }
+
+    /**
+     * Returns the hyperbolic cosine of this {@code NDArray} element-wise.
+     *
+     * @return the result {@code NDArray}
+     */
+    public NDArray cosh() {
+        return invoke(getParent(), "_npi_cosh", this, null);
+    }
+
+    /**
+     * Returns the hyperbolic tangent of this {@code NDArray} element-wise.
+     *
+     * @return the result {@code NDArray}
+     */
+    public NDArray tanh() {
+        return invoke(getParent(), "_npi_tanh", this, null);
+    }
+
+    /**
+     * Returns the inverse hyperbolic sine of this {@code NDArray} element-wise.
+     *
+     * @return the result {@code NDArray}
+     */
+    public NDArray asinh() {
+        return invoke(getParent(), "_npi_arcsinh", this, null);
+    }
+
+    /**
+     * Returns the inverse hyperbolic cosine of this {@code NDArray} element-wise.
+     *
+     * @return the result {@code NDArray}
+     */
+    public NDArray acosh() {
+        return invoke(getParent(), "_npi_arccosh", this, null);
+    }
+
+    /**
+     * Returns the inverse hyperbolic tangent of this {@code NDArray} element-wise.
+     *
+     * @return the result {@code NDArray}
+     */
+    public NDArray atanh() {
+        return invoke(getParent(), "_npi_arctanh", this, null);
+    }
+
+    /**
+     * Converts this {@code NDArray} from radians to degrees element-wise.
+     *
+     * @return the result {@code NDArray}
+     */
+    public NDArray toDegrees() {
+        return invoke(getParent(), "_npi_degrees", this, null);
+    }
+
+    /**
+     * Converts this {@code NDArray} from degrees to radians element-wise.
+     *
+     * @return the result {@code NDArray}
+     */
+    public NDArray toRadians() {
+        return invoke(getParent(), "_npi_radians", this, null);
+    }
+
+    /**
+     * Returns the maximum of this {@code NDArray}.
+     *
+     * @return the maximum of this {@code NDArray}
+     */
+    public NDArray max() {
+        return invoke(getParent(), "_np_max", this, null);
+    }
+
+    /**
+     * Returns the maximum of this {@code NDArray} along given axes.
+     *
+     * @param axes the axes along which to operate
+     * @return the maximum of this {@code NDArray} with the specified axes removed from the Shape
+     *     containing the max
+     * @see NDArray#max(int[], boolean)
+     */
+    public NDArray max(int[] axes) {
+        OpParams params = new OpParams();
+        params.addTupleParam("axis", axes);
+        return invoke(getParent(), "_np_max", this, params);
+    }
+
+    /**
+     * Returns the maximum of this {@code NDArray} along given axes.
+     *
+     * @param axes the axes along which to operate
+     * @param keepDims {@code true} to keep the specified axes as size 1 in the output array, {@code
+     *     false} to squeeze the values out of the output array.
+     * @return the maximum of this {@code NDArray}
+     */
+    public NDArray max(int[] axes, boolean keepDims) {
+        OpParams params = new OpParams();
+        params.addTupleParam("axis", axes);
+        params.addParam("keepdims", keepDims);
+        return invoke(getParent(), "_np_max", this, params);
+    }
+
+    /**
+     * Returns the minimum of this {@code NDArray}.
+     *
+     * @return the minimum of this {@code NDArray}
+     */
+    public NDArray min() {
+        return invoke(getParent(), "_np_min", this, null);
+    }
+
+    /**
+     * Returns the minimum of this {@code NDArray} along given axes.
+     *
+     * @param axes the axes along which to operate
+     * @return the minimum of this {@code NDArray} with the specified axes removed from the Shape
+     *     containing the min
+     * @see NDArray#min(int[], boolean)
+     */
+    public NDArray min(int[] axes) {
+        return min(axes, false);
+    }
+
+    /**
+     * Returns the minimum of this {@code NDArray} along given axes.
+     *
+     * @param axes the axes along which to operate
+     * @param keepDims {@code true} to keep the specified axes as size 1 in the output array, {@code
+     *     false} to squeeze the values out of the output array
+     * @return the minimum of this {@code NDArray}
+     */
+    public NDArray min(int[] axes, boolean keepDims) {
+        OpParams params = new OpParams();
+        params.addTupleParam("axis", axes);
+        params.addParam("keepdims", keepDims);
+        return invoke(getParent(), "_np_min", this, params);
+    }
+
+    /**
+     * Returns the sum of this {@code NDArray}.
+     *
+     * @return the sum of this {@code NDArray}
+     */
+    public NDArray sum() {
+        // TODO current windows doesn't support boolean MxNDArray
+        if (System.getProperty("os.name").toLowerCase().contains("win")) {
+            DataType target = getDataType();
+            if (!target.isFloating()) {
+                try (NDArray thisArr = toType(DataType.FLOAT32, false)) {
+                    if (target == DataType.BOOLEAN) {
+                        target = DataType.INT64;
+                    }
+                    try (NDArray array = invoke(getParent(), "_np_sum", thisArr, null)) {
+                        return array.toType(target, false);
+                    }
+                }
+            }
+        }
+        return invoke(getParent(), "_np_sum", this, null);
+    }
+
+    /**
+     * Returns the sum of this {@code NDArray} along given axes.
+     *
+     * @param axes the axes along which to operate
+     * @return the sum of this {@code NDArray} with the specified axes removed from the Shape
+     *     containing the sum
+     * @see NDArray#sum(int[], boolean)
+     */
+    public NDArray sum(int[] axes) {
+        return sum(axes, false);
+    }
+
+    /**
+     * Returns the sum of this {@code NDArray} along given axes.
+     *
+     * @param axes the axes along which to operate
+     * @param keepDims {@code true} to keep the specified axes as size 1 in the output array, {@code
+     *     false} to squeeze the values out of the output array
+     * @return the sum of this {@code NDArray}
+     */
+    public NDArray sum(int[] axes, boolean keepDims) {
+        OpParams params = new OpParams();
+        params.addTupleParam("axis", axes);
+        params.addParam("keepdims", keepDims);
+        return invoke(getParent(), "_np_sum", this, params);
+    }
+
+    /**
+     * Returns the product of this {@code NDArray}.
+     *
+     * @return the product of this {@code NDArray}
+     */
+    public NDArray prod() {
+        return invoke(getParent(), "_np_prod", this, null);
+    }
+
+    /**
+     * Returns the product of this {@code NDArray} elements over the given axes.
+     *
+     * @param axes the axes along which to operate
+     * @return the product of this {@code NDArray} with the specified axes removed from the Shape
+     *     containing the prod
+     * @see NDArray#prod(int[], boolean)
+     */
+    NDArray prod(int[] axes) {
+        return prod(axes, false);
+    }
+
+    /**
+     * Returns the product of this {@code NDArray} elements over the given axes.
+     *
+     * @param axes the axes along which to operate
+     * @param keepDims {@code true} to keep the specified axes as size 1 in the output array, {@code
+     *     false} to squeeze the values out of the output array
+     * @return the product of this {@code NDArray}
+     */
+    public NDArray prod(int[] axes, boolean keepDims) {
+        OpParams params = new OpParams();
+        params.addTupleParam("axis", axes);
+        params.addParam("keepdims", keepDims);
+        return invoke(getParent(), "_np_prod", this, params);
+    }
+
+    /**
+     * Returns the average of this {@code NDArray}.
+     *
+     * @return the average of this {@code NDArray}
+     */
+    public NDArray mean() {
+        return invoke(getParent(), "_npi_mean", this, null);
+    }
+
+    /**
+     * Returns the average of this {@code NDArray} along given axes.
+     *
+     * @param axes the axes along which to operate
+     * @return the average of this {@code NDArray} with the specified axes removed from the Shape
+     *     containing the mean
+     * @see NDArray#mean(int[], boolean)
+     */
+    public NDArray mean(int[] axes) {
+        return mean(axes, false);
+    }
+
+    /**
+     * Returns the average of this {@code NDArray} along given axes.
+     *
+     * @param axes the axes along which to operate
+     * @param keepDims {@code true} to keep the specified axes as size 1 in the output array, {@code
+     *     false} to squeeze the values out of the output array
+     * @return the average of this {@code NDArray}
+     */
+    public NDArray mean(int[] axes, boolean keepDims) {
+        OpParams params = new OpParams();
+        params.addTupleParam("axis", axes);
+        params.addParam("keepdims", keepDims);
+        return invoke(getParent(), "_npi_mean", this, params);
+    }
+
+    /**
+     * Rotates an array by 90 degrees in the plane specified by axes.
+     *
+     * @param times Number of times the array is rotated by 90 degrees.
+     * @param axes The array is rotated in the plane defined by the axes. Axes must be different.
+     * @return the rotated NDArray
+     */
+    public NDArray rotate90(int times, int[] axes) {
+        if (axes.length != 2) {
+            throw new IllegalArgumentException("Axes must be 2");
+        }
+        OpParams params = new OpParams();
+        params.addTupleParam("axes", axes);
+        params.addParam("k", times);
+        return invoke(getParent(), "_npi_rot90", this, params);
+    }
+
+    /**
+     * Returns the sum along diagonals of this {@code NDArray}.
+     *
+     * <p>If this {@code NDArray} is 2-D, the sum along its diagonal with the given offset is
+     * returned, i.e., the sum of elements a[i,i+offset] for all i. If this {@code NDArray} has more
+     * than two dimensions, then the axes specified by axis1 and axis2 are used to determine the 2-D
+     * sub-arrays whose traces are returned. The {@link Shape} of the resulting array is the same as
+     * this {@code NDArray} with axis1 and axis2 removed.
+     *
+     * @param offset offset of the diagonal from the main diagonal. Can be both positive and
+     *     negative.
+     * @param axis1 axes to be used as the first axis of the 2-D sub-arrays from which the diagonals
+     *     should be taken
+     * @param axis2 axes to be used as the second axis of the 2-D sub-arrays from which the
+     *     diagonals should be taken
+     * @return the sum along diagonals of this {@code NDArray}
+     */
+    public NDArray trace(int offset, int axis1, int axis2) {
+        OpParams params = new OpParams();
+        params.addParam("offset", offset);
+        params.addParam("axis1", axis1);
+        params.addParam("axis2", axis2);
+        return invoke(getParent(), "_np_trace", this, params);
+    }
+
+    /**
+     * Returns the sum along diagonals of this {@code NDArray}.
+     *
+     * <p>If this {@code NDArray} is 2-D, the sum along its diagonal with the given offset is
+     * returned, i.e., the sum of elements a[i,i+offset] for all i. If this {@code NDArray} has more
+     * than two dimensions, then the axes specified by axis1 and axis2 are used to determine the 2-D
+     * sub-arrays whose traces are returned. The {@link Shape} of the resulting array is the same as
+     * this {@code NDArray} with axis1 and axis2 removed.
+     *
+     * @param offset offset of the diagonal from the main diagonal. Can be both positive and
+     *     negative.
+     * @return the sum along diagonals of this {@code NDArray}
+     */
+    public NDArray trace(int offset) {
+        return trace(offset, 0, 1);
+    }
+
+    /**
+     * Splits this {@code NDArray} into multiple sub{@code NDArray}s given sections along the given
+     * axis.
+     *
+     * @param indices this {@code NDArray} will be divided into N (sections) equal arrays along axis
+     * @param axis the axis to split along
+     * @return an {@link NDList} with numOutputs {@code NDArray}s with {@link Shape} {@code
+     *     (this.shape.axis /= axis) }
+     * @throws IllegalArgumentException thrown if the numOutputs does not equally divide the given
+     *     axis
+     */
+    public NDList split(long[] indices, int axis) {
+        if (indices.length == 0) {
+            return new NDList(this);
+        }
+        OpParams params = new OpParams();
+        // follow the numpy behavior
+        if (indices[0] != 0) {
+            long[] tempIndices = new long[indices.length + 1];
+            tempIndices[0] = 0;
+            System.arraycopy(indices, 0, tempIndices, 1, indices.length);
+            indices = tempIndices;
+        }
+        params.addTupleParam("indices", indices);
+        params.addParam("axis", axis);
+        params.addParam("squeeze_axis", false);
+        return invoke(getParent(), "_npi_split", new NDList(this), params);
+    }
+
+    /**
+     * Flattens this {@code NDArray} into a 1-D {@code NDArray} in row-major order.
+     *
+     * <p>To flatten in column-major order, first transpose this {@code NDArray}
+     *
+     * @return a 1-D {@code NDArray} of equal size
+     */
+    public NDArray flatten() {
+        return reshape(new Shape(Math.toIntExact(size())));
+    }
+
+    /**
+     * Reshapes this {@code NDArray} to the given {@link Shape}.
+     *
+     * <p>You can reshape it to match another NDArray by calling {@code a.reshape(b.getShape()) }
+     *
+     * @param shape the {@link Shape} to reshape into. Must have equal size to the current shape
+     * @return a reshaped {@code NDArray}
+     * @throws IllegalArgumentException thrown if the given {@link Shape} does not match the size of
+     *     the current shape
+     */
+    public NDArray reshape(Shape shape) {
+        OpParams params = new OpParams();
+        params.addParam("newshape", shape);
+        return invoke(getParent(), "_np_reshape", this, params);
+    }
+
+    /**
+     * Reshapes this {@code NDArray} to the given {@link Shape}.
+     *
+     * @param newShape the long array to reshape into. Must have equal size to the current shape
+     * @return a reshaped {@code NDArray}
+     * @throws IllegalArgumentException thrown if the given {@link Shape} does not match the size of
+     *     the current shape
+     */
+    public NDArray reshape(long... newShape) {
+        return reshape(new Shape(newShape));
+    }
+
+    /**
+     * Expands the {@link Shape} of a {@code NDArray}.
+     *
+     * <p>Inserts a new axis that will appear at the axis position in the expanded {@code NDArray}
+     * shape.
+     *
+     * @param axis the position in the expanded axes where the new axis is placed
+     * @return the result {@code NDArray}. The number of dimensions is one greater than that of the
+     *     {@code NDArray}
+     */
+    public NDArray expandDims(int axis) {
+        OpParams params = new OpParams();
+        params.addParam("axis", axis);
+        return invoke(getParent(), "_npi_expand_dims", this, params);
+    }
+
+    /**
+     * Removes all singleton dimensions from this {@code NDArray} {@link Shape}.
+     *
+     * @return a result {@code NDArray} of same size and data without singleton dimensions
+     */
+    public NDArray squeeze() {
+        return invoke(getParent(), "_np_squeeze", this, null);
+    }
+
+    /**
+     * Removes singleton dimensions at the given axes.
+     *
+     * <p>Examples
+     *
+     * <pre>
+     * jshell&gt; NDArray array = manager.create(new float[] {0f, 1f, 2f}, new Shape(1, 3, 1));
+     * jshell&gt; array;
+     * ND: (1, 3, 1) cpu() float32
+     * [[[0.],
+     *   [1.],
+     *   [2.],
+     *  ],
+     * ]
+     * jshell&gt; array.squeeze(new int[] {0, 2});
+     * ND: (3) cpu() float32
+     * [0., 1., 2.]
+     * </pre>
+     *
+     * @param axes the axes at which to remove the singleton dimensions
+     * @return a result {@code NDArray} of same size and data without the axes at part of the shape
+     * @throws IllegalArgumentException thrown if any of the given axes are not a singleton
+     *     dimension
+     */
+    public NDArray squeeze(int[] axes) {
+        OpParams params = new OpParams();
+        params.addTupleParam("axis", axes);
+        return invoke(getParent(), "_np_squeeze", this, params);
+    }
+
+    /**
+     * Returns the truth value of this {@code NDArray} AND the other {@code NDArray} element-wise.
+     *
+     * <p>The shapes of this {@code NDArray} and the other {@code NDArray} must be broadcastable.
+     *
+     * @param other the other {@code NDArray} to operate on
+     * @return the boolean {@code NDArray} of the logical AND operation applied to the elements of
+     *     this {@code NDArray} and the other {@code NDArray}
+     */
+    public NDArray logicalAnd(NDArray other) {
+        // TODO switch to numpy op, although current op support zero-dim, scalar
+        NDArray thisArr =
+                (getDataType() == DataType.BOOLEAN) ? toType(DataType.INT32, false) : this;
+        other =
+                (other.getDataType() == DataType.BOOLEAN)
+                        ? other.toType(DataType.INT32, false)
+                        : other;
+        return invoke(getParent(), "broadcast_logical_and", new NDArray[] {thisArr, other}, null)
+                .toType(DataType.BOOLEAN, false);
+    }
+
+    /**
+     * Computes the truth value of this {@code NDArray} OR the other {@code NDArray} element-wise.
+     *
+     * <p>The shapes of this {@code NDArray} and the other {@code NDArray} must be broadcastable.
+     *
+     * @param other the other {@code NDArray} to operate on
+     * @return the boolean {@code NDArray} of the logical OR operation applied to the elements of
+     *     this {@code NDArray} and the other {@code NDArray}
+     */
+    public NDArray logicalOr(NDArray other) {
+        // TODO switch to numpy op, although current op support zero-dim, scalar
+        NDArray thisArr =
+                (getDataType() == DataType.BOOLEAN) ? toType(DataType.INT32, false) : this;
+        other =
+                (other.getDataType() == DataType.BOOLEAN)
+                        ? other.toType(DataType.INT32, false)
+                        : other;
+        return invoke(getParent(), "broadcast_logical_or", new NDArray[] {thisArr, other}, null)
+                .toType(DataType.BOOLEAN, false);
+    }
+
+    /**
+     * Computes the truth value of this {@code NDArray} XOR the other {@code NDArray} element-wise.
+     *
+     * <p>The shapes of this {@code NDArray} and the other {@code NDArray} must be broadcastable.
+     *
+     * @param other the other {@code NDArray} to operate on
+     * @return the boolean {@code NDArray} of the logical XOR operation applied to the elements of
+     *     this {@code NDArray} and the other {@code NDArray}
+     */
+    public NDArray logicalXor(NDArray other) {
+        // TODO switch to numpy op, although current op support zero-dim, scalar
+        NDArray thisArr =
+                (getDataType() == DataType.BOOLEAN) ? toType(DataType.INT32, false) : this;
+        other =
+                (other.getDataType() == DataType.BOOLEAN)
+                        ? other.toType(DataType.INT32, false)
+                        : other;
+        return invoke(getParent(), "broadcast_logical_xor", new NDArray[] {thisArr, other}, null)
+                .toType(DataType.BOOLEAN, false);
+    }
+
+    /**
+     * Computes the truth value of NOT this {@code NDArray} element-wise.
+     *
+     * @return the boolean {@code NDArray}
+     */
+    public NDArray logicalNot() {
+        return invoke(getParent(), "_npi_logical_not", this, null);
+    }
+
+    /**
+     * Returns the indices that would sort this {@code NDArray} given the axis.
+     *
+     * <p>Perform an indirect sort along the given axis. It returns a {@code NDArray} of indices of
+     * the same {@link Shape} as this {@code NDArray}.
+     *
+     * @param axis the axis to sort along
+     * @param ascending whether to sort ascending
+     * @return a {@code NDArray} of indices corresponding to elements in this {@code NDArray} on the
+     *     axis, the output DataType is always {@link DataType#INT64}
+     */
+    public NDArray argSort(int axis, boolean ascending) {
+        OpParams params = new OpParams();
+        params.addParam("axis", axis);
+        // be careful that MXNet numpy argsort op didn't officially support this param
+        params.addParam("is_ascend", ascending);
+        params.setDataType(DataType.INT64);
+        return invoke(getParent(), "_npi_argsort", this, params);
+    }
+
+    /**
+     * Sorts the flattened {@code NDArray}.
+     *
+     * @param axis the axis to sort along
+     * @return the sorted {@code NDArray}
+     */
+    public NDArray sort(int axis) {
+        OpParams params = new OpParams();
+        params.addParam("axis", axis);
+        return invoke(getParent(), "_npi_sort", this, params);
+    }
+
+    /**
+     * Sorts the flattened {@code NDArray}.
+     *
+     * @return the sorted {@code NDArray}
+     */
+    public NDArray sort() {
+        return invoke(getParent(), "_npi_sort", this, null);
+    }
+
+    /**
+     * Applies the softmax function along the given axis.
+     *
+     * @param axis the axis along which to apply
+     * @return the result {@code NDArray}
+     * @see <a href="https://en.wikipedia.org/wiki/Softmax_function">softmax</a>
+     * @see NDArray#softmax(int)
+     */
+    public NDArray softmax(int axis) {
+        // MXNet softmax op bug on GPU
+        if (isEmpty()) {
+            return create(getParent(), getShape(), DataType.FLOAT32, getDevice());
+        }
+        OpParams params = new OpParams();
+        params.addParam("axis", axis);
+        return invoke(getParent(), "_npx_softmax", this, params);
+    }
+
+    /**
+     * Applies the softmax function followed by a logarithm.
+     *
+     * <p>Mathematically equivalent to calling softmax and then log. This single operator is faster
+     * than calling two operators and numerically more stable when computing gradients.
+     *
+     * @param axis the axis along which to apply
+     * @return the result {@code NDArray}
+     */
+    public NDArray logSoftmax(int axis) {
+        // MXNet logsoftmax op bug on GPU
+        if (isEmpty()) {
+            return create(getParent(), getShape(), DataType.FLOAT32, getDevice());
+        }
+        OpParams params = new OpParams();
+        params.addParam("axis", axis);
+        return invoke(getParent(), "_npx_log_softmax", this, params);
+    }
+
+    /**
+     * Returns the cumulative sum of the elements in the flattened {@code NDArray}.
+     *
+     * @return the cumulative sum of the elements in the flattened {@code NDArray}
+     */
+    public NDArray cumSum() {
+        return invoke(getParent(), "_np_cumsum", this, null);
+    }
+
+    /**
+     * Return the cumulative sum of the elements along a given axis.
+     *
+     * @param axis the axis along which the cumulative sum is computed
+     * @return the cumulative sum along the specified axis
+     */
+    public NDArray cumSum(int axis) {
+        OpParams params = new OpParams();
+        params.addParam("axis", axis);
+        return invoke(getParent(), "_np_cumsum", this, params);
+    }
+
+    /**
+     * Replace the handle of the NDArray with the other. The NDArray used for replacement will be
+     * killed.
+     *
+     * <p>Please use with caution, this method will make the input argument unusable.
+     *
+     * @param replaced the handle provider that will be killed
+     */
+    public void intern(NDArray replaced) {
+        NDArray arr = replaced;
+        Pointer oldHandle = handle.getAndSet(arr.handle.getAndSet(null));
+        JnaUtils.waitToRead(oldHandle);
+        JnaUtils.freeNdArray(oldHandle);
+        // dereference old ndarray
+        arr.close();
+    }
+
+    /**
+     * Returns the boolean {@code NDArray} with value {@code true} where this {@code NDArray}'s
+     * entries are infinite, or {@code false} where they are not infinite.
+     *
+     * @return the boolean {@code NDArray} with value {@code true} if this {@code NDArray}'s entries
+     *     are infinite
+     */
+    public NDArray isInfinite() {
+        throw new UnsupportedOperationException("Not implemented yet.");
+    }
+
+    /**
+     * Returns the boolean {@code NDArray} with value {@code true} where this {@code NDArray}'s
+     * entries are NaN, or {@code false} where they are not NaN.
+     *
+     * @return the boolean {@code NDArray} with value {@code true} if this {@code NDArray}'s {@link
+     *     NDArray} are NaN
+     */
+    public NDArray isNaN() {
+        return invoke(getParent(), "_npi_isnan", this, null);
+    }
+
+    /**
+     * Returns a dense representation of the sparse {@code NDArray}.
+     *
+     * @return the result {@code NDArray}
+     */
+    public NDArray toDense() {
+        if (!isSparse()) {
+            return duplicate();
+        }
+        return castStorage(SparseFormat.DENSE);
+    }
+
+    /**
+     * Returns a sparse representation of {@code NDArray}.
+     *
+     * @param fmt the {@link SparseFormat} of this {@code NDArray}
+     * @return the result {@code NDArray}
+     */
+    public NDArray toSparse(SparseFormat fmt) {
+        if (fmt != SparseFormat.DENSE
+                && fmt != SparseFormat.CSR
+                && fmt != SparseFormat.ROW_SPARSE) {
+            throw new UnsupportedOperationException(fmt + " is not supported");
+        }
+        if (fmt == getSparseFormat()) {
+            return duplicate();
+        }
+        return castStorage(fmt);
+    }
+
+    private NDArray castStorage(SparseFormat fmt) {
+        OpParams params = new OpParams();
+        params.setParam("stype", fmt.getType());
+        return invoke(getParent(), "cast_storage", this, params);
+    }
+
+    /**
+     * Constructs a {@code NDArray} by repeating this {@code NDArray} the number of times given
+     * repeats.
+     *
+     * @param repeats the number of times to repeat for each dimension
+     * @return a NDArray that has been tiled
+     */
+    public NDArray tile(long repeats) {
+        // zero-dim
+        if (isEmpty()) {
+            return duplicate();
+        }
+        // scalar
+        int dim = (isScalar()) ? 1 : getShape().dimension();
+        long[] repeatsArray = new long[dim];
+        Arrays.fill(repeatsArray, repeats);
+        return tile(repeatsArray);
+    }
+
+    /**
+     * Constructs a {@code NDArray} by repeating this {@code NDArray} the number of times given by
+     * repeats.
+     *
+     * @param repeats the number of times to repeat along each axis
+     * @return a {@code NDArray} that has been tiled
+     */
+    public NDArray tile(long[] repeats) {
+        OpParams params = new OpParams();
+        params.addTupleParam("reps", repeats);
+        return invoke(getParent(), "_npi_tile", this, params);
+    }
+
+    /**
+     * Constructs a {@code NDArray} by repeating this {@code NDArray} the number of times given by
+     * repeats along given axis.
+     *
+     * @param axis the axis to repeat
+     * @param repeats the number of times to repeat for each axis
+     * @return a {@code NDArray} that has been tiled
+     * @throws IllegalArgumentException thrown for invalid axis
+     */
+    public NDArray tile(int axis, long repeats) {
+        // scalar
+        if (isScalar()) {
+            throw new IllegalArgumentException("scalar didn't support specifying axis");
+        }
+        long[] repeatsArray = new long[getShape().dimension()];
+        Arrays.fill(repeatsArray, 1);
+        repeatsArray[withAxis(axis)] = repeats;
+        return tile(repeatsArray);
+    }
+
+    /**
+     * Constructs a {@code NDArray} by repeating this {@code NDArray} the number of times to match
+     * the desired shape.
+     *
+     * <p>If the desired {@link Shape}has fewer dimensions than this {@code NDArray}, it will tile
+     * against the last axis.
+     *
+     * @param desiredShape the {@link Shape}that should be converted to
+     * @return a {@code NDArray} that has been tiled
+     */
+    public NDArray tile(Shape desiredShape) {
+        return tile(repeatsToMatchShape(desiredShape));
+    }
+
+    private int withAxis(int axis) {
+        return Math.floorMod(axis, getShape().dimension());
+    }
+
+    private long[] repeatsToMatchShape(Shape desiredShape) {
+        Shape curShape = getShape();
+        int dimension = curShape.dimension();
+        if (desiredShape.dimension() > dimension) {
+            throw new IllegalArgumentException("The desired shape has too many dimensions");
+        }
+        if (desiredShape.dimension() < dimension) {
+            int additionalDimensions = dimension - desiredShape.dimension();
+            desiredShape = curShape.slice(0, additionalDimensions).addAll(desiredShape);
+        }
+        long[] repeats = new long[dimension];
+        for (int i = 0; i < dimension; i++) {
+            if (curShape.get(i) == 0 || desiredShape.get(i) % curShape.get(i) != 0) {
+                throw new IllegalArgumentException(
+                        "The desired shape is not a multiple of the original shape");
+            }
+            repeats[i] = Math.round(Math.ceil((double) desiredShape.get(i) / curShape.get(i)));
+        }
+        return repeats;
+    }
+
+    /**
+     * Repeats element of this {@code NDArray} the number of times given repeats.
+     *
+     * @param repeats the number of times to repeat for each axis
+     * @return an {@code NDArray} that has been repeated
+     */
+    public NDArray repeat(long repeats) {
+        // zero-dim
+        if (isEmpty()) {
+            return duplicate();
+        }
+        // scalar
+        int dim = (isScalar()) ? 1 : getShape().dimension();
+        long[] repeatsArray = new long[dim];
+        Arrays.fill(repeatsArray, repeats);
+        return repeat(repeatsArray);
+    }
+
+    /**
+     * Repeats element of this {@code NDArray} the number of times given repeats along given axis.
+     *
+     * @param axis the axis to repeat
+     * @param repeats the number of times to repeat for each axis
+     * @return an {@code NDArray} that has been repeated
+     * @throws IllegalArgumentException thrown for invalid axis
+     */
+    public NDArray repeat(int axis, long repeats) {
+        long[] repeatsArray = new long[getShape().dimension()];
+        Arrays.fill(repeatsArray, 1);
+        repeatsArray[withAxis(axis)] = repeats;
+        return repeat(repeatsArray);
+    }
+
+    /**
+     * Repeats element of this {@code NDArray} the number of times given repeats along each axis.
+     *
+     * @param repeats the number of times to repeat along each axis
+     * @return a {@code NDArray} that has been repeated
+     */
+    public NDArray repeat(long[] repeats) {
+        // TODO get rid of for loop once bug in MXNet np.repeat is fixed
+        NDArray array = this;
+        int baseAxis = getShape().dimension() - repeats.length;
+        for (int i = 0; i < repeats.length; i++) {
+            if (repeats[i] > 1) {
+                NDArray previousArray = array;
+                OpParams params = new OpParams();
+                params.addParam("repeats", repeats[i]);
+                params.addParam("axis", baseAxis + i);
+                array = invoke(getParent(), "_np_repeat", array, params);
+                if (previousArray != this) {
+                    previousArray.close();
+                }
+            }
+        }
+        return array;
+    }
+
+    /**
+     * Repeats element of this {@code NDArray} to match the desired shape.
+     *
+     * <p>If the desired {@link Shape} has fewer dimensions that the array, it will repeat against
+     * the last axis.
+     *
+     * @param desiredShape the {@link Shape} that should be converted to
+     * @return an {@code NDArray} that has been repeated
+     */
+    public NDArray repeat(Shape desiredShape) {
+        return repeat(repeatsToMatchShape(desiredShape));
+    }
+
+    /**
+     * Dot product of this {@code NDArray} and the other {@code NDArray}.
+     *
+     * <ul>
+     *   <li>If both this {@code NDArray} and the other {@code NDArray} are 1-D {@code NDArray}s, it
+     *       is inner product of vectors (without complex conjugation).
+     *   <li>If both this {@code NDArray} and the other {@code NDArray} are 2-D {@code NDArray}s, it
+     *       is matrix multiplication.
+     *   <li>If either this {@code NDArray} or the other {@code NDArray} is 0-D {@code NDArray}
+     *       (scalar), it is equivalent to mul.
+     *   <li>If this {@code NDArray} is N-D {@code NDArray} and the other {@code NDArray} is 1-D
+     *       {@code NDArray}, it is a sum product over the last axis of those.
+     *   <li>If this {@code NDArray} is N-D {@code NDArray} and the other {@code NDArray} is M-D
+     *       {@code NDArray}(where M&gt;&#61;2), it is a sum product over the last axis of this
+     *       {@code NDArray} and the second-to-last axis of the other {@code NDArray}
+     * </ul>
+     *
+     * @param other the other {@code NDArray} to perform dot product with
+     * @return the result {@code NDArray}
+     */
+    public NDArray dot(NDArray other) {
+        return invoke(getParent(), "_np_dot", new NDArray[] {this, other}, null);
+    }
+
+    /**
+     * Product matrix of this {@code NDArray} and the other {@code NDArray}.
+     *
+     * @param other the other {@code NDArray} to perform matrix product with
+     * @return the result {@code NDArray}
+     */
+    public NDArray matMul(NDArray other) {
+        if (isScalar() || other.isScalar()) {
+            throw new IllegalArgumentException("scalar is not allowed for matMul()");
+        }
+        return invoke(getParent(), "_npi_matmul", new NDArray[] {this, other}, null);
+    }
+
+    /**
+     * Clips (limit) the values in this {@code NDArray}.
+     *
+     * <p>Given an interval, values outside the interval are clipped to the interval edges. For
+     * example, if an interval of [0, 1] is specified, values smaller than 0 become 0, and values
+     * larger than 1 become 1.
+     *
+     * @param min the minimum value
+     * @param max the maximum value
+     * @return an {@code NDArray} with the elements of this {@code NDArray}, but where values &lt;
+     *     min are replaced with min, and those &gt; max with max
+     */
+    public NDArray clip(Number min, Number max) {
+        OpParams params = new OpParams();
+        params.addParam("a_min", min);
+        params.addParam("a_max", max);
+        return invoke(getParent(), "_npi_clip", this, params);
+    }
+
+    /**
+     * Interchanges two axes of this {@code NDArray}.
+     *
+     * @param axis1 the first axis
+     * @param axis2 the second axis
+     * @return the swapped axes {@code NDArray}
+     */
+    public NDArray swapAxes(int axis1, int axis2) {
+        OpParams params = new OpParams();
+        params.addParam("dim1", axis1);
+        params.addParam("dim2", axis2);
+        return invoke(getParent(), "_npi_swapaxes", this, params);
+    }
+
+    /**
+     * Returns the reverse order of elements in an array along the given axis.
+     *
+     * <p>The shape of the array is preserved, but the elements are reordered.
+     *
+     * @param axes the axes to flip on
+     * @return the newly flipped array
+     */
+    public NDArray flip(int... axes) {
+        OpParams params = new OpParams();
+        params.addTupleParam("axis", axes);
+        return invoke(getParent(), "_npi_flip", this, params);
+    }
+
+    /**
+     * Returns this {@code NDArray} with axes transposed.
+     *
+     * @return the newly permuted array
+     */
+    public NDArray transpose() {
+        return invoke(getParent(), "_np_transpose", this, null);
+    }
+
+    /**
+     * Returns this {@code NDArray} with given axes transposed.
+     *
+     * @param dimensions the axes to swap to
+     * @return the transposed {@code NDArray}
+     * @throws IllegalArgumentException thrown when passing a axis that is greater than the actual
+     *     number of dimensions
+     */
+    public NDArray transpose(int... dimensions) {
+        if (Arrays.stream(dimensions).anyMatch(d -> d < 0)) {
+            throw new UnsupportedOperationException(
+                    "Passing -1 for broadcasting the dimension is not currently supported");
+        }
+        if (!Arrays.equals(
+                Arrays.stream(dimensions).sorted().toArray(),
+                IntStream.range(0, getShape().dimension()).toArray())) {
+            throw new IllegalArgumentException(
+                    "You must include each of the dimensions from 0 until "
+                            + getShape().dimension());
+        }
+        OpParams params = new OpParams();
+        params.addTupleParam("axes", dimensions);
+        return invoke(getParent(), "_np_transpose", this, params);
+    }
+
+    /**
+     * Broadcasts this {@code NDArray} to be the given shape.
+     *
+     * @param shape the new {@link Shape} of this {@code NDArray}
+     * @return the broadcasted {@code NDArray}
+     */
+    public NDArray broadcast(Shape shape) {
+        OpParams params = new OpParams();
+        params.setShape(shape);
+        return invoke(getParent(), "_npi_broadcast_to", this, params);
+    }
+
+    /**
+     * Returns the indices of the maximum values into the flattened {@code NDArray}.
+     *
+     * @return a {@code NDArray} containing indices
+     */
+    public NDArray argMax() {
+        if (isEmpty()) {
+            throw new IllegalArgumentException("attempt to get argMax of an empty MxNDArray");
+        }
+        return invoke(getParent(), "_npi_argmax", this, null);
+    }
+
+    /**
+     * Returns the indices of the maximum values along given axis.
+     *
+     * @param axis the axis along which to find maximum values
+     * @return a {@code NDArray} containing indices
+     */
+    public NDArray argMax(int axis) {
+        OpParams params = new OpParams();
+        params.addParam("axis", axis);
+        return invoke(getParent(), "_npi_argmax", this, params);
+    }
+
+    /**
+     * Returns the indices of the minimum values into the flattened {@code NDArray}.
+     *
+     * @return a {@code NDArray} containing indices
+     */
+    public NDArray argMin() {
+        if (isEmpty()) {
+            throw new IllegalArgumentException("attempt to get argMin of an empty MxNDArray");
+        }
+        return invoke(getParent(), "_npi_argmin", this, null);
+    }
+
+    /**
+     * Returns the indices of the minimum values along given axis.
+     *
+     * @param axis the axis along which to find minimum values
+     * @return a {@code NDArray} containing indices
+     */
+    public NDArray argMin(int axis) {
+        OpParams params = new OpParams();
+        params.addParam("axis", axis);
+        return invoke(getParent(), "_npi_argmin", this, params);
+    }
+
+    /**
+     * Returns percentile for this {@code NDArray}.
+     *
+     * @param percentile the target percentile in range of 0..100
+     * @return the result {@code NDArray}
+     */
+    public NDArray percentile(Number percentile) {
+        throw new UnsupportedOperationException("Not implemented yet.");
+    }
+
+    /**
+     * Returns median along given dimension(s).
+     *
+     * @param percentile the target percentile in range of 0..100
+     * @param dimension the dimension to calculate percentile for
+     * @return the result {@code NDArray} NDArray
+     */
+    public NDArray percentile(Number percentile, int[] dimension) {
+        throw new UnsupportedOperationException("Not implemented yet.");
+    }
+
+    /**
+     * Returns median value for this {@code NDArray}.
+     *
+     * @return the median {@code NDArray}
+     */
+    public NDArray median() {
+        throw new UnsupportedOperationException("Not implemented yet.");
+    }
+
+    /**
+     * Returns median value along given axes.
+     *
+     * @param axes the axes along which to perform the median operation
+     * @return the median {@code NDArray} along the specified axes
+     */
+    public NDArray median(int[] axes) {
+        throw new UnsupportedOperationException("Not implemented yet.");
+    }
+
+    /**
+     * Returns the indices of elements that are non-zero.
+     *
+     * <p>Note that the behavior is slightly different from numpy.nonzero. Numpy returns a tuple of
+     * NDArray, one for each dimension of NDArray. DJL nonzero returns only one {@code NDArray} with
+     * last dimension containing all dimension of indices.
+     *
+     * @return the indices of the elements that are non-zero
+     */
+    public NDArray nonzero() {
+        NDArray thisArr =
+                (getDataType() == DataType.BOOLEAN) ? toType(DataType.INT32, false) : this;
+        return invoke(getParent(), "_npx_nonzero", thisArr, null);
+    }
+
+    /**
+     * Returns element-wise inverse gauss error function of the {@code NDArray}.
+     *
+     * @return The inverse of gauss error of the {@code NDArray}, element-wise
+     */
+    public NDArray erfinv() {
+        return invoke(getParent(), "erfinv", this, null);
+    }
+
+    /**
+     * Returns the norm of this {@code NDArray}.
+     *
+     * @param keepDims If this is set to True, the axes which are normed over are left in the result
+     *     as dimensions with size one. With this option the result will broadcast correctly against
+     *     the original x.
+     * @return the norm of this {@code NDArray}
+     */
+    public NDArray norm(boolean keepDims) {
+        OpParams params = new OpParams();
+        params.add("flag", -2);
+        params.addParam("keepdims", keepDims);
+        return invoke(getParent(), "_npi_norm", this, params);
+    }
+
+    /**
+     * Returns the norm of this {@code NDArray}.
+     *
+     * @param ord Order of the norm.
+     * @param axes If axes contains an integer, it specifies the axis of x along which to compute
+     *     the vector norms. If axis contains 2 integers, it specifies the axes that hold 2-D
+     *     matrices, and the matrix norms of these matrices are computed.
+     * @param keepDims keepDims If this is set to True, the axes which are normed over are left in
+     *     the result as dimensions with size one. With this option the result will broadcast
+     *     correctly against the original x.
+     * @return the norm of this {@code NDArray}
+     */
+    public NDArray norm(int ord, int[] axes, boolean keepDims) {
+        OpParams params = new OpParams();
+        params.addParam("ord", (double) ord);
+        params.addTupleParam("axis", axes);
+        params.addParam("keepdims", keepDims);
+        return invoke(getParent(), "_npi_norm", this, params);
+    }
+
+    //    public MxNDArray oneHot(int depth) {
+    //        return LazyNDArray.super.oneHot(depth);
+    //    }
+
+    /**
+     * Returns a one-hot {@code NDArray}.
+     *
+     * <ul>
+     *   <li>The locations represented by indices take value onValue, while all other locations take
+     *       value offValue.
+     *   <li>If the input {@code NDArray} is rank N, the output will have rank N+1. The new axis is
+     *       appended at the end.
+     *   <li>If {@code NDArray} is a scalar the output shape will be a vector of length depth.
+     *   <li>If {@code NDArray} is a vector of length features, the output shape will be features x
+     *       depth.
+     *   <li>If {@code NDArray} is a matrix with shape [batch, features], the output shape will be
+     *       batch x features x depth.
+     * </ul>
+     *
+     * @param depth Depth of the one hot dimension.
+     * @param onValue The value assigned to the locations represented by indices.
+     * @param offValue The value assigned to the locations not represented by indices.
+     * @param dataType dataType of the output.
+     * @return one-hot encoding of this {@code NDArray}
+     */
+    public NDArray oneHot(int depth, float onValue, float offValue, DataType dataType) {
+        OpParams params = new OpParams();
+        params.add("depth", depth);
+        params.add("on_value", onValue);
+        params.add("off_value", offValue);
+        params.add("dtype", dataType);
+        return invoke(getParent(), "_npx_one_hot", this, params).toType(dataType, false);
+    }
+
+    /**
+     * Batchwise product of this {@code NDArray} and the other {@code NDArray}.
+     *
+     * <ul>
+     *   <li>batchDot is used to compute dot product of x and y when x and y are data in batch,
+     *       namely N-D (N greater or equal to 3) arrays in shape of (B0, …, B_i, :, :). For
+     *       example, given x with shape (B_0, …, B_i, N, M) and y with shape (B_0, …, B_i, M, K),
+     *       the result array will have shape (B_0, …, B_i, N, K), which is computed by:
+     *       batch_dot(x,y)[b_0, ..., b_i, :, :] = dot(x[b_0, ..., b_i, :, :], y[b_0, ..., b_i, :,
+     *       :])
+     * </ul>
+     *
+     * @param other the other {@code NDArray} to perform batch dot product with
+     * @return the result {@code NDArray}
+     */
+    public NDArray batchDot(NDArray other) {
+        return invoke(getParent(), "_npx_batch_dot", new NDArray[] {this, other}, null);
+    }
+
+    /**
+     * Returns an internal representative of Native {@code NDArray}.
+     *
+     * <p>This method should only be used by Engine provider
+     *
+     * @return an internal representative of Native {@code NDArray}
+     */
+    public NDArrayEx getNDArrayInternal() {
+        return mxNDArrayEx;
+    }
+
+    /** {@inheritDoc} */
+    @Override
+    public void close() {
+        if (!getClosed()) {
+            logger.debug(String.format("Start to free NDArray instance: %S", this.getUid()));
+            super.freeSubResources();
+
+            if (this.getHandle() != null) {
+                JnaUtils.freeNdArray(this.getHandle());
+            }
+            setClosed(true);
+            logger.debug(String.format("Finish to free NDArray instance: %S", this.getUid()));
+        }
+    }
+
+    /**
+     * Returns {@code true} if this {@code NDArray} is special case: no-value {@code NDArray}.
+     *
+     * @return {@code true} if this NDArray is empty
+     */
+    public boolean isEmpty() {
+        return getShape().size() == 0;
+    }
+
+    boolean isSparse() {
+        return getSparseFormat() != SparseFormat.DENSE;
+    }
+
+    boolean shapeEquals(NDArray other) {
+        return getShape().equals(other.getShape());
+    }
+
+    /**
+     * An engine specific generic invocation to native operation.
+     *
+     * <p>You should avoid using this function if possible. Since this function is engine specific,
+     * using this API may cause a portability issue. Native operation may not compatible between
+     * each version.
+     *
+     * @param parent the parent {@link MxResource} of the created {@link NDList}
+     * @param operation the native operation to perform
+     * @param src the {@link NDList} of source {@link NDArray}
+     * @param params the parameters to be passed to the native operation
+     * @return the output array of {@link NDArray}
+     * @throws IllegalArgumentException if operation is not supported by Engine
+     */
+    public static NDList invoke(
+            MxResource parent, String operation, NDList src, PairList<String, ?> params) {
+        return new NDList(JnaUtils.op(operation).invoke(parent, src.toArray(EMPTY), params));
+    }
+
+    /**
+     * An engine specific generic invocation to native operator.
+     *
+     * <p>You should avoid using this function if possible. Since this function is engine specific,
+     * using this API may cause portability issues. A native operation may not be compatible between
+     * each version.
+     *
+     * @param operation the native operation to perform
+     * @param src the {@link NDList} of source {@link NDArray}
+     * @param dest the {@link NDList} to save output to
+     * @param params the parameters to be passed to the native operator
+     * @throws IllegalArgumentException if operation is not supported by Engine
+     */
+    public static void invoke(
+            String operation, NDList src, NDList dest, PairList<String, ?> params) {
+        invoke(operation, src.toArray(EMPTY), dest.toArray(EMPTY), params);
+    }
+
+    /**
+     * An engine specific generic invocation to native operator.
+     *
+     * <p>You should avoid using this function if possible. Since this function is engine specific,
+     * using this API may cause portability issues. A native operation may not be compatible between
+     * each version.
+     *
+     * @param parent the parent {@link MxResource} to manage this instance
+     * @param operation the native operation to perform
+     * @param src the array of source {@link NDArray}
+     * @param params the parameters to be passed to the native operator
+     * @return the output array of {@link NDArray}
+     */
+    public static NDArray invoke(
+            MxResource parent, String operation, NDArray[] src, PairList<String, ?> params) {
+        return JnaUtils.op(operation).invoke(parent, src, params)[0];
+    }
+
+    /**
+     * An engine specific generic invocation to native operation.
+     *
+     * <p>You should avoid using this function if possible. Since this function is engine specific,
+     * using this API may cause a portability issue. Native operation may not be compatible between
+     * each version.
+     *
+     * @param operation the native operation to perform
+     * @param src the {@link NDList} of source {@link NDArray}
+     * @param dest the {@link NDList} to save output to
+     * @param params the parameters to be passed to the native operation
+     * @throws IllegalArgumentException if operation is not supported by Engine
+     */
+    public static void invoke(
+            String operation, NDArray[] src, NDArray[] dest, PairList<String, ?> params) {
+        JnaUtils.op(operation).invoke(src, dest, params);
+    }
+
+    /**
+     * An engine specific generic invocation to native operator.
+     *
+     * <p>You should avoid using this function if possible. Since this function is engine specific,
+     * using this API may cause portability issues. A native operation may not be compatible between
+     * each version.
+     *
+     * @param parent the parent {@link MxResource} to manage this instance
+     * @param operation the native operation to perform
+     * @param src the source {@link NDArray}
+     * @param params the parameters to be passed to the native operator
+     * @return the output array of {@link NDArray}
+     */
+    public static NDArray invoke(
+            MxResource parent, String operation, NDArray src, PairList<String, ?> params) {
+        return invoke(parent, operation, new NDArray[] {src}, params);
+    }
+
+    /**
+     * An engine specific generic invocation to native operator.
+     *
+     * <p>You should avoid using this function if possible. Since this function is engine specific,
+     * using this API may cause portability issues. A native operation may not be compatible between
+     * each version.
+     *
+     * @param parent the parent {@link MxResource} to manage this instance
+     * @param operation the native operation to perform
+     * @param params the parameters to be passed to the native operator
+     * @return the output array of {@link NDArray}
+     */
+    public static NDArray invoke(MxResource parent, String operation, PairList<String, ?> params) {
+        return invoke(parent, operation, EMPTY, params);
+    }
+
+    /**
+     * Encodes {@code MxNDArray} to byte array.
+     *
+     * @return byte array
+     */
+    public byte[] encode() {
+        return NDSerializer.encode(this);
+    }
+
+    /**
+     * Draws samples from a uniform distribution.
+     *
+     * <p>Samples are uniformly distributed over the half-open interval [low, high) (includes low,
+     * but excludes high). In other words, any value within the given interval is equally likely to
+     * be drawn by uniform.
+     *
+     * @param parent {@link MxResource} of this instance
+     * @param low the lower boundary of the output interval. All values generated will be greater
+     *     than or equal to low.
+     * @param high the upper boundary of the output interval. All values generated will be less than
+     *     high.
+     * @param shape the {@link Shape} of the {@link NDArray}
+     * @param dataType the {@link DataType} of the {@link NDArray}
+     * @param device the {@link Device} of the {@link NDArray}
+     * @return the drawn samples {@link NDArray}
+     */
+    public static NDArray randomUniform(
+            MxResource parent,
+            float low,
+            float high,
+            Shape shape,
+            DataType dataType,
+            Device device) {
+        OpParams params = new OpParams();
+        params.addParam("low", low);
+        params.addParam("high", high);
+        params.addParam("size", shape);
+        params.setDevice(device);
+        params.setDataType(dataType);
+        return invoke(parent, "_npi_uniform", params);
+    }
+
+    /**
+     * Draws samples from a uniform distribution.
+     *
+     * <p>Samples are uniformly distributed over the half-open interval [low, high) (includes low,
+     * but excludes high). In other words, any value within the given interval is equally likely to
+     * be drawn by uniform.
+     *
+     * @param parent {@link MxResource} of this instance
+     * @param low the lower boundary of the output interval. All values generated will be greater
+     *     than or equal to low.
+     * @param high the upper boundary of the output interval. All values generated will be less than
+     *     high.
+     * @param shape the {@link Shape} of the {@link NDArray}
+     * @param dataType the {@link DataType} of the {@link NDArray}
+     * @return the drawn samples {@link NDArray}
+     */
+    private static NDArray randomUniform(
+            MxResource parent, float low, float high, Shape shape, DataType dataType) {
+        return randomUniform(parent, low, high, shape, dataType, Device.defaultIfNull(null));
+    }
+
+    /**
+     * Draws random samples from a normal (Gaussian) distribution.
+     *
+     * @param parent {@link MxResource} of this instance
+     * @param loc the mean (centre) of the distribution
+     * @param scale the standard deviation (spread or "width") of the distribution
+     * @param shape the output {@link Shape}
+     * @param dataType the {@link DataType} of the {@link NDArray}
+     * @param device the {@link Device} of the {@link NDArray}
+     * @return the drawn samples {@link NDArray}
+     */
+    public static NDArray randomNormal(
+            MxResource parent,
+            float loc,
+            float scale,
+            Shape shape,
+            DataType dataType,
+            Device device) {
+        if (device == null) {
+            return randomNormal(parent, loc, scale, shape, dataType);
+        }
+        return randomNormal(parent, loc, scale, shape, dataType);
+    }
+
+    /**
+     * Draws random samples from a normal (Gaussian) distribution.
+     *
+     * @param parent {@link MxResource} of this instance
+     * @param loc the mean (centre) of the distribution
+     * @param scale the standard deviation (spread or "width") of the distribution
+     * @param shape the output {@link Shape}
+     * @param dataType the {@link DataType} of the {@link NDArray}
+     * @return the drawn samples {@link NDArray}
+     */
+    public static NDArray randomNormal(
+            MxResource parent, float loc, float scale, Shape shape, DataType dataType) {
+        OpParams params = new OpParams();
+        params.addParam("loc", loc);
+        params.addParam("scale", scale);
+        params.addParam("size", shape);
+        params.setDevice(Device.defaultIfNull(null));
+        params.setDataType(dataType);
+        return invoke(parent, "_npi_normal", params);
+    }
+
+    /**
+     * Decodes {@link NDArray} through byte array.
+     *
+     * @param parent the parent {@link MxResource} to create the {@link NDArray}
+     * @param bytes byte array to load from
+     * @return {@link NDArray}
+     */
+    static NDArray decode(MxResource parent, byte[] bytes) {
+        try (DataInputStream dis = new DataInputStream(new ByteArrayInputStream(bytes))) {
+            return NDSerializer.decode(parent, dis);
+        } catch (IOException e) {
+            throw new IllegalArgumentException("NDArray decoding failed", e);
+        }
+    }
+
+    /**
+     * Decodes {@link NDArray} through {@link DataInputStream}.
+     *
+     * @param parent the parent {@link MxResource} to create the {@link NDArray}
+     * @param is input stream data to load from
+     * @return {@link NDArray}
+     * @throws IOException data is not readable
+     */
+    public static NDArray decode(MxResource parent, InputStream is) throws IOException {
+        return NDSerializer.decode(parent, is);
+    }
+
+    /**
+     * Converts this {@code NDArray} to a Number array based on its {@link DataType}.
+     *
+     * @return a Number array
+     */
+    public Number[] toArray() {
+        switch (getDataType()) {
+            case FLOAT16:
+            case FLOAT32:
+                float[] floatArray = toFloatArray();
+                return IntStream.range(0, floatArray.length)
+                        .mapToObj(i -> floatArray[i])
+                        .toArray(Number[]::new);
+            case FLOAT64:
+                return Arrays.stream(toDoubleArray()).boxed().toArray(Double[]::new);
+            case INT32:
+                return Arrays.stream(toIntArray()).boxed().toArray(Integer[]::new);
+            case INT64:
+                return Arrays.stream(toLongArray()).boxed().toArray(Long[]::new);
+            case BOOLEAN:
+            case INT8:
+                ByteBuffer bb = toByteBuffer();
+                Byte[] ret = new Byte[bb.remaining()];
+                for (int i = 0; i < ret.length; ++i) {
+                    ret[i] = bb.get();
+                }
+                return ret;
+            case UINT8:
+                return Arrays.stream(toUint8Array()).boxed().toArray(Integer[]::new);
+            default:
+                throw new IllegalStateException("Unsupported DataType: " + getDataType());
+        }
+    }
+
+    /**
+     * Converts this {@code NDArray} to a boolean array.
+     *
+     * @return a boolean array
+     * @throws IllegalStateException when {@link DataType} of this {@code NDArray} mismatches
+     */
+    public boolean[] toBooleanArray() {
+        if (getDataType() != DataType.BOOLEAN) {
+            throw new IllegalStateException(
+                    "DataType mismatch, Required boolean" + " Actual " + getDataType());
+        }
+        ByteBuffer bb = toByteBuffer();
+        boolean[] ret = new boolean[bb.remaining()];
+        for (int i = 0; i < ret.length; ++i) {
+            ret[i] = bb.get() != 0;
+        }
+        return ret;
+    }
+
+    /**
+     * Converts this {@code NDArray} to a double array.
+     *
+     * @return a double array
+     * @throws IllegalStateException when {@link DataType} of this {@code NDArray} mismatches
+     */
+    public double[] toDoubleArray() {
+        if (getDataType() != DataType.FLOAT64) {
+            throw new IllegalStateException(
+                    "DataType mismatch, Required double" + " Actual " + getDataType());
+        }
+        DoubleBuffer db = toByteBuffer().asDoubleBuffer();
+        double[] ret = new double[db.remaining()];
+        db.get(ret);
+        return ret;
+    }
+
+    /**
+     * Converts this {@code NDArray} to a float array.
+     *
+     * @return a float array
+     * @throws IllegalStateException when {@link DataType} of this {@code NDArray} mismatches
+     */
+    public float[] toFloatArray() {
+        if (getDataType() == DataType.FLOAT16) {
+            return Float16Utils.fromByteBuffer(toByteBuffer());
+        } else if (getDataType() != DataType.FLOAT32) {
+            throw new IllegalStateException(
+                    "DataType mismatch, Required float, Actual " + getDataType());
+        }
+        FloatBuffer fb = toByteBuffer().asFloatBuffer();
+        float[] ret = new float[fb.remaining()];
+        fb.get(ret);
+        return ret;
+    }
+
+    /**
+     * Converts this {@code NDArray} to an int array.
+     *
+     * @return an int array
+     * @throws IllegalStateException when {@link DataType} of this {@code NDArray} mismatches
+     */
+    public int[] toIntArray() {
+        if (getDataType() != DataType.INT32) {
+            throw new IllegalStateException(
+                    "DataType mismatch, Required int" + " Actual " + getDataType());
+        }
+        IntBuffer ib = toByteBuffer().asIntBuffer();
+        int[] ret = new int[ib.remaining()];
+        ib.get(ret);
+        return ret;
+    }
+
+    /**
+     * Converts this {@code NDArray} to a long array.
+     *
+     * @return a long array
+     * @throws IllegalStateException when {@link DataType} of this {@code NDArray} mismatches
+     */
+    public long[] toLongArray() {
+        if (getDataType() != DataType.INT64) {
+            throw new IllegalStateException(
+                    "DataType mismatch, Required long" + " Actual " + getDataType());
+        }
+        LongBuffer lb = toByteBuffer().asLongBuffer();
+        long[] ret = new long[lb.remaining()];
+        lb.get(ret);
+        return ret;
+    }
+
+    /**
+     * Converts this {@code NDArray} to a byte array.
+     *
+     * @return a byte array
+     * @throws IllegalStateException when {@link DataType} of this {@code NDArray} mismatches
+     */
+    public byte[] toByteArray() {
+        ByteBuffer bb = toByteBuffer();
+        if (bb.hasArray()) {
+            return bb.array();
+        }
+        byte[] buf = new byte[bb.remaining()];
+        bb.get(buf);
+        return buf;
+    }
+
+    /**
+     * Converts this {@code NDArray} to a uint8 array.
+     *
+     * @return a uint8 array
+     * @throws IllegalStateException when {@link DataType} of this {@code NDArray} mismatches
+     */
+    public int[] toUint8Array() {
+        ByteBuffer bb = toByteBuffer();
+        int[] buf = new int[bb.remaining()];
+        for (int i = 0; i < buf.length; ++i) {
+            buf[i] = bb.get() & 0xff;
+        }
+        return buf;
+    }
+
+    /** {@inheritDoc} */
+    @Override
+    public String toString() {
+        if (getClosed()) {
+            return "This array is already closed";
+        }
+        return toDebugString(MAX_SIZE, MAX_DEPTH, MAX_ROWS, MAX_COLUMNS);
+    }
+
+    /**
+     * Runs the debug string representation of this {@code NDArray}.
+     *
+     * @param maxSize the maximum elements to print out
+     * @param maxDepth the maximum depth to print out
+     * @param maxRows the maximum rows to print out
+     * @param maxColumns the maximum columns to print out
+     * @return the debug string representation of this {@code NDArray}
+     */
+    String toDebugString(int maxSize, int maxDepth, int maxRows, int maxColumns) {
+        return NDFormat.format(this, maxSize, maxDepth, maxRows, maxColumns);
+    }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/NDArrayEx.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/NDArrayEx.java
new file mode 100644
index 0000000..047ff04
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/NDArrayEx.java
@@ -0,0 +1,1107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.ndarray;
+
+import java.util.Arrays;
+import java.util.List;
+import java.util.stream.Stream;
+import org.apache.mxnet.engine.Device;
+import org.apache.mxnet.engine.OpParams;
+import org.apache.mxnet.jna.JnaUtils;
+import org.apache.mxnet.ndarray.types.DataType;
+import org.apache.mxnet.ndarray.types.Shape;
+import org.apache.mxnet.ndarray.types.SparseFormat;
+
+/** An internal interface that encapsulates engine specific operations. */
+@SuppressWarnings("MissingJavadocMethod")
+public class NDArrayEx {
+
+    private static final NDArrayIndexer INDEXER = new NDArrayIndexer();
+
+    private NDArray array;
+
+    /**
+     * Constructs an {@code MxNDArrayEx} given a {@link NDArray}.
+     *
+     * @param parent the {@link NDArray} to extend
+     */
+    NDArrayEx(NDArray parent) {
+        this.array = parent;
+    }
+
+    // TODO only used to calculate zero-dim numpy shape
+    // remove it once MXNet have all the np op that we support
+    private Shape deriveBroadcastedShape(Shape lhs, Shape rhs) {
+        long[] result = new long[Math.max(lhs.dimension(), rhs.dimension())];
+        long lDiff = result.length - lhs.dimension();
+        long rDiff = result.length - rhs.dimension();
+        for (int i = 0; i < result.length; i++) {
+            long l = 1;
+            long r = 1;
+            if (i >= lDiff) {
+                l = lhs.get(Math.toIntExact(i - lDiff));
+            }
+            if (i >= rDiff) {
+                r = rhs.get(Math.toIntExact(i - rDiff));
+            }
+            if (l != r) {
+                if (l != 1 && r != 1) {
+                    throw new IllegalArgumentException(
+                            "operands could not be broadcast together with shapes "
+                                    + lhs
+                                    + " "
+                                    + rhs);
+                }
+                result[i] = (l == 1) ? r : l;
+            } else {
+                result[i] = l;
+            }
+        }
+        return new Shape(result);
+    }
+
+    ////////////////////////////////////////
+    // MxNDArrays
+    ////////////////////////////////////////
+    /**
+     * Applies reverse division with a scalar - i.e., (n / thisArrayValues).
+     *
+     * @param n the Value to use for reverse division
+     * @return a copy of the array after applying reverse division
+     */
+    public NDArray rdiv(Number n) {
+        OpParams params = new OpParams();
+        params.add("scalar", n.toString());
+        return NDArray.invoke(getArray().getParent(), "_rdiv_scalar", array, params);
+    }
+
+    /**
+     * Applies reverse division with a scalar - i.e., (n / thisArrayValues).
+     *
+     * @param b the ndarray to use for reverse division
+     * @return a copy of the array after applying reverse division
+     */
+    public NDArray rdiv(NDArray b) {
+        return b.div(array);
+    }
+
+    /**
+     * Applies in place reverse division - i.e., (n / thisArrayValues).
+     *
+     * @param n the value to use for reverse division
+     * @return this array after applying reverse division
+     */
+    public NDArray rdivi(Number n) {
+        OpParams params = new OpParams();
+        params.add("scalar", n.toString());
+        NDArray.invoke("_rdiv_scalar", new NDArray[] {array}, new NDArray[] {array}, params);
+        return array;
+    }
+
+    /**
+     * Applies in place reverse division - i.e., (n / thisArrayValues).
+     *
+     * @param b the ndarray to use for reverse division
+     * @return this array after applying reverse division
+     */
+    public NDArray rdivi(NDArray b) {
+        NDArray.invoke("elemwise_div", new NDArray[] {b, array}, new NDArray[] {array}, null);
+        return array;
+    }
+
+    /**
+     * Applies reverse subtraction with duplicates - i.e., (n - thisArrayValues).
+     *
+     * @param n the value to use for reverse subtraction
+     * @return a copy of array after reverse subtraction
+     */
+    public NDArray rsub(Number n) {
+        return array.sub(n).neg();
+    }
+
+    /**
+     * Applies reverse subtraction with duplicates - i.e., (n - thisArrayValues).
+     *
+     * @param b the ndarray to use for reverse subtraction
+     * @return a copy of the array after reverse subtraction
+     */
+    public NDArray rsub(NDArray b) {
+        return array.sub(b).neg();
+    }
+
+    /**
+     * Applies reverse subtraction in place - i.e., (n - thisArrayValues).
+     *
+     * @param n the value to use for reverse subtraction
+     * @return this array after reverse subtraction
+     */
+    public NDArray rsubi(Number n) {
+        return array.subi(n).negi();
+    }
+
+    /**
+     * Applies reverse subtraction in place - i.e., (n - thisArrayValues).
+     *
+     * @param b the ndarray to use for reverse subtraction
+     * @return this array after reverse subtraction
+     */
+    public NDArray rsubi(NDArray b) {
+        return array.subi(b).negi();
+    }
+
+    public NDArray rmod(Number n) {
+        OpParams params = new OpParams();
+        params.add("scalar", n.toString());
+        return NDArray.invoke(getArray().getParent(), "_npi_rmod_scalar", array, params);
+    }
+
+    /**
+     * Applies reverse remainder of division with a scalar.
+     *
+     * @param b the value to use for reverse division
+     * @return a copy of array after applying reverse division
+     */
+    public NDArray rmod(NDArray b) {
+        return b.mod(array);
+    }
+
+    /**
+     * Applies in place reverse remainder of division with a scalar.
+     *
+     * @param n the value to use for reverse division
+     * @return this array after applying reverse division
+     */
+    public NDArray rmodi(Number n) {
+        OpParams params = new OpParams();
+        params.add("scalar", n.toString());
+        NDArray.invoke("_npi_rmod_scalar", new NDArray[] {array}, new NDArray[] {array}, params);
+        return array;
+    }
+
+    /**
+     * Applies in place reverse remainder of division.
+     *
+     * @param b the ndarray to use for reverse division
+     * @return this array after applying reverse division
+     */
+    public NDArray rmodi(NDArray b) {
+        NDArray.invoke("_npi_mod", new NDArray[] {b, array}, new NDArray[] {array}, null);
+        return array;
+    }
+
+    /**
+     * Reverses the power of each element being raised in the {@code NDArray}.
+     *
+     * @param n the value to use for reverse power
+     * @return a copy of array after applying reverse power
+     */
+    public NDArray rpow(Number n) {
+        OpParams params = new OpParams();
+        params.add("scalar", n.toString());
+        return NDArray.invoke(getArray().getParent(), "_npi_rpower_scalar", array, params);
+    }
+
+    /**
+     * Reverses the power of each element being raised in the {@code NDArray} in place.
+     *
+     * @param n the value to use for reverse power
+     * @return a copy of array after applying reverse power
+     */
+    public NDArray rpowi(Number n) {
+        OpParams params = new OpParams();
+        params.add("scalar", n.toString());
+        NDArray.invoke("_npi_rpower_scalar", new NDArray[] {array}, new NDArray[] {array}, params);
+        return array;
+    }
+
+    ////////////////////////////////////////
+    // Activations
+    ////////////////////////////////////////
+    /**
+     * Computes rectified linear activation.
+     *
+     * @return a copy of array after applying relu
+     */
+    public NDArray relu() {
+        OpParams params = new OpParams();
+        params.addParam("act_type", "relu");
+        return NDArray.invoke(getArray().getParent(), "_npx_activation", array, params);
+    }
+
+    public NDArray sigmoid() {
+        OpParams params = new OpParams();
+        params.addParam("act_type", "sigmoid");
+        return NDArray.invoke(getArray().getParent(), "_npx_activation", array, params);
+    }
+
+    public NDArray tanh() {
+        OpParams params = new OpParams();
+        params.addParam("act_type", "tanh");
+        return NDArray.invoke(getArray().getParent(), "_npx_activation", array, params);
+    }
+
+    public NDArray softPlus() {
+        OpParams params = new OpParams();
+        params.addParam("act_type", "softrelu");
+        return NDArray.invoke(getArray().getParent(), "_npx_activation", array, params);
+    }
+
+    public NDArray softSign() {
+        OpParams params = new OpParams();
+        params.addParam("act_type", "softsign");
+        return NDArray.invoke(getArray().getParent(), "_npx_activation", array, params);
+    }
+
+    public NDArray leakyRelu(float alpha) {
+        OpParams params = new OpParams();
+        params.addParam("act_type", "leaky");
+        params.addParam("slope", alpha);
+        return NDArray.invoke(getArray().getParent(), "_npx_leaky_relu", array, params);
+    }
+
+    public NDArray elu(float alpha) {
+        OpParams params = new OpParams();
+        params.addParam("act_type", "elu");
+        params.addParam("slope", alpha);
+        return NDArray.invoke(getArray().getParent(), "_npx_leaky_relu", array, params);
+    }
+
+    public NDArray selu() {
+        OpParams params = new OpParams();
+        params.addParam("act_type", "selu");
+        return NDArray.invoke(getArray().getParent(), "_npx_leaky_relu", array, params);
+    }
+
+    public NDArray gelu() {
+        OpParams params = new OpParams();
+        params.addParam("act_type", "gelu");
+        return NDArray.invoke(getArray().getParent(), "_npx_leaky_relu", array, params);
+    }
+
+    ////////////////////////////////////////
+    // Pooling Operations
+    ////////////////////////////////////////
+
+    public NDArray maxPool(Shape kernelShape, Shape stride, Shape padding, boolean ceilMode) {
+        OpParams params = new OpParams();
+        params.addParam("kernel", kernelShape);
+        params.add("pool_type", "max");
+        params.addParam("stride", stride);
+        params.addParam("pad", padding);
+        params.add("pooling_convention", ceilMode ? "full" : "valid");
+        return NDArray.invoke(getArray().getParent(), "_npx_pooling", getArray(), params);
+    }
+
+    public NDArray globalMaxPool() {
+        OpParams params = new OpParams();
+        params.add("kernel", getGlobalPoolingShapes(1));
+        params.add("pad", getGlobalPoolingShapes(0));
+        params.add("pool_type", "max");
+        params.addParam("global_pool", true);
+        try (NDArray temp =
+                NDArray.invoke(getArray().getParent(), "_npx_pooling", getArray(), params)) {
+            return temp.reshape(temp.getShape().size(0), temp.getShape().size(1));
+        }
+    }
+
+    public NDArray avgPool(
+            Shape kernelShape,
+            Shape stride,
+            Shape padding,
+            boolean ceilMode,
+            boolean countIncludePad) {
+        OpParams params = new OpParams();
+        params.addParam("kernel", kernelShape);
+        params.add("pool_type", "avg");
+        params.addParam("stride", stride);
+        params.addParam("pad", padding);
+        params.add("pooling_convention", ceilMode ? "full" : "valid");
+        params.addParam("count_include_pad", countIncludePad);
+        return NDArray.invoke(getArray().getParent(), "_npx_pooling", getArray(), params);
+    }
+
+    public NDArray globalAvgPool() {
+        OpParams params = new OpParams();
+        params.add("kernel", getGlobalPoolingShapes(1));
+        params.add("pad", getGlobalPoolingShapes(0));
+        params.add("pool_type", "avg");
+        params.addParam("global_pool", true);
+        try (NDArray temp =
+                NDArray.invoke(getArray().getParent(), "_npx_pooling", getArray(), params)) {
+            return temp.reshape(temp.getShape().size(0), temp.getShape().size(1));
+        }
+    }
+
+    public NDArray lpPool(
+            float normType, Shape kernelShape, Shape stride, Shape padding, boolean ceilMode) {
+        if (((int) normType) != normType) {
+            throw new IllegalArgumentException(
+                    "float type of normType is not supported in MXNet engine, please use integer instead");
+        }
+        OpParams params = new OpParams();
+        params.addParam("p_value", (int) normType);
+        params.addParam("kernel", kernelShape);
+        params.add("pool_type", "lp");
+        params.addParam("stride", stride);
+        params.addParam("pad", padding);
+        params.add("pooling_convention", ceilMode ? "full" : "valid");
+
+        return NDArray.invoke(getArray().getParent(), "_npx_pooling", getArray(), params);
+    }
+
+    public NDArray globalLpPool(float normType) {
+        if (((int) normType) != normType) {
+            throw new IllegalArgumentException(
+                    "float type of normType is not supported in MXNet engine, please use integer instead");
+        }
+        OpParams params = new OpParams();
+        params.add("pool_type", "lp");
+        params.addParam("p_value", (int) normType);
+        params.addParam("global_pool", true);
+        try (NDArray temp =
+                NDArray.invoke(getArray().getParent(), "_npx_pooling", getArray(), params)) {
+            return temp.reshape(temp.getShape().size(0), temp.getShape().size(1));
+        }
+    }
+
+    ////////////////////////////////////////
+    // Optimizer
+    ////////////////////////////////////////
+
+    //    public void adadeltaUpdate(
+    //            MxNDList inputs,
+    //            MxNDList weights,
+    //            float weightDecay,
+    //            float rescaleGrad,
+    //            float clipGrad,
+    //            float rho,
+    //            float epsilon) {
+    //        MxNDArray weight = inputs.get(0);
+    //        MxNDArray grad = inputs.get(1);
+    //        MxNDArray s = inputs.get(2);
+    //        MxNDArray delta = inputs.get(3);
+    //
+    //        // create a baseManager to close all intermediate MxNDArrays
+    //        try (NDManager subManager = NDManager.newBaseManager()) {
+    //            subManager.tempAttachAll(inputs, weights);
+    //
+    //            // Preprocess Gradient
+    //            grad.muli(rescaleGrad);
+    //            if (clipGrad > 0) {
+    //                grad = grad.clip(-clipGrad, clipGrad);
+    //            }
+    //            grad.addi(weight.mul(weightDecay));
+    //
+    //            // Update s, g, and delta
+    //            s.muli(rho).addi(grad.square().mul(1 - rho));
+    //            MxNDArray g = delta.add(epsilon).sqrt().div(s.add(epsilon).sqrt()).mul(grad);
+    //            delta.muli(rho).addi(g.square().mul(1 - rho));
+    //
+    //            // Update weight
+    //            weight.subi(g);
+    //        }
+    //    }
+
+    public void adagradUpdate(
+            NDList inputs,
+            NDList weights,
+            float learningRate,
+            float weightDecay,
+            float rescaleGrad,
+            float clipGrad,
+            float epsilon) {
+        OpParams params = new OpParams();
+        params.addParam("lr", learningRate);
+        params.addParam("wd", weightDecay);
+        params.addParam("rescale_grad", rescaleGrad);
+        params.addParam("clip_gradient", clipGrad);
+
+        params.addParam("epsilon", epsilon);
+
+        NDArray.invoke("adagrad_update", inputs, weights, params);
+    }
+
+    public void adamUpdate(
+            NDList inputs,
+            NDList weights,
+            float learningRate,
+            float weightDecay,
+            float rescaleGrad,
+            float clipGrad,
+            float beta1,
+            float beta2,
+            float epsilon,
+            boolean lazyUpdate) {
+        OpParams params = new OpParams();
+        params.addParam("lr", learningRate);
+        params.addParam("wd", weightDecay);
+        params.addParam("rescale_grad", rescaleGrad);
+        params.addParam("clip_gradient", clipGrad);
+
+        params.addParam("beta1", beta1);
+        params.addParam("beta2", beta2);
+        params.addParam("epsilon", epsilon);
+        params.addParam("lazy_update", lazyUpdate);
+
+        NDArray.invoke("adam_update", inputs, weights, params);
+    }
+
+    public void rmspropUpdate(
+            NDList inputs,
+            NDList weights,
+            float learningRate,
+            float weightDecay,
+            float rescaleGrad,
+            float clipGrad,
+            float gamma1,
+            float gamma2,
+            float epsilon,
+            boolean centered) {
+        OpParams params = new OpParams();
+        params.addParam("lr", learningRate);
+        params.addParam("wd", weightDecay);
+        params.addParam("rescale_grad", rescaleGrad);
+        params.addParam("clip_gradient", clipGrad);
+
+        params.addParam("gamma1", gamma1);
+        params.addParam("epsilon", epsilon);
+
+        if (!centered) {
+            NDArray.invoke("rmsprop_update", inputs, weights, params);
+        } else {
+            params.addParam("gamma2", gamma2);
+
+            NDArray.invoke("rmspropalex_update", inputs, weights, params);
+        }
+    }
+
+    public void nagUpdate(
+            NDList inputs,
+            NDList weights,
+            float learningRate,
+            float weightDecay,
+            float rescaleGrad,
+            float clipGrad,
+            float momentum) {
+        OpParams params = new OpParams();
+        params.addParam("lr", learningRate);
+        params.addParam("wd", weightDecay);
+        params.addParam("rescale_grad", rescaleGrad);
+        params.addParam("clip_gradient", clipGrad);
+        params.addParam("momentum", momentum);
+        NDArray.invoke("nag_mom_update", inputs, weights, params);
+    }
+
+    public void sgdUpdate(
+            NDList inputs,
+            NDList weights,
+            float learningRate,
+            float weightDecay,
+            float rescaleGrad,
+            float clipGrad,
+            float momentum,
+            boolean lazyUpdate) {
+        OpParams params = new OpParams();
+        params.addParam("lr", learningRate);
+        params.addParam("wd", weightDecay);
+        params.addParam("rescale_grad", rescaleGrad);
+        params.addParam("clip_gradient", clipGrad);
+        params.addParam("lazy_update", lazyUpdate);
+
+        if (momentum != 0) {
+            params.addParam("momentum", momentum);
+            NDArray.invoke("sgd_mom_update", inputs, weights, params);
+        } else {
+            NDArray.invoke("sgd_update", inputs, weights, params);
+        }
+    }
+
+    ////////////////////////////////////////
+    // Neural network
+    ////////////////////////////////////////
+
+    public NDList convolution(
+            NDArray input,
+            NDArray weight,
+            NDArray bias,
+            Shape stride,
+            Shape padding,
+            Shape dilation,
+            int groups) {
+        OpParams params = new OpParams();
+        params.addParam("kernel", weight.getShape().slice(2));
+        params.addParam("stride", stride);
+        params.addParam("pad", padding);
+        params.addParam("dilate", dilation);
+        params.addParam("num_group", groups);
+        params.addParam("num_filter", weight.getShape().get(0));
+
+        NDList inputs = new NDList(input, weight);
+        if (bias != null) {
+            params.add("no_bias", false);
+            inputs.add(bias);
+        } else {
+            params.add("no_bias", true);
+        }
+
+        return NDArray.invoke(getArray().getParent(), "_npx_convolution", inputs, params);
+    }
+
+    public NDList deconvolution(
+            NDArray input,
+            NDArray weight,
+            NDArray bias,
+            Shape stride,
+            Shape padding,
+            Shape outPadding,
+            Shape dilation,
+            int groups) {
+        OpParams params = new OpParams();
+        params.addParam("kernel", weight.getShape().slice(2));
+        params.addParam("stride", stride);
+        params.addParam("pad", padding);
+        params.addParam("adj", outPadding);
+        params.addParam("dilate", dilation);
+        params.addParam("num_group", groups);
+        params.addParam("num_filter", weight.getShape().get(0));
+
+        NDList inputs = new NDList(input, weight);
+        if (bias != null) {
+            params.add("no_bias", false);
+            inputs.add(bias);
+        } else {
+            params.add("no_bias", true);
+        }
+
+        return NDArray.invoke(getArray().getParent(), "_npx_deconvolution", inputs, params);
+    }
+
+    public NDList linear(NDArray input, NDArray weight, NDArray bias) {
+        OpParams params = new OpParams();
+        params.addParam("num_hidden", weight.size(0));
+        params.addParam("flatten", false);
+        params.addParam("no_bias", bias == null);
+        NDList inputs = new NDList(input, weight);
+        if (bias != null) {
+            inputs.add(bias);
+        }
+
+        return NDArray.invoke(getArray().getParent(), "_npx_fully_connected", inputs, params);
+    }
+
+    public NDList embedding(NDArray input, NDArray weight, SparseFormat sparse) {
+        if (!sparse.equals(SparseFormat.DENSE) && !sparse.equals(SparseFormat.ROW_SPARSE)) {
+            throw new IllegalArgumentException("MXNet only supports row sparse");
+        }
+        OpParams params = new OpParams();
+        long inputDim = weight.getShape().get(0);
+        long outputDim = weight.getShape().get(1);
+        params.addParam("input_dim", inputDim);
+        params.addParam("output_dim", outputDim);
+        params.addParam("sparse_grad", sparse.getValue());
+        return NDArray.invoke(
+                getArray().getParent(), "_npx_embedding", new NDList(input, weight), params);
+    }
+
+    public NDList prelu(NDArray input, NDArray alpha) {
+        OpParams params = new OpParams();
+        params.addParam("act_type", "prelu");
+        return NDArray.invoke(
+                getArray().getParent(), "_npx_leaky_relu", new NDList(input, alpha), params);
+    }
+
+    public NDList dropout(NDArray input, float rate, boolean training) {
+        if (training != JnaUtils.autogradIsTraining()) {
+            throw new IllegalArgumentException(
+                    "the mode of dropout in MXNet should align with the mode of GradientCollector");
+        }
+
+        OpParams params = new OpParams();
+        params.addParam("p", rate);
+
+        return NDArray.invoke(getArray().getParent(), "_npx_dropout", new NDList(input), params);
+    }
+
+    public NDList batchNorm(
+            NDArray input,
+            NDArray runningMean,
+            NDArray runningVar,
+            NDArray gamma,
+            NDArray beta,
+            int axis,
+            float momentum,
+            float eps,
+            boolean training) {
+        OpParams params = new OpParams();
+        params.addParam("axis", axis);
+        params.addParam("fix_gamma", gamma == null);
+        params.addParam("eps", eps);
+        params.addParam("momentum", momentum);
+
+        if (training != JnaUtils.autogradIsTraining()) {
+            throw new IllegalArgumentException(
+                    "the mode of batchNorm in MXNet should align with the mode of GradientCollector");
+        }
+
+        return NDArray.invoke(
+                getArray().getParent(),
+                "_npx_batch_norm",
+                new NDList(input, gamma, beta, runningMean, runningVar),
+                params);
+    }
+
+    //    public MxNDList rnn(
+    //            MxNDArray input,
+    //            MxNDArray state,
+    //            MxNDList params,
+    //            boolean hasBiases,
+    //            int numLayers,
+    //            RNN.Activation activation,
+    //            double dropRate,
+    //            boolean training,
+    //            boolean bidirectional,
+    //            boolean batchFirst) {
+    //        int numParams = numLayers * ((hasBiases) ? 4 : 2) * ((bidirectional) ? 2 : 1);
+    //        Preconditions.checkArgument(
+    //                params.size() == numParams,
+    //                "The size of Params is incorrect expect "
+    //                        + numParams
+    //                        + " parameters but got "
+    //                        + params.size());
+    //
+    //        if (training != JnaUtils.autogradIsTraining()) {
+    //            throw new IllegalArgumentException(
+    //                    "the mode of rnn in MXNet should align with the mode of
+    // GradientCollector");
+    //        }
+    //
+    //        if (batchFirst) {
+    //            input = input.swapAxes(0, 1);
+    //        }
+    //
+    //        MxOpParams opParams = new MxOpParams();
+    //        opParams.addParam("p", dropRate);
+    //        opParams.addParam("state_size", state.getShape().tail());
+    //        opParams.addParam("num_layers", numLayers);
+    //        opParams.addParam("bidirectional", bidirectional);
+    //        opParams.addParam("state_outputs", true);
+    //        opParams.addParam("mode", activation == RNN.Activation.TANH ? "rnn_tanh" :
+    // "rnn_relu");
+    //
+    //        MxNDList inputs = new MxNDList();
+    //        inputs.add(input);
+    //
+    //        try (MxNDList temp = new MxNDList()) {
+    //            for (MxNDArray param : params) {
+    //                temp.add(param.flatten());
+    //            }
+    //            MxNDArray tempParam = MxNDArrays.concat(temp);
+    //            tempParam.attach(input.getManager());
+    //            inputs.add(tempParam);
+    //        }
+    //
+    //        inputs.add(state);
+    //
+    //        if (!batchFirst) {
+    //            return getManager().invoke("_npx_rnn", inputs, opParams);
+    //        }
+    //
+    //        MxNDList result = getManager().invoke("_npx_rnn", inputs, opParams);
+    //        try (MxNDArray temp = result.head()) {
+    //            return new MxNDList(temp.swapAxes(0, 1), result.get(1));
+    //        }
+    //    }
+
+    //    public MxNDList gru(
+    //            MxNDArray input,
+    //            MxNDArray state,
+    //            MxNDList params,
+    //            boolean hasBiases,
+    //            int numLayers,
+    //            double dropRate,
+    //            boolean training,
+    //            boolean bidirectional,
+    //            boolean batchFirst) {
+    //        int numParams = numLayers * ((hasBiases) ? 4 : 2) * ((bidirectional) ? 2 : 1);
+    //        Preconditions.checkArgument(
+    //                params.size() == numParams,
+    //                "The size of Params is incorrect expect "
+    //                        + numParams
+    //                        + " parameters but got "
+    //                        + params.size());
+    //
+    //        if (training != JnaUtils.autogradIsTraining()) {
+    //            throw new IllegalArgumentException(
+    //                    "the mode of gru in MXNet should align with the mode of
+    // GradientCollector");
+    //        }
+    //
+    //        if (batchFirst) {
+    //            input = input.swapAxes(0, 1);
+    //        }
+    //
+    //        MxOpParams opParams = new MxOpParams();
+    //        opParams.addParam("p", dropRate);
+    //        opParams.addParam("state_size", state.getShape().tail());
+    //        opParams.addParam("num_layers", numLayers);
+    //        opParams.addParam("bidirectional", bidirectional);
+    //        opParams.addParam("state_outputs", true);
+    //        opParams.addParam("mode", "gru");
+    //
+    //        MxNDList inputs = new MxNDList();
+    //        inputs.add(input);
+    //
+    //        try (MxNDList temp = new MxNDList()) {
+    //            for (MxNDArray param : params) {
+    //                temp.add(param.flatten());
+    //            }
+    //            MxNDArray tempParam = MxNDArrays.concat(temp);
+    //            tempParam.attach(input.getManager());
+    //            inputs.add(tempParam);
+    //        }
+    //
+    //        inputs.add(state);
+    //
+    //        if (!batchFirst) {
+    //            return getManager().invoke("_npx_rnn", inputs, opParams);
+    //        }
+    //
+    //        MxNDList result = getManager().invoke("_npx_rnn", inputs, opParams);
+    //        try (MxNDArray temp = result.head()) {
+    //            return new MxNDList(temp.swapAxes(0, 1), result.get(1));
+    //        }
+    //    }
+    //
+    //    public MxNDList lstm(
+    //            MxNDArray input,
+    //            MxNDList states,
+    //            MxNDList params,
+    //            boolean hasBiases,
+    //            int numLayers,
+    //            double dropRate,
+    //            boolean training,
+    //            boolean bidirectional,
+    //            boolean batchFirst) {
+    //        int numParams = numLayers * ((hasBiases) ? 4 : 2) * ((bidirectional) ? 2 : 1);
+    //        Preconditions.checkArgument(
+    //                params.size() == numParams,
+    //                "The size of Params is incorrect expect "
+    //                        + numParams
+    //                        + " parameters but got "
+    //                        + params.size());
+    //
+    //        if (training != JnaUtils.autogradIsTraining()) {
+    //            throw new IllegalArgumentException(
+    //                    "the mode of lstm in MXNet should align with the mode of
+    // GradientCollector");
+    //        }
+    //
+    //        if (batchFirst) {
+    //            input = input.swapAxes(0, 1);
+    //        }
+    //
+    //        MxOpParams opParams = new MxOpParams();
+    //        opParams.addParam("mode", "lstm");
+    //        opParams.addParam("p", dropRate);
+    //        opParams.addParam("state_size", states.head().getShape().tail());
+    //        opParams.addParam("state_outputs", true);
+    //        opParams.addParam("num_layers", numLayers);
+    //        opParams.addParam("bidirectional", bidirectional);
+    //        opParams.addParam("lstm_state_clip_nan", true);
+    //
+    //        MxNDList inputs = new MxNDList();
+    //        inputs.add(input);
+    //        try (MxNDList temp = new MxNDList()) {
+    //            for (MxNDArray param : params) {
+    //                temp.add(param.flatten());
+    //            }
+    //            MxNDArray tempParam = MxNDArrays.concat(temp);
+    //            tempParam.attach(input.getManager());
+    //            inputs.add(tempParam);
+    //        }
+    //        inputs.addAll(states);
+    //
+    //        if (!batchFirst) {
+    //            return getManager().invoke("_npx_rnn", inputs, opParams);
+    //        }
+    //
+    //        MxNDList result = getManager().invoke("_npx_rnn", inputs, opParams);
+    //        try (MxNDArray temp = result.head()) {
+    //            return new MxNDList(temp.swapAxes(0, 1), result.get(1), result.get(2));
+    //        }
+    //    }
+
+    ////////////////////////////////////////
+    // Image and CV
+    ////////////////////////////////////////
+
+    public NDArray normalize(float[] mean, float[] std) {
+        OpParams params = new OpParams();
+        params.addTupleParam("mean", mean);
+        params.addTupleParam("std", std);
+        return NDArray.invoke(getArray().getParent(), "_npx__image_normalize", array, params);
+    }
+
+    public NDArray toTensor() {
+        return NDArray.invoke(getArray().getParent(), "_npx__image_to_tensor", array, null);
+    }
+
+    public NDArray resize(int width, int height, int interpolation) {
+        if (array.isEmpty()) {
+            throw new IllegalArgumentException("attempt to resize of an empty MxNDArray");
+        }
+        OpParams params = new OpParams();
+        params.addTupleParam("size", width, height);
+        params.addParam("interp", interpolation);
+        return NDArray.invoke(getArray().getParent(), "_npx__image_resize", array, params);
+    }
+
+    public NDArray crop(int x, int y, int width, int height) {
+        OpParams params = new OpParams();
+        params.add("x", x);
+        params.add("y", y);
+        params.add("width", width);
+        params.add("height", height);
+        return NDArray.invoke(getArray().getParent(), "_npx__image_crop", array, params);
+    }
+
+    public NDArray randomFlipLeftRight() {
+        if (array.getDevice().getDeviceType().equals(Device.Type.GPU)) {
+            throw new UnsupportedOperationException("randomFlipLeftRight is not supported on GPU");
+        }
+        return NDArray.invoke(
+                getArray().getParent(), "_npx__image_random_flip_left_right", array, null);
+    }
+
+    public NDArray randomFlipTopBottom() {
+        if (array.getDevice().getDeviceType().equals(Device.Type.GPU)) {
+            throw new UnsupportedOperationException("randomFlipTopBottom is not supported on GPU");
+        }
+        return NDArray.invoke(
+                getArray().getParent(), "_npx__image_random_flip_top_bottom", array, null);
+    }
+
+    public NDArray randomBrightness(float brightness) {
+        if (array.getDevice().getDeviceType().equals(Device.Type.GPU)) {
+            throw new UnsupportedOperationException("randomBrightness is not supported on GPU");
+        }
+        OpParams params = new OpParams();
+        float min = Math.max(0, 1 - brightness);
+        float max = 1 + brightness;
+        params.addParam("min_factor", min);
+        params.addParam("max_factor", max);
+        return NDArray.invoke(
+                getArray().getParent(), "_npx__image_random_brightness", array, params);
+    }
+
+    public NDArray randomHue(float hue) {
+        if (array.getDevice().getDeviceType().equals(Device.Type.GPU)) {
+            throw new UnsupportedOperationException("randomHue is not supported on GPU");
+        }
+        OpParams params = new OpParams();
+        float min = Math.max(0, 1 - hue);
+        float max = 1 + hue;
+        params.addParam("min_factor", min);
+        params.addParam("max_factor", max);
+        return NDArray.invoke(getArray().getParent(), "_npx__image_random_hue", array, params);
+    }
+
+    public NDArray randomColorJitter(
+            float brightness, float contrast, float saturation, float hue) {
+        if (array.getDevice().getDeviceType().equals(Device.Type.GPU)) {
+            throw new UnsupportedOperationException("randomColorJitter is not supported on GPU");
+        }
+        OpParams params = new OpParams();
+        params.addParam("brightness", brightness);
+        params.addParam("contrast", contrast);
+        params.addParam("saturation", saturation);
+        params.addParam("hue", hue);
+        return NDArray.invoke(
+                getArray().getParent(), "_npx__image_random_color_jitter", array, params);
+    }
+
+    public NDArrayIndexer getIndexer() {
+        return INDEXER;
+    }
+
+    ////////////////////////////////////////
+    // Miscellaneous
+    ////////////////////////////////////////
+
+    @SuppressWarnings("PMD.UseTryWithResources")
+    public NDArray where(NDArray condition, NDArray other) {
+        NDArray array1;
+        NDArray array2;
+        condition =
+                (condition.getDataType() == DataType.BOOLEAN)
+                        ? condition.toType(DataType.INT32, false)
+                        : condition;
+        if (array.getDataType() != other.getDataType()) {
+            throw new IllegalArgumentException(
+                    "DataType mismatch, required "
+                            + array.getDataType()
+                            + " actual "
+                            + other.getDataType());
+        }
+        if (!array.shapeEquals(other)) {
+            Shape res = deriveBroadcastedShape(array.getShape(), other.getShape());
+            array1 = (!res.equals(array.getShape())) ? array.broadcast(res) : array;
+            array2 = (!res.equals(other.getShape())) ? other.broadcast(res) : other;
+        } else {
+            array1 = array;
+            array2 = other;
+        }
+        try {
+            return NDArray.invoke(
+                    getArray().getParent(),
+                    "where",
+                    new NDArray[] {condition, array1, array2},
+                    null);
+        } finally {
+            if (array1 != array) {
+                array1.close();
+            }
+            if (array2 != other) {
+                array2.close();
+            }
+        }
+    }
+
+    public NDArray stack(NDList arrays, int axis) {
+        OpParams params = new OpParams();
+        params.addParam("axis", axis);
+        NDArray[] srcArray = new NDArray[arrays.size() + 1];
+        srcArray[0] = array;
+        System.arraycopy(arrays.toArray(new NDArray[0]), 0, srcArray, 1, arrays.size());
+        return NDArray.invoke(getArray().getParent(), "_npi_stack", srcArray, params);
+    }
+
+    /**
+     * Check two criteria of concat input: 1. no scalar 2. dimensions of all the array must be the
+     * same.
+     *
+     * @param list input {@link NDList}
+     */
+    public static void checkConcatInput(NDList list) {
+        NDArray[] arrays = list.toArray(new NDArray[0]);
+        if (Stream.of(arrays).allMatch(array -> array.getShape().dimension() == 0)) {
+            throw new IllegalArgumentException(
+                    "scalar(zero-dimensional) arrays cannot be concatenated");
+        }
+        int dimension = arrays[0].getShape().dimension();
+        for (int i = 1; i < arrays.length; i++) {
+            if (arrays[i].getShape().dimension() != dimension) {
+                throw new IllegalArgumentException(
+                        "all the input arrays must have same number of dimensions, but the array at index 0 has "
+                                + dimension
+                                + " dimension(s) and the array at index "
+                                + i
+                                + " has "
+                                + arrays[i].getShape().dimension()
+                                + " dimension(s)");
+            }
+        }
+    }
+
+    public NDArray concat(NDList list, int axis) {
+        checkConcatInput(list);
+
+        OpParams params = new OpParams();
+        // MXNet backend use dim as argument name
+        params.addParam("axis", axis);
+        NDArray[] srcArray = new NDArray[list.size() + 1];
+        srcArray[0] = array;
+        System.arraycopy(list.toArray(new NDArray[0]), 0, srcArray, 1, list.size());
+        return NDArray.invoke(getArray().getParent(), "_npi_concatenate", srcArray, params);
+    }
+
+    public NDList multiBoxTarget(
+            NDList inputs,
+            float iouThreshold,
+            float ignoreLabel,
+            float negativeMiningRatio,
+            float negativeMiningThreshold,
+            int minNegativeSamples) {
+        OpParams parameters = new OpParams();
+        parameters.add("minimum_negative_samples", minNegativeSamples);
+        parameters.add("overlap_threshold", iouThreshold);
+        parameters.add("ignore_label", ignoreLabel);
+        parameters.add("negative_mining_ratio", negativeMiningRatio);
+        parameters.add("negative_mining_thresh", negativeMiningThreshold);
+        return NDArray.invoke(getArray().getParent(), "MultiBoxTarget", inputs, parameters);
+    }
+
+    public NDList multiBoxPrior(
+            List<Float> sizes,
+            List<Float> ratios,
+            List<Float> steps,
+            List<Float> offsets,
+            boolean clip) {
+        OpParams parameters = new OpParams();
+        parameters.add("sizes", sizes);
+        parameters.add("ratios", ratios);
+        parameters.add("steps", steps);
+        parameters.add("offsets", offsets);
+        parameters.add("clip", clip);
+        return NDArray.invoke(
+                getArray().getParent(), "MultiBoxPrior", new NDList(array), parameters);
+    }
+
+    public NDList multiBoxDetection(
+            NDList inputs,
+            boolean clip,
+            float threshold,
+            int backgroundId,
+            float nmsThreashold,
+            boolean forceSuppress,
+            int nmsTopK) {
+        OpParams parameters = new OpParams();
+        parameters.add("clip", clip);
+        parameters.add("threshold", threshold);
+        parameters.add("background_id", backgroundId);
+        parameters.add("nms_threshold", nmsThreashold);
+        parameters.add("force_suppress", forceSuppress);
+        parameters.add("nms_topk", nmsTopK);
+        return NDArray.invoke(getArray().getParent(), "MultiBoxDetection", inputs, parameters);
+    }
+
+    public NDArray getArray() {
+        return array;
+    }
+
+    private int getGlobalPoolingDim() {
+        int poolDim = getArray().getShape().dimension() - 2;
+        if (poolDim < 1 || poolDim > 3) {
+            throw new IllegalStateException(
+                    "GlobalPooling only support"
+                            + "1 to 3 Dimensions, "
+                            + poolDim
+                            + "D is not supported.");
+        }
+        return poolDim;
+    }
+
+    private Shape getGlobalPoolingShapes(long fillValue) {
+        // determine pooling dimension according to input
+        // input dimension minus 2 (batch and channel dim)
+        int poolDim = getGlobalPoolingDim();
+        long[] shape = new long[poolDim];
+        Arrays.fill(shape, fillValue);
+        return new Shape(shape);
+    }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/NDArrayIndexer.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/NDArrayIndexer.java
new file mode 100644
index 0000000..ee93fd7
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/NDArrayIndexer.java
@@ -0,0 +1,163 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.ndarray;
+
+import java.util.List;
+import java.util.Optional;
+import java.util.Stack;
+import org.apache.mxnet.engine.OpParams;
+import org.apache.mxnet.ndarray.dim.NDIndexBooleans;
+import org.apache.mxnet.ndarray.dim.NDIndexElement;
+import org.apache.mxnet.ndarray.dim.full.NDIndexFullPick;
+import org.apache.mxnet.ndarray.dim.full.NDIndexFullSlice;
+import org.apache.mxnet.ndarray.index.NDIndex;
+import org.apache.mxnet.ndarray.types.Shape;
+
+/** A helper class for {@link NDArray} implementations for operations with an {@link NDIndex}. */
+public class NDArrayIndexer {
+
+    /**
+     * Returns a subarray by picking the elements.
+     *
+     * @param array the array to get from
+     * @param index the index to get
+     * @return the subArray
+     */
+    public NDArray get(NDArray array, NDIndex index) {
+        if (index.getRank() == 0 && array.getShape().isScalar()) {
+            return array.duplicate();
+        }
+
+        // use booleanMask for NDIndexBooleans case
+        List<NDIndexElement> indices = index.getIndices();
+        if (!indices.isEmpty() && indices.get(0) instanceof NDIndexBooleans) {
+            if (indices.size() != 1) {
+                throw new IllegalArgumentException(
+                        "get() currently didn't support more that one boolean NDArray");
+            }
+            return array.booleanMask(((NDIndexBooleans) indices.get(0)).getIndex());
+        }
+
+        Optional<NDIndexFullPick> fullPick = NDIndexFullPick.fromIndex(index, array.getShape());
+        if (fullPick.isPresent()) {
+            return get(array, fullPick.get());
+        }
+
+        Optional<NDIndexFullSlice> fullSlice = NDIndexFullSlice.fromIndex(index, array.getShape());
+        if (fullSlice.isPresent()) {
+            return get(array, fullSlice.get());
+        }
+        throw new UnsupportedOperationException(
+                "get() currently supports all, fixed, and slices indices");
+    }
+
+    /**
+     * Returns a subarray by picking the elements.
+     *
+     * @param array the array to get from
+     * @param fullPick the elements to pick
+     * @return the subArray
+     */
+    public NDArray get(NDArray array, NDIndexFullPick fullPick) {
+        OpParams params = new OpParams();
+        params.addParam("axis", fullPick.getAxis());
+        params.addParam("keepdims", true);
+        params.add("mode", "wrap");
+        return NDArray.invoke(
+                        array.getParent(), "pick", new NDList(array, fullPick.getIndices()), params)
+                .singletonOrThrow();
+    }
+
+    /**
+     * Returns a subarray at the slice.
+     *
+     * @param array the array to get from
+     * @param fullSlice the fullSlice index of the array
+     * @return the subArray
+     */
+    public NDArray get(NDArray array, NDIndexFullSlice fullSlice) {
+        OpParams params = new OpParams();
+        params.addTupleParam("begin", fullSlice.getMin());
+        params.addTupleParam("end", fullSlice.getMax());
+        params.addTupleParam("step", fullSlice.getStep());
+
+        NDArray result = NDArray.invoke(array.getParent(), "_npi_slice", array, params);
+        int[] toSqueeze = fullSlice.getToSqueeze();
+        if (toSqueeze.length > 0) {
+            NDArray oldResult = result;
+            result = result.squeeze(toSqueeze);
+            oldResult.close();
+        }
+        return result;
+    }
+
+    /**
+     * Sets the values of the array at the fullSlice with an array.
+     *
+     * @param array the array to set
+     * @param fullSlice the fullSlice of the index to set in the array
+     * @param value the value to set with
+     */
+    public void set(NDArray array, NDIndexFullSlice fullSlice, NDArray value) {
+        OpParams params = new OpParams();
+        params.addTupleParam("begin", fullSlice.getMin());
+        params.addTupleParam("end", fullSlice.getMax());
+        params.addTupleParam("step", fullSlice.getStep());
+
+        Stack<NDArray> prepareValue = new Stack<>();
+        prepareValue.add(value);
+        prepareValue.add(prepareValue.peek().toDevice(array.getDevice(), false));
+        // prepareValue.add(prepareValue.peek().asType(getDataType(), false));
+        // Deal with the case target: (1, 10, 1), original (10)
+        // try to find (10, 1) and reshape (10) to that
+        Shape targetShape = fullSlice.getShape();
+        while (targetShape.size() > value.size()) {
+            targetShape = targetShape.slice(1);
+        }
+        prepareValue.add(prepareValue.peek().reshape(targetShape));
+        prepareValue.add(prepareValue.peek().broadcast(fullSlice.getShape()));
+
+        NDArray.invoke(
+                "_npi_slice_assign",
+                new NDArray[] {array, prepareValue.peek()},
+                new NDArray[] {array},
+                params);
+        for (NDArray toClean : prepareValue) {
+            if (toClean != value) {
+                toClean.close();
+            }
+        }
+    }
+
+    /**
+     * Sets the values of the array at the fullSlice with a number.
+     *
+     * @param array the array to set
+     * @param fullSlice the fullSlice of the index to set in the array
+     * @param value the value to set with
+     */
+    public void set(NDArray array, NDIndexFullSlice fullSlice, Number value) {
+        OpParams params = new OpParams();
+        params.addTupleParam("begin", fullSlice.getMin());
+        params.addTupleParam("end", fullSlice.getMax());
+        params.addTupleParam("step", fullSlice.getStep());
+        params.addParam("scalar", value);
+        NDArray.invoke(
+                "_npi_slice_assign_scalar", new NDArray[] {array}, new NDArray[] {array}, params);
+    }
+}
diff --git a/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/NDArrays.java b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/NDArrays.java
new file mode 100644
index 0000000..df0c6b8
--- /dev/null
+++ b/java-package/mxnet-engine/src/main/java/org/apache/mxnet/ndarray/NDArrays.java
@@ -0,0 +1,2008 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.ndarray;
+
+import java.util.Arrays;
+import org.apache.mxnet.ndarray.types.Shape;
+
+/** This class contains various methods for manipulating MxNDArrays. */
+public final class NDArrays {
+
+    private NDArrays() {}
+
+    private static void checkInputs(NDArray[] arrays) {
+        if (arrays == null || arrays.length < 2) {
+            throw new IllegalArgumentException("Passed in arrays must have at least one element");
+        }
+        if (arrays.length > 2
+                && Arrays.stream(arrays).skip(1).anyMatch(array -> !arrays[0].shapeEquals(array))) {
+            throw new IllegalArgumentException("The shape of all inputs must be the same");
+        }
+    }
+
+    ////////////////////////////////////////
+    // Operations: Element Comparison
+    ////////////////////////////////////////
+
+    /**
+     * Returns {@code true} if all elements in {@link NDArray} a are equal to {@link NDArray} b.
+     *
+     * <p>Examples
+     *
+     * <pre>
+     * jshell&gt; MxNDArray array = manager.ones(new Shape(3));
+     * jshell&gt; MxNDArrays.contentEquals(array, 1); // return true instead of boolean MxNDArray
+     * true
+     * </pre>
+     *
+     * @param a the {@link NDArray} to compare
+     * @param n the number to compare
+     * @return the boolean result
+     */
+    public static boolean contentEquals(NDArray a, Number n) {
+        if (a == null) {
+            return false;
+        }
+        return a.contentEquals(n);
+    }
+
+    /**
+     * Returns {@code true} if all elements in {@link NDArray} a are equal to {@link NDArray} b.
+     *
+     * <p>Examples
+     *
+     * <pre>
+     * jshell&gt; MxNDArray array1 = manager.arange(6f).reshape(2, 3);
+     * jshell&gt; MxNDArray array2 = manager.create(new float[] {0f, 1f, 2f, 3f, 4f, 5f}, new Shape(2, 3));
+     * jshell&gt; MxNDArrays.contentEquals(array1, array2); // return true instead of boolean MxNDArray
+     * true
+     * </pre>
+     *
+     * @param a the {@link NDArray} to compare
+     * @param b the {@link NDArray} to compare
+     * @return the boolean result
+     */
+    public static boolean contentEquals(NDArray a, NDArray b) {
+        return a.contentEquals(b);
+    }
+
+    /**
+     * Checks 2 {@link NDArray}s for equal shapes.
+     *
+     * <p>Shapes are considered equal if:
+     *
+     * <ul>
+     *   <li>Both {@link NDArray}s have equal rank, and
+     *   <li>size(0)...size(rank()-1) are equal for both {@link NDArray}s
+     * </ul>
+     *
+     * <p>Examples
+     *
+     * <pre>
+     * jshell&gt; MxNDArray array1 = manager.ones(new Shape(1, 2, 3));
+     * jshell&gt; MxNDArray array2 = manager.create(new Shape(1, 2, 3));
+     * jshell&gt; MxNDArrays.shapeEquals(array1, array2); // return true instead of boolean MxNDArray
+     * true
+     * </pre>
+     *
+     * @param a the {@link NDArray} to compare
+     * @param b the {@link NDArray} to compare
+     * @return {@code true} if the {@link Shape}s are the same
+     */
+    public static boolean shapeEquals(NDArray a, NDArray b) {
+        return a.shapeEquals(b);
+    }
+
+    /**
+     * Returns {@code true} if two {@link NDArray} are element-wise equal within a tolerance.
+     *
+     * <p>Examples
+     *
+     * <pre>
+     * jshell&gt; MxNDArray array1 = manager.create(new double[] {1e10,1e-7});
+     * jshell&gt; MxNDArray array2 = manager.create(new double[] {1.00001e10,1e-8});
+     * jshell&gt; MxNDArrays.allClose(array1, array2); // return false instead of boolean MxNDArray
+     * false
+     * jshell&gt; MxNDArray array1 = manager.create(new double[] {1e10,1e-8});
+     * jshell&gt; MxNDArray array2 = manager.create(new double[] {1.00001e10,1e-9});
+     * jshell&gt; MxNDArrays.allClose(array1, array2); // return true instead of boolean MxNDArray
+     * true
+     * </pre>
+     *
+     * @param a the {@link NDArray} to compare with
+     * @param b the {@link NDArray} to compare with
+     * @return the boolean result
+     */
+    //    public static boolean allClose(MxNDArray a, MxNDArray b) {
+    //        return a.allClose(b);
+    //    }
+
+    /**
+     * Returns {@code true} if two {@link NDArray} are element-wise equal within a tolerance.
+     *
+     * <p>Examples
+     *
+     * <pre>
+     * jshell&gt; MxNDArray array1 = manager.create(new double[] {1e10, 1e-7});
+     * jshell&gt; MxNDArray array2 = manager.create(new double[] {1.00001e10, 1e-8});
+     * jshell&gt; MxNDArrays.allClose(array1, array2, 1e-05, 1e-08, false); // return false instead of boolean MxNDArray
+     * false
+     * jshell&gt; MxNDArray array1 = manager.create(new double[] {1e10, 1e-8});
+     * jshell&gt; MxNDArray array2 = manager.create(new double[] {1.00001e10, 1e-9});
+     * jshell&gt; MxNDArrays.allClose(array1, array2, 1e-05, 1e-08, false); // return true instead of boolean MxNDArray
+     * true
+     * jshell&gt; MxNDArray array1 = manager.create(new float[] {1f, Float.NaN});
+     * jshell&gt; MxNDArray array2 = manager.create(new float[] {1f, Float.NaN});
+     * jshell&gt; MxNDArrays.allClose(array1, array2, 1e-05, 1e-08, true); // return true instead of boolean MxNDArray
+     * true
+     * </pre>
+     *
+     * @param a the {@link NDArray} to compare with
+     * @param b the {@link NDArray} to compare with
+     * @param rtol the relative tolerance parameter
+     * @param atol the absolute tolerance parameter
+     * @param equalNan whether to compare NaN’s as equal. If {@code true}, NaN’s in the {@link
+     *     NDArray} will be considered equal to NaN’s in the other {@link NDArray}
+     * @return the boolean result
+     */
+    //    public static boolean allClose(
+    //            MxNDArray a, MxNDArray b, double rtol, double atol, boolean equalNan) {
+    //        return a.allClose(b, rtol, atol, equalNan);
+    //    }
+
+    /**
+     * Returns the boolean {@link NDArray} for element-wise "Equals" comparison.
+     *
+     * <p>Examples
+     *
+     * <pre>
+     * jshell&gt; MxNDArray array = manager.ones(new Shape(1));
+     * jshell&gt; MxNDArrays.eq(array, 1);
+     * ND: (1) cpu() boolean
+     * [ true]
+     * </pre>
+     *
+     * @param a the {@link NDArray} to compare
+     * @param n the number to compare
+     * @return the boolean {@link NDArray} for element-wise "Equals" comparison
+     */
+    public static NDArray eq(NDArray a, Number n) {
+        return a.eq(n);
+    }
+
+    /**
+     * Returns the boolean {@link NDArray} for element-wise "Equals" comparison.
+     *
+     * <p>Examples
+     *
+     * <pre>
+     * jshell&gt; MxNDArray array = manager.ones(new Shape(1));
+     * jshell&gt; MxNDArrays.eq(1, array);
+     * ND: (1) cpu() boolean
+     * [ true]
+     * </pre>
+     *
+     * @param n the number to compare
+     * @param a the {@link NDArray} to compare
+     * @return the boolean {@link NDArray} for element-wise "Equals" comparison
+     */
+    public static NDArray eq(Number n, NDArray a) {
+        return a.eq(n);
+    }
+
+    /**
+     * Returns the boolean {@link NDArray} for element-wise "Equals" comparison.
+     *
+     * <p>Examples
+     *
+     * <pre>
+     * jshell&gt; MxNDArray array1 = manager.create(new float[] {0f, 1f, 3f});
+     * jshell&gt; MxNDArray array2 = manager.arange(3f);
+     * jshell&gt; MxNDArrays.eq(array1, array2);
+     * ND: (3) cpu() boolean
+     * [ true,  true, false]
+     * </pre>
+     *
+     * @param a the {@link NDArray} to compare
+     * @param b the {@link NDArray} to compare
+     * @return the boolean {@link NDArray} for element-wise "Equals" comparison
+     */
+    public static NDArray eq(NDArray a, NDArray b) {
+        return a.eq(b);
+    }
+
+    /**
+     * Returns the boolean {@link NDArray} for element-wise "Not equals" comparison.
+     *
+     * <p>Examples
+     *
+     * <pre>
+     * jshell&gt; MxNDArray array = manager.arange(4f).reshape(2, 2);
+     * jshell&gt; MxNDArrays.neq(array, 1);
+     * ND: (2, 2) cpu() boolean
+     * [[ true, false],
+     *  [ true,  true],
+     * ]
+     * </pre>
+     *
+     * @param a the {@link NDArray} to compare
+     * @param n the number to compare
+     * @return the boolean {@link NDArray} for element-wise "Not equals" comparison
+     */
+    public static NDArray neq(NDArray a, Number n) {
+        return a.neq(n);
+    }
+
+    /**
+     * Returns the boolean {@link NDArray} for element-wise "Not equals" comparison.
+     *
+     * <p>Examples
+     *
+     * <pre>
+     * jshell&gt; MxNDArray array = manager.arange(f4).reshape(2, 2);
+     * jshell&gt; MxNDArrays.neq(1, array);
+     * ND: (2, 2) cpu() boolean
+     * [[ true, false],
+     *  [ true,  true],
+     * ]
+     * </pre>
+     *
+     * @param n the number to compare
+     * @param a the {@link NDArray} to compare
+     * @return the boolean {@link NDArray} for element-wise "Not equals" comparison
+     */
+    public static NDArray neq(Number n, NDArray a) {
+        return a.neq(n);
+    }
+
+    /**
+     * Returns the boolean {@link NDArray} for element-wise "Not equals" comparison.
+     *
+     * <p>Examples
+     *
+     * <pre>
+     * jshell&gt; MxNDArray array1 = manager.create(new float[] {1f, 2f});
+     * jshell&gt; MxNDArray array2 = manager.create(new float[] {1f, 3f});
+     * jshell&gt; MxNDArrays.neq(array1, array2);
+     * ND: (2) cpu() boolean
+     * [false,  true]
+     * jshell&gt; MxNDArray array1 = manager.create(new float[] {1f, 2f});
+     * jshell&gt; MxNDArray array2 = manager.create(new float[] {1f, 3f, 1f, 4f}, new Shape(2, 2));
+     * jshell&gt; MxNDArrays.neq(array1, array2); // broadcasting
+     * ND: (2, 2) cpu() boolean
+     * [[false,  true],
+     *  [false,  true],
+     * ]
+     * </pre>
+     *
+     * @param a the {@link NDArray} to compare
+     * @param b the {@link NDArray} to compare
+     * @return the boolean {@link NDArray} for element-wise "Not equals" comparison
+     */
+    public static NDArray neq(NDArray a, NDArray b) {
+        return a.neq(b);
+    }
+
+    /**
+     * Returns the boolean {@link NDArray} for element-wise "Greater Than" comparison.
+     *
+     * <p>Examples
+     *
+     * <pre>
+     * jshell&gt; MxNDArray array = manager.create(new float[] {4f, 2f});
+     * jshell&gt; MxNDArrays.gt(array, 2f);
+     * ND: (2) cpu() boolean
+     * [ true, false]
+     * </pre>
+     *
+     * @param a the {@link NDArray} to compare
+     * @param n the number to be compared against
+     * @return the boolean {@link NDArray} for element-wise "Greater Than" comparison
+     */
+    public static NDArray gt(NDArray a, Number n) {
+        return a.gt(n);
+    }
+
+    /**
+     * Returns the boolean {@link NDArray} for element-wise "Greater Than" comparison.
+     *
+     * <p>Examples
+     *
+     * <pre>
+     * jshell&gt; MxNDArray array = manager.create(new float[] {4f, 2f});
+     * jshell&gt; MxNDArrays.gt(2f, array);
+     * ND: (2) cpu() boolean
+     * [false, false]
+     * </pre>
+     *
+     * @param n the number to be compared
+     * @param a the MxNDArray to be compared against
+     * @return the boolean {@link NDArray} for element-wise "Greater Than" comparison
+     */
+    public static NDArray gt(Number n, NDArray a) {
+        return a.lt(n);
+    }
+
+    /**
+     * Returns the boolean {@link NDArray} for element-wise "Greater Than" comparison.
+     *
+     * <p>Examples
+     *
+     * <pre>
+     * jshell&gt; MxNDArray array1 = manager.create(new float[] {4f, 2f});
+     * jshell&gt; MxNDArray array2 = manager.create(new float[] {2f, 2f});
+     * jshell&gt; MxNDArrays.gt(array1, array2);
+     * ND: (2) cpu() boolean
+     * [ true, false]
+     * </pre>
+     *
+     * @param a the {@link NDArray} to be compared
+     * @param b the {@link NDArray} to be compared against
+     * @return the boolean {@link NDArray} for element-wise "Greater Than" comparison
+     */
+    public static NDArray gt(NDArray a, NDArray b) {
+        return a.gt(b);
+    }
+
+    /**
+     * Returns the boolean {@link NDArray} for element-wise "Greater or equals" comparison.
+     *
+     * <p>Examples
+     *
+     * <pre>
+     * jshell&gt; MxNDArray array = manager.create(new float[] {4f, 2f});
+     * jshell&gt; MxNDArrays.gte(array, 2);
+     * ND: (2) cpu() boolean
+     * [ true, true]
+     * </pre>
+     *
+     * @param a the {@link NDArray} to be compared
+     * @param n the number to be compared against
+     * @return the boolean {@link NDArray} for element-wise "Greater or equals" comparison
+     */
+    public static NDArray gte(NDArray a, Number n) {
+        return a.gte(n);
+    }
+
+    /**
+     * Returns the boolean {@link NDArray} for element-wise "Greater or equals" comparison.
+     *
+     * <p>Examples
+     *
+     * <pre>
+     * jshell&gt; MxNDArray array = manager.create(new float[] {4f, 2f});
+     * jshell&gt; MxNDArrays.gte(2, array);
+     * ND: (2) cpu() boolean
+     * [false,  true]
+     * </pre>
+     *
+     * @param n the number to be compared
+     * @param a the {@link NDArray} to be compared against
+     * @return the boolean {@link NDArray} for element-wise "Greater or equals" comparison
+     */
+    public static NDArray gte(Number n, NDArray a) {
+        return a.lte(n);
+    }
+
+    /**
+     * Returns the boolean {@link NDArray} for element-wise "Greater or equals" comparison.
+     *
+     * <p>Examples
+     *
+     * <pre>
+     * jshell&gt; MxNDArray array1 = manager.create(new float[] {4f, 2f});
+     * jshell&gt; MxNDArray array2 = manager.create(new float[] {2f, 2f});
+     * jshell&gt; MxNDArrays.gte(array1, array2);
+     * ND: (2) cpu() boolean
+     * [ true, true]
+     * </pre>
+     *
+     * @param a the {@link NDArray} to be compared
+     * @param b the {@link NDArray} to be compared against
+     * @return the boolean {@link NDArray} for element-wise "Greater or equals" comparison
+     */
+    public static NDArray gte(NDArray a, NDArray b) {
+        return a.gte(b);
+    }
+
+    /**
+     * Returns the boolean {@link NDArray} for element-wise "Less" comparison.
+     *
+     * <p>Examples
+     *
+     * <pre>
+     * jshell&gt; MxNDArray array = manager.create(new float[] {1f, 2f});
+     * jshell&gt; MxNDArrays.lt(array, 2f);
+     * ND: (2) cpu() boolean
+     * [ true, false]
+     * </pre>
+     *
+     * @param a the {@link NDArray} to be compared
+     * @param n the number to be compared against
+     * @return the boolean {@link NDArray} for element-wise "Less" comparison
+     */
+    public static NDArray lt(NDArray a, Number n) {
+        return a.lt(n);
+    }
+
+    /**
+     * Returns the boolean {@link NDArray} for element-wise "Less" comparison.
+     *
+     * <p>Examples
+     *
+     * <pre>
+     * jshell&gt; MxNDArray array = manager.create(new float[] {1f, 2f});
+     * jshell&gt; MxNDArrays.lt(2f, array);
+     * ND: (2) cpu() boolean
+     * [false, false]
+     * </pre>
+     *
+     * @param n the number to be compared
+     * @param a the {@link NDArray} to be compared against
+     * @return the boolean {@link NDArray} for element-wise "Less" comparison
+     */
+    public static NDArray lt(Number n, NDArray a) {
+        return a.gt(n);
+    }
+
+    /**
+     * Returns the boolean {@link NDArray} for element-wise "Less" comparison.
+     *
+     * <p>Examples
+     *
+     * <pre>
+     * jshell&gt; MxNDArray array1 = manager.create(new float[] {1f, 2f});
+     * jshell&gt; MxNDArray array2 = manager.create(new float[] {2f, 2f});
+     * jshell&gt; MxNDArrays.lt(array1, array2);
+     * ND: (2) cpu() boolean
+     * [ true, false]
+     * </pre>
+     *
+     * @param a the {@link NDArray} to be compared
+     * @param b the {@link NDArray} to be compared against
+     * @return the boolean {@link NDArray} for element-wise "Less" comparison
+     */
+    public static NDArray lt(NDArray a, NDArray b) {
+        return a.lt(b);
+    }
+
+    /**
+     * Returns the boolean {@link NDArray} for element-wise "Less or equals" comparison.
+     *
+     * <p>Examples
+     *
+     * <pre>
+     * jshell&gt; MxNDArray array = manager.create(new float[] {1f, 2f});
+     * jshell&gt; MxNDArrays.lte(array, 2f);
+     * ND: (2) cpu() boolean
+     * [ true, true]
+     * </pre>
+     *
+     * @param a the {@link NDArray} to be compared
+     * @param n the number to be compared against
+     * @return the boolean {@link NDArray} for element-wise "Less or equals" comparison
+     */
+    public static NDArray lte(NDArray a, Number n) {
+        return a.lte(n);
+    }
+
+    /**
+     * Returns the boolean {@link NDArray} for element-wise "Less or equals" comparison.
+     *
+     * <p>Examples
+     *
+     * <pre>
+     * jshell&gt; MxNDArray array = manager.create(new float[] {1f, 2f});
+     * jshell&gt; MxNDArrays.lte(2f, array);
+     * ND: (2) cpu() boolean
+     * [false,  true]
+     * </pre>
+     *
+     * @param n the number to be compared
+     * @param a the {@link NDArray} to be compared against
+     * @return the boolean {@link NDArray} for element-wise "Less or equals" comparison
+     */
+    public static NDArray lte(Number n, NDArray a) {
+        return a.gte(n);
+    }
+
+    /**
+     * Returns the boolean {@link NDArray} for element-wise "Less or equals" comparison.
+     *
+     * <p>Examples
+     *
+     * <pre>
+     * jshell&gt; MxNDArray array1 = manager.create(new float[] {1f, 2f});
+     * jshell&gt; MxNDArray array2 = manager.create(new float[] {2f, 2f});
+     * jshell&gt; MxNDArrays.lte(array1, array2)
+     * ND: (2) cpu() boolean
+     * [ true, true]
+     * </pre>
+     *
+     * @param a the {@link NDArray} to be compared
+     * @param b the {@link NDArray} to be compared against
+     * @return the boolean {@link NDArray} for element-wise "Less or equals" comparison
+     */
+    public static NDArray lte(NDArray a, NDArray b) {
+        return a.lte(b);
+    }
+
+    /**
+     * Returns elements chosen from the {@link NDArray} or the other {@link NDArray} depending on
+     * condition.
+     *
+     * <p>Given three {@link NDArray}s, condition, a, and b, returns an {@link NDArray} with the
+     * elements from a or b, depending on whether the elements from condition {@link NDArray} are
+     * {@code true} or {@code false}. If condition has the same shape as a, each element in the
+     * output {@link NDArray} is from this if the corresponding element in the condition is {@code
+     * true}, and from other if {@code false}.
+     *
+     * <p>Note that all non-zero values are interpreted as {@code true} in condition {@link
+     * NDArray}.
+     *
+     * <p>Examples
+     *
+     * <pre>
+     * jshell&gt; MxNDArray array = manager.arange(10f);
+     * jshell&gt; MxNDArrays.where(array.lt(5), array, array.mul(10));
+     * ND: (10) cpu() float32
+     * [ 0.,  1.,  2.,  3.,  4., 50., 60., 70., 80., 90.]
+     * jshell&gt; MxNDArray array = manager.create(new float[]{0f, 1f, 2f, 0f, 2f, 4f, 0f, 3f, 6f}, new Shape(3, 3));
+     * jshell&gt; MxNDArrays.where(array.lt(4), array, manager.create(-1f));
+     * ND: (3, 3) cpu() float32
+     * [[ 0.,  1.,  2.],
+     *  [ 0.,  2., -1.],
+     *  [ 0.,  3., -1.],
+     * ]
+     * </pre>
+     *
+     * @param condition the condition {@code MxNDArray}
+     * @param a the first {@link NDArray}
+     * @param b the other {@link NDArray}
+     * @return the result {@link NDArray}
+     */
+    public static NDArray where(NDArray condition, NDArray a, NDArray b) {
+        return a.getNDArrayInternal().where(condition, b);
+    }
+
+    /**
+     * Returns the maximum of a {@link NDArray} and a number element-wise.
+     *
+     * <p>Examples
+     *
+     * <pre>
+     * jshell&gt; MxNDArray array = manager.create(new float[] {2f, 3f, 4f});
+     * jshell&gt; MxNDArrays.maximum(array, 3f);
+     * ND: (3) cpu() float32
+     * [3., 3., 4.]
+     * </pre>
+     *
+     * @param a the {@link NDArray} to be compared
+     * @param n the number to be compared
+     * @return the maximum of a {@link NDArray} and a number element-wise
+     */
+    public static NDArray maximum(NDArray a, Number n) {
+        return a.maximum(n);
+    }
+
+    /**
+     * Returns the maximum of a number and a {@link NDArray} element-wise.
+     *
+     * <p>Examples
+     *
+     * <pre>
+     * jshell&gt; MxNDArray array = manager.create(new float[] {2f, 3f, 4f});
+     * jshell&gt; MxNDArrays.maximum(3f, array);
+     * ND: (3) cpu() float32
+     * [3., 3., 4.]
+     * </pre>
+     *
+     * @param n the number to be compared
+     * @param a the {@link NDArray} to be compared
+     * @return the maximum of a number and a {@link NDArray} element-wise
+     */
+    public static NDArray maximum(Number n, NDArray a) {
+        return maximum(a, n);
+    }
+
+    /**
+     * Returns the maximum of {@link NDArray} a and {@link NDArray} b element-wise.
+     *
+     * <p>The shapes of {@link NDArray} a and {@link NDArray} b must be broadcastable.
+     *
+     * <p>Examples
+     *
+     * <pre>
+     * jshell&gt; MxNDArray array1 = manager.create(new float[] {2f, 3f, 4f});
+     * jshell&gt; MxNDArray array2 = manager.create(new float[] {1f, 5f, 2f});
+     * jshell&gt; MxNDArrays.maximum(array1, array2);
+     * ND: (3) cpu() float32
+     * [2., 5., 4.]
+     * jshell&gt; MxNDArray array1 = manager.eye(2);
+     * jshell&gt; MxNDArray array2 = manager.create(new float[] {0.5f, 2f});
+     * jshell&gt; MxNDArrays.maximum(array1, array2); // broadcasting
+     * ND: (2, 2) cpu() float32
+     * [[1. , 2. ],
+     *  [0.5, 2. ],
+     * ]
+     * </pre>
+     *
+     * @param a the {@link NDArray} to be compared
+     * @param b the {@link NDArray} to be compared
+     * @return the maximum of {@link NDArray} a and {@link NDArray} b element-wise
+     */
+    public static NDArray maximum(NDArray a, NDArray b) {
+        return a.maximum(b);
+    }
+
+    /**
+     * Returns the minimum of a {@link NDArray} and a number element-wise.
+     *
+     * <p>Examples
+     *
+     * <pre>
+     * jshell&gt; MxNDArray array = manager.create(new float[] {2f, 3f, 4f});
+     * jshell&gt; MxNDArrays.minimum(array, 3f);
+     * ND: (3) cpu() float32
+     * [2., 3., 3.]
+     * </pre>
+     *
+     * @param a the {@link NDArray} to be compared
+     * @param n the number to be compared
+     * @return the minimum of a {@link NDArray} and a number element-wise
+     */
+    public static NDArray minimum(NDArray a, Number n) {
+        return a.minimum(n);
+    }
+
+    /**
+     * Returns the minimum of a number and a {@link NDArray} element-wise.
+     *
+     * <p>Examples
+     *
+     * <pre>
+     * jshell&gt; MxNDArray array = manager.create(new float[] {2f, 3f, 4f});
+     * jshell&gt; MxNDArrays.minimum(3f, array);
+     * ND: (3) cpu() float32
+     * [2., 3., 3.]
+     * </pre>
+     *
+     * @param n the number to be compared
+     * @param a the {@link NDArray} to be compared
+     * @return the minimum of a number and a {@link NDArray} element-wise
+     */
+    public static NDArray minimum(Number n, NDArray a) {
+        return minimum(a, n);
+    }
+
+    /**
+     * Returns the minimum of {@link NDArray} a and {@link NDArray} b element-wise.
+     *
+     * <p>The shapes of {@link NDArray} a and {@link NDArray} b must be broadcastable.
+     *
+     * <p>Examples
+     *
+     * <pre>
+     * jshell&gt; MxNDArray array1 = manager.create(new float[] {2f, 3f, 4f});
+     * jshell&gt; MxNDArray array2 = manager.create(new float[] {1f, 5f, 2f});
+     * jshell&gt; MxNDArrays.minimum(array1, array2);
+     * ND: (3) cpu() float32
+     * [1., 3., 2.]
+     * jshell&gt; MxNDArray array1 = manager.eye(2);
+     * jshell&gt; MxNDArray array2 = manager.create(new float[] {0.5f, 2f});
+     * jshell&gt; MxNDArrays.minimum(array1, array2); // broadcasting
+     * ND: (2, 2) cpu() float32
+     * [[0.5, 0. ],
+     *  [0. , 1. ],
+     * ]
+     * </pre>
+     *
+     * @param a the {@link NDArray} to be compared
+     * @param b the {@link NDArray} to be compared
+     * @return the minimum of {@link NDArray} a and {@link NDArray} b element-wise
+     */
+    public static NDArray minimum(NDArray a, NDArray b) {
+        return a.minimum(b);
+    }
+
+    /**
+     * Returns portion of the {@link NDArray} given the index boolean {@link NDArray} along first
+     * axis.
+     *
+     * <p>Examples
+     *
+     * <pre>
+     * jshell&gt; MxNDArray array = manager.create(new float[] {1f, 2f, 3f, 4f, 5f, 6f}, new Shape(3, 2));
+     * jshell&gt; MxNDArray mask = manager.create(new boolean[] {true, false, true});
+     * jshell&gt; MxNDArrays.booleanMask(array, mask);
+     * ND: (2, 2) cpu() float32
+     * [[1., 2.],
+     *  [5., 6.],
+     * ]
+     * </pre>
+     *
+     * @param data the {@link NDArray} to operate on
+     * @param index the boolean {@link NDArray} mask
+     * @return the result {@link NDArray}
+     */
+    public static NDArray booleanMask(NDArray data, NDArray index) {
+        return booleanMask(data, index, 0);
+    }
+
+    /**
+     * Returns portion of the {@link NDArray} given the index boolean {@link NDArray} along given
+     * axis.
+     *
+     * @param data the {@link NDArray} to operate on
+     * @param index the boolean {@link NDArray} mask
+     * @param axis an integer that represents the axis of {@link NDArray} to mask from
+     * @return the result {@link NDArray}
+     */
+    public static NDArray booleanMask(NDArray data, NDArray index, int axis) {
+        return data.booleanMask(index, axis);
+    }
+
+    /**
+     * Sets all elements of the given {@link NDArray} outside the sequence {@link NDArray} to a
+     * constant value.
+     *
+     * <p>This function takes an n-dimensional input array of the form [batch_size,
+     * max_sequence_length, ....] and returns an array of the same shape. Parameter {@code
+     * sequenceLength} is used to handle variable-length sequences. {@code sequenceLength} should be
+     * an input array of positive ints of dimension [batch_size].
+     *
+     * @param data the {@link NDArray} to operate on
+     * @param sequenceLength used to handle variable-length sequences
+     * @param value the constant value to be set
+     * @return the result {@link NDArray}
+     */
+    public static NDArray sequenceMask(NDArray data, NDArray sequenceLength, float value) {
+        return data.sequenceMask(sequenceLength, value);
+    }
+
+    /**
+     * Sets all elements of the given {@link NDArray} outside the sequence {@link NDArray} to 0.
+     *
+     * <p>This function takes an n-dimensional input array of the form [batch_size,
+     * max_sequence_length, ....] and returns an array of the same shape. Parameter {@code
+     * sequenceLength} is used to handle variable-length sequences. {@code sequenceLength} should be
+     * an input array of positive ints of dimension [batch_size].
+     *
+     * @param data the {@link NDArray} to operate on
+     * @param sequenceLength used to handle variable-length sequences
+     * @return the result {@link NDArray}
+     */
+    public static NDArray sequenceMask(NDArray data, NDArray sequenceLength) {
+        return data.sequenceMask(sequenceLength);
+    }
+
+    ////////////////////////////////////////
+    // Operations: Element Arithmetic
+    ////////////////////////////////////////
+
+    /**
+     * Adds a number to the {@link NDArray} element-wise.
+     *
+     * <p>Examples
+     *
+     * <pre>
+     * jshell&gt; MxNDArray array = manager.create(new float[] {1f, 2f});
+     * jshell&gt; MxNDArrays.add(array, 2f);
+     * ND: (2) cpu() float32
+     * [3., 4.]
+     * </pre>
+     *
+     * @param a the {@link NDArray} to be added to
+     * @param n the number to add
+     * @return the result {@link NDArray}
+     */
+    public static NDArray add(NDArray a, Number n) {
+        return a.add(n);
+    }
+
+    /**
+     * Adds a {@link NDArray} to a number element-wise.
+     *
+     * <p>Examples
+     *
+     * <pre>
+     * jshell&gt; MxNDArray array = manager.create(new float[] {1f, 2f});
+     * jshell&gt; MxNDArrays.add(2f, array);
+     * ND: (2) cpu() float32
+     * [3., 4.]
+     * </pre>
+     *
+     * @param n the number to be added to
+     * @param a the {@link NDArray} to add
+     * @return the result {@link NDArray}
+     */
+    public static NDArray add(Number n, NDArray a) {
+        return a.add(n);
+    }
+
+    /**