blob: c7b0ccc3df3d8a5509f82e60b5714066119d386e [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=C0111,too-many-arguments,too-many-instance-attributes,too-many-locals,redefined-outer-name,fixme
# pylint: disable=superfluous-parens, no-member, invalid-name
from __future__ import print_function
import sys
sys.path.insert(0, "../../python")
import numpy as np
import mxnet as mx
import random
from lstm import lstm_unroll
class SimpleBatch(object):
def __init__(self, data_names, data, label_names, label):
self.data = data
self.label = label
self.data_names = data_names
self.label_names = label_names
self.pad = 0
self.index = None # TODO: what is index?
@property
def provide_data(self):
return [(n, x.shape) for n, x in zip(self.data_names, self.data)]
@property
def provide_label(self):
return [(n, x.shape) for n, x in zip(self.label_names, self.label)]
def gen_feature(n):
ret = np.zeros(10)
ret[n] = 1
return ret
def gen_rand():
num = random.randint(0, 9999)
buf = str(num)
while len(buf) < 4:
buf = "0" + buf
ret = []
for i in range(80):
c = int(buf[i // 20])
ret.append(gen_feature(c))
return buf, ret
def get_label(buf):
ret = np.zeros(4)
for i in range(4):
ret[i] = 1 + int(buf[i])
return ret
class DataIter(mx.io.DataIter):
def __init__(self, count, batch_size, num_label, init_states):
super(DataIter, self).__init__()
self.batch_size = batch_size
self.count = count
self.num_label = num_label
self.init_states = init_states
self.init_state_arrays = [mx.nd.zeros(x[1]) for x in init_states]
self.provide_data = [('data', (batch_size, 80, 10))] + init_states
self.provide_label = [('label', (self.batch_size, 4))]
def __iter__(self):
init_state_names = [x[0] for x in self.init_states]
for k in range(self.count):
data = []
label = []
for i in range(self.batch_size):
num, img = gen_rand()
data.append(img)
label.append(get_label(num))
data_all = [mx.nd.array(data)] + self.init_state_arrays
label_all = [mx.nd.array(label)]
data_names = ['data'] + init_state_names
label_names = ['label']
data_batch = SimpleBatch(data_names, data_all, label_names, label_all)
yield data_batch
def reset(self):
pass
BATCH_SIZE = 32
SEQ_LENGTH = 80
def ctc_label(p):
ret = []
p1 = [0] + p
for i in range(len(p)):
c1 = p1[i]
c2 = p1[i+1]
if c2 == 0 or c2 == c1:
continue
ret.append(c2)
return ret
def Accuracy(label, pred):
global BATCH_SIZE
global SEQ_LENGTH
hit = 0.
total = 0.
for i in range(BATCH_SIZE):
l = label[i]
p = []
for k in range(SEQ_LENGTH):
p.append(np.argmax(pred[k * BATCH_SIZE + i]))
p = ctc_label(p)
if len(p) == len(l):
match = True
for k in range(len(p)):
if p[k] != int(l[k]):
match = False
break
if match:
hit += 1.0
total += 1.0
return hit / total
if __name__ == '__main__':
num_hidden = 100
num_lstm_layer = 1
num_epoch = 10
learning_rate = 0.001
momentum = 0.9
num_label = 4
contexts = [mx.context.gpu(0)]
def sym_gen(seq_len):
return lstm_unroll(num_lstm_layer, seq_len,
num_hidden=num_hidden,
num_label = num_label)
init_c = [('l%d_init_c'%l, (BATCH_SIZE, num_hidden)) for l in range(num_lstm_layer)]
init_h = [('l%d_init_h'%l, (BATCH_SIZE, num_hidden)) for l in range(num_lstm_layer)]
init_states = init_c + init_h
data_train = DataIter(100000, BATCH_SIZE, num_label, init_states)
data_val = DataIter(1000, BATCH_SIZE, num_label, init_states)
symbol = sym_gen(SEQ_LENGTH)
model = mx.model.FeedForward(ctx=contexts,
symbol=symbol,
num_epoch=num_epoch,
learning_rate=learning_rate,
momentum=momentum,
wd=0.00001,
initializer=mx.init.Xavier(factor_type="in", magnitude=2.34))
import logging
head = '%(asctime)-15s %(message)s'
logging.basicConfig(level=logging.DEBUG, format=head)
print('begin fit')
model.fit(X=data_train, eval_data=data_val,
eval_metric = mx.metric.np(Accuracy),
batch_end_callback=mx.callback.Speedometer(BATCH_SIZE, 50),)
model.save("ocr")