| # pylint: disable=C0111,too-many-arguments,too-many-instance-attributes,too-many-locals,redefined-outer-name,fixme |
| # pylint: disable=superfluous-parens, no-member, invalid-name |
| from __future__ import print_function |
| import sys |
| sys.path.insert(0, "../../python") |
| import numpy as np |
| import mxnet as mx |
| import random |
| from lstm import lstm_unroll |
| |
| class SimpleBatch(object): |
| def __init__(self, data_names, data, label_names, label): |
| self.data = data |
| self.label = label |
| self.data_names = data_names |
| self.label_names = label_names |
| |
| self.pad = 0 |
| self.index = None # TODO: what is index? |
| |
| @property |
| def provide_data(self): |
| return [(n, x.shape) for n, x in zip(self.data_names, self.data)] |
| |
| @property |
| def provide_label(self): |
| return [(n, x.shape) for n, x in zip(self.label_names, self.label)] |
| |
| def gen_feature(n): |
| ret = np.zeros(10) |
| ret[n] = 1 |
| return ret |
| |
| def gen_rand(): |
| num = random.randint(0, 9999) |
| buf = str(num) |
| while len(buf) < 4: |
| buf = "0" + buf |
| ret = [] |
| for i in range(80): |
| c = int(buf[i // 20]) |
| ret.append(gen_feature(c)) |
| return buf, ret |
| |
| def get_label(buf): |
| ret = np.zeros(4) |
| for i in range(4): |
| ret[i] = 1 + int(buf[i]) |
| return ret |
| |
| class DataIter(mx.io.DataIter): |
| def __init__(self, count, batch_size, num_label, init_states): |
| super(DataIter, self).__init__() |
| self.batch_size = batch_size |
| self.count = count |
| self.num_label = num_label |
| self.init_states = init_states |
| self.init_state_arrays = [mx.nd.zeros(x[1]) for x in init_states] |
| self.provide_data = [('data', (batch_size, 80, 10))] + init_states |
| self.provide_label = [('label', (self.batch_size, 4))] |
| |
| def __iter__(self): |
| init_state_names = [x[0] for x in self.init_states] |
| for k in range(self.count): |
| data = [] |
| label = [] |
| for i in range(self.batch_size): |
| num, img = gen_rand() |
| data.append(img) |
| label.append(get_label(num)) |
| |
| data_all = [mx.nd.array(data)] + self.init_state_arrays |
| label_all = [mx.nd.array(label)] |
| data_names = ['data'] + init_state_names |
| label_names = ['label'] |
| |
| |
| data_batch = SimpleBatch(data_names, data_all, label_names, label_all) |
| yield data_batch |
| |
| def reset(self): |
| pass |
| |
| BATCH_SIZE = 32 |
| SEQ_LENGTH = 80 |
| |
| def ctc_label(p): |
| ret = [] |
| p1 = [0] + p |
| for i in range(len(p)): |
| c1 = p1[i] |
| c2 = p1[i+1] |
| if c2 == 0 or c2 == c1: |
| continue |
| ret.append(c2) |
| return ret |
| |
| |
| def Accuracy(label, pred): |
| global BATCH_SIZE |
| global SEQ_LENGTH |
| hit = 0. |
| total = 0. |
| for i in range(BATCH_SIZE): |
| l = label[i] |
| p = [] |
| for k in range(SEQ_LENGTH): |
| p.append(np.argmax(pred[k * BATCH_SIZE + i])) |
| p = ctc_label(p) |
| if len(p) == len(l): |
| match = True |
| for k in range(len(p)): |
| if p[k] != int(l[k]): |
| match = False |
| break |
| if match: |
| hit += 1.0 |
| total += 1.0 |
| return hit / total |
| |
| if __name__ == '__main__': |
| num_hidden = 100 |
| num_lstm_layer = 1 |
| |
| num_epoch = 10 |
| learning_rate = 0.001 |
| momentum = 0.9 |
| num_label = 4 |
| |
| contexts = [mx.context.gpu(0)] |
| |
| def sym_gen(seq_len): |
| return lstm_unroll(num_lstm_layer, seq_len, |
| num_hidden=num_hidden, |
| num_label = num_label) |
| |
| init_c = [('l%d_init_c'%l, (BATCH_SIZE, num_hidden)) for l in range(num_lstm_layer)] |
| init_h = [('l%d_init_h'%l, (BATCH_SIZE, num_hidden)) for l in range(num_lstm_layer)] |
| init_states = init_c + init_h |
| |
| data_train = DataIter(100000, BATCH_SIZE, num_label, init_states) |
| data_val = DataIter(1000, BATCH_SIZE, num_label, init_states) |
| |
| symbol = sym_gen(SEQ_LENGTH) |
| |
| model = mx.model.FeedForward(ctx=contexts, |
| symbol=symbol, |
| num_epoch=num_epoch, |
| learning_rate=learning_rate, |
| momentum=momentum, |
| wd=0.00001, |
| initializer=mx.init.Xavier(factor_type="in", magnitude=2.34)) |
| |
| import logging |
| head = '%(asctime)-15s %(message)s' |
| logging.basicConfig(level=logging.DEBUG, format=head) |
| |
| print('begin fit') |
| |
| model.fit(X=data_train, eval_data=data_val, |
| eval_metric = mx.metric.np(Accuracy), |
| batch_end_callback=mx.callback.Speedometer(BATCH_SIZE, 50),) |
| |
| model.save("ocr") |