blob: 745a1f87b17ce14095917da1813bd00e1a30dc3f [file] [log] [blame]
import sys, os
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
sys.path.append("../../../amalgamation/python/")
from mxnet_predict import Predictor, load_ndarray_file
import logging
import numpy as np
from skimage import io, transform
# Load the pre-trained model
prefix = "resnet/resnet-18"
num_round = 0
symbol_file = "%s-symbol.json" % prefix
param_file = "%s-0000.params" % prefix
predictor = Predictor(open(symbol_file, "r").read(),
open(param_file, "rb").read(),
{'data':(1, 3, 224, 224)})
synset = [l.strip() for l in open('resnet/synset.txt').readlines()]
def PreprocessImage(path, show_img=False):
# load image
img = io.imread(path)
print("Original Image Shape: ", img.shape)
# we crop image from center
short_egde = min(img.shape[:2])
yy = int((img.shape[0] - short_egde) / 2)
xx = int((img.shape[1] - short_egde) / 2)
crop_img = img[yy : yy + short_egde, xx : xx + short_egde]
# resize to 224, 224
resized_img = transform.resize(crop_img, (224, 224))
# convert to numpy.ndarray
sample = np.asarray(resized_img) * 255
# swap axes to make image from (224, 224, 3) to (3, 224, 224)
sample = np.swapaxes(sample, 0, 2)
sample = np.swapaxes(sample, 1, 2)
# sub mean
return sample
# Get preprocessed batch (single image batch)
batch = PreprocessImage('./download.jpg', True)
predictor.forward(data=batch)
prob = predictor.get_output(0)[0]
pred = np.argsort(prob)[::-1]
# Get top1 label
top1 = synset[pred[0]]
print("Top1: ", top1)
# Get top5 label
top5 = [synset[pred[i]] for i in range(5)]
print("Top5: ", top5)