blob: 048308a40ed57e93464fc0ab55ad0d4a291093bd [file] [log] [blame]
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Image Classification using Caffe VGG-19 model (Transfer Learning)\n",
"\n",
"This notebook demonstrates importing VGG-19 model from Caffe to SystemML and use that model to do an image classification. VGG-19 model has been trained using ImageNet dataset (1000 classes with ~ 14M images). If an image to be predicted is in one of the class VGG-19 has trained on then accuracy will be higher.\n",
"We expect prediction of any image through SystemML using VGG-19 model will be similar to that of image predicted through Caffe using VGG-19 model directly."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Prerequisite:\n",
"1. SystemML Python Package\n",
"To run this notebook you need to install systeml 1.0 (Master Branch code as of 08/24/2017 or later) python package.\n",
"2. Download Dogs-vs-Cats Kaggle dataset from https://www.kaggle.com/c/dogs-vs-cats/data location to a directory.\n",
" Unzip the train.zip directory to some location and update the variable \"train_dir\" in bottom two cells in which classifyImagesWTransfLearning() and classifyImages() methods are called to test this change. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"##### SystemML Python Package information"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Name: systemml\r\n",
"Version: 1.0.0\r\n",
"Summary: Apache SystemML is a distributed and declarative machine learning platform.\r\n",
"Home-page: http://systemml.apache.org/\r\n",
"Author: Apache SystemML\r\n",
"Author-email: dev@systemml.apache.org\r\n",
"License: Apache 2.0\r\n",
"Location: /home/asurve/src/anaconda2/lib/python2.7/site-packages\r\n",
"Requires: Pillow, numpy, scipy, pandas\r\n"
]
}
],
"source": [
"!pip show systemml"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### SystemML Build information\n",
"Following code will show SystemML information which is installed in the environment."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": false,
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"SystemML Built-Time:2017-08-17 19:20:41 UTC\n",
"Archiver-Version: Plexus Archiver\n",
"Artifact-Id: systemml\n",
"Build-Jdk: 1.8.0_121\n",
"Build-Time: 2017-08-17 19:20:41 UTC\n",
"Built-By: asurve\n",
"Created-By: Apache Maven 3.3.9\n",
"Group-Id: org.apache.systemml\n",
"Main-Class: org.apache.sysml.api.DMLScript\n",
"Manifest-Version: 1.0\n",
"Minimum-Recommended-Spark-Version: 2.1.0\n",
"Version: 1.0.0-SNAPSHOT\n",
"\n"
]
}
],
"source": [
"from systemml import MLContext\n",
"ml = MLContext(sc)\n",
"print (\"SystemML Built-Time:\"+ ml.buildTime())\n",
"print(ml.info())"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": false,
"scrolled": true
},
"outputs": [],
"source": [
"# Workaround for Python 2.7.13 to avoid certificate validation issue while downloading any file.\n",
"\n",
"import ssl\n",
"\n",
"try:\n",
" _create_unverified_https_context = ssl._create_unverified_context\n",
"except AttributeError:\n",
" # Legacy Python that doesn't verify HTTPS certificates by default\n",
" pass\n",
"else:\n",
" # Handle target environment that doesn't support HTTPS verification\n",
" ssl._create_default_https_context = _create_unverified_https_context"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# Create label.txt file\n",
"\n",
"def createLabelFile(fileName):\n",
" file = open(fileName, 'w')\n",
" file.write('1,\"Cat\" \\n')\n",
" file.write('2,\"Dog\" \\n')\n",
" file.close()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Download model, proto files and convert them to SystemML format.\n",
"\n",
"1. Download Caffe Model (VGG-19), proto files (deployer, network and solver) and label file.\n",
"2. Convert the Caffe model into SystemML input format.\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# Download caffemodel and proto files \n",
"\n",
"\n",
"def downloadAndConvertModel(downloadDir='.', trained_vgg_weights='trained_vgg_weights'):\n",
" \n",
" # Step 1: Download the VGG-19 model and other files.\n",
" import errno\n",
" import os\n",
" import urllib\n",
"\n",
" # Create directory, if exists don't error out\n",
" try:\n",
" os.makedirs(os.path.join(downloadDir,trained_vgg_weights))\n",
" except OSError as exc: # Python >2.5\n",
" if exc.errno == errno.EEXIST and os.path.isdir(trained_vgg_weights):\n",
" pass\n",
" else:\n",
" raise\n",
" \n",
" # Download deployer, network, solver proto and label files.\n",
" urllib.urlretrieve('https://raw.githubusercontent.com/apache/systemml/master/scripts/nn/examples/caffe2dml/models/imagenet/vgg19/VGG_ILSVRC_19_layers_deploy.proto', os.path.join(downloadDir,'VGG_ILSVRC_19_layers_deploy.proto'))\n",
" urllib.urlretrieve('https://raw.githubusercontent.com/apache/systemml/master/scripts/nn/examples/caffe2dml/models/imagenet/vgg19/VGG_ILSVRC_19_layers_network.proto',os.path.join(downloadDir,'VGG_ILSVRC_19_layers_network.proto'))\n",
" #TODO: After downloading network file (VGG_ILSVRC_19_layers_network.proto) , change num_output from 1000 to 2\n",
" \n",
" urllib.urlretrieve('https://raw.githubusercontent.com/apache/systemml/master/scripts/nn/examples/caffe2dml/models/imagenet/vgg19/VGG_ILSVRC_19_layers_solver.proto',os.path.join(downloadDir,'VGG_ILSVRC_19_layers_solver.proto'))\n",
" # TODO: set values as descrived below in VGG_ILSVRC_19_layers_solver.proto (Possibly through APIs whenever available)\n",
" # test_iter: 100\n",
" # stepsize: 40\n",
" # max_iter: 200\n",
" \n",
" # Create labels for data\n",
" ### 1,\"Cat\"\n",
" ### 2,\"Dog\"\n",
" createLabelFile(os.path.join(downloadDir, trained_vgg_weights, 'labels.txt'))\n",
"\n",
" # TODO: Following line commented as its 500MG file, if u need to download it please uncomment it and run.\n",
" # urllib.urlretrieve('http://www.robots.ox.ac.uk/~vgg/software/very_deep/caffe/VGG_ILSVRC_19_layers.caffemodel', os.path.join(downloadDir,'VGG_ILSVRC_19_layers.caffemodel'))\n",
"\n",
" # Step 2: Convert the caffemodel to trained_vgg_weights directory\n",
" import systemml as sml\n",
" sml.convert_caffemodel(sc, os.path.join(downloadDir,'VGG_ILSVRC_19_layers_deploy.proto'), os.path.join(downloadDir,'VGG_ILSVRC_19_layers.caffemodel'), os.path.join(downloadDir,trained_vgg_weights))\n",
" \n",
" return"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"##### PrintTopK\n",
"This function will print top K probabilities and indices from the result."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# Print top K indices and probability\n",
"\n",
"def printTopK(prob, label, k):\n",
" print(label, 'Top ', k, ' Index : ', np.argsort(-prob)[0, :k])\n",
" print(label, 'Top ', k, ' Probability : ', prob[0,np.argsort(-prob)[0, :k]])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Classify images\n",
"\n",
"This function classify images from images specified through urls.\n",
"\n",
"###### Input Parameters: \n",
" urls: List of urls\n",
" printTokKData (default False): Whether to print top K indices and probabilities\n",
" topK: Top K elements to be displayed. "
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import numpy as np\n",
"import urllib\n",
"from systemml.mllearn import Caffe2DML\n",
"import systemml as sml\n",
"\n",
"\n",
"def classifyImages(urls,img_shape=(3, 224, 224), printTokKData=False, topK=5, downloadDir='.', trained_vgg_weights='trained_vgg_weights'):\n",
"\n",
" size = (img_shape[1], img_shape[2])\n",
" \n",
" vgg = Caffe2DML(sqlCtx, solver=os.path.join(downloadDir,'VGG_ILSVRC_19_layers_solver.proto'), input_shape=img_shape)\n",
" vgg.load(trained_vgg_weights)\n",
"\n",
" for url in urls:\n",
" outFile = 'inputTest.jpg'\n",
" urllib.urlretrieve(url, outFile)\n",
" \n",
" from IPython.display import Image, display\n",
" display(Image(filename=outFile))\n",
" \n",
" print (\"Prediction of above image to ImageNet Class using\");\n",
"\n",
" ## Do image classification through SystemML processing\n",
" from PIL import Image\n",
" input_image = sml.convertImageToNumPyArr(Image.open(outFile), img_shape=img_shape\n",
" , color_mode='BGR', mean=sml.getDatasetMean('VGG_ILSVRC_19_2014'))\n",
" print (\"Image preprocessed through SystemML :: \", vgg.predict(input_image)[0])\n",
" if(printTopKData == True):\n",
" sysml_proba = vgg.predict_proba(input_image)\n",
" printTopK(sysml_proba, 'SystemML BGR', topK)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from pyspark.ml.linalg import Vectors\n",
"import os\n",
"import systemml as sml\n",
"\n",
"\n",
"def getLabelFeatures(filename, train_dir, img_shape):\n",
" from PIL import Image\n",
"\n",
" vec = Vectors.dense(sml.convertImageToNumPyArr(Image.open(os.path.join(train_dir, filename)), img_shape=img_shape)[0,:])\n",
" if filename.lower().startswith('cat'):\n",
" return (1, vec)\n",
" elif filename.lower().startswith('dog'):\n",
" return (2, vec)\n",
" else:\n",
" raise ValueError('Expected the filename to start with either cat or dog')\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from pyspark.sql.functions import rand\n",
"import os\n",
"\n",
"def createTrainingDF(train_dir, train_data_file, img_shape):\n",
" list_jpeg_files = os.listdir(train_dir)\n",
" # 10 files per partition\n",
" train_df = sc.parallelize(list_jpeg_files, int(len(list_jpeg_files)/10)).map(lambda filename : getLabelFeatures(filename, train_dir, img_shape)).toDF(['label', 'features']).orderBy(rand())\n",
" # Optional: but helps seperates conversion-related from training\n",
" # train_df.write.parquet(train_data_file) # 'kaggle-cats-dogs.parquet'\n",
" return train_df"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def readTrainingDF(train_dir, train_data_file):\n",
" train_df = sqlContext.read.parquet(train_data_file)\n",
" return train_df"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# downloadAndConvertModel(downloadDir, trained_vgg_weights)\n",
"# TODO: Take \"TODO\" actions mentioned in the downloadAndConvertModel() function after calling downloadAndConvertModel() function."
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"def retrainModel(img_shape, downloadDir, trained_vgg_weights, train_dir, train_data_file, vgg_new_model):\n",
"\n",
" # Let downloadAndConvertModel() functon be commented out, as it needs to be called separately (which is done in cell above) and manual action to be taken after calling it.\n",
" # downloadAndConvertModel(downloadDir, trained_vgg_weights)\n",
" # TODO: Take \"TODO\" actions mentioned in the downloadAndConvertModel() function after calling that function.\n",
" \n",
" train_df = createTrainingDF(train_dir, train_data_file, img_shape)\n",
" ## Write from input files OR read if its already written/converted\n",
" # train_df = readTrainingDF(train_dir, train_data_file)\n",
" \n",
" # Load the model\n",
" vgg = Caffe2DML(sqlCtx, solver=os.path.join(downloadDir,'VGG_ILSVRC_19_layers_solver.proto'), input_shape=img_shape)\n",
" vgg.load(weights=os.path.join(downloadDir,trained_vgg_weights), ignore_weights=['fc8'])\n",
" vgg.set(debug=True).setExplain(True)\n",
"\n",
" # Train the model using new data\n",
" vgg.fit(train_df)\n",
" \n",
" # Save the trained model\n",
" vgg.save(vgg_new_model)\n",
" \n",
" return vgg"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import numpy as np\n",
"import urllib\n",
"from systemml.mllearn import Caffe2DML\n",
"import systemml as sml\n",
"\n",
"\n",
"def classifyImagesWTransfLearning(urls, model, img_shape=(3, 224, 224), printTokKData=False, topK=5):\n",
"\n",
" size = (img_shape[1], img_shape[2])\n",
" # vgg.load(trained_vgg_weights)\n",
"\n",
" for url in urls:\n",
" outFile = 'inputTest.jpg'\n",
" urllib.urlretrieve(url, outFile)\n",
" \n",
" from IPython.display import Image, display\n",
" display(Image(filename=outFile))\n",
" \n",
" print (\"Prediction of above image to ImageNet Class using\");\n",
"\n",
" ## Do image classification through SystemML processing\n",
" from PIL import Image\n",
" input_image = sml.convertImageToNumPyArr(Image.open(outFile), img_shape=img_shape\n",
" , color_mode='BGR', mean=sml.getDatasetMean('VGG_ILSVRC_19_2014'))\n",
"\n",
" print (\"Image preprocessed through SystemML :: \", model.predict(input_image)[0])\n",
" if(printTopKData == True):\n",
" sysml_proba = model.predict_proba(input_image)\n",
" printTopK(sysml_proba, 'SystemML BGR', topK)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Sample code to retrain the model and use it to classify image through two different way\n",
"\n",
"There are couple of parameters to set based on what you are looking for.\n",
"1. printTopKData (default False): If this parameter gets set to True, then top K results (probabilities and indices) will be displayed. \n",
"2. topK (default 5): How many entities (K) to be displayed.\n",
"3. Directories, data file name, model name and directory where data has donwloaded."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"scrolled": false
},
"outputs": [],
"source": [
"# ImageNet specific parameters\n",
"img_shape = (3, 224, 224)\n",
"\n",
"# Setting other than current directory causes \"network file not found\" issue, as network file\n",
"# location is defined in solver file which does not have a path, so it searches in current dir.\n",
"downloadDir = '.' # /home/asurve/caffe_models' \n",
"trained_vgg_weights = 'trained_vgg_weights'\n",
"\n",
"train_dir = '/home/asurve/data/keggle/dogs_vs_cats_2/train'\n",
"train_data_file = 'kaggle-cats-dogs.parquet'\n",
" \n",
"vgg_new_model = 'kaggle-cats-dogs-model_2'\n",
" \n",
"printTopKData=True\n",
"topK=5\n",
"\n",
"urls = ['http://cdn3-www.dogtime.com/assets/uploads/gallery/goldador-dog-breed-pictures/puppy-1.jpg','https://lh3.googleusercontent.com/-YdeAa1Ff4Ac/VkUnQ4vuZGI/AAAAAAAAAEg/nBiUn4pp6aE/w800-h800/images-6.jpeg','https://upload.wikimedia.org/wikipedia/commons/thumb/5/58/MountainLion.jpg/312px-MountainLion.jpg']\n",
"\n",
"vgg = retrainModel(img_shape, downloadDir, trained_vgg_weights, train_dir, train_data_file, vgg_new_model)\n",
"classifyImagesWTransfLearning(urls, vgg, img_shape, printTopKData, topK)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"scrolled": true
},
"outputs": [],
"source": [
"img_shape = (3, 224, 224)\n",
"\n",
"printTopKData=True\n",
"topK=5\n",
"\n",
"# Setting other than current directory causes \"network file not found\" issue, as network file\n",
"# location is defined in solver file which does not have a path, so it searches in current dir.\n",
"downloadDir = '.' # /home/asurve/caffe_models' \n",
"trained_vgg_weights = 'kaggle-cats-dogs-model_2'\n",
"\n",
"urls = ['http://cdn3-www.dogtime.com/assets/uploads/gallery/goldador-dog-breed-pictures/puppy-1.jpg','https://lh3.googleusercontent.com/-YdeAa1Ff4Ac/VkUnQ4vuZGI/AAAAAAAAAEg/nBiUn4pp6aE/w800-h800/images-6.jpeg','https://upload.wikimedia.org/wikipedia/commons/thumb/5/58/MountainLion.jpg/312px-MountainLion.jpg']\n",
"\n",
"classifyImages(urls,img_shape, printTopKData, topK, downloadDir, trained_vgg_weights)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.13"
}
},
"nbformat": 4,
"nbformat_minor": 2
}