SINGA-300 Add residual networks for imagenet classification

Merge pull request 307 to master branch
diff --git a/examples/imagenet/resnet/README.md b/examples/imagenet/resnet/README.md
new file mode 100644
index 0000000..4a0f4da
--- /dev/null
+++ b/examples/imagenet/resnet/README.md
@@ -0,0 +1,54 @@
+---
+name: Resnets on ImageNet
+SINGA version: 1.1
+SINGA commit: 45ec92d8ffc1fa1385a9307fdf07e21da939ee2f
+parameter_url: https://s3-ap-southeast-1.amazonaws.com/dlfile/resnet/resnet-18.tar.gz
+license: Apache V2, https://github.com/facebook/fb.resnet.torch/blob/master/LICENSE
+---
+
+# Image Classification using Residual Networks
+
+
+In this example, we convert Residual Networks trained on [Torch](https://github.com/facebook/fb.resnet.torch) to SINGA for image classification.
+
+## Instructions
+
+* Download one parameter checkpoint file (see below) and the synset word file of ImageNet into this folder, e.g.,
+
+        $ wget https://s3-ap-southeast-1.amazonaws.com/dlfile/resnet/resnet-18.tar.gz
+        $ wget https://s3-ap-southeast-1.amazonaws.com/dlfile/resnet/synset_words.txt
+        $ tar xvf resnet-18.tar.gz
+
+* Usage
+
+        $ python serve.py -h
+
+* Example
+
+        # use cpu
+        $ python serve.py --use_cpu --parameter_file resnet-18.pickle --model resnet --depth 18 &
+        # use gpu
+        $ python serve.py --parameter_file resnet-18.pickle --model resnet --depth 18 &
+
+  The parameter files for the following model and depth configuration pairs are provided:
+  * resnet (original resnet), [18](https://s3-ap-southeast-1.amazonaws.com/dlfile/resnet/resnet-101.tar.gz)|[34](https://s3-ap-southeast-1.amazonaws.com/dlfile/resnet/resnet-34.tar.gz)|[101](https://s3-ap-southeast-1.amazonaws.com/dlfile/resnet/resnet-101.tar.gz)|[152](https://s3-ap-southeast-1.amazonaws.com/dlfile/resnet/resnet-152.tar.gz)
+  * addbn (resnet with a batch normalization layer after the addition), [50](https://s3-ap-southeast-1.amazonaws.com/dlfile/resnet/resnet-50.tar.gz)
+  * wrn (wide resnet), [50](https://s3-ap-southeast-1.amazonaws.com/dlfile/resnet/wrn-50-2.tar.gz)
+  * preact (resnet with pre-activation) [200](https://s3-ap-southeast-1.amazonaws.com/dlfile/resnet/resnet-200.tar.gz)
+
+* Submit images for classification
+
+        $ curl -i -F image=@image1.jpg http://localhost:9999/api
+        $ curl -i -F image=@image2.jpg http://localhost:9999/api
+        $ curl -i -F image=@image3.jpg http://localhost:9999/api
+
+image1.jpg, image2.jpg and image3.jpg should be downloaded before executing the above commands.
+
+## Details
+
+The parameter files were extracted from the original [torch files](https://github.com/facebook/fb.resnet.torch/tree/master/pretrained) via
+the convert.py program.
+
+Usage:
+
+    $ python convert.py -h
diff --git a/examples/imagenet/resnet/convert.py b/examples/imagenet/resnet/convert.py
new file mode 100644
index 0000000..6bf4101
--- /dev/null
+++ b/examples/imagenet/resnet/convert.py
@@ -0,0 +1,114 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+'''Extract the net parameters from the torch file and store them as python dict
+using cPickle'''
+
+import os
+import torchfile
+import numpy as np
+import cPickle as pickle
+from argparse import ArgumentParser
+
+import model
+
+verbose=False
+
+def add_param(idx, name, val, params):
+    if type(params) == dict:
+        assert name not in params, 'duplicated param %s' % name
+        params[name] = val
+    else:
+        assert params[idx].size() == val.size, 'size mismatch for %s: %s - %s' % (name, (params[idx].shape,), (val.shape,))
+        params[idx].copy_from_numpy(val)
+
+    if verbose:
+        print name, val.shape
+
+
+def conv(m, idx, params, param_names):
+    outplane = m['weight'].shape[0]
+    name = param_names[idx]
+    val = np.reshape(m['weight'], (outplane, -1))
+    add_param(idx, name, val, params)
+    return idx + 1
+
+
+def batchnorm(m, idx, params, param_names):
+    add_param(idx, param_names[idx], m['weight'], params)
+    add_param(idx + 1, param_names[idx + 1], m['bias'], params)
+    add_param(idx + 2, param_names[idx + 2], m['running_mean'], params)
+    add_param(idx + 3, param_names[idx + 3], m['running_var'], params)
+    return idx + 4
+
+
+def linear(m, idx, params, param_names):
+    add_param(idx, param_names[idx], np.transpose(m['weight']), params)
+    add_param(idx + 1, param_names[idx + 1], m['bias'], params)
+    return idx + 2
+
+
+def traverse(m, idx, params, param_names):
+    ''' Traverse all modules of the torch checkpoint file to extract params.
+
+    Args:
+        m, a TorchObject
+        idx, index for the current cursor of param_names
+        params, an empty dictionary (name->numpy) to dump the params via pickle;
+            or a list of tensor objects which should be in the same order as
+            param_names, called to initialize net created in Singa directly
+            using param values from torch checkpoint file.
+
+    Returns:
+        the updated idx
+    '''
+    module_type = m.__dict__['_typename']
+    if module_type in ['nn.Sequential', 'nn.ConcatTable'] :
+        for x in m.modules:
+            idx = traverse(x, idx, params, param_names)
+    elif 'SpatialConvolution' in module_type:
+        idx = conv(m, idx, params, param_names)
+    elif 'SpatialBatchNormalization' in module_type:
+        idx = batchnorm(m, idx, params, param_names)
+    elif 'Linear' in module_type:
+        idx = linear(m, idx, params, param_names)
+    return idx
+
+
+if __name__ == '__main__':
+    parser = ArgumentParser(description='Convert params from torch to python '
+            'dict. \n resnet could have depth of 18, 34, 101, 152; \n
+            wrn has depth 50; preact has depth 200; addbn has depth 50')
+    parser.add_argument("infile", help="torch checkpoint file")
+    parser.add_argument("model", choices = ['resnet', 'wrn', 'preact', 'addbn'])
+    parser.add_argument("depth", type=int, choices = [18, 34, 50, 101, 152, 200])
+    args = parser.parse_args()
+
+    net = model.create_net(args.model, args.depth)
+    # model.init_params(net)
+    m = torchfile.load(args.infile)
+    params = {}
+    # params = net.param_values()
+    param_names = net.param_names()
+    traverse(m, 0, params, param_names)
+    miss = [name for name in param_names if name not in params]
+    if len(miss) > 0:
+        print 'The following params are missing from torch file'
+        print miss
+
+    outfile = os.path.splitext(args.infile)[0] + '.pickle'
+    with open(outfile, 'wb') as fd:
+        pickle.dump(params, fd)
diff --git a/examples/imagenet/resnet/model.py b/examples/imagenet/resnet/model.py
new file mode 100644
index 0000000..bf90da3
--- /dev/null
+++ b/examples/imagenet/resnet/model.py
@@ -0,0 +1,275 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+''' This models are created following https://github.com/facebook/fb.resnet.torch.git
+and https://github.com/szagoruyko/wide-residual-networks
+'''
+from singa.layer import Conv2D, Activation, MaxPooling2D, AvgPooling2D,\
+        Split, Merge, Flatten, Dense, BatchNormalization, Softmax
+from singa import net as ffnet
+from singa import initializer
+from singa import layer
+
+ffnet.verbose=True
+
+conv_bias = False
+
+def conv(net, prefix, n, ksize, stride=1, pad=0, bn=True, relu=True, src=None):
+    '''Add a convolution layer and optionally a batchnorm and relu layer.
+
+    Args:
+        prefix, a string for the prefix of the layer name
+        n, num of filters for the conv layer
+        bn, if true add batchnorm
+        relu, if true add relu
+
+    Returns:
+        the last added layer
+    '''
+    ret = net.add(Conv2D(
+        prefix + '-conv', n, ksize, stride, pad=pad, use_bias=conv_bias), src)
+    if bn:
+        ret = net.add(BatchNormalization(prefix + '-bn'))
+    if relu:
+        ret = net.add(Activation(prefix + '-relu'))
+    return ret
+
+
+def shortcut(net, prefix, inplane, outplane, stride, src, bn=False):
+    '''Add a conv shortcut layer if inplane != outplane; or return the source
+    layer directly.
+
+    Args:
+        prefix, a string for the prefix of the layer name
+        bn, if true add a batchnorm layer after the conv layer
+
+    Returns:
+        return the last added layer or the source layer.
+    '''
+    if inplane == outplane:
+        return src
+    return conv(net, prefix + '-shortcut', outplane, 1, stride, 0, bn, False, src)
+
+
+def bottleneck(name, net, inplane, midplane, outplane, stride=1, preact=False, add_bn=False):
+    '''Add three conv layers, with a>=b<=c filters.
+
+    The default structure is
+    input
+         -split - conv1-bn1-relu1-conv2-bn2-relu2-conv3-bn3
+                - conv-bn or dummy
+         -add
+         -relu
+
+    Args:
+        inplane, num of feature maps of the input
+        midplane, num of featue maps of the middle layer
+        outplane, num of feature maps of the output
+        preact, if true, move the bn3 and relu before conv1, i.e., pre-activation ref identity mapping paper
+        add_bn, if true, move the last bn after the addition layer (for resnet-50)
+    '''
+    assert not (preact and add_bn), 'preact and batchnorm after addition cannot be true at the same time'
+    split = net.add(Split(name + '-split', 2))
+    if preact:
+        net.add(BatchNormalization(name + '-preact-bn'))
+        net.add(Activation(name + '-preact-relu'))
+    conv(net, name + '-0', midplane, 1, 1, 0, True, True)
+    conv(net, name + '-1', midplane, 3, stride, 1, True, True)
+    br0 = conv(net, name + '-2', outplane, 1, 1, 0, not (preact or add_bn), False)
+    br1 = shortcut(net, name, inplane, outplane, stride, split, not add_bn)
+    ret = net.add(Merge(name + '-add'), [br0, br1])
+    if add_bn:
+        ret = net.add(BatchNormalization(name + '-add-bn'))
+    if not preact:
+        ret = net.add(Activation(name + '-add-relu'))
+    return ret
+
+
+def basicblock(name, net, inplane, midplane, outplane, stride=1, preact=False, add_bn=False):
+    '''Add two conv layers, with a<=b filters.
+
+    The default structure is
+    input
+         -split - conv1-bn1-relu1-conv2-bn2
+                - conv or dummy
+         -add
+         -relu
+
+    Args:
+        inplane, num of feature maps of the input
+        midplane, num of featue maps of the middle layer
+        outplane, num of feature maps of the output
+        preact, if true, move the bn2 and relu before conv1, i.e., pre-activation ref identity mapping paper
+        add_bn, if true, move the last bn after the addition layer (for resnet-50)
+    '''
+    assert not (preact and add_bn), 'preact and batchnorm after addition cannot be true at the same time'
+    split = net.add(Split(name + '-split', 2))
+    if preact:
+        net.add(BatchNormalization(name + '-preact-bn'))
+        net.add(Activation(name + '-preact-relu'))
+    conv(net, name + '-0', midplane, 3, stride, 1, True, True)
+    br0 = conv(net, name + '-1', outplane, 3, 1, 1, not preact, False)
+    br1 = shortcut(net, name, inplane, outplane, stride, split, False)
+    ret = net.add(Merge(name + '-add'), [br0, br1])
+    if add_bn:
+        ret = net.add(BatchNormalization(name + '-add-bn'))
+    if not preact:
+        ret = net.add(Activation(name + '-add-relu'))
+    return ret
+
+
+def stage(sid, net, num_blk, inplane, midplane, outplane, stride, block, preact=False, add_bn=False):
+    block('stage%d-blk%d' % (sid, 0), net, inplane, midplane, outplane, stride, preact, add_bn)
+    for i in range(1, num_blk):
+        block('stage%d-blk%d' % (sid, i), net, outplane, midplane, outplane, 1, preact, add_bn)
+
+def init_params(net, weight_path=None):
+    if weight_path == None:
+        for pname, pval in zip(net.param_names(), net.param_values()):
+            print pname, pval.shape
+            if 'conv' in pname and len(pval.shape) > 1:
+                initializer.gaussian(pval, 0, pval.shape[1])
+            elif 'dense' in pname:
+                if len(pval.shape) > 1:
+                    initializer.gaussian(pval, 0, pval.shape[0])
+                else:
+                    pval.set_value(0)
+            # init params from batch norm layer
+            elif 'mean' in pname or 'beta' in pname:
+                pval.set_value(0)
+            elif 'var' in pname:
+                pval.set_value(1)
+            elif 'gamma' in pname:
+                initializer.uniform(pval, 0, 1)
+    else:
+        net.load(weight_path, use_pickle = 'pickle' in weight_path)
+
+
+cfg = { 18: [2, 2, 2, 2],  # basicblock
+        34: [3, 4, 6, 3],  # basicblock
+        50: [3, 4, 6, 3],  # bottleneck
+        101: [3, 4, 23, 3], # bottleneck
+        152: [3, 8, 36, 3], # bottleneck
+        200: [3, 24, 36, 3]} # bottleneck
+
+
+def create_addbn_resnet(depth=50):
+    '''Original resnet with the last batchnorm of each block moved to after the addition layer'''
+    net = ffnet.FeedForwardNet()
+    net.add(Conv2D('input-conv', 64, 7, 2, pad=3, use_bias=False, input_sample_shape=(3, 224, 224)))
+    net.add(BatchNormalization('input-bn'))
+    net.add(Activation('input_relu'))
+    net.add(MaxPooling2D('input_pool', 3, 2, pad=1))
+    conf = cfg[depth]
+    if depth > 34:
+        stage(0, net, conf[0], 64, 64, 256, 1, bottleneck, add_bn=True)
+        stage(1, net, conf[1], 256, 128, 512, 2, bottleneck, add_bn=True)
+        stage(2, net, conf[2], 512, 256, 1024, 2, bottleneck, add_bn=True)
+        stage(3, net, conf[3], 1024, 512, 2048, 2, bottleneck, add_bn=True)
+    else:
+        stage(0, net, conf[0], 64, 64, 64, 1, basicblock, add_bn=True)
+        stage(1, net, conf[1], 64, 128, 128, 2, basicblock, add_bn=True)
+        stage(2, net, conf[2], 128, 256, 256, 2, basicblock, add_bn=True)
+        stage(3, net, conf[3], 256, 512, 512, 2, basicblock, add_bn=True)
+    net.add(AvgPooling2D('avg', 7, 1, pad=0))
+    net.add(Flatten('flat'))
+    net.add(Dense('dense', 1000))
+    return net
+
+
+def create_resnet(depth=18):
+    '''Original resnet, where the there is a relue after the addition layer'''
+    net = ffnet.FeedForwardNet()
+    net.add(Conv2D('input-conv', 64, 7, 2, pad=3, use_bias=False, input_sample_shape=(3, 224, 224)))
+    net.add(BatchNormalization('input-bn'))
+    net.add(Activation('input_relu'))
+    net.add(MaxPooling2D('input_pool', 3, 2, pad=1))
+    conf = cfg[depth]
+    if depth > 34:
+        stage(0, net, conf[0], 64, 64, 256, 1, bottleneck)
+        stage(1, net, conf[1], 256, 128, 512, 2, bottleneck)
+        stage(2, net, conf[2], 512, 256, 1024, 2, bottleneck)
+        stage(3, net, conf[3], 1024, 512, 2048, 2, bottleneck)
+    else:
+        stage(0, net, conf[0], 64, 64, 64, 1, basicblock)
+        stage(1, net, conf[1], 64, 128, 128, 2, basicblock)
+        stage(2, net, conf[2], 128, 256, 256, 2, basicblock)
+        stage(3, net, conf[3], 256, 512, 512, 2, basicblock)
+    net.add(AvgPooling2D('avg', 7, 1, pad=0))
+    net.add(Flatten('flat'))
+    net.add(Dense('dense', 1000))
+    return net
+
+def create_preact_resnet(depth=200):
+    '''Resnet with the batchnorm and relu moved to before the conv layer for each block'''
+    net = ffnet.FeedForwardNet()
+    net.add(Conv2D('input-conv', 64, 7, 2, pad=3, use_bias=False, input_sample_shape=(3, 224, 224)))
+    net.add(BatchNormalization('input-bn'))
+    net.add(Activation('input_relu'))
+    net.add(MaxPooling2D('input_pool', 3, 2, pad=1))
+    conf = cfg[depth]
+    if depth > 34:
+        stage(0, net, conf[0], 64, 64, 256, 1, bottleneck, preact=True)
+        stage(1, net, conf[1], 256, 128, 512, 2, bottleneck, preact=True)
+        stage(2, net, conf[2], 512, 256, 1024, 2, bottleneck, preact=True)
+        stage(3, net, conf[3], 1024, 512, 2048, 2, bottleneck, preact=True)
+    else:
+        stage(0, net, conf[0], 64, 64, 64, 1, basicblock, preact=True)
+        stage(1, net, conf[1], 64, 128, 128, 2, basicblock, preact=True)
+        stage(2, net, conf[2], 128, 256, 256, 2, basicblock, preact=True)
+        stage(3, net, conf[3], 256, 512, 512, 2, basicblock, preact=True)
+    net.add(BatchNormalization('final-bn'))
+    net.add(Activation('final-relu'))
+    net.add(AvgPooling2D('avg', 7, 1, pad=0))
+    net.add(Flatten('flat'))
+    net.add(Dense('dense', 1000))
+    return net
+
+
+def create_wide_resnet(depth=50):
+    '''Similar original resnet except that a<=b<=c for the bottleneck block'''
+    net = ffnet.FeedForwardNet()
+    net.add(Conv2D('input-conv', 64, 7, 2, pad=3, use_bias=False, input_sample_shape=(3, 224, 224)))
+    net.add(BatchNormalization('input-bn'))
+    net.add(Activation('input_relu'))
+    net.add(MaxPooling2D('input_pool', 3, 2, pad=1))
+
+    stage(0, net, 3, 64, 128, 256, 1, bottleneck)
+    stage(1, net, 4, 256, 256, 512, 2, bottleneck)
+    stage(2, net, 6, 512, 512, 1024, 2, bottleneck)
+    stage(3, net, 3, 1024, 1024, 2048, 2, bottleneck)
+
+    net.add(AvgPooling2D('avg_pool', 7, 1, pad=0))
+    net.add(Flatten('flag'))
+    net.add(Dense('dense', 1000))
+    return net
+
+
+def create_net(name, depth, use_cpu):
+    if use_cpu:
+        layer.engine = 'singacpp'
+    if name == 'resnet':
+        return create_resnet(depth)
+    elif name == 'wrn':
+        return create_wide_resnet(depth)
+    elif name == 'preact':
+        return create_preact_resnet(depth)
+    elif name == 'addbn':
+        return create_addbn_resnet(depth)
+
+
+if __name__ == '__main__':
+    create_net('wrn', 50)
diff --git a/examples/imagenet/resnet/serve.py b/examples/imagenet/resnet/serve.py
new file mode 100644
index 0000000..ba5adb1
--- /dev/null
+++ b/examples/imagenet/resnet/serve.py
@@ -0,0 +1,162 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import sys
+import time
+import numpy as np
+import threading
+import traceback
+from scipy.misc import imread, imresize
+from argparse import ArgumentParser
+
+from singa import device
+from singa import tensor
+from singa import data
+from singa import image_tool
+from singa import metric
+from rafiki.agent import Agent, MsgType
+import model
+
+tool = image_tool.ImageTool()
+num_augmentation = 10
+crop_size = 224
+mean = np.array([0.485, 0.456, 0.406])
+std = np.array([ 0.229, 0.224, 0.225])
+def image_transform(img):
+    '''Input an image path and return a set of augmented images (type Image)'''
+    global tool
+    return tool.load(img).resize_by_list([256]).crop5((crop_size, crop_size), 5).flip(2).get()
+
+
+def predict(net, images, num=10):
+    '''predict probability distribution for one net.
+
+    Args:
+        net: neural net (vgg or resnet)
+        images: a batch of augmented images (type numpy)
+        num: num of augmentations
+    '''
+    prob = net.predict(images)
+    prob = tensor.to_numpy(prob)
+    prob = prob.reshape((images.shape[0] / num, num, -1))
+    prob = np.average(prob, 1)
+    return prob
+
+
+def allowed_file(filename):
+    return '.' in filename and filename.rsplit('.', 1)[1] in \
+        ["PNG", "png", "jpg", "JPG", "JPEG", "jpeg"]
+
+
+def serve(net, label_map, dev, agent, topk=5):
+    '''Serve to predict image labels.
+
+    It prints the topk food names for each image.
+
+    Args:
+        label_map: a list of food names, corresponding to the index in meta_file
+    '''
+
+    images =tensor.Tensor((num_augmentation, 3, crop_size, crop_size), dev)
+    while True:
+        msg, val = agent.pull()
+        if msg is None:
+            time.sleep(0.1)
+            continue
+        msg = MsgType.parse(msg)
+        if msg.is_request():
+            try:
+                # process images
+                im = [np.array(x.convert('RGB'), dtype=np.float32).transpose(2, 0, 1) for x in image_transform(val['image'])]
+                im = np.array(im) / 256
+                im -= mean[np.newaxis, :, np.newaxis, np.newaxis]
+                im /= std[np.newaxis, :, np.newaxis, np.newaxis]
+                images.copy_from_numpy(im)
+                print "input: ", images.l1()
+                # do prediction
+                prob = predict(net, images, num_augmentation)[0]
+                idx = np.argsort(-prob)
+                # prepare results
+                response = ""
+                for i in range(topk):
+                    response += "%s:%f <br/>" % (label_map[idx[i]], prob[idx[i]])
+            except:
+                traceback.print_exc()
+                response = "sorry, system error during prediction."
+            agent.push(MsgType.kResponse, response)
+        elif msg.is_command():
+            if MsgType.kCommandStop.equal(msg):
+                print 'get stop command'
+                agent.push(MsgType.kStatus, "success")
+                break
+            else:
+                print 'get unsupported command %s' % str(msg)
+                agent.push(MsgType.kStatus, "Unknown command")
+        else:
+            print 'get unsupported message %s' % str(msg)
+            agent.push(MsgType.kStatus, "unsupported msg; going to shutdown")
+            break
+    print "server stop"
+
+def main():
+    try:
+        # Setup argument parser
+        parser = ArgumentParser(description="Wide residual network")
+
+        parser.add_argument("--port", default=9999, help="listen port")
+        parser.add_argument("--use_cpu", action="store_true",
+                            help="If set, load models onto CPU devices")
+        parser.add_argument("--parameter_file", default="wrn-50-2.pickle")
+        parser.add_argument("--model", choices = ['resnet', 'wrn', 'preact', 'addbn'], default='wrn')
+        parser.add_argument("--depth", type=int, choices = [18, 34, 50, 101, 152, 200], default='50')
+
+        # Process arguments
+        args = parser.parse_args()
+        port = args.port
+
+        # start to train
+        agent = Agent(port)
+
+        net = model.create_net(args.model, args.depth, args.use_cpu)
+        if args.use_cpu:
+            print 'Using CPU'
+            dev = device.get_default_device()
+        else:
+            print 'Using GPU'
+            dev = device.create_cuda_gpu()
+            net.to_device(dev)
+        model.init_params(net, args.parameter_file)
+        print 'Finish loading models'
+
+        labels = np.loadtxt('synset_words.txt', str, delimiter='\t ')
+        serve(net, labels, dev, agent)
+
+        # acc = evaluate(net, '../val_list.txt',  'image/val', dev)
+        # print acc
+
+        # wait the agent finish handling http request
+        agent.stop()
+    except SystemExit:
+        return
+    except:
+        traceback.print_exc()
+        sys.stderr.write("  for help use --help \n\n")
+        return 2
+
+
+if __name__ == '__main__':
+    main()
diff --git a/python/singa/device.py b/python/singa/device.py
index 1df4c84..fdd2a92 100644
--- a/python/singa/device.py
+++ b/python/singa/device.py
@@ -132,12 +132,12 @@
 
 def create_opencl_device():
     '''Create the default OpenCL device.
-    
+
     Returns:
         a swig converted OpenCL device.
     '''
     assert singa.USE_OPENCL, 'SINGA has not been compiled with OpenCL enabled.'
-    return singa.Platform.GetDefaultDevice()
+    return singa.Platform.GetDefaultOpenclDevice()
 
 
 default_device = singa.Platform.GetDefaultDevice()
diff --git a/python/singa/layer.py b/python/singa/layer.py
index 0bea2d2..7975042 100644
--- a/python/singa/layer.py
+++ b/python/singa/layer.py
@@ -337,18 +337,19 @@
         # conf.data_format = data_format
         if W_specs is None:
             W_specs = {'init': 'xavier'}
-        if b_specs is None:
-            b_specs = {'init': 'constant'}
         if 'name' not in W_specs:
             W_specs['name'] = name + '_weight'
-        if 'name' not in b_specs:
-            b_specs['name'] = name + '_bias'
         wspecs = _construct_param_specs_from_dict(W_specs)
         self.conf.param.extend([wspecs])
         self.param_specs.append(wspecs)
-        bspecs = _construct_param_specs_from_dict(b_specs)
-        self.conf.param.extend([bspecs])
-        self.param_specs.append(bspecs)
+        if use_bias:
+            if b_specs is None:
+                b_specs = {'init': 'constant'}
+            if 'name' not in b_specs:
+                b_specs['name'] = name + '_bias'
+            bspecs = _construct_param_specs_from_dict(b_specs)
+            self.conf.param.extend([bspecs])
+            self.param_specs.append(bspecs)
 
         _check_engine(engine, ['cudnn', 'singacpp', 'singacl'])
         self.layer = _create_layer(engine, 'Convolution')
@@ -610,16 +611,19 @@
         conf.transpose = W_transpose
         if W_specs is None:
             W_specs = {'init': 'xavier'}
-        if b_specs is None:
-            b_specs = {'init': 'constant', 'value': 0}
         if 'name' not in W_specs:
             W_specs['name'] = name + '_weight'
-        if 'name' not in b_specs:
-            b_specs['name'] = name + '_bias'
         wspecs = _construct_param_specs_from_dict(W_specs)
-        bspecs = _construct_param_specs_from_dict(b_specs)
-        self.conf.param.extend([wspecs, bspecs])
-        self.param_specs.extend([wspecs, bspecs])
+        self.conf.param.extend([wspecs])
+        self.param_specs.append(wspecs)
+        if use_bias:
+            if b_specs is None:
+                b_specs = {'init': 'constant', 'value': 0}
+            if 'name' not in b_specs:
+                b_specs['name'] = name + '_bias'
+            bspecs = _construct_param_specs_from_dict(b_specs)
+            self.conf.param.extend([bspecs])
+            self.param_specs.append(bspecs)
         # dense layer is transparent to engine.
         if engine == 'cudnn':
             self.layer = _create_layer('singacuda', 'Dense')
@@ -775,7 +779,6 @@
         input_sample_shape: includes a single integer for the input sample
             feature size.
     '''
-
     def __init__(self, name, num_output, input_sample_shape=None):
         self.num_output = num_output
         self.in_shape = input_sample_shape
diff --git a/python/singa/net.py b/python/singa/net.py
index 027e78c..26fb61d 100644
--- a/python/singa/net.py
+++ b/python/singa/net.py
@@ -386,16 +386,16 @@
         '''
         if use_pickle:
             params = {}
-            for (specs, val) in zip(self.param_specs(), self.param_values()):
+            for (name, val) in zip(self.param_names(), self.param_values()):
                 val.to_host()
-                params[specs.name] = tensor.to_numpy(val)
+                params[name] = tensor.to_numpy(val)
                 with open(f, 'wb') as fd:
                     pickle.dump(params, fd)
         else:
             sp = snapshot.Snapshot(f, True, buffer_size)
-            for (specs, val) in zip(self.param_specs(), self.param_values()):
+            for (name, val) in zip(self.param_names(), self.param_values()):
                 val.to_host()
-                sp.write(specs.name, val)
+                sp.write(name, val)
 
     def load(self, f, buffer_size=10, use_pickle=False):
         '''Load model parameters using io/snapshot.
@@ -407,18 +407,30 @@
                     'then set use_pickle=False for loading it'
             with open(f, 'rb') as fd:
                 params = pickle.load(fd)
-                for (specs, val) in zip(self.param_specs(),
-                                        self.param_values()):
+                for name, val in zip(self.param_names(), self.param_values()):
+                    if name not in params:
+                        print 'Param: %s missing in the checkpoint file' % name
+                        continue
                     try:
-                        val.copy_from_numpy(params[specs.name])
+                        val.copy_from_numpy(params[name])
                     except AssertionError as err:
-                        print 'Error from copying values for param: %s' % specs.name
-                        print 'shape of param vs checkpoint', val.shape, params[specs.name].shape
+                        print 'Error from copying values for param: %s' % name
+                        print 'shape of param vs checkpoint', \
+                                val.shape, params[name].shape
                         raise err
         else:
             print 'NOTE: If your model was saved using pickle, '\
                     'then set use_pickle=True for loading it'
             sp = snapshot.Snapshot(f, False, buffer_size)
             params = sp.read()
-            for (specs, val) in zip(self.param_specs(), self.param_values()):
-                val.copy_data(params[specs.name])
+            for (name, val) in zip(self.param_names(), self.param_values()):
+                if name not in params:
+                    print 'Param: %s missing in the checkpoint file' % name
+                    continue
+                try:
+                    val.copy_data(params[name])
+                except AssertionError as err:
+                    print 'Error from copying values for param: %s' % name
+                    print 'shape of param vs checkpoint', \
+                            val.shape, params[name].shape
+                    raise err
diff --git a/src/model/layer/convolution.cc b/src/model/layer/convolution.cc
index 78ec1af..8940fb2 100644
--- a/src/model/layer/convolution.cc
+++ b/src/model/layer/convolution.cc
@@ -97,7 +97,8 @@
 
   // Setup shape of weight_ and bias_
   weight_.Reshape(Shape{num_filters_, col_height_});
-  bias_.Reshape(Shape{num_filters_});
+  if (bias_term_)
+    bias_.Reshape(Shape{num_filters_});
   // Assume the order of param is: weight, bias
   for (const auto &spec : conf.param()) param_specs_.push_back(spec);
 }
@@ -143,7 +144,6 @@
   Tensor dx;
   Tensor db, dw;
   dx.ResetLike(src_data);
-  db.ResetLike(bias_);
   dw.ResetLike(weight_);
   dw.SetValue(0.0f);
   size_t batchsize = grad.shape(0);
@@ -156,6 +156,7 @@
     SumColumns(tmp1, &tmp2);
     Tensor tmp3 = Reshape(tmp2, Shape{batchsize, num_filters_});
 
+    db.ResetLike(bias_);
     SumRows(tmp3, &db);
   }
 
@@ -178,7 +179,8 @@
     dx.CopyDataFromHostPtr(dx_b, imagesize, b * imagesize);
   }
   param_grad.push_back(dw);
-  param_grad.push_back(db);
+  if (bias_term_)
+    param_grad.push_back(db);
   delete[] data_col;
   delete[] dx_b;
   return std::make_pair(dx, param_grad);
diff --git a/src/model/layer/convolution.h b/src/model/layer/convolution.h
index 7b7fd00..89b5319 100644
--- a/src/model/layer/convolution.h
+++ b/src/model/layer/convolution.h
@@ -57,7 +57,10 @@
               const int stride_w, float* data_im);
 
   const std::vector<Tensor> param_values() override {
-    return std::vector<Tensor>{weight_, bias_};
+    if (bias_term_)
+      return std::vector<Tensor>{weight_, bias_};
+    else
+      return std::vector<Tensor>{weight_};
   }
 
   size_t kernel_w() const { return kernel_w_; }
diff --git a/src/model/layer/cudnn_convolution.cc b/src/model/layer/cudnn_convolution.cc
index 196d137..03ad8b9 100644
--- a/src/model/layer/cudnn_convolution.cc
+++ b/src/model/layer/cudnn_convolution.cc
@@ -60,7 +60,8 @@
   size_t batchsize = input.shape(0);
   CUDNN_CHECK(cudnnCreateTensorDescriptor(&x_desc_));
   CUDNN_CHECK(cudnnCreateTensorDescriptor(&y_desc_));
-  CUDNN_CHECK(cudnnCreateTensorDescriptor(&bias_desc_));
+  if (bias_term_)
+    CUDNN_CHECK(cudnnCreateTensorDescriptor(&bias_desc_));
   CUDNN_CHECK(cudnnCreateFilterDescriptor(&filter_desc_));
   CUDNN_CHECK(cudnnCreateConvolutionDescriptor(&conv_desc_));
 
@@ -209,11 +210,11 @@
   Tensor dx;
   dx.ResetLike(src_data);
   Tensor db, dw;
-  db.ResetLike(bias_);
   dw.ResetLike(weight_);
 
   // LOG(ERROR) << "backward bias";
   if (bias_term_) {
+    db.ResetLike(bias_);
     dx.device()->Exec([grad, db, this](Context *ctx) {
       Block *dyblock = grad.block(), *dbblock = db.block();
       float alpha = 1.f, beta = 0.f;
@@ -248,7 +249,8 @@
                                  this->x_desc_, dxblock->mutable_data());
   }, {grad.block(), weight_.block()}, {dx.block(), workspace_.block()});
   param_grad.push_back(dw);
-  param_grad.push_back(db);
+  if (bias_term_)
+    param_grad.push_back(db);
   return std::make_pair(dx, param_grad);
 }
 
diff --git a/src/model/layer/dense.cc b/src/model/layer/dense.cc
index 64e3d86..fac9130 100644
--- a/src/model/layer/dense.cc
+++ b/src/model/layer/dense.cc
@@ -38,11 +38,13 @@
   vdim_ = in_sample.at(0);

   hdim_ = dense_conf.num_output();

   transpose_ = dense_conf.transpose();

+  bias_term_ = dense_conf.bias_term();

   if (transpose_)  // was {vdim_, hdim} by zhaojing?

     weight_.Reshape(Shape{hdim_, vdim_});

   else

     weight_.Reshape(Shape{vdim_, hdim_});

-  bias_.Reshape(Shape{hdim_});

+  if (bias_term_)

+    bias_.Reshape(Shape{hdim_});

   for (auto specs: conf.param())

     param_specs_.push_back(specs);

 }

@@ -56,7 +58,8 @@
     output = Mult(input, weight_.T());

   else

     output = Mult(input, weight_);

-  AddRow(bias_, &output);

+  if (bias_term_)

+    AddRow(bias_, &output);

   if (flag & kTrain)

     buf_.push(input);

   return output;

@@ -70,10 +73,12 @@
   Tensor src_data = buf_.top();

   buf_.pop();

   Tensor db, dw, dx;

-  db.ResetLike(bias_);

   dw.ResetLike(weight_);

   dx.ResetLike(src_data);

-  SumRows(grad, &db);

+  if (bias_term_) {

+    db.ResetLike(bias_);

+    SumRows(grad, &db);

+  }

   if (transpose_) {

     dx = Mult(grad, weight_);

     dw = Mult(grad.T(), src_data);

@@ -82,7 +87,8 @@
     dw = Mult(src_data.T(), grad);

   }

   param_grad.push_back(dw);

-  param_grad.push_back(db);

+  if (bias_term_)

+    param_grad.push_back(db);

   return std::make_pair(dx, param_grad);

 }

 

diff --git a/src/model/layer/dense.h b/src/model/layer/dense.h
index 8a149a5..8f53699 100644
--- a/src/model/layer/dense.h
+++ b/src/model/layer/dense.h
@@ -46,7 +46,10 @@
 

   void ToDevice(std::shared_ptr<Device> device) override;

   const std::vector<Tensor> param_values() override {

-    return std::vector<Tensor>{weight_, bias_};

+    if (bias_term_)

+      return std::vector<Tensor>{weight_, bias_};

+    else

+      return std::vector<Tensor>{weight_};

   }

   size_t num_output() const { return hdim_; }

   size_t num_input() const { return vdim_; }

@@ -67,6 +70,8 @@
   /// Used in auto-encoder, where the decoder would share its weight matrix from

   /// the encoder's transposed weight matrix.

   bool transpose_ = false;

+  /// use bias or not;

+  bool bias_term_ = true;

   size_t vdim_, hdim_;

   Tensor weight_, bias_;

   // Tensor data_, grad_;