Merge pull request #676 from nudles/master

Update rnn example to use the Module API from v3.0
diff --git a/RELEASE_NOTES b/RELEASE_NOTES
index 5fc4df4..8491a6b 100644
--- a/RELEASE_NOTES
+++ b/RELEASE_NOTES
@@ -61,7 +61,7 @@
     After analyzing the dependency, the computational graph is created, which is further analyzed for
     speed and memory optimization. To enable this feature, use the [Module API](./python/singa/module.py).
 
-  * New website based on Docusaurus. The documentation files are moved to a separate repo [singa-doc]](https://github.com/apache/singa-doc).
+  * New website based on Docusaurus. The documentation files are moved to a separate repo [singa-doc](https://github.com/apache/singa-doc).
     The static website files are stored at [singa-site](https://github.com/apache/singa-site).
 
   * DNNL([Deep Neural Network Library](https://github.com/intel/mkl-dnn)), powered by Intel, 
diff --git a/examples/rnn/README.md b/examples/rnn/README.md
index 6a3a9bd..7c1c697 100644
--- a/examples/rnn/README.md
+++ b/examples/rnn/README.md
@@ -24,14 +24,10 @@
 We will use the [char-rnn](https://github.com/karpathy/char-rnn) model as an
 example, which trains over sentences or
 source code, with each character as an input unit. Particularly, we will train
-a RNN using GRU over Linux kernel source code. After training, we expect to
-generate meaningful code from the model.
-
+a RNN over Linux kernel source code. 
 
 ## Instructions
 
-* Compile and install SINGA. Currently the RNN implementation depends on Cudnn with version >= 5.05.
-
 * Prepare the dataset. Download the [kernel source code](http://cs.stanford.edu/people/karpathy/char-rnn/).
 Other plain text files can also be used.
 
@@ -42,9 +38,3 @@
   Some hyper-parameters could be set through command line,
 
         python train.py -h
-
-* Sample characters from the model by providing the number of characters to sample and the seed string.
-
-        python sample.py 'model.bin' 100 --seed '#include <std'
-
-  Please replace 'model.bin' with the path to one of the checkpoint paths.
diff --git a/examples/rnn/train.py b/examples/rnn/train.py
index c2440d7..30ce680 100644
--- a/examples/rnn/train.py
+++ b/examples/rnn/train.py
@@ -20,26 +20,55 @@
 e.g., http://cs.stanford.edu/people/karpathy/char-rnn/
 '''
 
-
 from __future__ import division
 from __future__ import print_function
-from builtins import zip
 from builtins import range
-from builtins import object
-import pickle as pickle
 import numpy as np
+import sys
 import argparse
+from tqdm import tqdm
 
-from singa import layer
-from singa import loss
 from singa import device
 from singa import tensor
-from singa import optimizer
-from singa import initializer
-from singa import utils
+from singa import autograd
+from singa import module
+from singa import opt
+
+
+class CharRNN(module.Module):
+
+    def __init__(self, vocab_size, hidden_size=32):
+        super(CharRNN, self).__init__()
+        self.rnn = autograd.LSTM(vocab_size, hidden_size)
+        self.dense = autograd.Linear(hidden_size, vocab_size)
+        self.optimizer = opt.SGD(0.01)
+        self.hidden_size = hidden_size
+        self.vocab_size = vocab_size
+        self.hx = tensor.Tensor((1, self.hidden_size))
+        self.cx = tensor.Tensor((1, self.hidden_size))
+
+    def reset_states(self, dev):
+        self.hx.to_device(dev)
+        self.cx.to_device(dev)
+        self.hx.set_value(0.0)
+        self.cx.set_value(0.0)
+
+    def forward(self, inputs):
+        x, self.hx, self.cx = self.rnn(inputs, (self.hx, self.cx))
+        x = autograd.cat(x)
+        x = autograd.reshape(x, (-1, self.hidden_size))
+        return self.dense(x)
+
+    def loss(self, out, ty):
+        ty = autograd.reshape(ty, (-1, 1))
+        return autograd.softmax_cross_entropy(out, ty)
+
+    def optim(self, loss):
+        self.optimizer.backward_and_update(loss)
 
 
 class Data(object):
+
     def __init__(self, fpath, batch_size=32, seq_length=100, train_ratio=0.8):
         '''Data object for loading a plain text file.
 
@@ -48,7 +77,8 @@
             train_ratio, split the text file into train and test sets, where
                 train_ratio of the characters are in the train set.
         '''
-        self.raw_data = open(fpath, 'r',encoding='iso-8859-1').read()  # read text file
+        self.raw_data = open(fpath, 'r',
+                             encoding='iso-8859-1').read()  # read text file
         chars = list(set(self.raw_data))
         self.vocab_size = len(chars)
         self.char_to_idx = {ch: i for i, ch in enumerate(chars)}
@@ -56,12 +86,12 @@
         data = [self.char_to_idx[c] for c in self.raw_data]
         # seq_length + 1 for the data + label
         nsamples = len(data) // (1 + seq_length)
-        data = data[0:nsamples * (1 + seq_length)]
+        data = data[0:300 * (1 + seq_length)]
         data = np.asarray(data, dtype=np.int32)
         data = np.reshape(data, (-1, seq_length + 1))
         # shuffle all sequences
         np.random.shuffle(data)
-        self.train_dat = data[0:int(data.shape[0]*train_ratio)]
+        self.train_dat = data[0:int(data.shape[0] * train_ratio)]
         self.num_train_batch = self.train_dat.shape[0] // batch_size
         self.val_dat = data[self.train_dat.shape[0]:]
         self.num_test_batch = self.val_dat.shape[0] // batch_size
@@ -69,23 +99,35 @@
         print('val dat', self.val_dat.shape)
 
 
-def numpy2tensors(npx, npy, dev):
+def numpy2tensors(npx, npy, dev, inputs=None, labels=None):
     '''batch, seq, dim -- > seq, batch, dim'''
+    tmpy = np.swapaxes(npy, 0, 1).reshape((-1, 1))
+    if labels:
+        labels.copy_from_numpy(tmpy)
+    else:
+        labels = tensor.from_numpy(tmpy)
+    labels.to_device(dev)
     tmpx = np.swapaxes(npx, 0, 1)
-    tmpy = np.swapaxes(npy, 0, 1)
-    inputs = []
-    labels = []
+    inputs_ = []
     for t in range(tmpx.shape[0]):
-        x = tensor.from_numpy(tmpx[t])
-        y = tensor.from_numpy(tmpy[t])
-        x.to_device(dev)
-        y.to_device(dev)
-        inputs.append(x)
-        labels.append(y)
+        if inputs:
+            inputs[t].copy_from_numpy(tmpx[t])
+        else:
+            x = tensor.from_numpy(tmpx[t])
+            x.to_device(dev)
+            inputs_.append(x)
+    if not inputs:
+        inputs = inputs_
     return inputs, labels
 
 
-def convert(batch, batch_size, seq_length, vocab_size, dev):
+def convert(batch,
+            batch_size,
+            seq_length,
+            vocab_size,
+            dev,
+            inputs=None,
+            labels=None):
     '''convert a batch of data into a sequence of input tensors'''
     y = batch[:, 1:]
     x1 = batch[:, :seq_length]
@@ -94,127 +136,90 @@
         for t in range(seq_length):
             c = x1[b, t]
             x[b, t, c] = 1
-    return numpy2tensors(x, y, dev)
+    return numpy2tensors(x, y, dev, inputs, labels)
 
 
-def get_lr(epoch):
-    return 0.001 / float(1 << (epoch // 50))
+def sample(model, data, dev, nsamples=100, use_max=False):
+    while True:
+        cmd = input('Do you want to sample text from the model [y/n]')
+        if cmd == 'n':
+            return
+        else:
+            seed = input('Please input some seeding text, e.g., #include <c: ')
+            inputs = []
+            for c in seed:
+                x = np.zeros((1, data.vocab_size), dtype=np.float32)
+                x[0, data.char_to_idx[c]] = 1
+                tx = tensor.from_numpy(x)
+                tx.to_device(dev)
+                inputs.append(tx)
+            model.reset_states(dev)
+            outputs = model(inputs)
+            y = tensor.softmax(outputs[-1])
+            sys.stdout.write(seed)
+            for i in range(nsamples):
+                prob = tensor.to_numpy(y)[0]
+                if use_max:
+                    cur = np.argmax(prob)
+                else:
+                    cur = np.random.choice(data.vocab_size, 1, p=prob)[0]
+                sys.stdout.write(data.idx_to_char[cur])
+                x = np.zeros((1, data.vocab_size), dtype=np.float32)
+                x[0, cur] = 1
+                tx = tensor.from_numpy(x)
+                tx.to_device(dev)
+                outputs = model([tx])
+                y = tensor.softmax(outputs[-1])
 
 
-def train(data, max_epoch, hidden_size=100, seq_length=100, batch_size=16,
-          num_stacks=1, dropout=0.5, model_path='model'):
+def evaluate(model, data, batch_size, seq_length, dev):
+    model.eval()
+    val_loss = 0.0
+    for b in range(data.num_test_batch):
+        batch = data.val_dat[b * batch_size:(b + 1) * batch_size]
+        inputs, labels = convert(batch, batch_size, seq_length, data.vocab_size,
+                                 dev)
+        model.reset_states(dev)
+        y = model(inputs)
+        loss = model.loss(y, labels)[0]
+        val_loss += tensor.to_numpy(loss)[0]
+    print('            validation loss is %f' %
+          (val_loss / data.num_test_batch / seq_length))
+
+
+def train(data,
+          max_epoch,
+          hidden_size=100,
+          seq_length=100,
+          batch_size=16,
+          model_path='model'):
     # SGD with L2 gradient normalization
-    opt = optimizer.RMSProp(constraint=optimizer.L2Constraint(5))
     cuda = device.create_cuda_gpu()
-    rnn = layer.LSTM(
-        name='lstm',
-        hidden_size=hidden_size,
-        num_stacks=num_stacks,
-        dropout=dropout,
-        input_sample_shape=(
-            data.vocab_size,
-        ))
-    rnn.to_device(cuda)
-    print('created rnn')
-    rnn_w = rnn.param_values()[0]
-    rnn_w.uniform(-0.08, 0.08)  # init all rnn parameters
-    print('rnn weight l1 = %f' % (rnn_w.l1()))
-    dense = layer.Dense(
-        'dense',
-        data.vocab_size,
-        input_sample_shape=(
-            hidden_size,
-        ))
-    dense.to_device(cuda)
-    dense_w = dense.param_values()[0]
-    dense_b = dense.param_values()[1]
-    print('dense w ', dense_w.shape)
-    print('dense b ', dense_b.shape)
-    initializer.uniform(dense_w, dense_w.shape[0], 0)
-    print('dense weight l1 = %f' % (dense_w.l1()))
-    dense_b.set_value(0)
-    print('dense b l1 = %f' % (dense_b.l1()))
+    model = CharRNN(data.vocab_size, hidden_size)
+    model.on_device(cuda)
+    model.graph(True, True)
 
-    g_dense_w = tensor.Tensor(dense_w.shape, cuda)
-    g_dense_b = tensor.Tensor(dense_b.shape, cuda)
+    inputs, labels = None, None
 
-    lossfun = loss.SoftmaxCrossEntropy()
     for epoch in range(max_epoch):
+        model.train()
         train_loss = 0
-        for b in range(data.num_train_batch):
-            batch = data.train_dat[b * batch_size: (b + 1) * batch_size]
+        for b in tqdm(range(data.num_train_batch)):
+            batch = data.train_dat[b * batch_size:(b + 1) * batch_size]
             inputs, labels = convert(batch, batch_size, seq_length,
-                                     data.vocab_size, cuda)
-            inputs.append(tensor.Tensor())
-            inputs.append(tensor.Tensor())
+                                     data.vocab_size, cuda, inputs, labels)
+            model.reset_states(cuda)
+            y = model(inputs)
+            loss = model.loss(y, labels)
+            model.optim(loss)
+            train_loss += tensor.to_numpy(loss)[0]
 
-            outputs = rnn.forward(True, inputs)[0:-2]
-            grads = []
-            batch_loss = 0
-            g_dense_w.set_value(0.0)
-            g_dense_b.set_value(0.0)
-            for output, label in zip(outputs, labels):
-                act = dense.forward(True, output)
-                lvalue = lossfun.forward(True, act, label)
-                batch_loss += lvalue.l1()
-                grad = lossfun.backward()
-                grad /= batch_size
-                grad, gwb = dense.backward(True, grad)
-                grads.append(grad)
-                g_dense_w += gwb[0]
-                g_dense_b += gwb[1]
-                # print output.l1(), act.l1()
-            utils.update_progress(
-                b * 1.0 / data.num_train_batch, 'training loss = %f' %
-                (batch_loss / seq_length))
-            train_loss += batch_loss
-
-            grads.append(tensor.Tensor())
-            grads.append(tensor.Tensor())
-            g_rnn_w = rnn.backward(True, grads)[1][0]
-            dense_w, dense_b = dense.param_values()
-            opt.apply_with_lr(epoch, get_lr(epoch), g_rnn_w, rnn_w, 'rnnw')
-            opt.apply_with_lr(
-                epoch, get_lr(epoch),
-                g_dense_w, dense_w, 'dense_w')
-            opt.apply_with_lr(
-                epoch, get_lr(epoch),
-                g_dense_b, dense_b, 'dense_b')
         print('\nEpoch %d, train loss is %f' %
               (epoch, train_loss / data.num_train_batch / seq_length))
 
-        eval_loss = 0
-        for b in range(data.num_test_batch):
-            batch = data.val_dat[b * batch_size: (b + 1) * batch_size]
-            inputs, labels = convert(batch, batch_size, seq_length,
-                                     data.vocab_size, cuda)
-            inputs.append(tensor.Tensor())
-            inputs.append(tensor.Tensor())
-            outputs = rnn.forward(False, inputs)[0:-2]
-            for output, label in zip(outputs, labels):
-                output = dense.forward(True, output)
-                eval_loss += lossfun.forward(True, output, label).l1()
-        print('Epoch %d, evaluation loss is %f' %
-              (epoch, eval_loss / data.num_test_batch / seq_length))
+        # evaluate(model, data, batch_size, seq_length, cuda, inputs, labels)
+        # sample(model, data, cuda)
 
-        if (epoch + 1) % 30 == 0:
-            # checkpoint the file model
-            with open('%s_%d.bin' % (model_path, epoch), 'wb') as fd:
-                print('saving model to %s' % model_path)
-                d = {}
-                for name, w in zip(
-                        ['rnn_w', 'dense_w', 'dense_b'],
-                        [rnn_w, dense_w, dense_b]):
-                    w.to_host()
-                    d[name] = tensor.to_numpy(w)
-                    w.to_device(cuda)
-                d['idx_to_char'] = data.idx_to_char
-                d['char_to_idx'] = data.char_to_idx
-                d['hidden_size'] = hidden_size
-                d['num_stacks'] = num_stacks
-                d['dropout'] = dropout
-
-                pickle.dump(d, fd)
 
 if __name__ == '__main__':
     parser = argparse.ArgumentParser(
@@ -224,9 +229,11 @@
     parser.add_argument('-b', type=int, default=32, help='batch_size')
     parser.add_argument('-l', type=int, default=64, help='sequence length')
     parser.add_argument('-d', type=int, default=128, help='hidden size')
-    parser.add_argument('-s', type=int, default=2, help='num of stacks')
     parser.add_argument('-m', type=int, default=50, help='max num of epoch')
     args = parser.parse_args()
     data = Data(args.data, batch_size=args.b, seq_length=args.l)
-    train(data, args.m,  hidden_size=args.d, num_stacks=args.s,
-          seq_length=args.l, batch_size=args.b)
+    train(data,
+          args.m,
+          hidden_size=args.d,
+          seq_length=args.l,
+          batch_size=args.b)
diff --git a/python/singa/autograd.py b/python/singa/autograd.py
index d26e794..e0813da 100644
--- a/python/singa/autograd.py
+++ b/python/singa/autograd.py
@@ -1806,13 +1806,13 @@
     """
 
     def __init__(
-        self,
-        in_channels,
-        out_channels,
-        kernel_size,
-        stride=1,
-        padding=0,
-        bias=False,
+            self,
+            in_channels,
+            out_channels,
+            kernel_size,
+            stride=1,
+            padding=0,
+            bias=False,
     ):
         """
         Args:
@@ -3132,15 +3132,15 @@
     """
 
     def __init__(
-        self,
-        input_size,
-        hidden_size,
-        num_layers=1,
-        nonlinearity="tanh",
-        bias=True,
-        batch_first=False,
-        dropout=0,
-        bidirectional=False,
+            self,
+            input_size,
+            hidden_size,
+            num_layers=1,
+            nonlinearity="tanh",
+            bias=True,
+            batch_first=False,
+            dropout=0,
+            bidirectional=False,
     ):
         """
         Args:
@@ -3212,15 +3212,15 @@
     """
 
     def __init__(
-        self,
-        input_size,
-        hidden_size,
-        nonlinearity="tanh",
-        num_layers=1,
-        bias=True,
-        batch_first=False,
-        dropout=0,
-        bidirectional=False,
+            self,
+            input_size,
+            hidden_size,
+            nonlinearity="tanh",
+            num_layers=1,
+            bias=True,
+            batch_first=False,
+            dropout=0,
+            bidirectional=False,
     ):
         """
         Args:
@@ -3244,14 +3244,14 @@
         self.Wx = []
         for i in range(4):
             w = Tensor(shape=Wx_shape, requires_grad=True, stores_grad=True)
-            w.gaussian(0.0, 1.0)
+            w.gaussian(0.0, 0.01)
             self.Wx.append(w)
 
         Wh_shape = (hidden_size, hidden_size)
         self.Wh = []
         for i in range(4):
             w = Tensor(shape=Wh_shape, requires_grad=True, stores_grad=True)
-            w.gaussian(0.0, 1.0)
+            w.gaussian(0.0, 0.01)
             self.Wh.append(w)
 
         Bx_shape = (hidden_size,)
diff --git a/tool/release/release.py b/tool/release/release.py
index 5de1655..28a53db 100755
--- a/tool/release/release.py
+++ b/tool/release/release.py
@@ -49,7 +49,7 @@
         default=False,
         dest='confirmed',
         action='store_true',
-        help="In interactive mode, for user to confirm. Could be used in sript")
+        help="In interactive mode, for user to confirm. Could be used in script")
     parser.add_argument('type',
                         choices=['major', 'minor', 'patch', 'rc', 'stable'],
                         help="Release types")