| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| import sys, os |
| curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) |
| sys.path.append("../../../amalgamation/python/") |
| |
| from mxnet_predict import Predictor, load_ndarray_file |
| import logging |
| import numpy as np |
| from skimage import io, transform |
| |
| # Load the pre-trained model |
| prefix = "resnet/resnet-18" |
| num_round = 0 |
| symbol_file = "%s-symbol.json" % prefix |
| param_file = "%s-0000.params" % prefix |
| predictor = Predictor(open(symbol_file, "r").read(), |
| open(param_file, "rb").read(), |
| {'data':(1, 3, 224, 224)}) |
| |
| synset = [l.strip() for l in open('resnet/synset.txt').readlines()] |
| |
| def PreprocessImage(path, show_img=False): |
| # load image |
| img = io.imread(path) |
| print("Original Image Shape: ", img.shape) |
| # we crop image from center |
| short_egde = min(img.shape[:2]) |
| yy = int((img.shape[0] - short_egde) / 2) |
| xx = int((img.shape[1] - short_egde) / 2) |
| crop_img = img[yy : yy + short_egde, xx : xx + short_egde] |
| # resize to 224, 224 |
| resized_img = transform.resize(crop_img, (224, 224)) |
| # convert to numpy.ndarray |
| sample = np.asarray(resized_img) * 255 |
| # swap axes to make image from (224, 224, 3) to (3, 224, 224) |
| sample = np.swapaxes(sample, 0, 2) |
| sample = np.swapaxes(sample, 1, 2) |
| |
| # sub mean |
| return sample |
| |
| # Get preprocessed batch (single image batch) |
| batch = PreprocessImage('./download.jpg', True) |
| |
| predictor.forward(data=batch) |
| prob = predictor.get_output(0)[0] |
| |
| pred = np.argsort(prob)[::-1] |
| # Get top1 label |
| top1 = synset[pred[0]] |
| print("Top1: ", top1) |
| # Get top5 label |
| top5 = [synset[pred[i]] for i in range(5)] |
| print("Top5: ", top5) |
| |