blob: c233ab2e17aae52edc2b2594a719fb19c07b1cfa [file] [log] [blame]
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os, cv2\n",
"from keras.models import load_model\n",
"import matplotlib.pyplot as plt\n",
"%matplotlib inline\n",
"import numpy as np\n",
"import tensorflow as tf"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Pathes to model and weights"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"path_to_model = '/home/datalab-user/model_1000.json'\n",
"path_to_weights = '/home/datalab-user/weigths_1000.h5'"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Loading test images"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ROWS = 128\n",
"COLS = 128\n",
"CHANNELS = 3\n",
"TEST_DIR = '/home/datalab-user/test/'\n",
"all_image_names = os.listdir(TEST_DIR)\n",
"all_image_names.sort()\n",
"test_images = [TEST_DIR+i for i in all_image_names[6:11] + all_image_names[19:32] + all_image_names[33:34]]\n",
"\n",
"def read_image(file_path):\n",
" img = cv2.imread(file_path, cv2.IMREAD_COLOR)\n",
" return cv2.resize(img, (ROWS, COLS), interpolation=cv2.INTER_CUBIC).reshape(ROWS, COLS, CHANNELS)\n",
"\n",
"def prep_data(images):\n",
" count = len(images)\n",
" data = np.ndarray((count, ROWS, COLS, CHANNELS), dtype=np.uint8)\n",
"\n",
" for i, image_file in enumerate(images):\n",
" image = read_image(image_file)\n",
" data[i] = image\n",
" return data\n",
"test = prep_data(test_images)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Loading the model and making predictions on test data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with tf.device('/cpu:0'):\n",
" model = load_model(path_to_model)\n",
" model.load_weights(path_to_weights)\n",
" predictions = model.predict(test, verbose=2) "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Visualizing results (rendering can take about a minute)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"plt.figure(figsize=(16, 12))\n",
"for i in range(0, 12):\n",
" plt.subplot(3, 4, i+1)\n",
" if predictions[i, 0] >= 0.5: \n",
" plt.title('{:.2%} Dog'.format(predictions[i][0]))\n",
" else: \n",
" plt.title('{:.2%} Cat'.format(1-predictions[i][0]))\n",
" \n",
" plt.imshow(cv2.cvtColor(test[i], cv2.COLOR_BGR2RGB))\n",
" plt.axis('off')\n",
"plt.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "KERNEL_NAME"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.13"
}
},
"nbformat": 4,
"nbformat_minor": 2
}