blob: faa322ea2a1e62160b4a50cc56562682373aea3f [file] [log] [blame]
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# CNN Using Keras and MADlib\n",
"\n",
"E2E classification example using MADlib calling a Keras CNN. Based on model architecture in https://keras.io/examples/cifar10_cnn/\n",
"\n",
"To load images into tables we use the script called <em>madlib_image_loader.py</em> located at https://github.com/apache/madlib-site/tree/asf-site/community-artifacts/Deep-learning which uses the Python Imaging Library so supports multiple formats\n",
"http://www.pythonware.com/products/pil/\n",
"\n",
"\n",
"## Table of contents\n",
"<a href=\"#import_libraries\">1. Import libraries</a>\n",
"\n",
"<a href=\"#load_and_prepare_data\">2. Load dataset into table</a>\n",
"\n",
"<a href=\"#image_preproc\">3. Call image preprocessor</a>\n",
"\n",
"<a href=\"#define_and_load_model\">4. Define and load model architecture</a>\n",
"\n",
"<a href=\"#train\">5. Train</a>\n",
"\n",
"<a href=\"#plot\">6. Plots by iteration and time</a>"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/fmcquillan/anaconda/lib/python2.7/site-packages/IPython/config.py:13: ShimWarning: The `IPython.config` package has been deprecated since IPython 4.0. You should import from traitlets.config instead.\n",
" \"You should import from traitlets.config instead.\", ShimWarning)\n",
"/Users/fmcquillan/anaconda/lib/python2.7/site-packages/IPython/utils/traitlets.py:5: UserWarning: IPython.utils.traitlets has moved to a top-level traitlets package.\n",
" warn(\"IPython.utils.traitlets has moved to a top-level traitlets package.\")\n"
]
}
],
"source": [
"%load_ext sql"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"u'Connected: fmcquillan@madlib'"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Greenplum Database 5.x on GCP for deep learning (PM demo machine)\n",
"#%sql postgresql://gpadmin@35.239.240.26:5432/madlib\n",
" \n",
"# PostgreSQL local\n",
"%sql postgresql://fmcquillan@localhost:5432/madlib"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>version</th>\n",
" </tr>\n",
" <tr>\n",
" <td>MADlib version: 1.16, git revision: rc/1.16-rc1, cmake configuration time: Mon Jul 1 17:45:09 UTC 2019, build type: Release, build system: Darwin-16.7.0, C compiler: Clang, C++ compiler: Clang</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(u'MADlib version: 1.16, git revision: rc/1.16-rc1, cmake configuration time: Mon Jul 1 17:45:09 UTC 2019, build type: Release, build system: Darwin-16.7.0, C compiler: Clang, C++ compiler: Clang',)]"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%sql select madlib.version();\n",
"#%sql select version();"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a id=\"import_libraries\"></a>\n",
"# 1. Import libraries\n",
"From https://keras.io/examples/mnist_transfer_cnn/ import libraries and define some params"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using TensorFlow backend.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Couldn't import dot_parser, loading of dot files will not be possible.\n"
]
}
],
"source": [
"from __future__ import print_function\n",
"import keras\n",
"from keras.datasets import cifar10\n",
"from keras.preprocessing.image import ImageDataGenerator\n",
"from keras.models import Sequential\n",
"from keras.layers import Dense, Dropout, Activation, Flatten\n",
"from keras.layers import Conv2D, MaxPooling2D\n",
"import os\n",
"\n",
"batch_size = 32\n",
"num_classes = 10\n",
"epochs = 100"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Others needed in this workbook"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"import sys\n",
"import os\n",
"from matplotlib import pyplot as plt"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a id=\"load_and_prepare_data\"></a>\n",
"# 2. Set up image loader and load dataset into table\n",
"\n",
"First set up image loader using the script called <em>madlib_image_loader.py</em> located at https://github.com/apache/madlib-site/tree/asf-site/community-artifacts/Deep-learning"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"madlib_site_dir = '/Users/fmcquillan/Documents/Product/MADlib/Demos/data'\n",
"sys.path.append(madlib_site_dir)\n",
"\n",
"# Import image loader module\n",
"from madlib_image_loader import ImageLoader, DbCredentials\n",
"\n",
"# Specify database credentials, for connecting to db\n",
"db_creds = DbCredentials(user='fmcquillan',\n",
" host='localhost',\n",
" port='5432',\n",
" password='')\n",
"\n",
"# Specify database credentials, for connecting to db\n",
"#db_creds = DbCredentials(user='gpadmin', \n",
"# db_name='madlib',\n",
"# host='35.239.240.26',\n",
"# port='5432',\n",
"# password='')\n",
"\n",
"# Initialize ImageLoader (increase num_workers to run faster)\n",
"iloader = ImageLoader(num_workers=5, db_creds=db_creds)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"First load CIFAR-10 data from Keras consisting of 50,000 32x32 color training images, labeled over 10 categories, and 10,000 test images."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"MainProcess: Connected to madlib db.\n",
"Executing: CREATE TABLE cifar_10_train_data (id SERIAL, x REAL[], y TEXT)\n",
"CREATE TABLE\n",
"Created table cifar_10_train_data in madlib db\n",
"Spawning 5 workers...\n",
"Initializing PoolWorker-1 [pid 28054]\n",
"PoolWorker-1: Created temporary directory /tmp/madlib_tdv3zEFPL1\n",
"Initializing PoolWorker-2 [pid 28055]\n",
"PoolWorker-2: Created temporary directory /tmp/madlib_bWb3jWWKsY\n",
"Initializing PoolWorker-3 [pid 28056]\n",
"PoolWorker-1: Connected to madlib db.\n",
"PoolWorker-3: Created temporary directory /tmp/madlib_KetBMAbjq5\n",
"Initializing PoolWorker-4 [pid 28057]\n",
"PoolWorker-2: Connected to madlib db.\n",
"PoolWorker-4: Created temporary directory /tmp/madlib_sME12BQHb1\n",
"Initializing PoolWorker-5 [pid 28059]\n",
"PoolWorker-3: Connected to madlib db.\n",
"PoolWorker-5: Created temporary directory /tmp/madlib_i6LP0aJJDY\n",
"PoolWorker-4: Connected to madlib db.\n",
"PoolWorker-5: Connected to madlib db.\n",
"PoolWorker-1: Wrote 1000 images to /tmp/madlib_tdv3zEFPL1/cifar_10_train_data0000.tmp\n",
"PoolWorker-2: Wrote 1000 images to /tmp/madlib_bWb3jWWKsY/cifar_10_train_data0000.tmp\n",
"PoolWorker-3: Wrote 1000 images to /tmp/madlib_KetBMAbjq5/cifar_10_train_data0000.tmp\n",
"PoolWorker-4: Wrote 1000 images to /tmp/madlib_sME12BQHb1/cifar_10_train_data0000.tmp\n",
"PoolWorker-5: Wrote 1000 images to /tmp/madlib_i6LP0aJJDY/cifar_10_train_data0000.tmp\n",
"PoolWorker-1: Loaded 1000 images into cifar_10_train_data\n",
"PoolWorker-2: Loaded 1000 images into cifar_10_train_data\n",
"PoolWorker-3: Loaded 1000 images into cifar_10_train_data\n",
"PoolWorker-4: Loaded 1000 images into cifar_10_train_data\n",
"PoolWorker-5: Loaded 1000 images into cifar_10_train_data\n",
"PoolWorker-1: Wrote 1000 images to /tmp/madlib_tdv3zEFPL1/cifar_10_train_data0001.tmp\n",
"PoolWorker-2: Wrote 1000 images to /tmp/madlib_bWb3jWWKsY/cifar_10_train_data0001.tmp\n",
"PoolWorker-3: Wrote 1000 images to /tmp/madlib_KetBMAbjq5/cifar_10_train_data0001.tmp\n",
"PoolWorker-4: Wrote 1000 images to /tmp/madlib_sME12BQHb1/cifar_10_train_data0001.tmp\n",
"PoolWorker-5: Wrote 1000 images to /tmp/madlib_i6LP0aJJDY/cifar_10_train_data0001.tmp\n",
"PoolWorker-1: Loaded 1000 images into cifar_10_train_data\n",
"PoolWorker-2: Loaded 1000 images into cifar_10_train_data\n",
"PoolWorker-4: Loaded 1000 images into cifar_10_train_data\n",
"PoolWorker-5: Loaded 1000 images into cifar_10_train_data\n",
"PoolWorker-3: Loaded 1000 images into cifar_10_train_data\n",
"PoolWorker-1: Wrote 1000 images to /tmp/madlib_tdv3zEFPL1/cifar_10_train_data0002.tmp\n",
"PoolWorker-2: Wrote 1000 images to /tmp/madlib_bWb3jWWKsY/cifar_10_train_data0002.tmp\n",
"PoolWorker-4: Wrote 1000 images to /tmp/madlib_sME12BQHb1/cifar_10_train_data0002.tmp\n",
"PoolWorker-5: Wrote 1000 images to /tmp/madlib_i6LP0aJJDY/cifar_10_train_data0002.tmp\n",
"PoolWorker-3: Wrote 1000 images to /tmp/madlib_KetBMAbjq5/cifar_10_train_data0002.tmp\n",
"PoolWorker-1: Loaded 1000 images into cifar_10_train_data\n",
"PoolWorker-2: Loaded 1000 images into cifar_10_train_data\n",
"PoolWorker-4: Loaded 1000 images into cifar_10_train_data\n",
"PoolWorker-5: Loaded 1000 images into cifar_10_train_data\n",
"PoolWorker-3: Loaded 1000 images into cifar_10_train_data\n",
"PoolWorker-1: Wrote 1000 images to /tmp/madlib_tdv3zEFPL1/cifar_10_train_data0003.tmp\n",
"PoolWorker-2: Wrote 1000 images to /tmp/madlib_bWb3jWWKsY/cifar_10_train_data0003.tmp\n",
"PoolWorker-4: Wrote 1000 images to /tmp/madlib_sME12BQHb1/cifar_10_train_data0003.tmp\n",
"PoolWorker-5: Wrote 1000 images to /tmp/madlib_i6LP0aJJDY/cifar_10_train_data0003.tmp\n",
"PoolWorker-3: Wrote 1000 images to /tmp/madlib_KetBMAbjq5/cifar_10_train_data0003.tmp\n",
"PoolWorker-1: Loaded 1000 images into cifar_10_train_data\n",
"PoolWorker-2: Loaded 1000 images into cifar_10_train_data\n",
"PoolWorker-5: Loaded 1000 images into cifar_10_train_data\n",
"PoolWorker-4: Loaded 1000 images into cifar_10_train_data\n",
"PoolWorker-3: Loaded 1000 images into cifar_10_train_data\n",
"PoolWorker-1: Wrote 1000 images to /tmp/madlib_tdv3zEFPL1/cifar_10_train_data0004.tmp\n",
"PoolWorker-2: Wrote 1000 images to /tmp/madlib_bWb3jWWKsY/cifar_10_train_data0004.tmp\n",
"PoolWorker-4: Wrote 1000 images to /tmp/madlib_sME12BQHb1/cifar_10_train_data0004.tmp\n",
"PoolWorker-5: Wrote 1000 images to /tmp/madlib_i6LP0aJJDY/cifar_10_train_data0004.tmp\n",
"PoolWorker-3: Wrote 1000 images to /tmp/madlib_KetBMAbjq5/cifar_10_train_data0004.tmp\n",
"PoolWorker-1: Loaded 1000 images into cifar_10_train_data\n",
"PoolWorker-2: Loaded 1000 images into cifar_10_train_data\n",
"PoolWorker-5: Loaded 1000 images into cifar_10_train_data\n",
"PoolWorker-4: Loaded 1000 images into cifar_10_train_data\n",
"PoolWorker-3: Loaded 1000 images into cifar_10_train_data\n",
"PoolWorker-1: Wrote 1000 images to /tmp/madlib_tdv3zEFPL1/cifar_10_train_data0005.tmp\n",
"PoolWorker-2: Wrote 1000 images to /tmp/madlib_bWb3jWWKsY/cifar_10_train_data0005.tmp\n",
"PoolWorker-4: Wrote 1000 images to /tmp/madlib_sME12BQHb1/cifar_10_train_data0005.tmp\n",
"PoolWorker-5: Wrote 1000 images to /tmp/madlib_i6LP0aJJDY/cifar_10_train_data0005.tmp\n",
"PoolWorker-3: Wrote 1000 images to /tmp/madlib_KetBMAbjq5/cifar_10_train_data0005.tmp\n",
"PoolWorker-1: Loaded 1000 images into cifar_10_train_data\n",
"PoolWorker-2: Loaded 1000 images into cifar_10_train_data\n",
"PoolWorker-4: Loaded 1000 images into cifar_10_train_data\n",
"PoolWorker-5: Loaded 1000 images into cifar_10_train_data\n",
"PoolWorker-3: Loaded 1000 images into cifar_10_train_data\n",
"PoolWorker-1: Wrote 1000 images to /tmp/madlib_tdv3zEFPL1/cifar_10_train_data0006.tmp\n",
"PoolWorker-2: Wrote 1000 images to /tmp/madlib_bWb3jWWKsY/cifar_10_train_data0006.tmp\n",
"PoolWorker-4: Wrote 1000 images to /tmp/madlib_sME12BQHb1/cifar_10_train_data0006.tmp\n",
"PoolWorker-5: Wrote 1000 images to /tmp/madlib_i6LP0aJJDY/cifar_10_train_data0006.tmp\n",
"PoolWorker-3: Wrote 1000 images to /tmp/madlib_KetBMAbjq5/cifar_10_train_data0006.tmp\n",
"PoolWorker-1: Loaded 1000 images into cifar_10_train_data\n",
"PoolWorker-2: Loaded 1000 images into cifar_10_train_data\n",
"PoolWorker-4: Loaded 1000 images into cifar_10_train_data\n",
"PoolWorker-5: Loaded 1000 images into cifar_10_train_data\n",
"PoolWorker-3: Loaded 1000 images into cifar_10_train_data\n",
"PoolWorker-1: Wrote 1000 images to /tmp/madlib_tdv3zEFPL1/cifar_10_train_data0007.tmp\n",
"PoolWorker-5: Wrote 1000 images to /tmp/madlib_i6LP0aJJDY/cifar_10_train_data0007.tmp\n",
"PoolWorker-2: Wrote 1000 images to /tmp/madlib_bWb3jWWKsY/cifar_10_train_data0007.tmp\n",
"PoolWorker-4: Wrote 1000 images to /tmp/madlib_sME12BQHb1/cifar_10_train_data0007.tmp\n",
"PoolWorker-3: Wrote 1000 images to /tmp/madlib_KetBMAbjq5/cifar_10_train_data0007.tmp\n",
"PoolWorker-1: Loaded 1000 images into cifar_10_train_data\n",
"PoolWorker-5: Loaded 1000 images into cifar_10_train_data\n",
"PoolWorker-2: Loaded 1000 images into cifar_10_train_data\n",
"PoolWorker-4: Loaded 1000 images into cifar_10_train_data\n",
"PoolWorker-3: Loaded 1000 images into cifar_10_train_data\n",
"PoolWorker-1: Wrote 1000 images to /tmp/madlib_tdv3zEFPL1/cifar_10_train_data0008.tmp\n",
"PoolWorker-4: Wrote 1000 images to /tmp/madlib_sME12BQHb1/cifar_10_train_data0008.tmp\n",
"PoolWorker-5: Wrote 1000 images to /tmp/madlib_i6LP0aJJDY/cifar_10_train_data0008.tmp\n",
"PoolWorker-2: Wrote 1000 images to /tmp/madlib_bWb3jWWKsY/cifar_10_train_data0008.tmp\n",
"PoolWorker-3: Wrote 1000 images to /tmp/madlib_KetBMAbjq5/cifar_10_train_data0008.tmp\n",
"PoolWorker-1: Loaded 1000 images into cifar_10_train_data\n",
"PoolWorker-4: Loaded 1000 images into cifar_10_train_data\n",
"PoolWorker-5: Loaded 1000 images into cifar_10_train_data\n",
"PoolWorker-2: Loaded 1000 images into cifar_10_train_data\n",
"PoolWorker-3: Loaded 1000 images into cifar_10_train_data\n",
"PoolWorker-1: Wrote 1000 images to /tmp/madlib_tdv3zEFPL1/cifar_10_train_data0009.tmp\n",
"PoolWorker-4: Wrote 1000 images to /tmp/madlib_sME12BQHb1/cifar_10_train_data0009.tmp\n",
"PoolWorker-1: Loaded 1000 images into cifar_10_train_data\n",
"PoolWorker-4: Loaded 1000 images into cifar_10_train_data\n",
"PoolWorker-1: Wrote 1000 images to /tmp/madlib_tdv3zEFPL1/cifar_10_train_data0010.tmp\n",
"PoolWorker-4: Wrote 1000 images to /tmp/madlib_sME12BQHb1/cifar_10_train_data0010.tmp\n",
"PoolWorker-1: Loaded 1000 images into cifar_10_train_data\n",
"PoolWorker-4: Loaded 1000 images into cifar_10_train_data\n",
"PoolWorker-1: Wrote 1000 images to /tmp/madlib_tdv3zEFPL1/cifar_10_train_data0011.tmp\n",
"PoolWorker-1: Loaded 1000 images into cifar_10_train_data\n",
"PoolWorker-3: Removed temporary directory /tmp/madlib_KetBMAbjq5\n",
"PoolWorker-2: Removed temporary directory /tmp/madlib_bWb3jWWKsY\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"PoolWorker-4: Removed temporary directory /tmp/madlib_sME12BQHb1\n",
"PoolWorker-1: Removed temporary directory /tmp/madlib_tdv3zEFPL1\n",
"PoolWorker-5: Removed temporary directory /tmp/madlib_i6LP0aJJDY\n",
"Done! Loaded 50000 images in 24.227011919s\n",
"5 workers terminated.\n",
"MainProcess: Connected to madlib db.\n",
"Executing: CREATE TABLE cifar_10_test_data (id SERIAL, x REAL[], y TEXT)\n",
"CREATE TABLE\n",
"Created table cifar_10_test_data in madlib db\n",
"Spawning 5 workers...\n",
"Initializing PoolWorker-6 [pid 28066]\n",
"PoolWorker-6: Created temporary directory /tmp/madlib_yKNKBHEc3G\n",
"Initializing PoolWorker-7 [pid 28067]\n",
"PoolWorker-7: Created temporary directory /tmp/madlib_hb8ESuQLva\n",
"Initializing PoolWorker-8 [pid 28068]\n",
"PoolWorker-8: Created temporary directory /tmp/madlib_PmtDmYhSBj\n",
"PoolWorker-6: Connected to madlib db.\n",
"Initializing PoolWorker-9 [pid 28069]\n",
"PoolWorker-7: Connected to madlib db.\n",
"PoolWorker-9: Created temporary directory /tmp/madlib_h7oUVpBwyZ\n",
"Initializing PoolWorker-10 [pid 28071]\n",
"PoolWorker-8: Connected to madlib db.\n",
"PoolWorker-10: Created temporary directory /tmp/madlib_9TZoE98hbn\n",
"PoolWorker-9: Connected to madlib db.\n",
"PoolWorker-10: Connected to madlib db.\n",
"PoolWorker-8: Wrote 1000 images to /tmp/madlib_PmtDmYhSBj/cifar_10_test_data0000.tmp\n",
"PoolWorker-6: Wrote 1000 images to /tmp/madlib_yKNKBHEc3G/cifar_10_test_data0000.tmp\n",
"PoolWorker-7: Wrote 1000 images to /tmp/madlib_hb8ESuQLva/cifar_10_test_data0000.tmp\n",
"PoolWorker-9: Wrote 1000 images to /tmp/madlib_h7oUVpBwyZ/cifar_10_test_data0000.tmp\n",
"PoolWorker-10: Wrote 1000 images to /tmp/madlib_9TZoE98hbn/cifar_10_test_data0000.tmp\n",
"PoolWorker-8: Loaded 1000 images into cifar_10_test_data\n",
"PoolWorker-7: Loaded 1000 images into cifar_10_test_data\n",
"PoolWorker-6: Loaded 1000 images into cifar_10_test_data\n",
"PoolWorker-10: Loaded 1000 images into cifar_10_test_data\n",
"PoolWorker-9: Loaded 1000 images into cifar_10_test_data\n",
"PoolWorker-8: Wrote 1000 images to /tmp/madlib_PmtDmYhSBj/cifar_10_test_data0001.tmp\n",
"PoolWorker-7: Wrote 1000 images to /tmp/madlib_hb8ESuQLva/cifar_10_test_data0001.tmp\n",
"PoolWorker-6: Wrote 1000 images to /tmp/madlib_yKNKBHEc3G/cifar_10_test_data0001.tmp\n",
"PoolWorker-10: Wrote 1000 images to /tmp/madlib_9TZoE98hbn/cifar_10_test_data0001.tmp\n",
"PoolWorker-9: Wrote 1000 images to /tmp/madlib_h7oUVpBwyZ/cifar_10_test_data0001.tmp\n",
"PoolWorker-8: Loaded 1000 images into cifar_10_test_data\n",
"PoolWorker-7: Loaded 1000 images into cifar_10_test_data\n",
"PoolWorker-6: Loaded 1000 images into cifar_10_test_data\n",
"PoolWorker-10: Loaded 1000 images into cifar_10_test_data\n",
"PoolWorker-9: Loaded 1000 images into cifar_10_test_data\n",
"PoolWorker-8: Removed temporary directory /tmp/madlib_PmtDmYhSBj\n",
"PoolWorker-7: Removed temporary directory /tmp/madlib_hb8ESuQLva\n",
"PoolWorker-6: Removed temporary directory /tmp/madlib_yKNKBHEc3G\n",
"PoolWorker-9: Removed temporary directory /tmp/madlib_h7oUVpBwyZ\n",
"PoolWorker-10: Removed temporary directory /tmp/madlib_9TZoE98hbn\n",
"Done! Loaded 10000 images in 4.54620194435s\n",
"5 workers terminated.\n"
]
}
],
"source": [
"# Load dataset into np array\n",
"(x_train, y_train), (x_test, y_test) = cifar10.load_data()\n",
"\n",
"%sql DROP TABLE IF EXISTS cifar_10_train_data, cifar_10_test_data;\n",
"\n",
"# Save images to temporary directories and load into database\n",
"iloader.load_dataset_from_np(x_train, y_train, 'cifar_10_train_data', append=False)\n",
"iloader.load_dataset_from_np(x_test, y_test, 'cifar_10_test_data', append=False)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>count</th>\n",
" </tr>\n",
" <tr>\n",
" <td>50000</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(50000L,)]"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%sql select count(*) from cifar_10_train_data;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a id=\"image_preproc\"></a>\n",
"# 3. Call image preprocessor\n",
"\n",
"Transforms from one image per row to multiple images per row for batch optimization. Also normalizes and one-hot encodes.\n",
"\n",
"Training data"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>source_table</th>\n",
" <th>output_table</th>\n",
" <th>dependent_varname</th>\n",
" <th>independent_varname</th>\n",
" <th>dependent_vartype</th>\n",
" <th>class_values</th>\n",
" <th>buffer_size</th>\n",
" <th>normalizing_const</th>\n",
" <th>num_classes</th>\n",
" </tr>\n",
" <tr>\n",
" <td>cifar_10_train_data</td>\n",
" <td>cifar_10_train_data_packed</td>\n",
" <td>y</td>\n",
" <td>x</td>\n",
" <td>text</td>\n",
" <td>[u'0', u'1', u'2', u'3', u'4', u'5', u'6', u'7', u'8', u'9']</td>\n",
" <td>1000</td>\n",
" <td>255.0</td>\n",
" <td>10</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(u'cifar_10_train_data', u'cifar_10_train_data_packed', u'y', u'x', u'text', [u'0', u'1', u'2', u'3', u'4', u'5', u'6', u'7', u'8', u'9'], 1000, 255.0, 10)]"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS cifar_10_train_data_packed, cifar_10_train_data_packed_summary;\n",
"\n",
"SELECT madlib.training_preprocessor_dl('cifar_10_train_data', -- Source table\n",
" 'cifar_10_train_data_packed', -- Output table\n",
" 'y', -- Dependent variable\n",
" 'x', -- Independent variable\n",
" 1000, -- Buffer size\n",
" 255 -- Normalizing constant\n",
" );\n",
"\n",
"SELECT * FROM cifar_10_train_data_packed_summary;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Test data"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>source_table</th>\n",
" <th>output_table</th>\n",
" <th>dependent_varname</th>\n",
" <th>independent_varname</th>\n",
" <th>dependent_vartype</th>\n",
" <th>class_values</th>\n",
" <th>buffer_size</th>\n",
" <th>normalizing_const</th>\n",
" <th>num_classes</th>\n",
" </tr>\n",
" <tr>\n",
" <td>cifar_10_test_data</td>\n",
" <td>cifar_10_test_data_packed</td>\n",
" <td>y</td>\n",
" <td>x</td>\n",
" <td>text</td>\n",
" <td>[u'0', u'1', u'2', u'3', u'4', u'5', u'6', u'7', u'8', u'9']</td>\n",
" <td>1000</td>\n",
" <td>255.0</td>\n",
" <td>10</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(u'cifar_10_test_data', u'cifar_10_test_data_packed', u'y', u'x', u'text', [u'0', u'1', u'2', u'3', u'4', u'5', u'6', u'7', u'8', u'9'], 1000, 255.0, 10)]"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS cifar_10_test_data_packed, cifar_10_test_data_packed_summary;\n",
"\n",
"SELECT madlib.validation_preprocessor_dl('cifar_10_test_data', -- Source table\n",
" 'cifar_10_test_data_packed', -- Output table\n",
" 'y', -- Dependent variable\n",
" 'x', -- Independent variable\n",
" 'cifar_10_train_data_packed', -- Training preproc table\n",
" 1000 -- Buffer size\n",
" );\n",
"\n",
"SELECT * FROM cifar_10_test_data_packed_summary;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a id=\"define_and_load_model\"></a>\n",
"# 4. Define and load model architecture\n",
"\n",
"Model architecture from https://keras.io/examples/cifar10_cnn/"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"conv2d_1 (Conv2D) (None, 32, 32, 32) 896 \n",
"_________________________________________________________________\n",
"activation_1 (Activation) (None, 32, 32, 32) 0 \n",
"_________________________________________________________________\n",
"conv2d_2 (Conv2D) (None, 30, 30, 32) 9248 \n",
"_________________________________________________________________\n",
"activation_2 (Activation) (None, 30, 30, 32) 0 \n",
"_________________________________________________________________\n",
"max_pooling2d_1 (MaxPooling2 (None, 15, 15, 32) 0 \n",
"_________________________________________________________________\n",
"dropout_1 (Dropout) (None, 15, 15, 32) 0 \n",
"_________________________________________________________________\n",
"conv2d_3 (Conv2D) (None, 15, 15, 64) 18496 \n",
"_________________________________________________________________\n",
"activation_3 (Activation) (None, 15, 15, 64) 0 \n",
"_________________________________________________________________\n",
"conv2d_4 (Conv2D) (None, 13, 13, 64) 36928 \n",
"_________________________________________________________________\n",
"activation_4 (Activation) (None, 13, 13, 64) 0 \n",
"_________________________________________________________________\n",
"max_pooling2d_2 (MaxPooling2 (None, 6, 6, 64) 0 \n",
"_________________________________________________________________\n",
"dropout_2 (Dropout) (None, 6, 6, 64) 0 \n",
"_________________________________________________________________\n",
"flatten_1 (Flatten) (None, 2304) 0 \n",
"_________________________________________________________________\n",
"dense_1 (Dense) (None, 512) 1180160 \n",
"_________________________________________________________________\n",
"activation_5 (Activation) (None, 512) 0 \n",
"_________________________________________________________________\n",
"dropout_3 (Dropout) (None, 512) 0 \n",
"_________________________________________________________________\n",
"dense_2 (Dense) (None, 10) 5130 \n",
"_________________________________________________________________\n",
"activation_6 (Activation) (None, 10) 0 \n",
"=================================================================\n",
"Total params: 1,250,858\n",
"Trainable params: 1,250,858\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"model = Sequential()\n",
"model.add(Conv2D(32, (3, 3), padding='same',\n",
" input_shape=x_train.shape[1:]))\n",
"model.add(Activation('relu'))\n",
"model.add(Conv2D(32, (3, 3)))\n",
"model.add(Activation('relu'))\n",
"model.add(MaxPooling2D(pool_size=(2, 2)))\n",
"model.add(Dropout(0.25))\n",
"\n",
"model.add(Conv2D(64, (3, 3), padding='same'))\n",
"model.add(Activation('relu'))\n",
"model.add(Conv2D(64, (3, 3)))\n",
"model.add(Activation('relu'))\n",
"model.add(MaxPooling2D(pool_size=(2, 2)))\n",
"model.add(Dropout(0.25))\n",
"\n",
"model.add(Flatten())\n",
"model.add(Dense(512))\n",
"model.add(Activation('relu'))\n",
"model.add(Dropout(0.5))\n",
"model.add(Dense(num_classes))\n",
"model.add(Activation('softmax'))\n",
"\n",
"model.summary()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Load into model architecture table using psycopg2"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>model_id</th>\n",
" <th>name</th>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>CNN from Keras docs for CIFAR-10</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(1, u'CNN from Keras docs for CIFAR-10')]"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import psycopg2 as p2\n",
"#conn = p2.connect('postgresql://gpadmin@35.239.240.26:5432/madlib')\n",
"conn = p2.connect('postgresql://fmcquillan@localhost:5432/madlib')\n",
"cur = conn.cursor()\n",
"\n",
"%sql DROP TABLE IF EXISTS model_arch_library;\n",
"query = \"SELECT madlib.load_keras_model('model_arch_library', %s, NULL, %s)\"\n",
"cur.execute(query,[model.to_json(), \"CNN from Keras docs for CIFAR-10\"])\n",
"conn.commit()\n",
"\n",
"# check model loaded OK\n",
"%sql SELECT model_id, name FROM model_arch_library;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a id=\"train\"></a>\n",
"# 5. Train"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n"
]
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS cifar_10_model, cifar_10_model_summary;\n",
"\n",
"SELECT madlib.madlib_keras_fit('cifar_10_train_data_packed', -- source table\n",
" 'cifar_10_model', -- model output table\n",
" 'model_arch_library', -- model arch table\n",
" 1, -- model arch id\n",
" $$ loss='categorical_crossentropy', optimizer='rmsprop(lr=0.0001, decay=1e-6)', metrics=['accuracy']$$, -- compile_params\n",
" $$ batch_size=32, epochs=3 $$, -- fit_params\n",
" 3, -- num_iterations\n",
" 0, -- GPUs per host\n",
" 'cifar_10_test_data_packed', -- validation dataset\n",
" 2 -- metrics compute frequency\n",
" );"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"View the model summary:"
]
},
{
"cell_type": "code",
"execution_count": 110,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>source_table</th>\n",
" <th>model</th>\n",
" <th>dependent_varname</th>\n",
" <th>independent_varname</th>\n",
" <th>model_arch_table</th>\n",
" <th>model_arch_id</th>\n",
" <th>compile_params</th>\n",
" <th>fit_params</th>\n",
" <th>num_iterations</th>\n",
" <th>validation_table</th>\n",
" <th>metrics_compute_frequency</th>\n",
" <th>name</th>\n",
" <th>description</th>\n",
" <th>model_type</th>\n",
" <th>model_size</th>\n",
" <th>start_training_time</th>\n",
" <th>end_training_time</th>\n",
" <th>metrics_elapsed_time</th>\n",
" <th>madlib_version</th>\n",
" <th>num_classes</th>\n",
" <th>class_values</th>\n",
" <th>dependent_vartype</th>\n",
" <th>normalizing_const</th>\n",
" <th>metrics_type</th>\n",
" <th>training_metrics_final</th>\n",
" <th>training_loss_final</th>\n",
" <th>training_metrics</th>\n",
" <th>training_loss</th>\n",
" <th>validation_metrics_final</th>\n",
" <th>validation_loss_final</th>\n",
" <th>validation_metrics</th>\n",
" <th>validation_loss</th>\n",
" <th>metrics_iters</th>\n",
" </tr>\n",
" <tr>\n",
" <td>cifar_10_train_data_packed</td>\n",
" <td>cifar_10_model</td>\n",
" <td>y</td>\n",
" <td>x</td>\n",
" <td>model_arch_library</td>\n",
" <td>1</td>\n",
" <td> loss='categorical_crossentropy', optimizer='rmsprop(lr=0.0001, decay=1e-6)', metrics=['accuracy']</td>\n",
" <td> batch_size=32, epochs=3 </td>\n",
" <td>20</td>\n",
" <td>cifar_10_test_data_packed</td>\n",
" <td>2</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>madlib_keras</td>\n",
" <td>4886.20019531</td>\n",
" <td>2019-06-25 05:40:29.287703</td>\n",
" <td>2019-06-25 07:59:52.961506</td>\n",
" <td>[798.95044708252, 1616.68976902962, 2447.13853096962, 3273.68762302399, 4116.44566893578, 4962.07483291626, 5805.66080999374, 6665.33687210083, 7526.0603749752, 8363.67366909981]</td>\n",
" <td>1.16-dev</td>\n",
" <td>10</td>\n",
" <td>[u'0', u'1', u'2', u'3', u'4', u'5', u'6', u'7', u'8', u'9']</td>\n",
" <td>text</td>\n",
" <td>255.0</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>0.836480021477</td>\n",
" <td>0.500134825706</td>\n",
" <td>[0.579240024089813, 0.672980010509491, 0.723999977111816, 0.75764000415802, 0.783959984779358, 0.79475998878479, 0.811240017414093, 0.822780013084412, 0.829559981822968, 0.836480021476746]</td>\n",
" <td>[1.19081699848175, 0.940543830394745, 0.800645172595978, 0.700933694839478, 0.636690974235535, 0.599389910697937, 0.556614756584167, 0.53840559720993, 0.517430067062378, 0.500134825706482]</td>\n",
" <td>0.778900027275</td>\n",
" <td>0.661625564098</td>\n",
" <td>[0.57150000333786, 0.653800010681152, 0.692200005054474, 0.721300005912781, 0.740000009536743, 0.751299977302551, 0.756099998950958, 0.769999980926514, 0.77240002155304, 0.778900027275085]</td>\n",
" <td>[1.20945084095001, 0.987037718296051, 0.871006071567535, 0.800125658512115, 0.751632690429688, 0.72808450460434, 0.704570233821869, 0.684175074100494, 0.675221920013428, 0.661625564098358]</td>\n",
" <td>[2, 4, 6, 8, 10, 12, 14, 16, 18, 20]</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(u'cifar_10_train_data_packed', u'cifar_10_model', u'y', u'x', u'model_arch_library', 1, u\" loss='categorical_crossentropy', optimizer='rmsprop(lr=0.0001, decay=1e-6)', metrics=['accuracy']\", u' batch_size=32, epochs=3 ', 20, u'cifar_10_test_data_packed', 2, None, None, u'madlib_keras', 4886.20019531, datetime.datetime(2019, 6, 25, 5, 40, 29, 287703), datetime.datetime(2019, 6, 25, 7, 59, 52, 961506), [798.95044708252, 1616.68976902962, 2447.13853096962, 3273.68762302399, 4116.44566893578, 4962.07483291626, 5805.66080999374, 6665.33687210083, 7526.0603749752, 8363.67366909981], u'1.16-dev', 10, [u'0', u'1', u'2', u'3', u'4', u'5', u'6', u'7', u'8', u'9'], u'text', 255.0, [u'accuracy'], 0.836480021477, 0.500134825706, [0.579240024089813, 0.672980010509491, 0.723999977111816, 0.75764000415802, 0.783959984779358, 0.79475998878479, 0.811240017414093, 0.822780013084412, 0.829559981822968, 0.836480021476746], [1.19081699848175, 0.940543830394745, 0.800645172595978, 0.700933694839478, 0.636690974235535, 0.599389910697937, 0.556614756584167, 0.53840559720993, 0.517430067062378, 0.500134825706482], 0.778900027275, 0.661625564098, [0.57150000333786, 0.653800010681152, 0.692200005054474, 0.721300005912781, 0.740000009536743, 0.751299977302551, 0.756099998950958, 0.769999980926514, 0.77240002155304, 0.778900027275085], [1.20945084095001, 0.987037718296051, 0.871006071567535, 0.800125658512115, 0.751632690429688, 0.72808450460434, 0.704570233821869, 0.684175074100494, 0.675221920013428, 0.661625564098358], [2, 4, 6, 8, 10, 12, 14, 16, 18, 20])]"
]
},
"execution_count": 110,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT * FROM cifar_10_model_summary;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Evaluate using test data (same values as last iteration from the fit output summary above)."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>loss</th>\n",
" <th>metric</th>\n",
" <th>metrics_type</th>\n",
" </tr>\n",
" <tr>\n",
" <td>0.661625564098</td>\n",
" <td>0.778900027275</td>\n",
" <td>[u'accuracy']</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(0.661625564098358, 0.778900027275085, [u'accuracy'])]"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS cifar10_validate;\n",
"\n",
"SELECT madlib.madlib_keras_evaluate('cifar_10_model', -- model\n",
" 'cifar_10_test_data_packed', -- test table\n",
" 'cifar10_validate' -- output table\n",
" );\n",
"\n",
"SELECT * FROM cifar10_validate;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a id=\"plot\"></a>\n",
"# 6. Plots by iteration and by time\n",
"Accuracy by iteration"
]
},
{
"cell_type": "code",
"execution_count": 100,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n"
]
},
{
"data": {
"text/plain": [
"<matplotlib.legend.Legend at 0x1195ef9d0>"
]
},
"execution_count": 100,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# get accuracy and iteration number\n",
"iters_proxy = %sql SELECT metrics_iters FROM cifar_10_model_summary;\n",
"train_accuracy_proxy = %sql SELECT training_metrics FROM cifar_10_model_summary;\n",
"test_accuracy_proxy = %sql SELECT validation_metrics FROM cifar_10_model_summary;\n",
"\n",
"# get number of points\n",
"num_points_proxy = %sql SELECT array_length(metrics_iters,1) FROM cifar_10_model_summary;\n",
"num_points = num_points_proxy[0]\n",
"\n",
"# reshape to np arrays\n",
"iters = np.array(iters_proxy).reshape(num_points)\n",
"train_accuracy = np.array(train_accuracy_proxy).reshape(num_points)\n",
"test_accuracy = np.array(test_accuracy_proxy).reshape(num_points)\n",
"\n",
"#plot\n",
"plt.title('CIFAR-10 accuracy by iteration')\n",
"plt.xlabel('Iteration number')\n",
"plt.ylabel('Accuracy')\n",
"plt.grid(True)\n",
"plt.plot(iters, train_accuracy, 'g.-', label='Train')\n",
"plt.plot(iters, test_accuracy, 'r.-', label='Test')\n",
"plt.legend()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Loss by iteration"
]
},
{
"cell_type": "code",
"execution_count": 101,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n",
"1 rows affected.\n"
]
},
{
"data": {
"text/plain": [
"<matplotlib.legend.Legend at 0x119279910>"
]
},
"execution_count": 101,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# get loss\n",
"train_loss_proxy = %sql SELECT training_loss FROM cifar_10_model_summary;\n",
"test_loss_proxy = %sql SELECT validation_loss FROM cifar_10_model_summary;\n",
"\n",
"# reshape to np arrays\n",
"train_loss = np.array(train_loss_proxy).reshape(num_points)\n",
"test_loss = np.array(test_loss_proxy).reshape(num_points)\n",
"\n",
"#plot\n",
"plt.title('CIFAR-10 loss by iteration')\n",
"plt.xlabel('Iteration number')\n",
"plt.ylabel('Loss')\n",
"plt.grid(True)\n",
"plt.plot(iters, train_loss, 'g.-', label='Train')\n",
"plt.plot(iters, test_loss, 'r.-', label='Test')\n",
"plt.legend()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Accuracy by time"
]
},
{
"cell_type": "code",
"execution_count": 108,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n"
]
},
{
"data": {
"text/plain": [
"<matplotlib.legend.Legend at 0x119664410>"
]
},
"execution_count": 108,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# get time\n",
"time_proxy = %sql SELECT metrics_elapsed_time FROM cifar_10_model_summary;\n",
"\n",
"# reshape to np arrays\n",
"time = np.array(time_proxy).reshape(num_points)/60.0\n",
"\n",
"#plot\n",
"plt.title('CIFAR-10 accuracy by time')\n",
"plt.xlabel('Time (min)')\n",
"plt.ylabel('Accuracy')\n",
"plt.grid(True)\n",
"plt.plot(time, train_accuracy, 'g.-', label='Train')\n",
"plt.plot(time, test_accuracy, 'r.-', label='Test')\n",
"plt.legend()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Time to achieve a given accuracy"
]
},
{
"cell_type": "code",
"execution_count": 109,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.legend.Legend at 0x119628690>"
]
},
"execution_count": 109,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"#plot\n",
"plt.title('CIFAR-10 time by accuracy')\n",
"plt.xlabel('Accuracy')\n",
"plt.ylabel('Time (min)')\n",
"plt.grid(True)\n",
"plt.plot(train_accuracy, time, 'g.-', label='Train')\n",
"plt.plot(test_accuracy, time, 'r.-', label='Test')\n",
"plt.legend()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.10"
}
},
"nbformat": 4,
"nbformat_minor": 1
}