blob: c5e8919e4b5e0f559d9b696d9ce70a463922c7c3 [file] [log] [blame]
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Hyperband diagonal using CIFAR-10\n",
"\n",
"Implemention of Hyperband https://arxiv.org/pdf/1603.06560.pdf for MPP with a synchronous barrier. Uses the Hyperband schedule but runs it on a diagonal across brackets, instead of one bracket at a time, to be more efficient with cluster resources.\n",
"\n",
"The CIFAR-10 dataset consists of 60,000 32x32 colour images in 10 classes, with 6,000 images per class. There are 50,000 training images and 10,000 test images.\n",
"https://www.cs.toronto.edu/~kriz/cifar.html\n",
"\n",
"\n",
"## Table of contents \n",
"\n",
"<a href=\"#setup\">0. Setup</a>\n",
"\n",
"<a href=\"#load_dataset\">1. Load dataset into table</a>\n",
"\n",
"<a href=\"#distr\">2. Setup distribution rules and call preprocessor</a>\n",
"\n",
"<a href=\"#arch\">3. Define and load model architectures</a>\n",
"\n",
"<a href=\"#hyperband\">4. Hyperband diagonal</a>\n",
"\n",
"<a href=\"#plot\">5. Plot results</a>\n",
"\n",
"<a href=\"#print\">6. Pretty print schedules</a>\n",
"\n",
"<a href=\"#predict\">7. Inference</a>"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a id=\"setup\"></a>\n",
"# 0. Setup"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"%load_ext sql"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"u'Connected: fmcquillan@madlib'"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Greenplum Database 5.x on GCP - via tunnel\n",
"#%sql postgresql://gpadmin@localhost:8000/madlib\n",
"#%sql postgresql://gpadmin@35.230.53.21:5432/cifar_demo\n",
"\n",
"# PostgreSQL local\n",
"%sql postgresql://fmcquillan@localhost:5432/madlib"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" * postgresql://fmcquillan@localhost:5432/madlib\n",
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>version</th>\n",
" </tr>\n",
" <tr>\n",
" <td>MADlib version: 1.16, git revision: rc/1.16-rc1, cmake configuration time: Mon Jul 1 17:45:09 UTC 2019, build type: Release, build system: Darwin-16.7.0, C compiler: Clang, C++ compiler: Clang</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(u'MADlib version: 1.16, git revision: rc/1.16-rc1, cmake configuration time: Mon Jul 1 17:45:09 UTC 2019, build type: Release, build system: Darwin-16.7.0, C compiler: Clang, C++ compiler: Clang',)]"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%sql select madlib.version();\n",
"#%sql select version();"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Import libraries and define some params"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using TensorFlow backend.\n"
]
}
],
"source": [
"from __future__ import print_function\n",
"import keras\n",
"from keras.datasets import cifar10\n",
"from keras.preprocessing.image import ImageDataGenerator\n",
"from keras.models import Sequential\n",
"from keras.layers import Dense, Dropout, Activation, Flatten, BatchNormalization\n",
"from keras.layers import Conv2D, MaxPooling2D\n",
"import os"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Others needed in this workbook"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"import sys\n",
"import os\n",
"from matplotlib import pyplot as plt"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a id=\"load_dataset\"></a>\n",
"# 1. Load dataset into table"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"PXF can be used to load image data. \n",
"\n",
"For this demo, we will get the dataset from Keras and use the script called madlib_image_loader.py located at https://github.com/apache/madlib-site/tree/asf-site/community-artifacts/Deep-learning .\n",
"\n",
"If the script is not in the same folder as the notebook, you can use the following lines to import it."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"sys.path.insert(1, '/Users/fmcquillan/workspace/madlib-site/community-artifacts/Deep-learning')\n",
"from madlib_image_loader import ImageLoader, DbCredentials\n",
"\n",
"# Specify database credentials, for connecting to db\n",
"#db_creds = DbCredentials(user='gpadmin',\n",
"# host='localhost',\n",
"# port='8000',\n",
"# password='')\n",
"\n",
"db_creds = DbCredentials(user='fmcquillan',\n",
" host='localhost',\n",
" port='5432',\n",
" password='')\n",
"\n",
"# Initialize ImageLoader (increase num_workers to run faster)\n",
"iloader = ImageLoader(num_workers=5, db_creds=db_creds)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Load the training and test data"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" * postgresql://fmcquillan@localhost:5432/madlib\n",
"Done.\n"
]
},
{
"data": {
"text/plain": [
"[]"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"MainProcess: Connected to madlib db.\n",
"Executing: CREATE TABLE cifar10_train (id SERIAL, x REAL[], y TEXT)\n",
"CREATE TABLE\n",
"Created table cifar10_train in madlib db\n",
"Spawning 5 workers...\n",
"Initializing PoolWorker-1 [pid 10828]\n",
"Initializing PoolWorker-2 [pid 10829]\n",
"PoolWorker-1: Created temporary directory /tmp/madlib_DaP40IOgzi\n",
"Initializing PoolWorker-3 [pid 10830]\n",
"PoolWorker-2: Created temporary directory /tmp/madlib_n5XjJvXs5s\n",
"PoolWorker-3: Created temporary directory /tmp/madlib_99mTsCxOFF\n",
"Initializing PoolWorker-4 [pid 10831]\n",
"PoolWorker-5: Connected to madlib db.\n",
"PoolWorker-4: Created temporary directory /tmp/madlib_zGujxaoQIb\n",
"Initializing PoolWorker-5 [pid 10832]\n",
"PoolWorker-1: Connected to madlib db.\n",
"PoolWorker-5: Created temporary directory /tmp/madlib_D6q8olnown\n",
"PoolWorker-2: Connected to madlib db.\n",
"PoolWorker-3: Connected to madlib db.\n",
"PoolWorker-4: Connected to madlib db.\n",
"PoolWorker-1: Wrote 1000 images to /tmp/madlib_DaP40IOgzi/cifar10_train0000.tmp\n",
"PoolWorker-3: Wrote 1000 images to /tmp/madlib_99mTsCxOFF/cifar10_train0000.tmp\n",
"PoolWorker-2: Wrote 1000 images to /tmp/madlib_n5XjJvXs5s/cifar10_train0000.tmp\n",
"PoolWorker-4: Wrote 1000 images to /tmp/madlib_zGujxaoQIb/cifar10_train0000.tmp\n",
"PoolWorker-5: Wrote 1000 images to /tmp/madlib_D6q8olnown/cifar10_train0000.tmp\n",
"PoolWorker-1: Loaded 1000 images into cifar10_train\n",
"PoolWorker-3: Loaded 1000 images into cifar10_train\n",
"PoolWorker-2: Loaded 1000 images into cifar10_train\n",
"PoolWorker-4: Loaded 1000 images into cifar10_train\n",
"PoolWorker-5: Loaded 1000 images into cifar10_train\n",
"PoolWorker-1: Wrote 1000 images to /tmp/madlib_DaP40IOgzi/cifar10_train0001.tmp\n",
"PoolWorker-3: Wrote 1000 images to /tmp/madlib_99mTsCxOFF/cifar10_train0001.tmp\n",
"PoolWorker-2: Wrote 1000 images to /tmp/madlib_n5XjJvXs5s/cifar10_train0001.tmp\n",
"PoolWorker-4: Wrote 1000 images to /tmp/madlib_zGujxaoQIb/cifar10_train0001.tmp\n",
"PoolWorker-5: Wrote 1000 images to /tmp/madlib_D6q8olnown/cifar10_train0001.tmp\n",
"PoolWorker-1: Loaded 1000 images into cifar10_train\n",
"PoolWorker-3: Loaded 1000 images into cifar10_train\n",
"PoolWorker-2: Loaded 1000 images into cifar10_train\n",
"PoolWorker-4: Loaded 1000 images into cifar10_train\n",
"PoolWorker-5: Loaded 1000 images into cifar10_train\n",
"PoolWorker-1: Wrote 1000 images to /tmp/madlib_DaP40IOgzi/cifar10_train0002.tmp\n",
"PoolWorker-3: Wrote 1000 images to /tmp/madlib_99mTsCxOFF/cifar10_train0002.tmp\n",
"PoolWorker-2: Wrote 1000 images to /tmp/madlib_n5XjJvXs5s/cifar10_train0002.tmp\n",
"PoolWorker-4: Wrote 1000 images to /tmp/madlib_zGujxaoQIb/cifar10_train0002.tmp\n",
"PoolWorker-5: Wrote 1000 images to /tmp/madlib_D6q8olnown/cifar10_train0002.tmp\n",
"PoolWorker-1: Loaded 1000 images into cifar10_train\n",
"PoolWorker-3: Loaded 1000 images into cifar10_train\n",
"PoolWorker-2: Loaded 1000 images into cifar10_train\n",
"PoolWorker-4: Loaded 1000 images into cifar10_train\n",
"PoolWorker-5: Loaded 1000 images into cifar10_train\n",
"PoolWorker-1: Wrote 1000 images to /tmp/madlib_DaP40IOgzi/cifar10_train0003.tmp\n",
"PoolWorker-3: Wrote 1000 images to /tmp/madlib_99mTsCxOFF/cifar10_train0003.tmp\n",
"PoolWorker-2: Wrote 1000 images to /tmp/madlib_n5XjJvXs5s/cifar10_train0003.tmp\n",
"PoolWorker-4: Wrote 1000 images to /tmp/madlib_zGujxaoQIb/cifar10_train0003.tmp\n",
"PoolWorker-5: Wrote 1000 images to /tmp/madlib_D6q8olnown/cifar10_train0003.tmp\n",
"PoolWorker-1: Loaded 1000 images into cifar10_train\n",
"PoolWorker-3: Loaded 1000 images into cifar10_train\n",
"PoolWorker-2: Loaded 1000 images into cifar10_train\n",
"PoolWorker-4: Loaded 1000 images into cifar10_train\n",
"PoolWorker-5: Loaded 1000 images into cifar10_train\n",
"PoolWorker-1: Wrote 1000 images to /tmp/madlib_DaP40IOgzi/cifar10_train0004.tmp\n",
"PoolWorker-3: Wrote 1000 images to /tmp/madlib_99mTsCxOFF/cifar10_train0004.tmp\n",
"PoolWorker-2: Wrote 1000 images to /tmp/madlib_n5XjJvXs5s/cifar10_train0004.tmp\n",
"PoolWorker-4: Wrote 1000 images to /tmp/madlib_zGujxaoQIb/cifar10_train0004.tmp\n",
"PoolWorker-5: Wrote 1000 images to /tmp/madlib_D6q8olnown/cifar10_train0004.tmp\n",
"PoolWorker-1: Loaded 1000 images into cifar10_train\n",
"PoolWorker-3: Loaded 1000 images into cifar10_train\n",
"PoolWorker-2: Loaded 1000 images into cifar10_train\n",
"PoolWorker-4: Loaded 1000 images into cifar10_train\n",
"PoolWorker-5: Loaded 1000 images into cifar10_train\n",
"PoolWorker-1: Wrote 1000 images to /tmp/madlib_DaP40IOgzi/cifar10_train0005.tmp\n",
"PoolWorker-3: Wrote 1000 images to /tmp/madlib_99mTsCxOFF/cifar10_train0005.tmp\n",
"PoolWorker-2: Wrote 1000 images to /tmp/madlib_n5XjJvXs5s/cifar10_train0005.tmp\n",
"PoolWorker-4: Wrote 1000 images to /tmp/madlib_zGujxaoQIb/cifar10_train0005.tmp\n",
"PoolWorker-5: Wrote 1000 images to /tmp/madlib_D6q8olnown/cifar10_train0005.tmp\n",
"PoolWorker-1: Loaded 1000 images into cifar10_train\n",
"PoolWorker-3: Loaded 1000 images into cifar10_train\n",
"PoolWorker-2: Loaded 1000 images into cifar10_train\n",
"PoolWorker-4: Loaded 1000 images into cifar10_train\n",
"PoolWorker-5: Loaded 1000 images into cifar10_train\n",
"PoolWorker-1: Wrote 1000 images to /tmp/madlib_DaP40IOgzi/cifar10_train0006.tmp\n",
"PoolWorker-3: Wrote 1000 images to /tmp/madlib_99mTsCxOFF/cifar10_train0006.tmp\n",
"PoolWorker-2: Wrote 1000 images to /tmp/madlib_n5XjJvXs5s/cifar10_train0006.tmp\n",
"PoolWorker-4: Wrote 1000 images to /tmp/madlib_zGujxaoQIb/cifar10_train0006.tmp\n",
"PoolWorker-5: Wrote 1000 images to /tmp/madlib_D6q8olnown/cifar10_train0006.tmp\n",
"PoolWorker-1: Loaded 1000 images into cifar10_train\n",
"PoolWorker-3: Loaded 1000 images into cifar10_train\n",
"PoolWorker-2: Loaded 1000 images into cifar10_train\n",
"PoolWorker-4: Loaded 1000 images into cifar10_train\n",
"PoolWorker-5: Loaded 1000 images into cifar10_train\n",
"PoolWorker-1: Wrote 1000 images to /tmp/madlib_DaP40IOgzi/cifar10_train0007.tmp\n",
"PoolWorker-3: Wrote 1000 images to /tmp/madlib_99mTsCxOFF/cifar10_train0007.tmp\n",
"PoolWorker-2: Wrote 1000 images to /tmp/madlib_n5XjJvXs5s/cifar10_train0007.tmp\n",
"PoolWorker-4: Wrote 1000 images to /tmp/madlib_zGujxaoQIb/cifar10_train0007.tmp\n",
"PoolWorker-5: Wrote 1000 images to /tmp/madlib_D6q8olnown/cifar10_train0007.tmp\n",
"PoolWorker-1: Loaded 1000 images into cifar10_train\n",
"PoolWorker-3: Loaded 1000 images into cifar10_train\n",
"PoolWorker-2: Loaded 1000 images into cifar10_train\n",
"PoolWorker-4: Loaded 1000 images into cifar10_train\n",
"PoolWorker-5: Loaded 1000 images into cifar10_train\n",
"PoolWorker-1: Wrote 1000 images to /tmp/madlib_DaP40IOgzi/cifar10_train0008.tmp\n",
"PoolWorker-3: Wrote 1000 images to /tmp/madlib_99mTsCxOFF/cifar10_train0008.tmp\n",
"PoolWorker-2: Wrote 1000 images to /tmp/madlib_n5XjJvXs5s/cifar10_train0008.tmp\n",
"PoolWorker-4: Wrote 1000 images to /tmp/madlib_zGujxaoQIb/cifar10_train0008.tmp\n",
"PoolWorker-5: Wrote 1000 images to /tmp/madlib_D6q8olnown/cifar10_train0008.tmp\n",
"PoolWorker-1: Loaded 1000 images into cifar10_train\n",
"PoolWorker-3: Loaded 1000 images into cifar10_train\n",
"PoolWorker-2: Loaded 1000 images into cifar10_train\n",
"PoolWorker-4: Loaded 1000 images into cifar10_train\n",
"PoolWorker-5: Loaded 1000 images into cifar10_train\n",
"PoolWorker-1: Wrote 1000 images to /tmp/madlib_DaP40IOgzi/cifar10_train0009.tmp\n",
"PoolWorker-3: Wrote 1000 images to /tmp/madlib_99mTsCxOFF/cifar10_train0009.tmp\n",
"PoolWorker-1: Loaded 1000 images into cifar10_train\n",
"PoolWorker-3: Loaded 1000 images into cifar10_train\n",
"PoolWorker-1: Wrote 1000 images to /tmp/madlib_DaP40IOgzi/cifar10_train0010.tmp\n",
"PoolWorker-3: Wrote 1000 images to /tmp/madlib_99mTsCxOFF/cifar10_train0010.tmp\n",
"PoolWorker-1: Loaded 1000 images into cifar10_train\n",
"PoolWorker-3: Loaded 1000 images into cifar10_train\n",
"PoolWorker-1: Wrote 1000 images to /tmp/madlib_DaP40IOgzi/cifar10_train0011.tmp\n",
"PoolWorker-1: Loaded 1000 images into cifar10_train\n",
"PoolWorker-4: Removed temporary directory /tmp/madlib_zGujxaoQIb\n",
"PoolWorker-5: Removed temporary directory /tmp/madlib_D6q8olnown\n",
"PoolWorker-2: Removed temporary directory /tmp/madlib_n5XjJvXs5s\n",
"PoolWorker-1: Removed temporary directory /tmp/madlib_DaP40IOgzi\n",
"PoolWorker-3: Removed temporary directory /tmp/madlib_99mTsCxOFF\n",
"Done! Loaded 50000 images in 19.7727279663s\n",
"5 workers terminated.\n",
"MainProcess: Connected to madlib db.\n",
"Executing: CREATE TABLE cifar10_val (id SERIAL, x REAL[], y TEXT)\n",
"CREATE TABLE\n",
"Created table cifar10_val in madlib db\n",
"Spawning 5 workers...\n",
"Initializing PoolWorker-6 [pid 10850]\n",
"PoolWorker-6: Created temporary directory /tmp/madlib_OqFarH4eVS\n",
"Initializing PoolWorker-7 [pid 10851]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"PoolWorker-7: Created temporary directory /tmp/madlib_BHhah9z53T\n",
"Initializing PoolWorker-8 [pid 10852]\n",
"PoolWorker-8: Created temporary directory /tmp/madlib_G5oLCmXwQN\n",
"Initializing PoolWorker-9 [pid 10853]\n",
"PoolWorker-6: Connected to madlib db.\n",
"PoolWorker-9: Created temporary directory /tmp/madlib_THDiiymnsM\n",
"Initializing PoolWorker-10 [pid 10854]\n",
"PoolWorker-7: Connected to madlib db.\n",
"PoolWorker-10: Created temporary directory /tmp/madlib_DLO1TEiyo6\n",
"PoolWorker-8: Connected to madlib db.\n",
"PoolWorker-9: Connected to madlib db.\n",
"PoolWorker-10: Connected to madlib db.\n",
"PoolWorker-6: Wrote 1000 images to /tmp/madlib_OqFarH4eVS/cifar10_val0000.tmp\n",
"PoolWorker-7: Wrote 1000 images to /tmp/madlib_BHhah9z53T/cifar10_val0000.tmp\n",
"PoolWorker-8: Wrote 1000 images to /tmp/madlib_G5oLCmXwQN/cifar10_val0000.tmp\n",
"PoolWorker-9: Wrote 1000 images to /tmp/madlib_THDiiymnsM/cifar10_val0000.tmp\n",
"PoolWorker-10: Wrote 1000 images to /tmp/madlib_DLO1TEiyo6/cifar10_val0000.tmp\n",
"PoolWorker-6: Loaded 1000 images into cifar10_val\n",
"PoolWorker-7: Loaded 1000 images into cifar10_val\n",
"PoolWorker-8: Loaded 1000 images into cifar10_val\n",
"PoolWorker-9: Loaded 1000 images into cifar10_val\n",
"PoolWorker-10: Loaded 1000 images into cifar10_val\n",
"PoolWorker-6: Wrote 1000 images to /tmp/madlib_OqFarH4eVS/cifar10_val0001.tmp\n",
"PoolWorker-7: Wrote 1000 images to /tmp/madlib_BHhah9z53T/cifar10_val0001.tmp\n",
"PoolWorker-8: Wrote 1000 images to /tmp/madlib_G5oLCmXwQN/cifar10_val0001.tmp\n",
"PoolWorker-9: Wrote 1000 images to /tmp/madlib_THDiiymnsM/cifar10_val0001.tmp\n",
"PoolWorker-10: Wrote 1000 images to /tmp/madlib_DLO1TEiyo6/cifar10_val0001.tmp\n",
"PoolWorker-6: Loaded 1000 images into cifar10_val\n",
"PoolWorker-7: Loaded 1000 images into cifar10_val\n",
"PoolWorker-8: Loaded 1000 images into cifar10_val\n",
"PoolWorker-9: Loaded 1000 images into cifar10_val\n",
"PoolWorker-10: Loaded 1000 images into cifar10_val\n",
"PoolWorker-8: Removed temporary directory /tmp/madlib_G5oLCmXwQN\n",
"PoolWorker-7: Removed temporary directory /tmp/madlib_BHhah9z53T\n",
"PoolWorker-10: Removed temporary directory /tmp/madlib_DLO1TEiyo6\n",
"PoolWorker-6: Removed temporary directory /tmp/madlib_OqFarH4eVS\n",
"PoolWorker-9: Removed temporary directory /tmp/madlib_THDiiymnsM\n",
"Done! Loaded 10000 images in 4.03977298737s\n",
"5 workers terminated.\n"
]
}
],
"source": [
"# Load dataset into np array\n",
"(x_train, y_train), (x_test, y_test) = cifar10.load_data()\n",
"\n",
"%sql DROP TABLE IF EXISTS cifar10_train, cifar10_val;\n",
"\n",
"# Save images to temporary directories and load into database\n",
"iloader.load_dataset_from_np(x_train, y_train, 'cifar10_train', append=False)\n",
"iloader.load_dataset_from_np(x_test, y_test, 'cifar10_val', append=False)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" * postgresql://gpadmin@localhost:8000/madlib\n",
"(psycopg2.errors.UndefinedTable) relation \"cifar_10_train_data\" does not exist\n",
"LINE 1: SELECT COUNT(*) FROM cifar_10_train_data;\n",
" ^\n",
"\n",
"[SQL: SELECT COUNT(*) FROM cifar_10_train_data;]\n",
"(Background on this error at: http://sqlalche.me/e/f405)\n"
]
}
],
"source": [
"%sql SELECT COUNT(*) FROM cifar10_train;"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>count</th>\n",
" </tr>\n",
" <tr>\n",
" <td>10000</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(10000L,)]"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%sql SELECT COUNT(*) FROM cifar10_val;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a id=\"distr\"></a>\n",
"# 2. Setup distribution rules and call preprocessor\n",
"\n",
"Get cluster configuration\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"20 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>hostname</th>\n",
" <th>gpu_descr</th>\n",
" </tr>\n",
" <tr>\n",
" <td>phoenix0</td>\n",
" <td>device: 0, name: Tesla P100-PCIE-16GB, pci bus id: 0000:00:04.0, compute capability: 6.0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>phoenix0</td>\n",
" <td>device: 1, name: Tesla P100-PCIE-16GB, pci bus id: 0000:00:05.0, compute capability: 6.0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>phoenix0</td>\n",
" <td>device: 2, name: Tesla P100-PCIE-16GB, pci bus id: 0000:00:06.0, compute capability: 6.0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>phoenix0</td>\n",
" <td>device: 3, name: Tesla P100-PCIE-16GB, pci bus id: 0000:00:07.0, compute capability: 6.0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>phoenix1</td>\n",
" <td>device: 0, name: Tesla P100-PCIE-16GB, pci bus id: 0000:00:04.0, compute capability: 6.0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>phoenix1</td>\n",
" <td>device: 1, name: Tesla P100-PCIE-16GB, pci bus id: 0000:00:05.0, compute capability: 6.0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>phoenix1</td>\n",
" <td>device: 2, name: Tesla P100-PCIE-16GB, pci bus id: 0000:00:06.0, compute capability: 6.0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>phoenix1</td>\n",
" <td>device: 3, name: Tesla P100-PCIE-16GB, pci bus id: 0000:00:07.0, compute capability: 6.0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>phoenix2</td>\n",
" <td>device: 0, name: Tesla P100-PCIE-16GB, pci bus id: 0000:00:04.0, compute capability: 6.0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>phoenix2</td>\n",
" <td>device: 1, name: Tesla P100-PCIE-16GB, pci bus id: 0000:00:05.0, compute capability: 6.0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>phoenix2</td>\n",
" <td>device: 2, name: Tesla P100-PCIE-16GB, pci bus id: 0000:00:06.0, compute capability: 6.0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>phoenix2</td>\n",
" <td>device: 3, name: Tesla P100-PCIE-16GB, pci bus id: 0000:00:07.0, compute capability: 6.0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>phoenix3</td>\n",
" <td>device: 0, name: Tesla P100-PCIE-16GB, pci bus id: 0000:00:04.0, compute capability: 6.0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>phoenix3</td>\n",
" <td>device: 1, name: Tesla P100-PCIE-16GB, pci bus id: 0000:00:05.0, compute capability: 6.0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>phoenix3</td>\n",
" <td>device: 2, name: Tesla P100-PCIE-16GB, pci bus id: 0000:00:06.0, compute capability: 6.0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>phoenix3</td>\n",
" <td>device: 3, name: Tesla P100-PCIE-16GB, pci bus id: 0000:00:07.0, compute capability: 6.0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>phoenix4</td>\n",
" <td>device: 0, name: Tesla P100-PCIE-16GB, pci bus id: 0000:00:04.0, compute capability: 6.0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>phoenix4</td>\n",
" <td>device: 1, name: Tesla P100-PCIE-16GB, pci bus id: 0000:00:05.0, compute capability: 6.0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>phoenix4</td>\n",
" <td>device: 2, name: Tesla P100-PCIE-16GB, pci bus id: 0000:00:06.0, compute capability: 6.0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>phoenix4</td>\n",
" <td>device: 3, name: Tesla P100-PCIE-16GB, pci bus id: 0000:00:07.0, compute capability: 6.0</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(u'phoenix0', u'device: 0, name: Tesla P100-PCIE-16GB, pci bus id: 0000:00:04.0, compute capability: 6.0'),\n",
" (u'phoenix0', u'device: 1, name: Tesla P100-PCIE-16GB, pci bus id: 0000:00:05.0, compute capability: 6.0'),\n",
" (u'phoenix0', u'device: 2, name: Tesla P100-PCIE-16GB, pci bus id: 0000:00:06.0, compute capability: 6.0'),\n",
" (u'phoenix0', u'device: 3, name: Tesla P100-PCIE-16GB, pci bus id: 0000:00:07.0, compute capability: 6.0'),\n",
" (u'phoenix1', u'device: 0, name: Tesla P100-PCIE-16GB, pci bus id: 0000:00:04.0, compute capability: 6.0'),\n",
" (u'phoenix1', u'device: 1, name: Tesla P100-PCIE-16GB, pci bus id: 0000:00:05.0, compute capability: 6.0'),\n",
" (u'phoenix1', u'device: 2, name: Tesla P100-PCIE-16GB, pci bus id: 0000:00:06.0, compute capability: 6.0'),\n",
" (u'phoenix1', u'device: 3, name: Tesla P100-PCIE-16GB, pci bus id: 0000:00:07.0, compute capability: 6.0'),\n",
" (u'phoenix2', u'device: 0, name: Tesla P100-PCIE-16GB, pci bus id: 0000:00:04.0, compute capability: 6.0'),\n",
" (u'phoenix2', u'device: 1, name: Tesla P100-PCIE-16GB, pci bus id: 0000:00:05.0, compute capability: 6.0'),\n",
" (u'phoenix2', u'device: 2, name: Tesla P100-PCIE-16GB, pci bus id: 0000:00:06.0, compute capability: 6.0'),\n",
" (u'phoenix2', u'device: 3, name: Tesla P100-PCIE-16GB, pci bus id: 0000:00:07.0, compute capability: 6.0'),\n",
" (u'phoenix3', u'device: 0, name: Tesla P100-PCIE-16GB, pci bus id: 0000:00:04.0, compute capability: 6.0'),\n",
" (u'phoenix3', u'device: 1, name: Tesla P100-PCIE-16GB, pci bus id: 0000:00:05.0, compute capability: 6.0'),\n",
" (u'phoenix3', u'device: 2, name: Tesla P100-PCIE-16GB, pci bus id: 0000:00:06.0, compute capability: 6.0'),\n",
" (u'phoenix3', u'device: 3, name: Tesla P100-PCIE-16GB, pci bus id: 0000:00:07.0, compute capability: 6.0'),\n",
" (u'phoenix4', u'device: 0, name: Tesla P100-PCIE-16GB, pci bus id: 0000:00:04.0, compute capability: 6.0'),\n",
" (u'phoenix4', u'device: 1, name: Tesla P100-PCIE-16GB, pci bus id: 0000:00:05.0, compute capability: 6.0'),\n",
" (u'phoenix4', u'device: 2, name: Tesla P100-PCIE-16GB, pci bus id: 0000:00:06.0, compute capability: 6.0'),\n",
" (u'phoenix4', u'device: 3, name: Tesla P100-PCIE-16GB, pci bus id: 0000:00:07.0, compute capability: 6.0')]"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS host_gpu_mapping_tf;\n",
"SELECT * FROM madlib.gpu_configuration('host_gpu_mapping_tf');\n",
"SELECT * FROM host_gpu_mapping_tf ORDER BY hostname, gpu_descr;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Below are examples of setting up different distribution rules tables. You can customize this to your needs.\n",
"\n",
"Build distribution rules table for 4 VMs"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS segments_to_use_4VMs;\n",
"CREATE TABLE segments_to_use_4VMs AS\n",
" SELECT DISTINCT dbid, hostname FROM gp_segment_configuration JOIN host_gpu_mapping_tf USING (hostname)\n",
" WHERE role='p' AND content>=0 AND hostname!='phoenix4';\n",
"SELECT * FROM segments_to_use_4VMs ORDER BY hostname, dbid;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Build distribution rules table for 2 VMs"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS segments_to_use_2VMs;\n",
"CREATE TABLE segments_to_use_2VMs AS\n",
" SELECT DISTINCT dbid, hostname FROM gp_segment_configuration JOIN host_gpu_mapping_tf USING (hostname)\n",
" WHERE role='p' AND content>=0 AND (hostname='phoenix0' OR hostname='phoenix1');\n",
"SELECT * FROM segments_to_use_2VMs ORDER BY hostname, dbid;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Build distribution rules table for 1 VMs"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS segments_to_use_1VM;\n",
"CREATE TABLE segments_to_use_1VM AS\n",
" SELECT DISTINCT dbid, hostname FROM gp_segment_configuration JOIN host_gpu_mapping_tf USING (hostname)\n",
" WHERE role='p' AND content>=0 AND hostname='phoenix0';\n",
"SELECT * FROM segments_to_use_1VM ORDER BY hostname, dbid;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Build distribution rules table for 1 segment"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"5 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>dbid</th>\n",
" <th>content</th>\n",
" <th>role</th>\n",
" <th>preferred_role</th>\n",
" <th>mode</th>\n",
" <th>status</th>\n",
" <th>port</th>\n",
" <th>hostname</th>\n",
" <th>address</th>\n",
" <th>replication_port</th>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>-1</td>\n",
" <td>p</td>\n",
" <td>p</td>\n",
" <td>s</td>\n",
" <td>u</td>\n",
" <td>5432</td>\n",
" <td>phoenix0</td>\n",
" <td>phoenix0</td>\n",
" <td>None</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>0</td>\n",
" <td>p</td>\n",
" <td>p</td>\n",
" <td>c</td>\n",
" <td>u</td>\n",
" <td>40000</td>\n",
" <td>phoenix0</td>\n",
" <td>phoenix0</td>\n",
" <td>70000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>1</td>\n",
" <td>p</td>\n",
" <td>p</td>\n",
" <td>c</td>\n",
" <td>u</td>\n",
" <td>40001</td>\n",
" <td>phoenix0</td>\n",
" <td>phoenix0</td>\n",
" <td>70001</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>2</td>\n",
" <td>p</td>\n",
" <td>p</td>\n",
" <td>c</td>\n",
" <td>u</td>\n",
" <td>40002</td>\n",
" <td>phoenix0</td>\n",
" <td>phoenix0</td>\n",
" <td>70002</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>3</td>\n",
" <td>p</td>\n",
" <td>p</td>\n",
" <td>c</td>\n",
" <td>u</td>\n",
" <td>40003</td>\n",
" <td>phoenix0</td>\n",
" <td>phoenix0</td>\n",
" <td>70003</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(1, -1, u'p', u'p', u's', u'u', 5432, u'phoenix0', u'phoenix0', None),\n",
" (2, 0, u'p', u'p', u'c', u'u', 40000, u'phoenix0', u'phoenix0', 70000),\n",
" (3, 1, u'p', u'p', u'c', u'u', 40001, u'phoenix0', u'phoenix0', 70001),\n",
" (4, 2, u'p', u'p', u'c', u'u', 40002, u'phoenix0', u'phoenix0', 70002),\n",
" (5, 3, u'p', u'p', u'c', u'u', 40003, u'phoenix0', u'phoenix0', 70003)]"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT * FROM gp_segment_configuration WHERE role='p' AND hostname='phoenix0' ORDER BY dbid;"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>dbid</th>\n",
" <th>hostname</th>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>phoenix0</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(2, u'phoenix0')]"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS segments_to_use_1seg;\n",
"CREATE TABLE segments_to_use_1seg AS\n",
" SELECT DISTINCT dbid, hostname FROM gp_segment_configuration JOIN host_gpu_mapping_tf USING (hostname)\n",
" WHERE dbid=2;\n",
"SELECT * FROM segments_to_use_1seg ORDER BY hostname, dbid;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Training dataset (uses training preprocessor):"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"16 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>independent_var_shape</th>\n",
" <th>dependent_var_shape</th>\n",
" <th>buffer_id</th>\n",
" </tr>\n",
" <tr>\n",
" <td>[3125, 32, 32, 3]</td>\n",
" <td>[3125, 10]</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>[3125, 32, 32, 3]</td>\n",
" <td>[3125, 10]</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>[3125, 32, 32, 3]</td>\n",
" <td>[3125, 10]</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>[3125, 32, 32, 3]</td>\n",
" <td>[3125, 10]</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>[3125, 32, 32, 3]</td>\n",
" <td>[3125, 10]</td>\n",
" <td>4</td>\n",
" </tr>\n",
" <tr>\n",
" <td>[3125, 32, 32, 3]</td>\n",
" <td>[3125, 10]</td>\n",
" <td>5</td>\n",
" </tr>\n",
" <tr>\n",
" <td>[3125, 32, 32, 3]</td>\n",
" <td>[3125, 10]</td>\n",
" <td>6</td>\n",
" </tr>\n",
" <tr>\n",
" <td>[3125, 32, 32, 3]</td>\n",
" <td>[3125, 10]</td>\n",
" <td>7</td>\n",
" </tr>\n",
" <tr>\n",
" <td>[3125, 32, 32, 3]</td>\n",
" <td>[3125, 10]</td>\n",
" <td>8</td>\n",
" </tr>\n",
" <tr>\n",
" <td>[3125, 32, 32, 3]</td>\n",
" <td>[3125, 10]</td>\n",
" <td>9</td>\n",
" </tr>\n",
" <tr>\n",
" <td>[3125, 32, 32, 3]</td>\n",
" <td>[3125, 10]</td>\n",
" <td>10</td>\n",
" </tr>\n",
" <tr>\n",
" <td>[3125, 32, 32, 3]</td>\n",
" <td>[3125, 10]</td>\n",
" <td>11</td>\n",
" </tr>\n",
" <tr>\n",
" <td>[3125, 32, 32, 3]</td>\n",
" <td>[3125, 10]</td>\n",
" <td>12</td>\n",
" </tr>\n",
" <tr>\n",
" <td>[3125, 32, 32, 3]</td>\n",
" <td>[3125, 10]</td>\n",
" <td>13</td>\n",
" </tr>\n",
" <tr>\n",
" <td>[3125, 32, 32, 3]</td>\n",
" <td>[3125, 10]</td>\n",
" <td>14</td>\n",
" </tr>\n",
" <tr>\n",
" <td>[3125, 32, 32, 3]</td>\n",
" <td>[3125, 10]</td>\n",
" <td>15</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[([3125, 32, 32, 3], [3125, 10], 0),\n",
" ([3125, 32, 32, 3], [3125, 10], 1),\n",
" ([3125, 32, 32, 3], [3125, 10], 2),\n",
" ([3125, 32, 32, 3], [3125, 10], 3),\n",
" ([3125, 32, 32, 3], [3125, 10], 4),\n",
" ([3125, 32, 32, 3], [3125, 10], 5),\n",
" ([3125, 32, 32, 3], [3125, 10], 6),\n",
" ([3125, 32, 32, 3], [3125, 10], 7),\n",
" ([3125, 32, 32, 3], [3125, 10], 8),\n",
" ([3125, 32, 32, 3], [3125, 10], 9),\n",
" ([3125, 32, 32, 3], [3125, 10], 10),\n",
" ([3125, 32, 32, 3], [3125, 10], 11),\n",
" ([3125, 32, 32, 3], [3125, 10], 12),\n",
" ([3125, 32, 32, 3], [3125, 10], 13),\n",
" ([3125, 32, 32, 3], [3125, 10], 14),\n",
" ([3125, 32, 32, 3], [3125, 10], 15)]"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS cifar10_train_packed, cifar10_train_packed_summary;\n",
"\n",
"SELECT madlib.training_preprocessor_dl('cifar10_train', -- Source table\n",
" 'cifar10_train_packed', -- Output table\n",
" 'y', -- Dependent variable\n",
" 'x', -- Independent variable\n",
" NULL, -- Buffer size\n",
" 256.0, -- Normalizing constant\n",
" NULL, -- Number of classes\n",
" 'gpu_segments' -- Distribution rules\n",
" );\n",
"\n",
"SELECT independent_var_shape, dependent_var_shape, buffer_id FROM cifar10_train_packed ORDER BY buffer_id;"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>source_table</th>\n",
" <th>output_table</th>\n",
" <th>dependent_varname</th>\n",
" <th>independent_varname</th>\n",
" <th>dependent_vartype</th>\n",
" <th>class_values</th>\n",
" <th>buffer_size</th>\n",
" <th>normalizing_const</th>\n",
" <th>num_classes</th>\n",
" <th>distribution_rules</th>\n",
" <th>__internal_gpu_config__</th>\n",
" </tr>\n",
" <tr>\n",
" <td>cifar10_train</td>\n",
" <td>cifar10_train_packed</td>\n",
" <td>y</td>\n",
" <td>x</td>\n",
" <td>smallint</td>\n",
" <td>[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]</td>\n",
" <td>3125</td>\n",
" <td>256.0</td>\n",
" <td>10</td>\n",
" <td>[2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17]</td>\n",
" <td>[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(u'cifar10_train', u'cifar10_train_packed', u'y', u'x', u'smallint', [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], 3125, 256.0, 10, [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])]"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT * FROM cifar10_train_packed_summary;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Validation dataset (uses validation preprocessor):"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"16 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>independent_var_shape</th>\n",
" <th>dependent_var_shape</th>\n",
" <th>buffer_id</th>\n",
" </tr>\n",
" <tr>\n",
" <td>[625, 32, 32, 3]</td>\n",
" <td>[625, 10]</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>[625, 32, 32, 3]</td>\n",
" <td>[625, 10]</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>[625, 32, 32, 3]</td>\n",
" <td>[625, 10]</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>[625, 32, 32, 3]</td>\n",
" <td>[625, 10]</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>[625, 32, 32, 3]</td>\n",
" <td>[625, 10]</td>\n",
" <td>4</td>\n",
" </tr>\n",
" <tr>\n",
" <td>[625, 32, 32, 3]</td>\n",
" <td>[625, 10]</td>\n",
" <td>5</td>\n",
" </tr>\n",
" <tr>\n",
" <td>[625, 32, 32, 3]</td>\n",
" <td>[625, 10]</td>\n",
" <td>6</td>\n",
" </tr>\n",
" <tr>\n",
" <td>[625, 32, 32, 3]</td>\n",
" <td>[625, 10]</td>\n",
" <td>7</td>\n",
" </tr>\n",
" <tr>\n",
" <td>[625, 32, 32, 3]</td>\n",
" <td>[625, 10]</td>\n",
" <td>8</td>\n",
" </tr>\n",
" <tr>\n",
" <td>[625, 32, 32, 3]</td>\n",
" <td>[625, 10]</td>\n",
" <td>9</td>\n",
" </tr>\n",
" <tr>\n",
" <td>[625, 32, 32, 3]</td>\n",
" <td>[625, 10]</td>\n",
" <td>10</td>\n",
" </tr>\n",
" <tr>\n",
" <td>[625, 32, 32, 3]</td>\n",
" <td>[625, 10]</td>\n",
" <td>11</td>\n",
" </tr>\n",
" <tr>\n",
" <td>[625, 32, 32, 3]</td>\n",
" <td>[625, 10]</td>\n",
" <td>12</td>\n",
" </tr>\n",
" <tr>\n",
" <td>[625, 32, 32, 3]</td>\n",
" <td>[625, 10]</td>\n",
" <td>13</td>\n",
" </tr>\n",
" <tr>\n",
" <td>[625, 32, 32, 3]</td>\n",
" <td>[625, 10]</td>\n",
" <td>14</td>\n",
" </tr>\n",
" <tr>\n",
" <td>[625, 32, 32, 3]</td>\n",
" <td>[625, 10]</td>\n",
" <td>15</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[([625, 32, 32, 3], [625, 10], 0),\n",
" ([625, 32, 32, 3], [625, 10], 1),\n",
" ([625, 32, 32, 3], [625, 10], 2),\n",
" ([625, 32, 32, 3], [625, 10], 3),\n",
" ([625, 32, 32, 3], [625, 10], 4),\n",
" ([625, 32, 32, 3], [625, 10], 5),\n",
" ([625, 32, 32, 3], [625, 10], 6),\n",
" ([625, 32, 32, 3], [625, 10], 7),\n",
" ([625, 32, 32, 3], [625, 10], 8),\n",
" ([625, 32, 32, 3], [625, 10], 9),\n",
" ([625, 32, 32, 3], [625, 10], 10),\n",
" ([625, 32, 32, 3], [625, 10], 11),\n",
" ([625, 32, 32, 3], [625, 10], 12),\n",
" ([625, 32, 32, 3], [625, 10], 13),\n",
" ([625, 32, 32, 3], [625, 10], 14),\n",
" ([625, 32, 32, 3], [625, 10], 15)]"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS cifar10_val_packed, cifar10_val_packed_summary;\n",
"\n",
"SELECT madlib.validation_preprocessor_dl('cifar10_val', -- Source table\n",
" 'cifar10_val_packed', -- Output table\n",
" 'y', -- Dependent variable\n",
" 'x', -- Independent variable\n",
" 'cifar10_train_packed', -- From training preprocessor step\n",
" NULL, -- Buffer size\n",
" 'gpu_segments' -- Distribution rules\n",
" ); \n",
"\n",
"SELECT independent_var_shape, dependent_var_shape, buffer_id FROM cifar10_val_packed ORDER BY buffer_id;"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>source_table</th>\n",
" <th>output_table</th>\n",
" <th>dependent_varname</th>\n",
" <th>independent_varname</th>\n",
" <th>dependent_vartype</th>\n",
" <th>class_values</th>\n",
" <th>buffer_size</th>\n",
" <th>normalizing_const</th>\n",
" <th>num_classes</th>\n",
" <th>distribution_rules</th>\n",
" <th>__internal_gpu_config__</th>\n",
" </tr>\n",
" <tr>\n",
" <td>cifar10_val</td>\n",
" <td>cifar10_val_packed</td>\n",
" <td>y</td>\n",
" <td>x</td>\n",
" <td>smallint</td>\n",
" <td>[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]</td>\n",
" <td>625</td>\n",
" <td>256.0</td>\n",
" <td>10</td>\n",
" <td>[2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17]</td>\n",
" <td>[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(u'cifar10_val', u'cifar10_val_packed', u'y', u'x', u'smallint', [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], 625, 256.0, 10, [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])]"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT * FROM cifar10_val_packed_summary;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a id=\"arch\"></a>\n",
"# 3. Define and load model architectures\n",
"\n",
"Here we load some example model architectures from published sources.\n",
"\n",
"a. Model architecture from https://keras.io/examples/cifar10_cnn/"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"num_classes = 10\n",
"\n",
"#to be removed\n",
"#do this just to get shape for model architecture \n",
"(x_train, y_train), (x_test, y_test) = cifar10.load_data()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"conv2d_1 (Conv2D) (None, 32, 32, 32) 896 \n",
"_________________________________________________________________\n",
"activation_1 (Activation) (None, 32, 32, 32) 0 \n",
"_________________________________________________________________\n",
"conv2d_2 (Conv2D) (None, 30, 30, 32) 9248 \n",
"_________________________________________________________________\n",
"activation_2 (Activation) (None, 30, 30, 32) 0 \n",
"_________________________________________________________________\n",
"max_pooling2d_1 (MaxPooling2 (None, 15, 15, 32) 0 \n",
"_________________________________________________________________\n",
"dropout_1 (Dropout) (None, 15, 15, 32) 0 \n",
"_________________________________________________________________\n",
"conv2d_3 (Conv2D) (None, 15, 15, 64) 18496 \n",
"_________________________________________________________________\n",
"activation_3 (Activation) (None, 15, 15, 64) 0 \n",
"_________________________________________________________________\n",
"conv2d_4 (Conv2D) (None, 13, 13, 64) 36928 \n",
"_________________________________________________________________\n",
"activation_4 (Activation) (None, 13, 13, 64) 0 \n",
"_________________________________________________________________\n",
"max_pooling2d_2 (MaxPooling2 (None, 6, 6, 64) 0 \n",
"_________________________________________________________________\n",
"dropout_2 (Dropout) (None, 6, 6, 64) 0 \n",
"_________________________________________________________________\n",
"flatten_1 (Flatten) (None, 2304) 0 \n",
"_________________________________________________________________\n",
"dense_1 (Dense) (None, 512) 1180160 \n",
"_________________________________________________________________\n",
"activation_5 (Activation) (None, 512) 0 \n",
"_________________________________________________________________\n",
"dropout_3 (Dropout) (None, 512) 0 \n",
"_________________________________________________________________\n",
"dense_2 (Dense) (None, 10) 5130 \n",
"_________________________________________________________________\n",
"activation_6 (Activation) (None, 10) 0 \n",
"=================================================================\n",
"Total params: 1,250,858\n",
"Trainable params: 1,250,858\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"model1 = Sequential()\n",
"\n",
"model1.add(Conv2D(32, (3, 3), padding='same',\n",
" input_shape=x_train.shape[1:]))\n",
"model1.add(Activation('relu'))\n",
"model1.add(Conv2D(32, (3, 3)))\n",
"model1.add(Activation('relu'))\n",
"model1.add(MaxPooling2D(pool_size=(2, 2)))\n",
"model1.add(Dropout(0.25))\n",
"\n",
"model1.add(Conv2D(64, (3, 3), padding='same'))\n",
"model1.add(Activation('relu'))\n",
"model1.add(Conv2D(64, (3, 3)))\n",
"model1.add(Activation('relu'))\n",
"model1.add(MaxPooling2D(pool_size=(2, 2)))\n",
"model1.add(Dropout(0.25))\n",
"\n",
"model1.add(Flatten())\n",
"model1.add(Dense(512))\n",
"model1.add(Activation('relu'))\n",
"model1.add(Dropout(0.5))\n",
"model1.add(Dense(num_classes))\n",
"model1.add(Activation('softmax'))\n",
"\n",
"model1.summary()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'{\"class_name\": \"Sequential\", \"keras_version\": \"2.1.6\", \"config\": [{\"class_name\": \"Conv2D\", \"config\": {\"kernel_initializer\": {\"class_name\": \"VarianceScaling\", \"config\": {\"distribution\": \"uniform\", \"scale\": 1.0, \"seed\": null, \"mode\": \"fan_avg\"}}, \"name\": \"conv2d_1\", \"kernel_constraint\": null, \"bias_regularizer\": null, \"bias_constraint\": null, \"dtype\": \"float32\", \"activation\": \"linear\", \"trainable\": true, \"data_format\": \"channels_last\", \"filters\": 32, \"padding\": \"same\", \"strides\": [1, 1], \"dilation_rate\": [1, 1], \"kernel_regularizer\": null, \"bias_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"batch_input_shape\": [null, 32, 32, 3], \"use_bias\": true, \"activity_regularizer\": null, \"kernel_size\": [3, 3]}}, {\"class_name\": \"Activation\", \"config\": {\"activation\": \"relu\", \"trainable\": true, \"name\": \"activation_1\"}}, {\"class_name\": \"Conv2D\", \"config\": {\"kernel_constraint\": null, \"kernel_initializer\": {\"class_name\": \"VarianceScaling\", \"config\": {\"distribution\": \"uniform\", \"scale\": 1.0, \"seed\": null, \"mode\": \"fan_avg\"}}, \"name\": \"conv2d_2\", \"bias_regularizer\": null, \"bias_constraint\": null, \"activation\": \"linear\", \"trainable\": true, \"data_format\": \"channels_last\", \"padding\": \"valid\", \"strides\": [1, 1], \"dilation_rate\": [1, 1], \"kernel_regularizer\": null, \"filters\": 32, \"bias_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"use_bias\": true, \"activity_regularizer\": null, \"kernel_size\": [3, 3]}}, {\"class_name\": \"Activation\", \"config\": {\"activation\": \"relu\", \"trainable\": true, \"name\": \"activation_2\"}}, {\"class_name\": \"MaxPooling2D\", \"config\": {\"name\": \"max_pooling2d_1\", \"trainable\": true, \"data_format\": \"channels_last\", \"pool_size\": [2, 2], \"padding\": \"valid\", \"strides\": [2, 2]}}, {\"class_name\": \"Dropout\", \"config\": {\"rate\": 0.25, \"noise_shape\": null, \"trainable\": true, \"seed\": null, \"name\": \"dropout_1\"}}, {\"class_name\": \"Conv2D\", \"config\": {\"kernel_constraint\": null, \"kernel_initializer\": {\"class_name\": \"VarianceScaling\", \"config\": {\"distribution\": \"uniform\", \"scale\": 1.0, \"seed\": null, \"mode\": \"fan_avg\"}}, \"name\": \"conv2d_3\", \"bias_regularizer\": null, \"bias_constraint\": null, \"activation\": \"linear\", \"trainable\": true, \"data_format\": \"channels_last\", \"padding\": \"same\", \"strides\": [1, 1], \"dilation_rate\": [1, 1], \"kernel_regularizer\": null, \"filters\": 64, \"bias_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"use_bias\": true, \"activity_regularizer\": null, \"kernel_size\": [3, 3]}}, {\"class_name\": \"Activation\", \"config\": {\"activation\": \"relu\", \"trainable\": true, \"name\": \"activation_3\"}}, {\"class_name\": \"Conv2D\", \"config\": {\"kernel_constraint\": null, \"kernel_initializer\": {\"class_name\": \"VarianceScaling\", \"config\": {\"distribution\": \"uniform\", \"scale\": 1.0, \"seed\": null, \"mode\": \"fan_avg\"}}, \"name\": \"conv2d_4\", \"bias_regularizer\": null, \"bias_constraint\": null, \"activation\": \"linear\", \"trainable\": true, \"data_format\": \"channels_last\", \"padding\": \"valid\", \"strides\": [1, 1], \"dilation_rate\": [1, 1], \"kernel_regularizer\": null, \"filters\": 64, \"bias_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"use_bias\": true, \"activity_regularizer\": null, \"kernel_size\": [3, 3]}}, {\"class_name\": \"Activation\", \"config\": {\"activation\": \"relu\", \"trainable\": true, \"name\": \"activation_4\"}}, {\"class_name\": \"MaxPooling2D\", \"config\": {\"name\": \"max_pooling2d_2\", \"trainable\": true, \"data_format\": \"channels_last\", \"pool_size\": [2, 2], \"padding\": \"valid\", \"strides\": [2, 2]}}, {\"class_name\": \"Dropout\", \"config\": {\"rate\": 0.25, \"noise_shape\": null, \"trainable\": true, \"seed\": null, \"name\": \"dropout_2\"}}, {\"class_name\": \"Flatten\", \"config\": {\"trainable\": true, \"name\": \"flatten_1\", \"data_format\": \"channels_last\"}}, {\"class_name\": \"Dense\", \"config\": {\"kernel_initializer\": {\"class_name\": \"VarianceScaling\", \"config\": {\"distribution\": \"uniform\", \"scale\": 1.0, \"seed\": null, \"mode\": \"fan_avg\"}}, \"name\": \"dense_1\", \"kernel_constraint\": null, \"bias_regularizer\": null, \"bias_constraint\": null, \"activation\": \"linear\", \"trainable\": true, \"kernel_regularizer\": null, \"bias_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"units\": 512, \"use_bias\": true, \"activity_regularizer\": null}}, {\"class_name\": \"Activation\", \"config\": {\"activation\": \"relu\", \"trainable\": true, \"name\": \"activation_5\"}}, {\"class_name\": \"Dropout\", \"config\": {\"rate\": 0.5, \"noise_shape\": null, \"trainable\": true, \"seed\": null, \"name\": \"dropout_3\"}}, {\"class_name\": \"Dense\", \"config\": {\"kernel_initializer\": {\"class_name\": \"VarianceScaling\", \"config\": {\"distribution\": \"uniform\", \"scale\": 1.0, \"seed\": null, \"mode\": \"fan_avg\"}}, \"name\": \"dense_2\", \"kernel_constraint\": null, \"bias_regularizer\": null, \"bias_constraint\": null, \"activation\": \"linear\", \"trainable\": true, \"kernel_regularizer\": null, \"bias_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"units\": 10, \"use_bias\": true, \"activity_regularizer\": null}}, {\"class_name\": \"Activation\", \"config\": {\"activation\": \"softmax\", \"trainable\": true, \"name\": \"activation_6\"}}], \"backend\": \"tensorflow\"}'"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model1.to_json()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"b. Model architecture from https://machinelearningmastery.com/how-to-develop-a-cnn-from-scratch-for-cifar-10-photo-classification/"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"conv2d_5 (Conv2D) (None, 32, 32, 32) 896 \n",
"_________________________________________________________________\n",
"batch_normalization_1 (Batch (None, 32, 32, 32) 128 \n",
"_________________________________________________________________\n",
"conv2d_6 (Conv2D) (None, 32, 32, 32) 9248 \n",
"_________________________________________________________________\n",
"batch_normalization_2 (Batch (None, 32, 32, 32) 128 \n",
"_________________________________________________________________\n",
"max_pooling2d_3 (MaxPooling2 (None, 16, 16, 32) 0 \n",
"_________________________________________________________________\n",
"dropout_4 (Dropout) (None, 16, 16, 32) 0 \n",
"_________________________________________________________________\n",
"conv2d_7 (Conv2D) (None, 16, 16, 64) 18496 \n",
"_________________________________________________________________\n",
"batch_normalization_3 (Batch (None, 16, 16, 64) 256 \n",
"_________________________________________________________________\n",
"conv2d_8 (Conv2D) (None, 16, 16, 64) 36928 \n",
"_________________________________________________________________\n",
"batch_normalization_4 (Batch (None, 16, 16, 64) 256 \n",
"_________________________________________________________________\n",
"max_pooling2d_4 (MaxPooling2 (None, 8, 8, 64) 0 \n",
"_________________________________________________________________\n",
"dropout_5 (Dropout) (None, 8, 8, 64) 0 \n",
"_________________________________________________________________\n",
"conv2d_9 (Conv2D) (None, 8, 8, 128) 73856 \n",
"_________________________________________________________________\n",
"batch_normalization_5 (Batch (None, 8, 8, 128) 512 \n",
"_________________________________________________________________\n",
"conv2d_10 (Conv2D) (None, 8, 8, 128) 147584 \n",
"_________________________________________________________________\n",
"batch_normalization_6 (Batch (None, 8, 8, 128) 512 \n",
"_________________________________________________________________\n",
"max_pooling2d_5 (MaxPooling2 (None, 4, 4, 128) 0 \n",
"_________________________________________________________________\n",
"dropout_6 (Dropout) (None, 4, 4, 128) 0 \n",
"_________________________________________________________________\n",
"flatten_2 (Flatten) (None, 2048) 0 \n",
"_________________________________________________________________\n",
"dense_3 (Dense) (None, 128) 262272 \n",
"_________________________________________________________________\n",
"batch_normalization_7 (Batch (None, 128) 512 \n",
"_________________________________________________________________\n",
"dropout_7 (Dropout) (None, 128) 0 \n",
"_________________________________________________________________\n",
"dense_4 (Dense) (None, 10) 1290 \n",
"=================================================================\n",
"Total params: 552,874\n",
"Trainable params: 551,722\n",
"Non-trainable params: 1,152\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"model2 = Sequential()\n",
"\n",
"model2.add(Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same', input_shape=(32, 32, 3)))\n",
"model2.add(BatchNormalization())\n",
"model2.add(Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same'))\n",
"model2.add(BatchNormalization())\n",
"model2.add(MaxPooling2D((2, 2)))\n",
"model2.add(Dropout(0.2))\n",
"\n",
"model2.add(Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same'))\n",
"model2.add(BatchNormalization())\n",
"model2.add(Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same'))\n",
"model2.add(BatchNormalization())\n",
"model2.add(MaxPooling2D((2, 2)))\n",
"model2.add(Dropout(0.3))\n",
"\n",
"model2.add(Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same'))\n",
"model2.add(BatchNormalization())\n",
"model2.add(Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same'))\n",
"model2.add(BatchNormalization())\n",
"model2.add(MaxPooling2D((2, 2)))\n",
"model2.add(Dropout(0.4))\n",
"\n",
"model2.add(Flatten())\n",
"model2.add(Dense(128, activation='relu', kernel_initializer='he_uniform'))\n",
"model2.add(BatchNormalization())\n",
"model2.add(Dropout(0.5))\n",
"model2.add(Dense(10, activation='softmax'))\n",
"\n",
"model2.summary()"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'{\"class_name\": \"Sequential\", \"keras_version\": \"2.1.6\", \"config\": [{\"class_name\": \"Conv2D\", \"config\": {\"kernel_initializer\": {\"class_name\": \"VarianceScaling\", \"config\": {\"distribution\": \"uniform\", \"scale\": 2.0, \"seed\": null, \"mode\": \"fan_in\"}}, \"name\": \"conv2d_5\", \"kernel_constraint\": null, \"bias_regularizer\": null, \"bias_constraint\": null, \"dtype\": \"float32\", \"activation\": \"relu\", \"trainable\": true, \"data_format\": \"channels_last\", \"filters\": 32, \"padding\": \"same\", \"strides\": [1, 1], \"dilation_rate\": [1, 1], \"kernel_regularizer\": null, \"bias_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"batch_input_shape\": [null, 32, 32, 3], \"use_bias\": true, \"activity_regularizer\": null, \"kernel_size\": [3, 3]}}, {\"class_name\": \"BatchNormalization\", \"config\": {\"beta_constraint\": null, \"gamma_initializer\": {\"class_name\": \"Ones\", \"config\": {}}, \"moving_mean_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"name\": \"batch_normalization_1\", \"epsilon\": 0.001, \"trainable\": true, \"moving_variance_initializer\": {\"class_name\": \"Ones\", \"config\": {}}, \"beta_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"scale\": true, \"axis\": -1, \"gamma_constraint\": null, \"gamma_regularizer\": null, \"beta_regularizer\": null, \"momentum\": 0.99, \"center\": true}}, {\"class_name\": \"Conv2D\", \"config\": {\"kernel_constraint\": null, \"kernel_initializer\": {\"class_name\": \"VarianceScaling\", \"config\": {\"distribution\": \"uniform\", \"scale\": 2.0, \"seed\": null, \"mode\": \"fan_in\"}}, \"name\": \"conv2d_6\", \"bias_regularizer\": null, \"bias_constraint\": null, \"activation\": \"relu\", \"trainable\": true, \"data_format\": \"channels_last\", \"padding\": \"same\", \"strides\": [1, 1], \"dilation_rate\": [1, 1], \"kernel_regularizer\": null, \"filters\": 32, \"bias_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"use_bias\": true, \"activity_regularizer\": null, \"kernel_size\": [3, 3]}}, {\"class_name\": \"BatchNormalization\", \"config\": {\"beta_constraint\": null, \"gamma_initializer\": {\"class_name\": \"Ones\", \"config\": {}}, \"moving_mean_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"name\": \"batch_normalization_2\", \"epsilon\": 0.001, \"trainable\": true, \"moving_variance_initializer\": {\"class_name\": \"Ones\", \"config\": {}}, \"beta_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"scale\": true, \"axis\": -1, \"gamma_constraint\": null, \"gamma_regularizer\": null, \"beta_regularizer\": null, \"momentum\": 0.99, \"center\": true}}, {\"class_name\": \"MaxPooling2D\", \"config\": {\"name\": \"max_pooling2d_3\", \"trainable\": true, \"data_format\": \"channels_last\", \"pool_size\": [2, 2], \"padding\": \"valid\", \"strides\": [2, 2]}}, {\"class_name\": \"Dropout\", \"config\": {\"rate\": 0.2, \"noise_shape\": null, \"trainable\": true, \"seed\": null, \"name\": \"dropout_4\"}}, {\"class_name\": \"Conv2D\", \"config\": {\"kernel_constraint\": null, \"kernel_initializer\": {\"class_name\": \"VarianceScaling\", \"config\": {\"distribution\": \"uniform\", \"scale\": 2.0, \"seed\": null, \"mode\": \"fan_in\"}}, \"name\": \"conv2d_7\", \"bias_regularizer\": null, \"bias_constraint\": null, \"activation\": \"relu\", \"trainable\": true, \"data_format\": \"channels_last\", \"padding\": \"same\", \"strides\": [1, 1], \"dilation_rate\": [1, 1], \"kernel_regularizer\": null, \"filters\": 64, \"bias_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"use_bias\": true, \"activity_regularizer\": null, \"kernel_size\": [3, 3]}}, {\"class_name\": \"BatchNormalization\", \"config\": {\"beta_constraint\": null, \"gamma_initializer\": {\"class_name\": \"Ones\", \"config\": {}}, \"moving_mean_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"name\": \"batch_normalization_3\", \"epsilon\": 0.001, \"trainable\": true, \"moving_variance_initializer\": {\"class_name\": \"Ones\", \"config\": {}}, \"beta_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"scale\": true, \"axis\": -1, \"gamma_constraint\": null, \"gamma_regularizer\": null, \"beta_regularizer\": null, \"momentum\": 0.99, \"center\": true}}, {\"class_name\": \"Conv2D\", \"config\": {\"kernel_constraint\": null, \"kernel_initializer\": {\"class_name\": \"VarianceScaling\", \"config\": {\"distribution\": \"uniform\", \"scale\": 2.0, \"seed\": null, \"mode\": \"fan_in\"}}, \"name\": \"conv2d_8\", \"bias_regularizer\": null, \"bias_constraint\": null, \"activation\": \"relu\", \"trainable\": true, \"data_format\": \"channels_last\", \"padding\": \"same\", \"strides\": [1, 1], \"dilation_rate\": [1, 1], \"kernel_regularizer\": null, \"filters\": 64, \"bias_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"use_bias\": true, \"activity_regularizer\": null, \"kernel_size\": [3, 3]}}, {\"class_name\": \"BatchNormalization\", \"config\": {\"beta_constraint\": null, \"gamma_initializer\": {\"class_name\": \"Ones\", \"config\": {}}, \"moving_mean_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"name\": \"batch_normalization_4\", \"epsilon\": 0.001, \"trainable\": true, \"moving_variance_initializer\": {\"class_name\": \"Ones\", \"config\": {}}, \"beta_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"scale\": true, \"axis\": -1, \"gamma_constraint\": null, \"gamma_regularizer\": null, \"beta_regularizer\": null, \"momentum\": 0.99, \"center\": true}}, {\"class_name\": \"MaxPooling2D\", \"config\": {\"name\": \"max_pooling2d_4\", \"trainable\": true, \"data_format\": \"channels_last\", \"pool_size\": [2, 2], \"padding\": \"valid\", \"strides\": [2, 2]}}, {\"class_name\": \"Dropout\", \"config\": {\"rate\": 0.3, \"noise_shape\": null, \"trainable\": true, \"seed\": null, \"name\": \"dropout_5\"}}, {\"class_name\": \"Conv2D\", \"config\": {\"kernel_constraint\": null, \"kernel_initializer\": {\"class_name\": \"VarianceScaling\", \"config\": {\"distribution\": \"uniform\", \"scale\": 2.0, \"seed\": null, \"mode\": \"fan_in\"}}, \"name\": \"conv2d_9\", \"bias_regularizer\": null, \"bias_constraint\": null, \"activation\": \"relu\", \"trainable\": true, \"data_format\": \"channels_last\", \"padding\": \"same\", \"strides\": [1, 1], \"dilation_rate\": [1, 1], \"kernel_regularizer\": null, \"filters\": 128, \"bias_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"use_bias\": true, \"activity_regularizer\": null, \"kernel_size\": [3, 3]}}, {\"class_name\": \"BatchNormalization\", \"config\": {\"beta_constraint\": null, \"gamma_initializer\": {\"class_name\": \"Ones\", \"config\": {}}, \"moving_mean_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"name\": \"batch_normalization_5\", \"epsilon\": 0.001, \"trainable\": true, \"moving_variance_initializer\": {\"class_name\": \"Ones\", \"config\": {}}, \"beta_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"scale\": true, \"axis\": -1, \"gamma_constraint\": null, \"gamma_regularizer\": null, \"beta_regularizer\": null, \"momentum\": 0.99, \"center\": true}}, {\"class_name\": \"Conv2D\", \"config\": {\"kernel_constraint\": null, \"kernel_initializer\": {\"class_name\": \"VarianceScaling\", \"config\": {\"distribution\": \"uniform\", \"scale\": 2.0, \"seed\": null, \"mode\": \"fan_in\"}}, \"name\": \"conv2d_10\", \"bias_regularizer\": null, \"bias_constraint\": null, \"activation\": \"relu\", \"trainable\": true, \"data_format\": \"channels_last\", \"padding\": \"same\", \"strides\": [1, 1], \"dilation_rate\": [1, 1], \"kernel_regularizer\": null, \"filters\": 128, \"bias_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"use_bias\": true, \"activity_regularizer\": null, \"kernel_size\": [3, 3]}}, {\"class_name\": \"BatchNormalization\", \"config\": {\"beta_constraint\": null, \"gamma_initializer\": {\"class_name\": \"Ones\", \"config\": {}}, \"moving_mean_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"name\": \"batch_normalization_6\", \"epsilon\": 0.001, \"trainable\": true, \"moving_variance_initializer\": {\"class_name\": \"Ones\", \"config\": {}}, \"beta_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"scale\": true, \"axis\": -1, \"gamma_constraint\": null, \"gamma_regularizer\": null, \"beta_regularizer\": null, \"momentum\": 0.99, \"center\": true}}, {\"class_name\": \"MaxPooling2D\", \"config\": {\"name\": \"max_pooling2d_5\", \"trainable\": true, \"data_format\": \"channels_last\", \"pool_size\": [2, 2], \"padding\": \"valid\", \"strides\": [2, 2]}}, {\"class_name\": \"Dropout\", \"config\": {\"rate\": 0.4, \"noise_shape\": null, \"trainable\": true, \"seed\": null, \"name\": \"dropout_6\"}}, {\"class_name\": \"Flatten\", \"config\": {\"trainable\": true, \"name\": \"flatten_2\", \"data_format\": \"channels_last\"}}, {\"class_name\": \"Dense\", \"config\": {\"kernel_initializer\": {\"class_name\": \"VarianceScaling\", \"config\": {\"distribution\": \"uniform\", \"scale\": 2.0, \"seed\": null, \"mode\": \"fan_in\"}}, \"name\": \"dense_3\", \"kernel_constraint\": null, \"bias_regularizer\": null, \"bias_constraint\": null, \"activation\": \"relu\", \"trainable\": true, \"kernel_regularizer\": null, \"bias_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"units\": 128, \"use_bias\": true, \"activity_regularizer\": null}}, {\"class_name\": \"BatchNormalization\", \"config\": {\"beta_constraint\": null, \"gamma_initializer\": {\"class_name\": \"Ones\", \"config\": {}}, \"moving_mean_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"name\": \"batch_normalization_7\", \"epsilon\": 0.001, \"trainable\": true, \"moving_variance_initializer\": {\"class_name\": \"Ones\", \"config\": {}}, \"beta_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"scale\": true, \"axis\": -1, \"gamma_constraint\": null, \"gamma_regularizer\": null, \"beta_regularizer\": null, \"momentum\": 0.99, \"center\": true}}, {\"class_name\": \"Dropout\", \"config\": {\"rate\": 0.5, \"noise_shape\": null, \"trainable\": true, \"seed\": null, \"name\": \"dropout_7\"}}, {\"class_name\": \"Dense\", \"config\": {\"kernel_initializer\": {\"class_name\": \"VarianceScaling\", \"config\": {\"distribution\": \"uniform\", \"scale\": 1.0, \"seed\": null, \"mode\": \"fan_avg\"}}, \"name\": \"dense_4\", \"kernel_constraint\": null, \"bias_regularizer\": null, \"bias_constraint\": null, \"activation\": \"softmax\", \"trainable\": true, \"kernel_regularizer\": null, \"bias_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"units\": 10, \"use_bias\": true, \"activity_regularizer\": null}}], \"backend\": \"tensorflow\"}'"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model2.to_json()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"c. Another model architecture from https://machinelearningmastery.com/how-to-develop-a-cnn-from-scratch-for-cifar-10-photo-classification/"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"conv2d_11 (Conv2D) (None, 32, 32, 32) 896 \n",
"_________________________________________________________________\n",
"conv2d_12 (Conv2D) (None, 32, 32, 32) 9248 \n",
"_________________________________________________________________\n",
"max_pooling2d_6 (MaxPooling2 (None, 16, 16, 32) 0 \n",
"_________________________________________________________________\n",
"dropout_8 (Dropout) (None, 16, 16, 32) 0 \n",
"_________________________________________________________________\n",
"conv2d_13 (Conv2D) (None, 16, 16, 64) 18496 \n",
"_________________________________________________________________\n",
"conv2d_14 (Conv2D) (None, 16, 16, 64) 36928 \n",
"_________________________________________________________________\n",
"max_pooling2d_7 (MaxPooling2 (None, 8, 8, 64) 0 \n",
"_________________________________________________________________\n",
"dropout_9 (Dropout) (None, 8, 8, 64) 0 \n",
"_________________________________________________________________\n",
"conv2d_15 (Conv2D) (None, 8, 8, 128) 73856 \n",
"_________________________________________________________________\n",
"conv2d_16 (Conv2D) (None, 8, 8, 128) 147584 \n",
"_________________________________________________________________\n",
"max_pooling2d_8 (MaxPooling2 (None, 4, 4, 128) 0 \n",
"_________________________________________________________________\n",
"dropout_10 (Dropout) (None, 4, 4, 128) 0 \n",
"_________________________________________________________________\n",
"flatten_3 (Flatten) (None, 2048) 0 \n",
"_________________________________________________________________\n",
"dense_5 (Dense) (None, 128) 262272 \n",
"_________________________________________________________________\n",
"dropout_11 (Dropout) (None, 128) 0 \n",
"_________________________________________________________________\n",
"dense_6 (Dense) (None, 10) 1290 \n",
"=================================================================\n",
"Total params: 550,570\n",
"Trainable params: 550,570\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"model3 = Sequential()\n",
"\n",
"model3.add(Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same', input_shape=(32, 32, 3)))\n",
"model3.add(Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same'))\n",
"model3.add(MaxPooling2D((2, 2)))\n",
"model3.add(Dropout(0.2))\n",
"model3.add(Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same'))\n",
"model3.add(Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same'))\n",
"model3.add(MaxPooling2D((2, 2)))\n",
"model3.add(Dropout(0.3))\n",
"model3.add(Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same'))\n",
"model3.add(Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same'))\n",
"model3.add(MaxPooling2D((2, 2)))\n",
"model3.add(Dropout(0.4))\n",
"model3.add(Flatten())\n",
"model3.add(Dense(128, activation='relu', kernel_initializer='he_uniform'))\n",
"model3.add(Dropout(0.5))\n",
"model3.add(Dense(10, activation='softmax'))\n",
"\n",
"model3.summary()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Load into model architecture table using psycopg2"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"3 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>model_id</th>\n",
" <th>name</th>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>CNN from Keras docs for CIFAR-10</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>CNN from Jason Brownlee blog post</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>CNN from Jason Brownlee blog post - no batch normalization</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(1, u'CNN from Keras docs for CIFAR-10'),\n",
" (2, u'CNN from Jason Brownlee blog post'),\n",
" (3, u'CNN from Jason Brownlee blog post - no batch normalization')]"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import psycopg2 as p2\n",
"#conn = p2.connect('postgresql://gpadmin@35.239.240.26:5432/madlib')\n",
"#conn = p2.connect('postgresql://fmcquillan@localhost:5432/madlib')\n",
"conn = p2.connect('postgresql://gpadmin@localhost:8000/cifar_demo')\n",
"cur = conn.cursor()\n",
"\n",
"%sql DROP TABLE IF EXISTS model_arch_table_cifar10;\n",
"query = \"SELECT madlib.load_keras_model('model_arch_table_cifar10', %s, NULL, %s)\"\n",
"cur.execute(query,[model1.to_json(), \"CNN from Keras docs for CIFAR-10\"])\n",
"conn.commit()\n",
"\n",
"query = \"SELECT madlib.load_keras_model('model_arch_table_cifar10', %s, NULL, %s)\"\n",
"cur.execute(query,[model2.to_json(), \"CNN from Jason Brownlee blog post\"])\n",
"conn.commit()\n",
"\n",
"query = \"SELECT madlib.load_keras_model('model_arch_table_cifar10', %s, NULL, %s)\"\n",
"cur.execute(query,[model3.to_json(), \"CNN from Jason Brownlee blog post - no batch normalization\"])\n",
"conn.commit()\n",
"\n",
"# check model loaded OK\n",
"%sql SELECT model_id, name FROM model_arch_table_cifar10 ORDER BY model_id;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a id=\"hyperband\"></a>\n",
"# 4. Hyperband diagonal\n",
"\n",
"Create tables for intermediate and overall results from Hyperband, which is running on top of MADlib model selection methods."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"Done.\n",
"Done.\n",
"Done.\n",
"Done.\n",
"Done.\n",
"1 rows affected.\n",
"Done.\n",
"Done.\n",
"Done.\n",
"Done.\n",
"1 rows affected.\n"
]
},
{
"data": {
"text/plain": [
"[]"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"-- overall results table\n",
"DROP TABLE IF EXISTS results_cifar10;\n",
"CREATE TABLE results_cifar10 ( \n",
" mst_key INTEGER, -- note not SERIAL\n",
" model_id INTEGER, \n",
" compile_params TEXT,\n",
" fit_params TEXT, \n",
" model_type TEXT, \n",
" model_size DOUBLE PRECISION, \n",
" metrics_elapsed_time DOUBLE PRECISION[], \n",
" metrics_type TEXT[], \n",
" training_metrics_final DOUBLE PRECISION, \n",
" training_loss_final DOUBLE PRECISION, \n",
" training_metrics DOUBLE PRECISION[], \n",
" training_loss DOUBLE PRECISION[], \n",
" validation_metrics_final DOUBLE PRECISION, \n",
" validation_loss_final DOUBLE PRECISION, \n",
" validation_metrics DOUBLE PRECISION[], \n",
" validation_loss DOUBLE PRECISION[], \n",
" model_arch_table TEXT, \n",
" num_iterations INTEGER, \n",
" start_training_time TIMESTAMP, \n",
" end_training_time TIMESTAMP,\n",
" s INTEGER, -- bracket number from Hyperband\n",
" i INTEGER, -- iteration corresponding to successive having within a bracket\n",
" run_id SERIAL -- global counter for the training runs\n",
" );\n",
"\n",
"-- all model selections:\n",
"-- model selection table containing all model configs (all brackets)\n",
"DROP TABLE IF EXISTS mst_table_hb_cifar10;\n",
"CREATE TABLE mst_table_hb_cifar10 (\n",
" mst_key SERIAL, \n",
" s INTEGER, -- bracket\n",
" model_id INTEGER, \n",
" compile_params VARCHAR, \n",
" fit_params VARCHAR\n",
" );\n",
"\n",
"-- model selection summary table\n",
"DROP TABLE IF EXISTS mst_table_hb_cifar10_summary;\n",
"CREATE TABLE mst_table_hb_cifar10_summary (model_arch_table VARCHAR);\n",
"INSERT INTO mst_table_hb_cifar10_summary VALUES ('model_arch_table_cifar10');\n",
"\n",
"-- diagonal model selections:\n",
"-- model selection table for diagonal: fit() will be called on a per diagonal basis\n",
"DROP TABLE IF EXISTS mst_diag_table_hb_cifar10;\n",
"CREATE TABLE mst_diag_table_hb_cifar10 (\n",
" mst_key INTEGER, -- note not SERIAL since this table derived from main model selection table\n",
" s INTEGER, -- bracket\n",
" model_id INTEGER, \n",
" compile_params VARCHAR, \n",
" fit_params VARCHAR\n",
" );\n",
"\n",
"-- model selection summary table for diagonal table\n",
"DROP TABLE IF EXISTS mst_diag_table_hb_cifar10_summary;\n",
"CREATE TABLE mst_diag_table_hb_cifar10_summary (model_arch_table VARCHAR);\n",
"INSERT INTO mst_diag_table_hb_cifar10_summary VALUES ('model_arch_table_cifar10');"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Generalize table names"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"results_table = 'results_cifar10'\n",
"\n",
"output_table = 'cifar10_multi_model'\n",
"output_table_info = '_'.join([output_table, 'info'])\n",
"output_table_summary = '_'.join([output_table, 'summary'])\n",
"\n",
"best_model = 'cifar10_best_model'\n",
"best_model_info = '_'.join([best_model, 'info'])\n",
"best_model_summary = '_'.join([best_model, 'summary'])\n",
"\n",
"\n",
"mst_table = 'mst_table_hb_cifar10'\n",
"mst_table_summary = '_'.join([mst_table, 'summary'])\n",
"\n",
"mst_diag_table = 'mst_diag_table_hb_cifar10'\n",
"mst_diag_table_summary = '_'.join([mst_diag_table, 'summary'])\n",
"\n",
"model_arch_table = 'model_arch_table_cifar10'"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Hyperband diagonal logic"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Define variables for Hyperband"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"max_iter = 27 # maximum iterations per configuration\n",
"eta = 3 # defines downsampling rate (default = 3)\n",
"skip_last = 0 # 1 means skip last run in each bracket, 0 means run full bracket"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"from random import random\n",
"from math import log, ceil\n",
"from time import time, ctime\n",
"\n",
"class Hyperband_diagonal:\n",
" \n",
" def __init__( self, get_params_function, try_params_function ):\n",
" self.get_params = get_params_function #\n",
" self.try_params = try_params_function\n",
"\n",
" self.max_iter = max_iter \n",
" self.eta = eta \n",
" self.skip_last = skip_last \n",
"\n",
" self.logeta = lambda x: log( x ) / log( self.eta )\n",
" self.s_max = int( self.logeta( self.max_iter ))\n",
" self.B = ( self.s_max + 1 ) * self.max_iter\n",
" \n",
" #echo output\n",
" print (\"max_iter = \" + str(self.max_iter))\n",
" print (\"eta = \" + str(self.eta))\n",
" print (\"B = \" + str(self.s_max+1) + \"*max_iter = \" + str(self.B))\n",
" print (\"skip_last = \" + str(self.skip_last))\n",
" \n",
" self.setup_full_schedule()\n",
" self.create_mst_superset()\n",
" \n",
" self.best_loss = np.inf\n",
" self.best_accuracy = 0.0\n",
"\n",
" # create full Hyperband schedule for all brackets ahead of time\n",
" def setup_full_schedule(self):\n",
" self.n_vals = np.zeros((self.s_max+1, self.s_max+1), dtype=int)\n",
" self.r_vals = np.zeros((self.s_max+1, self.s_max+1), dtype=int)\n",
" \n",
" print (\" \")\n",
" print (\"Hyperband brackets\")\n",
"\n",
" # loop through each bracket in reverse order\n",
" for s in reversed(range(self.s_max+1)):\n",
" \n",
" print (\" \")\n",
" print (\"s=\" + str(s))\n",
" print (\"n_i r_i\")\n",
" print (\"------------\")\n",
"\n",
" for i in range(s+1):\n",
" # n_i configs for r_i iterations\n",
" n_i = n*self.eta**(-i)\n",
" r_i = r*self.eta**(i)\n",
"\n",
" self.n_vals[s][i] = n_i\n",
" self.r_vals[s][i] = r_i\n",
"\n",
" print (str(n_i) + \" \" + str (r_i))\n",
" \n",
" \n",
" # generate model selection tuples for all brackets\n",
" def create_mst_superset(self):\n",
" \n",
" print (\" \")\n",
" print (\"Create superset of MSTs for each bracket s\")\n",
" \n",
" # get hyper parameter configs for each bracket s\n",
" for s in reversed(range(self.s_max+1)):\n",
" n = int(ceil(int(self.B/self.max_iter/(s+1))*self.eta**s)) # initial number of configurations\n",
" r = self.max_iter*self.eta**(-s) # initial number of iterations to run configurations for\n",
"\n",
" print (\" \")\n",
" print (\"s=\" + str(s))\n",
" print (\"n=\" + str(n))\n",
" print (\"r=\" + str(r))\n",
" print (\" \")\n",
" \n",
" # n random configurations for each bracket s\n",
" self.get_params(n, s)\n",
" \n",
" \n",
" # Hyperband diagonal logic\n",
" def run(self): \n",
" \n",
" print (\" \")\n",
" print (\"Hyperband diagonal\")\n",
" print (\"Outer loop on diagonal:\")\n",
" \n",
" # outer loop on diagonal\n",
" #for i in range(self.s_max+1):\n",
" for i in range((self.s_max+1) - int(self.skip_last)):\n",
" print (\" \")\n",
" print (\"i=\" + str(i))\n",
" \n",
" # zero out diagonal table\n",
" %sql TRUNCATE TABLE $mst_diag_table\n",
" \n",
" # loop on brackets s desc to create diagonal table\n",
" print (\"Loop on s desc to create diagonal table:\")\n",
" for s in range(self.s_max, self.s_max-i-1, -1):\n",
"\n",
" # build up mst table for diagonal\n",
" %sql INSERT INTO $mst_diag_table (SELECT * FROM $mst_table WHERE s=$s);\n",
" \n",
" # first pass\n",
" if i == 0:\n",
" first_pass = True\n",
" else:\n",
" first_pass = False\n",
" \n",
" # multi-model training\n",
" print (\" \")\n",
" print (\"Try params for i = \" + str(i))\n",
" U = self.try_params(i, self.r_vals[self.s_max][i], first_pass) # r_i is the same for all diagonal elements\n",
" \n",
" # loop on brackets s desc to prune model selection table\n",
" # don't need to prune if finished last diagonal\n",
" #if i < (self.s_max):\n",
" if i < (self.s_max - int(self.skip_last)):\n",
" print (\"Loop on s desc to prune mst table:\")\n",
" for s in range(self.s_max, self.s_max-i-1, -1):\n",
" \n",
" # compute number of configs to keep\n",
" # remember i value is different for each bracket s on the diagonal\n",
" k = int( self.n_vals[s][s-self.s_max+i] / self.eta)\n",
" print (\"Pruning s = {} with k = {}\".format(s, k))\n",
"\n",
" # temporarily re-define table names due to weird Python scope issues\n",
" results_table = 'results_cifar10'\n",
"\n",
" output_table = 'cifar10_multi_model'\n",
" output_table_info = '_'.join([output_table, 'info'])\n",
" output_table_summary = '_'.join([output_table, 'summary'])\n",
"\n",
" mst_table = 'mst_table_hb_cifar10'\n",
" mst_table_summary = '_'.join([mst_table, 'summary'])\n",
"\n",
" mst_diag_table = 'mst_diag_table_hb_cifar10'\n",
" mst_diag_table_summary = '_'.join([mst_diag_table, 'summary'])\n",
"\n",
" model_arch_table = 'model_arch_table_cifar10'\n",
" \n",
" query = \"\"\"\n",
" DELETE FROM {mst_table} WHERE s={s} AND mst_key NOT IN (SELECT {output_table_info}.mst_key FROM {output_table_info} JOIN {mst_table} ON {output_table_info}.mst_key={mst_table}.mst_key WHERE s={s} ORDER BY validation_loss_final ASC LIMIT {k}::INT);\n",
" \"\"\".format(**locals())\n",
" cur.execute(query)\n",
" conn.commit()\n",
" \n",
" # these were not working so used cursor instead\n",
" #%sql DELETE FROM $mst_table WHERE s=$s AND mst_key NOT IN (SELECT $output_table_info.mst_key FROM $output_table_info JOIN $mst_table ON $output_table_info.mst_key=$mst_table.mst_key WHERE s=$s ORDER BY validation_loss_final ASC LIMIT $k::INT);\n",
" #%sql DELETE FROM mst_table_hb_cifar10 WHERE s=1 AND mst_key NOT IN (SELECT cifar10_multi_model_info.mst_key FROM cifar10_multi_model_info JOIN mst_table_hb_cifar10 ON cifar10_multi_model_info.mst_key=mst_table_hb_cifar10.mst_key WHERE s=1 ORDER BY validation_loss_final ASC LIMIT 1);\n",
" \n",
" # keep track of best loss so far and save the model for inference\n",
" # get best loss and accuracy from this diagonal run\n",
" # (need to check if this will work OK if don't evaluate metrics every iteration)\n",
" loss = %sql SELECT validation_loss_final FROM $output_table_info ORDER BY validation_loss_final ASC LIMIT 1;\n",
" accuracy = %sql SELECT validation_metrics_final FROM $output_table_info ORDER BY validation_metrics_final DESC LIMIT 1;\n",
" \n",
" # save best model based on accuracy (could do loss if you wanted)\n",
" if accuracy > self.best_accuracy:\n",
" \n",
" self.best_accuracy = accuracy\n",
" \n",
" # get best mst_key\n",
" best_mst_key = %sql SELECT mst_key FROM $output_table_info ORDER BY validation_metrics_final DESC LIMIT 1; \n",
" best_mst_key = best_mst_key.DataFrame().to_numpy()[0][0]\n",
"\n",
" # save model table (1 row for best model)\n",
" %sql DROP TABLE IF EXISTS $best_model;\n",
" %sql CREATE TABLE $best_model AS SELECT * FROM $output_table WHERE mst_key = $best_mst_key;\n",
"\n",
" # save info table (1 row for best model)\n",
" %sql DROP TABLE IF EXISTS $best_model_info;\n",
" %sql CREATE TABLE $best_model_info AS SELECT * FROM $output_table_info WHERE mst_key = $best_mst_key;\n",
" \n",
" # save summary table\n",
" %sql DROP TABLE IF EXISTS $best_model_summary;\n",
" %sql CREATE TABLE $best_model_summary AS SELECT * FROM $output_table_summary;\n",
" \n",
" if loss < self.best_loss:\n",
" self.best_loss = loss\n",
" \n",
" print (\" \")\n",
" print (\"Best validation loss so far = \")\n",
" print (str(loss))\n",
" print (\"Best validation accuracy so far = \")\n",
" print (str(accuracy))\n",
" \n",
"\n",
" \n",
" return"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Generate params and insert into MST table. This version of get_params uses the same compile parameters for all optimizers, and the same compile/fit parameters for all model architectures. (This may be too restrictive in some cases.) -- Note 3/13: check SIGMOID paper runs which I think I may have addressed this to some extent"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"def get_params(n, s):\n",
" \n",
" from sklearn.model_selection import ParameterSampler\n",
" from scipy.stats.distributions import uniform\n",
" import numpy as np\n",
" \n",
" # model architecture\n",
" model_id = [1,2]\n",
"\n",
" # compile params\n",
" # loss function\n",
" loss = ['categorical_crossentropy']\n",
" # optimizer\n",
" optimizer = ['sgd', 'adam', 'rmsprop']\n",
" # learning rate (sample on log scale here not in ParameterSampler)\n",
" lr_range = [0.0001, 0.01]\n",
" lr = 10**np.random.uniform(np.log10(lr_range[0]), np.log10(lr_range[1]), n)\n",
" # metrics\n",
" metrics = ['accuracy']\n",
"\n",
" # fit params\n",
" # batch size\n",
" batch_size = [32, 64, 128, 256]\n",
" # epochs\n",
" epochs = [5]\n",
"\n",
" # create random param list\n",
" param_grid = {\n",
" 'model_id': model_id,\n",
" 'loss': loss,\n",
" 'optimizer': optimizer,\n",
" 'lr': lr,\n",
" 'metrics': metrics,\n",
" 'batch_size': batch_size,\n",
" 'epochs': epochs\n",
" }\n",
" param_list = list(ParameterSampler(param_grid, n_iter=n))\n",
" \n",
" for params in param_list:\n",
"\n",
" model_id = str(params.get(\"model_id\"))\n",
" compile_params = \"$$loss='\" + str(params.get(\"loss\")) + \"',optimizer='\" + str(params.get(\"optimizer\")) + \"(lr=\" + str(params.get(\"lr\")) + \")',metrics=['\" + str(params.get(\"metrics\")) + \"']$$\" \n",
" fit_params = \"$$batch_size=\" + str(params.get(\"batch_size\")) + \",epochs=\" + str(params.get(\"epochs\")) + \"$$\" \n",
" row_content = \"(\" + str(s) + \", \" + model_id + \", \" + compile_params + \", \" + fit_params + \");\"\n",
" \n",
" %sql INSERT INTO $mst_table (s, model_id, compile_params, fit_params) VALUES $row_content\n",
" \n",
" return"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Generate params and insert into MST table. This version of get_params allows for more customization by optimizer and model architecture. This is sort of brute force and can be improved."
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"def get_params(n, s):\n",
" \n",
" from sklearn.model_selection import ParameterSampler\n",
" from scipy.stats.distributions import uniform\n",
" import numpy as np\n",
" \n",
" # number of samples by optimizer\n",
" #n_adam = int(n/3)\n",
" n_adam = int(n/2)\n",
" #n_rmsprop = int(n/3)\n",
" n_rmsprop = 0\n",
" n_sgd = int(n - n_adam - n_rmsprop)\n",
"\n",
" # 1) adam\n",
" \n",
" # model architecture\n",
" model_id = [2,3]\n",
"\n",
" # compile params\n",
" # loss function\n",
" loss = ['categorical_crossentropy']\n",
" # optimizer\n",
" optimizer = ['adam']\n",
" # learning rate (sample on log scale here not in ParameterSampler)\n",
" lr_range = [0.0001, 0.001]\n",
" lr = 10**np.random.uniform(np.log10(lr_range[0]), np.log10(lr_range[1]), n_adam)\n",
" # metrics\n",
" metrics = ['accuracy']\n",
"\n",
" # fit params\n",
" # batch size\n",
" batch_size = [128, 256]\n",
" # epochs\n",
" epochs = [5]\n",
"\n",
" # create random param list\n",
" param_grid = {\n",
" 'model_id': model_id,\n",
" 'loss': loss,\n",
" 'optimizer': optimizer,\n",
" 'lr': lr,\n",
" 'metrics': metrics,\n",
" 'batch_size': batch_size,\n",
" 'epochs': epochs\n",
" }\n",
" param_list_adam = list(ParameterSampler(param_grid, n_iter=n_adam))\n",
"\n",
" # iterate over params\n",
" for params in param_list_adam:\n",
"\n",
" model_id = str(params.get(\"model_id\"))\n",
" compile_params = \"$$loss='\" + str(params.get(\"loss\")) + \"',optimizer='\" + str(params.get(\"optimizer\")) + \"(lr=\" + str(params.get(\"lr\")) + \")',metrics=['\" + str(params.get(\"metrics\")) + \"']$$\" \n",
" fit_params = \"$$batch_size=\" + str(params.get(\"batch_size\")) + \",epochs=\" + str(params.get(\"epochs\")) + \"$$\" \n",
" row_content = \"(\" + str(s) + \", \" + model_id + \", \" + compile_params + \", \" + fit_params + \");\"\n",
" \n",
" # populate mst table\n",
" %sql INSERT INTO $mst_table (s, model_id, compile_params, fit_params) VALUES $row_content\n",
" \n",
" \n",
" # 2) rmsprop\n",
" \n",
" # model architecture\n",
" model_id = [1,2,3]\n",
"\n",
" # compile params\n",
" # loss function\n",
" loss = ['categorical_crossentropy']\n",
" # optimizer\n",
" optimizer = ['rmsprop']\n",
" # learning rate (sample on log scale here not in ParameterSampler)\n",
" lr_range = [0.0001, 0.001]\n",
" lr = 10**np.random.uniform(np.log10(lr_range[0]), np.log10(lr_range[1]), n_rmsprop)\n",
" # decay (sample on log scale here not in ParameterSampler if want multiple values)\n",
" decay = [1e-6]\n",
"\n",
" # metrics\n",
" metrics = ['accuracy']\n",
"\n",
" # fit params\n",
" # batch size\n",
" batch_size = [32, 64, 128, 256]\n",
" # epochs\n",
" epochs = [5]\n",
"\n",
" # create random param list\n",
" param_grid = {\n",
" 'model_id': model_id,\n",
" 'loss': loss,\n",
" 'optimizer': optimizer,\n",
" 'lr': lr,\n",
" 'decay': decay,\n",
" 'metrics': metrics,\n",
" 'batch_size': batch_size,\n",
" 'epochs': epochs\n",
" }\n",
" param_list_rmsprop = list(ParameterSampler(param_grid, n_iter=n_rmsprop))\n",
"\n",
" # iterate over params\n",
" for params in param_list_rmsprop:\n",
"\n",
" model_id = str(params.get(\"model_id\"))\n",
" compile_params = \"$$loss='\" + str(params.get(\"loss\")) + \"',optimizer='\" + str(params.get(\"optimizer\")) + \"(lr=\" + str(params.get(\"lr\")) + \",decay=\" + str(params.get(\"decay\")) + \")',metrics=['\" + str(params.get(\"metrics\")) + \"']$$\" \n",
" fit_params = \"$$batch_size=\" + str(params.get(\"batch_size\")) + \",epochs=\" + str(params.get(\"epochs\")) + \"$$\" \n",
" row_content = \"(\" + str(s) + \", \" + model_id + \", \" + compile_params + \", \" + fit_params + \");\"\n",
" \n",
" # populate mst table\n",
" %sql INSERT INTO $mst_table (s, model_id, compile_params, fit_params) VALUES $row_content\n",
"\n",
"\n",
" # 3) sgd\n",
" \n",
" # model architecture\n",
" model_id = [2,3]\n",
"\n",
" # compile params\n",
" # loss function\n",
" loss = ['categorical_crossentropy']\n",
" # optimizer\n",
" optimizer = ['sgd']\n",
" # learning rate (sample on log scale here not in ParameterSampler)\n",
" lr_range = [0.001, 0.005]\n",
" lr = 10**np.random.uniform(np.log10(lr_range[0]), np.log10(lr_range[1]), n_sgd)\n",
" # momentum (sample on log scale here not in ParameterSampler)\n",
" # recall momentum is an exponentially weighted array\n",
" beta_range = [0.9, 0.95]\n",
" beta = 1.0 - 10**np.random.uniform(np.log10(1.0-beta_range[0]), np.log10(1.0-beta_range[1]), n_sgd)\n",
" # metrics\n",
" metrics = ['accuracy']\n",
"\n",
" # fit params\n",
" # batch size\n",
" batch_size = [128, 256]\n",
" # epochs\n",
" epochs = [5]\n",
"\n",
" # create random param list\n",
" param_grid = {\n",
" 'model_id': model_id,\n",
" 'loss': loss,\n",
" 'optimizer': optimizer,\n",
" 'lr': lr,\n",
" 'beta': beta,\n",
" 'metrics': metrics,\n",
" 'batch_size': batch_size,\n",
" 'epochs': epochs\n",
" }\n",
" param_list_sgd = list(ParameterSampler(param_grid, n_iter=n_sgd))\n",
"\n",
" # iterate over params\n",
" for params in param_list_sgd:\n",
"\n",
" model_id = str(params.get(\"model_id\"))\n",
" compile_params = \"$$loss='\" + str(params.get(\"loss\")) + \"',optimizer='\" + str(params.get(\"optimizer\")) + \"(lr=\" + str(params.get(\"lr\")) + \",momentum=\" + str(params.get(\"beta\")) + \")',metrics=['\" + str(params.get(\"metrics\")) + \"']$$\" \n",
" fit_params = \"$$batch_size=\" + str(params.get(\"batch_size\")) + \",epochs=\" + str(params.get(\"epochs\")) + \"$$\" \n",
" row_content = \"(\" + str(s) + \", \" + model_id + \", \" + compile_params + \", \" + fit_params + \");\"\n",
" \n",
" # populate mst table\n",
" %sql INSERT INTO $mst_table (s, model_id, compile_params, fit_params) VALUES $row_content\n",
"\n",
" \n",
" #4) organize mst table\n",
"\n",
" #down sample\n",
" #%sql DELETE from $mst_table WHERE mst_key NOT IN (SELECT mst_key FROM $mst_table ORDER BY random() LIMIT $n);\n",
"\n",
" # make mst_keys contiguous\n",
" #%sql ALTER TABLE $mst_table DROP COLUMN mst_key;\n",
" #%sql ALTER TABLE $mst_table ADD COLUMN mst_key SERIAL;\n",
" \n",
" return"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Run model hopper for candidates in MST table"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"def try_params(i, r, first_pass):\n",
" \n",
" # multi-model fit\n",
" if first_pass:\n",
" # cold start\n",
" %sql DROP TABLE IF EXISTS $output_table, $output_table_summary, $output_table_info;\n",
" # passing vars as madlib args does not seem to work\n",
" #%sql SELECT madlib.madlib_keras_fit_multiple_model('cifar10_train_packed', $output_table, $mst_diag_table, $r_i::INT, 0);\n",
" %sql SELECT madlib.madlib_keras_fit_multiple_model('cifar10_train_packed', 'cifar10_multi_model', 'mst_diag_table_hb_cifar10', $r::INT, True, 'cifar10_val_packed',1);\n",
"\n",
" else:\n",
" # warm start to continue from previous run\n",
" %sql SELECT madlib.madlib_keras_fit_multiple_model('cifar10_train_packed', 'cifar10_multi_model', 'mst_diag_table_hb_cifar10', $r::INT, True, 'cifar10_val_packed', 1, True);\n",
"\n",
" # save results via temp table\n",
" # add everything from info table\n",
" %sql DROP TABLE IF EXISTS temp_results;\n",
" %sql CREATE TABLE temp_results AS (SELECT * FROM $output_table_info);\n",
" \n",
" # add summary table info and i value (same for each row)\n",
" %sql ALTER TABLE temp_results ADD COLUMN model_arch_table TEXT, ADD COLUMN num_iterations INTEGER, ADD COLUMN start_training_time TIMESTAMP, ADD COLUMN end_training_time TIMESTAMP, ADD COLUMN s INTEGER, ADD COLUMN i INTEGER;\n",
" %sql UPDATE temp_results SET model_arch_table = (SELECT model_arch_table FROM $output_table_summary), num_iterations = (SELECT num_iterations FROM $output_table_summary), start_training_time = (SELECT start_training_time FROM $output_table_summary), end_training_time = (SELECT end_training_time FROM $output_table_summary), i = $i;\n",
" \n",
" # get the s value for each run (not the same for each row since diagonal table crosses multiple brackets)\n",
" %sql UPDATE temp_results SET s = m.s FROM mst_diag_table_hb_cifar10 AS m WHERE m.mst_key = temp_results.mst_key;\n",
" \n",
" # copy temp table into results table\n",
" %sql INSERT INTO $results_table (SELECT * FROM temp_results);\n",
"\n",
" return"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Call Hyperband diagonal"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"max_iter = 27\n",
"eta = 3\n",
"B = 4*max_iter = 108\n",
"skip_last = 0\n",
" \n",
"Hyperband brackets\n",
" \n",
"s=3\n",
"n_i r_i\n",
"------------\n",
"27 1.0\n",
"9.0 3.0\n",
"3.0 9.0\n",
"1.0 27.0\n",
" \n",
"s=2\n",
"n_i r_i\n",
"------------\n",
"9 3.0\n",
"3.0 9.0\n",
"1.0 27.0\n",
" \n",
"s=1\n",
"n_i r_i\n",
"------------\n",
"6 9.0\n",
"2.0 27.0\n",
" \n",
"s=0\n",
"n_i r_i\n",
"------------\n",
"4 27\n",
" \n",
"Create superset of MSTs for each bracket s\n",
" \n",
"s=3\n",
"n=27\n",
"r=1.0\n",
" \n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
" \n",
"s=2\n",
"n=9\n",
"r=3.0\n",
" \n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
" \n",
"s=1\n",
"n=6\n",
"r=9.0\n",
" \n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
" \n",
"s=0\n",
"n=4\n",
"r=27\n",
" \n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
" \n",
"Hyperband diagonal\n",
"Outer loop on diagonal:\n",
" \n",
"i=0\n",
"Done.\n",
"Loop on s desc to create diagonal table:\n",
"27 rows affected.\n",
" \n",
"Try params for i = 0\n",
"Done.\n",
"1 rows affected.\n",
"Done.\n",
"27 rows affected.\n",
"Done.\n",
"27 rows affected.\n",
"27 rows affected.\n",
"27 rows affected.\n",
"Loop on s desc to prune mst table:\n",
"Pruning s = 3 with k = 9\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"Done.\n",
"1 rows affected.\n",
"Done.\n",
"1 rows affected.\n",
"Done.\n",
"1 rows affected.\n",
" \n",
"Best validation loss so far = \n",
"+-----------------------+\n",
"| validation_loss_final |\n",
"+-----------------------+\n",
"| 0.782763898373 |\n",
"+-----------------------+\n",
"Best validation accuracy so far = \n",
"+--------------------------+\n",
"| validation_metrics_final |\n",
"+--------------------------+\n",
"| 0.72729998827 |\n",
"+--------------------------+\n",
" \n",
"i=1\n",
"Done.\n",
"Loop on s desc to create diagonal table:\n",
"9 rows affected.\n",
"9 rows affected.\n",
" \n",
"Try params for i = 1\n",
"1 rows affected.\n",
"Done.\n",
"18 rows affected.\n",
"Done.\n",
"18 rows affected.\n",
"18 rows affected.\n",
"18 rows affected.\n",
"Loop on s desc to prune mst table:\n",
"Pruning s = 3 with k = 3\n",
"Pruning s = 2 with k = 3\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"Done.\n",
"1 rows affected.\n",
"Done.\n",
"1 rows affected.\n",
"Done.\n",
"1 rows affected.\n",
" \n",
"Best validation loss so far = \n",
"+-----------------------+\n",
"| validation_loss_final |\n",
"+-----------------------+\n",
"| 0.602479159832 |\n",
"+-----------------------+\n",
"Best validation accuracy so far = \n",
"+--------------------------+\n",
"| validation_metrics_final |\n",
"+--------------------------+\n",
"| 0.805599987507 |\n",
"+--------------------------+\n",
" \n",
"i=2\n",
"Done.\n",
"Loop on s desc to create diagonal table:\n",
"3 rows affected.\n",
"3 rows affected.\n",
"6 rows affected.\n",
" \n",
"Try params for i = 2\n",
"1 rows affected.\n",
"Done.\n",
"12 rows affected.\n",
"Done.\n",
"12 rows affected.\n",
"12 rows affected.\n",
"12 rows affected.\n",
"Loop on s desc to prune mst table:\n",
"Pruning s = 3 with k = 1\n",
"Pruning s = 2 with k = 1\n",
"Pruning s = 1 with k = 2\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"Done.\n",
"1 rows affected.\n",
"Done.\n",
"1 rows affected.\n",
"Done.\n",
"1 rows affected.\n",
" \n",
"Best validation loss so far = \n",
"+-----------------------+\n",
"| validation_loss_final |\n",
"+-----------------------+\n",
"| 0.595765888691 |\n",
"+-----------------------+\n",
"Best validation accuracy so far = \n",
"+--------------------------+\n",
"| validation_metrics_final |\n",
"+--------------------------+\n",
"| 0.824999988079 |\n",
"+--------------------------+\n",
" \n",
"i=3\n",
"Done.\n",
"Loop on s desc to create diagonal table:\n",
"1 rows affected.\n",
"1 rows affected.\n",
"2 rows affected.\n",
"4 rows affected.\n",
" \n",
"Try params for i = 3\n",
"1 rows affected.\n",
"Done.\n",
"8 rows affected.\n",
"Done.\n",
"8 rows affected.\n",
"8 rows affected.\n",
"8 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"Done.\n",
"1 rows affected.\n",
"Done.\n",
"1 rows affected.\n",
"Done.\n",
"1 rows affected.\n",
" \n",
"Best validation loss so far = \n",
"+-----------------------+\n",
"| validation_loss_final |\n",
"+-----------------------+\n",
"| 0.580716967583 |\n",
"+-----------------------+\n",
"Best validation accuracy so far = \n",
"+--------------------------+\n",
"| validation_metrics_final |\n",
"+--------------------------+\n",
"| 0.834100008011 |\n",
"+--------------------------+\n"
]
}
],
"source": [
"hp = Hyperband_diagonal(get_params, try_params )\n",
"results = hp.run()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a id=\"plot\"></a>\n",
"# 5. Review and plot results"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>mst_key</th>\n",
" <th>model_id</th>\n",
" <th>compile_params</th>\n",
" <th>fit_params</th>\n",
" <th>model_type</th>\n",
" <th>model_size</th>\n",
" <th>metrics_elapsed_time</th>\n",
" <th>metrics_type</th>\n",
" <th>training_metrics_final</th>\n",
" <th>training_loss_final</th>\n",
" <th>training_metrics</th>\n",
" <th>training_loss</th>\n",
" <th>validation_metrics_final</th>\n",
" <th>validation_loss_final</th>\n",
" <th>validation_metrics</th>\n",
" <th>validation_loss</th>\n",
" <th>model_arch_table</th>\n",
" <th>num_iterations</th>\n",
" <th>start_training_time</th>\n",
" <th>end_training_time</th>\n",
" <th>s</th>\n",
" <th>i</th>\n",
" <th>run_id</th>\n",
" </tr>\n",
" <tr>\n",
" <td>45</td>\n",
" <td>2</td>\n",
" <td>loss='categorical_crossentropy',optimizer='sgd(lr=0.004501919010538727,momentum=0.9002808952996391)',metrics=['accuracy']</td>\n",
" <td>batch_size=256,epochs=5</td>\n",
" <td>madlib_keras</td>\n",
" <td>2159.70019531</td>\n",
" <td>[121.955986022949, 245.619317054749, 368.365077972412, 490.415205955505, 614.768485069275, 737.048167943954, 860.508330106735, 984.307431936264, 1106.31793498993, 1229.54079914093, 1352.66811394691, 1477.57317709923, 1599.99458003044, 1723.35215711594, 1847.86346912384, 1971.57312297821, 2096.37913298607, 2221.54790210724, 2346.08665895462, 2470.83494997025, 2595.6411960125, 2722.25887513161, 2846.48335313797, 2971.13271403313, 3097.49445009232, 3222.44972395897, 3348.5662779808]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>0.941940009594</td>\n",
" <td>0.169452220201</td>\n",
" <td>[0.574479997158051, 0.658760011196136, 0.695840001106262, 0.72733998298645, 0.733219981193542, 0.771200001239777, 0.778680026531219, 0.808700025081635, 0.809000015258789, 0.818579971790314, 0.835739970207214, 0.84799998998642, 0.853200018405914, 0.858900010585785, 0.872919976711273, 0.878780007362366, 0.88808000087738, 0.880240023136139, 0.894320011138916, 0.903779983520508, 0.912299990653992, 0.908439993858337, 0.919539988040924, 0.924639999866486, 0.929180026054382, 0.9375, 0.941940009593964]</td>\n",
" <td>[1.19219434261322, 0.959131419658661, 0.861107409000397, 0.770956337451935, 0.747268915176392, 0.64410811662674, 0.628470838069916, 0.539423823356628, 0.541868448257446, 0.514527797698975, 0.469026476144791, 0.432008743286133, 0.416983753442764, 0.402583330869675, 0.363078087568283, 0.346161216497421, 0.317243546247482, 0.340911239385605, 0.304346263408661, 0.274338334798813, 0.253901869058609, 0.262585163116455, 0.231020957231522, 0.218931555747986, 0.206650838255882, 0.184870630502701, 0.169452220201492]</td>\n",
" <td>0.816399991512</td>\n",
" <td>0.580716967583</td>\n",
" <td>[0.565699994564056, 0.641200006008148, 0.674899995326996, 0.704500019550323, 0.708000004291534, 0.740499973297119, 0.739799976348877, 0.766499996185303, 0.762099981307983, 0.76690000295639, 0.780900001525879, 0.785000026226044, 0.785300016403198, 0.79009997844696, 0.79449999332428, 0.795799970626831, 0.802600026130676, 0.792599976062775, 0.798399984836578, 0.807299971580505, 0.810500025749207, 0.801699995994568, 0.805400013923645, 0.811600029468536, 0.810100018978119, 0.813899993896484, 0.816399991512299]</td>\n",
" <td>[1.20952260494232, 1.00138294696808, 0.919946014881134, 0.846988558769226, 0.835236310958862, 0.748137712478638, 0.745132148265839, 0.670836567878723, 0.688502311706543, 0.673530399799347, 0.646275579929352, 0.626095473766327, 0.629233837127686, 0.623023450374603, 0.601795375347137, 0.603216171264648, 0.587353229522705, 0.635767936706543, 0.61867493391037, 0.594616591930389, 0.586753845214844, 0.60888147354126, 0.601007521152496, 0.593143999576569, 0.601291477680206, 0.583372294902802, 0.580716967582703]</td>\n",
" <td>model_arch_table_cifar10</td>\n",
" <td>27</td>\n",
" <td>2020-01-23 21:12:04.749779</td>\n",
" <td>2020-01-23 22:07:53.819497</td>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" <td>65</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(45, 2, u\"loss='categorical_crossentropy',optimizer='sgd(lr=0.004501919010538727,momentum=0.9002808952996391)',metrics=['accuracy']\", u'batch_size=256,epochs=5', u'madlib_keras', 2159.70019531, [121.955986022949, 245.619317054749, 368.365077972412, 490.415205955505, 614.768485069275, 737.048167943954, 860.508330106735, 984.307431936264, 1106.31793498993, 1229.54079914093, 1352.66811394691, 1477.57317709923, 1599.99458003044, 1723.35215711594, 1847.86346912384, 1971.57312297821, 2096.37913298607, 2221.54790210724, 2346.08665895462, 2470.83494997025, 2595.6411960125, 2722.25887513161, 2846.48335313797, 2971.13271403313, 3097.49445009232, 3222.44972395897, 3348.5662779808], [u'accuracy'], 0.941940009594, 0.169452220201, [0.574479997158051, 0.658760011196136, 0.695840001106262, 0.72733998298645, 0.733219981193542, 0.771200001239777, 0.778680026531219, 0.808700025081635, 0.809000015258789, 0.818579971790314, 0.835739970207214, 0.84799998998642, 0.853200018405914, 0.858900010585785, 0.872919976711273, 0.878780007362366, 0.88808000087738, 0.880240023136139, 0.894320011138916, 0.903779983520508, 0.912299990653992, 0.908439993858337, 0.919539988040924, 0.924639999866486, 0.929180026054382, 0.9375, 0.941940009593964], [1.19219434261322, 0.959131419658661, 0.861107409000397, 0.770956337451935, 0.747268915176392, 0.64410811662674, 0.628470838069916, 0.539423823356628, 0.541868448257446, 0.514527797698975, 0.469026476144791, 0.432008743286133, 0.416983753442764, 0.402583330869675, 0.363078087568283, 0.346161216497421, 0.317243546247482, 0.340911239385605, 0.304346263408661, 0.274338334798813, 0.253901869058609, 0.262585163116455, 0.231020957231522, 0.218931555747986, 0.206650838255882, 0.184870630502701, 0.169452220201492], 0.816399991512, 0.580716967583, [0.565699994564056, 0.641200006008148, 0.674899995326996, 0.704500019550323, 0.708000004291534, 0.740499973297119, 0.739799976348877, 0.766499996185303, 0.762099981307983, 0.76690000295639, 0.780900001525879, 0.785000026226044, 0.785300016403198, 0.79009997844696, 0.79449999332428, 0.795799970626831, 0.802600026130676, 0.792599976062775, 0.798399984836578, 0.807299971580505, 0.810500025749207, 0.801699995994568, 0.805400013923645, 0.811600029468536, 0.810100018978119, 0.813899993896484, 0.816399991512299], [1.20952260494232, 1.00138294696808, 0.919946014881134, 0.846988558769226, 0.835236310958862, 0.748137712478638, 0.745132148265839, 0.670836567878723, 0.688502311706543, 0.673530399799347, 0.646275579929352, 0.626095473766327, 0.629233837127686, 0.623023450374603, 0.601795375347137, 0.603216171264648, 0.587353229522705, 0.635767936706543, 0.61867493391037, 0.594616591930389, 0.586753845214844, 0.60888147354126, 0.601007521152496, 0.593143999576569, 0.601291477680206, 0.583372294902802, 0.580716967582703], u'model_arch_table_cifar10', 27, datetime.datetime(2020, 1, 23, 21, 12, 4, 749779), datetime.datetime(2020, 1, 23, 22, 7, 53, 819497), 0, 3, 65)]"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%sql SELECT * FROM $results_table ORDER BY validation_loss_final ASC LIMIT 1;"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib notebook\n",
"import matplotlib.pyplot as plt\n",
"from matplotlib.ticker import MaxNLocator\n",
"from collections import defaultdict\n",
"import pandas as pd\n",
"import seaborn as sns\n",
"sns.set_palette(sns.color_palette(\"hls\", 20))\n",
"plt.rcParams.update({'font.size': 12})\n",
"pd.set_option('display.max_colwidth', -1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Training dataset"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"65 rows affected.\n"
]
},
{
"data": {
"application/javascript": [
"/* Put everything inside the global mpl namespace */\n",
"window.mpl = {};\n",
"\n",
"\n",
"mpl.get_websocket_type = function() {\n",
" if (typeof(WebSocket) !== 'undefined') {\n",
" return WebSocket;\n",
" } else if (typeof(MozWebSocket) !== 'undefined') {\n",
" return MozWebSocket;\n",
" } else {\n",
" alert('Your browser does not have WebSocket support.' +\n",
" 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n",
" 'Firefox 4 and 5 are also supported but you ' +\n",
" 'have to enable WebSockets in about:config.');\n",
" };\n",
"}\n",
"\n",
"mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n",
" this.id = figure_id;\n",
"\n",
" this.ws = websocket;\n",
"\n",
" this.supports_binary = (this.ws.binaryType != undefined);\n",
"\n",
" if (!this.supports_binary) {\n",
" var warnings = document.getElementById(\"mpl-warnings\");\n",
" if (warnings) {\n",
" warnings.style.display = 'block';\n",
" warnings.textContent = (\n",
" \"This browser does not support binary websocket messages. \" +\n",
" \"Performance may be slow.\");\n",
" }\n",
" }\n",
"\n",
" this.imageObj = new Image();\n",
"\n",
" this.context = undefined;\n",
" this.message = undefined;\n",
" this.canvas = undefined;\n",
" this.rubberband_canvas = undefined;\n",
" this.rubberband_context = undefined;\n",
" this.format_dropdown = undefined;\n",
"\n",
" this.image_mode = 'full';\n",
"\n",
" this.root = $('<div/>');\n",
" this._root_extra_style(this.root)\n",
" this.root.attr('style', 'display: inline-block');\n",
"\n",
" $(parent_element).append(this.root);\n",
"\n",
" this._init_header(this);\n",
" this._init_canvas(this);\n",
" this._init_toolbar(this);\n",
"\n",
" var fig = this;\n",
"\n",
" this.waiting = false;\n",
"\n",
" this.ws.onopen = function () {\n",
" fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n",
" fig.send_message(\"send_image_mode\", {});\n",
" if (mpl.ratio != 1) {\n",
" fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n",
" }\n",
" fig.send_message(\"refresh\", {});\n",
" }\n",
"\n",
" this.imageObj.onload = function() {\n",
" if (fig.image_mode == 'full') {\n",
" // Full images could contain transparency (where diff images\n",
" // almost always do), so we need to clear the canvas so that\n",
" // there is no ghosting.\n",
" fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n",
" }\n",
" fig.context.drawImage(fig.imageObj, 0, 0);\n",
" };\n",
"\n",
" this.imageObj.onunload = function() {\n",
" fig.ws.close();\n",
" }\n",
"\n",
" this.ws.onmessage = this._make_on_message_function(this);\n",
"\n",
" this.ondownload = ondownload;\n",
"}\n",
"\n",
"mpl.figure.prototype._init_header = function() {\n",
" var titlebar = $(\n",
" '<div class=\"ui-dialog-titlebar ui-widget-header ui-corner-all ' +\n",
" 'ui-helper-clearfix\"/>');\n",
" var titletext = $(\n",
" '<div class=\"ui-dialog-title\" style=\"width: 100%; ' +\n",
" 'text-align: center; padding: 3px;\"/>');\n",
" titlebar.append(titletext)\n",
" this.root.append(titlebar);\n",
" this.header = titletext[0];\n",
"}\n",
"\n",
"\n",
"\n",
"mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n",
"\n",
"}\n",
"\n",
"\n",
"mpl.figure.prototype._root_extra_style = function(canvas_div) {\n",
"\n",
"}\n",
"\n",
"mpl.figure.prototype._init_canvas = function() {\n",
" var fig = this;\n",
"\n",
" var canvas_div = $('<div/>');\n",
"\n",
" canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n",
"\n",
" function canvas_keyboard_event(event) {\n",
" return fig.key_event(event, event['data']);\n",
" }\n",
"\n",
" canvas_div.keydown('key_press', canvas_keyboard_event);\n",
" canvas_div.keyup('key_release', canvas_keyboard_event);\n",
" this.canvas_div = canvas_div\n",
" this._canvas_extra_style(canvas_div)\n",
" this.root.append(canvas_div);\n",
"\n",
" var canvas = $('<canvas/>');\n",
" canvas.addClass('mpl-canvas');\n",
" canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n",
"\n",
" this.canvas = canvas[0];\n",
" this.context = canvas[0].getContext(\"2d\");\n",
"\n",
" var backingStore = this.context.backingStorePixelRatio ||\n",
"\tthis.context.webkitBackingStorePixelRatio ||\n",
"\tthis.context.mozBackingStorePixelRatio ||\n",
"\tthis.context.msBackingStorePixelRatio ||\n",
"\tthis.context.oBackingStorePixelRatio ||\n",
"\tthis.context.backingStorePixelRatio || 1;\n",
"\n",
" mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n",
"\n",
" var rubberband = $('<canvas/>');\n",
" rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n",
"\n",
" var pass_mouse_events = true;\n",
"\n",
" canvas_div.resizable({\n",
" start: function(event, ui) {\n",
" pass_mouse_events = false;\n",
" },\n",
" resize: function(event, ui) {\n",
" fig.request_resize(ui.size.width, ui.size.height);\n",
" },\n",
" stop: function(event, ui) {\n",
" pass_mouse_events = true;\n",
" fig.request_resize(ui.size.width, ui.size.height);\n",
" },\n",
" });\n",
"\n",
" function mouse_event_fn(event) {\n",
" if (pass_mouse_events)\n",
" return fig.mouse_event(event, event['data']);\n",
" }\n",
"\n",
" rubberband.mousedown('button_press', mouse_event_fn);\n",
" rubberband.mouseup('button_release', mouse_event_fn);\n",
" // Throttle sequential mouse events to 1 every 20ms.\n",
" rubberband.mousemove('motion_notify', mouse_event_fn);\n",
"\n",
" rubberband.mouseenter('figure_enter', mouse_event_fn);\n",
" rubberband.mouseleave('figure_leave', mouse_event_fn);\n",
"\n",
" canvas_div.on(\"wheel\", function (event) {\n",
" event = event.originalEvent;\n",
" event['data'] = 'scroll'\n",
" if (event.deltaY < 0) {\n",
" event.step = 1;\n",
" } else {\n",
" event.step = -1;\n",
" }\n",
" mouse_event_fn(event);\n",
" });\n",
"\n",
" canvas_div.append(canvas);\n",
" canvas_div.append(rubberband);\n",
"\n",
" this.rubberband = rubberband;\n",
" this.rubberband_canvas = rubberband[0];\n",
" this.rubberband_context = rubberband[0].getContext(\"2d\");\n",
" this.rubberband_context.strokeStyle = \"#000000\";\n",
"\n",
" this._resize_canvas = function(width, height) {\n",
" // Keep the size of the canvas, canvas container, and rubber band\n",
" // canvas in synch.\n",
" canvas_div.css('width', width)\n",
" canvas_div.css('height', height)\n",
"\n",
" canvas.attr('width', width * mpl.ratio);\n",
" canvas.attr('height', height * mpl.ratio);\n",
" canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n",
"\n",
" rubberband.attr('width', width);\n",
" rubberband.attr('height', height);\n",
" }\n",
"\n",
" // Set the figure to an initial 600x600px, this will subsequently be updated\n",
" // upon first draw.\n",
" this._resize_canvas(600, 600);\n",
"\n",
" // Disable right mouse context menu.\n",
" $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n",
" return false;\n",
" });\n",
"\n",
" function set_focus () {\n",
" canvas.focus();\n",
" canvas_div.focus();\n",
" }\n",
"\n",
" window.setTimeout(set_focus, 100);\n",
"}\n",
"\n",
"mpl.figure.prototype._init_toolbar = function() {\n",
" var fig = this;\n",
"\n",
" var nav_element = $('<div/>')\n",
" nav_element.attr('style', 'width: 100%');\n",
" this.root.append(nav_element);\n",
"\n",
" // Define a callback function for later on.\n",
" function toolbar_event(event) {\n",
" return fig.toolbar_button_onclick(event['data']);\n",
" }\n",
" function toolbar_mouse_event(event) {\n",
" return fig.toolbar_button_onmouseover(event['data']);\n",
" }\n",
"\n",
" for(var toolbar_ind in mpl.toolbar_items) {\n",
" var name = mpl.toolbar_items[toolbar_ind][0];\n",
" var tooltip = mpl.toolbar_items[toolbar_ind][1];\n",
" var image = mpl.toolbar_items[toolbar_ind][2];\n",
" var method_name = mpl.toolbar_items[toolbar_ind][3];\n",
"\n",
" if (!name) {\n",
" // put a spacer in here.\n",
" continue;\n",
" }\n",
" var button = $('<button/>');\n",
" button.addClass('ui-button ui-widget ui-state-default ui-corner-all ' +\n",
" 'ui-button-icon-only');\n",
" button.attr('role', 'button');\n",
" button.attr('aria-disabled', 'false');\n",
" button.click(method_name, toolbar_event);\n",
" button.mouseover(tooltip, toolbar_mouse_event);\n",
"\n",
" var icon_img = $('<span/>');\n",
" icon_img.addClass('ui-button-icon-primary ui-icon');\n",
" icon_img.addClass(image);\n",
" icon_img.addClass('ui-corner-all');\n",
"\n",
" var tooltip_span = $('<span/>');\n",
" tooltip_span.addClass('ui-button-text');\n",
" tooltip_span.html(tooltip);\n",
"\n",
" button.append(icon_img);\n",
" button.append(tooltip_span);\n",
"\n",
" nav_element.append(button);\n",
" }\n",
"\n",
" var fmt_picker_span = $('<span/>');\n",
"\n",
" var fmt_picker = $('<select/>');\n",
" fmt_picker.addClass('mpl-toolbar-option ui-widget ui-widget-content');\n",
" fmt_picker_span.append(fmt_picker);\n",
" nav_element.append(fmt_picker_span);\n",
" this.format_dropdown = fmt_picker[0];\n",
"\n",
" for (var ind in mpl.extensions) {\n",
" var fmt = mpl.extensions[ind];\n",
" var option = $(\n",
" '<option/>', {selected: fmt === mpl.default_extension}).html(fmt);\n",
" fmt_picker.append(option)\n",
" }\n",
"\n",
" // Add hover states to the ui-buttons\n",
" $( \".ui-button\" ).hover(\n",
" function() { $(this).addClass(\"ui-state-hover\");},\n",
" function() { $(this).removeClass(\"ui-state-hover\");}\n",
" );\n",
"\n",
" var status_bar = $('<span class=\"mpl-message\"/>');\n",
" nav_element.append(status_bar);\n",
" this.message = status_bar[0];\n",
"}\n",
"\n",
"mpl.figure.prototype.request_resize = function(x_pixels, y_pixels) {\n",
" // Request matplotlib to resize the figure. Matplotlib will then trigger a resize in the client,\n",
" // which will in turn request a refresh of the image.\n",
" this.send_message('resize', {'width': x_pixels, 'height': y_pixels});\n",
"}\n",
"\n",
"mpl.figure.prototype.send_message = function(type, properties) {\n",
" properties['type'] = type;\n",
" properties['figure_id'] = this.id;\n",
" this.ws.send(JSON.stringify(properties));\n",
"}\n",
"\n",
"mpl.figure.prototype.send_draw_message = function() {\n",
" if (!this.waiting) {\n",
" this.waiting = true;\n",
" this.ws.send(JSON.stringify({type: \"draw\", figure_id: this.id}));\n",
" }\n",
"}\n",
"\n",
"\n",
"mpl.figure.prototype.handle_save = function(fig, msg) {\n",
" var format_dropdown = fig.format_dropdown;\n",
" var format = format_dropdown.options[format_dropdown.selectedIndex].value;\n",
" fig.ondownload(fig, format);\n",
"}\n",
"\n",
"\n",
"mpl.figure.prototype.handle_resize = function(fig, msg) {\n",
" var size = msg['size'];\n",
" if (size[0] != fig.canvas.width || size[1] != fig.canvas.height) {\n",
" fig._resize_canvas(size[0], size[1]);\n",
" fig.send_message(\"refresh\", {});\n",
" };\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_rubberband = function(fig, msg) {\n",
" var x0 = msg['x0'] / mpl.ratio;\n",
" var y0 = (fig.canvas.height - msg['y0']) / mpl.ratio;\n",
" var x1 = msg['x1'] / mpl.ratio;\n",
" var y1 = (fig.canvas.height - msg['y1']) / mpl.ratio;\n",
" x0 = Math.floor(x0) + 0.5;\n",
" y0 = Math.floor(y0) + 0.5;\n",
" x1 = Math.floor(x1) + 0.5;\n",
" y1 = Math.floor(y1) + 0.5;\n",
" var min_x = Math.min(x0, x1);\n",
" var min_y = Math.min(y0, y1);\n",
" var width = Math.abs(x1 - x0);\n",
" var height = Math.abs(y1 - y0);\n",
"\n",
" fig.rubberband_context.clearRect(\n",
" 0, 0, fig.canvas.width, fig.canvas.height);\n",
"\n",
" fig.rubberband_context.strokeRect(min_x, min_y, width, height);\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_figure_label = function(fig, msg) {\n",
" // Updates the figure title.\n",
" fig.header.textContent = msg['label'];\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_cursor = function(fig, msg) {\n",
" var cursor = msg['cursor'];\n",
" switch(cursor)\n",
" {\n",
" case 0:\n",
" cursor = 'pointer';\n",
" break;\n",
" case 1:\n",
" cursor = 'default';\n",
" break;\n",
" case 2:\n",
" cursor = 'crosshair';\n",
" break;\n",
" case 3:\n",
" cursor = 'move';\n",
" break;\n",
" }\n",
" fig.rubberband_canvas.style.cursor = cursor;\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_message = function(fig, msg) {\n",
" fig.message.textContent = msg['message'];\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_draw = function(fig, msg) {\n",
" // Request the server to send over a new figure.\n",
" fig.send_draw_message();\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_image_mode = function(fig, msg) {\n",
" fig.image_mode = msg['mode'];\n",
"}\n",
"\n",
"mpl.figure.prototype.updated_canvas_event = function() {\n",
" // Called whenever the canvas gets updated.\n",
" this.send_message(\"ack\", {});\n",
"}\n",
"\n",
"// A function to construct a web socket function for onmessage handling.\n",
"// Called in the figure constructor.\n",
"mpl.figure.prototype._make_on_message_function = function(fig) {\n",
" return function socket_on_message(evt) {\n",
" if (evt.data instanceof Blob) {\n",
" /* FIXME: We get \"Resource interpreted as Image but\n",
" * transferred with MIME type text/plain:\" errors on\n",
" * Chrome. But how to set the MIME type? It doesn't seem\n",
" * to be part of the websocket stream */\n",
" evt.data.type = \"image/png\";\n",
"\n",
" /* Free the memory for the previous frames */\n",
" if (fig.imageObj.src) {\n",
" (window.URL || window.webkitURL).revokeObjectURL(\n",
" fig.imageObj.src);\n",
" }\n",
"\n",
" fig.imageObj.src = (window.URL || window.webkitURL).createObjectURL(\n",
" evt.data);\n",
" fig.updated_canvas_event();\n",
" fig.waiting = false;\n",
" return;\n",
" }\n",
" else if (typeof evt.data === 'string' && evt.data.slice(0, 21) == \"data:image/png;base64\") {\n",
" fig.imageObj.src = evt.data;\n",
" fig.updated_canvas_event();\n",
" fig.waiting = false;\n",
" return;\n",
" }\n",
"\n",
" var msg = JSON.parse(evt.data);\n",
" var msg_type = msg['type'];\n",
"\n",
" // Call the \"handle_{type}\" callback, which takes\n",
" // the figure and JSON message as its only arguments.\n",
" try {\n",
" var callback = fig[\"handle_\" + msg_type];\n",
" } catch (e) {\n",
" console.log(\"No handler for the '\" + msg_type + \"' message type: \", msg);\n",
" return;\n",
" }\n",
"\n",
" if (callback) {\n",
" try {\n",
" // console.log(\"Handling '\" + msg_type + \"' message: \", msg);\n",
" callback(fig, msg);\n",
" } catch (e) {\n",
" console.log(\"Exception inside the 'handler_\" + msg_type + \"' callback:\", e, e.stack, msg);\n",
" }\n",
" }\n",
" };\n",
"}\n",
"\n",
"// from http://stackoverflow.com/questions/1114465/getting-mouse-location-in-canvas\n",
"mpl.findpos = function(e) {\n",
" //this section is from http://www.quirksmode.org/js/events_properties.html\n",
" var targ;\n",
" if (!e)\n",
" e = window.event;\n",
" if (e.target)\n",
" targ = e.target;\n",
" else if (e.srcElement)\n",
" targ = e.srcElement;\n",
" if (targ.nodeType == 3) // defeat Safari bug\n",
" targ = targ.parentNode;\n",
"\n",
" // jQuery normalizes the pageX and pageY\n",
" // pageX,Y are the mouse positions relative to the document\n",
" // offset() returns the position of the element relative to the document\n",
" var x = e.pageX - $(targ).offset().left;\n",
" var y = e.pageY - $(targ).offset().top;\n",
"\n",
" return {\"x\": x, \"y\": y};\n",
"};\n",
"\n",
"/*\n",
" * return a copy of an object with only non-object keys\n",
" * we need this to avoid circular references\n",
" * http://stackoverflow.com/a/24161582/3208463\n",
" */\n",
"function simpleKeys (original) {\n",
" return Object.keys(original).reduce(function (obj, key) {\n",
" if (typeof original[key] !== 'object')\n",
" obj[key] = original[key]\n",
" return obj;\n",
" }, {});\n",
"}\n",
"\n",
"mpl.figure.prototype.mouse_event = function(event, name) {\n",
" var canvas_pos = mpl.findpos(event)\n",
"\n",
" if (name === 'button_press')\n",
" {\n",
" this.canvas.focus();\n",
" this.canvas_div.focus();\n",
" }\n",
"\n",
" var x = canvas_pos.x * mpl.ratio;\n",
" var y = canvas_pos.y * mpl.ratio;\n",
"\n",
" this.send_message(name, {x: x, y: y, button: event.button,\n",
" step: event.step,\n",
" guiEvent: simpleKeys(event)});\n",
"\n",
" /* This prevents the web browser from automatically changing to\n",
" * the text insertion cursor when the button is pressed. We want\n",
" * to control all of the cursor setting manually through the\n",
" * 'cursor' event from matplotlib */\n",
" event.preventDefault();\n",
" return false;\n",
"}\n",
"\n",
"mpl.figure.prototype._key_event_extra = function(event, name) {\n",
" // Handle any extra behaviour associated with a key event\n",
"}\n",
"\n",
"mpl.figure.prototype.key_event = function(event, name) {\n",
"\n",
" // Prevent repeat events\n",
" if (name == 'key_press')\n",
" {\n",
" if (event.which === this._key)\n",
" return;\n",
" else\n",
" this._key = event.which;\n",
" }\n",
" if (name == 'key_release')\n",
" this._key = null;\n",
"\n",
" var value = '';\n",
" if (event.ctrlKey && event.which != 17)\n",
" value += \"ctrl+\";\n",
" if (event.altKey && event.which != 18)\n",
" value += \"alt+\";\n",
" if (event.shiftKey && event.which != 16)\n",
" value += \"shift+\";\n",
"\n",
" value += 'k';\n",
" value += event.which.toString();\n",
"\n",
" this._key_event_extra(event, name);\n",
"\n",
" this.send_message(name, {key: value,\n",
" guiEvent: simpleKeys(event)});\n",
" return false;\n",
"}\n",
"\n",
"mpl.figure.prototype.toolbar_button_onclick = function(name) {\n",
" if (name == 'download') {\n",
" this.handle_save(this, null);\n",
" } else {\n",
" this.send_message(\"toolbar_button\", {name: name});\n",
" }\n",
"};\n",
"\n",
"mpl.figure.prototype.toolbar_button_onmouseover = function(tooltip) {\n",
" this.message.textContent = tooltip;\n",
"};\n",
"mpl.toolbar_items = [[\"Home\", \"Reset original view\", \"fa fa-home icon-home\", \"home\"], [\"Back\", \"Back to previous view\", \"fa fa-arrow-left icon-arrow-left\", \"back\"], [\"Forward\", \"Forward to next view\", \"fa fa-arrow-right icon-arrow-right\", \"forward\"], [\"\", \"\", \"\", \"\"], [\"Pan\", \"Pan axes with left mouse, zoom with right\", \"fa fa-arrows icon-move\", \"pan\"], [\"Zoom\", \"Zoom to rectangle\", \"fa fa-square-o icon-check-empty\", \"zoom\"], [\"\", \"\", \"\", \"\"], [\"Download\", \"Download plot\", \"fa fa-floppy-o icon-save\", \"download\"]];\n",
"\n",
"mpl.extensions = [\"eps\", \"jpeg\", \"pdf\", \"png\", \"ps\", \"raw\", \"svg\", \"tif\"];\n",
"\n",
"mpl.default_extension = \"png\";var comm_websocket_adapter = function(comm) {\n",
" // Create a \"websocket\"-like object which calls the given IPython comm\n",
" // object with the appropriate methods. Currently this is a non binary\n",
" // socket, so there is still some room for performance tuning.\n",
" var ws = {};\n",
"\n",
" ws.close = function() {\n",
" comm.close()\n",
" };\n",
" ws.send = function(m) {\n",
" //console.log('sending', m);\n",
" comm.send(m);\n",
" };\n",
" // Register the callback with on_msg.\n",
" comm.on_msg(function(msg) {\n",
" //console.log('receiving', msg['content']['data'], msg);\n",
" // Pass the mpl event to the overridden (by mpl) onmessage function.\n",
" ws.onmessage(msg['content']['data'])\n",
" });\n",
" return ws;\n",
"}\n",
"\n",
"mpl.mpl_figure_comm = function(comm, msg) {\n",
" // This is the function which gets called when the mpl process\n",
" // starts-up an IPython Comm through the \"matplotlib\" channel.\n",
"\n",
" var id = msg.content.data.id;\n",
" // Get hold of the div created by the display call when the Comm\n",
" // socket was opened in Python.\n",
" var element = $(\"#\" + id);\n",
" var ws_proxy = comm_websocket_adapter(comm)\n",
"\n",
" function ondownload(figure, format) {\n",
" window.open(figure.imageObj.src);\n",
" }\n",
"\n",
" var fig = new mpl.figure(id, ws_proxy,\n",
" ondownload,\n",
" element.get(0));\n",
"\n",
" // Call onopen now - mpl needs it, as it is assuming we've passed it a real\n",
" // web socket which is closed, not our websocket->open comm proxy.\n",
" ws_proxy.onopen();\n",
"\n",
" fig.parent_element = element.get(0);\n",
" fig.cell_info = mpl.find_output_cell(\"<div id='\" + id + \"'></div>\");\n",
" if (!fig.cell_info) {\n",
" console.error(\"Failed to find cell for figure\", id, fig);\n",
" return;\n",
" }\n",
"\n",
" var output_index = fig.cell_info[2]\n",
" var cell = fig.cell_info[0];\n",
"\n",
"};\n",
"\n",
"mpl.figure.prototype.handle_close = function(fig, msg) {\n",
" var width = fig.canvas.width/mpl.ratio\n",
" fig.root.unbind('remove')\n",
"\n",
" // Update the output cell to use the data from the current canvas.\n",
" fig.push_to_output();\n",
" var dataURL = fig.canvas.toDataURL();\n",
" // Re-enable the keyboard manager in IPython - without this line, in FF,\n",
" // the notebook keyboard shortcuts fail.\n",
" IPython.keyboard_manager.enable()\n",
" $(fig.parent_element).html('<img src=\"' + dataURL + '\" width=\"' + width + '\">');\n",
" fig.close_ws(fig, msg);\n",
"}\n",
"\n",
"mpl.figure.prototype.close_ws = function(fig, msg){\n",
" fig.send_message('closing', msg);\n",
" // fig.ws.close()\n",
"}\n",
"\n",
"mpl.figure.prototype.push_to_output = function(remove_interactive) {\n",
" // Turn the data on the canvas into data in the output cell.\n",
" var width = this.canvas.width/mpl.ratio\n",
" var dataURL = this.canvas.toDataURL();\n",
" this.cell_info[1]['text/html'] = '<img src=\"' + dataURL + '\" width=\"' + width + '\">';\n",
"}\n",
"\n",
"mpl.figure.prototype.updated_canvas_event = function() {\n",
" // Tell IPython that the notebook contents must change.\n",
" IPython.notebook.set_dirty(true);\n",
" this.send_message(\"ack\", {});\n",
" var fig = this;\n",
" // Wait a second, then push the new image to the DOM so\n",
" // that it is saved nicely (might be nice to debounce this).\n",
" setTimeout(function () { fig.push_to_output() }, 1000);\n",
"}\n",
"\n",
"mpl.figure.prototype._init_toolbar = function() {\n",
" var fig = this;\n",
"\n",
" var nav_element = $('<div/>')\n",
" nav_element.attr('style', 'width: 100%');\n",
" this.root.append(nav_element);\n",
"\n",
" // Define a callback function for later on.\n",
" function toolbar_event(event) {\n",
" return fig.toolbar_button_onclick(event['data']);\n",
" }\n",
" function toolbar_mouse_event(event) {\n",
" return fig.toolbar_button_onmouseover(event['data']);\n",
" }\n",
"\n",
" for(var toolbar_ind in mpl.toolbar_items){\n",
" var name = mpl.toolbar_items[toolbar_ind][0];\n",
" var tooltip = mpl.toolbar_items[toolbar_ind][1];\n",
" var image = mpl.toolbar_items[toolbar_ind][2];\n",
" var method_name = mpl.toolbar_items[toolbar_ind][3];\n",
"\n",
" if (!name) { continue; };\n",
"\n",
" var button = $('<button class=\"btn btn-default\" href=\"#\" title=\"' + name + '\"><i class=\"fa ' + image + ' fa-lg\"></i></button>');\n",
" button.click(method_name, toolbar_event);\n",
" button.mouseover(tooltip, toolbar_mouse_event);\n",
" nav_element.append(button);\n",
" }\n",
"\n",
" // Add the status bar.\n",
" var status_bar = $('<span class=\"mpl-message\" style=\"text-align:right; float: right;\"/>');\n",
" nav_element.append(status_bar);\n",
" this.message = status_bar[0];\n",
"\n",
" // Add the close button to the window.\n",
" var buttongrp = $('<div class=\"btn-group inline pull-right\"></div>');\n",
" var button = $('<button class=\"btn btn-mini btn-primary\" href=\"#\" title=\"Stop Interaction\"><i class=\"fa fa-power-off icon-remove icon-large\"></i></button>');\n",
" button.click(function (evt) { fig.handle_close(fig, {}); } );\n",
" button.mouseover('Stop Interaction', toolbar_mouse_event);\n",
" buttongrp.append(button);\n",
" var titlebar = this.root.find($('.ui-dialog-titlebar'));\n",
" titlebar.prepend(buttongrp);\n",
"}\n",
"\n",
"mpl.figure.prototype._root_extra_style = function(el){\n",
" var fig = this\n",
" el.on(\"remove\", function(){\n",
"\tfig.close_ws(fig, {});\n",
" });\n",
"}\n",
"\n",
"mpl.figure.prototype._canvas_extra_style = function(el){\n",
" // this is important to make the div 'focusable\n",
" el.attr('tabindex', 0)\n",
" // reach out to IPython and tell the keyboard manager to turn it's self\n",
" // off when our div gets focus\n",
"\n",
" // location in version 3\n",
" if (IPython.notebook.keyboard_manager) {\n",
" IPython.notebook.keyboard_manager.register_events(el);\n",
" }\n",
" else {\n",
" // location in version 2\n",
" IPython.keyboard_manager.register_events(el);\n",
" }\n",
"\n",
"}\n",
"\n",
"mpl.figure.prototype._key_event_extra = function(event, name) {\n",
" var manager = IPython.notebook.keyboard_manager;\n",
" if (!manager)\n",
" manager = IPython.keyboard_manager;\n",
"\n",
" // Check for shift+enter\n",
" if (event.shiftKey && event.which == 13) {\n",
" this.canvas_div.blur();\n",
" event.shiftKey = false;\n",
" // Send a \"J\" for go to next cell\n",
" event.which = 74;\n",
" event.keyCode = 74;\n",
" manager.command_mode();\n",
" manager.handle_keydown(event);\n",
" }\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_save = function(fig, msg) {\n",
" fig.ondownload(fig, null);\n",
"}\n",
"\n",
"\n",
"mpl.find_output_cell = function(html_output) {\n",
" // Return the cell and output element which can be found *uniquely* in the notebook.\n",
" // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n",
" // IPython event is triggered only after the cells have been serialised, which for\n",
" // our purposes (turning an active figure into a static one), is too late.\n",
" var cells = IPython.notebook.get_cells();\n",
" var ncells = cells.length;\n",
" for (var i=0; i<ncells; i++) {\n",
" var cell = cells[i];\n",
" if (cell.cell_type === 'code'){\n",
" for (var j=0; j<cell.output_area.outputs.length; j++) {\n",
" var data = cell.output_area.outputs[j];\n",
" if (data.data) {\n",
" // IPython >= 3 moved mimebundle to data attribute of output\n",
" data = data.data;\n",
" }\n",
" if (data['text/html'] == html_output) {\n",
" return [cell, data, j];\n",
" }\n",
" }\n",
" }\n",
" }\n",
"}\n",
"\n",
"// Register the function which deals with the matplotlib target/channel.\n",
"// The kernel may be null if the page has been refreshed.\n",
"if (IPython.notebook.kernel != null) {\n",
" IPython.notebook.kernel.comm_manager.register_target('matplotlib', mpl.mpl_figure_comm);\n",
"}\n"
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<img src=\"\" width=\"1000\">"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n"
]
}
],
"source": [
"#df_results = %sql SELECT * FROM $results_table ORDER BY run_id;\n",
"df_results = %sql SELECT * FROM $results_table ORDER BY training_loss_final ASC LIMIT 100;\n",
"df_results = df_results.DataFrame()\n",
"\n",
"#set up plots\n",
"fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(10,5))\n",
"fig.legend(ncol=4)\n",
"fig.tight_layout()\n",
"\n",
"ax_metric = axs[0]\n",
"ax_loss = axs[1]\n",
"\n",
"ax_metric.xaxis.set_major_locator(MaxNLocator(integer=True))\n",
"ax_metric.set_xlabel('Iteration')\n",
"ax_metric.set_ylabel('Metric')\n",
"ax_metric.set_title('Training metric curve')\n",
"\n",
"ax_loss.xaxis.set_major_locator(MaxNLocator(integer=True))\n",
"ax_loss.set_xlabel('Iteration')\n",
"ax_loss.set_ylabel('Loss')\n",
"ax_loss.set_title('Training loss curve')\n",
"\n",
"for run_id in df_results['run_id']:\n",
" df_output_info = %sql SELECT training_metrics,training_loss FROM $results_table WHERE run_id = $run_id\n",
" df_output_info = df_output_info.DataFrame()\n",
" training_metrics = df_output_info['training_metrics'][0]\n",
" training_loss = df_output_info['training_loss'][0]\n",
" X = range(len(training_metrics))\n",
" \n",
" ax_metric.plot(X, training_metrics, label=run_id, marker='o')\n",
" ax_loss.plot(X, training_loss, label=run_id, marker='o')\n",
"\n",
"# fig.savefig('./lc_keras_fit.png', dpi = 300)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Validation dataset"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"65 rows affected.\n"
]
},
{
"data": {
"application/javascript": [
"/* Put everything inside the global mpl namespace */\n",
"window.mpl = {};\n",
"\n",
"\n",
"mpl.get_websocket_type = function() {\n",
" if (typeof(WebSocket) !== 'undefined') {\n",
" return WebSocket;\n",
" } else if (typeof(MozWebSocket) !== 'undefined') {\n",
" return MozWebSocket;\n",
" } else {\n",
" alert('Your browser does not have WebSocket support.' +\n",
" 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n",
" 'Firefox 4 and 5 are also supported but you ' +\n",
" 'have to enable WebSockets in about:config.');\n",
" };\n",
"}\n",
"\n",
"mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n",
" this.id = figure_id;\n",
"\n",
" this.ws = websocket;\n",
"\n",
" this.supports_binary = (this.ws.binaryType != undefined);\n",
"\n",
" if (!this.supports_binary) {\n",
" var warnings = document.getElementById(\"mpl-warnings\");\n",
" if (warnings) {\n",
" warnings.style.display = 'block';\n",
" warnings.textContent = (\n",
" \"This browser does not support binary websocket messages. \" +\n",
" \"Performance may be slow.\");\n",
" }\n",
" }\n",
"\n",
" this.imageObj = new Image();\n",
"\n",
" this.context = undefined;\n",
" this.message = undefined;\n",
" this.canvas = undefined;\n",
" this.rubberband_canvas = undefined;\n",
" this.rubberband_context = undefined;\n",
" this.format_dropdown = undefined;\n",
"\n",
" this.image_mode = 'full';\n",
"\n",
" this.root = $('<div/>');\n",
" this._root_extra_style(this.root)\n",
" this.root.attr('style', 'display: inline-block');\n",
"\n",
" $(parent_element).append(this.root);\n",
"\n",
" this._init_header(this);\n",
" this._init_canvas(this);\n",
" this._init_toolbar(this);\n",
"\n",
" var fig = this;\n",
"\n",
" this.waiting = false;\n",
"\n",
" this.ws.onopen = function () {\n",
" fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n",
" fig.send_message(\"send_image_mode\", {});\n",
" if (mpl.ratio != 1) {\n",
" fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n",
" }\n",
" fig.send_message(\"refresh\", {});\n",
" }\n",
"\n",
" this.imageObj.onload = function() {\n",
" if (fig.image_mode == 'full') {\n",
" // Full images could contain transparency (where diff images\n",
" // almost always do), so we need to clear the canvas so that\n",
" // there is no ghosting.\n",
" fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n",
" }\n",
" fig.context.drawImage(fig.imageObj, 0, 0);\n",
" };\n",
"\n",
" this.imageObj.onunload = function() {\n",
" fig.ws.close();\n",
" }\n",
"\n",
" this.ws.onmessage = this._make_on_message_function(this);\n",
"\n",
" this.ondownload = ondownload;\n",
"}\n",
"\n",
"mpl.figure.prototype._init_header = function() {\n",
" var titlebar = $(\n",
" '<div class=\"ui-dialog-titlebar ui-widget-header ui-corner-all ' +\n",
" 'ui-helper-clearfix\"/>');\n",
" var titletext = $(\n",
" '<div class=\"ui-dialog-title\" style=\"width: 100%; ' +\n",
" 'text-align: center; padding: 3px;\"/>');\n",
" titlebar.append(titletext)\n",
" this.root.append(titlebar);\n",
" this.header = titletext[0];\n",
"}\n",
"\n",
"\n",
"\n",
"mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n",
"\n",
"}\n",
"\n",
"\n",
"mpl.figure.prototype._root_extra_style = function(canvas_div) {\n",
"\n",
"}\n",
"\n",
"mpl.figure.prototype._init_canvas = function() {\n",
" var fig = this;\n",
"\n",
" var canvas_div = $('<div/>');\n",
"\n",
" canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n",
"\n",
" function canvas_keyboard_event(event) {\n",
" return fig.key_event(event, event['data']);\n",
" }\n",
"\n",
" canvas_div.keydown('key_press', canvas_keyboard_event);\n",
" canvas_div.keyup('key_release', canvas_keyboard_event);\n",
" this.canvas_div = canvas_div\n",
" this._canvas_extra_style(canvas_div)\n",
" this.root.append(canvas_div);\n",
"\n",
" var canvas = $('<canvas/>');\n",
" canvas.addClass('mpl-canvas');\n",
" canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n",
"\n",
" this.canvas = canvas[0];\n",
" this.context = canvas[0].getContext(\"2d\");\n",
"\n",
" var backingStore = this.context.backingStorePixelRatio ||\n",
"\tthis.context.webkitBackingStorePixelRatio ||\n",
"\tthis.context.mozBackingStorePixelRatio ||\n",
"\tthis.context.msBackingStorePixelRatio ||\n",
"\tthis.context.oBackingStorePixelRatio ||\n",
"\tthis.context.backingStorePixelRatio || 1;\n",
"\n",
" mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n",
"\n",
" var rubberband = $('<canvas/>');\n",
" rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n",
"\n",
" var pass_mouse_events = true;\n",
"\n",
" canvas_div.resizable({\n",
" start: function(event, ui) {\n",
" pass_mouse_events = false;\n",
" },\n",
" resize: function(event, ui) {\n",
" fig.request_resize(ui.size.width, ui.size.height);\n",
" },\n",
" stop: function(event, ui) {\n",
" pass_mouse_events = true;\n",
" fig.request_resize(ui.size.width, ui.size.height);\n",
" },\n",
" });\n",
"\n",
" function mouse_event_fn(event) {\n",
" if (pass_mouse_events)\n",
" return fig.mouse_event(event, event['data']);\n",
" }\n",
"\n",
" rubberband.mousedown('button_press', mouse_event_fn);\n",
" rubberband.mouseup('button_release', mouse_event_fn);\n",
" // Throttle sequential mouse events to 1 every 20ms.\n",
" rubberband.mousemove('motion_notify', mouse_event_fn);\n",
"\n",
" rubberband.mouseenter('figure_enter', mouse_event_fn);\n",
" rubberband.mouseleave('figure_leave', mouse_event_fn);\n",
"\n",
" canvas_div.on(\"wheel\", function (event) {\n",
" event = event.originalEvent;\n",
" event['data'] = 'scroll'\n",
" if (event.deltaY < 0) {\n",
" event.step = 1;\n",
" } else {\n",
" event.step = -1;\n",
" }\n",
" mouse_event_fn(event);\n",
" });\n",
"\n",
" canvas_div.append(canvas);\n",
" canvas_div.append(rubberband);\n",
"\n",
" this.rubberband = rubberband;\n",
" this.rubberband_canvas = rubberband[0];\n",
" this.rubberband_context = rubberband[0].getContext(\"2d\");\n",
" this.rubberband_context.strokeStyle = \"#000000\";\n",
"\n",
" this._resize_canvas = function(width, height) {\n",
" // Keep the size of the canvas, canvas container, and rubber band\n",
" // canvas in synch.\n",
" canvas_div.css('width', width)\n",
" canvas_div.css('height', height)\n",
"\n",
" canvas.attr('width', width * mpl.ratio);\n",
" canvas.attr('height', height * mpl.ratio);\n",
" canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n",
"\n",
" rubberband.attr('width', width);\n",
" rubberband.attr('height', height);\n",
" }\n",
"\n",
" // Set the figure to an initial 600x600px, this will subsequently be updated\n",
" // upon first draw.\n",
" this._resize_canvas(600, 600);\n",
"\n",
" // Disable right mouse context menu.\n",
" $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n",
" return false;\n",
" });\n",
"\n",
" function set_focus () {\n",
" canvas.focus();\n",
" canvas_div.focus();\n",
" }\n",
"\n",
" window.setTimeout(set_focus, 100);\n",
"}\n",
"\n",
"mpl.figure.prototype._init_toolbar = function() {\n",
" var fig = this;\n",
"\n",
" var nav_element = $('<div/>')\n",
" nav_element.attr('style', 'width: 100%');\n",
" this.root.append(nav_element);\n",
"\n",
" // Define a callback function for later on.\n",
" function toolbar_event(event) {\n",
" return fig.toolbar_button_onclick(event['data']);\n",
" }\n",
" function toolbar_mouse_event(event) {\n",
" return fig.toolbar_button_onmouseover(event['data']);\n",
" }\n",
"\n",
" for(var toolbar_ind in mpl.toolbar_items) {\n",
" var name = mpl.toolbar_items[toolbar_ind][0];\n",
" var tooltip = mpl.toolbar_items[toolbar_ind][1];\n",
" var image = mpl.toolbar_items[toolbar_ind][2];\n",
" var method_name = mpl.toolbar_items[toolbar_ind][3];\n",
"\n",
" if (!name) {\n",
" // put a spacer in here.\n",
" continue;\n",
" }\n",
" var button = $('<button/>');\n",
" button.addClass('ui-button ui-widget ui-state-default ui-corner-all ' +\n",
" 'ui-button-icon-only');\n",
" button.attr('role', 'button');\n",
" button.attr('aria-disabled', 'false');\n",
" button.click(method_name, toolbar_event);\n",
" button.mouseover(tooltip, toolbar_mouse_event);\n",
"\n",
" var icon_img = $('<span/>');\n",
" icon_img.addClass('ui-button-icon-primary ui-icon');\n",
" icon_img.addClass(image);\n",
" icon_img.addClass('ui-corner-all');\n",
"\n",
" var tooltip_span = $('<span/>');\n",
" tooltip_span.addClass('ui-button-text');\n",
" tooltip_span.html(tooltip);\n",
"\n",
" button.append(icon_img);\n",
" button.append(tooltip_span);\n",
"\n",
" nav_element.append(button);\n",
" }\n",
"\n",
" var fmt_picker_span = $('<span/>');\n",
"\n",
" var fmt_picker = $('<select/>');\n",
" fmt_picker.addClass('mpl-toolbar-option ui-widget ui-widget-content');\n",
" fmt_picker_span.append(fmt_picker);\n",
" nav_element.append(fmt_picker_span);\n",
" this.format_dropdown = fmt_picker[0];\n",
"\n",
" for (var ind in mpl.extensions) {\n",
" var fmt = mpl.extensions[ind];\n",
" var option = $(\n",
" '<option/>', {selected: fmt === mpl.default_extension}).html(fmt);\n",
" fmt_picker.append(option)\n",
" }\n",
"\n",
" // Add hover states to the ui-buttons\n",
" $( \".ui-button\" ).hover(\n",
" function() { $(this).addClass(\"ui-state-hover\");},\n",
" function() { $(this).removeClass(\"ui-state-hover\");}\n",
" );\n",
"\n",
" var status_bar = $('<span class=\"mpl-message\"/>');\n",
" nav_element.append(status_bar);\n",
" this.message = status_bar[0];\n",
"}\n",
"\n",
"mpl.figure.prototype.request_resize = function(x_pixels, y_pixels) {\n",
" // Request matplotlib to resize the figure. Matplotlib will then trigger a resize in the client,\n",
" // which will in turn request a refresh of the image.\n",
" this.send_message('resize', {'width': x_pixels, 'height': y_pixels});\n",
"}\n",
"\n",
"mpl.figure.prototype.send_message = function(type, properties) {\n",
" properties['type'] = type;\n",
" properties['figure_id'] = this.id;\n",
" this.ws.send(JSON.stringify(properties));\n",
"}\n",
"\n",
"mpl.figure.prototype.send_draw_message = function() {\n",
" if (!this.waiting) {\n",
" this.waiting = true;\n",
" this.ws.send(JSON.stringify({type: \"draw\", figure_id: this.id}));\n",
" }\n",
"}\n",
"\n",
"\n",
"mpl.figure.prototype.handle_save = function(fig, msg) {\n",
" var format_dropdown = fig.format_dropdown;\n",
" var format = format_dropdown.options[format_dropdown.selectedIndex].value;\n",
" fig.ondownload(fig, format);\n",
"}\n",
"\n",
"\n",
"mpl.figure.prototype.handle_resize = function(fig, msg) {\n",
" var size = msg['size'];\n",
" if (size[0] != fig.canvas.width || size[1] != fig.canvas.height) {\n",
" fig._resize_canvas(size[0], size[1]);\n",
" fig.send_message(\"refresh\", {});\n",
" };\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_rubberband = function(fig, msg) {\n",
" var x0 = msg['x0'] / mpl.ratio;\n",
" var y0 = (fig.canvas.height - msg['y0']) / mpl.ratio;\n",
" var x1 = msg['x1'] / mpl.ratio;\n",
" var y1 = (fig.canvas.height - msg['y1']) / mpl.ratio;\n",
" x0 = Math.floor(x0) + 0.5;\n",
" y0 = Math.floor(y0) + 0.5;\n",
" x1 = Math.floor(x1) + 0.5;\n",
" y1 = Math.floor(y1) + 0.5;\n",
" var min_x = Math.min(x0, x1);\n",
" var min_y = Math.min(y0, y1);\n",
" var width = Math.abs(x1 - x0);\n",
" var height = Math.abs(y1 - y0);\n",
"\n",
" fig.rubberband_context.clearRect(\n",
" 0, 0, fig.canvas.width, fig.canvas.height);\n",
"\n",
" fig.rubberband_context.strokeRect(min_x, min_y, width, height);\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_figure_label = function(fig, msg) {\n",
" // Updates the figure title.\n",
" fig.header.textContent = msg['label'];\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_cursor = function(fig, msg) {\n",
" var cursor = msg['cursor'];\n",
" switch(cursor)\n",
" {\n",
" case 0:\n",
" cursor = 'pointer';\n",
" break;\n",
" case 1:\n",
" cursor = 'default';\n",
" break;\n",
" case 2:\n",
" cursor = 'crosshair';\n",
" break;\n",
" case 3:\n",
" cursor = 'move';\n",
" break;\n",
" }\n",
" fig.rubberband_canvas.style.cursor = cursor;\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_message = function(fig, msg) {\n",
" fig.message.textContent = msg['message'];\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_draw = function(fig, msg) {\n",
" // Request the server to send over a new figure.\n",
" fig.send_draw_message();\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_image_mode = function(fig, msg) {\n",
" fig.image_mode = msg['mode'];\n",
"}\n",
"\n",
"mpl.figure.prototype.updated_canvas_event = function() {\n",
" // Called whenever the canvas gets updated.\n",
" this.send_message(\"ack\", {});\n",
"}\n",
"\n",
"// A function to construct a web socket function for onmessage handling.\n",
"// Called in the figure constructor.\n",
"mpl.figure.prototype._make_on_message_function = function(fig) {\n",
" return function socket_on_message(evt) {\n",
" if (evt.data instanceof Blob) {\n",
" /* FIXME: We get \"Resource interpreted as Image but\n",
" * transferred with MIME type text/plain:\" errors on\n",
" * Chrome. But how to set the MIME type? It doesn't seem\n",
" * to be part of the websocket stream */\n",
" evt.data.type = \"image/png\";\n",
"\n",
" /* Free the memory for the previous frames */\n",
" if (fig.imageObj.src) {\n",
" (window.URL || window.webkitURL).revokeObjectURL(\n",
" fig.imageObj.src);\n",
" }\n",
"\n",
" fig.imageObj.src = (window.URL || window.webkitURL).createObjectURL(\n",
" evt.data);\n",
" fig.updated_canvas_event();\n",
" fig.waiting = false;\n",
" return;\n",
" }\n",
" else if (typeof evt.data === 'string' && evt.data.slice(0, 21) == \"data:image/png;base64\") {\n",
" fig.imageObj.src = evt.data;\n",
" fig.updated_canvas_event();\n",
" fig.waiting = false;\n",
" return;\n",
" }\n",
"\n",
" var msg = JSON.parse(evt.data);\n",
" var msg_type = msg['type'];\n",
"\n",
" // Call the \"handle_{type}\" callback, which takes\n",
" // the figure and JSON message as its only arguments.\n",
" try {\n",
" var callback = fig[\"handle_\" + msg_type];\n",
" } catch (e) {\n",
" console.log(\"No handler for the '\" + msg_type + \"' message type: \", msg);\n",
" return;\n",
" }\n",
"\n",
" if (callback) {\n",
" try {\n",
" // console.log(\"Handling '\" + msg_type + \"' message: \", msg);\n",
" callback(fig, msg);\n",
" } catch (e) {\n",
" console.log(\"Exception inside the 'handler_\" + msg_type + \"' callback:\", e, e.stack, msg);\n",
" }\n",
" }\n",
" };\n",
"}\n",
"\n",
"// from http://stackoverflow.com/questions/1114465/getting-mouse-location-in-canvas\n",
"mpl.findpos = function(e) {\n",
" //this section is from http://www.quirksmode.org/js/events_properties.html\n",
" var targ;\n",
" if (!e)\n",
" e = window.event;\n",
" if (e.target)\n",
" targ = e.target;\n",
" else if (e.srcElement)\n",
" targ = e.srcElement;\n",
" if (targ.nodeType == 3) // defeat Safari bug\n",
" targ = targ.parentNode;\n",
"\n",
" // jQuery normalizes the pageX and pageY\n",
" // pageX,Y are the mouse positions relative to the document\n",
" // offset() returns the position of the element relative to the document\n",
" var x = e.pageX - $(targ).offset().left;\n",
" var y = e.pageY - $(targ).offset().top;\n",
"\n",
" return {\"x\": x, \"y\": y};\n",
"};\n",
"\n",
"/*\n",
" * return a copy of an object with only non-object keys\n",
" * we need this to avoid circular references\n",
" * http://stackoverflow.com/a/24161582/3208463\n",
" */\n",
"function simpleKeys (original) {\n",
" return Object.keys(original).reduce(function (obj, key) {\n",
" if (typeof original[key] !== 'object')\n",
" obj[key] = original[key]\n",
" return obj;\n",
" }, {});\n",
"}\n",
"\n",
"mpl.figure.prototype.mouse_event = function(event, name) {\n",
" var canvas_pos = mpl.findpos(event)\n",
"\n",
" if (name === 'button_press')\n",
" {\n",
" this.canvas.focus();\n",
" this.canvas_div.focus();\n",
" }\n",
"\n",
" var x = canvas_pos.x * mpl.ratio;\n",
" var y = canvas_pos.y * mpl.ratio;\n",
"\n",
" this.send_message(name, {x: x, y: y, button: event.button,\n",
" step: event.step,\n",
" guiEvent: simpleKeys(event)});\n",
"\n",
" /* This prevents the web browser from automatically changing to\n",
" * the text insertion cursor when the button is pressed. We want\n",
" * to control all of the cursor setting manually through the\n",
" * 'cursor' event from matplotlib */\n",
" event.preventDefault();\n",
" return false;\n",
"}\n",
"\n",
"mpl.figure.prototype._key_event_extra = function(event, name) {\n",
" // Handle any extra behaviour associated with a key event\n",
"}\n",
"\n",
"mpl.figure.prototype.key_event = function(event, name) {\n",
"\n",
" // Prevent repeat events\n",
" if (name == 'key_press')\n",
" {\n",
" if (event.which === this._key)\n",
" return;\n",
" else\n",
" this._key = event.which;\n",
" }\n",
" if (name == 'key_release')\n",
" this._key = null;\n",
"\n",
" var value = '';\n",
" if (event.ctrlKey && event.which != 17)\n",
" value += \"ctrl+\";\n",
" if (event.altKey && event.which != 18)\n",
" value += \"alt+\";\n",
" if (event.shiftKey && event.which != 16)\n",
" value += \"shift+\";\n",
"\n",
" value += 'k';\n",
" value += event.which.toString();\n",
"\n",
" this._key_event_extra(event, name);\n",
"\n",
" this.send_message(name, {key: value,\n",
" guiEvent: simpleKeys(event)});\n",
" return false;\n",
"}\n",
"\n",
"mpl.figure.prototype.toolbar_button_onclick = function(name) {\n",
" if (name == 'download') {\n",
" this.handle_save(this, null);\n",
" } else {\n",
" this.send_message(\"toolbar_button\", {name: name});\n",
" }\n",
"};\n",
"\n",
"mpl.figure.prototype.toolbar_button_onmouseover = function(tooltip) {\n",
" this.message.textContent = tooltip;\n",
"};\n",
"mpl.toolbar_items = [[\"Home\", \"Reset original view\", \"fa fa-home icon-home\", \"home\"], [\"Back\", \"Back to previous view\", \"fa fa-arrow-left icon-arrow-left\", \"back\"], [\"Forward\", \"Forward to next view\", \"fa fa-arrow-right icon-arrow-right\", \"forward\"], [\"\", \"\", \"\", \"\"], [\"Pan\", \"Pan axes with left mouse, zoom with right\", \"fa fa-arrows icon-move\", \"pan\"], [\"Zoom\", \"Zoom to rectangle\", \"fa fa-square-o icon-check-empty\", \"zoom\"], [\"\", \"\", \"\", \"\"], [\"Download\", \"Download plot\", \"fa fa-floppy-o icon-save\", \"download\"]];\n",
"\n",
"mpl.extensions = [\"eps\", \"jpeg\", \"pdf\", \"png\", \"ps\", \"raw\", \"svg\", \"tif\"];\n",
"\n",
"mpl.default_extension = \"png\";var comm_websocket_adapter = function(comm) {\n",
" // Create a \"websocket\"-like object which calls the given IPython comm\n",
" // object with the appropriate methods. Currently this is a non binary\n",
" // socket, so there is still some room for performance tuning.\n",
" var ws = {};\n",
"\n",
" ws.close = function() {\n",
" comm.close()\n",
" };\n",
" ws.send = function(m) {\n",
" //console.log('sending', m);\n",
" comm.send(m);\n",
" };\n",
" // Register the callback with on_msg.\n",
" comm.on_msg(function(msg) {\n",
" //console.log('receiving', msg['content']['data'], msg);\n",
" // Pass the mpl event to the overridden (by mpl) onmessage function.\n",
" ws.onmessage(msg['content']['data'])\n",
" });\n",
" return ws;\n",
"}\n",
"\n",
"mpl.mpl_figure_comm = function(comm, msg) {\n",
" // This is the function which gets called when the mpl process\n",
" // starts-up an IPython Comm through the \"matplotlib\" channel.\n",
"\n",
" var id = msg.content.data.id;\n",
" // Get hold of the div created by the display call when the Comm\n",
" // socket was opened in Python.\n",
" var element = $(\"#\" + id);\n",
" var ws_proxy = comm_websocket_adapter(comm)\n",
"\n",
" function ondownload(figure, format) {\n",
" window.open(figure.imageObj.src);\n",
" }\n",
"\n",
" var fig = new mpl.figure(id, ws_proxy,\n",
" ondownload,\n",
" element.get(0));\n",
"\n",
" // Call onopen now - mpl needs it, as it is assuming we've passed it a real\n",
" // web socket which is closed, not our websocket->open comm proxy.\n",
" ws_proxy.onopen();\n",
"\n",
" fig.parent_element = element.get(0);\n",
" fig.cell_info = mpl.find_output_cell(\"<div id='\" + id + \"'></div>\");\n",
" if (!fig.cell_info) {\n",
" console.error(\"Failed to find cell for figure\", id, fig);\n",
" return;\n",
" }\n",
"\n",
" var output_index = fig.cell_info[2]\n",
" var cell = fig.cell_info[0];\n",
"\n",
"};\n",
"\n",
"mpl.figure.prototype.handle_close = function(fig, msg) {\n",
" var width = fig.canvas.width/mpl.ratio\n",
" fig.root.unbind('remove')\n",
"\n",
" // Update the output cell to use the data from the current canvas.\n",
" fig.push_to_output();\n",
" var dataURL = fig.canvas.toDataURL();\n",
" // Re-enable the keyboard manager in IPython - without this line, in FF,\n",
" // the notebook keyboard shortcuts fail.\n",
" IPython.keyboard_manager.enable()\n",
" $(fig.parent_element).html('<img src=\"' + dataURL + '\" width=\"' + width + '\">');\n",
" fig.close_ws(fig, msg);\n",
"}\n",
"\n",
"mpl.figure.prototype.close_ws = function(fig, msg){\n",
" fig.send_message('closing', msg);\n",
" // fig.ws.close()\n",
"}\n",
"\n",
"mpl.figure.prototype.push_to_output = function(remove_interactive) {\n",
" // Turn the data on the canvas into data in the output cell.\n",
" var width = this.canvas.width/mpl.ratio\n",
" var dataURL = this.canvas.toDataURL();\n",
" this.cell_info[1]['text/html'] = '<img src=\"' + dataURL + '\" width=\"' + width + '\">';\n",
"}\n",
"\n",
"mpl.figure.prototype.updated_canvas_event = function() {\n",
" // Tell IPython that the notebook contents must change.\n",
" IPython.notebook.set_dirty(true);\n",
" this.send_message(\"ack\", {});\n",
" var fig = this;\n",
" // Wait a second, then push the new image to the DOM so\n",
" // that it is saved nicely (might be nice to debounce this).\n",
" setTimeout(function () { fig.push_to_output() }, 1000);\n",
"}\n",
"\n",
"mpl.figure.prototype._init_toolbar = function() {\n",
" var fig = this;\n",
"\n",
" var nav_element = $('<div/>')\n",
" nav_element.attr('style', 'width: 100%');\n",
" this.root.append(nav_element);\n",
"\n",
" // Define a callback function for later on.\n",
" function toolbar_event(event) {\n",
" return fig.toolbar_button_onclick(event['data']);\n",
" }\n",
" function toolbar_mouse_event(event) {\n",
" return fig.toolbar_button_onmouseover(event['data']);\n",
" }\n",
"\n",
" for(var toolbar_ind in mpl.toolbar_items){\n",
" var name = mpl.toolbar_items[toolbar_ind][0];\n",
" var tooltip = mpl.toolbar_items[toolbar_ind][1];\n",
" var image = mpl.toolbar_items[toolbar_ind][2];\n",
" var method_name = mpl.toolbar_items[toolbar_ind][3];\n",
"\n",
" if (!name) { continue; };\n",
"\n",
" var button = $('<button class=\"btn btn-default\" href=\"#\" title=\"' + name + '\"><i class=\"fa ' + image + ' fa-lg\"></i></button>');\n",
" button.click(method_name, toolbar_event);\n",
" button.mouseover(tooltip, toolbar_mouse_event);\n",
" nav_element.append(button);\n",
" }\n",
"\n",
" // Add the status bar.\n",
" var status_bar = $('<span class=\"mpl-message\" style=\"text-align:right; float: right;\"/>');\n",
" nav_element.append(status_bar);\n",
" this.message = status_bar[0];\n",
"\n",
" // Add the close button to the window.\n",
" var buttongrp = $('<div class=\"btn-group inline pull-right\"></div>');\n",
" var button = $('<button class=\"btn btn-mini btn-primary\" href=\"#\" title=\"Stop Interaction\"><i class=\"fa fa-power-off icon-remove icon-large\"></i></button>');\n",
" button.click(function (evt) { fig.handle_close(fig, {}); } );\n",
" button.mouseover('Stop Interaction', toolbar_mouse_event);\n",
" buttongrp.append(button);\n",
" var titlebar = this.root.find($('.ui-dialog-titlebar'));\n",
" titlebar.prepend(buttongrp);\n",
"}\n",
"\n",
"mpl.figure.prototype._root_extra_style = function(el){\n",
" var fig = this\n",
" el.on(\"remove\", function(){\n",
"\tfig.close_ws(fig, {});\n",
" });\n",
"}\n",
"\n",
"mpl.figure.prototype._canvas_extra_style = function(el){\n",
" // this is important to make the div 'focusable\n",
" el.attr('tabindex', 0)\n",
" // reach out to IPython and tell the keyboard manager to turn it's self\n",
" // off when our div gets focus\n",
"\n",
" // location in version 3\n",
" if (IPython.notebook.keyboard_manager) {\n",
" IPython.notebook.keyboard_manager.register_events(el);\n",
" }\n",
" else {\n",
" // location in version 2\n",
" IPython.keyboard_manager.register_events(el);\n",
" }\n",
"\n",
"}\n",
"\n",
"mpl.figure.prototype._key_event_extra = function(event, name) {\n",
" var manager = IPython.notebook.keyboard_manager;\n",
" if (!manager)\n",
" manager = IPython.keyboard_manager;\n",
"\n",
" // Check for shift+enter\n",
" if (event.shiftKey && event.which == 13) {\n",
" this.canvas_div.blur();\n",
" event.shiftKey = false;\n",
" // Send a \"J\" for go to next cell\n",
" event.which = 74;\n",
" event.keyCode = 74;\n",
" manager.command_mode();\n",
" manager.handle_keydown(event);\n",
" }\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_save = function(fig, msg) {\n",
" fig.ondownload(fig, null);\n",
"}\n",
"\n",
"\n",
"mpl.find_output_cell = function(html_output) {\n",
" // Return the cell and output element which can be found *uniquely* in the notebook.\n",
" // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n",
" // IPython event is triggered only after the cells have been serialised, which for\n",
" // our purposes (turning an active figure into a static one), is too late.\n",
" var cells = IPython.notebook.get_cells();\n",
" var ncells = cells.length;\n",
" for (var i=0; i<ncells; i++) {\n",
" var cell = cells[i];\n",
" if (cell.cell_type === 'code'){\n",
" for (var j=0; j<cell.output_area.outputs.length; j++) {\n",
" var data = cell.output_area.outputs[j];\n",
" if (data.data) {\n",
" // IPython >= 3 moved mimebundle to data attribute of output\n",
" data = data.data;\n",
" }\n",
" if (data['text/html'] == html_output) {\n",
" return [cell, data, j];\n",
" }\n",
" }\n",
" }\n",
" }\n",
"}\n",
"\n",
"// Register the function which deals with the matplotlib target/channel.\n",
"// The kernel may be null if the page has been refreshed.\n",
"if (IPython.notebook.kernel != null) {\n",
" IPython.notebook.kernel.comm_manager.register_target('matplotlib', mpl.mpl_figure_comm);\n",
"}\n"
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<img src=\"\" width=\"1000\">"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n"
]
}
],
"source": [
"#df_results = %sql SELECT * FROM $results_table ORDER BY run_id;\n",
"df_results = %sql SELECT * FROM $results_table ORDER BY validation_metrics_final DESC LIMIT 100;\n",
"df_results = df_results.DataFrame()\n",
"\n",
"#set up plots\n",
"fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(10,5))\n",
"fig.legend(ncol=4)\n",
"fig.tight_layout()\n",
"\n",
"ax_metric = axs[0]\n",
"ax_loss = axs[1]\n",
"\n",
"ax_metric.xaxis.set_major_locator(MaxNLocator(integer=True))\n",
"ax_metric.set_xlabel('Iteration')\n",
"ax_metric.set_ylabel('Metric')\n",
"ax_metric.set_title('Validation metric curve')\n",
"\n",
"ax_loss.xaxis.set_major_locator(MaxNLocator(integer=True))\n",
"ax_loss.set_xlabel('Iteration')\n",
"ax_loss.set_ylabel('Loss')\n",
"ax_loss.set_title('Validation loss curve')\n",
"\n",
"for run_id in df_results['run_id']:\n",
" df_output_info = %sql SELECT validation_metrics,validation_loss FROM $results_table WHERE run_id = $run_id\n",
" df_output_info = df_output_info.DataFrame()\n",
" validation_metrics = df_output_info['validation_metrics'][0]\n",
" validation_loss = df_output_info['validation_loss'][0]\n",
" X = range(len(validation_metrics))\n",
" \n",
" ax_metric.plot(X, validation_metrics, label=run_id, marker='o')\n",
" ax_loss.plot(X, validation_loss, label=run_id, marker='o')\n",
"\n",
"# fig.savefig('./lc_keras_fit.png', dpi = 300)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a id=\"print\"></a>\n",
"# 6. Print run schedules (display only)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Pretty print reg Hyperband run schedule"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"max_iter = 27\n",
"eta = 3\n",
"B = 4*max_iter = 108\n",
"skip_last = 0\n",
" \n",
"s=3\n",
"n_i r_i\n",
"------------\n",
"27 1.0\n",
"9.0 3.0\n",
"3.0 9.0\n",
"1.0 27.0\n",
" \n",
"s=2\n",
"n_i r_i\n",
"------------\n",
"9 3.0\n",
"3.0 9.0\n",
"1.0 27.0\n",
" \n",
"s=1\n",
"n_i r_i\n",
"------------\n",
"6 9.0\n",
"2.0 27.0\n",
" \n",
"s=0\n",
"n_i r_i\n",
"------------\n",
"4 27\n",
" \n",
"sum of configurations at leaf nodes across all s = 8.0\n",
"(if have more workers than this, they may not be 100% busy)\n"
]
}
],
"source": [
"import numpy as np\n",
"from math import log, ceil\n",
"\n",
"#input\n",
"max_iter = 27 # maximum iterations/epochs per configuration\n",
"eta = 3 # defines downsampling rate (default=3)\n",
"skip_last = 0 # 1 means skip last run in each bracket, 0 means run full bracket\n",
"\n",
"logeta = lambda x: log(x)/log(eta)\n",
"s_max = int(logeta(max_iter)) # number of unique executions of Successive Halving (minus one)\n",
"B = (s_max+1)*max_iter # total number of iterations (without reuse) per execution of Succesive Halving (n,r)\n",
"\n",
"#echo output\n",
"print (\"max_iter = \" + str(max_iter))\n",
"print (\"eta = \" + str(eta))\n",
"print (\"B = \" + str(s_max+1) + \"*max_iter = \" + str(B))\n",
"print (\"skip_last = \" + str(skip_last))\n",
"\n",
"sum_leaf_n_i = 0 # count configurations at leaf nodes across all s\n",
"\n",
"#### Begin Finite Horizon Hyperband outlerloop. Repeat indefinitely.\n",
"for s in reversed(range(s_max+1)):\n",
" \n",
" print (\" \")\n",
" print (\"s=\" + str(s))\n",
" print (\"n_i r_i\")\n",
" print (\"------------\")\n",
" counter = 0\n",
" \n",
" n = int(ceil(int(B/max_iter/(s+1))*eta**s)) # initial number of configurations\n",
" r = max_iter*eta**(-s) # initial number of iterations to run configurations for\n",
"\n",
" #### Begin Finite Horizon Successive Halving with (n,r)\n",
" #T = [ get_random_hyperparameter_configuration() for i in range(n) ] \n",
" for i in range((s+1) - int(skip_last)):\n",
" # Run each of the n_i configs for r_i iterations and keep best n_i/eta\n",
" n_i = n*eta**(-i)\n",
" r_i = r*eta**(i)\n",
" \n",
" print (str(n_i) + \" \" + str (r_i))\n",
" \n",
" # check if leaf node for this s\n",
" if counter == (s-skip_last):\n",
" sum_leaf_n_i += n_i\n",
" counter += 1\n",
" \n",
" #val_losses = [ run_then_return_val_loss(num_iters=r_i,hyperparameters=t) for t in T ]\n",
" #T = [ T[i] for i in argsort(val_losses)[0:int( n_i/eta )] ]\n",
" #### End Finite Horizon Successive Halving with (n,r)\n",
"\n",
"print (\" \")\n",
"print (\"sum of configurations at leaf nodes across all s = \" + str(sum_leaf_n_i))\n",
"print (\"(if have more workers than this, they may not be 100% busy)\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Pretty print Hyperband diagonal run schedule"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"from math import log, ceil\n",
"\n",
"#input\n",
"max_iter = 27 # maximum iterations/epochs per configuration\n",
"eta = 3 # defines downsampling rate (default=3)\n",
"skip_last = 1 # 1 means skip last run in each bracket, 0 means run full bracket\n",
"\n",
"logeta = lambda x: log(x)/log(eta)\n",
"s_max = int(logeta(max_iter)) # number of unique executions of Successive Halving (minus one)\n",
"B = (s_max+1)*max_iter # total number of iterations (without reuse) per execution of Succesive Halving (n,r)\n",
"\n",
"#echo output\n",
"print (\"echo input:\")\n",
"print (\"max_iter = \" + str(max_iter))\n",
"print (\"eta = \" + str(eta))\n",
"print (\"s_max = \" + str(s_max))\n",
"print (\"B = \" + str(s_max+1) + \"*max_iter = \" + str(B))\n",
"\n",
"print (\" \")\n",
"print (\"initial n, r values for each s:\")\n",
"initial_n_vals = {}\n",
"initial_r_vals = {}\n",
"# get hyper parameter configs for each s\n",
"for s in reversed(range(s_max+1)):\n",
" \n",
" n = int(ceil(int(B/max_iter/(s+1))*eta**s)) # initial number of configurations\n",
" r = max_iter*eta**(-s) # initial number of iterations to run configurations for\n",
" \n",
" initial_n_vals[s] = n \n",
" initial_r_vals[s] = r \n",
" \n",
" print (\"s=\" + str(s))\n",
" print (\"n=\" + str(n))\n",
" print (\"r=\" + str(r))\n",
" print (\" \")\n",
" \n",
"print (\"outer loop on diagonal:\")\n",
"# outer loop on diagonal\n",
"for i in range((s_max+1) - int(skip_last)):\n",
" print (\" \")\n",
" print (\"i=\" + str(i))\n",
" \n",
" print (\"inner loop on s desc:\")\n",
" # inner loop on s desc\n",
" for s in range(s_max, s_max-i-1, -1):\n",
" n_i = initial_n_vals[s]*eta**(-i+s_max-s)\n",
" r_i = initial_r_vals[s]*eta**(i-s_max+s)\n",
" \n",
" print (\"s=\" + str(s))\n",
" print (\"n_i=\" + str(n_i))\n",
" print (\"r_i=\" + str(r_i))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a id=\"predict\"></a>\n",
"# 7. Inference\n",
"\n",
"Use the best model from the last run.\n",
"\n",
"## 7a. Run predict on the whole validation dataset"
]
},
{
"cell_type": "code",
"execution_count": 93,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>mst_key</th>\n",
" <th>model_id</th>\n",
" <th>compile_params</th>\n",
" <th>fit_params</th>\n",
" <th>model_type</th>\n",
" <th>model_size</th>\n",
" <th>metrics_elapsed_time</th>\n",
" <th>metrics_type</th>\n",
" <th>training_metrics_final</th>\n",
" <th>training_loss_final</th>\n",
" <th>training_metrics</th>\n",
" <th>training_loss</th>\n",
" <th>validation_metrics_final</th>\n",
" <th>validation_loss_final</th>\n",
" <th>validation_metrics</th>\n",
" <th>validation_loss</th>\n",
" </tr>\n",
" <tr>\n",
" <td>6</td>\n",
" <td>2</td>\n",
" <td>loss='categorical_crossentropy',optimizer='adam(lr=0.002826545217978097)',metrics=['accuracy']</td>\n",
" <td>batch_size=128,epochs=5</td>\n",
" <td>madlib_keras</td>\n",
" <td>2159.70019531</td>\n",
" <td>[156.498700857162, 314.38369679451, 471.076618909836]</td>\n",
" <td>[u'accuracy']</td>\n",
" <td>0.89631998539</td>\n",
" <td>0.301868826151</td>\n",
" <td>[0.817480027675629, 0.862479984760284, 0.896319985389709]</td>\n",
" <td>[0.536632478237152, 0.400230169296265, 0.301868826150894]</td>\n",
" <td>0.805899977684</td>\n",
" <td>0.613121390343</td>\n",
" <td>[0.764500021934509, 0.788500010967255, 0.805899977684021]</td>\n",
" <td>[0.717438697814941, 0.662977695465088, 0.613121390342712]</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(6, 2, u\"loss='categorical_crossentropy',optimizer='adam(lr=0.002826545217978097)',metrics=['accuracy']\", u'batch_size=128,epochs=5', u'madlib_keras', 2159.70019531, [156.498700857162, 314.38369679451, 471.076618909836], [u'accuracy'], 0.89631998539, 0.301868826151, [0.817480027675629, 0.862479984760284, 0.896319985389709], [0.536632478237152, 0.400230169296265, 0.301868826150894], 0.805899977684, 0.613121390343, [0.764500021934509, 0.788500010967255, 0.805899977684021], [0.717438697814941, 0.662977695465088, 0.613121390342712])]"
]
},
"execution_count": 93,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%sql SELECT * FROM $best_model_info;"
]
},
{
"cell_type": "code",
"execution_count": 94,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n"
]
}
],
"source": [
"best_mst_key = %sql SELECT mst_key FROM $best_model_info; \n",
"best_mst_key = best_mst_key.DataFrame().to_numpy()[0][0]"
]
},
{
"cell_type": "code",
"execution_count": 95,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"5 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>id</th>\n",
" <th>estimated_y</th>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>0</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(1, 0), (2, 0), (3, 0), (4, 0), (5, 0)]"
]
},
"execution_count": 95,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%sql DROP TABLE IF EXISTS cifar10_val_predict;\n",
"%sql SELECT madlib.madlib_keras_predict('cifar10_best_model', 'cifar10_val', 'id', 'x', 'cifar10_val_predict', 'response', True, $best_mst_key);\n",
"%sql SELECT * FROM cifar10_val_predict ORDER BY id LIMIT 5;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Count missclassifications"
]
},
{
"cell_type": "code",
"execution_count": 96,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>count</th>\n",
" </tr>\n",
" <tr>\n",
" <td>1941</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(1941L,)]"
]
},
"execution_count": 96,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT COUNT(*) FROM cifar10_val_predict JOIN cifar10_val USING (id) \n",
"WHERE cifar10_val_predict.estimated_y != cifar10_val.y;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Accuracy"
]
},
{
"cell_type": "code",
"execution_count": 97,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>test_accuracy_percent</th>\n",
" </tr>\n",
" <tr>\n",
" <td>80.59</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(Decimal('80.59'),)]"
]
},
"execution_count": 97,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT round(count(*)*100.0/10000.0,2) as test_accuracy_percent from\n",
" (select cifar10_val.y as actual, cifar10_val_predict.estimated_y as predicted\n",
" from cifar10_val_predict inner join cifar10_val\n",
" on cifar10_val.id=cifar10_val_predict.id) q\n",
"WHERE q.actual=q.predicted;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 7b. Select a random image from the validation dataset and run predict\n",
"\n",
"Label map"
]
},
{
"cell_type": "code",
"execution_count": 98,
"metadata": {},
"outputs": [],
"source": [
"label_names = {\n",
" 0 :\"airplane\",\n",
" 1 :\"automobile\",\n",
" 2 :\"bird\",\n",
" 3:\"cat\",\n",
" 4 :\"deer\",\n",
" 5 :\"dog\",\n",
" 6 :\"frog\",\n",
" 7 :\"horse\",\n",
" 8 :\"ship\",\n",
" 9 :\"truck\"\n",
"}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Pick a random image"
]
},
{
"cell_type": "code",
"execution_count": 99,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n"
]
},
{
"data": {
"text/plain": [
"[]"
]
},
"execution_count": 99,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS cifar10_val_random;\n",
"CREATE TABLE cifar10_val_random AS\n",
" SELECT * FROM cifar10_val ORDER BY random() LIMIT 1;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Predict"
]
},
{
"cell_type": "code",
"execution_count": 100,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>id</th>\n",
" <th>prob_0</th>\n",
" <th>prob_1</th>\n",
" <th>prob_2</th>\n",
" <th>prob_3</th>\n",
" <th>prob_4</th>\n",
" <th>prob_5</th>\n",
" <th>prob_6</th>\n",
" <th>prob_7</th>\n",
" <th>prob_8</th>\n",
" <th>prob_9</th>\n",
" </tr>\n",
" <tr>\n",
" <td>9813</td>\n",
" <td>7.9166554e-08</td>\n",
" <td>0.00038159246</td>\n",
" <td>8.776156e-11</td>\n",
" <td>1.7702625e-08</td>\n",
" <td>1.2219187e-10</td>\n",
" <td>8.096258e-10</td>\n",
" <td>5.192042e-10</td>\n",
" <td>1.5758073e-09</td>\n",
" <td>4.106987e-07</td>\n",
" <td>0.99961793</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(9813, 7.9166554e-08, 0.00038159246, 8.776156e-11, 1.7702625e-08, 1.2219187e-10, 8.096258e-10, 5.192042e-10, 1.5758073e-09, 4.106987e-07, 0.99961793)]"
]
},
"execution_count": 100,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%sql DROP TABLE IF EXISTS cifar10_val_random_predict;\n",
"%sql SELECT madlib.madlib_keras_predict('cifar10_best_model', 'cifar10_val_random', 'id', 'x', 'cifar10_val_random_predict', 'prob', True, $best_mst_key);\n",
"%sql SELECT * FROM cifar10_val_random_predict ;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Format output and display"
]
},
{
"cell_type": "code",
"execution_count": 101,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>feature_vector</th>\n",
" </tr>\n",
" <tr>\n",
" <td>[7.9166554e-08, 0.00038159246, 8.776156e-11, 1.7702625e-08, 1.2219187e-10, 8.096258e-10, 5.192042e-10, 1.5758073e-09, 4.106987e-07, 0.99961793]</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[([7.9166554e-08, 0.00038159246, 8.776156e-11, 1.7702625e-08, 1.2219187e-10, 8.096258e-10, 5.192042e-10, 1.5758073e-09, 4.106987e-07, 0.99961793],)]"
]
},
"execution_count": 101,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS cifar10_val_random_predict_array, cifar10_val_random_predict_array_summary;\n",
"SELECT madlib.cols2vec(\n",
" 'cifar10_val_random_predict',\n",
" 'cifar10_val_random_predict_array',\n",
" '*',\n",
" 'id'\n",
");\n",
"select * from cifar10_val_random_predict_array;"
]
},
{
"cell_type": "code",
"execution_count": 102,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n",
"1 rows affected.\n",
" \n",
"truck 0.99961793\n",
"automobile 0.00038159246\n",
"ship 4.106987e-07\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAHtdJREFUeJztnWmQXGeVpt+TtWTti1SLSqWSSpttZNmSbHkDG9vtMW3ci2G62wHdQfgH0erogIghoueHg4kYmIn5ARMDBBMxwYRoPJgOBsMAbgxN0IBpjzewXd4ka7G1L6WSqkpSqfYlM8/8yBQjyd97K1VVypL6e58IhbK+N7+8X968J2/mffOcY+4OIUR8pBZ7AUKIxUHBL0SkKPiFiBQFvxCRouAXIlIU/EJEioJfiEhR8AsRKQp+ISKlfD6TzewhAF8HUAbg7939S0n3b2lp8e7u7vlscoFY2F815pJ+JZkgpVL8vTeXNM8S1pILj1vCnKTlJ23Ls1mqZchCkp5Xli0ewMxMhmvTM/wxs+HHTFdW0jkV5TwsqqrSVLOEnWxlVALItKSj1Il69PARnB4cTHjV/j9zDn4zKwPwPwA8COA4gNfM7Bl3383mdHd3o+eVnsvelrOnYkm7hx9IMK7N5W1hairh4Mvw16G6tppq01N8e5UVCWuZDI8nvM8gw2MYdTxGMDk0TLWzY2PB8QkSjAAwPDZNtb6Tg1Q7cbyPaiNkjd1dK+mczrZ2qq2/bi3VKtJ8J1fUcC1VFR6fTHinzFj4Rbv/tjvonPdtt+h7vp/bAex394PuPg3gKQCPzOPxhBAlZD7B3wng2AV/Hy+MCSGuAa74BT8z22ZmPWbWMzAwcKU3J4QokvkEfy+Argv+XlEYuwh33+7uW919a2tr6zw2J4RYSOYT/K8BWG9mq82sEsAnADyzMMsSQlxp5ny1390zZvZZAP+MvNX3hLvvmnVeUSbE4pG0PHbttSpNLtcCyJbzR0wlWAuT49xBmAK/3H/y1FBwfHSU2AAApqe4jXbivTepNpPh88bIY54ZHaFzTg6epdrgGe4sTE9wa8Q8fH7b+85ROqehhrswKzqXU62xuY5qLR1LudYZ/kTc1L6Ezunqag6OW+IRfDHz8vnd/ecAfj6fxxBCLA76hZ8QkaLgFyJSFPxCRIqCX4hIUfALESnzutpfMq4Se5AuIyEtrrwsIXMvIaHmpRdepNqxoyepdub0ueD48PA4nTM5wRNqMMltxYkZbkeOTYbtt4kEexBlPGOuqqqBakkZeqls+LUZOj1K50wMcetwbITv+5kcn5c1/rwz5eEDoXtdN53zJ/82nEYzOZlwUF2CzvxCRIqCX4hIUfALESkKfiEiRcEvRKSU/Gp/Ui05PumyhXlol8/UFL/Km07XUK33GC8/9Q9Pfodqb7/9DtWWNrYFx1uWhscBoLExnCQCANMzPJHFEq7O58rCz7uikhezK0tIkKqo4K5DhjgLADA5Ei4n1ljNE23Gh8OOCQCcOnmCa4P89cySslsAUFYbfm6T03xfbdgY3tbUJHdgLkVnfiEiRcEvRKQo+IWIFAW/EJGi4BciUhT8QkRKya2+udTwW/i8nqT3vIROP2ReZVlCW5uEOn1LG5uotqpzBdWGBsJ1+gBgdCRcq2//3j10Tn19I9Vql3KtuYXbZQ314ec2neUJLmfP8hp+Q0P8OU+R7kAAYCT5KDsxQedUV3CLrbODd/PZcM/dVNt4601UW31DuAvQkja+f9dvCK/xy/+Z26WXojO/EJGi4BciUhT8QkSKgl+ISFHwCxEpCn4hImVeVp+ZHQYwAiALIOPuWxdiUZfC3LIrU9ov4f2QLMTKEnZjhluHNfW8LdRHP/Ig1TZt4LbRyy+9Ghz/p5/yxkqZKV7Db2BiH9UOHHuXamXELqup4y2t6up5BmRdDbdT25vr+TxS368zwUZbuYzbeXfceRvVlq/g86rquX1YTnZJUjm+adJ9zZOc6ku3W/xdKfe7++ACPI4QooToY78QkTLf4HcAvzSz181s20IsSAhRGub7sf9ud+81szYAvzKzve7+/IV3KLwpbAOAlStXznNzQoiFYl5nfnfvLfzfD+BpALcH7rPd3be6+9bW1nAfciFE6Zlz8JtZrZnVn78N4CMAeHE5IcRVxXw+9rcDeNryFTnLAfxvd//FgqxqQZhj5l5CFt5cGBse4atIqLU4dPoM1dKVvJjl+tXdwfEP3vm+D2W/Z/XqcFYZAAyV8+KevX28mGWWZO9dRzLYAGDLZm5hdi5roVqFJxTHnAnbmEcPvJfweNz6rKkYoNrk6DDVjp3gWYn1S8MZkG0rePHUGtK+LHUZp/M5B7+7HwSwaa7zhRCLi6w+ISJFwS9EpCj4hYgUBb8QkaLgFyJSSl7Ak5ls18S7kJNV5rg/WNvEC2Am2YozMwn958YSrKjacAHHZe38B1btbUuoduMHbqDayGgn1ZqawqlqmzZvpHNWdnVQbfj0Kaodepf/vGSoP2xH3rCaP+djB/dTLTfJC2Q2tfP9UdfAsxKbOsL7qryOFzsdGTsaHPccPzYu5ZqIOSHEwqPgFyJSFPxCRIqCX4hIUfALESklvdrv4BfGUwkF+XLEIkgZv1xeSDgKryOhNlqWtHcCgPLKNNsYf8AZnkQ0PcGv6Dc2cpfg3JnwlV4A6Os7GRxftoxf7Z+cHKfa2dN8W1VV/Ap2fS0pTJfhLbleePZ3VBsa6KXa9at4a7NxOxccL88mHDszp6mWG6+lWtPqLVTr+eU/U63yZDhJ5+aEeoHpVPggtsvITNOZX4hIUfALESkKfiEiRcEvRKQo+IWIFAW/EJFS8sQeY15fGbfLLEXmzLHeXpINWFbG2yrNqT9YBX9/tUn+gCOjYYsKAGYSiv91dIRbRk1MTNA5vb3cRqvuCNeXA4DqBm57WS5sRTnrMwXgum5e2n35rTdSLV3O7dRf7N0ZHB/LcZt1aR1vG9ZQze3NzIFDVBs7y1/PuibSOowvkSeZXUZM6MwvRKQo+IWIFAW/EJGi4BciUhT8QkSKgl+ISJnV6jOzJwD8MYB+d99YGFsC4PsAugEcBvCou/N0rYsekNky3GIz6rFxX8N9bhl/Vp6wS8hDnukfpFNGR3jGXLqct91qbiH2D4DyCj5vcjJspWVP81TGFQlZcdWkLRQAWJbbXpUI17pb3tZF59RX8XNRdoJbZaP9vIXW2NBYcHxJObfzynL8GBgf4nX1ppfy46qpju/jcoQzOKcT6gVWNoYtWEvx1+RSijnzfxvAQ5eMPQ7gWXdfD+DZwt9CiGuIWYPf3Z8HcGnXyEcAPFm4/SSAjy3wuoQQV5i5fudvd/e+wu2TyHfsFUJcQ8z7gp/nv1zTL9hmts3MesysZ3CAfzcTQpSWuQb/KTPrAIDC//3sju6+3d23uvvWllZeSkoIUVrmGvzPAHiscPsxAD9ZmOUIIUpFMVbf9wDcB6DFzI4D+AKALwH4gZl9GsARAI/OdyGeYNulmNWXYNllMtySKZ+DnQcA2WzYpjx2/DidM5FQpLOZ2DUAsLSthWqNS/i84eGR4PjYBLccM+N8X6VQTbWZbELGYio8r7Ep4dNfhmceHjq4m2rDp8ItuQBgfDj83EYreGbkqT7eGqzv5NtU2/aZzVSrrVlGtemZsO3oWd5SDM4sveJTT2cNfnf/JJEeKHorQoirDv3CT4hIUfALESkKfiEiRcEvRKQo+IWIlJIW8DRw286JjQYASCqqScixBn+zkeCUnDsXziw7dqIvOA4Azc3NVBtLKGZZWU36AgJY2thGtevrw1l4re3cYtu1axfV+k5x+625iWf8ZcjuHx9P6IWY5drYJLdMR8b4Gscnwo/57nu82Ob6tddR7Xhv2EoFgCFiKwLAD//xWaqVN4Vt3b9aeTOd00paOSYks74PnfmFiBQFvxCRouAXIlIU/EJEioJfiEhR8AsRKSXv1ZcibzeZzBwb7xHc59JYD8gwjwrAKVKos/cEzyobIvYgAMzMcGtr6223UG04oY/foSMHg+PpNLcO11y/lmrH+t+j2rQnNJMjR1Z5NS8+Oth7lGq/ffW3VJs4w4vE2EzYTh0fHaZzVq3l+6NnJ98fz77wItVefWsH1e7+wz8Jjreu4MU4Z8jhfTmHvc78QkSKgl+ISFHwCxEpCn4hIkXBL0SklPxqPyOX40kRpVzm1BS/gt0/GL6qbGX8PXT33j1U27t3L9XWrl9DtSSX4B9/+kxwfNUq3iZr48aNVPNy3uZrJjVNNasMuzcVvCQgpnI8QefAsX1UGxvgNfc6m8N18FqX8xqJR/u46/Di716g2sxrb1Htxjs+TLWbbr81ON7Lu8ChbzDchmxioviENp35hYgUBb8QkaLgFyJSFPxCRIqCX4hIUfALESnFtOt6AsAfA+h3942FsS8C+GsA572vz7v7z+ezkLml4Sw80wk22shIuH5bdTX3r4aGhqj2+luv822Nh60cALjhhhuodvPNYduuq4tbffWNvBbf6BS32JpbaqmWtbANODzO90dtA08+2nIbr2fXd4An25w9fjK8rfpOOicDbmEuX8Xn7d7PazlO5/hx9cS3/1dwfDLF98etH7w7OD48ymsMXkoxZ/5vA3goMP41d99c+DevwBdClJ5Zg9/dnwdwpgRrEUKUkPl85/+sme0wsyfMjNenFkJclcw1+L8BYC2AzQD6AHyF3dHMtplZj5n1DAzwogtCiNIyp+B391PunnX3HIBvArg94b7b3X2ru29tbU3ozS6EKClzCn4z67jgz48DeGdhliOEKBXFWH3fA3AfgBYzOw7gCwDuM7PNABzAYQB/U9zmHMiGs/cqKnhLrumpcLZXZUJduokJbnkcOsKzwCoqeI25rpXtwfHrb1hN56R5GTYsbeZWWU3C23JDOd9XD33onvCcxjo651Rf2A4DgJYantU3fprXLmzqXhkcb0w10Tn7+/hrVjbFW2jV13D77Wz58eD4zoO8ZuTYLp5d6LV3Uq3zunGq7Xinn2p1zWGrNVXB1zjVdzo47jNJ2bEXM2vwu/snA8PfKnoLQoirEv3CT4hIUfALESkKfiEiRcEvRKQo+IWIlBIX8DQgFd7kyAhvn8Rg7bOA5Gw6GLfK6uu4JTZNantWV9XTOffccz/Vrlv3AaplZ3gh0f37wy25AKCqIrx/q6qq6JwUeU0AIJvh54dsQou1iYlwFttbb/KipS899zLVdr+9m2qe5WtM5djxNkrnTM1we/NkP7eJhyd4JmYuxfNW73ngrvD4fffSOd3d3cHxp3/63+mcS9GZX4hIUfALESkKfiEiRcEvRKQo+IWIFAW/EJFSWqvPHdnpcHHEl1/6HZ3W1bUqOG4JFlVbK8/0SlfVUM0SbMDxsbAFlJnhNk5lFS/u2dzURrX+kzxjzsAzDzs6VgTH6+q4HTk+xi2qsRGeJbZ29TqqrVi+Njj+6kuv0DnPP/8S1Zy7b8AMtxxHz4UtvdFRbvWdSbCJ+wZ5BuT9D9xHtRs38wKkt31wc3B8w43h4x4ArCz8nFP88H3/fYu/qxDiXxMKfiEiRcEvRKQo+IWIFAW/EJFS0qv92WwWw8Phq6zPPfc8nbdly63B8ZbWZXROupK3oEpK+sklJImYhQvylZFxADhymNduSxm/hJ2u5Ffn65t5kk5tTXjeuXO8Pt742CTVfvvyW1Q7cYz3cpkkBsILv+Gv846de6jW0dpBtdFh/nqeOxvWNmzgLc/Kq7l7c2CAr7GmgR87N97E6zw2Nocv0bvxZDejl/WTbJGL0ZlfiEhR8AsRKQp+ISJFwS9EpCj4hYgUBb8QkVJMu64uAN8B0I58e67t7v51M1sC4PsAupFv2fWou59NeqyRkVH83+deCGqvvtpD57EEmKmpBFvOeS2+iQluh0xNca2mOtxey50nlkxPcRttRWe4/RcANNTxVl5jo9xi29u/LziezYYTqgCgsYHvq1s2h+vLAcCbb7xOtdP9vw6OnzrRR+dMT+WoduDAAb6tEV5Xr61hSXC8fflSOuem9vVUe+/QDqr1DxyjWgbctqutXx4cr0uwDuvIa5bU9u5SijnzZwD8nbtvAHAngM+Y2QYAjwN41t3XA3i28LcQ4hph1uB39z53f6NwewTAHgCdAB4B8GThbk8C+NiVWqQQYuG5rO/8ZtYNYAuAVwC0u/v5z3Ankf9aIIS4Rig6+M2sDsCPAHzO3S/6AuP5L73BL75mts3MesysZ3j48mvzCyGuDEUFv5lVIB/433X3HxeGT5lZR0HvABD8Ebu7b3f3re6+taGB/95eCFFaZg1+MzMA3wKwx92/eoH0DIDHCrcfA/CThV+eEOJKUUxW34cAfArATjM7n+L1eQBfAvADM/s0gCMAHp3tgZKy+jIJLZJ2kFZNr722i87pWNZNteZmfnlicIB/NTk3FF57b28vndPWGraaAGDzZt6u66Ybr6NaKsVtu8x0OJ2urb2JzqmoSFPt/vt5uzGWMQcAe3aHX7PBQd5iLZNJeF65cPsvAKgAt7eqa8IZl8s7ef3EnPFtTWfHqVZbx8NpVXfYzgOAtvbG4LiDt2wrKye1IY3bzpcya/C7+4sAWI7jA0VvSQhxVaFf+AkRKQp+ISJFwS9EpCj4hYgUBb8QkVLSAp7pdBpr1qwJamvX8kyqd/ceDI6/+SYvppjL8qKaK1bwNlMz3G2ibbKWLuUZYpPjPIutLMXfezdt3Mi1m7mWzYTtyOoa/lIPn+NZgv/0M/7zjYOHwhmEAHDk6P7geFJWXzbB6qur4XZkfRNviTY+Fd4fVbW85dmBw4f5Ohr4tj5wE7du29paqFZdHX7Mikr+mpUTe9OoMfd+dOYXIlIU/EJEioJfiEhR8AsRKQp+ISJFwS9EpJTU6hsdHcVLL70c1KanM3ReZ2dXcDyV4oUnDx04QbXx8QmqVZbzmgPXXx/u75ZUwPPY0SNU27njHaqd/kOeTTc5yS2xPbvCj3mij69j1zu8H9/eIzwL79y5c1RLlYUz49as7aRzJidIgz8AJ44fpdrYOM/CKy8LW1+//s2v6Jylbc1Ue/jhh6n2kQcfolpjI3/Mhvpw5meS1TdNPOmEQ/F96MwvRKQo+IWIFAW/EJGi4BciUhT8QkSKJV2pXmjaWtv8zz8WLvX3zM9+QedlM+H3qC1beCupu+64j2rfe+ppqg2f4+21amvqg+Pj47yuW3UVTzCC8xptyzv41eG6hFpxVZXh13Plqg465/TgSapNkqvlAJCZ5lfZ2XFlOX68jU+Ek3AAYHR4hGoVCQlSdXVhR6hzBd8ft956K9Vuv+t2qq3sXkW1mpoaqlWkw0lLlvC8srlwzcs77rgDPT09RWX36MwvRKQo+IWIFAW/EJGi4BciUhT8QkSKgl+ISJk1scfMugB8B/kW3A5gu7t/3cy+COCvAQwU7vp5d/950mNVV9dg06ZNQW3ve+E6fQCwj2iHDofrxAHAsvYVVGtqrqXa6dOnqXb4SLj+3JYtW+iculpe8+3NN16hWlcXt/ruvPM2qhlp8TQ+zq2y5ma+rTOTZ6mWSmgNlW/x+H7qavj+WFXHW2glrXFVVzfVurrCSWHLl/P2Wa2tvJ1bQ2O4tVYefi7NJViccDIvYUqZ8RZlxVJMVl8GwN+5+xtmVg/gdTM7nxL1NXf/b/NehRCi5BTTq68PQF/h9oiZ7QHA8zKFENcEl/Wd38y6AWwBcP7z6mfNbIeZPWFm/HOZEOKqo+jgN7M6AD8C8Dl3HwbwDQBrAWxG/pPBV8i8bWbWY2Y9o6P8e6cQorQUFfxmVoF84H/X3X8MAO5+yt2z7p4D8E0AwR89u/t2d9/q7lvr6sK/jRdClJ5Zg9/yl22/BWCPu3/1gvELMyM+DoDXpBJCXHUUc7X/QwA+BWCnmZ0v9vZ5AJ80s83IGxKHAfzNbA9UU1NNrb7TZ3k9uK4V4WypAwd5Xbo33vwd1U718/ZUQ8PDVFt3XbilWHsHr/v34L95gGqTUwNUS1fxxKzB0zwLb7A/bEcOj/DnvCYhG23deq4xOw/gLaiampronKYGvh8bEyy29rZlVGtrC9uHZWW8XVdlFdeyWe6/uYcz7QAgVc6zO42dgq9wwm0xV/tfBIINwBI9fSHE1Y1+4SdEpCj4hYgUBb8QkaLgFyJSFPxCREpJ23W5A7lcLqgNDXEryhEuFHnzphvpnJkZ3v7reG/YDgOA3/zmOb4OC7eT6ujkNtR1N/A0iPpGvvuHh7gNeOAgb9dVUxV+zDVruumcD931QaqtuWkl1crKeGYZs/qqqqroHEtxbyub4VrSOlAWPr+lE7ILK9N8jRUVCZZdUWUzA7CnlmD1TZI6sx4OryA68wsRKQp+ISJFwS9EpCj4hYgUBb8QkaLgFyJSSmr15XJZjI2F+7EdPnKIztu9a09wvHvNajqntbWVa+3h/m0A0NDEbaPy9ERw/AMbebHQw0d3Ue3I0b1UW72KZ9PdcstGrm3aHBxfuTJcyBIAOhOKWU6leD++nHM7lVm6bBwAchmupVLcRysv54dxOh225hoak2pLJJwT55hpl5Dwx7P6EqzDKtL6jz5WAJ35hYgUBb8QkaLgFyJSFPxCRIqCX4hIUfALESkltfrS6SqsW7cuqN1777103thYuN7/3nd30jkvvHiKanUNvFffuZF+qrW2h+ftO/A2nTM8xAuTrli5lGp/8eifUu2mG2+mWtvScMHK6ir+nGtIBh4AtDRwSyyTYNs5mNXH7UFL6P1XXplgwZYnncPCjzk1TdLiAKRSfFupVMK2ctyb8wQtRcIwVZbg9S1AcU+d+YWIFAW/EJGi4BciUhT8QkSKgl+ISJn1ar+ZVQF4HkC6cP8fuvsXzGw1gKcALAXwOoBPuTsvLgcgXVWJ1eu6g9oflX+UL5Ks8vmX+BXsfft40sz4FO8WnHX+mGfO9gbH39vHr/YvX8Zr+P3lX/0Z1R56kLf5qq3mV+CzU+HLwKkUf6lrEpJcsvziPCpJfTwAALsCbwmHXMLF7VxCcbpMJiHBiCQfpSvTfGPEqciT8JwTT6VJBf7CWT+5aZ4NlM2GNU9wYC6lmDP/FIA/cPdNyLfjfsjM7gTwZQBfc/d1AM4C+HTRWxVCLDqzBr/nOZ+HW1H45wD+AMAPC+NPAvjYFVmhEOKKUNR3fjMrK3To7QfwKwAHAAy5//4z1XEA/POtEOKqo6jgd/esu28GsALA7QBuKHYDZrbNzHrMrGdggNeiF0KUlsu62u/uQwD+BcBdAJrMfn/1ZgWA4NUwd9/u7lvdfWtSdR0hRGmZNfjNrNXMmgq3qwE8CGAP8m8Cf16422MAfnKlFimEWHiKSezpAPCkmZUh/2bxA3f/mZntBvCUmf0XAG8C+Nasj+RAbiZsUXR385p19953T3C8rpG3Vdq1m9elO37iMNUq03yX1NaGE2A2bNhA53Sv5HUG77jtTqqly0iRNgDlKW5TVTSS+oQJll2SC5XUCWtOvxJJSkhJWEdSQk2SjckWmXNem3Au1iEAlKUqqFZRzjVGqjLhOZPnZQm1Di9l1uB39x0AtgTGDyL//V8IcQ2iX/gJESkKfiEiRcEvRKQo+IWIFAW/EJFi7gtQDKzYjZkNADhS+LMFwGDJNs7ROi5G67iYa20dq9y9qF/TlTT4L9qwWY+7b12UjWsdWofWoY/9QsSKgl+ISFnM4N++iNu+EK3jYrSOi/lXu45F+84vhFhc9LFfiEhZlOA3s4fM7F0z229mjy/GGgrrOGxmO83sLTPrKeF2nzCzfjN754KxJWb2KzPbV/i/eZHW8UUz6y3sk7fM7OESrKPLzP7FzHab2S4z+3eF8ZLuk4R1lHSfmFmVmb1qZm8X1vGfCuOrzeyVQtx838wq57Uhdy/pPwBlyJcBWwOgEsDbADaUeh2FtRwG0LII2/0wgFsAvHPB2H8F8Hjh9uMAvrxI6/gigH9f4v3RAeCWwu16AO8B2FDqfZKwjpLuE+STm+sKtysAvALgTgA/APCJwvj/BPC389nOYpz5bwew390Per7U91MAHlmEdSwa7v48gDOXDD+CfCFUoEQFUck6So6797n7G4XbI8gXi+lEifdJwjpKiue54kVzFyP4OwEcu+DvxSz+6QB+aWavm9m2RVrDedrdva9w+ySA9kVcy2fNbEfha8EV//pxIWbWjXz9iFewiPvkknUAJd4npSiaG/sFv7vd/RYAHwXwGTP78GIvCMi/82NBmjDPiW8AWIt8j4Y+AF8p1YbNrA7AjwB8zt2HL9RKuU8C6yj5PvF5FM0tlsUI/l4AXRf8TYt/Xmncvbfwfz+Ap7G4lYlOmVkHABT+71+MRbj7qcKBlwPwTZRon5hZBfIB9113/3FhuOT7JLSOxdonhW1fdtHcYlmM4H8NwPrClctKAJ8A8EypF2FmtWZWf/42gI8AeCd51hXlGeQLoQKLWBD1fLAV+DhKsE/MzJCvAbnH3b96gVTSfcLWUep9UrKiuaW6gnnJ1cyHkb+SegDAf1ikNaxB3ml4G8CuUq4DwPeQ//g4g/x3t08j3/PwWQD7APwawJJFWsc/ANgJYAfywddRgnXcjfxH+h0A3ir8e7jU+yRhHSXdJwBuRr4o7g7k32j+4wXH7KsA9gP4PwDS89mOfuEnRKTEfsFPiGhR8AsRKQp+ISJFwS9EpCj4hYgUBb8QkaLgFyJSFPxCRMr/A3fhoyps8T+yAAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"x = %sql SELECT x FROM cifar10_val_random;\n",
"x = x.DataFrame().to_numpy()\n",
"import numpy as np\n",
"from matplotlib.pyplot import imshow\n",
"%matplotlib inline\n",
"x_np = np.array(x[0][0], dtype=np.uint8)\n",
"imshow(x_np)\n",
"\n",
"x = %sql SELECT * FROM cifar10_val_random_predict_array;\n",
"x = x.DataFrame().to_numpy()\n",
"x = np.array(x[0][0])\n",
"top_3_prob_label_indices = x.argsort()[-3:][::-1]\n",
"print (\" \");\n",
"for index in top_3_prob_label_indices:\n",
" print (label_names[index], x[index])"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.16"
}
},
"nbformat": 4,
"nbformat_minor": 1
}