blob: eec48ead78b643b8cb5cbd42fbeef8ca49cb2ae5 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
""" An example of predicting CAPTCHA image data with a LSTM network pre-trained with a CTC loss"""
from __future__ import print_function
import argparse
import numpy as np
from ctc_metrics import CtcMetrics
import cv2
from hyperparams import Hyperparams
import lstm
import mxnet as mx
from ocr_iter import SimpleBatch
def read_img(path):
""" Reads image specified by path into numpy.ndarray"""
img = cv2.resize(cv2.imread(path, 0), (80, 30)).astype(np.float32) / 255
img = np.expand_dims(img.transpose(1, 0), 0)
return img
def lstm_init_states(batch_size):
""" Returns a tuple of names and zero arrays for LSTM init states"""
hp = Hyperparams()
init_shapes = lstm.init_states(batch_size=batch_size, num_lstm_layer=hp.num_lstm_layer, num_hidden=hp.num_hidden)
init_names = [s[0] for s in init_shapes]
init_arrays = [mx.nd.zeros(x[1]) for x in init_shapes]
return init_names, init_arrays
def load_module(prefix, epoch, data_names, data_shapes):
"""Loads the model from checkpoint specified by prefix and epoch, binds it
to an executor, and sets its parameters and returns a mx.mod.Module
"""
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
# We don't need CTC loss for prediction, just a simple softmax will suffice.
# We get the output of the layer just before the loss layer ('pred_fc') and add softmax on top
pred_fc = sym.get_internals()['pred_fc_output']
sym = mx.sym.softmax(data=pred_fc)
mod = mx.mod.Module(symbol=sym, context=mx.cpu(), data_names=data_names, label_names=None)
mod.bind(for_training=False, data_shapes=data_shapes)
mod.set_params(arg_params, aux_params, allow_missing=False)
return mod
def main():
"""Program entry point"""
parser = argparse.ArgumentParser()
parser.add_argument("path", help="Path to the CAPTCHA image file")
parser.add_argument("--prefix", help="Checkpoint prefix [Default 'ocr']", default='ocr')
parser.add_argument("--epoch", help="Checkpoint epoch [Default 100]", type=int, default=100)
args = parser.parse_args()
init_state_names, init_state_arrays = lstm_init_states(batch_size=1)
img = read_img(args.path)
sample = SimpleBatch(
data_names=['data'] + init_state_names,
data=[mx.nd.array(img)] + init_state_arrays)
mod = load_module(args.prefix, args.epoch, sample.data_names, sample.provide_data)
mod.forward(sample)
prob = mod.get_outputs()[0].asnumpy()
prediction = CtcMetrics.ctc_label(np.argmax(prob, axis=-1).tolist())
# Predictions are 1 to 10 for digits 0 to 9 respectively (prediction 0 means no-digit)
prediction = [p - 1 for p in prediction]
print("Digits:", prediction)
if __name__ == '__main__':
main()