| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| """ An example of predicting CAPTCHA image data with a LSTM network pre-trained with a CTC loss""" |
| |
| from __future__ import print_function |
| |
| import argparse |
| |
| import sys |
| import numpy as np |
| import cv2 |
| |
| |
| class lstm_ocr_model(object): |
| """LSTM network for predicting the Optical Character Recognition""" |
| # Keep Zero index for blank. (CTC request it) |
| CONST_CHAR = '0123456789' |
| |
| def __init__(self, path_of_json, path_of_params): |
| super(lstm_ocr_model, self).__init__() |
| self.path_of_json = path_of_json |
| self.path_of_params = path_of_params |
| self.predictor = None |
| self.__init_ocr() |
| |
| def __init_ocr(self): |
| num_label = 4 # Set your max length of label, add one more for blank |
| batch_size = 1 |
| |
| num_hidden = 100 |
| num_lstm_layer = 2 |
| init_c = [('l%d_init_c'%l, (batch_size, num_hidden)) for l in range(num_lstm_layer)] |
| init_h = [('l%d_init_h'%l, (batch_size, num_hidden)) for l in range(num_lstm_layer)] |
| init_states = init_c + init_h |
| |
| init_state_arrays = np.zeros((batch_size, num_hidden), dtype="float32") |
| self.init_state_dict = {} |
| for x in init_states: |
| self.init_state_dict[x[0]] = init_state_arrays |
| |
| all_shapes = [('data', (batch_size, 80, 30))] + init_states + [('label', (batch_size, num_label))] |
| all_shapes_dict = {} |
| for _shape in all_shapes: |
| all_shapes_dict[_shape[0]] = _shape[1] |
| self.predictor = Predictor(open(self.path_of_json, 'rb').read(), |
| open(self.path_of_params, 'rb').read(), |
| all_shapes_dict) |
| |
| def forward_ocr(self, img_): |
| """Forward the image through the LSTM network model |
| |
| Parameters |
| ---------- |
| img_: int of array |
| |
| Returns |
| ---------- |
| label_list: string of list |
| """ |
| img_ = cv2.resize(img_, (80, 30)) |
| img_ = img_.transpose(1, 0) |
| print(img_.shape) |
| img_ = img_.reshape((1, 80, 30)) |
| print(img_.shape) |
| # img_ = img_.reshape((80 * 30)) |
| img_ = np.multiply(img_, 1 / 255.0) |
| self.predictor.forward(data=img_, **self.init_state_dict) |
| prob = self.predictor.get_output(0) |
| label_list = [] |
| for p in prob: |
| print(np.argsort(p)) |
| max_index = np.argsort(p)[::-1][0] |
| label_list.append(max_index) |
| return self.__get_string(label_list) |
| |
| @staticmethod |
| def __get_string(label_list): |
| # Do CTC label rule |
| # CTC cannot emit a repeated symbol on consecutive timesteps |
| ret = [] |
| label_list2 = [0] + list(label_list) |
| for i in range(len(label_list)): |
| c1 = label_list2[i] |
| c2 = label_list2[i+1] |
| if c2 in (0, c1): |
| continue |
| ret.append(c2) |
| # change to ascii |
| s = '' |
| for l in ret: |
| if l > 0 and l < (len(lstm_ocr_model.CONST_CHAR)+1): |
| c = lstm_ocr_model.CONST_CHAR[l-1] |
| else: |
| c = '' |
| s += c |
| return s |
| |
| |
| if __name__ == '__main__': |
| # parser = argparse.ArgumentParser() |
| # parser.add_argument("path", help="Path to the CAPTCHA image file") |
| # parser.add_argument("--prefix", help="Checkpoint prefix [Default 'ocr']", default='ocr') |
| # parser.add_argument("--epoch", help="Checkpoint epoch [Default 100]", type=int, default=100) |
| # args = parser.parse_args() |
| # |
| # # Create array of zeros for LSTM init states |
| # hp = Hyperparams() |
| # init_states = lstm.init_states(batch_size=1, num_lstm_layer=hp.num_lstm_layer, num_hidden=hp.num_hidden) |
| # init_state_arrays = [mx.nd.zeros(x[1]) for x in init_states] |
| # # Read the image into an ndarray |
| # img = cv2.resize(cv2.imread(args.path, 0), (80, 30)).astype(np.float32) / 255 |
| # img = np.expand_dims(img.transpose(1, 0), 0) |
| # |
| # data_names = ['data'] + [s[0] for s in init_states] |
| # sample = SimpleBatch(data_names, data=[mx.nd.array(img)] + init_state_arrays) |
| # |
| # sym, arg_params, aux_params = mx.model.load_checkpoint(args.prefix, args.epoch) |
| # |
| # # We don't need CTC loss for prediction, just a simple softmax will suffice. |
| # # We get the output of the layer just before the loss layer ('pred_fc') and add softmax on top |
| # pred_fc = sym.get_internals()['pred_fc_output'] |
| # sym = mx.sym.softmax(data=pred_fc) |
| # |
| # mod = mx.mod.Module(symbol=sym, context=mx.cpu(), data_names=data_names, label_names=None) |
| # mod.bind(for_training=False, data_shapes=sample.provide_data) |
| # mod.set_params(arg_params, aux_params, allow_missing=False) |
| # |
| # mod.forward(sample) |
| # prob = mod.get_outputs()[0].asnumpy() |
| # |
| # label_list = list() |
| # prediction = CtcMetrics.ctc_label(np.argmax(prob, axis=-1).tolist()) |
| # # Predictions are 1 to 10 for digits 0 to 9 respectively (prediction 0 means no-digit) |
| # prediction = [p - 1 for p in prediction] |
| # print("Digits:", prediction) |
| # exit(0) |
| # |
| |
| parser = argparse.ArgumentParser() |
| parser.add_argument("predict_lib_path", help="Path to directory containing mxnet_predict.so") |
| args = parser.parse_args() |
| |
| sys.path.append(args.predict_lib_path + "/python") |
| from mxnet_predict import Predictor |
| |
| _lstm_ocr_model = lstm_ocr_model('ocr-symbol.json', 'ocr-0010.params') |
| img = cv2.imread('sample0.png', 0) |
| _str = _lstm_ocr_model.forward_ocr(img) |
| print('Result: ', _str) |