|  | # pylint: disable=C0111,too-many-arguments,too-many-instance-attributes,too-many-locals,redefined-outer-name,fixme | 
|  | # pylint: disable=superfluous-parens, no-member, invalid-name | 
|  | from __future__ import print_function | 
|  | import sys, random | 
|  | sys.path.insert(0, "../../python") | 
|  | import numpy as np | 
|  | import mxnet as mx | 
|  |  | 
|  | from lstm import lstm_unroll | 
|  |  | 
|  | from io import BytesIO | 
|  | from captcha.image import ImageCaptcha | 
|  | import cv2, random | 
|  |  | 
|  | class SimpleBatch(object): | 
|  | def __init__(self, data_names, data, label_names, label): | 
|  | self.data = data | 
|  | self.label = label | 
|  | self.data_names = data_names | 
|  | self.label_names = label_names | 
|  |  | 
|  | self.pad = 0 | 
|  | self.index = None # TODO: what is index? | 
|  |  | 
|  | @property | 
|  | def provide_data(self): | 
|  | return [(n, x.shape) for n, x in zip(self.data_names, self.data)] | 
|  |  | 
|  | @property | 
|  | def provide_label(self): | 
|  | return [(n, x.shape) for n, x in zip(self.label_names, self.label)] | 
|  |  | 
|  | def gen_rand(): | 
|  | buf = "" | 
|  | max_len = random.randint(3,4) | 
|  | for i in range(max_len): | 
|  | buf += str(random.randint(0,9)) | 
|  | return buf | 
|  |  | 
|  | def get_label(buf): | 
|  | ret = np.zeros(4) | 
|  | for i in range(len(buf)): | 
|  | ret[i] = 1 + int(buf[i]) | 
|  | if len(buf) == 3: | 
|  | ret[3] = 0 | 
|  | return ret | 
|  |  | 
|  | class OCRIter(mx.io.DataIter): | 
|  | def __init__(self, count, batch_size, num_label, init_states): | 
|  | super(OCRIter, self).__init__() | 
|  | # you can get this font from http://font.ubuntu.com/ | 
|  | self.captcha = ImageCaptcha(fonts=['./font/Ubuntu-M.ttf']) | 
|  | self.batch_size = batch_size | 
|  | self.count = count | 
|  | self.num_label = num_label | 
|  | self.init_states = init_states | 
|  | self.init_state_arrays = [mx.nd.zeros(x[1]) for x in init_states] | 
|  | self.provide_data = [('data', (batch_size, 2400))] + init_states | 
|  | self.provide_label = [('label', (self.batch_size, 4))] | 
|  |  | 
|  | def __iter__(self): | 
|  | print('iter') | 
|  | init_state_names = [x[0] for x in self.init_states] | 
|  | for k in range(self.count): | 
|  | data = [] | 
|  | label = [] | 
|  | for i in range(self.batch_size): | 
|  | num = gen_rand() | 
|  | img = self.captcha.generate(num) | 
|  | img = np.fromstring(img.getvalue(), dtype='uint8') | 
|  | img = cv2.imdecode(img, cv2.IMREAD_GRAYSCALE) | 
|  | img = cv2.resize(img, (80, 30)) | 
|  | img = img.transpose(1, 0) | 
|  | img = img.reshape((80 * 30)) | 
|  | img = np.multiply(img, 1/255.0) | 
|  | data.append(img) | 
|  | label.append(get_label(num)) | 
|  |  | 
|  | data_all = [mx.nd.array(data)] + self.init_state_arrays | 
|  | label_all = [mx.nd.array(label)] | 
|  | data_names = ['data'] + init_state_names | 
|  | label_names = ['label'] | 
|  |  | 
|  |  | 
|  | data_batch = SimpleBatch(data_names, data_all, label_names, label_all) | 
|  | yield data_batch | 
|  |  | 
|  | def reset(self): | 
|  | pass | 
|  |  | 
|  | BATCH_SIZE = 32 | 
|  | SEQ_LENGTH = 80 | 
|  |  | 
|  | def ctc_label(p): | 
|  | ret = [] | 
|  | p1 = [0] + p | 
|  | for i in range(len(p)): | 
|  | c1 = p1[i] | 
|  | c2 = p1[i+1] | 
|  | if c2 == 0 or c2 == c1: | 
|  | continue | 
|  | ret.append(c2) | 
|  | return ret | 
|  |  | 
|  | def remove_blank(l): | 
|  | ret = [] | 
|  | for i in range(len(l)): | 
|  | if l[i] == 0: | 
|  | break | 
|  | ret.append(l[i]) | 
|  | return ret | 
|  |  | 
|  | def Accuracy(label, pred): | 
|  | global BATCH_SIZE | 
|  | global SEQ_LENGTH | 
|  | hit = 0. | 
|  | total = 0. | 
|  | for i in range(BATCH_SIZE): | 
|  | l = remove_blank(label[i]) | 
|  | p = [] | 
|  | for k in range(SEQ_LENGTH): | 
|  | p.append(np.argmax(pred[k * BATCH_SIZE + i])) | 
|  | p = ctc_label(p) | 
|  | if len(p) == len(l): | 
|  | match = True | 
|  | for k in range(len(p)): | 
|  | if p[k] != int(l[k]): | 
|  | match = False | 
|  | break | 
|  | if match: | 
|  | hit += 1.0 | 
|  | total += 1.0 | 
|  | return hit / total | 
|  |  | 
|  | def LCS(p,l): | 
|  | # Dynamic Programming Finding LCS | 
|  | if len(p) == 0: | 
|  | return 0 | 
|  | P = np.array(list(p)).reshape((1, len(p))) | 
|  | L = np.array(list(l)).reshape((len(l), 1)) | 
|  | M = np.int32(P == L) | 
|  | for i in range(M.shape[0]): | 
|  | for j in range(M.shape[1]): | 
|  | up = 0 if i == 0 else M[i-1,j] | 
|  | left = 0 if j == 0 else M[i,j-1] | 
|  | M[i,j] = max(up, left, M[i,j] if (i == 0 or j == 0) else M[i,j] + M[i-1,j-1]) | 
|  | return M.max() | 
|  |  | 
|  |  | 
|  | def Accuracy_LCS(label, pred): | 
|  | global BATCH_SIZE | 
|  | global SEQ_LENGTH | 
|  | hit = 0. | 
|  | total = 0. | 
|  | for i in range(BATCH_SIZE): | 
|  | l = remove_blank(label[i]) | 
|  | p = [] | 
|  | for k in range(SEQ_LENGTH): | 
|  | p.append(np.argmax(pred[k * BATCH_SIZE + i])) | 
|  | p = ctc_label(p) | 
|  | hit += LCS(p,l) * 1.0 / len(l) | 
|  | total += 1.0 | 
|  | return hit / total | 
|  |  | 
|  | if __name__ == '__main__': | 
|  | num_hidden = 100 | 
|  | num_lstm_layer = 2 | 
|  |  | 
|  | num_epoch = 10 | 
|  | learning_rate = 0.001 | 
|  | momentum = 0.9 | 
|  | num_label = 4 | 
|  |  | 
|  | contexts = [mx.context.gpu(0)] | 
|  |  | 
|  | def sym_gen(seq_len): | 
|  | return lstm_unroll(num_lstm_layer, seq_len, | 
|  | num_hidden=num_hidden, | 
|  | num_label = num_label) | 
|  |  | 
|  | init_c = [('l%d_init_c'%l, (BATCH_SIZE, num_hidden)) for l in range(num_lstm_layer)] | 
|  | init_h = [('l%d_init_h'%l, (BATCH_SIZE, num_hidden)) for l in range(num_lstm_layer)] | 
|  | init_states = init_c + init_h | 
|  |  | 
|  | data_train = OCRIter(10000, BATCH_SIZE, num_label, init_states) | 
|  | data_val = OCRIter(1000, BATCH_SIZE, num_label, init_states) | 
|  |  | 
|  | symbol = sym_gen(SEQ_LENGTH) | 
|  |  | 
|  | model = mx.model.FeedForward(ctx=contexts, | 
|  | symbol=symbol, | 
|  | num_epoch=num_epoch, | 
|  | learning_rate=learning_rate, | 
|  | momentum=momentum, | 
|  | wd=0.00001, | 
|  | initializer=mx.init.Xavier(factor_type="in", magnitude=2.34)) | 
|  |  | 
|  | import logging | 
|  | head = '%(asctime)-15s %(message)s' | 
|  | logging.basicConfig(level=logging.DEBUG, format=head) | 
|  |  | 
|  | print('begin fit') | 
|  |  | 
|  | prefix = 'ocr' | 
|  | model.fit(X=data_train, eval_data=data_val, | 
|  | eval_metric = mx.metric.np(Accuracy), | 
|  | # Use the following eval_metric if your num_label >= 10, or varies in a wide range | 
|  | # eval_metric = mx.metric.np(Accuracy_LCS), | 
|  | batch_end_callback=mx.callback.Speedometer(BATCH_SIZE, 50), | 
|  | epoch_end_callback = mx.callback.do_checkpoint(prefix, 1)) | 
|  |  | 
|  | model.save(prefix) |