Merge branch 'dev' for V3.1-RC1
diff --git a/.asf.yaml b/.asf.yaml
index df96840..1e0b37f 100644
--- a/.asf.yaml
+++ b/.asf.yaml
@@ -24,3 +24,5 @@
     wiki: true
     # Enable issues on github
     issues: true
+    # Enable settings on github
+    settings: true
diff --git a/.codecov.yml b/.codecov.yml
new file mode 100644
index 0000000..497c927
--- /dev/null
+++ b/.codecov.yml
@@ -0,0 +1,19 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+ignore:
+  - proto/*
diff --git a/.github/workflows/conda.yaml b/.github/workflows/conda.yaml
new file mode 100644
index 0000000..882a424
--- /dev/null
+++ b/.github/workflows/conda.yaml
@@ -0,0 +1,71 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# This is a basic workflow to help you get started with Actions
+
+name: conda
+
+# Controls when the action will run. Triggers the workflow on push or pull request
+# events but only for the master branch
+on:
+  push:
+  pull_request:
+
+# A workflow run is made up of one or more jobs that can run sequentially or in parallel
+jobs:
+  build-pytest-package-ubuntu:
+    runs-on: ubuntu-latest
+
+    steps:
+      - uses: actions/checkout@v1
+      - name: install-conda-build
+        run: conda install conda-build anaconda-client
+      - name: conda-config
+        run: conda config --add channels conda-forge && conda config --add channels nusdbsystem && conda config --set anaconda_upload no
+      - name: build-pytest
+        run:  conda build tool/conda/singa --python 3.6
+        env:
+          TEST_COMMAND: pytest --cov=$PREFIX/lib/python3.7/site-packages/singa --cov-report=xml && codecov --flags singa-python
+      - name: upload-package
+        env: 
+          ANACONDA_UPLOAD_TOKEN: ${{ secrets.ANACONDA_UPLOAD_TOKEN }}
+        if: ${{ env.ANACONDA_UPLOAD_TOKEN }}
+        run: /usr/share/miniconda/bin/anaconda -t $ANACONDA_UPLOAD_TOKEN upload -u nusdbsystem -l main /usr/share/miniconda/conda-bld/linux-64/singa-*.tar.bz2 --force
+        # 
+
+
+  build-pytest-package-macos:
+    runs-on: macos-latest
+
+    steps:
+      - uses: actions/checkout@v1
+      - name: set permission
+        run: sudo chmod -R 777 /usr/local/miniconda 
+        # && xcrun --show-sdk-path
+      - name: install-conda-build
+        run: conda install conda-build anaconda-client
+      - name: conda-config
+        run: conda config --add channels conda-forge && conda config --add channels nusdbsystem && conda config --set anaconda_upload no
+      - name: build-pytest
+        run:  conda build tool/conda/singa --python 3.6
+        env:
+          TEST_COMMAND: pytest --cov=$PREFIX/lib/python3.6/site-packages/singa --cov-report=xml && codecov --flags singa-python
+      - name: upload-package
+        env: 
+          ANACONDA_UPLOAD_TOKEN: ${{ secrets.ANACONDA_UPLOAD_TOKEN }}
+        if: ${{ env.ANACONDA_UPLOAD_TOKEN }}
+        run: /usr/local/miniconda/bin/anaconda -t $ANACONDA_UPLOAD_TOKEN upload -u nusdbsystem -l main /usr/local/miniconda/conda-bld/osx-64/singa-*.tar.bz2 --force
\ No newline at end of file
diff --git a/.github/workflows/macOS.yaml b/.github/workflows/macOS.yaml
new file mode 100644
index 0000000..d38ce32
--- /dev/null
+++ b/.github/workflows/macOS.yaml
@@ -0,0 +1,56 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+name: Native-MacOS
+
+on:
+  push:
+  pull_request:
+
+jobs:
+  build-cpptest-cpu:
+    runs-on: macos-latest
+
+    steps:
+      - uses: actions/checkout@v1
+      - uses: actions/setup-python@v2
+        with:
+          python-version: "3.7"
+      - name: install-build-dependencies
+        run: |
+         brew install protobuf swig opencv glog lmdb numpy
+         pip3 install numpy && wget https://github.com/oneapi-src/oneDNN/releases/download/v1.2/dnnl_mac_1.2.0_cpu_tbb.tgz -P /tmp
+         tar zxf /tmp/dnnl_mac_1.2.0_cpu_tbb.tgz -C /tmp
+      - name: configure
+        run: mkdir build && cd build && cmake -DUSE_PYTHON3=YES -DENABLE_TEST=YES -DUSE_DNNL=YES ..
+        env:
+          CMAKE_INCLUDE_PATH: /usr/local/opt/openblas/include:$CMAKE_INCLUDE_PATH
+          CMAKE_LIBRARY_PATH: /usr/local/opt/openblas/lib:$CMAKE_LIBRARY_PATH
+          DNNL_ROOT: /tmp/dnnl_mac_1.2.0_cpu_tbb/
+      - name: build
+        run: cd build && make
+        env:
+          CXXFLAGS: -I  /Users/runner/hostedtoolcache/Python/3.7.8/x64/lib/python3.7/site-packages/numpy/core/include $CXXFLAGS
+          LD_LIBRARY_PATH: /usr/local/opt/openblas/lib:/tmp/dnnl_mac_1.2.0_cpu_tbb/lib:$LD_LIBRARY_PATH
+      - name: C++ test
+        run: |
+         brew install tbb
+         install_name_tool -change libdnnl.1.dylib /tmp/dnnl_mac_1.2.0_cpu_tbb/lib/libdnnl.1.dylib /Users/runner/work/singa/singa/build/lib/libsinga.dylib
+         install_name_tool -change libdnnl.1.dylib /tmp/dnnl_mac_1.2.0_cpu_tbb/lib/libdnnl.1.dylib build/bin/test_singa
+         build/bin/test_singa
+        env:
+          LD_LIBRARY_PATH: /usr/local/opt/openblas/lib:/tmp/dnnl_mac_1.2.0_cpu_tbb/lib:$LD_LIBRARY_PATH
diff --git a/.github/workflows/rat.yaml b/.github/workflows/rat.yaml
index d3462e0..b7c588d 100644
--- a/.github/workflows/rat.yaml
+++ b/.github/workflows/rat.yaml
@@ -17,7 +17,7 @@
 
 # This is a basic workflow to help you get started with Actions
 
-name: CI
+name: License-Check
 
 # Controls when the action will run. Triggers the workflow on push or pull request 
 # events but only for the master branch
diff --git a/.github/workflows/ubuntu.yaml b/.github/workflows/ubuntu.yaml
new file mode 100644
index 0000000..b67fcda
--- /dev/null
+++ b/.github/workflows/ubuntu.yaml
@@ -0,0 +1,64 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# This is a basic workflow to help you get started with Actions
+
+name: Native-Ubuntu
+
+# Controls when the action will run. Triggers the workflow on push or pull request
+# events but only for the master branch
+on:
+  push:
+  pull_request:
+
+# A workflow run is made up of one or more jobs that can run sequentially or in parallel
+jobs:
+  # build-ubuntu-cpp:
+  #  runs-on: ubuntu-latest
+
+  #  steps:
+  #    - uses: actions/checkout@v1
+  #    - name: install-build-dependencies
+  #      run: sudo apt-get install -y libgoogle-glog-dev libprotobuf-dev protobuf-compiler libncurses-dev libopenblas-dev gfortran libblas-dev liblapack-dev libatlas-base-dev swig libcurl3-dev cmake dh-autoreconf  
+  #    - name: configure
+  #      run: mkdir build && cd build && cmake -DUSE_PYTHON=NO -DENABLE_TEST=YES ..
+  #    - name: build
+  #      run: cd build && make
+  #    - name: C++ test
+  #      run: build/bin/test_singa
+ 
+  build-cpptest-on-cpu:
+    runs-on: ubuntu-latest
+
+    steps:
+      - uses: actions/checkout@v1
+      - name: get-oneDNN
+        run: wget https://github.com/oneapi-src/oneDNN/releases/download/v1.1/dnnl_lnx_1.1.0_cpu_gomp.tgz -P /tmp/ && tar zxf /tmp/dnnl_lnx_1.1.0_cpu_gomp.tgz -C /tmp
+      - name: install-build-dependencies
+        run: sudo apt-get install -y libgoogle-glog-dev libprotobuf-dev protobuf-compiler libncurses-dev libopenblas-dev gfortran libblas-dev liblapack-dev libatlas-base-dev swig dh-autoreconf lcov
+      - name: configure
+        run: mkdir build && cd build && cmake -DUSE_PYTHON=NO -DENABLE_TEST=YES -DCODE_COVERAGE=YES -DUSE_DNNL=YES ..
+        env:
+         DNNL_ROOT: /tmp/dnnl_lnx_1.1.0_cpu_gomp/
+      - name: build
+        run: cd build && make
+      - name: C++ test
+        run: build/bin/test_singa
+      - name: Upload coverage to Codecov
+        uses: codecov/codecov-action@v1
+        with:
+          flags: singa-cpp
diff --git a/.travis.yml b/.travis.yml
deleted file mode 100644
index b53c0af..0000000
--- a/.travis.yml
+++ /dev/null
@@ -1,70 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-sudo: required
-language: cpp
-
-stages:
-  - linting
-  - build
-
-matrix:
-  include:
-  - os: osx
-    stage: build
-    compiler: clang
-    # system cblas will be used by cmake for other xcode versions
-    osx_image: xcode8
-    name: "macOS 10.11 (C++)"
-  - os: linux
-    stage: build
-    dist: trusty
-    compiler: gcc
-    name: "Ubuntu 14.04 (C++)"
-  - os: linux
-    stage: build
-    dist: xenial
-    compiler: gcc
-    name: "Ubuntu 16.04 (C++)"
-  - os: linux
-    stage: build
-    dist: bionic
-    compiler: gcc
-    name: "Ubuntu 18.04 (C++)"
-  - os: linux
-    stage: linting
-    dist: bionic
-    compiler: gcc
-    name: "Ubuntu 18.04 (Static Analysis Python)"
-    script:
-      - bash -ex tool/linting/py.sh
-  - os: linux
-    stage: linting
-    dist: bionic
-    compiler: gcc
-    name: "Ubuntu 18.04 (Static Analysis Cpp)"
-    script:
-      - bash -ex tool/linting/cpp.sh
-
-install:
-  - travis_wait bash -ex tool/travis/depends.sh
-
-after_success:
-  - bash <(curl -s https://codecov.io/bash)
-
-script:
-  - bash -ex tool/travis/build.sh
diff --git a/CMakeLists.txt b/CMakeLists.txt
index cd67c9f..ba3102c 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -67,6 +67,7 @@
 
 OPTION(USE_CUDA "Use Cuda libs" OFF)
 OPTION(ENABLE_TEST "Enable unit test" OFF)
+option(CODE_COVERAGE "Enable coverage reporting" OFF)
 OPTION(USE_PYTHON "Generate py wrappers" ON)
 OPTION(USE_PYTHON3 "Python 3x" OFF)
 
diff --git a/README.md b/README.md
index 392451f..849b2e6 100644
--- a/README.md
+++ b/README.md
@@ -21,7 +21,9 @@
 
 # Apache SINGA
 
-[![Build Status](https://travis-ci.org/apache/singa.png)](https://travis-ci.org/apache/singa)
+[![Native Ubuntu build status](https://github.com/apache/singa/workflows/Native-Ubuntu/badge.svg)
+[![Native Mac build status](https://github.com/apache/singa/workflows/Native-MacOS/badge.svg)
+[![conda build status](https://github.com/apache/singa/workflows/conda/badge.svg)
 [![Documentation Status](https://readthedocs.org/projects/apache-singa/badge/?version=latest)](https://apache-singa.readthedocs.io/en/latest/?badge=latest)
 ![License](http://img.shields.io/:license-Apache%202.0-blue.svg)
 [![Follow Apache SINGA on Twitter](https://img.shields.io/twitter/follow/apachesinga.svg?style=social&label=Follow)](https://twitter.com/ApacheSinga)
@@ -42,8 +44,9 @@
 
 ## Code Analysis:
 
-![LGTM C++ Grade](https://img.shields.io/lgtm/grade/cpp/github/apache/incubator-singa)
-![LGTM Python Grade](https://img.shields.io/lgtm/grade/python/github/apache/incubator-singa)
+![LGTM C++ Grade](https://img.shields.io/lgtm/grade/cpp/github/apache/singa)
+![LGTM Python Grade](https://img.shields.io/lgtm/grade/python/github/apache/singa)
+[![codecov](https://codecov.io/gh/apache/singa/branch/master/graph/badge.svg)](https://codecov.io/gh/apache/singa)
 
 [![Stargazers over time](https://starchart.cc/apache/singa.svg)](https://starchart.cc/apache/singa)
 
diff --git a/examples/cnn/README.md b/examples/cnn/README.md
index c7fe673..b081aff 100644
--- a/examples/cnn/README.md
+++ b/examples/cnn/README.md
@@ -34,7 +34,7 @@
   [neural network operations](../../python/singa/autograd.py) imperatively. 
   The computational graph is not created.
 
-* `train.py` is the training script, which controls the training flow by
+* `train_cnn.py` is the training script, which controls the training flow by
   doing BackPropagation and SGD update.
 
 * `train_multiprocess.py` is the script for distributed training on a single
diff --git a/examples/cnn/autograd/mnist_cnn.py b/examples/cnn/autograd/mnist_cnn.py
index a9187e7..ff2e1dc 100644
--- a/examples/cnn/autograd/mnist_cnn.py
+++ b/examples/cnn/autograd/mnist_cnn.py
@@ -19,6 +19,7 @@
 
 from singa import singa_wrap as singa
 from singa import autograd
+from singa import layer
 from singa import tensor
 from singa import device
 from singa import opt
@@ -33,23 +34,27 @@
 class CNN:
 
     def __init__(self):
-        self.conv1 = autograd.Conv2d(1, 20, 5, padding=0)
-        self.conv2 = autograd.Conv2d(20, 50, 5, padding=0)
-        self.linear1 = autograd.Linear(4 * 4 * 50, 500)
-        self.linear2 = autograd.Linear(500, 10)
-        self.pooling1 = autograd.MaxPool2d(2, 2, padding=0)
-        self.pooling2 = autograd.MaxPool2d(2, 2, padding=0)
+        self.conv1 = layer.Conv2d(1, 20, 5, padding=0)
+        self.conv2 = layer.Conv2d(20, 50, 5, padding=0)
+        self.linear1 = layer.Linear(4 * 4 * 50, 500)
+        self.linear2 = layer.Linear(500, 10)
+        self.pooling1 = layer.MaxPool2d(2, 2, padding=0)
+        self.pooling2 = layer.MaxPool2d(2, 2, padding=0)
+        self.relu1 = layer.ReLU()
+        self.relu2 = layer.ReLU()
+        self.relu3 = layer.ReLU()
+        self.flatten = layer.Flatten()
 
     def forward(self, x):
         y = self.conv1(x)
-        y = autograd.relu(y)
+        y = self.relu1(y)
         y = self.pooling1(y)
         y = self.conv2(y)
-        y = autograd.relu(y)
+        y = self.relu2(y)
         y = self.pooling2(y)
-        y = autograd.flatten(y)
+        y = self.flatten(y)
         y = self.linear1(y)
-        y = autograd.relu(y)
+        y = self.relu3(y)
         y = self.linear2(y)
         return y
 
@@ -255,7 +260,7 @@
                                                    topK=topK,
                                                    corr=corr)
             else:
-                sgd.backward_and_update(loss)
+                sgd(loss)
 
         if DIST:
             # Reduce the Evaluation Accuracy and Loss from Multiple Devices
diff --git a/examples/cnn/autograd/xceptionnet.py b/examples/cnn/autograd/xceptionnet.py
index c1c63be..357e47d 100644
--- a/examples/cnn/autograd/xceptionnet.py
+++ b/examples/cnn/autograd/xceptionnet.py
@@ -18,6 +18,7 @@
 from singa import autograd
 from singa import tensor
 from singa import device
+from singa import layer
 from singa import opt
 
 import numpy as np
@@ -27,7 +28,7 @@
 # https://github.com/Cadene/pretrained-models.pytorch/blob/master/pretrainedmodels/models/xception.py
 
 
-class Block(autograd.Layer):
+class Block(layer.Layer):
 
     def __init__(self,
                  in_filters,
@@ -40,13 +41,13 @@
         super(Block, self).__init__()
 
         if out_filters != in_filters or strides != 1:
-            self.skip = autograd.Conv2d(in_filters,
-                                        out_filters,
-                                        1,
-                                        stride=strides,
-                                        padding=padding,
-                                        bias=False)
-            self.skipbn = autograd.BatchNorm2d(out_filters)
+            self.skip = layer.Conv2d(in_filters,
+                                     out_filters,
+                                     1,
+                                     stride=strides,
+                                     padding=padding,
+                                     bias=False)
+            self.skipbn = layer.BatchNorm2d(out_filters)
         else:
             self.skip = None
 
@@ -54,48 +55,52 @@
 
         filters = in_filters
         if grow_first:
-            self.layers.append(autograd.ReLU())
+            self.layers.append(layer.ReLU())
             self.layers.append(
-                autograd.SeparableConv2d(in_filters,
-                                         out_filters,
-                                         3,
-                                         stride=1,
-                                         padding=1,
-                                         bias=False))
-            self.layers.append(autograd.BatchNorm2d(out_filters))
+                layer.SeparableConv2d(in_filters,
+                                      out_filters,
+                                      3,
+                                      stride=1,
+                                      padding=1,
+                                      bias=False))
+            self.layers.append(layer.BatchNorm2d(out_filters))
             filters = out_filters
 
         for i in range(reps - 1):
-            self.layers.append(autograd.ReLU())
+            self.layers.append(layer.ReLU())
             self.layers.append(
-                autograd.SeparableConv2d(filters,
-                                         filters,
-                                         3,
-                                         stride=1,
-                                         padding=1,
-                                         bias=False))
-            self.layers.append(autograd.BatchNorm2d(filters))
+                layer.SeparableConv2d(filters,
+                                      filters,
+                                      3,
+                                      stride=1,
+                                      padding=1,
+                                      bias=False))
+            self.layers.append(layer.BatchNorm2d(filters))
 
         if not grow_first:
-            self.layers.append(autograd.ReLU())
+            self.layers.append(layer.ReLU())
             self.layers.append(
-                autograd.SeparableConv2d(in_filters,
-                                         out_filters,
-                                         3,
-                                         stride=1,
-                                         padding=1,
-                                         bias=False))
-            self.layers.append(autograd.BatchNorm2d(out_filters))
+                layer.SeparableConv2d(in_filters,
+                                      out_filters,
+                                      3,
+                                      stride=1,
+                                      padding=1,
+                                      bias=False))
+            self.layers.append(layer.BatchNorm2d(out_filters))
 
         if not start_with_relu:
             self.layers = self.layers[1:]
         else:
-            self.layers[0] = autograd.ReLU()
+            self.layers[0] = layer.ReLU()
 
         if strides != 1:
-            self.layers.append(autograd.MaxPool2d(3, strides, padding + 1))
+            self.layers.append(layer.MaxPool2d(3, strides, padding + 1))
 
-    def __call__(self, x):
+        self.register_layers(*self.layers)
+
+        self.add = layer.Add()
+
+    def forward(self, x):
         y = self.layers[0](x)
         for layer in self.layers[1:]:
             if isinstance(y, tuple):
@@ -107,14 +112,14 @@
             skip = self.skipbn(skip)
         else:
             skip = x
-        y = autograd.add(y, skip)
+        y = self.add(y, skip)
         return y
 
 
 __all__ = ['Xception']
 
 
-class Xception(autograd.Layer):
+class Xception(layer.Layer):
     """
     Xception optimized for the ImageNet dataset, as specified in
     https://arxiv.org/pdf/1610.02357.pdf
@@ -128,11 +133,13 @@
         super(Xception, self).__init__()
         self.num_classes = num_classes
 
-        self.conv1 = autograd.Conv2d(3, 32, 3, 2, 0, bias=False)
-        self.bn1 = autograd.BatchNorm2d(32)
+        self.conv1 = layer.Conv2d(3, 32, 3, 2, 0, bias=False)
+        self.bn1 = layer.BatchNorm2d(32)
+        self.relu1 = layer.ReLU()
 
-        self.conv2 = autograd.Conv2d(32, 64, 3, 1, 1, bias=False)
-        self.bn2 = autograd.BatchNorm2d(64)
+        self.conv2 = layer.Conv2d(32, 64, 3, 1, 1, bias=False)
+        self.bn2 = layer.BatchNorm2d(64)
+        self.relu2 = layer.ReLU()
         # do relu here
 
         self.block1 = Block(64,
@@ -214,24 +221,27 @@
                              start_with_relu=True,
                              grow_first=False)
 
-        self.conv3 = autograd.SeparableConv2d(1024, 1536, 3, 1, 1)
-        self.bn3 = autograd.BatchNorm2d(1536)
+        self.conv3 = layer.SeparableConv2d(1024, 1536, 3, 1, 1)
+        self.bn3 = layer.BatchNorm2d(1536)
+        self.relu3 = layer.ReLU()
 
         # do relu here
-        self.conv4 = autograd.SeparableConv2d(1536, 2048, 3, 1, 1)
-        self.bn4 = autograd.BatchNorm2d(2048)
+        self.conv4 = layer.SeparableConv2d(1536, 2048, 3, 1, 1)
+        self.bn4 = layer.BatchNorm2d(2048)
 
-        self.globalpooling = autograd.MaxPool2d(10, 1)
-        self.fc = autograd.Linear(2048, num_classes)
+        self.relu4 = layer.ReLU()
+        self.globalpooling = layer.MaxPool2d(10, 1)
+        self.flatten = layer.Flatten()
+        self.fc = layer.Linear(2048, num_classes)
 
     def features(self, input):
         x = self.conv1(input)
         x = self.bn1(x)
-        x = autograd.relu(x)
+        x = self.relu1(x)
 
         x = self.conv2(x)
         x = self.bn2(x)
-        x = autograd.relu(x)
+        x = self.relu2(x)
 
         x = self.block1(x)
         x = self.block2(x)
@@ -248,20 +258,20 @@
 
         x = self.conv3(x)
         x = self.bn3(x)
-        x = autograd.relu(x)
+        x = self.relu3(x)
 
         x = self.conv4(x)
         x = self.bn4(x)
         return x
 
     def logits(self, features):
-        x = autograd.relu(features)
+        x = self.relu4(features)
         x = self.globalpooling(x)
-        x = autograd.flatten(x)
+        x = self.flatten(x)
         x = self.fc(x)
         return x
 
-    def __call__(self, input):
+    def forward(self, input):
         x = self.features(input)
         x = self.logits(x)
         return x
@@ -290,5 +300,4 @@
         for _ in t:
             x = model(tx)
             loss = autograd.softmax_cross_entropy(x, ty)
-            for p, g in autograd.backward(loss):
-                sgd.update(p, g)
+            sgd(loss)
diff --git a/examples/cnn/benchmark.py b/examples/cnn/benchmark.py
index d2bbc3b..a182139 100644
--- a/examples/cnn/benchmark.py
+++ b/examples/cnn/benchmark.py
@@ -30,7 +30,7 @@
 from tqdm import trange
 
 
-def train_resnet(DIST=True, graph=True, sequential=False):
+def train_resnet(DIST=True, graph=True, sequential=False, verbosity=0):
 
     # Define the hypermeters good for the train_resnet
     niters = 100
@@ -40,7 +40,7 @@
     IMG_SIZE = 224
 
     # For distributed training, sequential has better throughput in the current version
-    if DIST:
+    if DIST == True:
         sgd = opt.DistOpt(sgd)
         world_size = sgd.world_size
         local_rank = sgd.local_rank
@@ -61,23 +61,23 @@
     tx.copy_from_numpy(x)
     ty.copy_from_numpy(y)
 
+    dev.SetVerbosity(verbosity)
+    dev.SetSkipIteration(5)
+
     # construct the model
     from model import resnet
     model = resnet.resnet50(num_channels=3, num_classes=1000)
 
     model.train()
-    model.on_device(dev)
     model.set_optimizer(sgd)
-    model.graph(graph, sequential)
+    model.compile([tx], is_train=True, use_graph=graph, sequential=sequential)
 
     # train model
     dev.Sync()
     start = time.time()
     with trange(niters) as t:
         for _ in t:
-            out = model(tx)
-            loss = model.loss(out, ty)
-            model.optim(loss, dist_option='fp32', spars=None)
+            model(tx, ty, dist_option='fp32', spars=None)
 
     dev.Sync()
     end = time.time()
@@ -87,6 +87,7 @@
         print("Throughput = {} per second".format(throughput), flush=True)
         print("TotalTime={}".format(end - start), flush=True)
         print("Total={}".format(titer), flush=True)
+        dev.PrintTimeProfiling()
 
 
 if __name__ == "__main__":
@@ -105,7 +106,16 @@
                         action='store_false',
                         help='disable graph',
                         dest='graph')
+    parser.add_argument('--verbosity',
+                        '--log-verbosity',
+                        default=0,
+                        type=int,
+                        help='logging verbosity',
+                        dest='verbosity')
 
     args = parser.parse_args()
 
-    train_resnet(DIST=args.DIST, graph=args.graph)
+    train_resnet(DIST=args.DIST,
+                 graph=args.graph,
+                 sequential=False,
+                 verbosity=args.verbosity)
diff --git a/examples/cnn/model/alexnet.py b/examples/cnn/model/alexnet.py
index e13f525..988596e 100644
--- a/examples/cnn/model/alexnet.py
+++ b/examples/cnn/model/alexnet.py
@@ -17,61 +17,72 @@
 # under the License.
 #
 
-from singa import autograd
-from singa import module
+from singa import layer
+from singa import model
 
 
-class AlexNet(module.Module):
+class AlexNet(model.Model):
 
     def __init__(self, num_classes=10, num_channels=1):
         super(AlexNet, self).__init__()
         self.num_classes = num_classes
         self.input_size = 224
         self.dimension = 4
-        self.conv1 = autograd.Conv2d(num_channels, 64, 11, stride=4, padding=2)
-        self.conv2 = autograd.Conv2d(64, 192, 5, padding=2)
-        self.conv3 = autograd.Conv2d(192, 384, 3, padding=1)
-        self.conv4 = autograd.Conv2d(384, 256, 3, padding=1)
-        self.conv5 = autograd.Conv2d(256, 256, 3, padding=1)
-        self.linear1 = autograd.Linear(1024, 4096)
-        self.linear2 = autograd.Linear(4096, 4096)
-        self.linear3 = autograd.Linear(4096, num_classes)
-        self.pooling1 = autograd.MaxPool2d(2, 2, padding=0)
-        self.pooling2 = autograd.MaxPool2d(2, 2, padding=0)
-        self.pooling3 = autograd.MaxPool2d(2, 2, padding=0)
-        self.avg_pooling1 = autograd.AvgPool2d(3, 2, padding=0)
+        self.conv1 = layer.Conv2d(num_channels, 64, 11, stride=4, padding=2)
+        self.conv2 = layer.Conv2d(64, 192, 5, padding=2)
+        self.conv3 = layer.Conv2d(192, 384, 3, padding=1)
+        self.conv4 = layer.Conv2d(384, 256, 3, padding=1)
+        self.conv5 = layer.Conv2d(256, 256, 3, padding=1)
+        self.linear1 = layer.Linear(4096)
+        self.linear2 = layer.Linear(4096)
+        self.linear3 = layer.Linear(num_classes)
+        self.pooling1 = layer.MaxPool2d(2, 2, padding=0)
+        self.pooling2 = layer.MaxPool2d(2, 2, padding=0)
+        self.pooling3 = layer.MaxPool2d(2, 2, padding=0)
+        self.avg_pooling1 = layer.AvgPool2d(3, 2, padding=0)
+        self.relu1 = layer.ReLU()
+        self.relu2 = layer.ReLU()
+        self.relu3 = layer.ReLU()
+        self.relu4 = layer.ReLU()
+        self.relu5 = layer.ReLU()
+        self.relu6 = layer.ReLU()
+        self.relu7 = layer.ReLU()
+        self.flatten = layer.Flatten()
+        self.dropout1 = layer.Dropout()
+        self.dropout2 = layer.Dropout()
+        self.softmax_cross_entropy = layer.SoftMaxCrossEntropy()
 
     def forward(self, x):
         y = self.conv1(x)
-        y = autograd.relu(y)
+        y = self.relu1(y)
         y = self.pooling1(y)
         y = self.conv2(y)
-        y = autograd.relu(y)
+        y = self.relu2(y)
         y = self.pooling2(y)
         y = self.conv3(y)
-        y = autograd.relu(y)
+        y = self.relu3(y)
         y = self.conv4(y)
-        y = autograd.relu(y)
+        y = self.relu4(y)
         y = self.conv5(y)
-        y = autograd.relu(y)
+        y = self.relu5(y)
         y = self.pooling3(y)
         y = self.avg_pooling1(y)
-        y = autograd.flatten(y)
-        y = autograd.dropout(y)
+        y = self.flatten(y)
+        y = self.dropout1(y)
         y = self.linear1(y)
-        y = autograd.relu(y)
-        y = autograd.dropout(y)
+        y = self.relu6(y)
+        y = self.dropout2(y)
         y = self.linear2(y)
-        y = autograd.relu(y)
+        y = self.relu7(y)
         y = self.linear3(y)
         return y
 
-    def loss(self, out, ty):
-        return autograd.softmax_cross_entropy(out, ty)
+    def train_one_batch(self, x, y, dist_option, spars):
+        out = self.forward(x)
+        loss = self.softmax_cross_entropy(out, y)
 
-    def optim(self, loss, dist_option, spars):
         if dist_option == 'fp32':
-            self.optimizer.backward_and_update(loss)
+            self.optimizer(loss)
         elif dist_option == 'fp16':
             self.optimizer.backward_and_update_half(loss)
         elif dist_option == 'partialUpdate':
@@ -84,6 +95,7 @@
             self.optimizer.backward_and_sparse_update(loss,
                                                       topK=False,
                                                       spars=spars)
+        return out, loss
 
     def set_optimizer(self, optimizer):
         self.optimizer = optimizer
diff --git a/examples/cnn/model/cnn.py b/examples/cnn/model/cnn.py
index 24547d4..28ecd6c 100644
--- a/examples/cnn/model/cnn.py
+++ b/examples/cnn/model/cnn.py
@@ -17,43 +17,44 @@
 # under the License.
 #
 
-from singa import autograd
-from singa import module
+from singa import layer
+from singa import model
 
 
-class CNN(module.Module):
+class CNN(model.Model):
 
     def __init__(self, num_classes=10, num_channels=1):
         super(CNN, self).__init__()
         self.num_classes = num_classes
         self.input_size = 28
         self.dimension = 4
-        self.conv1 = autograd.Conv2d(num_channels, 20, 5, padding=0)
-        self.conv2 = autograd.Conv2d(20, 50, 5, padding=0)
-        self.linear1 = autograd.Linear(4 * 4 * 50, 500)
-        self.linear2 = autograd.Linear(500, num_classes)
-        self.pooling1 = autograd.MaxPool2d(2, 2, padding=0)
-        self.pooling2 = autograd.MaxPool2d(2, 2, padding=0)
+        self.conv1 = layer.Conv2d(num_channels, 20, 5, padding=0, activation="RELU")
+        self.conv2 = layer.Conv2d(20, 50, 5, padding=0, activation="RELU")
+        self.linear1 = layer.Linear(500)
+        self.linear2 = layer.Linear(num_classes)
+        self.pooling1 = layer.MaxPool2d(2, 2, padding=0)
+        self.pooling2 = layer.MaxPool2d(2, 2, padding=0)
+        self.relu = layer.ReLU()
+        self.flatten = layer.Flatten()
+        self.softmax_cross_entropy = layer.SoftMaxCrossEntropy()
 
     def forward(self, x):
         y = self.conv1(x)
-        y = autograd.relu(y)
         y = self.pooling1(y)
         y = self.conv2(y)
-        y = autograd.relu(y)
         y = self.pooling2(y)
-        y = autograd.flatten(y)
+        y = self.flatten(y)
         y = self.linear1(y)
-        y = autograd.relu(y)
+        y = self.relu(y)
         y = self.linear2(y)
         return y
 
-    def loss(self, out, ty):
-        return autograd.softmax_cross_entropy(out, ty)
+    def train_one_batch(self, x, y, dist_option, spars):
+        out = self.forward(x)
+        loss = self.softmax_cross_entropy(out, y)
 
-    def optim(self, loss, dist_option, spars):
         if dist_option == 'fp32':
-            self.optimizer.backward_and_update(loss)
+            self.optimizer(loss)
         elif dist_option == 'fp16':
             self.optimizer.backward_and_update_half(loss)
         elif dist_option == 'partialUpdate':
@@ -66,6 +67,7 @@
             self.optimizer.backward_and_sparse_update(loss,
                                                       topK=False,
                                                       spars=spars)
+        return out, loss
 
     def set_optimizer(self, optimizer):
         self.optimizer = optimizer
diff --git a/examples/cnn/model/resnet.py b/examples/cnn/model/resnet.py
index 4e83757..2b2a7fd 100644
--- a/examples/cnn/model/resnet.py
+++ b/examples/cnn/model/resnet.py
@@ -20,40 +20,43 @@
 # the code is modified from
 # https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
 
-from singa import autograd
-from singa import module
+from singa import layer
+from singa import model
 
 
 def conv3x3(in_planes, out_planes, stride=1):
     """3x3 convolution with padding"""
-    return autograd.Conv2d(
+    return layer.Conv2d(
         in_planes,
         out_planes,
-        kernel_size=3,
+        3,
         stride=stride,
         padding=1,
         bias=False,
     )
 
 
-class BasicBlock(autograd.Layer):
+class BasicBlock(layer.Layer):
     expansion = 1
 
     def __init__(self, inplanes, planes, stride=1, downsample=None):
         super(BasicBlock, self).__init__()
         self.conv1 = conv3x3(inplanes, planes, stride)
-        self.bn1 = autograd.BatchNorm2d(planes)
+        self.bn1 = layer.BatchNorm2d(planes)
         self.conv2 = conv3x3(planes, planes)
-        self.bn2 = autograd.BatchNorm2d(planes)
+        self.bn2 = layer.BatchNorm2d(planes)
+        self.relu1 = layer.ReLU()
+        self.add = layer.Add()
+        self.relu2 = layer.ReLU()
         self.downsample = downsample
         self.stride = stride
 
-    def __call__(self, x):
+    def forward(self, x):
         residual = x
 
         out = self.conv1(x)
         out = self.bn1(out)
-        out = autograd.relu(out)
+        out = self.relu1(out)
 
         out = self.conv2(out)
         out = self.bn2(out)
@@ -61,48 +64,50 @@
         if self.downsample is not None:
             residual = self.downsample(x)
 
-        out = autograd.add(out, residual)
-        out = autograd.relu(out)
+        out = self.add(out, residual)
+        out = self.relu2(out)
 
         return out
 
 
-class Bottleneck(autograd.Layer):
+class Bottleneck(layer.Layer):
     expansion = 4
 
     def __init__(self, inplanes, planes, stride=1, downsample=None):
         super(Bottleneck, self).__init__()
-        self.conv1 = autograd.Conv2d(inplanes,
-                                     planes,
-                                     kernel_size=1,
-                                     bias=False)
-        self.bn1 = autograd.BatchNorm2d(planes)
-        self.conv2 = autograd.Conv2d(planes,
-                                     planes,
-                                     kernel_size=3,
-                                     stride=stride,
-                                     padding=1,
-                                     bias=False)
-        self.bn2 = autograd.BatchNorm2d(planes)
-        self.conv3 = autograd.Conv2d(planes,
-                                     planes * self.expansion,
-                                     kernel_size=1,
-                                     bias=False)
-        self.bn3 = autograd.BatchNorm2d(planes * self.expansion)
+        self.conv1 = layer.Conv2d(inplanes, planes, 1, bias=False)
+        self.bn1 = layer.BatchNorm2d(planes)
+        self.relu1 = layer.ReLU()
+        self.conv2 = layer.Conv2d(planes,
+                                  planes,
+                                  3,
+                                  stride=stride,
+                                  padding=1,
+                                  bias=False)
+        self.bn2 = layer.BatchNorm2d(planes)
+        self.relu2 = layer.ReLU()
+        self.conv3 = layer.Conv2d(planes,
+                                  planes * self.expansion,
+                                  1,
+                                  bias=False)
+        self.bn3 = layer.BatchNorm2d(planes * self.expansion)
+
+        self.add = layer.Add()
+        self.relu3 = layer.ReLU()
 
         self.downsample = downsample
         self.stride = stride
 
-    def __call__(self, x):
+    def forward(self, x):
         residual = x
 
         out = self.conv1(x)
         out = self.bn1(out)
-        out = autograd.relu(out)
+        out = self.relu1(out)
 
         out = self.conv2(out)
         out = self.bn2(out)
-        out = autograd.relu(out)
+        out = self.relu2(out)
 
         out = self.conv3(out)
         out = self.bn3(out)
@@ -110,8 +115,8 @@
         if self.downsample is not None:
             residual = self.downsample(x)
 
-        out = autograd.add(out, residual)
-        out = autograd.relu(out)
+        out = self.add(out, residual)
+        out = self.relu3(out)
 
         return out
 
@@ -121,7 +126,7 @@
 ]
 
 
-class ResNet(module.Module):
+class ResNet(model.Model):
 
     def __init__(self, block, layers, num_classes=10, num_channels=3):
         self.inplanes = 64
@@ -129,32 +134,37 @@
         self.num_classes = num_classes
         self.input_size = 224
         self.dimension = 4
-        self.conv1 = autograd.Conv2d(num_channels,
-                                     64,
-                                     kernel_size=7,
-                                     stride=2,
-                                     padding=3,
-                                     bias=False)
-        self.bn1 = autograd.BatchNorm2d(64)
-        self.maxpool = autograd.MaxPool2d(kernel_size=3, stride=2, padding=1)
-        self.layer1 = self._make_layer(block, 64, layers[0])
-        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
-        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
-        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
-        self.avgpool = autograd.AvgPool2d(7, stride=1)
-        self.fc = autograd.Linear(512 * block.expansion, num_classes)
+        self.conv1 = layer.Conv2d(num_channels,
+                                  64,
+                                  7,
+                                  stride=2,
+                                  padding=3,
+                                  bias=False)
+        self.bn1 = layer.BatchNorm2d(64)
+        self.relu = layer.ReLU()
+        self.maxpool = layer.MaxPool2d(kernel_size=3, stride=2, padding=1)
+        self.layer1, layers1 = self._make_layer(block, 64, layers[0])
+        self.layer2, layers2 = self._make_layer(block, 128, layers[1], stride=2)
+        self.layer3, layers3 = self._make_layer(block, 256, layers[2], stride=2)
+        self.layer4, layers4 = self._make_layer(block, 512, layers[3], stride=2)
+        self.avgpool = layer.AvgPool2d(7, stride=1)
+        self.flatten = layer.Flatten()
+        self.fc = layer.Linear(num_classes)
+        self.softmax_cross_entropy = layer.SoftMaxCrossEntropy()
+
+        self.register_layers(*layers1, *layers2, *layers3, *layers4)
 
     def _make_layer(self, block, planes, blocks, stride=1):
         downsample = None
         if stride != 1 or self.inplanes != planes * block.expansion:
-            conv = autograd.Conv2d(
+            conv = layer.Conv2d(
                 self.inplanes,
                 planes * block.expansion,
-                kernel_size=1,
+                1,
                 stride=stride,
                 bias=False,
             )
-            bn = autograd.BatchNorm2d(planes * block.expansion)
+            bn = layer.BatchNorm2d(planes * block.expansion)
 
             def _downsample(x):
                 return bn(conv(x))
@@ -172,12 +182,12 @@
                 x = layer(x)
             return x
 
-        return forward
+        return forward, layers
 
     def forward(self, x):
         x = self.conv1(x)
         x = self.bn1(x)
-        x = autograd.relu(x)
+        x = self.relu(x)
         x = self.maxpool(x)
 
         x = self.layer1(x)
@@ -186,17 +196,17 @@
         x = self.layer4(x)
 
         x = self.avgpool(x)
-        x = autograd.flatten(x)
+        x = self.flatten(x)
         x = self.fc(x)
 
         return x
 
-    def loss(self, out, ty):
-        return autograd.softmax_cross_entropy(out, ty)
+    def train_one_batch(self, x, y, dist_option, spars):
+        out = self.forward(x)
+        loss = self.softmax_cross_entropy(out, y)
 
-    def optim(self, loss, dist_option, spars):
         if dist_option == 'fp32':
-            self.optimizer.backward_and_update(loss)
+            self.optimizer(loss)
         elif dist_option == 'fp16':
             self.optimizer.backward_and_update_half(loss)
         elif dist_option == 'partialUpdate':
@@ -209,6 +219,7 @@
             self.optimizer.backward_and_sparse_update(loss,
                                                       topK=False,
                                                       spars=spars)
+        return out, loss
 
     def set_optimizer(self, optimizer):
         self.optimizer = optimizer
diff --git a/examples/cnn/model/xceptionnet.py b/examples/cnn/model/xceptionnet.py
index 9015a34..524e3f6 100644
--- a/examples/cnn/model/xceptionnet.py
+++ b/examples/cnn/model/xceptionnet.py
@@ -18,11 +18,11 @@
 # the code is modified from
 # https://github.com/Cadene/pretrained-models.pytorch/blob/master/pretrainedmodels/models/xception.py
 
-from singa import autograd
-from singa import module
+from singa import layer
+from singa import model
 
 
-class Block(autograd.Layer):
+class Block(layer.Layer):
 
     def __init__(self,
                  in_filters,
@@ -35,13 +35,13 @@
         super(Block, self).__init__()
 
         if out_filters != in_filters or strides != 1:
-            self.skip = autograd.Conv2d(in_filters,
-                                        out_filters,
-                                        1,
-                                        stride=strides,
-                                        padding=padding,
-                                        bias=False)
-            self.skipbn = autograd.BatchNorm2d(out_filters)
+            self.skip = layer.Conv2d(in_filters,
+                                     out_filters,
+                                     1,
+                                     stride=strides,
+                                     padding=padding,
+                                     bias=False)
+            self.skipbn = layer.BatchNorm2d(out_filters)
         else:
             self.skip = None
 
@@ -49,48 +49,52 @@
 
         filters = in_filters
         if grow_first:
-            self.layers.append(autograd.ReLU())
+            self.layers.append(layer.ReLU())
             self.layers.append(
-                autograd.SeparableConv2d(in_filters,
-                                         out_filters,
-                                         3,
-                                         stride=1,
-                                         padding=1,
-                                         bias=False))
-            self.layers.append(autograd.BatchNorm2d(out_filters))
+                layer.SeparableConv2d(in_filters,
+                                      out_filters,
+                                      3,
+                                      stride=1,
+                                      padding=1,
+                                      bias=False))
+            self.layers.append(layer.BatchNorm2d(out_filters))
             filters = out_filters
 
         for i in range(reps - 1):
-            self.layers.append(autograd.ReLU())
+            self.layers.append(layer.ReLU())
             self.layers.append(
-                autograd.SeparableConv2d(filters,
-                                         filters,
-                                         3,
-                                         stride=1,
-                                         padding=1,
-                                         bias=False))
-            self.layers.append(autograd.BatchNorm2d(filters))
+                layer.SeparableConv2d(filters,
+                                      filters,
+                                      3,
+                                      stride=1,
+                                      padding=1,
+                                      bias=False))
+            self.layers.append(layer.BatchNorm2d(filters))
 
         if not grow_first:
-            self.layers.append(autograd.ReLU())
+            self.layers.append(layer.ReLU())
             self.layers.append(
-                autograd.SeparableConv2d(in_filters,
-                                         out_filters,
-                                         3,
-                                         stride=1,
-                                         padding=1,
-                                         bias=False))
-            self.layers.append(autograd.BatchNorm2d(out_filters))
+                layer.SeparableConv2d(in_filters,
+                                      out_filters,
+                                      3,
+                                      stride=1,
+                                      padding=1,
+                                      bias=False))
+            self.layers.append(layer.BatchNorm2d(out_filters))
 
         if not start_with_relu:
             self.layers = self.layers[1:]
         else:
-            self.layers[0] = autograd.ReLU()
+            self.layers[0] = layer.ReLU()
 
         if strides != 1:
-            self.layers.append(autograd.MaxPool2d(3, strides, padding + 1))
+            self.layers.append(layer.MaxPool2d(3, strides, padding + 1))
 
-    def __call__(self, x):
+        self.register_layers(*self.layers)
+
+        self.add = layer.Add()
+
+    def forward(self, x):
         y = self.layers[0](x)
         for layer in self.layers[1:]:
             if isinstance(y, tuple):
@@ -102,11 +106,11 @@
             skip = self.skipbn(skip)
         else:
             skip = x
-        y = autograd.add(y, skip)
+        y = self.add(y, skip)
         return y
 
 
-class Xception(module.Module):
+class Xception(model.Model):
     """
     Xception optimized for the ImageNet dataset, as specified in
     https://arxiv.org/pdf/1610.02357.pdf
@@ -122,11 +126,13 @@
         self.input_size = 299
         self.dimension = 4
 
-        self.conv1 = autograd.Conv2d(num_channels, 32, 3, 2, 0, bias=False)
-        self.bn1 = autograd.BatchNorm2d(32)
+        self.conv1 = layer.Conv2d(num_channels, 32, 3, 2, 0, bias=False)
+        self.bn1 = layer.BatchNorm2d(32)
+        self.relu1 = layer.ReLU()
 
-        self.conv2 = autograd.Conv2d(32, 64, 3, 1, 1, bias=False)
-        self.bn2 = autograd.BatchNorm2d(64)
+        self.conv2 = layer.Conv2d(32, 64, 3, 1, 1, bias=False)
+        self.bn2 = layer.BatchNorm2d(64)
+        self.relu2 = layer.ReLU()
         # do relu here
 
         self.block1 = Block(64,
@@ -208,24 +214,29 @@
                              start_with_relu=True,
                              grow_first=False)
 
-        self.conv3 = autograd.SeparableConv2d(1024, 1536, 3, 1, 1)
-        self.bn3 = autograd.BatchNorm2d(1536)
+        self.conv3 = layer.SeparableConv2d(1024, 1536, 3, 1, 1)
+        self.bn3 = layer.BatchNorm2d(1536)
+        self.relu3 = layer.ReLU()
 
         # do relu here
-        self.conv4 = autograd.SeparableConv2d(1536, 2048, 3, 1, 1)
-        self.bn4 = autograd.BatchNorm2d(2048)
+        self.conv4 = layer.SeparableConv2d(1536, 2048, 3, 1, 1)
+        self.bn4 = layer.BatchNorm2d(2048)
 
-        self.globalpooling = autograd.MaxPool2d(10, 1)
-        self.fc = autograd.Linear(2048, num_classes)
+        self.relu4 = layer.ReLU()
+        self.globalpooling = layer.MaxPool2d(10, 1)
+        self.flatten = layer.Flatten()
+        self.fc = layer.Linear(num_classes)
+
+        self.softmax_cross_entropy = layer.SoftMaxCrossEntropy()
 
     def features(self, input):
         x = self.conv1(input)
         x = self.bn1(x)
-        x = autograd.relu(x)
+        x = self.relu1(x)
 
         x = self.conv2(x)
         x = self.bn2(x)
-        x = autograd.relu(x)
+        x = self.relu2(x)
 
         x = self.block1(x)
         x = self.block2(x)
@@ -242,30 +253,29 @@
 
         x = self.conv3(x)
         x = self.bn3(x)
-        x = autograd.relu(x)
+        x = self.relu3(x)
 
         x = self.conv4(x)
         x = self.bn4(x)
         return x
 
     def logits(self, features):
-        x = autograd.relu(features)
+        x = self.relu4(features)
         x = self.globalpooling(x)
-        x = autograd.flatten(x)
+        x = self.flatten(x)
         x = self.fc(x)
         return x
 
-    def forward(self, input):
-        x = self.features(input)
+    def forward(self, x):
+        x = self.features(x)
         x = self.logits(x)
         return x
 
-    def loss(self, out, ty):
-        return autograd.softmax_cross_entropy(out, ty)
-
-    def optim(self, loss, dist_option, spars):
+    def train_one_batch(self, x, y, dist_option, spars):
+        out = self.forward(x)
+        loss = self.softmax_cross_entropy(out, y)
         if dist_option == 'fp32':
-            self.optimizer.backward_and_update(loss)
+            self.optimizer(loss)
         elif dist_option == 'fp16':
             self.optimizer.backward_and_update_half(loss)
         elif dist_option == 'partialUpdate':
@@ -278,6 +288,7 @@
             self.optimizer.backward_and_sparse_update(loss,
                                                       topK=False,
                                                       spars=spars)
+        return out, loss
 
     def set_optimizer(self, optimizer):
         self.optimizer = optimizer
diff --git a/examples/cnn/train.py b/examples/cnn/train_cnn.py
similarity index 93%
rename from examples/cnn/train.py
rename to examples/cnn/train_cnn.py
index 9f75b12..4c74b99 100644
--- a/examples/cnn/train.py
+++ b/examples/cnn/train_cnn.py
@@ -18,9 +18,9 @@
 #
 
 from singa import singa_wrap as singa
-from singa import opt
 from singa import device
 from singa import tensor
+from singa import opt
 import numpy as np
 import time
 import argparse
@@ -99,6 +99,7 @@
         data,
         sgd,
         graph,
+        verbosity,
         dist_option='fp32',
         spars=None):
     dev = device.create_cuda_gpu_on(local_rank)
@@ -123,7 +124,7 @@
 
     if model == 'resnet':
         from model import resnet
-        model = resnet.resnet18(num_channels=num_channels,
+        model = resnet.resnet50(num_channels=num_channels,
                                 num_classes=num_classes)
     elif model == 'xceptionnet':
         from model import xceptionnet
@@ -182,9 +183,9 @@
     idx = np.arange(train_x.shape[0], dtype=np.int32)
 
     # attached model to graph
-    model.on_device(dev)
     model.set_optimizer(sgd)
-    model.graph(graph, sequential)
+    model.compile([tx], is_train=True, use_graph=graph, sequential=sequential)
+    dev.SetVerbosity(verbosity)
 
     # Training and Evaluation Loop
     for epoch in range(max_epoch):
@@ -214,9 +215,7 @@
             ty.copy_from_numpy(y)
 
             # Train the model
-            out = model(tx)
-            loss = model.loss(out, ty)
-            model.optim(loss, dist_option, spars)
+            out, loss = model(tx, ty, dist_option, spars)
             train_correct += accuracy(tensor.to_numpy(out), y)
             train_loss += tensor.to_numpy(loss)[0]
 
@@ -256,6 +255,8 @@
                    time.time() - start_time),
                   flush=True)
 
+    dev.PrintTimeProfiling()
+
 
 if __name__ == '__main__':
     # use argparse to get command config: max_epoch, model, data, etc. for single gpu training
@@ -267,40 +268,46 @@
     parser.add_argument('data',
                         choices=['cifar10', 'cifar100', 'mnist'],
                         default='mnist')
-    parser.add_argument('--epoch',
+    parser.add_argument('-m',
                         '--max-epoch',
                         default=10,
                         type=int,
                         help='maximum epochs',
                         dest='max_epoch')
-    parser.add_argument('--bs',
+    parser.add_argument('-b',
                         '--batch-size',
                         default=64,
                         type=int,
                         help='batch size',
                         dest='batch_size')
-    parser.add_argument('--lr',
+    parser.add_argument('-l',
                         '--learning-rate',
                         default=0.005,
                         type=float,
                         help='initial learning rate',
                         dest='lr')
     # determine which gpu to use
-    parser.add_argument('--id',
+    parser.add_argument('-i',
                         '--device-id',
                         default=0,
                         type=int,
                         help='which GPU to use',
                         dest='device_id')
-    parser.add_argument('--no-graph',
+    parser.add_argument('-g',
                         '--disable-graph',
                         default='True',
                         action='store_false',
                         help='disable graph',
                         dest='graph')
+    parser.add_argument('-v',
+                        '--log-verbosity',
+                        default=0,
+                        type=int,
+                        help='logging verbosity',
+                        dest='verbosity')
 
     args = parser.parse_args()
 
     sgd = opt.SGD(lr=args.lr, momentum=0.9, weight_decay=1e-5)
     run(0, 1, args.device_id, args.max_epoch, args.batch_size, args.model,
-        args.data, sgd, args.graph)
+        args.data, sgd, args.graph, args.verbosity)
diff --git a/examples/cnn/train_mpi.py b/examples/cnn/train_mpi.py
index 01a32ff..fd78b12 100644
--- a/examples/cnn/train_mpi.py
+++ b/examples/cnn/train_mpi.py
@@ -21,7 +21,7 @@
 from singa import singa_wrap as singa
 from singa import opt
 import argparse
-import train
+import train_cnn
 
 if __name__ == '__main__':
     # use argparse to get command config: max_epoch, model, data, etc. for single gpu training
@@ -31,47 +31,54 @@
                         choices=['resnet', 'xceptionnet', 'cnn', 'mlp'],
                         default='cnn')
     parser.add_argument('data', choices=['cifar10', 'cifar100', 'mnist'], default='mnist')
-    parser.add_argument('--epoch',
+    parser.add_argument('-m',
                         '--max-epoch',
                         default=10,
                         type=int,
                         help='maximum epochs',
                         dest='max_epoch')
-    parser.add_argument('--bs',
+    parser.add_argument('-b',
                         '--batch-size',
                         default=64,
                         type=int,
                         help='batch size',
                         dest='batch_size')
-    parser.add_argument('--lr',
+    parser.add_argument('-l',
                         '--learning-rate',
                         default=0.005,
                         type=float,
                         help='initial learning rate',
                         dest='lr')
-    parser.add_argument('--op',
-                        '--option',
+    parser.add_argument('-d',
+                        '--dist-option',
                         default='fp32',
                         choices=['fp32','fp16','partialUpdate','sparseTopK','sparseThreshold'],
                         help='distibuted training options',
-                        dest='dist_option')  # currently partialUpdate support graph=False only 
-    parser.add_argument('--spars',
+                        dest='dist_option')  # currently partialUpdate support graph=False only
+    parser.add_argument('-s',
                         '--sparsification',
                         default='0.05',
                         type=float,
                         help='the sparsity parameter used for sparsification, between 0 to 1',
                         dest='spars')
-    parser.add_argument('--no-graph',
+    parser.add_argument('-g',
                         '--disable-graph',
                         default='True',
                         action='store_false',
                         help='disable graph',
                         dest='graph')
+    parser.add_argument('-v',
+                        '--log-verbosity',
+                        default=0,
+                        type=int,
+                        help='logging verbosity',
+                        dest='verbosity')
 
     args = parser.parse_args()
 
     sgd = opt.SGD(lr=args.lr, momentum=0.9, weight_decay=1e-5)
     sgd = opt.DistOpt(sgd)
 
-    train.run(sgd.global_rank, sgd.world_size, sgd.local_rank, args.max_epoch,
-              args.batch_size, args.model, args.data, sgd, args.graph, args.dist_option, args.spars)
+    train_cnn.run(sgd.global_rank, sgd.world_size, sgd.local_rank, args.max_epoch,
+              args.batch_size, args.model, args.data, sgd, args.graph,
+              args.verbosity, args.dist_option, args.spars)
diff --git a/examples/cnn/train_multiprocess.py b/examples/cnn/train_multiprocess.py
index 11ebe88..9972ddd 100644
--- a/examples/cnn/train_multiprocess.py
+++ b/examples/cnn/train_multiprocess.py
@@ -21,14 +21,15 @@
 from singa import singa_wrap as singa
 from singa import opt
 import argparse
-import train
+import train_cnn
 import multiprocessing
 
 def run(args, local_rank, world_size, nccl_id):
     sgd = opt.SGD(lr=args.lr, momentum=0.9, weight_decay=1e-5)
     sgd = opt.DistOpt(sgd, nccl_id=nccl_id, local_rank=local_rank, world_size=world_size)
-    train.run(sgd.global_rank, sgd.world_size, sgd.local_rank,
-              args.max_epoch, args.batch_size, args.model, args.data, sgd, args.graph, args.dist_option, args.spars)
+    train_cnn.run(sgd.global_rank, sgd.world_size, sgd.local_rank, args.max_epoch,
+              args.batch_size, args.model, args.data, sgd, args.graph,
+              args.verbosity, args.dist_option, args.spars)
 
 
 if __name__ == '__main__':
@@ -39,49 +40,54 @@
                         choices=['resnet', 'xceptionnet', 'cnn', 'mlp'],
                         default='cnn')
     parser.add_argument('data', choices=['cifar10', 'cifar100', 'mnist'], default='mnist')
-    parser.add_argument('--epoch',
+    parser.add_argument('-m',
                         '--max-epoch',
                         default=10,
                         type=int,
                         help='maximum epochs',
                         dest='max_epoch')
-    parser.add_argument('--bs',
+    parser.add_argument('-b',
                         '--batch-size',
                         default=64,
                         type=int,
                         help='batch size',
                         dest='batch_size')
-    parser.add_argument('--lr',
+    parser.add_argument('-l',
                         '--learning-rate',
                         default=0.005,
                         type=float,
                         help='initial learning rate',
                         dest='lr')
-    parser.add_argument('--ws',
+    parser.add_argument('-w',
                         '--world-size',
                         default=2,
                         type=int,
                         help='number of gpus to be used',
                         dest='world_size')
-    parser.add_argument('--op',
-                        '--option',
+    parser.add_argument('-d',
+                        '--dist-option',
                         default='fp32',
                         choices=['fp32','fp16','partialUpdate','sparseTopK','sparseThreshold'],
                         help='distibuted training options',
-                        dest='dist_option') # currently partialUpdate support graph=False only 
-    parser.add_argument('--spars',
+                        dest='dist_option') # currently partialUpdate support graph=False only
+    parser.add_argument('-s',
                         '--sparsification',
                         default='0.05',
                         type=float,
                         help='the sparsity parameter used for sparsification, between 0 to 1',
                         dest='spars')
-    parser.add_argument('--no-graph',
+    parser.add_argument('-g',
                         '--disable-graph',
                         default='True',
                         action='store_false',
                         help='disable graph',
                         dest='graph')
-
+    parser.add_argument('-v',
+                        '--log-verbosity',
+                        default=0,
+                        type=int,
+                        help='logging verbosity',
+                        dest='verbosity')
 
     args = parser.parse_args()
 
diff --git a/examples/gan/README.md b/examples/gan/README.md
new file mode 100644
index 0000000..c805f7b
--- /dev/null
+++ b/examples/gan/README.md
@@ -0,0 +1,34 @@
+<!--
+    Licensed to the Apache Software Foundation (ASF) under one
+    or more contributor license agreements.  See the NOTICE file
+    distributed with this work for additional information
+    regarding copyright ownership.  The ASF licenses this file
+    to you under the Apache License, Version 2.0 (the
+    "License"); you may not use this file except in compliance
+    with the License.  You may obtain a copy of the License at
+
+      http://www.apache.org/licenses/LICENSE-2.0
+
+    Unless required by applicable law or agreed to in writing,
+    software distributed under the License is distributed on an
+    "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+    KIND, either express or implied.  See the License for the
+    specific language governing permissions and limitations
+    under the License.
+-->
+# Train a Generative Adversarial Nets (GAN) model
+
+This example is to train a Generative Adversarial Nets (GAN) model over the MNIST dataset.
+
+## Running instructions
+
+1. Download the pre-processed [MNIST dataset](https://github.com/mnielsen/neural-networks-and-deep-learning/raw/master/data/mnist.pkl.gz)
+
+2. Start the training
+
+        python vanilla.py mnist.pkl.gz
+
+By default the training code would run on CPU. To run it on a GPU card, please start
+the program with an additional argument
+
+        python vanilla.py mnist.pkl.gz --use_gpu
diff --git a/examples/gan/lsgan.py b/examples/gan/lsgan.py
index dc6582c..39f243e 100644
--- a/examples/gan/lsgan.py
+++ b/examples/gan/lsgan.py
@@ -18,196 +18,169 @@
 #
 
 from singa import device
-from singa import initializer
-from singa import layer
-from singa import loss
-from singa import net as ffnet
-from singa import optimizer
+from singa import opt
 from singa import tensor
 
 import argparse
 import matplotlib.pyplot as plt
 import numpy as np
 import os
-
+from model import lsgan_mlp
 from utils import load_data
 from utils import print_log
 
+
 class LSGAN():
-	def  __init__(self, dev, rows=28, cols=28, channels=1, noise_size=100, hidden_size=128, batch=128, 
-		interval=1000, learning_rate=0.001, epochs=1000000, d_steps=3, g_steps=1, 
-		dataset_filepath='mnist.pkl.gz', file_dir='lsgan_images/'):
-		self.dev = dev
-		self.rows = rows
-		self.cols = cols
-		self.channels = channels
-		self.feature_size = self.rows * self.cols * self.channels
-		self.noise_size = noise_size
-		self.hidden_size = hidden_size
-		self.batch = batch
-		self.batch_size = self.batch//2
-		self.interval = interval
-		self.learning_rate = learning_rate
-		self.epochs = epochs
-		self.d_steps = d_steps
-		self.g_steps = g_steps
-		self.dataset_filepath = dataset_filepath
-		self.file_dir = file_dir
 
-		self.g_w0_specs = {'init': 'xavier',}
-		self.g_b0_specs = {'init': 'constant', 'value': 0,}
-		self.g_w1_specs = {'init': 'xavier',}
-		self.g_b1_specs = {'init': 'constant', 'value': 0,}
-		self.gen_net = ffnet.FeedForwardNet(loss.SquaredError(),)
-		self.gen_net_fc_0 = layer.Dense(name='g_fc_0', num_output=self.hidden_size, use_bias=True, 
-			W_specs=self.g_w0_specs, b_specs=self.g_b0_specs, input_sample_shape=(self.noise_size,))
-		self.gen_net_relu_0 = layer.Activation(name='g_relu_0', mode='relu',input_sample_shape=(self.hidden_size,))
-		self.gen_net_fc_1 = layer.Dense(name='g_fc_1', num_output=self.feature_size, use_bias=True, 
-			W_specs=self.g_w1_specs, b_specs=self.g_b1_specs, input_sample_shape=(self.hidden_size,))
-		self.gen_net_sigmoid_1 = layer.Activation(name='g_relu_1', mode='sigmoid', input_sample_shape=(self.feature_size,))
-		self.gen_net.add(self.gen_net_fc_0)
-		self.gen_net.add(self.gen_net_relu_0)
-		self.gen_net.add(self.gen_net_fc_1)
-		self.gen_net.add(self.gen_net_sigmoid_1)
-		for (p, specs) in zip(self.gen_net.param_values(), self.gen_net.param_specs()):
-			filler = specs.filler
-			if filler.type == 'gaussian':
-				p.gaussian(filler.mean, filler.std)
-			elif filler.type == 'xavier':
-				initializer.xavier(p)
-			else: 
-				p.set_value(0)
-			print(specs.name, filler.type, p.l1())	
-		self.gen_net.to_device(self.dev)		
+    def __init__(self,
+                 dev,
+                 rows=28,
+                 cols=28,
+                 channels=1,
+                 noise_size=100,
+                 hidden_size=128,
+                 batch=128,
+                 interval=1000,
+                 learning_rate=0.001,
+                 iterations=1000000,
+                 d_steps=3,
+                 g_steps=1,
+                 dataset_filepath='mnist.pkl.gz',
+                 file_dir='lsgan_images/'):
+        self.dev = dev
+        self.rows = rows
+        self.cols = cols
+        self.channels = channels
+        self.feature_size = self.rows * self.cols * self.channels
+        self.noise_size = noise_size
+        self.hidden_size = hidden_size
+        self.batch = batch
+        self.batch_size = self.batch // 2
+        self.interval = interval
+        self.learning_rate = learning_rate
+        self.iterations = iterations
+        self.d_steps = d_steps
+        self.g_steps = g_steps
+        self.dataset_filepath = dataset_filepath
+        self.file_dir = file_dir
+        self.model = lsgan_mlp.create_model(noise_size=self.noise_size,
+                                            feature_size=self.feature_size,
+                                            hidden_size=self.hidden_size)
 
-		self.d_w0_specs = {'init': 'xavier',}
-		self.d_b0_specs = {'init': 'constant', 'value': 0,}
-		self.d_w1_specs = {'init': 'xavier',}
-		self.d_b1_specs = {'init': 'constant', 'value': 0,}			
-		self.dis_net = ffnet.FeedForwardNet(loss.SquaredError(),)
-		self.dis_net_fc_0 = layer.Dense(name='d_fc_0', num_output=self.hidden_size, use_bias=True, 
-			W_specs=self.d_w0_specs, b_specs=self.d_b0_specs, input_sample_shape=(self.feature_size,))
-		self.dis_net_relu_0 = layer.Activation(name='d_relu_0', mode='relu',input_sample_shape=(self.hidden_size,))
-		self.dis_net_fc_1 = layer.Dense(name='d_fc_1', num_output=1,  use_bias=True, 
-			W_specs=self.d_w1_specs, b_specs=self.d_b1_specs, input_sample_shape=(self.hidden_size,))
-		self.dis_net.add(self.dis_net_fc_0)
-		self.dis_net.add(self.dis_net_relu_0)
-		self.dis_net.add(self.dis_net_fc_1)			
-		for (p, specs) in zip(self.dis_net.param_values(), self.dis_net.param_specs()):
-			filler = specs.filler
-			if filler.type == 'gaussian':
-				p.gaussian(filler.mean, filler.std)
-			elif filler.type == 'xavier':
-				initializer.xavier(p)
-			else: 
-				p.set_value(0)
-			print(specs.name, filler.type, p.l1())
-		self.dis_net.to_device(self.dev)
+    def train(self):
+        train_data, _, _, _, _, _ = load_data(self.dataset_filepath)
+        dev = device.create_cuda_gpu_on(0)
+        dev.SetRandSeed(0)
+        np.random.seed(0)
 
-		self.combined_net = ffnet.FeedForwardNet(loss.SquaredError(), )
-		for l in self.gen_net.layers:
-			self.combined_net.add(l)
-		for l in self.dis_net.layers:
-			self.combined_net.add(l)
-		self.combined_net.to_device(self.dev)
+        #sgd = opt.SGD(lr=self.learning_rate, momentum=0.9, weight_decay=1e-5)
+        sgd = opt.Adam(lr=self.learning_rate)
 
-	def train(self):
-		train_data, _, _, _, _, _ = load_data(self.dataset_filepath)
-		opt_0 = optimizer.Adam(lr=self.learning_rate) # optimizer for discriminator 
-		opt_1 = optimizer.Adam(lr=self.learning_rate) # optimizer for generator, aka the combined model
-		for (p, specs) in zip(self.dis_net.param_names(), self.dis_net.param_specs()):
-			opt_0.register(p, specs)
-		for (p, specs) in zip(self.gen_net.param_names(), self.gen_net.param_specs()):
-			opt_1.register(p, specs)
+        noise = tensor.Tensor((self.batch_size, self.noise_size), dev,
+                              tensor.float32)
+        real_images = tensor.Tensor((self.batch_size, self.feature_size), dev,
+                                    tensor.float32)
+        real_labels = tensor.Tensor((self.batch_size, 1), dev, tensor.float32)
+        fake_labels = tensor.Tensor((self.batch_size, 1), dev, tensor.float32)
+        substrahend_labels = tensor.Tensor((self.batch_size, 1), dev, tensor.float32)
 
-		for epoch in range(self.epochs):
-			for d_step in range(self.d_steps):
-				idx = np.random.randint(0, train_data.shape[0], self.batch_size)
-				real_imgs = train_data[idx]
-				real_imgs = tensor.from_numpy(real_imgs)
-				real_imgs.to_device(self.dev)
-				noise = tensor.Tensor((self.batch_size, self.noise_size))
-				noise.uniform(-1, 1)
-				noise.to_device(self.dev)
-				fake_imgs = self.gen_net.forward(flag=False, x=noise)
-				substrahend = tensor.Tensor((real_imgs.shape[0], 1))
-				substrahend.set_value(1.0)
-				substrahend.to_device(self.dev)
-				grads, (d_loss_real, _) = self.dis_net.train(real_imgs, substrahend)
-				for (s, p ,g) in zip(self.dis_net.param_names(), self.dis_net.param_values(), grads):
-					opt_0.apply_with_lr(epoch, self.learning_rate, g, p, str(s), epoch)
-				substrahend.set_value(-1.0)
-				grads, (d_loss_fake, _) = self.dis_net.train(fake_imgs, substrahend)
-				for (s, p ,g) in zip(self.dis_net.param_names(), self.dis_net.param_values(), grads):
-					opt_0.apply_with_lr(epoch, self.learning_rate, g, p, str(s), epoch)
-				d_loss = d_loss_real + d_loss_fake
-			
-			for g_step in range(self.g_steps): 
-				noise = tensor.Tensor((self.batch_size, self.noise_size))
-				noise.uniform(-1, 1)
-				noise.to_device(self.dev)
-				substrahend = tensor.Tensor((real_imgs.shape[0], 1))
-				substrahend.set_value(0.0)
-				substrahend.to_device(self.dev)
-				grads, (g_loss, _) = self.combined_net.train(noise, substrahend)
-				for (s, p ,g) in zip(self.gen_net.param_names(), self.gen_net.param_values(), grads):
-					opt_1.apply_with_lr(epoch, self.learning_rate, g, p, str(s), epoch)
-			
-			if epoch % self.interval == 0:
-				self.save_image(epoch)
-				print_log('The {} epoch, G_LOSS: {}, D_LOSS: {}'.format(epoch, g_loss, d_loss))
+        # attached model to graph
+        self.model.set_optimizer(sgd)
+        self.model.compile([noise],
+                           is_train=True,
+                           use_graph=False,
+                           sequential=True)
 
-	def save_image(self, epoch):
-		rows = 5
-		cols = 5
-		channels = self.channels
-		noise = tensor.Tensor((rows*cols*channels, self.noise_size))
-		noise.uniform(-1,1)
-		noise.to_device(self.dev)
-		gen_imgs = self.gen_net.forward(flag=False, x=noise)
-		gen_imgs = tensor.to_numpy(gen_imgs)
-		show_imgs = np.reshape(gen_imgs, (gen_imgs.shape[0], self.rows, self.cols, self.channels))
-		fig, axs = plt.subplots(rows, cols)
-		cnt = 0
-		for r in range(rows):
-			for c in range(cols):
-				axs[r,c].imshow(show_imgs[cnt, :, :, 0], cmap='gray')
-				axs[r,c].axis('off')
-				cnt += 1
-		fig.savefig("{}{}.png".format(self.file_dir, epoch))
-		plt.close()
+        real_labels.set_value(1.0)
+        fake_labels.set_value(-1.0)
+        substrahend_labels.set_value(0.0)
+
+        for iteration in range(self.iterations):
+
+            for d_step in range(self.d_steps):
+                idx = np.random.randint(0, train_data.shape[0], self.batch_size)
+                real_images.copy_from_numpy(train_data[idx])
+
+                self.model.train()
+
+                # Training the Discriminative Net
+                _, d_loss_real = self.model.train_one_batch_dis(
+                    real_images, real_labels)
+
+                noise.uniform(-1, 1)
+                fake_images = self.model.forward_gen(noise)
+                _, d_loss_fake = self.model.train_one_batch_dis(
+                    fake_images, fake_labels)
+
+                d_loss = tensor.to_numpy(d_loss_real)[0] + tensor.to_numpy(
+                    d_loss_fake)[0]
+
+            for g_step in range(self.g_steps):
+                # Training the Generative Net
+                noise.uniform(-1, 1)
+                _, g_loss_tensor = self.model.train_one_batch(
+                    noise, substrahend_labels)
+
+                g_loss = tensor.to_numpy(g_loss_tensor)[0]
+
+            if iteration % self.interval == 0:
+                self.model.eval()
+                self.save_image(iteration)
+                print_log(' The {} iteration, G_LOSS: {}, D_LOSS: {}'.format(
+                    iteration, g_loss, d_loss))
+
+    def save_image(self, iteration):
+        demo_row = 5
+        demo_col = 5
+        if not hasattr(self, "demo_noise"):
+            self.demo_noise = tensor.Tensor(
+                (demo_col * demo_row, self.noise_size), dev, tensor.float32)
+        self.demo_noise.uniform(-1, 1)
+        gen_imgs = self.model.forward_gen(self.demo_noise)
+        gen_imgs = tensor.to_numpy(gen_imgs)
+        show_imgs = np.reshape(
+            gen_imgs, (gen_imgs.shape[0], self.rows, self.cols, self.channels))
+        fig, axs = plt.subplots(demo_row, demo_col)
+        cnt = 0
+        for r in range(demo_row):
+            for c in range(demo_col):
+                axs[r, c].imshow(show_imgs[cnt, :, :, 0], cmap='gray')
+                axs[r, c].axis('off')
+                cnt += 1
+        fig.savefig("{}{}.png".format(self.file_dir, iteration))
+        plt.close()
+
 
 if __name__ == '__main__':
-	parser = argparse.ArgumentParser(description='Train GAN over MNIST')
-	parser.add_argument('filepath',  type=str, help='the dataset path')
-	parser.add_argument('--use_gpu', action='store_true')
-	args = parser.parse_args()
-	
-	if args.use_gpu:
-		print('Using GPU')
-		dev = device.create_cuda_gpu()
-		layer.engine = 'cudnn'
-	else:
-		print('Using CPU')
-		dev = device.get_default_device()
-		layer.engine = 'singacpp'
+    parser = argparse.ArgumentParser(description='Train GAN over MNIST')
+    parser.add_argument('filepath', type=str, help='the dataset path')
+    parser.add_argument('--use_gpu', action='store_true')
+    args = parser.parse_args()
 
-	if not os.path.exists('lsgan_images/'):
-		os.makedirs('lsgan_images/')
+    if args.use_gpu:
+        print('Using GPU')
+        dev = device.create_cuda_gpu()
+    else:
+        print('Using CPU')
+        dev = device.get_default_device()
 
-	rows = 28
-	cols = 28
-	channels = 1
-	noise_size = 100
-	hidden_size = 128
-	batch = 128
-	interval = 1000
-	learning_rate = 0.001
-	epochs = 1000000
-	d_steps = 3
-	g_steps = 1
-	dataset_filepath = 'mnist.pkl.gz'
-	file_dir = 'lsgan_images/'
-	lsgan = LSGAN(dev, rows, cols, channels, noise_size, hidden_size, batch, interval, 
-		learning_rate, epochs, d_steps, g_steps, dataset_filepath, file_dir)
-	lsgan.train()
\ No newline at end of file
+    if not os.path.exists('lsgan_images/'):
+        os.makedirs('lsgan_images/')
+
+    rows = 28
+    cols = 28
+    channels = 1
+    noise_size = 100
+    hidden_size = 128
+    batch = 128
+    interval = 1000
+    learning_rate = 0.0005
+    iterations = 1000000
+    d_steps = 1
+    g_steps = 1
+    dataset_filepath = 'mnist.pkl.gz'
+    file_dir = 'lsgan_images/'
+    lsgan = LSGAN(dev, rows, cols, channels, noise_size, hidden_size, batch,
+                  interval, learning_rate, iterations, d_steps, g_steps,
+                  dataset_filepath, file_dir)
+    lsgan.train()
diff --git a/examples/gan/model/gan_mlp.py b/examples/gan/model/gan_mlp.py
new file mode 100644
index 0000000..d1c46a1
--- /dev/null
+++ b/examples/gan/model/gan_mlp.py
@@ -0,0 +1,104 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from singa import layer
+from singa import model
+from singa import autograd
+
+
+class GAN_MLP(model.Model):
+
+    def __init__(self, noise_size=100, feature_size=784, hidden_size=128):
+        super(GAN_MLP, self).__init__()
+        self.noise_size = noise_size
+        self.feature_size = feature_size
+        self.hidden_size = hidden_size
+
+        # Generative Net
+        self.gen_net_fc_0 = layer.Linear(self.hidden_size)
+        self.gen_net_relu_0 = layer.ReLU()
+        self.gen_net_fc_1 = layer.Linear(self.feature_size)
+        self.gen_net_sigmoid_1 = layer.Sigmoid()
+
+        # Discriminative Net
+        self.dis_net_fc_0 = layer.Linear(self.hidden_size)
+        self.dis_net_relu_0 = layer.ReLU()
+        self.dis_net_fc_1 = layer.Linear(1)
+        self.dis_net_sigmoid_1= layer.Sigmoid()
+        self.binary_cross_entropy = layer.BinaryCrossEntropy()
+
+    def forward(self, x):
+        # Cascaded Net
+        y = self.forward_gen(x)
+        y = self.forward_dis(y)
+        return y
+
+    def forward_dis(self, x):
+        # Discriminative Net
+        y = self.dis_net_fc_0(x)
+        y = self.dis_net_relu_0(y)
+        y = self.dis_net_fc_1(y)
+        y = self.dis_net_sigmoid_1(y)
+        return y
+
+    def forward_gen(self, x):
+        # Generative Net
+        y = self.gen_net_fc_0(x)
+        y = self.gen_net_relu_0(y)
+        y = self.gen_net_fc_1(y)
+        y = self.gen_net_sigmoid_1(y)
+        return y
+
+    def train_one_batch(self, x, y):
+        # Training the Generative Net
+        out = self.forward(x)
+        loss = self.binary_cross_entropy(out, y)
+        # Only update the Generative Net
+        for p, g in autograd.backward(loss):
+            if "gen_net" in p.name:
+                self.optimizer.apply(p.name, p, g)
+        return out, loss
+
+    def train_one_batch_dis(self, x, y):
+        # Training the Discriminative Net
+        out = self.forward_dis(x)
+        loss = self.binary_cross_entropy(out, y)
+        # Only update the Discriminative Net
+        for p, g in autograd.backward(loss):
+            if "dis_net" in p.name:
+                self.optimizer.apply(p.name, p, g)
+        self.optimizer(loss)
+        return out, loss
+
+    def set_optimizer(self, optimizer):
+        self.optimizer = optimizer
+
+
+def create_model(pretrained=False, **kwargs):
+    """Constructs a CNN model.
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained
+    """
+    model = GAN_MLP(**kwargs)
+
+    return model
+
+
+__all__ = ['GAN_MLP', 'create_model']
diff --git a/examples/gan/model/lsgan_mlp.py b/examples/gan/model/lsgan_mlp.py
new file mode 100644
index 0000000..c67222e
--- /dev/null
+++ b/examples/gan/model/lsgan_mlp.py
@@ -0,0 +1,101 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from singa import layer
+from singa import model
+from singa import autograd
+
+
+class LSGAN_MLP(model.Model):
+
+    def __init__(self, noise_size=100, feature_size=784, hidden_size=128):
+        super(LSGAN_MLP, self).__init__()
+        self.noise_size = noise_size
+        self.feature_size = feature_size
+        self.hidden_size = hidden_size
+
+        # Generative Net
+        self.gen_net_fc_0 = layer.Linear(self.hidden_size)
+        self.gen_net_relu_0 = layer.ReLU()
+        self.gen_net_fc_1 = layer.Linear(self.feature_size)
+        self.gen_net_sigmoid_1 = layer.Sigmoid()
+
+        # Discriminative Net
+        self.dis_net_fc_0 = layer.Linear(self.hidden_size)
+        self.dis_net_relu_0 = layer.ReLU()
+        self.dis_net_fc_1 = layer.Linear(1)
+        self.mse_loss = layer.MeanSquareError()
+
+    def forward(self, x):
+        # Cascaded Net
+        y = self.forward_gen(x)
+        y = self.forward_dis(y)
+        return y
+
+    def forward_dis(self, x):
+        # Discriminative Net
+        y = self.dis_net_fc_0(x)
+        y = self.dis_net_relu_0(y)
+        y = self.dis_net_fc_1(y)
+        return y
+
+    def forward_gen(self, x):
+        # Generative Net
+        y = self.gen_net_fc_0(x)
+        y = self.gen_net_relu_0(y)
+        y = self.gen_net_fc_1(y)
+        y = self.gen_net_sigmoid_1(y)
+        return y
+
+    def train_one_batch(self, x, y):
+        # Training the Generative Net
+        out = self.forward(x)
+        loss = self.mse_loss(out, y)
+        # Only update the Generative Net
+        for p, g in autograd.backward(loss):
+            if "gen_net" in p.name:
+                self.optimizer.apply(p.name, p, g)
+        return out, loss
+
+    def train_one_batch_dis(self, x, y):
+        # Training the Discriminative Net
+        out = self.forward_dis(x)
+        loss = self.mse_loss(out, y)
+        # Only update the Discriminative Net
+        for p, g in autograd.backward(loss):
+            if "dis_net" in p.name:
+                self.optimizer.apply(p.name, p, g)
+        return out, loss
+
+    def set_optimizer(self, optimizer):
+        self.optimizer = optimizer
+
+
+def create_model(pretrained=False, **kwargs):
+    """Constructs a CNN model.
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained
+    """
+    model = LSGAN_MLP(**kwargs)
+
+    return model
+
+
+__all__ = ['LSGAN_MLP', 'create_model']
diff --git a/examples/gan/vanilla.py b/examples/gan/vanilla.py
index ce5e048..49c8ec4 100644
--- a/examples/gan/vanilla.py
+++ b/examples/gan/vanilla.py
@@ -18,190 +18,158 @@
 #
 
 from singa import device
-from singa import initializer
-from singa import layer
-from singa import loss
-from singa import net as ffnet
-from singa import optimizer
+from singa import opt
 from singa import tensor
 
 import argparse
 import matplotlib.pyplot as plt
 import numpy as np
 import os
-
+from model import gan_mlp
 from utils import load_data
 from utils import print_log
 
+
 class VANILLA():
-	def  __init__(self, dev, rows=28, cols=28, channels=1, noise_size=100, hidden_size=128, batch=128, 
-		interval=1000, learning_rate=0.001, epochs=1000000, dataset_filepath='mnist.pkl.gz', file_dir='vanilla_images/'):
-		self.dev = dev
-		self.rows = rows
-		self.cols = cols
-		self.channels = channels
-		self.feature_size = self.rows * self.cols * self.channels
-		self.noise_size = noise_size
-		self.hidden_size = hidden_size
-		self.batch = batch
-		self.batch_size = self.batch//2
-		self.interval = interval
-		self.learning_rate = learning_rate
-		self.epochs = epochs
-		self.dataset_filepath = dataset_filepath
-		self.file_dir = file_dir
 
-		self.g_w0_specs = {'init': 'xavier',}
-		self.g_b0_specs = {'init': 'constant', 'value': 0,}
-		self.g_w1_specs = {'init': 'xavier',}
-		self.g_b1_specs = {'init': 'constant', 'value': 0,}
-		self.gen_net = ffnet.FeedForwardNet(loss.SigmoidCrossEntropy(),)
-		self.gen_net_fc_0 = layer.Dense(name='g_fc_0', num_output=self.hidden_size, use_bias=True, 
-			W_specs=self.g_w0_specs, b_specs=self.g_b0_specs, input_sample_shape=(self.noise_size,))
-		self.gen_net_relu_0 = layer.Activation(name='g_relu_0', mode='relu',input_sample_shape=(self.hidden_size,))
-		self.gen_net_fc_1 = layer.Dense(name='g_fc_1', num_output=self.feature_size, use_bias=True, 
-			W_specs=self.g_w1_specs, b_specs=self.g_b1_specs, input_sample_shape=(self.hidden_size,))
-		self.gen_net_sigmoid_1 = layer.Activation(name='g_relu_1', mode='sigmoid', input_sample_shape=(self.feature_size,))
-		self.gen_net.add(self.gen_net_fc_0)
-		self.gen_net.add(self.gen_net_relu_0)
-		self.gen_net.add(self.gen_net_fc_1)
-		self.gen_net.add(self.gen_net_sigmoid_1)
-		for (p, specs) in zip(self.gen_net.param_values(), self.gen_net.param_specs()):
-			filler = specs.filler
-			if filler.type == 'gaussian':
-				p.gaussian(filler.mean, filler.std)
-			elif filler.type == 'xavier':
-				initializer.xavier(p)
-			else: 
-				p.set_value(0)
-			print(specs.name, filler.type, p.l1())	
-		self.gen_net.to_device(self.dev)		
+    def __init__(self,
+                 dev,
+                 rows=28,
+                 cols=28,
+                 channels=1,
+                 noise_size=100,
+                 hidden_size=128,
+                 batch=128,
+                 interval=1000,
+                 learning_rate=0.001,
+                 iterations=1000000,
+                 dataset_filepath='mnist.pkl.gz',
+                 file_dir='vanilla_images/'):
+        self.dev = dev
+        self.rows = rows
+        self.cols = cols
+        self.channels = channels
+        self.feature_size = self.rows * self.cols * self.channels
+        self.noise_size = noise_size
+        self.hidden_size = hidden_size
+        self.batch = batch
+        self.batch_size = self.batch // 2
+        self.interval = interval
+        self.learning_rate = learning_rate
+        self.iterations = iterations
+        self.dataset_filepath = dataset_filepath
+        self.file_dir = file_dir
+        self.model = gan_mlp.create_model(noise_size=self.noise_size,
+                                          feature_size=self.feature_size,
+                                          hidden_size=self.hidden_size)
 
-		self.d_w0_specs = {'init': 'xavier',}
-		self.d_b0_specs = {'init': 'constant', 'value': 0,}
-		self.d_w1_specs = {'init': 'xavier',}
-		self.d_b1_specs = {'init': 'constant', 'value': 0,}			
-		self.dis_net = ffnet.FeedForwardNet(loss.SigmoidCrossEntropy(),)
-		self.dis_net_fc_0 = layer.Dense(name='d_fc_0', num_output=self.hidden_size, use_bias=True, 
-			W_specs=self.d_w0_specs, b_specs=self.d_b0_specs, input_sample_shape=(self.feature_size,))
-		self.dis_net_relu_0 = layer.Activation(name='d_relu_0', mode='relu',input_sample_shape=(self.hidden_size,))
-		self.dis_net_fc_1 = layer.Dense(name='d_fc_1', num_output=1,  use_bias=True, 
-			W_specs=self.d_w1_specs, b_specs=self.d_b1_specs, input_sample_shape=(self.hidden_size,))
-		self.dis_net.add(self.dis_net_fc_0)
-		self.dis_net.add(self.dis_net_relu_0)
-		self.dis_net.add(self.dis_net_fc_1)			
-		for (p, specs) in zip(self.dis_net.param_values(), self.dis_net.param_specs()):
-			filler = specs.filler
-			if filler.type == 'gaussian':
-				p.gaussian(filler.mean, filler.std)
-			elif filler.type == 'xavier':
-				initializer.xavier(p)
-			else: 
-				p.set_value(0)
-			print(specs.name, filler.type, p.l1())
-		self.dis_net.to_device(self.dev)
+    def train(self):
+        train_data, _, _, _, _, _ = load_data(self.dataset_filepath)
+        dev = device.create_cuda_gpu_on(0)
+        dev.SetRandSeed(0)
+        np.random.seed(0)
 
-		self.combined_net = ffnet.FeedForwardNet(loss.SigmoidCrossEntropy(), )
-		for l in self.gen_net.layers:
-			self.combined_net.add(l)
-		for l in self.dis_net.layers:
-			self.combined_net.add(l)
-		self.combined_net.to_device(self.dev)
+        # sgd = opt.SGD(lr=self.learning_rate, momentum=0.9, weight_decay=1e-5)
+        sgd = opt.Adam(lr=self.learning_rate)
 
-	def train(self):
-		train_data, _, _, _, _, _ = load_data(self.dataset_filepath)
-		opt_0 = optimizer.Adam(lr=self.learning_rate) # optimizer for discriminator 
-		opt_1 = optimizer.Adam(lr=self.learning_rate) # optimizer for generator, aka the combined model
-		for (p, specs) in zip(self.dis_net.param_names(), self.dis_net.param_specs()):
-			opt_0.register(p, specs)
-		for (p, specs) in zip(self.gen_net.param_names(), self.gen_net.param_specs()):
-			opt_1.register(p, specs)
+        noise = tensor.Tensor((self.batch_size, self.noise_size), dev,
+                              tensor.float32)
+        real_images = tensor.Tensor((self.batch_size, self.feature_size), dev,
+                                    tensor.float32)
+        real_labels = tensor.Tensor((self.batch_size, 1), dev, tensor.float32)
+        fake_labels = tensor.Tensor((self.batch_size, 1), dev, tensor.float32)
 
-		for epoch in range(self.epochs):
-			idx = np.random.randint(0, train_data.shape[0], self.batch_size)
-			real_imgs = train_data[idx]
-			real_imgs = tensor.from_numpy(real_imgs)
-			real_imgs.to_device(self.dev)
-			noise = tensor.Tensor((self.batch_size, self.noise_size))
-			noise.uniform(-1, 1)
-			noise.to_device(self.dev)
-			fake_imgs = self.gen_net.forward(flag=False, x=noise)
-			real_labels = tensor.Tensor((self.batch_size, 1))
-			fake_labels = tensor.Tensor((self.batch_size, 1))
-			real_labels.set_value(1.0)
-			fake_labels.set_value(0.0)
-			real_labels.to_device(self.dev)
-			fake_labels.to_device(self.dev)
-			grads, (d_loss_real, _) = self.dis_net.train(real_imgs, real_labels)
-			for (s, p ,g) in zip(self.dis_net.param_names(), self.dis_net.param_values(), grads):
-				opt_0.apply_with_lr(epoch, self.learning_rate, g, p, str(s), epoch)
-			grads, (d_loss_fake, _) = self.dis_net.train(fake_imgs, fake_labels)
-			for (s, p ,g) in zip(self.dis_net.param_names(), self.dis_net.param_values(), grads):
-				opt_0.apply_with_lr(epoch, self.learning_rate, g, p, str(s), epoch)
-			d_loss = d_loss_real + d_loss_fake
-			noise = tensor.Tensor((self.batch_size, self.noise_size))
-			noise.uniform(-1,1)
-			noise.to_device(self.dev)
-			real_labels = tensor.Tensor((self.batch_size, 1))
-			real_labels.set_value(1.0)
-			real_labels.to_device(self.dev)
-			grads, (g_loss, _) = self.combined_net.train(noise, real_labels)
-			for (s, p ,g) in zip(self.gen_net.param_names(), self.gen_net.param_values(), grads):
-				opt_1.apply_with_lr(epoch, self.learning_rate, g, p, str(s), epoch)
-			
-			if epoch % self.interval == 0:
-				self.save_image(epoch)
-				print_log('The {} epoch, G_LOSS: {}, D_LOSS: {}'.format(epoch, g_loss, d_loss))
+        # attached model to graph
+        self.model.set_optimizer(sgd)
+        self.model.compile([noise],
+                           is_train=True,
+                           use_graph=False,
+                           sequential=True)
 
-	def save_image(self, epoch):
-		rows = 5
-		cols = 5
-		channels = self.channels
-		noise = tensor.Tensor((rows*cols*channels, self.noise_size))
-		noise.uniform(-1, 1)
-		noise.to_device(self.dev)
-		gen_imgs = self.gen_net.forward(flag=False, x=noise)
-		gen_imgs = tensor.to_numpy(gen_imgs)
-		show_imgs = np.reshape(gen_imgs, (gen_imgs.shape[0], self.rows, self.cols, self.channels))
-		fig, axs = plt.subplots(rows, cols)
-		cnt = 0
-		for r in range(rows):
-			for c in range(cols):
-				axs[r,c].imshow(show_imgs[cnt, :, :, 0], cmap='gray')
-				axs[r,c].axis('off')
-				cnt += 1
-		fig.savefig("{}{}.png".format(self.file_dir, epoch))
-		plt.close()
+        real_labels.set_value(1.0)
+        fake_labels.set_value(0.0)
+
+        for iteration in range(self.iterations):
+            idx = np.random.randint(0, train_data.shape[0], self.batch_size)
+            real_images.copy_from_numpy(train_data[idx])
+
+            self.model.train()
+
+            # Training the Discriminative Net
+            _, d_loss_real = self.model.train_one_batch_dis(
+                real_images, real_labels)
+
+            noise.uniform(-1, 1)
+            fake_images = self.model.forward_gen(noise)
+            _, d_loss_fake = self.model.train_one_batch_dis(
+                fake_images, fake_labels)
+
+            d_loss = tensor.to_numpy(d_loss_real)[0] + tensor.to_numpy(
+                d_loss_fake)[0]
+
+            # Training the Generative Net
+            noise.uniform(-1, 1)
+            _, g_loss_tensor = self.model.train_one_batch(
+                noise, real_labels)
+
+            g_loss = tensor.to_numpy(g_loss_tensor)[0]
+
+            if iteration % self.interval == 0:
+                self.model.eval()
+                self.save_image(iteration)
+                print_log(' The {} iteration, G_LOSS: {}, D_LOSS: {}'.format(
+                    iteration, g_loss, d_loss))
+
+    def save_image(self, iteration):
+        demo_row = 5
+        demo_col = 5
+        if not hasattr(self, "demo_noise"):
+            self.demo_noise = tensor.Tensor(
+                (demo_col * demo_row, self.noise_size), dev, tensor.float32)
+        self.demo_noise.uniform(-1, 1)
+        gen_imgs = self.model.forward_gen(self.demo_noise)
+        gen_imgs = tensor.to_numpy(gen_imgs)
+        show_imgs = np.reshape(
+            gen_imgs, (gen_imgs.shape[0], self.rows, self.cols, self.channels))
+        fig, axs = plt.subplots(demo_row, demo_col)
+        cnt = 0
+        for r in range(demo_row):
+            for c in range(demo_col):
+                axs[r, c].imshow(show_imgs[cnt, :, :, 0], cmap='gray')
+                axs[r, c].axis('off')
+                cnt += 1
+        fig.savefig("{}{}.png".format(self.file_dir, iteration))
+        plt.close()
+
 
 if __name__ == '__main__':
-	parser = argparse.ArgumentParser(description='Train GAN over MNIST')
-	parser.add_argument('filepath',  type=str, help='the dataset path')
-	parser.add_argument('--use_gpu', action='store_true')
-	args = parser.parse_args()
-	
-	if args.use_gpu:
-		print('Using GPU')
-		dev = device.create_cuda_gpu()
-		layer.engine = 'cudnn'
-	else:
-		print('Using CPU')
-		dev = device.get_default_device()
-		layer.engine = 'singacpp'
+    parser = argparse.ArgumentParser(description='Train GAN over MNIST')
+    parser.add_argument('filepath', type=str, help='the dataset path')
+    parser.add_argument('--use_gpu', action='store_true')
+    args = parser.parse_args()
 
-	if not os.path.exists('vanilla_images/'):
-		os.makedirs('vanilla_images/')
+    if args.use_gpu:
+        print('Using GPU')
+        dev = device.create_cuda_gpu()
+    else:
+        print('Using CPU')
+        dev = device.get_default_device()
 
-	rows = 28
-	cols = 28
-	channels = 1
-	noise_size = 100
-	hidden_size = 128
-	batch = 128
-	interval = 1000
-	learning_rate = 0.001
-	epochs = 1000000
-	dataset_filepath = 'mnist.pkl.gz'
-	file_dir = 'vanilla_images/'
-	vanilla = VANILLA(dev, rows, cols, channels, noise_size, hidden_size, batch, 
-		interval, learning_rate, epochs, dataset_filepath, file_dir)
-	vanilla.train()
\ No newline at end of file
+    if not os.path.exists('vanilla_images/'):
+        os.makedirs('vanilla_images/')
+
+    rows = 28
+    cols = 28
+    channels = 1
+    noise_size = 100
+    hidden_size = 128
+    batch = 128
+    interval = 1000
+    learning_rate = 0.0005
+    iterations = 1000000
+    dataset_filepath = 'mnist.pkl.gz'
+    file_dir = 'vanilla_images/'
+    vanilla = VANILLA(dev, rows, cols, channels, noise_size, hidden_size, batch,
+                      interval, learning_rate, iterations, dataset_filepath,
+                      file_dir)
+    vanilla.train()
diff --git a/examples/mlp/module.py b/examples/mlp/module.py
index 6adc03b..ab6a0bf 100644
--- a/examples/mlp/module.py
+++ b/examples/mlp/module.py
@@ -17,51 +17,35 @@
 # under the License.
 #
 
-from singa import module
-from singa import autograd
+from singa import layer
+from singa import model
 from singa import tensor
-from singa.tensor import Tensor
 
 
-class MLP(module.Module):
+class MLP(model.Model):
 
     def __init__(self, data_size=10, perceptron_size=100, num_classes=10):
         super(MLP, self).__init__()
         self.num_classes = num_classes
         self.dimension = 2
 
-        self.w0 = Tensor(shape=(data_size, perceptron_size),
-                         requires_grad=True,
-                         stores_grad=True)
-        self.w0.gaussian(0.0, 0.1)
-        self.b0 = Tensor(shape=(perceptron_size,),
-                         requires_grad=True,
-                         stores_grad=True)
-        self.b0.set_value(0.0)
-
-        self.w1 = Tensor(shape=(perceptron_size, num_classes),
-                         requires_grad=True,
-                         stores_grad=True)
-        self.w1.gaussian(0.0, 0.1)
-        self.b1 = Tensor(shape=(num_classes,),
-                         requires_grad=True,
-                         stores_grad=True)
-        self.b1.set_value(0.0)
+        self.relu = layer.ReLU()
+        self.linear1 = layer.Linear(perceptron_size)
+        self.linear2 = layer.Linear(num_classes)
+        self.softmax_cross_entropy = layer.SoftMaxCrossEntropy()
 
     def forward(self, inputs):
-        x = autograd.matmul(inputs, self.w0)
-        x = autograd.add_bias(x, self.b0)
-        x = autograd.relu(x)
-        x = autograd.matmul(x, self.w1)
-        x = autograd.add_bias(x, self.b1)
-        return x
+        y = self.linear1(inputs)
+        y = self.relu(y)
+        y = self.linear2(y)
+        return y
 
-    def loss(self, out, ty):
-        return autograd.softmax_cross_entropy(out, ty)
+    def train_one_batch(self, x, y, dist_option, spars):
+        out = self.forward(x)
+        loss = self.softmax_cross_entropy(out, y)
 
-    def optim(self, loss, dist_option, spars):
         if dist_option == 'fp32':
-            self.optimizer.backward_and_update(loss)
+            self.optimizer(loss)
         elif dist_option == 'fp16':
             self.optimizer.backward_and_update_half(loss)
         elif dist_option == 'partialUpdate':
@@ -74,6 +58,7 @@
             self.optimizer.backward_and_sparse_update(loss,
                                                       topK=False,
                                                       spars=spars)
+        return out, loss
 
     def set_optimizer(self, optimizer):
         self.optimizer = optimizer
@@ -116,17 +101,14 @@
     model = MLP(data_size=2, perceptron_size=3, num_classes=2)
 
     # attached model to graph
-    model.on_device(dev)
     model.set_optimizer(sgd)
-    model.graph(True, False)
+    model.compile([tx], is_train=True, use_graph=True, sequential=False)
     model.train()
 
     for i in range(1001):
         tx.copy_from_numpy(data)
         ty.copy_from_numpy(label)
-        out = model(tx)
-        loss = model.loss(out, ty)
-        model.optim(loss, 'fp32', spars=None)
+        out, loss = model(tx, ty, 'fp32', spars=None)
 
         if i % 100 == 0:
             print("training loss = ", tensor.to_numpy(loss)[0])
diff --git a/examples/mlp/native.py b/examples/mlp/native.py
index f1283d6..00f4c0d 100644
--- a/examples/mlp/native.py
+++ b/examples/mlp/native.py
@@ -84,7 +84,7 @@
         x = autograd.matmul(x, w1)
         x = autograd.add_bias(x, b1)
         loss = autograd.softmax_cross_entropy(x, target)
-        sgd.backward_and_update(loss)
+        sgd(loss)
 
         if i % 100 == 0:
             print("training loss = ", tensor.to_numpy(loss)[0])
diff --git a/examples/onnx/arcface.py b/examples/onnx/arcface.py
index e1cfa18..6050418 100644
--- a/examples/onnx/arcface.py
+++ b/examples/onnx/arcface.py
@@ -24,10 +24,9 @@
 
 from singa import device
 from singa import tensor
-from singa import autograd
 from singa import sonnx
 import onnx
-from utils import download_model, update_batch_size, check_exist_or_download
+from utils import download_model, check_exist_or_download
 
 import logging
 logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s')
@@ -54,18 +53,17 @@
     return img1, img2
 
 
-class Infer:
+class MyModel(sonnx.SONNXModel):
 
-    def __init__(self, sg_ir):
-        self.sg_ir = sg_ir
-        for idx, tens in sg_ir.tensor_map.items():
-            # allow the tensors to be updated
-            tens.requires_grad = True
-            tens.stores_grad = True
-            sg_ir.tensor_map[idx] = tens
+    def __init__(self, onnx_model):
+        super(MyModel, self).__init__(onnx_model)
 
-    def forward(self, x):
-        return sg_ir.run([x])[0]
+    def forward(self, *x):
+        y = super(MyModel, self).forward(*x)
+        return y[0]
+
+    def train_one_batch(self, x, y):
+        pass
 
 
 if __name__ == "__main__":
@@ -78,35 +76,30 @@
     download_model(url)
     onnx_model = onnx.load(model_path)
 
-    # set batch size
-    onnx_model = update_batch_size(onnx_model, 2)
-
-    # prepare the model
-    logging.info("prepare model...")
-    dev = device.create_cuda_gpu()
-    sg_ir = sonnx.prepare(onnx_model, device=dev)
-    autograd.training = False
-    model = Infer(sg_ir)
-
-    # verifty the test dataset
-    # from utils import load_dataset
-    # inputs, ref_outputs = load_dataset(
-    #     os.path.join('/tmp', 'resnet100', 'test_data_set_0'))
-    # x_batch = tensor.Tensor(device=dev, data=inputs[0])
-    # outputs = model.forward(x_batch)
-    # for ref_o, o in zip(ref_outputs, outputs):
-    #     np.testing.assert_almost_equal(ref_o, tensor.to_numpy(o), 4)
-
     # inference demo
     logging.info("preprocessing...")
     img1, img2 = get_image()
     img1 = preprocess(img1)
     img2 = preprocess(img2)
+    # sg_ir = sonnx.prepare(onnx_model) # run without graph
+    # y = sg_ir.run([img1, img2])
 
-    x_batch = tensor.Tensor(device=dev,
-                            data=np.concatenate((img1, img2), axis=0))
+    logging.info("model compling...")
+    dev = device.create_cuda_gpu()
+    x = tensor.Tensor(device=dev, data=np.concatenate((img1, img2), axis=0))
+    model = MyModel(onnx_model)
+    model.compile([x], is_train=False, use_graph=True, sequential=True)
+
+    # verifty the test
+    # from utils import load_dataset
+    # inputs, ref_outputs = load_dataset(os.path.join('/tmp', 'resnet100', 'test_data_set_0'))
+    # x_batch = tensor.Tensor(device=dev, data=inputs[0])
+    # outputs = sg_ir.run([x_batch])
+    # for ref_o, o in zip(ref_outputs, outputs):
+    #     np.testing.assert_almost_equal(ref_o, tensor.to_numpy(o), 4)
+
     logging.info("model running...")
-    y = model.forward(x_batch)
+    y = model.forward(x)
 
     logging.info("postprocessing...")
     embedding = tensor.to_numpy(y)
@@ -120,4 +113,4 @@
     sim = np.dot(embedding1, embedding2.T)
     # logging.info predictions
     logging.info('Distance = %f' % (dist))
-    logging.info('Similarity = %f' % (sim))
+    logging.info('Similarity = %f' % (sim))
\ No newline at end of file
diff --git a/examples/onnx/bert/bert-squad.py b/examples/onnx/bert/bert-squad.py
index e4a8488..936968e 100644
--- a/examples/onnx/bert/bert-squad.py
+++ b/examples/onnx/bert/bert-squad.py
@@ -24,14 +24,13 @@
 from singa import device
 from singa import tensor
 from singa import sonnx
-from singa import autograd
 import onnx
 import tokenization
 from run_onnx_squad import read_squad_examples, convert_examples_to_features, RawResult, write_predictions
 
 import sys
 sys.path.append(os.path.dirname(__file__) + '/..')
-from utils import download_model, update_batch_size, check_exist_or_download
+from utils import download_model, check_exist_or_download
 
 import logging
 logging.basicConfig(level=logging.INFO, format='%(asctime)-15s %(message)s')
@@ -54,15 +53,6 @@
     return filename
 
 
-class Infer:
-
-    def __init__(self, sg_ir):
-        self.sg_ir = sg_ir
-
-    def forward(self, x):
-        return sg_ir.run(x)
-
-
 def preprocess():
     vocab_file = load_vocab()
     tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file,
@@ -96,6 +86,19 @@
         print("The result is:", json.dumps(test_data, indent=2))
 
 
+class MyModel(sonnx.SONNXModel):
+
+    def __init__(self, onnx_model):
+        super(MyModel, self).__init__(onnx_model)
+
+    def forward(self, *x):
+        y = super(MyModel, self).forward(*x)
+        return y
+
+    def train_one_batch(self, x, y):
+        pass
+
+
 if __name__ == "__main__":
 
     url = 'https://media.githubusercontent.com/media/onnx/models/master/text/machine_comprehension/bert-squad/model/bertsquad-10.tar.gz'
@@ -107,16 +110,12 @@
     download_model(url)
     onnx_model = onnx.load(model_path)
 
-    # set batch size
-    onnx_model = update_batch_size(onnx_model, batch_size)
-    dev = device.create_cuda_gpu()
-    autograd.training = False
-
     # inference
     logging.info("preprocessing...")
     input_ids, input_mask, segment_ids, extra_data, eval_examples = preprocess()
 
-    sg_ir = None
+    m = None
+    dev = device.create_cuda_gpu()
     n = len(input_ids)
     bs = batch_size
     all_results = []
@@ -132,23 +131,20 @@
             input_ids[idx:idx + bs].astype(np.int32),
         ]
 
-        if sg_ir is None:
-            # prepare the model
-            logging.info("model is none, prepare model...")
-            sg_ir = sonnx.prepare(onnx_model,
-                                  device=dev,
-                                  init_inputs=inputs,
-                                  keep_initializers_as_inputs=False)
-            model = Infer(sg_ir)
-
         x_batch = []
         for inp in inputs:
             tmp_tensor = tensor.from_numpy(inp)
             tmp_tensor.to_device(dev)
             x_batch.append(tmp_tensor)
 
+        # prepare the model
+        if m is None:
+            logging.info("model compling...")
+            m = MyModel(onnx_model)
+            # m.compile(x_batch, is_train=False, use_graph=True, sequential=True)
+
         logging.info("model running for sample {}...".format(idx))
-        outputs = model.forward(x_batch)
+        outputs = m.forward(*x_batch)
 
         logging.info("hanlde the result of sample {}...".format(idx))
         result = []
diff --git a/examples/onnx/bert/tokenization.py b/examples/onnx/bert/tokenization.py
index 4dd0a31..09b9b4f 100644
--- a/examples/onnx/bert/tokenization.py
+++ b/examples/onnx/bert/tokenization.py
@@ -86,8 +86,6 @@
   elif six.PY2:
     if isinstance(text, str):
       return text.decode("utf-8", "ignore")
-    elif isinstance(text, unicode):
-      return text
     else:
       raise ValueError("Unsupported string type: %s" % (type(text)))
   else:
@@ -109,8 +107,6 @@
   elif six.PY2:
     if isinstance(text, str):
       return text
-    elif isinstance(text, unicode):
-      return text.encode("utf-8")
     else:
       raise ValueError("Unsupported string type: %s" % (type(text)))
   else:
diff --git a/examples/onnx/fer_emotion.py b/examples/onnx/fer_emotion.py
index 46c0142..e980580 100644
--- a/examples/onnx/fer_emotion.py
+++ b/examples/onnx/fer_emotion.py
@@ -22,10 +22,9 @@
 
 from singa import device
 from singa import tensor
-from singa import autograd
 from singa import sonnx
 import onnx
-from utils import download_model, update_batch_size, check_exist_or_download
+from utils import download_model, check_exist_or_download
 
 import logging
 logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s')
@@ -51,18 +50,17 @@
     return img, labels
 
 
-class Infer:
+class MyModel(sonnx.SONNXModel):
 
-    def __init__(self, sg_ir):
-        self.sg_ir = sg_ir
-        for idx, tens in sg_ir.tensor_map.items():
-            # allow the tensors to be updated
-            tens.requires_grad = True
-            tens.stores_grad = True
-            sg_ir.tensor_map[idx] = tens
+    def __init__(self, onnx_model):
+        super(MyModel, self).__init__(onnx_model)
 
-    def forward(self, x):
-        return sg_ir.run([x])[0]
+    def forward(self, *x):
+        y = super(MyModel, self).forward(*x)
+        return y[0]
+
+    def train_one_batch(self, x, y):
+        pass
 
 
 if __name__ == "__main__":
@@ -75,33 +73,30 @@
     download_model(url)
     onnx_model = onnx.load(model_path)
 
-    # set batch size
-    onnx_model = update_batch_size(onnx_model, 1)
+    # inference
+    logging.info("preprocessing...")
+    img, labels = get_image_labe()
+    img = preprocess(img)
+    # sg_ir = sonnx.prepare(onnx_model) # run without graph
+    # y = sg_ir.run([img])
 
-    # prepare the model
-    logging.info("prepare model...")
+    logging.info("model compling...")
     dev = device.create_cuda_gpu()
-    sg_ir = sonnx.prepare(onnx_model, device=dev)
-    autograd.training = False
-    model = Infer(sg_ir)
+    x = tensor.PlaceHolder(img.shape, device=dev)
+    model = MyModel(onnx_model)
+    model.compile([x], is_train=False, use_graph=True, sequential=True)
 
     # verifty the test
     # from utils import load_dataset
     # inputs, ref_outputs = load_dataset(os.path.join('/tmp', 'emotion_ferplus', 'test_data_set_0'))
     # x_batch = tensor.Tensor(device=dev, data=inputs[0])
-    # outputs = model.forward(x_batch)
+    # outputs = sg_ir.run([x_batch])
     # for ref_o, o in zip(ref_outputs, outputs):
     #     np.testing.assert_almost_equal(ref_o, tensor.to_numpy(o), 4)
 
-    # inference
-    logging.info("preprocessing...")
-    img, labels = get_image_labe()
-    img = preprocess(img)
-
-    x_batch = tensor.Tensor(device=dev, data=img)
-
     logging.info("model running...")
-    y = model.forward(x_batch)
+    x = tensor.Tensor(device=dev, data=img)
+    y = model.forward(x)
 
     logging.info("postprocessing...")
     y = tensor.softmax(y)
diff --git a/examples/onnx/gpt2/gpt2.py b/examples/onnx/gpt2/gpt2.py
new file mode 100644
index 0000000..56b7cfe
--- /dev/null
+++ b/examples/onnx/gpt2/gpt2.py
@@ -0,0 +1,110 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import os
+import numpy as np
+
+from singa import device
+from singa import tensor
+from singa import sonnx
+from singa import autograd
+import onnx
+
+import sys
+sys.path.append(os.path.dirname(__file__) + '/..')
+from utils import download_model
+
+import logging
+logging.basicConfig(level=logging.INFO, format='%(asctime)-15s %(message)s')
+
+from transformers import GPT2Tokenizer
+
+tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
+length = 20
+
+
+def preprocess():
+    text = "Here is some text to encode : Hello World"
+    tokens = tokenizer.encode(text)
+    tokens = np.array(tokens)
+    return tokens.reshape([1, 1, -1]).astype(np.float32)
+
+
+def postprocess(out):
+    text = tokenizer.decode(out)
+    return text
+
+
+class MyModel(sonnx.SONNXModel):
+
+    def __init__(self, onnx_model):
+        super(MyModel, self).__init__(onnx_model)
+
+    def forward(self, *x):
+        y = super(MyModel, self).forward(*x)
+        return y[0]
+
+    def train_one_batch(self, x, y):
+        pass
+
+
+if __name__ == "__main__":
+    url = 'https://github.com/onnx/models/raw/master/text/machine_comprehension/gpt-2/model/gpt2-lm-head-10.tar.gz'
+    download_dir = '/tmp/'
+    model_path = os.path.join(download_dir, 'GPT-2-LM-HEAD', 'model.onnx')
+
+    logging.info("onnx load model...")
+    download_model(url)
+    onnx_model = onnx.load(model_path)
+
+    # inference
+    logging.info("preprocessing...")
+    input_ids = preprocess()
+
+    logging.info("model compling...")
+    dev = device.get_default_device()
+    x = tensor.Tensor(device=dev, data=input_ids)
+    model = MyModel(onnx_model)
+
+    # verifty the test
+    # from utils import load_dataset
+    # sg_ir = sonnx.prepare(onnx_model) # run without graph
+    # inputs, ref_outputs = load_dataset(
+    #     os.path.join('/tmp', 'GPT-2-LM-HEAD', 'test_data_set_0'))
+    # outputs = sg_ir.run(inputs)
+    # for ref_o, o in zip(ref_outputs, outputs):
+    #     np.testing.assert_almost_equal(ref_o, o, 4)
+
+    logging.info("model running...")
+    output = []
+
+    for i in range(length):
+        logging.info("word {} generating...".format(i))
+        y = model.forward(x)
+        y = autograd.reshape(y, y.shape[-2:])[-1, :]
+        y = tensor.softmax(y)
+        y = tensor.to_numpy(y)[0]
+        y = np.argsort(y)[-1]
+        output.append(y)
+        y = np.array([y]).reshape([1, 1, -1]).astype(np.float32)
+        y = tensor.Tensor(device=dev, data=y)
+        x = tensor.concatenate([x, y], 2)
+
+    text = tokenizer.decode(output)
+    print(text)
\ No newline at end of file
diff --git a/examples/onnx/gpt2/requirements.txt b/examples/onnx/gpt2/requirements.txt
new file mode 100644
index 0000000..14693ad
--- /dev/null
+++ b/examples/onnx/gpt2/requirements.txt
@@ -0,0 +1 @@
+transformers==2.5.1
\ No newline at end of file
diff --git a/examples/onnx/mnist.py b/examples/onnx/mnist.py
deleted file mode 100644
index cf36727..0000000
--- a/examples/onnx/mnist.py
+++ /dev/null
@@ -1,320 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under th
-
-import os
-import gzip
-import numpy as np
-import codecs
-
-from singa import device
-from singa import tensor
-from singa import opt
-from singa import autograd
-from singa import sonnx
-import onnx
-from utils import check_exist_or_download
-
-import logging
-logging.basicConfig(level=logging.INFO, format='%(asctime)-15s %(message)s')
-
-
-def load_dataset():
-    train_x_url = 'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz'
-    train_y_url = 'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz'
-    valid_x_url = 'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz'
-    valid_y_url = 'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz'
-    train_x = read_image_file(check_exist_or_download(train_x_url)).astype(
-        np.float32)
-    train_y = read_label_file(check_exist_or_download(train_y_url)).astype(
-        np.float32)
-    valid_x = read_image_file(check_exist_or_download(valid_x_url)).astype(
-        np.float32)
-    valid_y = read_label_file(check_exist_or_download(valid_y_url)).astype(
-        np.float32)
-    return train_x, train_y, valid_x, valid_y
-
-
-def read_label_file(path):
-    with gzip.open(path, 'rb') as f:
-        data = f.read()
-        assert get_int(data[:4]) == 2049
-        length = get_int(data[4:8])
-        parsed = np.frombuffer(data, dtype=np.uint8, offset=8).reshape((length))
-        return parsed
-
-
-def get_int(b):
-    return int(codecs.encode(b, 'hex'), 16)
-
-
-def read_image_file(path):
-    with gzip.open(path, 'rb') as f:
-        data = f.read()
-        assert get_int(data[:4]) == 2051
-        length = get_int(data[4:8])
-        num_rows = get_int(data[8:12])
-        num_cols = get_int(data[12:16])
-        parsed = np.frombuffer(data, dtype=np.uint8, offset=16).reshape(
-            (length, 1, num_rows, num_cols))
-        return parsed
-
-
-def to_categorical(y, num_classes):
-    y = np.array(y, dtype="int")
-    n = y.shape[0]
-    categorical = np.zeros((n, num_classes))
-    categorical[np.arange(n), y] = 1
-    categorical = categorical.astype(np.float32)
-    return categorical
-
-
-class CNN:
-
-    def __init__(self):
-        self.conv1 = autograd.Conv2d(1, 20, 5, padding=0)
-        self.conv2 = autograd.Conv2d(20, 50, 5, padding=0)
-        self.linear1 = autograd.Linear(4 * 4 * 50, 500, bias=False)
-        self.linear2 = autograd.Linear(500, 10, bias=False)
-        self.pooling1 = autograd.MaxPool2d(2, 2, padding=0)
-        self.pooling2 = autograd.MaxPool2d(2, 2, padding=0)
-
-    def forward(self, x):
-        y = self.conv1(x)
-        y = autograd.relu(y)
-        y = self.pooling1(y)
-        y = self.conv2(y)
-        y = autograd.relu(y)
-        y = self.pooling2(y)
-        y = autograd.flatten(y)
-        y = self.linear1(y)
-        y = autograd.relu(y)
-        y = self.linear2(y)
-        return y
-
-
-def accuracy(pred, target):
-    y = np.argmax(pred, axis=1)
-    t = np.argmax(target, axis=1)
-    a = y == t
-    return np.array(a, "int").sum() / float(len(t))
-
-
-def train(model,
-          x,
-          y,
-          epochs=1,
-          batch_size=64,
-          dev=device.get_default_device()):
-    batch_number = x.shape[0] // batch_size
-
-    for i in range(epochs):
-        for b in range(batch_number):
-            l_idx = b * batch_size
-            r_idx = (b + 1) * batch_size
-
-            x_batch = tensor.Tensor(device=dev, data=x[l_idx:r_idx])
-            target_batch = tensor.Tensor(device=dev, data=y[l_idx:r_idx])
-
-            output_batch = model.forward(x_batch)
-
-            loss = autograd.softmax_cross_entropy(output_batch, target_batch)
-            accuracy_rate = accuracy(tensor.to_numpy(output_batch),
-                                     tensor.to_numpy(target_batch))
-
-            sgd = opt.SGD(lr=0.001)
-            for p, gp in autograd.backward(loss):
-                sgd.update(p, gp)
-            sgd.step()
-
-            if b % 1e2 == 0:
-                logging.info("acc %6.2f loss, %6.2f" %
-                             (accuracy_rate, tensor.to_numpy(loss)[0]))
-    logging.info("training completed")
-    return x_batch, output_batch
-
-
-def make_onnx(x, y):
-    return sonnx.to_onnx([x], [y])
-
-
-class Infer:
-
-    def __init__(self, sg_ir):
-        self.sg_ir = sg_ir
-        for idx, tens in sg_ir.tensor_map.items():
-            # allow the tensors to be updated
-            tens.requires_grad = True
-            tens.stores_grad = True
-
-    def forward(self, x):
-        return sg_ir.run([x])[0]
-
-
-def re_train(sg_ir,
-             x,
-             y,
-             epochs=1,
-             batch_size=64,
-             dev=device.get_default_device()):
-    batch_number = x.shape[0] // batch_size
-
-    new_model = Infer(sg_ir)
-
-    for i in range(epochs):
-        for b in range(batch_number):
-            l_idx = b * batch_size
-            r_idx = (b + 1) * batch_size
-
-            x_batch = tensor.Tensor(device=dev, data=x[l_idx:r_idx])
-            target_batch = tensor.Tensor(device=dev, data=y[l_idx:r_idx])
-
-            output_batch = new_model.forward(x_batch)
-
-            loss = autograd.softmax_cross_entropy(output_batch, target_batch)
-            accuracy_rate = accuracy(tensor.to_numpy(output_batch),
-                                     tensor.to_numpy(target_batch))
-
-            sgd = opt.SGD(lr=0.01)
-            for p, gp in autograd.backward(loss):
-                sgd.update(p, gp)
-            sgd.step()
-
-            if b % 1e2 == 0:
-                logging.info("acc %6.2f loss, %6.2f" %
-                             (accuracy_rate, tensor.to_numpy(loss)[0]))
-    logging.info("re-training completed")
-    return new_model
-
-
-class Trans:
-
-    def __init__(self, sg_ir, last_layers):
-        self.sg_ir = sg_ir
-        self.last_layers = last_layers
-        self.append_linear1 = autograd.Linear(500, 128, bias=False)
-        self.append_linear2 = autograd.Linear(128, 32, bias=False)
-        self.append_linear3 = autograd.Linear(32, 10, bias=False)
-
-    def forward(self, x):
-        y = sg_ir.run([x], last_layers=self.last_layers)[0]
-        y = self.append_linear1(y)
-        y = autograd.relu(y)
-        y = self.append_linear2(y)
-        y = autograd.relu(y)
-        y = self.append_linear3(y)
-        y = autograd.relu(y)
-        return y
-
-
-def transfer_learning(sg_ir,
-                      x,
-                      y,
-                      epochs=1,
-                      batch_size=64,
-                      dev=device.get_default_device()):
-    batch_number = x.shape[0] // batch_size
-
-    trans_model = Trans(sg_ir, -1)
-
-    for i in range(epochs):
-        for b in range(batch_number):
-            l_idx = b * batch_size
-            r_idx = (b + 1) * batch_size
-
-            x_batch = tensor.Tensor(device=dev, data=x[l_idx:r_idx])
-            target_batch = tensor.Tensor(device=dev, data=y[l_idx:r_idx])
-            output_batch = trans_model.forward(x_batch)
-
-            loss = autograd.softmax_cross_entropy(output_batch, target_batch)
-            accuracy_rate = accuracy(tensor.to_numpy(output_batch),
-                                     tensor.to_numpy(target_batch))
-
-            sgd = opt.SGD(lr=0.07)
-            for p, gp in autograd.backward(loss):
-                sgd.update(p, gp)
-            sgd.step()
-
-            if b % 1e2 == 0:
-                logging.info("acc %6.2f loss, %6.2f" %
-                             (accuracy_rate, tensor.to_numpy(loss)[0]))
-    logging.info("transfer-learning completed")
-    return trans_model
-
-
-def test(model, x, y, batch_size=64, dev=device.get_default_device()):
-    batch_number = x.shape[0] // batch_size
-
-    result = 0
-    for b in range(batch_number):
-        l_idx = b * batch_size
-        r_idx = (b + 1) * batch_size
-
-        x_batch = tensor.Tensor(device=dev, data=x[l_idx:r_idx])
-        target_batch = tensor.Tensor(device=dev, data=y[l_idx:r_idx])
-
-        output_batch = model.forward(x_batch)
-        result += accuracy(tensor.to_numpy(output_batch),
-                           tensor.to_numpy(target_batch))
-
-    logging.info("testing acc %6.2f" % (result / batch_number))
-
-
-if __name__ == "__main__":
-    # create device
-    dev = device.create_cuda_gpu()
-    #dev = device.get_default_device()
-    # create model
-    model = CNN()
-    # load data
-    train_x, train_y, valid_x, valid_y = load_dataset()
-    # normalization
-    train_x = train_x / 255
-    valid_x = valid_x / 255
-    train_y = to_categorical(train_y, 10)
-    valid_y = to_categorical(valid_y, 10)
-    # do training
-    autograd.training = True
-    x, y = train(model, train_x, train_y, dev=dev)
-    onnx_model = make_onnx(x, y)
-    # logging.info('The model is:\n{}'.format(onnx_model))
-
-    # Save the ONNX model
-    model_path = os.path.join('/', 'tmp', 'mnist.onnx')
-    onnx.save(onnx_model, model_path)
-    logging.info('The model is saved.')
-
-    # load the ONNX model
-    onnx_model = onnx.load(model_path)
-    sg_ir = sonnx.prepare(onnx_model, device=dev)
-
-    # inference
-    autograd.training = False
-    logging.info('The inference result is:')
-    test(Infer(sg_ir), valid_x, valid_y, dev=dev)
-
-    # re-training
-    autograd.training = True
-    new_model = re_train(sg_ir, train_x, train_y, dev=dev)
-    autograd.training = False
-    test(new_model, valid_x, valid_y, dev=dev)
-
-    # transfer-learning
-    autograd.training = True
-    new_model = transfer_learning(sg_ir, train_x, train_y, dev=dev)
-    autograd.training = False
-    test(new_model, valid_x, valid_y, dev=dev)
\ No newline at end of file
diff --git a/examples/onnx/mobilenet.py b/examples/onnx/mobilenet.py
index 75758f1..ad394ca 100644
--- a/examples/onnx/mobilenet.py
+++ b/examples/onnx/mobilenet.py
@@ -22,10 +22,9 @@
 
 from singa import device
 from singa import tensor
-from singa import autograd
 from singa import sonnx
 import onnx
-from utils import download_model, update_batch_size, check_exist_or_download
+from utils import download_model, check_exist_or_download
 
 import logging
 logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s')
@@ -56,18 +55,17 @@
     return img, labels
 
 
-class Infer:
+class MyModel(sonnx.SONNXModel):
 
-    def __init__(self, sg_ir):
-        self.sg_ir = sg_ir
-        for idx, tens in sg_ir.tensor_map.items():
-            # allow the tensors to be updated
-            tens.requires_grad = True
-            tens.stores_grad = True
-            sg_ir.tensor_map[idx] = tens
+    def __init__(self, onnx_model):
+        super(MyModel, self).__init__(onnx_model)
 
-    def forward(self, x):
-        return sg_ir.run([x])[0]
+    def forward(self, *x):
+        y = super(MyModel, self).forward(*x)
+        return y[0]
+
+    def train_one_batch(self, x, y):
+        pass
 
 
 if __name__ == "__main__":
@@ -81,32 +79,30 @@
     download_model(url)
     onnx_model = onnx.load(model_path)
 
-    # set batch size
-    onnx_model = update_batch_size(onnx_model, 1)
-
-    # prepare the model
-    logging.info("prepare model...")
-    dev = device.create_cuda_gpu()
-    sg_ir = sonnx.prepare(onnx_model, device=dev)
-    autograd.training = False
-    model = Infer(sg_ir)
-
-    # verifty the test dataset
-    # from utils import load_dataset
-    # inputs, ref_outputs = load_dataset(os.path.join('/tmp', 'mobilenetv2-1.0', 'test_data_set_0'))
-    # x_batch = tensor.Tensor(device=dev, data=inputs[0])
-    # outputs = model.forward(x_batch)
-    # for ref_o, o in zip(ref_outputs, outputs):
-    #     np.testing.assert_almost_equal(ref_o, tensor.to_numpy(o), 4)
-
     # inference
     logging.info("preprocessing...")
     img, labels = get_image_labe()
     img = preprocess(img)
+    # sg_ir = sonnx.prepare(onnx_model) # run without graph
+    # y = sg_ir.run([img])
+
+    logging.info("model compling...")
+    dev = device.create_cuda_gpu()
+    x = tensor.PlaceHolder(img.shape, device=dev)
+    model = MyModel(onnx_model)
+    model.compile([x], is_train=False, use_graph=True, sequential=True)
+
+    # verifty the test
+    # from utils import load_dataset
+    # inputs, ref_outputs = load_dataset(os.path.join('/tmp', 'mobilenetv2-1.0', 'test_data_set_0'))
+    # x_batch = tensor.Tensor(device=dev, data=inputs[0])
+    # outputs = sg_ir.run([x_batch])
+    # for ref_o, o in zip(ref_outputs, outputs):
+    #     np.testing.assert_almost_equal(ref_o, tensor.to_numpy(o), 4)
 
     logging.info("model running...")
-    x_batch = tensor.Tensor(device=dev, data=img)
-    y = model.forward(x_batch)
+    x = tensor.Tensor(device=dev, data=img)
+    y = model.forward(x)
 
     logging.info("postprocessing...")
     y = tensor.softmax(y)
diff --git a/examples/onnx/resnet18.py b/examples/onnx/resnet18.py
index b3381c0..b66c3fb 100644
--- a/examples/onnx/resnet18.py
+++ b/examples/onnx/resnet18.py
@@ -22,10 +22,9 @@
 
 from singa import device
 from singa import tensor
-from singa import autograd
 from singa import sonnx
 import onnx
-from utils import download_model, update_batch_size, check_exist_or_download
+from utils import download_model, check_exist_or_download
 
 import logging
 logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s')
@@ -55,19 +54,17 @@
     img = Image.open(check_exist_or_download(image_url))
     return img, labels
 
+class MyModel(sonnx.SONNXModel):
 
-class Infer:
+    def __init__(self, onnx_model):
+        super(MyModel, self).__init__(onnx_model)
 
-    def __init__(self, sg_ir):
-        self.sg_ir = sg_ir
-        for idx, tens in sg_ir.tensor_map.items():
-            # allow the tensors to be updated
-            tens.requires_grad = True
-            tens.stores_grad = True
-            sg_ir.tensor_map[idx] = tens
+    def forward(self, *x):
+        y = super(MyModel, self).forward(*x)
+        return y[0]
 
-    def forward(self, x):
-        return sg_ir.run([x])[0]
+    def train_one_batch(self, x, y):
+        pass
 
 
 if __name__ == "__main__":
@@ -80,32 +77,30 @@
     download_model(url)
     onnx_model = onnx.load(model_path)
 
-    # set batch size
-    onnx_model = update_batch_size(onnx_model, 1)
+    # inference
+    logging.info("preprocessing...")
+    img, labels = get_image_labe()
+    img = preprocess(img)
+    # sg_ir = sonnx.prepare(onnx_model) # run without graph
+    # y = sg_ir.run([img])
 
-    # prepare the model
-    logging.info("prepare model...")
+    logging.info("model compling...")
     dev = device.create_cuda_gpu()
-    sg_ir = sonnx.prepare(onnx_model, device=dev)
-    autograd.training = False
-    model = Infer(sg_ir)
+    x = tensor.PlaceHolder(img.shape, device=dev)
+    model = MyModel(onnx_model)
+    model.compile([x], is_train=False, use_graph=True, sequential=True)
 
     # verifty the test
     # from utils import load_dataset
     # inputs, ref_outputs = load_dataset(os.path.join('/tmp', 'resnet18v1', 'test_data_set_0'))
     # x_batch = tensor.Tensor(device=dev, data=inputs[0])
-    # outputs = model.forward(x_batch)
+    # outputs = sg_ir.run([x_batch])
     # for ref_o, o in zip(ref_outputs, outputs):
     #     np.testing.assert_almost_equal(ref_o, tensor.to_numpy(o), 4)
 
-    # inference
-    logging.info("preprocessing...")
-    img, labels = get_image_labe()
-    img = preprocess(img)
-
     logging.info("model running...")
-    x_batch = tensor.Tensor(device=dev, data=img)
-    y = model.forward(x_batch)
+    x = tensor.Tensor(device=dev, data=img)
+    y = model.forward(x)
 
     logging.info("postprocessing...")
     y = tensor.softmax(y)
diff --git a/examples/onnx/ro_bert_a.py b/examples/onnx/ro_bert_a.py
new file mode 100644
index 0000000..5b6ac0a
--- /dev/null
+++ b/examples/onnx/ro_bert_a.py
@@ -0,0 +1,103 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+
+import sys
+sys.path.append('/singa/build/python/')
+
+import os
+import numpy as np
+
+from singa import device
+from singa import tensor
+from singa import sonnx
+from singa import autograd
+import onnx
+
+from utils import download_model, check_exist_or_download
+
+import logging
+logging.basicConfig(level=logging.INFO, format='%(asctime)-15s %(message)s')
+
+from transformers import RobertaTokenizer
+
+tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
+
+def preprocess():
+    text = "This film is so good"
+    tokens = tokenizer.encode(text, add_special_tokens=True)
+    tokens = np.array(tokens)
+    return tokens.reshape([1, -1]).astype(np.float32)
+
+
+def postprocess(out):
+    text = tokenizer.decode(out)
+    return text
+
+
+class MyModel(sonnx.SONNXModel):
+
+    def __init__(self, onnx_model):
+        super(MyModel, self).__init__(onnx_model)
+
+    def forward(self, *x):
+        y = super(MyModel, self).forward(*x)
+        return y[0]
+
+    def train_one_batch(self, x, y):
+        pass
+
+
+if __name__ == "__main__":
+    url = 'https://media.githubusercontent.com/media/onnx/models/master/text/machine_comprehension/roberta/model/roberta-sequence-classification-9.tar.gz'
+    download_dir = '/tmp/'
+    model_path = os.path.join(download_dir, 'roberta-sequence-classification-9', 'roberta-sequence-classification-9.onnx')
+
+    logging.info("onnx load model...")
+    download_model(url)
+    onnx_model = onnx.load(model_path)
+
+    # inference
+    logging.info("preprocessing...")
+    input_ids = preprocess()
+
+    logging.info("model compling...")
+    dev = device.get_default_device()
+    x = tensor.Tensor(device=dev, data=input_ids)
+    model = MyModel(onnx_model)
+
+    # verifty the test
+    # from utils import load_dataset
+    # sg_ir = sonnx.prepare(onnx_model) # run without graph
+    # inputs, ref_outputs = load_dataset(
+    #     os.path.join('/tmp', 'roberta-sst-9', 'test_data_set_0'))
+    # outputs = sg_ir.run(inputs)
+    # for ref_o, o in zip(ref_outputs, outputs):
+    #     np.testing.assert_almost_equal(ref_o, o, 4)
+
+    logging.info("model running...")
+    y = model.forward(x)
+    y = autograd.reshape(y, y.shape[-2:])[-1, :]
+    y = tensor.softmax(y)
+    y = tensor.to_numpy(y)[0]
+    y = np.argsort(y)[::-1]
+    if(y[0] == 0):
+        print("Prediction: negative")
+    else:
+        print("Prediction: positive")
\ No newline at end of file
diff --git a/examples/onnx/shufflenetv2.py b/examples/onnx/shufflenetv2.py
new file mode 100644
index 0000000..74dd794
--- /dev/null
+++ b/examples/onnx/shufflenetv2.py
@@ -0,0 +1,114 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under th
+
+import os
+import numpy as np
+from PIL import Image
+
+from singa import device
+from singa import tensor
+from singa import sonnx
+import onnx
+from utils import download_model, check_exist_or_download
+
+import logging
+logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s')
+
+
+def preprocess(img):
+    img = img.resize((256, 256))
+    img = img.crop((16, 16, 240, 240))
+    img = np.array(img).astype(np.float32) / 255.
+    img = np.rollaxis(img, 2, 0)
+    for channel, mean, std in zip(range(3), [0.485, 0.456, 0.406],
+                                  [0.229, 0.224, 0.225]):
+        img[channel, :, :] -= mean
+        img[channel, :, :] /= std
+    img = np.expand_dims(img, axis=0)
+    return img
+
+
+def get_image_labe():
+    # download label
+    label_url = 'https://s3.amazonaws.com/onnx-model-zoo/synset.txt'
+    with open(check_exist_or_download(label_url), 'r') as f:
+        labels = [l.rstrip() for l in f]
+
+    # download image
+    image_url = 'https://s3.amazonaws.com/model-server/inputs/kitten.jpg'
+    img = Image.open(check_exist_or_download(image_url))
+    return img, labels
+
+
+class MyModel(sonnx.SONNXModel):
+
+    def __init__(self, onnx_model):
+        super(MyModel, self).__init__(onnx_model)
+
+    def forward(self, *x):
+        y = super(MyModel, self).forward(*x)
+        return y[0]
+
+    def train_one_batch(self, x, y):
+        pass
+
+
+if __name__ == "__main__":
+
+    url = 'https://github.com/onnx/models/raw/master/vision/classification/shufflenet/model/shufflenet-v2-10.tar.gz'
+    download_dir = '/tmp/'
+    model_path = os.path.join(download_dir, 'model', 'test_shufflenetv2',
+                              'model.onnx')
+
+    logging.info("onnx load model...")
+    download_model(url)
+    onnx_model = onnx.load(model_path)
+
+    # inference
+    logging.info("preprocessing...")
+    img, labels = get_image_labe()
+    img = preprocess(img)
+    # sg_ir = sonnx.prepare(onnx_model) # run without graph
+    # y = sg_ir.run([img])
+
+    logging.info("model compling...")
+    dev = device.create_cuda_gpu()
+    x = tensor.PlaceHolder(img.shape, device=dev)
+    model = MyModel(onnx_model)
+    model.compile([x], is_train=False, use_graph=True, sequential=True)
+
+    # verifty the test
+    # from utils import load_dataset
+    # inputs, ref_outputs = load_dataset(os.path.join('/tmp', 'model', 'test_shufflenetv2',
+    #                           'model.onnx'))
+    # x_batch = tensor.Tensor(device=dev, data=inputs[0])
+    # outputs = sg_ir.run([x_batch])
+    # for ref_o, o in zip(ref_outputs, outputs):
+    #     np.testing.assert_almost_equal(ref_o, tensor.to_numpy(o), 4)
+
+    logging.info("model running...")
+    x = tensor.Tensor(device=dev, data=img)
+    y = model.forward(x)
+
+    logging.info("postprocessing...")
+    y = tensor.softmax(y)
+    scores = tensor.to_numpy(y)
+    scores = np.squeeze(scores)
+    a = np.argsort(scores)[::-1]
+    for i in a[0:5]:
+        logging.info('class=%s ; probability=%f' % (labels[i], scores[i]))
diff --git a/examples/onnx/tiny_yolov2.py b/examples/onnx/tiny_yolov2.py
index e883117..72d3666 100644
--- a/examples/onnx/tiny_yolov2.py
+++ b/examples/onnx/tiny_yolov2.py
@@ -22,10 +22,9 @@
 
 from singa import device
 from singa import tensor
-from singa import autograd
 from singa import sonnx
 import onnx
-from utils import download_model, update_batch_size, check_exist_or_download
+from utils import download_model, check_exist_or_download
 
 import logging
 logging.basicConfig(level=logging.INFO, format='%(asctime)-15s %(message)s')
@@ -45,20 +44,6 @@
     return img
 
 
-class Infer:
-
-    def __init__(self, sg_ir):
-        self.sg_ir = sg_ir
-        for idx, tens in sg_ir.tensor_map.items():
-            # allow the tensors to be updated
-            tens.requires_grad = True
-            tens.stores_grad = True
-            sg_ir.tensor_map[idx] = tens
-
-    def forward(self, x):
-        return sg_ir.run([x])[0]
-
-
 def postprcess(out):
     numClasses = 20
     anchors = [1.08, 1.19, 3.42, 4.41, 6.63, 11.38, 9.42, 5.11, 16.62, 10.52]
@@ -124,6 +109,19 @@
     img.save("result.png")
 
 
+class MyModel(sonnx.SONNXModel):
+
+    def __init__(self, onnx_model):
+        super(MyModel, self).__init__(onnx_model)
+
+    def forward(self, *x):
+        y = super(MyModel, self).forward(*x)
+        return y[0]
+
+    def train_one_batch(self, x, y):
+        pass
+
+
 if __name__ == "__main__":
 
     url = 'https://onnxzoo.blob.core.windows.net/models/opset_8/tiny_yolov2/tiny_yolov2.tar.gz'
@@ -134,33 +132,31 @@
     download_model(url)
     onnx_model = onnx.load(model_path)
 
-    # set batch size
-    onnx_model = update_batch_size(onnx_model, 1)
-
-    # prepare the model
-    logging.info("prepare model...")
-    dev = device.create_cuda_gpu()
-    sg_ir = sonnx.prepare(onnx_model, device=dev)
-    autograd.training = False
-    model = Infer(sg_ir)
-
-    # verifty the test dataset
-    # from utils import load_dataset
-    # inputs, ref_outputs = load_dataset(os.path.join('/tmp', 'tiny_yolov2', 'test_data_set_0'))
-    # x_batch = tensor.Tensor(device=dev, data=inputs[0])
-    # outputs = model.forward(x_batch)
-    # for ref_o, o in zip(ref_outputs, outputs):
-    #     np.testing.assert_almost_equal(ref_o, tensor.to_numpy(o), 4)
-
     # inference
     logging.info("preprocessing...")
     img = get_image()
     img = preprocess(img)
+    # sg_ir = sonnx.prepare(onnx_model) # run without graph
+    # y = sg_ir.run([img])
+
+    logging.info("model compling...")
+    dev = device.create_cuda_gpu()
+    x = tensor.PlaceHolder(img.shape, device=dev)
+    model = MyModel(onnx_model)
+    model.compile([x], is_train=False, use_graph=True, sequential=True)
+
+    # verifty the test
+    # from utils import load_dataset
+    # inputs, ref_outputs = load_dataset(os.path.join('/tmp', 'tiny_yolov2', 'test_data_set_0'))
+    # x_batch = tensor.Tensor(device=dev, data=inputs[0])
+    # outputs = sg_ir.run([x_batch])
+    # for ref_o, o in zip(ref_outputs, outputs):
+    #     np.testing.assert_almost_equal(ref_o, tensor.to_numpy(o), 4)
 
     logging.info("model running...")
-    x_batch = tensor.Tensor(device=dev, data=img)
-    y = model.forward(x_batch)
+    x = tensor.Tensor(device=dev, data=img)
+    y = model.forward(x)
 
     logging.info("postprocessing...")
     out = tensor.to_numpy(y)[0]
-    postprcess(out)
+    postprcess(out)
\ No newline at end of file
diff --git a/examples/onnx/training/model.json b/examples/onnx/training/model.json
new file mode 100644
index 0000000..1fe52b1
--- /dev/null
+++ b/examples/onnx/training/model.json
@@ -0,0 +1,84 @@
+{
+    "resnet18v1": {
+        "name": "ResNet-18 Version 1",
+        "description": "ResNet v1 uses post-activation for the residual blocks",
+        "url": "https://s3.amazonaws.com/onnx-model-zoo/resnet/resnet18v1/resnet18v1.tar.gz",
+        "path": "resnet18v1/resnet18v1.onnx"
+    },
+    "resnet34v1": {
+        "name": "ResNet-34 Version 1",
+        "description": "ResNet v1 uses post-activation for the residual blocks",
+        "url": "https://s3.amazonaws.com/onnx-model-zoo/resnet/resnet34v1/resnet34v1.tar.gz",
+        "path": "resnet34v1/resnet34v1.onnx"
+    },
+    "resnet50v1": {
+        "name": "ResNet-50 Version 1",
+        "description": "ResNet v1 uses post-activation for the residual blocks",
+        "url": "https://s3.amazonaws.com/onnx-model-zoo/resnet/resnet50v1/resnet50v1.tar.gz",
+        "path": "resnet50v1/resnet50v1.onnx"
+    },
+    "resnet101v1": {
+        "name": "ResNet-101 Version 1",
+        "description": "ResNet v1 uses post-activation for the residual blocks",
+        "url": "https://s3.amazonaws.com/onnx-model-zoo/resnet/resnet101v1/resnet101v1.tar.gz",
+        "path": "resnet101v1/resnet101v1.onnx"
+    },
+    "resnet152v1": {
+        "name": "ResNet-152 Version 1",
+        "description": "ResNet v1 uses post-activation for the residual blocks",
+        "url": "https://s3.amazonaws.com/onnx-model-zoo/resnet/resnet152v1/resnet152v1.tar.gz",
+        "path": "resnet152v1/resnet152v1.onnx"
+    },
+    "resnet18v2": {
+        "name": "ResNet-18 Version 2",
+        "description": "ResNet v2 uses pre-activation for the residual blocks",
+        "url": "https://s3.amazonaws.com/onnx-model-zoo/resnet/resnet18v2/resnet18v2.tar.gz",
+        "path": "resnet18v2/resnet18v2.onnx"
+    },
+    "resnet34v2": {
+        "name": "ResNet-34 Version 2",
+        "description": "ResNet v2 uses pre-activation for the residual blocks",
+        "url": "https://s3.amazonaws.com/onnx-model-zoo/resnet/resnet34v2/resnet34v2.tar.gz",
+        "path": "resnet34v2/resnet34v2.onnx"
+    },
+    "resnet50v2": {
+        "name": "ResNet-50 Version 2",
+        "description": "ResNet v2 uses pre-activation for the residual blocks",
+        "url": "https://s3.amazonaws.com/onnx-model-zoo/resnet/resnet50v2/resnet50v2.tar.gz",
+        "path": "resnet50v2/resnet50v2.onnx"
+    },
+    "resnet101v2": {
+        "name": "ResNet-101 Version 2",
+        "description": "ResNet v2 uses pre-activation for the residual blocks",
+        "url": "https://s3.amazonaws.com/onnx-model-zoo/resnet/resnet101v2/resnet101v2.tar.gz",
+        "path": "resnet101v2/resnet101v2.onnx"
+    },
+    "resnet152v2": {
+        "name": "ResNet-152 Version 2",
+        "description": "ResNet v2 uses pre-activation for the residual blocks",
+        "url": "https://s3.amazonaws.com/onnx-model-zoo/resnet/resnet152v2/resnet152v2.tar.gz",
+        "path": "resnet152v2/resnet152v2.onnx"
+    },
+    "vgg16": {
+        "name": "VGG-16",
+        "url": "https://github.com/onnx/models/raw/master/vision/classification/vgg/model/vgg16-7.tar.gz",
+        "path": "vgg16/vgg16.onnx"
+    },
+    "vgg16bn": {
+        "name": "VGG-16 with batch normalization",
+        "description": "VGG have batch normalization applied after each convolutional layer",
+        "url": "https://github.com/onnx/models/raw/master/vision/classification/vgg/model/vgg16-bn-7.tar.gz",
+        "path": "vgg16-bn/vgg16-bn.onnx"
+    },
+    "vgg19": {
+        "name": "VGG-19",
+        "url": "https://github.com/onnx/models/raw/master/vision/classification/vgg/model/vgg16-9.tar.gz",
+        "path": "vgg19/vgg19.onnx"
+    },
+    "vgg19bn": {
+        "name": "VGG-19 with batch normalization",
+        "description": "VGG have batch normalization applied after each convolutional layer",
+        "url": "https://github.com/onnx/models/raw/master/vision/classification/vgg/model/vgg16-bn-9.tar.gz",
+        "path": "vgg19-bn/vgg19-bn.onnx"
+    }
+}
\ No newline at end of file
diff --git a/examples/cnn/train.py b/examples/onnx/training/train.py
similarity index 71%
copy from examples/cnn/train.py
copy to examples/onnx/training/train.py
index 9f75b12..fca072e 100644
--- a/examples/cnn/train.py
+++ b/examples/onnx/training/train.py
@@ -17,15 +17,27 @@
 # under the License.
 #
 
+import sys, os
+import json
 from singa import singa_wrap as singa
 from singa import opt
 from singa import device
 from singa import tensor
+from singa import sonnx
+from singa import layer
+from singa import autograd
 import numpy as np
 import time
 import argparse
 from PIL import Image
+import onnx
+import logging
+from tqdm import tqdm
 
+logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s')
+sys.path.append(os.path.dirname(__file__) + '/../../cnn')
+sys.path.append(os.path.dirname(__file__) + '/..')
+from utils import download_model
 
 # Data Augmentation
 def augmentation(x, batch_size):
@@ -90,15 +102,56 @@
     return X
 
 
+class MyModel(sonnx.SONNXModel):
+
+    def __init__(self, onnx_model, num_classes=10, num_channels=3):
+        super(MyModel, self).__init__(onnx_model)
+        self.num_classes = num_classes
+        self.input_size = 224
+        self.dimension = 4
+        self.num_channels = num_channels
+        self.num_classes = num_classes
+        self.linear = layer.Linear(512, num_classes)
+
+    def forward(self, *x):
+        # if you change to other models, please update the output name here
+        y = super(MyModel, self).forward(*x, aux_output=['flatten_170'])[1]
+        y = self.linear(y)
+        return y
+
+    def train_one_batch(self, x, y, dist_option, spars):
+        out = self.forward(x)
+        loss = autograd.softmax_cross_entropy(out, y)
+        if dist_option == 'fp32':
+            self.optimizer.backward_and_update(loss)
+        elif dist_option == 'fp16':
+            self.optimizer.backward_and_update_half(loss)
+        elif dist_option == 'partialUpdate':
+            self.optimizer.backward_and_partial_update(loss)
+        elif dist_option == 'sparseTopK':
+            self.optimizer.backward_and_sparse_update(loss,
+                                                      topK=True,
+                                                      spars=spars)
+        elif dist_option == 'sparseThreshold':
+            self.optimizer.backward_and_sparse_update(loss,
+                                                      topK=False,
+                                                      spars=spars)
+        return out, loss
+
+    def set_optimizer(self, optimizer):
+        self.optimizer = optimizer
+
+
 def run(global_rank,
         world_size,
         local_rank,
         max_epoch,
         batch_size,
-        model,
+        model_config,
         data,
         sgd,
         graph,
+        verbosity,
         dist_option='fp32',
         spars=None):
     dev = device.create_cuda_gpu_on(local_rank)
@@ -111,41 +164,18 @@
     elif data == 'cifar100':
         from data import cifar100
         train_x, train_y, val_x, val_y = cifar100.load()
-    elif data == 'mnist':
-        from data import mnist
-        train_x, train_y, val_x, val_y = mnist.load()
 
     num_channels = train_x.shape[1]
     image_size = train_x.shape[2]
     data_size = np.prod(train_x.shape[1:train_x.ndim]).item()
     num_classes = (np.max(train_y) + 1).item()
-    #print(num_classes)
 
-    if model == 'resnet':
-        from model import resnet
-        model = resnet.resnet18(num_channels=num_channels,
-                                num_classes=num_classes)
-    elif model == 'xceptionnet':
-        from model import xceptionnet
-        model = xceptionnet.create_model(num_channels=num_channels,
-                                         num_classes=num_classes)
-    elif model == 'cnn':
-        from model import cnn
-        model = cnn.create_model(num_channels=num_channels,
-                                 num_classes=num_classes)
-    elif model == 'alexnet':
-        from model import alexnet
-        model = alexnet.create_model(num_channels=num_channels,
-                                     num_classes=num_classes)
-    elif model == 'mlp':
-        import os, sys, inspect
-        current = os.path.dirname(
-            os.path.abspath(inspect.getfile(inspect.currentframe())))
-        parent = os.path.dirname(current)
-        sys.path.insert(0, parent)
-        from mlp import module
-        model = module.create_model(data_size=data_size,
-                                    num_classes=num_classes)
+    # read and make onnx model
+    download_model(model_config['url'])
+    onnx_model = onnx.load(os.path.join('/tmp', model_config['path']))
+    model = MyModel(onnx_model,
+                    num_channels=num_channels,
+                    num_classes=num_classes)
 
     # For distributed training, sequential gives better performance
     if hasattr(sgd, "communicator"):
@@ -182,9 +212,9 @@
     idx = np.arange(train_x.shape[0], dtype=np.int32)
 
     # attached model to graph
-    model.on_device(dev)
     model.set_optimizer(sgd)
-    model.graph(graph, sequential)
+    model.compile([tx], is_train=True, use_graph=graph, sequential=sequential)
+    dev.SetVerbosity(verbosity)
 
     # Training and Evaluation Loop
     for epoch in range(max_epoch):
@@ -200,7 +230,7 @@
         train_loss = np.zeros(shape=[1], dtype=np.float32)
 
         model.train()
-        for b in range(num_train_batch):
+        for b in tqdm(range(num_train_batch)):
             # Generate the patch data in this iteration
             x = train_x[idx[b * batch_size:(b + 1) * batch_size]]
             if model.dimension == 4:
@@ -214,9 +244,7 @@
             ty.copy_from_numpy(y)
 
             # Train the model
-            out = model(tx)
-            loss = model.loss(out, ty)
-            model.optim(loss, dist_option, spars)
+            out, loss = model(tx, ty, dist_option, spars)
             train_correct += accuracy(tensor.to_numpy(out), y)
             train_loss += tensor.to_numpy(loss)[0]
 
@@ -234,7 +262,7 @@
 
         # Evaluation Phase
         model.eval()
-        for b in range(num_val_batch):
+        for b in tqdm(range(num_val_batch)):
             x = val_x[b * batch_size:(b + 1) * batch_size]
             if model.dimension == 4:
                 if (image_size != model.input_size):
@@ -256,17 +284,29 @@
                    time.time() - start_time),
                   flush=True)
 
+    dev.PrintTimeProfiling()
+
+
+def loss(out, y):
+    return autograd.softmax_cross_entropy(out, y)
+
 
 if __name__ == '__main__':
+
+    with open(os.path.join(os.path.dirname(__file__),
+                           'model.json')) as json_file:
+        model_config = json.load(json_file)
+
     # use argparse to get command config: max_epoch, model, data, etc. for single gpu training
     parser = argparse.ArgumentParser(
         description='Training using the autograd and graph.')
-    parser.add_argument('model',
-                        choices=['resnet', 'xceptionnet', 'cnn', 'mlp', 'alexnet'],
-                        default='cnn')
-    parser.add_argument('data',
-                        choices=['cifar10', 'cifar100', 'mnist'],
-                        default='mnist')
+    parser.add_argument('--model',
+                        choices=list(model_config.keys()),
+                        help='please refer to the models.json for more details',
+                        default='resnet18v1')
+    parser.add_argument('--data',
+                        choices=['cifar10', 'cifar100'],
+                        default='cifar10')
     parser.add_argument('--epoch',
                         '--max-epoch',
                         default=10,
@@ -275,7 +315,7 @@
                         dest='max_epoch')
     parser.add_argument('--bs',
                         '--batch-size',
-                        default=64,
+                        default=32,
                         type=int,
                         help='batch size',
                         dest='batch_size')
@@ -298,9 +338,15 @@
                         action='store_false',
                         help='disable graph',
                         dest='graph')
+    parser.add_argument('--verbosity',
+                        '--log-verbosity',
+                        default=1,
+                        type=int,
+                        help='logging verbosity',
+                        dest='verbosity')
 
     args = parser.parse_args()
 
     sgd = opt.SGD(lr=args.lr, momentum=0.9, weight_decay=1e-5)
-    run(0, 1, args.device_id, args.max_epoch, args.batch_size, args.model,
-        args.data, sgd, args.graph)
+    run(0, 1, args.device_id, args.max_epoch, args.batch_size, model_config[args.model],
+        args.data, sgd, args.graph, args.verbosity)
diff --git a/examples/onnx/utils.py b/examples/onnx/utils.py
index 71d1ef4..b8f7b34 100644
--- a/examples/onnx/utils.py
+++ b/examples/onnx/utils.py
@@ -21,7 +21,6 @@
 import tarfile
 import glob
 import onnx
-from onnx import numpy_helper
 import logging
 logging.basicConfig(level=logging.INFO, format='%(asctime)-15s %(message)s')
 
@@ -41,7 +40,7 @@
         onnx_tensor = onnx.TensorProto()
         with open(input_file, 'rb') as f:
             onnx_tensor.ParseFromString(f.read())
-        inputs.append(numpy_helper.to_array(onnx_tensor))
+        inputs.append(onnx.numpy_helper.to_array(onnx_tensor))
 
     # load reference outputs
     ref_outputs = []
@@ -51,7 +50,7 @@
         onnx_tensor = onnx.TensorProto()
         with open(output_file, 'rb') as f:
             onnx_tensor.ParseFromString(f.read())
-        ref_outputs.append(numpy_helper.to_array(onnx_tensor))
+        ref_outputs.append(onnx.numpy_helper.to_array(onnx_tensor))
     return inputs, ref_outputs
 
 
@@ -63,11 +62,3 @@
         logging.info("Downloading %s" % url)
         urllib.request.urlretrieve(url, filename)
     return filename
-
-
-def update_batch_size(onnx_model, batch_size):
-    model_input = onnx_model.graph.input[0]
-    model_input.type.tensor_type.shape.dim[0].dim_value = batch_size
-    model_output = onnx_model.graph.output[0]
-    model_output.type.tensor_type.shape.dim[0].dim_value = batch_size
-    return onnx_model
diff --git a/examples/onnx/vgg16.py b/examples/onnx/vgg16.py
index b26ea94..369cee9 100644
--- a/examples/onnx/vgg16.py
+++ b/examples/onnx/vgg16.py
@@ -22,10 +22,9 @@
 
 from singa import device
 from singa import tensor
-from singa import autograd
 from singa import sonnx
 import onnx
-from utils import download_model, update_batch_size, check_exist_or_download
+from utils import download_model, check_exist_or_download
 
 import logging
 logging.basicConfig(level=logging.INFO, format='%(asctime)-15s %(message)s')
@@ -56,18 +55,17 @@
     return img, labels
 
 
-class Infer:
+class MyModel(sonnx.SONNXModel):
 
-    def __init__(self, sg_ir):
-        self.sg_ir = sg_ir
-        for idx, tens in sg_ir.tensor_map.items():
-            # allow the tensors to be updated
-            tens.requires_grad = True
-            tens.stores_grad = True
-            sg_ir.tensor_map[idx] = tens
+    def __init__(self, onnx_model):
+        super(MyModel, self).__init__(onnx_model)
 
-    def forward(self, x):
-        return sg_ir.run([x])[0]
+    def forward(self, *x):
+        y = super(MyModel, self).forward(*x)
+        return y[0]
+
+    def train_one_batch(self, x, y):
+        pass
 
 
 if __name__ == "__main__":
@@ -79,32 +77,30 @@
     download_model(url)
     onnx_model = onnx.load(model_path)
 
-    # set batch size
-    onnx_model = update_batch_size(onnx_model, 1)
+    # inference
+    logging.info("preprocessing...")
+    img, labels = get_image_labe()
+    img = preprocess(img)
+    # sg_ir = sonnx.prepare(onnx_model) # run without graph
+    # y = sg_ir.run([img])
 
-    # prepare the model
-    logging.info("prepare model...")
+    logging.info("model compling...")
     dev = device.create_cuda_gpu()
-    sg_ir = sonnx.prepare(onnx_model, device=dev)
-    autograd.training = False
-    model = Infer(sg_ir)
+    x = tensor.PlaceHolder(img.shape, device=dev)
+    model = MyModel(onnx_model)
+    model.compile([x], is_train=False, use_graph=True, sequential=True)
 
     # verifty the test
     # from utils import load_dataset
     # inputs, ref_outputs = load_dataset(os.path.join('/tmp', 'vgg16', 'test_data_set_0'))
     # x_batch = tensor.Tensor(device=dev, data=inputs[0])
-    # outputs = model.forward(x_batch)
+    # outputs = sg_ir.run([x_batch])
     # for ref_o, o in zip(ref_outputs, outputs):
     #     np.testing.assert_almost_equal(ref_o, tensor.to_numpy(o), 4)
 
-    # inference
-    logging.info("preprocessing...")
-    img, labels = get_image_labe()
-    img = preprocess(img)
-
     logging.info("model running...")
-    x_batch = tensor.Tensor(device=dev, data=img)
-    y = model.forward(x_batch)
+    x = tensor.Tensor(device=dev, data=img)
+    y = model.forward(x)
 
     logging.info("postprocessing...")
     y = tensor.softmax(y)
diff --git a/examples/qabot/README.md b/examples/qabot/README.md
new file mode 100644
index 0000000..fdbab08
--- /dev/null
+++ b/examples/qabot/README.md
@@ -0,0 +1,31 @@
+<!--
+    Licensed to the Apache Software Foundation (ASF) under one
+    or more contributor license agreements.  See the NOTICE file
+    distributed with this work for additional information
+    regarding copyright ownership.  The ASF licenses this file
+    to you under the Apache License, Version 2.0 (the
+    "License"); you may not use this file except in compliance
+    with the License.  You may obtain a copy of the License at
+
+      http://www.apache.org/licenses/LICENSE-2.0
+
+    Unless required by applicable law or agreed to in writing,
+    software distributed under the License is distributed on an
+    "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+    KIND, either express or implied.  See the License for the
+    specific language governing permissions and limitations
+    under the License.
+-->
+# Train a question and answering QABOT model
+
+This example describes how to implement a question and answering QABOT
+application using SINGA's CUDNN RNN layers.
+
+We will use the [LSTM](https://www.mitpressjournals.org/doi/abs/10.1162/neco.1997.9.8.1735) model together with max pooling as an
+example to train the QABOT.
+
+## Instructions
+
+* Start the training,
+
+        python qabot_train.py
diff --git a/examples/qabot/qabot_data.py b/examples/qabot/qabot_data.py
new file mode 100644
index 0000000..4494855
--- /dev/null
+++ b/examples/qabot/qabot_data.py
@@ -0,0 +1,282 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import numpy as np
+import random
+
+download_dir = "/tmp/"
+import os
+import urllib
+
+
+def check_exist_or_download(url):
+    ''' download data into tmp '''
+    name = url.rsplit('/', 1)[-1]
+    filename = os.path.join(download_dir, name)
+    if not os.path.isfile(filename):
+        print("Downloading %s" % url)
+        urllib.request.urlretrieve(url, filename)
+    return filename
+
+
+def unzip_data(download_dir, data_zip):
+    data_dir = download_dir + "insuranceQA-master/V2/"
+    if not os.path.exists(data_dir):
+        print("extracting %s to %s" % (download_dir, data_dir))
+        from zipfile import ZipFile
+        with ZipFile(data_zip, 'r') as zipObj:
+            zipObj.extractall(download_dir)
+    return data_dir
+
+
+def get_label2answer(data_dir):
+    import gzip
+    label2answer = dict()
+    with gzip.open(data_dir +
+                   "/InsuranceQA.label2answer.token.encoded.gz") as fin:
+        for line in fin:
+            pair = line.decode().strip().split("\t")
+            idxs = pair[1].split(" ")
+            idxs = [int(idx.replace("idx_", "")) for idx in idxs]
+            label2answer[int(pair[0])] = idxs
+    return label2answer
+
+
+pad_idx = 0
+pad_string = "<pad>"
+pad_embed = np.zeros((300,))
+
+insuranceqa_train_filename = "/InsuranceQA.question.anslabel.token.100.pool.solr.train.encoded.gz"
+insuranceqa_test_filename = "/InsuranceQA.question.anslabel.token.100.pool.solr.test.encoded.gz"
+insuranceQA_url = "https://github.com/shuzi/insuranceQA/archive/master.zip"
+insuranceQA_cache_fp = download_dir + "insuranceQA_cache.pickle"
+google_news_pretrain_embeddings_link = "https://s3.amazonaws.com/dl4j-distribution/GoogleNews-vectors-negative300.bin.gz"
+
+
+def get_idx2word(data_dir):
+    idx2word = dict()
+    with open(data_dir + "vocabulary", encoding="utf-8") as vc_f:
+        for line in vc_f:
+            pair = line.strip().split("\t")
+            idx = int(pair[0].replace("idx_", ""))
+            idx2word[idx] = pair[1]
+
+    # add padding string to idx2word lookup
+    idx2word[pad_idx] = pad_string
+
+    return idx2word
+
+
+def get_train_raw(data_dir, data_filename):
+    ''' deserialize training data file
+        args:
+            data_dir: dir of data file
+        return:
+            train_raw: list of QnA pair, length of list  == number of samples,
+                each pair has 3 fields:
+                    0 is question sentence idx encoded, use idx2word to decode,
+                        idx2vec to get embedding.
+                    1 is ans labels, each label corresponds to a ans sentence,
+                        use label2answer to decode.
+                    2 is top K candidate ans, these are negative ans for
+                        training.
+    '''
+    train_raw = []
+    import gzip
+    with gzip.open(data_dir + data_filename) as fin:
+        for line in fin:
+            tpl = line.decode().strip().split("\t")
+            question = [
+                int(idx.replace("idx_", "")) for idx in tpl[1].split(" ")
+            ]
+            ans = [int(label) for label in tpl[2].split(" ")]
+            candis = [int(label) for label in tpl[3].split(" ")]
+            train_raw.append((question, ans, candis))
+    return train_raw
+
+
+def limit_encode_train(train_raw, label2answer, idx2word, q_seq_limit,
+                       ans_seq_limit, idx2vec):
+    ''' prepare train data to embedded word vector sequence given sequence limit
+        return:
+            questions_encoded: np ndarray, shape
+                (number samples, seq length, vector size)
+            poss_encoded: same layout, sequence for positive answer
+            negs_encoded: same layout, sequence for negative answer
+    '''
+    questions = [question for question, answers, candis in train_raw]
+    # choose 1 answer from answer pool
+    poss = [
+        label2answer[random.choice(answers)]
+        for question, answers, candis in train_raw
+    ]
+    # choose 1 candidate from candidate pool
+    negs = [
+        label2answer[random.choice(candis)]
+        for question, answers, candis in train_raw
+    ]
+
+    # filtered word not in idx2vec
+    questions_filtered = [
+        [idx for idx in q if idx in idx2vec] for q in questions
+    ]
+    poss_filtered = [[idx for idx in ans if idx in idx2vec] for ans in poss]
+    negs_filtered = [[idx for idx in ans if idx in idx2vec] for ans in negs]
+
+    # crop to seq limit
+    questions_crop = [
+        q[:q_seq_limit] + [0] * max(0, q_seq_limit - len(q))
+        for q in questions_filtered
+    ]
+    poss_crop = [
+        ans[:ans_seq_limit] + [0] * max(0, ans_seq_limit - len(ans))
+        for ans in poss_filtered
+    ]
+    negs_crop = [
+        ans[:ans_seq_limit] + [0] * max(0, ans_seq_limit - len(ans))
+        for ans in negs_filtered
+    ]
+
+    # encoded, word idx to word vector
+    questions_encoded = [[idx2vec[idx] for idx in q] for q in questions_crop]
+    poss_encoded = [[idx2vec[idx] for idx in ans] for ans in poss_crop]
+    negs_encoded = [[idx2vec[idx] for idx in ans] for ans in negs_crop]
+
+    # make nd array
+    questions_encoded = np.array(questions_encoded).astype(np.float32)
+    poss_encoded = np.array(poss_encoded).astype(np.float32)
+    negs_encoded = np.array(negs_encoded).astype(np.float32)
+    return questions_encoded, poss_encoded, negs_encoded
+
+
+def get_idx2vec_weights(wv, idx2word):
+    idx2vec = {k: wv[v] for k, v in idx2word.items() if v in wv}
+
+    # add padding embedding (all zeros) to idx2vec lookup
+    idx2vec[pad_idx] = pad_embed
+    return idx2vec
+
+
+def prepare_data(use_cache=True):
+    import pickle
+    if not os.path.isfile(insuranceQA_cache_fp) or not use_cache:
+        # no cache is found, preprocess data from scratch
+        print("prepare data from scratch")
+
+        # get pretained word vector
+        from gensim.models.keyedvectors import KeyedVectors
+        google_news_pretrain_fp = check_exist_or_download(
+            google_news_pretrain_embeddings_link)
+        wv = KeyedVectors.load_word2vec_format(google_news_pretrain_fp,
+                                               binary=True)
+
+        # prepare insurance QA dataset
+        data_zip = check_exist_or_download(insuranceQA_url)
+        data_dir = unzip_data(download_dir, data_zip)
+
+        label2answer = get_label2answer(data_dir)
+        idx2word = get_idx2word(data_dir)
+        idx2vec = get_idx2vec_weights(wv, idx2word)
+
+        train_raw = get_train_raw(data_dir, insuranceqa_train_filename)
+        test_raw = get_train_raw(data_dir, insuranceqa_test_filename)
+        with open(insuranceQA_cache_fp, 'wb') as handle:
+            pickle.dump((train_raw, test_raw, label2answer, idx2word, idx2vec),
+                        handle,
+                        protocol=pickle.HIGHEST_PROTOCOL)
+    else:
+        # load from cached pickle
+        with open(insuranceQA_cache_fp, 'rb') as handle:
+            (train_raw, test_raw, label2answer, idx2word,
+             idx2vec) = pickle.load(handle)
+
+    return train_raw, test_raw, label2answer, idx2word, idx2vec
+
+
+def limit_encode_eval(train_raw,
+                      label2answer,
+                      idx2word,
+                      q_seq_limit,
+                      ans_seq_limit,
+                      idx2vec,
+                      top_k_candi_limit=6):
+    ''' prepare train data to embedded word vector sequence given sequence limit for testing
+        return:
+            questions_encoded: np ndarray, shape
+                (number samples, seq length, vector size)
+            poss_encoded: same layout, sequence for positive answer
+            negs_encoded: same layout, sequence for negative answer
+    '''
+    questions = [question for question, answers, candis in train_raw]
+
+    # combine truth and candidate answers label,
+    candi_pools = [
+        list(answers + candis)[:top_k_candi_limit]
+        for question, answers, candis in train_raw
+    ]
+    assert all([len(pool) == top_k_candi_limit for pool in candi_pools])
+
+    ans_count = [len(answers) for question, answers, candis in train_raw]
+    assert all([c > 0 for c in ans_count])
+
+    # encode ans
+    candi_pools_encoded = [[label2answer[candi_label]
+                            for candi_label in pool]
+                           for pool in candi_pools]
+
+    # filtered word not in idx2vec
+    questions_filtered = [
+        [idx for idx in q if idx in idx2vec] for q in questions
+    ]
+    candi_pools_filtered = [[[idx
+                              for idx in candi_encoded
+                              if idx in idx2vec]
+                             for candi_encoded in pool]
+                            for pool in candi_pools_encoded]
+
+    # crop to seq limit
+    questions_crop = [
+        q[:q_seq_limit] + [0] * max(0, q_seq_limit - len(q))
+        for q in questions_filtered
+    ]
+    candi_pools_crop = [[
+        candi[:ans_seq_limit] + [0] * max(0, ans_seq_limit - len(candi))
+        for candi in pool
+    ]
+                        for pool in candi_pools_filtered]
+
+    # encoded, word idx to word vector
+    questions_encoded = [[idx2vec[idx] for idx in q] for q in questions_crop]
+    candi_pools_encoded = [[[idx2vec[idx]
+                             for idx in candi]
+                            for candi in pool]
+                           for pool in candi_pools_crop]
+    questions_encoded = np.array(questions_encoded).astype(np.float32)
+    candi_pools_encoded = np.array(candi_pools_encoded).astype(np.float32)
+
+    # candi_pools_encoded shape
+    #    (number of sample QnA,
+    #     number of candi in pool,
+    #     number of sequence word idx per candi,
+    #     300 word embedding for 1 word idx)
+    #  e.g 10 QnA to test
+    #      5 each question has 5 possible ans
+    #      8 each ans has 8 words
+    #      300 each word has vector size 300
+    return questions_encoded, candi_pools_encoded, ans_count
diff --git a/examples/qabot/qabot_model.py b/examples/qabot/qabot_model.py
new file mode 100644
index 0000000..d5a9d88
--- /dev/null
+++ b/examples/qabot/qabot_model.py
@@ -0,0 +1,152 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from singa import autograd, layer, model
+
+
+class QAModel_mlp(model.Model):
+
+    def __init__(self, hidden_size):
+        super().__init__()
+        self.linear_q = layer.Linear(hidden_size)
+        self.linear_a = layer.Linear(hidden_size)
+
+    def forward(self, q, a_batch):
+        q = autograd.reshape(q, (q.shape[0], -1))  # bs, seq_q*data_s
+        a_batch = autograd.reshape(a_batch,
+                                   (a_batch.shape[0], -1))  # 2bs, seq_a*data_s
+
+        q = self.linear_q(q)  # bs, hid_s
+        a_batch = self.linear_a(a_batch)  # 2bs, hid_s
+
+        a_pos, a_neg = autograd.split(a_batch, 0,
+                                      [q.shape[0], q.shape[0]])  # 2*(bs, hid)
+
+        sim_pos = autograd.cossim(q, a_pos)
+        sim_neg = autograd.cossim(q, a_neg)
+        return sim_pos, sim_neg
+
+
+class QAModel(model.Model):
+
+    def __init__(self,
+                 hidden_size,
+                 num_layers=1,
+                 bidirectional=True,
+                 return_sequences=False):
+        super(QAModel, self).__init__()
+        self.hidden_size = hidden_size
+        self.lstm_q = layer.CudnnRNN(hidden_size=hidden_size,
+                                     bidirectional=bidirectional,
+                                     return_sequences=return_sequences)
+        self.lstm_a = layer.CudnnRNN(hidden_size=hidden_size,
+                                     bidirectional=bidirectional,
+                                     return_sequences=return_sequences)
+
+    def forward(self, q, a_batch):
+        q = self.lstm_q(q)  # bs, Hidden*2
+        a_batch = self.lstm_a(a_batch)  # 2bs, Hidden*2
+
+        bs_a = q.shape[0]
+        # bs, hid*2
+        a_pos, a_neg = autograd.split(a_batch, 0, [bs_a, bs_a])
+
+        sim_pos = autograd.cossim(q, a_pos)
+        sim_neg = autograd.cossim(q, a_neg)
+        return sim_pos, sim_neg
+
+
+class QAModel_mean(model.Model):
+
+    def __init__(self, hidden_size, bidirectional=True, return_sequences=True):
+        super(QAModel_mean, self).__init__()
+        self.hidden_size = hidden_size
+        self.lstm_q = layer.CudnnRNN(hidden_size=hidden_size,
+                                     batch_first=True,
+                                     bidirectional=bidirectional,
+                                     return_sequences=return_sequences)
+        self.lstm_a = layer.CudnnRNN(hidden_size=hidden_size,
+                                     batch_first=True,
+                                     bidirectional=bidirectional,
+                                     return_sequences=return_sequences)
+
+    def forward(self, q, a_batch):
+        q = self.lstm_q(q)  # bs, seq, Hidden*2
+        a_batch = self.lstm_a(a_batch)  # 2bs, seq, Hidden*2
+
+        # bs, hid*2
+        q = autograd.reduce_mean(q, [1], keepdims=0)
+        # (2bs, hid*2)
+        a_batch = autograd.reduce_mean(a_batch, [1], keepdims=0)
+
+        # 2*(bs, seq, hid*2)
+        a_pos, a_neg = autograd.split(a_batch, 0, [q.shape[0], q.shape[0]])
+
+        sim_pos = autograd.cossim(q, a_pos)
+        sim_neg = autograd.cossim(q, a_neg)
+        return sim_pos, sim_neg
+
+
+class QAModel_maxpooling(model.Model):
+
+    def __init__(self,
+                 hidden_size,
+                 q_seq,
+                 a_seq,
+                 num_layers=1,
+                 bidirectional=True,
+                 return_sequences=True):
+        super(QAModel_maxpooling, self).__init__()
+        self.hidden_size = hidden_size
+        self.lstm_q = layer.CudnnRNN(hidden_size=hidden_size,
+                                     bidirectional=bidirectional,
+                                     return_sequences=return_sequences)
+        self.lstm_a = layer.CudnnRNN(hidden_size=hidden_size,
+                                     bidirectional=bidirectional,
+                                     return_sequences=return_sequences)
+        self.q_pool = layer.MaxPool2d((q_seq, 1))
+        self.a_pool = layer.MaxPool2d((a_seq, 1))
+
+    def forward(self, q, a_batch):
+        # bs, seq, Hidden*2
+        q = self.lstm_q(q)
+        # bs, 1, seq, hid*2
+        q = autograd.reshape(q, (q.shape[0], 1, q.shape[1], q.shape[2]))
+        # bs, 1, 1, hid*2
+        q = self.q_pool(q)
+        # bs, hid*2
+        q = autograd.reshape(q, (q.shape[0], q.shape[3]))
+
+        # 2bs, seq, Hidden*2
+        a_batch = self.lstm_a(a_batch)
+        # 2bs, 1, seq, hid*2
+        a_batch = autograd.reshape(
+            a_batch, (a_batch.shape[0], 1, a_batch.shape[1], a_batch.shape[2]))
+        # 2bs, 1, 1, hid*2
+        a_batch = self.a_pool(a_batch)
+        # 2bs, hid*2
+        a_batch = autograd.reshape(a_batch,
+                                   (a_batch.shape[0], a_batch.shape[3]))
+
+        # 2*(bs, hid*2)
+        a_pos, a_neg = autograd.split(a_batch, 0, [q.shape[0], q.shape[0]])
+
+        sim_pos = autograd.cossim(q, a_pos)
+        sim_neg = autograd.cossim(q, a_neg)
+        return sim_pos, sim_neg
\ No newline at end of file
diff --git a/examples/qabot/qabot_train.py b/examples/qabot/qabot_train.py
new file mode 100644
index 0000000..45893e0
--- /dev/null
+++ b/examples/qabot/qabot_train.py
@@ -0,0 +1,159 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import numpy as np
+import time
+import random
+from tqdm import tqdm
+import argparse
+
+from singa import autograd, tensor, device, opt
+from qabot_data import limit_encode_train, limit_encode_eval, prepare_data
+from qabot_model import QAModel_maxpooling
+
+
+def do_train(m, tq, ta, train, meta_data, args):
+    '''
+    batch size need to be large to see all negative ans
+    '''
+    m.train()
+    for epoch in range(args.epochs):
+        total_loss = 0
+        start = time.time()
+
+        q, ans_p, ans_n = limit_encode_train(train, meta_data['label2answer'],
+                                             meta_data['idx2word'],
+                                             args.q_seq_limit,
+                                             args.ans_seq_limit,
+                                             meta_data['idx2vec'])
+        bs = args.bs
+
+        for i in tqdm(range(len(q) // bs)):
+            tq.copy_from_numpy(q[i * bs:(i + 1) * bs])
+            a_batch = np.concatenate(
+                [ans_p[i * bs:(i + 1) * bs], ans_n[i * bs:(i + 1) * bs]])
+            ta.copy_from_numpy(a_batch)
+
+            p_sim, n_sim = m.forward(tq, ta)
+            l = autograd.ranking_loss(p_sim, n_sim)
+            m.optimizer(l)
+
+            total_loss += tensor.to_numpy(l)
+        print(
+            "epoch %d, time used %d sec, loss: " % (epoch, time.time() - start),
+            total_loss * bs / len(q))
+
+
+def do_eval(m, tq, ta, test, meta_data, args):
+    q, candis, ans_count = limit_encode_eval(test, meta_data['label2answer'],
+                                             meta_data['idx2word'],
+                                             args.q_seq_limit,
+                                             args.ans_seq_limit,
+                                             meta_data['idx2vec'],
+                                             args.number_of_candidates)
+    m.eval()
+    candi_pool_size = candis.shape[1]
+    correct = 0
+    start = time.time()
+    for i in tqdm(range(len(q))):
+        # batch size bs must satisfy: bs == repeated q, bs == number of answers//2
+        # 1 question repeat n times, n == number of answers//2
+        _q = np.repeat([q[i]], candi_pool_size // 2, axis=0)
+        tq.copy_from_numpy(_q)
+        ta.copy_from_numpy(candis[i])
+
+        (first_half_score, second_half_score) = m.forward(tq, ta)
+
+        first_half_score = tensor.to_numpy(first_half_score)
+        second_half_score = tensor.to_numpy(second_half_score)
+        scores = np.concatenate((first_half_score, second_half_score))
+        pred_max_idx = np.argmax(scores)
+
+        if pred_max_idx < ans_count[i]:
+            correct += 1
+
+    print("eval top %s " % (candi_pool_size), " accuracy", correct / len(q),
+          " time used %d sec" % (time.time() - start))
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument('-m',
+                        '--max-epoch',
+                        default=30,
+                        type=int,
+                        help='maximum epochs',
+                        dest='epochs')
+    parser.add_argument('-b',
+                        '--batch-size',
+                        default=50,
+                        type=int,
+                        help='batch size',
+                        dest='bs')
+    parser.add_argument('-l',
+                        '--learning-rate',
+                        default=0.01,
+                        type=float,
+                        help='initial learning rate',
+                        dest='lr')
+    parser.add_argument('-i',
+                        '--device-id',
+                        default=0,
+                        type=int,
+                        help='which GPU to use',
+                        dest='device_id')
+
+    args = parser.parse_args()
+
+    args.hid_s = 64
+    args.q_seq_limit = 10
+    args.ans_seq_limit = 50
+    args.embed_size = 300
+    args.number_of_candidates = args.bs * 2
+    assert args.number_of_candidates <= 100, "number_of_candidates should be <= 100"
+
+    dev = device.create_cuda_gpu_on(args.device_id)
+
+    # tensor container
+    tq = tensor.random((args.bs, args.q_seq_limit, args.embed_size), dev)
+    ta = tensor.random((args.bs * 2, args.ans_seq_limit, args.embed_size), dev)
+
+    # model
+    m = QAModel_maxpooling(args.hid_s,
+                           q_seq=args.q_seq_limit,
+                           a_seq=args.ans_seq_limit)
+    m.compile([tq, ta], is_train=True, use_graph=True, sequential=False)
+    m.optimizer = opt.SGD(args.lr, 0.9)
+
+    # get data
+    train_raw, test_raw, label2answer, idx2word, idx2vec = prepare_data()
+    meta_data = {
+        'label2answer': label2answer,
+        'idx2word': idx2word,
+        'idx2vec': idx2vec
+    }
+
+    print("training...")
+    do_train(m, tq, ta, train_raw, meta_data, args)
+
+    print("Eval with train data...")
+    do_eval(m, tq, ta, random.sample(train_raw, 2000), meta_data, args)
+
+    print("Eval with test data...")
+    do_eval(m, tq, ta, test_raw, meta_data, args)
diff --git a/examples/rbm/train.py b/examples/rbm/train.py
index 2c09423..a2419ab 100755
--- a/examples/rbm/train.py
+++ b/examples/rbm/train.py
@@ -28,8 +28,7 @@
 except ImportError:
     import cPickle as pickle
 
-from singa import initializer
-from singa import optimizer
+from singa import opt
 from singa import device
 from singa import tensor
 
@@ -48,16 +47,15 @@
 
 def train(data_file, use_gpu, num_epoch=10, batch_size=100):
     print('Start intialization............')
-    lr = 0.1   # Learning rate
+    lr = 0.0005   # Learning rate
     weight_decay = 0.0002
     hdim = 1000
     vdim = 784
-
     tweight = tensor.Tensor((vdim, hdim))
     tweight.gaussian(0.0, 0.1)
     tvbias = tensor.from_numpy(np.zeros(vdim, dtype=np.float32))
     thbias = tensor.from_numpy(np.zeros(hdim, dtype=np.float32))
-    opt = optimizer.SGD(momentum=0.5, weight_decay=weight_decay)
+    sgd = opt.SGD(lr=lr, momentum=0.9, weight_decay=weight_decay)
 
     print('Loading data ..................')
     train_x, valid_x = load_train_data(data_file)
@@ -103,9 +101,9 @@
             tgvbias = tensor.sum(tnegdata, 0) - tensor.sum(tdata, 0)
             tghbias = tensor.sum(tneghidprob, 0) - tensor.sum(tposhidprob, 0)
 
-            opt.apply_with_lr(epoch, lr / batch_size, tgweight, tweight, 'w')
-            opt.apply_with_lr(epoch, lr / batch_size, tgvbias, tvbias, 'vb')
-            opt.apply_with_lr(epoch, lr / batch_size, tghbias, thbias, 'hb')
+            sgd.apply('w', tweight, tgweight)
+            sgd.apply('vb', tvbias, tgvbias)
+            sgd.apply('hb', thbias, tghbias)
 
         print('training erroraverage = %f' %
               (tensor.to_numpy(trainerrorsum) / train_x.shape[0]))
@@ -116,7 +114,7 @@
         tvalidposhidprob = tvalidposhidprob + thbias
         tvalidposhidprob = tensor.sigmoid(tvalidposhidprob)
         tvalidposhidrandom = tensor.Tensor(tvalidposhidprob.shape, dev)
-        initializer.uniform(tvalidposhidrandom, 0.0, 1.0)
+        tvalidposhidrandom.uniform(0.0, 1.0)
         tvalidposhidsample = tensor.gt(tvalidposhidprob, tvalidposhidrandom)
 
         tvalidnegdata = tensor.mult(tvalidposhidsample, tweight.T())
diff --git a/examples/rnn/README.md b/examples/rnn/README.md
index 7c1c697..36d60ab 100644
--- a/examples/rnn/README.md
+++ b/examples/rnn/README.md
@@ -16,25 +16,20 @@
     specific language governing permissions and limitations
     under the License.
 -->
-# Train Char-RNN over plain text
+# Train RNN model over IMDB dataset
 
 Recurrent neural networks (RNN) are widely used for modelling sequential data,
 e.g., natural language sentences. This example describes how to implement a RNN
-application (or model) using SINGA's RNN layers.
-We will use the [char-rnn](https://github.com/karpathy/char-rnn) model as an
-example, which trains over sentences or
-source code, with each character as an input unit. Particularly, we will train
-a RNN over Linux kernel source code. 
+application (or model) using SINGA's CUDNN RNN layers.
+We will use the [LSTM](https://www.mitpressjournals.org/doi/abs/10.1162/neco.1997.9.8.1735) model as an
+example to train on IMDB dataset.
 
 ## Instructions
 
-* Prepare the dataset. Download the [kernel source code](http://cs.stanford.edu/people/karpathy/char-rnn/).
-Other plain text files can also be used.
+* Prepare the dataset,
+
+        python imdb_data.py
 
 * Start the training,
 
-        python train.py linux_input.txt
-
-  Some hyper-parameters could be set through command line,
-
-        python train.py -h
+        python imdb_train.py
diff --git a/examples/rnn/train.py b/examples/rnn/char_rnn.py
similarity index 83%
rename from examples/rnn/train.py
rename to examples/rnn/char_rnn.py
index 107f0f1..2979b95 100644
--- a/examples/rnn/train.py
+++ b/examples/rnn/char_rnn.py
@@ -31,21 +31,24 @@
 from singa import device
 from singa import tensor
 from singa import autograd
-from singa import module
+from singa import layer
+from singa import model
 from singa import opt
 
 
-class CharRNN(module.Module):
+class CharRNN(model.Model):
 
     def __init__(self, vocab_size, hidden_size=32):
         super(CharRNN, self).__init__()
-        self.rnn = autograd.LSTM(vocab_size, hidden_size)
-        self.dense = autograd.Linear(hidden_size, vocab_size)
+        self.rnn = layer.LSTM(vocab_size, hidden_size)
+        self.cat = layer.Cat()
+        self.reshape1 = layer.Reshape()
+        self.dense = layer.Linear(hidden_size, vocab_size)
+        self.reshape2 = layer.Reshape()
+        self.softmax_cross_entropy = layer.SoftMaxCrossEntropy()
         self.optimizer = opt.SGD(0.01)
         self.hidden_size = hidden_size
         self.vocab_size = vocab_size
-        self.hx = tensor.Tensor((1, self.hidden_size))
-        self.cx = tensor.Tensor((1, self.hidden_size))
 
     def reset_states(self, dev):
         self.hx.to_device(dev)
@@ -53,18 +56,37 @@
         self.hx.set_value(0.0)
         self.cx.set_value(0.0)
 
+    def initialize(self, inputs):
+        batchsize = inputs[0].shape[0]
+        self.hx = tensor.Tensor((batchsize, self.hidden_size))
+        self.cx = tensor.Tensor((batchsize, self.hidden_size))
+        self.reset_states(inputs[0].device)
+
     def forward(self, inputs):
-        x, self.hx, self.cx = self.rnn(inputs, (self.hx, self.cx))
-        x = autograd.cat(x)
-        x = autograd.reshape(x, (-1, self.hidden_size))
+        x, hx, cx = self.rnn(inputs, (self.hx, self.cx))
+        self.hx.copy_data(hx)
+        self.cx.copy_data(cx)
+        x = self.cat(x)
+        x = self.reshape1(x, (-1, self.hidden_size))
         return self.dense(x)
 
-    def loss(self, out, ty):
-        ty = autograd.reshape(ty, (-1, 1))
-        return autograd.softmax_cross_entropy(out, ty)
+    def train_one_batch(self, x, y):
+        out = self.forward(x)
+        y = self.reshape2(y, (-1, 1))
+        loss = self.softmax_cross_entropy(out, y)
+        self.optimizer(loss)
+        return out, loss
 
-    def optim(self, loss):
-        self.optimizer.backward_and_update(loss)
+    def get_states(self):
+        ret = super().get_states()
+        ret[self.hx.name] = self.hx
+        ret[self.cx.name] = self.cx
+        return ret
+
+    def set_states(self, states):
+        self.hx.copy_from(states[self.hx.name])
+        self.hx.copy_from(states[self.hx.name])
+        super().set_states(states)
 
 
 class Data(object):
@@ -86,7 +108,7 @@
         data = [self.char_to_idx[c] for c in self.raw_data]
         # seq_length + 1 for the data + label
         nsamples = len(data) // (1 + seq_length)
-        data = data[0: nsamples * (1 + seq_length)]
+        data = data[0:nsamples * (1 + seq_length)]
         data = np.asarray(data, dtype=np.int32)
         data = np.reshape(data, (-1, seq_length + 1))
         # shuffle all sequences
@@ -181,7 +203,7 @@
                                  dev, inputs, labels)
         model.reset_states(dev)
         y = model(inputs)
-        loss = model.loss(y, labels)[0]
+        loss = autograd.softmax_cross_entropy(y, labels)[0]
         val_loss += tensor.to_numpy(loss)[0]
     print('            validation loss is %f' %
           (val_loss / data.num_test_batch / seq_length))
@@ -196,7 +218,6 @@
     # SGD with L2 gradient normalization
     cuda = device.create_cuda_gpu()
     model = CharRNN(data.vocab_size, hidden_size)
-    model.on_device(cuda)
     model.graph(True, False)
 
     inputs, labels = None, None
@@ -208,10 +229,8 @@
             batch = data.train_dat[b * batch_size:(b + 1) * batch_size]
             inputs, labels = convert(batch, batch_size, seq_length,
                                      data.vocab_size, cuda, inputs, labels)
+            out, loss = model(inputs, labels)
             model.reset_states(cuda)
-            y = model(inputs)
-            loss = model.loss(y, labels)
-            model.optim(loss)
             train_loss += tensor.to_numpy(loss)[0]
 
         print('\nEpoch %d, train loss is %f' %
diff --git a/examples/rnn/imdb_data.py b/examples/rnn/imdb_data.py
new file mode 100644
index 0000000..973f9e5
--- /dev/null
+++ b/examples/rnn/imdb_data.py
@@ -0,0 +1,283 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+
+import re
+import os
+import pickle
+import urllib
+import tarfile
+import numpy as np
+import pandas as pd
+import nltk
+from nltk.stem import PorterStemmer
+from nltk.tokenize.toktok import ToktokTokenizer
+from gensim.models.keyedvectors import KeyedVectors
+from sklearn.model_selection import train_test_split
+from bs4 import BeautifulSoup
+'''
+    data collection preprocessing constants
+'''
+download_dir = '/tmp/'
+preprocessed_imdb_data_fp = download_dir + 'imdb_processed.pickle'
+imdb_dataset_link = "https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz"
+google_news_pretrain_embeddings_link = "https://s3.amazonaws.com/dl4j-distribution/GoogleNews-vectors-negative300.bin.gz"
+
+
+def pad_batch(b, seq_limit):
+    ''' convert a batch of encoded sequence
+        to pretrained word vectors from the embed weights (lookup dictionary)
+    '''
+    batch_seq = []
+    batch_senti_onehot = []
+    batch_senti = []
+    for r in b:
+        # r[0] encoded sequence
+        # r[1] label 1 or 0
+        encoded = None
+        if len(r[0]) >= seq_limit:
+            encoded = r[0][:seq_limit]
+        else:
+            encoded = r[0] + [0] * (seq_limit - len(r[0]))
+
+        batch_seq.append(encoded)
+        batch_senti.append(r[1])
+        if r[1] == 1:
+            batch_senti_onehot.append([0, 1])
+        else:
+            batch_senti_onehot.append([1, 0])
+    batch_senti = np.array(batch_senti).astype(np.float32)
+    batch_senti_onehot = np.array(batch_senti_onehot).astype(np.float32)
+    batch_seq = np.array(batch_seq).astype(np.int32)
+    return batch_seq, batch_senti_onehot, batch_senti
+
+
+def pad_batch_2vec(b, seq_limit, embed_weights):
+    ''' convert a batch of encoded sequence
+        to pretrained word vectors from the embed weights (lookup dictionary)
+    '''
+    batch_seq = []
+    batch_senti_onehot = []
+    batch_senti = []
+    for r in b:
+        # r[0] encoded sequence
+        # r[1] label 1 or 0
+        encoded = None
+        if len(r[0]) >= seq_limit:
+            encoded = r[0][:seq_limit]
+        else:
+            encoded = r[0] + [0] * (seq_limit - len(r[0]))
+
+        batch_seq.append([embed_weights[idx] for idx in encoded])
+        batch_senti.append(r[1])
+        if r[1] == 1:
+            batch_senti_onehot.append([0, 1])
+        else:
+            batch_senti_onehot.append([1, 0])
+    batch_senti = np.array(batch_senti).astype(np.float32)
+    batch_senti_onehot = np.array(batch_senti_onehot).astype(np.float32)
+    batch_seq = np.array(batch_seq).astype(np.float32)
+    return batch_seq, batch_senti_onehot, batch_senti
+
+
+def check_exist_or_download(url):
+    ''' download data into tmp '''
+    name = url.rsplit('/', 1)[-1]
+    filename = os.path.join(download_dir, name)
+    if not os.path.isfile(filename):
+        print("Downloading %s" % url)
+        urllib.request.urlretrieve(url, filename)
+    return filename
+
+
+def unzip_data(download_dir, data_gz):
+    data_dir = download_dir + 'aclImdb'
+    if not os.path.exists(data_dir):
+        print("extracting %s to %s" % (download_dir, data_dir))
+        with tarfile.open(data_gz) as tar:
+            tar.extractall(download_dir)
+    return data_dir
+
+
+def strip_html(text):
+    ''' lambda fn for cleaning html '''
+    soup = BeautifulSoup(text, "html.parser")
+    return soup.get_text()
+
+
+def remove_between_square_brackets(text):
+    ''' lambda fn for cleaning square brackets'''
+    return re.sub('\[[^]]*\]', '', text)
+
+
+def remove_special_characters(text, remove_digits=True):
+    ''' lambda fn for removing special char '''
+    pattern = r'[^a-zA-Z0-9\s]'
+    text = re.sub(pattern, '', text)
+    return text
+
+
+def simple_stemmer(text):
+    ''' lambda fn for stemming '''
+    ps = PorterStemmer()
+    text = ' '.join([ps.stem(word) for word in text.split()])
+    return text
+
+
+def remove_stopwords(text, tokenizer, stopword_list, is_lower_case=False):
+    ''' lambda fn for removing stopwrods '''
+    tokens = tokenizer.tokenize(text)
+    tokens = [token.strip() for token in tokens]
+    if is_lower_case:
+        filtered_tokens = [
+            token for token in tokens if token not in stopword_list
+        ]
+    else:
+        filtered_tokens = [
+            token for token in tokens if token.lower() not in stopword_list
+        ]
+    filtered_text = ' '.join(filtered_tokens)
+    return filtered_text
+
+
+def tokenize(x):
+    ''' lambda fn for tokenize sentences '''
+    ret = []
+    for w in x.split(" "):
+        if w != '':
+            ret.append(w)
+    return ret
+
+
+def encode_token(words, wv, w2i):
+    ''' lambda fn for encoding string seq to int seq 
+        args: 
+            wv: word vector lookup dictionary
+            w2i: word2index lookup dictionary
+    '''
+    ret = []
+    for w in words:
+        if w in wv:
+            ret.append(w2i[w])
+    return ret
+
+
+def preprocess():
+    ''' collect and preprocess raw data from acl Imdb dataset
+    '''
+    nltk.download('stopwords')
+
+    print("preparing raw imdb data")
+    data_gz = check_exist_or_download(imdb_dataset_link)
+    data_dir = unzip_data(download_dir, data_gz)
+
+    # imdb dirs
+    # vocab_f = data_dir + '/imdb.vocab'
+    train_pos_dir = data_dir + '/train/pos/'
+    train_neg_dir = data_dir + '/train/neg/'
+    test_pos_dir = data_dir + '/test/pos/'
+    test_neg_dir = data_dir + '/test/neg/'
+
+    # nltk helpers
+    tokenizer = ToktokTokenizer()
+    stopword_list = nltk.corpus.stopwords.words('english')
+
+    # load pretrained word2vec binary
+    print("loading pretrained word2vec")
+    google_news_pretrain_fp = check_exist_or_download(
+        google_news_pretrain_embeddings_link)
+    wv = KeyedVectors.load_word2vec_format(google_news_pretrain_fp, binary=True)
+
+    # parse flat files to memory
+    data = []
+    for data_dir, label in [(train_pos_dir, 1), (train_neg_dir, 0),
+                            (test_pos_dir, 1), (test_neg_dir, 0)]:
+        for filename in os.listdir(data_dir):
+            if filename.endswith(".txt"):
+                with open(os.path.join(data_dir, filename),
+                          "r",
+                          encoding="utf-8") as fhdl:
+                    data.append((fhdl.read(), label))
+
+    # text review cleaning
+    print("cleaning text review")
+    imdb_data = pd.DataFrame(data, columns=["review", "label"])
+    imdb_data['review'] = imdb_data['review'].apply(strip_html)
+    imdb_data['review'] = imdb_data['review'].apply(
+        remove_between_square_brackets)
+    imdb_data['review'] = imdb_data['review'].apply(remove_special_characters)
+    imdb_data['review'] = imdb_data['review'].apply(simple_stemmer)
+    imdb_data['review'] = imdb_data['review'].apply(remove_stopwords,
+                                                    args=(tokenizer,
+                                                          stopword_list))
+    imdb_data['token'] = imdb_data['review'].apply(tokenize)
+
+    # build  word2index and index2word
+    w2i = dict()
+    i2w = dict()
+
+    # add vocab <pad> as index 0
+    w2i["<pad>"] = 0
+    i2w[0] = "<pad>"
+
+    idx = 1  # start from idx 1
+    for index, row in imdb_data['token'].iteritems():
+        for w in row:
+            if w in wv and w not in w2i:
+                w2i[w] = idx
+                i2w[idx] = w
+                assert idx < 28241
+                idx += 1
+    assert len(w2i) == len(i2w)
+    print("vocab size: ", len(w2i))
+
+    # encode tokens to int
+    imdb_data['encoded'] = imdb_data['token'].apply(encode_token,
+                                                    args=(wv, w2i))
+
+    # select word vector weights for embedding layer from vocab
+    embed_weights = []
+    for w in w2i.keys():
+        val = None
+        if w in wv:
+            val = wv[w]
+        else:
+            val = np.zeros([
+                300,
+            ])
+        embed_weights.append(val)
+    embed_weights = np.array(embed_weights)
+    print("embedding layer lookup weight shape: ", embed_weights.shape)
+
+    # split into train and test
+    train_data = imdb_data[['encoded', 'label']].values
+    train, val = train_test_split(train_data, test_size=0.33, random_state=42)
+
+    # save preprocessed for training
+    imdb_processed = {
+        "train": train,
+        "val": val,
+        "embed_weights": embed_weights,
+        "w2i": w2i,
+        "i2w": i2w
+    }
+    print("saving preprocessed file to ", preprocessed_imdb_data_fp)
+    with open(preprocessed_imdb_data_fp, 'wb') as handle:
+        pickle.dump(imdb_processed, handle, protocol=pickle.HIGHEST_PROTOCOL)
+
+
+if __name__ == "__main__":
+    preprocess()
diff --git a/examples/rnn/imdb_model.py b/examples/rnn/imdb_model.py
new file mode 100644
index 0000000..5698c0c
--- /dev/null
+++ b/examples/rnn/imdb_model.py
@@ -0,0 +1,58 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+
+from singa import autograd
+from singa import layer
+from singa import model
+
+
+class IMDBModel(model.Model):
+
+    def __init__(self,
+                 hidden_size,
+                 mode='lstm',
+                 return_sequences=False,
+                 bidirectional="False",
+                 num_layers=1):
+        super().__init__()
+        batch_first = True
+        self.lstm = layer.CudnnRNN(hidden_size=hidden_size,
+                                   batch_first=batch_first,
+                                   rnn_mode=mode,
+                                   return_sequences=return_sequences,
+                                   num_layers=1,
+                                   dropout=0.9,
+                                   bidirectional=bidirectional)
+        self.l1 = layer.Linear(64)
+        self.l2 = layer.Linear(2)
+
+    def forward(self, x):
+        y = self.lstm(x)
+        y = autograd.reshape(y, (y.shape[0], -1))
+        y = self.l1(y)
+        y = autograd.relu(y)
+        y = self.l2(y)
+        return y
+
+    def train_one_batch(self, x, y):
+        out = self.forward(x)
+        loss = autograd.softmax_cross_entropy(out, y)
+        self.optimizer(loss)
+        return out, loss
+
+    def set_opt(self, optimizer):
+        self.optimizer = optimizer
diff --git a/examples/rnn/imdb_train.py b/examples/rnn/imdb_train.py
new file mode 100644
index 0000000..4952639
--- /dev/null
+++ b/examples/rnn/imdb_train.py
@@ -0,0 +1,176 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+
+import pickle
+import os
+import sys
+import numpy as np
+from singa import tensor
+from singa import device
+from singa import opt
+from imdb_data import pad_batch_2vec, preprocessed_imdb_data_fp
+from imdb_model import IMDBModel
+import argparse
+
+if not os.path.isfile(preprocessed_imdb_data_fp):
+    sys.exit(
+        "Imdb dataset is not found, run python3 examples/rnn/imdb_data.py to prepare data"
+    )
+
+# load preprocessed data
+imdb_processed = None
+with open(preprocessed_imdb_data_fp, 'rb') as handle:
+    imdb_processed = pickle.load(handle)
+
+# use argparse to get command config: max_epoch, model, data, etc. for single gpu training
+parser = argparse.ArgumentParser()
+parser.add_argument('-m',
+                    '--max-epoch',
+                    default=5,
+                    type=int,
+                    help='maximum epochs',
+                    dest='max_epoch')
+parser.add_argument('-b',
+                    '--batch-size',
+                    default=128,
+                    type=int,
+                    help='batch size',
+                    dest='bs')
+parser.add_argument('-l',
+                    '--learning-rate',
+                    default=0.01,
+                    type=float,
+                    help='initial learning rate',
+                    dest='lr')
+# determine which gpu to use
+parser.add_argument('-i',
+                    '--device-id',
+                    default=0,
+                    type=int,
+                    help='which GPU to use',
+                    dest='device_id')
+# training params
+parser.add_argument('--mode',
+                    default='lstm',
+                    help='relu, tanh, lstm, gru',
+                    dest='mode')
+parser.add_argument('-s', '--return-sequences',
+                    default=False,
+                    action='store_true',
+                    help='return sequences',
+                    dest='return_sequences')
+parser.add_argument('-d', '--bidirectional',
+                    default=False,
+                    action='store_true',
+                    help='bidirectional lstm',
+                    dest='bidirectional')
+parser.add_argument('-n', '--num-layers',
+                    default=2,
+                    type=int,
+                    help='num layers',
+                    dest='num_layers')
+
+args = parser.parse_args()
+
+# parameters
+seq_limit = 50
+embed_size = 300
+hid = 32
+
+# gpu device
+dev = device.create_cuda_gpu_on(args.device_id)
+
+# create placeholder
+tx = tensor.Tensor((args.bs, seq_limit, embed_size), dev, tensor.float32)
+ty = tensor.Tensor((args.bs, 2), dev, tensor.float32)
+tx.gaussian(0, 1)
+ty.gaussian(0, 1)
+
+# create model
+m = IMDBModel(hid,
+              mode=args.mode,
+              return_sequences=args.return_sequences,
+              bidirectional=args.bidirectional,
+              num_layers=args.num_layers)
+m.set_opt(opt.SGD(args.lr, 0.9))
+
+m.compile([tx], is_train=True, use_graph=True, sequential=False)
+
+# training
+m.train()
+x_train, y_onehot_train, y_train = pad_batch_2vec(
+    imdb_processed['train'], seq_limit, imdb_processed['embed_weights'])
+x_test, y_onehot_test, y_test = pad_batch_2vec(imdb_processed['val'], seq_limit,
+                                               imdb_processed['embed_weights'])
+
+for epoch in range(args.max_epoch):
+    i = 0
+    l = 0
+    correct = 0
+    trials = 0
+    while (i + 1) * args.bs < len(x_train):
+        l_idx = i * args.bs
+        r_idx = l_idx + args.bs
+        x_batch = x_train[l_idx:r_idx]
+        y_onehot_batch = y_onehot_train[l_idx:r_idx]
+        y_batch = y_train[l_idx:r_idx]
+        i += 1
+
+        # reuse placeholders
+        tx.copy_from_numpy(x_batch)
+        ty.copy_from_numpy(y_onehot_batch)
+
+        # train one batch
+        out, loss = m(tx, ty)
+
+        # save output
+        l += tensor.to_numpy(loss)
+        scores = tensor.to_numpy(out)
+        correct += (y_batch == np.argmax(scores, 1)).sum()
+        trials += len(y_batch)
+
+    print("epoch %d loss %s; acc %.3f" % (epoch, l /
+                                          (trials / args.bs), correct / trials))
+    l = 0
+
+# testing:
+m.eval()
+
+i = 0
+correct = 0
+trials = 0
+while (i + 1) * args.bs < len(x_test):
+    l_idx = i * args.bs
+    r_idx = l_idx + args.bs
+    x_batch = x_test[l_idx:r_idx]
+    y_onehot_batch = y_onehot_test[l_idx:r_idx]
+    y_batch = y_test[l_idx:r_idx]
+    i += 1
+
+    # reuse same tensors
+    tx.copy_from_numpy(x_batch)
+    ty.copy_from_numpy(y_onehot_batch)
+
+    # make inference
+    out = m(tx)
+
+    # save correct predictions
+    scores = tensor.to_numpy(out)
+    correct += (y_batch == np.argmax(scores, 1)).sum()
+    trials += len(y_batch)
+
+print("eval acc %.3f" % (correct / trials))
diff --git a/include/singa/core/common.h b/include/singa/core/common.h
index cfef4e5..a408650 100644
--- a/include/singa/core/common.h
+++ b/include/singa/core/common.h
@@ -100,13 +100,21 @@
   std::mt19937 random_generator;
 #ifdef USE_CUDA
   cublasHandle_t cublas_handle;
-  cudaStream_t stream;
-  curandGenerator_t curand_generator;
+  cudaStream_t stream; 
+  curandGenerator_t curand_generator; 
+
 #ifdef USE_CUDNN
   cudnnHandle_t cudnn_handle;
 #endif
 #endif  // USE_CUDA
 
+#ifdef USE_DIST
+  // cuda streams used by communicator
+  cudaStream_t c1;
+  cudaStream_t c2;
+  cudaStream_t s;
+#endif
+
 #ifdef USE_DNNL
   dnnl::engine dnnl_engine;
   dnnl::stream dnnl_stream;
diff --git a/include/singa/core/device.h b/include/singa/core/device.h
index b910fec..50644c0 100644
--- a/include/singa/core/device.h
+++ b/include/singa/core/device.h
@@ -19,6 +19,7 @@
 #ifndef SINGA_CORE_DEVICE_H_
 #define SINGA_CORE_DEVICE_H_
 
+#include <chrono>
 #include <functional>
 #include <map>
 #include <memory>
@@ -61,6 +62,8 @@
   /// max mem size to use (in MB)
   Device(int id, int num_executors);
 
+  void Reset();
+
   virtual void SetRandSeed(unsigned seed) = 0;
 
   void EnableGraph(bool enable) { graph_enabled_ = enable; }
@@ -80,14 +83,15 @@
   /// Copy data within or across devices.
   virtual void CopyDataToFrom(Block* dst, Block* src, size_t nBytes,
                               CopyDirection direction, int dst_offset,
-                              int src_offset);
+                              int src_offset, Context* ctx);
 
   void CopyDataFromHostPtr(Block* dst, const void* src, size_t nBytes,
-                           size_t dst_offset = 0);
+                           size_t dst_offset = 0, Context* ctx = nullptr);
   /// Submit the operation to the device, which may execute it right now or
   /// delay it depending on the scheduler.
   void Exec(function<void(Context*)>&& fn, const vector<Block*> read_blocks,
-            const vector<Block*> write_blocks, bool use_rand_generator = false);
+            const vector<Block*> write_blocks, string op_name = "no_name",
+            bool use_rand_generator = false);
 
   void RunGraph(bool serial = false);
 
@@ -108,11 +112,28 @@
 
   bool graph_enabled() const { return graph_enabled_; }
 
+  /// Verbosity of the time profiling function:
+  /// verbosity == 0 (default) -> no logging
+  /// verbosity == 1 -> display forward and backward propagation time
+  /// verbosity == 2 -> display each operation time (OP_ID, op name, time)
+  int verbosity() const { return verbosity_; }
+  /// the number of initial iteration that is skipped for time profiling
+  int skip_iteration() const { return skip_iteration_; }
+
   virtual std::shared_ptr<Device> host() const { return host_; }
 
+  void PrintTimeProfiling();
+  void SetVerbosity(int verbosity) { verbosity_ = verbosity; };
+  void SetSkipIteration(int skip_iteration) {
+    skip_iteration_ = skip_iteration;
+  };
+
  protected:
   /// Execute one operation on one executor.
   virtual void DoExec(function<void(Context*)>&& fn, int executor) = 0;
+  virtual void TimeProfilingDoExec(function<void(Context*)>&& fn, int executor,
+                                   Node* node) = 0;
+  virtual void EvaluateTimeElapsed(Node* node) = 0;
 
   virtual void CopyToFrom(void* dst, const void* src, size_t nBytes,
                           CopyDirection direction, Context* ctx) = 0;
@@ -134,6 +155,8 @@
   int num_executors_ = 0;
   unsigned seed_ = 0;
   bool graph_enabled_ = false;
+  int verbosity_ = 0;
+  int skip_iteration_ = 5;
   /// The computational graph
   Graph* graph_ = nullptr;
   /// Programming language type, could be kCpp, kCuda, kOpencl
@@ -165,6 +188,9 @@
 
  protected:
   void DoExec(function<void(Context*)>&& fn, int executor) override;
+  void TimeProfilingDoExec(function<void(Context*)>&& fn, int executor,
+                           Node* node) override;
+  void EvaluateTimeElapsed(Node* node) override;
 
   void CopyToFrom(void* dst, const void* src, size_t nBytes,
                   CopyDirection direction, Context* ctx) override;
@@ -195,6 +221,11 @@
 
  protected:
   void DoExec(function<void(Context*)>&& fn, int executor) override;
+  void TimeProfilingDoExec(function<void(Context*)>&& fn, int executor,
+                           Node* node) override;
+  void EvaluateTimeElapsed(Node* node) override;
+
+  void SyncBeforeCountingTime();
 
   void CopyToFrom(void* dst, const void* src, size_t nBytes,
                   CopyDirection direction, Context* ctx) override;
@@ -232,7 +263,8 @@
 
   virtual void CopyDataToFrom(Block* dst, Block* src, size_t nBytes,
                               CopyDirection direction, int dst_offset = 0,
-                              int src_offset = 0) override;
+                              int src_offset = 0,
+                              Context* ctx = nullptr) override;
 
  protected:
   /// The OpenCL device that this object represents.
@@ -279,7 +311,11 @@
 class Platform {
  public:
   /// Return the default host device
-  static std::shared_ptr<Device> GetDefaultDevice() { return defaultDevice; }
+  static std::shared_ptr<Device> GetDefaultDevice() {
+    // cannot reset cpu device, which leads to error
+    // defaultDevice->Reset();
+    return defaultDevice;
+  }
 
 #ifdef USE_CUDA
   /// Return the number of total available GPUs
diff --git a/include/singa/core/scheduler.h b/include/singa/core/scheduler.h
index 68407b4..b430101 100644
--- a/include/singa/core/scheduler.h
+++ b/include/singa/core/scheduler.h
@@ -22,14 +22,17 @@
 #include <condition_variable>
 #include <functional>
 #include <mutex>
+#include <string>
 #include <thread>
 #include <unordered_map>
+#include <unordered_set>
 #include <vector>
 
 #include "singa/core/common.h"
 #include "singa/utils/safe_queue.h"
 
 using std::function;
+using std::string;
 using std::unordered_map;
 using std::vector;
 
@@ -44,22 +47,30 @@
 typedef std::vector<Node *> NodeVec;
 typedef std::vector<Edge *> EdgeVec;
 typedef std::vector<Block *> BlockVec;
+typedef std::unordered_set<Block *> BlockSet;
 typedef std::function<void(Context *)> OpFunc;
 typedef std::unordered_map<Block *, BlkInfo *> Blk2InfoMap;
+typedef std::chrono::high_resolution_clock::time_point TimePoint;
 
 enum BlockType { kUnknow, kInput, kParam, kInter, kEnd };
 
 class Node {
  public:
-  Node(int id, OpFunc &&op) : id_(id), op_(std::move(op)) {}
+  Node(int id, OpFunc &&op, string op_name)
+      : id_(id), op_(std::move(op)), op_name_(op_name) {}
 
   void AddInEdge(Edge *in_edge);
   void AddOutEdge(Edge *out_edge);
 
   // getters of Node
   int id() const { return id_; }
+  string op_name() const { return op_name_; }
   const EdgeVec &in_edges() const { return in_edges_; }
   const EdgeVec &out_edges() const { return out_edges_; }
+  float time_elapsed() const { return time_elapsed_; }
+
+  // time profiling
+  void time_elapsed_inc(float time) { time_elapsed_ += time; }
 
  private:
   friend Graph;
@@ -68,6 +79,15 @@
   OpFunc op_;
   EdgeVec in_edges_;
   EdgeVec out_edges_;
+
+  string op_name_;
+  float time_elapsed_ = 0;
+
+#ifdef USE_CUDA
+  cudaEvent_t start_;
+  cudaEvent_t end_;
+  friend class CudaGPU;
+#endif  // USE_CUDA
 };
 
 class Edge {
@@ -135,27 +155,27 @@
   void Debug();
   void RunGraph();
   void RunInSerial();
+  void PrintTimeProfiling();
   void AddOperation(OpFunc &&op, const BlockVec &read_blocks,
-                    const BlockVec &write_blocks);
+                    const BlockVec &write_blocks, string op_name = "no_name");
 
   // getters of Graph
   const NodeVec &nodes() const { return nodes_; }
   const EdgeVec &edges() const { return edges_; }
   const Blk2InfoMap &blocks() const { return blocks_; }
 
-  const BlockVec &write_blocks() const { return write_blocks_; }
+  const BlockSet &leaf_blocks() const { return leaf_blocks_; }
 
   bool dirty() const { return dirty_; }
   const NodeVec &begin_nodes() const { return begin_nodes_; }
   const std::vector<NodeVec> &next_nodes() const { return next_nodes_; }
   const std::vector<BlockVec> &free_blocks() const { return free_blocks_; }
+  int iteration() const { return iteration_; }
 
   Node *node(const size_t idx) const;
   Edge *edge(const size_t idx) const;
   BlkInfo *block(Block *blk) const;
 
-  Block *write_block(const size_t idx) const;
-
   Node *begin_node(const size_t idx) const;
   const NodeVec &next_nodes(const size_t idx) const;
   const BlockVec &free_blocks(const size_t idx) const;
@@ -165,7 +185,13 @@
   void FreeLoop();
   void AnalyzeNodes();
   void AnalyzeEdges();
-  void AddSyncOp(function<void(Context *)> &&op);
+  void TimeProfilingDoExec(Node *curNode);
+  void AddSyncOp(function<void(Context *)> &&op, string op_name = "no_name");
+
+  void step() { iteration_++; }
+  void time_elapsed_inc(float time) { time_elapsed_ += time; }
+  void TakeStartTime(TimePoint &start);
+  void EvaluateTimeElapsed(const TimePoint &start);
 
   // static void CUDART_CB Callback(cudaStream_t stream, cudaError_t status,
   //                                void *data);
@@ -178,16 +204,20 @@
   EdgeVec edges_;
   Blk2InfoMap blocks_;
 
-  // Blocks written by the last operation, used for sync op
-  BlockVec write_blocks_;
+  // Leaf blocks written by the previous operations, used for sync op
+  BlockSet leaf_blocks_;
 
-  // Calculation graph analysis
+  // Computational graph analysis
   bool dirty_ = false;
   bool in_serial_ = false;
   NodeVec begin_nodes_;
   std::vector<NodeVec> next_nodes_;
   std::vector<BlockVec> free_blocks_;
 
+  // Time Profiling
+  int iteration_ = 0;
+  float time_elapsed_ = 0;
+
   SafeQueue<int> free_queue_;
 };
 
diff --git a/include/singa/core/tensor.h b/include/singa/core/tensor.h
index ec342ed..aea988d 100644
--- a/include/singa/core/tensor.h
+++ b/include/singa/core/tensor.h
@@ -114,6 +114,15 @@
     return false;
   }
 
+  bool broadcasted() const {
+    int strideProduct = 1;
+    for (const auto &i : stride_) strideProduct *= i;
+    if (strideProduct == 0) {
+      return true;
+    }
+    return false;
+  }
+
   const vector<int> &stride() const { return stride_; }
 
   /// Return true if the content of the tensor is initialized
@@ -138,10 +147,10 @@
   /// used for swig code to convert Tensor into numpy array.
   /// It gets data into 'value'
   template <typename SType>
-  void GetValue(SType *value, const size_t num);
+  void GetValue(SType *value, const size_t num) const;
 
   template <typename SType>
-  void get_value(SType *value, const size_t num);
+  void get_value(SType *value, const size_t num) const;
 
   /// Serialize data, shape and transpose to protobuf object.
   void ToProto(singa::TensorProto *proto) const;
@@ -169,7 +178,7 @@
   /// memory with 'offset' (elements).
   template <typename SType>
   void CopyDataFromHostPtr(const SType *src, const size_t num,
-                           const size_t offset = 0);
+                           const size_t offset = 0) const;
 
   /// Copy data from another Tensor which may be on a diff device.
   /// Meta data would not be copied!
@@ -230,6 +239,9 @@
   template <typename SType>
   Tensor &operator/=(const SType x);
 
+  /// if tensor is transposed, transform to contiguous memory
+  Tensor &Contiguous();
+
   /// change the shape (and stride); the block may be reallocated.
   Tensor &Reshape(const Shape &shape);
 
@@ -247,7 +259,7 @@
 
   /// Return a view of the input tensor whose shape is broadcasted to be
   /// compitable with the given shape
-  Tensor &Broadcast(const Shape &shape);
+  Tensor &Broadcast(const Shape &shape, const int ignore_last_dim = 0);
 
   /// Reset the shape, device, and data type as given tensor.
   /// If block size changes, then reallocate a new block.
@@ -320,6 +332,8 @@
 /// which shares the memory with in if possible
 Tensor Reshape(const Tensor &in, const Shape &s);
 
+Tensor Contiguous(const Tensor &in);
+
 Tensor Resize(const Tensor &in, const Shape &s);
 
 /// Reverse the shape vector
@@ -327,7 +341,8 @@
 
 /// Return a view of the input tensor whose shape is broadcasted to be
 /// compitable with the given shape
-Tensor Broadcast(const Tensor &in, const Shape &shape);
+Tensor Broadcast(const Tensor &in, const Shape &shape,
+                 const int ignore_last_dim = 0);
 
 /// Change the axes
 Tensor Transpose(const Tensor &in, const vector<size_t> &axes);
@@ -343,7 +358,11 @@
 
 // =============Element-wise operations====================================
 Tensor Abs(const Tensor &in);
+Tensor Erf(const Tensor &in);
 Tensor Ceil(const Tensor &in);
+Tensor Floor(const Tensor &in);
+Tensor Round(const Tensor &in);
+Tensor RoundE(const Tensor &in);
 Tensor Exp(const Tensor &in);
 Tensor Log(const Tensor &in);
 Tensor ReLU(const Tensor &in);
@@ -368,7 +387,11 @@
 Tensor Transform(const Tensor &in);
 
 void Abs(const Tensor &in, Tensor *out);
+void Erf(const Tensor &in, Tensor *out);
 void Ceil(const Tensor &in, Tensor *out);
+void Floor(const Tensor &in, Tensor *out);
+void Round(const Tensor &in, Tensor *out);
+void RoundE(const Tensor &in, Tensor *out);
 void Exp(const Tensor &in, Tensor *out);
 void Log(const Tensor &in, Tensor *out);
 void ReLU(const Tensor &in, Tensor *out);
@@ -447,6 +470,16 @@
 Tensor operator>=(const Tensor &in1, const Tensor &in2);
 void GE(const Tensor &in1, const Tensor &in2, Tensor *out);
 
+/// Element-wise operation, out[i]= (in[i] == x) ? 1.f : 0.f
+template <typename SType>
+Tensor operator==(const Tensor &in, const SType x);
+template <typename SType>
+void EQ(const Tensor &in, const SType x, Tensor *out);
+
+/// Element-wise operation, out[i]= (in1[i] == in2[i]) ? 1.f : 0.f
+Tensor operator==(const Tensor &in1, const Tensor &in2);
+void EQ(const Tensor &in1, const Tensor &in2, Tensor *out);
+
 Tensor operator+(const Tensor &lhs, const Tensor &rhs);
 void Add(const Tensor &lhs, const Tensor &rhs, Tensor *out);
 Tensor operator-(const Tensor &lhs, const Tensor &rhs);
@@ -557,6 +590,8 @@
 template <typename SType>
 void Axpy(SType alpha, const Tensor &in, Tensor *out);
 
+void Axpy(const Tensor &alpha, const Tensor &in, Tensor *out);
+
 /// Do matrix vector multipication or matrix matrix multiplication depdending
 /// on the Tensor shape.  result = A * B
 Tensor Mult(const Tensor &A, const Tensor &B);
diff --git a/include/singa/io/communicator.h b/include/singa/io/communicator.h
index f805136..3f738ea 100644
--- a/include/singa/io/communicator.h
+++ b/include/singa/io/communicator.h
@@ -98,16 +98,16 @@
   void generateBlocks(Tensor &t);
   void generateBlocks(std::vector<Tensor> &t);
   void allReduce(int size, void *sendbuff, void *recvbuff,
-                 ncclDataType_t ncclType);
+                 ncclDataType_t ncclType, Context *ctx);
   void setup();
   void sparsInit();
   void halfInit();
   void _fusedSparsification(vector<Tensor> &t, Tensor *accumulation,
-                            float sparsThreshold, bool topK);
+                            float sparsThreshold, bool topK, Context *ctx);
   void _sparsification(Tensor &t, Tensor *accumulation, float sparsThreshold,
-                       bool topK);
-  void valSparsAllReduce(size_t num, float *accumulation);
-  void topKSparsAllReduce(size_t num, float *accumulation);
+                       bool topK, Context *ctx);
+  void valSparsAllReduce(size_t num, float *accumulation, Context *ctx);
+  void topKSparsAllReduce(size_t num, float *accumulation, Context *ctx);
 
   // last group of synchronized memory blocks
   std::shared_ptr<Device> device_ = nullptr;
@@ -115,11 +115,6 @@
   std::vector<Block *> prev_blocks_;
 
   ncclUniqueId id;
-  // cuda stream s is for nccl all reduce
-  cudaStream_t s;
-  // cuda streams c1 and c2 are mainly for data copy to and from memory buffers
-  cudaStream_t c1;
-  cudaStream_t c2;
   ncclComm_t comm;
   cudaEvent_t event;
 
diff --git a/java/pom.xml b/java/pom.xml
index cc55ce3..b16dbe3 100644
--- a/java/pom.xml
+++ b/java/pom.xml
@@ -83,6 +83,7 @@
                             <exclude>.gitmodules</exclude>
                             <exclude>java/target/*</exclude>
                             <exclude>miniconda.sh</exclude>
+                            <exclude>**/*.json</exclude>
                         </excludes>
                         <consoleOutput>True</consoleOutput>
                     </configuration>
diff --git a/python/singa/autograd.py b/python/singa/autograd.py
index 8ac9d98..ddfbf03 100644
--- a/python/singa/autograd.py
+++ b/python/singa/autograd.py
@@ -21,7 +21,6 @@
 
 from collections import Counter, deque
 import numpy as np
-import math
 
 from singa import tensor
 from singa import utils
@@ -39,7 +38,7 @@
         y_shape: the shape of result
         x_shape: the shape of x
     Return:
-        a tuple refering the axes 
+        a tuple refering the axes
     """
     res = []
     j = len(x_shape) - 1
@@ -73,15 +72,15 @@
     """
     Infer the dependency of all operations with the
     given op as the last operation.
-    Operation A is depending on B if A uses the output(s) of B.
+    Operator A is depending on B if A uses the output(s) of B.
 
     Args:
-        op: an Operation instance, e.g. the loss operation.
+        op: an Operator instance, e.g. the loss operation.
 
     Return:
         a Counter instance with the operation as the key,
         and the number of operations that are depending on it as the value;
-        and a Counter instance with the id of the output tensor as the key, and 
+        and a Counter instance with the id of the output tensor as the key, and
         the number of operations that are depending on it as the value.
     """
 
@@ -118,7 +117,11 @@
     """
     grads = {}  # mapping: x->dx if x.stores_grad
     for p, dp in backward(y, dy):
-        grads[p] = dp
+        # TODO: this fn is only helper for test case for now.
+        #   1. could implement __hash__ or
+        #   2. make grad as a attribute of tensor class
+        #      p.grad = dp
+        grads[id(p)] = dp
     return grads
 
 
@@ -221,12 +224,12 @@
         del op  # delete the operation to free all tensors from this op
 
 
-class Operation(object):
+class Operator(object):
     """
     An operation includes the forward and backward function of
     tensor calculation.
     Steps to add a specific operation Xxxx:
-    1. create a subclass of Operation, name it as Xxxx
+    1. create a subclass of Operator, name it as Xxxx
     2. override the forward() and backward(); The arguments of forward()
        and backward() should only include CTensor;
     """
@@ -236,8 +239,8 @@
     def __init__(self, name=None):
         if name is None:
             self.name = "{}#{}".format(self.__class__.__name__,
-                                       Operation.op_count)
-            Operation.op_count += 1
+                                       Operator.op_count)
+            Operator.op_count += 1
         else:
             self.name = name
 
@@ -338,7 +341,7 @@
         return []
 
 
-class Dummy(Operation):
+class Dummy(Operator):
     """Dummy operation whice serves as a placehoder for autograd
     Args:
         name(string): set it for debug
@@ -361,7 +364,7 @@
         return self.tensor.__getattribute__(name)
 
 
-class Mean(Operation):
+class Mean(Operator):
     """
     Element-wise mean of each of the input CTensors.
     """
@@ -406,9 +409,9 @@
     return Mean()(*l)[0]
 
 
-class ReLU(Operation):
+class ReLU(Operator):
     """
-    Relu means rectified linear function, i.e, y = max(0, x) is applied to the 
+    Relu means rectified linear function, i.e, y = max(0, x) is applied to the
     CTensor elementwise.
     """
 
@@ -438,7 +441,7 @@
 
 def relu(x):
     """
-    Relu means rectified linear function, i.e, y = max(0, x) is applied to the 
+    Relu means rectified linear function, i.e, y = max(0, x) is applied to the
     CTensors elementwise.
     Args:
         x (Tensor): input tensor.
@@ -448,9 +451,9 @@
     return ReLU()(x)[0]
 
 
-class Less(Operation):
+class Less(Operator):
     """
-    Returns the tensor resulted from performing the less logical operation 
+    Returns the tensor resulted from performing the less logical operation
     elementwise on the input CTensors x and y.
     """
 
@@ -483,9 +486,9 @@
     return Less()(x, y)[0]
 
 
-class Clip(Operation):
+class Clip(Operator):
     """
-    Clip operator limits the given input within an interval. The interval 
+    Clip operator limits the given input within an interval. The interval
     is specified by the inputs 'min' and 'max'.
     """
 
@@ -510,6 +513,7 @@
         self.mask.SetFloatValue(1.0)
 
         if self.min is not None:
+            self.min = float(self.min)
             mask0 = singa.LTFloat(x, self.min)
             mask1 = singa.GEFloat(x, self.min)
             self.mask = singa.__mul__(mask1, self.mask)
@@ -517,6 +521,7 @@
                               singa.__mul__(mask1, x))
 
         if self.max is not None:
+            self.max = float(self.max)
             mask0 = singa.GTFloat(x, self.max)
             mask1 = singa.LEFloat(x, self.max)
             self.mask = singa.__mul__(mask1, self.mask)
@@ -537,7 +542,7 @@
 
 def clip(x, min=None, max=None):
     """
-    Clip operator limits the given input within an interval. The interval 
+    Clip operator limits the given input within an interval. The interval
     is specified by the inputs 'min' and 'max'.
     Args:
         x (Tensor): input tensor
@@ -549,7 +554,7 @@
     return Clip(min, max)(x)[0]
 
 
-class Identity(Operation):
+class Identity(Operator):
     """
     Init a identity operator
     """
@@ -587,7 +592,7 @@
     return Identity()(x)[0]
 
 
-class Matmul(Operation):
+class Matmul(Operator):
     """
     Init matrix multiplication operator.
     """
@@ -599,9 +604,11 @@
         """
         Return `np.matmul(x,w)`, where x and w are CTensor.
         """
+        # todo, cannot do Mult for dims more than 2
         if training:
             self.input = (x, w)
-        return singa.Mult(x, w)
+        res = singa.Mult(x, w)
+        return res
 
     def backward(self, dy):
         """
@@ -623,9 +630,9 @@
     return Matmul()(x, w)[0]
 
 
-class Greater(Operation):
+class Greater(Operator):
     """
-    Returns the tensor resulted from performing the greater logical 
+    Returns the tensor resulted from performing the greater logical
     operation elementwise on the input tensors A and B.
     """
 
@@ -658,7 +665,7 @@
     return Greater()(x, y)[0]
 
 
-class AddBias(Operation):
+class AddBias(Operator):
     """
     Add Bias to each row / column of the Tensor, depending on the axis arg.
     """
@@ -710,26 +717,31 @@
     Return:
         the result Tensor
     """
+    assert x.ndim() == 2, "1st arg required 2d tensor. got shape: %s" % (
+        x.shape)
+    assert b.ndim() == 1, "2nd arg required 1d tensor. got shape: %s" % (
+        b.shape)
+    assert axis in [0, 1], "allowed axis: 0 or 1"
     return AddBias(axis)(x, b)[0]
 
 
-class Reshape(Operation):
+class Reshape(Operator):
     """
-    Reshape the input tensor similar to np.reshape. 
+    Reshape the input tensor similar to np.reshape.
     """
 
     def __init__(self, shape):
         """
         Args:
             shape (list of int): Specified shape for output. At most one
-                dimension of the new shape can be -1. In this case, the 
-                value is inferred from the size of the tensor and the 
-                remaining dimensions. A dimension could also be 0, 
-                in which case the actual dimension value is unchanged 
+                dimension of the new shape can be -1. In this case, the
+                value is inferred from the size of the tensor and the
+                remaining dimensions. A dimension could also be 0,
+                in which case the actual dimension value is unchanged
                 (i.e. taken from the input tensor).
         """
         super(Reshape, self).__init__()
-        self.shape = list(shape)
+        self.shape = shape
 
     def forward(self, x):
         """
@@ -739,7 +751,7 @@
             the result CTensor
         """
         self._shape = x.shape()
-        shape = self.shape
+        shape = list(self.shape)
         # handle the shape with 0
         shape = [
             self._shape[i]
@@ -748,7 +760,7 @@
         ]
         # handle the shape with -1
         hidden_shape = int(np.prod(self._shape) // np.abs(np.prod(shape)))
-        self.cache = [s if s != -1 else hidden_shape for s in shape]
+        self.cache = [int(s) if s != -1 else hidden_shape for s in shape]
         return singa.Reshape(x, self.cache)
 
     def backward(self, dy):
@@ -763,14 +775,14 @@
 
 def reshape(x, shape):
     """
-    Reshape the input tensor similar to mp.reshape. 
+    Reshape the input tensor similar to mp.reshape.
     Args:
         x (Tensor): matrix.
         shape (list of int): Specified shape for output. At most one
-            dimension of the new shape can be -1. In this case, the 
-            value is inferred from the size of the tensor and the 
-            remaining dimensions. A dimension could also be 0, 
-            in which case the actual dimension value is unchanged 
+            dimension of the new shape can be -1. In this case, the
+            value is inferred from the size of the tensor and the
+            remaining dimensions. A dimension could also be 0,
+            in which case the actual dimension value is unchanged
             (i.e. taken from the input tensor).
     Return:
         the result Tensor
@@ -778,9 +790,9 @@
     return Reshape(shape)(x)[0]
 
 
-class PRelu(Operation):
+class PRelu(Operator):
     """
-    PRelu applies the function `f(x) = slope * x` for x < 0, 
+    PRelu applies the function `f(x) = slope * x` for x < 0,
     `f(x) = x` for x >= 0 to the data tensor elementwise.
     """
 
@@ -830,7 +842,7 @@
 
 def prelu(x, slope):
     """
-    PRelu applies the function `f(x) = slope * x` for x < 0, 
+    PRelu applies the function `f(x) = slope * x` for x < 0,
     `f(x) = x` for x >= 0 to the data tensor elementwise.
     Args:
         x (Tensor): matrix.
@@ -840,7 +852,7 @@
     return PRelu()(x, slope)[0]
 
 
-class Add(Operation):
+class Add(Operator):
     """
     Performs element-wise binary addition.
     """
@@ -884,9 +896,9 @@
     return Add()(a, b)[0]
 
 
-class Elu(Operation):
+class Elu(Operator):
     """
-    `f(x) = alpha * (exp(x) - 1.)` for x < 0, `f(x) = x` for x >= 0., is applied to 
+    `f(x) = alpha * (exp(x) - 1.)` for x < 0, `f(x) = x` for x >= 0., is applied to
     the tensor elementwise.
     """
 
@@ -935,7 +947,7 @@
 
 def elu(x, alpha=1):
     """
-    `f(x) = alpha * (exp(x) - 1.)` for x < 0, `f(x) = x` for x >= 0., is applied to 
+    `f(x) = alpha * (exp(x) - 1.)` for x < 0, `f(x) = x` for x >= 0., is applied to
     the tensor elementwise.
     Args:
         x (Tensor): matrix
@@ -946,9 +958,9 @@
     return Elu(alpha)(x)[0]
 
 
-class Equal(Operation):
+class Equal(Operator):
     """
-    Returns the tensor resulted from performing the equal logical operation 
+    Returns the tensor resulted from performing the equal logical operation
     elementwise on the input tensors x and y.
     """
 
@@ -959,9 +971,7 @@
         """
         Return `a=b`, where a and b are CTensor.
         """
-        m = singa.__sub__(x, y)
-        cur = singa.__mul__(singa.GEFloat(m, 0), singa.LEFloat(m, 0))
-        return cur
+        return singa.__eq__(x, y)
 
     def backward(self, dy):
         """
@@ -980,9 +990,9 @@
     return Equal()(x, y)[0]
 
 
-class SeLU(Operation):
+class SeLU(Operator):
     """
-    `y = gamma * (alpha * e^x - alpha)` for x <= 0, `y = gamma * x` for x > 0 
+    `y = gamma * (alpha * e^x - alpha)` for x <= 0, `y = gamma * x` for x > 0
     is applied to the tensor elementwise.
     """
 
@@ -1036,7 +1046,7 @@
 
 def selu(x, alpha=1.67326, gamma=1.0507):
     """
-    `y = gamma * (alpha * e^x - alpha)` for x <= 0, `y = gamma * x` for x > 0 
+    `y = gamma * (alpha * e^x - alpha)` for x <= 0, `y = gamma * x` for x > 0
     is applied to the tensor elementwise.
     Args:
         x (Tensor): matrix
@@ -1048,7 +1058,7 @@
     return SeLU(alpha, gamma)(x)[0]
 
 
-class SoftMax(Operation):
+class SoftMax(Operator):
     """
     Apply SoftMax for each row of the Tensor or each column of the Tensor
     according to the parameter axis.
@@ -1095,7 +1105,7 @@
     return SoftMax(axis)(x)[0]
 
 
-class Sum(Operation):
+class Sum(Operator):
     """
     Element-wise sum of each of the input tensors
     """
@@ -1140,16 +1150,17 @@
     return Sum()(*l)[0]
 
 
-class CrossEntropy(Operation):
+class BinaryCrossEntropy(Operator):
 
-    def __init__(self):
-        super(CrossEntropy, self).__init__()
+    def __init__(self, t):
+        super(BinaryCrossEntropy, self).__init__()
+        self.t = t.data
 
     """
     Calculte negative log likelihood loss for a batch of training data.
     """
 
-    def forward(self, x, t):
+    def forward(self, x):
         """
         Args:
             x (CTensor): 1d or 2d tensor, the prediction data(output)
@@ -1158,11 +1169,14 @@
         Returns:
             loss (CTensor): scalar.
         """
-        loss = singa.SumAll(singa.__mul__(t, singa.Log(x)))
+        posx = singa.AddFloat(x, 0.0001)
+        loss = singa.SumAll(singa.__mul__(self.t, singa.Log(posx)))
+        negt = singa.AddFloat(singa.MultFloat(self.t,-1.0), 1.0)
+        negx = singa.AddFloat(singa.MultFloat(x,-1.0), 1.0001)
+        negLoss = singa.SumAll(singa.__mul__(negt, singa.Log(negx)))
+        loss += negLoss
         loss /= -x.shape()[0]
-        self.x = x
-        self.t = t
-        self.input = (x, t)
+        self.x = singa.AddFloat(x, 0.0001)
         return loss
 
     def backward(self, dy=1.0):
@@ -1175,21 +1189,119 @@
                           of current network. note that this is true for
                           dy = 1.0
         """
+
+        dx = singa.__div__(self.t, self.x)
+        negt = singa.AddFloat(self.t, -1.0)
+        negx = singa.AddFloat(self.x, -0.9999)
+        dx -= singa.__div__(negt, negx)
+        dx *= float(-1.0 / self.x.shape()[0])
+        if isinstance(dy, float):
+            # dtype of dy: float
+            dx *= dy
+            return dx
+        elif isinstance(dy, CTensor):
+            pass  # TODO, broadcast elementwise multiply seems not support
+
+
+def binary_cross_entropy(x, t):
+    return BinaryCrossEntropy(t)(x)[0]
+
+
+class CrossEntropy(Operator):
+
+    def __init__(self, t):
+        super(CrossEntropy, self).__init__()
+        self.t = t.data
+
+    """
+    Calculte negative log likelihood loss for a batch of training data.
+    """
+
+    def forward(self, x):
+        """
+        Args:
+            x (CTensor): 1d or 2d tensor, the prediction data(output)
+                         of current network.
+            t (CTensor): 1d or 2d tensor, the target data for training.
+        Returns:
+            loss (CTensor): scalar.
+        """
+        loss = singa.SumAll(singa.__mul__(self.t, singa.Log(x)))
+        loss /= -x.shape()[0]
+        self.x = x
+        return loss
+
+    def backward(self, dy=1.0):
+        """
+        Args:
+            dy (float or CTensor): scalar, accumulate gradient from outside
+                                of current network, usually equal to 1.0
+        Returns:
+            dx (CTensor): data for the dL /dx, L is the loss, x is the output
+                          of current network. note that this is true for
+                          dy = 1.0
+        """
+
         dx = singa.__div__(self.t, self.x)
         dx *= float(-1.0 / self.x.shape()[0])
         if isinstance(dy, float):
             # dtype of dy: float
             dx *= dy
-            return dx, None
+            return dx
         elif isinstance(dy, CTensor):
             pass  # TODO, broadcast elementwise multiply seems not support
 
 
-def cross_entropy(y, t):
-    return CrossEntropy()(y, t)[0]
+def cross_entropy(x, t):
+    assert x.ndim() == 2, "1st arg required 2d tensor. got shape: " + str(
+        x.shape)
+    assert t.ndim() <= 2, "2nd arg required <=2d tensor. got shape: " + str(
+        t.shape)
+    # x is the logits and t is the ground truth.
+    return CrossEntropy(t)(x)[0]
 
 
-class SoftMaxCrossEntropy(Operation):
+class RankingLoss(Operator):
+
+    def __init__(self, M=0.2):
+        super().__init__()
+        # margin
+        self.M = M
+
+    def forward(self, pos, neg):
+        # L = max{0, M - fn(pos) + fn(neg)}
+        zero = singa.Tensor(list(pos.shape()), pos.device())
+        zero.SetFloatValue(0.0)
+        val = singa.AddFloat(singa.__sub__(neg, pos), self.M)
+        gt_zero = singa.__gt__(val, zero)
+        if training:
+            self.inputs = (gt_zero,)  # (BS,)
+        all_loss = singa.__mul__(gt_zero, val)
+        loss = singa.SumAll(all_loss)
+        loss /= (pos.shape()[0])
+        return loss
+
+    def backward(self, dy=1.0):
+        assert training, "enable training mode to do backward"
+        # dpos = -1 if M-pos+neg > 0 else 0
+        # dneg =  1 if M-pos+neg > 0 else 0
+        gt_zero = self.inputs[0]
+        dpos_factor = singa.Tensor(list(gt_zero.shape()), gt_zero.device())
+        dpos_factor.SetFloatValue(-1.0 / gt_zero.Size())
+        dneg_factor = singa.Tensor(list(gt_zero.shape()), gt_zero.device())
+        dneg_factor.SetFloatValue(1.0 / gt_zero.Size())
+        dpos = singa.__mul__(gt_zero, dpos_factor)
+        dneg = singa.__mul__(gt_zero, dneg_factor)
+        return dpos, dneg
+
+
+def ranking_loss(pos, neg, M=0.2):
+    assert pos.shape == neg.shape, "input and target shape different: %s, %s" % (
+        pos.shape, neg.shape)
+    return RankingLoss(M)(pos, neg)[0]
+
+
+class SoftMaxCrossEntropy(Operator):
 
     def __init__(self, t):
         super(SoftMaxCrossEntropy, self).__init__()
@@ -1209,58 +1321,68 @@
 
 
 def softmax_cross_entropy(x, t):
-    # x is the logits and t is the ground truth; both are 2D.
+    assert x.ndim() == 2, "1st arg required 2d tensor. got shape: " + str(
+        x.shape)
+    assert t.ndim() <= 2, "2nd arg required <=2d tensor. got shape: " + str(
+        t.shape)
+    # x is the logits and t is the ground truth.
     return SoftMaxCrossEntropy(t)(x)[0]
 
 
-class MeanSquareError(Operation):
+class MeanSquareError(Operator):
 
-    def __init__(self):
+    def __init__(self, t):
         super(MeanSquareError, self).__init__()
+        self.t = t.data
 
-    def forward(self, x, t):
-        self.err = singa.__sub__(x, t)
+    def forward(self, x):
+        self.err = singa.__sub__(x, self.t)
         sqr = singa.Square(self.err)
         loss = singa.SumAll(sqr)
-        loss /= (x.shape()[0] * 2)
+        self.n = 1
+        for s in x.shape():
+            self.n *= s
+        loss /= self.n
         return loss
 
     def backward(self, dy=1.0):
         dx = self.err
-        dx *= float(1 / self.err.shape()[0])
+        dx *= float(2 / self.n)
         dx *= dy
-        return dx, None
+        return dx
 
 
 def mse_loss(x, t):
-    return MeanSquareError()(x, t)[0]
+    assert x.shape == t.shape, "input and target shape different: %s, %s" % (
+        x.shape, t.shape)
+    return MeanSquareError(t)(x)[0]
 
 
 def ctensor2numpy(x):
     """
-    To be used in SoftMax Operation.
+    To be used in SoftMax Operator.
     Convert a singa_tensor to numpy_tensor.
     """
     np_array = x.GetFloatValue(int(x.Size()))
     return np_array.reshape(x.shape())
 
 
-class Flatten(Operation):
+class Flatten(Operator):
     """
-    Flattens the input tensor into a 2D matrix. If input tensor has shape 
-    `(d_0, d_1, ... d_n)` then the output will have shape `(d_0 X d_1 ... 
+    Flattens the input tensor into a 2D matrix. If input tensor has shape
+    `(d_0, d_1, ... d_n)` then the output will have shape `(d_0 X d_1 ...
     d_(axis-1), d_axis X d_(axis+1) ... X dn)`.
     """
 
     def __init__(self, axis=1):
         """
         Args:
-            axis (int): Indicate up to which input dimensions (exclusive) 
-                should be flattened to the outer dimension of the output. The 
-                value for axis must be in the range [-r, r], where r is the 
-                rank of the input tensor. Negative value means counting 
-                dimensions from the back. When axis = 0, the shape of the 
-                output tensor is `(1, (d_0 X d_1 ... d_n)`, where the shape 
+            axis (int): Indicate up to which input dimensions (exclusive)
+                should be flattened to the outer dimension of the output. The
+                value for axis must be in the range [-r, r], where r is the
+                rank of the input tensor. Negative value means counting
+                dimensions from the back. When axis = 0, the shape of the
+                output tensor is `(1, (d_0 X d_1 ... d_n)`, where the shape
                 of the input tensor is `(d_0, d_1, ... d_n)`.
         Returns:
             the result CTensor
@@ -1301,17 +1423,17 @@
 
 def flatten(x, axis=1):
     """
-    Flattens the input tensor into a 2D matrix. If input tensor has shape 
-    `(d_0, d_1, ... d_n)` then the output will have shape `(d_0 X d_1 ... 
+    Flattens the input tensor into a 2D matrix. If input tensor has shape
+    `(d_0, d_1, ... d_n)` then the output will have shape `(d_0 X d_1 ...
     d_(axis-1), d_axis X d_(axis+1) ... X dn)`.
     Args:
         x (Tensor): the input tensor
-        axis (int): Indicate up to which input dimensions (exclusive) 
-            should be flattened to the outer dimension of the output. The 
-            value for axis must be in the range [-r, r], where r is the 
-            rank of the input tensor. Negative value means counting 
-            dimensions from the back. When axis = 0, the shape of the 
-            output tensor is `(1, (d_0 X d_1 ... d_n)`, where the shape 
+        axis (int): Indicate up to which input dimensions (exclusive)
+            should be flattened to the outer dimension of the output. The
+            value for axis must be in the range [-r, r], where r is the
+            rank of the input tensor. Negative value means counting
+            dimensions from the back. When axis = 0, the shape of the
+            output tensor is `(1, (d_0 X d_1 ... d_n)`, where the shape
             of the input tensor is `(d_0, d_1, ... d_n)`.
     Returns:
         the result Tensor
@@ -1319,135 +1441,131 @@
     return Flatten(axis)(x)[0]
 
 
-class Layer(object):
-
-    def __init__(self):
-        self.allow_params = []
-        pass
-
-    def device_check(self, *inputs):
-        x_device = inputs[0].device
-        x_dev_id = x_device.id()
-        for var in inputs:
-            if var.device.id() != x_dev_id:
-                var.to_device(x_device)
-
-    def find_sublayers(self):
-        # return a list whose elements are in form of (attribute_name,
-        # sublayer)
-        sublayers = []
-        for attr in self.__dict__:
-            if isinstance(self.__dict__[attr], Layer):
-                sublayers.append((attr, self.__dict__[attr]))
-        return sublayers
-
-    def get_params(self):
-        sublayers = self.find_sublayers()
-        params = dict()
-        for sublayer_name, sublayer in sublayers:
-            params[sublayer_name] = sublayer.get_params()
-        return params
-
-    def set_params(self, **parameters):
-        # set parameters for Layer
-        # input should be either a PyTensor or numpy ndarray.
-        # examples: Layer.set_params(W=np.ones((in, out), dtype=np.float32)),
-        # Layer.set_params(**{'block1':{'linear1':{'W':np.ones((in, out),
-        # dtype=np.float32)}}})
-        for (parameter_name, parameter_value) in parameters.items():
-            # assert isinstance(self.__dict__[parameter_name], Layer)
-            assert (parameter_name in self.__dict__
-                   ), "please input correct parameters."
-            if isinstance(self.__dict__[parameter_name], Layer):
-                self.__dict__[parameter_name].set_params(
-                    **parameters[parameter_name])
-            elif isinstance(self.__dict__[parameter_name], Tensor):
-                self.set_one_param(parameter_name, parameter_value)
-            else:
-                raise ValueError("please input correct parameters.")
-
-    def set_one_param(self, parameter_name, parameter_value):
-        assert (parameter_name in self.allow_params
-               ), "please input allowed parameters."
-        assert (parameter_value.shape == self.__dict__[parameter_name].shape
-               ), "Shape dismatched."
-        if isinstance(parameter_value, Tensor):
-            self.__dict__[parameter_name].reset_like(parameter_value)
-        elif isinstance(parameter_value, np.ndarray):
-            self.__dict__[parameter_name].copy_from_numpy(parameter_value)
-        else:
-            raise ValueError("parameters should be Tensor or Numpy array.")
-
-
-class Linear(Layer):
+class ScatterElements(Operator):
     """
-    Generate a Linear operator
+    ScatterElements operator following ONNX Operator Schemas
+    https://github.com/onnx/onnx/blob/master/docs/Changelog.md#ScatterElements-11
+
+    Example usage:
+    data = [
+        [0.0, 0.0, 0.0],
+        [0.0, 0.0, 0.0],
+        [0.0, 0.0, 0.0],
+    ]
+    axis = 0
+    indices = [
+        [1, 0, 2],
+        [0, 2, 1],
+    ]
+    updates = [
+        [1.0, 1.1, 1.2],
+        [2.0, 2.1, 2.2],
+    ]
+    output = [
+        [2.0, 1.1, 0.0]
+        [1.0, 0.0, 2.2]
+        [0.0, 2.1, 1.2]
+    ]
+
     """
 
-    def __init__(self, in_features, out_features, bias=True):
+    def __init__(self, indices, updates, axis=0):
         """
         Args:
-            in_channels: int, the channel of input
-            out_channels: int, the channel of output, also is the number of 
-                filters
-            bias: bool
+            indices (Tensor): index tensor
+            updates (Tensor): source tensor
+            axis (int): Which axis to scatter on. A negative value means 
+                counting dimension from the back. Accepted range is [-r,r-1]
+                where r=rank(destination_tensor) 
         """
-        w_shape = (in_features, out_features)
-        b_shape = (out_features,)
-        self.bias = bias
+        super(ScatterElements, self).__init__()
+        self.indices = indices
+        self.updates = updates
+        self.axis = axis
 
-        self.W = Tensor(shape=w_shape, requires_grad=True, stores_grad=True)
-        std = math.sqrt(2.0 / (in_features + out_features))
-        self.W.gaussian(0.0, std)
+    def forward(self, x):
+        x_shape = x.shape()
+        x_rank = len(x_shape)
+        if isinstance(self.indices, Tensor):
+            self.indices = tensor.to_numpy(self.indices)
+        elif isinstance(self.indices, (list, tuple)):
+            self.indices = np.array(self.indices)
+        if isinstance(self.updates, Tensor):
+            self.updates = tensor.to_numpy(self.updates)
+        elif isinstance(self.updates, (list, tuple)):
+            self.updates = np.array(self.updates)
+        self.updates.astype(np.int32)
+        _x = tensor.to_numpy(tensor.from_raw_tensor(x))
+        _x = _x.astype(np.float32)
 
-        if self.bias:
-            self.b = Tensor(shape=b_shape, requires_grad=True, stores_grad=True)
-            self.b.set_value(0.0)
+        assert x_rank == 2, "Only support 2D input."
+        assert x_rank == len(
+            self.indices.shape
+        ), "Index should have the same number of dimensions as output"
+        assert -x_rank < self.axis <= x_rank, "Axis is out of range"
+        assert np.logical_and(
+            -_x.shape[self.axis] < self.indices,
+            self.indices <= _x.shape[self.axis]).all(
+            ), "The values of the indexes should be between %d and %d" % (-_x.shape[self.axis], _x.shape[self.axis] - 1)
 
-    def __call__(self, x):
-        if self.bias:
-            self.device_check(x, self.W, self.b)
-        else:
-            self.device_check(x, self.W)
-        assert x.shape[1] == self.W.shape[0], (
-            "Linear layer expects input features size %d received %d" %
-            (self.W.shape[0], x.shape[1]))
-        y = matmul(x, self.W)
-        if self.bias:
-            y = add_bias(y, self.b, axis=0)
-        return y
+        self.axis = self.axis % x_rank
+        u_shape = self.updates.shape
+        y = _x.copy()
+        for i in range(u_shape[0]):
+            for j in range(u_shape[1]):
+                idx = int(self.indices[i][j])
+                if self.axis == 0:
+                    y[idx][j] = self.updates[i][j]
+                else:
+                    y[i][idx] = self.updates[i][j]
+        y = tensor.from_numpy(y)
+        y.to_device(x.device())
+        return y.data
 
-    def get_params(self):
-        if self.bias:
-            return {"W": self.W, "b": self.b}
-        else:
-            return {"W": self.W}
-
-    def set_params(self, **parameters):
-        # TODO(wangwei) remove this funciton as Opeation's set_params() enough
-        # set parameters for Linear Layer
-        # input should be either a PyTensor or numpy ndarray.
-        # examples: Linear.set_params(W=np.ones((in, out), dtype=np.float32)),
-        # Linear.set_params(**{'W':np.ones((in, out), dtype=np.float32)})
-        self.allow_params = ["W", "b"]
-        super(Linear, self).set_params(**parameters)
-        for parameter_name in parameters:
-            if parameter_name is "b":
-                self.bias = True
+    def backward(self, dy):
+        mask = np.ones(dy.shape(), dtype=np.float32)
+        u_shape = self.updates.shape
+        for i in range(u_shape[0]):
+            for j in range(u_shape[1]):
+                idx = int(self.indices[i][j])
+                if self.axis == 0:
+                    mask[idx][j] = 0.
+                else:
+                    mask[i][idx] = 0.
+        mask = tensor.from_numpy(mask)
+        mask.to_device(dy.device())
+        return singa.__mul__(dy, mask.data)
 
 
-class Concat(Operation):
+def scatter_elements(x, indices, updates, axis=0):
     """
-    Concatenate a list of tensors into a single tensor. All input tensors must 
-    have the same shape, except for the dimension size of the axis to 
+    Produces a ScatterElements operator
+    Args:
+        x (Tensor): input tensor.
+        indices (Tensor): index tensor
+        updates (Tensor): source tensor
+        axis (int): Which axis to scatter on. A negative value means 
+            counting dimension from the back. Accepted range is [-r,r-1]
+            where r=rank(destination_tensor) 
+    Returns:
+        the output Tensor.
+    """
+    return ScatterElements(indices, updates, axis)(x)[0]
+
+
+
+class Concat(Operator):
+    """
+    Concatenate a list of tensors into a single tensor. All input tensors must
+    have the same shape, except for the dimension size of the axis to
     concatenate on.
     """
 
     def __init__(self, axis=0):
         """
         Args:
-            axis (int): Which axis to concat on. A negative value means 
-                counting dimensions from the back. Accepted range is [-r, r-1] 
+            axis (int): Which axis to concat on. A negative value means
+                counting dimensions from the back. Accepted range is [-r, r-1]
                 where r = rank(inputs).
         Returns:
             the result CTensor
@@ -1462,6 +1580,8 @@
         Returns:
             a CTensor for the result
         """
+        if self.axis < 0:
+            self.axis = self.axis % len(xs[0].shape())
         if training:
             offset = 0
             self.slice_point = []
@@ -1491,21 +1611,26 @@
 
 def cat(xs, axis=0):
     """
-    Concatenate a list of tensors into a single tensor. All input tensors must 
-    have the same shape, except for the dimension size of the axis to 
+    Concatenate a list of tensors into a single tensor. All input tensors must
+    have the same shape, except for the dimension size of the axis to
     concatenate on.
     Args:
         xs (a list of Tensor): List of tensors for concatenation
-        axis (int): Which axis to concat on. A negative value means 
-            counting dimensions from the back. Accepted range is [-r, r-1] 
+        axis (int): Which axis to concat on. A negative value means
+            counting dimensions from the back. Accepted range is [-r, r-1]
             where r = rank(inputs).
     Returns:
         a Tensor for the result
     """
     return Concat(axis)(*xs)[0]
+"""
+def make_slice(arr, axis, i):  # type: ignore
+        slc = [slice(None)] * arr.ndim
+        slc[axis] = i
+        return slc
+"""
 
-
-class _Conv2d(Operation):
+class _Conv2d(Operator):
     """
     Init a conv 2d operator
     """
@@ -1514,16 +1639,14 @@
         """
         Args:
             handle (object): ConvHandle for cpu or CudnnConvHandle for gpu
-            odd_padding (tuple of four ints):, the odd paddding is the value 
-                that cannot be handled by the tuple padding (w, h) mode so 
-                we need to firstly handle the input, then use the nomal padding 
+            odd_padding (tuple of four ints):, the odd paddding is the value
+                that cannot be handled by the tuple padding (w, h) mode so
+                we need to firstly handle the input, then use the nomal padding
                 method.
         """
         super(_Conv2d, self).__init__()
         self.handle = handle
         self.odd_padding = odd_padding
-        if self.odd_padding != (0, 0, 0, 0):
-            self.re_new_handle = True
 
     def forward(self, x, W, b=None):
         """
@@ -1532,15 +1655,11 @@
             W (CTensor): weight
             b (CTensor): bias
         Returns:
-            CTensor 
+            CTensor
         """
         assert x.nDim() == 4, "The dimensions of input should be 4D."
         if self.odd_padding != (0, 0, 0, 0):
             x = utils.handle_odd_pad_fwd(x, self.odd_padding)
-            # re-new a handle with updated x
-            if self.re_new_handle:
-                self.re_new_handle = False
-                self.handle = utils.re_new_handle(self.handle, x)
 
         if training:
             if self.handle.bias_term:
@@ -1602,9 +1721,9 @@
         x (Tensor): input
         W (Tensor): weight
         b (Tensor): bias
-        odd_padding (tuple of four ints):, the odd paddding is the value 
-            that cannot be handled by the tuple padding (w, h) mode so 
-            we need to firstly handle the input, then use the nomal padding 
+        odd_padding (tuple of four ints):, the odd paddding is the value
+            that cannot be handled by the tuple padding (w, h) mode so
+            we need to firstly handle the input, then use the nomal padding
             method.
     """
     if b is None:
@@ -1613,332 +1732,16 @@
         return _Conv2d(handle, odd_padding)(x, W, b)[0]
 
 
-class Conv2d(Layer):
+class _BatchNorm2d(Operator):
     """
-    Generate a Conv 2d operator
-    """
-
-    def __init__(self,
-                 in_channels,
-                 out_channels,
-                 kernel_size,
-                 stride=1,
-                 padding=0,
-                 dilation=1,
-                 group=1,
-                 bias=True,
-                 pad_mode="NOTSET",
-                 **kwargs):
-        """
-        Args:
-            in_channels (int): the channel of input
-            out_channels (int): the channel of output, also is the number of filters
-            kernel_size (int or tuple): kernel size for two direction of each 
-                axis. For example, (2, 3), the first 2 means will add 2 at the 
-                beginning and also 2 at the end for its axis.and if a int is 
-                accepted, the kernel size will be initiated as (int, int)
-            stride (int or tuple): stride, the logic is the same as kernel size.
-            padding (int): tuple, list or None, padding, the logic is the same 
-                as kernel size. However, if you set pad_mode as "SAME_UPPER" or 
-                "SAME_LOWER" mode, you can set padding as None, and the padding 
-                will be computed automatically.
-            dilation (int): only support 1
-            group (int): group
-            bias (bool): bias
-            pad_mode (string): can be NOTSET, SAME_UPPER, or SAME_LOWER, where 
-                default value is NOTSET, which means explicit padding is used. 
-                SAME_UPPER or SAME_LOWER mean pad the input so that the output 
-                spatial size match the input. In case of odd number add the extra 
-                padding at the end for SAME_UPPER and at the beginning for SAME_LOWER.
-        """
-        self.in_channels = in_channels
-        self.out_channels = out_channels
-
-        self.group = group
-
-        assert (self.group >= 1 and self.in_channels %
-                self.group == 0), "please set reasonable group."
-
-        assert (self.out_channels >= self.group and self.out_channels %
-                self.group == 0), "out_channels and group dismatched."
-
-        if isinstance(kernel_size, int):
-            self.kernel_size = (kernel_size, kernel_size)
-        elif isinstance(kernel_size, tuple):
-            self.kernel_size = kernel_size
-        else:
-            raise TypeError("Wrong kernel_size type.")
-
-        if isinstance(stride, int):
-            self.stride = (stride, stride)
-        elif isinstance(stride, tuple):
-            self.stride = stride
-        else:
-            raise TypeError("Wrong stride type.")
-
-        self.odd_padding = (0, 0, 0, 0)
-        if isinstance(padding, int):
-            self.padding = (padding, padding)
-        elif isinstance(padding, tuple) or isinstance(padding, list):
-            if len(padding) == 2:
-                self.padding = padding
-            elif len(padding) == 4:
-                _h_mask = padding[0] - padding[1]
-                _w_mask = padding[2] - padding[3]
-                # the odd paddding is the value that cannot be handled by the tuple padding (w, h) mode
-                # so we need to firstly handle the input, then use the nomal padding method.
-                self.odd_padding = (max(_h_mask, 0), max(-_h_mask, 0),
-                                    max(_w_mask, 0), max(-_w_mask, 0))
-                self.padding = (
-                    padding[0] - self.odd_padding[0],
-                    padding[2] - self.odd_padding[2],
-                )
-            else:
-                raise TypeError("Wrong padding value.")
-
-        if dilation != 1:
-            raise ValueError("Not implemented yet")
-
-        self.bias = bias
-
-        self.inner_params = {
-            "cudnn_prefer": "fastest",
-            "workspace_MB_limit": 1024,
-        }
-        # TODO valid value of inner_params check
-
-        for kwarg in kwargs:
-            if kwarg not in self.inner_params:
-                raise TypeError("Keyword argument not understood:", kwarg)
-            else:
-                self.inner_params[kwarg] = kwargs[kwarg]
-
-        w_shape = (
-            self.out_channels,
-            int(self.in_channels / self.group),
-            self.kernel_size[0],
-            self.kernel_size[1],
-        )
-
-        self.W = Tensor(shape=w_shape, requires_grad=True, stores_grad=True)
-        # std = math.sqrt(
-        # 2.0 / (self.in_channels * self.kernel_size[0] * self.kernel_size[1] +
-        # self.out_channels))
-        std = math.sqrt(
-            2.0 / (w_shape[1] * self.kernel_size[0] * self.kernel_size[1] +
-                   self.out_channels))
-        self.W.gaussian(0.0, std)
-
-        if self.bias:
-            b_shape = (self.out_channels,)
-            self.b = Tensor(shape=b_shape, requires_grad=True, stores_grad=True)
-            self.b.set_value(0.0)
-        else:
-            # to keep consistency when to do forward.
-            self.b = None
-            # Tensor(data=CTensor([]), requires_grad=False, stores_grad=False)
-        self.pad_mode = pad_mode
-
-    def __call__(self, x):
-        assert x.shape[1] == self.in_channels, "in_channels mismatched"
-
-        # if same pad mode, re-compute the padding
-        if self.pad_mode in ("SAME_UPPER", "SAME_LOWER"):
-            self.padding, self.odd_padding = utils.get_padding_shape(
-                self.pad_mode, x.shape[2:], self.kernel_size, self.stride)
-
-        if self.bias:
-            self.device_check(x, self.W, self.b)
-        else:
-            self.device_check(x, self.W)
-
-        if x.device.id() == -1:
-            if self.group != 1:
-                raise ValueError("Not implemented yet")
-            else:
-                if (not hasattr(self, "handle")) or (x.shape[0] !=
-                                                     self.handle.batchsize):
-                    self.handle = singa.ConvHandle(
-                        x.data,
-                        self.kernel_size,
-                        self.stride,
-                        self.padding,
-                        self.in_channels,
-                        self.out_channels,
-                        self.bias,
-                        self.group,
-                    )
-        else:
-            if (not hasattr(self,
-                            "handle")) or (x.shape[0] != self.handle.batchsize):
-                self.handle = singa.CudnnConvHandle(
-                    x.data,
-                    self.kernel_size,
-                    self.stride,
-                    self.padding,
-                    self.in_channels,
-                    self.out_channels,
-                    self.bias,
-                    self.group,
-                )
-
-        y = conv2d(self.handle, x, self.W, self.b, self.odd_padding)
-        return y
-
-    def get_params(self):
-        if self.bias:
-            return {"W": self.W, "b": self.b}
-        else:
-            return {"W": self.W}
-
-    def set_params(self, **parameters):
-        # TODO(wangwei) remove it as Operation's set_params() is enough
-        # input should be either a PyTensor or numpy ndarray.
-        # Conv2d.set_params(W=np.ones((n, c, h, w), dtype=np.float32)),
-        # Conv2d.set_params(**{'W':np.ones((n, c, h, w), dtype=np.float32)})
-        self.allow_params = ["W", "b"]
-        super(Conv2d, self).set_params(**parameters)
-        for parameter_name in parameters:
-            if parameter_name is "b":
-                self.bias = True
-
-
-class SeparableConv2d(Layer):
-    """
-    Generate a Conv 2d operator
-    """
-
-    def __init__(
-            self,
-            in_channels,
-            out_channels,
-            kernel_size,
-            stride=1,
-            padding=0,
-            bias=False,
-    ):
-        """
-        Args:
-            in_channels (int): the channel of input
-            out_channels (int): the channel of output, also is the number of filters
-            kernel_size (int or tuple): kernel size for two direction of each 
-                axis. For example, (2, 3), the first 2 means will add 2 at the 
-                beginning and also 2 at the end for its axis.and if a int is 
-                accepted, the kernel size will be initiated as (int, int)
-            stride (int or tuple): stride, the logic is the same as kernel size.
-            padding (int): tuple, list or None, padding, the logic is the same 
-                as kernel size. However, if you set pad_mode as "SAME_UPPER" or 
-                "SAME_LOWER" mode, you can set padding as None, and the padding 
-                will be computed automatically.
-            bias (bool): bias
-        """
-        self.depthwise_conv = Conv2d(
-            in_channels,
-            in_channels,
-            kernel_size,
-            stride,
-            padding,
-            group=in_channels,
-            bias=bias,
-        )
-
-        self.point_conv = Conv2d(in_channels, out_channels, 1, bias=bias)
-
-    def __call__(self, x):
-        y = self.depthwise_conv(x)
-        y = self.point_conv(y)
-        return y
-
-
-class BatchNorm2d(Layer):
-    """
-    Generate a BatchNorm 2d operator
-    """
-
-    def __init__(self, num_features, momentum=0.9):
-        """
-        Args:
-            num_features (int): int, the channel of input
-            momentum (float): Factor used in computing the running mean and 
-                variance.
-        """
-        self.channels = num_features
-        self.momentum = momentum
-
-        param_shape = (self.channels,)
-
-        self.scale = Tensor(shape=param_shape,
-                            requires_grad=True,
-                            stores_grad=True)
-        self.scale.set_value(1.0)
-
-        self.bias = Tensor(shape=param_shape,
-                           requires_grad=True,
-                           stores_grad=True)
-        self.bias.set_value(0.0)
-
-        self.running_mean = Tensor(shape=param_shape,
-                                   requires_grad=False,
-                                   stores_grad=False)
-        self.running_mean.set_value(0.0)
-
-        self.running_var = Tensor(shape=param_shape,
-                                  requires_grad=False,
-                                  stores_grad=False)
-        self.running_var.set_value(1.0)
-
-    def __call__(self, x):
-        assert x.shape[1] == self.channels, (
-            "number of channels dismatched. %d vs %d" %
-            (x.shape[1], self.channels))
-
-        self.device_check(x, self.scale, self.bias, self.running_mean,
-                          self.running_var)
-
-        if x.device.id() == -1:
-            if not hasattr(self, "handle"):
-                self.handle = singa.BatchNormHandle(self.momentum, x.data)
-            elif x.shape[0] != self.handle.batchsize:
-                self.handle = singa.BatchNormHandle(self.momentum, x.data)
-        else:
-            if not hasattr(self, "handle"):
-                self.handle = singa.CudnnBatchNormHandle(self.momentum, x.data)
-            elif x.shape[0] != self.handle.batchsize:
-                self.handle = singa.CudnnBatchNormHandle(self.momentum, x.data)
-
-        y = batchnorm_2d(
-            self.handle,
-            x,
-            self.scale,
-            self.bias,
-            self.running_mean,
-            self.running_var,
-        )
-        return y
-
-    def get_params(self):
-        return {"scale": self.scale, "bias": self.bias}
-
-    def set_params(self, **parameters):
-        # set parameters for BatchNorm2d Layer
-        # input should be either a PyTensor or numpy ndarray.
-        # examples:
-        #   Batchnorm2d.set_params(scale=np.ones((1,), dtype=np.float32)),
-        #   Batchnorm2d.set_params(**{'bias':np.ones((1), dtype=np.float32)})
-        self.allow_params = ["scale", "bias"]
-        super(BatchNorm2d, self).set_params(**parameters)
-
-
-class _BatchNorm2d(Operation):
-    """
-    Carries out batch normalization as described in the paper 
-    https://arxiv.org/abs/1502.03167. 
+    Carries out batch normalization as described in the paper
+    https://arxiv.org/abs/1502.03167.
     """
 
     def __init__(self, handle, running_mean, running_var, name=None):
         """
         Args:
-            handle (object): BatchNormHandle for cpu and CudnnBatchNormHandle 
+            handle (object): BatchNormHandle for cpu and CudnnBatchNormHandle
                 for gpu
             running_mean (float): the running_mean
             running_var (float): the running_var
@@ -2020,10 +1823,10 @@
 
 def batchnorm_2d(handle, x, scale, bias, running_mean, running_var):
     """
-    Carries out batch normalization as described in the paper 
-    https://arxiv.org/abs/1502.03167. 
+    Carries out batch normalization as described in the paper
+    https://arxiv.org/abs/1502.03167.
     Args:
-        handle (object): BatchNormHandle for cpu and CudnnBatchNormHandle 
+        handle (object): BatchNormHandle for cpu and CudnnBatchNormHandle
             for gpu
         x (Tensor): the input tensor
         scale (Tensor): the bias tensor
@@ -2036,7 +1839,7 @@
     return _BatchNorm2d(handle, running_mean, running_var)(x, scale, bias)[0]
 
 
-class _Pooling2d(Operation):
+class _Pooling2d(Operator):
     """
     Init a pool 2d operator
     """
@@ -2044,18 +1847,16 @@
     def __init__(self, handle, odd_padding=(0, 0, 0, 0)):
         """
         Args:
-            handle (object): PoolingHandle for cpu or CudnnPoolingHandle for 
+            handle (object): PoolingHandle for cpu or CudnnPoolingHandle for
                 gpu
-            odd_padding (tuple of four int): the odd paddding is the value 
-                that cannot be handled by the tuple padding (w, h) mode so 
-                it needs to firstly handle the input, then use the normal 
+            odd_padding (tuple of four int): the odd paddding is the value
+                that cannot be handled by the tuple padding (w, h) mode so
+                it needs to firstly handle the input, then use the normal
                 padding method.
         """
         super(_Pooling2d, self).__init__()
         self.handle = handle
         self.odd_padding = odd_padding
-        if self.odd_padding != (0, 0, 0, 0):
-            self.re_new_handle = True
 
     def forward(self, x):
         """
@@ -2066,11 +1867,7 @@
         """
         assert x.nDim() == 4, "The dimensions of input should be 4D."
         if self.odd_padding != (0, 0, 0, 0):
-            x = utils.handle_odd_pad_fwd(x, self.odd_padding)
-            # re-new a handle with updated x
-            if self.re_new_handle:
-                self.re_new_handle = False
-                self.handle = utils.re_new_handle(self.handle, x, True)
+            x = utils.handle_odd_pad_fwd(x, self.odd_padding, True)
 
         if (type(self.handle) != singa.PoolingHandle):
             y = singa.GpuPoolingForward(self.handle, x)
@@ -2103,12 +1900,12 @@
     """
     Pooling 2d operator
     Args:
-        handle (object): PoolingHandle for cpu or CudnnPoolingHandle for 
+        handle (object): PoolingHandle for cpu or CudnnPoolingHandle for
             gpu
         x (Tensor): input
-        odd_padding (tuple of four int): the odd paddding is the value 
-            that cannot be handled by the tuple padding (w, h) mode so 
-            it needs to firstly handle the input, then use the normal 
+        odd_padding (tuple of four int): the odd paddding is the value
+            that cannot be handled by the tuple padding (w, h) mode so
+            it needs to firstly handle the input, then use the normal
             padding method.
     Returns:
         the result Tensor
@@ -2116,254 +1913,7 @@
     return _Pooling2d(handle, odd_padding)(x)[0]
 
 
-class Pooling2d(Layer):
-    """
-    Generate a Pooling 2d operator
-    """
-
-    def __init__(self,
-                 kernel_size,
-                 stride=None,
-                 padding=0,
-                 is_max=True,
-                 pad_mode="NOTSET"):
-        """
-        Args:
-            kernel_size (int or tuple): kernel size for two direction of each 
-                axis. For example, (2, 3), the first 2 means will add 2 at the 
-                beginning and also 2 at the end for its axis.and if a int is 
-                accepted, the kernel size will be initiated as (int, int)
-            stride (int or tuple): stride, the logic is the same as kernel size.
-            padding (int): tuple, list or None, padding, the logic is the same 
-                as kernel size. However, if you set pad_mode as "SAME_UPPER" or 
-                "SAME_LOWER" mode, you can set padding as None, and the padding 
-                will be computed automatically.
-            is_max (bool): is max pooling or avg pooling
-            pad_mode (string): can be NOTSET, SAME_UPPER, or SAME_LOWER, where 
-                default value is NOTSET, which means explicit padding is used. 
-                SAME_UPPER or SAME_LOWER mean pad the input so that the output 
-                spatial size match the input. In case of odd number add the extra 
-                padding at the end for SAME_UPPER and at the beginning for SAME_LOWER.
-        """
-        if isinstance(kernel_size, int):
-            self.kernel_size = (kernel_size, kernel_size)
-        elif isinstance(kernel_size, tuple):
-            self.kernel_size = kernel_size
-        else:
-            raise TypeError("Wrong kernel_size type.")
-
-        if stride is None:
-            self.stride = self.kernel_size
-        elif isinstance(stride, int):
-            self.stride = (stride, stride)
-        elif isinstance(stride, tuple):
-            self.stride = stride
-            assert stride[0] > 0 or (kernel_size[0] == 1 and padding[0] == 0), (
-                "stride[0]=0, but kernel_size[0]=%d, padding[0]=%d" %
-                (kernel_size[0], padding[0]))
-        else:
-            raise TypeError("Wrong stride type.")
-
-        self.odd_padding = (0, 0, 0, 0)
-        if isinstance(padding, int):
-            self.padding = (padding, padding)
-        elif isinstance(padding, tuple) or isinstance(padding, list):
-            if len(padding) == 2:
-                self.padding = padding
-            elif len(padding) == 4:
-                _h_mask = padding[0] - padding[1]
-                _w_mask = padding[2] - padding[3]
-                # the odd paddding is the value that cannot be handled by the tuple padding (w, h) mode
-                # so we need to firstly handle the input, then use the nomal padding method.
-                self.odd_padding = (max(_h_mask, 0), max(-_h_mask, 0),
-                                    max(_w_mask, 0), max(-_w_mask, 0))
-                self.padding = (
-                    padding[0] - self.odd_padding[0],
-                    padding[2] - self.odd_padding[2],
-                )
-            else:
-                raise TypeError("Wrong padding value.")
-
-        self.is_max = is_max
-        self.pad_mode = pad_mode
-
-    def __call__(self, x):
-        # if same pad mode, re-compute the padding
-        if self.pad_mode in ("SAME_UPPER", "SAME_LOWER"):
-            self.padding, self.odd_padding = utils.get_padding_shape(
-                self.pad_mode, x.shape[2:], self.kernel_size, self.stride)
-
-        out_shape_h = (int(
-            (x.shape[2] + 2 * self.padding[0] - self.kernel_size[0]) //
-            self.stride[0]) + 1)
-        out_shape_w = (int(
-            (x.shape[3] + 2 * self.padding[1] - self.kernel_size[1]) //
-            self.stride[1]) + 1)
-        if x.device.id() == -1:
-            if not hasattr(self, "handle"):
-                self.handle = singa.PoolingHandle(
-                    x.data,
-                    self.kernel_size,
-                    self.stride,
-                    self.padding,
-                    self.is_max,
-                )
-            elif (x.shape[0] != self.handle.batchsize or
-                  out_shape_h != self.handle.pooled_height or
-                  out_shape_w != self.handle.pooled_width):
-                self.handle = singa.PoolingHandle(
-                    x.data,
-                    self.kernel_size,
-                    self.stride,
-                    self.padding,
-                    self.is_max,
-                )
-        else:
-            if not hasattr(self, "handle"):
-                self.handle = singa.CudnnPoolingHandle(
-                    x.data,
-                    self.kernel_size,
-                    self.stride,
-                    self.padding,
-                    self.is_max,
-                )
-            elif (x.shape[0] != self.handle.batchsize or
-                  out_shape_h != self.handle.pooled_height or
-                  out_shape_w != self.handle.pooled_width):
-                self.handle = singa.CudnnPoolingHandle(
-                    x.data,
-                    self.kernel_size,
-                    self.stride,
-                    self.padding,
-                    self.is_max,
-                )
-
-        y = pooling_2d(self.handle, x, self.odd_padding)
-        return y
-
-
-class MaxPool2d(Pooling2d):
-    """
-    Generate a Max Pooling 2d operator
-    """
-
-    def __init__(self,
-                 kernel_size,
-                 stride=None,
-                 padding=0,
-                 odd_padding=(0, 0, 0, 0)):
-        """
-        Args:
-            kernel_size (int or tuple): kernel size for two direction of each 
-                axis. For example, (2, 3), the first 2 means will add 2 at the 
-                beginning and also 2 at the end for its axis.and if a int is 
-                accepted, the kernel size will be initiated as (int, int)
-            stride (int or tuple): stride, the logic is the same as kernel size.
-            padding (int): tuple, list or None, padding, the logic is the same 
-                as kernel size. However, if you set pad_mode as "SAME_UPPER" or 
-                "SAME_LOWER" mode, you can set padding as None, and the padding 
-                will be computed automatically.
-            odd_padding (tuple of four int): the odd paddding is the value 
-                that cannot be handled by the tuple padding (w, h) mode so 
-                it needs to firstly handle the input, then use the normal 
-                padding method.
-        """
-        super(MaxPool2d, self).__init__(kernel_size, stride, padding, True,
-                                        odd_padding)
-
-
-class AvgPool2d(Pooling2d):
-
-    def __init__(self,
-                 kernel_size,
-                 stride=None,
-                 padding=0,
-                 odd_padding=(0, 0, 0, 0)):
-        """
-        Args:
-            kernel_size (int or tuple): kernel size for two direction of each 
-                axis. For example, (2, 3), the first 2 means will add 2 at the 
-                beginning and also 2 at the end for its axis.and if a int is 
-                accepted, the kernel size will be initiated as (int, int)
-            stride (int or tuple): stride, the logic is the same as kernel size.
-            padding (int): tuple, list or None, padding, the logic is the same 
-                as kernel size. However, if you set pad_mode as "SAME_UPPER" or 
-                "SAME_LOWER" mode, you can set padding as None, and the padding 
-                will be computed automatically.
-            odd_padding (tuple of four int): the odd paddding is the value 
-                that cannot be handled by the tuple padding (w, h) mode so 
-                it needs to firstly handle the input, then use the normal 
-                padding method.
-        """
-        super(AvgPool2d, self).__init__(kernel_size, stride, padding, False,
-                                        odd_padding)
-
-
-class MaxPool1d(Pooling2d):
-    """
-    Generate a Max Pooling 1d operator
-    """
-
-    def __init__(self,
-                 kernel_size,
-                 stride=None,
-                 padding=0,
-                 odd_padding=(0, 0, 0, 0)):
-        """
-        Args:
-            kernel_size (int or tuple): kernel size for two direction of each 
-                axis. For example, (2, 3), the first 2 means will add 2 at the 
-                beginning and also 2 at the end for its axis.and if a int is 
-                accepted, the kernel size will be initiated as (int, int)
-            stride (int or tuple): stride, the logic is the same as kernel size.
-            padding (int): tuple, list or None, padding, the logic is the same 
-                as kernel size. However, if you set pad_mode as "SAME_UPPER" or 
-                "SAME_LOWER" mode, you can set padding as None, and the padding 
-                will be computed automatically.
-            odd_padding (tuple of four int): the odd paddding is the value 
-                that cannot be handled by the tuple padding (w, h) mode so 
-                it needs to firstly handle the input, then use the normal 
-                padding method.
-        """
-        if stride is None:
-            stride = kernel_size
-        super(MaxPool1d, self).__init__((1, kernel_size), (1, stride),
-                                        (0, padding), True, odd_padding)
-
-
-class AvgPool1d(Pooling2d):
-    """
-    Generate a Avg Pooling 1d operator
-    """
-
-    def __init__(self,
-                 kernel_size,
-                 stride=None,
-                 padding=0,
-                 odd_padding=(0, 0, 0, 0)):
-        """
-        Args:
-            kernel_size (int or tuple): kernel size for two direction of each 
-                axis. For example, (2, 3), the first 2 means will add 2 at the 
-                beginning and also 2 at the end for its axis.and if a int is 
-                accepted, the kernel size will be initiated as (int, int)
-            stride (int or tuple): stride, the logic is the same as kernel size.
-            padding (int): tuple, list or None, padding, the logic is the same 
-                as kernel size. However, if you set pad_mode as "SAME_UPPER" or 
-                "SAME_LOWER" mode, you can set padding as None, and the padding 
-                will be computed automatically.
-            odd_padding (tuple of four int): the odd paddding is the value 
-                that cannot be handled by the tuple padding (w, h) mode so 
-                it needs to firstly handle the input, then use the normal 
-                padding method.
-        """
-        if stride is None:
-            stride = kernel_size
-        super(AvgPool1d, self).__init__((1, kernel_size), (1, stride),
-                                        (0, padding), False, odd_padding)
-
-
-class Tanh(Operation):
+class Tanh(Operator):
     """
     Calculates the hyperbolic tangent of the given input tensor element-wise.
     """
@@ -2375,7 +1925,7 @@
         """
         Args:
             x (CTensor): Input tensor
-        Returns: 
+        Returns:
             CTensor, the output
         """
         out = singa.Tanh(x)
@@ -2387,7 +1937,7 @@
         """
         Args:
             dy (CTensor): the gradient tensor from upper operations
-        Returns: 
+        Returns:
             CTensor, the gradient over input
         """
         dx = singa.__mul__(self.cache[0], self.cache[0])
@@ -2402,13 +1952,13 @@
     Calculates the hyperbolic tangent of the given input tensor element-wise.
     Args:
         x (Tensor): Input tensor
-    Returns: 
+    Returns:
         Tensor, the output
     """
     return Tanh()(x)[0]
 
 
-class Cos(Operation):
+class Cos(Operator):
     """
     Calculates the cosine of the given input tensor, element-wise.
     """
@@ -2420,7 +1970,7 @@
         """
         Args:
             x (CTensor): Input tensor
-        Returns: 
+        Returns:
             CTensor, the output
         """
         if training:
@@ -2431,7 +1981,7 @@
         """
         Args:
             dy (CTensor): the gradient tensor from upper operations
-        Returns: 
+        Returns:
             CTensor, the gradient over input
         """
         dx = singa.Sin(self.input)
@@ -2445,14 +1995,14 @@
     Calculates the cosine of the given input tensor, element-wise.
     Args:
         x (Tensor): Input tensor
-    Returns: 
+    Returns:
         Tensor, the output
     """
 
     return Cos()(x)[0]
 
 
-class Cosh(Operation):
+class Cosh(Operator):
     """
     Calculates the hyperbolic cosine of the given input tensor element-wise.
     """
@@ -2464,7 +2014,7 @@
         """
         Args:
             x (CTensor): Input tensor
-        Returns: 
+        Returns:
             CTensor, the output
         """
         if training:
@@ -2475,7 +2025,7 @@
         """
         Args:
             dy (CTensor): the gradient tensor from upper operations
-        Returns: 
+        Returns:
             CTensor, the gradient over input
         """
         dx = singa.Sinh(self.input)
@@ -2488,15 +2038,15 @@
     Calculates the hyperbolic cosine of the given input tensor element-wise.
     Args:
         x (Tensor): Input tensor
-    Returns: 
+    Returns:
         Tensor, the output
     """
     return Cosh()(x)[0]
 
 
-class Acos(Operation):
+class Acos(Operator):
     """
-    Calculates the arccosine (inverse of cosine) of the given input tensor, 
+    Calculates the arccosine (inverse of cosine) of the given input tensor,
     element-wise.
     """
 
@@ -2507,7 +2057,7 @@
         """
         Args:
             x (CTensor): Input tensor
-        Returns: 
+        Returns:
             CTensor, the output
         """
         if training:
@@ -2518,7 +2068,7 @@
         """
         Args:
             dy (CTensor): the gradient tensor from upper operations
-        Returns: 
+        Returns:
             CTensor, the gradient over input
         """
         dx = singa.Square(self.input)
@@ -2532,17 +2082,17 @@
 
 def acos(x):
     """
-    Calculates the arccosine (inverse of cosine) of the given input tensor, 
+    Calculates the arccosine (inverse of cosine) of the given input tensor,
     element-wise.
     Args:
         x (Tensor): Input tensor
-    Returns: 
+    Returns:
         Tensor, the output
     """
     return Acos()(x)[0]
 
 
-class Acosh(Operation):
+class Acosh(Operator):
     """
     Calculates the hyperbolic arccosine of the given input tensor element-wise.
     """
@@ -2554,7 +2104,7 @@
         """
         Args:
             x (CTensor): Input tensor
-        Returns: 
+        Returns:
             CTensor, the output
         """
         if training:
@@ -2565,7 +2115,7 @@
         """
         Args:
             dy (CTensor): the gradient tensor from upper operations
-        Returns: 
+        Returns:
             CTensor, the gradient over input
         """
         dx = singa.SubFloat(self.input, 1.0)
@@ -2583,13 +2133,13 @@
     Calculates the hyperbolic arccosine of the given input tensor element-wise.
     Args:
         x (Tensor): Input tensor
-    Returns: 
+    Returns:
         Tensor, the output
     """
     return Acosh()(x)[0]
 
 
-class Sin(Operation):
+class Sin(Operator):
     """
     Calculates the sine of the given input tensor, element-wise.
     """
@@ -2601,7 +2151,7 @@
         """
         Args:
             x (CTensor): Input tensor
-        Returns: 
+        Returns:
             CTensor, the output
         """
         if training:
@@ -2612,7 +2162,7 @@
         """
         Args:
             dy (CTensor): the gradient tensor from upper operations
-        Returns: 
+        Returns:
             CTensor, the gradient over input
         """
         dx = singa.Cos(self.input)
@@ -2625,13 +2175,13 @@
     Calculates the sine of the given input tensor, element-wise.
     Args:
         x (Tensor): Input tensor
-    Returns: 
+    Returns:
         Tensor, the output
     """
     return Sin()(x)[0]
 
 
-class Sinh(Operation):
+class Sinh(Operator):
     """
     Calculates the hyperbolic sine of the given input tensor element-wise.
     """
@@ -2643,7 +2193,7 @@
         """
         Args:
             x (CTensor): Input tensor
-        Returns: 
+        Returns:
             CTensor, the output
         """
         if training:
@@ -2654,7 +2204,7 @@
         """
         Args:
             dy (CTensor): the gradient tensor from upper operations
-        Returns: 
+        Returns:
             CTensor, the gradient over input
         """
         dx = singa.Cosh(self.input)
@@ -2667,13 +2217,13 @@
     Calculates the hyperbolic sine of the given input tensor element-wise.
     Args:
         x (Tensor): Input tensor
-    Returns: 
+    Returns:
         Tensor, the output
     """
     return Sinh()(x)[0]
 
 
-class Asin(Operation):
+class Asin(Operator):
     """
     Calculates the arcsine (inverse of sine) of the given input tensor, element-wise.
     """
@@ -2685,7 +2235,7 @@
         """
         Args:
             x (CTensor): Input tensor
-        Returns: 
+        Returns:
             CTensor, the output
         """
         if training:
@@ -2696,7 +2246,7 @@
         """
         Args:
             dy (CTensor): the gradient tensor from upper operations
-        Returns: 
+        Returns:
             CTensor, the gradient over input
         """
         dx = singa.Square(self.input)
@@ -2712,14 +2262,14 @@
     Calculates the arcsine (inverse of sine) of the given input tensor, element-wise.
     Args:
         x (Tensor): Input tensor
-    Returns: 
+    Returns:
         Tensor, the output
     """
 
     return Asin()(x)[0]
 
 
-class Asinh(Operation):
+class Asinh(Operator):
     """
     Calculates the hyperbolic arcsine of the given input tensor element-wise.
     """
@@ -2731,7 +2281,7 @@
         """
         Args:
             x (CTensor): Input tensor
-        Returns: 
+        Returns:
             CTensor, the output
         """
         if training:
@@ -2742,7 +2292,7 @@
         """
         Args:
             dy (CTensor): the gradient tensor from upper operations
-        Returns: 
+        Returns:
             CTensor, the gradient over input
         """
         dx = singa.Square(self.input)
@@ -2757,15 +2307,15 @@
     Calculates the hyperbolic arcsine of the given input tensor element-wise.
     Args:
         x (Tensor): Input tensor
-    Returns: 
+    Returns:
         Tensor, the output
     """
     return Asinh()(x)[0]
 
 
-class Tan(Operation):
+class Tan(Operator):
     """
-    Insert single-dimensional entries to the shape of an input tensor (data). 
+    Insert single-dimensional entries to the shape of an input tensor (data).
     """
 
     def __init__(self):
@@ -2775,7 +2325,7 @@
         """
         Args:
             x (CTensor): Input tensor
-        Returns: 
+        Returns:
             CTensor, the output
         """
         if training:
@@ -2786,7 +2336,7 @@
         """
         Args:
             dy (CTensor): the gradient tensor from upper operations
-        Returns: 
+        Returns:
             CTensor, the gradient over input
         """
         dx = singa.Cos(self.input)
@@ -2801,13 +2351,13 @@
     Calculates the tangent of the given input tensor, element-wise.
     Args:
         x (Tensor): Input tensor
-    Returns: 
+    Returns:
         Tensor, the output
     """
     return Tan()(x)[0]
 
 
-class Atan(Operation):
+class Atan(Operator):
     """
     Calculates the arctangent (inverse of tangent) of the given input tensor, element-wise.
     """
@@ -2819,7 +2369,7 @@
         """
         Args:
             x (CTensor): Input tensor
-        Returns: 
+        Returns:
             CTensor, the output
         """
         if training:
@@ -2830,7 +2380,7 @@
         """
         Args:
             dy (CTensor): the gradient tensor from upper operations
-        Returns: 
+        Returns:
             CTensor, the gradient over input
         """
         dx = singa.Square(self.input)
@@ -2845,13 +2395,13 @@
     Calculates the arctangent (inverse of tangent) of the given input tensor, element-wise.
     Args:
         x (Tensor): Input tensor
-    Returns: 
+    Returns:
         Tensor, the output
     """
     return Atan()(x)[0]
 
 
-class Atanh(Operation):
+class Atanh(Operator):
     """
     Calculates the hyperbolic arctangent of the given input tensor element-wise.
     """
@@ -2863,7 +2413,7 @@
         """
         Args:
             x (CTensor): Input tensor
-        Returns: 
+        Returns:
             CTensor, the output
         """
         if training:
@@ -2874,7 +2424,7 @@
         """
         Args:
             dy (CTensor): the gradient tensor from upper operations
-        Returns: 
+        Returns:
             CTensor, the gradient over input
         """
         dx = singa.Square(self.input)
@@ -2890,13 +2440,13 @@
     Calculates the hyperbolic arctangent of the given input tensor element-wise.
     Args:
         x (Tensor): Input tensor
-    Returns: 
+    Returns:
         Tensor, the output
     """
     return Atanh()(x)[0]
 
 
-class Sigmoid(Operation):
+class Sigmoid(Operator):
     """
     `y = 1 / (1 + exp(-x))`, is applied to the tensor elementwise.
     """
@@ -2908,7 +2458,7 @@
         """
         Args:
             x (CTensor): Input tensor
-        Returns: 
+        Returns:
             CTensor, the output
         """
         out = singa.Sigmoid(x)
@@ -2920,7 +2470,7 @@
         """
         Args:
             dy (CTensor): the gradient tensor from upper operations
-        Returns: 
+        Returns:
             CTensor, the gradient over input
         """
         dx = singa.MultFloat(self.cache[0], -1.0)
@@ -2935,16 +2485,16 @@
     `y = 1 / (1 + exp(-x))`, is applied to the tensor elementwise.
     Args:
         x (Tensor): Input tensor
-    Returns: 
+    Returns:
         Tensor, the output
     """
     return Sigmoid()(x)[0]
 
 
-class Mul(Operation):
+class Mul(Operator):
     """
-    Performs element-wise binary multiplication (with Numpy-style broadcasting 
-    support).        
+    Performs element-wise binary multiplication (with Numpy-style broadcasting
+    support).
     """
 
     def __init__(self):
@@ -2976,7 +2526,7 @@
         """
         Args:
             dy (CTensor): the gradient tensor from upper operations
-        Returns: 
+        Returns:
             a tuple for (da, db), da is data for dL / da, db is data
                 for dL / db.
         """
@@ -2998,9 +2548,9 @@
     return Mul()(x, y)[0]
 
 
-class Unsqueeze(Operation):
+class Unsqueeze(Operator):
     """
-    Insert single-dimensional entries to the shape of an input tensor (data). 
+    Insert single-dimensional entries to the shape of an input tensor (data).
     """
 
     def __init__(self, axis):
@@ -3018,7 +2568,7 @@
         """
         Args:
             x (CTensor): Input tensor
-        Returns: 
+        Returns:
             CTensor, the output
         """
         self.cache = x.shape()
@@ -3034,7 +2584,7 @@
         """
         Args:
             dy (CTensor): the gradient tensor from upper operations
-        Returns: 
+        Returns:
             CTensor, the gradient over input
         """
         return singa.Reshape(dy, self.cache)
@@ -3042,25 +2592,25 @@
 
 def unsqueeze(x, axis=-1):
     """
-    Insert single-dimensional entries to the shape of an input tensor (data). 
+    Insert single-dimensional entries to the shape of an input tensor (data).
     Args:
         x (Tensor): Input tensor
         axis (list of int): the dimensions to be inserted.
-    Returns: 
+    Returns:
         Tensor, the output
     """
     return Unsqueeze(axis)(x)[0]
 
 
-class Transpose(Operation):
+class Transpose(Operator):
     """
-    Transpose the input tensor similar to numpy.transpose. 
+    Transpose the input tensor similar to numpy.transpose.
     """
 
     def __init__(self, perm):
         """
         Args:
-            perm (list of ints): A list of integers. By default, reverse the 
+            perm (list of ints): A list of integers. By default, reverse the
                 dimensions, otherwise permute the axes according to the values given.
         """
         super(Transpose, self).__init__()
@@ -3070,7 +2620,7 @@
         """
         Args:
             x (CTensor): Input tensor
-        Returns: 
+        Returns:
             CTensor, the output
         """
         return singa.Transpose(x, self.perm)
@@ -3079,7 +2629,7 @@
         """
         Args:
             dy (CTensor): the gradient tensor from upper operations
-        Returns: 
+        Returns:
             CTensor, the gradient over input
         """
         cur = []
@@ -3090,12 +2640,12 @@
 
 def transpose(x, shape):
     """
-    Transpose the input tensor similar to numpy.transpose. 
+    Transpose the input tensor similar to numpy.transpose.
     Args:
         x (Tensor): Input tensor
-        perm (list of ints): A list of integers. By default, reverse the 
+        perm (list of ints): A list of integers. By default, reverse the
             dimensions, otherwise permute the axes according to the values given.
-    Returns: 
+    Returns:
         Tensor, the output
     """
     return Transpose(shape)(x)[0]
@@ -3109,230 +2659,7 @@
     return
 
 
-class RNN_Base(Layer):
-
-    def __init__(self):
-        raise NotImplementedError
-
-    def __call__(self):
-        raise NotImplementedError
-
-    def step_forward(self,
-                     x=None,
-                     h=None,
-                     c=None,
-                     Wx=None,
-                     Wh=None,
-                     Bx=None,
-                     Bh=None,
-                     b=None):
-        raise NotImplementedError
-
-
-class RNN(RNN_Base):
-    """
-    Generate a RNN operator
-    """
-
-    def __init__(
-            self,
-            input_size,
-            hidden_size,
-            num_layers=1,
-            nonlinearity="tanh",
-            bias=True,
-            batch_first=False,
-            dropout=0,
-            bidirectional=False,
-    ):
-        """
-        Args:
-            input_size (int):  The number of expected features in the input x
-            hidden_size (int): The number of features in the hidden state h
-            num_layers (int):  Number of recurrent layers. Default: 1
-            nonlinearity (string): The non-linearity to use. Default: 'tanh'
-            bias (bool):  If False, then the layer does not use bias weights. 
-                Default: True
-            batch_first (bool):  If True, then the input and output tensors 
-                are provided as (batch, seq, feature). Default: False
-            dropout (float): If non-zero, introduces a Dropout layer on the 
-                outputs of each RNN layer except the last layer, with dropout 
-                probability equal to dropout. Default: 0
-            bidirectional (bool): If True, becomes a bidirectional RNN. 
-                Default: False
-        """
-        self.nonlinearity = nonlinearity
-
-        Wx_shape = (input_size, hidden_size)
-        self.Wx = Tensor(shape=Wx_shape, requires_grad=True, stores_grad=True)
-        self.Wx.gaussian(0.0, 1.0)
-
-        Wh_shape = (hidden_size, hidden_size)
-        self.Wh = Tensor(shape=Wh_shape, requires_grad=True, stores_grad=True)
-        self.Wh.gaussian(0.0, 1.0)
-
-        B_shape = (hidden_size,)
-        self.b = Tensor(shape=B_shape, requires_grad=True, stores_grad=True)
-        self.b.set_value(0.0)
-
-        self.params = (self.Wx, self.Wh, self.b)
-
-    def __call__(self, xs, h0):
-        # xs: a tuple or list of input tensors
-        if not isinstance(xs, tuple):
-            xs = tuple(xs)
-        inputs = xs + (h0,)
-        self.device_check(*inputs)
-        # self.device_check(inputs[0], *self.params)
-        self.device_check(inputs[0], self.Wx, self.Wh, self.b)
-        batchsize = xs[0].shape[0]
-        out = []
-        h = self.step_forward(xs[0], h0, self.Wx, self.Wh, self.b)
-        out.append(h)
-        for x in xs[1:]:
-            assert x.shape[0] == batchsize
-            h = self.step_forward(x, h, self.Wx, self.Wh, self.b)
-            out.append(h)
-        return out, h
-
-    def step_forward(self, x, h, Wx, Wh, b):
-        y2 = matmul(h, Wh)
-        y1 = matmul(x, Wx)
-        y = add(y2, y1)
-        y = add_bias(y, b, axis=0)
-        if self.nonlinearity == "tanh":
-            y = tanh(y)
-        elif self.nonlinearity == "relu":
-            y = relu(y)
-        else:
-            raise ValueError
-        return y
-
-
-class LSTM(RNN_Base):
-    """
-    Generate a LSTM operator
-    """
-
-    def __init__(
-            self,
-            input_size,
-            hidden_size,
-            nonlinearity="tanh",
-            num_layers=1,
-            bias=True,
-            batch_first=False,
-            dropout=0,
-            bidirectional=False,
-    ):
-        """
-        Args:
-            input_size (int):  The number of expected features in the input x
-            hidden_size (int): The number of features in the hidden state h
-            num_layers (int):  Number of recurrent layers. Default: 1
-            nonlinearity (string): The non-linearity to use. Default: 'tanh'
-            bias (bool):  If False, then the layer does not use bias weights. 
-                Default: True
-            batch_first (bool):  If True, then the input and output tensors 
-                are provided as (batch, seq, feature). Default: False
-            dropout (float): If non-zero, introduces a Dropout layer on the 
-                outputs of each RNN layer except the last layer, with dropout 
-                probability equal to dropout. Default: 0
-            bidirectional (bool): If True, becomes a bidirectional RNN. 
-                Default: False
-        """
-        self.nonlinearity = nonlinearity
-
-        Wx_shape = (input_size, hidden_size)
-        self.Wx = []
-        for i in range(4):
-            w = Tensor(shape=Wx_shape, requires_grad=True, stores_grad=True)
-            w.gaussian(0.0, 0.01)
-            self.Wx.append(w)
-
-        Wh_shape = (hidden_size, hidden_size)
-        self.Wh = []
-        for i in range(4):
-            w = Tensor(shape=Wh_shape, requires_grad=True, stores_grad=True)
-            w.gaussian(0.0, 0.01)
-            self.Wh.append(w)
-
-        Bx_shape = (hidden_size,)
-        self.Bx = []
-        for i in range(4):
-            b = Tensor(shape=Bx_shape, requires_grad=True, stores_grad=True)
-            b.set_value(0.0)
-            self.Bx.append(b)
-
-        self.Bh = []
-        for i in range(4):
-            b = Tensor(shape=Bx_shape, requires_grad=True, stores_grad=True)
-            b.set_value(0.0)
-            self.Bh.append(b)
-
-        self.params = self.Wx + self.Wh + self.Bx + self.Bh
-
-    def __call__(self, xs, h0_c0):
-        # xs: a tuple or list of input tensors
-        # h0_c0: a tuple of (h0, c0)
-        h0, c0 = h0_c0
-        if not isinstance(xs, list):
-            xs = list(xs)
-        inputs = xs + list((h0, c0))
-        self.device_check(*inputs)
-        # self.device_check(inputs[0], *self.params)
-        self.device_check(inputs[0], *(self.Wx + self.Wh + self.Bx + self.Bh))
-        batchsize = xs[0].shape[0]
-        out = []
-        h, c = self.step_forward(xs[0], h0, c0, self.Wx, self.Wh, self.Bx,
-                                 self.Bh)
-        out.append(h)
-        for x in xs[1:]:
-            assert x.shape[0] == batchsize
-            h, c = self.step_forward(x, h, c, self.Wx, self.Wh, self.Bx,
-                                     self.Bh)
-            out.append(h)
-        return out, h, c
-
-    def step_forward(self, x, h, c, Wx, Wh, Bx, Bh):
-        y1 = matmul(x, Wx[0])
-        y1 = add_bias(y1, Bx[0], axis=0)
-        y2 = matmul(h, Wh[0])
-        y2 = add_bias(y2, Bh[0], axis=0)
-        i = add(y1, y2)
-        i = sigmoid(i)
-
-        y1 = matmul(x, Wx[1])
-        y1 = add_bias(y1, Bx[1], axis=0)
-        y2 = matmul(h, Wh[1])
-        y2 = add_bias(y2, Bh[1], axis=0)
-        f = add(y1, y2)
-        f = sigmoid(f)
-
-        y1 = matmul(x, Wx[2])
-        y1 = add_bias(y1, Bx[2], axis=0)
-        y2 = matmul(h, Wh[2])
-        y2 = add_bias(y2, Bh[2], axis=0)
-        o = add(y1, y2)
-        o = sigmoid(o)
-
-        y1 = matmul(x, Wx[3])
-        y1 = add_bias(y1, Bx[3], axis=0)
-        y2 = matmul(h, Wh[3])
-        y2 = add_bias(y2, Bh[3], axis=0)
-        g = add(y1, y2)
-        g = tanh(g)
-
-        cout1 = mul(f, c)
-        cout2 = mul(i, g)
-        cout = add(cout1, cout2)
-
-        hout = tanh(cout)
-        hout = mul(o, hout)
-        return hout, cout
-
-
-class Abs(Operation):
+class Abs(Operator):
     """
     `y = abs(x)`, is applied to the tensor elementwise.
     """
@@ -3349,7 +2676,7 @@
         """
         Args:
             dy (CTensor): the gradient tensor from upper operations
-        Returns: 
+        Returns:
             CTensor, the gradient over input
         """
         dx = singa.Sign(self.input)
@@ -3364,7 +2691,7 @@
     return Abs()(a)[0]
 
 
-class Exp(Operation):
+class Exp(Operator):
     """
     `y = exp(x)`, is applied to the tensor elementwise.
     """
@@ -3381,7 +2708,7 @@
         """
         Args:
             dy (CTensor): the gradient tensor from upper operations
-        Returns: 
+        Returns:
             CTensor, the gradient over input
         """
         dx = singa.Exp(self.input)
@@ -3396,7 +2723,7 @@
     return Exp()(a)[0]
 
 
-class LeakyRelu(Operation):
+class LeakyRelu(Operator):
     """
     `f(x) = alpha * x` for x < 0, `f(x) = x` for x >= 0, is applied to the tensor elementwise.
     """
@@ -3413,7 +2740,7 @@
         """
         Args:
             x (CTensor): Input tensor
-        Returns: 
+        Returns:
             CTensor, the output
         """
         if training:
@@ -3429,7 +2756,7 @@
         """
         Args:
             dy (CTensor): the gradient tensor from upper operations
-        Returns: 
+        Returns:
             CTensor, the gradient over input
         """
         # TODO(wangwei) check the correctness
@@ -3443,20 +2770,20 @@
 
 def leakyrelu(x, a=0.01):
     """
-    `f(x) = alpha * x` for x < 0, `f(x) = x` for x >= 0 is applied to the tensor 
+    `f(x) = alpha * x` for x < 0, `f(x) = x` for x >= 0 is applied to the tensor
     elementwise.
     Args:
         x (Tensor): Input tensor
         a (float): Coefficient of leakage, default to 0.01.
-    Returns: 
+    Returns:
         Tensor, the output
     """
     return LeakyRelu(a)(x)[0]
 
 
-class Sign(Operation):
+class Sign(Operator):
     """
-    Calculate the sign of the given input tensor element-wise. If input > 0, 
+    Calculate the sign of the given input tensor element-wise. If input > 0,
     output 1. if input < 0, output -1. if input == 0, output 0.
     """
 
@@ -3467,7 +2794,7 @@
         """
         Args:
             a (CTensor): Input tensor
-        Returns: 
+        Returns:
             CTensor, the output
         """
         if training:
@@ -3478,7 +2805,7 @@
         """
         Args:
             dy (CTensor): the gradient tensor from upper operations
-        Returns: 
+        Returns:
             CTensor, the gradient over input
         """
         dx = singa.MultFloat(dy, 0.0)
@@ -3487,17 +2814,17 @@
 
 def sign(a):
     """
-    Calculate the sign of the given input tensor element-wise. If input > 0, 
+    Calculate the sign of the given input tensor element-wise. If input > 0,
     output 1. if input < 0, output -1. if input == 0, output 0.
     Args:
         a (Tensor): Input tensor
-    Returns: 
+    Returns:
         Tensor, the output
     """
     return Sign()(a)[0]
 
 
-class Pow(Operation):
+class Pow(Operator):
     """
     `f(x) = a^b`, is applied to the tensor elementwise.
     """
@@ -3521,7 +2848,7 @@
         """
         Args:
             dy (CTensor): the gradient tensor from upper operations
-        Returns: 
+        Returns:
             a tuple for (da, db), da is data for dL / da, db is data
                 for dL / db.
         """
@@ -3548,7 +2875,7 @@
     return Pow()(a, b)[0]
 
 
-class SoftSign(Operation):
+class SoftSign(Operator):
     """
     Calculates the softsign `(x/(1+|x|))` of the given input tensor element-wise.
     """
@@ -3572,7 +2899,7 @@
         """
         Args:
             dy (CTensor): the gradient tensor from upper operations
-        Returns: 
+        Returns:
             CTensor, the gradient over input
         """
         dx = singa.AddFloat(singa.Abs(self.input), 1.0)
@@ -3588,7 +2915,7 @@
     return SoftSign()(x)[0]
 
 
-class Sqrt(Operation):
+class Sqrt(Operator):
     """
     `y = x^0.5`, is applied to the tensor elementwise.
     """
@@ -3608,7 +2935,7 @@
         """
         Args:
             dy (CTensor): the gradient tensor from upper operations
-        Returns: 
+        Returns:
             CTensor, the gradient over input
         """
         dx = singa.PowFloat(self.input, -0.5)
@@ -3624,7 +2951,7 @@
     return Sqrt()(x)[0]
 
 
-class SoftPlus(Operation):
+class SoftPlus(Operator):
     """
     `y = ln(exp(x) + 1)` is applied to the tensor elementwise.
     """
@@ -3647,7 +2974,7 @@
         """
         Args:
             dy (CTensor): the gradient tensor from upper operations
-        Returns: 
+        Returns:
             CTensor, the gradient over input
         """
         dx = singa.Exp(singa.MultFloat(self.input, -1.0))
@@ -3663,9 +2990,9 @@
     return SoftPlus()(x)[0]
 
 
-class Sub(Operation):
+class Sub(Operator):
     """
-    Performs element-wise binary subtraction (with Numpy-style broadcasting 
+    Performs element-wise binary subtraction (with Numpy-style broadcasting
     support).
     """
 
@@ -3676,7 +3003,14 @@
         """
         Return `a-b`, where x is CTensor.
         """
+        ori_type = None
+        if a.data_type() != singa.kFloat32:
+            ori_type = a.data_type()
+            a = a.AsType(singa.kFloat32)
+            b = b.AsType(singa.kFloat32)
         res = singa.__sub__(a, b)
+        if ori_type is not None:
+            res = res.AsType(ori_type)
         if training:
             self.shape0 = list(a.shape())
             self.shape1 = list(b.shape())
@@ -3687,7 +3021,7 @@
         """
         Args:
             dy (CTensor): the gradient tensor from upper operations
-        Returns: 
+        Returns:
             a tuple for (da, db), da is data for dL / da, db is data
                 for dL / db.
         """
@@ -3710,9 +3044,9 @@
 
 
 # optimize min to support multi inputs
-class Min(Operation):
+class Min(Operator):
     """
-    Element-wise min of each of the input tensors (with Numpy-style 
+    Element-wise min of each of the input tensors (with Numpy-style
     broadcasting support).
     """
 
@@ -3725,7 +3059,7 @@
         Args:
             a (CTensor): First operand
             b (CTensor): Second operand
-        Returns: 
+        Returns:
             CTensor, the output
             tuple of CTensor, mask tensor
         """
@@ -3739,7 +3073,7 @@
         """
         Args:
             *x (a list of CTensor): List of tensors for max.
-        Returns: 
+        Returns:
             CTensor, the output
         """
         assert (len(x) > 0)
@@ -3759,7 +3093,7 @@
         """
         Args:
             dy (CTensor): the gradient tensor from upper operations
-        Returns: 
+        Returns:
             a tuple for (*dx), dx is data for dL / dx.
         """
         if self.l == 1:
@@ -3780,17 +3114,17 @@
 
 def min(*l):
     """
-    Element-wise min of each of the input tensors (with Numpy-style 
+    Element-wise min of each of the input tensors (with Numpy-style
     broadcasting support).
     Args:
         *x (a list of Tensor): List of tensors for max.
-    Returns: 
+    Returns:
         Tensor, the output
     """
     return Min()(*l)[0]
 
 
-class Log(Operation):
+class Log(Operator):
     """
     `y = log(x)`, is applied to the tensor elementwise.
     """
@@ -3810,7 +3144,7 @@
         """
         Args:
             dy (CTensor): the gradient tensor from upper operations
-        Returns: 
+        Returns:
             CTensor, the gradient over input
         """
         dx = singa.PowFloat(self.input, -1)
@@ -3825,7 +3159,7 @@
     return Log()(x)[0]
 
 
-class HardSigmoid(Operation):
+class HardSigmoid(Operator):
     """
     `y = max(0, min(1, alpha * x + beta))`, is applied to the tensor elementwise.
     """
@@ -3862,7 +3196,7 @@
         """
         Args:
             dy (CTensor): the gradient tensor from upper operations
-        Returns: 
+        Returns:
             CTensor, the gradient over input
         """
         mask0 = singa.GTFloat(self.cache, 0.0)
@@ -3877,26 +3211,26 @@
     Args:
         x (Tensor): matrix
         alpha (float): Value of alpha.
-        gamma (float): Value of beta.        
+        gamma (float): Value of beta.
     Returns:
         a Tensor for the result
     """
     return HardSigmoid(alpha, gamma)(x)[0]
 
 
-class Squeeze(Operation):
+class Squeeze(Operator):
     """
-    Remove single-dimensional entries from the shape of a tensor. Takes a 
-    parameter axes with a list of axes to squeeze. If axes is not provided, 
-    all the single dimensions will be removed from the shape. If an axis is 
+    Remove single-dimensional entries from the shape of a tensor. Takes a
+    parameter axes with a list of axes to squeeze. If axes is not provided,
+    all the single dimensions will be removed from the shape. If an axis is
     selected with shape entry not equal to one, an error is raised.
     """
 
     def __init__(self, axis=[]):
         """
         Args:
-            axis (list of ints): List of integers indicating the dimensions 
-                to squeeze. Negative value means counting dimensions from 
+            axis (list of ints): List of integers indicating the dimensions
+                to squeeze. Negative value means counting dimensions from
                 the back. Accepted range is [-r, r-1] where r = rank(data).
         """
         super(Squeeze, self).__init__()
@@ -3906,7 +3240,7 @@
         """
         Args:
             x (CTensor): Input tensor
-        Returns: 
+        Returns:
             CTensor, the output
         """
         self.cache = x.shape()
@@ -3932,7 +3266,7 @@
         """
         Args:
             dy (CTensor): the gradient tensor from upper operations
-        Returns: 
+        Returns:
             CTensor, the gradient over input
         """
         return singa.Reshape(dy, self.cache)
@@ -3940,22 +3274,22 @@
 
 def squeeze(x, axis=[]):
     """
-    Remove single-dimensional entries from the shape of a tensor. Takes a 
-    parameter axes with a list of axes to squeeze. If axes is not provided, 
-    all the single dimensions will be removed from the shape. If an axis is 
+    Remove single-dimensional entries from the shape of a tensor. Takes a
+    parameter axes with a list of axes to squeeze. If axes is not provided,
+    all the single dimensions will be removed from the shape. If an axis is
     selected with shape entry not equal to one, an error is raised.
     Args:
         x (Tensor): Input tensor
-        axis (list of ints): List of integers indicating the dimensions 
-            to squeeze. Negative value means counting dimensions from 
+        axis (list of ints): List of integers indicating the dimensions
+            to squeeze. Negative value means counting dimensions from
             the back. Accepted range is [-r, r-1] where r = rank(data).
-    Returns: 
+    Returns:
         Tensor, the output
     """
     return Squeeze(axis)(x)[0]
 
 
-class Div(Operation):
+class Div(Operator):
     """
     Performs element-wise binary division (with Numpy-style broadcasting support).
     """
@@ -3967,8 +3301,15 @@
         """
         Return `np.div(a,b)`, where a and b are CTensor.
         """
+        ori_type = None
+        if a.data_type() != singa.kFloat32:
+            ori_type = a.data_type()
+            a = a.AsType(singa.kFloat32)
+            b = b.AsType(singa.kFloat32)
         res = singa.__mul__(a, singa.PowFloat(b, -1.0))
         # res = singa.__div__(a, b)
+        if ori_type is not None:
+            res = res.AsType(ori_type)
         if training:
             self.input = (singa.MultFloat(a, -1.0), singa.PowFloat(b, -1.0)
                          )  # -a, 1/b
@@ -3981,7 +3322,7 @@
         """
         Args:
             dy (CTensor): the gradient tensor from upper operations
-        Returns: 
+        Returns:
             a CTensor tuple for (da, db), da is data for dL / da, db is data
                 for dL / db.
         """
@@ -4006,9 +3347,9 @@
     return Div()(a, b)[0]
 
 
-class Shape(Operation):
+class Shape(Operator):
     """
-    Takes a tensor as input and outputs a tensor containing the shape of the 
+    Takes a tensor as input and outputs a tensor containing the shape of the
     input tensor.
     """
 
@@ -4019,7 +3360,7 @@
         """
         Args:
             x (CTensor): Input tensor
-        Returns: 
+        Returns:
             CTensor, the output
         """
         cur = list(x.shape())
@@ -4031,7 +3372,7 @@
         """
         Args:
             dy (CTensor): the gradient tensor from upper operations
-        Returns: 
+        Returns:
             list of int, the shape of dy
         """
         return list(dy.shape())
@@ -4039,21 +3380,21 @@
 
 def shape(x):
     """
-    Takes a tensor as input and outputs a tensor containing the shape of the 
+    Takes a tensor as input and outputs a tensor containing the shape of the
     input tensor.
     Args:
         x (Tensor): Input tensor
-    Returns: 
+    Returns:
         Tensor, the output
     """
     return Shape()(x)[0]
 
 
 # optimize max to support multi inputs
-class Max(Operation):
+class Max(Operator):
     """
-    Element-wise max of each of the input tensors (with Numpy-style 
-    broadcasting support). 
+    Element-wise max of each of the input tensors (with Numpy-style
+    broadcasting support).
     """
 
     def __init__(self):
@@ -4065,7 +3406,7 @@
         Args:
             a (CTensor): First operand
             b (CTensor): Second operand
-        Returns: 
+        Returns:
             CTensor, the output
             tuple of CTensor, mask tensor
         """
@@ -4079,7 +3420,7 @@
         """
         Args:
             *x (a list of CTensor): List of tensors for max.
-        Returns: 
+        Returns:
             CTensor, the output
         """
         assert (len(x) > 0)
@@ -4099,7 +3440,7 @@
         """
         Args:
             dy (CTensor): the gradient tensor from upper operations
-        Returns: 
+        Returns:
             a tuple for (*dx), dx is data for dL / dx.
         """
         if self.l == 1:
@@ -4120,16 +3461,16 @@
 
 def max(*l):
     """
-    Element-wise max of each of the input tensors (with Numpy-style broadcasting support). 
+    Element-wise max of each of the input tensors (with Numpy-style broadcasting support).
     Args:
         *x (a list of Tensor): List of tensors for max.
-    Returns: 
+    Returns:
         Tensor, the output
     """
     return Max()(*l)[0]
 
 
-class And(Operation):
+class And(Operator):
     """
     Returns the tensor resulted from performing the and logical operation elementwise on the input tensors A and B (with Numpy-style broadcasting support).
     """
@@ -4163,7 +3504,7 @@
     return And()(a, b)[0]
 
 
-class Or(Operation):
+class Or(Operator):
     """
     Returns the tensor resulted from performing the or logical operation elementwise on the input tensors A and B (with Numpy-style broadcasting support).
     """
@@ -4198,7 +3539,7 @@
     return Or()(a, b)[0]
 
 
-class Not(Operation):
+class Not(Operator):
     """
     Returns the negation of the input tensor element-wise.
     """
@@ -4233,7 +3574,7 @@
     return Not()(x)[0]
 
 
-class Xor(Operation):
+class Xor(Operator):
     """
     Performing the xor logical operation elementwise on the input tensors A and B (with Numpy-style broadcasting support).
     """
@@ -4268,7 +3609,7 @@
     return Xor()(a, b)[0]
 
 
-class Negative(Operation):
+class Negative(Operator):
     """
     `y = -x`, is applied to the tensor elementwise.
     """
@@ -4287,7 +3628,7 @@
         """
         Args:
             dy (CTensor): the gradient tensor from upper operations
-        Returns: 
+        Returns:
             CTensor, the gradient over input
         """
         return singa.MultFloat(dy, -1)
@@ -4300,7 +3641,7 @@
     return Negative()(x)[0]
 
 
-class Reciprocal(Operation):
+class Reciprocal(Operator):
     """
     `y = 1/x`, is applied to the tensor elementwise.
     """
@@ -4322,7 +3663,7 @@
         """
         Args:
             dy (CTensor): the gradient tensor from upper operations
-        Returns: 
+        Returns:
             CTensor, the gradient over input
         """
         #dy/dx = -1/x**2
@@ -4337,11 +3678,11 @@
     return Reciprocal()(x)[0]
 
 
-class Gemm(Operation):
+class Gemm(Operator):
     """
-    Init a General Matrix multiplication(Gemm) operator. Compute `Y = alpha * 
-    A' * B' + beta * C`, where input tensor A has shape (M, K) or (K, M), input 
-    tensor B has shape (K, N) or (N, K), input tensor C is broadcastable to 
+    Init a General Matrix multiplication(Gemm) operator. Compute `Y = alpha *
+    A' * B' + beta * C`, where input tensor A has shape (M, K) or (K, M), input
+    tensor B has shape (K, N) or (N, K), input tensor C is broadcastable to
     shape (M, N), and output tensor Y has shape (M, N).
     `A' = transpose(A)` if transA else A
     `B' = transpose(B)` if transB else B
@@ -4350,12 +3691,12 @@
     def __init__(self, alpha=1.0, beta=1.0, transA=0, transB=0):
         """
         Args:
-            alpha (float): Scalar multiplier for the product of input tensors 
+            alpha (float): Scalar multiplier for the product of input tensors
                 A * B.
             beta (float): Scalar multiplier for input tensor C.
             ransA (int): Whether A should be transposed
             transB (int): Whether B should be transposed
-        Returns: 
+        Returns:
             CTensor, the output
         """
         super(Gemm, self).__init__()
@@ -4368,14 +3709,14 @@
         """
         forward propogation of Gemm
         Args:
-            A (CTensor): The shape of A should be (M, K) if transA is 0, or 
+            A (CTensor): The shape of A should be (M, K) if transA is 0, or
                 (K, M) if transA is non-zero.
-            B (CTensor): The shape of B should be (K, N) if transB is 0, or 
+            B (CTensor): The shape of B should be (K, N) if transB is 0, or
                 (N, K) if transB is non-zero.
-            C (CTensor): (optional), Optional input tensor C. If not specified, 
-                the computation is done as if C is a scalar 0. The shape of C 
+            C (CTensor): (optional), Optional input tensor C. If not specified,
+                the computation is done as if C is a scalar 0. The shape of C
                 should be unidirectional broadcastable to (M, N).
-        Returns: 
+        Returns:
             tensor, the output
         """
         _A = singa.DefaultTranspose(A) if self.transA == 1 else A
@@ -4392,7 +3733,7 @@
         backward propogation of Gemm
         Args:
             dy (CTensor): The shape of A should be (M, K) if transA is 0, or (K, M) if transA is non-zero.
-        Returns: 
+        Returns:
             CTensor, the gradient over A
             CTensor, the gradient over B
             CTensor(optional), the gradient over C
@@ -4425,31 +3766,34 @@
 
 def gemm(A, B, C=None, alpha=1.0, beta=1.0, transA=0, transB=0):
     """
-    Init a General Matrix multiplication(Gemm) operator. Compute `Y = alpha * 
-    A' * B' + beta * C`, where input tensor A has shape (M, K) or (K, M), input 
-    tensor B has shape (K, N) or (N, K), input tensor C is broadcastable to 
+    Init a General Matrix multiplication(Gemm) operator. Compute `Y = alpha *
+    A' * B' + beta * C`, where input tensor A has shape (M, K) or (K, M), input
+    tensor B has shape (K, N) or (N, K), input tensor C is broadcastable to
     shape (M, N), and output tensor Y has shape (M, N).
     `A' = transpose(A)` if transA else A
     `B' = transpose(B)` if transB else B
     Args:
-        A (Tensor): The shape of A should be (M, K) if transA is 0, or 
+        A (Tensor): The shape of A should be (M, K) if transA is 0, or
             (K, M) if transA is non-zero.
-        B (Tensor): The shape of B should be (K, N) if transB is 0, or 
+        B (Tensor): The shape of B should be (K, N) if transB is 0, or
             (N, K) if transB is non-zero.
-        C (Tensor): (optional), Optional input tensor C. If not specified, 
-            the computation is done as if C is a scalar 0. The shape of C 
+        C (Tensor): (optional), Optional input tensor C. If not specified,
+            the computation is done as if C is a scalar 0. The shape of C
             should be unidirectional broadcastable to (M, N).
         alpha (float): Scalar multiplier for the product of input tensors A * B.
         beta (float): Scalar multiplier for input tensor C.
         ransA (int): Whether A should be transposed
         transB (int): Whether B should be transposed
-    Returns: 
+    Returns:
         Tensor, the output
     """
-    return Gemm(alpha, beta, transA, transB)(A, B, C)[0]
+    if C:
+        return Gemm(alpha, beta, transA, transB)(A, B, C)[0]
+    else:
+        return Gemm(alpha, beta, transA, transB)(A, B)[0]
 
 
-class GlobalAveragePool(Operation):
+class GlobalAveragePool(Operator):
     """
     Init a GlobalAveragePool operator
     """
@@ -4457,7 +3801,7 @@
     def __init__(self, data_format='channels_first'):
         """
         Args:
-            data_format (string): A string, we support two formats: 
+            data_format (string): A string, we support two formats:
                 channels_last and channels_first, default is channels_first.
                 channels_first means the format of input is (N x C x H x W)
                 channels_last means the format of input is (N x H x W x C)
@@ -4470,7 +3814,7 @@
         forward propogation of GlobalAveragePool
         Args:
             x (CTensor): the input tensor
-        Returns: 
+        Returns:
             CTensor, the output
         """
         if training:
@@ -4502,7 +3846,7 @@
         backward propogation of GlobalAveragePool
         Args:
             dy (CTensor): the gradient tensor from upper operations
-        Returns: 
+        Returns:
             CTensor, the gradient over input
         """
         self.mask.SetFloatValue(self.shape_divisor)
@@ -4514,17 +3858,17 @@
     GlobalAveragePool operator
     Args:
         x (Tensor): the input tensor
-        data_format (string): A string, we support two formats: 
+        data_format (string): A string, we support two formats:
             channels_last and channels_first, default is channels_first.
             channels_first means the format of input is (N x C x H x W)
             channels_last means the format of input is (N x H x W x C)
-    Returns: 
+    Returns:
         Tensor, the output
     """
     return GlobalAveragePool(data_format)(x)[0]
 
 
-class ConstantOfShape(Operation):
+class ConstantOfShape(Operator):
     """
     Init a ConstantOfShape, generate a tensor with given value and shape.
     """
@@ -4532,8 +3876,8 @@
     def __init__(self, value=0.):
         """
         Args:
-            value (float): (Optional) The value of the output elements. Should 
-                be a one-element value. If not specified, it defaults to 0 and 
+            value (float): (Optional) The value of the output elements. Should
+                be a one-element value. If not specified, it defaults to 0 and
                 datatype float32
         """
         super(ConstantOfShape, self).__init__()
@@ -4543,12 +3887,12 @@
         """
         forward of ConstantOfShape
         Args:
-            x: CTensor, 1D tensor. The shape of the expected output tensor. 
+            x: CTensor, 1D tensor. The shape of the expected output tensor.
                 All values must be >= 0.
         Returns:
-            the output CTensor. If attribute 'value' is specified, the value 
-                and datatype of the output tensor is taken from 'value'. If 
-                attribute 'value' is not specified, the value in the output 
+            the output CTensor. If attribute 'value' is specified, the value
+                and datatype of the output tensor is taken from 'value'. If
+                attribute 'value' is not specified, the value in the output
                 defaults to 0, and the datatype defaults to float32.
         """
         x_shape = tensor.to_numpy(tensor.from_raw_tensor(x)).astype(
@@ -4573,33 +3917,36 @@
     """
     Init a ConstantOfShape, generate a tensor with given value and shape.
     Args:
-        x: Tensor, 1D tensor. The shape of the expected output tensor. 
+        x: Tensor, 1D tensor. The shape of the expected output tensor.
             All values must be >= 0.
-        value (float): (Optional) The value of the output elements. Should 
-            be a one-element value. If not specified, it defaults to 0 and 
+        value (float): (Optional) The value of the output elements. Should
+            be a one-element value. If not specified, it defaults to 0 and
             datatype float32
     Returns:
-        the output Tensor. If attribute 'value' is specified, the value 
-            and datatype of the output tensor is taken from 'value'. If 
-            attribute 'value' is not specified, the value in the output 
+        the output Tensor. If attribute 'value' is specified, the value
+            and datatype of the output tensor is taken from 'value'. If
+            attribute 'value' is not specified, the value in the output
             defaults to 0, and the datatype defaults to float32.
     """
     return ConstantOfShape(value)(x)[0]
 
 
-class Dropout(Operation):
+class Dropout(Operator):
     """
     Init a Dropout, which scales the masked input data by the following equation:
     `output = scale * data * mask`, `scale = 1. / (1. - ratio)`.
     """
 
-    def __init__(self, ratio=0.5):
+    def __init__(self, seed=0, ratio=0.5):
         """
         Args:
+            seed (int): the random seed
             ratio (float): the ratio of random dropout, with value in [0, 1).
         """
         super(Dropout, self).__init__()
         self.ratio = ratio
+        self.seed = int(seed)
+        self.init_seed = False
 
     def forward(self, x):
         """
@@ -4609,6 +3956,9 @@
         Returns:
             the output CTensor.
         """
+        if not self.init_seed:
+            x.device().SetRandSeed(self.seed)
+            self.init_seed = True
         if training:
             self.scale = 1 / 1 - self.ratio
             self.mask = singa.Tensor(list(x.shape()), x.device())
@@ -4629,9 +3979,9 @@
         return dy
 
 
-def dropout(x, ratio=0.5):
+def dropout(x, seed=0, ratio=0.5):
     """
-    Init a Dropout, which scales the masked input data by the following 
+    Init a Dropout, which scales the masked input data by the following
     equation: `output = scale * data * mask`, `scale = 1. / (1. - ratio)`.
     Args:
         x (Tensor): input tensor.
@@ -4639,22 +3989,22 @@
     Returns:
         the output Tensor.
     """
-    return Dropout(ratio)(x)[0]
+    return Dropout(seed, ratio)(x)[0]
 
 
-class ReduceSum(Operation):
+class ReduceSum(Operator):
     """
-    Init a ReduceSum, computes the sum of the input tensor's element along 
+    Init a ReduceSum, computes the sum of the input tensor's element along
     the provided axes.
     """
 
     def __init__(self, axes=None, keepdims=1):
         """
         Args:
-            axes (list of int): A list of integers, along which to reduce. 
-                Accepted range is [-r, r-1] where r = rank(data). The default 
+            axes (list of int): A list of integers, along which to reduce.
+                Accepted range is [-r, r-1] where r = rank(data). The default
                 is None, which reduces over all the dimensions of the input tensor.
-            keepdims (int): Keep the reduced dimension or not, default 1 mean 
+            keepdims (int): Keep the reduced dimension or not, default 1 mean
                 keep reduced dimension.
         """
         super(ReduceSum, self).__init__()
@@ -4705,14 +4055,14 @@
 
 def reduce_sum(x, axes=None, keepdims=1):
     """
-    Init a ReduceSum, computes the sum of the input tensor's element along 
+    Init a ReduceSum, computes the sum of the input tensor's element along
     the provided axes.
     Args:
         x (Tensor): input tensor.
-        axes (list of int): A list of integers, along which to reduce. 
-            Accepted range is [-r, r-1] where r = rank(data). The default 
+        axes (list of int): A list of integers, along which to reduce.
+            Accepted range is [-r, r-1] where r = rank(data). The default
             is None, which reduces over all the dimensions of the input tensor.
-        keepdims (int): Keep the reduced dimension or not, default 1 mean 
+        keepdims (int): Keep the reduced dimension or not, default 1 mean
             keep reduced dimension.
     Returns:
         the output Tensor.
@@ -4720,19 +4070,19 @@
     return ReduceSum(axes, keepdims)(x)[0]
 
 
-class ReduceMean(Operation):
+class ReduceMean(Operator):
     """
-    Init a ReduceMean, computes the mean of the input tensor's element along 
+    Init a ReduceMean, computes the mean of the input tensor's element along
     the provided axes.
     """
 
     def __init__(self, axes=None, keepdims=1):
         """
         Args:
-            axes (list of int): A list of integers, along which to reduce. 
-                Accepted range is [-r, r-1] where r = rank(data). The default 
+            axes (list of int): A list of integers, along which to reduce.
+                Accepted range is [-r, r-1] where r = rank(data). The default
                 is None, which reduces over all the dimensions of the input tensor.
-            keepdims (int): Keep the reduced dimension or not, default 1 mean 
+            keepdims (int): Keep the reduced dimension or not, default 1 mean
                 keep reduced dimension.
         """
         super(ReduceMean, self).__init__()
@@ -4763,6 +4113,7 @@
             _x = tensor.reshape(_x, x_shape)
         self.cache = (x_shape, x)
         scale = np.prod(x_shape) / np.prod(x.shape())
+        self.scale = scale
         _x = singa.MultFloat(_x.data, scale)
         return _x
 
@@ -4779,19 +4130,20 @@
         mask = singa.Tensor(list(x.shape()), x.device())
         mask.SetFloatValue(1.0)
         dy = singa.__mul__(mask, dy)
+        dy = singa.MultFloat(dy, self.scale)
         return dy
 
 
 def reduce_mean(x, axes=None, keepdims=1):
     """
-    Init a ReduceMean, computes the mean of the input tensor's element along 
+    Init a ReduceMean, computes the mean of the input tensor's element along
     the provided axes.
     Args:
         x (Tensor): input tensor.
-        axes (list of int): A list of integers, along which to reduce. 
-            Accepted range is [-r, r-1] where r = rank(data). The default 
+        axes (list of int): A list of integers, along which to reduce.
+            Accepted range is [-r, r-1] where r = rank(data). The default
             is None, which reduces over all the dimensions of the input tensor.
-        keepdims (int): Keep the reduced dimension or not, default 1 mean 
+        keepdims (int): Keep the reduced dimension or not, default 1 mean
             keep reduced dimension.
     Returns:
         the output Tensor.
@@ -4799,9 +4151,9 @@
     return ReduceMean(axes, keepdims)(x)[0]
 
 
-class Slice(Operation):
+class Slice(Operator):
     """
-    Init a Slice, Produces a slice of the input tensor along multiple axes. 
+    Init a Slice, Produces a slice of the input tensor along multiple axes.
     Similar to numpy: https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html
     """
 
@@ -4810,11 +4162,11 @@
         Args:
             starts (list of int): starting indices of corresponding axis
             ends (list of int): ending indices of corresponding axis
-            axes (list of int): axes that `starts` and `ends` apply to. 
-                Negative value means counting dimensions from the back. 
+            axes (list of int): axes that `starts` and `ends` apply to.
+                Negative value means counting dimensions from the back.
                 Accepted range is [-r, r-1] where r = rank(data).
-            steps (list of int): slice step of corresponding axis in `axes`. 
-                Negative value means slicing backward. 'steps' cannot be 0. 
+            steps (list of int): slice step of corresponding axis in `axes`.
+                Negative value means slicing backward. 'steps' cannot be 0.
                 Defaults to 1.
         """
         super(Slice, self).__init__()
@@ -4843,6 +4195,7 @@
         if self.steps is None:
             self.steps = [1] * len(x_shape)  # steps = None
         for idx, axis in enumerate(self.axes):
+            axis = int(axis)
             start, end, step = self.starts[idx], self.ends[idx], self.steps[idx]
             if end > x_shape[axis]:
                 end = x_shape[axis]
@@ -4884,17 +4237,17 @@
 
 def slice(x, starts, ends, axes=None, steps=None):
     """
-    Init a Slice, Produces a slice of the input tensor along multiple axes. 
+    Init a Slice, Produces a slice of the input tensor along multiple axes.
     Similar to numpy: https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html
     Args:
         x (Tensor): input tensor.
         starts (list of int): starting indices of corresponding axis
         ends (list of int): ending indices of corresponding axis
-        axes (list of int): axes that `starts` and `ends` apply to. 
-            Negative value means counting dimensions from the back. 
+        axes (list of int): axes that `starts` and `ends` apply to.
+            Negative value means counting dimensions from the back.
             Accepted range is [-r, r-1] where r = rank(data).
-        steps (list of int): slice step of corresponding axis in `axes`. 
-            Negative value means slicing backward. 'steps' cannot be 0. 
+        steps (list of int): slice step of corresponding axis in `axes`.
+            Negative value means slicing backward. 'steps' cannot be 0.
             Defaults to 1.
     Returns:
         the output Tensor.
@@ -4902,9 +4255,9 @@
     return Slice(starts, ends, axes, steps)(x)[0]
 
 
-class Ceil(Operation):
+class Ceil(Operator):
     """
-    Ceil takes one input data (Tensor) and produces one output data (Tensor) 
+    Ceil takes one input data (Tensor) and produces one output data (Tensor)
     where the ceil is, `y = ceil(x)`, is applied to the tensor elementwise.
     """
 
@@ -4936,7 +4289,7 @@
 
 def ceil(x):
     """
-    Ceil takes one input data (Tensor) and produces one output data (Tensor) 
+    Ceil takes one input data (Tensor) and produces one output data (Tensor)
     where the ceil is, `y = ceil(x)`, is applied to the tensor elementwise.
     Args:
         x (Tensor): input tensor.
@@ -4946,22 +4299,66 @@
     return Ceil()(x)[0]
 
 
-class Split(Operation):
+class Floor(Operator):
     """
-    Init a Split, Split a tensor into a list of tensors, along the specified 
-    'axis'. 
+    Floor takes one input data (Tensor) and produces one output data (Tensor), 
+    where the floor is, `y = floor(x)`, is applied to the tensor elementwise
+    """
+
+    def __init__(self):
+        super(Floor, self).__init__()
+
+    def forward(self, x):
+        """
+        forward of floor
+        Args: 
+            x (CTensor): input tensor
+        Returns:
+            the output CTensor    
+        """
+        return singa.Floor(x)
+
+    def backward(self, dy):
+        """
+        backward of floor. Derivative of floor is 0
+        Args: 
+            dy (CTensor): gradient tensor
+        Returns:
+            the gradient tensor over the input tensor. 
+        """
+        dy = singa.Tensor(dy.shape(), dy.device())
+        dy.SetFloatValue(0.)
+        return dy
+
+
+def floor(x):
+    """
+    floor takes one input data (Tensor) and produces one output data (Tensor)
+    the value of floor is `y = floor(x)`, is applied to the tensor elementwise. 
+    Args: 
+        x(Tensor): input tensor.
+    Returns: 
+        the output tensor    
+    """
+    return Floor()(x)[0]
+
+
+class Split(Operator):
+    """
+    Init a Split, Split a tensor into a list of tensors, along the specified
+    'axis'.
     """
 
     def __init__(self, axis, parts, num_output=None):
         """
         Args:
-            axis (int): which axis to split on. A negative value means 
-                counting dimensions from the back. Accepted range is 
+            axis (int): which axis to split on. A negative value means
+                counting dimensions from the back. Accepted range is
                 [-rank, rank-1] where r = rank(input).
-            parts (list of int): length of each output, which can be specified 
-                using argument 'parts'. Otherwise, the tensor is parts to equal 
+            parts (list of int): length of each output, which can be specified
+                using argument 'parts'. Otherwise, the tensor is parts to equal
                 sized parts.
-            num_output (bool): once parts is none, the tensor is split to equal 
+            num_output (bool): once parts is none, the tensor is split to equal
                 sized parts for each output.
         """
         super(Split, self).__init__()
@@ -5006,17 +4403,17 @@
 
 def split(x, axis, parts, num_output=None):
     """
-    Init a Split, Split a tensor into a list of tensors, along the specified 
-    'axis'. 
+    Init a Split, Split a tensor into a list of tensors, along the specified
+    'axis'.
     Args:
         x (Tensor): input tensor.
-        axis (int): which axis to split on. A negative value means 
-            counting dimensions from the back. Accepted range is 
+        axis (int): which axis to split on. A negative value means
+            counting dimensions from the back. Accepted range is
             [-rank, rank-1] where r = rank(input).
-        parts (list of int): length of each output, which can be specified 
-            using argument 'parts'. Otherwise, the tensor is parts to equal 
+        parts (list of int): length of each output, which can be specified
+            using argument 'parts'. Otherwise, the tensor is parts to equal
             sized parts.
-        num_output (bool): once parts is none, the tensor is split to equal 
+        num_output (bool): once parts is none, the tensor is split to equal
             sized parts for each output.
     Returns:
         the output Tensor.
@@ -5024,18 +4421,18 @@
     return Split(axis, parts, num_output)(x)
 
 
-class Gather(Operation):
+class Gather(Operator):
     """
-    Init a Gather, Given data tensor of rank r >= 1, and indices tensor of 
-    rank q, gather entries of the axis dimension of data (by default outer-most 
+    Init a Gather, Given data tensor of rank r >= 1, and indices tensor of
+    rank q, gather entries of the axis dimension of data (by default outer-most
     one as axis=0) indexed by indices, and concatenates them in an output tensor of rank `q + (r - 1)`.
     """
 
     def __init__(self, axis, indices):
         """
         Args:
-            axis (int): which axis to slice on. A negative value means counting 
-                dimensions from the back. Accepted range is [-rank, rank-1] 
+            axis (int): which axis to slice on. A negative value means counting
+                dimensions from the back. Accepted range is [-rank, rank-1]
                 where r = rank(input).
             indices (list of int): entries of the axis dimension of data.
         """
@@ -5057,10 +4454,10 @@
         xs = []
         for indice in self.indices:
             # each indice is a sub-indice
-            if isinstance(indice, tuple) or isinstance(indice, list):
+            if isinstance(indice, (tuple, list, np.ndarray)):
                 sub_xs = []
                 for idx in indice:
-                    idx = idx % _shape
+                    idx = int(idx % _shape)
                     tmp_tensor = singa.SliceOn(x, idx, idx + 1, self.axis)
                     sub_xs.append(tmp_tensor)
                 sub_xs = singa.VecTensor(sub_xs)
@@ -5069,7 +4466,7 @@
                 _slice_shape.insert(self.axis, 1)  # add a new axis to concat
                 tmp_tensor = singa.Reshape(tmp_tensor, _slice_shape)
             else:
-                indice = indice % _shape
+                indice = int(indice % _shape)
                 tmp_tensor = singa.SliceOn(x, indice, indice + 1, self.axis)
             xs.append(tmp_tensor)
         xs = singa.VecTensor(xs)
@@ -5126,13 +4523,13 @@
 
 def gather(x, axis, indices):
     """
-    Init a Gather, Given data tensor of rank r >= 1, and indices tensor of 
-    rank q, gather entries of the axis dimension of data (by default outer-most 
+    Init a Gather, Given data tensor of rank r >= 1, and indices tensor of
+    rank q, gather entries of the axis dimension of data (by default outer-most
     one as axis=0) indexed by indices, and concatenates them in an output tensor of rank `q + (r - 1)`.
     Args:
         x (Tensor): input tensor.
-        axis (int): which axis to slice on. A negative value means counting 
-            dimensions from the back. Accepted range is [-rank, rank-1] 
+        axis (int): which axis to slice on. A negative value means counting
+            dimensions from the back. Accepted range is [-rank, rank-1]
             where r = rank(input).
         indices (list of int): entries of the axis dimension of data.
     Returns:
@@ -5141,17 +4538,17 @@
     return Gather(axis, indices)(x)[0]
 
 
-class Tile(Operation):
+class Tile(Operator):
     """
-    Init a Tile, Constructs a tensor by tiling a given tensor. This is the same 
+    Init a Tile, Constructs a tensor by tiling a given tensor. This is the same
     as function tile in Numpy: https://docs.scipy.org/doc/numpy/reference/generated/numpy.tile.html
     """
 
     def __init__(self, repeats):
         """
         Args:
-            repeats (list of int): 1D int matrix of the same length as input's 
-                dimension number, includes numbers of repeated copies along 
+            repeats (list of int): 1D int matrix of the same length as input's
+                dimension number, includes numbers of repeated copies along
                 input's dimensions.
         """
         super(Tile, self).__init__()
@@ -5211,12 +4608,12 @@
 
 def tile(x, repeats):
     """
-    Init a Tile, Constructs a tensor by tiling a given tensor. This is the same 
+    Init a Tile, Constructs a tensor by tiling a given tensor. This is the same
     as function tile in Numpy: https://docs.scipy.org/doc/numpy/reference/generated/numpy.tile.html
     Args:
         x (Tensor): input tensor.
-        repeats (list of int): 1D int matrix of the same length as input's 
-            dimension number, includes numbers of repeated copies along 
+        repeats (list of int): 1D int matrix of the same length as input's
+            dimension number, includes numbers of repeated copies along
             input's dimensions.
     Returns:
         the output Tensor.
@@ -5224,9 +4621,9 @@
     return Tile(repeats)(x)[0]
 
 
-class NonZero(Operation):
+class NonZero(Operator):
     """
-    Init a NonZero, Constructs a tensor by tiling a given tensor. This is the same 
+    Init a NonZero, Constructs a tensor by tiling a given tensor. This is the same
     as function tile in Numpy: https://docs.scipy.org/doc/numpy/reference/generated/numpy.tile.html
     """
 
@@ -5260,7 +4657,7 @@
 
 def nonzero(x):
     """
-    Init a NonZero, Constructs a tensor by tiling a given tensor. This is the same 
+    Init a NonZero, Constructs a tensor by tiling a given tensor. This is the same
     as function tile in Numpy: https://docs.scipy.org/doc/numpy/reference/generated/numpy.tile.html
     Args:
         x (Tensor): input tensor.
@@ -5270,10 +4667,10 @@
     return NonZero()(x)[0]
 
 
-class Cast(Operation):
+class Cast(Operator):
     """
-    The operator casts the elements of a given input tensor to a data type 
-    specified by the 'to' argument and returns an output tensor of the same 
+    The operator casts the elements of a given input tensor to a data type
+    specified by the 'to' argument and returns an output tensor of the same
     size in the converted type.
     """
 
@@ -5300,7 +4697,7 @@
     def backward(self, dy):
         """
         backward of Cast
-        Args:f
+        Args:
             dy (CTensor), gradient tensor.
         Raises:
             AssertionError: no backward function for this operator
@@ -5310,8 +4707,8 @@
 
 def cast(x, to):
     """
-    The operator casts the elements of a given input tensor to a data type 
-    specified by the 'to' argument and returns an output tensor of the same 
+    The operator casts the elements of a given input tensor to a data type
+    specified by the 'to' argument and returns an output tensor of the same
     size in the converted type.
     Args:
         x (Tensor): input tensor.
@@ -5322,27 +4719,27 @@
     return Cast(to)(x)[0]
 
 
-class OneHot(Operation):
+class OneHot(Operator):
     """
-    Produces a one-hot tensor based on inputs. 
+    Produces a one-hot tensor based on inputs.
     """
 
     def __init__(self, axis, depth, values):
         """
         Args:
-            axis (int): Axis along which one-hot representation in added. 
-                Default: axis=-1. axis=-1 means that the additional dimension 
-                will be inserted as the innermost/last dimension in the output 
+            axis (int): Axis along which one-hot representation in added.
+                Default: axis=-1. axis=-1 means that the additional dimension
+                will be inserted as the innermost/last dimension in the output
                 tensor.
-            depth (int): Scalar specifying the number of classes in one-hot 
-                tensor. This is also the size of the one-hot dimension 
-                (specified by 'axis' attribute) added on in the output tensor. 
-                The values in the 'indices' input tensor are expected to be in 
+            depth (int): Scalar specifying the number of classes in one-hot
+                tensor. This is also the size of the one-hot dimension
+                (specified by 'axis' attribute) added on in the output tensor.
+                The values in the 'indices' input tensor are expected to be in
                 the range [-depth, depth-1].
-            values (float): Rank 1 tensor containing exactly two elements, in 
-                the format [off_value, on_value], where 'on_value' is the 
-                value used for filling locations specified in 'indices' input 
-                tensor, 
+            values (float): Rank 1 tensor containing exactly two elements, in
+                the format [off_value, on_value], where 'on_value' is the
+                value used for filling locations specified in 'indices' input
+                tensor,
         """
         super(OneHot, self).__init__()
         self.axis = axis
@@ -5353,9 +4750,9 @@
         """
         forward of OneHot, we borrow this function from onnx
         Args:
-            indices (CTensor): Scalar specifying the number of classes in 
-                one-hot tensor. The values in the 'indices' input tensor are 
-                expected to be in the range [-depth, depth-1]. 
+            indices (CTensor): Scalar specifying the number of classes in
+                one-hot tensor. The values in the 'indices' input tensor are
+                expected to be in the range [-depth, depth-1].
         Returns:
             the output CTensor.
         """
@@ -5389,25 +4786,986 @@
 
 def onehot(axis, indices, depth, values):
     """
-    Produces a one-hot tensor based on inputs. 
+    Produces a one-hot tensor based on inputs.
     Args:
-        axis (int): Axis along which one-hot representation in added. 
-            Default: axis=-1. axis=-1 means that the additional dimension 
-            will be inserted as the innermost/last dimension in the output 
+        axis (int): Axis along which one-hot representation in added.
+            Default: axis=-1. axis=-1 means that the additional dimension
+            will be inserted as the innermost/last dimension in the output
             tensor.
-        indices (Tensor): Scalar specifying the number of classes in 
-            one-hot tensor. The values in the 'indices' input tensor are 
-            expected to be in the range [-depth, depth-1]. 
-        depth (int): Scalar specifying the number of classes in one-hot 
-            tensor. This is also the size of the one-hot dimension 
-            (specified by 'axis' attribute) added on in the output tensor. 
-            The values in the 'indices' input tensor are expected to be in 
+        indices (Tensor): Scalar specifying the number of classes in
+            one-hot tensor. The values in the 'indices' input tensor are
+            expected to be in the range [-depth, depth-1].
+        depth (int): Scalar specifying the number of classes in one-hot
+            tensor. This is also the size of the one-hot dimension
+            (specified by 'axis' attribute) added on in the output tensor.
+            The values in the 'indices' input tensor are expected to be in
             the range [-depth, depth-1].
-        values (float): Rank 1 tensor containing exactly two elements, in 
-            the format [off_value, on_value], where 'on_value' is the 
-            value used for filling locations specified in 'indices' input 
-            tensor, 
+        values (float): Rank 1 tensor containing exactly two elements, in
+            the format [off_value, on_value], where 'on_value' is the
+            value used for filling locations specified in 'indices' input
+            tensor,
     Returns:
         the output Tensor.
     """
     return OneHot(axis, depth, values)(indices)[0]
+
+
+class _RNN(Operator):
+    """ RNN operation with c++ backend
+    """
+
+    def __init__(
+            self,
+            handle,
+            return_sequences=False,
+            #  batch_first=True,
+            use_mask=False,
+            seq_lengths=None):
+        assert singa.USE_CUDA, "Not able to run without CUDA"
+        super(_RNN, self).__init__()
+        self.handle = handle
+        self.return_sequences = return_sequences
+        self.use_mask = use_mask
+        if use_mask:
+            assert type(seq_lengths) == Tensor, "wrong type for seq_lengths"
+        self.seq_lengths = seq_lengths
+
+    def forward(self, x, hx, cx, w):
+        if training:
+            if self.use_mask:
+                (y, hy,
+                 cy) = singa.GpuRNNForwardTrainingEx(x, hx, cx, w,
+                                                     self.seq_lengths.data,
+                                                     self.handle)
+            else:
+                (y, hy,
+                 cy) = singa.GpuRNNForwardTraining(x, hx, cx, w, self.handle)
+            self.inputs = {
+                'x': x,
+                'hx': hx,
+                'cx': cx,
+                'w': w,
+                'y': y,
+                'hy': hy,
+                'cy': cy
+            }
+        else:
+            if self.use_mask:
+                (y, hy,
+                 cy) = singa.GpuRNNForwardInferenceEx(x, hx, cx, w,
+                                                      self.seq_lengths.data,
+                                                      self.handle)
+            else:
+                (y, hy,
+                 cy) = singa.GpuRNNForwardInference(x, hx, cx, w, self.handle)
+
+        if self.return_sequences:
+            # (seq, bs, data)
+            return y
+        else:
+            # return last time step of y
+            # (seq, bs, data)[-1] -> (bs, data)
+            last_y_shape = (y.shape()[1], y.shape()[2])
+            last_y = singa.Tensor(list(last_y_shape), x.device())
+
+            src_offset = y.Size() - last_y.Size()
+            # def copy_data_to_from(dst, src, size, dst_offset=0, src_offset=0):
+            singa.CopyDataToFrom(last_y, y, last_y.Size(), 0, src_offset)
+            return last_y
+
+    def backward(self, grad):
+        assert training is True and hasattr(
+            self, "inputs"), "Please set training as True before do BP. "
+
+        # (seq, bs, hid)
+        dy = None
+        if self.return_sequences:
+            assert grad.shape() == self.inputs['y'].shape(), (
+                "grad shape %s != y shape %s" %
+                (grad.shape(), self.inputs['y'].shape()))
+            dy = grad
+        else:
+            # grad (bs, directions*hidden) -> dy (seq, bs, directions*hidden)
+            #   empty space filled by zeros
+            assert grad.shape() == (self.inputs['y'].shape()[1],
+                                    self.inputs['y'].shape()[2]), (
+                                        "grad y shape %s != last y shape %s" %
+                                        (grad.shape(),
+                                         (self.inputs['y'].shape()[1],
+                                          self.inputs['y'].shape()[2])))
+            dy = singa.Tensor(list(self.inputs['y'].shape()), grad.device())
+            dy.SetFloatValue(0.0)
+            dst_offset = dy.Size() - grad.Size()
+            singa.CopyDataToFrom(dy, grad, grad.Size(), dst_offset, 0)
+
+        # states grad are zeros, since states are not used in forward pass
+        dhy = singa.Tensor(list(self.inputs['hy'].shape()), grad.device())
+        dhy.SetFloatValue(0.0)
+        dcy = singa.Tensor(list(self.inputs['cy'].shape()), grad.device())
+        dcy.SetFloatValue(0.0)
+
+        if self.use_mask:
+            (dx, dhx,
+             dcx) = singa.GpuRNNBackwardxEx(self.inputs['y'], dy, dhy, dcy,
+                                            self.inputs['w'], self.inputs['hx'],
+                                            self.inputs['cx'],
+                                            self.seq_lengths.data, self.handle)
+            dW = singa.GpuRNNBackwardWEx(self.inputs['x'], self.inputs['hx'],
+                                         self.inputs['y'],
+                                         self.seq_lengths.data, self.handle)
+        else:
+            (dx, dhx,
+             dcx) = singa.GpuRNNBackwardx(self.inputs['y'], dy, dhy, dcy,
+                                          self.inputs['w'], self.inputs['hx'],
+                                          self.inputs['cx'], self.handle)
+            dW = singa.GpuRNNBackwardW(self.inputs['x'], self.inputs['hx'],
+                                       self.inputs['y'], self.handle)
+
+
+        return dx, dhx, dcx, dW
+
+
+class CosSim(Operator):
+    """
+    Init a cos similarity operator
+    """
+
+    def __init__(self):
+        super(CosSim, self).__init__()
+
+    @classmethod
+    def dot(cls, a, b):
+        """
+        dot multiply
+        Args:
+            a (CTensor): 2d input tensor.
+            b (CTensor): 2d input tensor.
+        Returns:
+            CTensor: the output CTensor.
+        """
+        batch_size = a.shape()[0]
+        ret = []
+        for indice in range(batch_size):
+            tmp_a = singa.SliceOn(a, indice, indice + 1, 0)  # 1 * d
+            tmp_b = singa.SliceOn(b, indice, indice + 1, 0)  # 1 * d
+            tmp_b = singa.DefaultTranspose(tmp_b)
+            tmp_tensor = singa.Mult(tmp_a, tmp_b)  # 1 * d * d * 1
+            ret.append(tmp_tensor)
+        ret = singa.VecTensor(ret)
+        ret = singa.ConcatOn(ret, 0)  # b * 1
+        return singa.Reshape(ret, [ret.shape()[0]])  # b
+
+    def forward(self, a, b):
+        """
+        forward of CosSim
+        Args:
+            a (CTensor): input tensor.
+            b (CTensor): input tensor.
+        Returns:
+            the output CTensor.
+        """
+        ad = CosSim.dot(a, a)
+        bd = CosSim.dot(b, b)
+        ap = singa.PowFloat(ad, 0.5)
+        bp = singa.PowFloat(bd, 0.5)
+        ret = singa.__div__(CosSim.dot(a, b), singa.__mul__(ap, bp))
+        if training:
+            self.cache = (a, b, ad, bd, ap, bp, ret)
+        return ret
+
+    def backward(self, dy):
+        """
+        backward of CosSim
+        follow https://math.stackexchange.com/a/1923705
+        Args:
+            dy (CTensor): gradient tensor.
+        Return:
+            the gradient tensor over input tensor.
+        """
+        a, b, ad, bd, ap, bp, ret = self.cache
+        ab = singa.__mul__(ap, bp)
+        ab = singa.Reshape(ab, list(ab.shape()) + [1])  # b * 1
+        ad = singa.Reshape(ad, list(ad.shape()) + [1])  # b * 1
+        bd = singa.Reshape(bd, list(bd.shape()) + [1])  # b * 1
+        ret = singa.Reshape(ret, list(ret.shape()) + [1])  # b * 1
+        dy = singa.Reshape(dy, list(dy.shape()) + [1])  # boardcast
+        da = singa.__sub__(singa.__div__(b, ab),
+                           singa.__div__(singa.__mul__(ret, a), ad))
+        db = singa.__sub__(singa.__div__(a, ab),
+                           singa.__div__(singa.__mul__(ret, b), bd))
+        da = singa.__mul__(dy, da)
+        db = singa.__mul__(dy, db)
+        return da, db
+
+
+def cossim(a, b):
+    """
+    Produces a cos similarity operator
+    Args:
+        a (CTensor): input tensor.
+        b (CTensor): input tensor.
+    Returns:
+        the output Tensor.
+    """
+    assert a.shape == b.shape, "shape not match for cossim"
+    assert a.ndim() == 2, "shape should be in 2d for cossim"
+    assert b.ndim() == 2, "shape should be in 2d for cossim"
+    return CosSim()(a, b)[0]
+
+
+class Expand(Operator):
+    """
+    Expand operator following ONNX Operator Schemas
+    https://github.com/onnx/onnx/blob/master/docs/Operators.md#Expand
+
+    Example usage::
+    data = [[1.], [2.], [3.]]
+
+    # dim_changed
+    shape = [2, 1, 6]
+    output = [[[1., 1., 1., 1., 1., 1.], 
+               [2., 2., 2., 2., 2., 2.],
+               [3., 3., 3., 3., 3., 3.]],
+              [[1., 1., 1., 1., 1., 1.],
+               [2., 2., 2., 2., 2., 2.],
+               [3., 3., 3., 3., 3., 3.]]]
+
+    # dim_unchanged
+    shape = [3, 4]
+    output = [[1., 1., 1., 1.],
+              [2., 2., 2., 2.],
+              [3., 3., 3., 3.]]
+    """
+
+    def __init__(self, shape):
+        """
+        Args:
+            shape (list[int]: indicates the shape you want to expand to, 
+                following the broadcast rule
+        """
+        super(Expand, self).__init__()
+        self.shape = shape
+
+    def forward(self, x):
+        if isinstance(self.shape, np.ndarray):
+            self.shape = self.shape.tolist()
+        else:
+            self.shape = list(self.shape)
+        self.dim_changed = True
+        self.x_shape = list(x.shape())
+        x_shape = self.x_shape.copy()
+        for s_1, s_2 in zip(self.shape[::-1], x_shape[::-1]):
+            if s_1 != 1 and s_2 != 1 and s_1 != s_2:
+                if len(self.shape) != len(x_shape):
+                    assert False, ('not support dim_unchanged mode')
+                self.dim_changed = False
+                break
+        if self.dim_changed:
+            tmp_tensor = singa.Tensor(self.shape, x.device())
+            tmp_tensor.SetFloatValue(1.)
+            x = singa.__mul__(x, tmp_tensor)
+        else:
+            for axis, s_1, s_2 in zip(range(len(self.shape)), self.shape,
+                                      x_shape):
+                if s_1 == s_2:
+                    continue
+                xs = [x] * (s_1 // s_2)
+                x = singa.VecTensor(xs)
+                x = singa.ConcatOn(x, axis)
+        return x
+
+    def backward(self, dy):
+        x_shape = self.x_shape
+        if self.dim_changed:
+            dy = tensor.from_raw_tensor(dy)
+            if len(self.shape) > len(x_shape):
+                x_shape = [1] * (len(self.shape) - len(x_shape)) + x_shape
+            for axis, s in zip(range(len(self.shape))[::-1], x_shape[::1]):
+                if s == 1:
+                    dy = tensor.sum(dy, axis)
+            dy = dy.data
+        else:
+            for axis, s_1, s_2 in zip(
+                    range(len(self.shape))[::-1], self.shape[::-1],
+                    x_shape[::-1]):
+                if s_1 > s_2:
+                    duplic = s_1 // s_2
+                    dxs = []
+                    for i in range(s_2):
+                        tmp_tensor = None
+                        for j in range(duplic):
+                            if not tmp_tensor:
+                                tmp_tensor = singa.SliceOn(
+                                    dy, j * s_2 + i, j * s_2 + i + 1, axis)
+                            else:
+                                tmp_tensor += singa.SliceOn(
+                                    dy, j * s_2 + i, j * s_2 + i + 1, axis)
+                        dxs.append(tmp_tensor)
+                    dxs = singa.VecTensor(dxs)
+                    dy = singa.ConcatOn(dxs, axis)
+        dy = singa.Reshape(dy, self.x_shape)
+        return dy
+
+
+def expand(x, shape):
+    """
+    Produces a Expand operator
+    Args:
+        x (Tensor): input tensor.
+        shape (list[int]: indicates the shape you want to expand to, 
+            following the broadcast rule
+    Returns:
+        the output Tensor.
+    """
+    return Expand(shape)(x)[0]
+
+
+class Pad(Operator):
+    """
+    Pad operator following ONNX Operator Schemas
+    https://github.com/onnx/onnx/blob/master/docs/Operators.md#Pad
+
+    Example usage::
+        data = 
+        [
+            [1.0, 1.2],
+            [2.3, 3.4],
+            [4.5, 5.7],
+        ] 
+        pads = [0, 2, 0, 0]
+
+        # constant mode
+        mode = 'constant'
+        constant_value = 0.0
+        output = 
+        [
+            [
+                [0.0, 0.0, 1.0, 1.2],
+                [0.0, 0.0, 2.3, 3.4],
+                [0.0, 0.0, 4.5, 5.7],
+            ],
+        ]
+
+        # reflect mode
+        mode = 'reflect'
+        output = 
+        [
+            [
+                [1.0, 1.2, 1.0, 1.2],
+                [2.3, 3.4, 2.3, 3.4],
+                [4.5, 5.7, 4.5, 5.7],
+            ],
+        ]
+
+        # edge mode
+        mode = 'edge'
+        output = 
+        [
+            [
+                [1.0, 1.0, 1.0, 1.2],
+                [2.3, 2.3, 2.3, 3.4],
+                [4.5, 4.5, 4.5, 5.7],
+            ],
+        ]
+    """
+
+    def __init__(self, mode, pads, constant=0.):
+        """
+        Args:
+            mode (string): Supported modes: `constant`(default), `reflect`, `edge`.
+            pads (list[int]): list of integers indicating the number of padding elements 
+                to add at the beginning each axis.
+            constant (float): A scalar value to be used if the mode chosen is 
+                `constant`
+        """
+        super(Pad, self).__init__()
+        self.mode = mode
+        if self.mode not in ("constant", "reflect", "edge"):
+            assert False, ('Only support three modes: constant, reflect, edge')
+        self.constant = constant
+        self.pads = pads
+        self.pad_width = ()
+
+    def forward(self, x):
+        if not self.pad_width:
+            half_width = len(self.pads) // 2
+            for i in range(half_width):
+                self.pad_width += ((self.pads[i], self.pads[i + half_width])),
+
+        for axis, pads in zip(range(len(x.shape())), self.pad_width):
+            for pad, is_left in zip(pads, (True, False)):
+                if pad == 0:
+                    continue
+                pad_shape = list(x.shape())
+                if self.mode == "constant":
+                    pad_shape[axis] = pad
+                    padding = singa.Tensor(list(pad_shape), x.device())
+                    padding.SetFloatValue(self.constant)
+                    if is_left:
+                        x = singa.ConcatOn(singa.VecTensor([padding, x]), axis)
+                    else:
+                        x = singa.ConcatOn(singa.VecTensor([x, padding]), axis)
+                elif self.mode == "reflect":
+                    axis_shape = pad_shape[axis]
+                    if is_left:
+                        padding = singa.SliceOn(x, 0, pad, axis)
+                        x = singa.ConcatOn(singa.VecTensor([padding, x]), axis)
+                    else:
+                        padding = singa.SliceOn(x, axis_shape - pad, axis_shape,
+                                                axis)
+                        x = singa.ConcatOn(singa.VecTensor([x, padding]), axis)
+                elif self.mode == "edge":
+                    axis_shape = pad_shape[axis]
+                    if is_left:
+                        padding = []
+                        for _ in range(pad):
+                            padding.append(singa.SliceOn(x, 0, 1, axis))
+                        padding.append(x)
+                        padding = singa.VecTensor(padding)
+                        x = singa.ConcatOn(padding, axis)
+                    else:
+                        padding = [x]
+                        for _ in range(pad):
+                            padding.append(
+                                singa.SliceOn(x, axis_shape - 1, axis_shape,
+                                              axis))
+                        padding = singa.VecTensor(padding)
+                        x = singa.ConcatOn(padding, axis)
+        return x
+
+    def backward(self, dy):
+        for axis, pads in zip(range(len(dy.shape())), self.pad_width):
+            for pad, is_left in zip(pads, (True, False)):
+                if pad == 0:
+                    continue
+                axis_shape = list(dy.shape())[axis]
+                if is_left:
+                    dy = singa.SliceOn(dy, pad, axis_shape, axis)
+                else:
+                    dy = singa.SliceOn(dy, 0, axis_shape - pad, axis)
+        return dy
+
+
+def pad(x, mode, pads, constant=0.):
+    """
+    Produces a pad operator
+    Args:
+        x (Tensor): input tensor.
+        mode (string): Supported modes: `constant`(default), `reflect`, `edge`.
+        pads (list[int]): list of integers indicating the number of padding elements 
+            to add at the beginning each axis.
+        constant (float): A scalar value to be used if the mode chosen is 
+            `constant`
+    Returns:
+        the output Tensor.
+    """
+    return Pad(mode, pads, constant)(x)[0]
+
+
+class UpSample(Operator):
+    """
+    UpSample operator following ONNX Operator Schemas
+    https://github.com/onnx/onnx/blob/master/docs/Operators.md#upsample
+
+    Example usage::
+    data = [[[[1, 2],
+              [3, 4],]]]
+
+    # nearest
+    scales = [1.0, 1.0, 2.0, 3.0]
+    output = [[[[1, 1, 1, 2, 2, 2],
+                [1, 1, 1, 2, 2, 2],
+                [3, 3, 3, 4, 4, 4],
+                [3, 3, 3, 4, 4, 4],]]]
+    """
+
+    def __init__(self, mode, scales):
+        """
+        Args:
+            scales (list[int]): The scale array along each dimension. It takes 
+                value greater than or equal to 1. 
+        """
+        super(UpSample, self).__init__()
+        self.scales = scales
+        self.mode = mode.lower()
+        if self.mode != "nearest":
+            assert False, "only support nearest mode."
+
+    def forward(self, x):
+        if isinstance(self.scales, np.ndarray):
+            self.scales = self.scales.tolist()
+        else:
+            self.scales = list(self.scales)
+        self.x_shape = list(x.shape())
+        for axis, s in zip(range(len(self.scales)), self.scales):
+            s = int(s)
+            if s == 1:
+                continue
+            x = x.Repeat([
+                s,
+            ], axis)
+        return x
+
+    def backward(self, dy):
+        x_shape = self.x_shape.copy()
+        for axis, s_1, s_2 in zip(
+                range(len(self.scales))[::-1], self.scales[::-1],
+                x_shape[::-1]):
+            s_1 = int(s_1)
+            if s_1 != 1:
+                duplic = s_1
+                dxs = []
+                for i in range(s_2):
+                    tmp_tensor = None
+                    for j in range(duplic):
+                        if not tmp_tensor:
+                            tmp_tensor = singa.SliceOn(dy, i * duplic + j,
+                                                       i * duplic + j + 1, axis)
+                        else:
+                            tmp_tensor += singa.SliceOn(dy, i * duplic + j,
+                                                        i * duplic + j + 1,
+                                                        axis)
+                    dxs.append(tmp_tensor)
+                dxs = singa.VecTensor(dxs)
+                dy = singa.ConcatOn(dxs, axis)
+        dy = singa.Reshape(dy, self.x_shape)
+        return dy
+
+
+def upsample(x, mode, scales):
+    """
+    Produces a upsample operator
+    Args:
+        x (Tensor): input tensor.
+        scales (list[int]): The scale array along each dimension. It takes 
+                value greater than or equal to 1. 
+    Returns:
+        the output Tensor.
+    """
+    return UpSample(mode, scales)(x)[0]
+
+
+class DepthToSpace(Operator):
+    """
+    DepthToSpace operator following ONNX Operator Schemas
+    https://github.com/onnx/onnx/blob/master/docs/Operators.md#DepthToSpace
+
+    Example usage::
+    blocksize = 2
+    # (1, 8, 2, 3) input tensor
+    data = [[[[0., 1., 2.],
+            [3., 4., 5.]],
+            [[9., 10., 11.],
+            [12., 13., 14.]],
+            [[18., 19., 20.],
+            [21., 22., 23.]],
+            [[27., 28., 29.],
+            [30., 31., 32.]],
+            [[36., 37., 38.],
+            [39., 40., 41.]],
+            [[45., 46., 47.],
+            [48., 49., 50.]],
+            [[54., 55., 56.],
+            [57., 58., 59.]],
+            [[63., 64., 65.],
+            [66., 67., 68.]]]]
+
+    # DCR mode
+    # (1, 2, 4, 6) output tensor
+    output = [[[[0., 18., 1., 19., 2., 20.],
+                [36., 54., 37., 55., 38., 56.],
+                [3., 21., 4., 22., 5., 23.],
+                [39., 57., 40., 58., 41., 59.]],
+               [[9., 27., 10., 28., 11., 29.],
+                [45., 63., 46., 64., 47., 65.],
+                [12., 30., 13., 31., 14., 32.],
+                [48., 66., 49., 67., 50., 68.]]]]
+
+    # CRD mode
+    # (1, 2, 4, 6) output tensor
+    output = [[[[0., 9., 1., 10., 2., 11.],
+                [18., 27., 19., 28., 20., 29.],
+                [3., 12., 4., 13., 5., 14.],
+                [21., 30., 22., 31., 23., 32.]],
+               [[36., 45., 37., 46., 38., 47.],
+                [54., 63., 55., 64., 56., 65.],
+                [39., 48., 40., 49., 41., 50.],
+                [57., 66., 58., 67., 59., 68.]]]]
+    """
+
+    def __init__(self, blocksize, mode="DCR"):
+        """
+        Args:
+            blocksize (int): Blocks of [blocksize, blocksize] are moved.
+            mode (string): DCR (default) for depth-column-row order re-
+                arrangement. Use CRD for column-row-depth order.
+        """
+        super(DepthToSpace, self).__init__()
+        self.blocksize = blocksize
+        self.mode = mode.upper()
+
+    def forward(self, x):
+        if training:
+            self.x_shape = x.shape()
+        b, c, h, w = x.shape()
+        blocksize = self.blocksize
+        if self.mode == "DCR":
+            x = singa.Reshape(
+                x, [b, blocksize, blocksize, c // (blocksize**2), h, w])
+            x = singa.Transpose(x, [0, 3, 4, 1, 5, 2])
+            x = singa.Reshape(
+                x, [b, c // (blocksize**2), h * blocksize, w * blocksize])
+        elif self.mode == "CRD":
+            x = singa.Reshape(
+                x, [b, c // (blocksize**2), blocksize, blocksize, h, w])
+            x = singa.Transpose(x, [0, 1, 4, 2, 5, 3])
+            x = singa.Reshape(
+                x, [b, c // (blocksize**2), h * blocksize, w * blocksize])
+        else:
+            assert False, ("only support two methods: DCR and CRD.")
+        return x
+
+    def backward(self, dy):
+        b, c, h, w = self.x_shape
+        blocksize = self.blocksize
+        dy = singa.Reshape(
+            dy, [b, c // (blocksize**2), h, blocksize, w, blocksize])
+        if self.mode == "DCR":
+            dy = singa.Transpose(dy, [0, 3, 5, 1, 2, 4])
+        elif self.mode == "CRD":
+            dy = singa.Transpose(dy, [0, 1, 3, 5, 2, 4])
+        else:
+            assert False, ("only support two methods: DCR and CRD.")
+        dy = singa.Reshape(dy, self.x_shape)
+        return dy
+
+
+def depth_to_space(x, blocksize, mode="DCR"):
+    """
+    Produces a DepthToSpace operator
+    Args:
+        x (Tensor): input tensor.
+        blocksize (int): Blocks of [blocksize, blocksize] are moved.
+        mode (string): DCR (default) for depth-column-row order re-
+            arrangement. Use CRD for column-row-depth order.
+    Returns:
+        the output Tensor.
+    """
+    return DepthToSpace(blocksize, mode)(x)[0]
+
+
+class SpaceToDepth(Operator):
+    """
+    SpaceToDepth operator following ONNX Operator Schemas, reverse of DepthToSpace
+    https://github.com/onnx/onnx/blob/master/docs/Operators.md#SpaceToDepth
+    """
+
+    def __init__(self, blocksize, mode="DCR"):
+        """
+        Args:
+            blocksize (int): Blocks of [blocksize, blocksize] are moved.
+            mode (string): DCR (default) for depth-column-row order re-
+                arrangement. Use CRD for column-row-depth order.
+        """
+        super(SpaceToDepth, self).__init__()
+        self.blocksize = blocksize
+        self.mode = mode.upper()
+
+    def forward(self, x):
+        blocksize = self.blocksize
+        b, c, h, w = x.shape()
+        b, c, h, w = b, c * (blocksize**2), h // blocksize, w // blocksize
+        if training:
+            self.x_shape = (b, c, h, w)
+        x = singa.Reshape(
+            x, [b, c // (blocksize**2), h, blocksize, w, blocksize])
+        if self.mode == "DCR":
+            x = singa.Transpose(x, [0, 3, 5, 1, 2, 4])
+        elif self.mode == "CRD":
+            x = singa.Transpose(x, [0, 1, 3, 5, 2, 4])
+        else:
+            assert False, ("only support two methods: DCR and CRD.")
+        x = singa.Reshape(x, self.x_shape)
+        return x
+
+    def backward(self, dy):
+        b, c, h, w = self.x_shape
+        blocksize = self.blocksize
+        if self.mode == "DCR":
+            dy = singa.Reshape(
+                dy, [b, blocksize, blocksize, c // (blocksize**2), h, w])
+            dy = singa.Transpose(dy, [0, 3, 4, 1, 5, 2])
+            dy = singa.Reshape(
+                dy, [b, c // (blocksize**2), h * blocksize, w * blocksize])
+        elif self.mode == "CRD":
+            dy = singa.Reshape(
+                dy, [b, c // (blocksize**2), blocksize, blocksize, h, w])
+            dy = singa.Transpose(dy, [0, 1, 4, 2, 5, 3])
+            dy = singa.Reshape(
+                dy, [b, c // (blocksize**2), h * blocksize, w * blocksize])
+        else:
+            assert False, ("only support two methods: DCR and CRD.")
+        return dy
+
+
+def space_to_depth(x, blocksize, mode="DCR"):
+    """
+    Produces a SpaceToDepth operator
+    Args:
+        x (Tensor): input tensor.
+        blocksize (int): Blocks of [blocksize, blocksize] are moved.
+        mode (string): DCR (default) for depth-column-row order re-
+            arrangement. Use CRD for column-row-depth order.
+    Returns:
+        the output Tensor.
+    """
+    return SpaceToDepth(blocksize, mode)(x)[0]
+
+
+class Where(Operator):
+    """
+    Where operator following ONNX Operator Schemas
+    https://github.com/onnx/onnx/blob/master/docs/Operators.md#Where
+    and Numpy
+    https://numpy.org/doc/stable/reference/generated/numpy.where.html
+    Example usage::
+    condition = [[True, False], 
+              [True, True]]
+    x = [[1, 2], 
+        [3, 4]]
+    y =  [[9, 8], 
+        [7, 6]]
+
+    output = [[1, 8],
+            [3, 4]]
+    """
+
+    def __init__(self, condition):
+        """
+        Args:
+            condition (Tensor): When True (nonzero), yield X, otherwise yield Y
+        """
+        super(Where, self).__init__()
+        self.condition = condition
+
+    def forward(self, a, b):
+        if isinstance(self.condition, list):
+            self.condition = np.array(self.condition)
+        if isinstance(self.condition, np.ndarray):
+            self.condition = self.condition.astype(np.float32)
+            self.condition = tensor.from_numpy(self.condition)
+            self.condition.to_device(a.device())
+            self.condition = self.condition.data
+        self.neg_condition = singa.AddFloat(singa.MultFloat(self.condition, -1.), 1.)
+        _a, _b = a, b
+        dtype0 = _a.data_type()
+        dtype1 = _b.data_type()
+        if dtype0 == singa.kInt or dtype1 == singa.kInt:
+            _a = a.AsType(singa.kFloat32)
+            _b = b.AsType(singa.kFloat32)
+            res = singa.__add__(singa.__mul__(self.condition, _a),
+                             singa.__mul__(self.neg_condition, _b))
+            res = res.AsType(singa.kInt)
+        else:
+            res = singa.__add__(singa.__mul__(self.condition, _a),
+                             singa.__mul__(self.neg_condition, _b))
+        return res
+
+    def backward(self, dy):
+        da = singa.__mul__(self.condition, dy)
+        db = singa.__mul__(self.neg_condition, dy)
+        return da, db
+
+
+def where(x, y, condition):
+    """
+    Produces a Where operator
+    Args:
+        x (Tensor): input tensor.
+        y (Tensor): input tensor.
+        condition (Tensor): When True (nonzero), yield X, otherwise yield Y
+    Returns:
+        the output Tensor.
+    """
+    return Where(condition)(x, y)[0]
+
+
+class Round(Operator):
+    """
+    Element-wise round the input
+    """
+
+    def __init__(self):
+        super(Round, self).__init__()
+
+    def forward(self, x):
+        return singa.Round(x)
+
+    def backward(self, dy):
+        dy = singa.Tensor(dy.shape(), dy.device())
+        dy.SetFloatValue(0.)
+        return dy
+
+
+def round(x):
+    """
+    Element-wise round the input
+    Args:
+        x (Tensor): input tensor.
+    Returns:
+        the output Tensor.
+    """
+    return Round()(x)[0]
+
+
+class Rounde(Operator):
+    """
+    Element-wise round the input, In case of halfs, round to the nearest even integer
+    """
+
+    def __init__(self):
+        super(Rounde, self).__init__()
+
+    def forward(self, x):
+        return singa.RoundE(x)
+
+    def backward(self, dy):
+        dy = singa.Tensor(dy.shape(), dy.device())
+        dy.SetFloatValue(0.)
+        return dy
+
+
+def rounde(x):
+    """
+    Element-wise round the input, In case of halfs, round to the nearest even integer
+    Args:
+        x (Tensor): input tensor.
+    Returns:
+        the output Tensor.
+    """
+    return Rounde()(x)[0]
+
+
+class Embedding(Operator):
+    """
+    Init an embedding operator
+    """
+
+    def __init__(self):
+        super(Embedding, self).__init__()
+
+    def forward(self, x, w):
+        """
+        forward of embedding
+        Args:
+            x (CTensor): input tensor.
+            w (CTensor): weight tensor.
+        Returns:
+            the output CTensor.
+        """
+        x = tensor.to_numpy(tensor.from_raw_tensor(x))
+        if training:
+            self.cache = (x, w.shape())
+
+        xs = []
+        x = x.tolist()
+        for indice in x:
+            sub_xs = []
+            for idx in indice:
+                idx = int(idx)
+                tmp_tensor = singa.SliceOn(w, idx, idx + 1, 0)
+                sub_xs.append(tmp_tensor)
+            sub_xs = singa.VecTensor(sub_xs)
+            tmp_tensor = singa.ConcatOn(sub_xs, 0)
+            tmp_tensor = singa.Reshape(tmp_tensor,
+                                       [1] + list(tmp_tensor.shape()))
+
+            xs.append(tmp_tensor)
+        xs = singa.VecTensor(xs)
+        xs = singa.ConcatOn(xs, 0)
+        return xs
+
+    def backward(self, dy):
+        """
+        backward of embedding
+        Args:
+            dy (CTensor): gradient tensor.
+        Raises:
+            the gradient tensor over input tensor.
+        """
+        x, w_shape = self.cache
+        dy_shape = dy.shape()
+        # construct the dx
+        dx = tensor.sum(tensor.from_raw_tensor(dy), axis=2)
+
+        # construct the dw
+        dws = []
+        for idx in range(w_shape[0]):
+            tmp_tensor = singa.Tensor((1, w_shape[1]), dy.device())
+            tmp_tensor.SetFloatValue(0.0)
+            dws.append(tmp_tensor)
+        dy = singa.Reshape(dy, [dy_shape[0] * dy_shape[1], dy_shape[2]])
+        x = x.reshape(-1)
+        for idx, val in enumerate(x):
+            tmp_tensor = singa.SliceOn(dy, idx, idx + 1, 0)
+            dws[val] = singa.__add__(dws[val], tmp_tensor)
+        dws = singa.VecTensor(dws)
+        return dx.data, singa.ConcatOn(dws, 0)
+
+
+def embedding(x, w):
+    """
+    Produces an embedding operator.
+    Args:
+    Returns:
+        the output Tensor.
+    """
+    return Embedding()(x, w)[0]
+
+
+class Erf(Operator):
+    """
+    Apply element-wise math.erf to the input
+    """
+
+    def __init__(self):
+        super(Erf, self).__init__()
+
+    def forward(self, x):
+        return singa.Erf(x)
+
+    def backward(self, dy):
+        dx = singa.MultFloat(singa.PowFloat(dy, 2.0), -1.0)
+        dx = singa.MultFloat(singa.Exp(dx), 2. / np.pi ** 0.5)
+        return dx
+
+
+def erf(x):
+    """
+    Apply element-wise math.erf to the input
+    Args:
+        x (Tensor): input tensor.
+    Returns:
+        the output Tensor.
+    """
+    return Erf()(x)[0]
+
+
+''' alias for Operator and Layers
+'''
+Operation = Operator
+''' import layer at the end to resolve circular import
+'''
+from singa import layer
+Linear = layer.Linear
+Conv2d = layer.Conv2d
+SeparableConv2d = layer.SeparableConv2d
+BatchNorm2d = layer.BatchNorm2d
+Pooling2d = layer.Pooling2d
+MaxPool2d = layer.MaxPool2d
+AvgPool2d = layer.AvgPool2d
+MaxPool1d = layer.MaxPool1d
+AvgPool1d = layer.AvgPool1d
+RNN_Base = layer.RNN_Base
+RNN = layer.RNN
+LSTM = layer.LSTM
diff --git a/python/singa/converter.py b/python/singa/converter.py
deleted file mode 100644
index 34954e1..0000000
--- a/python/singa/converter.py
+++ /dev/null
@@ -1,242 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# =============================================================================
-
-from __future__ import print_function
-from builtins import str
-from builtins import range
-from builtins import object
-from google.protobuf import text_format
-from singa import layer
-from singa import metric
-from singa import loss
-from singa import net as ffnet
-from .proto import model_pb2
-from .proto import caffe_pb2
-import numpy as np
-
-
-class CaffeConverter(object):
-
-    def __init__(self,
-                 net_proto,
-                 solver_proto=None,
-                 input_sample_shape=None,
-                 param_path=None):
-        self.caffe_net_path = net_proto
-        self.caffe_solver_path = solver_proto
-        self.input_sample_shape = input_sample_shape
-        self.param_path = param_path
-
-    def read_net_proto(self):
-        net_config = caffe_pb2.NetParameter()
-        return self.read_proto(self.caffe_net_path, net_config)
-
-    def read_solver_proto(self):
-        solver_config = caffe_pb2.SolverParameter()
-        return self.read_proto(self.caffe_solver_path, solver_config)
-
-    def read_caffemodel(self):
-        f = open(self.param_path, 'rb')
-        contents = f.read()
-        net_param = caffe_pb2.NetParameter()
-        net_param.ParseFromString(contents)
-        return net_param
-
-    def read_proto(self, filepath, parser_object):
-        file = open(filepath, "r")
-        if not file:
-            raise FileNotFoundError("ERROR (" + filepath + ")!")
-        # Merges an ASCII representation of a protocol message into a message.
-        text_format.Merge(str(file.read()), parser_object)
-        file.close()
-        return parser_object
-
-    def convert_engine(self, layer_conf, solver_mode):
-        '''
-        Convert caffe engine into singa engine
-        return:
-            a singa engine string
-        '''
-        caffe_engine = ''
-        singa_engine = ''
-
-        # if no 'engine' field in caffe proto, set engine to -1
-        if layer_conf.type == 'Convolution' or layer_conf.type == 4:
-            caffe_engine = layer_conf.convolution_param.engine
-        elif layer_conf.type == 'Pooling' or layer_conf.type == 17:
-            caffe_engine = layer_conf.pooling_param.engine
-        elif layer_conf.type == 'ReLU' or layer_conf.type == 18:
-            caffe_engine = layer_conf.relu_param.engine
-        elif layer_conf.type == 'Sigmoid' or layer_conf.type == 19:
-            caffe_engine = layer_conf.sigmoid_param.engine
-        elif layer_conf.type == 'TanH' or layer_conf.type == 23:
-            caffe_engine = layer_conf.tanh_param.engine
-        elif layer_conf.type == 'LRN' or layer_conf.type == 15:
-            caffe_engine = layer_conf.lrn_param.engine
-        elif layer_conf.type == 'Softmax' or layer_conf.type == 20:
-            caffe_engine = layer_conf.softmax_param.engine
-        elif layer_conf.type == 'InnerProduct' or layer_conf.type == 14:
-            caffe_engine = -1
-        elif layer_conf.type == 'Dropout' or layer_conf.type == 6:
-            caffe_engine = -1
-        elif layer_conf.type == 'Flatten' or layer_conf.type == 8:
-            caffe_engine = -1
-        else:
-            raise Exception('Unknown layer type: ' + layer_conf.type)
-
-        # caffe_engine: -1-no field;  0-DEFAULT; 1-CAFFE; 2-CUDNN
-        # solver_mode: 0-CPU; 1-GPU
-        if solver_mode == 1:
-            singa_engine = 'cudnn'
-        else:
-            if caffe_engine == 2:
-                raise Exception('engine and solver mode mismatch!')
-            else:
-                singa_engine = 'singacpp'
-
-        if ((layer_conf.type == 'InnerProduct' or layer_conf.type == 14) or \
-            (layer_conf.type == 'Flatten' or layer_conf.type == 8)) and \
-            singa_engine == 'cudnn':
-            singa_engine = 'singacuda'
-
-        return singa_engine
-
-    def create_net(self):
-        '''
-        Create singa net based on caffe proto files.
-            net_proto: caffe prototxt that describes net
-            solver_proto: caffe prototxt that describe solver
-            input_sample_shape: shape of input data tensor
-        return:
-            a FeedForwardNet object
-        '''
-        caffe_net = self.read_net_proto()
-        caffe_solver = None
-        if self.caffe_solver_path is not None:
-            caffe_solver = self.read_solver_proto()
-        layer_confs = ''
-        flatten_id = 0
-
-        # If the net proto has the input shape
-        if len(caffe_net.input_dim) > 0:
-            self.input_sample_shape = caffe_net.input_dim
-        if len(caffe_net.layer):
-            layer_confs = caffe_net.layer
-        elif len(caffe_net.layers):
-            layer_confs = caffe_net.layers
-        else:
-            raise Exception('Invalid proto file!')
-
-        net = ffnet.FeedForwardNet()
-        for i in range(len(layer_confs)):
-            if layer_confs[i].type == 'Data' or layer_confs[i].type == 5:
-                continue
-            elif layer_confs[i].type == 'Input':
-                self.input_sample_shape = layer_confs[i].input_param.shape[
-                    0].dim[1:]
-            elif layer_confs[i].type == 'SoftmaxWithLoss' or layer_confs[
-                    i].type == 21:
-                net.loss = loss.SoftmaxCrossEntropy()
-            elif layer_confs[i].type == 'EuclideanLoss' or layer_confs[
-                    i].type == 7:
-                net.loss = loss.SquaredError()
-            elif layer_confs[i].type == 'Accuracy' or layer_confs[i].type == 1:
-                net.metric = metric.Accuracy()
-            else:
-                strConf = layer_confs[i].SerializeToString()
-                conf = model_pb2.LayerConf()
-                conf.ParseFromString(strConf)
-                if caffe_solver:
-                    layer.engine = self.convert_engine(layer_confs[i],
-                                                       caffe_solver.solver_mode)
-                else:
-                    # if caffe_solver is None,
-                    layer.engine = self.convert_engine(layer_confs[i], 0)
-                lyr = layer.Layer(conf.name, conf)
-                if len(net.layers) == 0:
-                    print('input sample shape: ', self.input_sample_shape)
-                    lyr.setup(self.input_sample_shape)
-                    print(lyr.name, lyr.get_output_sample_shape())
-                if layer_confs[i].type == 'InnerProduct' or layer_confs[
-                        i].type == 14:
-                    net.add(layer.Flatten('flat' + str(flatten_id)))
-                    flatten_id += 1
-                net.add(lyr)
-
-        return net
-
-    def convert_params(self, net):
-        '''
-        Convert params in .caffemodel into singa model.
-        This method only supports current version of Caffe(24-Nov-2016).
-        '''
-
-        params = net.param_values()
-        caffe_model = self.read_caffemodel()
-        layers = None
-        if len(caffe_model.layer):
-            layers = caffe_model.layer
-        else:
-            raise Exception('Invalid proto file!')
-
-        i = 0
-        first_conv = True
-        for layer in layers:
-            if layer.type == 'Convolution' or layer.type == 'InnerProduct':
-                assert (len(layer.blobs) == 2), 'Either 2 params per layer or 0'
-                wmat_dim = []
-                if getattr(layer.blobs[0].shape, 'dim', None) is not None:
-                    if len(layer.blobs[0].shape.dim) > 0:
-                        wmat_dim = layer.blobs[0].shape.dim
-                    else:
-                        wmat_dim = [layer.blobs[0].num, \
-                                layer.blobs[0].channels, \
-                                layer.blobs[0].height, \
-                                layer.blobs[0].width]
-                else:
-                    wmat_dim = list(layer.blobs[0].shape)
-
-                wmat = np.array(layer.blobs[0].data, dtype=np.float32)
-                bias = np.array(layer.blobs[1].data, dtype=np.float32)
-                #print layer.name, ' wmat_dim: ', wmat_dim
-
-                wdim = []
-                if layer.type == 'InnerProduct':
-                    wdim = wmat_dim[-2:]
-                else:
-                    if wmat_dim[1] == 3 and first_conv:  # BGR -> RGB
-                        wmat = wmat.reshape(wmat_dim)
-                        wmat[:, [0, 1, 2], :, :] = wmat[:, [2, 1, 0], :, :]
-                        first_conv = False
-                    nb_filters = wmat_dim[0]
-                    chw = 1
-                    for k in range(1, len(wmat_dim)):
-                        chw *= wmat_dim[k]
-                    wdim.extend([nb_filters, chw])
-                #print layer.name, ' wdim: ', wdim
-                w = np.reshape(wmat, wdim)
-
-                # TODO(wangwei) transpose SINGA's weight following caffe
-                if layer.type == 'InnerProduct':
-                    w = np.transpose(w)
-                params[i].copy_from_numpy(w)
-                i += 1
-                params[i].copy_from_numpy(bias)
-                i += 1
-                print(
-                    'converting layer {0}, wmat shape = {1}, bias shape = {2}'.
-                    format(layer.name, w.shape, bias.shape))
diff --git a/python/singa/device.py b/python/singa/device.py
index c1dd837..cfc3eb8 100644
--- a/python/singa/device.py
+++ b/python/singa/device.py
@@ -22,39 +22,10 @@
 TODO(wangwei) implement py CudaGPU class.
 '''
 
-from builtins import object
+# from builtins import object
 from . import singa_wrap as singa
 
 
-class Device(object):
-    """ Class and member functions for singa::Device.
-
-    Create Device instances using the CreateXXXDevice.
-    """
-
-    def __init__(self, id, device):
-        """Device constructor given device ID.
-
-        Args:
-            id (int): device ID.
-            device: swig shared_ptr<Device>
-        """
-        self.id = id
-        self.singa_device = device
-
-    def set_rand_seed(self, seed):
-        self.singa_device.SetRandSeed(seed)
-
-    def enable_graph(self, enable):
-        self.singa_device.EnableGraph(enable)
-
-    def get_host(self):
-        return self.singa_device.host()
-
-    def get_id(self):
-        return self.singa_device.id()
-
-
 def get_num_gpus():
     assert singa.USE_CUDA, 'SINGA has not been compiled with CUDA enabled.'
     return singa.Platform.GetNumGPUs()
@@ -85,6 +56,15 @@
     return singa.Platform.DeviceQuery(id, verbose)
 
 
+def create_cpu_device():
+    '''Create the default CPU device.
+
+    Returns:
+        a swig converted CPU device.
+    '''
+    return singa.Platform.GetDefaultDevice()
+
+
 def create_cuda_gpus(num):
     '''Create a list of CudaGPU devices.
 
@@ -97,17 +77,14 @@
     return singa.Platform.CreateCudaGPUs(num)
 
 
-def create_cuda_gpu(set_default=True):
+def create_cuda_gpu():
     '''Create a single CudaGPU device.
 
     Returns:
         a swig converted CudaGPU device.
     '''
     assert singa.USE_CUDA, 'SINGA has not been compiled with CUDA enabled.'
-    devices = singa.Platform.CreateCudaGPUs(1)
-    if set_default is True:
-        set_default_device(devices[0])
-    return devices[0]
+    return create_cuda_gpu_on(0)
 
 
 def create_cuda_gpus_on(device_ids):
@@ -123,7 +100,7 @@
     return singa.Platform.CreateCudaGPUsOn(device_ids)
 
 
-def create_cuda_gpu_on(device_id, set_default=True):
+def create_cuda_gpu_on(device_id):
     '''Create a CudaGPU device on the given device ID.
 
     Args:
@@ -134,8 +111,6 @@
     '''
     assert singa.USE_CUDA, 'SINGA has not been compiled with CUDA enabled.'
     devices = create_cuda_gpus_on([device_id])
-    if set_default is True:
-        set_default_device(devices[0])
     return devices[0]
 
 
@@ -149,18 +124,13 @@
     return singa.Platform.GetDefaultOpenclDevice()
 
 
-Device.default_device = singa.Platform.GetDefaultDevice()
+default_device = singa.Platform.GetDefaultDevice()
 
 
 def get_default_device():
     '''Get the default host device which is a CppCPU device'''
-    return Device.default_device
+    return default_device
 
 
-def set_default_device(device):
-    '''Set the Device class static variable default_device'''
-    Device.default_device = device
-
-
-def enbale_lazy_alloc(enable):
+def enable_lazy_alloc(enable):
     singa.Device.EnableLazyAlloc(enable)
diff --git a/python/singa/initializer.py b/python/singa/initializer.py
index cb2f5a0..c907736 100644
--- a/python/singa/initializer.py
+++ b/python/singa/initializer.py
@@ -17,44 +17,170 @@
 # =============================================================================
 '''Popular initialization methods for parameter values (Tensor objects).
 
+credit: this module is adapted from keras
+https://github.com/keras-team/keras/blob/master/keras/initializers.py
+
+All functions in this module change the input tensor in-place.
+
 Example usages::
 
     from singa import tensor
     from singa import initializer
 
     x = tensor.Tensor((3, 5))
-    initializer.uniform(x, 3, 5) # use both fan_in and fan_out
-    initializer.uniform(x, 3, 0)  # use only fan_in
+    initializer.he_uniform(x)
+    initializer.golorot_norm(x) 
 '''
+
 from __future__ import division
 import math
+import numpy as np
+from deprecated import deprecated
 
 
-def uniform(t, fan_in=0, fan_out=0):
+def eye(t):
+    """Initialize the tensor with ones on the diagonal and zeros elsewhere.
+
+    Note: it is implemented by calling numpy. 
+    Do not call it within forward propagation when computation graph is enabled.
+
+    # Arguments
+        t(Tensor): the matrix to be filled in.
+    """
+    if len(t.shape) == 2:
+        raise ValueError("Only tensors with 2 dimensions are supported")
+    a = np.eye(t.shape[0], t.shape[1], dtype=np.float32)
+    t.copy_from(a)
+
+
+def orthogonal(t, gain=1.0):
+    """Initializer that generates a random orthogonal matrix.
+
+    Note: it is implemented by calling numpy. 
+    Do not call it within forward propagation when computation graph is enabled.
+
+    # Arguments
+        t(Tensor): the matrix to be filled in.
+        gain: Multiplicative factor to apply to the orthogonal matrix.
+
+    # References
+        - [Exact solutions to the nonlinear dynamics of learning in deep
+           linear neural networks](http://arxiv.org/abs/1312.6120)
+    """
+    if len(t.shape) == 2:
+        raise ValueError("Only tensors with 2 dimensions are supported")
+
+    a = np.random.normal(0.0, 1.0, t.shape).astype(np.float32)
+    u, _, v = np.linalg.svd(a, full_matrices=False)
+    # Pick the one with the correct shape.
+    q = u if u.shape == t.shape else v
+    q *= gain
+    t.copy_from(q)
+
+
+def lecun_uniform(t):
+    """LeCun uniform initializer.
+
+    It draws samples from a uniform distribution within [-limit, limit]
+    where `limit` is `sqrt(3 / fan_in)`
+    where `fan_in` is the number of input units in the weight tensor.
+
+    # Arguments
+        t(Tensor):the tensor to be filled in.
+
+    # References
+        - [Efficient BackProp](http://yann.lecun.com/exdb/publis/pdf/lecun-98b.pdf)
+    """
+    _random_fill(t, scale=1., mode='fan_in', distribution='uniform')
+
+
+def glorot_normal(t):
+    """Glorot normal initializer, also called Xavier normal initializer.
+
+    It draws samples from a normal distribution centered on 0
+    with `stddev = sqrt(2 / (fan_in + fan_out))`
+    where `fan_in` is the number of input units in the weight tensor
+    and `fan_out` is the number of output units in the weight tensor.
+
+    # Arguments
+        t(Tensor):the tensor to be filled in.
+
+    # References
+        - [Understanding the difficulty of training deep feedforward neural
+           networks](http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf)
+    """
+    _random_fill(t, scale=1., mode='fan_avg', distribution='normal')
+
+
+def glorot_uniform(t):
+    """Glorot uniform initializer, also called Xavier uniform initializer.
+
+    It draws samples from a uniform distribution within [-limit, limit]
+    where `limit` is `sqrt(6 / (fan_in + fan_out))`
+    where `fan_in` is the number of input units in the weight tensor
+    and `fan_out` is the number of output units in the weight tensor.
+
+    # Arguments
+        t(Tensor):the tensor to be filled in.
+    # References
+        - [Understanding the difficulty of training deep feedforward neural
+           networks](http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf)
+    """
+    _random_fill(t, scale=1., mode='fan_avg', distribution='uniform')
+
+
+def he_normal(t):
+    """He normal initializer.
+
+    It draws samples from a truncated normal distribution centered on 0
+    with `stddev = sqrt(2 / fan_in)`
+    where `fan_in` is the number of input units in the weight tensor.
+
+    # Arguments
+        t(Tensor):the tensor to be filled in.
+
+    # References
+        - [Delving Deep into Rectifiers: Surpassing Human-Level Performance on
+           ImageNet Classification](http://arxiv.org/abs/1502.01852)
+    """
+    _random_fill(t, scale=2., mode='fan_in', distribution='normal')
+
+def lecun_normal(t):
+    """LeCun normal initializer.
+
+    It draws samples from a truncated normal distribution centered on 0
+    with `stddev = sqrt(1 / fan_in)`
+    where `fan_in` is the number of input units in the weight tensor.
+
+    # Arguments
+        t(Tensor):the tensor to be filled in.
+
+    # References
+        - [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515)
+        - [Efficient Backprop](http://yann.lecun.com/exdb/publis/pdf/lecun-98b.pdf)
+    """
+    _random_fill(t, scale=1., mode='fan_in', distribution='normal')
+
+
+def he_uniform(t):
     '''Initialize the values of the input tensor following a uniform
     distribution with specific bounds.
 
-    Args:
-        fan_in(int): for the weight Tensor of a convolution layer,
-            fan_in = nb_channel * kh * kw; for dense layer,
-            fan_in = input_feature_length
-        fan_out(int): for the convolution layer weight Tensor,
-            fan_out = nb_filter * kh * kw; for the weight Tensor of a dense
-            layer, fan_out = output_feature_length
+    It draws samples from a uniform distribution within [-limit, limit]
+    where `limit` is `sqrt(6 / fan_in)`
+    where `fan_in` is the number of input units in the weight tensor.
 
-    Ref: [Bengio and Glorot 2010]: Understanding the difficulty of
-    training deep feedforward neuralnetworks.
+    # Arguments
+        t(Tensor): the tensor to be filled in.
 
+    # References
+        - [Delving Deep into Rectifiers: Surpassing Human-Level Performance on
+           ImageNet Classification](http://arxiv.org/abs/1502.01852)
     '''
-    assert fan_in > 0 or fan_out > 0, \
-        'fan_in and fan_out cannot be 0 at the same time'
-    avg = 2
-    if fan_in * fan_out == 0:
-        avg = 1
-    x = math.sqrt(3.0 * avg / (fan_in + fan_out))
-    t.uniform(-x, x)
+    _random_fill(t, scale=2., mode='fan_in', distribution='uniform')
 
 
+@deprecated(reason="Use he_normal or glorot_normal")
 def gaussian(t, fan_in=0, fan_out=0):
     '''Initialize the values of the input tensor following a Gaussian
     distribution with specific std.
@@ -79,12 +205,11 @@
     t.gaussian(0, std)
 
 
+@deprecated(reason="Use glorot_normal")
 def xavier(t):
     '''Initialize the matrix parameter follow a Uniform distribution from
     [-sqrt(6/(fan_in + fan_out)), sqrt(6/(fan_in + fan_out))].
 
-    Deprecated. Please use uniform()
-
     Args:
         t (Tensor): the parater tensor
     '''
@@ -93,12 +218,11 @@
     t.uniform(-scale, scale)
 
 
+@deprecated(reason="Use glorot_uniform")
 def glorot(t):
     '''Initialize the matrix parameter follow a Gaussian distribution with
     mean = 0 and std = sqrt(2.0 / (nb_row + nb_col))
 
-    Deprecated. Please use gaussian()
-
     Args:
         t (Tensor): the parater tensor
     '''
@@ -107,12 +231,11 @@
     t *= scale
 
 
+@deprecated(reason="Use he_normal")
 def msra(t):
     '''Initialize the matrix parameter follow a Guassian distribution with
     mean = 0, std = math.sqrt(2.0 / nb_row).
 
-    Deprecated. Please use gaussian()
-
     Ref [He, Zhang, Ren and Sun 2015]: Specifically accounts for ReLU
     nonlinearities.
 
@@ -120,3 +243,94 @@
         t (Tensor): the parater tensor
     '''
     t.gaussian(0, math.sqrt(2.0 / t.shape[0]))
+
+
+def _compute_fans(shape, data_format='channels_first'):
+    """Computes the number of input and output units for a weight shape.
+    # Arguments
+        shape: Integer shape tuple.
+        data_format: Image data format to use for convolution kernels.
+            Note that all kernels in Keras are standardized on the
+            `channels_last` ordering (even when inputs are set
+            to `channels_first`).
+    # Returns
+        A tuple of scalars, `(fan_in, fan_out)`.
+    # Raises
+        ValueError: in case of invalid `data_format` argument.
+    """
+    if len(shape) == 2:
+        fan_in = shape[0]
+        fan_out = shape[1]
+    elif len(shape) in {3, 4, 5}:
+        # Assuming convolution kernels (1D, 2D or 3D).
+        # TH kernel shape: (depth, input_depth, ...)
+        # TF kernel shape: (..., input_depth, depth)
+        if data_format == 'channels_first':
+            receptive_field_size = np.prod(shape[2:])
+            fan_in = shape[1] * receptive_field_size
+            fan_out = shape[0] * receptive_field_size
+        elif data_format == 'channels_last':
+            receptive_field_size = np.prod(shape[:-2])
+            fan_in = shape[-2] * receptive_field_size
+            fan_out = shape[-1] * receptive_field_size
+        else:
+            raise ValueError('Invalid data_format: ' + data_format)
+    else:
+        # No specific assumptions.
+        fan_in = np.sqrt(np.prod(shape))
+        fan_out = np.sqrt(np.prod(shape))
+    return fan_in, fan_out
+
+
+def _random_fill(t, scale, mode, distribution):
+    """Fill the tensor with values sampled from a distribution.
+
+    With `distribution="normal"`, samples are drawn from a normal
+    distribution centered on zero, with `stddev = sqrt(scale / n)` where n is:
+        - number of input units in the weight tensor, if mode = "fan_in"
+        - number of output units, if mode = "fan_out"
+        - average of the numbers of input and output units, if mode = "fan_avg"
+
+    With `distribution="uniform"`,
+    samples are drawn from a uniform distribution
+    within [-limit, limit], with `limit = sqrt(3 * scale / n)`.
+
+
+    Args:
+        t (Tensor): Tensor to be filled
+        scale (float): scale factor  
+        mode (str): "fan_in" or "fan_out" or "fan_avg" 
+        distribution (str): "normal" or "uniform" 
+
+    Raises:
+        ValueError: In case of an invalid value for scale, mode or distribution 
+    """
+    if scale <= 0.:
+        raise ValueError('`scale` must be a positive float. Got:', scale)
+    mode = mode.lower()
+    if mode not in {'fan_in', 'fan_out', 'fan_avg'}:
+        raise ValueError(
+            'Invalid `mode` argument: '
+            'expected on of {"fan_in", "fan_out", "fan_avg"} '
+            'but got', mode)
+    distribution = distribution.lower()
+    if distribution not in {'normal', 'uniform'}:
+        raise ValueError(
+            'Invalid `distribution` argument: '
+            'expected one of {"normal", "uniform"} '
+            'but got', distribution)
+
+    fan_in, fan_out = _compute_fans(t.shape)
+    if mode == 'fan_in':
+        scale /= max(1., fan_in)
+    elif mode == 'fan_out':
+        scale /= max(1., fan_out)
+    else:
+        scale /= max(1., float(fan_in + fan_out) / 2)
+    if distribution == 'normal':
+        # 0.879... = scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.)
+        # stddev = np.sqrt(scale) / .87962566103423978
+        t.gaussian(0., np.sqrt(scale))
+    else:
+        limit = np.sqrt(3. * scale)
+        t.uniform(-limit, limit)
\ No newline at end of file
diff --git a/python/singa/layer.py b/python/singa/layer.py
index 0041e03..e5abea7 100644
--- a/python/singa/layer.py
+++ b/python/singa/layer.py
@@ -8,1443 +8,1617 @@
 #
 #   http://www.apache.org/licenses/LICENSE-2.0
 #
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
 # =============================================================================
-""" Python layers wrap the C++ layers to provide simpler construction APIs.
 
-Example usages::
+import math
+from functools import wraps
+from collections import OrderedDict
 
-    from singa import layer
-    from singa import tensor
-    from singa import device
+from singa import utils
+from .tensor import Tensor
+from . import singa_wrap as singa
 
-    layer.engine = 'cudnn'  # to use cudnn layers
-    dev = device.create_cuda_gpu()
 
-    # create a convolution layer
-    conv = layer.Conv2D('conv', 32, 3, 1, pad=1, input_sample_shape=(3, 32, 32))
+class LayerMeta(type):
 
-    # init param values
-    w, b = conv.param_values()
-    w.guassian(0, 0.01)
-    b.set_value(0)
-    conv.to_device(dev)  # move the layer data onto a CudaGPU device
+    def init_wrapper(func):
 
-    x = tensor.Tensor((3, 32, 32), dev)
-    x.uniform(-1, 1)
-    y = conv.foward(True, x)
+        @wraps(func)
+        def wrapper(self, *args, **kwargs):
+            if len(args) == 0:
+                return
 
-    dy = tensor.Tensor()
-    dy.reset_like(y)
-    dy.set_value(0.1)
-    # dp is a list of tensors for parameter gradients
-    dx, dp = conv.backward(kTrain, dy)
-"""
-from __future__ import division
-from __future__ import absolute_import
+            if isinstance(args[0], list):
+                assert len(args) > 0 and isinstance(args[0][0], Tensor), (
+                    'initialize function expects PlaceHolders or Tensors')
+                dev = args[0][0].device
+            else:
+                assert len(args) > 0 and isinstance(args[0], Tensor), (
+                    'initialize function expects PlaceHolders or Tensors')
+                dev = args[0].device
 
-from builtins import str
-from builtins import range
-from builtins import object
-from builtins import set
+            prev_state = dev.graph_enabled()
+            dev.EnableGraph(False)
+            func(self, *args, **kwargs)
+            self._initialized = True
+            dev.EnableGraph(prev_state)
 
-from . import singa_wrap
-from .proto import model_pb2
-from . import tensor
+        return wrapper
 
-engine = 'cudnn'
-'''engine is the prefix of layer identifier.
+    def forward_wrapper(func):
 
-The value could be one of [**'cudnn', 'singacpp', 'singacuda', 'singacl'**], for
-layers implemented using the cudnn library, Cpp, Cuda and OpenCL respectively.
-For example, CudnnConvolution layer is identified by 'cudnn_convolution';
-'singacpp_convolution' is for Convolution layer;
-Some layers' implementation use only Tensor functions, thererfore they are
-transparent to the underlying devices. For threse layers, they would have
-multiple identifiers, e.g., singacpp_dropout, singacuda_dropout and
-singacl_dropout are all for the Dropout layer. In addition, it has an extra
-identifier 'singa', i.e. 'singa_dropout' also stands for the Dropout layer.
+        @wraps(func)
+        def wrapper(self, *args, **kwargs):
+            if not self._initialized:
+                self.initialize(*args, **kwargs)
+                self._initialized = True
+            return func(self, *args, **kwargs)
 
-engine is case insensitive. Each python layer would create the correct specific
-layer using the engine attribute.
+        return wrapper
+
+    def __new__(cls, name, bases, attr):
+        if 'initialize' in attr:
+            attr['initialize'] = LayerMeta.init_wrapper(attr['initialize'])
+        if 'forward' in attr:
+            attr['forward'] = LayerMeta.forward_wrapper(attr['forward'])
+
+        return super(LayerMeta, cls).__new__(cls, name, bases, attr)
+
+
+class Layer(object, metaclass=LayerMeta):
+
+    sep = '.'
+
+    def __init__(self):
+        self.name = None
+        self._initialized = False
+        self._parent = None
+        self._layers = dict()
+
+    def initialize(self, *input):
+        """ Initialize the layer
+
+        This function will be called before the forward function if this
+        layer hasn't been initialized. Those members that need to be
+        initialized according to the input will be initialized in this
+        function. e.g. parameters, states and handles.
+
+        Args:
+            *input: input args, should be consistent with the forward function
+        """
+        pass
+
+    def forward(self, *input):
+        """ Forward propagation
+
+        Args:
+            *input: input arguments consisting of only PyTensors
+        Returns:
+            PyTensor instance(s)
+        """
+        raise NotImplementedError
+
+    def __call__(self, *args, **kwargs):
+        return self.forward(*args, **kwargs)
+
+    def get_params(self):
+        """ Get parameters of this layer and all sublayers
+
+        Returns:
+            parameters(dict): A dictionary contains parameter names
+            and values of this layer and all sublayers.
+        """
+        params = dict()
+        sublayers = self._layers
+        for name, sublayer in sublayers.items():
+            if sublayer._initialized:
+                params.update(sublayer.get_params())
+        return params
+
+    def set_params(self, parameters):
+        """ Set parameters for this layer and all sublayers
+
+        Args:
+            parameters(dict): A dictionary contains parameter names
+            and corresponding values. The value shoud be either a
+            PyTensor or numpy ndarray
+        """
+        names = parameters.keys()
+        sublayers = self._layers
+        for name, sublayer in sublayers.items():
+            if sublayer._initialized:
+                if self._has_layer_param(sublayer, names):
+                    sublayer.set_params(parameters)
+
+    def get_states(self):
+        """ Get states of this layer and all sublayers
+
+        Returns:
+            states(dict): A dictionary contains state names and values
+            of this layer and all sublayers.
+        """
+        states = dict()
+        sublayers = self._layers
+        for name, sublayer in sublayers.items():
+            if sublayer._initialized:
+                states.update(sublayer.get_states())
+        states.update(self.get_params())
+        return states
+
+    def set_states(self, states):
+        """ Set states for this layer and all sublayers
+
+        Args:
+            states(dict): A dictionary contains state names and
+            corresponding values. The value shoud be either a
+            PyTensor or numpy ndarray
+        """
+        names = states.keys()
+        sublayers = self._layers
+        for name, sublayer in sublayers.items():
+            if sublayer._initialized:
+                if self._has_layer_param(sublayer, names):
+                    sublayer.set_states(states)
+        self.set_params(states)
+
+    def device_check(self, *inputs):
+        """ Check if the devices of the input tensor are the same.
+
+        Keep the device where each tensors is located the same as the
+        first tensor. Copy data to the device of the first tensor if
+        the device does not match.
+
+        Args:
+            *inputs: input args consisting of only PyTensors
+        """
+        # disabled the graph to prevent buffering data transfer operator
+        x_device = inputs[0].device
+        prev_state = x_device.graph_enabled()
+        x_device.EnableGraph(False)
+        x_dev_id = x_device.id()
+        for var in inputs:
+            if var.device.id() != x_dev_id:
+                var.to_device(x_device)
+        x_device.EnableGraph(prev_state)
+
+    def _has_layer_param(self, layer, names):
+        """ Determine whether names contains parameter names in the layer
+
+        Args:
+            layer(Layer): the layer instance
+            names(list): the list of parameter names
+
+        Returns:
+            boolean: whether names contains parameter names in that layer
+        """
+        for name in names:
+            if name.startswith(layer.name):
+                return True
+        return False
+
+    def _get_name_prefix(self):
+        """ Get the name prefix
+
+        Returns:
+            prefix(str): the layer or param name prefix
+        """
+        if self.name and self._parent:
+            return self.name + Layer.sep
+        else:
+            return ''
+
+    def __getattr__(self, name):
+        if '_layers' in self.__dict__:
+            layers = self.__dict__['_layers']
+            if name in layers:
+                return layers[name]
+        raise AttributeError("'{}' object has no attribute '{}'".format(
+            type(self).__name__, name))
+
+    def __setattr__(self, name, value):
+        if isinstance(value, Layer):
+            # TODO: remove the attr from dict first
+            self.__dict__['_layers'][name] = value
+            value.__dict__['_parent'] = self
+            value.name = self._get_name_prefix() + name
+        else:
+            object.__setattr__(self, name, value)
+            if isinstance(value, Tensor) and value.is_dummy():
+                # WARN: If tensors are initialized in __init__ function
+                #       their names may be incorrect and should be reset
+                value.name = self._get_name_prefix() + name
+            elif name == 'name' and value:
+                # WARN: can't reset the name after the initialization
+                # update sublayer name
+                for name, sublayer in self._layers.items():
+                    sublayer.name = self._get_name_prefix() + name
+
+    def __delattr__(self, name):
+        if name in self._layers:
+            del self._layers[name]
+        else:
+            object.__delattr__(self, name)
+
+    def register_layers(self, *args):
+        """ Register a list of sublayers.
+
+        Can only be called once in each subclass.
+
+        Args:
+            *args: a list of sublayers or a dictionary that contains
+            the name and the instance of each sublayer
+        """
+        if len(args) == 1 and isinstance(args[0], OrderedDict):
+            items = args[0].items()
+        else:
+            items = [(v.__class__.__name__ + '_' + str(idx), v)
+                     for idx, v in enumerate(args)]
+
+        for name, value in items:
+            if isinstance(value, Layer):
+                self._layers[name] = value
+                value.__dict__['_parent'] = self
+                value.name = name
+
+
+class Linear(Layer):
+    """
+    Generate a Linear operator
+    """
+
+    # TODO: replace current with
+    #   def __init__(self, out_features, bias=True):
+    def __init__(self, out_features, *args, bias=True, **kwargs):
+        """
+        Args:
+            out_channels: int, the channel of output, also is the number of
+                filters
+            bias: bool
+        """
+        super(Linear, self).__init__()
+
+        self.out_features = out_features
+
+        # TODO: for backward compatibility, to remove
+        if len(args) > 0:
+            self.in_features = out_features
+            self.out_features = args[0]
+        if len(args) > 1:
+            self.bias = args[1]
+        else:
+            self.bias = bias
+
+    def initialize(self, x):
+        self.in_features = x.shape[1]
+        w_shape = (self.in_features, self.out_features)
+        b_shape = (self.out_features,)
+
+        self.W = Tensor(shape=w_shape, requires_grad=True, stores_grad=True)
+        std = math.sqrt(2.0 / (self.in_features + self.out_features))
+        self.W.gaussian(0.0, std)
+
+        if self.bias:
+            self.b = Tensor(shape=b_shape, requires_grad=True, stores_grad=True)
+            self.b.set_value(0.0)
+        else:
+            self.b = None
+
+    def forward(self, x):
+        if self.b:
+            self.device_check(x, self.W, self.b)
+        else:
+            self.device_check(x, self.W)
+
+        assert x.shape[1] == self.W.shape[0], (
+            "Linear layer expects input features size %d received %d" %
+            (self.W.shape[0], x.shape[1]))
+
+        y = autograd.matmul(x, self.W)
+        if self.bias:
+            y = autograd.add_bias(y, self.b, axis=0)
+        return y
+
+    def get_params(self):
+        if self.bias:
+            return {self.W.name: self.W, self.b.name: self.b}
+        else:
+            return {self.W.name: self.W}
+
+    def set_params(self, parameters):
+        self.W.copy_from(parameters[self.W.name])
+        if self.bias:
+            self.b.copy_from(parameters[self.b.name])
+
+
+class Gemm(Layer):
+    """
+    Generate a Gemm operator
+    Y = alpha * A' * B' + beta * C
+    B is weight, C is bias
+    """
+
+    def __init__(self,
+                 nb_kernels,
+                 alpha=1.0,
+                 beta=1.0,
+                 transA=False,
+                 transB=True,
+                 bias=True,
+                 bias_shape=None):
+        """
+        Args:
+            nb_kernels: int, the channel of output, also is the number of
+                filters
+            alpha (float): Scalar multiplier for the product of input tensors A * B.
+            beta (float): Scalar multiplier for input tensor C.
+            ransA (bool): Whether A should be transposed
+            transB (bool): Whether B should be transposed
+            bias: bool
+        """
+        super(Gemm, self).__init__()
+        self.nb_kernels = nb_kernels
+        self.alpha = alpha
+        self.beta = beta
+        self.transA = 1 if transA else 0
+        self.transB = 1 if transB else 0
+        self.bias = bias
+        self.bias_shape = bias_shape
+
+    def initialize(self, x):
+        if self.transA == 0:
+            self.in_features = x.shape[-1]
+        else:
+            self.in_features = x.shape[0]
+
+        if self.transB == 0:
+            w_shape = (self.in_features, self.nb_kernels)
+        else:
+            w_shape = (self.nb_kernels, self.in_features)
+
+        if self.bias_shape:
+            b_shape = self.bias_shape
+        else:
+            b_shape = (1, self.nb_kernels)
+
+        self.W = Tensor(shape=w_shape,
+                        requires_grad=True,
+                        stores_grad=True,
+                        device=x.device)
+        std = math.sqrt(2.0 / (self.in_features + self.nb_kernels))
+        self.W.gaussian(0.0, std)
+
+        if self.bias:
+            self.b = Tensor(shape=b_shape,
+                            requires_grad=True,
+                            stores_grad=True,
+                            device=x.device)
+            self.b.set_value(0.0)
+        else:
+            self.b = None
+
+    def forward(self, x):
+        if self.b:
+            self.device_check(x, self.W, self.b)
+        else:
+            self.device_check(x, self.W)
+
+        if self.transA == 0:
+            in_features = x.shape[-1]
+        else:
+            in_features = x.shape[0]
+
+        if self.transB == 0:
+            in_features_w = self.W.shape[0]
+        else:
+            in_features_w = self.W.shape[-1]
+
+        assert in_features == in_features_w, (
+            "Gemm layer expects input features size %d received %d" %
+            (in_features_w, in_features))
+        y = autograd.gemm(x, self.W, self.b, self.alpha, self.beta, self.transA,
+                          self.transB)
+
+        return y
+
+    def get_params(self):
+        if self.bias:
+            return {self.W.name: self.W, self.b.name: self.b}
+        else:
+            return {self.W.name: self.W}
+
+    def set_params(self, parameters):
+        self.W.copy_from(parameters[self.W.name])
+        if self.bias:
+            self.b.copy_from(parameters[self.b.name])
+
+
+class Embedding(Layer):
+    """
+    Generate an Embedding operator
+    """
+
+    def __init__(self, input_dim, output_dim, initializer="gaussian"):
+        """init the Embedding operator
+        Args:
+            input_dim (int): the number of different words in the dictionary
+            output_dim (int): the dimendion of a word after the embedding
+            initializer (str, optional): weight initializer, can be [uniform, gaussian]. Defaults to "uniform".
+        """
+        super(Embedding, self).__init__()
+        self.input_dim = input_dim
+        self.output_dim = output_dim
+        self.initializer = initializer
+
+    def initialize(self, x):
+        w_shape = (self.input_dim, self.output_dim)
+        self.W = Tensor(shape=w_shape,
+                        requires_grad=True,
+                        stores_grad=True,
+                        device=x.device)
+        if self.initializer == 'uniform':
+            self.W.uniform(-1., 1.)
+        else:
+            self.W.gaussian(0., 1.)
+
+    def from_pretrained(self, W, freeze=True):
+        self.set_params({self.W.name: W})
+        self.W.requires_grad = not freeze
+
+    def forward(self, x):
+        return autograd.embedding(x, self.W)
+
+    def get_params(self):
+        return {self.W.name: self.W}
+
+    def set_params(self, parameters):
+        self.W.copy_from(parameters[self.W.name])
+
+
+class Conv2d(Layer):
+    """
+    Generate a Conv 2d operator
+    """
+
+    def __init__(self,
+                 nb_kernels,
+                 kernel_size,
+                 *args,
+                 stride=1,
+                 padding=0,
+                 dilation=1,
+                 group=1,
+                 bias=True,
+                 pad_mode="NOTSET",
+                 activation="NOTSET",
+                 **kwargs):
+        """
+        Args:
+            nb_kernels (int): the channel of output, also is the number of filters
+            kernel_size (int or tuple): kernel size for two direction of each
+                axis. For example, (2, 3), the first 2 means will add 2 at the
+                beginning and also 2 at the end for its axis.and if a int is
+                accepted, the kernel size will be initiated as (int, int)
+            stride (int or tuple): stride, the logic is the same as kernel size.
+            padding (int): tuple, list or None, padding, the logic is the same
+                as kernel size. However, if you set pad_mode as "SAME_UPPER" or
+                "SAME_LOWER" mode, you can set padding as None, and the padding
+                will be computed automatically.
+            dilation (int): only support 1
+            group (int): group
+            bias (bool): bias
+            pad_mode (string): can be NOTSET, SAME_UPPER, or SAME_LOWER, where
+                default value is NOTSET, which means explicit padding is used.
+                SAME_UPPER or SAME_LOWER mean pad the input so that the output
+                spatial size match the input. In case of odd number add the extra
+                padding at the end for SAME_UPPER and at the beginning for SAME_LOWER.
+            activation (string): can be NOTSET, RELU, where default value is NOTSET,
+                which means there is no activation behind the conv2d layer.
+                RELU means there is a ReLU behind current conv2d layer.
+        """
+        super(Conv2d, self).__init__()
+
+        # the old code create the layer like: Conv2d(8, 16, 3), or Conv2d(8, 16, 3, stride=1)
+        # the following code block is for backward compatibility
+        if len(args) > 0:
+            nb_kernels = kernel_size
+            kernel_size = args[0]
+        if len(args) > 1:
+            stride = args[1]
+        if len(args) > 2:
+            padding = args[2]
+
+        self.nb_kernels = nb_kernels
+        self.kernel_size = kernel_size
+        self.stride = stride
+        self.padding = padding
+        self.dilation = dilation
+        self.group = group
+        self.bias = bias
+        self.pad_mode = pad_mode
+        self.activation = activation
+
+        if isinstance(kernel_size, int):
+            self.kernel_size = (kernel_size, kernel_size)
+        elif isinstance(kernel_size, tuple):
+            self.kernel_size = kernel_size
+        else:
+            raise TypeError("Wrong kernel_size type.")
+
+        if isinstance(stride, int):
+            self.stride = (stride, stride)
+        elif isinstance(stride, tuple):
+            self.stride = stride
+        else:
+            raise TypeError("Wrong stride type.")
+
+        self.odd_padding = (0, 0, 0, 0)
+        if isinstance(padding, int):
+            self.padding = (padding, padding)
+        elif isinstance(padding, tuple) or isinstance(padding, list):
+            if len(padding) == 2:
+                self.padding = padding
+            elif len(padding) == 4:
+                _h_mask = padding[0] - padding[1]
+                _w_mask = padding[2] - padding[3]
+                # the odd paddding is the value that cannot be handled by the tuple padding (w, h) mode
+                # so we need to firstly handle the input, then use the nomal padding method.
+                self.odd_padding = (max(_h_mask, 0), max(-_h_mask, 0),
+                                    max(_w_mask, 0), max(-_w_mask, 0))
+                self.padding = (
+                    padding[0] - self.odd_padding[0],
+                    padding[2] - self.odd_padding[2],
+                )
+            else:
+                raise TypeError("Wrong padding value.")
+
+        if dilation != 1 and list(dilation) != [1, 1]:
+            raise ValueError("Not implemented yet")
+
+        self.inner_params = {
+            "cudnn_prefer": "fastest",
+            "workspace_MB_limit": 1024,
+        }
+        # TODO valid value of inner_params check
+
+        for kwarg in kwargs:
+            if kwarg not in self.inner_params:
+                raise TypeError("Keyword argument not understood:", kwarg)
+            else:
+                self.inner_params[kwarg] = kwargs[kwarg]
+
+    def initialize(self, x):
+        self.in_channels = x.shape[1]
+        w_shape = (
+            self.nb_kernels,
+            int(self.in_channels / self.group),
+            self.kernel_size[0],
+            self.kernel_size[1],
+        )
+
+        self.W = Tensor(shape=w_shape,
+                        requires_grad=True,
+                        stores_grad=True,
+                        device=x.device)
+        # std = math.sqrt(
+        # 2.0 / (self.in_channels * self.kernel_size[0] * self.kernel_size[1] +
+        # self.nb_kernels))
+        std = math.sqrt(
+            2.0 / (w_shape[1] * self.kernel_size[0] * self.kernel_size[1] +
+                   self.nb_kernels))
+        self.W.gaussian(0.0, std)
+
+        if self.bias:
+            b_shape = (self.nb_kernels,)
+            self.b = Tensor(shape=b_shape,
+                            requires_grad=True,
+                            stores_grad=True,
+                            device=x.device)
+            self.b.set_value(0.0)
+        else:
+            # to keep consistency when to do forward.
+            self.b = None
+            # Tensor(data=CTensor([]), requires_grad=False, stores_grad=False)
+
+        # if same pad mode, re-compute the padding
+        if self.pad_mode in ("SAME_UPPER", "SAME_LOWER"):
+            self.padding, self.odd_padding = utils.get_padding_shape(
+                self.pad_mode, x.shape[2:], self.kernel_size, self.stride)
+            self.padding = [self.padding[0], self.padding[2]]
+
+        _x = x
+        if self.odd_padding != (0, 0, 0, 0):
+            x_shape = list(x.data.shape())
+            x_shape[2] += (self.odd_padding[0] + self.odd_padding[1])
+            x_shape[3] += (self.odd_padding[2] + self.odd_padding[3])
+            _x = Tensor(shape=x_shape, device=x.device)
+            _x.set_value(0.0)
+
+        if _x.device.id() == -1:
+            if self.group != 1:
+                raise ValueError("Not implemented yet")
+            else:
+                if not hasattr(self, "handle"):
+                    self.handle = singa.ConvHandle(
+                        _x.data,
+                        self.kernel_size,
+                        self.stride,
+                        self.padding,
+                        self.in_channels,
+                        self.nb_kernels,
+                        self.bias,
+                        self.group,
+                    )
+        else:
+            if not hasattr(self, "handle"):
+                self.handle = singa.CudnnConvHandle(
+                    _x.data,
+                    self.kernel_size,
+                    self.stride,
+                    self.padding,
+                    self.in_channels,
+                    self.nb_kernels,
+                    self.bias,
+                    self.group,
+                )
+
+    def forward(self, x):
+        # sanitize the device of params/states, TODO: better to decorate forward()
+        self.device_check(x, *[s for k, s in self.get_states().items()])
+
+        assert (self.group >= 1 and self.in_channels % self.group
+                == 0), "please set reasonable group."
+
+        assert (self.nb_kernels >= self.group and self.nb_kernels % self.group
+                == 0), "nb_kernels and group dismatched."
+
+        y = autograd.conv2d(self.handle, x, self.W, self.b, self.odd_padding)
+
+        if self.activation != "NOTSET":
+            if self.activation == "RELU":
+                y = autograd.relu(y)
+
+        return y
+
+    def get_params(self):
+        if self.bias:
+            return {self.W.name: self.W, self.b.name: self.b}
+        else:
+            return {self.W.name: self.W}
+
+    def set_params(self, parameters):
+        self.W.copy_from(parameters[self.W.name])
+        if self.bias:
+            self.b.copy_from(parameters[self.b.name])
+
+
+class SeparableConv2d(Layer):
+    """
+    Generate a Conv 2d operator
+    """
+
+    def __init__(self,
+                 nb_kernels,
+                 kernel_size,
+                 *args,
+                 stride=1,
+                 padding=0,
+                 bias=False):
+        """
+        Args:
+            nb_kernels (int): the channel of output, also is the number of filters
+            kernel_size (int or tuple): kernel size for two direction of each
+                axis. For example, (2, 3), the first 2 means will add 2 at the
+                beginning and also 2 at the end for its axis.and if a int is
+                accepted, the kernel size will be initiated as (int, int)
+            stride (int or tuple): stride, the logic is the same as kernel size.
+            padding (int): tuple, list or None, padding, the logic is the same
+                as kernel size. However, if you set pad_mode as "SAME_UPPER" or
+                "SAME_LOWER" mode, you can set padding as None, and the padding
+                will be computed automatically.
+            bias (bool): bias
+        """
+        super(SeparableConv2d, self).__init__()
+
+        # the following code block is for backward compatibility
+        if len(args) > 0:
+            nb_kernels = kernel_size
+            kernel_size = args[0]
+        if len(args) > 1:
+            stride = args[1]
+        if len(args) > 2:
+            padding = args[2]
+
+        self.nb_kernels = nb_kernels
+        self.kernel_size = kernel_size
+        self.stride = stride
+        self.padding = padding
+        self.bias = bias
+
+    def initialize(self, x):
+        self.in_channels = x.shape[1]
+        self.depthwise_conv = Conv2d(
+            self.in_channels,
+            self.kernel_size,
+            stride=self.stride,
+            padding=self.padding,
+            group=self.in_channels,
+            bias=self.bias,
+        )
+
+        self.point_conv = Conv2d(self.nb_kernels, 1, bias=self.bias)
+
+    def forward(self, x):
+        y = self.depthwise_conv(x)
+        y = self.point_conv(y)
+        return y
+
+
+class BatchNorm2d(Layer):
+    """
+    Generate a BatchNorm 2d operator
+    """
+
+    def __init__(self, *args, momentum=0.9):
+        """
+        Args:
+            momentum (float): Factor used in computing the running mean and
+                variance.
+        """
+        super(BatchNorm2d, self).__init__()
+
+        if len(args) > 0:
+            self.channels = args[0]
+        if len(args) > 1:
+            self.momentum = args[1]
+        self.momentum = momentum
+        assert 0 <= momentum <= 1.0, ("Illegal momentum")
+
+    def initialize(self, x):
+        self.channels = x.shape[1]
+        param_shape = (self.channels,)
+
+        self.scale = Tensor(shape=param_shape,
+                            requires_grad=True,
+                            stores_grad=True)
+        self.scale.set_value(1.0)
+
+        self.bias = Tensor(shape=param_shape,
+                           requires_grad=True,
+                           stores_grad=True)
+        self.bias.set_value(0.0)
+
+        self.running_mean = Tensor(shape=param_shape,
+                                   requires_grad=False,
+                                   stores_grad=False)
+        self.running_mean.set_value(0.0)
+
+        self.running_var = Tensor(shape=param_shape,
+                                  requires_grad=False,
+                                  stores_grad=False)
+        self.running_var.set_value(1.0)
+
+        if not hasattr(self, "handle"):
+            if x.device.id() == -1:
+                self.handle = singa.BatchNormHandle(self.momentum, x.data)
+            else:
+                self.handle = singa.CudnnBatchNormHandle(self.momentum, x.data)
+
+    def forward(self, x):
+        assert x.shape[1] == self.channels, (
+            "number of channels dismatched. %d vs %d" %
+            (x.shape[1], self.channels))
+
+        self.device_check(x, self.scale, self.bias, self.running_mean,
+                          self.running_var)
+
+        y = autograd.batchnorm_2d(
+            self.handle,
+            x,
+            self.scale,
+            self.bias,
+            self.running_mean,
+            self.running_var,
+        )
+        return y
+
+    def get_params(self):
+        return {self.scale.name: self.scale, self.bias.name: self.bias}
+
+    def set_params(self, parameters):
+        self.scale.copy_from(parameters[self.scale.name])
+        self.bias.copy_from(parameters[self.bias.name])
+
+    def get_states(self):
+        ret = self.get_params()
+        ret[self.running_mean.name] = self.running_mean
+        ret[self.running_var.name] = self.running_var
+        return ret
+
+    def set_states(self, states):
+        self.set_params(states)
+        self.running_mean.copy_from(states[self.running_mean.name])
+        self.running_var.copy_from(states[self.running_var.name])
+
+
+class Pooling2d(Layer):
+    """
+    Generate a Pooling 2d operator
+    """
+
+    def __init__(self,
+                 kernel_size,
+                 stride=None,
+                 padding=0,
+                 is_max=True,
+                 pad_mode="NOTSET"):
+        """
+        Args:
+            kernel_size (int or tuple): kernel size for two direction of each
+                axis. For example, (2, 3), the first 2 means will add 2 at the
+                beginning and also 2 at the end for its axis.and if a int is
+                accepted, the kernel size will be initiated as (int, int)
+            stride (int or tuple): stride, the logic is the same as kernel size.
+            padding (int): tuple, list or None, padding, the logic is the same
+                as kernel size. However, if you set pad_mode as "SAME_UPPER" or
+                "SAME_LOWER" mode, you can set padding as None, and the padding
+                will be computed automatically.
+            is_max (bool): is max pooling or avg pooling
+            pad_mode (string): can be NOTSET, SAME_UPPER, or SAME_LOWER, where
+                default value is NOTSET, which means explicit padding is used.
+                SAME_UPPER or SAME_LOWER mean pad the input so that the output
+                spatial size match the input. In case of odd number add the extra
+                padding at the end for SAME_UPPER and at the beginning for SAME_LOWER.
+        """
+        super(Pooling2d, self).__init__()
+
+        if isinstance(kernel_size, int):
+            self.kernel_size = (kernel_size, kernel_size)
+        elif isinstance(kernel_size, tuple):
+            self.kernel_size = kernel_size
+        else:
+            raise TypeError("Wrong kernel_size type.")
+
+        if stride is None:
+            self.stride = self.kernel_size
+        elif isinstance(stride, int):
+            self.stride = (stride, stride)
+        elif isinstance(stride, tuple):
+            self.stride = stride
+            assert stride[0] > 0 or (kernel_size[0] == 1 and padding[0] == 0), (
+                "stride[0]=0, but kernel_size[0]=%d, padding[0]=%d" %
+                (kernel_size[0], padding[0]))
+        else:
+            raise TypeError("Wrong stride type.")
+
+        self.odd_padding = (0, 0, 0, 0)
+        if isinstance(padding, int):
+            self.padding = (padding, padding)
+        elif isinstance(padding, tuple) or isinstance(padding, list):
+            if len(padding) == 2:
+                self.padding = padding
+            elif len(padding) == 4:
+                _h_mask = padding[0] - padding[1]
+                _w_mask = padding[2] - padding[3]
+                # the odd paddding is the value that cannot be handled by the tuple padding (w, h) mode
+                # so we need to firstly handle the input, then use the nomal padding method.
+                self.odd_padding = (max(_h_mask, 0), max(-_h_mask, 0),
+                                    max(_w_mask, 0), max(-_w_mask, 0))
+                self.padding = (
+                    padding[0] - self.odd_padding[0],
+                    padding[2] - self.odd_padding[2],
+                )
+            else:
+                raise TypeError("Wrong padding value.")
+
+        self.is_max = is_max
+        self.pad_mode = pad_mode
+
+    def initialize(self, x):
+        # if same pad mode, re-compute the padding
+        if self.pad_mode in ("SAME_UPPER", "SAME_LOWER"):
+            self.padding, self.odd_padding = utils.get_padding_shape(
+                self.pad_mode, x.shape[2:], self.kernel_size, self.stride)
+
+        # if same pad mode, re-compute the padding
+        if self.pad_mode in ("SAME_UPPER", "SAME_LOWER"):
+            self.padding, self.odd_padding = utils.get_padding_shape(
+                self.pad_mode, x.shape[2:], self.kernel_size, self.stride)
+            self.padding = [self.padding[0], self.padding[2]]
+
+        _x = x
+        if self.odd_padding != (0, 0, 0, 0):
+            x_shape = list(x.data.shape())
+            x_shape[2] += (self.odd_padding[0] + self.odd_padding[1])
+            x_shape[3] += (self.odd_padding[2] + self.odd_padding[3])
+            _x = Tensor(shape=x_shape, device=x.device)
+            _x.set_value(0.0)
+
+        if _x.device.id() == -1:
+            self.handle = singa.PoolingHandle(
+                _x.data,
+                self.kernel_size,
+                self.stride,
+                self.padding,
+                self.is_max,
+            )
+        else:
+            self.handle = singa.CudnnPoolingHandle(
+                _x.data,
+                self.kernel_size,
+                self.stride,
+                self.padding,
+                self.is_max,
+            )
+
+    def forward(self, x):
+        y = autograd.pooling_2d(self.handle, x, self.odd_padding)
+        return y
+
+
+class MaxPool2d(Pooling2d):
+    """
+    Generate a Max Pooling 2d operator
+    """
+
+    def __init__(self, kernel_size, stride=None, padding=0, pad_mode="NOTSET"):
+        """
+        Args:
+            kernel_size (int or tuple): kernel size for two direction of each
+                axis. For example, (2, 3), the first 2 means will add 2 at the
+                beginning and also 2 at the end for its axis.and if a int is
+                accepted, the kernel size will be initiated as (int, int)
+            stride (int or tuple): stride, the logic is the same as kernel size.
+            padding (int): tuple, list or None, padding, the logic is the same
+                as kernel size. However, if you set pad_mode as "SAME_UPPER" or
+                "SAME_LOWER" mode, you can set padding as None, and the padding
+                will be computed automatically.
+            pad_mode (string): can be NOTSET, SAME_UPPER, or SAME_LOWER, where
+                default value is NOTSET, which means explicit padding is used.
+                SAME_UPPER or SAME_LOWER mean pad the input so that the output
+                spatial size match the input. In case of odd number add the extra
+                padding at the end for SAME_UPPER and at the beginning for SAME_LOWER.
+        """
+        super(MaxPool2d, self).__init__(kernel_size, stride, padding, True,
+                                        pad_mode)
+
+
+class AvgPool2d(Pooling2d):
+
+    def __init__(self, kernel_size, stride=None, padding=0, pad_mode="NOTSET"):
+        """
+        Args:
+            kernel_size (int or tuple): kernel size for two direction of each
+                axis. For example, (2, 3), the first 2 means will add 2 at the
+                beginning and also 2 at the end for its axis.and if a int is
+                accepted, the kernel size will be initiated as (int, int)
+            stride (int or tuple): stride, the logic is the same as kernel size.
+            padding (int): tuple, list or None, padding, the logic is the same
+                as kernel size. However, if you set pad_mode as "SAME_UPPER" or
+                "SAME_LOWER" mode, you can set padding as None, and the padding
+                will be computed automatically.
+            pad_mode (string): can be NOTSET, SAME_UPPER, or SAME_LOWER, where
+                default value is NOTSET, which means explicit padding is used.
+                SAME_UPPER or SAME_LOWER mean pad the input so that the output
+                spatial size match the input. In case of odd number add the extra
+                padding at the end for SAME_UPPER and at the beginning for SAME_LOWER.
+        """
+        super(AvgPool2d, self).__init__(kernel_size, stride, padding, False,
+                                        pad_mode)
+
+
+class MaxPool1d(Pooling2d):
+    """
+    Generate a Max Pooling 1d operator
+    """
+
+    def __init__(self, kernel_size, stride=None, padding=0, pad_mode="NOTSET"):
+        """
+        Args:
+            kernel_size (int or tuple): kernel size for two direction of each
+                axis. For example, (2, 3), the first 2 means will add 2 at the
+                beginning and also 2 at the end for its axis.and if a int is
+                accepted, the kernel size will be initiated as (int, int)
+            stride (int or tuple): stride, the logic is the same as kernel size.
+            padding (int): tuple, list or None, padding, the logic is the same
+                as kernel size. However, if you set pad_mode as "SAME_UPPER" or
+                "SAME_LOWER" mode, you can set padding as None, and the padding
+                will be computed automatically.
+            pad_mode (string): can be NOTSET, SAME_UPPER, or SAME_LOWER, where
+                default value is NOTSET, which means explicit padding is used.
+                SAME_UPPER or SAME_LOWER mean pad the input so that the output
+                spatial size match the input. In case of odd number add the extra
+                padding at the end for SAME_UPPER and at the beginning for SAME_LOWER.
+        """
+        if stride is None:
+            stride = kernel_size
+        super(MaxPool1d, self).__init__((1, kernel_size), (1, stride),
+                                        (0, padding), True, pad_mode)
+
+
+class AvgPool1d(Pooling2d):
+    """
+    Generate a Avg Pooling 1d operator
+    """
+
+    def __init__(self, kernel_size, stride=None, padding=0, pad_mode="NOTSET"):
+        """
+        Args:
+            kernel_size (int or tuple): kernel size for two direction of each
+                axis. For example, (2, 3), the first 2 means will add 2 at the
+                beginning and also 2 at the end for its axis.and if a int is
+                accepted, the kernel size will be initiated as (int, int)
+            stride (int or tuple): stride, the logic is the same as kernel size.
+            padding (int): tuple, list or None, padding, the logic is the same
+                as kernel size. However, if you set pad_mode as "SAME_UPPER" or
+                "SAME_LOWER" mode, you can set padding as None, and the padding
+                will be computed automatically.
+            pad_mode (string): can be NOTSET, SAME_UPPER, or SAME_LOWER, where
+                default value is NOTSET, which means explicit padding is used.
+                SAME_UPPER or SAME_LOWER mean pad the input so that the output
+                spatial size match the input. In case of odd number add the extra
+                padding at the end for SAME_UPPER and at the beginning for SAME_LOWER.
+        """
+        if stride is None:
+            stride = kernel_size
+        super(AvgPool1d, self).__init__((1, kernel_size), (1, stride),
+                                        (0, padding), False, pad_mode)
+
+
+class RNN_Base(Layer):
+
+    def step_forward(self,
+                     x=None,
+                     h=None,
+                     c=None,
+                     Wx=None,
+                     Wh=None,
+                     Bx=None,
+                     Bh=None,
+                     b=None):
+        raise NotImplementedError
+
+
+class RNN(RNN_Base):
+    """
+    Generate a RNN operator
+    """
+
+    def __init__(
+        self,
+        input_size,
+        hidden_size,
+        num_layers=1,
+        nonlinearity="tanh",
+        bias=True,
+        batch_first=False,
+        dropout=0,
+        bidirectional=False,
+    ):
+        """
+        Args:
+            input_size (int):  The number of expected features in the input x
+            hidden_size (int): The number of features in the hidden state h
+            num_layers (int):  Number of recurrent layers. Default: 1
+            nonlinearity (string): The non-linearity to use. Default: 'tanh'
+            bias (bool):  If False, then the layer does not use bias weights.
+                Default: True
+            batch_first (bool):  If True, then the input and output tensors
+                are provided as (batch, seq, feature). Default: False
+            dropout (float): If non-zero, introduces a Dropout layer on the
+                outputs of each RNN layer except the last layer, with dropout
+                probability equal to dropout. Default: 0
+            bidirectional (bool): If True, becomes a bidirectional RNN.
+                Default: False
+        """
+        super(RNN, self).__init__()
+
+        self.input_size = input_size
+        self.hidden_size = hidden_size
+        self.num_layers = num_layers
+        self.nonlinearity = nonlinearity
+        self.bias = bias
+        self.batch_first = batch_first
+        self.dropout = dropout
+        self.bidirectional = bidirectional
+
+    def initialize(self, xs, h0):
+        Wx_shape = (self.input_size, self.hidden_size)
+        self.Wx = Tensor(shape=Wx_shape, requires_grad=True, stores_grad=True)
+        self.Wx.gaussian(0.0, 1.0)
+
+        Wh_shape = (self.hidden_size, self.hidden_size)
+        self.Wh = Tensor(shape=Wh_shape, requires_grad=True, stores_grad=True)
+        self.Wh.gaussian(0.0, 1.0)
+
+        b_shape = (self.hidden_size,)
+        self.b = Tensor(shape=b_shape, requires_grad=True, stores_grad=True)
+        self.b.set_value(0.0)
+
+    def forward(self, xs, h0):
+        # xs: a tuple or list of input tensors
+        if not isinstance(xs, tuple):
+            xs = tuple(xs)
+        inputs = xs + (h0,)
+        self.device_check(*inputs)
+        # self.device_check(inputs[0], *self.params)
+        self.device_check(inputs[0], self.Wx, self.Wh, self.b)
+        batchsize = xs[0].shape[0]
+        out = []
+        h = self.step_forward(xs[0], h0, self.Wx, self.Wh, self.b)
+        out.append(h)
+        for x in xs[1:]:
+            assert x.shape[0] == batchsize
+            h = self.step_forward(x, h, self.Wx, self.Wh, self.b)
+            out.append(h)
+        return out, h
+
+    def step_forward(self, x, h, Wx, Wh, b):
+        y2 = autograd.matmul(h, Wh)
+        y1 = autograd.matmul(x, Wx)
+        y = autograd.add(y2, y1)
+        y = autograd.add_bias(y, b, axis=0)
+        if self.nonlinearity == "tanh":
+            y = autograd.tanh(y)
+        elif self.nonlinearity == "relu":
+            y = autograd.relu(y)
+        else:
+            raise ValueError
+        return y
+
+    def get_params(self):
+        return {
+            self.Wx.name: self.Wx,
+            self.Wh.name: self.Wh,
+            self.b.name: self.b
+        }
+
+    def set_params(self, parameters):
+        self.Wx.copy_from(parameters[self.Wx.name])
+        self.Wh.copy_from(parameters[self.Wh.name])
+        self.b.copy_from(parameters[self.b.name])
+
+
+class LSTM(RNN_Base):
+    """
+    Generate a LSTM operator
+    """
+
+    def __init__(
+        self,
+        input_size,
+        hidden_size,
+        nonlinearity="tanh",
+        num_layers=1,
+        bias=True,
+        batch_first=False,
+        dropout=0,
+        bidirectional=False,
+    ):
+        """
+        Args:
+            input_size (int):  The number of expected features in the input x
+            hidden_size (int): The number of features in the hidden state h
+            num_layers (int):  Number of recurrent layers. Default: 1
+            nonlinearity (string): The non-linearity to use. Default: 'tanh'
+            bias (bool):  If False, then the layer does not use bias weights.
+                Default: True
+            batch_first (bool):  If True, then the input and output tensors
+                are provided as (batch, seq, feature). Default: False
+            dropout (float): If non-zero, introduces a Dropout layer on the
+                outputs of each RNN layer except the last layer, with dropout
+                probability equal to dropout. Default: 0
+            bidirectional (bool): If True, becomes a bidirectional RNN.
+                Default: False
+        """
+        super(LSTM, self).__init__()
+
+        self.input_size = input_size
+        self.hidden_size = hidden_size
+        self.num_layers = num_layers
+        self.nonlinearity = nonlinearity
+        self.bias = bias
+        self.batch_first = batch_first
+        self.dropout = dropout
+        self.bidirectional = bidirectional
+
+    def initialize(self, xs, h0_c0):
+        # 1. Wx_i input,  Bx_i
+        # 2. Wx_f forget, Bx_f
+        # 3. Wx_o output, Bx_o
+        # 4. Wx_g candidate, Bx_g
+        Wx_shape = (self.input_size, self.hidden_size)
+        self.Wx_i = Tensor(shape=Wx_shape, requires_grad=True, stores_grad=True)
+        self.Wx_f = Tensor(shape=Wx_shape, requires_grad=True, stores_grad=True)
+        self.Wx_o = Tensor(shape=Wx_shape, requires_grad=True, stores_grad=True)
+        self.Wx_g = Tensor(shape=Wx_shape, requires_grad=True, stores_grad=True)
+
+        Wh_shape = (self.hidden_size, self.hidden_size)
+        self.Wh_i = Tensor(shape=Wh_shape, requires_grad=True, stores_grad=True)
+        self.Wh_f = Tensor(shape=Wh_shape, requires_grad=True, stores_grad=True)
+        self.Wh_o = Tensor(shape=Wh_shape, requires_grad=True, stores_grad=True)
+        self.Wh_g = Tensor(shape=Wh_shape, requires_grad=True, stores_grad=True)
+        [
+            w.gaussian(0.0, 0.01) for w in [
+                self.Wx_i, self.Wx_f, self.Wx_o, self.Wx_g, self.Wh_i,
+                self.Wh_f, self.Wh_o, self.Wh_g
+            ]
+        ]
+
+        Bx_shape = (self.hidden_size,)
+        self.Bx_i = Tensor(shape=Bx_shape, requires_grad=True, stores_grad=True)
+        self.Bx_f = Tensor(shape=Bx_shape, requires_grad=True, stores_grad=True)
+        self.Bx_o = Tensor(shape=Bx_shape, requires_grad=True, stores_grad=True)
+        self.Bx_g = Tensor(shape=Bx_shape, requires_grad=True, stores_grad=True)
+        self.Bh_i = Tensor(shape=Bx_shape, requires_grad=True, stores_grad=True)
+        self.Bh_f = Tensor(shape=Bx_shape, requires_grad=True, stores_grad=True)
+        self.Bh_o = Tensor(shape=Bx_shape, requires_grad=True, stores_grad=True)
+        self.Bh_g = Tensor(shape=Bx_shape, requires_grad=True, stores_grad=True)
+        [
+            b.set_value(0.0) for b in [
+                self.Bx_i, self.Bx_f, self.Bx_o, self.Bx_g, self.Bh_i,
+                self.Bh_f, self.Bh_o, self.Bh_g
+            ]
+        ]
+
+    def forward(self, xs, h0_c0):
+        # xs: a tuple or list of input tensors
+        # h0_c0: a tuple of (h0, c0)
+        h0, c0 = h0_c0
+        if not isinstance(xs, list):
+            xs = list(xs)
+        inputs = xs + list((h0, c0))
+        self.device_check(*inputs)
+        self.device_check(inputs[0], *[s for k, s in self.get_states().items()])
+        batchsize = xs[0].shape[0]
+        out = []
+        h, c = self.step_forward(xs[0], h0, c0)
+        out.append(h)
+        for x in xs[1:]:
+            assert x.shape[0] == batchsize
+            h, c = self.step_forward(x, h, c)
+            out.append(h)
+        return out, h, c
+
+    def step_forward(self, x, h, c):
+        # input
+        y1 = autograd.matmul(x, self.Wx_i)
+        y1 = autograd.add_bias(y1, self.Bx_i, axis=0)
+        y2 = autograd.matmul(h, self.Wh_i)
+        y2 = autograd.add_bias(y2, self.Bh_i, axis=0)
+        i = autograd.add(y1, y2)
+        i = autograd.sigmoid(i)
+
+        # forget
+        y1 = autograd.matmul(x, self.Wx_f)
+        y1 = autograd.add_bias(y1, self.Bx_f, axis=0)
+        y2 = autograd.matmul(h, self.Wh_f)
+        y2 = autograd.add_bias(y2, self.Bh_f, axis=0)
+        f = autograd.add(y1, y2)
+        f = autograd.sigmoid(f)
+
+        # output
+        y1 = autograd.matmul(x, self.Wx_o)
+        y1 = autograd.add_bias(y1, self.Bx_o, axis=0)
+        y2 = autograd.matmul(h, self.Wh_o)
+        y2 = autograd.add_bias(y2, self.Bh_o, axis=0)
+        o = autograd.add(y1, y2)
+        o = autograd.sigmoid(o)
+
+        y1 = autograd.matmul(x, self.Wx_g)
+        y1 = autograd.add_bias(y1, self.Bx_g, axis=0)
+        y2 = autograd.matmul(h, self.Wh_g)
+        y2 = autograd.add_bias(y2, self.Bh_g, axis=0)
+        g = autograd.add(y1, y2)
+        g = autograd.tanh(g)
+
+        cout1 = autograd.mul(f, c)
+        cout2 = autograd.mul(i, g)
+        cout = autograd.add(cout1, cout2)
+
+        hout = autograd.tanh(cout)
+        hout = autograd.mul(o, hout)
+        return hout, cout
+
+    def get_params(self):
+        ret = {}
+        for w in [
+                self.Wx_i, self.Wx_f, self.Wx_o, self.Wx_g, self.Wh_i,
+                self.Wh_f, self.Wh_o, self.Wh_g
+        ]:
+            ret[w.name] = w
+
+        for b in [
+                self.Bx_i, self.Bx_f, self.Bx_o, self.Bx_g, self.Bh_i,
+                self.Bh_f, self.Bh_o, self.Bh_g
+        ]:
+            ret[b.name] = b
+        return ret
+
+    def set_params(self, parameters):
+        for w in [
+                self.Wx_i, self.Wx_f, self.Wx_o, self.Wx_g, self.Wh_i,
+                self.Wh_f, self.Wh_o, self.Wh_g
+        ]:
+            w.copy_from(parameters[w.name])
+
+        for b in [
+                self.Bx_i, self.Bx_f, self.Bx_o, self.Bx_g, self.Bh_i,
+                self.Bh_f, self.Bh_o, self.Bh_g
+        ]:
+            b.copy_from(parameters[b.name])
+
+
+''' layers without params or states
 '''
 
-if singa_wrap.USE_CUDNN:
-    cudnn_version = singa_wrap.CUDNN_VERSION
-else:
-    cudnn_version = 0
 
-
-class Layer(object):
-    '''Base Python layer class.
-
-    Typically, the life cycle of a layer instance includes:
-        1. construct layer without input_sample_shapes, goto 2;
-           construct layer with input_sample_shapes, goto 3;
-        2. call setup to create the parameters and setup other meta fields
-        3. call forward or access layer members
-        4. call backward and get parameters for update
-
-    Args:
-        name (str): layer name
-    '''
-
-    def __init__(self, name, conf=None, **kwargs):
-        if conf is None:
-            self.layer = None  # layer converted by swig
-            self.name = name  # TODO(wangwei) duplicate with self.conf.name
-            self.conf = model_pb2.LayerConf()
-            self.conf.name = name
-            self.param_specs = []
-        else:
-            self.conf = conf
-            self.name = conf.name
-            self.caffe_layer()
-            self.param_specs = []
-
-            # convert caffe proto into singa proto format
-            #   case1: parameters of conv and dense layers
-            #   case2: type of activation layers
-            if (conf.type == 'Convolution' or conf.type == 4) or \
-                    (conf.type == 'InnerProduct' or conf.type == 14):
-                w, b = _construct_param_specs_from_caffe_proto(conf)
-                del conf.param[:]
-                conf.param.extend([w, b])
-                self.param_specs.append(w)
-                self.param_specs.append(b)
-                # print 'conf:\n', conf
-            if conf.type == 'Pooling':
-                conf.pooling_conf.ceil = True
-                # print 'conf:\n', conf
-            elif (conf.type == 'ReLU' or conf.type == 18 or
-                  conf.type == 'Sigmoid' or conf.type == 19 or
-                  conf.type == 'TanH' or conf.type == 23):
-                conf.type = (engine + '_' + conf.type).lower()
-            self.conf = conf
-
-        self.has_setup = False
-
-    def setup(self, in_shapes):
-        '''Call the C++ setup function to create params and set some meta data.
-
-        Args:
-            in_shapes: if the layer accepts a single input Tensor, in_shapes is
-                a single tuple specifying the inpute Tensor shape; if the layer
-                accepts multiple input Tensor (e.g., the concatenation layer),
-                in_shapes is a tuple of tuples, each for one input Tensor
-        '''
-        if self.has_setup:
-            return
-        if type(in_shapes[0]) is tuple:
-            self.layer.SetupWithMultInputs([list(s) for s in in_shapes],
-                                           self.conf.SerializeToString())
-        else:
-            self.layer.Setup(list(in_shapes), self.conf.SerializeToString())
-        self.has_setup = True
-
-    def caffe_layer(self):
-        '''
-        Create a singa layer based on caffe layer configuration.
-        '''
-        _check_engine(engine, ['cudnn', 'singacpp', 'singacuda', 'singacl'])
-        if self.conf.type == 'InnerProduct' or self.conf.type == 14:
-            self.layer = _create_layer(engine, 'Dense')
-        else:
-            self.layer = _create_layer(engine, self.conf.type)
-
-    def get_output_sample_shape(self):
-        '''Called after setup to get the shape of the output sample(s).
-
-        Returns:
-            a tuple for a single output Tensor or a list of tuples if this layer
-            has multiple outputs
-        '''
-        assert self.has_setup, \
-            'Must call setup() before get_output_sample_shape()'
-        return self.layer.GetOutputSampleShape()
-
-    def param_names(self):
-        '''
-        Returns:
-            a list of strings, one for the name of one parameter Tensor
-        '''
-        names = []
-        for x in self.param_specs:
-            names.append(x.name)
-        return names
-
-    def param_values(self):
-        '''Return param value tensors.
-
-        Parameter tensors are not stored as layer members because cpp Tensor
-        could be moved onto diff devices due to the change of layer device,
-        which would result in inconsistency.
-
-        Returns:
-            a list of tensors, one for each paramter
-        '''
-        if self.layer is None:
-            return []
-        else:
-            return tensor.from_raw_tensors(self.layer.param_values())
-
-    def forward(self, flag, x):
-        '''Forward propagate through this layer.
-
-        Args:
-            flag: True (kTrain) for training (kEval); False for evaluating;
-                other values for furture use.
-            x (Tensor or list<Tensor>): an input tensor if the layer is
-                connected from a single layer; a list of tensors if the layer
-                is connected from multiple layers.
-
-        Return:
-            a tensor if the layer is connected to a single layer; a list of
-            tensors if the layer is connected to multiple layers;
-        '''
-        assert self.has_setup, 'Must call setup() before forward()'
-        if type(flag) is bool:
-            if flag:
-                flag = model_pb2.kTrain
-            else:
-                flag = model_pb2.kEval
-        if type(x) is list:
-            xs = [t.data for t in x]
-            y = self.layer.ForwardWithMultInputs(flag, xs)
-        else:
-            assert isinstance(x, tensor.Tensor), \
-                'input of %s (type:%s) must be a Tensor or Tensor list'\
-                % (self.name, type(x).__name__)
-            y = self.layer.Forward(flag, x.data)
-        if type(y) is tuple:
-            return tensor.from_raw_tensors(y)
-        else:
-            return tensor.from_raw_tensor(y)
-
-    def backward(self, flag, dy):
-        '''Backward propagate gradients through this layer.
-
-        Args:
-            flag (int): for future use.
-            dy (Tensor or list<Tensor>): the gradient tensor(s) y w.r.t the
-                objective loss
-        Return:
-            <dx, <dp1, dp2..>>, dx is a (set of) tensor(s) for the gradient of x
-            , dpi is the gradient of the i-th parameter
-        '''
-        if type(flag) is bool:
-            if flag:
-                flag = model_pb2.kTrain
-            else:
-                flag = model_pb2.kEval
-
-        if type(dy) == list:
-            dys = [t.data for t in dy]
-            ret = self.layer.BackwardWithMultInputs(flag, dys)
-        else:
-            assert isinstance(dy, tensor.Tensor), \
-                'input of %s (type:%s) must be a Tensor or Tensor list'\
-                % (self.name, type(dy).__name__)
-            dys = dy.data
-            ret = self.layer.Backward(flag, dys)
-        if type(ret[0]) is tuple:
-            dxs = tensor.from_raw_tensors(ret[0])
-        else:
-            dxs = tensor.from_raw_tensor(ret[0])
-        return dxs, tensor.from_raw_tensors(ret[1])
-
-    def to_device(self, device):
-        '''Move layer state tensors onto the given device.
-
-        Args:
-            device: swig converted device, created using singa.device
-        '''
-        if self.layer is not None:
-            self.layer.ToDevice(device)
-
-    def as_type(self, dtype):
-        pass
-
-    def __copy__(self):
-        pass
-
-    def __deepcopy__(self, memo):
-        pass
-
-
-class Dummy(Layer):
-    '''A dummy layer that does nothing but just forwards/backwards the data
-    (the input/output is a single tensor).
-    '''
-
-    def __init__(self, name, input_sample_shape=None):
-        super(Dummy, self).__init__(name)
-        self.output_sample_shape = input_sample_shape
-
-    def get_output_sample_shape(self):
-        return self.output_sample_shape
-
-    def setup(self, input_sample_shape):
-        self.output_sample_shape = input_sample_shape
-        self.has_setup = True
-
-    def forward(self, flag, x):
-        '''Return the input x'''
-        return x
-
-    def backward(self, falg, dy):
-        '''Return dy, []'''
-        return dy, []
-
-
-class Conv2D(Layer):
-    """Construct a layer for 2D convolution.
-
-    Args:
-        nb_kernels (int): num of the channels (kernels) of the input Tensor
-        kernel: an integer or a pair of integers for kernel height and width
-        stride: an integer or a pair of integers for stride height and width
-        border_mode (string): padding mode, case in-sensitive,
-            'valid' -> padding is 0 for height and width
-            'same' -> padding is half of the kernel (floor), the kernel must be
-            odd number.
-        cudnn_prefer (string): the preferred algorithm for cudnn convolution
-            which could be 'fastest', 'autotune', 'limited_workspace' and
-            'no_workspace'
-        workspace_byte_limit(int): max workspace size in MB (default is 512MB)
-        data_format (string): either 'NCHW' or 'NHWC'
-        use_bias (bool): True or False
-        pad: an integer or a pair of integers for padding height and width
-        W_specs (dict): used to specify the weight matrix specs, fields
-            include,
-            'name' for parameter name
-            'lr_mult' for learning rate multiplier
-            'decay_mult' for weight decay multiplier
-            'init' for init method, which could be 'gaussian', 'uniform',
-            'xavier' and ''
-            'std', 'mean', 'high', 'low' for corresponding init methods
-            TODO(wangwei) 'clamp' for gradient constraint, value is scalar
-            'regularizer' for regularization, currently support 'l2'
-        b_specs (dict): hyper-parameters for bias vector, similar as W_specs
-        name (string): layer name.
-        input_sample_shape: 3d tuple for the shape of the input Tensor
-            without the batchsize, e.g., (channel, height, width) or
-            (height, width, channel)
+class ReLU(Layer):
+    """
+    Generate a ReLU operator
     """
 
-    def __init__(self,
-                 name,
-                 nb_kernels,
-                 kernel=3,
-                 stride=1,
-                 border_mode='same',
-                 cudnn_prefer='fastest',
-                 workspace_byte_limit=1024,
-                 data_format='NCHW',
-                 use_bias=True,
-                 W_specs=None,
-                 b_specs=None,
-                 pad=None,
-                 input_sample_shape=None):
-        super(Conv2D, self).__init__(name)
-        assert data_format == 'NCHW', 'Not supported data format: %s ' \
-            'only "NCHW" is enabled currently' % (data_format)
-        conf = self.conf.convolution_conf
-        conf.num_output = nb_kernels
-        conf.prefer = cudnn_prefer
-        conf.workspace_byte_limit = workspace_byte_limit
-        self.kernel = kernel
-        self.stride = stride
-        self.pad = pad
-        self.border_mode = border_mode
-        conf.bias_term = use_bias
-        # TODO(wangwei) enable data format for cpp code
-        # conf.data_format = data_format
-        if W_specs is None:
-            W_specs = {'init': 'xavier'}
-        if 'name' not in W_specs:
-            W_specs['name'] = name + '/weight'
-        wspecs = _construct_param_specs_from_dict(W_specs)
-        self.conf.param.extend([wspecs])
-        self.param_specs.append(wspecs)
-        if use_bias:
-            if b_specs is None:
-                b_specs = {'init': 'constant'}
-            if 'name' not in b_specs:
-                b_specs['name'] = name + '/bias'
-            bspecs = _construct_param_specs_from_dict(b_specs)
-            self.conf.param.extend([bspecs])
-            self.param_specs.append(bspecs)
+    def __init__(self):
+        super(ReLU, self).__init__()
 
-        _check_engine(engine, ['cudnn', 'singacpp', 'singacl'])
-        self.layer = _create_layer(engine, 'Convolution')
-        if input_sample_shape is not None:
-            self.setup(input_sample_shape)
-
-    def setup(self, in_shape):
-        '''Set up the kernel, stride and padding; then call the C++ setup
-        function to create params and set some meta data.
-
-        Args:
-                in_shapes is a tuple of int for the input sample shape
-        '''
-        if self.has_setup:
-            return
-        _set_kernel_stride_pad(self.conf.convolution_conf, self.kernel,
-                               self.stride, self.border_mode, self.pad,
-                               in_shape)
-        self.layer.Setup(list(in_shape), self.conf.SerializeToString())
-        self.has_setup = True
+    def forward(self, x):
+        return autograd.relu(x)
 
 
-class Conv1D(Conv2D):
-    """Construct a layer for 1D convolution.
-
-    Most of the args are the same as those for Conv2D except the kernel,
-    stride, pad, which is a scalar instead of a tuple.
-    input_sample_shape is a tuple with a single value for the input feature
-    length
+class Sigmoid(Layer):
+    """
+    Generate a ReLU operator
     """
 
-    def __init__(self,
-                 name,
-                 nb_kernels,
-                 kernel=3,
-                 stride=1,
-                 border_mode='same',
-                 cudnn_prefer='fastest',
-                 workspace_byte_limit=1024,
-                 use_bias=True,
-                 W_specs={'init': 'Xavier'},
-                 b_specs={
-                     'init': 'Constant',
-                     'value': 0
-                 },
-                 pad=None,
-                 input_sample_shape=None):
-        pad = None
-        if pad is not None:
-            pad = (0, pad)
-        if input_sample_shape is not None:
-            input_sample_shape = (1, 1, input_sample_shape[0])
-        super(Conv1D, self).__init__(name,
-                                     nb_kernels, (1, kernel), (0, stride),
-                                     border_mode,
-                                     cudnn_prefer,
-                                     workspace_byte_limit,
-                                     use_bias=use_bias,
-                                     pad=pad,
-                                     W_specs=W_specs,
-                                     b_specs=b_specs,
-                                     input_sample_shape=input_sample_shape)
+    def __init__(self):
+        super(Sigmoid, self).__init__()
 
-    def get_output_sample_shape(self):
-        shape = self.layer.GetOutputSampleShape()
-        assert len(shape) == 3, 'The output sample shape should be 3D.'\
-            'But the length is %d' % len(shape)
-        return (shape[0], shape[2])
+    def forward(self, x):
+        return autograd.sigmoid(x)
 
 
-class Pooling2D(Layer):
-    '''2D pooling layer providing max/avg pooling.
-
-    All args are the same as those for Conv2D, except the following one
-
-    Args:
-        mode: pooling type, model_pb2.PoolingConf.MAX or
-            model_pb2.PoolingConf.AVE
-
-    '''
-
-    def __init__(self,
-                 name,
-                 mode,
-                 kernel=3,
-                 stride=2,
-                 border_mode='same',
-                 pad=None,
-                 data_format='NCHW',
-                 input_sample_shape=None):
-        super(Pooling2D, self).__init__(name)
-        assert data_format == 'NCHW', 'Not supported data format: %s ' \
-            'only "NCHW" is enabled currently' % (data_format)
-        conf = self.conf.pooling_conf
-        conf.pool = mode
-        self.kernel = kernel
-        self.stride = stride
-        self.pad = pad
-        self.border_mode = border_mode
-        _check_engine(engine, ['cudnn', 'singacpp', 'singacl'])
-        self.layer = _create_layer(engine, 'Pooling')
-        if input_sample_shape is not None:
-            self.setup(input_sample_shape)
-
-    def setup(self, in_shape):
-        '''Set up the kernel, stride and padding; then call the C++ setup
-        function to create params and set some meta data.
-
-        Args:
-            in_shapes is a tuple of int for the input sample shape
-        '''
-        if self.has_setup:
-            return
-        _set_kernel_stride_pad(self.conf.pooling_conf, self.kernel, self.stride,
-                               self.border_mode, self.pad, in_shape)
-        self.layer.Setup(list(in_shape), self.conf.SerializeToString())
-        self.has_setup = True
-
-
-class MaxPooling2D(Pooling2D):
-
-    def __init__(self,
-                 name,
-                 kernel=3,
-                 stride=2,
-                 border_mode='same',
-                 pad=None,
-                 data_format='NCHW',
-                 input_sample_shape=None):
-        super(MaxPooling2D,
-              self).__init__(name, model_pb2.PoolingConf.MAX, kernel, stride,
-                             border_mode, pad, data_format, input_sample_shape)
-
-
-class AvgPooling2D(Pooling2D):
-
-    def __init__(self,
-                 name,
-                 kernel=3,
-                 stride=2,
-                 border_mode='same',
-                 pad=None,
-                 data_format='NCHW',
-                 input_sample_shape=None):
-        super(AvgPooling2D,
-              self).__init__(name, model_pb2.PoolingConf.AVE, kernel, stride,
-                             border_mode, pad, data_format, input_sample_shape)
-
-
-class MaxPooling1D(MaxPooling2D):
-
-    def __init__(self,
-                 name,
-                 kernel=3,
-                 stride=2,
-                 border_mode='same',
-                 pad=None,
-                 data_format='NCHW',
-                 input_sample_shape=None):
-        """Max pooling for 1D feature.
-
-        Args:
-            input_sample_shape (tuple): 1D tuple for input feature length
-        """
-        pad = None
-        if pad is not None:
-            pad = (0, pad)
-        if input_sample_shape is not None:
-            assert len(input_sample_shape) == 1, \
-                'AvgPooling1D expects input sample to be 1D'
-            input_sample_shape = (1, 1, input_sample_shape[0])
-        else:
-            input_sample_shape = None
-        super(MaxPooling1D,
-              self).__init__(name, (1, kernel), (0, stride), border_mode, pad,
-                             data_format, input_sample_shape)
-
-    def get_output_sample_shape(self):
-        shape = self.layer.GetOutputSampleShape()
-        return (shape[2],)
-
-
-class AvgPooling1D(AvgPooling2D):
-
-    def __init__(self,
-                 name,
-                 kernel=3,
-                 stride=2,
-                 border_mode='same',
-                 pad=None,
-                 data_format='NCHW',
-                 input_sample_shape=None):
-        """input_feature_length is a scalar value"""
-        pad2 = None
-        if pad is not None:
-            pad2 = (pad, 0)
-        if input_sample_shape is not None:
-            assert len(input_sample_shape) == 1, \
-                'AvgPooling1D expects input sample to be 1D'
-            input_sample_shape = (1, 1, input_sample_shape[0])
-        else:
-            input_sample_shape = None
-
-        super(AvgPooling1D,
-              self).__init__(name, (kernel, 1), (0, stride), border_mode, pad2,
-                             data_format, input_sample_shape)
-
-    def get_output_sample_shape(self):
-        shape = self.layer.GetOutputSampleShape()
-        return (shape[2],)
-
-
-class BatchNormalization(Layer):
-    """Batch-normalization.
-
-    Args:
-        momentum (float): for running average mean and variance.
-        beta_specs (dict): dictionary includes the fields for the beta
-            param:
-            'name' for parameter name
-            'lr_mult' for learning rate multiplier
-            'decay_mult' for weight decay multiplier
-            'init' for init method, which could be 'gaussian', 'uniform',
-            'xavier' and ''
-            'std', 'mean', 'high', 'low' for corresponding init methods
-            'clamp' for gradient constraint, value is scalar
-            'regularizer' for regularization, currently support 'l2'
-        gamma_specs (dict): similar to beta_specs, but for the gamma param.
-        name (string): layer name
-        input_sample_shape (tuple): with at least one integer
+class Add(Layer):
+    """
+    Generate a Add operator
     """
 
-    def __init__(self,
-                 name,
-                 momentum=0.9,
-                 beta_specs=None,
-                 gamma_specs=None,
-                 input_sample_shape=None):
-        super(BatchNormalization, self).__init__(name)
-        conf = self.conf.batchnorm_conf
-        conf.factor = momentum
-        if beta_specs is None:
-            beta_specs = {'init': 'Xavier'}
-        if gamma_specs is None:
-            gamma_specs = {'init': 'Xavier'}
-        if 'name' not in beta_specs:
-            beta_specs['name'] = name + '/beta'
-        if 'name' not in gamma_specs:
-            gamma_specs['name'] = name + '/gamma'
-        mean_specs = {'init': 'constant', 'value': 0, 'name': name + '/mean'}
-        var_specs = {'init': 'constant', 'value': 1, 'name': name + '/var'}
-        self.conf.param.extend([_construct_param_specs_from_dict(gamma_specs)])
-        self.conf.param.extend([_construct_param_specs_from_dict(beta_specs)])
-        self.conf.param.extend([_construct_param_specs_from_dict(mean_specs)])
-        self.conf.param.extend([_construct_param_specs_from_dict(var_specs)])
-        self.param_specs.append(_construct_param_specs_from_dict(gamma_specs))
-        self.param_specs.append(_construct_param_specs_from_dict(beta_specs))
-        self.param_specs.append(_construct_param_specs_from_dict(mean_specs))
-        self.param_specs.append(_construct_param_specs_from_dict(var_specs))
-        _check_engine(engine,
-                      ['cudnn', 'singa', 'singacpp', 'singacuda', 'singacl'])
-        self.layer = _create_layer(engine, 'BatchNorm')
-        if input_sample_shape is not None:
-            self.setup(input_sample_shape)
+    def __init__(self):
+        super(Add, self).__init__()
 
-
-class L2Norm(Layer):
-    '''Normalize each sample to have L2 norm = 1'''
-
-    def __init__(self, name, input_sample_shape, epsilon=1e-8):
-        super(L2Norm, self).__init__(name)
-        self.y = None
-        self.norm = None
-        self.name = name
-        self.epsilon = epsilon
-        self.out_sample_shape = input_sample_shape
-
-    def get_output_sample_shape(self):
-        return self.out_sample_shape
-
-    def forward(self, is_train, x):
-        norm = tensor.sum_columns(tensor.square(x))
-        norm += self.epsilon
-        norm = tensor.sqrt(norm)
-        self.y = x.clone()
-        self.y.div_column(norm)
-
-        if is_train:
-            self.norm = norm
-        return self.y
-
-    def backward(self, is_train, dy):
-        # (dy - y * k) / norm, k = sum(dy * y)
-        k = tensor.sum_columns(tensor.eltwise_mult(dy, self.y))
-        self.y.mult_column(k)
-        dx = dy - self.y
-        dx.div_column(self.norm)
-        return dx, []
-
-
-class LRN(Layer):
-    """Local response normalization.
-
-    Args:
-        size (int): # of channels to be crossed
-            normalization.
-        mode (string): 'cross_channel'
-        input_sample_shape (tuple): 3d tuple, (channel, height, width)
-    """
-
-    def __init__(self,
-                 name,
-                 size=5,
-                 alpha=1,
-                 beta=0.75,
-                 mode='cross_channel',
-                 k=1,
-                 input_sample_shape=None):
-        super(LRN, self).__init__(name)
-        conf = self.conf.lrn_conf
-        conf.local_size = size
-        conf.alpha = alpha
-        conf.beta = beta
-        conf.k = k
-        # TODO(wangwei) enable mode = 'within_channel'
-        assert mode == 'cross_channel', 'only support mode="across_channel"'
-        conf.norm_region = model_pb2.LRNConf.ACROSS_CHANNELS
-        _check_engine(engine,
-                      ['cudnn', 'singa', 'singacpp', 'singacuda', 'singacl'])
-        self.layer = _create_layer(engine, 'LRN')
-        if input_sample_shape is not None:
-            self.setup(input_sample_shape)
-
-
-class Dense(Layer):
-    """Apply linear/affine transformation, also called inner-product or
-    fully connected layer.
-
-    Args:
-        num_output (int): output feature length.
-        use_bias (bool): add a bias vector or not to the transformed feature
-        W_specs (dict): specs for the weight matrix
-            'name' for parameter name
-            'lr_mult' for learning rate multiplier
-            'decay_mult' for weight decay multiplier
-            'init' for init method, which could be 'gaussian', 'uniform',
-            'xavier' and ''
-            'std', 'mean', 'high', 'low' for corresponding init methods
-            'clamp' for gradient constraint, value is scalar
-            'regularizer' for regularization, currently support 'l2'
-        b_specs (dict): specs for the bias vector, same fields as W_specs.
-        W_transpose (bool): if true, output=x*W.T+b;
-        input_sample_shape (tuple): input feature length
-    """
-
-    def __init__(self,
-                 name,
-                 num_output,
-                 use_bias=True,
-                 W_specs=None,
-                 b_specs=None,
-                 W_transpose=False,
-                 input_sample_shape=None):
-        """Apply linear/affine transformation, also called inner-product or
-        fully connected layer.
-
-        Args:
-            num_output (int): output feature length.
-            use_bias (bool): add a bias vector or not to the transformed feature
-            W_specs (dict): specs for the weight matrix
-                'name' for parameter name
-                'lr_mult' for learning rate multiplier
-                'decay_mult' for weight decay multiplier
-                'init' for init method, which could be 'gaussian', 'uniform',
-                'xavier' and ''
-                'std', 'mean', 'high', 'low' for corresponding init methods
-                'clamp' for gradient constraint, value is scalar
-                'regularizer' for regularization, currently support 'l2'
-            b_specs (dict): specs for the bias vector, same fields as W_specs.
-            W_transpose (bool): if true, output=x*W.T+b;
-            input_sample_shape (tuple): input feature length
-        """
-        super(Dense, self).__init__(name)
-        conf = self.conf.dense_conf
-        conf.num_output = num_output
-        conf.bias_term = use_bias
-        conf.transpose = W_transpose
-        if W_specs is None:
-            W_specs = {'init': 'xavier'}
-        if 'name' not in W_specs:
-            W_specs['name'] = name + '/weight'
-        wspecs = _construct_param_specs_from_dict(W_specs)
-        self.conf.param.extend([wspecs])
-        self.param_specs.append(wspecs)
-        if use_bias:
-            if b_specs is None:
-                b_specs = {'init': 'constant', 'value': 0}
-            if 'name' not in b_specs:
-                b_specs['name'] = name + '/bias'
-            bspecs = _construct_param_specs_from_dict(b_specs)
-            self.conf.param.extend([bspecs])
-            self.param_specs.append(bspecs)
-        # dense layer is transparent to engine.
-        if engine == 'cudnn':
-            self.layer = _create_layer('singacuda', 'Dense')
-        else:
-            self.layer = _create_layer(engine, 'Dense')
-        if input_sample_shape is not None:
-            self.setup(input_sample_shape)
-
-
-class Dropout(Layer):
-    """Droput layer.
-
-    Args:
-        p (float): probability for dropping out the element, i.e., set to 0
-        name (string): layer name
-    """
-
-    def __init__(self, name, p=0.5, input_sample_shape=None):
-        super(Dropout, self).__init__(name)
-        conf = self.conf.dropout_conf
-        conf.dropout_ratio = p
-        # dropout is support in cudnn since V5
-        if engine.lower() == 'cudnn' and cudnn_version < 5000:
-            myengine = 'singacuda'
-        else:
-            myengine = engine
-        _check_engine(myengine,
-                      ['cudnn', 'singa', 'singacpp', 'singacuda', 'singacl'])
-        self.layer = _create_layer(myengine, 'Dropout')
-        if input_sample_shape is not None:
-            self.setup(input_sample_shape)
-
-
-class Activation(Layer):
-    """Activation layers.
-
-    Args:
-        name (string): layer name
-        mode (string): 'relu', 'sigmoid', or 'tanh'
-        input_sample_shape (tuple): shape of a single sample
-    """
-
-    def __init__(self, name, mode='relu', input_sample_shape=None):
-        super(Activation, self).__init__(name)
-        _check_engine(engine, ['cudnn', 'singacpp', 'singacuda', 'singacl'])
-        self.conf.type = (engine + '_' + mode).lower()
-        self.layer = _create_layer(engine, mode)
-        if input_sample_shape is not None:
-            self.setup(input_sample_shape)
-
-
-class Softmax(Layer):
-    """Apply softmax.
-
-    Args:
-        axis (int): reshape the input as a matrix with the dimension
-            [0,axis) as the row, the [axis, -1) as the column.
-        input_sample_shape (tuple): shape of a single sample
-    """
-
-    def __init__(self, name, axis=1, input_sample_shape=None):
-        super(Softmax, self).__init__(name)
-        # conf = self.conf.softmax_conf
-        # conf.axis = axis
-        _check_engine(engine,
-                      ['cudnn', 'singa', 'singacpp', 'singacl', 'singacuda'])
-        self.layer = _create_layer(engine, 'Softmax')
-        if input_sample_shape is not None:
-            self.setup(input_sample_shape)
+    def forward(self, a, b):
+        return autograd.add(a, b)
 
 
 class Flatten(Layer):
-    """Reshape the input tensor into a matrix.
-
-    Args:
-        axis (int): reshape the input as a matrix with the dimension
-            [0,axis) as the row, the [axis, -1) as the column.
-        input_sample_shape (tuple): shape for a single sample
+    """
+    Generate a Flatten operator
     """
 
-    def __init__(self, name, axis=1, input_sample_shape=None):
-        super(Flatten, self).__init__(name)
-        conf = self.conf.flatten_conf
-        conf.axis = axis
-        # fltten layer is transparent to engine
-        if engine == 'cudnn':
-            self.layer = _create_layer('singacuda', 'Flatten')
-        else:
-            self.layer = _create_layer(engine, 'Flatten')
-        if input_sample_shape is not None:
-            self.setup(input_sample_shape)
-
-
-class Merge(Layer):
-    '''Sum all input tensors.
-
-    Args:
-        input_sample_shape: sample shape of the input. The sample shape of all
-            inputs should be the same.
-    '''
-
-    def __init__(self, name, input_sample_shape=None):
-        self.in_shape = input_sample_shape
-        self.num_input = 1
-        super(Merge, self).__init__(name)
-
-    def setup(self, in_shape):
-        self.in_shape = in_shape
-        self.has_setup = True
-
-    def get_output_sample_shape(self):
-        return self.in_shape
-
-    def forward(self, flag, inputs):
-        '''Merge all input tensors by summation.
-
-        TODO(wangwei) do element-wise merge operations, e.g., avg, count
-        Args:
-            flag: not used.
-            inputs (list): a list of tensors
-
-        Returns:
-            A single tensor as the sum of all input tensors
-        '''
-        assert len(inputs) > 1, 'There must be multiple input tensors'
-        self.num_input = len(inputs)
-        output = tensor.Tensor()
-        output.reset_like(inputs[0])
-        output.set_value(0)
-        for x in inputs:
-            output += x
-        return output
-
-    def backward(self, flag, grad):
-        '''Replicate the grad for each input source layer.
-
-        Args:
-            grad(Tensor), the gradient tensor of the merged result from forward
-
-        Returns:
-            A list of replicated grad, one per source layer
-        '''
-        assert isinstance(grad, tensor.Tensor), 'The input must be Tensor' \
-            ' instead of %s' % type(grad).__name__
-        return [grad] * self.num_input, []  # * self.num_input
-
-
-class Split(Layer):
-    '''Replicate the input tensor.
-
-    Args:
-        num_output (int): number of output tensors to generate.
-        input_sample_shape: includes a single integer for the input sample
-            feature size.
-    '''
-
-    def __init__(self, name, num_output, input_sample_shape=None):
-        self.num_output = num_output
-        self.in_shape = input_sample_shape
-        super(Split, self).__init__(name)
-
-    def setup(self, in_shape):
-        self.in_shape = in_shape
-        self.has_setup = True
-
-    def get_output_sample_shape(self):
-        return [self.in_shape] * self.num_output
-
-    def forward(self, flag, input):
-        '''Replicate the input tensor into mutiple tensors.
-
-        Args:
-            flag: not used
-            input: a single input tensor
-
-        Returns:
-            a list a output tensor (each one is a copy of the input)
-        '''
-        assert isinstance(input, tensor.Tensor), 'The input must be Tensor'
-        outputs = [input] * self.num_output
-        return outputs
-
-    def backward(self, flag, grads):
-        '''Sum all grad tensors to generate a single output tensor.
-
-        Args:
-            grads(list of Tensor), one per dest layer
-
-        Returns:
-            a single tensor as the sum of all grads
-        '''
-        assert len(grads) > 1, 'There must be multiple gradients'
-        dx = tensor.Tensor()
-        dx.reset_like(grads[0])
-        dx.set_value(0)
-        for g in grads:
-            dx += g
-        return dx, []
-
-
-class Concat(Layer):
-    '''Concatenate tensors vertically (axis = 0) or horizontally (axis = 1).
-
-    Currently, only support tensors with 2 dimensions.
-
-    Args:
-        axis(int): 0 for concat row; 1 for concat columns;
-        input_sample_shapes: a list of sample shape tuples, one per input tensor
-    '''
-
-    def __init__(self, name, axis, input_sample_shapes=None):
-        super(Concat, self).__init__(name)
-        self.in_shapes = input_sample_shapes
+    def __init__(self, axis=1):
+        super(Flatten, self).__init__()
         self.axis = axis
-        self.conf.concat_conf.axis = axis
-        if engine == "cudnn":
-            self.layer = _create_layer('singacuda', 'Concat')
-        else:
-            self.layer = _create_layer(engine, 'Concat')
-        if input_sample_shapes is not None:
-            self.setup(input_sample_shapes)
 
-    def forward(self, flag, inputs):
-        '''Concatenate all input tensors.
-
-        Args:
-            flag: same as Layer::forward()
-            input: a list of tensors
-
-        Returns:
-            a single concatenated tensor
-        '''
-        assert type(inputs) is list, 'Must be a list of Tensors'
-        ys = super(Concat, self).forward(flag, inputs)
-        return ys[0]
-
-    def backward(self, flag, dy):
-        '''Backward propagate gradients through this layer.
-
-        Args:
-            flag: same as Layer::backward()
-            dy(Tensor): the gradient tensors of y w.r.t objective loss
-        Return:
-            <dx, []>, dx is a list tensors for the gradient of the inputs; []
-               is an empty list.
-        '''
-        if type(dy) is tensor.Tensor:
-            dy = [dy]
-        assert type(dy) is list, 'Must be a list(Tensor)'
-        return super(Concat, self).backward(flag, dy)
+    def forward(self, x):
+        return autograd.flatten(x, self.axis)
 
 
-class Slice(Layer):
-    '''Slice the input tensor into multiple sub-tensors vertially (axis=0) or
-    horizontally (axis=1).
+class SoftMaxCrossEntropy(Layer):
+    """
+    Generate a SoftMaxCrossEntropy operator
+    """
 
-    Args:
-        axis (int): 0 for slice rows; 1 for slice columns;
-        slice_point(list): positions along the axis to do slice; there are n-1
-            points for n sub-tensors;
-        input_sample_shape: input tensor sample shape
-    '''
+    def __init__(self):
+        super(SoftMaxCrossEntropy, self).__init__()
 
-    def __init__(self, name, axis, slice_point, input_sample_shape=None):
-        super(Slice, self).__init__(name)
-        self.in_shape = input_sample_shape
+    def forward(self, x, t):
+        return autograd.softmax_cross_entropy(x, t)
+
+
+class SoftMax(Layer):
+    """
+    Generate a SoftMax operator
+    """
+
+    def __init__(self):
+        super(SoftMax, self).__init__()
+
+    def forward(self, x):
+        return autograd.softmax(x)
+
+
+class MeanSquareError(Layer):
+    """
+    Generate a MeanSquareError operator
+    """
+
+    def __init__(self):
+        super(MeanSquareError, self).__init__()
+
+    def forward(self, x, t):
+        return autograd.mse_loss(x, t)
+
+
+class CrossEntropy(Layer):
+    """
+    Generate a CrossEntropy operator
+    """
+
+    def __init__(self):
+        super(CrossEntropy, self).__init__()
+
+    def forward(self, x, t):
+        return autograd.cross_entropy(x, t)
+
+
+class BinaryCrossEntropy(Layer):
+    """
+    Generate a BinaryCrossEntropy operator
+    """
+
+    def __init__(self):
+        super(BinaryCrossEntropy, self).__init__()
+
+    def forward(self, x, t):
+        return autograd.binary_cross_entropy(x, t)
+
+
+class Dropout(Layer):
+    """
+    Generate a Dropout operator
+    """
+
+    def __init__(self, ratio=0.5):
+        super(Dropout, self).__init__()
+        self.ratio = ratio
+
+    def forward(self, x):
+        return autograd.dropout(x, self.ratio)
+
+
+class Cat(Layer):
+    """
+    Generate a Cat Operator
+    """
+
+    def __init__(self, axis=0):
+        super(Cat, self).__init__()
         self.axis = axis
-        self.conf.slice_conf.axis = axis
-        self.conf.slice_conf.slice_point.extend(slice_point)
-        if engine == "cudnn":
-            self.layer = _create_layer('singacuda', 'Slice')
+
+    def forward(self, xs):
+        return autograd.cat(xs, self.axis)
+
+
+class Reshape(Layer):
+    """
+    Generate a Reshape Operator
+    """
+
+    def __init__(self):
+        super(Reshape, self).__init__()
+
+    def forward(self, x, shape):
+        return autograd.reshape(x, shape)
+
+
+class CudnnRNN(Layer):
+    """ `CudnnRNN` class implements with c++ backend and run the operation
+          directly on cuDNN
+        While `RNN` class implements with high level singa API
+    """
+
+    def __init__(self,
+                 hidden_size,
+                 activation="tanh",
+                 num_layers=1,
+                 bias=True,
+                 batch_first=True,
+                 dropout=0,
+                 bidirectional=False,
+                 rnn_mode="lstm",
+                 use_mask=False,
+                 return_sequences=True):
+        """
+            Args:
+                hidden_size: hidden feature dim
+                rnn_mode: accepted value: "vanilla", "tanh", "relu",  "lstm", "gru"
+        """
+        assert singa.USE_CUDA, "Not able to run without CUDA"
+        assert num_layers > 0, "num layers should be > 0"
+        assert 0 <= dropout < 1, "dropout shouldbe >=0 and <1"
+        super(CudnnRNN, self).__init__()
+
+        self.rnn_mode = rnn_mode
+        self.hidden_size = hidden_size
+        self.num_layers = num_layers
+        self.dropout = dropout
+        self.bidirectional = 1 if bidirectional else 0
+        self.return_sequences = return_sequences
+        self.batch_first = batch_first
+        self.use_mask = use_mask
+
+        # GPU parameter
+        # cudnn_rnn_mode: 0 - RNN RELU, 1 - RNN TANH, 2 - LSTM, 3 - GRU
+        if self.rnn_mode == "lstm":
+            self.cudnn_rnn_mode = 2
+        elif self.rnn_mode == "vanilla" or self.rnn_mode == "tanh":
+            self.cudnn_rnn_mode = 1
+        elif self.rnn_mode == "relu":
+            self.cudnn_rnn_mode = 0
+        elif self.rnn_mode == "gru":
+            self.cudnn_rnn_mode = 3
+
+    def initialize(self, x, hx=None, cx=None, seq_lengths=None):
+        if self.batch_first:
+            x = x.transpose((1, 0, 2))
+        self.input_size = x.shape[1]
+
+        # GPU handle
+        self.handle = singa.CudnnRNNHandle(x.data,
+                                           self.hidden_size,
+                                           mode=self.cudnn_rnn_mode,
+                                           num_layers=self.num_layers,
+                                           dropout=self.dropout,
+                                           bidirectional=self.bidirectional)
+
+        self.W = Tensor(shape=(self.handle.weights_size,),
+                        requires_grad=True,
+                        stores_grad=True,
+                        device=x.device)
+
+        k = 1 / self.hidden_size
+        self.W.uniform(-math.sqrt(k), math.sqrt(k))
+
+    def forward(self, x, hx=None, cx=None, seq_lengths=None):
+
+        self.device_check(x, self.W)
+        if self.batch_first:  # (bs,seq,data) -> (seq,bs,data)
+            x = autograd.transpose(x, (1, 0, 2))
+
+        batch_size = x.shape[1]
+        directions = 2 if self.bidirectional else 1
+        if hx == None:
+            hx = Tensor(shape=(self.num_layers * directions, batch_size,
+                               self.hidden_size),
+                        requires_grad=False,
+                        stores_grad=False,
+                        device=x.device).set_value(0.0)
+        if cx == None:
+            cx = Tensor(shape=(self.num_layers * directions, batch_size,
+                               self.hidden_size),
+                        requires_grad=False,
+                        stores_grad=False,
+                        device=x.device).set_value(0.0)
+
+        # outputs returned is list
+        #   inputs has shape of {sequence length, batch size, feature size}
+        if self.use_mask:
+            assert type(seq_lengths) == Tensor, "wrong type for seq_lengths"
+            y = autograd._RNN(self.handle,
+                              return_sequences=self.return_sequences,
+                              use_mask=self.use_mask,
+                              seq_lengths=seq_lengths)(x, hx, cx, self.W)[0]
         else:
-            self.layer = _create_layer(engine, 'Slice')
-        if input_sample_shape is not None:
-            self.setup(input_sample_shape)
+            y = autograd._RNN(
+                self.handle,
+                return_sequences=self.return_sequences,
+            )(x, hx, cx, self.W)[0]
+        if self.return_sequences and self.batch_first:
+            # (seq, bs, hid) -> (bs, seq, hid)
+            y = autograd.transpose(y, (1, 0, 2))
+        return y
 
-    def get_output_sample_shape(self):
-        out = []
-        for i in range(len(self.conf.slice_conf.slice_point) + 1):
-            out.append(self.layer.GetOutputSampleShapeAt(i))
-        return out
+    def get_params(self):
+        return {self.W.name: self.W}
 
-    def forward(self, flag, x):
-        '''Slice the input tensor on the given axis.
-
-        Args:
-            flag: same as Layer::forward()
-            x: a single input tensor
-
-        Returns:
-            a list a output tensor
-        '''
-        if type(x) is tensor.Tensor:
-            x = [x]
-        assert type(x) is list, 'Must be a list of Tensor'
-        return super(Slice, self).forward(flag, x)
-
-    def backward(self, flag, grads):
-        '''Concate all grad tensors to generate a single output tensor
-
-        Args:
-            flag: same as Layer::backward()
-            grads: a list of tensors, one for the gradient of one sliced tensor
-
-        Returns:
-            a single tensor for the gradient of the original user, and an empty
-                list.
-        '''
-        assert len(grads) > 1, 'There must be multiple gradients'
-        dxs, _ = super(Slice, self).backward(flag, grads)
-        return dxs[0], []
+    def set_params(self, parameters):
+        self.set_attribute(self.W, parameters[self.W.name])
 
 
-class RNN(Layer):
-    '''Recurrent layer with 4 types of units, namely lstm, gru, tanh and relu.
-
-    Args:
-        hidden_size: hidden feature size, the same for all stacks of layers.
-        rnn_mode: decides the rnn unit, which could be one of 'lstm', 'gru',
-            'tanh' and 'relu', refer to cudnn manual for each mode.
-        num_stacks: num of stacks of rnn layers. It is different to the
-            unrolling seqence length.
-        input_mode: 'linear' convert the input feature x by by a linear
-            transformation to get a feature vector of size hidden_size;
-            'skip' does nothing but requires the input feature size equals
-            hidden_size
-        bidirection: True for bidirectional RNN
-        param_specs: config for initializing the RNN parameters.
-        input_sample_shape: includes a single integer for the input sample
-            feature size.
-    '''
-
-    def __init__(self,
-                 name,
-                 hidden_size,
-                 rnn_mode='lstm',
-                 dropout=0.0,
-                 num_stacks=1,
-                 input_mode='linear',
-                 bidirectional=False,
-                 param_specs=None,
-                 input_sample_shape=None):
-        assert cudnn_version >= 5005, 'RNN is supported since CUDNN V5.0.5; '\
-            'This version is %d' % cudnn_version
-        super(RNN, self).__init__(name)
-        conf = self.conf.rnn_conf
-        assert hidden_size > 0, 'Hidden feature size must > 0'
-        conf.hidden_size = hidden_size
-        assert rnn_mode in set(['lstm', 'gru', 'tanh', 'relu']),  \
-            'rnn mode %s is not available' % (rnn_mode)
-        conf.rnn_mode = rnn_mode
-        conf.num_stacks = num_stacks
-        conf.dropout = dropout
-        conf.input_mode = input_mode
-        conf.direction = 'unidirectional'
-        if bidirectional:
-            conf.direction = 'bidirectional'
-        # currently only has rnn layer implemented using cudnn
-        _check_engine(engine, ['cudnn'])
-        if param_specs is None:
-            param_specs = {
-                'name': name + '/weight',
-                'init': 'uniform',
-                'low': 0,
-                'high': 1
-            }
-        self.conf.param.extend([_construct_param_specs_from_dict(param_specs)])
-        self.param_specs.append(_construct_param_specs_from_dict(param_specs))
-
-        self.layer = singa_wrap.CudnnRNN()
-        if input_sample_shape is not None:
-            self.setup(input_sample_shape)
-
-    def forward(self, flag, inputs):
-        '''Forward inputs through the RNN.
-
-        Args:
-            flag: True(kTrain) for training; False(kEval) for evaluation;
-                others values for future use.
-            inputs, <x1, x2,...xn, hx, cx>, where xi is the input tensor for the
-                i-th position, its shape is (batch_size, input_feature_length);
-                the batch_size of xi must >= that of xi+1; hx is the initial
-                hidden state of shape (num_stacks * bidirection?2:1, batch_size,
-                hidden_size). cx is the initial cell state tensor of the same
-                shape as hy. cx is valid for only lstm. For other RNNs there is
-                no cx. Both hx and cx could be dummy tensors without shape and
-                data.
-
-        Returns:
-            <y1, y2, ... yn, hy, cy>, where yi is the output tensor for the i-th
-                position, its shape is (batch_size,
-                hidden_size * bidirection?2:1). hy is the final hidden state
-                tensor. cx is the final cell state tensor. cx is only used for
-                lstm.
-        '''
-        assert self.has_setup, 'Must call setup() before forward()'
-        assert len(inputs) > 1, 'The input to RNN must include at '\
-            'least one input tensor '\
-            'and one hidden state tensor (could be a dummy tensor)'
-        tensors = []
-        for t in inputs:
-            assert isinstance(t, tensor.Tensor), \
-                'input must be py Tensor %s' % (type(t))
-            tensors.append(t.data)
-        if type(flag) is bool:
-            if flag:
-                flag = model_pb2.kTrain
-            else:
-                flag = model_pb2.kEval
-        y = self.layer.ForwardWithMultInputs(flag, tensors)
-        return tensor.from_raw_tensors(y)
-
-    def backward(self, flag, grad):
-        '''Backward gradients through the RNN.
-
-        Args:
-            flag, for future use.
-            grad, <dy1, dy2,...dyn, dhy, dcy>, where dyi is the gradient for the
-            i-th output, its shape is (batch_size, hidden_size*bidirection?2:1);
-                dhy is the gradient for the final hidden state, its shape is
-                (num_stacks * bidirection?2:1, batch_size,
-                hidden_size). dcy is the gradient for the final cell state.
-                cx is valid only for lstm. For other RNNs there is
-                no cx. Both dhy and dcy could be dummy tensors without shape and
-                data.
-
-        Returns:
-            <dx1, dx2, ... dxn, dhx, dcx>, where dxi is the gradient tensor for
-                the i-th input, its shape is (batch_size,
-                input_feature_length). dhx is the gradient for the initial
-                hidden state. dcx is the gradient for the initial cell state,
-                which is valid only for lstm.
-        '''
-        if type(flag) is bool:
-            if flag:
-                flag = model_pb2.kTrain
-            else:
-                flag = model_pb2.kEval
-
-        tensors = []
-        for t in grad:
-            assert isinstance(t, tensor.Tensor), 'grad must be py Tensor'
-            tensors.append(t.data)
-        ret = self.layer.BackwardWithMultInputs(flag, tensors)
-        return tensor.from_raw_tensors(ret[0]), tensor.from_raw_tensors(ret[1])
-
-
-class LSTM(RNN):
-
-    def __init__(self,
-                 name,
-                 hidden_size,
-                 dropout=0.0,
-                 num_stacks=1,
-                 input_mode='linear',
-                 bidirectional=False,
-                 param_specs=None,
-                 input_sample_shape=None):
-        super(LSTM, self).__init__(name, hidden_size, 'lstm', dropout,
-                                   num_stacks, input_mode, bidirectional,
-                                   param_specs, input_sample_shape)
-
-
-class GRU(RNN):
-
-    def __init__(self,
-                 name,
-                 hidden_size,
-                 dropout=0.0,
-                 num_stacks=1,
-                 input_mode='linear',
-                 bidirectional=False,
-                 param_specs=None,
-                 input_sample_shape=None):
-        super(GRU, self).__init__(name, hidden_size, 'gru', dropout, num_stacks,
-                                  input_mode, bidirectional, param_specs,
-                                  input_sample_shape)
-
-
-def _check_engine(engine, allowed_engines):
-    assert engine.lower() in set(allowed_engines), \
-        '%s is not a supported engine. Pls use one of %s' % \
-        (engine, ', '.join(allowed_engines))
-
-
-def _create_layer(eng, layer):
-    ''' create singa wrap layer.
-
-    Both arguments are case insensitive.
-    Args:
-        engine, implementation engine, either 'singa' or 'cudnn'
-        layer, layer type, e.g., 'convolution', 'pooling'; for activation
-        layers, use the specific activation mode, e.g. 'relu', 'tanh'.
-    '''
-    assert eng != 'cudnn' or cudnn_version > 0, 'CUDNN is not enabled, please '\
-        'change the engine, e.g., layer.engine=singacpp'
-    layer_type = eng + '_' + layer
-    return singa_wrap.CreateLayer(layer_type.lower().encode())
-
-
-def _set_kernel_stride_pad(conf, kernel, stride, border_mode, pad, in_shape):
-    """Private function called by Convolution2D and Pooling2D.
-
-    PyTorch:
-        http://pytorch.org/docs/nn.html#pooling-layers
-        floor for both conv and pooling
-    Caffe:
-        https://github.com/BVLC/caffe/issues/1318#issuecomment-59594323
-        floor for conv and ceil for pooling
-    Tensorflow: https://www.tensorflow.org/api_guides/python/nn#Convolution
-        SAME  outsize = ceil(insize/stride),
-              pad_h_w = max((outsize-1)*stride+k-insize, 0)
-        VALID same as pytorch
-    """
-    if isinstance(kernel, tuple):
-        conf.kernel_h = kernel[0]
-        conf.kernel_w = kernel[1]
-    else:
-        conf.kernel_h = kernel
-        conf.kernel_w = kernel
-    if isinstance(stride, tuple):
-        conf.stride_h = stride[0]
-        conf.stride_w = stride[1]
-    else:
-        conf.stride_h = stride
-        conf.stride_w = stride
-    mode = border_mode.lower()
-    if pad is None:
-        # TODO(wangwei) check the border mode
-        if mode == 'same':
-            if conf.stride_h != 0:
-                out_h = in_shape[1] // conf.stride_h
-                ph = max(
-                    (out_h - 1) * conf.stride_h + conf.kernel_h - in_shape[1],
-                    0)
-            else:
-                ph = 0
-            out_w = in_shape[2] // conf.stride_w
-            pw = max((out_w - 1) * conf.stride_w + conf.kernel_w - in_shape[2],
-                     0)
-            assert ph % 2 == 0 and pw % 2 == 0, 'ph=%d and pw=%d are not even' \
-                % (ph, pw)
-            pad = (ph // 2, pw // 2)
-        elif mode == 'valid':
-            pad = (0, 0)
-        else:
-            assert False, ('Unsupported border_mode: %s. '
-                           'Please use {"VALID", "SAME"}' % border_mode)
-    if isinstance(pad, tuple):
-        conf.pad_h = pad[0]
-        conf.pad_w = pad[1]
-    else:
-        conf.pad_h = pad
-        conf.pad_w = pad
-    return conf
-
-
-def _construct_param_specs_from_dict(specs):
-    """Conver the param specs from a dict into ParamSpec protobuf object.
-
-    Args:
-        specs (dict): the fields inlcude
-            'name' for parameter name
-            'lr_mult' for learning rate multiplier;
-            'decay_mult' for weight decay multiplier;
-            'init' for init method, which could be 'gaussian', 'uniform',
-            'xavier' and 'msra';
-            'std', 'mean', 'high', 'low' are used by corresponding init methods;
-            'constraint' for gradient constraint, value is a float threshold for
-                clampping the gradient.
-            'regularizer' for regularization, currently support 'l2', value is a
-                float for the coefficient.
-
-    Returns:
-        a ParamSpec object
-    """
-    conf = model_pb2.ParamSpec()
-    if 'name' in specs:
-        conf.name = specs['name']
-    if 'lr_mult' in specs:
-        conf.lr_mult = specs['lr_mult']
-    if 'decay_mult' in specs:
-        conf.decay_mult = specs['decay_mult']
-    if 'init' in specs:
-        filler = conf.filler
-        filler.type = specs['init'].lower()
-        if specs['init'].lower() == 'uniform':
-            assert 'low' in specs and 'high' in specs, \
-                'low and high are required for "uniform" init method'
-            filler.min = specs['low']
-            filler.max = specs['high']
-        elif specs['init'].lower() == 'gaussian':
-            assert 'mean' in specs and 'std' in specs, \
-                'std and mean are required for "gaussian" init method'
-            filler.mean = specs['mean']
-            filler.std = specs['std']
-        elif specs['init'].lower() == 'constant' and 'value' in specs:
-            filler.value = specs['value']
-    if 'regularizer' in specs:
-        conf.regularizer.coefficient = specs['regularizer']
-    if 'constraint' in specs:
-        conf.constraint.threshold = specs['constraint']
-    return conf
-
-
-def _construct_param_specs_from_caffe_proto(lyr_conf):
-    """convert the param specs from a caffe layer proto into a singa paramspec
-    protobuf object.
-
-    args:
-        specs (dict): the fields inlcude
-            'name' for parameter name
-            'lr_mult' for learning rate multiplier;
-            'decay_mult' for weight decay multiplier;
-            'init' for init method, which could be 'gaussian', 'uniform',
-            'xavier' and 'msra';
-            'std', 'mean', 'high', 'low' are used by corresponding init methods;
-            caffe model has no 'constraint' and 'regularizer'
-
-    returns:
-        a pair of paramspec objects(weight and bias)
-    """
-    wparam = model_pb2.ParamSpec()
-    bparam = model_pb2.ParamSpec()
-    if len(lyr_conf.param) > 0:
-        wparam.name = lyr_conf.param[0].name
-        wparam.lr_mult = lyr_conf.param[0].lr_mult
-        wparam.decay_mult = lyr_conf.param[0].decay_mult
-        if len(lyr_conf.param) > 1:
-            bparam.name = lyr_conf.param[1].name
-            bparam.lr_mult = lyr_conf.param[1].lr_mult
-            bparam.decay_mult = lyr_conf.param[1].decay_mult
-    if wparam.name == '' or wparam.name is None:
-        wparam.name = lyr_conf.name + '_weight'
-    if bparam.name == '' or bparam.name is None:
-        bparam.name = lyr_conf.name + '_bias'
-    wfiller = wparam.filler
-    bfiller = bparam.filler
-    param = ''
-    if lyr_conf.type == 'Convolution' or lyr_conf.type == 4:
-        param = lyr_conf.convolution_conf
-    elif lyr_conf.type == 'InnerProduct' or lyr_conf.type == 14:
-        param = lyr_conf.dense_conf
-
-    if param != '':
-        wfiller.type = param.weight_filler.type.lower()
-        wfiller.min = param.weight_filler.min
-        wfiller.max = param.weight_filler.max
-        wfiller.mean = param.weight_filler.mean
-        wfiller.std = param.weight_filler.std
-        wfiller.value = param.weight_filler.value
-
-        bfiller.type = param.bias_filler.type.lower()
-        bfiller.min = param.bias_filler.min
-        bfiller.max = param.bias_filler.max
-        bfiller.mean = param.bias_filler.mean
-        bfiller.std = param.bias_filler.std
-        bfiller.value = param.bias_filler.value
-
-    return (wparam, bparam)
-
-
-def get_layer_list():
-    """ Return a list of strings which include the identifiers (tags) of all
-    supported layers
-    """
-    return [str(l) for l in singa_wrap.GetRegisteredLayers()]
+''' import autograd at the end to resolve circular import
+'''
+from singa import autograd
diff --git a/python/singa/loss.py b/python/singa/loss.py
deleted file mode 100644
index c84ba76..0000000
--- a/python/singa/loss.py
+++ /dev/null
@@ -1,216 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-# =============================================================================
-'''
-Loss module includes a set of training loss implmentations. Some are converted
-from C++ implementation, and the rest are implemented directly using python
-Tensor.
-
-Example usage::
-
-    from singa import tensor
-    from singa import loss
-    import numpy as np
-
-    x = tensor.Tensor((3, 5))
-    x.uniform(0, 1)  # randomly generate the prediction activation
-    y = tensor.from_numpy(np.array([0, 1, 3], dtype=np.int))  # set the truth
-
-    f = loss.SoftmaxCrossEntropy()
-    l = f.forward(True, x, y)  # l is tensor with 3 loss values
-    g = f.backward()  # g is a tensor containing all gradients of x w.r.t l
-'''
-from __future__ import division
-from __future__ import absolute_import
-from builtins import object
-
-from . import singa_wrap as singa
-from . import tensor
-from .proto import model_pb2
-
-
-class Loss(object):
-    '''Base loss class.
-
-    Subclasses that wrap the C++ loss classes can use the inherited foward,
-    backward, and evaluate functions of this base class. Other subclasses need
-    to override these functions
-    '''
-
-    def __init__(self):
-        self.swig_loss = None
-
-    def forward(self, flag, x, y):
-        '''Compute the loss values.
-
-        Args:
-            flag: kTrain/kEval or bool. If it is kTrain/True, then the backward
-                function must be called before calling forward again.
-            x (Tensor): the prediction Tensor
-            y (Tensor): the ground truch Tensor, x.shape[0] must = y.shape[0]
-
-        Returns:
-            a tensor of floats for the loss values, one per sample
-        '''
-        if type(flag) is bool:
-            if flag:
-                flag = model_pb2.kTrain
-            else:
-                flag = model_pb2.kEval
-        return tensor.from_raw_tensor(
-            self.swig_loss.Forward(flag, x.data, y.data))
-
-    def backward(self):
-        '''
-        Returns:
-            the grad of x w.r.t. the loss
-        '''
-        return tensor.from_raw_tensor(self.swig_loss.Backward())
-
-    def evaluate(self, flag, x, y):  # TODO(wangwei) remove flag
-        '''
-        Args:
-            flag (int): must be kEval, to be removed
-            x (Tensor): the prediction Tensor
-            y (Tensor): the ground truth Tnesor
-
-        Returns:
-            the averaged loss for all samples in x.
-        '''
-        if type(flag) is bool:
-            if flag:
-                flag = model_pb2.kTrain
-            else:
-                flag = model_pb2.kEval
-
-        return self.swig_loss.Evaluate(flag, x.data, y.data)
-
-
-class SoftmaxCrossEntropy(Loss):
-    '''This loss function is a combination of SoftMax and Cross-Entropy loss.
-
-    It converts the inputs via SoftMax function and then
-    computes the cross-entropy loss against the ground truth values.
-
-    For each sample, the ground truth could be a integer as the label index;
-    or a binary array, indicating the label distribution. The ground truth
-    tensor thus could be a 1d or 2d tensor.
-    The data/feature tensor could 1d (for a single sample) or 2d for a batch of
-    samples.
-    '''
-
-    def __init__(self):
-        super(SoftmaxCrossEntropy, self).__init__()
-        self.swig_loss = singa.SoftmaxCrossEntropy()
-
-
-class SigmoidCrossEntropy(Loss):
-    '''This loss evaluates the cross-entropy loss between the prediction and the
-    truth values with the prediction probability generated from Sigmoid.
-    '''
-
-    def __init__(self, epsilon=1e-8):
-        super(SigmoidCrossEntropy, self).__init__()
-        self.truth = None
-        self.prob = None
-        self.epsilon = epsilon  # to avoid log(x) with x being too small
-
-    def forward(self, flag, x, y):
-        '''loss is -yi * log pi - (1-yi) log (1-pi), where pi=sigmoid(xi)
-
-        Args:
-            flag (bool): true for training; false for evaluation
-            x (Tensor): the prediction Tensor
-            y (Tensor): the truth Tensor, a binary array value per sample
-
-        Returns:
-            a Tensor with one error value per sample
-        '''
-        p = tensor.sigmoid(x)
-        if flag:
-            self.truth = y
-            self.prob = p
-        np = 1 - p
-        p += (p < self.epsilon) * self.epsilon
-        np += (np < self.epsilon) * self.epsilon
-        l = (y - 1) * tensor.log(np) - y * tensor.log(p)
-        # TODO(wangwei): add unary operation -Tensor
-        return tensor.average(l, axis=1)
-
-    def backward(self):
-        ''' Compute the gradient of loss w.r.t to x.
-
-        Returns:
-            dx = pi - yi.
-        '''
-        assert self.truth is not None, 'must call forward in a prior'
-        dx = self.prob - self.truth
-        self.truth = None
-        return dx
-
-    def evaluate(self, flag, x, y):
-        '''Compuate the averaged error.
-
-        Returns:
-            a float value as the averaged error
-        '''
-        l = self.forward(False, x, y)
-        return l.l1()
-
-
-class SquaredError(Loss):
-    '''This loss evaluates the squared error between the prediction and the
-    truth values.
-
-    It is implemented using Python Tensor operations.
-    '''
-
-    def __init__(self):
-        super(SquaredError, self).__init__()
-        self.err = None
-
-    def forward(self, flag, x, y):
-        '''Compute the error as 0.5 * ||x-y||^2.
-
-        Args:
-            flag (int): kTrain or kEval; if kTrain, then the backward must be
-                called before calling forward again.
-            x (Tensor): the prediction Tensor
-            y (Tensor): the truth Tensor, an integer value per sample, whose
-                value is [0, x.shape[1])
-
-        Returns:
-            a Tensor with one error value per sample
-        '''
-        self.err = x - y
-        return tensor.square(self.err) * 0.5
-
-    def backward(self):
-        '''Compute the gradient of x w.r.t the error.
-
-        Returns:
-            x - y
-        '''
-        return self.err
-
-    def evaluate(self, flag, x, y):
-        '''Compuate the averaged error.
-
-        Returns:
-            a float value as the averaged error
-        '''
-        return tensor.sum(tensor.square(x - y)) * 0.5 / x.size()
diff --git a/python/singa/metric.py b/python/singa/metric.py
deleted file mode 100644
index 73d57df..0000000
--- a/python/singa/metric.py
+++ /dev/null
@@ -1,218 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-# =============================================================================
-"""This module includes a set of metric classes for evaluating the model's
-performance. The specific metric classes could be converted from C++
-implmentation or implemented directly using Python.
-
-
-Example usage::
-
-    from singa import tensor
-    from singa import metric
-    import numpy as np
-
-    x = tensor.Tensor((3, 5))
-    x.uniform(0, 1)  # randomly genearte the prediction activation
-    x = tensor.Softmax(x)  # normalize the prediction into probabilities
-    y = tensor.from_numpy(np.array([0, 1, 3], dtype=np.int))  # set the truth
-
-    f = metric.Accuracy()
-    acc = f.evaluate(x, y)  # averaged accuracy over all 3 samples in x
-
-"""
-from __future__ import division
-from __future__ import absolute_import
-
-from builtins import range
-from builtins import object
-
-from . import singa_wrap as singa
-from . import tensor
-import numpy as np
-
-
-class Metric(object):
-    """Base metric class.
-
-    Subclasses that wrap the C++ loss classes can use the inherited foward,
-    and evaluate functions of this base class. Other subclasses need
-    to override these functions. Users need to feed in the **predictions** and
-    ground truth to get the metric values.
-    """
-
-    def __init__(self):
-        self.swig_metric = None
-
-    def forward(self, x, y):
-        """Compute the metric for each sample.
-
-        Args:
-            x (Tensor): predictions, one row per sample
-            y (Tensor): ground truth values, one row per sample
-
-        Returns:
-            a tensor of floats, one per sample
-        """
-        return tensor.from_raw_tensor(self.swig_metric.Forward(x.data, y.data))
-
-    def evaluate(self, x, y):
-        """Compute the averaged metric over all samples.
-
-        Args:
-            x (Tensor): predictions, one row per sample
-            y (Tensor): ground truth values, one row per sample
-        Returns:
-            a float value for the averaged metric
-        """
-        return self.swig_metric.Evaluate(x.data, y.data)
-
-
-class Accuracy(Metric):
-    """Compute the top one accuracy for single label prediction tasks.
-
-    It calls the C++ functions to do the calculation.
-    """
-
-    def __init__(self):
-        self.swig_metric = singa.Accuracy()
-
-
-class Precision(Metric):
-    """Make the top-k labels of max probability as the prediction
-
-    Compute the precision against the groundtruth labels
-    """
-
-    def __init__(self, top_k):
-        self.top_k = top_k
-
-    def forward(self, x, y):
-        """Compute the precision for each sample.
-
-        Convert tensor to numpy for computation
-
-        Args:
-            x (Tensor): predictions, one row per sample
-            y (Tensor): ground truth labels, one row per sample
-
-        Returns:
-            a tensor of floats, one per sample
-        """
-
-        dev = x.device
-        x.to_host()
-        y.to_host()
-
-        x_np = tensor.to_numpy(x)
-        y_np = tensor.to_numpy(y)
-
-        # Sort in descending order
-        pred_np = np.argsort(-x_np)[:, 0:self.top_k]
-
-        prcs_np = np.zeros(pred_np.shape[0], dtype=np.float32)
-
-        for i in range(pred_np.shape[0]):
-            # groundtruth labels
-            label_np = np.argwhere(y_np[i])
-
-            # num of common labels among prediction and groundtruth
-            num_intersect = np.intersect1d(pred_np[i], label_np).size
-            prcs_np[i] = num_intersect / float(self.top_k)
-
-        precision = tensor.from_numpy(prcs_np)
-
-        x.to_device(dev)
-        y.to_device(dev)
-        precision.to_device(dev)
-
-        return precision
-
-    def evaluate(self, x, y):
-        """Compute the averaged precision over all samples.
-
-        Args:
-            x (Tensor): predictions, one row per sample
-            y (Tensor): ground truth values, one row per sample
-        Returns:
-            a float value for the averaged metric
-        """
-
-        return tensor.average(self.forward(x, y))
-
-
-class Recall(Metric):
-    """Make the top-k labels of max probability as the prediction
-
-    Compute the recall against the groundtruth labels
-    """
-
-    def __init__(self, top_k):
-        self.top_k = top_k
-
-    def forward(self, x, y):
-        """Compute the recall for each sample.
-
-        Convert tensor to numpy for computation
-
-        Args:
-            x (Tensor): predictions, one row per sample
-            y (Tensor): ground truth labels, one row per sample
-
-        Returns:
-            a tensor of floats, one per sample
-        """
-
-        dev = x.device
-        x.to_host()
-        y.to_host()
-
-        x_np = tensor.to_numpy(x)
-        y_np = tensor.to_numpy(y)
-
-        # Sort in descending order
-        pred_np = np.argsort(-x_np)[:, 0:self.top_k]
-
-        recall_np = np.zeros(pred_np.shape[0], dtype=np.float32)
-
-        for i in range(pred_np.shape[0]):
-            # Return the index of non-zero dimension of i-th sample
-            label_np = np.argwhere(y_np[i])
-
-            # Num of common labels among prediction and groundtruth
-            num_intersect = np.intersect1d(pred_np[i], label_np).size
-            recall_np[i] = float(num_intersect) / label_np.size
-
-        recall = tensor.from_numpy(recall_np)
-
-        x.to_device(dev)
-        y.to_device(dev)
-        recall.to_device(dev)
-
-        return recall
-
-    def evaluate(self, x, y):
-        """Compute the averaged precision over all samples.
-
-        Args:
-            x (Tensor): predictions, one row per sample
-            y (Tensor): ground truth values, one row per sample
-        Returns:
-            a float value for the averaged metric
-        """
-
-        return tensor.average(self.forward(x, y))
diff --git a/python/singa/model.py b/python/singa/model.py
new file mode 100644
index 0000000..5f1ed2c
--- /dev/null
+++ b/python/singa/model.py
@@ -0,0 +1,354 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# =============================================================================
+'''
+This script includes Model class for python users
+to use Computational Graph in their model.
+'''
+
+import os
+import gc
+import time
+import json
+import zipfile
+import numpy as np
+from functools import wraps
+from collections import Iterable
+
+from singa import tensor
+from singa import autograd
+from singa import layer
+from .tensor import Tensor
+from . import singa_wrap as singa
+
+
+class ModelMeta(layer.LayerMeta):
+
+    def buffer_operation(func):
+
+        def remove_creator(tensors):
+            if not tensors:
+                return
+
+            if isinstance(tensors, Iterable):
+                for item in tensors:
+                    if isinstance(item, Iterable):
+                        remove_creator(item)
+                    elif isinstance(item, tensor.Tensor):
+                        item.creator = None
+            elif isinstance(tensors, tensor.Tensor):
+                tensors.creator = None
+
+        @wraps(func)
+        def wrapper(self, *args, **kwargs):
+            if self.graph_mode and self.training:
+                if len(args) == 0:
+                    raise ValueError('expect at least one input tensor')
+
+                if isinstance(args[0], list):
+                    assert isinstance(
+                        args[0][0],
+                        Tensor), ('function expects PlaceHolders or Tensors')
+                    dev = args[0][0].device
+                else:
+                    assert isinstance(
+                        args[0],
+                        Tensor), ('function expects PlaceHolders or Tensors')
+                    dev = args[0].device
+
+                if not self._buffered:
+                    # buffer operations
+                    dev.EnableGraph(True)
+                    self._results = func(self, *args, **kwargs)
+                    dev.Sync()
+                    dev.EnableGraph(False)
+                    self._buffered = True
+
+                    # deconstruct Operations before running the entire graph
+                    remove_creator(self._results)
+
+                    # make sure all Operations are deallocated
+                    gc.collect()
+
+                # run graph
+                dev.RunGraph(self.sequential)
+                return self._results
+            else:
+                return func(self, *args, **kwargs)
+
+        return wrapper
+
+    def __new__(cls, name, bases, attr):
+        if 'train_one_batch' in attr:
+            attr['train_one_batch'] = ModelMeta.buffer_operation(
+                attr['train_one_batch'])
+
+        return super(ModelMeta, cls).__new__(cls, name, bases, attr)
+
+
+class Model(layer.Layer, metaclass=ModelMeta):
+    """ Base class for your neural network models.
+
+    Example usage::
+
+        import numpy as np
+        from singa import opt
+        from singa import tensor
+        from singa import device
+        from singa import autograd
+        from singa import layer
+        from singa import model
+
+        class MyModel(model.Model):
+            def __init__(self):
+                super(MyModel, self).__init__()
+
+                self.softmax_cross_entropy = layer.SoftMaxCrossEntropy()
+                self.conv1 = layer.Conv2d(1, 20, 5, padding=0)
+                self.conv2 = layer.Conv2d(20, 50, 5, padding=0)
+                self.sgd = opt.SGD(lr=0.01)
+
+            def forward(self, x):
+                y = self.conv1(x)
+                y = self.conv2(y)
+                return y
+
+            def train_one_batch(self, x, y):
+                out = self.forward(x)
+                loss = self.softmax_cross_entropy(out, y)
+                self.sgd(loss)
+                return out, loss
+
+    """
+
+    # save load states constant
+    TENSOR_DICT_FILENAME = '/tensor_dict.npz'
+    STATES_ATTR_FILENAME = '/states_attr.json'
+    MODEL_STATE_TYPE = 0
+    AUX_STATE_TYPE = 1
+
+    def __init__(self):
+        """
+        Initializes internal Model state
+        """
+        super(Model, self).__init__()
+
+        self.training = True
+        self.graph_mode = True
+        self.sequential = False
+        self._buffered = False
+        self._results = None
+
+    def compile(self, inputs, is_train=True, use_graph=False, sequential=False):
+        """ Compile and initialize the model
+
+        This function will automatically derive the shape of parameters
+        in each sublayer based on the shape of input placeholders. It will
+        also do some settings.
+
+        Args:
+            inputs(list): the list of input tensors(placeholders)
+            is_train(bool): when is_trainis True, this model will enter
+            training mode, otherwise it will enter the evaluation mode
+            use_graph(bool): when use_graph is True, computational graph
+            will be used to train this model
+            sequential(bool): when sequential is True, model will execute ops
+            in the graph follow the order of joining the graph
+        """
+        assert len(inputs) > 0 and isinstance(inputs[0], Tensor), (
+            'compile function expects PlaceHolders or Tensors')
+
+        dev = inputs[0].device
+        dev.EnableGraph(True)
+        self.forward(*inputs)
+        dev.EnableGraph(False)
+        dev.ResetGraph()
+
+        autograd.training = is_train
+        self.training = is_train
+        self.graph_mode = use_graph
+        self.sequential = sequential
+
+    def forward(self, *input):
+        """Defines the computation performed in every forward propagation.
+
+        Should be overridden by all subclasses.
+
+        Args:
+            *input: the input training data for the model
+
+        Returns:
+            out: the outputs of the forward propagation.
+        """
+        raise NotImplementedError
+
+    def train_one_batch(self, *input, **kwargs):
+        """Defines the computation performed in every training iteration
+
+        Should be overridden by all subclasses.
+
+        Args:
+            *input: the arguments of train_one_batch
+            **kwargs: the keyword arguments of train_one_batch
+        """
+        raise NotImplementedError
+
+    def train(self, mode=True):
+        """Set the model in evaluation mode.
+
+        Args:
+            mode(bool): when mode is True, this model will enter training mode
+        """
+        self.training = mode
+        autograd.training = mode
+
+    def eval(self):
+        """Sets the model in evaluation mode.
+        """
+        self.train(mode=False)
+
+    def graph(self, mode=True, sequential=False):
+        """ Turn on the computational graph. Specify execution mode.
+
+        Args:
+            mode(bool): when mode is True, model will use computational graph
+            sequential(bool): when sequential is True, model will execute ops
+            in the graph follow the order of joining the graph
+        """
+        self.graph_mode = mode
+        self.sequential = sequential
+
+    def __get_name__(self):
+        return self.__class__.__name__
+
+    def __call__(self, *input, **kwargs):
+        if self.training:
+            return self.train_one_batch(*input, **kwargs)
+        else:
+            return self.forward(*input, **kwargs)
+
+    def save_states(self, fpath, aux_states={}):
+        """Save states.
+
+        Args:
+            fpath: output file path (without the extension)
+            aux_states(dict): values are standard data types or Tensor,
+                              e.g., epoch ID, learning rate, optimizer states
+        """
+        assert not os.path.isfile(fpath), (
+            "Failed to save states, %s is already existed." % fpath)
+
+        states = self.get_states()
+
+        # save states data and attr
+        tensor_dict = {}
+        states_attr = {}
+        for k, v in states.items():
+            assert isinstance(v, tensor.Tensor), "Only tensor state is allowed"
+            tensor_dict[k] = tensor.to_numpy(v)
+            states_attr[k] = {
+                'state_type': self.MODEL_STATE_TYPE,
+                'shape': v.shape,
+                'dtype': v.dtype
+            }
+
+        for k, v in aux_states.items():
+            assert isinstance(v,
+                              tensor.Tensor), "Only tensor aux state is allowed"
+            tensor_dict[k] = tensor.to_numpy(v)
+            states_attr[k] = {
+                'state_type': self.AUX_STATE_TYPE,
+                'shape': v.shape,
+                'dtype': v.dtype
+            }
+
+        # save to files
+        timestamp = time.time()
+        tmp_dir = '/tmp/singa_save_states_%s' % timestamp
+        os.mkdir(tmp_dir)
+        tensor_dict_fp = tmp_dir + self.TENSOR_DICT_FILENAME
+        states_attr_fp = tmp_dir + self.STATES_ATTR_FILENAME
+
+        np.savez(tensor_dict_fp, **tensor_dict)
+
+        with open(states_attr_fp, 'w') as fp:
+            json.dump(states_attr, fp)
+
+        compression = zipfile.ZIP_DEFLATED
+        with zipfile.ZipFile(fpath, mode="w") as zf:
+            zf.write(tensor_dict_fp,
+                     os.path.basename(tensor_dict_fp),
+                     compress_type=compression)
+            zf.write(states_attr_fp,
+                     os.path.basename(states_attr_fp),
+                     compress_type=compression)
+
+        # clean up tmp files
+        os.remove(tensor_dict_fp)
+        os.remove(states_attr_fp)
+        os.rmdir(tmp_dir)
+
+    def load_states(self, fpath):
+        """Load the model states and auxiliary states from disk.
+
+        Usage:
+            m = MyModel()
+            m.compile(...)
+            aux_states = m.load_states('mymodel.zip')
+
+        Args:
+            path: input file path (without the extension)
+        Returns:
+            dict
+        """
+
+        assert os.path.isfile(fpath), (
+            "Failed to load states, %s is not exist." % fpath)
+
+        timestamp = time.time()
+        tmp_dir = '/tmp/singa_load_states_%s' % timestamp
+        os.mkdir(tmp_dir)
+
+        with zipfile.ZipFile(fpath, 'r') as zf:
+            zf.extractall(tmp_dir)
+
+        tensor_dict_fp = tmp_dir + self.TENSOR_DICT_FILENAME
+        states_attr_fp = tmp_dir + self.STATES_ATTR_FILENAME
+
+        with open(states_attr_fp) as f:
+            states_attr = json.load(f)
+
+        tensor_dict = np.load(tensor_dict_fp)
+
+        # restore singa tensor from numpy
+        model_states = dict()
+        aux_states = dict()
+
+        for k in tensor_dict.files:
+            if states_attr[k]['state_type'] == self.MODEL_STATE_TYPE:
+                model_states[k] = tensor.from_numpy(tensor_dict[k])
+            elif states_attr[k]['state_type'] == self.AUX_STATE_TYPE:
+                aux_states[k] = tensor.from_numpy(tensor_dict[k])
+
+        # restore model_states
+        self.set_states(model_states)
+
+        # clean up tmp files
+        os.remove(tensor_dict_fp)
+        os.remove(states_attr_fp)
+        os.rmdir(tmp_dir)
+        return aux_states
diff --git a/python/singa/module.py b/python/singa/module.py
deleted file mode 100644
index 88881c0..0000000
--- a/python/singa/module.py
+++ /dev/null
@@ -1,193 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-# =============================================================================
-'''
-This script includes Module class for python users
-to use Computational Graph in their model.
-'''
-
-from functools import wraps
-
-from singa import autograd
-from . import singa_wrap as singa
-from .device import get_default_device
-
-import gc
-
-
-class Graph(type):
-
-    def buffer_operation(func):
-
-        @wraps(func)
-        def wrapper(self, *args, **kwargs):
-            if self.graph_mode and self.training:
-                name = func.__name__
-                if name not in self._called:
-                    # tag this function
-                    self._called.add(name)
-                    # buffer operations
-                    self._device.EnableGraph(True)
-                    ret = func(self, *args, **kwargs)
-                    self._device.Sync()
-                    self._device.EnableGraph(False)
-                    # deconstruct Operations before running the entire graph
-                    if name == 'optim':
-                        for fname in self._results:
-                            if isinstance(self._results[fname], list):
-                                for _matrix in self._results[fname]:
-                                    _matrix.creator = None
-                            else:
-                                self._results[fname].creator = None
-                        # make sure all Operations are deallocated
-                        gc.collect()
-                    # add result tensor
-                    self._results[name] = ret
-                    # run graph
-                    self._device.RunGraph(self.sequential)
-                    self.initialized = True
-                    return ret
-
-                return self._results[name]
-            else:
-                return func(self, *args, **kwargs)
-
-        return wrapper
-
-    def __new__(cls, name, bases, attr):
-        attr["forward"] = Graph.buffer_operation(attr["forward"])
-        attr["loss"] = Graph.buffer_operation(attr["loss"])
-        attr["optim"] = Graph.buffer_operation(attr["optim"])
-
-        return super(Graph, cls).__new__(cls, name, bases, attr)
-
-
-class Module(object, metaclass=Graph):
-    """ Base class for your neural network modules.
-
-    Example usage::
-
-        import numpy as np
-        from singa import opt
-        from singa import tensor
-        from singa import device
-        from singa import autograd
-        from singa.module import Module
-
-        class Model(Module):
-            def __init__(self):
-                super(Model, self).__init__()
-
-                self.conv1 = autograd.Conv2d(1, 20, 5, padding=0)
-                self.conv2 = autograd.Conv2d(20, 50, 5, padding=0)
-
-                self.sgd = opt.SGD(lr=0.01)
-
-            def forward(self, x):
-                y = self.conv1(x)
-                y = self.conv2(y)
-                return y
-
-            def loss(self, out, y):
-                return autograd.softmax_cross_entropy(out, y)
-
-            def optim(self, loss):
-                self.sgd.backward_and_update(loss)
-
-    """
-
-    def __init__(self):
-        """
-        Initializes internal Module state
-        """
-        self.training = True
-        self.graph_mode = True
-        self.sequential = False
-        self.initialized = False
-        self._device = get_default_device()
-
-        self._results = {}
-        self._called = set()
-
-    def forward(self, *input):
-        """Defines the computation performed at every call.
-
-        Should be overridden by all subclasses.
-
-        Args:
-            *input: the input training data for the module
-
-        Returns:
-            out: the outputs of the forward propagation.
-        """
-        raise NotImplementedError
-
-    def loss(self, *args, **kwargs):
-        """Defines the loss function performed when training the module.
-        """
-        pass
-
-    def optim(self, *args, **kwargs):
-        """Defines the optim function for backward pass.
-        """
-        pass
-
-    def train(self, mode=True):
-        """Set the module in evaluation mode.
-
-        Args:
-            mode(bool): when mode is True, this module will enter training mode
-        """
-        self.training = mode
-        autograd.training = True
-
-    def eval(self):
-        """Sets the module in evaluation mode.
-        """
-        self.train(mode=False)
-        autograd.training = False
-
-    def graph(self, mode=True, sequential=False):
-        """ Turn on the computational graph. Specify execution mode.
-
-        Args:
-            mode(bool): when mode is True, module will use computational graph
-            sequential(bool): when sequential is True, module will execute ops
-            in the graph follow the order of joining the graph
-        """
-        self.graph_mode = mode
-        self.sequential = sequential
-
-    def on_device(self, device):
-        """Sets the target device.
-
-        The following training will be performed on that device.
-
-        Args:
-            device(Device): the target device
-        """
-        self._device = device
-
-    def __get_name__(self):
-        return self.__class__.__name__
-
-    def __call__(self, *input, **kwargs):
-        if self.graph_mode and self.training:
-            if self.initialized == True:
-                self._device.RunGraph(self.sequential)
-
-        return self.forward(*input, **kwargs)
diff --git a/python/singa/net.py b/python/singa/net.py
deleted file mode 100755
index 4f2a8c3..0000000
--- a/python/singa/net.py
+++ /dev/null
@@ -1,531 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# =============================================================================
-"""
-Nerual net class for constructing the nets using layers and providing access
-functions for net info, e.g., parameters.
-
-
-Example usages::
-
-    from singa import net as ffnet
-    from singa import metric
-    from singa import loss
-    from singa import layer
-    from singa import device
-
-    # create net and add layers
-    net = ffnet.FeedForwardNet(loss.SoftmaxCrossEntropy(), metric.Accuracy())
-    net.add(layer.Conv2D('conv1', 32, 5, 1, input_sample_shape=(3,32,32,)))
-    net.add(layer.Activation('relu1'))
-    net.add(layer.MaxPooling2D('pool1', 3, 2))
-    net.add(layer.Flatten('flat'))
-    net.add(layer.Dense('dense', 10))
-
-    # init parameters
-    for p in net.param_values():
-        if len(p.shape) == 0:
-            p.set_value(0)
-        else:
-            p.gaussian(0, 0.01)
-
-    # move net onto gpu
-    dev = device.create_cuda_gpu()
-    net.to_device(dev)
-
-    # training (skipped)
-
-    # do prediction after training
-    x = tensor.Tensor((2, 3, 32, 32), dev)
-    x.uniform(-1, 1)
-    y = net.predict(x)
-    print tensor.to_numpy(y)
-"""
-from __future__ import print_function
-from __future__ import absolute_import
-
-from builtins import zip
-from builtins import str
-from builtins import object
-import numpy as np
-import os
-
-from .proto.model_pb2 import kTrain, kEval
-from .__init__ import __version__
-from . import tensor
-from . import layer
-from . import snapshot
-
-try:
-    import pickle
-except ImportError:
-    import cPickle as pickle
-'''For display training information, e.g L1 value of layer data'''
-verbose = False
-
-
-class FeedForwardNet(object):
-
-    def __init__(self, loss=None, metric=None):
-        '''Representing a feed-forward neural net.
-
-        Args:
-            loss, a Loss instance. Necessary training
-            metric, a Metric instance. Necessary for evaluation
-        '''
-        self.loss = loss
-        self.metric = metric
-        self.layers = []
-        self.src_of_layer = {}
-        self.dst_of_layer = None
-        self.ordered_layers = None
-        self.out_sample_shape_of_layer = {}
-
-    def to_device(self, dev):
-        '''Move the net onto the given device, including
-        all parameters and intermediate data.
-        '''
-        for lyr in self.layers:
-            lyr.to_device(dev)
-
-    def add(self, lyr, src=None):
-        """Append a layer into the layer list.
-
-        This function will get the sample shape from the src layers to setup the
-        newly added layer. For the first layer, it is setup outside. The calling
-        function should ensure the correctness of the layer order. If src is
-        None, the last layer is the src layer. If there are multiple src layers,
-        the src is a list of the src layers.
-
-        Args:
-            lyr (Layer): the layer to be added
-            src (Layer): the source layer of lyr
-        """
-        if src is not None:
-            if isinstance(src, layer.Layer):
-                assert src.has_setup is True, 'the source layer must be set up'
-                self.src_of_layer[lyr.name] = [src]
-            else:
-                assert type(src) == list, 'the src must be a list of layers'
-                self.src_of_layer[lyr.name] = src
-                # print 'merge------', len(src)
-        else:
-            assert len(self.layers) > 0 or lyr.has_setup, \
-                'Source layers are needed to set up this layer'
-            if len(self.layers) > 0:
-                self.src_of_layer[lyr.name] = [self.layers[-1]]
-            else:
-                self.src_of_layer[lyr.name] = []
-        if lyr.has_setup is False:
-            in_shape = []
-            for src in self.src_of_layer[lyr.name]:
-                shapes = self.out_sample_shape_of_layer[src.name]
-                assert len(shapes) > 0, \
-                    'Cannot get output shape of layer %s' % lyr.name
-                in_shape.append(shapes[0])
-                shapes.pop(0)
-            if len(in_shape) == 1:
-                lyr.setup(in_shape[0])
-            else:
-                lyr.setup(in_shape)
-        out_shape = lyr.get_output_sample_shape()
-        if type(out_shape[0]) is tuple:
-            self.out_sample_shape_of_layer[lyr.name] = out_shape
-        else:
-            self.out_sample_shape_of_layer[lyr.name] = [out_shape]
-        self.layers.append(lyr)
-        print((lyr.name, out_shape))
-        return lyr
-
-    def param_values(self):
-        '''Return a list of tensors for all parameters'''
-        values = []
-        layers = self.layers
-        if self.ordered_layers is not None:
-            layers = self.ordered_layers
-        for lyr in layers:
-            values.extend(lyr.param_values())
-        return values
-
-    def param_specs(self):
-        '''Return a list of ParamSpec for all parameters'''
-        specs = []
-        layers = self.layers
-        if self.ordered_layers is not None:
-            layers = self.ordered_layers
-        for lyr in layers:
-            specs.extend(lyr.param_specs)
-        return specs
-
-    def param_names(self):
-        '''Return a list for the names of all params'''
-        return [spec.name for spec in self.param_specs()]
-
-    def train(self, x, y):
-        '''Run BP for one iteration.
-        This method is deprecated. It is only kept for backward compatibility.
-        The name of this method is confusing since it does not update parameters.
-        Please use backprob() instead.
-        The back progagation algorithm computes gradients but it does not train.
-        '''
-        return self.backprob(x, y)
-
-    def backprob(self, x, y):
-        '''Run BP for one iteration.
-
-        Currently only support nets with a single output layer, and a single
-        loss objective and metric.
-        For multiple outputs (with multiple loss/metric), please manually
-        call forward, compute loss/metric and call backward. backward() is also
-        more memory efficient than this function.
-
-        Args:
-            x: input data, a single input Tensor or a dict: layer name -> Tensor
-            y: label data, a single input Tensor.
-        Returns:
-            gradients of parameters and the loss and metric values.
-        '''
-        out = self.forward(kTrain, x)
-        l = self.loss.forward(kTrain, out, y)
-        g = self.loss.backward()
-        g /= x.shape[0]
-        m = None
-        if self.metric is not None:
-            m = self.metric.evaluate(out, y)
-        grads = []  # store all gradient tensors; memory inefficient
-        for _, _, grad, _ in self.backward(g):
-            grads.extend(grad[::-1])
-        return grads[::-1], (l.l1(), m)
-
-    def evaluate(self, x, y):
-        '''Evaluate the loss and metric of the given data.
-
-        Currently only support nets with a single output layer, and a single
-        loss objective and metric.
-        TODO(wangwei) consider multiple loss objectives and metrics.
-
-        Args:
-            x: input data, a single input Tensor or a dict: layer name -> Tensor
-            y: label data, a single input Tensor.
-        '''
-        out = self.forward(kEval, x)
-        l = None
-        m = None
-        assert self.loss is not None or self.metric is not None,\
-            'Cannot do evaluation, as neither loss nor metic is set'
-        if self.loss is not None:
-            l = self.loss.evaluate(kEval, out, y)
-        if self.metric is not None:
-            m = self.metric.evaluate(out, y)
-        return l, m
-
-    def predict(self, x):
-        '''Forward the input data through each layer to get the values of the
-        output layers.
-
-        Currently only support nets with a single output layer
-        TODO(yujian) to handle multiple outputs from the network
-
-        Args:
-            x: input data, a single input Tensor or a dict: layer name -> Tensor
-
-        Returns:
-            a single output tensor as the prediction result.
-
-        '''
-
-        xx = self.forward(kEval, x)
-        if type(xx) is dict:
-            return tensor.softmax(list(xx.values())[0])
-        else:
-            return tensor.softmax(xx)
-
-    def topo_sort(self, layers, src_of_layer):
-        '''Topology sort of layers.
-
-        It would try to preserve the orders of the input layers.
-
-        Args:
-            layers: a list of layers; the layers from the output of the same
-                layer (e.g., slice layer) should be added by users in correct
-                order; This function would not change their order.
-            src_of_layer: a dictionary: src layer name -> a list of src layers
-
-        Returns:
-            A list of ordered layer
-        '''
-        order = []
-        while len(order) < len(layers):
-            for lyr in self.layers:
-                if lyr not in order:
-                    for src in src_of_layer[lyr.name]:
-                        if src not in order:
-                            break
-                    order.append(lyr)
-        return order
-
-    def forward(self, flag, x, output=[], freeze=None):
-        '''Forward the input(s) through every layer.
-
-        Args:
-            flag: True for training; False for evaluation; could also be
-                model_pb2.kTrain or model_pb2.kEval, or other values for future
-                use.
-            x: a single SINGA tensor if there is a single input; otherwise, a
-                dictionary: layer name-> singa tensor, for each layer accepting
-                input data. Do not associate a layer with input tensor if it is
-                connected from another layer. For such case, use a Dummy() layer
-                to accept the input data and connect the dummy layer to this
-                layer.
-            output(list): a list of layer names whose output would be returned
-                in addition to the default output.
-            freeze(str): layer name, freeze all layers before this layer; flag
-                is set to false for these layers.
-
-        Returns:
-            if there is only one output layer and output arg is empty, return
-                the result from the single output layer; otherwise, return a
-                dictionary: layer name -> output tensor(s)
-        '''
-        if self.ordered_layers is None:
-            self.ordered_layers = self.topo_sort(self.layers, self.src_of_layer)
-        if type(x) is dict:
-            input_of_layer = x
-        else:
-            assert isinstance(x, tensor.Tensor), \
-                'The inputs of a net should be dict or a single tensor'
-            input_of_layer = {self.ordered_layers[0].name: x}
-        output_of_layer = {}  # outputs generated by each layer
-        ret = {}  # outputs to return
-        if freeze is not None:
-            is_valid = False
-            for lyr in self.ordered_layers:
-                is_valid |= lyr.name == freeze
-            assert is_valid, 'Invalid freeze layer name =%s' % freeze
-            old_flag = flag
-            flag = False
-        for cur in self.ordered_layers:
-            if cur.name == freeze:
-                flag = old_flag
-            inputs = []
-            if cur.name in input_of_layer:
-                if type(input_of_layer[cur.name]) is list:
-                    inputs.extend(input_of_layer[cur.name])
-                else:
-                    inputs.append(input_of_layer[cur.name])
-            srcs = self.src_of_layer[cur.name]
-            disp_src = ''
-            for src in srcs:
-                outs = output_of_layer[src.name]
-                if type(outs) == list:
-                    assert len(outs) > 0, \
-                        'the output from layer %s is empty' % src.name
-                    inputs.append(outs[0])
-                    outs.pop(0)
-                    if len(outs) == 0:
-                        output_of_layer.pop(src.name)
-                else:
-                    inputs.append(outs)
-                    output_of_layer[cur.name] = []
-                    output_of_layer.pop(src.name)
-            if len(inputs) == 1:
-                inputs = inputs[0]
-            out = cur.forward(flag, inputs)
-            if verbose:
-                disp_src = '+'.join([src.name for src in srcs])
-                disp_src += '-->' + cur.name
-                if type(out) is list:
-                    print('%s: %s' %
-                          (disp_src, ' '.join([str(o.l1()) for o in out])))
-                else:
-                    print('%s: %f' % (disp_src, out.l1()))
-            output_of_layer[cur.name] = out
-            if cur.name in output:
-                ret[cur.name] = out
-            # print lyr.name, x.l1()
-        # print output_of_layer
-        ret.update(output_of_layer)
-        if len(ret) == 1:
-            return list(ret.values())[0]
-        else:
-            return ret
-
-    def backward(self, dy, output=[], freeze=None):
-        '''Run back-propagation after forward-propagation.
-
-        Args:
-            dy: a single tensor if there is a single loss function; otherwise,
-                a dictionary maps the name of the layer connecting to the loss
-                function -> gradient from the loss function. Do not associate a
-                layer with gradient tensor if it is connecting to another layer.
-                For such case, connect this layer to a Dummy() layer and use the
-                dummy layer to accept the gradient.
-            output(list): a list of layer names whose output gradient would be
-                returned in addition to the param gradient
-            freeze(str): layer name, stop backward after this layer.
-
-        Returns:
-                a geneartor iterator that generates
-                (param_names, param_values, param_grads, layer_grads) after
-                processing each layer h, where the first three lists are for h
-                and the last item is a dictionary which maps
-                layer name -> its output gradient tensor(s). At the end of this
-                function, the key set includes all layers in the output arg.
-        '''
-        if self.dst_of_layer is None:
-            self.dst_of_layer = {}
-            for cur in self.layers:
-                self.dst_of_layer[cur.name] = []
-            for cur in self.ordered_layers[1:]:
-                srcs = self.src_of_layer[cur.name]
-                for src in srcs:
-                    self.dst_of_layer[src.name].append(cur)
-        output_of_layer = {}  # outputs generated by each layer
-        ret = {}  # outputs to return
-        if type(dy) is dict:
-            input_of_layer = dy
-        else:
-            assert isinstance(dy, tensor.Tensor), \
-                'The inputs of a net should be dict or a single tensor'
-            input_of_layer = {self.ordered_layers[-1].name: dy}
-        for cur in reversed(self.ordered_layers):
-            inputs = []
-            if cur.name in input_of_layer:
-                if type(input_of_layer[cur.name]) is list:
-                    inputs.extend(input_of_layer[cur.name])
-                else:
-                    inputs.append(input_of_layer[cur.name])
-            for dst in self.dst_of_layer[cur.name]:
-                outputs = output_of_layer[dst.name]
-                if type(outputs) == list:
-                    assert len(outputs) > 0, \
-                        'the gradient from layer %s is empty' % dst.name
-                    inputs.append(outputs[0])
-                    outputs.pop(0)
-                else:
-                    inputs.append(outputs)
-                    output_of_layer[dst.name] = []
-                # del output_of_layer[dst.name]
-            if len(inputs) == 1:
-                inputs = inputs[0]
-            outs, pgrads = cur.backward(kTrain, inputs)
-            if verbose:
-                disp_src = '+'.join(
-                    [dst.name for dst in self.dst_of_layer[cur.name]])
-                disp_src += '-->' + cur.name
-                if type(outs) is list:
-                    print('%s: %s' %
-                          (disp_src, ' '.join([str(o.l1()) for o in outs])))
-                else:
-                    print('%s: %f' % (disp_src, outs.l1()))
-            if type(outs) is list:
-                output_of_layer[cur.name] = outs[::-1]
-            else:
-                output_of_layer[cur.name] = outs
-            if cur.name in output:
-                ret[cur.name] = outs
-            # ret.update(output_of_layer)
-            yield (cur.param_names(), cur.param_values(), pgrads, ret)
-            if cur.name == freeze:
-                break
-
-    def save(self, f, buffer_size=10, use_pickle=False):
-        '''Save model parameters using io/snapshot.
-
-        Args:
-            f: file name
-            buffer_size: size (MB) of the IO, default setting is 10MB; Please
-                make sure it is larger than any single parameter object.
-            use_pickle(Boolean): if true, it would use pickle for dumping;
-                otherwise, it would use protobuf for serialization, which uses
-                less space.
-        '''
-        if use_pickle:
-            params = {}
-            # since SINGA>=1.1.1  (1101)
-            params['SINGA_VERSION'] = __version__
-            for (name, val) in zip(self.param_names(), self.param_values()):
-                val.to_host()
-                params[name] = tensor.to_numpy(val)
-            if not f.endswith('.pickle'):
-                f = f + '.pickle'
-            with open(f, 'wb') as fd:
-                pickle.dump(params, fd)
-        else:
-            if f.endswith('.bin'):
-                f = f[0:-4]
-            sp = snapshot.Snapshot(f, True, buffer_size)
-            v = tensor.from_numpy(np.array([__version__]))
-            sp.write('SINGA_VERSION', v)
-            for (name, val) in zip(self.param_names(), self.param_values()):
-                val.to_host()
-                sp.write(name, val)
-
-    def load(self, f, buffer_size=10, use_pickle=False):
-        '''Load model parameters using io/snapshot.
-
-        Please refer to the argument description in save().
-        '''
-        version = 0
-
-        def get_name(name):
-            if version < 1101:
-                idx = name.rfind('/')
-                assert idx > 0, '/ must be in the parameter name'
-                name = name[:idx] + '_' + name[idx + 1:]
-            return name
-
-        if use_pickle:
-            print('NOTE: If your model was saved using Snapshot, '
-                  'then set use_pickle=False for loading it')
-            if not os.path.exists(f):
-                # guess the correct path
-                if f.endswith('.pickle'):
-                    f = f[0:-7]
-                else:
-                    f = f + '.pickle'
-            assert os.path.exists(f), 'file not exists %s w/o .pickle' % f
-            with open(f, 'rb') as fd:
-                params = pickle.load(fd, encoding='iso-8859-1')
-        else:
-            print('NOTE: If your model was saved using pickle, '
-                  'then set use_pickle=True for loading it')
-            if f.endswith('.bin'):
-                f = f[0:-4]
-            sp = snapshot.Snapshot(f, False, buffer_size)
-            params = sp.read()
-
-        if 'SINGA_VERSION' in params:
-            version = params['SINGA_VERSION']
-            if isinstance(version, tensor.Tensor):
-                version = tensor.to_numpy(version)[0]
-        else:
-            version = 1100
-        for name, val in zip(self.param_names(), self.param_values()):
-            name = get_name(name)
-            if name not in params:
-                print('Param: %s missing in the checkpoint file' % name)
-                continue
-            try:
-                if isinstance(params[name], tensor.Tensor):
-                    val.copy_data(params[name])
-                else:
-                    val.copy_from_numpy(params[name])
-            except AssertionError as err:
-                print('Error from copying values for param: %s' % name)
-                print(('shape of param vs checkpoint', val.shape,
-                       params[name].shape))
-                raise err
diff --git a/python/singa/opt.py b/python/singa/opt.py
index cacb14f..8eda563 100755
--- a/python/singa/opt.py
+++ b/python/singa/opt.py
@@ -18,9 +18,55 @@
 It replaces the old optimizers from optimizer.py'''
 
 from singa import tensor
+from singa.tensor import Tensor
 from singa import autograd
 from . import singa_wrap as singa
 
+from deprecated import deprecated
+
+
+class DecayScheduler:
+    # to be used for decaying learning rate or regularization coefficient or momentum, etc.
+    def __init__(self, init_value):
+        self.init_value = init_value
+
+    def __call__(self, step):
+        assert isinstance(step, Tensor)
+        return self.call(step)
+
+    def call(self, step) -> Tensor:
+        # step is a Tensor with a single scalar value
+        # return the current value as a Tensor
+        raise NotImplementedError
+
+
+class Constant(DecayScheduler):
+
+    def call(self, step: Tensor) -> Tensor:
+        # TODO should be an in-place operator
+        ret = Tensor((1,), step.device)
+        ret.set_value(self.init_value)
+        return ret
+
+
+class ExponentialDecay(DecayScheduler):
+
+    def __init__(self, init_value, decay_steps, decay_rate, staircase=False):
+        super(ExponentialDecay, self).__init__(init_value)
+
+        self.decay_steps = decay_steps
+        self.decay_rate = decay_rate
+        self.staircase = staircase
+
+    def call(self, step):
+        if self.staircase:
+            s = step // self.decay_steps
+        else:
+            s = step / self.decay_steps
+        ret = Tensor((1,), s.device)
+        ret.set_value(self.decay_rate)
+        return self.init_value * tensor.pow(ret, s)
+
 
 class Optimizer(object):
     """Base optimizer.
@@ -29,12 +75,61 @@
         config (Dict): specify the default values of configurable variables.
     """
 
-    def __init__(self, config):
-        self.default_config = config
-        self.iter = 0
-        self.param2config = {}
-        self.param2state = {}
+    def __init__(self, lr):
+        # init lr(could be a constant scalar or a learning rate scheduler)
+        if type(lr) == float or type(lr) == int:
+            self.lr = Constant(lr)
+        elif isinstance(lr, DecayScheduler):
+            self.lr = lr
+        else:
+            raise TypeError("Wrong learning rate type")
 
+        # init step counter
+        # TODO change type to int32
+        self.step_counter = Tensor((1,), dtype=tensor.float32)
+        self.step_counter.set_value(0)
+        self.lr_value = self.lr(self.step_counter)
+
+    def get_states(self):
+        # skip DecayScheduler as it does not have persistent states
+        return {'step_counter': tensor.to_numpy(self.step_counter)[0]}
+
+    def set_states(self, states):
+        self.step_counter = Tensor((1,))
+        self.step_counter.set_value(states['step_counter'])
+        self.lr_value = self.lr(self.step_counter)
+
+    def __call__(self, loss):
+        self.call(loss)
+        self.step()
+
+    def call(self, loss):
+        for p, g in autograd.backward(loss):
+            if p.name is None:
+                p.name = id(p)
+            self.apply(p.name, p, g)
+
+    def step(self):
+        """To increment the step counter and update the lr"""
+        self.step_counter.data += 1
+        lr_value = self.lr(self.step_counter)
+        self.lr_value.copy_from(lr_value)
+
+    def apply(self, param_name, param_value, param_grad):
+        """Performs a single optimization step.
+
+        Args:
+                param_name(String): the name of the param
+                param_value(Tensor): param values to be update in-place
+                grad(Tensor): param gradients; the values may be updated
+                        in this function; cannot use it anymore
+        """
+        raise NotImplementedError
+
+    @deprecated(
+        reason=
+        "Update is deprecated, use apply() to do update, refer to apply for more details."
+    )
     def update(self, param, grad):
         """Update the param values with given gradients.
 
@@ -43,30 +138,42 @@
             grad(Tensor): param gradients; the values may be updated
                     in this function; do not use it anymore
         """
-        pass
+        if param.name is None:
+            param.name = id(param)
+        self.apply(param.name, param, grad)
 
-    def step(self):
-        """To increment the step counter"""
-        self.iter += 1
+    def device_check(self, *inputs):
+        flag = inputs[0].device.graph_enabled()
+        inputs[0].device.EnableGraph(False)
+        x_device = inputs[0].device
+        x_dev_id = x_device.id()
+        for var in inputs:
+            if var.device.id() != x_dev_id:
+                var.to_device(x_device)
+        inputs[0].device.EnableGraph(flag)
 
-    def register(self, param_group, config):
-        for param in param_group:
-            assert param not in self.param2config, 'param is already registered'
+    @deprecated(
+        reason=
+        "backward_and_update is deprecated, use __call__() to do update, refer to __call__ for more details."
+    )
+    def backward_and_update(self, loss):
+        """Performs backward propagation from the loss and parameter update.
 
-            self.param2config[param] = config
+        From the loss, it performs backward propagation to get the gradients
+        and do the parameter update.
 
-    def load(self):
-        pass
-
-    def save(self):
-        pass
+        Args:
+                loss(Tensor): loss is the objective function of the deep learning model
+                optimization, e.g. for classification problem it can be the output of the
+                softmax_cross_entropy function.
+        """
+        self.__call__(loss)
 
 
 class SGD(Optimizer):
     """Implements stochastic gradient descent (optionally with momentum).
 
-    Nesterov momentum is based on the formula from
-    `On the importance of initialization and momentum in deep learning`__.
+    Nesterov momentum is based on the formula from `On the importance of initialization and momentum in deep learning`__.
 
     Args:
         lr(float): learning rate
@@ -111,77 +218,463 @@
                  dampening=0,
                  weight_decay=0,
                  nesterov=False):
-        if momentum < 0.0:
-            raise ValueError("Invalid momentum value: {}".format(momentum))
-        if weight_decay < 0.0:
-            raise ValueError(
-                "Invalid weight_decay value: {}".format(weight_decay))
+        super(SGD, self).__init__(lr)
 
-        defaults = dict(lr=lr,
-                        momentum=momentum,
-                        dampening=dampening,
-                        weight_decay=weight_decay,
-                        nesterov=nesterov)
+        # init momentum
+        if type(momentum) == float or type(momentum) == int:
+            if momentum < 0.0:
+                raise ValueError("Invalid momentum value: {}".format(momentum))
+            self.momentum = Constant(momentum)
+        elif isinstance(momentum, DecayScheduler):
+            self.momentum = momentum
+            momentum = momentum.init_value
+        else:
+            raise TypeError("Wrong momentum type")
+        self.mom_value = self.momentum(self.step_counter)
+
+        # init dampening
+        if type(dampening) == float or type(dampening) == int:
+            self.dampening = Constant(dampening)
+        elif isinstance(dampening, DecayScheduler):
+            self.dampening = dampening
+            dampening = dampening.init_value
+        else:
+            raise TypeError("Wrong dampening type")
+        self.dam_value = self.dampening(self.step_counter)
+
+        # init weight_decay
+        if type(weight_decay) == float or type(weight_decay) == int:
+            if weight_decay < 0.0:
+                raise ValueError(
+                    "Invalid weight_decay value: {}".format(weight_decay))
+            self.weight_decay = Constant(weight_decay)
+        elif isinstance(weight_decay, DecayScheduler):
+            self.weight_decay = weight_decay
+        else:
+            raise TypeError("Wrong weight_decay type")
+        self.decay_value = self.weight_decay(self.step_counter)
+
+        # init other params
+        self.nesterov = nesterov
+        self.moments = dict()
+
+        # check value
         if nesterov and (momentum <= 0 or dampening != 0):
             raise ValueError(
                 "Nesterov momentum requires a momentum and zero dampening")
-        super(SGD, self).__init__(defaults)
 
-    def update(self, param, grad):
+    def apply(self, param_name, param_value, param_grad):
         """Performs a single optimization step.
 
         Args:
-                param(Tensor): param values to be update in-place
+                param_name(String): the name of the param
+                param_value(Tensor): param values to be update in-place
                 grad(Tensor): param gradients; the values may be updated
                         in this function; cannot use it anymore
         """
-        assert param.shape == grad.shape, ("shape mismatch", param.shape,
-                                           grad.shape)
-        group = self.default_config
-        if param in self.param2config:
-            group = self.param2config[param]
-        weight_decay = group['weight_decay']
-        momentum = group['momentum']
-        dampening = group['dampening']
-        nesterov = group['nesterov']
+        assert param_value.shape == param_grad.shape, ("shape mismatch",
+                                                       param_value.shape,
+                                                       param_grad.shape)
+        self.device_check(param_value, self.step_counter, self.lr_value,
+                          self.mom_value, self.dam_value, self.decay_value)
 
-        if weight_decay != 0:
-            singa.Axpy(weight_decay, param.data, grad.data)
-        if momentum != 0:
-            if param not in self.param2state:
-                self.param2state[param] = {}
-            param_state = self.param2state[param]
-            if 'momentum_buffer' not in param_state:
-                flag = param.device.graph_enabled()
-                param.device.EnableGraph(False)
-                buf = param_state['momentum_buffer'] = tensor.zeros_like(param)
-                param.device.EnableGraph(flag)
+        # TODO add branch operator
+        # if self.decay_value != 0:
+        if self.weight_decay.init_value != 0:
+            singa.Axpy(self.decay_value.data, param_value.data, param_grad.data)
 
-                buf *= momentum
-                singa.Axpy(1.0, grad.data, buf.data)
+        if self.momentum.init_value != 0:
+            if param_name not in self.moments:
+                flag = param_value.device.graph_enabled()
+                param_value.device.EnableGraph(False)
+                self.moments[param_name] = tensor.zeros_like(param_value)
+                param_value.device.EnableGraph(flag)
+
+            buf = self.moments[param_name]
+            buf *= self.mom_value
+            alpha = 1.0 - self.dam_value
+            singa.Axpy(alpha.data, param_grad.data, buf.data)
+
+            if self.nesterov:
+                singa.Axpy(self.mom_value.data, buf.data, param_grad.data)
             else:
-                buf = param_state['momentum_buffer']
-                buf *= momentum
-                singa.Axpy(1.0 - dampening, grad.data, buf.data)
-            if nesterov:
-                singa.Axpy(momentum, buf.data, grad.data)
-            else:
-                grad = buf
-        singa.Axpy(-group['lr'], grad.data, param.data)
+                param_grad = buf
 
-    def backward_and_update(self, loss):
-        """Performs backward propagation from the loss and parameter update.
+        minus_lr = 0.0 - self.lr_value
+        singa.Axpy(minus_lr.data, param_grad.data, param_value.data)
 
-        From the loss, it performs backward propagation to get the gradients
-        and do the parameter update.
+    def step(self):
+        # increment step counter, lr and moment
+        super().step()
+        mom_value = self.momentum(self.step_counter)
+        dam_value = self.dampening(self.step_counter)
+        decay_value = self.weight_decay(self.step_counter)
+        self.mom_value.copy_from(mom_value)
+        self.dam_value.copy_from(dam_value)
+        self.decay_value.copy_from(decay_value)
+
+    def get_states(self):
+        states = super().get_states()
+        if self.mom_value > 0:
+            states[
+                'moments'] = self.moments  # a dict for 1st order moments tensors
+        return states
+
+    def set_states(self, states):
+        super().set_states(states)
+        if 'moments' in states:
+            self.moments = states['moments']
+            self.mom_value = self.momentum(self.step_counter)
+
+
+class RMSProp(Optimizer):
+    '''RMSProp optimizer.
+
+    See the base Optimizer for all constructor args.
+
+    Args:
+        rho (float): float within [0, 1]
+        epsilon (float): small value for preventing numeric error
+    '''
+
+    def __init__(self, lr=0.1, rho=0.9, epsilon=1e-8, weight_decay=0):
+        super(RMSProp, self).__init__(lr)
+
+        # init weight_decay
+        if type(weight_decay) == float or type(weight_decay) == int:
+            if weight_decay < 0.0:
+                raise ValueError(
+                    "Invalid weight_decay value: {}".format(weight_decay))
+            self.weight_decay = Constant(weight_decay)
+        elif isinstance(weight_decay, DecayScheduler):
+            self.weight_decay = weight_decay
+        else:
+            raise TypeError("Wrong weight_decay type")
+        self.decay_value = self.weight_decay(self.step_counter)
+
+        # init rho
+        if type(rho) == float or type(rho) == int:
+            self.rho = Constant(rho)
+        elif isinstance(rho, DecayScheduler):
+            self.rho = rho
+        else:
+            raise TypeError("Wrong rho type")
+        self.rho_value = self.rho(self.step_counter)
+
+        # init epsilon
+        if type(epsilon) == float or type(epsilon) == int:
+            self.epsilon = Constant(epsilon)
+        elif isinstance(rho, DecayScheduler):
+            self.epsilon = epsilon
+        else:
+            raise TypeError("Wrong epsilon type")
+        self.epsilon_value = self.epsilon(self.step_counter)
+
+        # init running average
+        self.running_average = dict()
+
+    def apply(self, param_name, param_value, param_grad):
+        """Performs a single optimization step.
 
         Args:
-                loss(Tensor): loss is the objective function of the deep learning model
-                optimization, e.g. for classification problem it can be the output of the
-                softmax_cross_entropy function.
+                param_name(String): the name of the param
+                param_value(Tensor): param values to be update in-place
+                grad(Tensor): param gradients; the values may be updated
+                        in this function; cannot use it anymore
         """
-        for p, g in autograd.backward(loss):
-            self.update(p, g)
+        assert param_value.shape == param_grad.shape, ("shape mismatch",
+                                                       param_value.shape,
+                                                       param_grad.shape)
+        self.device_check(param_value, self.step_counter, self.lr_value,
+                          self.rho_value, self.epsilon_value, self.decay_value)
+
+        # if self.decay_value != 0:
+        if self.weight_decay.init_value != 0:
+            singa.Axpy(self.decay_value.data, param_value.data, param_grad.data)
+
+        if param_name not in self.running_average:
+            flag = param_value.device.graph_enabled()
+            param_value.device.EnableGraph(False)
+            self.running_average[param_name] = tensor.zeros_like(param_value)
+            param_value.device.EnableGraph(flag)
+
+        # running_average = running_average * rho + param_grad * param_grad * (1 - rho)
+        # param_value = param_value - lr * param_grad / sqrt(running_average + epsilon)
+
+        self.running_average[param_name] *= self.rho_value
+
+        tmp1 = singa.Square(param_grad.data)
+        tmp2 = 1.0 - self.rho_value
+        singa.Axpy(tmp2.data, tmp1, self.running_average[param_name].data)
+
+        minus_lr = 0.0 - self.lr_value
+        tmp3 = self.running_average[param_name] + self.epsilon_value
+        tmp3 = singa.Sqrt(tmp3.data)
+        tmp3 = singa.__div__(param_grad.data, tmp3)
+
+        singa.Axpy(minus_lr.data, tmp3, param_value.data)
+
+    def step(self):
+        # increment step counter, lr and moment
+        super().step()
+        decay_value = self.weight_decay(self.step_counter)
+        rho_value = self.rho(self.step_counter)
+        epsilon_value = self.epsilon(self.step_counter)
+        self.decay_value.copy_from(decay_value)
+        self.rho_value.copy_from(rho_value)
+        self.epsilon_value.copy_from(epsilon_value)
+
+    def get_states(self):
+        states = super().get_states()
+        states['running_average'] = self.running_average
+        return states
+
+    def set_states(self, states):
+        super().set_states(states)
+        if 'running_average' in states:
+            self.running_average = states['running_average']
+
+
+class AdaGrad(Optimizer):
+    '''AdaGrad optimizer.
+
+    See the base Optimizer for all constructor args.
+
+    Args:
+        epsilon (float): small number for preventing numeric error.
+    '''
+
+    def __init__(self, lr=0.1, epsilon=1e-8, weight_decay=0):
+        super(AdaGrad, self).__init__(lr)
+
+        # init weight_decay
+        if type(weight_decay) == float or type(weight_decay) == int:
+            if weight_decay < 0.0:
+                raise ValueError(
+                    "Invalid weight_decay value: {}".format(weight_decay))
+            self.weight_decay = Constant(weight_decay)
+        elif isinstance(weight_decay, DecayScheduler):
+            self.weight_decay = weight_decay
+        else:
+            raise TypeError("Wrong weight_decay type")
+        self.decay_value = self.weight_decay(self.step_counter)
+
+        # init epsilon
+        if type(epsilon) == float or type(epsilon) == int:
+            self.epsilon = Constant(epsilon)
+        elif isinstance(epsilon, DecayScheduler):
+            self.epsilon = epsilon
+        else:
+            raise TypeError("Wrong epsilon type")
+        self.epsilon_value = self.epsilon(self.step_counter)
+
+        # init history
+        self.history = dict()
+
+    def apply(self, param_name, param_value, param_grad):
+        """Performs a single optimization step.
+
+        Args:
+                param_name(String): the name of the param
+                param_value(Tensor): param values to be update in-place
+                grad(Tensor): param gradients; the values may be updated
+                        in this function; cannot use it anymore
+        """
+        assert param_value.shape == param_grad.shape, ("shape mismatch",
+                                                       param_value.shape,
+                                                       param_grad.shape)
+        self.device_check(param_value, self.step_counter, self.lr_value,
+                          self.epsilon_value, self.decay_value)
+
+        # if self.decay_value != 0:
+        if self.weight_decay.init_value != 0:
+            singa.Axpy(self.decay_value.data, param_value.data, param_grad.data)
+
+        if param_name not in self.history:
+            flag = param_value.device.graph_enabled()
+            param_value.device.EnableGraph(False)
+            self.history[param_name] = tensor.zeros_like(param_value)
+            param_value.device.EnableGraph(flag)
+
+        # history = history + param_grad * param_grad
+        # param_value = param_value - lr * param_grad / sqrt(history + epsilon)
+
+        tmp = self.history[param_name].data
+        tmp += singa.Square(param_grad.data)
+
+        minus_lr = 0.0 - self.lr_value
+        tmp = self.history[param_name] + self.epsilon_value
+        tmp = singa.Sqrt(tmp.data)
+        tmp = singa.__div__(param_grad.data, tmp)
+        singa.Axpy(minus_lr.data, tmp, param_value.data)
+
+    def step(self):
+        # increment step counter, lr and moment
+        super().step()
+        decay_value = self.weight_decay(self.step_counter)
+        epsilon_value = self.epsilon(self.step_counter)
+        self.decay_value.copy_from(decay_value)
+        self.epsilon_value.copy_from(epsilon_value)
+
+    def get_states(self):
+        states = super().get_states()
+        states['history'] = self.history  # a dict for 1st order moments tensors
+        return states
+
+    def set_states(self, states):
+        super().set_states(states)
+        if 'history' in states:
+            self.history = states['history']
+
+
+class Adam(Optimizer):
+    '''Adam optimizer.
+
+    See the base Optimizer for all constructor args.
+
+    Args:
+        beta_1(float): coefficient of momentum
+        beta_2(float): coefficient of aggregated squared gradient
+        epsilon (float): small value for preventing numeric error
+    '''
+
+    def __init__(self,
+                 lr=0.1,
+                 beta_1=0.9,
+                 beta_2=0.999,
+                 epsilon=1e-8,
+                 weight_decay=0):
+        super(Adam, self).__init__(lr)
+
+        # init weight_decay
+        if type(weight_decay) == float or type(weight_decay) == int:
+            if weight_decay < 0.0:
+                raise ValueError(
+                    "Invalid weight_decay value: {}".format(weight_decay))
+            self.weight_decay = Constant(weight_decay)
+        elif isinstance(weight_decay, DecayScheduler):
+            self.weight_decay = weight_decay
+        else:
+            raise TypeError("Wrong weight_decay type")
+        self.decay_value = self.weight_decay(self.step_counter)
+
+        # init beta_1
+        if type(beta_1) == float or type(beta_1) == int:
+            self.beta_1 = Constant(beta_1)
+        elif isinstance(beta_1, DecayScheduler):
+            self.beta_1 = beta_1
+        else:
+            raise TypeError("Wrong beta_1 type")
+        self.beta_1_value = self.beta_1(self.step_counter)
+
+        # init beta_2
+        if type(beta_2) == float or type(beta_2) == int:
+            self.beta_2 = Constant(beta_2)
+        elif isinstance(beta_2, DecayScheduler):
+            self.beta_2 = beta_2
+        else:
+            raise TypeError("Wrong beta_2 type")
+        self.beta_2_value = self.beta_2(self.step_counter)
+
+        # init epsilon
+        if type(epsilon) == float or type(epsilon) == int:
+            self.epsilon = Constant(epsilon)
+        elif isinstance(epsilon, DecayScheduler):
+            self.epsilon = epsilon
+        else:
+            raise TypeError("Wrong epsilon type")
+        self.epsilon_value = self.epsilon(self.step_counter)
+
+        # init m and v
+        self.m = dict()
+        self.v = dict()
+
+    def apply(self, param_name, param_value, param_grad):
+        """Performs a single optimization step.
+
+        Args:
+                param_name(String): the name of the param
+                param_value(Tensor): param values to be update in-place
+                grad(Tensor): param gradients; the values may be updated
+                        in this function; cannot use it anymore
+        """
+        assert param_value.shape == param_grad.shape, ("shape mismatch",
+                                                       param_value.shape,
+                                                       param_grad.shape)
+        self.device_check(param_value, self.step_counter, self.lr_value,
+                          self.beta_1_value, self.beta_2_value,
+                          self.epsilon_value, self.decay_value)
+
+        # if self.decay_value != 0:
+        if self.weight_decay.init_value != 0:
+            singa.Axpy(self.decay_value.data, param_value.data, param_grad.data)
+
+        if param_name not in self.m:
+            flag = param_value.device.graph_enabled()
+            param_value.device.EnableGraph(False)
+            self.m[param_name] = tensor.zeros_like(param_value)
+            self.v[param_name] = tensor.zeros_like(param_value)
+            param_value.device.EnableGraph(flag)
+
+        # overall steps
+        # m := beta_1 * m + (1 - beta_1) * grad
+        # v := beta_2 * v + (1 - beta_2) * grad * grad
+        # m_norm = m / (1 - beta_1 ^ step)
+        # v_norm = v / (1 - beta_2 ^ step)
+        # param := param - (lr * m_norm) / ( sqrt(v_norm) + epsilon) )
+
+        step = self.step_counter + 1.0
+
+        # m := beta_1 * m + (1 - beta_1) * grad
+        tmp = 1.0 - self.beta_1_value
+        self.m[param_name] *= self.beta_1_value
+        singa.Axpy(tmp.data, param_grad.data, self.m[param_name].data)
+
+        # v := beta_2 * v + (1 - beta_2) * grad * grad
+        tmp = 1.0 - self.beta_2_value
+        self.v[param_name] *= self.beta_2_value
+        singa.Axpy(tmp.data, singa.Square(param_grad.data),
+                   self.v[param_name].data)
+
+        # m_norm = m / (1 - beta_1 ^ step)
+        tmp = tensor.pow(self.beta_1_value, step)
+        tmp = 1.0 - tmp
+        m_norm = self.m[param_name] / tmp
+
+        # v_norm = v / (1 - beta_2 ^ step)
+        tmp = tensor.pow(self.beta_2_value, step)
+        tmp = 1.0 - tmp
+        v_norm = self.v[param_name] / tmp
+
+        # param := param - (lr * m_norm) / ( sqrt(v_norm) + epsilon) )
+        a = tensor.sqrt(v_norm) + self.epsilon_value
+        tmp = m_norm / a
+
+        minus_lr = 0.0 - self.lr_value
+        singa.Axpy(minus_lr.data, tmp.data, param_value.data)
+
+    def step(self):
+        # increment step counter, lr and moment
+        super().step()
+        decay_value = self.weight_decay(self.step_counter)
+        beta_1_value = self.beta_1(self.step_counter)
+        beta_2_value = self.beta_2(self.step_counter)
+        self.decay_value.copy_from(decay_value)
+        self.beta_1_value.copy_from(beta_1_value)
+        self.beta_2_value.copy_from(beta_2_value)
+
+    def get_states(self):
+        states = super().get_states()
+        states['m'] = self.m  # a dict for 1st order moments tensors
+        states['v'] = self.v  # a dict for 2nd order moments tensors
+        return states
+
+    def set_states(self, states):
+        super().set_states(states)
+        if 'm' in states:
+            self.m = states['m']
+        if 'v' in states:
+            self.v = states['v']
 
 
 class DistOpt(object):
@@ -226,13 +719,16 @@
             self.communicator = singa.Communicator(buffSize)
         else:
             # constructor for application using python multi-process module
-            self.communicator = singa.Communicator(local_rank, world_size, nccl_id,
-                                                   buffSize)
+            self.communicator = singa.Communicator(local_rank, world_size,
+                                                   nccl_id, buffSize)
 
         self.world_size = self.communicator.world_size
         self.local_rank = self.communicator.local_rank
         self.global_rank = self.communicator.global_rank
 
+    def __call__(self, loss):
+        self.backward_and_update(loss)
+
     def update(self, param, grad):
         """Performs a single optimization step.
 
@@ -360,6 +856,7 @@
         self.wait()
         for p, g in plist:
             self.update(p, g)
+        self.opt.step()
 
     def backward_and_update_half(self,
                                  loss,
@@ -411,6 +908,7 @@
         self.wait()
         for p, g in plist:
             self.update(p, g)
+        self.opt.step()
 
     def backward_and_partial_update(self, loss, threshold=2097152):
         """Performs backward propagation from the loss and parameter update using asychronous training.
@@ -482,6 +980,7 @@
         # the counter returns to zero after a cycle of partial update
         if (k == self.partial):
             self.partial = 0
+        self.opt.step()
 
     def backward_and_sparse_update(self,
                                    loss,
@@ -583,3 +1082,4 @@
         for p, g in plist:
             self.update(p, g)
         self.sparsInit = True
+        self.opt.step()
diff --git a/python/singa/optimizer.py b/python/singa/optimizer.py
deleted file mode 100644
index 8c252c7..0000000
--- a/python/singa/optimizer.py
+++ /dev/null
@@ -1,472 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-# =============================================================================
-'''This module includes a set of optimizers for updating model parameters.
-
-Example usage::
-
-  from singa import optimizer
-  from singa import tensor
-
-  sgd = optimizer.SGD(lr=0.01, momentum=0.9, weight_decay=1e-4)
-  p = tensor.Tensor((3,5))
-  p.uniform(-1, 1)
-  g = tensor.Tensor((3,5))
-  g.gaussian(0, 0.01)
-
-  sgd.apply(1, g, p, 'param')  # use the global lr=0.1 for epoch 1
-  sgd.apply_with_lr(2, 0.03, g, p, 'param')  # use lr=0.03 for epoch 2
-'''
-from __future__ import division
-from __future__ import absolute_import
-
-from builtins import object
-import math
-
-from . import singa_wrap as singa
-from . import tensor
-from .proto import model_pb2
-
-
-class Optimizer(object):
-    '''The base python optimizer class.
-
-    Typically, an optimizer is used as follows:
-
-    1. construct the optimizer
-    2. (optional) register each parameter with its specs.
-    3. use the optimizer to update parameter values given parameter gradients
-       and other optional info
-
-    The subclasses should override the apply_with_lr function to do the real
-    parameter udpate.
-
-    Args:
-        lr (float): a constant value for the learning rate
-        momentum (float): a constant value for the momentum value
-        weight_decay (float): the coefficent for L2 regularizer, which is
-            mutually exclusive with 'regularizer'.
-        regularizer: an instance of Regularizer or RegularizerConf; If set,
-            regularization would be applied in apply_with_lr().
-            Users can also do regularization outside.
-        constraint: an instance of Constraint or ConstraintConf; If set,
-            constraint would be applied inside apply_with_lr(). Users can
-            also apply constraint outside.
-    '''
-
-    def __init__(self,
-                 lr=None,
-                 momentum=None,
-                 weight_decay=None,
-                 regularizer=None,
-                 constraint=None):
-        self.lr = lr
-        self.momentum = momentum
-        if weight_decay is not None:
-            assert regularizer is None, \
-                'Cannot set weight_decay and regularizer at the same time'
-            regularizer = L2Regularizer(weight_decay)
-
-        if regularizer is not None:
-            if isinstance(regularizer, model_pb2.RegularizerConf):
-                self.regularizer = CppRegularizer(regularizer)
-            else:
-                self.regularizer = regularizer
-        else:
-            self.regularizer = None
-        if constraint is not None:
-            if isinstance(constraint, model_pb2.ConstraintConf):
-                self.constraint = CppConstraint(constraint)
-            else:
-                self.constraint = constraint
-        else:
-            self.constraint = None
-        self.regularizers = {}
-        self.constraints = {}
-        self.decay_multiplier = {}
-        self.learning_rate_multiplier = {}
-
-    def register(self, name, specs):
-        '''Register the param specs, including creating regularizer and
-        constraint per param object. Param specific regularizer and constraint
-        have higher priority than the global ones. If all parameters share the
-        same setting for learning rate, regularizer and constraint, then there
-        is no need to call this function.
-
-        Args:
-            name (str): parameter name
-            specs (ParamSpec): protobuf obj, including regularizer and
-                constraint, multipliers for learning rate and weight decay.
-        '''
-        assert isinstance(specs, model_pb2.ParamSpec), \
-            'specs should be model_pb2.ParamSpec instance'
-        if specs.HasField('regularizer'):
-            self.regularizers[name] = CppRegularizer(specs.regularizer)
-        elif specs.decay_mult != 1:
-            self.regularizers[name] = L2Regularizer(
-                specs.decay_mult * self.regularizer.coefficient)
-
-        if specs.HasField('constraint'):
-            self.constraints[name] = CppConstraint(specs.constraint)
-
-        if specs.lr_mult != 1:
-            self.learning_rate_multiplier[name] = specs.lr_mult
-
-    def apply_regularizer_constraint(self,
-                                     epoch,
-                                     value,
-                                     grad,
-                                     name=None,
-                                     step=-1):
-        '''Apply regularization and constraint if available.
-
-        If there are both global regularizer (constraint) and param specific
-        regularizer (constraint), it would use the param specific one.
-
-        Args:
-            epoch (int): training epoch ID
-            value (Tensor): parameter value Tensor
-            grad (Tensor): parameter gradient Tensor
-            name (string): to get parameter specific regularizer or constraint
-            step (int): iteration ID within one epoch
-
-        Returns:
-            the updated gradient Tensor
-        '''
-        if name is not None and name in self.constraints:
-            grad = self.constraints[name].apply(epoch, value, grad, step)
-        elif self.constraint is not None:
-            grad = self.constraint.apply(epoch, value, grad, step)
-
-        if name is not None and name in self.regularizers:
-            grad = self.regularizers[name].apply(epoch, value, grad, step)
-        elif self.regularizer is not None:
-            grad = self.regularizer.apply(epoch, value, grad, step)
-        return grad
-
-    def apply_with_lr(self, epoch, lr, grad, value, name=None, step=-1):
-        '''Do update of parameters with given learning rate if the grad is not
-        empty.
-
-        The subclass optimizer must override this function.
-        This function do nothing if the grad is empty.
-
-        Args:
-            epoch (int): training epoch ID
-            lr (float): learning rate
-            grad (Tensor): parameter gradient
-            value (Tesnor): parameter value
-            name (string): paramter name to index parameter specific
-                updating rules (including regularizer and constraint)
-            step (int): iteration ID within one epoch
-
-        Returns:
-            updated parameter value
-        '''
-        assert False, 'This is the base function, pls call the subclass func'
-
-    def apply(self, epoch, grad, value, name=None, step=-1):
-        '''Do update assuming the learning rate generator is set.
-
-        The subclass optimizer does not need to override this function.
-
-        Args:
-            epoch (int): training epoch ID
-            grad (Tensor): parameter gradient
-            value (Tesnor): parameter value
-            name (string): paramter name to retrieval parameter specific
-                updating rules (including regularizer and constraint)
-            step (int): training iteration ID within one epoch
-
-        Return:
-            updated parameter value
-        '''
-        assert self.lr is not None, 'Must set the learning rate, i.e. "lr"'
-        return self.apply_with_lr(epoch, self.lr, grad, value, name, step)
-
-
-class SGD(Optimizer):
-    '''The vallina Stochasitc Gradient Descent algorithm with momentum.
-
-    See the base Optimizer for all arguments.
-    '''
-
-    def __init__(self,
-                 lr=None,
-                 momentum=None,
-                 weight_decay=None,
-                 regularizer=None,
-                 constraint=None):
-        super(SGD, self).__init__(lr, momentum, weight_decay, regularizer,
-                                  constraint)
-        conf = model_pb2.OptimizerConf()
-        if self.momentum is not None:
-            conf.momentum = self.momentum
-        conf.type = 'sgd'
-        self.opt = singa.CreateOptimizer('SGD'.encode())
-        self.opt.Setup(conf.SerializeToString())
-
-    def apply_with_lr(self, epoch, lr, grad, value, name, step=-1):
-        if grad.is_empty():
-            return value
-        grad = self.apply_regularizer_constraint(epoch, value, grad, name, step)
-        if name is not None and name in self.learning_rate_multiplier:
-            lr = lr * self.learning_rate_multiplier[name]
-        self.opt.Apply(epoch, lr, name.encode(), grad.data, value.data)
-        return value
-
-
-class Nesterov(Optimizer):
-    '''The SGD with Nesterov momentum.
-
-    See the base Optimizer for all arguments.
-    '''
-
-    def __init__(self,
-                 lr=None,
-                 momentum=0.9,
-                 weight_decay=None,
-                 regularizer=None,
-                 constraint=None):
-        super(Nesterov, self).__init__(lr, momentum, weight_decay, regularizer,
-                                       constraint)
-        conf = model_pb2.OptimizerConf()
-        if self.momentum is not None:
-            conf.momentum = momentum
-        conf.type = 'nesterov'
-        self.opt = singa.CreateOptimizer('Nesterov'.encode())
-        self.opt.Setup(conf.SerializeToString())
-
-    def apply_with_lr(self, epoch, lr, grad, value, name, step=-1):
-        if grad.is_empty():
-            return value
-
-        grad = self.apply_regularizer_constraint(epoch, value, grad, name, step)
-        if name is not None and name in self.learning_rate_multiplier:
-            lr = lr * self.learning_rate_multiplier[name]
-        self.opt.Apply(epoch, lr, name.encode(), grad.data, value.data)
-        return value
-
-
-class RMSProp(Optimizer):
-    '''RMSProp optimizer.
-
-    See the base Optimizer for all constructor args.
-
-    Args:
-        rho (float): float within [0, 1]
-        epsilon (float): small value for preventing numeric error
-    '''
-
-    def __init__(self,
-                 rho=0.9,
-                 epsilon=1e-8,
-                 lr=None,
-                 weight_decay=None,
-                 regularizer=None,
-                 constraint=None):
-        super(RMSProp, self).__init__(lr, None, weight_decay, regularizer,
-                                      constraint)
-        conf = model_pb2.OptimizerConf()
-        conf.rho = rho
-        conf.delta = epsilon
-        self.opt = singa.CreateOptimizer('RMSProp'.encode())
-        self.opt.Setup(conf.SerializeToString())
-
-    def apply_with_lr(self, epoch, lr, grad, value, name, step=-1):
-        if grad.is_empty():
-            return value
-
-        grad = self.apply_regularizer_constraint(epoch, value, grad, name, step)
-        if name is not None and name in self.learning_rate_multiplier:
-            lr = lr * self.learning_rate_multiplier[name]
-        self.opt.Apply(step, lr, name.encode(), grad.data, value.data)
-        return value
-
-
-class AdaGrad(Optimizer):
-    '''AdaGrad optimizer.
-
-    See the base Optimizer for all constructor args.
-
-    Args:
-        epsilon (float): small number for preventing numeric error.
-    '''
-
-    def __init__(self,
-                 epsilon=1e-8,
-                 lr=None,
-                 weight_decay=None,
-                 lr_gen=None,
-                 regularizer=None,
-                 constraint=None):
-        super(AdaGrad, self).__init__(lr, None, weight_decay, regularizer,
-                                      constraint)
-        conf = model_pb2.OptimizerConf()
-        conf.delta = epsilon
-        conf.type = 'adagrad'
-        self.opt = singa.CreateOptimizer('AdaGrad'.encode())
-        self.opt.Setup(conf.SerializeToString())
-
-    def apply_with_lr(self, epoch, lr, grad, value, name, step=-1):
-        if grad.is_empty():
-            return value
-
-        grad = self.apply_regularizer_constraint(epoch, value, grad, name, step)
-        if name is not None and name in self.learning_rate_multiplier:
-            lr = lr * self.learning_rate_multiplier[name]
-        self.opt.Apply(epoch, lr, name.encode(), grad.data, value.data)
-        return value
-
-
-class Adam(Optimizer):
-    '''Adam optimizer.
-
-    See the base Optimizer for all constructor args.
-
-    Args:
-        beta_1(float): coefficient of momentum
-        beta_2(float): coefficient of aggregated squared gradient
-        epsilon (float): small value for preventing numeric error
-    '''
-
-    def __init__(self,
-                 beta_1=0.9,
-                 beta_2=0.999,
-                 epsilon=1e-8,
-                 lr=None,
-                 weight_decay=None,
-                 regularizer=None,
-                 constraint=None):
-        super(Adam, self).__init__(lr, None, weight_decay, regularizer,
-                                   constraint)
-        self.beta_1 = beta_1
-        self.beta_2 = beta_2
-        self.epsilon = epsilon
-        self.m = {}
-        self.v = {}
-        self.t = 0
-        self.last_epoch = -1
-        self.last_step = -1
-
-    def apply_with_lr(self, epoch, lr, grad, value, name, step):
-        '''Update one parameter object.
-
-        Args:
-            step(int): the accumulated training iterations, not the iteration ID
-        '''
-        if grad.is_empty():
-            return value
-
-        assert step != -1, 'step should >= 0'
-        if epoch != self.last_epoch or step != self.last_step:
-            self.t += 1
-            self.last_step = step
-            self.last_epoch = epoch
-        grad = self.apply_regularizer_constraint(epoch, value, grad, name, step)
-        if name is not None and name in self.learning_rate_multiplier:
-            lr = lr * self.learning_rate_multiplier[name]
-        if name not in self.m or name not in self.v:
-            self.m[name] = tensor.Tensor(grad.shape, grad.device, grad.dtype)
-            self.m[name].set_value(0)
-            self.v[name] = tensor.Tensor(grad.shape, grad.device, grad.dtype)
-            self.v[name].set_value(0)
-
-        self.m[name] *= self.beta_1
-        tensor.axpy(1 - self.beta_1, grad, self.m[name])
-        self.v[name] *= self.beta_2
-        tensor.axpy(1 - self.beta_2, tensor.square(grad), self.v[name])
-        alpha = lr * math.sqrt(1 - math.pow(self.beta_2, self.t)) \
-            / (1 - math.pow(self.beta_1, self.t))
-        value -= alpha * self.m[name] / (tensor.sqrt(self.v[name]) +
-                                         self.epsilon)
-        return value
-
-
-class Regularizer(object):
-    '''Base Python regularizer for parameter gradients.'''
-
-    def apply(self, epoch, value, grad, step=-1):
-        assert False, 'Not Implemented. Call the subclass function.'
-
-
-class CppRegularizer(Regularizer):
-    '''Wrapper for regularizer implemented using C++.
-
-    Args:
-        conf (RegularizerConf): protobuf message for the configuration.
-    '''
-
-    def __init__(self, conf):
-        self.reg = singa.CreateRegularizer(conf.type)
-        self.reg.Setup(conf.SerializeToString())
-
-    def apply(self, epoch, value, grad, step=-1):
-        self.reg.Apply(epoch, value.data, grad.data)
-        return grad
-
-
-class L2Regularizer(Regularizer):
-    '''L2 regularization
-
-    Args:
-        coefficient (float): regularization coefficient.
-    '''
-
-    def __init__(self, coefficient):
-        self.coefficient = coefficient
-
-    def apply(self, epoch, value, grad, step=-1):
-        # print coefficient, value.l1(), grad.l1()
-        if self.coefficient != 0:
-            tensor.axpy(self.coefficient, value, grad)
-        return grad
-
-
-class Constraint(object):
-    '''Base Python constraint class for paramter gradients'''
-
-    def apply(self, epoch, value, grad, step=-1):
-        return grad
-
-
-class CppConstraint(Constraint):
-    '''Wrapper for constraints implemented using C++.
-
-    Args:
-        conf (ConstraintConf): protobuf message for the configuration.
-    '''
-
-    def __init__(self, conf):
-        self.constraint = singa.CreateConstraint(conf.type)
-        self.constraint.Setup(conf.SerializeToString())
-
-    def apply(self, epoch, value, grad, step=-1):
-        self.constraint.Apply(epoch, value.data, grad.data, step)
-        return grad
-
-
-class L2Constraint(Constraint):
-    '''Rescale the gradient to make the L2 norm <= a given threshold'''
-
-    def __init__(self, threshold=None):
-        self.threshold = threshold
-
-    def apply(self, epoch, value, grad, step=-1):
-        nrm = grad.l2()
-        grad *= self.threshold / nrm
-        return grad
diff --git a/python/singa/snapshot.py b/python/singa/snapshot.py
index cae151d..67f246b 100644
--- a/python/singa/snapshot.py
+++ b/python/singa/snapshot.py
@@ -18,6 +18,9 @@
 '''
 This script includes io::snapshot class and its methods.
 
+Note: This module is depreated. Please use the model module for 
+checkpoing and restore.
+
 Example usages::
 
     from singa import snapshot
diff --git a/python/singa/sonnx.py b/python/singa/sonnx.py
index 69ec986..6ff7cef 100755
--- a/python/singa/sonnx.py
+++ b/python/singa/sonnx.py
@@ -20,22 +20,57 @@
 from __future__ import division
 
 import numpy as np
-import onnx.utils
+
 import onnx
+import onnx.utils
 from onnx.backend.base import Backend, BackendRep
 from onnx import (checker, helper, numpy_helper, GraphProto, NodeProto,
-                  TensorProto, OperatorSetIdProto, optimizer)
+                  TensorProto, OperatorSetIdProto, optimizer, mapping,
+                  shape_inference)
 import warnings
 
-from . import singa_wrap as singa
+from . import device
 from . import autograd
+from . import layer
 from . import tensor
-from singa import utils
+from . import model
+from . import utils
+from . import singa_wrap as singa
 
 import collections
 OrderedDict = collections.OrderedDict
 namedtuple = collections.namedtuple
 
+# singa only supports float32 and int32
+NP_TYPE_TO_SINGA_SUPPORT_TYPE = {
+    np.dtype('float32'): np.dtype('float32'),
+    np.dtype('uint8'): None,
+    np.dtype('int8'): np.dtype('int32'),
+    np.dtype('uint16'): None,
+    np.dtype('int16'): np.dtype('int32'),
+    np.dtype('int32'): np.dtype('int32'),
+    np.dtype('int64'): np.dtype('int32'),
+    np.dtype('bool'): np.dtype('float32'),
+    np.dtype('float16'): np.dtype('float32'),
+    np.dtype('float64'): np.dtype('float32'),
+    np.dtype('complex64'): None,
+    np.dtype('complex128'): None,
+    np.dtype('uint32'): None,
+    np.dtype('uint64'): None,
+    np.dtype(np.object): None
+}
+
+
+def onnx_type_to_singa_type(onnx_type):
+    np_type = mapping.TENSOR_TYPE_TO_NP_TYPE[onnx_type]
+    return NP_TYPE_TO_SINGA_SUPPORT_TYPE[np_type]
+
+
+gpu_dev = None
+if singa.USE_CUDA:
+    gpu_dev = device.create_cuda_gpu()
+cpu_dev = device.get_default_device()
+
 
 class SingaFrontend(object):
     """
@@ -358,10 +393,9 @@
             the onnx node
         """
         node = cls._common_singa_tensor_to_onnx_node(op, op_t)
-        tensor_type = onnx.TensorProto.FLOAT if isinstance(
-            op.value, float) else onnx.TensorProto.INT32
-        tensor_value = onnx.helper.make_tensor("value", tensor_type, [1],
-                                               [op.value])
+        tensor_type = TensorProto.FLOAT if isinstance(
+            op.value, float) else TensorProto.INT32
+        tensor_value = helper.make_tensor("value", tensor_type, [1], [op.value])
         node.attribute.extend([
             helper.make_attribute('value', tensor_value),
         ])
@@ -829,7 +863,7 @@
     @classmethod
     def _common_singa_tensor_to_onnx_node(cls, op, op_t):
         """
-        get a onnx node from a singa operator, prepare its type, inputs and outputs
+        get a onnx node from singa operator, prepare its type, inputs and outputs
         Args:
             op: a given operator
         Args:
@@ -960,17 +994,31 @@
     """
 
     def __init__(self, node):
-        self.name = str(node.name)
+        self.name = str(node.name).replace(".", "_")
         self.op_type = str(node.op_type)
         self.attrs = OnnxAttributes.from_onnx(node.attribute)
-        # there may some inputs which we regard as attribute, so we mark them there
-        self.consumed_inputs = list()
+        # inputs as attributes in singa
+        self.attr_inputs = {}
+        # inputs as weights in singa
+        self.weight_inputs = {}
         self.inputs = list(node.input)
         self.outputs = list(node.output)
 
     def getattr(self, key, default=None):
         return self.attrs[key] if key in self.attrs else default
 
+    def set_attr_inputs(self, key, name):
+        self.attr_inputs[key] = name
+
+    def del_attr_inputs(self, key):
+        del self.attr_inputs[key]
+
+    def set_weight_inputs(self, key, name):
+        self.weight_inputs[key] = name
+
+    def del_weight_inputs(self, key):
+        del self.weight_inputs[key]
+
 
 class OnnxAttributes(dict):
     """
@@ -989,585 +1037,618 @@
 class SingaBackend(Backend):
 
     # This number indicates the onnx operator set version
-    _known_opset_version = 11
+    _opset_version = 11
+
+    _ir_version = 0x0000000000000006
 
     # beceuase singa's operators are different from onnx.
     # we define a dict for the name projection
     _rename_operators = {
-        'Relu': 'relu',
-        'Softmax': 'SoftMax',
-        'Sigmoid': 'sigmoid',
-        'Add': 'add',
-        'MatMul': 'matmul',
-        'Conv': '_Conv2d',
-        'MaxPool': '_Pooling2d',
-        'AveragePool': '_Pooling2d',
-        'BatchNormalization': 'batchnorm_2d',
-        'Concat': 'Concat',
-        'Flatten': 'Flatten',
-        'Gemm': 'Gemm',
-        'Reshape': 'Reshape',
-        'Sum': 'sum',
-        'Cos': 'cos',
-        'Cosh': 'cosh',
-        'Sin': 'sin',
-        'Sinh': 'sinh',
-        'Tan': 'tan',
-        'Tanh': 'tanh',
-        'Acos': 'acos',
-        'Acosh': 'acosh',
-        'Asin': 'asin',
-        'Asinh': 'asinh',
-        'Atan': 'atan',
-        'Atanh': 'atanh',
-        'Selu': 'SeLU',
-        'Elu': 'Elu',
-        'Equal': 'equal',
-        'Less': 'less',
-        'Sign': 'sign',
-        'Div': 'div',
-        'Sub': 'sub',
-        'Sqrt': 'sqrt',
-        'Log': 'log',
-        'Greater': 'greater',
-        'HardSigmoid': 'HardSigmoid',
-        'Identity': 'identity',
-        'Softplus': 'softplus',
-        'Softsign': 'softsign',
-        'Mean': 'mean',
-        'Pow': 'pow',
-        'Clip': 'Clip',
-        'PRelu': 'prelu',
-        'Mul': 'mul',
-        'Transpose': 'Transpose',
-        'Max': 'max',
-        'Min': 'min',
-        'Shape': 'shape',
-        'And': '_and',
-        'Or': '_or',
-        'Xor': '_xor',
-        'Not': '_not',
-        'Neg': 'negative',
-        'Reciprocal': 'reciprocal',
-        'ConstantOfShape': 'ConstantOfShape',
-        'Dropout': 'Dropout',
+        # common op
+        'Relu': 'ReLU',
+        'Sigmoid': 'Sigmoid',
+        'Add': 'Add',
+        'MatMul': 'Matmul',
+        'Sum': 'Sum',
+        'Cos': 'Cos',
+        'Cosh': 'Cosh',
+        'Sin': 'Sin',
+        'Sinh': 'Sinh',
+        'Tan': 'Tan',
+        'Tanh': 'Tanh',
+        'Acos': 'Acos',
+        'Acosh': 'Acosh',
+        'Asin': 'Asin',
+        'Asinh': 'Asinh',
+        'Atan': 'Atan',
+        'Atanh': 'Atanh',
+        'Equal': 'Equal',
+        'Less': 'Less',
+        'Sign': 'Sign',
+        'Div': 'Div',
+        'Sub': 'Sub',
+        'Sqrt': 'Sqrt',
+        'Log': 'Log',
+        'Greater': 'Greater',
+        'Identity': 'Identity',
+        'Softplus': 'SoftPlus',
+        'Softsign': 'SoftSign',
+        'Mean': 'Mean',
+        'Pow': 'Pow',
+        'PRelu': 'PRelu',
+        'Mul': 'Mul',
+        'Max': 'Max',
+        'Min': 'Min',
+        'Shape': 'Shape',
+        'And': 'And',
+        'Or': 'Or',
+        'Xor': 'Xor',
+        'Not': 'Not',
+        'Neg': 'Negative',
+        'Reciprocal': 'Reciprocal',
+        'Unsqueeze': 'Unsqueeze',
+        'NonZero': 'NonZero',
+        'Ceil': 'Ceil',
+        'Floor': 'Floor',
+        'Abs': 'Abs',
+        # special op
+        'ScatterElements': 'ScatterElements',
+        'Cast': 'Cast',
+        'Split': 'Split',
+        'Squeeze': 'Squeeze',
+        'GlobalAveragePool': 'GlobalAveragePool',
+        'LeakyRelu': 'LeakyRelu',
         'ReduceSum': 'ReduceSum',
         'ReduceMean': 'ReduceMean',
-        'LeakyRelu': 'LeakyRelu',
-        'GlobalAveragePool': 'GlobalAveragePool',
-        'Squeeze': 'Squeeze',
-        'Unsqueeze': 'Unsqueeze',
-        'Slice': 'Slice',
-        'Ceil': 'Ceil',
-        'Split': 'Split',
-        'Gather': 'Gather',
-        'Tile': 'Tile',
-        'NonZero': 'nonzero',
-        'Cast': 'Cast',
+        'Dropout': 'Dropout',
+        'ConstantOfShape': 'ConstantOfShape',
+        'Transpose': 'Transpose',
+        'HardSigmoid': 'HardSigmoid',
+        'Elu': 'Elu',
+        'Selu': 'SeLU',
+        'Concat': 'Concat',
+        'Softmax': 'SoftMax',
+        'Flatten': 'Flatten',
         'OneHot': 'OneHot',
+        'Tile': 'Tile',
+        'Gather': 'Gather',
+        'Reshape': 'Reshape',
+        'Slice': 'Slice',
+        'Clip': 'Clip',
+        'Expand': 'Expand',
+        'Pad': 'Pad',
+        'Upsample': 'UpSample',
+        'DepthToSpace': 'DepthToSpace',
+        'SpaceToDepth': 'SpaceToDepth',
+        'Where': 'Where',
+        'Erf': 'Erf',
+        'Gemm': 'layer.Gemm',  # layer
+        'BatchNormalization': 'layer.BatchNorm2d',  # layer
+        'Conv': 'layer.Conv2d',  # layer
+        'MaxPool': 'layer.Pooling2d',  # layer
+        'AveragePool': 'layer.Pooling2d',  # layer
     }
 
     # this dict indicates the operators that need extra handle
     # each indicates a function name
     _special_operators = {
+        'Cast': '_create_cast',
+        'Split': '_create_split',
+        'Squeeze': '_create_squeeze_unsqueeze',
+        'Unsqueeze': '_create_squeeze_unsqueeze',
+        'GlobalAveragePool': '_create_global_average_pool',
+        'LeakyRelu': '_create_leakyrelu',
+        'ReduceSum': '_create_reduce_ops',
+        'ReduceMean': '_create_reduce_ops',
+        'Dropout': '_create_dropout',
+        'ConstantOfShape': '_create_constant_of_shape',
+        'Transpose': '_create_transpose',
+        'HardSigmoid': '_create_hardsigmoid',
+        'Elu': '_create_elu',
+        'Selu': '_create_selu',
+        'Concat': '_create_concat',
+        'Softmax': '_create_softmax',
+        'Gemm': '_create_gemm',
+        'Flatten': '_create_flatten',
+        'OneHot': '_create_onehot',
+        'Tile': '_create_tile',
+        'Gather': '_create_gather',
+        'Reshape': '_create_reshape',
+        'Slice': '_create_slice',
+        'Clip': '_create_clip',
+        'BatchNormalization': '_create_batch_norm',
         'Conv': '_create_conv',
         'MaxPool': '_create_max_avg_pool',
         'AveragePool': '_create_max_avg_pool',
-        'BatchNormalization': '_create_batchnorm',
-        'Concat': '_create_concat',
-        'Flatten': '_create_flatten',
-        'Gemm': '_create_gemm',
-        'Reshape': '_create_reshape',
-        'Softmax': '_create_softmax',
-        'Selu': '_create_selu',
-        'Elu': '_create_elu',
-        'HardSigmoid': '_create_hardsigmoid',
-        'Clip': '_create_clip',
-        'Transpose': '_create_transpose',
-        'ConstantOfShape': '_create_constantOfShape',
-        'Dropout': '_create_dropout',
-        'ReduceSum': '_create_reduceOp',
-        'ReduceMean': '_create_reduceOp',
-        'LeakyRelu': '_create_leakyrelu',
-        'GlobalAveragePool': '_create_globalaveragepool',
-        'Squeeze': '_create_squeeze',
-        'Unsqueeze': '_create_squeeze',
-        'Slice': '_create_slice',
-        'Split': '_create_split',
-        'Gather': '_create_gather',
-        'Tile': '_create_tile',
-        'Cast': '_create_cast',
-        'OneHot': '_create_onehot',
-        'Constant': "_create_constant"
+        'Expand': '_create_expand',
+        'Pad': '_create_pad',
+        'Upsample': '_create_upsample',
+        'DepthToSpace': '_create_depth_space',
+        'SpaceToDepth': '_create_depth_space',
+        'ScatterElements': '_create_scatter_elements',
+        'Where': '_create_where',
     }
 
     @classmethod
-    def _create_constant(cls, onnx_node, inputs, opset_version):
+    def _create_depth_space(cls, onnx_node, operator, opset_version=_opset_version):
         """
-        parse onnx constatn node to weights
+        get the DepthToSpace and SpaceToDepth operator from onnx node
         Args:
-            onnx_node: a given onnx node
-        Args:
-            inputs: the input tensor
-        Args:
-            opset_version: the opset version
+            onnx_node (OnnxNode): a given onnx node
+            operator (Operator Class): a singa operator class
+            opset_version (int): the opset version
         Returns: 
-            handle, the handle of singa operator
-        Returns: 
-            forward, the autograd of singa operator
+            singa operator instance
         """
-        tmp_tensor = onnx_node.getattr('value')
-        np_dtype = onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[tmp_tensor.data_type]
-        np_tensor = np.frombuffer(tmp_tensor.raw_data, dtype=np_dtype)
-        if np_tensor.dtype == "int64":
-            np_tensor = np_tensor.astype(np.int32)
-        # todo, we cannot support scalar tensor
-        if np.ndim(np_tensor) == 0:
-            np_tensor = np.array(np_tensor, ndmin=1)
-        return None, np_tensor
+        blocksize = onnx_node.getattr("blocksize")
+        mode = utils.force_unicode(onnx_node.getattr("mode", "DCR"))
+        return operator(blocksize, mode)
 
     @classmethod
-    def _create_onehot(cls, onnx_node, inputs, opset_version):
+    def _create_where(cls, onnx_node, operator, opset_version=_opset_version):
         """
-        get the OneHot operator from onnx node
+        get the Where operator from onnx node
         Args:
-            onnx_node: a given onnx node
-        Args:
-            inputs: the input tensor
-        Args:
-            opset_version: the opset version
+            onnx_node (OnnxNode): a given onnx node
+            operator (Operator Class): a singa operator class
+            opset_version (int): the opset version
         Returns: 
-            handle, the handle of singa operator
-        Returns: 
-            forward, the autograd of singa operator
+            singa operator instance
         """
-        axis = onnx_node.getattr("axis", -1)
-        # we move several inputs to singa's attribuates
-        # and mark them so we don't use them when we run this operator
-        depth = tensor.to_numpy(inputs.pop(1)).astype(np.int32)
-        value = tensor.to_numpy(inputs.pop(1))
-        onnx_node.consumed_inputs.extend(onnx_node.inputs[1:])
-        _, forward = cls._common_onnx_node_to_singa_op(onnx_node, inputs,
-                                                       opset_version)
-        return _, forward(axis, depth, value)
+        onnx_node.set_attr_inputs(onnx_node.inputs[0], 'condition')
+        return operator(None)
 
     @classmethod
-    def _create_cast(cls, onnx_node, inputs, opset_version):
+    def _create_pad(cls, onnx_node, operator, opset_version=_opset_version):
+        """
+        get the Pad operator from onnx node
+        Args:
+            onnx_node (OnnxNode): a given onnx node
+            operator (Operator Class): a singa operator class
+            opset_version (int): the opset version
+        Returns: 
+            singa operator instance
+        """
+        mode = onnx_node.getattr("mode", "constant")
+        onnx_node.set_attr_inputs(onnx_node.inputs[1], 'pads')
+        if len(onnx_node.inputs) == 3:
+            onnx_node.set_attr_inputs(onnx_node.inputs[2], 'constant')
+        return operator(mode, None, None)
+
+    @classmethod
+    def _create_upsample(cls,
+                         onnx_node,
+                         operator,
+                         opset_version=_opset_version):
+        """
+        get the UpSample operator from onnx node
+        Args:
+            onnx_node (OnnxNode): a given onnx node
+            operator (Operator Class): a singa operator class
+            opset_version (int): the opset version
+        Returns: 
+            singa operator instance
+        """
+        mode = utils.force_unicode(onnx_node.getattr("mode", None))
+        onnx_node.set_attr_inputs(onnx_node.inputs[1], 'scales')
+        return operator(mode, None)
+
+    @classmethod
+    def _create_expand(cls, onnx_node, operator, opset_version=_opset_version):
+        """
+        get the Expand operator from onnx node
+        Args:
+            onnx_node (OnnxNode): a given onnx node
+            operator (Operator Class): a singa operator class
+            opset_version (int): the opset version
+        Returns: 
+            singa operator instance
+        """
+        onnx_node.set_attr_inputs(onnx_node.inputs[1], 'shape')
+        return operator(None)
+
+    @classmethod
+    def _create_cast(cls, onnx_node, operator, opset_version=_opset_version):
         """
         get the Cast operator from onnx node
         Args:
-            onnx_node: a given onnx node
-        Args:
-            inputs: the input tensor
-        Args:
-            opset_version: the opset version
+            onnx_node (OnnxNode): a given onnx node
+            operator (Operator Class): a singa operator class
+            opset_version (int): the opset version
         Returns: 
-            handle, the handle of singa operator
-        Returns: 
-            forward, the autograd of singa operator
+            singa operator instance
         """
-        to = onnx_node.getattr("to")
-        # singa only supports float32 and int32
-        map_dict = {
-            TensorProto.FLOAT: tensor.float32,  # FLOAT to float32
-            TensorProto.UINT8: None,  # UINT8
-            TensorProto.INT8: tensor.int32,  # INT8 to int32
-            TensorProto.UINT16: None,  # UINT16
-            TensorProto.INT16: tensor.int32,  # INT16 to int32
-            TensorProto.INT32: tensor.int32,  # INT32 to int32
-            TensorProto.INT64: tensor.int32,  # INT64 to int32
-            TensorProto.STRING: None,  # stirng
-            TensorProto.BOOL: None,  # bool
-        }
-        to = map_dict[to]
-        assert to != None, "not support cast type: {}".format(to)
-        _, forward = cls._common_onnx_node_to_singa_op(onnx_node, inputs,
-                                                       opset_version)
-        return _, forward(to)
+        to_type = onnx_type_to_singa_type(onnx_node.getattr("to"))
+        assert to_type != None, "not support cast type: {}".format(to_type)
+        if to_type == np.dtype('float32'):
+            return operator(tensor.float32)
+        else:
+            return operator(tensor.int32)
 
     @classmethod
-    def _create_tile(cls, onnx_node, inputs, opset_version):
-        """
-        get the Tile operator from onnx node
-        Args:
-            onnx_node: a given onnx node
-        Args:
-            inputs: the input tensor
-        Args:
-            opset_version: the opset version
-        Returns: 
-            handle, the handle of singa operator
-        Returns: 
-            forward, the autograd of singa operator
-        """
-        # we move several inputs to singa's attribuates
-        # and mark them so we don't use them when we run this operator
-        repeats = tensor.to_numpy(inputs.pop(1)).astype(np.int32).tolist()
-        onnx_node.consumed_inputs.append(onnx_node.inputs[1])
-        _, forward = cls._common_onnx_node_to_singa_op(onnx_node, inputs,
-                                                       opset_version)
-        return _, forward(repeats)
-
-    @classmethod
-    def _create_gather(cls, onnx_node, inputs, opset_version):
-        """
-        get the Gather operator from onnx node
-        Args:
-            onnx_node: a given onnx node
-        Args:
-            inputs: the input tensor
-        Args:
-            opset_version: the opset version
-        Returns: 
-            handle, the handle of singa operator
-        Returns: 
-            forward, the autograd of singa operator
-        """
-        axis = onnx_node.getattr("axis", 0)
-        # we move several inputs to singa's attribuates
-        # and mark them so we don't use them when we run this operator
-        indices = tensor.to_numpy(inputs.pop(1)).astype(np.int32).tolist()
-        onnx_node.consumed_inputs.append(onnx_node.inputs[1])
-        _, forward = cls._common_onnx_node_to_singa_op(onnx_node, inputs,
-                                                       opset_version)
-        return _, forward(axis, indices)
-
-    @classmethod
-    def _create_split(cls, onnx_node, inputs, opset_version):
+    def _create_split(cls, onnx_node, operator, opset_version=_opset_version):
         """
         get the Split operator from onnx node
         Args:
-            onnx_node: a given onnx node
-        Args:
-            inputs: the input tensor
-        Args:
-            opset_version: the opset version
+            onnx_node (OnnxNode): a given onnx node
+            operator (Operator Class): a singa operator class
+            opset_version (int): the opset version
         Returns: 
-            handle, the handle of singa operator
-        Returns: 
-            forward, the autograd of singa operator
+            singa operator instance
         """
         axis = onnx_node.getattr("axis", 0)
         split = onnx_node.getattr("split", None)
         num_output = len(onnx_node.outputs)
-        _, forward = cls._common_onnx_node_to_singa_op(onnx_node, inputs,
-                                                       opset_version)
-        return _, forward(axis, split, num_output)
+        return operator(axis, split, num_output)
 
     @classmethod
-    def _create_slice(cls, onnx_node, inputs, opset_version):
-        """
-        get the Slice operator from onnx node
-        Args:
-            onnx_node: a given onnx node
-        Args:
-            inputs: the input tensor
-        Args:
-            opset_version: the opset version
-        Returns: 
-            handle, the handle of singa operator
-        Returns: 
-            forward, the autograd of singa operator
-        """
-        # we move several inputs to singa's attribuates
-        # and mark them so we don't use them when we run this operator
-        starts = tensor.to_numpy(inputs.pop(1)).astype(np.int32).tolist()
-        ends = tensor.to_numpy(inputs.pop(1)).astype(np.int32).tolist()
-        # sometime onnx may ignore these two inputs, axes and step
-        if len(inputs) >= 2 and onnx_node.inputs[3] != '':
-            axes = tensor.to_numpy(inputs.pop(1)).astype(np.int32).tolist()
-        else:
-            axes = None
-        steps = tensor.to_numpy(inputs.pop(1)).astype(
-            np.int32).tolist() if len(inputs) >= 2 else None
-        onnx_node.consumed_inputs.extend(onnx_node.inputs[1:])
-        _, forward = cls._common_onnx_node_to_singa_op(onnx_node, inputs,
-                                                       opset_version)
-        return _, forward(starts, ends, axes, steps)
-
-    @classmethod
-    def _create_squeeze(cls, onnx_node, inputs, opset_version):
+    def _create_squeeze_unsqueeze(cls,
+                                  onnx_node,
+                                  operator,
+                                  opset_version=_opset_version):
         """
         get the Squeeze and Unsqueeze operator from onnx node
         Args:
-            onnx_node: a given onnx node
-        Args:
-            inputs: the input tensor
-        Args:
-            opset_version: the opset version
+            onnx_node (OnnxNode): a given onnx node
+            operator (Operator Class): a singa operator class
+            opset_version (int): the opset version
         Returns: 
-            handle, the handle of singa operator
-        Returns: 
-            forward, the autograd of singa operator
+            singa operator instance
         """
         axes = onnx_node.getattr("axes")
-        _, forward = cls._common_onnx_node_to_singa_op(onnx_node, inputs,
-                                                       opset_version)
-        return _, forward(axes)
+        return operator(axes)
 
     @classmethod
-    def _create_globalaveragepool(cls, onnx_node, inputs, opset_version):
+    def _create_global_average_pool(cls,
+                                    onnx_node,
+                                    operator,
+                                    opset_version=_opset_version):
         """
         get the GlobalAveragePool operator from onnx node
         Args:
-            onnx_node: a given onnx node
-        Args:
-            inputs: the input tensor
-        Args:
-            opset_version: the opset version
+            onnx_node (OnnxNode): a given onnx node
+            operator (Operator Class): a singa operator class
+            opset_version (int): the opset version
         Returns: 
-            handle, the handle of singa operator
-        Returns: 
-            forward, the autograd of singa operator
+            singa operator instance
         """
         data_format = onnx_node.getattr("data_format", 'channels_first')
-        _, forward = cls._common_onnx_node_to_singa_op(onnx_node, inputs,
-                                                       opset_version)
-        return _, forward(data_format)
+        return operator(data_format)
 
     @classmethod
-    def _create_leakyrelu(cls, onnx_node, inputs, opset_version):
+    def _create_leakyrelu(cls,
+                          onnx_node,
+                          operator,
+                          opset_version=_opset_version):
         """
         get the LeakyRelu operator from onnx node
         Args:
-            onnx_node: a given onnx node
-        Args:
-            inputs: the input tensor
-        Args:
-            opset_version: the opset version
+            onnx_node (OnnxNode): a given onnx node
+            operator (Operator Class): a singa operator class
+            opset_version (int): the opset version
         Returns: 
-            handle, the handle of singa operator
-        Returns: 
-            forward, the autograd of singa operator
+            singa operator instance
         """
         alpha = onnx_node.getattr("alpha", 0.01)
-        _, forward = cls._common_onnx_node_to_singa_op(onnx_node, inputs,
-                                                       opset_version)
-        return _, forward(alpha)
+        return operator(alpha)
 
     @classmethod
-    def _create_reduceOp(cls, onnx_node, inputs, opset_version):
+    def _create_reduce_ops(cls,
+                           onnx_node,
+                           operator,
+                           opset_version=_opset_version):
         """
         get the ReduceSum, ReduceMean, ReduceMax, ReduceMin, etc, operator from onnx node
         Args:
-            onnx_node: a given onnx node
-        Args:
-            inputs: the input tensor
-        Args:
-            opset_version: the opset version
+            onnx_node (OnnxNode): a given onnx node
+            operator (Operator Class): a singa operator class
+            opset_version (int): the opset version
         Returns: 
-            handle, the handle of singa operator
-        Returns: 
-            forward, the autograd of singa operator
+            singa operator instance
         """
         axes = onnx_node.getattr("axes", None)
         keepdims = onnx_node.getattr("keepdims", 1)
-        _, forward = cls._common_onnx_node_to_singa_op(onnx_node, inputs,
-                                                       opset_version)
-        return _, forward(axes, keepdims)
+        return operator(axes, keepdims)
 
     @classmethod
-    def _create_dropout(cls, onnx_node, inputs, opset_version):
+    def _create_dropout(cls, onnx_node, operator, opset_version=_opset_version):
         """
         get the Dropout operator from onnx node
         Args:
-            onnx_node: a given onnx node
-        Args:
-            inputs: the input tensor
-        Args:
-            opset_version: the opset version
+            onnx_node (OnnxNode): a given onnx node
+            operator (Operator Class): a singa operator class
+            opset_version (int): the opset version
         Returns: 
-            handle, the handle of singa operator
-        Returns: 
-            forward, the autograd of singa operator
+            singa operator instance
         """
+        seed = onnx_node.getattr("seed", 0)
         ratio = onnx_node.getattr("ratio", 0)
-        _, forward = cls._common_onnx_node_to_singa_op(onnx_node, inputs,
-                                                       opset_version)
-        return _, forward(ratio)
+        return operator(seed, ratio)
 
     @classmethod
-    def _create_constantOfShape(cls, onnx_node, inputs, opset_version):
+    def _create_constant_of_shape(cls,
+                                  onnx_node,
+                                  operator,
+                                  opset_version=_opset_version):
         """
         get the ConstantOfShape operator from onnx node
         Args:
-            onnx_node: a given onnx node
-        Args:
-            inputs: the input tensor
-        Args:
-            opset_version: the opset version
+            onnx_node (OnnxNode): a given onnx node
+            operator (Operator Class): a singa operator class
+            opset_version (int): the opset version
         Returns: 
-            handle, the handle of singa operator
-        Returns: 
-            forward, the autograd of singa operator
+            singa operator instance
         """
         value = onnx_node.getattr("value", 0)
         if isinstance(value, onnx.TensorProto):
             value = numpy_helper.to_array(value)[0].item()
-        _, forward = cls._common_onnx_node_to_singa_op(onnx_node, inputs,
-                                                       opset_version)
-        return _, forward(value)
+        return operator(value)
 
     @classmethod
-    def _create_transpose(cls, onnx_node, inputs, opset_version):
+    def _create_transpose(cls,
+                          onnx_node,
+                          operator,
+                          opset_version=_opset_version):
         """
         get the Transpose operator from onnx node
         Args:
-            onnx_node: a given onnx node
-        Args:
-            inputs: the input tensor
-        Args:
-            opset_version: the opset version
+            onnx_node (OnnxNode): a given onnx node
+            operator (Operator Class): a singa operator class
+            opset_version (int): the opset version
         Returns: 
-            handle, the handle of singa operator
-        Returns: 
-            forward, the autograd of singa operator
+            singa operator instance
         """
-        shape = inputs[0].shape
-        perm = onnx_node.getattr("perm", list(range(len(shape) - 1, -1, -1)))
-        _, forward = cls._common_onnx_node_to_singa_op(onnx_node, inputs,
-                                                       opset_version)
-        return _, forward(perm)
+        perm = onnx_node.getattr("perm")
+        return operator(perm)
 
     @classmethod
-    def _create_clip(cls, onnx_node, inputs, opset_version):
+    def _create_hardsigmoid(cls,
+                            onnx_node,
+                            operator,
+                            opset_version=_opset_version):
         """
-        get the clip operator from onnx node
+        get the hardsigmoid operator from onnx node
         Args:
-            onnx_node: a given onnx node
-        Args:
-            inputs: the input tensor
-        Args:
-            opset_version: the opset version
+            onnx_node (OnnxNode): a given onnx node
+            operator (Operator Class): a singa operator class
+            
+            opset_version (int): the opset version
         Returns: 
-            handle, the handle of singa operator
-        Returns: 
-            forward, the autograd of singa operator
-        """
-        # sometime onnx may ignore these two inputs, min or max or both
-        if len(inputs) >= 2 and onnx_node.inputs[1] != '':
-            min_v = tensor.to_numpy(inputs.pop(1)).tolist()[0]
-        else:
-            min_v = None
-        if len(inputs) >= 2 and onnx_node.inputs[2] != '':
-            max_v = tensor.to_numpy(inputs.pop(1)).tolist()[0]
-        else:
-            max_v = None
-        onnx_node.consumed_inputs.extend(onnx_node.inputs[1:])
-        _, forward = cls._common_onnx_node_to_singa_op(onnx_node, inputs,
-                                                       opset_version)
-        return _, forward(min_v, max_v)
-
-    @classmethod
-    def _create_hardsigmoid(cls, onnx_node, inputs, opset_version):
-        """
-        get the HardSigmoid operator from onnx node
-        Args:
-            onnx_node: a given onnx node
-        Args:
-            inputs: the input tensor
-        Args:
-            opset_version: the opset version
-        Returns: 
-            handle, the handle of singa operator
-        Returns: 
-            forward, the autograd of singa operator
+            singa operator instance
         """
         alpha = onnx_node.getattr("alpha", 0.2)
         beta = onnx_node.getattr("beta", 0.5)
-        _, forward = cls._common_onnx_node_to_singa_op(onnx_node, inputs,
-                                                       opset_version)
-        return _, forward(alpha, beta)
+        return operator(alpha, beta)
 
     @classmethod
-    def _create_elu(cls, onnx_node, inputs, opset_version):
+    def _create_elu(cls, onnx_node, operator, opset_version=_opset_version):
         """
         get the elu operator from onnx node
         Args:
-            onnx_node: a given onnx node
-        Args:
-            inputs: the input tensor
-        Args:
-            opset_version: the opset version
+            onnx_node (OnnxNode): a given onnx node
+            operator (Operator Class): a singa operator class
+            opset_version (int): the opset version
         Returns: 
-            handle, the handle of singa operator
-        Returns: 
-            forward, the autograd of singa operator
+            singa operator instance
         """
         alpha = onnx_node.getattr("alpha", 1.)
-        _, forward = cls._common_onnx_node_to_singa_op(onnx_node, inputs,
-                                                       opset_version)
-        return _, forward(alpha)
+        return operator(alpha)
 
     @classmethod
-    def _create_selu(cls, onnx_node, inputs, opset_version):
+    def _create_selu(cls, onnx_node, operator, opset_version=_opset_version):
         """
         get the selu operator from onnx node
         Args:
-            onnx_node: a given onnx node
-        Args:
-            inputs: the input tensor
-        Args:
-            opset_version: the opset version
+            onnx_node (OnnxNode): a given onnx node
+            operator (Operator Class): a singa operator class
+            opset_version (int): the opset version
         Returns: 
-            handle, the handle of singa operator
-        Returns: 
-            forward, the autograd of singa operator
+            singa operator instance
         """
         alpha = onnx_node.getattr("alpha", 1.67326)
         gamma = onnx_node.getattr("gamma", 1.0507)
-        _, forward = cls._common_onnx_node_to_singa_op(onnx_node, inputs,
-                                                       opset_version)
-        return _, forward(alpha, gamma)
+        return operator(alpha, gamma)
 
     @classmethod
-    def _create_reshape(cls, onnx_node, inputs, opset_version):
+    def _create_concat(cls, onnx_node, operator, opset_version=_opset_version):
+        """
+        get the concat operator from onnx node
+        Args:
+            onnx_node (OnnxNode): a given onnx node
+            operator (Operator Class): a singa operator class
+            opset_version (int): the opset version
+        Returns: 
+            singa operator instance
+        """
+        factor = onnx_node.getattr('axis')
+        return operator(axis=factor)
+
+    @classmethod
+    def _create_softmax(cls, onnx_node, operator, opset_version=_opset_version):
+        """
+        get the softmax operator from onnx node
+        Args:
+            onnx_node (OnnxNode): a given onnx node
+            operator (Operator Class): a singa operator class
+            opset_version (int): the opset version
+        Returns: 
+            singa operator instance
+        """
+        factor = onnx_node.getattr('axis', 1)
+        return operator(axis=factor)
+
+    @classmethod
+    def _create_gemm(cls, onnx_node, operator, opset_version=_opset_version):
+        """
+        get the gemm operator from onnx node
+        Args:
+            onnx_node (OnnxNode): a given onnx node
+            operator (Operator Class): a singa operator class
+            opset_version (int): the opset version
+        Returns: 
+            singa operator instance
+        """
+        alpha = onnx_node.getattr('alpha', 1.)
+        beta = onnx_node.getattr('beta', 1.)
+        transA = onnx_node.getattr('transA', 0)
+        transB = onnx_node.getattr('transB', 0)
+        onnx_node.set_weight_inputs(onnx_node.inputs[1], 'W')
+        bias = False
+        if len(onnx_node.inputs) == 3:
+            onnx_node.set_weight_inputs(onnx_node.inputs[2], 'b')
+            bias = True
+        return operator(None,
+                        alpha=alpha,
+                        beta=beta,
+                        transA=transA,
+                        transB=transB,
+                        bias=bias)
+
+    @classmethod
+    def _create_flatten(cls, onnx_node, operator, opset_version=_opset_version):
+        """
+        get the flatten operator from onnx node
+        Args:
+            onnx_node (OnnxNode): a given onnx node
+            operator (Operator Class): a singa operator class
+            opset_version (int): the opset version
+        Returns: 
+            singa operator instance
+        """
+        factor = onnx_node.getattr('axis', 1)
+        return operator(axis=factor)
+
+    @classmethod
+    def _create_onehot(cls, onnx_node, operator, opset_version=_opset_version):
+        """
+        get the OneHot operator from onnx node
+        Args:
+            onnx_node (OnnxNode): a given onnx node
+            operator (Operator Class): a singa operator class
+            opset_version (int): the opset version
+        Returns: 
+            singa operator instance
+        """
+        axis = onnx_node.getattr("axis", -1)
+        onnx_node.set_attr_inputs(onnx_node.inputs[1], 'depth')
+        onnx_node.set_attr_inputs(onnx_node.inputs[2], 'values')
+        return operator(axis, None, None)
+
+    @classmethod
+    def _create_tile(cls, onnx_node, operator, opset_version=_opset_version):
+        """
+        get the Tile operator from onnx node
+        Args:
+            onnx_node (OnnxNode): a given onnx node
+            operator (Operator Class): a singa operator class
+            opset_version (int): the opset version
+        Returns: 
+            singa operator instance
+        """
+        onnx_node.set_attr_inputs(onnx_node.inputs[1], 'repeats')
+        return operator(None)
+
+    @classmethod
+    def _create_gather(cls, onnx_node, operator, opset_version=_opset_version):
+        """
+        get the Gather operator from onnx node
+        Args:
+            onnx_node (OnnxNode): a given onnx node
+            operator (Operator Class): a singa operator class
+            opset_version (int): the opset version
+        Returns: 
+            singa operator instance
+        """
+        axis = onnx_node.getattr("axis", 0)
+        onnx_node.set_attr_inputs(onnx_node.inputs[1], 'indices')
+        return operator(axis, None)
+
+    @classmethod
+    def _create_reshape(cls, onnx_node, operator, opset_version=_opset_version):
         """
         get the reshape operator from onnx node
         Args:
-            onnx_node: a given onnx node
-        Args:
-            inputs: the input tensor
-        Args:
-            opset_version: the opset version
+            onnx_node (OnnxNode): a given onnx node
+            operator (Operator Class): a singa operator class
+            opset_version (int): the opset version
         Returns: 
-            the handle of singa operator
-        Returns: 
-            the autograd of singa operator
+            singa operator instance
         """
-        shape = tensor.to_numpy(inputs.pop(1)).astype(np.int32).tolist()
-        onnx_node.consumed_inputs.append(onnx_node.inputs[1])
-        _, forward = cls._common_onnx_node_to_singa_op(onnx_node, inputs,
-                                                       opset_version)
-        return _, forward(shape)
+        onnx_node.set_attr_inputs(onnx_node.inputs[1], 'shape')
+        return operator(None)
 
     @classmethod
-    def _create_conv(cls, onnx_node, inputs, opset_version):
+    def _create_slice(cls, onnx_node, operator, opset_version=_opset_version):
         """
-        get the conv operator from onnx node
+        get the Slice operator from onnx node
         Args:
-            onnx_node: a given onnx node
-        Args:
-            inputs: the input tensor
-        Args:
-            opset_version: the opset version
+            onnx_node (OnnxNode): a given onnx node
+            operator (Operator Class): a singa operator class
+            opset_version (int): the opset version
         Returns: 
-            handle, the handle of singa operator
-        Returns: 
-            forward, the autograd of singa operator
+            singa operator instance
         """
-        kernel = tuple(onnx_node.attrs["kernel_shape"])
-        padding = tuple(
-            onnx_node.attrs["pads"]) if "pads" in onnx_node.attrs else (0, 0)
+        onnx_node.set_attr_inputs(onnx_node.inputs[1], 'starts')
+        onnx_node.set_attr_inputs(onnx_node.inputs[2], 'ends')
+        if len(onnx_node.inputs) >= 4 and onnx_node.inputs[3] != '':
+            onnx_node.set_attr_inputs(onnx_node.inputs[3], 'axes')
+        if len(onnx_node.inputs) == 5 and onnx_node.inputs[4] != '':
+            onnx_node.set_attr_inputs(onnx_node.inputs[4], 'steps')
+        return operator(None, None, None, None)
+
+    @classmethod
+    def _create_clip(cls, onnx_node, operator, opset_version=_opset_version):
+        """
+        get the clip operator from onnx node
+        Args:
+            onnx_node (OnnxNode): a given onnx node
+            operator (Operator Class): a singa operator class
+            opset_version (int): the opset version
+        Returns: 
+            singa operator instance
+        """
+        if len(onnx_node.inputs) >= 2 and onnx_node.inputs[1] != '':
+            onnx_node.set_attr_inputs(onnx_node.inputs[1], 'min')
+        if len(onnx_node.inputs) == 3 and onnx_node.inputs[2] != '':
+            onnx_node.set_attr_inputs(onnx_node.inputs[2], 'max')
+        return operator(None, None)
+
+    @classmethod
+    def _create_batch_norm(cls,
+                           onnx_node,
+                           operator,
+                           opset_version=_opset_version):
+        """
+        get the clip operator from onnx node
+        Args:
+            onnx_node (OnnxNode): a given onnx node
+            operator (Operator Class): a singa operator class
+            opset_version (int): the opset version
+        Returns: 
+            singa operator instance
+        """
+        factor = onnx_node.getattr('momentum', 0.9)
+        onnx_node.set_weight_inputs(onnx_node.inputs[1], 'scale')
+        onnx_node.set_weight_inputs(onnx_node.inputs[2], 'bias')
+        onnx_node.set_weight_inputs(onnx_node.inputs[3], 'running_mean')
+        onnx_node.set_weight_inputs(onnx_node.inputs[4], 'running_var')
+        return operator(factor)
+
+    @classmethod
+    def _create_conv(cls, onnx_node, operator, opset_version=_opset_version):
+        """
+        get the clip operator from onnx node
+        Args:
+            onnx_node (OnnxNode): a given onnx node
+            operator (Operator Class): a singa operator class
+            opset_version (int): the opset version
+        Returns: 
+            singa operator instance
+        """
+        kernel_size = tuple(onnx_node.getattr('kernel_shape'))
+        padding = tuple(onnx_node.getattr('pads', (0, 0)))
         stride = tuple(onnx_node.getattr('strides', (1, 1)))
-        # default the odd_padding is 0, once there are same pad mode, we modify it
-        # for odd_padding, please refer the autegrade.py
-        odd_padding = (0, 0, 0, 0)
-        if "auto_pad" in onnx_node.attrs:
-            auto_pad = utils.force_unicode(onnx_node.attrs['auto_pad'])
-            if auto_pad in ('SAME_UPPER', 'SAME_LOWER'):
-                padding, odd_padding = utils.get_padding_shape(
-                    auto_pad, inputs[0].shape[2:], kernel, stride)
+        auto_pad = utils.force_unicode(onnx_node.getattr('auto_pad', 'NOTSET'))
 
         # not support dilation
         dilation = onnx_node.getattr('dilations', 1)
@@ -1576,563 +1657,567 @@
         group = onnx_node.getattr('group', 1)
 
         # only support 1d or 2d
-        if len(kernel) > 2:
+        if len(kernel_size) > 2:
             raise ValueError("Only implemented for 1d or 2d")
 
-        bias = len(inputs) == 3
-        x = inputs[0]
-        x_shape = inputs[0].shape
-        in_channels = x_shape[1]
-        w_shape = inputs[1].shape
-        out_channels = w_shape[0]
-        assert w_shape[1] == in_channels // group
-
-        if inputs[0].device.id() == -1:
-            if group != 1:
-                raise NotImplementedError
-            else:
-                handle = singa.ConvHandle(x.data, kernel, stride, padding,
-                                          in_channels, out_channels, bias,
-                                          group)
-        else:
-            handle = singa.CudnnConvHandle(x.data, kernel, stride, padding,
-                                           in_channels, out_channels, bias,
-                                           group)
-
-        _, forward = cls._common_onnx_node_to_singa_op(onnx_node, inputs,
-                                                       opset_version)
-        return _, forward(handle, odd_padding)
+        onnx_node.set_weight_inputs(onnx_node.inputs[1], 'W')
+        bias = False
+        if len(onnx_node.inputs) == 3:
+            onnx_node.set_weight_inputs(onnx_node.inputs[2], 'b')
+            bias = True
+        return operator(None,
+                        kernel_size,
+                        stride=stride,
+                        padding=padding,
+                        dilation=dilation,
+                        group=group,
+                        bias=bias,
+                        pad_mode=auto_pad)
 
     @classmethod
-    def _create_max_avg_pool(cls, onnx_node, inputs, opset_version):
+    def _create_max_avg_pool(cls,
+                             onnx_node,
+                             operator,
+                             opset_version=_opset_version):
         """
-        get the max or avg pool operator from onnx node
+        get the clip operator from onnx node
         Args:
-            onnx_node: a given onnx node
-        Args:
-            inputs: the input tensor
-        Args:
-            opset_version: the opset version
+            onnx_node (OnnxNode): a given onnx node
+            operator (Operator Class): a singa operator class
+            opset_version (int): the opset version
         Returns: 
-            handle, the handle of singa operator
-        Returns: 
-            forward, the autograd of singa operator
+            singa operator instance
         """
-        kernel = tuple(onnx_node.attrs["kernel_shape"])
-        padding = tuple(
-            onnx_node.attrs["pads"]) if "pads" in onnx_node.attrs else (0, 0)
+        kernel_size = tuple(onnx_node.getattr('kernel_shape'))
+        padding = tuple(onnx_node.getattr('pads', (0, 0)))
         stride = tuple(onnx_node.getattr('strides', (1, 1)))
-        # default the odd_padding is 0, once there are same pad mode, we modify it
-        # for odd_padding, please refer the autegrade.py
-        odd_padding = (0, 0, 0, 0)
-        if "auto_pad" in onnx_node.attrs:
-            auto_pad = utils.force_unicode(onnx_node.attrs['auto_pad'])
-            if auto_pad in ('SAME_UPPER', 'SAME_LOWER'):
-                padding, odd_padding = utils.get_padding_shape(
-                    auto_pad, inputs[0].shape[2:], kernel, stride)
+        auto_pad = utils.force_unicode(onnx_node.getattr('auto_pad', 'NOTSET'))
 
         # not support count_include_pad and auto_pad
-        if "count_include_pad" in onnx_node.attrs or "ceil_mode" in onnx_node.attrs:
+        ceil_mode = onnx_node.getattr('ceil_mode', 0)
+        count_include_pad = onnx_node.getattr('count_include_pad', 0)
+        if ceil_mode != 0 or count_include_pad != 0:
             raise ValueError(
                 "Not implemented yet for count_include_pad or ceil_mode")
 
-        # only support 2d
-        if len(kernel) != 2:
-            raise ValueError("Not implemented yet")
+        # only support 1d or 2d
+        if len(kernel_size) > 2:
+            raise ValueError("Only implemented for 1d or 2d")
 
         is_max = onnx_node.op_type == 'MaxPool'
-        x = inputs[0]
-        if x.device.id() == -1:
-            handle = singa.PoolingHandle(x.data, kernel, stride, padding,
-                                         is_max)
-        else:
-            handle = singa.CudnnPoolingHandle(x.data, kernel, stride, padding,
-                                              is_max)
-
-        _, forward = cls._common_onnx_node_to_singa_op(onnx_node, inputs,
-                                                       opset_version)
-        return _, forward(handle, odd_padding)
+        return operator(kernel_size, stride, padding, is_max, auto_pad)
 
     @classmethod
-    def _create_batchnorm(cls, onnx_node, inputs, opset_version):
+    def _create_scatter_elements(cls,
+                                 onnx_node,
+                                 operator,
+                                 opset_version=_opset_version):
         """
-        get the batch norm operator from onnx node
-        Args:onnx_node: a given onnx node
-        Args:inputs: the input tensor
-        Args:opset_version: the opset version
-        Returns: the handle of singa operator
-        Returns: the autograd of singa operator
+        get the ScatterElements from the onnx node
+        Args:
+            onnx_node(OnnxNode): a given onnx node
+            operator (Operator Class): a singa operator class
+            opset_version(int): the opset version
+        Returns: 
+            singa operator instance      
         """
-        x = inputs[0]
-        factor = onnx_node.getattr('momentum', 0.9)
-        if x.device.id() == -1:
-            handle = singa.BatchNormHandle(factor, x.data)
-        else:
-            handle = singa.CudnnBatchNormHandle(factor, x.data)
-
-        _, forward = cls._common_onnx_node_to_singa_op(onnx_node, inputs,
-                                                       opset_version)
-        return handle, forward
+        axis = onnx_node.getattr("axis", 0)
+        onnx_node.set_attr_inputs(onnx_node.inputs[1], 'indices')
+        onnx_node.set_attr_inputs(onnx_node.inputs[2], 'updates')
+        return operator(None, None, axis)
 
     @classmethod
-    def _create_concat(cls, onnx_node, inputs, opset_version):
+    def _onnx_constant_to_np(cls, onnx_node, opset_version=_opset_version):
         """
-        get the concat operator from onnx node
+        parse onnx constatn node to numpy array
         Args:
-            onnx_node: a given onnx node
-        Args:
-            inputs: the input tensor
-        Args:
-            opset_version: the opset version
+            onnx_node (OnnxNode): a given onnx node
+            opset_version (int): the opset version
         Returns: 
-            the handle of singa operator
-        Returns: 
-            the autograd of singa operator
+            a numpy ndarray
         """
-        factor = onnx_node.attrs["axis"]
-        if factor < 0:
-            factor = len(inputs[0].shape
-                        ) + factor  # in order to support the negative axis
-        _, forward = cls._common_onnx_node_to_singa_op(onnx_node, inputs,
-                                                       opset_version)
-        return None, forward(axis=factor)
+        onnx_tensor = onnx_node.getattr('value')
+        np_dtype = mapping.TENSOR_TYPE_TO_NP_TYPE[onnx_tensor.data_type]
+        return np.frombuffer(onnx_tensor.raw_data, dtype=np_dtype)
 
     @classmethod
-    def _create_softmax(cls, onnx_node, inputs, opset_version):
+    def _onnx_node_to_singa_op(cls, onnx_node, opset_version=_opset_version):
         """
-        get the concat operator from onnx node
+        get singa operator from a onnx node
         Args:
-            onnx_node: a given onnx node
-        Args:
-            inputs: the input tensor
-        Args:
-            opset_version: the opset version
+            onnx_node (OnnxNode): a given onnx node
+            opset_version (int): the opset version
         Returns: 
-            the handle of singa operator
-        Returns: 
-            the autograd of singa operator
-        """
-        factor = onnx_node.getattr('axis', 1)
-        if factor < 0:
-            # in order to support the negative axis
-            factor = len(inputs[0].shape) + factor
-        _, forward = cls._common_onnx_node_to_singa_op(onnx_node, inputs,
-                                                       opset_version)
-        return None, forward(axis=factor)
-
-    @classmethod
-    def _create_gemm(cls, onnx_node, inputs, opset_version):
-        """
-        get the gemm operator from onnx node
-        Args:
-            onnx_node: a given onnx node
-        Args:
-            inputs: the input tensor
-        Args:
-            opset_version: the opset version
-        Returns: 
-            the handle of singa operator
-        Returns: 
-            the autograd of singa operator
-        """
-        x = inputs[0]
-        alpha = onnx_node.getattr('alpha', 1.)
-        beta = onnx_node.getattr('beta', 1.)
-        transA = onnx_node.getattr('transA', 0)
-        transB = onnx_node.getattr('transB', 0)
-        _, forward = cls._common_onnx_node_to_singa_op(onnx_node, inputs,
-                                                       opset_version)
-        return None, forward(alpha=alpha,
-                             beta=beta,
-                             transA=transA,
-                             transB=transB)
-
-    @classmethod
-    def _create_flatten(cls, onnx_node, inputs, opset_version):
-        """
-        get the flatten operator from onnx node
-        Args:
-            onnx_node: a given onnx node
-        Args:
-            inputs: the input tensor
-        Args:
-            opset_version: the opset version
-        Returns: 
-            the handle of singa operator
-        Returns: 
-            the autograd of singa operator
-        """
-        factor = onnx_node.getattr('axis', 1)
-        if factor < 0:
-            # in order to support the negative axis
-            factor = len(inputs[0].shape) + factor
-
-        _, forward = cls._common_onnx_node_to_singa_op(onnx_node, inputs,
-                                                       opset_version)
-        return None, forward(axis=factor)
-
-    @classmethod
-    def _common_onnx_node_to_singa_op(cls, onnx_node, inputs, opset_version):
-        """
-        get a common singa operator(only autograd) from a onnx node
-        other special operators also can call this func to get autograd
-        Args:
-            onnx_node: a given onnx node
-        Args:
-            tensor_map: the input tensor
-        Args:
-            opset_version: the opset version
-        Returns: 
-            a dict of tensors
-        Returns: 
-            a list of SingaOps('name', 'op', 'handle', 'forward')
+            singa operator instance
         """
         onnx_op_type = onnx_node.op_type
         assert onnx_op_type in cls._rename_operators, "not support operator: {}".format(
             onnx_op_type)
-        autograd_op = getattr(autograd, cls._rename_operators[onnx_op_type])
-        return None, autograd_op
-
-    @classmethod
-    def _onnx_node_to_singa_op(cls,
-                               onnx_node,
-                               inputs,
-                               opset_version=_known_opset_version):
-        """
-        get a singa operator(handle and autograd) from a onnx node
-        Args:
-            onnx_node: a given onnx node
-        Args:
-            inputs: the input list
-        Args:
-            opset_version: the opset version
-        Returns: 
-            a dict of tensors
-        Returns: 
-            a list of SingaOps('name', 'op', 'handle', 'forward')
-        """
+        renamed_op = cls._rename_operators[onnx_op_type]
+        if renamed_op.startswith('layer.'):
+            op_class = getattr(layer, renamed_op[6:])
+        else:
+            op_class = getattr(autograd, renamed_op)
         if onnx_node.op_type in cls._special_operators:
             translator = getattr(cls, cls._special_operators[onnx_node.op_type])
+            op = translator(onnx_node, op_class, opset_version)
+            op.name = onnx_node.name
         else:
-            translator = cls._common_onnx_node_to_singa_op
-        return translator(onnx_node, inputs, opset_version)
+            op = op_class()
+        # refine the ONNXNode
+        onnx_node.inputs = [inp for inp in onnx_node.inputs if inp != '']
+        return op
 
     @classmethod
-    def run_node(cls, onnx_node, inputs, opset_version=_known_opset_version):
+    def run_node(cls, node, inputs, device='CPU', opset_version=_opset_version):
         """
         run a single singa operator from a onnx node
         Args:
-            onnx_node: a given onnx node
-        Args:
-            inputs: the input tensor
-        Args:
-            device: the used device
-        Args:
-            opset_version: the opset version
-        Returns: 
-            list, the output of the 
+            node (NodeProto): a given onnx node
+            inputs (ndarray[]): a list of numpy ndarray
+            device (string): CPU or CUDA
+            opset_version (int): the opset version
+        Returns:
+            list, the output
         """
-        valid_inputs = [x for x in onnx_node.inputs if x != ""]
+        node = OnnxNode(node)
+        valid_inputs = [x for x in node.inputs if x != ""]
         assert len(valid_inputs) == len(
-            inputs), "{}: expected {} but got {}".format(
-                onnx_node.op_type, len(valid_inputs), len(inputs))
+            inputs), "{}: expected {} inputs, but got {}. ".format(
+                node.op_type, len(valid_inputs), len(inputs))
 
-        tmp_inputs = [inputs[x] for x in onnx_node.inputs if x != ""]
-        handle, forward = cls._onnx_node_to_singa_op(onnx_node, tmp_inputs,
-                                                     opset_version)
-        # only give the inputs it needs
-        # consumed_inputs are the inputs marked as attributes
-        # so we remove it here
-        tmp_inputs = [
-            inputs[x]
-            for x in onnx_node.inputs
-            if x not in onnx_node.consumed_inputs
-        ]
-        return cls._run_node(onnx_node, tmp_inputs, handle, forward,
-                             opset_version)
-
-    @classmethod
-    def _run_node(cls,
-                  onnx_node,
-                  inputs,
-                  handle,
-                  forward,
-                  opset_version=_known_opset_version):
-        """
-        run a single singa operator from a onnx node
-        Args:inputs: 
-            the input tensor
-        Args:handle: 
-            the handle of singa operator
-        Args:forward: 
-            the forward of singa operator
-        Args:
-            opset_version: the opset version
-        Returns: 
-            list, the output of the
-        """
-        outputs = forward(*inputs) if handle is None else forward(
-            handle, *inputs)
-        if not isinstance(outputs, collections.Iterable):
-            outputs = [outputs]
+        operator = cls._onnx_node_to_singa_op(node, opset_version)
+        # seperate weights with inputs, and init inputs as Tensor
+        weights = {}
+        _inputs = []
+        for (key, val) in zip(valid_inputs, inputs):
+            val = val.astype(onnx_type_to_singa_type(val.dtype))
+            if key in node.weight_inputs:
+                weights[key] = val
+            else:
+                x = tensor.from_numpy(val)
+                if device != 'CPU':
+                    assert singa.USE_CUDA, "Your SINGA doesn't compile GPU module."
+                    dev = device.create_cuda_gpu()
+                else:
+                    dev = device.get_default_device()
+                x.to_device(dev)
+                _inputs.append(x)
+        inputs = _inputs
+        # set params
+        params = {}
+        for key, name in node.weight_inputs.items():
+            params[name] = weights[key]
+        operator.set_params(params)
+        outputs = cls._run_node(operator, inputs)
         outputs_dict = OrderedDict()
-        for (key, val) in zip(onnx_node.outputs, outputs):
+        for (key, val) in zip(node.outputs, outputs):
             outputs_dict[key] = val
         return outputs_dict
 
     @classmethod
-    def _init_graph_parameter(cls, graph, init_inputs, device):
+    def _run_node(cls, operator, inputs):
         """
-        init the singa tensor from onnx infos
+        run a single singa operator from singa operator
         Args:
-            graph: a given onnx graph
-        Args:
-            init_inputs: a list of inputs, which used to init the operators
-        Args:
-            device: the used device
+            operator (Operator): the Operator instance
+            inputs (Tensor[]): a list of SINGA Tensor
         Returns:
-            a dict of tensors
+            list, the output
         """
-        tensor_map = {}
-        # due to https://github.com/onnx/onnx/issues/2417
-        # sometimes, input contains all initializer's info
-        # sometimes, may not
-        all_inputs = OrderedDict()
+        outputs = operator(*inputs)
+        if not isinstance(outputs, collections.Iterable):
+            outputs = [outputs]
+        return outputs
+
+    @classmethod
+    def _parse_graph_params(cls, graph, device):
+        """
+        parse the parameters from onnx graph
+        Args:
+            graph (Graph): a given onnx graph
+            device (string): CPU or CUDA
+        Returns:
+            a dict of numpy ndarray
+        """
+        params = {}
+        for tp in graph.initializer:
+            val = numpy_helper.to_array(tp)
+            val = val.astype(onnx_type_to_singa_type(tp.data_type))
+            params[tp.name] = val
+        return params
+
+    @classmethod
+    def _parse_graph_inputs_outputs(cls, graph, params, device):
+        """
+        parse the inits, outputs from onnx graph
+        Args:
+            graph (Graph): a given onnx graph
+            device (string): # CPU or CUDA
+        Returns:
+            a dict of ValueInfo
+            a dict of ValueInfo
+        """
+        inputs = []
+        outputs = []
+        info_tuple = namedtuple('info_tuple', ['name', 'dtype', 'shape'])
         for t in graph.input:
-            all_inputs[t.name] = t
-        # so we refresh the input by the initializer
-        for t in graph.initializer:
-            all_inputs[t.name] = t
-        initializers = {t.name for t in graph.initializer}
-        inp_idx = 0
-        for name, x in all_inputs.items():
-            if name in initializers:
-                # if it has initializer, we use its value as the input
-                np_tensor = numpy_helper.to_array(x)
-                if np_tensor.dtype == "int64":
-                    np_tensor = np_tensor.astype(np.int32)
-                # todo, we cannot support scalar tensor
-                if np.ndim(np_tensor) == 0:
-                    np_tensor = np.array(np_tensor, ndmin=1)
-            else:
-                # if not, means it's a input rather than a inner weight
-                # so if the user gives values, we use these values
-                # if not, we just use the shape of input gived by onnx to init a random value
-                # HOWEVER, the random value may not be correct for some inputs, such as gather which needs indices
-                # so if have operators, the user must give inputs
-                x_shape = tuple(
-                    dim.dim_value for dim in x.type.tensor_type.shape.dim)
-                if init_inputs is not None:
-                    np_tensor = init_inputs[inp_idx]
-                    inp_idx += 1
-                else:
-                    np_tensor = np.random.randn(*x_shape).astype(np.float32)
-            tmp_tensor = tensor.from_numpy(np_tensor)
-            tmp_tensor.to_device(device)
-            # todo, for backward
-            tmp_tensor.stores_grad = (name in initializers)
-            tensor_map[x.name] = tmp_tensor
-        return tensor_map
+            if t.name not in params:
+                dtype = t.type.tensor_type.elem_type
+                shape = [dim.dim_value for dim in t.type.tensor_type.shape.dim]
+                inputs.extend([info_tuple(t.name, dtype, shape)])
+        for t in graph.output:
+            dtype = t.type.tensor_type.elem_type
+            shape = [dim.dim_value for dim in t.type.tensor_type.shape.dim]
+            outputs.extend([info_tuple(t.name, dtype, shape)])
+        return inputs, outputs
 
     @classmethod
-    def _onnx_model_to_singa_net(cls, model, init_inputs, device,
-                                 opset_version):
+    def _onnx_model_to_singa_ops(cls,
+                                 graph,
+                                 device,
+                                 opset_version=_opset_version):
         """
-        get all intermediate tensors and operators from onnx model
+        get all intermediate params, operators, and input info from onnx model
         Args:
-            model: a given onnx model
-        Args:
-            init_inputs: a list of inputs, which used to init the operators
-        Args:
-            device: the used device
-        Args:
-            opset_version: the opset version
+            graph (Graph): the loaded ONNX graph
+            device (string): CPU or CUDA
+            opset_version (int): the opset version
         Returns:
-            a dict of tensors
-        Returns:
-            a list of SingaOps('name', 'op', 'handle', 'forward')
+            a dict of weights
+            a dict of ValueInfo
+            a dict of ValueInfo
+            a list of SingaOps('node', 'forward')
         """
-        # init all tensor input and weight as a tensor map
-        tensor_map = cls._init_graph_parameter(model.graph, init_inputs, device)
-        # only weights tensor
-        weights = {x.name: tensor_map[x.name] for x in model.graph.initializer}
+        # init all tensor input and params as a tensor map
+        params = cls._parse_graph_params(graph, device)
+        inputs, outputs = cls._parse_graph_inputs_outputs(graph, params, device)
         # the parsed operators queue
-        singa_ops = []
-        singa_op = namedtuple('SingaOps', ['name', 'op', 'handle', 'forward'])
-        for node in model.graph.node:
+        operators = []
+        operator_tuple = namedtuple('operator_tuple', ['node', 'operator'])
+        for node in graph.node:
+            if not node.name:
+                node.name = "%s_%d" % (str(node.op_type), len(operators))
             node = OnnxNode(node)
-            # only give the inputs it needs
-            # consumed_inputs are the inputs marked as attributes
-            # so we remove it here
-            inputs = [
-                tensor_map[x]
-                for x in node.inputs
-                if x not in node.consumed_inputs
-            ]
-            handle, forward = cls._onnx_node_to_singa_op(
-                node, inputs, opset_version)
-            # if it is Constant, we hanlde it as a weight
-            # otherwise, we run it and add its output into map for being used by later operators
+            # convert Constant to param
             if node.op_type == 'Constant':
-                tmp_tensor = tensor.from_numpy(forward)
-                tmp_tensor.to_device(device)
-                tmp_name = node.outputs.pop(0)
-                weights[tmp_name] = tmp_tensor
-                tensor_map[tmp_name] = tmp_tensor
+                params[node.outputs[0]] = cls._onnx_constant_to_np(node)
             else:
-                outputs = cls._run_node(node, inputs, handle, forward)
-                for key, val in outputs.items():
-                    tensor_map[key] = val
-                singa_ops.extend([singa_op(node.name, node, handle, forward)])
-        return weights, singa_ops
+                op = cls._onnx_node_to_singa_op(node, opset_version)
+                operators.append(operator_tuple(node, op))
+        return params, inputs, outputs, operators
 
     @classmethod
-    def prepare(cls, model, device, **kwargs):
+    def prepare(cls, model, device='CPU', **kwargs):
         """
-        get the batch norm operator from onnx node
+        parse the ONNX and to create layers
         Args:
-            model: a given onnx node
-        Args:
-            device: the used device
-        Returns: 
-            a list of output values
+            model (ModelProto): the loaded ONNX model
+            device (string): CPU or CUDA
+        Returns:
+            a SingaRep instance to stores the layers and weights
         """
         super(SingaBackend, cls).prepare(model, device, **kwargs)
-        # when parsing graph, we use the shape of input gived by onnx to init a random value
-        # HOWEVER, the random value may not be correct for some inputs, such as gather which needs indices
-        # so if have operators, the user must give inputs
-        init_inputs = kwargs.get("init_inputs", None)
-        # whether initializers are moved into inputs, due to https://github.com/onnx/onnx/issues/2417
-        # sometimes, input contains all initializer's info, sometimes, may not
-        cls.keep_initializers_as_inputs = kwargs.get(
-            'keep_initializers_as_inputs', True)
         # optimize and infer the shape of the model
         try:
             model = onnx.utils.polish_model(model)
         except IndexError as err:
-            # due to https://github.com/onnx/onnx/issues/2417
-            model = onnx.shape_inference.infer_shapes(model)
+            model = shape_inference.infer_shapes(model)
 
         # check the opset version and ir version
+        # SINGA supports opset version(11), ir version(1.6.0 -> 6)
         opset_version = None
         for imp in model.opset_import:
             if not imp.HasField("domain") or imp.domain == "":
                 opset_version = imp.version
-                if imp.version > cls._known_opset_version:
+                if imp.version > cls._opset_version:
                     warnings.warn(
-                        "This version of singa targets ONNX operator set version {}, but the model we are trying to import uses version {}.  We will try to import it anyway, but if the model uses operators which had BC-breaking changes in the intervening versions, import will fail."
-                        .format(cls._known_opset_version, imp.version))
+                        "The imported opertor set verion {} is larger than the supported version {}."
+                        .format(imp.version, cls._opset_version))
             else:
                 warnings.warn("Unrecognized operator set {}".format(imp.domain))
-        if opset_version is None:
-            if model.ir_version >= 0x00000003:
-                raise RuntimeError(
-                    "Model with IR version >= 3 did not specify ONNX operator set version (singa requires it)"
-                )
-            else:
-                opset_version = 1
-        weights, singa_ops = cls._onnx_model_to_singa_net(
-            model, init_inputs, device, opset_version)
-        return SingaRep(model, weights, singa_ops,
-                        cls.keep_initializers_as_inputs)
+
+        if model.ir_version > cls._ir_version:
+            warnings.warn(
+                "The imported ir verion {} is larger than the supported version {}."
+                .format(cls._ir_version, imp.version))
+
+        graph = model.graph
+        params, inputs, outputs, layers = cls._onnx_model_to_singa_ops(
+            graph, device, opset_version)
+        return SingaRep(params, inputs, outputs, layers, device)
 
 
 class SingaRep(BackendRep):
 
-    def __init__(self,
-                 model,
-                 weights,
-                 singa_ops,
-                 keep_initializers_as_inputs=True):
+    def __init__(self, params, inputs, outputs, layers, device):
         """
+        https://github.com/onnx/onnx/blob/master/docs/ImplementingAnOnnxBackend.md
         SingaRep provides the intermediate representation of Singa,
         the user can run the forward of the singa model by run func,
         or, the user can append more layers after the singa_ops to do
         the transfer learning
         Args:
-            model: a given operator
-        Args:
-            weights: the tensor of weights
-        Args:
-            singa_ops: the tensor of the operator
+            params (dict{}): a dict of params, data type is numpy ndarray
+            inputs (ValueInfo): a dict of inputs
+            outputs (ValueInfo): a dict of outputs
+            layers (namedtuple('operator_tuple', ['node', 'operator'])[]): a list of singa operator
+            device (string): CPU or CUDA
         """
         super(SingaRep, self).__init__()
-        self.model = model
-        self.tensor_map = weights
-        self.keep_initializers_as_inputs = keep_initializers_as_inputs
-        # this each item of singa_ops is: ('name', 'op', 'handle', 'forward')
-        # the name is a string, op is OnnxNode,
-        # handle is Singa handle to store the tensor into singa operator
-        # the forward is singa autograd operator
-        self.singa_ops = singa_ops
+        self.inputs = inputs
+        self.states = params
+        self.outputs = outputs
+        self.dev = cpu_dev if device == "CPU" else gpu_dev
+        self.layers = layers
+        self.tensor_count = {}
+        self.has_initialized = False
+        self.is_graph = False
 
-    def run(self, inputs, **kwargs):
+    def initialize(self):
+        """
+        Init the instance
+        """
+        self.outputs_info = {outp.name: outp for outp in self.outputs}
+        _layers = []  # layers by topo order
+        for node, operator in self.layers:
+            _del_keys = []
+            for key, name in node.weight_inputs.items():
+                if key not in self.states:
+                    # cannot find the weights, try to find it from input
+                    node.set_attr_inputs(key, name)
+                    _del_keys.append(key)
+            for key in _del_keys:
+                node.del_weight_inputs(key)
+            self.__dict__[node.name] = operator
+            _layers.append(node)
+        self._layers = _layers
+
+    def init_tensor_count(self):
+        """
+        Init the tensor count dict
+        """
+        self.tensor_count = {}
+        for node, operator in self.layers:
+            # init the tensor count
+            all_possible_inputs = node.inputs + list(
+                node.attr_inputs.keys()) + list(node.weight_inputs.keys())
+            for inp in all_possible_inputs:
+                if inp not in self.tensor_count:
+                    self.tensor_count[inp] = 1
+                else:
+                    self.tensor_count[inp] += 1
+
+    def to_input_tensor(self, x):
+        """
+        convert the input to tensors
+        Args:
+            x (np.ndarray[]): a list of numpy ndarray as inputs
+        Returns: 
+            a dict of SINGA Tensors
+        """
+        tensor_dict = {}
+        # init inputs as Tensor
+        for (key, val) in zip(self.inputs, x):
+            if not self.is_graph:
+                val = val.astype(onnx_type_to_singa_type(key.dtype))
+                # todo, scalar
+                val = np.atleast_1d(val)
+                val = tensor.from_numpy(val)
+                val.to_device(self.dev)
+            tensor_dict[key.name] = val
+        return tensor_dict
+
+    def to_output_tensor(self, y, out_name):
+        """
+        convert the tensors to input
+        Args:
+            x (np.ndarray[]): a list of numpy ndarray as inputs
+        Returns: 
+            a dict of SINGA Tensors
+        """
+        if not self.is_graph:
+            y = tensor.to_numpy(y)
+            if out_name in self.outputs_info:
+                np_dtyp = mapping.TENSOR_TYPE_TO_NP_TYPE[
+                    self.outputs_info[out_name].dtype]
+                y = y.astype(np_dtyp)
+        return y
+
+    def get_s(self, name, node, tensor_dict):
+        """
+        get state from the node's weights or tensor_dict
+        Args:
+            name (str): name of the state
+            node (ONNXNode): ONNX node
+            tensor_dict ({}): tensor dict
+        Returns: 
+            the states
+        """
+        if name in node.attr_inputs:
+            return tensor_dict[name]
+        else:
+            return self.states[name]
+
+    def handle_special_ops(self, node, op, tensor_dict):
+        """
+        hanlde some special operations
+        Args:
+            name (str): name of the state
+            node (ONNXNode): ONNX node
+            tensor_dict ({}): tensor dict
+        Returns: 
+            the states
+        """
+        # todo, hard code
+        # Conv2d nb_kernels
+        if node.op_type == "Conv":
+            shape = self.get_s(node.inputs[1], node, tensor_dict).shape
+            op.nb_kernels = shape[0]
+        # Gemm nb_kernels and bias_shape
+        elif node.op_type == "Gemm":
+            nb_kernels_flag = 0 if op.transB == 1 else -1
+            shape = self.get_s(node.inputs[1], node, tensor_dict).shape
+            op.nb_kernels = shape[nb_kernels_flag]
+            if op.bias:
+                shape = self.get_s(node.inputs[2], node, tensor_dict).shape
+                op.bias_shape = shape
+
+    def run(self, x, **kwargs):
         """
         run the forward of singa model
         Args:
-            inputs: a given operator
+            x (np.ndarray[]): a list of numpy ndarray as inputs
         Returns: 
-            the onnx node
+            a list of outputs
         """
-        graph = self.model.graph
+        if not self.has_initialized:
+            self.initialize()
+            if isinstance(x[0], tensor.Tensor):
+                self.dev = x[0].device
+
+        outputs_dict = OrderedDict([])
+
         # last_layers means we run this model until the last #N layers
-        last_layers = kwargs.get('last_layers', len(self.singa_ops))
-        if last_layers != len(self.singa_ops):
-            final_outputs = self.singa_ops[last_layers-1].op.outputs
+        last_layers = kwargs.get('last_layers', len(self._layers) - 1)
+        last_layers = last_layers if last_layers >= 0 else (
+            last_layers + 1) % len(self._layers)
+        if last_layers != len(self._layers) - 1:
+            for outp in self._layers[last_layers].outputs:
+                outputs_dict[outp] = None
         else:
-            final_outputs =  [outp.name for outp in graph.output]
-        # whether return all outputs
-        all_outputs = kwargs.get('all_outputs', False)
-        # get a specific op by its name
-        op_name = kwargs.get('op_name', None)
-        # record the tensor we added from input
-        tmp_tensor_map = {name: val for name, val in self.tensor_map.items()}
+            for outp in self.outputs:
+                outputs_dict[outp.name] = None
 
-        # the dict will be returned
-        ret_outputs = OrderedDict()
-        if self.keep_initializers_as_inputs:
-            require_input_len = len(graph.input) - len(graph.initializer)
-            actual_input_len = len(inputs)
-        else:
-            require_input_len = len(graph.input)
-            actual_input_len = len(inputs)
-        assert require_input_len == actual_input_len, "The length of graph input is different from the tensor input: %d, %d" % (
-            require_input_len, actual_input_len)
-        # run the handle by the order of the list(the list is Topological Sorting)
-        for inp in graph.input:
-            if inp.name not in tmp_tensor_map:
-                tmp_tensor_map[inp.name] = inputs.pop(0)
+        aux_output = kwargs.get('aux_output', ())
+        for outp in aux_output:
+            outputs_dict[outp] = None
 
-        for _, op, handle, forward in self.singa_ops[:last_layers]:
-            if len(op.consumed_inputs) != 0:
-                # because if op has consumed_inputs, it means it moved some inputs into attributes
-                # so when running, we should update these attributes
-                handle, forward = get_op(op,
-                                         [tmp_tensor_map[x] for x in op.inputs])
-            inputs = [
-                tmp_tensor_map[x]
-                for x in op.inputs
-                if x not in op.consumed_inputs
-            ]
-            outputs = _run_node(op, inputs, handle, forward)
-            for key, val in outputs.items():
-                tmp_tensor_map[key] = val
-                ret_outputs[key] = val
+        tensor_dict = self.to_input_tensor(x)
+        self.init_tensor_count()
 
-        if op_name is not None:
-            if op_name in outputs:
-                return outputs[op_name]
-            else:
-                raise RuntimeError(
-                    "The op_name {} does not exist, please check. The available op_names are: {}"
-                    .format(op_name, [val for key, val in op_name.items()]))
+        # run the layer by the topo order
+        for node in self._layers[:last_layers + 1]:
+            op = self.__dict__[node.name]
+            self.handle_special_ops(node, op, tensor_dict)
+            # make input
+            inputs = []
+            for inp in node.inputs:
+                if inp not in node.weight_inputs and inp not in node.attr_inputs:
+                    if inp in tensor_dict:
+                        inputs.append(tensor_dict[inp])
+                    elif inp in self.states:
+                        # todo, scalar
+                        val = np.atleast_1d(self.states[inp])
+                        val = tensor.from_numpy(val)
+                        val.to_device(self.dev)
+                        inputs.append(val)
+                    else:
+                        raise KeyError(
+                            "Not found the input {} for operation {}".format(
+                                inp, node.name))
+            states = {}
+            if callable(getattr(op, "initialize",
+                                None)) and not op._initialized:
+                # init the operator
+                op.initialize(*inputs)
+                op._initialized = True
+                for key, name in node.weight_inputs.items():
+                    if key not in node.attr_inputs:
+                        # find the weights and not in the inputs
+                        states[name] = self.states[key]
 
-        # return all outputs if all_outputs==True
-        # else return last outputs
-        if all_outputs:
-            return ret_outputs
-        else:
-            return [ret_outputs[outp] for outp in final_outputs]
+            # replace attrs by inputs
+            for key, name in node.attr_inputs.items():
+                if key in tensor_dict:
+                    ts = tensor_dict[key]
+                elif key in self.states:
+                    ts = self.states[key]
+                if isinstance(ts, tensor.Tensor):
+                    ts = tensor.to_numpy(ts)
+                states[name] = ts
+            # set states
+            if states:
+                if callable(getattr(op, "set_states", None)):
+                    # rename the layer's states
+                    states = {
+                        getattr(op, key).name: val
+                        for (key, val) in states.items()
+                    }
+                    if self.is_graph and not self.has_initialized:
+                        prev_state = self.dev.graph_enabled()
+                        self.dev.EnableGraph(False)
+                        op.set_states(states)
+                        self.dev.EnableGraph(prev_state)
+                    else:
+                        op.set_states(states)
+                else:
+                    for key, value in states.items():
+                        setattr(op, key, value)
+            # run the node
+            outputs = _run_node(op, inputs)
+            # release the input tensor
+            for inp in node.inputs:
+                if inp in self.tensor_count:
+                    self.tensor_count[inp] -= 1
+                if self.tensor_count[inp] == 0:
+                    if inp in tensor_dict:
+                        del tensor_dict[inp]
+                    del self.tensor_count[inp]
+            # store the output
+            for (outp, val) in zip(node.outputs, outputs):
+                tensor_dict[outp] = val
+                if outp in outputs_dict:
+                    outputs_dict[outp] = self.to_output_tensor(val, outp)
+        self.has_initialized = True
+        return list(outputs_dict.values())
+
+
+class SONNXModel(model.Model):
+
+    def __init__(self, onnx_model):
+        """
+        Init a SIGNA Model
+        Args:
+            onnx_model (ModelProto): a loaded onnx model
+        """
+        super(SONNXModel, self).__init__()
+        self.sg_ir = prepare(onnx_model)
+        for node, operator in self.sg_ir.layers:
+            self.__dict__[node.name] = operator
+        self.sg_ir.is_graph = True
+
+    def forward(self, *input, aux_output=(), **kwargs):
+        """
+        The forward of the SINGA model
+        Args:
+            input (Tensors[]): a list of Tensor
+            aux_output (string()): a set of required output name
+
+        Returns:
+            a OrderedDict of Tensor
+        """
+        return self.sg_ir.run(input, aux_output=aux_output, **kwargs)
 
 
 run_node = SingaBackend.run_node
@@ -2141,4 +2226,4 @@
 get_op = SingaBackend._onnx_node_to_singa_op
 to_onnx = SingaFrontend.singa_to_onnx_model
 save = onnx.save
-load = onnx.load
+load = onnx.load
\ No newline at end of file
diff --git a/python/singa/tensor.py b/python/singa/tensor.py
index 08eee75..4f62a31 100755
--- a/python/singa/tensor.py
+++ b/python/singa/tensor.py
@@ -59,13 +59,13 @@
 from builtins import object
 import numpy as np
 from functools import reduce
+import re
 
-from .proto import core_pb2
 from . import singa_wrap as singa
 from .device import get_default_device
 
-int32 = core_pb2.kInt
-float32 = core_pb2.kFloat32
+int32 = 2  #core.proto.kInt32
+float32 = 0  #core.proto.kFloat32
 CTensor = singa.Tensor
 
 
@@ -155,6 +155,17 @@
 
         return ret
 
+    def is_dummy(self):
+        '''
+        Returns:
+            True if the tensor is a dummy tensor
+        '''
+        match = re.match(r'Dummy#\d+', self.name)
+        if match:
+            return True
+        else:
+            return False
+
     def ndim(self):
         '''
         Returns:
@@ -212,6 +223,11 @@
         '''
         return self.data.MemSize()
 
+    def contiguous(self):
+        t = Tensor(self.shape, self.device, self.dtype)
+        t.data = singa.Contiguous(self.data)
+        return t
+
     def reshape(self, shape):
         '''Return a new tensor with the given shape, and the original
             tensor is not changed.
@@ -336,9 +352,24 @@
         Args:
             t (Tensor): source Tensor.
         '''
+        assert (t.size() == self.size()), "tensor shape should be the same"
         assert isinstance(t, Tensor), 't must be a singa Tensor instance'
         self.data.CopyData(t.data)
 
+    def copy_from(self, t, offset=0):
+        ''' Copy the data from the numpy array or other Tensor instance
+
+        Args:
+            t (Tensor or np array): source Tensor or numpy array
+            offset (int): destination offset
+        '''
+        if isinstance(t, Tensor):
+            self.copy_data(t)
+        elif isinstance(t, np.ndarray):
+            self.copy_from_numpy(t)
+        else:
+            raise ValueError("t should be Tensor or numpy array.")
+
     def clone(self):
         '''
         Returns:
@@ -638,6 +669,14 @@
         else:
             return _call_singa_func(singa.DivFloat, self.data, rhs)
 
+    def __floordiv__(self, rhs):
+        if isinstance(rhs, Tensor):
+            tmp = from_raw_tensor(singa.__div__(self.data, rhs.data))
+            return _call_singa_func(singa.Floor, tmp.data)
+        else:
+            tmp = _call_singa_func(singa.DivFloat, self.data, rhs)
+            return _call_singa_func(singa.Floor, tmp.data)
+
     def __lt__(self, rhs):
         if isinstance(rhs, Tensor):
             return from_raw_tensor(singa.__lt__(self.data, rhs.data))
@@ -662,6 +701,14 @@
         else:
             return _call_singa_func(singa.GEFloat, self.data, rhs)
 
+    def __eq__(self, rhs):
+        if isinstance(rhs, Tensor):
+            return from_raw_tensor(singa.__eq__(self.data, rhs.data))
+        elif rhs is None:
+            return False
+        else:
+            return _call_singa_func(singa.EQFloat, self.data, rhs)
+
     def __radd__(self, lhs):
         lhs = float(lhs)
         one = Tensor(self.shape, self.device, self.dtype)
@@ -701,6 +748,9 @@
         return np.array2string(to_numpy(self))
 
 
+''' alias Tensor to PlaceHolder
+'''
+PlaceHolder = Tensor
 ''' python functions for global functions in Tensor.h
 '''
 
@@ -746,6 +796,10 @@
     return singa.SizeOf(dtype)
 
 
+def contiguous(tensor):
+    return _call_singa_func(singa.Contiguous, tensor.data)
+
+
 def reshape(tensor, shape):
     '''Reshape the input tensor with the given shape and
     the original tensor is not changed
@@ -789,7 +843,7 @@
     singa.CopyDataToFrom(dst.data, src.data, size, dst_offset, src_offset)
 
 
-def from_numpy(np_array):
+def from_numpy(np_array, dev=None):
     '''Create a Tensor instance with the shape, dtype and values from the numpy
     array.
 
@@ -808,13 +862,15 @@
         np_array = np_array.astype(np.int32)
 
     if np_array.dtype == np.float32:
-        dtype = core_pb2.kFloat32
+        dtype = float32
     else:
         assert np_array.dtype == np.int32, \
             'Only float and int tensors are supported'
-        dtype = core_pb2.kInt
+        dtype = int32
     ret = Tensor(np_array.shape, dtype=dtype)
     ret.copy_from_numpy(np_array)
+    if dev:
+        ret.to_device(dev)
     return ret
 
 
@@ -842,9 +898,9 @@
         a numpy array
     '''
     th = to_host(t)
-    if th.dtype == core_pb2.kFloat32:
+    if th.dtype == float32:
         np_array = th.data.GetFloatValue(int(th.size()))
-    elif th.dtype == core_pb2.kInt:
+    elif th.dtype == int32:
         np_array = th.data.GetIntValue(int(th.size()))
     else:
         print('Not implemented yet for ', th.dtype)
@@ -1124,6 +1180,20 @@
     return t >= x
 
 
+def eq(t, x):
+    '''Elementi-wise comparison for t == x.
+
+    Args:
+        t (Tensor): left hand side operand
+        x (Tensor or float): right hand side operand
+
+    Returns:
+        a Tensor with each element being t[i] == x ? 1.0f:0.0f,
+        or t[i] == x[i] ? 1.0f:0.0f
+    '''
+    return t == x
+
+
 def add(lhs, rhs, ret=None):
     '''Elementi-wise addition.
 
@@ -1705,3 +1775,30 @@
     for t in tensors:
         ctensors.append(t.data)
     return _call_singa_func(singa.ConcatOn, ctensors, axis)
+
+
+def random(shape, device=get_default_device()):
+    ''' return a random tensor with given shape
+
+    Args:
+        shape: shape of generated tensor
+        device: device of generated tensor, default is cpu
+
+    Returns:
+        new tensor generated
+    '''
+    ret = Tensor(shape, device=device)
+    ret.uniform(0, 1)
+    return ret
+
+
+def zeros(shape, device=get_default_device()):
+    ret = Tensor(shape, device=device)
+    ret.set_value(0.0)
+    return ret
+
+
+def ones(shape, device=get_default_device()):
+    ret = Tensor(shape, device=device)
+    ret.set_value(1.0)
+    return ret
diff --git a/python/singa/utils.py b/python/singa/utils.py
index 78c9f2c..d102771 100644
--- a/python/singa/utils.py
+++ b/python/singa/utils.py
@@ -20,7 +20,6 @@
 import numpy as np
 import collections
 
-from singa import tensor
 from . import singa_wrap as singa
 
 OrderedDict = collections.OrderedDict
@@ -55,40 +54,44 @@
     sys.stdout.flush()
 
 
-def handle_odd_pad_fwd(x, odd_padding):
+def handle_odd_pad_fwd(x, odd_padding, is_pool=False):
     """
     handle odd padding mode forward
-    Args:x
-        the input tensor
-    Args:odd_padding
-        the odd_padding
+    Args:
+        x, the input tensor
+        odd_padding, the odd_padding
     Returns: 
         tensor, the output
     """
-    x_tensor = tensor.from_raw_tensor(x)
     # (axis, left padding if True else right padding)
     flags = [(2, True), (2, False), (3, True), (3, False)]
     for (axis, left), pad in zip(flags, odd_padding):
         if pad == 0:
             continue
-        zeros_shape = list(x_tensor.data.shape())
-        zeros_shape[axis] = pad
-        zero_padding = np.zeros(zeros_shape).astype(np.float32)
-        zero_padding = tensor.Tensor(device=x.device(), data=zero_padding)
-        if left:
-            x_tensor = tensor.concatenate((zero_padding, x_tensor), axis)
+        if is_pool:
+            if left:
+                padding = singa.SliceOn(x, 0, pad, axis)
+            else:
+                axis_shape = list(x.shape())[axis]
+                padding = singa.SliceOn(x, axis_shape - pad, axis_shape, axis)
         else:
-            x_tensor = tensor.concatenate((x_tensor, zero_padding), axis)
-    return x_tensor.data
+            pad_shape = list(x.shape())
+            pad_shape[axis] = pad
+            padding = singa.Tensor(list(pad_shape), x.device())
+            padding.SetFloatValue(0.)
+        if left:
+            x = singa.ConcatOn(singa.VecTensor([padding, x]), axis)
+        else:
+            x = singa.ConcatOn(singa.VecTensor([x, padding]), axis)
+    return x
 
 
 def handle_odd_pad_bwd(dx, odd_padding):
     """
     handle odd padding mode backward
-    Args:dx
-        the backward tensor
-    Args:odd_padding
-        the odd_padding
+    Args:
+        dx, the backward tensor
+        odd_padding, the odd_padding
     Returns: 
         tensor, the output
     """
@@ -108,12 +111,10 @@
 def same_pad_shape_check(handle, pad_mode, x):
     """
     check the shape is correct for same padding mode
-    Args:handle
-        the handle
-    Args:pad_mode
-        pad_mode
-    Args:x
-        input tensor
+    Args:
+        handle, the handle
+        pad_mode, pad_mode
+        x: input tensor
     Returns: 
         tuple, the correct padding(before divide 2)
     """
@@ -132,10 +133,9 @@
 def re_new_handle(handle, x, is_pool=False):
     """
     re-new a handle by useing the new input tensor
-    Args:handle
-        the handle
-    Args:x
-        input tensor
+    Args:
+        handle, the handle
+        x, input tensor
     Returns: 
         handle, a new handle
     """
@@ -163,9 +163,7 @@
     return padding shape of conv2d or pooling,
     Args:
         pad_mode: string
-    Args:
         kernel_spatial_shape: list[int]
-    Args:
         strides_spatial: list[int]
     Returns: 
         list[int]
@@ -196,11 +194,9 @@
     ! borrow from onnx
     Args:
         auto_pad: string
-    Args:
+        input_spatial_shape: list[int]
         kernel_spatial_shape: list[int]
-    Args:
         strides_spatial: list[int]
-    Args:
         output_spatial_shape: list[int]
     Returns: 
         list[int
@@ -241,19 +237,18 @@
     return a list by the topological ordering (postorder of Depth-first search)
     Args:
         root: singa operator
-    Args:
         root_t: tensor
     Returns: 
         deque[int]
     """
 
-    def recursive(root, yid, root_t, nodes, weights, inputs):
+    def recursive(root, yid, root_t):
         if root:
             # srcop: operator for a input of root
             # yid: id(output of this operator)
             # y: output of this operator
             for srcop, yid, y, _ in root.src:
-                recursive(srcop, yid, y, nodes, weights, inputs)
+                recursive(srcop, yid, y)
 
             if type(root).__name__ == 'Dummy':
                 if root_t != None:
@@ -269,5 +264,5 @@
     weights = OrderedDict()
     inputs = OrderedDict()
 
-    recursive(root, None, root_t, nodes, weights, inputs)
+    recursive(root, None, root_t)
     return nodes, weights, inputs
diff --git a/setup.py b/setup.py
new file mode 100644
index 0000000..1abc48c
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,440 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+'''Script for building wheel package for installing singa via pip.
+
+This script must be launched at the root dir of the singa project 
+inside the docker container created via tool/docker/devel/centos/cudaxx/Dockerfile.manylinux2014.
+
+    # launch docker container
+    $ nvidia-docker run -v <local singa dir>:/root/singa -it apache/singa:manylinux2014-cuda10.2
+    # build the wheel packag; replace cp36-cp36m to compile singa for other py version
+    $ /opt/python/cp36-cp36m/bin/python setup.py bdist_wheel
+    $ /opt/python/cp37-cp37m/bin/python setup.py bdist_wheel
+    $ /opt/python/cp38-cp38/bin/python setup.py bdist_wheel
+
+The generted wheel file should be repaired by the auditwheel tool to make it 
+compatible with PEP513. Otherwise, the dependent libs will not be included in
+the wheel package and the wheel file will be rejected by PYPI website during
+uploading due to file name error.
+
+    # repair the wheel pakage and upload to pypi
+    $ /opt/python/cp36-cp36m/bin/python setup.py audit
+
+For the Dockerfile with CUDA and CUDNN installed, the CUDA version and 
+CUDNN version are exported as environment variable: CUDA_VERSION, CUDNN_VERSION.
+You can control the script to build CUDA enabled singa package by exporting
+SINGA_CUDA=ON; otherwise the CPU only package will be built.
+
+
+Ref: 
+[1] https://github.com/bytedance/byteps/blob/master/setup.py
+[2] https://setuptools.readthedocs.io/en/latest/setuptools.html
+[3] https://packaging.python.org/tutorials/packaging-projects/ 
+'''
+
+from setuptools import find_packages, setup, Command, Extension
+from setuptools.command.build_ext import build_ext
+from distutils.errors import CompileError, DistutilsSetupError
+
+import os
+import io
+import sys
+import subprocess
+import shutil
+import shlex
+from pathlib import Path
+
+import numpy as np
+
+NAME = 'singa'
+'''
+Pypi does not allow you to overwrite the uploaded package;
+therefore, you have to bump the version.
+Pypi does not allow [local version label](https://www.python.org/dev/peps/pep-0440/#local-version-segments) 
+to appear in the version, therefore, you have to include the public 
+version label only. Currently, due to the pypi size limit, the package 
+uploaded to pypi is cpu only (without cuda and cudnn), which can be installed via
+    
+    $ pip install singa
+    $ pip install singa=3.0.0.dev1
+
+The cuda and cudnn enabled package's version consists of the public 
+version label + local version label, e.g., 3.0.0.dev1+cuda10.2, which
+can be installed via
+
+    $ pip install singa=3.0.0.dev1+cuda10.2 -f <url of the repo>
+
+'''
+from datetime import date
+
+# stable version
+VERSION = '3.0.0'
+# get the git hash
+# git_hash = subprocess.check_output(["git", "describe"]).strip().split('-')[-1][1:]
+# comment the next line to build wheel for stable version
+# VERSION += '.dev' + date.today().strftime('%y%m%d')
+
+SINGA_PY = Path('python')
+SINGA_SRC = Path('src')
+SINGA_HDR = Path('include')
+
+
+class AuditCommand(Command):
+    """Support setup.py upload."""
+
+    description = 'Repair the package via auditwheel tool.'
+    user_options = []
+
+    @staticmethod
+    def status(s):
+        """Prints things in bold."""
+        print('\033[1m{0}\033[0m'.format(s))
+
+    def initialize_options(self):
+        pass
+
+    def finalize_options(self):
+        pass
+
+    def run(self):
+        self.status('Removing previous wheel files under wheelhouse')
+        shutil.rmtree('wheelhouse', ignore_errors=True)
+        for wheel in os.listdir('dist'):
+            self.status('Repair the dist/{} via auditwheel'.format(wheel))
+            os.system('auditwheel repair dist/{}'.format(wheel))
+
+        # self.status('Uploading the package to PyPI via Twine…')
+        # os.system('{} -m twine upload dist/*'.format(sys.executable))
+        sys.exit()
+
+
+def parse_compile_options():
+    '''Read the environment variables to parse the compile options.
+
+    Returns:
+        a tuple of bool values as the indicators
+    '''
+    with_cuda = os.environ.get('SINGA_CUDA', False)
+    with_nccl = os.environ.get('SINGA_NCCL', False)
+    with_test = os.environ.get('SINGA_TEST', False)
+    with_debug = os.environ.get('SINGA_DEBUG', False)
+
+    return with_cuda, with_nccl, with_test, with_debug
+
+
+def generate_singa_config(with_cuda, with_nccl):
+    '''Generate singa_config.h file to define some macros for the cpp code.
+
+    Args:
+        with_cuda(bool): indicator for cudnn and cuda lib
+        with_nccl(bool): indicator for nccl lib
+    '''
+    config = ['#define USE_CBLAS', '#define USE_GLOG', '#define USE_DNNL']
+    if not with_cuda:
+        config.append('#define CPU_ONLY')
+    else:
+        config.append('#define USE_CUDA')
+        config.append('#define USE_CUDNN')
+
+    if with_nccl:
+        config.append('#define ENABLE_DIST')
+        config.append('#define USE_DIST')
+
+    # singa_config.h to be included by cpp code
+    cpp_conf_path = SINGA_HDR / 'singa/singa_config.h'
+    print('Writing configs to {}'.format(cpp_conf_path))
+    with cpp_conf_path.open('w') as fd:
+        for line in config:
+            fd.write(line + '\n')
+        versions = [int(x) for x in VERSION.split('+')[0].split('.')[:3]]
+        fd.write('#define SINGA_MAJOR_VERSION {}\n'.format(versions[0]))
+        fd.write('#define SINGA_MINOR_VERSION {}\n'.format(versions[1]))
+        fd.write('#define SINGA_PATCH_VERSION {}\n'.format(versions[2]))
+        fd.write('#define SINGA_VERSION "{}"\n'.format(VERSION))
+
+    # config.i to be included by swig files
+    swig_conf_path = SINGA_SRC / 'api/config.i'
+    with swig_conf_path.open('w') as fd:
+        for line in config:
+            fd.write(line + ' 1 \n')
+
+        fd.write('#define USE_PYTHON 1\n')
+        if not with_nccl:
+            fd.write('#define USE_DIST 0\n')
+        if not with_cuda:
+            fd.write('#define USE_CUDA 0\n')
+            fd.write('#define USE_CUDNN 0\n')
+        else:
+            fd.write('#define CUDNN_VERSION "{}"\n'.format(
+                os.environ.get('CUDNN_VERSION')))
+        versions = [int(x) for x in VERSION.split('+')[0].split('.')[:3]]
+        fd.write('#define SINGA_MAJOR_VERSION {}\n'.format(versions[0]))
+        fd.write('#define SINGA_MINOR_VERSION {}\n'.format(versions[1]))
+        fd.write('#define SINGA_PATCH_VERSION {}\n'.format(versions[2]))
+        fd.write('#define SINGA_VERSION "{}"\n'.format(VERSION))
+
+
+def get_cpp_flags():
+    default_flags = ['-std=c++11', '-fPIC', '-g', '-O2', '-Wall', '-pthread']
+    # avx_flags = [ '-mavx'] #'-mf16c',
+    if sys.platform == 'darwin':
+        # Darwin most likely will have Clang, which has libc++.
+        return default_flags + ['-stdlib=libc++']
+    else:
+        return default_flags
+
+
+def generate_proto_files():
+    print('----------------------')
+    print('Generating proto files')
+    print('----------------------')
+    proto_src = SINGA_SRC / 'proto'
+    cmd = "/usr/bin/protoc --proto_path={} --cpp_out={} {}".format(
+        proto_src, proto_src, proto_src / 'core.proto')
+    subprocess.run(cmd, shell=True, check=True)
+
+    proto_hdr_dir = SINGA_HDR / 'singa/proto'
+    proto_hdr_file = proto_hdr_dir / 'core.pb.h'
+    if proto_hdr_dir.exists():
+        if proto_hdr_file.exists():
+            proto_hdr_file.unlink()
+    else:
+        proto_hdr_dir.mkdir()
+
+    shutil.copyfile(Path(proto_src / 'core.pb.h'), proto_hdr_file)
+    return proto_hdr_file, proto_src / 'core.pb.cc'
+
+
+def path_to_str(path_list):
+    return [str(x) if not isinstance(x, str) else x for x in path_list]
+
+
+def prepare_extension_options():
+    with_cuda, with_nccl, with_test, with_debug = parse_compile_options()
+
+    generate_singa_config(with_cuda, with_nccl)
+    generate_proto_files()
+
+    link_libs = ['glog', 'protobuf', 'openblas', 'dnnl']
+
+    sources = path_to_str([
+        *list((SINGA_SRC / 'core').rglob('*.cc')), *list(
+            (SINGA_SRC / 'model/operation').glob('*.cc')), *list(
+                (SINGA_SRC / 'utils').glob('*.cc')),
+        SINGA_SRC / 'proto/core.pb.cc', SINGA_SRC / 'api/singa.i'
+    ])
+    include_dirs = path_to_str([
+        SINGA_HDR, SINGA_HDR / 'singa/proto',
+        np.get_include(), '/usr/include', '/usr/include/openblas',
+        '/usr/local/include'
+    ])
+
+    try:
+        np_include = np.get_include()
+    except AttributeError:
+        np_include = np.get_numpy_include()
+    include_dirs.append(np_include)
+
+    library_dirs = []  # path_to_str(['/usr/lib64', '/usr/local/lib'])
+
+    if with_cuda:
+        link_libs.extend(['cudart', 'cudnn', 'curand', 'cublas', 'cnmem'])
+        include_dirs.append('/usr/local/cuda/include')
+        library_dirs.append('/usr/local/cuda/lib64')
+        sources.append(str(SINGA_SRC / 'core/tensor/math_kernel.cu'))
+        if with_nccl:
+            link_libs.extend(['nccl', 'cusparse', 'mpicxx', 'mpi'])
+            sources.append(str(SINGA_SRC / 'io/communicator.cc'))
+    # print(link_libs, extra_libs)
+
+    libraries = link_libs
+    runtime_library_dirs = ['.'] + library_dirs
+    extra_compile_args = {'gcc': get_cpp_flags()}
+
+    if with_cuda:
+        cuda9_gencode = (' -gencode arch=compute_35,code=sm_35'
+                         ' -gencode arch=compute_50,code=sm_50'
+                         ' -gencode arch=compute_60,code=sm_60'
+                         ' -gencode arch=compute_70,code=sm_70')
+        cuda10_gencode = ' -gencode arch=compute_75,code=sm_75'
+        cuda11_gencode = ' -gencode arch=compute_80,code=sm_80'
+        cuda9_ptx = ' -gencode arch=compute_70,code=compute_70'
+        cuda10_ptx = ' -gencode arch=compute_75,code=compute_75'
+        cuda11_ptx = ' -gencode arch=compute_80,code=compute_80'
+        if cuda_major >= 11:
+            gencode = cuda9_gencode + cuda10_gencode + cuda11_gencode + cuda11_ptx
+        elif cuda_major >= 10:
+            gencode = cuda9_gencode + cuda10_gencode + cuda10_ptx
+        elif cuda_major >= 9:
+            gencode = cuda9_gencode + cuda9_ptx
+        else:
+            raise CompileError(
+                'CUDA version must be >=9.0, the current version is {}'.format(
+                    cuda_major))
+
+        extra_compile_args['nvcc'] = shlex.split(gencode) + [
+            '-Xcompiler', '-fPIC'
+        ]
+    options = {
+        'sources': sources,
+        'include_dirs': include_dirs,
+        'library_dirs': library_dirs,
+        'libraries': libraries,
+        'runtime_library_dirs': runtime_library_dirs,
+        'extra_compile_args': extra_compile_args
+    }
+
+    return options
+
+
+# credit: https://github.com/rmcgibbo/npcuda-example/blob/master/cython/setup.py#L55
+def customize_compiler_for_nvcc(self):
+    """Inject deep into distutils to customize how the dispatch
+    to gcc/nvcc works.
+    If you subclass UnixCCompiler, it's not trivial to get your subclass
+    injected in, and still have the right customizations (i.e.
+    distutils.sysconfig.customize_compiler) run on it. So instead of going
+    the OO route, I have this. Note, it's kindof like a wierd functional
+    subclassing going on.
+    """
+
+    # Tell the compiler it can processes .cu
+    self.src_extensions.append('.cu')
+
+    # Save references to the default compiler_so and _comple methods
+    default_compiler_so = self.compiler_so
+    super = self._compile
+
+    # Now redefine the _compile method. This gets executed for each
+    # object but distutils doesn't have the ability to change compilers
+    # based on source extension: we add it.
+    def _compile(obj, src, ext, cc_args, extra_postargs, pp_opts):
+        if os.path.splitext(src)[1] == '.cu':
+            # use the cuda for .cu files
+            self.set_executable('compiler_so', 'nvcc')
+            # use only a subset of the extra_postargs, which are 1-1
+            # translated from the extra_compile_args in the Extension class
+            postargs = extra_postargs['nvcc']
+        else:
+            postargs = extra_postargs['gcc']
+
+        super(obj, src, ext, cc_args, postargs, pp_opts)
+        # Reset the default compiler_so, which we might have changed for cuda
+        self.compiler_so = default_compiler_so
+
+    # Inject our redefined _compile method into the class
+    self._compile = _compile
+
+
+class custom_build_ext(build_ext):
+    '''Customize the process for building the extension by chaning 
+    the options for compiling swig files and cu files.
+
+    Ref: https://github.com/python/cpython/blob/master/Lib/distutils/command/build_ext.py
+    '''
+
+    def finalize_options(self):
+        self.swig_cpp = True
+        print('build temp', self.build_temp)
+        print('build lib', self.build_lib)
+        super(custom_build_ext, self).finalize_options()
+        self.swig_opts = '-py3 -outdir {}/singa/'.format(self.build_lib).split()
+        print('build temp', self.build_temp)
+        print('build lib', self.build_lib)
+
+    def build_extensions(self):
+        options = prepare_extension_options()
+        for key, val in options.items():
+            singa_wrap.__dict__[key] = val
+        customize_compiler_for_nvcc(self.compiler)
+        build_ext.build_extensions(self)
+
+
+try:
+    with io.open('README.md', encoding='utf-8') as f:
+        long_description = '\n' + f.read()
+except OSError:
+    long_description = ''
+
+classifiers = [
+    # Trove classifiers
+    # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers
+    'License :: OSI Approved :: Apache Software License',
+    'Development Status :: 3 - Alpha',
+    'Intended Audience :: Developers',
+    'Programming Language :: Python :: 3.6',
+    'Programming Language :: Python :: 3.7',
+    'Programming Language :: Python :: 3.8',
+    'Topic :: Scientific/Engineering :: Artificial Intelligence'
+]
+if sys.platform == 'darwin':
+    classifiers.append('Operating System :: MacOS :: MacOS X')
+elif sys.platform == 'linux':
+    'Operating System :: POSIX :: Linux'
+else:
+    raise DistutilsSetupError('Building on Windows is not supported currently.')
+
+keywords = 'deep learning, apache singa'
+with_cuda, with_nccl, _, _ = parse_compile_options()
+if with_cuda:
+    classifiers.append('Environment :: GPU :: NVIDIA CUDA')
+    cuda_version = os.environ.get('CUDA_VERSION')
+    cudnn_version = os.environ.get('CUDNN_VERSION')
+    keywords += ', cuda{}, cudnn{}'.format(cuda_version, cudnn_version)
+    cuda_major = int(cuda_version.split('.')[0])
+    cuda_minor = int(cuda_version.split('.')[1])
+    # local label '+cuda10.2'. Ref: https://www.python.org/dev/peps/pep-0440/
+    VERSION = VERSION + '+cuda{}.{}'.format(cuda_major, cuda_minor)
+    if with_nccl:
+        classifiers.append('Topic :: System :: Distributed Computing')
+        keywords += ', distributed'
+else:
+    keywords += ', cpu-only'
+
+singa_wrap = Extension('singa._singa_wrap', [])
+
+setup(
+    name=NAME,
+    version=VERSION,
+    description='A General Deep Learning System',
+    long_description=long_description,
+    long_description_content_type='text/markdown',
+    author='Apache SINGA Community',
+    author_email='dev@singa.apache.org',
+    url='http://singa.apache.org',
+    python_requires='>=3',
+    install_requires=[
+        'numpy >=1.16,<2.0',  #1.16
+        'onnx==1.6',
+        'deprecated',
+        'unittest-xml-reporting',
+        'future',
+        'pillow',
+        'tqdm',
+    ],
+    include_package_data=True,
+    license='Apache 2',
+    classifiers=classifiers,
+    keywords=keywords,
+    packages=find_packages('python'),
+    package_dir={'': 'python'},
+    ext_modules=[singa_wrap],
+    cmdclass={
+        'build_ext': custom_build_ext,
+        'audit': AuditCommand
+    })
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index 0752496..5f30299 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -105,6 +105,11 @@
   SET_TARGET_PROPERTIES(singa PROPERTIES LINK_FLAGS "")
 ENDIF()
 
+IF(CODE_COVERAGE)
+    MESSAGE("-- Enabling Code Coverage")
+    SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g --coverage")
+ENDIF(CODE_COVERAGE)
+
 #pass configure infor to swig
 FILE(REMOVE "${CMAKE_CURRENT_SOURCE_DIR}/api/config.i")
 CONFIGURE_FILE("${CMAKE_CURRENT_SOURCE_DIR}/api/config.i.in" "${CMAKE_CURRENT_SOURCE_DIR}/api/config.i")
diff --git a/src/api/core_device.i b/src/api/core_device.i
index 1ceacf1..a5a9644 100644
--- a/src/api/core_device.i
+++ b/src/api/core_device.i
@@ -47,12 +47,16 @@
  public:
   virtual void SetRandSeed(unsigned seed) = 0;
   std::shared_ptr<Device> host();
+  void Reset();
   int id() const;
   virtual void Sync();
   void ResetGraph();
   void RunGraph(bool serial = false);
   bool graph_enabled() const;
   void EnableGraph(bool enable);
+  void PrintTimeProfiling();
+  void SetVerbosity(int verbosity);
+  void SetSkipIteration(int skip_iteration);
   static void EnableLazyAlloc(bool enbale);
 };
 
diff --git a/src/api/core_tensor.i b/src/api/core_tensor.i
index a91e8d7..f7e3160 100755
--- a/src/api/core_tensor.i
+++ b/src/api/core_tensor.i
@@ -32,7 +32,7 @@
 #include "singa/core/tensor.h"
 #include "singa/core/device.h"
 #include "singa/proto/core.pb.h"
-#include "singa/proto/model.pb.h"
+// #include "singa/proto/model.pb.h"
 using singa::DataType;
 %}
 %shared_ptr(singa::Device)
@@ -42,6 +42,9 @@
 %init %{
   import_array();
 %}
+// better use (int DIM1, float* IN_ARRAY1)
+// otherwise, the generated py method will have the arg name src,
+// which in fact accepts num as the input
 %apply (float *IN_ARRAY1, int DIM1) {
        (const float *src, const size_t num)
 }
@@ -62,7 +65,10 @@
 %apply float[] {float *};
 #endif // USE_JAVA
 
-
+namespace std {
+  %template(VecTensor) vector<singa::Tensor>;
+  %template(VecVecSize) vector<vector<size_t>>;
+}
 
 %template(Shape) std::vector<size_t>;
 
@@ -90,7 +96,7 @@
 
     std::shared_ptr<singa::Device> device() const;
 
-    template <typename SType> void GetValue(SType* value, const size_t num);
+    template <typename SType> void GetValue(SType* value, const size_t num) const;
     %template(GetFloatValue) GetValue<float>;
     %template(GetIntValue) GetValue<int>;
 
@@ -101,11 +107,11 @@
     const std::vector<size_t> &shape() const;
     const size_t shape(size_t idx) const;
     bool transpose() const;
-    size_t nDim() const;    
+    size_t nDim() const;
 
     size_t Size() const;
     size_t MemSize() const;
-    
+
     void ResetLike(const Tensor &t);
     Tensor AsType(DataType type);
     void ToDevice(std::shared_ptr<singa::Device> dev);
@@ -115,16 +121,16 @@
 
     template <typename DType> void CopyDataFromHostPtr(const DType *src,
                                                        const size_t num,
-                                                       const size_t offset = 0);
+                                                       const size_t offset = 0) const;
     %template(CopyFloatDataFromHostPtr) CopyDataFromHostPtr<float>;
     %template(CopyIntDataFromHostPtr) CopyDataFromHostPtr<int>;
 
     void CopyData(const Tensor &other);
     void RepeatData(std::vector<size_t> repeats, int axis, int total_repeats, const Tensor &src);
-    
+
     Tensor Clone() const;
     Tensor Repeat(std::vector<size_t> repeats, int axis);
-    
+
 
 #if USE_JAVA
     %rename(iAdd) operator+=(const Tensor &t);
@@ -161,10 +167,11 @@
   void CopyDataToFrom(Tensor *dst, const Tensor &src, size_t num,
                       size_t src_offset = 0, size_t dst_offset = 0);
 
-  void RepeatDataToFrom(bool broadcast_flag, std::vector<size_t> repeats, int axis, 
+  void RepeatDataToFrom(bool broadcast_flag, std::vector<size_t> repeats, int axis,
                         Tensor *dst, const Tensor &src, const size_t num);
 
   Tensor Reshape(const Tensor &in, const std::vector<size_t> &s);
+  Tensor Contiguous(const Tensor &in);
   Tensor Transpose(const Tensor &in, const std::vector<size_t> &axes);
 
   %rename(DefaultTranspose) Transpose(const Tensor &in);
@@ -172,7 +179,11 @@
 
   Tensor Abs(const Tensor &t);
   Tensor Ceil(const Tensor &t);
+  Tensor Floor(const Tensor &t);
+  Tensor Round(const Tensor &t);
+  Tensor RoundE(const Tensor &t);
   Tensor Exp(const Tensor &t);
+  Tensor Erf(const Tensor &t);
   Tensor Log(const Tensor &t);
   Tensor ReLU(const Tensor &t);
   Tensor Sigmoid(const Tensor &t);
@@ -221,10 +232,12 @@
   %rename(__le__) operator<=(const Tensor &lhs, const Tensor &rhs);
   %rename(__gt__) operator>(const Tensor &lhs, const Tensor &rhs);
   %rename(__ge__) operator>=(const Tensor &lhs, const Tensor &rhs);
+  %rename(__eq__) operator==(const Tensor &lhs, const Tensor &rhs);
   Tensor operator<(const Tensor &lhs, const Tensor &rhs);
   Tensor operator<=(const Tensor &lhs, const Tensor &rhs);
   Tensor operator>(const Tensor &lhs, const Tensor &rhs);
   Tensor operator>=(const Tensor &lhs, const Tensor &rhs);
+  Tensor operator==(const Tensor &lhs, const Tensor &rhs);
 
 
   %rename(LTFloat) operator<(const Tensor &t, const float x);
@@ -244,6 +257,10 @@
   template <typename DType> Tensor operator>=(const Tensor &t, const DType x);
   %template(opge) operator>= <float>;
 
+  %rename(EQFloat) operator==(const Tensor &t, const float x);
+  template <typename DType> Tensor operator==(const Tensor &t, const DType x);
+  %template(opeq) operator== <float>;
+
   Tensor ConcatOn(const std::vector<Tensor> &in, int axis);
   Tensor SliceOn(const Tensor&in, const size_t start, const size_t end, int axis);
 
@@ -312,6 +329,7 @@
   template <typename SType>
   void Axpy(SType alpha, const Tensor &in, Tensor *out);
   %template(Axpy) Axpy<float>;
+  void Axpy(const Tensor &alpha, const Tensor &in, Tensor *out);
 
   Tensor Mult(const Tensor &A, const Tensor &B);
   %rename(MultWithRet) Mult(const Tensor &A, const Tensor &B, Tensor *C);
diff --git a/src/api/model_operation.i b/src/api/model_operation.i
index d30a52c..b0d95a0 100755
--- a/src/api/model_operation.i
+++ b/src/api/model_operation.i
@@ -28,6 +28,7 @@
 #include "../src/model/operation/convolution.h"
 #include "../src/model/operation/batchnorm.h"
 #include "../src/model/operation/pooling.h"
+#include "../src/model/operation/rnn.h"
 
 %}
 
@@ -189,6 +190,43 @@
 
 Tensor GpuPoolingBackward(const CudnnPoolingHandle &cph, const Tensor &dy, const Tensor& x, const Tensor& y);
 
+class CudnnRNNHandle {
+ public:
+  CudnnRNNHandle(const Tensor &x,
+                 const int hidden_size, const int mode = 0,
+                 const int num_layers = 1, const int bias = 1,
+                 const float dropout = 0.0f, const int bidirectional = 0);
+  int bias;
+  int mode;
+  float dropout;
+  int bidirectional;
+  size_t feature_size;
+  size_t hidden_size;
+  size_t weights_size;
+  int num_layers;
+  size_t batch_size;
+  size_t seq_length;
+  size_t workspace_size;
+  size_t reserve_size;
+  Tensor workspace;
+  Tensor reserve_space;
+  void *states;
+};
+
+std::vector<Tensor> GpuRNNForwardTraining(const Tensor &x, const Tensor &hx, const Tensor &cx, const Tensor &W, CudnnRNNHandle &h);
+std::vector<Tensor> GpuRNNForwardInference(const Tensor &x, const Tensor &hx, const Tensor &cx, const Tensor &W, CudnnRNNHandle &h);
+std::vector<Tensor> GpuRNNBackwardx(const Tensor &y, const Tensor &dy, const Tensor &dhy, const Tensor &dcy, const Tensor &W, const Tensor &hx, const Tensor &cx, CudnnRNNHandle &h);
+Tensor GpuRNNBackwardW(const Tensor &x, const Tensor &hx, const Tensor &y, CudnnRNNHandle &h);
+
+void GpuRNNSetParam(int linLayerID, int pseudoLayer, Tensor &weights, Tensor &paramValues, bool is_bias, CudnnRNNHandle &h);
+Tensor GpuRNNGetParamCopy(int linLayerID, int pseudoLayer, Tensor &weights, bool is_bias, CudnnRNNHandle &h);
+
+std::vector<Tensor> GpuRNNForwardTrainingEx(const Tensor &x, const Tensor &hx, const Tensor &cx, const Tensor &W, const Tensor &seq_lengths, CudnnRNNHandle &h);
+std::vector<Tensor> GpuRNNForwardInferenceEx(const Tensor &x, const Tensor &hx, const Tensor &cx, const Tensor &W, const Tensor &seq_lengths, CudnnRNNHandle &h);
+std::vector<Tensor> GpuRNNBackwardxEx(const Tensor &y, const Tensor &dy, const Tensor &dhy, const Tensor &dcy, const Tensor &W, const Tensor &hx, const Tensor &cx, const Tensor &seq_lengths, CudnnRNNHandle &h);
+Tensor GpuRNNBackwardWEx(const Tensor &x, const Tensor &hx, const Tensor &y, const Tensor &seq_lengths, CudnnRNNHandle &h);
+
+
 #endif  // USE_CUDNN
 
 }  //namespace singa
diff --git a/src/api/model_optimizer.i b/src/api/model_optimizer.i
deleted file mode 100644
index 9b73d81..0000000
--- a/src/api/model_optimizer.i
+++ /dev/null
@@ -1,71 +0,0 @@
-/************************************************************
-*
-* Licensed to the Apache Software Foundation (ASF) under one
-* or more contributor license agreements.  See the NOTICE file
-* distributed with this work for additional information
-* regarding copyright ownership.  The ASF licenses this file
-* to you under the Apache License, Version 2.0 (the
-* "License"); you may not use this file except in compliance
-* with the License.  You may obtain a copy of the License at
-*
-*   http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing,
-* software distributed under the License is distributed on an
-* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-* KIND, either express or implied.  See the License for the
-* specific language governing permissions and limitations
-* under the License.
-*
-*************************************************************/
-
-/*interface file for swig */
-
-%module model_optimizer
-%include "std_vector.i"
-%include "std_string.i"
-%include "std_pair.i"
-%include "std_shared_ptr.i"
-
-%{
-#define SWIG_PYTHON_STRICT_BYTE_CHAR
-#include "singa/model/optimizer.h"
-#include "singa/proto/model.pb.h"
-using singa::Tensor;
-using singa::ParamSpec;
-using singa::OptimizerConf;
-%}
-
-
-%shared_ptr(singa::Optimizer)
-%shared_ptr(singa::Regularizer)
-%shared_ptr(singa::Constraint)
-
-namespace singa {
-class Optimizer {
- public:
-  // Optimizer() = default;
-  virtual ~Optimizer() = default;
-  void Setup(const std::string& str);
-  virtual void Apply(int epoch, float lr, const std::string& name,
-      Tensor& grad, Tensor& value, int step = -1) = 0;
-};
-inline std::shared_ptr<Optimizer> CreateOptimizer(const std::string& type);
-
-class Constraint {
- public:
-  Constraint() = default;
-  void Setup(const std::string& conf_str);
-  void Apply(int epoch, const Tensor& value, Tensor& grad, int step = -1);
-};
-
-inline std::shared_ptr<Constraint> CreateConstraint(const std::string& type);
-
-class Regularizer {
- public:
-  Regularizer() = default;
-  void Setup(const std::string& conf_str);
-  void Apply(int epoch, const Tensor& value, Tensor& grad, int step = -1);
-};
-inline std::shared_ptr<Regularizer> CreateRegularizer(const std::string& type);
-}
diff --git a/src/api/singa.i b/src/api/singa.i
index 98ad9f5..98be9d2 100644
--- a/src/api/singa.i
+++ b/src/api/singa.i
@@ -25,10 +25,6 @@
 %include "config.i"
 %include "core_tensor.i"
 %include "core_device.i"
-%include "model_layer.i"
-%include "model_optimizer.i"
-%include "model_loss.i"
-%include "model_metric.i"
 %include "model_operation.i"
-%include "io_snapshot.i"
 %include "dist_communicator.i"
+ // %include "io_snapshot.i"
\ No newline at end of file
diff --git a/src/core/device/cpp_cpu.cc b/src/core/device/cpp_cpu.cc
index 6407c65..af56b1b 100644
--- a/src/core/device/cpp_cpu.cc
+++ b/src/core/device/cpp_cpu.cc
@@ -40,6 +40,19 @@
   fn(&ctx_);
 }
 
+void CppCPU::TimeProfilingDoExec(function<void(Context*)>&& fn, int executor,
+                                 Node* node) {
+  CHECK_EQ(executor, 0);
+
+  auto t_start = std::chrono::high_resolution_clock::now();
+  fn(&ctx_);
+  std::chrono::duration<float> duration =
+      std::chrono::high_resolution_clock::now() - t_start;
+  node->time_elapsed_inc(duration.count());
+}
+
+void CppCPU::EvaluateTimeElapsed(Node* node) {}
+
 void* CppCPU::Malloc(int size) {
   if (size > 0) {
     void* ptr = malloc(size);
diff --git a/src/core/device/cuda_gpu.cc b/src/core/device/cuda_gpu.cc
index 2f3b0d6..6025a5e 100644
--- a/src/core/device/cuda_gpu.cc
+++ b/src/core/device/cuda_gpu.cc
@@ -43,6 +43,11 @@
     CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << cudnnGetErrorString(status);
   }
 #endif
+
+  // Explicitly destroys and cleans up all resources associated with current
+  // device
+  cudaDeviceReset();
+  // the returned code incidate "driver shutting down" after reset
 }
 const int kNumCudaStream = 1;
 
@@ -68,6 +73,12 @@
   // Preserse for future use instead of default sync stream, for concurrency
   // cudaStreamCreate(&ctx_.stream);
 
+#ifdef USE_DIST
+  CUDA_CHECK(cudaStreamCreateWithFlags(&ctx_.s, cudaStreamNonBlocking));
+  CUDA_CHECK(cudaStreamCreateWithFlags(&ctx_.c1, cudaStreamNonBlocking));
+  CUDA_CHECK(cudaStreamCreateWithFlags(&ctx_.c2, cudaStreamNonBlocking));
+#endif  // USE_DIST
+
   CUDA_CHECK(cudaSetDevice(id_));
   // use curandCreateGeneratorHost for CudaHost device
   CURAND_CHECK(
@@ -95,6 +106,68 @@
 
 void CudaGPU::DoExec(function<void(Context*)>&& fn, int executor) { fn(&ctx_); }
 
+void CudaGPU::SyncBeforeCountingTime() {
+  // synchronization before counting time
+  bool previous_state = graph_enabled();
+  graph_enabled_ = false;
+  Sync();
+  graph_enabled_ = previous_state;
+}
+
+void CudaGPU::EvaluateTimeElapsed(Node* node) {
+  float totalTime;
+
+  cudaEventElapsedTime(&totalTime, node->start_, node->end_);
+
+  cudaEventDestroy(node->start_);
+  cudaEventDestroy(node->end_);
+
+  node->time_elapsed_inc(totalTime * 0.001);
+}
+
+void CudaGPU::TimeProfilingDoExec(function<void(Context*)>&& fn, int executor,
+                                  Node* node) {
+  // time profiling using cudaEvent
+  cudaEventCreate(&(node->start_));
+  cudaEventCreate(&(node->end_));
+
+#ifdef USE_DIST
+  if (node->op_name().find("Dist") != std::string::npos) {
+    if (node->op_name().find("Dist_s") != std::string::npos)
+      cudaEventRecord(node->start_, ctx_.s);
+    else if (node->op_name().find("Dist_c1") != std::string::npos)
+      cudaEventRecord(node->start_, ctx_.c1);
+    else if (node->op_name().find("Dist_c2") != std::string::npos)
+      cudaEventRecord(node->start_, ctx_.c2);
+    else if (node->op_name().find("Dist_c1c2") != std::string::npos)
+      cudaEventRecord(node->start_, ctx_.c1);
+  } else {
+    cudaEventRecord(node->start_, ctx_.stream);
+  }
+#else
+  cudaEventRecord(node->start_, ctx_.stream);
+#endif  // USE_DIST
+
+  fn(&ctx_);
+
+#ifdef USE_DIST
+  if (node->op_name().find("Dist") != std::string::npos) {
+    if (node->op_name().find("Dist_s") != std::string::npos)
+      cudaEventRecord(node->end_, ctx_.s);
+    else if (node->op_name().find("Dist_c1") != std::string::npos)
+      cudaEventRecord(node->end_, ctx_.c1);
+    else if (node->op_name().find("Dist_c2") != std::string::npos)
+      cudaEventRecord(node->end_, ctx_.c2);
+    else if (node->op_name().find("Dist_c1c2") != std::string::npos)
+      cudaEventRecord(node->end_, ctx_.c2);
+  } else {
+    cudaEventRecord(node->end_, ctx_.stream);
+  }
+#else
+  cudaEventRecord(node->end_, ctx_.stream);
+#endif  // USE_DIST
+}
+
 void CudaGPU::CopyToFrom(void* dst, const void* src, size_t nBytes,
                          CopyDirection direction, Context* ctx) {
   // cudaMemcpy(dst, src, nBytes, copyKind[direction]);
@@ -131,8 +204,8 @@
 }
 
 void CudaGPU::Sync() {
-  Exec([this](Context* ctx) { CUDA_CHECK(cudaStreamSynchronize(ctx_.stream)); },
-       {}, {});
+  Exec([this](Context* ctx) { CUDA_CHECK(cudaDeviceSynchronize()); }, {}, {},
+       "Waiting");
 }
 
 }  // namespace singa
diff --git a/src/core/device/device.cc b/src/core/device/device.cc
index 4114601..15167e2 100644
--- a/src/core/device/device.cc
+++ b/src/core/device/device.cc
@@ -35,11 +35,29 @@
   }
 }
 
+void Device::Reset() {
+  // Sync the device to finished the current calculation
+  graph_enabled_ = false;
+  Sync();
+
+  // Reset Seed
+  // seed_ = std::chrono::system_clock::now().time_since_epoch().count();
+  // SetRandSeed(seed_);
+
+  // Reset Graph
+  graph_->Reset();
+
+  // Others
+  verbosity_ = 0;
+  skip_iteration_ = 5;
+}
+
 void Device::Exec(function<void(Context*)>&& fn,
                   const vector<Block*> read_blocks,
-                  const vector<Block*> write_blocks, bool use_rand_generator) {
+                  const vector<Block*> write_blocks, string op_name,
+                  bool use_rand_generator) {
   if (graph_enabled_ == true) {
-    graph_->AddOperation(std::move(fn), read_blocks, write_blocks);
+    graph_->AddOperation(std::move(fn), read_blocks, write_blocks, op_name);
   } else {
     // printf("immediately ops\n");
     DoExec(std::move(fn), 0);
@@ -63,6 +81,8 @@
   graph_enabled_ = previous_state;
 }
 
+void Device::PrintTimeProfiling() { graph_->PrintTimeProfiling(); }
+
 // Todo(Wangwei) Get Block From The Memory manager
 Block* Device::NewBlock(int size) {
   CHECK_GE(size, 0)
@@ -90,24 +110,19 @@
 
 void Device::CopyDataToFrom(Block* dst, Block* src, size_t nBytes,
                             CopyDirection direct, int dst_offset,
-                            int src_offset) {
-  this->Exec(
-      [this, dst, src, nBytes, direct, dst_offset, src_offset](Context* ctx) {
-        this->CopyToFrom(
-            reinterpret_cast<char*>(dst->mutable_data()) + dst_offset,
-            reinterpret_cast<const char*>(src->data()) + src_offset, nBytes,
-            direct, ctx);
-      },
-      {src}, {dst});
+                            int src_offset, Context* ctx) {
+  this->CopyToFrom(reinterpret_cast<char*>(dst->mutable_data()) + dst_offset,
+                   reinterpret_cast<const char*>(src->data()) + src_offset,
+                   nBytes, direct, ctx);
 }
 
 void Device::CopyDataFromHostPtr(Block* dst, const void* src, size_t nBytes,
-                                 size_t dst_offset) {
+                                 size_t dst_offset, Context* ctx) {
   auto direct = lang_ == kCpp ? kHostToHost : kHostToDevice;
   void* dstptr = reinterpret_cast<char*>(dst->mutable_data()) + dst_offset;
   Exec([this, dstptr, src, nBytes,
         direct](Context* ctx) { CopyToFrom(dstptr, src, nBytes, direct, ctx); },
-       {}, {dst});
+       {}, {dst}, "CopyDataFromHostPtr");
 }
 void Device::Sync() {}
 }  // namespace singa
diff --git a/src/core/device/opencl_device.cc b/src/core/device/opencl_device.cc
index 99edbf7..8d0971d 100644
--- a/src/core/device/opencl_device.cc
+++ b/src/core/device/opencl_device.cc
@@ -59,7 +59,7 @@
 
 void OpenclDevice::CopyDataToFrom(Block* dst, Block* src, size_t nBytes,
                                   CopyDirection direction, int dst_offset,
-                                  int src_offset) {
+                                  int src_offset, Context* ctx) {
   // Pointers must be valid.
   if (!dst || !src) return;
 
diff --git a/src/core/device/platform.cc b/src/core/device/platform.cc
index 1991cbc..d07c67c 100644
--- a/src/core/device/platform.cc
+++ b/src/core/device/platform.cc
@@ -138,6 +138,8 @@
   for (size_t i = 0; i < devices.size(); i++) {
     if (UsedDevice[devices[i]] == nullptr)
       UsedDevice[devices[i]] = std::make_shared<CudaGPU>(devices[i], pool);
+    else
+      UsedDevice[devices[i]]->Reset();
     ret.push_back(UsedDevice[devices[i]]);
   }
   mtx_.unlock();
diff --git a/src/core/scheduler/scheduler.cc b/src/core/scheduler/scheduler.cc
index be31c6e..0172ee8 100644
--- a/src/core/scheduler/scheduler.cc
+++ b/src/core/scheduler/scheduler.cc
@@ -23,7 +23,6 @@
 #include <iomanip>
 #include <sstream>
 #include <thread>
-#include <unordered_set>
 
 #include "singa/core/device.h"
 #include "singa/utils/safe_queue.h"
@@ -65,11 +64,6 @@
   return it->second;
 }
 
-Block *Graph::write_block(const size_t idx) const {
-  CHECK_LT(idx, write_blocks_.size());
-  return write_blocks_[idx];
-}
-
 Node *Graph::begin_node(const size_t idx) const {
   CHECK_LT(idx, begin_nodes_.size());
   return begin_nodes_[idx];
@@ -101,9 +95,15 @@
   }
   blocks_.clear();
 
-  write_blocks_.clear();
+  leaf_blocks_.clear();
+
+  iteration_ = 0;
+
+  time_elapsed_ = 0;
 
   dirty_ = false;
+
+  in_serial_ = false;
 }
 
 void Graph::Debug() {
@@ -112,6 +112,7 @@
   int w = 0;
   size_t max_in_num = 0, max_out_num = 0, max_next_num = 0, max_free_num = 0;
   for (auto &it : nodes_) {
+    if (it->op_name_ == "Waiting") continue;
     max_in_num = std::max(max_in_num, it->in_edges_.size());
     max_out_num = std::max(max_out_num, it->out_edges_.size());
   }
@@ -124,7 +125,8 @@
     max_free_num = std::max(max_free_num, it.size());
   }
 
-  for (int i = std::max(nodes_.size(), blocks_.size()); i > 0; i /= 10, ++w) {
+  size_t max_size = std::max(nodes_.size(), blocks_.size());
+  for (size_t i = max_size; i > 0; i /= 10, ++w) {
   }
 
   std::stringstream ss;
@@ -139,9 +141,16 @@
     ss << "OP[" << std::setw(w) << i;
     auto node = nodes_[i];
 
+    string name;
+    if (node->op_name_.size() > 16) {
+      name = node->op_name_.substr(0, 13) + "...";
+    } else {
+      name = node->op_name_;
+    }
+
     ss << "] Inputs:[";
     size = node->in_edges_.size();
-    for (size_t j = 0; j < max_in_num; ++j) {
+    for (size_t j = 0; j < std::max(max_in_num, size); ++j) {
       if (j < size)
         ss << std::setw(w) << blocks_[node->in_edges_[j]->blk_]->id_ << " ";
       else
@@ -150,7 +159,7 @@
 
     ss << "] Outputs:[";
     size = node->out_edges_.size();
-    for (size_t j = 0; j < max_out_num; ++j) {
+    for (size_t j = 0; j < std::max(max_out_num, size); ++j) {
       if (j < size)
         ss << std::setw(w) << blocks_[node->out_edges_[j]->blk_]->id_ << " ";
       else
@@ -188,7 +197,7 @@
 
   for (auto it : blkInfos) {
     auto blkInfo = it;
-    ss << "Block[" << std::setw(w) << blkInfo->id_ << "] addr[" << std::setw(w)
+    ss << "Block[" << std::setw(w) << blkInfo->id_ << "] addr[" << std::setw(10)
        << blkInfo->blk_ << "] size[" << std::setw(10) << blkInfo->blk_->size()
        << "] graph_ref[" << std::setw(w) << blkInfo->graph_ref_
        << "] ref_count[" << std::setw(w) << blkInfo->blk_->ref_count() << "] ";
@@ -228,10 +237,98 @@
   printf("%s", ss.str().c_str());
 }
 
+void Graph::PrintTimeProfiling() {
+  std::stringstream ss;
+
+  // verbosity level: 1 -> forward and backward propagation time
+  if (device_->verbosity() == 1) {
+    bool forward = true;
+    float forward_time = 0;
+    float backward_time = 0;
+    float time_elapsed;
+
+    for (size_t i = 0; i < nodes_.size(); ++i)
+      if (nodes_[i]->time_elapsed() > 0) {
+        if (forward == true)
+          // check the op of cross entropy backward, after that are backward ops
+          // note that the function is more accurate when either
+          // SoftmaxCrossEntropy or Softmax is used
+          if (nodes_[i]->op_name().find("Backward") != std::string::npos)
+            forward = false;
+        // when forward becomes false, it starts the backward propagation
+
+        time_elapsed = (nodes_[i]->time_elapsed()) /
+                       (iteration_ - device_->skip_iteration());
+
+        if (forward == true) forward_time += time_elapsed;
+      }
+
+    backward_time = (time_elapsed_ / (iteration_ - device_->skip_iteration())) -
+                    forward_time;
+
+    ss << std::endl << "Time Profiling:" << std::endl;
+    ss << "Forward Propagation Time : " << forward_time << " sec" << std::endl;
+    ss << "Backward Propagation Time : " << backward_time << " sec"
+       << std::endl;
+  }
+
+  // verbosity level: 2 -> each operation time (OP_ID, operation name, time)
+  if (device_->verbosity() == 2) {
+    ss << std::endl << "Time Profiling:" << std::endl;
+    for (size_t i = 0; i < nodes_.size(); ++i)
+      if (nodes_[i]->time_elapsed() > 0)
+        ss << "OP_ID" << nodes_[i]->id_ << ". " << nodes_[i]->op_name() << " : "
+           << (nodes_[i]->time_elapsed()) / (iteration_) << " sec" << std::endl;
+  }
+
+  // verbosity level: 3 -> Distributed training operations
+  if (device_->verbosity() == 3) {
+    ss << std::endl << "Time Profiling:" << std::endl;
+    for (size_t i = 0; i < nodes_.size(); ++i)
+      if ((nodes_[i]->op_name().find("Dist") != std::string::npos) &&
+          (nodes_[i]->time_elapsed() > 0))
+        ss << "OP_ID" << nodes_[i]->id_ << ". " << nodes_[i]->op_name() << " : "
+           << (nodes_[i]->time_elapsed()) / (iteration_) << " sec" << std::endl;
+  }
+
+  printf("%s", ss.str().c_str());
+}
+
+void Graph::TimeProfilingDoExec(Node *curNode) {
+  if ((device_->verbosity() > 0) && (curNode->op_name_ != "Waiting") &&
+      (iteration_ >= device_->skip_iteration()))
+    device_->TimeProfilingDoExec(std::move(curNode->op_), 0, curNode);
+  else
+    device_->DoExec(std::move(curNode->op_), 0);
+}
+
+void Graph::EvaluateTimeElapsed(const TimePoint &start) {
+  if ((device_->verbosity() > 0) && (iteration_ > device_->skip_iteration())) {
+    device_->Sync();
+    std::chrono::duration<float> duration =
+        std::chrono::high_resolution_clock::now() - start;
+    time_elapsed_inc(duration.count());
+    for (size_t i = 0; i < nodes_.size(); ++i) {
+      Node *curNode = nodes_[i];
+      if (curNode->op_name_ != "Waiting") {
+        device_->EvaluateTimeElapsed(curNode);
+      }
+    }
+  }
+}
+
+void Graph::TakeStartTime(TimePoint &start) {
+  if ((device_->verbosity() > 0) && (iteration_ >= device_->skip_iteration())) {
+    device_->Sync();
+    start = std::chrono::high_resolution_clock::now();
+  }
+}
+
 void Graph::RunGraph() {
   in_serial_ = false;
   if (dirty_) Analyze();
 
+  TimePoint start;
   SafeQueue<Node *> node_queue;
 
   // activate nodes
@@ -239,6 +336,8 @@
     node_queue.Push(it);
   }
 
+  TakeStartTime(start);
+
   // run graph
   while (node_queue.Size()) {
     // step 1: pop the first element, get the node corresponding to the index
@@ -247,7 +346,7 @@
     int curIndex = curNode->id_;
 
     // step 2: execute the operation
-    device_->DoExec(std::move(curNode->op_), 0);
+    TimeProfilingDoExec(curNode);
 
     // step 3: release some blocks' data that won't be used later
     for (auto it : free_blocks_[curIndex]) {
@@ -267,17 +366,24 @@
       node_queue.Push(it);
     }
   }
+
+  // increment iteration counter
+  step();
+  EvaluateTimeElapsed(start);
 }
 
 void Graph::RunInSerial() {
   in_serial_ = true;
   if (dirty_) Analyze();
 
+  TimePoint start;
+  TakeStartTime(start);
+
   for (size_t i = 0; i < nodes_.size(); ++i) {
     Node *curNode = nodes_[i];
 
     // step 1: execute the operation
-    device_->DoExec(std::move(curNode->op_), 0);
+    TimeProfilingDoExec(curNode);
 
     // step 2: release some blocks' data that won't be used later
     for (auto it : free_blocks_[i]) {
@@ -291,21 +397,25 @@
     *)(cb_data), 0));
     */
   }
+
+  // increment iteration counter
+  step();
+  EvaluateTimeElapsed(start);
 }
 
 void Graph::AddOperation(OpFunc &&op, const BlockVec &read_blocks,
-                         const BlockVec &write_blocks) {
+                         const BlockVec &write_blocks, string op_name) {
   dirty_ = true;
 
   // if the size of both read_blocks and write_blocks is zero,
   // this operation is used for synchronization
   if (read_blocks.size() == 0 && write_blocks.size() == 0) {
-    AddSyncOp(std::move(op));
+    AddSyncOp(std::move(op), op_name);
     return;
   }
 
   // create new node
-  Node *node = new Node(nodes_.size(), std::move(op));
+  Node *node = new Node(nodes_.size(), std::move(op), op_name);
 
   // create edges for read_blocks
   for (size_t i = 0; i < read_blocks.size(); ++i) {
@@ -313,6 +423,13 @@
     Node *src_node = nullptr;
     BlkInfo *blkInfo = nullptr;
 
+    // update leaf blocks
+    auto iter = leaf_blocks_.find(blk);
+    if (iter != leaf_blocks_.end()) {
+      leaf_blocks_.erase(iter);
+    }
+
+    // check if the block is already in the computational graph
     auto it = blocks_.find(blk);
     if (it == blocks_.end()) {
       blkInfo = new BlkInfo(blocks_.size(), blk, BlockType::kInput);
@@ -323,6 +440,7 @@
         blkInfo->type_ = BlockType::kInter;
       }
 
+      // update the existing edge, update dst node and create new edge
       Edge *write_edge = blkInfo->write_edge_;
       if (write_edge) {
         if (!write_edge->dst_node_) {
@@ -337,6 +455,7 @@
       }
     }
 
+    // create new edge for new block
     Edge *edge = new Edge(edges_.size(), blk, src_node, node);
     blkInfo->graph_ref_ += 1;
     if (src_node) {
@@ -352,6 +471,9 @@
     Block *blk = write_blocks[i];
     BlkInfo *blkInfo = nullptr;
 
+    // update leaf blocks
+    leaf_blocks_.insert(blk);
+
     auto it = blocks_.find(blk);
     if (it == blocks_.end()) {
       blkInfo = new BlkInfo(blocks_.size(), blk, BlockType::kEnd);
@@ -361,8 +483,28 @@
       if (blkInfo->type_ == BlockType::kInput) {
         blkInfo->type_ = BlockType::kParam;
       }
+
+      Edge *write_edge = blkInfo->write_edge_;
+      if (write_edge) {
+        if (!write_edge->dst_node_) {
+          write_edge->dst_node_ = node;
+          node->AddInEdge(write_edge);
+        } else {
+          Node *lastNode = write_edge->src_node_;
+          auto outEdges = lastNode->out_edges();
+          for (auto outEdge : outEdges) {
+            if (outEdge->blk_ == blk && outEdge->dst_node_ != node) {
+              Edge *edge =
+                  new Edge(edges_.size(), blk, outEdge->dst_node_, node);
+              outEdge->dst_node_->AddOutEdge(edge);
+              node->AddInEdge(edge);
+            }
+          }
+        }
+      }
     }
 
+    // create new edge for new block
     Edge *edge = new Edge(edges_.size(), blk, node, nullptr);
     blkInfo->write_edge_ = edge;
     blkInfo->graph_ref_ += 1;
@@ -371,19 +513,16 @@
     edges_.push_back(edge);
   }
 
-  // for sync op
-  write_blocks_ = write_blocks;
-
   // add node into nodes
   nodes_.push_back(node);
 }
 
-void Graph::AddSyncOp(function<void(Context *)> &&op) {
+void Graph::AddSyncOp(function<void(Context *)> &&op, string op_name) {
   // create new node
-  Node *node = new Node(nodes_.size(), std::move(op));
+  Node *node = new Node(nodes_.size(), std::move(op), op_name);
 
-  for (size_t i = 0; i < write_blocks_.size(); ++i) {
-    Block *blk = write_blocks_[i];
+  for (auto it : leaf_blocks_) {
+    Block *blk = it;
     BlkInfo *blkInfo = blocks_[blk];
     Edge *edge = nullptr;
 
@@ -419,6 +558,7 @@
 
 void Graph::Analyze() {
   begin_nodes_.clear();
+  next_nodes_.clear();
   next_nodes_.resize(nodes_.size());
   free_blocks_.clear();
   free_blocks_.resize(nodes_.size());
@@ -440,12 +580,15 @@
   if (in_serial_) {
     begin_nodes_.push_back(nodes_[0]);
 
-    for (size_t i = 0; i < nodes_.size() - 1; ++i) {
+    for (size_t i = 0; i < nodes_.size(); ++i) {
       Node *curNode = nodes_[i];
 
-      next_nodes_[i].push_back(nodes_[i + 1]);
+      next_nodes_[i].clear();
+      if (i + 1 < nodes_.size()) {
+        next_nodes_[i].push_back(nodes_[i + 1]);
+      }
 
-      std::unordered_set<Block *> blks;
+      BlockSet blks;
       for (size_t j = 0; j < curNode->in_edges_.size(); ++j) {
         blks.insert(curNode->in_edges_[j]->blk_);
       }
@@ -510,7 +653,7 @@
       }
 
       // step 3: push_back curNode to the used_nodes_ of relevant blocks
-      std::unordered_set<Block *> blks;
+      BlockSet blks;
       for (size_t j = 0; j < curNode->in_edges_.size(); ++j) {
         blks.insert(curNode->in_edges_[j]->blk_);
       }
diff --git a/src/core/tensor/math_kernel.cu b/src/core/tensor/math_kernel.cu
index 2ce87ea..43be56d 100644
--- a/src/core/tensor/math_kernel.cu
+++ b/src/core/tensor/math_kernel.cu
@@ -69,7 +69,9 @@
 }
 */
 
-__global__ void KernelBroadcastTo(const size_t n, size_t nDim, const float *in,const float* shape, const float* stride, float *out) {
+__global__ void KernelTraverseUnaryTransform(const size_t n, size_t nDim,
+                                             const float *in, const int *shape,
+                                             const int *stride, float *out) {
   for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n;
        i += blockDim.x * gridDim.x) {
     int shape_accu = n;
@@ -77,10 +79,10 @@
     int remains = i;
 
     for (int k = 0; k < nDim; k++) {
-      shape_accu = shape_accu/shape[k];
-      int idx = remains/shape_accu;
-      remains = remains%shape_accu;
-      offset = offset + idx*stride[k];
+      shape_accu = shape_accu / shape[k];
+      int idx = remains / shape_accu;
+      remains = remains % shape_accu;
+      offset = offset + idx * stride[k];
     }
     out[i] = in[offset];
   }
@@ -117,12 +119,45 @@
   }
 }
 
+__global__ void KernelErf(const size_t n, const float *in, float *out) {
+  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n;
+       i += blockDim.x * gridDim.x) {
+    out[i] = erff(in[i]);
+  }
+}
+
 __global__ void KernelCeil2(const size_t n, const float *in, float *out) {
   for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n;
        i += blockDim.x * gridDim.x) {
     out[i] = std::ceil(in[i]);
   }
 }
+__global__ void KernelFloor(const size_t n, const float *in, float *out) {
+  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n;
+       i += blockDim.x * gridDim.x) {
+    out[i] = std::floor(in[i]);
+  }
+}
+
+__global__ void KernelRound(const size_t n, const float *in, float *out) {
+  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n;
+       i += blockDim.x * gridDim.x) {
+    out[i] = roundf(in[i]);
+  }
+}
+
+__global__ void KernelRoundE(const size_t n, const float *in, float *out) {
+  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n;
+       i += blockDim.x * gridDim.x) {
+    float doub = in[i]*2;
+    if (ceilf(doub) == doub) {
+      out[i] = roundf(in[i]/2)*2;
+    } else {
+      out[i] = roundf(in[i]);
+    }
+  }
+}
+
 
 __global__ void KernelLog(const size_t n, const float *in, float *out) {
   for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n;
@@ -301,6 +336,23 @@
     out[idx] = in1[idx] >= in2[idx] ? 1.0f : 0.0f;
   }
 }
+
+__global__ void KernelEQ(const size_t num, const float *in, const float x,
+                         float *out) {
+  for (size_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < num;
+       idx += blockDim.x * gridDim.x) {
+    out[idx] = in[idx] == x ? 1.0f : 0.0f;
+  }
+}
+
+__global__ void KernelBEQ(const size_t num, const float *in1, const float *in2,
+                         float *out) {
+  for (size_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < num;
+       idx += blockDim.x * gridDim.x) {
+    out[idx] = in1[idx] == in2[idx] ? 1.0f : 0.0f;
+  }
+}
+
 __global__ void KernelGT(const size_t num, const float *in, const float x,
                          float *out) {
   for (size_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < num;
@@ -539,10 +591,26 @@
   KernelExp <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (n, in, out);
 }
 
+void erf(const size_t n, const float *in, float *out, cudaStream_t s) {
+  KernelErf <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (n, in, out);
+}
+
 void ceil2(const size_t n, const float *in, float *out, cudaStream_t s) {
   KernelCeil2 <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (n, in, out);
 }
 
+void floor(const size_t n, const float *in, float *out, cudaStream_t s) {
+  KernelFloor <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (n, in, out);
+}
+
+void round(const size_t n, const float *in, float *out, cudaStream_t s) {
+  KernelRound <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (n, in, out);
+}
+
+void rounde(const size_t n, const float *in, float *out, cudaStream_t s) {
+  KernelRoundE <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (n, in, out);
+}
+
 void log(const size_t n, const float *in, float *out, cudaStream_t s) {
   KernelLog <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (n, in, out);
 }
@@ -585,8 +653,11 @@
   KernelAdd <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (n, in, x, out);
 }
 
-void broadcast_to(const size_t n, size_t nDim,const float *in,const float* shape, const float* stride, float *out, cudaStream_t s) {
-  KernelBroadcastTo <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>> (n, nDim, in, shape, stride, out);
+void traverse_unary_transform(const size_t n, size_t nDim, const float *in,
+                              const int *shape, const int *stride, float *out,
+                              cudaStream_t s) {
+  KernelTraverseUnaryTransform<<<ceil(n / CU1DBLOCKF), CU1DBLOCKF>>>(
+      n, nDim, in, shape, stride, out);
 }
 
 void mult(const size_t n, const float *in, const float x, float *out,
@@ -625,6 +696,14 @@
         cudaStream_t s) {
   KernelBGE <<<ceil(num / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (num, in1, in2, out);
 }
+void eq(const size_t num, const float *in, const float x, float *out,
+        cudaStream_t s) {
+  KernelEQ <<<ceil(num / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (num, in, x, out);
+}
+void eq(const size_t num, const float *in1, const float *in2, float *out,
+        cudaStream_t s) {
+  KernelBEQ <<<ceil(num / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (num, in1, in2, out);
+}
 void lt(const size_t num, const float *in, const float x, float *out,
         cudaStream_t s) {
   KernelLT <<<ceil(num / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (num, in, x, out);
diff --git a/src/core/tensor/math_kernel.h b/src/core/tensor/math_kernel.h
index 12398fb..69e5047 100644
--- a/src/core/tensor/math_kernel.h
+++ b/src/core/tensor/math_kernel.h
@@ -44,7 +44,11 @@
 void abs(const size_t n, const float *in, float *out, cudaStream_t s);
 void sign(const size_t n, const float *in, float *out, cudaStream_t s);
 void exp(const size_t n, const float *in, float *out, cudaStream_t s);
+void erf(const size_t n, const float *in, float *out, cudaStream_t s);
 void ceil2(const size_t n, const float *in, float *out, cudaStream_t s);
+void floor(const size_t n, const float *in, float *out, cudaStream_t s);
+void round(const size_t n, const float *in, float *out, cudaStream_t s);
+void rounde(const size_t n, const float *in, float *out, cudaStream_t s);
 void cast_float_2_int(const size_t n, const float *src, int *dst,
                       cudaStream_t s);
 void cast_int_2_float(const size_t n, const int *src, float *dst,
@@ -80,9 +84,9 @@
 void mult(const size_t n, const float *in, const float x, float *out,
           cudaStream_t s);
 
-void broadcast_to(const size_t n, size_t nDim, const float *in,
-                  const float *shape, const float *stride, float *out,
-                  cudaStream_t s);
+void traverse_unary_transform(const size_t n, size_t nDim, const float *in,
+                              const int *shape, const int *stride, float *out,
+                              cudaStream_t s);
 
 void div(const size_t n, const float x, const float *in, float *out,
          cudaStream_t s);
@@ -103,6 +107,11 @@
 void ge(const size_t num, const float *in1, const float *in2, float *out,
         cudaStream_t s);
 
+void eq(const size_t num, const float *in, const float x, float *out,
+        cudaStream_t s);
+void eq(const size_t num, const float *in1, const float *in2, float *out,
+        cudaStream_t s);
+
 void lt(const size_t num, const float *in, const float x, float *out,
         cudaStream_t s);
 void lt(const size_t num, const float *in1, const float *in2, float *out,
diff --git a/src/core/tensor/tensor.cc b/src/core/tensor/tensor.cc
index 2dfee71..99d9e2a 100644
--- a/src/core/tensor/tensor.cc
+++ b/src/core/tensor/tensor.cc
@@ -173,7 +173,7 @@
               [thisRef, ret](Context *ctx) mutable {
                 CastCopy<LDType, RDType, Lang>(&thisRef, &ret, ctx);
               },
-              {this->block()}, {ret.block()});
+              {this->block()}, {ret.block()}, "AsType");
         });
     return ret;
   } else {
@@ -205,24 +205,32 @@
 
 template <typename DType>
 void Tensor::CopyDataFromHostPtr(const DType *src, const size_t num,
-                                 const size_t offset) {
+                                 const size_t offset) const {
   CHECK_EQ(sizeof(DType), SizeOf(data_type_))
       << "data_type is " << DataType_Name(data_type_)
       << " user given type is of size " << sizeof(DType);
   if (src != nullptr) {
-    device_->CopyDataFromHostPtr(block(), src, sizeof(DType) * num,
-                                 sizeof(DType) * offset);
+    Device *dev = device_.get();
+    const Tensor &thisRef = *this;
+    size_t nBytes = sizeof(DType) * num;
+    size_t dst_offset = sizeof(DType) * offset;
+    device_->Exec(
+        [dev, thisRef, src, nBytes, dst_offset](Context *ctx) mutable {
+          dev->CopyDataFromHostPtr(thisRef.block(), src, nBytes, dst_offset,
+                                   ctx);
+        },
+        {}, {block()}, "CopyDataFromHostPtr");
   } else {
     LOG(WARNING) << "Copy data from null host ptr";
   }
 }
 template void Tensor::CopyDataFromHostPtr(const unsigned char *src,
                                           const size_t num,
-                                          const size_t offset);
+                                          const size_t offset) const;
 template void Tensor::CopyDataFromHostPtr(const float *src, const size_t num,
-                                          const size_t offset);
+                                          const size_t offset) const;
 template void Tensor::CopyDataFromHostPtr(const int *src, const size_t num,
-                                          const size_t offset);
+                                          const size_t offset) const;
 
 void Tensor::CopyData(const Tensor &src) {
   CHECK_EQ(Size(), src.Size());
@@ -407,7 +415,7 @@
 
 Tensor Tensor::Clone(std::shared_ptr<Device> device) const {
   if (device == nullptr) device = device_;
-  Tensor t(shape_, device_, data_type_);
+  Tensor t(shape_, device, data_type_);
   // t.transpose_ = transpose_;
   t.stride_ = stride_;
   t.CopyData(*this);
@@ -422,15 +430,19 @@
   return;
 }
 
-Tensor &Tensor::Broadcast(const Shape &shape) {
+Tensor &Tensor::Broadcast(const Shape &shape, const int ignore_last_dim) {
   // TODO(wangwei) do we need to transform the mem layout if the tensor was
   // transposed?
   auto m = shape_.size() - 1, n = shape.size() - 1;
-  for (size_t i = 0; i <= std::min(m, n); i++) {
-    if ((shape.at(n - i) != shape_.at(m - i)) && (shape.at(n - i) != 1)) {
-      CHECK_EQ(shape_.at(m - i), 1) << "i= " << i << "\n";  // << Backtrace();
-      shape_.at(m - i) = shape.at(n - i);
-      stride_.at(m - i) = 0;
+  // ignore_last_dim is useful for mult broadcast
+  // e.g. (2,3,4)x(4,5) to (2,3,4)x(2,4,5)
+  if (ignore_last_dim < std::min(m, n) + 1) {
+    for (size_t i = ignore_last_dim; i <= std::min(m, n); i++) {
+      if ((shape.at(n - i) != shape_.at(m - i)) && (shape.at(n - i) != 1)) {
+        CHECK_EQ(shape_.at(m - i), 1) << "i= " << i << "\n";  // << Backtrace();
+        shape_.at(m - i) = shape.at(n - i);
+        stride_.at(m - i) = 0;
+      }
     }
   }
   if (m < n) {
@@ -442,9 +454,10 @@
   return *this;
 }
 
-Tensor Broadcast(const Tensor &in, const Shape &shape) {
+Tensor Broadcast(const Tensor &in, const Shape &shape,
+                 const int ignore_last_dim) {
   Tensor out(in);
-  return out.Broadcast(shape);
+  return out.Broadcast(shape, ignore_last_dim);
 }
 
 Tensor &Tensor::T() {
@@ -552,24 +565,34 @@
   CHECK_GE(src.MemSize(), s_offset + nBytes);
   CHECK_GE(dst->MemSize(), d_offset + nBytes);
 
+  Device *dev = nullptr;
+  CopyDirection direct;
   std::shared_ptr<Device> src_dev = src.device(), dst_dev = dst->device();
-  Block *from = src.block(), *to = dst->block();
   if (dst_dev->lang() != src_dev->lang()) {
     // let the none cpp device conduct copy op
     if (dst_dev->lang() == kCpp) {
-      src_dev->CopyDataToFrom(to, from, nBytes, kDeviceToHost, (int)d_offset,
-                              (int)s_offset);
+      dev = src_dev.get();
+      direct = kDeviceToHost;
     } else if (src_dev->lang() == kCpp) {
-      dst_dev->CopyDataToFrom(to, from, nBytes, kHostToDevice, (int)d_offset,
-                              (int)s_offset);
+      dev = dst_dev.get();
+      direct = kHostToDevice;
     } else {
-      LOG(FATAL) << "Not support mem copy betwee Cuda and OpenCL device";
+      LOG(FATAL) << "Not support mem copy between Cuda and OpenCL device";
     }
   } else {
-    auto direct = src_dev->lang() == kCpp ? kHostToHost : kDeviceToDevice;
-    src_dev->CopyDataToFrom(to, from, nBytes, direct, (int)d_offset,
-                            (int)s_offset);
+    dev = src_dev.get();
+    direct = src_dev->lang() == kCpp ? kHostToHost : kDeviceToDevice;
   }
+
+  Tensor &dstRef = *dst;
+  dev->Exec(
+      [dev, dstRef, src, nBytes, direct, d_offset,
+       s_offset](Context *ctx) mutable {
+        Block *from = src.block(), *to = dstRef.block();
+        dev->CopyDataToFrom(to, from, nBytes, direct, (int)d_offset,
+                            (int)s_offset, ctx);
+      },
+      {src.block()}, {dst->block()}, "CopyDataToFrom");
 }
 
 void RepeatDataToFrom(bool broadcast_flag, const vector<size_t> &repeats,
@@ -603,31 +626,42 @@
       chunk *= src.shape()[i];
     }
   }
+
+  Device *dev = nullptr;
+  CopyDirection direct;
+  std::shared_ptr<Device> src_dev = src.device(), dst_dev = dst->device();
+  if (dst_dev->lang() != src_dev->lang()) {
+    // let the none cpp device conduct copy op
+    if (dst_dev->lang() == kCpp) {
+      dev = src_dev.get();
+      direct = kDeviceToHost;
+    } else if (src_dev->lang() == kCpp) {
+      dev = dst_dev.get();
+      direct = kHostToDevice;
+    } else {
+      LOG(FATAL)
+          << "Not support mem repeat copy between Cuda and OpenCL device";
+    }
+  } else {
+    dev = src_dev.get();
+    direct = src_dev->lang() == kCpp ? kHostToHost : kDeviceToDevice;
+  }
+
   int dst_offset = 0;
   int src_offset = 0;
-  std::shared_ptr<Device> src_dev = src.device(), dst_dev = dst->device();
-  Block *from = src.block(), *to = dst->block();
+  Tensor &dstRef = *dst;
   for (int i = 0; i < shape_outer; i++) {
     for (int j = 0; j < axis_shape; j++) {
       int temp = broadcast_flag ? repeats[0] : repeats[j];
       for (int k = 0; k < temp; k++) {
-        if (dst_dev->lang() != src_dev->lang()) {
-          // let the none cpp device conduct copy op
-          if (dst_dev->lang() == kCpp) {
-            src_dev->CopyDataToFrom(to, from, chunk, kDeviceToHost, dst_offset,
-                                    src_offset);
-          } else if (src_dev->lang() == kCpp) {
-            dst_dev->CopyDataToFrom(to, from, chunk, kHostToDevice, dst_offset,
-                                    src_offset);
-          } else {
-            LOG(FATAL)
-                << "Not support mem repeat copy betwee Cuda and OpenCL device";
-          }
-        } else {
-          auto direct = src_dev->lang() == kCpp ? kHostToHost : kDeviceToDevice;
-          src_dev->CopyDataToFrom(to, from, chunk, direct, dst_offset,
-                                  src_offset);
-        }
+        dev->Exec(
+            [dev, dstRef, src, chunk, direct, dst_offset,
+             src_offset](Context *ctx) mutable {
+              Block *from = src.block(), *to = dstRef.block();
+              dev->CopyDataToFrom(to, from, chunk, direct, dst_offset,
+                                  src_offset, ctx);
+            },
+            {src.block()}, {dst->block()}, "CopyDataToFrom");
         dst_offset += chunk;
       }
       src_offset += chunk;
@@ -681,6 +715,12 @@
         { __VA_ARGS__ }                                        \
         break;                                                 \
       }                                                        \
+      case ((kInt << _SwitchShift) + kCuda): {                 \
+        typedef int DType;                                     \
+        typedef lang::Cuda Lang;                               \
+        { __VA_ARGS__ }                                        \
+        break;                                                 \
+      }                                                        \
       case ((kFloat32 << _SwitchShift) + kCpp): {              \
         typedef float DType;                                   \
         typedef lang::Cpp Lang;                                \
@@ -688,7 +728,7 @@
         break;                                                 \
       }                                                        \
       case ((kInt << _SwitchShift) + kCpp): {                  \
-        typedef float DType;                                   \
+        typedef int DType;                                     \
         typedef lang::Cpp Lang;                                \
         { __VA_ARGS__ }                                        \
         break;                                                 \
@@ -716,7 +756,7 @@
           Asum<DType, Lang>(*this, &ret, ctx);
           nrm = TypeCast<DType, float>(ret);
         },
-        {this->block()}, {});
+        {this->block()}, {}, "l1");
   });
   return nrm / Size();
 }
@@ -730,11 +770,9 @@
   TYPE_LANG_SWITCH(data_type_, DType, device_->lang(), Lang, {
     device_->Exec(
         [&nrm, this](Context *ctx) {
-          DType ret = DType(0);
-          Nrm2<DType, Lang>(*this, &ret, ctx);
-          nrm = TypeCast<DType, float>(ret);
+          Nrm2<DType, Lang>(*this, &nrm, ctx);
         },
-        {this->block()}, {});
+        {this->block()}, {}, "L1");
   });
   return nrm / Size();
 }
@@ -755,14 +793,14 @@
         [thisRef, x](Context *ctx) mutable {
           Set<DType, Lang>(x, &thisRef, ctx);
         },
-        {}, {ptr});
+        {}, {ptr}, "SetValue");
   });
 }
 template void Tensor::SetValue<float>(const float x);
 template void Tensor::SetValue<int>(const int x);
 
 template <typename SType>
-void Tensor::get_value(SType *value, const size_t num) {
+void Tensor::get_value(SType *value, const size_t num) const {
   CHECK(device_ == defaultDevice);
   Tensor t(shape_, device_, data_type_);
   // transform function arrange data in memory considering stride
@@ -770,16 +808,16 @@
   auto ptr = static_cast<const SType *>(t.block()->data());
   for (size_t i = 0; i < num; i++) value[i] = ptr[i];
 }
-template void Tensor::get_value<float>(float *value, const size_t num);
-template void Tensor::get_value<int>(int *value, const size_t num);
+template void Tensor::get_value<float>(float *value, const size_t num) const;
+template void Tensor::get_value<int>(int *value, const size_t num) const;
 
 // DEPRECATED
 template <typename SType>
-void Tensor::GetValue(SType *value, const size_t num) {
+void Tensor::GetValue(SType *value, const size_t num) const {
   get_value(value, num);
 }
-template void Tensor::GetValue<float>(float *value, const size_t num);
-template void Tensor::GetValue<int>(int *value, const size_t num);
+template void Tensor::GetValue<float>(float *value, const size_t num) const;
+template void Tensor::GetValue<int>(int *value, const size_t num) const;
 
 #define EltwiseUnaryTensorFn(fn, t, ret)                               \
   do {                                                                 \
@@ -789,7 +827,7 @@
           [t, retRef](Context *ctx) mutable {                          \
             fn<DType, Lang>(t, &retRef, ctx);                          \
           },                                                           \
-          {t.block()}, {ret->block()});                                \
+          {t.block()}, {ret->block()}, #fn);                           \
     });                                                                \
   } while (0)
 
@@ -803,7 +841,11 @@
   void fn(const Tensor &in, Tensor *out) { EltwiseUnaryTensorFn(fn, in, out); }
 
 GenUnaryTensorFn(Abs);
+GenUnaryTensorFn(Erf);
 GenUnaryTensorFn(Ceil);
+GenUnaryTensorFn(Floor);
+GenUnaryTensorFn(Round);
+GenUnaryTensorFn(RoundE);
 GenUnaryTensorFn(Exp);
 GenUnaryTensorFn(Log);
 GenUnaryTensorFn(ReLU);
@@ -899,7 +941,7 @@
           [in, outRef, fdout](Context *ctx) mutable {
             SoftMaxBackward<DType, Lang>(in, &outRef, fdout, ctx);
           },
-          {in.block(), fdout.block()}, {out->block()});
+          {in.block(), fdout.block()}, {out->block()}, "SoftmaxBackward");
     });
   } while (0);
 
@@ -922,36 +964,62 @@
           [lhs, rhs, retRef](Context *ctx) mutable {                       \
             fn<DType, Lang>(lhs, rhs, &retRef, ctx);                       \
           },                                                               \
-          {lhs.block(), rhs.block()}, {ret->block()});                     \
+          {lhs.block(), rhs.block()}, {ret->block()}, #fn);                \
     });                                                                    \
   } while (0)
 
-#define GenBinaryTensorFn(op, fn)                              \
-  Tensor op(const Tensor &lhs, const Tensor &rhs) {            \
-    if (lhs.shape() != rhs.shape()) {                          \
-      auto lhs_ = Broadcast(lhs, rhs.shape());                 \
-      auto rhs_ = Broadcast(rhs, lhs.shape());                 \
-      Tensor ret(lhs_.shape(), lhs.device(), lhs.data_type()); \
-      fn(lhs_, rhs_, &ret);                                    \
-      return ret;                                              \
-    } else {                                                   \
-      Tensor ret(lhs.shape(), lhs.device(), lhs.data_type());  \
-      fn(lhs, rhs, &ret);                                      \
-      return ret;                                              \
-    }                                                          \
-  }                                                            \
-  void fn(const Tensor &lhs, const Tensor &rhs, Tensor *ret) { \
-    CHECK_EQ(lhs.device(), ret->device());                     \
-    CHECK_EQ(rhs.device(), ret->device());                     \
-    if (lhs.shape() != rhs.shape()) {                          \
-      auto lhs_ = Broadcast(lhs, rhs.shape());                 \
-      auto rhs_ = Broadcast(rhs, lhs.shape());                 \
-      CHECK(lhs_.shape() == ret->shape());                     \
-      EltwiseBinaryTensorFn(fn, lhs_, rhs_, ret);              \
-    } else {                                                   \
-      CHECK(lhs.shape() == ret->shape());                      \
-      EltwiseBinaryTensorFn(fn, lhs, rhs, ret);                \
-    }                                                          \
+#define GenBinaryTensorFn(op, fn)                                           \
+  Tensor op(const Tensor &lhs, const Tensor &rhs) {                         \
+    if (lhs.shape() != rhs.shape()) {                                       \
+      if (lhs.data_type() == kFloat32 && rhs.data_type() == kFloat32) {     \
+        auto lhs_ = Broadcast(lhs, rhs.shape());                            \
+        auto rhs_ = Broadcast(rhs, lhs.shape());                            \
+        Tensor ret(lhs_.shape(), lhs.device(), lhs.data_type());            \
+        fn(lhs_, rhs_, &ret);                                               \
+        return ret;                                                         \
+      } else {                                                              \
+        /* lhs tensor and rhs tensor are not both in float, cast to float */\
+        Tensor tmp_lhs = lhs.Clone().AsType(kFloat32);                      \
+        Tensor tmp_rhs = rhs.Clone().AsType(kFloat32);                      \
+        tmp_lhs = Broadcast(tmp_lhs, tmp_rhs.shape());                      \
+        tmp_rhs = Broadcast(tmp_rhs, tmp_lhs.shape());                      \
+        Tensor ret(tmp_lhs.shape(), tmp_lhs.device(), tmp_lhs.data_type()); \
+        fn(tmp_lhs, tmp_rhs, &ret);                                         \
+        /* if lhs and rhs are both int, cast back to int */                 \
+        if (lhs.data_type() == kInt && rhs.data_type() == kInt)             \
+          return ret.Clone().AsType(kInt);                                  \
+        return ret;                                                         \
+      }                                                                     \
+    } else {                                                                \
+      if (lhs.data_type() == kFloat32 && rhs.data_type() == kFloat32) {     \
+        Tensor ret(lhs.shape(), lhs.device(), lhs.data_type());             \
+        fn(lhs, rhs, &ret);                                                 \
+        return ret;                                                         \
+      } else {                                                              \
+        /* lhs tensor and rhs tensor are not both in float, cast to float */\
+        Tensor tmp_lhs = lhs.Clone().AsType(kFloat32);                      \
+        Tensor tmp_rhs = rhs.Clone().AsType(kFloat32);                      \
+        Tensor ret(tmp_lhs.shape(), tmp_lhs.device(), tmp_lhs.data_type()); \
+        fn(tmp_lhs, tmp_rhs, &ret);                                         \
+        /* if lhs and rhs are both int, cast back to int */                 \
+        if (lhs.data_type() == kInt && rhs.data_type() == kInt)             \
+          return ret.Clone().AsType(kInt);                                  \
+        return ret;                                                         \
+      }                                                                     \
+    }                                                                       \
+  }                                                                         \
+  void fn(const Tensor &lhs, const Tensor &rhs, Tensor *ret) {              \
+    CHECK_EQ(lhs.device(), ret->device());                                  \
+    CHECK_EQ(rhs.device(), ret->device());                                  \
+    if (lhs.shape() != rhs.shape()) {                                       \
+      auto lhs_ = Broadcast(lhs, rhs.shape());                              \
+      auto rhs_ = Broadcast(rhs, lhs.shape());                              \
+      CHECK(lhs_.shape() == ret->shape());                                  \
+      EltwiseBinaryTensorFn(fn, lhs_, rhs_, ret);                           \
+    } else {                                                                \
+      CHECK(lhs.shape() == ret->shape());                                   \
+      EltwiseBinaryTensorFn(fn, lhs, rhs, ret);                             \
+    }                                                                       \
   }  // namespace singa
 
 // boradcasting operations:
@@ -965,34 +1033,50 @@
 GenBinaryTensorFn(operator<=, LE);
 GenBinaryTensorFn(operator>, GT);
 GenBinaryTensorFn(operator>=, GE);
+GenBinaryTensorFn(operator==, EQ);
 GenBinaryTensorFn(ReLUBackward, ReLUBackward);
 
 #define EltwiseTensorScalarFn(fn, t, x, ret)                            \
   do {                                                                  \
     TYPE_LANG_SWITCH(t.data_type(), DType, t.device()->lang(), Lang, {  \
-      static_assert(std::is_same<SType, DType>::value,                  \
-                    "The Scalar type must match the Tensor data type"); \
       Tensor &retRef = *ret;                                            \
       ret->device()->Exec(                                              \
           [t, x, retRef](Context *ctx) mutable {                        \
             fn<DType, Lang>(t, x, &retRef, ctx);                        \
           },                                                            \
-          {t.block()}, {ret->block()});                                 \
+          {t.block()}, {ret->block()}, #fn);                            \
     });                                                                 \
   } while (0)
 
-#define GenTensorScalarFn(op, fn)                             \
-  template <typename SType>                                   \
-  Tensor op(const Tensor &in, const SType x) {                \
-    Tensor ret(in.shape(), in.device(), in.data_type());      \
-    fn(in, x, &ret);                                          \
-    return ret;                                               \
-  }                                                           \
-  template <typename SType>                                   \
-  void fn(const Tensor &in, const SType x, Tensor *ret) {     \
-    EltwiseTensorScalarFn(fn, in, x, ret);                    \
-  }                                                           \
-  template Tensor op<float>(const Tensor &in, const float x); \
+#define GenTensorScalarFn(op, fn)                                          \
+  template <typename SType>                                                \
+  Tensor op(const Tensor &in, const SType x) {                             \
+    if (in.data_type() == kFloat32 && std::is_same<SType, float>::value){  \
+      Tensor ret(in.shape(), in.device(), in.data_type());                 \
+      fn(in, x, &ret);                                                     \
+      return ret;                                                          \
+    } else if (in.data_type() == kFloat32) {                               \
+      Tensor ret(in.shape(), in.device(), in.data_type());                 \
+      float tmp_x = x;                                                     \
+      fn(in, tmp_x, &ret);                                                 \
+      return ret;                                                          \
+    } else {                                                               \
+      /* tensor and scalar are not both in float, cast to float */         \
+      Tensor tmp_in = in.Clone().AsType(kFloat32);                         \
+      float tmp_x = x;                                                     \
+      Tensor ret(tmp_in.shape(), tmp_in.device(), tmp_in.data_type());     \
+      fn(tmp_in, tmp_x, &ret);                                             \
+      /* if tensor and scalar are both int, cast back to int */            \
+      if (in.data_type() == kInt && std::is_same<SType, int>::value)       \
+        return ret.Clone().AsType(kInt);                                   \
+      return ret;                                                          \
+    }                                                                      \
+  }                                                                        \
+  template <typename SType>                                                \
+  void fn(const Tensor &in, const SType x, Tensor *ret) {                  \
+    EltwiseTensorScalarFn(fn, in, x, ret);                                 \
+  }                                                                        \
+  template Tensor op<float>(const Tensor &in, const float x);              \
   template void fn<float>(const Tensor &in, const float x, Tensor *ret)
 
 GenTensorScalarFn(operator+, Add);
@@ -1004,6 +1088,8 @@
 GenTensorScalarFn(operator<=, LE);
 GenTensorScalarFn(operator>, GT);
 GenTensorScalarFn(operator>=, GE);
+GenTensorScalarFn(operator==, EQ);
+
 template <typename SType>
 Tensor Div(const SType alpha, const Tensor &in) {
   Tensor out(in.shape(), in.device(), in.data_type());
@@ -1023,7 +1109,7 @@
         [alpha, in, outRef](Context *ctx) mutable {
           Div<DType, Lang>(alpha, in, &outRef, ctx);
         },
-        {in.block()}, {out->block()});
+        {in.block()}, {out->block()}, "Div");
   });
 }
 template void Div<float>(const float, const Tensor &, Tensor *);
@@ -1065,7 +1151,7 @@
           Dot<DType, Lang>(in, one, &ret, ctx);
           s = ret;
         },
-        {in.block(), one.block()}, {});
+        {in.block(), one.block()}, {}, "Sum");
   });
   return s;
 }
@@ -1092,7 +1178,7 @@
         [in, one, out](Context *ctx) mutable {
           Dot<DType, Lang>(in, one, &out, ctx);
         },
-        {in.block(), one.block()}, {out.block()});
+        {in.block(), one.block()}, {out.block()}, "SumAll");
   });
   return out;
 }
@@ -1107,7 +1193,7 @@
           // size_t ncol = in.Size() / nrow;
           RowMax<DType, Lang>(in, &ret, ctx);
         },
-        {in.block()}, {ret.block()});
+        {in.block()}, {ret.block()}, "RowMax");
   });
   return ret;
 }
@@ -1324,7 +1410,7 @@
         [MRef, v](Context *ctx) mutable {
           DGMM<DType, Lang>(false, MRef, v, &MRef, ctx);
         },
-        {M->block(), v.block()}, {M->block()});
+        {M->block(), v.block()}, {M->block()}, "MultColumn");
   });
 }
 
@@ -1341,7 +1427,7 @@
         [MRef, v](Context *ctx) mutable {
           DGMM<DType, Lang>(true, MRef, v, &MRef, ctx);
         },
-        {M->block(), v.block()}, {M->block()});
+        {M->block(), v.block()}, {M->block()}, "MultRow");
   });
 }
 
@@ -1390,7 +1476,7 @@
         [prob, outRef](Context *ctx) mutable {
           Bernoulli<DType, Lang>(prob, &outRef, ctx);
         },
-        {}, {out->block()}, true);
+        {}, {out->block()}, "Bernoulli", true);
   });
 }
 
@@ -1406,7 +1492,7 @@
         [l, h, outRef](Context *ctx) mutable {
           Uniform<DType, Lang>(l, h, &outRef, ctx);
         },
-        {}, {out->block()}, true);
+        {}, {out->block()}, "Uniform", true);
   });
 }
 
@@ -1422,7 +1508,7 @@
         [m, s, outRef](Context *ctx) mutable {
           Gaussian<DType, Lang>(m, s, &outRef, ctx);
         },
-        {}, {out->block()}, true);
+        {}, {out->block()}, "Gaussian", true);
   });
 }
 template void Gaussian<float>(const float mean, const float std, Tensor *out);
@@ -1439,26 +1525,42 @@
         [a, in, outRef, fake](Context *ctx) mutable {
           Axpy<DType, Lang>(a, in, &outRef, ctx);
         },
-        {in.block(), out->block()}, {out->block()});
+        {in.block(), out->block()}, {out->block()}, "Axpy");
   });
 }
 
 template void Axpy<float>(const float alpha, const Tensor &in, Tensor *out);
 
+void Axpy(const Tensor &alpha, const Tensor &in, Tensor *out) {
+  TYPE_SWITCH(alpha.data_type(), SType, {
+    TYPE_LANG_SWITCH(in.data_type(), DType, in.device()->lang(), Lang, {
+      Tensor fake(*out);
+      Tensor &outRef = *out;
+      out->device()->Exec(
+          [alpha, in, outRef, fake](Context *ctx) mutable {
+            Tensor alphaHost = alpha.Clone(defaultDevice);
+            // synchronize the stream to wait for the data transfer to complete
+            alpha.device()->Sync();
+            const SType value =
+                static_cast<const SType *>(alphaHost.block()->data())[0];
+            auto a = TypeCast<SType, DType>(value);
+            Axpy<DType, Lang>(a, in, &outRef, ctx);
+          },
+          {alpha.block(), in.block(), out->block()}, {out->block()}, "Axpy");
+    });
+  });
+}
+
 Tensor Mult(const Tensor &A, const Tensor &B) {
-  Shape s;
-  s.push_back(A.shape(0));
-  if (B.nDim() == 2) s.push_back(B.shape(1));
-  if (A.nDim() > 2) {
-    // for n>2 dim
-    // A {..., m1, m2} x B {..., m2, m3} = C {..., m1, m3}
-    s = A.shape();
-    s.pop_back();
-    s.push_back(B.shape(B.nDim() - 1));
-  }
+  auto A_ = Broadcast(A, B.shape(), 2);
+  auto B_ = Broadcast(B, A.shape(), 2);
+
+  Shape s = A_.shape();
+  s.pop_back();
+  s.push_back(B.shape(B.nDim() - 1));
 
   Tensor out(s, A.device(), A.data_type());
-  Mult(A, B, &out);
+  Mult(A_, B_, &out);
   return out;
 }
 
@@ -1485,7 +1587,7 @@
           [a, A, b, B, CRef, fakeC](Context *ctx) mutable {
             GEMV<DType, Lang>(a, A, B, b, &CRef, ctx);
           },
-          read_blocks, {C->block()});
+          read_blocks, {C->block()}, "GEMV");
     });
   } else if (B.nDim() == 2u) {
     CHECK_EQ(A.shape().size(), 2u);
@@ -1498,7 +1600,7 @@
           [a, A, b, B, CRef, fakeC](Context *ctx) mutable {
             GEMM<DType, Lang>(a, A, B, b, &CRef, ctx);
           },
-          read_blocks, {C->block()});
+          read_blocks, {C->block()}, "GEMM");
     });
   } else if (B.nDim() == 3u || B.nDim() == 4u) {
     CHECK_EQ(A.shape().size(), B.shape().size());
@@ -1510,14 +1612,14 @@
       Tensor A_tmp;
       Tensor B_tmp;
 
-      if (A.transpose()) {
+      if (A.transpose() || A.broadcasted()) {
         A_tmp = Tensor(A.shape(), A.device(), A.data_type());
         singa::Transform(A, &A_tmp);
       } else {
         A_tmp = A;
       }
 
-      if (B.transpose()) {
+      if (B.transpose() || B.broadcasted()) {
         B_tmp = Tensor(B.shape(), B.device(), B.data_type());
         singa::Transform(B, &B_tmp);
       } else {
@@ -1533,7 +1635,7 @@
           [a, A_tmp, b, B_tmp, CRef, fakeC](Context *ctx) mutable {
             GEMMBatched<DType, Lang>(a, A_tmp, B_tmp, b, &CRef, ctx);
           },
-          read_blocks, {C->block()});
+          read_blocks, {C->block()}, "GEMMBatched");
     });
   } else {
     LOG(FATAL) << "Un-supported tensor dimentions " << A.nDim() << "d matmul "
@@ -1571,7 +1673,7 @@
                                            p.block(), t.block(),
                                            lossRef.block(), ctx);
         },
-        {p.block(), t.block()}, {loss->block()});
+        {p.block(), t.block()}, {loss->block()}, "ComputeCrossEntropy");
   });
 }
 
@@ -1591,10 +1693,24 @@
                                               pRef.block(), t.block(),
                                               pRef.block(), ctx);
         },
-        {p->block(), t.block()}, {p->block()});
+        {p->block(), t.block()}, {p->block()}, "SoftmaxCrossEntropyBackward");
   });
 }
 
+Tensor &Tensor::Contiguous() {
+  if (transpose()) {
+    Tensor t(shape_, device_, data_type_);
+    singa::Transform(*this, &t);
+    std::swap(t.block_, block_);
+  }
+  return *this;
+}
+
+Tensor Contiguous(const Tensor &in) {
+  Tensor out(in);
+  return out.Contiguous();
+}
+
 // if tensor is not transposed yet, we change the shape and generate new stride
 // if tensor is already transposed, we reallocate the memory and generate stride
 Tensor &Tensor::Reshape(const Shape &shape) {
diff --git a/src/core/tensor/tensor_math.h b/src/core/tensor/tensor_math.h
index 19d5178..3236e7c 100644
--- a/src/core/tensor/tensor_math.h
+++ b/src/core/tensor/tensor_math.h
@@ -86,6 +86,11 @@
   LOG(FATAL) << "Abs Not Implemented";
 }
 
+template <typename DType, typename Lang>
+void Erf(const Tensor &in, Tensor *out, Context *ctx) {
+  LOG(FATAL) << "Erf Not Implemented";
+}
+
 template <typename DTypeSrc, typename DTypeDst, typename Lang>
 void CastCopy(const Tensor *src, Tensor *dst, Context *ctx) {
   LOG(FATAL) << "CastCopy Not Implemented";
@@ -96,6 +101,21 @@
   LOG(FATAL) << "Ceil Not Implemented";
 }
 
+template <typename DType, typename Lang>
+void Floor(const Tensor &in, Tensor *out, Context *ctx) {
+  LOG(FATAL) << "Floor Not Implemented";
+}
+
+template <typename DType, typename Lang>
+void Round(const Tensor &in, Tensor *out, Context *ctx) {
+  LOG(FATAL) << "Round Not Implemented";
+}
+
+template <typename DType, typename Lang>
+void RoundE(const Tensor &in, Tensor *out, Context *ctx) {
+  LOG(FATAL) << "Round Not Implemented";
+}
+
 /// out[i] = in[i] + x
 template <typename DType, typename Lang>
 void Add(const Tensor &in, const DType x, Tensor *out, Context *ctx) {
@@ -205,6 +225,16 @@
 void GT(const Tensor &in, const Tensor &in2, Tensor *out, Context *ctx) {
   LOG(FATAL) << "Tensor-Tensor GT Not Implemented";
 }
+/// out[i]=(in[i]==x)?1.f:0.f
+template <typename DType, typename Lang>
+void EQ(const Tensor &in, const DType x, Tensor *out, Context *ctx) {
+  LOG(FATAL) << "EQ Not Implemented";
+}
+/// out[i]=(in[i]==in2[i])?1.f:0.f
+template <typename DType, typename Lang>
+void EQ(const Tensor &in, const Tensor &in2, Tensor *out, Context *ctx) {
+  LOG(FATAL) << "Tensor-Tensor EQ Not Implemented";
+}
 /// out[i] = pow(in[i], x)
 template <typename DType, typename Lang>
 void Pow(const Tensor &in, const DType x, Tensor *out, Context *ctx) {
diff --git a/src/core/tensor/tensor_math_cpp.h b/src/core/tensor/tensor_math_cpp.h
index b43d523..5be46c6 100644
--- a/src/core/tensor/tensor_math_cpp.h
+++ b/src/core/tensor/tensor_math_cpp.h
@@ -241,6 +241,11 @@
 }
 
 template <>
+void Erf<float, lang::Cpp>(const Tensor &in, Tensor *out, Context *ctx) {
+  traverse_unary<float>(in, out, [](float x) { return erff(x); });
+}
+
+template <>
 void CastCopy<float, int, lang::Cpp>(const Tensor *src, Tensor *dst,
                                      Context *ctx) {
   int *dst_array = static_cast<int *>(dst->block()->mutable_data());
@@ -261,6 +266,28 @@
   traverse_unary<float>(in, out, [](float x) { return std::ceil(x); });
 }
 
+template <>
+void Floor<float, lang::Cpp>(const Tensor &in, Tensor *out, Context *ctx) {
+  traverse_unary<float>(in, out, [](float x) { return std::floor(x); });
+}
+
+template <>
+void Round<float, lang::Cpp>(const Tensor &in, Tensor *out, Context *ctx) {
+  traverse_unary<float>(in, out, [](float x) { return std::round(x); });
+}
+
+template <>
+void RoundE<float, lang::Cpp>(const Tensor &in, Tensor *out, Context *ctx) {
+  traverse_unary<float>(in, out, [](float x) {
+    float doub = x*2;
+    if (ceilf(doub) == doub) {
+      return std::round(x/2)*2;
+    } else {
+      return std::round(x);
+    }
+  });
+}
+
 #ifdef USE_DNNL
 template <>
 void SoftMax<float, lang::Cpp>(const Tensor &in, Tensor *out, Context *ctx) {
@@ -309,7 +336,26 @@
                                         {DNNL_ARG_DST, fdout_mem}});
   ctx->dnnl_stream.wait();
 }
+#else
+// native Softmax without DNNL
+template <>
+void SoftMax<float, lang::Cpp>(const Tensor &in, Tensor *out, Context* ctx) {
+  CHECK_LE(in.nDim(), 2u) << "Axis is required for SoftMax on multi dimemsional tensor";
+  out->CopyData(in);
+  size_t nrow = 1, ncol = in.Size(), size = ncol;
+  if (in.nDim() == 2u) {
+    nrow = in.shape(0);
+    ncol = size / nrow;
+    out->Reshape(Shape{nrow, ncol});
+  }
+  Tensor tmp = RowMax(*out);
+  SubColumn(tmp, out);
+  Exp(*out, out);
 
+  SumColumns(*out, &tmp);
+  DivColumn(tmp, out);
+  out->Reshape(in.shape());
+}
 #endif  // USE_DNNL
 
 template <>
@@ -403,6 +449,13 @@
 }
 
 template <>
+void GE<int, lang::Cpp>(const Tensor &in1, const Tensor &in2, Tensor *out,
+                          Context *ctx) {
+  auto ge_lambda_binary = [](int a, int b) { return (a >= b) ? 1.f : 0.f; };
+  traverse_binary<int>(in1, in2, out, ge_lambda_binary);
+}
+
+template <>
 void GT<float, lang::Cpp>(const Tensor &in, const float x, Tensor *out,
                           Context *ctx) {
   auto gt_lambda = [&x](float a) { return (a > x) ? 1.f : 0.f; };
@@ -417,6 +470,13 @@
 }
 
 template <>
+void GT<int, lang::Cpp>(const Tensor &in1, const Tensor &in2, Tensor *out,
+                          Context *ctx) {
+  auto gt_lambda_binary = [](int a, int b) { return (a > b) ? 1.f : 0.f; };
+  traverse_binary<int>(in1, in2, out, gt_lambda_binary);
+}
+
+template <>
 void LE<float, lang::Cpp>(const Tensor &in, const float x, Tensor *out,
                           Context *ctx) {
   auto le_lambda = [&x](float a) { return (a <= x) ? 1.f : 0.f; };
@@ -431,6 +491,13 @@
 }
 
 template <>
+void LE<int, lang::Cpp>(const Tensor &in1, const Tensor &in2, Tensor *out,
+                          Context *ctx) {
+  auto le_lambda_binary = [](int a, int b) { return (a <= b) ? 1.f : 0.f; };
+  traverse_binary<int>(in1, in2, out, le_lambda_binary);
+}
+
+template <>
 void Log<float, lang::Cpp>(const Tensor &in, Tensor *out, Context *ctx) {
   auto ulog = [](float a) {
     CHECK_GT(a, 0.f);
@@ -454,6 +521,34 @@
 }
 
 template <>
+void LT<int, lang::Cpp>(const Tensor &in1, const Tensor &in2, Tensor *out,
+                          Context *ctx) {
+  auto lt_lambda_binary = [](int a, int b) { return (a < b) ? 1.f : 0.f; };
+  traverse_binary<int>(in1, in2, out, lt_lambda_binary);
+}
+
+template <>
+void EQ<float, lang::Cpp>(const Tensor &in, const float x, Tensor *out,
+                          Context *ctx) {
+  auto eq_lambda = [&x](float a) { return (a == x) ? 1.f : 0.f; };
+  traverse_unary<float>(in, out, eq_lambda);
+}
+
+template <>
+void EQ<float, lang::Cpp>(const Tensor &in1, const Tensor &in2, Tensor *out,
+                          Context *ctx) {
+  auto eq_lambda_binary = [](float a, float b) { return (a == b) ? 1.f : 0.f; };
+  traverse_binary<float>(in1, in2, out, eq_lambda_binary);
+}
+
+template <>
+void EQ<int, lang::Cpp>(const Tensor &in1, const Tensor &in2, Tensor *out,
+                          Context *ctx) {
+  auto eq_lambda_binary = [](int a, int b) { return (a == b) ? 1.f : 0.f; };
+  traverse_binary<int>(in1, in2, out, eq_lambda_binary);
+}
+
+template <>
 void Pow<float, lang::Cpp>(const Tensor &in, const float x, Tensor *out,
                            Context *ctx) {
   traverse_unary<float>(in, out, [x](float y) { return pow(y, x); });
@@ -564,6 +659,12 @@
 }
 
 template <>
+void Transform<int, lang::Cpp>(const Tensor &in, Tensor *out, Context *ctx) {
+  auto identity = [](int a) { return a; };
+  traverse_unary<int>(in, out, identity);
+}
+
+template <>
 void Bernoulli<float, lang::Cpp>(const float p, Tensor *out, Context *ctx) {
   std::bernoulli_distribution distribution(p);
   float *outPtr = static_cast<float *>(out->block()->mutable_data());
diff --git a/src/core/tensor/tensor_math_cuda.h b/src/core/tensor/tensor_math_cuda.h
index 8f12337..b3ff100 100644
--- a/src/core/tensor/tensor_math_cuda.h
+++ b/src/core/tensor/tensor_math_cuda.h
@@ -58,8 +58,10 @@
 */
 vector<int> generate_shape_cuda(const Tensor& x) {
   Shape shape = x.shape();
-  CHECK_LE(shape.size(), 5)
-      << "Dimensions (shape) beyond 5 are currently not supported";
+  // maximum dimension allowed defined in cudnn.h, variable CUDNN_DIM_MAX
+  // TODO: check other side effects
+  CHECK_LE(shape.size(), CUDNN_DIM_MAX)
+      << "Dimensions (shape) beyond " << CUDNN_DIM_MAX << " are currently not supported";
   vector<int> shape_arr;
   if (shape.size() < 4) {
     for (int n = 0; n < 4 - int(shape.size()); ++n) {
@@ -73,12 +75,13 @@
 }
 
 int generate_dim_cuda(const Tensor& x) {
-  CHECK_LE(x.nDim(), 5)
-      << "Dimensions (shape) beyond 5 are currently not supported";
+  // maximum dimension allowed defined in cudnn.h, variable CUDNN_DIM_MAX
+  CHECK_LE(x.nDim(), CUDNN_DIM_MAX)
+      << "Dimensions (shape) beyond " << CUDNN_DIM_MAX << " are currently not supported";
   if (x.shape().size() <= 4) {
     return 4;
   } else {
-    return 5;
+    return x.nDim();
   }
 }
 
@@ -206,30 +209,36 @@
                              generate_tensor_nd_desc(*out), outPtr));
 }
 
-inline Tensor get_broadcasted_tensor(const Tensor& in1, Context* ctx) {
-  Tensor in1Bc(in1.shape(), in1.device(), in1.data_type());
-  Tensor shape(Shape{in1.nDim()}, in1.device(), in1.data_type());
-  Tensor stride(Shape{in1.nDim()}, in1.device(), in1.data_type());
-  const vector<float> strideVec(in1.stride().begin(), in1.stride().end());
-  const vector<float> shapeVec(in1.shape().begin(), in1.shape().end());
+template <typename T>
+void TraverseUnaryTransformImpl(const Tensor& in1, Tensor* in1Bc,
+                                Context* ctx) {
+  Tensor shape(Shape{in1.nDim()}, in1.device(), kInt);
+  Tensor stride(Shape{in1.nDim()}, in1.device(), kInt);
+  const vector<int> strideVec(in1.stride().begin(), in1.stride().end());
+  const vector<int> shapeVec(in1.shape().begin(), in1.shape().end());
   shape.CopyDataFromHostPtr(shapeVec.data(), in1.nDim());
   stride.CopyDataFromHostPtr(strideVec.data(), in1.nDim());
+  const int* shapePtr = static_cast<const int*>(shape.block()->data());
+  const int* stridePtr = static_cast<const int*>(stride.block()->data());
 
-  const float* shapePtr = static_cast<const float*>(shape.block()->data());
-  const float* stridePtr = static_cast<const float*>(stride.block()->data());
-  const float* inPtr1 = static_cast<const float*>(in1.block()->data());
-  float* inBcPtr1 = static_cast<float*>(in1Bc.block()->mutable_data());
+  const T* inPtr1 = static_cast<const T*>(in1.block()->data());
+  T* inBcPtr1 = static_cast<T*>(in1Bc->block()->mutable_data());
 
-  const size_t n = Product(in1Bc.shape());
+  const size_t n = Product(in1Bc->shape());
 
-  cuda::broadcast_to(n, in1.nDim(), inPtr1, shapePtr, stridePtr, inBcPtr1,
-                     ctx->stream);
-
-  return in1Bc;
+  cuda::traverse_unary_transform(n, in1.nDim(), inPtr1, shapePtr, stridePtr,
+                                 inBcPtr1, ctx->stream);
 }
+template void TraverseUnaryTransformImpl<float>(const Tensor& in1,
+                                                Tensor* in1Bc, Context* ctx);
 
 template <>
 void Transform<float, lang::Cuda>(const Tensor& in, Tensor* out, Context* ctx) {
+  if (in.broadcasted()) {
+    TraverseUnaryTransformImpl<float>(in, out, ctx);
+    return;
+  }
+
   const float* inPtr = static_cast<const float*>(in.block()->data());
   float* outPtr = static_cast<float*>(out->block()->mutable_data());
 
@@ -251,13 +260,7 @@
     float* outPtr = static_cast<float*>(out->block()->mutable_data());    \
     const size_t num = out->Size();                                       \
                                                                           \
-    int strideProduct1 = 1;                                               \
-    for (const auto& i : in1.stride()) strideProduct1 *= i;               \
-                                                                          \
-    int strideProduct2 = 1;                                               \
-    for (const auto& i : in2.stride()) strideProduct2 *= i;               \
-                                                                          \
-    if ((strideProduct1 * strideProduct2) != 0) {                         \
+    if (!in1.broadcasted() && !in2.broadcasted()) {                       \
       if (!in1.transpose() && !in2.transpose() &&                         \
           (in1.stride() == in2.stride())) {                               \
         kernel(num, inPtr1, inPtr2, outPtr, ctx->stream);                 \
@@ -278,16 +281,18 @@
         }                                                                 \
       }                                                                   \
     } else {                                                              \
-      Tensor in1Bc;                                                       \
-      Tensor in2Bc;                                                       \
-      if (strideProduct1 == 0) {                                          \
-        in1Bc = get_broadcasted_tensor(in1, ctx);                         \
-        inPtr1 = static_cast<const float*>(in1Bc.block()->data());        \
+      Tensor in1bc;                                                       \
+      Tensor in2bc;                                                       \
+      if (in1.broadcasted()) {                                            \
+        in1bc = Tensor(in1.shape(), in1.device(), in1.data_type());       \
+        Transform<float, lang::Cuda>(in1, &in1bc, ctx);                   \
+        inPtr1 = static_cast<const float*>(in1bc.block()->data());        \
       }                                                                   \
                                                                           \
-      if (strideProduct2 == 0) {                                          \
-        in2Bc = get_broadcasted_tensor(in2, ctx);                         \
-        inPtr2 = static_cast<const float*>(in2Bc.block()->data());        \
+      if (in2.broadcasted()) {                                            \
+        in2bc = Tensor(in2.shape(), in2.device(), in2.data_type());       \
+        Transform<float, lang::Cuda>(in2, &in2bc, ctx);                   \
+        inPtr2 = static_cast<const float*>(in2bc.block()->data());        \
       }                                                                   \
                                                                           \
       kernel(num, inPtr1, inPtr2, outPtr, ctx->stream);                   \
@@ -363,6 +368,20 @@
 }
 
 template <>
+void Erf<float, lang::Cuda>(const Tensor& in, Tensor* out, Context* ctx) {
+  const float* inPtr = static_cast<const float*>(in.block()->data());
+  float* outPtr = static_cast<float*>(out->block()->mutable_data());
+  const size_t num = in.Size();
+
+  if (in.stride() == out->stride()) {
+    cuda::erf(num, inPtr, outPtr, ctx->stream);
+  } else {  // else we transform in to out to store first
+    Transform<float, lang::Cuda>(in, out, ctx);
+    cuda::erf(num, outPtr, outPtr, ctx->stream);
+  }
+}
+
+template <>
 void Ceil<float, lang::Cuda>(const Tensor& in, Tensor* out, Context* ctx) {
   const float* inPtr = static_cast<const float*>(in.block()->data());
   float* outPtr = static_cast<float*>(out->block()->mutable_data());
@@ -377,6 +396,48 @@
 }
 
 template <>
+void Floor<float, lang::Cuda>(const Tensor& in, Tensor* out, Context* ctx) {
+  const float* inPtr = static_cast<const float*>(in.block()->data());
+  float* outPtr = static_cast<float*>(out->block()->mutable_data());
+  const size_t num = in.Size();
+
+  if (in.stride() == out->stride()) {
+    cuda::floor(num, inPtr, outPtr, ctx->stream);
+  } else {  // else we transform in to out to store first
+    Transform<float, lang::Cuda>(in, out, ctx);
+    cuda::floor(num, outPtr, outPtr, ctx->stream);
+  }
+}
+
+template <>
+void Round<float, lang::Cuda>(const Tensor& in, Tensor* out, Context* ctx) {
+  const float* inPtr = static_cast<const float*>(in.block()->data());
+  float* outPtr = static_cast<float*>(out->block()->mutable_data());
+  const size_t num = in.Size();
+
+  if (in.stride() == out->stride()) {
+    cuda::round(num, inPtr, outPtr, ctx->stream);
+  } else {  // else we transform in to out to store first
+    Transform<float, lang::Cuda>(in, out, ctx);
+    cuda::round(num, outPtr, outPtr, ctx->stream);
+  }
+}
+
+template <>
+void RoundE<float, lang::Cuda>(const Tensor& in, Tensor* out, Context* ctx) {
+  const float* inPtr = static_cast<const float*>(in.block()->data());
+  float* outPtr = static_cast<float*>(out->block()->mutable_data());
+  const size_t num = in.Size();
+
+  if (in.stride() == out->stride()) {
+    cuda::rounde(num, inPtr, outPtr, ctx->stream);
+  } else {  // else we transform in to out to store first
+    Transform<float, lang::Cuda>(in, out, ctx);
+    cuda::rounde(num, outPtr, outPtr, ctx->stream);
+  }
+}
+
+template <>
 void GE<float, lang::Cuda>(const Tensor& in, const float x, Tensor* out,
                            Context* ctx) {
   float* outPtr = static_cast<float*>(out->block()->mutable_data());
@@ -445,6 +506,29 @@
   cuda::le(num, outPtr, 0.0, outPtr, ctx->stream);
 }
 
+template <>
+void EQ<float, lang::Cuda>(const Tensor& in, const float x, Tensor* out,
+                           Context* ctx) {
+  float* outPtr = static_cast<float*>(out->block()->mutable_data());
+  const float* inPtr = static_cast<const float*>(in.block()->data());
+  const size_t num = in.Size();
+
+  if (in.stride() == out->stride()) {
+    cuda::eq(num, inPtr, x, outPtr, ctx->stream);
+  } else {  // else we transform in to out to store first
+    Transform<float, lang::Cuda>(in, out, ctx);
+    cuda::eq(num, outPtr, x, outPtr, ctx->stream);
+  }
+}
+template <>
+void EQ<float, lang::Cuda>(const Tensor& in1, const Tensor& in2, Tensor* out,
+                           Context* ctx) {
+  Sub<float, lang::Cuda>(in1, in2, out, ctx);
+  float* outPtr = static_cast<float*>(out->block()->mutable_data());
+  const size_t num = in1.Size();
+  cuda::eq(num, outPtr, 0.0, outPtr, ctx->stream);
+}
+
 /// Natual logarithm, the base is e, Neper number out[i]=ln(in[i]).
 template <>
 void Log<float, lang::Cuda>(const Tensor& in, Tensor* out, Context* ctx) {
@@ -788,7 +872,16 @@
   auto rgen = ctx->curand_generator;
   float* outPtr = static_cast<float*>(out->block()->mutable_data());
   const size_t num = out->Size();
-  CURAND_CHECK(curandGenerateNormal(rgen, outPtr, num, mean, std));
+
+  // CURAND_STATUS_LENGTH_NOT_MULTIPLE
+  if (num % 2 != 0) {
+    Tensor tmp(Shape{num + 1}, out->device());
+    float* outPtr_tmp = static_cast<float*>(tmp.block()->mutable_data());
+    CURAND_CHECK(curandGenerateNormal(rgen, outPtr_tmp, num + 1, mean, std));
+    CopyDataToFrom(out, tmp, num, 0, 0);
+  } else {
+    CURAND_CHECK(curandGenerateNormal(rgen, outPtr, num, mean, std));
+  }
 }
 
 // =========================Blas operations==================================
diff --git a/src/io/communicator.cc b/src/io/communicator.cc
index 0e93366..a64c79d 100644
--- a/src/io/communicator.cc
+++ b/src/io/communicator.cc
@@ -105,9 +105,6 @@
 void Communicator::setup() {
   CUDA_CHECK(cudaSetDevice(local_rank));
   NCCLCHECK(ncclCommInitRank(&comm, world_size, id, global_rank));
-  CUDA_CHECK(cudaStreamCreateWithFlags(&s, cudaStreamNonBlocking));
-  CUDA_CHECK(cudaStreamCreateWithFlags(&c1, cudaStreamNonBlocking));
-  CUDA_CHECK(cudaStreamCreateWithFlags(&c2, cudaStreamNonBlocking));
   CUDA_CHECK(cudaMalloc(&fusedSendBuff, maxSize * sizeof(float)));
   CUDA_CHECK(cudaMalloc(&fusedRecvBuff, maxSize * sizeof(float)));
   CUDA_CHECK(cudaEventCreateWithFlags(
@@ -134,7 +131,6 @@
   CUDA_CHECK(cudaMalloc(&xInd, (int)(sizeof(int) * maxSize)));
   CUDA_CHECK(cudaMalloc(&xVal, (int)(sizeof(float) * maxSize)));
   CUSPARSE_CHECK(cusparseCreate(&cusparse_handle));
-  CUSPARSE_CHECK(cusparseSetStream(cusparse_handle, c2));
   nnz = (int *)malloc(sizeof(int));
   nnzAll = (int *)malloc(sizeof(int) * world_size);
   CUDA_CHECK(cudaMalloc(&nnzGPU, sizeof(int) * world_size));
@@ -143,9 +139,9 @@
 }
 
 void Communicator::allReduce(int size, void *sendbuff, void *recvbuff,
-                             ncclDataType_t ncclType) {
+                             ncclDataType_t ncclType, Context *ctx) {
   NCCLCHECK(ncclAllReduce((const void *)sendbuff, (void *)recvbuff, size,
-                          ncclType, ncclSum, comm, s));
+                          ncclType, ncclSum, comm, ctx->s));
 }
 
 void Communicator::generateBlocks(Tensor &t) {
@@ -179,14 +175,14 @@
   device_->Exec(
       [this](Context *ctx) mutable {
         // synchronizing on all the CUDA streams used by communicator
-        CUDA_CHECK(cudaEventRecord(event, s));
-        CUDA_CHECK(cudaStreamWaitEvent(NULL, event, 0));
-        CUDA_CHECK(cudaEventRecord(event, c1));
-        CUDA_CHECK(cudaStreamWaitEvent(NULL, event, 0));
-        CUDA_CHECK(cudaEventRecord(event, c2));
-        CUDA_CHECK(cudaStreamWaitEvent(NULL, event, 0));
+        CUDA_CHECK(cudaEventRecord(event, ctx->s));
+        CUDA_CHECK(cudaStreamWaitEvent(ctx->stream, event, 0));
+        CUDA_CHECK(cudaEventRecord(event, ctx->c1));
+        CUDA_CHECK(cudaStreamWaitEvent(ctx->stream, event, 0));
+        CUDA_CHECK(cudaEventRecord(event, ctx->c2));
+        CUDA_CHECK(cudaStreamWaitEvent(ctx->stream, event, 0));
       },
-      blocks_, blocks_);
+      blocks_, blocks_, "Waiting");
 }
 
 Communicator::~Communicator() {
@@ -195,9 +191,6 @@
   if (UseMPI == true) MPICHECK(MPI_Finalize());
   CUDA_CHECK(cudaFree(fusedSendBuff));
   CUDA_CHECK(cudaFree(fusedRecvBuff));
-  CUDA_CHECK(cudaStreamDestroy(s));
-  CUDA_CHECK(cudaStreamDestroy(c1));
-  CUDA_CHECK(cudaStreamDestroy(c2));
 
   if (halfInitialized == true) {
     CUDA_CHECK(cudaFree(fusedSendBuffHalf));
@@ -226,47 +219,61 @@
     device_->Exec(
         [this, t](Context *ctx) mutable {
           // record the event of the default cuda stream and follow it
-          CUDA_CHECK(cudaEventRecord(event, NULL));
-          CUDA_CHECK(cudaStreamWaitEvent(c1, event, 0));
+          CUDA_CHECK(cudaEventRecord(event, ctx->stream));
+          CUDA_CHECK(cudaStreamWaitEvent(ctx->c1, event, 0));
+        },
+        prev_blocks_, prev_blocks_, "Waiting");
 
+    device_->Exec(
+        [this, t](Context *ctx) mutable {
           // memory copy to fusedBuff
           for (size_t i = 0; i < t.size(); i++) {
-            CUDA_CHECK(cudaMemcpyAsync(
-                (void *)(fusedSendBuff + sendBuffOffset),
-                (const void *)t[i].block()->mutable_data(),
-                t[i].Size() * sizeof(float), cudaMemcpyDeviceToDevice, c1));
+            CUDA_CHECK(
+                cudaMemcpyAsync((void *)(fusedSendBuff + sendBuffOffset),
+                                (const void *)t[i].block()->mutable_data(),
+                                t[i].Size() * sizeof(float),
+                                cudaMemcpyDeviceToDevice, ctx->c1));
             sendBuffOffset += t[i].Size();
           }
         },
-        prev_blocks_, blocks_);
+        prev_blocks_, blocks_, "Dist_c1_fusedSynch_filling");
+
   } else {
     // send the tensors in the buffer
     device_->Exec(
-        [this, t](Context *ctx) mutable {
+        [this](Context *ctx) mutable {
           // wait for the memcpy to complete
-          CUDA_CHECK(cudaEventRecord(event, c1));
-          CUDA_CHECK(cudaStreamWaitEvent(s, event, 0));
-
+          CUDA_CHECK(cudaEventRecord(event, ctx->c1));
+          CUDA_CHECK(cudaStreamWaitEvent(ctx->s, event, 0));
+        },
+        prev_blocks_, prev_blocks_, "Waiting");
+    device_->Exec(
+        [this](Context *ctx) mutable {
           allReduce((int)sendBuffOffset, (void *)fusedSendBuff,
-                    (void *)fusedRecvBuff, ncclFloat);
-
+                    (void *)fusedRecvBuff, ncclFloat, ctx);
           sendBuffOffset = 0;
-
+        },
+        prev_blocks_, blocks_, "Dist_s_fusedSynch_allreduce");
+    device_->Exec(
+        [this](Context *ctx) mutable {
           // wait for the allreduce to complete
-          CUDA_CHECK(cudaEventRecord(event, s));
-          CUDA_CHECK(cudaStreamWaitEvent(c1, event, 0));
-
+          CUDA_CHECK(cudaEventRecord(event, ctx->s));
+          CUDA_CHECK(cudaStreamWaitEvent(ctx->c1, event, 0));
+        },
+        blocks_, blocks_, "Waiting");
+    device_->Exec(
+        [this, t](Context *ctx) mutable {
           // copy data back to tensors after allreduce
           size_t offset = 0;
           for (size_t i = 0; i < t.size(); i++) {
             CUDA_CHECK(cudaMemcpyAsync((void *)t[i].block()->mutable_data(),
                                        (const void *)(fusedRecvBuff + offset),
                                        t[i].Size() * sizeof(float),
-                                       cudaMemcpyDeviceToDevice, c1));
+                                       cudaMemcpyDeviceToDevice, ctx->c1));
             offset += t[i].Size();
           }
         },
-        blocks_, blocks_);
+        blocks_, blocks_, "Dist_c1_fusedSynch_copyBackToTensor");
   }
 }
 
@@ -277,13 +284,17 @@
   device_->Exec(
       [this, t](Context *ctx) mutable {
         // record the event of the default cuda stream and follow it
-        CUDA_CHECK(cudaEventRecord(event, NULL));
-        CUDA_CHECK(cudaStreamWaitEvent(s, event, 0));
-
-        void *addr = t.block()->mutable_data();
-        allReduce(t.Size(), addr, addr, ncclFloat);
+        CUDA_CHECK(cudaEventRecord(event, ctx->stream));
+        CUDA_CHECK(cudaStreamWaitEvent(ctx->s, event, 0));
       },
-      {t.block()}, {t.block()});
+      {t.block()}, {t.block()}, "Waiting");
+
+  device_->Exec(
+      [this, t](Context *ctx) mutable {
+        void *addr = t.block()->mutable_data();
+        allReduce(t.Size(), addr, addr, ncclFloat, ctx);
+      },
+      {t.block()}, {t.block()}, "Dist_s_synch_allreduce");
 }
 
 void Communicator::fusedSynchHalf(vector<Tensor> &t, bool send) {
@@ -296,43 +307,59 @@
   if (!send) {
     // buffer the tensors and convert them into half
     device_->Exec(
-        [this, t](Context *ctx) mutable {
+        [this](Context *ctx) mutable {
           // record the event of the default cuda stream and follow it
-          CUDA_CHECK(cudaEventRecord(event, NULL));
-          CUDA_CHECK(cudaStreamWaitEvent(c1, event, 0));
-
+          CUDA_CHECK(cudaEventRecord(event, ctx->stream));
+          CUDA_CHECK(cudaStreamWaitEvent(ctx->c1, event, 0));
+        },
+        prev_blocks_, prev_blocks_, "Waiting");
+    device_->Exec(
+        [this, t](Context *ctx) mutable {
           size_t offset = 0;
           // memory copy to fusedBuff
           for (size_t i = 0; i < t.size(); i++) {
-            CUDA_CHECK(cudaMemcpyAsync(
-                (void *)(fusedSendBuff + sendBuffOffset),
-                (const void *)t[i].block()->mutable_data(),
-                t[i].Size() * sizeof(float), cudaMemcpyDeviceToDevice, c1));
+            CUDA_CHECK(
+                cudaMemcpyAsync((void *)(fusedSendBuff + sendBuffOffset),
+                                (const void *)t[i].block()->mutable_data(),
+                                t[i].Size() * sizeof(float),
+                                cudaMemcpyDeviceToDevice, ctx->c1));
             sendBuffOffset += t[i].Size();
             offset += t[i].Size();
           }
         },
-        prev_blocks_, blocks_);
+        prev_blocks_, blocks_, "Dist_c1_fusedSynchHalf_filling");
   } else {
     // send the tensors in the buffer
     device_->Exec(
-        [this, t](Context *ctx) mutable {
+        [this](Context *ctx) mutable {
           cuda::float2half(sendBuffOffset, fusedSendBuff, fusedSendBuffHalf,
-                           c1);
-
+                           ctx->c1);
+        },
+        prev_blocks_, blocks_, "Dist_c1_fusedSynchHalf_float2half");
+    device_->Exec(
+        [this](Context *ctx) mutable {
           // wait for the memcpy to complete
-          CUDA_CHECK(cudaEventRecord(event, c1));
-          CUDA_CHECK(cudaStreamWaitEvent(s, event, 0));
-
+          CUDA_CHECK(cudaEventRecord(event, ctx->c1));
+          CUDA_CHECK(cudaStreamWaitEvent(ctx->s, event, 0));
+        },
+        blocks_, blocks_, "Waiting");
+    device_->Exec(
+        [this](Context *ctx) mutable {
           allReduce((int)sendBuffOffset, (void *)fusedSendBuffHalf,
-                    (void *)fusedRecvBuffHalf, ncclHalf);
-
+                    (void *)fusedRecvBuffHalf, ncclHalf, ctx);
+        },
+        blocks_, blocks_, "Dist_s_fusedSynchHalf_allreduce");
+    device_->Exec(
+        [this](Context *ctx) mutable {
           // wait for the allreduce to complete
-          CUDA_CHECK(cudaEventRecord(event, s));
-          CUDA_CHECK(cudaStreamWaitEvent(c2, event, 0));
-
+          CUDA_CHECK(cudaEventRecord(event, ctx->s));
+          CUDA_CHECK(cudaStreamWaitEvent(ctx->c2, event, 0));          
+        },
+        blocks_, blocks_, "Waiting");
+    device_->Exec(
+        [this, t](Context *ctx) mutable {
           cuda::half2float(sendBuffOffset, fusedRecvBuffHalf, fusedRecvBuff,
-                           c2);
+                           ctx->c2);
 
           sendBuffOffset = 0;
 
@@ -342,11 +369,11 @@
             CUDA_CHECK(cudaMemcpyAsync((void *)t[i].block()->mutable_data(),
                                        (const void *)(fusedRecvBuff + offset),
                                        t[i].Size() * sizeof(float),
-                                       cudaMemcpyDeviceToDevice, c2));
+                                       cudaMemcpyDeviceToDevice, ctx->c2));
             offset += t[i].Size();
           }
         },
-        blocks_, blocks_);
+        blocks_, blocks_, "Dist_c2_fusedSynchHalf_half2floatcopy");
   }
 }
 
@@ -357,28 +384,44 @@
 
   device_->Exec(
       [this, t](Context *ctx) mutable {
-        float *addr = static_cast<float *>(t.block()->mutable_data());
-
         // record the event of the default cuda stream and follow it
-        CUDA_CHECK(cudaEventRecord(event, NULL));
-        CUDA_CHECK(cudaStreamWaitEvent(c1, event, 0));
-
-        cuda::float2half(t.Size(), addr, fusedSendBuffHalf, c1);
-
-        // wait for conversion to half precision complete
-        CUDA_CHECK(cudaEventRecord(event, c1));
-        CUDA_CHECK(cudaStreamWaitEvent(s, event, 0));
-
-        allReduce(t.Size(), (void *)fusedSendBuffHalf,
-                  (void *)fusedRecvBuffHalf, ncclHalf);
-
-        // wait for the allreduce to complete
-        CUDA_CHECK(cudaEventRecord(event, s));
-        CUDA_CHECK(cudaStreamWaitEvent(c2, event, 0));
-
-        cuda::half2float(t.Size(), fusedRecvBuffHalf, addr, c2);
+        CUDA_CHECK(cudaEventRecord(event, ctx->stream));
+        CUDA_CHECK(cudaStreamWaitEvent(ctx->c1, event, 0));
       },
-      blocks_, blocks_);
+      blocks_, blocks_, "Waiting");
+  device_->Exec(
+      [this, t](Context *ctx) mutable {
+        float *addr = static_cast<float *>(t.block()->mutable_data());
+        cuda::float2half(t.Size(), addr, fusedSendBuffHalf, ctx->c1);
+      },
+      blocks_, blocks_, "Dist_c1_synchHalf_float2half");
+  device_->Exec(
+      [this, t](Context *ctx) mutable {
+        // wait for conversion to half precision complete
+        CUDA_CHECK(cudaEventRecord(event, ctx->c1));
+        CUDA_CHECK(cudaStreamWaitEvent(ctx->s, event, 0));
+      },
+      blocks_, blocks_, "Waiting");
+  device_->Exec(
+      [this, t](Context *ctx) mutable {
+        allReduce(t.Size(), (void *)fusedSendBuffHalf,
+                  (void *)fusedRecvBuffHalf, ncclHalf, ctx);
+      },
+      blocks_, blocks_, "Dist_s_synchHalf_allreduce");
+  device_->Exec(
+      [this, t](Context *ctx) mutable {
+        // wait for the allreduce to complete
+        CUDA_CHECK(cudaEventRecord(event, ctx->s));
+        CUDA_CHECK(cudaStreamWaitEvent(ctx->c2, event, 0));
+      },
+      blocks_, blocks_, "Waiting");
+  device_->Exec(
+      [this, t](Context *ctx) mutable {
+        float *addr = static_cast<float *>(t.block()->mutable_data());
+        cuda::half2float(t.Size(), fusedRecvBuffHalf, addr, ctx->c2);
+      },
+      blocks_, blocks_, "Dist_c2_synchHalf_half2float");
+
 }
 
 void Communicator::sparsification(Tensor &t, Tensor &accumulation,
@@ -388,9 +431,9 @@
 
   device_->Exec(
       [=](Context *ctx) mutable {
-        _sparsification(t, &accumulation, sparsThreshold, topK);
+        _sparsification(t, &accumulation, sparsThreshold, topK, ctx);
       },
-      blocks_, blocks_);
+      blocks_, blocks_, "Dist_c1c2_sparsification");
 }
 
 void Communicator::sparsification(Tensor &t, float sparsThreshold, bool topK) {
@@ -398,24 +441,25 @@
 
   t.device()->Exec(
       [=](Context *ctx) mutable {
-        _sparsification(t, (Tensor *)NULL, sparsThreshold, topK);
+        _sparsification(t, (Tensor *)NULL, sparsThreshold, topK, ctx);
       },
-      blocks_, blocks_);
+      blocks_, blocks_, "Dist_c1c2_sparsification");
 }
 
 void Communicator::_sparsification(Tensor &t, Tensor *accumulation,
-                                   float sparsThreshold, bool topK) {
+                                   float sparsThreshold, bool topK,
+                                   Context *ctx) {
   // threshold for sprasification
   threshold = sparsThreshold;
 
   // record the event of the default cuda stream and follow it
-  CUDA_CHECK(cudaEventRecord(event, NULL));
-  CUDA_CHECK(cudaStreamWaitEvent(c1, event, 0));
+  CUDA_CHECK(cudaEventRecord(event, ctx->stream));
+  CUDA_CHECK(cudaStreamWaitEvent(ctx->c1, event, 0));
 
   // memory copy to fusedBuff
   CUDA_CHECK(cudaMemcpyAsync(
       (void *)fusedSendBuff, (const void *)t.block()->mutable_data(),
-      t.Size() * sizeof(float), cudaMemcpyDeviceToDevice, c1));
+      t.Size() * sizeof(float), cudaMemcpyDeviceToDevice, ctx->c1));
 
   float *accumPtr;
 
@@ -425,14 +469,14 @@
     accumPtr = NULL;
 
   if (topK == false)
-    valSparsAllReduce(t.Size(), accumPtr);
+    valSparsAllReduce(t.Size(), accumPtr, ctx);
   else
-    topKSparsAllReduce(t.Size(), accumPtr);
+    topKSparsAllReduce(t.Size(), accumPtr, ctx);
 
   // copy data back to tensor after allreduce
   CUDA_CHECK(cudaMemcpyAsync(
       (void *)t.block()->mutable_data(), (const void *)fusedRecvBuff,
-      t.Size() * sizeof(float), cudaMemcpyDeviceToDevice, c2));
+      t.Size() * sizeof(float), cudaMemcpyDeviceToDevice, ctx->c2));
 }
 
 void Communicator::fusedSparsification(vector<Tensor> &t, Tensor &accumulation,
@@ -444,9 +488,9 @@
 
   device_->Exec(
       [=](Context *ctx) mutable {
-        _fusedSparsification(t, &accumulation, sparsThreshold, topK);
+        _fusedSparsification(t, &accumulation, sparsThreshold, topK, ctx);
       },
-      blocks_, blocks_);
+      blocks_, blocks_, "Dist_c1c2_fusedSparsification");
 }
 
 void Communicator::fusedSparsification(vector<Tensor> &t, float sparsThreshold,
@@ -457,19 +501,20 @@
 
   device_->Exec(
       [=](Context *ctx) mutable {
-        _fusedSparsification(t, (Tensor *)NULL, sparsThreshold, topK);
+        _fusedSparsification(t, (Tensor *)NULL, sparsThreshold, topK, ctx);
       },
-      blocks_, blocks_);
+      blocks_, blocks_, "Dist_c1c2_fusedSparsification");
 }
 
 void Communicator::_fusedSparsification(vector<Tensor> &t, Tensor *accumulation,
-                                        float sparsThreshold, bool topK) {
+                                        float sparsThreshold, bool topK,
+                                        Context *ctx) {
   // threshold for sprasification
   threshold = sparsThreshold;
 
   // record the event of the default cuda stream and follow it
-  CUDA_CHECK(cudaEventRecord(event, NULL));
-  CUDA_CHECK(cudaStreamWaitEvent(c1, event, 0));
+  CUDA_CHECK(cudaEventRecord(event, ctx->stream));
+  CUDA_CHECK(cudaStreamWaitEvent(ctx->c1, event, 0));
 
   size_t offset = 0;
 
@@ -478,7 +523,7 @@
     CUDA_CHECK(cudaMemcpyAsync((void *)(fusedSendBuff + offset),
                                (const void *)t[i].block()->mutable_data(),
                                t[i].Size() * sizeof(float),
-                               cudaMemcpyDeviceToDevice, c1));
+                               cudaMemcpyDeviceToDevice, ctx->c1));
     offset += t[i].Size();
   }
 
@@ -490,9 +535,9 @@
     accumPtr = NULL;
 
   if (topK == false)
-    valSparsAllReduce(offset, accumPtr);
+    valSparsAllReduce(offset, accumPtr, ctx);
   else
-    topKSparsAllReduce(offset, accumPtr);
+    topKSparsAllReduce(offset, accumPtr, ctx);
 
   // copy data back to tensors after allreduce
   offset = 0;
@@ -500,91 +545,94 @@
     CUDA_CHECK(cudaMemcpyAsync((void *)t[i].block()->mutable_data(),
                                (const void *)(fusedRecvBuff + offset),
                                t[i].Size() * sizeof(float),
-                               cudaMemcpyDeviceToDevice, c2));
+                               cudaMemcpyDeviceToDevice, ctx->c2));
     offset += t[i].Size();
   }
 }
 
-void Communicator::valSparsAllReduce(size_t num, float *accumulation) {
+void Communicator::valSparsAllReduce(size_t num, float *accumulation,
+                                     Context *ctx) {
   if (sparsInitialized == false) sparsInit();
 
   if (accumulation != NULL) {
     // add the previous accumulation
-    cuda::add(num, fusedSendBuff, accumulation, fusedSendBuff, c1);
+    cuda::add(num, fusedSendBuff, accumulation, fusedSendBuff, ctx->c1);
     // backup the fusedSendBuff
     CUDA_CHECK(cudaMemcpyAsync((void *)backupBuff, (const void *)fusedSendBuff,
                                sizeof(float) * num, cudaMemcpyDeviceToDevice,
-                               c1));
+                               ctx->c1));
   }
 
   // sparsification based on threshold
-  cuda::sparsabs(num, threshold, fusedSendBuff, fusedSendBuff, c1);
+  cuda::sparsabs(num, threshold, fusedSendBuff, fusedSendBuff, ctx->c1);
 
   // output the gradient accumulation
   if (accumulation != NULL)
-    cuda::sub(num, backupBuff, fusedSendBuff, accumulation, c1);
+    cuda::sub(num, backupBuff, fusedSendBuff, accumulation, ctx->c1);
 
   // produce the index of the sparse array
-  cuda::sparsindex(num, fusedSendBuff, fusedIndex, c1);
+  cuda::sparsindex(num, fusedSendBuff, fusedIndex, ctx->c1);
 
   // remove zero of index to become sprase array and get the num of non-zero nnz
-  cuda::removezeroidx(num, fusedIndex, c1, nnz);
+  cuda::removezeroidx(num, fusedIndex, ctx->c1, nnz);
 
   CUDA_CHECK(cudaMemcpyAsync((void *)nnzGPU, (const void *)nnz, sizeof(int),
-                             cudaMemcpyHostToDevice, c1));
+                             cudaMemcpyHostToDevice, ctx->c1));
 
   // all-gather all the nnz from different ranks
   NCCLCHECK(ncclAllGather((const void *)nnzGPU, (void *)nnzAllGPU, 1, ncclInt,
-                          comm, c1));
+                          comm, ctx->c1));
 
   CUDA_CHECK(cudaMemcpyAsync((void *)nnzAll, (const void *)nnzAllGPU,
                              sizeof(int) * world_size, cudaMemcpyDeviceToHost,
-                             c1));
+                             ctx->c1));
 
-  CUDA_CHECK(cudaStreamSynchronize(c1));
+  CUDA_CHECK(cudaStreamSynchronize(ctx->c1));
 
   int nnzMax = 0;
   for (int i = 0; i < world_size; i++)
     if (nnzAll[i] > nnzMax) nnzMax = nnzAll[i];
 
   // remove zero of values to become sprase array
-  cuda::removezeroval(num, fusedSendBuff, c1);
+  cuda::removezeroval(num, fusedSendBuff, ctx->c1);
 
   CUDA_CHECK(cudaMemcpyAsync((void *)(sparsSendBuff), (const void *)fusedIndex,
                              sizeof(int) * (*nnz), cudaMemcpyDeviceToDevice,
-                             c1));
+                             ctx->c1));
   CUDA_CHECK(cudaMemcpyAsync(
       (void *)(sparsSendBuff + (*nnz)), (const void *)fusedSendBuff,
-      sizeof(float) * (*nnz), cudaMemcpyDeviceToDevice, c1));
+      sizeof(float) * (*nnz), cudaMemcpyDeviceToDevice, ctx->c1));
 
   // wait for the memcpy to complete
-  CUDA_CHECK(cudaEventRecord(event, c1));
-  CUDA_CHECK(cudaStreamWaitEvent(s, event, 0));
+  CUDA_CHECK(cudaEventRecord(event, ctx->c1));
+  CUDA_CHECK(cudaStreamWaitEvent(ctx->s, event, 0));
 
   // all-gather all the sparse gradients
   NCCLCHECK(ncclAllGather((const void *)sparsSendBuff, (void *)sparsRecvBuff,
-                          2 * nnzMax, ncclFloat, comm, s));
+                          2 * nnzMax, ncclFloat, comm, ctx->s));
 
   // wait for the all-gather to complete
-  CUDA_CHECK(cudaEventRecord(event, s));
-  CUDA_CHECK(cudaStreamWaitEvent(c2, event, 0));
+  CUDA_CHECK(cudaEventRecord(event, ctx->s));
+  CUDA_CHECK(cudaStreamWaitEvent(ctx->c2, event, 0));
 
   // reduce the sparse gradients, firstly setting the sum buff value to zero
-  CUDA_CHECK(cudaMemsetAsync(fusedRecvBuff, 0, num * sizeof(float), c2));
+  CUDA_CHECK(cudaMemsetAsync(fusedRecvBuff, 0, num * sizeof(float), ctx->c2));
 
   size_t offset = 0;
   float alpha = 1.0;
 
   // add the spase gradent from each rank to the sum buff to finish the
   // all-reduce process
+  CUSPARSE_CHECK(cusparseSetStream(cusparse_handle, ctx->c2));
+
   for (int i = 0; i < world_size; i++) {
-    CUDA_CHECK(
-        cudaMemcpyAsync((void *)xInd, (const void *)(sparsRecvBuff + offset),
-                        sizeof(int) * nnzAll[i], cudaMemcpyDeviceToDevice, c2));
+    CUDA_CHECK(cudaMemcpyAsync(
+        (void *)xInd, (const void *)(sparsRecvBuff + offset),
+        sizeof(int) * nnzAll[i], cudaMemcpyDeviceToDevice, ctx->c2));
     offset += nnzAll[i];
     CUDA_CHECK(cudaMemcpyAsync(
         (void *)xVal, (const void *)(sparsRecvBuff + offset),
-        sizeof(float) * nnzAll[i], cudaMemcpyDeviceToDevice, c2));
+        sizeof(float) * nnzAll[i], cudaMemcpyDeviceToDevice, ctx->c2));
     offset += (2 * nnzMax - nnzAll[i]);
     CUSPARSE_CHECK(cusparseSaxpyi(cusparse_handle, nnzAll[i], &alpha, xVal,
                                   xInd, fusedRecvBuff,
@@ -592,22 +640,23 @@
   }
 }
 
-void Communicator::topKSparsAllReduce(size_t num, float *accumulation) {
+void Communicator::topKSparsAllReduce(size_t num, float *accumulation,
+                                      Context *ctx) {
   if (sparsInitialized == false) sparsInit();
 
   // use gradient accumulation
   if (accumulation != NULL) {
     // add the previous accumulation
-    cuda::add(num, fusedSendBuff, accumulation, fusedSendBuff, c1);
+    cuda::add(num, fusedSendBuff, accumulation, fusedSendBuff, ctx->c1);
     // backup the fusedSendBuff
     CUDA_CHECK(cudaMemcpyAsync((void *)backupBuff, (const void *)fusedSendBuff,
                                sizeof(float) * num, cudaMemcpyDeviceToDevice,
-                               c1));
+                               ctx->c1));
   }
 
   // generate an index and sort the fusedSendBuff from large to small values
-  cuda::generateindex(num, fusedIndex, c1);
-  cuda::sortbykey(num, fusedSendBuff, fusedIndex, c1);
+  cuda::generateindex(num, fusedIndex, ctx->c1);
+  cuda::sortbykey(num, fusedSendBuff, fusedIndex, ctx->c1);
 
   // determine the number of topK for communication
   int nnzMax = (int)ceil(threshold * num);
@@ -615,51 +664,51 @@
   // output the gradient accumulation
   float alpha = 1.0;
   if (accumulation != NULL) {
-    CUDA_CHECK(cudaMemsetAsync(accumulation, 0, num * sizeof(float), c1));
-    CUSPARSE_CHECK(cusparseSetStream(cusparse_handle, c1));
+    CUDA_CHECK(cudaMemsetAsync(accumulation, 0, num * sizeof(float), ctx->c1));
+    CUSPARSE_CHECK(cusparseSetStream(cusparse_handle, ctx->c1));
     CUSPARSE_CHECK(cusparseSaxpyi(cusparse_handle, nnzMax, &alpha,
                                   fusedSendBuff, fusedIndex, accumulation,
                                   CUSPARSE_INDEX_BASE_ONE));
-    cuda::sub(num, backupBuff, accumulation, accumulation, c1);
+    cuda::sub(num, backupBuff, accumulation, accumulation, ctx->c1);
   }
 
   // the topK value and index will be sent
   CUDA_CHECK(cudaMemcpyAsync((void *)(sparsSendBuff), (const void *)fusedIndex,
                              sizeof(int) * nnzMax, cudaMemcpyDeviceToDevice,
-                             c1));
+                             ctx->c1));
   CUDA_CHECK(cudaMemcpyAsync(
       (void *)(sparsSendBuff + nnzMax), (const void *)fusedSendBuff,
-      sizeof(float) * nnzMax, cudaMemcpyDeviceToDevice, c1));
+      sizeof(float) * nnzMax, cudaMemcpyDeviceToDevice, ctx->c1));
 
   // wait for the memcpy to complete
-  CUDA_CHECK(cudaEventRecord(event, c1));
-  CUDA_CHECK(cudaStreamWaitEvent(s, event, 0));
+  CUDA_CHECK(cudaEventRecord(event, ctx->c1));
+  CUDA_CHECK(cudaStreamWaitEvent(ctx->s, event, 0));
 
   // all-gather all the sparse gradients
   NCCLCHECK(ncclAllGather((const void *)sparsSendBuff, (void *)sparsRecvBuff,
-                          2 * nnzMax, ncclFloat, comm, s));
+                          2 * nnzMax, ncclFloat, comm, ctx->s));
 
   // wait for the all-gather to complete
-  CUDA_CHECK(cudaEventRecord(event, s));
-  CUDA_CHECK(cudaStreamWaitEvent(c2, event, 0));
+  CUDA_CHECK(cudaEventRecord(event, ctx->s));
+  CUDA_CHECK(cudaStreamWaitEvent(ctx->c2, event, 0));
 
   // reduce the sparse gradients, firstly setting the sum buff value to zero
-  CUDA_CHECK(cudaMemsetAsync(fusedRecvBuff, 0, num * sizeof(float), c2));
+  CUDA_CHECK(cudaMemsetAsync(fusedRecvBuff, 0, num * sizeof(float), ctx->c2));
 
   size_t offset = 0;
 
-  CUSPARSE_CHECK(cusparseSetStream(cusparse_handle, c2));
+  CUSPARSE_CHECK(cusparseSetStream(cusparse_handle, ctx->c2));
 
   // add the spase gradent from each rank to the sum buff to finish the
   // all-reduce process
   for (int i = 0; i < world_size; i++) {
-    CUDA_CHECK(
-        cudaMemcpyAsync((void *)xInd, (const void *)(sparsRecvBuff + offset),
-                        sizeof(int) * nnzMax, cudaMemcpyDeviceToDevice, c2));
+    CUDA_CHECK(cudaMemcpyAsync(
+        (void *)xInd, (const void *)(sparsRecvBuff + offset),
+        sizeof(int) * nnzMax, cudaMemcpyDeviceToDevice, ctx->c2));
     offset += nnzMax;
-    CUDA_CHECK(
-        cudaMemcpyAsync((void *)xVal, (const void *)(sparsRecvBuff + offset),
-                        sizeof(float) * nnzMax, cudaMemcpyDeviceToDevice, c2));
+    CUDA_CHECK(cudaMemcpyAsync(
+        (void *)xVal, (const void *)(sparsRecvBuff + offset),
+        sizeof(float) * nnzMax, cudaMemcpyDeviceToDevice, ctx->c2));
     offset += nnzMax;
     CUSPARSE_CHECK(cusparseSaxpyi(cusparse_handle, nnzMax, &alpha, xVal, xInd,
                                   fusedRecvBuff, CUSPARSE_INDEX_BASE_ONE));
diff --git a/src/model/layer/cudnn_activation.cc b/src/model/layer/cudnn_activation.cc
index d1b2141..806e714 100644
--- a/src/model/layer/cudnn_activation.cc
+++ b/src/model/layer/cudnn_activation.cc
@@ -81,7 +81,7 @@
     CUDNN_CHECK(cudnnActivationForward(
         ctx->cudnn_handle, this->acti_desc_, &alpha, this->desc_,
         inblock->data(), &beta, this->desc_, outblock->mutable_data()));
-  }, {input.block()}, {output.block()});
+  }, {input.block()}, {output.block()}, "cudnnActivationForward");
   if (flag & kTrain) {
     if (cudnn_mode_ == CUDNN_ACTIVATION_SIGMOID ||
         cudnn_mode_ == CUDNN_ACTIVATION_TANH) {
@@ -111,7 +111,7 @@
         ctx->cudnn_handle, this->acti_desc_, &alpha, this->desc_,
         yblock->data(), this->desc_, dyblock->data(), this->desc_,
         xblock->data(), &beta, this->desc_, dxblock->mutable_data()));
-  }, {grad.block(), inout.block()}, {dx.block()});
+  }, {grad.block(), inout.block()}, {dx.block()}, "cudnnActivationBackward");
   return std::make_pair(dx, param_grad);
 }
 }  // namespace singa
diff --git a/src/model/layer/cudnn_convolution.cc b/src/model/layer/cudnn_convolution.cc
index 0aed832..44e1fef 100644
--- a/src/model/layer/cudnn_convolution.cc
+++ b/src/model/layer/cudnn_convolution.cc
@@ -197,7 +197,7 @@
                             this->workspace_.block()->mutable_data(),
                             this->workspace_count_ * sizeof(float), &beta,
                             this->y_desc_, outblock->mutable_data());
-  }, {input.block(), weight_.block()}, {output.block()}, workspace_.block());
+  }, {input.block(), weight_.block()}, {output.block(), workspace_.block()});
 
   if (bias_term_) {
     output.device()->Exec([output, this](Context * ctx) {
diff --git a/src/model/layer/cudnn_dropout.cc b/src/model/layer/cudnn_dropout.cc
index 65d7b42..e9c36ee 100644
--- a/src/model/layer/cudnn_dropout.cc
+++ b/src/model/layer/cudnn_dropout.cc
@@ -70,7 +70,7 @@
     if (!has_init_cudnn_) {
       input.device()->Exec([size, dtype, this, dev](Context* ctx) {
           this->InitCudnn(size, dtype, dev, ctx);
-          }, {}, {this->state_.block()});
+          }, {}, {this->state_.block()}, "InitCudnn");
     } else {
       int n, c, h, w, s;
       cudnnDataType_t type;
@@ -79,7 +79,7 @@
       if (size != static_cast<size_t>(w))
         input.device()->Exec([size, dtype, this, dev](Context* ctx) {
             this->InitCudnn(size, dtype, dev, ctx);
-            }, {}, {this->state_.block()});
+            }, {}, {this->state_.block()}, "InitCudnn");
     }
     Tensor output;
     output.ResetLike(input);
@@ -90,7 +90,7 @@
                           inblock->data(), this->y_desc_,
                           outblock->mutable_data(), mblock->mutable_data(),
                           this->reserve_size_);
-    }, {input.block()}, {output.block(), mask_.block()});
+    }, {input.block()}, {output.block(), mask_.block()}, "cudnnDropoutForward");
     return output;
   } else {
     return input;
@@ -110,7 +110,7 @@
                            dyblock->data(), this->x_desc_,
                            dxblock->mutable_data(), mblock->mutable_data(),
                            this->reserve_size_);
-    }, {grad.block(), mask_.block()}, {dx.block()});
+    }, {grad.block(), mask_.block()}, {dx.block()}, "cudnnDropoutBackward");
   } else {
     LOG(ERROR) << "Do not call backward for evaluation phase";
   }
diff --git a/src/model/operation/batchnorm.cc b/src/model/operation/batchnorm.cc
index 74ca76f..dffac1e 100644
--- a/src/model/operation/batchnorm.cc
+++ b/src/model/operation/batchnorm.cc
@@ -115,7 +115,7 @@
         ctx->dnnl_stream.wait();
       },
       {x.block(), w.block(), running_mean.block(), running_var.block()},
-      {y.block(), running_mean.block(), running_var.block()});
+      {y.block(), running_mean.block(), running_var.block()}, "CpuBatchNormForwardInference");
 
   return y;
 }
@@ -173,7 +173,7 @@
       },
       {x.block(), w.block(), running_mean.block(), running_var.block()},
       {y.block(), running_mean.block(), running_var.block(), mean.block(),
-       var.block()});
+       var.block()}, "CpuBatchNormForwardTraining");
 
   return {y, mean, var};
 }
@@ -238,7 +238,7 @@
         ctx->dnnl_stream.wait();
       },
       {x.block(), dy.block(), mean.block(), var.block(), w.block(), y.block()},
-      {dx.block(), dw.block()});
+      {dx.block(), dw.block()}, "CpuBatchNormBackwardx");
 
   singa::Tensor dbnScale(bnScale.shape());
   CopyDataToFrom(&dbnScale, dw, bnScale.Size(), 0, 0);
@@ -323,7 +323,7 @@
       {input.block(), bnScale.block(), bnBias.block(), running_mean.block(),
        running_var.block()},
       {output.block(), running_mean.block(), running_var.block(), mean.block(),
-       var.block()});
+       var.block()}, "GpuBatchNormForwardTraining");
   if (cbnh.is_2d) output.Reshape(Shape{shape.at(0), shape.at(1)});
   return {output, mean, var};
 }
@@ -361,7 +361,7 @@
       },
       {input.block(), bnScale.block(), bnBias.block(), running_mean.block(),
        running_var.block()},
-      {output.block()});
+      {output.block()}, "GpuBatchNormForwardInference");
   return output;
 }
 
@@ -396,7 +396,7 @@
             epsilon, mean.block()->data(), var.block()->data()));
       },
       {x.block(), dy.block(), bnScale.block(), mean.block(), var.block()},
-      {dx.block(), dbnScale.block(), dbnBias.block()});
+      {dx.block(), dbnScale.block(), dbnBias.block()}, "GpuBatchNormBackward");
 
   if (cbnh.is_2d) dx.Reshape(Shape{dx.shape().at(0), dx.shape().at(1)});
 
diff --git a/src/model/operation/convolution.cc b/src/model/operation/convolution.cc
index e4e6306..052e521 100644
--- a/src/model/operation/convolution.cc
+++ b/src/model/operation/convolution.cc
@@ -18,12 +18,12 @@
  * under the License.
  *
  ************************************************************/
-#include "../layer/convolution.h"
-
-#include <cctype>
+// #include "../layer/convolution.h"
 
 #include "convolution.h"
 
+#include <cctype>
+
 namespace singa {
 
 ConvHandle::ConvHandle(const Tensor &input,
@@ -185,10 +185,11 @@
         // synchronize stream
         s.wait();
       },
-      {x.block(), W.block(), b.block()}, {output.block()});
+      {x.block(), W.block(), b.block()}, {output.block()}, "CpuConvForward");
 
   return output;
-#else   // cpp naive
+#else   // cpp naive, error due to Im2col importing
+/*
   Shape w_shape = W.shape();
   Shape b_shape;
   if (ch.bias_term) b_shape = b.shape();
@@ -219,6 +220,7 @@
   W.Reshape(w_shape);
   if (ch.bias_term) b.Reshape(b_shape);
   return output;
+*/
 #endif  // USE_DNNL
 }
 
@@ -279,11 +281,12 @@
                       {DNNL_ARG_DIFF_SRC, conv_user_src_memory}});
         ctx->dnnl_stream.wait();
       },
-      {x.block(), dy.block(), W.block()}, {dx.block()});
+      {x.block(), dy.block(), W.block()}, {dx.block()}, "CpuConvBackwardx");
 
   return dx;
 
 #else   // NOT USE_DNNL
+/*  // error due to importing Col2im
   Shape w_shape = W.shape();
   W.Reshape(Shape{ch.num_filters, ch.col_height});
 
@@ -303,6 +306,7 @@
   }
   W.Reshape(w_shape);
   return dx;
+*/
 #endif  // USE_DNNL
 }
 
@@ -372,10 +376,12 @@
                       {DNNL_ARG_DIFF_BIAS, conv_diff_bias_memory}});
         ctx->dnnl_stream.wait();
       },
-      {x.block(), dy.block(), W.block()}, {dW.block(), ch.db->block()});
+      {x.block(), dy.block(), W.block()}, {dW.block(), ch.db->block()},
+      "CpuConvBackwardW");
 
   return dW;
 #else   // native cpp
+/* // error due to importing Im2col
   Tensor dW;
   dW.ResetLike(W);
   dW.SetValue(0.0f);
@@ -398,6 +404,7 @@
   }
   dW.Reshape(w_shape);
   return dW;
+*/
 #endif  // USE_DNNL
 }
 
@@ -598,7 +605,8 @@
                                 cch.workspace_count * sizeof(float), &beta,
                                 cch.y_desc, outblock->mutable_data());
       },
-      {x.block(), W.block()}, {output.block(), cch.workspace.block()});
+      {x.block(), W.block()}, {output.block(), cch.workspace.block()},
+      "cudnnConvForward");
 
   if (cch.bias_term) {
     Tensor outputFake(output);
@@ -610,7 +618,7 @@
                          bblock->data(), &beta, cch.y_desc,
                          outblock->mutable_data());
         },
-        {output.block(), b.block()}, {output.block()});
+        {output.block(), b.block()}, {output.block()}, "cudnnAddTensor");
   }
 
   return output;
@@ -634,7 +642,8 @@
             cch.workspace_count * sizeof(float), &beta, cch.x_desc,
             dxblock->mutable_data());
       },
-      {dy.block(), W.block()}, {dx.block(), cch.workspace.block()});
+      {dy.block(), W.block()}, {dx.block(), cch.workspace.block()},
+      "cudnnConvolutionBackwardData");
 
   return dx;
 }
@@ -658,7 +667,8 @@
             cch.workspace_count * sizeof(float), &beta, cch.filter_desc,
             dwblock->mutable_data());
       },
-      {dy.block(), x.block()}, {dW.block(), cch.workspace.block()});
+      {dy.block(), x.block()}, {dW.block(), cch.workspace.block()},
+      "cudnnConvolutionBackwardFilter");
 
   return dW;
 }
@@ -679,7 +689,7 @@
                                      dyblock->data(), &beta, cch.bias_desc,
                                      dbblock->mutable_data());
       },
-      {dy.block()}, {db.block()});
+      {dy.block()}, {db.block()}, "cudnnConvolutionBackwardBias");
 
   return db;
 }
diff --git a/src/model/operation/pooling.cc b/src/model/operation/pooling.cc
index 05a457a..b07ad4b 100644
--- a/src/model/operation/pooling.cc
+++ b/src/model/operation/pooling.cc
@@ -112,7 +112,7 @@
                                         {DNNL_ARG_WORKSPACE, ph.ws_mem}});
         ctx->dnnl_stream.wait();
       },
-      {x.block()}, {y.block()});
+      {x.block()}, {y.block()}, "CpuPoolingForward");
 
   return y;
 }
@@ -139,7 +139,7 @@
                                         {DNNL_ARG_WORKSPACE, ph.ws_mem}});
         ctx->dnnl_stream.wait();
       },
-      {x.block(), y.block(), grad.block()}, {in_grad.block()});
+      {x.block(), y.block(), grad.block()}, {in_grad.block()}, "CpuPoolingBackward");
 
   return in_grad;
 }
@@ -199,7 +199,7 @@
                             cph.x_desc, x.block()->data(), &beta, cph.y_desc,
                             output.block()->mutable_data());
       },
-      {x.block()}, {output.block()});
+      {x.block()}, {output.block()}, "GpuPoolingForward");
 
   return output;
 }
@@ -220,7 +220,7 @@
                              dy.block()->data(), cph.x_desc, x.block()->data(),
                              &beta, cph.x_desc, dx.block()->mutable_data());
       },
-      {dy.block(), y.block(), x.block()}, {dx.block()});
+      {dy.block(), y.block(), x.block()}, {dx.block()}, "GpuPoolingBackward");
 
   return dx;
 };
diff --git a/src/model/operation/rnn.cc b/src/model/operation/rnn.cc
new file mode 100644
index 0000000..bc8edfd
--- /dev/null
+++ b/src/model/operation/rnn.cc
@@ -0,0 +1,808 @@
+/*********************************************************
+ *
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ ************************************************************/
+
+#include "rnn.h"
+
+#include <map>
+namespace singa {
+#ifdef USE_CUDNN
+CudnnRNNHandle::CudnnRNNHandle(const Tensor &x, const int hidden_size,
+                               const int mode, const int num_layers,
+                               const int bias, const float dropout,
+                               const int bidirectional)
+    : bias(bias),
+      dropout(dropout),
+      bidirectional(bidirectional),
+      hidden_size(hidden_size),
+      mode(mode),
+      num_layers(num_layers) {
+  // cudnn rnn bias is not available in cudnn v7.4.5, not found in cudnn.h
+  CHECK_EQ(bias, 1) << "Current implementation always include bias";
+  CHECK(bidirectional == 0 || bidirectional == 1)
+      << "bidirectional should be 0 or 1 not " << bidirectional;
+
+  dev = x.device();
+  ctx = x.device()->context(0);
+
+  // TODO: batch first mode failed in cudnn
+  batch_first = 0;
+
+  // x shape {seq, bs, ..}
+  seq_length = x.shape(0);
+  batch_size = x.shape(1);
+  feature_size = x.shape(2);
+
+  cudnnRNNAlgo = CUDNN_RNN_ALGO_STANDARD;
+  cudnnDataType = CUDNN_DATA_FLOAT;
+
+  cudnnTensorDescriptor_t *xDesc = new cudnnTensorDescriptor_t[seq_length];
+  init_xDesc(xDesc, *this);
+
+  init_dropout_desc();
+  init_rnn_desc();
+  init_parameters_desc(xDesc);
+  init_workspace(xDesc);
+  init_param_mapping(xDesc);
+  delete[] xDesc;
+}
+
+void CudnnRNNHandle::init_workspace(cudnnTensorDescriptor_t *xDesc) {
+  /* workspace data */
+  // Need for every pass
+  CUDNN_CHECK(cudnnGetRNNWorkspaceSize(ctx->cudnn_handle, rnnDesc, seq_length,
+                                       xDesc, &workspace_size_bytes));
+  // Only needed in training, shouldn't be touched between passes.
+  CUDNN_CHECK(cudnnGetRNNTrainingReserveSize(
+      ctx->cudnn_handle, rnnDesc, seq_length, xDesc, &reserve_size_bytes));
+
+  workspace_size = workspace_size_bytes / sizeof(float);
+  reserve_size = reserve_size_bytes / sizeof(float);
+  workspace = Tensor(Shape{workspace_size}, dev);
+  reserve_space = Tensor(Shape{reserve_size}, dev);
+}
+
+void CudnnRNNHandle::init_parameters_desc(cudnnTensorDescriptor_t *xDesc) {
+  /* weights size
+   *   depends on rnn desc */
+  CUDNN_CHECK(cudnnGetRNNParamsSize(ctx->cudnn_handle, rnnDesc, xDesc[0],
+                                    &weights_size_bytes, cudnnDataType));
+  /* weights desc
+   *   depends on weights size */
+  CUDNN_CHECK(cudnnCreateFilterDescriptor(&wDesc));
+  CUDNN_CHECK(cudnnCreateFilterDescriptor(&dwDesc));
+
+  weights_size = weights_size_bytes / sizeof(float);  // TODO different types
+  int dimW[3];
+  dimW[0] = weights_size;  // TODO different types
+  dimW[1] = 1;
+  dimW[2] = 1;
+  CUDNN_CHECK(cudnnSetFilterNdDescriptor(wDesc, cudnnDataType,
+                                         CUDNN_TENSOR_NCHW, 3, dimW));
+  CUDNN_CHECK(cudnnSetFilterNdDescriptor(dwDesc, cudnnDataType,
+                                         CUDNN_TENSOR_NCHW, 3, dimW));
+}
+
+void CudnnRNNHandle::init_rnn_desc() {
+  /* rnn desc */
+  CUDNN_CHECK(cudnnCreateRNNDescriptor(&rnnDesc));
+  if (mode == 0)
+    RNNMode = CUDNN_RNN_RELU;
+  else if (mode == 1)
+    RNNMode = CUDNN_RNN_TANH;
+  else if (mode == 2)
+    RNNMode = CUDNN_LSTM;
+  else if (mode == 3)
+    RNNMode = CUDNN_GRU;
+  CUDNN_CHECK(cudnnSetRNNDescriptor(
+      ctx->cudnn_handle, rnnDesc, hidden_size, num_layers, dropoutDesc,
+      CUDNN_LINEAR_INPUT,
+      bidirectional ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL, RNNMode,
+      cudnnRNNAlgo,  // CUDNN_RNN_ALGO_STANDARD,
+      cudnnDataType));
+}
+void CudnnRNNHandle::init_dropout_desc() {
+  /* drop out */
+  size_t seed = 0x1234567;
+  CUDNN_CHECK(cudnnCreateDropoutDescriptor(&dropoutDesc));
+  size_t stateSize;
+  CUDNN_CHECK(cudnnDropoutGetStatesSize(ctx->cudnn_handle, &stateSize));
+  CUDA_CHECK(cudaMalloc(&states, stateSize));
+  CUDNN_CHECK(cudnnSetDropoutDescriptor(dropoutDesc, ctx->cudnn_handle, dropout,
+                                        states, stateSize, seed));
+}
+
+void init_yDesc(cudnnTensorDescriptor_t *yDesc, CudnnRNNHandle &h) {
+  int dimA[] = {h.batch_size,
+                h.bidirectional ? h.hidden_size * 2 : h.hidden_size, 1};
+  int strideA[] = {dimA[1] * dimA[2], dimA[2], 1};
+  for (int i = 0; i < h.seq_length; i++) {
+    CUDNN_CHECK(cudnnCreateTensorDescriptor(&yDesc[i]));
+    CUDNN_CHECK(cudnnSetTensorNdDescriptor(yDesc[i], h.cudnnDataType, 3, dimA,
+                                           strideA));
+  }
+}
+
+void init_xDesc(cudnnTensorDescriptor_t *xDesc, CudnnRNNHandle &h) {
+  int dimA[] = {h.batch_size, h.feature_size, 1};
+  int strideA[] = {dimA[1] * dimA[2], dimA[2], 1};
+  for (int i = 0; i < h.seq_length; i++) {
+    CUDNN_CHECK(cudnnCreateTensorDescriptor(&xDesc[i]));
+    CUDNN_CHECK(cudnnSetTensorNdDescriptor(xDesc[i], h.cudnnDataType, 3, dimA,
+                                           strideA));
+  }
+}
+
+void init_hc_Desc(cudnnTensorDescriptor_t &hxDesc, CudnnRNNHandle &h) {
+  /* If direction is CUDNN_BIDIRECTIONAL then the first dimension should match
+  double the numLayers argument passed to cudnnSetRNNDescriptor(). */
+  /* The second dimension must match the batchSize parameter in xDesc */
+  /* the third dimension must match the hiddenSize argument passed to the
+  cudnnSetRNNDescriptor() call used to initialize rnnDesc. */
+  int dimA[] = {h.num_layers * (h.bidirectional ? 2 : 1), h.batch_size,
+                h.hidden_size};
+  int strideA[] = {dimA[2] * dimA[1], dimA[2], 1};
+  CUDNN_CHECK(cudnnCreateTensorDescriptor(&hxDesc));
+  CUDNN_CHECK(
+      cudnnSetTensorNdDescriptor(hxDesc, h.cudnnDataType, 3, dimA, strideA));
+}
+
+/*
+vector<Tensor> GpuRNNForwardTraining();
+vector<Tensor> GpuRNNForwardInference();
+vector<Tensor> GpuRNNBackwardx();
+Tensor GpuRNNBackwardW();
+*/
+
+vector<Tensor> GpuRNNForwardInference(const Tensor &x, const Tensor &hx,
+                                      const Tensor &cx, const Tensor &W,
+                                      CudnnRNNHandle &h) {
+  CHECK_EQ(h.feature_size, x.shape(2)) << "feature size should not change";
+
+  // in
+  // x in shape {seq, bs, ..}
+  // out
+  // y in shape {seq, bs, ..}
+
+  h.batch_size = x.shape(1);  // update batch size to accomodate bs change
+  h.seq_length = x.shape(0);
+
+  Tensor y(Shape{h.seq_length, h.batch_size,
+                 h.hidden_size * (h.bidirectional ? 2 : 1)},
+           x.device());
+  Tensor hy(Shape{h.num_layers * (h.bidirectional ? 2 : 1), h.batch_size,
+                  h.hidden_size},
+            x.device());
+  Tensor cy(Shape{h.num_layers * (h.bidirectional ? 2 : 1), h.batch_size,
+                  h.hidden_size},
+            x.device());
+  y.SetValue(0.0f);
+  hy.SetValue(0.0f);
+  cy.SetValue(0.0f);
+  h.workspace.SetValue(0.0f);
+  y.device()->Exec(
+      [y, hy, cy, x, hx, cx, &W, &h](Context *ctx) {
+        // require desc, [x], hx, cx, w, y, hy, cy
+        cudnnTensorDescriptor_t *xDesc =
+            new cudnnTensorDescriptor_t[h.seq_length];
+        cudnnTensorDescriptor_t *yDesc =
+            new cudnnTensorDescriptor_t[h.seq_length];
+        init_xDesc(xDesc, h);
+        init_yDesc(yDesc, h);
+        cudnnTensorDescriptor_t hxDesc;
+        cudnnTensorDescriptor_t cxDesc;
+        cudnnTensorDescriptor_t hyDesc;
+        cudnnTensorDescriptor_t cyDesc;
+        init_hc_Desc(hxDesc, h);
+        init_hc_Desc(cxDesc, h);
+        init_hc_Desc(hyDesc, h);
+        init_hc_Desc(cyDesc, h);
+
+        auto x_con = Contiguous(x);
+
+        auto xptr = x_con.block()->data();
+        auto hxptr = hx.block()->data();
+        auto cxptr = cx.block()->data();
+        auto Wptr = W.block()->data();
+        auto yptr = y.block()->mutable_data();
+        auto hyptr = hy.block()->mutable_data();
+        auto cyptr = cy.block()->mutable_data();
+        auto wsptr = h.workspace.block()->mutable_data();
+
+        CUDNN_CHECK(cudnnRNNForwardInference(
+            ctx->cudnn_handle, h.rnnDesc, h.seq_length, xDesc, xptr, hxDesc,
+            hxptr, cxDesc, cxptr, h.wDesc, Wptr, yDesc, yptr, hyDesc, hyptr,
+            cyDesc, cyptr, wsptr, h.workspace_size_bytes));
+
+        delete[] xDesc;
+        delete[] yDesc;
+      },
+      {x.block(), hx.block(), cx.block(), W.block()},
+      {y.block(), hy.block(), cy.block(), h.workspace.block()},
+      "cudnnRNNForwardInterface");
+  return {y, hy, cy};
+}
+
+vector<Tensor> GpuRNNForwardTraining(const Tensor &x, const Tensor &hx,
+                                     const Tensor &cx, const Tensor &W,
+                                     CudnnRNNHandle &h) {
+  CHECK_EQ(h.feature_size, x.shape(2)) << "feature size should not change";
+
+  // in
+  // x in shape {seq, bs, ..}
+  // out
+  // y in shape {seq, bs, ..}
+
+  // update batch size to accomodate bs change
+  h.batch_size = x.shape(1);
+  h.seq_length = x.shape(0);
+
+  Tensor y(Shape{h.seq_length, h.batch_size,
+                 h.hidden_size * (h.bidirectional ? 2 : 1)},
+           x.device());
+  Tensor hy(Shape{h.num_layers * (h.bidirectional ? 2 : 1), h.batch_size,
+                  h.hidden_size},
+            x.device());
+  Tensor cy(Shape{h.num_layers * (h.bidirectional ? 2 : 1), h.batch_size,
+                  h.hidden_size},
+            x.device());
+  y.SetValue(0.0f);
+  hy.SetValue(0.0f);
+  cy.SetValue(0.0f);
+  h.workspace.SetValue(0.0f);
+  h.reserve_space.SetValue(0.0f);
+
+  y.device()->Exec(
+      [y, hy, cy, x, hx, cx, &W, &h](Context *ctx) {
+        // require desc, [x], hx, cx, w, y, hy, cy
+        cudnnTensorDescriptor_t *xDesc =
+            new cudnnTensorDescriptor_t[h.seq_length];
+        cudnnTensorDescriptor_t *yDesc =
+            new cudnnTensorDescriptor_t[h.seq_length];
+        init_xDesc(xDesc, h);
+        init_yDesc(yDesc, h);
+        cudnnTensorDescriptor_t hxDesc;
+        cudnnTensorDescriptor_t cxDesc;
+        cudnnTensorDescriptor_t hyDesc;
+        cudnnTensorDescriptor_t cyDesc;
+        init_hc_Desc(hxDesc, h);
+        init_hc_Desc(cxDesc, h);
+        init_hc_Desc(hyDesc, h);
+        init_hc_Desc(cyDesc, h);
+
+        auto x_con = Contiguous(x);
+
+        auto xptr = x_con.block()->data();
+        auto hxptr = hx.block()->data();
+        auto cxptr = cx.block()->data();
+        auto Wptr = W.block()->data();
+        auto yptr = y.block()->mutable_data();
+        auto hyptr = hy.block()->mutable_data();
+        auto cyptr = cy.block()->mutable_data();
+        auto wsptr = h.workspace.block()->mutable_data();
+        auto rsptr = h.reserve_space.block()->mutable_data();
+        CUDNN_CHECK(cudnnRNNForwardTraining(
+            ctx->cudnn_handle, h.rnnDesc, h.seq_length, xDesc, xptr, hxDesc,
+            hxptr, cxDesc, cxptr, h.wDesc, Wptr, yDesc, yptr, hyDesc, hyptr,
+            cyDesc, cyptr, wsptr, h.workspace_size_bytes, rsptr,
+            h.reserve_size_bytes));
+        delete[] xDesc;
+        delete[] yDesc;
+      },
+      {x.block(), hx.block(), cx.block(), W.block()},
+      {y.block(), hy.block(), cy.block(), h.workspace.block(),
+       h.reserve_space.block()},
+      "cudnnRNNForwardTraining");
+
+  return {y, hy, cy};
+}
+vector<Tensor> GpuRNNBackwardx(const Tensor &y, const Tensor &dy,
+                               const Tensor &dhy, const Tensor &dcy,
+                               const Tensor &W, const Tensor &hx,
+                               const Tensor &cx, CudnnRNNHandle &h) {
+  // in
+  // y shape {seq, bs}
+  // dy shape {seq, bs}
+  Tensor dx(Shape{h.seq_length, h.batch_size, h.feature_size}, y.device());
+  Tensor dhx(Shape{h.num_layers * (h.bidirectional ? 2 : 1), h.batch_size,
+                   h.hidden_size},
+             y.device());
+  Tensor dcx(Shape{h.num_layers * (h.bidirectional ? 2 : 1), h.batch_size,
+                   h.hidden_size},
+             y.device());
+  dx.SetValue(0.0f);
+  dhx.SetValue(0.0f);
+  dcx.SetValue(0.0f);
+  h.workspace.SetValue(0.0f);
+  dx.device()->Exec(
+      [dx, dhx, dcx, y, dy, dhy, dcy, &W, hx, cx, &h](Context *ctx) {
+        // require desc:
+        //      [dx], hx, dhx, cx, dcx, w,
+        // [y], [dy],     dhy,     dcy
+        cudnnTensorDescriptor_t *dxDesc =
+            new cudnnTensorDescriptor_t[h.seq_length];
+        cudnnTensorDescriptor_t *yDesc =
+            new cudnnTensorDescriptor_t[h.seq_length];
+        cudnnTensorDescriptor_t *dyDesc =
+            new cudnnTensorDescriptor_t[h.seq_length];
+        init_yDesc(yDesc, h);
+        init_xDesc(dxDesc, h);
+        init_yDesc(dyDesc, h);
+        cudnnTensorDescriptor_t hxDesc;
+        cudnnTensorDescriptor_t cxDesc;
+        cudnnTensorDescriptor_t dhxDesc;
+        cudnnTensorDescriptor_t dcxDesc;
+        cudnnTensorDescriptor_t dhyDesc;
+        cudnnTensorDescriptor_t dcyDesc;
+        init_hc_Desc(hxDesc, h);
+        init_hc_Desc(cxDesc, h);
+        init_hc_Desc(dhxDesc, h);
+        init_hc_Desc(dcxDesc, h);
+        init_hc_Desc(dhyDesc, h);
+        init_hc_Desc(dcyDesc, h);
+
+        auto y_con = Contiguous(y);
+        auto dy_con = Contiguous(dy);
+
+        auto dxptr = dx.block()->mutable_data();
+        auto hxptr = hx.block()->data();
+        auto dhxptr = dhx.block()->mutable_data();
+        auto cxptr = cx.block()->data();
+        auto dcxptr = dcx.block()->mutable_data();
+        auto Wptr = W.block()->data();
+        auto yptr = y_con.block()->data();
+        auto dyptr = dy_con.block()->data();
+        auto dhyptr = dhy.block()->data();
+        auto dcyptr = dcy.block()->data();
+        auto wsptr = h.workspace.block()->mutable_data();
+        auto rsptr = h.reserve_space.block()->mutable_data();
+
+        CUDNN_CHECK(cudnnRNNBackwardData(
+            ctx->cudnn_handle, h.rnnDesc, h.seq_length, yDesc, yptr, dyDesc,
+            dyptr, dhyDesc, dhyptr, dcyDesc, dcyptr, h.wDesc, Wptr, hxDesc,
+            hxptr, cxDesc, cxptr, dxDesc, dxptr, dhxDesc, dhxptr, dcxDesc,
+            dcxptr, wsptr, h.workspace_size_bytes, rsptr,
+            h.reserve_size_bytes));
+        delete[] dxDesc;
+        delete[] yDesc;
+        delete[] dyDesc;
+      },
+      {y.block(), dy.block(), dhy.block(), dcy.block(), hx.block(), cx.block(),
+       W.block()},
+      {dx.block(), dhx.block(), dcx.block(), h.workspace.block(),
+       h.reserve_space.block()},
+      "cudnnRNNBackwardx");
+  return {dx, dhx, dcx};
+}
+
+Tensor GpuRNNBackwardW(const Tensor &x, const Tensor &hx, const Tensor &y,
+                       CudnnRNNHandle &h) {
+  Tensor dW(Shape{h.weights_size}, x.device());
+  // in
+  // x shape {seq, bs}
+  // y shape {seq, bs}
+  dW.SetValue(0.0f);
+  h.workspace.SetValue(0.0f);
+  dW.device()->Exec(
+      [dW, x, hx, y, &h](Context *ctx) {
+        cudnnTensorDescriptor_t *xDesc =
+            new cudnnTensorDescriptor_t[h.seq_length];
+        cudnnTensorDescriptor_t *yDesc =
+            new cudnnTensorDescriptor_t[h.seq_length];
+        init_xDesc(xDesc, h);
+        init_yDesc(yDesc, h);
+        cudnnTensorDescriptor_t hxDesc;
+        init_hc_Desc(hxDesc, h);
+
+        auto y_con = Contiguous(y);
+        auto x_con = Contiguous(x);
+
+        auto xptr = x_con.block()->data();
+        auto hxptr = hx.block()->data();
+        auto yptr = y_con.block()->data();
+        auto dWptr = dW.block()->mutable_data();
+        auto wsptr = h.workspace.block()->mutable_data();
+        auto rsptr = h.reserve_space.block()->mutable_data();
+
+        CUDNN_CHECK(cudnnRNNBackwardWeights(
+            ctx->cudnn_handle, h.rnnDesc, h.seq_length, xDesc, xptr, hxDesc,
+            hxptr, yDesc, yptr, wsptr, h.workspace_size_bytes, h.dwDesc, dWptr,
+            rsptr, h.reserve_size_bytes));
+        delete[] xDesc;
+        delete[] yDesc;
+      },
+      {x.block(), y.block(), hx.block()},
+      {dW.block(), h.workspace.block(), h.reserve_space.block()},
+      "cudnnRnnBackwardW");
+  return dW;
+}
+
+void CudnnRNNHandle::init_param_mapping(cudnnTensorDescriptor_t *xDesc) {
+  int linLayerIDRange = 2;
+  if (mode == 0 || mode == 1) {
+    // vanilla relu/tanh
+    linLayerIDRange = 2;
+  } else if (mode == 2) {
+    // lstm
+    linLayerIDRange = 8;
+  } else if (mode == 3) {
+    // gru
+    linLayerIDRange = 6;
+  }
+  int pseudoLayerRange = (bidirectional ? 2 : 1) * num_layers;
+
+  // dummy weights for getting the offset
+  Tensor weights(
+      Shape{
+          weights_size,
+      },
+      dev);
+  weights.SetValue(0.0f);
+  const void *W_ptr = weights.block()->data();
+
+  void *param_ptr = nullptr;
+  int dims[] = {1, 1, 1};
+  cudnnDataType_t data_type;
+  cudnnTensorFormat_t tensor_format;
+  int n_dims;
+  cudnnFilterDescriptor_t paramDesc;
+  CUDNN_CHECK(cudnnCreateFilterDescriptor(&paramDesc));
+
+  vector<bool> paramTypes{false, true};
+  for (int linLayerID = 0; linLayerID < linLayerIDRange; linLayerID++) {
+    for (int pseudoLayer = 0; pseudoLayer < pseudoLayerRange; pseudoLayer++) {
+      for (const bool &is_bias : paramTypes) {
+        // get param ptr
+        if (is_bias) {
+          CUDNN_CHECK(cudnnGetRNNLinLayerBiasParams(
+              ctx->cudnn_handle, rnnDesc, pseudoLayer, xDesc[0], wDesc, W_ptr,
+              linLayerID, paramDesc, &param_ptr));
+        } else {
+          CUDNN_CHECK(cudnnGetRNNLinLayerMatrixParams(
+              ctx->cudnn_handle, rnnDesc, pseudoLayer, xDesc[0], wDesc, W_ptr,
+              linLayerID, paramDesc, &param_ptr));
+        }
+
+        // get param dims
+        CUDNN_CHECK(cudnnGetFilterNdDescriptor(paramDesc, 3, &data_type,
+                                               &tensor_format, &n_dims, dims));
+
+        // get diff - offset
+        size_t offset = (float *)param_ptr - (float *)W_ptr;
+
+        // save in map
+        weights_mapping[std::make_tuple(linLayerID, pseudoLayer, is_bias)] =
+            std::make_tuple(offset, dims[0] * dims[1] * dims[2]);
+      }
+    }
+  }
+}
+
+void GpuRNNSetParam(int linLayerID, int pseudoLayer, Tensor &weights,
+                    Tensor &paramValues, bool is_bias, CudnnRNNHandle &h) {
+  size_t offset, size;
+  std::tie(offset, size) =
+      h.weights_mapping[std::make_tuple(linLayerID, pseudoLayer, is_bias)];
+  CHECK_EQ(size, paramValues.size()) << "param size is not expected";
+  CopyDataToFrom(&weights, paramValues, size, offset, 0);
+}
+
+Tensor GpuRNNGetParamCopy(int linLayerID, int pseudoLayer, Tensor &weights,
+                          bool is_bias, CudnnRNNHandle &h) {
+  size_t offset, size;
+  std::tie(offset, size) =
+      h.weights_mapping[std::make_tuple(linLayerID, pseudoLayer, is_bias)];
+  Tensor paramCopy(
+      Shape{
+          size,
+      },
+      weights.device());
+  CopyDataToFrom(&paramCopy, weights, size, 0, offset);
+  return paramCopy;
+}
+
+/*
+vector<Tensor> GpuRNNForwardTrainingEx();
+vector<Tensor> GpuRNNForwardInferenceEx();
+vector<Tensor> GpuRNNBackwardxEx();
+Tensor GpuRNNBackwardWEx();
+*/
+
+void init_data_desc(cudnnRNNDataDescriptor_t &desc, int data_size,
+                    const Tensor seq_lengths, CudnnRNNHandle &h) {
+  /* cudnnRNNDataDescriptor_t is a pointer to an opaque structure holding
+  the description of an RNN data set. The function
+  cudnnCreateRNNDataDescriptor() is used to create one instance, and
+  cudnnSetRNNDataDescriptor() must be used to initialize this instance.
+  */
+  CUDNN_CHECK(cudnnCreateRNNDataDescriptor(&desc));
+  /* CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED
+    Data layout is padded, with outer stride from one time-step to the
+  next. CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED The sequence length is
+  sorted and packed as in basic RNN API.
+  CUDNN_RNN_DATA_LAYOUT_BATCH_MAJOR_UNPACKED
+    Data layout is padded, with outer stride from one batch to the next.
+  */
+  cudnnRNNDataLayout_t layout;
+  if (h.batch_first) {
+    layout = CUDNN_RNN_DATA_LAYOUT_BATCH_MAJOR_UNPACKED;
+  } else {
+    layout = CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED;
+  }
+
+  /* This is only effective when the descriptor is describing the RNN
+  output, and the unpacked layout is specified.*/
+  float paddingFill = 0.0f;
+
+  /* Input. An integer array with batchSize number of elements.
+  Describes the length (number of time-steps) of each sequence. Each
+  element in seqLengthArray must be greater than 0 but less than or
+  equal to maxSeqLength. */
+  Tensor tmp = seq_lengths.Clone();
+  tmp.ToHost();
+  tmp = tmp.AsType(singa::kInt);
+  const int *seq_lengths_ptr = static_cast<const int *>(tmp.block()->data());
+
+  CUDNN_CHECK(cudnnSetRNNDataDescriptor(desc, h.cudnnDataType, layout,
+                                        h.seq_length, h.batch_size, data_size,
+                                        seq_lengths_ptr, (void *)&paddingFill));
+}
+
+vector<Tensor> GpuRNNForwardInferenceEx(const Tensor &x, const Tensor &hx,
+                                        const Tensor &cx, const Tensor &W,
+                                        const Tensor &seq_lengths,
+                                        CudnnRNNHandle &h) {
+  CHECK_EQ(h.feature_size, x.shape(2)) << "feature size should not change";
+
+  Tensor y, hy, cy;
+  Shape yshape, states_shape;
+
+  if (h.batch_first) {
+    LOG(FATAL) << "batch_first not implemented for GpuRNNForwardTrainingEx";
+  } else {
+    h.seq_length = x.shape(0);
+    h.batch_size = x.shape(1);
+    yshape = Shape{h.seq_length, h.batch_size,
+                   h.hidden_size * (h.bidirectional ? 2 : 1)};
+    states_shape = Shape{h.num_layers * (h.bidirectional ? 2 : 1), h.batch_size,
+                         h.hidden_size};
+  }
+
+  y = Tensor(yshape, x.device());
+  hy = Tensor(states_shape, x.device());
+  cy = Tensor(states_shape, x.device());
+
+  y.device()->Exec(
+      [y, hy, cy, x, seq_lengths, hx, cx, &W, &h](Context *ctx) {
+        // data descriptor
+        cudnnRNNDataDescriptor_t xDesc, yDesc;
+        init_data_desc(xDesc, h.feature_size, seq_lengths, h);
+        init_data_desc(yDesc,
+                       h.bidirectional ? h.hidden_size * 2 : h.hidden_size,
+                       seq_lengths, h);
+
+        // hidden cell states descriptor
+        cudnnTensorDescriptor_t hxDesc, cxDesc, hyDesc, cyDesc;
+        init_hc_Desc(hxDesc, h);
+        init_hc_Desc(cxDesc, h);
+        init_hc_Desc(hyDesc, h);
+        init_hc_Desc(cyDesc, h);
+
+        auto xptr = x.block()->data();
+        auto hxptr = hx.block()->data();
+        auto cxptr = cx.block()->data();
+        auto Wptr = W.block()->data();
+        auto yptr = y.block()->mutable_data();
+        auto hyptr = hy.block()->mutable_data();
+        auto cyptr = cy.block()->mutable_data();
+        auto wsptr = h.workspace.block()->mutable_data();
+
+        /* This routine is the extended version of the cudnnRNNForwardTraining()
+        function. The cudnnRNNForwardTrainingEx() allows the user to use
+        unpacked (padded) layout for input x and output y.
+        */
+        CUDNN_CHECK(cudnnRNNForwardInferenceEx(
+            ctx->cudnn_handle, h.rnnDesc, xDesc, xptr, hxDesc, hxptr, cxDesc,
+            cxptr, h.wDesc, Wptr, yDesc, yptr, hyDesc, hyptr, cyDesc, cyptr,
+            NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, wsptr,
+            h.workspace_size_bytes));
+      },
+      {x.block(), hx.block(), cx.block(), W.block()},
+      {y.block(), hy.block(), cy.block(), h.workspace.block(),
+       h.reserve_space.block()});
+  return {y, hy, cy};
+}
+
+vector<Tensor> GpuRNNForwardTrainingEx(const Tensor &x, const Tensor &hx,
+                                       const Tensor &cx, const Tensor &W,
+                                       const Tensor &seq_lengths,
+                                       CudnnRNNHandle &h) {
+  CHECK_EQ(h.feature_size, x.shape(2)) << "feature size should not change";
+
+  Tensor y, hy, cy;
+  Shape yshape, states_shape;
+
+  if (h.batch_first) {
+    LOG(FATAL) << "batch_first not implemented for GpuRNNForwardTrainingEx";
+  } else {
+    h.seq_length = x.shape(0);
+    h.batch_size = x.shape(1);
+    yshape = Shape{h.seq_length, h.batch_size,
+                   h.hidden_size * (h.bidirectional ? 2 : 1)};
+    states_shape = Shape{h.num_layers * (h.bidirectional ? 2 : 1), h.batch_size,
+                         h.hidden_size};
+  }
+
+  y = Tensor(yshape, x.device());
+  hy = Tensor(states_shape, x.device());
+  cy = Tensor(states_shape, x.device());
+
+  y.device()->Exec(
+      [y, hy, cy, x, seq_lengths, hx, cx, &W, &h](Context *ctx) {
+        // data descriptor
+        cudnnRNNDataDescriptor_t xDesc, yDesc;
+        init_data_desc(xDesc, h.feature_size, seq_lengths, h);
+        init_data_desc(yDesc,
+                       h.bidirectional ? h.hidden_size * 2 : h.hidden_size,
+                       seq_lengths, h);
+
+        // hidden cell states descriptor
+        cudnnTensorDescriptor_t hxDesc, cxDesc, hyDesc, cyDesc;
+        init_hc_Desc(hxDesc, h);
+        init_hc_Desc(cxDesc, h);
+        init_hc_Desc(hyDesc, h);
+        init_hc_Desc(cyDesc, h);
+
+        auto xptr = x.block()->data();
+        auto hxptr = hx.block()->data();
+        auto cxptr = cx.block()->data();
+        auto Wptr = W.block()->data();
+        auto yptr = y.block()->mutable_data();
+        auto hyptr = hy.block()->mutable_data();
+        auto cyptr = cy.block()->mutable_data();
+        auto wsptr = h.workspace.block()->mutable_data();
+        auto rsptr = h.reserve_space.block()->mutable_data();
+
+        /* This routine is the extended version of the cudnnRNNForwardTraining()
+        function. The cudnnRNNForwardTrainingEx() allows the user to use
+        unpacked (padded) layout for input x and output y.
+        */
+        CUDNN_CHECK(cudnnRNNForwardTrainingEx(
+            ctx->cudnn_handle, h.rnnDesc, xDesc, xptr, hxDesc, hxptr, cxDesc,
+            cxptr, h.wDesc, Wptr, yDesc, yptr, hyDesc, hyptr, cyDesc, cyptr,
+            NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, wsptr,
+            h.workspace_size_bytes, rsptr, h.reserve_size_bytes));
+      },
+      {x.block(), hx.block(), cx.block(), W.block()},
+      {y.block(), hy.block(), cy.block(), h.workspace.block(),
+       h.reserve_space.block()});
+  return {y, hy, cy};
+}
+
+vector<Tensor> GpuRNNBackwardxEx(const Tensor &y, const Tensor &dy,
+                                 const Tensor &dhy, const Tensor &dcy,
+                                 const Tensor &W, const Tensor &hx,
+                                 const Tensor &cx, const Tensor &seq_lengths,
+                                 CudnnRNNHandle &h) {
+  // y shape: {bs, seq}
+  // dy shape: {bs, seq}
+  // dx shape: {bs, seq}
+  Shape xshape, states_shape;
+  if (h.batch_first) {
+    LOG(FATAL) << "batch_first not implemented for GpuRNNBackwardxEx";
+  } else {
+    xshape = Shape{h.batch_size, h.seq_length, h.feature_size};
+    states_shape = Shape{h.num_layers * (h.bidirectional ? 2 : 1), h.batch_size,
+                         h.hidden_size};
+  }
+  Tensor dx(xshape, y.device());
+  Tensor dhx(states_shape, y.device());
+  Tensor dcx(states_shape, y.device());
+
+  dx.SetValue(0.0f);
+  dhx.SetValue(0.0f);
+  dcx.SetValue(0.0f);
+  h.workspace.SetValue(0.0f);
+
+  dx.device()->Exec(
+      [dx, dhx, dcx, y, dy, dhy, dcy, &W, hx, cx, seq_lengths,
+       &h](Context *ctx) {
+        cudnnRNNDataDescriptor_t yDesc, dyDesc, dxDesc;
+        init_data_desc(yDesc,
+                       h.bidirectional ? h.hidden_size * 2 : h.hidden_size,
+                       seq_lengths, h);
+        init_data_desc(dyDesc,
+                       h.bidirectional ? h.hidden_size * 2 : h.hidden_size,
+                       seq_lengths, h);
+        init_data_desc(dxDesc, h.feature_size, seq_lengths, h);
+
+        /* other tensors desc*/
+        cudnnTensorDescriptor_t hxDesc, cxDesc, dhxDesc, dcxDesc, dhyDesc,
+            dcyDesc;
+        init_hc_Desc(hxDesc, h);
+        init_hc_Desc(cxDesc, h);
+        init_hc_Desc(dhxDesc, h);
+        init_hc_Desc(dcxDesc, h);
+        init_hc_Desc(dhyDesc, h);
+        init_hc_Desc(dcyDesc, h);
+
+        auto dxptr = dx.block()->mutable_data();
+        auto hxptr = hx.block()->data();
+        auto dhxptr = dhx.block()->mutable_data();
+        auto cxptr = cx.block()->data();
+        auto dcxptr = dcx.block()->mutable_data();
+        auto Wptr = W.block()->data();
+        auto yptr = y.block()->data();
+        auto dyptr = dy.block()->data();
+        auto dhyptr = dhy.block()->data();
+        auto dcyptr = dcy.block()->data();
+        auto wsptr = h.workspace.block()->mutable_data();
+        auto rsptr = h.reserve_space.block()->mutable_data();
+
+        CUDNN_CHECK(cudnnRNNBackwardDataEx(
+            ctx->cudnn_handle, h.rnnDesc, yDesc, yptr, dyDesc, dyptr, NULL,
+            NULL, dhyDesc, dhyptr, dcyDesc, dcyptr, h.wDesc, Wptr, hxDesc,
+            hxptr, cxDesc, cxptr, dxDesc, dxptr, dhxDesc, dhxptr, dcxDesc,
+            dcxptr, NULL, NULL, wsptr, h.workspace_size_bytes, rsptr,
+            h.reserve_size_bytes));
+      },
+      {y.block(), dy.block(), dhy.block(), dcy.block(), hx.block(), cx.block(),
+       W.block()},
+      {dx.block(), dhx.block(), dcx.block(), h.workspace.block(),
+       h.reserve_space.block()});
+  return {dx, dhx, dcx};
+}
+
+Tensor GpuRNNBackwardWEx(const Tensor &x, const Tensor &hx, const Tensor &y,
+                         const Tensor &seq_lengths, CudnnRNNHandle &h) {
+  Tensor dW(Shape{h.weights_size}, x.device());
+  dW.SetValue(0.0f);
+
+  dW.device()->Exec(
+      [dW, x, hx, y, seq_lengths, &h](Context *ctx) {
+        cudnnRNNDataDescriptor_t xDesc, yDesc;
+        init_data_desc(xDesc, h.feature_size, seq_lengths, h);
+        init_data_desc(yDesc,
+                       h.bidirectional ? h.hidden_size * 2 : h.hidden_size,
+                       seq_lengths, h);
+
+        /* other tensor desc */
+        cudnnTensorDescriptor_t hxDesc;
+        init_hc_Desc(hxDesc, h);
+
+        auto xptr = x.block()->data();
+        auto hxptr = hx.block()->data();
+        auto yptr = y.block()->data();
+        auto dWptr = dW.block()->mutable_data();
+        auto wsptr = h.workspace.block()->mutable_data();
+        auto rsptr = h.reserve_space.block()->mutable_data();
+
+        CUDNN_CHECK(cudnnRNNBackwardWeightsEx(
+            ctx->cudnn_handle, h.rnnDesc, xDesc, xptr, hxDesc, hxptr, yDesc,
+            yptr, wsptr, h.workspace_size_bytes, h.dwDesc, dWptr, rsptr,
+            h.reserve_size_bytes));
+      },
+      {x.block(), y.block(), hx.block()},
+      {dW.block(), h.workspace.block(), h.reserve_space.block()});
+  return dW;
+}
+
+#endif  // USE_CUDNN
+}  // namespace singa
diff --git a/src/model/operation/rnn.h b/src/model/operation/rnn.h
new file mode 100644
index 0000000..bbc9266
--- /dev/null
+++ b/src/model/operation/rnn.h
@@ -0,0 +1,136 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef SRC_MODEL_OPERATION_RNN_H_
+#define SRC_MODEL_OPERATION_RNN_H_
+
+#include <iostream>
+#include <tuple>
+#include <vector>
+
+#include "singa/core/tensor.h"
+#include "singa/singa_config.h"
+#include "singa/utils/logging.h"
+
+#ifdef USE_CUDNN
+#include <cudnn.h>
+
+#include "../layer/cudnn_utils.h"
+#endif  // USE_CUDNN
+
+namespace singa {
+
+#ifdef USE_CUDNN
+class CudnnRNNHandle {
+ public:
+  CudnnRNNHandle(const Tensor &x, const int hidden_size, const int mode = 0,
+                 const int num_layers = 1, const int bias = 1,
+                 const float dropout = 0.0f, const int bidirectional = 0);
+
+  Context *ctx;
+  std::shared_ptr<Device> dev;
+
+  // parameters
+  int bias;
+  int mode;
+  float dropout;
+  int bidirectional;
+  size_t feature_size;
+  size_t hidden_size;
+  size_t num_layers;
+  int batch_first;
+
+  size_t weights_size_bytes;
+  size_t weights_size;
+  size_t batch_size;
+  size_t seq_length;
+
+  /* workspace data */
+  size_t workspace_size;
+  size_t workspace_size_bytes;
+  size_t reserve_size;
+  size_t reserve_size_bytes;
+  Tensor workspace;
+  Tensor reserve_space;
+
+  /* dropout */
+  void *states;
+  cudnnDropoutDescriptor_t dropoutDesc;
+
+  /* rnn desc */
+  cudnnRNNDescriptor_t rnnDesc;
+  cudnnRNNMode_t RNNMode;
+  cudnnRNNAlgo_t cudnnRNNAlgo;
+  cudnnDataType_t cudnnDataType;
+
+  /* weights desc */
+  cudnnFilterDescriptor_t wDesc, dwDesc;
+
+  void init_dropout_desc();
+  void init_rnn_desc();
+  void init_parameters_desc(cudnnTensorDescriptor_t *xDesc);
+  void init_workspace(cudnnTensorDescriptor_t *xDesc);
+  void init_param_mapping(cudnnTensorDescriptor_t *xDesc);
+
+  // linLayerID, pseudoLayer, is_bias => offset, size
+  // e.g. Wx of 1st layer is at <0,0,false> => 0, data_s*hid_s
+  std::map<std::tuple<int, int, bool>, std::tuple<size_t, size_t>>
+      weights_mapping;
+};
+
+void init_xDesc(cudnnTensorDescriptor_t *xDesc, CudnnRNNHandle &h);
+void init_yDesc(cudnnTensorDescriptor_t *yDesc, CudnnRNNHandle &h);
+void init_hc_Desc(cudnnTensorDescriptor_t &hDesc, CudnnRNNHandle &h);
+
+vector<Tensor> GpuRNNForwardTraining(const Tensor &x, const Tensor &hx,
+                                     const Tensor &cx, const Tensor &W,
+                                     CudnnRNNHandle &h);
+vector<Tensor> GpuRNNForwardInference(const Tensor &x, const Tensor &hx,
+                                      const Tensor &cx, const Tensor &W,
+                                      CudnnRNNHandle &h);
+vector<Tensor> GpuRNNBackwardx(const Tensor &y, const Tensor &dy,
+                               const Tensor &dhy, const Tensor &dcy,
+                               const Tensor &W, const Tensor &hx,
+                               const Tensor &cx, CudnnRNNHandle &h);
+Tensor GpuRNNBackwardW(const Tensor &x, const Tensor &hx, const Tensor &y,
+                       CudnnRNNHandle &h);
+
+void GpuRNNSetParam(int linLayerID, int pseudoLayer, Tensor &weights,
+                    Tensor &paramValues, bool is_bias, CudnnRNNHandle &h);
+Tensor GpuRNNGetParamCopy(int linLayerID, int pseudoLayer, Tensor &weights,
+                          bool is_bias, CudnnRNNHandle &h);
+
+vector<Tensor> GpuRNNForwardTrainingEx(const Tensor &x, const Tensor &hx,
+                                       const Tensor &cx, const Tensor &W,
+                                       const Tensor &seq_lengths,
+                                       CudnnRNNHandle &h);
+vector<Tensor> GpuRNNForwardInferenceEx(const Tensor &x, const Tensor &hx,
+                                        const Tensor &cx, const Tensor &W,
+                                        const Tensor &seq_lengths,
+                                        CudnnRNNHandle &h);
+vector<Tensor> GpuRNNBackwardxEx(const Tensor &y, const Tensor &dy,
+                                 const Tensor &dhy, const Tensor &dcy,
+                                 const Tensor &W, const Tensor &hx,
+                                 const Tensor &cx, const Tensor &seq_lengths,
+                                 CudnnRNNHandle &h);
+Tensor GpuRNNBackwardWEx(const Tensor &x, const Tensor &hx, const Tensor &y,
+                         const Tensor &seq_lengths, CudnnRNNHandle &h);
+
+#endif  // USE_CUDNN
+
+}  // namespace singa
+#endif  // SRC_MODEL_OPERATION_RNN_H_
diff --git a/test/python/cuda_helper.py b/test/python/cuda_helper.py
index ed2ee43..8f6bd4f 100644
--- a/test/python/cuda_helper.py
+++ b/test/python/cuda_helper.py
@@ -21,5 +21,5 @@
 # avoid singleton error
 gpu_dev = None
 if singa_wrap.USE_CUDA:
-    gpu_dev = device.create_cuda_gpu(set_default=False)
+    gpu_dev = device.create_cuda_gpu()
 cpu_dev = device.get_default_device()
diff --git a/test/python/run.py b/test/python/run.py
index b6e318f..b787a15 100644
--- a/test/python/run.py
+++ b/test/python/run.py
@@ -16,11 +16,15 @@
 # limitations under the License.
 #
 
+import sys
 import unittest
-# import xmlrunner
 
-loader = unittest.TestLoader()
-tests = loader.discover('.')
-testRunner = unittest.runner.TextTestRunner()
-# testRunner = xmlrunner.XMLTestRunner(output='.')
-testRunner.run(tests)
+def main():
+    loader = unittest.TestLoader()
+    tests = loader.discover('.')
+    testRunner = unittest.runner.TextTestRunner()
+    ret = not testRunner.run(tests).wasSuccessful()
+    sys.exit(ret)
+
+if __name__ == "__main__":
+    main()
diff --git a/test/python/test_api.py b/test/python/test_api.py
index a671923..e307dc9 100644
--- a/test/python/test_api.py
+++ b/test/python/test_api.py
@@ -129,7 +129,7 @@
                 tensor.Tensor(device=dev, data=b_0).data, rm_t.data, rv_t.data)
 
             np.testing.assert_array_almost_equal(
-                y_1, tensor.to_numpy(_cTensor_to_pyTensor(y_2_c)))
+                y_1, tensor.to_numpy(_cTensor_to_pyTensor(y_2_c)), decimal=4)
             np.testing.assert_array_almost_equal(
                 bm_1, tensor.to_numpy(_cTensor_to_pyTensor(bm_2_c)))
             np.testing.assert_array_almost_equal(rm_1, tensor.to_numpy(rm_t))
@@ -183,7 +183,7 @@
             #print(tensor.to_numpy(_cTensor_to_pyTensor(y_2_c)))
 
             np.testing.assert_array_almost_equal(
-                y_1, tensor.to_numpy(_cTensor_to_pyTensor(y_2_c)), decimal=5)
+                y_1, tensor.to_numpy(_cTensor_to_pyTensor(y_2_c)), decimal=3)
             return
 
         x_0 = np.array([1, 1, 1, 1, 2, 2, 2, 2, 10, 10, 10, 10, 20, 20, 20, 20],
@@ -340,6 +340,19 @@
     def test_transpose_and_arithmetic_op_broadcast_cpu(self):
         self._transpose_and_arithmetic_op_broadcast_helper(cpu_dev)
 
+    def _erf(self, dev=cpu_dev):
+        np1 = np.random.random((2, 3)).astype(np.float32)
+
+        x1 = tensor.from_numpy(np1)
+        x1.to_device(dev)
+        y1 = tensor.from_raw_tensor(singa_api.Erf(x1.data))
+
+        # from scipy.special import erf
+        # np.testing.assert_array_almost_equal(erf(np1), tensor.to_numpy(y1))
+
+    def test_erf_cpu(self):
+        self._erf(cpu_dev)
+
     @unittest.skipIf(not singa_api.USE_CUDA, 'CUDA is not enabled')
     def test_transpose_and_arithmetic_op_broadcast_gpu(self):
         self._transpose_and_arithmetic_op_broadcast_helper(gpu_dev)
@@ -666,6 +679,28 @@
     def test_ceil_gpu(self):
         self._ceil_helper(gpu_dev)
 
+    def _floor_helper(self, dev):
+
+        np1 = np.random.random([5, 6, 7, 8]).astype(np.float32)
+
+        np1 = np.random.random([5, 6, 7, 8]).astype(np.float32)
+        np1 = np1 * 10
+        np2 = np.floor(np1)
+
+        t1 = tensor.Tensor(device=dev, data=np1)
+
+        t2_ct = singa_api.Floor(t1.data)
+
+        np.testing.assert_array_almost_equal(
+            tensor.to_numpy(_cTensor_to_pyTensor(t2_ct)), np2)
+
+    def test_floor_cpu(self):
+        self._floor_helper(cpu_dev)
+
+    @unittest.skipIf(not singa_api.USE_CUDA, 'CUDA is not enabled')
+    def test_floor_gpu(self):
+        self._floor_helper(gpu_dev)
+
     def _as_type_helper(self, dev):
 
         np1 = np.random.random([3]).astype(np.float32)
@@ -730,6 +765,159 @@
     def test_as_type2_gpu(self):
         self._as_type2_helper(gpu_dev)
 
+    @unittest.skipIf(not singa_api.USE_CUDA, 'CUDA is not enabled')
+    def test_rnn_relu(self):
+        self._rnn_helper(0)
+
+    @unittest.skipIf(not singa_api.USE_CUDA, 'CUDA is not enabled')
+    def test_rnn_tanh(self):
+        self._rnn_helper(1)
+
+    @unittest.skipIf(not singa_api.USE_CUDA, 'CUDA is not enabled')
+    def test_rnn_lstm(self):
+        self._rnn_helper(2)
+
+    @unittest.skipIf(not singa_api.USE_CUDA, 'CUDA is not enabled')
+    def test_rnn_gru(self):
+        self._rnn_helper(3)
+
+    def _rnn_helper(self, mode):
+        dev = gpu_dev
+
+        hidden_size = 7
+        seq_length = 5
+        batch_size = 6
+        feature_size = 3
+        directions = 2
+        num_layers = 2
+
+        x = tensor.Tensor(shape=(seq_length, batch_size, feature_size),
+                          device=dev).gaussian(0, 1)
+        hx = tensor.Tensor(shape=(num_layers * directions, batch_size,
+                                  hidden_size),
+                           device=dev).gaussian(0, 1)
+        cx = tensor.Tensor(shape=(num_layers * directions, batch_size,
+                                  hidden_size),
+                           device=dev).gaussian(0, 1)
+
+        rnn_handle = singa_api.CudnnRNNHandle(x.data,
+                                              hidden_size,
+                                              mode,
+                                              num_layers=num_layers,
+                                              dropout=0.1,
+                                              bidirectional=1)
+
+        w = tensor.Tensor(shape=(rnn_handle.weights_size,),
+                          device=dev).gaussian(0, 1)
+        # print("weights size is ", rnn_handle.weights_size)
+
+        (y, hy, cy) = singa_api.GpuRNNForwardTraining(x.data, hx.data, cx.data,
+                                                      w.data, rnn_handle)
+        self.assertEqual(y.shape(),
+                         (seq_length, batch_size, directions * hidden_size))
+        self.assertEqual(hy.shape(), hx.shape)
+        self.assertEqual(cy.shape(), cx.shape)
+
+        (y2, hy2,
+         cy2) = singa_api.GpuRNNForwardInference(x.data, hx.data, cx.data,
+                                                 w.data, rnn_handle)
+        self.assertEqual(y2.shape(),
+                         (seq_length, batch_size, directions * hidden_size))
+        self.assertEqual(hy2.shape(), hx.shape)
+        self.assertEqual(cy2.shape(), cx.shape)
+
+        dy = tensor.Tensor(shape=(seq_length, batch_size,
+                                  directions * hidden_size),
+                           device=dev).gaussian(0, 1)
+        dhy = tensor.Tensor(shape=(num_layers * directions, batch_size,
+                                   hidden_size),
+                            device=dev).gaussian(0, 1)
+        dcy = tensor.Tensor(shape=(num_layers * directions, batch_size,
+                                   hidden_size),
+                            device=dev).gaussian(0, 1)
+
+        (dx, dhx, dcx) = singa_api.GpuRNNBackwardx(y, dy.data, dhy.data,
+                                                   dcy.data, w.data, hx.data,
+                                                   cx.data, rnn_handle)
+        self.assertEqual(dx.shape(), (seq_length, batch_size, feature_size))
+        self.assertEqual(dhx.shape(), hx.shape)
+        self.assertEqual(dcx.shape(), cx.shape)
+
+        dW = singa_api.GpuRNNBackwardW(x.data, hx.data, y, rnn_handle)
+
+    @unittest.skipIf(not singa_api.USE_CUDA, 'CUDA is not enabled')
+    def test_rnn_with_seq_lengths(self):
+        dev = gpu_dev
+
+        # params
+        hidden_size = 7
+        seq_length = 5
+        batch_size = 6
+        feature_size = 3
+        directions = 2
+        num_layers = 2
+
+        # shapes
+        x_s = (seq_length, batch_size, feature_size)
+        y_s = (seq_length, batch_size, hidden_size)
+        states_s = (num_layers * directions, batch_size, hidden_size)
+
+        # tensors
+        x = tensor.Tensor(x_s, dev).gaussian(0, 1)
+        y = tensor.Tensor(y_s, dev).gaussian(0, 1)
+        dy = tensor.Tensor(y_s, dev).gaussian(0, 1)
+        dhy = tensor.Tensor(states_s, dev).gaussian(0, 1)
+        dcy = tensor.Tensor(states_s, dev).gaussian(0, 1)
+        hx = tensor.Tensor(states_s, dev).gaussian(0, 1)
+        cx = tensor.Tensor(states_s, dev).gaussian(0, 1)
+
+        # handle
+        rnn_handle = singa_api.CudnnRNNHandle(x.data, hidden_size, 2)
+        w = tensor.Tensor((rnn_handle.weights_size,), dev).gaussian(0, 1)
+
+        # seq lengths
+        seq_lengths = tensor.from_numpy(np.array([seq_length] * batch_size))
+
+        # operations
+        (dx, dhx, dcx) = singa_api.GpuRNNBackwardxEx(y.data, dy.data, dhy.data,
+                                                     dcy.data, w.data, hx.data,
+                                                     cx.data, seq_lengths.data,
+                                                     rnn_handle)
+
+
+    def test_round_cpu(self):
+        self._round(cpu_dev)
+
+    @unittest.skipIf(not singa_api.USE_CUDA, 'CUDA is not enabled')
+    def test_round_gpu(self):
+        self._round(gpu_dev)
+
+    def _round(self, dev=gpu_dev):
+        x = tensor.Tensor(shape=(3,4,5), device=dev).gaussian(0, 1)
+        y = tensor._call_singa_func(singa_api.Round, x.data)
+        np.testing.assert_array_almost_equal(np.round(tensor.to_numpy(x)),
+                                             tensor.to_numpy(y))
+
+    def test_round_even_cpu(self):
+        self._round_even(cpu_dev)
+
+    @unittest.skipIf(not singa_api.USE_CUDA, 'CUDA is not enabled')
+    def test_round_even_gpu(self):
+        self._round_even(gpu_dev)
+
+    def _round_even(self, dev=gpu_dev):
+        q=np.array([0.1, 0.5, 0.9, 1.2, 1.5,
+                    1.8, 2.3, 2.5, 2.7, -1.1,
+                    -1.5, -1.9, -2.2, -2.5, -2.8]).astype(np.float32)
+        ans = np.array([0., 0., 1., 1., 2.,
+                    2., 2., 2., 3., -1.,
+                    -2., -2., -2., -2., -3.]).astype(np.float32)
+
+        x = tensor.Tensor(shape=q.shape, device=dev)
+        x.copy_from_numpy(q)
+        y = tensor._call_singa_func(singa_api.RoundE, x.data)
+        np.testing.assert_array_almost_equal(ans, tensor.to_numpy(y))
+
 
 if __name__ == '__main__':
     unittest.main()
diff --git a/test/python/test_dist.py b/test/python/test_dist.py
index c87586c..76c3404 100644
--- a/test/python/test_dist.py
+++ b/test/python/test_dist.py
@@ -25,7 +25,7 @@
 if (singa_wrap.USE_DIST):
     sgd = opt.SGD(lr=0.1)
     sgd = opt.DistOpt(sgd)
-    dev = device.create_cuda_gpu_on(sgd.local_rank, set_default=False)
+    dev = device.create_cuda_gpu_on(sgd.local_rank)
     param = tensor.Tensor((10, 10), dev, tensor.float32)
     grad = tensor.Tensor((10, 10), dev, tensor.float32)
     expected = np.ones((10, 10), dtype=np.float32) * (10 - 0.1)
diff --git a/test/python/test_initializer.py b/test/python/test_initializer.py
new file mode 100644
index 0000000..cbd082e
--- /dev/null
+++ b/test/python/test_initializer.py
@@ -0,0 +1,123 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from singa import initializer
+from singa import tensor
+from singa import singa_wrap
+
+from cuda_helper import gpu_dev, cpu_dev
+
+import unittest
+import numpy as np
+
+
+class TestInitializer(unittest.TestCase):
+
+    def setUp(self):
+        self.t1 = tensor.Tensor((40, 90))
+        self.t2 = tensor.Tensor((30, 50, 8))
+        self.t3 = tensor.Tensor((30, 50, 4, 8))
+
+    def compute_fan(self, shape):
+        if len(shape) == 2:
+            fan_in = shape[0]
+            fan_out = shape[1]
+        elif len(shape) in {3, 4, 5}:
+            fan_in = shape[1] * np.prod(shape[2:])
+            fan_out = shape[0] * np.prod(shape[2:])
+        else:
+            fan_in = fan_out = np.sqrt(np.prod(shape))
+
+        return fan_in, fan_out
+
+    def he_uniform(self, dev):
+
+        def init(shape):
+            fan_in, _ = self.compute_fan(shape)
+            limit = np.sqrt(6 / fan_in)
+            return limit
+
+        self.t1.to_device(dev)
+        initializer.he_uniform(self.t1)
+        np_t1 = tensor.to_numpy(self.t1)
+        limit = init(self.t1.shape)
+        self.assertAlmostEqual(np_t1.max(), limit, delta=limit/10)
+        self.assertAlmostEqual(np_t1.min(), -limit, delta=limit/10)
+        self.assertAlmostEqual(np_t1.mean(), 0, delta=limit/10)
+                                       
+        self.t2.to_device(dev)
+        initializer.he_uniform(self.t2)
+        np_t2 = tensor.to_numpy(self.t2)
+        limit = init(self.t2.shape)
+        self.assertAlmostEqual(np_t2.max(), limit, delta=limit/10)
+        self.assertAlmostEqual(np_t2.min(), -limit, delta=limit/10)
+        self.assertAlmostEqual(np_t2.mean(), 0, delta=limit/10)
+ 
+        self.t3.to_device(dev)
+        initializer.he_uniform(self.t3)
+        np_t3 = tensor.to_numpy(self.t3)
+        limit = init(self.t3.shape)
+        self.assertAlmostEqual(np_t3.max(), limit, delta=limit/10)
+        self.assertAlmostEqual(np_t3.min(), -limit, delta=limit/10)
+        self.assertAlmostEqual(np_t3.mean(), 0, delta=limit/10)
+ 
+
+    @unittest.skipIf(not singa_wrap.USE_CUDA, 'CUDA is not enabled')
+    def test_he_uniform_gpu(self):
+        self.he_uniform(gpu_dev)
+
+    def test_he_uniform_cpu(self):
+        self.he_uniform(cpu_dev)
+
+    def he_normal(self, dev):
+
+        def init(shape):
+            fan_in, _ = self.compute_fan(shape)
+            stddev = np.sqrt(2 / fan_in)
+            return stddev
+
+        self.t1.to_device(dev)
+        initializer.he_normal(self.t1)
+        np_t1 = tensor.to_numpy(self.t1)
+        stddev = init(self.t1.shape)
+        self.assertAlmostEqual(np_t1.mean(), 0, delta=stddev/10)
+        self.assertAlmostEqual(np_t1.std(), stddev, delta=stddev/10)
+ 
+        self.t2.to_device(dev)
+        initializer.he_normal(self.t2)
+        np_t2 = tensor.to_numpy(self.t2)
+        stddev = init(self.t2.shape)
+        self.assertAlmostEqual(np_t2.mean(), 0, delta=stddev/10)
+        self.assertAlmostEqual(np_t2.std(), stddev, delta=stddev/10)
+ 
+        self.t3.to_device(dev)
+        initializer.he_normal(self.t3)
+        np_t3 = tensor.to_numpy(self.t3)
+        stddev = init(self.t3.shape)
+        self.assertAlmostEqual(np_t3.mean(), 0, delta=stddev/10)
+        self.assertAlmostEqual(np_t3.std(), stddev, delta=stddev/10)
+
+    @unittest.skipIf(not singa_wrap.USE_CUDA, 'CUDA is not enabled')
+    def test_he_normal_gpu(self):
+        self.he_uniform(gpu_dev)
+
+    def test_he_normal_cpu(self):
+        self.he_uniform(cpu_dev)
+
+
+if __name__ == '__main__':
+    unittest.main()
\ No newline at end of file
diff --git a/test/python/test_layer.py b/test/python/test_layer.py
deleted file mode 100755
index f57e863..0000000
--- a/test/python/test_layer.py
+++ /dev/null
@@ -1,276 +0,0 @@
-from builtins import str
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-import unittest
-import numpy as np
-
-from singa import layer
-from singa import tensor
-from singa.proto import model_pb2
-
-
-def _tuple_to_string(t):
-    lt = [str(x) for x in t]
-    return '(' + ', '.join(lt) + ')'
-
-
-class TestPythonLayer(unittest.TestCase):
-
-    def check_shape(self, actual, expect):
-        self.assertEqual(
-            actual, expect, 'shape mismatch, actual shape is %s'
-            ' exepcted is %s' %
-            (_tuple_to_string(actual), _tuple_to_string(expect)))
-
-    def setUp(self):
-        layer.engine = 'singacpp'
-        self.w = {'init': 'Xavier', 'regularizer': 1e-4}
-        self.b = {'init': 'Constant', 'value': 0}
-        self.sample_shape = None
-
-    def test_conv2D_shape(self):
-        in_sample_shape = (3, 224, 224)
-        conv = layer.Conv2D('conv',
-                            64,
-                            3,
-                            1,
-                            W_specs=self.w,
-                            b_specs=self.b,
-                            input_sample_shape=in_sample_shape)
-        out_sample_shape = conv.get_output_sample_shape()
-        self.check_shape(out_sample_shape, (64, 224, 224))
-
-    def test_conv2D_forward_backward(self):
-        in_sample_shape = (1, 3, 3)
-        conv = layer.Conv2D('conv',
-                            1,
-                            3,
-                            2,
-                            W_specs=self.w,
-                            b_specs=self.b,
-                            pad=1,
-                            input_sample_shape=in_sample_shape)
-        # cuda = device.create_cuda_gpu()
-        # conv.to_device(cuda)
-        params = conv.param_values()
-
-        raw_x = np.arange(9, dtype=np.float32) + 1
-        x = tensor.from_numpy(raw_x)
-        x = x.reshape((1, 1, 3, 3))
-        w = np.array([1, 1, 0, 0, 0, -1, 0, 1, 0], dtype=np.float32)
-        params[0].copy_from_numpy(w)
-        params[1].set_value(1.0)
-
-        # x.to_device(cuda)
-        y = conv.forward(model_pb2.kTrain, x)
-        # y.to_host()
-        npy = tensor.to_numpy(y).flatten()
-
-        self.assertAlmostEqual(3.0, npy[0])
-        self.assertAlmostEqual(7.0, npy[1])
-        self.assertAlmostEqual(-3.0, npy[2])
-        self.assertAlmostEqual(12.0, npy[3])
-
-        dy = np.asarray([0.1, 0.2, 0.3, 0.4], dtype=np.float32).reshape(y.shape)
-        grad = tensor.from_numpy(dy)
-        # grad.to_device(cuda)
-        (dx, [dw, db]) = conv.backward(model_pb2.kTrain, grad)
-        dx.to_host()
-        dw.to_host()
-        dx = tensor.to_numpy(dx).flatten()
-        dw = tensor.to_numpy(dw).flatten()
-        dy = dy.flatten()
-        self.assertAlmostEqual(dy[0] * w[4], dx[0])
-        self.assertAlmostEqual(dy[0] * w[5] + dy[1] * w[3], dx[1])
-        self.assertAlmostEqual(dy[1] * w[4], dx[2])
-        self.assertAlmostEqual(dy[0] * w[7] + dy[2] * w[1], dx[3])
-        self.assertAlmostEqual(
-            dy[0] * w[8] + dy[1] * w[6] + dy[2] * w[2] + dy[3] * w[0], dx[4])
-        self.assertAlmostEqual(dy[1] * w[7] + dy[3] * w[1], dx[5])
-        self.assertAlmostEqual(dy[2] * w[4], dx[6])
-        self.assertAlmostEqual(dy[2] * w[5] + dy[3] * w[3], dx[7])
-        self.assertAlmostEqual(dy[3] * w[4], dx[8])
-
-        self.assertAlmostEqual(dy[3] * raw_x[4], dw[0])
-        self.assertAlmostEqual(dy[3] * raw_x[5] + dy[2] * raw_x[3], dw[1])
-        self.assertAlmostEqual(dy[2] * raw_x[4], dw[2])
-        self.assertAlmostEqual(dy[1] * raw_x[1] + dy[3] * raw_x[7], dw[3])
-        self.assertAlmostEqual(
-            dy[0] * raw_x[0] + dy[1] * raw_x[2] + dy[2] * raw_x[6] +
-            dy[3] * raw_x[8], dw[4], 5)
-        self.assertAlmostEqual(dy[0] * raw_x[1] + dy[2] * raw_x[7], dw[5])
-        self.assertAlmostEqual(dy[1] * raw_x[4], dw[6])
-        self.assertAlmostEqual(dy[0] * raw_x[3] + dy[1] * raw_x[5], dw[7])
-        self.assertAlmostEqual(dy[0] * raw_x[4], dw[8])
-
-    def test_conv1D(self):
-        in_sample_shape = (224,)
-        conv = layer.Conv1D('conv',
-                            64,
-                            3,
-                            1,
-                            W_specs=self.w,
-                            b_specs=self.b,
-                            pad=1,
-                            input_sample_shape=in_sample_shape)
-        out_sample_shape = conv.get_output_sample_shape()
-        self.check_shape(out_sample_shape, (
-            64,
-            224,
-        ))
-
-    def test_max_pooling2D(self):
-        in_sample_shape = (64, 225, 225)
-        pooling = layer.MaxPooling2D('pool',
-                                     3,
-                                     2,
-                                     input_sample_shape=in_sample_shape)
-        out_sample_shape = pooling.get_output_sample_shape()
-        self.check_shape(out_sample_shape, (64, 112, 112))
-
-    def test_max_pooling1D(self):
-        in_sample_shape = (225,)
-        pooling = layer.MaxPooling1D('pool',
-                                     3,
-                                     2,
-                                     input_sample_shape=in_sample_shape)
-        out_sample_shape = pooling.get_output_sample_shape()
-        self.check_shape(out_sample_shape, (112,))
-
-    def test_avg_pooling2D(self):
-        in_sample_shape = (64, 225, 225)
-        pooling = layer.AvgPooling2D('pool',
-                                     3,
-                                     2,
-                                     input_sample_shape=in_sample_shape)
-        out_sample_shape = pooling.get_output_sample_shape()
-        self.check_shape(out_sample_shape, (64, 112, 112))
-
-    def test_avg_pooling1D(self):
-        in_sample_shape = (224,)
-        pooling = layer.AvgPooling1D('pool',
-                                     3,
-                                     2,
-                                     input_sample_shape=in_sample_shape)
-        out_sample_shape = pooling.get_output_sample_shape()
-        self.check_shape(out_sample_shape, (112,))
-
-    def test_batch_normalization(self):
-        in_sample_shape = (3, 224, 224)
-        bn = layer.BatchNormalization('bn', input_sample_shape=in_sample_shape)
-        out_sample_shape = bn.get_output_sample_shape()
-        self.check_shape(out_sample_shape, in_sample_shape)
-
-    def test_lrn(self):
-        in_sample_shape = (3, 224, 224)
-        lrn = layer.LRN('lrn', input_sample_shape=in_sample_shape)
-        out_sample_shape = lrn.get_output_sample_shape()
-        self.check_shape(out_sample_shape, in_sample_shape)
-
-    def test_dense(self):
-        dense = layer.Dense('ip', 32, input_sample_shape=(64,))
-        out_sample_shape = dense.get_output_sample_shape()
-        self.check_shape(out_sample_shape, (32,))
-
-    def test_dropout(self):
-        input_sample_shape = (64, 1, 12)
-        dropout = layer.Dropout('drop', input_sample_shape=input_sample_shape)
-        out_sample_shape = dropout.get_output_sample_shape()
-        self.check_shape(out_sample_shape, input_sample_shape)
-
-    def test_activation(self):
-        input_sample_shape = (64, 1, 12)
-        act = layer.Activation('act', input_sample_shape=input_sample_shape)
-        out_sample_shape = act.get_output_sample_shape()
-        self.check_shape(out_sample_shape, input_sample_shape)
-
-    def test_softmax(self):
-        input_sample_shape = (12,)
-        softmax = layer.Softmax('soft', input_sample_shape=input_sample_shape)
-        out_sample_shape = softmax.get_output_sample_shape()
-        self.check_shape(out_sample_shape, input_sample_shape)
-
-    def test_flatten(self):
-        input_sample_shape = (64, 1, 12)
-        flatten = layer.Flatten('flat', input_sample_shape=input_sample_shape)
-        out_sample_shape = flatten.get_output_sample_shape()
-        self.check_shape(out_sample_shape, (64 * 1 * 12,))
-
-        flatten = layer.Flatten('flat',
-                                axis=2,
-                                input_sample_shape=input_sample_shape)
-        out_sample_shape = flatten.get_output_sample_shape()
-        self.check_shape(out_sample_shape, (12,))
-
-    def test_concat(self):
-        t1 = tensor.Tensor((2, 3))
-        t2 = tensor.Tensor((1, 3))
-        t1.set_value(1)
-        t2.set_value(2)
-        lyr = layer.Concat('concat', 0, [(3,), (3,)])
-        t = lyr.forward(model_pb2.kTrain, [t1, t2])
-        tnp = tensor.to_numpy(t)
-        self.assertEqual(np.sum(tnp), 12)
-        t3 = tensor.Tensor((3, 3))
-        t3.set_value(1.5)
-        grads, _ = lyr.backward(model_pb2.kTrain, [t3])
-        gnp = tensor.to_numpy(grads[0])
-        self.assertEqual(np.sum(gnp), 6 * 1.5)
-
-    def test_slice(self):
-        t = np.zeros((3, 3))
-        t[:, :2] = float(2)
-        t[:, 2] = float(1)
-        lyr = layer.Slice('slice', 1, [2], (3,))
-        out = lyr.forward(model_pb2.kTrain, [tensor.from_numpy(t)])
-        t1 = tensor.to_numpy(out[0])
-        t2 = tensor.to_numpy(out[1])
-        self.assertEqual(np.average(t1), 2)
-        self.assertEqual(np.average(t2), 1)
-        t1 = tensor.Tensor((3, 2))
-        t2 = tensor.Tensor((3, 1))
-        t1.set_value(1)
-        t2.set_value(2)
-        grad, _ = lyr.backward(model_pb2.kTrain, [t1, t2])
-        gnp = tensor.to_numpy(grad)
-        self.assertEqual(np.sum(gnp), 12)
-
-    def test_l2norm(self):
-        in_sample_shape = (3, 224, 224)
-        l2norm = layer.L2Norm('l2norm', input_sample_shape=in_sample_shape)
-        out_sample_shape = l2norm.get_output_sample_shape()
-        self.check_shape(out_sample_shape, in_sample_shape)
-
-    def test_merge(self):
-        in_sample_shape = (3, 224, 224)
-        merge = layer.Merge('merge', input_sample_shape=in_sample_shape)
-        out_sample_shape = merge.get_output_sample_shape()
-        self.check_shape(out_sample_shape, in_sample_shape)
-
-    def test_split(self):
-        in_sample_shape = (3, 224, 224)
-        split = layer.Split('split',
-                            num_output=3,
-                            input_sample_shape=in_sample_shape)
-        out_sample_shape = split.get_output_sample_shape()
-        self.check_shape(out_sample_shape, [in_sample_shape] * 3)
-
-
-if __name__ == '__main__':
-    unittest.main()
diff --git a/test/python/test_loss.py b/test/python/test_loss.py
deleted file mode 100644
index 1d0361d..0000000
--- a/test/python/test_loss.py
+++ /dev/null
@@ -1,71 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-from __future__ import division
-
-import unittest
-import numpy as np
-
-from singa import loss
-from singa import tensor
-
-
-class TestLoss(unittest.TestCase):
-
-    def setUp(self):
-        self.x_np = np.asarray(
-            [[0.9, 0.2, 0.1], [0.1, 0.4, 0.5], [0.2, 0.4, 0.4]],
-            dtype=np.float32)
-
-        self.y_np = np.asarray([[1, 0, 1], [0, 1, 1], [1, 0, 0]],
-                               dtype=np.float32)
-
-        self.x = tensor.from_numpy(self.x_np)
-        self.y = tensor.from_numpy(self.y_np)
-
-    def test_sigmoid_cross_entropy(self):
-        sig = loss.SigmoidCrossEntropy()
-        l1 = sig.forward(True, self.x, self.y)
-        sig.backward()
-        l2 = sig.evaluate(True, self.x, self.y)
-
-        p = 1.0 / (1 + np.exp(np.negative(self.x_np)))
-        l = -(self.y_np * np.log(p) + (1 - self.y_np) * np.log(1 - p))
-        self.assertAlmostEqual(l1.l1(), l2)
-        self.assertAlmostEqual(l1.l1(), np.average(l))
-
-    def test_squared_error(self):
-        sqe = loss.SquaredError()
-        l1 = sqe.forward(True, self.x, self.y)
-        sqe.backward()
-        l2 = sqe.evaluate(True, self.x, self.y)
-
-        l = 0.5 * (self.y_np - self.x_np)**2
-        self.assertAlmostEqual(l1.l1(), tensor.to_numpy(l2).flatten()[0])
-        self.assertAlmostEqual(l1.l1(), np.average(l))
-
-    def test_softmax_cross_entropy(self):
-        sce = loss.SoftmaxCrossEntropy()
-        l1 = sce.forward(True, self.x, self.y)
-        sce.backward()
-        l2 = sce.evaluate(True, self.x, self.y)
-
-        self.assertAlmostEqual(l1.l1(), l2)
-
-
-if __name__ == '__main__':
-    unittest.main()
diff --git a/test/python/test_metric.py b/test/python/test_metric.py
deleted file mode 100644
index e22b5a4..0000000
--- a/test/python/test_metric.py
+++ /dev/null
@@ -1,74 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-from __future__ import division
-
-import unittest
-import numpy as np
-
-from singa import metric
-from singa import tensor
-
-
-class TestPrecision(unittest.TestCase):
-
-    def setUp(self):
-        x_np = np.asarray([[0.7, 0.2, 0.1], [0.2, 0.4, 0.5], [0.2, 0.4, 0.4]],
-                          dtype=np.float32)
-
-        y_np = np.asarray([[1, 0, 1], [0, 1, 1], [1, 0, 0]], dtype=np.int32)
-
-        self.prcs = metric.Precision(top_k=2)
-        self.x = tensor.from_numpy(x_np)
-        self.y = tensor.from_numpy(y_np)
-
-    def test_forward(self):
-        p = self.prcs.forward(self.x, self.y)
-        self.assertAlmostEqual(tensor.to_numpy(p)[0], 0.5)
-        self.assertAlmostEqual(tensor.to_numpy(p)[1], 1)
-        self.assertAlmostEqual(tensor.to_numpy(p)[2], 0)
-
-    def test_evaluate(self):
-        e = self.prcs.evaluate(self.x, self.y)
-        self.assertAlmostEqual(e, (0.5 + 1 + 0) / 3)
-
-
-class TestRecall(unittest.TestCase):
-
-    def setUp(self):
-        x_np = np.asarray([[0.7, 0.2, 0.1], [0.2, 0.4, 0.5], [0.2, 0.4, 0.4]],
-                          dtype=np.float32)
-
-        y_np = np.asarray([[1, 0, 1], [1, 1, 1], [1, 0, 0]], dtype=np.int32)
-
-        self.recall = metric.Recall(top_k=2)
-        self.x = tensor.from_numpy(x_np)
-        self.y = tensor.from_numpy(y_np)
-
-    def test_forward(self):
-        r = self.recall.forward(self.x, self.y)
-        self.assertAlmostEqual(tensor.to_numpy(r)[0], 0.5)
-        self.assertAlmostEqual(tensor.to_numpy(r)[1], 2.0 / 3)
-        self.assertAlmostEqual(tensor.to_numpy(r)[2], 0)
-
-    def test_evaluate(self):
-        e = self.recall.evaluate(self.x, self.y)
-        self.assertAlmostEqual(e, (0.5 + (2.0 / 3) + 0) / 3)
-
-
-if __name__ == '__main__':
-    unittest.main()
diff --git a/test/python/test_model.py b/test/python/test_model.py
new file mode 100644
index 0000000..aaf1023
--- /dev/null
+++ b/test/python/test_model.py
@@ -0,0 +1,499 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# =============================================================================
+
+from __future__ import division
+
+import os
+import math
+import unittest
+import numpy as np
+
+from singa import singa_wrap as singa_api
+from singa.tensor import Tensor
+from singa import autograd
+from singa import tensor
+from singa import device
+from singa import layer
+from singa import model
+from singa import opt
+
+from cuda_helper import gpu_dev, cpu_dev
+
+
+class DoubleLinear(layer.Layer):
+
+    def __init__(self, a, b, c):
+        super(DoubleLinear, self).__init__()
+        self.l1 = layer.Linear(a, b)
+        self.l2 = layer.Linear(b, c)
+
+    def forward(self, x):
+        y = self.l1(x)
+        y = self.l2(y)
+        return y
+
+
+class MyModel(model.Model):
+
+    def __init__(self):
+        super(MyModel, self).__init__()
+        self.conv1 = layer.Conv2d(2, 2)
+        self.bn1 = layer.BatchNorm2d(2)
+        self.doublelinear1 = DoubleLinear(2, 4, 2)
+        self.optimizer = opt.SGD()
+
+    def forward(self, x):
+        y = self.conv1(x)
+        y = self.bn1(y)
+        y = autograd.reshape(y, (y.shape[0], -1))
+        y = self.doublelinear1(y)
+        return y
+
+    def train_one_batch(self, x, y):
+        y_ = self.forward(x)
+        l = self.loss(y_, y)
+        self.optim(l)
+        return y_, l
+
+    def loss(self, out, ty):
+        return autograd.softmax_cross_entropy(out, ty)
+
+    def optim(self, loss):
+        self.optimizer(loss)
+
+
+class MLP(model.Model):
+
+    def __init__(self, data_size=10, perceptron_size=100, num_classes=10):
+        super(MLP, self).__init__()
+        self.num_classes = num_classes
+        self.dimension = 2
+
+        self.relu = layer.ReLU()
+        self.linear1 = layer.Linear(perceptron_size)
+        self.linear2 = layer.Linear(num_classes)
+        self.softmax_cross_entropy = layer.SoftMaxCrossEntropy()
+
+    def forward(self, inputs):
+        y = self.linear1(inputs)
+        y = self.relu(y)
+        y = self.linear2(y)
+        return y
+
+    def train_one_batch(self, x, y):
+        out = self.forward(x)
+        loss = self.softmax_cross_entropy(out, y)
+        self.optimizer(loss)
+        return out, loss
+
+    def set_optimizer(self, optimizer):
+        self.optimizer = optimizer
+
+
+# lstm testing
+class LSTMModel3(model.Model):
+
+    def __init__(self, hidden_size):
+        super(LSTMModel3, self).__init__()
+        self.lstm = layer.CudnnRNN(
+            hidden_size=hidden_size,
+            batch_first=True,
+            #    return_sequences=True,
+            use_mask=True)
+        self.l1 = layer.Linear(2)
+        self.optimizer = opt.SGD(0.1)
+
+    def forward(self, x, seq_lengths):
+        y = self.lstm(x, seq_lengths=seq_lengths)
+        y = autograd.reshape(y, (y.shape[0], -1))
+        y = self.l1(y)
+        return y
+
+
+class LSTMModel2(model.Model):
+
+    def __init__(self, hidden_size, bidirectional, num_layers):
+        super(LSTMModel2, self).__init__()
+        self.lstm = layer.CudnnRNN(hidden_size=hidden_size,
+                                   num_layers=num_layers,
+                                   bidirectional=bidirectional,
+                                   return_sequences=False,
+                                   rnn_mode='lstm',
+                                   batch_first=True)
+        self.optimizer = opt.SGD(0.1)
+
+    def forward(self, x):
+        return self.lstm(x)
+
+
+class LSTMModel(model.Model):
+
+    def __init__(self, hidden_size, seq_length, batch_size, bidirectional,
+                 num_layers, return_sequences, rnn_mode, batch_first):
+        super(LSTMModel, self).__init__()
+        self.hidden_size = hidden_size
+        self.seq_length = seq_length
+        self.return_sequences = return_sequences
+
+        self.lstm = layer.CudnnRNN(hidden_size=hidden_size,
+                                   num_layers=num_layers,
+                                   bidirectional=bidirectional,
+                                   return_sequences=return_sequences,
+                                   rnn_mode=rnn_mode,
+                                   batch_first=batch_first)
+        self.optimizer = opt.SGD(0.1)
+
+    def forward(self, x):
+        y = self.lstm(x)
+        if self.return_sequences:
+            y = autograd.reshape(y, (-1, self.seq_length * self.hidden_size))
+        return y
+
+
+class TestModelMethods(unittest.TestCase):
+
+    @unittest.skipIf(not singa_api.USE_CUDA, 'CUDA is not enabled')
+    def test_rnn_with_seq_lengths(self, dev=gpu_dev):
+        bs = 2
+        seq_length = 3
+        hidden_size = 2
+        em_size = 2
+        x_np = np.array([[[0.1, 0.1], [0.2, 0.2], [0.3, 0.3]],
+                         [[0.3, 0.3], [0.4, 0.4], [0.0,
+                                                   0.0]]]).astype(np.float32)
+        y_np = np.array([[0.4, 0.4], [0.5, 0.5]]).astype(np.float32)
+        seq_lengths_np = np.array([3, 2]).astype(np.int32)
+
+        x = tensor.from_numpy(x_np)
+        x.to_device(dev)
+        y = tensor.from_numpy(y_np)
+        y.to_device(dev)
+        seq_lengths = tensor.from_numpy(seq_lengths_np)
+
+        m = LSTMModel3(hidden_size)
+        m.compile([x, seq_lengths],
+                  is_train=True,
+                  use_graph=False,
+                  sequential=False)
+        m.train()
+        for i in range(10):
+            out = m.forward(x, seq_lengths)
+            loss = autograd.mse_loss(out, y)
+            print("train l:", tensor.to_numpy(loss))
+            m.optimizer(loss)
+        m.eval()
+        out = m.forward(x, seq_lengths)
+        loss = autograd.mse_loss(out, y)
+        print(" eval l:", tensor.to_numpy(loss))
+
+    @unittest.skipIf(not singa_api.USE_CUDA, 'CUDA is not enabled')
+    def test_lstm_model(self, dev=gpu_dev):
+        hidden_size = 3
+        seq_length = 2
+        batch_size = 4
+        feature_size = 3
+        bidirectional = False
+        directions = 2 if bidirectional else 1
+        num_layers = 2
+        out_size = hidden_size
+        return_sequences = False
+        batch_first = True
+        rnn_mode = "lstm"
+
+        # manual test case
+        x_data = np.array([[[0, 0, 1], [0, 1, 0]], [[0, 1, 0], [1, 0, 0]],
+                           [[0, 0, 1], [0, 1, 0]], [[1, 0, 0], [0, 0, 1]]],
+                          dtype=np.float32).reshape(batch_size, seq_length,
+                                                    hidden_size)  # bs, seq, fea
+        if return_sequences:
+            y_data = np.array(
+                [[[0, 1, 0], [1, 0, 0]], [[1, 0, 0], [0, 0, 1]],
+                 [[0, 1, 0], [1, 0, 0]], [[0, 0, 1], [0, 1, 0]]],
+                dtype=np.float32).reshape(batch_size, seq_length,
+                                          hidden_size)  # bs, hidden
+            y_data.reshape(batch_size, -1)
+        else:
+            y_data = np.array([[1, 0, 0], [0, 0, 1], [1, 0, 0], [0, 1, 0]],
+                              dtype=np.float32).reshape(
+                                  batch_size, hidden_size)  # bs, hidden
+
+        x = tensor.Tensor(device=dev, data=x_data)
+        y_t = tensor.Tensor(device=dev, data=y_data)
+
+        m = LSTMModel(hidden_size, seq_length, batch_size, bidirectional,
+                      num_layers, return_sequences, rnn_mode, batch_first)
+        m.compile([x], is_train=True, use_graph=False, sequential=False)
+
+        m.train()
+        for i in range(1000):
+            y = m.forward(x)
+            assert y.shape == y_t.shape
+            loss = autograd.softmax_cross_entropy(y, y_t)
+            if i % 100 == 0:
+                print("loss", loss)
+            m.optimizer(loss)
+
+        m.eval()
+        y = m.forward(x)
+        loss = autograd.softmax_cross_entropy(y, y_t)
+        print("eval loss", loss)
+
+
+class TestModelSaveMethods(unittest.TestCase):
+
+    def _save_states_load_states_helper(self, dev, graph_flag="False"):
+        x_shape = (2, 2, 2, 2)
+        x = tensor.PlaceHolder(x_shape, device=dev)
+
+        m = MyModel()
+        m.compile([x], is_train=True, use_graph=graph_flag, sequential=False)
+
+        states = {
+            "conv1.W":
+                tensor.Tensor((2, 2, 2, 2), device=dev).set_value(0.1),
+            "conv1.b":
+                tensor.Tensor((2,), device=dev).set_value(0.2),
+            "bn1.scale":
+                tensor.Tensor((2,), device=dev).set_value(0.3),
+            "bn1.bias":
+                tensor.Tensor((2,), device=dev).set_value(0.4),
+            "bn1.running_mean":
+                tensor.Tensor((2,), device=dev).set_value(0.5),
+            "bn1.running_var":
+                tensor.Tensor((2,), device=dev).set_value(0.6),
+            "doublelinear1.l1.W":
+                tensor.Tensor((2, 4), device=dev).set_value(0.7),
+            "doublelinear1.l1.b":
+                tensor.Tensor((4,), device=dev).set_value(0.8),
+            "doublelinear1.l2.W":
+                tensor.Tensor((4, 2), device=dev).set_value(0.9),
+            "doublelinear1.l2.b":
+                tensor.Tensor((2,), device=dev).set_value(1.0)
+        }
+
+        m.set_states(states)
+        states2 = m.get_states()
+        for k in states2.keys():
+            np.testing.assert_array_almost_equal(tensor.to_numpy(states[k]),
+                                                 tensor.to_numpy(states2[k]))
+
+        opt_state1 = tensor.Tensor((2, 10), device=dev).gaussian(1, 0.1)
+        opt_state2 = tensor.Tensor((20, 2), device=dev).gaussian(0.1, 1)
+        aux = {"opt1": opt_state1, "opt2": opt_state2}
+
+        # save snapshot1
+        zip_fp = 'snapshot1_%s.zip' % self._testMethodName
+        if os.path.exists(zip_fp):
+            os.remove(zip_fp)
+        m.save_states(zip_fp, aux)
+
+        # do some training, states changes
+        cx = tensor.Tensor(x_shape, device=dev).gaussian(1, 1)
+        cy = tensor.Tensor((2, 2), device=dev).gaussian(1, 1)
+        mini_batch_size = 10
+        for i in range(mini_batch_size):
+            m.train_one_batch(cx, cy)
+
+        # restore snapshot
+        aux2 = m.load_states(zip_fp)
+        np.testing.assert_array_almost_equal(tensor.to_numpy(aux2["opt1"]),
+                                             tensor.to_numpy(aux["opt1"]))
+        np.testing.assert_array_almost_equal(tensor.to_numpy(aux2["opt2"]),
+                                             tensor.to_numpy(aux["opt2"]))
+
+        # snapshot states
+        states3 = m.get_states()
+        for k in states3.keys():
+            np.testing.assert_array_almost_equal(tensor.to_numpy(states[k]),
+                                                 tensor.to_numpy(states3[k]))
+
+    @unittest.skipIf(not singa_api.USE_CUDA, 'CUDA is not enabled')
+    def test_save_states_load_states_gpu(self):
+        self._save_states_load_states_helper(gpu_dev, graph_flag=False)
+        self._save_states_load_states_helper(gpu_dev, graph_flag=True)
+
+    def test_save_states_load_states_cpu(self):
+        self._save_states_load_states_helper(cpu_dev, graph_flag=False)
+        self._save_states_load_states_helper(cpu_dev, graph_flag=True)
+
+
+class TestPythonModule(unittest.TestCase):
+
+    def to_categorical(self, y, num_classes):
+        y = np.array(y, dtype="int")
+        n = y.shape[0]
+        categorical = np.zeros((n, num_classes))
+        categorical[np.arange(n), y] = 1
+        return categorical
+
+    def generate_data(self, dev, num=400):
+        f = lambda x: (5 * x + 1)
+
+        x = np.random.uniform(-1, 1, num)
+        y = f(x) + 2 * np.random.randn(len(x))
+
+        self.label = np.asarray([5 * a + 1 > b for (a, b) in zip(x, y)])
+        self.data = np.array([[a, b] for (a, b) in zip(x, y)], dtype=np.float32)
+        self.label = self.to_categorical(self.label, 2).astype(np.float32)
+
+        self.inputs = Tensor(data=self.data, device=dev)
+        self.target = Tensor(data=self.label, device=dev)
+
+    def get_params(self, model):
+        params = model.get_params()
+        self.w0 = params['linear1.W']
+        self.b0 = params['linear1.b']
+        self.w1 = params['linear2.W']
+        self.b1 = params['linear2.b']
+
+        self.W0 = tensor.to_numpy(self.w0)
+        self.B0 = tensor.to_numpy(self.b0)
+        self.W1 = tensor.to_numpy(self.w1)
+        self.B1 = tensor.to_numpy(self.b1)
+
+    def numpy_forward(self, inputs):
+        self.x1 = np.matmul(inputs, self.W0)
+        self.x2 = np.add(self.x1, self.B0)
+        self.x3 = np.maximum(self.x2, 0)
+        self.x4 = np.matmul(self.x3, self.W1)
+        self.x5 = np.add(self.x4, self.B1)
+        return self.x5
+
+    def numpy_train_one_batch(self, inputs, y):
+        # forward propagation
+        out = self.numpy_forward(inputs)
+
+        # softmax cross entropy loss
+        exp_out = np.exp(out - np.max(out, axis=-1, keepdims=True))
+        self.softmax = exp_out / np.sum(exp_out, axis=-1, keepdims=True)
+        loss = np.sum(y * np.log(self.softmax)) / -self.softmax.shape[0]
+
+        # optimize
+        # calculate gradients
+        label_sum = np.sum(self.label, axis=-1)
+        dloss = self.softmax - self.label / label_sum.reshape(
+            label_sum.shape[0], 1)
+        dloss /= self.softmax.shape[0]
+
+        dx5 = dloss
+        db1 = np.sum(dloss, 0)
+
+        dx4 = np.matmul(dx5, self.W1.T)
+        dw1 = np.matmul(self.x3.T, dx5)
+
+        dx3 = dx4 * (self.x3 > 0)
+
+        dx2 = dx3
+        db0 = np.sum(dx3, 0)
+
+        dx1 = np.matmul(dx2, self.W0.T)
+        dw0 = np.matmul(self.data.T, dx2)
+
+        # update all the params
+        self.W0 -= 0.05 * dw0
+        self.B0 -= 0.05 * db0
+        self.W1 -= 0.05 * dw1
+        self.B1 -= 0.05 * db1
+        return out, loss
+
+    def setUp(self):
+        self.sgd = opt.SGD(lr=0.05)
+
+        cpu_dev.ResetGraph()
+        if singa_api.USE_CUDA:
+            gpu_dev.ResetGraph()
+
+    def tearDown(self):
+        cpu_dev.ResetGraph()
+        if singa_api.USE_CUDA:
+            gpu_dev.ResetGraph()
+
+    def _forward_helper(self, dev, is_train, use_graph, sequential):
+        self.generate_data(dev)
+        model = MLP(self.sgd)
+        model.compile([self.inputs],
+                      is_train=is_train,
+                      use_graph=use_graph,
+                      sequential=sequential)
+
+        self.get_params(model)
+
+        out = model(self.inputs)
+        np_out = self.numpy_forward(self.data)
+
+        np.testing.assert_array_almost_equal(tensor.to_numpy(out), np_out)
+
+    def _train_one_batch_helper(self, dev, is_train, use_graph, sequential):
+        self.generate_data(dev)
+        model = MLP(num_classes=2)
+        model.set_optimizer(self.sgd)
+        model.compile([self.inputs],
+                      is_train=is_train,
+                      use_graph=use_graph,
+                      sequential=sequential)
+
+        self.get_params(model)
+
+        out, loss = model(self.inputs, self.target)
+        np_out, np_loss = self.numpy_train_one_batch(self.data, self.label)
+
+        np.testing.assert_array_almost_equal(tensor.to_numpy(out), np_out)
+        np.testing.assert_array_almost_equal(tensor.to_numpy(loss), np_loss)
+        np.testing.assert_array_almost_equal(tensor.to_numpy(self.w0), self.W0)
+        np.testing.assert_array_almost_equal(tensor.to_numpy(self.b0), self.B0)
+        np.testing.assert_array_almost_equal(tensor.to_numpy(self.w1), self.W1)
+        np.testing.assert_array_almost_equal(tensor.to_numpy(self.b1), self.B1)
+
+    def test_forward_cpu(self):
+        self._forward_helper(cpu_dev, False, True, False)
+
+    @unittest.skipIf(not singa_api.USE_CUDA, 'CUDA is not enabled')
+    def test_forward_gpu(self):
+        self._forward_helper(gpu_dev, False, True, False)
+
+    def test_evaluate_cpu(self):
+        self._forward_helper(cpu_dev, False, False, False)
+
+    @unittest.skipIf(not singa_api.USE_CUDA, 'CUDA is not enabled')
+    def test_evaluate_gpu(self):
+        self._forward_helper(gpu_dev, False, False, False)
+
+    def test_train_one_batch_cpu(self):
+        self._train_one_batch_helper(cpu_dev, True, True, False)
+
+    @unittest.skipIf(not singa_api.USE_CUDA, 'CUDA is not enabled')
+    def test_train_one_batch_gpu(self):
+        self._train_one_batch_helper(gpu_dev, True, True, False)
+
+    def test_without_graph_cpu(self):
+        self._train_one_batch_helper(cpu_dev, True, False, False)
+
+    @unittest.skipIf(not singa_api.USE_CUDA, 'CUDA is not enabled')
+    def test_without_graph_gpu(self):
+        self._train_one_batch_helper(gpu_dev, True, False, False)
+
+    def test_run_in_serial_cpu(self):
+        self._train_one_batch_helper(cpu_dev, True, True, True)
+
+    @unittest.skipIf(not singa_api.USE_CUDA, 'CUDA is not enabled')
+    def test_run_in_serial_gpu(self):
+        self._train_one_batch_helper(gpu_dev, True, True, True)
+
+
+if __name__ == '__main__':
+    unittest.main()
diff --git a/test/python/test_module.py b/test/python/test_module.py
deleted file mode 100644
index a8b5cff..0000000
--- a/test/python/test_module.py
+++ /dev/null
@@ -1,306 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# =============================================================================
-
-import unittest
-import numpy as np
-from builtins import str
-
-from singa import singa_wrap
-from singa import opt
-from singa import device
-from singa import tensor
-from singa import module
-from singa import autograd
-from singa.tensor import Tensor
-
-from cuda_helper import cpu_dev, gpu_dev
-
-
-class MLP(module.Module):
-
-    def __init__(self, optimizer):
-        super(MLP, self).__init__()
-
-        self.w0 = Tensor(shape=(2, 3), requires_grad=True, stores_grad=True)
-        self.b0 = Tensor(shape=(3,), requires_grad=True, stores_grad=True)
-        self.w1 = Tensor(shape=(3, 2), requires_grad=True, stores_grad=True)
-        self.b1 = Tensor(shape=(2,), requires_grad=True, stores_grad=True)
-
-        self.w0.gaussian(0.0, 0.1)
-        self.b0.set_value(0.0)
-        self.w1.gaussian(0.0, 0.1)
-        self.b1.set_value(0.0)
-
-        self.optimizer = optimizer
-
-    def forward(self, inputs):
-        x = autograd.matmul(inputs, self.w0)
-        x = autograd.add_bias(x, self.b0)
-        x = autograd.relu(x)
-        x = autograd.matmul(x, self.w1)
-        x = autograd.add_bias(x, self.b1)
-        return x
-
-    def loss(self, out, target):
-        return autograd.softmax_cross_entropy(out, target)
-
-    def optim(self, loss):
-        return self.optimizer.backward_and_update(loss)
-
-
-class TestPythonModule(unittest.TestCase):
-
-    def to_categorical(self, y, num_classes):
-        y = np.array(y, dtype="int")
-        n = y.shape[0]
-        categorical = np.zeros((n, num_classes))
-        categorical[np.arange(n), y] = 1
-        return categorical
-
-    def generate_data(self, num=400):
-        f = lambda x: (5 * x + 1)
-
-        x = np.random.uniform(-1, 1, num)
-        y = f(x) + 2 * np.random.randn(len(x))
-
-        self.label = np.asarray([5 * a + 1 > b for (a, b) in zip(x, y)])
-        self.data = np.array([[a, b] for (a, b) in zip(x, y)], dtype=np.float32)
-        self.label = self.to_categorical(self.label, 2).astype(np.float32)
-
-        self.inputs = Tensor(data=self.data)
-        self.target = Tensor(data=self.label)
-
-    def get_numpy_params(self, model):
-        self.W0 = tensor.to_numpy(model.w0)
-        self.B0 = tensor.to_numpy(model.b0)
-        self.W1 = tensor.to_numpy(model.w1)
-        self.B1 = tensor.to_numpy(model.b1)
-
-    def numpy_forward(self, inputs):
-        self.x1 = np.matmul(inputs, self.W0)
-        self.x2 = np.add(self.x1, self.B0)
-        self.x3 = np.maximum(self.x2, 0)
-        self.x4 = np.matmul(self.x3, self.W1)
-        self.x5 = np.add(self.x4, self.B1)
-        return self.x5
-
-    def numpy_loss(self, out, y):
-        exp_out = np.exp(out - np.max(out, axis=-1, keepdims=True))
-        self.softmax = exp_out / np.sum(exp_out, axis=-1, keepdims=True)
-
-        loss = np.sum(y * np.log(self.softmax)) / -self.softmax.shape[0]
-
-        return loss
-
-    def numpy_optim(self, loss):
-        # calculate gradients
-        label_sum = np.sum(self.label, axis=-1)
-        dloss = self.softmax - self.label / label_sum.reshape(
-            label_sum.shape[0], 1)
-        dloss /= self.softmax.shape[0]
-
-        dx5 = dloss
-        db1 = np.sum(dloss, 0)
-
-        dx4 = np.matmul(dx5, self.W1.T)
-        dw1 = np.matmul(self.x3.T, dx5)
-
-        dx3 = dx4 * (self.x3 > 0)
-
-        dx2 = dx3
-        db0 = np.sum(dx3, 0)
-
-        dx1 = np.matmul(dx2, self.W0.T)
-        dw0 = np.matmul(self.data.T, dx2)
-
-        # update all the params
-        self.W0 -= 0.05 * dw0
-        self.B0 -= 0.05 * db0
-        self.W1 -= 0.05 * dw1
-        self.B1 -= 0.05 * db1
-
-    def setUp(self):
-        self.sgd = opt.SGD(lr=0.05)
-
-        self.generate_data(400)
-
-        cpu_dev.ResetGraph()
-
-        if singa_wrap.USE_CUDA:
-            gpu_dev.ResetGraph()
-
-    def _forward_helper(self, dev):
-        model = MLP(self.sgd)
-        model.train()
-        model.on_device(dev)
-        self.get_numpy_params(model)
-
-        out = model(self.inputs)
-
-        np_out = self.numpy_forward(self.data)
-
-        np.testing.assert_array_almost_equal(tensor.to_numpy(out), np_out)
-    
-    def test_forward_cpu(self):
-        self._forward_helper(cpu_dev)
-
-    @unittest.skipIf(not singa_wrap.USE_CUDA, 'CUDA is not enabled')
-    def test_forward_gpu(self):
-        self._forward_helper(gpu_dev)
-
-    def _forward_loss_helper(self, dev):
-        model = MLP(self.sgd)
-        model.train()
-        model.on_device(dev)
-        self.get_numpy_params(model)
-
-        out = model(self.inputs)
-        loss = model.loss(out, self.target)
-
-        np_out = self.numpy_forward(self.data)
-        np_loss = self.numpy_loss(np_out, self.label)
-
-        np.testing.assert_array_almost_equal(tensor.to_numpy(out), np_out)
-        np.testing.assert_array_almost_equal(tensor.to_numpy(loss), np_loss)
-
-    def test_forward_loss_cpu(self):
-        self._forward_loss_helper(cpu_dev)
-
-    @unittest.skipIf(not singa_wrap.USE_CUDA, 'CUDA is not enabled')
-    def test_forward_loss_gpu(self):
-        self._forward_loss_helper(gpu_dev)
-
-    def _forward_loss_optim_helper(self, dev):
-        model = MLP(self.sgd)
-        model.train()
-        model.on_device(dev)
-        self.get_numpy_params(model)
-
-        out = model(self.inputs)
-        loss = model.loss(out, self.target)
-        model.optim(loss)
-
-        np_out = self.numpy_forward(self.data)
-        np_loss = self.numpy_loss(np_out, self.label)
-        self.numpy_optim(np_loss)
-
-        np.testing.assert_array_almost_equal(tensor.to_numpy(out), np_out)
-        np.testing.assert_array_almost_equal(tensor.to_numpy(loss), np_loss)
-        np.testing.assert_array_almost_equal(tensor.to_numpy(model.w0),
-                                                 self.W0)
-        np.testing.assert_array_almost_equal(tensor.to_numpy(model.b0),
-                                                 self.B0)
-        np.testing.assert_array_almost_equal(tensor.to_numpy(model.w1),
-                                                 self.W1)
-        np.testing.assert_array_almost_equal(tensor.to_numpy(model.b1),
-                                                 self.B1)
-
-    def test_forward_loss_optim_cpu(self):
-        self._forward_loss_optim_helper(cpu_dev)
-
-    @unittest.skipIf(not singa_wrap.USE_CUDA, 'CUDA is not enabled')
-    def test_forward_loss_optim_gpu(self):
-        self._forward_loss_optim_helper(gpu_dev)
-
-    def _train_without_graph_helper(self, dev):
-        model = MLP(self.sgd)
-        model.train()
-        model.on_device(dev)
-        model.graph(False)
-        self.get_numpy_params(model)
-
-        out = model(self.inputs)
-        loss = model.loss(out, self.target)
-        model.optim(loss)
-
-        np_out = self.numpy_forward(self.data)
-        np_loss = self.numpy_loss(np_out, self.label)
-        self.numpy_optim(np_loss)
-
-        np.testing.assert_array_almost_equal(tensor.to_numpy(out), np_out)
-        np.testing.assert_array_almost_equal(tensor.to_numpy(loss), np_loss)
-        np.testing.assert_array_almost_equal(tensor.to_numpy(model.w0),
-                                                 self.W0)
-        np.testing.assert_array_almost_equal(tensor.to_numpy(model.b0),
-                                                 self.B0)
-        np.testing.assert_array_almost_equal(tensor.to_numpy(model.w1),
-                                                 self.W1)
-        np.testing.assert_array_almost_equal(tensor.to_numpy(model.b1),
-                                                 self.B1)
-
-    def test_without_graph_cpu(self):
-        self._train_without_graph_helper(cpu_dev)
-
-    @unittest.skipIf(not singa_wrap.USE_CUDA, 'CUDA is not enabled')
-    def test_without_graph_gpu(self):
-        self._train_without_graph_helper(gpu_dev)
-
-    def _run_in_serial_helper(self, dev):
-        model = MLP(self.sgd)
-        model.train()
-        model.on_device(dev)
-        model.graph(True, False)
-        self.get_numpy_params(model)
-
-        out = model(self.inputs)
-        loss = model.loss(out, self.target)
-        model.optim(loss)
-
-        np_out = self.numpy_forward(self.data)
-        np_loss = self.numpy_loss(np_out, self.label)
-        self.numpy_optim(np_loss)
-
-        np.testing.assert_array_almost_equal(tensor.to_numpy(out), np_out)
-        np.testing.assert_array_almost_equal(tensor.to_numpy(loss), np_loss)
-        np.testing.assert_array_almost_equal(tensor.to_numpy(model.w0),
-                                                 self.W0)
-        np.testing.assert_array_almost_equal(tensor.to_numpy(model.b0),
-                                                 self.B0)
-        np.testing.assert_array_almost_equal(tensor.to_numpy(model.w1),
-                                                 self.W1)
-        np.testing.assert_array_almost_equal(tensor.to_numpy(model.b1),
-                                                 self.B1)
-
-    def test_run_in_serial_cpu(self):
-        self._run_in_serial_helper(cpu_dev)
-
-    @unittest.skipIf(not singa_wrap.USE_CUDA, 'CUDA is not enabled')
-    def test_run_in_serial_gpu(self):
-        self._run_in_serial_helper(gpu_dev)
-
-    def _evaluate_helper(self, dev):
-        model = MLP(self.sgd)
-        model.eval()
-        model.on_device(dev)
-        self.get_numpy_params(model)
-
-        out = model(self.inputs)
-
-        np_out = self.numpy_forward(self.data)
-
-        np.testing.assert_array_almost_equal(tensor.to_numpy(out), np_out)
-
-    def test_evaluate_cpu(self):
-        self._evaluate_helper(cpu_dev)
-
-    @unittest.skipIf(not singa_wrap.USE_CUDA, 'CUDA is not enabled')
-    def test_evaluate_gpu(self):
-        self._evaluate_helper(gpu_dev)
-
-
-if __name__ == '__main__':
-    unittest.main()
diff --git a/test/python/test_net.py b/test/python/test_net.py
deleted file mode 100644
index 10d0135..0000000
--- a/test/python/test_net.py
+++ /dev/null
@@ -1,115 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-from __future__ import division
-from builtins import zip
-
-import unittest
-import math
-import numpy as np
-
-from singa import net
-from singa import layer
-from singa import tensor
-from singa import loss
-
-layer.engine = 'singacpp'
-# net.verbose = True
-
-
-class TestFeedForwardNet(unittest.TestCase):
-
-    def test_single_input_output(self):
-        ffn = net.FeedForwardNet(loss.SoftmaxCrossEntropy())
-        ffn.add(layer.Activation('relu1', input_sample_shape=(2,)))
-        ffn.add(layer.Activation('relu2'))
-        x = np.array([[-1, 1], [1, 1], [-1, -2]], dtype=np.float32)
-        x = tensor.from_numpy(x)
-        y = tensor.Tensor((3,))
-        y.set_value(0)
-        out, _ = ffn.evaluate(x, y)
-        self.assertAlmostEqual(
-            out * 3,
-            -math.log(1.0 / (1 + math.exp(1))) - math.log(0.5) - math.log(0.5),
-            5)
-
-    def test_mult_inputs(self):
-        ffn = net.FeedForwardNet(loss.SoftmaxCrossEntropy())
-        s1 = ffn.add(layer.Activation('relu1', input_sample_shape=(2,)), [])
-        s2 = ffn.add(layer.Activation('relu2', input_sample_shape=(2,)), [])
-        ffn.add(layer.Merge('merge', input_sample_shape=(2,)), [s1, s2])
-        x1 = tensor.Tensor((2, 2))
-        x1.set_value(1.1)
-        x2 = tensor.Tensor((2, 2))
-        x2.set_value(0.9)
-        out = ffn.forward(False, {'relu1': x1, 'relu2': x2})
-        out = tensor.to_numpy(out)
-        self.assertAlmostEqual(np.average(out), 2)
-
-    def test_mult_outputs(self):
-        ffn = net.FeedForwardNet(loss.SoftmaxCrossEntropy())
-        s1 = ffn.add(layer.Activation('relu1', input_sample_shape=(2,)), [])
-        s2 = ffn.add(layer.Activation('relu2', input_sample_shape=(2,)), [])
-        ffn.add(layer.Merge('merge', input_sample_shape=(2,)), [s1, s2])
-        split = ffn.add(layer.Split('split', 2))
-        ffn.add(layer.Dummy('split1'), split)
-        ffn.add(layer.Dummy('split2'), split)
-        x1 = tensor.Tensor((2, 2))
-        x1.set_value(1.1)
-        x2 = tensor.Tensor((2, 2))
-        x2.set_value(0.9)
-        out = ffn.forward(False, {'relu1': x1, 'relu2': x2})
-        out = tensor.to_numpy(out['split1'])
-        self.assertAlmostEqual(np.average(out), 2)
-
-    def test_save_load(self):
-        ffn = net.FeedForwardNet(loss.SoftmaxCrossEntropy())
-        ffn.add(layer.Conv2D('conv', 4, 3, input_sample_shape=(3, 12, 12)))
-        ffn.add(layer.Flatten('flat'))
-        # ffn.add(layer.BatchNorm('bn'))
-        ffn.add(layer.Dense('dense', num_output=4))
-        for pname, pval in zip(ffn.param_names(), ffn.param_values()):
-            pval.set_value(0.1)
-        ffn.save('test_snaphost')
-        ffn.save('test_pickle', use_pickle=True)
-
-        ffn.load('test_snaphost')
-        ffn.load('test_pickle', use_pickle=True)
-
-    def test_train_one_batch(self):
-        ffn = net.FeedForwardNet(loss.SoftmaxCrossEntropy())
-        ffn.add(layer.Conv2D('conv', 4, 3, input_sample_shape=(3, 12, 12)))
-        ffn.add(layer.Flatten('flat'))
-        ffn.add(layer.Dense('dense', num_output=4))
-        for pname, pval in zip(ffn.param_names(), ffn.param_values()):
-            pval.set_value(0.1)
-        x = tensor.Tensor((4, 3, 12, 12))
-        x.gaussian(0, 0.01)
-        y = np.asarray([[1, 0, 0], [0, 0, 1], [0, 0, 1], [0, 1, 0]],
-                       dtype=np.int32)
-        y = tensor.from_numpy(y)
-        o = ffn.forward(True, x)
-        ffn.loss.forward(True, o, y)
-        g = ffn.loss.backward()
-        for pname, pvalue, pgrad, _ in ffn.backward(g):
-            self.assertEqual(len(pvalue), len(pgrad))
-            for p, g in zip(pvalue, pgrad):
-                self.assertEqual(p.size(), g.size())
-
-
-if __name__ == '__main__':
-    unittest.main()
diff --git a/test/python/test_onnx.py b/test/python/test_onnx.py
index 5a951ba..504fbcf 100644
--- a/test/python/test_onnx.py
+++ b/test/python/test_onnx.py
@@ -22,6 +22,7 @@
 from singa import tensor
 from singa import singa_wrap as singa
 from singa import autograd
+from singa import layer
 from singa import sonnx
 from singa import opt
 
@@ -54,12 +55,13 @@
     def _conv2d_helper(self, dev):
         x = tensor.Tensor(shape=(2, 3, 3, 3), device=dev)
         x.gaussian(0.0, 1.0)
-        y = autograd.Conv2d(3, 1, 2)(x)
+        y = layer.Conv2d(1, 2)(x)
 
         # frontend
         model = sonnx.to_onnx([x], [y])
         # backend
         sg_ir = sonnx.prepare(model, device=dev)
+        sg_ir.is_graph = True
         y_t = sg_ir.run([x])
 
         np.testing.assert_array_almost_equal(tensor.to_numpy(y),
@@ -87,6 +89,7 @@
 
         # backend
         sg_ir = sonnx.prepare(model, device=dev)
+        sg_ir.is_graph = True
         y_t = sg_ir.run([x])
         np.testing.assert_array_almost_equal(tensor.to_numpy(y),
                                              tensor.to_numpy(y_t[0]),
@@ -102,13 +105,14 @@
     def _avg_pool_helper(self, dev):
         x = tensor.Tensor(shape=(2, 3, 3, 3), device=dev)
         x.gaussian(0.0, 1.0)
-        y = autograd.AvgPool2d(3, 1, 2)(x)
+        y = layer.AvgPool2d(3, 1, 2)(x)
 
         # frontend
         model = sonnx.to_onnx([x], [y])
 
         # backend
         sg_ir = sonnx.prepare(model, device=dev)
+        sg_ir.is_graph = True
         y_t = sg_ir.run([x])
 
         np.testing.assert_array_almost_equal(tensor.to_numpy(y),
@@ -134,6 +138,7 @@
 
         # backend
         sg_ir = sonnx.prepare(model, device=dev)
+        sg_ir.is_graph = True
         y_t = sg_ir.run([x])
         np.testing.assert_array_almost_equal(tensor.to_numpy(y),
                                              tensor.to_numpy(y_t[0]),
@@ -158,6 +163,7 @@
 
         # backend
         sg_ir = sonnx.prepare(model, device=dev)
+        sg_ir.is_graph = True
         y_t = sg_ir.run([x])
         np.testing.assert_array_almost_equal(tensor.to_numpy(y),
                                              tensor.to_numpy(y_t[0]),
@@ -186,6 +192,7 @@
 
         # backend
         sg_ir = sonnx.prepare(model, device=dev)
+        sg_ir.is_graph = True
         y_t = sg_ir.run([x1, x2])
 
         np.testing.assert_array_almost_equal(tensor.to_numpy(y),
@@ -214,6 +221,7 @@
 
         # backend
         sg_ir = sonnx.prepare(model, device=dev)
+        sg_ir.is_graph = True
         y_t = sg_ir.run([x1, x2])
 
         np.testing.assert_array_almost_equal(tensor.to_numpy(y),
@@ -244,6 +252,7 @@
 
         # backend
         sg_ir = sonnx.prepare(model, device=dev)
+        sg_ir.is_graph = True
         y_t = sg_ir.run([x1, x2])
 
         np.testing.assert_array_almost_equal(tensor.to_numpy(y),
@@ -260,7 +269,7 @@
     def _max_pool_helper(self, dev):
         x = tensor.Tensor(shape=(2, 3, 4, 4), device=dev)
         x.gaussian(0.0, 1.0)
-        y = autograd.MaxPool2d(2, 2, 0)(x)
+        y = layer.MaxPool2d(2, 2, 0)(x)
 
         # frontend
         model = sonnx.to_onnx([x], [y])
@@ -268,6 +277,7 @@
 
         # backend
         sg_ir = sonnx.prepare(model, device=dev)
+        sg_ir.is_graph = True
         y_t = sg_ir.run([x])
 
         np.testing.assert_array_almost_equal(tensor.to_numpy(y),
@@ -312,6 +322,7 @@
 
         # backend
         sg_ir = sonnx.prepare(model, device=dev)
+        sg_ir.is_graph = True
         y_t = sg_ir.run([x, s, bias])  # mean and var has been stored in graph
 
         np.testing.assert_array_almost_equal(tensor.to_numpy(y),
@@ -329,7 +340,7 @@
         x = tensor.Tensor(shape=(2, 20), device=dev)
         x.gaussian(0.0, 1.0)
         x1 = x.clone()
-        y = autograd.Linear(20, 1, bias=False)(x)
+        y = layer.Linear(20, 1, bias=False)(x)
 
         # frontend
         model = sonnx.to_onnx([x], [y])
@@ -337,6 +348,7 @@
 
         # backend
         sg_ir = sonnx.prepare(model, device=dev)
+        sg_ir.is_graph = True
         y_t = sg_ir.run([x1])
 
         np.testing.assert_array_almost_equal(tensor.to_numpy(y),
@@ -371,6 +383,7 @@
 
         # backend
         sg_ir = sonnx.prepare(model, device=dev)
+        sg_ir.is_graph = True
         y_t = sg_ir.run([tA, tB, tC])
 
         np.testing.assert_array_almost_equal(tensor.to_numpy(y),
@@ -397,6 +410,7 @@
 
         # backend
         sg_ir = sonnx.prepare(model, device=dev)
+        sg_ir.is_graph = True
         y_t = sg_ir.run([x])  # shape has been stored in graph
 
         np.testing.assert_array_almost_equal(tensor.to_numpy(y),
@@ -425,6 +439,7 @@
 
         # backend
         sg_ir = sonnx.prepare(model, device=dev)
+        sg_ir.is_graph = True
         y_t = sg_ir.run([x, x1])
 
         np.testing.assert_array_almost_equal(tensor.to_numpy(y),
@@ -450,6 +465,7 @@
 
         # backend
         sg_ir = sonnx.prepare(model, device=dev)
+        sg_ir.is_graph = True
         y_t = sg_ir.run([x])
 
         np.testing.assert_array_almost_equal(tensor.to_numpy(y),
@@ -475,6 +491,7 @@
 
         # backend
         sg_ir = sonnx.prepare(model, device=dev)
+        sg_ir.is_graph = True
         y_t = sg_ir.run([x])
 
         np.testing.assert_array_almost_equal(tensor.to_numpy(y),
@@ -500,6 +517,7 @@
 
         # backend
         sg_ir = sonnx.prepare(model, device=dev)
+        sg_ir.is_graph = True
         y_t = sg_ir.run([x])
 
         np.testing.assert_array_almost_equal(tensor.to_numpy(y),
@@ -525,6 +543,7 @@
 
         # backend
         sg_ir = sonnx.prepare(model, device=dev)
+        sg_ir.is_graph = True
         y_t = sg_ir.run([x])
 
         np.testing.assert_array_almost_equal(tensor.to_numpy(y),
@@ -550,6 +569,7 @@
 
         # backend
         sg_ir = sonnx.prepare(model, device=dev)
+        sg_ir.is_graph = True
         y_t = sg_ir.run([x])
 
         np.testing.assert_array_almost_equal(tensor.to_numpy(y),
@@ -575,6 +595,7 @@
 
         # backend
         sg_ir = sonnx.prepare(model, device=dev)
+        sg_ir.is_graph = True
         y_t = sg_ir.run([x])
 
         np.testing.assert_array_almost_equal(tensor.to_numpy(y),
@@ -600,6 +621,7 @@
 
         # backend
         sg_ir = sonnx.prepare(model, device=dev)
+        sg_ir.is_graph = True
         y_t = sg_ir.run([x])
 
         np.testing.assert_array_almost_equal(tensor.to_numpy(y),
@@ -625,6 +647,7 @@
 
         # backend
         sg_ir = sonnx.prepare(model, device=dev)
+        sg_ir.is_graph = True
         y_t = sg_ir.run([x])
 
         np.testing.assert_array_almost_equal(tensor.to_numpy(y),
@@ -650,6 +673,7 @@
 
         # backend
         sg_ir = sonnx.prepare(model, device=dev)
+        sg_ir.is_graph = True
         y_t = sg_ir.run([x])
 
         np.testing.assert_array_almost_equal(tensor.to_numpy(y),
@@ -675,6 +699,7 @@
 
         # backend
         sg_ir = sonnx.prepare(model, device=dev)
+        sg_ir.is_graph = True
         y_t = sg_ir.run([x])
 
         np.testing.assert_array_almost_equal(tensor.to_numpy(y),
@@ -700,6 +725,7 @@
 
         # backend
         sg_ir = sonnx.prepare(model, device=dev)
+        sg_ir.is_graph = True
         y_t = sg_ir.run([x])
 
         np.testing.assert_array_almost_equal(tensor.to_numpy(y),
@@ -725,6 +751,7 @@
 
         # backend
         sg_ir = sonnx.prepare(model, device=dev)
+        sg_ir.is_graph = True
         y_t = sg_ir.run([x])
 
         np.testing.assert_array_almost_equal(tensor.to_numpy(y),
@@ -755,6 +782,7 @@
 
         # backend
         sg_ir = sonnx.prepare(model, device=dev)
+        sg_ir.is_graph = True
         y_t = sg_ir.run([x])
 
         np.testing.assert_array_almost_equal(tensor.to_numpy(y),
@@ -784,6 +812,7 @@
 
         # backend
         sg_ir = sonnx.prepare(model, device=dev)
+        sg_ir.is_graph = True
         y_t = sg_ir.run([x])
 
         np.testing.assert_array_almost_equal(tensor.to_numpy(y),
@@ -816,6 +845,7 @@
 
     #     # backend
     #     sg_ir = sonnx.prepare(model, device=dev)
+    #     sg_ir.is_graph = True
     #     y_t = sg_ir.run([x0, x1])
 
     #     np.testing.assert_array_almost_equal(tensor.to_numpy(y),
@@ -847,6 +877,7 @@
 
         # backend
         sg_ir = sonnx.prepare(model, device=dev)
+        sg_ir.is_graph = True
         y_t = sg_ir.run([x0, x1])
 
         np.testing.assert_array_almost_equal(tensor.to_numpy(y),
@@ -873,6 +904,7 @@
 
         # backend
         sg_ir = sonnx.prepare(model, device=dev)
+        sg_ir.is_graph = True
         y_t = sg_ir.run([x])
 
         np.testing.assert_array_almost_equal(tensor.to_numpy(y),
@@ -904,6 +936,7 @@
 
         # backend
         sg_ir = sonnx.prepare(model, device=dev)
+        sg_ir.is_graph = True
         y_t = sg_ir.run([x0, x1])
 
         np.testing.assert_array_almost_equal(tensor.to_numpy(y),
@@ -935,6 +968,7 @@
 
         # backend
         sg_ir = sonnx.prepare(model, device=dev)
+        sg_ir.is_graph = True
         y_t = sg_ir.run([x0, x1])
 
         np.testing.assert_array_almost_equal(tensor.to_numpy(y),
@@ -961,6 +995,7 @@
 
         # backend
         sg_ir = sonnx.prepare(model, device=dev, init_inputs=X)
+        sg_ir.is_graph = True
         y_t = sg_ir.run([x])
 
         np.testing.assert_array_almost_equal(tensor.to_numpy(y),
@@ -992,6 +1027,7 @@
 
         # backend
         sg_ir = sonnx.prepare(model, device=dev)
+        sg_ir.is_graph = True
         y_t = sg_ir.run([x0, x1])
 
         np.testing.assert_array_almost_equal(tensor.to_numpy(y),
@@ -1021,6 +1057,7 @@
 
         # backend
         sg_ir = sonnx.prepare(model, device=dev)
+        sg_ir.is_graph = True
         y_t = sg_ir.run([x])
 
         np.testing.assert_array_almost_equal(tensor.to_numpy(y),
@@ -1048,6 +1085,7 @@
 
         # backend
         sg_ir = sonnx.prepare(model, device=dev)
+        sg_ir.is_graph = True
         y_t = sg_ir.run([x])
 
         np.testing.assert_array_almost_equal(tensor.to_numpy(y),
@@ -1075,6 +1113,7 @@
 
         # backend
         sg_ir = sonnx.prepare(model, device=dev)
+        sg_ir.is_graph = True
         y_t = sg_ir.run([x])
 
         np.testing.assert_array_almost_equal(tensor.to_numpy(y),
@@ -1102,6 +1141,7 @@
 
         # backend
         sg_ir = sonnx.prepare(model, device=dev)
+        sg_ir.is_graph = True
         y_t = sg_ir.run([x])
 
         np.testing.assert_array_almost_equal(tensor.to_numpy(y),
@@ -1133,6 +1173,7 @@
 
         # backend
         sg_ir = sonnx.prepare(model, device=dev)
+        sg_ir.is_graph = True
         y_t = sg_ir.run([x0, x1])
 
         np.testing.assert_array_almost_equal(tensor.to_numpy(y),
@@ -1163,6 +1204,7 @@
 
         # backend
         sg_ir = sonnx.prepare(model, device=dev)
+        sg_ir.is_graph = True
         y_t = sg_ir.run([x0, x1])
 
         np.testing.assert_array_almost_equal(tensor.to_numpy(y),
@@ -1193,6 +1235,7 @@
 
         # backend
         sg_ir = sonnx.prepare(model, device=dev)
+        sg_ir.is_graph = True
         y_t = sg_ir.run([x])  # min, max has been stored in model
 
         np.testing.assert_array_almost_equal(tensor.to_numpy(y),
@@ -1225,6 +1268,7 @@
 
         # backend
         sg_ir = sonnx.prepare(model, device=dev)
+        sg_ir.is_graph = True
         y_t = sg_ir.run([x, slope])
 
         np.testing.assert_array_almost_equal(tensor.to_numpy(y),
@@ -1255,6 +1299,7 @@
 
         # backend
         sg_ir = sonnx.prepare(model, device=dev)
+        sg_ir.is_graph = True
         y_t = sg_ir.run([x, x1])
 
         np.testing.assert_array_almost_equal(tensor.to_numpy(y),
@@ -1283,6 +1328,7 @@
 
         # backend
         sg_ir = sonnx.prepare(model, device=dev)
+        sg_ir.is_graph = True
         y_t = sg_ir.run([x])
 
         np.testing.assert_array_almost_equal(tensor.to_numpy(y),
@@ -1314,6 +1360,7 @@
 
         # backend
         sg_ir = sonnx.prepare(model, device=dev)
+        sg_ir.is_graph = True
         y_t = sg_ir.run([x0, x1])
 
         np.testing.assert_array_almost_equal(tensor.to_numpy(y),
@@ -1345,6 +1392,7 @@
 
         # backend
         sg_ir = sonnx.prepare(model, device=dev)
+        sg_ir.is_graph = True
         y_t = sg_ir.run([x0, x1])
 
         np.testing.assert_array_almost_equal(tensor.to_numpy(y),
@@ -1372,6 +1420,7 @@
 
         # backend
         sg_ir = sonnx.prepare(model, device=dev)
+        sg_ir.is_graph = True
         y_t = sg_ir.run([x])
 
         np.testing.assert_array_almost_equal(tensor.to_numpy(y),
@@ -1404,6 +1453,7 @@
 
         # backend
         sg_ir = sonnx.prepare(model, device=dev)
+        sg_ir.is_graph = True
         y_t = sg_ir.run([x0, x1])
 
         np.testing.assert_array_almost_equal(tensor.to_numpy(y),
@@ -1435,6 +1485,7 @@
 
         # backend
         sg_ir = sonnx.prepare(model, device=dev)
+        sg_ir.is_graph = True
         y_t = sg_ir.run([x0, x1])
 
         np.testing.assert_array_almost_equal(tensor.to_numpy(y),
@@ -1467,6 +1518,7 @@
 
         # backend
         sg_ir = sonnx.prepare(model, device=dev)
+        sg_ir.is_graph = True
         y_t = sg_ir.run([x0, x1])
 
         np.testing.assert_array_almost_equal(tensor.to_numpy(y),
@@ -1494,6 +1546,7 @@
 
         # backend
         sg_ir = sonnx.prepare(model, device=dev)
+        sg_ir.is_graph = True
         y_t = sg_ir.run([x])
 
         np.testing.assert_array_almost_equal(tensor.to_numpy(y),
@@ -1521,6 +1574,7 @@
 
         # backend
         sg_ir = sonnx.prepare(model, device=dev)
+        sg_ir.is_graph = True
         y_t = sg_ir.run([x])
 
         np.testing.assert_array_almost_equal(tensor.to_numpy(y),
@@ -1547,6 +1601,7 @@
 
         # backend
         sg_ir = sonnx.prepare(model, device=dev)
+        sg_ir.is_graph = True
         y_t = sg_ir.run([x])
 
         np.testing.assert_array_almost_equal(tensor.to_numpy(y),
@@ -1572,6 +1627,7 @@
 
         # backend
         sg_ir = sonnx.prepare(model, device=dev, init_inputs=[X])
+        sg_ir.is_graph = True
         y_t = sg_ir.run([x])
 
         np.testing.assert_array_almost_equal(tensor.to_numpy(y),
@@ -1598,6 +1654,7 @@
 
         # backend
         sg_ir = sonnx.prepare(model, device=dev)
+        sg_ir.is_graph = True
         y_t = sg_ir.run([x])
 
         self.check_shape(
@@ -1624,6 +1681,7 @@
 
         # backend
         sg_ir = sonnx.prepare(model, device=dev)
+        sg_ir.is_graph = True
         y_t = sg_ir.run([x])
 
         np.testing.assert_array_almost_equal(
@@ -1650,6 +1708,7 @@
 
         # backend
         sg_ir = sonnx.prepare(model, device=dev)
+        sg_ir.is_graph = True
         y_t = sg_ir.run([x])
 
         np.testing.assert_array_almost_equal(
@@ -1676,6 +1735,7 @@
 
         # backend
         sg_ir = sonnx.prepare(model, device=dev)
+        sg_ir.is_graph = True
         y_t = sg_ir.run([x])
 
         np.testing.assert_array_almost_equal(
@@ -1702,6 +1762,7 @@
 
         # backend
         sg_ir = sonnx.prepare(model, device=dev)
+        sg_ir.is_graph = True
         y_t = sg_ir.run([x])
 
         np.testing.assert_array_almost_equal(
@@ -1728,6 +1789,7 @@
 
         # backend
         sg_ir = sonnx.prepare(model, device=dev)
+        sg_ir.is_graph = True
         y_t = sg_ir.run([x])
 
         np.testing.assert_array_almost_equal(
@@ -1754,6 +1816,7 @@
 
     #       # backend
     #       sg_ir = sonnx.prepare(model, device=dev)
+    #       sg_ir.is_graph = True
     #       y_t = sg_ir.run([x])[0]
 
     #       np.testing.assert_array_almost_equal(tensor.to_numpy(y).shape, tensor.to_numpy(y_t).shape)
@@ -1777,6 +1840,7 @@
 
         # backend
         sg_ir = sonnx.prepare(model, device=dev)
+        sg_ir.is_graph = True
         y_t = sg_ir.run([x])
 
         np.testing.assert_array_almost_equal(
@@ -1802,6 +1866,7 @@
 
         # backend
         sg_ir = sonnx.prepare(model, device=dev)
+        sg_ir.is_graph = True
         y_t = sg_ir.run([x])
 
         np.testing.assert_array_almost_equal(
@@ -1827,6 +1892,7 @@
 
         # backend
         sg_ir = sonnx.prepare(model, device=dev)
+        sg_ir.is_graph = True
         y_t = sg_ir.run([x])
 
         np.testing.assert_array_almost_equal(
@@ -1852,6 +1918,7 @@
 
         # backend
         sg_ir = sonnx.prepare(model, device=dev)
+        sg_ir.is_graph = True
         y_t = sg_ir.run([x])
 
         np.testing.assert_array_almost_equal(
@@ -1884,6 +1951,7 @@
 
         # backend
         sg_ir = sonnx.prepare(model, device=dev)
+        sg_ir.is_graph = True
         y_t = sg_ir.run([x])
 
         self.check_shape(
@@ -1900,15 +1968,31 @@
     def _inference_helper(self, dev):
         x = tensor.Tensor(shape=(2, 3, 3, 3), device=dev)
         x.gaussian(0.0, 1.0)
-        x1 = autograd.Conv2d(3, 1, 2)(x)
-        y = autograd.Conv2d(1, 1, 2)(x1)
 
+        conv1 = layer.Conv2d(1, 2)
+        conv2 = layer.Conv2d(1, 2)
+
+        class MyLayer(layer.Layer):
+
+            def __init__(self, conv1, conv2):
+                super(MyLayer, self).__init__()
+                self.conv1 = conv1
+                self.conv2 = conv2
+
+            def forward(self, inputs):
+                x = self.conv1(inputs)
+                x = self.conv2(x)
+                return x
+
+        y = MyLayer(conv1, conv2)(x)
+        x1 = conv1(x)
         # frontend
         model = sonnx.to_onnx([x], [y])
         # print('The model is:\n{}'.format(model))
 
         # backend
         sg_ir = sonnx.prepare(model, device=dev)
+        sg_ir.is_graph = True
         y_t = sg_ir.run([x], last_layers=-1)
 
         np.testing.assert_array_almost_equal(tensor.to_numpy(x1),
@@ -1926,16 +2010,28 @@
         # forward
         x = tensor.Tensor(shape=(2, 3, 3, 3), device=dev)
         x.gaussian(0.0, 1.0)
-        x1 = autograd.Conv2d(3, 1, 2)(x)
-        x2 = autograd.Conv2d(1, 1, 2)(x1)
-        y = autograd.Flatten()(x2)[0]
+
+        class MyLayer(layer.Layer):
+
+            def __init__(self):
+                super(MyLayer, self).__init__()
+                self.conv1 = layer.Conv2d(1, 2)
+                self.conv2 = layer.Conv2d(1, 2)
+
+            def forward(self, inputs):
+                x = self.conv1(inputs)
+                x = self.conv2(x)
+                x = autograd.flatten(x)
+                return x
+
+        y = MyLayer()(x)
         y_t = tensor.Tensor(shape=(2, 1), device=dev)
         y_t.gaussian(0.0, 1.0)
-        loss = autograd.MeanSquareError()(y, y_t)[0]
+        loss = autograd.MeanSquareError(y_t)(y)[0]
         # backward
         sgd = opt.SGD(lr=0.01)
         for p, gp in autograd.backward(loss):
-            sgd.update(p, gp)
+            sgd.apply(p.name, p, gp)
         sgd.step()
 
         # frontend
@@ -1944,17 +2040,14 @@
 
         # backend
         sg_ir = sonnx.prepare(model, device=dev)
-        for idx, tens in sg_ir.tensor_map.items():
-            tens.requires_grad = True
-            tens.stores_grad = True
-            sg_ir.tensor_map[idx] = tens
+        sg_ir.is_graph = True
         # forward
         y_o = sg_ir.run([x])[0]
         # backward
-        loss = autograd.MeanSquareError()(y_o, y_t)[0]
+        loss = autograd.MeanSquareError(y_t)(y_o)[0]
         sgd = opt.SGD(lr=0.01)
         for p, gp in autograd.backward(loss):
-            sgd.update(p, gp)
+            sgd.apply(p.name, p, gp)
         sgd.step()
 
     def test_retraining_cpu(self):
@@ -1968,15 +2061,26 @@
         # forward
         x = tensor.Tensor(shape=(2, 3, 3, 3), device=dev)
         x.gaussian(0.0, 1.0)
-        x1 = autograd.Conv2d(3, 1, 2)(x)
-        y = autograd.Flatten()(x1)[0]
+
+        class MyLayer(layer.Layer):
+
+            def __init__(self):
+                super(MyLayer, self).__init__()
+                self.conv1 = layer.Conv2d(1, 2)
+
+            def forward(self, inputs):
+                x = self.conv1(inputs)
+                x = autograd.flatten(x)
+                return x
+
+        y = MyLayer()(x)
         y_t = tensor.Tensor(shape=(2, 4), device=dev)
         y_t.gaussian(0.0, 1.0)
-        loss = autograd.MeanSquareError()(y, y_t)[0]
+        loss = autograd.MeanSquareError(y_t)(y)[0]
         # backward
         sgd = opt.SGD(lr=0.01)
         for p, gp in autograd.backward(loss):
-            sgd.update(p, gp)
+            sgd.apply(p.name, p, gp)
         sgd.step()
 
         # frontend
@@ -1985,17 +2089,31 @@
 
         # backend
         sg_ir = sonnx.prepare(model, device=dev)
+        sg_ir.is_graph = True
         # forward
-        x1 = sg_ir.run([x], last_layers=-1)[0]
-        x2 = autograd.Conv2d(1, 1, 2)(x1)
-        y_o = autograd.Flatten()(x2)[0]
+        class MyLayer2(layer.Layer):
+
+            def __init__(self, sg_ir):
+                super(MyLayer2, self).__init__()
+                self.sg_ir = sg_ir
+                for node, operator in self.sg_ir.layers:
+                    self.__dict__[node.name] = operator
+                self.conv2 = layer.Conv2d(1, 2)
+
+            def forward(self, inputs):
+                x = self.sg_ir.run(inputs, last_layers=-1)[0]
+                x = self.conv2(inputs)
+                x = autograd.flatten(x)
+                return x
+
+        y_o = MyLayer()(x)
         # backward
         y_ot = tensor.Tensor(shape=(2, 1), device=dev)
         y_ot.gaussian(0.0, 1.0)
-        loss = autograd.MeanSquareError()(y_o, y_ot)[0]
+        loss = autograd.MeanSquareError(y_ot)(y_o)[0]
         sgd = opt.SGD(lr=0.01)
         for p, gp in autograd.backward(loss):
-            sgd.update(p, gp)
+            sgd.apply(p.name, p, gp)
         sgd.step()
 
     def test_transfer_learning_cpu(self):
diff --git a/test/python/test_onnx_backend.py b/test/python/test_onnx_backend.py
index 2144d1d..0e7bb65 100644
--- a/test/python/test_onnx_backend.py
+++ b/test/python/test_onnx_backend.py
@@ -15,8 +15,10 @@
 # limitations under the License.
 # =============================================================================
 
-import unittest
-from builtins import str
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
 
 from singa import tensor
 from singa import singa_wrap as singa
@@ -24,3125 +26,106 @@
 from singa import sonnx
 from singa import opt
 
-import onnx
-from onnx import (defs, checker, helper, numpy_helper, mapping, ModelProto,
-                  GraphProto, NodeProto, AttributeProto, TensorProto,
-                  OperatorSetIdProto)
-from onnx.helper import make_tensor, make_tensor_value_info, make_node, make_graph
+import os
 
-from cuda_helper import gpu_dev, cpu_dev
+import unittest
+import onnx.backend.test
 
-import numpy as np
-import itertools
+# This is a pytest magic variable to load extra plugins
+pytest_plugins = 'onnx.backend.test.report',
 
-autograd.training = True
+backend_test = onnx.backend.test.BackendTest(sonnx.SingaBackend, __name__)
 
-_default_opset_version = 11
+_include_nodes_patterns = {
+    # rename some patterns
+    'ReduceSum': r'(test_reduce_sum)',
+    'ReduceMean': r'(test_reduce_mean)',
+    'BatchNormalization': r'(test_batchnorm)',
+    'ScatterElements': r'(test_scatter_elements)',
+    'Conv': r'(test_basic_conv_|test_conv_with_|test_Conv2d)',
+    'MaxPool': r'(test_maxpool_2d)',
+    'AveragePool': r'(test_averagepool_2d)',
+}
 
+_exclude_nodes_patterns = [
+    # not support data type
+    r'(uint)',  # does not support uint
+    r'(scalar)',  # does not support scalar
+    r'(STRING)',  # does not support string
+    # not support some features
+    r'(test_split_zero_size_splits|test_slice_start_out_of_bounds)',  # not support empty tensor
+    r'(test_batchnorm_epsilon)',  # does not support epsilon
+    r'(dilations)',  # does not support dilations
+    r'(test_maxpool_2d_ceil|test_averagepool_2d_ceil)',  # does not ceil for max or avg pool
+    r'(count_include_pad)',  # pool not support count_include_pad
+    # interrupt some include patterns
+    r'(test_matmulinteger)',  # interrupt matmulinteger
+    r'(test_less_equal)',  # interrupt les
+    r'(test_greater_equal)',  # interrupt greater
+    r'(test_negative_log)',  # interrupt negative
+    r'(test_softmax_cross_entropy)',  # interrupt softmax
+    r'(test_reduce_sum_square)',  # interrupt reduce sum squre
+    r'(test_log_softmax)',  # interrupt log softmax
+    r'(test_maxunpool)',  # interrupt max unpool
+    r'(test_gather_elements)',  # interrupt gather elements
+    r'(test_logsoftmax)',  # interrupt log softmax
+    r'(test_gathernd)',  # interrupt gather nd
+    r'(test_maxpool_with_argmax)', # interrupt maxpool_with_argmax
+    # todo, some special error
+    r'test_transpose',  # the test cases are wrong
+    r'test_conv_with_strides_and_asymmetric_padding', # the test cases are wrong
+    r'(test_gemm_default_single_elem_vector_bias_cuda)',  # status == CURAND_STATUS_SUCCESS
+    r'(test_equal_bcast_cuda|test_equal_cuda)',  # Unknown combination of data type kInt and language kCuda
+    r'(test_maxpool_1d|test_averagepool_1d|test_maxpool_3d|test_averagepool_3d)',  # Check failed: idx < shape_.size() (3 vs. 3)
+    r'test_depthtospace.*cuda', # cuda cannot support transpose with more than 4 dims
+]
 
-def expect(node,
-           inputs,
-           outputs,
-           name,
-           opset_version=_default_opset_version,
-           decimal=5):
+_include_real_patterns = []  # todo
 
-    def _helper(dev):
-        onnx_node = sonnx.OnnxNode(node)
-        input_tensors = {}
-        input_labels = [x for x in onnx_node.inputs if x != ""]
-        # prepare input tensors
-        for key, val in zip(input_labels, inputs):
-            # very important! must be float
-            if not isinstance(val, np.ndarray) or len(val.shape) == 0:
-                val = np.array([val])
-            x = tensor.from_numpy(val.astype(np.float32))
-            x.to_device(dev)
-            input_tensors[key] = x
-        outputs_dict = sonnx.run_node(onnx_node, input_tensors, opset_version)
-        for out1, out2 in zip(outputs, outputs_dict.values()):
-            np.testing.assert_array_almost_equal(out1,
-                                                 tensor.to_numpy(out2),
-                                                 decimal=decimal)
+_include_simple_patterns = []  # todo
 
-    _helper(cpu_dev)
-    if (singa.USE_CUDA):
-        _helper(gpu_dev)
+_include_pytorch_converted_patterns = []  # todo
 
+_include_pytorch_operator_patterns = []  # todo
 
-class TestPythonOnnxBackend(unittest.TestCase):
-    """
-    This class aims to test the backend functionality of sonnx,
-    The most of the code is borrowed from onnx.
-    """
+# add supported operators into include patterns
+for name in sonnx.SingaBackend._rename_operators.keys():
+    if name not in _include_nodes_patterns:
+        backend_test.include(r'(test_{})'.format(name.lower()))
+    else:
+        # todo, need to fix the conv2d
+        if name == 'Conv':
+            continue
+        backend_test.include(_include_nodes_patterns[name])
 
-    def test_conv2d(self):
-        x = np.array([[[
-            [0., 1., 2., 3., 4.],  # (1, 1, 5, 5) input tensor
-            [5., 6., 7., 8., 9.],
-            [10., 11., 12., 13., 14.],
-            [15., 16., 17., 18., 19.],
-            [20., 21., 22., 23., 24.]
-        ]]]).astype(np.float32)
+# exclude the unsupported operators
+for pattern in _exclude_nodes_patterns:
+    backend_test.exclude(pattern)
 
-        W = np.array([[[
-            [1., 1., 1.],  # (1, 1, 3, 3) tensor for convolution weights
-            [1., 1., 1.],
-            [1., 1., 1.]
-        ]]]).astype(np.float32)
+# exclude the cuda cases
+if not singa.USE_CUDA:
+    backend_test.exclude(r'(cuda)')
 
-        # Convolution with padding
-        node_with_padding = onnx.helper.make_node(
-            'Conv',
-            inputs=['x', 'W'],
-            outputs=['y'],
-            kernel_shape=[3, 3],
-            # Default values for other attributes: strides=[1, 1], dilations=[1, 1], groups=1
-            pads=[1, 1, 1, 1],
-        )
+OnnxBackendNodeModelTest = backend_test.enable_report().test_cases['OnnxBackendNodeModelTest']
 
-        y_with_padding = np.array([[[
-            [12., 21., 27., 33., 24.],  # (1, 1, 5, 5) output tensor
-            [33., 54., 63., 72., 51.],
-            [63., 99., 108., 117., 81.],
-            [93., 144., 153., 162., 111.],
-            [72., 111., 117., 123., 84.]
-        ]]]).astype(np.float32)
+# disable and enable training before and after test cases
+def setUp(self):
+    # print("\nIn method", self._testMethodName)
+    autograd.training = False
 
-        expect(node_with_padding,
-               inputs=[x, W],
-               outputs=[y_with_padding],
-               name='test_basic_conv_with_padding')
+def tearDown(self):
+    autograd.training = True
 
-        # Convolution without padding
-        node_without_padding = onnx.helper.make_node(
-            'Conv',
-            inputs=['x', 'W'],
-            outputs=['y'],
-            kernel_shape=[3, 3],
-            # Default values for other attributes: strides=[1, 1], dilations=[1, 1], groups=1
-            pads=[0, 0, 0, 0],
-        )
-        y_without_padding = np.array([[[
-            [54., 63., 72.],  # (1, 1, 3, 3) output tensor
-            [99., 108., 117.],
-            [144., 153., 162.]
-        ]]]).astype(np.float32)
-        expect(node_without_padding,
-               inputs=[x, W],
-               outputs=[y_without_padding],
-               name='test_basic_conv_without_padding')
+OnnxBackendNodeModelTest.setUp = setUp
+OnnxBackendNodeModelTest.tearDown = tearDown
 
-    def test_conv2d_with_strides(self):  # type: () -> None
+# import all test cases at global scope to make them visible to python.unittest
+# print(backend_test.enable_report().test_cases)
+test_cases = {
+    'OnnxBackendNodeModelTest': OnnxBackendNodeModelTest
+}
 
-        x = np.array([[[
-            [0., 1., 2., 3., 4.],  # (1, 1, 7, 5) input tensor
-            [5., 6., 7., 8., 9.],
-            [10., 11., 12., 13., 14.],
-            [15., 16., 17., 18., 19.],
-            [20., 21., 22., 23., 24.],
-            [25., 26., 27., 28., 29.],
-            [30., 31., 32., 33., 34.]
-        ]]]).astype(np.float32)
-        W = np.array([[[
-            [1., 1., 1.],  # (1, 1, 3, 3) tensor for convolution weights
-            [1., 1., 1.],
-            [1., 1., 1.]
-        ]]]).astype(np.float32)
-
-        # Convolution with strides=2 and padding
-        node_with_padding = onnx.helper.make_node(
-            'Conv',
-            inputs=['x', 'W'],
-            outputs=['y'],
-            kernel_shape=[3, 3],
-            pads=[1, 1, 1, 1],
-            # Default values for other attributes: dilations=[1, 1], groups=1
-            strides=[2, 2],
-        )
-        y_with_padding = np.array([[[
-            [12., 27., 24.],  # (1, 1, 4, 3) output tensor
-            [63., 108., 81.],
-            [123., 198., 141.],
-            [112., 177., 124.]
-        ]]]).astype(np.float32)
-        expect(node_with_padding,
-               inputs=[x, W],
-               outputs=[y_with_padding],
-               name='test_conv_with_strides_padding')
-
-        # Convolution with strides=2 and no padding
-        node_without_padding = onnx.helper.make_node(
-            'Conv',
-            inputs=['x', 'W'],
-            outputs=['y'],
-            kernel_shape=[3, 3],
-            pads=[0, 0, 0, 0],
-            # Default values for other attributes: dilations=[1, 1], groups=1
-            strides=[2, 2],
-        )
-        y_without_padding = np.array([[[
-            [54., 72.],  # (1, 1, 3, 2) output tensor
-            [144., 162.],
-            [234., 252.]
-        ]]]).astype(np.float32)
-        expect(node_without_padding,
-               inputs=[x, W],
-               outputs=[y_without_padding],
-               name='test_conv_with_strides_no_padding')
-
-        # Convolution with strides=2 and padding only along one dimension (the H dimension in NxCxHxW tensor)
-        node_with_asymmetric_padding = onnx.helper.make_node(
-            'Conv',
-            inputs=['x', 'W'],
-            outputs=['y'],
-            kernel_shape=[3, 3],
-            pads=[1, 0, 1, 0],
-            # Default values for other attributes: dilations=[1, 1], groups=1
-            strides=[2, 2],
-        )
-        y_with_asymmetric_padding = np.array([[[
-            [21., 33.],  # (1, 1, 4, 2) output tensor
-            [99., 117.],
-            [189., 207.],
-            [171., 183.]
-        ]]]).astype(np.float32)
-        expect(node_with_asymmetric_padding,
-               inputs=[x, W],
-               outputs=[y_with_asymmetric_padding],
-               name='test_conv_with_strides_and_asymmetric_padding')
-
-    def test_averagepool_2d_precomputed_pads(self):  # type: () -> None
-        """
-        input_shape: [1, 1, 5, 5]
-        output_shape: [1, 1, 5, 5]
-        pad_shape: [4, 4] -> [2, 2, 2, 2] by axis
-        """
-        node = onnx.helper.make_node('AveragePool',
-                                     inputs=['x'],
-                                     outputs=['y'],
-                                     kernel_shape=[5, 5],
-                                     pads=[2, 2, 2, 2])
-        x = np.array([[[
-            [1, 2, 3, 4, 5],
-            [6, 7, 8, 9, 10],
-            [11, 12, 13, 14, 15],
-            [16, 17, 18, 19, 20],
-            [21, 22, 23, 24, 25],
-        ]]]).astype(np.float32)
-        y = np.array([[[[7, 7.5, 8, 8.5, 9], [9.5, 10, 10.5, 11, 11.5],
-                        [12, 12.5, 13, 13.5, 14], [14.5, 15, 15.5, 16, 16.5],
-                        [17, 17.5, 18, 18.5, 19]]]]).astype(np.float32)
-
-        expect(node,
-               inputs=[x],
-               outputs=[y],
-               name='test_averagepool_2d_precomputed_pads')
-
-    def test_averagepool_2d_precomputed_strides(self):  # type: () -> None
-        """
-        input_shape: [1, 1, 5, 5]
-        output_shape: [1, 1, 2, 2]
-        """
-        node = onnx.helper.make_node('AveragePool',
-                                     inputs=['x'],
-                                     outputs=['y'],
-                                     kernel_shape=[2, 2],
-                                     strides=[2, 2])
-        x = np.array([[[
-            [1, 2, 3, 4, 5],
-            [6, 7, 8, 9, 10],
-            [11, 12, 13, 14, 15],
-            [16, 17, 18, 19, 20],
-            [21, 22, 23, 24, 25],
-        ]]]).astype(np.float32)
-        y = np.array([[[[4, 6], [14, 16]]]]).astype(np.float32)
-
-        expect(node,
-               inputs=[x],
-               outputs=[y],
-               name='test_averagepool_2d_precomputed_strides')
-
-    def test_averagepool_2d_precomputed_same_upper(self):  # type: () -> None
-        """
-        input_shape: [1, 1, 5, 5]
-        output_shape: [1, 1, 3, 3]
-        pad_shape: [2, 2] -> [1, 1, 1, 1] by axis
-        """
-        node = onnx.helper.make_node('AveragePool',
-                                     inputs=['x'],
-                                     outputs=['y'],
-                                     kernel_shape=[3, 3],
-                                     strides=[2, 2],
-                                     auto_pad='SAME_UPPER')
-        x = np.array([[[
-            [1, 2, 3, 4, 5],
-            [6, 7, 8, 9, 10],
-            [11, 12, 13, 14, 15],
-            [16, 17, 18, 19, 20],
-            [21, 22, 23, 24, 25],
-        ]]]).astype(np.float32)
-        y = np.array([[[[4, 5.5, 7], [11.5, 13, 14.5],
-                        [19, 20.5, 22]]]]).astype(np.float32)
-
-        expect(node,
-               inputs=[x],
-               outputs=[y],
-               name='test_averagepool_2d_precomputed_same_upper')
-
-    def test_averagepool_2d_default(self):  # type: () -> None
-        """
-        input_shape: [1, 3, 32, 32]
-        output_shape: [1, 3, 31, 31]
-        """
-        node = onnx.helper.make_node(
-            'AveragePool',
-            inputs=['x'],
-            outputs=['y'],
-            kernel_shape=[2, 2],
-        )
-        x = np.random.randn(1, 3, 32, 32).astype(np.float32)
-        x_shape = np.shape(x)
-        kernel_shape = (2, 2)
-        strides = (1, 1)
-        out_shape = get_output_shape('VALID', x_shape[2:], kernel_shape,
-                                     strides)
-        padded = x
-        y = pool(padded, x_shape, kernel_shape, strides, out_shape, (0, 0),
-                 'AVG')
-
-        expect(node,
-               inputs=[x],
-               outputs=[y],
-               name='test_averagepool_2d_default')
-
-    def test_averagepool_2d_pads(self):  # type: () -> None
-        """
-        input_shape: [1, 3, 28, 28]
-        output_shape: [1, 3, 30, 30]
-        pad_shape: [4, 4] -> [2, 2, 2, 2] by axis
-        """
-        node = onnx.helper.make_node('AveragePool',
-                                     inputs=['x'],
-                                     outputs=['y'],
-                                     kernel_shape=[3, 3],
-                                     pads=[2, 2, 2, 2])
-        x = np.random.randn(1, 3, 28, 28).astype(np.float32)
-        x_shape = np.shape(x)
-        kernel_shape = (3, 3)
-        strides = (1, 1)
-        pad_bottom = 2
-        pad_top = 2
-        pad_right = 2
-        pad_left = 2
-        pad_shape = [pad_top + pad_bottom, pad_left + pad_right]
-        out_shape = get_output_shape('VALID', np.add(x_shape[2:], pad_shape),
-                                     kernel_shape, strides)
-        padded = np.pad(x, ((0, 0), (0, 0), (pad_top, pad_bottom),
-                            (pad_left, pad_right)),
-                        mode='constant',
-                        constant_values=np.nan)
-        y = pool(padded, x_shape, kernel_shape, strides, out_shape, pad_shape,
-                 'AVG')
-
-        expect(node, inputs=[x], outputs=[y], name='test_averagepool_2d_pads')
-
-    def test_averagepool_2d_strides(self):  # type: () -> None
-        """
-        input_shape: [1, 3, 32, 32]
-        output_shape: [1, 3, 10, 10]
-        """
-        node = onnx.helper.make_node('AveragePool',
-                                     inputs=['x'],
-                                     outputs=['y'],
-                                     kernel_shape=[5, 5],
-                                     strides=[3, 3])
-        x = np.random.randn(1, 3, 32, 32).astype(np.float32)
-        x_shape = np.shape(x)
-        kernel_shape = (5, 5)
-        strides = (3, 3)
-        out_shape = get_output_shape('VALID', x_shape[2:], kernel_shape,
-                                     strides)
-        padded = x
-        y = pool(padded, x_shape, kernel_shape, strides, out_shape, (0, 0),
-                 'AVG')
-
-        expect(node,
-               inputs=[x],
-               outputs=[y],
-               name='test_averagepool_2d_strides')
-
-    def test_maxpool_2d_precomputed_pads(self):  # type: () -> None
-        """
-        input_shape: [1, 1, 5, 5]
-        output_shape: [1, 1, 5, 5]
-        pad_shape: [4, 4] -> [2, 2, 2, 2] by axis
-        """
-        node = onnx.helper.make_node('MaxPool',
-                                     inputs=['x'],
-                                     outputs=['y'],
-                                     kernel_shape=[5, 5],
-                                     pads=[2, 2, 2, 2])
-        x = np.array([[[
-            [1, 2, 3, 4, 5],
-            [6, 7, 8, 9, 10],
-            [11, 12, 13, 14, 15],
-            [16, 17, 18, 19, 20],
-            [21, 22, 23, 24, 25],
-        ]]]).astype(np.float32)
-        y = np.array([[[[13, 14, 15, 15, 15], [18, 19, 20, 20, 20],
-                        [23, 24, 25, 25, 25], [23, 24, 25, 25, 25],
-                        [23, 24, 25, 25, 25]]]]).astype(np.float32)
-
-        expect(node,
-               inputs=[x],
-               outputs=[y],
-               name='test_maxpool_2d_precomputed_pads')
-
-    def test_maxpool_with_argmax_2d_precomputed_pads(self):  # type: () -> None
-        """
-        input_shape: [1, 1, 5, 5]
-        output_shape: [1, 1, 5, 5]
-        pad_shape: [4, 4] -> [2, 2, 2, 2] by axis
-        """
-        node = onnx.helper.make_node('MaxPool',
-                                     inputs=['x'],
-                                     outputs=['y', 'z'],
-                                     kernel_shape=[5, 5],
-                                     pads=[2, 2, 2, 2])
-        x = np.array([[[
-            [1, 2, 3, 4, 5],
-            [6, 7, 8, 9, 10],
-            [11, 12, 13, 14, 15],
-            [16, 17, 18, 19, 20],
-            [21, 22, 23, 24, 25],
-        ]]]).astype(np.float32)
-        y = np.array([[[[13, 14, 15, 15, 15], [18, 19, 20, 20, 20],
-                        [23, 24, 25, 25, 25], [23, 24, 25, 25, 25],
-                        [23, 24, 25, 25, 25]]]]).astype(np.float32)
-        z = np.array([[[[12, 13, 14, 14, 14], [17, 18, 19, 19, 19],
-                        [22, 23, 24, 24, 24], [22, 23, 24, 24, 24],
-                        [22, 23, 24, 24, 24]]]]).astype(np.int64)
-
-        expect(node,
-               inputs=[x],
-               outputs=[y, z],
-               name='test_maxpool_with_argmax_2d_precomputed_pads')
-
-    def test_maxpool_2d_precomputed_strides(self):  # type: () -> None
-        """
-        input_shape: [1, 1, 5, 5]
-        output_shape: [1, 1, 2, 2]
-        """
-        node = onnx.helper.make_node('MaxPool',
-                                     inputs=['x'],
-                                     outputs=['y'],
-                                     kernel_shape=[2, 2],
-                                     strides=[2, 2])
-        x = np.array([[[
-            [1, 2, 3, 4, 5],
-            [6, 7, 8, 9, 10],
-            [11, 12, 13, 14, 15],
-            [16, 17, 18, 19, 20],
-            [21, 22, 23, 24, 25],
-        ]]]).astype(np.float32)
-        y = np.array([[[[7, 9], [17, 19]]]]).astype(np.float32)
-
-        expect(node,
-               inputs=[x],
-               outputs=[y],
-               name='test_maxpool_2d_precomputed_strides')
-
-    def test_maxpool_with_argmax_2d_precomputed_strides(
-        self):  # type: () -> None
-        """
-        input_shape: [1, 1, 5, 5]
-        output_shape: [1, 1, 2, 2]
-        """
-        node = onnx.helper.make_node('MaxPool',
-                                     inputs=['x'],
-                                     outputs=['y', 'z'],
-                                     kernel_shape=[2, 2],
-                                     strides=[2, 2],
-                                     storage_order=1)
-        x = np.array([[[
-            [1, 2, 3, 4, 5],
-            [6, 7, 8, 9, 10],
-            [11, 12, 13, 14, 15],
-            [16, 17, 18, 19, 20],
-            [21, 22, 23, 24, 25],
-        ]]]).astype(np.float32)
-        y = np.array([[[[7, 9], [17, 19]]]]).astype(np.float32)
-        z = np.array([[[[6, 16], [8, 18]]]]).astype(np.int64)
-
-        expect(node,
-               inputs=[x],
-               outputs=[y, z],
-               name='test_maxpool_with_argmax_2d_precomputed_strides')
-
-    def test_maxpool_2d_precomputed_same_upper(self):  # type: () -> None
-        """
-        input_shape: [1, 1, 5, 5]
-        output_shape: [1, 1, 3, 3]
-        pad_shape: [2, 2] -> [1, 1, 1, 1] by axis
-        """
-        node = onnx.helper.make_node('MaxPool',
-                                     inputs=['x'],
-                                     outputs=['y'],
-                                     kernel_shape=[3, 3],
-                                     strides=[2, 2],
-                                     auto_pad='SAME_UPPER')
-        x = np.array([[[
-            [1, 2, 3, 4, 5],
-            [6, 7, 8, 9, 10],
-            [11, 12, 13, 14, 15],
-            [16, 17, 18, 19, 20],
-            [21, 22, 23, 24, 25],
-        ]]]).astype(np.float32)
-        y = np.array([[[[7, 9, 10], [17, 19, 20], [22, 24,
-                                                   25]]]]).astype(np.float32)
-
-        expect(node,
-               inputs=[x],
-               outputs=[y],
-               name='test_maxpool_2d_precomputed_same_upper')
-
-    def test_maxpool_2d_default(self):  # type: () -> None
-        """
-        input_shape: [1, 3, 32, 32]
-        output_shape: [1, 3, 31, 31]
-        """
-        node = onnx.helper.make_node(
-            'MaxPool',
-            inputs=['x'],
-            outputs=['y'],
-            kernel_shape=[2, 2],
-        )
-        x = np.random.randn(1, 3, 32, 32).astype(np.float32)
-        x_shape = np.shape(x)
-        kernel_shape = (2, 2)
-        strides = (1, 1)
-        out_shape = get_output_shape('VALID', x_shape[2:], kernel_shape,
-                                     strides)
-        padded = x
-        y = pool(padded, x_shape, kernel_shape, strides, out_shape, (0, 0),
-                 'MAX')
-
-        expect(node, inputs=[x], outputs=[y], name='test_maxpool_2d_default')
-
-    def test_maxpool_2d_pads(self):  # type: () -> None
-        """
-        input_shape: [1, 3, 28, 28]
-        output_shape: [1, 3, 30, 30]
-        pad_shape: [4, 4] -> [2, 2, 2, 2] by axis
-        """
-        node = onnx.helper.make_node('MaxPool',
-                                     inputs=['x'],
-                                     outputs=['y'],
-                                     kernel_shape=[3, 3],
-                                     pads=[2, 2, 2, 2])
-        x = np.random.randn(1, 3, 28, 28).astype(np.float32)
-        x_shape = np.shape(x)
-        kernel_shape = (3, 3)
-        strides = (1, 1)
-        pad_bottom = pad_top = pad_right = pad_left = 2
-        pad_shape = [pad_top + pad_bottom, pad_left + pad_right]
-        out_shape = get_output_shape('VALID', np.add(x_shape[2:], pad_shape),
-                                     kernel_shape, strides)
-        padded = np.pad(x, ((0, 0), (0, 0), (pad_top, pad_bottom),
-                            (pad_left, pad_right)),
-                        mode='constant',
-                        constant_values=np.nan)
-        y = pool(padded, x_shape, kernel_shape, strides, out_shape, pad_shape,
-                 'MAX')
-
-        expect(node, inputs=[x], outputs=[y], name='test_maxpool_2d_pads')
-
-    def test_maxpool_2d_strides(self):  # type: () -> None
-        """
-        input_shape: [1, 3, 32, 32]
-        output_shape: [1, 3, 10, 10]
-        """
-        node = onnx.helper.make_node('MaxPool',
-                                     inputs=['x'],
-                                     outputs=['y'],
-                                     kernel_shape=[5, 5],
-                                     strides=[3, 3])
-        x = np.random.randn(1, 3, 32, 32).astype(np.float32)
-        x_shape = np.shape(x)
-        kernel_shape = (5, 5)
-        strides = (3, 3)
-        out_shape = get_output_shape('VALID', x_shape[2:], kernel_shape,
-                                     strides)
-        padded = x
-        y = pool(padded, x_shape, kernel_shape, strides, out_shape, (0, 0),
-                 'MAX')
-
-        expect(node, inputs=[x], outputs=[y], name='test_maxpool_2d_strides')
-
-    def test_reshape(self):  # type: () -> None
-
-        def reshape_reference_implementation(
-            data, shape):  # type: (np.ndarray, np.ndarray) -> np.ndarray
-            # replace zeros with corresponding dim size
-            # we need to do this because np.reshape doesn't support 0
-            new_shape = np.copy(shape)
-            zeros_index = np.where(shape == 0)
-            new_shape[zeros_index] = np.array(data.shape)[zeros_index]
-            reshaped = np.reshape(data, new_shape)
-            return reshaped
-
-        original_shape = [2, 3, 4]
-        test_cases = {
-            'reordered_all_dims': np.array([4, 2, 3], dtype=np.int64),
-            'reordered_last_dims': np.array([2, 4, 3], dtype=np.int64),
-            'reduced_dims': np.array([2, 12], dtype=np.int64),
-            'extended_dims': np.array([2, 3, 2, 2], dtype=np.int64),
-            'one_dim': np.array([24], dtype=np.int64),
-            'negative_dim': np.array([2, -1, 2], dtype=np.int64),
-            'negative_extended_dims': np.array([-1, 2, 3, 4], dtype=np.int64),
-            'zero_dim': np.array([2, 0, 4, 1], dtype=np.int64),
-            'zero_and_negative_dim': np.array([2, 0, 1, -1], dtype=np.int64),
-        }
-        data = np.random.random_sample(original_shape).astype(np.float32)
-
-        for test_name, shape in test_cases.items():
-            node = onnx.helper.make_node(
-                'Reshape',
-                inputs=['data', 'shape'],
-                outputs=['reshaped'],
-            )
-
-            reshaped = reshape_reference_implementation(data, shape)
-
-            expect(node,
-                   inputs=[data, shape],
-                   outputs=[reshaped],
-                   name='test_reshape_' + test_name)
-
-    def test_concat(self):  # type: () -> None
-        test_cases = {
-            # '1d': ([1, 2], not support 1d
-            #    [3, 4]),
-            '2d': ([[1, 2], [3, 4]], [[5, 6], [7, 8]]),
-            '3d': ([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], [[[9, 10], [11, 12]],
-                                                          [[13, 14], [15, 16]]])
-        }  # type: Dict[Text, Sequence[Any]]
-
-        for test_case, values_ in test_cases.items():
-            values = [np.asarray(v, dtype=np.float32) for v in values_]
-            for i in range(len(values[0].shape)):
-                in_args = ['value' + str(k) for k in range(len(values))]
-                node = onnx.helper.make_node('Concat',
-                                             inputs=[s for s in in_args],
-                                             outputs=['output'],
-                                             axis=i)
-                output = np.concatenate(values, i)
-                expect(node,
-                       inputs=[v for v in values],
-                       outputs=[output],
-                       name='test_concat_' + test_case + '_axis_' + str(i))
-
-            for i in range(-len(values[0].shape), 0):
-                in_args = ['value' + str(k) for k in range(len(values))]
-                node = onnx.helper.make_node('Concat',
-                                             inputs=[s for s in in_args],
-                                             outputs=['output'],
-                                             axis=i)
-                output = np.concatenate(values, i)
-                expect(node,
-                       inputs=[v for v in values],
-                       outputs=[output],
-                       name='test_concat_' + test_case + '_axis_negative_' +
-                       str(abs(i)))
-
-    def test_flatten(self):  # type: () -> None
-        shape = (2, 3, 4, 5)
-        a = np.random.random_sample(shape).astype(np.float32)
-
-        for i in range(len(shape)):
-            node = onnx.helper.make_node(
-                'Flatten',
-                inputs=['a'],
-                outputs=['b'],
-                axis=i,
-            )
-
-            new_shape = (1, -1) if i == 0 else (np.prod(shape[0:i]).astype(int),
-                                                -1)
-            b = np.reshape(a, new_shape)
-            expect(node,
-                   inputs=[a],
-                   outputs=[b],
-                   name='test_flatten_axis' + str(i))
-
-    def test_flatten_with_default_axis(self):  # type: () -> None
-        node = onnx.helper.make_node(
-            'Flatten',
-            inputs=['a'],
-            outputs=['b'],  # Default value for axis: axis=1
-        )
-
-        shape = (5, 4, 3, 2)
-        a = np.random.random_sample(shape).astype(np.float32)
-        new_shape = (5, 24)
-        b = np.reshape(a, new_shape)
-        expect(node, inputs=[a], outputs=[b], name='test_flatten_default_axis')
-
-    def test_flatten_negative_axis(self):  # type: () -> None
-        shape = (2, 3, 4, 5)
-        a = np.random.random_sample(shape).astype(np.float32)
-
-        for i in range(-len(shape), 0):
-            node = onnx.helper.make_node(
-                'Flatten',
-                inputs=['a'],
-                outputs=['b'],
-                axis=i,
-            )
-
-            new_shape = (np.prod(shape[0:i]).astype(int), -1)
-            b = np.reshape(a, new_shape)
-            expect(node,
-                   inputs=[a],
-                   outputs=[b],
-                   name='test_flatten_negative_axis' + str(abs(i)))
-
-    def test_add(self):  # type: () -> None
-        node = onnx.helper.make_node(
-            'Add',
-            inputs=['x', 'y'],
-            outputs=['sum'],
-        )
-
-        x = np.random.randn(3, 4, 5).astype(np.float32)
-        y = np.random.randn(3, 4, 5).astype(np.float32)
-        expect(node, inputs=[x, y], outputs=[x + y], name='test_add')
-
-    def test_add_broadcast(self):  # type: () -> None
-        node = onnx.helper.make_node(
-            'Add',
-            inputs=['x', 'y'],
-            outputs=['sum'],
-        )
-
-        # todo, we don't support 3d here
-        x = np.random.randn(3, 4, 5).astype(np.float32)
-        y = np.random.randn(5).astype(np.float32)
-        expect(node, inputs=[x, y], outputs=[x + y], name='test_add_bcast')
-
-    def test_sum(self):  # type: () -> None
-        data_0 = np.array([3, 0, 2]).astype(np.float32)
-        data_1 = np.array([1, 3, 4]).astype(np.float32)
-        data_2 = np.array([2, 6, 6]).astype(np.float32)
-        result = np.array([6, 9, 12]).astype(np.float32)
-        node = onnx.helper.make_node(
-            'Sum',
-            inputs=['data_0', 'data_1', 'data_2'],
-            outputs=['result'],
-        )
-        expect(node,
-               inputs=[data_0, data_1, data_2],
-               outputs=[result],
-               name='test_sum_example')
-
-        node = onnx.helper.make_node(
-            'Sum',
-            inputs=['data_0'],
-            outputs=['result'],
-        )
-        expect(node,
-               inputs=[data_0],
-               outputs=[data_0],
-               name='test_sum_one_input')
-
-        result = np.add(data_0, data_1)
-        node = onnx.helper.make_node(
-            'Sum',
-            inputs=['data_0', 'data_1'],
-            outputs=['result'],
-        )
-        expect(node,
-               inputs=[data_0, data_1],
-               outputs=[result],
-               name='test_sum_two_inputs')
-
-    def test_relu(self):  # type: () -> None
-        node = onnx.helper.make_node(
-            'Relu',
-            inputs=['x'],
-            outputs=['y'],
-        )
-        x = np.random.randn(3, 4, 5).astype(np.float32)
-        y = np.clip(x, 0, np.inf)
-
-        expect(node, inputs=[x], outputs=[y], name='test_relu')
-
-    def test_sigmoid(self):  # type: () -> None
-        node = onnx.helper.make_node(
-            'Sigmoid',
-            inputs=['x'],
-            outputs=['y'],
-        )
-
-        x = np.array([-1, 0, 1]).astype(np.float32)
-        # expected output [0.26894143, 0.5, 0.7310586]
-        y = 1.0 / (1.0 + np.exp(np.negative(x)))
-        expect(node, inputs=[x], outputs=[y], name='test_sigmoid_example')
-
-        x = np.random.randn(3, 4, 5).astype(np.float32)
-        y = 1.0 / (1.0 + np.exp(np.negative(x)))
-        expect(node, inputs=[x], outputs=[y], name='test_sigmoid')
-
-    def test_matmul(self):  # type: () -> None
-        node = onnx.helper.make_node(
-            'MatMul',
-            inputs=['a', 'b'],
-            outputs=['c'],
-        )
-
-        # 2d
-        a = np.random.randn(3, 4).astype(np.float32)
-        b = np.random.randn(4, 3).astype(np.float32)
-        c = np.matmul(a, b)
-        expect(node, inputs=[a, b], outputs=[c], name='test_matmul_2d')
-
-        # todo, # 3d not support 3d
-        # a = np.random.randn(2, 3, 4).astype(np.float32)
-        # b = np.random.randn(2, 4, 3).astype(np.float32)
-        # c = np.matmul(a, b)
-        # expect(node, inputs=[a, b], outputs=[c],
-        #        name='test_matmul_3d')
-
-        # todo, # 4d not support 4d
-        # a = np.random.randn(1, 2, 3, 4).astype(np.float32)
-        # b = np.random.randn(1, 2, 4, 3).astype(np.float32)
-        # c = np.matmul(a, b)
-        # expect(node, inputs=[a, b], outputs=[c],
-        #        name='test_matmul_4d')
-
-    def test_cos(self):  # type: () -> None
-        node = onnx.helper.make_node(
-            'Cos',
-            inputs=['x'],
-            outputs=['y'],
-        )
-
-        x = np.array([-1, 0, 1]).astype(np.float32)
-        y = np.cos(x)
-        expect(node, inputs=[x], outputs=[y], name='test_cos_example')
-
-        x = np.random.randn(3, 4, 5).astype(np.float32)
-        y = np.cos(x)
-        expect(node, inputs=[x], outputs=[y], name='test_cos')
-
-    def test_cosh(self):  # type: () -> None
-        node = onnx.helper.make_node(
-            'Cosh',
-            inputs=['x'],
-            outputs=['y'],
-        )
-
-        x = np.array([-1, 0, 1]).astype(np.float32)
-        y = np.cosh(x)  # expected output [1.54308069,  1.,  1.54308069]
-        expect(node, inputs=[x], outputs=[y], name='test_cosh_example')
-
-        x = np.random.randn(3, 4, 5).astype(np.float32)
-        y = np.cosh(x)
-        expect(node, inputs=[x], outputs=[y], name='test_cosh')
-
-    def test_Sin(self):  # type: () -> None
-        node = onnx.helper.make_node(
-            'Sin',
-            inputs=['x'],
-            outputs=['y'],
-        )
-
-        x = np.array([-1, 0, 1]).astype(np.float32)
-        y = np.sin(x)
-        expect(node, inputs=[x], outputs=[y], name='test_sin_example')
-
-        x = np.random.randn(3, 4, 5).astype(np.float32)
-        y = np.sin(x)
-        expect(node, inputs=[x], outputs=[y], name='test_sin')
-
-    def test_Sinh(self):  # type: () -> None
-        node = onnx.helper.make_node(
-            'Sinh',
-            inputs=['x'],
-            outputs=['y'],
-        )
-
-        x = np.array([-1, 0, 1]).astype(np.float32)
-        y = np.sinh(x)  # expected output [-1.17520118,  0.,  1.17520118]
-        expect(node, inputs=[x], outputs=[y], name='test_sinh_example')
-
-        x = np.random.randn(3, 4, 5).astype(np.float32)
-        y = np.sinh(x)
-        expect(node, inputs=[x], outputs=[y], name='test_sinh')
-
-    def test_Tan(self):  # type: () -> None
-        node = onnx.helper.make_node(
-            'Tan',
-            inputs=['x'],
-            outputs=['y'],
-        )
-
-        x = np.array([-1, 0, 1]).astype(np.float32)
-        y = np.tan(x)
-        expect(node, inputs=[x], outputs=[y], name='test_tan_example')
-
-        x = np.random.randn(3, 4, 5).astype(np.float32)
-        y = np.tan(x)
-        expect(node, inputs=[x], outputs=[y], name='test_tan', decimal=3)
-
-    def test_Tanh(self):  # type: () -> None
-        node = onnx.helper.make_node(
-            'Tanh',
-            inputs=['x'],
-            outputs=['y'],
-        )
-
-        x = np.array([-1, 0, 1]).astype(np.float32)
-        y = np.tanh(x)  # expected output [-0.76159418, 0., 0.76159418]
-        expect(node, inputs=[x], outputs=[y], name='test_tanh_example')
-
-        x = np.random.randn(3, 4, 5).astype(np.float32)
-        y = np.tanh(x)
-        expect(node, inputs=[x], outputs=[y], name='test_tanh')
-
-    def test_Acos(self):  # type: () -> None
-        node = onnx.helper.make_node(
-            'Acos',
-            inputs=['x'],
-            outputs=['y'],
-        )
-
-        x = np.array([-0.5, 0, 0.5]).astype(np.float32)
-        y = np.arccos(x)
-        expect(node, inputs=[x], outputs=[y], name='test_acos_example')
-
-        x = np.random.rand(3, 4, 5).astype(np.float32)
-        y = np.arccos(x)
-        expect(node, inputs=[x], outputs=[y], name='test_acos')
-
-    def test_Acosh(self):  # type: () -> None
-        node = onnx.helper.make_node(
-            'Acosh',
-            inputs=['x'],
-            outputs=['y'],
-        )
-
-        x = np.array([10, np.e, 1]).astype(np.float32)
-        y = np.arccosh(x)  # expected output [2.99322295,  1.65745449,  0.]
-        expect(node, inputs=[x], outputs=[y], name='test_acosh_example')
-
-        x = np.random.uniform(1.0, 10.0, (3, 4, 5)).astype(np.float32)
-        y = np.arccosh(x)
-        expect(node, inputs=[x], outputs=[y], name='test_acosh')
-
-    def test_Asin(self):  # type: () -> None
-        node = onnx.helper.make_node(
-            'Asin',
-            inputs=['x'],
-            outputs=['y'],
-        )
-
-        x = np.array([-0.5, 0, 0.5]).astype(np.float32)
-        y = np.arcsin(x)
-        expect(node, inputs=[x], outputs=[y], name='test_asin_example')
-
-        x = np.random.rand(3, 4, 5).astype(np.float32)
-        y = np.arcsin(x)
-        expect(node, inputs=[x], outputs=[y], name='test_asin')
-
-    def test_Asinh(self):  # type: () -> None
-        node = onnx.helper.make_node(
-            'Asinh',
-            inputs=['x'],
-            outputs=['y'],
-        )
-
-        x = np.array([-1, 0, 1]).astype(np.float32)
-        y = np.arcsinh(x)  # expected output [-0.88137358,  0.,  0.88137358]
-        expect(node, inputs=[x], outputs=[y], name='test_asinh_example')
-
-        x = np.random.randn(3, 4, 5).astype(np.float32)
-        y = np.arcsinh(x)
-        expect(node, inputs=[x], outputs=[y], name='test_asinh')
-
-    def test_Atan(self):  # type: () -> None
-        node = onnx.helper.make_node(
-            'Atan',
-            inputs=['x'],
-            outputs=['y'],
-        )
-
-        x = np.array([-1, 0, 1]).astype(np.float32)
-        y = np.arctan(x)
-        expect(node, inputs=[x], outputs=[y], name='test_atan_example')
-
-        x = np.random.randn(3, 4, 5).astype(np.float32)
-        y = np.arctan(x)
-        expect(node, inputs=[x], outputs=[y], name='test_atan')
-
-    def test_Atanh(self):  # type: () -> None
-        node = onnx.helper.make_node(
-            'Atanh',
-            inputs=['x'],
-            outputs=['y'],
-        )
-
-        x = np.array([-0.5, 0, 0.5]).astype(np.float32)
-        y = np.arctanh(x)  # expected output [-0.54930615,  0.,  0.54930615]
-        expect(node, inputs=[x], outputs=[y], name='test_atanh_example')
-
-        x = np.random.uniform(0.0, 1.0, (3, 4, 5)).astype(np.float32)
-        y = np.arctanh(x)
-        expect(node, inputs=[x], outputs=[y], name='test_atanh')
-
-    def test_selu(self):  # type: () -> None
-        node = onnx.helper.make_node('Selu',
-                                     inputs=['x'],
-                                     outputs=['y'],
-                                     alpha=2.0,
-                                     gamma=3.0)
-
-        x = np.array([-1, 0, 1]).astype(np.float32)
-        # expected output [-3.79272318, 0., 3.]
-        y = np.clip(x, 0, np.inf) * 3.0 + \
-            (np.exp(np.clip(x, -np.inf, 0)) - 1) * 2.0 * 3.0
-        expect(node, inputs=[x], outputs=[y], name='test_selu_example')
-
-        x = np.random.randn(3, 4, 5).astype(np.float32)
-        y = np.clip(x, 0, np.inf) * 3.0 + \
-            (np.exp(np.clip(x, -np.inf, 0)) - 1) * 2.0 * 3.0
-        expect(node, inputs=[x], outputs=[y], name='test_selu')
-
-    def test_selu_default(self):  # type: () -> None
-        default_alpha = 1.67326319217681884765625
-        default_gamma = 1.05070102214813232421875
-        node = onnx.helper.make_node(
-            'Selu',
-            inputs=['x'],
-            outputs=['y'],
-        )
-        x = np.random.randn(3, 4, 5).astype(np.float32)
-        y = np.clip(x, 0, np.inf) * default_gamma + \
-            (np.exp(np.clip(x, -np.inf, 0)) - 1) * default_alpha * default_gamma
-        expect(node, inputs=[x], outputs=[y], name='test_selu_default')
-
-    def test_elu(self):  # type: () -> None
-        node = onnx.helper.make_node('Elu',
-                                     inputs=['x'],
-                                     outputs=['y'],
-                                     alpha=2.0)
-
-        x = np.array([-1, 0, 1]).astype(np.float32)
-        # expected output [-1.2642411, 0., 1.]
-        y = np.clip(x, 0, np.inf) + (np.exp(np.clip(x, -np.inf, 0)) - 1) * 2.0
-        expect(node, inputs=[x], outputs=[y], name='test_elu_example')
-
-        x = np.random.randn(3, 4, 5).astype(np.float32)
-        y = np.clip(x, 0, np.inf) + (np.exp(np.clip(x, -np.inf, 0)) - 1) * 2.0
-        expect(node, inputs=[x], outputs=[y], name='test_elu')
-
-    def test_elu_default(self):  # type: () -> None
-        default_alpha = 1.0
-        node = onnx.helper.make_node(
-            'Elu',
-            inputs=['x'],
-            outputs=['y'],
-        )
-        x = np.random.randn(3, 4, 5).astype(np.float32)
-        y = np.clip(x, 0, np.inf) + \
-            (np.exp(np.clip(x, -np.inf, 0)) - 1) * default_alpha
-        expect(node, inputs=[x], outputs=[y], name='test_elu_default')
-
-    def test_equal(self):  # type: () -> None
-        node = onnx.helper.make_node(
-            'Equal',
-            inputs=['x', 'y'],
-            outputs=['z'],
-        )
-
-        x = (np.random.randn(3, 4, 5) * 10).astype(np.int32)
-        y = (np.random.randn(3, 4, 5) * 10).astype(np.int32)
-        z = np.equal(x, y)
-
-        expect(node, inputs=[x, y], outputs=[z], name='test_equal')
-
-    def test_equal_broadcast(self):  # type: () -> None
-        node = onnx.helper.make_node(
-            'Equal',
-            inputs=['x', 'y'],
-            outputs=['z'],
-        )
-
-        x = (np.random.randn(3, 4, 5) * 10).astype(np.int32)
-        y = (np.random.randn(5) * 10).astype(np.int32)
-        z = np.equal(x, y).astype(np.int32)  # need to convert to int type
-        expect(node, inputs=[x, y], outputs=[z], name='test_equal_bcast')
-
-    def test_less(self):  # type: () -> None
-        node = onnx.helper.make_node(
-            'Less',
-            inputs=['x', 'y'],
-            outputs=['less'],
-        )
-
-        x = np.random.randn(3, 4, 5).astype(np.float32)
-        y = np.random.randn(3, 4, 5).astype(np.float32)
-        z = np.less(x, y)
-        expect(node, inputs=[x, y], outputs=[z], name='test_less')
-
-    def test_less_broadcast(self):  # type: () -> None
-        node = onnx.helper.make_node(
-            'Less',
-            inputs=['x', 'y'],
-            outputs=['less'],
-        )
-
-        x = np.random.randn(3, 4, 5).astype(np.float32)
-        y = np.random.randn(5).astype(np.float32)
-        z = np.less(x, y)
-        expect(node, inputs=[x, y], outputs=[z], name='test_less_bcast')
-
-    def test_sign(self):  # type: () -> None
-        node = onnx.helper.make_node(
-            'Sign',
-            inputs=['x'],
-            outputs=['y'],
-        )
-
-        x = np.array(range(-5, 6)).astype(np.float32)
-        y = np.sign(x)
-        expect(node, inputs=[x], outputs=[y], name='test_sign')
-
-    def test_sub(self):  # type: () -> None
-        node = onnx.helper.make_node(
-            'Sub',
-            inputs=['x', 'y'],
-            outputs=['z'],
-        )
-
-        x = np.array([1, 2, 3]).astype(np.float32)
-        y = np.array([3, 2, 1]).astype(np.float32)
-        z = x - y  # expected output [-2., 0., 2.]
-        expect(node, inputs=[x, y], outputs=[z], name='test_sub_example')
-
-        x = np.random.randn(3, 4, 5).astype(np.float32)
-        y = np.random.randn(3, 4, 5).astype(np.float32)
-        z = x - y
-        expect(node, inputs=[x, y], outputs=[z], name='test_sub')
-
-    def test_sub_broadcast(self):  # type: () -> None
-        node = onnx.helper.make_node(
-            'Sub',
-            inputs=['x', 'y'],
-            outputs=['z'],
-        )
-
-        x = np.random.randn(3, 4, 5).astype(np.float32)
-        y = np.random.randn(5).astype(np.float32)
-        z = x - y
-        expect(node, inputs=[x, y], outputs=[z], name='test_sub_bcast')
-
-    def test_sqrt(self):  # type: () -> None
-        node = onnx.helper.make_node(
-            'Sqrt',
-            inputs=['x'],
-            outputs=['y'],
-        )
-
-        x = np.array([1, 4, 9]).astype(np.float32)
-        y = np.sqrt(x)  # expected output [1., 2., 3.]
-        expect(node, inputs=[x], outputs=[y], name='test_sqrt_example')
-
-        x = np.abs(np.random.randn(3, 4, 5).astype(np.float32))
-        y = np.sqrt(x)
-        expect(node, inputs=[x], outputs=[y], name='test_sqrt')
-
-    def test_log(self):  # type: () -> None
-        node = onnx.helper.make_node(
-            'Log',
-            inputs=['x'],
-            outputs=['y'],
-        )
-
-        x = np.array([1, 10]).astype(np.float32)
-        y = np.log(x)  # expected output [0., 2.30258512]
-        expect(node, inputs=[x], outputs=[y], name='test_log_example')
-
-        x = np.exp(np.random.randn(3, 4, 5).astype(np.float32))
-        y = np.log(x)
-        expect(node, inputs=[x], outputs=[y], name='test_log')
-
-    def test_greater(self):  # type: () -> None
-        node = onnx.helper.make_node(
-            'Greater',
-            inputs=['x', 'y'],
-            outputs=['greater'],
-        )
-
-        x = np.random.randn(3, 4, 5).astype(np.float32)
-        y = np.random.randn(3, 4, 5).astype(np.float32)
-        z = np.greater(x, y)
-        expect(node, inputs=[x, y], outputs=[z], name='test_greater')
-
-    def test_greater_broadcast(self):  # type: () -> None
-        node = onnx.helper.make_node(
-            'Greater',
-            inputs=['x', 'y'],
-            outputs=['greater'],
-        )
-
-        x = np.random.randn(3, 4, 5).astype(np.float32)
-        y = np.random.randn(5).astype(np.float32)
-        z = np.greater(x, y)
-        expect(node, inputs=[x, y], outputs=[z], name='test_greater_bcast')
-
-    def test_hardsigmoid(self):  # type: () -> None
-        node = onnx.helper.make_node('HardSigmoid',
-                                     inputs=['x'],
-                                     outputs=['y'],
-                                     alpha=0.5,
-                                     beta=0.6)
-
-        x = np.array([-1, 0, 1]).astype(np.float32)
-        y = np.clip(x * 0.5 + 0.6, 0, 1)  # expected output [0.1, 0.6, 1.]
-        expect(node, inputs=[x], outputs=[y], name='test_hardsigmoid_example')
-
-        x = np.random.randn(3, 4, 5).astype(np.float32)
-        y = np.clip(x * 0.5 + 0.6, 0, 1)
-        expect(node, inputs=[x], outputs=[y], name='test_hardsigmoid')
-
-    def test_hardsigmoid_default(self):  # type: () -> None
-        default_alpha = 0.2
-        default_beta = 0.5
-        node = onnx.helper.make_node(
-            'HardSigmoid',
-            inputs=['x'],
-            outputs=['y'],
-        )
-        x = np.random.randn(3, 4, 5).astype(np.float32)
-        y = np.clip(x * default_alpha + default_beta, 0, 1)
-        expect(node, inputs=[x], outputs=[y], name='test_hardsigmoid_default')
-
-    def test_identity(self):
-        node = onnx.helper.make_node(
-            'Identity',
-            inputs=['x'],
-            outputs=['y'],
-        )
-
-        data = np.array([[[
-            [1, 2],
-            [3, 4],
-        ]]], dtype=np.float32)
-
-        expect(node, inputs=[data], outputs=[data], name='test_identity')
-
-    def test_softplus(self):
-        node = onnx.helper.make_node(
-            'Softplus',
-            inputs=['x'],
-            outputs=['y'],
-        )
-
-        x = np.array([-1, 0, 1]).astype(np.float32)
-        # expected output [0.31326166, 0.69314718, 1.31326163]
-        y = np.log(np.exp(x) + 1)
-        expect(node, inputs=[x], outputs=[y], name='test_softplus_example')
-
-        x = np.random.randn(3, 4, 5).astype(np.float32)
-        y = np.log(np.exp(x) + 1)
-        expect(node, inputs=[x], outputs=[y], name='test_softplus')
-
-    def test_softsign(self):
-        node = onnx.helper.make_node(
-            'Softsign',
-            inputs=['x'],
-            outputs=['y'],
-        )
-
-        x = np.array([-1, 0, 1]).astype(np.float32)
-        y = np.array([-0.5, 0, 0.5]).astype(np.float32)
-        expect(node, inputs=[x], outputs=[y], name='test_softsign_example')
-
-        x = np.random.randn(3, 4, 5).astype(np.float32)
-        y = x / (1 + np.abs(x))
-        expect(node, inputs=[x], outputs=[y], name='test_softsign')
-
-    def test_mean(self):
-        data_0 = np.array([3, 0, 2]).astype(np.float32)
-        data_1 = np.array([1, 3, 4]).astype(np.float32)
-        data_2 = np.array([2, 6, 6]).astype(np.float32)
-        result = np.array([2, 3, 4]).astype(np.float32)
-        node = onnx.helper.make_node(
-            'Mean',
-            inputs=['data_0', 'data_1', 'data_2'],
-            outputs=['result'],
-        )
-        expect(node,
-               inputs=[data_0, data_1, data_2],
-               outputs=[result],
-               name='test_mean_example')
-
-        node = onnx.helper.make_node(
-            'Mean',
-            inputs=['data_0'],
-            outputs=['result'],
-        )
-        expect(node,
-               inputs=[data_0],
-               outputs=[data_0],
-               name='test_mean_one_input')
-
-        result = np.divide(np.add(data_0, data_1), 2.)
-        node = onnx.helper.make_node(
-            'Mean',
-            inputs=['data_0', 'data_1'],
-            outputs=['result'],
-        )
-        expect(node,
-               inputs=[data_0, data_1],
-               outputs=[result],
-               name='test_mean_two_inputs')
-
-    def test_transpose_default(self):  # type: () -> None
-        shape = (2, 3, 4)
-        data = np.random.random_sample(shape).astype(np.float32)
-
-        node = onnx.helper.make_node('Transpose',
-                                     inputs=['data'],
-                                     outputs=['transposed'])
-
-        transposed = np.transpose(data)
-        expect(node,
-               inputs=[data],
-               outputs=[transposed],
-               name='test_transpose_default')
-
-    def test_transpose_all_permutations(self):  # type: () -> None
-        shape = (2, 3, 4)
-        data = np.random.random_sample(shape).astype(np.float32)
-        permutations = list(itertools.permutations(np.arange(len(shape))))
-
-        for i in range(len(permutations)):
-            node = onnx.helper.make_node('Transpose',
-                                         inputs=['data'],
-                                         outputs=['transposed'],
-                                         perm=permutations[i])
-            transposed = np.transpose(data, permutations[i])
-            expect(node,
-                   inputs=[data],
-                   outputs=[transposed],
-                   name='test_transpose_all_permutations_' + str(i))
-
-    def test_max(self):
-        data_0 = np.array([3, 2, 1]).astype(np.float32)
-        data_1 = np.array([1, 4, 4]).astype(np.float32)
-        data_2 = np.array([2, 5, 3]).astype(np.float32)
-        result = np.array([3, 5, 4]).astype(np.float32)
-        # todo, not support 3 inputs
-        node = onnx.helper.make_node(
-            'Max',
-            inputs=['data_0', 'data_1', 'data_2'],
-            outputs=['result'],
-        )
-        expect(node,
-               inputs=[data_0, data_1, data_2],
-               outputs=[result],
-               name='test_max_example')
-
-        # todo, not support 1 inputs
-        node = onnx.helper.make_node(
-            'Max',
-            inputs=['data_0'],
-            outputs=['result'],
-        )
-        expect(node,
-               inputs=[data_0],
-               outputs=[data_0],
-               name='test_max_one_input')
-
-        result = np.maximum(data_0, data_1)
-        node = onnx.helper.make_node(
-            'Max',
-            inputs=['data_0', 'data_1'],
-            outputs=['result'],
-        )
-        expect(node,
-               inputs=[data_0, data_1],
-               outputs=[result],
-               name='test_max_two_inputs')
-
-    def test_min(self):
-        data_0 = np.array([3, 2, 1]).astype(np.float32)
-        data_1 = np.array([1, 4, 4]).astype(np.float32)
-        data_2 = np.array([2, 5, 0]).astype(np.float32)
-        result = np.array([1, 2, 0]).astype(np.float32)
-        node = onnx.helper.make_node(
-            'Min',
-            inputs=['data_0', 'data_1', 'data_2'],
-            outputs=['result'],
-        )
-        expect(node,
-               inputs=[data_0, data_1, data_2],
-               outputs=[result],
-               name='test_min_example')
-
-        node = onnx.helper.make_node(
-            'Min',
-            inputs=['data_0'],
-            outputs=['result'],
-        )
-        expect(node,
-               inputs=[data_0],
-               outputs=[data_0],
-               name='test_min_one_input')
-
-        result = np.minimum(data_0, data_1)
-        node = onnx.helper.make_node(
-            'Min',
-            inputs=['data_0', 'data_1'],
-            outputs=['result'],
-        )
-        expect(node,
-               inputs=[data_0, data_1],
-               outputs=[result],
-               name='test_min_two_inputs')
-
-    def test_shape(self):
-        node = onnx.helper.make_node(
-            'Shape',
-            inputs=['x'],
-            outputs=['y'],
-        )
-
-        x = np.array([
-            [1, 2, 3],
-            [4, 5, 6],
-        ]).astype(np.float32)
-        y = np.array([
-            2,
-            3,
-        ]).astype(np.int64)
-
-        expect(node, inputs=[x], outputs=[y], name='test_shape_example')
-
-        x = np.random.randn(3, 4, 5).astype(np.float32)
-        y = np.array(x.shape).astype(np.int64)
-
-        expect(node, inputs=[x], outputs=[y], name='test_shape')
-
-    def test_and(self):  # type: () -> None
-        node = onnx.helper.make_node(
-            'And',
-            inputs=['x', 'y'],
-            outputs=['and'],
-        )
-
-        # 2d
-        x = (np.random.randn(3, 4) > 0).astype(np.bool)
-        y = (np.random.randn(3, 4) > 0).astype(np.bool)
-        z = np.logical_and(x, y)
-        expect(node, inputs=[x, y], outputs=[z], name='test_and2d')
-
-        # 3d
-        x = (np.random.randn(3, 4, 5) > 0).astype(np.bool)
-        y = (np.random.randn(3, 4, 5) > 0).astype(np.bool)
-        z = np.logical_and(x, y)
-        expect(node, inputs=[x, y], outputs=[z], name='test_and3d')
-
-        # 4d
-        x = (np.random.randn(3, 4, 5, 6) > 0).astype(np.bool)
-        y = (np.random.randn(3, 4, 5, 6) > 0).astype(np.bool)
-        z = np.logical_and(x, y)
-        expect(node, inputs=[x, y], outputs=[z], name='test_and4d')
-
-    def test_and_broadcast(self):  # type: () -> None
-        node = onnx.helper.make_node(
-            'And',
-            inputs=['x', 'y'],
-            outputs=['and'],
-        )
-
-        # 3d vs 1d
-        x = (np.random.randn(3, 4, 5) > 0).astype(np.bool)
-        y = (np.random.randn(5) > 0).astype(np.bool)
-        z = np.logical_and(x, y)
-        expect(node, inputs=[x, y], outputs=[z], name='test_and_bcast3v1d')
-
-        # 3d vs 2d
-        x = (np.random.randn(3, 4, 5) > 0).astype(np.bool)
-        y = (np.random.randn(4, 5) > 0).astype(np.bool)
-        z = np.logical_and(x, y)
-        expect(node, inputs=[x, y], outputs=[z], name='test_and_bcast3v2d')
-
-        # 4d vs 2d
-        x = (np.random.randn(3, 4, 5, 6) > 0).astype(np.bool)
-        y = (np.random.randn(5, 6) > 0).astype(np.bool)
-        z = np.logical_and(x, y)
-        expect(node, inputs=[x, y], outputs=[z], name='test_and_bcast4v2d')
-
-        # 4d vs 3d
-        x = (np.random.randn(3, 4, 5, 6) > 0).astype(np.bool)
-        y = (np.random.randn(4, 5, 6) > 0).astype(np.bool)
-        z = np.logical_and(x, y)
-        expect(node, inputs=[x, y], outputs=[z], name='test_and_bcast4v3d')
-
-        # 4d vs 4d
-        x = (np.random.randn(1, 4, 1, 6) > 0).astype(np.bool)
-        y = (np.random.randn(3, 1, 5, 6) > 0).astype(np.bool)
-        z = np.logical_and(x, y)
-        expect(node, inputs=[x, y], outputs=[z], name='test_and_bcast4v4d')
-
-    def test_or(self):
-        node = onnx.helper.make_node(
-            'Or',
-            inputs=['x', 'y'],
-            outputs=['or'],
-        )
-
-        # 2d
-        x = (np.random.randn(3, 4) > 0).astype(np.bool)
-        y = (np.random.randn(3, 4) > 0).astype(np.bool)
-        z = np.logical_or(x, y)
-        expect(node, inputs=[x, y], outputs=[z], name='test_or2d')
-
-        # 3d
-        x = (np.random.randn(3, 4, 5) > 0).astype(np.bool)
-        y = (np.random.randn(3, 4, 5) > 0).astype(np.bool)
-        z = np.logical_or(x, y)
-        expect(node, inputs=[x, y], outputs=[z], name='test_or3d')
-
-        # 4d
-        x = (np.random.randn(3, 4, 5, 6) > 0).astype(np.bool)
-        y = (np.random.randn(3, 4, 5, 6) > 0).astype(np.bool)
-        z = np.logical_or(x, y)
-        expect(node, inputs=[x, y], outputs=[z], name='test_or4d')
-
-    def test_or_broadcast(self):  # type: () -> None
-        node = onnx.helper.make_node(
-            'Or',
-            inputs=['x', 'y'],
-            outputs=['or'],
-        )
-
-        # 3d vs 1d
-        x = (np.random.randn(3, 4, 5) > 0).astype(np.bool)
-        y = (np.random.randn(5) > 0).astype(np.bool)
-        z = np.logical_or(x, y)
-        expect(node, inputs=[x, y], outputs=[z], name='test_or_bcast3v1d')
-
-        # 3d vs 2d
-        x = (np.random.randn(3, 4, 5) > 0).astype(np.bool)
-        y = (np.random.randn(4, 5) > 0).astype(np.bool)
-        z = np.logical_or(x, y)
-        expect(node, inputs=[x, y], outputs=[z], name='test_or_bcast3v2d')
-
-        # 4d vs 2d
-        x = (np.random.randn(3, 4, 5, 6) > 0).astype(np.bool)
-        y = (np.random.randn(5, 6) > 0).astype(np.bool)
-        z = np.logical_or(x, y)
-        expect(node, inputs=[x, y], outputs=[z], name='test_or_bcast4v2d')
-
-        # 4d vs 3d
-        x = (np.random.randn(3, 4, 5, 6) > 0).astype(np.bool)
-        y = (np.random.randn(4, 5, 6) > 0).astype(np.bool)
-        z = np.logical_or(x, y)
-        expect(node, inputs=[x, y], outputs=[z], name='test_or_bcast4v3d')
-
-        # 4d vs 4d
-        x = (np.random.randn(1, 4, 1, 6) > 0).astype(np.bool)
-        y = (np.random.randn(3, 1, 5, 6) > 0).astype(np.bool)
-        z = np.logical_or(x, y)
-        expect(node, inputs=[x, y], outputs=[z], name='test_or_bcast4v4d')
-
-    def test_xor(self):  # type: () -> None
-        node = onnx.helper.make_node(
-            'Xor',
-            inputs=['x', 'y'],
-            outputs=['xor'],
-        )
-
-        # 2d
-        x = (np.random.randn(3, 4) > 0).astype(np.bool)
-        y = (np.random.randn(3, 4) > 0).astype(np.bool)
-        z = np.logical_xor(x, y)
-        expect(node, inputs=[x, y], outputs=[z], name='test_xor2d')
-
-        # 3d
-        x = (np.random.randn(3, 4, 5) > 0).astype(np.bool)
-        y = (np.random.randn(3, 4, 5) > 0).astype(np.bool)
-        z = np.logical_xor(x, y)
-        expect(node, inputs=[x, y], outputs=[z], name='test_xor3d')
-
-        # 4d
-        x = (np.random.randn(3, 4, 5, 6) > 0).astype(np.bool)
-        y = (np.random.randn(3, 4, 5, 6) > 0).astype(np.bool)
-        z = np.logical_xor(x, y)
-        expect(node, inputs=[x, y], outputs=[z], name='test_xor4d')
-
-    def test_xor_broadcast(self):  # type: () -> None
-        node = onnx.helper.make_node(
-            'Xor',
-            inputs=['x', 'y'],
-            outputs=['xor'],
-        )
-
-        # 3d vs 1d
-        x = (np.random.randn(3, 4, 5) > 0).astype(np.bool)
-        y = (np.random.randn(5) > 0).astype(np.bool)
-        z = np.logical_xor(x, y)
-        expect(node, inputs=[x, y], outputs=[z], name='test_xor_bcast3v1d')
-
-        # 3d vs 2d
-        x = (np.random.randn(3, 4, 5) > 0).astype(np.bool)
-        y = (np.random.randn(4, 5) > 0).astype(np.bool)
-        z = np.logical_xor(x, y)
-        expect(node, inputs=[x, y], outputs=[z], name='test_xor_bcast3v2d')
-
-        # 4d vs 2d
-        x = (np.random.randn(3, 4, 5, 6) > 0).astype(np.bool)
-        y = (np.random.randn(5, 6) > 0).astype(np.bool)
-        z = np.logical_xor(x, y)
-        expect(node, inputs=[x, y], outputs=[z], name='test_xor_bcast4v2d')
-
-        # 4d vs 3d
-        x = (np.random.randn(3, 4, 5, 6) > 0).astype(np.bool)
-        y = (np.random.randn(4, 5, 6) > 0).astype(np.bool)
-        z = np.logical_xor(x, y)
-        expect(node, inputs=[x, y], outputs=[z], name='test_xor_bcast4v3d')
-
-        # 4d vs 4d
-        x = (np.random.randn(1, 4, 1, 6) > 0).astype(np.bool)
-        y = (np.random.randn(3, 1, 5, 6) > 0).astype(np.bool)
-        z = np.logical_xor(x, y)
-        expect(node, inputs=[x, y], outputs=[z], name='test_xor_bcast4v4d')
-
-    def test_not(self):
-        node = onnx.helper.make_node(
-            'Not',
-            inputs=['x'],
-            outputs=['not'],
-        )
-
-        # 2d
-        x = (np.random.randn(3, 4) > 0).astype(np.bool)
-        expect(node,
-               inputs=[x],
-               outputs=[np.logical_not(x)],
-               name='test_not_2d')
-
-        # 3d
-        x = (np.random.randn(3, 4, 5) > 0).astype(np.bool)
-        expect(node,
-               inputs=[x],
-               outputs=[np.logical_not(x)],
-               name='test_not_3d')
-
-        # 4d
-        x = (np.random.randn(3, 4, 5, 6) > 0).astype(np.bool)
-        expect(node,
-               inputs=[x],
-               outputs=[np.logical_not(x)],
-               name='test_not_4d')
-
-    def test_neg(self):
-        node = onnx.helper.make_node(
-            'Neg',
-            inputs=['x'],
-            outputs=['y'],
-        )
-
-        x = np.array([-4, 2]).astype(np.float32)
-        y = np.negative(x)  # expected output [4., -2.],
-        expect(node, inputs=[x], outputs=[y], name='test_neg_example')
-
-        x = np.random.randn(3, 4, 5).astype(np.float32)
-        y = np.negative(x)
-        expect(node, inputs=[x], outputs=[y], name='test_neg')
-
-    def test_reciprocal(self):
-        node = onnx.helper.make_node(
-            'Reciprocal',
-            inputs=['x'],
-            outputs=['y'],
-        )
-
-        x = np.array([-4, 2]).astype(np.float32)
-        y = np.reciprocal(x)  # expected output [-0.25, 0.5],
-        expect(node, inputs=[x], outputs=[y], name='test_reciprocal_example')
-
-        x = np.random.rand(3, 4, 5).astype(np.float32) + 0.5
-        y = np.reciprocal(x)
-        expect(node, inputs=[x], outputs=[y], name='test_reciprocal')
-
-    def test_batchnorm(self):  # type: () -> None
-        # we changed this test cases
-        # according to the paper https://arxiv.org/pdf/1502.03167.pdf
-        def _batchnorm_test_mode(x,
-                                 s,
-                                 bias,
-                                 mean,
-                                 var,
-                                 momentum=0.9,
-                                 epsilon=1e-5):  # type: ignore
-            dims_x = len(x.shape)
-            dim_ones = (1,) * (dims_x - 2)
-            s = s.reshape(-1, *dim_ones)
-            bias = bias.reshape(-1, *dim_ones)
-            mean = mean.reshape(-1, *dim_ones)
-            var = var.reshape(-1, *dim_ones)
-            batch_m = x.mean(axis=(0, 2, 3), keepdims=True)
-            batch_v = x.var(axis=(0, 2, 3), keepdims=True)
-            return s * (x - batch_m) / np.sqrt(batch_v + epsilon) + bias
-
-        # input size: (1, 2, 1, 3)
-        x = np.array([[[[-1, 0, 1]], [[2, 3, 4]]]]).astype(np.float32)
-        s = np.array([1.0, 1.5]).astype(np.float32)
-        bias = np.array([0, 1]).astype(np.float32)
-        mean = np.array([0, 3]).astype(np.float32)
-        var = np.array([1, 1.5]).astype(np.float32)
-        y = _batchnorm_test_mode(x, s, bias, mean, var).astype(np.float32)
-
-        node = onnx.helper.make_node(
-            'BatchNormalization',
-            inputs=['x', 's', 'bias', 'mean', 'var'],
-            outputs=['y'],
-        )
-
-        # output size: (1, 2, 1, 3)
-        expect(node,
-               inputs=[x, s, bias, mean, var],
-               outputs=[y],
-               name='test_batchnorm_example')
-
-        # input size: (2, 3, 4, 5)
-        x = np.random.randn(2, 3, 4, 5).astype(np.float32)
-        s = np.random.randn(3).astype(np.float32)
-        bias = np.random.randn(3).astype(np.float32)
-        mean = np.random.randn(3).astype(np.float32)
-        var = np.random.rand(3).astype(np.float32)
-        epsilon = 1e-2
-        y = _batchnorm_test_mode(x, s, bias, mean, var,
-                                 epsilon).astype(np.float32)
-
-        node = onnx.helper.make_node(
-            'BatchNormalization',
-            inputs=['x', 's', 'bias', 'mean', 'var'],
-            outputs=['y'],
-            epsilon=epsilon,
-        )
-
-        # output size: (2, 3, 4, 5)
-        expect(node,
-               inputs=[x, s, bias, mean, var],
-               outputs=[y],
-               name='test_batchnorm_epsilon')
-
-    def test_softmax(self):  # type: () -> None
-        node = onnx.helper.make_node(
-            'Softmax',
-            inputs=['x'],
-            outputs=['y'],
-        )
-        x = np.array([[-1, 0, 1]]).astype(np.float32)
-        # expected output [[0.09003058, 0.24472848, 0.66524094]]
-        y = np.exp(x) / np.sum(np.exp(x), axis=1)
-        expect(node, inputs=[x], outputs=[y], name='test_softmax_example')
-
-    def test_softmax_axis(self):  # type: () -> None
-
-        def softmax_2d(x):  # type: (np.ndarray) -> np.ndarray
-            max_x = np.max(x, axis=1).reshape((-1, 1))
-            exp_x = np.exp(x - max_x)
-            return exp_x / np.sum(exp_x, axis=1).reshape((-1, 1))
-
-        x = np.array([[0, 1, 2, 3], [10000, 10001, 10002,
-                                     10003]]).astype(np.float32)
-        # expected output [[0.0320586, 0.08714432, 0.23688284, 0.64391428],
-        #                 [0.0320586, 0.08714432, 0.23688284, 0.64391428]]
-        y = softmax_2d(x)
-
-        node = onnx.helper.make_node(
-            'Softmax',
-            inputs=['x'],
-            outputs=['y'],
-        )
-        expect(node, inputs=[x], outputs=[y], name='test_softmax_large_number')
-
-        x = np.abs(np.random.randn(3, 4, 5).astype(np.float32))
-        node = onnx.helper.make_node(
-            'Softmax',
-            inputs=['x'],
-            outputs=['y'],
-            axis=0,
-        )
-        y = softmax_2d(x.reshape(1, 60)).reshape(3, 4, 5)
-        expect(node, inputs=[x], outputs=[y], name='test_softmax_axis_0')
-
-        node = onnx.helper.make_node(
-            'Softmax',
-            inputs=['x'],
-            outputs=['y'],
-            axis=1,
-        )
-        y = softmax_2d(x.reshape(3, 20)).reshape(3, 4, 5)
-        expect(node, inputs=[x], outputs=[y], name='test_softmax_axis_1')
-
-        # default axis is 1
-        node = onnx.helper.make_node(
-            'Softmax',
-            inputs=['x'],
-            outputs=['y'],
-        )
-        expect(node, inputs=[x], outputs=[y], name='test_softmax_default_axis')
-
-        node = onnx.helper.make_node(
-            'Softmax',
-            inputs=['x'],
-            outputs=['y'],
-            axis=2,
-        )
-        y = softmax_2d(x.reshape(12, 5)).reshape(3, 4, 5)
-        expect(node, inputs=[x], outputs=[y], name='test_softmax_axis_2')
-
-        node = onnx.helper.make_node(
-            'Softmax',
-            inputs=['x'],
-            outputs=['y'],
-            axis=-1,
-        )
-        y = softmax_2d(x.reshape(12, 5)).reshape(3, 4, 5)
-        expect(node, inputs=[x], outputs=[y], name='test_softmax_negative_axis')
-
-    def test_div(self):  # type: () -> None
-        node = onnx.helper.make_node(
-            'Div',
-            inputs=['x', 'y'],
-            outputs=['z'],
-        )
-
-        x = np.array([3, 4]).astype(np.float32)
-        y = np.array([1, 2]).astype(np.float32)
-        z = x / y  # expected output [3., 2.]
-        expect(node, inputs=[x, y], outputs=[z], name='test_div_example')
-
-        x = np.random.randn(3, 4, 5).astype(np.float32)
-        y = np.random.rand(3, 4, 5).astype(np.float32) + 1.0
-        z = x / y
-        expect(node, inputs=[x, y], outputs=[z], name='test_div')
-
-    def test_div_broadcast(self):  # type: () -> None
-        node = onnx.helper.make_node(
-            'Div',
-            inputs=['x', 'y'],
-            outputs=['z'],
-        )
-
-        x = np.random.randn(3, 4, 5).astype(np.float32)
-        y = np.random.rand(5).astype(np.float32) + 1.0
-        z = x / y
-        expect(node, inputs=[x, y], outputs=[z], name='test_div_bcast')
-
-    def test_pow(self):
-        node = onnx.helper.make_node(
-            'Pow',
-            inputs=['x', 'y'],
-            outputs=['z'],
-        )
-
-        x = np.array([1, 2, 3]).astype(np.float32)
-        y = np.array([4, 5, 6]).astype(np.float32)  # todo, not exactly same
-        z = np.power(x, y)  # expected output [1., 32., 729.]
-        expect(node,
-               inputs=[x, y],
-               outputs=[z],
-               name='test_pow_example',
-               decimal=3)
-
-        x = np.arange(24).reshape(2, 3, 4).astype(
-            np.float32)  # todo, cannot too big here
-        y = np.random.randn(2, 3, 4).astype(np.float32)
-        z = np.power(x, y)
-        expect(node, inputs=[x, y], outputs=[z], name='test_pow', decimal=3)
-
-    def test_pow_broadcast(self):  # type: () -> None
-        node = onnx.helper.make_node(
-            'Pow',
-            inputs=['x', 'y'],
-            outputs=['z'],
-        )
-
-        x = np.array([1, 2, 3]).astype(np.float32)
-        y = np.array(2).astype(np.float32)
-        z = np.power(x, y)  # expected output [1., 4., 9.]
-        expect(node,
-               inputs=[x, y],
-               outputs=[z],
-               name='test_pow_bcast_scalar',
-               decimal=3)
-
-        node = onnx.helper.make_node(
-            'Pow',
-            inputs=['x', 'y'],
-            outputs=['z'],
-        )
-        x = np.array([[1, 2, 3], [4, 5, 6]]).astype(np.float32)
-        y = np.array([1, 2, 3]).astype(np.float32)
-        # expected output [[1, 4, 27], [4, 25, 216]]
-        z = np.power(x, y).astype(np.float32)
-        expect(node,
-               inputs=[x, y],
-               outputs=[z],
-               name='test_pow_bcast_array',
-               decimal=3)
-
-    def test_clip(self):
-        node = onnx.helper.make_node(
-            'Clip',
-            inputs=['x', 'min', 'max'],
-            outputs=['y'],
-        )
-
-        x = np.array([-2, 0, 2]).astype(np.float32)
-        min_val = np.float32(-1)
-        max_val = np.float32(1)
-        y = np.clip(x, min_val, max_val)  # expected output [-1., 0., 1.]
-        expect(node,
-               inputs=[x, min_val, max_val],
-               outputs=[y],
-               name='test_clip_example')
-
-        x = np.random.randn(3, 4, 5).astype(np.float32)
-        y = np.clip(x, min_val, max_val)
-        expect(node,
-               inputs=[x, min_val, max_val],
-               outputs=[y],
-               name='test_clip')
-        node = onnx.helper.make_node(
-            'Clip',
-            inputs=['x', 'min', 'max'],
-            outputs=['y'],
-        )
-
-        min_val = np.float32(-5)
-        max_val = np.float32(5)
-
-        x = np.array([-1, 0, 1]).astype(np.float32)
-        y = np.array([-1, 0, 1]).astype(np.float32)
-        expect(node,
-               inputs=[x, min_val, max_val],
-               outputs=[y],
-               name='test_clip_inbounds')
-
-        x = np.array([-6, 0, 6]).astype(np.float32)
-        y = np.array([-5, 0, 5]).astype(np.float32)
-        expect(node,
-               inputs=[x, min_val, max_val],
-               outputs=[y],
-               name='test_clip_outbounds')
-
-        x = np.array([-1, 0, 6]).astype(np.float32)
-        y = np.array([-1, 0, 5]).astype(np.float32)
-        expect(node,
-               inputs=[x, min_val, max_val],
-               outputs=[y],
-               name='test_clip_splitbounds')
-
-    def test_clip_default(self):  # type: () -> None
-        node = onnx.helper.make_node(
-            'Clip',
-            inputs=['x', 'min'],
-            outputs=['y'],
-        )
-        min_val = np.float32(0)
-        x = np.random.randn(3, 4, 5).astype(np.float32)
-        y = np.clip(x, min_val, np.inf)
-        expect(node,
-               inputs=[x, min_val],
-               outputs=[y],
-               name='test_clip_default_min')
-
-        no_min = ""  # optional input, not supplied
-        node = onnx.helper.make_node(
-            'Clip',
-            inputs=['x', no_min, 'max'],
-            outputs=['y'],
-        )
-        max_val = np.float32(0)
-        x = np.random.randn(3, 4, 5).astype(np.float32)
-        y = np.clip(x, -np.inf, max_val)
-        expect(node,
-               inputs=[x, max_val],
-               outputs=[y],
-               name='test_clip_default_max')
-
-        no_max = ""  # optional input, not supplied
-        node = onnx.helper.make_node(
-            'Clip',
-            inputs=['x', no_min, no_max],
-            outputs=['y'],
-        )
-
-        x = np.array([-1, 0, 1]).astype(np.float32)
-        y = np.array([-1, 0, 1]).astype(np.float32)
-        expect(node, inputs=[x], outputs=[y], name='test_clip_default_inbounds')
-
-    def test_prelu(self):
-        node = onnx.helper.make_node(
-            'PRelu',
-            inputs=['x', 'slope'],
-            outputs=['y'],
-        )
-
-        x = np.random.randn(3, 4, 5).astype(np.float32)
-        slope = np.random.randn(3, 4, 5).astype(np.float32)
-        y = np.clip(x, 0, np.inf) + np.clip(x, -np.inf, 0) * slope
-
-        expect(node, inputs=[x, slope], outputs=[y], name='test_prelu_example')
-
-    #todo, not support prelu broadcast
-    def test_prelu_broadcast(self):  # type: () -> None
-        node = onnx.helper.make_node(
-            'PRelu',
-            inputs=['x', 'slope'],
-            outputs=['y'],
-        )
-
-        x = np.random.randn(3, 4, 5).astype(np.float32)
-        slope = np.random.randn(5).astype(np.float32)
-        y = np.clip(x, 0, np.inf) + np.clip(x, -np.inf, 0) * slope
-
-        expect(node,
-               inputs=[x, slope],
-               outputs=[y],
-               name='test_prelu_broadcast')
-
-    def test_mul(self):
-        node = onnx.helper.make_node(
-            'Mul',
-            inputs=['x', 'y'],
-            outputs=['z'],
-        )
-
-        x = np.array([1, 2, 3]).astype(np.float32)
-        y = np.array([4, 5, 6]).astype(np.float32)
-        z = x * y  # expected output [4., 10., 18.]
-        expect(node, inputs=[x, y], outputs=[z], name='test_mul_example')
-
-        x = np.random.randn(3, 4, 5).astype(np.float32)
-        y = np.random.randn(3, 4, 5).astype(np.float32)
-        z = x * y
-        expect(node, inputs=[x, y], outputs=[z], name='test_mul')
-
-    def test_mul_broadcast(self):  # type: () -> None
-        node = onnx.helper.make_node(
-            'Mul',
-            inputs=['x', 'y'],
-            outputs=['z'],
-        )
-
-        x = np.random.randn(3, 4, 5).astype(np.float32)
-        y = np.random.randn(5).astype(np.float32)
-        z = x * y
-        expect(node, inputs=[x, y], outputs=[z], name='test_mul_bcast')
-
-    def test_gemm_default_zero_bias(self):
-        node = onnx.helper.make_node('Gemm',
-                                     inputs=['a', 'b', 'c'],
-                                     outputs=['y'])
-        a = np.random.ranf([3, 5]).astype(np.float32)
-        b = np.random.ranf([5, 4]).astype(np.float32)
-        c = np.zeros([1, 4]).astype(np.float32)
-        y = gemm_reference_implementation(a, b, c)
-        expect(node,
-               inputs=[a, b, c],
-               outputs=[y],
-               name='test_gemm_default_zero_bias')
-
-    def test_gemm_default_no_bias(self):
-        node = onnx.helper.make_node('Gemm', inputs=['a', 'b'], outputs=['y'])
-        a = np.random.ranf([2, 10]).astype(np.float32)
-        b = np.random.ranf([10, 3]).astype(np.float32)
-        y = gemm_reference_implementation(a, b)
-        expect(node,
-               inputs=[a, b],
-               outputs=[y],
-               name='test_gemm_default_no_bias')
-
-    def test_gemm_default_scalar_bias(self):
-        node = onnx.helper.make_node('Gemm',
-                                     inputs=['a', 'b', 'c'],
-                                     outputs=['y'])
-        a = np.random.ranf([2, 3]).astype(np.float32)
-        b = np.random.ranf([3, 4]).astype(np.float32)
-        c = np.array(3.14).astype(np.float32)
-        y = gemm_reference_implementation(a, b, c)
-        expect(node,
-               inputs=[a, b, c],
-               outputs=[y],
-               name='test_gemm_default_scalar_bias')
-
-    def test_gemm_default_single_elem_vector_bias(self):
-        node = onnx.helper.make_node('Gemm',
-                                     inputs=['a', 'b', 'c'],
-                                     outputs=['y'])
-        a = np.random.ranf([3, 7]).astype(np.float32)
-        b = np.random.ranf([7, 3]).astype(np.float32)
-        c = np.random.ranf([1]).astype(np.float32)
-        y = gemm_reference_implementation(a, b, c)
-        expect(node,
-               inputs=[a, b, c],
-               outputs=[y],
-               name='test_gemm_default_single_elem_vector_bias')
-
-    def test_gemm_default_vector_bias(self):
-        node = onnx.helper.make_node('Gemm',
-                                     inputs=['a', 'b', 'c'],
-                                     outputs=['y'])
-        a = np.random.ranf([2, 7]).astype(np.float32)
-        b = np.random.ranf([7, 4]).astype(np.float32)
-        c = np.random.ranf([1, 4]).astype(np.float32)
-        y = gemm_reference_implementation(a, b, c)
-        expect(node,
-               inputs=[a, b, c],
-               outputs=[y],
-               name='test_gemm_default_vector_bias')
-
-    def test_gemm_default_matrix_bias(self):
-        node = onnx.helper.make_node('Gemm',
-                                     inputs=['a', 'b', 'c'],
-                                     outputs=['y'])
-        a = np.random.ranf([3, 6]).astype(np.float32)
-        b = np.random.ranf([6, 4]).astype(np.float32)
-        c = np.random.ranf([3, 4]).astype(np.float32)
-        y = gemm_reference_implementation(a, b, c)
-        expect(node,
-               inputs=[a, b, c],
-               outputs=[y],
-               name='test_gemm_default_matrix_bias')
-
-    def test_gemm_transposeA(self):
-        node = onnx.helper.make_node('Gemm',
-                                     inputs=['a', 'b', 'c'],
-                                     outputs=['y'],
-                                     transA=1)
-        a = np.random.ranf([6, 3]).astype(np.float32)
-        b = np.random.ranf([6, 4]).astype(np.float32)
-        c = np.zeros([1, 4]).astype(np.float32)
-        y = gemm_reference_implementation(a, b, c, transA=1)
-        expect(node, inputs=[a, b, c], outputs=[y], name='test_gemm_transposeA')
-
-    def test_gemm_transposeB(self):
-        node = onnx.helper.make_node('Gemm',
-                                     inputs=['a', 'b', 'c'],
-                                     outputs=['y'],
-                                     transB=1)
-        a = np.random.ranf([3, 6]).astype(np.float32)
-        b = np.random.ranf([4, 6]).astype(np.float32)
-        c = np.zeros([1, 4]).astype(np.float32)
-        y = gemm_reference_implementation(a, b, c, transB=1)
-        expect(node, inputs=[a, b, c], outputs=[y], name='test_gemm_transposeB')
-
-    def test_gemm_alpha(self):
-        node = onnx.helper.make_node('Gemm',
-                                     inputs=['a', 'b', 'c'],
-                                     outputs=['y'],
-                                     alpha=0.5)
-        a = np.random.ranf([3, 5]).astype(np.float32)
-        b = np.random.ranf([5, 4]).astype(np.float32)
-        c = np.zeros([1, 4]).astype(np.float32)
-        y = gemm_reference_implementation(a, b, c, alpha=0.5)
-        expect(node, inputs=[a, b, c], outputs=[y], name='test_gemm_alpha')
-
-    def test_gemm_beta(self):
-        node = onnx.helper.make_node('Gemm',
-                                     inputs=['a', 'b', 'c'],
-                                     outputs=['y'],
-                                     beta=0.5)
-        a = np.random.ranf([2, 7]).astype(np.float32)
-        b = np.random.ranf([7, 4]).astype(np.float32)
-        c = np.random.ranf([1, 4]).astype(np.float32)
-        y = gemm_reference_implementation(a, b, c, beta=0.5)
-        expect(node, inputs=[a, b, c], outputs=[y], name='test_gemm_beta')
-
-    def test_gemm_all_attributes(self):
-        node = onnx.helper.make_node('Gemm',
-                                     inputs=['a', 'b', 'c'],
-                                     outputs=['y'],
-                                     alpha=0.25,
-                                     beta=0.35,
-                                     transA=1,
-                                     transB=1)
-        a = np.random.ranf([4, 3]).astype(np.float32)
-        b = np.random.ranf([5, 4]).astype(np.float32)
-        c = np.random.ranf([1, 5]).astype(np.float32)
-        y = gemm_reference_implementation(a,
-                                          b,
-                                          c,
-                                          transA=1,
-                                          transB=1,
-                                          alpha=0.25,
-                                          beta=0.35)
-        expect(node,
-               inputs=[a, b, c],
-               outputs=[y],
-               name='test_gemm_all_attributes')
-
-    def test_constantOfShape_float_ones(self):
-        x = np.array([4, 3, 2]).astype(np.int64)
-        tensor_value = onnx.helper.make_tensor("value", onnx.TensorProto.FLOAT,
-                                               [1], [1])
-        node = onnx.helper.make_node(
-            'ConstantOfShape',
-            inputs=['x'],
-            outputs=['y'],
-            value=tensor_value,
-        )
-
-        y = np.ones(x, dtype=np.float32)
-        expect(node,
-               inputs=[x],
-               outputs=[y],
-               name='test_constantofshape_float_ones')
-
-    def test_constantOfShape_int32_zeros(self):
-        x = np.array([10, 6]).astype(np.int64)
-        tensor_value = onnx.helper.make_tensor("value", onnx.TensorProto.INT32,
-                                               [1], [0])
-        node = onnx.helper.make_node(
-            'ConstantOfShape',
-            inputs=['x'],
-            outputs=['y'],
-            value=tensor_value,
-        )
-        y = np.zeros(x, dtype=np.int32)
-        expect(node,
-               inputs=[x],
-               outputs=[y],
-               name='test_constantofshape_int_zeros')
-
-    # cannot support yet
-    # def test_int32_shape_zero(self):
-    #     x = np.array([0, ]).astype(np.int64)
-    #     tensor_value = onnx.helper.make_tensor("value", onnx.TensorProto.INT32,
-    #                                            [1], [0])
-    #     node = onnx.helper.make_node(
-    #         'ConstantOfShape',
-    #         inputs=['x'],
-    #         outputs=['y'],
-    #         value=tensor_value,
-    #     )
-    #     y = np.zeros(x, dtype=np.int32)
-    #     expect(node, inputs=[x], outputs=[y],
-    #            name='test_constantofshape_int_shape_zero')
-
-    def test_reduce_sum_do_not_keepdims(self):
-        shape = [3, 2, 2]
-        axes = [1]
-        keepdims = 0
-
-        node = onnx.helper.make_node('ReduceSum',
-                                     inputs=['data'],
-                                     outputs=['reduced'],
-                                     axes=axes,
-                                     keepdims=keepdims)
-
-        data = np.array(
-            [[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 10], [11, 12]]],
-            dtype=np.float32)
-        reduced = np.sum(data, axis=tuple(axes), keepdims=keepdims == 1)
-        #print(reduced)
-        #[[4., 6.]
-        # [12., 14.]
-        # [20., 22.]]
-
-        expect(node,
-               inputs=[data],
-               outputs=[reduced],
-               name='test_reduce_sum_do_not_keepdims_example')
-
-        np.random.seed(0)
-        data = np.random.uniform(-10, 10, shape).astype(np.float32)
-        reduced = np.sum(data, axis=tuple(axes), keepdims=keepdims == 1)
-
-        expect(node,
-               inputs=[data],
-               outputs=[reduced],
-               name='test_reduce_sum_do_not_keepdims_random')
-
-    def test_reduce_sum_keepdims(self):
-        shape = [3, 2, 2]
-        axes = [1]
-        keepdims = 1
-
-        node = onnx.helper.make_node('ReduceSum',
-                                     inputs=['data'],
-                                     outputs=['reduced'],
-                                     axes=axes,
-                                     keepdims=keepdims)
-
-        data = np.array(
-            [[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 10], [11, 12]]],
-            dtype=np.float32)
-        reduced = np.sum(data, axis=tuple(axes), keepdims=keepdims == 1)
-        #print(reduced)
-        #[[[4., 6.]]
-        # [[12., 14.]]
-        # [[20., 22.]]]
-
-        expect(node,
-               inputs=[data],
-               outputs=[reduced],
-               name='test_reduce_sum_keepdims_example')
-
-        np.random.seed(0)
-        data = np.random.uniform(-10, 10, shape).astype(np.float32)
-        reduced = np.sum(data, axis=tuple(axes), keepdims=keepdims == 1)
-
-        expect(node,
-               inputs=[data],
-               outputs=[reduced],
-               name='test_reduce_sum_keepdims_random')
-
-    def test_reduce_sum_default_axes_keepdims(self):
-        shape = [3, 2, 2]
-        axes = None
-        keepdims = 1
-
-        node = onnx.helper.make_node('ReduceSum',
-                                     inputs=['data'],
-                                     outputs=['reduced'],
-                                     keepdims=keepdims)
-
-        data = np.array(
-            [[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 10], [11, 12]]],
-            dtype=np.float32)
-        reduced = np.sum(data, axis=axes, keepdims=keepdims == 1)
-        #print(reduced)
-        #[[[78.]]]
-
-        expect(node,
-               inputs=[data],
-               outputs=[reduced],
-               name='test_reduce_sum_default_axes_keepdims_example')
-
-        np.random.seed(0)
-        data = np.random.uniform(-10, 10, shape).astype(np.float32)
-        reduced = np.sum(data, axis=axes, keepdims=keepdims == 1)
-
-        expect(node,
-               inputs=[data],
-               outputs=[reduced],
-               name='test_reduce_sum_default_axes_keepdims_random')
-
-    def test_reduce_sum_negative_axes_keepdims(self):
-        shape = [3, 2, 2]
-        axes = [-2]
-        keepdims = 1
-
-        node = onnx.helper.make_node('ReduceSum',
-                                     inputs=['data'],
-                                     outputs=['reduced'],
-                                     axes=axes,
-                                     keepdims=keepdims)
-
-        data = np.array(
-            [[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 10], [11, 12]]],
-            dtype=np.float32)
-        reduced = np.sum(data, axis=tuple(axes), keepdims=keepdims == 1)
-        # print(reduced)
-        #[[[4., 6.]]
-        # [[12., 14.]]
-        # [[20., 22.]]]
-
-        expect(node,
-               inputs=[data],
-               outputs=[reduced],
-               name='test_reduce_sum_negative_axes_keepdims_example')
-
-        np.random.seed(0)
-        data = np.random.uniform(-10, 10, shape).astype(np.float32)
-        reduced = np.sum(data, axis=tuple(axes), keepdims=keepdims == 1)
-
-        expect(node,
-               inputs=[data],
-               outputs=[reduced],
-               name='test_reduce_sum_negative_axes_keepdims_random')
-
-    def test_reduce_mean_do_not_keepdims(self):
-        shape = [3, 2, 2]
-        axes = [1]
-        keepdims = 0
-
-        node = onnx.helper.make_node('ReduceMean',
-                                     inputs=['data'],
-                                     outputs=['reduced'],
-                                     axes=axes,
-                                     keepdims=keepdims)
-
-        data = np.array(
-            [[[5, 1], [20, 2]], [[30, 1], [40, 2]], [[55, 1], [60, 2]]],
-            dtype=np.float32)
-        reduced = np.mean(data, axis=tuple(axes), keepdims=keepdims == 1)
-        #print(reduced)
-        #[[12.5, 1.5]
-        # [35., 1.5]
-        # [57.5, 1.5]]
-
-        expect(node,
-               inputs=[data],
-               outputs=[reduced],
-               name='test_reduce_mean_do_not_keepdims_example')
-
-        np.random.seed(0)
-        data = np.random.uniform(-10, 10, shape).astype(np.float32)
-        reduced = np.mean(data, axis=tuple(axes), keepdims=keepdims == 1)
-
-        expect(node,
-               inputs=[data],
-               outputs=[reduced],
-               name='test_reduce_mean_do_not_keepdims_random')
-
-    def test_reduce_mean_keepdims(self):
-        shape = [3, 2, 2]
-        axes = [1]
-        keepdims = 1
-
-        node = onnx.helper.make_node('ReduceMean',
-                                     inputs=['data'],
-                                     outputs=['reduced'],
-                                     axes=axes,
-                                     keepdims=keepdims)
-
-        data = np.array(
-            [[[5, 1], [20, 2]], [[30, 1], [40, 2]], [[55, 1], [60, 2]]],
-            dtype=np.float32)
-        reduced = np.mean(data, axis=tuple(axes), keepdims=keepdims == 1)
-        #print(reduced)
-        #[[[12.5, 1.5]]
-        # [[35., 1.5]]
-        # [[57.5, 1.5]]]
-
-        expect(node,
-               inputs=[data],
-               outputs=[reduced],
-               name='test_reduce_mean_keepdims_example')
-
-        np.random.seed(0)
-        data = np.random.uniform(-10, 10, shape).astype(np.float32)
-        reduced = np.mean(data, axis=tuple(axes), keepdims=keepdims == 1)
-
-        expect(node,
-               inputs=[data],
-               outputs=[reduced],
-               name='test_reduce_mean_keepdims_random')
-
-    def test_reduce_mean_default_axes_keepdims(self):
-        shape = [3, 2, 2]
-        axes = None
-        keepdims = 1
-
-        node = onnx.helper.make_node('ReduceMean',
-                                     inputs=['data'],
-                                     outputs=['reduced'],
-                                     keepdims=keepdims)
-
-        data = np.array(
-            [[[5, 1], [20, 2]], [[30, 1], [40, 2]], [[55, 1], [60, 2]]],
-            dtype=np.float32)
-        reduced = np.mean(data, axis=axes, keepdims=keepdims == 1)
-        #print(reduced)
-        #[[[18.25]]]
-
-        expect(node,
-               inputs=[data],
-               outputs=[reduced],
-               name='test_reduce_mean_default_axes_keepdims_example')
-
-        np.random.seed(0)
-        data = np.random.uniform(-10, 10, shape).astype(np.float32)
-        reduced = np.mean(data, axis=axes, keepdims=keepdims == 1)
-
-        expect(node,
-               inputs=[data],
-               outputs=[reduced],
-               name='test_reduce_mean_default_axes_keepdims_random')
-
-    def test_reduce_mean_negative_axes_keepdims(self):
-        shape = [3, 2, 2]
-        axes = [-2]
-        keepdims = 1
-
-        node = onnx.helper.make_node('ReduceMean',
-                                     inputs=['data'],
-                                     outputs=['reduced'],
-                                     axes=axes,
-                                     keepdims=keepdims)
-
-        data = np.array(
-            [[[5, 1], [20, 2]], [[30, 1], [40, 2]], [[55, 1], [60, 2]]],
-            dtype=np.float32)
-        reduced = np.mean(data, axis=tuple(axes), keepdims=keepdims == 1)
-        # print(reduced)
-        # [[[12.5, 1.5]]
-        # [[35., 1.5]]
-        # [[57.5, 1.5]]]
-
-        expect(node,
-               inputs=[data],
-               outputs=[reduced],
-               name='test_reduce_mean_negative_axes_keepdims_example')
-
-        np.random.seed(0)
-        data = np.random.uniform(-10, 10, shape).astype(np.float32)
-        reduced = np.mean(data, axis=tuple(axes), keepdims=keepdims == 1)
-
-        expect(node,
-               inputs=[data],
-               outputs=[reduced],
-               name='test_reduce_mean_negative_axes_keepdims_random')
-
-    def test_squeeze(self):
-        node = onnx.helper.make_node(
-            'Squeeze',
-            inputs=['x'],
-            outputs=['y'],
-            axes=[0],
-        )
-        x = np.random.randn(1, 3, 4, 5).astype(np.float32)
-        y = np.squeeze(x, axis=0)
-
-        expect(node, inputs=[x], outputs=[y], name='test_squeeze')
-
-    def test_squeeze_negative_axes(self):
-        node = onnx.helper.make_node(
-            'Squeeze',
-            inputs=['x'],
-            outputs=['y'],
-            axes=[-2],
-        )
-        x = np.random.randn(1, 3, 1, 5).astype(np.float32)
-        y = np.squeeze(x, axis=-2)
-        expect(node, inputs=[x], outputs=[y], name='test_squeeze_negative_axes')
-
-    def test_unsqueeze_one_axis(self):
-        x = np.random.randn(3, 4, 5).astype(np.float32)
-
-        for i in range(x.ndim):
-            node = onnx.helper.make_node(
-                'Unsqueeze',
-                inputs=['x'],
-                outputs=['y'],
-                axes=[i],
-            )
-            y = np.expand_dims(x, axis=i)
-
-            expect(node,
-                   inputs=[x],
-                   outputs=[y],
-                   name='test_unsqueeze_axis_' + str(i))
-
-    def test_unsqueeze_two_axes(self):
-        x = np.random.randn(3, 4, 5).astype(np.float32)
-
-        node = onnx.helper.make_node(
-            'Unsqueeze',
-            inputs=['x'],
-            outputs=['y'],
-            axes=[1, 4],
-        )
-        y = np.expand_dims(x, axis=1)
-        y = np.expand_dims(y, axis=4)
-
-        expect(node, inputs=[x], outputs=[y], name='test_unsqueeze_two_axes')
-
-    def test_unsqueeze_three_axes(self):
-        x = np.random.randn(3, 4, 5).astype(np.float32)
-
-        node = onnx.helper.make_node(
-            'Unsqueeze',
-            inputs=['x'],
-            outputs=['y'],
-            axes=[2, 4, 5],
-        )
-        y = np.expand_dims(x, axis=2)
-        y = np.expand_dims(y, axis=4)
-        y = np.expand_dims(y, axis=5)
-
-        expect(node, inputs=[x], outputs=[y], name='test_unsqueeze_three_axes')
-
-    def test_unsqueeze_unsorted_axes(self):
-        x = np.random.randn(3, 4, 5).astype(np.float32)
-
-        node = onnx.helper.make_node(
-            'Unsqueeze',
-            inputs=['x'],
-            outputs=['y'],
-            axes=[5, 4, 2],
-        )
-        y = np.expand_dims(x, axis=2)
-        y = np.expand_dims(y, axis=4)
-        y = np.expand_dims(y, axis=5)
-
-        expect(node,
-               inputs=[x],
-               outputs=[y],
-               name='test_unsqueeze_unsorted_axes')
-
-    def test_unsqueeze_negative_axes(self):
-        node = onnx.helper.make_node(
-            'Unsqueeze',
-            inputs=['x'],
-            outputs=['y'],
-            axes=[-2],
-        )
-        x = np.random.randn(1, 3, 1, 5).astype(np.float32)
-        y = np.expand_dims(x, axis=-2)
-        expect(node,
-               inputs=[x],
-               outputs=[y],
-               name='test_unsqueeze_negative_axes')
-
-    def test_slice(self):
-        node = onnx.helper.make_node(
-            'Slice',
-            inputs=['x', 'starts', 'ends', 'axes', 'steps'],
-            outputs=['y'],
-        )
-
-        x = np.random.randn(20, 10, 5).astype(np.float32)
-        y = x[0:3, 0:10]
-        starts = np.array([0, 0], dtype=np.int64)
-        ends = np.array([3, 10], dtype=np.int64)
-        axes = np.array([0, 1], dtype=np.int64)
-        steps = np.array([1, 1], dtype=np.int64)
-
-        expect(node,
-               inputs=[x, starts, ends, axes, steps],
-               outputs=[y],
-               name='test_slice')
-
-    def test_slice_neg(self):
-        node = onnx.helper.make_node(
-            'Slice',
-            inputs=['x', 'starts', 'ends', 'axes', 'steps'],
-            outputs=['y'],
-        )
-
-        x = np.random.randn(20, 10, 5).astype(np.float32)
-        starts = np.array([0], dtype=np.int64)
-        ends = np.array([-1], dtype=np.int64)
-        axes = np.array([1], dtype=np.int64)
-        steps = np.array([1], dtype=np.int64)
-        y = x[:, 0:-1]
-
-        expect(node,
-               inputs=[x, starts, ends, axes, steps],
-               outputs=[y],
-               name='test_slice_neg')
-
-    # not support empty tensor
-    # def test_slice_start_out_of_bounds(self):
-    #     node = onnx.helper.make_node(
-    #         'Slice',
-    #         inputs=['x', 'starts', 'ends', 'axes', 'steps'],
-    #         outputs=['y'],
-    #     )
-
-    #     x = np.random.randn(20, 10, 5).astype(np.float32)
-    #     starts = np.array([1000], dtype=np.int64)
-    #     ends = np.array([1000], dtype=np.int64)
-    #     axes = np.array([1], dtype=np.int64)
-    #     steps = np.array([1], dtype=np.int64)
-    #     y = x[:, 1000:1000]
-
-    #     expect(node, inputs=[x, starts, ends, axes, steps], outputs=[y],
-    #            name='test_slice_start_out_of_bounds')
-
-    def test_slice_end_out_of_bounds(self):
-        node = onnx.helper.make_node(
-            'Slice',
-            inputs=['x', 'starts', 'ends', 'axes', 'steps'],
-            outputs=['y'],
-        )
-
-        x = np.random.randn(20, 10, 5).astype(np.float32)
-        starts = np.array([1], dtype=np.int64)
-        ends = np.array([1000], dtype=np.int64)
-        axes = np.array([1], dtype=np.int64)
-        steps = np.array([1], dtype=np.int64)
-        y = x[:, 1:1000]
-
-        expect(node,
-               inputs=[x, starts, ends, axes, steps],
-               outputs=[y],
-               name='test_slice_end_out_of_bounds')
-
-    def test_slice_default_axes(self):
-        node = onnx.helper.make_node(
-            'Slice',
-            inputs=['x', 'starts', 'ends'],
-            outputs=['y'],
-        )
-
-        x = np.random.randn(20, 10, 5).astype(np.float32)
-        starts = np.array([0, 0, 3], dtype=np.int64)
-        ends = np.array([20, 10, 4], dtype=np.int64)
-        y = x[:, :, 3:4]
-
-        expect(node,
-               inputs=[x, starts, ends],
-               outputs=[y],
-               name='test_slice_default_axes')
-
-    def test_slice_default_steps(self):
-        node = onnx.helper.make_node(
-            'Slice',
-            inputs=['x', 'starts', 'ends', 'axes'],
-            outputs=['y'],
-        )
-
-        x = np.random.randn(20, 10, 5).astype(np.float32)
-        starts = np.array([0, 0, 3], dtype=np.int64)
-        ends = np.array([20, 10, 4], dtype=np.int64)
-        axes = np.array([0, 1, 2], dtype=np.int64)
-        y = x[:, :, 3:4]
-
-        expect(node,
-               inputs=[x, starts, ends, axes],
-               outputs=[y],
-               name='test_slice_default_steps')
-
-    def test_slice_neg_steps(self):
-        node = onnx.helper.make_node(
-            'Slice',
-            inputs=['x', 'starts', 'ends', 'axes', 'steps'],
-            outputs=['y'],
-        )
-
-        x = np.random.randn(20, 10, 5).astype(np.float32)
-        starts = np.array([20, 10, 4], dtype=np.int64)
-        ends = np.array([0, 0, 1], dtype=np.int64)
-        axes = np.array([0, 1, 2], dtype=np.int64)
-        steps = np.array([-1, -3, -2])
-        y = x[20:0:-1, 10:0:-3, 4:1:-2]
-
-        expect(node,
-               inputs=[x, starts, ends, axes, steps],
-               outputs=[y],
-               name='test_slice_neg_steps')
-
-    def test_slice_negative_axes(self):
-        node = onnx.helper.make_node(
-            'Slice',
-            inputs=['x', 'starts', 'ends', 'axes'],
-            outputs=['y'],
-        )
-
-        x = np.random.randn(20, 10, 5).astype(np.float32)
-        starts = np.array([0, 0, 3], dtype=np.int64)
-        ends = np.array([20, 10, 4], dtype=np.int64)
-        axes = np.array([0, -2, -1], dtype=np.int64)
-        y = x[:, :, 3:4]
-
-        expect(node,
-               inputs=[x, starts, ends, axes],
-               outputs=[y],
-               name='test_slice_negative_axes')
-
-    def test_split_1d(self):
-        input = np.array([1., 2., 3., 4., 5., 6.]).astype(np.float32)
-
-        node = onnx.helper.make_node(
-            'Split',
-            inputs=['input'],
-            outputs=['output_1', 'output_2', 'output_3'],
-            axis=0)
-
-        expected_outputs = [
-            np.array([1., 2.]).astype(np.float32),
-            np.array([3., 4.]).astype(np.float32),
-            np.array([5., 6.]).astype(np.float32)
-        ]
-        expect(node,
-               inputs=[input],
-               outputs=[y for y in expected_outputs],
-               name='test_split_equal_parts_1d')
-
-        node = onnx.helper.make_node('Split',
-                                     inputs=['input'],
-                                     outputs=['output_1', 'output_2'],
-                                     axis=0,
-                                     split=[2, 4])
-
-        expected_outputs = [
-            np.array([1., 2.]).astype(np.float32),
-            np.array([3., 4., 5., 6.]).astype(np.float32)
-        ]
-        expect(node,
-               inputs=[input],
-               outputs=[y for y in expected_outputs],
-               name='test_split_variable_parts_1d')
-
-    def test_split_2d(self):
-        input = np.array([[1., 2., 3., 4., 5., 6.], [7., 8., 9., 10., 11.,
-                                                     12.]]).astype(np.float32)
-
-        node = onnx.helper.make_node('Split',
-                                     inputs=['input'],
-                                     outputs=['output_1', 'output_2'],
-                                     axis=1)
-
-        expected_outputs = [
-            np.array([[1., 2., 3.], [7., 8., 9.]]).astype(np.float32),
-            np.array([[4., 5., 6.], [10., 11., 12.]]).astype(np.float32)
-        ]
-
-        expect(node,
-               inputs=[input],
-               outputs=[y for y in expected_outputs],
-               name='test_split_equal_parts_2d')
-
-        node = onnx.helper.make_node('Split',
-                                     inputs=['input'],
-                                     outputs=['output_1', 'output_2'],
-                                     axis=1,
-                                     split=[2, 4])
-
-        expected_outputs = [
-            np.array([[1., 2.], [7., 8.]]).astype(np.float32),
-            np.array([[3., 4., 5., 6.], [9., 10., 11., 12.]]).astype(np.float32)
-        ]
-
-        expect(node,
-               inputs=[input],
-               outputs=[y for y in expected_outputs],
-               name='test_split_variable_parts_2d')
-
-    def test_split_default_values(self):
-        input = np.array([1., 2., 3., 4., 5., 6.]).astype(np.float32)
-
-        # If axis is not specified, split is applied on default axis 0
-        node = onnx.helper.make_node(
-            'Split',
-            inputs=['input'],
-            outputs=['output_1', 'output_2', 'output_3'])
-
-        expected_outputs = [
-            np.array([1., 2.]).astype(np.float32),
-            np.array([3., 4.]).astype(np.float32),
-            np.array([5., 6.]).astype(np.float32)
-        ]
-        expect(node,
-               inputs=[input],
-               outputs=[y for y in expected_outputs],
-               name='test_split_equal_parts_default_axis')
-
-        node = onnx.helper.make_node('Split',
-                                     inputs=['input'],
-                                     outputs=['output_1', 'output_2'],
-                                     split=[2, 4])
-
-        expected_outputs = [
-            np.array([1., 2.]).astype(np.float32),
-            np.array([3., 4., 5., 6.]).astype(np.float32)
-        ]
-        expect(node,
-               inputs=[input],
-               outputs=[y for y in expected_outputs],
-               name='test_split_variable_parts_default_axis')
-
-    # not support empty tensor
-    # def test_split_zero_size_splits(self):
-    #     input = np.array([]).astype(np.float32)
-
-    #     # Split emtpy tensor to tensors of size zero
-    #     node = onnx.helper.make_node(
-    #         'Split',
-    #         inputs=['input'],
-    #         outputs=['output_1', 'output_2', 'output_3'],
-    #         split=[0, 0, 0]
-    #     )
-
-    #     expected_outputs = [np.array([]).astype(np.float32), np.array([]).astype(np.float32), np.array([]).astype(np.float32)]
-    #     expect(node, inputs=[input], outputs=[y for y in expected_outputs], name='test_split_zero_size_splits')
-
-    def test_gather_0(self):
-        node = onnx.helper.make_node(
-            'Gather',
-            inputs=['data', 'indices'],
-            outputs=['y'],
-            axis=0,
-        )
-        data = np.random.randn(5, 4, 3, 2).astype(np.float32)
-        indices = np.array([0, 1, 3])
-        y = np.take(data, indices, axis=0)
-
-        expect(node,
-               inputs=[data, indices.astype(np.int64)],
-               outputs=[y],
-               name='test_gather_0')
-
-    def test_gather_1(self):
-        node = onnx.helper.make_node(
-            'Gather',
-            inputs=['data', 'indices'],
-            outputs=['y'],
-            axis=1,
-        )
-        data = np.random.randn(5, 4, 3, 2).astype(np.float32)
-        indices = np.array([0, 1, 3])
-        y = np.take(data, indices, axis=1)
-
-        expect(node,
-               inputs=[data, indices.astype(np.int64)],
-               outputs=[y],
-               name='test_gather_1')
-
-    def test_gather_negative_indices(self):
-        node = onnx.helper.make_node(
-            'Gather',
-            inputs=['data', 'indices'],
-            outputs=['y'],
-            axis=0,
-        )
-        data = np.arange(10).astype(np.float32)
-        indices = np.array([0, -9, -10])
-        y = np.take(data, indices, axis=0)
-
-        expect(node,
-               inputs=[data, indices.astype(np.int64)],
-               outputs=[y],
-               name='test_gather_negative_indices')
-
-    def test_tile(self):
-        node = onnx.helper.make_node('Tile', inputs=['x', 'y'], outputs=['z'])
-
-        x = np.random.rand(2, 3, 4, 5).astype(np.float32)
-
-        repeats = np.random.randint(low=1, high=10,
-                                    size=(np.ndim(x),)).astype(np.int64)
-
-        z = np.tile(x, repeats)
-
-        expect(node, inputs=[x, repeats], outputs=[z], name='test_tile')
-
-    def test_tile_precomputed(self):
-        node = onnx.helper.make_node('Tile', inputs=['x', 'y'], outputs=['z'])
-
-        x = np.array([[0, 1], [2, 3]], dtype=np.float32)
-
-        repeats = np.array([2, 2], dtype=np.int64)
-
-        z = np.array([[0, 1, 0, 1], [2, 3, 2, 3], [0, 1, 0, 1], [2, 3, 2, 3]],
-                     dtype=np.float32)
-
-        expect(node,
-               inputs=[x, repeats],
-               outputs=[z],
-               name='test_tile_precomputed')
-
-    def test_onehot_without_axis(self):
-        on_value = 5
-        off_value = 2
-        output_type = np.int32
-        node = onnx.helper.make_node('OneHot',
-                                     inputs=['indices', 'depth', 'values'],
-                                     outputs=['y'])
-        indices = np.array([0, 7, 8], dtype=np.int64)
-        depth = np.float32(12)
-        values = np.array([off_value, on_value], dtype=output_type)
-        y = one_hot(indices, depth, dtype=output_type)
-        y = y * (on_value - off_value) + off_value
-        expect(node,
-               inputs=[indices, depth, values],
-               outputs=[y],
-               name='test_onehot_without_axis')
-
-    def test_onehot_with_axis(self):
-        axisValue = 1
-        on_value = 3
-        off_value = 1
-        output_type = np.float32
-        node = onnx.helper.make_node('OneHot',
-                                     inputs=['indices', 'depth', 'values'],
-                                     outputs=['y'],
-                                     axis=axisValue)
-        indices = np.array([[1, 9], [2, 4]], dtype=np.float32)
-        depth = np.array([10], dtype=np.float32)
-        values = np.array([off_value, on_value], dtype=output_type)
-        y = one_hot(indices, depth, axis=axisValue, dtype=output_type)
-        y = y * (on_value - off_value) + off_value
-        expect(node,
-               inputs=[indices, depth, values],
-               outputs=[y],
-               name='test_onehot_with_axis')
-
-    def test_onehot_with_negative_indices(self):
-        axisValue = 1
-        on_value = 3
-        off_value = 1
-        output_type = np.float32
-        node = onnx.helper.make_node('OneHot',
-                                     inputs=['indices', 'depth', 'values'],
-                                     outputs=['y'],
-                                     axis=axisValue)
-        indices = np.array([0, -7, -8], dtype=np.int64)
-
-        depth = np.array([10], dtype=np.float32)
-        values = np.array([off_value, on_value], dtype=output_type)
-        y = one_hot(indices, depth, axis=axisValue, dtype=output_type)
-        y = y * (on_value - off_value) + off_value
-        expect(node,
-               inputs=[indices, depth, values],
-               outputs=[y],
-               name='test_onehot_negative_indices')
-
-    def test_onehot_with_negative_axis(self):
-        axisValue = -2
-        on_value = 3
-        off_value = 1
-        output_type = np.float32
-        node = onnx.helper.make_node('OneHot',
-                                     inputs=['indices', 'depth', 'values'],
-                                     outputs=['y'],
-                                     axis=axisValue)
-        indices = np.array([[1, 9], [2, 4]], dtype=np.float32)
-        depth = np.array([10], dtype=np.float32)
-        values = np.array([off_value, on_value], dtype=output_type)
-        y = one_hot(indices, depth, axis=axisValue, dtype=output_type)
-        y = y * (on_value - off_value) + off_value
-        expect(node,
-               inputs=[indices, depth, values],
-               outputs=[y],
-               name='test_onehot_with_negative_axis')
-
-
-def one_hot(indices, depth, axis=-1, dtype=np.float32):  # type: ignore
-    ''' Compute one hot from indices at a specific axis '''
-    values = np.asarray(indices)
-    rank = len(values.shape)
-    depth_range = np.arange(depth)
-    if axis < 0:
-        axis += (rank + 1)
-    ls = values.shape[0:axis]
-    rs = values.shape[axis:rank]
-    targets = np.reshape(depth_range,
-                         (1,) * len(ls) + depth_range.shape + (1,) * len(rs))
-    values = np.reshape(np.mod(values, depth), ls + (1,) + rs)
-    return np.asarray(targets == values, dtype=dtype)
-
-
-def gemm_reference_implementation(
-    A,
-    B,
-    C=None,
-    alpha=1.,
-    beta=1.,
-    transA=0,
-    transB=0
-):  # type: (np.ndarray, np.ndarray, Optional[np.ndarray], float, float, int, int) -> np.ndarray
-    A = A if transA == 0 else A.T
-    B = B if transB == 0 else B.T
-    C = C if C is not None else np.array(0)
-
-    Y = alpha * np.dot(A, B) + beta * C
-
-    return Y
-
-
-# return padding shape of conv2d or pooling
-def get_pad_shape(
-    auto_pad,  # type: Text
-    input_spatial_shape,  # type: Sequence[int]
-    kernel_spatial_shape,  # type: Sequence[int]
-    strides_spatial,  # type: Sequence[int]
-    output_spatial_shape  # type: Sequence[int]
-):  # type: (...) -> Sequence[int]
-    pad_shape = [0] * len(input_spatial_shape)
-    if auto_pad in ('SAME_UPPER', 'SAME_LOWER'):
-        for i in range(len(input_spatial_shape)):
-            pad_shape[i] = (output_spatial_shape[i] - 1) * strides_spatial[i] + \
-                kernel_spatial_shape[i] - input_spatial_shape[i]
-    elif auto_pad == 'VALID':
-        pass
-    return pad_shape
-
-
-# return output shape of conv2d or pooling
-
-
-def get_output_shape(
-    auto_pad,  # type: Text
-    input_spatial_shape,  # type: Sequence[int]
-    kernel_spatial_shape,  # type: Sequence[int]
-    strides_spatial  # type: Sequence[int]
-):  # type: (...) -> Sequence[int]
-    out_shape = [0] * len(input_spatial_shape)
-    if auto_pad in ('SAME_UPPER', 'SAME_LOWER'):
-        for i in range(len(input_spatial_shape)):
-            out_shape[i] = int(
-                np.ceil(
-                    float(input_spatial_shape[i]) / float(strides_spatial[i])))
-    elif auto_pad == 'VALID':
-        for i in range(len(input_spatial_shape)):
-            out_shape[i] = int(
-                np.ceil(
-                    float(input_spatial_shape[i] -
-                          (kernel_spatial_shape[i] - 1)) /
-                    float(strides_spatial[i])))
-    return out_shape
-
-
-def pool(
-    padded,  # type: np.ndarray
-    x_shape,  # type: Sequence[int]
-    kernel_shape,  # type: Sequence[int]
-    strides_shape,  # type: Sequence[int]
-    out_shape,  # type: Sequence[int]
-    pad_shape,  # type: Sequence[int]
-    pooling_type,  # type: Text
-    count_include_pad=0  # type: int
-):  # type: (...) -> np.ndarray
-    spatial_size = len(x_shape) - 2
-    y = np.zeros([x_shape[0], x_shape[1]] + list(out_shape))
-
-    for shape in itertools.product(
-            range(x_shape[0]), range(x_shape[1]), *[
-                range(
-                    int((x_shape[i + 2] + pad_shape[i] - kernel_shape[i]) /
-                        strides_shape[i] + 1)) for i in range(spatial_size)
-            ]):
-        window = padded[shape[0], shape[1]]
-        window_vals = np.array([
-            window[i] for i in list(
-                itertools.product(*[
-                    range(strides_shape[i] *
-                          shape[i + 2], strides_shape[i] * shape[i + 2] +
-                          kernel_shape[i]) for i in range(spatial_size)
-                ]))
-        ])
-        if pooling_type == 'AVG':
-            f = np.average
-        elif pooling_type == 'MAX':
-            f = np.max
-        else:
-            raise NotImplementedError(
-                'Pooling type {} does not support. Should be AVG, MAX'.format(
-                    pooling_type))
-
-        if count_include_pad == 1 and pooling_type == 'AVG':
-            y[shape] = f(window_vals)
-        else:
-            y[shape] = f(window_vals[np.where(~np.isnan(window_vals))])
-    return y.astype(np.float32)
-
+globals().update(test_cases)
 
 if __name__ == '__main__':
-    unittest.main()
+    unittest.main()
\ No newline at end of file
diff --git a/test/python/test_operation.py b/test/python/test_operation.py
index 557fd49..54d2513 100755
--- a/test/python/test_operation.py
+++ b/test/python/test_operation.py
@@ -21,6 +21,7 @@
 from singa import tensor
 from singa import singa_wrap as singa
 from singa import autograd
+from singa import layer
 from singa import singa_wrap
 from cuda_helper import gpu_dev, cpu_dev
 
@@ -46,7 +47,7 @@
         y_shape: the shape of result
         x_shape: the shape of x
     Return:
-        a tuple refering the axes 
+        a tuple refering the axes
     """
     res = []
     j = len(x_shape) - 1
@@ -116,9 +117,9 @@
         self._greater_helper(gpu_dev)
 
     def _conv2d_helper(self, dev):
-        # (in_channels, out_channels, kernel_size)
-        conv_0 = autograd.Conv2d(3, 1, 2)
-        conv_without_bias_0 = autograd.Conv2d(3, 1, 2, bias=False)
+        # (out_channels, kernel_size)
+        conv_0 = layer.Conv2d(1, 2)
+        conv_without_bias_0 = layer.Conv2d(1, 2, bias=False)
 
         cpu_input_tensor = tensor.Tensor(shape=(2, 3, 3, 3), device=dev)
         cpu_input_tensor.gaussian(0.0, 1.0)
@@ -148,46 +149,31 @@
     def _conv_same_pad(self, dev, pad_mode, is_2d):
         if is_2d:
             x_h, w_h, k_h, p_h = 32, 4, 4, 1
-            if pad_mode == "SAME_LOWER":
-                o_p = (0, 1, 0, 1)
-            else:
-                o_p = (1, 0, 1, 0)
         else:
             x_h, w_h, k_h, p_h = 1, 1, 1, 0
-            if pad_mode == "SAME_LOWER":
-                o_p = (0, 0, 0, 1)
-            else:
-                o_p = (0, 0, 1, 0)
+
         x = tensor.Tensor(shape=(3, 3, x_h, 32), device=dev)
         x.gaussian(0.0, 1.0)
 
-        w = tensor.Tensor(shape=(3, 3, w_h, 4), device=dev)
-        w.gaussian(0.0, 1.0)
-
         # with the same padding, the padding should be 3
         # for SAME_UPPER, is (1, 1) + (0, 1)
         # for SAME_LOWER, is (1, 1) + (1, 0)
 
-        x_shape = x.shape
         kernel = (k_h, 4)
         padding = (p_h, 1)
         stride = (1, 1)
         group = 1
         bias = False
-        in_channels = x_shape[1]
-        w_shape = w.shape
-        out_channels = w_shape[0]
-        assert w_shape[1] == in_channels // group
+        out_channels = 3
 
-        if dev == cpu_dev:
-            handle = singa.ConvHandle(x.data, kernel, stride, padding,
-                                      in_channels, out_channels, bias, group)
-        else:
-            handle = singa.CudnnConvHandle(x.data, kernel, stride, padding,
-                                           in_channels, out_channels, bias,
-                                           group)
-        y = autograd._Conv2d(handle, o_p)(x, w)[0]
+        conv_0 = layer.Conv2d(out_channels,
+                              kernel,
+                              stride=stride,
+                              group=group,
+                              bias=bias,
+                              pad_mode=pad_mode)
 
+        y = conv_0(x)
         dy = np.ones((3, 3, x_h, 32), dtype=np.float32)
         dy = tensor.from_numpy(dy)
         dy.to_device(dev)
@@ -218,16 +204,9 @@
     def _pooling_same_pad(self, dev, pad_mode, is_2d):
         if is_2d:
             x_h, k_h, p_h = 32, 4, 1
-            if pad_mode == "SAME_LOWER":
-                o_p = (0, 1, 0, 1)
-            else:
-                o_p = (1, 0, 1, 0)
         else:
             x_h, k_h, p_h = 1, 1, 0
-            if pad_mode == "SAME_LOWER":
-                o_p = (0, 0, 0, 1)
-            else:
-                o_p = (0, 0, 1, 0)
+
         x = tensor.Tensor(shape=(3, 3, x_h, 32), device=dev)
         x.gaussian(0.0, 1.0)
 
@@ -235,19 +214,14 @@
         # for SAME_UPPER, is (1, 1) + (0, 1)
         # for SAME_LOWER, is (1, 1) + (1, 0)
 
-        x_shape = x.shape
         kernel = (k_h, 4)
         # we add 4 padding here and hope the conv and trim one padding then
         padding = (p_h, 1)
         stride = (1, 1)
 
-        if dev == cpu_dev:
-            handle = singa.PoolingHandle(x.data, kernel, stride, padding, True)
-        else:
-            handle = singa.CudnnPoolingHandle(x.data, kernel, stride, padding,
-                                              True)
+        pooling = layer.Pooling2d(kernel, stride=stride, pad_mode=pad_mode)
 
-        y = autograd._Pooling2d(handle, o_p)(x)[0]
+        y = pooling(x)
 
         dy = np.ones((3, 3, x_h, 32), dtype=np.float32)
         dy = tensor.from_numpy(dy)
@@ -319,11 +293,14 @@
             in_channels = 1
         else:
             in_channels = 8
-        separ_conv = autograd.SeparableConv2d(in_channels, 16, 3, padding=1)
+        separ_conv = layer.SeparableConv2d(16, 3, padding=1)
 
         x = np.random.random((10, in_channels, 28, 28)).astype(np.float32)
         x = tensor.Tensor(device=dev, data=x)
 
+        y = separ_conv(x)
+        self.check_shape(y.shape, (10, 16, 28, 28))
+
         y1 = separ_conv.depthwise_conv(x)
         y2 = separ_conv.point_conv(y1)
 
@@ -338,9 +315,6 @@
         self.check_shape(dx.shape(), (10, in_channels, 28, 28))
         self.check_shape(dW_spacial.shape(), (in_channels, 1, 3, 3))
 
-        y = separ_conv(x)
-        self.check_shape(y.shape, (10, 16, 28, 28))
-
     def test_SeparableConv2d_cpu(self):
         self._SeparableConv2d_helper(cpu_dev)
 
@@ -349,7 +323,7 @@
         self._SeparableConv2d_helper(gpu_dev)
 
     def _batchnorm2d_helper(self, dev):
-        batchnorm_0 = autograd.BatchNorm2d(3)
+        batchnorm_0 = layer.BatchNorm2d(3)
 
         cpu_input_tensor = tensor.Tensor(shape=(2, 3, 3, 3), device=dev)
         cpu_input_tensor.gaussian(0.0, 1.0)
@@ -411,7 +385,7 @@
     def _vanillaRNN_gpu_tiny_ops_shape_check_helper(self, dev):
         # gradients shape check.
         inputs, target, h0 = prepare_inputs_targets_for_rnn_test(dev)
-        rnn = autograd.RNN(3, 2)
+        rnn = layer.RNN(3, 2)
 
         hs, _ = rnn(inputs, h0)
 
@@ -437,7 +411,7 @@
         c_0 = np.random.random((2, 1)).astype(np.float32)
         c0 = tensor.Tensor(device=dev, data=c_0)
 
-        rnn = autograd.LSTM(3, 2)
+        rnn = layer.LSTM(3, 2)
 
         hs, _, _ = rnn(inputs, (h0, c0))
         loss = autograd.softmax_cross_entropy(hs[0], target[0])
@@ -460,7 +434,7 @@
     def _numerical_gradients_check_for_vallina_rnn_helper(self, dev):
         inputs, target, h0 = prepare_inputs_targets_for_rnn_test(dev)
 
-        rnn = autograd.RNN(3, 2)
+        rnn = layer.RNN(3, 2)
 
         def valinna_rnn_forward():
             hs, _ = rnn(inputs, h0)
@@ -475,11 +449,78 @@
         loss1 = valinna_rnn_forward()
         auto_grads = autograd.gradients(loss1)
 
-        for param in rnn.params:
-            auto_grad = tensor.to_numpy(auto_grads[param])
+        params = rnn.get_params()
+        for key, param in params.items():
+            auto_grad = tensor.to_numpy(auto_grads[id(param)])
 
             self.gradients_check(valinna_rnn_forward, param, auto_grad, dev=dev)
 
+    def _gradient_check_cudnn_rnn(self, mode="vanilla", dev=gpu_dev):
+        seq = 10
+        bs = 2
+        fea = 10
+        hid = 10
+        x = np.random.random((seq, bs, fea)).astype(np.float32)
+        tx = tensor.Tensor(device=dev, data=x)
+        y = np.random.random((seq, bs, hid)).astype(np.float32)
+        y = np.reshape(y, (-1, hid))
+        ty = tensor.Tensor(device=dev, data=y)
+        rnn = layer.CudnnRNN(hid, rnn_mode=mode, return_sequences=True)
+
+        def vanilla_rnn_forward():
+            out = rnn(tx)
+            out = autograd.reshape(out, (-1, hid))
+            loss = autograd.softmax_cross_entropy(out, ty)
+            return loss
+
+        loss = vanilla_rnn_forward()
+        auto_grads = autograd.gradients(loss)
+
+        params = rnn.get_params()
+        for key, param in params.items():
+            auto_grad = tensor.to_numpy(auto_grads[id(param)])
+            self.gradients_check(vanilla_rnn_forward, param, auto_grad, dev=dev)
+
+    @unittest.skipIf(not singa_wrap.USE_CUDA, 'CUDA is not enabled')
+    def test_gradient_check_cudnn_rnn_vanilla(self):
+        self._gradient_check_cudnn_rnn(mode="vanilla", dev=gpu_dev)
+
+    @unittest.skipIf(not singa_wrap.USE_CUDA, 'CUDA is not enabled')
+    def test_gradient_check_cudnn_rnn_lstm(self):
+        self._gradient_check_cudnn_rnn(mode="lstm", dev=gpu_dev)
+
+    # Cos Sim Gradient Check
+    def _gradient_check_cossim(self, dev=gpu_dev):
+        bs = 2
+        vec = 3
+        ta = tensor.random((bs, vec), dev)
+        tb = tensor.random((bs, vec), dev)
+        # treat ta, tb as params
+        ta.stores_grad = True
+        tb.stores_grad = True
+        ty = tensor.random((bs,), dev)
+
+        def _forward():
+            out = autograd.cossim(ta, tb)
+            loss = autograd.mse_loss(out, ty)
+            return loss
+
+        loss = _forward()
+        auto_grads = autograd.gradients(loss)
+
+        params = {id(ta): ta, id(tb): tb}
+
+        for key, param in params.items():
+            auto_grad = tensor.to_numpy(auto_grads[id(param)])
+            self.gradients_check(_forward, param, auto_grad, dev=dev)
+
+    @unittest.skipIf(not singa_wrap.USE_CUDA, 'CUDA is not enabled')
+    def test_gradient_check_cossim_gpu(self):
+        self._gradient_check_cossim(dev=gpu_dev)
+
+    def test_gradient_check_cossim_cpu(self):
+        self._gradient_check_cossim(dev=cpu_dev)
+
     def test_numerical_gradients_check_for_vallina_rnn_cpu(self):
         self._numerical_gradients_check_for_vallina_rnn_helper(cpu_dev)
 
@@ -492,7 +533,7 @@
         c_0 = np.zeros((2, 2)).astype(np.float32)
         c0 = tensor.Tensor(device=dev, data=c_0)
 
-        rnn = autograd.LSTM(3, 2)
+        rnn = layer.LSTM(3, 2)
 
         def lstm_forward():
             hs, _, _ = rnn(inputs, (h0, c0))
@@ -506,8 +547,9 @@
         loss1 = lstm_forward()
         auto_grads = autograd.gradients(loss1)
 
-        for param in rnn.params:
-            auto_grad = tensor.to_numpy(auto_grads[param])
+        params = rnn.get_params()
+        for key, param in params.items():
+            auto_grad = tensor.to_numpy(auto_grads[id(param)])
 
             self.gradients_check(lstm_forward, param, auto_grad, dev=dev)
 
@@ -529,7 +571,7 @@
         t.to_device(dev)
 
         loss = autograd.mse_loss(x, t)
-        dx = loss.creator.backward()[0]
+        dx = loss.creator.backward()
 
         loss_np = tensor.to_numpy(loss)[0]
         self.assertAlmostEqual(loss_np, 0.0366666, places=4)
@@ -2525,21 +2567,28 @@
             shapeB = config[5]
             shapeC = config[6]
             shapeY = config[7]
+
             A = np.random.randn(*shapeA).astype(np.float32)
-            B = np.random.randn(*shapeB).astype(np.float32)
-            C = np.random.randn(*shapeC).astype(np.float32)
             DY = np.ones(shapeY, dtype=np.float32)
 
+            if transB == 0:
+                out_features = shapeB[1]
+            else:
+                out_features = shapeB[0]
+
             a = tensor.from_numpy(A)
             a.to_device(dev)
-            b = tensor.from_numpy(B)
-            b.to_device(dev)
-            c = tensor.from_numpy(C)
-            c.to_device(dev)
             dy = tensor.from_numpy(DY)
             dy.to_device(dev)
 
-            result = autograd.gemm(a, b, c, alpha, beta, transA, transB)
+            gemm = layer.Gemm(out_features, alpha, beta, transA == 1,
+                              transB == 1)
+            result = gemm(a)
+
+            params = gemm.get_params()
+            B = tensor.to_numpy(params['W'])
+            C = tensor.to_numpy(params['b'])
+
             da, db, dc = result.creator.backward(dy.data)
 
             # Y = alpha * A' * B' + beta * C
@@ -2832,6 +2881,100 @@
     def test_ceil_gpu(self):
         self.ceil_test(gpu_dev)
 
+    def floor_test(self,dev):
+        X = np.array([-1.9,1.2]).astype(np.float32)
+        DY = np.ones((2),dtype=np.float32)
+        y = np.floor(X)
+        x = tensor.from_numpy(X)
+        dy = tensor.from_numpy(DY)
+        x.to_device(dev)
+        dy.to_device(dev)
+
+        result = autograd.floor(x)
+        dx = result.creator.backward(dy.data)
+        DX = np.zeros((2),dtype=np.float32)
+        np.testing.assert_array_almost_equal(tensor.to_numpy(result),y,decimal=5)
+        np.testing.assert_array_almost_equal(tensor.to_numpy(tensor.from_raw_tensor(dx)),DX,decimal=5)
+    
+    def test_floor_cpu(self):
+        self.floor_test(cpu_dev)
+
+    @unittest.skipIf(not singa_wrap.USE_CUDA, 'CUDA is not enabled')
+    def test_floor_gpu(self):
+        self.floor_test(gpu_dev)    
+
+    def _test_scatter_elements(self, dev):
+        # testing witout axis
+        data = np.zeros((3, 3), dtype=np.float32)
+        indices = np.array([[1, 0, 2], [0, 2, 1]], dtype=np.int32)
+        updates = np.array([[1.0, 1.1, 1.2], [2.0, 2.1, 2.2]], dtype=np.float32)
+        output = np.array([[2.0, 1.1, 0.0], [1.0, 0.0, 2.2], [0.0, 2.1, 1.2]],
+                          dtype=np.float32)
+
+        data = tensor.from_numpy(data)
+        indices = tensor.from_numpy(indices)
+        updates = tensor.from_numpy(updates)
+        data.to_device(dev)
+        indices.to_device(dev)
+        updates.to_device(dev)
+
+        result = autograd.scatter_elements(data, indices, updates)
+        dy = tensor.from_numpy(np.ones(data.shape, dtype=np.float32))
+        dx = result.creator.backward(dy.data)
+        np.testing.assert_almost_equal(tensor.to_numpy(result),
+                                       output,
+                                       decimal=5)
+        self.check_shape(dx.shape(), data.shape)
+
+        # testing with axis
+        data = np.array([[1.0, 2.0, 3.0, 4.0, 5.0]], dtype=np.float32)
+        indices = np.array([[1, 3]], dtype=np.int32)
+        updates = np.array([[1.1, 2.1]], dtype=np.float32)
+        output = np.array([[1.0, 1.1, 3.0, 2.1, 5.0]], dtype=np.float32)
+
+        data = tensor.from_numpy(data)
+        indices = tensor.from_numpy(indices)
+        updates = tensor.from_numpy(updates)
+        data.to_device(dev)
+        indices.to_device(dev)
+        updates.to_device(dev)
+
+        result = autograd.scatter_elements(data, indices, updates, axis=1)
+        dy = tensor.from_numpy(np.ones(data.shape, dtype=np.float32))
+        dx = result.creator.backward(dy.data)
+        np.testing.assert_almost_equal(tensor.to_numpy(result),
+                                       output,
+                                       decimal=5)
+        self.check_shape(dx.shape(), data.shape)
+
+        # testing with negative indices:
+        data = np.array([[1.0, 2.0, 3.0, 4.0, 5.0]], dtype=np.float32)
+        indices = np.array([[1, -3]], dtype=np.int64)
+        updates = np.array([[1.1, 2.1]], dtype=np.float32)
+        output = np.array([[1.0, 1.1, 2.1, 4.0, 5.0]], dtype=np.float32)
+
+        data = tensor.from_numpy(data)
+        indices = tensor.from_numpy(indices)
+        updates = tensor.from_numpy(updates)
+        data.to_device(dev)
+        indices.to_device(dev)
+        updates.to_device(dev)
+
+        result = autograd.scatter_elements(data, indices, updates, axis=1)
+        dy = tensor.from_numpy(np.ones(data.shape, dtype=np.float32))
+        dx = result.creator.backward(dy.data)
+        np.testing.assert_almost_equal(tensor.to_numpy(result),
+                                       output,
+                                       decimal=5)
+        self.check_shape(dx.shape(), data.shape)
+
+    def test_cpu_scatter_elements(self):
+        self._test_scatter_elements(cpu_dev)
+
+    @unittest.skipIf(not singa_wrap.USE_CUDA, 'CUDA is not enabled')
+    def test_gpu_scatter_elements(self):
+        self._test_scatter_elements(gpu_dev)
+
     def split_test(self, dev):
         X = np.array([1., 2., 3., 4., 5., 6.]).astype(np.float32)
         DY1 = np.ones((2), dtype=np.float32)
@@ -3015,6 +3158,499 @@
     def test_onehot_gpu(self):
         self.onehot_test(gpu_dev)
 
+    @unittest.skipIf(not singa_wrap.USE_CUDA, 'CUDA is not enabled')
+    def test_cudnn_rnn_operation(self, dev=gpu_dev):
+        # init params, inputs
+        hidden_size = 7
+        seq_length = 5
+        batch_size = 6
+        feature_size = 3
+        directions = 2
+        num_layers = 2
+
+        for mode in [0, 1, 2, 3]:  # 0-relu, 1-tanh, 2-lstm, 3-gru
+            x = tensor.Tensor(shape=(seq_length, batch_size, feature_size),
+                              device=dev).gaussian(0, 1)
+            hx = tensor.Tensor(shape=(num_layers * directions, batch_size,
+                                      hidden_size),
+                               device=dev).gaussian(0, 1)
+            cx = tensor.Tensor(shape=(num_layers * directions, batch_size,
+                                      hidden_size),
+                               device=dev).gaussian(0, 1)
+            dy = tensor.Tensor(shape=(seq_length, batch_size,
+                                      directions * hidden_size),
+                               device=dev).gaussian(0, 1)
+
+            # init cudnn rnn op
+            rnn_handle = singa.CudnnRNNHandle(x.data,
+                                              hidden_size,
+                                              mode,
+                                              num_layers=num_layers,
+                                              dropout=0.1,
+                                              bidirectional=1)
+
+            w = tensor.Tensor(shape=(rnn_handle.weights_size,),
+                              device=dev).gaussian(0, 1)
+
+            # return sequence, y shape = {seq, bs, hidden}
+            # init operator/operation
+            _rnn = autograd._RNN(rnn_handle, return_sequences=True)
+
+            # forward
+            y = _rnn(x, hx, cx, w)[0]
+            assert y.shape == dy.shape
+            # print(ys)
+
+            # backward
+            dx, dhx, dcx, dw = _rnn.backward(dy.data)
+
+            # return no sequence, y shape = {bs, hidden}
+            _rnn = autograd._RNN(rnn_handle, return_sequences=False)
+            dy = tensor.Tensor(shape=(batch_size, directions * hidden_size),
+                               device=dev).gaussian(0, 1)
+            y = _rnn(x, hx, cx, w)[0]
+
+            assert y.shape == dy.shape
+            # backward
+            dx, dhx, dcx, dw = _rnn.backward(dy.data)
+
+    def cossim_helper(self, dev):
+        A = np.random.randn(*[3, 10]).astype(np.float32)
+        B = np.random.randn(*[3, 10]).astype(np.float32)
+
+        a = tensor.from_numpy(A)
+        a.to_device(dev)
+        b = tensor.from_numpy(B)
+        b.to_device(dev)
+
+        DY = np.random.randn(3).astype(np.float32)
+        dy = tensor.from_numpy(DY)
+        dy.to_device(dev)
+
+        y = autograd.cossim(a, b)
+        da, db = y.creator.backward(dy.data)  # CTensor
+
+        self.check_shape(y.shape, (3,))
+        self.check_shape(da.shape(), (3, 10))
+        self.check_shape(db.shape(), (3, 10))
+
+    def test_cossim_cpu(self):
+        self.cossim_helper(cpu_dev)
+
+    @unittest.skipIf(not singa_wrap.USE_CUDA, 'CUDA is not enabled')
+    def test_cossim_gpu(self):
+        self.cossim_helper(gpu_dev)
+
+    def expand_helper(self, dev):
+        shape = [3, 1]
+        X = np.reshape(np.arange(1, np.prod(shape) + 1, dtype=np.float32),
+                       shape)
+        x = tensor.from_numpy(X)
+        x.to_device(dev)
+
+        # dim_changed
+        new_shape = [2, 1, 6]
+        y_t = X * np.ones(new_shape, dtype=np.float32)
+        dy = tensor.from_numpy(y_t)
+        dy.to_device(dev)
+        y = autograd.expand(x, new_shape)
+        dx = y.creator.backward(dy.data)
+        np.testing.assert_array_almost_equal(tensor.to_numpy(y), y_t)
+        self.check_shape(dx.shape(), tuple(shape))
+
+        # dim_unchanged
+        new_shape_2 = [3, 4]
+        y_t2 = np.tile(X, 4)
+        dy2 = tensor.from_numpy(y_t2)
+        dy2.to_device(dev)
+        y2 = autograd.expand(x, new_shape_2)
+        dx2 = y2.creator.backward(dy2.data)
+        np.testing.assert_array_almost_equal(tensor.to_numpy(y2), y_t2)
+        self.check_shape(dx2.shape(), tuple(shape))
+
+    def test_expand_cpu(self):
+        self.expand_helper(cpu_dev)
+
+    @unittest.skipIf(not singa_wrap.USE_CUDA, 'CUDA is not enabled')
+    def test_expand_gpu(self):
+        self.expand_helper(gpu_dev)
+
+    def pad_helper(self, dev):
+        X = np.array([
+            [1.0, 1.2],
+            [2.3, 3.4],
+            [4.5, 5.7],
+        ]).astype(np.float32)
+        Y1 = np.array([
+            [0.0, 0.0, 1.0, 1.2],
+            [0.0, 0.0, 2.3, 3.4],
+            [0.0, 0.0, 4.5, 5.7],
+        ],).astype(np.float32)
+        Y2 = np.array([
+            [1.0, 1.2, 1.0, 1.2],
+            [2.3, 3.4, 2.3, 3.4],
+            [4.5, 5.7, 4.5, 5.7],
+        ],).astype(np.float32)
+        Y3 = np.array([
+            [1.0, 1.0, 1.0, 1.2],
+            [2.3, 2.3, 2.3, 3.4],
+            [4.5, 4.5, 4.5, 5.7],
+        ],).astype(np.float32)
+
+        x = tensor.from_numpy(X)
+        x.to_device(dev)
+        pads = [0, 2, 0, 0]
+
+        DY = np.random.randn(3, 4).astype(np.float32)
+        dy = tensor.from_numpy(DY)
+        dy.to_device(dev)
+
+        y1 = autograd.pad(x, "constant", pads)
+        y2 = autograd.pad(x, "reflect", pads)
+        y3 = autograd.pad(x, "edge", pads)
+        dx1 = y1.creator.backward(dy.data)
+        dx2 = y2.creator.backward(dy.data)
+        dx3 = y3.creator.backward(dy.data)
+        pad_width = []
+        half_width = len(pads) // 2
+        for i in range(half_width):
+            pad_width += [[pads[i], pads[i + half_width]]]
+
+        np.testing.assert_array_almost_equal(tensor.to_numpy(y1),
+                                             np.pad(
+                                                 X,
+                                                 pad_width=pad_width,
+                                                 mode="constant",
+                                                 constant_values=0.,
+                                             ),
+                                             decimal=5)
+        np.testing.assert_array_almost_equal(tensor.to_numpy(y2),
+                                             np.pad(
+                                                 X,
+                                                 pad_width=pad_width,
+                                                 mode="reflect",
+                                             ),
+                                             decimal=5)
+        np.testing.assert_array_almost_equal(tensor.to_numpy(y3),
+                                             np.pad(
+                                                 X,
+                                                 pad_width=pad_width,
+                                                 mode="edge",
+                                             ),
+                                             decimal=5)
+        self.check_shape(dx1.shape(), (3, 2))
+        self.check_shape(dx2.shape(), (3, 2))
+        self.check_shape(dx3.shape(), (3, 2))
+
+    def test_pad_cpu(self):
+        self.pad_helper(cpu_dev)
+
+    @unittest.skipIf(not singa_wrap.USE_CUDA, 'CUDA is not enabled')
+    def test_pad_gpu(self):
+        self.pad_helper(gpu_dev)
+
+    def upsample_helper(self, dev):
+        X = np.array([[[
+            [1, 2],
+            [3, 4],
+        ]]], dtype=np.float32)
+        x = tensor.from_numpy(X)
+        x.to_device(dev)
+
+        scales = np.array([1.0, 1.0, 2.0, 3.0], dtype=np.float32)
+        y_t = np.array([[[
+            [1, 1, 1, 2, 2, 2],
+            [1, 1, 1, 2, 2, 2],
+            [3, 3, 3, 4, 4, 4],
+            [3, 3, 3, 4, 4, 4],
+        ]]],
+                       dtype=np.float32)
+        dy = tensor.from_numpy(y_t)
+        dy.to_device(dev)
+
+        y = autograd.upsample(x, "nearest", scales)
+        dx = y.creator.backward(dy.data)
+        np.testing.assert_array_almost_equal(tensor.to_numpy(y), y_t)
+        self.check_shape(dx.shape(), tuple(X.shape))
+
+    def test_upsample_cpu(self):
+        self.upsample_helper(cpu_dev)
+
+    @unittest.skipIf(not singa_wrap.USE_CUDA, 'CUDA is not enabled')
+    def test_upsample_gpu(self):
+        self.upsample_helper(gpu_dev)
+
+    def depth_space_helper(self, dev):
+        # (1, 8, 2, 3) input tensor
+        X = np.array(
+            [[[[0., 1., 2.], [3., 4., 5.]], [[9., 10., 11.], [12., 13., 14.]],
+              [[18., 19., 20.], [21., 22., 23.]],
+              [[27., 28., 29.], [30., 31., 32.]],
+              [[36., 37., 38.], [39., 40., 41.]],
+              [[45., 46., 47.], [48., 49., 50.]],
+              [[54., 55., 56.], [57., 58., 59.]],
+              [[63., 64., 65.], [66., 67., 68.]]]],
+            dtype=np.float32)
+        x = tensor.from_numpy(X)
+        x.to_device(dev)
+
+        # (1, 2, 4, 6) output tensor
+        y_t = np.array(
+            [[[[0., 18., 1., 19., 2., 20.], [36., 54., 37., 55., 38., 56.],
+               [3., 21., 4., 22., 5., 23.], [39., 57., 40., 58., 41., 59.]],
+              [[9., 27., 10., 28., 11., 29.], [45., 63., 46., 64., 47., 65.],
+               [12., 30., 13., 31., 14., 32.], [48., 66., 49., 67., 50., 68.]]]
+            ],
+            dtype=np.float32)
+        dy = tensor.from_numpy(y_t)
+        dy.to_device(dev)
+        y = autograd.depth_to_space(x, 2, "DCR")
+        dx = y.creator.backward(dy.data)
+        np.testing.assert_array_almost_equal(tensor.to_numpy(y), y_t)
+        np.testing.assert_array_almost_equal(tensor.to_numpy(tensor.from_raw_tensor(dx)), X)
+
+        y = autograd.space_to_depth(dy, 2, "DCR")
+        dx = y.creator.backward(x.data)
+        np.testing.assert_array_almost_equal(tensor.to_numpy(y), X)
+        np.testing.assert_array_almost_equal(tensor.to_numpy(tensor.from_raw_tensor(dx)), y_t)
+
+        y_t = np.array(
+            [[[[0., 9., 1., 10., 2., 11.], [18., 27., 19., 28., 20., 29.],
+               [3., 12., 4., 13., 5., 14.], [21., 30., 22., 31., 23., 32.]],
+              [[36., 45., 37., 46., 38., 47.], [54., 63., 55., 64., 56., 65.],
+               [39., 48., 40., 49., 41., 50.], [57., 66., 58., 67., 59., 68.]]]
+            ],
+            dtype=np.float32)
+        dy = tensor.from_numpy(y_t)
+        dy.to_device(dev)
+        y = autograd.depth_to_space(x, 2, "CRD")
+        dx = y.creator.backward(dy.data)
+        np.testing.assert_array_almost_equal(tensor.to_numpy(y), y_t)
+        np.testing.assert_array_almost_equal(tensor.to_numpy(tensor.from_raw_tensor(dx)), X)
+
+        y = autograd.space_to_depth(dy, 2, "CRD")
+        dx = y.creator.backward(x.data)
+        np.testing.assert_array_almost_equal(tensor.to_numpy(y), X)
+        np.testing.assert_array_almost_equal(tensor.to_numpy(tensor.from_raw_tensor(dx)), y_t)
+
+    def test_depth_space_cpu(self):
+        self.depth_space_helper(cpu_dev)
+
+    @unittest.skipIf(not singa_wrap.USE_CUDA, 'CUDA is not enabled')
+    def test_depth_space_gpu(self):
+        self.depth_space_helper(gpu_dev)
+
+    def test_invalid_inputs(self, dev=cpu_dev):
+        _1d = tensor.Tensor((10,), dev)
+        _2d = tensor.Tensor((10, 10), dev)
+        _3d = tensor.Tensor((10, 10, 10), dev)
+        self.assertRaises(AssertionError, autograd.softmax_cross_entropy, _2d,
+                          _3d)
+        self.assertRaises(AssertionError, autograd.mse_loss, _2d, _3d)
+        self.assertRaises(AssertionError, autograd.add_bias, _2d, _1d, 3)
+        self.assertRaises(AssertionError, autograd.ranking_loss, _2d, _1d)
+
+    def where_helper(self, dev):
+        X = np.array([[1, 2], [3, 4]], dtype=np.float32)
+        x = tensor.from_numpy(X)
+        x.to_device(dev)
+
+        X2 = np.array([[9, 8], [7, 6]], dtype=np.float32)
+        x2 = tensor.from_numpy(X2)
+        x2.to_device(dev)
+
+        condition = [[True, False], [True, True]]
+        y_t = np.where(condition, X, X2)
+        dx1_t = np.array([[1, 0], [3, 4]], dtype=np.float32)
+        dx2_t = np.array([[0, 8], [0, 0]], dtype=np.float32)
+        dy = tensor.from_numpy(y_t)
+        dy.to_device(dev)
+
+        y = autograd.where(x, x2, condition)
+        dx1, dx2 = y.creator.backward(dy.data)
+        np.testing.assert_array_almost_equal(tensor.to_numpy(y), y_t)
+        np.testing.assert_array_almost_equal(
+            tensor.to_numpy(tensor.from_raw_tensor(dx1)), dx1_t)
+        np.testing.assert_array_almost_equal(
+            tensor.to_numpy(tensor.from_raw_tensor(dx2)), dx2_t)
+
+    def test_where_cpu(self):
+        self.where_helper(cpu_dev)
+
+    @unittest.skipIf(not singa_wrap.USE_CUDA, 'CUDA is not enabled')
+    def test_where_gpu(self):
+        self.where_helper(gpu_dev)
+
+    def rounde_helper(self, dev):
+        X = np.array([
+            0.1, 0.5, 0.9, 1.2, 1.5, 1.8, 2.3, 2.5, 2.7, -1.1, -1.5, -1.9, -2.2,
+            -2.5, -2.8
+        ]).astype(np.float32)
+        x = tensor.from_numpy(X)
+        x.to_device(dev)
+
+        y_t = np.array(
+            [0., 0., 1., 1., 2., 2., 2., 2., 3., -1., -2., -2., -2., -2.,
+             -3.]).astype(np.float32)
+        dy = tensor.from_numpy(y_t)
+        dy.to_device(dev)
+
+        y = autograd.rounde(x)
+        np.testing.assert_array_almost_equal(tensor.to_numpy(y), y_t)
+
+    def test_rounde_cpu(self):
+        self.rounde_helper(cpu_dev)
+
+    @unittest.skipIf(not singa_wrap.USE_CUDA, 'CUDA is not enabled')
+    def test_rounde_gpu(self):
+        self.rounde_helper(gpu_dev)
+
+    def round_helper(self, dev):
+        X = np.array([
+            0.1, 0.5, 0.9, 1.2, 1.5, 1.8, 2.3, 2.5, 2.7, -1.1, -1.5, -1.9, -2.2,
+            -2.5, -2.8
+        ]).astype(np.float32)
+        x = tensor.from_numpy(X)
+        x.to_device(dev)
+
+        y_t = np.array(
+            [0., 1., 1., 1., 2., 2., 2., 3., 3., -1., -2., -2., -2., -3.,
+             -3.]).astype(np.float32)
+        dy = tensor.from_numpy(y_t)
+        dy.to_device(dev)
+
+        y = autograd.round(x)
+        np.testing.assert_array_almost_equal(tensor.to_numpy(y), y_t)
+
+    def test_round_cpu(self):
+        self.round_helper(cpu_dev)
+
+    @unittest.skipIf(not singa_wrap.USE_CUDA, 'CUDA is not enabled')
+    def test_round_gpu(self):
+        self.round_helper(gpu_dev)
+
+    def embedding_helper(self, dev):
+        embedding = layer.Embedding(10, 3)
+
+        X = np.array([[0, 1, 2, 3], [9, 8, 7, 6]])
+        x = tensor.from_numpy(X)
+        x.to_device(dev)
+
+        dy = tensor.Tensor(shape=(2, 4, 3), device=dev)
+        dy.gaussian(0.0, 1.0)
+
+        y = embedding(x)  # PyTensor
+        dx, dW = y.creator.backward(dy.data)  # CTensor
+
+        self.check_shape(y.shape, (2, 4, 3))
+        self.check_shape(dx.shape(), (2, 4))
+        self.check_shape(dW.shape(), (10, 3))
+
+    def test_embedding_cpu(self):
+        self.embedding_helper(cpu_dev)
+
+    @unittest.skipIf(not singa_wrap.USE_CUDA, 'CUDA is not enabled')
+    def test_embedding_gpu(self):
+        self.embedding_helper(gpu_dev)
+
+    @unittest.skipIf(not singa_wrap.USE_CUDA, 'CUDA is not enabled')
+    def _cossim_value(self, dev=gpu_dev):
+        # numpy val
+        np.random.seed(0)
+        bs = 1000
+        vec_s = 1200
+        a = np.random.random((bs, vec_s)).astype(np.float32)
+        b = np.random.random((bs, vec_s)).astype(np.float32)
+        dy = np.random.random((bs,)).astype(np.float32)
+
+        # singa tensor
+        ta = tensor.from_numpy(a)
+        tb = tensor.from_numpy(b)
+        tdy = tensor.from_numpy(dy)
+        ta.to_device(dev)
+        tb.to_device(dev)
+        tdy.to_device(dev)
+
+        # singa forward and backward
+        ty = autograd.cossim(ta, tb)
+        tda, tdb = ty.creator.backward(tdy.data)
+
+        np_forward = list()
+        for i in range(len(a)):
+            a_norm = np.linalg.norm(a[i])
+            b_norm = np.linalg.norm(b[i])
+            ab_dot = np.dot(a[i], b[i])
+            out = ab_dot / (a_norm * b_norm)
+            np_forward.append(out)
+
+        np_backward_a = list()
+        np_backward_b = list()
+        for i in range(len(a)):
+            a_norm = np.linalg.norm(a[i])
+            b_norm = np.linalg.norm(b[i])
+            da = dy[i] * (b[i] / (a_norm * b_norm) - (np_forward[i] * a[i]) /
+                          (a_norm * a_norm))
+            db = dy[i] * (a[i] / (a_norm * b_norm) - (np_forward[i] * b[i]) /
+                          (b_norm * b_norm))
+            np_backward_a.append(da)
+            np_backward_b.append(db)
+
+        np.testing.assert_array_almost_equal(tensor.to_numpy(ty),
+                                             np.array(np_forward))
+        np.testing.assert_array_almost_equal(
+            tensor.to_numpy(tensor.from_raw_tensor(tda)), np_backward_a)
+
+    @unittest.skipIf(not singa_wrap.USE_CUDA, 'CUDA is not enabled')
+    def test_cossim_value_gpu(self):
+        self._cossim_value(gpu_dev)
+
+    def test_cossim_value_cpu(self):
+        self._cossim_value(cpu_dev)
+
+    def test_mse_loss_value(self, dev=cpu_dev):
+        y = np.random.random((1000, 1200)).astype(np.float32)
+        tar = np.random.random((1000, 1200)).astype(np.float32)
+        # get singa value
+        sy = tensor.from_numpy(y, dev)
+        starget = tensor.from_numpy(tar, dev)
+        sloss = autograd.mse_loss(sy, starget)
+        sgrad = sloss.creator.backward()
+        # get np value result
+        np_loss = np.mean(np.square(tar - y))
+        np_grad = -2 * (tar - y) / np.prod(tar.shape)
+        # value check
+        np.testing.assert_array_almost_equal(
+            tensor.to_numpy(tensor.from_raw_tensor(sgrad)), np_grad)
+        np.testing.assert_array_almost_equal(tensor.to_numpy(sloss), np_loss)
+
+    def erf_helper(self, dev):
+        X = np.array([
+            0.1, 0.5, 0.9, 1.2, 1.5, 1.8, 2.3, 2.5, 2.7, -1.1, -1.5, -1.9, -2.2,
+            -2.5, -2.8
+        ]).astype(np.float32)
+        x = tensor.from_numpy(X)
+        x.to_device(dev)
+
+        import math
+
+        y_t = np.vectorize(math.erf)(X)
+        dy = tensor.from_numpy(y_t)
+        dy.to_device(dev)
+        dx_t = 2. / np.pi**0.5 * np.exp(-np.power(y_t, 2))
+
+        y = autograd.erf(x)
+        dx = y.creator.backward(dy.data)
+        np.testing.assert_array_almost_equal(tensor.to_numpy(y), y_t)
+        np.testing.assert_array_almost_equal(
+            tensor.to_numpy(tensor.from_raw_tensor(dx)), dx_t)
+
+    def test_erf_cpu(self):
+        self.erf_helper(cpu_dev)
+
+    @unittest.skipIf(not singa_wrap.USE_CUDA, 'CUDA is not enabled')
+    def test_erf_gpu(self):
+        self.erf_helper(gpu_dev)
+
 
 if __name__ == '__main__':
     unittest.main()
diff --git a/test/python/test_opt.py b/test/python/test_opt.py
new file mode 100644
index 0000000..8027d3a
--- /dev/null
+++ b/test/python/test_opt.py
@@ -0,0 +1,230 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# =============================================================================
+from __future__ import division
+
+import math
+import unittest
+import numpy as np
+import functools
+
+
+from singa import tensor
+from singa import singa_wrap as singa
+from singa import opt
+
+from cuda_helper import gpu_dev, cpu_dev
+
+def assertTensorEqual(x,y,decimal=6):
+    assert x.shape == y.shape
+    assert x.dtype == y.dtype
+    assert x.device.id() == y.device.id()
+    d = x.device
+    x.to_host()
+    y.to_host()
+    np.testing.assert_array_almost_equal(
+        x.data.GetFloatValue(int(x.size())),
+        y.data.GetFloatValue(int(y.size())),
+                                    decimal)
+    x.to_device(d)
+    y.to_device(d)
+
+def on_cpu_gpu(func):
+    @functools.wraps(func)
+    def wrapper_decorator(*args, **kwargs):
+        func(*args, dev=cpu_dev, **kwargs)
+        if singa.USE_CUDA:
+            func(*args, dev=gpu_dev, **kwargs)
+    return wrapper_decorator
+
+class TestDecayScheduler(unittest.TestCase):
+    def test_exponential_decay_cpu(self):
+        lr = opt.ExponentialDecay(0.1, 2, 0.5, True)
+        sgd1 = opt.SGD(lr=lr)
+        for i in range(5):
+            np.testing.assert_array_almost_equal(tensor.to_numpy(sgd1.lr_value), [0.1*0.5**(i//2)])
+            sgd1.step()
+
+    def test_exponential_decay_no_staircase_cpu(self):
+        lr = opt.ExponentialDecay(0.1, 2, 0.5, False)
+        sgd1 = opt.SGD(lr=lr)
+        for i in range(5):
+            np.testing.assert_array_almost_equal(tensor.to_numpy(sgd1.lr_value), [0.1*0.5**(i/2)])
+            sgd1.step()
+
+    @on_cpu_gpu
+    def test_const_decay_scheduler(self, dev):
+        c1 = opt.Constant(0.2)
+        step = tensor.Tensor((1,), device=dev).set_value(0)
+        lr_val = c1(step)
+        np.testing.assert_array_almost_equal(tensor.to_numpy(c1(step)) , [0.2])
+        step+=1
+        np.testing.assert_array_almost_equal(tensor.to_numpy(c1(step)) , [0.2])
+
+class TestOptimizer(unittest.TestCase):
+    @on_cpu_gpu
+    def test_optimizer(self, dev):
+        o1 = opt.Optimizer(0.1)
+
+        # test step
+        o1.step()
+        o1.step()
+
+        # test get states
+        s1 = o1.get_states()
+        self.assertAlmostEqual(s1['step_counter'], 2)
+
+        # test set states
+        s2 = {'step_counter': 5}
+        o1.set_states(s2)
+        np.testing.assert_array_almost_equal( tensor.to_numpy(o1.step_counter), [5])
+
+    @on_cpu_gpu
+    def test_sgd_const_lr(self, dev=cpu_dev):
+        cpu_dev.EnableGraph(False)
+        sgd1 = opt.SGD(lr=0.1)
+        w_shape=(2,3)
+        w = tensor.Tensor(w_shape, device=dev).set_value(0.1)
+        g = tensor.Tensor(w_shape, device=dev).set_value(0.1)
+
+        w_step1 = w-0.1*g
+        sgd1.apply(w.name, w, g)
+
+        assertTensorEqual(w, w_step1)
+
+    @on_cpu_gpu
+    def test_RMSProp_const_lr(self, dev=cpu_dev):
+        cpu_dev.EnableGraph(False)
+        opt1 = opt.RMSProp(lr=0.1)
+        w_shape=(2,3)
+        w = tensor.Tensor(w_shape, device=dev).set_value(0.1)
+        g = tensor.Tensor(w_shape, device=dev).set_value(0.1)
+
+        # running_average = running_average * rho + param_grad * param_grad * (1 - rho)
+        # param_value = param_value - lr * param_grad / sqrt(running_average + epsilon)
+
+        running_average = 0.1 * tensor.square(g)
+        tmp = running_average + 1e-8
+        tmp = tensor.sqrt(tmp)
+        tmp = g / tmp
+
+        w_step1 = w - 0.1 * tmp
+        opt1.apply(w.name, w, g)
+
+        assertTensorEqual(w, w_step1)
+
+    @on_cpu_gpu
+    def test_AdaGrad_const_lr(self, dev=cpu_dev):
+        cpu_dev.EnableGraph(False)
+        opt1 = opt.AdaGrad(lr=0.1)
+        w_shape=(2,3)
+        w = tensor.Tensor(w_shape, device=dev).set_value(0.1)
+        g = tensor.Tensor(w_shape, device=dev).set_value(0.1)
+
+        # history = history + param_grad * param_grad
+        # param_value = param_value - lr * param_grad / sqrt(history + epsilon)
+
+        history = tensor.square(g)
+        tmp = history + 1e-8
+        tmp = tensor.sqrt(tmp)
+        tmp = g / tmp
+
+        w_step1 = w - 0.1 * tmp
+        opt1.apply(w.name, w, g)
+
+        assertTensorEqual(w, w_step1)
+
+    @on_cpu_gpu
+    def test_Adam_const_lr(self, dev=cpu_dev):
+        cpu_dev.EnableGraph(False)
+        opt1 = opt.Adam(lr=0.1)
+        w_shape=(2,3)
+        w = tensor.Tensor(w_shape, device=dev).set_value(1.0)
+        g = tensor.Tensor(w_shape, device=dev).set_value(0.1)
+
+        # m := beta_1 * m + (1 - beta_1) * grad 
+        # v := beta_2 * v + (1 - beta_2) * grad * grad
+        # m_norm = m / (1 - beta_1 ^ step) 
+        # v_norm = v / (1 - beta_2 ^ step) 
+        # param := param - (lr * m_norm) / ( sqrt(v_norm) + epsilon) )
+
+        m = 0.1 * g
+        tmp = tensor.square(g)
+        v = 0.001 * tmp
+
+        m_norm = m / 0.1
+        v_norm = v / 0.001
+
+        tmp = tensor.sqrt(v_norm) + 1e-8
+        tmp = m_norm / tmp      
+
+        w_step1 = w - 0.1 * tmp
+        opt1.apply(w.name, w, g)
+
+        assertTensorEqual(w, w_step1, decimal=5)
+
+    @on_cpu_gpu
+    def test_sgd_const_lr_momentum(self, dev=cpu_dev):
+        sgd1 = opt.SGD(lr=0.1,momentum=0.9)
+        w_shape=(2,3)
+        w = tensor.Tensor(w_shape, device=dev).set_value(0.1)
+        g = tensor.Tensor(w_shape, device=dev).set_value(0.01)
+
+        w_step1 = w-0.1*g
+        buf = g
+
+        sgd1.apply(w.name, w, g)
+        sgd1.step()
+
+        assertTensorEqual(w,w_step1)
+
+        buf = g + buf*0.9
+        w_step2 = w-0.1*buf
+
+        sgd1.apply(w.name, w, g)
+
+        assertTensorEqual(w, w_step2)
+
+    @on_cpu_gpu
+    def test_sgd_const_lr_momentum_weight_decay(self, dev=cpu_dev):
+        sgd1 = opt.SGD(lr=0.1, weight_decay=0.2)
+        w_shape=(2,3)
+        w = tensor.Tensor(w_shape, device=dev).set_value(0.1)
+        g = tensor.Tensor(w_shape, device=dev).set_value(0.01)
+
+        w_step1 = w-0.1*(g+0.2*w)
+
+        sgd1.apply(w.name, w, g)
+
+        assertTensorEqual(w,w_step1)
+
+    # @on_cpu_gpu
+    def test_sgd_const_lr_momentum_nesterov(self, dev=cpu_dev):
+        sgd1 = opt.SGD(lr=0.1, momentum=0.9, nesterov=True)
+        w_shape=(2,3)
+        w = tensor.Tensor(w_shape, device=dev).set_value(0.1)
+        g = tensor.Tensor(w_shape, device=dev).set_value(0.1)
+
+        buf = g
+        w_step1 = w-0.1*(g+0.9*buf)
+
+        sgd1.apply(w.name, w, g)
+
+        assertTensorEqual(w,w_step1)
+
+if __name__ == '__main__':
+    unittest.main()
diff --git a/test/python/test_optimizer.py b/test/python/test_optimizer.py
deleted file mode 100644
index f559380..0000000
--- a/test/python/test_optimizer.py
+++ /dev/null
@@ -1,382 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-# =============================================================================
-from __future__ import division
-from builtins import zip
-from builtins import range
-
-import unittest
-import math
-import numpy as np
-
-import singa.tensor as tensor
-import singa.optimizer as opt
-from singa import singa_wrap
-
-from cuda_helper import gpu_dev as cuda
-
-
-def np_adam(plist, glist, mlist, vlist, lr, t, b1=0.9, b2=0.999):
-    for p, g, m, v in zip(plist, glist, mlist, vlist):
-        m *= b1
-        m += (1 - b1) * g
-        v *= b2
-        v += (1 - b2) * g * g
-        alpha = lr * math.sqrt(1. - math.pow(b2, t)) / (1. - math.pow(b1, t))
-        p -= alpha * m / (np.sqrt(v) + 1e-8)
-
-
-def np_rmsprop(plist, glist, vlist, lr, t, rho=0.9):
-    for p, g, v in zip(plist, glist, vlist):
-        v *= rho
-        v += (1 - rho) * g * g
-        p -= lr * g / (np.sqrt(v + 1e-8))
-
-
-def np_momentum(plist, glist, vlist, lr, t, momentum=0.9):
-    for p, g, v in zip(plist, glist, vlist):
-        v *= momentum
-        v += lr * g
-        p -= v
-
-
-def np_adagrad(plist, glist, vlist, lr, t):
-    for p, g, v in zip(plist, glist, vlist):
-        v += g * g
-        p -= lr * g / (np.sqrt(v + 1e-8))
-
-
-class TestOptimizer(unittest.TestCase):
-
-    def setUp(self):
-        self.np_W = np.array([0.1, 0.2, 0.3, 0.4], dtype=np.float32)
-        self.W = tensor.from_numpy(self.np_W)
-        self.np_g = np.array([0.1, 0.3, 0.1, 0.2], dtype=np.float32)
-        self.g = tensor.from_numpy(self.np_g)
-
-    def to_cuda(self):
-        self.W.to_device(cuda)
-        self.g.to_device(cuda)
-
-    def test_sgd(self):
-        lr = 0.1
-        sgd = opt.SGD(lr)
-        sgd.apply(0, self.g, self.W, 'w')
-        w = tensor.to_numpy(self.W)
-        for i in range(self.W.size()):
-            self.assertAlmostEqual(w[i], self.np_W[i] - lr * self.np_g[i])
-
-    def test_adam(self):
-        lr = 0.1
-        n, m = 4, 6
-        p1 = np.random.rand(n, m)
-        p2 = np.random.rand(n, m)
-        g1 = np.random.rand(n, m) * 0.01
-        g2 = np.random.rand(n, m) * 0.01
-        m1 = np.zeros((n, m))
-        m2 = np.zeros((n, m))
-        v1 = np.zeros((n, m))
-        v2 = np.zeros((n, m))
-        t1 = tensor.from_numpy(p1)
-        t2 = tensor.from_numpy(p2)
-        tg1 = tensor.from_numpy(g1)
-        tg2 = tensor.from_numpy(g2)
-
-        for t in range(1, 10):
-            np_adam([p1, p2], [g1, g2], [m1, m2], [v1, v2], lr, t)
-
-        adam = opt.Adam(lr=lr)
-        for t in range(1, 10):
-            adam.apply(0, tg1, t1, 'p1', t)
-            adam.apply(0, tg2, t2, 'p2', t)
-
-        t1 = tensor.to_numpy(t1)
-        t2 = tensor.to_numpy(t2)
-        for t, p in zip([t1, t2], [p1, p2]):
-            for i in range(n):
-                for j in range(m):
-                    self.assertAlmostEqual(t[i, j], p[i, j], 6)
-
-    @unittest.skipIf(not singa_wrap.USE_CUDA, 'CUDA is not enabled')
-    def test_sgd_cuda(self):
-        lr = 0.1
-        sgd = opt.SGD(lr)
-        self.to_cuda()
-        sgd.apply(0, self.g, self.W, 'w')
-        self.W.to_host()
-        w = tensor.to_numpy(self.W)
-        for i in range(self.W.size()):
-            self.assertAlmostEqual(w[i], self.np_W[i] - lr * self.np_g[i])
-
-    def test_constraint(self):
-        threshold = 0.02
-        cons = opt.L2Constraint(threshold)
-        cons.apply(0, self.W, self.g)
-        g = tensor.to_numpy(self.g)
-        nrm = np.linalg.norm(self.np_g) / self.np_g.size
-        for i in range(g.size):
-            self.assertAlmostEqual(g[i], self.np_g[i] * threshold / nrm)
-
-    @unittest.skipIf(not singa_wrap.USE_CUDA, 'CUDA is not enabled')
-    def test_constraint_cuda(self):
-        threshold = 0.02
-        self.to_cuda()
-        cons = opt.L2Constraint(threshold)
-        cons.apply(0, self.W, self.g)
-        self.g.to_host()
-        g = tensor.to_numpy(self.g)
-        nrm = np.linalg.norm(self.np_g) / self.np_g.size
-        for i in range(g.size):
-            self.assertAlmostEqual(g[i], self.np_g[i] * threshold / nrm)
-
-    def test_regularizer(self):
-        coefficient = 0.0001
-        reg = opt.L2Regularizer(coefficient)
-        reg.apply(0, self.W, self.g)
-        g = tensor.to_numpy(self.g)
-        for i in range(g.size):
-            self.assertAlmostEqual(g[i],
-                                   self.np_g[i] + coefficient * self.np_W[i])
-
-    @unittest.skipIf(not singa_wrap.USE_CUDA, 'CUDA is not enabled')
-    def test_regularizer_cuda(self):
-        coefficient = 0.0001
-        reg = opt.L2Regularizer(coefficient)
-        self.to_cuda()
-        reg.apply(0, self.W, self.g)
-        self.g.to_host()
-        g = tensor.to_numpy(self.g)
-        for i in range(g.size):
-            self.assertAlmostEqual(g[i],
-                                   self.np_g[i] + coefficient * self.np_W[i])
-
-    @unittest.skipIf(not singa_wrap.USE_CUDA, 'CUDA is not enabled')
-    def test_adam_cuda(self):
-        lr = 0.1
-        n, m = 4, 6
-        p1 = np.random.rand(n, m)
-        p2 = np.random.rand(n, m)
-        g1 = np.random.rand(n, m) * 0.01
-        g2 = np.random.rand(n, m) * 0.01
-        m1 = np.zeros((n, m))
-        m2 = np.zeros((n, m))
-        v1 = np.zeros((n, m))
-        v2 = np.zeros((n, m))
-        t1 = tensor.from_numpy(p1)
-        t2 = tensor.from_numpy(p2)
-        tg1 = tensor.from_numpy(g1)
-        tg2 = tensor.from_numpy(g2)
-
-        for t in range(1, 10):
-            np_adam([p1, p2], [g1, g2], [m1, m2], [v1, v2], lr, t)
-
-        adam = opt.Adam(lr=lr)
-        self.to_cuda()
-        for t in range(1, 10):
-            adam.apply(0, tg1, t1, 'p1', t)
-            adam.apply(0, tg2, t2, 'p2', t)
-
-        t1 = tensor.to_numpy(t1)
-        t2 = tensor.to_numpy(t2)
-        for t, p in zip([t1, t2], [p1, p2]):
-            for i in range(n):
-                for j in range(m):
-                    self.assertAlmostEqual(t[i, j], p[i, j], 6)
-
-    def test_rmsprop(self):
-        lr = 0.1
-        n, m = 2, 2
-        p1 = np.random.rand(n, m)
-        p2 = np.random.rand(n, m)
-        g1 = np.random.rand(n, m) * 0.01
-        g2 = np.random.rand(n, m) * 0.01
-        v1 = np.zeros((n, m))
-        v2 = np.zeros((n, m))
-        t1 = tensor.from_numpy(p1)
-        t2 = tensor.from_numpy(p2)
-        tg1 = tensor.from_numpy(g1)
-        tg2 = tensor.from_numpy(g2)
-
-        for t in range(1, 4):
-            np_rmsprop([p1, p2], [g1, g2], [v1, v2], lr, t)
-
-        rsmprop = opt.RMSProp(lr=lr)
-        for t in range(1, 4):
-            rsmprop.apply(0, tg1, t1, 'p1', t)
-            rsmprop.apply(0, tg2, t2, 'p2', t)
-
-        t1 = tensor.to_numpy(t1)
-        t2 = tensor.to_numpy(t2)
-        for t, p in zip([t1, t2], [p1, p2]):
-            for i in range(n):
-                for j in range(m):
-                    self.assertAlmostEqual(t[i, j], p[i, j], 2)
-
-    @unittest.skipIf(not singa_wrap.USE_CUDA, 'CUDA is not enabled')
-    def test_rmsprop_cuda(self):
-        lr = 0.1
-        n, m = 2, 2
-        p1 = np.random.rand(n, m)
-        p2 = np.random.rand(n, m)
-        g1 = np.random.rand(n, m) * 0.01
-        g2 = np.random.rand(n, m) * 0.01
-        v1 = np.zeros((n, m))
-        v2 = np.zeros((n, m))
-        t1 = tensor.from_numpy(p1)
-        t2 = tensor.from_numpy(p2)
-        tg1 = tensor.from_numpy(g1)
-        tg2 = tensor.from_numpy(g2)
-
-        for t in range(1, 4):
-            np_rmsprop([p1, p2], [g1, g2], [v1, v2], lr, t)
-
-        rsmprop = opt.RMSProp(lr=lr)
-        self.to_cuda()
-        for t in range(1, 4):
-            rsmprop.apply(0, tg1, t1, 'p1', t)
-            rsmprop.apply(0, tg2, t2, 'p2', t)
-
-        t1 = tensor.to_numpy(t1)
-        t2 = tensor.to_numpy(t2)
-        for t, p in zip([t1, t2], [p1, p2]):
-            for i in range(n):
-                for j in range(m):
-                    self.assertAlmostEqual(t[i, j], p[i, j], 2)
-
-    def test_momentum(self):
-        lr = 0.1
-        n, m = 2, 2
-        p1 = np.random.rand(n, m)
-        p2 = np.random.rand(n, m)
-        g1 = np.random.rand(n, m) * 0.01
-        g2 = np.random.rand(n, m) * 0.01
-        v1 = np.zeros((n, m))
-        v2 = np.zeros((n, m))
-        t1 = tensor.from_numpy(p1)
-        t2 = tensor.from_numpy(p2)
-        tg1 = tensor.from_numpy(g1)
-        tg2 = tensor.from_numpy(g2)
-
-        for t in range(1, 4):
-            np_momentum([p1, p2], [g1, g2], [v1, v2], lr, t)
-
-        momentum = opt.SGD(lr, momentum=0.9)
-        for t in range(1, 4):
-            momentum.apply(0, tg1, t1, 'p1', t)
-            momentum.apply(0, tg2, t2, 'p2', t)
-
-        t1 = tensor.to_numpy(t1)
-        t2 = tensor.to_numpy(t2)
-        for t, p in zip([t1, t2], [p1, p2]):
-            for i in range(n):
-                for j in range(m):
-                    self.assertAlmostEqual(t[i, j], p[i, j], 2)
-
-    @unittest.skipIf(not singa_wrap.USE_CUDA, 'CUDA is not enabled')
-    def test_momentum_cuda(self):
-        lr = 0.1
-        n, m = 2, 2
-        p1 = np.random.rand(n, m)
-        p2 = np.random.rand(n, m)
-        g1 = np.random.rand(n, m) * 0.01
-        g2 = np.random.rand(n, m) * 0.01
-        v1 = np.zeros((n, m))
-        v2 = np.zeros((n, m))
-        t1 = tensor.from_numpy(p1)
-        t2 = tensor.from_numpy(p2)
-        tg1 = tensor.from_numpy(g1)
-        tg2 = tensor.from_numpy(g2)
-
-        for t in range(1, 4):
-            np_momentum([p1, p2], [g1, g2], [v1, v2], lr, t)
-
-        momentum = opt.SGD(lr, momentum=0.9)
-        self.to_cuda()
-        for t in range(1, 4):
-            momentum.apply(0, tg1, t1, 'p1', t)
-            momentum.apply(0, tg2, t2, 'p2', t)
-
-        t1 = tensor.to_numpy(t1)
-        t2 = tensor.to_numpy(t2)
-        for t, p in zip([t1, t2], [p1, p2]):
-            for i in range(n):
-                for j in range(m):
-                    self.assertAlmostEqual(t[i, j], p[i, j], 2)
-
-    def test_adagrad(self):
-        lr = 0.1
-        n, m = 2, 2
-        p1 = np.random.rand(n, m)
-        p2 = np.random.rand(n, m)
-        g1 = np.random.rand(n, m) * 0.01
-        g2 = np.random.rand(n, m) * 0.01
-        v1 = np.zeros((n, m))
-        v2 = np.zeros((n, m))
-        t1 = tensor.from_numpy(p1)
-        t2 = tensor.from_numpy(p2)
-        tg1 = tensor.from_numpy(g1)
-        tg2 = tensor.from_numpy(g2)
-
-        for t in range(1, 4):
-            np_adagrad([p1, p2], [g1, g2], [v1, v2], lr, t)
-
-        adagrad = opt.AdaGrad(lr=lr)
-        for t in range(1, 4):
-            adagrad.apply(0, tg1, t1, 'p1', t)
-            adagrad.apply(0, tg2, t2, 'p2', t)
-
-        t1 = tensor.to_numpy(t1)
-        t2 = tensor.to_numpy(t2)
-        for t, p in zip([t1, t2], [p1, p2]):
-            for i in range(n):
-                for j in range(m):
-                    self.assertAlmostEqual(t[i, j], p[i, j], 2)
-
-    @unittest.skipIf(not singa_wrap.USE_CUDA, 'CUDA is not enabled')
-    def test_adagrad_cuda(self):
-        lr = 0.1
-        n, m = 2, 2
-        p1 = np.random.rand(n, m)
-        p2 = np.random.rand(n, m)
-        g1 = np.random.rand(n, m) * 0.01
-        g2 = np.random.rand(n, m) * 0.01
-        v1 = np.zeros((n, m))
-        v2 = np.zeros((n, m))
-        t1 = tensor.from_numpy(p1)
-        t2 = tensor.from_numpy(p2)
-        tg1 = tensor.from_numpy(g1)
-        tg2 = tensor.from_numpy(g2)
-
-        for t in range(1, 4):
-            np_adagrad([p1, p2], [g1, g2], [v1, v2], lr, t)
-
-        adagrad = opt.AdaGrad(lr=lr)
-        self.to_cuda()
-        for t in range(1, 4):
-            adagrad.apply(0, tg1, t1, 'p1', t)
-            adagrad.apply(0, tg2, t2, 'p2', t)
-
-        t1 = tensor.to_numpy(t1)
-        t2 = tensor.to_numpy(t2)
-        for t, p in zip([t1, t2], [p1, p2]):
-            for i in range(n):
-                for j in range(m):
-                    self.assertAlmostEqual(t[i, j], p[i, j], 2)
-
-
-if __name__ == '__main__':
-    unittest.main()
diff --git a/test/python/test_tensor.py b/test/python/test_tensor.py
index a3e4b2c..82d6d5c 100644
--- a/test/python/test_tensor.py
+++ b/test/python/test_tensor.py
@@ -24,7 +24,6 @@
 from singa import tensor
 from singa import singa_wrap as singa_api
 from singa import autograd
-from singa.proto import core_pb2
 
 from cuda_helper import gpu_dev, cpu_dev
 
@@ -47,7 +46,7 @@
         self.assertEqual(tensor.product(shape), 2 * 3)
         self.assertEqual(t.ndim(), 2)
         self.assertEqual(t.size(), 2 * 3)
-        self.assertEqual(t.memsize(), 2 * 3 * tensor.sizeof(core_pb2.kFloat32))
+        self.assertEqual(t.memsize(), 2 * 3 * tensor.sizeof(tensor.float32))
         self.assertFalse(t.is_transpose())
 
     def test_unary_operators(self):
@@ -90,6 +89,8 @@
         self.assertEqual(tensor.to_numpy(a)[0, 0], 0)
         a = t >= 3.45
         self.assertEqual(tensor.to_numpy(a)[0, 0], 1)
+        a = t == 3.45
+        self.assertEqual(tensor.to_numpy(a)[0, 0], 1)
         a = tensor.lt(t, 3.45)
         self.assertEqual(tensor.to_numpy(a)[0, 0], 0)
         a = tensor.le(t, 3.45)
@@ -98,6 +99,8 @@
         self.assertEqual(tensor.to_numpy(a)[0, 0], 0)
         a = tensor.ge(t, 3.45)
         self.assertEqual(tensor.to_numpy(a)[0, 0], 1)
+        a = tensor.eq(t, 3.45)
+        self.assertEqual(tensor.to_numpy(a)[0, 0], 1)
 
     def test_tensor_copy(self):
         t = tensor.Tensor((2, 3))
@@ -158,6 +161,32 @@
         y = 2 / x
         self.assertEqual(tensor.average(y), 2.)
 
+    def matmul_high_dim_helper(self, dev):
+        configs = [
+            [(1, 12, 7, 64), (1, 12, 64, 7)],
+            [(1, 7, 768), (768, 768)],
+        ]
+        print()
+        for config in configs:
+            X = np.random.random(config[0]).astype(np.float32)
+            x = tensor.from_numpy(X)
+            x.to_device(dev)
+
+            W = np.random.random(config[1]).astype(np.float32)
+            w = tensor.from_numpy(W)
+            w.to_device(dev)
+
+            y_t = np.matmul(X, W)
+            y = autograd.matmul(x, w)
+            np.testing.assert_array_almost_equal(tensor.to_numpy(y), y_t, 3)
+
+    def test_matmul_high_dim_cpu(self):
+        self.matmul_high_dim_helper(cpu_dev)
+
+    @unittest.skipIf(not singa_api.USE_CUDA, 'CUDA is not enabled')
+    def test_matmul_high_dim_gpu(self):
+        self.matmul_high_dim_helper(gpu_dev)
+
     def test_tensor_inplace_api(self):
         """ tensor inplace methods alter internal state and also return self
         """
@@ -204,6 +233,27 @@
         np.testing.assert_array_almost_equal(TA1, A1)
         np.testing.assert_array_almost_equal(TA2, A2)
 
+    @unittest.skipIf(not singa_api.USE_CUDA, 'CUDA is not enabled')
+    def test_gpu_6d_transpose(self,dev=gpu_dev):
+        s0 = (2,3,4,5,6,7)
+        axes1=[5,4,3,2,1,0]
+        s1 = (2,7,6,5,4,3)
+        s2 = (2,4,3,5,7,6)
+        a = np.random.random(s1)
+
+        ta = tensor.from_numpy(a)
+        ta.to_device(dev)
+
+        ta = tensor.reshape(ta,s1)
+        ta = tensor.transpose(ta,axes1)
+        ta = tensor.reshape(ta,s2)
+
+        a = np.reshape(a,s1)
+        a = np.transpose(a,axes1)
+        a = np.reshape(a,s2)
+
+        np.testing.assert_array_almost_equal(tensor.to_numpy(ta), a)
+
     def test_einsum(self):
 
         a = np.array(
@@ -467,6 +517,101 @@
     def test_matmul_transpose_gpu(self):
         self._matmul_transpose_helper(gpu_dev)
 
+    @unittest.skipIf(not singa_api.USE_CUDA, 'CUDA is not enabled')
+    def test_gaussian_gpu(self, dev=gpu_dev):
+        x = tensor.Tensor((3, 5, 3, 5), device=dev)
+        x.gaussian(0, 1)
+        x = tensor.Tensor((4, 5, 3, 2), device=dev)
+        x.gaussian(0, 1)
+
+    def _kfloat32_int(self, dev=gpu_dev):
+        np.random.seed(0)
+        x_val = np.random.random((2, 3)).astype(np.float32) * 10
+        x = tensor.from_numpy(x_val)
+        x.to_device(dev)
+        scalar = np.random.random((1,))[0] * 100
+        y = x + scalar
+        self.assertEqual(y.dtype, tensor.float32)
+        np.testing.assert_array_almost_equal(tensor.to_numpy(y), x_val + scalar)
+
+    @unittest.skipIf(not singa_api.USE_CUDA, 'CUDA is not enabled')
+    def test_kfloat32_int_gpu(self):
+        self._kfloat32_int(gpu_dev)
+
+    def test_kfloat32_int_cpu(self):
+        self._kfloat32_int(cpu_dev)
+
+    def _kint_float(self, dev=gpu_dev):
+        np.random.seed(0)
+        x_val = np.random.randint(0, 10, (2, 3))
+        x = tensor.from_numpy(x_val)
+        x.to_device(dev)
+        scalar = np.random.random((1,))[0] * 100
+        y = x + scalar
+        self.assertEqual(y.dtype, tensor.float32)
+        np.testing.assert_array_almost_equal(tensor.to_numpy(y), x_val + scalar)
+
+    @unittest.skipIf(not singa_api.USE_CUDA, 'CUDA is not enabled')
+    def test_kint_float_gpu(self):
+        self._kint_float(gpu_dev)
+
+    def test_kint_float_cpu(self):
+        self._kint_float(cpu_dev)
+
+    def _kint_kint(self, dev=gpu_dev):
+        a_np = np.array([[[17, 4, 9, 22, 18], [-9, 9, -1, -1, 4],
+                          [1, 14, 7, 1, 4], [3, 14, -2, 3, -8]],
+                         [[-25, 6, 8, -7, 22], [-14, 0, -1, 15, 14],
+                          [1, 3, -8, -19, -3], [1, 12, 12, -3, -3]],
+                         [[-10, -14, -17, 19, -5], [-4, -12, 7, -16, -2],
+                          [-8, 3, -5, -11, 0], [4, 0, 3, -6, -3]]],
+                        dtype=np.int32)
+        b_np = np.array([[[-6, -3, -8, -17, 1], [-4, -16, 4, -9, 0],
+                          [7, 1, 11, -12, 4], [-6, -8, -5, -3, 0]],
+                         [[-11, 9, 4, -15, 14], [18, 11, -1, -10, 10],
+                          [-4, 12, 2, 9, 3], [7, 0, 17, 1, 4]],
+                         [[18, -13, -12, 9, -11], [19, -4, -7, 19, 14],
+                          [18, 9, -8, 19, -2], [8, 9, -1, 6, 9]]],
+                        dtype=np.int32)
+        ta = tensor.from_numpy(a_np)
+        tb = tensor.from_numpy(b_np)
+        ta.to_device(dev)
+        tb.to_device(dev)
+        y = ta - tb
+        np.testing.assert_array_almost_equal(tensor.to_numpy(y), a_np - b_np)
+
+    def test_kint_kint_cpu(self, dev=cpu_dev):
+        self._kint_kint(cpu_dev)
+
+    @unittest.skipIf(not singa_api.USE_CUDA, 'CUDA is not enabled')
+    def test_kint_kint_gpu(self, dev=gpu_dev):
+        self._kint_kint(gpu_dev)
+
+    def _kint_kint_bc(self, dev=gpu_dev):
+        a_np = np.array([[[17, 4, 9, 22, 18], [-9, 9, -1, -1, 4],
+                          [1, 14, 7, 1, 4], [3, 14, -2, 3, -8]],
+                         [[-25, 6, 8, -7, 22], [-14, 0, -1, 15, 14],
+                          [1, 3, -8, -19, -3], [1, 12, 12, -3, -3]],
+                         [[-10, -14, -17, 19, -5], [-4, -12, 7, -16, -2],
+                          [-8, 3, -5, -11, 0], [4, 0, 3, -6, -3]]],
+                        dtype=np.int32)
+        b_np = np.array([[-6, -3, -8, -17, 1], [-4, -16, 4, -9, 0],
+                         [7, 1, 11, -12, 4], [-6, -8, -5, -3, 0]],
+                        dtype=np.int32)
+        ta = tensor.from_numpy(a_np)
+        tb = tensor.from_numpy(b_np)
+        ta.to_device(dev)
+        tb.to_device(dev)
+        y = ta - tb
+        np.testing.assert_array_almost_equal(tensor.to_numpy(y), a_np - b_np)
+
+    def test_kint_kint_bc_cpu(self, dev=cpu_dev):
+        self._kint_kint_bc(cpu_dev)
+
+    @unittest.skipIf(not singa_api.USE_CUDA, 'CUDA is not enabled')
+    def test_kint_kint_bc_gpu(self, dev=gpu_dev):
+        self._kint_kint_bc(gpu_dev)
+
 
 if __name__ == '__main__':
     unittest.main()
diff --git a/test/singa/test_cpp_cpu.cc b/test/singa/test_cpp_cpu.cc
index 19be14f..b8efe37 100644
--- a/test/singa/test_cpp_cpu.cc
+++ b/test/singa/test_cpp_cpu.cc
@@ -42,7 +42,7 @@
   CppCPU dev;
   Block* b = dev.NewBlock(4);
   int x = 1, y = 3, z = 0;
-  dev.Exec([x, y, &z](singa::Context* ctx) { z = x + y; }, {b}, {b}, false);
+  dev.Exec([x, y, &z](singa::Context* ctx) { z = x + y; }, {b}, {b});
   EXPECT_EQ(x + y, z);
   dev.FreeBlock(b);
 }
@@ -58,7 +58,7 @@
   EXPECT_EQ('x', bstr[3]);
 
   Block* c = dev.NewBlock(4);
-  dev.CopyDataToFrom(c, b, 4, singa::kHostToHost, 0, 0);
+  dev.CopyDataToFrom(c, b, 4, singa::kHostToHost, 0, 0, dev.context(0));
   const char* cstr = static_cast<const char*>(c->data());
 
   EXPECT_EQ('a', cstr[0]);
diff --git a/test/singa/test_operation_rnn.cc b/test/singa/test_operation_rnn.cc
new file mode 100644
index 0000000..bf52975
--- /dev/null
+++ b/test/singa/test_operation_rnn.cc
@@ -0,0 +1,141 @@
+/************************************************************
+ *
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ *************************************************************/
+#include "../src/model/operation/rnn.h"
+#include "gtest/gtest.h"
+#include "singa/core/tensor.h"
+#include "singa/singa_config.h"
+
+using namespace singa;
+
+#ifdef USE_CUDNN
+TEST(OperationRNN, tranining) {
+  auto cuda = std::make_shared<singa::CudaGPU>();
+
+  size_t hidden_size = 7;
+  int seq_length = 5;
+  size_t batch_size = 6;
+  size_t feature_size = 3;
+  size_t num_layers = 1;
+  int bdirect = 0;
+
+  Shape s_s{num_layers * (bdirect ? 2 : 1), batch_size, hidden_size};
+  Shape y_s{seq_length, batch_size, hidden_size * (bdirect ? 2 : 1)};
+
+  // x
+  Tensor x(Shape{seq_length, batch_size, feature_size}, cuda);
+  Gaussian(0.0f, 1.0f, &x);
+
+  // x hidden states and cell states
+  Tensor hx(s_s, cuda);
+  Tensor cx(s_s, cuda);
+  hx.SetValue(0.0f);
+  cx.SetValue(0.0f);
+
+  // y dy
+  Tensor y(y_s, cuda);
+  Tensor dy(y_s, cuda);
+  Gaussian(0.0f, 1.0f, &y);
+  Gaussian(0.0f, 1.0f, &dy);
+
+  // y hidden states and cell states
+  Tensor dhy(s_s, cuda);
+  Tensor dcy(s_s, cuda);
+  Gaussian(0.0f, 1.0f, &dhy);
+  Gaussian(0.0f, 1.0f, &dcy);
+
+  // init handle and weights
+  CudnnRNNHandle rnn_handle(x, hidden_size);
+  Tensor W(Shape{rnn_handle.weights_size}, cuda);
+  Gaussian(0.0f, 1.0f, &W);
+
+  // forward and backward passes
+  auto outputs = GpuRNNForwardTraining(x, hx, cx, W, rnn_handle);
+  auto outputs2 = GpuRNNForwardInference(x, hx, cx, W, rnn_handle);
+  auto output3 = GpuRNNBackwardx(y, dy, dhy, dcy, W, hx, cx, rnn_handle);
+  auto dW = GpuRNNBackwardW(x, hx, y, rnn_handle);
+}
+
+TEST(OperationRNNEx, tranining) {
+  auto cuda = std::make_shared<singa::CudaGPU>();
+
+  size_t hidden_size = 2;
+  size_t seq_length = 6;
+  size_t batch_size = 6;
+  size_t feature_size = 4;
+  int bdirect = 0;  // 0 or 1
+  size_t num_layers = 1;
+
+  Shape s_s{num_layers * (bdirect ? 2 : 1), batch_size, hidden_size};
+  Shape y_s{seq_length, batch_size, hidden_size * (bdirect ? 2 : 1)};
+  Shape x_s{seq_length, batch_size, feature_size};
+
+  // x
+  Tensor x(x_s, cuda);
+  Gaussian(0.0f, 1.0f, &x);
+
+  // x hidden states and cell states
+  Tensor hx(s_s, cuda);
+  Tensor cx(s_s, cuda);
+  hx.SetValue(0.0f);
+  cx.SetValue(0.0f);
+
+  // y hidden states and cell states
+  Tensor dhy(s_s, cuda);
+  Tensor dcy(s_s, cuda);
+  Gaussian(0.0f, 1.0f, &dhy);
+  Gaussian(0.0f, 1.0f, &dcy);
+
+  // y dy
+  Tensor y(y_s, cuda);
+  Tensor dy(y_s, cuda);
+  Gaussian(0.0f, 1.0f, &y);
+  Gaussian(0.0f, 1.0f, &dy);
+
+  // seq lengths
+  Tensor seq_lengths(
+      Shape{
+          batch_size,
+      },
+      cuda, singa::kInt);
+  vector<int> data(batch_size, seq_length);
+  seq_lengths.CopyDataFromHostPtr(data.data(), batch_size);
+
+  // init handle and weights
+  CudnnRNNHandle rnn_handle(x, hidden_size, 0);
+  Tensor W(Shape{rnn_handle.weights_size}, cuda);
+  Gaussian(0.0f, 1.0f, &W);
+
+  // forward and backward passes for batch first format
+  /* TODO: WARNING: Logging before InitGoogleLogging() is written to STDERR
+    F0619 07:11:43.435175  1094 rnn.cc:658] Check failed: status ==
+    CUDNN_STATUS_SUCCESS (8 vs. 0)  CUDNN_STATUS_EXECUTION_FAILED
+    *** Check failure stack trace: ***
+    Aborted (core dumped)
+    */
+  auto outputs = GpuRNNForwardTrainingEx(x, hx, cx, W, seq_lengths, rnn_handle);
+  auto outputs2 =
+      GpuRNNForwardInferenceEx(x, hx, cx, W, seq_lengths, rnn_handle);
+  auto outputs3 =
+      GpuRNNBackwardxEx(y, dy, dhy, dcy, W, hx, cx, seq_lengths, rnn_handle);
+  auto dW = GpuRNNBackwardWEx(x, hx, y, seq_lengths, rnn_handle);
+}
+
+#endif  // USE_CUDNN
diff --git a/test/singa/test_scheduler.cc b/test/singa/test_scheduler.cc
index 9c95a05..c94f8f7 100644
--- a/test/singa/test_scheduler.cc
+++ b/test/singa/test_scheduler.cc
@@ -32,6 +32,7 @@
 using singa::Blk2InfoMap;
 using singa::BlkInfo;
 using singa::Block;
+using singa::BlockSet;
 using singa::BlockType;
 using singa::BlockVec;
 using singa::Context;
@@ -100,13 +101,13 @@
     }                                                                         \
   } while (false)
 
-#define CheckWriteBlocks(write_blocks, correct_write_blocks)     \
-  do {                                                           \
-    EXPECT_EQ(correct_write_blocks.size(), write_blocks.size()); \
-    for (size_t i = 0; i < write_blocks.size(); ++i) {           \
-      EXPECT_EQ(correct_write_blocks[i], write_blocks[i])        \
-          << "write_blocks is wrong at index [" << i << "]";     \
-    }                                                            \
+#define CheckLeafBlocks(leaf_blocks, correct_leaf_blocks)                   \
+  do {                                                                      \
+    EXPECT_EQ(correct_leaf_blocks.size(), leaf_blocks.size());              \
+    for (auto it : leaf_blocks) {                                           \
+      auto iter = correct_leaf_blocks.find(it);                             \
+      EXPECT_NE(iter, correct_leaf_blocks.end()) << "leaf blocks mismatch"; \
+    }                                                                       \
   } while (false)
 
 #define CheckFreeBlocks(node_id, blocks, free_blocks, correct_free_blocks)   \
@@ -157,7 +158,7 @@
     auto &nodes = graph.nodes();
     auto &edges = graph.edges();
     auto &blocks = graph.blocks();
-    auto &write_blocks = graph.write_blocks();
+    auto &leaf_blocks = graph.leaf_blocks();
 
     Tensor in(Shape{1}, dev);
     Tensor out(Shape{1}, dev);
@@ -168,7 +169,7 @@
     EXPECT_EQ(1u, nodes.size());
     EXPECT_EQ(2u, edges.size());
     EXPECT_EQ(2u, blocks.size());
-    EXPECT_EQ(1u, write_blocks.size());
+    EXPECT_EQ(1u, leaf_blocks.size());
 
     auto node = nodes[0];
     auto edge1 = edges[0];
@@ -182,7 +183,7 @@
     CheckBlock(block1, 0, in.block(), BlockType::kInput, 1, nullptr,
                NodeVec({}));
     CheckBlock(block2, 1, out.block(), BlockType::kEnd, 1, edge2, NodeVec({}));
-    CheckWriteBlocks(write_blocks, BlockVec({out.block()}));
+    CheckLeafBlocks(leaf_blocks, BlockSet({out.block()}));
     EXPECT_TRUE(graph.dirty());
   }
 }
@@ -197,7 +198,7 @@
     auto &nodes = graph.nodes();
     auto &edges = graph.edges();
     auto &blocks = graph.blocks();
-    auto &write_blocks = graph.write_blocks();
+    auto &leaf_blocks = graph.leaf_blocks();
 
     Tensor in(Shape{1}, dev);
     Tensor out(Shape{1}, dev);
@@ -209,7 +210,7 @@
     EXPECT_EQ(2u, nodes.size());
     EXPECT_EQ(3u, edges.size());
     EXPECT_EQ(2u, blocks.size());
-    EXPECT_EQ(1u, write_blocks.size());
+    EXPECT_EQ(1u, leaf_blocks.size());
 
     auto node1 = nodes[0];
     auto node2 = nodes[1];
@@ -228,7 +229,7 @@
                NodeVec({}));
     CheckBlock(block2, 1, out.block(), BlockType::kInter, 1, edge3,
                NodeVec({}));
-    CheckWriteBlocks(write_blocks, BlockVec({out.block()}));
+    CheckLeafBlocks(leaf_blocks, BlockSet({out.block()}));
     EXPECT_TRUE(graph.dirty());
   }
 }
@@ -243,7 +244,7 @@
     auto &nodes = graph.nodes();
     auto &edges = graph.edges();
     auto &blocks = graph.blocks();
-    auto &write_blocks = graph.write_blocks();
+    auto &leaf_blocks = graph.leaf_blocks();
 
     Tensor in(Shape{1}, dev);
     Tensor out(Shape{1}, dev);
@@ -254,7 +255,7 @@
     EXPECT_EQ(1u, nodes.size());
     EXPECT_EQ(2u, edges.size());
     EXPECT_EQ(1u, blocks.size());
-    EXPECT_EQ(1u, write_blocks.size());
+    EXPECT_EQ(1u, leaf_blocks.size());
 
     auto node1 = nodes[0];
     auto edge1 = edges[0];
@@ -265,7 +266,7 @@
     CheckEdge(edge1, 0, in.block(), nullptr, node1);
     CheckEdge(edge2, 1, in.block(), node1, nullptr);
     CheckBlock(block1, 0, in.block(), BlockType::kParam, 2, edge2, NodeVec({}));
-    CheckWriteBlocks(write_blocks, BlockVec({in.block()}));
+    CheckLeafBlocks(leaf_blocks, BlockSet{in.block()});
     EXPECT_TRUE(graph.dirty());
 
     graph.AddOperation(op, {in.block(), out.block()}, {out.block()});
@@ -273,7 +274,7 @@
     EXPECT_EQ(2u, nodes.size());
     EXPECT_EQ(4u, edges.size());
     EXPECT_EQ(2u, blocks.size());
-    EXPECT_EQ(1u, write_blocks.size());
+    EXPECT_EQ(1u, leaf_blocks.size());
 
     auto node2 = nodes[1];
     auto edge3 = edges[2];
@@ -288,7 +289,7 @@
     CheckBlock(block1, 0, in.block(), BlockType::kParam, 3, edge2, NodeVec({}));
     CheckBlock(block2, 1, out.block(), BlockType::kParam, 2, edge4,
                NodeVec({}));
-    CheckWriteBlocks(write_blocks, BlockVec({out.block()}));
+    CheckLeafBlocks(leaf_blocks, BlockSet({out.block()}));
     EXPECT_TRUE(graph.dirty());
   }
 }
@@ -303,7 +304,7 @@
     auto &nodes = graph.nodes();
     auto &edges = graph.edges();
     auto &blocks = graph.blocks();
-    auto &write_blocks = graph.write_blocks();
+    auto &leaf_blocks = graph.leaf_blocks();
 
     Tensor in(Shape{1}, dev);
     Tensor out(Shape{1}, dev);
@@ -314,7 +315,7 @@
     EXPECT_EQ(1u, nodes.size());
     EXPECT_EQ(2u, edges.size());
     EXPECT_EQ(2u, blocks.size());
-    EXPECT_EQ(1u, write_blocks.size());
+    EXPECT_EQ(1u, leaf_blocks.size());
 
     auto block1 = blocks.find(in.block())->second;
 
@@ -333,7 +334,7 @@
     auto &nodes = graph.nodes();
     auto &edges = graph.edges();
     auto &blocks = graph.blocks();
-    auto &write_blocks = graph.write_blocks();
+    auto &leaf_blocks = graph.leaf_blocks();
 
     Tensor in(Shape{1}, dev);
     Tensor mid(Shape{1}, dev);
@@ -347,7 +348,7 @@
     EXPECT_EQ(3u, nodes.size());
     EXPECT_EQ(5u, edges.size());
     EXPECT_EQ(3u, blocks.size());
-    EXPECT_EQ(1u, write_blocks.size());
+    EXPECT_EQ(1u, leaf_blocks.size());
 
     auto edge2 = edges[1];
     auto edge5 = edges[4];
@@ -370,7 +371,7 @@
     auto &nodes = graph.nodes();
     auto &edges = graph.edges();
     auto &blocks = graph.blocks();
-    auto &write_blocks = graph.write_blocks();
+    auto &leaf_blocks = graph.leaf_blocks();
 
     Tensor in(Shape{1}, dev);
     Tensor mid(Shape{1}, dev);
@@ -384,7 +385,7 @@
     EXPECT_EQ(3u, nodes.size());
     EXPECT_EQ(4u, edges.size());
     EXPECT_EQ(3u, blocks.size());
-    EXPECT_EQ(1u, write_blocks.size());
+    EXPECT_EQ(1u, leaf_blocks.size());
 
     auto edge2 = edges[1];
     auto edge4 = edges[3];
@@ -408,7 +409,7 @@
     auto &nodes = graph.nodes();
     auto &edges = graph.edges();
     auto &blocks = graph.blocks();
-    auto &write_blocks = graph.write_blocks();
+    auto &leaf_blocks = graph.leaf_blocks();
 
     Tensor in(Shape{1}, dev);
     Tensor out1(Shape{1}, dev);
@@ -421,7 +422,7 @@
     EXPECT_EQ(2u, nodes.size());
     EXPECT_EQ(3u, edges.size());
     EXPECT_EQ(3u, blocks.size());
-    EXPECT_EQ(1u, write_blocks.size());
+    EXPECT_EQ(2u, leaf_blocks.size());
 
     auto edge2 = edges[1];
     auto edge3 = edges[2];
@@ -443,7 +444,7 @@
     auto &nodes = graph.nodes();
     auto &edges = graph.edges();
     auto &blocks = graph.blocks();
-    auto &write_blocks = graph.write_blocks();
+    auto &leaf_blocks = graph.leaf_blocks();
 
     Tensor in(Shape{1}, dev);
     Tensor mid(Shape{1}, dev);
@@ -493,7 +494,7 @@
     EXPECT_EQ(7u, nodes.size());
     EXPECT_EQ(14u, edges.size());
     EXPECT_EQ(12u, blocks.size());
-    EXPECT_EQ(1u, write_blocks.size());
+    EXPECT_EQ(3u, leaf_blocks.size());
 
     in.SetValue(0);
     b1.SetValue(-1);
@@ -511,6 +512,47 @@
   }
 }
 
+TEST_F(TestGraph, MultipleIndependentOps) {
+  for (auto &it : devices) {
+    GOUT << "Test graph on device [" << it.first << "]" << std::endl;
+
+    auto dev = it.second;
+    Graph graph(dev.get());
+
+    auto &nodes = graph.nodes();
+
+    Tensor workspace(Shape{1}, dev);
+    Tensor b1(Shape{1}, dev);
+    Tensor b2(Shape{1}, dev);
+    Tensor b3(Shape{1}, dev);
+    Tensor b4(Shape{1}, dev);
+
+    // emulate clean up workspace, use the rnn design as reference
+    auto clean1 = [workspace](Context *ctx) mutable {};
+    auto clean2 = [workspace](Context *ctx) mutable {};
+    auto clean3 = [workspace](Context *ctx) mutable {};
+    auto clean4 = [workspace](Context *ctx) mutable {};
+
+    // emulate usage of workspace, use the rnn design as reference
+    auto op1 = [workspace, b1](Context *ctx) mutable {};
+    auto op2 = [workspace, b2](Context *ctx) mutable {};
+    auto op3 = [workspace, b2](Context *ctx) mutable {};
+    auto op4 = [workspace, b2](Context *ctx) mutable {};
+
+    graph.AddOperation(clean1, {}, {workspace.block()});
+    graph.AddOperation(op1, {b1.block()}, {workspace.block(), b1.block()});
+    graph.AddOperation(clean2, {}, {workspace.block()});
+    graph.AddOperation(op2, {b2.block()}, {workspace.block(), b2.block()});
+    graph.AddOperation(clean3, {}, {workspace.block()});
+    graph.AddOperation(op3, {b3.block()}, {workspace.block(), b3.block()});
+    graph.AddOperation(clean4, {}, {workspace.block()});
+    graph.AddOperation(op4, {b4.block()}, {workspace.block(), b4.block()});
+
+    EXPECT_EQ(8u, nodes.size());
+    graph.RunGraph();
+  }
+}
+
 TEST_F(TestGraph, RunInSerial) {
   for (auto &it : devices) {
     GOUT << "Test graph on device [" << it.first << "]" << std::endl;
@@ -521,7 +563,7 @@
     auto &nodes = graph.nodes();
     auto &edges = graph.edges();
     auto &blocks = graph.blocks();
-    auto &write_blocks = graph.write_blocks();
+    auto &leaf_blocks = graph.leaf_blocks();
 
     Tensor in(Shape{1}, dev);
     Tensor mid(Shape{1}, dev);
@@ -570,7 +612,7 @@
     EXPECT_EQ(7u, nodes.size());
     EXPECT_EQ(14u, edges.size());
     EXPECT_EQ(12u, blocks.size());
-    EXPECT_EQ(1u, write_blocks.size());
+    EXPECT_EQ(3u, leaf_blocks.size());
 
     in.SetValue(0);
     b1.SetValue(-1);
@@ -598,7 +640,7 @@
     auto &nodes = graph.nodes();
     auto &edges = graph.edges();
     auto &blocks = graph.blocks();
-    auto &write_blocks = graph.write_blocks();
+    auto &leaf_blocks = graph.leaf_blocks();
 
     {
       Tensor in(Shape{1}, dev);
@@ -671,7 +713,7 @@
     EXPECT_EQ(10u, nodes.size());
     EXPECT_EQ(21u, edges.size());
     EXPECT_EQ(15u, blocks.size());
-    EXPECT_EQ(1u, write_blocks.size());
+    EXPECT_EQ(3u, leaf_blocks.size());
 
     graph.RunGraph();
 
diff --git a/test/singa/test_tensor_math.cc b/test/singa/test_tensor_math.cc
index f0eafb6..a980f22 100644
--- a/test/singa/test_tensor_math.cc
+++ b/test/singa/test_tensor_math.cc
@@ -16,6 +16,8 @@
  * limitations under the License.
  */
 
+#include <array>
+
 #include "gtest/gtest.h"
 #include "singa/core/tensor.h"
 using singa::Device;
@@ -270,6 +272,14 @@
   EXPECT_FLOAT_EQ(1.0f, dptr1[2]);
 }
 
+TEST_F(TensorMath, EQCpp) {
+  Tensor p1 = a == 2.0f;
+  const float *dptr1 = p1.data<float>();
+  EXPECT_FLOAT_EQ(0.0f, dptr1[0]);
+  EXPECT_FLOAT_EQ(1.0f, dptr1[1]);
+  EXPECT_FLOAT_EQ(0.0f, dptr1[2]);
+}
+
 TEST_F(TensorMath, PowCpp) {
   Tensor p1 = Pow(b, 3.0f);
   const float *dptr1 = p1.data<float>();
diff --git a/tool/conda/dist/meta.yaml b/tool/conda/dist/meta.yaml
index f59dea7..97cc0b3 100644
--- a/tool/conda/dist/meta.yaml
+++ b/tool/conda/dist/meta.yaml
@@ -21,9 +21,13 @@
   name: singa-dist
   version: {{ environ.get('GIT_DESCRIBE_TAG') }}
 
+source:
+  path: ../../../
+  # git_url: https://github.com/apache/singa.git
+
 requirements:
   run:
-    - singa {{ environ.get('GIT_DESCRIBE_TAG') }} cudnn7.3.1_cuda10.0_nccl2.4.8.1_mpich3.3.2_py{{ py }}
+    - singa {{ environ.get('GIT_DESCRIBE_TAG') }} cudnn7.6.5_cuda10.0_nccl2.4.8.1_mpich3.3.2_py{{ py }}
 
 build:
   number: 0
@@ -33,4 +37,4 @@
 about:
   home: http://singa.apache.org/
   license: Apache V2
-  summary: SINGA is an Apache Incubating project for providing distributed deep learning. Apache disclaimers http://singa.apache.org/en/index.html#disclaimers
\ No newline at end of file
+  summary: SINGA is an Apache Incubating project for providing distributed deep learning. Apache disclaimers http://singa.apache.org/en/index.html#disclaimers
diff --git a/tool/conda/docker/cuda10.2/Dockerfile b/tool/conda/docker/cuda10.2/Dockerfile
new file mode 100644
index 0000000..7526b2f
--- /dev/null
+++ b/tool/conda/docker/cuda10.2/Dockerfile
@@ -0,0 +1,63 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# 18.04 has erros in ssh
+FROM nvidia/cuda:10.2-devel-ubuntu16.04
+
+# install dependencies
+RUN apt-get update \
+    && apt-get install -y --no-install-recommends \
+        git \
+        build-essential \
+        cmake \
+        wget \
+        openssh-server \
+        ca-certificates \
+    && apt-get clean \
+    && apt-get autoremove \
+    && apt-get autoclean \
+    && rm -rf /var/lib/apt/lists/* \
+    #
+    # install conda, conda-build and anaconda-client
+    #
+    && wget --no-check-certificate https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh \
+    && bash miniconda.sh -b -p /root/miniconda \
+    && /root/miniconda/bin/conda config --set always_yes yes --set changeps1 no \
+    && /root/miniconda/bin/conda update -q conda \
+    && /root/miniconda/bin/conda install -y \
+        conda-build \
+        anaconda-client \
+    && /root/miniconda/bin/conda clean -tipsy \
+    # config ssh service
+    && mkdir /var/run/sshd \
+    && echo 'root:singa' | chpasswd \
+    # for ubuntu 16.04 prohibit
+    && sed -i 's/PermitRootLogin prohibit-password/PermitRootLogin yes/' /etc/ssh/sshd_config \
+    # SSH login fix. Otherwise user is kicked off after login
+    && sed 's@session\s*required\s*pam_loginuid.so@session optional pam_loginuid.so@g' -i /etc/pam.d/sshd \
+    # dump environment variables into files, so that ssh can see also
+    && env | grep _ >> /etc/environment
+
+# Add conda to PATH. Doing this here so other RUN steps can be grouped above
+ENV PATH /root/miniconda/bin:${PATH}
+
+# In nvidia/cuda:10.2-devel-ubuntu16.04, the location of cubas headers moved to another directory
+RUN cp /usr/include/cublas* /usr/local/cuda/include/
+
+EXPOSE 22
+
+CMD ["/usr/sbin/sshd", "-D"]
diff --git a/tool/conda/gpu/meta.yaml b/tool/conda/gpu/meta.yaml
index c247372..58ef499 100644
--- a/tool/conda/gpu/meta.yaml
+++ b/tool/conda/gpu/meta.yaml
@@ -26,7 +26,7 @@
 
 requirements:
   run:
-    - singa {{ environ.get('GIT_DESCRIBE_TAG') | replace("-", ".") }} cudnn7.3.1_cuda10.0_py{{ py }}
+    - singa {{ environ.get('GIT_DESCRIBE_TAG') | replace("-", ".") }} cudnn7.6.5_cuda10.0_py{{ py }}
 
 build:
   number: 0
diff --git a/tool/conda/singa/conda_build_config.yaml b/tool/conda/singa/conda_build_config.yaml
index 4652e9d..9ac45c0 100644
--- a/tool/conda/singa/conda_build_config.yaml
+++ b/tool/conda/singa/conda_build_config.yaml
@@ -23,22 +23,27 @@
     - 5.4                   # [linux]
 # https://docs.conda.io/projects/conda-build/en/latest/resources/compiler-tools.html#macos-sdk
 CONDA_BUILD_SYSROOT:
-    - "/tmp/MacOSX10.9.sdk" # [osx]
+    - "/Applications/Xcode_11.7.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk" # [osx]
 cudnn:                      # [linux]
-    - "7.3.1 cuda10.0_0"    # [environ.get("CUDA")=="10.0"]
-    - "7.3.1 cuda9.0_0"     # [environ.get("CUDA")=="9.0"]
+    - "7.6.5 cuda10.2_0"    # [environ.get("CUDA")=="10.2"]
+    - "7.6.5 cuda10.0_0"    # [environ.get("CUDA")=="10.0"]
+    - "7.6.5 cuda9.0_0"     # [environ.get("CUDA")=="9.0"]
 dnnl:
     - 1.1
 python:
-    - 3.6
+#    - 3.6
     - 3.7
 nccl:
-    - 2.4.8.1
+    - 2.6.4.1               # [environ.get("CUDA")=="10.2"]
+    - 2.4.8.1               # [environ.get("CUDA")=="10.0"]
+    - 2.4.8.1               # [environ.get("CUDA")=="9.0"]
 mpich:
     - 3.3.2
 build_str:
-    - "cudnn7.3.1_cuda10.0"   # [environ.get("CUDA")=="10.0"] && [environ.get("DIST")=="OFF"]
-    - "cudnn7.3.1_cuda9.0"    # [environ.get("CUDA")=="9.0"] && [environ.get("DIST")=="OFF"]
-    - "cpu"              # [environ.get("CUDA", "")== ""]
-    - "cudnn7.3.1_cuda10.0_nccl2.4.8.1_mpich3.3.2"     # [environ.get("CUDA")=="10.0"] && [environ.get("DIST")=="ON"]
-    - "cudnn7.3.1_cuda9.0_nccl2.4.8.1_mpich3.3.2"     # [environ.get("CUDA")=="9.0"] && [environ.get("DIST")=="ON"]
+    - "cudnn7.6.5_cuda10.2"   # [environ.get("CUDA")=="10.2"] && [environ.get("DIST")=="OFF"]
+    - "cudnn7.6.5_cuda10.0"   # [environ.get("CUDA")=="10.0"] && [environ.get("DIST")=="OFF"]
+    - "cudnn7.6.5_cuda9.0"    # [environ.get("CUDA")=="9.0"] && [environ.get("DIST")=="OFF"]
+    - "cpu"                   # [environ.get("CUDA", "")== ""]
+    - "cudnn7.6.5_cuda10.2_nccl2.6.4.1_mpich3.3.2"     # [environ.get("CUDA")=="10.2"] && [environ.get("DIST")=="ON"]
+    - "cudnn7.6.5_cuda10.0_nccl2.4.8.1_mpich3.3.2"     # [environ.get("CUDA")=="10.0"] && [environ.get("DIST")=="ON"]
+    - "cudnn7.6.5_cuda9.0_nccl2.4.8.1_mpich3.3.2"      # [environ.get("CUDA")=="9.0"] && [environ.get("DIST")=="ON"]
diff --git a/tool/conda/singa/meta.yaml b/tool/conda/singa/meta.yaml
index e4d09d6..cface0d 100644
--- a/tool/conda/singa/meta.yaml
+++ b/tool/conda/singa/meta.yaml
@@ -50,7 +50,9 @@
     - protobuf 3.10.0         # [osx]
     - protobuf 3.9.2          # [linux]
     - glog 0.3.5
-    - numpy 1.16.5
+    - numpy >=1.16,<2.0
+    - pytest
+    - deprecated 1.2.7
     - cudnn {{ cudnn }}       # ['cudnn' in str(build_str)]
     - dnnl {{ dnnl }}
     - python {{ python }}
@@ -73,13 +75,16 @@
     - tqdm
     - onnx 1.6.0
     - deprecated 1.2.7
-
+    
 test:
+  requires:
+    - pytest-cov
+    - tabulate
+    - codecov
   source_files:
     - test/python/*.py
   commands:
-    - cd test/python && python run.py
-
+    - {{ environ.get('TEST_COMMAND', 'cd test/python && python run.py') }}
 about:
   home: http://singa.apache.org/
   license: Apache V2
diff --git a/tool/cpplint.py b/tool/cpplint.py
deleted file mode 100755
index 08a304d..0000000
--- a/tool/cpplint.py
+++ /dev/null
@@ -1,6327 +0,0 @@
-#!/usr/bin/env python
-#
-# Copyright (c) 2009 Google Inc. All rights reserved.
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions are
-# met:
-#
-#    * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-#    * Redistributions in binary form must reproduce the above
-# copyright notice, this list of conditions and the following disclaimer
-# in the documentation and/or other materials provided with the
-# distribution.
-#    * Neither the name of Google Inc. nor the names of its
-# contributors may be used to endorse or promote products derived from
-# this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
-# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
-# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
-# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-"""Does google-lint on c++ files.
-
-The goal of this script is to identify places in the code that *may*
-be in non-compliance with google style.  It does not attempt to fix
-up these problems -- the point is to educate.  It does also not
-attempt to find all problems, or to ensure that everything it does
-find is legitimately a problem.
-
-In particular, we can get very confused by /* and // inside strings!
-We do a small hack, which is to ignore //'s with "'s after them on the
-same line, but it is far from perfect (in either direction).
-"""
-
-import codecs
-import copy
-import getopt
-import math  # for log
-import os
-import re
-import sre_compile
-import string
-import sys
-import unicodedata
-
-
-_USAGE = """
-Syntax: cpplint.py [--verbose=#] [--output=vs7] [--filter=-x,+y,...]
-                   [--counting=total|toplevel|detailed] [--root=subdir]
-                   [--linelength=digits]
-        <file> [file] ...
-
-  The style guidelines this tries to follow are those in
-    http://google-styleguide.googlecode.com/svn/trunk/cppguide.xml
-
-  Every problem is given a confidence score from 1-5, with 5 meaning we are
-  certain of the problem, and 1 meaning it could be a legitimate construct.
-  This will miss some errors, and is not a substitute for a code review.
-
-  To suppress false-positive errors of a certain category, add a
-  'NOLINT(category)' comment to the line.  NOLINT or NOLINT(*)
-  suppresses errors of all categories on that line.
-
-  The files passed in will be linted; at least one file must be provided.
-  Default linted extensions are .cc, .cpp, .cu, .cuh and .h.  Change the
-  extensions with the --extensions flag.
-
-  Flags:
-
-    output=vs7
-      By default, the output is formatted to ease emacs parsing.  Visual Studio
-      compatible output (vs7) may also be used.  Other formats are unsupported.
-
-    verbose=#
-      Specify a number 0-5 to restrict errors to certain verbosity levels.
-
-    filter=-x,+y,...
-      Specify a comma-separated list of category-filters to apply: only
-      error messages whose category names pass the filters will be printed.
-      (Category names are printed with the message and look like
-      "[whitespace/indent]".)  Filters are evaluated left to right.
-      "-FOO" and "FOO" means "do not print categories that start with FOO".
-      "+FOO" means "do print categories that start with FOO".
-
-      Examples: --filter=-whitespace,+whitespace/braces
-                --filter=whitespace,runtime/printf,+runtime/printf_format
-                --filter=-,+build/include_what_you_use
-
-      To see a list of all the categories used in cpplint, pass no arg:
-         --filter=
-
-    counting=total|toplevel|detailed
-      The total number of errors found is always printed. If
-      'toplevel' is provided, then the count of errors in each of
-      the top-level categories like 'build' and 'whitespace' will
-      also be printed. If 'detailed' is provided, then a count
-      is provided for each category like 'build/class'.
-
-    root=subdir
-      The root directory used for deriving header guard CPP variable.
-      By default, the header guard CPP variable is calculated as the relative
-      path to the directory that contains .git, .hg, or .svn.  When this flag
-      is specified, the relative path is calculated from the specified
-      directory. If the specified directory does not exist, this flag is
-      ignored.
-
-      Examples:
-        Assuming that src/.git exists, the header guard CPP variables for
-        src/chrome/browser/ui/browser.h are:
-
-        No flag => CHROME_BROWSER_UI_BROWSER_H_
-        --root=chrome => BROWSER_UI_BROWSER_H_
-        --root=chrome/browser => UI_BROWSER_H_
-
-    linelength=digits
-      This is the allowed line length for the project. The default value is
-      80 characters.
-
-      Examples:
-        --linelength=120
-
-    extensions=extension,extension,...
-      The allowed file extensions that cpplint will check
-
-      Examples:
-        --extensions=hpp,cpp
-
-    cpplint.py supports per-directory configurations specified in CPPLINT.cfg
-    files. CPPLINT.cfg file can contain a number of key=value pairs.
-    Currently the following options are supported:
-
-      set noparent
-      filter=+filter1,-filter2,...
-      exclude_files=regex
-      linelength=80
-
-    "set noparent" option prevents cpplint from traversing directory tree
-    upwards looking for more .cfg files in parent directories. This option
-    is usually placed in the top-level project directory.
-
-    The "filter" option is similar in function to --filter flag. It specifies
-    message filters in addition to the |_DEFAULT_FILTERS| and those specified
-    through --filter command-line flag.
-
-    "exclude_files" allows to specify a regular expression to be matched against
-    a file name. If the expression matches, the file is skipped and not run
-    through liner.
-
-    "linelength" allows to specify the allowed line length for the project.
-
-    CPPLINT.cfg has an effect on files in the same directory and all
-    sub-directories, unless overridden by a nested configuration file.
-
-      Example file:
-        filter=-build/include_order,+build/include_alpha
-        exclude_files=.*\.cc
-
-    The above example disables build/include_order warning and enables
-    build/include_alpha as well as excludes all .cc from being
-    processed by linter, in the current directory (where the .cfg
-    file is located) and all sub-directories.
-"""
-
-# We categorize each error message we print.  Here are the categories.
-# We want an explicit list so we can list them all in cpplint --filter=.
-# If you add a new error message with a new category, add it to the list
-# here!  cpplint_unittest.py should tell you if you forget to do this.
-_ERROR_CATEGORIES = [
-    'build/class',
-    'build/c++11',
-    'build/deprecated',
-    'build/endif_comment',
-    'build/explicit_make_pair',
-    'build/forward_decl',
-    'build/header_guard',
-    'build/include',
-    'build/include_alpha',
-    'build/include_order',
-    'build/include_what_you_use',
-    'build/namespaces',
-    'build/printf_format',
-    'build/storage_class',
-    'legal/copyright',
-    'readability/alt_tokens',
-    'readability/braces',
-    'readability/casting',
-    'readability/check',
-    'readability/constructors',
-    'readability/fn_size',
-    'readability/function',
-    'readability/inheritance',
-    'readability/multiline_comment',
-    'readability/multiline_string',
-    'readability/namespace',
-    'readability/nolint',
-    'readability/nul',
-    'readability/strings',
-    'readability/todo',
-    'readability/utf8',
-    'runtime/arrays',
-    'runtime/casting',
-    'runtime/explicit',
-    'runtime/int',
-    'runtime/init',
-    'runtime/invalid_increment',
-    'runtime/member_string_references',
-    'runtime/memset',
-    'runtime/indentation_namespace',
-    'runtime/operator',
-    'runtime/printf',
-    'runtime/printf_format',
-    'runtime/references',
-    'runtime/string',
-    'runtime/threadsafe_fn',
-    'runtime/vlog',
-    'whitespace/blank_line',
-    'whitespace/braces',
-    'whitespace/comma',
-    'whitespace/comments',
-    'whitespace/empty_conditional_body',
-    'whitespace/empty_loop_body',
-    'whitespace/end_of_line',
-    'whitespace/ending_newline',
-    'whitespace/forcolon',
-    'whitespace/indent',
-    'whitespace/line_length',
-    'whitespace/newline',
-    'whitespace/operators',
-    'whitespace/parens',
-    'whitespace/semicolon',
-    'whitespace/tab',
-    'whitespace/todo',
-    ]
-
-# These error categories are no longer enforced by cpplint, but for backwards-
-# compatibility they may still appear in NOLINT comments.
-_LEGACY_ERROR_CATEGORIES = [
-    'readability/streams',
-    ]
-
-# The default state of the category filter. This is overridden by the --filter=
-# flag. By default all errors are on, so only add here categories that should be
-# off by default (i.e., categories that must be enabled by the --filter= flags).
-# All entries here should start with a '-' or '+', as in the --filter= flag.
-_DEFAULT_FILTERS = ['-build/include_alpha']
-
-# We used to check for high-bit characters, but after much discussion we
-# decided those were OK, as long as they were in UTF-8 and didn't represent
-# hard-coded international strings, which belong in a separate i18n file.
-
-# C++ headers
-_CPP_HEADERS = frozenset([
-    # Legacy
-    'algobase.h',
-    'algo.h',
-    'alloc.h',
-    'builtinbuf.h',
-    'bvector.h',
-    'complex.h',
-    'defalloc.h',
-    'deque.h',
-    'editbuf.h',
-    'fstream.h',
-    'function.h',
-    'hash_map',
-    'hash_map.h',
-    'hash_set',
-    'hash_set.h',
-    'hashtable.h',
-    'heap.h',
-    'indstream.h',
-    'iomanip.h',
-    'iostream.h',
-    'istream.h',
-    'iterator.h',
-    'list.h',
-    'map.h',
-    'multimap.h',
-    'multiset.h',
-    'ostream.h',
-    'pair.h',
-    'parsestream.h',
-    'pfstream.h',
-    'procbuf.h',
-    'pthread_alloc',
-    'pthread_alloc.h',
-    'rope',
-    'rope.h',
-    'ropeimpl.h',
-    'set.h',
-    'slist',
-    'slist.h',
-    'stack.h',
-    'stdiostream.h',
-    'stl_alloc.h',
-    'stl_relops.h',
-    'streambuf.h',
-    'stream.h',
-    'strfile.h',
-    'strstream.h',
-    'tempbuf.h',
-    'tree.h',
-    'type_traits.h',
-    'vector.h',
-    # 17.6.1.2 C++ library headers
-    'algorithm',
-    'array',
-    'atomic',
-    'bitset',
-    'chrono',
-    'codecvt',
-    'complex',
-    'condition_variable',
-    'deque',
-    'exception',
-    'forward_list',
-    'fstream',
-    'functional',
-    'future',
-    'initializer_list',
-    'iomanip',
-    'ios',
-    'iosfwd',
-    'iostream',
-    'istream',
-    'iterator',
-    'limits',
-    'list',
-    'locale',
-    'map',
-    'memory',
-    'mutex',
-    'new',
-    'numeric',
-    'ostream',
-    'queue',
-    'random',
-    'ratio',
-    'regex',
-    'set',
-    'sstream',
-    'stack',
-    'stdexcept',
-    'streambuf',
-    'string',
-    'strstream',
-    'system_error',
-    'thread',
-    'tuple',
-    'typeindex',
-    'typeinfo',
-    'type_traits',
-    'unordered_map',
-    'unordered_set',
-    'utility',
-    'valarray',
-    'vector',
-    # 17.6.1.2 C++ headers for C library facilities
-    'cassert',
-    'ccomplex',
-    'cctype',
-    'cerrno',
-    'cfenv',
-    'cfloat',
-    'cinttypes',
-    'ciso646',
-    'climits',
-    'clocale',
-    'cmath',
-    'csetjmp',
-    'csignal',
-    'cstdalign',
-    'cstdarg',
-    'cstdbool',
-    'cstddef',
-    'cstdint',
-    'cstdio',
-    'cstdlib',
-    'cstring',
-    'ctgmath',
-    'ctime',
-    'cuchar',
-    'cwchar',
-    'cwctype',
-    ])
-
-
-# These headers are excluded from [build/include] and [build/include_order]
-# checks:
-# - Anything not following google file name conventions (containing an
-#   uppercase character, such as Python.h or nsStringAPI.h, for example).
-# - Lua headers.
-_THIRD_PARTY_HEADERS_PATTERN = re.compile(
-    r'^(?:[^/]*[A-Z][^/]*\.h|lua\.h|lauxlib\.h|lualib\.h)$')
-
-
-# Assertion macros.  These are defined in base/logging.h and
-# testing/base/gunit.h.  Note that the _M versions need to come first
-# for substring matching to work.
-_CHECK_MACROS = [
-    'DCHECK', 'CHECK',
-    'EXPECT_TRUE_M', 'EXPECT_TRUE',
-    'ASSERT_TRUE_M', 'ASSERT_TRUE',
-    'EXPECT_FALSE_M', 'EXPECT_FALSE',
-    'ASSERT_FALSE_M', 'ASSERT_FALSE',
-    ]
-
-# Replacement macros for CHECK/DCHECK/EXPECT_TRUE/EXPECT_FALSE
-_CHECK_REPLACEMENT = dict([(m, {}) for m in _CHECK_MACROS])
-
-for op, replacement in [('==', 'EQ'), ('!=', 'NE'),
-                        ('>=', 'GE'), ('>', 'GT'),
-                        ('<=', 'LE'), ('<', 'LT')]:
-  _CHECK_REPLACEMENT['DCHECK'][op] = 'DCHECK_%s' % replacement
-  _CHECK_REPLACEMENT['CHECK'][op] = 'CHECK_%s' % replacement
-  _CHECK_REPLACEMENT['EXPECT_TRUE'][op] = 'EXPECT_%s' % replacement
-  _CHECK_REPLACEMENT['ASSERT_TRUE'][op] = 'ASSERT_%s' % replacement
-  _CHECK_REPLACEMENT['EXPECT_TRUE_M'][op] = 'EXPECT_%s_M' % replacement
-  _CHECK_REPLACEMENT['ASSERT_TRUE_M'][op] = 'ASSERT_%s_M' % replacement
-
-for op, inv_replacement in [('==', 'NE'), ('!=', 'EQ'),
-                            ('>=', 'LT'), ('>', 'LE'),
-                            ('<=', 'GT'), ('<', 'GE')]:
-  _CHECK_REPLACEMENT['EXPECT_FALSE'][op] = 'EXPECT_%s' % inv_replacement
-  _CHECK_REPLACEMENT['ASSERT_FALSE'][op] = 'ASSERT_%s' % inv_replacement
-  _CHECK_REPLACEMENT['EXPECT_FALSE_M'][op] = 'EXPECT_%s_M' % inv_replacement
-  _CHECK_REPLACEMENT['ASSERT_FALSE_M'][op] = 'ASSERT_%s_M' % inv_replacement
-
-# Alternative tokens and their replacements.  For full list, see section 2.5
-# Alternative tokens [lex.digraph] in the C++ standard.
-#
-# Digraphs (such as '%:') are not included here since it's a mess to
-# match those on a word boundary.
-_ALT_TOKEN_REPLACEMENT = {
-    'and': '&&',
-    'bitor': '|',
-    'or': '||',
-    'xor': '^',
-    'compl': '~',
-    'bitand': '&',
-    'and_eq': '&=',
-    'or_eq': '|=',
-    'xor_eq': '^=',
-    'not': '!',
-    'not_eq': '!='
-    }
-
-# Compile regular expression that matches all the above keywords.  The "[ =()]"
-# bit is meant to avoid matching these keywords outside of boolean expressions.
-#
-# False positives include C-style multi-line comments and multi-line strings
-# but those have always been troublesome for cpplint.
-_ALT_TOKEN_REPLACEMENT_PATTERN = re.compile(
-    r'[ =()](' + ('|'.join(_ALT_TOKEN_REPLACEMENT.keys())) + r')(?=[ (]|$)')
-
-
-# These constants define types of headers for use with
-# _IncludeState.CheckNextIncludeOrder().
-_C_SYS_HEADER = 1
-_CPP_SYS_HEADER = 2
-_LIKELY_MY_HEADER = 3
-_POSSIBLE_MY_HEADER = 4
-_OTHER_HEADER = 5
-
-# These constants define the current inline assembly state
-_NO_ASM = 0       # Outside of inline assembly block
-_INSIDE_ASM = 1   # Inside inline assembly block
-_END_ASM = 2      # Last line of inline assembly block
-_BLOCK_ASM = 3    # The whole block is an inline assembly block
-
-# Match start of assembly blocks
-_MATCH_ASM = re.compile(r'^\s*(?:asm|_asm|__asm|__asm__)'
-                        r'(?:\s+(volatile|__volatile__))?'
-                        r'\s*[{(]')
-
-
-_regexp_compile_cache = {}
-
-# {str, set(int)}: a map from error categories to sets of linenumbers
-# on which those errors are expected and should be suppressed.
-_error_suppressions = {}
-
-# The root directory used for deriving header guard CPP variable.
-# This is set by --root flag.
-_root = None
-
-# The allowed line length of files.
-# This is set by --linelength flag.
-_line_length = 80
-
-# The allowed extensions for file names
-# This is set by --extensions flag.
-_valid_extensions = set(['cc', 'h', 'cpp', 'cu', 'cuh'])
-
-def ParseNolintSuppressions(filename, raw_line, linenum, error):
-  """Updates the global list of error-suppressions.
-
-  Parses any NOLINT comments on the current line, updating the global
-  error_suppressions store.  Reports an error if the NOLINT comment
-  was malformed.
-
-  Args:
-    filename: str, the name of the input file.
-    raw_line: str, the line of input text, with comments.
-    linenum: int, the number of the current line.
-    error: function, an error handler.
-  """
-  matched = Search(r'\bNOLINT(NEXTLINE)?\b(\([^)]+\))?', raw_line)
-  if matched:
-    if matched.group(1):
-      suppressed_line = linenum + 1
-    else:
-      suppressed_line = linenum
-    category = matched.group(2)
-    if category in (None, '(*)'):  # => "suppress all"
-      _error_suppressions.setdefault(None, set()).add(suppressed_line)
-    else:
-      if category.startswith('(') and category.endswith(')'):
-        category = category[1:-1]
-        if category in _ERROR_CATEGORIES:
-          _error_suppressions.setdefault(category, set()).add(suppressed_line)
-        elif category not in _LEGACY_ERROR_CATEGORIES:
-          error(filename, linenum, 'readability/nolint', 5,
-                'Unknown NOLINT error category: %s' % category)
-
-
-def ResetNolintSuppressions():
-  """Resets the set of NOLINT suppressions to empty."""
-  _error_suppressions.clear()
-
-
-def IsErrorSuppressedByNolint(category, linenum):
-  """Returns true if the specified error category is suppressed on this line.
-
-  Consults the global error_suppressions map populated by
-  ParseNolintSuppressions/ResetNolintSuppressions.
-
-  Args:
-    category: str, the category of the error.
-    linenum: int, the current line number.
-  Returns:
-    bool, True iff the error should be suppressed due to a NOLINT comment.
-  """
-  return (linenum in _error_suppressions.get(category, set()) or
-          linenum in _error_suppressions.get(None, set()))
-
-
-def Match(pattern, s):
-  """Matches the string with the pattern, caching the compiled regexp."""
-  # The regexp compilation caching is inlined in both Match and Search for
-  # performance reasons; factoring it out into a separate function turns out
-  # to be noticeably expensive.
-  if pattern not in _regexp_compile_cache:
-    _regexp_compile_cache[pattern] = sre_compile.compile(pattern)
-  return _regexp_compile_cache[pattern].match(s)
-
-
-def ReplaceAll(pattern, rep, s):
-  """Replaces instances of pattern in a string with a replacement.
-
-  The compiled regex is kept in a cache shared by Match and Search.
-
-  Args:
-    pattern: regex pattern
-    rep: replacement text
-    s: search string
-
-  Returns:
-    string with replacements made (or original string if no replacements)
-  """
-  if pattern not in _regexp_compile_cache:
-    _regexp_compile_cache[pattern] = sre_compile.compile(pattern)
-  return _regexp_compile_cache[pattern].sub(rep, s)
-
-
-def Search(pattern, s):
-  """Searches the string for the pattern, caching the compiled regexp."""
-  if pattern not in _regexp_compile_cache:
-    _regexp_compile_cache[pattern] = sre_compile.compile(pattern)
-  return _regexp_compile_cache[pattern].search(s)
-
-
-class _IncludeState(object):
-  """Tracks line numbers for includes, and the order in which includes appear.
-
-  include_list contains list of lists of (header, line number) pairs.
-  It's a lists of lists rather than just one flat list to make it
-  easier to update across preprocessor boundaries.
-
-  Call CheckNextIncludeOrder() once for each header in the file, passing
-  in the type constants defined above. Calls in an illegal order will
-  raise an _IncludeError with an appropriate error message.
-
-  """
-  # self._section will move monotonically through this set. If it ever
-  # needs to move backwards, CheckNextIncludeOrder will raise an error.
-  _INITIAL_SECTION = 0
-  _MY_H_SECTION = 1
-  _C_SECTION = 2
-  _CPP_SECTION = 3
-  _OTHER_H_SECTION = 4
-
-  _TYPE_NAMES = {
-      _C_SYS_HEADER: 'C system header',
-      _CPP_SYS_HEADER: 'C++ system header',
-      _LIKELY_MY_HEADER: 'header this file implements',
-      _POSSIBLE_MY_HEADER: 'header this file may implement',
-      _OTHER_HEADER: 'other header',
-      }
-  _SECTION_NAMES = {
-      _INITIAL_SECTION: "... nothing. (This can't be an error.)",
-      _MY_H_SECTION: 'a header this file implements',
-      _C_SECTION: 'C system header',
-      _CPP_SECTION: 'C++ system header',
-      _OTHER_H_SECTION: 'other header',
-      }
-
-  def __init__(self):
-    self.include_list = [[]]
-    self.ResetSection('')
-
-  def FindHeader(self, header):
-    """Check if a header has already been included.
-
-    Args:
-      header: header to check.
-    Returns:
-      Line number of previous occurrence, or -1 if the header has not
-      been seen before.
-    """
-    for section_list in self.include_list:
-      for f in section_list:
-        if f[0] == header:
-          return f[1]
-    return -1
-
-  def ResetSection(self, directive):
-    """Reset section checking for preprocessor directive.
-
-    Args:
-      directive: preprocessor directive (e.g. "if", "else").
-    """
-    # The name of the current section.
-    self._section = self._INITIAL_SECTION
-    # The path of last found header.
-    self._last_header = ''
-
-    # Update list of includes.  Note that we never pop from the
-    # include list.
-    if directive in ('if', 'ifdef', 'ifndef'):
-      self.include_list.append([])
-    elif directive in ('else', 'elif'):
-      self.include_list[-1] = []
-
-  def SetLastHeader(self, header_path):
-    self._last_header = header_path
-
-  def CanonicalizeAlphabeticalOrder(self, header_path):
-    """Returns a path canonicalized for alphabetical comparison.
-
-    - replaces "-" with "_" so they both cmp the same.
-    - removes '-inl' since we don't require them to be after the main header.
-    - lowercase everything, just in case.
-
-    Args:
-      header_path: Path to be canonicalized.
-
-    Returns:
-      Canonicalized path.
-    """
-    return header_path.replace('-inl.h', '.h').replace('-', '_').lower()
-
-  def IsInAlphabeticalOrder(self, clean_lines, linenum, header_path):
-    """Check if a header is in alphabetical order with the previous header.
-
-    Args:
-      clean_lines: A CleansedLines instance containing the file.
-      linenum: The number of the line to check.
-      header_path: Canonicalized header to be checked.
-
-    Returns:
-      Returns true if the header is in alphabetical order.
-    """
-    # If previous section is different from current section, _last_header will
-    # be reset to empty string, so it's always less than current header.
-    #
-    # If previous line was a blank line, assume that the headers are
-    # intentionally sorted the way they are.
-    if (self._last_header > header_path and
-        Match(r'^\s*#\s*include\b', clean_lines.elided[linenum - 1])):
-      return False
-    return True
-
-  def CheckNextIncludeOrder(self, header_type):
-    """Returns a non-empty error message if the next header is out of order.
-
-    This function also updates the internal state to be ready to check
-    the next include.
-
-    Args:
-      header_type: One of the _XXX_HEADER constants defined above.
-
-    Returns:
-      The empty string if the header is in the right order, or an
-      error message describing what's wrong.
-
-    """
-    error_message = ('Found %s after %s' %
-                     (self._TYPE_NAMES[header_type],
-                      self._SECTION_NAMES[self._section]))
-
-    last_section = self._section
-
-    if header_type == _C_SYS_HEADER:
-      if self._section <= self._C_SECTION:
-        self._section = self._C_SECTION
-      else:
-        self._last_header = ''
-        return error_message
-    elif header_type == _CPP_SYS_HEADER:
-      if self._section <= self._CPP_SECTION:
-        self._section = self._CPP_SECTION
-      else:
-        self._last_header = ''
-        return error_message
-    elif header_type == _LIKELY_MY_HEADER:
-      if self._section <= self._MY_H_SECTION:
-        self._section = self._MY_H_SECTION
-      else:
-        self._section = self._OTHER_H_SECTION
-    elif header_type == _POSSIBLE_MY_HEADER:
-      if self._section <= self._MY_H_SECTION:
-        self._section = self._MY_H_SECTION
-      else:
-        # This will always be the fallback because we're not sure
-        # enough that the header is associated with this file.
-        self._section = self._OTHER_H_SECTION
-    else:
-      assert header_type == _OTHER_HEADER
-      self._section = self._OTHER_H_SECTION
-
-    if last_section != self._section:
-      self._last_header = ''
-
-    return ''
-
-
-class _CppLintState(object):
-  """Maintains module-wide state.."""
-
-  def __init__(self):
-    self.verbose_level = 1  # global setting.
-    self.error_count = 0    # global count of reported errors
-    # filters to apply when emitting error messages
-    self.filters = _DEFAULT_FILTERS[:]
-    # backup of filter list. Used to restore the state after each file.
-    self._filters_backup = self.filters[:]
-    self.counting = 'total'  # In what way are we counting errors?
-    self.errors_by_category = {}  # string to int dict storing error counts
-
-    # output format:
-    # "emacs" - format that emacs can parse (default)
-    # "vs7" - format that Microsoft Visual Studio 7 can parse
-    self.output_format = 'emacs'
-
-  def SetOutputFormat(self, output_format):
-    """Sets the output format for errors."""
-    self.output_format = output_format
-
-  def SetVerboseLevel(self, level):
-    """Sets the module's verbosity, and returns the previous setting."""
-    last_verbose_level = self.verbose_level
-    self.verbose_level = level
-    return last_verbose_level
-
-  def SetCountingStyle(self, counting_style):
-    """Sets the module's counting options."""
-    self.counting = counting_style
-
-  def SetFilters(self, filters):
-    """Sets the error-message filters.
-
-    These filters are applied when deciding whether to emit a given
-    error message.
-
-    Args:
-      filters: A string of comma-separated filters (eg "+whitespace/indent").
-               Each filter should start with + or -; else we die.
-
-    Raises:
-      ValueError: The comma-separated filters did not all start with '+' or '-'.
-                  E.g. "-,+whitespace,-whitespace/indent,whitespace/badfilter"
-    """
-    # Default filters always have less priority than the flag ones.
-    self.filters = _DEFAULT_FILTERS[:]
-    self.AddFilters(filters)
-
-  def AddFilters(self, filters):
-    """ Adds more filters to the existing list of error-message filters. """
-    for filt in filters.split(','):
-      clean_filt = filt.strip()
-      if clean_filt:
-        self.filters.append(clean_filt)
-    for filt in self.filters:
-      if not (filt.startswith('+') or filt.startswith('-')):
-        raise ValueError('Every filter in --filters must start with + or -'
-                         ' (%s does not)' % filt)
-
-  def BackupFilters(self):
-    """ Saves the current filter list to backup storage."""
-    self._filters_backup = self.filters[:]
-
-  def RestoreFilters(self):
-    """ Restores filters previously backed up."""
-    self.filters = self._filters_backup[:]
-
-  def ResetErrorCounts(self):
-    """Sets the module's error statistic back to zero."""
-    self.error_count = 0
-    self.errors_by_category = {}
-
-  def IncrementErrorCount(self, category):
-    """Bumps the module's error statistic."""
-    self.error_count += 1
-    if self.counting in ('toplevel', 'detailed'):
-      if self.counting != 'detailed':
-        category = category.split('/')[0]
-      if category not in self.errors_by_category:
-        self.errors_by_category[category] = 0
-      self.errors_by_category[category] += 1
-
-  def PrintErrorCounts(self):
-    """Print a summary of errors by category, and the total."""
-    for category, count in self.errors_by_category.iteritems():
-      sys.stderr.write('Category \'%s\' errors found: %d\n' %
-                       (category, count))
-    sys.stderr.write('Total errors found: %d\n' % self.error_count)
-
-_cpplint_state = _CppLintState()
-
-
-def _OutputFormat():
-  """Gets the module's output format."""
-  return _cpplint_state.output_format
-
-
-def _SetOutputFormat(output_format):
-  """Sets the module's output format."""
-  _cpplint_state.SetOutputFormat(output_format)
-
-
-def _VerboseLevel():
-  """Returns the module's verbosity setting."""
-  return _cpplint_state.verbose_level
-
-
-def _SetVerboseLevel(level):
-  """Sets the module's verbosity, and returns the previous setting."""
-  return _cpplint_state.SetVerboseLevel(level)
-
-
-def _SetCountingStyle(level):
-  """Sets the module's counting options."""
-  _cpplint_state.SetCountingStyle(level)
-
-
-def _Filters():
-  """Returns the module's list of output filters, as a list."""
-  return _cpplint_state.filters
-
-
-def _SetFilters(filters):
-  """Sets the module's error-message filters.
-
-  These filters are applied when deciding whether to emit a given
-  error message.
-
-  Args:
-    filters: A string of comma-separated filters (eg "whitespace/indent").
-             Each filter should start with + or -; else we die.
-  """
-  _cpplint_state.SetFilters(filters)
-
-def _AddFilters(filters):
-  """Adds more filter overrides.
-
-  Unlike _SetFilters, this function does not reset the current list of filters
-  available.
-
-  Args:
-    filters: A string of comma-separated filters (eg "whitespace/indent").
-             Each filter should start with + or -; else we die.
-  """
-  _cpplint_state.AddFilters(filters)
-
-def _BackupFilters():
-  """ Saves the current filter list to backup storage."""
-  _cpplint_state.BackupFilters()
-
-def _RestoreFilters():
-  """ Restores filters previously backed up."""
-  _cpplint_state.RestoreFilters()
-
-class _FunctionState(object):
-  """Tracks current function name and the number of lines in its body."""
-
-  _NORMAL_TRIGGER = 250  # for --v=0, 500 for --v=1, etc.
-  _TEST_TRIGGER = 400    # about 50% more than _NORMAL_TRIGGER.
-
-  def __init__(self):
-    self.in_a_function = False
-    self.lines_in_function = 0
-    self.current_function = ''
-
-  def Begin(self, function_name):
-    """Start analyzing function body.
-
-    Args:
-      function_name: The name of the function being tracked.
-    """
-    self.in_a_function = True
-    self.lines_in_function = 0
-    self.current_function = function_name
-
-  def Count(self):
-    """Count line in current function body."""
-    if self.in_a_function:
-      self.lines_in_function += 1
-
-  def Check(self, error, filename, linenum):
-    """Report if too many lines in function body.
-
-    Args:
-      error: The function to call with any errors found.
-      filename: The name of the current file.
-      linenum: The number of the line to check.
-    """
-    if Match(r'T(EST|est)', self.current_function):
-      base_trigger = self._TEST_TRIGGER
-    else:
-      base_trigger = self._NORMAL_TRIGGER
-    trigger = base_trigger * 2**_VerboseLevel()
-
-    if self.lines_in_function > trigger:
-      error_level = int(math.log(self.lines_in_function / base_trigger, 2))
-      # 50 => 0, 100 => 1, 200 => 2, 400 => 3, 800 => 4, 1600 => 5, ...
-      if error_level > 5:
-        error_level = 5
-      error(filename, linenum, 'readability/fn_size', error_level,
-            'Small and focused functions are preferred:'
-            ' %s has %d non-comment lines'
-            ' (error triggered by exceeding %d lines).'  % (
-                self.current_function, self.lines_in_function, trigger))
-
-  def End(self):
-    """Stop analyzing function body."""
-    self.in_a_function = False
-
-
-class _IncludeError(Exception):
-  """Indicates a problem with the include order in a file."""
-  pass
-
-
-class FileInfo(object):
-  """Provides utility functions for filenames.
-
-  FileInfo provides easy access to the components of a file's path
-  relative to the project root.
-  """
-
-  def __init__(self, filename):
-    self._filename = filename
-
-  def FullName(self):
-    """Make Windows paths like Unix."""
-    return os.path.abspath(self._filename).replace('\\', '/')
-
-  def RepositoryName(self):
-    """FullName after removing the local path to the repository.
-
-    If we have a real absolute path name here we can try to do something smart:
-    detecting the root of the checkout and truncating /path/to/checkout from
-    the name so that we get header guards that don't include things like
-    "C:\Documents and Settings\..." or "/home/username/..." in them and thus
-    people on different computers who have checked the source out to different
-    locations won't see bogus errors.
-    """
-    fullname = self.FullName()
-
-    if os.path.exists(fullname):
-      project_dir = os.path.dirname(fullname)
-
-      if os.path.exists(os.path.join(project_dir, ".svn")):
-        # If there's a .svn file in the current directory, we recursively look
-        # up the directory tree for the top of the SVN checkout
-        root_dir = project_dir
-        one_up_dir = os.path.dirname(root_dir)
-        while os.path.exists(os.path.join(one_up_dir, ".svn")):
-          root_dir = os.path.dirname(root_dir)
-          one_up_dir = os.path.dirname(one_up_dir)
-
-        prefix = os.path.commonprefix([root_dir, project_dir])
-        return fullname[len(prefix) + 1:]
-
-      # Not SVN <= 1.6? Try to find a git, hg, or svn top level directory by
-      # searching up from the current path.
-      root_dir = os.path.dirname(fullname)
-      while (root_dir != os.path.dirname(root_dir) and
-             not os.path.exists(os.path.join(root_dir, ".git")) and
-             not os.path.exists(os.path.join(root_dir, ".hg")) and
-             not os.path.exists(os.path.join(root_dir, ".svn"))):
-        root_dir = os.path.dirname(root_dir)
-
-      if (os.path.exists(os.path.join(root_dir, ".git")) or
-          os.path.exists(os.path.join(root_dir, ".hg")) or
-          os.path.exists(os.path.join(root_dir, ".svn"))):
-        prefix = os.path.commonprefix([root_dir, project_dir])
-        return fullname[len(prefix) + 1:]
-
-    # Don't know what to do; header guard warnings may be wrong...
-    return fullname
-
-  def Split(self):
-    """Splits the file into the directory, basename, and extension.
-
-    For 'chrome/browser/browser.cc', Split() would
-    return ('chrome/browser', 'browser', '.cc')
-
-    Returns:
-      A tuple of (directory, basename, extension).
-    """
-
-    googlename = self.RepositoryName()
-    project, rest = os.path.split(googlename)
-    return (project,) + os.path.splitext(rest)
-
-  def BaseName(self):
-    """File base name - text after the final slash, before the final period."""
-    return self.Split()[1]
-
-  def Extension(self):
-    """File extension - text following the final period."""
-    return self.Split()[2]
-
-  def NoExtension(self):
-    """File has no source file extension."""
-    return '/'.join(self.Split()[0:2])
-
-  def IsSource(self):
-    """File has a source file extension."""
-    return self.Extension()[1:] in ('c', 'cc', 'cpp', 'cxx')
-
-
-def _ShouldPrintError(category, confidence, linenum):
-  """If confidence >= verbose, category passes filter and is not suppressed."""
-
-  # There are three ways we might decide not to print an error message:
-  # a "NOLINT(category)" comment appears in the source,
-  # the verbosity level isn't high enough, or the filters filter it out.
-  if IsErrorSuppressedByNolint(category, linenum):
-    return False
-
-  if confidence < _cpplint_state.verbose_level:
-    return False
-
-  is_filtered = False
-  for one_filter in _Filters():
-    if one_filter.startswith('-'):
-      if category.startswith(one_filter[1:]):
-        is_filtered = True
-    elif one_filter.startswith('+'):
-      if category.startswith(one_filter[1:]):
-        is_filtered = False
-    else:
-      assert False  # should have been checked for in SetFilter.
-  if is_filtered:
-    return False
-
-  return True
-
-
-def Error(filename, linenum, category, confidence, message):
-  """Logs the fact we've found a lint error.
-
-  We log where the error was found, and also our confidence in the error,
-  that is, how certain we are this is a legitimate style regression, and
-  not a misidentification or a use that's sometimes justified.
-
-  False positives can be suppressed by the use of
-  "cpplint(category)"  comments on the offending line.  These are
-  parsed into _error_suppressions.
-
-  Args:
-    filename: The name of the file containing the error.
-    linenum: The number of the line containing the error.
-    category: A string used to describe the "category" this bug
-      falls under: "whitespace", say, or "runtime".  Categories
-      may have a hierarchy separated by slashes: "whitespace/indent".
-    confidence: A number from 1-5 representing a confidence score for
-      the error, with 5 meaning that we are certain of the problem,
-      and 1 meaning that it could be a legitimate construct.
-    message: The error message.
-  """
-  if _ShouldPrintError(category, confidence, linenum):
-    _cpplint_state.IncrementErrorCount(category)
-    if _cpplint_state.output_format == 'vs7':
-      sys.stderr.write('%s(%s):  %s  [%s] [%d]\n' % (
-          filename, linenum, message, category, confidence))
-    elif _cpplint_state.output_format == 'eclipse':
-      sys.stderr.write('%s:%s: warning: %s  [%s] [%d]\n' % (
-          filename, linenum, message, category, confidence))
-    else:
-      sys.stderr.write('%s:%s:  %s  [%s] [%d]\n' % (
-          filename, linenum, message, category, confidence))
-
-
-# Matches standard C++ escape sequences per 2.13.2.3 of the C++ standard.
-_RE_PATTERN_CLEANSE_LINE_ESCAPES = re.compile(
-    r'\\([abfnrtv?"\\\']|\d+|x[0-9a-fA-F]+)')
-# Match a single C style comment on the same line.
-_RE_PATTERN_C_COMMENTS = r'/\*(?:[^*]|\*(?!/))*\*/'
-# Matches multi-line C style comments.
-# This RE is a little bit more complicated than one might expect, because we
-# have to take care of space removals tools so we can handle comments inside
-# statements better.
-# The current rule is: We only clear spaces from both sides when we're at the
-# end of the line. Otherwise, we try to remove spaces from the right side,
-# if this doesn't work we try on left side but only if there's a non-character
-# on the right.
-_RE_PATTERN_CLEANSE_LINE_C_COMMENTS = re.compile(
-    r'(\s*' + _RE_PATTERN_C_COMMENTS + r'\s*$|' +
-    _RE_PATTERN_C_COMMENTS + r'\s+|' +
-    r'\s+' + _RE_PATTERN_C_COMMENTS + r'(?=\W)|' +
-    _RE_PATTERN_C_COMMENTS + r')')
-
-
-def IsCppString(line):
-  """Does line terminate so, that the next symbol is in string constant.
-
-  This function does not consider single-line nor multi-line comments.
-
-  Args:
-    line: is a partial line of code starting from the 0..n.
-
-  Returns:
-    True, if next character appended to 'line' is inside a
-    string constant.
-  """
-
-  line = line.replace(r'\\', 'XX')  # after this, \\" does not match to \"
-  return ((line.count('"') - line.count(r'\"') - line.count("'\"'")) & 1) == 1
-
-
-def CleanseRawStrings(raw_lines):
-  """Removes C++11 raw strings from lines.
-
-    Before:
-      static const char kData[] = R"(
-          multi-line string
-          )";
-
-    After:
-      static const char kData[] = ""
-          (replaced by blank line)
-          "";
-
-  Args:
-    raw_lines: list of raw lines.
-
-  Returns:
-    list of lines with C++11 raw strings replaced by empty strings.
-  """
-
-  delimiter = None
-  lines_without_raw_strings = []
-  for line in raw_lines:
-    if delimiter:
-      # Inside a raw string, look for the end
-      end = line.find(delimiter)
-      if end >= 0:
-        # Found the end of the string, match leading space for this
-        # line and resume copying the original lines, and also insert
-        # a "" on the last line.
-        leading_space = Match(r'^(\s*)\S', line)
-        line = leading_space.group(1) + '""' + line[end + len(delimiter):]
-        delimiter = None
-      else:
-        # Haven't found the end yet, append a blank line.
-        line = '""'
-
-    # Look for beginning of a raw string, and replace them with
-    # empty strings.  This is done in a loop to handle multiple raw
-    # strings on the same line.
-    while delimiter is None:
-      # Look for beginning of a raw string.
-      # See 2.14.15 [lex.string] for syntax.
-      matched = Match(r'^(.*)\b(?:R|u8R|uR|UR|LR)"([^\s\\()]*)\((.*)$', line)
-      if matched:
-        delimiter = ')' + matched.group(2) + '"'
-
-        end = matched.group(3).find(delimiter)
-        if end >= 0:
-          # Raw string ended on same line
-          line = (matched.group(1) + '""' +
-                  matched.group(3)[end + len(delimiter):])
-          delimiter = None
-        else:
-          # Start of a multi-line raw string
-          line = matched.group(1) + '""'
-      else:
-        break
-
-    lines_without_raw_strings.append(line)
-
-  # TODO(unknown): if delimiter is not None here, we might want to
-  # emit a warning for unterminated string.
-  return lines_without_raw_strings
-
-
-def FindNextMultiLineCommentStart(lines, lineix):
-  """Find the beginning marker for a multiline comment."""
-  while lineix < len(lines):
-    if lines[lineix].strip().startswith('/*'):
-      # Only return this marker if the comment goes beyond this line
-      if lines[lineix].strip().find('*/', 2) < 0:
-        return lineix
-    lineix += 1
-  return len(lines)
-
-
-def FindNextMultiLineCommentEnd(lines, lineix):
-  """We are inside a comment, find the end marker."""
-  while lineix < len(lines):
-    if lines[lineix].strip().endswith('*/'):
-      return lineix
-    lineix += 1
-  return len(lines)
-
-
-def RemoveMultiLineCommentsFromRange(lines, begin, end):
-  """Clears a range of lines for multi-line comments."""
-  # Having // dummy comments makes the lines non-empty, so we will not get
-  # unnecessary blank line warnings later in the code.
-  for i in range(begin, end):
-    lines[i] = '/**/'
-
-
-def RemoveMultiLineComments(filename, lines, error):
-  """Removes multiline (c-style) comments from lines."""
-  lineix = 0
-  while lineix < len(lines):
-    lineix_begin = FindNextMultiLineCommentStart(lines, lineix)
-    if lineix_begin >= len(lines):
-      return
-    lineix_end = FindNextMultiLineCommentEnd(lines, lineix_begin)
-    if lineix_end >= len(lines):
-      error(filename, lineix_begin + 1, 'readability/multiline_comment', 5,
-            'Could not find end of multi-line comment')
-      return
-    RemoveMultiLineCommentsFromRange(lines, lineix_begin, lineix_end + 1)
-    lineix = lineix_end + 1
-
-
-def CleanseComments(line):
-  """Removes //-comments and single-line C-style /* */ comments.
-
-  Args:
-    line: A line of C++ source.
-
-  Returns:
-    The line with single-line comments removed.
-  """
-  commentpos = line.find('//')
-  if commentpos != -1 and not IsCppString(line[:commentpos]):
-    line = line[:commentpos].rstrip()
-  # get rid of /* ... */
-  return _RE_PATTERN_CLEANSE_LINE_C_COMMENTS.sub('', line)
-
-
-class CleansedLines(object):
-  """Holds 4 copies of all lines with different preprocessing applied to them.
-
-  1) elided member contains lines without strings and comments.
-  2) lines member contains lines without comments.
-  3) raw_lines member contains all the lines without processing.
-  4) lines_without_raw_strings member is same as raw_lines, but with C++11 raw
-     strings removed.
-  All these members are of <type 'list'>, and of the same length.
-  """
-
-  def __init__(self, lines):
-    self.elided = []
-    self.lines = []
-    self.raw_lines = lines
-    self.num_lines = len(lines)
-    self.lines_without_raw_strings = CleanseRawStrings(lines)
-    for linenum in range(len(self.lines_without_raw_strings)):
-      self.lines.append(CleanseComments(
-          self.lines_without_raw_strings[linenum]))
-      elided = self._CollapseStrings(self.lines_without_raw_strings[linenum])
-      self.elided.append(CleanseComments(elided))
-
-  def NumLines(self):
-    """Returns the number of lines represented."""
-    return self.num_lines
-
-  @staticmethod
-  def _CollapseStrings(elided):
-    """Collapses strings and chars on a line to simple "" or '' blocks.
-
-    We nix strings first so we're not fooled by text like '"http://"'
-
-    Args:
-      elided: The line being processed.
-
-    Returns:
-      The line with collapsed strings.
-    """
-    if _RE_PATTERN_INCLUDE.match(elided):
-      return elided
-
-    # Remove escaped characters first to make quote/single quote collapsing
-    # basic.  Things that look like escaped characters shouldn't occur
-    # outside of strings and chars.
-    elided = _RE_PATTERN_CLEANSE_LINE_ESCAPES.sub('', elided)
-
-    # Replace quoted strings and digit separators.  Both single quotes
-    # and double quotes are processed in the same loop, otherwise
-    # nested quotes wouldn't work.
-    collapsed = ''
-    while True:
-      # Find the first quote character
-      match = Match(r'^([^\'"]*)([\'"])(.*)$', elided)
-      if not match:
-        collapsed += elided
-        break
-      head, quote, tail = match.groups()
-
-      if quote == '"':
-        # Collapse double quoted strings
-        second_quote = tail.find('"')
-        if second_quote >= 0:
-          collapsed += head + '""'
-          elided = tail[second_quote + 1:]
-        else:
-          # Unmatched double quote, don't bother processing the rest
-          # of the line since this is probably a multiline string.
-          collapsed += elided
-          break
-      else:
-        # Found single quote, check nearby text to eliminate digit separators.
-        #
-        # There is no special handling for floating point here, because
-        # the integer/fractional/exponent parts would all be parsed
-        # correctly as long as there are digits on both sides of the
-        # separator.  So we are fine as long as we don't see something
-        # like "0.'3" (gcc 4.9.0 will not allow this literal).
-        if Search(r'\b(?:0[bBxX]?|[1-9])[0-9a-fA-F]*$', head):
-          match_literal = Match(r'^((?:\'?[0-9a-zA-Z_])*)(.*)$', "'" + tail)
-          collapsed += head + match_literal.group(1).replace("'", '')
-          elided = match_literal.group(2)
-        else:
-          second_quote = tail.find('\'')
-          if second_quote >= 0:
-            collapsed += head + "''"
-            elided = tail[second_quote + 1:]
-          else:
-            # Unmatched single quote
-            collapsed += elided
-            break
-
-    return collapsed
-
-
-def FindEndOfExpressionInLine(line, startpos, stack):
-  """Find the position just after the end of current parenthesized expression.
-
-  Args:
-    line: a CleansedLines line.
-    startpos: start searching at this position.
-    stack: nesting stack at startpos.
-
-  Returns:
-    On finding matching end: (index just after matching end, None)
-    On finding an unclosed expression: (-1, None)
-    Otherwise: (-1, new stack at end of this line)
-  """
-  for i in xrange(startpos, len(line)):
-    char = line[i]
-    if char in '([{':
-      # Found start of parenthesized expression, push to expression stack
-      stack.append(char)
-    elif char == '<':
-      # Found potential start of template argument list
-      if i > 0 and line[i - 1] == '<':
-        # Left shift operator
-        if stack and stack[-1] == '<':
-          stack.pop()
-          if not stack:
-            return (-1, None)
-      elif i > 0 and Search(r'\boperator\s*$', line[0:i]):
-        # operator<, don't add to stack
-        continue
-      else:
-        # Tentative start of template argument list
-        stack.append('<')
-    elif char in ')]}':
-      # Found end of parenthesized expression.
-      #
-      # If we are currently expecting a matching '>', the pending '<'
-      # must have been an operator.  Remove them from expression stack.
-      while stack and stack[-1] == '<':
-        stack.pop()
-      if not stack:
-        return (-1, None)
-      if ((stack[-1] == '(' and char == ')') or
-          (stack[-1] == '[' and char == ']') or
-          (stack[-1] == '{' and char == '}')):
-        stack.pop()
-        if not stack:
-          return (i + 1, None)
-      else:
-        # Mismatched parentheses
-        return (-1, None)
-    elif char == '>':
-      # Found potential end of template argument list.
-
-      # Ignore "->" and operator functions
-      if (i > 0 and
-          (line[i - 1] == '-' or Search(r'\boperator\s*$', line[0:i - 1]))):
-        continue
-
-      # Pop the stack if there is a matching '<'.  Otherwise, ignore
-      # this '>' since it must be an operator.
-      if stack:
-        if stack[-1] == '<':
-          stack.pop()
-          if not stack:
-            return (i + 1, None)
-    elif char == ';':
-      # Found something that look like end of statements.  If we are currently
-      # expecting a '>', the matching '<' must have been an operator, since
-      # template argument list should not contain statements.
-      while stack and stack[-1] == '<':
-        stack.pop()
-      if not stack:
-        return (-1, None)
-
-  # Did not find end of expression or unbalanced parentheses on this line
-  return (-1, stack)
-
-
-def CloseExpression(clean_lines, linenum, pos):
-  """If input points to ( or { or [ or <, finds the position that closes it.
-
-  If lines[linenum][pos] points to a '(' or '{' or '[' or '<', finds the
-  linenum/pos that correspond to the closing of the expression.
-
-  TODO(unknown): cpplint spends a fair bit of time matching parentheses.
-  Ideally we would want to index all opening and closing parentheses once
-  and have CloseExpression be just a simple lookup, but due to preprocessor
-  tricks, this is not so easy.
-
-  Args:
-    clean_lines: A CleansedLines instance containing the file.
-    linenum: The number of the line to check.
-    pos: A position on the line.
-
-  Returns:
-    A tuple (line, linenum, pos) pointer *past* the closing brace, or
-    (line, len(lines), -1) if we never find a close.  Note we ignore
-    strings and comments when matching; and the line we return is the
-    'cleansed' line at linenum.
-  """
-
-  line = clean_lines.elided[linenum]
-  if (line[pos] not in '({[<') or Match(r'<[<=]', line[pos:]):
-    return (line, clean_lines.NumLines(), -1)
-
-  # Check first line
-  (end_pos, stack) = FindEndOfExpressionInLine(line, pos, [])
-  if end_pos > -1:
-    return (line, linenum, end_pos)
-
-  # Continue scanning forward
-  while stack and linenum < clean_lines.NumLines() - 1:
-    linenum += 1
-    line = clean_lines.elided[linenum]
-    (end_pos, stack) = FindEndOfExpressionInLine(line, 0, stack)
-    if end_pos > -1:
-      return (line, linenum, end_pos)
-
-  # Did not find end of expression before end of file, give up
-  return (line, clean_lines.NumLines(), -1)
-
-
-def FindStartOfExpressionInLine(line, endpos, stack):
-  """Find position at the matching start of current expression.
-
-  This is almost the reverse of FindEndOfExpressionInLine, but note
-  that the input position and returned position differs by 1.
-
-  Args:
-    line: a CleansedLines line.
-    endpos: start searching at this position.
-    stack: nesting stack at endpos.
-
-  Returns:
-    On finding matching start: (index at matching start, None)
-    On finding an unclosed expression: (-1, None)
-    Otherwise: (-1, new stack at beginning of this line)
-  """
-  i = endpos
-  while i >= 0:
-    char = line[i]
-    if char in ')]}':
-      # Found end of expression, push to expression stack
-      stack.append(char)
-    elif char == '>':
-      # Found potential end of template argument list.
-      #
-      # Ignore it if it's a "->" or ">=" or "operator>"
-      if (i > 0 and
-          (line[i - 1] == '-' or
-           Match(r'\s>=\s', line[i - 1:]) or
-           Search(r'\boperator\s*$', line[0:i]))):
-        i -= 1
-      else:
-        stack.append('>')
-    elif char == '<':
-      # Found potential start of template argument list
-      if i > 0 and line[i - 1] == '<':
-        # Left shift operator
-        i -= 1
-      else:
-        # If there is a matching '>', we can pop the expression stack.
-        # Otherwise, ignore this '<' since it must be an operator.
-        if stack and stack[-1] == '>':
-          stack.pop()
-          if not stack:
-            return (i, None)
-    elif char in '([{':
-      # Found start of expression.
-      #
-      # If there are any unmatched '>' on the stack, they must be
-      # operators.  Remove those.
-      while stack and stack[-1] == '>':
-        stack.pop()
-      if not stack:
-        return (-1, None)
-      if ((char == '(' and stack[-1] == ')') or
-          (char == '[' and stack[-1] == ']') or
-          (char == '{' and stack[-1] == '}')):
-        stack.pop()
-        if not stack:
-          return (i, None)
-      else:
-        # Mismatched parentheses
-        return (-1, None)
-    elif char == ';':
-      # Found something that look like end of statements.  If we are currently
-      # expecting a '<', the matching '>' must have been an operator, since
-      # template argument list should not contain statements.
-      while stack and stack[-1] == '>':
-        stack.pop()
-      if not stack:
-        return (-1, None)
-
-    i -= 1
-
-  return (-1, stack)
-
-
-def ReverseCloseExpression(clean_lines, linenum, pos):
-  """If input points to ) or } or ] or >, finds the position that opens it.
-
-  If lines[linenum][pos] points to a ')' or '}' or ']' or '>', finds the
-  linenum/pos that correspond to the opening of the expression.
-
-  Args:
-    clean_lines: A CleansedLines instance containing the file.
-    linenum: The number of the line to check.
-    pos: A position on the line.
-
-  Returns:
-    A tuple (line, linenum, pos) pointer *at* the opening brace, or
-    (line, 0, -1) if we never find the matching opening brace.  Note
-    we ignore strings and comments when matching; and the line we
-    return is the 'cleansed' line at linenum.
-  """
-  line = clean_lines.elided[linenum]
-  if line[pos] not in ')}]>':
-    return (line, 0, -1)
-
-  # Check last line
-  (start_pos, stack) = FindStartOfExpressionInLine(line, pos, [])
-  if start_pos > -1:
-    return (line, linenum, start_pos)
-
-  # Continue scanning backward
-  while stack and linenum > 0:
-    linenum -= 1
-    line = clean_lines.elided[linenum]
-    (start_pos, stack) = FindStartOfExpressionInLine(line, len(line) - 1, stack)
-    if start_pos > -1:
-      return (line, linenum, start_pos)
-
-  # Did not find start of expression before beginning of file, give up
-  return (line, 0, -1)
-
-
-def CheckForCopyright(filename, lines, error):
-  """Logs an error if no Copyright message appears at the top of the file."""
-
-  # We'll say it should occur by line 10. Don't forget there's a
-  # dummy line at the front.
-  for line in xrange(1, min(len(lines), 11)):
-    if re.search(r'Copyright', lines[line], re.I): break
-  else:                       # means no copyright line was found
-    error(filename, 0, 'legal/copyright', 5,
-          'No copyright message found.  '
-          'You should have a line: "Copyright [year] <Copyright Owner>"')
-
-
-def GetIndentLevel(line):
-  """Return the number of leading spaces in line.
-
-  Args:
-    line: A string to check.
-
-  Returns:
-    An integer count of leading spaces, possibly zero.
-  """
-  indent = Match(r'^( *)\S', line)
-  if indent:
-    return len(indent.group(1))
-  else:
-    return 0
-
-
-def GetHeaderGuardCPPVariable(filename):
-  """Returns the CPP variable that should be used as a header guard.
-
-  Args:
-    filename: The name of a C++ header file.
-
-  Returns:
-    The CPP variable that should be used as a header guard in the
-    named file.
-
-  """
-
-  # Restores original filename in case that cpplint is invoked from Emacs's
-  # flymake.
-  filename = re.sub(r'_flymake\.h$', '.h', filename)
-  filename = re.sub(r'/\.flymake/([^/]*)$', r'/\1', filename)
-  # Replace 'c++' with 'cpp'.
-  filename = filename.replace('C++', 'cpp').replace('c++', 'cpp')
-
-  fileinfo = FileInfo(filename)
-  file_path_from_root = fileinfo.RepositoryName()
-  if _root:
-    file_path_from_root = re.sub('^' + _root + os.sep, '', file_path_from_root)
-  # return re.sub(r'[^a-zA-Z0-9]', '_', file_path_from_root).upper() + '_'
-
-  # wangsheng@singa.apache: change INCLUDE to SINGA
-  singa_path = re.sub(r'[^a-zA-Z0-9]', '_', file_path_from_root).upper() + '_'
-  return singa_path.replace("INCLUDE_", "")
-
-
-def CheckForHeaderGuard(filename, clean_lines, error):
-  """Checks that the file contains a header guard.
-
-  Logs an error if no #ifndef header guard is present.  For other
-  headers, checks that the full pathname is used.
-
-  Args:
-    filename: The name of the C++ header file.
-    clean_lines: A CleansedLines instance containing the file.
-    error: The function to call with any errors found.
-  """
-
-  # Don't check for header guards if there are error suppression
-  # comments somewhere in this file.
-  #
-  # Because this is silencing a warning for a nonexistent line, we
-  # only support the very specific NOLINT(build/header_guard) syntax,
-  # and not the general NOLINT or NOLINT(*) syntax.
-  raw_lines = clean_lines.lines_without_raw_strings
-  for i in raw_lines:
-    if Search(r'//\s*NOLINT\(build/header_guard\)', i):
-      return
-
-  cppvar = GetHeaderGuardCPPVariable(filename)
-
-  ifndef = ''
-  ifndef_linenum = 0
-  define = ''
-  endif = ''
-  endif_linenum = 0
-  for linenum, line in enumerate(raw_lines):
-    linesplit = line.split()
-    if len(linesplit) >= 2:
-      # find the first occurrence of #ifndef and #define, save arg
-      if not ifndef and linesplit[0] == '#ifndef':
-        # set ifndef to the header guard presented on the #ifndef line.
-        ifndef = linesplit[1]
-        ifndef_linenum = linenum
-      if not define and linesplit[0] == '#define':
-        define = linesplit[1]
-    # find the last occurrence of #endif, save entire line
-    if line.startswith('#endif'):
-      endif = line
-      endif_linenum = linenum
-
-  if not ifndef or not define or ifndef != define:
-    error(filename, 0, 'build/header_guard', 5,
-          'No #ifndef header guard found, suggested CPP variable is: %s' %
-          cppvar)
-    return
-
-  # The guard should be PATH_FILE_H_, but we also allow PATH_FILE_H__
-  # for backward compatibility.
-  if ifndef != cppvar:
-    error_level = 0
-    if ifndef != cppvar + '_':
-      error_level = 5
-
-    ParseNolintSuppressions(filename, raw_lines[ifndef_linenum], ifndef_linenum,
-                            error)
-    error(filename, ifndef_linenum, 'build/header_guard', error_level,
-          '#ifndef header guard has wrong style, please use: %s' % cppvar)
-
-  # Check for "//" comments on endif line.
-  ParseNolintSuppressions(filename, raw_lines[endif_linenum], endif_linenum,
-                          error)
-  match = Match(r'#endif\s*//\s*' + cppvar + r'(_)?\b', endif)
-  if match:
-    if match.group(1) == '_':
-      # Issue low severity warning for deprecated double trailing underscore
-      error(filename, endif_linenum, 'build/header_guard', 0,
-            '#endif line should be "#endif  // %s"' % cppvar)
-    return
-
-  # Didn't find the corresponding "//" comment.  If this file does not
-  # contain any "//" comments at all, it could be that the compiler
-  # only wants "/**/" comments, look for those instead.
-  no_single_line_comments = True
-  for i in xrange(1, len(raw_lines) - 1):
-    line = raw_lines[i]
-    if Match(r'^(?:(?:\'(?:\.|[^\'])*\')|(?:"(?:\.|[^"])*")|[^\'"])*//', line):
-      no_single_line_comments = False
-      break
-
-  if no_single_line_comments:
-    match = Match(r'#endif\s*/\*\s*' + cppvar + r'(_)?\s*\*/', endif)
-    if match:
-      if match.group(1) == '_':
-        # Low severity warning for double trailing underscore
-        error(filename, endif_linenum, 'build/header_guard', 0,
-              '#endif line should be "#endif  /* %s */"' % cppvar)
-      return
-
-  # Didn't find anything
-  error(filename, endif_linenum, 'build/header_guard', 5,
-        '#endif line should be "#endif  // %s"' % cppvar)
-
-
-def CheckHeaderFileIncluded(filename, include_state, error):
-  """Logs an error if a .cc file does not include its header."""
-
-  # Do not check test files
-  if filename.endswith('_test.cc') or filename.endswith('_unittest.cc'):
-    return
-
-  fileinfo = FileInfo(filename)
-  headerfile = filename[0:len(filename) - 2] + 'h'
-  if not os.path.exists(headerfile):
-    return
-  headername = FileInfo(headerfile).RepositoryName()
-  first_include = 0
-  for section_list in include_state.include_list:
-    for f in section_list:
-      if headername in f[0] or f[0] in headername:
-        return
-      if not first_include:
-        first_include = f[1]
-
-  error(filename, first_include, 'build/include', 5,
-        '%s should include its header file %s' % (fileinfo.RepositoryName(),
-                                                  headername))
-
-
-def CheckForBadCharacters(filename, lines, error):
-  """Logs an error for each line containing bad characters.
-
-  Two kinds of bad characters:
-
-  1. Unicode replacement characters: These indicate that either the file
-  contained invalid UTF-8 (likely) or Unicode replacement characters (which
-  it shouldn't).  Note that it's possible for this to throw off line
-  numbering if the invalid UTF-8 occurred adjacent to a newline.
-
-  2. NUL bytes.  These are problematic for some tools.
-
-  Args:
-    filename: The name of the current file.
-    lines: An array of strings, each representing a line of the file.
-    error: The function to call with any errors found.
-  """
-  for linenum, line in enumerate(lines):
-    if u'\ufffd' in line:
-      error(filename, linenum, 'readability/utf8', 5,
-            'Line contains invalid UTF-8 (or Unicode replacement character).')
-    if '\0' in line:
-      error(filename, linenum, 'readability/nul', 5, 'Line contains NUL byte.')
-
-
-def CheckForNewlineAtEOF(filename, lines, error):
-  """Logs an error if there is no newline char at the end of the file.
-
-  Args:
-    filename: The name of the current file.
-    lines: An array of strings, each representing a line of the file.
-    error: The function to call with any errors found.
-  """
-
-  # The array lines() was created by adding two newlines to the
-  # original file (go figure), then splitting on \n.
-  # To verify that the file ends in \n, we just have to make sure the
-  # last-but-two element of lines() exists and is empty.
-  if len(lines) < 3 or lines[-2]:
-    error(filename, len(lines) - 2, 'whitespace/ending_newline', 5,
-          'Could not find a newline character at the end of the file.')
-
-
-def CheckForMultilineCommentsAndStrings(filename, clean_lines, linenum, error):
-  """Logs an error if we see /* ... */ or "..." that extend past one line.
-
-  /* ... */ comments are legit inside macros, for one line.
-  Otherwise, we prefer // comments, so it's ok to warn about the
-  other.  Likewise, it's ok for strings to extend across multiple
-  lines, as long as a line continuation character (backslash)
-  terminates each line. Although not currently prohibited by the C++
-  style guide, it's ugly and unnecessary. We don't do well with either
-  in this lint program, so we warn about both.
-
-  Args:
-    filename: The name of the current file.
-    clean_lines: A CleansedLines instance containing the file.
-    linenum: The number of the line to check.
-    error: The function to call with any errors found.
-  """
-  line = clean_lines.elided[linenum]
-
-  # Remove all \\ (escaped backslashes) from the line. They are OK, and the
-  # second (escaped) slash may trigger later \" detection erroneously.
-  line = line.replace('\\\\', '')
-
-  if line.count('/*') > line.count('*/'):
-    error(filename, linenum, 'readability/multiline_comment', 5,
-          'Complex multi-line /*...*/-style comment found. '
-          'Lint may give bogus warnings.  '
-          'Consider replacing these with //-style comments, '
-          'with #if 0...#endif, '
-          'or with more clearly structured multi-line comments.')
-
-  if (line.count('"') - line.count('\\"')) % 2:
-    error(filename, linenum, 'readability/multiline_string', 5,
-          'Multi-line string ("...") found.  This lint script doesn\'t '
-          'do well with such strings, and may give bogus warnings.  '
-          'Use C++11 raw strings or concatenation instead.')
-
-
-# (non-threadsafe name, thread-safe alternative, validation pattern)
-#
-# The validation pattern is used to eliminate false positives such as:
-#  _rand();               // false positive due to substring match.
-#  ->rand();              // some member function rand().
-#  ACMRandom rand(seed);  // some variable named rand.
-#  ISAACRandom rand();    // another variable named rand.
-#
-# Basically we require the return value of these functions to be used
-# in some expression context on the same line by matching on some
-# operator before the function name.  This eliminates constructors and
-# member function calls.
-_UNSAFE_FUNC_PREFIX = r'(?:[-+*/=%^&|(<]\s*|>\s+)'
-_THREADING_LIST = (
-    ('asctime(', 'asctime_r(', _UNSAFE_FUNC_PREFIX + r'asctime\([^)]+\)'),
-    ('ctime(', 'ctime_r(', _UNSAFE_FUNC_PREFIX + r'ctime\([^)]+\)'),
-    ('getgrgid(', 'getgrgid_r(', _UNSAFE_FUNC_PREFIX + r'getgrgid\([^)]+\)'),
-    ('getgrnam(', 'getgrnam_r(', _UNSAFE_FUNC_PREFIX + r'getgrnam\([^)]+\)'),
-    ('getlogin(', 'getlogin_r(', _UNSAFE_FUNC_PREFIX + r'getlogin\(\)'),
-    ('getpwnam(', 'getpwnam_r(', _UNSAFE_FUNC_PREFIX + r'getpwnam\([^)]+\)'),
-    ('getpwuid(', 'getpwuid_r(', _UNSAFE_FUNC_PREFIX + r'getpwuid\([^)]+\)'),
-    ('gmtime(', 'gmtime_r(', _UNSAFE_FUNC_PREFIX + r'gmtime\([^)]+\)'),
-    ('localtime(', 'localtime_r(', _UNSAFE_FUNC_PREFIX + r'localtime\([^)]+\)'),
-    ('rand(', 'rand_r(', _UNSAFE_FUNC_PREFIX + r'rand\(\)'),
-    ('strtok(', 'strtok_r(',
-     _UNSAFE_FUNC_PREFIX + r'strtok\([^)]+\)'),
-    ('ttyname(', 'ttyname_r(', _UNSAFE_FUNC_PREFIX + r'ttyname\([^)]+\)'),
-    )
-
-
-def CheckPosixThreading(filename, clean_lines, linenum, error):
-  """Checks for calls to thread-unsafe functions.
-
-  Much code has been originally written without consideration of
-  multi-threading. Also, engineers are relying on their old experience;
-  they have learned posix before threading extensions were added. These
-  tests guide the engineers to use thread-safe functions (when using
-  posix directly).
-
-  Args:
-    filename: The name of the current file.
-    clean_lines: A CleansedLines instance containing the file.
-    linenum: The number of the line to check.
-    error: The function to call with any errors found.
-  """
-  line = clean_lines.elided[linenum]
-  for single_thread_func, multithread_safe_func, pattern in _THREADING_LIST:
-    # Additional pattern matching check to confirm that this is the
-    # function we are looking for
-    if Search(pattern, line):
-      error(filename, linenum, 'runtime/threadsafe_fn', 2,
-            'Consider using ' + multithread_safe_func +
-            '...) instead of ' + single_thread_func +
-            '...) for improved thread safety.')
-
-
-def CheckVlogArguments(filename, clean_lines, linenum, error):
-  """Checks that VLOG() is only used for defining a logging level.
-
-  For example, VLOG(2) is correct. VLOG(INFO), VLOG(WARNING), VLOG(ERROR), and
-  VLOG(FATAL) are not.
-
-  Args:
-    filename: The name of the current file.
-    clean_lines: A CleansedLines instance containing the file.
-    linenum: The number of the line to check.
-    error: The function to call with any errors found.
-  """
-  line = clean_lines.elided[linenum]
-  if Search(r'\bVLOG\((INFO|ERROR|WARNING|DFATAL|FATAL)\)', line):
-    error(filename, linenum, 'runtime/vlog', 5,
-          'VLOG() should be used with numeric verbosity level.  '
-          'Use LOG() if you want symbolic severity levels.')
-
-# Matches invalid increment: *count++, which moves pointer instead of
-# incrementing a value.
-_RE_PATTERN_INVALID_INCREMENT = re.compile(
-    r'^\s*\*\w+(\+\+|--);')
-
-
-def CheckInvalidIncrement(filename, clean_lines, linenum, error):
-  """Checks for invalid increment *count++.
-
-  For example following function:
-  void increment_counter(int* count) {
-    *count++;
-  }
-  is invalid, because it effectively does count++, moving pointer, and should
-  be replaced with ++*count, (*count)++ or *count += 1.
-
-  Args:
-    filename: The name of the current file.
-    clean_lines: A CleansedLines instance containing the file.
-    linenum: The number of the line to check.
-    error: The function to call with any errors found.
-  """
-  line = clean_lines.elided[linenum]
-  if _RE_PATTERN_INVALID_INCREMENT.match(line):
-    error(filename, linenum, 'runtime/invalid_increment', 5,
-          'Changing pointer instead of value (or unused value of operator*).')
-
-
-def IsMacroDefinition(clean_lines, linenum):
-  if Search(r'^#define', clean_lines[linenum]):
-    return True
-
-  if linenum > 0 and Search(r'\\$', clean_lines[linenum - 1]):
-    return True
-
-  return False
-
-
-def IsForwardClassDeclaration(clean_lines, linenum):
-  return Match(r'^\s*(\btemplate\b)*.*class\s+\w+;\s*$', clean_lines[linenum])
-
-
-class _BlockInfo(object):
-  """Stores information about a generic block of code."""
-
-  def __init__(self, seen_open_brace):
-    self.seen_open_brace = seen_open_brace
-    self.open_parentheses = 0
-    self.inline_asm = _NO_ASM
-    self.check_namespace_indentation = False
-
-  def CheckBegin(self, filename, clean_lines, linenum, error):
-    """Run checks that applies to text up to the opening brace.
-
-    This is mostly for checking the text after the class identifier
-    and the "{", usually where the base class is specified.  For other
-    blocks, there isn't much to check, so we always pass.
-
-    Args:
-      filename: The name of the current file.
-      clean_lines: A CleansedLines instance containing the file.
-      linenum: The number of the line to check.
-      error: The function to call with any errors found.
-    """
-    pass
-
-  def CheckEnd(self, filename, clean_lines, linenum, error):
-    """Run checks that applies to text after the closing brace.
-
-    This is mostly used for checking end of namespace comments.
-
-    Args:
-      filename: The name of the current file.
-      clean_lines: A CleansedLines instance containing the file.
-      linenum: The number of the line to check.
-      error: The function to call with any errors found.
-    """
-    pass
-
-  def IsBlockInfo(self):
-    """Returns true if this block is a _BlockInfo.
-
-    This is convenient for verifying that an object is an instance of
-    a _BlockInfo, but not an instance of any of the derived classes.
-
-    Returns:
-      True for this class, False for derived classes.
-    """
-    return self.__class__ == _BlockInfo
-
-
-class _ExternCInfo(_BlockInfo):
-  """Stores information about an 'extern "C"' block."""
-
-  def __init__(self):
-    _BlockInfo.__init__(self, True)
-
-
-class _ClassInfo(_BlockInfo):
-  """Stores information about a class."""
-
-  def __init__(self, name, class_or_struct, clean_lines, linenum):
-    _BlockInfo.__init__(self, False)
-    self.name = name
-    self.starting_linenum = linenum
-    self.is_derived = False
-    self.check_namespace_indentation = True
-    if class_or_struct == 'struct':
-      self.access = 'public'
-      self.is_struct = True
-    else:
-      self.access = 'private'
-      self.is_struct = False
-
-    # Remember initial indentation level for this class.  Using raw_lines here
-    # instead of elided to account for leading comments.
-    self.class_indent = GetIndentLevel(clean_lines.raw_lines[linenum])
-
-    # Try to find the end of the class.  This will be confused by things like:
-    #   class A {
-    #   } *x = { ...
-    #
-    # But it's still good enough for CheckSectionSpacing.
-    self.last_line = 0
-    depth = 0
-    for i in range(linenum, clean_lines.NumLines()):
-      line = clean_lines.elided[i]
-      depth += line.count('{') - line.count('}')
-      if not depth:
-        self.last_line = i
-        break
-
-  def CheckBegin(self, filename, clean_lines, linenum, error):
-    # Look for a bare ':'
-    if Search('(^|[^:]):($|[^:])', clean_lines.elided[linenum]):
-      self.is_derived = True
-
-  def CheckEnd(self, filename, clean_lines, linenum, error):
-    # If there is a DISALLOW macro, it should appear near the end of
-    # the class.
-    seen_last_thing_in_class = False
-    for i in xrange(linenum - 1, self.starting_linenum, -1):
-      match = Search(
-          r'\b(DISALLOW_COPY_AND_ASSIGN|DISALLOW_IMPLICIT_CONSTRUCTORS)\(' +
-          self.name + r'\)',
-          clean_lines.elided[i])
-      if match:
-        if seen_last_thing_in_class:
-          error(filename, i, 'readability/constructors', 3,
-                match.group(1) + ' should be the last thing in the class')
-        break
-
-      if not Match(r'^\s*$', clean_lines.elided[i]):
-        seen_last_thing_in_class = True
-
-    # Check that closing brace is aligned with beginning of the class.
-    # Only do this if the closing brace is indented by only whitespaces.
-    # This means we will not check single-line class definitions.
-    indent = Match(r'^( *)\}', clean_lines.elided[linenum])
-    if indent and len(indent.group(1)) != self.class_indent:
-      if self.is_struct:
-        parent = 'struct ' + self.name
-      else:
-        parent = 'class ' + self.name
-      error(filename, linenum, 'whitespace/indent', 3,
-            'Closing brace should be aligned with beginning of %s' % parent)
-
-
-class _NamespaceInfo(_BlockInfo):
-  """Stores information about a namespace."""
-
-  def __init__(self, name, linenum):
-    _BlockInfo.__init__(self, False)
-    self.name = name or ''
-    self.starting_linenum = linenum
-    self.check_namespace_indentation = True
-
-  def CheckEnd(self, filename, clean_lines, linenum, error):
-    """Check end of namespace comments."""
-    line = clean_lines.raw_lines[linenum]
-
-    # Check how many lines is enclosed in this namespace.  Don't issue
-    # warning for missing namespace comments if there aren't enough
-    # lines.  However, do apply checks if there is already an end of
-    # namespace comment and it's incorrect.
-    #
-    # TODO(unknown): We always want to check end of namespace comments
-    # if a namespace is large, but sometimes we also want to apply the
-    # check if a short namespace contained nontrivial things (something
-    # other than forward declarations).  There is currently no logic on
-    # deciding what these nontrivial things are, so this check is
-    # triggered by namespace size only, which works most of the time.
-    if (linenum - self.starting_linenum < 10
-        and not Match(r'};*\s*(//|/\*).*\bnamespace\b', line)):
-      return
-
-    # Look for matching comment at end of namespace.
-    #
-    # Note that we accept C style "/* */" comments for terminating
-    # namespaces, so that code that terminate namespaces inside
-    # preprocessor macros can be cpplint clean.
-    #
-    # We also accept stuff like "// end of namespace <name>." with the
-    # period at the end.
-    #
-    # Besides these, we don't accept anything else, otherwise we might
-    # get false negatives when existing comment is a substring of the
-    # expected namespace.
-    if self.name:
-      # Named namespace
-      if not Match((r'};*\s*(//|/\*).*\bnamespace\s+' + re.escape(self.name) +
-                    r'[\*/\.\\\s]*$'),
-                   line):
-        error(filename, linenum, 'readability/namespace', 5,
-              'Namespace should be terminated with "// namespace %s"' %
-              self.name)
-    else:
-      # Anonymous namespace
-      if not Match(r'};*\s*(//|/\*).*\bnamespace[\*/\.\\\s]*$', line):
-        # If "// namespace anonymous" or "// anonymous namespace (more text)",
-        # mention "// anonymous namespace" as an acceptable form
-        if Match(r'}.*\b(namespace anonymous|anonymous namespace)\b', line):
-          error(filename, linenum, 'readability/namespace', 5,
-                'Anonymous namespace should be terminated with "// namespace"'
-                ' or "// anonymous namespace"')
-        else:
-          error(filename, linenum, 'readability/namespace', 5,
-                'Anonymous namespace should be terminated with "// namespace"')
-
-
-class _PreprocessorInfo(object):
-  """Stores checkpoints of nesting stacks when #if/#else is seen."""
-
-  def __init__(self, stack_before_if):
-    # The entire nesting stack before #if
-    self.stack_before_if = stack_before_if
-
-    # The entire nesting stack up to #else
-    self.stack_before_else = []
-
-    # Whether we have already seen #else or #elif
-    self.seen_else = False
-
-
-class NestingState(object):
-  """Holds states related to parsing braces."""
-
-  def __init__(self):
-    # Stack for tracking all braces.  An object is pushed whenever we
-    # see a "{", and popped when we see a "}".  Only 3 types of
-    # objects are possible:
-    # - _ClassInfo: a class or struct.
-    # - _NamespaceInfo: a namespace.
-    # - _BlockInfo: some other type of block.
-    self.stack = []
-
-    # Top of the previous stack before each Update().
-    #
-    # Because the nesting_stack is updated at the end of each line, we
-    # had to do some convoluted checks to find out what is the current
-    # scope at the beginning of the line.  This check is simplified by
-    # saving the previous top of nesting stack.
-    #
-    # We could save the full stack, but we only need the top.  Copying
-    # the full nesting stack would slow down cpplint by ~10%.
-    self.previous_stack_top = []
-
-    # Stack of _PreprocessorInfo objects.
-    self.pp_stack = []
-
-  def SeenOpenBrace(self):
-    """Check if we have seen the opening brace for the innermost block.
-
-    Returns:
-      True if we have seen the opening brace, False if the innermost
-      block is still expecting an opening brace.
-    """
-    return (not self.stack) or self.stack[-1].seen_open_brace
-
-  def InNamespaceBody(self):
-    """Check if we are currently one level inside a namespace body.
-
-    Returns:
-      True if top of the stack is a namespace block, False otherwise.
-    """
-    return self.stack and isinstance(self.stack[-1], _NamespaceInfo)
-
-  def InExternC(self):
-    """Check if we are currently one level inside an 'extern "C"' block.
-
-    Returns:
-      True if top of the stack is an extern block, False otherwise.
-    """
-    return self.stack and isinstance(self.stack[-1], _ExternCInfo)
-
-  def InClassDeclaration(self):
-    """Check if we are currently one level inside a class or struct declaration.
-
-    Returns:
-      True if top of the stack is a class/struct, False otherwise.
-    """
-    return self.stack and isinstance(self.stack[-1], _ClassInfo)
-
-  def InAsmBlock(self):
-    """Check if we are currently one level inside an inline ASM block.
-
-    Returns:
-      True if the top of the stack is a block containing inline ASM.
-    """
-    return self.stack and self.stack[-1].inline_asm != _NO_ASM
-
-  def InTemplateArgumentList(self, clean_lines, linenum, pos):
-    """Check if current position is inside template argument list.
-
-    Args:
-      clean_lines: A CleansedLines instance containing the file.
-      linenum: The number of the line to check.
-      pos: position just after the suspected template argument.
-    Returns:
-      True if (linenum, pos) is inside template arguments.
-    """
-    while linenum < clean_lines.NumLines():
-      # Find the earliest character that might indicate a template argument
-      line = clean_lines.elided[linenum]
-      match = Match(r'^[^{};=\[\]\.<>]*(.)', line[pos:])
-      if not match:
-        linenum += 1
-        pos = 0
-        continue
-      token = match.group(1)
-      pos += len(match.group(0))
-
-      # These things do not look like template argument list:
-      #   class Suspect {
-      #   class Suspect x; }
-      if token in ('{', '}', ';'): return False
-
-      # These things look like template argument list:
-      #   template <class Suspect>
-      #   template <class Suspect = default_value>
-      #   template <class Suspect[]>
-      #   template <class Suspect...>
-      if token in ('>', '=', '[', ']', '.'): return True
-
-      # Check if token is an unmatched '<'.
-      # If not, move on to the next character.
-      if token != '<':
-        pos += 1
-        if pos >= len(line):
-          linenum += 1
-          pos = 0
-        continue
-
-      # We can't be sure if we just find a single '<', and need to
-      # find the matching '>'.
-      (_, end_line, end_pos) = CloseExpression(clean_lines, linenum, pos - 1)
-      if end_pos < 0:
-        # Not sure if template argument list or syntax error in file
-        return False
-      linenum = end_line
-      pos = end_pos
-    return False
-
-  def UpdatePreprocessor(self, line):
-    """Update preprocessor stack.
-
-    We need to handle preprocessors due to classes like this:
-      #ifdef SWIG
-      struct ResultDetailsPageElementExtensionPoint {
-      #else
-      struct ResultDetailsPageElementExtensionPoint : public Extension {
-      #endif
-
-    We make the following assumptions (good enough for most files):
-    - Preprocessor condition evaluates to true from #if up to first
-      #else/#elif/#endif.
-
-    - Preprocessor condition evaluates to false from #else/#elif up
-      to #endif.  We still perform lint checks on these lines, but
-      these do not affect nesting stack.
-
-    Args:
-      line: current line to check.
-    """
-    if Match(r'^\s*#\s*(if|ifdef|ifndef)\b', line):
-      # Beginning of #if block, save the nesting stack here.  The saved
-      # stack will allow us to restore the parsing state in the #else case.
-      self.pp_stack.append(_PreprocessorInfo(copy.deepcopy(self.stack)))
-    elif Match(r'^\s*#\s*(else|elif)\b', line):
-      # Beginning of #else block
-      if self.pp_stack:
-        if not self.pp_stack[-1].seen_else:
-          # This is the first #else or #elif block.  Remember the
-          # whole nesting stack up to this point.  This is what we
-          # keep after the #endif.
-          self.pp_stack[-1].seen_else = True
-          self.pp_stack[-1].stack_before_else = copy.deepcopy(self.stack)
-
-        # Restore the stack to how it was before the #if
-        self.stack = copy.deepcopy(self.pp_stack[-1].stack_before_if)
-      else:
-        # TODO(unknown): unexpected #else, issue warning?
-        pass
-    elif Match(r'^\s*#\s*endif\b', line):
-      # End of #if or #else blocks.
-      if self.pp_stack:
-        # If we saw an #else, we will need to restore the nesting
-        # stack to its former state before the #else, otherwise we
-        # will just continue from where we left off.
-        if self.pp_stack[-1].seen_else:
-          # Here we can just use a shallow copy since we are the last
-          # reference to it.
-          self.stack = self.pp_stack[-1].stack_before_else
-        # Drop the corresponding #if
-        self.pp_stack.pop()
-      else:
-        # TODO(unknown): unexpected #endif, issue warning?
-        pass
-
-  # TODO(unknown): Update() is too long, but we will refactor later.
-  def Update(self, filename, clean_lines, linenum, error):
-    """Update nesting state with current line.
-
-    Args:
-      filename: The name of the current file.
-      clean_lines: A CleansedLines instance containing the file.
-      linenum: The number of the line to check.
-      error: The function to call with any errors found.
-    """
-    line = clean_lines.elided[linenum]
-
-    # Remember top of the previous nesting stack.
-    #
-    # The stack is always pushed/popped and not modified in place, so
-    # we can just do a shallow copy instead of copy.deepcopy.  Using
-    # deepcopy would slow down cpplint by ~28%.
-    if self.stack:
-      self.previous_stack_top = self.stack[-1]
-    else:
-      self.previous_stack_top = None
-
-    # Update pp_stack
-    self.UpdatePreprocessor(line)
-
-    # Count parentheses.  This is to avoid adding struct arguments to
-    # the nesting stack.
-    if self.stack:
-      inner_block = self.stack[-1]
-      depth_change = line.count('(') - line.count(')')
-      inner_block.open_parentheses += depth_change
-
-      # Also check if we are starting or ending an inline assembly block.
-      if inner_block.inline_asm in (_NO_ASM, _END_ASM):
-        if (depth_change != 0 and
-            inner_block.open_parentheses == 1 and
-            _MATCH_ASM.match(line)):
-          # Enter assembly block
-          inner_block.inline_asm = _INSIDE_ASM
-        else:
-          # Not entering assembly block.  If previous line was _END_ASM,
-          # we will now shift to _NO_ASM state.
-          inner_block.inline_asm = _NO_ASM
-      elif (inner_block.inline_asm == _INSIDE_ASM and
-            inner_block.open_parentheses == 0):
-        # Exit assembly block
-        inner_block.inline_asm = _END_ASM
-
-    # Consume namespace declaration at the beginning of the line.  Do
-    # this in a loop so that we catch same line declarations like this:
-    #   namespace proto2 { namespace bridge { class MessageSet; } }
-    while True:
-      # Match start of namespace.  The "\b\s*" below catches namespace
-      # declarations even if it weren't followed by a whitespace, this
-      # is so that we don't confuse our namespace checker.  The
-      # missing spaces will be flagged by CheckSpacing.
-      namespace_decl_match = Match(r'^\s*namespace\b\s*([:\w]+)?(.*)$', line)
-      if not namespace_decl_match:
-        break
-
-      new_namespace = _NamespaceInfo(namespace_decl_match.group(1), linenum)
-      self.stack.append(new_namespace)
-
-      line = namespace_decl_match.group(2)
-      if line.find('{') != -1:
-        new_namespace.seen_open_brace = True
-        line = line[line.find('{') + 1:]
-
-    # Look for a class declaration in whatever is left of the line
-    # after parsing namespaces.  The regexp accounts for decorated classes
-    # such as in:
-    #   class LOCKABLE API Object {
-    #   };
-    class_decl_match = Match(
-        r'^(\s*(?:template\s*<[\w\s<>,:]*>\s*)?'
-        r'(class|struct)\s+(?:[A-Z_]+\s+)*(\w+(?:::\w+)*))'
-        r'(.*)$', line)
-    if (class_decl_match and
-        (not self.stack or self.stack[-1].open_parentheses == 0)):
-      # We do not want to accept classes that are actually template arguments:
-      #   template <class Ignore1,
-      #             class Ignore2 = Default<Args>,
-      #             template <Args> class Ignore3>
-      #   void Function() {};
-      #
-      # To avoid template argument cases, we scan forward and look for
-      # an unmatched '>'.  If we see one, assume we are inside a
-      # template argument list.
-      end_declaration = len(class_decl_match.group(1))
-      if not self.InTemplateArgumentList(clean_lines, linenum, end_declaration):
-        self.stack.append(_ClassInfo(
-            class_decl_match.group(3), class_decl_match.group(2),
-            clean_lines, linenum))
-        line = class_decl_match.group(4)
-
-    # If we have not yet seen the opening brace for the innermost block,
-    # run checks here.
-    if not self.SeenOpenBrace():
-      self.stack[-1].CheckBegin(filename, clean_lines, linenum, error)
-
-    # Update access control if we are inside a class/struct
-    if self.stack and isinstance(self.stack[-1], _ClassInfo):
-      classinfo = self.stack[-1]
-      access_match = Match(
-          r'^(.*)\b(public|private|protected|signals)(\s+(?:slots\s*)?)?'
-          r':(?:[^:]|$)',
-          line)
-      if access_match:
-        classinfo.access = access_match.group(2)
-
-        # Check that access keywords are indented +1 space.  Skip this
-        # check if the keywords are not preceded by whitespaces.
-        indent = access_match.group(1)
-        if (len(indent) != classinfo.class_indent + 1 and
-            Match(r'^\s*$', indent)):
-          if classinfo.is_struct:
-            parent = 'struct ' + classinfo.name
-          else:
-            parent = 'class ' + classinfo.name
-          slots = ''
-          if access_match.group(3):
-            slots = access_match.group(3)
-          error(filename, linenum, 'whitespace/indent', 3,
-                '%s%s: should be indented +1 space inside %s' % (
-                    access_match.group(2), slots, parent))
-
-    # Consume braces or semicolons from what's left of the line
-    while True:
-      # Match first brace, semicolon, or closed parenthesis.
-      matched = Match(r'^[^{;)}]*([{;)}])(.*)$', line)
-      if not matched:
-        break
-
-      token = matched.group(1)
-      if token == '{':
-        # If namespace or class hasn't seen a opening brace yet, mark
-        # namespace/class head as complete.  Push a new block onto the
-        # stack otherwise.
-        if not self.SeenOpenBrace():
-          self.stack[-1].seen_open_brace = True
-        elif Match(r'^extern\s*"[^"]*"\s*\{', line):
-          self.stack.append(_ExternCInfo())
-        else:
-          self.stack.append(_BlockInfo(True))
-          if _MATCH_ASM.match(line):
-            self.stack[-1].inline_asm = _BLOCK_ASM
-
-      elif token == ';' or token == ')':
-        # If we haven't seen an opening brace yet, but we already saw
-        # a semicolon, this is probably a forward declaration.  Pop
-        # the stack for these.
-        #
-        # Similarly, if we haven't seen an opening brace yet, but we
-        # already saw a closing parenthesis, then these are probably
-        # function arguments with extra "class" or "struct" keywords.
-        # Also pop these stack for these.
-        if not self.SeenOpenBrace():
-          self.stack.pop()
-      else:  # token == '}'
-        # Perform end of block checks and pop the stack.
-        if self.stack:
-          self.stack[-1].CheckEnd(filename, clean_lines, linenum, error)
-          self.stack.pop()
-      line = matched.group(2)
-
-  def InnermostClass(self):
-    """Get class info on the top of the stack.
-
-    Returns:
-      A _ClassInfo object if we are inside a class, or None otherwise.
-    """
-    for i in range(len(self.stack), 0, -1):
-      classinfo = self.stack[i - 1]
-      if isinstance(classinfo, _ClassInfo):
-        return classinfo
-    return None
-
-  def CheckCompletedBlocks(self, filename, error):
-    """Checks that all classes and namespaces have been completely parsed.
-
-    Call this when all lines in a file have been processed.
-    Args:
-      filename: The name of the current file.
-      error: The function to call with any errors found.
-    """
-    # Note: This test can result in false positives if #ifdef constructs
-    # get in the way of brace matching. See the testBuildClass test in
-    # cpplint_unittest.py for an example of this.
-    for obj in self.stack:
-      if isinstance(obj, _ClassInfo):
-        error(filename, obj.starting_linenum, 'build/class', 5,
-              'Failed to find complete declaration of class %s' %
-              obj.name)
-      elif isinstance(obj, _NamespaceInfo):
-        error(filename, obj.starting_linenum, 'build/namespaces', 5,
-              'Failed to find complete declaration of namespace %s' %
-              obj.name)
-
-
-def CheckForNonStandardConstructs(filename, clean_lines, linenum,
-                                  nesting_state, error):
-  r"""Logs an error if we see certain non-ANSI constructs ignored by gcc-2.
-
-  Complain about several constructs which gcc-2 accepts, but which are
-  not standard C++.  Warning about these in lint is one way to ease the
-  transition to new compilers.
-  - put storage class first (e.g. "static const" instead of "const static").
-  - "%lld" instead of %qd" in printf-type functions.
-  - "%1$d" is non-standard in printf-type functions.
-  - "\%" is an undefined character escape sequence.
-  - text after #endif is not allowed.
-  - invalid inner-style forward declaration.
-  - >? and <? operators, and their >?= and <?= cousins.
-
-  Additionally, check for constructor/destructor style violations and reference
-  members, as it is very convenient to do so while checking for
-  gcc-2 compliance.
-
-  Args:
-    filename: The name of the current file.
-    clean_lines: A CleansedLines instance containing the file.
-    linenum: The number of the line to check.
-    nesting_state: A NestingState instance which maintains information about
-                   the current stack of nested blocks being parsed.
-    error: A callable to which errors are reported, which takes 4 arguments:
-           filename, line number, error level, and message
-  """
-
-  # Remove comments from the line, but leave in strings for now.
-  line = clean_lines.lines[linenum]
-
-  if Search(r'printf\s*\(.*".*%[-+ ]?\d*q', line):
-    error(filename, linenum, 'runtime/printf_format', 3,
-          '%q in format strings is deprecated.  Use %ll instead.')
-
-  if Search(r'printf\s*\(.*".*%\d+\$', line):
-    error(filename, linenum, 'runtime/printf_format', 2,
-          '%N$ formats are unconventional.  Try rewriting to avoid them.')
-
-  # Remove escaped backslashes before looking for undefined escapes.
-  line = line.replace('\\\\', '')
-
-  if Search(r'("|\').*\\(%|\[|\(|{)', line):
-    error(filename, linenum, 'build/printf_format', 3,
-          '%, [, (, and { are undefined character escapes.  Unescape them.')
-
-  # For the rest, work with both comments and strings removed.
-  line = clean_lines.elided[linenum]
-
-  if Search(r'\b(const|volatile|void|char|short|int|long'
-            r'|float|double|signed|unsigned'
-            r'|schar|u?int8|u?int16|u?int32|u?int64)'
-            r'\s+(register|static|extern|typedef)\b',
-            line):
-    error(filename, linenum, 'build/storage_class', 5,
-          'Storage class (static, extern, typedef, etc) should be first.')
-
-  if Match(r'\s*#\s*endif\s*[^/\s]+', line):
-    error(filename, linenum, 'build/endif_comment', 5,
-          'Uncommented text after #endif is non-standard.  Use a comment.')
-
-  if Match(r'\s*class\s+(\w+\s*::\s*)+\w+\s*;', line):
-    error(filename, linenum, 'build/forward_decl', 5,
-          'Inner-style forward declarations are invalid.  Remove this line.')
-
-  if Search(r'(\w+|[+-]?\d+(\.\d*)?)\s*(<|>)\?=?\s*(\w+|[+-]?\d+)(\.\d*)?',
-            line):
-    error(filename, linenum, 'build/deprecated', 3,
-          '>? and <? (max and min) operators are non-standard and deprecated.')
-
-  if Search(r'^\s*const\s*string\s*&\s*\w+\s*;', line):
-    # TODO(unknown): Could it be expanded safely to arbitrary references,
-    # without triggering too many false positives? The first
-    # attempt triggered 5 warnings for mostly benign code in the regtest, hence
-    # the restriction.
-    # Here's the original regexp, for the reference:
-    # type_name = r'\w+((\s*::\s*\w+)|(\s*<\s*\w+?\s*>))?'
-    # r'\s*const\s*' + type_name + '\s*&\s*\w+\s*;'
-    error(filename, linenum, 'runtime/member_string_references', 2,
-          'const string& members are dangerous. It is much better to use '
-          'alternatives, such as pointers or simple constants.')
-
-  # Everything else in this function operates on class declarations.
-  # Return early if the top of the nesting stack is not a class, or if
-  # the class head is not completed yet.
-  classinfo = nesting_state.InnermostClass()
-  if not classinfo or not classinfo.seen_open_brace:
-    return
-
-  # The class may have been declared with namespace or classname qualifiers.
-  # The constructor and destructor will not have those qualifiers.
-  base_classname = classinfo.name.split('::')[-1]
-
-  # Look for single-argument constructors that aren't marked explicit.
-  # Technically a valid construct, but against style. Also look for
-  # non-single-argument constructors which are also technically valid, but
-  # strongly suggest something is wrong.
-  explicit_constructor_match = Match(
-      r'\s+(?:inline\s+)?(explicit\s+)?(?:inline\s+)?%s\s*'
-      r'\(((?:[^()]|\([^()]*\))*)\)'
-      % re.escape(base_classname),
-      line)
-
-  if explicit_constructor_match:
-    is_marked_explicit = explicit_constructor_match.group(1)
-
-    if not explicit_constructor_match.group(2):
-      constructor_args = []
-    else:
-      constructor_args = explicit_constructor_match.group(2).split(',')
-
-    # collapse arguments so that commas in template parameter lists and function
-    # argument parameter lists don't split arguments in two
-    i = 0
-    while i < len(constructor_args):
-      constructor_arg = constructor_args[i]
-      while (constructor_arg.count('<') > constructor_arg.count('>') or
-             constructor_arg.count('(') > constructor_arg.count(')')):
-        constructor_arg += ',' + constructor_args[i + 1]
-        del constructor_args[i + 1]
-      constructor_args[i] = constructor_arg
-      i += 1
-
-    defaulted_args = [arg for arg in constructor_args if '=' in arg]
-    noarg_constructor = (not constructor_args or  # empty arg list
-                         # 'void' arg specifier
-                         (len(constructor_args) == 1 and
-                          constructor_args[0].strip() == 'void'))
-    onearg_constructor = ((len(constructor_args) == 1 and  # exactly one arg
-                           not noarg_constructor) or
-                          # all but at most one arg defaulted
-                          (len(constructor_args) >= 1 and
-                           not noarg_constructor and
-                           len(defaulted_args) >= len(constructor_args) - 1))
-    initializer_list_constructor = bool(
-        onearg_constructor and
-        Search(r'\bstd\s*::\s*initializer_list\b', constructor_args[0]))
-    copy_constructor = bool(
-        onearg_constructor and
-        Match(r'(const\s+)?%s(\s*<[^>]*>)?(\s+const)?\s*(?:<\w+>\s*)?&'
-              % re.escape(base_classname), constructor_args[0].strip()))
-
-    if (not is_marked_explicit and
-        onearg_constructor and
-        not initializer_list_constructor and
-        not copy_constructor):
-      if defaulted_args:
-        error(filename, linenum, 'runtime/explicit', 5,
-              'Constructors callable with one argument '
-              'should be marked explicit.')
-      else:
-        error(filename, linenum, 'runtime/explicit', 5,
-              'Single-parameter constructors should be marked explicit.')
-    elif is_marked_explicit and not onearg_constructor:
-      if noarg_constructor:
-        error(filename, linenum, 'runtime/explicit', 5,
-              'Zero-parameter constructors should not be marked explicit.')
-      else:
-        error(filename, linenum, 'runtime/explicit', 0,
-              'Constructors that require multiple arguments '
-              'should not be marked explicit.')
-
-
-def CheckSpacingForFunctionCall(filename, clean_lines, linenum, error):
-  """Checks for the correctness of various spacing around function calls.
-
-  Args:
-    filename: The name of the current file.
-    clean_lines: A CleansedLines instance containing the file.
-    linenum: The number of the line to check.
-    error: The function to call with any errors found.
-  """
-  line = clean_lines.elided[linenum]
-
-  # Since function calls often occur inside if/for/while/switch
-  # expressions - which have their own, more liberal conventions - we
-  # first see if we should be looking inside such an expression for a
-  # function call, to which we can apply more strict standards.
-  fncall = line    # if there's no control flow construct, look at whole line
-  for pattern in (r'\bif\s*\((.*)\)\s*{',
-                  r'\bfor\s*\((.*)\)\s*{',
-                  r'\bwhile\s*\((.*)\)\s*[{;]',
-                  r'\bswitch\s*\((.*)\)\s*{'):
-    match = Search(pattern, line)
-    if match:
-      fncall = match.group(1)    # look inside the parens for function calls
-      break
-
-  # Except in if/for/while/switch, there should never be space
-  # immediately inside parens (eg "f( 3, 4 )").  We make an exception
-  # for nested parens ( (a+b) + c ).  Likewise, there should never be
-  # a space before a ( when it's a function argument.  I assume it's a
-  # function argument when the char before the whitespace is legal in
-  # a function name (alnum + _) and we're not starting a macro. Also ignore
-  # pointers and references to arrays and functions coz they're too tricky:
-  # we use a very simple way to recognize these:
-  # " (something)(maybe-something)" or
-  # " (something)(maybe-something," or
-  # " (something)[something]"
-  # Note that we assume the contents of [] to be short enough that
-  # they'll never need to wrap.
-  if (  # Ignore control structures.
-      not Search(r'\b(if|for|while|switch|return|new|delete|catch|sizeof)\b',
-                 fncall) and
-      # Ignore pointers/references to functions.
-      not Search(r' \([^)]+\)\([^)]*(\)|,$)', fncall) and
-      # Ignore pointers/references to arrays.
-      not Search(r' \([^)]+\)\[[^\]]+\]', fncall)):
-    if Search(r'\w\s*\(\s(?!\s*\\$)', fncall):      # a ( used for a fn call
-      error(filename, linenum, 'whitespace/parens', 4,
-            'Extra space after ( in function call')
-    elif Search(r'\(\s+(?!(\s*\\)|\()', fncall):
-      error(filename, linenum, 'whitespace/parens', 2,
-            'Extra space after (')
-    if (Search(r'\w\s+\(', fncall) and
-        not Search(r'#\s*define|typedef|using\s+\w+\s*=', fncall) and
-        not Search(r'\w\s+\((\w+::)*\*\w+\)\(', fncall) and
-        not Search(r'\bcase\s+\(', fncall)):
-      # TODO(unknown): Space after an operator function seem to be a common
-      # error, silence those for now by restricting them to highest verbosity.
-      if Search(r'\boperator_*\b', line):
-        error(filename, linenum, 'whitespace/parens', 0,
-              'Extra space before ( in function call')
-      else:
-        error(filename, linenum, 'whitespace/parens', 4,
-              'Extra space before ( in function call')
-    # If the ) is followed only by a newline or a { + newline, assume it's
-    # part of a control statement (if/while/etc), and don't complain
-    if Search(r'[^)]\s+\)\s*[^{\s]', fncall):
-      # If the closing parenthesis is preceded by only whitespaces,
-      # try to give a more descriptive error message.
-      if Search(r'^\s+\)', fncall):
-        error(filename, linenum, 'whitespace/parens', 2,
-              'Closing ) should be moved to the previous line')
-      else:
-        error(filename, linenum, 'whitespace/parens', 2,
-              'Extra space before )')
-
-
-def IsBlankLine(line):
-  """Returns true if the given line is blank.
-
-  We consider a line to be blank if the line is empty or consists of
-  only white spaces.
-
-  Args:
-    line: A line of a string.
-
-  Returns:
-    True, if the given line is blank.
-  """
-  return not line or line.isspace()
-
-
-def CheckForNamespaceIndentation(filename, nesting_state, clean_lines, line,
-                                 error):
-  is_namespace_indent_item = (
-      len(nesting_state.stack) > 1 and
-      nesting_state.stack[-1].check_namespace_indentation and
-      isinstance(nesting_state.previous_stack_top, _NamespaceInfo) and
-      nesting_state.previous_stack_top == nesting_state.stack[-2])
-
-  if ShouldCheckNamespaceIndentation(nesting_state, is_namespace_indent_item,
-                                     clean_lines.elided, line):
-    CheckItemIndentationInNamespace(filename, clean_lines.elided,
-                                    line, error)
-
-
-def CheckForFunctionLengths(filename, clean_lines, linenum,
-                            function_state, error):
-  """Reports for long function bodies.
-
-  For an overview why this is done, see:
-  http://google-styleguide.googlecode.com/svn/trunk/cppguide.xml#Write_Short_Functions
-
-  Uses a simplistic algorithm assuming other style guidelines
-  (especially spacing) are followed.
-  Only checks unindented functions, so class members are unchecked.
-  Trivial bodies are unchecked, so constructors with huge initializer lists
-  may be missed.
-  Blank/comment lines are not counted so as to avoid encouraging the removal
-  of vertical space and comments just to get through a lint check.
-  NOLINT *on the last line of a function* disables this check.
-
-  Args:
-    filename: The name of the current file.
-    clean_lines: A CleansedLines instance containing the file.
-    linenum: The number of the line to check.
-    function_state: Current function name and lines in body so far.
-    error: The function to call with any errors found.
-  """
-  lines = clean_lines.lines
-  line = lines[linenum]
-  joined_line = ''
-
-  starting_func = False
-  regexp = r'(\w(\w|::|\*|\&|\s)*)\('  # decls * & space::name( ...
-  match_result = Match(regexp, line)
-  if match_result:
-    # If the name is all caps and underscores, figure it's a macro and
-    # ignore it, unless it's TEST or TEST_F.
-    function_name = match_result.group(1).split()[-1]
-    if function_name == 'TEST' or function_name == 'TEST_F' or (
-        not Match(r'[A-Z_]+$', function_name)):
-      starting_func = True
-
-  if starting_func:
-    body_found = False
-    for start_linenum in xrange(linenum, clean_lines.NumLines()):
-      start_line = lines[start_linenum]
-      joined_line += ' ' + start_line.lstrip()
-      if Search(r'(;|})', start_line):  # Declarations and trivial functions
-        body_found = True
-        break                              # ... ignore
-      elif Search(r'{', start_line):
-        body_found = True
-        function = Search(r'((\w|:)*)\(', line).group(1)
-        if Match(r'TEST', function):    # Handle TEST... macros
-          parameter_regexp = Search(r'(\(.*\))', joined_line)
-          if parameter_regexp:             # Ignore bad syntax
-            function += parameter_regexp.group(1)
-        else:
-          function += '()'
-        function_state.Begin(function)
-        break
-    if not body_found:
-      # No body for the function (or evidence of a non-function) was found.
-      error(filename, linenum, 'readability/fn_size', 5,
-            'Lint failed to find start of function body.')
-  elif Match(r'^\}\s*$', line):  # function end
-    function_state.Check(error, filename, linenum)
-    function_state.End()
-  elif not Match(r'^\s*$', line):
-    function_state.Count()  # Count non-blank/non-comment lines.
-
-
-_RE_PATTERN_TODO = re.compile(r'^//(\s*)TODO(\(.+?\))?:?(\s|$)?')
-
-
-def CheckComment(line, filename, linenum, next_line_start, error):
-  """Checks for common mistakes in comments.
-
-  Args:
-    line: The line in question.
-    filename: The name of the current file.
-    linenum: The number of the line to check.
-    next_line_start: The first non-whitespace column of the next line.
-    error: The function to call with any errors found.
-  """
-  commentpos = line.find('//')
-  if commentpos != -1:
-    # Check if the // may be in quotes.  If so, ignore it
-    # Comparisons made explicit for clarity -- pylint: disable=g-explicit-bool-comparison
-    if (line.count('"', 0, commentpos) -
-        line.count('\\"', 0, commentpos)) % 2 == 0:   # not in quotes
-      # Allow one space for new scopes, two spaces otherwise:
-      if (not (Match(r'^.*{ *//', line) and next_line_start == commentpos) and
-          ((commentpos >= 1 and
-            line[commentpos-1] not in string.whitespace) or
-           (commentpos >= 2 and
-            line[commentpos-2] not in string.whitespace))):
-        error(filename, linenum, 'whitespace/comments', 2,
-              'At least two spaces is best between code and comments')
-
-      # Checks for common mistakes in TODO comments.
-      comment = line[commentpos:]
-      match = _RE_PATTERN_TODO.match(comment)
-      if match:
-        # One whitespace is correct; zero whitespace is handled elsewhere.
-        leading_whitespace = match.group(1)
-        if len(leading_whitespace) > 1:
-          error(filename, linenum, 'whitespace/todo', 2,
-                'Too many spaces before TODO')
-
-        username = match.group(2)
-        if not username:
-          error(filename, linenum, 'readability/todo', 2,
-                'Missing username in TODO; it should look like '
-                '"// TODO(my_username): Stuff."')
-
-        middle_whitespace = match.group(3)
-        # Comparisons made explicit for correctness -- pylint: disable=g-explicit-bool-comparison
-        if middle_whitespace != ' ' and middle_whitespace != '':
-          error(filename, linenum, 'whitespace/todo', 2,
-                'TODO(my_username) should be followed by a space')
-
-      # If the comment contains an alphanumeric character, there
-      # should be a space somewhere between it and the // unless
-      # it's a /// or //! Doxygen comment.
-      if (Match(r'//[^ ]*\w', comment) and
-          not Match(r'(///|//\!)(\s+|$)', comment)):
-        error(filename, linenum, 'whitespace/comments', 4,
-              'Should have a space between // and comment')
-
-
-def CheckAccess(filename, clean_lines, linenum, nesting_state, error):
-  """Checks for improper use of DISALLOW* macros.
-
-  Args:
-    filename: The name of the current file.
-    clean_lines: A CleansedLines instance containing the file.
-    linenum: The number of the line to check.
-    nesting_state: A NestingState instance which maintains information about
-                   the current stack of nested blocks being parsed.
-    error: The function to call with any errors found.
-  """
-  line = clean_lines.elided[linenum]  # get rid of comments and strings
-
-  matched = Match((r'\s*(DISALLOW_COPY_AND_ASSIGN|'
-                   r'DISALLOW_IMPLICIT_CONSTRUCTORS)'), line)
-  if not matched:
-    return
-  if nesting_state.stack and isinstance(nesting_state.stack[-1], _ClassInfo):
-    if nesting_state.stack[-1].access != 'private':
-      error(filename, linenum, 'readability/constructors', 3,
-            '%s must be in the private: section' % matched.group(1))
-
-  else:
-    # Found DISALLOW* macro outside a class declaration, or perhaps it
-    # was used inside a function when it should have been part of the
-    # class declaration.  We could issue a warning here, but it
-    # probably resulted in a compiler error already.
-    pass
-
-
-def CheckSpacing(filename, clean_lines, linenum, nesting_state, error):
-  """Checks for the correctness of various spacing issues in the code.
-
-  Things we check for: spaces around operators, spaces after
-  if/for/while/switch, no spaces around parens in function calls, two
-  spaces between code and comment, don't start a block with a blank
-  line, don't end a function with a blank line, don't add a blank line
-  after public/protected/private, don't have too many blank lines in a row.
-
-  Args:
-    filename: The name of the current file.
-    clean_lines: A CleansedLines instance containing the file.
-    linenum: The number of the line to check.
-    nesting_state: A NestingState instance which maintains information about
-                   the current stack of nested blocks being parsed.
-    error: The function to call with any errors found.
-  """
-
-  # Don't use "elided" lines here, otherwise we can't check commented lines.
-  # Don't want to use "raw" either, because we don't want to check inside C++11
-  # raw strings,
-  raw = clean_lines.lines_without_raw_strings
-  line = raw[linenum]
-
-  # Before nixing comments, check if the line is blank for no good
-  # reason.  This includes the first line after a block is opened, and
-  # blank lines at the end of a function (ie, right before a line like '}'
-  #
-  # Skip all the blank line checks if we are immediately inside a
-  # namespace body.  In other words, don't issue blank line warnings
-  # for this block:
-  #   namespace {
-  #
-  #   }
-  #
-  # A warning about missing end of namespace comments will be issued instead.
-  #
-  # Also skip blank line checks for 'extern "C"' blocks, which are formatted
-  # like namespaces.
-  if (IsBlankLine(line) and
-      not nesting_state.InNamespaceBody() and
-      not nesting_state.InExternC()):
-    elided = clean_lines.elided
-    prev_line = elided[linenum - 1]
-    prevbrace = prev_line.rfind('{')
-    # TODO(unknown): Don't complain if line before blank line, and line after,
-    #                both start with alnums and are indented the same amount.
-    #                This ignores whitespace at the start of a namespace block
-    #                because those are not usually indented.
-    if prevbrace != -1 and prev_line[prevbrace:].find('}') == -1:
-      # OK, we have a blank line at the start of a code block.  Before we
-      # complain, we check if it is an exception to the rule: The previous
-      # non-empty line has the parameters of a function header that are indented
-      # 4 spaces (because they did not fit in a 80 column line when placed on
-      # the same line as the function name).  We also check for the case where
-      # the previous line is indented 6 spaces, which may happen when the
-      # initializers of a constructor do not fit into a 80 column line.
-      exception = False
-      if Match(r' {6}\w', prev_line):  # Initializer list?
-        # We are looking for the opening column of initializer list, which
-        # should be indented 4 spaces to cause 6 space indentation afterwards.
-        search_position = linenum-2
-        while (search_position >= 0
-               and Match(r' {6}\w', elided[search_position])):
-          search_position -= 1
-        exception = (search_position >= 0
-                     and elided[search_position][:5] == '    :')
-      else:
-        # Search for the function arguments or an initializer list.  We use a
-        # simple heuristic here: If the line is indented 4 spaces; and we have a
-        # closing paren, without the opening paren, followed by an opening brace
-        # or colon (for initializer lists) we assume that it is the last line of
-        # a function header.  If we have a colon indented 4 spaces, it is an
-        # initializer list.
-        exception = (Match(r' {4}\w[^\(]*\)\s*(const\s*)?(\{\s*$|:)',
-                           prev_line)
-                     or Match(r' {4}:', prev_line))
-
-      if not exception:
-        error(filename, linenum, 'whitespace/blank_line', 2,
-              'Redundant blank line at the start of a code block '
-              'should be deleted.')
-    # Ignore blank lines at the end of a block in a long if-else
-    # chain, like this:
-    #   if (condition1) {
-    #     // Something followed by a blank line
-    #
-    #   } else if (condition2) {
-    #     // Something else
-    #   }
-    if linenum + 1 < clean_lines.NumLines():
-      next_line = raw[linenum + 1]
-      if (next_line
-          and Match(r'\s*}', next_line)
-          and next_line.find('} else ') == -1):
-        error(filename, linenum, 'whitespace/blank_line', 3,
-              'Redundant blank line at the end of a code block '
-              'should be deleted.')
-
-    matched = Match(r'\s*(public|protected|private):', prev_line)
-    if matched:
-      error(filename, linenum, 'whitespace/blank_line', 3,
-            'Do not leave a blank line after "%s:"' % matched.group(1))
-
-  # Next, check comments
-  next_line_start = 0
-  if linenum + 1 < clean_lines.NumLines():
-    next_line = raw[linenum + 1]
-    next_line_start = len(next_line) - len(next_line.lstrip())
-  CheckComment(line, filename, linenum, next_line_start, error)
-
-  # get rid of comments and strings
-  line = clean_lines.elided[linenum]
-
-  # You shouldn't have spaces before your brackets, except maybe after
-  # 'delete []' or 'return []() {};'
-  if Search(r'\w\s+\[', line) and not Search(r'(?:delete|return)\s+\[', line):
-    error(filename, linenum, 'whitespace/braces', 5,
-          'Extra space before [')
-
-  # In range-based for, we wanted spaces before and after the colon, but
-  # not around "::" tokens that might appear.
-  if (Search(r'for *\(.*[^:]:[^: ]', line) or
-      Search(r'for *\(.*[^: ]:[^:]', line)):
-    error(filename, linenum, 'whitespace/forcolon', 2,
-          'Missing space around colon in range-based for loop')
-
-
-def CheckOperatorSpacing(filename, clean_lines, linenum, error):
-  """Checks for horizontal spacing around operators.
-
-  Args:
-    filename: The name of the current file.
-    clean_lines: A CleansedLines instance containing the file.
-    linenum: The number of the line to check.
-    error: The function to call with any errors found.
-  """
-  line = clean_lines.elided[linenum]
-
-  # Don't try to do spacing checks for operator methods.  Do this by
-  # replacing the troublesome characters with something else,
-  # preserving column position for all other characters.
-  #
-  # The replacement is done repeatedly to avoid false positives from
-  # operators that call operators.
-  while True:
-    match = Match(r'^(.*\boperator\b)(\S+)(\s*\(.*)$', line)
-    if match:
-      line = match.group(1) + ('_' * len(match.group(2))) + match.group(3)
-    else:
-      break
-
-  # We allow no-spaces around = within an if: "if ( (a=Foo()) == 0 )".
-  # Otherwise not.  Note we only check for non-spaces on *both* sides;
-  # sometimes people put non-spaces on one side when aligning ='s among
-  # many lines (not that this is behavior that I approve of...)
-  if ((Search(r'[\w.]=', line) or
-       Search(r'=[\w.]', line))
-      and not Search(r'\b(if|while|for) ', line)
-      # Operators taken from [lex.operators] in C++11 standard.
-      and not Search(r'(>=|<=|==|!=|&=|\^=|\|=|\+=|\*=|\/=|\%=)', line)
-      and not Search(r'operator=', line)):
-    error(filename, linenum, 'whitespace/operators', 4,
-          'Missing spaces around =')
-
-  # It's ok not to have spaces around binary operators like + - * /, but if
-  # there's too little whitespace, we get concerned.  It's hard to tell,
-  # though, so we punt on this one for now.  TODO.
-
-  # You should always have whitespace around binary operators.
-  #
-  # Check <= and >= first to avoid false positives with < and >, then
-  # check non-include lines for spacing around < and >.
-  #
-  # If the operator is followed by a comma, assume it's be used in a
-  # macro context and don't do any checks.  This avoids false
-  # positives.
-  #
-  # Note that && is not included here.  Those are checked separately
-  # in CheckRValueReference
-  match = Search(r'[^<>=!\s](==|!=|<=|>=|\|\|)[^<>=!\s,;\)]', line)
-  if match:
-    error(filename, linenum, 'whitespace/operators', 3,
-          'Missing spaces around %s' % match.group(1))
-  elif not Match(r'#.*include', line):
-    # Look for < that is not surrounded by spaces.  This is only
-    # triggered if both sides are missing spaces, even though
-    # technically should should flag if at least one side is missing a
-    # space.  This is done to avoid some false positives with shifts.
-    match = Match(r'^(.*[^\s<])<[^\s=<,]', line)
-    if match:
-      (_, _, end_pos) = CloseExpression(
-          clean_lines, linenum, len(match.group(1)))
-      if end_pos <= -1:
-        error(filename, linenum, 'whitespace/operators', 3,
-              'Missing spaces around <')
-
-    # Look for > that is not surrounded by spaces.  Similar to the
-    # above, we only trigger if both sides are missing spaces to avoid
-    # false positives with shifts.
-    match = Match(r'^(.*[^-\s>])>[^\s=>,]', line)
-    if match:
-      (_, _, start_pos) = ReverseCloseExpression(
-          clean_lines, linenum, len(match.group(1)))
-      if start_pos <= -1:
-        error(filename, linenum, 'whitespace/operators', 3,
-              'Missing spaces around >')
-
-  # We allow no-spaces around << when used like this: 10<<20, but
-  # not otherwise (particularly, not when used as streams)
-  #
-  # We also allow operators following an opening parenthesis, since
-  # those tend to be macros that deal with operators.
-  match = Search(r'(operator|[^\s(<])(?:L|UL|ULL|l|ul|ull)?<<([^\s,=<])', line)
-  if (match and not (match.group(1).isdigit() and match.group(2).isdigit()) and
-      not (match.group(1) == 'operator' and match.group(2) == ';')):
-    error(filename, linenum, 'whitespace/operators', 3,
-          'Missing spaces around <<')
-
-  # We allow no-spaces around >> for almost anything.  This is because
-  # C++11 allows ">>" to close nested templates, which accounts for
-  # most cases when ">>" is not followed by a space.
-  #
-  # We still warn on ">>" followed by alpha character, because that is
-  # likely due to ">>" being used for right shifts, e.g.:
-  #   value >> alpha
-  #
-  # When ">>" is used to close templates, the alphanumeric letter that
-  # follows would be part of an identifier, and there should still be
-  # a space separating the template type and the identifier.
-  #   type<type<type>> alpha
-  match = Search(r'>>[a-zA-Z_]', line)
-  if match:
-    error(filename, linenum, 'whitespace/operators', 3,
-          'Missing spaces around >>')
-
-  # There shouldn't be space around unary operators
-  match = Search(r'(!\s|~\s|[\s]--[\s;]|[\s]\+\+[\s;])', line)
-  if match:
-    error(filename, linenum, 'whitespace/operators', 4,
-          'Extra space for operator %s' % match.group(1))
-
-
-def CheckParenthesisSpacing(filename, clean_lines, linenum, error):
-  """Checks for horizontal spacing around parentheses.
-
-  Args:
-    filename: The name of the current file.
-    clean_lines: A CleansedLines instance containing the file.
-    linenum: The number of the line to check.
-    error: The function to call with any errors found.
-  """
-  line = clean_lines.elided[linenum]
-
-  # No spaces after an if, while, switch, or for
-  match = Search(r' (if\(|for\(|while\(|switch\()', line)
-  if match:
-    error(filename, linenum, 'whitespace/parens', 5,
-          'Missing space before ( in %s' % match.group(1))
-
-  # For if/for/while/switch, the left and right parens should be
-  # consistent about how many spaces are inside the parens, and
-  # there should either be zero or one spaces inside the parens.
-  # We don't want: "if ( foo)" or "if ( foo   )".
-  # Exception: "for ( ; foo; bar)" and "for (foo; bar; )" are allowed.
-  match = Search(r'\b(if|for|while|switch)\s*'
-                 r'\(([ ]*)(.).*[^ ]+([ ]*)\)\s*{\s*$',
-                 line)
-  if match:
-    if len(match.group(2)) != len(match.group(4)):
-      if not (match.group(3) == ';' and
-              len(match.group(2)) == 1 + len(match.group(4)) or
-              not match.group(2) and Search(r'\bfor\s*\(.*; \)', line)):
-        error(filename, linenum, 'whitespace/parens', 5,
-              'Mismatching spaces inside () in %s' % match.group(1))
-    if len(match.group(2)) not in [0, 1]:
-      error(filename, linenum, 'whitespace/parens', 5,
-            'Should have zero or one spaces inside ( and ) in %s' %
-            match.group(1))
-
-
-def CheckCommaSpacing(filename, clean_lines, linenum, error):
-  """Checks for horizontal spacing near commas and semicolons.
-
-  Args:
-    filename: The name of the current file.
-    clean_lines: A CleansedLines instance containing the file.
-    linenum: The number of the line to check.
-    error: The function to call with any errors found.
-  """
-  raw = clean_lines.lines_without_raw_strings
-  line = clean_lines.elided[linenum]
-
-  # You should always have a space after a comma (either as fn arg or operator)
-  #
-  # This does not apply when the non-space character following the
-  # comma is another comma, since the only time when that happens is
-  # for empty macro arguments.
-  #
-  # We run this check in two passes: first pass on elided lines to
-  # verify that lines contain missing whitespaces, second pass on raw
-  # lines to confirm that those missing whitespaces are not due to
-  # elided comments.
-  if (Search(r',[^,\s]', ReplaceAll(r'\boperator\s*,\s*\(', 'F(', line)) and
-      Search(r',[^,\s]', raw[linenum])):
-    error(filename, linenum, 'whitespace/comma', 3,
-          'Missing space after ,')
-
-  # You should always have a space after a semicolon
-  # except for few corner cases
-  # TODO(unknown): clarify if 'if (1) { return 1;}' is requires one more
-  # space after ;
-  if Search(r';[^\s};\\)/]', line):
-    error(filename, linenum, 'whitespace/semicolon', 3,
-          'Missing space after ;')
-
-
-def CheckBracesSpacing(filename, clean_lines, linenum, error):
-  """Checks for horizontal spacing near commas.
-
-  Args:
-    filename: The name of the current file.
-    clean_lines: A CleansedLines instance containing the file.
-    linenum: The number of the line to check.
-    error: The function to call with any errors found.
-  """
-  line = clean_lines.elided[linenum]
-
-  # Except after an opening paren, or after another opening brace (in case of
-  # an initializer list, for instance), you should have spaces before your
-  # braces. And since you should never have braces at the beginning of a line,
-  # this is an easy test.
-  match = Match(r'^(.*[^ ({>]){', line)
-  if match:
-    # Try a bit harder to check for brace initialization.  This
-    # happens in one of the following forms:
-    #   Constructor() : initializer_list_{} { ... }
-    #   Constructor{}.MemberFunction()
-    #   Type variable{};
-    #   FunctionCall(type{}, ...);
-    #   LastArgument(..., type{});
-    #   LOG(INFO) << type{} << " ...";
-    #   map_of_type[{...}] = ...;
-    #   ternary = expr ? new type{} : nullptr;
-    #   OuterTemplate<InnerTemplateConstructor<Type>{}>
-    #
-    # We check for the character following the closing brace, and
-    # silence the warning if it's one of those listed above, i.e.
-    # "{.;,)<>]:".
-    #
-    # To account for nested initializer list, we allow any number of
-    # closing braces up to "{;,)<".  We can't simply silence the
-    # warning on first sight of closing brace, because that would
-    # cause false negatives for things that are not initializer lists.
-    #   Silence this:         But not this:
-    #     Outer{                if (...) {
-    #       Inner{...}            if (...){  // Missing space before {
-    #     };                    }
-    #
-    # There is a false negative with this approach if people inserted
-    # spurious semicolons, e.g. "if (cond){};", but we will catch the
-    # spurious semicolon with a separate check.
-    (endline, endlinenum, endpos) = CloseExpression(
-        clean_lines, linenum, len(match.group(1)))
-    trailing_text = ''
-    if endpos > -1:
-      trailing_text = endline[endpos:]
-    for offset in xrange(endlinenum + 1,
-                         min(endlinenum + 3, clean_lines.NumLines() - 1)):
-      trailing_text += clean_lines.elided[offset]
-    if not Match(r'^[\s}]*[{.;,)<>\]:]', trailing_text):
-      error(filename, linenum, 'whitespace/braces', 5,
-            'Missing space before {')
-
-  # Make sure '} else {' has spaces.
-  if Search(r'}else', line):
-    error(filename, linenum, 'whitespace/braces', 5,
-          'Missing space before else')
-
-  # You shouldn't have a space before a semicolon at the end of the line.
-  # There's a special case for "for" since the style guide allows space before
-  # the semicolon there.
-  if Search(r':\s*;\s*$', line):
-    error(filename, linenum, 'whitespace/semicolon', 5,
-          'Semicolon defining empty statement. Use {} instead.')
-  elif Search(r'^\s*;\s*$', line):
-    error(filename, linenum, 'whitespace/semicolon', 5,
-          'Line contains only semicolon. If this should be an empty statement, '
-          'use {} instead.')
-  elif (Search(r'\s+;\s*$', line) and
-        not Search(r'\bfor\b', line)):
-    error(filename, linenum, 'whitespace/semicolon', 5,
-          'Extra space before last semicolon. If this should be an empty '
-          'statement, use {} instead.')
-
-
-def IsDecltype(clean_lines, linenum, column):
-  """Check if the token ending on (linenum, column) is decltype().
-
-  Args:
-    clean_lines: A CleansedLines instance containing the file.
-    linenum: the number of the line to check.
-    column: end column of the token to check.
-  Returns:
-    True if this token is decltype() expression, False otherwise.
-  """
-  (text, _, start_col) = ReverseCloseExpression(clean_lines, linenum, column)
-  if start_col < 0:
-    return False
-  if Search(r'\bdecltype\s*$', text[0:start_col]):
-    return True
-  return False
-
-
-def IsTemplateParameterList(clean_lines, linenum, column):
-  """Check if the token ending on (linenum, column) is the end of template<>.
-
-  Args:
-    clean_lines: A CleansedLines instance containing the file.
-    linenum: the number of the line to check.
-    column: end column of the token to check.
-  Returns:
-    True if this token is end of a template parameter list, False otherwise.
-  """
-  (_, startline, startpos) = ReverseCloseExpression(
-      clean_lines, linenum, column)
-  if (startpos > -1 and
-      Search(r'\btemplate\s*$', clean_lines.elided[startline][0:startpos])):
-    return True
-  return False
-
-
-def IsRValueType(typenames, clean_lines, nesting_state, linenum, column):
-  """Check if the token ending on (linenum, column) is a type.
-
-  Assumes that text to the right of the column is "&&" or a function
-  name.
-
-  Args:
-    typenames: set of type names from template-argument-list.
-    clean_lines: A CleansedLines instance containing the file.
-    nesting_state: A NestingState instance which maintains information about
-                   the current stack of nested blocks being parsed.
-    linenum: the number of the line to check.
-    column: end column of the token to check.
-  Returns:
-    True if this token is a type, False if we are not sure.
-  """
-  prefix = clean_lines.elided[linenum][0:column]
-
-  # Get one word to the left.  If we failed to do so, this is most
-  # likely not a type, since it's unlikely that the type name and "&&"
-  # would be split across multiple lines.
-  match = Match(r'^(.*)(\b\w+|[>*)&])\s*$', prefix)
-  if not match:
-    return False
-
-  # Check text following the token.  If it's "&&>" or "&&," or "&&...", it's
-  # most likely a rvalue reference used inside a template.
-  suffix = clean_lines.elided[linenum][column:]
-  if Match(r'&&\s*(?:[>,]|\.\.\.)', suffix):
-    return True
-
-  # Check for known types and end of templates:
-  #   int&& variable
-  #   vector<int>&& variable
-  #
-  # Because this function is called recursively, we also need to
-  # recognize pointer and reference types:
-  #   int* Function()
-  #   int& Function()
-  if (match.group(2) in typenames or
-      match.group(2) in ['char', 'char16_t', 'char32_t', 'wchar_t', 'bool',
-                         'short', 'int', 'long', 'signed', 'unsigned',
-                         'float', 'double', 'void', 'auto', '>', '*', '&']):
-    return True
-
-  # If we see a close parenthesis, look for decltype on the other side.
-  # decltype would unambiguously identify a type, anything else is
-  # probably a parenthesized expression and not a type.
-  if match.group(2) == ')':
-    return IsDecltype(
-        clean_lines, linenum, len(match.group(1)) + len(match.group(2)) - 1)
-
-  # Check for casts and cv-qualifiers.
-  #   match.group(1)  remainder
-  #   --------------  ---------
-  #   const_cast<     type&&
-  #   const           type&&
-  #   type            const&&
-  if Search(r'\b(?:const_cast\s*<|static_cast\s*<|dynamic_cast\s*<|'
-            r'reinterpret_cast\s*<|\w+\s)\s*$',
-            match.group(1)):
-    return True
-
-  # Look for a preceding symbol that might help differentiate the context.
-  # These are the cases that would be ambiguous:
-  #   match.group(1)  remainder
-  #   --------------  ---------
-  #   Call         (   expression &&
-  #   Declaration  (   type&&
-  #   sizeof       (   type&&
-  #   if           (   expression &&
-  #   while        (   expression &&
-  #   for          (   type&&
-  #   for(         ;   expression &&
-  #   statement    ;   type&&
-  #   block        {   type&&
-  #   constructor  {   expression &&
-  start = linenum
-  line = match.group(1)
-  match_symbol = None
-  while start >= 0:
-    # We want to skip over identifiers and commas to get to a symbol.
-    # Commas are skipped so that we can find the opening parenthesis
-    # for function parameter lists.
-    match_symbol = Match(r'^(.*)([^\w\s,])[\w\s,]*$', line)
-    if match_symbol:
-      break
-    start -= 1
-    line = clean_lines.elided[start]
-
-  if not match_symbol:
-    # Probably the first statement in the file is an rvalue reference
-    return True
-
-  if match_symbol.group(2) == '}':
-    # Found closing brace, probably an indicate of this:
-    #   block{} type&&
-    return True
-
-  if match_symbol.group(2) == ';':
-    # Found semicolon, probably one of these:
-    #   for(; expression &&
-    #   statement; type&&
-
-    # Look for the previous 'for(' in the previous lines.
-    before_text = match_symbol.group(1)
-    for i in xrange(start - 1, max(start - 6, 0), -1):
-      before_text = clean_lines.elided[i] + before_text
-    if Search(r'for\s*\([^{};]*$', before_text):
-      # This is the condition inside a for-loop
-      return False
-
-    # Did not find a for-init-statement before this semicolon, so this
-    # is probably a new statement and not a condition.
-    return True
-
-  if match_symbol.group(2) == '{':
-    # Found opening brace, probably one of these:
-    #   block{ type&& = ... ; }
-    #   constructor{ expression && expression }
-
-    # Look for a closing brace or a semicolon.  If we see a semicolon
-    # first, this is probably a rvalue reference.
-    line = clean_lines.elided[start][0:len(match_symbol.group(1)) + 1]
-    end = start
-    depth = 1
-    while True:
-      for ch in line:
-        if ch == ';':
-          return True
-        elif ch == '{':
-          depth += 1
-        elif ch == '}':
-          depth -= 1
-          if depth == 0:
-            return False
-      end += 1
-      if end >= clean_lines.NumLines():
-        break
-      line = clean_lines.elided[end]
-    # Incomplete program?
-    return False
-
-  if match_symbol.group(2) == '(':
-    # Opening parenthesis.  Need to check what's to the left of the
-    # parenthesis.  Look back one extra line for additional context.
-    before_text = match_symbol.group(1)
-    # if linenum > 1:
-    #   before_text = clean_lines.elided[linenum - 1] + before_text
-    # before_text = match_symbol.group(1)
-
-    # Patterns that are likely to be types:
-    #   [](type&&
-    #   for (type&&
-    #   sizeof(type&&
-    #   operator=(type&&
-    #
-    if Search(r'(?:\]|\bfor|\bsizeof|\boperator\s*\S+\s*)\s*$', before_text):
-      return True
-
-    # Patterns that are likely to be expressions:
-    #   if (expression &&
-    #   while (expression &&
-    #   : initializer(expression &&
-    #   , initializer(expression &&
-    #   ( FunctionCall(expression &&
-    #   + FunctionCall(expression &&
-    #   + (expression &&
-    #
-    # The last '+' represents operators such as '+' and '-'.
-    if Search(r'(?:\bif|\bwhile|[-+=%^(<!?:,&*]\s*)$', before_text):
-      return False
-
-    # Something else.  Check that tokens to the left look like
-    #   return_type function_name
-    match_func = Match(r'^(.*\S.*)\s+\w(?:\w|::)*(?:<[^<>]*>)?\s*$',
-                       match_symbol.group(1))
-    if match_func:
-      # Check for constructors, which don't have return types.
-      if Search(r'\b(?:explicit|inline)$', match_func.group(1)):
-        return True
-      implicit_constructor = Match(r'\s*(\w+)\((?:const\s+)?(\w+)', prefix)
-      if (implicit_constructor and
-          implicit_constructor.group(1) == implicit_constructor.group(2)):
-        return True
-      return IsRValueType(typenames, clean_lines, nesting_state, linenum,
-                          len(match_func.group(1)))
-
-    # Nothing before the function name.  If this is inside a block scope,
-    # this is probably a function call.
-    return not (nesting_state.previous_stack_top and
-                nesting_state.previous_stack_top.IsBlockInfo())
-
-  if match_symbol.group(2) == '>':
-    # Possibly a closing bracket, check that what's on the other side
-    # looks like the start of a template.
-    return IsTemplateParameterList(
-        clean_lines, start, len(match_symbol.group(1)))
-
-  # Some other symbol, usually something like "a=b&&c".  This is most
-  # likely not a type.
-  return False
-
-
-def IsDeletedOrDefault(clean_lines, linenum):
-  """Check if current constructor or operator is deleted or default.
-
-  Args:
-    clean_lines: A CleansedLines instance containing the file.
-    linenum: The number of the line to check.
-  Returns:
-    True if this is a deleted or default constructor.
-  """
-  open_paren = clean_lines.elided[linenum].find('(')
-  if open_paren < 0:
-    return False
-  (close_line, _, close_paren) = CloseExpression(
-      clean_lines, linenum, open_paren)
-  if close_paren < 0:
-    return False
-  return Match(r'\s*=\s*(?:delete|default)\b', close_line[close_paren:])
-
-
-def IsRValueAllowed(clean_lines, linenum, typenames):
-  """Check if RValue reference is allowed on a particular line.
-
-  Args:
-    clean_lines: A CleansedLines instance containing the file.
-    linenum: The number of the line to check.
-    typenames: set of type names from template-argument-list.
-  Returns:
-    True if line is within the region where RValue references are allowed.
-  """
-  # Allow region marked by PUSH/POP macros
-  for i in xrange(linenum, 0, -1):
-    line = clean_lines.elided[i]
-    if Match(r'GOOGLE_ALLOW_RVALUE_REFERENCES_(?:PUSH|POP)', line):
-      if not line.endswith('PUSH'):
-        return False
-      for j in xrange(linenum, clean_lines.NumLines(), 1):
-        line = clean_lines.elided[j]
-        if Match(r'GOOGLE_ALLOW_RVALUE_REFERENCES_(?:PUSH|POP)', line):
-          return line.endswith('POP')
-
-  # Allow operator=
-  line = clean_lines.elided[linenum]
-  if Search(r'\boperator\s*=\s*\(', line):
-    return IsDeletedOrDefault(clean_lines, linenum)
-
-  # Allow constructors
-  match = Match(r'\s*(?:[\w<>]+::)*([\w<>]+)\s*::\s*([\w<>]+)\s*\(', line)
-  if match and match.group(1) == match.group(2):
-    return IsDeletedOrDefault(clean_lines, linenum)
-  if Search(r'\b(?:explicit|inline)\s+[\w<>]+\s*\(', line):
-    return IsDeletedOrDefault(clean_lines, linenum)
-
-  if Match(r'\s*[\w<>]+\s*\(', line):
-    previous_line = 'ReturnType'
-    if linenum > 0:
-      previous_line = clean_lines.elided[linenum - 1]
-    if Match(r'^\s*$', previous_line) or Search(r'[{}:;]\s*$', previous_line):
-      return IsDeletedOrDefault(clean_lines, linenum)
-
-  # Reject types not mentioned in template-argument-list
-  while line:
-    match = Match(r'^.*?(\w+)\s*&&(.*)$', line)
-    if not match:
-      break
-    if match.group(1) not in typenames:
-      return False
-    line = match.group(2)
-
-  # All RValue types that were in template-argument-list should have
-  # been removed by now.  Those were allowed, assuming that they will
-  # be forwarded.
-  #
-  # If there are no remaining RValue types left (i.e. types that were
-  # not found in template-argument-list), flag those as not allowed.
-  return line.find('&&') < 0
-
-
-def GetTemplateArgs(clean_lines, linenum):
-  """Find list of template arguments associated with this function declaration.
-
-  Args:
-    clean_lines: A CleansedLines instance containing the file.
-    linenum: Line number containing the start of the function declaration,
-             usually one line after the end of the template-argument-list.
-  Returns:
-    Set of type names, or empty set if this does not appear to have
-    any template parameters.
-  """
-  # Find start of function
-  func_line = linenum
-  while func_line > 0:
-    line = clean_lines.elided[func_line]
-    if Match(r'^\s*$', line):
-      return set()
-    if line.find('(') >= 0:
-      break
-    func_line -= 1
-  if func_line == 0:
-    return set()
-
-  # Collapse template-argument-list into a single string
-  argument_list = ''
-  match = Match(r'^(\s*template\s*)<', clean_lines.elided[func_line])
-  if match:
-    # template-argument-list on the same line as function name
-    start_col = len(match.group(1))
-    _, end_line, end_col = CloseExpression(clean_lines, func_line, start_col)
-    if end_col > -1 and end_line == func_line:
-      start_col += 1  # Skip the opening bracket
-      argument_list = clean_lines.elided[func_line][start_col:end_col]
-
-  elif func_line > 1:
-    # template-argument-list one line before function name
-    match = Match(r'^(.*)>\s*$', clean_lines.elided[func_line - 1])
-    if match:
-      end_col = len(match.group(1))
-      _, start_line, start_col = ReverseCloseExpression(
-          clean_lines, func_line - 1, end_col)
-      if start_col > -1:
-        start_col += 1  # Skip the opening bracket
-        while start_line < func_line - 1:
-          argument_list += clean_lines.elided[start_line][start_col:]
-          start_col = 0
-          start_line += 1
-        argument_list += clean_lines.elided[func_line - 1][start_col:end_col]
-
-  if not argument_list:
-    return set()
-
-  # Extract type names
-  typenames = set()
-  while True:
-    match = Match(r'^[,\s]*(?:typename|class)(?:\.\.\.)?\s+(\w+)(.*)$',
-                  argument_list)
-    if not match:
-      break
-    typenames.add(match.group(1))
-    argument_list = match.group(2)
-  return typenames
-
-
-def CheckRValueReference(filename, clean_lines, linenum, nesting_state, error):
-  """Check for rvalue references.
-
-  Args:
-    filename: The name of the current file.
-    clean_lines: A CleansedLines instance containing the file.
-    linenum: The number of the line to check.
-    nesting_state: A NestingState instance which maintains information about
-                   the current stack of nested blocks being parsed.
-    error: The function to call with any errors found.
-  """
-  # Find lines missing spaces around &&.
-  # TODO(unknown): currently we don't check for rvalue references
-  # with spaces surrounding the && to avoid false positives with
-  # boolean expressions.
-  line = clean_lines.elided[linenum]
-  match = Match(r'^(.*\S)&&', line)
-  if not match:
-    match = Match(r'(.*)&&\S', line)
-  if (not match) or '(&&)' in line or Search(r'\boperator\s*$', match.group(1)):
-    return
-
-  # Either poorly formed && or an rvalue reference, check the context
-  # to get a more accurate error message.  Mostly we want to determine
-  # if what's to the left of "&&" is a type or not.
-  typenames = GetTemplateArgs(clean_lines, linenum)
-  and_pos = len(match.group(1))
-  if IsRValueType(typenames, clean_lines, nesting_state, linenum, and_pos):
-    if not IsRValueAllowed(clean_lines, linenum, typenames):
-      error(filename, linenum, 'build/c++11', 3,
-            'RValue references are an unapproved C++ feature.')
-  else:
-    error(filename, linenum, 'whitespace/operators', 3,
-          'Missing spaces around &&')
-
-
-def CheckSectionSpacing(filename, clean_lines, class_info, linenum, error):
-  """Checks for additional blank line issues related to sections.
-
-  Currently the only thing checked here is blank line before protected/private.
-
-  Args:
-    filename: The name of the current file.
-    clean_lines: A CleansedLines instance containing the file.
-    class_info: A _ClassInfo objects.
-    linenum: The number of the line to check.
-    error: The function to call with any errors found.
-  """
-  # Skip checks if the class is small, where small means 25 lines or less.
-  # 25 lines seems like a good cutoff since that's the usual height of
-  # terminals, and any class that can't fit in one screen can't really
-  # be considered "small".
-  #
-  # Also skip checks if we are on the first line.  This accounts for
-  # classes that look like
-  #   class Foo { public: ... };
-  #
-  # If we didn't find the end of the class, last_line would be zero,
-  # and the check will be skipped by the first condition.
-  if (class_info.last_line - class_info.starting_linenum <= 24 or
-      linenum <= class_info.starting_linenum):
-    return
-
-  matched = Match(r'\s*(public|protected|private):', clean_lines.lines[linenum])
-  if matched:
-    # Issue warning if the line before public/protected/private was
-    # not a blank line, but don't do this if the previous line contains
-    # "class" or "struct".  This can happen two ways:
-    #  - We are at the beginning of the class.
-    #  - We are forward-declaring an inner class that is semantically
-    #    private, but needed to be public for implementation reasons.
-    # Also ignores cases where the previous line ends with a backslash as can be
-    # common when defining classes in C macros.
-    prev_line = clean_lines.lines[linenum - 1]
-    if (not IsBlankLine(prev_line) and
-        not Search(r'\b(class|struct)\b', prev_line) and
-        not Search(r'\\$', prev_line)):
-      # Try a bit harder to find the beginning of the class.  This is to
-      # account for multi-line base-specifier lists, e.g.:
-      #   class Derived
-      #       : public Base {
-      end_class_head = class_info.starting_linenum
-      for i in range(class_info.starting_linenum, linenum):
-        if Search(r'\{\s*$', clean_lines.lines[i]):
-          end_class_head = i
-          break
-      if end_class_head < linenum - 1:
-        error(filename, linenum, 'whitespace/blank_line', 3,
-              '"%s:" should be preceded by a blank line' % matched.group(1))
-
-
-def GetPreviousNonBlankLine(clean_lines, linenum):
-  """Return the most recent non-blank line and its line number.
-
-  Args:
-    clean_lines: A CleansedLines instance containing the file contents.
-    linenum: The number of the line to check.
-
-  Returns:
-    A tuple with two elements.  The first element is the contents of the last
-    non-blank line before the current line, or the empty string if this is the
-    first non-blank line.  The second is the line number of that line, or -1
-    if this is the first non-blank line.
-  """
-
-  prevlinenum = linenum - 1
-  while prevlinenum >= 0:
-    prevline = clean_lines.elided[prevlinenum]
-    if not IsBlankLine(prevline):     # if not a blank line...
-      return (prevline, prevlinenum)
-    prevlinenum -= 1
-  return ('', -1)
-
-
-def CheckBraces(filename, clean_lines, linenum, error):
-  """Looks for misplaced braces (e.g. at the end of line).
-
-  Args:
-    filename: The name of the current file.
-    clean_lines: A CleansedLines instance containing the file.
-    linenum: The number of the line to check.
-    error: The function to call with any errors found.
-  """
-
-  line = clean_lines.elided[linenum]        # get rid of comments and strings
-
-  if Match(r'\s*{\s*$', line):
-    # We allow an open brace to start a line in the case where someone is using
-    # braces in a block to explicitly create a new scope, which is commonly used
-    # to control the lifetime of stack-allocated variables.  Braces are also
-    # used for brace initializers inside function calls.  We don't detect this
-    # perfectly: we just don't complain if the last non-whitespace character on
-    # the previous non-blank line is ',', ';', ':', '(', '{', or '}', or if the
-    # previous line starts a preprocessor block.
-    prevline = GetPreviousNonBlankLine(clean_lines, linenum)[0]
-    if (not Search(r'[,;:}{(]\s*$', prevline) and
-        not Match(r'\s*#', prevline)):
-      error(filename, linenum, 'whitespace/braces', 4,
-            '{ should almost always be at the end of the previous line')
-
-  # An else clause should be on the same line as the preceding closing brace.
-  if Match(r'\s*else\b\s*(?:if\b|\{|$)', line):
-    prevline = GetPreviousNonBlankLine(clean_lines, linenum)[0]
-    if Match(r'\s*}\s*$', prevline):
-      error(filename, linenum, 'whitespace/newline', 4,
-            'An else should appear on the same line as the preceding }')
-
-  # If braces come on one side of an else, they should be on both.
-  # However, we have to worry about "else if" that spans multiple lines!
-  if Search(r'else if\s*\(', line):       # could be multi-line if
-    brace_on_left = bool(Search(r'}\s*else if\s*\(', line))
-    # find the ( after the if
-    pos = line.find('else if')
-    pos = line.find('(', pos)
-    if pos > 0:
-      (endline, _, endpos) = CloseExpression(clean_lines, linenum, pos)
-      brace_on_right = endline[endpos:].find('{') != -1
-      if brace_on_left != brace_on_right:    # must be brace after if
-        error(filename, linenum, 'readability/braces', 5,
-              'If an else has a brace on one side, it should have it on both')
-  elif Search(r'}\s*else[^{]*$', line) or Match(r'[^}]*else\s*{', line):
-    error(filename, linenum, 'readability/braces', 5,
-          'If an else has a brace on one side, it should have it on both')
-
-  # Likewise, an else should never have the else clause on the same line
-  if Search(r'\belse [^\s{]', line) and not Search(r'\belse if\b', line):
-    error(filename, linenum, 'whitespace/newline', 4,
-          'Else clause should never be on same line as else (use 2 lines)')
-
-  # In the same way, a do/while should never be on one line
-  if Match(r'\s*do [^\s{]', line):
-    error(filename, linenum, 'whitespace/newline', 4,
-          'do/while clauses should not be on a single line')
-
-  # Check single-line if/else bodies. The style guide says 'curly braces are not
-  # required for single-line statements'. We additionally allow multi-line,
-  # single statements, but we reject anything with more than one semicolon in
-  # it. This means that the first semicolon after the if should be at the end of
-  # its line, and the line after that should have an indent level equal to or
-  # lower than the if. We also check for ambiguous if/else nesting without
-  # braces.
-  if_else_match = Search(r'\b(if\s*\(|else\b)', line)
-  if if_else_match and not Match(r'\s*#', line):
-    if_indent = GetIndentLevel(line)
-    endline, endlinenum, endpos = line, linenum, if_else_match.end()
-    if_match = Search(r'\bif\s*\(', line)
-    if if_match:
-      # This could be a multiline if condition, so find the end first.
-      pos = if_match.end() - 1
-      (endline, endlinenum, endpos) = CloseExpression(clean_lines, linenum, pos)
-    # Check for an opening brace, either directly after the if or on the next
-    # line. If found, this isn't a single-statement conditional.
-    if (not Match(r'\s*{', endline[endpos:])
-        and not (Match(r'\s*$', endline[endpos:])
-                 and endlinenum < (len(clean_lines.elided) - 1)
-                 and Match(r'\s*{', clean_lines.elided[endlinenum + 1]))):
-      while (endlinenum < len(clean_lines.elided)
-             and ';' not in clean_lines.elided[endlinenum][endpos:]):
-        endlinenum += 1
-        endpos = 0
-      if endlinenum < len(clean_lines.elided):
-        endline = clean_lines.elided[endlinenum]
-        # We allow a mix of whitespace and closing braces (e.g. for one-liner
-        # methods) and a single \ after the semicolon (for macros)
-        endpos = endline.find(';')
-        if not Match(r';[\s}]*(\\?)$', endline[endpos:]):
-          # Semicolon isn't the last character, there's something trailing.
-          # Output a warning if the semicolon is not contained inside
-          # a lambda expression.
-          if not Match(r'^[^{};]*\[[^\[\]]*\][^{}]*\{[^{}]*\}\s*\)*[;,]\s*$',
-                       endline):
-            error(filename, linenum, 'readability/braces', 4,
-                  'If/else bodies with multiple statements require braces')
-        elif endlinenum < len(clean_lines.elided) - 1:
-          # Make sure the next line is dedented
-          next_line = clean_lines.elided[endlinenum + 1]
-          next_indent = GetIndentLevel(next_line)
-          # With ambiguous nested if statements, this will error out on the
-          # if that *doesn't* match the else, regardless of whether it's the
-          # inner one or outer one.
-          if (if_match and Match(r'\s*else\b', next_line)
-              and next_indent != if_indent):
-            error(filename, linenum, 'readability/braces', 4,
-                  'Else clause should be indented at the same level as if. '
-                  'Ambiguous nested if/else chains require braces.')
-          elif next_indent > if_indent:
-            error(filename, linenum, 'readability/braces', 4,
-                  'If/else bodies with multiple statements require braces')
-
-
-def CheckTrailingSemicolon(filename, clean_lines, linenum, error):
-  """Looks for redundant trailing semicolon.
-
-  Args:
-    filename: The name of the current file.
-    clean_lines: A CleansedLines instance containing the file.
-    linenum: The number of the line to check.
-    error: The function to call with any errors found.
-  """
-
-  line = clean_lines.elided[linenum]
-
-  # Block bodies should not be followed by a semicolon.  Due to C++11
-  # brace initialization, there are more places where semicolons are
-  # required than not, so we use a whitelist approach to check these
-  # rather than a blacklist.  These are the places where "};" should
-  # be replaced by just "}":
-  # 1. Some flavor of block following closing parenthesis:
-  #    for (;;) {};
-  #    while (...) {};
-  #    switch (...) {};
-  #    Function(...) {};
-  #    if (...) {};
-  #    if (...) else if (...) {};
-  #
-  # 2. else block:
-  #    if (...) else {};
-  #
-  # 3. const member function:
-  #    Function(...) const {};
-  #
-  # 4. Block following some statement:
-  #    x = 42;
-  #    {};
-  #
-  # 5. Block at the beginning of a function:
-  #    Function(...) {
-  #      {};
-  #    }
-  #
-  #    Note that naively checking for the preceding "{" will also match
-  #    braces inside multi-dimensional arrays, but this is fine since
-  #    that expression will not contain semicolons.
-  #
-  # 6. Block following another block:
-  #    while (true) {}
-  #    {};
-  #
-  # 7. End of namespaces:
-  #    namespace {};
-  #
-  #    These semicolons seems far more common than other kinds of
-  #    redundant semicolons, possibly due to people converting classes
-  #    to namespaces.  For now we do not warn for this case.
-  #
-  # Try matching case 1 first.
-  match = Match(r'^(.*\)\s*)\{', line)
-  if match:
-    # Matched closing parenthesis (case 1).  Check the token before the
-    # matching opening parenthesis, and don't warn if it looks like a
-    # macro.  This avoids these false positives:
-    #  - macro that defines a base class
-    #  - multi-line macro that defines a base class
-    #  - macro that defines the whole class-head
-    #
-    # But we still issue warnings for macros that we know are safe to
-    # warn, specifically:
-    #  - TEST, TEST_F, TEST_P, MATCHER, MATCHER_P
-    #  - TYPED_TEST
-    #  - INTERFACE_DEF
-    #  - EXCLUSIVE_LOCKS_REQUIRED, SHARED_LOCKS_REQUIRED, LOCKS_EXCLUDED:
-    #
-    # We implement a whitelist of safe macros instead of a blacklist of
-    # unsafe macros, even though the latter appears less frequently in
-    # google code and would have been easier to implement.  This is because
-    # the downside for getting the whitelist wrong means some extra
-    # semicolons, while the downside for getting the blacklist wrong
-    # would result in compile errors.
-    #
-    # In addition to macros, we also don't want to warn on
-    #  - Compound literals
-    #  - Lambdas
-    #  - alignas specifier with anonymous structs:
-    closing_brace_pos = match.group(1).rfind(')')
-    opening_parenthesis = ReverseCloseExpression(
-        clean_lines, linenum, closing_brace_pos)
-    if opening_parenthesis[2] > -1:
-      line_prefix = opening_parenthesis[0][0:opening_parenthesis[2]]
-      macro = Search(r'\b([A-Z_]+)\s*$', line_prefix)
-      func = Match(r'^(.*\])\s*$', line_prefix)
-      if ((macro and
-           macro.group(1) not in (
-               'TEST', 'TEST_F', 'MATCHER', 'MATCHER_P', 'TYPED_TEST',
-               'EXCLUSIVE_LOCKS_REQUIRED', 'SHARED_LOCKS_REQUIRED',
-               'LOCKS_EXCLUDED', 'INTERFACE_DEF')) or
-          (func and not Search(r'\boperator\s*\[\s*\]', func.group(1))) or
-          Search(r'\b(?:struct|union)\s+alignas\s*$', line_prefix) or
-          Search(r'\s+=\s*$', line_prefix)):
-        match = None
-    if (match and
-        opening_parenthesis[1] > 1 and
-        Search(r'\]\s*$', clean_lines.elided[opening_parenthesis[1] - 1])):
-      # Multi-line lambda-expression
-      match = None
-
-  else:
-    # Try matching cases 2-3.
-    match = Match(r'^(.*(?:else|\)\s*const)\s*)\{', line)
-    if not match:
-      # Try matching cases 4-6.  These are always matched on separate lines.
-      #
-      # Note that we can't simply concatenate the previous line to the
-      # current line and do a single match, otherwise we may output
-      # duplicate warnings for the blank line case:
-      #   if (cond) {
-      #     // blank line
-      #   }
-      prevline = GetPreviousNonBlankLine(clean_lines, linenum)[0]
-      if prevline and Search(r'[;{}]\s*$', prevline):
-        match = Match(r'^(\s*)\{', line)
-
-  # Check matching closing brace
-  if match:
-    (endline, endlinenum, endpos) = CloseExpression(
-        clean_lines, linenum, len(match.group(1)))
-    if endpos > -1 and Match(r'^\s*;', endline[endpos:]):
-      # Current {} pair is eligible for semicolon check, and we have found
-      # the redundant semicolon, output warning here.
-      #
-      # Note: because we are scanning forward for opening braces, and
-      # outputting warnings for the matching closing brace, if there are
-      # nested blocks with trailing semicolons, we will get the error
-      # messages in reversed order.
-      error(filename, endlinenum, 'readability/braces', 4,
-            "You don't need a ; after a }")
-
-
-def CheckEmptyBlockBody(filename, clean_lines, linenum, error):
-  """Look for empty loop/conditional body with only a single semicolon.
-
-  Args:
-    filename: The name of the current file.
-    clean_lines: A CleansedLines instance containing the file.
-    linenum: The number of the line to check.
-    error: The function to call with any errors found.
-  """
-
-  # Search for loop keywords at the beginning of the line.  Because only
-  # whitespaces are allowed before the keywords, this will also ignore most
-  # do-while-loops, since those lines should start with closing brace.
-  #
-  # We also check "if" blocks here, since an empty conditional block
-  # is likely an error.
-  line = clean_lines.elided[linenum]
-  matched = Match(r'\s*(for|while|if)\s*\(', line)
-  if matched:
-    # Find the end of the conditional expression
-    (end_line, end_linenum, end_pos) = CloseExpression(
-        clean_lines, linenum, line.find('('))
-
-    # Output warning if what follows the condition expression is a semicolon.
-    # No warning for all other cases, including whitespace or newline, since we
-    # have a separate check for semicolons preceded by whitespace.
-    if end_pos >= 0 and Match(r';', end_line[end_pos:]):
-      if matched.group(1) == 'if':
-        error(filename, end_linenum, 'whitespace/empty_conditional_body', 5,
-              'Empty conditional bodies should use {}')
-      else:
-        error(filename, end_linenum, 'whitespace/empty_loop_body', 5,
-              'Empty loop bodies should use {} or continue')
-
-
-def FindCheckMacro(line):
-  """Find a replaceable CHECK-like macro.
-
-  Args:
-    line: line to search on.
-  Returns:
-    (macro name, start position), or (None, -1) if no replaceable
-    macro is found.
-  """
-  for macro in _CHECK_MACROS:
-    i = line.find(macro)
-    if i >= 0:
-      # Find opening parenthesis.  Do a regular expression match here
-      # to make sure that we are matching the expected CHECK macro, as
-      # opposed to some other macro that happens to contain the CHECK
-      # substring.
-      matched = Match(r'^(.*\b' + macro + r'\s*)\(', line)
-      if not matched:
-        continue
-      return (macro, len(matched.group(1)))
-  return (None, -1)
-
-
-def CheckCheck(filename, clean_lines, linenum, error):
-  """Checks the use of CHECK and EXPECT macros.
-
-  Args:
-    filename: The name of the current file.
-    clean_lines: A CleansedLines instance containing the file.
-    linenum: The number of the line to check.
-    error: The function to call with any errors found.
-  """
-
-  # Decide the set of replacement macros that should be suggested
-  lines = clean_lines.elided
-  (check_macro, start_pos) = FindCheckMacro(lines[linenum])
-  if not check_macro:
-    return
-
-  # Find end of the boolean expression by matching parentheses
-  (last_line, end_line, end_pos) = CloseExpression(
-      clean_lines, linenum, start_pos)
-  if end_pos < 0:
-    return
-
-  # If the check macro is followed by something other than a
-  # semicolon, assume users will log their own custom error messages
-  # and don't suggest any replacements.
-  if not Match(r'\s*;', last_line[end_pos:]):
-    return
-
-  if linenum == end_line:
-    expression = lines[linenum][start_pos + 1:end_pos - 1]
-  else:
-    expression = lines[linenum][start_pos + 1:]
-    for i in xrange(linenum + 1, end_line):
-      expression += lines[i]
-    expression += last_line[0:end_pos - 1]
-
-  # Parse expression so that we can take parentheses into account.
-  # This avoids false positives for inputs like "CHECK((a < 4) == b)",
-  # which is not replaceable by CHECK_LE.
-  lhs = ''
-  rhs = ''
-  operator = None
-  while expression:
-    matched = Match(r'^\s*(<<|<<=|>>|>>=|->\*|->|&&|\|\||'
-                    r'==|!=|>=|>|<=|<|\()(.*)$', expression)
-    if matched:
-      token = matched.group(1)
-      if token == '(':
-        # Parenthesized operand
-        expression = matched.group(2)
-        (end, _) = FindEndOfExpressionInLine(expression, 0, ['('])
-        if end < 0:
-          return  # Unmatched parenthesis
-        lhs += '(' + expression[0:end]
-        expression = expression[end:]
-      elif token in ('&&', '||'):
-        # Logical and/or operators.  This means the expression
-        # contains more than one term, for example:
-        #   CHECK(42 < a && a < b);
-        #
-        # These are not replaceable with CHECK_LE, so bail out early.
-        return
-      elif token in ('<<', '<<=', '>>', '>>=', '->*', '->'):
-        # Non-relational operator
-        lhs += token
-        expression = matched.group(2)
-      else:
-        # Relational operator
-        operator = token
-        rhs = matched.group(2)
-        break
-    else:
-      # Unparenthesized operand.  Instead of appending to lhs one character
-      # at a time, we do another regular expression match to consume several
-      # characters at once if possible.  Trivial benchmark shows that this
-      # is more efficient when the operands are longer than a single
-      # character, which is generally the case.
-      matched = Match(r'^([^-=!<>()&|]+)(.*)$', expression)
-      if not matched:
-        matched = Match(r'^(\s*\S)(.*)$', expression)
-        if not matched:
-          break
-      lhs += matched.group(1)
-      expression = matched.group(2)
-
-  # Only apply checks if we got all parts of the boolean expression
-  if not (lhs and operator and rhs):
-    return
-
-  # Check that rhs do not contain logical operators.  We already know
-  # that lhs is fine since the loop above parses out && and ||.
-  if rhs.find('&&') > -1 or rhs.find('||') > -1:
-    return
-
-  # At least one of the operands must be a constant literal.  This is
-  # to avoid suggesting replacements for unprintable things like
-  # CHECK(variable != iterator)
-  #
-  # The following pattern matches decimal, hex integers, strings, and
-  # characters (in that order).
-  lhs = lhs.strip()
-  rhs = rhs.strip()
-  match_constant = r'^([-+]?(\d+|0[xX][0-9a-fA-F]+)[lLuU]{0,3}|".*"|\'.*\')$'
-  if Match(match_constant, lhs) or Match(match_constant, rhs):
-    # Note: since we know both lhs and rhs, we can provide a more
-    # descriptive error message like:
-    #   Consider using CHECK_EQ(x, 42) instead of CHECK(x == 42)
-    # Instead of:
-    #   Consider using CHECK_EQ instead of CHECK(a == b)
-    #
-    # We are still keeping the less descriptive message because if lhs
-    # or rhs gets long, the error message might become unreadable.
-    error(filename, linenum, 'readability/check', 2,
-          'Consider using %s instead of %s(a %s b)' % (
-              _CHECK_REPLACEMENT[check_macro][operator],
-              check_macro, operator))
-
-
-def CheckAltTokens(filename, clean_lines, linenum, error):
-  """Check alternative keywords being used in boolean expressions.
-
-  Args:
-    filename: The name of the current file.
-    clean_lines: A CleansedLines instance containing the file.
-    linenum: The number of the line to check.
-    error: The function to call with any errors found.
-  """
-  line = clean_lines.elided[linenum]
-
-  # Avoid preprocessor lines
-  if Match(r'^\s*#', line):
-    return
-
-  # Last ditch effort to avoid multi-line comments.  This will not help
-  # if the comment started before the current line or ended after the
-  # current line, but it catches most of the false positives.  At least,
-  # it provides a way to workaround this warning for people who use
-  # multi-line comments in preprocessor macros.
-  #
-  # TODO(unknown): remove this once cpplint has better support for
-  # multi-line comments.
-  if line.find('/*') >= 0 or line.find('*/') >= 0:
-    return
-
-  for match in _ALT_TOKEN_REPLACEMENT_PATTERN.finditer(line):
-    error(filename, linenum, 'readability/alt_tokens', 2,
-          'Use operator %s instead of %s' % (
-              _ALT_TOKEN_REPLACEMENT[match.group(1)], match.group(1)))
-
-
-def GetLineWidth(line):
-  """Determines the width of the line in column positions.
-
-  Args:
-    line: A string, which may be a Unicode string.
-
-  Returns:
-    The width of the line in column positions, accounting for Unicode
-    combining characters and wide characters.
-  """
-  if isinstance(line, unicode):
-    width = 0
-    for uc in unicodedata.normalize('NFC', line):
-      if unicodedata.east_asian_width(uc) in ('W', 'F'):
-        width += 2
-      elif not unicodedata.combining(uc):
-        width += 1
-    return width
-  else:
-    return len(line)
-
-
-def CheckStyle(filename, clean_lines, linenum, file_extension, nesting_state,
-               error):
-  """Checks rules from the 'C++ style rules' section of cppguide.html.
-
-  Most of these rules are hard to test (naming, comment style), but we
-  do what we can.  In particular we check for 2-space indents, line lengths,
-  tab usage, spaces inside code, etc.
-
-  Args:
-    filename: The name of the current file.
-    clean_lines: A CleansedLines instance containing the file.
-    linenum: The number of the line to check.
-    file_extension: The extension (without the dot) of the filename.
-    nesting_state: A NestingState instance which maintains information about
-                   the current stack of nested blocks being parsed.
-    error: The function to call with any errors found.
-  """
-
-  # Don't use "elided" lines here, otherwise we can't check commented lines.
-  # Don't want to use "raw" either, because we don't want to check inside C++11
-  # raw strings,
-  raw_lines = clean_lines.lines_without_raw_strings
-  line = raw_lines[linenum]
-
-  if line.find('\t') != -1:
-    error(filename, linenum, 'whitespace/tab', 1,
-          'Tab found; better to use spaces')
-
-  # One or three blank spaces at the beginning of the line is weird; it's
-  # hard to reconcile that with 2-space indents.
-  # NOTE: here are the conditions rob pike used for his tests.  Mine aren't
-  # as sophisticated, but it may be worth becoming so:  RLENGTH==initial_spaces
-  # if(RLENGTH > 20) complain = 0;
-  # if(match($0, " +(error|private|public|protected):")) complain = 0;
-  # if(match(prev, "&& *$")) complain = 0;
-  # if(match(prev, "\\|\\| *$")) complain = 0;
-  # if(match(prev, "[\",=><] *$")) complain = 0;
-  # if(match($0, " <<")) complain = 0;
-  # if(match(prev, " +for \\(")) complain = 0;
-  # if(prevodd && match(prevprev, " +for \\(")) complain = 0;
-  scope_or_label_pattern = r'\s*\w+\s*:\s*\\?$'
-  # classinfo = nesting_state.InnermostClass()
-  initial_spaces = 0
-  cleansed_line = clean_lines.elided[linenum]
-  while initial_spaces < len(line) and line[initial_spaces] == ' ':
-    initial_spaces += 1
-  if line and line[-1].isspace():
-    error(filename, linenum, 'whitespace/end_of_line', 4,
-          'Line ends in whitespace.  Consider deleting these extra spaces.')
-  # There are certain situations we allow one space, notably for
-  # section labels, and also lines containing multi-line raw strings.
-  elif ((initial_spaces == 1 or initial_spaces == 3) and
-        not Match(scope_or_label_pattern, cleansed_line) and
-        not (clean_lines.raw_lines[linenum] != line and
-             Match(r'^\s*""', line))):
-    error(filename, linenum, 'whitespace/indent', 3,
-          'Weird number of spaces at line-start.  '
-          'Are you using a 2-space indent?')
-
-  # Check if the line is a header guard.
-  is_header_guard = False
-  if file_extension == 'h':
-    cppvar = GetHeaderGuardCPPVariable(filename)
-    if (line.startswith('#ifndef %s' % cppvar) or
-        line.startswith('#define %s' % cppvar) or
-        line.startswith('#endif  // %s' % cppvar)):
-      is_header_guard = True
-  # #include lines and header guards can be long, since there's no clean way to
-  # split them.
-  #
-  # URLs can be long too.  It's possible to split these, but it makes them
-  # harder to cut&paste.
-  #
-  # The "$Id:...$" comment may also get very long without it being the
-  # developers fault.
-  if (not line.startswith('#include') and not is_header_guard and
-      not Match(r'^\s*//.*http(s?)://\S*$', line) and
-      not Match(r'^// \$Id:.*#[0-9]+ \$$', line)):
-    line_width = GetLineWidth(line)
-    extended_length = int((_line_length * 1.25))
-    if line_width > extended_length:
-      error(filename, linenum, 'whitespace/line_length', 4,
-            'Lines should very rarely be longer than %i characters' %
-            extended_length)
-    elif line_width > _line_length:
-      error(filename, linenum, 'whitespace/line_length', 2,
-            'Lines should be <= %i characters long' % _line_length)
-
-  if (cleansed_line.count(';') > 1 and
-      # for loops are allowed two ;'s (and may run over two lines).
-      cleansed_line.find('for') == -1 and
-      (GetPreviousNonBlankLine(clean_lines, linenum)[0].find('for') == -1 or
-       GetPreviousNonBlankLine(clean_lines, linenum)[0].find(';') != -1) and
-      # It's ok to have many commands in a switch case that fits in 1 line
-      not ((cleansed_line.find('case ') != -1 or
-            cleansed_line.find('default:') != -1) and
-           cleansed_line.find('break;') != -1)):
-    error(filename, linenum, 'whitespace/newline', 0,
-          'More than one command on the same line')
-
-  # Some more style checks
-  CheckBraces(filename, clean_lines, linenum, error)
-  CheckTrailingSemicolon(filename, clean_lines, linenum, error)
-  CheckEmptyBlockBody(filename, clean_lines, linenum, error)
-  CheckAccess(filename, clean_lines, linenum, nesting_state, error)
-  CheckSpacing(filename, clean_lines, linenum, nesting_state, error)
-  CheckOperatorSpacing(filename, clean_lines, linenum, error)
-  CheckParenthesisSpacing(filename, clean_lines, linenum, error)
-  CheckCommaSpacing(filename, clean_lines, linenum, error)
-  CheckBracesSpacing(filename, clean_lines, linenum, error)
-  CheckSpacingForFunctionCall(filename, clean_lines, linenum, error)
-  CheckRValueReference(filename, clean_lines, linenum, nesting_state, error)
-  CheckCheck(filename, clean_lines, linenum, error)
-  CheckAltTokens(filename, clean_lines, linenum, error)
-  classinfo = nesting_state.InnermostClass()
-  if classinfo:
-    CheckSectionSpacing(filename, clean_lines, classinfo, linenum, error)
-
-
-_RE_PATTERN_INCLUDE = re.compile(r'^\s*#\s*include\s*([<"])([^>"]*)[>"].*$')
-# Matches the first component of a filename delimited by -s and _s. That is:
-#  _RE_FIRST_COMPONENT.match('foo').group(0) == 'foo'
-#  _RE_FIRST_COMPONENT.match('foo.cc').group(0) == 'foo'
-#  _RE_FIRST_COMPONENT.match('foo-bar_baz.cc').group(0) == 'foo'
-#  _RE_FIRST_COMPONENT.match('foo_bar-baz.cc').group(0) == 'foo'
-_RE_FIRST_COMPONENT = re.compile(r'^[^-_.]+')
-
-
-def _DropCommonSuffixes(filename):
-  """Drops common suffixes like _test.cc or -inl.h from filename.
-
-  For example:
-    >>> _DropCommonSuffixes('foo/foo-inl.h')
-    'foo/foo'
-    >>> _DropCommonSuffixes('foo/bar/foo.cc')
-    'foo/bar/foo'
-    >>> _DropCommonSuffixes('foo/foo_internal.h')
-    'foo/foo'
-    >>> _DropCommonSuffixes('foo/foo_unusualinternal.h')
-    'foo/foo_unusualinternal'
-
-  Args:
-    filename: The input filename.
-
-  Returns:
-    The filename with the common suffix removed.
-  """
-  for suffix in ('test.cc', 'regtest.cc', 'unittest.cc',
-                 'inl.h', 'impl.h', 'internal.h'):
-    if (filename.endswith(suffix) and len(filename) > len(suffix) and
-        filename[-len(suffix) - 1] in ('-', '_')):
-      return filename[:-len(suffix) - 1]
-  return os.path.splitext(filename)[0]
-
-
-def _IsTestFilename(filename):
-  """Determines if the given filename has a suffix that identifies it as a test.
-
-  Args:
-    filename: The input filename.
-
-  Returns:
-    True if 'filename' looks like a test, False otherwise.
-  """
-  if (filename.endswith('_test.cc') or
-      filename.endswith('_unittest.cc') or
-      filename.endswith('_regtest.cc')):
-    return True
-  else:
-    return False
-
-
-def _ClassifyInclude(fileinfo, include, is_system):
-  """Figures out what kind of header 'include' is.
-
-  Args:
-    fileinfo: The current file cpplint is running over. A FileInfo instance.
-    include: The path to a #included file.
-    is_system: True if the #include used <> rather than "".
-
-  Returns:
-    One of the _XXX_HEADER constants.
-
-  For example:
-    >>> _ClassifyInclude(FileInfo('foo/foo.cc'), 'stdio.h', True)
-    _C_SYS_HEADER
-    >>> _ClassifyInclude(FileInfo('foo/foo.cc'), 'string', True)
-    _CPP_SYS_HEADER
-    >>> _ClassifyInclude(FileInfo('foo/foo.cc'), 'foo/foo.h', False)
-    _LIKELY_MY_HEADER
-    >>> _ClassifyInclude(FileInfo('foo/foo_unknown_extension.cc'),
-    ...                  'bar/foo_other_ext.h', False)
-    _POSSIBLE_MY_HEADER
-    >>> _ClassifyInclude(FileInfo('foo/foo.cc'), 'foo/bar.h', False)
-    _OTHER_HEADER
-  """
-  # This is a list of all standard c++ header files, except
-  # those already checked for above.
-  is_cpp_h = include in _CPP_HEADERS
-
-  if is_system:
-    if is_cpp_h:
-      return _CPP_SYS_HEADER
-    else:
-      return _C_SYS_HEADER
-
-  # If the target file and the include we're checking share a
-  # basename when we drop common extensions, and the include
-  # lives in . , then it's likely to be owned by the target file.
-  target_dir, target_base = (
-      os.path.split(_DropCommonSuffixes(fileinfo.RepositoryName())))
-  include_dir, include_base = os.path.split(_DropCommonSuffixes(include))
-  if target_base == include_base and (
-      include_dir == target_dir or
-      include_dir == os.path.normpath(target_dir + '/../public')):
-    return _LIKELY_MY_HEADER
-
-  # If the target and include share some initial basename
-  # component, it's possible the target is implementing the
-  # include, so it's allowed to be first, but we'll never
-  # complain if it's not there.
-  target_first_component = _RE_FIRST_COMPONENT.match(target_base)
-  include_first_component = _RE_FIRST_COMPONENT.match(include_base)
-  if (target_first_component and include_first_component and
-      target_first_component.group(0) ==
-      include_first_component.group(0)):
-    return _POSSIBLE_MY_HEADER
-
-  return _OTHER_HEADER
-
-
-
-def CheckIncludeLine(filename, clean_lines, linenum, include_state, error):
-  """Check rules that are applicable to #include lines.
-
-  Strings on #include lines are NOT removed from elided line, to make
-  certain tasks easier. However, to prevent false positives, checks
-  applicable to #include lines in CheckLanguage must be put here.
-
-  Args:
-    filename: The name of the current file.
-    clean_lines: A CleansedLines instance containing the file.
-    linenum: The number of the line to check.
-    include_state: An _IncludeState instance in which the headers are inserted.
-    error: The function to call with any errors found.
-  """
-  fileinfo = FileInfo(filename)
-  line = clean_lines.lines[linenum]
-
-  # "include" should use the new style "foo/bar.h" instead of just "bar.h"
-  # Only do this check if the included header follows google naming
-  # conventions.  If not, assume that it's a 3rd party API that
-  # requires special include conventions.
-  #
-  # We also make an exception for Lua headers, which follow google
-  # naming convention but not the include convention.
-  match = Match(r'#include\s*"([^/]+\.h)"', line)
-  if match and not _THIRD_PARTY_HEADERS_PATTERN.match(match.group(1)):
-    error(filename, linenum, 'build/include', 4,
-          'Include the directory when naming .h files')
-
-  # we shouldn't include a file more than once. actually, there are a
-  # handful of instances where doing so is okay, but in general it's
-  # not.
-  match = _RE_PATTERN_INCLUDE.search(line)
-  if match:
-    include = match.group(2)
-    is_system = (match.group(1) == '<')
-    duplicate_line = include_state.FindHeader(include)
-    if duplicate_line >= 0:
-      error(filename, linenum, 'build/include', 4,
-            '"%s" already included at %s:%s' %
-            (include, filename, duplicate_line))
-    elif (include.endswith('.cc') and
-          os.path.dirname(fileinfo.RepositoryName()) != os.path.dirname(include)):
-      error(filename, linenum, 'build/include', 4,
-            'Do not include .cc files from other packages')
-    elif not _THIRD_PARTY_HEADERS_PATTERN.match(include):
-      include_state.include_list[-1].append((include, linenum))
-
-      # We want to ensure that headers appear in the right order:
-      # 1) for foo.cc, foo.h  (preferred location)
-      # 2) c system files
-      # 3) cpp system files
-      # 4) for foo.cc, foo.h  (deprecated location)
-      # 5) other google headers
-      #
-      # We classify each include statement as one of those 5 types
-      # using a number of techniques. The include_state object keeps
-      # track of the highest type seen, and complains if we see a
-      # lower type after that.
-      error_message = include_state.CheckNextIncludeOrder(
-          _ClassifyInclude(fileinfo, include, is_system))
-      if error_message:
-        error(filename, linenum, 'build/include_order', 4,
-              '%s. Should be: %s.h, c system, c++ system, other.' %
-              (error_message, fileinfo.BaseName()))
-      canonical_include = include_state.CanonicalizeAlphabeticalOrder(include)
-      if not include_state.IsInAlphabeticalOrder(
-          clean_lines, linenum, canonical_include):
-        error(filename, linenum, 'build/include_alpha', 4,
-              'Include "%s" not in alphabetical order' % include)
-      include_state.SetLastHeader(canonical_include)
-
-
-
-def _GetTextInside(text, start_pattern):
-  r"""Retrieves all the text between matching open and close parentheses.
-
-  Given a string of lines and a regular expression string, retrieve all the text
-  following the expression and between opening punctuation symbols like
-  (, [, or {, and the matching close-punctuation symbol. This properly nested
-  occurrences of the punctuations, so for the text like
-    printf(a(), b(c()));
-  a call to _GetTextInside(text, r'printf\(') will return 'a(), b(c())'.
-  start_pattern must match string having an open punctuation symbol at the end.
-
-  Args:
-    text: The lines to extract text. Its comments and strings must be elided.
-           It can be single line and can span multiple lines.
-    start_pattern: The regexp string indicating where to start extracting
-                   the text.
-  Returns:
-    The extracted text.
-    None if either the opening string or ending punctuation could not be found.
-  """
-  # TODO(unknown): Audit cpplint.py to see what places could be profitably
-  # rewritten to use _GetTextInside (and use inferior regexp matching today).
-
-  # Give opening punctuations to get the matching close-punctuations.
-  matching_punctuation = {'(': ')', '{': '}', '[': ']'}
-  closing_punctuation = set(matching_punctuation.itervalues())
-
-  # Find the position to start extracting text.
-  match = re.search(start_pattern, text, re.M)
-  if not match:  # start_pattern not found in text.
-    return None
-  start_position = match.end(0)
-
-  assert start_position > 0, (
-      'start_pattern must ends with an opening punctuation.')
-  assert text[start_position - 1] in matching_punctuation, (
-      'start_pattern must ends with an opening punctuation.')
-  # Stack of closing punctuations we expect to have in text after position.
-  punctuation_stack = [matching_punctuation[text[start_position - 1]]]
-  position = start_position
-  while punctuation_stack and position < len(text):
-    if text[position] == punctuation_stack[-1]:
-      punctuation_stack.pop()
-    elif text[position] in closing_punctuation:
-      # A closing punctuation without matching opening punctuations.
-      return None
-    elif text[position] in matching_punctuation:
-      punctuation_stack.append(matching_punctuation[text[position]])
-    position += 1
-  if punctuation_stack:
-    # Opening punctuations left without matching close-punctuations.
-    return None
-  # punctuations match.
-  return text[start_position:position - 1]
-
-
-# Patterns for matching call-by-reference parameters.
-#
-# Supports nested templates up to 2 levels deep using this messy pattern:
-#   < (?: < (?: < [^<>]*
-#               >
-#           |   [^<>] )*
-#         >
-#     |   [^<>] )*
-#   >
-_RE_PATTERN_IDENT = r'[_a-zA-Z]\w*'  # =~ [[:alpha:]][[:alnum:]]*
-_RE_PATTERN_TYPE = (
-    r'(?:const\s+)?(?:typename\s+|class\s+|struct\s+|union\s+|enum\s+)?'
-    r'(?:\w|'
-    r'\s*<(?:<(?:<[^<>]*>|[^<>])*>|[^<>])*>|'
-    r'::)+')
-# A call-by-reference parameter ends with '& identifier'.
-_RE_PATTERN_REF_PARAM = re.compile(
-    r'(' + _RE_PATTERN_TYPE + r'(?:\s*(?:\bconst\b|[*]))*\s*'
-    r'&\s*' + _RE_PATTERN_IDENT + r')\s*(?:=[^,()]+)?[,)]')
-# A call-by-const-reference parameter either ends with 'const& identifier'
-# or looks like 'const type& identifier' when 'type' is atomic.
-_RE_PATTERN_CONST_REF_PARAM = (
-    r'(?:.*\s*\bconst\s*&\s*' + _RE_PATTERN_IDENT +
-    r'|const\s+' + _RE_PATTERN_TYPE + r'\s*&\s*' + _RE_PATTERN_IDENT + r')')
-
-
-def CheckLanguage(filename, clean_lines, linenum, file_extension,
-                  include_state, nesting_state, error):
-  """Checks rules from the 'C++ language rules' section of cppguide.html.
-
-  Some of these rules are hard to test (function overloading, using
-  uint32 inappropriately), but we do the best we can.
-
-  Args:
-    filename: The name of the current file.
-    clean_lines: A CleansedLines instance containing the file.
-    linenum: The number of the line to check.
-    file_extension: The extension (without the dot) of the filename.
-    include_state: An _IncludeState instance in which the headers are inserted.
-    nesting_state: A NestingState instance which maintains information about
-                   the current stack of nested blocks being parsed.
-    error: The function to call with any errors found.
-  """
-  # If the line is empty or consists of entirely a comment, no need to
-  # check it.
-  line = clean_lines.elided[linenum]
-  if not line:
-    return
-
-  match = _RE_PATTERN_INCLUDE.search(line)
-  if match:
-    CheckIncludeLine(filename, clean_lines, linenum, include_state, error)
-    return
-
-  # Reset include state across preprocessor directives.  This is meant
-  # to silence warnings for conditional includes.
-  match = Match(r'^\s*#\s*(if|ifdef|ifndef|elif|else|endif)\b', line)
-  if match:
-    include_state.ResetSection(match.group(1))
-
-  # Make Windows paths like Unix.
-  # fullname = os.path.abspath(filename).replace('\\', '/')
-
-  # Perform other checks now that we are sure that this is not an include line
-  CheckCasts(filename, clean_lines, linenum, error)
-  CheckGlobalStatic(filename, clean_lines, linenum, error)
-  CheckPrintf(filename, clean_lines, linenum, error)
-
-  if file_extension == 'h':
-    # TODO(unknown): check that 1-arg constructors are explicit.
-    #                How to tell it's a constructor?
-    #                (handled in CheckForNonStandardConstructs for now)
-    # TODO(unknown): check that classes declare or disable copy/assign
-    #                (level 1 error)
-    pass
-
-  # Check if people are using the verboten C basic types.  The only exception
-  # we regularly allow is "unsigned short port" for port.
-  if Search(r'\bshort port\b', line):
-    if not Search(r'\bunsigned short port\b', line):
-      error(filename, linenum, 'runtime/int', 4,
-            'Use "unsigned short" for ports, not "short"')
-  else:
-    match = Search(r'\b(short|long(?! +double)|long long)\b', line)
-    if match:
-      error(filename, linenum, 'runtime/int', 4,
-            'Use int16/int64/etc, rather than the C type %s' % match.group(1))
-
-  # Check if some verboten operator overloading is going on
-  # TODO(unknown): catch out-of-line unary operator&:
-  #   class X {};
-  #   int operator&(const X& x) { return 42; }  // unary operator&
-  # The trick is it's hard to tell apart from binary operator&:
-  #   class Y { int operator&(const Y& x) { return 23; } }; // binary operator&
-  if Search(r'\boperator\s*&\s*\(\s*\)', line):
-    error(filename, linenum, 'runtime/operator', 4,
-          'Unary operator& is dangerous.  Do not use it.')
-
-  # Check for suspicious usage of "if" like
-  # } if (a == b) {
-  if Search(r'\}\s*if\s*\(', line):
-    error(filename, linenum, 'readability/braces', 4,
-          'Did you mean "else if"? If not, start a new line for "if".')
-
-  # Check for potential format string bugs like printf(foo).
-  # We constrain the pattern not to pick things like DocidForPrintf(foo).
-  # Not perfect but it can catch printf(foo.c_str()) and printf(foo->c_str())
-  # TODO(unknown): Catch the following case. Need to change the calling
-  # convention of the whole function to process multiple line to handle it.
-  #   printf(
-  #       boy_this_is_a_really_long_variable_that_cannot_fit_on_the_prev_line);
-  printf_args = _GetTextInside(line, r'(?i)\b(string)?printf\s*\(')
-  if printf_args:
-    match = Match(r'([\w.\->()]+)$', printf_args)
-    if match and match.group(1) != '__VA_ARGS__':
-      function_name = re.search(r'\b((?:string)?printf)\s*\(',
-                                line, re.I).group(1)
-      error(filename, linenum, 'runtime/printf', 4,
-            'Potential format string bug. Do %s("%%s", %s) instead.'
-            % (function_name, match.group(1)))
-
-  # Check for potential memset bugs like memset(buf, sizeof(buf), 0).
-  match = Search(r'memset\s*\(([^,]*),\s*([^,]*),\s*0\s*\)', line)
-  if match and not Match(r"^''|-?[0-9]+|0x[0-9A-Fa-f]$", match.group(2)):
-    error(filename, linenum, 'runtime/memset', 4,
-          'Did you mean "memset(%s, 0, %s)"?'
-          % (match.group(1), match.group(2)))
-
-  if Search(r'\busing namespace\b', line):
-    error(filename, linenum, 'build/namespaces', 5,
-          'Do not use namespace using-directives.  '
-          'Use using-declarations instead.')
-
-  # Detect variable-length arrays.
-  match = Match(r'\s*(.+::)?(\w+) [a-z]\w*\[(.+)];', line)
-  if (match and match.group(2) != 'return' and match.group(2) != 'delete' and
-      match.group(3).find(']') == -1):
-    # Split the size using space and arithmetic operators as delimiters.
-    # If any of the resulting tokens are not compile time constants then
-    # report the error.
-    tokens = re.split(r'\s|\+|\-|\*|\/|<<|>>]', match.group(3))
-    is_const = True
-    skip_next = False
-    for tok in tokens:
-      if skip_next:
-        skip_next = False
-        continue
-
-      if Search(r'sizeof\(.+\)', tok): continue
-      if Search(r'arraysize\(\w+\)', tok): continue
-
-      tok = tok.lstrip('(')
-      tok = tok.rstrip(')')
-      if not tok: continue
-      if Match(r'\d+', tok): continue
-      if Match(r'0[xX][0-9a-fA-F]+', tok): continue
-      if Match(r'k[A-Z0-9]\w*', tok): continue
-      if Match(r'(.+::)?k[A-Z0-9]\w*', tok): continue
-      if Match(r'(.+::)?[A-Z][A-Z0-9_]*', tok): continue
-      # A catch all for tricky sizeof cases, including 'sizeof expression',
-      # 'sizeof(*type)', 'sizeof(const type)', 'sizeof(struct StructName)'
-      # requires skipping the next token because we split on ' ' and '*'.
-      if tok.startswith('sizeof'):
-        skip_next = True
-        continue
-      is_const = False
-      break
-    if not is_const:
-      error(filename, linenum, 'runtime/arrays', 1,
-            'Do not use variable-length arrays.  Use an appropriately named '
-            "('k' followed by CamelCase) compile-time constant for the size.")
-
-  # Check for use of unnamed namespaces in header files.  Registration
-  # macros are typically OK, so we allow use of "namespace {" on lines
-  # that end with backslashes.
-  if (file_extension == 'h'
-      and Search(r'\bnamespace\s*{', line)
-      and line[-1] != '\\'):
-    error(filename, linenum, 'build/namespaces', 4,
-          'Do not use unnamed namespaces in header files.  See '
-          'http://google-styleguide.googlecode.com/svn/trunk/cppguide.xml#Namespaces'
-          ' for more information.')
-
-
-def CheckGlobalStatic(filename, clean_lines, linenum, error):
-  """Check for unsafe global or static objects.
-
-  Args:
-    filename: The name of the current file.
-    clean_lines: A CleansedLines instance containing the file.
-    linenum: The number of the line to check.
-    error: The function to call with any errors found.
-  """
-  line = clean_lines.elided[linenum]
-
-  # Match two lines at a time to support multiline declarations
-  if linenum + 1 < clean_lines.NumLines() and not Search(r'[;({]', line):
-    line += clean_lines.elided[linenum + 1].strip()
-
-  # Check for people declaring static/global STL strings at the top level.
-  # This is dangerous because the C++ language does not guarantee that
-  # globals with constructors are initialized before the first access.
-  match = Match(
-      r'((?:|static +)(?:|const +))string +([a-zA-Z0-9_:]+)\b(.*)',
-      line)
-
-  # Remove false positives:
-  # - String pointers (as opposed to values).
-  #    string *pointer
-  #    const string *pointer
-  #    string const *pointer
-  #    string *const pointer
-  #
-  # - Functions and template specializations.
-  #    string Function<Type>(...
-  #    string Class<Type>::Method(...
-  #
-  # - Operators.  These are matched separately because operator names
-  #   cross non-word boundaries, and trying to match both operators
-  #   and functions at the same time would decrease accuracy of
-  #   matching identifiers.
-  #    string Class::operator*()
-  if (match and
-      not Search(r'\bstring\b(\s+const)?\s*\*\s*(const\s+)?\w', line) and
-      not Search(r'\boperator\W', line) and
-      not Match(r'\s*(<.*>)?(::[a-zA-Z0-9_]+)*\s*\(([^"]|$)', match.group(3))):
-    error(filename, linenum, 'runtime/string', 4,
-          'For a static/global string constant, use a C style string instead: '
-          '"%schar %s[]".' %
-          (match.group(1), match.group(2)))
-
-  if Search(r'\b([A-Za-z0-9_]*_)\(\1\)', line):
-    error(filename, linenum, 'runtime/init', 4,
-          'You seem to be initializing a member variable with itself.')
-
-
-def CheckPrintf(filename, clean_lines, linenum, error):
-  """Check for printf related issues.
-
-  Args:
-    filename: The name of the current file.
-    clean_lines: A CleansedLines instance containing the file.
-    linenum: The number of the line to check.
-    error: The function to call with any errors found.
-  """
-  line = clean_lines.elided[linenum]
-
-  # When snprintf is used, the second argument shouldn't be a literal.
-  match = Search(r'snprintf\s*\(([^,]*),\s*([0-9]*)\s*,', line)
-  if match and match.group(2) != '0':
-    # If 2nd arg is zero, snprintf is used to calculate size.
-    error(filename, linenum, 'runtime/printf', 3,
-          'If you can, use sizeof(%s) instead of %s as the 2nd arg '
-          'to snprintf.' % (match.group(1), match.group(2)))
-
-  # Check if some verboten C functions are being used.
-  if Search(r'\bsprintf\s*\(', line):
-    error(filename, linenum, 'runtime/printf', 5,
-          'Never use sprintf. Use snprintf instead.')
-  match = Search(r'\b(strcpy|strcat)\s*\(', line)
-  if match:
-    error(filename, linenum, 'runtime/printf', 4,
-          'Almost always, snprintf is better than %s' % match.group(1))
-
-
-def IsDerivedFunction(clean_lines, linenum):
-  """Check if current line contains an inherited function.
-
-  Args:
-    clean_lines: A CleansedLines instance containing the file.
-    linenum: The number of the line to check.
-  Returns:
-    True if current line contains a function with "override"
-    virt-specifier.
-  """
-  # Scan back a few lines for start of current function
-  for i in xrange(linenum, max(-1, linenum - 10), -1):
-    match = Match(r'^([^()]*\w+)\(', clean_lines.elided[i])
-    if match:
-      # Look for "override" after the matching closing parenthesis
-      line, _, closing_paren = CloseExpression(
-          clean_lines, i, len(match.group(1)))
-      return (closing_paren >= 0 and
-              Search(r'\boverride\b', line[closing_paren:]))
-  return False
-
-
-def IsOutOfLineMethodDefinition(clean_lines, linenum):
-  """Check if current line contains an out-of-line method definition.
-
-  Args:
-    clean_lines: A CleansedLines instance containing the file.
-    linenum: The number of the line to check.
-  Returns:
-    True if current line contains an out-of-line method definition.
-  """
-  # Scan back a few lines for start of current function
-  for i in xrange(linenum, max(-1, linenum - 10), -1):
-    if Match(r'^([^()]*\w+)\(', clean_lines.elided[i]):
-      return Match(r'^[^()]*\w+::\w+\(', clean_lines.elided[i]) is not None
-  return False
-
-
-def IsInitializerList(clean_lines, linenum):
-  """Check if current line is inside constructor initializer list.
-
-  Args:
-    clean_lines: A CleansedLines instance containing the file.
-    linenum: The number of the line to check.
-  Returns:
-    True if current line appears to be inside constructor initializer
-    list, False otherwise.
-  """
-  for i in xrange(linenum, 1, -1):
-    line = clean_lines.elided[i]
-    if i == linenum:
-      remove_function_body = Match(r'^(.*)\{\s*$', line)
-      if remove_function_body:
-        line = remove_function_body.group(1)
-
-    if Search(r'\s:\s*\w+[({]', line):
-      # A lone colon tend to indicate the start of a constructor
-      # initializer list.  It could also be a ternary operator, which
-      # also tend to appear in constructor initializer lists as
-      # opposed to parameter lists.
-      return True
-    if Search(r'\}\s*,\s*$', line):
-      # A closing brace followed by a comma is probably the end of a
-      # brace-initialized member in constructor initializer list.
-      return True
-    if Search(r'[{};]\s*$', line):
-      # Found one of the following:
-      # - A closing brace or semicolon, probably the end of the previous
-      #   function.
-      # - An opening brace, probably the start of current class or namespace.
-      #
-      # Current line is probably not inside an initializer list since
-      # we saw one of those things without seeing the starting colon.
-      return False
-
-  # Got to the beginning of the file without seeing the start of
-  # constructor initializer list.
-  return False
-
-
-def CheckForNonConstReference(filename, clean_lines, linenum,
-                              nesting_state, error):
-  """Check for non-const references.
-
-  Separate from CheckLanguage since it scans backwards from current
-  line, instead of scanning forward.
-
-  Args:
-    filename: The name of the current file.
-    clean_lines: A CleansedLines instance containing the file.
-    linenum: The number of the line to check.
-    nesting_state: A NestingState instance which maintains information about
-                   the current stack of nested blocks being parsed.
-    error: The function to call with any errors found.
-  """
-  # Do nothing if there is no '&' on current line.
-  line = clean_lines.elided[linenum]
-  if '&' not in line:
-    return
-
-  # If a function is inherited, current function doesn't have much of
-  # a choice, so any non-const references should not be blamed on
-  # derived function.
-  if IsDerivedFunction(clean_lines, linenum):
-    return
-
-  # Don't warn on out-of-line method definitions, as we would warn on the
-  # in-line declaration, if it isn't marked with 'override'.
-  if IsOutOfLineMethodDefinition(clean_lines, linenum):
-    return
-
-  # Long type names may be broken across multiple lines, usually in one
-  # of these forms:
-  #   LongType
-  #       ::LongTypeContinued &identifier
-  #   LongType::
-  #       LongTypeContinued &identifier
-  #   LongType<
-  #       ...>::LongTypeContinued &identifier
-  #
-  # If we detected a type split across two lines, join the previous
-  # line to current line so that we can match const references
-  # accordingly.
-  #
-  # Note that this only scans back one line, since scanning back
-  # arbitrary number of lines would be expensive.  If you have a type
-  # that spans more than 2 lines, please use a typedef.
-  if linenum > 1:
-    previous = None
-    if Match(r'\s*::(?:[\w<>]|::)+\s*&\s*\S', line):
-      # previous_line\n + ::current_line
-      previous = Search(r'\b((?:const\s*)?(?:[\w<>]|::)+[\w<>])\s*$',
-                        clean_lines.elided[linenum - 1])
-    elif Match(r'\s*[a-zA-Z_]([\w<>]|::)+\s*&\s*\S', line):
-      # previous_line::\n + current_line
-      previous = Search(r'\b((?:const\s*)?(?:[\w<>]|::)+::)\s*$',
-                        clean_lines.elided[linenum - 1])
-    if previous:
-      line = previous.group(1) + line.lstrip()
-    else:
-      # Check for templated parameter that is split across multiple lines
-      endpos = line.rfind('>')
-      if endpos > -1:
-        (_, startline, startpos) = ReverseCloseExpression(
-            clean_lines, linenum, endpos)
-        if startpos > -1 and startline < linenum:
-          # Found the matching < on an earlier line, collect all
-          # pieces up to current line.
-          line = ''
-          for i in xrange(startline, linenum + 1):
-            line += clean_lines.elided[i].strip()
-
-  # Check for non-const references in function parameters.  A single '&' may
-  # found in the following places:
-  #   inside expression: binary & for bitwise AND
-  #   inside expression: unary & for taking the address of something
-  #   inside declarators: reference parameter
-  # We will exclude the first two cases by checking that we are not inside a
-  # function body, including one that was just introduced by a trailing '{'.
-  # TODO(unknown): Doesn't account for 'catch(Exception& e)' [rare].
-  if (nesting_state.previous_stack_top and
-      not (isinstance(nesting_state.previous_stack_top, _ClassInfo) or
-           isinstance(nesting_state.previous_stack_top, _NamespaceInfo))):
-    # Not at toplevel, not within a class, and not within a namespace
-    return
-
-  # Avoid initializer lists.  We only need to scan back from the
-  # current line for something that starts with ':'.
-  #
-  # We don't need to check the current line, since the '&' would
-  # appear inside the second set of parentheses on the current line as
-  # opposed to the first set.
-  if linenum > 0:
-    for i in xrange(linenum - 1, max(0, linenum - 10), -1):
-      previous_line = clean_lines.elided[i]
-      if not Search(r'[),]\s*$', previous_line):
-        break
-      if Match(r'^\s*:\s+\S', previous_line):
-        return
-
-  # Avoid preprocessors
-  if Search(r'\\\s*$', line):
-    return
-
-  # Avoid constructor initializer lists
-  if IsInitializerList(clean_lines, linenum):
-    return
-
-  # We allow non-const references in a few standard places, like functions
-  # called "swap()" or iostream operators like "<<" or ">>".  Do not check
-  # those function parameters.
-  #
-  # We also accept & in static_assert, which looks like a function but
-  # it's actually a declaration expression.
-  whitelisted_functions = (r'(?:[sS]wap(?:<\w:+>)?|'
-                           r'operator\s*[<>][<>]|'
-                           r'static_assert|COMPILE_ASSERT'
-                           r')\s*\(')
-  if Search(whitelisted_functions, line):
-    return
-  elif not Search(r'\S+\([^)]*$', line):
-    # Don't see a whitelisted function on this line.  Actually we
-    # didn't see any function name on this line, so this is likely a
-    # multi-line parameter list.  Try a bit harder to catch this case.
-    for i in xrange(2):
-      if (linenum > i and
-          Search(whitelisted_functions, clean_lines.elided[linenum - i - 1])):
-        return
-
-  decls = ReplaceAll(r'{[^}]*}', ' ', line)  # exclude function body
-  for parameter in re.findall(_RE_PATTERN_REF_PARAM, decls):
-    if not Match(_RE_PATTERN_CONST_REF_PARAM, parameter):
-      error(filename, linenum, 'runtime/references', 2,
-            'Is this a non-const reference? '
-            'If so, make const or use a pointer: ' +
-            ReplaceAll(' *<', '<', parameter))
-
-
-def CheckCasts(filename, clean_lines, linenum, error):
-  """Various cast related checks.
-
-  Args:
-    filename: The name of the current file.
-    clean_lines: A CleansedLines instance containing the file.
-    linenum: The number of the line to check.
-    error: The function to call with any errors found.
-  """
-  line = clean_lines.elided[linenum]
-
-  # Check to see if they're using an conversion function cast.
-  # I just try to capture the most common basic types, though there are more.
-  # Parameterless conversion functions, such as bool(), are allowed as they are
-  # probably a member operator declaration or default constructor.
-  match = Search(
-      r'(\bnew\s+|\S<\s*(?:const\s+)?)?\b'
-      r'(int|float|double|bool|char|int32|uint32|int64|uint64)'
-      r'(\([^)].*)', line)
-  expecting_function = ExpectingFunctionArgs(clean_lines, linenum)
-  if match and not expecting_function:
-    matched_type = match.group(2)
-
-    # matched_new_or_template is used to silence two false positives:
-    # - New operators
-    # - Template arguments with function types
-    #
-    # For template arguments, we match on types immediately following
-    # an opening bracket without any spaces.  This is a fast way to
-    # silence the common case where the function type is the first
-    # template argument.  False negative with less-than comparison is
-    # avoided because those operators are usually followed by a space.
-    #
-    #   function<double(double)>   // bracket + no space = false positive
-    #   value < double(42)         // bracket + space = true positive
-    matched_new_or_template = match.group(1)
-
-    # Avoid arrays by looking for brackets that come after the closing
-    # parenthesis.
-    if Match(r'\([^()]+\)\s*\[', match.group(3)):
-      return
-
-    # Other things to ignore:
-    # - Function pointers
-    # - Casts to pointer types
-    # - Placement new
-    # - Alias declarations
-    matched_funcptr = match.group(3)
-    if (matched_new_or_template is None and
-        not (matched_funcptr and
-             (Match(r'\((?:[^() ]+::\s*\*\s*)?[^() ]+\)\s*\(',
-                    matched_funcptr) or
-              matched_funcptr.startswith('(*)'))) and
-        not Match(r'\s*using\s+\S+\s*=\s*' + matched_type, line) and
-        not Search(r'new\(\S+\)\s*' + matched_type, line)):
-      error(filename, linenum, 'readability/casting', 4,
-            'Using deprecated casting style.  '
-            'Use static_cast<%s>(...) instead' %
-            matched_type)
-
-  if not expecting_function:
-    CheckCStyleCast(filename, clean_lines, linenum, 'static_cast',
-                    r'\((int|float|double|bool|char|u?int(16|32|64))\)', error)
-
-  # This doesn't catch all cases. Consider (const char * const)"hello".
-  #
-  # (char *) "foo" should always be a const_cast (reinterpret_cast won't
-  # compile).
-  if CheckCStyleCast(filename, clean_lines, linenum, 'const_cast',
-                     r'\((char\s?\*+\s?)\)\s*"', error):
-    pass
-  else:
-    # Check pointer casts for other than string constants
-    CheckCStyleCast(filename, clean_lines, linenum, 'reinterpret_cast',
-                    r'\((\w+\s?\*+\s?)\)', error)
-
-  # In addition, we look for people taking the address of a cast.  This
-  # is dangerous -- casts can assign to temporaries, so the pointer doesn't
-  # point where you think.
-  #
-  # Some non-identifier character is required before the '&' for the
-  # expression to be recognized as a cast.  These are casts:
-  #   expression = &static_cast<int*>(temporary());
-  #   function(&(int*)(temporary()));
-  #
-  # This is not a cast:
-  #   reference_type&(int* function_param);
-  match = Search(
-      r'(?:[^\w]&\(([^)*][^)]*)\)[\w(])|'
-      r'(?:[^\w]&(static|dynamic|down|reinterpret)_cast\b)', line)
-  if match:
-    # Try a better error message when the & is bound to something
-    # dereferenced by the casted pointer, as opposed to the casted
-    # pointer itself.
-    parenthesis_error = False
-    match = Match(r'^(.*&(?:static|dynamic|down|reinterpret)_cast\b)<', line)
-    if match:
-      _, y1, x1 = CloseExpression(clean_lines, linenum, len(match.group(1)))
-      if x1 >= 0 and clean_lines.elided[y1][x1] == '(':
-        _, y2, x2 = CloseExpression(clean_lines, y1, x1)
-        if x2 >= 0:
-          extended_line = clean_lines.elided[y2][x2:]
-          if y2 < clean_lines.NumLines() - 1:
-            extended_line += clean_lines.elided[y2 + 1]
-          if Match(r'\s*(?:->|\[)', extended_line):
-            parenthesis_error = True
-
-    if parenthesis_error:
-      error(filename, linenum, 'readability/casting', 4,
-            ('Are you taking an address of something dereferenced '
-             'from a cast?  Wrapping the dereferenced expression in '
-             'parentheses will make the binding more obvious'))
-    else:
-      error(filename, linenum, 'runtime/casting', 4,
-            ('Are you taking an address of a cast?  '
-             'This is dangerous: could be a temp var.  '
-             'Take the address before doing the cast, rather than after'))
-
-
-def CheckCStyleCast(filename, clean_lines, linenum, cast_type, pattern, error):
-  """Checks for a C-style cast by looking for the pattern.
-
-  Args:
-    filename: The name of the current file.
-    clean_lines: A CleansedLines instance containing the file.
-    linenum: The number of the line to check.
-    cast_type: The string for the C++ cast to recommend.  This is either
-      reinterpret_cast, static_cast, or const_cast, depending.
-    pattern: The regular expression used to find C-style casts.
-    error: The function to call with any errors found.
-
-  Returns:
-    True if an error was emitted.
-    False otherwise.
-  """
-  line = clean_lines.elided[linenum]
-  match = Search(pattern, line)
-  if not match:
-    return False
-
-  # Exclude lines with keywords that tend to look like casts
-  context = line[0:match.start(1) - 1]
-  if Match(r'.*\b(?:sizeof|alignof|alignas|[_A-Z][_A-Z0-9]*)\s*$', context):
-    return False
-
-  # Try expanding current context to see if we one level of
-  # parentheses inside a macro.
-  if linenum > 0:
-    for i in xrange(linenum - 1, max(0, linenum - 5), -1):
-      context = clean_lines.elided[i] + context
-  if Match(r'.*\b[_A-Z][_A-Z0-9]*\s*\((?:\([^()]*\)|[^()])*$', context):
-    return False
-
-  # operator++(int) and operator--(int)
-  if context.endswith(' operator++') or context.endswith(' operator--'):
-    return False
-
-  # A single unnamed argument for a function tends to look like old
-  # style cast.  If we see those, don't issue warnings for deprecated
-  # casts, instead issue warnings for unnamed arguments where
-  # appropriate.
-  #
-  # These are things that we want warnings for, since the style guide
-  # explicitly require all parameters to be named:
-  #   Function(int);
-  #   Function(int) {
-  #   ConstMember(int) const;
-  #   ConstMember(int) const {
-  #   ExceptionMember(int) throw (...);
-  #   ExceptionMember(int) throw (...) {
-  #   PureVirtual(int) = 0;
-  #   [](int) -> bool {
-  #
-  # These are functions of some sort, where the compiler would be fine
-  # if they had named parameters, but people often omit those
-  # identifiers to reduce clutter:
-  #   (FunctionPointer)(int);
-  #   (FunctionPointer)(int) = value;
-  #   Function((function_pointer_arg)(int))
-  #   Function((function_pointer_arg)(int), int param)
-  #   <TemplateArgument(int)>;
-  #   <(FunctionPointerTemplateArgument)(int)>;
-  remainder = line[match.end(0):]
-  if Match(r'^\s*(?:;|const\b|throw\b|final\b|override\b|[=>{),]|->)',
-           remainder):
-    # Looks like an unnamed parameter.
-
-    # Don't warn on any kind of template arguments.
-    if Match(r'^\s*>', remainder):
-      return False
-
-    # Don't warn on assignments to function pointers, but keep warnings for
-    # unnamed parameters to pure virtual functions.  Note that this pattern
-    # will also pass on assignments of "0" to function pointers, but the
-    # preferred values for those would be "nullptr" or "NULL".
-    matched_zero = Match(r'^\s=\s*(\S+)\s*;', remainder)
-    if matched_zero and matched_zero.group(1) != '0':
-      return False
-
-    # Don't warn on function pointer declarations.  For this we need
-    # to check what came before the "(type)" string.
-    if Match(r'.*\)\s*$', line[0:match.start(0)]):
-      return False
-
-    # Don't warn if the parameter is named with block comments, e.g.:
-    #  Function(int /*unused_param*/);
-    raw_line = clean_lines.raw_lines[linenum]
-    if '/*' in raw_line:
-      return False
-
-    # Passed all filters, issue warning here.
-    error(filename, linenum, 'readability/function', 3,
-          'All parameters should be named in a function')
-    return True
-
-  # At this point, all that should be left is actual casts.
-  error(filename, linenum, 'readability/casting', 4,
-        'Using C-style cast.  Use %s<%s>(...) instead' %
-        (cast_type, match.group(1)))
-
-  return True
-
-
-def ExpectingFunctionArgs(clean_lines, linenum):
-  """Checks whether where function type arguments are expected.
-
-  Args:
-    clean_lines: A CleansedLines instance containing the file.
-    linenum: The number of the line to check.
-
-  Returns:
-    True if the line at 'linenum' is inside something that expects arguments
-    of function types.
-  """
-  line = clean_lines.elided[linenum]
-  return (Match(r'^\s*MOCK_(CONST_)?METHOD\d+(_T)?\(', line) or
-          (linenum >= 2 and
-           (Match(r'^\s*MOCK_(?:CONST_)?METHOD\d+(?:_T)?\((?:\S+,)?\s*$',
-                  clean_lines.elided[linenum - 1]) or
-            Match(r'^\s*MOCK_(?:CONST_)?METHOD\d+(?:_T)?\(\s*$',
-                  clean_lines.elided[linenum - 2]) or
-            Search(r'\bstd::m?function\s*\<\s*$',
-                   clean_lines.elided[linenum - 1]))))
-
-
-_HEADERS_CONTAINING_TEMPLATES = (
-    ('<deque>', ('deque',)),
-    ('<functional>', ('unary_function', 'binary_function',
-                      'plus', 'minus', 'multiplies', 'divides', 'modulus',
-                      'negate',
-                      'equal_to', 'not_equal_to', 'greater', 'less',
-                      'greater_equal', 'less_equal',
-                      'logical_and', 'logical_or', 'logical_not',
-                      'unary_negate', 'not1', 'binary_negate', 'not2',
-                      'bind1st', 'bind2nd',
-                      'pointer_to_unary_function',
-                      'pointer_to_binary_function',
-                      'ptr_fun',
-                      'mem_fun_t', 'mem_fun', 'mem_fun1_t', 'mem_fun1_ref_t',
-                      'mem_fun_ref_t',
-                      'const_mem_fun_t', 'const_mem_fun1_t',
-                      'const_mem_fun_ref_t', 'const_mem_fun1_ref_t',
-                      'mem_fun_ref',
-                     )),
-    ('<limits>', ('numeric_limits',)),
-    ('<list>', ('list',)),
-    ('<map>', ('map', 'multimap',)),
-    ('<memory>', ('allocator',)),
-    ('<queue>', ('queue', 'priority_queue',)),
-    ('<set>', ('set', 'multiset',)),
-    ('<stack>', ('stack',)),
-    ('<string>', ('char_traits', 'basic_string',)),
-    ('<tuple>', ('tuple',)),
-    ('<utility>', ('pair',)),
-    ('<vector>', ('vector',)),
-
-    # gcc extensions.
-    # Note: std::hash is their hash, ::hash is our hash
-    ('<hash_map>', ('hash_map', 'hash_multimap',)),
-    ('<hash_set>', ('hash_set', 'hash_multiset',)),
-    ('<slist>', ('slist',)),
-    )
-
-_RE_PATTERN_STRING = re.compile(r'\bstring\b')
-
-_re_pattern_algorithm_header = []
-for _template in ('copy', 'max', 'min', 'min_element', 'sort', 'swap',
-                  'transform'):
-  # Match max<type>(..., ...), max(..., ...), but not foo->max, foo.max or
-  # type::max().
-  _re_pattern_algorithm_header.append(
-      (re.compile(r'[^>.]\b' + _template + r'(<.*?>)?\([^\)]'),
-       _template,
-       '<algorithm>'))
-
-_re_pattern_templates = []
-for _header, _templates in _HEADERS_CONTAINING_TEMPLATES:
-  for _template in _templates:
-    _re_pattern_templates.append(
-        (re.compile(r'(\<|\b)' + _template + r'\s*\<'),
-         _template + '<>',
-         _header))
-
-
-def FilesBelongToSameModule(filename_cc, filename_h):
-  """Check if these two filenames belong to the same module.
-
-  The concept of a 'module' here is a as follows:
-  foo.h, foo-inl.h, foo.cc, foo_test.cc and foo_unittest.cc belong to the
-  same 'module' if they are in the same directory.
-  some/path/public/xyzzy and some/path/internal/xyzzy are also considered
-  to belong to the same module here.
-
-  If the filename_cc contains a longer path than the filename_h, for example,
-  '/absolute/path/to/base/sysinfo.cc', and this file would include
-  'base/sysinfo.h', this function also produces the prefix needed to open the
-  header. This is used by the caller of this function to more robustly open the
-  header file. We don't have access to the real include paths in this context,
-  so we need this guesswork here.
-
-  Known bugs: tools/base/bar.cc and base/bar.h belong to the same module
-  according to this implementation. Because of this, this function gives
-  some false positives. This should be sufficiently rare in practice.
-
-  Args:
-    filename_cc: is the path for the .cc file
-    filename_h: is the path for the header path
-
-  Returns:
-    Tuple with a bool and a string:
-    bool: True if filename_cc and filename_h belong to the same module.
-    string: the additional prefix needed to open the header file.
-  """
-
-  if not filename_cc.endswith('.cc'):
-    return (False, '')
-  filename_cc = filename_cc[:-len('.cc')]
-  if filename_cc.endswith('_unittest'):
-    filename_cc = filename_cc[:-len('_unittest')]
-  elif filename_cc.endswith('_test'):
-    filename_cc = filename_cc[:-len('_test')]
-  filename_cc = filename_cc.replace('/public/', '/')
-  filename_cc = filename_cc.replace('/internal/', '/')
-
-  if not filename_h.endswith('.h'):
-    return (False, '')
-  filename_h = filename_h[:-len('.h')]
-  if filename_h.endswith('-inl'):
-    filename_h = filename_h[:-len('-inl')]
-  filename_h = filename_h.replace('/public/', '/')
-  filename_h = filename_h.replace('/internal/', '/')
-
-  files_belong_to_same_module = filename_cc.endswith(filename_h)
-  common_path = ''
-  if files_belong_to_same_module:
-    common_path = filename_cc[:-len(filename_h)]
-  return files_belong_to_same_module, common_path
-
-
-def UpdateIncludeState(filename, include_dict, io=codecs):
-  """Fill up the include_dict with new includes found from the file.
-
-  Args:
-    filename: the name of the header to read.
-    include_dict: a dictionary in which the headers are inserted.
-    io: The io factory to use to read the file. Provided for testability.
-
-  Returns:
-    True if a header was successfully added. False otherwise.
-  """
-  headerfile = None
-  try:
-    headerfile = io.open(filename, 'r', 'utf8', 'replace')
-  except IOError:
-    return False
-  linenum = 0
-  for line in headerfile:
-    linenum += 1
-    clean_line = CleanseComments(line)
-    match = _RE_PATTERN_INCLUDE.search(clean_line)
-    if match:
-      include = match.group(2)
-      include_dict.setdefault(include, linenum)
-  return True
-
-
-def CheckForIncludeWhatYouUse(filename, clean_lines, include_state, error,
-                              io=codecs):
-  """Reports for missing stl includes.
-
-  This function will output warnings to make sure you are including the headers
-  necessary for the stl containers and functions that you use. We only give one
-  reason to include a header. For example, if you use both equal_to<> and
-  less<> in a .h file, only one (the latter in the file) of these will be
-  reported as a reason to include the <functional>.
-
-  Args:
-    filename: The name of the current file.
-    clean_lines: A CleansedLines instance containing the file.
-    include_state: An _IncludeState instance.
-    error: The function to call with any errors found.
-    io: The IO factory to use to read the header file. Provided for unittest
-        injection.
-  """
-  required = {}  # A map of header name to linenumber and the template entity.
-                 # Example of required: { '<functional>': (1219, 'less<>') }
-
-  for linenum in xrange(clean_lines.NumLines()):
-    line = clean_lines.elided[linenum]
-    if not line or line[0] == '#':
-      continue
-
-    # String is special -- it is a non-templatized type in STL.
-    matched = _RE_PATTERN_STRING.search(line)
-    if matched:
-      # Don't warn about strings in non-STL namespaces:
-      # (We check only the first match per line; good enough.)
-      prefix = line[:matched.start()]
-      if prefix.endswith('std::') or not prefix.endswith('::'):
-        required['<string>'] = (linenum, 'string')
-
-    for pattern, template, header in _re_pattern_algorithm_header:
-      if pattern.search(line):
-        required[header] = (linenum, template)
-
-    # The following function is just a speed up, no semantics are changed.
-    if not '<' in line:  # Reduces the cpu time usage by skipping lines.
-      continue
-
-    for pattern, template, header in _re_pattern_templates:
-      if pattern.search(line):
-        required[header] = (linenum, template)
-
-  # The policy is that if you #include something in foo.h you don't need to
-  # include it again in foo.cc. Here, we will look at possible includes.
-  # Let's flatten the include_state include_list and copy it into a dictionary.
-  include_dict = dict([item for sublist in include_state.include_list
-                       for item in sublist])
-
-  # Did we find the header for this file (if any) and successfully load it?
-  header_found = False
-
-  # Use the absolute path so that matching works properly.
-  abs_filename = FileInfo(filename).FullName()
-
-  # For Emacs's flymake.
-  # If cpplint is invoked from Emacs's flymake, a temporary file is generated
-  # by flymake and that file name might end with '_flymake.cc'. In that case,
-  # restore original file name here so that the corresponding header file can be
-  # found.
-  # e.g. If the file name is 'foo_flymake.cc', we should search for 'foo.h'
-  # instead of 'foo_flymake.h'
-  abs_filename = re.sub(r'_flymake\.cc$', '.cc', abs_filename)
-
-  # include_dict is modified during iteration, so we iterate over a copy of
-  # the keys.
-  header_keys = include_dict.keys()
-  for header in header_keys:
-    (same_module, common_path) = FilesBelongToSameModule(abs_filename, header)
-    fullpath = common_path + header
-    if same_module and UpdateIncludeState(fullpath, include_dict, io):
-      header_found = True
-
-  # If we can't find the header file for a .cc, assume it's because we don't
-  # know where to look. In that case we'll give up as we're not sure they
-  # didn't include it in the .h file.
-  # TODO(unknown): Do a better job of finding .h files so we are confident that
-  # not having the .h file means there isn't one.
-  if filename.endswith('.cc') and not header_found:
-    return
-
-  # All the lines have been processed, report the errors found.
-  for required_header_unstripped in required:
-    template = required[required_header_unstripped][1]
-    if required_header_unstripped.strip('<>"') not in include_dict:
-      error(filename, required[required_header_unstripped][0],
-            'build/include_what_you_use', 4,
-            'Add #include ' + required_header_unstripped + ' for ' + template)
-
-
-_RE_PATTERN_EXPLICIT_MAKEPAIR = re.compile(r'\bmake_pair\s*<')
-
-
-def CheckMakePairUsesDeduction(filename, clean_lines, linenum, error):
-  """Check that make_pair's template arguments are deduced.
-
-  G++ 4.6 in C++11 mode fails badly if make_pair's template arguments are
-  specified explicitly, and such use isn't intended in any case.
-
-  Args:
-    filename: The name of the current file.
-    clean_lines: A CleansedLines instance containing the file.
-    linenum: The number of the line to check.
-    error: The function to call with any errors found.
-  """
-  line = clean_lines.elided[linenum]
-  match = _RE_PATTERN_EXPLICIT_MAKEPAIR.search(line)
-  if match:
-    error(filename, linenum, 'build/explicit_make_pair',
-          4,  # 4 = high confidence
-          'For C++11-compatibility, omit template arguments from make_pair'
-          ' OR use pair directly OR if appropriate, construct a pair directly')
-
-
-def CheckDefaultLambdaCaptures(filename, clean_lines, linenum, error):
-  """Check that default lambda captures are not used.
-
-  Args:
-    filename: The name of the current file.
-    clean_lines: A CleansedLines instance containing the file.
-    linenum: The number of the line to check.
-    error: The function to call with any errors found.
-  """
-  line = clean_lines.elided[linenum]
-
-  # A lambda introducer specifies a default capture if it starts with "[="
-  # or if it starts with "[&" _not_ followed by an identifier.
-  match = Match(r'^(.*)\[\s*(?:=|&[^\w])', line)
-  if match:
-    # Found a potential error, check what comes after the lambda-introducer.
-    # If it's not open parenthesis (for lambda-declarator) or open brace
-    # (for compound-statement), it's not a lambda.
-    line, _, pos = CloseExpression(clean_lines, linenum, len(match.group(1)))
-    if pos >= 0 and Match(r'^\s*[{(]', line[pos:]):
-      error(filename, linenum, 'build/c++11',
-            4,  # 4 = high confidence
-            'Default lambda captures are an unapproved C++ feature.')
-
-
-def CheckRedundantVirtual(filename, clean_lines, linenum, error):
-  """Check if line contains a redundant "virtual" function-specifier.
-
-  Args:
-    filename: The name of the current file.
-    clean_lines: A CleansedLines instance containing the file.
-    linenum: The number of the line to check.
-    error: The function to call with any errors found.
-  """
-  # Look for "virtual" on current line.
-  line = clean_lines.elided[linenum]
-  virtual = Match(r'^(.*)(\bvirtual\b)(.*)$', line)
-  if not virtual: return
-
-  # Ignore "virtual" keywords that are near access-specifiers.  These
-  # are only used in class base-specifier and do not apply to member
-  # functions.
-  if (Search(r'\b(public|protected|private)\s+$', virtual.group(1)) or
-      Match(r'^\s+(public|protected|private)\b', virtual.group(3))):
-    return
-
-  # Ignore the "virtual" keyword from virtual base classes.  Usually
-  # there is a column on the same line in these cases (virtual base
-  # classes are rare in google3 because multiple inheritance is rare).
-  if Match(r'^.*[^:]:[^:].*$', line): return
-
-  # Look for the next opening parenthesis.  This is the start of the
-  # parameter list (possibly on the next line shortly after virtual).
-  # TODO(unknown): doesn't work if there are virtual functions with
-  # decltype() or other things that use parentheses, but csearch suggests
-  # that this is rare.
-  end_col = -1
-  end_line = -1
-  start_col = len(virtual.group(2))
-  for start_line in xrange(linenum, min(linenum + 3, clean_lines.NumLines())):
-    line = clean_lines.elided[start_line][start_col:]
-    parameter_list = Match(r'^([^(]*)\(', line)
-    if parameter_list:
-      # Match parentheses to find the end of the parameter list
-      (_, end_line, end_col) = CloseExpression(
-          clean_lines, start_line, start_col + len(parameter_list.group(1)))
-      break
-    start_col = 0
-
-  if end_col < 0:
-    return  # Couldn't find end of parameter list, give up
-
-  # Look for "override" or "final" after the parameter list
-  # (possibly on the next few lines).
-  for i in xrange(end_line, min(end_line + 3, clean_lines.NumLines())):
-    line = clean_lines.elided[i][end_col:]
-    match = Search(r'\b(override|final)\b', line)
-    if match:
-      error(filename, linenum, 'readability/inheritance', 4,
-            ('"virtual" is redundant since function is '
-             'already declared as "%s"' % match.group(1)))
-
-    # Set end_col to check whole lines after we are done with the
-    # first line.
-    end_col = 0
-    if Search(r'[^\w]\s*$', line):
-      break
-
-
-def CheckRedundantOverrideOrFinal(filename, clean_lines, linenum, error):
-  """Check if line contains a redundant "override" or "final" virt-specifier.
-
-  Args:
-    filename: The name of the current file.
-    clean_lines: A CleansedLines instance containing the file.
-    linenum: The number of the line to check.
-    error: The function to call with any errors found.
-  """
-  # Look for closing parenthesis nearby.  We need one to confirm where
-  # the declarator ends and where the virt-specifier starts to avoid
-  # false positives.
-  line = clean_lines.elided[linenum]
-  declarator_end = line.rfind(')')
-  if declarator_end >= 0:
-    fragment = line[declarator_end:]
-  else:
-    if linenum > 1 and clean_lines.elided[linenum - 1].rfind(')') >= 0:
-      fragment = line
-    else:
-      return
-
-  # Check that at most one of "override" or "final" is present, not both
-  if Search(r'\boverride\b', fragment) and Search(r'\bfinal\b', fragment):
-    error(filename, linenum, 'readability/inheritance', 4,
-          ('"override" is redundant since function is '
-           'already declared as "final"'))
-
-
-
-
-# Returns true if we are at a new block, and it is directly
-# inside of a namespace.
-def IsBlockInNameSpace(nesting_state, is_forward_declaration):
-  """Checks that the new block is directly in a namespace.
-
-  Args:
-    nesting_state: The _NestingState object that contains info about our state.
-    is_forward_declaration: If the class is a forward declared class.
-  Returns:
-    Whether or not the new block is directly in a namespace.
-  """
-  if is_forward_declaration:
-    if len(nesting_state.stack) >= 1 and (
-        isinstance(nesting_state.stack[-1], _NamespaceInfo)):
-      return True
-    else:
-      return False
-
-  return (len(nesting_state.stack) > 1 and
-          nesting_state.stack[-1].check_namespace_indentation and
-          isinstance(nesting_state.stack[-2], _NamespaceInfo))
-
-
-def ShouldCheckNamespaceIndentation(nesting_state, is_namespace_indent_item,
-                                    raw_lines_no_comments, linenum):
-  """This method determines if we should apply our namespace indentation check.
-
-  Args:
-    nesting_state: The current nesting state.
-    is_namespace_indent_item: If we just put a new class on the stack, True.
-      If the top of the stack is not a class, or we did not recently
-      add the class, False.
-    raw_lines_no_comments: The lines without the comments.
-    linenum: The current line number we are processing.
-
-  Returns:
-    True if we should apply our namespace indentation check. Currently, it
-    only works for classes and namespaces inside of a namespace.
-  """
-
-  is_forward_declaration = IsForwardClassDeclaration(raw_lines_no_comments,
-                                                     linenum)
-
-  if not (is_namespace_indent_item or is_forward_declaration):
-    return False
-
-  # If we are in a macro, we do not want to check the namespace indentation.
-  if IsMacroDefinition(raw_lines_no_comments, linenum):
-    return False
-
-  return IsBlockInNameSpace(nesting_state, is_forward_declaration)
-
-
-# Call this method if the line is directly inside of a namespace.
-# If the line above is blank (excluding comments) or the start of
-# an inner namespace, it cannot be indented.
-def CheckItemIndentationInNamespace(filename, raw_lines_no_comments, linenum,
-                                    error):
-  line = raw_lines_no_comments[linenum]
-  if Match(r'^\s+', line):
-    error(filename, linenum, 'runtime/indentation_namespace', 4,
-          'Do not indent within a namespace')
-
-
-def ProcessLine(filename, file_extension, clean_lines, line,
-                include_state, function_state, nesting_state, error,
-                extra_check_functions=[]):
-  """Processes a single line in the file.
-
-  Args:
-    filename: Filename of the file that is being processed.
-    file_extension: The extension (dot not included) of the file.
-    clean_lines: An array of strings, each representing a line of the file,
-                 with comments stripped.
-    line: Number of line being processed.
-    include_state: An _IncludeState instance in which the headers are inserted.
-    function_state: A _FunctionState instance which counts function lines, etc.
-    nesting_state: A NestingState instance which maintains information about
-                   the current stack of nested blocks being parsed.
-    error: A callable to which errors are reported, which takes 4 arguments:
-           filename, line number, error level, and message
-    extra_check_functions: An array of additional check functions that will be
-                           run on each source line. Each function takes 4
-                           arguments: filename, clean_lines, line, error
-  """
-  raw_lines = clean_lines.raw_lines
-  ParseNolintSuppressions(filename, raw_lines[line], line, error)
-  nesting_state.Update(filename, clean_lines, line, error)
-  CheckForNamespaceIndentation(filename, nesting_state, clean_lines, line,
-                               error)
-  if nesting_state.InAsmBlock(): return
-  CheckForFunctionLengths(filename, clean_lines, line, function_state, error)
-  CheckForMultilineCommentsAndStrings(filename, clean_lines, line, error)
-  CheckStyle(filename, clean_lines, line, file_extension, nesting_state, error)
-  CheckLanguage(filename, clean_lines, line, file_extension, include_state,
-                nesting_state, error)
-  CheckForNonConstReference(filename, clean_lines, line, nesting_state, error)
-  CheckForNonStandardConstructs(filename, clean_lines, line,
-                                nesting_state, error)
-  CheckVlogArguments(filename, clean_lines, line, error)
-  CheckPosixThreading(filename, clean_lines, line, error)
-  CheckInvalidIncrement(filename, clean_lines, line, error)
-  CheckMakePairUsesDeduction(filename, clean_lines, line, error)
-  CheckDefaultLambdaCaptures(filename, clean_lines, line, error)
-  CheckRedundantVirtual(filename, clean_lines, line, error)
-  CheckRedundantOverrideOrFinal(filename, clean_lines, line, error)
-  for check_fn in extra_check_functions:
-    check_fn(filename, clean_lines, line, error)
-
-def FlagCxx11Features(filename, clean_lines, linenum, error):
-  """Flag those c++11 features that we only allow in certain places.
-
-  Args:
-    filename: The name of the current file.
-    clean_lines: A CleansedLines instance containing the file.
-    linenum: The number of the line to check.
-    error: The function to call with any errors found.
-  """
-  line = clean_lines.elided[linenum]
-
-  # Flag unapproved C++11 headers.
-  include = Match(r'\s*#\s*include\s+[<"]([^<"]+)[">]', line)
-  if include and include.group(1) in ('cfenv',
-                                      'condition_variable',
-                                      'fenv.h',
-                                      'future',
-                                      'mutex',
-                                      'thread',
-                                      'chrono',
-                                      'ratio',
-                                      'regex',
-                                      'system_error',
-                                     ):
-    error(filename, linenum, 'build/c++11', 5,
-          ('<%s> is an unapproved C++11 header.') % include.group(1))
-
-  # The only place where we need to worry about C++11 keywords and library
-  # features in preprocessor directives is in macro definitions.
-  if Match(r'\s*#', line) and not Match(r'\s*#\s*define\b', line): return
-
-  # These are classes and free functions.  The classes are always
-  # mentioned as std::*, but we only catch the free functions if
-  # they're not found by ADL.  They're alphabetical by header.
-  for top_name in (
-      # type_traits
-      'alignment_of',
-      'aligned_union',
-      ):
-    if Search(r'\bstd::%s\b' % top_name, line):
-      error(filename, linenum, 'build/c++11', 5,
-            ('std::%s is an unapproved C++11 class or function.  Send c-style '
-             'an example of where it would make your code more readable, and '
-             'they may let you use it.') % top_name)
-
-
-def ProcessFileData(filename, file_extension, lines, error,
-                    extra_check_functions=[]):
-  """Performs lint checks and reports any errors to the given error function.
-
-  Args:
-    filename: Filename of the file that is being processed.
-    file_extension: The extension (dot not included) of the file.
-    lines: An array of strings, each representing a line of the file, with the
-           last element being empty if the file is terminated with a newline.
-    error: A callable to which errors are reported, which takes 4 arguments:
-           filename, line number, error level, and message
-    extra_check_functions: An array of additional check functions that will be
-                           run on each source line. Each function takes 4
-                           arguments: filename, clean_lines, line, error
-  """
-  lines = (['// marker so line numbers and indices both start at 1'] + lines +
-           ['// marker so line numbers end in a known way'])
-
-  include_state = _IncludeState()
-  function_state = _FunctionState()
-  nesting_state = NestingState()
-
-  ResetNolintSuppressions()
-
-  CheckForCopyright(filename, lines, error)
-
-  RemoveMultiLineComments(filename, lines, error)
-  clean_lines = CleansedLines(lines)
-
-  if file_extension == 'h':
-    CheckForHeaderGuard(filename, clean_lines, error)
-
-  for line in xrange(clean_lines.NumLines()):
-    ProcessLine(filename, file_extension, clean_lines, line,
-                include_state, function_state, nesting_state, error,
-                extra_check_functions)
-    FlagCxx11Features(filename, clean_lines, line, error)
-  nesting_state.CheckCompletedBlocks(filename, error)
-
-  CheckForIncludeWhatYouUse(filename, clean_lines, include_state, error)
-
-  # Check that the .cc file has included its header if it exists.
-  if file_extension == 'cc':
-    CheckHeaderFileIncluded(filename, include_state, error)
-
-  # We check here rather than inside ProcessLine so that we see raw
-  # lines rather than "cleaned" lines.
-  CheckForBadCharacters(filename, lines, error)
-
-  CheckForNewlineAtEOF(filename, lines, error)
-
-def ProcessConfigOverrides(filename):
-  """ Loads the configuration files and processes the config overrides.
-
-  Args:
-    filename: The name of the file being processed by the linter.
-
-  Returns:
-    False if the current |filename| should not be processed further.
-  """
-
-  abs_filename = os.path.abspath(filename)
-  cfg_filters = []
-  keep_looking = True
-  while keep_looking:
-    abs_path, base_name = os.path.split(abs_filename)
-    if not base_name:
-      break  # Reached the root directory.
-
-    cfg_file = os.path.join(abs_path, "CPPLINT.cfg")
-    abs_filename = abs_path
-    if not os.path.isfile(cfg_file):
-      continue
-
-    try:
-      with open(cfg_file) as file_handle:
-        for line in file_handle:
-          line, _, _ = line.partition('#')  # Remove comments.
-          if not line.strip():
-            continue
-
-          name, _, val = line.partition('=')
-          name = name.strip()
-          val = val.strip()
-          if name == 'set noparent':
-            keep_looking = False
-          elif name == 'filter':
-            cfg_filters.append(val)
-          elif name == 'exclude_files':
-            # When matching exclude_files pattern, use the base_name of
-            # the current file name or the directory name we are processing.
-            # For example, if we are checking for lint errors in /foo/bar/baz.cc
-            # and we found the .cfg file at /foo/CPPLINT.cfg, then the config
-            # file's "exclude_files" filter is meant to be checked against "bar"
-            # and not "baz" nor "bar/baz.cc".
-            if base_name:
-              pattern = re.compile(val)
-              if pattern.match(base_name):
-                sys.stderr.write('Ignoring "%s": file excluded by "%s". '
-                                 'File path component "%s" matches '
-                                 'pattern "%s"\n' %
-                                 (filename, cfg_file, base_name, val))
-                return False
-          elif name == 'linelength':
-            global _line_length
-            try:
-                _line_length = int(val)
-            except ValueError:
-                sys.stderr.write('Line length must be numeric.')
-          else:
-            sys.stderr.write(
-                'Invalid configuration option (%s) in file %s\n' %
-                (name, cfg_file))
-
-    except IOError:
-      sys.stderr.write(
-          "Skipping config file '%s': Can't open for reading\n" % cfg_file)
-      keep_looking = False
-
-  # Apply all the accumulated filters in reverse order (top-level directory
-  # config options having the least priority).
-  for filter in reversed(cfg_filters):
-     _AddFilters(filter)
-
-  return True
-
-
-def ProcessFile(filename, vlevel, extra_check_functions=[]):
-  """Does google-lint on a single file.
-
-  Args:
-    filename: The name of the file to parse.
-
-    vlevel: The level of errors to report.  Every error of confidence
-    >= verbose_level will be reported.  0 is a good default.
-
-    extra_check_functions: An array of additional check functions that will be
-                           run on each source line. Each function takes 4
-                           arguments: filename, clean_lines, line, error
-  """
-
-  _SetVerboseLevel(vlevel)
-  _BackupFilters()
-
-  if not ProcessConfigOverrides(filename):
-    _RestoreFilters()
-    return
-
-  lf_lines = []
-  crlf_lines = []
-  try:
-    # Support the UNIX convention of using "-" for stdin.  Note that
-    # we are not opening the file with universal newline support
-    # (which codecs doesn't support anyway), so the resulting lines do
-    # contain trailing '\r' characters if we are reading a file that
-    # has CRLF endings.
-    # If after the split a trailing '\r' is present, it is removed
-    # below.
-    if filename == '-':
-      lines = codecs.StreamReaderWriter(sys.stdin,
-                                        codecs.getreader('utf8'),
-                                        codecs.getwriter('utf8'),
-                                        'replace').read().split('\n')
-    else:
-      lines = codecs.open(filename, 'r', 'utf8', 'replace').read().split('\n')
-
-    # Remove trailing '\r'.
-    # The -1 accounts for the extra trailing blank line we get from split()
-    for linenum in range(len(lines) - 1):
-      if lines[linenum].endswith('\r'):
-        lines[linenum] = lines[linenum].rstrip('\r')
-        crlf_lines.append(linenum + 1)
-      else:
-        lf_lines.append(linenum + 1)
-
-  except IOError:
-    sys.stderr.write(
-        "Skipping input '%s': Can't open for reading\n" % filename)
-    _RestoreFilters()
-    return
-
-  # Note, if no dot is found, this will give the entire filename as the ext.
-  file_extension = filename[filename.rfind('.') + 1:]
-
-  # When reading from stdin, the extension is unknown, so no cpplint tests
-  # should rely on the extension.
-  if filename != '-' and file_extension not in _valid_extensions:
-    sys.stderr.write('Ignoring %s; not a valid file name '
-                     '(%s)\n' % (filename, ', '.join(_valid_extensions)))
-  else:
-    ProcessFileData(filename, file_extension, lines, Error,
-                    extra_check_functions)
-
-    # If end-of-line sequences are a mix of LF and CR-LF, issue
-    # warnings on the lines with CR.
-    #
-    # Don't issue any warnings if all lines are uniformly LF or CR-LF,
-    # since critique can handle these just fine, and the style guide
-    # doesn't dictate a particular end of line sequence.
-    #
-    # We can't depend on os.linesep to determine what the desired
-    # end-of-line sequence should be, since that will return the
-    # server-side end-of-line sequence.
-    if lf_lines and crlf_lines:
-      # Warn on every line with CR.  An alternative approach might be to
-      # check whether the file is mostly CRLF or just LF, and warn on the
-      # minority, we bias toward LF here since most tools prefer LF.
-      for linenum in crlf_lines:
-        Error(filename, linenum, 'whitespace/newline', 1,
-              'Unexpected \\r (^M) found; better to use only \\n')
-
-  sys.stderr.write('Done processing %s\n' % filename)
-  _RestoreFilters()
-
-
-def PrintUsage(message):
-  """Prints a brief usage string and exits, optionally with an error message.
-
-  Args:
-    message: The optional error message.
-  """
-  sys.stderr.write(_USAGE)
-  if message:
-    sys.exit('\nFATAL ERROR: ' + message)
-  else:
-    sys.exit(1)
-
-
-def PrintCategories():
-  """Prints a list of all the error-categories used by error messages.
-
-  These are the categories used to filter messages via --filter.
-  """
-  sys.stderr.write(''.join('  %s\n' % cat for cat in _ERROR_CATEGORIES))
-  sys.exit(0)
-
-
-def ParseArguments(args):
-  """Parses the command line arguments.
-
-  This may set the output format and verbosity level as side-effects.
-
-  Args:
-    args: The command line arguments:
-
-  Returns:
-    The list of filenames to lint.
-  """
-  try:
-    (opts, filenames) = getopt.getopt(args, '', ['help', 'output=', 'verbose=',
-                                                 'counting=',
-                                                 'filter=',
-                                                 'root=',
-                                                 'linelength=',
-                                                 'extensions='])
-  except getopt.GetoptError:
-    PrintUsage('Invalid arguments.')
-
-  verbosity = _VerboseLevel()
-  output_format = _OutputFormat()
-  filters = ''
-  counting_style = ''
-
-  for (opt, val) in opts:
-    if opt == '--help':
-      PrintUsage(None)
-    elif opt == '--output':
-      if val not in ('emacs', 'vs7', 'eclipse'):
-        PrintUsage('The only allowed output formats are emacs, vs7 and eclipse.')
-      output_format = val
-    elif opt == '--verbose':
-      verbosity = int(val)
-    elif opt == '--filter':
-      filters = val
-      if not filters:
-        PrintCategories()
-    elif opt == '--counting':
-      if val not in ('total', 'toplevel', 'detailed'):
-        PrintUsage('Valid counting options are total, toplevel, and detailed')
-      counting_style = val
-    elif opt == '--root':
-      global _root
-      _root = val
-    elif opt == '--linelength':
-      global _line_length
-      try:
-          _line_length = int(val)
-      except ValueError:
-          PrintUsage('Line length must be digits.')
-    elif opt == '--extensions':
-      global _valid_extensions
-      try:
-          _valid_extensions = set(val.split(','))
-      except ValueError:
-          PrintUsage('Extensions must be comma seperated list.')
-
-  if not filenames:
-    PrintUsage('No files were specified.')
-
-  _SetOutputFormat(output_format)
-  _SetVerboseLevel(verbosity)
-  _SetFilters(filters)
-  _SetCountingStyle(counting_style)
-
-  return filenames
-
-
-def main():
-  filenames = ParseArguments(sys.argv[1:])
-
-  # Change stderr to write with replacement characters so we don't die
-  # if we try to print something containing non-ASCII characters.
-  sys.stderr = codecs.StreamReaderWriter(sys.stderr,
-                                         codecs.getreader('utf8'),
-                                         codecs.getwriter('utf8'),
-                                         'replace')
-
-  _cpplint_state.ResetErrorCounts()
-  for filename in filenames:
-    ProcessFile(filename, _cpplint_state.verbose_level)
-  _cpplint_state.PrintErrorCounts()
-
-  sys.exit(_cpplint_state.error_count > 0)
-
-
-if __name__ == '__main__':
-  main()
diff --git a/tool/docker/README.md b/tool/docker/README.md
index 4d5206e..740735e 100644
--- a/tool/docker/README.md
+++ b/tool/docker/README.md
@@ -38,7 +38,7 @@
 
     docker run -it apache/singa:1.2.0-cpu-devel-ubuntu18.04 /bin/bash
     # or
-    docker run -it apache/singa:1.2.0-cuda10.0-cudnn7.4.2-devel-ubuntu18.04 /bin/bash
+    nvidia-docker run -it apache/singa:1.2.0-cuda10.0-cudnn7.4.2-devel-ubuntu18.04 /bin/bash
 
 The latest SINGA code is under the `singa` folder.
 
@@ -49,12 +49,13 @@
 New Docker images could be created by executing the following command within the
 Dockerfile folder, e.g., tool/docker/devel/
 
-    docker build -t singa:<TAG> -f Dockerfile
+    docker build -t apache/singa:<TAG> -f Dockerfile
 
 The `<TAG>` is named as
 
-    devel|runtime[-CUDA|CPU][-CUDNN]
+    VERSION-devel|runtime[-CUDA|CPU][-CUDNN]
 
+* VERSION: e.g., 3.0.0
 * devel: development images with all dependent libs' header files installed and SINGA's source code; runtime: the minimal images which can run SINGA programs.
 * CUDA: cuda10.0, cuda9.0
 * CUDNN: cudnn7
@@ -72,4 +73,4 @@
     level3: Dockerfile, CUDA|MKLDNN
 
 
-For example, the path of the Dockerfile for `devel-cuda9-cudnn7` is `tool/docker/devel/ubuntu/cuda9/Dockerfile`.
+For example, the path of the Dockerfile for `devel-cuda9-cudnn7` is `tool/docker/devel/ubuntu/cuda9/Dockerfile`.
\ No newline at end of file
diff --git a/tool/docker/devel/centos6/cuda10/Dockerfile.manylinux2014 b/tool/docker/devel/centos6/cuda10/Dockerfile.manylinux2014
new file mode 100644
index 0000000..1adb1b1
--- /dev/null
+++ b/tool/docker/devel/centos6/cuda10/Dockerfile.manylinux2014
@@ -0,0 +1,132 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# The latest tag uses gcc 9, which is too high nvcc.
+# The following tag uses gcc 8, which works with nvcc.
+FROM quay.io/pypa/manylinux2014_x86_64:2020-05-01-b37d76b
+
+# install dependencies
+RUN yum install -y \
+    protobuf-devel \
+    openblas-devel \
+    # git \
+    wget \
+    openssh-server \
+    pcre-devel \
+    cmake \
+    && yum clean all \
+    && rm -rf /var/cache/yum/*
+
+# install glog into /usr/local/include/glog /usr/local/lib
+RUN wget https://github.com/google/glog/archive/v0.3.5.tar.gz -P /tmp/\
+    && tar zxf /tmp/v0.3.5.tar.gz  -C /tmp/ \
+    && cd /tmp/glog-0.3.5 \
+    && ./configure && make && make install && cd .. && rm -rf glog-0.3.5
+
+# install dnnl into /usr/local/include  /usr/local/lib
+RUN wget https://github.com/intel/mkl-dnn/releases/download/v1.1/dnnl_lnx_1.1.0_cpu_gomp.tgz -P /tmp/ \
+    && tar zxf /tmp/dnnl_lnx_1.1.0_cpu_gomp.tgz  -C /tmp/ \
+    && cp -r -H /tmp/dnnl_lnx_1.1.0_cpu_gomp/lib/lib* /usr/local/lib/ \
+    && cp -r -H /tmp/dnnl_lnx_1.1.0_cpu_gomp/include/* /usr/local/include/  \
+    && rm -rf /tmp/dnnl_lnx_1.1.0_cpu_gomp
+# ENV DNNL_ROOT /root/dnnl_lnx_1.1.0_cpu_gomp/
+
+# install swig into /usr/local/bin
+RUN wget http://prdownloads.sourceforge.net/swig/swig-3.0.12.tar.gz -P /tmp/ \
+    && tar zxf /tmp/swig-3.0.12.tar.gz -C /tmp/ \
+    && cd /tmp/swig-3.0.12 && ./configure && make && make install && cd .. && rm -rf swig-3.0.12
+
+# numpy and python versions should be matched; 
+# twine works for all python versions
+RUN /opt/python/cp36-cp36m/bin/pip install numpy twine
+RUN /opt/python/cp37-cp37m/bin/pip install numpy 
+RUN /opt/python/cp38-cp38/bin/pip install numpy
+
+# install cuda and cudnn
+# Refer to https://gitlab.com/nvidia/container-images/cuda/-/tree/master/dist for other cuda and cudnn versions
+# 10.2-base-centos7
+RUN NVIDIA_GPGKEY_SUM=d1be581509378368edeec8c1eb2958702feedf3bc3d17011adbf24efacce4ab5 && \
+    curl -fsSL https://developer.download.nvidia.com/compute/cuda/repos/rhel7/x86_64/7fa2af80.pub | sed '/^Version/d' > /etc/pki/rpm-gpg/RPM-GPG-KEY-NVIDIA && \
+    echo "$NVIDIA_GPGKEY_SUM  /etc/pki/rpm-gpg/RPM-GPG-KEY-NVIDIA" | sha256sum -c --strict  -
+COPY cuda.repo /etc/yum.repos.d/cuda.repo
+ENV CUDA_VERSION 10.2.89
+ENV CUDA_PKG_VERSION 10-2-$CUDA_VERSION-1
+# For libraries in the cuda-compat-* package: https://docs.nvidia.com/cuda/eula/index.html#attachment-a
+RUN yum install -y \
+    cuda-cudart-$CUDA_PKG_VERSION \
+    cuda-compat-10-2 \
+    && ln -s cuda-10.2 /usr/local/cuda && \
+    rm -rf /var/cache/yum/*
+
+# nvidia-docker 1.0
+RUN echo "/usr/local/nvidia/lib" >> /etc/ld.so.conf.d/nvidia.conf && \
+    echo "/usr/local/nvidia/lib64" >> /etc/ld.so.conf.d/nvidia.conf
+ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:${PATH}
+ENV LD_LIBRARY_PATH /usr/local/nvidia/lib:/usr/local/nvidia/lib64:$LD_LIBRARY_PATH
+
+# nvidia-container-runtime
+ENV NVIDIA_VISIBLE_DEVICES all
+ENV NVIDIA_DRIVER_CAPABILITIES compute,utility
+ENV NVIDIA_REQUIRE_CUDA "cuda>=10.2 brand=tesla,driver>=396,driver<397 brand=tesla,driver>=410,driver<411 brand=tesla,driver>=418,driver<419 brand=tesla,driver>=440,driver<441"
+
+# 10.2-runtime-centos7
+RUN yum install -y \
+    cuda-libraries-$CUDA_PKG_VERSION \
+    cuda-nvtx-$CUDA_PKG_VERSION \
+    libcublas10-10.2.2.89-1 \
+    && rm -rf /var/cache/yum/*
+
+
+# 10.2-devel-centos7
+RUN yum install -y \
+    cuda-nvml-dev-$CUDA_PKG_VERSION \
+    cuda-command-line-tools-$CUDA_PKG_VERSION \
+    cuda-cudart-dev-$CUDA_PKG_VERSION \
+    cuda-libraries-dev-$CUDA_PKG_VERSION \
+    cuda-minimal-build-$CUDA_PKG_VERSION \
+    && rm -rf /var/cache/yum/*
+RUN yum install -y xz && NCCL_DOWNLOAD_SUM=a9ee790c3fc64b0ecbb00db92eddc1525552eda10a8656ff4b7380f66d81bda1 && \
+    curl -fsSL https://developer.download.nvidia.com/compute/redist/nccl/v2.7/nccl_2.7.3-1+cuda10.2_x86_64.txz -O && \
+    echo "$NCCL_DOWNLOAD_SUM  nccl_2.7.3-1+cuda10.2_x86_64.txz" | sha256sum -c - && \
+    unxz nccl_2.7.3-1+cuda10.2_x86_64.txz && \
+    tar --no-same-owner --keep-old-files --no-overwrite-dir -xvf  nccl_2.7.3-1+cuda10.2_x86_64.tar -C /usr/local/cuda/include/ --strip-components=2 --wildcards '*/include/*' && \
+    tar --no-same-owner --keep-old-files --no-overwrite-dir -xvf  nccl_2.7.3-1+cuda10.2_x86_64.tar -C /usr/local/cuda/lib64/ --strip-components=2 --wildcards '*/lib/libnccl.so' && \
+    rm -f nccl_2.7.3-1+cuda10.2_x86_64.tar && \
+    ldconfig
+ENV LIBRARY_PATH /usr/local/cuda/lib64/stubs
+
+# 10.2-cudnn7-devel-centos7
+ENV CUDNN_VERSION 7.6.5.32
+# cuDNN license: https://developer.nvidia.com/cudnn/license_agreement
+RUN CUDNN_DOWNLOAD_SUM=600267f2caaed2fd58eb214ba669d8ea35f396a7d19b94822e6b36f9f7088c20 && \
+    curl -fsSL http://developer.download.nvidia.com/compute/redist/cudnn/v7.6.5/cudnn-10.2-linux-x64-v7.6.5.32.tgz -O && \
+    echo "$CUDNN_DOWNLOAD_SUM  cudnn-10.2-linux-x64-v7.6.5.32.tgz" | sha256sum -c - && \
+    tar --no-same-owner -xzf cudnn-10.2-linux-x64-v7.6.5.32.tgz -C /usr/local && \
+    rm cudnn-10.2-linux-x64-v7.6.5.32.tgz && \
+    ldconfig
+
+# install cnmem to /usr/local/include  /usr/local/lib
+RUN git clone https://github.com/NVIDIA/cnmem.git cnmem \
+    && cd cnmem && mkdir build && cd build && cmake .. && make && make install && cd ../.. && rm -rf cnmem
+
+# install mpich /usr/local/include   /usr/local/lib
+RUN wget http://www.mpich.org/static/downloads/3.3.2/mpich-3.3.2.tar.gz -P $HOME \
+    && cd $HOME \
+    && tar xfz mpich-3.3.2.tar.gz \
+    && cd mpich-3.3.2 \
+    && ./configure --prefix=/usr/local --disable-fortran \
+    && make && make install  && cd .. && rm -rf mpich-3.3.2
\ No newline at end of file
diff --git a/tool/docker/devel/centos6/cuda10/cuda.repo b/tool/docker/devel/centos6/cuda10/cuda.repo
new file mode 100644
index 0000000..990ac25
--- /dev/null
+++ b/tool/docker/devel/centos6/cuda10/cuda.repo
@@ -0,0 +1,6 @@
+[cuda]
+name=cuda
+baseurl=https://developer.download.nvidia.com/compute/cuda/repos/rhel7/x86_64
+enabled=1
+gpgcheck=1
+gpgkey=file:///etc/pki/rpm-gpg/RPM-GPG-KEY-NVIDIA
\ No newline at end of file
diff --git a/tool/docker/devel/ubuntu/cuda10/Dockerfile b/tool/docker/devel/ubuntu/cuda10/Dockerfile
index 02d546a..5560f8a 100644
--- a/tool/docker/devel/ubuntu/cuda10/Dockerfile
+++ b/tool/docker/devel/ubuntu/cuda10/Dockerfile
@@ -18,7 +18,7 @@
 # Change tags to build with different cuda/cudnn versions:
 FROM nvidia/cuda:10.0-devel-ubuntu18.04
 
-ENV CUDNN_VERSION 7.4.2.24
+ENV CUDNN_VERSION 7.6.5.32
 
 RUN apt-get update && apt-get install -y --no-install-recommends \
     libcudnn7=$CUDNN_VERSION-1+cuda10.0 \
diff --git a/tool/docker/devel/ubuntu/cuda9/Dockerfile b/tool/docker/devel/ubuntu/cuda9/Dockerfile
index 33d0180..50e1279 100644
--- a/tool/docker/devel/ubuntu/cuda9/Dockerfile
+++ b/tool/docker/devel/ubuntu/cuda9/Dockerfile
@@ -18,7 +18,7 @@
 # Change tags to build with different cuda/cudnn versions:
 FROM nvidia/cuda:9.0-devel-ubuntu16.04
 
-ENV CUDNN_VERSION 7.4.2.24
+ENV CUDNN_VERSION 7.6.5.32
 
 RUN apt-get update && apt-get install -y --no-install-recommends \
     libcudnn7=$CUDNN_VERSION-1+cuda9.0 \
diff --git a/tool/linting/py.sh b/tool/linting/py.sh
index e7f793c..ac83822 100644
--- a/tool/linting/py.sh
+++ b/tool/linting/py.sh
@@ -21,8 +21,8 @@
 
 # pylint
 find python/singa/ \
-    examples/autograd \
-    test/python/ -iname *.py | xargs pylint
+    examples/ \
+    test/python/ -iname "*.py" | xargs pylint
 
 LINTRESULT=$?
 if [ $LINTRESULT == 0 ]; then
diff --git a/tool/travis/build.sh b/tool/travis/build.sh
deleted file mode 100644
index 80c7585..0000000
--- a/tool/travis/build.sh
+++ /dev/null
@@ -1,55 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-set -ex
-
-# anaconda login user name
-USER=nusdbsystem
-OS=$TRAVIS_OS_NAME-64
-
-export PATH="$HOME/miniconda/bin:$PATH"
-conda config --set anaconda_upload no
-
-# save the package at given folder, then we can upload using singa-*.tar.bz2
-suffix=$TRAVIS_JOB_NUMBER  #`TZ=Asia/Singapore date +%Y-%m-%d-%H-%M-%S`
-export CONDA_BLD_PATH=~/conda-bld-$suffix
-mkdir $CONDA_BLD_PATH
-
-# get all tags
-git fetch --unshallow
-
-conda build tool/conda/singa --python 3.6
-conda build tool/conda/singa --python 3.7
-# conda install --use-local singa
-# cd test/python
-# $HOME/miniconda/bin/python run.py
-
-if [[ "$TRAVIS_SECURE_ENV_VARS" == "false" ]];
-  # install and run unittest
-then
-  echo "no uploading if ANACONDA_UPLOAD_TOKEN not set"
-else
-  # turn off debug to hide the token in travis log
-  set +x
-  # upload the package onto anaconda cloud
-
-
-  NEW_VERSION=`git describe --abbrev=0 --tags`
-  echo "[travis]Updating to new version $NEW_VERSION"
-
-  anaconda -t $ANACONDA_UPLOAD_TOKEN upload -u $USER -l main $CONDA_BLD_PATH/$OS/singa-*.tar.bz2 --force
-fi
diff --git a/tool/travis/depends.sh b/tool/travis/depends.sh
deleted file mode 100644
index d711f8e..0000000
--- a/tool/travis/depends.sh
+++ /dev/null
@@ -1,43 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-
-# install miniconda
-if [[ "$TRAVIS_OS_NAME" == "linux" ]];
-then
-  wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh;
-else
-  # https://docs.conda.io/projects/conda-build/en/latest/resources/compiler-tools.html#macos-sdk
-  wget https://github.com/phracker/MacOSX-SDKs/releases/download/10.13/MacOSX10.9.sdk.tar.xz
-  tar xf MacOSX10.9.sdk.tar.xz -C /tmp/
-  wget https://repo.continuum.io/miniconda/Miniconda3-latest-MacOSX-x86_64.sh -O miniconda.sh;
-fi
-bash miniconda.sh -b -p $HOME/miniconda
-export PATH="$HOME/miniconda/bin:$PATH"
-hash -r
-conda config --set always_yes yes --set changeps1 no
-conda update -q conda
-conda install conda-build
-conda install anaconda-client
-conda config --add channels conda-forge
-conda config --add channels nusdbsystem
-
-# linting
-conda install -c conda-forge pylint
-conda install -c conda-forge cpplint
-conda install -c conda-forge deprecated
-python -c "from deprecated import deprecated"
diff --git a/tool/wheel.sh b/tool/wheel.sh
new file mode 100644
index 0000000..10e419a
--- /dev/null
+++ b/tool/wheel.sh
@@ -0,0 +1,43 @@
+#!/bin/bash
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# this script should be launched at the root of the singa source folder
+# it build the cpu-only and cuda enabled wheel packages for py3.6, 3.7 and 3.8
+
+rm -rf dist
+
+# build cpu only wheel packages
+rm -rf build 
+/opt/python/cp36-cp36m/bin/python setup.py bdist_wheel
+rm -rf build 
+/opt/python/cp37-cp37m/bin/python setup.py bdist_wheel
+rm -rf build 
+/opt/python/cp38-cp38/bin/python setup.py bdist_wheel
+
+# build cuda enabled wheel packages
+export SINGA_CUDA=ON
+rm -rf build 
+/opt/python/cp36-cp36m/bin/python setup.py bdist_wheel
+rm -rf build 
+/opt/python/cp37-cp37m/bin/python setup.py bdist_wheel
+rm -rf build 
+/opt/python/cp38-cp38/bin/python setup.py bdist_wheel
+
+# repair the wheel files in dist/*.whl and store the results into wheelhouse/
+/opt/python/cp38-cp38/bin/python setup.py audit
\ No newline at end of file